├── .gitignore ├── DigitsRecogniser ├── __init__.py ├── models │ ├── CNN.h5 │ ├── Softmax.pkl │ └── XGBOOST.bin └── recognise_digits.py ├── ExampleImages ├── image1.jpg ├── image2.jpg ├── image3.jpg ├── image4.jpg ├── image5.jpg ├── image6.jpg ├── image7.jpg └── image8.jpg ├── GridSolver ├── SudokuSolve.py ├── setup.py ├── sudokuGen.cpp ├── sudokuSolve.i ├── sudokuSolve_wrap.cpp └── svgHead.txt ├── Pipeline.png ├── Project_Report.pdf ├── PuzzleExtractor ├── __init__.py ├── digit_extraction.py ├── grid_extraction.py └── processing.py ├── README.md ├── Testing ├── CNN_DigitClassify.ipynb ├── DigitsClassifyTesting.ipynb ├── README.md ├── SudokuTest.ipynb ├── X_train.csv └── y_train.csv ├── __init__.py ├── app.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /DigitsRecogniser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/DigitsRecogniser/__init__.py -------------------------------------------------------------------------------- /DigitsRecogniser/models/CNN.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/DigitsRecogniser/models/CNN.h5 -------------------------------------------------------------------------------- /DigitsRecogniser/models/Softmax.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/DigitsRecogniser/models/Softmax.pkl -------------------------------------------------------------------------------- /DigitsRecogniser/models/XGBOOST.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/DigitsRecogniser/models/XGBOOST.bin -------------------------------------------------------------------------------- /DigitsRecogniser/recognise_digits.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from sklearn.metrics import accuracy_score 4 | from sklearn.pipeline import make_pipeline 5 | from sklearn.preprocessing import StandardScaler 6 | import joblib 7 | import os 8 | import xgboost as xgb 9 | from tensorflow.keras.models import Sequential, load_model 10 | 11 | models = ["CNN", "XGBOOST", "Softmax", "RandomForest", "GNB"] 12 | files = ["CNN.h5", "XGBOOST.bin", "Softmax.pkl", "RandomForest.pkl", "GNB.pkl"] 13 | 14 | # Digits array contains digits of grid in column major order 15 | def predictDigits(digits: np.ndarray, model: int): 16 | assert (len(digits) == 81) and (model>=0) and (model<=4) 17 | 18 | res = "" 19 | for i in range(9): 20 | for j in range(9): 21 | digit = predictDigit(digits[j*9 + i], model) 22 | res += str(digit) 23 | 24 | return res 25 | #return "009000780830019000610000403001900027000040000590008300905000072000590048082000900" 26 | 27 | 28 | def preprocess(x_vec): 29 | x_vec = x_vec.astype(np.uint8) 30 | x_vec = 255 - x_vec 31 | hog = cv2.HOGDescriptor((28, 28), (14, 14), (7, 7), (14, 14), 12) 32 | return hog.compute(x_vec).reshape(1, -1) 33 | 34 | # Take decision based on probability vector of each class 35 | # cost_r: Cost of rejection (In our case we will mark it as empty or no digit = 0 in Sudoku) 36 | # cost_w: Cost of wrong classification (In our case we will mark it as empty or no digit = 0 in Sudoku) 37 | def take_decision(probabilities, cost_r=10, cost_w=30): 38 | assert cost_w != 0 39 | 40 | # Reference: https://www.cs.ubc.ca/~murphyk/Teaching/CS340-Fall07/dtheory.pdf 41 | pred_class = np.argmax(probabilities) 42 | if(probabilities[pred_class] > (1 - (cost_r/cost_w))): 43 | return pred_class 44 | 45 | # reject => No digit => 0 for our case 46 | return 0 47 | 48 | 49 | def preprocess_for_CNN(digit: np.ndarray): 50 | digit = 255 - digit 51 | digit = digit/255 52 | digit = digit.reshape((1, 28, 28, 1)) 53 | return digit 54 | 55 | def predictDigit(digit: np.ndarray, model): 56 | # Less than 10 pixels coloured 57 | if np.sum(digit) < 10*255: 58 | return 0 59 | 60 | if models[model]=="XGBOOST": 61 | clf = xgb.XGBClassifier(objective="multi:softmax", booster="gbtree", num_classes=10, ) 62 | clf.load_model("DigitsRecogniser/models/" + files[model]) 63 | prob = clf.predict_proba(preprocess(digit)) 64 | prob = np.array([x[1] for x in prob]) 65 | elif models[model]=="CNN": 66 | clf = load_model("DigitsRecogniser/models/" + files[model]) 67 | processed_digit = preprocess_for_CNN(digit) 68 | prob = clf.predict(processed_digit) 69 | prob = prob[0] 70 | else: 71 | clf = joblib.load("DigitsRecogniser/models/" + files[model]) 72 | prob = clf.predict_proba(preprocess(digit)) 73 | prob = prob[0] 74 | 75 | return take_decision(prob, 8, 10) 76 | -------------------------------------------------------------------------------- /ExampleImages/image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/ExampleImages/image1.jpg -------------------------------------------------------------------------------- /ExampleImages/image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/ExampleImages/image2.jpg -------------------------------------------------------------------------------- /ExampleImages/image3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/ExampleImages/image3.jpg -------------------------------------------------------------------------------- /ExampleImages/image4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/ExampleImages/image4.jpg -------------------------------------------------------------------------------- /ExampleImages/image5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/ExampleImages/image5.jpg -------------------------------------------------------------------------------- /ExampleImages/image6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/ExampleImages/image6.jpg -------------------------------------------------------------------------------- /ExampleImages/image7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/ExampleImages/image7.jpg -------------------------------------------------------------------------------- /ExampleImages/image8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/ExampleImages/image8.jpg -------------------------------------------------------------------------------- /GridSolver/SudokuSolve.py: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by SWIG (http://www.swig.org). 2 | # Version 3.0.8 3 | # 4 | # Do not make changes to this file unless you know what you are doing--modify 5 | # the SWIG interface file instead. 6 | 7 | 8 | 9 | 10 | 11 | from sys import version_info 12 | if version_info >= (2, 6, 0): 13 | def swig_import_helper(): 14 | from os.path import dirname 15 | import imp 16 | fp = None 17 | try: 18 | fp, pathname, description = imp.find_module('_SudokuSolve', [dirname(__file__)]) 19 | except ImportError: 20 | import _SudokuSolve 21 | return _SudokuSolve 22 | if fp is not None: 23 | try: 24 | _mod = imp.load_module('_SudokuSolve', fp, pathname, description) 25 | finally: 26 | fp.close() 27 | return _mod 28 | _SudokuSolve = swig_import_helper() 29 | del swig_import_helper 30 | else: 31 | import _SudokuSolve 32 | del version_info 33 | try: 34 | _swig_property = property 35 | except NameError: 36 | pass # Python < 2.2 doesn't have 'property'. 37 | 38 | 39 | def _swig_setattr_nondynamic(self, class_type, name, value, static=1): 40 | if (name == "thisown"): 41 | return self.this.own(value) 42 | if (name == "this"): 43 | if type(value).__name__ == 'SwigPyObject': 44 | self.__dict__[name] = value 45 | return 46 | method = class_type.__swig_setmethods__.get(name, None) 47 | if method: 48 | return method(self, value) 49 | if (not static): 50 | if _newclass: 51 | object.__setattr__(self, name, value) 52 | else: 53 | self.__dict__[name] = value 54 | else: 55 | raise AttributeError("You cannot add attributes to %s" % self) 56 | 57 | 58 | def _swig_setattr(self, class_type, name, value): 59 | return _swig_setattr_nondynamic(self, class_type, name, value, 0) 60 | 61 | 62 | def _swig_getattr_nondynamic(self, class_type, name, static=1): 63 | if (name == "thisown"): 64 | return self.this.own() 65 | method = class_type.__swig_getmethods__.get(name, None) 66 | if method: 67 | return method(self) 68 | if (not static): 69 | return object.__getattr__(self, name) 70 | else: 71 | raise AttributeError(name) 72 | 73 | def _swig_getattr(self, class_type, name): 74 | return _swig_getattr_nondynamic(self, class_type, name, 0) 75 | 76 | 77 | def _swig_repr(self): 78 | try: 79 | strthis = "proxy of " + self.this.__repr__() 80 | except Exception: 81 | strthis = "" 82 | return "<%s.%s; %s >" % (self.__class__.__module__, self.__class__.__name__, strthis,) 83 | 84 | try: 85 | _object = object 86 | _newclass = 1 87 | except AttributeError: 88 | class _object: 89 | pass 90 | _newclass = 0 91 | 92 | 93 | 94 | _SudokuSolve.UNASSIGNED_swigconstant(_SudokuSolve) 95 | UNASSIGNED = _SudokuSolve.UNASSIGNED 96 | class Sudoku(_object): 97 | __swig_setmethods__ = {} 98 | __setattr__ = lambda self, name, value: _swig_setattr(self, Sudoku, name, value) 99 | __swig_getmethods__ = {} 100 | __getattr__ = lambda self, name: _swig_getattr(self, Sudoku, name) 101 | __repr__ = _swig_repr 102 | 103 | def __init__(self, *args): 104 | this = _SudokuSolve.new_Sudoku(*args) 105 | try: 106 | self.this.append(this) 107 | except Exception: 108 | self.this = this 109 | 110 | def createSeed(self): 111 | return _SudokuSolve.Sudoku_createSeed(self) 112 | 113 | def printGrid(self): 114 | return _SudokuSolve.Sudoku_printGrid(self) 115 | 116 | def solveGrid(self): 117 | return _SudokuSolve.Sudoku_solveGrid(self) 118 | 119 | def getGrid(self): 120 | return _SudokuSolve.Sudoku_getGrid(self) 121 | 122 | def countSoln(self, number): 123 | return _SudokuSolve.Sudoku_countSoln(self, number) 124 | 125 | def genPuzzle(self): 126 | return _SudokuSolve.Sudoku_genPuzzle(self) 127 | 128 | def verifyGridStatus(self): 129 | return _SudokuSolve.Sudoku_verifyGridStatus(self) 130 | 131 | def printSVG(self, arg2): 132 | return _SudokuSolve.Sudoku_printSVG(self, arg2) 133 | 134 | def calculateDifficulty(self): 135 | return _SudokuSolve.Sudoku_calculateDifficulty(self) 136 | 137 | def branchDifficultyScore(self): 138 | return _SudokuSolve.Sudoku_branchDifficultyScore(self) 139 | __swig_destroy__ = _SudokuSolve.delete_Sudoku 140 | __del__ = lambda self: None 141 | Sudoku_swigregister = _SudokuSolve.Sudoku_swigregister 142 | Sudoku_swigregister(Sudoku) 143 | 144 | 145 | def genRandNum(maxLimit): 146 | return _SudokuSolve.genRandNum(maxLimit) 147 | genRandNum = _SudokuSolve.genRandNum 148 | 149 | def FindUnassignedLocation(grid, row, col): 150 | return _SudokuSolve.FindUnassignedLocation(grid, row, col) 151 | FindUnassignedLocation = _SudokuSolve.FindUnassignedLocation 152 | 153 | def UsedInRow(grid, row, num): 154 | return _SudokuSolve.UsedInRow(grid, row, num) 155 | UsedInRow = _SudokuSolve.UsedInRow 156 | 157 | def UsedInCol(grid, col, num): 158 | return _SudokuSolve.UsedInCol(grid, col, num) 159 | UsedInCol = _SudokuSolve.UsedInCol 160 | 161 | def UsedInBox(grid, boxStartRow, boxStartCol, num): 162 | return _SudokuSolve.UsedInBox(grid, boxStartRow, boxStartCol, num) 163 | UsedInBox = _SudokuSolve.UsedInBox 164 | 165 | def isSafe(grid, row, col, num): 166 | return _SudokuSolve.isSafe(grid, row, col, num) 167 | isSafe = _SudokuSolve.isSafe 168 | 169 | def main(argc, argv): 170 | return _SudokuSolve.main(argc, argv) 171 | main = _SudokuSolve.main 172 | # This file is compatible with both classic and new-style classes. 173 | 174 | 175 | -------------------------------------------------------------------------------- /GridSolver/setup.py: -------------------------------------------------------------------------------- 1 | # File : setup.py 2 | 3 | #from distutils.core import setup, Extension 4 | #name of module 5 | #name = "SudokuSolve" 6 | 7 | #version of module 8 | #version = "1.0" 9 | 10 | # specify the name of the extension and source files 11 | # required to compile this 12 | #ext_modules = Extension(name='_SudokuSolve',sources=["sudokuSolve.i","sudokuGen.cpp"]) 13 | 14 | from distutils.core import * 15 | from setuptools import setup, Extension 16 | os.environ["CC"] = "g++" # force compiling c as c++ 17 | setup(name='SudokuSolve', 18 | version='1', 19 | ext_modules=[Extension('_SudokuSolve', sources=['sudokuSolve.i'], 20 | swig_opts=['-c++'], 21 | extra_compile_args=['--std=c++14'] 22 | )], 23 | ) 24 | -------------------------------------------------------------------------------- /GridSolver/sudokuGen.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #define UNASSIGNED 0 11 | 12 | using namespace std; 13 | 14 | class Sudoku { 15 | private: 16 | int grid[9][9]; 17 | int solnGrid[9][9]; 18 | int guessNum[9]; 19 | int gridPos[81]; 20 | int difficultyLevel; 21 | bool grid_status; 22 | 23 | public: 24 | Sudoku (); 25 | Sudoku (string, bool row_major=true); 26 | void createSeed(); 27 | void printGrid(); 28 | bool solveGrid(); 29 | string getGrid(); 30 | void countSoln(int &number); 31 | void genPuzzle(); 32 | bool verifyGridStatus(); 33 | void printSVG(string); 34 | void calculateDifficulty(); 35 | int branchDifficultyScore(); 36 | }; 37 | 38 | // START: Get grid as string in row major order 39 | string Sudoku::getGrid() 40 | { 41 | string s = ""; 42 | for(int row_num=0; row_num<9; ++row_num) 43 | { 44 | for(int col_num=0; col_num<9; ++col_num) 45 | { 46 | s = s + to_string(grid[row_num][col_num]); 47 | } 48 | } 49 | 50 | return s; 51 | } 52 | // END: Get grid as string in row major order 53 | 54 | 55 | // START: Generate random number 56 | int genRandNum(int maxLimit) 57 | { 58 | return rand()%maxLimit; 59 | } 60 | // END: Generate random number 61 | 62 | 63 | // START: Create seed grid 64 | void Sudoku::createSeed() 65 | { 66 | this->solveGrid(); 67 | 68 | // Saving the solution grid 69 | for(int i=0;i<9;i++) 70 | { 71 | for(int j=0;j<9;j++) 72 | { 73 | this->solnGrid[i][j] = this->grid[i][j]; 74 | } 75 | } 76 | } 77 | // END: Create seed grid 78 | 79 | 80 | // START: Intialising 81 | Sudoku::Sudoku() 82 | { 83 | 84 | // initialize difficulty level 85 | this->difficultyLevel = 0; 86 | 87 | // Randomly shuffling the array of removing grid positions 88 | for(int i=0;i<81;i++) 89 | { 90 | this->gridPos[i] = i; 91 | } 92 | 93 | random_shuffle(this->gridPos, (this->gridPos) + 81, genRandNum); 94 | 95 | // Randomly shuffling the guessing number array 96 | for(int i=0;i<9;i++) 97 | { 98 | this->guessNum[i]=i+1; 99 | } 100 | 101 | random_shuffle(this->guessNum, (this->guessNum) + 9, genRandNum); 102 | 103 | // Initialising the grid 104 | for(int i=0;i<9;i++) 105 | { 106 | for(int j=0;j<9;j++) 107 | { 108 | this->grid[i][j]=0; 109 | } 110 | } 111 | 112 | grid_status = true; 113 | } 114 | // END: Initialising 115 | 116 | 117 | // START: Custom Initialising with grid passed as argument 118 | Sudoku::Sudoku(string grid_str, bool row_major) 119 | { 120 | if(grid_str.length() != 81) 121 | { 122 | grid_status=false; 123 | return; 124 | } 125 | 126 | // First pass: Check if all cells are valid 127 | for(int i=0; i<81; ++i) 128 | { 129 | int curr_num = grid_str[i]-'0'; 130 | if(!((curr_num == UNASSIGNED) || (curr_num > 0 && curr_num < 10))) 131 | { 132 | grid_status=false; 133 | return; 134 | } 135 | 136 | if(row_major) grid[i/9][i%9] = curr_num; 137 | else grid[i%9][i/9] = curr_num; 138 | } 139 | 140 | // Second pass: Check if all columns are valid 141 | for (int col_num=0; col_num<9; ++col_num) 142 | { 143 | bool nums[10]={false}; 144 | for (int row_num=0; row_num<9; ++row_num) 145 | { 146 | int curr_num = grid[row_num][col_num]; 147 | if(curr_num!=UNASSIGNED && nums[curr_num]==true) 148 | { 149 | grid_status=false; 150 | return; 151 | } 152 | nums[curr_num] = true; 153 | } 154 | } 155 | 156 | // Third pass: Check if all rows are valid 157 | for (int row_num=0; row_num<9; ++row_num) 158 | { 159 | bool nums[10]={false}; 160 | for (int col_num=0; col_num<9; ++col_num) 161 | { 162 | int curr_num = grid[row_num][col_num]; 163 | if(curr_num!=UNASSIGNED && nums[curr_num]==true) 164 | { 165 | grid_status=false; 166 | return; 167 | } 168 | nums[curr_num] = true; 169 | } 170 | } 171 | 172 | // Fourth pass: Check if all blocks are valid 173 | for (int block_num=0; block_num<9; ++block_num) 174 | { 175 | bool nums[10]={false}; 176 | for (int cell_num=0; cell_num<9; ++cell_num) 177 | { 178 | int curr_num = grid[((int)(block_num/3))*3 + (cell_num/3)][((int)(block_num%3))*3 + (cell_num%3)]; 179 | if(curr_num!=UNASSIGNED && nums[curr_num]==true) 180 | { 181 | grid_status=false; 182 | return; 183 | } 184 | nums[curr_num] = true; 185 | } 186 | } 187 | 188 | // Randomly shuffling the guessing number array 189 | for(int i=0;i<9;i++) 190 | { 191 | this->guessNum[i]=i+1; 192 | } 193 | 194 | random_shuffle(this->guessNum, (this->guessNum) + 9, genRandNum); 195 | 196 | grid_status = true; 197 | } 198 | // END: Custom Initialising 199 | 200 | 201 | // START: Verification status of the custom grid passed 202 | bool Sudoku::verifyGridStatus() 203 | { 204 | return grid_status; 205 | } 206 | // END: Verification of the custom grid passed 207 | 208 | 209 | // START: Printing the grid 210 | void Sudoku::printGrid() 211 | { 212 | for(int i=0;i<9;i++) 213 | { 214 | for(int j=0;j<9;j++) 215 | { 216 | if(grid[i][j] == 0) 217 | cout<<"."; 218 | else 219 | cout<difficultyLevel; 226 | cout<grid, row, col)) 297 | return true; // success! 298 | 299 | // Consider digits 1 to 9 300 | for (int num = 0; num < 9; num++) 301 | { 302 | // if looks promising 303 | if (isSafe(this->grid, row, col, this->guessNum[num])) 304 | { 305 | // make tentative assignment 306 | this->grid[row][col] = this->guessNum[num]; 307 | 308 | // return, if success, yay! 309 | if (solveGrid()) 310 | return true; 311 | 312 | // failure, unmake & try again 313 | this->grid[row][col] = UNASSIGNED; 314 | } 315 | } 316 | 317 | return false; // this triggers backtracking 318 | 319 | } 320 | // END: Modified Sudoku Solver 321 | 322 | 323 | // START: Check if the grid is uniquely solvable 324 | void Sudoku::countSoln(int &number) 325 | { 326 | int row, col; 327 | 328 | if(!FindUnassignedLocation(this->grid, row, col)) 329 | { 330 | number++; 331 | return ; 332 | } 333 | 334 | 335 | for(int i=0;i<9 && number<2;i++) 336 | { 337 | if( isSafe(this->grid, row, col, this->guessNum[i]) ) 338 | { 339 | this->grid[row][col] = this->guessNum[i]; 340 | countSoln(number); 341 | } 342 | 343 | this->grid[row][col] = UNASSIGNED; 344 | } 345 | 346 | } 347 | // END: Check if the grid is uniquely solvable 348 | 349 | 350 | // START: Gneerate puzzle 351 | void Sudoku::genPuzzle() 352 | { 353 | for(int i=0;i<81;i++) 354 | { 355 | int x = (this->gridPos[i])/9; 356 | int y = (this->gridPos[i])%9; 357 | int temp = this->grid[x][y]; 358 | this->grid[x][y] = UNASSIGNED; 359 | 360 | // If now more than 1 solution , replace the removed cell back. 361 | int check=0; 362 | countSoln(check); 363 | if(check!=1) 364 | { 365 | this->grid[x][y] = temp; 366 | } 367 | } 368 | } 369 | // END: Generate puzzle 370 | 371 | 372 | // START: Printing into SVG file 373 | void Sudoku::printSVG(string path="") 374 | { 375 | string fileName = path + "svgHead.txt"; 376 | ifstream file1(fileName.c_str()); 377 | stringstream svgHead; 378 | svgHead << file1.rdbuf(); 379 | 380 | ofstream outFile("puzzle.svg"); 381 | outFile << svgHead.rdbuf(); 382 | 383 | for(int i=0;i<9;i++) 384 | { 385 | for(int j=0;j<9;j++) 386 | { 387 | if(this->grid[i][j]!=0) 388 | { 389 | int x = 50*j + 16; 390 | int y = 50*i + 35; 391 | 392 | stringstream text; 393 | text<<""<grid[i][j]<<"\n"; 394 | 395 | outFile << text.rdbuf(); 396 | } 397 | } 398 | } 399 | 400 | outFile << "Difficulty Level (0 being easiest): " <difficultyLevel<<"\n"; 401 | outFile << ""; 402 | 403 | } 404 | // END: Printing into SVG file 405 | 406 | 407 | // START: Calculate branch difficulty score 408 | int Sudoku::branchDifficultyScore() 409 | { 410 | int emptyPositions = -1; 411 | int tempGrid[9][9]; 412 | int sum=0; 413 | 414 | for(int i=0;i<9;i++) 415 | { 416 | for(int j=0;j<9;j++) 417 | { 418 | tempGrid[i][j] = this->grid[i][j]; 419 | } 420 | } 421 | 422 | while(emptyPositions!=0) 423 | { 424 | vector > empty; 425 | 426 | for(int i=0;i<81;i++) 427 | { 428 | if(tempGrid[(int)(i/9)][(int)(i%9)] == 0) 429 | { 430 | vector temp; 431 | temp.push_back(i); 432 | 433 | for(int num=1;num<=9;num++) 434 | { 435 | if(isSafe(tempGrid,i/9,i%9,num)) 436 | { 437 | temp.push_back(num); 438 | } 439 | } 440 | 441 | empty.push_back(temp); 442 | } 443 | 444 | } 445 | 446 | if(empty.size() == 0) 447 | { 448 | cout<<"Hello: "<solnGrid[rowIndex][colIndex]; 466 | sum = sum + ((branchFactor-2) * (branchFactor-2)) ; 467 | 468 | emptyPositions = empty.size() - 1; 469 | } 470 | 471 | return sum; 472 | 473 | } 474 | // END: Finish branch difficulty score 475 | 476 | 477 | // START: Calculate difficulty level of current grid 478 | void Sudoku::calculateDifficulty() 479 | { 480 | int B = branchDifficultyScore(); 481 | int emptyCells = 0; 482 | 483 | for(int i=0;i<9;i++) 484 | { 485 | for(int j=0;j<9;j++) 486 | { 487 | if(this->grid[i][j] == 0) 488 | emptyCells++; 489 | } 490 | } 491 | 492 | this->difficultyLevel = B*100 + emptyCells; 493 | } 494 | // END: calculating difficulty level 495 | 496 | 497 | // START: The main function 498 | int main(int argc, char const *argv[]) 499 | { 500 | // Initialising seed for random number generation 501 | srand(time(NULL)); 502 | 503 | // Creating an instance of Sudoku 504 | Sudoku *puzzle = new Sudoku(); 505 | 506 | // Creating a seed for puzzle generation 507 | puzzle->createSeed(); 508 | 509 | // Generating the puzzle 510 | puzzle->genPuzzle(); 511 | 512 | // Calculating difficulty of puzzle 513 | puzzle->calculateDifficulty(); 514 | 515 | // testing by printing the grid 516 | puzzle->printGrid(); 517 | 518 | // Printing the grid into SVG file 519 | string rem = "sudokuGen"; 520 | string path = argv[0]; 521 | path = path.substr(0,path.size() - rem.size()); 522 | puzzle->printSVG(path); 523 | cout<<"The above sudoku puzzle has been stored in puzzles.svg in current folder\n"; 524 | // freeing the memory 525 | delete puzzle; 526 | 527 | return 0; 528 | } 529 | // END: The main function 530 | -------------------------------------------------------------------------------- /GridSolver/sudokuSolve.i: -------------------------------------------------------------------------------- 1 | %module SudokuSolve 2 | %{ 3 | #include "sudokuGen.cpp" 4 | %} 5 | 6 | %include "std_string.i" 7 | %include "sudokuGen.cpp" 8 | -------------------------------------------------------------------------------- /GridSolver/svgHead.txt: -------------------------------------------------------------------------------- 1 | 2 | 6 | 7 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /Pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/Pipeline.png -------------------------------------------------------------------------------- /Project_Report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/Project_Report.pdf -------------------------------------------------------------------------------- /PuzzleExtractor/__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Package for extracting the sudoku puzzle from the image passed 3 | Also returns the status, meaning whether grid succesfully extracted. 4 | 5 | ''' 6 | -------------------------------------------------------------------------------- /PuzzleExtractor/digit_extraction.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def get_square_centers(transformed_img): 6 | lines_X = np.linspace(0, transformed_img.shape[1], num=10, dtype=int) 7 | lines_Y = np.linspace(0, transformed_img.shape[0], num=10, dtype=int) 8 | centers_X = [(lines_X[i] + lines_X[i-1])//2 for i in range(1, len(lines_X))] 9 | centers_Y = [(lines_Y[i] + lines_Y[i-1])//2 for i in range(1, len(lines_Y))] 10 | 11 | return centers_X, centers_Y 12 | 13 | def extract_digit_from_cell(digit): 14 | if(np.sum(digit) < 255*5): 15 | return digit 16 | 17 | contours, _ = cv2.findContours(digit.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 18 | grid_cnt = np.array(sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True)) 19 | mask = np.zeros_like(digit) 20 | cv2.drawContours(mask, grid_cnt, 0, 255, -1) # Draw filled contour in mask 21 | out = np.zeros_like(digit) # Extract out the object and place into output image 22 | out[mask == 255] = digit[mask == 255] 23 | 24 | # Now crop 25 | (y, x) = np.where(mask == 255) 26 | (topy, topx) = (np.min(y), np.min(x)) 27 | (bottomy, bottomx) = (np.max(y), np.max(x)) 28 | out = out[topy:bottomy+1, topx:bottomx+1] 29 | #out = cv2.resize(out, (16,16), interpolation=cv2.INTER_AREA) 30 | 31 | # Now place on top of black image of size same as passed image in center 32 | res = np.zeros_like(digit) 33 | hh, ww = res.shape[0], res.shape[1] 34 | h, w = out.shape[0], out.shape[1] 35 | yoff = round((hh-h)/2) 36 | xoff = round((ww-w)/2) 37 | 38 | # use numpy indexing to place the resized image in the center of background image 39 | res[yoff:yoff+h, xoff:xoff+w] = out 40 | return res 41 | 42 | def centering_se(shape: (int, int), shape_ones: (int, int)): 43 | x = np.zeros(shape) 44 | assert (shape_ones[0] < shape[0]) and (shape_ones[1] < shape[1]) 45 | width = shape_ones[0] 46 | height = shape_ones[1] 47 | 48 | rows, cols = shape 49 | for i in range(width): 50 | for j in range(height): 51 | x[i][j], x[rows-1-i][j], x[i][cols-1-j], x[rows-1-i][cols-1-j] = 1, 1, 1, 1 52 | 53 | return x 54 | 55 | 56 | def recentre(img: np.ndarray, prev_center: (int, int), h_se: np.ndarray, v_se: np.ndarray, h_mov_range: (int, int), v_mov_range: (int, int)) -> (int, int): 57 | # reference: https://web.stanford.edu/class/ee368/Project_Spring_1415/Reports/Wang.pdf 58 | max_res, max_center = 0, prev_center 59 | 60 | for i in range(v_mov_range[0], v_mov_range[1]): 61 | curr_center = (prev_center[0] + 0, prev_center[1] + i) 62 | start_row = max(curr_center[1] - v_se.shape[0]//2, 0) 63 | start_col = max(curr_center[0] - v_se.shape[1]//2, 0) 64 | partial = img[start_row:start_row+v_se.shape[0], start_col:start_col+v_se.shape[1]] 65 | 66 | curr_dot = np.sum(partial*(v_se[0:partial.shape[0], 0:partial.shape[1]])) 67 | # curr_dot = np.sum(img[x1:x1+v_se.shape[0], y1:y1+v_se.shape[1]]*(v_se)) 68 | # print(curr_center, curr_dot) 69 | if max_res <= curr_dot: 70 | max_res = curr_dot 71 | max_center = curr_center 72 | 73 | # # print("max_center after v_se: ", max_center) 74 | prev_center = max_center 75 | max_res = 0 76 | for i in range(h_mov_range[0], h_mov_range[1]): 77 | curr_center = (prev_center[0] + i, prev_center[1] + 0) 78 | start_row = max(curr_center[1] - h_se.shape[0]//2, 0) 79 | start_col = max(curr_center[0] - h_se.shape[1]//2, 0) 80 | partial = img[start_row:start_row+h_se.shape[0], start_col:start_col+h_se.shape[1]] 81 | 82 | curr_dot = np.sum(partial*(h_se[0:partial.shape[0], 0:partial.shape[1]])) 83 | # print(curr_center, curr_dot) 84 | if max_res <= curr_dot: 85 | max_res = curr_dot 86 | max_center = curr_center 87 | 88 | # print("max_center after h_se: ", max_center) 89 | return max_center 90 | 91 | 92 | def preprocess_digit(digit_img): 93 | # remove possible edges from border 94 | digit_img[0:3,:] = 0 95 | digit_img[:,0:3] = 0 96 | digit_img[-3:,:] = 0 97 | digit_img[:,-3:] = 0 98 | 99 | # dilating and eroding the digit 100 | if(np.sum(digit_img) < 255*30): 101 | return np.zeros_like(digit_img) 102 | 103 | return digit_img 104 | 105 | 106 | def extractDigits(transformed_img): 107 | centers_X, centers_Y = get_square_centers(transformed_img) 108 | centers = [(centers_X[i], centers_Y[j]) for i in range(len(centers_X)) for j in range(len(centers_Y))] 109 | kernel_shape = (centers_X[1] - centers_X[0], centers_Y[1] - centers_Y[0]) 110 | 111 | ones_length = (kernel_shape[0]+kernel_shape[1])//20 112 | v_se = centering_se(kernel_shape, (2,ones_length)) 113 | h_se = centering_se(kernel_shape, (ones_length,2)) 114 | new_centers = [] 115 | for i in range(len(centers)): 116 | v_mov_range, h_mov_range = (-kernel_shape[0]//8, kernel_shape[0]//8), (-kernel_shape[1]//8, kernel_shape[1]//8) 117 | if (i<9) : h_mov_range = (-kernel_shape[1]//32, kernel_shape[1]//8) 118 | elif (i>71) : h_mov_range = (-kernel_shape[1]//8, kernel_shape[1]//32) 119 | if (i%9 == 0) : v_mov_range = (-kernel_shape[0]//32, kernel_shape[0]//8) 120 | elif ((i+1)%9 == 0) : v_mov_range = (-kernel_shape[0]//8, kernel_shape[0]//32) 121 | new_centers.append(recentre(transformed_img, centers[i], h_se, v_se, h_mov_range, v_mov_range)) 122 | 123 | digits = [] 124 | for center in new_centers: 125 | top_l = [center[0]-kernel_shape[1]//2, center[1]-kernel_shape[0]//2] 126 | top_r = [center[0]+kernel_shape[1]//2, center[1]-kernel_shape[0]//2] 127 | bottom_l = [center[0]-kernel_shape[1]//2, center[1]+kernel_shape[0]//2] 128 | bottom_r = [center[0]+kernel_shape[1]//2, center[1]+kernel_shape[0]//2] 129 | 130 | M = cv2.getPerspectiveTransform(np.float32([top_l, top_r, bottom_l, bottom_r]), np.float32([[0,0], [28,0], [0,28], [28,28]])) 131 | dst = cv2.warpPerspective(transformed_img,M,(28,28)) 132 | dst = dst.astype('uint8') 133 | dst_mod = preprocess_digit(dst) 134 | dst_mod = extract_digit_from_cell(dst_mod) 135 | digits.append(dst_mod) 136 | 137 | return digits -------------------------------------------------------------------------------- /PuzzleExtractor/grid_extraction.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import operator 4 | 5 | # Note: Pass processed images only 6 | def find_largest_contour(img: np.ndarray) -> (bool, np.ndarray): 7 | # find contours in the edged image, keep only the largest 8 | contours, hierarchy = cv2.findContours(img.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 9 | grid_cnt = np.array(sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True)) 10 | status, main_contour = False, np.array([]) 11 | if len(grid_cnt) != 0: 12 | status, main_contour = True, grid_cnt[0] 13 | 14 | return (status, main_contour) 15 | 16 | 17 | def perspective_transform(image: np.ndarray, corners: np.ndarray) -> np.ndarray: 18 | # Reference: https://stackoverflow.com/questions/57636399/how-to-detect-sudoku-grid-board-in-opencv 19 | 20 | def order_corner_points(corners): 21 | # Bottom-right point has the largest (x + y) value 22 | # Top-left has point smallest (x + y) value 23 | # Bottom-left point has smallest (x - y) value 24 | # Top-right point has largest (x - y) value 25 | bottom_r, _ = max(enumerate([pt[0][0] + pt[0][1] for pt in corners]), key=operator.itemgetter(1)) 26 | top_l = (bottom_r + 2)%4 27 | left_corners = [corners[i] for i in range(len(corners)) if((i!=bottom_r) and (i!=top_l))] 28 | bottom_l, _ = min(enumerate([pt[0][0] - pt[0][1] for pt in left_corners]), key=operator.itemgetter(1)) 29 | top_r = (bottom_l + 1)%2 30 | 31 | return (corners[top_l][0], left_corners[top_r][0], corners[bottom_r][0], left_corners[bottom_l][0]) 32 | 33 | # Order points in clockwise order 34 | ordered_corners = order_corner_points(corners) 35 | top_l, top_r, bottom_r, bottom_l = ordered_corners 36 | 37 | # Determine width of new image which is the max distance between 38 | # (bottom right and bottom left) or (top right and top left) x-coordinates 39 | width_A = np.sqrt(((bottom_r[0] - bottom_l[0]) ** 2) + ((bottom_r[1] - bottom_l[1]) ** 2)) 40 | width_B = np.sqrt(((top_r[0] - top_l[0]) ** 2) + ((top_r[1] - top_l[1]) ** 2)) 41 | width = max(int(width_A), int(width_B)) 42 | 43 | # Determine height of new image which is the max distance between 44 | # (top right and bottom right) or (top left and bottom left) y-coordinates 45 | height_A = np.sqrt(((top_r[0] - bottom_r[0]) ** 2) + ((top_r[1] - bottom_r[1]) ** 2)) 46 | height_B = np.sqrt(((top_l[0] - bottom_l[0]) ** 2) + ((top_l[1] - bottom_l[1]) ** 2)) 47 | height = max(int(height_A), int(height_B)) 48 | 49 | # Construct new points to obtain top-down view of image in 50 | # top_r, top_l, bottom_l, bottom_r order 51 | dimensions = np.array([[0, 0], [width - 1, 0], [width - 1, height - 1], 52 | [0, height - 1]], dtype = "float32") 53 | 54 | # Convert to Numpy format 55 | ordered_corners = np.array(ordered_corners, dtype="float32") 56 | 57 | # Find perspective transform matrix 58 | matrix = cv2.getPerspectiveTransform(ordered_corners, dimensions) 59 | 60 | # Return the transformed image 61 | return cv2.warpPerspective(image, matrix, (width, height)) 62 | 63 | 64 | # Main function for extracting the grid from the image 65 | def extractGrid(processed_img: np.ndarray) -> (bool, np.ndarray): 66 | status, main_contour = find_largest_contour(processed_img) 67 | if status == False: 68 | return (status, main_contour) 69 | 70 | peri = cv2.arcLength(main_contour, True) 71 | approx = cv2.approxPolyDP(main_contour, 0.01 * peri, True) 72 | transformed_processed = perspective_transform(processed_img, approx[0:4]) 73 | # For debugging 74 | # print(approx) 75 | # transformed_original = perspective_transform(orig_img, approx) 76 | # print(approx.shape, approx[0], approx[1], approx[2], approx[3]) 77 | 78 | return (True, transformed_processed) 79 | -------------------------------------------------------------------------------- /PuzzleExtractor/processing.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | def basic_preprocessing(img: np.ndarray) -> np.ndarray: 5 | # create a CLAHE object for Histogram equalisation and improvng the contrast. 6 | img_plt = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 7 | clahe = cv2.createCLAHE(clipLimit=0.8, tileGridSize=(8,8)) 8 | enhanced = clahe.apply(img_plt) 9 | 10 | # Edge preserving smoother: 11 | # https://dsp.stackexchange.com/questions/60916/what-is-the-bilateral-filter-category-lpf-hpf-bpf-or-bsf 12 | x, y = max(img.shape[0]//200, 5), max(img.shape[1]//200, 5) 13 | blurred = cv2.GaussianBlur(enhanced, (x+(x+1)%2, y+(y+1)%2), 0) 14 | blurred = cv2.bilateralFilter(blurred,7,75,75) 15 | return blurred 16 | 17 | # requires a grayscale image as input 18 | def to_binary(img: np.ndarray) -> np.ndarray: 19 | # opening for clearing some noise 20 | se = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)) 21 | opened = cv2.morphologyEx(img, cv2.MORPH_OPEN, se) 22 | 23 | thresholded_img = cv2.adaptiveThreshold(opened, 255, cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRESH_BINARY,5,2) 24 | inverted = cv2.bitwise_not(thresholded_img) 25 | 26 | if(img.shape[0] > 1000 and img.shape[1] > 1000): 27 | se = np.ones((2,2)) 28 | eroded = cv2.erode(inverted, se, iterations=1) 29 | else: 30 | se = np.ones((2,2)) 31 | eroded = cv2.erode(inverted, se, iterations=1) 32 | 33 | return eroded 34 | 35 | 36 | def processImage(img: np.ndarray) -> np.ndarray: 37 | preprocessed = basic_preprocessing(img) 38 | binary = to_binary(preprocessed) 39 | 40 | if(img.shape[0] > 1000 and img.shape[1] > 1000): 41 | kernel = np.ones((3,3)) 42 | dilated = cv2.dilate(binary, kernel, iterations=3) 43 | eroded = cv2.erode(dilated, kernel, iterations=3) 44 | else: 45 | eroded = binary 46 | 47 | return eroded 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sudoku Image Solver 2 | A tool for solving sudoku puzzles by sending image of the puzzle as input. 3 | 4 | [Video Demo](https://youtu.be/zLT7nHLe0bs) 5 | 6 | ### Complete pipeline 7 | ![Complete Pipeline](Pipeline.png) 8 | 9 |
10 | 11 | ### Run locally 12 | Dependencies: python3.6+, gcc, swig. 13 | **Recommended:** Run the below instructions inside a python virtual environment. 14 | 15 | 1) After cloning the repo, run `pip install requirements.txt` in the cloned folder. 16 | 2) Then in the GridSolverFolder, run `python setup.py build_ext --inplace`. 17 | 3) Now, you can start the app from main folder by running `streamlit run app.py`. 18 |
19 | 20 | **If you can improve any part in the pipeline or fix any bug in the codebase, please make a PR** :smile: 21 | -------------------------------------------------------------------------------- /Testing/CNN_DigitClassify.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "CNN-DigitClassify.ipynb", 7 | "provenance": [] 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "accelerator": "GPU" 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "code", 18 | "metadata": { 19 | "id": "ytlgn0B7BRne", 20 | "colab_type": "code", 21 | "colab": {} 22 | }, 23 | "source": [ 24 | "# Simple CNN for the MNIST Dataset\n", 25 | "from keras.datasets import mnist\n", 26 | "from keras.models import Sequential\n", 27 | "from keras.layers import Dense\n", 28 | "from keras.layers import Dropout\n", 29 | "from keras.layers import Flatten\n", 30 | "from keras.layers.convolutional import Conv2D\n", 31 | "from keras.layers.convolutional import MaxPooling2D\n", 32 | "from keras.utils import np_utils\n", 33 | "import numpy as np" 34 | ], 35 | "execution_count": 0, 36 | "outputs": [] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "metadata": { 41 | "id": "vMJDjWFXCaaM", 42 | "colab_type": "code", 43 | "colab": {} 44 | }, 45 | "source": [ 46 | "from scipy.ndimage.interpolation import shift\n", 47 | "\n", 48 | "def shift_image(image, dx, dy):\n", 49 | " shifted_image = shift(image, [dy, dx], cval=0, mode=\"constant\")\n", 50 | " return shifted_image" 51 | ], 52 | "execution_count": 0, 53 | "outputs": [] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "metadata": { 58 | "id": "k-fhs1-_BipO", 59 | "colab_type": "code", 60 | "outputId": "2a0b2144-8362-4804-c911-8540ca35d64b", 61 | "colab": { 62 | "base_uri": "https://localhost:8080/", 63 | "height": 34 64 | } 65 | }, 66 | "source": [ 67 | "# load data\n", 68 | "(X_raw_train, y_train), (X_raw_test, y_test) = mnist.load_data()\n", 69 | "print(\"Creating Augmented Dataset...\")\n", 70 | "X_raw_train_augmented = [image for image in X_raw_train]\n", 71 | "y_train_augmented = [label for label in y_train]\n", 72 | "\n", 73 | "for dx, dy in ((1,0), (-1,0), (0,1), (0,-1), (1,1), (-1,1), (-1,-1), (1,-1)):\n", 74 | " for image, label in zip(X_raw_train, y_train):\n", 75 | " X_raw_train_augmented.append(shift_image(image, dx, dy))\n", 76 | " y_train_augmented.append(label)\n", 77 | "\n", 78 | "X_raw_train = np.array(X_raw_train_augmented, dtype=np.uint8)\n", 79 | "y_train_augmented = np.array(y_train_augmented, dtype=np.uint8)" 80 | ], 81 | "execution_count": 0, 82 | "outputs": [ 83 | { 84 | "output_type": "stream", 85 | "text": [ 86 | "Creating Augmented Dataset...\n" 87 | ], 88 | "name": "stdout" 89 | } 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "metadata": { 95 | "id": "WOT4VdZ7Jl0B", 96 | "colab_type": "code", 97 | "outputId": "69531a12-e70a-42c9-c25b-f6b7d10b843b", 98 | "colab": { 99 | "base_uri": "https://localhost:8080/", 100 | "height": 85 101 | } 102 | }, 103 | "source": [ 104 | "print(\"Adding hand labelled data for sudoku grid\")\n", 105 | "print(X_raw_train.shape, y_train_augmented.shape)\n", 106 | "X_raw_train_2 = np.loadtxt(\"X_train.csv\", dtype=np.uint8, delimiter=' ')\n", 107 | "y_train_2 = np.loadtxt(\"y_train.csv\", dtype=np.uint8, delimiter=' ')\n", 108 | "print(X_raw_train_2.shape, y_train_2.shape)\n", 109 | "\n", 110 | "X_raw_train_2 = np.reshape(X_raw_train_2, (-1, 28, 28))\n", 111 | "X_raw_train_2 = 255 - X_raw_train_2 # inverting as hand labelled data set has number is white and back in black\n", 112 | "\n", 113 | "print(\"Creating Augmented Dataset...\")\n", 114 | "X_raw_train_augmented_2 = [image for image in X_raw_train_2]\n", 115 | "y_train_augmented_2 = [label for label in y_train_2]\n", 116 | "\n", 117 | "for dx, dy in ((1,0), (-1,0), (0,1), (0,-1), (1,1), (-1,1), (-1,-1), (1,-1)):\n", 118 | " for image, label in zip(X_raw_train_2, y_train_2):\n", 119 | " X_raw_train_augmented_2.append(shift_image(image, dx, dy))\n", 120 | " y_train_augmented_2.append(label)\n", 121 | "\n", 122 | "X_raw_train_augmented_2 = np.array(X_raw_train_augmented_2, dtype=np.uint8)\n", 123 | "y_train_augmented_2 = np.array(y_train_augmented_2, dtype=np.uint8)\n", 124 | "\n", 125 | "X_raw_train = np.append(X_raw_train, X_raw_train_augmented_2, axis=0)\n", 126 | "y_train_augmented = np.append(y_train_augmented, y_train_augmented_2)" 127 | ], 128 | "execution_count": 0, 129 | "outputs": [ 130 | { 131 | "output_type": "stream", 132 | "text": [ 133 | "Adding hand labelled data for sudoku grid\n", 134 | "(540000, 28, 28) (540000,)\n", 135 | "(313, 784) (313,)\n", 136 | "Creating Augmented Dataset...\n" 137 | ], 138 | "name": "stdout" 139 | } 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "metadata": { 145 | "id": "bXB-BPqDGJYt", 146 | "colab_type": "code", 147 | "colab": {} 148 | }, 149 | "source": [ 150 | "# reshape to be [samples][width][height][channels]\n", 151 | "X_train = X_raw_train.reshape((X_raw_train.shape[0], 28, 28, 1)).astype('float32')\n", 152 | "X_test = X_raw_test.reshape((X_raw_test.shape[0], 28, 28, 1)).astype('float32')\n", 153 | "\n", 154 | "# normalize inputs from 0-255 to 0-1\n", 155 | "X_train = X_train / 255\n", 156 | "X_test = X_test / 255\n", 157 | "\n", 158 | "# one hot encode outputs\n", 159 | "y_train = np_utils.to_categorical(y_train_augmented)\n", 160 | "y_test = np_utils.to_categorical(y_test)\n", 161 | "num_classes = y_test.shape[1]" 162 | ], 163 | "execution_count": 0, 164 | "outputs": [] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "metadata": { 169 | "id": "sRzetpDBBf4s", 170 | "colab_type": "code", 171 | "colab": {} 172 | }, 173 | "source": [ 174 | "# define a simple CNN model\n", 175 | "def baseline_model():\n", 176 | "\t# create model\n", 177 | "\tmodel = Sequential()\n", 178 | "\tmodel.add(Conv2D(32, (5, 5), input_shape=(28, 28, 1), activation='relu'))\n", 179 | "\tmodel.add(MaxPooling2D())\n", 180 | "\tmodel.add(Dropout(0.2))\n", 181 | "\tmodel.add(Flatten())\n", 182 | "\tmodel.add(Dense(128, activation='relu'))\n", 183 | "\tmodel.add(Dense(num_classes, activation='softmax'))\n", 184 | "\t\n", 185 | " # Compile model\n", 186 | "\tmodel.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n", 187 | "\treturn model" 188 | ], 189 | "execution_count": 0, 190 | "outputs": [] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "metadata": { 195 | "id": "6XEH-R6EBmOE", 196 | "colab_type": "code", 197 | "outputId": "c93178c4-34aa-41ce-bd16-e45d4c9b181f", 198 | "colab": { 199 | "base_uri": "https://localhost:8080/", 200 | "height": 357 201 | } 202 | }, 203 | "source": [ 204 | "# build the model\n", 205 | "model_simple = baseline_model()\n", 206 | "model_simple.summary()" 207 | ], 208 | "execution_count": 0, 209 | "outputs": [ 210 | { 211 | "output_type": "stream", 212 | "text": [ 213 | "Model: \"sequential_10\"\n", 214 | "_________________________________________________________________\n", 215 | "Layer (type) Output Shape Param # \n", 216 | "=================================================================\n", 217 | "conv2d_10 (Conv2D) (None, 24, 24, 32) 832 \n", 218 | "_________________________________________________________________\n", 219 | "max_pooling2d_10 (MaxPooling (None, 12, 12, 32) 0 \n", 220 | "_________________________________________________________________\n", 221 | "dropout_10 (Dropout) (None, 12, 12, 32) 0 \n", 222 | "_________________________________________________________________\n", 223 | "flatten_10 (Flatten) (None, 4608) 0 \n", 224 | "_________________________________________________________________\n", 225 | "dense_19 (Dense) (None, 128) 589952 \n", 226 | "_________________________________________________________________\n", 227 | "dense_20 (Dense) (None, 10) 1290 \n", 228 | "=================================================================\n", 229 | "Total params: 592,074\n", 230 | "Trainable params: 592,074\n", 231 | "Non-trainable params: 0\n", 232 | "_________________________________________________________________\n" 233 | ], 234 | "name": "stdout" 235 | } 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "metadata": { 241 | "id": "rmXSph8guGwg", 242 | "colab_type": "code", 243 | "outputId": "22c63cf6-18c9-43ab-a575-1570e2b2eb8a", 244 | "colab": { 245 | "base_uri": "https://localhost:8080/", 246 | "height": 428 247 | } 248 | }, 249 | "source": [ 250 | "model_simple.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=10, batch_size=200)\n", 251 | "\n", 252 | "# Final evaluation of the model\n", 253 | "scores = model_simple.evaluate(X_test, y_test, verbose=1)\n", 254 | "\n", 255 | "print(\"CNN Error: %.2f%%\" % (100-scores[1]*100))" 256 | ], 257 | "execution_count": 0, 258 | "outputs": [ 259 | { 260 | "output_type": "stream", 261 | "text": [ 262 | "Train on 542817 samples, validate on 10000 samples\n", 263 | "Epoch 1/10\n", 264 | "542817/542817 [==============================] - 14s 26us/step - loss: 0.0110 - accuracy: 0.9963 - val_loss: 0.0214 - val_accuracy: 0.9941\n", 265 | "Epoch 2/10\n", 266 | "542817/542817 [==============================] - 14s 25us/step - loss: 0.0082 - accuracy: 0.9972 - val_loss: 0.0193 - val_accuracy: 0.9942\n", 267 | "Epoch 3/10\n", 268 | "542817/542817 [==============================] - 14s 25us/step - loss: 0.0066 - accuracy: 0.9977 - val_loss: 0.0218 - val_accuracy: 0.9940\n", 269 | "Epoch 4/10\n", 270 | "542817/542817 [==============================] - 14s 25us/step - loss: 0.0055 - accuracy: 0.9981 - val_loss: 0.0277 - val_accuracy: 0.9929\n", 271 | "Epoch 5/10\n", 272 | "542817/542817 [==============================] - 14s 25us/step - loss: 0.0050 - accuracy: 0.9982 - val_loss: 0.0237 - val_accuracy: 0.9935\n", 273 | "Epoch 6/10\n", 274 | "542817/542817 [==============================] - 14s 25us/step - loss: 0.0041 - accuracy: 0.9986 - val_loss: 0.0277 - val_accuracy: 0.9930\n", 275 | "Epoch 7/10\n", 276 | "542817/542817 [==============================] - 14s 25us/step - loss: 0.0038 - accuracy: 0.9987 - val_loss: 0.0278 - val_accuracy: 0.9935\n", 277 | "Epoch 8/10\n", 278 | "542817/542817 [==============================] - 14s 26us/step - loss: 0.0037 - accuracy: 0.9988 - val_loss: 0.0307 - val_accuracy: 0.9928\n", 279 | "Epoch 9/10\n", 280 | "542817/542817 [==============================] - 14s 25us/step - loss: 0.0033 - accuracy: 0.9989 - val_loss: 0.0276 - val_accuracy: 0.9938\n", 281 | "Epoch 10/10\n", 282 | "542817/542817 [==============================] - 14s 25us/step - loss: 0.0032 - accuracy: 0.9989 - val_loss: 0.0271 - val_accuracy: 0.9941\n", 283 | "10000/10000 [==============================] - 1s 58us/step\n", 284 | "CNN Error: 0.59%\n" 285 | ], 286 | "name": "stdout" 287 | } 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "metadata": { 293 | "id": "JOyTpcKMwu-8", 294 | "colab_type": "code", 295 | "colab": {} 296 | }, 297 | "source": [ 298 | "def preprocess_for_CNN(digit: np.ndarray):\n", 299 | " digit = 255 - digit\n", 300 | " digit = digit/255\n", 301 | " digit = digit.reshape((1, 28, 28, 1))\n", 302 | " return digit" 303 | ], 304 | "execution_count": 0, 305 | "outputs": [] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "metadata": { 310 | "id": "6Tr8HdTfKgKl", 311 | "colab_type": "code", 312 | "outputId": "bf437b36-9193-465e-c2f7-b5d071703ae1", 313 | "colab": { 314 | "base_uri": "https://localhost:8080/", 315 | "height": 34 316 | } 317 | }, 318 | "source": [ 319 | "X_small_test = np.zeros((X_raw_train_2.shape[0], 28, 28, 1))\n", 320 | "for i in range(X_raw_train_2.shape[0]):\n", 321 | " X_small_test[i] = preprocess_for_CNN(X_raw_train_2[i])\n", 322 | "\n", 323 | "# processed_digit = preprocess_for_CNN(X_raw_train_2)\n", 324 | "print(X_small_test.shape)" 325 | ], 326 | "execution_count": 0, 327 | "outputs": [ 328 | { 329 | "output_type": "stream", 330 | "text": [ 331 | "(313, 28, 28, 1)\n" 332 | ], 333 | "name": "stdout" 334 | } 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "metadata": { 340 | "id": "P0baS2JqKuSU", 341 | "colab_type": "code", 342 | "colab": {} 343 | }, 344 | "source": [ 345 | "probabilities_mat = model_simple.predict(X_small_test)\n", 346 | "y_pred = np.zeros((probabilities_mat.shape[0], ), dtype=np.uint8)\n", 347 | "for i in range(len(probabilities_mat)):\n", 348 | " y_pred[i] = np.argmax(probabilities_mat[i])" 349 | ], 350 | "execution_count": 0, 351 | "outputs": [] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "metadata": { 356 | "id": "ah8lJT4iNK3B", 357 | "colab_type": "code", 358 | "outputId": "ea505109-16e2-4d00-e987-b7d4d07ab897", 359 | "colab": { 360 | "base_uri": "https://localhost:8080/", 361 | "height": 68 362 | } 363 | }, 364 | "source": [ 365 | "from sklearn.metrics import accuracy_score\n", 366 | "print(y_pred[0:15])\n", 367 | "print(y_train_2[0:15])\n", 368 | "print(accuracy_score(y_pred, y_train_2))" 369 | ], 370 | "execution_count": 0, 371 | "outputs": [ 372 | { 373 | "output_type": "stream", 374 | "text": [ 375 | "[1 8 9 3 4 2 5 3 9 6 4 3 7 1 9]\n", 376 | "[1 8 9 3 4 2 5 3 9 6 4 3 7 1 9]\n", 377 | "0.9648562300319489\n" 378 | ], 379 | "name": "stdout" 380 | } 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "metadata": { 386 | "id": "fbGHluH8NRbk", 387 | "colab_type": "code", 388 | "colab": {} 389 | }, 390 | "source": [ 391 | "model_simple.save('my_model.h5') # creates a HDF5 file 'my_model.h5'" 392 | ], 393 | "execution_count": 0, 394 | "outputs": [] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "metadata": { 399 | "id": "4XoapmTdQrQX", 400 | "colab_type": "code", 401 | "colab": {} 402 | }, 403 | "source": [ 404 | "import tensorflow" 405 | ], 406 | "execution_count": 0, 407 | "outputs": [] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "metadata": { 412 | "id": "lkZ-9qltCu6i", 413 | "colab_type": "code", 414 | "outputId": "5e1b2869-2973-4235-e931-a7735c5e4cfe", 415 | "colab": { 416 | "base_uri": "https://localhost:8080/", 417 | "height": 34 418 | } 419 | }, 420 | "source": [ 421 | "tensorflow.__version__" 422 | ], 423 | "execution_count": 0, 424 | "outputs": [ 425 | { 426 | "output_type": "execute_result", 427 | "data": { 428 | "text/plain": [ 429 | "'2.2.0'" 430 | ] 431 | }, 432 | "metadata": { 433 | "tags": [] 434 | }, 435 | "execution_count": 2 436 | } 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "metadata": { 442 | "id": "TmeuA8qyCz60", 443 | "colab_type": "code", 444 | "outputId": "48bfbed5-8632-45f5-c857-b7b5b057d568", 445 | "colab": { 446 | "base_uri": "https://localhost:8080/", 447 | "height": 51 448 | } 449 | }, 450 | "source": [ 451 | "import keras\n", 452 | "keras.__version__" 453 | ], 454 | "execution_count": 0, 455 | "outputs": [ 456 | { 457 | "output_type": "stream", 458 | "text": [ 459 | "Using TensorFlow backend.\n" 460 | ], 461 | "name": "stderr" 462 | }, 463 | { 464 | "output_type": "execute_result", 465 | "data": { 466 | "text/plain": [ 467 | "'2.3.1'" 468 | ] 469 | }, 470 | "metadata": { 471 | "tags": [] 472 | }, 473 | "execution_count": 4 474 | } 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "metadata": { 480 | "id": "G4X5eSrxC3BJ", 481 | "colab_type": "code", 482 | "colab": {} 483 | }, 484 | "source": [ 485 | "" 486 | ], 487 | "execution_count": 0, 488 | "outputs": [] 489 | } 490 | ] 491 | } -------------------------------------------------------------------------------- /Testing/DigitsClassifyTesting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "DigitsClassifyTesting.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "toc_visible": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "code", 19 | "metadata": { 20 | "id": "nPWkwlMXmXn8", 21 | "colab_type": "code", 22 | "colab": {} 23 | }, 24 | "source": [ 25 | "import pandas as pd\n", 26 | "import numpy as np\n", 27 | "import cv2\n", 28 | "from tensorflow import keras\n", 29 | "from sklearn.metrics import accuracy_score\n", 30 | "from sklearn.pipeline import make_pipeline\n", 31 | "from sklearn.preprocessing import StandardScaler\n", 32 | "import joblib " 33 | ], 34 | "execution_count": 0, 35 | "outputs": [] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "metadata": { 40 | "id": "Bxr6wcY6wV_C", 41 | "colab_type": "code", 42 | "colab": {} 43 | }, 44 | "source": [ 45 | "def load_data():\n", 46 | " # df_train = pd.read_csv(train_data_file)\n", 47 | " # df_test = pd.read_csv(test_data_file) \n", 48 | " # train_features = df_train.iloc[:,1:]\n", 49 | " # train_labels = df_train.iloc[:,0]\n", 50 | " # train_features = np.array(train_features).astype(np.uint8)\n", 51 | " # test_features = np.array(test).astype(np.uint8) \n", 52 | " # return train_features, train_labels, test_features\n", 53 | " dataset = keras.datasets.mnist\n", 54 | " (X_raw_train, y_train), (X_raw_test, y_test) = dataset.load_data()\n", 55 | " return (X_raw_train, y_train, X_raw_test, y_test)" 56 | ], 57 | "execution_count": 0, 58 | "outputs": [] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "metadata": { 63 | "id": "DVof1SGCbk1f", 64 | "colab_type": "code", 65 | "colab": {} 66 | }, 67 | "source": [ 68 | "from scipy.ndimage.interpolation import shift\n", 69 | "import matplotlib.pyplot as plt\n", 70 | "\n", 71 | "def shift_image(image, dx, dy):\n", 72 | " shifted_image = shift(image, [dy, dx], cval=0, mode=\"constant\")\n", 73 | " return shifted_image" 74 | ], 75 | "execution_count": 0, 76 | "outputs": [] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "metadata": { 81 | "id": "lLAfrbfbjBZo", 82 | "colab_type": "code", 83 | "colab": {} 84 | }, 85 | "source": [ 86 | "def preprocess(X_raw_train, X_raw_test):\n", 87 | " X_train = np.zeros((X_raw_train.shape[0], 108))\n", 88 | " X_test = np.zeros((X_raw_test.shape[0], 108))\n", 89 | "\n", 90 | " hog = cv2.HOGDescriptor((28, 28), (14, 14), (7, 7), (14, 14), 12)\n", 91 | "\n", 92 | " for n in range(len(X_raw_train)):\n", 93 | " X_train[n] = hog.compute(X_raw_train[n]).reshape(1, -1)\n", 94 | " \n", 95 | " for n in range(len(X_raw_test)):\n", 96 | " X_test[n] = hog.compute(X_raw_test[n]).reshape(1, -1)\n", 97 | "\n", 98 | " return X_train, X_test" 99 | ], 100 | "execution_count": 0, 101 | "outputs": [] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "metadata": { 106 | "id": "ntQ5UgaLE9Ap", 107 | "colab_type": "code", 108 | "colab": {} 109 | }, 110 | "source": [ 111 | "# Classify the digits in X_test using the passed classifier 'clf'\n", 112 | "def classify_Digits(clf, X_test):\n", 113 | " # Returns the probability vector of length 10 for each input in X_raw_test\n", 114 | " return clf.predict_proba(X_test)" 115 | ], 116 | "execution_count": 0, 117 | "outputs": [] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "metadata": { 122 | "id": "r9lL6BewFPOa", 123 | "colab_type": "code", 124 | "colab": {} 125 | }, 126 | "source": [ 127 | "# Take decision based on probability vector of each class\n", 128 | "# cost_r: Cost of rejection (In our case we will mark it as empty or no digit = 0 in Sudoku)\n", 129 | "# cost_w: Cost of wrong classification (In our case we will mark it as empty or no digit = 0 in Sudoku)\n", 130 | "def take_decision(probabilities, cost_r=10, cost_w=20):\n", 131 | " assert cost_w != 0\n", 132 | "\n", 133 | " # Reference: https://www.cs.ubc.ca/~murphyk/Teaching/CS340-Fall07/dtheory.pdf\n", 134 | " pred_class = np.argmax(probabilities)\n", 135 | " if(probabilities[pred_class] > (1 - (cost_r/cost_w))):\n", 136 | " return pred_class\n", 137 | "\n", 138 | " # reject => No digit => 0 for our case\n", 139 | " return 0" 140 | ], 141 | "execution_count": 0, 142 | "outputs": [] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "metadata": { 147 | "id": "5HRAsw3qwoKC", 148 | "colab_type": "code", 149 | "outputId": "a234208b-6ba7-43aa-e6e1-b8cad49320be", 150 | "colab": { 151 | "base_uri": "https://localhost:8080/", 152 | "height": 68 153 | } 154 | }, 155 | "source": [ 156 | "X_raw_train, y_train, X_raw_test, y_test = load_data()\n", 157 | "\n", 158 | "print(\"Creating Augmented Dataset...\")\n", 159 | "X_raw_train_augmented = [image for image in X_raw_train]\n", 160 | "y_train_augmented = [image for image in y_train]\n", 161 | "\n", 162 | "for dx, dy in ((1,0), (-1,0), (0,1), (0,-1), (1,1), (-1,1), (-1,-1), (1,-1)):\n", 163 | " for image, label in zip(X_raw_train, y_train):\n", 164 | " X_raw_train_augmented.append(shift_image(image, dx, dy))\n", 165 | " y_train_augmented.append(label)" 166 | ], 167 | "execution_count": 7, 168 | "outputs": [ 169 | { 170 | "output_type": "stream", 171 | "text": [ 172 | "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n", 173 | "11493376/11490434 [==============================] - 0s 0us/step\n", 174 | "Creating Augmented Dataset...\n" 175 | ], 176 | "name": "stdout" 177 | } 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "metadata": { 183 | "id": "Gu6hozsS7IWt", 184 | "colab_type": "code", 185 | "colab": {} 186 | }, 187 | "source": [ 188 | "X_raw_train_augmented, y_train_augmented = np.array(X_raw_train_augmented), np.array(y_train_augmented)\n", 189 | "X_train, X_test = preprocess(X_raw_train_augmented, X_raw_test)" 190 | ], 191 | "execution_count": 0, 192 | "outputs": [] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": { 197 | "id": "IcfTU3HpFHXT", 198 | "colab_type": "text" 199 | }, 200 | "source": [ 201 | "# Gaussian Naive Bayes" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "metadata": { 207 | "id": "JmGwiWC9tJGA", 208 | "colab_type": "code", 209 | "colab": {} 210 | }, 211 | "source": [ 212 | "from sklearn.naive_bayes import GaussianNB" 213 | ], 214 | "execution_count": 0, 215 | "outputs": [] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "metadata": { 220 | "id": "45tt7rfAnzB_", 221 | "colab_type": "code", 222 | "colab": {} 223 | }, 224 | "source": [ 225 | "def train_GNB(X_train, y_train):\n", 226 | " clf = GaussianNB()\n", 227 | " clf.fit(X_train, y_train)\n", 228 | " return clf" 229 | ], 230 | "execution_count": 0, 231 | "outputs": [] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": { 236 | "id": "WZOhwbqJFS7I", 237 | "colab_type": "text" 238 | }, 239 | "source": [ 240 | "### Testing Gaussian Naive Bayes" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "metadata": { 246 | "id": "-U2310_i6clM", 247 | "colab_type": "code", 248 | "outputId": "1b01f956-0ea5-49a4-e842-0d4009594958", 249 | "colab": { 250 | "base_uri": "https://localhost:8080/", 251 | "height": 34 252 | } 253 | }, 254 | "source": [ 255 | "clf = train_GNB(X_train, y_train_augmented)\n", 256 | "print(clf.classes_)" 257 | ], 258 | "execution_count": 0, 259 | "outputs": [ 260 | { 261 | "output_type": "stream", 262 | "text": [ 263 | "[0 1 2 3 4 5 6 7 8 9]\n" 264 | ], 265 | "name": "stdout" 266 | } 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "metadata": { 272 | "id": "rzxqJQuS7IEd", 273 | "colab_type": "code", 274 | "colab": {} 275 | }, 276 | "source": [ 277 | "probabilities_mat = classify_Digits(clf, X_test)\n", 278 | "y_pred = np.zeros((probabilities_mat.shape[0], ))\n", 279 | "for i in range(len(probabilities_mat)):\n", 280 | " y_pred[i] = take_decision(probabilities_mat[i])" 281 | ], 282 | "execution_count": 0, 283 | "outputs": [] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "metadata": { 288 | "id": "oNZli63I9PAG", 289 | "colab_type": "code", 290 | "outputId": "2b07bcbe-8f6d-4bd2-b685-2580c69d3d91", 291 | "colab": { 292 | "base_uri": "https://localhost:8080/", 293 | "height": 51 294 | } 295 | }, 296 | "source": [ 297 | "print(y_pred[0:10])\n", 298 | "print(accuracy_score(y_pred, y_test))" 299 | ], 300 | "execution_count": 0, 301 | "outputs": [ 302 | { 303 | "output_type": "stream", 304 | "text": [ 305 | "[7. 2. 1. 0. 4. 1. 4. 9. 5. 7.]\n", 306 | "0.915\n" 307 | ], 308 | "name": "stdout" 309 | } 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "metadata": { 315 | "id": "a0l1IqzTAA_V", 316 | "colab_type": "code", 317 | "outputId": "b517cb7b-5a72-4130-f475-7c0ad0821e08", 318 | "colab": { 319 | "base_uri": "https://localhost:8080/", 320 | "height": 34 321 | } 322 | }, 323 | "source": [ 324 | "# Save model\n", 325 | "joblib.dump(clf, 'GNB.pkl')" 326 | ], 327 | "execution_count": 0, 328 | "outputs": [ 329 | { 330 | "output_type": "execute_result", 331 | "data": { 332 | "text/plain": [ 333 | "['GNB.pkl']" 334 | ] 335 | }, 336 | "metadata": { 337 | "tags": [] 338 | }, 339 | "execution_count": 29 340 | } 341 | ] 342 | }, 343 | { 344 | "cell_type": "markdown", 345 | "metadata": { 346 | "id": "8pJwXJufEgmn", 347 | "colab_type": "text" 348 | }, 349 | "source": [ 350 | "# Random Forest model" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "metadata": { 356 | "id": "ePoGfqEXEkyc", 357 | "colab_type": "code", 358 | "colab": {} 359 | }, 360 | "source": [ 361 | "from sklearn.ensemble import RandomForestClassifier" 362 | ], 363 | "execution_count": 0, 364 | "outputs": [] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "metadata": { 369 | "id": "Q-cI7kBBFZPP", 370 | "colab_type": "code", 371 | "colab": {} 372 | }, 373 | "source": [ 374 | "def train_Random_Forest(X_train, y_train):\n", 375 | " rfc = RandomForestClassifier(n_jobs=-1, n_estimators=100, max_depth = 20)\n", 376 | " rfc.fit(X_train, y_train)\n", 377 | " return rfc" 378 | ], 379 | "execution_count": 0, 380 | "outputs": [] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "metadata": { 385 | "id": "1W6Wp9ifHl5_", 386 | "colab_type": "text" 387 | }, 388 | "source": [ 389 | "### Testing Random Forest classifier" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "metadata": { 395 | "id": "IBmyH0zKAmp_", 396 | "colab_type": "code", 397 | "outputId": "e79915cd-feee-4eb3-fd52-32c0b8383e04", 398 | "colab": { 399 | "base_uri": "https://localhost:8080/", 400 | "height": 34 401 | } 402 | }, 403 | "source": [ 404 | "clf = train_Random_Forest(X_train, y_train_augmented)\n", 405 | "print(clf.classes_)" 406 | ], 407 | "execution_count": 0, 408 | "outputs": [ 409 | { 410 | "output_type": "stream", 411 | "text": [ 412 | "[0 1 2 3 4 5 6 7 8 9]\n" 413 | ], 414 | "name": "stdout" 415 | } 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "metadata": { 421 | "id": "t695PU41Eey1", 422 | "colab_type": "code", 423 | "colab": {} 424 | }, 425 | "source": [ 426 | "probabilities_mat = classify_Digits(clf, X_test)\n", 427 | "y_pred = np.zeros((probabilities_mat.shape[0], ))\n", 428 | "for i in range(len(probabilities_mat)):\n", 429 | " y_pred[i] = take_decision(probabilities_mat[i])" 430 | ], 431 | "execution_count": 0, 432 | "outputs": [] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "metadata": { 437 | "id": "Iv15gl_gqjJ9", 438 | "colab_type": "code", 439 | "outputId": "5e06f95a-df90-4935-fbac-cca3a6b68082", 440 | "colab": { 441 | "base_uri": "https://localhost:8080/", 442 | "height": 51 443 | } 444 | }, 445 | "source": [ 446 | "print(y_pred[0:10])\n", 447 | "print(accuracy_score(y_pred, y_test))" 448 | ], 449 | "execution_count": 0, 450 | "outputs": [ 451 | { 452 | "output_type": "stream", 453 | "text": [ 454 | "[7. 2. 1. 0. 4. 1. 4. 9. 5. 9.]\n", 455 | "0.9372\n" 456 | ], 457 | "name": "stdout" 458 | } 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "metadata": { 464 | "id": "O9D7nfxdAZz9", 465 | "colab_type": "code", 466 | "outputId": "48919987-1c6b-4a77-d029-00e7395446e3", 467 | "colab": { 468 | "base_uri": "https://localhost:8080/", 469 | "height": 34 470 | } 471 | }, 472 | "source": [ 473 | "# Save model\n", 474 | "joblib.dump(clf, 'RandomForest.pkl')" 475 | ], 476 | "execution_count": 0, 477 | "outputs": [ 478 | { 479 | "output_type": "execute_result", 480 | "data": { 481 | "text/plain": [ 482 | "['RandomForest.pkl']" 483 | ] 484 | }, 485 | "metadata": { 486 | "tags": [] 487 | }, 488 | "execution_count": 45 489 | } 490 | ] 491 | }, 492 | { 493 | "cell_type": "markdown", 494 | "metadata": { 495 | "id": "gOE4tjR9JRfD", 496 | "colab_type": "text" 497 | }, 498 | "source": [ 499 | "# Softmax Regression (Logistic regression for > 2 classes)" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "metadata": { 505 | "id": "LrRwQr6DJTvE", 506 | "colab_type": "code", 507 | "colab": {} 508 | }, 509 | "source": [ 510 | "from sklearn.linear_model import LogisticRegression" 511 | ], 512 | "execution_count": 0, 513 | "outputs": [] 514 | }, 515 | { 516 | "cell_type": "code", 517 | "metadata": { 518 | "id": "sRLJT6uhcPU-", 519 | "colab_type": "code", 520 | "colab": {} 521 | }, 522 | "source": [ 523 | "def train_SoftmaxRegression(X_train, y_train, max_iterations=100):\n", 524 | " clf = LogisticRegression(penalty='l2', dual=False, solver='lbfgs', multi_class='multinomial')\n", 525 | " clf.fit(X_train, y_train)\n", 526 | " return clf" 527 | ], 528 | "execution_count": 0, 529 | "outputs": [] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "metadata": { 534 | "id": "v5iKWrCIc_kk", 535 | "colab_type": "text" 536 | }, 537 | "source": [ 538 | "### Testing softmax regression model" 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "metadata": { 544 | "id": "T1uheAP5dE3w", 545 | "colab_type": "code", 546 | "outputId": "4cad2065-4493-43d9-d654-eaa749d766ec", 547 | "colab": { 548 | "base_uri": "https://localhost:8080/", 549 | "height": 190 550 | } 551 | }, 552 | "source": [ 553 | "clf = train_SoftmaxRegression(X_train, y_train_augmented, max_iterations=1000)\n", 554 | "print(clf.classes_)" 555 | ], 556 | "execution_count": 0, 557 | "outputs": [ 558 | { 559 | "output_type": "stream", 560 | "text": [ 561 | "[0 1 2 3 4 5 6 7 8 9]\n" 562 | ], 563 | "name": "stdout" 564 | }, 565 | { 566 | "output_type": "stream", 567 | "text": [ 568 | "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n", 569 | "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", 570 | "\n", 571 | "Increase the number of iterations (max_iter) or scale the data as shown in:\n", 572 | " https://scikit-learn.org/stable/modules/preprocessing.html\n", 573 | "Please also refer to the documentation for alternative solver options:\n", 574 | " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", 575 | " extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n" 576 | ], 577 | "name": "stderr" 578 | } 579 | ] 580 | }, 581 | { 582 | "cell_type": "code", 583 | "metadata": { 584 | "id": "4XTFn_4ydo_d", 585 | "colab_type": "code", 586 | "colab": {} 587 | }, 588 | "source": [ 589 | "probabilities_mat = classify_Digits(clf, X_test)\n", 590 | "y_pred = np.zeros((probabilities_mat.shape[0], ))\n", 591 | "for i in range(len(probabilities_mat)):\n", 592 | " y_pred[i] = take_decision(probabilities_mat[i])" 593 | ], 594 | "execution_count": 0, 595 | "outputs": [] 596 | }, 597 | { 598 | "cell_type": "code", 599 | "metadata": { 600 | "id": "xY1aLIYmdrh7", 601 | "colab_type": "code", 602 | "outputId": "3e382bc5-41e1-4922-fa10-301f5a77c921", 603 | "colab": { 604 | "base_uri": "https://localhost:8080/" 605 | } 606 | }, 607 | "source": [ 608 | "print(y_pred[0:10])\n", 609 | "print(accuracy_score(y_pred, y_test))" 610 | ], 611 | "execution_count": 0, 612 | "outputs": [ 613 | { 614 | "output_type": "stream", 615 | "text": [ 616 | "[7. 2. 1. 0. 4. 1. 4. 9. 5. 9.]\n", 617 | "0.9668\n" 618 | ], 619 | "name": "stdout" 620 | } 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "metadata": { 626 | "id": "7m-YqWZkqQA0", 627 | "colab_type": "code", 628 | "outputId": "a5efb1fc-954b-4e74-bceb-7730400f091c", 629 | "colab": { 630 | "base_uri": "https://localhost:8080/", 631 | "height": 34 632 | } 633 | }, 634 | "source": [ 635 | "# Save model\n", 636 | "joblib.dump(clf, 'Softmax.pkl')" 637 | ], 638 | "execution_count": 0, 639 | "outputs": [ 640 | { 641 | "output_type": "execute_result", 642 | "data": { 643 | "text/plain": [ 644 | "['Softmax.pkl']" 645 | ] 646 | }, 647 | "metadata": { 648 | "tags": [] 649 | }, 650 | "execution_count": 56 651 | } 652 | ] 653 | }, 654 | { 655 | "cell_type": "markdown", 656 | "metadata": { 657 | "id": "v_kPGhhgJlJh", 658 | "colab_type": "text" 659 | }, 660 | "source": [ 661 | "# Gradient Boosted Softmax Regression (cross entropy loss)" 662 | ] 663 | }, 664 | { 665 | "cell_type": "code", 666 | "metadata": { 667 | "id": "I6uNRszeJo0i", 668 | "colab_type": "code", 669 | "colab": {} 670 | }, 671 | "source": [ 672 | "import xgboost" 673 | ], 674 | "execution_count": 0, 675 | "outputs": [] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "metadata": { 680 | "id": "4ckfVFCF6OGj", 681 | "colab_type": "code", 682 | "colab": {} 683 | }, 684 | "source": [ 685 | "def train_XGBoostSoftmax(X_train, y_train, max_depth=6, n_estimators=50):\n", 686 | " model = xgboost.XGBClassifier(objective=\"multi:softmax\", booster=\"gbtree\", max_depth=max_depth, n_estimators=n_estimators, num_classes=10)\n", 687 | " model.fit(X_train, y_train)\n", 688 | " return model" 689 | ], 690 | "execution_count": 0, 691 | "outputs": [] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "metadata": { 696 | "id": "_svF46mA6LvG", 697 | "colab_type": "code", 698 | "outputId": "0a091287-2022-4afd-d6aa-07db4b905f50", 699 | "colab": { 700 | "base_uri": "https://localhost:8080/", 701 | "height": 34 702 | } 703 | }, 704 | "source": [ 705 | "# clf = train_XGBoostSoftmax(X_train, y_train_augmented, max_depth=7, n_estimators=100)\n", 706 | "clf = joblib.load(\"XGBOOST.pkl\")\n", 707 | "print(clf.classes_)" 708 | ], 709 | "execution_count": 0, 710 | "outputs": [ 711 | { 712 | "output_type": "stream", 713 | "text": [ 714 | "[0 1 2 3 4 5 6 7 8 9]\n" 715 | ], 716 | "name": "stdout" 717 | } 718 | ] 719 | }, 720 | { 721 | "cell_type": "code", 722 | "metadata": { 723 | "id": "ju-DKNydr9Ch", 724 | "colab_type": "code", 725 | "outputId": "5598f90c-b3b7-45f0-b013-67fcbce31c1f", 726 | "colab": { 727 | "base_uri": "https://localhost:8080/", 728 | "height": 51 729 | } 730 | }, 731 | "source": [ 732 | "probabilities_mat = classify_Digits(clf, X_test)\n", 733 | "y_pred = np.zeros((probabilities_mat.shape[0], ))\n", 734 | "print(probabilities_mat[0])\n", 735 | "for i in range(len(probabilities_mat)):\n", 736 | " y_pred[i] = take_decision(probabilities_mat[i])" 737 | ], 738 | "execution_count": 0, 739 | "outputs": [ 740 | { 741 | "output_type": "stream", 742 | "text": [ 743 | "[2.3678804e-05 2.8760120e-05 4.0040057e-05 6.5543842e-05 3.8886657e-05\n", 744 | " 5.7991092e-05 1.9326704e-05 9.9962735e-01 3.2450516e-05 6.5977889e-05]\n" 745 | ], 746 | "name": "stdout" 747 | } 748 | ] 749 | }, 750 | { 751 | "cell_type": "code", 752 | "metadata": { 753 | "id": "xuwJc5AesDc7", 754 | "colab_type": "code", 755 | "outputId": "27b7dbcf-bf1e-44a4-f7d5-7b9c9d4ee50a", 756 | "colab": { 757 | "base_uri": "https://localhost:8080/", 758 | "height": 51 759 | } 760 | }, 761 | "source": [ 762 | "print(y_pred[0:10])\n", 763 | "print(accuracy_score(y_pred, y_test))" 764 | ], 765 | "execution_count": 0, 766 | "outputs": [ 767 | { 768 | "output_type": "stream", 769 | "text": [ 770 | "[7. 2. 1. 0. 4. 1. 4. 9. 5. 9.]\n", 771 | "0.9753\n" 772 | ], 773 | "name": "stdout" 774 | } 775 | ] 776 | }, 777 | { 778 | "cell_type": "code", 779 | "metadata": { 780 | "id": "oVTQ2I_VsRns", 781 | "colab_type": "code", 782 | "colab": {} 783 | }, 784 | "source": [ 785 | "# Save model\n", 786 | "# joblib.dump(clf, 'XGBOOST.pkl')\n", 787 | "clf.save_model(\"XGBOOST.bin\")" 788 | ], 789 | "execution_count": 0, 790 | "outputs": [] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "metadata": { 795 | "id": "Fx8RhjAOj4AD", 796 | "colab_type": "code", 797 | "outputId": "81c40eed-3c27-4b8d-b57d-0da453fe1e09", 798 | "colab": { 799 | "base_uri": "https://localhost:8080/", 800 | "height": 268 801 | } 802 | }, 803 | "source": [ 804 | "# Standard scientific Python imports\n", 805 | "import matplotlib.pyplot as plt\n", 806 | "\n", 807 | "# Import datasets, classifiers and performance metrics\n", 808 | "from sklearn import datasets, svm, metrics\n", 809 | "from sklearn.model_selection import train_test_split\n", 810 | "\n", 811 | "# The digits dataset\n", 812 | "digits = datasets.load_digits()\n", 813 | "\n", 814 | "# The data that we are interested in is made of 8x8 images of digits, let's\n", 815 | "# have a look at the first 4 images, stored in the `images` attribute of the\n", 816 | "# dataset. If we were working from image files, we could load them using\n", 817 | "# matplotlib.pyplot.imread. Note that each image must have the same size. For these\n", 818 | "# images, we know which digit they represent: it is given in the 'target' of\n", 819 | "# the dataset.\n", 820 | "_, axes = plt.subplots(2, 4)\n", 821 | "images_and_labels = list(zip(digits.images, digits.target))\n", 822 | "for ax, (image, label) in zip(axes[0, :], images_and_labels[109:113]):\n", 823 | " ax.set_axis_off()\n", 824 | " ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')\n", 825 | " ax.set_title('Training: %i' % label)" 826 | ], 827 | "execution_count": 0, 828 | "outputs": [ 829 | { 830 | "output_type": "display_data", 831 | "data": { 832 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD7CAYAAABnoJM0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAZNklEQVR4nO3dfXAc9Z3n8fc36GTOgfUDuA5OEtgTEWNMCRGPA9RexQY2GEhFpmqNI/YhdljKsBduA3dLJZdczEMudd7dqxxxmdSGItjhuLXYJVuRdouY8OQkV7WOkbPA2gSw5Ye1tOTWYOMDcicj3/f+mJYYy5L6p9G05uH3eVV1ebr719O/+Viar7p7ftPm7oiISLw+UukOiIhIZakQiIhEToVARCRyKgQiIpFTIRARiZwKgYhI5KIrBGb2IzNbU+62UqB8s6NssxN7tlYL4wjM7L2i2ZnAIHAymb/d3f/H9PeqfMxsPnAAeL9o8Z+4+zemaf91nW8xM1sP3A982t2fnYb9Kdvs9lfX2ZrZ7wLfLVr0EeBfAnl331XOfTWU88my4u5nDT82s4PAbWP9oJlZg7sPTWffymx2JfofS75m9jHgZuDN6dqnss1OvWebFLKRYmZma4GvA78o975q+tSQmS03s34z+7KZ/QrYbGZzzOxvzeyImR1LHjcXbbPdzG5LHq81s/9pZv81aXvAzG4ose0CM/upmb1rZs+a2UNm9vg0xlF2dZjvQ8CXgRNTyaUclG126jDbYWuAxzyD0zg1XQgS5wFzgQuBdRRe0+Zk/gLg/wCbJtj+CuB14FzgT4HvmZmV0PYvgJ3AOcB9wO8Xb2hmr5jZ76S8lkPJD/BmMzs3pe10qYt8zexmYNDdn5qgr9NN2WanLrItanch8CngsbS2JXH3mpqAg8BvJY+XU/gL5MwJ2rcDx4rmt1M4hARYC+wrWjcTcOC8ybSl8IM1BMwsWv848HjgazoLyFM4VfevgCeBp5Vv2fI9G9gLzB/9GpWtsq3WbEf19+vA9qzyq4cjgiPu/n+HZ8xsppl918wOmdn/Bn4KzDazM8bZ/lfDD9z918nDsybZ9l8DR4uWARwOfQHu/p6797r7kLv/L+BO4DozOzv0OTJU8/lS+Evsv7v7wUlsMx2UbXbqIdtinwe+X+K2qeqhEIw+X/YfgIXAFe7+GxQOpwDGO6wrhzeBuWY2s2hZyxSeb/g1VcP/Tz3key3wR2b2q+SccQvwl2b25XJ2sgTKNjv1kC0AZvabFIrKk+Xq2GjV8EZTbmdTOP/3jpnNBe7NeofufgjoBe4zs0Yzuwr4bOj2ZnaFmS00s4+Y2TnARgqHgccz6vJU1Fy+FN6sLqVwOqAd+CfgdgoXOKuJss1OLWY7bA3wA3d/t6wdLFKPheBBCp+1fQvYAWybpv3+LnAV8Dbwn4EnKHyuGQAz22OFzwWPJUehn+8Cu5Ptbsm0t6WruXzd/W13/9XwROGz5sfc/b2x2leQss1OzWWbrD8TWE2Gp4WgRgaU1SIzewJ4zd0z/8sjRso3O8o2O9WabT0eEVSEmS01s48lp3euB1YCP6x0v+qF8s2Oss1OrWRbEyOLa8R5wF9T+LxwP/CH7v73le1SXVG+2VG22amJbHVqSEQkcjo1JCISuWo8NVSWQ5TZs2entjl+vHyfzly2bFlqmy1btqS2mT9/fsjupvLZ52k7BFy7dm1qmx/+MOx06TvvvDPF3kxKqfmWJduQ13rXXXeltgn5HQB48MEHg9qVSUWzfemll1Lb3Hfffalturu7g/Z32WWXpbYJ6VOgkt8XdEQgIhI5FQIRkcipEIiIRE6FQEQkcioEIiKRUyEQEYmcCoGISORUCEREIleNA8pSHTx4MLVNyGCxWbNmpbZpb28P6VKQaR4Ulambbroptc327dvL0iY2IQOaQgbihQ7WqxchA7OWL19eln2FDBQDePnll8uyv6zpiEBEJHIqBCIikVMhEBGJnAqBiEjkVAhERCKnQiAiEjkVAhGRyKkQiIhErm4HlIUIuctTyOCeehNyx6qQgWAhA5rKOWCvFoQMKvz2t7+d2iYk23INnqoVIQPKQn7nQ+6sF3pXt1oZRKojAhGRyKkQiIhEToVARCRyKgQiIpFTIRARiZwKgYhI5FQIREQip0IgIhK5mhxQVq67WpXzDlr1NHinXIOVQgaLhQwCgrA+hQwEmj9/ftD+shL6etMsW7YstU3oz25Iblu2bEltU+nfgZDXUa6BkCF5AMyePTu1TXd3d2qblStXBu2vVDoiEBGJnAqBiEjkVAhERCKnQiAiEjkVAhGRyKkQiIhEToVARCRyKgQiIpGryQFlIXcNCxk48pOf/CS1zdVXXx3QI3D3oHa1ICTfkDblHLx1/Pjx1DYhd4MKvbNUtbvppptS24Teye/QoUOpbULyr7SQ3/nQ3+dyCRlQduzYsWnoycR0RCAiEjkVAhGRyKkQiIhEToVARCRyKgQiIpFTIRARiZwKgYhI5FQIREQip0IgIhI5q8IRsdPWoZDbBl5++eVBz3XvvfemtgkZjRvIprDttOUb8nrvv//+oOcKybeMt6osNd+yZBsyQjrkVol333130P6+9KUvpbYp44jszLIt18jykFtVhowYhvLdVjdQye8LOiIQEYmcCoGISORUCEREIqdCICISORUCEZHIqRCIiEROhUBEJHIqBCIikavJW1WG3IIvZCBHyIAyKV3I/9OyZcuCnquMg/GqXshgpZBbVYYOKAt5rloQkttdd92V2iZkkOMLL7wQ1KdaoSMCEZHIqRCIiEROhUBEJHIqBCIikVMhEBGJnAqBiEjkVAhERCKnQiAiErmaHFBWroEjx48fT20za9asoD7Vy6CccgoZUNbe3p59R+pQyAC7yy67LOi5li9fPrXO1JCQn8mQ3/l6+7nVEYGISORUCEREIqdCICISORUCEZHIqRCIiEROhUBEJHIqBCIikVMhEBGJnLl7pfsgIiIVpCMCEZHIqRCIiEROhUBEJHIqBCIikVMhEBGJnAqBiEjkVAhERCKnQiAiEjkVAhGRyKkQiIhEToVARCRyKgQiIpFTIRARiZwKgYhI5FQIREQip0IgIhI5FQIRkcipEIiIRE6FQEQkcioEIiKRUyEQEYmcCoGISORUCEREIqdCICISORUCEZHIqRCIiEROhUBEJHIqBCIikVMhEBGJXGohMLNHzeyfzWz3OOvNzDaa2T4ze8XMPlG0bo2Z7U2mNeXseL1QvtlRttlRtnXG3SecgE8BnwB2j7P+RuBHgAFXAj9Pls8F9if/zkkez0nbX2yT8lW2tTgp2/qaUo8I3P2nwNEJmqwEHvOCHcBsMzsfWAE84+5H3f0Y8Axwfdr+YqN8s6Nss6Ns60tDGZ6jCThcNN+fLBtv+WnMbB2wDuCjH/3okosvvrgM3aodl156Kfv27SOfz/vodbNmzeK8885bl8/n/wLg7LPP5t13330V+CbKN9Vks21qavqn11577X1gQ1FTZTsGZVtddu3a9Za7zytp45DDBmA+4x8C/i3wb4rmnwPywB8D/6lo+deBP07b15IlSzw2Bw4c8MWLF4+57jOf+Yz/7Gc/G5m/5pprHHhV+YaZbLYvvviiUyiwyjaFsq0uQK9ndWoowADQUjTfnCwbb7lMQlNTE4cPf/iHf39/P8AHKN8pGyvbpqYmKOSrbKdA2daWchSCHuDzyacErgSOu/ubwNPAdWY2x8zmANcly2QSOjo6eOyxx3B3duzYwaxZs6Dwy6R8p2isbM8//3yA4yjbKVG2tSX1GoGZbQWWA+eaWT9wL/AvANz9z4GnKHxCYB/wa+ALybqjZvYN4MXkqR5w94kuLkXplltuYfv27bz11ls0Nzdz//3388EHHwBwxx13cOONN/LUU0/R2trKzJkz2bx5M0uXLlW+AUrJNnESULYTULb1xQqnlqpHPp/33t7eSnejqpnZLnfPl7Kt8k1Xar7KNp2yzc5U3hc0slhEJHIqBCIikVMhEBGJnAqBiEjkVAhERCKnQiAiEjkVAhGRyKkQiIhEToVARCRyKgQiIpFTIRARiZwKgYhI5FQIREQip0IgIhI5FQIRkcgFFQIzu97MXjezfWb2lTHW/zczeymZ3jCzd4rWnSxa11POzteDbdu2sXDhQlpbW9mwYcNp6++++27a29tpb2/n4x//OLNnzx5Zp2zTKd/sKNs6knZTY+AMoA/IAY3Ay8AlE7T/d8CjRfPvTeYmyjHdpHpoaMhzuZz39fX54OCgt7W1+Z49e8Ztv3HjRv/CF74wcpPqyWbryjfTfJWtsq0kMr55/SeBfe6+391PAF3Aygna3wJsnVw5itPOnTtpbW0ll8vR2NhIZ2cn3d3d47bfunUrt9xyyzT2sLYp3+wo2/oSUgiagMNF8/3JstOY2YXAAuD5osVnmlmvme0ws5tK7mkdGhgYoKWlZWS+ubmZgYGBMdseOnSIAwcOcM011xQvVrYTUL7ZUbb1JfXm9ZPUCTzp7ieLll3o7gNmlgOeN7N/cPe+4o3MbB2wDuCCCy4oc5fqQ1dXF6tWreKMM84oXpyaLSjfEKXmq2zTKdvqF3JEMAC0FM03J8vG0smo00LuPpD8ux/YDlw+eiN3f9jd8+6enzdvXkCX6kNTUxOHD394sNXf309T05gHW3R1dZ12aB2SbbJe+ZJNvsq2QNnWtpBC8CJwkZktMLNGCm/2p13lN7OLgTnA3xUtm2NmM5LH5wK/Cbxajo7Xg6VLl7J3714OHDjAiRMn6OrqoqOj47R2r732GseOHeOqq64aWaZs0ynf7Cjb+pJaCNx9CLgTeBr4JfCX7r7HzB4ws+L/+U6gK7l6PWwR0GtmLwMvABvcXf/hiYaGBjZt2sSKFStYtGgRq1evZvHixaxfv56eng9rbVdXF52dnZhZ8ebKNoXyzY6yrS926vt25eXzee/t7a10N6qame1y93wp2yrfdKXmq2zTKdvsTOV9QSOLRUQip0IgIhI5FQIRkcipEIiIRE6FQEQkcioEIiKRUyEQEYmcCoGISORUCEREIqdCICISORUCEZHIqRCIiEROhUBEJHIqBCIikVMhEBGJnAqBiEjkggqBmV1vZq+b2T4z+8oY69ea2REzeymZbitat8bM9ibTmnJ2vh5s27aNhQsX0trayoYNG05bv2XLFubNm0d7ezvt7e088sgjI+uUbTrlmx1lW0fcfcIJOAPoA3JAI/AycMmoNmuBTWNsOxfYn/w7J3k8Z6L9LVmyxGMxNDTkuVzO+/r6fHBw0Nva2nzPnj2ntNm8ebN/8YtfPGUZ0FtKtq58M81X2SrbSgJ6PeX3f7wp5Ijgk8A+d9/v7ieALmBlYJ1ZATzj7kfd/RjwDHB94LZ1b+fOnbS2tpLL5WhsbKSzs5Pu7u7QzZVtCuWbHWVbX0IKQRNwuGi+P1k22m+b2Stm9qSZtUxmWzNbZ2a9ZtZ75MiRwK7XvoGBAVpaWkbmm5ubGRgYOK3dD37wA9ra2li1ahWHD4/EGfr/onwTWeSrbAuUbW0r18XivwHmu3sbher+/cls7O4Pu3ve3fPz5s0rU5fqw2c/+1kOHjzIK6+8wqc//WnWrJn86VTlO76p5qtsx6dsa0dIIRgAWormm5NlI9z9bXcfTGYfAZaEbhuzpqam4r+S6O/vp6np1D+MzjnnHGbMmAHAbbfdxq5du4ZXKdsUyjc7yra+hBSCF4GLzGyBmTUCnUBPcQMzO79otgP4ZfL4aeA6M5tjZnOA65JlAixdupS9e/dy4MABTpw4QVdXFx0dHae0efPNN0ce9/T0sGjRouFZZZtC+WZH2daXhrQG7j5kZndS+I86A3jU3feY2QMUrlL3AH9kZh3AEHCUwqeIcPejZvYNCsUE4AF3P5rB66hJDQ0NbNq0iRUrVnDy5EluvfVWFi9ezPr168nn83R0dLBx40Z6enpoaGhg7ty5bNmyhUWLFinbAMo3O8q2vljhU0fVI5/Pe29vb6W7UdXMbJe750vZVvmmKzVfZZtO2WZnKu8LGlksIhI5FQIRkcipEIiIRE6FQEQkcioEIiKRUyEQEYmcCoGISORUCEREIqdCICISORUCEZHIqRCIiEROhUBEJHIqBCIikVMhEBGJnAqBiEjkVAhERCIXVAjM7Hoze93M9pnZV8ZY/+/N7FUze8XMnjOzC4vWnTSzl5KpZ/S2sdu2bRsLFy6ktbWVDRs2nLb+W9/6FpdccgltbW1ce+21HDp0aGSdsk2nfLOjbOuIu084Ubg9ZR+QAxqBl4FLRrW5GpiZPP5D4Imide+l7aN4WrJkicdiaGjIc7mc9/X1+eDgoLe1tfmePXtOafP888/7+++/7+7u3/nOd3z16tVO4Rahk87WlW+m+SpbZVtJw9mWMoUcEXwS2Ofu+939BNAFrBxVTF5w918nszuA5pKqUmR27txJa2sruVyOxsZGOjs76e7uPqXN1VdfzcyZMwG48sor6e/vr0RXa5LyzY6yrS8hhaAJOFw0358sG88fAD8qmj/TzHrNbIeZ3TTWBma2LmnTe+TIkYAu1YeBgQFaWlpG5pubmxkYGBi3/fe+9z1uuOGG4kWp2YLyHZZFvsq2QNnWtrJeLDaz3wPywJ8VLb7QCzdU/h3gQTP72Ojt3P1hd8+7e37evHnl7FLdePzxx+nt7eWee+4pXpyaLSjfEKXmq2zTKdvqF1IIBoCWovnmZNkpzOy3gK8BHe4+OLzc3QeSf/cD24HLp9DfutLU1MThwx8ebPX399PUdPrB1rPPPss3v/lNenp6mDFjxshyZTsx5ZsdZVtn0i4iAA3AfmABH14sXjyqzeUULihfNGr5HGBG8vhcYC+jLjSPnmK6KPTBBx/4ggULfP/+/SMX3Hbv3n1Km1/84heey+X8jTfeGFkG9JaSrSvfTPNVtsq2kpjCxeKGgEIxZGZ3Ak9T+ATRo+6+x8weSHbcQ+FU0FnAX5kZwD+6ewewCPiumf0/CkcfG9z91UnWqrrV0NDApk2bWLFiBSdPnuTWW29l8eLFrF+/nnw+T0dHB/fccw/vvfceN998MwAXXHDB8ObKNoXyzY6yrS9WKCTVI5/Pe29vb6W7UdXMbJcXzq9OmvJNV2q+yjadss3OVN4XNLJYRCRyKgQiIpFTIRARiZwKgYhI5FQIREQip0IgIhI5FQIRkcipEIiIRE6FQEQkcioEIiKRUyEQEYmcCoGISORUCEREIqdCICISORUCEZHIqRCIiEQuqBCY2fVm9rqZ7TOzr4yxfoaZPZGs/7mZzS9a9x+T5a+b2Yrydb0+bNu2jYULF9La2sqGDRtOWz84OMjnPvc5WltbueKKKzh48ODIOmWbTvlmR9nWkbR7WVK4PWUfkOPDexZfMqrNvwX+PHncCTyRPL4kaT+Dwj2P+4AzJtpfTPcmHRoa8lwu5319fSP3fd2zZ88pbR566CG//fbb3d1969atvnr16uH7vk46W1e+mearbJVtJTGFexaHHBF8Etjn7vvd/QTQBawc1WYl8P3k8ZPAtVa4efFKoMvdB939ALAveT4Bdu7cSWtrK7lcjsbGRjo7O+nu7j6lTXd3N2vWrAFg1apVPPfcc8OrlG0K5ZsdZVtfUm9eDzQBh4vm+4ErxmvjhZvdHwfOSZbvGLVt0+gdmNk6YF0yO2hmu4N6P33OBd7K4HnnAL9hZoeS+bnAWV/96lf/sajN4h//+MdvAB8k85cCFxOYLVR9vlllC9OQb5VnCzX8sxtxtqVaWOqGIYUgc+7+MPAwgJn1eok3YM5KVn0ys1XA9e5+WzL/+8AV7n5nUZvdwGfcvT+Z7wPencx+qjnfLPszHflWc7ZQ2z+7sWZbKjPrLXXbkFNDA0BL0XxzsmzMNmbWAMwC3g7cNmalZjsUuG3slG92lG0dCSkELwIXmdkCM2ukcDG4Z1SbHmBN8ngV8Hxy8aIH6Ew+VbQAuAjYWZ6u14WSsi1armwnpnyzo2zrScgVZeBG4A0KV/e/lix7AOhIHp8J/BWFiz47gVzRtl9LtnsduCFgX+tKvfKd1ZRln0rJdrg/k822GvPNuj/TmW+1ZZt1n5RtdfVpKv2x5AlERCRSGlksIhI5FQIRkchVrBBM5WsrKtintWZ2xMxeSqbbMuzLo2b2z+N9dtoKNiZ9fcXMPjHJ1zKt+VZTtsn+Ss5X2ab2p26yDexTzbwvjKtCFzVK/tqKCvdpLbBpmjL6FPAJYPc4628EfgQYcCXw82rNt9qynUq+yjaebKsx36m8L0w0VeqIYCpfW1HJPk0bd/8pcHSCJiuBx7xgBzDbzM5P1lVbvlWVLUwpX2Wboo6yJbBP02aK7wvjqlQhGOtrK0YPMT/layuA4a+tqGSfAH47OeR60sxaxlg/XSbqb7XlW2vZwvh9VrZTVyvZnrK/CfoE1ZNvaH9PoYvFk/M3wHx3bwOe4cO/TGTqlG12lG22aj7fShWCqXxtRcX65O5vu/tgMvsIsCTD/qSZqL/Vlm+tZQvj91nZTl2tZHvK/sbrU5XlW9LXd1SqEEzlaysq1qdR59o6gF9m2J80PcDnk08JXAkcd/c3k3XVlm+tZQvj56tsp65WsiWkT1WW70TvC+ObjivdE1zdLulrKyrYp/8C7KHwyYEXgIsz7MtW4E0KX+HbD/wBcAdwR7LegIeSvv4DkK/mfKsp26nmq2zjybba8p3q+8J4k75iQkQkcrpYLCISORUCEZHIqRCIiEROhUBEJHIqBCIikVMhEBGJnAqBiEjk/j+Wz11+MMfPsQAAAABJRU5ErkJggg==\n", 833 | "text/plain": [ 834 | "
" 835 | ] 836 | }, 837 | "metadata": { 838 | "tags": [], 839 | "needs_background": "light" 840 | } 841 | } 842 | ] 843 | }, 844 | { 845 | "cell_type": "code", 846 | "metadata": { 847 | "id": "HswmA43Oj4wd", 848 | "colab_type": "code", 849 | "colab": {} 850 | }, 851 | "source": [ 852 | "image = X_raw_train[1000]\n", 853 | "shifted_image_down = shift_image(image, 0, 1)\n", 854 | "shifted_image_left = shift_image(image, -1, 0)" 855 | ], 856 | "execution_count": 0, 857 | "outputs": [] 858 | }, 859 | { 860 | "cell_type": "code", 861 | "metadata": { 862 | "id": "DnPXTMoVlA2I", 863 | "colab_type": "code", 864 | "outputId": "e359aab2-c4c8-4bdb-dff5-a8b796cfdd2b", 865 | "colab": { 866 | "base_uri": "https://localhost:8080/", 867 | "height": 245 868 | } 869 | }, 870 | "source": [ 871 | "plt.figure(figsize=(12,3))\n", 872 | "plt.subplot(131)\n", 873 | "plt.title(\"original\", fontsize=14)\n", 874 | "plt.imshow(image.reshape(28, 28), interpolation=\"nearest\", cmap=\"Greys\")\n", 875 | "\n", 876 | "plt.subplot(132)\n", 877 | "plt.title(\"shifted down\", fontsize=14)\n", 878 | "plt.imshow(shifted_image_down.reshape(28, 28), interpolation=\"nearest\", cmap=\"Greys\")\n", 879 | "\n", 880 | "plt.subplot(133)\n", 881 | "plt.title(\"shifted left\", fontsize=14)\n", 882 | "plt.imshow(shifted_image_left.reshape(28, 28), interpolation=\"nearest\", cmap=\"Greys\")" 883 | ], 884 | "execution_count": 0, 885 | "outputs": [ 886 | { 887 | "output_type": "execute_result", 888 | "data": { 889 | "text/plain": [ 890 | "" 891 | ] 892 | }, 893 | "metadata": { 894 | "tags": [] 895 | }, 896 | "execution_count": 25 897 | }, 898 | { 899 | "output_type": "display_data", 900 | "data": { 901 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAp0AAADTCAYAAADDGKgLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAZKklEQVR4nO3df5AcdZnH8c+HBQ0HnkdkyYUQsoiAcICAe8FTPEKJdzmr7vAHIBTHwYEErsSCM5QCemeKEwUExLp4WFFjoM6fKBEElV8aAQUkcgiRoPwwQDCGTWFJwqFcyHN/dK8Ou73szM58Z6a736+qrd15pmf66d159vtMd397HBECAAAAUtqq1wkAAACg+mg6AQAAkBxNJwAAAJKj6QQAAEByNJ0AAABIjqYTAAAAydF09gHbJ9re1OJjFtlelSCXsH1kp58X6IZmaqloGdsLbD9ue4vtRR3OabHtFS0+ZjivxaFO5gKURRlr2fZQXrfDLT7vItvr88ee2G6e/Yymsz98VdKrW3zMxZIOTZALUHUvqjfbO0j6tKRPSJol6WLbK2wv7lF+AJpT+lq2va+kj0g6TdJMSV+1vcb2Wb3NLI2te51A3dneJiKek/RcK4+LiE2SWto7CkAqqLc5yv4XXhcR6yTJdi9SA9CCitTya/Lv34z803pKkPOUsaezw2y/3PZl+a7y39m+0/Yh+X3z8t3nb7P9Y9vPS/rbCQ4RnJM/xybbV9r+iO01Dfe/6PC67WW2r7N9hu0nbf/G9hds/0nDMvNt35bf97TtG2zvnf63AnSO7b/O62qT7d/mtbTvmGXeYnuV7Wdtf9/2bg33/aHe8kNZ/5Pf9When8uUHUV4b377D4e5be9j+3rbG20/ZfvLtv+84bkHbF+c19hvbF8maaCJbZpv+8H8f8ZtkvYsWOadtu+3/XvbT9j+kPPRyfZpth9sWPbwPO+zG2L/bftzjb+Dl/o9AalVsZYLtnHC9Tg7/L88X3RLnt8KZc3zJ0ZzbnWd/Yyms/MukvRuSSdJOlDS/ZK+a3tmwzIXSvqwpNdKumvsE9g+Rtnu9g9JOkjSaknvb2Ldb5a0r6TD8xzeIemMhvu3k3SZpLmS5kn6raRv2X5Z01sH9JDtrSVdI+l2Sa+TdLCy1/QLDYu9XNI5ymrwryT9maTPTPCUX5U0P/95rrLDW2dIukPSF/LbMyU9kdfwrZJW5cseLml7SdfYHv1fulDSKZJOzdc9IOm4SbZptqRvSrpJ0gGS/lPZ/5HGZV4v6SpJV0vaT9LZ+Taeni+yQtJeDYPmPEkb8u+jDs2XG9XK7wnoqCrWcsE2Traei/N1qCG/d0paK+m8hlh1RARfHfpS1tQ9L+mfGmIDkh6R9FFlA0BIeteYx50oaVPD7TskfWbMMjdKWtNwe5GkVQ23l0l6QtJAQ+yzkm6eJN8XJB3SEAtJR/b6d8kXX0Vfkqbnr9FDJ7j/xPz+vRpix0n6vSQ3LNNYb8P5Y4YaYiskLR7z3OdJumVMbIf8sXPz27+S9KGG+7eS9AtJK15imz6WL+OG2Icbc5L0RUnfG/O4RZLWNtxeJ+nY/OfbJX1Q2Sk4Wys7hBeSdmn298QXXym/KlrLQ/lzDLewniMlxZhl1kg6q9d/oxRf7OnsrN0lbSPph6OBiHhBWRO5T8NyKyd5ntdK+vGY2Lg9ogUeyNc36leSdhq9YXt321+y/YjtZyStV1ZIuzbx3EDPRcTTyt5g3ZAfsnq/7bGv399HxM8bbv9K0suU/bNvx+sl/XV+KHBTfljvify+3W2/UtleiTsa8t2iyWt3b0l3Rj7a5O4oWOaHY2K3S5pl+0/z2z+QNM/ZKTV/qez3tCH/eZ6kRyJibcPjU/2egElVtJZbWk97m1BOTCTqnsYB5dlE6/i/gnU2vrG4Ttlu+1MlPSlps6QHlBUxUAoR8c/5+VXzJf2DpPNtvz0ibsgX2Tz2Ifn3dt9kbyXpeklFs0pH38D1wuj2rVB2Gs4bJT0cEevz88MOU/amd8WYx6X6PQFNqUEtT7ae2uGfS2c9ouzw+ptGA7YHlJ0P8kALz/Ogsr0Tjea2k5jtVynbg/qxiLg5IlZLeoV444ESioifRsSFETFPWTN1QodX8bzGTxq4R9JfSHosIh4e87UxIn6r7BD3G0YfkE/0max2V0s6eHRSUO4NBcu8aUzsEGWH1zfmt1dI2kPZIcgVDbHDNP58TqAvVKyWx3rJ9bSYcyXQdHZQRDwr6XJJFzqbob53fnuGpP9q4ak+JelE2yfZ3sP2B5SdZN3OLLbfKDvUdort19g+VNkJ2WPfSQJ9y/Zuti+w/Ubbc2wfJml/tfamrhlrJM11drHnHfOT/j8t6ZXKrqN3sO1X57PEl9h+Rf64T0n6gO0jbe+lbGLEZBMBPqPsXLDLbO/l7MMZThuzzCWSDnV21Yo9bR+nbKLDHyYcRcSDkn4t6R8lfT8Pr1B2aH0X0XSij1S0lsdqZj0T5fxm27Ns79jiOvsaTWfnfVDZLLovSLpXWRHNj/yaYc2IiK9I+g9JFyi7BMS+ygam3001qfx8lHfn+axSVgz/puykbKAs/lfZ5YSuUnZS/xXKJtlc2OH1XKxsb8MDkkYk7RoRv1K2t3GLpO9K+pmyOvq9/lhHlyir/c8pO/9rqzy/CUXE48pmrM6X9FNJ/6psdnrjMvdIOkrSu5TV7wX519iLXv9A2R6SH+SPW6PsVJqx53MCvVa5Wh6ryfUU+XdJs5UdPR1pZZ39bnQGGPqc7eWSto6Iv+91LgAAAK3ifL4+lM8+/Rdl74w2K9u7cUT+HQAAoHTY09mHbG8r6VvKLi6/raSHJF0YEV/qaWIAAABTRNMJAACA5JhIBAAAgOTaajptz7f9c9sP2z578kcA6CVqFigP6hVVM+XD6/lFz38h6a3KPuXmbmWf+zvhNbZ23HHHGBoamtL6gE5as2aNNmzY4MmXrI5Wa5Z6Rb+gXhljUS4T1Ww7s9fnKvuotUclyfZXlM2wnrAghoaGtHLlZB87DqQ3PDzc6xR6oaWapV7RL6hXxliUy0Q1287h9Vn64wfXS9k7sVltPB+AtKhZoDyoV1RO8olEthfYXml75chIpS6sD1QO9QqUCzWLMmmn6XxS2cc0jdolj71IRCyJiOGIGB4cHGxjdQDaNGnNUq9A32CMReW003TeLWkP27vZfpmkYyRd25m0ACRAzQLlQb2icqY8kSgiNts+XdINkgYkLY2In3UsMwAdRc0C5UG9oora+uz1iPi2pG93KBcAiVGzQHlQr6gaPpEIAAAAydF0AgAAIDmaTgAAACRH0wkAAIDkaDoBAACQHE0nAAAAkqPpBAAAQHI0nQAAAEiOphMAAADJ0XQCAAAgOZpOAAAAJEfTCQAAgORoOgEAAJAcTScAAACSo+kEAABAcjSdAAAASI6mEwAAAMnRdAIAACA5mk4AAAAkR9MJAACA5LZu58G210jaKOkFSZsjYrgTSeGlPf3004XxxYsXF8YXLVpUGI+IcbGtty5+Sdxwww2F8UMPPbQwPjAwUBhHb1GzQHlQr2mlHEul4vG07mNpW01n7rCI2NCB5wHQHdQsUB7UKyqDw+sAAABIrt2mMyTdaPsnthd0IiEASVGzQHlQr6iUdg+vHxIRT9reSdJNth+MiFsbF8gLZYEk7brrrm2uDkCbXrJmqVegrzDGolLa2tMZEU/m35+StFzS3IJllkTEcEQMDw4OtrM6AG2arGapV6B/MMaiaqa8p9P2dpK2ioiN+c9/I+m8jmVWI1u2bCmMf+973yuMH3/88YXx9evXt7TenXfeeVxs3bp1hcsefvjhhfENG4rPb58+fXpLuSA9arY3uNoEpoJ6bV0/jaVS8Xha97G0ncPrMyQttz36PF+KiO92JCsAKVCzQHlQr6icKTedEfGopNd1MBcACVGzQHlQr6giLpkEAACA5Gg6AQAAkBxNJwAAAJLrxMdgokm33XZbYfxHP/pRYfycc85p6flPOumkwvjChQsL4zNnzhwXO/roowuXvfnmmwvjCxYUX6/461//emEcKLt+miHL1SZQV0XjaT+NpVLxeFr3sZQ9nQAAAEiOphMAAADJ0XQCAAAgOZpOAAAAJEfTCQAAgOSYvZ7I1VdfPS521FFHFS5b9JnKkrTTTjsVxu++++7C+C677FIYzz9GrSnXXXddYXzatGmF8eXLlxfGf/nLXxbGd9ttt6ZzAXqJq00AvVc0lkrF42k/jaVS8Xha97GUPZ0AAABIjqYTAAAAydF0AgAAIDmaTgAAACRH0wkAAIDkmL3epueff74wft55542LTTSzbrvttiuM33nnnYXx2bNnN5ld6wYGBgrjBx10UGH8nnvuKYxPtK1AP+JqE9WaIYvyaWUslYrrsJ/GUql4PK37WMqeTgAAACRH0wkAAIDkaDoBAACQHE0nAAAAkpu06bS91PZTtlc1xKbbvsn2Q/n3HdKmCaBZ1CxQHtQr6qSZ2evLJC2WdGVD7GxJt0TEBbbPzm9/sPPp9b+JZtzdd999TT/H+eefXxgfGhqaSkptmWj2+sEHH1wYn2jGHXpqmajZQlxtoh4zZEtmmajXyo2lUnF91n0snXRPZ0TcKunpMeEjJF2R/3yFpLd3OC8AU0TNAuVBvaJOpnpO54yIWJf//GtJMzqUD4A0qFmgPKhXVFLbE4kiOy4z4bEZ2wtsr7S9cmRkpN3VAWjTS9Us9Qr0F8ZYVMlUm871tmdKUv79qYkWjIglETEcEcODg4NTXB2ANjVVs9Qr0BcYY1FJU/0YzGslnSDpgvz7NR3LqGQ2btzY9LLbb799Yfz444/vVDrARKhZVW+yAhP/Kqt29drKWCoVj6eMpf2vmUsmfVnSHZL2sr3W9snKCuGtth+SdHh+G0AfoGaB8qBeUSeT7umMiGMnuOstHc4FQAdQs0B5UK+oEz6RCAAAAMnRdAIAACA5mk4AAAAkN9XZ68gtX7686WVPO+20wvgOO/CxukA3cLUJoD+1MpZKxeMpY2n/Y08nAAAAkqPpBAAAQHI0nQAAAEiOphMAAADJ0XQCAAAgOWavN+m5554rjF900UVNP8chhxzSqXSS2bx5c2H8+uuv73ImQOdxtQmgtzoxlkrlHU/rPpaypxMAAADJ0XQCAAAgOZpOAAAAJEfTCQAAgORoOgEAAJAcs9ebtG7dusL4448/3vRzvOpVr+pUOslERGF8ou3cdtttC+PTpk3rWE5Aq7jaRL1nyKJ/dWIslco7ntZ9LGVPJwAAAJKj6QQAAEByNJ0AAABIjqYTAAAAyU3adNpeavsp26saYotsP2n73vzrbWnTBNAsahYoD+oVddLM7PVlkhZLunJM/JMRcXHHM6qwuXPn9jqFjttvv/0K4zvvvHOXM0GDZap5zXK1iXrPkC2ZZap5vU5F1cbTuoylk+7pjIhbJT3dhVwAdAA1C5QH9Yo6aeecztNt35cfGtihYxkBSIWaBcqDekXlTLXpvFzS7pIOkLRO0iUTLWh7ge2VtleOjIxMcXUA2tRUzVKvQF9gjEUlTanpjIj1EfFCRGyR9FlJE55cERFLImI4IoYHBwenmieANjRbs9Qr0HuMsaiqKTWdtmc23HyHpFUTLQug96hZoDyoV1TVpLPXbX9Z0jxJO9peK+kjkubZPkBSSFoj6dSEOaKLvvOd77S0fCufZY3uoGY7o2qzY6X6zJAtE+q1uloZT+sylk7adEbEsQXhzyfIBUAHULNAeVCvqBM+kQgAAADJ0XQCAAAgOZpOAAAAJEfTCQAAgOSa+ex1SJozZ05hfO+99y6Mr169OmU6HbFp06ZxsdNPP72l5zjooIM6lQ6AFnG1CZRNXcZSqbXxtC5jKXs6AQAAkBxNJwAAAJKj6QQAAEByNJ0AAABIjolETRoYGCiMb7PNNl3OpHNWrRr/cb5r164tXHai7bfd0ZyATqjLZAUm/qFs6jKWSsXjad3HUvZ0AgAAIDmaTgAAACRH0wkAAIDkaDoBAACQHE0nAAAAkmP2ehdt3LixMD59+vSk63322WcL4wsXLhwXm2hm3Y033lgY32677aaeGJBIXWbIcrUJ1FXReNpPY6lUXId1H0vZ0wkAAIDkaDoBAACQHE0nAAAAkqPpBAAAQHI0nQAAAEhu0tnrtmdLulLSDEkhaUlEfMr2dElflTQkaY2koyPiN+lS7U+nnHJKYfx973vfuNi1115buOyJJ57YkVy2bNlSGF+8eHFh/I477hgXmz17duGyhx122NQTQ9dQr53D1SbQDdRsppWxVCoeT/tpLJWKx9O6j6XN7OncLGlhROwj6Q2S3mt7H0lnS7olIvaQdEt+G0BvUa9AuVCzqI1Jm86IWBcR9+Q/b5S0WtIsSUdIuiJf7ApJb0+VJIDmUK9AuVCzqJOWzum0PSTpQEl3SZoREevyu36t7NBA0WMW2F5pe+XIyEgbqQJoBfUKlAs1i6pruum0vb2kb0g6MyKeabwvIkLZuSjjRMSSiBiOiOHBwcG2kgXQHOoVKBdqFnXQVNNpextlxfDFiLg6D6+3PTO/f6akp9KkCKAV1CtQLtQs6qKZ2euW9HlJqyPi0oa7rpV0gqQL8u/XJMmwzw0PDze97Mc//vHC+DHHHFMYnzZtWku53H777YXxc845pzBe9K741ltvbWmd6C/U60vjahPoN9RsppWxVCoeT/tpLJUYT4tM2nRKepOk4yXdb/vePHauskL4mu2TJT0m6eg0KQJoAfUKlAs1i9qYtOmMiNsleYK739LZdAC0g3oFyoWaRZ3wiUQAAABIjqYTAAAAydF0AgAAILlmJhLhJey///6F8Z122mlc7KGHHipc9vLLLy+Mn3rqqYXxq666qjB+1llnFcYn8tGPfnRcbM6cOS09B1AmXG0C6E+tjKVS8XjaT2OpxHhahD2dAAAASI6mEwAAAMnRdAIAACA5mk4AAAAkR9MJAACA5Ji93qZtt922MH7XXXeNi+25556Fyy5cuLAwftFFFxXGR0ZGCuMTfZbze97znsL4ySefXBgHqoqrTQD9qZWxVCoeTxlL+x97OgEAAJAcTScAAACSo+kEAABAcjSdAAAASI6mEwAAAMkxez2RolmlS5cuLVz2zDPPLIyvX7++pXVeeumlhfEFCxYUxrfaivccqBeuNgGUy0RXaCgaTxlL+x+/KQAAACRH0wkAAIDkaDoBAACQHE0nAAAAkpt0IpHt2ZKulDRDUkhaEhGfsr1I0imSRs+SPzcivp0q0So47rjjWooDraJep4aJf+gVanZqisZNxtL+18zs9c2SFkbEPbZfIekntm/K7/tkRFycLj0ALaJegXKhZlEbkzadEbFO0rr85422V0ualToxAK2jXoFyoWZRJy0dr7E9JOlASaMXtTvd9n22l9reYYLHLLC90vbKia5XB6DzqFegXKhZVF3TTaft7SV9Q9KZEfGMpMsl7S7pAGXv0i4pelxELImI4YgYHhwc7EDKACZDvQLlQs2iDppqOm1vo6wYvhgRV0tSRKyPiBciYoukz0qamy5NAM2iXoFyoWZRF83MXrekz0taHRGXNsRn5ueiSNI7JK1KkyKAZlGvncPVJtAN1CzqpJnZ62+SdLyk+23fm8fOlXSs7QOUXeJhjaRTk2QIoBXUK1Au1Cxqo5nZ67dLcsFdXC8M6DPUK1Au1CzqhKsNAwAAIDmaTgAAACRH0wkAAIDkaDoBAACQHE0nAAAAkqPpBAAAQHI0nQAAAEiOphMAAADJ0XQCAAAgOUdE91Zmj0h6LL+5o6QNXVt577Cd/WlORAz2Ool+Rr1WXpm2lXptAjVbaWXbzsKa7WrT+aIV2ysjYrgnK+8ithNVUJe/b122U6rXttZRXf6+bGe5cHgdAAAAydF0AgAAILleNp1LerjubmI7UQV1+fvWZTulem1rHdXl78t2lkjPzukEAABAfXB4HQAAAMl1vem0Pd/2z20/bPvsbq8/JdtLbT9le1VDbLrtm2w/lH/foZc5doLt2ba/b/sB2z+zfUYer9y2oro1S71Wb1tR3XqV6lGzVa/XrjadtgckfVrS30naR9KxtvfpZg6JLZM0f0zsbEm3RMQekm7Jb5fdZkkLI2IfSW+Q9N7871jFba21itfsMlGvVdvWWqt4vUr1qNlK12u393TOlfRwRDwaEc9L+oqkI7qcQzIRcaukp8eEj5B0Rf7zFZLe3tWkEoiIdRFxT/7zRkmrJc1SBbcV1a1Z6rV624rq1qtUj5qter12u+mcJemJhttr81iVzYiIdfnPv5Y0o5fJdJrtIUkHSrpLFd/WmqpbzVb6NUy9Vl7d6lWq8Ou4ivXKRKIuiuxSAZW5XIDt7SV9Q9KZEfFM431V21bUT9Vew9Qrqq5Kr+Oq1mu3m84nJc1uuL1LHquy9bZnSlL+/ake59MRtrdRVhBfjIir83Alt7Xm6lazlXwNU6+1Ubd6lSr4Oq5yvXa76bxb0h62d7P9MknHSLq2yzl027WSTsh/PkHSNT3MpSNsW9LnJa2OiEsb7qrctqJ2NVu51zD1Wit1q1epYq/jqtdr1y8Ob/ttki6TNCBpaUSc39UEErL9ZUnzJO0oab2kj0j6pqSvSdpV0mOSjo6IsSdCl4rtQyTdJul+SVvy8LnKzjup1LaiujVLvVKvVVTVepXqUbNVr1c+kQgAAADJMZEIAAAAydF0AgAAIDmaTgAAACRH0wkAAIDkaDoBAACQHE0nAAAAkqPpBAAAQHI0nQAAAEju/wHwvY54lmnxrAAAAABJRU5ErkJggg==\n", 902 | "text/plain": [ 903 | "
" 904 | ] 905 | }, 906 | "metadata": { 907 | "tags": [], 908 | "needs_background": "light" 909 | } 910 | } 911 | ] 912 | }, 913 | { 914 | "cell_type": "code", 915 | "metadata": { 916 | "id": "fEP5Z19emNtB", 917 | "colab_type": "code", 918 | "outputId": "8dd3d614-ea0c-48c9-d75f-62b76a48ee10", 919 | "colab": { 920 | "base_uri": "https://localhost:8080/", 921 | "height": 190 922 | } 923 | }, 924 | "source": [ 925 | "clf_t = train_SoftmaxRegression(X_raw_train_augmented.reshape((-1, 28*28)), y_train_augmented, max_iterations=5000)\n", 926 | "print(clf_t.classes_)" 927 | ], 928 | "execution_count": 0, 929 | "outputs": [ 930 | { 931 | "output_type": "stream", 932 | "text": [ 933 | "[0 1 2 3 4 5 6 7 8 9]\n" 934 | ], 935 | "name": "stdout" 936 | }, 937 | { 938 | "output_type": "stream", 939 | "text": [ 940 | "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n", 941 | "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", 942 | "\n", 943 | "Increase the number of iterations (max_iter) or scale the data as shown in:\n", 944 | " https://scikit-learn.org/stable/modules/preprocessing.html\n", 945 | "Please also refer to the documentation for alternative solver options:\n", 946 | " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", 947 | " extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n" 948 | ], 949 | "name": "stderr" 950 | } 951 | ] 952 | }, 953 | { 954 | "cell_type": "code", 955 | "metadata": { 956 | "id": "LLR9eteZpDUb", 957 | "colab_type": "code", 958 | "colab": {} 959 | }, 960 | "source": [ 961 | "probabilities_mat = classify_Digits(clf_t, X_test_2)\n", 962 | "y_pred = np.zeros((probabilities_mat.shape[0], ))\n", 963 | "for i in range(len(probabilities_mat)):\n", 964 | " y_pred[i] = take_decision(probabilities_mat[i])" 965 | ], 966 | "execution_count": 0, 967 | "outputs": [] 968 | }, 969 | { 970 | "cell_type": "code", 971 | "metadata": { 972 | "id": "iFFS42ZFolYc", 973 | "colab_type": "code", 974 | "outputId": "b4f11fa5-abc5-4add-d13c-83d3d5aee01f", 975 | "colab": { 976 | "base_uri": "https://localhost:8080/", 977 | "height": 51 978 | } 979 | }, 980 | "source": [ 981 | "print(y_pred[0:10])\n", 982 | "print(accuracy_score(y_pred, y_test))" 983 | ], 984 | "execution_count": 0, 985 | "outputs": [ 986 | { 987 | "output_type": "stream", 988 | "text": [ 989 | "[7. 2. 1. 0. 4. 1. 4. 9. 6. 9.]\n", 990 | "0.9035\n" 991 | ], 992 | "name": "stdout" 993 | } 994 | ] 995 | }, 996 | { 997 | "cell_type": "markdown", 998 | "metadata": { 999 | "id": "2jQiiDD4mOAx", 1000 | "colab_type": "text" 1001 | }, 1002 | "source": [ 1003 | "# Testing on Computer typed hand labelled dataset" 1004 | ] 1005 | }, 1006 | { 1007 | "cell_type": "code", 1008 | "metadata": { 1009 | "id": "Ng1V1lvfM-dp", 1010 | "colab_type": "code", 1011 | "outputId": "b5f9dfba-48e9-4d05-e93c-08f1ad720159", 1012 | "colab": { 1013 | "base_uri": "https://localhost:8080/", 1014 | "height": 68 1015 | } 1016 | }, 1017 | "source": [ 1018 | "X_raw_train_2 = np.loadtxt(\"X_train.csv\", dtype=np.uint8, delimiter=' ')\n", 1019 | "y_train_2 = np.loadtxt(\"y_train.csv\", dtype=np.uint8, delimiter=' ')\n", 1020 | "print(X_raw_train_2.shape, y_train_2.shape)\n", 1021 | "\n", 1022 | "X_raw_train_2 = np.reshape(X_raw_train_2, (-1, 28, 28))\n", 1023 | "\n", 1024 | "print(\"Creating Augmented Dataset...\")\n", 1025 | "X_raw_train_augmented_2 = [image for image in X_raw_train_2]\n", 1026 | "y_train_augmented_2 = [image for image in y_train_2]\n", 1027 | "\n", 1028 | "for dx, dy in ((1,0), (-1,0), (0,1), (0,-1), (1,1), (-1,1), (-1,-1), (1,-1)):\n", 1029 | " for image, label in zip(X_raw_train_2, y_train_2):\n", 1030 | " X_raw_train_augmented_2.append(shift_image(image, dx, dy))\n", 1031 | " y_train_augmented_2.append(label)\n", 1032 | "\n", 1033 | "X_raw_train_augmented_2 = np.array(X_raw_train_augmented_2, dtype=np.uint8)\n", 1034 | "y_train_augmented_2 = np.array(y_train_augmented_2, dtype=np.uint8)\n", 1035 | "\n", 1036 | "print(X_raw_train_augmented_2.shape, y_train_augmented_2.shape)" 1037 | ], 1038 | "execution_count": 10, 1039 | "outputs": [ 1040 | { 1041 | "output_type": "stream", 1042 | "text": [ 1043 | "(313, 784) (313,)\n", 1044 | "Creating Augmented Dataset...\n", 1045 | "(2817, 28, 28) (2817,)\n" 1046 | ], 1047 | "name": "stdout" 1048 | } 1049 | ] 1050 | }, 1051 | { 1052 | "cell_type": "code", 1053 | "metadata": { 1054 | "id": "A0-_0eUPiGiS", 1055 | "colab_type": "code", 1056 | "outputId": "e4efbc8c-5ecc-493d-a2b9-102446892dc4", 1057 | "colab": { 1058 | "base_uri": "https://localhost:8080/", 1059 | "height": 34 1060 | } 1061 | }, 1062 | "source": [ 1063 | "X_train_small, X_test_small = preprocess(X_raw_train_augmented_2[0:2000], X_raw_train_augmented_2[2000:])\n", 1064 | "y_train_small, y_test_small = y_train_augmented_2[0:2000], y_train_augmented_2[2000:]\n", 1065 | "print(X_train_small.shape)" 1066 | ], 1067 | "execution_count": 11, 1068 | "outputs": [ 1069 | { 1070 | "output_type": "stream", 1071 | "text": [ 1072 | "(2000, 108)\n" 1073 | ], 1074 | "name": "stdout" 1075 | } 1076 | ] 1077 | }, 1078 | { 1079 | "cell_type": "code", 1080 | "metadata": { 1081 | "id": "YSagbjh4NRWE", 1082 | "colab_type": "code", 1083 | "outputId": "1fdb9e5a-f7b2-44d3-bdad-9adead2000fd", 1084 | "colab": { 1085 | "base_uri": "https://localhost:8080/", 1086 | "height": 51 1087 | } 1088 | }, 1089 | "source": [ 1090 | "print(X_train.shape, y_train_augmented.shape)\n", 1091 | "l = np.append(X_train, X_train_small, axis=0)\n", 1092 | "m = np.append(y_train_augmented, y_train_small)\n", 1093 | "print(l.shape, m.shape)" 1094 | ], 1095 | "execution_count": 12, 1096 | "outputs": [ 1097 | { 1098 | "output_type": "stream", 1099 | "text": [ 1100 | "(540000, 108) (540000,)\n", 1101 | "(542000, 108) (542000,)\n" 1102 | ], 1103 | "name": "stdout" 1104 | } 1105 | ] 1106 | }, 1107 | { 1108 | "cell_type": "code", 1109 | "metadata": { 1110 | "id": "qsXIApJXPOdF", 1111 | "colab_type": "code", 1112 | "outputId": "8e5f055c-dea5-4963-ba2f-1bdf26fa3fc2", 1113 | "colab": { 1114 | "base_uri": "https://localhost:8080/", 1115 | "height": 190 1116 | } 1117 | }, 1118 | "source": [ 1119 | "clf_sf_mod = train_SoftmaxRegression(l, m, max_iterations=5000)\n", 1120 | "print(clf_sf_mod.classes_)" 1121 | ], 1122 | "execution_count": 18, 1123 | "outputs": [ 1124 | { 1125 | "output_type": "stream", 1126 | "text": [ 1127 | "[0 1 2 3 4 5 6 7 8 9]\n" 1128 | ], 1129 | "name": "stdout" 1130 | }, 1131 | { 1132 | "output_type": "stream", 1133 | "text": [ 1134 | "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/_logistic.py:940: ConvergenceWarning: lbfgs failed to converge (status=1):\n", 1135 | "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", 1136 | "\n", 1137 | "Increase the number of iterations (max_iter) or scale the data as shown in:\n", 1138 | " https://scikit-learn.org/stable/modules/preprocessing.html\n", 1139 | "Please also refer to the documentation for alternative solver options:\n", 1140 | " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", 1141 | " extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n" 1142 | ], 1143 | "name": "stderr" 1144 | } 1145 | ] 1146 | }, 1147 | { 1148 | "cell_type": "code", 1149 | "metadata": { 1150 | "id": "Y0H0waJTslVK", 1151 | "colab_type": "code", 1152 | "colab": {} 1153 | }, 1154 | "source": [ 1155 | "probabilities_mat = classify_Digits(clf_sf_mod, X_test_small)\n", 1156 | "y_pred = np.zeros((probabilities_mat.shape[0], ), dtype=np.uint8)\n", 1157 | "for i in range(len(probabilities_mat)):\n", 1158 | " y_pred[i] = take_decision(probabilities_mat[i], 10, 15)" 1159 | ], 1160 | "execution_count": 0, 1161 | "outputs": [] 1162 | }, 1163 | { 1164 | "cell_type": "code", 1165 | "metadata": { 1166 | "id": "Antu9iFVjuus", 1167 | "colab_type": "code", 1168 | "outputId": "a36a0a0c-3437-479e-fa51-2ba1eae06b89", 1169 | "colab": { 1170 | "base_uri": "https://localhost:8080/", 1171 | "height": 68 1172 | } 1173 | }, 1174 | "source": [ 1175 | "print(accuracy_score(y_pred, y_test_small))\n", 1176 | "print(y_pred[0:30])\n", 1177 | "print(y_test_small[0:30])" 1178 | ], 1179 | "execution_count": 21, 1180 | "outputs": [ 1181 | { 1182 | "output_type": "stream", 1183 | "text": [ 1184 | "0.8727050183598531\n", 1185 | "[1 6 7 4 5 7 1 7 4 0 5 5 4 2 8 3 4 8 1 9 3 7 1 8 8 8 5 3 3 7]\n", 1186 | "[1 6 7 4 5 7 1 7 4 9 5 5 4 2 8 3 4 8 1 9 3 7 1 6 6 8 5 8 6 7]\n" 1187 | ], 1188 | "name": "stdout" 1189 | } 1190 | ] 1191 | }, 1192 | { 1193 | "cell_type": "code", 1194 | "metadata": { 1195 | "id": "WSdQyfJvQ_mK", 1196 | "colab_type": "code", 1197 | "outputId": "58275192-8830-4826-92e2-39e9dd62e57a", 1198 | "colab": { 1199 | "base_uri": "https://localhost:8080/", 1200 | "height": 34 1201 | } 1202 | }, 1203 | "source": [ 1204 | "# Save model\n", 1205 | "joblib.dump(clf_sf_mod, 'Softmax.pkl')" 1206 | ], 1207 | "execution_count": 0, 1208 | "outputs": [ 1209 | { 1210 | "output_type": "execute_result", 1211 | "data": { 1212 | "text/plain": [ 1213 | "['Softmax.pkl']" 1214 | ] 1215 | }, 1216 | "metadata": { 1217 | "tags": [] 1218 | }, 1219 | "execution_count": 105 1220 | } 1221 | ] 1222 | }, 1223 | { 1224 | "cell_type": "code", 1225 | "metadata": { 1226 | "id": "9zA1QnspPL8H", 1227 | "colab_type": "code", 1228 | "colab": {} 1229 | }, 1230 | "source": [ 1231 | "clf_rf = train_Random_Forest(l, m)\n", 1232 | "print(clf_rf.classes_)" 1233 | ], 1234 | "execution_count": 0, 1235 | "outputs": [] 1236 | }, 1237 | { 1238 | "cell_type": "code", 1239 | "metadata": { 1240 | "id": "nLnCkxd7Prve", 1241 | "colab_type": "code", 1242 | "colab": {} 1243 | }, 1244 | "source": [ 1245 | "probabilities_mat = classify_Digits(clf_rf, X_test_small)\n", 1246 | "y_pred = np.zeros((probabilities_mat.shape[0], ), dtype=np.uint8)\n", 1247 | "for i in range(len(probabilities_mat)):\n", 1248 | " y_pred[i] = take_decision(probabilities_mat[i], 10, 15)" 1249 | ], 1250 | "execution_count": 0, 1251 | "outputs": [] 1252 | }, 1253 | { 1254 | "cell_type": "code", 1255 | "metadata": { 1256 | "id": "qriz8mK2mMIE", 1257 | "colab_type": "code", 1258 | "outputId": "986014d7-3201-447a-9ea6-fb172826f7c3", 1259 | "colab": { 1260 | "base_uri": "https://localhost:8080/" 1261 | } 1262 | }, 1263 | "source": [ 1264 | "print(accuracy_score(y_pred, y_test_small))\n", 1265 | "print(y_pred[0:10])\n", 1266 | "print(y_test[0:10])" 1267 | ], 1268 | "execution_count": 0, 1269 | "outputs": [ 1270 | { 1271 | "output_type": "stream", 1272 | "text": [ 1273 | "0.0\n", 1274 | "[0 0 0 0 0 0 0 0 0 0]\n", 1275 | "[7 2 1 0 4 1 4 9 5 9]\n" 1276 | ], 1277 | "name": "stdout" 1278 | } 1279 | ] 1280 | }, 1281 | { 1282 | "cell_type": "code", 1283 | "metadata": { 1284 | "id": "pRT2a8G7Ne-a", 1285 | "colab_type": "code", 1286 | "outputId": "88df8141-9627-4b84-e316-a4b45ce6997a", 1287 | "colab": { 1288 | "base_uri": "https://localhost:8080/", 1289 | "height": 68 1290 | } 1291 | }, 1292 | "source": [ 1293 | "print(accuracy_score(y_pred, y_test_small))\n", 1294 | "print(y_pred[0:10])\n", 1295 | "print(y_test[0:10])" 1296 | ], 1297 | "execution_count": 0, 1298 | "outputs": [ 1299 | { 1300 | "output_type": "stream", 1301 | "text": [ 1302 | "0.0\n", 1303 | "[0 0 0 0 0 0 0 0 0 0]\n", 1304 | "[7 2 1 0 4 1 4 9 5 9]\n" 1305 | ], 1306 | "name": "stdout" 1307 | } 1308 | ] 1309 | }, 1310 | { 1311 | "cell_type": "code", 1312 | "metadata": { 1313 | "id": "VNv1VMGuPun3", 1314 | "colab_type": "code", 1315 | "colab": {} 1316 | }, 1317 | "source": [ 1318 | "clf_xg = train_XGBoostSoftmax(l, m, max_depth=8, n_estimators=100)" 1319 | ], 1320 | "execution_count": 0, 1321 | "outputs": [] 1322 | }, 1323 | { 1324 | "cell_type": "code", 1325 | "metadata": { 1326 | "id": "mWcpxwertWGa", 1327 | "colab_type": "code", 1328 | "colab": {} 1329 | }, 1330 | "source": [ 1331 | "probabilities_mat = classify_Digits(clf_xg, X_test_small)\n", 1332 | "y_pred = np.zeros((probabilities_mat.shape[0], ), dtype=np.uint8)\n", 1333 | "for i in range(len(probabilities_mat)):\n", 1334 | " y_pred[i] = take_decision(probabilities_mat[i], 10, 10)" 1335 | ], 1336 | "execution_count": 0, 1337 | "outputs": [] 1338 | }, 1339 | { 1340 | "cell_type": "code", 1341 | "metadata": { 1342 | "id": "4KvlRwketJwl", 1343 | "colab_type": "code", 1344 | "colab": { 1345 | "base_uri": "https://localhost:8080/", 1346 | "height": 68 1347 | }, 1348 | "outputId": "ae152d8f-4d86-43a1-a2e8-2b50727bf1ef" 1349 | }, 1350 | "source": [ 1351 | "print(accuracy_score(y_pred, y_test_small))\n", 1352 | "print(y_pred[0:30])\n", 1353 | "print(y_test_small[0:30])" 1354 | ], 1355 | "execution_count": 32, 1356 | "outputs": [ 1357 | { 1358 | "output_type": "stream", 1359 | "text": [ 1360 | "0.8580171358629131\n", 1361 | "[1 5 7 8 5 7 1 7 6 0 5 5 2 2 8 3 4 8 1 9 3 7 1 8 6 8 5 3 3 7]\n", 1362 | "[1 6 7 4 5 7 1 7 4 9 5 5 4 2 8 3 4 8 1 9 3 7 1 6 6 8 5 8 6 7]\n" 1363 | ], 1364 | "name": "stdout" 1365 | } 1366 | ] 1367 | }, 1368 | { 1369 | "cell_type": "code", 1370 | "metadata": { 1371 | "id": "oWXJEs-8tYW3", 1372 | "colab_type": "code", 1373 | "colab": {} 1374 | }, 1375 | "source": [ 1376 | "" 1377 | ], 1378 | "execution_count": 0, 1379 | "outputs": [] 1380 | } 1381 | ] 1382 | } -------------------------------------------------------------------------------- /Testing/README.md: -------------------------------------------------------------------------------- 1 | ## My testing attempt for parsing the puzzle grid 2 | -------------------------------------------------------------------------------- /Testing/y_train.csv: -------------------------------------------------------------------------------- 1 | 1.000000000000000000e+00 2 | 8.000000000000000000e+00 3 | 9.000000000000000000e+00 4 | 3.000000000000000000e+00 5 | 4.000000000000000000e+00 6 | 2.000000000000000000e+00 7 | 5.000000000000000000e+00 8 | 3.000000000000000000e+00 9 | 9.000000000000000000e+00 10 | 6.000000000000000000e+00 11 | 4.000000000000000000e+00 12 | 3.000000000000000000e+00 13 | 7.000000000000000000e+00 14 | 1.000000000000000000e+00 15 | 9.000000000000000000e+00 16 | 5.000000000000000000e+00 17 | 6.000000000000000000e+00 18 | 7.000000000000000000e+00 19 | 7.000000000000000000e+00 20 | 8.000000000000000000e+00 21 | 5.000000000000000000e+00 22 | 6.000000000000000000e+00 23 | 1.000000000000000000e+00 24 | 4.000000000000000000e+00 25 | 8.000000000000000000e+00 26 | 9.000000000000000000e+00 27 | 1.000000000000000000e+00 28 | 4.000000000000000000e+00 29 | 3.000000000000000000e+00 30 | 8.000000000000000000e+00 31 | 4.000000000000000000e+00 32 | 8.000000000000000000e+00 33 | 5.000000000000000000e+00 34 | 3.000000000000000000e+00 35 | 6.000000000000000000e+00 36 | 2.000000000000000000e+00 37 | 2.000000000000000000e+00 38 | 5.000000000000000000e+00 39 | 7.000000000000000000e+00 40 | 7.000000000000000000e+00 41 | 8.000000000000000000e+00 42 | 8.000000000000000000e+00 43 | 5.000000000000000000e+00 44 | 3.000000000000000000e+00 45 | 6.000000000000000000e+00 46 | 6.000000000000000000e+00 47 | 2.000000000000000000e+00 48 | 1.000000000000000000e+00 49 | 9.000000000000000000e+00 50 | 1.000000000000000000e+00 51 | 4.000000000000000000e+00 52 | 3.000000000000000000e+00 53 | 8.000000000000000000e+00 54 | 4.000000000000000000e+00 55 | 8.000000000000000000e+00 56 | 5.000000000000000000e+00 57 | 3.000000000000000000e+00 58 | 6.000000000000000000e+00 59 | 2.000000000000000000e+00 60 | 2.000000000000000000e+00 61 | 5.000000000000000000e+00 62 | 7.000000000000000000e+00 63 | 7.000000000000000000e+00 64 | 8.000000000000000000e+00 65 | 8.000000000000000000e+00 66 | 5.000000000000000000e+00 67 | 3.000000000000000000e+00 68 | 6.000000000000000000e+00 69 | 6.000000000000000000e+00 70 | 2.000000000000000000e+00 71 | 1.000000000000000000e+00 72 | 9.000000000000000000e+00 73 | 1.000000000000000000e+00 74 | 4.000000000000000000e+00 75 | 3.000000000000000000e+00 76 | 8.000000000000000000e+00 77 | 4.000000000000000000e+00 78 | 8.000000000000000000e+00 79 | 5.000000000000000000e+00 80 | 3.000000000000000000e+00 81 | 6.000000000000000000e+00 82 | 2.000000000000000000e+00 83 | 2.000000000000000000e+00 84 | 5.000000000000000000e+00 85 | 7.000000000000000000e+00 86 | 7.000000000000000000e+00 87 | 8.000000000000000000e+00 88 | 8.000000000000000000e+00 89 | 5.000000000000000000e+00 90 | 3.000000000000000000e+00 91 | 6.000000000000000000e+00 92 | 6.000000000000000000e+00 93 | 2.000000000000000000e+00 94 | 1.000000000000000000e+00 95 | 9.000000000000000000e+00 96 | 5.000000000000000000e+00 97 | 3.000000000000000000e+00 98 | 1.000000000000000000e+00 99 | 7.000000000000000000e+00 100 | 1.000000000000000000e+00 101 | 4.000000000000000000e+00 102 | 5.000000000000000000e+00 103 | 2.000000000000000000e+00 104 | 9.000000000000000000e+00 105 | 6.000000000000000000e+00 106 | 3.000000000000000000e+00 107 | 8.000000000000000000e+00 108 | 5.000000000000000000e+00 109 | 1.000000000000000000e+00 110 | 5.000000000000000000e+00 111 | 7.000000000000000000e+00 112 | 9.000000000000000000e+00 113 | 3.000000000000000000e+00 114 | 7.000000000000000000e+00 115 | 5.000000000000000000e+00 116 | 8.000000000000000000e+00 117 | 2.000000000000000000e+00 118 | 8.000000000000000000e+00 119 | 6.000000000000000000e+00 120 | 3.000000000000000000e+00 121 | 5.000000000000000000e+00 122 | 2.000000000000000000e+00 123 | 1.000000000000000000e+00 124 | 6.000000000000000000e+00 125 | 7.000000000000000000e+00 126 | 4.000000000000000000e+00 127 | 5.000000000000000000e+00 128 | 7.000000000000000000e+00 129 | 1.000000000000000000e+00 130 | 7.000000000000000000e+00 131 | 4.000000000000000000e+00 132 | 9.000000000000000000e+00 133 | 5.000000000000000000e+00 134 | 5.000000000000000000e+00 135 | 4.000000000000000000e+00 136 | 2.000000000000000000e+00 137 | 8.000000000000000000e+00 138 | 3.000000000000000000e+00 139 | 4.000000000000000000e+00 140 | 8.000000000000000000e+00 141 | 1.000000000000000000e+00 142 | 9.000000000000000000e+00 143 | 3.000000000000000000e+00 144 | 7.000000000000000000e+00 145 | 1.000000000000000000e+00 146 | 6.000000000000000000e+00 147 | 6.000000000000000000e+00 148 | 8.000000000000000000e+00 149 | 5.000000000000000000e+00 150 | 8.000000000000000000e+00 151 | 6.000000000000000000e+00 152 | 7.000000000000000000e+00 153 | 4.000000000000000000e+00 154 | 5.000000000000000000e+00 155 | 1.000000000000000000e+00 156 | 7.000000000000000000e+00 157 | 8.000000000000000000e+00 158 | 2.000000000000000000e+00 159 | 7.000000000000000000e+00 160 | 3.000000000000000000e+00 161 | 4.000000000000000000e+00 162 | 7.000000000000000000e+00 163 | 2.000000000000000000e+00 164 | 8.000000000000000000e+00 165 | 9.000000000000000000e+00 166 | 3.000000000000000000e+00 167 | 4.000000000000000000e+00 168 | 7.000000000000000000e+00 169 | 8.000000000000000000e+00 170 | 9.000000000000000000e+00 171 | 9.000000000000000000e+00 172 | 4.000000000000000000e+00 173 | 1.000000000000000000e+00 174 | 5.000000000000000000e+00 175 | 9.000000000000000000e+00 176 | 2.000000000000000000e+00 177 | 5.000000000000000000e+00 178 | 1.000000000000000000e+00 179 | 9.000000000000000000e+00 180 | 8.000000000000000000e+00 181 | 9.000000000000000000e+00 182 | 1.000000000000000000e+00 183 | 3.000000000000000000e+00 184 | 9.000000000000000000e+00 185 | 5.000000000000000000e+00 186 | 6.000000000000000000e+00 187 | 8.000000000000000000e+00 188 | 5.000000000000000000e+00 189 | 7.000000000000000000e+00 190 | 6.000000000000000000e+00 191 | 6.000000000000000000e+00 192 | 1.000000000000000000e+00 193 | 3.000000000000000000e+00 194 | 2.000000000000000000e+00 195 | 5.000000000000000000e+00 196 | 4.000000000000000000e+00 197 | 9.000000000000000000e+00 198 | 4.000000000000000000e+00 199 | 3.000000000000000000e+00 200 | 2.000000000000000000e+00 201 | 3.000000000000000000e+00 202 | 8.000000000000000000e+00 203 | 6.000000000000000000e+00 204 | 2.000000000000000000e+00 205 | 5.000000000000000000e+00 206 | 4.000000000000000000e+00 207 | 8.000000000000000000e+00 208 | 9.000000000000000000e+00 209 | 4.000000000000000000e+00 210 | 7.000000000000000000e+00 211 | 2.000000000000000000e+00 212 | 9.000000000000000000e+00 213 | 7.000000000000000000e+00 214 | 4.000000000000000000e+00 215 | 1.000000000000000000e+00 216 | 6.000000000000000000e+00 217 | 9.000000000000000000e+00 218 | 4.000000000000000000e+00 219 | 8.000000000000000000e+00 220 | 6.000000000000000000e+00 221 | 8.000000000000000000e+00 222 | 5.000000000000000000e+00 223 | 7.000000000000000000e+00 224 | 5.000000000000000000e+00 225 | 9.000000000000000000e+00 226 | 1.000000000000000000e+00 227 | 2.000000000000000000e+00 228 | 3.000000000000000000e+00 229 | 5.000000000000000000e+00 230 | 2.000000000000000000e+00 231 | 8.000000000000000000e+00 232 | 6.000000000000000000e+00 233 | 1.000000000000000000e+00 234 | 3.000000000000000000e+00 235 | 8.000000000000000000e+00 236 | 4.000000000000000000e+00 237 | 5.000000000000000000e+00 238 | 6.000000000000000000e+00 239 | 2.000000000000000000e+00 240 | 3.000000000000000000e+00 241 | 1.000000000000000000e+00 242 | 4.000000000000000000e+00 243 | 7.000000000000000000e+00 244 | 4.000000000000000000e+00 245 | 5.000000000000000000e+00 246 | 6.000000000000000000e+00 247 | 2.000000000000000000e+00 248 | 1.000000000000000000e+00 249 | 3.000000000000000000e+00 250 | 9.000000000000000000e+00 251 | 4.000000000000000000e+00 252 | 1.000000000000000000e+00 253 | 3.000000000000000000e+00 254 | 8.000000000000000000e+00 255 | 5.000000000000000000e+00 256 | 1.000000000000000000e+00 257 | 2.000000000000000000e+00 258 | 4.000000000000000000e+00 259 | 7.000000000000000000e+00 260 | 3.000000000000000000e+00 261 | 1.000000000000000000e+00 262 | 2.000000000000000000e+00 263 | 6.000000000000000000e+00 264 | 8.000000000000000000e+00 265 | 4.000000000000000000e+00 266 | 7.000000000000000000e+00 267 | 1.000000000000000000e+00 268 | 6.000000000000000000e+00 269 | 5.000000000000000000e+00 270 | 3.000000000000000000e+00 271 | 4.000000000000000000e+00 272 | 3.000000000000000000e+00 273 | 2.000000000000000000e+00 274 | 5.000000000000000000e+00 275 | 1.000000000000000000e+00 276 | 7.000000000000000000e+00 277 | 5.000000000000000000e+00 278 | 9.000000000000000000e+00 279 | 5.000000000000000000e+00 280 | 3.000000000000000000e+00 281 | 7.000000000000000000e+00 282 | 2.000000000000000000e+00 283 | 8.000000000000000000e+00 284 | 3.000000000000000000e+00 285 | 6.000000000000000000e+00 286 | 9.000000000000000000e+00 287 | 4.000000000000000000e+00 288 | 2.000000000000000000e+00 289 | 3.000000000000000000e+00 290 | 7.000000000000000000e+00 291 | 1.000000000000000000e+00 292 | 5.000000000000000000e+00 293 | 3.000000000000000000e+00 294 | 3.000000000000000000e+00 295 | 3.000000000000000000e+00 296 | 1.000000000000000000e+00 297 | 6.000000000000000000e+00 298 | 9.000000000000000000e+00 299 | 3.000000000000000000e+00 300 | 6.000000000000000000e+00 301 | 8.000000000000000000e+00 302 | 2.000000000000000000e+00 303 | 2.000000000000000000e+00 304 | 7.000000000000000000e+00 305 | 8.000000000000000000e+00 306 | 1.000000000000000000e+00 307 | 1.000000000000000000e+00 308 | 3.000000000000000000e+00 309 | 6.000000000000000000e+00 310 | 5.000000000000000000e+00 311 | 8.000000000000000000e+00 312 | 3.000000000000000000e+00 313 | 9.000000000000000000e+00 314 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vaithak/Sudoku-Image-Solver/5d01d6e6c3921251c1f38050aa3e3e85610d66fc/__init__.py -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import cv2 3 | import numpy as np 4 | from PIL import Image 5 | from io import StringIO, BytesIO 6 | import base64 7 | import matplotlib.pyplot as plt 8 | import matplotlib.patches 9 | from PuzzleExtractor.digit_extraction import extractDigits 10 | from PuzzleExtractor.processing import processImage 11 | from PuzzleExtractor.grid_extraction import extractGrid 12 | import GridSolver.SudokuSolve as Solver 13 | from DigitsRecogniser.recognise_digits import predictDigits 14 | 15 | def _max_width_(): 16 | max_width_str = f"max-width: 2000px;" 17 | st.markdown( 18 | f""" 19 | 24 | """, 25 | unsafe_allow_html=True, 26 | ) 27 | 28 | def PIL_image_to_bytes(img_obj: Image.Image): 29 | buf = BytesIO() 30 | img_obj.save(buf, format='JPEG') 31 | img_bytes = buf.getvalue() 32 | encoded = base64.b64encode(img_bytes).decode() 33 | return encoded 34 | 35 | 36 | def show_image_with_pyplot(img, caption, grayscale=True, ax=None): 37 | if(len(img.shape) == 3): img_plt = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 38 | else : img_plt = img 39 | 40 | if ax is None: 41 | fig = plt.figure(figsize=(14, 10)) 42 | ax = plt.axes() 43 | 44 | ax.set_title(caption) 45 | ax.set_xticks([]) 46 | ax.set_yticks([]) 47 | 48 | if grayscale: ax.imshow(img_plt, cmap='gray', vmin=0, vmax=255) 49 | else: ax.imshow(img_plt, vmin=0, vmax=255) 50 | 51 | 52 | # Plot's sudoku passed as numpy array (n) on pyplot axis (ax) 53 | def plot_sudoku(n: np.ndarray, caption:str, ax): 54 | ax.set_xticks([]) 55 | ax.set_yticks([]) 56 | ax.set_title(caption) 57 | 58 | # Simple plotting statement that ingests a 9x9 array (n), and plots a sudoku-style grid around it. 59 | for y in range(10): 60 | ax.plot([-0.05,9.05],[y,y],color='black',linewidth=1) 61 | 62 | for y in range(0,10,3): 63 | ax.plot([-0.05,9.05],[y,y],color='black',linewidth=3) 64 | 65 | for x in range(10): 66 | ax.plot([x,x],[-0.05,9.05],color='black',linewidth=1) 67 | 68 | for x in range(0,10,3): 69 | ax.plot([x,x],[-0.05,9.05],color='black',linewidth=3) 70 | 71 | # plt.axis('image') 72 | # plt.axis('off') # drop the axes, they're not important here 73 | 74 | for x in range(9): 75 | for y in range(9): 76 | foo=n[8-y][x] # need to reverse the y-direction for plotting 77 | if foo > 0: # ignore the zeros 78 | T=str(foo) 79 | ax.text(x+0.3,y+0.2,T,fontsize=20) 80 | 81 | 82 | # Assumes digits in column major order 83 | def get_extracted_digits_img(digits: list, rows: int, cols: int) -> np.ndarray: 84 | combined = np.array([]) 85 | for i in range(cols): 86 | bordered_digit = cv2.copyMakeBorder(digits[i*rows],2,2,2,2,cv2.BORDER_CONSTANT,value=255) 87 | col_combined = bordered_digit 88 | for j in range(1, rows): 89 | bordered_digit = cv2.copyMakeBorder(digits[i*rows + j],2,2,2,2,cv2.BORDER_CONSTANT,value=255) 90 | col_combined = np.vstack((col_combined, bordered_digit)) 91 | 92 | if i == 0: 93 | combined = col_combined 94 | else: 95 | combined = np.hstack((combined, col_combined)) 96 | 97 | return combined 98 | 99 | 100 | def get_np_array_row_major(inp: str, rows: int, cols: int) -> np.ndarray: 101 | inp_array = np.fromstring(inp, dtype=np.int8, sep='') - 48 102 | return np.reshape(inp_array, (rows, cols)) 103 | 104 | 105 | def get_patches(ax1, ax2, ax3, ax4): 106 | limits_x = [axis.get_xlim() for axis in [ax1, ax2, ax3, ax4]] 107 | limits_y = [axis.get_ylim() for axis in [ax1, ax2, ax3, ax4]] 108 | patch_type = "-|>" 109 | 110 | patch1 = matplotlib.patches.ConnectionPatch( 111 | xyA=((limits_x[0][1]+limits_x[0][0])//2, limits_y[0][0]), 112 | xyB=((limits_x[1][1]+limits_x[1][0])//2, limits_y[1][1]), 113 | coordsA="data", 114 | coordsB="data", 115 | axesA=ax1, 116 | axesB=ax2, 117 | arrowstyle=patch_type, 118 | color="green", 119 | shrinkA=10, 120 | mutation_scale=60, 121 | linewidth=10, 122 | alpha=0.8, 123 | clip_on=False, 124 | ) 125 | 126 | patch2 = matplotlib.patches.ConnectionPatch( 127 | xyA=(limits_x[1][1], (limits_y[1][0]+limits_y[1][1])//2), 128 | xyB=(limits_x[2][0], (limits_y[2][0]+limits_y[2][1])//2), 129 | coordsA="data", 130 | coordsB="data", 131 | axesA=ax2, 132 | axesB=ax3, 133 | arrowstyle=patch_type, 134 | color="green", 135 | shrinkA=10, 136 | mutation_scale=60, 137 | linewidth=10, 138 | alpha=0.8, 139 | clip_on=False, 140 | ) 141 | 142 | patch3 = matplotlib.patches.ConnectionPatch( 143 | xyA=((limits_x[2][0]+limits_x[2][1])//2, limits_y[2][1]), 144 | xyB=((limits_x[3][0]+limits_x[3][1])//2, limits_y[3][0]), 145 | coordsA="data", 146 | coordsB="data", 147 | axesA=ax3, 148 | axesB=ax4, 149 | arrowstyle=patch_type, 150 | color="green", 151 | shrinkA=10, 152 | mutation_scale=60, 153 | linewidth=10, 154 | alpha=0.8, 155 | clip_on=False, 156 | ) 157 | 158 | return (patch1, patch2, patch3) 159 | 160 | 161 | 162 | def get_matplotlib_figure(orig_img_array: np.ndarray, extracted_digits: np.ndarray, recognised_digits: np.ndarray, solved_grid: np.ndarray) -> plt.figure: 163 | fig, axes = plt.subplots(2,2) 164 | fig.set_size_inches(15,15) 165 | 166 | # Plotting original image 167 | caption="Original Image" 168 | show_image_with_pyplot(orig_img_array, caption, grayscale=False, ax=axes[0,0]) 169 | 170 | # Plotting extracted digits 171 | caption="Extracted digits from Image" 172 | digits_img = get_extracted_digits_img(extracted_digits, 9, 9) 173 | show_image_with_pyplot(digits_img, caption, grayscale=True, ax=axes[1, 0]) 174 | 175 | # Plotting recognised digits 176 | caption="Recognised digits using the selected model" 177 | plot_sudoku(recognised_digits, caption, axes[1,1]) 178 | 179 | if solved_grid is not None: 180 | # Plotting solved_grid 181 | fig.suptitle("Successfully solved the sudoku") 182 | caption = "Solved sudoku using Backtracking" 183 | plot_sudoku(solved_grid, caption, axes[0,1]) 184 | else: 185 | fig.suptitle("Error !!! The sudoku recognised is invalid and cannot be solved.") 186 | 187 | patches1, patches2, patches3 = get_patches(axes[0,0], axes[1,0], axes[1,1], axes[0,1]) 188 | axes[0,0].add_artist(patches1) 189 | axes[1,0].add_artist(patches2) 190 | axes[1,1].add_artist(patches3) 191 | 192 | return fig 193 | 194 | 195 | #def display_PIL_img(img_obj: Image.Image, caption: str, width_str: str): 196 | # st.write(caption) 197 | # header_html = "".format(PIL_image_to_bytes(img_obj)) 198 | # st.markdown(header_html, unsafe_allow_html=True) 199 | 200 | 201 | def str_to_np_arr(inp_str): 202 | try: 203 | res = np.zeros((9,9), dtype=np.int8) 204 | rows = inp_str[1:-1].split('\n ') 205 | #print(rows) 206 | for i in range(9): 207 | curr_row = rows[i][1:-1].split(' ') 208 | for j in range(9): 209 | res[i][j] = int(curr_row[j]) 210 | if(res[i][j] > 9 or res[i][j] < 0): 211 | return None 212 | 213 | return res 214 | except: 215 | return None 216 | 217 | 218 | def operate_on_image(img_array, model): 219 | binary_img_array = processImage(img_array) 220 | status, grid_array = extractGrid(binary_img_array) 221 | if status == False: 222 | st.write("The Sudoku grid cannot be extracted from the image") 223 | else: 224 | extracted_digits = extractDigits(grid_array) 225 | recognised_digits_str = predictDigits(extracted_digits, model) 226 | recognised_digits = get_np_array_row_major(recognised_digits_str, 9, 9) 227 | grid_text = st.sidebar.text_area("Modify the parsed sudoku grid to correct the parser", value=recognised_digits, max_chars=200) 228 | 229 | if grid_text != str(recognised_digits): 230 | grid_text_arr = str_to_np_arr(grid_text) 231 | if grid_text_arr is not None: 232 | recognised_digits = grid_text_arr 233 | recognised_digits_str = ''.join(c for c in grid_text if c not in '[]\n ') 234 | else: 235 | st.sidebar.markdown("Invalid grid entered") 236 | 237 | solver = Solver.Sudoku(recognised_digits_str) 238 | 239 | if((not solver.verifyGridStatus()) or (not solver.solveGrid())): 240 | st.write("The puzzle extracted is invalid") 241 | figure = get_matplotlib_figure(img_array, extracted_digits, recognised_digits, None) 242 | else: 243 | solved_grid_str = solver.getGrid() 244 | solved_grid = get_np_array_row_major(solved_grid_str, 9, 9) 245 | figure = get_matplotlib_figure(img_array, extracted_digits, recognised_digits, solved_grid) 246 | 247 | st.pyplot(figure) 248 | 249 | 250 | 251 | @st.cache(hash_funcs={StringIO: StringIO.getvalue}, suppress_st_warning=True) 252 | def read_example_image(name): 253 | img = cv2.imread("ExampleImages/" + name + ".jpg") 254 | return img 255 | 256 | 257 | _max_width_() 258 | 259 | st.title("Sudoku Image Solver") 260 | st.sidebar.markdown("

Sudoku Image Solver

", unsafe_allow_html=True) 261 | 262 | """ 263 | --- 264 | 265 | Tools and libraries used: 266 | OpenCV-Python : For using image processing algorthims for Grid Extraction and Digits Extraction. 267 | Keras, Scikit-Learn : For implementing ML models for recognizing digits extracted from the image 268 | I have also used Swig for creating Python bindings of C++ code for solving the puzzle recognized from the image. 269 | 270 | Author: [Vaibhav Thakkar](https://github.com/vaithak) 271 | You can find the code [here](https://github.com/vaithak/SudokuImageSolver) 272 | 273 | --- 274 | """ 275 | 276 | 277 | images_list = ["image" + str(i) for i in range(1,8)] 278 | models = ["Convolutional Neural Networks", "XGBoost Algorithm", "Softmax Regression"] 279 | 280 | 281 | option = st.sidebar.radio('Please choose any of the following options',('Choose image from examples','Upload your own image')) 282 | model = st.sidebar.radio("Model for recognizing digits", models, 0) 283 | 284 | input_image = None 285 | if option == "Choose image from examples": 286 | example_image_name = st.sidebar.selectbox("Select an example image", images_list) 287 | input_image = read_example_image(example_image_name) 288 | else: 289 | uploaded_file = st.file_uploader("Upload your own image (supported types: jpg, jpeg, png)...", type=["jpg","jpeg","png"]) 290 | if uploaded_file is not None: 291 | img = Image.open(uploaded_file) 292 | input_image = np.array(img) 293 | input_image = input_image[:,:,::-1] 294 | 295 | 296 | if input_image is not None: 297 | operate_on_image(input_image, models.index(model)) 298 | 299 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | altair==4.1.0 3 | astor==0.8.1 4 | astunparse==1.6.3 5 | attrs==19.3.0 6 | backcall==0.1.0 7 | base58==2.0.1 8 | bleach==3.1.5 9 | blinker==1.4 10 | boto3==1.13.24 11 | botocore==1.16.24 12 | cachetools==4.1.0 13 | certifi==2020.4.5.1 14 | chardet==3.0.4 15 | click==7.1.2 16 | cycler==0.10.0 17 | decorator==4.4.2 18 | defusedxml==0.6.0 19 | docutils==0.15.2 20 | entrypoints==0.3 21 | enum-compat==0.0.3 22 | gast==0.3.3 23 | google-auth==1.17.1 24 | google-auth-oauthlib==0.4.1 25 | google-pasta==0.2.0 26 | grpcio==1.29.0 27 | h5py==2.10.0 28 | idna==2.9 29 | ipykernel==5.3.0 30 | ipython==7.15.0 31 | ipython-genutils==0.2.0 32 | ipywidgets==7.5.1 33 | jedi==0.17.0 34 | Jinja2==2.11.2 35 | jmespath==0.10.0 36 | joblib==0.15.1 37 | jsonschema==3.2.0 38 | jupyter-client==6.1.3 39 | jupyter-core==4.6.3 40 | Keras==2.3.1 41 | Keras-Applications==1.0.8 42 | Keras-Preprocessing==1.1.2 43 | kiwisolver==1.2.0 44 | Markdown==3.2.2 45 | MarkupSafe==1.1.1 46 | matplotlib==3.2.1 47 | mistune==0.8.4 48 | nbconvert==5.6.1 49 | nbformat==5.0.6 50 | notebook==6.0.3 51 | numpy==1.18.5 52 | oauthlib==3.1.0 53 | opencv-python==4.2.0.34 54 | opt-einsum==3.2.1 55 | packaging==20.4 56 | pandas==1.0.4 57 | pandocfilters==1.4.2 58 | parso==0.7.0 59 | pathtools==0.1.2 60 | pexpect==4.8.0 61 | pickleshare==0.7.5 62 | Pillow==7.1.2 63 | prometheus-client==0.8.0 64 | prompt-toolkit==3.0.5 65 | protobuf==3.12.2 66 | ptyprocess==0.6.0 67 | pyasn1==0.4.8 68 | pyasn1-modules==0.2.8 69 | pydeck==0.4.0b1 70 | Pygments==2.6.1 71 | pyparsing==2.4.7 72 | pyrsistent==0.16.0 73 | python-dateutil==2.8.1 74 | pytz==2020.1 75 | PyYAML==5.3.1 76 | pyzmq==19.0.1 77 | requests==2.23.0 78 | requests-oauthlib==1.3.0 79 | rsa==4.2 80 | s3transfer==0.3.3 81 | scikit-learn==0.22.2.post1 82 | scipy==1.4.1 83 | Send2Trash==1.5.0 84 | six==1.15.0 85 | streamlit==0.61.0 86 | tensorboard==2.2.2 87 | tensorboard-plugin-wit==1.6.0.post3 88 | tensorflow==2.2.0 89 | tensorflow-estimator==2.2.0 90 | termcolor==1.1.0 91 | terminado==0.8.3 92 | testpath==0.4.4 93 | threadpoolctl==2.1.0 94 | toml==0.10.1 95 | toolz==0.10.0 96 | tornado==5.1.1 97 | traitlets==4.3.3 98 | tzlocal==2.1 99 | urllib3==1.25.9 100 | validators==0.15.0 101 | watchdog==0.10.2 102 | wcwidth==0.2.3 103 | webencodings==0.5.1 104 | Werkzeug==1.0.1 105 | widgetsnbextension==3.5.1 106 | wrapt==1.12.1 107 | xgboost==0.90 108 | --------------------------------------------------------------------------------