├── +tests ├── +integration │ └── NetworkTrainingTest.m ├── +unit │ ├── MapToCornerTest.m │ ├── ReadDataTest.m │ └── UndistortTest.m ├── 0001_01.jpg ├── Label_1.png ├── bad_data1 ├── bad_data2 ├── data1 └── data2 ├── .gitignore ├── LICENSE ├── README.md ├── docs ├── classification_training.PNG ├── deep_sudoku_solver.mlx ├── example.jpg ├── final_segmentations.jpg ├── result.jpg ├── segmentation_training.PNG ├── solver_result.jpg └── sudoku-solver.gif ├── install.m ├── setupDataGitHub.m ├── src └── sudoku │ ├── +sudoku │ ├── +synth │ │ ├── Mnist.m │ │ ├── addBlur.m │ │ ├── addDistortion.m │ │ ├── addLines.m │ │ ├── addNoise.m │ │ ├── addPrintedDigit.m │ │ ├── addShading.m │ │ ├── addSharpening.m │ │ ├── getMnistData.m │ │ ├── insertWrittenDigit.m │ │ ├── makePaper.m │ │ ├── makeSyntheticDigit.m │ │ └── rescaleChannels.m │ ├── +training │ │ ├── extractNumberData.m │ │ ├── generateSyntheticNumberData.m │ │ ├── getNumberData.m │ │ ├── getSudokuData.m │ │ ├── parseFilename.m │ │ ├── readNumberLabels.m │ │ ├── resnetLike.m │ │ ├── vggLike.m │ │ └── weightLossByFrequency.m │ ├── PuzzleSolver.m │ ├── extractNumbers.m │ ├── fetchMnistData.m │ ├── findPrimaryRegion.m │ ├── getMapLines.m │ ├── intersect.m │ ├── intersectAll.m │ ├── intersectionsFromLabel.m │ ├── postProcessMask.m │ ├── prepareMnistData.m │ ├── segmentPuzzle.m │ ├── selectAndSort.m │ ├── sudokuEngine.m │ ├── thresholdImage.m │ ├── trainNumberNetwork.m │ ├── trainSemanticSegmentation.m │ ├── undistort.m │ └── visualiseClassifier.m │ └── sudokuRoot.m └── uninstall.m /+tests/+integration/NetworkTrainingTest.m: -------------------------------------------------------------------------------- 1 | classdef NetworkTrainingTest < matlab.unittest.TestCase 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | methods (Test) 6 | function testNumberNetwork(testCase) 7 | nSamples = 10; 8 | initialChannels = 4; 9 | imageSize = [64, 64]; 10 | 11 | [train, ~] = sudoku.training.getNumberData(nSamples, false); 12 | 13 | options = trainingOptions('sgdm', ... 14 | 'ExecutionEnvironment', 'cpu', ... 15 | 'MaxEpochs', 2, ... 16 | 'MiniBatchSize', 64); 17 | 18 | layers = sudoku.training.vggLike(initialChannels, imageSize); 19 | net = trainNetwork(train, layers, options); 20 | end 21 | 22 | function testSegmentationNetwork(testCase) 23 | inputSize = [64, 64, 3]; 24 | numClasses = 2; 25 | networkDepth = 2; 26 | trainFraction = 0.1; 27 | 28 | %% Get the training data 29 | [imagesTrain, labelsTrain] = sudoku.training.getSudokuData(trainFraction, false); 30 | 31 | train = pixelLabelImageDatastore(imagesTrain, labelsTrain, ... 32 | 'OutputSize', inputSize(1:2)); 33 | 34 | %% Setup the network 35 | layers = segnetLayers(inputSize, numClasses, networkDepth); 36 | layers = sudoku.training.weightLossByFrequency(layers, train); 37 | 38 | opts = trainingOptions('sgdm', ... 39 | 'InitialLearnRate', 0.005, ... 40 | 'MaxEpochs', 2, ... 41 | 'MiniBatchSize', 2); 42 | 43 | %% Train 44 | net = trainNetwork(train, layers, opts); %#ok 45 | 46 | 47 | end 48 | end 49 | end -------------------------------------------------------------------------------- /+tests/+unit/MapToCornerTest.m: -------------------------------------------------------------------------------- 1 | classdef MapToCornerTest < matlab.unittest.TestCase 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | methods (Test) 6 | function testLineIntersectionOrigin(testCase) 7 | line1 = [0, 0]; 8 | line2 = [0, pi/2]; 9 | 10 | intersection = sudoku.intersect(line1, line2); 11 | 12 | testCase.verifyEqual(intersection, [0, 0]); 13 | end 14 | 15 | function testLineIntersectionPositive(testCase) 16 | line1 = [100, 0]; 17 | line2 = [100, pi/2]; 18 | 19 | intersection = sudoku.intersect(line1, line2); 20 | 21 | testCase.verifyEqual(intersection, [100, 100]); 22 | end 23 | 24 | function testLineIntersectionNegative(testCase) 25 | line1 = [100, 0]; 26 | line2 = [-100, -pi/2]; 27 | 28 | intersection = sudoku.intersect(line1, line2); 29 | 30 | testCase.verifyEqual(intersection, [100, 100]); 31 | end 32 | 33 | function testLineIntersectionParallel(testCase) 34 | line1 = [0, 0]; 35 | line2 = [100, 0]; 36 | 37 | intersect = @() sudoku.intersect(line1, line2); 38 | 39 | testCase.verifyError(intersect, 'sudoku:intersect:noIntersection'); 40 | end 41 | 42 | function testMultiLineIntersect(testCase) 43 | lines = [0, 0; 44 | 100, 0; 45 | 0, pi/2; 46 | 100, pi/2]; 47 | 48 | intersections = sudoku.intersectAll(lines); 49 | 50 | testCase.verifyEqual(size(intersections, 1), 6); 51 | testCase.verifyTrue(testCase.pointInArray(intersections, [0, 0], 1e-6)); 52 | testCase.verifyTrue(testCase.pointInArray(intersections, [100, 0], 1e-6)); 53 | testCase.verifyTrue(testCase.pointInArray(intersections, [0, 100], 1e-6)); 54 | testCase.verifyTrue(testCase.pointInArray(intersections, [100, 100], 1e-6)); 55 | end 56 | 57 | function testMultiLineIntersectParallel(testCase) 58 | lines = [0, 0; 59 | 100, 0;]; 60 | 61 | intersections = sudoku.intersectAll(lines); 62 | 63 | testCase.verifyEqual(size(intersections, 1), 1); 64 | testCase.verifyEqual(intersections, [NaN, NaN]); 65 | end 66 | 67 | function testMapToIntersections(testCase) 68 | im = imread('+tests/Label_1.png'); 69 | map = im == 1; 70 | 71 | lines = sudoku.getMapLines(map); 72 | intersections = sudoku.intersectAll(lines); 73 | 74 | testCase.verifyEqual(size(intersections, 1), 6); 75 | testCase.verifyTrue(testCase.pointInArray(intersections, [2180, 788], 1)); 76 | testCase.verifyTrue(testCase.pointInArray(intersections, [2781, 943], 1)); 77 | testCase.verifyTrue(testCase.pointInArray(intersections, [1933, 1215], 1)); 78 | testCase.verifyTrue(testCase.pointInArray(intersections, [2596, 1412], 1)); 79 | end 80 | 81 | function testMapToSortedIntersections(testCase) 82 | im = imread('+tests/Label_1.png'); 83 | map = im == 1; 84 | expectedIntersections = [2180, 788; 85 | 2781, 943; 86 | 2596, 1412; 87 | 1933, 1215;]; 88 | 89 | lines = sudoku.getMapLines(map); 90 | intersections = sudoku.intersectAll(lines); 91 | intersections = sudoku.selectAndSort(intersections); 92 | 93 | testCase.verifyEqual(size(intersections, 1), 4); 94 | testCase.verifyEqual(intersections, expectedIntersections, 'AbsTol', 1); 95 | end 96 | 97 | function testSelectAndSortRemoveParallel(testCase) 98 | intersections = [0, 0; 99 | 100, 0; 100 | NaN, NaN; 101 | 100, 100]; 102 | 103 | intersections = sudoku.selectAndSort(intersections); 104 | 105 | testCase.verifyEqual(size(intersections, 1), 3); 106 | end 107 | 108 | function testSelectAndSortRemoveDistant(testCase) 109 | intersections = [0, 0; 110 | 100, 0; 111 | 0, 100; 112 | 100, 100; 113 | 500, 1000]; 114 | 115 | intersections = sudoku.selectAndSort(intersections); 116 | 117 | testCase.verifyEqual(size(intersections, 1), 4); 118 | end 119 | end 120 | 121 | methods (Static) 122 | function tf = pointInArray(points, testPoint, tolerance) 123 | testResult = sum(sum(abs(points - testPoint) < tolerance, 2) == 2); 124 | tf = testResult == 1; 125 | end 126 | end 127 | 128 | end -------------------------------------------------------------------------------- /+tests/+unit/ReadDataTest.m: -------------------------------------------------------------------------------- 1 | classdef ReadDataTest < matlab.unittest.TestCase 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | methods (Test) 6 | function testImportOneLabel(testCase) 7 | data = sudoku.training.readNumberLabels('+tests/data1'); 8 | 9 | testCase.verifyEqual(length(data), 1); 10 | testCase.verifyEqual(size(data('0001')), [9, 9]); 11 | end 12 | 13 | function testImportTwoLabels(testCase) 14 | data = sudoku.training.readNumberLabels('+tests/data2'); 15 | 16 | testCase.verifyEqual(length(data), 2); 17 | testCase.verifyEqual(size(data('0001')), [9, 9]); 18 | testCase.verifyEqual(size(data('0002')), [9, 9]); 19 | end 20 | 21 | function testImportBadNumbers(testCase) 22 | importData = @() sudoku.training.readNumberLabels('+tests/bad_data1'); 23 | 24 | testCase.verifyError(importData, 'sudoku:BadNumberData'); 25 | end 26 | 27 | function testImportBadLabels(testCase) 28 | importData = @() sudoku.training.readNumberLabels('+tests/bad_data2'); 29 | 30 | testCase.verifyError(importData, 'sudoku:DuplicateLabel'); 31 | end 32 | 33 | function testNameParseing(testCase) 34 | testName = 'C:\one\two\0003_04.jpg'; 35 | 36 | [sudokuNumber, repeat] = sudoku.training.parseFilename(testName); 37 | 38 | testCase.verifyEqual(sudokuNumber, '0003'); 39 | testCase.verifyEqual(repeat, '04'); 40 | end 41 | end 42 | end 43 | -------------------------------------------------------------------------------- /+tests/+unit/UndistortTest.m: -------------------------------------------------------------------------------- 1 | classdef UndistortTest < matlab.unittest.TestCase 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | methods (Test) 6 | 7 | function testUndistortion(testCase) 8 | im = imread('+tests/0001_01.jpg'); 9 | imagePoints = [2180, 788; ... 10 | 2781, 943; ... 11 | 1933, 1215; ... 12 | 2596, 1412]; 13 | 14 | imagePoints = imagePoints + 20*(0.5 - 1*rand(4,2)); 15 | 16 | boxWidth = 32; 17 | fullWidth = 9*boxWidth; 18 | worldPoints = [0, 0; ... 19 | fullWidth, 0; ... 20 | 0, fullWidth; ... 21 | fullWidth, fullWidth]; 22 | outputImage = sudoku.undistort(im, imagePoints, worldPoints); 23 | 24 | newIm = mat2cell(outputImage, ... 25 | repmat(boxWidth, 1, 9), ... 26 | repmat(boxWidth, 1, 9), ... 27 | 3); 28 | montage(newIm) 29 | % TODO finish this test 30 | end 31 | 32 | end 33 | 34 | end -------------------------------------------------------------------------------- /+tests/0001_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/deep-sudoku-solver/616f25edffa9154426850d808fa079e0a70e3356/+tests/0001_01.jpg -------------------------------------------------------------------------------- /+tests/Label_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/deep-sudoku-solver/616f25edffa9154426850d808fa079e0a70e3356/+tests/Label_1.png -------------------------------------------------------------------------------- /+tests/bad_data1: -------------------------------------------------------------------------------- 1 | 0001 2 | -5 -3 4 6 -7 8 9 1 2 3 | -6 7 2 -1 -9 -5 3 4 8 4 | 1 -9 -8 3 4 2 5 -6 7 5 | -8 5 9 7 -6 4 2 -3 6 | -4 2 6 -8 5 -3 7 9 -1 7 | -7 1 3 9 -2 4 8 5 -6 8 | 9 -6 1 5 3 7 -2 -8 4 9 | 2 8 7 -4 -1 -9 6 3 -5 10 | 3 4 5 2 -8 6 1 -7 -9 -------------------------------------------------------------------------------- /+tests/bad_data2: -------------------------------------------------------------------------------- 1 | 0001 2 | -5 -3 4 6 -7 8 9 1 2 3 | -6 7 2 -1 -9 -5 3 4 8 4 | 1 -9 -8 3 4 2 5 -6 7 5 | -8 5 9 7 -6 1 4 2 -3 6 | -4 2 6 -8 5 -3 7 9 -1 7 | -7 1 3 9 -2 4 8 5 -6 8 | 9 -6 1 5 3 7 -2 -8 4 9 | 2 8 7 -4 -1 -9 6 3 -5 10 | 3 4 5 2 -8 6 1 -7 -9 11 | 12 | 0001 13 | -5 -3 4 6 -7 8 9 1 2 14 | -6 7 2 -1 -9 -5 3 4 8 15 | 1 -9 -8 3 4 2 5 -6 7 16 | -8 5 9 7 -6 1 4 2 -3 17 | -4 2 6 -8 5 -3 7 9 -1 18 | -7 1 3 9 -2 4 8 5 -6 19 | 9 -6 1 5 3 7 -2 -8 4 20 | 2 8 7 -4 -1 -9 6 3 -5 21 | 3 4 5 2 -8 6 1 -7 -9 22 | -------------------------------------------------------------------------------- /+tests/data1: -------------------------------------------------------------------------------- 1 | 0001 2 | -5 -3 4 6 -7 8 9 1 2 3 | -6 7 2 -1 -9 -5 3 4 8 4 | 1 -9 -8 3 4 2 5 -6 7 5 | -8 5 9 7 -6 1 4 2 -3 6 | -4 2 6 -8 5 -3 7 9 -1 7 | -7 1 3 9 -2 4 8 5 -6 8 | 9 -6 1 5 3 7 -2 -8 4 9 | 2 8 7 -4 -1 -9 6 3 -5 10 | 3 4 5 2 -8 6 1 -7 -9 -------------------------------------------------------------------------------- /+tests/data2: -------------------------------------------------------------------------------- 1 | 0001 2 | -5 -3 4 6 -7 8 9 1 2 3 | -6 7 2 -1 -9 -5 3 4 8 4 | 1 -9 -8 3 4 2 5 -6 7 5 | -8 5 9 7 -6 1 4 2 -3 6 | -4 2 6 -8 5 -3 7 9 -1 7 | -7 1 3 9 -2 4 8 5 -6 8 | 9 -6 1 5 3 7 -2 -8 4 9 | 2 8 7 -4 -1 -9 6 3 -5 10 | 3 4 5 2 -8 6 1 -7 -9 11 | 12 | 0002 13 | -5 -3 4 6 -7 8 9 1 2 14 | -6 7 2 -1 -9 -5 3 4 8 15 | 1 -9 -8 3 4 2 5 -6 7 16 | -8 5 9 7 -6 1 4 2 -3 17 | -4 2 6 -8 5 -3 7 9 -1 18 | -7 1 3 9 -2 4 8 5 -6 19 | 9 -6 1 5 3 7 -2 -8 4 20 | 2 8 7 -4 -1 -9 6 3 -5 21 | 3 4 5 2 -8 6 1 -7 -9 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.mp4 2 | *.asv 3 | *.avi 4 | models/ 5 | raw_data/ 6 | number_data/ 7 | checkpoints/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, The MathWorks, Inc. 2 | All rights reserved. 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 5 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 6 | 3. In all cases, the software is, and all modifications and derivatives of the software shall be, licensed to you solely for use in conjunction with MathWorks products and service offerings. 7 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Sudoku Solver 2 | 3 | __Takes an uncontrolled image of a sudoku puzzle, identifies the location, reads the puzzle, and solves it.__ 4 | 5 | This example was originally put together for the [UK MATLAB Expo](https://www.matlabexpo.com/uk) 2018, for a talk entitled _Computer Vision and Image processing with MATLAB ([video](https://www.mathworks.com/videos/image-processing-and-computer-vision-with-matlab-1541003708736.html), [blog post](https://blogs.mathworks.com/deep-learning/2018/11/15/sudoku-solver-image-processing-and-deep-learning/))_. It is intended to demonstrate the use of a combination of deep learning and image procesing to solve a computer vision problem. 6 | 7 | ![](docs/sudoku-solver.gif) 8 | 9 | ## Getting started 10 | 11 | - Get a copy of the code either by cloning the repository or downloading a .zip 12 | - Run the example live script getting_started.mlx, or see the [usage](usage) section below 13 | 14 | ## Details 15 | 16 | Broadly, the algorithm is divided into four distinct steps: 17 | 18 | 1. Find the sudoku puzzle in an image using deep learning (sematic segmentation) 19 | 2. Extracts each of the 81 number boxes in the puzzle using image processing. 20 | 3. Read the number contained in each box using deep learning. 21 | 4. Solve the puzzle using opimisation. 22 | 23 | For more details see the original [Expo talk](https://www.mathworks.com/videos/image-processing-and-computer-vision-with-matlab-1541003708736.html). 24 | 25 | __Input:__ 26 | 27 | ![](docs/example.jpg) 28 | 29 | __Result:__ 30 | 31 | ![](docs/result.jpg) 32 | 33 | ## Usage 34 | 35 | - In MATLAB set the top level directory as the working directory then runn `install()` to add the required folders to the MATLAB path. 36 | - Run `setupDataGitHub()` to fetch the required training data from GitHub. The data is ~70 MB, downloading and extracting this can take a few minutes. 37 | - Run `sudoku.trainSemanticSegmentation()`. This will train the semantic segmentation network and save the trained network in the `models/` folder. 38 | - Run `sudoku.trainNumberNetwork()`. This will train the number classification network and save the trained network in the `models/` folder. 39 | - Once both networks have been trained you can process an image as follows: 40 | 41 | ```matlab 42 | im = imread("docs/example.jpg"); 43 | % reduce the size of the example image 44 | im = imresize(im, 0.5); 45 | solver = sudoku.PuzzleSolver(); 46 | solution = solver.process(im) 47 | ``` 48 | 49 | ## Training 50 | 51 | For reference the training curves for the two networks should look as follows: 52 | 53 | __Classfication training__ 54 | 55 | ![](docs/classification_training.PNG) 56 | 57 | __Segmentation training__ 58 | 59 | ![](docs/segmentation_training.PNG) 60 | 61 | By the end of training, the segmentation results should look something like this: 62 | 63 | ![](docs/final_segmentations.jpg) 64 | 65 | ## Contributing 66 | 67 | Please file any bug reports or questions as [GitHub issues](https://github.com/mathworks/deep-sudoku-solver/issues). 68 | 69 | _Copyright 2018-2019 The MathWorks, Inc._ -------------------------------------------------------------------------------- /docs/classification_training.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/deep-sudoku-solver/616f25edffa9154426850d808fa079e0a70e3356/docs/classification_training.PNG -------------------------------------------------------------------------------- /docs/deep_sudoku_solver.mlx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/deep-sudoku-solver/616f25edffa9154426850d808fa079e0a70e3356/docs/deep_sudoku_solver.mlx -------------------------------------------------------------------------------- /docs/example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/deep-sudoku-solver/616f25edffa9154426850d808fa079e0a70e3356/docs/example.jpg -------------------------------------------------------------------------------- /docs/final_segmentations.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/deep-sudoku-solver/616f25edffa9154426850d808fa079e0a70e3356/docs/final_segmentations.jpg -------------------------------------------------------------------------------- /docs/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/deep-sudoku-solver/616f25edffa9154426850d808fa079e0a70e3356/docs/result.jpg -------------------------------------------------------------------------------- /docs/segmentation_training.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/deep-sudoku-solver/616f25edffa9154426850d808fa079e0a70e3356/docs/segmentation_training.PNG -------------------------------------------------------------------------------- /docs/solver_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/deep-sudoku-solver/616f25edffa9154426850d808fa079e0a70e3356/docs/solver_result.jpg -------------------------------------------------------------------------------- /docs/sudoku-solver.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/deep-sudoku-solver/616f25edffa9154426850d808fa079e0a70e3356/docs/sudoku-solver.gif -------------------------------------------------------------------------------- /install.m: -------------------------------------------------------------------------------- 1 | function install() 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | thisPath = fileparts(mfilename('fullpath')); 6 | addpath(fullfile(thisPath, 'src', 'sudoku')); 7 | end -------------------------------------------------------------------------------- /setupDataGitHub.m: -------------------------------------------------------------------------------- 1 | function setupDataGitHub() 2 | % setupDataGitHub Copy all the required data files. 3 | 4 | % Copyright 2019 The MathWorks, Inc. 5 | 6 | version = 'v0.0'; 7 | baseURL = 'https://github.com/mathworks/deep-sudoku-solver/releases/download'; 8 | 9 | segmentationURL = sprintf('%s/%s/%s', baseURL, version, 'segmentation_data.zip'); 10 | numberURL = sprintf('%s/%s/%s', baseURL, version, 'number_data.zip'); 11 | 12 | outputRoot = fullfile(sudokuRoot(), 'data'); 13 | 14 | fprintf('Downloading segmentation data...'); 15 | unzip(segmentationURL, outputRoot); 16 | fprintf('done.\n'); 17 | 18 | fprintf('Downloading synthetic number data...'); 19 | unzip(numberURL, outputRoot); 20 | fprintf('done.\n'); 21 | 22 | sudoku.prepareMnistData(); 23 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+synth/Mnist.m: -------------------------------------------------------------------------------- 1 | classdef Mnist < handle 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | properties (Dependent) 6 | Cache 7 | end 8 | 9 | properties 10 | Cache_ 11 | end 12 | 13 | %% Public Methods 14 | methods (Static) 15 | function output = getInstance() 16 | persistent singletonObj 17 | if isempty(singletonObj) || ~isvalid(singletonObj) 18 | singletonObj = sudoku.synth.Mnist(); 19 | end 20 | output = singletonObj; 21 | end 22 | end 23 | 24 | 25 | %% Private helper functions 26 | methods (Access = private) 27 | function obj = Mnist() 28 | end 29 | end 30 | 31 | methods 32 | function clearCache(obj) 33 | obj.Cache_ = []; 34 | end 35 | 36 | function val = get.Cache(obj) 37 | if isempty(obj.Cache_) 38 | dataFile = fullfile(sudokuRoot(), 'data', 'number_data', 'mnist.mat'); 39 | mnist = load(dataFile, 'training'); 40 | obj.Cache_ = mnist.training; 41 | end 42 | val = obj.Cache_; 43 | end 44 | end 45 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+synth/addBlur.m: -------------------------------------------------------------------------------- 1 | function imOut = addBlur(im, maxBlur) 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | imOut = imgaussfilt(im, maxBlur*rand(1)); 6 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+synth/addDistortion.m: -------------------------------------------------------------------------------- 1 | function imOut = addDistortion(im, resolution, border, finalSize, maxError, maxShift) 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | imagePoints = [border, border; 6 | border, resolution - border; 7 | resolution - border, resolution - border 8 | resolution - border, border]; 9 | imagePoints = imagePoints + maxError*(rand(4, 2)-0.5); 10 | % add shift 11 | imagePoints = imagePoints + maxShift*(rand(1, 2) - 0.5); 12 | worldPoints = [0, 0; ... 13 | 0, finalSize; ... 14 | finalSize, finalSize; ... 15 | finalSize, 0]; 16 | imOut = sudoku.undistort(im, imagePoints, worldPoints); 17 | -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+synth/addLines.m: -------------------------------------------------------------------------------- 1 | function im = addLines(im, border, lineWidth, maxColour) 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | resolution = size(im, 1); 6 | positions = [border, 0, border, resolution; 7 | resolution - border, 0, resolution - border, resolution, 8 | 0, border, resolution, border; 9 | 0, resolution - border, resolution, resolution - border]; 10 | lineColour = maxColour*rand(1, 3); 11 | im = insertShape(im, 'line', positions, ... 12 | 'Color', lineColour, ... 13 | 'LineWidth', lineWidth); 14 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+synth/addNoise.m: -------------------------------------------------------------------------------- 1 | function imOut = addNoise(im, maxNoise) 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | noiseMag = maxNoise*rand(1); 6 | imOut = im - uint8(noiseMag*rand(size(im))); 7 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+synth/addPrintedDigit.m: -------------------------------------------------------------------------------- 1 | function im = addPrintedDigit(im, digit, minSize, maxColour, border, maxOffset) 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | possibleFonts = {'Arial', ... 6 | 'Garamond', ... 7 | 'Courier New', ... 8 | 'Gill Sans MT', ... 9 | 'Comic Sans MS', ... 10 | 'Calibri', ... 11 | 'Tahoma', ... 12 | 'Times New Roman', ... 13 | 'Times New Roman Bold', ... 14 | 'Times New Roman Italic', ... 15 | 'Verdana', ... 16 | 'Verdana Bold'}; 17 | font = possibleFonts{randi(numel(possibleFonts))}; 18 | resolution = size(im, 1); 19 | maxSize = resolution - 2*border; 20 | fontSize = round((maxSize - minSize)*rand(1) + minSize); 21 | 22 | textColour = maxColour*rand(1, 3); 23 | position = [resolution/2, resolution/2] + maxOffset.*(rand(1, 2) - 0.5); 24 | 25 | im = insertText(im, position, digit, ... 26 | 'BoxOpacity', 0, ... 27 | 'Font', font, ... 28 | 'FontSize', fontSize, ... 29 | 'TextColor', textColour, ... 30 | 'AnchorPoint', 'center'); 31 | clear insertText % Avoid a strange memory issue 32 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+synth/addShading.m: -------------------------------------------------------------------------------- 1 | function im = addShading(im, maxShade) 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | resolution = size(im, 1); 6 | [xGrid, yGrid] = meshgrid(1:resolution, 1:resolution); 7 | shading = interp2([0, resolution], [0, resolution], ... 8 | maxShade*rand(2), xGrid, yGrid); 9 | im = im - uint8(rand(1)*shading); 10 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+synth/addSharpening.m: -------------------------------------------------------------------------------- 1 | function imOut = addSharpening(im, maxAmount) 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | imOut = imsharpen(im, 'Amount', maxAmount*rand(1)); 6 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+synth/getMnistData.m: -------------------------------------------------------------------------------- 1 | function data = getMnistData() 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | persistent digitData 6 | 7 | if isempty(digitData) 8 | mnist = load('number_data/mnist.mat', 'training'); 9 | digitData = mnist.training; 10 | end 11 | 12 | data = digitData; 13 | 14 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+synth/insertWrittenDigit.m: -------------------------------------------------------------------------------- 1 | function imOut = insertWrittenDigit(im, digit, resolution, border, padding) 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | digitData = sudoku.synth.Mnist.getInstance.Cache; 6 | 7 | matchingDigits = digitData.images(:, :, digitData.labels == digit); 8 | digit = matchingDigits(:,:,randi(size(matchingDigits, 3))); 9 | 10 | textSize = [resolution - 2*border - 2*padding, resolution - 2*border - 2*padding]; 11 | textImage = imresize(digit, textSize); 12 | textImage = repmat(textImage, 1, 1, 3); 13 | textImage = textImage.*(0.5 + 0.5*rand(1, 1, 3)); 14 | textImage = uint8(255*textImage); 15 | textImage = padarray(textImage, ... 16 | [border + padding, border + padding], 0); 17 | textImage = imerode(textImage, strel('disk', round(4*rand(1)))); 18 | imOut = im - textImage; 19 | 20 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+synth/makePaper.m: -------------------------------------------------------------------------------- 1 | function im = makePaper(resolution, minWhite, maxNoise) 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | baseColour = minWhite + rand(1)*(255 - minWhite); 6 | im = 255*ones(resolution, resolution, 3, 'uint8'); 7 | noiseValue = maxNoise*rand(); 8 | im = im - uint8(noiseValue*rand(size(im))); 9 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+synth/makeSyntheticDigit.m: -------------------------------------------------------------------------------- 1 | function im = makeSyntheticDigit(iDigit, isHandwritten) 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | resolution = 300; 6 | border = 50; 7 | im = sudoku.synth.makePaper(resolution, 225, 50); 8 | im = sudoku.synth.addLines(im, border, 2+randi(11), 70); 9 | if iDigit > 0 10 | if isHandwritten 11 | im = sudoku.synth.insertWrittenDigit(im, iDigit, resolution, border, randi(60)); 12 | else 13 | im = sudoku.synth.addPrintedDigit(im, num2str(iDigit), 2*border, 70, border, 20); 14 | end 15 | end 16 | im = sudoku.synth.addShading(im, 155); 17 | im = sudoku.synth.addBlur(im, 2.6); 18 | im = sudoku.synth.addNoise(im, 40); 19 | im = sudoku.synth.rescaleChannels(im, 0.11); 20 | im = sudoku.synth.addSharpening(im, 2); 21 | im = sudoku.synth.addDistortion(im, resolution, border, 64, 30, 60); 22 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+synth/rescaleChannels.m: -------------------------------------------------------------------------------- 1 | function imOut = rescaleChannels(im, maxDecrease) 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | imOut = uint8(double(im).*(1-maxDecrease*rand(1, 1, 3))); 6 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+training/extractNumberData.m: -------------------------------------------------------------------------------- 1 | function extractNumberData(outputFolder) 2 | % extractNumberData Saves individual number images from sudoku images. 3 | % 4 | % sudoku.training.extractNumberData(outputFolder) 5 | % 6 | % Args: 7 | % outputFolder (char): output directory to save the data 8 | 9 | % Copyright 2018, The MathWorks, Inc. 10 | 11 | %% Load data 12 | [images, labels, ~, ~] = sudoku.training.getSudokuData(1, false); 13 | pximds = pixelLabelImageDatastore(images, labels); 14 | dataFile = fullfile(sudokuRoot(), 'data', 'labels', 'numbers.txt'); 15 | digitLabels = sudoku.training.readNumberLabels(dataFile); 16 | 17 | directoryNames = 0:9; 18 | numberSize = 64; 19 | 20 | for iDirectory = 1:numel(directoryNames) 21 | thisDir = fullfile(outputFolder, num2str(directoryNames(iDirectory))); 22 | if ~isfolder(thisDir) 23 | mkdir(thisDir); 24 | end 25 | end 26 | 27 | for iImage = 1:numel(pximds.Images) 28 | fprintf('Extracting number image for image %d.\n', iImage); 29 | data = pximds.readByIndex(iImage); 30 | filename = pximds.Images{iImage}; 31 | 32 | intersections = sudoku.intersectionsFromLabel(data.pixelLabelImage{1} == "sudoku"); 33 | 34 | % add some realistic line error 35 | puzzleArea = polyarea(intersections(:,1), intersections(:,2)); 36 | maxError = sqrt(puzzleArea)/9/10; 37 | imagePoints = intersections + maxError*(0.5 - 1*rand(4, 2)); 38 | 39 | numberImages = extractNumbers(data.inputImage{1}, imagePoints, numberSize); 40 | 41 | % Save images 42 | [puzzle, repeat] = sudoku.training.parseFilename(filename); 43 | if digitLabels.isKey(puzzle) 44 | puzzleLabel = abs(digitLabels(puzzle)); 45 | for iDigitImage = 1:81 46 | thisNumber = puzzleLabel(iDigitImage); 47 | imageName = sprintf('%s_%s_%d.tif', puzzle, repeat, iDigitImage); 48 | filename = fullfile(outputFolder, num2str(thisNumber), imageName); 49 | imwrite(numberImages{iDigitImage}, filename); 50 | end 51 | end 52 | end 53 | end 54 | 55 | function numberImages = extractNumbers(im, imagePoints, numberSize) 56 | % Generate rectified puzzle 57 | fullWidth = 9*numberSize; 58 | worldPoints = [0, 0; ... 59 | fullWidth, 0; ... 60 | fullWidth, fullWidth; ... 61 | 0, fullWidth; ... 62 | ]; 63 | outputImage = sudoku.undistort(im, imagePoints, worldPoints); 64 | 65 | % Convert to individual number images 66 | numberImages = mat2cell(outputImage, ... 67 | repmat(numberSize, 1, 9), ... 68 | repmat(numberSize, 1, 9), ... 69 | 3); 70 | end 71 | -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+training/generateSyntheticNumberData.m: -------------------------------------------------------------------------------- 1 | function generateSyntheticNumberData(outputFolder, nSamples) 2 | % generateSyntheticNumberData generate synthetic number images 3 | % 4 | % sudoku.training.generateSyntheticNumberData(outputFolder, nSamples) 5 | % 6 | % Args: 7 | % outputFolder (char): output directory to save the data 8 | % nSamples (double): number of synthetic images to generate for 9 | % each digit 10 | % 11 | % Note: 12 | % The output folder must be empty. 13 | % The synthetic data generate is carried out in parallel. 14 | 15 | % Copyright 2018, The MathWorks, Inc. 16 | 17 | if isfolder(outputFolder) 18 | existingFiles = dir(fullfile(outputFolder, '*')); 19 | assert(numel(existingFiles) == 2, ... 20 | 'sudoku:generateSyntheticNumberData', ... 21 | 'Output folder is not empty.'); 22 | end 23 | 24 | parfor iDigit = 0:9 25 | fprintf('Generating digit %d', iDigit); 26 | dataDir = fullfile(outputFolder, num2str(iDigit)); 27 | if ~isfolder(dataDir) 28 | mkdir(dataDir); 29 | end 30 | for iSample = 1:nSamples 31 | handWritten = iSample < nSamples/2; 32 | im = sudoku.synth.makeSyntheticDigit(iDigit, handWritten); 33 | imwrite(im, fullfile(dataDir, sprintf('%0.5d.jpg', iSample)), ... 34 | 'Quality', 20 + randi(55)); 35 | if rem(iSample/nSamples*10, 1) == 0 36 | fprintf('.'); 37 | end 38 | end 39 | fprintf('done.\n'); 40 | end 41 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+training/getNumberData.m: -------------------------------------------------------------------------------- 1 | function [train, test] = getNumberData(nSamples, force) 2 | % getNumberData retrieve number classification data 3 | % 4 | % [train, test] = sudoku.training.getNumberData(nSamples, force) 5 | % 6 | % Args: 7 | % nSamples (double): number of synthetic images to generate for 8 | % each digit 9 | % force (boolean): force a regenerate of the data 10 | % 11 | % Returns: 12 | % train (imageDatastore): synthetic samples for training 13 | % test (imageDatastore): real samples for testing. 14 | 15 | % Copyright 2018, The MathWorks, Inc. 16 | 17 | numberDataRoot = fullfile(sudokuRoot, 'data', 'number_data'); 18 | trainLocation = fullfile(numberDataRoot, 'train'); 19 | testLocation = fullfile(numberDataRoot, 'test'); 20 | 21 | if force 22 | if isfolder(trainLocation) 23 | rmdir(trainLocation, 's'); 24 | end 25 | if isfolder(testLocation) 26 | rmdir(testLocation, 's'); 27 | end 28 | end 29 | if nSamples > 0 && ~isfolder(trainLocation) 30 | refreshSyntheticData(trainLocation, nSamples); 31 | end 32 | if ~isfolder(testLocation) 33 | sudoku.training.extractNumberData(testLocation); 34 | end 35 | 36 | if nSamples > 0 37 | % fetch the data that is already there 38 | train = imageDatastore(trainLocation, 'IncludeSubfolders', true, 'LabelSource', 'foldernames'); 39 | 40 | samples = train.countEachLabel; 41 | if any(samples.Count < nSamples) 42 | % Not enough data, regenerate it all 43 | refreshSyntheticData(trainLocation, nSamples); 44 | train = imageDatastore(trainLocation, 'IncludeSubfolders', true, 'LabelSource', 'foldernames'); 45 | elseif any(samples.Count > nSamples) 46 | % Too much data, subsample 47 | train = splitEachLabel(train, nSamples); 48 | end 49 | else 50 | train = []; 51 | end 52 | test = imageDatastore(testLocation, 'IncludeSubfolders', true, 'LabelSource', 'foldernames'); 53 | 54 | end 55 | 56 | function refreshSyntheticData(trainLocation, nSamples) 57 | if isfolder(trainLocation) 58 | rmdir(trainLocation, 's'); 59 | end 60 | sudoku.training.generateSyntheticNumberData(trainLocation, nSamples); 61 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+training/getSudokuData.m: -------------------------------------------------------------------------------- 1 | function [imagesTrain, labelsTrain, imagesTest, labelsTest] = ... 2 | getSudokuData(trainingFraction, includeExtra) 3 | 4 | % Copyright 2018, The MathWorks, Inc. 5 | 6 | rng('default'); 7 | gTruth = loadGroundTruth('sudokuLabels.mat'); 8 | 9 | imageFiles = gTruth.DataSource.Source; 10 | labelFiles = gTruth.LabelData.PixelLabelData; 11 | numImages = numel(imageFiles); 12 | shuffledIndices = randperm(numImages); 13 | N = round(trainingFraction*numImages); 14 | trainIndex = shuffledIndices(1:N); 15 | testIndex = shuffledIndices(N+1:end); 16 | 17 | if includeExtra 18 | gTruth = loadGroundTruth('extraLabels.mat'); 19 | extraImages = gTruth.DataSource.Source; 20 | extraLabels = gTruth.LabelData.PixelLabelData; 21 | trainImageFiles = [imageFiles(trainIndex); extraImages]; 22 | trainLabelFiles = [labelFiles(trainIndex); extraLabels]; 23 | else 24 | trainImageFiles = imageFiles(trainIndex); 25 | trainLabelFiles = labelFiles(trainIndex); 26 | end 27 | 28 | if trainingFraction == 1 29 | imagesTrain = imageDatastore(trainImageFiles); 30 | labelsTrain = pixelLabelDatastore(trainLabelFiles, ... 31 | ["background", "sudoku"], ... 32 | [0, 1]); 33 | imagesTest = []; 34 | labelsTest = []; 35 | elseif trainingFraction == 0 36 | imagesTest = imageDatastore(imageFiles(testIndex)); 37 | labelsTest = pixelLabelDatastore(labelFiles(testIndex), ... 38 | ["background", "sudoku"], ... 39 | [0, 1]); 40 | imagesTrain = []; 41 | labelsTrain = []; 42 | else 43 | imagesTrain = imageDatastore(trainImageFiles); 44 | labelsTrain = pixelLabelDatastore(trainLabelFiles, ... 45 | ["background", "sudoku"], ... 46 | [0, 1]); 47 | imagesTest = imageDatastore(imageFiles(testIndex)); 48 | labelsTest = pixelLabelDatastore(labelFiles(testIndex), ... 49 | ["background", "sudoku"], ... 50 | [0, 1]); 51 | end 52 | 53 | end 54 | 55 | function gTruth = loadGroundTruth(labelFile) 56 | labelDirectory = fullfile(sudokuRoot(), 'data', 'labels'); 57 | data = load(fullfile(labelDirectory, labelFile), 'gTruth'); 58 | gTruth = data.gTruth; 59 | changeFilePaths(gTruth, ... 60 | ["C:\Users\jpinkney\MATLAB Drive Connector\deep-sudoku", fullfile(sudokuRoot(), "data")]); 61 | changeFilePaths(gTruth, ... 62 | ["raw_data", fullfile(sudokuRoot(), "data", "raw_data")]); 63 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+training/parseFilename.m: -------------------------------------------------------------------------------- 1 | function [sudokuNumber, repeat] = parseFilename(testName) 2 | 3 | % Copyright 2018, The MathWorks, Inc. 4 | 5 | [~, name] = fileparts(testName); 6 | 7 | nameParts = split(name, '_'); 8 | sudokuNumber = nameParts{1}; 9 | repeat = nameParts{2}; 10 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+training/readNumberLabels.m: -------------------------------------------------------------------------------- 1 | function data = readNumberLabels(filename) 2 | 3 | % Copyright 2018, The MathWorks, Inc. 4 | 5 | data = containers.Map(); 6 | 7 | fid = fopen(filename); 8 | 9 | while ~feof(fid) 10 | name = textscan(fid, '%s', 1); 11 | name = name{1}; 12 | if isempty(name) 13 | continue 14 | end 15 | 16 | numbers = textscan(fid, '%d', 81); 17 | assert(numel(numbers{1}) == 81, ... 18 | 'sudoku:BadNumberData', ... 19 | 'Number labels for ''%s'' are of incorrect size.', name{1}); 20 | 21 | newName = name{1}; 22 | assert(~ismember(newName, data.keys()), ... 23 | 'sudoku:DuplicateLabel', ... 24 | 'The label ''%s'' is duplicated.', newName); 25 | 26 | data(newName) = reshape(numbers{1}, 9, 9)'; 27 | end 28 | 29 | fclose(fid); 30 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+training/resnetLike.m: -------------------------------------------------------------------------------- 1 | function lgraph = resnetLike() 2 | 3 | % Copyright 2018, The MathWorks, Inc. 4 | 5 | net = resnet50(); 6 | lgraph = layerGraph(net); 7 | lgraph = removeLayers(lgraph, ... 8 | {'input_1', ... 9 | 'avg_pool', ... 10 | 'fc1000', ... 11 | 'fc1000_softmax', ... 12 | 'ClassificationLayer_fc1000'}); 13 | 14 | lgraph = addLayers(lgraph, imageInputLayer([64, 64, 3], 'Name', 'input')); 15 | lgraph = connectLayers(lgraph, 'input', 'conv1'); 16 | 17 | numClasses = 10; 18 | newLayers = [ 19 | averagePooling2dLayer([2, 2], 'Name', 'avg_pool') 20 | fullyConnectedLayer(numClasses,'Name','fc') 21 | softmaxLayer('Name','softmax') 22 | classificationLayer('Name','classoutput')]; 23 | lgraph = addLayers(lgraph,newLayers); 24 | 25 | lgraph = connectLayers(lgraph,'activation_49_relu','avg_pool'); 26 | 27 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+training/vggLike.m: -------------------------------------------------------------------------------- 1 | function layers = vggLike(n, imageSize) 2 | 3 | % Copyright 2018, The MathWorks, Inc. 4 | 5 | layers = [imageInputLayer([imageSize(1), imageSize(2), 3]) 6 | convolution2dLayer([5, 5], n, 'Stride', [3, 3], 'Padding', 'same') 7 | reluLayer 8 | batchNormalizationLayer 9 | ... 10 | convolution2dLayer([3, 3], n, 'Padding', 'same') 11 | reluLayer 12 | batchNormalizationLayer 13 | ... 14 | convolution2dLayer([3, 3], n, 'Padding', 'same') 15 | reluLayer 16 | batchNormalizationLayer 17 | maxPooling2dLayer([2, 2], 'Stride', [2, 2], 'Padding', 'same') 18 | ... 19 | convolution2dLayer([3, 3], 2*n, 'Padding', 'same') 20 | reluLayer 21 | batchNormalizationLayer 22 | ... 23 | convolution2dLayer([3, 3], 2*n, 'Padding', 'same') 24 | reluLayer 25 | batchNormalizationLayer 26 | maxPooling2dLayer([2, 2], 'Stride', [2, 2], 'Padding', 'same') 27 | ... 28 | convolution2dLayer([3, 3], 4*n, 'Padding', 'same') 29 | reluLayer 30 | batchNormalizationLayer 31 | ... 32 | convolution2dLayer([3, 3], 4*n, 'Padding', 'same') 33 | reluLayer 34 | batchNormalizationLayer 35 | maxPooling2dLayer([2, 2], 'Stride', [2, 2], 'Padding', 'same') 36 | ... 37 | convolution2dLayer([3, 3], 8*n, 'Padding', 'same') 38 | reluLayer 39 | batchNormalizationLayer 40 | ... 41 | convolution2dLayer([3, 3], 8*n, 'Padding', 'same') 42 | reluLayer 43 | batchNormalizationLayer 44 | ... 45 | convolution2dLayer([3, 3], 8*n, 'Padding', 'same') 46 | reluLayer 47 | batchNormalizationLayer 48 | maxPooling2dLayer([2, 2], 'Stride', [2, 2], 'Padding', 'same') 49 | ... 50 | fullyConnectedLayer(8*n) 51 | dropoutLayer(0.5) 52 | fullyConnectedLayer(10) 53 | softmaxLayer 54 | classificationLayer]; 55 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/+training/weightLossByFrequency.m: -------------------------------------------------------------------------------- 1 | function lgraph = weightLossByFrequency(lgraph, pixelImageDataStore) 2 | 3 | % Copyright 2018, The MathWorks, Inc. 4 | 5 | pixelCounts = countEachLabel(pixelImageDataStore); 6 | imageFreq = pixelCounts.PixelCount./pixelCounts.ImagePixelCount; 7 | classWeights = mean(imageFreq)./imageFreq; 8 | pxLayer = pixelClassificationLayer('Name', 'labels', ... 9 | 'ClassNames', pixelCounts.Name, ... 10 | 'ClassWeights', classWeights); 11 | lgraph = removeLayers(lgraph, 'pixelLabels'); 12 | lgraph = addLayers(lgraph, pxLayer); 13 | lgraph = connectLayers(lgraph, 'softmax', 'labels'); 14 | 15 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/PuzzleSolver.m: -------------------------------------------------------------------------------- 1 | classdef PuzzleSolver < handle 2 | % PuzzleSolver Solves sudoku puzzles from uncontrolled images. 3 | 4 | % Copyright 2018, The MathWorks, Inc. 5 | 6 | 7 | properties 8 | SegmentationNetwork 9 | SegmentationInputSize 10 | ClassificationNetwork 11 | ClassificationInputSize 12 | end 13 | 14 | methods 15 | function obj = PuzzleSolver(segmentationNetworkPath, numberNetworkPath) 16 | % PuzzleSolver Constructor. 17 | % 18 | % solver = sudoku.PuzzleSolver(segmentationNetworkPath, 19 | % numberNetworkPath); 20 | % 21 | % Args: 22 | % segmentationNetworkPath: Path to a saved 23 | % segmentation network 24 | % numberNetworkPath: Path to a saved 25 | % classiciation network 26 | % 27 | % Notes: 28 | % The paths to networks should point to .mat files which 29 | % contain an appropriate network with the variable name 30 | % 'net'. 31 | % 32 | % If no file paths are specified the object will try and 33 | % look for files with the default names in a 'models' 34 | % directory in the current working directory. 35 | % 36 | % See also: sudoku.trainSemanticSegmentation, 37 | % sudoku.trainNumberNetwork 38 | 39 | if nargin == 0 40 | % look for the default network names 41 | numberNetworkPath = fullfile('models', 'number_network.mat'); 42 | segmentationNetworkPath = fullfile('models', 'segmentation_network.mat'); 43 | end 44 | 45 | obj.loadNetworks(segmentationNetworkPath, numberNetworkPath); 46 | end 47 | 48 | function [solution, intermediateOutputs] = process(obj, im) 49 | % process Solves a sudoku from an input image. 50 | % 51 | % [solution, intermediateOutputs] = solver.process(im); 52 | % 53 | % Args: 54 | % im: input image of a sudoku puzzle 55 | % 56 | % Returns: 57 | % solution (9x9 double): solved puzzle 58 | % intermediateOutputs (cell): outputs from intermediate 59 | % steps in the algorithm. 60 | 61 | 62 | [mask, findSteps] = obj.findPuzzle(im); 63 | [numberImages, extractSteps] = obj.extractNumbers(im, mask); 64 | numberData = obj.readNumbers(numberImages); 65 | solution = obj.solve(numberData); 66 | 67 | intermediateOutputs = [findSteps, ... 68 | {mask}, ... 69 | extractSteps, ... 70 | {numberImages}, ... 71 | {numberData}]; 72 | end 73 | 74 | function [mask, intermediateOutputs] = findPuzzle(obj, im) 75 | % findPuzzle Find the sudoku puzzle in an image 76 | % 77 | % Step 1 - find the puzzle using deep learning 78 | 79 | networkMask = sudoku.segmentPuzzle(im, obj.SegmentationNetwork); 80 | maskPostProcessed = sudoku.postProcessMask(networkMask); 81 | thresholdedImage = sudoku.thresholdImage(im, maskPostProcessed); 82 | mask = sudoku.findPrimaryRegion(thresholdedImage); 83 | 84 | intermediateOutputs = {networkMask, maskPostProcessed, thresholdedImage}; 85 | end 86 | 87 | function [numberImages, intermediateOutputs] = extractNumbers(obj, im, mask) 88 | % extractNumbers Extract the 81 number boxes from the puzzle 89 | % 90 | % Step 2 - find the number boxes using image processing 91 | 92 | intersections = sudoku.intersectionsFromLabel(mask); 93 | numberSize = obj.ClassificationInputSize(1); 94 | numberImages = sudoku.extractNumbers(im, intersections, numberSize); 95 | 96 | intermediateOutputs = {intersections}; 97 | end 98 | 99 | function numberData = readNumbers(obj, numberImages) 100 | % readNumbers Read the contents of the number boxes 101 | % 102 | % Step 3 - read the numbers using deep learning 103 | 104 | numbers = obj.ClassificationNetwork.classify(cat(4, numberImages{:})); 105 | numberData = str2double(string(numbers)); 106 | end 107 | 108 | function solution = solve(~, numberData) 109 | % solve Solve a sudoku puzzle 110 | % 111 | % Step 4 - solve the puzzle using optimisation 112 | 113 | solution = sudoku.sudokuEngine(reshape(numberData, [9, 9])); 114 | end 115 | 116 | function loadNetworks(obj, segmentationNetworkPath, numberNetworkPath) 117 | % load the segmenation network 118 | loadedData = load(segmentationNetworkPath, 'net'); 119 | obj.SegmentationNetwork = loadedData.net; 120 | obj.SegmentationInputSize = obj.SegmentationNetwork.Layers(1).InputSize; 121 | 122 | % load teh number network 123 | loadedData = load(numberNetworkPath, 'net'); 124 | obj.ClassificationNetwork = loadedData.net; 125 | obj.ClassificationInputSize = obj.ClassificationNetwork.Layers(1).InputSize; 126 | end 127 | end 128 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/extractNumbers.m: -------------------------------------------------------------------------------- 1 | function numberImages = extractNumbers(im, imagePoints, numberSize) 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | % Generate rectified puzzle 6 | fullWidth = 9*numberSize; 7 | worldPoints = [0, 0; ... 8 | fullWidth, 0; ... 9 | fullWidth, fullWidth; ... 10 | 0, fullWidth; ... 11 | ]; 12 | outputImage = sudoku.undistort(im, imagePoints, worldPoints); 13 | 14 | % Convert to individual number images 15 | numberImages = mat2cell(outputImage, ... 16 | repmat(numberSize, 1, 9), ... 17 | repmat(numberSize, 1, 9), ... 18 | 3); 19 | end 20 | -------------------------------------------------------------------------------- /src/sudoku/+sudoku/fetchMnistData.m: -------------------------------------------------------------------------------- 1 | function [imgDataTrain, labelsTrain] = fetchMnistData() 2 | % fetchMnistData Downloads and extract minst data. 3 | % Adapted from "Code Examples from Deep Learning Ebook" 4 | 5 | % Copyright 2018 The MathWorks, Inc. 6 | 7 | %% Download MNIST files 8 | dataDirectory = tempname; 9 | files = ["train-images-idx3-ubyte",... 10 | "train-labels-idx1-ubyte"]; 11 | 12 | disp('Downloading files...') 13 | mkdir(dataDirectory) 14 | webPrefix = "http://yann.lecun.com/exdb/mnist/"; 15 | webSuffix = ".gz"; 16 | 17 | filenames = files + webSuffix; 18 | for ii = 1:numel(files) 19 | outputFile = fullfile(dataDirectory, filenames(ii)); 20 | websave(outputFile, webPrefix + filenames(ii)); 21 | gunzip(outputFile, dataDirectory); 22 | end 23 | disp('Download complete.') 24 | 25 | %% Extract the MNIST images into arrays 26 | disp('Preparing MNIST data...'); 27 | 28 | % Read headers for training set image file 29 | fid = fopen(fullfile(dataDirectory, char(files(1))), 'r', 'b'); 30 | fread(fid, 1, 'uint32'); % skip one 31 | numImgs = fread(fid, 1, 'uint32'); 32 | numRows = fread(fid, 1, 'uint32'); 33 | numCols = fread(fid, 1, 'uint32'); 34 | 35 | % Read the data part 36 | rawImgDataTrain = uint8(fread(fid, numImgs * numRows * numCols, 'uint8')); 37 | fclose(fid); 38 | 39 | % Reshape the data part into a 4D array 40 | rawImgDataTrain = reshape(rawImgDataTrain, [numRows, numCols, numImgs]); 41 | rawImgDataTrain = permute(rawImgDataTrain, [2,1,3]); 42 | imgDataTrain(:,:,1,:) = uint8(rawImgDataTrain(:,:,:)); 43 | 44 | % Read headers for training set label file 45 | fid = fopen(fullfile(dataDirectory, char(files(2))), 'r', 'b'); 46 | fread(fid, 1, 'uint32'); % skip one 47 | numLabels = fread(fid, 1, 'uint32'); 48 | 49 | % Read the data for the labels 50 | labelsTrain = fread(fid, numLabels, 'uint8'); 51 | fclose(fid); 52 | 53 | % Process the labels 54 | labelsTrain = categorical(labelsTrain); 55 | 56 | 57 | disp('MNIST data preparation complete.'); 58 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/findPrimaryRegion.m: -------------------------------------------------------------------------------- 1 | function mask = findPrimaryRegion(thresholdedImage) 2 | % findPrimaryRegion Finds the largest region in a thresholded image. 3 | % 4 | % mask = findPrimaryRegion(thresholdedImage); 5 | % 6 | % Args: 7 | % thresholdedImage: binary image for analysis 8 | % 9 | % Returns: 10 | % mask: binary image corresponding to the filled area of the 11 | % largest region. 12 | 13 | % Copyright 2018, The MathWorks, Inc. 14 | 15 | % Analyse the regions and find the largest 16 | regions = regionprops('table', thresholdedImage, ... 17 | 'FilledArea', 'Image', 'FilledImage', 'BoundingBox'); 18 | regions = sortrows(regions, 4); 19 | filledRegion = imerode(regions{end, 3}{1}, ones(3)); 20 | boundingBox = regions{end, 1}; 21 | 22 | % Generate the new mask 23 | inputSize = size(thresholdedImage); 24 | mask = zeros(inputSize); 25 | mask(ceil(boundingBox(2)):floor(boundingBox(2) + boundingBox(4)), ... 26 | ceil(boundingBox(1)):floor(boundingBox(1) + boundingBox(3)), ... 27 | :) = repmat(filledRegion, 1, 1); 28 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/getMapLines.m: -------------------------------------------------------------------------------- 1 | function lines = getMapLines(map) 2 | % getMapLines Finds the four strongest lines in an image 3 | % 4 | % lines = getMapLines(map) 5 | % 6 | % Args: 7 | % map: image in which to find lines 8 | % 9 | % Returns: 10 | % lines: 4x2 array of line parameters in 11 | % [rho (pixels), theta (rads)] format. 12 | % 13 | % See also: sudoku.intersect 14 | 15 | % Copyright 2018, The MathWorks, Inc. 16 | 17 | rhoResolution = 0.5; 18 | thetaResolution = 0.5; 19 | 20 | edgeMap = edge(map); 21 | 22 | [houghTransform, theta, rho] = hough(edgeMap, ... 23 | 'RhoResolution', rhoResolution, ... 24 | 'Theta', -90:thetaResolution:90-thetaResolution); 25 | peaks = houghpeaks(houghTransform, 4, ... 26 | 'Threshold', 0.05*max(houghTransform(:))); 27 | ...'Theta', -90:thetaResolution:90-thetaResolution); 28 | lines = [rho(peaks(:, 1))', pi/180*theta(peaks(:, 2))']; 29 | 30 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/intersect.m: -------------------------------------------------------------------------------- 1 | function intersection = intersect(line1, line2) 2 | % intersect Calculate intersection of two lines. 3 | % 4 | % intersection = intersect(line1, line2) 5 | % 6 | % Args: 7 | % line1: 1x2 array defining line to intersect 8 | % line2: 1x2 array defining line to intersect 9 | % 10 | % Returns: 11 | % intersection: 1x2 array defining the intersection co-ordinate. 12 | % 13 | % Note: 14 | % Each line should be in the format [rho (pixels), theta (rads)]. 15 | % 16 | % If the lines are parallel an 'sudoku:intersect:noIntersection' 17 | % exception will be raised. 18 | % 19 | % See also: sudoku.getMapLines 20 | 21 | % Copyright 2018, The MathWorks, Inc. 22 | 23 | reenableWarning = onCleanup(@() warning('on', 'MATLAB:singularMatrix')); 24 | warning('off', 'MATLAB:singularMatrix'); 25 | 26 | M = [cos(line1(2)), sin(line1(2)); 27 | cos(line2(2)), sin(line2(2))]; 28 | b = [line1(1); line2(1)]; 29 | intersection = M\b; 30 | intersection = intersection'; 31 | 32 | if any(isinf(intersection)) 33 | error('sudoku:intersect:noIntersection', ... 34 | 'Infinte point, lines may be parallel'); 35 | end 36 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/intersectAll.m: -------------------------------------------------------------------------------- 1 | function intersections = intersectAll(lines) 2 | % intersectAll Find all the intersections between a set of lines 3 | % 4 | % intersections = intersectAll(lines) 5 | % 6 | % Args: 7 | % lines: nx2 array of line parameters 8 | % 9 | % Returns: 10 | % intersections: mx2 array of intersections points 11 | % 12 | % Note: 13 | % The number of intersection points will be n choose 2 for n lines. 14 | % 15 | % Any pairs of lines which do not have an intersection will 16 | % correspond to a point [NaN, NaN] in the intersections array. 17 | % 18 | % See also: sudoku.intersect 19 | 20 | % Copyright 2018, The MathWorks, Inc. 21 | 22 | 23 | nLines = size(lines, 1); 24 | intersections = zeros(nchoosek(nLines, 2), 2); 25 | count = 1; 26 | 27 | for iLine1 = 1:(nLines - 1) 28 | for iLine2 = (iLine1 + 1):nLines 29 | try 30 | intersections(count, :) = sudoku.intersect(lines(iLine1, :), ... 31 | lines(iLine2, :)); 32 | catch ME 33 | if strcmp(ME.identifier, 'sudoku:intersect:noIntersection') 34 | intersections(count, :) = [NaN, NaN]; 35 | else 36 | rethrow(ME) 37 | end 38 | end 39 | count = count + 1; 40 | end 41 | end 42 | 43 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/intersectionsFromLabel.m: -------------------------------------------------------------------------------- 1 | function intersections = intersectionsFromLabel(mask) 2 | % intersectionsFromLabel Find line intersections from puzzle mask. 3 | % 4 | % intersections = sudoku.intersectionsFromLabel(mask) 5 | % 6 | % Args: 7 | % mask: binary image corresponding to puzzle pixels in an image 8 | % 9 | % Returns: 10 | % intersections: 4x2 array of intersection points 11 | % 12 | % Note: 13 | % The four strongest lines are found in the image and from these the 14 | % four most plausible intersections are chosen. 15 | % 16 | % The intersections are also sorted to be arranged clockwise around 17 | % the centre of the puzzle. 18 | % 19 | % See also: sudoku.getMapLines, sudoku.intersectAll, sudoku.selectAndSort 20 | 21 | % Copyright 2018, The MathWorks, Inc. 22 | 23 | lines = sudoku.getMapLines(mask); 24 | intersections = sudoku.intersectAll(lines); 25 | intersections = sudoku.selectAndSort(intersections); 26 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/postProcessMask.m: -------------------------------------------------------------------------------- 1 | function maskDilated = postProcessMask(mask) 2 | 3 | % Copyright 2018-2019 The MathWorks, Inc. 4 | 5 | erodeSize = ceil(min(size(mask))/150); 6 | dilateSize = ceil(min(size(mask))/20); 7 | 8 | mask = imclearborder(mask); 9 | mask = imerode(mask, ones(erodeSize)); 10 | maskDilated = imdilate(mask, strel('disk', dilateSize)); 11 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/prepareMnistData.m: -------------------------------------------------------------------------------- 1 | function prepareMnistData() 2 | % prepareMnistData Generates the required mnist data files. 3 | % Adapted from "Code Examples from Deep Learning Ebook" 4 | 5 | % Copyright 2018 The MathWorks, Inc. 6 | 7 | [imgDataTrain, labelsTrain] = sudoku.fetchMnistData(); 8 | 9 | outputFile = 'mnist.mat'; 10 | destination = fullfile(sudokuRoot(), 'data', 'number_data', outputFile); 11 | 12 | training = struct(); 13 | training.images = double(squeeze(imgDataTrain)); 14 | training.labels = str2double(cellstr(labelsTrain)); 15 | 16 | save(destination, 'training'); 17 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/segmentPuzzle.m: -------------------------------------------------------------------------------- 1 | function mask = segmentPuzzle(im, net) 2 | % segmentPuzzle Perform semantic segmentation to find a puzzle 3 | % 4 | % mask = sudoku.segmentPuzzle(im, net) 5 | % 6 | % Args: 7 | % im: image to segment 8 | % net: semantic segmentation network 9 | % 10 | % Returns: 11 | % mask: binary segmentation mask 12 | % 13 | % Note: 14 | % This function assumes that the segmentation network has been 15 | % trained to recognised a class named 'sudoku'. 16 | % 17 | % See also: sudoku.trainSemanticSegmentationNetwork 18 | 19 | % Copyright 2018, The MathWorks, Inc. 20 | 21 | inputSize = net.Layers(1).InputSize; 22 | originalSize = size(im); 23 | imInput = imresize(im, inputSize(1:2)); 24 | C = semanticseg(imInput, net); 25 | mask = C == 'sudoku'; 26 | mask = imresize(mask, originalSize(1:2)); 27 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/selectAndSort.m: -------------------------------------------------------------------------------- 1 | function sortedIntersections = selectAndSort(intersections) 2 | % selectAndSort Select the most feasible intersections. 3 | % 4 | % sortedIntersections = sudoku.selectAndSort(intersections) 5 | % 6 | % Args: 7 | % intersections: nx2 array of intersection points 8 | % 9 | % Returns: 10 | % sortedIntersections: mx2 array of intersection points 11 | % 12 | % Note: 13 | % The selection algorithm removes nan, and negative intersection 14 | % locations. As well as those very far from the median point. 15 | % 16 | % The remaining intersections are sorted to be anti-clockwise 17 | % 18 | % See also: sudoku.intersect, sudoku.intersectAll 19 | 20 | % Copyright 2018, The MathWorks, Inc. 21 | 22 | % Remove nan and negative 23 | nanLocations = sum(isnan(intersections), 2) > 0; 24 | negativeLocations = min(intersections, [], 2) < 0; 25 | intersections(negativeLocations | nanLocations, :) = []; 26 | 27 | % Remove very distant 28 | distances = sqrt(sum((intersections - median(intersections)).^2, 2)); 29 | medianDistance = median(distances); 30 | distantPoints = distances > 2*medianDistance; 31 | intersections(distantPoints, :) = []; 32 | 33 | % Sort order 34 | meanPoint = mean(intersections); 35 | points = intersections - meanPoint; 36 | angles = atan2(points(:,2), points(:,1)); 37 | [~, sortIndex] = sort(angles); 38 | sortedIntersections = intersections(sortIndex, :); 39 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/sudokuEngine.m: -------------------------------------------------------------------------------- 1 | function [S,eflag] = sudokuEngine(B) 2 | % This function sets up the rules for Sudoku. It reads in the puzzle 3 | % expressed in matrix B, calls intlinprog to solve the puzzle, and returns 4 | % the solution in matrix S. 5 | % 6 | % The matrix B should have 3 columns and at least 17 rows (because a Sudoku 7 | % puzzle needs at least 17 entries to be uniquely solvable). The first two 8 | % elements in each row are the i,j coordinates of a clue, and the third 9 | % element is the value of the clue, an integer from 1 to 9. If B is a 10 | % 9-by-9 matrix, the function first converts it to 3-column form. 11 | 12 | % Copyright 2014 The MathWorks, Inc. 13 | 14 | if isequal(size(B),[9,9]) % 9-by-9 clues 15 | % Convert to 81-by-3 16 | [SM,SN] = meshgrid(1:9); % make i,j entries 17 | B = [SN(:),SM(:),B(:)]; % i,j,k rows 18 | % Now delete zero rows 19 | [rrem,~] = find(B(:,3) == 0); 20 | B(rrem,:) = []; 21 | end 22 | 23 | if size(B,2) ~= 3 || length(size(B)) > 2 24 | error('The input matrix must be N-by-3 or 9-by-9') 25 | end 26 | 27 | if sum([any(B ~= round(B)),any(B < 1),any(B > 9)]) % enforces entries 1-9 28 | error('Entries must be integers from 1 to 9') 29 | end 30 | 31 | %% The rules of Sudoku: 32 | N = 9^3; % number of independent variables in x, a 9-by-9-by-9 array 33 | M = 4*9^2; % number of constraints, see the construction of Aeq 34 | Aeq = zeros(M,N); % allocate equality constraint matrix Aeq*x = beq 35 | beq = ones(M,1); % allocate constant vector beq 36 | f = (1:N)'; % the objective can be anything, but having nonconstant f can speed the solver 37 | lb = zeros(9,9,9); % an initial zero array 38 | ub = lb+1; % upper bound array to give binary variables 39 | 40 | counter = 1; 41 | for j = 1:9 % one in each row 42 | for k = 1:9 43 | Astuff = lb; % clear Astuff 44 | Astuff(1:end,j,k) = 1; % one row in Aeq*x = beq 45 | Aeq(counter,:) = Astuff(:)'; % put Astuff in a row of Aeq 46 | counter = counter + 1; 47 | end 48 | end 49 | 50 | for i = 1:9 % one in each column 51 | for k = 1:9 52 | Astuff = lb; 53 | Astuff(i,1:end,k) = 1; 54 | Aeq(counter,:) = Astuff(:)'; 55 | counter = counter + 1; 56 | end 57 | end 58 | 59 | for U = 0:3:6 % one in each square 60 | for V = 0:3:6 61 | for k = 1:9 62 | Astuff = lb; 63 | Astuff(U+(1:3),V+(1:3),k) = 1; 64 | Aeq(counter,:) = Astuff(:)'; 65 | counter = counter + 1; 66 | end 67 | end 68 | end 69 | 70 | for i = 1:9 % one in each depth 71 | for j = 1:9 72 | Astuff = lb; 73 | Astuff(i,j,1:end) = 1; 74 | Aeq(counter,:) = Astuff(:)'; 75 | counter = counter + 1; 76 | end 77 | end 78 | 79 | %% Put the particular puzzle in the constraints 80 | % Include the initial clues in the |lb| array by setting corresponding 81 | % entries to 1. This forces the solution to have |x(i,j,k) = 1|. 82 | 83 | for i = 1:size(B,1) 84 | lb(B(i,1),B(i,2),B(i,3)) = 1; 85 | end 86 | 87 | %% Solve the Puzzle 88 | % The Sudoku problem is complete: the rules are represented in the |Aeq| 89 | % and |beq| matrices, and the clues are ones in the |lb| array. Solve the 90 | % problem by calling |intlinprog|. Ensure that the integer program has all 91 | % binary variables by setting the intcon argument to |1:N|, with lower and 92 | % upper bounds of 0 and 1. 93 | 94 | intcon = 1:N; 95 | 96 | [x,~,eflag] = intlinprog(f,intcon,[],[],Aeq,beq,lb,ub); 97 | 98 | %% Convert the Solution to a Usable Form 99 | % To go from the solution x to a Sudoku grid, simply add up the numbers at 100 | % each $(i,j)$ entry, multiplied by the depth at which the numbers appear: 101 | 102 | if eflag > 0 % good solution 103 | x = reshape(x,9,9,9); % change back to a 9-by-9-by-9 array 104 | x = round(x); % clean up non-integer solutions 105 | y = ones(size(x)); 106 | for k = 2:9 107 | y(:,:,k) = k; % multiplier for each depth k 108 | end 109 | 110 | S = x.*y; % multiply each entry by its depth 111 | S = sum(S,3); % S is 9-by-9 and holds the solved puzzle 112 | else 113 | S = []; 114 | end 115 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/thresholdImage.m: -------------------------------------------------------------------------------- 1 | function thresholded = thresholdImage(im, mask) 2 | % thresholdImage Applies a binary threshold to an image. 3 | % 4 | % thresholded = thresholdImage(im, mask) 5 | % 6 | % Args: 7 | % im: image to threshold 8 | % mask: existing binary mask 9 | % 10 | % Returns: 11 | % thresholded: output binary mask 12 | 13 | % Copyright 2018, The MathWorks, Inc. 14 | 15 | imGray = rgb2gray(im); 16 | bw = imbinarize(imGray, 'adaptive', ... 17 | 'Sensitivity', 0.7); 18 | thresholded = ~bw & mask; 19 | 20 | end 21 | -------------------------------------------------------------------------------- /src/sudoku/+sudoku/trainNumberNetwork.m: -------------------------------------------------------------------------------- 1 | function net = trainNumberNetwork(outputName) 2 | % trainNumberNetwork Trains a number classification network. 3 | % 4 | % net = sudoku.trainNumberNetwork(); 5 | % net = sudoku.trainNumberNetwork(outputName); 6 | % 7 | % Args: 8 | % outputName (optional): Name under which to save the network 9 | % 10 | % Returns: 11 | % net: Trained neural network 12 | % 13 | % Note: 14 | % The output model will be saved in a 'models' folder in the current 15 | % working directory. If no name is provided a default of 16 | % 'number_network.mat' will be used. 17 | 18 | % Copyright 2018, The MathWorks, Inc. 19 | 20 | if nargin < 1 21 | outputName = 'number_network'; 22 | end 23 | 24 | %% Parameters 25 | modelDirectory = 'models'; 26 | nSamples = 5000; 27 | initialChannels = 32; 28 | imageSize = [64, 64]; 29 | 30 | %% Get the training data 31 | [train, test] = sudoku.training.getNumberData(nSamples, false); 32 | 33 | %% Set up the training options 34 | options = trainingOptions('sgdm', ... 35 | 'Plots', 'training-progress', ... 36 | 'L2Regularization', 1e-2, ... 37 | 'MaxEpochs', 8, ... 38 | 'Shuffle', 'every-epoch', ... 39 | 'InitialLearnRate', 0.01, ... 40 | 'LearnRateDropFactor', 0.1, ... 41 | 'LearnRateDropPeriod', 3, ... 42 | 'LearnRateSchedule', 'piecewise', ... 43 | 'ValidationData', test, ... 44 | 'ValidationPatience', Inf, ... 45 | 'MiniBatchSize', 64); 46 | 47 | %% Setup the network 48 | layers = sudoku.training.vggLike(initialChannels, imageSize); 49 | 50 | %% Train 51 | net = trainNetwork(train, layers, options); 52 | 53 | %% Save the output model 54 | if ~isfolder(modelDirectory) 55 | mkdir(modelDirectory) 56 | end 57 | outputFile = fullfile(modelDirectory, outputName); 58 | save(outputFile); 59 | 60 | end 61 | -------------------------------------------------------------------------------- /src/sudoku/+sudoku/trainSemanticSegmentation.m: -------------------------------------------------------------------------------- 1 | function net = trainSemanticSegmentation(outputName, checkpoints) 2 | % trainSemanticSegmentation Train the sudoku segmentation network. 3 | % 4 | % net = sudoku.trainSemanticSegmentation(); 5 | % net = sudoku.trainSemanticSegmentation(outputName); 6 | % net = sudoku.trainSemanticSegmentation(outputName, checkpoints); 7 | % 8 | % Args: 9 | % outputName (optional): Name under which to save the network 10 | % checkpoints (optional boolean): Whether to save training 11 | % checkpoints 12 | % 13 | % Returns: 14 | % net: Trained neural network 15 | % 16 | % Note: 17 | % The output model will be saved in a 'models' folder in the current 18 | % working directory. If no name is provided a default of 19 | % 'segmentation_network.mat' will be used. 20 | 21 | % Copyright 2018, The MathWorks, Inc. 22 | 23 | if nargin < 2 24 | checkpoints = false; 25 | end 26 | 27 | if nargin < 1 28 | outputName = 'segmentation_network'; 29 | end 30 | 31 | %% Parameters 32 | modelDirectory = 'models'; 33 | inputSize = [512, 512, 3]; 34 | numClasses = 2; 35 | networkDepth = 'vgg16'; 36 | trainFraction = 0.7; 37 | 38 | %% Get the training data 39 | [imagesTrain, labelsTrain, imagesTest, labelsTest] = sudoku.training.getSudokuData(trainFraction, false); 40 | 41 | augmenter = imageDataAugmenter( ... 42 | 'RandXReflection',false, ... 43 | 'RandYReflection',false, ... 44 | 'RandRotation', [15, 15], ... 45 | 'RandXScale', [0.75, 1.25], ... 46 | 'RandYScale', [0.75, 1.25], ... 47 | 'RandXTranslation', [-100, 100], ... 48 | 'RandYTranslation', [-100, 100]); 49 | 50 | train = pixelLabelImageDatastore(imagesTrain, labelsTrain, ... 51 | 'OutputSize', inputSize(1:2), ... 52 | 'DataAugmentation', augmenter); 53 | test = pixelLabelImageDatastore(imagesTest, labelsTest, ... 54 | 'OutputSize', inputSize(1:2)); 55 | 56 | %% Setup the network 57 | layers = segnetLayers(inputSize, numClasses, networkDepth); 58 | layers = sudoku.training.weightLossByFrequency(layers, train); 59 | 60 | %% Set up the training options 61 | if checkpoints 62 | checkpointPath = fullfile('checkpoints', outputName); 63 | if ~isfolder(checkpointPath) 64 | mkdir(checkpointPath) 65 | end 66 | else 67 | checkpointPath = ''; 68 | end 69 | 70 | opts = trainingOptions('sgdm', ... 71 | 'InitialLearnRate', 0.005, ... 72 | 'LearnRateDropFactor', 0.1, ... 73 | 'LearnRateDropPeriod', 20, ... 74 | 'LearnRateSchedule', 'piecewise', ... 75 | 'ValidationData', test, ... 76 | 'ValidationPatience', Inf, ... 77 | 'MaxEpochs', 40, ... 78 | 'MiniBatchSize', 2, ... 79 | 'Shuffle', 'every-epoch', ... 80 | 'Plots', 'training-progress', ... 81 | 'CheckpointPath', checkpointPath); 82 | 83 | %% Train 84 | net = trainNetwork(train, layers, opts); 85 | 86 | %% Save 87 | if ~isfolder(modelDirectory) 88 | mkdir(modelDirectory) 89 | end 90 | outputFile = fullfile(modelDirectory, outputName); 91 | save(outputFile); 92 | end 93 | -------------------------------------------------------------------------------- /src/sudoku/+sudoku/undistort.m: -------------------------------------------------------------------------------- 1 | function outputImage = undistort(im, imagePoints, worldPoints) 2 | % undistort Undistort an image. 3 | % 4 | % outputImage = undistort(im, imagePoints, worldPoints) 5 | % 6 | % Args: 7 | % im: image for undistortion 8 | % imagePoints: nx2 array of points in the image 9 | % worldPoints: nx2 array of corresponding world points 10 | % 11 | % Returns: 12 | % outputImage: rectified image 13 | 14 | % Copyright 2018, The MathWorks, Inc. 15 | 16 | 17 | transform = fitgeotrans(imagePoints, worldPoints, 'projective'); 18 | outputImage = imwarp(im, transform, ... 19 | 'Interpolation', 'cubic', ... 20 | 'OutputView', imref2d(max(worldPoints))); 21 | 22 | end -------------------------------------------------------------------------------- /src/sudoku/+sudoku/visualiseClassifier.m: -------------------------------------------------------------------------------- 1 | function visualiseClassifier(net) 2 | % visualiseClassifier Generate tsne based visualisation of classifier. 3 | % 4 | % sudoku.visualiseClassifier(net); 5 | % 6 | % Args: 7 | % net: Classifier network 8 | 9 | % Copyright 2018, The MathWorks, Inc. 10 | 11 | % Parameters 12 | perplexity = 25; 13 | canvasScale = 15; 14 | imageSize = 10; 15 | 16 | % Get test data 17 | [~, test] = sudoku.training.getNumberData(0, false); 18 | 19 | % Get the activations 20 | X = activations(net, test, 'fc_1', 'OutputAs', 'channels'); 21 | predictions = net.classify(test); 22 | X = squeeze(X); 23 | Y = tsne(X', 'Perplexity', perplexity); 24 | 25 | % Generate the figure 26 | figure 27 | hold on 28 | cmap = colormap('hsv'); 29 | cmap = cmap(7:end-6, :); 30 | 31 | for iNum = 1:numel(test.Labels) 32 | im = test.readimage(iNum); 33 | label = test.Labels(iNum); 34 | num = str2double(char(label)); 35 | index = ceil(size(cmap, 1)*(num+1)/10); 36 | x = canvasScale*Y(iNum,1); 37 | y = canvasScale*Y(iNum,2); 38 | image('XData', [x, x+imageSize], ... 39 | 'YData', [y, y-imageSize], ... 40 | 'CData', im) 41 | 42 | if label ~= predictions(iNum) 43 | edgeColor = 'r'; 44 | else 45 | edgeColor = cmap(index, :); 46 | end 47 | 48 | rectangle('Position', [x, y-imageSize, imageSize, imageSize], ... 49 | 'EdgeColor', edgeColor, ... 50 | 'LineWidth', 2); 51 | 52 | end 53 | grid on 54 | axis equal 55 | set(gca,'YColor','none') 56 | set(gca,'XColor','none') 57 | currentXLim = get(gca,'XLim'); 58 | xlim([currentXLim(1) - 100, currentXLim(2) + 100]) 59 | end -------------------------------------------------------------------------------- /src/sudoku/sudokuRoot.m: -------------------------------------------------------------------------------- 1 | function thisPath = sudokuRoot() 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | thisPath = fileparts(mfilename('fullpath')); 6 | end 7 | -------------------------------------------------------------------------------- /uninstall.m: -------------------------------------------------------------------------------- 1 | function uninstall() 2 | 3 | % Copyright 2018 The MathWorks, Inc. 4 | 5 | thisPath = fileparts(mfilename('fullpath')); 6 | rmpath(fullfile(thisPath, 'src', 'sudoku')); 7 | 8 | end --------------------------------------------------------------------------------