├── .gitignore ├── LICENSE ├── README.md ├── micromobilenet ├── __init__.py ├── architectures │ ├── BaseMobileNet.py │ ├── Config.py │ ├── MicroMobileNet.py │ ├── MilliMobileNet.py │ ├── MobileNet.py │ ├── NanoMobileNet.py │ ├── PicoMobileNet.py │ └── __init__.py ├── convert │ ├── Environment.py │ ├── LayerData.py │ ├── Loader.py │ ├── MobileNetConverter.py │ ├── __init__.py │ └── templates │ │ ├── BaseMobileNet.jinja │ │ ├── ops │ │ ├── argmax.jinja │ │ ├── conv3x3x1.jinja │ │ ├── depthwise_conv.jinja │ │ ├── dot.jinja │ │ ├── maxpool.jinja │ │ ├── mult3x3.jinja │ │ ├── pad.jinja │ │ ├── pointwise_conv.jinja │ │ └── softmax.jinja │ │ └── predict_file.jinja ├── converters.py ├── load.py ├── runner.py └── utils.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | .idea 3 | __pycache__ 4 | dist 5 | publish.sh 6 | setup_template.py 7 | micromobilenet.egg-info -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 eloquentarduino 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MicroMobileNet 2 | 3 | Variations of MobileNetV1 meant to run on resource-constrained embedded hardware (a.k.a. microcontrollers). 4 | 5 | Blog post at: https://eloquentarduino.com/posts/micro-mobilenet 6 | 7 | 8 | ## Install 9 | 10 | ```bash 11 | pip install -U micromobilenet 12 | ``` 13 | 14 | ## Run 15 | 16 | ```python 17 | from micromobilenet import PicoMobileNet 18 | from micromobilenet import NanoMobileNet 19 | from micromobilenet import MicroMobileNet 20 | from micromobilenet import MilliMobileNet 21 | from micromobilenet import MobileNet 22 | 23 | 24 | if __name__ == '__main__': 25 | net = PicoMobileNet(num_classes=num_classes) 26 | net.config.learning_rate = 0.001 27 | net.config.batch_size = 32 28 | net.config.verbosity = 1 29 | net.config.checkpoint_min_accuracy = 0.65 30 | net.config.loss = "categorical_crossentropy" 31 | net.config.metrics = ["categorical_accuracy"] 32 | net.config.checkpoint_path = "checkpoints/pico" 33 | 34 | net.build() 35 | net.compile() 36 | # train_x is of shape (None, 96, 96, 1) 37 | # train_y is one-hot encoded 38 | net.fit(train_x, train_y, val_x, val_y, epochs=30) 39 | 40 | print(net.convert.to_cpp()) 41 | 42 | """ 43 | /** 44 | * "Compiled" implementation of modified MobileNet 45 | */ 46 | class PicoMobileNet { 47 | public: 48 | const uint16_t numInputs = 9216; 49 | const uint16_t numOutputs = 4; 50 | float outputs[4]; 51 | float arena[6936]; 52 | uint16_t output; 53 | float proba; 54 | 55 | /** 56 | * 57 | */ 58 | MobileNet() : output(0), proba(0) { 59 | for (uint16_t i = 0; i < numOutputs; i++) 60 | outputs[i] = 0; 61 | } 62 | 63 | /** 64 | * 65 | * @param input 66 | */ 67 | uint16_t predict(float *input) { 68 | ... 69 | }; 70 | """ 71 | ``` 72 | 73 | ## Deploy 74 | 75 | ```c++ 76 | // sample image is a float[96 * 96] array 77 | #include "sample_image.h" 78 | #include "MobileNet.h" 79 | 80 | MobileNet net; 81 | 82 | void setup() { 83 | Serial.begin(115200); 84 | Serial.println("MobileNet demo"); 85 | 86 | // no complicated setup! 87 | } 88 | 89 | void loop() { 90 | size_t start = micros(); 91 | net.predict(sample_image); 92 | 93 | Serial.print("Predicted output = "); 94 | Serial.println(net.output); 95 | Serial.print("It took "); 96 | Serial.print(micros() - start); 97 | Serial.println(" us to run MobileNet"); 98 | delay(2000); 99 | } 100 | ``` -------------------------------------------------------------------------------- /micromobilenet/__init__.py: -------------------------------------------------------------------------------- 1 | from micromobilenet.architectures import * -------------------------------------------------------------------------------- /micromobilenet/architectures/BaseMobileNet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path 3 | from typing import List, Generator, Iterable 4 | 5 | import numpy as np 6 | from cached_property import cached_property 7 | from keras import Sequential 8 | from keras.optimizers import Adam 9 | from keras.callbacks import ModelCheckpoint 10 | from keras.layers import Input, Conv2D, Reshape, Softmax, ZeroPadding2D, ReLU, DepthwiseConv2D 11 | from micromobilenet.convert.MobileNetConverter import MobileNetConverter 12 | from micromobilenet.architectures.Config import Config 13 | 14 | 15 | class BaseMobileNet: 16 | """ 17 | Base class for BaseMobileNet architectures 18 | """ 19 | def __init__(self, num_classes: int): 20 | """ 21 | 22 | """ 23 | self.num_classes = num_classes 24 | self.history = None 25 | self.layers = [] 26 | self.config = Config() 27 | self.model = None 28 | self.i = 1 29 | 30 | def __repr__(self): 31 | """ 32 | 33 | :return: 34 | """ 35 | self.model.summary() 36 | return str(self.model) 37 | 38 | @property 39 | def weights_file(self) -> str: 40 | """ 41 | Get path to weights file 42 | :return: 43 | """ 44 | return f"{self.config.checkpoint_path}.weights.h5" if self.config.checkpoint_path != "" else "" 45 | 46 | @cached_property 47 | def convert(self) -> MobileNetConverter: 48 | """ 49 | Get instance of C++ converter 50 | :return: 51 | """ 52 | return MobileNetConverter(self) 53 | 54 | def build(self): 55 | """ 56 | Generate model 57 | :return: 58 | """ 59 | self.i = 1 60 | self.model = Sequential() 61 | self.add(Input(shape=(96, 96, 1), name="input")) 62 | 63 | # add middle layers 64 | for layers in self.make_layers(): 65 | if isinstance(layers, Iterable): 66 | layers = list(layers) 67 | 68 | if not isinstance(layers, List): 69 | layers = [layers] 70 | 71 | for l in layers: 72 | self.add(l) 73 | 74 | # head 75 | self.add(Conv2D(self.num_classes, (1, 1), padding="same", name="conv2d_last")) 76 | self.add(Reshape((self.num_classes,), name="reshape")) 77 | self.add(Softmax(name="softmax")) 78 | 79 | def add(self, layer): 80 | """ 81 | Add layer 82 | :param layer: 83 | :return: 84 | """ 85 | self.model.add(layer) 86 | self.layers.append(layer) 87 | 88 | def load_weights(self, abort_on_fail: bool = True): 89 | """ 90 | Load checkpoint 91 | :return: 92 | """ 93 | assert self.config.checkpoint_path != "", "you must set net.config.checkpoint_path!" 94 | 95 | if os.path.isfile(self.weights_file): 96 | self.model.load_weights(self.weights_file) 97 | else: 98 | logging.warning(f"Cannot load weight file {self.weights_file}") 99 | 100 | if abort_on_fail: 101 | raise FileNotFoundError(self.weights_file) 102 | 103 | return self 104 | 105 | def compile(self): 106 | """ 107 | Compile model 108 | :return: 109 | """ 110 | if self.model is None: 111 | self.build() 112 | 113 | self.model.compile( 114 | optimizer=Adam(learning_rate=self.config.learning_rate), 115 | loss=self.config.loss, 116 | metrics=self.config.metrics, 117 | ) 118 | 119 | return self 120 | 121 | def fit(self, train_x: np.ndarray, train_y: np.ndarray, val_x: np.ndarray, val_y: np.ndarray, epochs: int = 100): 122 | """ 123 | Fit model 124 | :param train_x: 125 | :param train_y: 126 | :param val_x: 127 | :param val_y: 128 | :return: 129 | """ 130 | callbacks = [] 131 | 132 | if self.weights_file != "": 133 | callbacks.append(ModelCheckpoint( 134 | self.weights_file, 135 | monitor=f"val_{self.config.metrics[0]}", 136 | verbose=1, 137 | save_best_only=True, 138 | save_weights_only=True, 139 | initial_value_threshold=self.config.checkpoint_min_accuracy 140 | )) 141 | 142 | self.history = self.model.fit( 143 | train_x, 144 | train_y, 145 | validation_data=(val_x, val_y), 146 | batch_size=self.config.batch_size, 147 | epochs=epochs, 148 | verbose=self.config.verbosity, 149 | callbacks=callbacks 150 | ) 151 | 152 | return self 153 | 154 | def predict(self, xs: np.ndarray) -> np.ndarray: 155 | """ 156 | Predict 157 | :param xs: 158 | :return: 159 | """ 160 | return self.model.predict(xs) 161 | 162 | def make_depthwise(self, filters: int, stride: int = 1, padding: str = "same") -> Generator: 163 | """ 164 | Generate depthwise + pointwise layers 165 | :param padding: 166 | :param filters: 167 | :param stride: 168 | :return: 169 | """ 170 | i = self.i 171 | self.i += 1 172 | 173 | if padding == "same": 174 | yield ZeroPadding2D(name=f"hidden_{i}__padding") 175 | 176 | yield DepthwiseConv2D((3, 3), padding="valid", strides=(stride, stride), use_bias=False, name=f"hidden_{i}__dw") 177 | yield ReLU(6., name=f"hidden_{i}__relu_1") 178 | yield Conv2D(filters, (1, 1), padding="same", strides=(1, 1), use_bias=False, name=f"hidden_{i}__pw") 179 | yield ReLU(6., name=f"hidden_{i}__relu_2") 180 | -------------------------------------------------------------------------------- /micromobilenet/architectures/Config.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | """ 3 | BaseMobileNet config object 4 | """ 5 | def __init__(self): 6 | """ 7 | 8 | """ 9 | self.learning_rate = 0.001 10 | self.loss = "sparse_categorical_crossentropy" 11 | self.metrics = ["sparse_categorical_accuracy"] 12 | self.checkpoint_min_accuracy = 0.7 13 | self.batch_size = 32 14 | self.verbosity = 1 15 | self.checkpoint_path = "" 16 | -------------------------------------------------------------------------------- /micromobilenet/architectures/MicroMobileNet.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D, MaxPool2D, Dropout 2 | from micromobilenet.architectures.BaseMobileNet import BaseMobileNet 3 | 4 | 5 | class MicroMobileNet(BaseMobileNet): 6 | def make_layers(self): 7 | yield Conv2D(3, (3, 3), padding="valid", use_bias=False, strides=(2, 2), name="conv2d_0") 8 | yield self.make_depthwise(filters=6) 9 | yield self.make_depthwise(filters=12, stride=2) 10 | yield self.make_depthwise(filters=12) 11 | yield self.make_depthwise(filters=24, stride=2) 12 | yield self.make_depthwise(filters=24) 13 | yield self.make_depthwise(filters=24, stride=2) 14 | yield self.make_depthwise(filters=24) 15 | yield self.make_depthwise(filters=24, stride=2) 16 | yield MaxPool2D((3, 3), name="maxpool_last") 17 | yield Dropout(0.1, name="dropout") 18 | -------------------------------------------------------------------------------- /micromobilenet/architectures/MilliMobileNet.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D, MaxPool2D, Dropout 2 | from micromobilenet.architectures.BaseMobileNet import BaseMobileNet 3 | 4 | 5 | class MilliMobileNet(BaseMobileNet): 6 | def make_layers(self): 7 | """ 8 | 9 | """ 10 | yield Conv2D(3, (3, 3), padding="valid", use_bias=False, strides=(2, 2), name="conv2d_0") 11 | yield self.make_depthwise(filters=6) 12 | yield self.make_depthwise(filters=12, stride=2) 13 | yield self.make_depthwise(filters=12) 14 | yield self.make_depthwise(filters=24, stride=2) 15 | yield self.make_depthwise(filters=24) 16 | yield self.make_depthwise(filters=48, stride=2) 17 | yield self.make_depthwise(filters=48) 18 | yield self.make_depthwise(filters=48, stride=2) 19 | yield self.make_depthwise(filters=48) 20 | yield MaxPool2D((3, 3), name="maxpool_last") 21 | yield Dropout(0.1, name="dropout") 22 | 23 | -------------------------------------------------------------------------------- /micromobilenet/architectures/MobileNet.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D, MaxPool2D, Dropout 2 | from micromobilenet.architectures.BaseMobileNet import BaseMobileNet 3 | 4 | 5 | class MobileNet(BaseMobileNet): 6 | def make_layers(self): 7 | """ 8 | 9 | """ 10 | yield Conv2D(3, (3, 3), padding="valid", use_bias=False, strides=(2, 2), name="conv2d_0") 11 | yield self.make_depthwise(filters=6) 12 | yield self.make_depthwise(filters=12, stride=2) 13 | yield self.make_depthwise(filters=12) 14 | yield self.make_depthwise(filters=24, stride=2) 15 | yield self.make_depthwise(filters=24) 16 | yield self.make_depthwise(filters=48, stride=2) 17 | yield self.make_depthwise(filters=48) 18 | yield self.make_depthwise(filters=48) 19 | yield self.make_depthwise(filters=48) 20 | yield self.make_depthwise(filters=48) 21 | yield self.make_depthwise(filters=96, stride=2) 22 | yield self.make_depthwise(filters=96) 23 | yield MaxPool2D((3, 3), name="maxpool_last") 24 | yield Dropout(0.1, name="dropout") 25 | 26 | -------------------------------------------------------------------------------- /micromobilenet/architectures/NanoMobileNet.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D, MaxPool2D, Dropout 2 | from micromobilenet.architectures.BaseMobileNet import BaseMobileNet 3 | 4 | 5 | class NanoMobileNet(BaseMobileNet): 6 | def make_layers(self): 7 | yield Conv2D(3, (3, 3), padding="valid", use_bias=False, strides=(2, 2), name="conv2d_0") 8 | yield self.make_depthwise(filters=6, stride=2) 9 | yield self.make_depthwise(filters=12, stride=2) 10 | yield self.make_depthwise(filters=24, stride=2) 11 | yield self.make_depthwise(filters=24, stride=2) 12 | yield MaxPool2D((3, 3), name="maxpool_last") 13 | yield Dropout(0.1, name="dropout") 14 | -------------------------------------------------------------------------------- /micromobilenet/architectures/PicoMobileNet.py: -------------------------------------------------------------------------------- 1 | from keras.layers import Conv2D, MaxPool2D, Dropout 2 | from micromobilenet.architectures.BaseMobileNet import BaseMobileNet 3 | 4 | 5 | class PicoMobileNet(BaseMobileNet): 6 | def make_layers(self): 7 | yield Conv2D(3, (3, 3), padding="valid", use_bias=False, strides=(3, 3), name="conv2d_0") 8 | yield self.make_depthwise(filters=6, stride=2) 9 | yield self.make_depthwise(filters=12, stride=2) 10 | yield self.make_depthwise(filters=24, stride=2) 11 | yield MaxPool2D((4, 4), name="maxpool_last") 12 | yield Dropout(0.1, name="dropout") 13 | -------------------------------------------------------------------------------- /micromobilenet/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | from micromobilenet.architectures.PicoMobileNet import PicoMobileNet 2 | from micromobilenet.architectures.NanoMobileNet import NanoMobileNet 3 | from micromobilenet.architectures.MicroMobileNet import MicroMobileNet 4 | from micromobilenet.architectures.MobileNet import MobileNet 5 | from micromobilenet.architectures.MilliMobileNet import MilliMobileNet 6 | -------------------------------------------------------------------------------- /micromobilenet/convert/Environment.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from jinja2 import Environment as Base 4 | from os.path import normpath, join, dirname 5 | from math import ceil, floor 6 | import numpy as np 7 | 8 | 9 | class Environment(Base): 10 | """ 11 | Override default Environment 12 | """ 13 | def __init__(self, *args, **kwargs): 14 | """ 15 | 16 | :param args: 17 | :param kwargs: 18 | """ 19 | filters = kwargs.pop("filters", {}) 20 | globals = kwargs.pop("globals", {}) 21 | kwargs.setdefault("extensions", []) 22 | 23 | super().__init__(*args, **kwargs) 24 | self._add_filters() 25 | self._add_globals() 26 | self.filters.update(filters) 27 | self.globals.update(globals) 28 | 29 | def join_path(self, template: str, parent: str) -> str: 30 | """ 31 | 32 | :param template: 33 | :param parent: 34 | :return: 35 | """ 36 | return normpath(join(dirname(parent), template)) 37 | 38 | def _add_filters(self): 39 | """ 40 | Add language-agnostic filters 41 | :return: 42 | """ 43 | def to_array(arr) -> str: 44 | values = ", ".join("%.11f" % x for x in arr.flatten()) 45 | return f"{{{values}}}" 46 | 47 | def to_weights_shape(weights: np.ndarray) -> str: 48 | h, w, c, d = weights.shape 49 | 50 | if d == 1: 51 | # depthwise kernel 52 | return f"[{c}][{h * w}]" 53 | 54 | return f"[{d}][{h * w * c}]" 55 | 56 | def to_weights_array(weights: np.ndarray) -> str: 57 | h, w, c, d = weights.shape 58 | 59 | if d == 1: 60 | # depthwise kernel 61 | values = ",\n".join(to_array(weights[:, :, i]) for i in range(c)) 62 | else: 63 | values = ",\n".join(to_array(weights[:, :, :, i]) for i in range(d)) 64 | 65 | return f"{{{values}}}" 66 | 67 | self.filters.update({ 68 | "ceil": ceil, 69 | "floor": floor, 70 | "to_array": to_array, 71 | "to_weights_shape": to_weights_shape, 72 | "to_weights_array": to_weights_array 73 | }) 74 | 75 | def _add_globals(self): 76 | """ 77 | Add language-agnostic globals 78 | :return: 79 | """ 80 | self.globals.update({ 81 | "np": np, 82 | "len": len, 83 | "zip": zip, 84 | "int": int, 85 | "ceil": ceil, 86 | "eps": 0.0001, 87 | "floor": floor, 88 | "range": range, 89 | "sorted": sorted, 90 | "enumerate": enumerate, 91 | "isinstance": isinstance 92 | }) 93 | -------------------------------------------------------------------------------- /micromobilenet/convert/LayerData.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LayerData: 5 | """ 6 | Wrap layer to get its data 7 | """ 8 | def __init__(self, layer): 9 | """ 10 | 11 | :param layer: 12 | """ 13 | self.layer = layer 14 | 15 | def __repr__(self): 16 | """ 17 | Proxy 18 | :return: 19 | """ 20 | return repr(self.layer) 21 | 22 | def __getattr__(self, item): 23 | """ 24 | Proxy 25 | :param item: 26 | :return: 27 | """ 28 | return getattr(self.layer, item) 29 | 30 | @property 31 | def io(self): 32 | return getattr(self.layer, "_io", None) 33 | 34 | @property 35 | def input_shape(self): 36 | return self.layer.input.shape[1:] 37 | 38 | @property 39 | def output_shape(self): 40 | return self.layer.output.shape[1:] 41 | 42 | @property 43 | def weights(self): 44 | return self.io["weights"] if self.io is not None else np.asarray(self.layer.weights[0]) 45 | 46 | @property 47 | def bias(self): 48 | return self.io["bias"] if self.io is not None else self.layer.bias.numpy() 49 | 50 | -------------------------------------------------------------------------------- /micromobilenet/convert/Loader.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | from os.path import sep 3 | from jinja2 import FileSystemLoader as Base 4 | 5 | 6 | class Loader(Base): 7 | """ 8 | Override default FileSystemLoader 9 | """ 10 | def get_source(self, environment: "Environment", template: str) -> Tuple[str, str, Callable[[], bool]]: 11 | """ 12 | 13 | :param environment: 14 | :param template: 15 | :return: 16 | """ 17 | # normalize path separator for Windows and Unix 18 | template = template.replace(sep, "/") 19 | 20 | if not template.endswith(".jinja"): 21 | template = f"{template}.jinja" 22 | 23 | return super().get_source(environment, template) -------------------------------------------------------------------------------- /micromobilenet/convert/MobileNetConverter.py: -------------------------------------------------------------------------------- 1 | from itertools import groupby 2 | from os.path import join, dirname, realpath 3 | from typing import Dict 4 | 5 | import numpy as np 6 | 7 | from micromobilenet.convert.Environment import Environment 8 | from micromobilenet.convert.Loader import Loader 9 | from micromobilenet.convert.LayerData import LayerData 10 | 11 | 12 | class MobileNetConverter: 13 | """ 14 | Convert BaseMobileNet to C++ 15 | """ 16 | def __init__(self, net: "BaseMobileNet"): 17 | """ 18 | 19 | :param net: 20 | """ 21 | self.net = net 22 | 23 | def to_cpp(self, classname: str = None) -> str: 24 | """ 25 | Convert to C++ 26 | :param classname: 27 | :return: 28 | """ 29 | root = join(dirname(realpath(__file__)), "templates") 30 | loader = Loader(root) 31 | env = Environment(loader=loader) 32 | template = env.get_template("BaseMobileNet") 33 | data = self.get_data() 34 | 35 | if classname is not None: 36 | data.update(classname=classname) 37 | 38 | # render template 39 | output = template.render(data) 40 | 41 | return output 42 | 43 | def get_data(self) -> Dict: 44 | """ 45 | Get data for code generation 46 | :return: 47 | """ 48 | model = self.net.model 49 | classname = self.net.__class__.__name__ 50 | layers = [LayerData(l) for l in self.net.layers] 51 | inputs = layers[0] 52 | conv_0 = LayerData(model.get_layer("conv2d_0")) 53 | maxpool = LayerData(model.get_layer("maxpool_last")) 54 | conv_last = LayerData(model.get_layer("conv2d_last")) 55 | softmax = LayerData(model.get_layer("softmax")) 56 | 57 | # group hidden layers into chunks 58 | hidden_layers = [l for l in layers if l.name.startswith("hidden_")] 59 | hidden_layers = [list(ll) for _, ll in groupby(hidden_layers, key=lambda l: l.name.split("__")[0])] 60 | hidden_layers = [{l.name.split("__")[1]: l for l in chunk} for chunk in hidden_layers] 61 | 62 | num_inputs = np.product(inputs.shape[1:]) 63 | num_outputs = softmax.output_shape[-1] 64 | output_sizes = [np.product(l.output_shape) for l in layers[1:]] 65 | arena_size = max(output_sizes) 66 | 67 | return locals() -------------------------------------------------------------------------------- /micromobilenet/convert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eloquentarduino/micromobilenet/d3e2f8aab84a814d4a3cb03f0b69572f5bcd52fa/micromobilenet/convert/__init__.py -------------------------------------------------------------------------------- /micromobilenet/convert/templates/BaseMobileNet.jinja: -------------------------------------------------------------------------------- 1 | /** 2 | * "Compiled" implementation of modified MobileNet 3 | */ 4 | class {{ classname }} { 5 | public: 6 | const uint16_t numInputs = {{ num_inputs }}; 7 | const uint16_t numOutputs = {{ num_outputs }}; 8 | float outputs[{{ num_outputs }}]; 9 | float arena[{{ arena_size * 2 }}]; 10 | uint16_t output; 11 | float proba; 12 | 13 | /** 14 | * 15 | */ 16 | {{ classname }}() : output(0), proba(0) { 17 | for (uint16_t i = 0; i < numOutputs; i++) 18 | outputs[i] = 0; 19 | } 20 | 21 | /** 22 | * 23 | * @param input 24 | */ 25 | uint16_t predict(float *input) { 26 | float *ping = arena; 27 | float *pong = arena + {{ arena_size }}; 28 | 29 | // conv2d (0) 30 | for (int16_t d = 0; d < {{ conv_0.output_shape[2] }}; d++) 31 | this->conv2d_3x3x1(input, ping + {{ conv_0.output_shape[0] }} * {{ conv_0.output_shape[1] }} * d, conv2d_0_weights[d], {{ conv_0.input_shape[0] }}, {{ conv_0.strides[0] }}); 32 | 33 | {% for i, hidden in enumerate(hidden_layers) %} 34 | {% if 'padding' in hidden %} 35 | // padding ({{ i + 1 }}) 36 | for (int16_t d = 0; d < {{ hidden['padding'].input_shape[2] }}; d++) 37 | this->pad(ping + {{ hidden['padding'].input_shape[0] }} * {{ hidden['padding'].input_shape[1] }} * d, pong + {{ hidden['padding'].output_shape[0] }} * {{ hidden['padding'].output_shape[1] }} * d, {{ hidden['padding'].input_shape[0] }}); 38 | 39 | memcpy(ping, pong, sizeof(float) * {{ hidden['padding'].output_shape[0] }} * {{ hidden['padding'].output_shape[1] }} * {{ hidden['padding'].output_shape[2] }}); 40 | {% endif %} 41 | 42 | // depthwise ({{ i + 1 }}) 43 | for (int16_t d = 0; d < {{ hidden['dw'].input_shape[2] }}; d++) 44 | this->depthwise_conv(ping + {{ hidden['dw'].input_shape[0] }} * {{ hidden['dw'].input_shape[1] }} * d, pong + {{ hidden['pw'].input_shape[0] }} * {{ hidden['pw'].input_shape[1] }} * d, depthwise_{{ i + 1 }}_weights[d], {{ hidden['dw'].input_shape[0] }}, {{ hidden['dw'].strides[0] }}); 45 | 46 | // pointwise ({{ i + 1 }}) 47 | for (int16_t d = 0; d < {{ hidden['pw'].output_shape[2] }}; d++) 48 | this->pointwise_conv(pong, ping + {{ hidden['pw'].input_shape[0] }} * {{ hidden['pw'].input_shape[1] }} * d, pointwise_{{ i + 1 }}_weights[d], {{ hidden['dw'].output_shape[0] }}, {{ hidden['dw'].output_shape[2] }}); 49 | {% endfor %} 50 | 51 | this->maxpool(ping, pong, {{ maxpool.input_shape[0] }}, {{ maxpool.input_shape[-1] }}); 52 | 53 | for (uint16_t d = 0; d < numOutputs; d++) 54 | this->dot(pong, ping + d, conv2d_last_weights[d], conv2d_last_bias[d], {{ conv_last.input_shape[-1] }}); 55 | 56 | this->softmax(ping, outputs, numOutputs); 57 | 58 | return this->argmax(); 59 | } 60 | 61 | {% include './ops/argmax' %} 62 | 63 | protected: 64 | const float conv2d_0_weights{{ conv_0.weights | to_weights_shape }} = {{ conv_0.weights | to_weights_array }}; 65 | {% for i, hidden in enumerate(hidden_layers) %} 66 | const float depthwise_{{ i + 1 }}_weights{{ hidden['dw'].weights | to_weights_shape }} = {{ hidden['dw'].weights | to_weights_array }}; 67 | const float pointwise_{{ i + 1 }}_weights{{ hidden['pw'].weights | to_weights_shape }} = {{ hidden['pw'].weights | to_weights_array }}; 68 | {% endfor %} 69 | const float conv2d_last_weights{{ conv_last.weights | to_weights_shape }} = {{ conv_last.weights | to_weights_array }}; 70 | const float conv2d_last_bias[{{ conv_last.bias | length }}] = {{ conv_last.bias | to_array }}; 71 | 72 | {% include './ops/mult3x3' %} 73 | {% include './ops/pad' %} 74 | {% include './ops/conv3x3x1' %} 75 | {% include './ops/depthwise_conv' %} 76 | {% include './ops/pointwise_conv' %} 77 | {% include './ops/maxpool' %} 78 | {% include './ops/dot' %} 79 | {% include './ops/softmax' %} 80 | }; -------------------------------------------------------------------------------- /micromobilenet/convert/templates/ops/argmax.jinja: -------------------------------------------------------------------------------- 1 | /** 2 | * Get index of max output 3 | */ 4 | uint16_t argmax() { 5 | this->output = 0; 6 | this->proba = outputs[0]; 7 | 8 | for (uint16_t i = 1; i < numOutputs; i++) { 9 | if (outputs[i] > this->proba) { 10 | this->proba = outputs[i]; 11 | this->output = i; 12 | } 13 | } 14 | 15 | return this->output; 16 | } 17 | -------------------------------------------------------------------------------- /micromobilenet/convert/templates/ops/conv3x3x1.jinja: -------------------------------------------------------------------------------- 1 | /** 2 | * Depthwise 3x3 convolution without ReLU 3 | * 4 | * @param input 5 | * @param output 6 | * @param kernel 7 | * @param width 8 | * @param stride 9 | */ 10 | void conv2d_3x3x1(float *input, float *output, const float *kernel, const uint16_t width, uint8_t stride) { 11 | uint16_t o = 0; 12 | 13 | for (uint16_t y = 0; y <= width - 3; y += stride) { 14 | const uint16_t offset = y * width; 15 | float *i = input + offset; 16 | 17 | for (uint16_t x = 0; x <= width - 3; x += stride) { 18 | output[o++] = this->mult3x3(i + x, kernel, width); 19 | } 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /micromobilenet/convert/templates/ops/depthwise_conv.jinja: -------------------------------------------------------------------------------- 1 | /** 2 | * Depthwise 3x3 convolution with ReLU 3 | * 4 | * @param inputs 5 | * @param outputs 6 | * @param kernel 7 | * @param width 8 | * @param stride 9 | */ 10 | void depthwise_conv(float *inputs, float *outputs, const float *kernel, const uint16_t width, uint8_t stride) { 11 | uint16_t o = 0; 12 | 13 | for (uint16_t y = 0; y <= width - 3; y += stride) { 14 | const uint16_t offset = y * width; 15 | float *i = inputs + offset; 16 | 17 | for (uint16_t x = 0; x <= width - 3; x += stride) { 18 | float val = this->mult3x3(i + x, kernel, width); 19 | 20 | if (val < 0) val = 0; 21 | else if (val > 6) val = 6; 22 | 23 | outputs[o++] = val; 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /micromobilenet/convert/templates/ops/dot.jinja: -------------------------------------------------------------------------------- 1 | /** 2 | * Dot product with bias 3 | * 4 | * @param inputs 5 | * @param outputs 6 | * @param kernel 7 | * @param bias 8 | * @param length 9 | */ 10 | void dot(float *inputs, float *outputs, const float *weights, const float bias, const uint16_t length) { 11 | float sum = 0; 12 | 13 | for (uint16_t i = 0; i < length; i++) 14 | sum += inputs[i] * weights[i]; 15 | 16 | outputs[0] = sum + bias; 17 | } 18 | -------------------------------------------------------------------------------- /micromobilenet/convert/templates/ops/maxpool.jinja: -------------------------------------------------------------------------------- 1 | /** 2 | * (Global) MaxPooling 3 | * 4 | * @param inputs 5 | * @param outputs 6 | * @param width 7 | * @param channels 8 | */ 9 | void maxpool(float *inputs, float *outputs, const uint16_t width, const uint16_t channels) { 10 | const uint16_t size = width * width; 11 | 12 | for (uint16_t c = 0; c < channels; c++) { 13 | const uint16_t offset = size * c; 14 | float *in = inputs + offset; 15 | float greatest = in[0]; 16 | 17 | for (uint16_t j = 1; j < size; j++) 18 | if (in[j] > greatest) 19 | greatest = in[j]; 20 | 21 | outputs[c] = greatest; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /micromobilenet/convert/templates/ops/mult3x3.jinja: -------------------------------------------------------------------------------- 1 | /** 2 | * Multiply 3x3 kernel on single 3x3 image patch 3 | * 4 | * @param inputs 5 | * @param kernel 6 | * @param width 7 | */ 8 | inline float mult3x3(float *inputs, const float kernel[9], const uint16_t width) { 9 | const float *i1 = inputs; 10 | const float *i2 = inputs + width; 11 | const float *i3 = inputs + width + width; 12 | 13 | return i1[0] * kernel[0] + 14 | i1[1] * kernel[1] + 15 | i1[2] * kernel[2] + 16 | i2[0] * kernel[3] + 17 | i2[1] * kernel[4] + 18 | i2[2] * kernel[5] + 19 | i3[0] * kernel[6] + 20 | i3[1] * kernel[7] + 21 | i3[2] * kernel[8]; 22 | } 23 | -------------------------------------------------------------------------------- /micromobilenet/convert/templates/ops/pad.jinja: -------------------------------------------------------------------------------- 1 | /** 2 | * Zero padding 2D 3 | * 4 | * @param inputs 5 | * @param outputs 6 | * @param width 7 | */ 8 | void pad(float *inputs, float *outputs, uint16_t width) { 9 | const uint16_t paddedWidth = width + 2; 10 | uint16_t i = 0; 11 | uint16_t o = 0; 12 | 13 | // first row of zeros 14 | for (uint16_t x = 0; x < paddedWidth; x++) 15 | outputs[o++] = 0; 16 | 17 | for (uint16_t y = 0; y < width; y++) { 18 | outputs[o++] = 0; 19 | 20 | for (uint16_t x = 0; x < width; x++) 21 | outputs[o++] = inputs[i++]; 22 | 23 | outputs[o++] = 0; 24 | } 25 | 26 | // last row of zeros 27 | for (uint16_t x = 0; x < paddedWidth; x++) 28 | outputs[o++] = 0; 29 | } 30 | -------------------------------------------------------------------------------- /micromobilenet/convert/templates/ops/pointwise_conv.jinja: -------------------------------------------------------------------------------- 1 | /** 2 | * Pointwise 1x1 convolution with ReLU 3 | * 4 | * @param inputs 5 | * @param outputs 6 | * @param kernel 7 | * @param width 8 | * @param channels 9 | */ 10 | void pointwise_conv(float *inputs, float *outputs, const float *kernel, const uint16_t width, const uint16_t channels) { 11 | const uint16_t size = width * width; 12 | uint16_t o = 0; 13 | 14 | for (uint16_t y = 0; y < width; y += 1) { 15 | const uint16_t offset = y * width; 16 | for (uint16_t x = 0; x < width; x += 1) { 17 | float val = 0; 18 | 19 | for (uint16_t c = 0; c < channels; c++) 20 | val += inputs[(offset + x) + size * c] * kernel[c]; 21 | 22 | if (val < 0) val = 0; 23 | else if (val > 6) val = 6; 24 | 25 | outputs[o++] = val; 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /micromobilenet/convert/templates/ops/softmax.jinja: -------------------------------------------------------------------------------- 1 | /** 2 | * Softmax activation 3 | * 4 | * @param inputs 5 | * @param outputs 6 | * @param numOutputs 7 | */ 8 | void softmax(float *inputs, float *outputs, uint16_t numOutputs) { 9 | float sum = 0; 10 | 11 | for (uint16_t i = 0; i < numOutputs; i++) { 12 | const float e = exp(inputs[i]); 13 | outputs[i] = e; 14 | sum += e; 15 | } 16 | 17 | for (uint16_t i = 0; i < numOutputs; i++) 18 | outputs[i] /= sum; 19 | } 20 | -------------------------------------------------------------------------------- /micromobilenet/convert/templates/predict_file.jinja: -------------------------------------------------------------------------------- 1 | #define NUM_INPUTS 9216 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include "MobileNet.h" 8 | 9 | using namespace std; 10 | 11 | 12 | int main() { 13 | MobileNet net; 14 | unsigned char buffer[NUM_INPUTS]; 15 | float im[NUM_INPUTS]; 16 | FILE *file = fopen("X.bin", "rb"); 17 | 18 | while (fread(buffer, NUM_INPUTS, 1, file)) { 19 | for (int i = 0; i < NUM_INPUTS; i++) 20 | im[i] = buffer[i] / 255.0f; 21 | 22 | cout << net.predict(im) << endl; 23 | } 24 | 25 | fclose(file); 26 | return 0; 27 | } -------------------------------------------------------------------------------- /micromobilenet/converters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def convert_xs(xs: np.ndarray, ys: np.ndarray) -> str: 5 | """ 6 | Convert one sample for each y 7 | :param ys: 8 | :param xs: 9 | :return: 10 | """ 11 | samples = [] 12 | ys = ys.argmax(axis=1) 13 | 14 | for y in range(ys.max()): 15 | sample = xs[ys == y][-1].flatten() 16 | data = ", ".join("%.4f" % xi for xi in sample) 17 | samples.append(f"float x{y}[{len(sample)}] = {{ {data} }};") 18 | 19 | return "\n".join(samples) -------------------------------------------------------------------------------- /micromobilenet/load.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import numpy as np 4 | from os import listdir 5 | from glob import glob 6 | from PIL import Image 7 | 8 | 9 | def load_folder(folder: str): 10 | """ 11 | Load images from folder as [0, 1] floats 12 | :param folder: 13 | :return: 14 | """ 15 | for filename in sorted(glob(f"{folder}/*.jpg") + glob(f"{folder}/*.jpeg")): 16 | yield np.asarray(Image.open(filename).convert("L"), dtype=float) / 255. 17 | 18 | 19 | def load_split(root: str, split_name: str): 20 | """ 21 | Load images from train/val/test folder 22 | :param root: 23 | :param split_name: 24 | :return: 25 | """ 26 | X = [] 27 | Y = [] 28 | folders = listdir(f"{root}/{split_name}") 29 | folders = [f"{root}/{split_name}/{f}" for f in folders if os.path.isdir(f"{root}/{split_name}/{f}")] 30 | 31 | for k, folder in enumerate(sorted(folders)): 32 | folder_x = list(load_folder(folder)) 33 | X += folder_x 34 | Y += [k] * len(folder_x) 35 | 36 | # shuffle inputs 37 | shuffle_mask = np.random.permutation(len(X)) 38 | X = np.asarray(X)[shuffle_mask] 39 | Y = np.asarray(Y)[shuffle_mask] 40 | 41 | return X, Y 42 | -------------------------------------------------------------------------------- /micromobilenet/runner.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path 3 | import warnings 4 | from tempfile import gettempdir 5 | from subprocess import check_output, check_call 6 | 7 | import numpy as np 8 | 9 | 10 | class Runner: 11 | """ 12 | Run C++ MobileNet 13 | """ 14 | def __init__(self, net): 15 | """ 16 | 17 | :param net: 18 | """ 19 | self.net = net 20 | 21 | def predict(self, X: np.ndarray) -> np.ndarray: 22 | """ 23 | Predict samples 24 | :param X: 25 | :return: 26 | """ 27 | root = os.path.abspath(gettempdir()) 28 | logging.warning(f"setting CWD={root}") 29 | 30 | # save input to binary file 31 | with open(os.path.join(root, "X.bin"), "wb") as file: 32 | file.write(X.flatten().astype(np.uint8).tobytes("C")) 33 | 34 | # save net to file 35 | with open(os.path.join(root, "MobileNet.h"), "w") as file: 36 | file.write(self.net.convert.to_cpp(classname="MobileNet")) 37 | 38 | # create C++ main file 39 | src = os.path.join(os.path.dirname(__file__), "convert", "templates", "predict_file.jinja") 40 | dest = os.path.join(root, "mobilenet_test.cpp") 41 | 42 | with open(src) as fin, open(dest, "w") as fout: 43 | fout.write(fin.read()) 44 | 45 | # compile (disable compilation warnings) 46 | with warnings.catch_warnings(): 47 | warnings.simplefilter("ignore") 48 | 49 | if check_call(["g++", "mobilenet_test.cpp", "-o", "mobilenet_test"], cwd=root) == 0: 50 | output = check_output(["./mobilenet_test"], cwd=root).decode() 51 | return np.asarray([int(x) for x in output.split("\n") if x.strip()]) 52 | -------------------------------------------------------------------------------- /micromobilenet/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import numpy as np 4 | 5 | global_vars = {} 6 | 7 | 8 | def update_globals(**kwargs): 9 | global global_vars 10 | 11 | global_vars.update(**kwargs) 12 | 13 | 14 | def get_globals(): 15 | global global_vars 16 | 17 | return global_vars 18 | 19 | 20 | def parse_npy(x: str): 21 | """ 22 | Parse Numpy output as array 23 | :param x: 24 | :return: 25 | """ 26 | x = re.sub(r"(\d)\s+([-0-9])", lambda m: f"{m.group(1)}, {m.group(2)}", x) 27 | x = re.sub(r"\]\s+\[", "],\n[", x) 28 | x = json.loads(x) 29 | 30 | return np.asarray(x) 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | 4 | setup( 5 | name='micromobilenet', 6 | packages=['micromobilenet'], 7 | version='1.0.0', 8 | license='MIT', 9 | description='Variations of MobileNetV1 for emdedded CPUs', 10 | author='Simone Salerno', 11 | author_email='support@eloquentarduino.com', 12 | url='https://github.com/eloquentarduino/micromobilenet', 13 | keywords=[ 14 | 'ML', 15 | 'Edge AI' 16 | ], 17 | install_requires=[ 18 | 'numpy', 19 | 'keras', 20 | 'tensorflow', 21 | 'Jinja2', 22 | 'cached_property' 23 | ] 24 | ) --------------------------------------------------------------------------------