├── HISTORY.md ├── requirements ├── tests.txt ├── development.txt └── production.txt ├── pehchaan ├── config │ ├── base.py │ ├── __init__.py │ └── app.py ├── data │ ├── encoders │ │ └── DHCD-LE.pkl │ └── models │ │ └── DHCD-SVC-96.61,10-96.30,0.60.pkl ├── __init__.py ├── _util │ ├── __init__.py │ ├── const.py │ └── _util.py ├── __main__.py └── app.py ├── requirements.txt ├── .gitignore ├── README.md ├── Makefile ├── package.py ├── LICENSE ├── setup.py └── notebooks └── Devanagari Character Recognition.ipynb /HISTORY.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements/tests.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | -------------------------------------------------------------------------------- /requirements/development.txt: -------------------------------------------------------------------------------- 1 | jupyter 2 | pandas 3 | -------------------------------------------------------------------------------- /pehchaan/config/base.py: -------------------------------------------------------------------------------- 1 | class BaseConfig(object): 2 | pass 3 | -------------------------------------------------------------------------------- /requirements/production.txt: -------------------------------------------------------------------------------- 1 | future 2 | numpy 3 | scipy 4 | sklearn 5 | Pillow 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter 2 | pandas 3 | future 4 | numpy 5 | scipy 6 | sklearn 7 | Pillow 8 | pytest 9 | -------------------------------------------------------------------------------- /pehchaan/data/encoders/DHCD-LE.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/achillesrasquinha/pehchaan/HEAD/pehchaan/data/encoders/DHCD-LE.pkl -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # virtualenvs 2 | .venv 3 | 4 | # IPython 5 | # checkpoints 6 | notebooks/.ipynb_checkpoints 7 | # data 8 | notebooks/data 9 | -------------------------------------------------------------------------------- /pehchaan/data/models/DHCD-SVC-96.61,10-96.30,0.60.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/achillesrasquinha/pehchaan/HEAD/pehchaan/data/models/DHCD-SVC-96.61,10-96.30,0.60.pkl -------------------------------------------------------------------------------- /pehchaan/__init__.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility imports 2 | from __future__ import absolute_import 3 | 4 | # imports - pehchaan 5 | from pehchaan.config import AppConfig 6 | from pehchaan.app import App 7 | -------------------------------------------------------------------------------- /pehchaan/_util/__init__.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # imports - pehchaan 5 | from pehchaan._util._util import _get_version_string, _image_to_input 6 | -------------------------------------------------------------------------------- /pehchaan/config/__init__.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility imports 2 | from __future__ import absolute_import 3 | 4 | # imports - pehchaan 5 | from pehchaan.config.base import BaseConfig 6 | from pehchaan.config.app import AppConfig 7 | -------------------------------------------------------------------------------- /pehchaan/__main__.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # imports - pehchaan 5 | from pehchaan import App 6 | 7 | def main(): 8 | app = App() 9 | app.run() 10 | 11 | if __name__ == '__main__': 12 | main() 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pehchaan 2 | > Devanagari Character Recognition Using Machine Learning 3 | 4 | [![](http://img.youtube.com/vi/wmdhn-sWQQw/0.jpg)](http://www.youtube.com/watch?v=wmdhn-sWQQw) 5 | 6 | ### Usage 7 | 8 | ### License 9 | This code has been released under the [MIT License](LICENSE). 10 | -------------------------------------------------------------------------------- /pehchaan/config/app.py: -------------------------------------------------------------------------------- 1 | from pehchaan.config import BaseConfig 2 | 3 | class AppConfig(BaseConfig): 4 | NAME = 'pehchaan' 5 | VERSION = (0, 1, 0) 6 | 7 | WINDOW_WIDTH = 320 8 | WINDOW_HEIGHT = 480 9 | 10 | CANVAS_BACKGROUND_COLOR = '#000000' 11 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: docs 2 | 3 | PYTHON = python 4 | 5 | install: 6 | cat requirements/*.txt > requirements.txt 7 | pip install -r requirements.txt --no-cache-dir 8 | 9 | $(PYTHON) setup.py install 10 | 11 | docs: 12 | cd docs && make html 13 | 14 | test: 15 | $(PYTHON) setup.py test 16 | 17 | clean: 18 | $(PYTHON) setup.py clean 19 | -------------------------------------------------------------------------------- /pehchaan/_util/const.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ABSPATH_ROOT = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) 4 | ABSPATH_DATA = os.path.join(ABSPATH_ROOT, 'data') 5 | ABSPATH_MODELS = os.path.join(ABSPATH_DATA, 'models') 6 | ABSPATH_ENCODERS = os.path.join(ABSPATH_DATA, 'encoders') 7 | 8 | ABSPATH_ENCODER_DHCD = os.path.join(ABSPATH_ENCODERS, 'DHCD-LE.pkl') 9 | 10 | ABSPATH_MODEL_DHCD_SVC = os.path.join(ABSPATH_MODELS, 'DHCD-SVC-96.61,10-96.30,0.60.pkl') 11 | 12 | DHCD_INPUT_SIZE = (32, 32) 13 | -------------------------------------------------------------------------------- /pehchaan/_util/_util.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # imports - standard packages 5 | import numpy as np 6 | 7 | # imports - pehchaan 8 | from pehchaan import AppConfig 9 | 10 | def _get_version_string(): 11 | version = '.'.join(map(str, AppConfig.VERSION)) 12 | 13 | return version 14 | 15 | def _to_grayscale(r, g, b): 16 | gray = 0.2126 * r + 0.715 * g + 0.0722 * b 17 | 18 | return gray 19 | 20 | def _image_to_input(image): 21 | image = image.convert('RGB') 22 | arr = np.array(image) 23 | r,g,b = arr[:,:,0], arr[:,:,1], arr[:,:,2] 24 | bw = _to_grayscale(r, g, b) 25 | 26 | bw[bw < 128] = 0 27 | bw[bw >= 128] = 1 28 | 29 | bw = bw.flatten() 30 | 31 | return bw 32 | -------------------------------------------------------------------------------- /package.py: -------------------------------------------------------------------------------- 1 | # Inspired by npm's package.json file 2 | name = 'pehchaan' 3 | version = '0.1.0' 4 | release = '0.1.0' 5 | description = 'A character recognition suite' 6 | long_description = ['README.md', 'HISTORY.md'] 7 | keywords = ['image', 'character', 'recognition', 'machine', 'deep', 'learning'] 8 | authors = [ 9 | { 'name': 'Achilles Rasquinha', 'email': 'achillesrasquinha@gmail.com' } 10 | ] 11 | maintainers = [ 12 | { 'name': 'Achilles Rasquinha', 'email': 'achillesrasquinha@gmail.com' } 13 | ] 14 | license = 'Apache 2.0' 15 | modules = [ 16 | 'pehchaan', 17 | 'pehchaan.config', 18 | 'pehchaan._util', 19 | 'pehchaan.data', 20 | ] 21 | homepage = 'https://achillesrasquinha.github.io/pehchaan' 22 | github_username = 'achillesrasquinha' 23 | github_repository = 'pehchaan' 24 | github_url = '{baseurl}/{username}/{repository}'.format( 25 | baseurl = 'https://github.com', 26 | username = github_username, 27 | repository = github_repository) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2017 Achilles Rasquinha 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import os 5 | import shutil 6 | import codecs 7 | 8 | from distutils.core import Command 9 | from distutils.command.clean import clean as Clean 10 | 11 | import package 12 | 13 | ABSPATH_ROOTDIR = os.path.dirname(os.path.abspath(__file__)) 14 | RELPATH_FILES_CLEAN = ['build', 'dist', '{name}.egg-info'.format(name = package.name), '.cache'] 15 | RELPATH_WALK_FILES_EXT_CLEAN = ['.pyc'] 16 | RELPATH_WALK_DIRS_CLEAN = ['__pycache__'] 17 | 18 | class CleanCommand(Clean): 19 | def run(self): 20 | Clean.run(self) 21 | 22 | for filename in RELPATH_FILES_CLEAN: 23 | if os.path.exists(filename): 24 | shutil.rmtree(filename) 25 | 26 | for dirpath, dirnames, filenames in os.walk(ABSPATH_ROOTDIR): 27 | for filename in filenames: 28 | for extension in RELPATH_WALK_FILES_EXT_CLEAN: 29 | if filename.endswith(extension): 30 | path = os.path.join(dirpath, filename) 31 | os.unlink(path) 32 | 33 | for dirname in dirnames: 34 | if dirname in RELPATH_WALK_DIRS_CLEAN: 35 | path = os.path.join(dirpath, dirname) 36 | shutil.rmtree(path) 37 | 38 | class TestCommand(Command): 39 | user_options = [('pytest=', 'a', 'arguments to be passed to pytest')] 40 | 41 | def initialize_options(self): 42 | self.args_pytest = [ ] 43 | 44 | def finalize_options(self): 45 | pass 46 | 47 | def run(self): 48 | import pytest 49 | 50 | errno = pytest.main(self.args_pytest) 51 | 52 | sys.exit(errno) 53 | 54 | def get_long_description(filepaths): 55 | content = '' 56 | filepaths = filepaths if isinstance(filepaths, list) else [filepaths] 57 | 58 | if filepaths: 59 | for i, filepath in enumerate(filepaths): 60 | if os.path.exists(filepath): 61 | if os.path.isfile(filepath): 62 | if os.path.getsize(filepath) > 0: 63 | f = codecs.open(filepath, mode = 'r', encoding = 'utf-8') 64 | raw = f.read() 65 | content += '{prepend}{content}'.format(prepend = '' if i is 0 else '\n\n', content = raw) 66 | 67 | f.close() 68 | else: 69 | raise ValueError('Not a file: {filepath}'.format(filepath = filepath)) 70 | else: 71 | raise FileNotFoundError('No such file found: {filepath}'.format(filepath = filepath)) 72 | 73 | return content 74 | 75 | def main(): 76 | try: 77 | from setuptools import setup 78 | args_setuptools = dict( 79 | keywords = ', '.join([keyword for keyword in package.keywords]) 80 | ) 81 | except ImportError: 82 | from distutils.core import setup 83 | args_setuptools = dict() 84 | 85 | metadata = dict( 86 | name = package.name, 87 | version = package.version, 88 | description = package.description, 89 | long_description = get_long_description(package.long_description), 90 | author = ','.join([author['name'] for author in package.authors]), 91 | author_email = ','.join([author['email'] for author in package.authors]), 92 | maintainer = ','.join([maintainer['name'] for maintainer in package.maintainers]), 93 | maintainer_email = ','.join([maintainer['email'] for maintainer in package.maintainers]), 94 | license = package.license, 95 | packages = package.modules, 96 | url = package.homepage, 97 | cmdclass = { 98 | 'clean': CleanCommand, 99 | 'test': TestCommand 100 | }, 101 | **args_setuptools 102 | ) 103 | 104 | setup(**metadata) 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /pehchaan/app.py: -------------------------------------------------------------------------------- 1 | # imports - compatibility packages 2 | from __future__ import absolute_import 3 | 4 | # imports - standard packages 5 | from collections import defaultdict 6 | try: 7 | import Tkinter as tk 8 | except ImportError: 9 | import tkinter as tk 10 | import pickle 11 | 12 | # imports - third party 13 | from PIL import Image, ImageDraw 14 | import numpy as np 15 | 16 | # imports - pehchaan 17 | from pehchaan import AppConfig 18 | from pehchaan._util.const import ABSPATH_MODEL_DHCD_SVC, ABSPATH_ENCODER_DHCD, DHCD_INPUT_SIZE 19 | from pehchaan._util import _get_version_string, _image_to_input 20 | 21 | class App(object): 22 | class Frame(tk.Tk): 23 | BUTTON_CLEAR = 'clear' 24 | BUTTON_PREDICT = 'predict' 25 | 26 | def __init__(self, 27 | master = None, 28 | windowSize = (AppConfig.WINDOW_WIDTH, AppConfig.WINDOW_HEIGHT)): 29 | self.master = master 30 | self.windowSize = windowSize 31 | 32 | tk.Frame.__init__(self, master) 33 | self.createUI() 34 | 35 | def createUI(self): 36 | currrow = 0 37 | width, height = self.windowSize 38 | self.canvas = tk.Canvas(self.master, 39 | width = width, 40 | height = width, 41 | highlightthickness = 0, 42 | background = AppConfig.CANVAS_BACKGROUND_COLOR) 43 | self.canvas.grid(row = currrow, 44 | column = 0, 45 | columnspan = 2, 46 | sticky = tk.E + tk.W) 47 | currrow += 1 48 | self.canvas.bind('', lambda event: self.paintPoint(event.x, event.y, thickness = 6)) 49 | 50 | self.createPILImage() 51 | 52 | self.button = defaultdict(tk.Button) 53 | self.button[App.Frame.BUTTON_CLEAR] = tk.Button(self.master, 54 | text = App.Frame.BUTTON_CLEAR.capitalize(), 55 | command = self.clear) 56 | self.button[App.Frame.BUTTON_CLEAR].grid(row = currrow, 57 | column = 0, 58 | sticky = tk.E + tk.W) 59 | 60 | self.button[App.Frame.BUTTON_PREDICT] = tk.Button(self.master, 61 | text = App.Frame.BUTTON_PREDICT.capitalize()) 62 | self.button[App.Frame.BUTTON_PREDICT].grid(row = currrow, 63 | column = 1, 64 | sticky = tk.E + tk.W) 65 | currrow += 1 66 | 67 | self.output = tk.Canvas(self.master, 68 | width = width, 69 | height = width, 70 | highlightthickness = 0, 71 | background = '#FFFFFF') 72 | self.output.grid(row = currrow, 73 | column = 0, 74 | columnspan = 2, 75 | sticky = tk.E + tk.W) 76 | 77 | def setOutput(self, output): 78 | width, height = self.windowSize 79 | self.clearOutput() 80 | 81 | self.output.create_text(width/2, height * 0.25 / 2, 82 | font = 'Helvetica 30', 83 | text = output) 84 | 85 | def clearCanvas(self): 86 | self.canvas.delete('all') 87 | 88 | def clearOutput(self): 89 | self.output.delete('all') 90 | 91 | def paintPoint(self, x, y, 92 | thickness = 0, 93 | color = '#FFFFFF'): 94 | a, b = x - thickness, y - thickness 95 | c, d = x + thickness, y + thickness 96 | points = [a, b, c, d] 97 | 98 | self.canvas.create_oval(points, fill = color) 99 | self.imageDraw.ellipse (points, color) 100 | 101 | def clear(self): 102 | self.clearCanvas() 103 | self.clearOutput() 104 | self.createPILImage() 105 | 106 | def createPILImage(self): 107 | width, height = self.windowSize 108 | self.image = Image.new('RGB', (width, width), AppConfig.CANVAS_BACKGROUND_COLOR) 109 | self.imageDraw = ImageDraw.Draw(self.image) 110 | 111 | def __init__(self, 112 | windowSize = (AppConfig.WINDOW_WIDTH, AppConfig.WINDOW_HEIGHT)): 113 | self.root = tk.Tk() 114 | self.windowSize = windowSize 115 | 116 | self.root.title('{name} v{version}'.format( 117 | name = AppConfig.NAME, 118 | version = _get_version_string() 119 | )) 120 | width, height = self.windowSize 121 | self.root.geometry('{width}x{height}'.format(width = width, height = height)) 122 | self.root.resizable(width = False, 123 | height = False) 124 | 125 | self.frame = App.Frame(self.root, self.windowSize) 126 | self.frame.button[App.Frame.BUTTON_PREDICT].config(command = self.predict) 127 | 128 | with open(ABSPATH_MODEL_DHCD_SVC, 'rb') as f: 129 | self.model = pickle.load(f) 130 | 131 | with open(ABSPATH_ENCODER_DHCD, 'rb') as f: 132 | self.encoder = pickle.load(f) 133 | 134 | def predict(self): 135 | image = self.frame.image 136 | image.thumbnail(DHCD_INPUT_SIZE, Image.BICUBIC) 137 | 138 | arr = _image_to_input(image) 139 | arr = np.reshape(arr, (1, arr.size)) 140 | out = self.model.predict(arr) 141 | sym = self.encoder.inverse_transform(out[0]) 142 | 143 | self.frame.setOutput(sym) 144 | 145 | def run(self): 146 | self.root.mainloop() 147 | -------------------------------------------------------------------------------- /notebooks/Devanagari Character Recognition.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": false, 8 | "deletable": true, 9 | "editable": true 10 | }, 11 | "outputs": [], 12 | "source": [ 13 | "import os\n", 14 | "import time\n", 15 | "import pickle\n", 16 | "import multiprocessing\n", 17 | "\n", 18 | "from PIL import Image\n", 19 | "\n", 20 | "import numpy as np\n", 21 | "import matplotlib.pyplot as pplt\n", 22 | "import matplotlib.cm as cmap\n", 23 | "import pandas as pd\n", 24 | "\n", 25 | "from sklearn import datasets\n", 26 | "from sklearn import preprocessing\n", 27 | "from sklearn import model_selection\n", 28 | "from sklearn import metrics\n", 29 | "from sklearn import neural_network as nn\n", 30 | "\n", 31 | "from sklearn import svm\n", 32 | "from sklearn.multiclass import OneVsRestClassifier as OVRC\n", 33 | "from sklearn.ensemble import BaggingClassifier as BC\n", 34 | "\n", 35 | "% matplotlib inline" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": { 42 | "collapsed": false, 43 | "deletable": true, 44 | "editable": true 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "RELPATH_DATA = 'data'\n", 49 | "RELPATH_TRAIN = os.path.join(RELPATH_DATA, 'train.csv')\n", 50 | "RELPATH_TEST = os.path.join(RELPATH_DATA, 'test.csv')\n", 51 | "\n", 52 | "DEVANAGARI_DIGIT_CHARSET = ['०','१','२','३','४','५','६','७','८','९']" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": { 59 | "collapsed": false, 60 | "deletable": true, 61 | "editable": true 62 | }, 63 | "outputs": [], 64 | "source": [ 65 | "NUM_CORES = multiprocessing.cpu_count() - 2\n", 66 | "NUM_CORES" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": { 73 | "collapsed": false, 74 | "deletable": true, 75 | "editable": true 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "data = pd.concat([pd.read_csv(RELPATH_TRAIN), pd.read_csv(RELPATH_TEST)])" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": { 86 | "collapsed": true, 87 | "deletable": true, 88 | "editable": true 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "sample = data.sample()" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": { 99 | "collapsed": false, 100 | "deletable": true, 101 | "editable": true 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "sample.sample()" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": { 112 | "collapsed": false, 113 | "deletable": true, 114 | "editable": true 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "arr = np.array(sample.ix[:,:-1])\n", 119 | "arr = np.reshape(arr, (32, 32))" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": { 126 | "collapsed": false, 127 | "deletable": true, 128 | "editable": true 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "pplt.imshow(arr, cmap = cmap.gray)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "metadata": { 139 | "collapsed": true, 140 | "deletable": true, 141 | "editable": true 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "digits = data[data['character'].isin(DEVANAGARI_DIGIT_CHARSET)]\n", 146 | "X, y = digits.ix[:,:-1], digits.ix[:,-1]" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": { 153 | "collapsed": true, 154 | "deletable": true, 155 | "editable": true 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "# X, y = data.ix[:,:-1], data.ix[:,-1]" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": { 166 | "collapsed": false, 167 | "deletable": true, 168 | "editable": true 169 | }, 170 | "outputs": [], 171 | "source": [ 172 | "encoder = preprocessing.LabelEncoder()\n", 173 | "y = encoder.fit_transform(y)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": { 180 | "collapsed": true, 181 | "deletable": true, 182 | "editable": true 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "with open('DHCD-LE.pkl', 'wb') as f:\n", 187 | " pickle.dump(encoder, f)" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": { 194 | "collapsed": false, 195 | "deletable": true, 196 | "editable": true 197 | }, 198 | "outputs": [], 199 | "source": [ 200 | "Xtrain, Xtest, ytrain, ytest = model_selection.train_test_split(X, y, train_size = 0.60)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": { 207 | "collapsed": false, 208 | "deletable": true, 209 | "editable": true 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "def stratified_cross_validate(X, y, class_, n_folds = 10, *args, **kwargs):\n", 214 | " model = class_(*args, **kwargs)\n", 215 | " crossval = model_selection.StratifiedKFold(n_splits = n_folds)\n", 216 | " accuracy = 0\n", 217 | " \n", 218 | " accuracies = [ ] \n", 219 | " for ii, jj in crossval.split(X, y):\n", 220 | " Xtrain, Xtest = X[ii], X[jj]\n", 221 | " ytrain, ytest = y[ii], y[jj]\n", 222 | " \n", 223 | " model = model.fit(Xtrain, ytrain)\n", 224 | " predict = model.predict(Xtest)\n", 225 | " acc = metrics.accuracy_score(ytest, predict)\n", 226 | " accuracy += acc\n", 227 | " \n", 228 | " accuracies.append(acc)\n", 229 | " \n", 230 | " figure = pplt.figure(figsize = (20, 15))\n", 231 | " axes = figure.add_subplot(111)\n", 232 | " axes.plot(list(range(1, n_folds + 1)), accuracies)\n", 233 | " \n", 234 | " return figure, axes, (accuracy / n_folds), model\n", 235 | "# return (accuracy / n_folds), model" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": { 242 | "collapsed": false, 243 | "deletable": true, 244 | "editable": true 245 | }, 246 | "outputs": [], 247 | "source": [ 248 | "pplt.style.use('fivethirtyeight')\n", 249 | "\n", 250 | "start = time.time()\n", 251 | "fig, ax, acc, model = stratified_cross_validate(np.asarray(Xtrain), np.asarray(ytrain), svm.SVC, verbose = 3)\n", 252 | "end = time.time()\n", 253 | "# acc, model = stratified_cross_validate(np.asarray(Xtrain), np.asarray(ytrain), svm.SVC, verbose = 3)\n", 254 | "\n", 255 | "print('time ellapsed: %.2fs' % (end - start))" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": null, 261 | "metadata": { 262 | "collapsed": false, 263 | "deletable": true, 264 | "editable": true 265 | }, 266 | "outputs": [], 267 | "source": [ 268 | "acc" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "metadata": { 275 | "collapsed": false, 276 | "deletable": true, 277 | "editable": true 278 | }, 279 | "outputs": [], 280 | "source": [ 281 | "predict = model.predict(Xtest)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "metadata": { 288 | "collapsed": false, 289 | "deletable": true, 290 | "editable": true 291 | }, 292 | "outputs": [], 293 | "source": [ 294 | "overacc = metrics.accuracy_score(ytest, predict)\n", 295 | "overacc" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": null, 301 | "metadata": { 302 | "collapsed": true, 303 | "deletable": true, 304 | "editable": true 305 | }, 306 | "outputs": [], 307 | "source": [ 308 | "with open('SVC-{crossacc},10-{accuracy},0.60.pkl'.format(crossacc = acc, accuracy = overacc), 'wb') as f:\n", 309 | " pickle.dump(model, f)" 310 | ] 311 | } 312 | ], 313 | "metadata": { 314 | "kernelspec": { 315 | "display_name": "Python 3", 316 | "language": "python", 317 | "name": "python3" 318 | }, 319 | "language_info": { 320 | "codemirror_mode": { 321 | "name": "ipython", 322 | "version": 3 323 | }, 324 | "file_extension": ".py", 325 | "mimetype": "text/x-python", 326 | "name": "python", 327 | "nbconvert_exporter": "python", 328 | "pygments_lexer": "ipython3", 329 | "version": "3.5.2" 330 | } 331 | }, 332 | "nbformat": 4, 333 | "nbformat_minor": 2 334 | } 335 | --------------------------------------------------------------------------------