├── .gitignore ├── DigitRecognizer ├── Dataset │ ├── test_dataset.zip │ ├── train_dataset.zip │ └── train_expanded_dataset.zip ├── DigitRecognizer.Core │ ├── Data │ │ ├── BoundingBox.cs │ │ ├── Box.cs │ │ ├── LinkedList.cs │ │ ├── LinkedListNode.cs │ │ └── MnistImageBatch.cs │ ├── DigitRecognizer.Core.csproj │ ├── Extensions │ │ ├── ByteExtensions.cs │ │ ├── DoubleExtensions.cs │ │ ├── ImageExtensions.cs │ │ ├── RandomExtensions.cs │ │ ├── StreamExtensions.cs │ │ └── VectorExtensions.cs │ ├── IO │ │ ├── ILabelReader.cs │ │ ├── IPixelReader.cs │ │ ├── ImagePreprocessor.cs │ │ ├── LabelReader.cs │ │ ├── MemoryStreamReader.cs │ │ ├── NnBinaryDeserializer.cs │ │ ├── NnBinarySerializer.cs │ │ ├── NnSerializableBase.cs │ │ ├── NnSerializationContext.cs │ │ ├── NnSerializationContextInfo.cs │ │ └── PixelReader.cs │ ├── Properties │ │ └── AssemblyInfo.cs │ └── Utilities │ │ ├── ConsoleUtility.cs │ │ ├── Contracts.cs │ │ ├── DirectoryHelper.cs │ │ ├── ImageUtilities.cs │ │ ├── MathUtilities.cs │ │ └── VectorUtilities.cs ├── DigitRecognizer.DatasetExpansion │ ├── Api │ │ ├── DatasetExpander.cs │ │ └── DatasetSerializer.cs │ ├── App.config │ ├── DigitRecognizer.DatasetExpansion.csproj │ ├── Infrastructure │ │ ├── AffineTransformation.cs │ │ ├── DatasetProvider.cs │ │ ├── FisherYatesShuffle.cs │ │ └── MnistImage.cs │ ├── Program.cs │ ├── Properties │ │ └── AssemblyInfo.cs │ └── favico.ico ├── DigitRecognizer.Engine │ ├── App.config │ ├── DigitRecognizer.Engine.csproj │ ├── Program.cs │ ├── Properties │ │ └── AssemblyInfo.cs │ └── favico.ico ├── DigitRecognizer.MachineLearning │ ├── DigitRecognizer.MachineLearning.csproj │ ├── Infrastructure │ │ ├── Data │ │ │ └── CalculationCache.cs │ │ ├── Dropout │ │ │ ├── BinomialDistribution.cs │ │ │ └── Dropout.cs │ │ ├── Factories │ │ │ ├── AbstractTypeFactory.cs │ │ │ ├── FunctionFactory.cs │ │ │ └── InitializerFactory.cs │ │ ├── Functions │ │ │ ├── CrossEntropy.cs │ │ │ ├── ExponentialRelu.cs │ │ │ ├── IActivationFunction.cs │ │ │ ├── ICostFunction.cs │ │ │ ├── IDifferentiableFunction.cs │ │ │ ├── IFunction.cs │ │ │ ├── LeakyRelu.cs │ │ │ ├── MeanSquareError.cs │ │ │ ├── Relu.cs │ │ │ ├── Sigmoid.cs │ │ │ ├── Softmax.cs │ │ │ ├── Softplus.cs │ │ │ └── Tanh.cs │ │ ├── Initialization │ │ │ ├── HeInitializer.cs │ │ │ ├── IInitializer.cs │ │ │ ├── InitializerType.cs │ │ │ ├── RandomInitializer.cs │ │ │ ├── XavierInitializer.cs │ │ │ └── ZeroInitializer.cs │ │ ├── Models │ │ │ ├── ClusterPredictionModel.cs │ │ │ ├── IPredictionModel.cs │ │ │ └── PredictionModel.cs │ │ └── NeuralNetwork │ │ │ ├── BiasVector.cs │ │ │ ├── INeuralNetwork.cs │ │ │ ├── IValueAdjustable.cs │ │ │ ├── NeuralNetwork.cs │ │ │ ├── NnLayer.cs │ │ │ └── WeightMatrix.cs │ ├── Optimization │ │ ├── LearningRateDecay │ │ │ ├── ExponentailDecay.cs │ │ │ ├── ILearningRateDecay.cs │ │ │ ├── StepDecay.cs │ │ │ └── TimeBasedDecay.cs │ │ └── Optimizers │ │ │ ├── BaseOptimizer.cs │ │ │ ├── GradientDescentOptimizer.cs │ │ │ ├── IOptimizer.cs │ │ │ └── MomentumOptimizer.cs │ ├── Pipeline │ │ ├── ILearningPipelineDataLoader.cs │ │ ├── ILearningPipelineItem.cs │ │ ├── ILearningPipelineNeuralNetworkModel.cs │ │ ├── ILearningPipelineOptimizer.cs │ │ ├── LearningPipeline.cs │ │ ├── PipelineExtensions.cs │ │ └── PipelineSettings.cs │ ├── Properties │ │ └── AssemblyInfo.cs │ ├── Providers │ │ ├── BatchDataProvider.cs │ │ ├── DataProviderBase.cs │ │ └── IDataProvider.cs │ └── Serialization │ │ ├── INnSerializable.cs │ │ ├── NnDeserializer.cs │ │ └── NnSerializer.cs ├── DigitRecognizer.Presentation │ ├── App.config │ ├── Components │ │ ├── ImageGrid.Designer.cs │ │ ├── ImageGrid.cs │ │ ├── ImageGrid.resx │ │ ├── PredictionPane.Designer.cs │ │ ├── PredictionPane.cs │ │ └── PredictionPane.resx │ ├── DigitRecognizer.Presentation.csproj │ ├── Global.cs │ ├── Infrastructure │ │ ├── DependencyResolver.cs │ │ ├── ExceptionHandlers.cs │ │ ├── PanelDoubleBuffering.cs │ │ └── PredictionModelLoader.cs │ ├── Models │ │ └── ImageGridModel.cs │ ├── Presenters │ │ ├── ApplicationPresenter.cs │ │ ├── BenchmarkPresenter.cs │ │ ├── DrawingPresenter.cs │ │ ├── SlidingWindowPresenter.cs │ │ └── UploadImagePresenter.cs │ ├── Program.cs │ ├── Properties │ │ ├── AssemblyInfo.cs │ │ ├── Resources.Designer.cs │ │ ├── Resources.resx │ │ ├── Settings.Designer.cs │ │ └── Settings.settings │ ├── Services │ │ ├── ILoggingService.cs │ │ ├── IMessageService.cs │ │ ├── LoggingService.cs │ │ └── MessageService.cs │ ├── Startup.cs │ ├── Views │ │ ├── Implementations │ │ │ ├── BenchmarkView.Designer.cs │ │ │ ├── BenchmarkView.cs │ │ │ ├── BenchmarkView.resx │ │ │ ├── DrawingView.Designer.cs │ │ │ ├── DrawingView.cs │ │ │ ├── DrawingView.resx │ │ │ ├── MainForm.Designer.cs │ │ │ ├── MainForm.cs │ │ │ ├── MainForm.resx │ │ │ ├── SlidingWindowView.Designer.cs │ │ │ ├── SlidingWindowView.cs │ │ │ ├── SlidingWindowView.resx │ │ │ ├── UploadImageView.Designer.cs │ │ │ ├── UploadImageView.cs │ │ │ └── UploadImageView.resx │ │ └── Interfaces │ │ │ ├── IBenchmarkView.cs │ │ │ ├── IDrawingView.cs │ │ │ ├── IMainFormView.cs │ │ │ ├── ISlidingWindowView.cs │ │ │ ├── IUploadImageView.cs │ │ │ └── IView.cs │ └── favico.ico ├── DigitRecognizer.sln ├── Images │ ├── favico.ico │ ├── left_arrow.png │ └── right_arrow.png └── Models │ ├── 1f15fd63-8d73-4128-9d27-0090df8ac1ba-0.9849.nn │ ├── 23682b7e-19b6-4dc2-95f6-563d13c35ffe-0.9841.nn │ ├── 5bbc9588-713f-40f4-8d07-e349a089e0b2-0.9845.nn │ ├── 780baa22-c1c6-44ec-b13c-e867dba48010-0.9845.nn │ ├── 8edc3712-2eba-4f2c-879b-1a280687f9b4-0.9837.nn │ ├── b7d29bbc-a55c-4f26-9655-db392a504105-0.9845.nn │ ├── c0a9a2e0-0cca-4943-8755-068330811b1d-0.9827.nn │ ├── c9234ebb-c75b-42c1-88ad-b75650d7102a-0.9857.nn │ └── ffed716d-a8a3-48fd-b0eb-98900357be81-0.9823.nn ├── LICENSE └── README.md /DigitRecognizer/Dataset/test_dataset.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Dataset/test_dataset.zip -------------------------------------------------------------------------------- /DigitRecognizer/Dataset/train_dataset.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Dataset/train_dataset.zip -------------------------------------------------------------------------------- /DigitRecognizer/Dataset/train_expanded_dataset.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Dataset/train_expanded_dataset.zip -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/Data/BoundingBox.cs: -------------------------------------------------------------------------------- 1 | using System.Drawing; 2 | 3 | namespace DigitRecognizer.Core.Data 4 | { 5 | public class BoundingBox 6 | { 7 | public Image Image { get; set; } 8 | public int X { get; set; } 9 | public int Y { get; set; } 10 | } 11 | } -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/Data/Box.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.Core.Data 2 | { 3 | /// 4 | /// Represents a box with configurable dimensions. 5 | /// 6 | public struct Box 7 | { 8 | private int _top; 9 | private int _bottom; 10 | private int _left; 11 | private int _right; 12 | 13 | /// 14 | /// Initializes a new instance of the struct. 15 | /// 16 | /// The top value. 17 | /// The bottom value. 18 | /// The left value. 19 | /// The right value. 20 | public Box(int top, int bottom, int left, int right) 21 | { 22 | _top = top; 23 | _bottom = bottom; 24 | _left = left; 25 | _right = right; 26 | } 27 | 28 | /// 29 | /// Gets or sets the top. 30 | /// 31 | public int Top 32 | { 33 | get => _top; 34 | set => _top = value; 35 | } 36 | 37 | /// 38 | /// Gets or sets the bottom. 39 | /// 40 | public int Bottom 41 | { 42 | get => _bottom; 43 | set => _bottom = value; 44 | } 45 | 46 | /// 47 | /// Gets or sets the left. 48 | /// 49 | public int Left 50 | { 51 | get => _left; 52 | set => _left = value; 53 | } 54 | 55 | /// 56 | /// Gets or sets the right. 57 | /// 58 | public int Right 59 | { 60 | get => _right; 61 | set => _right = value; 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/Data/LinkedListNode.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Utilities; 2 | 3 | namespace DigitRecognizer.Core.Data 4 | { 5 | /// 6 | /// Doubly linked list node of a generic type. 7 | /// 8 | /// 9 | public class LinkedListNode 10 | { 11 | private LinkedListNode _previous; 12 | private LinkedListNode _next; 13 | private readonly T _item; 14 | private int _depth; 15 | 16 | /// 17 | /// Initializes a new instance of the class. 18 | /// 19 | /// 20 | public LinkedListNode(T item) 21 | { 22 | _item = item; 23 | } 24 | 25 | /// 26 | /// Gets or sets the previous property. 27 | /// 28 | public LinkedListNode Previous 29 | { 30 | get => _previous; 31 | set => _previous = value; 32 | } 33 | 34 | /// 35 | /// Gets or sets the next property. 36 | /// 37 | public LinkedListNode Next 38 | { 39 | get => _next; 40 | set => _next = value; 41 | } 42 | 43 | /// 44 | /// Gets the value of the node. This field is readonly. 45 | /// 46 | public T Value => _item; 47 | 48 | /// 49 | /// Gets or sets the depth of the node. 50 | /// 51 | public int Depth 52 | { 53 | get => _depth; 54 | set 55 | { 56 | Contracts.ValueGreaterThanZero(value, nameof(value)); 57 | 58 | _depth = value; 59 | } 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/Data/MnistImageBatch.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Utilities; 2 | 3 | namespace DigitRecognizer.Core.Data 4 | { 5 | /// 6 | /// Represents a batch of MNIST images. 7 | /// 8 | public class MnistImageBatch 9 | { 10 | /// 11 | /// Gets the labels. 12 | /// 13 | public int[] Labels { get; } 14 | 15 | /// 16 | /// Gets the pixels. 17 | /// 18 | public double[][] Pixels { get; } 19 | 20 | /// 21 | /// Initializes a new instance of the class. 22 | /// 23 | /// The labels of the batch. 24 | /// The pixels of the batch. 25 | public MnistImageBatch(int[] labels, double[][] pixels) 26 | { 27 | Contracts.ValuesMatch(labels.Length, pixels.Length, nameof(labels.Length)); 28 | 29 | Labels = labels; 30 | Pixels = pixels; 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/DigitRecognizer.Core.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | AnyCPU 7 | {BA726236-C7EC-41CA-ABB6-E4FE7E784939} 8 | Library 9 | Properties 10 | DigitRecognizer.Core 11 | DigitRecognizer.Core 12 | v4.8 13 | 512 14 | 15 | 16 | 17 | true 18 | full 19 | false 20 | bin\Debug\ 21 | DEBUG;TRACE 22 | prompt 23 | 4 24 | true 25 | 26 | 27 | pdbonly 28 | true 29 | bin\Release\ 30 | TRACE 31 | prompt 32 | 4 33 | true 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/Extensions/ByteExtensions.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace DigitRecognizer.Core.Extensions 4 | { 5 | /// 6 | /// Provides extension methods for working with struct. 7 | /// 8 | public static class ByteExtensions 9 | { 10 | /// 11 | /// Converts the specified byte array to an array of doubles, utilizing the class. 12 | /// 13 | /// The array of bytes that will be converted to doubles. 14 | /// A double array. 15 | public static double[] ToDoubles(this byte[] array) 16 | { 17 | if (array.Length == 0) 18 | { 19 | return default(double[]); 20 | } 21 | 22 | var result = new double[array.Length / sizeof(double)]; 23 | Buffer.BlockCopy(array, 0, result, 0, array.Length); 24 | return result; 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/Extensions/DoubleExtensions.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace DigitRecognizer.Core.Extensions 4 | { 5 | /// 6 | /// Provides extensions methods for working with the struct. 7 | /// 8 | public static class DoubleExtensions 9 | { 10 | private const double DoubleNorm = 255.0D; 11 | private const double DoubleOffsetNorm = 127.5D; 12 | 13 | /// 14 | /// Returns a double in the range [0, 1], parsed from the specified byte. 15 | /// 16 | /// The value that will be parsed then normalized. 17 | /// A double in the range [0,1]. 18 | public static double FromByteNormalized(byte value) 19 | { 20 | return value / DoubleNorm; 21 | } 22 | 23 | /// 24 | /// Returns a double in the range [-0.5, 0.5], parsed from the specified byte. 25 | /// 26 | /// The value that will be parsed then normalized. 27 | /// A double in the range [-0.5,0.5]. 28 | public static double FromByteOffset(byte value) 29 | { 30 | return value / DoubleOffsetNorm - 1; 31 | } 32 | 33 | /// 34 | /// Gets the bytes of the double array, utilizing the class. 35 | /// 36 | /// The array of doubles that will be converted to bytes. 37 | /// A byte array. 38 | public static byte[] GetBytes(this double[] array) 39 | { 40 | int length = array.Length; 41 | const int sizeOfDouble = sizeof(double); 42 | var result = new byte[length * sizeOfDouble]; 43 | 44 | var offset = 0; 45 | for (var i = 0; i < length; i++) 46 | { 47 | byte[] bytes = BitConverter.GetBytes(array[i]); 48 | Buffer.BlockCopy(bytes, 0, result, offset, sizeOfDouble); 49 | offset += sizeOfDouble; 50 | } 51 | 52 | return result; 53 | } 54 | 55 | /// 56 | /// Gets the bytes of the double array, utilizing the class. 57 | /// 58 | /// The array of doubles that will be converted to bytes. 59 | /// A byte array. 60 | public static byte[] ToBytes(this double[] array) 61 | { 62 | var result = new byte[array.Length * sizeof(double)]; 63 | 64 | Buffer.BlockCopy(array, 0, result, 0, result.Length); 65 | 66 | return result; 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/Extensions/ImageExtensions.cs: -------------------------------------------------------------------------------- 1 | using System.Drawing; 2 | using DigitRecognizer.Core.Utilities; 3 | 4 | namespace DigitRecognizer.Core.Extensions 5 | { 6 | /// 7 | /// Contains extensions methods for the class. 8 | /// 9 | public static class ImageExtensions 10 | { 11 | /// 12 | /// Resizes the image to the specified width and height. 13 | /// 14 | /// The image being resized. 15 | /// The new width in pixels. 16 | /// The new height in pixels. 17 | /// The resized image. 18 | public static Image Resize(this Image image, int width, int height) 19 | { 20 | return ImageUtilities.Resize(image, width, height); 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/Extensions/RandomExtensions.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using DigitRecognizer.Core.Utilities; 3 | 4 | namespace DigitRecognizer.Core.Extensions 5 | { 6 | /// 7 | /// Contains extension method for the class. 8 | /// 9 | public static class RandomExtensions 10 | { 11 | /// 12 | /// Returns an array of random floating-point nubmers that are greater than or equal to 0.0 and less than 1.0. 13 | /// 14 | /// The random number generator. 15 | /// The length of the array. 16 | /// The array filled with random numbers. 17 | public static double[] NextDoubles(this Random rnd, int length) 18 | { 19 | Contracts.ValueGreaterThanZero(length, nameof(length)); 20 | 21 | var result = new double[length]; 22 | 23 | for (var i = 0; i < length; i++) 24 | { 25 | result[i] = rnd.NextDouble(); 26 | } 27 | 28 | return result; 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/Extensions/StreamExtensions.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.IO; 3 | 4 | namespace DigitRecognizer.Core.Extensions 5 | { 6 | /// 7 | /// Contains extension methods for working with the class. 8 | /// 9 | public static class StreamExtensions 10 | { 11 | /// 12 | /// Initialzies a object, from the the specified file. 13 | /// 14 | /// The file path, that points to a file whose memory stream we want to get. 15 | /// A . 16 | public static MemoryStream GetMemoryStreamFromFile(string filename) 17 | { 18 | if (!File.Exists(filename)) 19 | { 20 | return default(MemoryStream); 21 | } 22 | 23 | var memoryStream = new MemoryStream(); 24 | 25 | using (var fileStream = File.OpenRead(filename)) 26 | { 27 | memoryStream.SetLength(fileStream.Length); 28 | fileStream.Read(memoryStream.GetBuffer(), 0, (int)fileStream.Length); 29 | } 30 | 31 | return memoryStream; 32 | } 33 | 34 | /// 35 | /// Sets the position within the current stream to the specified value. 36 | /// 37 | /// The in which we are setting the position. 38 | /// The position to be set in the . 39 | public static void SetPosition(this Stream stream, int position) 40 | { 41 | if (!stream.CanSeek) 42 | { 43 | return; 44 | } 45 | 46 | stream.Seek(position, SeekOrigin.Begin); 47 | } 48 | 49 | /// 50 | /// Sets the position within the current stream to the beginning of the stream. 51 | /// 52 | /// The that we are resetting. 53 | public static void Reset(this Stream stream) 54 | { 55 | stream.SetPosition(0); 56 | } 57 | 58 | /// 59 | /// Reads the specified number of bytes from the stream, and advances the current position of the stream. 60 | /// 61 | /// The that we are reading from. 62 | /// The number of bytes to read from the stream. 63 | /// An array of bytes. 64 | public static byte[] ReadBytes(this Stream stream, int count) 65 | { 66 | if (count < 0 || count > stream.Length) 67 | { 68 | return new byte[0]; 69 | } 70 | 71 | var buffer = new byte[count]; 72 | 73 | int bytesRead = stream.Read(buffer, 0, count); 74 | 75 | if (bytesRead == count) 76 | { 77 | return buffer; 78 | } 79 | 80 | var smallerBuffer = new byte[bytesRead]; 81 | Buffer.BlockCopy(buffer, 0, smallerBuffer, 0, bytesRead); 82 | buffer = smallerBuffer; 83 | 84 | return buffer; 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/IO/ILabelReader.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.Core.IO 2 | { 3 | /// 4 | /// Interface that should be implemented by a label provider class. 5 | /// 6 | public interface ILabelReader 7 | { 8 | int ReadLabel(); 9 | int[] ReadLabels(int count); 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/IO/IPixelReader.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.Core.IO 2 | { 3 | /// 4 | /// Interface that should be implemented by a pixel provider class. 5 | /// 6 | public interface IPixelReader 7 | { 8 | double[] ReadPixels(int count); 9 | double[][] ReadPixels(int count, int blockSize); 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/IO/ImagePreprocessor.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Drawing; 3 | using DigitRecognizer.Core.Data; 4 | using DigitRecognizer.Core.Utilities; 5 | 6 | namespace DigitRecognizer.Core.IO 7 | { 8 | /// 9 | /// Represents a preprocessor for images that will be run through the nerual network. 10 | /// 11 | public class ImagePreprocessor 12 | { 13 | private const int ImageSizeInPixels = 28; 14 | 15 | /// 16 | /// Preprocesses the specified image so that it can be fed through a neural network. 17 | /// The image has to be an RGB or grayscale image with 0 being black, and 255 being white. 18 | /// 19 | /// The image being preprocessed. 20 | /// The flattened and clamped pixels of the preprocessed image. 21 | /// 22 | /// First, the image is converted to a grayscale image and then a threshold is applied. 23 | /// The size of the box surrounding the pixels where the number is is determined. 24 | /// The image is then scaled to the predetermined box, and coordinates for padding are extracted. 25 | /// The image is then padded to again be of the required size and then it is inverted. 26 | /// We then calculate the center of mass, and translate the image towards the center of mass. 27 | /// The result is that the pixels that reperesent the number are shifted towards the center of the image. 28 | /// Another threshold is applied, and the pixels are then flattened and clamped. 29 | /// 30 | public double[] Preprocess(Image image) 31 | { 32 | Image grayscale = ImageUtilities.Grayscale(image); 33 | 34 | Image resized = ImageUtilities.Resize(grayscale, ImageSizeInPixels, ImageSizeInPixels); 35 | 36 | Image grayscaleWithTreshold = ImageUtilities.Threshold(resized, 128, 0, 255); 37 | 38 | Box coords = ImageUtilities.DetermineBox(grayscaleWithTreshold); 39 | 40 | (Image, Box) imageAndPadding = ImageUtilities.ScaleToBoxAndGetPaddingCoords(grayscaleWithTreshold, coords); 41 | 42 | Image padded = ImageUtilities.Pad(imageAndPadding.Item1, imageAndPadding.Item2); 43 | 44 | Image inverted = ImageUtilities.Invert(padded); 45 | 46 | Point centerOfMass = ImageUtilities.CalculateCenterOfMass(inverted); 47 | 48 | var shiftx = (int)Math.Round(ImageSizeInPixels / 2.0 - centerOfMass.X); 49 | 50 | var shifty = (int)Math.Round(ImageSizeInPixels / 2.0 - centerOfMass.Y); 51 | 52 | Image centered = ImageUtilities.Translate(inverted, shiftx, shifty); 53 | 54 | Image final = ImageUtilities.Threshold(centered, 154, 0, 255, true); 55 | 56 | double[] data = ImageUtilities.FlattenAndClamp(final); 57 | 58 | return data; 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/IO/LabelReader.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.Core.IO 2 | { 3 | /// 4 | /// Provides methods for reading labels from a . 5 | /// 6 | public class LabelReader : MemoryStreamReader, ILabelReader 7 | { 8 | /// 9 | /// A "magic number" of bytes that needs to be skipped at the beginning of the stream. 10 | /// 11 | public const int InitialOffset = 8; 12 | 13 | /// 14 | /// Initializes a new instance of the class. 15 | /// 16 | /// The file path used for instantiating an internal stream object. 17 | public LabelReader(string filename) : base(filename, InitialOffset) 18 | { 19 | } 20 | 21 | /// 22 | /// Reads a single label from the file. 23 | /// 24 | /// A label. 25 | public int ReadLabel() 26 | { 27 | int result = Read(1)[0]; 28 | 29 | return result; 30 | } 31 | 32 | /// 33 | /// Reads the specified ammount of labels from the file. 34 | /// 35 | /// The number of labels to read. 36 | /// An array of labels. 37 | public int[] ReadLabels(int count) 38 | { 39 | byte[] bytes = Read(count); 40 | 41 | var result = new int[count]; 42 | 43 | for (var i = 0; i < count; i++) 44 | { 45 | result[i] = bytes[i]; 46 | } 47 | 48 | return result; 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/IO/NnSerializableBase.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.IO; 3 | using DigitRecognizer.Core.Utilities; 4 | 5 | namespace DigitRecognizer.Core.IO 6 | { 7 | /// 8 | /// An abstract base class for serialization of neural network files. 9 | /// 10 | public abstract class NnSerializableBase : IDisposable 11 | { 12 | /// 13 | /// The file extension of a neural network file. 14 | /// 15 | private const string NnFileExtension = ".nn"; 16 | 17 | /// 18 | /// The internal object. 19 | /// 20 | protected readonly FileStream FileStream; 21 | 22 | /// 23 | /// Initializes a new instance of the class. 24 | /// 25 | /// The name of the file, used for opening a file stream. 26 | /// The file acces mode of the adapter. 27 | protected NnSerializableBase(string filename, FileMode fileMode) 28 | { 29 | Contracts.StringNotNullOrEmpty(filename, nameof(filename)); 30 | Contracts.FileHasExtension(filename, nameof(filename)); 31 | Contracts.FileExtensionValid(filename, NnFileExtension, nameof(filename)); 32 | // Contracts.FileExists(filename, nameof(filename)); 33 | 34 | FileStream = File.Open(filename, fileMode, FileAccess.ReadWrite); 35 | } 36 | 37 | /// 38 | /// Indicates if the disposing operation has been completed or not. 39 | /// 40 | private bool _disposed; 41 | 42 | /// 43 | /// Releases all resources used by the . 44 | /// 45 | public void Dispose() 46 | { 47 | Dispose(true); 48 | GC.SuppressFinalize(this); 49 | } 50 | 51 | /// 52 | /// Releases all resources used by the . 53 | /// 54 | protected virtual void Dispose(bool disposing) 55 | { 56 | if (_disposed) 57 | { 58 | return; 59 | } 60 | 61 | if (disposing) 62 | { 63 | FileStream?.Dispose(); 64 | } 65 | 66 | _disposed = true; 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/IO/NnSerializationContext.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.Core.IO 2 | { 3 | /// 4 | /// Model for neural netowrk serialization. 5 | /// 6 | public class NnSerializationContext 7 | { 8 | private readonly NnSerializationContextInfo _serializationContextInfo; 9 | private readonly double[] _data; 10 | 11 | /// 12 | /// Initializes a new instance of the class with the specified parameters. 13 | /// 14 | /// The file info to associate with the file. 15 | /// The data for the file. 16 | public NnSerializationContext(double[] fileData, NnSerializationContextInfo serializationContextInfo) 17 | { 18 | _serializationContextInfo = serializationContextInfo; 19 | _data = fileData; 20 | } 21 | 22 | /// 23 | /// Gets the . 24 | /// 25 | public NnSerializationContextInfo SerializationContextInfo => _serializationContextInfo; 26 | 27 | /// 28 | /// Gets the data. 29 | /// 30 | /// The data is a combined array of all weights and biases. 31 | /// 32 | /// 33 | public double[] FileData => _data; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/IO/NnSerializationContextInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Text; 2 | using DigitRecognizer.Core.Utilities; 3 | 4 | namespace DigitRecognizer.Core.IO 5 | { 6 | /// 7 | /// Contains context information about a neural netowrk file. 8 | /// 9 | public struct NnSerializationContextInfo 10 | { 11 | private readonly int _weightMatrixRowCount; 12 | private readonly int _weightMatrixColCount; 13 | private readonly int _biasLength; 14 | private readonly string _activationFunctionName; 15 | 16 | /// 17 | /// Initializes a new instance of the struct. 18 | /// 19 | /// The number of rows of the weight matrix. 20 | /// The number of columns of the weight matrix. 21 | /// The length of the bias. 22 | /// The name of the activation function for the layer. 23 | public NnSerializationContextInfo(int weightMatrixRowCount, int weightMatrixColCount, int biasLength, string activationFunctionName) 24 | { 25 | Contracts.ValueGreaterThanZero(weightMatrixRowCount, nameof(weightMatrixRowCount)); 26 | Contracts.ValueGreaterThanZero(weightMatrixColCount, nameof(weightMatrixColCount)); 27 | Contracts.ValueGreaterThanZero(biasLength, nameof(biasLength)); 28 | 29 | _weightMatrixRowCount = weightMatrixRowCount; 30 | _weightMatrixColCount = weightMatrixColCount; 31 | _biasLength = biasLength; 32 | _activationFunctionName = activationFunctionName; 33 | } 34 | 35 | /// 36 | /// Gets the weight matrix row count. 37 | /// 38 | public int WeightMatrixRowCount => _weightMatrixRowCount; 39 | 40 | /// 41 | /// Gets the weight matrix column count. 42 | /// 43 | public int WeightMatrixColCount => _weightMatrixColCount; 44 | 45 | /// 46 | /// Gets the bias length. 47 | /// 48 | public int BiasLength => _biasLength; 49 | 50 | /// 51 | /// Gets the name of the activation function. 52 | /// 53 | public string ActivationFunctionName => _activationFunctionName; 54 | 55 | /// 56 | /// Gets the number of bytes in the activation function name. 57 | /// 58 | public int ActivationFunctionNameSizeInBytes => Encoding.Unicode.GetByteCount(_activationFunctionName); 59 | 60 | /// 61 | /// Gets the length of the data array. 62 | /// 63 | public int DataLength => _weightMatrixRowCount * _weightMatrixColCount + _biasLength; 64 | 65 | /// 66 | /// Gets the size of the data array in bytes. 67 | /// 68 | public int DataSizeInBytes => DataLength * sizeof(double); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/IO/PixelReader.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.Core.IO 2 | { 3 | /// 4 | /// Provides methods for reading pixels from a . 5 | /// 6 | public class PixelReader : MemoryStreamReader, IPixelReader 7 | { 8 | /// 9 | /// A "magic number" of bytes that needs to be skipped at the beginning of the stream. 10 | /// 11 | public const int InitialOffset = 16; 12 | 13 | /// 14 | /// Initializes a new instance of the class. 15 | /// 16 | /// The file path used for instantiating an internal stream object. 17 | public PixelReader(string filename) : base(filename, InitialOffset) 18 | { 19 | } 20 | 21 | /// 22 | /// Reads the specified ammount of pixels from the file. 23 | /// 24 | /// The number of pixels to read. 25 | /// Am arrau of doubles. 26 | public double[] ReadPixels(int count) 27 | { 28 | byte[] bytes = Read(count); 29 | 30 | double[] result = ConvertByteArrayToDoubleArray(bytes, count); 31 | 32 | return result; 33 | } 34 | 35 | /// 36 | /// Reads the specified ammount of blocks of pixels from the file, each of the specified block size. 37 | /// 38 | /// The number of blocks to read. 39 | /// The size of each block. 40 | /// A jagged array of doubles. 41 | public double[][] ReadPixels(int count, int blockSize) 42 | { 43 | byte[][] bytes = Read(count, blockSize); 44 | 45 | var result = new double[count][]; 46 | 47 | for (var i = 0; i < count; i++) 48 | { 49 | result[i] = ConvertByteArrayToDoubleArray(bytes[i], blockSize); 50 | } 51 | 52 | return result; 53 | } 54 | 55 | /// 56 | /// Converts an array of bytes to an array of doubles. 57 | /// 58 | /// The array of bytes. 59 | /// The length of the array. 60 | /// 61 | private static double[] ConvertByteArrayToDoubleArray(byte[] bytes, int length) 62 | { 63 | var result = new double[length]; 64 | 65 | for (var i = 0; i < length; i++) 66 | { 67 | result[i] = bytes[i]; 68 | } 69 | 70 | return result; 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.InteropServices; 3 | 4 | // General Information about an assembly is controlled through the following 5 | // set of attributes. Change these attribute values to modify the information 6 | // associated with an assembly. 7 | [assembly: AssemblyTitle("DigitRecognizer.Core")] 8 | [assembly: AssemblyDescription("")] 9 | [assembly: AssemblyConfiguration("")] 10 | [assembly: AssemblyCompany("")] 11 | [assembly: AssemblyProduct("DigitRecognizer.Core")] 12 | [assembly: AssemblyCopyright("Copyright © 2018")] 13 | [assembly: AssemblyTrademark("")] 14 | [assembly: AssemblyCulture("")] 15 | 16 | // Setting ComVisible to false makes the types in this assembly not visible 17 | // to COM components. If you need to access a type in this assembly from 18 | // COM, set the ComVisible attribute to true on that type. 19 | [assembly: ComVisible(false)] 20 | 21 | // The following GUID is for the ID of the typelib if this project is exposed to COM 22 | [assembly: Guid("ba726236-c7ec-41ca-abb6-e4fe7e784939")] 23 | 24 | // Version information for an assembly consists of the following four values: 25 | // 26 | // Major Version 27 | // Minor Version 28 | // Build Number 29 | // Revision 30 | // 31 | // You can specify all the values or you can default the Build and Revision Numbers 32 | // by using the '*' as shown below: 33 | // [assembly: AssemblyVersion("1.0.*")] 34 | [assembly: AssemblyVersion("1.0.0.0")] 35 | [assembly: AssemblyFileVersion("1.0.0.0")] 36 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/Utilities/ConsoleUtility.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Diagnostics; 3 | 4 | namespace DigitRecognizer.Core.Utilities 5 | { 6 | /// 7 | /// Contains utility methods for printing in a console window. 8 | /// 9 | public static class ConsoleUtility 10 | { 11 | /// 12 | /// Writes the specified message. 13 | /// 14 | /// The message. 15 | public static void WriteLine(string message) 16 | { 17 | #if DEBUG 18 | Debug.WriteLine(message); 19 | #endif 20 | Console.WriteLine(message); 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Core/Utilities/DirectoryHelper.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.Core.Utilities 2 | { 3 | /// 4 | /// Provides utility for accessing commonly required folders. 5 | /// 6 | public static class DirectoryHelper 7 | { 8 | /// 9 | /// The path to the dataset folder. 10 | /// 11 | private const string DatasetPath = @"../../../Dataset"; 12 | 13 | /// 14 | /// The path to the models folder. 15 | /// 16 | private const string ModelsPath = @"../../../../Models"; 17 | 18 | /// 19 | /// Gets the path to the training labels file. 20 | /// 21 | public static string TrainLabelsPath => $"{DatasetPath}/train-labels.idx1-ubyte"; 22 | 23 | /// 24 | /// Gets the path to the training images file. 25 | /// 26 | public static string TrainImagesPath => $"{DatasetPath}/train-images.idx3-ubyte"; 27 | 28 | /// 29 | /// Gets the path to the expaneded training labels file. 30 | /// 31 | public static string ExpandedTrainLabelsPath => $"{DatasetPath}/train-exp-labels.idx1-ubyte"; 32 | 33 | /// 34 | /// Gets the path to the expanded training images file. 35 | /// 36 | public static string ExpandedTrainImagesPath => $"{DatasetPath}/train-exp-images.idx3-ubyte"; 37 | 38 | /// 39 | /// Gets the path to the testing labels file. 40 | /// 41 | public static string TestLabelsPath => $"{DatasetPath}/t10k-labels.idx1-ubyte"; 42 | 43 | /// 44 | /// Gets the path to the testing images file. 45 | /// 46 | public static string TestImagesPath => $"{DatasetPath}/t10k-images.idx3-ubyte"; 47 | 48 | /// 49 | /// Gets the path to the models folder. 50 | /// 51 | public static string ModelsFolder => ModelsPath; 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.DatasetExpansion/Api/DatasetExpander.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using DigitRecognizer.DatasetExpansion.Infrastructure; 5 | 6 | namespace DigitRecognizer.DatasetExpansion.Api 7 | { 8 | /// 9 | /// Expands the existing MNIST training dataset. 10 | /// 11 | public class DatasetExpander 12 | { 13 | /// 14 | /// Expands and shuffles the existing dataset with various affine transformations. 15 | /// 16 | /// 17 | public List ExpandDataset() 18 | { 19 | var result = new List(); 20 | 21 | Console.WriteLine("Reading images from disk"); 22 | 23 | List dataset = GetDataset(); 24 | 25 | Console.WriteLine("Applying translation transformation"); 26 | 27 | List translatedImages = ApplyTranslationTransformation(dataset); 28 | 29 | Console.WriteLine("Applying rotation transformation"); 30 | 31 | List rotatedImages = ApplyRotationTransformation(dataset); 32 | 33 | result.AddRange(dataset); 34 | 35 | result.AddRange(translatedImages); 36 | 37 | result.AddRange(rotatedImages); 38 | 39 | Console.WriteLine("Shuffling images"); 40 | 41 | return FisherYatesShuffle.Shuffle(result); 42 | } 43 | 44 | /// 45 | /// Gets the dataset form disk. 46 | /// 47 | /// A list of objects. 48 | private List GetDataset() 49 | { 50 | var provider = new DatasetProvider(); 51 | 52 | return provider.LoadDataset(); 53 | } 54 | 55 | /// 56 | /// Applies the transalation transformation to the dataset images. 57 | /// 58 | /// The dataset. 59 | /// The translated images. 60 | private List ApplyTranslationTransformation(List dataset) 61 | { 62 | var result = new List(); 63 | 64 | result.AddRange(dataset.Select(x => new MnistImage(x.Label, AffineTransformation.Translate(x.Pixels, -1, 0))).ToList()); 65 | 66 | result.AddRange(dataset.Select(x => new MnistImage(x.Label, AffineTransformation.Translate(x.Pixels, 1, 0))).ToList()); 67 | 68 | result.AddRange(dataset.Select(x => new MnistImage(x.Label, AffineTransformation.Translate(x.Pixels, 0, -1))).ToList()); 69 | 70 | result.AddRange(dataset.Select(x => new MnistImage(x.Label, AffineTransformation.Translate(x.Pixels, 0, 1))).ToList()); 71 | 72 | return result; 73 | } 74 | 75 | /// 76 | /// Applies the rotation transformation to the dataset images. 77 | /// 78 | /// The dataset. 79 | /// The rotated images. 80 | private List ApplyRotationTransformation(List dataset) 81 | { 82 | var result = new List(); 83 | 84 | var random = new Random(); 85 | 86 | result.AddRange(dataset.Select(x=> new MnistImage(x.Label, AffineTransformation.Rotate(x.Pixels, random.NextDouble() > 0.5 ? 0.5 : -0.5)))); 87 | 88 | return result; 89 | } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.DatasetExpansion/Api/DatasetSerializer.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.IO; 4 | using System.Linq; 5 | using DigitRecognizer.DatasetExpansion.Infrastructure; 6 | 7 | namespace DigitRecognizer.DatasetExpansion.Api 8 | { 9 | /// 10 | /// Represents a serializer for an MNIST dataset. 11 | /// 12 | public class DatasetSerializer 13 | { 14 | private const int LabelMagicNumber = 0x00000801; 15 | 16 | private const int ImageMagicNumber = 0x00000803; 17 | 18 | private const int Length = 28; 19 | 20 | /// 21 | /// Serializes the specified dataset to the specified files. 22 | /// 23 | /// The name of the file with labels data. 24 | /// The name of the file with images data. 25 | /// The dataset. 26 | public void SerializeDataset(string labelsFilename, string imagesFilename, List dataset) 27 | { 28 | Console.WriteLine("Ensuring file integrity"); 29 | 30 | EnsureFileIntegrity(labelsFilename); 31 | EnsureFileIntegrity(imagesFilename); 32 | 33 | // The dataset size must be MSB first. 34 | int datasetSize = BitConverter.ToInt32(BitConverter.GetBytes(dataset.Count).Reverse().ToArray(), 0); 35 | 36 | using (var lblWriter = new BinaryWriter(File.Open(labelsFilename, FileMode.Create, FileAccess.Write))) 37 | using (var imgWriter = new BinaryWriter(File.Open(imagesFilename, FileMode.Create, FileAccess.Write))) 38 | { 39 | // Write the magic number and dataset size to the file. 40 | lblWriter.Write(LabelMagicNumber); 41 | 42 | lblWriter.Write(datasetSize); 43 | 44 | // Write the magic number and dataset size to the file. 45 | imgWriter.Write(ImageMagicNumber); 46 | 47 | imgWriter.Write(datasetSize); 48 | 49 | imgWriter.Write(Length); 50 | 51 | imgWriter.Write(Length); 52 | 53 | // Write bytes of each image to appropriate files. 54 | foreach (MnistImage img in dataset) 55 | { 56 | lblWriter.Write(img.Label); 57 | 58 | imgWriter.Write(img.Pixels); 59 | } 60 | } 61 | } 62 | 63 | /// 64 | /// Ensures the integrity of the file with the specified file name. 65 | /// 66 | /// The filename. 67 | private void EnsureFileIntegrity(string filename) 68 | { 69 | const int retryCount = 5; 70 | 71 | for (var i = 0; i < retryCount; i++) 72 | { 73 | try 74 | { 75 | if (File.Exists(filename)) 76 | { 77 | File.Delete(filename); 78 | } 79 | 80 | return; 81 | } 82 | catch (Exception) 83 | { 84 | // ignored 85 | } 86 | } 87 | 88 | throw new Exception($"Failed to delete file {filename}"); 89 | } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.DatasetExpansion/App.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.DatasetExpansion/DigitRecognizer.DatasetExpansion.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | AnyCPU 7 | {757CF815-DD01-4F66-8E9F-DB3FF9B4A463} 8 | Exe 9 | DigitRecognizer.DatasetExpansion 10 | DigitRecognizer.DatasetExpansion 11 | v4.8 12 | 512 13 | true 14 | 15 | 16 | 17 | AnyCPU 18 | true 19 | full 20 | false 21 | bin\Debug\ 22 | DEBUG;TRACE 23 | prompt 24 | 4 25 | 26 | 27 | AnyCPU 28 | pdbonly 29 | true 30 | bin\Release\ 31 | TRACE 32 | prompt 33 | 4 34 | 35 | 36 | favico.ico 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | {BA726236-C7EC-41CA-ABB6-E4FE7E784939} 64 | DigitRecognizer.Core 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.DatasetExpansion/Infrastructure/DatasetProvider.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | using DigitRecognizer.Core.IO; 3 | using DigitRecognizer.Core.Utilities; 4 | 5 | namespace DigitRecognizer.DatasetExpansion.Infrastructure 6 | { 7 | /// 8 | /// Loads the current dataset from disk. 9 | /// 10 | public class DatasetProvider 11 | { 12 | private const int ImagePixelCount = 784; 13 | private const int DatasetSize = 60000; 14 | 15 | private static readonly string TrainImagesFilePath = DirectoryHelper.TrainImagesPath; 16 | 17 | private static readonly string TrainLabelsFilePath = DirectoryHelper.TrainLabelsPath; 18 | 19 | /// 20 | /// Lodas the MNIST training dataset from disk. 21 | /// 22 | /// The current dataset. 23 | public List LoadDataset() 24 | { 25 | var result = new List(); 26 | 27 | using (var labelReader = new MemoryStreamReader(TrainLabelsFilePath, LabelReader.InitialOffset)) 28 | using (var pixelReader = new MemoryStreamReader(TrainImagesFilePath, PixelReader.InitialOffset)) 29 | { 30 | for (var i = 0; i < DatasetSize; i++) 31 | { 32 | byte label = labelReader.Read(1)[0]; 33 | 34 | byte[] pixels = pixelReader.Read(ImagePixelCount); 35 | 36 | var img = new MnistImage(label, pixels); 37 | 38 | result.Add(img); 39 | } 40 | } 41 | 42 | return result; 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.DatasetExpansion/Infrastructure/FisherYatesShuffle.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | 5 | namespace DigitRecognizer.DatasetExpansion.Infrastructure 6 | { 7 | /// 8 | /// Implements the Fisher-Yates shuffling algorithm. 9 | /// 10 | public class FisherYatesShuffle 11 | { 12 | /// 13 | /// Used in Shuffle(T). 14 | /// 15 | private static readonly Random Random = new Random(); 16 | 17 | /// 18 | /// Shuffle the array. 19 | /// 20 | /// Array element type. 21 | /// Array to shuffle. 22 | public static void Shuffle(T[] array) 23 | { 24 | int n = array.Length; 25 | 26 | for (var i = 0; i < n; i++) 27 | { 28 | // Use Next on random instance with an argument. 29 | // The argument is an exclusive bound. 30 | // So we will not go past the end of the array. 31 | int r = i + Random.Next(n - i); 32 | 33 | T t = array[r]; 34 | 35 | array[r] = array[i]; 36 | 37 | array[i] = t; 38 | } 39 | } 40 | 41 | /// 42 | /// Shuffle the list. 43 | /// 44 | /// Array element type. 45 | /// List to shuffle. 46 | /// The shuffled list. 47 | public static List Shuffle(List list) 48 | { 49 | T[] array = list.ToArray(); 50 | 51 | Shuffle(array); 52 | 53 | return array.ToList(); 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.DatasetExpansion/Infrastructure/MnistImage.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace DigitRecognizer.DatasetExpansion.Infrastructure 4 | { 5 | /// 6 | /// Represents a raw MNIST image in binary format. 7 | /// 8 | public class MnistImage 9 | { 10 | private byte _label; 11 | private byte[] _pixels; 12 | 13 | /// 14 | /// Initializes a new instance of the class. 15 | /// 16 | /// The label. 17 | /// The pixels. 18 | public MnistImage(byte label, byte[] pixels) 19 | { 20 | _label = label; 21 | _pixels = pixels; 22 | } 23 | 24 | /// 25 | /// Gets the label. 26 | /// 27 | public byte Label => _label; 28 | 29 | /// 30 | /// Gets the pixels. 31 | /// 32 | public byte[] Pixels => _pixels; 33 | 34 | public override string ToString() 35 | { 36 | string result = string.Empty; 37 | 38 | for (var i = 0; i < 28; i++) 39 | { 40 | for (var j = 0; j < 28; j++) 41 | { 42 | result += _pixels[i * 28 + j] > 10 ? "0" : " "; 43 | } 44 | 45 | result += Environment.NewLine; 46 | } 47 | 48 | return result; 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.DatasetExpansion/Program.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using DigitRecognizer.Core.Utilities; 4 | using DigitRecognizer.DatasetExpansion.Api; 5 | using DigitRecognizer.DatasetExpansion.Infrastructure; 6 | 7 | namespace DigitRecognizer.DatasetExpansion 8 | { 9 | internal class Program 10 | { 11 | private static void Main() 12 | { 13 | Console.Title = "Digit Recognizer - Dataset Expansion"; 14 | 15 | Console.WriteLine("Starting dataset expansion"); 16 | 17 | var expander = new DatasetExpander(); 18 | 19 | List expandedDataset = expander.ExpandDataset(); 20 | 21 | Console.WriteLine("Completed dataset expansion"); 22 | 23 | Console.WriteLine("Starting dataset serialization"); 24 | 25 | var serializer = new DatasetSerializer(); 26 | 27 | serializer.SerializeDataset(DirectoryHelper.ExpandedTrainLabelsPath, DirectoryHelper.ExpandedTrainImagesPath, expandedDataset); 28 | 29 | Console.WriteLine("Completed dataset serialization"); 30 | 31 | Console.ReadKey(); 32 | } 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.DatasetExpansion/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.CompilerServices; 3 | using System.Runtime.InteropServices; 4 | 5 | // General Information about an assembly is controlled through the following 6 | // set of attributes. Change these attribute values to modify the information 7 | // associated with an assembly. 8 | [assembly: AssemblyTitle("DigitRecognizer.DatasetExpansion")] 9 | [assembly: AssemblyDescription("")] 10 | [assembly: AssemblyConfiguration("")] 11 | [assembly: AssemblyCompany("")] 12 | [assembly: AssemblyProduct("DigitRecognizer.DatasetExpansion")] 13 | [assembly: AssemblyCopyright("Copyright © 2018")] 14 | [assembly: AssemblyTrademark("")] 15 | [assembly: AssemblyCulture("")] 16 | 17 | // Setting ComVisible to false makes the types in this assembly not visible 18 | // to COM components. If you need to access a type in this assembly from 19 | // COM, set the ComVisible attribute to true on that type. 20 | [assembly: ComVisible(false)] 21 | 22 | // The following GUID is for the ID of the typelib if this project is exposed to COM 23 | [assembly: Guid("757cf815-dd01-4f66-8e9f-db3ff9b4a463")] 24 | 25 | // Version information for an assembly consists of the following four values: 26 | // 27 | // Major Version 28 | // Minor Version 29 | // Build Number 30 | // Revision 31 | // 32 | // You can specify all the values or you can default the Build and Revision Numbers 33 | // by using the '*' as shown below: 34 | // [assembly: AssemblyVersion("1.0.*")] 35 | [assembly: AssemblyVersion("1.0.0.0")] 36 | [assembly: AssemblyFileVersion("1.0.0.0")] 37 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.DatasetExpansion/favico.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/DigitRecognizer.DatasetExpansion/favico.ico -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Engine/App.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Engine/DigitRecognizer.Engine.csproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | Debug 6 | AnyCPU 7 | {00CE9C72-5F92-4555-BB96-D0E77B9A2390} 8 | Exe 9 | DigitRecognizer.Engine 10 | DigitRecognizer.Engine 11 | v4.8 12 | 512 13 | true 14 | 15 | 16 | 17 | AnyCPU 18 | true 19 | full 20 | false 21 | bin\Debug\ 22 | DEBUG;TRACE 23 | prompt 24 | 4 25 | 26 | 27 | AnyCPU 28 | pdbonly 29 | true 30 | bin\Release\ 31 | TRACE 32 | prompt 33 | 4 34 | 35 | 36 | favico.ico 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | {BA726236-C7EC-41CA-ABB6-E4FE7E784939} 58 | DigitRecognizer.Core 59 | 60 | 61 | {BF82C3DF-0EFE-45DB-8E3D-1FFDD548AE73} 62 | DigitRecognizer.MachineLearning 63 | 64 | 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Engine/Program.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.IO; 4 | using System.Linq; 5 | using DigitRecognizer.Core.Data; 6 | using DigitRecognizer.Core.Extensions; 7 | using DigitRecognizer.Core.Utilities; 8 | using DigitRecognizer.MachineLearning.Pipeline; 9 | using DigitRecognizer.MachineLearning.Infrastructure.Functions; 10 | using DigitRecognizer.MachineLearning.Infrastructure.Initialization; 11 | using DigitRecognizer.MachineLearning.Infrastructure.Models; 12 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork; 13 | using DigitRecognizer.MachineLearning.Optimization.Optimizers; 14 | using DigitRecognizer.MachineLearning.Providers; 15 | 16 | namespace DigitRecognizer.Engine 17 | { 18 | internal class Program 19 | { 20 | private static void Main() 21 | { 22 | Console.Title = "Digit Recognizer - Engine"; 23 | 24 | var learningRate = 0.00035; 25 | var epochs = 10; 26 | var regularizationFactor = 15.0; 27 | 28 | LearningPipeline pipeline = new LearningPipeline() 29 | .UseGradientClipping() 30 | .UseL2Regularization(regularizationFactor) 31 | .UseDropout(0.5) 32 | .SetWeightsInitializer(InitializerType.RandomInitialization) 33 | .SetEpochCount(epochs); 34 | 35 | var layers = new List 36 | { 37 | new NnLayer(784, 200, new LeakyRelu()), 38 | new NnLayer(200, 100, new LeakyRelu()), 39 | new NnLayer(100, 10, new Softmax()) 40 | }; 41 | 42 | var nn = new NeuralNetwork(layers, learningRate); 43 | 44 | var optimizer = new MomentumOptimizer(nn, new CrossEntropy(), 0.93); 45 | 46 | var provider = new BatchDataProvider(DirectoryHelper.ExpandedTrainLabelsPath, DirectoryHelper.ExpandedTrainImagesPath, 100); 47 | 48 | pipeline.Add(optimizer); 49 | 50 | pipeline.Add(nn); 51 | 52 | pipeline.Add(provider); 53 | 54 | PredictionModel model = pipeline.Run(); 55 | 56 | var provider1 = new BatchDataProvider(DirectoryHelper.TestLabelsPath, DirectoryHelper.TestImagesPath, 10000); 57 | var acc = 0.0; 58 | 59 | MnistImageBatch data = provider1.GetData(); 60 | 61 | List predictions = data.Pixels.Select(pixels => model.Predict(pixels)).ToList(); 62 | 63 | for (var i = 0; i < data.Labels.Length; i++) 64 | { 65 | if (data.Labels[i] == predictions[i].ArgMax()) 66 | { 67 | acc++; 68 | } 69 | } 70 | 71 | acc /= 10000.0; 72 | 73 | Console.WriteLine($"Accuracy on the test data is: {acc:P2}"); 74 | 75 | string basePath = Path.GetFullPath(Path.GetDirectoryName(AppDomain.CurrentDomain.BaseDirectory) + 76 | DirectoryHelper.ModelsFolder); 77 | 78 | string modelName = $"{Guid.NewGuid()}-{acc:N4}.nn"; 79 | 80 | string filename = $"{basePath}/{modelName}"; 81 | 82 | model.Save(filename); 83 | 84 | Console.ReadKey(); 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Engine/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.InteropServices; 3 | 4 | // General Information about an assembly is controlled through the following 5 | // set of attributes. Change these attribute values to modify the information 6 | // associated with an assembly. 7 | [assembly: AssemblyTitle("DigitRecognizer.Engine")] 8 | [assembly: AssemblyDescription("")] 9 | [assembly: AssemblyConfiguration("")] 10 | [assembly: AssemblyCompany("")] 11 | [assembly: AssemblyProduct("DigitRecognizer.Engine")] 12 | [assembly: AssemblyCopyright("Copyright © 2018")] 13 | [assembly: AssemblyTrademark("")] 14 | [assembly: AssemblyCulture("")] 15 | 16 | // Setting ComVisible to false makes the types in this assembly not visible 17 | // to COM components. If you need to access a type in this assembly from 18 | // COM, set the ComVisible attribute to true on that type. 19 | [assembly: ComVisible(false)] 20 | 21 | // The following GUID is for the ID of the typelib if this project is exposed to COM 22 | [assembly: Guid("00ce9c72-5f92-4555-bb96-d0e77b9a2390")] 23 | 24 | // Version information for an assembly consists of the following four values: 25 | // 26 | // Major Version 27 | // Minor Version 28 | // Build Number 29 | // Revision 30 | // 31 | // You can specify all the values or you can default the Build and Revision Numbers 32 | // by using the '*' as shown below: 33 | // [assembly: AssemblyVersion("1.0.*")] 34 | [assembly: AssemblyVersion("1.0.0.0")] 35 | [assembly: AssemblyFileVersion("1.0.0.0")] 36 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Engine/favico.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/DigitRecognizer.Engine/favico.ico -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Data/CalculationCache.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | using DigitRecognizer.Core.Utilities; 3 | 4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Data 5 | { 6 | /// 7 | /// Represents a simple cache for storing values that need to be saved. 8 | /// 9 | public class CalculationCache 10 | { 11 | private readonly List _cache; 12 | 13 | /// 14 | /// Initializes a new instance of the class. 15 | /// 16 | public CalculationCache() 17 | { 18 | _cache = new List(); 19 | } 20 | 21 | /// 22 | /// Gets the cache. 23 | /// 24 | /// 25 | public List GetCache() 26 | { 27 | return _cache; 28 | } 29 | 30 | /// 31 | /// Sets the specified value at the specified index. 32 | /// 33 | /// The value. 34 | /// The index. 35 | public void SetValue(double[][] value, int index) 36 | { 37 | Contracts.ValueNotNull(value, nameof(value)); 38 | Contracts.ValueWithinBounds(index, 0, _cache.Count, nameof(index)); 39 | 40 | if (index > _cache.Count - 1) 41 | { 42 | Add(value); 43 | } 44 | 45 | _cache[index] = value; 46 | } 47 | 48 | /// 49 | /// Addes the specified value to the cache. 50 | /// 51 | /// 52 | private void Add(double[][] value) 53 | { 54 | Contracts.ValueNotNull(value, nameof(value)); 55 | 56 | _cache.Add(value); 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Dropout/BinomialDistribution.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using DigitRecognizer.Core.Extensions; 3 | 4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Dropout 5 | { 6 | /// 7 | /// Represents a generator for a binomial distribution. 8 | /// 9 | public class BinomialDistribution 10 | { 11 | private readonly double _p; 12 | private readonly Random _rnd; 13 | 14 | /// 15 | /// Initializes a new instance of the class. 16 | /// 17 | /// The probability of success. 18 | /// The random number generator. 19 | public BinomialDistribution(double p, Random rnd) 20 | { 21 | _p = p; 22 | _rnd = rnd; 23 | } 24 | 25 | /// 26 | /// Generates a random binomial distribution of the specified length. 27 | /// 28 | /// The length. 29 | /// The binomial distribution. 30 | public int[] Generate(int length) 31 | { 32 | double[] uniform = _rnd.NextDoubles(length); 33 | 34 | var binomial = new int[length]; 35 | 36 | for (var i = 0; i < length; i++) 37 | { 38 | binomial[i] = uniform[i] < _p ? 1 : 0; 39 | } 40 | 41 | return binomial; 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Dropout/Dropout.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | 4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Dropout 5 | { 6 | /// 7 | /// Implements the dropout regularization method. 8 | /// 9 | public class Dropout 10 | { 11 | private readonly double _keepProbability; 12 | private readonly BinomialDistribution _binomialDistributiion; 13 | 14 | /// 15 | /// Initializes a new instance of the class. 16 | /// 17 | /// The keep probability. 18 | public Dropout(double keepProbability) 19 | { 20 | if (keepProbability < 0 || keepProbability > 1.0) 21 | { 22 | throw new ArgumentOutOfRangeException(nameof(keepProbability)); 23 | } 24 | 25 | _keepProbability = keepProbability; 26 | 27 | _binomialDistributiion = new BinomialDistribution(keepProbability, new Random()); 28 | } 29 | 30 | /// 31 | /// Randomly generates a list of dropout vectors of the specified sizes. 32 | /// 33 | /// The array of sizes. 34 | /// The list of dropout vectors. 35 | public List GenerateDropoutVectors(int[] sizes) 36 | { 37 | var result = new List(); 38 | 39 | foreach (int size in sizes) 40 | { 41 | var dropoutVector = new double[size]; 42 | 43 | int[] binomial = _binomialDistributiion.Generate(size); 44 | 45 | // The reason to divide with the keep probability is so that the expected output of the network 46 | // during test time is not affacted by the expected output of the network during training time. 47 | 48 | for (var i = 0; i < size; i++) 49 | { 50 | dropoutVector[i] = binomial[i] / _keepProbability; 51 | } 52 | 53 | result.Add(dropoutVector); 54 | } 55 | 56 | return result; 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Factories/AbstractTypeFactory.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Linq; 4 | using System.Reflection; 5 | 6 | namespace DigitRecognizer.MachineLearning.Infrastructure.Factories 7 | { 8 | /// 9 | /// Represents a factory for getting instances of based on a key. 10 | /// 11 | /// The instance type the abstract factory returns. 12 | /// The key used in the type dictionary. 13 | public abstract class AbstractTypeFactory 14 | { 15 | private readonly Dictionary _typeDictionary; 16 | 17 | private readonly string _propName; 18 | private readonly Type _baseType; 19 | 20 | /// 21 | /// Initializes a new instance of the class. 22 | /// 23 | /// The base 24 | /// The key of the property of the . 25 | protected AbstractTypeFactory(Type baseType, string propName) 26 | { 27 | _propName = propName; 28 | _baseType = baseType; 29 | 30 | _typeDictionary = new Dictionary(); 31 | 32 | FillTypeDictionary(); 33 | } 34 | 35 | /// 36 | /// Fills the internal dictionary with data. 37 | /// 38 | private void FillTypeDictionary() 39 | { 40 | // Get all the types in the executing assembly that implement our interface, and are classes. 41 | IEnumerable types = Assembly 42 | .GetExecutingAssembly() 43 | .GetTypes() 44 | .Where(t => _baseType.IsAssignableFrom(t) && t.IsClass); 45 | 46 | // We get the value of the property that is on the interface and fill the dictionary with the property value and type. 47 | foreach (Type type in types) 48 | { 49 | PropertyInfo prop = type.GetProperty(_propName); 50 | 51 | if (prop == null) 52 | { 53 | throw new ArgumentNullException(nameof(prop), $"Property {_propName} was not found on type {type.Name}"); 54 | } 55 | 56 | object obj = Activator.CreateInstance(type); 57 | 58 | _typeDictionary.Add((TKey)prop.GetValue(obj), type); 59 | } 60 | } 61 | 62 | /// 63 | /// Gets the instance of the for the specified key. 64 | /// 65 | /// 66 | /// If no matching type is found an exception is thrown. 67 | /// This indicates something is wrong in the system. 68 | /// Generally, this should never be thrown. 69 | /// 70 | /// The key of the dictionary. 71 | /// The instance of the type, if found. 72 | public TInstance GetInstance(TKey key) 73 | { 74 | Type type = _typeDictionary[key]; 75 | 76 | object result = Activator.CreateInstance(type); 77 | 78 | return (TInstance)result; 79 | } 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Factories/FunctionFactory.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.MachineLearning.Infrastructure.Functions; 2 | 3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Factories 4 | { 5 | /// 6 | /// Factory for getting instances of by providing the name of the function. 7 | /// 8 | public class FunctionFactory : AbstractTypeFactory 9 | { 10 | /// 11 | /// The singleton instance. 12 | /// 13 | private static FunctionFactory _instance; 14 | 15 | /// 16 | /// The object used for locking. Required for thread safety. 17 | /// 18 | private static readonly object Lock = new object(); 19 | 20 | /// 21 | /// Gets the singleton instance of the class. 22 | /// 23 | public static FunctionFactory Instance 24 | { 25 | get 26 | { 27 | lock (Lock) 28 | { 29 | return _instance ?? (_instance = new FunctionFactory()); 30 | } 31 | } 32 | } 33 | 34 | /// 35 | /// Initializes a new instance of the class. 36 | /// 37 | private FunctionFactory() : base(typeof(IFunction), "Name") 38 | { 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Factories/InitializerFactory.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.MachineLearning.Infrastructure.Initialization; 2 | 3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Factories 4 | { 5 | /// 6 | /// Factory for getting instances of by providing the . 7 | /// 8 | public class InitializerFactory : AbstractTypeFactory 9 | { 10 | /// 11 | /// The singleton instance. 12 | /// 13 | private static InitializerFactory _instance; 14 | 15 | /// 16 | /// The object used for locking. Required for thread safety. 17 | /// 18 | private static readonly object Lock = new object(); 19 | 20 | /// 21 | /// Gets the singleton instance of the class. 22 | /// 23 | public static InitializerFactory Instance 24 | { 25 | get 26 | { 27 | lock (Lock) 28 | { 29 | return _instance ?? (_instance = new InitializerFactory()); 30 | } 31 | } 32 | } 33 | 34 | /// 35 | /// Initializes a new instance of the class. 36 | /// 37 | private InitializerFactory() : base(typeof(IInitializer), "InitializerType") 38 | { 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/CrossEntropy.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Extensions; 2 | using DigitRecognizer.Core.Utilities; 3 | 4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions 5 | { 6 | /// 7 | /// Implements the cross entropy cost function. 8 | /// 9 | public class CrossEntropy : ICostFunction 10 | { 11 | public string Name => "Cross Entropy"; 12 | 13 | /// 14 | /// Calculates the cost for the specified estimated and actual values. 15 | /// 16 | /// The estimated values. 17 | /// The actual values. 18 | /// The cost. 19 | public double Cost(double[] estimatedValues, double[] actualValues) 20 | { 21 | double cost = MathUtilities.CrossEntropy(estimatedValues, actualValues); 22 | 23 | return cost; 24 | } 25 | 26 | /// 27 | /// Gets the derivative of the function for the specified input. 28 | /// 29 | /// The input. 30 | /// The one hot encoded array. 31 | /// The derivative with respect to each input. 32 | public double[] Derivative(double[] input, double[] oneHot) 33 | { 34 | var result = new double[input.Length]; 35 | 36 | int oneHotIndex = oneHot.ArgMax(); 37 | 38 | for (var i = 0; i < input.Length; i++) 39 | { 40 | double delta = i == oneHotIndex ? 1.0 : 0.0; 41 | 42 | result[i] = input[i] - delta; 43 | } 44 | 45 | return result; 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/ExponentialRelu.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Utilities; 2 | 3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions 4 | { 5 | /// 6 | /// Implements the exponential RELU activation function. 7 | /// 8 | public class ExponentialRelu : IActivationFunction 9 | { 10 | private const double Alpha = 0.01; 11 | 12 | public string Name => "Exponential Relu"; 13 | 14 | /// 15 | /// Applies the activation function on every element of the specified array. 16 | /// 17 | /// The array of values. 18 | /// The array with values fed through the activation function. 19 | public double[] Activate(double[] arr) 20 | { 21 | double[] activations = MathUtilities.ExponentialRelu(arr, Alpha); 22 | 23 | return activations; 24 | } 25 | 26 | /// 27 | /// Determines the derivative of the function for the specified inputs. 28 | /// 29 | /// The input. 30 | /// The one hot encoded element. 31 | /// The derivative of the function for every input. 32 | public double[] Derivative(double[] input, double[] oneHot) 33 | { 34 | var result = new double[input.Length]; 35 | 36 | for (var i = 0; i < input.Length; i++) 37 | { 38 | result[i] = input[i] > 0 ? 1.0 : MathUtilities.ExponentialRelu(input[i], Alpha) + Alpha; 39 | } 40 | 41 | return result; 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/IActivationFunction.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions 2 | { 3 | /// 4 | /// Implement that all activation functions must implement. 5 | /// 6 | public interface IActivationFunction : IDifferentiableFunction, IFunction 7 | { 8 | /// 9 | /// Applies the activation function on every element of the specified array. 10 | /// 11 | /// The array of values. 12 | /// The array with values fed through the activation function. 13 | double[] Activate(double[] arr); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/ICostFunction.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions 2 | { 3 | /// 4 | /// Interface that all cost functions must implement. 5 | /// 6 | public interface ICostFunction : IDifferentiableFunction, IFunction 7 | { 8 | /// 9 | /// Gets the cost for the specified estimated and actual values. 10 | /// 11 | /// The estimated values. 12 | /// The actual values. 13 | /// The cost. 14 | double Cost(double[] estimatedValues, double[] actualValues); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/IDifferentiableFunction.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions 2 | { 3 | /// 4 | /// This interface should be implemented by all differentiable functions. 5 | /// 6 | public interface IDifferentiableFunction 7 | { 8 | /// 9 | /// Determines the derivative of the function for the specified inputs. 10 | /// 11 | /// The input. 12 | /// The one hot encoded element. 13 | /// The derivative of the function for every input. 14 | double[] Derivative(double[] input, double[] oneHot); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/IFunction.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions 2 | { 3 | /// 4 | /// Interface that all function must implement. 5 | /// 6 | public interface IFunction 7 | { 8 | /// 9 | /// Gets the name of the function. 10 | /// 11 | string Name { get; } 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/LeakyRelu.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Utilities; 2 | 3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions 4 | { 5 | /// 6 | /// Implements the leaky RELU activation function. 7 | /// 8 | public class LeakyRelu : IActivationFunction 9 | { 10 | private const double Alpha = 0.01; 11 | 12 | public string Name => "Leaky Relu"; 13 | 14 | /// 15 | /// Applies the activation function on every element of the specified array. 16 | /// 17 | /// The array of values. 18 | /// The array with values fed through the activation function. 19 | public double[] Activate(double[] arr) 20 | { 21 | double[] activations = MathUtilities.LeakyRelu(arr, Alpha); 22 | 23 | return activations; 24 | } 25 | 26 | /// 27 | /// Determines the derivative of the function for the specified inputs. 28 | /// 29 | /// The input. 30 | /// The one hot encoded element. 31 | /// The derivative of the function for every input. 32 | public double[] Derivative(double[] input, double[] oneHot) 33 | { 34 | var result = new double[input.Length]; 35 | 36 | for (var i = 0; i < input.Length; i++) 37 | { 38 | result[i] = input[i] > 0 ? 1.0 : Alpha; 39 | } 40 | 41 | return result; 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/MeanSquareError.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Extensions; 2 | using DigitRecognizer.Core.Utilities; 3 | 4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions 5 | { 6 | /// 7 | /// Implements the mean square error function. 8 | /// 9 | public class MeanSquareError : ICostFunction 10 | { 11 | public string Name => "Mean Square Error"; 12 | 13 | /// 14 | /// Calculates the cost for the specified estimated and actual values. 15 | /// 16 | /// The estimated values. 17 | /// The actual values. 18 | /// The cost. 19 | public double Cost(double[] estimatedValues, double[] actualValues) 20 | { 21 | double cost = MathUtilities.MeanSquareErr(estimatedValues, actualValues); 22 | 23 | return cost; 24 | } 25 | 26 | /// 27 | /// Gets the derivative of the function for the specified input. 28 | /// 29 | /// The input. 30 | /// The one hot encoded array. 31 | /// The derivative with respect to each input. 32 | public double[] Derivative(double[] input, double[] oneHot) 33 | { 34 | var result = new double[input.Length]; 35 | 36 | int oneHotIndex = oneHot.ArgMax(); 37 | 38 | for (var i = 0; i < input.Length; i++) 39 | { 40 | double delta = i == oneHotIndex ? 1.0 : 0.0; 41 | 42 | result[i] = input[i] - delta; 43 | } 44 | 45 | return result; 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/Relu.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Utilities; 2 | 3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions 4 | { 5 | /// 6 | /// Implements the RELU activation function. 7 | /// 8 | public class Relu : IActivationFunction 9 | { 10 | public string Name => "Relu"; 11 | 12 | /// 13 | /// Applies the activation function on every element of the specified array. 14 | /// 15 | /// The array of values. 16 | /// The array with values fed through the activation function. 17 | public double[] Activate(double[] arr) 18 | { 19 | double[] activations = MathUtilities.Relu(arr); 20 | 21 | return activations; 22 | } 23 | 24 | /// 25 | /// Determines the derivative of the function for the specified inputs. 26 | /// 27 | /// The input. 28 | /// The one hot encoded element. 29 | /// The derivative of the function for every input. 30 | public double[] Derivative(double[] input, double[] oneHot) 31 | { 32 | var result = new double[input.Length]; 33 | 34 | for (var i = 0; i < input.Length; i++) 35 | { 36 | result[i] = input[i] > 0 ? 1.0 : 0.0; 37 | } 38 | 39 | return result; 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/Sigmoid.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Utilities; 2 | 3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions 4 | { 5 | /// 6 | /// Implements the sigmoid activation function. 7 | /// 8 | public class Sigmoid : IActivationFunction 9 | { 10 | public string Name => "Sigmoid"; 11 | 12 | /// 13 | /// Applies the activation function on every element of the specified array. 14 | /// 15 | /// The array of values. 16 | /// The array with values fed through the activation function. 17 | public double[] Activate(double[] arr) 18 | { 19 | double[] activations = MathUtilities.Sigmoid(arr); 20 | 21 | return activations; 22 | } 23 | 24 | /// 25 | /// Determines the derivative of the function for the specified inputs. 26 | /// 27 | /// The input. 28 | /// The one hot encoded element. 29 | /// The derivative of the function for every input. 30 | public double[] Derivative(double[] input, double[] oneHot) 31 | { 32 | var result = new double[input.Length]; 33 | 34 | for (var i = 0; i < input.Length; i++) 35 | { 36 | double sigmoid = MathUtilities.Sigmoid(input[i]); 37 | 38 | result[i] = sigmoid / (1 - sigmoid); 39 | } 40 | 41 | return result; 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/Softmax.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Extensions; 2 | using DigitRecognizer.Core.Utilities; 3 | 4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions 5 | { 6 | /// 7 | /// Implements the softmax activation function. 8 | /// 9 | public class Softmax : IActivationFunction 10 | { 11 | public string Name => "Softmax"; 12 | 13 | /// 14 | /// Applies the activation function on every element of the specified array. 15 | /// 16 | /// The array of values. 17 | /// The array with values fed through the activation function. 18 | public double[] Activate(double[] arr) 19 | { 20 | double[] activations = MathUtilities.Softmax(arr); 21 | 22 | return activations; 23 | } 24 | 25 | /// 26 | /// Determines the derivative of the function for the specified inputs. 27 | /// 28 | /// The input. 29 | /// The one hot encoded element. 30 | /// The derivative of the function for every input. 31 | public double[] Derivative(double[] input, double[] oneHot) 32 | { 33 | var result = new double[input.Length]; 34 | 35 | int oneHotIndex = oneHot.ArgMax(); 36 | 37 | for (var i = 0; i < input.Length; i++) 38 | { 39 | double delta = i == oneHotIndex ? 1.0 : 0.0; 40 | 41 | result[i] = input[i] * (delta - input[oneHotIndex]); 42 | } 43 | 44 | return result; 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/Softplus.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Utilities; 2 | 3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions 4 | { 5 | /// 6 | /// Implements the softmax activation function. 7 | /// 8 | public class Softplus : IActivationFunction 9 | { 10 | public string Name => "Softplus"; 11 | 12 | /// 13 | /// Applies the activation function on every element of the specified array. 14 | /// 15 | /// The array of values. 16 | /// The array with values fed through the activation function. 17 | public double[] Activate(double[] arr) 18 | { 19 | double[] activations = MathUtilities.Softplus(arr); 20 | 21 | return activations; 22 | } 23 | 24 | /// 25 | /// Determines the derivative of the function for the specified inputs. 26 | /// 27 | /// The input. 28 | /// The one hot encoded element. 29 | /// The derivative of the function for every input. 30 | public double[] Derivative(double[] input, double[] oneHot) 31 | { 32 | var result = new double[input.Length]; 33 | 34 | for (var i = 0; i < input.Length; i++) 35 | { 36 | result[i] = MathUtilities.Sigmoid(input[i]); 37 | } 38 | 39 | return result; 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/Tanh.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Utilities; 2 | 3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions 4 | { 5 | /// 6 | /// Implements the hyperbolic tangent activation function. 7 | /// 8 | public class Tanh : IActivationFunction 9 | { 10 | public string Name => "Tanh"; 11 | 12 | /// 13 | /// Applies the activation function on every element of the specified array. 14 | /// 15 | /// The array of values. 16 | /// The array with values fed through the activation function. 17 | public double[] Activate(double[] arr) 18 | { 19 | double[] activations = MathUtilities.Tanh(arr); 20 | 21 | return activations; 22 | } 23 | 24 | /// 25 | /// Determines the derivative of the function for the specified inputs. 26 | /// 27 | /// The input. 28 | /// The one hot encoded element. 29 | /// The derivative of the function for every input. 30 | public double[] Derivative(double[] input, double[] oneHot) 31 | { 32 | var result = new double[input.Length]; 33 | 34 | for (var i = 0; i < input.Length; i++) 35 | { 36 | double tanh = MathUtilities.Tanh(input[i]); 37 | 38 | result[i] = 1 - tanh * tanh; 39 | } 40 | 41 | return result; 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Initialization/HeInitializer.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using DigitRecognizer.Core.Utilities; 3 | 4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Initialization 5 | { 6 | /// 7 | /// Implements the He initialization algorithm. 8 | /// 9 | public class HeInitializer : IInitializer 10 | { 11 | /// 12 | /// Gest the . 13 | /// 14 | public InitializerType InitializerType => InitializerType.HeInitialization; 15 | 16 | /// 17 | /// Returns an initialized matrix based on the specified parameters. 18 | /// 19 | /// The row count. 20 | /// The col count. 21 | /// The initialized matrix. 22 | public double[][] Initialize(int rowCount, int colCont) 23 | { 24 | double[][] result = VectorUtilities.CreateMatrix(rowCount, colCont); 25 | 26 | var random = new Random(); 27 | 28 | double factor = Math.Sqrt(2.0 / rowCount); 29 | 30 | for (var i = 0; i < rowCount; i++) 31 | { 32 | for (var j = 0; j < colCont; j++) 33 | { 34 | result[i][j] = random.NextDouble() * factor; 35 | } 36 | } 37 | 38 | return result; 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Initialization/IInitializer.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.MachineLearning.Infrastructure.Initialization 2 | { 3 | /// 4 | /// Interface for an initialization algorithm. 5 | /// 6 | public interface IInitializer 7 | { 8 | /// 9 | /// Gets the . 10 | /// 11 | InitializerType InitializerType { get; } 12 | 13 | /// 14 | /// Returns an initialized matrix based on the specified parameters. 15 | /// 16 | /// The row count. 17 | /// The col count. 18 | /// The initialized matrix. 19 | double[][] Initialize(int rowCount, int colCont); 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Initialization/InitializerType.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.MachineLearning.Infrastructure.Initialization 2 | { 3 | /// 4 | /// Represents the type of initialization method. 5 | /// 6 | public enum InitializerType 7 | { 8 | /// 9 | /// Initializes all values to zero. 10 | /// 11 | ZeroInitialization = 0, 12 | 13 | /// 14 | /// Initializes all values to random values. 15 | /// 16 | RandomInitialization = 1, 17 | 18 | /// 19 | /// Initializes all values using the Xavier initialization method. 20 | /// 21 | XavierInitialization = 2, 22 | 23 | /// 24 | /// Initializes all values using the He-et-al initialization method. 25 | /// 26 | HeInitialization = 3 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Initialization/RandomInitializer.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using DigitRecognizer.Core.Utilities; 3 | 4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Initialization 5 | { 6 | /// 7 | /// Implements the random initialization algorithm. 8 | /// 9 | public class RandomInitializer : IInitializer 10 | { 11 | /// 12 | /// Gest the . 13 | /// 14 | public InitializerType InitializerType => InitializerType.RandomInitialization; 15 | 16 | /// 17 | /// Returns an initialized matrix based on the specified parameters. 18 | /// 19 | /// The row count. 20 | /// The col count. 21 | /// The initialized matrix. 22 | public double[][] Initialize(int rowCount, int colCont) 23 | { 24 | double[][] result = VectorUtilities.CreateMatrix(rowCount, colCont); 25 | 26 | var random = new Random(); 27 | 28 | const double factor = 0.01; 29 | 30 | for (var i = 0; i < rowCount; i++) 31 | { 32 | for (var j = 0; j < colCont; j++) 33 | { 34 | result[i][j] = random.NextDouble() * factor; 35 | } 36 | } 37 | 38 | return result; 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Initialization/XavierInitializer.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using DigitRecognizer.Core.Utilities; 3 | 4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Initialization 5 | { 6 | /// 7 | /// Implements the Xavier initialization algorithm. 8 | /// 9 | public class XavierInitializer : IInitializer 10 | { 11 | /// 12 | /// Gest the . 13 | /// 14 | public InitializerType InitializerType => InitializerType.XavierInitialization; 15 | 16 | /// 17 | /// Returns an initialized matrix based on the specified parameters. 18 | /// 19 | /// The row count. 20 | /// The col count. 21 | /// The initialized matrix. 22 | public double[][] Initialize(int rowCount, int colCont) 23 | { 24 | double[][] result = VectorUtilities.CreateMatrix(rowCount, colCont); 25 | 26 | var random = new Random(); 27 | 28 | double factor = Math.Sqrt(1.0 / rowCount); 29 | 30 | for (var i = 0; i < rowCount; i++) 31 | { 32 | for (var j = 0; j < colCont; j++) 33 | { 34 | result[i][j] = random.NextDouble() * factor; 35 | } 36 | } 37 | 38 | return result; 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Initialization/ZeroInitializer.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Utilities; 2 | 3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Initialization 4 | { 5 | /// 6 | /// Implements the zero initialization algorithm. 7 | /// 8 | public class ZeroInitializer : IInitializer 9 | { 10 | /// 11 | /// Gest the . 12 | /// 13 | public InitializerType InitializerType => InitializerType.ZeroInitialization; 14 | 15 | /// 16 | /// Returns an initialized matrix based on the specified parameters. 17 | /// 18 | /// The row count. 19 | /// The col count. 20 | /// The initialized matrix. 21 | public double[][] Initialize(int rowCount, int colCont) 22 | { 23 | return VectorUtilities.CreateMatrix(rowCount, colCont); 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Models/ClusterPredictionModel.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | using System.Linq; 3 | using DigitRecognizer.Core.Extensions; 4 | using DigitRecognizer.Core.Utilities; 5 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork; 6 | 7 | namespace DigitRecognizer.MachineLearning.Infrastructure.Models 8 | { 9 | public class ClusterPredictionModel : IPredictionModel 10 | { 11 | private readonly List _cluster; 12 | 13 | public ClusterPredictionModel(IEnumerable models) 14 | { 15 | Contracts.ValueNotNull(models, nameof(models)); 16 | 17 | _cluster = new List(models); 18 | } 19 | 20 | public ClusterPredictionModel(IEnumerable networks) 21 | { 22 | Contracts.ValueNotNull(networks, nameof(networks)); 23 | 24 | _cluster = new List(networks.Select(x=> new PredictionModel(x))); 25 | } 26 | 27 | /// 28 | /// Predicts the output for the specified input. 29 | /// 30 | /// The input. 31 | /// The output as a vector of probabilites. 32 | public double[] Predict(double[] input) 33 | { 34 | var predictions = new double[_cluster.Count][]; 35 | 36 | for (var i = 0; i < _cluster.Count; i++) 37 | { 38 | double[] prediction = _cluster[i].Predict(input); 39 | 40 | predictions[i] = prediction; 41 | } 42 | 43 | double[] result = predictions.Average(); 44 | 45 | return result; 46 | } 47 | 48 | public static ClusterPredictionModel FromFiles(string[] filenames) 49 | { 50 | return new ClusterPredictionModel(filenames.Select(PredictionModel.FromFile)); 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Models/IPredictionModel.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.MachineLearning.Infrastructure.Models 2 | { 3 | /// 4 | /// Interface for a prediction model. 5 | /// 6 | public interface IPredictionModel 7 | { 8 | /// 9 | /// Predicts the output for the specified input. 10 | /// 11 | /// The input. 12 | /// The output as a vector of probabilites. 13 | double[] Predict(double[] input); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Models/PredictionModel.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Extensions; 2 | using DigitRecognizer.Core.Utilities; 3 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork; 4 | using DigitRecognizer.MachineLearning.Serialization; 5 | 6 | namespace DigitRecognizer.MachineLearning.Infrastructure.Models 7 | { 8 | /// 9 | /// Represents a model that runs an internal neural network to predict input values. 10 | /// 11 | public class PredictionModel : IPredictionModel 12 | { 13 | private readonly INeuralNetwork _neuralNetwork; 14 | 15 | /// 16 | /// Initializes a new instance of the class. 17 | /// 18 | /// 19 | public PredictionModel(INeuralNetwork neuralNetwork) 20 | { 21 | Contracts.ValueNotNull(neuralNetwork, nameof(neuralNetwork)); 22 | 23 | _neuralNetwork = neuralNetwork; 24 | } 25 | 26 | /// 27 | /// Predicts the output for the specified input. 28 | /// 29 | /// The input. 30 | /// The output as a vector of probabilites. 31 | public double[] Predict(double[] input) 32 | { 33 | double[][] prediction = _neuralNetwork.FeedForward(input.AsMatrix()); 34 | 35 | return prediction[0]; 36 | } 37 | 38 | /// 39 | /// Saves the to the specified file. 40 | /// 41 | /// The filename. 42 | public void Save(string filename) 43 | { 44 | var serializer = new NnSerializer(); 45 | 46 | serializer.Serialize(filename, _neuralNetwork); 47 | } 48 | 49 | /// 50 | /// Creates a from the specified file. 51 | /// 52 | /// The filename. 53 | /// The deserialized prediction model. 54 | public static PredictionModel FromFile(string filename) 55 | { 56 | Contracts.FileExists(filename, nameof(filename)); 57 | Contracts.FileExtensionValid(filename, ".nn", nameof(filename)); 58 | 59 | var deserializer = new NnDeserializer(); 60 | 61 | NeuralNetwork.NeuralNetwork neuralNetwork = deserializer.Deserialize(filename); 62 | 63 | return new PredictionModel(neuralNetwork); 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/NeuralNetwork/BiasVector.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Utilities; 2 | 3 | namespace DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork 4 | { 5 | /// 6 | /// Represents a bias vector that is part of a . 7 | /// 8 | public class BiasVector : IValueAdjustable 9 | { 10 | private double[] _bias; 11 | 12 | /// 13 | /// Initializes a new instance of the class with the specified value. 14 | /// 15 | /// 16 | public BiasVector(double[] bias) 17 | { 18 | _bias = bias; 19 | } 20 | 21 | /// 22 | /// Initializes a new instance of the class with the specified length. 23 | /// 24 | /// The length. 25 | public BiasVector(int length) 26 | { 27 | Contracts.ValueGreaterThanZero(length, nameof(length)); 28 | 29 | _bias = new double[length]; 30 | } 31 | 32 | /// 33 | /// Gets the bias of the . 34 | /// 35 | public double[] Bias => _bias; 36 | 37 | /// 38 | /// Gets the length of the . 39 | /// 40 | public int Length => _bias.Length; 41 | 42 | /// 43 | /// Gets the size of the in bytes. 44 | /// 45 | public int SizeInBytes => Length * sizeof(double); 46 | 47 | /// 48 | /// Adjusts the biases of the using the specified gradient. 49 | /// 50 | /// The gradient with respect to the biases. 51 | /// The learning rate. 52 | public void AdjustValue(double[][] gradient, double learningRate) 53 | { 54 | int rowCount = gradient.Length; 55 | int colCount = gradient[0].Length; 56 | 57 | const int rowIndex = 0; 58 | Contracts.ValuesMatch(1, rowCount, nameof(rowCount)); 59 | Contracts.ValuesMatch(_bias.Length, colCount, nameof(colCount)); 60 | 61 | for (var i = 0; i < colCount; i++) 62 | { 63 | _bias[i] = _bias[i] - gradient[rowIndex][i] * learningRate; 64 | } 65 | } 66 | 67 | public static implicit operator double[] (BiasVector b) 68 | { 69 | return b.Bias; 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/NeuralNetwork/INeuralNetwork.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | using DigitRecognizer.MachineLearning.Pipeline; 3 | 4 | namespace DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork 5 | { 6 | /// 7 | /// Contains basic methods that a neural network should implement. 8 | /// 9 | public interface INeuralNetwork : ILearningPipelineNeuralNetworkModel 10 | { 11 | double LearningRate { get; set; } 12 | int NumberOfLayers { get; } 13 | 14 | Core.Data.LinkedList Layers { get; } 15 | List WeightedSumCache { get; } 16 | List ActivationCache { get; } 17 | 18 | double[][] FeedForward(double[][] input); 19 | void AddLayer(IEnumerable layers); 20 | void AddLayer(NnLayer layer); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/NeuralNetwork/IValueAdjustable.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork 2 | { 3 | /// 4 | /// Interface that should be implemented by classes whose values should be adjusted. 5 | /// 6 | public interface IValueAdjustable 7 | { 8 | /// 9 | /// Adjusts the current values with specified gradient and learning rate. 10 | /// 11 | /// The gradient values. 12 | /// The learning rate. 13 | void AdjustValue(double[][] gradient, double learningRate); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/NeuralNetwork/WeightMatrix.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Utilities; 2 | using DigitRecognizer.MachineLearning.Infrastructure.Factories; 3 | using DigitRecognizer.MachineLearning.Infrastructure.Initialization; 4 | using DigitRecognizer.MachineLearning.Pipeline; 5 | 6 | namespace DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork 7 | { 8 | /// 9 | /// Represents a weight matrix that is part of a . 10 | /// 11 | public class WeightMatrix : IValueAdjustable 12 | { 13 | private readonly double[][] _weights; 14 | 15 | /// 16 | /// Initializes a new instance of the class with the specified parameters. 17 | /// 18 | /// The row count. 19 | /// The col count. 20 | public WeightMatrix(int rowCount, int colCount) 21 | { 22 | Contracts.ValueGreaterThanZero(rowCount, nameof(rowCount)); 23 | Contracts.ValueGreaterThanZero(colCount, nameof(colCount)); 24 | 25 | IInitializer initializer = InitializerFactory.Instance.GetInstance(PipelineSettings.Instance.WeightsInitializerType); 26 | 27 | _weights = initializer.Initialize(rowCount, colCount); 28 | } 29 | 30 | /// 31 | /// Initializes a new instance of the class with the specified value. 32 | /// 33 | /// The weights. 34 | public WeightMatrix(double[][] weights) 35 | { 36 | _weights = weights; 37 | } 38 | 39 | /// 40 | /// Gets the row count. 41 | /// 42 | public int RowCount => _weights.Length; 43 | 44 | /// 45 | /// Gets the column count. 46 | /// 47 | public int ColCount => _weights[0].Length; 48 | 49 | /// 50 | /// Gets the flattened . 51 | /// 52 | public double[] Flattened => VectorUtilities.Flatten(_weights); 53 | 54 | /// 55 | /// Gets the size of the in the bytes. 56 | /// 57 | public int SizeInBytes => ColCount * RowCount * sizeof(double); 58 | 59 | /// 60 | /// Gets the weights of the . 61 | /// 62 | public double[][] Weights => _weights; 63 | 64 | /// 65 | /// Adjusts the weights of the using the specified gradient. 66 | /// 67 | /// The gradient with respect to the weights. 68 | /// The learning rate. 69 | public void AdjustValue(double[][] gradient, double learningRate) 70 | { 71 | int rowCount = gradient.Length; 72 | int colCount = gradient[0].Length; 73 | Contracts.ValuesMatch(RowCount, rowCount, nameof(RowCount)); 74 | Contracts.ValuesMatch(ColCount, colCount, nameof(RowCount)); 75 | 76 | var regularizationFactor = 1.0; 77 | 78 | if (PipelineSettings.Instance.UseL2Regularization) 79 | { 80 | regularizationFactor = 1 - PipelineSettings.Instance.RegularizationFactor * learningRate / PipelineSettings.Instance.DatasetSize; 81 | } 82 | 83 | for (var i = 0; i < rowCount; i++) 84 | { 85 | for (var j = 0; j < colCount; j++) 86 | { 87 | _weights[i][j] = regularizationFactor * _weights[i][j] - gradient[i][j] * learningRate; 88 | } 89 | } 90 | } 91 | 92 | public static implicit operator double[][] (WeightMatrix m) 93 | { 94 | return m._weights; 95 | } 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Optimization/LearningRateDecay/ExponentailDecay.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using DigitRecognizer.Core.Utilities; 3 | using DigitRecognizer.MachineLearning.Pipeline; 4 | 5 | namespace DigitRecognizer.MachineLearning.Optimization.LearningRateDecay 6 | { 7 | /// 8 | /// Implements the exponential decay learning rate scheduler. 9 | /// 10 | public class ExponentailDecay : ILearningRateDecay 11 | { 12 | /// 13 | /// Represents the decay rate to be applied to the learning rate. 14 | /// 15 | private readonly double _decayRate; 16 | 17 | /// 18 | /// The initial learning rate. 19 | /// 20 | private readonly double _initialLearningRate; 21 | 22 | /// 23 | /// Initializes a new instance of the class. 24 | /// 25 | /// The initial learning rate. 26 | /// The decay rate. 27 | public ExponentailDecay(double initialLearningRate, double decayRate) 28 | { 29 | Contracts.ValueGreaterThanZero(initialLearningRate, nameof(initialLearningRate)); 30 | Contracts.ValueGreaterThanZero(decayRate, nameof(decayRate)); 31 | 32 | _initialLearningRate = initialLearningRate; 33 | _decayRate = decayRate; 34 | } 35 | 36 | /// 37 | /// Calculates the learning rate. 38 | /// 39 | /// The current learning rate. 40 | /// The new learning rate. 41 | public double DecayLearningRate(double learningRate) 42 | { 43 | learningRate = _initialLearningRate * Math.Exp(-_decayRate * PipelineSettings.Instance.CurrentIteration); 44 | 45 | return learningRate; 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Optimization/LearningRateDecay/ILearningRateDecay.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.MachineLearning.Optimization.LearningRateDecay 2 | { 3 | /// 4 | /// Interface for a learning rate decay method. 5 | /// 6 | public interface ILearningRateDecay 7 | { 8 | /// 9 | /// Calculates the learning rate. 10 | /// 11 | /// The current learning rate. 12 | /// The new learning rate. 13 | double DecayLearningRate(double learningRate); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Optimization/LearningRateDecay/StepDecay.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using DigitRecognizer.Core.Utilities; 3 | using DigitRecognizer.MachineLearning.Pipeline; 4 | 5 | namespace DigitRecognizer.MachineLearning.Optimization.LearningRateDecay 6 | { 7 | /// 8 | /// Implements the step decay learning rate scheduler. 9 | /// 10 | public class StepDecay : ILearningRateDecay 11 | { 12 | /// 13 | /// Reperesents the drop to be applied to the learning rate. 14 | /// 15 | private readonly double _drop; 16 | 17 | /// 18 | /// Reperesnts the number of epochs that need to elapse in order to drop the learning rate. 19 | /// 20 | private readonly double _epochDrop; 21 | 22 | /// 23 | /// The initial learning rate. 24 | /// 25 | private readonly double _initialLearningRate; 26 | 27 | /// 28 | /// Initializes a new instance of the class. 29 | /// 30 | /// The initial learning rate. 31 | /// The drop. 32 | /// The epoch drop. 33 | public StepDecay(double initialLearningRate, double drop, double epochDrop) 34 | { 35 | Contracts.ValueGreaterThanZero(initialLearningRate, nameof(initialLearningRate)); 36 | Contracts.ValueGreaterThanZero(drop, nameof(drop)); 37 | Contracts.ValueGreaterThanZero(epochDrop, nameof(epochDrop)); 38 | 39 | _initialLearningRate = initialLearningRate; 40 | _drop = drop; 41 | _epochDrop = epochDrop; 42 | } 43 | 44 | /// 45 | /// Calculates the learning rate. 46 | /// 47 | /// The current learning rate. 48 | /// The new learning rate. 49 | public double DecayLearningRate(double learningRate) 50 | { 51 | learningRate = _initialLearningRate * Math.Pow(_drop, Math.Floor(PipelineSettings.Instance.CurrentEpoch / _epochDrop)); 52 | 53 | return learningRate; 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Optimization/LearningRateDecay/TimeBasedDecay.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Utilities; 2 | using DigitRecognizer.MachineLearning.Pipeline; 3 | 4 | namespace DigitRecognizer.MachineLearning.Optimization.LearningRateDecay 5 | { 6 | /// 7 | /// Implements the time based decayRate learning rate scheduler. 8 | /// 9 | public class TimeBasedDecay : ILearningRateDecay 10 | { 11 | /// 12 | /// The decay rate to be applied to the learning rate. 13 | /// 14 | private readonly double _decayRate; 15 | 16 | /// 17 | /// Initializes a new instance of the class. 18 | /// 19 | /// The decay rate. 20 | public TimeBasedDecay(double decayRate) 21 | { 22 | Contracts.ValueGreaterThanZero(decayRate, nameof(decayRate)); 23 | 24 | _decayRate = decayRate; 25 | } 26 | 27 | /// 28 | /// Calculates the learning rate. 29 | /// 30 | /// The current learning rate. 31 | /// The new learning rate. 32 | public double DecayLearningRate(double learningRate) 33 | { 34 | learningRate *= 1.0 / (1.0 + _decayRate * PipelineSettings.Instance.CurrentIteration); 35 | 36 | return learningRate; 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Optimization/Optimizers/GradientDescentOptimizer.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.MachineLearning.Infrastructure.Functions; 2 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork; 3 | 4 | namespace DigitRecognizer.MachineLearning.Optimization.Optimizers 5 | { 6 | /// 7 | /// Implements the gradient descent optimization algorithm. 8 | /// 9 | public class GradientDescentOptimizer : BaseOptimizer 10 | { 11 | /// 12 | /// Initializes a new instance of the class. 13 | /// 14 | /// The neural network. 15 | /// The cost function. 16 | public GradientDescentOptimizer(INeuralNetwork neuralNetwork, ICostFunction costFunction) 17 | : base(neuralNetwork, costFunction) 18 | { 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Optimization/Optimizers/IOptimizer.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Data; 2 | using DigitRecognizer.MachineLearning.Infrastructure.Functions; 3 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork; 4 | using DigitRecognizer.MachineLearning.Pipeline; 5 | 6 | namespace DigitRecognizer.MachineLearning.Optimization.Optimizers 7 | { 8 | /// 9 | /// Interface that an optimization algorithm should implement. 10 | /// 11 | public interface IOptimizer : ILearningPipelineOptimizer 12 | { 13 | /// 14 | /// Gets the . 15 | /// 16 | INeuralNetwork NeuralNetwork { get; } 17 | 18 | /// 19 | /// Gets the . 20 | /// 21 | ICostFunction CostFunction { get; } 22 | 23 | /// 24 | /// Calculates the cost of the specified prediction. 25 | /// 26 | /// The prediction. 27 | /// The one hot value. 28 | /// The cost. 29 | double CalculateError(double[] prediction, int oneHot); 30 | 31 | /// 32 | /// Peforms the backpropagatin algorithm. 33 | /// 34 | /// The predictions. 35 | /// The one hot values. 36 | void Backpropagate(double[][] predictions, int[] oneHots); 37 | 38 | /// 39 | /// Adjusts the current parameters with specified gradient and learning rate. 40 | /// 41 | /// The node. 42 | /// The delat values. 43 | /// The gradient values. 44 | /// The learning rate. 45 | void AdjustParameters(LinkedListNode node, double[][] delta, double[][] gradient, double learningRate); 46 | 47 | /// 48 | /// Calculates the output derivative with respect to the of the optimizer. 49 | /// 50 | /// The predictions. 51 | /// The one hot values. 52 | /// The derivatives. 53 | double[][] CalculateOutputDerivative(double[][] predictions, int[] oneHots); 54 | 55 | /// 56 | /// Calcualates the derivative of the weighted sum with respect to the specified . 57 | /// 58 | /// The activation function. 59 | /// The node depth. 60 | /// The one hot values. 61 | /// The derivatives. 62 | double[][] WeightedSumDerivative(IActivationFunction activationFunction, int nodeDepth, int[] oneHots); 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Optimization/Optimizers/MomentumOptimizer.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | using DigitRecognizer.Core.Utilities; 3 | using DigitRecognizer.MachineLearning.Infrastructure.Functions; 4 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork; 5 | 6 | namespace DigitRecognizer.MachineLearning.Optimization.Optimizers 7 | { 8 | /// 9 | /// Implements the gradient descent with momentum optimization algorithm. 10 | /// 11 | public class MomentumOptimizer : BaseOptimizer 12 | { 13 | private readonly double _momentum; 14 | 15 | private List _biasVelocities; 16 | private List _weightVelocities; 17 | 18 | /// 19 | /// initilazies a new instance of the class. 20 | /// 21 | /// The neural network. 22 | /// The cost function. 23 | /// The momentum value. 24 | public MomentumOptimizer(INeuralNetwork neuralNetwork, ICostFunction costFunction, double momentum) : base(neuralNetwork, costFunction) 25 | { 26 | Contracts.ValueGreaterThanZero(momentum, nameof(momentum)); 27 | Contracts.ValueNotNull(neuralNetwork, nameof(neuralNetwork)); 28 | 29 | _momentum = momentum; 30 | 31 | InitializeVelocities(neuralNetwork); 32 | } 33 | 34 | /// 35 | /// Initializes the velocity matrices for the . 36 | /// 37 | /// The neural network. 38 | private void InitializeVelocities(INeuralNetwork neuralNetwork) 39 | { 40 | _biasVelocities = new List(); 41 | _weightVelocities = new List(); 42 | 43 | foreach (NnLayer layer in neuralNetwork.Layers.ToList()) 44 | { 45 | _biasVelocities.Add(VectorUtilities.CreateMatrix(1, layer.NumberOfOutputs)); 46 | 47 | _weightVelocities.Add(VectorUtilities.CreateMatrix(layer.NumberOfInputs, layer.NumberOfOutputs)); 48 | } 49 | } 50 | 51 | /// 52 | /// Adjusts the current parameters with specified gradient and learning rate. 53 | /// 54 | /// The node. 55 | /// The bias delta values. 56 | /// The weight delta values. 57 | /// The learning rate. 58 | public override void AdjustParameters(Core.Data.LinkedListNode node, double[][] bDelta, double[][] wDelta, double learningRate) 59 | { 60 | const int rowIndex = 0; 61 | for (var i = 0; i < _biasVelocities[node.Depth][0].Length; i++) 62 | { 63 | _biasVelocities[node.Depth][rowIndex][i] = _momentum * _biasVelocities[node.Depth][rowIndex][i] + bDelta[rowIndex][i]; 64 | } 65 | 66 | for (var i = 0; i < _weightVelocities[node.Depth].Length; i++) 67 | { 68 | for (var j = 0; j < _weightVelocities[node.Depth][0].Length; j++) 69 | { 70 | _weightVelocities[node.Depth][i][j] = _momentum * _weightVelocities[node.Depth][i][j] + wDelta[i][j]; 71 | } 72 | } 73 | 74 | base.AdjustParameters(node, _biasVelocities[node.Depth], _weightVelocities[node.Depth], learningRate); 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Pipeline/ILearningPipelineDataLoader.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.MachineLearning.Pipeline 2 | { 3 | /// 4 | /// Interface for a data loader. 5 | /// 6 | public interface ILearningPipelineDataLoader : ILearningPipelineItem 7 | { 8 | /// 9 | /// Loads the data from a source file. 10 | /// 11 | /// The data object. 12 | object LoadData(); 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Pipeline/ILearningPipelineItem.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.MachineLearning.Pipeline 2 | { 3 | /// 4 | /// Interface for a item. 5 | /// 6 | public interface ILearningPipelineItem 7 | { 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Pipeline/ILearningPipelineNeuralNetworkModel.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.MachineLearning.Pipeline 2 | { 3 | /// 4 | /// Interface for a neural network model. 5 | /// 6 | public interface ILearningPipelineNeuralNetworkModel : ILearningPipelineItem 7 | { 8 | /// 9 | /// Generates a prediction based on the specified input. 10 | /// 11 | /// The input. 12 | /// The prediction. 13 | double[][] Predict(double[][] input); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Pipeline/ILearningPipelineOptimizer.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.MachineLearning.Pipeline 2 | { 3 | /// 4 | /// Interface for a optimizer. 5 | /// 6 | public interface ILearningPipelineOptimizer : ILearningPipelineItem 7 | { 8 | /// 9 | /// Optimizes the specified parameters. 10 | /// 11 | /// The predictions. 12 | /// The one hot values. 13 | void Optimize(double[][] predictions, int[] oneHots); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.InteropServices; 3 | 4 | // General Information about an assembly is controlled through the following 5 | // set of attributes. Change these attribute values to modify the information 6 | // associated with an assembly. 7 | [assembly: AssemblyTitle("DigitRecognizer.MachineLearning")] 8 | [assembly: AssemblyDescription("")] 9 | [assembly: AssemblyConfiguration("")] 10 | [assembly: AssemblyCompany("")] 11 | [assembly: AssemblyProduct("DigitRecognizer.MachineLearning")] 12 | [assembly: AssemblyCopyright("Copyright © 2018")] 13 | [assembly: AssemblyTrademark("")] 14 | [assembly: AssemblyCulture("")] 15 | 16 | // Setting ComVisible to false makes the types in this assembly not visible 17 | // to COM components. If you need to access a type in this assembly from 18 | // COM, set the ComVisible attribute to true on that type. 19 | [assembly: ComVisible(false)] 20 | 21 | // The following GUID is for the ID of the typelib if this project is exposed to COM 22 | [assembly: Guid("bf82c3df-0efe-45db-8e3d-1ffdd548ae73")] 23 | 24 | // Version information for an assembly consists of the following four values: 25 | // 26 | // Major Version 27 | // Minor Version 28 | // Build Number 29 | // Revision 30 | // 31 | // You can specify all the values or you can default the Build and Revision Numbers 32 | // by using the '*' as shown below: 33 | // [assembly: AssemblyVersion("1.0.*")] 34 | [assembly: AssemblyVersion("1.0.0.0")] 35 | [assembly: AssemblyFileVersion("1.0.0.0")] 36 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Providers/BatchDataProvider.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.Data; 2 | 3 | namespace DigitRecognizer.MachineLearning.Providers 4 | { 5 | /// 6 | /// A data provider, that retrieves a . The size of the batch is configurable. 7 | /// 8 | public class BatchDataProvider : DataProviderBase 9 | { 10 | /// 11 | /// Initializes a new instance of the class. 12 | /// 13 | /// The path to the file containing the labels. 14 | /// The path to the file containing the images. 15 | /// The size of the batch to read. 16 | public BatchDataProvider(string labelFilename, string imageFilename, int batchSize) 17 | : base(labelFilename, imageFilename, batchSize) 18 | { 19 | } 20 | 21 | /// 22 | /// Gets the data from the file system. 23 | /// 24 | /// 25 | public override MnistImageBatch GetData() 26 | { 27 | int[] label = LabelReader.ReadLabels(BatchSize); 28 | 29 | double[][] pixels = PixelReader.ReadPixels(BatchSize, ImageSizeInPixels); 30 | 31 | pixels = NormalizePixels(pixels); 32 | 33 | var result = new MnistImageBatch(label, pixels); 34 | 35 | return result; 36 | } 37 | 38 | /// 39 | /// Normalizes the specified matrix of pixels to be in the rande [0,1]. 40 | /// 41 | /// The pixels. 42 | /// The clamped pixels. 43 | private double[][] NormalizePixels(double[][] pixels) 44 | { 45 | for (var i = 0; i < pixels.Length; i++) 46 | { 47 | for (var j = 0; j < pixels[0].Length; j++) 48 | { 49 | pixels[i][j] = pixels[i][j] / 255d; 50 | } 51 | } 52 | 53 | return pixels; 54 | } 55 | 56 | /// 57 | /// Loads the data from a source file. 58 | /// 59 | /// The data object. 60 | public override object LoadData() 61 | { 62 | return GetData(); 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Providers/DataProviderBase.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.IO; 2 | using DigitRecognizer.Core.Utilities; 3 | 4 | namespace DigitRecognizer.MachineLearning.Providers 5 | { 6 | /// 7 | /// Represents a base class for a data provider. 8 | /// 9 | /// The type of data the provider returns. 10 | public abstract class DataProviderBase : IDataProvider 11 | { 12 | private readonly string _labelFilename; 13 | private readonly string _imageFilename; 14 | private readonly int _batchSize; 15 | 16 | /// 17 | /// The size of a single MNIST image in pixels. 18 | /// 19 | protected const int ImageSizeInPixels = 784; 20 | 21 | /// 22 | /// The instance, used for fetching labels. 23 | /// 24 | protected readonly ILabelReader LabelReader; 25 | 26 | /// 27 | /// The instance, used for fetching pixels. 28 | /// 29 | protected readonly IPixelReader PixelReader; 30 | 31 | /// 32 | /// Gets the batch size of the data provider. 33 | /// 34 | public int BatchSize => _batchSize; 35 | 36 | /// 37 | /// Gets the name of the file containing labels. 38 | /// 39 | public string LabelFilename => _labelFilename; 40 | 41 | /// 42 | /// Gets the name of the file containing images. 43 | /// 44 | public string ImageFilename => _imageFilename; 45 | 46 | /// 47 | /// Initialzies a new instance of the class. 48 | /// 49 | /// The path to the file containing the labels. 50 | /// The path to the file containing the images. 51 | /// The size of the batch to read. 52 | protected DataProviderBase(string labelFilename, string imageFilename, int batchSize) 53 | { 54 | Contracts.StringNotNullOrEmpty(labelFilename, nameof(labelFilename)); 55 | Contracts.FileExists(labelFilename, nameof(labelFilename)); 56 | 57 | Contracts.FileExists(imageFilename, nameof(imageFilename)); 58 | Contracts.StringNotNullOrEmpty(imageFilename, nameof(imageFilename)); 59 | 60 | Contracts.ValueGreaterThanZero(batchSize, nameof(batchSize)); 61 | 62 | LabelReader = new LabelReader(labelFilename); 63 | PixelReader = new PixelReader(imageFilename); 64 | _labelFilename = labelFilename; 65 | _imageFilename = imageFilename; 66 | _batchSize = batchSize; 67 | } 68 | 69 | /// 70 | /// Gets the data from the file system. 71 | /// 72 | /// 73 | public virtual T GetData() 74 | { 75 | return default(T); 76 | } 77 | 78 | /// 79 | /// 80 | /// 81 | /// 82 | public virtual object LoadData() 83 | { 84 | return GetData(); 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Providers/IDataProvider.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.MachineLearning.Pipeline; 2 | 3 | namespace DigitRecognizer.MachineLearning.Providers 4 | { 5 | /// 6 | /// Represents a data provider abstraction that a neural network will use for training. 7 | /// 8 | /// 9 | public interface IDataProvider : ILearningPipelineDataLoader 10 | { 11 | /// 12 | /// Gets the data from a file. 13 | /// 14 | /// 15 | T GetData(); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Serialization/INnSerializable.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Core.IO; 2 | 3 | namespace DigitRecognizer.MachineLearning.Serialization 4 | { 5 | /// 6 | /// Neural network layer will implement this interface in order to support serialization to a binary file. 7 | /// 8 | public interface INnSerializable 9 | { 10 | /// 11 | /// Seralizes the object to a . 12 | /// 13 | /// The serialization context containing information about the object. 14 | NnSerializationContext Serialize(); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Serialization/NnDeserializer.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | using System.IO; 3 | using System.Linq; 4 | using DigitRecognizer.Core.IO; 5 | using DigitRecognizer.MachineLearning.Infrastructure.Factories; 6 | using DigitRecognizer.MachineLearning.Infrastructure.Functions; 7 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork; 8 | 9 | namespace DigitRecognizer.MachineLearning.Serialization 10 | { 11 | /// 12 | /// Deserializer for deserializing objects. 13 | /// 14 | public class NnDeserializer 15 | { 16 | /// 17 | /// Deserializes data from the specified file to a . 18 | /// 19 | /// 20 | /// 21 | public NeuralNetwork Deserialize(string filename) 22 | { 23 | var deserializer = new NnDeserializer(); 24 | 25 | IEnumerable layers = deserializer.DeserializeLayers(filename); 26 | 27 | var neuralNetwork = new NeuralNetwork(layers, 1.0); 28 | 29 | return neuralNetwork; 30 | } 31 | 32 | /// 33 | /// Deserializes data from the specified file to a collection of objects. 34 | /// 35 | /// The filename. 36 | /// An collection of neural network layers. 37 | public IEnumerable DeserializeLayers(string filename) 38 | { 39 | using (var deserializer = new NnBinaryDeserializer(filename, FileMode.Open)) 40 | { 41 | IEnumerable contexts = deserializer.Deserialize(); 42 | 43 | IEnumerable result = contexts.Select(DeserializeContext); 44 | 45 | return result; 46 | } 47 | } 48 | 49 | /// 50 | /// Deserializes the specified to a object. 51 | /// 52 | /// The serialization context. 53 | /// The deserialized neural network layer. 54 | public NnLayer DeserializeContext(NnSerializationContext serializationContext) 55 | { 56 | NnSerializationContextInfo contextInfo = serializationContext.SerializationContextInfo; 57 | 58 | double[] data = serializationContext.FileData; 59 | 60 | var activationFunction = (IActivationFunction) FunctionFactory.Instance.GetInstance(contextInfo.ActivationFunctionName); 61 | 62 | var layer = new NnLayer(contextInfo.WeightMatrixRowCount, contextInfo.WeightMatrixColCount, contextInfo.BiasLength, data, activationFunction); 63 | 64 | return layer; 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.MachineLearning/Serialization/NnSerializer.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | using System.IO; 3 | using System.Linq; 4 | using DigitRecognizer.Core.IO; 5 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork; 6 | 7 | namespace DigitRecognizer.MachineLearning.Serialization 8 | { 9 | /// 10 | /// Serializer for serializing objects. 11 | /// 12 | public class NnSerializer 13 | { 14 | /// 15 | /// Serializes the specified to a file. 16 | /// 17 | /// The filename to write to. 18 | /// The neural network. 19 | public void Serialize(string filename, INeuralNetwork neuralNetwork) 20 | { 21 | Serialize(filename, neuralNetwork.Layers.ToList()); 22 | } 23 | 24 | /// 25 | /// Serializes the specified collection of objects to a file. 26 | /// 27 | /// The filename to write to. 28 | /// The collection of objects to serialize. 29 | public void Serialize(string filename, IEnumerable collection) 30 | { 31 | using (var serializer = new NnBinarySerializer(filename, FileMode.Create)) 32 | { 33 | IEnumerable contexts = collection.Select(layer => layer.Serialize()); 34 | 35 | serializer.Serialize(contexts); 36 | } 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/App.config: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Components/PredictionPane.Designer.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.Presentation.Components 2 | { 3 | partial class PredictionPane 4 | { 5 | /// 6 | /// Required designer variable. 7 | /// 8 | private System.ComponentModel.IContainer components = null; 9 | 10 | /// 11 | /// Clean up any resources being used. 12 | /// 13 | /// true if managed resources should be disposed; otherwise, false. 14 | protected override void Dispose(bool disposing) 15 | { 16 | if (disposing && (components != null)) 17 | { 18 | components.Dispose(); 19 | } 20 | base.Dispose(disposing); 21 | } 22 | 23 | #region Component Designer generated code 24 | 25 | /// 26 | /// Required method for Designer support - do not modify 27 | /// the contents of this method with the code editor. 28 | /// 29 | private void InitializeComponent() 30 | { 31 | this.SuspendLayout(); 32 | // 33 | // PredictionPane 34 | // 35 | this.AutoScaleDimensions = new System.Drawing.SizeF(6F, 13F); 36 | this.AutoScaleMode = System.Windows.Forms.AutoScaleMode.Font; 37 | this.BackColor = System.Drawing.Color.White; 38 | this.Name = "PredictionPane"; 39 | this.Size = new System.Drawing.Size(550, 400); 40 | this.ResumeLayout(false); 41 | 42 | } 43 | 44 | #endregion 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Components/PredictionPane.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | using System.Drawing; 4 | using System.Windows.Forms; 5 | 6 | namespace DigitRecognizer.Presentation.Components 7 | { 8 | public partial class PredictionPane : UserControl 9 | { 10 | #region Fields 11 | 12 | private const string LabelName = "lbl"; 13 | private const string LabelPredictionName = "lblPrediction"; 14 | private const string ProgressBarName = "progressBar"; 15 | private const int ClassNumber = 10; 16 | 17 | #endregion 18 | 19 | #region Ctor 20 | 21 | public PredictionPane() 22 | { 23 | InitializeComponent(); 24 | 25 | InitializeDisplay(); 26 | } 27 | 28 | #endregion 29 | 30 | #region Methods 31 | 32 | public void ProcessPrediction(double[] predictions) 33 | { 34 | if (predictions.Length != ClassNumber) 35 | { 36 | throw new ArgumentException(nameof(predictions)); 37 | } 38 | 39 | for (var i = 0; i < predictions.Length; i++) 40 | { 41 | ProgressBarHandler(GetProgressBarName(i), (int)(Math.Round(predictions[i] * 100))); 42 | 43 | PredictionLabelHandler(GetPredictionLabelName(i), predictions[i].ToString("P")); 44 | } 45 | } 46 | 47 | public void Clear() 48 | { 49 | for (var i = 0; i < ClassNumber; i++) 50 | { 51 | ProgressBarHandler(GetProgressBarName(i), 0); 52 | 53 | PredictionLabelHandler(GetPredictionLabelName(i), string.Empty); 54 | } 55 | } 56 | 57 | private void ProgressBarHandler(string progressBarKey, int value) 58 | { 59 | if (Controls.ContainsKey(progressBarKey)) 60 | { 61 | ((ProgressBar)Controls[progressBarKey]).Value = value; 62 | } 63 | else 64 | { 65 | throw new ArgumentException(nameof(progressBarKey)); 66 | } 67 | } 68 | 69 | private void PredictionLabelHandler(string labelPredictionKey, string text) 70 | { 71 | if (Controls.ContainsKey(labelPredictionKey)) 72 | { 73 | ((Label)Controls[labelPredictionKey]).Text = text; 74 | } 75 | else 76 | { 77 | throw new ArgumentException(nameof(labelPredictionKey)); 78 | } 79 | } 80 | 81 | #endregion 82 | 83 | #region Utilities 84 | 85 | private void InitializeDisplay() 86 | { 87 | var controls = new List(); 88 | 89 | for (var i = 0; i < 10; i++) 90 | { 91 | int top = Top + 40 * i + 1; 92 | 93 | Label labelNum = GenerateLabel(GetLabelName(i), Left, top + 1, i.ToString(), 20); 94 | 95 | ProgressBar progressBar = GenerateProgressBar(GetProgressBarName(i), labelNum.Right + 5, top); 96 | 97 | Label labelPrediction = GenerateLabel(GetPredictionLabelName(i), progressBar.Right + 5, top + 1, "Prediction"); 98 | 99 | controls.Add(labelNum); 100 | controls.Add(progressBar); 101 | controls.Add(labelPrediction); 102 | } 103 | 104 | Controls.AddRange(controls.ToArray()); 105 | } 106 | 107 | private static string GetPredictionLabelName(int i) => $"{LabelPredictionName}{i}"; 108 | 109 | private static string GetProgressBarName(int i) => $"{ProgressBarName}{i}"; 110 | 111 | private static string GetLabelName(int i) => $"{LabelName}{i}"; 112 | 113 | private static ProgressBar GenerateProgressBar(string name, int left, int top) 114 | { 115 | return new ProgressBar 116 | { 117 | Name = name, 118 | Value = 0, 119 | Left = left, 120 | Top = top, 121 | Width = 450 122 | }; 123 | } 124 | 125 | private static Label GenerateLabel(string name, int left, int right, string text = "", int? width = null) 126 | { 127 | var lbl = new Label 128 | { 129 | AutoSize = true, 130 | Name = name, 131 | Text = text, 132 | Left = left, 133 | Top = right, 134 | Font = new Font(FontFamily.GenericSansSerif, 12f) 135 | }; 136 | 137 | if (width != null) 138 | { 139 | lbl.Width = (int)width; 140 | } 141 | 142 | return lbl; 143 | } 144 | 145 | #endregion 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Global.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.MachineLearning.Infrastructure.Models; 2 | using DigitRecognizer.Presentation.Infrastructure; 3 | 4 | namespace DigitRecognizer.Presentation 5 | { 6 | public static class Global 7 | { 8 | private static IPredictionModel _predictionModel; 9 | 10 | public static IPredictionModel PredictionModel => _predictionModel ?? (_predictionModel = LoadModel()); 11 | 12 | public static readonly int ImageGridFieldCount = 100; 13 | 14 | private static IPredictionModel LoadModel() 15 | { 16 | var loader = new PredictionModelLoader(); 17 | 18 | IPredictionModel model = loader.Load(); 19 | 20 | return model; 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Infrastructure/DependencyResolver.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Collections.Generic; 3 | 4 | namespace DigitRecognizer.Presentation.Infrastructure 5 | { 6 | public class DependencyResolver 7 | { 8 | private readonly Dictionary _dependencyDictionary; 9 | 10 | private static DependencyResolver _instance; 11 | 12 | public DependencyResolver() 13 | { 14 | _dependencyDictionary = new Dictionary(); 15 | } 16 | 17 | /// 18 | /// The object used for locking. Required for thread safety. 19 | /// 20 | private static readonly object Lock = new object(); 21 | 22 | /// 23 | /// Gets the singleton instance of the class. 24 | /// 25 | public static DependencyResolver Instance 26 | { 27 | get 28 | { 29 | lock (Lock) 30 | { 31 | return _instance ?? (_instance = new DependencyResolver()); 32 | } 33 | } 34 | } 35 | 36 | public static void Register() 37 | { 38 | Register(typeof(TContract), typeof(TImplementation)); 39 | } 40 | 41 | public static void Register(Type contract, Type implementation) 42 | { 43 | Instance._dependencyDictionary.Add(contract, implementation); 44 | } 45 | 46 | public static T Resolve() 47 | { 48 | return (T)Resolve(typeof(T)); 49 | } 50 | 51 | public static object Resolve(Type contract) 52 | { 53 | return Activator.CreateInstance(Instance._dependencyDictionary[contract]); 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Infrastructure/ExceptionHandlers.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Threading; 3 | using System.Windows.Forms; 4 | using DigitRecognizer.Presentation.Services; 5 | 6 | namespace DigitRecognizer.Presentation.Infrastructure 7 | { 8 | public static class ExceptionHandlers 9 | { 10 | public static void CurrentDomainOnUnhandledException(object sender, UnhandledExceptionEventArgs e) 11 | { 12 | var ex = (Exception) e.ExceptionObject; 13 | 14 | string message = $"Something went wrong.{Environment.NewLine}{Environment.NewLine}{ex}"; 15 | 16 | DependencyResolver.Resolve().Log(ex); 17 | 18 | DependencyResolver.Resolve().ShowMessage(message, "Unhandled Exception", icon: MessageBoxIcon.Error); 19 | } 20 | 21 | public static void ApplicationOnThreadException(object sender, ThreadExceptionEventArgs e) 22 | { 23 | string message = $"Something went wrong.{Environment.NewLine}{Environment.NewLine}{e.Exception}"; 24 | 25 | DependencyResolver.Resolve().Log(e.Exception); 26 | 27 | DependencyResolver.Resolve().ShowMessage(message, "Thread exception", icon: MessageBoxIcon.Information); 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Infrastructure/PanelDoubleBuffering.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Reflection; 3 | using System.Windows.Forms; 4 | 5 | namespace DigitRecognizer.Presentation.Infrastructure 6 | { 7 | public static class PanelDoubleBuffering 8 | { 9 | private const string DoubleBufferedName = "DoubleBuffered"; 10 | 11 | public static void Enable(object target) 12 | { 13 | if (target == null) 14 | { 15 | throw new NullReferenceException(); 16 | } 17 | 18 | if (!(target is Panel)) 19 | { 20 | throw new ArgumentException($"Can not enable double buffering on object of type {target.GetType()}"); 21 | } 22 | 23 | typeof(Panel).InvokeMember(DoubleBufferedName, 24 | BindingFlags.SetProperty | BindingFlags.Instance | BindingFlags.NonPublic, 25 | null, 26 | target, 27 | new object[] { true }); 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Infrastructure/PredictionModelLoader.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.IO; 3 | using System.Linq; 4 | using System.Windows.Forms; 5 | using DigitRecognizer.Core.Utilities; 6 | using DigitRecognizer.MachineLearning.Infrastructure.Models; 7 | 8 | namespace DigitRecognizer.Presentation.Infrastructure 9 | { 10 | public class PredictionModelLoader 11 | { 12 | private readonly OpenFileDialog _openFileDialog; 13 | 14 | public PredictionModelLoader() 15 | { 16 | _openFileDialog = new OpenFileDialog 17 | { 18 | Multiselect = true, 19 | CheckFileExists = true, 20 | InitialDirectory = Path.GetFullPath(Path.GetDirectoryName(AppDomain.CurrentDomain.BaseDirectory) + DirectoryHelper.ModelsFolder), 21 | Filter = @"Neural network file (*.nn)|*.nn" 22 | }; 23 | } 24 | 25 | public IPredictionModel Load() 26 | { 27 | if (DialogResult.OK != _openFileDialog.ShowDialog() || _openFileDialog.FileNames.Length == 0 || _openFileDialog.FileNames.Any(string.IsNullOrWhiteSpace)) 28 | { 29 | return null; 30 | } 31 | 32 | IPredictionModel model = ClusterPredictionModel.FromFiles(_openFileDialog.FileNames); 33 | 34 | return model; 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Models/ImageGridModel.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | using System.Drawing; 3 | using System.Drawing.Imaging; 4 | using System.Threading.Tasks; 5 | using DigitRecognizer.Core.Data; 6 | using DigitRecognizer.Core.Extensions; 7 | 8 | namespace DigitRecognizer.Presentation.Models 9 | { 10 | public class ImageGridModel 11 | { 12 | private const int Size = 28; 13 | 14 | private static readonly Color PaleRed = Color.FromArgb(255, 239, 153, 153); 15 | 16 | public ImageGridModel(MnistImageBatch batch, int[] predictions) 17 | { 18 | Labels = new List(batch.Labels); 19 | 20 | Predictions = new List(predictions); 21 | 22 | Images = new List(ProcessImages(batch.Pixels)); 23 | } 24 | 25 | public ImageGridModel() 26 | { 27 | Labels = new List(); 28 | Predictions = new List(); 29 | Images = new List(); 30 | } 31 | 32 | private Image[] ProcessImages(double[][] pixelArray) 33 | { 34 | var images = new Image[pixelArray.Length]; 35 | 36 | Parallel.For(0, pixelArray.Length, (row) => 37 | { 38 | double[] pixels = pixelArray[row]; 39 | 40 | var bmp = new Bitmap(Size, Size); 41 | 42 | BitmapData data = bmp.LockBits(new Rectangle(0, 0, Size, Size), ImageLockMode.ReadWrite, PixelFormat.Format24bppRgb); 43 | 44 | unsafe 45 | { 46 | var ptr = (byte*)data.Scan0.ToPointer(); 47 | 48 | for (var len = 0; len < Size * Size; len++) 49 | { 50 | var pixel = (byte) (pixels[len] > 0.2 ? 0 : 255); 51 | 52 | if (Labels[row] != Predictions[row]) 53 | { 54 | *(ptr++) = pixel == 255 ? PaleRed.B : pixel; // B 55 | *(ptr++) = pixel == 255 ? PaleRed.G : pixel; // G 56 | *(ptr++) = pixel == 255 ? PaleRed.R : pixel; // R 57 | } 58 | else 59 | { 60 | *(ptr++) = pixel; // B 61 | *(ptr++) = pixel; // G 62 | *(ptr++) = pixel; // R 63 | } 64 | } 65 | } 66 | 67 | bmp.UnlockBits(data); 68 | 69 | images[row] = bmp.Resize(56, 56); 70 | }); 71 | 72 | return images; 73 | } 74 | 75 | public List Images { get; set; } 76 | 77 | public List Labels { get; set; } 78 | 79 | public List Predictions { get; set; } 80 | 81 | public int Count => Images.Count; 82 | 83 | public static ImageGridModel operator +(ImageGridModel igm1, ImageGridModel igm2) 84 | { 85 | igm1.Images.AddRange(igm2.Images); 86 | igm1.Labels.AddRange(igm2.Labels); 87 | igm1.Predictions.AddRange(igm2.Predictions); 88 | 89 | return igm1; 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Presenters/ApplicationPresenter.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Presentation.Services; 2 | using DigitRecognizer.Presentation.Views.Interfaces; 3 | 4 | namespace DigitRecognizer.Presentation.Presenters 5 | { 6 | public class ApplicationPresenter 7 | { 8 | private readonly BenchmarkPresenter _benchmarkPresenter; 9 | private readonly DrawingPresenter _drawingPresenter; 10 | private readonly UploadImagePresenter _uploadImagePresenter; 11 | private readonly SlidingWindowPresenter _slidingWindowPresenter; 12 | 13 | public ApplicationPresenter(IMainFormView mainFormView, 14 | IMessageService messageService, 15 | ILoggingService loggingService) 16 | { 17 | loggingService.Log("Initializing presenters"); 18 | 19 | _benchmarkPresenter = new BenchmarkPresenter(mainFormView.BenchmarkView, messageService, loggingService); 20 | _drawingPresenter = new DrawingPresenter(mainFormView.DrawingView, messageService, loggingService); 21 | _uploadImagePresenter = new UploadImagePresenter(mainFormView.UploadImageView, messageService, loggingService); 22 | _slidingWindowPresenter = new SlidingWindowPresenter(mainFormView.SlidingWindowView, messageService, loggingService); 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Presenters/BenchmarkPresenter.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Linq; 3 | using System.Threading; 4 | using System.Threading.Tasks; 5 | using System.Windows.Forms; 6 | using DigitRecognizer.Core.Data; 7 | using DigitRecognizer.Core.Extensions; 8 | using DigitRecognizer.Core.Utilities; 9 | using DigitRecognizer.MachineLearning.Infrastructure.Models; 10 | using DigitRecognizer.MachineLearning.Providers; 11 | using DigitRecognizer.Presentation.Models; 12 | using DigitRecognizer.Presentation.Services; 13 | using DigitRecognizer.Presentation.Views.Interfaces; 14 | 15 | namespace DigitRecognizer.Presentation.Presenters 16 | { 17 | public class BenchmarkPresenter 18 | { 19 | #region Fields 20 | 21 | private readonly IBenchmarkView _benchmarkView; 22 | private readonly IMessageService _messageService; 23 | private readonly ILoggingService _loggingService; 24 | private CancellationTokenSource _cancellationTokenSource; 25 | 26 | #endregion 27 | 28 | #region Ctor 29 | 30 | public BenchmarkPresenter(IBenchmarkView benchmarkView, IMessageService messageService, ILoggingService loggingService) 31 | { 32 | _messageService = messageService; 33 | _loggingService = loggingService; 34 | _benchmarkView = benchmarkView; 35 | 36 | _benchmarkView.RunBenchmark += OnRunBenchmark; 37 | 38 | _benchmarkView.CancelBenchmark += OnCancelBenchmark; 39 | 40 | _benchmarkView.IsBenchmarkRunning = false; 41 | } 42 | 43 | #endregion 44 | 45 | #region Methods 46 | 47 | private async void OnRunBenchmark(object sender, EventArgs e) 48 | { 49 | IPredictionModel predictionModel = Global.PredictionModel; 50 | 51 | if (predictionModel == null) 52 | { 53 | _messageService.ShowMessage("The prediction model must be loaded first.", "Prediction model", icon: MessageBoxIcon.Information); 54 | 55 | return; 56 | } 57 | 58 | _benchmarkView.ResetView(); 59 | 60 | _cancellationTokenSource = new CancellationTokenSource(); 61 | 62 | CancellationToken token = _cancellationTokenSource.Token; 63 | 64 | token.Register(() => { _benchmarkView.IsBenchmarkRunning = false; }); 65 | 66 | try 67 | { 68 | await Task.Run(() => { 69 | RunBenchmark(predictionModel); 70 | }, token); 71 | } 72 | catch (Exception exception) 73 | { 74 | _loggingService.Log(exception); 75 | 76 | _messageService.ShowMessage("An error ocurred while running the benchmark. Please try again.", "Benchmark error", icon: MessageBoxIcon.Information); 77 | } 78 | } 79 | 80 | private void RunBenchmark(IPredictionModel model) 81 | { 82 | _loggingService.Log("Running benchmark has started"); 83 | 84 | _benchmarkView.IsBenchmarkRunning = true; 85 | 86 | var provider = new BatchDataProvider(DirectoryHelper.TestLabelsPath, DirectoryHelper.TestImagesPath, 100); 87 | 88 | var acc = 0; 89 | 90 | for (var i = 0; i < 100; i++) 91 | { 92 | if (!_benchmarkView.IsBenchmarkRunning) 93 | { 94 | break; 95 | } 96 | 97 | MnistImageBatch data = provider.GetData(); 98 | 99 | int[] predictions = data.Pixels.Select(model.Predict).Select(x=> x.ArgMax()).ToArray(); 100 | 101 | acc += data.Labels.Where((lbl, pred) => lbl == predictions[pred]).Count(); 102 | 103 | _benchmarkView.PerformProgressStep(); 104 | 105 | _benchmarkView.DrawGrid(new ImageGridModel(data, predictions)); 106 | } 107 | 108 | _benchmarkView.SetAccuracy(acc); 109 | 110 | _benchmarkView.IsBenchmarkRunning = false; 111 | 112 | _loggingService.Log("Running benchmark has completed"); 113 | } 114 | 115 | private void OnCancelBenchmark(object sender, EventArgs e) 116 | { 117 | _loggingService.Log("Running benchmark was canceled"); 118 | 119 | _cancellationTokenSource.Cancel(); 120 | } 121 | 122 | #endregion 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Presenters/DrawingPresenter.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Windows.Forms; 3 | using DigitRecognizer.Core.IO; 4 | using DigitRecognizer.MachineLearning.Infrastructure.Models; 5 | using DigitRecognizer.Presentation.Services; 6 | using DigitRecognizer.Presentation.Views.Interfaces; 7 | 8 | namespace DigitRecognizer.Presentation.Presenters 9 | { 10 | public class DrawingPresenter 11 | { 12 | #region Fields 13 | 14 | private readonly IDrawingView _drawingView; 15 | private readonly IMessageService _messageService; 16 | private readonly ILoggingService _loggingService; 17 | 18 | #endregion 19 | 20 | #region Ctor 21 | 22 | public DrawingPresenter(IDrawingView drawingView, IMessageService messageService, ILoggingService loggingService) 23 | { 24 | _messageService = messageService; 25 | _loggingService = loggingService; 26 | _drawingView = drawingView; 27 | 28 | _drawingView.ClassifyDrawing += OnClassifyDrawing; 29 | 30 | _drawingView.ClearDrawing += OnClearDrawing; 31 | } 32 | 33 | #endregion 34 | 35 | #region Methods 36 | 37 | private void OnClassifyDrawing(object sender, EventArgs e) 38 | { 39 | try 40 | { 41 | _loggingService.Log("Classify drawing has started"); 42 | 43 | var imagePreprocessor = new ImagePreprocessor(); 44 | 45 | double[] pixels = imagePreprocessor.Preprocess(_drawingView.Drawing); 46 | 47 | IPredictionModel predictionModel = Global.PredictionModel; 48 | 49 | double[] prediction = predictionModel.Predict(pixels); 50 | 51 | _drawingView.ProcessPrediction(prediction); 52 | 53 | _loggingService.Log("Classify drawing has completed"); 54 | } 55 | catch (Exception exception) 56 | { 57 | _loggingService.Log(exception); 58 | 59 | _messageService.ShowMessage("An error ocurred while classyfing the drawing. Please try again.", "Classification error", icon: MessageBoxIcon.Information); 60 | } 61 | } 62 | 63 | private void OnClearDrawing(object sender, EventArgs e) 64 | { 65 | _drawingView.Clear(); 66 | } 67 | 68 | #endregion 69 | } 70 | } -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Presenters/SlidingWindowPresenter.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Drawing; 3 | using System.Windows.Forms; 4 | using DigitRecognizer.Core.Data; 5 | using DigitRecognizer.Core.Extensions; 6 | using DigitRecognizer.Core.IO; 7 | using DigitRecognizer.Core.Utilities; 8 | using DigitRecognizer.MachineLearning.Infrastructure.Models; 9 | using DigitRecognizer.Presentation.Services; 10 | using DigitRecognizer.Presentation.Views.Interfaces; 11 | 12 | namespace DigitRecognizer.Presentation.Presenters 13 | { 14 | public class SlidingWindowPresenter 15 | { 16 | #region Fields 17 | 18 | private readonly ISlidingWindowView _slidingWindowView; 19 | private readonly IMessageService _messageService; 20 | private readonly ILoggingService _loggingService; 21 | 22 | private static readonly Size[] WindowSizes = 23 | { 24 | new Size(280, 280), 25 | //new Size(140, 140), 26 | //new Size(112, 112), 27 | //new Size(84, 84), 28 | //new Size(56, 56) 29 | }; 30 | 31 | #endregion 32 | 33 | #region Ctor 34 | 35 | public SlidingWindowPresenter(ISlidingWindowView slidingWindowView, IMessageService messageService, ILoggingService loggingService) 36 | { 37 | _messageService = messageService; 38 | _loggingService = loggingService; 39 | _slidingWindowView = slidingWindowView; 40 | 41 | _slidingWindowView.ClassifyDrawing += OnClassifyDrawing; 42 | 43 | _slidingWindowView.ClearDrawing += OnClearDrawing; 44 | } 45 | 46 | #endregion 47 | 48 | #region Methods 49 | 50 | private void OnClassifyDrawing(object sender, EventArgs e) 51 | { 52 | try 53 | { 54 | _loggingService.Log("Classify drawing has started"); 55 | 56 | var imagePreprocessor = new ImagePreprocessor(); 57 | 58 | IPredictionModel predictionModel = Global.PredictionModel; 59 | 60 | Image img = _slidingWindowView.Drawing; 61 | 62 | foreach (Size windowSize in WindowSizes) 63 | { 64 | foreach (BoundingBox boundingBox in ImageUtilities.SlidingWindow(img, windowSize, 112)) 65 | { 66 | try 67 | { 68 | double[] pixels = imagePreprocessor.Preprocess(boundingBox.Image); 69 | 70 | double[] prediction = predictionModel.Predict(pixels); 71 | 72 | // If classification is over 99% draw a bounding box at this location 73 | int predicted = prediction.ArgMax(); 74 | double predictedAccuracy = prediction[prediction.ArgMax()]; 75 | 76 | if (predictedAccuracy >= 0.95) 77 | { 78 | _slidingWindowView.DrawBoundingBox(boundingBox, predicted, predictedAccuracy); 79 | } 80 | } 81 | catch (Exception exception) 82 | { 83 | _loggingService.Log(exception); 84 | } 85 | } 86 | } 87 | 88 | _loggingService.Log("Classify drawing has completed"); 89 | } 90 | catch (Exception exception) 91 | { 92 | _loggingService.Log(exception); 93 | 94 | _messageService.ShowMessage("An error ocurred while classyfing the drawing. Please try again.", "Classification error", icon: MessageBoxIcon.Information); 95 | } 96 | } 97 | 98 | private void OnClearDrawing(object sender, EventArgs e) 99 | { 100 | _slidingWindowView.Clear(); 101 | } 102 | 103 | #endregion 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Presenters/UploadImagePresenter.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Windows.Forms; 3 | using DigitRecognizer.Core.IO; 4 | using DigitRecognizer.MachineLearning.Infrastructure.Models; 5 | using DigitRecognizer.Presentation.Services; 6 | using DigitRecognizer.Presentation.Views.Interfaces; 7 | 8 | namespace DigitRecognizer.Presentation.Presenters 9 | { 10 | public class UploadImagePresenter 11 | { 12 | #region Fields 13 | 14 | private readonly IUploadImageView _uploadImageView; 15 | private readonly IMessageService _messageService; 16 | private readonly ILoggingService _loggingService; 17 | 18 | #endregion 19 | 20 | #region Ctor 21 | 22 | 23 | public UploadImagePresenter(IUploadImageView uploadImageView, IMessageService messageService, ILoggingService loggingService) 24 | { 25 | _uploadImageView = uploadImageView; 26 | _messageService = messageService; 27 | _loggingService = loggingService; 28 | 29 | _uploadImageView.ClassifyImage += OnClassifyImage; 30 | 31 | _uploadImageView.ClearImage += OnClearImage; 32 | } 33 | 34 | #endregion 35 | 36 | #region Methods 37 | 38 | private void OnClassifyImage(object sender, EventArgs e) 39 | { 40 | try 41 | { 42 | _loggingService.Log("Classify drawing has started"); 43 | 44 | var imagePreprocessor = new ImagePreprocessor(); 45 | 46 | double[] pixels = imagePreprocessor.Preprocess(_uploadImageView.Image); 47 | 48 | IPredictionModel predictionModel = Global.PredictionModel; 49 | 50 | double[] prediction = predictionModel.Predict(pixels); 51 | 52 | _uploadImageView.ProcessPrediction(prediction); 53 | 54 | _loggingService.Log("Classify drawing has completed"); 55 | } 56 | catch (NullReferenceException exception) 57 | { 58 | _loggingService.Log(exception); 59 | 60 | _messageService.ShowMessage("No image was uploaded. Please upload an image and try again.", "Upload error", icon: MessageBoxIcon.Information); 61 | } 62 | catch (Exception exception) 63 | { 64 | _loggingService.Log(exception); 65 | 66 | _messageService.ShowMessage("An error ocurred while classyfing the drawing. Please try again.", "Classification error", icon: MessageBoxIcon.Information); 67 | } 68 | } 69 | 70 | private void OnClearImage(object sender, EventArgs e) 71 | { 72 | _uploadImageView.Clear(); 73 | } 74 | 75 | #endregion 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Program.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Windows.Forms; 3 | using DigitRecognizer.Presentation.Infrastructure; 4 | using DigitRecognizer.Presentation.Views.Implementations; 5 | 6 | namespace DigitRecognizer.Presentation 7 | { 8 | internal static class Program 9 | { 10 | /// 11 | /// The main entry point for the application. 12 | /// 13 | [STAThread] 14 | private static void Main() 15 | { 16 | Startup.RegisterDependencies(); 17 | 18 | Application.SetUnhandledExceptionMode(UnhandledExceptionMode.CatchException); 19 | 20 | // Option to continue with execution. 21 | Application.ThreadException += ExceptionHandlers.ApplicationOnThreadException; 22 | 23 | // Application will terminate. 24 | AppDomain.CurrentDomain.UnhandledException += ExceptionHandlers.CurrentDomainOnUnhandledException; 25 | 26 | Application.EnableVisualStyles(); 27 | 28 | Application.SetCompatibleTextRenderingDefault(false); 29 | 30 | var mainForm = DependencyResolver.Resolve(); 31 | 32 | Startup.SetupPresenters(mainForm); 33 | 34 | Application.Run(mainForm); 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Properties/AssemblyInfo.cs: -------------------------------------------------------------------------------- 1 | using System.Reflection; 2 | using System.Runtime.InteropServices; 3 | 4 | // General Information about an assembly is controlled through the following 5 | // set of attributes. Change these attribute values to modify the information 6 | // associated with an assembly. 7 | [assembly: AssemblyTitle("DigitRecognizer.Presentation")] 8 | [assembly: AssemblyDescription("")] 9 | [assembly: AssemblyConfiguration("")] 10 | [assembly: AssemblyCompany("")] 11 | [assembly: AssemblyProduct("DigitRecognizer.Presentation")] 12 | [assembly: AssemblyCopyright("Copyright © 2018")] 13 | [assembly: AssemblyTrademark("")] 14 | [assembly: AssemblyCulture("")] 15 | 16 | // Setting ComVisible to false makes the types in this assembly not visible 17 | // to COM components. If you need to access a type in this assembly from 18 | // COM, set the ComVisible attribute to true on that type. 19 | [assembly: ComVisible(false)] 20 | 21 | // The following GUID is for the ID of the typelib if this project is exposed to COM 22 | [assembly: Guid("6caf69ea-9c0c-4571-a808-268c159bdb5c")] 23 | 24 | // Version information for an assembly consists of the following four values: 25 | // 26 | // Major Version 27 | // Minor Version 28 | // Build Number 29 | // Revision 30 | // 31 | // You can specify all the values or you can default the Build and Revision Numbers 32 | // by using the '*' as shown below: 33 | // [assembly: AssemblyVersion("1.0.*")] 34 | [assembly: AssemblyVersion("1.0.0.0")] 35 | [assembly: AssemblyFileVersion("1.0.0.0")] 36 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Properties/Resources.Designer.cs: -------------------------------------------------------------------------------- 1 | //------------------------------------------------------------------------------ 2 | // 3 | // This code was generated by a tool. 4 | // Runtime Version:4.0.30319.42000 5 | // 6 | // Changes to this file may cause incorrect behavior and will be lost if 7 | // the code is regenerated. 8 | // 9 | //------------------------------------------------------------------------------ 10 | 11 | namespace DigitRecognizer.Presentation.Properties { 12 | using System; 13 | 14 | 15 | /// 16 | /// A strongly-typed resource class, for looking up localized strings, etc. 17 | /// 18 | // This class was auto-generated by the StronglyTypedResourceBuilder 19 | // class via a tool like ResGen or Visual Studio. 20 | // To add or remove a member, edit your .ResX file then rerun ResGen 21 | // with the /str option, or rebuild your VS project. 22 | [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] 23 | [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] 24 | [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] 25 | internal class Resources { 26 | 27 | private static global::System.Resources.ResourceManager resourceMan; 28 | 29 | private static global::System.Globalization.CultureInfo resourceCulture; 30 | 31 | [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] 32 | internal Resources() { 33 | } 34 | 35 | /// 36 | /// Returns the cached ResourceManager instance used by this class. 37 | /// 38 | [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] 39 | internal static global::System.Resources.ResourceManager ResourceManager { 40 | get { 41 | if (object.ReferenceEquals(resourceMan, null)) { 42 | global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("DigitRecognizer.Presentation.Properties.Resources", typeof(Resources).Assembly); 43 | resourceMan = temp; 44 | } 45 | return resourceMan; 46 | } 47 | } 48 | 49 | /// 50 | /// Overrides the current thread's CurrentUICulture property for all 51 | /// resource lookups using this strongly typed resource class. 52 | /// 53 | [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] 54 | internal static global::System.Globalization.CultureInfo Culture { 55 | get { 56 | return resourceCulture; 57 | } 58 | set { 59 | resourceCulture = value; 60 | } 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Properties/Settings.Designer.cs: -------------------------------------------------------------------------------- 1 | //------------------------------------------------------------------------------ 2 | // 3 | // This code was generated by a tool. 4 | // Runtime Version:4.0.30319.42000 5 | // 6 | // Changes to this file may cause incorrect behavior and will be lost if 7 | // the code is regenerated. 8 | // 9 | //------------------------------------------------------------------------------ 10 | 11 | namespace DigitRecognizer.Presentation.Properties { 12 | 13 | 14 | [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] 15 | [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.Editors.SettingsDesigner.SettingsSingleFileGenerator", "17.9.0.0")] 16 | internal sealed partial class Settings : global::System.Configuration.ApplicationSettingsBase { 17 | 18 | private static Settings defaultInstance = ((Settings)(global::System.Configuration.ApplicationSettingsBase.Synchronized(new Settings()))); 19 | 20 | public static Settings Default { 21 | get { 22 | return defaultInstance; 23 | } 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Properties/Settings.settings: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Services/ILoggingService.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | 3 | namespace DigitRecognizer.Presentation.Services 4 | { 5 | public interface ILoggingService 6 | { 7 | void Log(Exception e); 8 | 9 | void Log(string message); 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Services/IMessageService.cs: -------------------------------------------------------------------------------- 1 | using System.Windows.Forms; 2 | 3 | namespace DigitRecognizer.Presentation.Services 4 | { 5 | public interface IMessageService 6 | { 7 | void ShowMessage(string text, string caption = "", MessageBoxButtons buttons = MessageBoxButtons.OK, MessageBoxIcon icon = MessageBoxIcon.None); 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Services/LoggingService.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.IO; 3 | using System.Text; 4 | 5 | namespace DigitRecognizer.Presentation.Services 6 | { 7 | public class LoggingService : ILoggingService 8 | { 9 | private const string LogFileName = "log.txt"; 10 | 11 | public void Log(Exception e) 12 | { 13 | using (var writer = new StreamWriter(LogFileName, true)) 14 | { 15 | var builder = new StringBuilder($"{DateTime.Now} - {e.Message}{Environment.NewLine}"); 16 | 17 | Exception current = e; 18 | while (current != null) 19 | { 20 | builder.Append(e); 21 | 22 | current = current.InnerException; 23 | } 24 | 25 | builder.AppendLine(); 26 | builder.AppendLine(); 27 | 28 | writer.Write(builder.ToString()); 29 | } 30 | } 31 | 32 | public void Log(string message) 33 | { 34 | if (string.IsNullOrWhiteSpace(message)) 35 | { 36 | return; 37 | } 38 | 39 | using (var writer = new StreamWriter(LogFileName, true)) 40 | { 41 | var builder = new StringBuilder($"{DateTime.Now} - {message}{Environment.NewLine}"); 42 | 43 | builder.AppendLine(); 44 | 45 | writer.Write(builder.ToString()); 46 | } 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Services/MessageService.cs: -------------------------------------------------------------------------------- 1 | using System.Windows.Forms; 2 | 3 | namespace DigitRecognizer.Presentation.Services 4 | { 5 | public class MessageService : IMessageService 6 | { 7 | public void ShowMessage(string text, string caption = "", MessageBoxButtons buttons = MessageBoxButtons.OK, MessageBoxIcon icon = MessageBoxIcon.None) 8 | { 9 | MessageBox.Show(text, caption, buttons, icon); 10 | } 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Startup.cs: -------------------------------------------------------------------------------- 1 | using DigitRecognizer.Presentation.Infrastructure; 2 | using DigitRecognizer.Presentation.Presenters; 3 | using DigitRecognizer.Presentation.Services; 4 | using DigitRecognizer.Presentation.Views.Implementations; 5 | using DigitRecognizer.Presentation.Views.Interfaces; 6 | 7 | namespace DigitRecognizer.Presentation 8 | { 9 | public static class Startup 10 | { 11 | public static void RegisterDependencies() 12 | { 13 | DependencyResolver.Register(); 14 | 15 | DependencyResolver.Register(); 16 | 17 | DependencyResolver.Register(); 18 | } 19 | 20 | public static void SetupPresenters(IMainFormView mainFormView) 21 | { 22 | var messageService = DependencyResolver.Resolve(); 23 | var loggingService = DependencyResolver.Resolve(); 24 | 25 | var _ = new ApplicationPresenter(mainFormView, messageService, loggingService); 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Views/Implementations/MainForm.cs: -------------------------------------------------------------------------------- 1 | using System.Collections.Generic; 2 | using System.Linq; 3 | using System.Windows.Forms; 4 | using DigitRecognizer.Presentation.Views.Interfaces; 5 | 6 | namespace DigitRecognizer.Presentation.Views.Implementations 7 | { 8 | public partial class MainForm : Form, IMainFormView 9 | { 10 | #region Fields 11 | 12 | private BenchmarkView _benchmarkView; 13 | private DrawingView _drawingView; 14 | private UploadImageView _uploadImageView; 15 | private SlidingWindowView _slidingWindowView; 16 | private List _views; 17 | 18 | #endregion 19 | 20 | #region Ctor 21 | 22 | public MainForm() 23 | { 24 | InitializeComponent(); 25 | 26 | InitializeViews(); 27 | 28 | HideView(); 29 | } 30 | 31 | #endregion 32 | 33 | #region Properties 34 | 35 | public IBenchmarkView BenchmarkView => _benchmarkView; 36 | 37 | public IDrawingView DrawingView => _drawingView; 38 | 39 | public IUploadImageView UploadImageView => _uploadImageView; 40 | 41 | public ISlidingWindowView SlidingWindowView => _slidingWindowView; 42 | 43 | #endregion 44 | 45 | #region Methods 46 | 47 | private void InitializeViews() 48 | { 49 | FillControlsCollection(); 50 | 51 | FillViewsCollection(); 52 | } 53 | 54 | private void FillControlsCollection() 55 | { 56 | _benchmarkView = new BenchmarkView { Dock = DockStyle.Fill }; 57 | _drawingView = new DrawingView { Dock = DockStyle.Fill }; 58 | _uploadImageView = new UploadImageView { Dock = DockStyle.Fill }; 59 | _slidingWindowView = new SlidingWindowView { Dock = DockStyle.Fill, Padding = new Padding(10)}; 60 | 61 | Control[] controls = 62 | { 63 | _benchmarkView, 64 | _drawingView, 65 | _uploadImageView, 66 | _slidingWindowView 67 | }; 68 | 69 | Controls.AddRange(controls); 70 | } 71 | 72 | private void FillViewsCollection() 73 | { 74 | _views = new List 75 | { 76 | _benchmarkView, 77 | _drawingView, 78 | _uploadImageView, 79 | _slidingWindowView, 80 | this 81 | }; 82 | 83 | HideViews(); 84 | 85 | _views.First().ShowView(); 86 | } 87 | 88 | private void HideViews() => _views.ForEach(x => x.HideView()); 89 | 90 | #endregion 91 | 92 | #region Event handlers 93 | 94 | private void BenchmarkToolStripMenuItem_Click(object sender, System.EventArgs e) 95 | { 96 | ToolstripMenuItem_Click_DisplayView(_benchmarkView); 97 | } 98 | 99 | private void DrawingToolStripMenuItem_Click(object sender, System.EventArgs e) 100 | { 101 | ToolstripMenuItem_Click_DisplayView(_drawingView); 102 | } 103 | 104 | private void UploadToolStripMenuItem_Click(object sender, System.EventArgs e) 105 | { 106 | ToolstripMenuItem_Click_DisplayView(_uploadImageView); 107 | } 108 | 109 | private void SlidingWindowToolStripMenuItem_Click(object sender, System.EventArgs e) 110 | { 111 | ToolstripMenuItem_Click_DisplayView(_slidingWindowView); 112 | } 113 | 114 | private static void ToolstripMenuItem_Click_DisplayView(IView view) 115 | { 116 | view.ShowView(); 117 | } 118 | 119 | #endregion 120 | 121 | #region IView implementation 122 | 123 | public void ShowView() 124 | { 125 | BringToFront(); 126 | 127 | Show(); 128 | } 129 | 130 | public void HideView() 131 | { 132 | Hide(); 133 | } 134 | 135 | #endregion 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Views/Implementations/SlidingWindowView.Designer.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.Presentation.Views.Implementations 2 | { 3 | partial class SlidingWindowView 4 | { 5 | /// 6 | /// Required designer variable. 7 | /// 8 | private System.ComponentModel.IContainer components = null; 9 | 10 | /// 11 | /// Clean up any resources being used. 12 | /// 13 | /// true if managed resources should be disposed; otherwise, false. 14 | protected override void Dispose(bool disposing) 15 | { 16 | if (disposing && (components != null)) 17 | { 18 | components.Dispose(); 19 | } 20 | base.Dispose(disposing); 21 | } 22 | 23 | #region Component Designer generated code 24 | 25 | /// 26 | /// Required method for Designer support - do not modify 27 | /// the contents of this method with the code editor. 28 | /// 29 | private void InitializeComponent() 30 | { 31 | this.panelDrawing = new System.Windows.Forms.Panel(); 32 | this.btnClearDrawing = new System.Windows.Forms.Button(); 33 | this.btnClassifyDrawing = new System.Windows.Forms.Button(); 34 | this.SuspendLayout(); 35 | // 36 | // panelDrawing 37 | // 38 | this.panelDrawing.BorderStyle = System.Windows.Forms.BorderStyle.FixedSingle; 39 | this.panelDrawing.Dock = System.Windows.Forms.DockStyle.Top; 40 | this.panelDrawing.Location = new System.Drawing.Point(0, 0); 41 | this.panelDrawing.Name = "panelDrawing"; 42 | this.panelDrawing.Size = new System.Drawing.Size(1178, 628); 43 | this.panelDrawing.TabIndex = 0; 44 | // 45 | // btnClearDrawing 46 | // 47 | this.btnClearDrawing.Anchor = ((System.Windows.Forms.AnchorStyles)((System.Windows.Forms.AnchorStyles.Bottom | System.Windows.Forms.AnchorStyles.Left))); 48 | this.btnClearDrawing.Font = new System.Drawing.Font("Microsoft Sans Serif", 9.75F, System.Drawing.FontStyle.Regular, System.Drawing.GraphicsUnit.Point, ((byte)(0))); 49 | this.btnClearDrawing.Location = new System.Drawing.Point(136, 634); 50 | this.btnClearDrawing.Name = "btnClearDrawing"; 51 | this.btnClearDrawing.Size = new System.Drawing.Size(127, 30); 52 | this.btnClearDrawing.TabIndex = 24; 53 | this.btnClearDrawing.Text = "Clear drawing"; 54 | this.btnClearDrawing.UseVisualStyleBackColor = true; 55 | // 56 | // btnClassifyDrawing 57 | // 58 | this.btnClassifyDrawing.Anchor = ((System.Windows.Forms.AnchorStyles)((System.Windows.Forms.AnchorStyles.Bottom | System.Windows.Forms.AnchorStyles.Left))); 59 | this.btnClassifyDrawing.Font = new System.Drawing.Font("Microsoft Sans Serif", 9.75F, System.Drawing.FontStyle.Regular, System.Drawing.GraphicsUnit.Point, ((byte)(0))); 60 | this.btnClassifyDrawing.Location = new System.Drawing.Point(3, 634); 61 | this.btnClassifyDrawing.Name = "btnClassifyDrawing"; 62 | this.btnClassifyDrawing.Size = new System.Drawing.Size(127, 30); 63 | this.btnClassifyDrawing.TabIndex = 23; 64 | this.btnClassifyDrawing.Text = "Classify drawing"; 65 | this.btnClassifyDrawing.UseVisualStyleBackColor = true; 66 | // 67 | // SlidingWindowView 68 | // 69 | this.AutoScaleDimensions = new System.Drawing.SizeF(6F, 13F); 70 | this.AutoScaleMode = System.Windows.Forms.AutoScaleMode.Font; 71 | this.BackColor = System.Drawing.Color.White; 72 | this.Controls.Add(this.btnClearDrawing); 73 | this.Controls.Add(this.btnClassifyDrawing); 74 | this.Controls.Add(this.panelDrawing); 75 | this.Name = "SlidingWindowView"; 76 | this.Size = new System.Drawing.Size(1178, 667); 77 | this.ResumeLayout(false); 78 | 79 | } 80 | 81 | #endregion 82 | 83 | private System.Windows.Forms.Panel panelDrawing; 84 | private System.Windows.Forms.Button btnClearDrawing; 85 | private System.Windows.Forms.Button btnClassifyDrawing; 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Views/Implementations/UploadImageView.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Drawing; 3 | using System.IO; 4 | using System.Windows.Forms; 5 | using DigitRecognizer.Core.Utilities; 6 | using DigitRecognizer.Presentation.Infrastructure; 7 | using DigitRecognizer.Presentation.Views.Interfaces; 8 | 9 | namespace DigitRecognizer.Presentation.Views.Implementations 10 | { 11 | public partial class UploadImageView : UserControl, IUploadImageView 12 | { 13 | #region Fields 14 | 15 | private Image _image; 16 | private OpenFileDialog _openFileDialog; 17 | 18 | #endregion 19 | 20 | #region Ctor 21 | 22 | public UploadImageView() 23 | { 24 | InitializeComponent(); 25 | 26 | Clear(); 27 | 28 | _openFileDialog = new OpenFileDialog 29 | { 30 | InitialDirectory = Environment.GetFolderPath(Environment.SpecialFolder.Desktop), 31 | Filter = @"Image|*.bmp;*.png;*.jpeg;*.jpg" 32 | }; 33 | 34 | PanelDoubleBuffering.Enable(uploadPanel); 35 | 36 | btnClassifyImg.Click += BtnClassifyImgOnClick; 37 | 38 | btnClearImg.Click += BtnClearImgOnClick; 39 | 40 | btnUploadImg.Click += BtnUploadImgOnClick; 41 | 42 | uploadPanel.Paint += UploadPanelOnPaint; 43 | } 44 | 45 | #endregion 46 | 47 | #region Properties 48 | 49 | public Image Image => _image ?? throw new NullReferenceException(nameof(_image)); 50 | 51 | #endregion 52 | 53 | #region Event handlers 54 | 55 | public event EventHandler ClassifyImage; 56 | public event EventHandler ClearImage; 57 | 58 | #endregion 59 | 60 | #region Methods 61 | 62 | private void BtnClassifyImgOnClick(object sender, EventArgs e) 63 | { 64 | ClassifyImage?.Invoke(this, EventArgs.Empty); 65 | } 66 | 67 | private void BtnClearImgOnClick(object sender, EventArgs e) 68 | { 69 | ClearImage?.Invoke(this, EventArgs.Empty); 70 | } 71 | 72 | private void BtnUploadImgOnClick(object sender, EventArgs e) 73 | { 74 | DialogResult dlgResult = _openFileDialog.ShowDialog(); 75 | 76 | if (dlgResult == DialogResult.OK) 77 | { 78 | Image rawImage = Image.FromFile(_openFileDialog.FileName); 79 | 80 | _image = ImageUtilities.Resize(rawImage, uploadPanel.Width, uploadPanel.Height); 81 | 82 | predictionPane.Clear(); 83 | 84 | uploadPanel.Invalidate(); 85 | } 86 | } 87 | 88 | private void UploadPanelOnPaint(object sender, PaintEventArgs e) 89 | { 90 | if (_image is null) 91 | { 92 | return; 93 | } 94 | 95 | e.Graphics.DrawImage(_image, Point.Empty); 96 | } 97 | 98 | protected override void OnResize(EventArgs e) 99 | { 100 | uploadPanel.Top = (Height - uploadPanel.Height) / 2; 101 | uploadPanel.Left = (Width - (uploadPanel.Width + predictionPane.Width - 35)) / 2; 102 | 103 | lblInstructions.Left = uploadPanel.Left - 5; 104 | lblInstructions.Top = uploadPanel.Top - lblInstructions.Height - 5; 105 | 106 | int top = uploadPanel.Bottom + 5; 107 | 108 | btnClassifyImg.Left = uploadPanel.Left; 109 | btnClassifyImg.Top = top; 110 | 111 | btnClearImg.Left = btnClassifyImg.Right + 5; 112 | btnClearImg.Top = top; 113 | 114 | btnUploadImg.Left = uploadPanel.Right - btnUploadImg.Width; 115 | btnUploadImg.Top = top; 116 | 117 | predictionPane.Left = uploadPanel.Right + 40; 118 | predictionPane.Top = uploadPanel.Top; 119 | 120 | base.OnResize(e); 121 | } 122 | 123 | #endregion 124 | 125 | #region IUploadImageView implementation 126 | 127 | public void Clear() 128 | { 129 | _image = null; 130 | 131 | predictionPane.Clear(); 132 | 133 | uploadPanel.Invalidate(); 134 | } 135 | 136 | public void ProcessPrediction(double[] prediction) 137 | { 138 | predictionPane.ProcessPrediction(prediction); 139 | } 140 | 141 | public void ShowView() 142 | { 143 | BringToFront(); 144 | 145 | Show(); 146 | } 147 | 148 | public void HideView() 149 | { 150 | Hide(); 151 | } 152 | 153 | #endregion 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Views/Interfaces/IBenchmarkView.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using DigitRecognizer.Presentation.Models; 3 | 4 | namespace DigitRecognizer.Presentation.Views.Interfaces 5 | { 6 | public interface IBenchmarkView : IView 7 | { 8 | event EventHandler RunBenchmark; 9 | event EventHandler CancelBenchmark; 10 | 11 | bool IsBenchmarkRunning { get; set; } 12 | 13 | void PerformProgressStep(); 14 | void SetAccuracy(int accuracy); 15 | void DrawGrid(ImageGridModel model); 16 | void ResetView(); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Views/Interfaces/IDrawingView.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Drawing; 3 | 4 | namespace DigitRecognizer.Presentation.Views.Interfaces 5 | { 6 | public interface IDrawingView : IView 7 | { 8 | event EventHandler ClassifyDrawing; 9 | event EventHandler ClearDrawing; 10 | 11 | void Clear(); 12 | void ProcessPrediction(double[] prediction); 13 | 14 | Image Drawing { get; } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Views/Interfaces/IMainFormView.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.Presentation.Views.Interfaces 2 | { 3 | /// 4 | /// Interface for the main form view. 5 | /// 6 | public interface IMainFormView : IView 7 | { 8 | /// 9 | /// Gets the benchmark view. 10 | /// 11 | IBenchmarkView BenchmarkView { get; } 12 | 13 | /// 14 | /// Gets the drawing view. 15 | /// 16 | IDrawingView DrawingView { get; } 17 | 18 | /// 19 | /// Gets the upload image view. 20 | /// 21 | IUploadImageView UploadImageView { get; } 22 | 23 | /// 24 | /// Gets the sliding window view. 25 | /// 26 | ISlidingWindowView SlidingWindowView { get; } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Views/Interfaces/ISlidingWindowView.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Drawing; 3 | using DigitRecognizer.Core.Data; 4 | 5 | namespace DigitRecognizer.Presentation.Views.Interfaces 6 | { 7 | public interface ISlidingWindowView : IView 8 | { 9 | event EventHandler ClassifyDrawing; 10 | event EventHandler ClearDrawing; 11 | 12 | void DrawBoundingBox(BoundingBox boundingBox, int prediction, double predictionAccuracy); 13 | void Clear(); 14 | 15 | Image Drawing { get; } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Views/Interfaces/IUploadImageView.cs: -------------------------------------------------------------------------------- 1 | using System; 2 | using System.Drawing; 3 | 4 | namespace DigitRecognizer.Presentation.Views.Interfaces 5 | { 6 | public interface IUploadImageView : IView 7 | { 8 | event EventHandler ClassifyImage; 9 | event EventHandler ClearImage; 10 | 11 | void Clear(); 12 | void ProcessPrediction(double[] prediction); 13 | 14 | Image Image { get; } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/Views/Interfaces/IView.cs: -------------------------------------------------------------------------------- 1 | namespace DigitRecognizer.Presentation.Views.Interfaces 2 | { 3 | public interface IView 4 | { 5 | void ShowView(); 6 | void HideView(); 7 | } 8 | } -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.Presentation/favico.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/DigitRecognizer.Presentation/favico.ico -------------------------------------------------------------------------------- /DigitRecognizer/DigitRecognizer.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio 15 4 | VisualStudioVersion = 15.0.27703.2000 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DigitRecognizer.Core", "DigitRecognizer.Core\DigitRecognizer.Core.csproj", "{BA726236-C7EC-41CA-ABB6-E4FE7E784939}" 7 | EndProject 8 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DigitRecognizer.Engine", "DigitRecognizer.Engine\DigitRecognizer.Engine.csproj", "{00CE9C72-5F92-4555-BB96-D0E77B9A2390}" 9 | EndProject 10 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DigitRecognizer.MachineLearning", "DigitRecognizer.MachineLearning\DigitRecognizer.MachineLearning.csproj", "{BF82C3DF-0EFE-45DB-8E3D-1FFDD548AE73}" 11 | EndProject 12 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DigitRecognizer.Presentation", "DigitRecognizer.Presentation\DigitRecognizer.Presentation.csproj", "{6CAF69EA-9C0C-4571-A808-268C159BDB5C}" 13 | EndProject 14 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DigitRecognizer.DatasetExpansion", "DigitRecognizer.DatasetExpansion\DigitRecognizer.DatasetExpansion.csproj", "{757CF815-DD01-4F66-8E9F-DB3FF9B4A463}" 15 | EndProject 16 | Global 17 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 18 | Debug|Any CPU = Debug|Any CPU 19 | Release|Any CPU = Release|Any CPU 20 | EndGlobalSection 21 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 22 | {BA726236-C7EC-41CA-ABB6-E4FE7E784939}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 23 | {BA726236-C7EC-41CA-ABB6-E4FE7E784939}.Debug|Any CPU.Build.0 = Debug|Any CPU 24 | {BA726236-C7EC-41CA-ABB6-E4FE7E784939}.Release|Any CPU.ActiveCfg = Release|Any CPU 25 | {BA726236-C7EC-41CA-ABB6-E4FE7E784939}.Release|Any CPU.Build.0 = Release|Any CPU 26 | {00CE9C72-5F92-4555-BB96-D0E77B9A2390}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 27 | {00CE9C72-5F92-4555-BB96-D0E77B9A2390}.Debug|Any CPU.Build.0 = Debug|Any CPU 28 | {00CE9C72-5F92-4555-BB96-D0E77B9A2390}.Release|Any CPU.ActiveCfg = Release|Any CPU 29 | {00CE9C72-5F92-4555-BB96-D0E77B9A2390}.Release|Any CPU.Build.0 = Release|Any CPU 30 | {BF82C3DF-0EFE-45DB-8E3D-1FFDD548AE73}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 31 | {BF82C3DF-0EFE-45DB-8E3D-1FFDD548AE73}.Debug|Any CPU.Build.0 = Debug|Any CPU 32 | {BF82C3DF-0EFE-45DB-8E3D-1FFDD548AE73}.Release|Any CPU.ActiveCfg = Release|Any CPU 33 | {BF82C3DF-0EFE-45DB-8E3D-1FFDD548AE73}.Release|Any CPU.Build.0 = Release|Any CPU 34 | {6CAF69EA-9C0C-4571-A808-268C159BDB5C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 35 | {6CAF69EA-9C0C-4571-A808-268C159BDB5C}.Debug|Any CPU.Build.0 = Debug|Any CPU 36 | {6CAF69EA-9C0C-4571-A808-268C159BDB5C}.Release|Any CPU.ActiveCfg = Release|Any CPU 37 | {6CAF69EA-9C0C-4571-A808-268C159BDB5C}.Release|Any CPU.Build.0 = Release|Any CPU 38 | {757CF815-DD01-4F66-8E9F-DB3FF9B4A463}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 39 | {757CF815-DD01-4F66-8E9F-DB3FF9B4A463}.Debug|Any CPU.Build.0 = Debug|Any CPU 40 | {757CF815-DD01-4F66-8E9F-DB3FF9B4A463}.Release|Any CPU.ActiveCfg = Release|Any CPU 41 | {757CF815-DD01-4F66-8E9F-DB3FF9B4A463}.Release|Any CPU.Build.0 = Release|Any CPU 42 | EndGlobalSection 43 | GlobalSection(SolutionProperties) = preSolution 44 | HideSolutionNode = FALSE 45 | EndGlobalSection 46 | GlobalSection(ExtensibilityGlobals) = postSolution 47 | SolutionGuid = {72EDBB0B-99B8-4D0F-A643-DEE415D96034} 48 | EndGlobalSection 49 | EndGlobal 50 | -------------------------------------------------------------------------------- /DigitRecognizer/Images/favico.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Images/favico.ico -------------------------------------------------------------------------------- /DigitRecognizer/Images/left_arrow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Images/left_arrow.png -------------------------------------------------------------------------------- /DigitRecognizer/Images/right_arrow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Images/right_arrow.png -------------------------------------------------------------------------------- /DigitRecognizer/Models/1f15fd63-8d73-4128-9d27-0090df8ac1ba-0.9849.nn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/1f15fd63-8d73-4128-9d27-0090df8ac1ba-0.9849.nn -------------------------------------------------------------------------------- /DigitRecognizer/Models/23682b7e-19b6-4dc2-95f6-563d13c35ffe-0.9841.nn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/23682b7e-19b6-4dc2-95f6-563d13c35ffe-0.9841.nn -------------------------------------------------------------------------------- /DigitRecognizer/Models/5bbc9588-713f-40f4-8d07-e349a089e0b2-0.9845.nn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/5bbc9588-713f-40f4-8d07-e349a089e0b2-0.9845.nn -------------------------------------------------------------------------------- /DigitRecognizer/Models/780baa22-c1c6-44ec-b13c-e867dba48010-0.9845.nn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/780baa22-c1c6-44ec-b13c-e867dba48010-0.9845.nn -------------------------------------------------------------------------------- /DigitRecognizer/Models/8edc3712-2eba-4f2c-879b-1a280687f9b4-0.9837.nn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/8edc3712-2eba-4f2c-879b-1a280687f9b4-0.9837.nn -------------------------------------------------------------------------------- /DigitRecognizer/Models/b7d29bbc-a55c-4f26-9655-db392a504105-0.9845.nn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/b7d29bbc-a55c-4f26-9655-db392a504105-0.9845.nn -------------------------------------------------------------------------------- /DigitRecognizer/Models/c0a9a2e0-0cca-4943-8755-068330811b1d-0.9827.nn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/c0a9a2e0-0cca-4943-8755-068330811b1d-0.9827.nn -------------------------------------------------------------------------------- /DigitRecognizer/Models/c9234ebb-c75b-42c1-88ad-b75650d7102a-0.9857.nn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/c9234ebb-c75b-42c1-88ad-b75650d7102a-0.9857.nn -------------------------------------------------------------------------------- /DigitRecognizer/Models/ffed716d-a8a3-48fd-b0eb-98900357be81-0.9823.nn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/ffed716d-a8a3-48fd-b0eb-98900357be81-0.9823.nn -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Milan Jovanović 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Digit Recognizer 2 | 3 | Neural network that classifies MNIST dataset, along with a desktop application with visual representation of the classficiation process. 4 | 5 | If you want to run this locally, unpack the contents (unzip) of the `Dataset` folder. It contains the images required for training and testing the neural network. 6 | 7 | Check out this video for more explantation: [**I Built a Neural Network in C# From Scratch. Here's What I Learned...**](https://youtu.be/wgNZWnua-90) 8 | --------------------------------------------------------------------------------