├── .gitignore ├── .swift-format ├── LICENSE.md ├── Package.resolved ├── Package.swift ├── README.md ├── Sources ├── Embeddings │ ├── Bert │ │ ├── BertModel.swift │ │ └── BertUtils.swift │ ├── Clip │ │ ├── ClipModel.swift │ │ └── ClipUtils.swift │ ├── EmbeddingsUtils.swift │ ├── Model2Vec │ │ ├── Model2VecModel.swift │ │ └── Model2VecUtils.swift │ ├── Roberta │ │ ├── RobertaModel.swift │ │ └── RobertaUtils.swift │ ├── StaticEmbeddings │ │ ├── StaticEmbeddingsModel.swift │ │ └── StaticEmbeddingsUtils.swift │ ├── Tokenizer │ │ ├── ClipTokenizer.swift │ │ ├── TextTokenizer.swift │ │ └── XLMRobetaTokenizer.swift │ ├── Word2Vec │ │ ├── Word2VecModel.swift │ │ └── Word2VecUtils.swift │ └── XLMRoberta │ │ ├── XLMRobertaModel.swift │ │ └── XLMRobertaUtils.swift ├── EmbeddingsCLI │ ├── Commands │ │ ├── BertCommand.swift │ │ ├── ClipCommand.swift │ │ ├── Model2VecCommand.swift │ │ ├── RobertaCommand.swift │ │ ├── StaticEmbeddingsCommad.swift │ │ ├── Word2VecCommand.swift │ │ └── XLMRobertaCommand.swift │ └── EmbeddingsCLI.swift ├── MLTensorUtils │ ├── Activations.swift │ ├── Functions.swift │ └── Layers.swift └── TestingUtils │ └── TestingUtils.swift └── Tests ├── AccuracyTests ├── AccuracyTests.swift └── Scripts │ └── generate.py ├── EmbeddingsTests ├── BertTests.swift ├── Model2VecTests.swift ├── Resources │ ├── merges.txt │ └── vocab.json ├── StaticEmbeddingsTests.swift ├── TokenizerTests.swift ├── UtilsTests.swift └── Word2VecTests.swift └── MLTensorUtilsTests ├── ActivationTests.swift ├── FunctionTests.swift └── LayerTests.swift /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | /.build 3 | /Packages 4 | xcuserdata/ 5 | DerivedData/ 6 | .swiftpm/configuration/registries.json 7 | .swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata 8 | .netrc 9 | .swiftpm/ 10 | .vscode/ 11 | .benchmarkBaselines/ 12 | -------------------------------------------------------------------------------- /.swift-format: -------------------------------------------------------------------------------- 1 | { 2 | "fileScopedDeclarationPrivacy" : { 3 | "accessLevel" : "private" 4 | }, 5 | "indentation" : { 6 | "spaces" : 4 7 | }, 8 | "indentConditionalCompilationBlocks" : true, 9 | "indentSwitchCaseLabels" : false, 10 | "lineBreakAroundMultilineExpressionChainComponents" : false, 11 | "lineBreakBeforeControlFlowKeywords" : false, 12 | "lineBreakBeforeEachArgument" : false, 13 | "lineBreakBeforeEachGenericRequirement" : false, 14 | "lineLength" : 100, 15 | "maximumBlankLines" : 1, 16 | "multiElementCollectionTrailingCommas" : true, 17 | "noAssignmentInExpressions" : { 18 | "allowedFunctions" : [ 19 | "XCTAssertNoThrow" 20 | ] 21 | }, 22 | "prioritizeKeepingFunctionOutputTogether" : false, 23 | "respectsExistingLineBreaks" : true, 24 | "rules" : { 25 | "AllPublicDeclarationsHaveDocumentation" : false, 26 | "AlwaysUseLiteralForEmptyCollectionInit" : false, 27 | "AlwaysUseLowerCamelCase" : true, 28 | "AmbiguousTrailingClosureOverload" : true, 29 | "BeginDocumentationCommentWithOneLineSummary" : false, 30 | "DoNotUseSemicolons" : true, 31 | "DontRepeatTypeInStaticProperties" : true, 32 | "FileScopedDeclarationPrivacy" : true, 33 | "FullyIndirectEnum" : true, 34 | "GroupNumericLiterals" : true, 35 | "IdentifiersMustBeASCII" : true, 36 | "NeverForceUnwrap" : false, 37 | "NeverUseForceTry" : false, 38 | "NeverUseImplicitlyUnwrappedOptionals" : false, 39 | "NoAccessLevelOnExtensionDeclaration" : true, 40 | "NoAssignmentInExpressions" : true, 41 | "NoBlockComments" : true, 42 | "NoCasesWithOnlyFallthrough" : true, 43 | "NoEmptyTrailingClosureParentheses" : true, 44 | "NoLabelsInCasePatterns" : true, 45 | "NoLeadingUnderscores" : false, 46 | "NoParensAroundConditions" : true, 47 | "NoPlaygroundLiterals" : true, 48 | "NoVoidReturnOnFunctionSignature" : true, 49 | "OmitExplicitReturns" : false, 50 | "OneCasePerLine" : true, 51 | "OneVariableDeclarationPerLine" : true, 52 | "OnlyOneTrailingClosureArgument" : true, 53 | "OrderedImports" : true, 54 | "ReplaceForEachWithForLoop" : true, 55 | "ReturnVoidInsteadOfEmptyTuple" : true, 56 | "TypeNamesShouldBeCapitalized" : true, 57 | "UseEarlyExits" : false, 58 | "UseExplicitNilCheckInConditions" : true, 59 | "UseLetInEveryBoundCaseVariable" : true, 60 | "UseShorthandTypeNames" : true, 61 | "UseSingleLinePropertyGetter" : true, 62 | "UseSynthesizedInitializer" : true, 63 | "UseTripleSlashForDocumentationComments" : true, 64 | "UseWhereClausesInForLoops" : false, 65 | "ValidateDocumentationComments" : false 66 | }, 67 | "spacesAroundRangeFormationOperators" : false, 68 | "tabWidth" : 8, 69 | "version" : 1 70 | } 71 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jan Krukowski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "originHash" : "805e5db73cce11acc62b4d7eb123a118645408e88cbc6cf52771e21fb7d84ae3", 3 | "pins" : [ 4 | { 5 | "identity" : "command", 6 | "kind" : "remoteSourceControl", 7 | "location" : "https://github.com/tuist/Command.git", 8 | "state" : { 9 | "revision" : "079a7803b581d3022469b3a331bccd51d48d2fc0", 10 | "version" : "0.13.0" 11 | } 12 | }, 13 | { 14 | "identity" : "jinja", 15 | "kind" : "remoteSourceControl", 16 | "location" : "https://github.com/johnmai-dev/Jinja", 17 | "state" : { 18 | "revision" : "bbddb92fc51ae420b87300298370fd1dfc308f73", 19 | "version" : "1.1.1" 20 | } 21 | }, 22 | { 23 | "identity" : "mockable", 24 | "kind" : "remoteSourceControl", 25 | "location" : "https://github.com/Kolos65/Mockable", 26 | "state" : { 27 | "revision" : "118a0b8934e585b80952586db30bcb72aef45a74", 28 | "version" : "0.3.2" 29 | } 30 | }, 31 | { 32 | "identity" : "path", 33 | "kind" : "remoteSourceControl", 34 | "location" : "https://github.com/tuist/Path", 35 | "state" : { 36 | "revision" : "7c74ac435e03a927c3a73134c48b61e60221abcb", 37 | "version" : "0.3.8" 38 | } 39 | }, 40 | { 41 | "identity" : "swift-argument-parser", 42 | "kind" : "remoteSourceControl", 43 | "location" : "https://github.com/apple/swift-argument-parser.git", 44 | "state" : { 45 | "revision" : "0fbc8848e389af3bb55c182bc19ca9d5dc2f255b", 46 | "version" : "1.4.0" 47 | } 48 | }, 49 | { 50 | "identity" : "swift-collections", 51 | "kind" : "remoteSourceControl", 52 | "location" : "https://github.com/apple/swift-collections.git", 53 | "state" : { 54 | "revision" : "671108c96644956dddcd89dd59c203dcdb36cec7", 55 | "version" : "1.1.4" 56 | } 57 | }, 58 | { 59 | "identity" : "swift-log", 60 | "kind" : "remoteSourceControl", 61 | "location" : "https://github.com/apple/swift-log", 62 | "state" : { 63 | "revision" : "96a2f8a0fa41e9e09af4585e2724c4e825410b91", 64 | "version" : "1.6.2" 65 | } 66 | }, 67 | { 68 | "identity" : "swift-numerics", 69 | "kind" : "remoteSourceControl", 70 | "location" : "https://github.com/apple/swift-numerics", 71 | "state" : { 72 | "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", 73 | "version" : "1.0.2" 74 | } 75 | }, 76 | { 77 | "identity" : "swift-safetensors", 78 | "kind" : "remoteSourceControl", 79 | "location" : "https://github.com/jkrukowski/swift-safetensors", 80 | "state" : { 81 | "revision" : "718b0f38f912e0bf9d92130fa1e1fe2ae5136dd6", 82 | "version" : "0.0.7" 83 | } 84 | }, 85 | { 86 | "identity" : "swift-sentencepiece", 87 | "kind" : "remoteSourceControl", 88 | "location" : "https://github.com/jkrukowski/swift-sentencepiece", 89 | "state" : { 90 | "revision" : "36a8b2b45733f6adb3092100f16e4c7d38a10a7c", 91 | "version" : "0.0.6" 92 | } 93 | }, 94 | { 95 | "identity" : "swift-syntax", 96 | "kind" : "remoteSourceControl", 97 | "location" : "https://github.com/swiftlang/swift-syntax", 98 | "state" : { 99 | "revision" : "0687f71944021d616d34d922343dcef086855920", 100 | "version" : "600.0.1" 101 | } 102 | }, 103 | { 104 | "identity" : "swift-transformers", 105 | "kind" : "remoteSourceControl", 106 | "location" : "https://github.com/huggingface/swift-transformers.git", 107 | "state" : { 108 | "revision" : "c2f302a74cca59cbde683b1425ab43c05685515a", 109 | "version" : "0.1.21" 110 | } 111 | }, 112 | { 113 | "identity" : "xctest-dynamic-overlay", 114 | "kind" : "remoteSourceControl", 115 | "location" : "https://github.com/pointfreeco/xctest-dynamic-overlay", 116 | "state" : { 117 | "revision" : "a3f634d1a409c7979cabc0a71b3f26ffa9fc8af1", 118 | "version" : "1.4.3" 119 | } 120 | } 121 | ], 122 | "version" : 3 123 | } 124 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version: 6.0 2 | 3 | import PackageDescription 4 | 5 | let package = Package( 6 | name: "swift-embeddings", 7 | platforms: [ 8 | .macOS(.v14), 9 | .iOS(.v17), 10 | .tvOS(.v17), 11 | .visionOS(.v1), 12 | .watchOS(.v10), 13 | ], 14 | products: [ 15 | .executable( 16 | name: "embeddings-cli", 17 | targets: ["EmbeddingsCLI"] 18 | ), 19 | .library( 20 | name: "Embeddings", 21 | targets: ["Embeddings"]), 22 | .library( 23 | name: "MLTensorUtils", 24 | targets: ["MLTensorUtils"]), 25 | ], 26 | dependencies: [ 27 | .package( 28 | url: "https://github.com/apple/swift-numerics.git", 29 | from: "1.0.2" 30 | ), 31 | .package( 32 | url: "https://github.com/huggingface/swift-transformers.git", 33 | from: "0.1.21" 34 | ), 35 | .package( 36 | url: "https://github.com/jkrukowski/swift-safetensors.git", 37 | from: "0.0.7" 38 | ), 39 | .package( 40 | url: "https://github.com/apple/swift-argument-parser.git", 41 | from: "1.4.0" 42 | ), 43 | .package( 44 | url: "https://github.com/jkrukowski/swift-sentencepiece", 45 | from: "0.0.6" 46 | ), 47 | .package( 48 | url: "https://github.com/tuist/Command.git", 49 | from: "0.13.0" 50 | ), 51 | ], 52 | targets: [ 53 | .executableTarget( 54 | name: "EmbeddingsCLI", 55 | dependencies: [ 56 | "Embeddings", 57 | "MLTensorUtils", 58 | .product(name: "Safetensors", package: "swift-safetensors"), 59 | .product(name: "ArgumentParser", package: "swift-argument-parser"), 60 | ] 61 | ), 62 | .target( 63 | name: "Embeddings", 64 | dependencies: [ 65 | "MLTensorUtils", 66 | .product(name: "Safetensors", package: "swift-safetensors"), 67 | .product(name: "Transformers", package: "swift-transformers"), 68 | .product(name: "SentencepieceTokenizer", package: "swift-sentencepiece"), 69 | ] 70 | ), 71 | .target( 72 | name: "MLTensorUtils"), 73 | .target( 74 | name: "TestingUtils", 75 | dependencies: [ 76 | .product(name: "Numerics", package: "swift-numerics") 77 | ] 78 | ), 79 | .testTarget( 80 | name: "EmbeddingsTests", 81 | dependencies: [ 82 | "Embeddings", 83 | "MLTensorUtils", 84 | "TestingUtils", 85 | .product(name: "Safetensors", package: "swift-safetensors"), 86 | ], 87 | resources: [ 88 | .copy("Resources") 89 | ] 90 | ), 91 | .testTarget( 92 | name: "AccuracyTests", 93 | dependencies: [ 94 | "Embeddings", 95 | "TestingUtils", 96 | .product(name: "Command", package: "Command"), 97 | ], 98 | resources: [ 99 | .copy("Scripts") 100 | ] 101 | ), 102 | .testTarget( 103 | name: "MLTensorUtilsTests", 104 | dependencies: [ 105 | "MLTensorUtils", 106 | "TestingUtils", 107 | ] 108 | ), 109 | ] 110 | ) 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `swift-embeddings` 2 | 3 | [![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fjkrukowski%2Fswift-embeddings%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/jkrukowski/swift-embeddings) 4 | [![](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fjkrukowski%2Fswift-embeddings%2Fbadge%3Ftype%3Dplatforms)](https://swiftpackageindex.com/jkrukowski/swift-embeddings) 5 | 6 | Run embedding models locally in `Swift` using `MLTensor`. 7 | Inspired by [mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings). 8 | 9 | ## Supported Models Architectures 10 | 11 | ### BERT (Bidirectional Encoder Representations from Transformers) 12 | 13 | Some of the supported models on `Hugging Face`: 14 | 15 | - [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) 16 | - [sentence-transformers/msmarco-bert-base-dot-v5](https://huggingface.co/sentence-transformers/msmarco-bert-base-dot-v5) 17 | - [sentence-transformers/LaBSE](https://huggingface.co/sentence-transformers/LaBSE) 18 | - [thenlper/gte-base](https://huggingface.co/thenlper/gte-base) 19 | - [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) 20 | 21 | NOTE: `google-bert/bert-base-uncased` is supported but `weightKeyTransform` must be provided in the `LoadConfig`: 22 | 23 | ```swift 24 | let modelBundle = try await Bert.loadModelBundle( 25 | from: "google-bert/bert-base-uncased", 26 | loadConfig: .googleBert 27 | ) 28 | ``` 29 | 30 | ### RoBERTa (Robustly Optimized BERT Approach) 31 | 32 | Some of the supported models on `Hugging Face`: 33 | 34 | - [FacebookAI/roberta-base](https://huggingface.co/FacebookAI/roberta-base) 35 | 36 | NOTE: Weights in `FacebookAI/roberta-base` must be prefixed with `roberta.`, this has to be provided in the `LoadConfig`: 37 | 38 | ```swift 39 | let modelBundle = try await Roberta.loadModelBundle( 40 | from: "FacebookAI/roberta-base", 41 | loadConfig: .addWeightKeyPrefix("roberta.") 42 | ) 43 | ``` 44 | 45 | ### XLM-RoBERTa (Cross-lingual Language Model - Robustly Optimized BERT Approach) 46 | 47 | Some of the supported models on `Hugging Face`: 48 | 49 | - [FacebookAI/xlm-roberta-base](https://huggingface.co/FacebookAI/xlm-roberta-base) 50 | - [sentence-transformers/paraphrase-multilingual-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-mpnet-base-v2) 51 | - [tomaarsen/xlm-roberta-base-multilingual-en-ar-fr-de-es-tr-it](https://huggingface.co/tomaarsen/xlm-roberta-base-multilingual-en-ar-fr-de-es-tr-it) 52 | 53 | NOTE: Weights in `FacebookAI/xlm-roberta-base` must be prefixed with `roberta.`, this has to be provided in the `LoadConfig`: 54 | 55 | ```swift 56 | let modelBundle = try await XLMRoberta.loadModelBundle( 57 | from: "FacebookAI/xlm-roberta-base", 58 | loadConfig: .addWeightKeyPrefix("roberta.") 59 | ) 60 | ``` 61 | 62 | ### CLIP (Contrastive Language–Image Pre-training) 63 | 64 | NOTE: only text encoding is supported for now. 65 | Some of the supported models on `Hugging Face`: 66 | 67 | - [jkrukowski/clip-vit-base-patch16](https://huggingface.co/jkrukowski/clip-vit-base-patch16) 68 | - [jkrukowski/clip-vit-base-patch32](https://huggingface.co/jkrukowski/clip-vit-base-patch32) 69 | - [jkrukowski/clip-vit-large-patch14](https://huggingface.co/jkrukowski/clip-vit-large-patch14) 70 | 71 | ### Word2Vec 72 | 73 | NOTE: it's a word embedding model. It loads and keeps the whole model in memory. 74 | For the more memory efficient solution, you might want to use [SQLiteVec](https://github.com/jkrukowski/SQLiteVec). 75 | Some of the supported models on `Hugging Face`: 76 | 77 | - [jkrukowski/glove-twitter-25](https://huggingface.co/jkrukowski/glove-twitter-25) 78 | - [jkrukowski/glove-twitter-50](https://huggingface.co/jkrukowski/glove-twitter-50) 79 | - [jkrukowski/glove-twitter-100](https://huggingface.co/jkrukowski/glove-twitter-100) 80 | - [jkrukowski/glove-twitter-200](https://huggingface.co/jkrukowski/glove-twitter-200) 81 | 82 | ### Model2Vec 83 | 84 | More info [here](https://huggingface.co/blog/Pringled/model2vec). 85 | 86 | Some of the supported models on `Hugging Face`: 87 | 88 | - [minishlab/potion-base-2M](https://huggingface.co/minishlab/potion-base-2M) 89 | - [minishlab/potion-base-4M](https://huggingface.co/minishlab/potion-base-4M) 90 | - [minishlab/potion-base-8M](https://huggingface.co/minishlab/potion-base-8M) 91 | - [minishlab/potion-retrieval-32M](https://huggingface.co/minishlab/potion-retrieval-32M) 92 | - [minishlab/potion-base-32M](https://huggingface.co/minishlab/potion-base-32M) 93 | - [minishlab/M2V_base_output](https://huggingface.co/minishlab/M2V_base_output) 94 | 95 | ### Static Embeddings 96 | 97 | More info [here](https://huggingface.co/blog/static-embeddings). 98 | 99 | Some of the supported models on `Hugging Face`: 100 | 101 | - [sentence-transformers/static-retrieval-mrl-en-v1](https://huggingface.co/sentence-transformers/static-retrieval-mrl-en-v1) 102 | - [sentence-transformers/static-similarity-mrl-multilingual-v1](https://huggingface.co/sentence-transformers/static-similarity-mrl-multilingual-v1) 103 | 104 | ## Installation 105 | 106 | Add the following to your `Package.swift` file. In the package dependencies add: 107 | 108 | ```swift 109 | dependencies: [ 110 | .package(url: "https://github.com/jkrukowski/swift-embeddings", from: "0.0.16") 111 | ] 112 | ``` 113 | 114 | In the target dependencies add: 115 | 116 | ```swift 117 | dependencies: [ 118 | .product(name: "Embeddings", package: "swift-embeddings") 119 | ] 120 | ``` 121 | 122 | ## Usage 123 | 124 | ### Encoding 125 | 126 | ```swift 127 | import Embeddings 128 | 129 | // load model and tokenizer from Hugging Face 130 | let modelBundle = try await Bert.loadModelBundle( 131 | from: "sentence-transformers/all-MiniLM-L6-v2" 132 | ) 133 | 134 | // encode text 135 | let encoded = modelBundle.encode("The cat is black") 136 | let result = await encoded.cast(to: Float.self).shapedArray(of: Float.self).scalars 137 | 138 | // print result 139 | print(result) 140 | ``` 141 | 142 | ### Batch Encoding 143 | 144 | ```swift 145 | import Embeddings 146 | import MLTensorUtils 147 | 148 | let texts = [ 149 | "The cat is black", 150 | "The dog is black", 151 | "The cat sleeps well" 152 | ] 153 | let modelBundle = try await Bert.loadModelBundle( 154 | from: "sentence-transformers/all-MiniLM-L6-v2" 155 | ) 156 | let encoded = modelBundle.batchEncode(texts) 157 | let distance = cosineDistance(encoded, encoded) 158 | let result = await distance.cast(to: Float.self).shapedArray(of: Float.self).scalars 159 | print(result) 160 | ``` 161 | 162 | ## Command Line Demo 163 | 164 | To run the command line demo, use the following command: 165 | 166 | ```bash 167 | swift run embeddings-cli [--model-id ] [--model-file ] [--text ] [--max-length ] 168 | ``` 169 | 170 | Subcommands: 171 | 172 | ```bash 173 | bert Encode text using BERT model 174 | clip Encode text using CLIP model 175 | model2vec Encode text using Model2Vec model 176 | roberta Encode text using RoBERTa model 177 | static-embeddings Encode text using Static Embeddings model 178 | xlm-roberta Encode text using XLMRoberta model 179 | word2vec Encode word using Word2Vec model 180 | ``` 181 | 182 | Command line options: 183 | 184 | ```bash 185 | --model-id Id of the model to use 186 | --model-file Path to the model file (only for `Word2Vec`) 187 | --text Text to encode 188 | --max-length Maximum length of the input (not for `Word2Vec`) 189 | -h, --help Show help information. 190 | ``` 191 | 192 | ## Code Formatting 193 | 194 | This project uses [swift-format](https://github.com/swiftlang/swift-format). To format the code run: 195 | 196 | ```bash 197 | swift format . -i -r --configuration .swift-format 198 | ``` 199 | 200 | ## Acknowledgements 201 | 202 | This project is based on and uses some of the code from: 203 | 204 | - [mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings) 205 | -------------------------------------------------------------------------------- /Sources/Embeddings/Bert/BertModel.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Foundation 3 | import MLTensorUtils 4 | @preconcurrency import Tokenizers 5 | 6 | public enum Bert {} 7 | 8 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 9 | extension Bert { 10 | public struct ModelConfig: Codable { 11 | public var modelType: String 12 | public var numHiddenLayers: Int 13 | public var numAttentionHeads: Int 14 | public var hiddenSize: Int 15 | public var intermediateSize: Int 16 | public var maxPositionEmbeddings: Int 17 | public var hiddenDropoutProb: Float 18 | public var attentionProbsDropoutProb: Float 19 | public var typeVocabSize: Int 20 | public var initializerRange: Float 21 | public var layerNormEps: Float 22 | public var vocabSize: Int 23 | 24 | public init( 25 | modelType: String, 26 | numHiddenLayers: Int, 27 | numAttentionHeads: Int, 28 | hiddenSize: Int, 29 | intermediateSize: Int, 30 | maxPositionEmbeddings: Int, 31 | hiddenDropoutProb: Float = 0.1, 32 | attentionProbsDropoutProb: Float = 0.1, 33 | typeVocabSize: Int = 2, 34 | initializerRange: Float = 0.02, 35 | layerNormEps: Float = 1e-12, 36 | vocabSize: Int = 30522 37 | ) { 38 | self.modelType = modelType 39 | self.numHiddenLayers = numHiddenLayers 40 | self.numAttentionHeads = numAttentionHeads 41 | self.hiddenSize = hiddenSize 42 | self.intermediateSize = intermediateSize 43 | self.maxPositionEmbeddings = maxPositionEmbeddings 44 | self.hiddenDropoutProb = hiddenDropoutProb 45 | self.attentionProbsDropoutProb = attentionProbsDropoutProb 46 | self.typeVocabSize = typeVocabSize 47 | self.initializerRange = initializerRange 48 | self.layerNormEps = layerNormEps 49 | self.vocabSize = vocabSize 50 | } 51 | } 52 | } 53 | 54 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 55 | extension Bert { 56 | public struct Pooler: Sendable { 57 | let dense: MLTensorUtils.Layer 58 | 59 | public init(dense: @escaping MLTensorUtils.Layer) { 60 | self.dense = dense 61 | } 62 | 63 | public func callAsFunction(_ hiddenStates: MLTensor) -> MLTensor { 64 | let firstTokenTensor = hiddenStates[0..., 0] 65 | let pooledOutput = dense(firstTokenTensor) 66 | return pooledOutput.tanh() 67 | } 68 | } 69 | } 70 | 71 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 72 | extension Bert { 73 | public struct Embeddings: Sendable { 74 | let wordEmbeddings: MLTensorUtils.Layer 75 | let positionEmbeddings: MLTensorUtils.Layer 76 | let tokenTypeEmbeddings: MLTensorUtils.Layer 77 | let layerNorm: MLTensorUtils.Layer 78 | 79 | public init( 80 | wordEmbeddings: @escaping MLTensorUtils.Layer, 81 | positionEmbeddings: @escaping MLTensorUtils.Layer, 82 | tokenTypeEmbeddings: @escaping MLTensorUtils.Layer, 83 | layerNorm: @escaping MLTensorUtils.Layer 84 | ) { 85 | self.wordEmbeddings = wordEmbeddings 86 | self.positionEmbeddings = positionEmbeddings 87 | self.tokenTypeEmbeddings = tokenTypeEmbeddings 88 | self.layerNorm = layerNorm 89 | } 90 | 91 | public func callAsFunction( 92 | inputIds: MLTensor, 93 | tokenTypeIds: MLTensor? = nil, 94 | positionIds: MLTensor? = nil 95 | ) -> MLTensor { 96 | let seqLength = inputIds.shape[1] 97 | let positionIds = 98 | positionIds 99 | ?? MLTensor( 100 | shape: [1, seqLength], 101 | scalars: 0.. MLTensor { 137 | let dense = dense(hiddenStates) 138 | let layerNormInput = dense + inputTensor 139 | return layerNorm(layerNormInput) 140 | } 141 | } 142 | } 143 | 144 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 145 | extension Bert { 146 | public struct Intermediate: Sendable { 147 | let dense: MLTensorUtils.Layer 148 | 149 | public init(dense: @escaping MLTensorUtils.Layer) { 150 | self.dense = dense 151 | } 152 | 153 | public func callAsFunction(hiddenStates: MLTensor) -> MLTensor { 154 | let dense = dense(hiddenStates) 155 | return gelu(dense) 156 | } 157 | } 158 | } 159 | 160 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 161 | extension Bert { 162 | public struct SelfOutput: Sendable { 163 | let dense: MLTensorUtils.Layer 164 | let layerNorm: MLTensorUtils.Layer 165 | 166 | public init( 167 | dense: @escaping MLTensorUtils.Layer, 168 | layerNorm: @escaping MLTensorUtils.Layer 169 | ) { 170 | self.dense = dense 171 | self.layerNorm = layerNorm 172 | } 173 | 174 | public func callAsFunction( 175 | hiddenStates: MLTensor, 176 | inputTensor: MLTensor 177 | ) -> MLTensor { 178 | let dense = dense(hiddenStates) 179 | let layerNormInput = dense + inputTensor 180 | return layerNorm(layerNormInput) 181 | } 182 | } 183 | } 184 | 185 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 186 | extension Bert { 187 | public struct SelfAttention: Sendable { 188 | let query: MLTensorUtils.Layer 189 | let key: MLTensorUtils.Layer 190 | let value: MLTensorUtils.Layer 191 | let numAttentionHeads: Int 192 | let attentionHeadSize: Int 193 | let allHeadSize: Int 194 | 195 | public init( 196 | query: @escaping MLTensorUtils.Layer, 197 | key: @escaping MLTensorUtils.Layer, 198 | value: @escaping MLTensorUtils.Layer, 199 | numAttentionHeads: Int, 200 | attentionHeadSize: Int, 201 | allHeadSize: Int 202 | ) { 203 | self.query = query 204 | self.key = key 205 | self.value = value 206 | self.numAttentionHeads = numAttentionHeads 207 | self.attentionHeadSize = attentionHeadSize 208 | self.allHeadSize = allHeadSize 209 | } 210 | 211 | private func transposeForScores(_ x: MLTensor) -> MLTensor { 212 | let newShape = x.shape.dropLast() + [numAttentionHeads, attentionHeadSize] 213 | return x.reshaped(to: Array(newShape)).transposed(permutation: 0, 2, 1, 3) 214 | } 215 | 216 | public func callAsFunction( 217 | hiddenStates: MLTensor, 218 | attentionMask: MLTensor? 219 | ) -> MLTensor { 220 | let mixedQueryLayer = query(hiddenStates) 221 | let mixedKeyLayer = key(hiddenStates) 222 | let mixedValueLayer = value(hiddenStates) 223 | 224 | let queryLayer = transposeForScores(mixedQueryLayer) 225 | let keyLayer = transposeForScores(mixedKeyLayer) 226 | let valueLayer = transposeForScores(mixedValueLayer) 227 | 228 | var attentionScores = queryLayer.matmul(keyLayer.transposed(permutation: 0, 1, 3, 2)) 229 | attentionScores = attentionScores / sqrt(Float(attentionHeadSize)) 230 | if let attentionMask { 231 | attentionScores = attentionScores + attentionMask 232 | } 233 | let attentionProbs = attentionScores.softmax(alongAxis: -1) 234 | var contextLayer = attentionProbs.matmul(valueLayer) 235 | contextLayer = contextLayer.transposed(permutation: [0, 2, 1, 3]) 236 | let newContextLayerShape = contextLayer.shape.dropLast(2) + [allHeadSize] 237 | return contextLayer.reshaped(to: Array(newContextLayerShape)) 238 | } 239 | } 240 | } 241 | 242 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 243 | extension Bert { 244 | public struct Attention: Sendable { 245 | let selfAttention: Bert.SelfAttention 246 | let output: Bert.SelfOutput 247 | 248 | public init( 249 | selfAttention: Bert.SelfAttention, 250 | output: Bert.SelfOutput 251 | ) { 252 | self.selfAttention = selfAttention 253 | self.output = output 254 | } 255 | 256 | public func callAsFunction( 257 | hiddenStates: MLTensor, 258 | attentionMask: MLTensor? 259 | ) -> MLTensor { 260 | let selfOutputs = selfAttention( 261 | hiddenStates: hiddenStates, 262 | attentionMask: attentionMask 263 | ) 264 | return output( 265 | hiddenStates: selfOutputs, 266 | inputTensor: hiddenStates 267 | ) 268 | } 269 | } 270 | } 271 | 272 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 273 | extension Bert { 274 | public struct Layer: Sendable { 275 | let attention: Bert.Attention 276 | let intermediate: Bert.Intermediate 277 | let output: Bert.Output 278 | 279 | public init( 280 | attention: Bert.Attention, 281 | intermediate: Bert.Intermediate, 282 | output: Bert.Output 283 | ) { 284 | self.attention = attention 285 | self.intermediate = intermediate 286 | self.output = output 287 | } 288 | 289 | public func callAsFunction( 290 | hiddenStates: MLTensor, 291 | attentionMask: MLTensor? 292 | ) -> MLTensor { 293 | let attentionOutput = attention( 294 | hiddenStates: hiddenStates, 295 | attentionMask: attentionMask 296 | ) 297 | let intermediateOutput = intermediate( 298 | hiddenStates: attentionOutput 299 | ) 300 | return output( 301 | hiddenStates: intermediateOutput, 302 | inputTensor: attentionOutput 303 | ) 304 | } 305 | } 306 | } 307 | 308 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 309 | extension Bert { 310 | public struct Encoder: Sendable { 311 | let layers: [Bert.Layer] 312 | 313 | public init(layers: [Bert.Layer]) { 314 | self.layers = layers 315 | } 316 | 317 | public func callAsFunction( 318 | hiddenStates: MLTensor, 319 | attentionMask: MLTensor? 320 | ) -> MLTensor { 321 | var hiddenStates = hiddenStates 322 | for layer in layers { 323 | hiddenStates = layer( 324 | hiddenStates: hiddenStates, 325 | attentionMask: attentionMask 326 | ) 327 | } 328 | return hiddenStates 329 | } 330 | } 331 | } 332 | 333 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 334 | extension Bert { 335 | public struct Model: Sendable { 336 | let embeddings: Bert.Embeddings 337 | let encoder: Bert.Encoder 338 | let pooler: Bert.Pooler 339 | 340 | public init( 341 | embeddings: Bert.Embeddings, 342 | encoder: Bert.Encoder, 343 | pooler: Bert.Pooler 344 | ) { 345 | self.embeddings = embeddings 346 | self.encoder = encoder 347 | self.pooler = pooler 348 | } 349 | 350 | public func callAsFunction( 351 | inputIds: MLTensor, 352 | tokenTypeIds: MLTensor? = nil, 353 | attentionMask: MLTensor? = nil 354 | ) -> (sequenceOutput: MLTensor, pooledOutput: MLTensor) { 355 | let embeddingOutput = embeddings(inputIds: inputIds, tokenTypeIds: tokenTypeIds) 356 | let mask: MLTensor? = 357 | if let attentionMask { 358 | (1.0 - attentionMask.expandingShape(at: 1, 1)) * -10000.0 359 | } else { 360 | nil 361 | } 362 | let encoderOutput = encoder(hiddenStates: embeddingOutput, attentionMask: mask) 363 | let pooledOutput = pooler(encoderOutput) 364 | return (encoderOutput, pooledOutput) 365 | } 366 | } 367 | } 368 | 369 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 370 | extension Bert { 371 | public struct ModelBundle: Sendable { 372 | public let model: Bert.Model 373 | public let tokenizer: any TextTokenizer 374 | 375 | public init( 376 | model: Bert.Model, 377 | tokenizer: any TextTokenizer 378 | ) { 379 | self.model = model 380 | self.tokenizer = tokenizer 381 | } 382 | 383 | public func encode( 384 | _ text: String, 385 | maxLength: Int = 512 386 | ) throws -> MLTensor { 387 | let tokens = try tokenizer.tokenizeText(text, maxLength: maxLength) 388 | let inputIds = MLTensor(shape: [1, tokens.count], scalars: tokens) 389 | let result = model(inputIds: inputIds) 390 | return result.sequenceOutput[0..., 0, 0...] 391 | } 392 | 393 | public func batchEncode( 394 | _ texts: [String], 395 | padTokenId: Int = 0, 396 | maxLength: Int = 512 397 | ) throws -> MLTensor { 398 | let encodedTexts = try tokenizer.tokenizeTextsPaddingToLongest( 399 | texts, padTokenId: padTokenId, maxLength: maxLength) 400 | let inputIds = MLTensor( 401 | shape: [encodedTexts.count, encodedTexts[0].count], 402 | scalars: encodedTexts.flatMap { $0 }) 403 | return model(inputIds: inputIds).sequenceOutput[0..., 0, 0...] 404 | } 405 | } 406 | } 407 | -------------------------------------------------------------------------------- /Sources/Embeddings/Bert/BertUtils.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Foundation 3 | import Hub 4 | import MLTensorUtils 5 | import Safetensors 6 | @preconcurrency import Tokenizers 7 | 8 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 9 | extension Bert { 10 | public static func loadConfig(at url: URL) throws -> Bert.ModelConfig { 11 | try loadConfigFromFile(at: url) 12 | } 13 | } 14 | 15 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 16 | extension Bert { 17 | public static func loadModelBundle( 18 | from hubRepoId: String, 19 | downloadBase: URL? = nil, 20 | useBackgroundSession: Bool = false, 21 | loadConfig: LoadConfig = LoadConfig() 22 | ) async throws -> Bert.ModelBundle { 23 | let modelFolder = try await downloadModelFromHub( 24 | from: hubRepoId, 25 | downloadBase: downloadBase, 26 | useBackgroundSession: useBackgroundSession 27 | ) 28 | return try await loadModelBundle( 29 | from: modelFolder, 30 | loadConfig: loadConfig 31 | ) 32 | } 33 | 34 | public static func loadModelBundle( 35 | from modelFolder: URL, 36 | loadConfig: LoadConfig = LoadConfig() 37 | ) async throws -> Bert.ModelBundle { 38 | let tokenizer = try await AutoTokenizer.from(modelFolder: modelFolder) 39 | let weightsUrl = modelFolder.appendingPathComponent(loadConfig.modelConfig.weightsFileName) 40 | let configUrl = modelFolder.appendingPathComponent(loadConfig.modelConfig.configFileName) 41 | let config = try Bert.loadConfig(at: configUrl) 42 | let model = try Bert.loadModel( 43 | weightsUrl: weightsUrl, 44 | config: config, 45 | loadConfig: loadConfig 46 | ) 47 | return Bert.ModelBundle(model: model, tokenizer: TokenizerWrapper(tokenizer)) 48 | } 49 | } 50 | 51 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 52 | extension Bert { 53 | // NOTE: this is a simple key transformation that is required for the Google BERT weights. 54 | // Model available here: [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) 55 | public static func googleWeightsKeyTransform(_ key: String) -> String { 56 | "bert.\(key)" 57 | .replace(suffix: ".LayerNorm.weight", with: ".LayerNorm.gamma") 58 | .replace(suffix: ".LayerNorm.bias", with: ".LayerNorm.beta") 59 | } 60 | } 61 | 62 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 63 | extension Bert { 64 | public static func loadModel( 65 | weightsUrl: URL, 66 | config: Bert.ModelConfig, 67 | loadConfig: LoadConfig = LoadConfig() 68 | ) throws -> Bert.Model { 69 | // NOTE: just `safetensors` support for now 70 | let safetensors = try Safetensors.read(at: weightsUrl) 71 | let pooler = try Bert.Pooler( 72 | dense: MLTensorUtils.linear( 73 | weight: safetensors.mlTensor( 74 | forKey: loadConfig.modelConfig.weightKeyTransform("pooler.dense.weight")), 75 | bias: safetensors.mlTensor( 76 | forKey: loadConfig.modelConfig.weightKeyTransform("pooler.dense.bias")))) 77 | 78 | let wordEmbeddings = try MLTensorUtils.embedding( 79 | weight: safetensors.mlTensor( 80 | forKey: loadConfig.modelConfig.weightKeyTransform( 81 | "embeddings.word_embeddings.weight"))) 82 | 83 | let tokenTypeEmbeddings = try MLTensorUtils.embedding( 84 | weight: safetensors.mlTensor( 85 | forKey: loadConfig.modelConfig.weightKeyTransform( 86 | "embeddings.token_type_embeddings.weight"))) 87 | 88 | let positionEmbeddings = try MLTensorUtils.embedding( 89 | weight: safetensors.mlTensor( 90 | forKey: loadConfig.modelConfig.weightKeyTransform( 91 | "embeddings.position_embeddings.weight"))) 92 | 93 | let layerNorm = try MLTensorUtils.layerNorm( 94 | weight: safetensors.mlTensor( 95 | forKey: loadConfig.modelConfig.weightKeyTransform("embeddings.LayerNorm.weight")), 96 | bias: safetensors.mlTensor( 97 | forKey: loadConfig.modelConfig.weightKeyTransform("embeddings.LayerNorm.bias")), 98 | epsilon: config.layerNormEps) 99 | 100 | let embeddings = Bert.Embeddings( 101 | wordEmbeddings: wordEmbeddings, 102 | positionEmbeddings: positionEmbeddings, 103 | tokenTypeEmbeddings: tokenTypeEmbeddings, 104 | layerNorm: layerNorm) 105 | 106 | var layers = [Bert.Layer]() 107 | for layer in 0.. MLTensor { 109 | let embeddings = tokenEmbedding(x) 110 | return embeddings + positionEmbeddingWeight[0.. MLTensor { 132 | fc2(gelu(fc1(x), approximation: .fast)) 133 | } 134 | } 135 | } 136 | 137 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 138 | extension Clip { 139 | public struct Attention: Sendable { 140 | let qProj: MLTensorUtils.Layer 141 | let kProj: MLTensorUtils.Layer 142 | let vProj: MLTensorUtils.Layer 143 | let outProj: MLTensorUtils.Layer 144 | private let numHeads: Int 145 | 146 | public init( 147 | qProj: @escaping MLTensorUtils.Layer, 148 | kProj: @escaping MLTensorUtils.Layer, 149 | vProj: @escaping MLTensorUtils.Layer, 150 | outProj: @escaping MLTensorUtils.Layer, 151 | numHeads: Int 152 | ) { 153 | self.qProj = qProj 154 | self.kProj = kProj 155 | self.vProj = vProj 156 | self.outProj = outProj 157 | self.numHeads = numHeads 158 | } 159 | 160 | public func callAsFunction( 161 | queries: MLTensor, 162 | keys: MLTensor, 163 | values: MLTensor, 164 | mask: MLTensor? = nil 165 | ) -> MLTensor { 166 | var queries = qProj(queries) 167 | var keys = kProj(keys) 168 | var values = vProj(values) 169 | let B = queries.shape[0] 170 | let L = queries.shape[1] 171 | let S = keys.shape[1] 172 | queries = queries.reshaped(to: [B, L, numHeads, -1]).transposed(permutation: 0, 2, 1, 3) 173 | keys = keys.reshaped(to: [B, S, numHeads, -1]).transposed(permutation: 0, 2, 3, 1) 174 | values = values.reshaped(to: [B, S, numHeads, -1]).transposed(permutation: 0, 2, 1, 3) 175 | let scale = sqrt(1.0 / Float(queries.shape.last!)) 176 | var scores = (queries * scale).matmul(keys) 177 | if let mask = mask { 178 | scores = scores + mask.cast(like: scores) 179 | } 180 | scores = scores.softmax(alongAxis: -1) 181 | let valuesHat = scores.matmul(values) 182 | .transposed(permutation: 0, 2, 1, 3) 183 | .reshaped(to: [B, L, -1]) 184 | return outProj(valuesHat) 185 | } 186 | } 187 | } 188 | 189 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 190 | extension Clip { 191 | public struct EncoderLayer: Sendable { 192 | let selfAttnention: Attention 193 | let mlp: MLP 194 | let layerNorm1: MLTensorUtils.Layer 195 | let layerNorm2: MLTensorUtils.Layer 196 | 197 | public init( 198 | selfAttnention: Attention, 199 | mlp: MLP, 200 | layerNorm1: @escaping MLTensorUtils.Layer, 201 | layerNorm2: @escaping MLTensorUtils.Layer 202 | ) { 203 | self.selfAttnention = selfAttnention 204 | self.mlp = mlp 205 | self.layerNorm1 = layerNorm1 206 | self.layerNorm2 = layerNorm2 207 | } 208 | 209 | public func callAsFunction( 210 | x: MLTensor, 211 | mask: MLTensor? = nil 212 | ) -> MLTensor { 213 | var y = layerNorm1(x) 214 | y = selfAttnention(queries: y, keys: y, values: y, mask: mask) 215 | let x = x + y 216 | y = layerNorm2(x) 217 | y = mlp(x: y) 218 | return x + y 219 | } 220 | } 221 | } 222 | 223 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 224 | extension Clip { 225 | public struct Encoder: Sendable { 226 | let layers: [EncoderLayer] 227 | 228 | public init(layers: [EncoderLayer]) { 229 | self.layers = layers 230 | } 231 | 232 | public func callAsFunction( 233 | x: MLTensor, 234 | mask: MLTensor? = nil 235 | ) -> MLTensor { 236 | var x = x 237 | for layer in layers { 238 | x = layer(x: x, mask: mask) 239 | } 240 | return x 241 | } 242 | } 243 | } 244 | 245 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 246 | extension Clip { 247 | public struct TextModel: Sendable { 248 | let embeddings: Embeddings 249 | let encoder: Encoder 250 | let finalLayerNorm: MLTensorUtils.Layer 251 | let textProjection: MLTensorUtils.Layer 252 | 253 | public init( 254 | embeddings: Embeddings, 255 | encoder: Encoder, 256 | finalLayerNorm: @escaping MLTensorUtils.Layer, 257 | textProjection: @escaping MLTensorUtils.Layer 258 | ) { 259 | self.embeddings = embeddings 260 | self.encoder = encoder 261 | self.finalLayerNorm = finalLayerNorm 262 | self.textProjection = textProjection 263 | } 264 | 265 | public func callAsFunction( 266 | inputIds: MLTensor 267 | ) -> (lastHiddenState: MLTensor, poolerOutput: MLTensor) { 268 | let N = Int32(inputIds.shape[1]) 269 | let eotTokens = inputIds.argmax(alongAxis: -1) 270 | var x = embeddings(x: inputIds) 271 | let mask = additiveCausalMask(N, scalarType: x.scalarType) 272 | x = encoder(x: x, mask: mask) 273 | let lastHiddenState = finalLayerNorm(x) 274 | let poolerOutput = lastHiddenState.gathering(atIndices: eotTokens, alongAxis: 1) 275 | return (lastHiddenState: lastHiddenState, poolerOutput: poolerOutput) 276 | } 277 | } 278 | } 279 | 280 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 281 | extension Clip { 282 | public struct ModelBundle: Sendable { 283 | public let textModel: Clip.TextModel 284 | public let tokenizer: any TextTokenizer 285 | 286 | public init( 287 | textModel: Clip.TextModel, 288 | tokenizer: any TextTokenizer 289 | ) { 290 | self.textModel = textModel 291 | self.tokenizer = tokenizer 292 | } 293 | 294 | public func encode(_ text: String, maxLength: Int = 77) throws -> MLTensor { 295 | let tokens = try tokenizer.tokenizeText(text, maxLength: maxLength) 296 | let inputIds = MLTensor(shape: [1, tokens.count], scalars: tokens) 297 | let modelOutput = textModel(inputIds: inputIds) 298 | let textEmbeddings = textModel.textProjection(modelOutput.poolerOutput) 299 | return textEmbeddings / norm(textEmbeddings, alongAxes: -1, keepRank: true) 300 | } 301 | 302 | public func batchEncode( 303 | _ texts: [String], 304 | padTokenId: Int = 0, 305 | maxLength: Int = 77 306 | ) throws -> MLTensor { 307 | let encodedTexts = try tokenizer.tokenizeTextsPaddingToLongest( 308 | texts, padTokenId: padTokenId, maxLength: maxLength) 309 | let inputIds = MLTensor( 310 | shape: [encodedTexts.count, encodedTexts[0].count], 311 | scalars: encodedTexts.flatMap { $0 }) 312 | let modelOutput = textModel(inputIds: inputIds) 313 | let textEmbeddings = textModel.textProjection(modelOutput.poolerOutput) 314 | return textEmbeddings / norm(textEmbeddings, alongAxes: -1, keepRank: true) 315 | } 316 | } 317 | } 318 | -------------------------------------------------------------------------------- /Sources/Embeddings/Clip/ClipUtils.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Foundation 3 | import Hub 4 | import MLTensorUtils 5 | import Safetensors 6 | 7 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 8 | extension Clip { 9 | public static func loadConfig(at url: URL) throws -> Clip.ModelConfig { 10 | try loadConfigFromFile(at: url) 11 | } 12 | } 13 | 14 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 15 | extension Clip { 16 | public static func loadModelBundle( 17 | from hubRepoId: String, 18 | downloadBase: URL? = nil, 19 | useBackgroundSession: Bool = false, 20 | loadConfig: LoadConfig = LoadConfig() 21 | ) async throws -> Clip.ModelBundle { 22 | let modelFolder = try await downloadModelFromHub( 23 | from: hubRepoId, 24 | downloadBase: downloadBase, 25 | useBackgroundSession: useBackgroundSession 26 | ) 27 | return try await loadModelBundle( 28 | from: modelFolder, 29 | loadConfig: loadConfig 30 | ) 31 | } 32 | 33 | public static func loadModelBundle( 34 | from modelFolder: URL, 35 | loadConfig: LoadConfig = LoadConfig() 36 | ) async throws -> Clip.ModelBundle { 37 | let tokenizer = try loadClipTokenizer(at: modelFolder) 38 | let weightsUrl = modelFolder.appendingPathComponent(loadConfig.modelConfig.weightsFileName) 39 | let configUrl = modelFolder.appendingPathComponent(loadConfig.modelConfig.configFileName) 40 | let config = try Clip.loadConfig(at: configUrl) 41 | let textModel = try Clip.loadModel( 42 | weightsUrl: weightsUrl, 43 | config: config, 44 | loadConfig: loadConfig 45 | ) 46 | return Clip.ModelBundle(textModel: textModel, tokenizer: tokenizer) 47 | } 48 | } 49 | 50 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 51 | extension Clip { 52 | public static func loadModel( 53 | weightsUrl: URL, 54 | config: Clip.ModelConfig, 55 | loadConfig: LoadConfig = LoadConfig() 56 | ) throws -> Clip.TextModel { 57 | let safetensors = try Safetensors.read(at: weightsUrl) 58 | let embeddings = try Clip.Embeddings( 59 | tokenEmbedding: MLTensorUtils.embedding( 60 | weight: safetensors.mlTensor( 61 | forKey: loadConfig.modelConfig.weightKeyTransform( 62 | "text_model.embeddings.token_embedding.weight"))), 63 | positionEmbeddingWeight: safetensors.mlTensor( 64 | forKey: loadConfig.modelConfig.weightKeyTransform( 65 | "text_model.embeddings.position_embedding.weight"))) 66 | var encoderLayers = [Clip.EncoderLayer]() 67 | encoderLayers.reserveCapacity(config.textConfig.numHiddenLayers) 68 | for i in 0.. URL { 11 | let hubApi = HubApi(downloadBase: downloadBase, useBackgroundSession: useBackgroundSession) 12 | let repo = Hub.Repo(id: hubRepoId, type: .models) 13 | return try await hubApi.snapshot( 14 | from: repo, 15 | matching: globs 16 | ) 17 | } 18 | 19 | enum EmbeddingsError: Error { 20 | case fileNotFound 21 | case invalidFile 22 | } 23 | 24 | enum Constants { 25 | static let modelGlobs = [ 26 | "*.json", 27 | "*.safetensors", 28 | "*.py", 29 | "tokenizer.model", 30 | "sentencepiece*.model", 31 | "*.tiktoken", 32 | "*.txt", 33 | ] 34 | } 35 | 36 | func loadConfigFromFile(at url: URL) throws -> Config { 37 | let configData = try Data(contentsOf: url) 38 | let decoder = JSONDecoder() 39 | decoder.keyDecodingStrategy = .convertFromSnakeCase 40 | return try decoder.decode(Config.self, from: configData) 41 | } 42 | 43 | extension String { 44 | func replace(suffix: String, with string: String) -> String { 45 | guard hasSuffix(suffix) else { return self } 46 | return String(dropLast(suffix.count) + string) 47 | } 48 | } 49 | 50 | public enum TokenizerConfigType { 51 | case filePath(String) 52 | case data([String: Any]) 53 | } 54 | 55 | public struct TokenizerConfig { 56 | public let data: TokenizerConfigType 57 | public let config: TokenizerConfigType 58 | 59 | public init( 60 | data: TokenizerConfigType = .filePath("tokenizer.json"), 61 | config: TokenizerConfigType = .filePath("tokenizer_config.json") 62 | ) { 63 | self.data = data 64 | self.config = config 65 | } 66 | } 67 | 68 | public struct ModelConfig { 69 | public let configFileName: String 70 | public let weightsFileName: String 71 | public let weightKeyTransform: ((String) -> String) 72 | 73 | public init( 74 | configFileName: String = "config.json", 75 | weightsFileName: String = "model.safetensors", 76 | weightKeyTransform: @escaping ((String) -> String) = { $0 } 77 | ) { 78 | self.configFileName = configFileName 79 | self.weightsFileName = weightsFileName 80 | self.weightKeyTransform = weightKeyTransform 81 | } 82 | } 83 | 84 | public struct LoadConfig { 85 | public let modelConfig: ModelConfig 86 | public let tokenizerConfig: TokenizerConfig? 87 | 88 | public init( 89 | modelConfig: ModelConfig = ModelConfig(), 90 | tokenizerConfig: TokenizerConfig? = nil 91 | ) { 92 | self.modelConfig = modelConfig 93 | self.tokenizerConfig = tokenizerConfig 94 | } 95 | } 96 | 97 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 98 | extension LoadConfig { 99 | public static var googleBert: LoadConfig { 100 | LoadConfig( 101 | modelConfig: ModelConfig( 102 | weightKeyTransform: Bert.googleWeightsKeyTransform 103 | ) 104 | ) 105 | } 106 | 107 | public static func addWeightKeyPrefix(_ prefix: String) -> LoadConfig { 108 | LoadConfig( 109 | modelConfig: ModelConfig( 110 | weightKeyTransform: { key in 111 | "\(prefix)\(key)" 112 | } 113 | ) 114 | ) 115 | } 116 | 117 | public static var staticEmbeddings: LoadConfig { 118 | LoadConfig( 119 | modelConfig: ModelConfig( 120 | weightsFileName: "0_StaticEmbedding/model.safetensors" 121 | ), 122 | // In case of `StaticEmbeddings` tokenizer `data` is loaded from `0_StaticEmbedding/tokenizer.json` file 123 | // and tokenizer `config` is a dictionary with a single key `tokenizerClass` and value `BertTokenizer`. 124 | tokenizerConfig: TokenizerConfig( 125 | data: .filePath("0_StaticEmbedding/tokenizer.json"), 126 | config: .data(["tokenizerClass": "BertTokenizer"]) 127 | ) 128 | ) 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /Sources/Embeddings/Model2Vec/Model2VecModel.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Foundation 3 | import MLTensorUtils 4 | 5 | public enum Model2Vec {} 6 | 7 | extension Model2Vec { 8 | public struct ModelConfig: Codable { 9 | public var normalize: Bool? 10 | 11 | public init(normalize: Bool? = nil) { 12 | self.normalize = normalize 13 | } 14 | } 15 | } 16 | 17 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 18 | extension Model2Vec { 19 | public struct ModelBundle: Sendable { 20 | public let model: Model2Vec.Model 21 | public let tokenizer: any TextTokenizer 22 | 23 | public init( 24 | model: Model2Vec.Model, 25 | tokenizer: any TextTokenizer 26 | ) { 27 | self.model = model 28 | self.tokenizer = tokenizer 29 | } 30 | 31 | public func encode( 32 | _ text: String, 33 | normalize: Bool = false, 34 | maxLength: Int? = nil 35 | ) throws -> MLTensor { 36 | try batchEncode([text], normalize: normalize, maxLength: maxLength) 37 | } 38 | 39 | public func batchEncode( 40 | _ texts: [String], 41 | normalize: Bool = false, 42 | maxLength: Int? = nil 43 | ) throws -> MLTensor { 44 | let inputIdsBatch = try texts.map { try tokenize($0, maxLength: maxLength) } 45 | let embeddingsBatch = inputIdsBatch.map { inputIds in 46 | if let inputIds { 47 | model.embeddings 48 | .gathering(atIndices: inputIds, alongAxis: 0) 49 | .mean(alongAxes: 0) 50 | } else { 51 | MLTensor(zeros: [model.dimienstion], scalarType: Int32.self) 52 | } 53 | } 54 | let embeddings = MLTensor(stacking: embeddingsBatch, alongAxis: 0).cast(to: Float.self) 55 | if normalize { 56 | let norm = norm(embeddings, alongAxes: 1, keepRank: true) + Float.ulpOfOne 57 | return embeddings / norm 58 | } else { 59 | return embeddings 60 | } 61 | } 62 | 63 | private func tokenize( 64 | _ text: String, 65 | maxLength: Int? 66 | ) throws -> MLTensor? { 67 | let tokensIds = try tokenizer.tokenizeText( 68 | text, maxLength: maxLength, addSpecialTokens: false) 69 | let tokens = 70 | if let unknownTokenId = tokenizer.unknownTokenId { 71 | tokensIds.filter { $0 != unknownTokenId } 72 | } else { 73 | tokensIds 74 | } 75 | return tokens.isEmpty ? nil : MLTensor(shape: [tokens.count], scalars: tokens) 76 | } 77 | } 78 | } 79 | 80 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 81 | extension Model2Vec { 82 | public struct Model: Sendable { 83 | public let embeddings: MLTensor 84 | public let dimienstion: Int 85 | public let normalize: Bool 86 | 87 | public init(embeddings: MLTensor, normalize: Bool = false) { 88 | self.embeddings = embeddings 89 | self.dimienstion = embeddings.shape[1] 90 | self.normalize = normalize 91 | } 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /Sources/Embeddings/Model2Vec/Model2VecUtils.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Foundation 3 | import Hub 4 | import MLTensorUtils 5 | import Safetensors 6 | @preconcurrency import Tokenizers 7 | 8 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 9 | extension Model2Vec { 10 | public static func loadConfig(at url: URL) throws -> Model2Vec.ModelConfig { 11 | try loadConfigFromFile(at: url) 12 | } 13 | } 14 | 15 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 16 | extension Model2Vec { 17 | public static func loadModelBundle( 18 | from hubRepoId: String, 19 | downloadBase: URL? = nil, 20 | useBackgroundSession: Bool = false, 21 | loadConfig: LoadConfig = LoadConfig() 22 | ) async throws -> Model2Vec.ModelBundle { 23 | let modelFolder = try await downloadModelFromHub( 24 | from: hubRepoId, 25 | downloadBase: downloadBase, 26 | useBackgroundSession: useBackgroundSession 27 | ) 28 | return try await loadModelBundle( 29 | from: modelFolder, 30 | loadConfig: loadConfig 31 | ) 32 | } 33 | 34 | public static func loadModelBundle( 35 | from modelFolder: URL, 36 | loadConfig: LoadConfig = LoadConfig() 37 | ) async throws -> Model2Vec.ModelBundle { 38 | let tokenizer = try await AutoTokenizer.from(modelFolder: modelFolder) 39 | let weightsUrl = modelFolder.appendingPathComponent(loadConfig.modelConfig.weightsFileName) 40 | let configUrl = modelFolder.appendingPathComponent(loadConfig.modelConfig.configFileName) 41 | let config = try Model2Vec.loadConfig(at: configUrl) 42 | let model = try Model2Vec.loadModel( 43 | weightsUrl: weightsUrl, 44 | normalize: config.normalize ?? false, 45 | loadConfig: loadConfig 46 | ) 47 | return Model2Vec.ModelBundle( 48 | model: model, 49 | tokenizer: TokenizerWrapper(tokenizer) 50 | ) 51 | } 52 | } 53 | 54 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 55 | extension Model2Vec { 56 | public static func loadModel( 57 | weightsUrl: URL, 58 | normalize: Bool, 59 | loadConfig: LoadConfig = LoadConfig() 60 | ) throws -> Model2Vec.Model { 61 | let data = try Safetensors.read(at: weightsUrl) 62 | let embeddings = try data.mlTensor( 63 | forKey: loadConfig.modelConfig.weightKeyTransform("embeddings")) 64 | return Model2Vec.Model(embeddings: embeddings, normalize: normalize) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /Sources/Embeddings/Roberta/RobertaModel.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Foundation 3 | import MLTensorUtils 4 | @preconcurrency import Tokenizers 5 | 6 | public enum Roberta {} 7 | 8 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 9 | extension Roberta { 10 | public struct ModelConfig: Codable { 11 | public var modelType: String 12 | public var numHiddenLayers: Int 13 | public var numAttentionHeads: Int 14 | public var hiddenSize: Int 15 | public var intermediateSize: Int 16 | public var maxPositionEmbeddings: Int 17 | public var hiddenDropoutProb: Float 18 | public var attentionProbsDropoutProb: Float 19 | public var typeVocabSize: Int 20 | public var initializerRange: Float 21 | public var layerNormEps: Float 22 | public var vocabSize: Int 23 | public var padTokenId: Int 24 | 25 | public init( 26 | modelType: String, 27 | numHiddenLayers: Int, 28 | numAttentionHeads: Int, 29 | hiddenSize: Int, 30 | intermediateSize: Int, 31 | maxPositionEmbeddings: Int, 32 | hiddenDropoutProb: Float = 0.1, 33 | attentionProbsDropoutProb: Float = 0.1, 34 | typeVocabSize: Int = 1, 35 | initializerRange: Float = 0.02, 36 | layerNormEps: Float = 1e-05, 37 | vocabSize: Int = 50265, 38 | padTokenId: Int = 1 39 | ) { 40 | self.modelType = modelType 41 | self.numHiddenLayers = numHiddenLayers 42 | self.numAttentionHeads = numAttentionHeads 43 | self.hiddenSize = hiddenSize 44 | self.intermediateSize = intermediateSize 45 | self.maxPositionEmbeddings = maxPositionEmbeddings 46 | self.hiddenDropoutProb = hiddenDropoutProb 47 | self.attentionProbsDropoutProb = attentionProbsDropoutProb 48 | self.typeVocabSize = typeVocabSize 49 | self.initializerRange = initializerRange 50 | self.layerNormEps = layerNormEps 51 | self.vocabSize = vocabSize 52 | self.padTokenId = padTokenId 53 | } 54 | } 55 | } 56 | 57 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 58 | extension Roberta { 59 | public struct Embeddings: Sendable { 60 | let wordEmbeddings: MLTensorUtils.Layer 61 | let positionEmbeddings: MLTensorUtils.Layer 62 | let tokenTypeEmbeddings: MLTensorUtils.Layer 63 | let layerNorm: MLTensorUtils.Layer 64 | let paddingIdx: Int32 65 | 66 | public init( 67 | wordEmbeddings: @escaping MLTensorUtils.Layer, 68 | positionEmbeddings: @escaping MLTensorUtils.Layer, 69 | tokenTypeEmbeddings: @escaping MLTensorUtils.Layer, 70 | layerNorm: @escaping MLTensorUtils.Layer, 71 | paddingIdx: Int32 72 | ) { 73 | self.wordEmbeddings = wordEmbeddings 74 | self.positionEmbeddings = positionEmbeddings 75 | self.tokenTypeEmbeddings = tokenTypeEmbeddings 76 | self.layerNorm = layerNorm 77 | self.paddingIdx = paddingIdx 78 | } 79 | 80 | public func callAsFunction( 81 | inputIds: MLTensor, 82 | tokenTypeIds: MLTensor? = nil, 83 | positionIds: MLTensor? = nil 84 | ) -> MLTensor { 85 | let positionIds = 86 | positionIds 87 | ?? createPositionIdsFromInputIds( 88 | inputIds: inputIds, 89 | paddingIdx: paddingIdx 90 | ) 91 | let tokenTypeIds = 92 | tokenTypeIds 93 | ?? MLTensor( 94 | zeros: inputIds.shape, 95 | scalarType: Int32.self 96 | ) 97 | let wordsEmbeddings = wordEmbeddings(inputIds) 98 | let positionEmbeddings = positionEmbeddings(positionIds) 99 | let tokenTypeEmbeddings = tokenTypeEmbeddings(tokenTypeIds) 100 | let embeddings = wordsEmbeddings + positionEmbeddings + tokenTypeEmbeddings 101 | return layerNorm(embeddings) 102 | } 103 | 104 | private func createPositionIdsFromInputIds( 105 | inputIds: MLTensor, 106 | paddingIdx: Int32 107 | ) -> MLTensor { 108 | let mask = (inputIds .!= paddingIdx).cast(to: Int32.self) 109 | let incrementalIndices = mask.cumulativeSum(alongAxis: 1).cast(like: mask) * mask 110 | return incrementalIndices + paddingIdx 111 | } 112 | } 113 | } 114 | 115 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 116 | extension Roberta { 117 | public struct Output: Sendable { 118 | let dense: MLTensorUtils.Layer 119 | let layerNorm: MLTensorUtils.Layer 120 | 121 | public init( 122 | dense: @escaping MLTensorUtils.Layer, 123 | layerNorm: @escaping MLTensorUtils.Layer 124 | ) { 125 | self.dense = dense 126 | self.layerNorm = layerNorm 127 | } 128 | 129 | public func callAsFunction( 130 | hiddenStates: MLTensor, 131 | inputTensor: MLTensor 132 | ) -> MLTensor { 133 | let dense = dense(hiddenStates) 134 | let layerNormInput = dense + inputTensor 135 | return layerNorm(layerNormInput) 136 | } 137 | } 138 | } 139 | 140 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 141 | extension Roberta { 142 | public struct Intermediate: Sendable { 143 | let dense: MLTensorUtils.Layer 144 | 145 | public init(dense: @escaping MLTensorUtils.Layer) { 146 | self.dense = dense 147 | } 148 | 149 | public func callAsFunction(hiddenStates: MLTensor) -> MLTensor { 150 | let dense = dense(hiddenStates) 151 | return gelu(dense) 152 | } 153 | } 154 | } 155 | 156 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 157 | extension Roberta { 158 | public struct SelfOutput: Sendable { 159 | let dense: MLTensorUtils.Layer 160 | let layerNorm: MLTensorUtils.Layer 161 | 162 | public init( 163 | dense: @escaping MLTensorUtils.Layer, 164 | layerNorm: @escaping MLTensorUtils.Layer 165 | ) { 166 | self.dense = dense 167 | self.layerNorm = layerNorm 168 | } 169 | 170 | public func callAsFunction( 171 | hiddenStates: MLTensor, 172 | inputTensor: MLTensor 173 | ) -> MLTensor { 174 | let dense = dense(hiddenStates) 175 | let layerNormInput = dense + inputTensor 176 | return layerNorm(layerNormInput) 177 | } 178 | } 179 | } 180 | 181 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 182 | extension Roberta { 183 | public struct SelfAttention: Sendable { 184 | let query: MLTensorUtils.Layer 185 | let key: MLTensorUtils.Layer 186 | let value: MLTensorUtils.Layer 187 | let numAttentionHeads: Int 188 | let attentionHeadSize: Int 189 | let allHeadSize: Int 190 | 191 | public init( 192 | query: @escaping MLTensorUtils.Layer, 193 | key: @escaping MLTensorUtils.Layer, 194 | value: @escaping MLTensorUtils.Layer, 195 | numAttentionHeads: Int, 196 | attentionHeadSize: Int, 197 | allHeadSize: Int 198 | ) { 199 | self.query = query 200 | self.key = key 201 | self.value = value 202 | self.numAttentionHeads = numAttentionHeads 203 | self.attentionHeadSize = attentionHeadSize 204 | self.allHeadSize = allHeadSize 205 | } 206 | 207 | private func transposeForScores(_ x: MLTensor) -> MLTensor { 208 | let newShape = x.shape.dropLast() + [numAttentionHeads, attentionHeadSize] 209 | return x.reshaped(to: Array(newShape)).transposed(permutation: 0, 2, 1, 3) 210 | } 211 | 212 | public func callAsFunction( 213 | hiddenStates: MLTensor, 214 | attentionMask: MLTensor? 215 | ) -> MLTensor { 216 | let mixedQueryLayer = query(hiddenStates) 217 | let mixedKeyLayer = key(hiddenStates) 218 | let mixedValueLayer = value(hiddenStates) 219 | 220 | let queryLayer = transposeForScores(mixedQueryLayer) 221 | let keyLayer = transposeForScores(mixedKeyLayer) 222 | let valueLayer = transposeForScores(mixedValueLayer) 223 | 224 | var attentionScores = queryLayer.matmul(keyLayer.transposed(permutation: 0, 1, 3, 2)) 225 | attentionScores = attentionScores / sqrt(Float(attentionHeadSize)) 226 | if let attentionMask { 227 | attentionScores = attentionScores + attentionMask 228 | } 229 | let attentionProbs = attentionScores.softmax(alongAxis: -1) 230 | var contextLayer = attentionProbs.matmul(valueLayer) 231 | contextLayer = contextLayer.transposed(permutation: [0, 2, 1, 3]) 232 | let newContextLayerShape = contextLayer.shape.dropLast(2) + [allHeadSize] 233 | return contextLayer.reshaped(to: Array(newContextLayerShape)) 234 | } 235 | } 236 | } 237 | 238 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 239 | extension Roberta { 240 | public struct Attention: Sendable { 241 | let selfAttention: Roberta.SelfAttention 242 | let output: Roberta.SelfOutput 243 | 244 | public init( 245 | selfAttention: Roberta.SelfAttention, 246 | output: Roberta.SelfOutput 247 | ) { 248 | self.selfAttention = selfAttention 249 | self.output = output 250 | } 251 | 252 | public func callAsFunction( 253 | hiddenStates: MLTensor, 254 | attentionMask: MLTensor? 255 | ) -> MLTensor { 256 | let selfOutputs = selfAttention( 257 | hiddenStates: hiddenStates, 258 | attentionMask: attentionMask 259 | ) 260 | return output( 261 | hiddenStates: selfOutputs, 262 | inputTensor: hiddenStates 263 | ) 264 | } 265 | } 266 | } 267 | 268 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 269 | extension Roberta { 270 | public struct Layer: Sendable { 271 | let attention: Roberta.Attention 272 | let intermediate: Roberta.Intermediate 273 | let output: Roberta.Output 274 | 275 | public init( 276 | attention: Roberta.Attention, 277 | intermediate: Roberta.Intermediate, 278 | output: Roberta.Output 279 | ) { 280 | self.attention = attention 281 | self.intermediate = intermediate 282 | self.output = output 283 | } 284 | 285 | public func callAsFunction( 286 | hiddenStates: MLTensor, 287 | attentionMask: MLTensor? 288 | ) -> MLTensor { 289 | let attentionOutput = attention( 290 | hiddenStates: hiddenStates, 291 | attentionMask: attentionMask 292 | ) 293 | let intermediateOutput = intermediate( 294 | hiddenStates: attentionOutput 295 | ) 296 | return output( 297 | hiddenStates: intermediateOutput, 298 | inputTensor: attentionOutput 299 | ) 300 | } 301 | } 302 | } 303 | 304 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 305 | extension Roberta { 306 | public struct Encoder: Sendable { 307 | let layers: [Roberta.Layer] 308 | 309 | public init(layers: [Roberta.Layer]) { 310 | self.layers = layers 311 | } 312 | 313 | public func callAsFunction( 314 | hiddenStates: MLTensor, 315 | attentionMask: MLTensor? 316 | ) -> MLTensor { 317 | var hiddenStates = hiddenStates 318 | for layer in layers { 319 | hiddenStates = layer( 320 | hiddenStates: hiddenStates, 321 | attentionMask: attentionMask 322 | ) 323 | } 324 | return hiddenStates 325 | } 326 | } 327 | } 328 | 329 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 330 | extension Roberta { 331 | public struct Model: Sendable { 332 | let embeddings: Roberta.Embeddings 333 | let encoder: Roberta.Encoder 334 | 335 | public init( 336 | embeddings: Roberta.Embeddings, 337 | encoder: Roberta.Encoder 338 | ) { 339 | self.embeddings = embeddings 340 | self.encoder = encoder 341 | } 342 | 343 | public func callAsFunction( 344 | inputIds: MLTensor, 345 | tokenTypeIds: MLTensor? = nil, 346 | attentionMask: MLTensor? = nil 347 | ) -> MLTensor { 348 | let embeddingOutput = embeddings(inputIds: inputIds, tokenTypeIds: tokenTypeIds) 349 | let mask: MLTensor? = 350 | if let attentionMask { 351 | (1.0 - attentionMask.expandingShape(at: 1, 1)) * -10000.0 352 | } else { 353 | nil 354 | } 355 | return encoder(hiddenStates: embeddingOutput, attentionMask: mask) 356 | } 357 | } 358 | } 359 | 360 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 361 | extension Roberta { 362 | public struct ModelBundle: Sendable { 363 | public let model: Roberta.Model 364 | public let tokenizer: any TextTokenizer 365 | 366 | public init( 367 | model: Roberta.Model, 368 | tokenizer: any TextTokenizer 369 | ) { 370 | self.model = model 371 | self.tokenizer = tokenizer 372 | } 373 | 374 | public func encode( 375 | _ text: String, 376 | maxLength: Int = 512 377 | ) throws -> MLTensor { 378 | let tokens = try tokenizer.tokenizeText(text, maxLength: maxLength) 379 | let inputIds = MLTensor(shape: [1, tokens.count], scalars: tokens) 380 | let result = model(inputIds: inputIds) 381 | return result[0..., 0, 0...] 382 | } 383 | 384 | public func batchEncode( 385 | _ texts: [String], 386 | padTokenId: Int = 0, 387 | maxLength: Int = 512 388 | ) throws -> MLTensor { 389 | let encodedTexts = try tokenizer.tokenizeTextsPaddingToLongest( 390 | texts, padTokenId: padTokenId, maxLength: maxLength) 391 | let inputIds = MLTensor( 392 | shape: [encodedTexts.count, encodedTexts[0].count], 393 | scalars: encodedTexts.flatMap { $0 }) 394 | return model(inputIds: inputIds)[0..., 0, 0...] 395 | } 396 | } 397 | } 398 | -------------------------------------------------------------------------------- /Sources/Embeddings/Roberta/RobertaUtils.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Foundation 3 | import Hub 4 | import MLTensorUtils 5 | import Safetensors 6 | @preconcurrency import Tokenizers 7 | 8 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 9 | extension Roberta { 10 | public static func loadConfig(at url: URL) throws -> Roberta.ModelConfig { 11 | try loadConfigFromFile(at: url) 12 | } 13 | } 14 | 15 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 16 | extension Roberta { 17 | public static func loadModelBundle( 18 | from hubRepoId: String, 19 | downloadBase: URL? = nil, 20 | useBackgroundSession: Bool = false, 21 | loadConfig: LoadConfig = LoadConfig() 22 | ) async throws -> Roberta.ModelBundle { 23 | let modelFolder = try await downloadModelFromHub( 24 | from: hubRepoId, 25 | downloadBase: downloadBase, 26 | useBackgroundSession: useBackgroundSession 27 | ) 28 | return try await loadModelBundle( 29 | from: modelFolder, 30 | loadConfig: loadConfig 31 | ) 32 | } 33 | 34 | public static func loadModelBundle( 35 | from modelFolder: URL, 36 | loadConfig: LoadConfig = LoadConfig() 37 | ) async throws -> Roberta.ModelBundle { 38 | let tokenizer = try await AutoTokenizer.from(modelFolder: modelFolder) 39 | let weightsUrl = modelFolder.appendingPathComponent(loadConfig.modelConfig.weightsFileName) 40 | let configUrl = modelFolder.appendingPathComponent(loadConfig.modelConfig.configFileName) 41 | let config = try Roberta.loadConfig(at: configUrl) 42 | let model = try Roberta.loadModel( 43 | weightsUrl: weightsUrl, 44 | config: config, 45 | loadConfig: loadConfig 46 | ) 47 | return Roberta.ModelBundle(model: model, tokenizer: TokenizerWrapper(tokenizer)) 48 | } 49 | } 50 | 51 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 52 | extension Roberta { 53 | public static func loadModel( 54 | weightsUrl: URL, 55 | config: Roberta.ModelConfig, 56 | loadConfig: LoadConfig = LoadConfig() 57 | ) throws -> Roberta.Model { 58 | // NOTE: just `safetensors` support for now 59 | let safetensors = try Safetensors.read(at: weightsUrl) 60 | let wordEmbeddings = try MLTensorUtils.embedding( 61 | weight: safetensors.mlTensor( 62 | forKey: loadConfig.modelConfig.weightKeyTransform( 63 | "embeddings.word_embeddings.weight"))) 64 | 65 | let tokenTypeEmbeddings = try MLTensorUtils.embedding( 66 | weight: safetensors.mlTensor( 67 | forKey: loadConfig.modelConfig.weightKeyTransform( 68 | "embeddings.token_type_embeddings.weight"))) 69 | 70 | let positionEmbeddings = try MLTensorUtils.embedding( 71 | weight: safetensors.mlTensor( 72 | forKey: loadConfig.modelConfig.weightKeyTransform( 73 | "embeddings.position_embeddings.weight"))) 74 | 75 | let layerNorm = try MLTensorUtils.layerNorm( 76 | weight: safetensors.mlTensor( 77 | forKey: loadConfig.modelConfig.weightKeyTransform( 78 | "embeddings.LayerNorm.weight")), 79 | bias: safetensors.mlTensor( 80 | forKey: loadConfig.modelConfig.weightKeyTransform( 81 | "embeddings.LayerNorm.bias")), 82 | epsilon: config.layerNormEps) 83 | 84 | let embeddings = Roberta.Embeddings( 85 | wordEmbeddings: wordEmbeddings, 86 | positionEmbeddings: positionEmbeddings, 87 | tokenTypeEmbeddings: tokenTypeEmbeddings, 88 | layerNorm: layerNorm, 89 | paddingIdx: Int32(config.padTokenId)) 90 | 91 | var layers = [Roberta.Layer]() 92 | for layer in 0.. MLTensor { 27 | try batchEncode( 28 | [text], 29 | normalize: normalize, 30 | maxLength: maxLength, 31 | truncateDimension: truncateDimension 32 | ) 33 | } 34 | 35 | public func batchEncode( 36 | _ texts: [String], 37 | normalize: Bool = false, 38 | maxLength: Int? = nil, 39 | truncateDimension: Int? = nil 40 | ) throws -> MLTensor { 41 | let dimension = 42 | if let truncateDimension { 43 | min(truncateDimension, model.dimension) 44 | } else { 45 | model.dimension 46 | } 47 | precondition(dimension > 0, "Dimension must be greater than 0") 48 | let inputIdsBatch = try texts.map { try tokenize($0, maxLength: maxLength) } 49 | let embeddingsBatch = inputIdsBatch.map { inputIds in 50 | if let inputIds { 51 | model.embeddings 52 | .gathering(atIndices: inputIds, alongAxis: 0) 53 | .mean(alongAxes: 0)[0.. MLTensor? { 71 | let tokens = try tokenizer.tokenizeText( 72 | text, maxLength: maxLength, addSpecialTokens: false) 73 | return tokens.isEmpty ? nil : MLTensor(shape: [tokens.count], scalars: tokens) 74 | } 75 | } 76 | } 77 | 78 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 79 | extension StaticEmbeddings { 80 | public struct Model: Sendable { 81 | public let embeddings: MLTensor 82 | public let dimension: Int 83 | 84 | public init(embeddings: MLTensor) { 85 | self.embeddings = embeddings 86 | self.dimension = embeddings.shape[1] 87 | } 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /Sources/Embeddings/StaticEmbeddings/StaticEmbeddingsUtils.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Foundation 3 | import Hub 4 | import MLTensorUtils 5 | import Safetensors 6 | @preconcurrency import Tokenizers 7 | 8 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 9 | extension StaticEmbeddings { 10 | public static func loadModelBundle( 11 | from hubRepoId: String, 12 | downloadBase: URL? = nil, 13 | useBackgroundSession: Bool = false, 14 | loadConfig: LoadConfig = LoadConfig() 15 | ) async throws -> StaticEmbeddings.ModelBundle { 16 | let modelFolder = try await downloadModelFromHub( 17 | from: hubRepoId, 18 | downloadBase: downloadBase, 19 | useBackgroundSession: useBackgroundSession 20 | ) 21 | return try await loadModelBundle( 22 | from: modelFolder, 23 | loadConfig: loadConfig 24 | ) 25 | } 26 | 27 | public static func loadModelBundle( 28 | from modelFolder: URL, 29 | loadConfig: LoadConfig = LoadConfig() 30 | ) async throws -> StaticEmbeddings.ModelBundle { 31 | let tokenizer = 32 | if let tokenizerConfig = loadConfig.tokenizerConfig { 33 | try AutoTokenizer.from( 34 | modelFolder: modelFolder, 35 | tokenizerData: tokenizerConfig.data, 36 | tokenizerConfig: tokenizerConfig.config 37 | ) 38 | } else { 39 | try await AutoTokenizer.from(modelFolder: modelFolder) 40 | } 41 | let weightsUrl = modelFolder.appendingPathComponent(loadConfig.modelConfig.weightsFileName) 42 | let model = try StaticEmbeddings.loadModel( 43 | weightsUrl: weightsUrl, 44 | loadConfig: loadConfig 45 | ) 46 | return StaticEmbeddings.ModelBundle( 47 | model: model, 48 | tokenizer: TokenizerWrapper(tokenizer) 49 | ) 50 | } 51 | } 52 | 53 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 54 | extension StaticEmbeddings { 55 | public static func loadModel( 56 | weightsUrl: URL, 57 | loadConfig: LoadConfig = LoadConfig() 58 | ) throws -> StaticEmbeddings.Model { 59 | let data = try Safetensors.read(at: weightsUrl) 60 | let embeddings = try data.mlTensor( 61 | forKey: loadConfig.modelConfig.weightKeyTransform("embedding.weight")) 62 | return StaticEmbeddings.Model(embeddings: embeddings) 63 | } 64 | } 65 | 66 | extension AutoTokenizer { 67 | static func from( 68 | modelFolder: URL, 69 | tokenizerData: TokenizerConfigType, 70 | tokenizerConfig: TokenizerConfigType 71 | ) throws -> any Tokenizer { 72 | let tokenizerConfig = try resolveConfig(tokenizerConfig, in: modelFolder) 73 | let tokenizerData = try resolveConfig(tokenizerData, in: modelFolder) 74 | return try AutoTokenizer.from( 75 | tokenizerConfig: tokenizerConfig, 76 | tokenizerData: tokenizerData 77 | ) 78 | } 79 | } 80 | 81 | func resolveConfig(_ tokenizerConfig: TokenizerConfigType, in modelFolder: URL) throws -> Config { 82 | switch tokenizerConfig { 83 | case .filePath(let filePath): 84 | let fileURL = modelFolder.appendingPathComponent(filePath) 85 | let data = try loadJSONConfig(at: fileURL) 86 | return Config(data as [NSString: Any]) 87 | case .data(let data): 88 | return Config(data as [NSString: Any]) 89 | } 90 | } 91 | 92 | func loadJSONConfig(at filePath: URL) throws -> [String: Any] { 93 | let data = try Data(contentsOf: filePath) 94 | let parsedData = try JSONSerialization.jsonObject(with: data, options: []) 95 | guard let config = parsedData as? [String: Any] else { 96 | throw EmbeddingsError.invalidFile 97 | } 98 | return config 99 | } 100 | -------------------------------------------------------------------------------- /Sources/Embeddings/Tokenizer/ClipTokenizer.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import Synchronization 3 | 4 | extension Regex: @retroactive @unchecked Sendable {} 5 | 6 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 7 | final class ClipTokenizer: Sendable { 8 | let bos: String 9 | let bosToken: Int 10 | let eos: String 11 | let eosToken: Int 12 | let unk: String 13 | let unkToken: Int 14 | private let bpeRanks: [Pair: Int] 15 | private let vocab: [String: Int] 16 | private let splitStringPattern: Regex 17 | private let emptyStringPattern: Regex 18 | private let cache: Mutex<[String: [String]]> 19 | 20 | init( 21 | bpeRanks: [Pair: Int], 22 | vocab: [String: Int] 23 | ) throws { 24 | self.bpeRanks = bpeRanks 25 | self.vocab = vocab 26 | self.splitStringPattern = try Regex( 27 | "<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+" 28 | ) 29 | self.emptyStringPattern = try Regex("\\s+") 30 | self.bos = "<|startoftext|>" 31 | self.bosToken = vocab[bos]! 32 | self.eos = "<|endoftext|>" 33 | self.eosToken = vocab[eos]! 34 | self.unk = "<|endoftext|>" 35 | self.unkToken = vocab[unk]! 36 | self.cache = Mutex([:]) 37 | } 38 | 39 | func tokenize( 40 | _ text: String, 41 | maxLength: Int?, 42 | padToLength: Int? = nil, 43 | addSpecialTokens: Bool 44 | ) -> [Int] { 45 | let cleanText = text.lowercased().replacing(emptyStringPattern, with: " ") 46 | let tokens = cleanText.ranges(of: splitStringPattern).map { String(cleanText[$0]) } 47 | let bpeTokens = tokens.flatMap { bpe($0) } 48 | let tokenIds = bpeTokens.map { vocab[$0]! } 49 | var result = addSpecialTokens ? [bosToken] : [] 50 | if let maxLength { 51 | if addSpecialTokens { 52 | precondition( 53 | maxLength >= 2, "maxLength must be at least 2 to accommodate BOS and EOS tokens" 54 | ) 55 | // Truncate to maxLength - 2 to make space for bos and eos tokens 56 | result.append(contentsOf: tokenIds.prefix(maxLength - 2)) 57 | } else { 58 | result.append(contentsOf: tokenIds.prefix(maxLength)) 59 | } 60 | } else { 61 | result.append(contentsOf: tokenIds) 62 | } 63 | if addSpecialTokens { 64 | result.append(eosToken) 65 | } 66 | // If padToLength is provided, pad the tokenIds with 0s 67 | if let padToLength { 68 | precondition(padToLength - 2 >= 0, "padToLength must be greater than or equal to 2") 69 | if let maxLength { 70 | if padToLength <= maxLength { 71 | result.append( 72 | contentsOf: Array(repeating: 0, count: padToLength - result.count)) 73 | } 74 | } else { 75 | result.append( 76 | contentsOf: Array(repeating: 0, count: padToLength - result.count)) 77 | } 78 | } 79 | return result 80 | } 81 | 82 | private func bpe(_ text: String) -> [String] { 83 | let cachedValue = cache.withLock { 84 | $0[text] 85 | } 86 | if let cachedValue { 87 | return cachedValue 88 | } 89 | var unigrams = text.dropLast().map { String($0) } + ["\(text.suffix(1))"] 90 | var uniqueBigrams = uniquePairs(from: unigrams) 91 | while !uniqueBigrams.isEmpty { 92 | guard let lowestMergePair = findLowestMergePair(in: uniqueBigrams, using: bpeRanks) 93 | else { 94 | break 95 | } 96 | var newUnigrams = [String]() 97 | var skip = false 98 | for (first, second) in zip(unigrams, unigrams.dropFirst()) { 99 | if skip { 100 | skip = false 101 | continue 102 | } 103 | let pair = Pair(first: first, second: second) 104 | if pair == lowestMergePair { 105 | newUnigrams.append(first + second) 106 | skip = true 107 | } else { 108 | newUnigrams.append(first) 109 | } 110 | } 111 | 112 | if !skip { 113 | newUnigrams.append(unigrams.last!) 114 | } 115 | 116 | unigrams = newUnigrams 117 | uniqueBigrams = uniquePairs(from: unigrams) 118 | } 119 | 120 | cache.withLock { 121 | $0[text] = unigrams 122 | } 123 | return unigrams 124 | } 125 | } 126 | 127 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 128 | extension ClipTokenizer: TextTokenizer { 129 | var unknownTokenId: Int? { 130 | unkToken 131 | } 132 | 133 | func tokenizeText( 134 | _ text: String, 135 | maxLength: Int?, 136 | addSpecialTokens: Bool 137 | ) throws -> [Int32] { 138 | tokenize( 139 | text, 140 | maxLength: maxLength, 141 | padToLength: nil, 142 | addSpecialTokens: addSpecialTokens 143 | ).map { Int32($0) } 144 | } 145 | } 146 | 147 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 148 | func loadClipTokenizer(at url: URL) throws -> ClipTokenizer { 149 | let mergesData = try String( 150 | contentsOf: url.appendingPathComponent("merges.txt"), 151 | encoding: .utf8) 152 | let merges = mergesData.split(separator: "\n").dropFirst() 153 | var bpeRanks = [Pair: Int]() 154 | for (index, line) in merges.enumerated() { 155 | let pair = line.trimmingCharacters(in: .whitespacesAndNewlines).components(separatedBy: " ") 156 | if pair.count != 2 { 157 | fatalError("Malformed data on line \(line)") 158 | } 159 | bpeRanks[Pair(first: pair[0], second: pair[1])] = index 160 | } 161 | let vocabData = try JSONSerialization.jsonObject( 162 | with: Data(contentsOf: url.appendingPathComponent("vocab.json"))) 163 | guard let vocab = vocabData as? [String: Int] else { 164 | fatalError("Malformed vocab data") 165 | } 166 | return try ClipTokenizer(bpeRanks: bpeRanks, vocab: vocab) 167 | } 168 | 169 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 170 | func uniquePairs(from arr: [String]) -> Set> { 171 | Set(zip(arr, arr.dropFirst()).map { Pair($0) }) 172 | } 173 | 174 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 175 | func findLowestMergePair( 176 | in keySet: Set>, 177 | using bpeRanks: [Pair: Int] 178 | ) -> Pair? { 179 | var pair: Pair? 180 | var index: Int? 181 | for key in keySet { 182 | guard let mergeIndex = bpeRanks[key] else { 183 | continue 184 | } 185 | if let currentIndex = index { 186 | if mergeIndex < currentIndex { 187 | index = mergeIndex 188 | pair = key 189 | } 190 | } else { 191 | index = mergeIndex 192 | pair = key 193 | } 194 | } 195 | guard let pair else { 196 | return nil 197 | } 198 | return pair 199 | } 200 | 201 | struct Pair { 202 | let first: T 203 | let second: T 204 | 205 | init(first: T, second: T) { 206 | self.first = first 207 | self.second = second 208 | } 209 | } 210 | 211 | extension Pair { 212 | init(_ pair: (T, T)) { 213 | self.init(first: pair.0, second: pair.1) 214 | } 215 | } 216 | 217 | extension Pair: Equatable where T: Equatable {} 218 | extension Pair: Hashable where T: Hashable {} 219 | extension Pair: Sendable where T: Sendable {} 220 | -------------------------------------------------------------------------------- /Sources/Embeddings/Tokenizer/TextTokenizer.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | @preconcurrency import Tokenizers 3 | 4 | public protocol TextTokenizer: Sendable { 5 | var unknownTokenId: Int? { get } 6 | 7 | func tokenizeText(_ text: String) throws -> [Int32] 8 | func tokenizeText(_ text: String, maxLength: Int?) throws -> [Int32] 9 | func tokenizeText(_ text: String, maxLength: Int?, addSpecialTokens: Bool) throws -> [Int32] 10 | func tokenizeTextsPaddingToLongest(_ texts: [String], padTokenId: Int) throws -> [[Int32]] 11 | func tokenizeTextsPaddingToLongest( 12 | _ texts: [String], padTokenId: Int, maxLength: Int? 13 | ) throws -> [[Int32]] 14 | func tokenizeTextsPaddingToLongest( 15 | _ texts: [String], padTokenId: Int, maxLength: Int?, addSpecialTokens: Bool 16 | ) throws -> [[Int32]] 17 | } 18 | 19 | extension TextTokenizer { 20 | public func tokenizeText(_ text: String) throws -> [Int32] { 21 | try tokenizeText(text, maxLength: nil, addSpecialTokens: true) 22 | } 23 | 24 | public func tokenizeText(_ text: String, maxLength: Int?) throws -> [Int32] { 25 | try tokenizeText(text, maxLength: maxLength, addSpecialTokens: true) 26 | } 27 | 28 | public func tokenizeTextsPaddingToLongest( 29 | _ texts: [String], 30 | padTokenId: Int 31 | ) throws -> [[Int32]] { 32 | try tokenizeTextsPaddingToLongest( 33 | texts, padTokenId: padTokenId, maxLength: nil, addSpecialTokens: true) 34 | } 35 | 36 | public func tokenizeTextsPaddingToLongest( 37 | _ texts: [String], 38 | padTokenId: Int, 39 | maxLength: Int? 40 | ) throws -> [[Int32]] { 41 | try tokenizeTextsPaddingToLongest( 42 | texts, padTokenId: padTokenId, maxLength: maxLength, addSpecialTokens: true) 43 | } 44 | 45 | public func tokenizeTextsPaddingToLongest( 46 | _ texts: [String], 47 | padTokenId: Int, 48 | maxLength: Int?, 49 | addSpecialTokens: Bool 50 | ) throws -> [[Int32]] { 51 | var longest = 0 52 | var result = [[Int32]]() 53 | result.reserveCapacity(texts.count) 54 | for text in texts { 55 | let encoded = try tokenizeText( 56 | text, 57 | maxLength: maxLength, 58 | addSpecialTokens: addSpecialTokens 59 | ) 60 | longest = max(longest, encoded.count) 61 | result.append(encoded) 62 | } 63 | return result.map { 64 | if $0.count < longest { 65 | return $0 + Array(repeating: Int32(padTokenId), count: longest - $0.count) 66 | } else { 67 | return $0 68 | } 69 | } 70 | } 71 | } 72 | 73 | public struct TokenizerWrapper { 74 | private let tokenizer: any Tokenizers.Tokenizer 75 | 76 | public var unknownTokenId: Int? { 77 | tokenizer.unknownTokenId 78 | } 79 | 80 | public init(_ tokenizer: any Tokenizers.Tokenizer) { 81 | self.tokenizer = tokenizer 82 | } 83 | } 84 | 85 | extension TokenizerWrapper: TextTokenizer { 86 | public func tokenizeText( 87 | _ text: String, 88 | maxLength: Int?, 89 | addSpecialTokens: Bool 90 | ) throws -> [Int32] { 91 | var encoded = tokenizer.encode(text: text, addSpecialTokens: addSpecialTokens) 92 | if let maxLength, encoded.count > maxLength { 93 | encoded.removeLast(encoded.count - maxLength) 94 | } 95 | return encoded.map { Int32($0) } 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /Sources/Embeddings/Tokenizer/XLMRobetaTokenizer.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import SentencepieceTokenizer 3 | import Synchronization 4 | 5 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 6 | final class XLMRobetaTokenizer: Sendable { 7 | /// TODO: Make `SentencepieceTokenizer` conform to `Sendable` 8 | private let tokenizer: Mutex 9 | private let addedTokens: [String: Int] 10 | 11 | init( 12 | tokenizerModelUrl: URL, 13 | addedTokens: [String: Int], 14 | tokenOffset: Int = 1 15 | ) throws { 16 | let sentencepieceTokenizer = try SentencepieceTokenizer( 17 | modelPath: tokenizerModelUrl.path, 18 | tokenOffset: tokenOffset 19 | ) 20 | self.tokenizer = Mutex(sentencepieceTokenizer) 21 | self.addedTokens = addedTokens 22 | } 23 | 24 | func tokenize( 25 | _ text: String, 26 | maxLength: Int?, 27 | padToLength: Int? = nil, 28 | addSpecialTokens: Bool 29 | ) throws -> [Int] { 30 | let tokenIds = try tokenizer.withLock { 31 | try $0.encode(text) 32 | } 33 | var result = addSpecialTokens ? [bosTokenId] : [] 34 | if let maxLength { 35 | if addSpecialTokens { 36 | precondition( 37 | maxLength >= 2, "maxLength must be at least 2 to accommodate BOS and EOS tokens" 38 | ) 39 | // Truncate to maxLength - 2 to make space for bos and eos tokens 40 | result.append(contentsOf: tokenIds.prefix(maxLength - 2)) 41 | } else { 42 | result.append(contentsOf: tokenIds.prefix(maxLength)) 43 | } 44 | } else { 45 | result.append(contentsOf: tokenIds) 46 | } 47 | if addSpecialTokens { 48 | result.append(eosTokenId) 49 | } 50 | // If padToLength is provided, pad the tokenIds with padTokenId 51 | if let padToLength { 52 | precondition(padToLength - 2 >= 0, "padToLength must be greater than or equal to 2") 53 | if let maxLength { 54 | if padToLength <= maxLength { 55 | result.append( 56 | contentsOf: Array(repeating: 0, count: padToLength - result.count)) 57 | } 58 | } else { 59 | result.append( 60 | contentsOf: Array(repeating: 0, count: padToLength - result.count)) 61 | } 62 | } 63 | return result 64 | } 65 | 66 | var bosTokenId: Int { 67 | if let bosTokenId = addedTokens[""] { 68 | return bosTokenId 69 | } 70 | return tokenizer.withLock { 71 | $0.bosTokenId 72 | } 73 | } 74 | 75 | var eosTokenId: Int { 76 | if let eosTokenId = addedTokens[""] { 77 | return eosTokenId 78 | } 79 | return tokenizer.withLock { 80 | $0.eosTokenId 81 | } 82 | } 83 | 84 | var padTokenId: Int { 85 | if let padTokenId = addedTokens[""] { 86 | return padTokenId 87 | } 88 | return tokenizer.withLock { 89 | $0.padTokenId 90 | } 91 | } 92 | 93 | var unkTokenId: Int { 94 | if let unkTokenId = addedTokens[""] { 95 | return unkTokenId 96 | } 97 | return tokenizer.withLock { 98 | $0.unkTokenId 99 | } 100 | } 101 | } 102 | 103 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 104 | extension XLMRobetaTokenizer: TextTokenizer { 105 | var unknownTokenId: Int? { 106 | unkTokenId 107 | } 108 | 109 | func tokenizeText( 110 | _ text: String, 111 | maxLength: Int?, 112 | addSpecialTokens: Bool 113 | ) throws -> [Int32] { 114 | try tokenize( 115 | text, 116 | maxLength: maxLength, 117 | padToLength: nil, 118 | addSpecialTokens: addSpecialTokens 119 | ).map { Int32($0) } 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /Sources/Embeddings/Word2Vec/Word2VecModel.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Foundation 3 | import MLTensorUtils 4 | 5 | public enum Word2Vec {} 6 | 7 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 8 | extension Word2Vec { 9 | public struct ModelBundle: Sendable { 10 | public let keyToIndex: [String: Int] 11 | public let indexToKey: [Int: String] 12 | public let embeddings: MLTensor 13 | 14 | public init( 15 | keyToIndex: [String: Int], 16 | indexToKey: [Int: String], 17 | embeddings: MLTensor 18 | ) { 19 | self.keyToIndex = keyToIndex 20 | self.indexToKey = indexToKey 21 | self.embeddings = embeddings 22 | } 23 | 24 | public func encode(_ word: String) -> MLTensor? { 25 | guard let index = keyToIndex[word] else { 26 | return nil 27 | } 28 | return embeddings[index] 29 | } 30 | 31 | public func batchEncode(_ words: [String]) -> MLTensor? { 32 | let indices = words.compactMap { keyToIndex[$0] } 33 | let rows = indices.map { embeddings[$0] } 34 | return MLTensor(stacking: rows, alongAxis: 0) 35 | } 36 | 37 | public func mostSimilar( 38 | to word: String, 39 | topK: Int = 1 40 | ) async -> [(word: String, score: Float)] { 41 | guard let wordIndex = keyToIndex[word] else { 42 | return [] 43 | } 44 | // Get the embedding vector for the input word 45 | let wordVector = embeddings[wordIndex] 46 | 47 | // Normalize the word vector 48 | let wordVectorNorm = wordVector / (norm(wordVector, alongAxes: 0) + Float.ulpOfOne) 49 | 50 | // Normalize all embedding vectors 51 | let norms = norm(embeddings, alongAxes: 1) + Float.ulpOfOne 52 | let normalizedEmbeddings = embeddings / norms.expandingShape(at: 1) 53 | 54 | // Compute similarity 55 | let similarities = normalizedEmbeddings.matmul(wordVectorNorm.transposed()) 56 | 57 | // +1 to account for the input word. NOTE: using `topK` function results in a hard crash 58 | let indices = similarities.argsort(descendingOrder: true)[ 59 | .. Word2Vec.ModelBundle { 14 | let modelFolder = try await downloadModelFromHub( 15 | from: hubRepoId, 16 | downloadBase: downloadBase, 17 | useBackgroundSession: useBackgroundSession, 18 | globs: [loadConfig.modelConfig.weightsFileName] 19 | ) 20 | let modelFile = modelFolder.appendingPathComponent(loadConfig.modelConfig.weightsFileName) 21 | return try await loadModelBundle(from: modelFile) 22 | } 23 | 24 | public static func loadModelBundle( 25 | from modelFile: URL 26 | ) async throws -> Word2Vec.ModelBundle { 27 | let data = try Data(contentsOf: modelFile, options: .mappedIfSafe) 28 | let lines = String(decoding: data, as: UTF8.self).components(separatedBy: .newlines) 29 | var lineCount: Int? 30 | var vectorSize: Int? 31 | for line in lines.prefix(1) { 32 | let parts = line.components(separatedBy: .whitespaces) 33 | if parts.count == 2 { 34 | lineCount = Int(parts[0]) 35 | vectorSize = Int(parts[1]) 36 | } 37 | } 38 | guard let lineCount, let vectorSize else { 39 | throw EmbeddingsError.invalidFile 40 | } 41 | var keyToIndex = [String: Int]() 42 | keyToIndex.reserveCapacity(lineCount) 43 | var indexToKey = [Int: String]() 44 | indexToKey.reserveCapacity(lineCount) 45 | var vectors = [Float]() 46 | vectors.reserveCapacity(lineCount * vectorSize) 47 | var currentIndex = 0 48 | for line in lines.dropFirst() where !line.isEmpty { 49 | let parts = line.components(separatedBy: .whitespaces) 50 | let word = parts[0] 51 | let vector = parts.dropFirst().map { Float($0)! } 52 | if vector.count != vectorSize { 53 | throw EmbeddingsError.invalidFile 54 | } 55 | keyToIndex[word] = currentIndex 56 | indexToKey[currentIndex] = word 57 | vectors.append(contentsOf: vector) 58 | currentIndex += 1 59 | } 60 | let embeddings = MLTensor(shape: [lineCount, vectorSize], scalars: vectors) 61 | return Word2Vec.ModelBundle( 62 | keyToIndex: keyToIndex, 63 | indexToKey: indexToKey, 64 | embeddings: embeddings 65 | ) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /Sources/Embeddings/XLMRoberta/XLMRobertaModel.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Foundation 3 | import MLTensorUtils 4 | @preconcurrency import Tokenizers 5 | 6 | public enum XLMRoberta {} 7 | 8 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 9 | extension XLMRoberta { 10 | public struct ModelConfig: Codable { 11 | public let hiddenSize: Int 12 | public let numHiddenLayers: Int 13 | public let intermediateSize: Int 14 | public let numAttentionHeads: Int 15 | public let maxPositionEmbeddings: Int 16 | public let layerNormEps: Float 17 | public let vocabSize: Int 18 | public let addPoolingLayer: Bool? 19 | public let attentionProbsDropoutProb: Float 20 | public let hiddenDropoutProb: Float 21 | public let typeVocabSize: Int 22 | public let outputPast: Bool 23 | public let padTokenId: Int 24 | public let positionEmbeddingType: String 25 | public let poolingConfig: [String: String]? 26 | 27 | public init( 28 | hiddenSize: Int, 29 | numHiddenLayers: Int, 30 | intermediateSize: Int, 31 | numAttentionHeads: Int, 32 | maxPositionEmbeddings: Int, 33 | layerNormEps: Float, 34 | vocabSize: Int, 35 | addPoolingLayer: Bool?, 36 | attentionProbsDropoutProb: Float, 37 | hiddenDropoutProb: Float, 38 | typeVocabSize: Int, 39 | outputPast: Bool, 40 | padTokenId: Int, 41 | positionEmbeddingType: String, 42 | poolingConfig: [String: String]? 43 | ) { 44 | self.hiddenSize = hiddenSize 45 | self.numHiddenLayers = numHiddenLayers 46 | self.intermediateSize = intermediateSize 47 | self.numAttentionHeads = numAttentionHeads 48 | self.maxPositionEmbeddings = maxPositionEmbeddings 49 | self.layerNormEps = layerNormEps 50 | self.vocabSize = vocabSize 51 | self.addPoolingLayer = addPoolingLayer 52 | self.attentionProbsDropoutProb = attentionProbsDropoutProb 53 | self.hiddenDropoutProb = hiddenDropoutProb 54 | self.typeVocabSize = typeVocabSize 55 | self.outputPast = outputPast 56 | self.padTokenId = padTokenId 57 | self.positionEmbeddingType = positionEmbeddingType 58 | self.poolingConfig = poolingConfig 59 | } 60 | } 61 | } 62 | 63 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 64 | extension XLMRoberta { 65 | public struct Pooler: Sendable { 66 | let dense: MLTensorUtils.Layer 67 | 68 | public init(dense: @escaping MLTensorUtils.Layer) { 69 | self.dense = dense 70 | } 71 | 72 | public func callAsFunction(hiddenStates: MLTensor) -> MLTensor { 73 | let firstTokenTensor = hiddenStates[0..., 0] 74 | let pooledOutput = dense(firstTokenTensor) 75 | return pooledOutput.tanh() 76 | } 77 | } 78 | } 79 | 80 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 81 | extension XLMRoberta { 82 | public struct Embeddings: Sendable { 83 | let wordEmbeddings: MLTensorUtils.Layer 84 | let positionEmbeddings: MLTensorUtils.Layer 85 | let tokenTypeEmbeddings: MLTensorUtils.Layer 86 | let layerNorm: MLTensorUtils.Layer 87 | let paddingIndex: Int32 88 | 89 | public init( 90 | wordEmbeddings: @escaping MLTensorUtils.Layer, 91 | positionEmbeddings: @escaping MLTensorUtils.Layer, 92 | tokenTypeEmbeddings: @escaping MLTensorUtils.Layer, 93 | layerNorm: @escaping MLTensorUtils.Layer, 94 | paddingIndex: Int32 95 | ) { 96 | self.wordEmbeddings = wordEmbeddings 97 | self.positionEmbeddings = positionEmbeddings 98 | self.tokenTypeEmbeddings = tokenTypeEmbeddings 99 | self.layerNorm = layerNorm 100 | self.paddingIndex = paddingIndex 101 | } 102 | 103 | private func createPositionIds( 104 | from inputIds: MLTensor, 105 | pastKeyValuesLength: Int32 106 | ) -> MLTensor { 107 | let mask = (inputIds .!= paddingIndex).cast(to: Int32.self) 108 | let incrementalIndices = (mask.cumulativeSum(alongAxis: 1) + pastKeyValuesLength) * mask 109 | return incrementalIndices + paddingIndex 110 | } 111 | 112 | public func callAsFunction( 113 | inputIds: MLTensor, 114 | tokenTypeIds: MLTensor? = nil, 115 | positionIds: MLTensor? = nil, 116 | inputsEmbeds: MLTensor? = nil, 117 | pastKeyValuesLength: Int32 = 0 118 | ) -> MLTensor { 119 | let positionIds = 120 | positionIds 121 | ?? createPositionIds(from: inputIds, pastKeyValuesLength: pastKeyValuesLength) 122 | let tokenTypeIds = 123 | tokenTypeIds 124 | ?? MLTensor( 125 | zeros: inputIds.shape, 126 | scalarType: Int32.self 127 | ) 128 | let inputEmbeddings = inputsEmbeds ?? wordEmbeddings(inputIds) 129 | let positionEmbeddings = positionEmbeddings(positionIds) 130 | let tokenTypeEmbeddings = tokenTypeEmbeddings(tokenTypeIds) 131 | let embeddings = inputEmbeddings + tokenTypeEmbeddings + positionEmbeddings 132 | return layerNorm(embeddings) 133 | } 134 | } 135 | } 136 | 137 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 138 | extension XLMRoberta { 139 | public struct Intermediate: Sendable { 140 | let dense: MLTensorUtils.Layer 141 | 142 | public init(dense: @escaping MLTensorUtils.Layer) { 143 | self.dense = dense 144 | } 145 | 146 | public func callAsFunction(hiddenStates: MLTensor) -> MLTensor { 147 | let dense = dense(hiddenStates) 148 | return gelu(dense) 149 | } 150 | } 151 | } 152 | 153 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 154 | extension XLMRoberta { 155 | public struct SelfOutput: Sendable { 156 | let dense: MLTensorUtils.Layer 157 | let layerNorm: MLTensorUtils.Layer 158 | 159 | public init( 160 | dense: @escaping MLTensorUtils.Layer, 161 | layerNorm: @escaping MLTensorUtils.Layer 162 | ) { 163 | self.dense = dense 164 | self.layerNorm = layerNorm 165 | } 166 | 167 | public func callAsFunction( 168 | hiddenStates: MLTensor, 169 | inputTensor: MLTensor 170 | ) -> MLTensor { 171 | let dense = dense(hiddenStates) 172 | let layerNormInput = dense + inputTensor 173 | return layerNorm(layerNormInput) 174 | } 175 | } 176 | } 177 | 178 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 179 | extension XLMRoberta { 180 | public struct SelfAttention: Sendable { 181 | let query: MLTensorUtils.Layer 182 | let key: MLTensorUtils.Layer 183 | let value: MLTensorUtils.Layer 184 | let numAttentionHeads: Int 185 | let attentionHeadSize: Int 186 | let allHeadSize: Int 187 | let scale: Float 188 | 189 | public init( 190 | query: @escaping MLTensorUtils.Layer, 191 | key: @escaping MLTensorUtils.Layer, 192 | value: @escaping MLTensorUtils.Layer, 193 | numAttentionHeads: Int, 194 | attentionHeadSize: Int, 195 | allHeadSize: Int, 196 | scale: Float 197 | ) { 198 | self.query = query 199 | self.key = key 200 | self.value = value 201 | self.numAttentionHeads = numAttentionHeads 202 | self.attentionHeadSize = attentionHeadSize 203 | self.allHeadSize = allHeadSize 204 | self.scale = scale 205 | } 206 | 207 | private func transposeForScores(_ x: MLTensor) -> MLTensor { 208 | let newShape = x.shape.dropLast() + [numAttentionHeads, attentionHeadSize] 209 | return x.reshaped(to: Array(newShape)).transposed(permutation: 0, 2, 1, 3) 210 | } 211 | 212 | public func callAsFunction( 213 | hiddenStates: MLTensor, 214 | attentionMask: MLTensor?, 215 | headMask: MLTensor? 216 | ) -> MLTensor { 217 | let queries = query(hiddenStates) 218 | let keys = key(hiddenStates) 219 | let values = value(hiddenStates) 220 | 221 | let queryLayer = transposeForScores(queries) 222 | let keyLayer = transposeForScores(keys) 223 | let valueLayer = transposeForScores(values) 224 | 225 | var attentionScores = queryLayer.matmul(keyLayer.transposed(permutation: 0, 1, 3, 2)) 226 | attentionScores = attentionScores / sqrt(Float(attentionHeadSize)) 227 | if let attentionMask { 228 | attentionScores = attentionScores + attentionMask 229 | } 230 | var attentionProbs = attentionScores.softmax(alongAxis: -1) 231 | if let headMask { 232 | attentionProbs = attentionProbs * headMask 233 | } 234 | var contextLayer = attentionProbs.matmul(valueLayer) 235 | contextLayer = contextLayer.transposed(permutation: 0, 2, 1, 3) 236 | let newContextLayerShape = contextLayer.shape.dropLast(2) + [allHeadSize] 237 | return contextLayer.reshaped(to: Array(newContextLayerShape)) 238 | } 239 | } 240 | } 241 | 242 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 243 | extension XLMRoberta { 244 | public struct Attention: Sendable { 245 | let selfAttention: XLMRoberta.SelfAttention 246 | let output: XLMRoberta.SelfOutput 247 | 248 | public init( 249 | selfAttention: XLMRoberta.SelfAttention, 250 | output: XLMRoberta.SelfOutput 251 | ) { 252 | self.selfAttention = selfAttention 253 | self.output = output 254 | } 255 | 256 | public func callAsFunction( 257 | hiddenStates: MLTensor, 258 | attentionMask: MLTensor?, 259 | headMask: MLTensor? 260 | ) -> MLTensor { 261 | let selfOutputs = selfAttention( 262 | hiddenStates: hiddenStates, 263 | attentionMask: attentionMask, 264 | headMask: headMask 265 | ) 266 | return output( 267 | hiddenStates: selfOutputs, 268 | inputTensor: hiddenStates 269 | ) 270 | } 271 | } 272 | } 273 | 274 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 275 | extension XLMRoberta { 276 | public struct Output: Sendable { 277 | let dense: MLTensorUtils.Layer 278 | let layerNorm: MLTensorUtils.Layer 279 | 280 | public init( 281 | dense: @escaping MLTensorUtils.Layer, 282 | layerNorm: @escaping MLTensorUtils.Layer 283 | ) { 284 | self.dense = dense 285 | self.layerNorm = layerNorm 286 | } 287 | 288 | public func callAsFunction( 289 | hiddenStates: MLTensor, 290 | inputTensor: MLTensor 291 | ) -> MLTensor { 292 | let hiddenStates = dense(hiddenStates) 293 | return layerNorm(hiddenStates + inputTensor) 294 | } 295 | } 296 | } 297 | 298 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 299 | extension XLMRoberta { 300 | public struct Layer: Sendable { 301 | let attention: XLMRoberta.Attention 302 | let intermediate: XLMRoberta.Intermediate 303 | let output: XLMRoberta.Output 304 | 305 | public init( 306 | attention: XLMRoberta.Attention, 307 | intermediate: XLMRoberta.Intermediate, 308 | output: XLMRoberta.Output 309 | ) { 310 | self.attention = attention 311 | self.intermediate = intermediate 312 | self.output = output 313 | } 314 | 315 | public func callAsFunction( 316 | hiddenStates: MLTensor, 317 | attentionMask: MLTensor?, 318 | headMask: MLTensor? 319 | ) -> MLTensor { 320 | let attentionOutput = attention( 321 | hiddenStates: hiddenStates, 322 | attentionMask: attentionMask, 323 | headMask: headMask 324 | ) 325 | let intermediateOutput = intermediate( 326 | hiddenStates: attentionOutput 327 | ) 328 | return output( 329 | hiddenStates: intermediateOutput, 330 | inputTensor: attentionOutput 331 | ) 332 | } 333 | } 334 | } 335 | 336 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 337 | extension XLMRoberta { 338 | public struct Encoder: Sendable { 339 | let layers: [XLMRoberta.Layer] 340 | 341 | public init(layers: [XLMRoberta.Layer]) { 342 | self.layers = layers 343 | } 344 | 345 | public func callAsFunction( 346 | hiddenStates: MLTensor, 347 | attentionMask: MLTensor?, 348 | headMask: MLTensor? 349 | ) -> MLTensor { 350 | var hiddenStates = hiddenStates 351 | for (index, layer) in layers.enumerated() { 352 | let layerHeadMask: MLTensor? = 353 | if let headMask { 354 | headMask[index] 355 | } else { 356 | nil 357 | } 358 | hiddenStates = layer( 359 | hiddenStates: hiddenStates, 360 | attentionMask: attentionMask, 361 | headMask: layerHeadMask 362 | ) 363 | } 364 | return hiddenStates 365 | } 366 | } 367 | } 368 | 369 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 370 | extension XLMRoberta { 371 | public struct Model: Sendable { 372 | let embeddings: XLMRoberta.Embeddings 373 | let encoder: XLMRoberta.Encoder 374 | let pooler: XLMRoberta.Pooler? 375 | let numHiddenLayers: Int 376 | 377 | public init( 378 | embeddings: XLMRoberta.Embeddings, 379 | encoder: XLMRoberta.Encoder, 380 | pooler: XLMRoberta.Pooler?, 381 | numHiddenLayers: Int 382 | ) { 383 | self.embeddings = embeddings 384 | self.encoder = encoder 385 | self.pooler = pooler 386 | self.numHiddenLayers = numHiddenLayers 387 | } 388 | 389 | private func extendedAttentionMask(_ attentionMask: MLTensor) -> MLTensor { 390 | let attentionMask: MLTensor = 391 | if attentionMask.rank == 3 { 392 | attentionMask.expandingShape(at: 1) 393 | } else if attentionMask.rank == 2 { 394 | attentionMask.expandingShape(at: 1, 1) 395 | } else { 396 | fatalError("Wrong shape for attentionMask (shape \(attentionMask.shape))") 397 | } 398 | return (1.0 - attentionMask) * -10000.0 399 | } 400 | 401 | public func callAsFunction( 402 | inputIds: MLTensor, 403 | tokenTypeIds: MLTensor? = nil, 404 | attentionMask: MLTensor? = nil, 405 | positionIds: MLTensor? = nil 406 | ) -> (sequenceOutput: MLTensor, pooledOutput: MLTensor?) { 407 | let attentionMask = 408 | attentionMask 409 | ?? MLTensor( 410 | ones: inputIds.shape, 411 | scalarType: Float32.self 412 | ) 413 | let tokenTypeIds = 414 | tokenTypeIds 415 | ?? MLTensor( 416 | zeros: inputIds.shape, 417 | scalarType: Int32.self 418 | ) 419 | let headMask = MLTensor( 420 | repeating: 1 as Int32, shape: [numHiddenLayers]) 421 | let embeddingOutput = embeddings( 422 | inputIds: inputIds, 423 | tokenTypeIds: tokenTypeIds, 424 | positionIds: positionIds 425 | ) 426 | let encoderOutput = encoder( 427 | hiddenStates: embeddingOutput, 428 | attentionMask: extendedAttentionMask(attentionMask), 429 | headMask: headMask 430 | ) 431 | let pooledOutput: MLTensor? = 432 | if let pooler { 433 | pooler(hiddenStates: encoderOutput) 434 | } else { 435 | nil 436 | } 437 | return (encoderOutput, pooledOutput) 438 | } 439 | } 440 | } 441 | 442 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 443 | extension XLMRoberta { 444 | public struct ModelBundle: Sendable { 445 | public let model: XLMRoberta.Model 446 | public let tokenizer: any TextTokenizer 447 | 448 | public init( 449 | model: XLMRoberta.Model, 450 | tokenizer: any TextTokenizer 451 | ) { 452 | self.model = model 453 | self.tokenizer = tokenizer 454 | } 455 | 456 | public func encode(_ text: String, maxLength: Int = 512) throws -> MLTensor { 457 | let tokens = try tokenizer.tokenizeText(text, maxLength: maxLength) 458 | let inputIds = MLTensor(shape: [1, tokens.count], scalars: tokens) 459 | let result = model(inputIds: inputIds) 460 | return result.sequenceOutput[0..., 0, 0...] 461 | } 462 | 463 | public func batchEncode( 464 | _ texts: [String], 465 | padTokenId: Int = 0, 466 | maxLength: Int = 512 467 | ) throws -> MLTensor { 468 | let encodedTexts = try tokenizer.tokenizeTextsPaddingToLongest( 469 | texts, padTokenId: padTokenId, maxLength: maxLength) 470 | let inputIds = MLTensor( 471 | shape: [encodedTexts.count, encodedTexts[0].count], 472 | scalars: encodedTexts.flatMap { $0 }) 473 | return model(inputIds: inputIds).sequenceOutput[0..., 0, 0...] 474 | } 475 | } 476 | } 477 | -------------------------------------------------------------------------------- /Sources/Embeddings/XLMRoberta/XLMRobertaUtils.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Foundation 3 | import Hub 4 | import MLTensorUtils 5 | import Safetensors 6 | @preconcurrency import Tokenizers 7 | 8 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 9 | extension XLMRoberta { 10 | public static func loadConfig(at url: URL) throws -> XLMRoberta.ModelConfig { 11 | try loadConfigFromFile(at: url) 12 | } 13 | } 14 | 15 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 16 | extension XLMRoberta { 17 | public static func loadModelBundle( 18 | from hubRepoId: String, 19 | downloadBase: URL? = nil, 20 | useBackgroundSession: Bool = false, 21 | loadConfig: LoadConfig = LoadConfig() 22 | ) async throws -> XLMRoberta.ModelBundle { 23 | let modelFolder = try await downloadModelFromHub( 24 | from: hubRepoId, 25 | downloadBase: downloadBase, 26 | useBackgroundSession: useBackgroundSession 27 | ) 28 | return try await loadModelBundle( 29 | from: modelFolder, 30 | loadConfig: loadConfig 31 | ) 32 | } 33 | 34 | public static func loadModelBundle( 35 | from modelFolder: URL, 36 | loadConfig: LoadConfig = LoadConfig() 37 | ) async throws -> XLMRoberta.ModelBundle { 38 | let addedTokens = try await loadAddedTokens(from: modelFolder) 39 | let tokenizerModelUrl = try findSentencePieceModel(in: modelFolder) 40 | let tokenizer = try XLMRobetaTokenizer( 41 | tokenizerModelUrl: tokenizerModelUrl, 42 | addedTokens: addedTokens 43 | ) 44 | let weightsUrl = modelFolder.appendingPathComponent(loadConfig.modelConfig.weightsFileName) 45 | let configUrl = modelFolder.appendingPathComponent(loadConfig.modelConfig.configFileName) 46 | let config = try XLMRoberta.loadConfig(at: configUrl) 47 | let model = try XLMRoberta.loadModel( 48 | weightsUrl: weightsUrl, 49 | config: config, 50 | loadConfig: loadConfig 51 | ) 52 | return XLMRoberta.ModelBundle(model: model, tokenizer: tokenizer) 53 | } 54 | 55 | private static func loadAddedTokens(from modelFolder: URL) async throws -> [String: Int] { 56 | let hubConfiguration = LanguageModelConfigurationFromHub(modelFolder: modelFolder) 57 | let addedTokens = try await hubConfiguration.tokenizerData.addedTokens?.array()?.compactMap { 58 | $0.dictionary() 59 | } 60 | guard let addedTokens else { 61 | return [:] 62 | } 63 | var result = [String: Int]() 64 | for addedToken in addedTokens { 65 | if let content = addedToken["content"]?.string(), let id = addedToken["id"]?.integer() { 66 | result[content] = id 67 | } 68 | } 69 | return result 70 | } 71 | 72 | private static func findSentencePieceModel(in folder: URL) throws -> URL { 73 | let fileManager = FileManager.default 74 | let contents = try fileManager.contentsOfDirectory( 75 | at: folder, includingPropertiesForKeys: nil) 76 | for url in contents { 77 | if url.pathExtension == "model", url.lastPathComponent.contains("sentencepiece") { 78 | return url 79 | } 80 | } 81 | throw EmbeddingsError.fileNotFound 82 | } 83 | } 84 | 85 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 86 | extension XLMRoberta { 87 | public static func loadModel( 88 | weightsUrl: URL, 89 | config: XLMRoberta.ModelConfig, 90 | loadConfig: LoadConfig = LoadConfig() 91 | ) throws -> XLMRoberta.Model { 92 | let safetensors = try Safetensors.read(at: weightsUrl) 93 | let wordEmbeddings = try MLTensorUtils.embedding( 94 | weight: safetensors.mlTensor( 95 | forKey: loadConfig.modelConfig.weightKeyTransform( 96 | "embeddings.word_embeddings.weight"))) 97 | 98 | let tokenTypeEmbeddings = try MLTensorUtils.embedding( 99 | weight: safetensors.mlTensor( 100 | forKey: loadConfig.modelConfig.weightKeyTransform( 101 | "embeddings.token_type_embeddings.weight"))) 102 | 103 | let positionEmbeddings = try MLTensorUtils.embedding( 104 | weight: safetensors.mlTensor( 105 | forKey: loadConfig.modelConfig.weightKeyTransform( 106 | "embeddings.position_embeddings.weight"))) 107 | 108 | let layerNorm = try MLTensorUtils.layerNorm( 109 | weight: safetensors.mlTensor( 110 | forKey: loadConfig.modelConfig.weightKeyTransform( 111 | "embeddings.LayerNorm.weight")), 112 | bias: safetensors.mlTensor( 113 | forKey: loadConfig.modelConfig.weightKeyTransform( 114 | "embeddings.LayerNorm.bias")), 115 | epsilon: config.layerNormEps) 116 | 117 | let embeddings = XLMRoberta.Embeddings( 118 | wordEmbeddings: wordEmbeddings, 119 | positionEmbeddings: positionEmbeddings, 120 | tokenTypeEmbeddings: tokenTypeEmbeddings, 121 | layerNorm: layerNorm, 122 | paddingIndex: Int32(config.padTokenId)) 123 | 124 | var layers = [XLMRoberta.Layer]() 125 | for layer in 0.. MLTensor { 12 | switch approximation { 13 | case .none: 14 | return x * (1 + erf(x / sqrt(2 as Float))) / 2 15 | case .fast: 16 | return x * sigmoid(1.702 * x) 17 | case .precise, .tanh: 18 | return 0.5 * x * (1 + (sqrt(2 / Float.pi) * (x + 0.044715 * x.pow(3))).tanh()) 19 | } 20 | } 21 | 22 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 23 | public func sigmoid(_ x: MLTensor) -> MLTensor { 24 | 1 / (1 + (-x).exp()) 25 | } 26 | 27 | // Ref: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations 28 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 29 | public func erf(_ x: MLTensor) -> MLTensor { 30 | let a1: Float = 0.254829592 31 | let a2: Float = -0.284496736 32 | let a3: Float = 1.421413741 33 | let a4: Float = -1.453152027 34 | let a5: Float = 1.061405429 35 | let p: Float = 0.3275911 36 | 37 | let sign = x.sign() 38 | let x = x.abs() 39 | 40 | let t = 1 / (1 + p * x) 41 | let y = 1 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp() 42 | 43 | return sign * y 44 | } 45 | -------------------------------------------------------------------------------- /Sources/MLTensorUtils/Functions.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Foundation 3 | 4 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 5 | public func norm(_ x: MLTensor, alongAxes: Int = 1, keepRank: Bool = false) -> MLTensor { 6 | x.squared().sum(alongAxes: alongAxes, keepRank: keepRank).squareRoot() 7 | } 8 | 9 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 10 | public func cosineSimilarity(_ x: MLTensor, _ y: MLTensor, alongAxes: Int = 1) -> MLTensor { 11 | let normX = norm(x, alongAxes: alongAxes) 12 | let normY = norm(y, alongAxes: alongAxes) 13 | return x.matmul(y.transposed()) / (normX * normY) 14 | } 15 | 16 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 17 | public func dotProduct(_ x: MLTensor, _ y: MLTensor) -> MLTensor { 18 | x.transposed().matmul(y) 19 | } 20 | 21 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 22 | public func cosineDistance(_ x: MLTensor, _ y: MLTensor, alongAxes: Int = 1) -> MLTensor { 23 | 1 - cosineSimilarity(x, y, alongAxes: alongAxes) 24 | } 25 | 26 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 27 | public func euclideanDistance(_ x: MLTensor, _ y: MLTensor, alongAxes: Int = 1) -> MLTensor { 28 | (x - y).squared().sum(alongAxes: alongAxes).squareRoot() 29 | } 30 | 31 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 32 | public func additiveCausalMask( 33 | _ n: Int32, 34 | scalarType: Scalar.Type = Float.self 35 | ) -> MLTensor { 36 | let indices = MLTensor(0.. MLTensor 5 | 6 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 7 | public func embedding(weight: MLTensor) -> Layer { 8 | { x in 9 | weight.gathering(atIndices: x, alongAxis: 0) 10 | } 11 | } 12 | 13 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 14 | public func linear(weight: MLTensor, bias: MLTensor? = nil) -> Layer { 15 | { x in 16 | if let bias { 17 | x.matmul(weight.transposed()) + bias 18 | } else { 19 | x.matmul(weight.transposed()) 20 | } 21 | } 22 | } 23 | 24 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 25 | public func layerNorm(weight: MLTensor, bias: MLTensor, epsilon: Float) -> Layer { 26 | { x in 27 | let mean = x.mean(alongAxes: -1, keepRank: true) 28 | let xshift = x - mean 29 | let variance = xshift.squared().mean(alongAxes: -1, keepRank: true) 30 | let invstd = (variance + epsilon).rsqrt() 31 | let norm = xshift * invstd 32 | return norm * weight + bias 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /Sources/TestingUtils/TestingUtils.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Numerics 3 | import Testing 4 | 5 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 6 | extension MLTensor { 7 | package func scalars( 8 | of scalarType: Scalar.Type 9 | ) async -> [Scalar] where Scalar: MLShapedArrayScalar, Scalar: MLTensorScalar { 10 | await shapedArray(of: scalarType).scalars 11 | } 12 | } 13 | 14 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 15 | extension MLTensor { 16 | package static func float(shape: [Int]) -> MLTensor { 17 | let count = shape.reduce(1, *) 18 | return MLTensor(shape: shape, scalars: (0.. MLTensor { 22 | let count = shape.reduce(1, *) 23 | return MLTensor(shape: shape, scalars: (0..( 28 | _ lhs: [T], 29 | _ rhs: [T], 30 | absoluteTolerance: T.Magnitude = T.Magnitude.ulpOfOne.squareRoot() 31 | * T.Magnitude.leastNormalMagnitude, 32 | relativeTolerance: T.Magnitude = T.Magnitude.ulpOfOne.squareRoot() 33 | ) -> Bool where T.Magnitude: FloatingPoint { 34 | guard lhs.count == rhs.count else { 35 | Issue.record("Expected \(lhs) to be approximately equal to \(rhs), but sizes differ") 36 | return false 37 | } 38 | for (l, r) in zip(lhs, rhs) { 39 | guard 40 | l.isApproximatelyEqual( 41 | to: r, 42 | absoluteTolerance: absoluteTolerance, 43 | relativeTolerance: relativeTolerance 44 | ) 45 | else { 46 | Issue.record("Expected \(lhs) to be approximately equal to \(rhs), but \(l) != \(r)") 47 | return false 48 | } 49 | } 50 | return true 51 | } 52 | -------------------------------------------------------------------------------- /Tests/AccuracyTests/AccuracyTests.swift: -------------------------------------------------------------------------------- 1 | import Command 2 | import CoreML 3 | import Testing 4 | import TestingUtils 5 | 6 | @testable import Embeddings 7 | 8 | /* 9 | NOTE: 10 | The following test are testing the accuracy of the embeddings generated by the Swift models 11 | against the embeddings generated by the Python transformers library. They they are slow 12 | and require the [uv](https://github.com/astral-sh/uv) command line tool to be available. 13 | 14 | This suite can be run using the following command from the command line: 15 | 16 | ``` 17 | PYTORCH_ENABLE_MPS_FALLBACK=1 UV_PATH=$(which uv) swift test --filter AccuracyTests 18 | ``` 19 | 20 | */ 21 | 22 | func generateUsingTransformers( 23 | modelPath: String, 24 | text: String, 25 | modelType: ModelType 26 | ) async throws -> [Float] { 27 | let scriptUrl = try #require( 28 | Bundle.module.path(forResource: "generate", ofType: "py", inDirectory: "Scripts"), 29 | "Script not found" 30 | ) 31 | let uvPath = try #require(ProcessInfo.processInfo.environment["UV_PATH"], "UV_PATH not found") 32 | let arguments = [uvPath, "run", "--quiet", scriptUrl, modelPath, text, modelType.rawValue] 33 | let result = 34 | try await Command 35 | .run(arguments: arguments) 36 | .concatenatedString() 37 | return 38 | result 39 | .components(separatedBy: .newlines) 40 | .filter { !$0.isEmpty } 41 | .map { stringValue -> Float in 42 | guard let value = Float(stringValue) else { 43 | fatalError("Invalid float value in stdout: \(stringValue)") 44 | } 45 | return value 46 | } 47 | } 48 | 49 | func modelPath(modelId: String, cacheDirectory: URL) -> String { 50 | cacheDirectory 51 | .appendingPathComponent("models") 52 | .appendingPathComponent(modelId) 53 | .path() 54 | } 55 | 56 | enum ModelType: String { 57 | case bert 58 | case clip 59 | case model2Vec = "model2vec" 60 | case roberta 61 | case staticEmbeddings = "static-embeddings" 62 | case xlmRoberta = "xlm-roberta" 63 | } 64 | 65 | @Suite(.enabled(if: ProcessInfo.processInfo.environment["UV_PATH"] != nil)) 66 | struct AccuracyTests { 67 | let cacheDirectory: URL = 68 | if let modelPath = ProcessInfo.processInfo.environment["MODEL_PATH"] { 69 | URL(fileURLWithPath: modelPath) 70 | } else { 71 | FileManager.default.temporaryDirectory 72 | } 73 | 74 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 75 | @Test("Bert Accuracy", arguments: ["Text to encode", "", "❤️"]) 76 | func bertAccuracy(_ text: String) async throws { 77 | let modelId = "google-bert/bert-base-uncased" 78 | let modelBundle = try await Bert.loadModelBundle( 79 | from: modelId, 80 | downloadBase: cacheDirectory, 81 | loadConfig: .googleBert 82 | ) 83 | let encoded = try modelBundle.encode(text) 84 | let swiftData = await encoded.cast(to: Float.self).scalars(of: Float.self) 85 | let modelPath = modelPath(modelId: modelId, cacheDirectory: cacheDirectory) 86 | let pythonData = try await generateUsingTransformers( 87 | modelPath: modelPath, 88 | text: text, 89 | modelType: .bert 90 | ) 91 | 92 | #expect(allClose(pythonData, swiftData, absoluteTolerance: 1e-5) == true) 93 | } 94 | 95 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 96 | @Test("Clip Accuracy") 97 | func clipAccuracy() async throws { 98 | let text = "a photo of a dog" 99 | let modelId = "jkrukowski/clip-vit-base-patch32" 100 | let modelBundle = try await Clip.loadModelBundle( 101 | from: modelId, 102 | downloadBase: cacheDirectory 103 | ) 104 | let tokens = try modelBundle.tokenizer.tokenizeText(text, maxLength: 77) 105 | let inputIds = MLTensor(shape: [1, tokens.count], scalars: tokens) 106 | let modelOutput = modelBundle.textModel(inputIds: inputIds) 107 | let swiftData = 108 | await modelOutput 109 | .poolerOutput 110 | .cast(to: Float.self) 111 | .scalars(of: Float.self) 112 | let modelPath = modelPath(modelId: modelId, cacheDirectory: cacheDirectory) 113 | let pythonData = try await generateUsingTransformers( 114 | modelPath: modelPath, 115 | text: text, 116 | modelType: .clip 117 | ) 118 | 119 | #expect(allClose(pythonData, swiftData, absoluteTolerance: 1e-5) == true) 120 | } 121 | 122 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 123 | @Test("Roberta Accuracy", arguments: ["Text to encode", "", "❤️"]) 124 | func robertaAccuracy(_ text: String) async throws { 125 | let modelId = "FacebookAI/roberta-base" 126 | let modelBundle = try await Roberta.loadModelBundle( 127 | from: modelId, 128 | downloadBase: cacheDirectory, 129 | loadConfig: .addWeightKeyPrefix("roberta.") 130 | ) 131 | let encoded = try modelBundle.encode(text) 132 | let swiftData = await encoded.cast(to: Float.self).scalars(of: Float.self) 133 | let modelPath = modelPath(modelId: modelId, cacheDirectory: cacheDirectory) 134 | let pythonData = try await generateUsingTransformers( 135 | modelPath: modelPath, 136 | text: text, 137 | modelType: .roberta 138 | ) 139 | 140 | #expect(allClose(pythonData, swiftData, absoluteTolerance: 1e-5) == true) 141 | } 142 | 143 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 144 | @Test("XLM Roberta Accuracy", arguments: ["Text to encode", "", "❤️"]) 145 | func xlmRobertaAccuracy(_ text: String) async throws { 146 | let modelId = "FacebookAI/xlm-roberta-base" 147 | let modelBundle = try await XLMRoberta.loadModelBundle( 148 | from: modelId, 149 | downloadBase: cacheDirectory, 150 | loadConfig: .addWeightKeyPrefix("roberta.") 151 | ) 152 | let encoded = try modelBundle.encode(text) 153 | let swiftData = await encoded.cast(to: Float.self).scalars(of: Float.self) 154 | let modelPath = modelPath(modelId: modelId, cacheDirectory: cacheDirectory) 155 | let pythonData = try await generateUsingTransformers( 156 | modelPath: modelPath, 157 | text: text, 158 | modelType: .xlmRoberta 159 | ) 160 | 161 | #expect(allClose(pythonData, swiftData, absoluteTolerance: 1e-5) == true) 162 | } 163 | 164 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 165 | @Test("Model2Vec Accuracy", arguments: ["Text to encode", "", "❤️"]) 166 | func model2VecAccuracy(_ text: String) async throws { 167 | let modelId = "minishlab/potion-base-2M" 168 | let modelBundle = try await Model2Vec.loadModelBundle( 169 | from: modelId, 170 | downloadBase: cacheDirectory 171 | ) 172 | let encoded = try modelBundle.encode(text, normalize: modelBundle.model.normalize) 173 | let swiftData = await encoded.cast(to: Float.self).scalars(of: Float.self) 174 | let modelPath = modelPath(modelId: modelId, cacheDirectory: cacheDirectory) 175 | let pythonData = try await generateUsingTransformers( 176 | modelPath: modelPath, 177 | text: text, 178 | modelType: .model2Vec 179 | ) 180 | 181 | #expect(allClose(pythonData, swiftData, absoluteTolerance: 1e-5) == true) 182 | } 183 | 184 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 185 | @Test("Static Embeddings Accuracy", arguments: ["Text to encode", "", "❤️"]) 186 | func staticEmbeddingsAccuracy(_ text: String) async throws { 187 | let modelId = "sentence-transformers/static-retrieval-mrl-en-v1" 188 | let modelBundle = try await StaticEmbeddings.loadModelBundle( 189 | from: modelId, 190 | downloadBase: cacheDirectory, 191 | loadConfig: LoadConfig.staticEmbeddings 192 | ) 193 | let encoded = try modelBundle.encode(text, normalize: true, truncateDimension: 1023) 194 | let swiftData = await encoded.cast(to: Float.self).scalars(of: Float.self) 195 | let modelPath = modelPath(modelId: modelId, cacheDirectory: cacheDirectory) 196 | let pythonData = try await generateUsingTransformers( 197 | modelPath: modelPath, 198 | text: text, 199 | modelType: .staticEmbeddings 200 | ) 201 | 202 | #expect(allClose(pythonData, swiftData, absoluteTolerance: 1e-5) == true) 203 | } 204 | } 205 | -------------------------------------------------------------------------------- /Tests/AccuracyTests/Scripts/generate.py: -------------------------------------------------------------------------------- 1 | # /// script 2 | # requires-python = "==3.12" 3 | # dependencies = [ 4 | # "torch", 5 | # "transformers", 6 | # "model2vec>=0.5.0", 7 | # "sentence-transformers", 8 | # ] 9 | # /// 10 | 11 | 12 | import warnings 13 | from transformers import AutoTokenizer, AutoModel 14 | from transformers import CLIPModel 15 | from transformers import logging 16 | from model2vec import StaticModel 17 | from sentence_transformers import SentenceTransformer 18 | import argparse 19 | 20 | 21 | def embeddings(model_dir, text): 22 | tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True) 23 | model = AutoModel.from_pretrained(model_dir, local_files_only=True) 24 | encoded_input = tokenizer(text, return_tensors="pt") 25 | output = model(**encoded_input) 26 | return output[0][:, 0, :].flatten().tolist() 27 | 28 | 29 | def clip_embeddings(model_dir, text): 30 | tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True) 31 | model = CLIPModel.from_pretrained(model_dir, local_files_only=True) 32 | encoded_input = tokenizer(text, return_tensors="pt") 33 | output = model.text_model(**encoded_input) 34 | return output.pooler_output.flatten().tolist() 35 | 36 | 37 | def model2vec_embeddings(model_dir, text): 38 | model = StaticModel.from_pretrained(model_dir) 39 | output = model.encode(text) 40 | return output.flatten().tolist() 41 | 42 | 43 | def static_embeddings(model_dir, text): 44 | model = SentenceTransformer(model_dir, truncate_dim=1023) 45 | output = model.encode(text, normalize_embeddings=True) 46 | return output.flatten().tolist() 47 | 48 | 49 | def main(model_dir, text, emb_type="bert"): 50 | if emb_type == "bert" or emb_type == "xlm-roberta" or emb_type == "roberta": 51 | values = embeddings(model_dir, text) 52 | elif emb_type == "clip": 53 | values = clip_embeddings(model_dir, text) 54 | elif emb_type == "model2vec": 55 | values = model2vec_embeddings(model_dir, text) 56 | elif emb_type == "static-embeddings": 57 | values = static_embeddings(model_dir, text) 58 | else: 59 | raise ValueError(f"Unknown emb_type: {emb_type}") 60 | print("\n".join([str(x) for x in values])) 61 | 62 | 63 | # run e.g: `uv run generate.py "./cache/google-bert/bert-base-uncased" "Text to encode"` bert 64 | if __name__ == "__main__": 65 | logging.set_verbosity_error() 66 | warnings.filterwarnings("ignore") 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument("model_dir", type=str, help="Model local dir") 69 | parser.add_argument("text", type=str, help="Text to embed") 70 | parser.add_argument("type", type=str, help="Embedding type") 71 | args = parser.parse_args() 72 | main(args.model_dir, args.text, args.type) 73 | -------------------------------------------------------------------------------- /Tests/EmbeddingsTests/BertTests.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import MLTensorUtils 3 | import Testing 4 | import TestingUtils 5 | import XCTest 6 | 7 | @testable import Embeddings 8 | 9 | struct BertTests { 10 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 11 | @Test func pooler() async { 12 | let pooler1 = Bert.Pooler( 13 | dense: MLTensorUtils.linear( 14 | weight: MLTensor.float(shape: [5, 5]), 15 | bias: nil 16 | ) 17 | ) 18 | let result1 = pooler1( 19 | MLTensor.float(shape: [1, 3, 5]) 20 | ) 21 | let data1 = await result1.scalars(of: Float.self) 22 | 23 | #expect(result1.shape == [1, 5]) 24 | #expect(allClose(data1, [1, 1, 1, 1, 1]) == true) 25 | 26 | let pooler2 = Bert.Pooler( 27 | dense: MLTensorUtils.linear( 28 | weight: MLTensor.float(shape: [5, 5]), 29 | bias: MLTensor.float(shape: [5]) 30 | ) 31 | ) 32 | let result2 = pooler2( 33 | MLTensor.float(shape: [1, 3, 5]) 34 | ) 35 | let data2 = await result2.scalars(of: Float.self) 36 | 37 | #expect(result2.shape == [1, 5]) 38 | #expect(allClose(data2, [1, 1, 1, 1, 1]) == true) 39 | } 40 | 41 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 42 | @Test func intermediate() async { 43 | let intermediate1 = Bert.Intermediate( 44 | dense: MLTensorUtils.linear( 45 | weight: MLTensor.float(shape: [2, 3]), 46 | bias: nil 47 | ) 48 | ) 49 | let result1 = intermediate1( 50 | hiddenStates: MLTensor.float(shape: [1, 2, 3]) 51 | ) 52 | let data1 = await result1.scalars(of: Float.self) 53 | 54 | #expect(result1.shape == [1, 2, 2]) 55 | #expect(allClose(data1, [5.0, 14.0, 14.0, 50.0]) == true) 56 | 57 | let intermediate2 = Bert.Intermediate( 58 | dense: MLTensorUtils.linear( 59 | weight: MLTensor.float(shape: [2, 3]), 60 | bias: MLTensor.float(shape: [2]) 61 | ) 62 | ) 63 | let result2 = intermediate2( 64 | hiddenStates: MLTensor.float(shape: [1, 2, 3]) 65 | ) 66 | let data2 = await result2.scalars(of: Float.self) 67 | 68 | #expect(result2.shape == [1, 2, 2]) 69 | #expect(allClose(data2, [5.0, 15.0, 14.0, 51.0]) == true) 70 | } 71 | 72 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 73 | @Test func output() async { 74 | let output1 = Bert.Output( 75 | dense: MLTensorUtils.linear( 76 | weight: MLTensor.float(shape: [4, 4]), 77 | bias: nil 78 | ), 79 | layerNorm: MLTensorUtils.layerNorm( 80 | weight: MLTensor.float(shape: [4]), 81 | bias: MLTensor.float(shape: [4]), 82 | epsilon: 1e-5 83 | ) 84 | ) 85 | 86 | let result1 = output1( 87 | hiddenStates: MLTensor.float(shape: [1, 2, 4]), 88 | inputTensor: MLTensor.float(shape: [1, 2, 4]) 89 | ) 90 | let data1 = await result1.scalars(of: Float.self) 91 | 92 | #expect(result1.shape == [1, 2, 4]) 93 | #expect( 94 | allClose( 95 | data1, 96 | [0.0, 0.5527864, 2.8944273, 7.0249224, 0.0, 0.5527864, 2.8944273, 7.0249224]) 97 | == true) 98 | 99 | let output2 = Bert.Output( 100 | dense: MLTensorUtils.linear( 101 | weight: MLTensor.float(shape: [4, 4]), 102 | bias: MLTensor.float(shape: [4]) 103 | ), 104 | layerNorm: MLTensorUtils.layerNorm( 105 | weight: MLTensor.float(shape: [4]), 106 | bias: MLTensor.float(shape: [4]), 107 | epsilon: 1e-5 108 | ) 109 | ) 110 | 111 | let result2 = output2( 112 | hiddenStates: MLTensor.float(shape: [1, 2, 4]), 113 | inputTensor: MLTensor.float(shape: [1, 2, 4]) 114 | ) 115 | let data2 = await result2.scalars(of: Float.self) 116 | 117 | #expect(result2.shape == [1, 2, 4]) 118 | #expect( 119 | allClose( 120 | data2, 121 | [0.0, 0.55278635, 2.8944273, 7.0249224, 0.0, 0.5527864, 2.8944273, 7.0249224]) 122 | == true) 123 | } 124 | } 125 | 126 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 127 | final class BertEmbeddingTests: XCTestCase { 128 | // NOTE: this test is not stable when running using `Testing` library, not sure why 129 | func testEmbeddings() async { 130 | let wordEmbeddings = MLTensorUtils.embedding(weight: MLTensor.float(shape: [2, 4])) 131 | let positionEmbeddings = MLTensorUtils.embedding(weight: MLTensor.float(shape: [1, 4])) 132 | let tokenTypeEmbeddings = MLTensorUtils.embedding(weight: MLTensor.float(shape: [2, 4])) 133 | let embeddings = Bert.Embeddings( 134 | wordEmbeddings: wordEmbeddings, 135 | positionEmbeddings: positionEmbeddings, 136 | tokenTypeEmbeddings: tokenTypeEmbeddings, 137 | layerNorm: MLTensorUtils.layerNorm( 138 | weight: MLTensor.float(shape: [4]), 139 | bias: MLTensor.float(shape: [4]), 140 | epsilon: 1e-5 141 | ) 142 | ) 143 | 144 | let result = embeddings(inputIds: MLTensor.int32(shape: [1, 2])) 145 | let data = await result.scalars(of: Float.self) 146 | 147 | XCTAssertEqual(result.shape, [1, 2, 4]) 148 | XCTAssertTrue( 149 | allClose(data, [0, 0.552787, 2.89443, 7.02492, 0, 0.552787, 2.89443, 7.02492])) 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /Tests/EmbeddingsTests/Model2VecTests.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import MLTensorUtils 3 | import Testing 4 | import TestingUtils 5 | 6 | @testable import Embeddings 7 | 8 | struct Model2VecTests { 9 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 10 | func modelBundle( 11 | tokenizedValues: [Int32], 12 | unknownTokenId: Int? = nil 13 | ) -> Model2Vec.ModelBundle { 14 | let data: [Float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 15 | let embeddings = MLTensor(shape: [3, 3], scalars: data) 16 | let model = Model2Vec.Model(embeddings: embeddings) 17 | let tokenizer = TextTokenizerMock( 18 | tokenizedValues: tokenizedValues, 19 | unknownTokenId: unknownTokenId 20 | ) 21 | return Model2Vec.ModelBundle(model: model, tokenizer: tokenizer) 22 | } 23 | 24 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 25 | @Test func model2VecEmbeddings() async throws { 26 | let modelBundle = modelBundle(tokenizedValues: [0, 1, 2]) 27 | let encoded = try modelBundle.encode("Text") 28 | let result = await encoded.scalars(of: Float.self) 29 | 30 | #expect(allClose(result, [0.4, 0.5, 0.6]) == true) 31 | } 32 | 33 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 34 | @Test func model2VecWhenNormalize() async throws { 35 | let modelBundle = modelBundle(tokenizedValues: [0, 1, 2]) 36 | let encoded = try modelBundle.encode("Text", normalize: true) 37 | let result = await encoded.scalars(of: Float.self) 38 | 39 | #expect(allClose(result, [0.45584226, 0.56980276, 0.6837634]) == true) 40 | } 41 | 42 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 43 | @Test func model2VecWhenUnknownTokenId() async throws { 44 | let modelBundle = modelBundle(tokenizedValues: [0, 1, 2], unknownTokenId: 0) 45 | let encoded = try modelBundle.encode("Text") 46 | let result = await encoded.scalars(of: Float.self) 47 | 48 | #expect(allClose(result, [0.55, 0.65, 0.75]) == true) 49 | } 50 | 51 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 52 | @Test func model2VecWhenTokenizerReturnsEmpty() async throws { 53 | let modelBundle = modelBundle(tokenizedValues: []) 54 | let encoded = try modelBundle.encode("Text") 55 | let result = await encoded.scalars(of: Float.self) 56 | 57 | #expect(allClose(result, [0.0, 0.0, 0.0]) == true) 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /Tests/EmbeddingsTests/StaticEmbeddingsTests.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import MLTensorUtils 3 | import Testing 4 | import TestingUtils 5 | 6 | @testable import Embeddings 7 | 8 | struct StaticEmbeddingsTests { 9 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 10 | func modelBundle(tokenizedValues: [Int32]) -> StaticEmbeddings.ModelBundle { 11 | let data: [Float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 12 | let embeddings = MLTensor(shape: [3, 3], scalars: data) 13 | let model = StaticEmbeddings.Model(embeddings: embeddings) 14 | let tokenizer = TextTokenizerMock(tokenizedValues: tokenizedValues) 15 | return StaticEmbeddings.ModelBundle(model: model, tokenizer: tokenizer) 16 | } 17 | 18 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 19 | @Test func staticEmbeddings() async throws { 20 | let modelBundle = modelBundle(tokenizedValues: [0, 1, 2]) 21 | let encoded = try modelBundle.encode("Text") 22 | let result = await encoded.scalars(of: Float.self) 23 | 24 | #expect(allClose(result, [0.4, 0.5, 0.6]) == true) 25 | } 26 | 27 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 28 | @Test func staticEmbeddingsWhenNormalize() async throws { 29 | let modelBundle = modelBundle(tokenizedValues: [0, 1, 2]) 30 | let encoded = try modelBundle.encode("Text", normalize: true) 31 | let result = await encoded.scalars(of: Float.self) 32 | 33 | #expect(allClose(result, [0.45584226, 0.56980276, 0.6837634]) == true) 34 | } 35 | 36 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 37 | @Test func staticEmbeddingsWhenTruncateDimension() async throws { 38 | let modelBundle = modelBundle(tokenizedValues: [0, 1, 2]) 39 | let encoded = try modelBundle.encode("Text", truncateDimension: 2) 40 | let result = await encoded.scalars(of: Float.self) 41 | 42 | #expect(allClose(result, [0.4, 0.5]) == true) 43 | } 44 | 45 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 46 | @Test func staticEmbeddingsWhenTokenizerReturnsEmpty() async throws { 47 | let modelBundle = modelBundle(tokenizedValues: []) 48 | let encoded = try modelBundle.encode("Text") 49 | let result = await encoded.scalars(of: Float.self) 50 | 51 | #expect(allClose(result, [0.0, 0.0, 0.0]) == true) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /Tests/EmbeddingsTests/TokenizerTests.swift: -------------------------------------------------------------------------------- 1 | import Foundation 2 | import Testing 3 | 4 | @testable import Embeddings 5 | 6 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 7 | @Test func clipTokenizer() throws { 8 | let bundleUrl = Bundle.module 9 | .url(forResource: "merges", withExtension: "txt", subdirectory: "Resources")? 10 | .deletingLastPathComponent() 11 | let url = try #require(bundleUrl, "Wrong bundle URL") 12 | let tokenizer = try loadClipTokenizer(at: url) 13 | 14 | #expect(tokenizer.tokenize("", maxLength: 128, addSpecialTokens: true) == [49406, 49407]) 15 | #expect(tokenizer.tokenize("", maxLength: 128, addSpecialTokens: false) == []) 16 | #expect( 17 | tokenizer.tokenize("a photo of a cat", maxLength: 128, addSpecialTokens: true) 18 | == [49406, 320, 1125, 539, 320, 2368, 49407]) 19 | #expect( 20 | tokenizer.tokenize("a photo of a cat", maxLength: 128, addSpecialTokens: false) 21 | == [320, 1125, 539, 320, 2368]) 22 | #expect( 23 | tokenizer.tokenize("a photo of a cat", maxLength: 5, addSpecialTokens: true) 24 | == [49406, 320, 1125, 539, 49407]) 25 | #expect( 26 | tokenizer.tokenize("a photo of a cat", maxLength: 5, addSpecialTokens: false) 27 | == [320, 1125, 539, 320, 2368]) 28 | #expect( 29 | tokenizer.tokenize( 30 | "a photo of a cat", maxLength: 128, padToLength: 10, addSpecialTokens: true) 31 | == [49406, 320, 1125, 539, 320, 2368, 49407, 0, 0, 0]) 32 | #expect( 33 | tokenizer.tokenize( 34 | "a photo of a cat", maxLength: 128, padToLength: 10, addSpecialTokens: false) 35 | == [320, 1125, 539, 320, 2368, 0, 0, 0, 0, 0]) 36 | #expect( 37 | tokenizer.tokenize( 38 | "a photo of a cat", maxLength: 5, padToLength: 10, addSpecialTokens: true) 39 | == [49406, 320, 1125, 539, 49407]) 40 | #expect( 41 | tokenizer.tokenize( 42 | "a photo of a cat", maxLength: 5, padToLength: 10, addSpecialTokens: false) 43 | == [320, 1125, 539, 320, 2368]) 44 | #expect( 45 | tokenizer.tokenize("a photo of a cat", maxLength: 128, addSpecialTokens: true) 46 | == tokenizer.tokenize( 47 | " a photo of a cat ", maxLength: 128, addSpecialTokens: true) 48 | ) 49 | #expect( 50 | tokenizer.tokenize("a photo of a cat", maxLength: 128, addSpecialTokens: true) 51 | == tokenizer.tokenize("A pHotO of a CaT", maxLength: 128, addSpecialTokens: true) 52 | ) 53 | } 54 | -------------------------------------------------------------------------------- /Tests/EmbeddingsTests/UtilsTests.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Testing 3 | 4 | @testable import Embeddings 5 | 6 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 7 | @Test func googleWeightsKeyTransform() async { 8 | #expect(Bert.googleWeightsKeyTransform("some.weight.key") == "bert.some.weight.key") 9 | #expect(Bert.googleWeightsKeyTransform("some.LayerNorm.weight") == "bert.some.LayerNorm.gamma") 10 | #expect(Bert.googleWeightsKeyTransform("some.LayerNorm.bias") == "bert.some.LayerNorm.beta") 11 | #expect(Bert.googleWeightsKeyTransform("some.Embedding.weight") == "bert.some.Embedding.weight") 12 | #expect(Bert.googleWeightsKeyTransform("some.Embedding.bias") == "bert.some.Embedding.bias") 13 | } 14 | 15 | final class TextTokenizerMock: TextTokenizer { 16 | let tokenizedValues: [Int32] 17 | let unknownTokenId: Int? 18 | 19 | init(tokenizedValues: [Int32], unknownTokenId: Int? = nil) { 20 | self.tokenizedValues = tokenizedValues 21 | self.unknownTokenId = unknownTokenId 22 | } 23 | 24 | func tokenizeText(_ text: String, maxLength: Int?, addSpecialTokens: Bool) throws -> [Int32] { 25 | tokenizedValues 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /Tests/EmbeddingsTests/Word2VecTests.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import MLTensorUtils 3 | import Testing 4 | import TestingUtils 5 | 6 | @testable import Embeddings 7 | 8 | struct Word2VecTests { 9 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 10 | @Test func mostSimilar() async { 11 | let modelBundle = Word2Vec.ModelBundle( 12 | keyToIndex: ["a": 0, "b": 1, "c": 2, "d": 3], 13 | indexToKey: [0: "a", 1: "b", 2: "c", 3: "d"], 14 | embeddings: MLTensor.float(shape: [4, 2]) 15 | ) 16 | let result1 = await modelBundle.mostSimilar(to: "a", topK: 3) 17 | let words1 = result1.map(\.word) 18 | let scores1 = result1.map(\.score) 19 | #expect(words1 == ["b", "c", "d"]) 20 | #expect(allClose(scores1, [0.8320502, 0.78086865, 0.7592565]) == true) 21 | 22 | let result2 = await modelBundle.mostSimilar(to: "c") 23 | let words2 = result2.map(\.word) 24 | let scores2 = result2.map(\.score) 25 | #expect(words2 == ["d"]) 26 | #expect(allClose(scores2, [0.9994259]) == true) 27 | } 28 | 29 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 30 | @Test func encode() async throws { 31 | let modelBundle = Word2Vec.ModelBundle( 32 | keyToIndex: ["a": 0, "b": 1, "c": 2, "d": 3], 33 | indexToKey: [0: "a", 1: "b", 2: "c", 3: "d"], 34 | embeddings: MLTensor.float(shape: [4, 2]) 35 | ) 36 | #expect(modelBundle.encode("e") == nil) 37 | 38 | let encoded = try #require(modelBundle.encode("a")) 39 | #expect(encoded.shape == [2]) 40 | let encodedShapedArray = await encoded.shapedArray(of: Float.self).scalars 41 | #expect(encodedShapedArray == [0.0, 1.0]) 42 | 43 | let batchEncoded = try #require(modelBundle.batchEncode(["a", "c", "d", "e"])) 44 | #expect(batchEncoded.shape == [3, 2]) 45 | let batchEncodedShapedArray = await batchEncoded.shapedArray(of: Float.self).scalars 46 | #expect(batchEncodedShapedArray == [0.0, 1.0, 4.0, 5.0, 6.0, 7.0]) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /Tests/MLTensorUtilsTests/ActivationTests.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Numerics 3 | import Testing 4 | import TestingUtils 5 | 6 | @testable import MLTensorUtils 7 | 8 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 9 | @Test func erf() async { 10 | let input1 = MLTensor(shape: [1], scalars: [0], scalarType: Float.self) 11 | let result1 = await erf(input1).shapedArray(of: Float.self).scalars 12 | #expect(result1 == [0]) 13 | 14 | let input2 = MLTensor(shape: [1], scalars: [.infinity], scalarType: Float.self) 15 | let result2 = await erf(input2).shapedArray(of: Float.self).scalars 16 | #expect(result2 == [1]) 17 | 18 | let input3 = MLTensor(shape: [1], scalars: [-.infinity], scalarType: Float.self) 19 | let result3 = await erf(input3).shapedArray(of: Float.self).scalars 20 | #expect(result3 == [-1]) 21 | 22 | let input4 = MLTensor( 23 | shape: [6], scalars: [0.9, 0.5, 0.1, -0.1, -0.5, -0.9], scalarType: Float.self) 24 | let result4 = await erf(input4).shapedArray(of: Float.self).scalars 25 | let expected4: [Float] = [ 26 | 0.7969082124228322, 27 | 0.5204998778130465, 28 | 0.1124629160182849, 29 | -0.1124629160182849, 30 | -0.5204998778130465, 31 | -0.7969082124228322, 32 | ] 33 | #expect(allClose(result4, expected4) == true) 34 | } 35 | 36 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 37 | @Test func gelu() async { 38 | let input1 = MLTensor(shape: [1], scalars: [0], scalarType: Float.self) 39 | let result1 = await gelu(input1).shapedArray(of: Float.self).scalars 40 | #expect(result1 == [0]) 41 | 42 | let input2 = MLTensor(shape: [1], scalars: [100], scalarType: Float.self) 43 | let result2 = await gelu(input2).shapedArray(of: Float.self).scalars 44 | #expect(result2 == [100]) 45 | 46 | let input3 = MLTensor(shape: [1], scalars: [-100], scalarType: Float.self) 47 | let result3 = await gelu(input3).shapedArray(of: Float.self).scalars 48 | #expect(result3 == [0]) 49 | 50 | let input4 = MLTensor( 51 | shape: [6], scalars: [0.9, 0.5, 0.1, -0.1, -0.5, -0.9], scalarType: Float.self) 52 | let result4 = await gelu(input4).shapedArray(of: Float.self).scalars 53 | let expected4: [Float] = [ 54 | 0.734346, 0.345731, 0.0539828, -0.0460172, -0.154269, -0.165654, 55 | ] 56 | #expect(allClose(result4, expected4) == true) 57 | } 58 | 59 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 60 | @Test func geluApproximationFast() async { 61 | let input1 = MLTensor(shape: [1], scalars: [0], scalarType: Float.self) 62 | let result1 = await gelu(input1, approximation: .fast).shapedArray(of: Float.self).scalars 63 | #expect(result1 == [0]) 64 | 65 | let input2 = MLTensor(shape: [1], scalars: [100], scalarType: Float.self) 66 | let result2 = await gelu(input2, approximation: .fast).shapedArray(of: Float.self).scalars 67 | #expect(result2 == [100]) 68 | 69 | let input3 = MLTensor(shape: [1], scalars: [-100], scalarType: Float.self) 70 | let result3 = await gelu(input3, approximation: .fast).shapedArray(of: Float.self).scalars 71 | #expect(result3 == [0]) 72 | 73 | let input4 = MLTensor( 74 | shape: [6], scalars: [0.9, 0.5, 0.1, -0.1, -0.5, -0.9], scalarType: Float.self) 75 | let result4 = await gelu(input4, approximation: .fast).shapedArray(of: Float.self).scalars 76 | let expected4: [Float] = [ 77 | 0.740043, 0.350388, 0.0542448, -0.0457552, -0.149612, -0.159957, 78 | ] 79 | #expect(allClose(result4, expected4) == true) 80 | } 81 | 82 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 83 | @Test func geluApproximationPrecise() async { 84 | let input1 = MLTensor(shape: [1], scalars: [0], scalarType: Float.self) 85 | let result1 = await gelu(input1, approximation: .precise).shapedArray(of: Float.self).scalars 86 | #expect(result1 == [0]) 87 | 88 | let input2 = MLTensor(shape: [1], scalars: [100], scalarType: Float.self) 89 | let result2 = await gelu(input2, approximation: .precise).shapedArray(of: Float.self).scalars 90 | #expect(result2 == [100]) 91 | 92 | let input3 = MLTensor(shape: [1], scalars: [-100], scalarType: Float.self) 93 | let result3 = await gelu(input3, approximation: .precise).shapedArray(of: Float.self).scalars 94 | #expect(result3 == [0]) 95 | 96 | let input4 = MLTensor( 97 | shape: [6], scalars: [0.9, 0.5, 0.1, -0.1, -0.5, -0.9], scalarType: Float.self) 98 | let result4 = await gelu(input4, approximation: .precise).shapedArray(of: Float.self).scalars 99 | let expected4: [Float] = [ 100 | 0.734228, 0.345714, 0.0539828, -0.0460172, -0.154286, -0.165772, 101 | ] 102 | #expect(allClose(result4, expected4) == true) 103 | } 104 | -------------------------------------------------------------------------------- /Tests/MLTensorUtilsTests/FunctionTests.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Numerics 3 | import Testing 4 | import TestingUtils 5 | 6 | @testable import MLTensorUtils 7 | 8 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 9 | @Test func additiveCasualMask() async { 10 | let result = additiveCausalMask(3) 11 | 12 | #expect(result.shape == [3, 3]) 13 | let resultArray = await result.scalars(of: Float.self) 14 | let expectedArray: [Float] = [0, -1e9, -1e9, 0, 0, -1e9, 0, 0, 0] 15 | #expect(allClose(resultArray, expectedArray) == true) 16 | } 17 | 18 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 19 | @Test func norm() async { 20 | let x = MLTensor( 21 | shape: [2, 3], 22 | scalars: [1, 2, 3, 4, 5, 6], 23 | scalarType: Float.self 24 | ) 25 | let result = norm(x) 26 | 27 | #expect(result.shape == [2]) 28 | let resultArray = await result.scalars(of: Float.self) 29 | let expectedArray: [Float] = [3.7417, 8.7749] 30 | #expect(allClose(resultArray, expectedArray) == true) 31 | } 32 | 33 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 34 | @Test func cosineSimilarity1D() async { 35 | let x = MLTensor( 36 | shape: [1, 6], 37 | scalars: [1, 2, 3, 4, 5, 6], 38 | scalarType: Float.self 39 | ) 40 | let y = MLTensor( 41 | shape: [1, 6], 42 | scalars: [4, 5, 6, 7, 8, 9], 43 | scalarType: Float.self 44 | ) 45 | let result = cosineSimilarity(x, y) 46 | 47 | #expect(result.shape == [1, 1]) 48 | let resultArray = await result.shapedArray(of: Float.self).scalars 49 | let expectedArray: [Float] = [0.980653] 50 | #expect(allClose(resultArray, expectedArray) == true) 51 | } 52 | 53 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 54 | @Test func cosineSimilarity2D() async { 55 | let x = MLTensor( 56 | shape: [2, 3], 57 | scalars: [1, 2, 3, 4, 5, 6], 58 | scalarType: Float.self 59 | ) 60 | let y = MLTensor( 61 | shape: [2, 3], 62 | scalars: [4, 5, 6, 7, 8, 9], 63 | scalarType: Float.self 64 | ) 65 | let result = cosineSimilarity(x, y) 66 | 67 | #expect(result.shape == [2, 2]) 68 | let resultArray = await result.shapedArray(of: Float.self).scalars 69 | let expectedArray: [Float] = [0.9746318, 0.4090946, 2.3452077, 0.9981909] 70 | #expect(allClose(resultArray, expectedArray) == true) 71 | } 72 | 73 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 74 | @Test func cosineSimilaritySameTensor() async { 75 | let x = MLTensor( 76 | shape: [2, 3], 77 | scalars: [1, 2, 3, 1, 2, 3], 78 | scalarType: Float.self 79 | ) 80 | let result = cosineSimilarity(x, x) 81 | 82 | #expect(result.shape == [2, 2]) 83 | let resultArray = await result.shapedArray(of: Float.self).scalars 84 | let expectedArray: [Float] = [1.0, 1.0, 1.0, 1.0] 85 | #expect(allClose(resultArray, expectedArray) == true) 86 | } 87 | 88 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 89 | @Test func dotProduct1D() async { 90 | let x = MLTensor( 91 | shape: [6], 92 | scalars: [1, 2, 3, 4, 5, 6], 93 | scalarType: Float.self 94 | ) 95 | let y = MLTensor( 96 | shape: [6], 97 | scalars: [4, 5, 6, 7, 8, 9], 98 | scalarType: Float.self 99 | ) 100 | let result = dotProduct(x, y) 101 | 102 | #expect(result.shape == []) 103 | let resultArray = await result.shapedArray(of: Float.self).scalars 104 | let expectedArray: [Float] = [154] 105 | #expect(allClose(resultArray, expectedArray) == true) 106 | } 107 | 108 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 109 | @Test func dotProduct2D() async { 110 | let x = MLTensor( 111 | shape: [2, 2], 112 | scalars: [1, 0, 0, 1], 113 | scalarType: Float.self 114 | ) 115 | let y = MLTensor( 116 | shape: [2, 2], 117 | scalars: [4, 5, 6, 7], 118 | scalarType: Float.self 119 | ) 120 | let result = dotProduct(x, y) 121 | 122 | #expect(result.shape == [2, 2]) 123 | let resultArray = await result.shapedArray(of: Float.self).scalars 124 | let expectedArray: [Float] = [4, 5, 6, 7] 125 | #expect(allClose(resultArray, expectedArray) == true) 126 | } 127 | 128 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 129 | @Test func euclideanDistance1D() async { 130 | let x = MLTensor( 131 | shape: [6], 132 | scalars: [1, 2, 3, 4, 5, 6], 133 | scalarType: Float.self 134 | ) 135 | let y = MLTensor( 136 | shape: [6], 137 | scalars: [4, 5, 6, 7, 8, 9], 138 | scalarType: Float.self 139 | ) 140 | let result = euclideanDistance(x, y, alongAxes: 0) 141 | 142 | #expect(result.shape == []) 143 | let resultArray = await result.shapedArray(of: Float.self).scalars 144 | let expectedArray: [Float] = [7.34846] 145 | #expect(allClose(resultArray, expectedArray) == true) 146 | } 147 | 148 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 149 | @Test func euclideanDistance2D() async { 150 | let x = MLTensor( 151 | shape: [2, 2], 152 | scalars: [1, 0, 0, 1], 153 | scalarType: Float.self 154 | ) 155 | let y = MLTensor( 156 | shape: [2, 2], 157 | scalars: [4, 5, 6, 7], 158 | scalarType: Float.self 159 | ) 160 | let result = euclideanDistance(x, y, alongAxes: 1) 161 | 162 | #expect(result.shape == [2]) 163 | let resultArray = await result.shapedArray(of: Float.self).scalars 164 | let expectedArray: [Float] = [5.83095, 8.48528] 165 | #expect(allClose(resultArray, expectedArray) == true) 166 | } 167 | 168 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 169 | @Test func euclideanDistanceSameTensor() async { 170 | let x = MLTensor( 171 | shape: [6], 172 | scalars: [1, 2, 3, 4, 5, 6], 173 | scalarType: Float.self 174 | ) 175 | let result = euclideanDistance(x, x, alongAxes: 0) 176 | 177 | #expect(result.shape == []) 178 | let resultArray = await result.shapedArray(of: Float.self).scalars 179 | let expectedArray: [Float] = [0.0] 180 | #expect(allClose(resultArray, expectedArray) == true) 181 | } 182 | -------------------------------------------------------------------------------- /Tests/MLTensorUtilsTests/LayerTests.swift: -------------------------------------------------------------------------------- 1 | import CoreML 2 | import Numerics 3 | import Testing 4 | import TestingUtils 5 | 6 | @testable import MLTensorUtils 7 | 8 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 9 | @Test func embeddingLayer1D() async { 10 | let embedding = embedding(weight: MLTensor.float(shape: [12])) 11 | let result = embedding(MLTensor([0, 2, 4] as [Int32])) 12 | 13 | #expect(result.shape == [3]) 14 | let resultArray = await result.scalars(of: Float.self) 15 | #expect(resultArray == [0, 2, 4]) 16 | } 17 | 18 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 19 | @Test func embeddingLayer2D() async { 20 | let embedding = embedding(weight: MLTensor.float(shape: [6, 2])) 21 | let result = embedding(MLTensor([0, 2, 4] as [Int32])) 22 | 23 | #expect(result.shape == [3, 2]) 24 | let resultArray = await result.scalars(of: Float.self) 25 | #expect(resultArray == [0, 1, 4, 5, 8, 9]) 26 | } 27 | 28 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 29 | @Test func embeddingLayer3D() async { 30 | let embedding = embedding(weight: MLTensor.float(shape: [2, 2, 2])) 31 | let result = embedding(MLTensor([0, 1] as [Int32])) 32 | 33 | #expect(result.shape == [2, 2, 2]) 34 | let resultArray = await result.scalars(of: Float.self) 35 | #expect(resultArray == [0, 1, 2, 3, 4, 5, 6, 7]) 36 | } 37 | 38 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 39 | @Test func layerNorm1D() async { 40 | let weight = MLTensor( 41 | shape: [3], 42 | scalars: [1, 2, 3], 43 | scalarType: Float.self 44 | ) 45 | let bias = MLTensor( 46 | shape: [3], 47 | scalars: [4, 5, 6], 48 | scalarType: Float.self 49 | ) 50 | let layerNorm = layerNorm(weight: weight, bias: bias, epsilon: 1e-5) 51 | let input = MLTensor( 52 | shape: [2, 3], 53 | scalars: [1, 2, 3, 4, 5, 6], 54 | scalarType: Float.self 55 | ) 56 | let result = layerNorm(input) 57 | 58 | #expect(result.shape == [2, 3]) 59 | let resultArray = await result.scalars(of: Float.self) 60 | let expectedArray: [Float] = [2.7753, 5.0000, 9.6742, 2.7753, 5.0000, 9.6742] 61 | #expect(allClose(resultArray, expectedArray) == true) 62 | } 63 | 64 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 65 | @Test func layerNorm2D() async { 66 | let weight = MLTensor( 67 | shape: [2, 3], 68 | scalars: [1, 2, 3, 4, 5, 6], 69 | scalarType: Float.self 70 | ) 71 | let bias = MLTensor( 72 | shape: [2, 3], 73 | scalars: [4, 5, 6, 7, 8, 9], 74 | scalarType: Float.self 75 | ) 76 | let layerNorm = layerNorm(weight: weight, bias: bias, epsilon: 1e-5) 77 | let input = MLTensor( 78 | shape: [1, 2, 3], 79 | scalars: [1, 2, 3, 4, 5, 6], 80 | scalarType: Float.self 81 | ) 82 | let result = layerNorm(input) 83 | 84 | #expect(result.shape == [1, 2, 3]) 85 | let resultArray = await result.shapedArray(of: Float.self).scalars 86 | let expectedArray: [Float] = [2.7753, 5.0000, 9.6742, 2.1011, 8.0000, 16.3484] 87 | #expect(allClose(resultArray, expectedArray) == true) 88 | } 89 | 90 | @available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, watchOS 11.0, *) 91 | @Test func layerNorm3D() async { 92 | let weight = MLTensor( 93 | shape: [1, 2, 3], 94 | scalars: [1, 2, 3, 4, 5, 6], 95 | scalarType: Float.self 96 | ) 97 | let bias = MLTensor( 98 | shape: [1, 2, 3], 99 | scalars: [4, 5, 6, 7, 8, 9], 100 | scalarType: Float.self 101 | ) 102 | let layerNorm = layerNorm(weight: weight, bias: bias, epsilon: 1e-5) 103 | let input = MLTensor( 104 | shape: [1, 2, 3], 105 | scalars: [1, 2, 3, 4, 5, 6], 106 | scalarType: Float.self 107 | ) 108 | let result = layerNorm(input) 109 | 110 | #expect(result.shape == [1, 2, 3]) 111 | let resultArray = await result.shapedArray(of: Float.self).scalars 112 | let expectedArray: [Float] = [2.7753, 5.0000, 9.6742, 2.1011, 8.0000, 16.3484] 113 | #expect(allClose(resultArray, expectedArray) == true) 114 | } 115 | --------------------------------------------------------------------------------