├── .gitignore
├── DigitRecognizer
├── Dataset
│ ├── test_dataset.zip
│ ├── train_dataset.zip
│ └── train_expanded_dataset.zip
├── DigitRecognizer.Core
│ ├── Data
│ │ ├── BoundingBox.cs
│ │ ├── Box.cs
│ │ ├── LinkedList.cs
│ │ ├── LinkedListNode.cs
│ │ └── MnistImageBatch.cs
│ ├── DigitRecognizer.Core.csproj
│ ├── Extensions
│ │ ├── ByteExtensions.cs
│ │ ├── DoubleExtensions.cs
│ │ ├── ImageExtensions.cs
│ │ ├── RandomExtensions.cs
│ │ ├── StreamExtensions.cs
│ │ └── VectorExtensions.cs
│ ├── IO
│ │ ├── ILabelReader.cs
│ │ ├── IPixelReader.cs
│ │ ├── ImagePreprocessor.cs
│ │ ├── LabelReader.cs
│ │ ├── MemoryStreamReader.cs
│ │ ├── NnBinaryDeserializer.cs
│ │ ├── NnBinarySerializer.cs
│ │ ├── NnSerializableBase.cs
│ │ ├── NnSerializationContext.cs
│ │ ├── NnSerializationContextInfo.cs
│ │ └── PixelReader.cs
│ ├── Properties
│ │ └── AssemblyInfo.cs
│ └── Utilities
│ │ ├── ConsoleUtility.cs
│ │ ├── Contracts.cs
│ │ ├── DirectoryHelper.cs
│ │ ├── ImageUtilities.cs
│ │ ├── MathUtilities.cs
│ │ └── VectorUtilities.cs
├── DigitRecognizer.DatasetExpansion
│ ├── Api
│ │ ├── DatasetExpander.cs
│ │ └── DatasetSerializer.cs
│ ├── App.config
│ ├── DigitRecognizer.DatasetExpansion.csproj
│ ├── Infrastructure
│ │ ├── AffineTransformation.cs
│ │ ├── DatasetProvider.cs
│ │ ├── FisherYatesShuffle.cs
│ │ └── MnistImage.cs
│ ├── Program.cs
│ ├── Properties
│ │ └── AssemblyInfo.cs
│ └── favico.ico
├── DigitRecognizer.Engine
│ ├── App.config
│ ├── DigitRecognizer.Engine.csproj
│ ├── Program.cs
│ ├── Properties
│ │ └── AssemblyInfo.cs
│ └── favico.ico
├── DigitRecognizer.MachineLearning
│ ├── DigitRecognizer.MachineLearning.csproj
│ ├── Infrastructure
│ │ ├── Data
│ │ │ └── CalculationCache.cs
│ │ ├── Dropout
│ │ │ ├── BinomialDistribution.cs
│ │ │ └── Dropout.cs
│ │ ├── Factories
│ │ │ ├── AbstractTypeFactory.cs
│ │ │ ├── FunctionFactory.cs
│ │ │ └── InitializerFactory.cs
│ │ ├── Functions
│ │ │ ├── CrossEntropy.cs
│ │ │ ├── ExponentialRelu.cs
│ │ │ ├── IActivationFunction.cs
│ │ │ ├── ICostFunction.cs
│ │ │ ├── IDifferentiableFunction.cs
│ │ │ ├── IFunction.cs
│ │ │ ├── LeakyRelu.cs
│ │ │ ├── MeanSquareError.cs
│ │ │ ├── Relu.cs
│ │ │ ├── Sigmoid.cs
│ │ │ ├── Softmax.cs
│ │ │ ├── Softplus.cs
│ │ │ └── Tanh.cs
│ │ ├── Initialization
│ │ │ ├── HeInitializer.cs
│ │ │ ├── IInitializer.cs
│ │ │ ├── InitializerType.cs
│ │ │ ├── RandomInitializer.cs
│ │ │ ├── XavierInitializer.cs
│ │ │ └── ZeroInitializer.cs
│ │ ├── Models
│ │ │ ├── ClusterPredictionModel.cs
│ │ │ ├── IPredictionModel.cs
│ │ │ └── PredictionModel.cs
│ │ └── NeuralNetwork
│ │ │ ├── BiasVector.cs
│ │ │ ├── INeuralNetwork.cs
│ │ │ ├── IValueAdjustable.cs
│ │ │ ├── NeuralNetwork.cs
│ │ │ ├── NnLayer.cs
│ │ │ └── WeightMatrix.cs
│ ├── Optimization
│ │ ├── LearningRateDecay
│ │ │ ├── ExponentailDecay.cs
│ │ │ ├── ILearningRateDecay.cs
│ │ │ ├── StepDecay.cs
│ │ │ └── TimeBasedDecay.cs
│ │ └── Optimizers
│ │ │ ├── BaseOptimizer.cs
│ │ │ ├── GradientDescentOptimizer.cs
│ │ │ ├── IOptimizer.cs
│ │ │ └── MomentumOptimizer.cs
│ ├── Pipeline
│ │ ├── ILearningPipelineDataLoader.cs
│ │ ├── ILearningPipelineItem.cs
│ │ ├── ILearningPipelineNeuralNetworkModel.cs
│ │ ├── ILearningPipelineOptimizer.cs
│ │ ├── LearningPipeline.cs
│ │ ├── PipelineExtensions.cs
│ │ └── PipelineSettings.cs
│ ├── Properties
│ │ └── AssemblyInfo.cs
│ ├── Providers
│ │ ├── BatchDataProvider.cs
│ │ ├── DataProviderBase.cs
│ │ └── IDataProvider.cs
│ └── Serialization
│ │ ├── INnSerializable.cs
│ │ ├── NnDeserializer.cs
│ │ └── NnSerializer.cs
├── DigitRecognizer.Presentation
│ ├── App.config
│ ├── Components
│ │ ├── ImageGrid.Designer.cs
│ │ ├── ImageGrid.cs
│ │ ├── ImageGrid.resx
│ │ ├── PredictionPane.Designer.cs
│ │ ├── PredictionPane.cs
│ │ └── PredictionPane.resx
│ ├── DigitRecognizer.Presentation.csproj
│ ├── Global.cs
│ ├── Infrastructure
│ │ ├── DependencyResolver.cs
│ │ ├── ExceptionHandlers.cs
│ │ ├── PanelDoubleBuffering.cs
│ │ └── PredictionModelLoader.cs
│ ├── Models
│ │ └── ImageGridModel.cs
│ ├── Presenters
│ │ ├── ApplicationPresenter.cs
│ │ ├── BenchmarkPresenter.cs
│ │ ├── DrawingPresenter.cs
│ │ ├── SlidingWindowPresenter.cs
│ │ └── UploadImagePresenter.cs
│ ├── Program.cs
│ ├── Properties
│ │ ├── AssemblyInfo.cs
│ │ ├── Resources.Designer.cs
│ │ ├── Resources.resx
│ │ ├── Settings.Designer.cs
│ │ └── Settings.settings
│ ├── Services
│ │ ├── ILoggingService.cs
│ │ ├── IMessageService.cs
│ │ ├── LoggingService.cs
│ │ └── MessageService.cs
│ ├── Startup.cs
│ ├── Views
│ │ ├── Implementations
│ │ │ ├── BenchmarkView.Designer.cs
│ │ │ ├── BenchmarkView.cs
│ │ │ ├── BenchmarkView.resx
│ │ │ ├── DrawingView.Designer.cs
│ │ │ ├── DrawingView.cs
│ │ │ ├── DrawingView.resx
│ │ │ ├── MainForm.Designer.cs
│ │ │ ├── MainForm.cs
│ │ │ ├── MainForm.resx
│ │ │ ├── SlidingWindowView.Designer.cs
│ │ │ ├── SlidingWindowView.cs
│ │ │ ├── SlidingWindowView.resx
│ │ │ ├── UploadImageView.Designer.cs
│ │ │ ├── UploadImageView.cs
│ │ │ └── UploadImageView.resx
│ │ └── Interfaces
│ │ │ ├── IBenchmarkView.cs
│ │ │ ├── IDrawingView.cs
│ │ │ ├── IMainFormView.cs
│ │ │ ├── ISlidingWindowView.cs
│ │ │ ├── IUploadImageView.cs
│ │ │ └── IView.cs
│ └── favico.ico
├── DigitRecognizer.sln
├── Images
│ ├── favico.ico
│ ├── left_arrow.png
│ └── right_arrow.png
└── Models
│ ├── 1f15fd63-8d73-4128-9d27-0090df8ac1ba-0.9849.nn
│ ├── 23682b7e-19b6-4dc2-95f6-563d13c35ffe-0.9841.nn
│ ├── 5bbc9588-713f-40f4-8d07-e349a089e0b2-0.9845.nn
│ ├── 780baa22-c1c6-44ec-b13c-e867dba48010-0.9845.nn
│ ├── 8edc3712-2eba-4f2c-879b-1a280687f9b4-0.9837.nn
│ ├── b7d29bbc-a55c-4f26-9655-db392a504105-0.9845.nn
│ ├── c0a9a2e0-0cca-4943-8755-068330811b1d-0.9827.nn
│ ├── c9234ebb-c75b-42c1-88ad-b75650d7102a-0.9857.nn
│ └── ffed716d-a8a3-48fd-b0eb-98900357be81-0.9823.nn
├── LICENSE
└── README.md
/DigitRecognizer/Dataset/test_dataset.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Dataset/test_dataset.zip
--------------------------------------------------------------------------------
/DigitRecognizer/Dataset/train_dataset.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Dataset/train_dataset.zip
--------------------------------------------------------------------------------
/DigitRecognizer/Dataset/train_expanded_dataset.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Dataset/train_expanded_dataset.zip
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/Data/BoundingBox.cs:
--------------------------------------------------------------------------------
1 | using System.Drawing;
2 |
3 | namespace DigitRecognizer.Core.Data
4 | {
5 | public class BoundingBox
6 | {
7 | public Image Image { get; set; }
8 | public int X { get; set; }
9 | public int Y { get; set; }
10 | }
11 | }
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/Data/Box.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.Core.Data
2 | {
3 | ///
4 | /// Represents a box with configurable dimensions.
5 | ///
6 | public struct Box
7 | {
8 | private int _top;
9 | private int _bottom;
10 | private int _left;
11 | private int _right;
12 |
13 | ///
14 | /// Initializes a new instance of the struct.
15 | ///
16 | /// The top value.
17 | /// The bottom value.
18 | /// The left value.
19 | /// The right value.
20 | public Box(int top, int bottom, int left, int right)
21 | {
22 | _top = top;
23 | _bottom = bottom;
24 | _left = left;
25 | _right = right;
26 | }
27 |
28 | ///
29 | /// Gets or sets the top.
30 | ///
31 | public int Top
32 | {
33 | get => _top;
34 | set => _top = value;
35 | }
36 |
37 | ///
38 | /// Gets or sets the bottom.
39 | ///
40 | public int Bottom
41 | {
42 | get => _bottom;
43 | set => _bottom = value;
44 | }
45 |
46 | ///
47 | /// Gets or sets the left.
48 | ///
49 | public int Left
50 | {
51 | get => _left;
52 | set => _left = value;
53 | }
54 |
55 | ///
56 | /// Gets or sets the right.
57 | ///
58 | public int Right
59 | {
60 | get => _right;
61 | set => _right = value;
62 | }
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/Data/LinkedListNode.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Utilities;
2 |
3 | namespace DigitRecognizer.Core.Data
4 | {
5 | ///
6 | /// Doubly linked list node of a generic type.
7 | ///
8 | ///
9 | public class LinkedListNode
10 | {
11 | private LinkedListNode _previous;
12 | private LinkedListNode _next;
13 | private readonly T _item;
14 | private int _depth;
15 |
16 | ///
17 | /// Initializes a new instance of the class.
18 | ///
19 | ///
20 | public LinkedListNode(T item)
21 | {
22 | _item = item;
23 | }
24 |
25 | ///
26 | /// Gets or sets the previous property.
27 | ///
28 | public LinkedListNode Previous
29 | {
30 | get => _previous;
31 | set => _previous = value;
32 | }
33 |
34 | ///
35 | /// Gets or sets the next property.
36 | ///
37 | public LinkedListNode Next
38 | {
39 | get => _next;
40 | set => _next = value;
41 | }
42 |
43 | ///
44 | /// Gets the value of the node. This field is readonly.
45 | ///
46 | public T Value => _item;
47 |
48 | ///
49 | /// Gets or sets the depth of the node.
50 | ///
51 | public int Depth
52 | {
53 | get => _depth;
54 | set
55 | {
56 | Contracts.ValueGreaterThanZero(value, nameof(value));
57 |
58 | _depth = value;
59 | }
60 | }
61 | }
62 | }
63 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/Data/MnistImageBatch.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Utilities;
2 |
3 | namespace DigitRecognizer.Core.Data
4 | {
5 | ///
6 | /// Represents a batch of MNIST images.
7 | ///
8 | public class MnistImageBatch
9 | {
10 | ///
11 | /// Gets the labels.
12 | ///
13 | public int[] Labels { get; }
14 |
15 | ///
16 | /// Gets the pixels.
17 | ///
18 | public double[][] Pixels { get; }
19 |
20 | ///
21 | /// Initializes a new instance of the class.
22 | ///
23 | /// The labels of the batch.
24 | /// The pixels of the batch.
25 | public MnistImageBatch(int[] labels, double[][] pixels)
26 | {
27 | Contracts.ValuesMatch(labels.Length, pixels.Length, nameof(labels.Length));
28 |
29 | Labels = labels;
30 | Pixels = pixels;
31 | }
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/DigitRecognizer.Core.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Debug
6 | AnyCPU
7 | {BA726236-C7EC-41CA-ABB6-E4FE7E784939}
8 | Library
9 | Properties
10 | DigitRecognizer.Core
11 | DigitRecognizer.Core
12 | v4.8
13 | 512
14 |
15 |
16 |
17 | true
18 | full
19 | false
20 | bin\Debug\
21 | DEBUG;TRACE
22 | prompt
23 | 4
24 | true
25 |
26 |
27 | pdbonly
28 | true
29 | bin\Release\
30 | TRACE
31 | prompt
32 | 4
33 | true
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/Extensions/ByteExtensions.cs:
--------------------------------------------------------------------------------
1 | using System;
2 |
3 | namespace DigitRecognizer.Core.Extensions
4 | {
5 | ///
6 | /// Provides extension methods for working with struct.
7 | ///
8 | public static class ByteExtensions
9 | {
10 | ///
11 | /// Converts the specified byte array to an array of doubles, utilizing the class.
12 | ///
13 | /// The array of bytes that will be converted to doubles.
14 | /// A double array.
15 | public static double[] ToDoubles(this byte[] array)
16 | {
17 | if (array.Length == 0)
18 | {
19 | return default(double[]);
20 | }
21 |
22 | var result = new double[array.Length / sizeof(double)];
23 | Buffer.BlockCopy(array, 0, result, 0, array.Length);
24 | return result;
25 | }
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/Extensions/DoubleExtensions.cs:
--------------------------------------------------------------------------------
1 | using System;
2 |
3 | namespace DigitRecognizer.Core.Extensions
4 | {
5 | ///
6 | /// Provides extensions methods for working with the struct.
7 | ///
8 | public static class DoubleExtensions
9 | {
10 | private const double DoubleNorm = 255.0D;
11 | private const double DoubleOffsetNorm = 127.5D;
12 |
13 | ///
14 | /// Returns a double in the range [0, 1], parsed from the specified byte.
15 | ///
16 | /// The value that will be parsed then normalized.
17 | /// A double in the range [0,1].
18 | public static double FromByteNormalized(byte value)
19 | {
20 | return value / DoubleNorm;
21 | }
22 |
23 | ///
24 | /// Returns a double in the range [-0.5, 0.5], parsed from the specified byte.
25 | ///
26 | /// The value that will be parsed then normalized.
27 | /// A double in the range [-0.5,0.5].
28 | public static double FromByteOffset(byte value)
29 | {
30 | return value / DoubleOffsetNorm - 1;
31 | }
32 |
33 | ///
34 | /// Gets the bytes of the double array, utilizing the class.
35 | ///
36 | /// The array of doubles that will be converted to bytes.
37 | /// A byte array.
38 | public static byte[] GetBytes(this double[] array)
39 | {
40 | int length = array.Length;
41 | const int sizeOfDouble = sizeof(double);
42 | var result = new byte[length * sizeOfDouble];
43 |
44 | var offset = 0;
45 | for (var i = 0; i < length; i++)
46 | {
47 | byte[] bytes = BitConverter.GetBytes(array[i]);
48 | Buffer.BlockCopy(bytes, 0, result, offset, sizeOfDouble);
49 | offset += sizeOfDouble;
50 | }
51 |
52 | return result;
53 | }
54 |
55 | ///
56 | /// Gets the bytes of the double array, utilizing the class.
57 | ///
58 | /// The array of doubles that will be converted to bytes.
59 | /// A byte array.
60 | public static byte[] ToBytes(this double[] array)
61 | {
62 | var result = new byte[array.Length * sizeof(double)];
63 |
64 | Buffer.BlockCopy(array, 0, result, 0, result.Length);
65 |
66 | return result;
67 | }
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/Extensions/ImageExtensions.cs:
--------------------------------------------------------------------------------
1 | using System.Drawing;
2 | using DigitRecognizer.Core.Utilities;
3 |
4 | namespace DigitRecognizer.Core.Extensions
5 | {
6 | ///
7 | /// Contains extensions methods for the class.
8 | ///
9 | public static class ImageExtensions
10 | {
11 | ///
12 | /// Resizes the image to the specified width and height.
13 | ///
14 | /// The image being resized.
15 | /// The new width in pixels.
16 | /// The new height in pixels.
17 | /// The resized image.
18 | public static Image Resize(this Image image, int width, int height)
19 | {
20 | return ImageUtilities.Resize(image, width, height);
21 | }
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/Extensions/RandomExtensions.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using DigitRecognizer.Core.Utilities;
3 |
4 | namespace DigitRecognizer.Core.Extensions
5 | {
6 | ///
7 | /// Contains extension method for the class.
8 | ///
9 | public static class RandomExtensions
10 | {
11 | ///
12 | /// Returns an array of random floating-point nubmers that are greater than or equal to 0.0 and less than 1.0.
13 | ///
14 | /// The random number generator.
15 | /// The length of the array.
16 | /// The array filled with random numbers.
17 | public static double[] NextDoubles(this Random rnd, int length)
18 | {
19 | Contracts.ValueGreaterThanZero(length, nameof(length));
20 |
21 | var result = new double[length];
22 |
23 | for (var i = 0; i < length; i++)
24 | {
25 | result[i] = rnd.NextDouble();
26 | }
27 |
28 | return result;
29 | }
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/Extensions/StreamExtensions.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.IO;
3 |
4 | namespace DigitRecognizer.Core.Extensions
5 | {
6 | ///
7 | /// Contains extension methods for working with the class.
8 | ///
9 | public static class StreamExtensions
10 | {
11 | ///
12 | /// Initialzies a object, from the the specified file.
13 | ///
14 | /// The file path, that points to a file whose memory stream we want to get.
15 | /// A .
16 | public static MemoryStream GetMemoryStreamFromFile(string filename)
17 | {
18 | if (!File.Exists(filename))
19 | {
20 | return default(MemoryStream);
21 | }
22 |
23 | var memoryStream = new MemoryStream();
24 |
25 | using (var fileStream = File.OpenRead(filename))
26 | {
27 | memoryStream.SetLength(fileStream.Length);
28 | fileStream.Read(memoryStream.GetBuffer(), 0, (int)fileStream.Length);
29 | }
30 |
31 | return memoryStream;
32 | }
33 |
34 | ///
35 | /// Sets the position within the current stream to the specified value.
36 | ///
37 | /// The in which we are setting the position.
38 | /// The position to be set in the .
39 | public static void SetPosition(this Stream stream, int position)
40 | {
41 | if (!stream.CanSeek)
42 | {
43 | return;
44 | }
45 |
46 | stream.Seek(position, SeekOrigin.Begin);
47 | }
48 |
49 | ///
50 | /// Sets the position within the current stream to the beginning of the stream.
51 | ///
52 | /// The that we are resetting.
53 | public static void Reset(this Stream stream)
54 | {
55 | stream.SetPosition(0);
56 | }
57 |
58 | ///
59 | /// Reads the specified number of bytes from the stream, and advances the current position of the stream.
60 | ///
61 | /// The that we are reading from.
62 | /// The number of bytes to read from the stream.
63 | /// An array of bytes.
64 | public static byte[] ReadBytes(this Stream stream, int count)
65 | {
66 | if (count < 0 || count > stream.Length)
67 | {
68 | return new byte[0];
69 | }
70 |
71 | var buffer = new byte[count];
72 |
73 | int bytesRead = stream.Read(buffer, 0, count);
74 |
75 | if (bytesRead == count)
76 | {
77 | return buffer;
78 | }
79 |
80 | var smallerBuffer = new byte[bytesRead];
81 | Buffer.BlockCopy(buffer, 0, smallerBuffer, 0, bytesRead);
82 | buffer = smallerBuffer;
83 |
84 | return buffer;
85 | }
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/IO/ILabelReader.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.Core.IO
2 | {
3 | ///
4 | /// Interface that should be implemented by a label provider class.
5 | ///
6 | public interface ILabelReader
7 | {
8 | int ReadLabel();
9 | int[] ReadLabels(int count);
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/IO/IPixelReader.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.Core.IO
2 | {
3 | ///
4 | /// Interface that should be implemented by a pixel provider class.
5 | ///
6 | public interface IPixelReader
7 | {
8 | double[] ReadPixels(int count);
9 | double[][] ReadPixels(int count, int blockSize);
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/IO/ImagePreprocessor.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Drawing;
3 | using DigitRecognizer.Core.Data;
4 | using DigitRecognizer.Core.Utilities;
5 |
6 | namespace DigitRecognizer.Core.IO
7 | {
8 | ///
9 | /// Represents a preprocessor for images that will be run through the nerual network.
10 | ///
11 | public class ImagePreprocessor
12 | {
13 | private const int ImageSizeInPixels = 28;
14 |
15 | ///
16 | /// Preprocesses the specified image so that it can be fed through a neural network.
17 | /// The image has to be an RGB or grayscale image with 0 being black, and 255 being white.
18 | ///
19 | /// The image being preprocessed.
20 | /// The flattened and clamped pixels of the preprocessed image.
21 | ///
22 | /// First, the image is converted to a grayscale image and then a threshold is applied.
23 | /// The size of the box surrounding the pixels where the number is is determined.
24 | /// The image is then scaled to the predetermined box, and coordinates for padding are extracted.
25 | /// The image is then padded to again be of the required size and then it is inverted.
26 | /// We then calculate the center of mass, and translate the image towards the center of mass.
27 | /// The result is that the pixels that reperesent the number are shifted towards the center of the image.
28 | /// Another threshold is applied, and the pixels are then flattened and clamped.
29 | ///
30 | public double[] Preprocess(Image image)
31 | {
32 | Image grayscale = ImageUtilities.Grayscale(image);
33 |
34 | Image resized = ImageUtilities.Resize(grayscale, ImageSizeInPixels, ImageSizeInPixels);
35 |
36 | Image grayscaleWithTreshold = ImageUtilities.Threshold(resized, 128, 0, 255);
37 |
38 | Box coords = ImageUtilities.DetermineBox(grayscaleWithTreshold);
39 |
40 | (Image, Box) imageAndPadding = ImageUtilities.ScaleToBoxAndGetPaddingCoords(grayscaleWithTreshold, coords);
41 |
42 | Image padded = ImageUtilities.Pad(imageAndPadding.Item1, imageAndPadding.Item2);
43 |
44 | Image inverted = ImageUtilities.Invert(padded);
45 |
46 | Point centerOfMass = ImageUtilities.CalculateCenterOfMass(inverted);
47 |
48 | var shiftx = (int)Math.Round(ImageSizeInPixels / 2.0 - centerOfMass.X);
49 |
50 | var shifty = (int)Math.Round(ImageSizeInPixels / 2.0 - centerOfMass.Y);
51 |
52 | Image centered = ImageUtilities.Translate(inverted, shiftx, shifty);
53 |
54 | Image final = ImageUtilities.Threshold(centered, 154, 0, 255, true);
55 |
56 | double[] data = ImageUtilities.FlattenAndClamp(final);
57 |
58 | return data;
59 | }
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/IO/LabelReader.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.Core.IO
2 | {
3 | ///
4 | /// Provides methods for reading labels from a .
5 | ///
6 | public class LabelReader : MemoryStreamReader, ILabelReader
7 | {
8 | ///
9 | /// A "magic number" of bytes that needs to be skipped at the beginning of the stream.
10 | ///
11 | public const int InitialOffset = 8;
12 |
13 | ///
14 | /// Initializes a new instance of the class.
15 | ///
16 | /// The file path used for instantiating an internal stream object.
17 | public LabelReader(string filename) : base(filename, InitialOffset)
18 | {
19 | }
20 |
21 | ///
22 | /// Reads a single label from the file.
23 | ///
24 | /// A label.
25 | public int ReadLabel()
26 | {
27 | int result = Read(1)[0];
28 |
29 | return result;
30 | }
31 |
32 | ///
33 | /// Reads the specified ammount of labels from the file.
34 | ///
35 | /// The number of labels to read.
36 | /// An array of labels.
37 | public int[] ReadLabels(int count)
38 | {
39 | byte[] bytes = Read(count);
40 |
41 | var result = new int[count];
42 |
43 | for (var i = 0; i < count; i++)
44 | {
45 | result[i] = bytes[i];
46 | }
47 |
48 | return result;
49 | }
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/IO/NnSerializableBase.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.IO;
3 | using DigitRecognizer.Core.Utilities;
4 |
5 | namespace DigitRecognizer.Core.IO
6 | {
7 | ///
8 | /// An abstract base class for serialization of neural network files.
9 | ///
10 | public abstract class NnSerializableBase : IDisposable
11 | {
12 | ///
13 | /// The file extension of a neural network file.
14 | ///
15 | private const string NnFileExtension = ".nn";
16 |
17 | ///
18 | /// The internal object.
19 | ///
20 | protected readonly FileStream FileStream;
21 |
22 | ///
23 | /// Initializes a new instance of the class.
24 | ///
25 | /// The name of the file, used for opening a file stream.
26 | /// The file acces mode of the adapter.
27 | protected NnSerializableBase(string filename, FileMode fileMode)
28 | {
29 | Contracts.StringNotNullOrEmpty(filename, nameof(filename));
30 | Contracts.FileHasExtension(filename, nameof(filename));
31 | Contracts.FileExtensionValid(filename, NnFileExtension, nameof(filename));
32 | // Contracts.FileExists(filename, nameof(filename));
33 |
34 | FileStream = File.Open(filename, fileMode, FileAccess.ReadWrite);
35 | }
36 |
37 | ///
38 | /// Indicates if the disposing operation has been completed or not.
39 | ///
40 | private bool _disposed;
41 |
42 | ///
43 | /// Releases all resources used by the .
44 | ///
45 | public void Dispose()
46 | {
47 | Dispose(true);
48 | GC.SuppressFinalize(this);
49 | }
50 |
51 | ///
52 | /// Releases all resources used by the .
53 | ///
54 | protected virtual void Dispose(bool disposing)
55 | {
56 | if (_disposed)
57 | {
58 | return;
59 | }
60 |
61 | if (disposing)
62 | {
63 | FileStream?.Dispose();
64 | }
65 |
66 | _disposed = true;
67 | }
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/IO/NnSerializationContext.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.Core.IO
2 | {
3 | ///
4 | /// Model for neural netowrk serialization.
5 | ///
6 | public class NnSerializationContext
7 | {
8 | private readonly NnSerializationContextInfo _serializationContextInfo;
9 | private readonly double[] _data;
10 |
11 | ///
12 | /// Initializes a new instance of the class with the specified parameters.
13 | ///
14 | /// The file info to associate with the file.
15 | /// The data for the file.
16 | public NnSerializationContext(double[] fileData, NnSerializationContextInfo serializationContextInfo)
17 | {
18 | _serializationContextInfo = serializationContextInfo;
19 | _data = fileData;
20 | }
21 |
22 | ///
23 | /// Gets the .
24 | ///
25 | public NnSerializationContextInfo SerializationContextInfo => _serializationContextInfo;
26 |
27 | ///
28 | /// Gets the data.
29 | ///
30 | /// The data is a combined array of all weights and biases.
31 | ///
32 | ///
33 | public double[] FileData => _data;
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/IO/NnSerializationContextInfo.cs:
--------------------------------------------------------------------------------
1 | using System.Text;
2 | using DigitRecognizer.Core.Utilities;
3 |
4 | namespace DigitRecognizer.Core.IO
5 | {
6 | ///
7 | /// Contains context information about a neural netowrk file.
8 | ///
9 | public struct NnSerializationContextInfo
10 | {
11 | private readonly int _weightMatrixRowCount;
12 | private readonly int _weightMatrixColCount;
13 | private readonly int _biasLength;
14 | private readonly string _activationFunctionName;
15 |
16 | ///
17 | /// Initializes a new instance of the struct.
18 | ///
19 | /// The number of rows of the weight matrix.
20 | /// The number of columns of the weight matrix.
21 | /// The length of the bias.
22 | /// The name of the activation function for the layer.
23 | public NnSerializationContextInfo(int weightMatrixRowCount, int weightMatrixColCount, int biasLength, string activationFunctionName)
24 | {
25 | Contracts.ValueGreaterThanZero(weightMatrixRowCount, nameof(weightMatrixRowCount));
26 | Contracts.ValueGreaterThanZero(weightMatrixColCount, nameof(weightMatrixColCount));
27 | Contracts.ValueGreaterThanZero(biasLength, nameof(biasLength));
28 |
29 | _weightMatrixRowCount = weightMatrixRowCount;
30 | _weightMatrixColCount = weightMatrixColCount;
31 | _biasLength = biasLength;
32 | _activationFunctionName = activationFunctionName;
33 | }
34 |
35 | ///
36 | /// Gets the weight matrix row count.
37 | ///
38 | public int WeightMatrixRowCount => _weightMatrixRowCount;
39 |
40 | ///
41 | /// Gets the weight matrix column count.
42 | ///
43 | public int WeightMatrixColCount => _weightMatrixColCount;
44 |
45 | ///
46 | /// Gets the bias length.
47 | ///
48 | public int BiasLength => _biasLength;
49 |
50 | ///
51 | /// Gets the name of the activation function.
52 | ///
53 | public string ActivationFunctionName => _activationFunctionName;
54 |
55 | ///
56 | /// Gets the number of bytes in the activation function name.
57 | ///
58 | public int ActivationFunctionNameSizeInBytes => Encoding.Unicode.GetByteCount(_activationFunctionName);
59 |
60 | ///
61 | /// Gets the length of the data array.
62 | ///
63 | public int DataLength => _weightMatrixRowCount * _weightMatrixColCount + _biasLength;
64 |
65 | ///
66 | /// Gets the size of the data array in bytes.
67 | ///
68 | public int DataSizeInBytes => DataLength * sizeof(double);
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/IO/PixelReader.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.Core.IO
2 | {
3 | ///
4 | /// Provides methods for reading pixels from a .
5 | ///
6 | public class PixelReader : MemoryStreamReader, IPixelReader
7 | {
8 | ///
9 | /// A "magic number" of bytes that needs to be skipped at the beginning of the stream.
10 | ///
11 | public const int InitialOffset = 16;
12 |
13 | ///
14 | /// Initializes a new instance of the class.
15 | ///
16 | /// The file path used for instantiating an internal stream object.
17 | public PixelReader(string filename) : base(filename, InitialOffset)
18 | {
19 | }
20 |
21 | ///
22 | /// Reads the specified ammount of pixels from the file.
23 | ///
24 | /// The number of pixels to read.
25 | /// Am arrau of doubles.
26 | public double[] ReadPixels(int count)
27 | {
28 | byte[] bytes = Read(count);
29 |
30 | double[] result = ConvertByteArrayToDoubleArray(bytes, count);
31 |
32 | return result;
33 | }
34 |
35 | ///
36 | /// Reads the specified ammount of blocks of pixels from the file, each of the specified block size.
37 | ///
38 | /// The number of blocks to read.
39 | /// The size of each block.
40 | /// A jagged array of doubles.
41 | public double[][] ReadPixels(int count, int blockSize)
42 | {
43 | byte[][] bytes = Read(count, blockSize);
44 |
45 | var result = new double[count][];
46 |
47 | for (var i = 0; i < count; i++)
48 | {
49 | result[i] = ConvertByteArrayToDoubleArray(bytes[i], blockSize);
50 | }
51 |
52 | return result;
53 | }
54 |
55 | ///
56 | /// Converts an array of bytes to an array of doubles.
57 | ///
58 | /// The array of bytes.
59 | /// The length of the array.
60 | ///
61 | private static double[] ConvertByteArrayToDoubleArray(byte[] bytes, int length)
62 | {
63 | var result = new double[length];
64 |
65 | for (var i = 0; i < length; i++)
66 | {
67 | result[i] = bytes[i];
68 | }
69 |
70 | return result;
71 | }
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/Properties/AssemblyInfo.cs:
--------------------------------------------------------------------------------
1 | using System.Reflection;
2 | using System.Runtime.InteropServices;
3 |
4 | // General Information about an assembly is controlled through the following
5 | // set of attributes. Change these attribute values to modify the information
6 | // associated with an assembly.
7 | [assembly: AssemblyTitle("DigitRecognizer.Core")]
8 | [assembly: AssemblyDescription("")]
9 | [assembly: AssemblyConfiguration("")]
10 | [assembly: AssemblyCompany("")]
11 | [assembly: AssemblyProduct("DigitRecognizer.Core")]
12 | [assembly: AssemblyCopyright("Copyright © 2018")]
13 | [assembly: AssemblyTrademark("")]
14 | [assembly: AssemblyCulture("")]
15 |
16 | // Setting ComVisible to false makes the types in this assembly not visible
17 | // to COM components. If you need to access a type in this assembly from
18 | // COM, set the ComVisible attribute to true on that type.
19 | [assembly: ComVisible(false)]
20 |
21 | // The following GUID is for the ID of the typelib if this project is exposed to COM
22 | [assembly: Guid("ba726236-c7ec-41ca-abb6-e4fe7e784939")]
23 |
24 | // Version information for an assembly consists of the following four values:
25 | //
26 | // Major Version
27 | // Minor Version
28 | // Build Number
29 | // Revision
30 | //
31 | // You can specify all the values or you can default the Build and Revision Numbers
32 | // by using the '*' as shown below:
33 | // [assembly: AssemblyVersion("1.0.*")]
34 | [assembly: AssemblyVersion("1.0.0.0")]
35 | [assembly: AssemblyFileVersion("1.0.0.0")]
36 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/Utilities/ConsoleUtility.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Diagnostics;
3 |
4 | namespace DigitRecognizer.Core.Utilities
5 | {
6 | ///
7 | /// Contains utility methods for printing in a console window.
8 | ///
9 | public static class ConsoleUtility
10 | {
11 | ///
12 | /// Writes the specified message.
13 | ///
14 | /// The message.
15 | public static void WriteLine(string message)
16 | {
17 | #if DEBUG
18 | Debug.WriteLine(message);
19 | #endif
20 | Console.WriteLine(message);
21 | }
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Core/Utilities/DirectoryHelper.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.Core.Utilities
2 | {
3 | ///
4 | /// Provides utility for accessing commonly required folders.
5 | ///
6 | public static class DirectoryHelper
7 | {
8 | ///
9 | /// The path to the dataset folder.
10 | ///
11 | private const string DatasetPath = @"../../../Dataset";
12 |
13 | ///
14 | /// The path to the models folder.
15 | ///
16 | private const string ModelsPath = @"../../../../Models";
17 |
18 | ///
19 | /// Gets the path to the training labels file.
20 | ///
21 | public static string TrainLabelsPath => $"{DatasetPath}/train-labels.idx1-ubyte";
22 |
23 | ///
24 | /// Gets the path to the training images file.
25 | ///
26 | public static string TrainImagesPath => $"{DatasetPath}/train-images.idx3-ubyte";
27 |
28 | ///
29 | /// Gets the path to the expaneded training labels file.
30 | ///
31 | public static string ExpandedTrainLabelsPath => $"{DatasetPath}/train-exp-labels.idx1-ubyte";
32 |
33 | ///
34 | /// Gets the path to the expanded training images file.
35 | ///
36 | public static string ExpandedTrainImagesPath => $"{DatasetPath}/train-exp-images.idx3-ubyte";
37 |
38 | ///
39 | /// Gets the path to the testing labels file.
40 | ///
41 | public static string TestLabelsPath => $"{DatasetPath}/t10k-labels.idx1-ubyte";
42 |
43 | ///
44 | /// Gets the path to the testing images file.
45 | ///
46 | public static string TestImagesPath => $"{DatasetPath}/t10k-images.idx3-ubyte";
47 |
48 | ///
49 | /// Gets the path to the models folder.
50 | ///
51 | public static string ModelsFolder => ModelsPath;
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.DatasetExpansion/Api/DatasetExpander.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.Linq;
4 | using DigitRecognizer.DatasetExpansion.Infrastructure;
5 |
6 | namespace DigitRecognizer.DatasetExpansion.Api
7 | {
8 | ///
9 | /// Expands the existing MNIST training dataset.
10 | ///
11 | public class DatasetExpander
12 | {
13 | ///
14 | /// Expands and shuffles the existing dataset with various affine transformations.
15 | ///
16 | ///
17 | public List ExpandDataset()
18 | {
19 | var result = new List();
20 |
21 | Console.WriteLine("Reading images from disk");
22 |
23 | List dataset = GetDataset();
24 |
25 | Console.WriteLine("Applying translation transformation");
26 |
27 | List translatedImages = ApplyTranslationTransformation(dataset);
28 |
29 | Console.WriteLine("Applying rotation transformation");
30 |
31 | List rotatedImages = ApplyRotationTransformation(dataset);
32 |
33 | result.AddRange(dataset);
34 |
35 | result.AddRange(translatedImages);
36 |
37 | result.AddRange(rotatedImages);
38 |
39 | Console.WriteLine("Shuffling images");
40 |
41 | return FisherYatesShuffle.Shuffle(result);
42 | }
43 |
44 | ///
45 | /// Gets the dataset form disk.
46 | ///
47 | /// A list of objects.
48 | private List GetDataset()
49 | {
50 | var provider = new DatasetProvider();
51 |
52 | return provider.LoadDataset();
53 | }
54 |
55 | ///
56 | /// Applies the transalation transformation to the dataset images.
57 | ///
58 | /// The dataset.
59 | /// The translated images.
60 | private List ApplyTranslationTransformation(List dataset)
61 | {
62 | var result = new List();
63 |
64 | result.AddRange(dataset.Select(x => new MnistImage(x.Label, AffineTransformation.Translate(x.Pixels, -1, 0))).ToList());
65 |
66 | result.AddRange(dataset.Select(x => new MnistImage(x.Label, AffineTransformation.Translate(x.Pixels, 1, 0))).ToList());
67 |
68 | result.AddRange(dataset.Select(x => new MnistImage(x.Label, AffineTransformation.Translate(x.Pixels, 0, -1))).ToList());
69 |
70 | result.AddRange(dataset.Select(x => new MnistImage(x.Label, AffineTransformation.Translate(x.Pixels, 0, 1))).ToList());
71 |
72 | return result;
73 | }
74 |
75 | ///
76 | /// Applies the rotation transformation to the dataset images.
77 | ///
78 | /// The dataset.
79 | /// The rotated images.
80 | private List ApplyRotationTransformation(List dataset)
81 | {
82 | var result = new List();
83 |
84 | var random = new Random();
85 |
86 | result.AddRange(dataset.Select(x=> new MnistImage(x.Label, AffineTransformation.Rotate(x.Pixels, random.NextDouble() > 0.5 ? 0.5 : -0.5))));
87 |
88 | return result;
89 | }
90 | }
91 | }
92 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.DatasetExpansion/Api/DatasetSerializer.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.IO;
4 | using System.Linq;
5 | using DigitRecognizer.DatasetExpansion.Infrastructure;
6 |
7 | namespace DigitRecognizer.DatasetExpansion.Api
8 | {
9 | ///
10 | /// Represents a serializer for an MNIST dataset.
11 | ///
12 | public class DatasetSerializer
13 | {
14 | private const int LabelMagicNumber = 0x00000801;
15 |
16 | private const int ImageMagicNumber = 0x00000803;
17 |
18 | private const int Length = 28;
19 |
20 | ///
21 | /// Serializes the specified dataset to the specified files.
22 | ///
23 | /// The name of the file with labels data.
24 | /// The name of the file with images data.
25 | /// The dataset.
26 | public void SerializeDataset(string labelsFilename, string imagesFilename, List dataset)
27 | {
28 | Console.WriteLine("Ensuring file integrity");
29 |
30 | EnsureFileIntegrity(labelsFilename);
31 | EnsureFileIntegrity(imagesFilename);
32 |
33 | // The dataset size must be MSB first.
34 | int datasetSize = BitConverter.ToInt32(BitConverter.GetBytes(dataset.Count).Reverse().ToArray(), 0);
35 |
36 | using (var lblWriter = new BinaryWriter(File.Open(labelsFilename, FileMode.Create, FileAccess.Write)))
37 | using (var imgWriter = new BinaryWriter(File.Open(imagesFilename, FileMode.Create, FileAccess.Write)))
38 | {
39 | // Write the magic number and dataset size to the file.
40 | lblWriter.Write(LabelMagicNumber);
41 |
42 | lblWriter.Write(datasetSize);
43 |
44 | // Write the magic number and dataset size to the file.
45 | imgWriter.Write(ImageMagicNumber);
46 |
47 | imgWriter.Write(datasetSize);
48 |
49 | imgWriter.Write(Length);
50 |
51 | imgWriter.Write(Length);
52 |
53 | // Write bytes of each image to appropriate files.
54 | foreach (MnistImage img in dataset)
55 | {
56 | lblWriter.Write(img.Label);
57 |
58 | imgWriter.Write(img.Pixels);
59 | }
60 | }
61 | }
62 |
63 | ///
64 | /// Ensures the integrity of the file with the specified file name.
65 | ///
66 | /// The filename.
67 | private void EnsureFileIntegrity(string filename)
68 | {
69 | const int retryCount = 5;
70 |
71 | for (var i = 0; i < retryCount; i++)
72 | {
73 | try
74 | {
75 | if (File.Exists(filename))
76 | {
77 | File.Delete(filename);
78 | }
79 |
80 | return;
81 | }
82 | catch (Exception)
83 | {
84 | // ignored
85 | }
86 | }
87 |
88 | throw new Exception($"Failed to delete file {filename}");
89 | }
90 | }
91 | }
92 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.DatasetExpansion/App.config:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.DatasetExpansion/DigitRecognizer.DatasetExpansion.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Debug
6 | AnyCPU
7 | {757CF815-DD01-4F66-8E9F-DB3FF9B4A463}
8 | Exe
9 | DigitRecognizer.DatasetExpansion
10 | DigitRecognizer.DatasetExpansion
11 | v4.8
12 | 512
13 | true
14 |
15 |
16 |
17 | AnyCPU
18 | true
19 | full
20 | false
21 | bin\Debug\
22 | DEBUG;TRACE
23 | prompt
24 | 4
25 |
26 |
27 | AnyCPU
28 | pdbonly
29 | true
30 | bin\Release\
31 | TRACE
32 | prompt
33 | 4
34 |
35 |
36 | favico.ico
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 | {BA726236-C7EC-41CA-ABB6-E4FE7E784939}
64 | DigitRecognizer.Core
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.DatasetExpansion/Infrastructure/DatasetProvider.cs:
--------------------------------------------------------------------------------
1 | using System.Collections.Generic;
2 | using DigitRecognizer.Core.IO;
3 | using DigitRecognizer.Core.Utilities;
4 |
5 | namespace DigitRecognizer.DatasetExpansion.Infrastructure
6 | {
7 | ///
8 | /// Loads the current dataset from disk.
9 | ///
10 | public class DatasetProvider
11 | {
12 | private const int ImagePixelCount = 784;
13 | private const int DatasetSize = 60000;
14 |
15 | private static readonly string TrainImagesFilePath = DirectoryHelper.TrainImagesPath;
16 |
17 | private static readonly string TrainLabelsFilePath = DirectoryHelper.TrainLabelsPath;
18 |
19 | ///
20 | /// Lodas the MNIST training dataset from disk.
21 | ///
22 | /// The current dataset.
23 | public List LoadDataset()
24 | {
25 | var result = new List();
26 |
27 | using (var labelReader = new MemoryStreamReader(TrainLabelsFilePath, LabelReader.InitialOffset))
28 | using (var pixelReader = new MemoryStreamReader(TrainImagesFilePath, PixelReader.InitialOffset))
29 | {
30 | for (var i = 0; i < DatasetSize; i++)
31 | {
32 | byte label = labelReader.Read(1)[0];
33 |
34 | byte[] pixels = pixelReader.Read(ImagePixelCount);
35 |
36 | var img = new MnistImage(label, pixels);
37 |
38 | result.Add(img);
39 | }
40 | }
41 |
42 | return result;
43 | }
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.DatasetExpansion/Infrastructure/FisherYatesShuffle.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.Linq;
4 |
5 | namespace DigitRecognizer.DatasetExpansion.Infrastructure
6 | {
7 | ///
8 | /// Implements the Fisher-Yates shuffling algorithm.
9 | ///
10 | public class FisherYatesShuffle
11 | {
12 | ///
13 | /// Used in Shuffle(T).
14 | ///
15 | private static readonly Random Random = new Random();
16 |
17 | ///
18 | /// Shuffle the array.
19 | ///
20 | /// Array element type.
21 | /// Array to shuffle.
22 | public static void Shuffle(T[] array)
23 | {
24 | int n = array.Length;
25 |
26 | for (var i = 0; i < n; i++)
27 | {
28 | // Use Next on random instance with an argument.
29 | // The argument is an exclusive bound.
30 | // So we will not go past the end of the array.
31 | int r = i + Random.Next(n - i);
32 |
33 | T t = array[r];
34 |
35 | array[r] = array[i];
36 |
37 | array[i] = t;
38 | }
39 | }
40 |
41 | ///
42 | /// Shuffle the list.
43 | ///
44 | /// Array element type.
45 | /// List to shuffle.
46 | /// The shuffled list.
47 | public static List Shuffle(List list)
48 | {
49 | T[] array = list.ToArray();
50 |
51 | Shuffle(array);
52 |
53 | return array.ToList();
54 | }
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.DatasetExpansion/Infrastructure/MnistImage.cs:
--------------------------------------------------------------------------------
1 | using System;
2 |
3 | namespace DigitRecognizer.DatasetExpansion.Infrastructure
4 | {
5 | ///
6 | /// Represents a raw MNIST image in binary format.
7 | ///
8 | public class MnistImage
9 | {
10 | private byte _label;
11 | private byte[] _pixels;
12 |
13 | ///
14 | /// Initializes a new instance of the class.
15 | ///
16 | /// The label.
17 | /// The pixels.
18 | public MnistImage(byte label, byte[] pixels)
19 | {
20 | _label = label;
21 | _pixels = pixels;
22 | }
23 |
24 | ///
25 | /// Gets the label.
26 | ///
27 | public byte Label => _label;
28 |
29 | ///
30 | /// Gets the pixels.
31 | ///
32 | public byte[] Pixels => _pixels;
33 |
34 | public override string ToString()
35 | {
36 | string result = string.Empty;
37 |
38 | for (var i = 0; i < 28; i++)
39 | {
40 | for (var j = 0; j < 28; j++)
41 | {
42 | result += _pixels[i * 28 + j] > 10 ? "0" : " ";
43 | }
44 |
45 | result += Environment.NewLine;
46 | }
47 |
48 | return result;
49 | }
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.DatasetExpansion/Program.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using DigitRecognizer.Core.Utilities;
4 | using DigitRecognizer.DatasetExpansion.Api;
5 | using DigitRecognizer.DatasetExpansion.Infrastructure;
6 |
7 | namespace DigitRecognizer.DatasetExpansion
8 | {
9 | internal class Program
10 | {
11 | private static void Main()
12 | {
13 | Console.Title = "Digit Recognizer - Dataset Expansion";
14 |
15 | Console.WriteLine("Starting dataset expansion");
16 |
17 | var expander = new DatasetExpander();
18 |
19 | List expandedDataset = expander.ExpandDataset();
20 |
21 | Console.WriteLine("Completed dataset expansion");
22 |
23 | Console.WriteLine("Starting dataset serialization");
24 |
25 | var serializer = new DatasetSerializer();
26 |
27 | serializer.SerializeDataset(DirectoryHelper.ExpandedTrainLabelsPath, DirectoryHelper.ExpandedTrainImagesPath, expandedDataset);
28 |
29 | Console.WriteLine("Completed dataset serialization");
30 |
31 | Console.ReadKey();
32 | }
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.DatasetExpansion/Properties/AssemblyInfo.cs:
--------------------------------------------------------------------------------
1 | using System.Reflection;
2 | using System.Runtime.CompilerServices;
3 | using System.Runtime.InteropServices;
4 |
5 | // General Information about an assembly is controlled through the following
6 | // set of attributes. Change these attribute values to modify the information
7 | // associated with an assembly.
8 | [assembly: AssemblyTitle("DigitRecognizer.DatasetExpansion")]
9 | [assembly: AssemblyDescription("")]
10 | [assembly: AssemblyConfiguration("")]
11 | [assembly: AssemblyCompany("")]
12 | [assembly: AssemblyProduct("DigitRecognizer.DatasetExpansion")]
13 | [assembly: AssemblyCopyright("Copyright © 2018")]
14 | [assembly: AssemblyTrademark("")]
15 | [assembly: AssemblyCulture("")]
16 |
17 | // Setting ComVisible to false makes the types in this assembly not visible
18 | // to COM components. If you need to access a type in this assembly from
19 | // COM, set the ComVisible attribute to true on that type.
20 | [assembly: ComVisible(false)]
21 |
22 | // The following GUID is for the ID of the typelib if this project is exposed to COM
23 | [assembly: Guid("757cf815-dd01-4f66-8e9f-db3ff9b4a463")]
24 |
25 | // Version information for an assembly consists of the following four values:
26 | //
27 | // Major Version
28 | // Minor Version
29 | // Build Number
30 | // Revision
31 | //
32 | // You can specify all the values or you can default the Build and Revision Numbers
33 | // by using the '*' as shown below:
34 | // [assembly: AssemblyVersion("1.0.*")]
35 | [assembly: AssemblyVersion("1.0.0.0")]
36 | [assembly: AssemblyFileVersion("1.0.0.0")]
37 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.DatasetExpansion/favico.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/DigitRecognizer.DatasetExpansion/favico.ico
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Engine/App.config:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Engine/DigitRecognizer.Engine.csproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Debug
6 | AnyCPU
7 | {00CE9C72-5F92-4555-BB96-D0E77B9A2390}
8 | Exe
9 | DigitRecognizer.Engine
10 | DigitRecognizer.Engine
11 | v4.8
12 | 512
13 | true
14 |
15 |
16 |
17 | AnyCPU
18 | true
19 | full
20 | false
21 | bin\Debug\
22 | DEBUG;TRACE
23 | prompt
24 | 4
25 |
26 |
27 | AnyCPU
28 | pdbonly
29 | true
30 | bin\Release\
31 | TRACE
32 | prompt
33 | 4
34 |
35 |
36 | favico.ico
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 | {BA726236-C7EC-41CA-ABB6-E4FE7E784939}
58 | DigitRecognizer.Core
59 |
60 |
61 | {BF82C3DF-0EFE-45DB-8E3D-1FFDD548AE73}
62 | DigitRecognizer.MachineLearning
63 |
64 |
65 |
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Engine/Program.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.IO;
4 | using System.Linq;
5 | using DigitRecognizer.Core.Data;
6 | using DigitRecognizer.Core.Extensions;
7 | using DigitRecognizer.Core.Utilities;
8 | using DigitRecognizer.MachineLearning.Pipeline;
9 | using DigitRecognizer.MachineLearning.Infrastructure.Functions;
10 | using DigitRecognizer.MachineLearning.Infrastructure.Initialization;
11 | using DigitRecognizer.MachineLearning.Infrastructure.Models;
12 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork;
13 | using DigitRecognizer.MachineLearning.Optimization.Optimizers;
14 | using DigitRecognizer.MachineLearning.Providers;
15 |
16 | namespace DigitRecognizer.Engine
17 | {
18 | internal class Program
19 | {
20 | private static void Main()
21 | {
22 | Console.Title = "Digit Recognizer - Engine";
23 |
24 | var learningRate = 0.00035;
25 | var epochs = 10;
26 | var regularizationFactor = 15.0;
27 |
28 | LearningPipeline pipeline = new LearningPipeline()
29 | .UseGradientClipping()
30 | .UseL2Regularization(regularizationFactor)
31 | .UseDropout(0.5)
32 | .SetWeightsInitializer(InitializerType.RandomInitialization)
33 | .SetEpochCount(epochs);
34 |
35 | var layers = new List
36 | {
37 | new NnLayer(784, 200, new LeakyRelu()),
38 | new NnLayer(200, 100, new LeakyRelu()),
39 | new NnLayer(100, 10, new Softmax())
40 | };
41 |
42 | var nn = new NeuralNetwork(layers, learningRate);
43 |
44 | var optimizer = new MomentumOptimizer(nn, new CrossEntropy(), 0.93);
45 |
46 | var provider = new BatchDataProvider(DirectoryHelper.ExpandedTrainLabelsPath, DirectoryHelper.ExpandedTrainImagesPath, 100);
47 |
48 | pipeline.Add(optimizer);
49 |
50 | pipeline.Add(nn);
51 |
52 | pipeline.Add(provider);
53 |
54 | PredictionModel model = pipeline.Run();
55 |
56 | var provider1 = new BatchDataProvider(DirectoryHelper.TestLabelsPath, DirectoryHelper.TestImagesPath, 10000);
57 | var acc = 0.0;
58 |
59 | MnistImageBatch data = provider1.GetData();
60 |
61 | List predictions = data.Pixels.Select(pixels => model.Predict(pixels)).ToList();
62 |
63 | for (var i = 0; i < data.Labels.Length; i++)
64 | {
65 | if (data.Labels[i] == predictions[i].ArgMax())
66 | {
67 | acc++;
68 | }
69 | }
70 |
71 | acc /= 10000.0;
72 |
73 | Console.WriteLine($"Accuracy on the test data is: {acc:P2}");
74 |
75 | string basePath = Path.GetFullPath(Path.GetDirectoryName(AppDomain.CurrentDomain.BaseDirectory) +
76 | DirectoryHelper.ModelsFolder);
77 |
78 | string modelName = $"{Guid.NewGuid()}-{acc:N4}.nn";
79 |
80 | string filename = $"{basePath}/{modelName}";
81 |
82 | model.Save(filename);
83 |
84 | Console.ReadKey();
85 | }
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Engine/Properties/AssemblyInfo.cs:
--------------------------------------------------------------------------------
1 | using System.Reflection;
2 | using System.Runtime.InteropServices;
3 |
4 | // General Information about an assembly is controlled through the following
5 | // set of attributes. Change these attribute values to modify the information
6 | // associated with an assembly.
7 | [assembly: AssemblyTitle("DigitRecognizer.Engine")]
8 | [assembly: AssemblyDescription("")]
9 | [assembly: AssemblyConfiguration("")]
10 | [assembly: AssemblyCompany("")]
11 | [assembly: AssemblyProduct("DigitRecognizer.Engine")]
12 | [assembly: AssemblyCopyright("Copyright © 2018")]
13 | [assembly: AssemblyTrademark("")]
14 | [assembly: AssemblyCulture("")]
15 |
16 | // Setting ComVisible to false makes the types in this assembly not visible
17 | // to COM components. If you need to access a type in this assembly from
18 | // COM, set the ComVisible attribute to true on that type.
19 | [assembly: ComVisible(false)]
20 |
21 | // The following GUID is for the ID of the typelib if this project is exposed to COM
22 | [assembly: Guid("00ce9c72-5f92-4555-bb96-d0e77b9a2390")]
23 |
24 | // Version information for an assembly consists of the following four values:
25 | //
26 | // Major Version
27 | // Minor Version
28 | // Build Number
29 | // Revision
30 | //
31 | // You can specify all the values or you can default the Build and Revision Numbers
32 | // by using the '*' as shown below:
33 | // [assembly: AssemblyVersion("1.0.*")]
34 | [assembly: AssemblyVersion("1.0.0.0")]
35 | [assembly: AssemblyFileVersion("1.0.0.0")]
36 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Engine/favico.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/DigitRecognizer.Engine/favico.ico
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Data/CalculationCache.cs:
--------------------------------------------------------------------------------
1 | using System.Collections.Generic;
2 | using DigitRecognizer.Core.Utilities;
3 |
4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Data
5 | {
6 | ///
7 | /// Represents a simple cache for storing values that need to be saved.
8 | ///
9 | public class CalculationCache
10 | {
11 | private readonly List _cache;
12 |
13 | ///
14 | /// Initializes a new instance of the class.
15 | ///
16 | public CalculationCache()
17 | {
18 | _cache = new List();
19 | }
20 |
21 | ///
22 | /// Gets the cache.
23 | ///
24 | ///
25 | public List GetCache()
26 | {
27 | return _cache;
28 | }
29 |
30 | ///
31 | /// Sets the specified value at the specified index.
32 | ///
33 | /// The value.
34 | /// The index.
35 | public void SetValue(double[][] value, int index)
36 | {
37 | Contracts.ValueNotNull(value, nameof(value));
38 | Contracts.ValueWithinBounds(index, 0, _cache.Count, nameof(index));
39 |
40 | if (index > _cache.Count - 1)
41 | {
42 | Add(value);
43 | }
44 |
45 | _cache[index] = value;
46 | }
47 |
48 | ///
49 | /// Addes the specified value to the cache.
50 | ///
51 | ///
52 | private void Add(double[][] value)
53 | {
54 | Contracts.ValueNotNull(value, nameof(value));
55 |
56 | _cache.Add(value);
57 | }
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Dropout/BinomialDistribution.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using DigitRecognizer.Core.Extensions;
3 |
4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Dropout
5 | {
6 | ///
7 | /// Represents a generator for a binomial distribution.
8 | ///
9 | public class BinomialDistribution
10 | {
11 | private readonly double _p;
12 | private readonly Random _rnd;
13 |
14 | ///
15 | /// Initializes a new instance of the class.
16 | ///
17 | /// The probability of success.
18 | /// The random number generator.
19 | public BinomialDistribution(double p, Random rnd)
20 | {
21 | _p = p;
22 | _rnd = rnd;
23 | }
24 |
25 | ///
26 | /// Generates a random binomial distribution of the specified length.
27 | ///
28 | /// The length.
29 | /// The binomial distribution.
30 | public int[] Generate(int length)
31 | {
32 | double[] uniform = _rnd.NextDoubles(length);
33 |
34 | var binomial = new int[length];
35 |
36 | for (var i = 0; i < length; i++)
37 | {
38 | binomial[i] = uniform[i] < _p ? 1 : 0;
39 | }
40 |
41 | return binomial;
42 | }
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Dropout/Dropout.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 |
4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Dropout
5 | {
6 | ///
7 | /// Implements the dropout regularization method.
8 | ///
9 | public class Dropout
10 | {
11 | private readonly double _keepProbability;
12 | private readonly BinomialDistribution _binomialDistributiion;
13 |
14 | ///
15 | /// Initializes a new instance of the class.
16 | ///
17 | /// The keep probability.
18 | public Dropout(double keepProbability)
19 | {
20 | if (keepProbability < 0 || keepProbability > 1.0)
21 | {
22 | throw new ArgumentOutOfRangeException(nameof(keepProbability));
23 | }
24 |
25 | _keepProbability = keepProbability;
26 |
27 | _binomialDistributiion = new BinomialDistribution(keepProbability, new Random());
28 | }
29 |
30 | ///
31 | /// Randomly generates a list of dropout vectors of the specified sizes.
32 | ///
33 | /// The array of sizes.
34 | /// The list of dropout vectors.
35 | public List GenerateDropoutVectors(int[] sizes)
36 | {
37 | var result = new List();
38 |
39 | foreach (int size in sizes)
40 | {
41 | var dropoutVector = new double[size];
42 |
43 | int[] binomial = _binomialDistributiion.Generate(size);
44 |
45 | // The reason to divide with the keep probability is so that the expected output of the network
46 | // during test time is not affacted by the expected output of the network during training time.
47 |
48 | for (var i = 0; i < size; i++)
49 | {
50 | dropoutVector[i] = binomial[i] / _keepProbability;
51 | }
52 |
53 | result.Add(dropoutVector);
54 | }
55 |
56 | return result;
57 | }
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Factories/AbstractTypeFactory.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.Linq;
4 | using System.Reflection;
5 |
6 | namespace DigitRecognizer.MachineLearning.Infrastructure.Factories
7 | {
8 | ///
9 | /// Represents a factory for getting instances of based on a key.
10 | ///
11 | /// The instance type the abstract factory returns.
12 | /// The key used in the type dictionary.
13 | public abstract class AbstractTypeFactory
14 | {
15 | private readonly Dictionary _typeDictionary;
16 |
17 | private readonly string _propName;
18 | private readonly Type _baseType;
19 |
20 | ///
21 | /// Initializes a new instance of the class.
22 | ///
23 | /// The base
24 | /// The key of the property of the .
25 | protected AbstractTypeFactory(Type baseType, string propName)
26 | {
27 | _propName = propName;
28 | _baseType = baseType;
29 |
30 | _typeDictionary = new Dictionary();
31 |
32 | FillTypeDictionary();
33 | }
34 |
35 | ///
36 | /// Fills the internal dictionary with data.
37 | ///
38 | private void FillTypeDictionary()
39 | {
40 | // Get all the types in the executing assembly that implement our interface, and are classes.
41 | IEnumerable types = Assembly
42 | .GetExecutingAssembly()
43 | .GetTypes()
44 | .Where(t => _baseType.IsAssignableFrom(t) && t.IsClass);
45 |
46 | // We get the value of the property that is on the interface and fill the dictionary with the property value and type.
47 | foreach (Type type in types)
48 | {
49 | PropertyInfo prop = type.GetProperty(_propName);
50 |
51 | if (prop == null)
52 | {
53 | throw new ArgumentNullException(nameof(prop), $"Property {_propName} was not found on type {type.Name}");
54 | }
55 |
56 | object obj = Activator.CreateInstance(type);
57 |
58 | _typeDictionary.Add((TKey)prop.GetValue(obj), type);
59 | }
60 | }
61 |
62 | ///
63 | /// Gets the instance of the for the specified key.
64 | ///
65 | ///
66 | /// If no matching type is found an exception is thrown.
67 | /// This indicates something is wrong in the system.
68 | /// Generally, this should never be thrown.
69 | ///
70 | /// The key of the dictionary.
71 | /// The instance of the type, if found.
72 | public TInstance GetInstance(TKey key)
73 | {
74 | Type type = _typeDictionary[key];
75 |
76 | object result = Activator.CreateInstance(type);
77 |
78 | return (TInstance)result;
79 | }
80 | }
81 | }
82 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Factories/FunctionFactory.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.MachineLearning.Infrastructure.Functions;
2 |
3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Factories
4 | {
5 | ///
6 | /// Factory for getting instances of by providing the name of the function.
7 | ///
8 | public class FunctionFactory : AbstractTypeFactory
9 | {
10 | ///
11 | /// The singleton instance.
12 | ///
13 | private static FunctionFactory _instance;
14 |
15 | ///
16 | /// The object used for locking. Required for thread safety.
17 | ///
18 | private static readonly object Lock = new object();
19 |
20 | ///
21 | /// Gets the singleton instance of the class.
22 | ///
23 | public static FunctionFactory Instance
24 | {
25 | get
26 | {
27 | lock (Lock)
28 | {
29 | return _instance ?? (_instance = new FunctionFactory());
30 | }
31 | }
32 | }
33 |
34 | ///
35 | /// Initializes a new instance of the class.
36 | ///
37 | private FunctionFactory() : base(typeof(IFunction), "Name")
38 | {
39 | }
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Factories/InitializerFactory.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.MachineLearning.Infrastructure.Initialization;
2 |
3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Factories
4 | {
5 | ///
6 | /// Factory for getting instances of by providing the .
7 | ///
8 | public class InitializerFactory : AbstractTypeFactory
9 | {
10 | ///
11 | /// The singleton instance.
12 | ///
13 | private static InitializerFactory _instance;
14 |
15 | ///
16 | /// The object used for locking. Required for thread safety.
17 | ///
18 | private static readonly object Lock = new object();
19 |
20 | ///
21 | /// Gets the singleton instance of the class.
22 | ///
23 | public static InitializerFactory Instance
24 | {
25 | get
26 | {
27 | lock (Lock)
28 | {
29 | return _instance ?? (_instance = new InitializerFactory());
30 | }
31 | }
32 | }
33 |
34 | ///
35 | /// Initializes a new instance of the class.
36 | ///
37 | private InitializerFactory() : base(typeof(IInitializer), "InitializerType")
38 | {
39 | }
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/CrossEntropy.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Extensions;
2 | using DigitRecognizer.Core.Utilities;
3 |
4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions
5 | {
6 | ///
7 | /// Implements the cross entropy cost function.
8 | ///
9 | public class CrossEntropy : ICostFunction
10 | {
11 | public string Name => "Cross Entropy";
12 |
13 | ///
14 | /// Calculates the cost for the specified estimated and actual values.
15 | ///
16 | /// The estimated values.
17 | /// The actual values.
18 | /// The cost.
19 | public double Cost(double[] estimatedValues, double[] actualValues)
20 | {
21 | double cost = MathUtilities.CrossEntropy(estimatedValues, actualValues);
22 |
23 | return cost;
24 | }
25 |
26 | ///
27 | /// Gets the derivative of the function for the specified input.
28 | ///
29 | /// The input.
30 | /// The one hot encoded array.
31 | /// The derivative with respect to each input.
32 | public double[] Derivative(double[] input, double[] oneHot)
33 | {
34 | var result = new double[input.Length];
35 |
36 | int oneHotIndex = oneHot.ArgMax();
37 |
38 | for (var i = 0; i < input.Length; i++)
39 | {
40 | double delta = i == oneHotIndex ? 1.0 : 0.0;
41 |
42 | result[i] = input[i] - delta;
43 | }
44 |
45 | return result;
46 | }
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/ExponentialRelu.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Utilities;
2 |
3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions
4 | {
5 | ///
6 | /// Implements the exponential RELU activation function.
7 | ///
8 | public class ExponentialRelu : IActivationFunction
9 | {
10 | private const double Alpha = 0.01;
11 |
12 | public string Name => "Exponential Relu";
13 |
14 | ///
15 | /// Applies the activation function on every element of the specified array.
16 | ///
17 | /// The array of values.
18 | /// The array with values fed through the activation function.
19 | public double[] Activate(double[] arr)
20 | {
21 | double[] activations = MathUtilities.ExponentialRelu(arr, Alpha);
22 |
23 | return activations;
24 | }
25 |
26 | ///
27 | /// Determines the derivative of the function for the specified inputs.
28 | ///
29 | /// The input.
30 | /// The one hot encoded element.
31 | /// The derivative of the function for every input.
32 | public double[] Derivative(double[] input, double[] oneHot)
33 | {
34 | var result = new double[input.Length];
35 |
36 | for (var i = 0; i < input.Length; i++)
37 | {
38 | result[i] = input[i] > 0 ? 1.0 : MathUtilities.ExponentialRelu(input[i], Alpha) + Alpha;
39 | }
40 |
41 | return result;
42 | }
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/IActivationFunction.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions
2 | {
3 | ///
4 | /// Implement that all activation functions must implement.
5 | ///
6 | public interface IActivationFunction : IDifferentiableFunction, IFunction
7 | {
8 | ///
9 | /// Applies the activation function on every element of the specified array.
10 | ///
11 | /// The array of values.
12 | /// The array with values fed through the activation function.
13 | double[] Activate(double[] arr);
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/ICostFunction.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions
2 | {
3 | ///
4 | /// Interface that all cost functions must implement.
5 | ///
6 | public interface ICostFunction : IDifferentiableFunction, IFunction
7 | {
8 | ///
9 | /// Gets the cost for the specified estimated and actual values.
10 | ///
11 | /// The estimated values.
12 | /// The actual values.
13 | /// The cost.
14 | double Cost(double[] estimatedValues, double[] actualValues);
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/IDifferentiableFunction.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions
2 | {
3 | ///
4 | /// This interface should be implemented by all differentiable functions.
5 | ///
6 | public interface IDifferentiableFunction
7 | {
8 | ///
9 | /// Determines the derivative of the function for the specified inputs.
10 | ///
11 | /// The input.
12 | /// The one hot encoded element.
13 | /// The derivative of the function for every input.
14 | double[] Derivative(double[] input, double[] oneHot);
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/IFunction.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions
2 | {
3 | ///
4 | /// Interface that all function must implement.
5 | ///
6 | public interface IFunction
7 | {
8 | ///
9 | /// Gets the name of the function.
10 | ///
11 | string Name { get; }
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/LeakyRelu.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Utilities;
2 |
3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions
4 | {
5 | ///
6 | /// Implements the leaky RELU activation function.
7 | ///
8 | public class LeakyRelu : IActivationFunction
9 | {
10 | private const double Alpha = 0.01;
11 |
12 | public string Name => "Leaky Relu";
13 |
14 | ///
15 | /// Applies the activation function on every element of the specified array.
16 | ///
17 | /// The array of values.
18 | /// The array with values fed through the activation function.
19 | public double[] Activate(double[] arr)
20 | {
21 | double[] activations = MathUtilities.LeakyRelu(arr, Alpha);
22 |
23 | return activations;
24 | }
25 |
26 | ///
27 | /// Determines the derivative of the function for the specified inputs.
28 | ///
29 | /// The input.
30 | /// The one hot encoded element.
31 | /// The derivative of the function for every input.
32 | public double[] Derivative(double[] input, double[] oneHot)
33 | {
34 | var result = new double[input.Length];
35 |
36 | for (var i = 0; i < input.Length; i++)
37 | {
38 | result[i] = input[i] > 0 ? 1.0 : Alpha;
39 | }
40 |
41 | return result;
42 | }
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/MeanSquareError.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Extensions;
2 | using DigitRecognizer.Core.Utilities;
3 |
4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions
5 | {
6 | ///
7 | /// Implements the mean square error function.
8 | ///
9 | public class MeanSquareError : ICostFunction
10 | {
11 | public string Name => "Mean Square Error";
12 |
13 | ///
14 | /// Calculates the cost for the specified estimated and actual values.
15 | ///
16 | /// The estimated values.
17 | /// The actual values.
18 | /// The cost.
19 | public double Cost(double[] estimatedValues, double[] actualValues)
20 | {
21 | double cost = MathUtilities.MeanSquareErr(estimatedValues, actualValues);
22 |
23 | return cost;
24 | }
25 |
26 | ///
27 | /// Gets the derivative of the function for the specified input.
28 | ///
29 | /// The input.
30 | /// The one hot encoded array.
31 | /// The derivative with respect to each input.
32 | public double[] Derivative(double[] input, double[] oneHot)
33 | {
34 | var result = new double[input.Length];
35 |
36 | int oneHotIndex = oneHot.ArgMax();
37 |
38 | for (var i = 0; i < input.Length; i++)
39 | {
40 | double delta = i == oneHotIndex ? 1.0 : 0.0;
41 |
42 | result[i] = input[i] - delta;
43 | }
44 |
45 | return result;
46 | }
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/Relu.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Utilities;
2 |
3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions
4 | {
5 | ///
6 | /// Implements the RELU activation function.
7 | ///
8 | public class Relu : IActivationFunction
9 | {
10 | public string Name => "Relu";
11 |
12 | ///
13 | /// Applies the activation function on every element of the specified array.
14 | ///
15 | /// The array of values.
16 | /// The array with values fed through the activation function.
17 | public double[] Activate(double[] arr)
18 | {
19 | double[] activations = MathUtilities.Relu(arr);
20 |
21 | return activations;
22 | }
23 |
24 | ///
25 | /// Determines the derivative of the function for the specified inputs.
26 | ///
27 | /// The input.
28 | /// The one hot encoded element.
29 | /// The derivative of the function for every input.
30 | public double[] Derivative(double[] input, double[] oneHot)
31 | {
32 | var result = new double[input.Length];
33 |
34 | for (var i = 0; i < input.Length; i++)
35 | {
36 | result[i] = input[i] > 0 ? 1.0 : 0.0;
37 | }
38 |
39 | return result;
40 | }
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/Sigmoid.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Utilities;
2 |
3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions
4 | {
5 | ///
6 | /// Implements the sigmoid activation function.
7 | ///
8 | public class Sigmoid : IActivationFunction
9 | {
10 | public string Name => "Sigmoid";
11 |
12 | ///
13 | /// Applies the activation function on every element of the specified array.
14 | ///
15 | /// The array of values.
16 | /// The array with values fed through the activation function.
17 | public double[] Activate(double[] arr)
18 | {
19 | double[] activations = MathUtilities.Sigmoid(arr);
20 |
21 | return activations;
22 | }
23 |
24 | ///
25 | /// Determines the derivative of the function for the specified inputs.
26 | ///
27 | /// The input.
28 | /// The one hot encoded element.
29 | /// The derivative of the function for every input.
30 | public double[] Derivative(double[] input, double[] oneHot)
31 | {
32 | var result = new double[input.Length];
33 |
34 | for (var i = 0; i < input.Length; i++)
35 | {
36 | double sigmoid = MathUtilities.Sigmoid(input[i]);
37 |
38 | result[i] = sigmoid / (1 - sigmoid);
39 | }
40 |
41 | return result;
42 | }
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/Softmax.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Extensions;
2 | using DigitRecognizer.Core.Utilities;
3 |
4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions
5 | {
6 | ///
7 | /// Implements the softmax activation function.
8 | ///
9 | public class Softmax : IActivationFunction
10 | {
11 | public string Name => "Softmax";
12 |
13 | ///
14 | /// Applies the activation function on every element of the specified array.
15 | ///
16 | /// The array of values.
17 | /// The array with values fed through the activation function.
18 | public double[] Activate(double[] arr)
19 | {
20 | double[] activations = MathUtilities.Softmax(arr);
21 |
22 | return activations;
23 | }
24 |
25 | ///
26 | /// Determines the derivative of the function for the specified inputs.
27 | ///
28 | /// The input.
29 | /// The one hot encoded element.
30 | /// The derivative of the function for every input.
31 | public double[] Derivative(double[] input, double[] oneHot)
32 | {
33 | var result = new double[input.Length];
34 |
35 | int oneHotIndex = oneHot.ArgMax();
36 |
37 | for (var i = 0; i < input.Length; i++)
38 | {
39 | double delta = i == oneHotIndex ? 1.0 : 0.0;
40 |
41 | result[i] = input[i] * (delta - input[oneHotIndex]);
42 | }
43 |
44 | return result;
45 | }
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/Softplus.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Utilities;
2 |
3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions
4 | {
5 | ///
6 | /// Implements the softmax activation function.
7 | ///
8 | public class Softplus : IActivationFunction
9 | {
10 | public string Name => "Softplus";
11 |
12 | ///
13 | /// Applies the activation function on every element of the specified array.
14 | ///
15 | /// The array of values.
16 | /// The array with values fed through the activation function.
17 | public double[] Activate(double[] arr)
18 | {
19 | double[] activations = MathUtilities.Softplus(arr);
20 |
21 | return activations;
22 | }
23 |
24 | ///
25 | /// Determines the derivative of the function for the specified inputs.
26 | ///
27 | /// The input.
28 | /// The one hot encoded element.
29 | /// The derivative of the function for every input.
30 | public double[] Derivative(double[] input, double[] oneHot)
31 | {
32 | var result = new double[input.Length];
33 |
34 | for (var i = 0; i < input.Length; i++)
35 | {
36 | result[i] = MathUtilities.Sigmoid(input[i]);
37 | }
38 |
39 | return result;
40 | }
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Functions/Tanh.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Utilities;
2 |
3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Functions
4 | {
5 | ///
6 | /// Implements the hyperbolic tangent activation function.
7 | ///
8 | public class Tanh : IActivationFunction
9 | {
10 | public string Name => "Tanh";
11 |
12 | ///
13 | /// Applies the activation function on every element of the specified array.
14 | ///
15 | /// The array of values.
16 | /// The array with values fed through the activation function.
17 | public double[] Activate(double[] arr)
18 | {
19 | double[] activations = MathUtilities.Tanh(arr);
20 |
21 | return activations;
22 | }
23 |
24 | ///
25 | /// Determines the derivative of the function for the specified inputs.
26 | ///
27 | /// The input.
28 | /// The one hot encoded element.
29 | /// The derivative of the function for every input.
30 | public double[] Derivative(double[] input, double[] oneHot)
31 | {
32 | var result = new double[input.Length];
33 |
34 | for (var i = 0; i < input.Length; i++)
35 | {
36 | double tanh = MathUtilities.Tanh(input[i]);
37 |
38 | result[i] = 1 - tanh * tanh;
39 | }
40 |
41 | return result;
42 | }
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Initialization/HeInitializer.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using DigitRecognizer.Core.Utilities;
3 |
4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Initialization
5 | {
6 | ///
7 | /// Implements the He initialization algorithm.
8 | ///
9 | public class HeInitializer : IInitializer
10 | {
11 | ///
12 | /// Gest the .
13 | ///
14 | public InitializerType InitializerType => InitializerType.HeInitialization;
15 |
16 | ///
17 | /// Returns an initialized matrix based on the specified parameters.
18 | ///
19 | /// The row count.
20 | /// The col count.
21 | /// The initialized matrix.
22 | public double[][] Initialize(int rowCount, int colCont)
23 | {
24 | double[][] result = VectorUtilities.CreateMatrix(rowCount, colCont);
25 |
26 | var random = new Random();
27 |
28 | double factor = Math.Sqrt(2.0 / rowCount);
29 |
30 | for (var i = 0; i < rowCount; i++)
31 | {
32 | for (var j = 0; j < colCont; j++)
33 | {
34 | result[i][j] = random.NextDouble() * factor;
35 | }
36 | }
37 |
38 | return result;
39 | }
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Initialization/IInitializer.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.MachineLearning.Infrastructure.Initialization
2 | {
3 | ///
4 | /// Interface for an initialization algorithm.
5 | ///
6 | public interface IInitializer
7 | {
8 | ///
9 | /// Gets the .
10 | ///
11 | InitializerType InitializerType { get; }
12 |
13 | ///
14 | /// Returns an initialized matrix based on the specified parameters.
15 | ///
16 | /// The row count.
17 | /// The col count.
18 | /// The initialized matrix.
19 | double[][] Initialize(int rowCount, int colCont);
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Initialization/InitializerType.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.MachineLearning.Infrastructure.Initialization
2 | {
3 | ///
4 | /// Represents the type of initialization method.
5 | ///
6 | public enum InitializerType
7 | {
8 | ///
9 | /// Initializes all values to zero.
10 | ///
11 | ZeroInitialization = 0,
12 |
13 | ///
14 | /// Initializes all values to random values.
15 | ///
16 | RandomInitialization = 1,
17 |
18 | ///
19 | /// Initializes all values using the Xavier initialization method.
20 | ///
21 | XavierInitialization = 2,
22 |
23 | ///
24 | /// Initializes all values using the He-et-al initialization method.
25 | ///
26 | HeInitialization = 3
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Initialization/RandomInitializer.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using DigitRecognizer.Core.Utilities;
3 |
4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Initialization
5 | {
6 | ///
7 | /// Implements the random initialization algorithm.
8 | ///
9 | public class RandomInitializer : IInitializer
10 | {
11 | ///
12 | /// Gest the .
13 | ///
14 | public InitializerType InitializerType => InitializerType.RandomInitialization;
15 |
16 | ///
17 | /// Returns an initialized matrix based on the specified parameters.
18 | ///
19 | /// The row count.
20 | /// The col count.
21 | /// The initialized matrix.
22 | public double[][] Initialize(int rowCount, int colCont)
23 | {
24 | double[][] result = VectorUtilities.CreateMatrix(rowCount, colCont);
25 |
26 | var random = new Random();
27 |
28 | const double factor = 0.01;
29 |
30 | for (var i = 0; i < rowCount; i++)
31 | {
32 | for (var j = 0; j < colCont; j++)
33 | {
34 | result[i][j] = random.NextDouble() * factor;
35 | }
36 | }
37 |
38 | return result;
39 | }
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Initialization/XavierInitializer.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using DigitRecognizer.Core.Utilities;
3 |
4 | namespace DigitRecognizer.MachineLearning.Infrastructure.Initialization
5 | {
6 | ///
7 | /// Implements the Xavier initialization algorithm.
8 | ///
9 | public class XavierInitializer : IInitializer
10 | {
11 | ///
12 | /// Gest the .
13 | ///
14 | public InitializerType InitializerType => InitializerType.XavierInitialization;
15 |
16 | ///
17 | /// Returns an initialized matrix based on the specified parameters.
18 | ///
19 | /// The row count.
20 | /// The col count.
21 | /// The initialized matrix.
22 | public double[][] Initialize(int rowCount, int colCont)
23 | {
24 | double[][] result = VectorUtilities.CreateMatrix(rowCount, colCont);
25 |
26 | var random = new Random();
27 |
28 | double factor = Math.Sqrt(1.0 / rowCount);
29 |
30 | for (var i = 0; i < rowCount; i++)
31 | {
32 | for (var j = 0; j < colCont; j++)
33 | {
34 | result[i][j] = random.NextDouble() * factor;
35 | }
36 | }
37 |
38 | return result;
39 | }
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Initialization/ZeroInitializer.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Utilities;
2 |
3 | namespace DigitRecognizer.MachineLearning.Infrastructure.Initialization
4 | {
5 | ///
6 | /// Implements the zero initialization algorithm.
7 | ///
8 | public class ZeroInitializer : IInitializer
9 | {
10 | ///
11 | /// Gest the .
12 | ///
13 | public InitializerType InitializerType => InitializerType.ZeroInitialization;
14 |
15 | ///
16 | /// Returns an initialized matrix based on the specified parameters.
17 | ///
18 | /// The row count.
19 | /// The col count.
20 | /// The initialized matrix.
21 | public double[][] Initialize(int rowCount, int colCont)
22 | {
23 | return VectorUtilities.CreateMatrix(rowCount, colCont);
24 | }
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Models/ClusterPredictionModel.cs:
--------------------------------------------------------------------------------
1 | using System.Collections.Generic;
2 | using System.Linq;
3 | using DigitRecognizer.Core.Extensions;
4 | using DigitRecognizer.Core.Utilities;
5 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork;
6 |
7 | namespace DigitRecognizer.MachineLearning.Infrastructure.Models
8 | {
9 | public class ClusterPredictionModel : IPredictionModel
10 | {
11 | private readonly List _cluster;
12 |
13 | public ClusterPredictionModel(IEnumerable models)
14 | {
15 | Contracts.ValueNotNull(models, nameof(models));
16 |
17 | _cluster = new List(models);
18 | }
19 |
20 | public ClusterPredictionModel(IEnumerable networks)
21 | {
22 | Contracts.ValueNotNull(networks, nameof(networks));
23 |
24 | _cluster = new List(networks.Select(x=> new PredictionModel(x)));
25 | }
26 |
27 | ///
28 | /// Predicts the output for the specified input.
29 | ///
30 | /// The input.
31 | /// The output as a vector of probabilites.
32 | public double[] Predict(double[] input)
33 | {
34 | var predictions = new double[_cluster.Count][];
35 |
36 | for (var i = 0; i < _cluster.Count; i++)
37 | {
38 | double[] prediction = _cluster[i].Predict(input);
39 |
40 | predictions[i] = prediction;
41 | }
42 |
43 | double[] result = predictions.Average();
44 |
45 | return result;
46 | }
47 |
48 | public static ClusterPredictionModel FromFiles(string[] filenames)
49 | {
50 | return new ClusterPredictionModel(filenames.Select(PredictionModel.FromFile));
51 | }
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Models/IPredictionModel.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.MachineLearning.Infrastructure.Models
2 | {
3 | ///
4 | /// Interface for a prediction model.
5 | ///
6 | public interface IPredictionModel
7 | {
8 | ///
9 | /// Predicts the output for the specified input.
10 | ///
11 | /// The input.
12 | /// The output as a vector of probabilites.
13 | double[] Predict(double[] input);
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/Models/PredictionModel.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Extensions;
2 | using DigitRecognizer.Core.Utilities;
3 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork;
4 | using DigitRecognizer.MachineLearning.Serialization;
5 |
6 | namespace DigitRecognizer.MachineLearning.Infrastructure.Models
7 | {
8 | ///
9 | /// Represents a model that runs an internal neural network to predict input values.
10 | ///
11 | public class PredictionModel : IPredictionModel
12 | {
13 | private readonly INeuralNetwork _neuralNetwork;
14 |
15 | ///
16 | /// Initializes a new instance of the class.
17 | ///
18 | ///
19 | public PredictionModel(INeuralNetwork neuralNetwork)
20 | {
21 | Contracts.ValueNotNull(neuralNetwork, nameof(neuralNetwork));
22 |
23 | _neuralNetwork = neuralNetwork;
24 | }
25 |
26 | ///
27 | /// Predicts the output for the specified input.
28 | ///
29 | /// The input.
30 | /// The output as a vector of probabilites.
31 | public double[] Predict(double[] input)
32 | {
33 | double[][] prediction = _neuralNetwork.FeedForward(input.AsMatrix());
34 |
35 | return prediction[0];
36 | }
37 |
38 | ///
39 | /// Saves the to the specified file.
40 | ///
41 | /// The filename.
42 | public void Save(string filename)
43 | {
44 | var serializer = new NnSerializer();
45 |
46 | serializer.Serialize(filename, _neuralNetwork);
47 | }
48 |
49 | ///
50 | /// Creates a from the specified file.
51 | ///
52 | /// The filename.
53 | /// The deserialized prediction model.
54 | public static PredictionModel FromFile(string filename)
55 | {
56 | Contracts.FileExists(filename, nameof(filename));
57 | Contracts.FileExtensionValid(filename, ".nn", nameof(filename));
58 |
59 | var deserializer = new NnDeserializer();
60 |
61 | NeuralNetwork.NeuralNetwork neuralNetwork = deserializer.Deserialize(filename);
62 |
63 | return new PredictionModel(neuralNetwork);
64 | }
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/NeuralNetwork/BiasVector.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Utilities;
2 |
3 | namespace DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork
4 | {
5 | ///
6 | /// Represents a bias vector that is part of a .
7 | ///
8 | public class BiasVector : IValueAdjustable
9 | {
10 | private double[] _bias;
11 |
12 | ///
13 | /// Initializes a new instance of the class with the specified value.
14 | ///
15 | ///
16 | public BiasVector(double[] bias)
17 | {
18 | _bias = bias;
19 | }
20 |
21 | ///
22 | /// Initializes a new instance of the class with the specified length.
23 | ///
24 | /// The length.
25 | public BiasVector(int length)
26 | {
27 | Contracts.ValueGreaterThanZero(length, nameof(length));
28 |
29 | _bias = new double[length];
30 | }
31 |
32 | ///
33 | /// Gets the bias of the .
34 | ///
35 | public double[] Bias => _bias;
36 |
37 | ///
38 | /// Gets the length of the .
39 | ///
40 | public int Length => _bias.Length;
41 |
42 | ///
43 | /// Gets the size of the in bytes.
44 | ///
45 | public int SizeInBytes => Length * sizeof(double);
46 |
47 | ///
48 | /// Adjusts the biases of the using the specified gradient.
49 | ///
50 | /// The gradient with respect to the biases.
51 | /// The learning rate.
52 | public void AdjustValue(double[][] gradient, double learningRate)
53 | {
54 | int rowCount = gradient.Length;
55 | int colCount = gradient[0].Length;
56 |
57 | const int rowIndex = 0;
58 | Contracts.ValuesMatch(1, rowCount, nameof(rowCount));
59 | Contracts.ValuesMatch(_bias.Length, colCount, nameof(colCount));
60 |
61 | for (var i = 0; i < colCount; i++)
62 | {
63 | _bias[i] = _bias[i] - gradient[rowIndex][i] * learningRate;
64 | }
65 | }
66 |
67 | public static implicit operator double[] (BiasVector b)
68 | {
69 | return b.Bias;
70 | }
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/NeuralNetwork/INeuralNetwork.cs:
--------------------------------------------------------------------------------
1 | using System.Collections.Generic;
2 | using DigitRecognizer.MachineLearning.Pipeline;
3 |
4 | namespace DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork
5 | {
6 | ///
7 | /// Contains basic methods that a neural network should implement.
8 | ///
9 | public interface INeuralNetwork : ILearningPipelineNeuralNetworkModel
10 | {
11 | double LearningRate { get; set; }
12 | int NumberOfLayers { get; }
13 |
14 | Core.Data.LinkedList Layers { get; }
15 | List WeightedSumCache { get; }
16 | List ActivationCache { get; }
17 |
18 | double[][] FeedForward(double[][] input);
19 | void AddLayer(IEnumerable layers);
20 | void AddLayer(NnLayer layer);
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/NeuralNetwork/IValueAdjustable.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork
2 | {
3 | ///
4 | /// Interface that should be implemented by classes whose values should be adjusted.
5 | ///
6 | public interface IValueAdjustable
7 | {
8 | ///
9 | /// Adjusts the current values with specified gradient and learning rate.
10 | ///
11 | /// The gradient values.
12 | /// The learning rate.
13 | void AdjustValue(double[][] gradient, double learningRate);
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Infrastructure/NeuralNetwork/WeightMatrix.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Utilities;
2 | using DigitRecognizer.MachineLearning.Infrastructure.Factories;
3 | using DigitRecognizer.MachineLearning.Infrastructure.Initialization;
4 | using DigitRecognizer.MachineLearning.Pipeline;
5 |
6 | namespace DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork
7 | {
8 | ///
9 | /// Represents a weight matrix that is part of a .
10 | ///
11 | public class WeightMatrix : IValueAdjustable
12 | {
13 | private readonly double[][] _weights;
14 |
15 | ///
16 | /// Initializes a new instance of the class with the specified parameters.
17 | ///
18 | /// The row count.
19 | /// The col count.
20 | public WeightMatrix(int rowCount, int colCount)
21 | {
22 | Contracts.ValueGreaterThanZero(rowCount, nameof(rowCount));
23 | Contracts.ValueGreaterThanZero(colCount, nameof(colCount));
24 |
25 | IInitializer initializer = InitializerFactory.Instance.GetInstance(PipelineSettings.Instance.WeightsInitializerType);
26 |
27 | _weights = initializer.Initialize(rowCount, colCount);
28 | }
29 |
30 | ///
31 | /// Initializes a new instance of the class with the specified value.
32 | ///
33 | /// The weights.
34 | public WeightMatrix(double[][] weights)
35 | {
36 | _weights = weights;
37 | }
38 |
39 | ///
40 | /// Gets the row count.
41 | ///
42 | public int RowCount => _weights.Length;
43 |
44 | ///
45 | /// Gets the column count.
46 | ///
47 | public int ColCount => _weights[0].Length;
48 |
49 | ///
50 | /// Gets the flattened .
51 | ///
52 | public double[] Flattened => VectorUtilities.Flatten(_weights);
53 |
54 | ///
55 | /// Gets the size of the in the bytes.
56 | ///
57 | public int SizeInBytes => ColCount * RowCount * sizeof(double);
58 |
59 | ///
60 | /// Gets the weights of the .
61 | ///
62 | public double[][] Weights => _weights;
63 |
64 | ///
65 | /// Adjusts the weights of the using the specified gradient.
66 | ///
67 | /// The gradient with respect to the weights.
68 | /// The learning rate.
69 | public void AdjustValue(double[][] gradient, double learningRate)
70 | {
71 | int rowCount = gradient.Length;
72 | int colCount = gradient[0].Length;
73 | Contracts.ValuesMatch(RowCount, rowCount, nameof(RowCount));
74 | Contracts.ValuesMatch(ColCount, colCount, nameof(RowCount));
75 |
76 | var regularizationFactor = 1.0;
77 |
78 | if (PipelineSettings.Instance.UseL2Regularization)
79 | {
80 | regularizationFactor = 1 - PipelineSettings.Instance.RegularizationFactor * learningRate / PipelineSettings.Instance.DatasetSize;
81 | }
82 |
83 | for (var i = 0; i < rowCount; i++)
84 | {
85 | for (var j = 0; j < colCount; j++)
86 | {
87 | _weights[i][j] = regularizationFactor * _weights[i][j] - gradient[i][j] * learningRate;
88 | }
89 | }
90 | }
91 |
92 | public static implicit operator double[][] (WeightMatrix m)
93 | {
94 | return m._weights;
95 | }
96 | }
97 | }
98 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Optimization/LearningRateDecay/ExponentailDecay.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using DigitRecognizer.Core.Utilities;
3 | using DigitRecognizer.MachineLearning.Pipeline;
4 |
5 | namespace DigitRecognizer.MachineLearning.Optimization.LearningRateDecay
6 | {
7 | ///
8 | /// Implements the exponential decay learning rate scheduler.
9 | ///
10 | public class ExponentailDecay : ILearningRateDecay
11 | {
12 | ///
13 | /// Represents the decay rate to be applied to the learning rate.
14 | ///
15 | private readonly double _decayRate;
16 |
17 | ///
18 | /// The initial learning rate.
19 | ///
20 | private readonly double _initialLearningRate;
21 |
22 | ///
23 | /// Initializes a new instance of the class.
24 | ///
25 | /// The initial learning rate.
26 | /// The decay rate.
27 | public ExponentailDecay(double initialLearningRate, double decayRate)
28 | {
29 | Contracts.ValueGreaterThanZero(initialLearningRate, nameof(initialLearningRate));
30 | Contracts.ValueGreaterThanZero(decayRate, nameof(decayRate));
31 |
32 | _initialLearningRate = initialLearningRate;
33 | _decayRate = decayRate;
34 | }
35 |
36 | ///
37 | /// Calculates the learning rate.
38 | ///
39 | /// The current learning rate.
40 | /// The new learning rate.
41 | public double DecayLearningRate(double learningRate)
42 | {
43 | learningRate = _initialLearningRate * Math.Exp(-_decayRate * PipelineSettings.Instance.CurrentIteration);
44 |
45 | return learningRate;
46 | }
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Optimization/LearningRateDecay/ILearningRateDecay.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.MachineLearning.Optimization.LearningRateDecay
2 | {
3 | ///
4 | /// Interface for a learning rate decay method.
5 | ///
6 | public interface ILearningRateDecay
7 | {
8 | ///
9 | /// Calculates the learning rate.
10 | ///
11 | /// The current learning rate.
12 | /// The new learning rate.
13 | double DecayLearningRate(double learningRate);
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Optimization/LearningRateDecay/StepDecay.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using DigitRecognizer.Core.Utilities;
3 | using DigitRecognizer.MachineLearning.Pipeline;
4 |
5 | namespace DigitRecognizer.MachineLearning.Optimization.LearningRateDecay
6 | {
7 | ///
8 | /// Implements the step decay learning rate scheduler.
9 | ///
10 | public class StepDecay : ILearningRateDecay
11 | {
12 | ///
13 | /// Reperesents the drop to be applied to the learning rate.
14 | ///
15 | private readonly double _drop;
16 |
17 | ///
18 | /// Reperesnts the number of epochs that need to elapse in order to drop the learning rate.
19 | ///
20 | private readonly double _epochDrop;
21 |
22 | ///
23 | /// The initial learning rate.
24 | ///
25 | private readonly double _initialLearningRate;
26 |
27 | ///
28 | /// Initializes a new instance of the class.
29 | ///
30 | /// The initial learning rate.
31 | /// The drop.
32 | /// The epoch drop.
33 | public StepDecay(double initialLearningRate, double drop, double epochDrop)
34 | {
35 | Contracts.ValueGreaterThanZero(initialLearningRate, nameof(initialLearningRate));
36 | Contracts.ValueGreaterThanZero(drop, nameof(drop));
37 | Contracts.ValueGreaterThanZero(epochDrop, nameof(epochDrop));
38 |
39 | _initialLearningRate = initialLearningRate;
40 | _drop = drop;
41 | _epochDrop = epochDrop;
42 | }
43 |
44 | ///
45 | /// Calculates the learning rate.
46 | ///
47 | /// The current learning rate.
48 | /// The new learning rate.
49 | public double DecayLearningRate(double learningRate)
50 | {
51 | learningRate = _initialLearningRate * Math.Pow(_drop, Math.Floor(PipelineSettings.Instance.CurrentEpoch / _epochDrop));
52 |
53 | return learningRate;
54 | }
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Optimization/LearningRateDecay/TimeBasedDecay.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Utilities;
2 | using DigitRecognizer.MachineLearning.Pipeline;
3 |
4 | namespace DigitRecognizer.MachineLearning.Optimization.LearningRateDecay
5 | {
6 | ///
7 | /// Implements the time based decayRate learning rate scheduler.
8 | ///
9 | public class TimeBasedDecay : ILearningRateDecay
10 | {
11 | ///
12 | /// The decay rate to be applied to the learning rate.
13 | ///
14 | private readonly double _decayRate;
15 |
16 | ///
17 | /// Initializes a new instance of the class.
18 | ///
19 | /// The decay rate.
20 | public TimeBasedDecay(double decayRate)
21 | {
22 | Contracts.ValueGreaterThanZero(decayRate, nameof(decayRate));
23 |
24 | _decayRate = decayRate;
25 | }
26 |
27 | ///
28 | /// Calculates the learning rate.
29 | ///
30 | /// The current learning rate.
31 | /// The new learning rate.
32 | public double DecayLearningRate(double learningRate)
33 | {
34 | learningRate *= 1.0 / (1.0 + _decayRate * PipelineSettings.Instance.CurrentIteration);
35 |
36 | return learningRate;
37 | }
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Optimization/Optimizers/GradientDescentOptimizer.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.MachineLearning.Infrastructure.Functions;
2 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork;
3 |
4 | namespace DigitRecognizer.MachineLearning.Optimization.Optimizers
5 | {
6 | ///
7 | /// Implements the gradient descent optimization algorithm.
8 | ///
9 | public class GradientDescentOptimizer : BaseOptimizer
10 | {
11 | ///
12 | /// Initializes a new instance of the class.
13 | ///
14 | /// The neural network.
15 | /// The cost function.
16 | public GradientDescentOptimizer(INeuralNetwork neuralNetwork, ICostFunction costFunction)
17 | : base(neuralNetwork, costFunction)
18 | {
19 | }
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Optimization/Optimizers/IOptimizer.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Data;
2 | using DigitRecognizer.MachineLearning.Infrastructure.Functions;
3 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork;
4 | using DigitRecognizer.MachineLearning.Pipeline;
5 |
6 | namespace DigitRecognizer.MachineLearning.Optimization.Optimizers
7 | {
8 | ///
9 | /// Interface that an optimization algorithm should implement.
10 | ///
11 | public interface IOptimizer : ILearningPipelineOptimizer
12 | {
13 | ///
14 | /// Gets the .
15 | ///
16 | INeuralNetwork NeuralNetwork { get; }
17 |
18 | ///
19 | /// Gets the .
20 | ///
21 | ICostFunction CostFunction { get; }
22 |
23 | ///
24 | /// Calculates the cost of the specified prediction.
25 | ///
26 | /// The prediction.
27 | /// The one hot value.
28 | /// The cost.
29 | double CalculateError(double[] prediction, int oneHot);
30 |
31 | ///
32 | /// Peforms the backpropagatin algorithm.
33 | ///
34 | /// The predictions.
35 | /// The one hot values.
36 | void Backpropagate(double[][] predictions, int[] oneHots);
37 |
38 | ///
39 | /// Adjusts the current parameters with specified gradient and learning rate.
40 | ///
41 | /// The node.
42 | /// The delat values.
43 | /// The gradient values.
44 | /// The learning rate.
45 | void AdjustParameters(LinkedListNode node, double[][] delta, double[][] gradient, double learningRate);
46 |
47 | ///
48 | /// Calculates the output derivative with respect to the of the optimizer.
49 | ///
50 | /// The predictions.
51 | /// The one hot values.
52 | /// The derivatives.
53 | double[][] CalculateOutputDerivative(double[][] predictions, int[] oneHots);
54 |
55 | ///
56 | /// Calcualates the derivative of the weighted sum with respect to the specified .
57 | ///
58 | /// The activation function.
59 | /// The node depth.
60 | /// The one hot values.
61 | /// The derivatives.
62 | double[][] WeightedSumDerivative(IActivationFunction activationFunction, int nodeDepth, int[] oneHots);
63 | }
64 | }
65 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Optimization/Optimizers/MomentumOptimizer.cs:
--------------------------------------------------------------------------------
1 | using System.Collections.Generic;
2 | using DigitRecognizer.Core.Utilities;
3 | using DigitRecognizer.MachineLearning.Infrastructure.Functions;
4 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork;
5 |
6 | namespace DigitRecognizer.MachineLearning.Optimization.Optimizers
7 | {
8 | ///
9 | /// Implements the gradient descent with momentum optimization algorithm.
10 | ///
11 | public class MomentumOptimizer : BaseOptimizer
12 | {
13 | private readonly double _momentum;
14 |
15 | private List _biasVelocities;
16 | private List _weightVelocities;
17 |
18 | ///
19 | /// initilazies a new instance of the class.
20 | ///
21 | /// The neural network.
22 | /// The cost function.
23 | /// The momentum value.
24 | public MomentumOptimizer(INeuralNetwork neuralNetwork, ICostFunction costFunction, double momentum) : base(neuralNetwork, costFunction)
25 | {
26 | Contracts.ValueGreaterThanZero(momentum, nameof(momentum));
27 | Contracts.ValueNotNull(neuralNetwork, nameof(neuralNetwork));
28 |
29 | _momentum = momentum;
30 |
31 | InitializeVelocities(neuralNetwork);
32 | }
33 |
34 | ///
35 | /// Initializes the velocity matrices for the .
36 | ///
37 | /// The neural network.
38 | private void InitializeVelocities(INeuralNetwork neuralNetwork)
39 | {
40 | _biasVelocities = new List();
41 | _weightVelocities = new List();
42 |
43 | foreach (NnLayer layer in neuralNetwork.Layers.ToList())
44 | {
45 | _biasVelocities.Add(VectorUtilities.CreateMatrix(1, layer.NumberOfOutputs));
46 |
47 | _weightVelocities.Add(VectorUtilities.CreateMatrix(layer.NumberOfInputs, layer.NumberOfOutputs));
48 | }
49 | }
50 |
51 | ///
52 | /// Adjusts the current parameters with specified gradient and learning rate.
53 | ///
54 | /// The node.
55 | /// The bias delta values.
56 | /// The weight delta values.
57 | /// The learning rate.
58 | public override void AdjustParameters(Core.Data.LinkedListNode node, double[][] bDelta, double[][] wDelta, double learningRate)
59 | {
60 | const int rowIndex = 0;
61 | for (var i = 0; i < _biasVelocities[node.Depth][0].Length; i++)
62 | {
63 | _biasVelocities[node.Depth][rowIndex][i] = _momentum * _biasVelocities[node.Depth][rowIndex][i] + bDelta[rowIndex][i];
64 | }
65 |
66 | for (var i = 0; i < _weightVelocities[node.Depth].Length; i++)
67 | {
68 | for (var j = 0; j < _weightVelocities[node.Depth][0].Length; j++)
69 | {
70 | _weightVelocities[node.Depth][i][j] = _momentum * _weightVelocities[node.Depth][i][j] + wDelta[i][j];
71 | }
72 | }
73 |
74 | base.AdjustParameters(node, _biasVelocities[node.Depth], _weightVelocities[node.Depth], learningRate);
75 | }
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Pipeline/ILearningPipelineDataLoader.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.MachineLearning.Pipeline
2 | {
3 | ///
4 | /// Interface for a data loader.
5 | ///
6 | public interface ILearningPipelineDataLoader : ILearningPipelineItem
7 | {
8 | ///
9 | /// Loads the data from a source file.
10 | ///
11 | /// The data object.
12 | object LoadData();
13 | }
14 | }
15 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Pipeline/ILearningPipelineItem.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.MachineLearning.Pipeline
2 | {
3 | ///
4 | /// Interface for a item.
5 | ///
6 | public interface ILearningPipelineItem
7 | {
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Pipeline/ILearningPipelineNeuralNetworkModel.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.MachineLearning.Pipeline
2 | {
3 | ///
4 | /// Interface for a neural network model.
5 | ///
6 | public interface ILearningPipelineNeuralNetworkModel : ILearningPipelineItem
7 | {
8 | ///
9 | /// Generates a prediction based on the specified input.
10 | ///
11 | /// The input.
12 | /// The prediction.
13 | double[][] Predict(double[][] input);
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Pipeline/ILearningPipelineOptimizer.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.MachineLearning.Pipeline
2 | {
3 | ///
4 | /// Interface for a optimizer.
5 | ///
6 | public interface ILearningPipelineOptimizer : ILearningPipelineItem
7 | {
8 | ///
9 | /// Optimizes the specified parameters.
10 | ///
11 | /// The predictions.
12 | /// The one hot values.
13 | void Optimize(double[][] predictions, int[] oneHots);
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Properties/AssemblyInfo.cs:
--------------------------------------------------------------------------------
1 | using System.Reflection;
2 | using System.Runtime.InteropServices;
3 |
4 | // General Information about an assembly is controlled through the following
5 | // set of attributes. Change these attribute values to modify the information
6 | // associated with an assembly.
7 | [assembly: AssemblyTitle("DigitRecognizer.MachineLearning")]
8 | [assembly: AssemblyDescription("")]
9 | [assembly: AssemblyConfiguration("")]
10 | [assembly: AssemblyCompany("")]
11 | [assembly: AssemblyProduct("DigitRecognizer.MachineLearning")]
12 | [assembly: AssemblyCopyright("Copyright © 2018")]
13 | [assembly: AssemblyTrademark("")]
14 | [assembly: AssemblyCulture("")]
15 |
16 | // Setting ComVisible to false makes the types in this assembly not visible
17 | // to COM components. If you need to access a type in this assembly from
18 | // COM, set the ComVisible attribute to true on that type.
19 | [assembly: ComVisible(false)]
20 |
21 | // The following GUID is for the ID of the typelib if this project is exposed to COM
22 | [assembly: Guid("bf82c3df-0efe-45db-8e3d-1ffdd548ae73")]
23 |
24 | // Version information for an assembly consists of the following four values:
25 | //
26 | // Major Version
27 | // Minor Version
28 | // Build Number
29 | // Revision
30 | //
31 | // You can specify all the values or you can default the Build and Revision Numbers
32 | // by using the '*' as shown below:
33 | // [assembly: AssemblyVersion("1.0.*")]
34 | [assembly: AssemblyVersion("1.0.0.0")]
35 | [assembly: AssemblyFileVersion("1.0.0.0")]
36 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Providers/BatchDataProvider.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.Data;
2 |
3 | namespace DigitRecognizer.MachineLearning.Providers
4 | {
5 | ///
6 | /// A data provider, that retrieves a . The size of the batch is configurable.
7 | ///
8 | public class BatchDataProvider : DataProviderBase
9 | {
10 | ///
11 | /// Initializes a new instance of the class.
12 | ///
13 | /// The path to the file containing the labels.
14 | /// The path to the file containing the images.
15 | /// The size of the batch to read.
16 | public BatchDataProvider(string labelFilename, string imageFilename, int batchSize)
17 | : base(labelFilename, imageFilename, batchSize)
18 | {
19 | }
20 |
21 | ///
22 | /// Gets the data from the file system.
23 | ///
24 | ///
25 | public override MnistImageBatch GetData()
26 | {
27 | int[] label = LabelReader.ReadLabels(BatchSize);
28 |
29 | double[][] pixels = PixelReader.ReadPixels(BatchSize, ImageSizeInPixels);
30 |
31 | pixels = NormalizePixels(pixels);
32 |
33 | var result = new MnistImageBatch(label, pixels);
34 |
35 | return result;
36 | }
37 |
38 | ///
39 | /// Normalizes the specified matrix of pixels to be in the rande [0,1].
40 | ///
41 | /// The pixels.
42 | /// The clamped pixels.
43 | private double[][] NormalizePixels(double[][] pixels)
44 | {
45 | for (var i = 0; i < pixels.Length; i++)
46 | {
47 | for (var j = 0; j < pixels[0].Length; j++)
48 | {
49 | pixels[i][j] = pixels[i][j] / 255d;
50 | }
51 | }
52 |
53 | return pixels;
54 | }
55 |
56 | ///
57 | /// Loads the data from a source file.
58 | ///
59 | /// The data object.
60 | public override object LoadData()
61 | {
62 | return GetData();
63 | }
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Providers/DataProviderBase.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.IO;
2 | using DigitRecognizer.Core.Utilities;
3 |
4 | namespace DigitRecognizer.MachineLearning.Providers
5 | {
6 | ///
7 | /// Represents a base class for a data provider.
8 | ///
9 | /// The type of data the provider returns.
10 | public abstract class DataProviderBase : IDataProvider
11 | {
12 | private readonly string _labelFilename;
13 | private readonly string _imageFilename;
14 | private readonly int _batchSize;
15 |
16 | ///
17 | /// The size of a single MNIST image in pixels.
18 | ///
19 | protected const int ImageSizeInPixels = 784;
20 |
21 | ///
22 | /// The instance, used for fetching labels.
23 | ///
24 | protected readonly ILabelReader LabelReader;
25 |
26 | ///
27 | /// The instance, used for fetching pixels.
28 | ///
29 | protected readonly IPixelReader PixelReader;
30 |
31 | ///
32 | /// Gets the batch size of the data provider.
33 | ///
34 | public int BatchSize => _batchSize;
35 |
36 | ///
37 | /// Gets the name of the file containing labels.
38 | ///
39 | public string LabelFilename => _labelFilename;
40 |
41 | ///
42 | /// Gets the name of the file containing images.
43 | ///
44 | public string ImageFilename => _imageFilename;
45 |
46 | ///
47 | /// Initialzies a new instance of the class.
48 | ///
49 | /// The path to the file containing the labels.
50 | /// The path to the file containing the images.
51 | /// The size of the batch to read.
52 | protected DataProviderBase(string labelFilename, string imageFilename, int batchSize)
53 | {
54 | Contracts.StringNotNullOrEmpty(labelFilename, nameof(labelFilename));
55 | Contracts.FileExists(labelFilename, nameof(labelFilename));
56 |
57 | Contracts.FileExists(imageFilename, nameof(imageFilename));
58 | Contracts.StringNotNullOrEmpty(imageFilename, nameof(imageFilename));
59 |
60 | Contracts.ValueGreaterThanZero(batchSize, nameof(batchSize));
61 |
62 | LabelReader = new LabelReader(labelFilename);
63 | PixelReader = new PixelReader(imageFilename);
64 | _labelFilename = labelFilename;
65 | _imageFilename = imageFilename;
66 | _batchSize = batchSize;
67 | }
68 |
69 | ///
70 | /// Gets the data from the file system.
71 | ///
72 | ///
73 | public virtual T GetData()
74 | {
75 | return default(T);
76 | }
77 |
78 | ///
79 | ///
80 | ///
81 | ///
82 | public virtual object LoadData()
83 | {
84 | return GetData();
85 | }
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Providers/IDataProvider.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.MachineLearning.Pipeline;
2 |
3 | namespace DigitRecognizer.MachineLearning.Providers
4 | {
5 | ///
6 | /// Represents a data provider abstraction that a neural network will use for training.
7 | ///
8 | ///
9 | public interface IDataProvider : ILearningPipelineDataLoader
10 | {
11 | ///
12 | /// Gets the data from a file.
13 | ///
14 | ///
15 | T GetData();
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Serialization/INnSerializable.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Core.IO;
2 |
3 | namespace DigitRecognizer.MachineLearning.Serialization
4 | {
5 | ///
6 | /// Neural network layer will implement this interface in order to support serialization to a binary file.
7 | ///
8 | public interface INnSerializable
9 | {
10 | ///
11 | /// Seralizes the object to a .
12 | ///
13 | /// The serialization context containing information about the object.
14 | NnSerializationContext Serialize();
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Serialization/NnDeserializer.cs:
--------------------------------------------------------------------------------
1 | using System.Collections.Generic;
2 | using System.IO;
3 | using System.Linq;
4 | using DigitRecognizer.Core.IO;
5 | using DigitRecognizer.MachineLearning.Infrastructure.Factories;
6 | using DigitRecognizer.MachineLearning.Infrastructure.Functions;
7 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork;
8 |
9 | namespace DigitRecognizer.MachineLearning.Serialization
10 | {
11 | ///
12 | /// Deserializer for deserializing objects.
13 | ///
14 | public class NnDeserializer
15 | {
16 | ///
17 | /// Deserializes data from the specified file to a .
18 | ///
19 | ///
20 | ///
21 | public NeuralNetwork Deserialize(string filename)
22 | {
23 | var deserializer = new NnDeserializer();
24 |
25 | IEnumerable layers = deserializer.DeserializeLayers(filename);
26 |
27 | var neuralNetwork = new NeuralNetwork(layers, 1.0);
28 |
29 | return neuralNetwork;
30 | }
31 |
32 | ///
33 | /// Deserializes data from the specified file to a collection of objects.
34 | ///
35 | /// The filename.
36 | /// An collection of neural network layers.
37 | public IEnumerable DeserializeLayers(string filename)
38 | {
39 | using (var deserializer = new NnBinaryDeserializer(filename, FileMode.Open))
40 | {
41 | IEnumerable contexts = deserializer.Deserialize();
42 |
43 | IEnumerable result = contexts.Select(DeserializeContext);
44 |
45 | return result;
46 | }
47 | }
48 |
49 | ///
50 | /// Deserializes the specified to a object.
51 | ///
52 | /// The serialization context.
53 | /// The deserialized neural network layer.
54 | public NnLayer DeserializeContext(NnSerializationContext serializationContext)
55 | {
56 | NnSerializationContextInfo contextInfo = serializationContext.SerializationContextInfo;
57 |
58 | double[] data = serializationContext.FileData;
59 |
60 | var activationFunction = (IActivationFunction) FunctionFactory.Instance.GetInstance(contextInfo.ActivationFunctionName);
61 |
62 | var layer = new NnLayer(contextInfo.WeightMatrixRowCount, contextInfo.WeightMatrixColCount, contextInfo.BiasLength, data, activationFunction);
63 |
64 | return layer;
65 | }
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.MachineLearning/Serialization/NnSerializer.cs:
--------------------------------------------------------------------------------
1 | using System.Collections.Generic;
2 | using System.IO;
3 | using System.Linq;
4 | using DigitRecognizer.Core.IO;
5 | using DigitRecognizer.MachineLearning.Infrastructure.NeuralNetwork;
6 |
7 | namespace DigitRecognizer.MachineLearning.Serialization
8 | {
9 | ///
10 | /// Serializer for serializing objects.
11 | ///
12 | public class NnSerializer
13 | {
14 | ///
15 | /// Serializes the specified to a file.
16 | ///
17 | /// The filename to write to.
18 | /// The neural network.
19 | public void Serialize(string filename, INeuralNetwork neuralNetwork)
20 | {
21 | Serialize(filename, neuralNetwork.Layers.ToList());
22 | }
23 |
24 | ///
25 | /// Serializes the specified collection of objects to a file.
26 | ///
27 | /// The filename to write to.
28 | /// The collection of objects to serialize.
29 | public void Serialize(string filename, IEnumerable collection)
30 | {
31 | using (var serializer = new NnBinarySerializer(filename, FileMode.Create))
32 | {
33 | IEnumerable contexts = collection.Select(layer => layer.Serialize());
34 |
35 | serializer.Serialize(contexts);
36 | }
37 | }
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/App.config:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Components/PredictionPane.Designer.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.Presentation.Components
2 | {
3 | partial class PredictionPane
4 | {
5 | ///
6 | /// Required designer variable.
7 | ///
8 | private System.ComponentModel.IContainer components = null;
9 |
10 | ///
11 | /// Clean up any resources being used.
12 | ///
13 | /// true if managed resources should be disposed; otherwise, false.
14 | protected override void Dispose(bool disposing)
15 | {
16 | if (disposing && (components != null))
17 | {
18 | components.Dispose();
19 | }
20 | base.Dispose(disposing);
21 | }
22 |
23 | #region Component Designer generated code
24 |
25 | ///
26 | /// Required method for Designer support - do not modify
27 | /// the contents of this method with the code editor.
28 | ///
29 | private void InitializeComponent()
30 | {
31 | this.SuspendLayout();
32 | //
33 | // PredictionPane
34 | //
35 | this.AutoScaleDimensions = new System.Drawing.SizeF(6F, 13F);
36 | this.AutoScaleMode = System.Windows.Forms.AutoScaleMode.Font;
37 | this.BackColor = System.Drawing.Color.White;
38 | this.Name = "PredictionPane";
39 | this.Size = new System.Drawing.Size(550, 400);
40 | this.ResumeLayout(false);
41 |
42 | }
43 |
44 | #endregion
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Components/PredictionPane.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 | using System.Drawing;
4 | using System.Windows.Forms;
5 |
6 | namespace DigitRecognizer.Presentation.Components
7 | {
8 | public partial class PredictionPane : UserControl
9 | {
10 | #region Fields
11 |
12 | private const string LabelName = "lbl";
13 | private const string LabelPredictionName = "lblPrediction";
14 | private const string ProgressBarName = "progressBar";
15 | private const int ClassNumber = 10;
16 |
17 | #endregion
18 |
19 | #region Ctor
20 |
21 | public PredictionPane()
22 | {
23 | InitializeComponent();
24 |
25 | InitializeDisplay();
26 | }
27 |
28 | #endregion
29 |
30 | #region Methods
31 |
32 | public void ProcessPrediction(double[] predictions)
33 | {
34 | if (predictions.Length != ClassNumber)
35 | {
36 | throw new ArgumentException(nameof(predictions));
37 | }
38 |
39 | for (var i = 0; i < predictions.Length; i++)
40 | {
41 | ProgressBarHandler(GetProgressBarName(i), (int)(Math.Round(predictions[i] * 100)));
42 |
43 | PredictionLabelHandler(GetPredictionLabelName(i), predictions[i].ToString("P"));
44 | }
45 | }
46 |
47 | public void Clear()
48 | {
49 | for (var i = 0; i < ClassNumber; i++)
50 | {
51 | ProgressBarHandler(GetProgressBarName(i), 0);
52 |
53 | PredictionLabelHandler(GetPredictionLabelName(i), string.Empty);
54 | }
55 | }
56 |
57 | private void ProgressBarHandler(string progressBarKey, int value)
58 | {
59 | if (Controls.ContainsKey(progressBarKey))
60 | {
61 | ((ProgressBar)Controls[progressBarKey]).Value = value;
62 | }
63 | else
64 | {
65 | throw new ArgumentException(nameof(progressBarKey));
66 | }
67 | }
68 |
69 | private void PredictionLabelHandler(string labelPredictionKey, string text)
70 | {
71 | if (Controls.ContainsKey(labelPredictionKey))
72 | {
73 | ((Label)Controls[labelPredictionKey]).Text = text;
74 | }
75 | else
76 | {
77 | throw new ArgumentException(nameof(labelPredictionKey));
78 | }
79 | }
80 |
81 | #endregion
82 |
83 | #region Utilities
84 |
85 | private void InitializeDisplay()
86 | {
87 | var controls = new List();
88 |
89 | for (var i = 0; i < 10; i++)
90 | {
91 | int top = Top + 40 * i + 1;
92 |
93 | Label labelNum = GenerateLabel(GetLabelName(i), Left, top + 1, i.ToString(), 20);
94 |
95 | ProgressBar progressBar = GenerateProgressBar(GetProgressBarName(i), labelNum.Right + 5, top);
96 |
97 | Label labelPrediction = GenerateLabel(GetPredictionLabelName(i), progressBar.Right + 5, top + 1, "Prediction");
98 |
99 | controls.Add(labelNum);
100 | controls.Add(progressBar);
101 | controls.Add(labelPrediction);
102 | }
103 |
104 | Controls.AddRange(controls.ToArray());
105 | }
106 |
107 | private static string GetPredictionLabelName(int i) => $"{LabelPredictionName}{i}";
108 |
109 | private static string GetProgressBarName(int i) => $"{ProgressBarName}{i}";
110 |
111 | private static string GetLabelName(int i) => $"{LabelName}{i}";
112 |
113 | private static ProgressBar GenerateProgressBar(string name, int left, int top)
114 | {
115 | return new ProgressBar
116 | {
117 | Name = name,
118 | Value = 0,
119 | Left = left,
120 | Top = top,
121 | Width = 450
122 | };
123 | }
124 |
125 | private static Label GenerateLabel(string name, int left, int right, string text = "", int? width = null)
126 | {
127 | var lbl = new Label
128 | {
129 | AutoSize = true,
130 | Name = name,
131 | Text = text,
132 | Left = left,
133 | Top = right,
134 | Font = new Font(FontFamily.GenericSansSerif, 12f)
135 | };
136 |
137 | if (width != null)
138 | {
139 | lbl.Width = (int)width;
140 | }
141 |
142 | return lbl;
143 | }
144 |
145 | #endregion
146 | }
147 | }
148 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Global.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.MachineLearning.Infrastructure.Models;
2 | using DigitRecognizer.Presentation.Infrastructure;
3 |
4 | namespace DigitRecognizer.Presentation
5 | {
6 | public static class Global
7 | {
8 | private static IPredictionModel _predictionModel;
9 |
10 | public static IPredictionModel PredictionModel => _predictionModel ?? (_predictionModel = LoadModel());
11 |
12 | public static readonly int ImageGridFieldCount = 100;
13 |
14 | private static IPredictionModel LoadModel()
15 | {
16 | var loader = new PredictionModelLoader();
17 |
18 | IPredictionModel model = loader.Load();
19 |
20 | return model;
21 | }
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Infrastructure/DependencyResolver.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Collections.Generic;
3 |
4 | namespace DigitRecognizer.Presentation.Infrastructure
5 | {
6 | public class DependencyResolver
7 | {
8 | private readonly Dictionary _dependencyDictionary;
9 |
10 | private static DependencyResolver _instance;
11 |
12 | public DependencyResolver()
13 | {
14 | _dependencyDictionary = new Dictionary();
15 | }
16 |
17 | ///
18 | /// The object used for locking. Required for thread safety.
19 | ///
20 | private static readonly object Lock = new object();
21 |
22 | ///
23 | /// Gets the singleton instance of the class.
24 | ///
25 | public static DependencyResolver Instance
26 | {
27 | get
28 | {
29 | lock (Lock)
30 | {
31 | return _instance ?? (_instance = new DependencyResolver());
32 | }
33 | }
34 | }
35 |
36 | public static void Register()
37 | {
38 | Register(typeof(TContract), typeof(TImplementation));
39 | }
40 |
41 | public static void Register(Type contract, Type implementation)
42 | {
43 | Instance._dependencyDictionary.Add(contract, implementation);
44 | }
45 |
46 | public static T Resolve()
47 | {
48 | return (T)Resolve(typeof(T));
49 | }
50 |
51 | public static object Resolve(Type contract)
52 | {
53 | return Activator.CreateInstance(Instance._dependencyDictionary[contract]);
54 | }
55 | }
56 | }
57 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Infrastructure/ExceptionHandlers.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Threading;
3 | using System.Windows.Forms;
4 | using DigitRecognizer.Presentation.Services;
5 |
6 | namespace DigitRecognizer.Presentation.Infrastructure
7 | {
8 | public static class ExceptionHandlers
9 | {
10 | public static void CurrentDomainOnUnhandledException(object sender, UnhandledExceptionEventArgs e)
11 | {
12 | var ex = (Exception) e.ExceptionObject;
13 |
14 | string message = $"Something went wrong.{Environment.NewLine}{Environment.NewLine}{ex}";
15 |
16 | DependencyResolver.Resolve().Log(ex);
17 |
18 | DependencyResolver.Resolve().ShowMessage(message, "Unhandled Exception", icon: MessageBoxIcon.Error);
19 | }
20 |
21 | public static void ApplicationOnThreadException(object sender, ThreadExceptionEventArgs e)
22 | {
23 | string message = $"Something went wrong.{Environment.NewLine}{Environment.NewLine}{e.Exception}";
24 |
25 | DependencyResolver.Resolve().Log(e.Exception);
26 |
27 | DependencyResolver.Resolve().ShowMessage(message, "Thread exception", icon: MessageBoxIcon.Information);
28 | }
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Infrastructure/PanelDoubleBuffering.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Reflection;
3 | using System.Windows.Forms;
4 |
5 | namespace DigitRecognizer.Presentation.Infrastructure
6 | {
7 | public static class PanelDoubleBuffering
8 | {
9 | private const string DoubleBufferedName = "DoubleBuffered";
10 |
11 | public static void Enable(object target)
12 | {
13 | if (target == null)
14 | {
15 | throw new NullReferenceException();
16 | }
17 |
18 | if (!(target is Panel))
19 | {
20 | throw new ArgumentException($"Can not enable double buffering on object of type {target.GetType()}");
21 | }
22 |
23 | typeof(Panel).InvokeMember(DoubleBufferedName,
24 | BindingFlags.SetProperty | BindingFlags.Instance | BindingFlags.NonPublic,
25 | null,
26 | target,
27 | new object[] { true });
28 | }
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Infrastructure/PredictionModelLoader.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.IO;
3 | using System.Linq;
4 | using System.Windows.Forms;
5 | using DigitRecognizer.Core.Utilities;
6 | using DigitRecognizer.MachineLearning.Infrastructure.Models;
7 |
8 | namespace DigitRecognizer.Presentation.Infrastructure
9 | {
10 | public class PredictionModelLoader
11 | {
12 | private readonly OpenFileDialog _openFileDialog;
13 |
14 | public PredictionModelLoader()
15 | {
16 | _openFileDialog = new OpenFileDialog
17 | {
18 | Multiselect = true,
19 | CheckFileExists = true,
20 | InitialDirectory = Path.GetFullPath(Path.GetDirectoryName(AppDomain.CurrentDomain.BaseDirectory) + DirectoryHelper.ModelsFolder),
21 | Filter = @"Neural network file (*.nn)|*.nn"
22 | };
23 | }
24 |
25 | public IPredictionModel Load()
26 | {
27 | if (DialogResult.OK != _openFileDialog.ShowDialog() || _openFileDialog.FileNames.Length == 0 || _openFileDialog.FileNames.Any(string.IsNullOrWhiteSpace))
28 | {
29 | return null;
30 | }
31 |
32 | IPredictionModel model = ClusterPredictionModel.FromFiles(_openFileDialog.FileNames);
33 |
34 | return model;
35 | }
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Models/ImageGridModel.cs:
--------------------------------------------------------------------------------
1 | using System.Collections.Generic;
2 | using System.Drawing;
3 | using System.Drawing.Imaging;
4 | using System.Threading.Tasks;
5 | using DigitRecognizer.Core.Data;
6 | using DigitRecognizer.Core.Extensions;
7 |
8 | namespace DigitRecognizer.Presentation.Models
9 | {
10 | public class ImageGridModel
11 | {
12 | private const int Size = 28;
13 |
14 | private static readonly Color PaleRed = Color.FromArgb(255, 239, 153, 153);
15 |
16 | public ImageGridModel(MnistImageBatch batch, int[] predictions)
17 | {
18 | Labels = new List(batch.Labels);
19 |
20 | Predictions = new List(predictions);
21 |
22 | Images = new List(ProcessImages(batch.Pixels));
23 | }
24 |
25 | public ImageGridModel()
26 | {
27 | Labels = new List();
28 | Predictions = new List();
29 | Images = new List();
30 | }
31 |
32 | private Image[] ProcessImages(double[][] pixelArray)
33 | {
34 | var images = new Image[pixelArray.Length];
35 |
36 | Parallel.For(0, pixelArray.Length, (row) =>
37 | {
38 | double[] pixels = pixelArray[row];
39 |
40 | var bmp = new Bitmap(Size, Size);
41 |
42 | BitmapData data = bmp.LockBits(new Rectangle(0, 0, Size, Size), ImageLockMode.ReadWrite, PixelFormat.Format24bppRgb);
43 |
44 | unsafe
45 | {
46 | var ptr = (byte*)data.Scan0.ToPointer();
47 |
48 | for (var len = 0; len < Size * Size; len++)
49 | {
50 | var pixel = (byte) (pixels[len] > 0.2 ? 0 : 255);
51 |
52 | if (Labels[row] != Predictions[row])
53 | {
54 | *(ptr++) = pixel == 255 ? PaleRed.B : pixel; // B
55 | *(ptr++) = pixel == 255 ? PaleRed.G : pixel; // G
56 | *(ptr++) = pixel == 255 ? PaleRed.R : pixel; // R
57 | }
58 | else
59 | {
60 | *(ptr++) = pixel; // B
61 | *(ptr++) = pixel; // G
62 | *(ptr++) = pixel; // R
63 | }
64 | }
65 | }
66 |
67 | bmp.UnlockBits(data);
68 |
69 | images[row] = bmp.Resize(56, 56);
70 | });
71 |
72 | return images;
73 | }
74 |
75 | public List Images { get; set; }
76 |
77 | public List Labels { get; set; }
78 |
79 | public List Predictions { get; set; }
80 |
81 | public int Count => Images.Count;
82 |
83 | public static ImageGridModel operator +(ImageGridModel igm1, ImageGridModel igm2)
84 | {
85 | igm1.Images.AddRange(igm2.Images);
86 | igm1.Labels.AddRange(igm2.Labels);
87 | igm1.Predictions.AddRange(igm2.Predictions);
88 |
89 | return igm1;
90 | }
91 | }
92 | }
93 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Presenters/ApplicationPresenter.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Presentation.Services;
2 | using DigitRecognizer.Presentation.Views.Interfaces;
3 |
4 | namespace DigitRecognizer.Presentation.Presenters
5 | {
6 | public class ApplicationPresenter
7 | {
8 | private readonly BenchmarkPresenter _benchmarkPresenter;
9 | private readonly DrawingPresenter _drawingPresenter;
10 | private readonly UploadImagePresenter _uploadImagePresenter;
11 | private readonly SlidingWindowPresenter _slidingWindowPresenter;
12 |
13 | public ApplicationPresenter(IMainFormView mainFormView,
14 | IMessageService messageService,
15 | ILoggingService loggingService)
16 | {
17 | loggingService.Log("Initializing presenters");
18 |
19 | _benchmarkPresenter = new BenchmarkPresenter(mainFormView.BenchmarkView, messageService, loggingService);
20 | _drawingPresenter = new DrawingPresenter(mainFormView.DrawingView, messageService, loggingService);
21 | _uploadImagePresenter = new UploadImagePresenter(mainFormView.UploadImageView, messageService, loggingService);
22 | _slidingWindowPresenter = new SlidingWindowPresenter(mainFormView.SlidingWindowView, messageService, loggingService);
23 | }
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Presenters/BenchmarkPresenter.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Linq;
3 | using System.Threading;
4 | using System.Threading.Tasks;
5 | using System.Windows.Forms;
6 | using DigitRecognizer.Core.Data;
7 | using DigitRecognizer.Core.Extensions;
8 | using DigitRecognizer.Core.Utilities;
9 | using DigitRecognizer.MachineLearning.Infrastructure.Models;
10 | using DigitRecognizer.MachineLearning.Providers;
11 | using DigitRecognizer.Presentation.Models;
12 | using DigitRecognizer.Presentation.Services;
13 | using DigitRecognizer.Presentation.Views.Interfaces;
14 |
15 | namespace DigitRecognizer.Presentation.Presenters
16 | {
17 | public class BenchmarkPresenter
18 | {
19 | #region Fields
20 |
21 | private readonly IBenchmarkView _benchmarkView;
22 | private readonly IMessageService _messageService;
23 | private readonly ILoggingService _loggingService;
24 | private CancellationTokenSource _cancellationTokenSource;
25 |
26 | #endregion
27 |
28 | #region Ctor
29 |
30 | public BenchmarkPresenter(IBenchmarkView benchmarkView, IMessageService messageService, ILoggingService loggingService)
31 | {
32 | _messageService = messageService;
33 | _loggingService = loggingService;
34 | _benchmarkView = benchmarkView;
35 |
36 | _benchmarkView.RunBenchmark += OnRunBenchmark;
37 |
38 | _benchmarkView.CancelBenchmark += OnCancelBenchmark;
39 |
40 | _benchmarkView.IsBenchmarkRunning = false;
41 | }
42 |
43 | #endregion
44 |
45 | #region Methods
46 |
47 | private async void OnRunBenchmark(object sender, EventArgs e)
48 | {
49 | IPredictionModel predictionModel = Global.PredictionModel;
50 |
51 | if (predictionModel == null)
52 | {
53 | _messageService.ShowMessage("The prediction model must be loaded first.", "Prediction model", icon: MessageBoxIcon.Information);
54 |
55 | return;
56 | }
57 |
58 | _benchmarkView.ResetView();
59 |
60 | _cancellationTokenSource = new CancellationTokenSource();
61 |
62 | CancellationToken token = _cancellationTokenSource.Token;
63 |
64 | token.Register(() => { _benchmarkView.IsBenchmarkRunning = false; });
65 |
66 | try
67 | {
68 | await Task.Run(() => {
69 | RunBenchmark(predictionModel);
70 | }, token);
71 | }
72 | catch (Exception exception)
73 | {
74 | _loggingService.Log(exception);
75 |
76 | _messageService.ShowMessage("An error ocurred while running the benchmark. Please try again.", "Benchmark error", icon: MessageBoxIcon.Information);
77 | }
78 | }
79 |
80 | private void RunBenchmark(IPredictionModel model)
81 | {
82 | _loggingService.Log("Running benchmark has started");
83 |
84 | _benchmarkView.IsBenchmarkRunning = true;
85 |
86 | var provider = new BatchDataProvider(DirectoryHelper.TestLabelsPath, DirectoryHelper.TestImagesPath, 100);
87 |
88 | var acc = 0;
89 |
90 | for (var i = 0; i < 100; i++)
91 | {
92 | if (!_benchmarkView.IsBenchmarkRunning)
93 | {
94 | break;
95 | }
96 |
97 | MnistImageBatch data = provider.GetData();
98 |
99 | int[] predictions = data.Pixels.Select(model.Predict).Select(x=> x.ArgMax()).ToArray();
100 |
101 | acc += data.Labels.Where((lbl, pred) => lbl == predictions[pred]).Count();
102 |
103 | _benchmarkView.PerformProgressStep();
104 |
105 | _benchmarkView.DrawGrid(new ImageGridModel(data, predictions));
106 | }
107 |
108 | _benchmarkView.SetAccuracy(acc);
109 |
110 | _benchmarkView.IsBenchmarkRunning = false;
111 |
112 | _loggingService.Log("Running benchmark has completed");
113 | }
114 |
115 | private void OnCancelBenchmark(object sender, EventArgs e)
116 | {
117 | _loggingService.Log("Running benchmark was canceled");
118 |
119 | _cancellationTokenSource.Cancel();
120 | }
121 |
122 | #endregion
123 | }
124 | }
125 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Presenters/DrawingPresenter.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Windows.Forms;
3 | using DigitRecognizer.Core.IO;
4 | using DigitRecognizer.MachineLearning.Infrastructure.Models;
5 | using DigitRecognizer.Presentation.Services;
6 | using DigitRecognizer.Presentation.Views.Interfaces;
7 |
8 | namespace DigitRecognizer.Presentation.Presenters
9 | {
10 | public class DrawingPresenter
11 | {
12 | #region Fields
13 |
14 | private readonly IDrawingView _drawingView;
15 | private readonly IMessageService _messageService;
16 | private readonly ILoggingService _loggingService;
17 |
18 | #endregion
19 |
20 | #region Ctor
21 |
22 | public DrawingPresenter(IDrawingView drawingView, IMessageService messageService, ILoggingService loggingService)
23 | {
24 | _messageService = messageService;
25 | _loggingService = loggingService;
26 | _drawingView = drawingView;
27 |
28 | _drawingView.ClassifyDrawing += OnClassifyDrawing;
29 |
30 | _drawingView.ClearDrawing += OnClearDrawing;
31 | }
32 |
33 | #endregion
34 |
35 | #region Methods
36 |
37 | private void OnClassifyDrawing(object sender, EventArgs e)
38 | {
39 | try
40 | {
41 | _loggingService.Log("Classify drawing has started");
42 |
43 | var imagePreprocessor = new ImagePreprocessor();
44 |
45 | double[] pixels = imagePreprocessor.Preprocess(_drawingView.Drawing);
46 |
47 | IPredictionModel predictionModel = Global.PredictionModel;
48 |
49 | double[] prediction = predictionModel.Predict(pixels);
50 |
51 | _drawingView.ProcessPrediction(prediction);
52 |
53 | _loggingService.Log("Classify drawing has completed");
54 | }
55 | catch (Exception exception)
56 | {
57 | _loggingService.Log(exception);
58 |
59 | _messageService.ShowMessage("An error ocurred while classyfing the drawing. Please try again.", "Classification error", icon: MessageBoxIcon.Information);
60 | }
61 | }
62 |
63 | private void OnClearDrawing(object sender, EventArgs e)
64 | {
65 | _drawingView.Clear();
66 | }
67 |
68 | #endregion
69 | }
70 | }
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Presenters/SlidingWindowPresenter.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Drawing;
3 | using System.Windows.Forms;
4 | using DigitRecognizer.Core.Data;
5 | using DigitRecognizer.Core.Extensions;
6 | using DigitRecognizer.Core.IO;
7 | using DigitRecognizer.Core.Utilities;
8 | using DigitRecognizer.MachineLearning.Infrastructure.Models;
9 | using DigitRecognizer.Presentation.Services;
10 | using DigitRecognizer.Presentation.Views.Interfaces;
11 |
12 | namespace DigitRecognizer.Presentation.Presenters
13 | {
14 | public class SlidingWindowPresenter
15 | {
16 | #region Fields
17 |
18 | private readonly ISlidingWindowView _slidingWindowView;
19 | private readonly IMessageService _messageService;
20 | private readonly ILoggingService _loggingService;
21 |
22 | private static readonly Size[] WindowSizes =
23 | {
24 | new Size(280, 280),
25 | //new Size(140, 140),
26 | //new Size(112, 112),
27 | //new Size(84, 84),
28 | //new Size(56, 56)
29 | };
30 |
31 | #endregion
32 |
33 | #region Ctor
34 |
35 | public SlidingWindowPresenter(ISlidingWindowView slidingWindowView, IMessageService messageService, ILoggingService loggingService)
36 | {
37 | _messageService = messageService;
38 | _loggingService = loggingService;
39 | _slidingWindowView = slidingWindowView;
40 |
41 | _slidingWindowView.ClassifyDrawing += OnClassifyDrawing;
42 |
43 | _slidingWindowView.ClearDrawing += OnClearDrawing;
44 | }
45 |
46 | #endregion
47 |
48 | #region Methods
49 |
50 | private void OnClassifyDrawing(object sender, EventArgs e)
51 | {
52 | try
53 | {
54 | _loggingService.Log("Classify drawing has started");
55 |
56 | var imagePreprocessor = new ImagePreprocessor();
57 |
58 | IPredictionModel predictionModel = Global.PredictionModel;
59 |
60 | Image img = _slidingWindowView.Drawing;
61 |
62 | foreach (Size windowSize in WindowSizes)
63 | {
64 | foreach (BoundingBox boundingBox in ImageUtilities.SlidingWindow(img, windowSize, 112))
65 | {
66 | try
67 | {
68 | double[] pixels = imagePreprocessor.Preprocess(boundingBox.Image);
69 |
70 | double[] prediction = predictionModel.Predict(pixels);
71 |
72 | // If classification is over 99% draw a bounding box at this location
73 | int predicted = prediction.ArgMax();
74 | double predictedAccuracy = prediction[prediction.ArgMax()];
75 |
76 | if (predictedAccuracy >= 0.95)
77 | {
78 | _slidingWindowView.DrawBoundingBox(boundingBox, predicted, predictedAccuracy);
79 | }
80 | }
81 | catch (Exception exception)
82 | {
83 | _loggingService.Log(exception);
84 | }
85 | }
86 | }
87 |
88 | _loggingService.Log("Classify drawing has completed");
89 | }
90 | catch (Exception exception)
91 | {
92 | _loggingService.Log(exception);
93 |
94 | _messageService.ShowMessage("An error ocurred while classyfing the drawing. Please try again.", "Classification error", icon: MessageBoxIcon.Information);
95 | }
96 | }
97 |
98 | private void OnClearDrawing(object sender, EventArgs e)
99 | {
100 | _slidingWindowView.Clear();
101 | }
102 |
103 | #endregion
104 | }
105 | }
106 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Presenters/UploadImagePresenter.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Windows.Forms;
3 | using DigitRecognizer.Core.IO;
4 | using DigitRecognizer.MachineLearning.Infrastructure.Models;
5 | using DigitRecognizer.Presentation.Services;
6 | using DigitRecognizer.Presentation.Views.Interfaces;
7 |
8 | namespace DigitRecognizer.Presentation.Presenters
9 | {
10 | public class UploadImagePresenter
11 | {
12 | #region Fields
13 |
14 | private readonly IUploadImageView _uploadImageView;
15 | private readonly IMessageService _messageService;
16 | private readonly ILoggingService _loggingService;
17 |
18 | #endregion
19 |
20 | #region Ctor
21 |
22 |
23 | public UploadImagePresenter(IUploadImageView uploadImageView, IMessageService messageService, ILoggingService loggingService)
24 | {
25 | _uploadImageView = uploadImageView;
26 | _messageService = messageService;
27 | _loggingService = loggingService;
28 |
29 | _uploadImageView.ClassifyImage += OnClassifyImage;
30 |
31 | _uploadImageView.ClearImage += OnClearImage;
32 | }
33 |
34 | #endregion
35 |
36 | #region Methods
37 |
38 | private void OnClassifyImage(object sender, EventArgs e)
39 | {
40 | try
41 | {
42 | _loggingService.Log("Classify drawing has started");
43 |
44 | var imagePreprocessor = new ImagePreprocessor();
45 |
46 | double[] pixels = imagePreprocessor.Preprocess(_uploadImageView.Image);
47 |
48 | IPredictionModel predictionModel = Global.PredictionModel;
49 |
50 | double[] prediction = predictionModel.Predict(pixels);
51 |
52 | _uploadImageView.ProcessPrediction(prediction);
53 |
54 | _loggingService.Log("Classify drawing has completed");
55 | }
56 | catch (NullReferenceException exception)
57 | {
58 | _loggingService.Log(exception);
59 |
60 | _messageService.ShowMessage("No image was uploaded. Please upload an image and try again.", "Upload error", icon: MessageBoxIcon.Information);
61 | }
62 | catch (Exception exception)
63 | {
64 | _loggingService.Log(exception);
65 |
66 | _messageService.ShowMessage("An error ocurred while classyfing the drawing. Please try again.", "Classification error", icon: MessageBoxIcon.Information);
67 | }
68 | }
69 |
70 | private void OnClearImage(object sender, EventArgs e)
71 | {
72 | _uploadImageView.Clear();
73 | }
74 |
75 | #endregion
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Program.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Windows.Forms;
3 | using DigitRecognizer.Presentation.Infrastructure;
4 | using DigitRecognizer.Presentation.Views.Implementations;
5 |
6 | namespace DigitRecognizer.Presentation
7 | {
8 | internal static class Program
9 | {
10 | ///
11 | /// The main entry point for the application.
12 | ///
13 | [STAThread]
14 | private static void Main()
15 | {
16 | Startup.RegisterDependencies();
17 |
18 | Application.SetUnhandledExceptionMode(UnhandledExceptionMode.CatchException);
19 |
20 | // Option to continue with execution.
21 | Application.ThreadException += ExceptionHandlers.ApplicationOnThreadException;
22 |
23 | // Application will terminate.
24 | AppDomain.CurrentDomain.UnhandledException += ExceptionHandlers.CurrentDomainOnUnhandledException;
25 |
26 | Application.EnableVisualStyles();
27 |
28 | Application.SetCompatibleTextRenderingDefault(false);
29 |
30 | var mainForm = DependencyResolver.Resolve();
31 |
32 | Startup.SetupPresenters(mainForm);
33 |
34 | Application.Run(mainForm);
35 | }
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Properties/AssemblyInfo.cs:
--------------------------------------------------------------------------------
1 | using System.Reflection;
2 | using System.Runtime.InteropServices;
3 |
4 | // General Information about an assembly is controlled through the following
5 | // set of attributes. Change these attribute values to modify the information
6 | // associated with an assembly.
7 | [assembly: AssemblyTitle("DigitRecognizer.Presentation")]
8 | [assembly: AssemblyDescription("")]
9 | [assembly: AssemblyConfiguration("")]
10 | [assembly: AssemblyCompany("")]
11 | [assembly: AssemblyProduct("DigitRecognizer.Presentation")]
12 | [assembly: AssemblyCopyright("Copyright © 2018")]
13 | [assembly: AssemblyTrademark("")]
14 | [assembly: AssemblyCulture("")]
15 |
16 | // Setting ComVisible to false makes the types in this assembly not visible
17 | // to COM components. If you need to access a type in this assembly from
18 | // COM, set the ComVisible attribute to true on that type.
19 | [assembly: ComVisible(false)]
20 |
21 | // The following GUID is for the ID of the typelib if this project is exposed to COM
22 | [assembly: Guid("6caf69ea-9c0c-4571-a808-268c159bdb5c")]
23 |
24 | // Version information for an assembly consists of the following four values:
25 | //
26 | // Major Version
27 | // Minor Version
28 | // Build Number
29 | // Revision
30 | //
31 | // You can specify all the values or you can default the Build and Revision Numbers
32 | // by using the '*' as shown below:
33 | // [assembly: AssemblyVersion("1.0.*")]
34 | [assembly: AssemblyVersion("1.0.0.0")]
35 | [assembly: AssemblyFileVersion("1.0.0.0")]
36 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Properties/Resources.Designer.cs:
--------------------------------------------------------------------------------
1 | //------------------------------------------------------------------------------
2 | //
3 | // This code was generated by a tool.
4 | // Runtime Version:4.0.30319.42000
5 | //
6 | // Changes to this file may cause incorrect behavior and will be lost if
7 | // the code is regenerated.
8 | //
9 | //------------------------------------------------------------------------------
10 |
11 | namespace DigitRecognizer.Presentation.Properties {
12 | using System;
13 |
14 |
15 | ///
16 | /// A strongly-typed resource class, for looking up localized strings, etc.
17 | ///
18 | // This class was auto-generated by the StronglyTypedResourceBuilder
19 | // class via a tool like ResGen or Visual Studio.
20 | // To add or remove a member, edit your .ResX file then rerun ResGen
21 | // with the /str option, or rebuild your VS project.
22 | [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")]
23 | [global::System.Diagnostics.DebuggerNonUserCodeAttribute()]
24 | [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()]
25 | internal class Resources {
26 |
27 | private static global::System.Resources.ResourceManager resourceMan;
28 |
29 | private static global::System.Globalization.CultureInfo resourceCulture;
30 |
31 | [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")]
32 | internal Resources() {
33 | }
34 |
35 | ///
36 | /// Returns the cached ResourceManager instance used by this class.
37 | ///
38 | [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)]
39 | internal static global::System.Resources.ResourceManager ResourceManager {
40 | get {
41 | if (object.ReferenceEquals(resourceMan, null)) {
42 | global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("DigitRecognizer.Presentation.Properties.Resources", typeof(Resources).Assembly);
43 | resourceMan = temp;
44 | }
45 | return resourceMan;
46 | }
47 | }
48 |
49 | ///
50 | /// Overrides the current thread's CurrentUICulture property for all
51 | /// resource lookups using this strongly typed resource class.
52 | ///
53 | [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)]
54 | internal static global::System.Globalization.CultureInfo Culture {
55 | get {
56 | return resourceCulture;
57 | }
58 | set {
59 | resourceCulture = value;
60 | }
61 | }
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Properties/Settings.Designer.cs:
--------------------------------------------------------------------------------
1 | //------------------------------------------------------------------------------
2 | //
3 | // This code was generated by a tool.
4 | // Runtime Version:4.0.30319.42000
5 | //
6 | // Changes to this file may cause incorrect behavior and will be lost if
7 | // the code is regenerated.
8 | //
9 | //------------------------------------------------------------------------------
10 |
11 | namespace DigitRecognizer.Presentation.Properties {
12 |
13 |
14 | [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()]
15 | [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.Editors.SettingsDesigner.SettingsSingleFileGenerator", "17.9.0.0")]
16 | internal sealed partial class Settings : global::System.Configuration.ApplicationSettingsBase {
17 |
18 | private static Settings defaultInstance = ((Settings)(global::System.Configuration.ApplicationSettingsBase.Synchronized(new Settings())));
19 |
20 | public static Settings Default {
21 | get {
22 | return defaultInstance;
23 | }
24 | }
25 | }
26 | }
27 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Properties/Settings.settings:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Services/ILoggingService.cs:
--------------------------------------------------------------------------------
1 | using System;
2 |
3 | namespace DigitRecognizer.Presentation.Services
4 | {
5 | public interface ILoggingService
6 | {
7 | void Log(Exception e);
8 |
9 | void Log(string message);
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Services/IMessageService.cs:
--------------------------------------------------------------------------------
1 | using System.Windows.Forms;
2 |
3 | namespace DigitRecognizer.Presentation.Services
4 | {
5 | public interface IMessageService
6 | {
7 | void ShowMessage(string text, string caption = "", MessageBoxButtons buttons = MessageBoxButtons.OK, MessageBoxIcon icon = MessageBoxIcon.None);
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Services/LoggingService.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.IO;
3 | using System.Text;
4 |
5 | namespace DigitRecognizer.Presentation.Services
6 | {
7 | public class LoggingService : ILoggingService
8 | {
9 | private const string LogFileName = "log.txt";
10 |
11 | public void Log(Exception e)
12 | {
13 | using (var writer = new StreamWriter(LogFileName, true))
14 | {
15 | var builder = new StringBuilder($"{DateTime.Now} - {e.Message}{Environment.NewLine}");
16 |
17 | Exception current = e;
18 | while (current != null)
19 | {
20 | builder.Append(e);
21 |
22 | current = current.InnerException;
23 | }
24 |
25 | builder.AppendLine();
26 | builder.AppendLine();
27 |
28 | writer.Write(builder.ToString());
29 | }
30 | }
31 |
32 | public void Log(string message)
33 | {
34 | if (string.IsNullOrWhiteSpace(message))
35 | {
36 | return;
37 | }
38 |
39 | using (var writer = new StreamWriter(LogFileName, true))
40 | {
41 | var builder = new StringBuilder($"{DateTime.Now} - {message}{Environment.NewLine}");
42 |
43 | builder.AppendLine();
44 |
45 | writer.Write(builder.ToString());
46 | }
47 | }
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Services/MessageService.cs:
--------------------------------------------------------------------------------
1 | using System.Windows.Forms;
2 |
3 | namespace DigitRecognizer.Presentation.Services
4 | {
5 | public class MessageService : IMessageService
6 | {
7 | public void ShowMessage(string text, string caption = "", MessageBoxButtons buttons = MessageBoxButtons.OK, MessageBoxIcon icon = MessageBoxIcon.None)
8 | {
9 | MessageBox.Show(text, caption, buttons, icon);
10 | }
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Startup.cs:
--------------------------------------------------------------------------------
1 | using DigitRecognizer.Presentation.Infrastructure;
2 | using DigitRecognizer.Presentation.Presenters;
3 | using DigitRecognizer.Presentation.Services;
4 | using DigitRecognizer.Presentation.Views.Implementations;
5 | using DigitRecognizer.Presentation.Views.Interfaces;
6 |
7 | namespace DigitRecognizer.Presentation
8 | {
9 | public static class Startup
10 | {
11 | public static void RegisterDependencies()
12 | {
13 | DependencyResolver.Register();
14 |
15 | DependencyResolver.Register();
16 |
17 | DependencyResolver.Register();
18 | }
19 |
20 | public static void SetupPresenters(IMainFormView mainFormView)
21 | {
22 | var messageService = DependencyResolver.Resolve();
23 | var loggingService = DependencyResolver.Resolve();
24 |
25 | var _ = new ApplicationPresenter(mainFormView, messageService, loggingService);
26 | }
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Views/Implementations/MainForm.cs:
--------------------------------------------------------------------------------
1 | using System.Collections.Generic;
2 | using System.Linq;
3 | using System.Windows.Forms;
4 | using DigitRecognizer.Presentation.Views.Interfaces;
5 |
6 | namespace DigitRecognizer.Presentation.Views.Implementations
7 | {
8 | public partial class MainForm : Form, IMainFormView
9 | {
10 | #region Fields
11 |
12 | private BenchmarkView _benchmarkView;
13 | private DrawingView _drawingView;
14 | private UploadImageView _uploadImageView;
15 | private SlidingWindowView _slidingWindowView;
16 | private List _views;
17 |
18 | #endregion
19 |
20 | #region Ctor
21 |
22 | public MainForm()
23 | {
24 | InitializeComponent();
25 |
26 | InitializeViews();
27 |
28 | HideView();
29 | }
30 |
31 | #endregion
32 |
33 | #region Properties
34 |
35 | public IBenchmarkView BenchmarkView => _benchmarkView;
36 |
37 | public IDrawingView DrawingView => _drawingView;
38 |
39 | public IUploadImageView UploadImageView => _uploadImageView;
40 |
41 | public ISlidingWindowView SlidingWindowView => _slidingWindowView;
42 |
43 | #endregion
44 |
45 | #region Methods
46 |
47 | private void InitializeViews()
48 | {
49 | FillControlsCollection();
50 |
51 | FillViewsCollection();
52 | }
53 |
54 | private void FillControlsCollection()
55 | {
56 | _benchmarkView = new BenchmarkView { Dock = DockStyle.Fill };
57 | _drawingView = new DrawingView { Dock = DockStyle.Fill };
58 | _uploadImageView = new UploadImageView { Dock = DockStyle.Fill };
59 | _slidingWindowView = new SlidingWindowView { Dock = DockStyle.Fill, Padding = new Padding(10)};
60 |
61 | Control[] controls =
62 | {
63 | _benchmarkView,
64 | _drawingView,
65 | _uploadImageView,
66 | _slidingWindowView
67 | };
68 |
69 | Controls.AddRange(controls);
70 | }
71 |
72 | private void FillViewsCollection()
73 | {
74 | _views = new List
75 | {
76 | _benchmarkView,
77 | _drawingView,
78 | _uploadImageView,
79 | _slidingWindowView,
80 | this
81 | };
82 |
83 | HideViews();
84 |
85 | _views.First().ShowView();
86 | }
87 |
88 | private void HideViews() => _views.ForEach(x => x.HideView());
89 |
90 | #endregion
91 |
92 | #region Event handlers
93 |
94 | private void BenchmarkToolStripMenuItem_Click(object sender, System.EventArgs e)
95 | {
96 | ToolstripMenuItem_Click_DisplayView(_benchmarkView);
97 | }
98 |
99 | private void DrawingToolStripMenuItem_Click(object sender, System.EventArgs e)
100 | {
101 | ToolstripMenuItem_Click_DisplayView(_drawingView);
102 | }
103 |
104 | private void UploadToolStripMenuItem_Click(object sender, System.EventArgs e)
105 | {
106 | ToolstripMenuItem_Click_DisplayView(_uploadImageView);
107 | }
108 |
109 | private void SlidingWindowToolStripMenuItem_Click(object sender, System.EventArgs e)
110 | {
111 | ToolstripMenuItem_Click_DisplayView(_slidingWindowView);
112 | }
113 |
114 | private static void ToolstripMenuItem_Click_DisplayView(IView view)
115 | {
116 | view.ShowView();
117 | }
118 |
119 | #endregion
120 |
121 | #region IView implementation
122 |
123 | public void ShowView()
124 | {
125 | BringToFront();
126 |
127 | Show();
128 | }
129 |
130 | public void HideView()
131 | {
132 | Hide();
133 | }
134 |
135 | #endregion
136 | }
137 | }
138 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Views/Implementations/SlidingWindowView.Designer.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.Presentation.Views.Implementations
2 | {
3 | partial class SlidingWindowView
4 | {
5 | ///
6 | /// Required designer variable.
7 | ///
8 | private System.ComponentModel.IContainer components = null;
9 |
10 | ///
11 | /// Clean up any resources being used.
12 | ///
13 | /// true if managed resources should be disposed; otherwise, false.
14 | protected override void Dispose(bool disposing)
15 | {
16 | if (disposing && (components != null))
17 | {
18 | components.Dispose();
19 | }
20 | base.Dispose(disposing);
21 | }
22 |
23 | #region Component Designer generated code
24 |
25 | ///
26 | /// Required method for Designer support - do not modify
27 | /// the contents of this method with the code editor.
28 | ///
29 | private void InitializeComponent()
30 | {
31 | this.panelDrawing = new System.Windows.Forms.Panel();
32 | this.btnClearDrawing = new System.Windows.Forms.Button();
33 | this.btnClassifyDrawing = new System.Windows.Forms.Button();
34 | this.SuspendLayout();
35 | //
36 | // panelDrawing
37 | //
38 | this.panelDrawing.BorderStyle = System.Windows.Forms.BorderStyle.FixedSingle;
39 | this.panelDrawing.Dock = System.Windows.Forms.DockStyle.Top;
40 | this.panelDrawing.Location = new System.Drawing.Point(0, 0);
41 | this.panelDrawing.Name = "panelDrawing";
42 | this.panelDrawing.Size = new System.Drawing.Size(1178, 628);
43 | this.panelDrawing.TabIndex = 0;
44 | //
45 | // btnClearDrawing
46 | //
47 | this.btnClearDrawing.Anchor = ((System.Windows.Forms.AnchorStyles)((System.Windows.Forms.AnchorStyles.Bottom | System.Windows.Forms.AnchorStyles.Left)));
48 | this.btnClearDrawing.Font = new System.Drawing.Font("Microsoft Sans Serif", 9.75F, System.Drawing.FontStyle.Regular, System.Drawing.GraphicsUnit.Point, ((byte)(0)));
49 | this.btnClearDrawing.Location = new System.Drawing.Point(136, 634);
50 | this.btnClearDrawing.Name = "btnClearDrawing";
51 | this.btnClearDrawing.Size = new System.Drawing.Size(127, 30);
52 | this.btnClearDrawing.TabIndex = 24;
53 | this.btnClearDrawing.Text = "Clear drawing";
54 | this.btnClearDrawing.UseVisualStyleBackColor = true;
55 | //
56 | // btnClassifyDrawing
57 | //
58 | this.btnClassifyDrawing.Anchor = ((System.Windows.Forms.AnchorStyles)((System.Windows.Forms.AnchorStyles.Bottom | System.Windows.Forms.AnchorStyles.Left)));
59 | this.btnClassifyDrawing.Font = new System.Drawing.Font("Microsoft Sans Serif", 9.75F, System.Drawing.FontStyle.Regular, System.Drawing.GraphicsUnit.Point, ((byte)(0)));
60 | this.btnClassifyDrawing.Location = new System.Drawing.Point(3, 634);
61 | this.btnClassifyDrawing.Name = "btnClassifyDrawing";
62 | this.btnClassifyDrawing.Size = new System.Drawing.Size(127, 30);
63 | this.btnClassifyDrawing.TabIndex = 23;
64 | this.btnClassifyDrawing.Text = "Classify drawing";
65 | this.btnClassifyDrawing.UseVisualStyleBackColor = true;
66 | //
67 | // SlidingWindowView
68 | //
69 | this.AutoScaleDimensions = new System.Drawing.SizeF(6F, 13F);
70 | this.AutoScaleMode = System.Windows.Forms.AutoScaleMode.Font;
71 | this.BackColor = System.Drawing.Color.White;
72 | this.Controls.Add(this.btnClearDrawing);
73 | this.Controls.Add(this.btnClassifyDrawing);
74 | this.Controls.Add(this.panelDrawing);
75 | this.Name = "SlidingWindowView";
76 | this.Size = new System.Drawing.Size(1178, 667);
77 | this.ResumeLayout(false);
78 |
79 | }
80 |
81 | #endregion
82 |
83 | private System.Windows.Forms.Panel panelDrawing;
84 | private System.Windows.Forms.Button btnClearDrawing;
85 | private System.Windows.Forms.Button btnClassifyDrawing;
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Views/Implementations/UploadImageView.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Drawing;
3 | using System.IO;
4 | using System.Windows.Forms;
5 | using DigitRecognizer.Core.Utilities;
6 | using DigitRecognizer.Presentation.Infrastructure;
7 | using DigitRecognizer.Presentation.Views.Interfaces;
8 |
9 | namespace DigitRecognizer.Presentation.Views.Implementations
10 | {
11 | public partial class UploadImageView : UserControl, IUploadImageView
12 | {
13 | #region Fields
14 |
15 | private Image _image;
16 | private OpenFileDialog _openFileDialog;
17 |
18 | #endregion
19 |
20 | #region Ctor
21 |
22 | public UploadImageView()
23 | {
24 | InitializeComponent();
25 |
26 | Clear();
27 |
28 | _openFileDialog = new OpenFileDialog
29 | {
30 | InitialDirectory = Environment.GetFolderPath(Environment.SpecialFolder.Desktop),
31 | Filter = @"Image|*.bmp;*.png;*.jpeg;*.jpg"
32 | };
33 |
34 | PanelDoubleBuffering.Enable(uploadPanel);
35 |
36 | btnClassifyImg.Click += BtnClassifyImgOnClick;
37 |
38 | btnClearImg.Click += BtnClearImgOnClick;
39 |
40 | btnUploadImg.Click += BtnUploadImgOnClick;
41 |
42 | uploadPanel.Paint += UploadPanelOnPaint;
43 | }
44 |
45 | #endregion
46 |
47 | #region Properties
48 |
49 | public Image Image => _image ?? throw new NullReferenceException(nameof(_image));
50 |
51 | #endregion
52 |
53 | #region Event handlers
54 |
55 | public event EventHandler ClassifyImage;
56 | public event EventHandler ClearImage;
57 |
58 | #endregion
59 |
60 | #region Methods
61 |
62 | private void BtnClassifyImgOnClick(object sender, EventArgs e)
63 | {
64 | ClassifyImage?.Invoke(this, EventArgs.Empty);
65 | }
66 |
67 | private void BtnClearImgOnClick(object sender, EventArgs e)
68 | {
69 | ClearImage?.Invoke(this, EventArgs.Empty);
70 | }
71 |
72 | private void BtnUploadImgOnClick(object sender, EventArgs e)
73 | {
74 | DialogResult dlgResult = _openFileDialog.ShowDialog();
75 |
76 | if (dlgResult == DialogResult.OK)
77 | {
78 | Image rawImage = Image.FromFile(_openFileDialog.FileName);
79 |
80 | _image = ImageUtilities.Resize(rawImage, uploadPanel.Width, uploadPanel.Height);
81 |
82 | predictionPane.Clear();
83 |
84 | uploadPanel.Invalidate();
85 | }
86 | }
87 |
88 | private void UploadPanelOnPaint(object sender, PaintEventArgs e)
89 | {
90 | if (_image is null)
91 | {
92 | return;
93 | }
94 |
95 | e.Graphics.DrawImage(_image, Point.Empty);
96 | }
97 |
98 | protected override void OnResize(EventArgs e)
99 | {
100 | uploadPanel.Top = (Height - uploadPanel.Height) / 2;
101 | uploadPanel.Left = (Width - (uploadPanel.Width + predictionPane.Width - 35)) / 2;
102 |
103 | lblInstructions.Left = uploadPanel.Left - 5;
104 | lblInstructions.Top = uploadPanel.Top - lblInstructions.Height - 5;
105 |
106 | int top = uploadPanel.Bottom + 5;
107 |
108 | btnClassifyImg.Left = uploadPanel.Left;
109 | btnClassifyImg.Top = top;
110 |
111 | btnClearImg.Left = btnClassifyImg.Right + 5;
112 | btnClearImg.Top = top;
113 |
114 | btnUploadImg.Left = uploadPanel.Right - btnUploadImg.Width;
115 | btnUploadImg.Top = top;
116 |
117 | predictionPane.Left = uploadPanel.Right + 40;
118 | predictionPane.Top = uploadPanel.Top;
119 |
120 | base.OnResize(e);
121 | }
122 |
123 | #endregion
124 |
125 | #region IUploadImageView implementation
126 |
127 | public void Clear()
128 | {
129 | _image = null;
130 |
131 | predictionPane.Clear();
132 |
133 | uploadPanel.Invalidate();
134 | }
135 |
136 | public void ProcessPrediction(double[] prediction)
137 | {
138 | predictionPane.ProcessPrediction(prediction);
139 | }
140 |
141 | public void ShowView()
142 | {
143 | BringToFront();
144 |
145 | Show();
146 | }
147 |
148 | public void HideView()
149 | {
150 | Hide();
151 | }
152 |
153 | #endregion
154 | }
155 | }
156 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Views/Interfaces/IBenchmarkView.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using DigitRecognizer.Presentation.Models;
3 |
4 | namespace DigitRecognizer.Presentation.Views.Interfaces
5 | {
6 | public interface IBenchmarkView : IView
7 | {
8 | event EventHandler RunBenchmark;
9 | event EventHandler CancelBenchmark;
10 |
11 | bool IsBenchmarkRunning { get; set; }
12 |
13 | void PerformProgressStep();
14 | void SetAccuracy(int accuracy);
15 | void DrawGrid(ImageGridModel model);
16 | void ResetView();
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Views/Interfaces/IDrawingView.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Drawing;
3 |
4 | namespace DigitRecognizer.Presentation.Views.Interfaces
5 | {
6 | public interface IDrawingView : IView
7 | {
8 | event EventHandler ClassifyDrawing;
9 | event EventHandler ClearDrawing;
10 |
11 | void Clear();
12 | void ProcessPrediction(double[] prediction);
13 |
14 | Image Drawing { get; }
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Views/Interfaces/IMainFormView.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.Presentation.Views.Interfaces
2 | {
3 | ///
4 | /// Interface for the main form view.
5 | ///
6 | public interface IMainFormView : IView
7 | {
8 | ///
9 | /// Gets the benchmark view.
10 | ///
11 | IBenchmarkView BenchmarkView { get; }
12 |
13 | ///
14 | /// Gets the drawing view.
15 | ///
16 | IDrawingView DrawingView { get; }
17 |
18 | ///
19 | /// Gets the upload image view.
20 | ///
21 | IUploadImageView UploadImageView { get; }
22 |
23 | ///
24 | /// Gets the sliding window view.
25 | ///
26 | ISlidingWindowView SlidingWindowView { get; }
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Views/Interfaces/ISlidingWindowView.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Drawing;
3 | using DigitRecognizer.Core.Data;
4 |
5 | namespace DigitRecognizer.Presentation.Views.Interfaces
6 | {
7 | public interface ISlidingWindowView : IView
8 | {
9 | event EventHandler ClassifyDrawing;
10 | event EventHandler ClearDrawing;
11 |
12 | void DrawBoundingBox(BoundingBox boundingBox, int prediction, double predictionAccuracy);
13 | void Clear();
14 |
15 | Image Drawing { get; }
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Views/Interfaces/IUploadImageView.cs:
--------------------------------------------------------------------------------
1 | using System;
2 | using System.Drawing;
3 |
4 | namespace DigitRecognizer.Presentation.Views.Interfaces
5 | {
6 | public interface IUploadImageView : IView
7 | {
8 | event EventHandler ClassifyImage;
9 | event EventHandler ClearImage;
10 |
11 | void Clear();
12 | void ProcessPrediction(double[] prediction);
13 |
14 | Image Image { get; }
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/Views/Interfaces/IView.cs:
--------------------------------------------------------------------------------
1 | namespace DigitRecognizer.Presentation.Views.Interfaces
2 | {
3 | public interface IView
4 | {
5 | void ShowView();
6 | void HideView();
7 | }
8 | }
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.Presentation/favico.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/DigitRecognizer.Presentation/favico.ico
--------------------------------------------------------------------------------
/DigitRecognizer/DigitRecognizer.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio 15
4 | VisualStudioVersion = 15.0.27703.2000
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DigitRecognizer.Core", "DigitRecognizer.Core\DigitRecognizer.Core.csproj", "{BA726236-C7EC-41CA-ABB6-E4FE7E784939}"
7 | EndProject
8 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DigitRecognizer.Engine", "DigitRecognizer.Engine\DigitRecognizer.Engine.csproj", "{00CE9C72-5F92-4555-BB96-D0E77B9A2390}"
9 | EndProject
10 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DigitRecognizer.MachineLearning", "DigitRecognizer.MachineLearning\DigitRecognizer.MachineLearning.csproj", "{BF82C3DF-0EFE-45DB-8E3D-1FFDD548AE73}"
11 | EndProject
12 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DigitRecognizer.Presentation", "DigitRecognizer.Presentation\DigitRecognizer.Presentation.csproj", "{6CAF69EA-9C0C-4571-A808-268C159BDB5C}"
13 | EndProject
14 | Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "DigitRecognizer.DatasetExpansion", "DigitRecognizer.DatasetExpansion\DigitRecognizer.DatasetExpansion.csproj", "{757CF815-DD01-4F66-8E9F-DB3FF9B4A463}"
15 | EndProject
16 | Global
17 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
18 | Debug|Any CPU = Debug|Any CPU
19 | Release|Any CPU = Release|Any CPU
20 | EndGlobalSection
21 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
22 | {BA726236-C7EC-41CA-ABB6-E4FE7E784939}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
23 | {BA726236-C7EC-41CA-ABB6-E4FE7E784939}.Debug|Any CPU.Build.0 = Debug|Any CPU
24 | {BA726236-C7EC-41CA-ABB6-E4FE7E784939}.Release|Any CPU.ActiveCfg = Release|Any CPU
25 | {BA726236-C7EC-41CA-ABB6-E4FE7E784939}.Release|Any CPU.Build.0 = Release|Any CPU
26 | {00CE9C72-5F92-4555-BB96-D0E77B9A2390}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
27 | {00CE9C72-5F92-4555-BB96-D0E77B9A2390}.Debug|Any CPU.Build.0 = Debug|Any CPU
28 | {00CE9C72-5F92-4555-BB96-D0E77B9A2390}.Release|Any CPU.ActiveCfg = Release|Any CPU
29 | {00CE9C72-5F92-4555-BB96-D0E77B9A2390}.Release|Any CPU.Build.0 = Release|Any CPU
30 | {BF82C3DF-0EFE-45DB-8E3D-1FFDD548AE73}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
31 | {BF82C3DF-0EFE-45DB-8E3D-1FFDD548AE73}.Debug|Any CPU.Build.0 = Debug|Any CPU
32 | {BF82C3DF-0EFE-45DB-8E3D-1FFDD548AE73}.Release|Any CPU.ActiveCfg = Release|Any CPU
33 | {BF82C3DF-0EFE-45DB-8E3D-1FFDD548AE73}.Release|Any CPU.Build.0 = Release|Any CPU
34 | {6CAF69EA-9C0C-4571-A808-268C159BDB5C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
35 | {6CAF69EA-9C0C-4571-A808-268C159BDB5C}.Debug|Any CPU.Build.0 = Debug|Any CPU
36 | {6CAF69EA-9C0C-4571-A808-268C159BDB5C}.Release|Any CPU.ActiveCfg = Release|Any CPU
37 | {6CAF69EA-9C0C-4571-A808-268C159BDB5C}.Release|Any CPU.Build.0 = Release|Any CPU
38 | {757CF815-DD01-4F66-8E9F-DB3FF9B4A463}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
39 | {757CF815-DD01-4F66-8E9F-DB3FF9B4A463}.Debug|Any CPU.Build.0 = Debug|Any CPU
40 | {757CF815-DD01-4F66-8E9F-DB3FF9B4A463}.Release|Any CPU.ActiveCfg = Release|Any CPU
41 | {757CF815-DD01-4F66-8E9F-DB3FF9B4A463}.Release|Any CPU.Build.0 = Release|Any CPU
42 | EndGlobalSection
43 | GlobalSection(SolutionProperties) = preSolution
44 | HideSolutionNode = FALSE
45 | EndGlobalSection
46 | GlobalSection(ExtensibilityGlobals) = postSolution
47 | SolutionGuid = {72EDBB0B-99B8-4D0F-A643-DEE415D96034}
48 | EndGlobalSection
49 | EndGlobal
50 |
--------------------------------------------------------------------------------
/DigitRecognizer/Images/favico.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Images/favico.ico
--------------------------------------------------------------------------------
/DigitRecognizer/Images/left_arrow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Images/left_arrow.png
--------------------------------------------------------------------------------
/DigitRecognizer/Images/right_arrow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Images/right_arrow.png
--------------------------------------------------------------------------------
/DigitRecognizer/Models/1f15fd63-8d73-4128-9d27-0090df8ac1ba-0.9849.nn:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/1f15fd63-8d73-4128-9d27-0090df8ac1ba-0.9849.nn
--------------------------------------------------------------------------------
/DigitRecognizer/Models/23682b7e-19b6-4dc2-95f6-563d13c35ffe-0.9841.nn:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/23682b7e-19b6-4dc2-95f6-563d13c35ffe-0.9841.nn
--------------------------------------------------------------------------------
/DigitRecognizer/Models/5bbc9588-713f-40f4-8d07-e349a089e0b2-0.9845.nn:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/5bbc9588-713f-40f4-8d07-e349a089e0b2-0.9845.nn
--------------------------------------------------------------------------------
/DigitRecognizer/Models/780baa22-c1c6-44ec-b13c-e867dba48010-0.9845.nn:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/780baa22-c1c6-44ec-b13c-e867dba48010-0.9845.nn
--------------------------------------------------------------------------------
/DigitRecognizer/Models/8edc3712-2eba-4f2c-879b-1a280687f9b4-0.9837.nn:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/8edc3712-2eba-4f2c-879b-1a280687f9b4-0.9837.nn
--------------------------------------------------------------------------------
/DigitRecognizer/Models/b7d29bbc-a55c-4f26-9655-db392a504105-0.9845.nn:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/b7d29bbc-a55c-4f26-9655-db392a504105-0.9845.nn
--------------------------------------------------------------------------------
/DigitRecognizer/Models/c0a9a2e0-0cca-4943-8755-068330811b1d-0.9827.nn:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/c0a9a2e0-0cca-4943-8755-068330811b1d-0.9827.nn
--------------------------------------------------------------------------------
/DigitRecognizer/Models/c9234ebb-c75b-42c1-88ad-b75650d7102a-0.9857.nn:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/c9234ebb-c75b-42c1-88ad-b75650d7102a-0.9857.nn
--------------------------------------------------------------------------------
/DigitRecognizer/Models/ffed716d-a8a3-48fd-b0eb-98900357be81-0.9823.nn:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/m-jovanovic/digit-recognizer/98aa36d4a9aed230aaddce24c10d9496843e8879/DigitRecognizer/Models/ffed716d-a8a3-48fd-b0eb-98900357be81-0.9823.nn
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Milan Jovanović
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Digit Recognizer
2 |
3 | Neural network that classifies MNIST dataset, along with a desktop application with visual representation of the classficiation process.
4 |
5 | If you want to run this locally, unpack the contents (unzip) of the `Dataset` folder. It contains the images required for training and testing the neural network.
6 |
7 | Check out this video for more explantation: [**I Built a Neural Network in C# From Scratch. Here's What I Learned...**](https://youtu.be/wgNZWnua-90)
8 |
--------------------------------------------------------------------------------