├── .gitignore ├── MNIST ├── t10k-images-idx3-ubyte.lzma ├── t10k-labels-idx1-ubyte.lzma ├── train-images-idx3-ubyte.lzma └── train-labels-idx1-ubyte.lzma ├── README.md ├── assets ├── logo-tessilab.png ├── logo.png ├── screen-capture1.jpg └── screen-capture2.jpg ├── benchmark.py ├── main.py ├── ml_evaluate.py ├── ml_training.py ├── requirements.txt ├── requirements_macos_intel.txt └── time_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ 3 | __pycache__/ 4 | auto/ 5 | checkpoints/ 6 | 28.png 7 | saved.png 8 | assets/logo.xcf 9 | sweet_spot.txt 10 | -------------------------------------------------------------------------------- /MNIST/t10k-images-idx3-ubyte.lzma: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tessi-lab/ml-bench/046d9b4a17084b2f2f41a3127c3c36a6ecfe7e1d/MNIST/t10k-images-idx3-ubyte.lzma -------------------------------------------------------------------------------- /MNIST/t10k-labels-idx1-ubyte.lzma: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tessi-lab/ml-bench/046d9b4a17084b2f2f41a3127c3c36a6ecfe7e1d/MNIST/t10k-labels-idx1-ubyte.lzma -------------------------------------------------------------------------------- /MNIST/train-images-idx3-ubyte.lzma: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tessi-lab/ml-bench/046d9b4a17084b2f2f41a3127c3c36a6ecfe7e1d/MNIST/train-images-idx3-ubyte.lzma -------------------------------------------------------------------------------- /MNIST/train-labels-idx1-ubyte.lzma: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tessi-lab/ml-bench/046d9b4a17084b2f2f41a3127c3c36a6ecfe7e1d/MNIST/train-labels-idx1-ubyte.lzma -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ML-Benchmark 2 | 3 | ### Visuals 4 | Main Screen 5 | 6 | ![Main screen](/assets/screen-capture1.jpg?raw=true "Main screen") 7 | 8 | 9 | Test Screen 10 | 11 | ![Test screen](/assets/screen-capture2.jpg?raw=true "Test screen") 12 | 13 | ### Description 14 | 15 | This projet aims to compare performance 16 | of computers with the same type of training. 17 | 18 | It's based on [MNIST](http://yann.lecun.com/exdb/mnist/) dataset. 19 | There are some parameters that can be tuned. 20 | 21 | **Epoch**: number, positive integer. By default 20. 22 | 23 | **Neural network size**: controls the number of filters applied on 2 first layers 24 | - `light : `[2, 4], 25 | - `basic : `[4, 8] enough to make the network converge, 26 | - `normal: `[8, 16] enough to make the network converge with more data, 27 | - `heavy : `[16, 32] too many filters for a such a simple task, 28 | - `too-heavy:`[64, 128] idem, but quadruple, will probably not converge, 29 | - `stupid: `[128, 256] idem, but stupidly huge, 30 | - `insane: `[256, 512] idem, but insanely huge. 31 | 32 | The button **Launch training !** does just that, each step is timed. 33 | It will display something like : 34 | - 'load_test' 745.21 ms 35 | - 'train' 13899.57 ms 36 | 37 | Once training is over you can copy paste the result in a file. 38 | Then you may click **Play with it !** to draw digits in the black area 39 | and see the result in the title bar. 40 | Right click removes the last line drawn. 41 | 42 | ### Installation 43 | Install python 3.8 for your system, create a virtual-env: 44 | 45 | `python3 -m venv venv-ml-bench` 46 | 47 | Use the right activate for your shell (for example bash) 48 | 49 | `source venv-ml-bench/bin/activate` 50 | 51 | Download sources from Github: 52 | 53 | `git clone https://github.com/tessi-lab/ml-bench.git` 54 | 55 | `cd ml-bench` 56 | 57 | Install the packages for your system. 58 | For macOS with Apple Silicon see this [article](https://developer.apple.com/metal/tensorflow-plugin/). 59 | 60 | _With a M1 Apple Silicon_ : 61 | 62 | bash Miniforge3-MacOSX-arm64.sh 63 | . miniforge3/bin/activate 64 | conda install -c apple tensorflow-deps 65 | pip install tensorflow-macos 66 | pip install tensorflow-metal 67 | conda install -c apple wxpython 68 | conda install -c apple opencv 69 | pip install python-mnist 70 | 71 | _With an Intel Mac_: 72 | 73 | Example: 74 | 75 | `python -m pip install -r requirements_macos_intel.txt` 76 | 77 | Launch the tool : 78 | 79 | `python benchmark.py` 80 | 81 | or 82 | 83 | `pythonw benchmark.py` 84 | 85 | ### Results for size normal and 20 epochs 86 | #### AlienWare Ryzen Edition AMD Ryzen 9 5950X 16-Core 64 GB RAM, SSD PM9A1 2 TB 87 | - LOAD_TRAIN: 3380 ms 88 | - CREATE_TRAIN: 6354 ms 89 | - LOAD_TEST: 547 ms 90 | - CREATE_TEST: 569 ms 91 | - TRAIN: 116817 ms 92 | 93 | #### iMac 2015, core i5, AMD Radeon R9 M390 2 Go, 24 GB RAM, SSD 1TB 94 | ##### Run on CPU only (requirements.txt) 95 | - 'load_train' 5802.60 ms 96 | - 'create_train' 12206.21 ms 97 | - 'load_test' 943.89 ms 98 | - 'create_test' 991.74 ms 99 | - 'train' 328884.54 ms 100 | 101 | ##### Run on GPU + CPU (requirements_macos_intel.txt) 102 | - 'load_train' 5908.59 ms 103 | - 'create_train' 11447.35 ms 104 | - 'load_test' 944.63 ms 105 | - 'create_test' 992.16 ms 106 | - 'train' 263666.66 ms 107 | 108 | #### iMac 2020 24' M1 - 16GB RAM - SSD 1 TB Metal (8 GPU Cores) 109 | - LOAD_TRAIN: 3326 ms 110 | - CREATE_TRAIN: 6384 ms 111 | - LOAD_TEST: 523 ms 112 | - CREATE_TEST: 540 ms 113 | - TRAIN: 129372 ms 114 | 115 | #### iMac 2020 24' M1 - 16GB RAM - SSD 1 TB CPU Only 116 | (same computer uninstall tensorflow-metal package) 117 | 118 | - LOAD_TRAIN: 3242 ms 119 | - CREATE_TRAIN: 6304 ms 120 | - LOAD_TEST: 518 ms 121 | - CREATE_TEST: 535 ms 122 | - TRAIN: 217383 ms 123 | 124 | #### MacBook 2019 16' - i9 - 32 GB RAM - SSD 1 TB on CPU and battery (turbo boost switcher) 125 | - LOAD_TRAIN: 8470 ms 126 | - CREATE_TRAIN: 16285 ms 127 | - LOAD_TEST: 1308 ms 128 | - CREATE_TEST: 1373 ms 129 | - TRAIN: 1105780 ms 130 | 131 | #### MacBook 2019 16' - i9 - 32 GB RAM - SSD 1 TB on CPU and plugged to the wall 132 | - LOAD_TRAIN: 4760 ms 133 | - CREATE_TRAIN: 9552 ms 134 | - LOAD_TEST: 746 ms 135 | - CREATE_TEST: 782 ms 136 | - TRAIN: 437925 ms 137 | 138 | #### MacBook 2019 16' - i9 - 32 GB RAM - SSD 1 TB on GPU 5500M (8GB) and battery (turbo boost switcher) 139 | - LOAD_TRAIN: 8492 ms 140 | - CREATE_TRAIN: 17309 ms 141 | - LOAD_TEST: 1343 ms 142 | - CREATE_TEST: 1409 ms 143 | - TRAIN: 222868 ms 144 | 145 | #### MacBook 2019 16' - i9 - 32 GB RAM - SSD 1 TB on GPU 5500M (8GB) and plugged to the wall 146 | - LOAD_TRAIN: 4625 ms 147 | - CREATE_TRAIN: 10190 ms 148 | - LOAD_TEST: 722 ms 149 | - CREATE_TEST: 759 ms 150 | - TRAIN: 153928 ms 151 | 152 | #### MacBook 2021 16' - M1 Max - 64 GB RAM - SSD 2 TB and plugged to the wall 153 | - LOAD_TRAIN: 3224 ms 154 | - CREATE_TRAIN: 6292 ms 155 | - LOAD_TEST: 517 ms 156 | - CREATE_TEST: 533 ms 157 | - TRAIN: 111895 ms 158 | -------------------------------------------------------------------------------- /assets/logo-tessilab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tessi-lab/ml-bench/046d9b4a17084b2f2f41a3127c3c36a6ecfe7e1d/assets/logo-tessilab.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tessi-lab/ml-bench/046d9b4a17084b2f2f41a3127c3c36a6ecfe7e1d/assets/logo.png -------------------------------------------------------------------------------- /assets/screen-capture1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tessi-lab/ml-bench/046d9b4a17084b2f2f41a3127c3c36a6ecfe7e1d/assets/screen-capture1.jpg -------------------------------------------------------------------------------- /assets/screen-capture2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tessi-lab/ml-bench/046d9b4a17084b2f2f41a3127c3c36a6ecfe7e1d/assets/screen-capture2.jpg -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import wx 2 | import wx.adv 3 | 4 | 5 | class Benchmark(wx.Frame): 6 | def __init__(self): 7 | wx.Frame.__init__(self, None, wx.ID_ANY, "Basic ML bench") 8 | pass 9 | 10 | 11 | if __name__ == "__main__": 12 | app = wx.App(False) 13 | bitmap = wx.Bitmap('./assets/logo-tessilab.png', wx.BITMAP_TYPE_PNG) 14 | splash = wx.adv.SplashScreen(bitmap, wx.adv.SPLASH_CENTRE_ON_SCREEN | wx.adv.SPLASH_TIMEOUT, 15 | 6000, None, -1, wx.DefaultPosition, wx.DefaultSize, 16 | wx.BORDER_SIMPLE | wx.STAY_ON_TOP) 17 | 18 | wx.Yield() 19 | import main 20 | main.run(app) 21 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # This is a sample Python script. 2 | 3 | # Press ⌃R to execute it or replace it with your code. 4 | # Press Double ⇧ to search everywhere for classes, files, tool windows, actions, and settings. 5 | import sys 6 | from pathlib import Path 7 | from threading import Thread 8 | from typing import List, Tuple, Dict, Optional 9 | 10 | import cv2 11 | import wx 12 | from keras.models import Model 13 | from wx import wxEVT_LEFT_DOWN, wxEVT_LEFT_UP, wxEVT_RIGHT_UP 14 | 15 | import ml_training 16 | import ml_evaluate 17 | 18 | modulus = 400 19 | log_time: Dict[str, int] = {} 20 | 21 | 22 | class RedirectText(object): 23 | def __init__(self, aWxTextCtrl: wx.TextCtrl): 24 | self.out = aWxTextCtrl 25 | 26 | def write(self, string): 27 | wx.CallAfter(self.out.WriteText, string) 28 | self.flush() 29 | 30 | def flush(self): 31 | if wx.GetApp() is not None: 32 | wx.CallAfter(self.out.Refresh) 33 | 34 | 35 | class Train(Thread): 36 | def __init__(self, this, level_name: str, epochs: int = 20, batch_test: bool = False): 37 | super().__init__() 38 | self.this: MainForm = this 39 | self.epochs = epochs 40 | self.level_name = level_name 41 | self.batch_test = batch_test 42 | 43 | def run(self) -> None: 44 | ml_training.epoch = self.epochs 45 | ml_training.do_it(self.level_name, tune_bs=self.batch_test, log_time=log_time, log_verbose=True) 46 | wx.CallAfter(self.this.train_done, self.batch_test) 47 | 48 | @staticmethod 49 | def get_epochs() -> int: 50 | return ml_training.epoch 51 | 52 | @staticmethod 53 | def get_level_names() -> List[str]: 54 | return list(ml_training.level.keys()) 55 | 56 | 57 | class MainForm(wx.Frame): 58 | def __init__(self): 59 | wx.Frame.__init__(self, None, wx.ID_ANY, "TessiLab's Basic ML bench") 60 | 61 | main_panel = wx.Panel(self, wx.ID_ANY) 62 | log = wx.TextCtrl(main_panel, wx.ID_ANY, size=(300, 100), 63 | style=wx.TE_MULTILINE | wx.TE_READONLY | wx.HSCROLL) 64 | info = wx.FontInfo(12.0) 65 | info.Family(wx.FONTFAMILY_TELETYPE) 66 | info.Weight(wx.FONTWEIGHT_NORMAL) 67 | log.SetFont(wx.Font(info)) 68 | 69 | south_panel = wx.Panel(main_panel, wx.ID_ANY) 70 | self.batch_detect = wx.Button(south_panel, wx.ID_ANY, 'Detect batch size') 71 | self.batch_detect.SetToolTip("Detect the sweet spot for this CPU/GPU.\n" 72 | "It may crash be records the last good one.") 73 | self.Bind(wx.EVT_BUTTON, self.onBSButton, self.batch_detect) 74 | self.btn = wx.Button(south_panel, wx.ID_ANY, 'Launch training !') 75 | self.Bind(wx.EVT_BUTTON, self.onButton, self.btn) 76 | self.play_btn = wx.Button(south_panel, wx.ID_ANY, 'Play with it !') 77 | self.Bind(wx.EVT_BUTTON, self.onButtonPlay, self.play_btn) 78 | 79 | lbl = wx.StaticText(south_panel, wx.ID_ANY, 'Epochs') 80 | self.epochs = wx.TextCtrl(south_panel, wx.ID_ANY, f'{Train.get_epochs()}') 81 | self.radios: List[wx.RadioButton] = [] 82 | 83 | # Add widgets to a sizer 84 | sizer = wx.BoxSizer(wx.VERTICAL) 85 | sizer.Add(log, 1, wx.ALL | wx.EXPAND, 5) 86 | sizer.Add(south_panel, 0, wx.ALL | wx.ALIGN_LEFT, 5) 87 | 88 | out_net_panel = wx.Panel(south_panel, wx.ID_ANY) 89 | out_net_sizer = wx.BoxSizer(wx.VERTICAL) 90 | # net_panel = wx.StaticBox(out_net_panel, wx.ID_ANY, 'Neural network size') 91 | # box = wx.RadioBox(south_panel, wx.ID_ANY, 'Neural network size') 92 | # net_sizer.Add(box) 93 | # net_sizer = wx.StaticBoxSizer(wx.VERTICAL, net_panel) 94 | 95 | net_panel = out_net_panel 96 | net_sizer = out_net_sizer 97 | text = wx.StaticText(net_panel, wx.ID_ANY, 'Neural network size ') 98 | font: wx.Font = text.GetFont() # wx.Font(18, wx.DECORATIVE, wx.ITALIC, wx.NORMAL) 99 | font.SetStyle(wx.FONTSTYLE_ITALIC) 100 | text.SetFont(font) 101 | net_sizer.Add(text) 102 | 103 | for label in Train.get_level_names(): 104 | radio = wx.RadioButton(net_panel, wx.ID_ANY, label, name=label) 105 | net_sizer.Add(radio) 106 | self.radios.append(radio) 107 | net_panel.SetSizer(net_sizer) 108 | # out_net_sizer.Add(net_panel) 109 | out_net_panel.SetSizer(out_net_sizer) 110 | self.radios[2].SetValue(True) 111 | 112 | south_west = wx.Panel(south_panel, wx.ID_ANY) 113 | south_west_sizer = wx.BoxSizer(wx.HORIZONTAL) 114 | south_west.SetSizer(south_west_sizer) 115 | 116 | png = wx.Image('./assets/logo.png', wx.BITMAP_TYPE_ANY) 117 | png = png.Rescale(png.GetWidth()/2, png.GetHeight()/2, wx.IMAGE_QUALITY_BICUBIC) 118 | png = png.ConvertToBitmap() 119 | w = png.GetWidth() 120 | h = png.GetHeight() 121 | logo = wx.StaticBitmap(south_west, wx.ID_ANY, png) 122 | logo.Bind(wx.EVT_MOUSE_EVENTS, self._logo_clicked) 123 | south_west.Bind(wx.EVT_MOUSE_EVENTS, self._logo_clicked) 124 | south_west_sizer.Add(logo) 125 | spacer = wx.StaticText(south_west, wx.ID_ANY, size=(w, h)) 126 | south_west_sizer.Add(spacer) 127 | 128 | south_sizer = wx.BoxSizer(wx.HORIZONTAL) 129 | south_sizer.Add(south_west, wx.ALIGN_LEFT) 130 | south_sizer.Add(out_net_panel) #, 0, wx.ALL, wx.LEFT, 5) 131 | south_sizer.Add(lbl) 132 | south_sizer.Add(self.epochs) 133 | south_sizer.Add(self.btn) 134 | south_sizer.Add(self.batch_detect) 135 | south_sizer.Add(self.play_btn) 136 | south_panel.SetSizer(south_sizer) 137 | 138 | main_panel.SetSizer(sizer) 139 | 140 | # redirect text here 141 | redir = RedirectText(log) 142 | sys.stdout = redir 143 | # sys.stderr = redir 144 | self.train: Train = None 145 | 146 | self.Maximize(True) 147 | print("Click a button below.") 148 | self._reload_batch_size() 149 | wx.CallLater(1000, print, "Waiting...") 150 | 151 | def _reload_batch_size(self): 152 | sweet_spot = Path('sweet_spot.txt') 153 | if sweet_spot.exists(): 154 | with open(sweet_spot, 'r') as f: 155 | selected_bs = -1 156 | min_duration = 1e300 157 | for l in f.readlines(): 158 | l: str 159 | if l.startswith('#'): 160 | continue 161 | else: 162 | s = l.split(', ') 163 | bs = int(s[0]) 164 | duration = int(s[1]) 165 | if duration < min_duration: 166 | selected_bs = bs 167 | min_duration = duration 168 | if selected_bs > 0: 169 | ml_training.batch_size = selected_bs 170 | print(f"Selected batch size is {ml_training.batch_size} from previous 'sweet_spot.txt'") 171 | self._flash_button(self.btn) 172 | else: 173 | print("You should start by choosing the right batch size with button 'Detect batch size'") 174 | self._flash_button(self.batch_detect) 175 | else: 176 | print("You should start by choosing the right batch size with button 'Detect batch size'") 177 | self._flash_button(self.batch_detect) 178 | 179 | def check_model_and_enable_ui(self): 180 | auto_path = Path("./auto") 181 | if auto_path.exists() and auto_path.is_dir(): 182 | wx.CallAfter(self.play_btn.Enable, True) 183 | else: 184 | wx.CallAfter(self.play_btn.Enable, False) 185 | 186 | def train_done(self, batch_detect: bool = False): 187 | self.check_model_and_enable_ui() 188 | print() 189 | print('************') 190 | print('** DONE **') 191 | print('************') 192 | for k, v in log_time.items(): 193 | print(f"{k}: {v} ms") 194 | print() 195 | self.batch_detect.Enable(True) 196 | self.btn.Enable(True) 197 | if batch_detect: 198 | self._flash_button(self.btn) 199 | else: 200 | self.btn.Enable(True) 201 | print("Training is over. You may now play with the model: click the button bellow.") 202 | print("Left click draws a shape (digit).") 203 | print("Right click deletes the last segment.") 204 | print("The result is diplayed in the window title.") 205 | self._flash_button(self.play_btn) 206 | 207 | def onBSButton(self, event): 208 | self.btn.Enable(False) 209 | self.batch_detect.Enable(False) 210 | self.play_btn.Enable(False) 211 | try: 212 | ep = int(self.epochs.GetValue()) 213 | name = 'basic' 214 | for r in self.radios: 215 | if r.GetValue(): 216 | name = r.GetName() 217 | break 218 | print(f"Will search batch size with neural network size {name}") 219 | self.train = Train(self, name, ep, True) 220 | self.train.start() 221 | except: 222 | self.btn.Enable(True) 223 | self.batch_detect.Enable(True) 224 | self.play_btn.Enable(True) 225 | self.epochs.SetFocus() 226 | 227 | 228 | def onButton(self, event): 229 | global frame 230 | self.btn.Enable(False) 231 | self.batch_detect.Enable(False) 232 | self.play_btn.Enable(False) 233 | try: 234 | ep = int(self.epochs.GetValue()) 235 | name = 'basic' 236 | for r in self.radios: 237 | if r.GetValue(): 238 | name = r.GetName() 239 | break 240 | print(f"Will train with neural network size {name} and batch size = {ml_training.batch_size}") 241 | self.train = Train(self, name, ep) 242 | self.train.start() 243 | except: 244 | self.btn.Enable(True) 245 | self.batch_detect.Enable(True) 246 | self.play_btn.Enable(True) 247 | self.epochs.SetFocus() 248 | 249 | def onButtonPlay(self, event): 250 | global frame 251 | self.play_btn.Enable(False) 252 | frame = DrawPanel() 253 | frame.Show() 254 | self.Show(False) 255 | 256 | __normal_color: Optional[wx.Colour] = None 257 | 258 | def _flash_button(self, button: wx.Button, times: int = 5, will_hi_light: bool = True): 259 | if will_hi_light: 260 | if self.__normal_color is None: 261 | self.__normal_color = button.GetForegroundColour() 262 | button.SetForegroundColour(wx.GREEN) 263 | else: 264 | button.SetForegroundColour(self.__normal_color) 265 | times -= 1 266 | 267 | if times > 0: 268 | wx.CallLater(800, self._flash_button, button, times, not will_hi_light) 269 | 270 | def _logo_clicked(self, event: wx.MoveEvent): 271 | if event.GetEventType() == wxEVT_LEFT_UP: 272 | from webbrowser import open 273 | open('https://www.tessi.eu/en/innovation-by-tessi/#tessi-lab') 274 | 275 | 276 | class DrawPanel(wx.Frame): 277 | 278 | button_state: int 279 | points: List[List[Tuple[int, int]]] 280 | model: Model 281 | 282 | """Draw to a panel.""" 283 | 284 | def __init__(self): 285 | wx.Frame.__init__(self, parent=None, id=wx.ID_ANY, title="Draw a digit on Panel", size=(modulus, modulus + 28)) 286 | self.Bind(wx.EVT_PAINT, self.OnPaint) 287 | self.Bind(wx.EVT_MOUSE_EVENTS, self.onMouseMove) 288 | 289 | self.button_state = 0 290 | self.points = [] 291 | self.model = ml_evaluate.load_model('./auto') 292 | self.Bind(wx.EVT_CLOSE, self.close) 293 | 294 | def close(self, event): 295 | print("Closing") 296 | exit(0) 297 | 298 | def onMouseMove(self, event: wx.MouseEvent): 299 | if event.GetButton() == wx.MOUSE_BTN_LEFT: 300 | if event.GetEventType() == wxEVT_LEFT_UP: 301 | self.button_state = 0 302 | self.predict() 303 | elif event.GetEventType() == wxEVT_LEFT_DOWN: 304 | self.button_state = 1 305 | self.points.append([]) 306 | self.points[-1].append((event.GetX(), event.GetY())) 307 | else: 308 | print(f"Event type {event.GetEventType()} <> {wxEVT_LEFT_DOWN} <> {wxEVT_LEFT_UP}") 309 | elif event.GetEventType() == wxEVT_RIGHT_UP and self.button_state == 0: 310 | if self.points: 311 | self.points.pop() 312 | if not self.points: 313 | self.SetTitle("Draw a digit on Panel") 314 | else: 315 | self.predict() 316 | 317 | elif self.button_state == 1: 318 | self.points[-1].append((event.GetX(), event.GetY())) 319 | self.Refresh() 320 | 321 | def OnPaint(self, event=None) -> wx.PaintDC: 322 | # dc.DrawLine(0,0, 100, 100) 323 | dc = wx.PaintDC(self) 324 | dc.SetBrush(wx.Brush("BLACK", wx.SOLID)) 325 | size = self.GetSize() 326 | width = size.width 327 | dc.DrawRectangle(0, 0, width, size.height) 328 | for points in self.points: 329 | ep = int(0.5 + width / 10) 330 | dc.SetPen(wx.Pen(wx.WHITE, ep)) 331 | previous = None 332 | for point in points: 333 | if previous is None: 334 | previous = point 335 | continue 336 | else: 337 | dc.DrawLine(previous[0], previous[1], point[0], point[1]) 338 | previous = point 339 | return dc 340 | 341 | def predict(self): 342 | image = self.paint_to_image() 343 | modulus_ = int(0.5 + modulus / 10) 344 | modulus_ = modulus_ if modulus_ % 2 == 1 else modulus_ + 1 345 | image = cv2.GaussianBlur(image, (modulus_, modulus_), 1.0) 346 | image = cv2.resize(image, (28, 28), interpolation=cv2.INTER_CUBIC) 347 | cv2.imwrite("28.png", image) 348 | digit = ml_evaluate.predict(self.model, image) 349 | self.SetTitle(f"You drew a {digit} on Panel") 350 | 351 | def paint_to_image(self): 352 | dcSource = self.OnPaint() 353 | size = dcSource.Size 354 | 355 | # Create a Bitmap that will later on hold the screenshot image 356 | # Note that the Bitmap must have a size big enough to hold the screenshot 357 | # -1 means using the current default colour depth 358 | bmp = wx.EmptyBitmap(size.width, size.height) 359 | 360 | # Create a memory DC that will be used for actually taking the screenshot 361 | memDC = wx.MemoryDC() 362 | 363 | # Tell the memory DC to use our Bitmap 364 | # all drawing action on the memory DC will go to the Bitmap now 365 | memDC.SelectObject(bmp) 366 | 367 | # Blit (in this case copy) the actual screen on the memory DC 368 | # and thus the Bitmap 369 | memDC.Blit(0, # Copy to this X coordinate 370 | 0, # Copy to this Y coordinate 371 | size.width, # Copy this width 372 | size.height, # Copy this height 373 | dcSource, # From where do we copy? 374 | 0, # What's the X offset in the original DC? 375 | 0 # What's the Y offset in the original DC? 376 | ) 377 | 378 | # Select the Bitmap out of the memory DC by selecting a new 379 | # uninitialized Bitmap 380 | memDC.SelectObject(wx.NullBitmap) 381 | 382 | img = bmp.ConvertToImage() 383 | img.SaveFile('saved.png', wx.BITMAP_TYPE_PNG) 384 | return cv2.imread("saved.png", cv2.IMREAD_GRAYSCALE) 385 | 386 | 387 | def run(app: wx.App): 388 | frame = MainForm() 389 | frame.Show() 390 | 391 | app.MainLoop() 392 | 393 | 394 | # Press the green button in the gutter to run the script. 395 | if __name__ == '__main__': 396 | app = wx.App(False) 397 | run(app) 398 | -------------------------------------------------------------------------------- /ml_evaluate.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from keras.models import Model 6 | from tensorflow.keras.models import load_model 7 | 8 | 9 | def load_mnist_model(filename: str) -> Any: 10 | return load_model(filename) 11 | 12 | 13 | def normalize_img(image: np.ndarray): 14 | """Normalizes images: `uint8` -> `float32`.""" 15 | return tf.cast(image, tf.float32) / 255. 16 | 17 | 18 | def predict(model: Model, image: np.ndarray) -> int: 19 | img = normalize_img(image) 20 | img = np.expand_dims(img, 0) 21 | img = np.expand_dims(img, -1) 22 | return model.predict(img).argmax() -------------------------------------------------------------------------------- /ml_training.py: -------------------------------------------------------------------------------- 1 | import lzma 2 | import random 3 | import time 4 | from typing import Dict, List 5 | 6 | import cv2 7 | from mnist import MNIST 8 | import tensorflow as tf 9 | import numpy as np 10 | 11 | from time_util import timeit 12 | 13 | reshape = True 14 | epoch = 20 15 | batch_size = 128 16 | 17 | level: Dict[str, List[int]] = { 18 | 'light': [2, 4], 19 | 'basic': [4, 8], 20 | 'normal': [8, 16], 21 | 'heavy': [16, 32], 22 | 'too-heavy': [64, 128], 23 | 'stupid': [128, 256], 24 | 'insane': [256, 512], 25 | } 26 | 27 | 28 | def print_sample(sample: np.ndarray, **kwargs): 29 | print(f"+{'-' * 28}+") 30 | for y in range(28): 31 | ss = "" 32 | for x in range(28): 33 | v = sample[y * 28 + x] 34 | # print(v, sep=None) 35 | if v == 0: 36 | s = ' ' 37 | elif v < 50: 38 | s = '.' 39 | elif v < 128: 40 | s = 'o' 41 | elif v < 200: 42 | s = "X" 43 | else: 44 | s = "W" 45 | ss += s 46 | print(f"|{ss}|") 47 | print(f"+{'-' * 28}+") 48 | if kwargs is not None and 'num' in kwargs and kwargs.get("save_it") is True: 49 | cv2.imwrite(f"sample-{kwargs['num']}.png", np.reshape(sample, (28, 28))) 50 | 51 | 52 | class MNIST_lzma(MNIST): 53 | def __init__(self, path='.', mode='vanilla', return_type='lists', gz=False, lzma=False): 54 | super().__init__(path, mode, return_type, gz) 55 | self.lzma = lzma 56 | 57 | def opener(self, path_fn, *args, **kwargs): 58 | if self.lzma: 59 | return lzma.open(path_fn + '.lzma', *args, **kwargs) 60 | else: 61 | return super().opener(path_fn, *args, **kwargs) 62 | 63 | 64 | @timeit 65 | def load_train(**kwargs): 66 | print("Loading train data") 67 | mndata = MNIST_lzma('./MNIST/', return_type='numpy', lzma=True) 68 | 69 | train, tr_lbl = mndata.load_training() 70 | rnd = int(len(tr_lbl) * random.random()) 71 | print(f"This is a {tr_lbl[rnd]}") 72 | print_sample(train[rnd], save_it=False, num=tr_lbl[rnd]) 73 | return train, tr_lbl 74 | 75 | 76 | @timeit 77 | def load_test(**kwargs): 78 | print("Loading test data") 79 | mndata = MNIST_lzma('./MNIST/', return_type='numpy', lzma=True) 80 | 81 | test, te_lbl = mndata.load_testing() 82 | rnd = int(len(te_lbl) * random.random()) 83 | print(f"This is a {te_lbl[rnd]}") 84 | print_sample(test[rnd], save_it=False, num=te_lbl[rnd]) 85 | return test, te_lbl 86 | 87 | 88 | def to_dataset(images, labels): 89 | ds = [] 90 | for i, lbl in enumerate(labels): 91 | if reshape: 92 | image = np.reshape(images[i], (28, 28, 1)) 93 | else: 94 | image = images[i] 95 | ds.append((image, int(lbl))) 96 | 97 | return tf.data.Dataset.from_generator(lambda: ds, (tf.int32, tf.int32)) 98 | 99 | 100 | def normalize_img(image, label): 101 | """Normalizes images: `uint8` -> `float32`.""" 102 | print(f"label: {label}") 103 | return tf.cast(image, tf.float32) / 255., label 104 | 105 | 106 | @timeit 107 | def create_train(**kwargs): 108 | train, tr_lbl = load_train(**kwargs) 109 | ds_train = to_dataset(train, tr_lbl).map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE) 110 | ds_train = ds_train.cache() 111 | # ds_train = ds_train.shuffle(len(tr_lbl)) #ds_info.splits['train'].num_examples) 112 | ds_train = ds_train.batch(batch_size) 113 | ds_train = ds_train.prefetch(tf.data.AUTOTUNE) 114 | print(len(list(ds_train.as_numpy_iterator()))) 115 | return ds_train 116 | 117 | 118 | @timeit 119 | def create_test(**kwargs): 120 | test, te_lbl = load_test(**kwargs) 121 | ds_test = to_dataset(test, te_lbl).map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE) 122 | ds_test = ds_test.cache() 123 | ds_test = ds_test.batch(batch_size) 124 | ds_test = ds_test.prefetch(tf.data.AUTOTUNE) 125 | return ds_test 126 | 127 | 128 | @timeit 129 | def train(train, test, level_name: str, **kwargs): 130 | model = tf.keras.models.Sequential([ 131 | tf.keras.layers.Input(shape=(28, 28, 1)), 132 | tf.keras.layers.Conv2D(level[level_name][0], (5, 5), activation='relu'), 133 | tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='valid'), 134 | tf.keras.layers.Conv2D(level[level_name][1], (3, 3), activation='relu'), 135 | tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='valid'), 136 | 137 | tf.keras.layers.Flatten(), 138 | tf.keras.layers.Dense(256, activation='relu'), 139 | 140 | tf.keras.layers.Dropout(0.8), 141 | tf.keras.layers.Dense(10) 142 | ]) 143 | 144 | model.compile( 145 | optimizer=tf.keras.optimizers.Adam(0.01), 146 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 147 | metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], 148 | ) 149 | 150 | if not reshape: 151 | callbacks = [] 152 | else: 153 | plateau = tf.keras.callbacks.ReduceLROnPlateau( 154 | monitor='val_loss', factor=0.1, patience=3, verbose=0, 155 | mode='auto', min_delta=0.0001, cooldown=0, min_lr=0 156 | ) 157 | checkpoint = tf.keras.callbacks.ModelCheckpoint( 158 | './checkpoints', monitor='val_loss', verbose=0, save_best_only=False, 159 | save_weights_only=False, mode='auto', save_freq='epoch', 160 | options=None 161 | ) 162 | stopping = tf.keras.callbacks.EarlyStopping( 163 | monitor='val_loss', min_delta=0, patience=5, verbose=0, 164 | mode='auto', baseline=None, restore_best_weights=True 165 | ) 166 | 167 | callbacks = [plateau, checkpoint] # , stopping] 168 | 169 | history = model.fit( 170 | train, 171 | epochs=epoch, 172 | validation_data=test, 173 | callbacks=callbacks, 174 | verbose=2 175 | ) 176 | model.save("./auto") 177 | return history, model 178 | 179 | 180 | def create_dataset(x, y): 181 | ds_train = to_dataset(x, y).map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE) 182 | ds_train = ds_train.cache() 183 | ds_train = ds_train.prefetch(tf.data.AUTOTUNE) 184 | return ds_train 185 | 186 | 187 | def do_it(level_name: str, **kwargs): 188 | if kwargs.get("tune_bs") == True: 189 | tr, tr_lbl = load_train(**kwargs) 190 | te, te_lbl = load_test(**kwargs) 191 | ds_train = create_dataset(tr, tr_lbl) 192 | ds_test = create_dataset(te, te_lbl) 193 | 194 | tune_batch_size(ds_test, ds_train, level_name) 195 | return None, None 196 | else: 197 | ds_train = create_train(**kwargs) 198 | ds_test = create_test(**kwargs) 199 | 200 | return train(ds_train, ds_test, level_name, **kwargs) 201 | 202 | 203 | @timeit 204 | def tune_batch_size(ds_test, ds_train, level_name): 205 | global batch_size, epoch 206 | batch_size = 16 207 | epoch = 2 208 | duration = 1e300 209 | selected_bs: int = batch_size 210 | with open('sweet_spot.txt', 'w') as f: 211 | for i in range(10): 212 | batch_size = batch_size * 2 213 | print(f"testing with batch_size = {batch_size}") 214 | current_train_ds = ds_train.batch(batch_size) 215 | current_test_ds = ds_test.batch(batch_size) 216 | start = time.time_ns() 217 | train(current_train_ds, current_test_ds, level_name) 218 | end = time.time_ns() 219 | d = end - start 220 | if d < duration: 221 | duration = d 222 | selected_bs = batch_size 223 | duration = min(d, duration) 224 | f.write(f"{batch_size}, {d}\n") 225 | f.flush() 226 | f.write(f"#selected: {selected_bs}") 227 | print(f"Selected batch size: {selected_bs}") 228 | batch_size = selected_bs 229 | 230 | 231 | if __name__ == "__main__": 232 | # h, m = do_it('basic') 233 | h, m = do_it('basic', tune_bs=True) 234 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.15.0 2 | anyio==3.3.4 3 | appnope==0.1.2 4 | argon2-cffi==21.1.0 5 | astunparse==1.6.3 6 | attrs==21.2.0 7 | Babel==2.9.1 8 | backcall==0.2.0 9 | bleach==4.1.0 10 | cachetools==4.2.4 11 | certifi==2021.10.8 12 | cffi==1.15.0 13 | charset-normalizer==2.0.7 14 | clang==5.0 15 | cycler==0.11.0 16 | debugpy==1.5.1 17 | decorator==5.1.0 18 | defusedxml==0.7.1 19 | entrypoints==0.3 20 | flatbuffers==1.12 21 | gast==0.4.0 22 | google-auth==2.3.2 23 | google-auth-oauthlib==0.4.6 24 | google-pasta==0.2.0 25 | grpcio==1.41.1 26 | h5py==3.1.0 27 | idna==3.3 28 | ipykernel==6.4.2 29 | ipython==7.28.0 30 | ipython-genutils==0.2.0 31 | jedi==0.18.0 32 | Jinja2==3.0.2 33 | joblib==1.1.0 34 | json5==0.9.6 35 | jsonschema==4.1.2 36 | jupyter-client==7.0.6 37 | jupyter-core==4.9.1 38 | jupyter-server==1.11.1 39 | jupyterlab==3.2.1 40 | jupyterlab-pygments==0.1.2 41 | jupyterlab-server==2.8.2 42 | keras==2.6.0 43 | Keras-Preprocessing==1.1.2 44 | kiwisolver==1.3.2 45 | Markdown==3.3.4 46 | MarkupSafe==2.0.1 47 | matplotlib==3.4.3 48 | matplotlib-inline==0.1.3 49 | mistune==0.8.4 50 | nbclassic==0.3.4 51 | nbclient==0.5.4 52 | nbconvert==6.2.0 53 | nbformat==5.1.3 54 | nest-asyncio==1.5.1 55 | notebook==6.4.5 56 | numpy==1.19.5 57 | oauthlib==3.1.1 58 | opencv-python==4.5.4.58 59 | opt-einsum==3.3.0 60 | packaging==21.2 61 | pandocfilters==1.5.0 62 | parso==0.8.2 63 | pexpect==4.8.0 64 | pickleshare==0.7.5 65 | Pillow==8.4.0 66 | prometheus-client==0.11.0 67 | prompt-toolkit==3.0.21 68 | protobuf==3.19.1 69 | ptyprocess==0.7.0 70 | pyasn1==0.4.8 71 | pyasn1-modules==0.2.8 72 | pycparser==2.20 73 | Pygments==2.10.0 74 | pyparsing==2.4.7 75 | pyrsistent==0.18.0 76 | python-dateutil==2.8.2 77 | python-mnist==0.7 78 | pytz==2021.3 79 | pyzmq==22.3.0 80 | requests==2.26.0 81 | requests-oauthlib==1.3.0 82 | requests-unixsocket==0.2.0 83 | rsa==4.7.2 84 | scikit-learn==1.0.1 85 | scipy==1.7.1 86 | Send2Trash==1.8.0 87 | six==1.15.0 88 | sklearn==0.0 89 | sniffio==1.2.0 90 | tensorboard==2.7.0 91 | tensorboard-data-server==0.6.1 92 | tensorboard-plugin-wit==1.8.0 93 | tensorflow-addons==0.14.0 94 | tensorflow-estimator==2.6.0 95 | tensorflow==2.6.0 96 | termcolor==1.1.0 97 | terminado==0.12.1 98 | testpath==0.5.0 99 | tf2crf==0.1.33 100 | threadpoolctl==3.0.0 101 | tornado==6.1 102 | traitlets==5.1.1 103 | typeguard==2.13.0 104 | typing-extensions==3.7.4.3 105 | urllib3==1.26.7 106 | wcwidth==0.2.5 107 | webencodings==0.5.1 108 | websocket-client==1.2.1 109 | Werkzeug==2.0.2 110 | wrapt==1.12.1 111 | wxPython==4.1.1 112 | -------------------------------------------------------------------------------- /requirements_macos_intel.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.15.0 2 | anyio==3.3.4 3 | appnope==0.1.2 4 | argon2-cffi==21.1.0 5 | astunparse==1.6.3 6 | attrs==21.2.0 7 | Babel==2.9.1 8 | backcall==0.2.0 9 | bleach==4.1.0 10 | cachetools==4.2.4 11 | certifi==2021.10.8 12 | cffi==1.15.0 13 | charset-normalizer==2.0.7 14 | clang==5.0 15 | cycler==0.11.0 16 | debugpy==1.5.1 17 | decorator==5.1.0 18 | defusedxml==0.7.1 19 | entrypoints==0.3 20 | flatbuffers==1.12 21 | gast==0.4.0 22 | google-auth==2.3.2 23 | google-auth-oauthlib==0.4.6 24 | google-pasta==0.2.0 25 | grpcio==1.41.1 26 | h5py==3.1.0 27 | idna==3.3 28 | ipykernel==6.4.2 29 | ipython==7.28.0 30 | ipython-genutils==0.2.0 31 | jedi==0.18.0 32 | Jinja2==3.0.2 33 | joblib==1.1.0 34 | json5==0.9.6 35 | jsonschema==4.1.2 36 | jupyter-client==7.0.6 37 | jupyter-core==4.9.1 38 | jupyter-server==1.11.1 39 | jupyterlab==3.2.1 40 | jupyterlab-pygments==0.1.2 41 | jupyterlab-server==2.8.2 42 | keras==2.6.0 43 | Keras-Preprocessing==1.1.2 44 | kiwisolver==1.3.2 45 | Markdown==3.3.4 46 | MarkupSafe==2.0.1 47 | matplotlib==3.4.3 48 | matplotlib-inline==0.1.3 49 | mistune==0.8.4 50 | nbclassic==0.3.4 51 | nbclient==0.5.4 52 | nbconvert==6.2.0 53 | nbformat==5.1.3 54 | nest-asyncio==1.5.1 55 | notebook==6.4.5 56 | numpy==1.19.5 57 | oauthlib==3.1.1 58 | opencv-python==4.5.4.58 59 | opt-einsum==3.3.0 60 | packaging==21.2 61 | pandocfilters==1.5.0 62 | parso==0.8.2 63 | pexpect==4.8.0 64 | pickleshare==0.7.5 65 | Pillow==8.4.0 66 | prometheus-client==0.11.0 67 | prompt-toolkit==3.0.21 68 | protobuf==3.19.1 69 | ptyprocess==0.7.0 70 | pyasn1==0.4.8 71 | pyasn1-modules==0.2.8 72 | pycparser==2.20 73 | Pygments==2.10.0 74 | pyparsing==2.4.7 75 | pyrsistent==0.18.0 76 | python-dateutil==2.8.2 77 | python-mnist==0.7 78 | pytz==2021.3 79 | pyzmq==22.3.0 80 | requests==2.26.0 81 | requests-oauthlib==1.3.0 82 | requests-unixsocket==0.2.0 83 | rsa==4.7.2 84 | scikit-learn==1.0.1 85 | scipy==1.7.1 86 | Send2Trash==1.8.0 87 | six==1.15.0 88 | sklearn==0.0 89 | sniffio==1.2.0 90 | tensorboard==2.7.0 91 | tensorboard-data-server==0.6.1 92 | tensorboard-plugin-wit==1.8.0 93 | tensorflow-addons==0.14.0 94 | tensorflow-estimator==2.6.0 95 | tensorflow-macos==2.6.0 96 | tensorflow-metal==0.2.0 97 | termcolor==1.1.0 98 | terminado==0.12.1 99 | testpath==0.5.0 100 | tf2crf==0.1.33 101 | threadpoolctl==3.0.0 102 | tornado==6.1 103 | traitlets==5.1.1 104 | typeguard==2.13.0 105 | typing-extensions==3.7.4.3 106 | urllib3==1.26.7 107 | wcwidth==0.2.5 108 | webencodings==0.5.1 109 | websocket-client==1.2.1 110 | Werkzeug==2.0.2 111 | wrapt==1.12.1 112 | wxPython==4.1.1 113 | -------------------------------------------------------------------------------- /time_util.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | def timeit(method): 5 | def timed(*args, **kw): 6 | ts = time.time() 7 | result = method(*args, **kw) 8 | te = time.time() 9 | in_kw = 'log_time' in kw 10 | if in_kw: 11 | name = kw.get('log_name', method.__name__.upper()) 12 | kw['log_time'][name] = int(0.5 + (te - ts) * 1000) 13 | if not in_kw or kw.get('log_verbose') is True: 14 | print('%r %2.2f ms' % (method.__name__, (te - ts) * 1000)) 15 | return result 16 | return timed --------------------------------------------------------------------------------