├── .gitignore ├── LICENSE ├── Models ├── DecisionTreeModel-Mushrooms.mlmodel └── RandomForestModel-Mushrooms.mlmodel ├── Playgrounds ├── DecisionTreeClassifier.playground │ ├── Contents.swift │ ├── Resources │ │ └── Mushrooms.csv │ └── contents.xcplayground ├── ImageClassifierBuilder.playground │ ├── Contents.swift │ └── contents.xcplayground ├── RandomForestClassifier.playground │ ├── Contents.swift │ ├── Resources │ │ └── Mushrooms.csv │ └── contents.xcplayground ├── TextClassifier.playground │ ├── Contents.swift │ ├── Sources │ │ └── ModelPerformance.swift │ └── contents.xcplayground └── WordTagger.playground │ ├── Contents.swift │ └── contents.xcplayground └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Xcode 2 | # 3 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore 4 | 5 | ## Build generated 6 | build/ 7 | DerivedData/ 8 | 9 | ## Various settings 10 | *.pbxuser 11 | !default.pbxuser 12 | *.mode1v3 13 | !default.mode1v3 14 | *.mode2v3 15 | !default.mode2v3 16 | *.perspectivev3 17 | !default.perspectivev3 18 | xcuserdata/ 19 | 20 | ## Other 21 | *.moved-aside 22 | *.xccheckout 23 | *.xcscmblueprint 24 | 25 | ## Obj-C/Swift specific 26 | *.hmap 27 | *.ipa 28 | *.dSYM.zip 29 | *.dSYM 30 | 31 | ## Playgrounds 32 | timeline.xctimeline 33 | playground.xcworkspace 34 | 35 | # Swift Package Manager 36 | # 37 | # Add this line if you want to avoid checking in source code from Swift Package Manager dependencies. 38 | # Packages/ 39 | # Package.pins 40 | # Package.resolved 41 | .build/ 42 | 43 | # CocoaPods 44 | # 45 | # We recommend against adding the Pods directory to your .gitignore. However 46 | # you should judge for yourself, the pros and cons are mentioned at: 47 | # https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control 48 | # 49 | # Pods/ 50 | 51 | # Carthage 52 | # 53 | # Add this line if you want to avoid checking in source code from Carthage dependencies. 54 | # Carthage/Checkouts 55 | 56 | Carthage/Build 57 | 58 | # fastlane 59 | # 60 | # It is recommended to not store the screenshots in the git repo. Instead, use fastlane to re-generate the 61 | # screenshots whenever they are needed. 62 | # For more information about the recommended setup visit: 63 | # https://docs.fastlane.tools/best-practices/source-control/#source-control 64 | 65 | fastlane/report.xml 66 | fastlane/Preview.html 67 | fastlane/screenshots/**/*.png 68 | fastlane/test_output 69 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Paolo Di Lorenzo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Models/DecisionTreeModel-Mushrooms.mlmodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pdil/ClassifierKit/44c10e38017f2ff0e210a70c3556b3b3e1920388/Models/DecisionTreeModel-Mushrooms.mlmodel -------------------------------------------------------------------------------- /Models/RandomForestModel-Mushrooms.mlmodel: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pdil/ClassifierKit/44c10e38017f2ff0e210a70c3556b3b3e1920388/Models/RandomForestModel-Mushrooms.mlmodel -------------------------------------------------------------------------------- /Playgrounds/DecisionTreeClassifier.playground/Contents.swift: -------------------------------------------------------------------------------- 1 | /** 2 | 3 | Mushrooms.csv Dataset 4 | 5 | Source: 6 | 7 | https://archive.ics.uci.edu/ml/datasets/Mushroom 8 | 9 | Citation: 10 | 11 | Dua, D. and Karra Taniskidou, E. (2017). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science. 12 | 13 | Attribute information: 14 | 15 | 1. cap-shape: bell=b,conical=c,convex=x,flat=f,knobbed=k,sunken=s 16 | 2. cap-surface: fibrous=f,grooves=g,scaly=y,smooth=s 17 | 3. cap-color: brown=n,buff=b,cinnamon=c,gray=g,green=r,pink=p,purple=u,red=e,white=w,yellow=y 18 | 4. bruises?: bruises=t,no=f 19 | 5. odor: almond=a,anise=l,creosote=c,fishy=y,foul=f,musty=m,none=n,pungent=p,spicy=s 20 | 6. gill-attachment: attached=a,descending=d,free=f,notched=n 21 | 7. gill-spacing: close=c,crowded=w,distant=d 22 | 8. gill-size: broad=b,narrow=n 23 | 9. gill-color: black=k,brown=n,buff=b,chocolate=h,gray=g,green=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y 24 | 10. stalk-shape: enlarging=e,tapering=t 25 | 11. stalk-root: bulbous=b,club=c,cup=u,equal=e,rhizomorphs=z,rooted=r,missing=? 26 | 12. stalk-surface-above-ring: fibrous=f,scaly=y,silky=k,smooth=s 27 | 13. stalk-surface-below-ring: fibrous=f,scaly=y,silky=k,smooth=s 28 | 14. stalk-color-above-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o,pink=p,red=e,white=w,yellow=y 29 | 15. stalk-color-below-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o,pink=p,red=e,white=w,yellow=y 30 | 16. veil-type: partial=p,universal=u 31 | 17. veil-color: brown=n,orange=o,white=w,yellow=y 32 | 18. ring-number: none=n,one=o,two=t 33 | 19. ring-type: cobwebby=c,evanescent=e,flaring=f,large=l,none=n,pendant=p,sheathing=s,zone=z 34 | 20. spore-print-color: black=k,brown=n,buff=b,chocolate=h,green=r,orange=o,purple=u,white=w,yellow=y 35 | 21. population: abundant=a,clustered=c,numerous=n,scattered=s,several=v,solitary=y 36 | 22. habitat: grasses=g,leaves=l,meadows=m,paths=p,urban=u,waste=w,woods=d 37 | 38 | */ 39 | 40 | /// File path containing data (JSON or CSV). 41 | let dataFileName: String = "Mushrooms" 42 | 43 | /// Column within data table to train model on. 44 | let targetColumn: String = "class" 45 | 46 | /// Output .mlmodel file path 47 | let mlmodelFileName: String = <#/path/to/.mlmodel output file#> 48 | 49 | /// ============================================= 50 | 51 | import Foundation 52 | import CreateML 53 | 54 | // Import data 55 | guard let dataFilePath = Bundle.main.path(forResource: dataFileName, ofType: "csv") else { 56 | fatalError("\(dataFileName) not found in Resources folder.") 57 | } 58 | 59 | let data = try MLDataTable(contentsOf: URL(fileURLWithPath: dataFilePath)) 60 | 61 | // Split into training/testing data 62 | let (trainingData, testingData) = data.randomSplit(by: 0.8, seed: 1000) 63 | 64 | // Train decision tree classifier 65 | let decisionTreeClassifier = try MLDecisionTreeClassifier(trainingData: trainingData, targetColumn: targetColumn) 66 | 67 | // Export Core ML file 68 | let metadata = MLModelMetadata( 69 | author: "ClassifierKit", 70 | shortDescription: "A decision tree classification model for identifing whether or not a mushroom with the given attributes is edible or poisonous.", 71 | license: "MIT", 72 | version: "1.0", 73 | additional: nil 74 | ) 75 | try decisionTreeClassifier.write(toFile: mlmodelFileName, metadata: metadata) 76 | -------------------------------------------------------------------------------- /Playgrounds/DecisionTreeClassifier.playground/contents.xcplayground: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /Playgrounds/ImageClassifierBuilder.playground/Contents.swift: -------------------------------------------------------------------------------- 1 | /*: Explanation 2 | 3 | # Image Classifier Builder 4 | 5 | ## Introduction 6 | 7 | This playground uses the `MLImageClassifierBuilder` class of `Create ML` to render a drag-and-drop interface in the playground's live view. The interface accepts a set of folders containing images and labels to train the model with. 8 | 9 | ## Presenting the live view 10 | 11 | To present the interface for training the model from within the Playground, first run `builder.showInLiveView()` by click the play arrow at the beginning of the line. 12 | 13 | Next, open the Assistant Editor (⌥ ⌘ ↩︎) to see the Live View. 14 | 15 | ## Training the model 16 | 17 | Drag the "Training Data" folder from the Resources folder to the "Drop Images to Begin Testing" section of the Live View. 18 | 19 | `Create ML` will begin training the model using the images. 20 | 21 | ## Testing the model 22 | 23 | Drag the "Testing Data" folder from the Resources folder to the Live View once again ("Drop Images to Begin Testing") to test the model. 24 | 25 | ## Saving the trained model 26 | 27 | If the testing results are satisfactory, the trained model can either be saved to your local computer as a `.mlmodel` file or dragged directly into an Xcode project to quickly be implemented into an app. 28 | 29 | */ 30 | 31 | import CreateMLUI 32 | 33 | let builder = MLImageClassifierBuilder() 34 | builder.showInLiveView() 35 | -------------------------------------------------------------------------------- /Playgrounds/ImageClassifierBuilder.playground/contents.xcplayground: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /Playgrounds/RandomForestClassifier.playground/Contents.swift: -------------------------------------------------------------------------------- 1 | /** 2 | 3 | Mushrooms.csv Dataset 4 | 5 | Source: 6 | 7 | https://archive.ics.uci.edu/ml/datasets/Mushroom 8 | 9 | Citation: 10 | 11 | Dua, D. and Karra Taniskidou, E. (2017). UCI Machine Learning Repository [http://archive.ics.uci.edu/ml]. Irvine, CA: University of California, School of Information and Computer Science. 12 | 13 | Attribute information: 14 | 15 | 1. cap-shape: bell=b,conical=c,convex=x,flat=f,knobbed=k,sunken=s 16 | 2. cap-surface: fibrous=f,grooves=g,scaly=y,smooth=s 17 | 3. cap-color: brown=n,buff=b,cinnamon=c,gray=g,green=r,pink=p,purple=u,red=e,white=w,yellow=y 18 | 4. bruises?: bruises=t,no=f 19 | 5. odor: almond=a,anise=l,creosote=c,fishy=y,foul=f,musty=m,none=n,pungent=p,spicy=s 20 | 6. gill-attachment: attached=a,descending=d,free=f,notched=n 21 | 7. gill-spacing: close=c,crowded=w,distant=d 22 | 8. gill-size: broad=b,narrow=n 23 | 9. gill-color: black=k,brown=n,buff=b,chocolate=h,gray=g,green=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y 24 | 10. stalk-shape: enlarging=e,tapering=t 25 | 11. stalk-root: bulbous=b,club=c,cup=u,equal=e,rhizomorphs=z,rooted=r,missing=? 26 | 12. stalk-surface-above-ring: fibrous=f,scaly=y,silky=k,smooth=s 27 | 13. stalk-surface-below-ring: fibrous=f,scaly=y,silky=k,smooth=s 28 | 14. stalk-color-above-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o,pink=p,red=e,white=w,yellow=y 29 | 15. stalk-color-below-ring: brown=n,buff=b,cinnamon=c,gray=g,orange=o,pink=p,red=e,white=w,yellow=y 30 | 16. veil-type: partial=p,universal=u 31 | 17. veil-color: brown=n,orange=o,white=w,yellow=y 32 | 18. ring-number: none=n,one=o,two=t 33 | 19. ring-type: cobwebby=c,evanescent=e,flaring=f,large=l,none=n,pendant=p,sheathing=s,zone=z 34 | 20. spore-print-color: black=k,brown=n,buff=b,chocolate=h,green=r,orange=o,purple=u,white=w,yellow=y 35 | 21. population: abundant=a,clustered=c,numerous=n,scattered=s,several=v,solitary=y 36 | 22. habitat: grasses=g,leaves=l,meadows=m,paths=p,urban=u,waste=w,woods=d 37 | 38 | */ 39 | 40 | /// File path containing data (JSON or CSV). 41 | let dataFileName: String = "Mushrooms" 42 | 43 | /// Column within data table to train model on. 44 | let targetColumn: String = "class" 45 | 46 | /// Output .mlmodel file path 47 | let mlmodelFileName: String = <#/path/to/.mlmodel output file#> 48 | 49 | /// ============================================= 50 | 51 | import Foundation 52 | import CreateML 53 | 54 | // Import data 55 | guard let dataFilePath = Bundle.main.path(forResource: dataFileName, ofType: "csv") else { 56 | fatalError("\(dataFileName) not found in Resources folder.") 57 | } 58 | 59 | let data = try MLDataTable(contentsOf: URL(fileURLWithPath: dataFilePath)) 60 | 61 | // Split into training/testing data 62 | let (trainingData, testingData) = data.randomSplit(by: 0.8, seed: 1000) 63 | 64 | // Train decision tree classifier 65 | let randomForestClassifier = try MLRandomForestClassifier(trainingData: trainingData, targetColumn: targetColumn) 66 | 67 | // Export Core ML file 68 | let metadata = MLModelMetadata( 69 | author: "ClassifierKit", 70 | shortDescription: "A random forest classification model for identifing whether or not a mushroom with the given attributes is edible or poisonous.", 71 | license: "MIT", 72 | version: "1.0", 73 | additional: nil 74 | ) 75 | try randomForestClassifier.write(toFile: mlmodelFileName, metadata: metadata) 76 | -------------------------------------------------------------------------------- /Playgrounds/RandomForestClassifier.playground/contents.xcplayground: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /Playgrounds/TextClassifier.playground/Contents.swift: -------------------------------------------------------------------------------- 1 | /// Enter the following values: 2 | 3 | /// The file path containing the data table to train 4 | /// the classifier from. Can be JSON or CSV file. 5 | let dataTablePath: String = <#data/table/path.csv#> 6 | 7 | /// The column name in the data table of text that is being 8 | /// used to train the model with. 9 | let textColumnName: String = <#textColumnName#> 10 | 11 | /// The column name in the data table of labels corresponding 12 | /// to the text in `textColumnName`. 13 | let labelColumnName: String = <#labelColumnName#> 14 | 15 | /// The file path to write the trained model to once 16 | /// classification is complete. Must be .mlmodel file. 17 | let modelPath: String = <#completed/model/path.mlmodel#> 18 | 19 | /// Percentage of data to be used as training data 20 | /// (the rest is used for testing the model). 21 | let trainingPercentage: Double = <#0.8#> 22 | 23 | /// The random seed used for splitting data into 24 | /// training and testing data sets. Use this to 25 | /// replicate results whenever re-running the model 26 | /// or change it to obtain different training/testing sets. 27 | let seed: Int = <#1000#> 28 | 29 | /// Model metadata 30 | let author: String = <#Your name#> 31 | let shortDescription: String = <#A description of this model#> 32 | let license: String? = <#License (optional)#> 33 | let version: String = <#Model version#> 34 | let additional: [String: String]? = <#Any additional metadata for this model (optional)#> 35 | 36 | /// ======================================================= 37 | 38 | import Foundation 39 | import CreateML 40 | 41 | // Import data from `dataTablePath` 42 | let data = try MLDataTable(contentsOf: URL(fileURLWithPath: dataTablePath)) 43 | 44 | // Create training/testing data 45 | let (trainingData, testingData) = data.randomSplit(by: trainingPercentage, seed: seed) 46 | 47 | // Train text classifier model 48 | let textClassifier = try MLTextClassifier(trainingData: trainingData, textColumn: textColumnName, labelColumn: labelColumnName) 49 | 50 | // Output trained model 51 | let metadata = MLModelMetadata(author: author, shortDescription: shortDescription, license: license, version: version, additional: additional) 52 | try textClassifier.write(to: URL(fileURLWithPath: modelPath), metadata: metadata) 53 | 54 | // Performance metrics 55 | print("Training accuracy: \(textClassifier.trainingAccuracy())%") 56 | print("Validation accuracy: \(textClassifier.validationAccuracy())%") 57 | -------------------------------------------------------------------------------- /Playgrounds/TextClassifier.playground/Sources/ModelPerformance.swift: -------------------------------------------------------------------------------- 1 | 2 | import CreateML 3 | 4 | public extension MLTextClassifier { 5 | /// Returns the training accuracy as a percentage. 6 | public func trainingAccuracy() -> Double { 7 | return (1.0 - trainingMetrics.classificationError) * 100 8 | } 9 | 10 | /// Returns the validation accuracy as a percentage. 11 | public func validationAccuracy() -> Double { 12 | return (1.0 - validationMetrics.classificationError) * 100 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /Playgrounds/TextClassifier.playground/contents.xcplayground: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /Playgrounds/WordTagger.playground/Contents.swift: -------------------------------------------------------------------------------- 1 | import Cocoa 2 | 3 | var str = "Hello, playground" 4 | 5 | 6 | -------------------------------------------------------------------------------- /Playgrounds/WordTagger.playground/contents.xcplayground: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ClassifierKit 2 | 🤖 A suite of tools and examples for training Core ML models with Create ML. 3 | 4 | ## 📄 Requirements 5 | * macOS 10.14 (Mojave) or later ([download](https://developer.apple.com/download/)) 6 | * Xcode 10 or later ([download](https://developer.apple.com/download/)) 7 | * Swift 4.2 or later 8 | 9 | **Important Note:** [`Create ML`](https://developer.apple.com/documentation/create_ml) is not available for the iOS SDK. It can only be used on macOS to train models and is not intended for on-device training. Instead, it is used to train models with data (which may take minutes to hours depending on computing power, size of dataset, and model). When the model is trained, a `.mlmodel` file can be exported and implemented in iOS/tvOS/watchOS/macOS apps using [`Core ML`](https://developer.apple.com/documentation/coreml). 10 | 11 | ## ⚙️ Models 12 | The following models are available as example Playgrounds: 13 | 14 | | Model | Associated Type | Playground | 15 | | --- | --- | :---:| 16 | | 🌅 Image Classifier | [`MLImageClassifier`](https://developer.apple.com/documentation/create_ml/mlimageclassifier) | | 17 | | 🌅 Image Classifier Builder | [`MLImageClassifierBuilder`](https://developer.apple.com/documentation/create_ml/mlimageclassifierbuilder) | [🔗](https://github.com/pdil/ClassifierKit/tree/master/Playgrounds/ImageClassifierBuilder.playground) | 18 | | 📄 Text Classifier | [`MLTextClassifier`](https://developer.apple.com/documentation/create_ml/mltextclassifier) | [🔗](https://github.com/pdil/ClassifierKit/tree/master/Playgrounds/TextClassifier.playground) | 19 | | 🏷️ Word Tagger | [`MLWordTagger`](https://developer.apple.com/documentation/create_ml/mlwordtagger) | [🔗](https://github.com/pdil/ClassifierKit/tree/master/Playgrounds/WordTagger.playground) | 20 | | 📊 Decision Tree Classifier | [`MLDecisionTreeClassifier`](https://developer.apple.com/documentation/create_ml/mldecisiontreeclassifier) | [🔗](https://github.com/pdil/ClassifierKit/tree/master/Playgrounds/DecisionTreeClassifier.playground) | 21 | | 📊 Random Forest Classifier | [`MLRandomForestClassifier`](https://developer.apple.com/documentation/create_ml/mlrandomforestclassifier) | [🔗](https://github.com/pdil/ClassifierKit/tree/master/Playgrounds/RandomForestClassifier.playground) | 22 | | 📊 Boosted Tree Classifier | [`MLBoostedTreeClassifier`](https://developer.apple.com/documentation/create_ml/mlboostedtreeclassifier) | | 23 | | 📊 Logistic Regression Classifier | [`MLLogisticRegressionClassifier`](https://developer.apple.com/documentation/create_ml/mllogisticregressionclassifier) | | 24 | | 📊 Support Vector Classifier | [`MLSupportVectorClassifier`](https://developer.apple.com/documentation/create_ml/mlsupportvectorclassifier) | | 25 | | 📈 Linear Regressor | [`MLLinearRegressor`](https://developer.apple.com/documentation/create_ml/mllinearregressor) | | 26 | | 📈 Decision Tree Regressor | [`MLDecisionTreeRegressor`](https://developer.apple.com/documentation/create_ml/mldecisiontreeregressor) | | 27 | | 📈 Boosted Tree Regressor | [`MLBoostedTreeRegressor`](https://developer.apple.com/documentation/create_ml/mlboostedtreeregressor) | | 28 | | 📈 Random Forest Regressor | [`MLRandomForestRegressor`](https://developer.apple.com/documentation/create_ml/mlrandomforestregressor) | | 29 | 30 | **Note:** Some of these are incomplete and are currently being added. The goal is to eventually have comprehensive example playgrounds for each model type in Create ML, including sample data and explanations. See [Project #1](https://github.com/pdil/ClassifierKit/projects/1) to track the progress of these playgrounds. 31 | 32 | ## 📝 Usage 33 | 34 | The easiest way to begin using ClassifierKit is to clone it directly onto your computer. 35 | 36 | 1. Navigate to the desired directory on your local filesystem. 37 | ``` 38 | $ cd Desktop/or/any/other/folder 39 | ``` 40 | 2. Clone this repository: 41 | ``` 42 | $ git clone https://github.com/pdil/ClassifierKit.git 43 | ``` 44 | 3. Begin! The [`Playgrounds`](https://github.com/pdil/ClassifierKit/tree/master/Playgrounds) folder contains Swift Playgrounds for the many models contained within Create ML that will allow you to set parameters and begin training the models, either with the provided sample data or your own data. 45 | 46 | ## 🗃️ References 47 | * [Introducing Create ML · WWDC 2018 · Session 703](https://developer.apple.com/videos/play/wwdc2018/703/) 48 | * [Create ML · Apple Developer Documentation](https://developer.apple.com/documentation/create_ml) 49 | 50 | ### Datasets 51 | * [Mushroom Data Set · UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/Mushroom) 52 | --------------------------------------------------------------------------------