├── .gitignore ├── LICENSE ├── MNIST ├── t10k-images-idx3-ubyte ├── t10k-labels-idx1-ubyte ├── train-images-idx3-ubyte └── train-labels-idx1-ubyte ├── Package.pins ├── Package.swift ├── README.md └── Sources ├── MNISTManager.swift └── main.swift /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | /.build 3 | /Packages 4 | /*.xcodeproj 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Swift-AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MNIST/t10k-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Swift-AI/NeuralNet-MNIST/3d73f8e98eb16e8b6de808b5cceb69e49506e20d/MNIST/t10k-images-idx3-ubyte -------------------------------------------------------------------------------- /MNIST/t10k-labels-idx1-ubyte: -------------------------------------------------------------------------------- 1 | '                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             -------------------------------------------------------------------------------- /MNIST/train-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Swift-AI/NeuralNet-MNIST/3d73f8e98eb16e8b6de808b5cceb69e49506e20d/MNIST/train-images-idx3-ubyte -------------------------------------------------------------------------------- /MNIST/train-labels-idx1-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Swift-AI/NeuralNet-MNIST/3d73f8e98eb16e8b6de808b5cceb69e49506e20d/MNIST/train-labels-idx1-ubyte -------------------------------------------------------------------------------- /Package.pins: -------------------------------------------------------------------------------- 1 | { 2 | "autoPin": true, 3 | "pins": [ 4 | { 5 | "package": "NeuralNet", 6 | "reason": null, 7 | "repositoryURL": "https://github.com/Swift-AI/NeuralNet.git", 8 | "version": "0.3.0" 9 | } 10 | ], 11 | "version": 1 12 | } -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version:3.1 2 | 3 | import PackageDescription 4 | 5 | let package = Package( 6 | name: "NeuralNet-MNIST", 7 | dependencies: [ 8 | .Package(url: "https://github.com/Swift-AI/NeuralNet.git", majorVersion: 0, minor: 3) 9 | ] 10 | ) 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Banner](https://github.com/Swift-AI/Swift-AI/blob/master/SiteAssets/banner.png) 2 | 3 | # NeuralNet-MNIST 4 | 5 | An MNIST handwriting training example for the [NeuralNet](https://github.com/Swift-AI/NeuralNet) package. 6 | 7 | This application is part of the Swift AI project. Full details on the project can be found in the [main repo](https://github.com/Swift-AI/Swift-AI). 8 | 9 | ## Installation 10 | 11 | 1. Clone the repository: 12 | 13 | ```sh 14 | git clone https://github.com/Swift-AI/NeuralNet-MNIST.git 15 | ``` 16 | 17 | 2. Generate Xcode project: 18 | 19 | ```sh 20 | swift package generate-xcodeproj 21 | ``` 22 | 23 | ## Training 24 | 25 | ### Setup 26 | 27 | This project comes packaged with training data and a pre-built routine. The only thing you need to do is edit the following lines at the top of `main.swift`: 28 | 29 | ```swift 30 | // Path to MNIST dataset directory. 31 | let mnistDataDir = "/PATH/TO/NeuralNet-MNIST/MNIST" 32 | 33 | // Full filepath for trained network output file. 34 | let outputFilepath = "/PATH/TO/neuralnet-mnist-trained" 35 | ``` 36 | 37 | You should set `mnistDataDir` to the absolute path of the training data directory (included in this repository). This is necessary because Swift Package Manager currently doesn't support app bundles. 38 | 39 | `outputFilepath` will be the location that the final, trained network is stored. You may set it to whatever you like. 40 | 41 | ### Run 42 | 43 | Once these paths are set, just hit run and watch! 44 | 45 | ***Always run the trainer in release mode!*** or it will be a long day :) 46 | 47 | ### Customization 48 | 49 | You can customize a number of parameters in the trainer and and neural network. The [NeuralNet](https://github.com/Swift-AI/NeuralNet) package contains more information on how to construct a neural net, so we won't go into detail here. 50 | 51 | See the information at the top of `main.swift` for more inspiration. 52 |   53 | ## Data 54 | 55 | The [MNIST dataset](http://yann.lecun.com/exdb/mnist/) is used for training. This includes 70,000 handwriting samples of the digits 0-9. 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /Sources/MNISTManager.swift: -------------------------------------------------------------------------------- 1 | // 2 | // MNISTManager.swift 3 | // NeuralNet-MNIST 4 | // 5 | // Created by Collin Hundley on 4/12/17. 6 | // 7 | // 8 | 9 | import Foundation 10 | 11 | 12 | class MNISTManager { 13 | 14 | /// All data files in the MNIST dataset. 15 | fileprivate enum File: String { 16 | case trainImages = "train-images-idx3-ubyte" 17 | case trainLabels = "train-labels-idx1-ubyte" 18 | case validationImages = "t10k-images-idx3-ubyte" 19 | case validationLabels = "t10k-labels-idx1-ubyte" 20 | } 21 | 22 | 23 | // MARK: Data caches 24 | 25 | let trainImages: [[[Float]]] 26 | let trainLabels: [[[Float]]] 27 | let validationImages: [[[Float]]] 28 | let validationLabels: [[[Float]]] 29 | 30 | 31 | // MARK: One-hot encoding helper 32 | // Note: This is easy, fast and convenient since MNIST only has 10 classifications 33 | 34 | fileprivate static let labelEncodings: [[Float]] = [ 35 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 36 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], 37 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], 38 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], 39 | [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], 40 | [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 41 | [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], 42 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], 43 | [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], 44 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1] 45 | ] 46 | 47 | 48 | // MARK: Initialization 49 | 50 | /// Initializes an instance of `MNISTManager` and caches the full dataset for quick access. 51 | /// 52 | /// - Parameter directory: The filepath leading to the MNIST data directory. 53 | /// Example: "/Users/Bob/Documents/NeuralNet-MNIST/MNIST/" 54 | /// This is necessary because SwiftPM does not yet support bundles. 55 | /// - Parameter pixelRange: A range [Float, Float] to scale image pixels. 56 | /// The desired range will depend on the activation functions used in the neural network. 57 | /// For example, a network with Logistic hidden activation might want to scale pixels to [-1, 1]. 58 | init(directory: String, pixelRange: (min: Float, max: Float), batchSize: Int) throws { 59 | // Cache training images 60 | trainImages = try MNISTManager.extractImages(from: .trainImages, directory: directory, range: pixelRange, batchSize: batchSize) 61 | // Cache training labels 62 | trainLabels = try MNISTManager.extractLabels(from: .trainLabels, directory: directory, batchSize: batchSize) 63 | // Cache validation images 64 | validationImages = try MNISTManager.extractImages(from: .validationImages, directory: directory, range: pixelRange, batchSize: batchSize) 65 | // Cache validation labels 66 | validationLabels = try MNISTManager.extractLabels(from: .validationLabels, directory: directory, batchSize: batchSize) 67 | } 68 | 69 | } 70 | 71 | 72 | // MARK: Data and file management 73 | 74 | private extension MNISTManager { 75 | 76 | /// Extracts image data from the given file, with all bytes scaled to the given range. 77 | static func extractImages(from file: File, directory: String, range: (min: Float, max: Float), batchSize: Int) throws -> [[[Float]]] { 78 | /// Scales a byte to the correct range. 79 | func scale(x: UInt8) -> Float { 80 | return (range.max - range.min) * Float(x) / 255 + range.min 81 | } 82 | // Read data from file and drop header data 83 | let url = URL(fileURLWithPath: directory).appendingPathComponent(file.rawValue) 84 | let data = try readFile(url: url).dropFirst(16) 85 | // Convert UInt8 array to Float array, scaled to the specified range 86 | let array = data.map{scale(x: $0)} 87 | // Split array into segments of length 784 (1 image = 28x28 pixels) 88 | return createBatches(stride(from: 0, to: array.count, by: 784).map{Array(array[$0.. [[[Float]]] { 94 | // Read data from file and drop header data 95 | let url = URL(fileURLWithPath: directory).appendingPathComponent(file.rawValue) 96 | let data = try readFile(url: url).dropFirst(8) 97 | // Lookup one-hot encodings in our key 98 | return createBatches(data.map{labelEncodings[Int($0)]}, size: batchSize) 99 | } 100 | 101 | /// Attempts to read the file with the given path, and returns its raw data. 102 | private static func readFile(url: URL) throws -> Data { 103 | return try Data(contentsOf: url) 104 | } 105 | 106 | /// Groups the given set of data into batches of the specified size. 107 | private static func createBatches(_ set: [[Float]], size: Int) -> [[[Float]]] { 108 | var output = [[[Float]]]() 109 | let numBatches = set.count / size 110 | for batchIdx in 0.. Bool in 83 | 84 | // Log progress 85 | let percCorrect = (1 - err) * 100 86 | let percError = err * 100 87 | print("\nEpoch \(epoch)") 88 | print("Accuracy:\t\(percCorrect)%") 89 | print("Error:\t\t\(percError)%") 90 | 91 | // Decay learning rate and momentum 92 | nn.learningRate *= 0.97 93 | nn.momentumFactor *= 0.97 94 | 95 | // Allow training to continue 96 | return true 97 | } 98 | 99 | // Save net to disk and log result 100 | try nn.save(to: URL(fileURLWithPath: outputFilepath)) 101 | print("\n--------------------------- DONE ---------------------------") 102 | print("\nFinal accuracy: \((1 - error) * 100)%") 103 | print("Trained network stored at: \(outputFilepath)\n\n") 104 | 105 | } catch { 106 | print(error) 107 | } 108 | 109 | --------------------------------------------------------------------------------