├── README.md ├── drawing_classifier.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # drawing-classifier 2 | A Python application which uses machine learning classification algorithms to classify drawings of the user. 3 | 4 | NEURALNINE (c) 2019 5 | Drawing Classifier ML Alpha v0.1 6 | 7 | This script allows you to define three custom classes and then to draw the respective drawings. By clicking the buttons of the individual classes, you add one training file to the model. Then you can train it (choose one of multiple classifiers) and predict future drawings. 8 | 9 | This is the very first prototype and the code is not clean at all 10 | Also there may be a couple of bugs 11 | A lot of exceptions are not handled -------------------------------------------------------------------------------- /drawing_classifier.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os.path 3 | 4 | import tkinter.messagebox 5 | from tkinter import * 6 | from tkinter import simpledialog, filedialog 7 | 8 | import PIL 9 | import PIL.Image, PIL.ImageDraw 10 | import cv2 as cv 11 | import numpy as np 12 | 13 | from sklearn.svm import LinearSVC 14 | from sklearn.naive_bayes import GaussianNB 15 | from sklearn.tree import DecisionTreeClassifier 16 | from sklearn.neighbors import KNeighborsClassifier 17 | from sklearn.ensemble import RandomForestClassifier 18 | from sklearn.linear_model import LogisticRegression 19 | 20 | 21 | class DrawingClassifier: 22 | 23 | def __init__(self): 24 | self.class1, self.class2, self.class3 = None, None, None 25 | self.class1_counter, self.class2_counter, self.class3_counter = None, None, None 26 | self.clf = None 27 | self.proj_name = None 28 | self.root = None 29 | self.image1 = None 30 | 31 | self.status_label = None 32 | self.canvas = None 33 | self.draw = None 34 | 35 | self.brush_width = 15 36 | 37 | self.classes_prompt() 38 | self.init_gui() 39 | 40 | def classes_prompt(self): 41 | msg = Tk() 42 | msg.withdraw() 43 | 44 | self.proj_name = simpledialog.askstring("Project Name", "Please enter your project name down below!", parent=msg) 45 | if os.path.exists(self.proj_name): 46 | with open(f"{self.proj_name}/{self.proj_name}_data.pickle", "rb") as f: 47 | data = pickle.load(f) 48 | self.class1 = data['c1'] 49 | self.class2 = data['c2'] 50 | self.class3 = data['c3'] 51 | self.class1_counter = data['c1c'] 52 | self.class2_counter = data['c2c'] 53 | self.class3_counter = data['c3c'] 54 | self.clf = data['clf'] 55 | self.proj_name = data['pname'] 56 | else: 57 | self.class1 = simpledialog.askstring("Class 1", "What is the first class called?", parent=msg) 58 | self.class2 = simpledialog.askstring("Class 2", "What is the second class called?", parent=msg) 59 | self.class3 = simpledialog.askstring("Class 3", "What is the third class called?", parent=msg) 60 | 61 | self.class1_counter = 1 62 | self.class2_counter = 1 63 | self.class3_counter = 1 64 | 65 | self.clf = LinearSVC() 66 | 67 | os.mkdir(self.proj_name) 68 | os.chdir(self.proj_name) 69 | os.mkdir(self.class1) 70 | os.mkdir(self.class2) 71 | os.mkdir(self.class3) 72 | os.chdir("..") 73 | 74 | def init_gui(self): 75 | WIDTH = 500 76 | HEIGHT = 500 77 | WHITE = (255, 255, 255) 78 | 79 | self.root = Tk() 80 | self.root.title(f"NeuralNine Drawing Classifier Alpha v0.2 - {self.proj_name}") 81 | 82 | self.canvas = Canvas(self.root, width=WIDTH-10, height=HEIGHT-10, bg="white") 83 | self.canvas.pack(expand=YES, fill=BOTH) 84 | self.canvas.bind("", self.paint) 85 | 86 | self.image1 = PIL.Image.new("RGB", (WIDTH, HEIGHT), WHITE) 87 | self.draw = PIL.ImageDraw.Draw(self.image1) 88 | 89 | btn_frame = tkinter.Frame(self.root) 90 | btn_frame.pack(fill=X, side=BOTTOM) 91 | 92 | btn_frame.columnconfigure(0, weight=1) 93 | btn_frame.columnconfigure(1, weight=1) 94 | btn_frame.columnconfigure(2, weight=1) 95 | 96 | class1_btn = Button(btn_frame, text=self.class1, command=lambda: self.save(1)) 97 | class1_btn.grid(row=0, column=0, sticky=W + E) 98 | 99 | class2_btn = Button(btn_frame, text=self.class2, command=lambda: self.save(2)) 100 | class2_btn.grid(row=0, column=1, sticky=W + E) 101 | 102 | class3_btn = Button(btn_frame, text=self.class3, command=lambda: self.save(3)) 103 | class3_btn.grid(row=0, column=2, sticky=W + E) 104 | 105 | bm_btn = Button(btn_frame, text="Brush-", command=self.brushminus) 106 | bm_btn.grid(row=1, column=0, sticky=W + E) 107 | 108 | clear_btn = Button(btn_frame, text="Clear", command=self.clear) 109 | clear_btn.grid(row=1, column=1, sticky=W + E) 110 | 111 | bp_btn = Button(btn_frame, text="Brush+", command=self.brushplus) 112 | bp_btn.grid(row=1, column=2, sticky=W + E) 113 | 114 | train_btn = Button(btn_frame, text="Train Model", command=self.train_model) 115 | train_btn.grid(row=2, column=0, sticky=W + E) 116 | 117 | save_btn = Button(btn_frame, text="Save Model", command=self.save_model) 118 | save_btn.grid(row=2, column=1, sticky=W + E) 119 | 120 | load_btn = Button(btn_frame, text="Load Model", command=self.load_model) 121 | load_btn.grid(row=2, column=2, sticky=W + E) 122 | 123 | change_btn = Button(btn_frame, text="Change Model", command=self.rotate_model) 124 | change_btn.grid(row=3, column=0, sticky=W + E) 125 | 126 | predict_btn = Button(btn_frame, text="Predict", command=self.predict) 127 | predict_btn.grid(row=3, column=1, sticky=W + E) 128 | 129 | save_everything_btn = Button(btn_frame, text="Save Everything", command=self.save_everything) 130 | save_everything_btn.grid(row=3, column=2, sticky=W + E) 131 | 132 | self.status_label = Label(btn_frame, text=f"Current Model: {type(self.clf).__name__}") 133 | self.status_label.config(font=("Arial", 10)) 134 | self.status_label.grid(row=4, column=1, sticky=W + E) 135 | 136 | self.root.protocol("WM_DELETE_WINDOW", self.on_closing) 137 | self.root.attributes("-topmost", True) 138 | self.root.mainloop() 139 | 140 | def paint(self, event): 141 | x1, y1 = (event.x - 1), (event.y - 1) 142 | x2, y2 = (event.x + 1), (event.y + 1) 143 | self.canvas.create_rectangle(x1, y1, x2, y2, fill="black", width=self.brush_width) 144 | self.draw.rectangle([x1, y2, x2 + self.brush_width, y2 + self.brush_width], fill="black", width=self.brush_width) 145 | 146 | def save(self, class_num): 147 | self.image1.save("temp.png") 148 | img = PIL.Image.open("temp.png") 149 | img.thumbnail((50, 50), PIL.Image.ANTIALIAS) 150 | 151 | if class_num == 1: 152 | img.save(f"{self.proj_name}/{self.class1}/{self.class1_counter}.png", "PNG") 153 | self.class1_counter += 1 154 | elif class_num == 2: 155 | img.save(f"{self.proj_name}/{self.class2}/{self.class2_counter}.png", "PNG") 156 | self.class2_counter += 1 157 | elif class_num == 3: 158 | img.save(f"{self.proj_name}/{self.class3}/{self.class3_counter}.png", "PNG") 159 | self.class3_counter += 1 160 | 161 | self.clear() 162 | 163 | def brushminus(self): 164 | if self.brush_width > 1: 165 | self.brush_width -= 1 166 | 167 | def brushplus(self): 168 | self.brush_width += 1 169 | 170 | def clear(self): 171 | self.canvas.delete("all") 172 | self.draw.rectangle([0, 0, 1000, 1000], fill="white") 173 | 174 | def train_model(self): 175 | img_list = np.array([]) 176 | class_list = np.array([]) 177 | 178 | for x in range(1, self.class1_counter): 179 | img = cv.imread(f"{self.proj_name}/{self.class1}/{x}.png")[:, :, 0] 180 | img = img.reshape(2500) 181 | img_list = np.append(img_list, [img]) 182 | class_list = np.append(class_list, 1) 183 | 184 | for x in range(1, self.class2_counter): 185 | img = cv.imread(f"{self.proj_name}/{self.class2}/{x}.png")[:, :, 0] 186 | img = img.reshape(2500) 187 | img_list = np.append(img_list, [img]) 188 | class_list = np.append(class_list, 2) 189 | 190 | for x in range(1, self.class3_counter): 191 | img = cv.imread(f"{self.proj_name}/{self.class3}/{x}.png")[:, :, 0] 192 | img = img.reshape(2500) 193 | img_list = np.append(img_list, [img]) 194 | class_list = np.append(class_list, 3) 195 | 196 | img_list = img_list.reshape(self.class1_counter - 1 + self.class2_counter - 1 + self.class3_counter - 1, 2500) 197 | 198 | self.clf.fit(img_list, class_list) 199 | tkinter.messagebox.showinfo("NeuralNine Drawing Classifier", "Model successfully trained!", parent=self.root) 200 | 201 | def predict(self): 202 | self.image1.save("temp.png") 203 | img = PIL.Image.open("temp.png") 204 | img.thumbnail((50, 50), PIL.Image.ANTIALIAS) 205 | img.save("predictshape.png", "PNG") 206 | 207 | img = cv.imread("predictshape.png")[:, :, 0] 208 | img = img.reshape(2500) 209 | prediction = self.clf.predict([img]) 210 | if prediction[0] == 1: 211 | tkinter.messagebox.showinfo("NeuralNine Drawing Classifier", f"The drawing is probably a {self.class1}", parent=self.root) 212 | elif prediction[0] == 2: 213 | tkinter.messagebox.showinfo("NeuralNine Drawing Classifier", f"The drawing is probably a {self.class2}", parent=self.root) 214 | elif prediction[0] == 3: 215 | tkinter.messagebox.showinfo("NeuralNine Drawing Classifier", f"The drawing is probably a {self.class3}", parent=self.root) 216 | 217 | def rotate_model(self): 218 | if isinstance(self.clf, LinearSVC): 219 | self.clf = KNeighborsClassifier() 220 | elif isinstance(self.clf, KNeighborsClassifier): 221 | self.clf = LogisticRegression() 222 | elif isinstance(self.clf, LogisticRegression): 223 | self.clf = DecisionTreeClassifier() 224 | elif isinstance(self.clf, DecisionTreeClassifier): 225 | self.clf = RandomForestClassifier() 226 | elif isinstance(self.clf, RandomForestClassifier): 227 | self.clf = GaussianNB() 228 | elif isinstance(self.clf, GaussianNB): 229 | self.clf = LinearSVC() 230 | 231 | self.status_label.config(text=f"Current Model: {type(self.clf).__name__}") 232 | 233 | def save_model(self): 234 | file_path = filedialog.asksaveasfilename(defaultextension="pickle") 235 | with open(file_path, "wb") as f: 236 | pickle.dump(self.clf, f) 237 | tkinter.messagebox.showinfo("NeuralNine Drawing Classifier", "Model successfully saved!", parent=self.root) 238 | 239 | def load_model(self): 240 | file_path = filedialog.askopenfilename() 241 | with open(file_path, "rb") as f: 242 | self.clf = pickle.load(f) 243 | tkinter.messagebox.showinfo("NeuralNine Drawing Classifier", "Model successfully loaded!", parent=self.root) 244 | 245 | def save_everything(self): 246 | data = {"c1": self.class1, "c2": self.class2, "c3": self.class3, "c1c": self.class1_counter, 247 | "c2c": self.class2_counter, "c3c": self.class3_counter, "clf": self.clf, "pname": self.proj_name} 248 | with open(f"{self.proj_name}/{self.proj_name}_data.pickle", "wb") as f: 249 | pickle.dump(data, f) 250 | tkinter.messagebox.showinfo("NeuralNine Drawing Classifier", "Project successfully saved!", parent=self.root) 251 | 252 | def on_closing(self): 253 | answer = tkinter.messagebox.askyesnocancel("Quit?", "Do you want to save your work?", parent=self.root) 254 | if answer is not None: 255 | if answer: 256 | self.save_everything() 257 | self.root.destroy() 258 | exit() 259 | 260 | 261 | DrawingClassifier() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pillow 2 | opencv-python 3 | numpy 4 | scikit-learn --------------------------------------------------------------------------------