├── 1face.png ├── 3faces.png ├── 4faces.png ├── Anchors.ipynb ├── Convert.ipynb ├── Inference.ipynb ├── LICENSE ├── README.markdown ├── anchors.npy ├── anchorsback.npy ├── blazeface.pth ├── blazeface.py └── blazefaceback.pth /1face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hollance/BlazeFace-PyTorch/852bfd8e3d44ed6775761105bdcead4ef389a538/1face.png -------------------------------------------------------------------------------- /3faces.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hollance/BlazeFace-PyTorch/852bfd8e3d44ed6775761105bdcead4ef389a538/3faces.png -------------------------------------------------------------------------------- /4faces.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hollance/BlazeFace-PyTorch/852bfd8e3d44ed6775761105bdcead4ef389a538/4faces.png -------------------------------------------------------------------------------- /Anchors.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Create anchor boxes\n", 8 | "\n", 9 | "This is the **SsdAnchorsCalculator** stage from the MediaPipe graph. It computes a list of anchor boxes. This only needs to be done once, so we will store these anchors into a lookup table.\n", 10 | "\n", 11 | "\n", 12 | "Using conda environnement:\n", 13 | "```\n", 14 | "conda create -c pytorch -c conda-forge -n BlazeConv 'pytorch=1.6' jupyter opencv matplotlib\n", 15 | "```\n", 16 | "```\n", 17 | "conda activate BlazeConv\n", 18 | "```\n", 19 | "```\n", 20 | "pip install tflite\n", 21 | "```" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import numpy as np" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "These are the options from [face_detection_mobile_gpu.pbtxt](https://github.com/google/mediapipe/blob/master/mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt).\n", 38 | "\n", 39 | "To understand what these options mean, see [ssd_anchors_calculator.proto](https://github.com/google/mediapipe/blob/master/mediapipe/calculators/tflite/ssd_anchors_calculator.proto)." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 4, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "anchor_options = {\n", 49 | " \"num_layers\": 4,\n", 50 | " \"min_scale\": 0.1484375,\n", 51 | " \"max_scale\": 0.75,\n", 52 | " \"input_size_height\": 128,\n", 53 | " \"input_size_width\": 128,\n", 54 | " \"anchor_offset_x\": 0.5,\n", 55 | " \"anchor_offset_y\": 0.5,\n", 56 | " \"strides\": [8, 16, 16, 16],\n", 57 | " \"aspect_ratios\": [1.0],\n", 58 | " \"reduce_boxes_in_lowest_layer\": False,\n", 59 | " \"interpolated_scale_aspect_ratio\": 1.0,\n", 60 | " \"fixed_anchor_size\": True,\n", 61 | "}" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "These are the options from [face_detection_back_mobile_gpu.pbtxt](https://github.com/google/mediapipe/blob/master/mediapipe/graphs/face_detection/face_detection_back_mobile_gpu.pbtxt)." 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 5, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "anchor_back_options = {\n", 78 | " \"num_layers\": 4,\n", 79 | " \"min_scale\": 0.15625,\n", 80 | " \"max_scale\": 0.75,\n", 81 | " \"input_size_height\": 256,\n", 82 | " \"input_size_width\": 256,\n", 83 | " \"anchor_offset_x\": 0.5,\n", 84 | " \"anchor_offset_y\": 0.5,\n", 85 | " \"strides\": [16, 32, 32, 32],\n", 86 | " \"aspect_ratios\": [1.0],\n", 87 | " \"reduce_boxes_in_lowest_layer\": False,\n", 88 | " \"interpolated_scale_aspect_ratio\": 1.0,\n", 89 | " \"fixed_anchor_size\": True,\n", 90 | "}" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "This is a literal translation of [ssd_anchors_calculator.cc](https://github.com/google/mediapipe/blob/master/mediapipe/calculators/tflite/ssd_anchors_calculator.cc):" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 6, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "def calculate_scale(min_scale, max_scale, stride_index, num_strides):\n", 107 | " return min_scale + (max_scale - min_scale) * stride_index / (num_strides - 1.0)\n", 108 | "\n", 109 | "\n", 110 | "def generate_anchors(options):\n", 111 | " strides_size = len(options[\"strides\"])\n", 112 | " assert options[\"num_layers\"] == strides_size\n", 113 | "\n", 114 | " anchors = []\n", 115 | " layer_id = 0\n", 116 | " while layer_id < strides_size:\n", 117 | " anchor_height = []\n", 118 | " anchor_width = []\n", 119 | " aspect_ratios = []\n", 120 | " scales = []\n", 121 | "\n", 122 | " # For same strides, we merge the anchors in the same order.\n", 123 | " last_same_stride_layer = layer_id\n", 124 | " while (last_same_stride_layer < strides_size) and \\\n", 125 | " (options[\"strides\"][last_same_stride_layer] == options[\"strides\"][layer_id]):\n", 126 | " scale = calculate_scale(options[\"min_scale\"],\n", 127 | " options[\"max_scale\"],\n", 128 | " last_same_stride_layer,\n", 129 | " strides_size)\n", 130 | "\n", 131 | " if last_same_stride_layer == 0 and options[\"reduce_boxes_in_lowest_layer\"]:\n", 132 | " # For first layer, it can be specified to use predefined anchors.\n", 133 | " aspect_ratios.append(1.0)\n", 134 | " aspect_ratios.append(2.0)\n", 135 | " aspect_ratios.append(0.5)\n", 136 | " scales.append(0.1)\n", 137 | " scales.append(scale)\n", 138 | " scales.append(scale) \n", 139 | " else:\n", 140 | " for aspect_ratio in options[\"aspect_ratios\"]:\n", 141 | " aspect_ratios.append(aspect_ratio)\n", 142 | " scales.append(scale)\n", 143 | "\n", 144 | " if options[\"interpolated_scale_aspect_ratio\"] > 0.0:\n", 145 | " scale_next = 1.0 if last_same_stride_layer == strides_size - 1 \\\n", 146 | " else calculate_scale(options[\"min_scale\"],\n", 147 | " options[\"max_scale\"],\n", 148 | " last_same_stride_layer + 1,\n", 149 | " strides_size)\n", 150 | " scales.append(np.sqrt(scale * scale_next))\n", 151 | " aspect_ratios.append(options[\"interpolated_scale_aspect_ratio\"])\n", 152 | "\n", 153 | " last_same_stride_layer += 1\n", 154 | "\n", 155 | " for i in range(len(aspect_ratios)):\n", 156 | " ratio_sqrts = np.sqrt(aspect_ratios[i])\n", 157 | " anchor_height.append(scales[i] / ratio_sqrts)\n", 158 | " anchor_width.append(scales[i] * ratio_sqrts) \n", 159 | " \n", 160 | " stride = options[\"strides\"][layer_id]\n", 161 | " feature_map_height = int(np.ceil(options[\"input_size_height\"] / stride))\n", 162 | " feature_map_width = int(np.ceil(options[\"input_size_width\"] / stride))\n", 163 | "\n", 164 | " for y in range(feature_map_height):\n", 165 | " for x in range(feature_map_width):\n", 166 | " for anchor_id in range(len(anchor_height)):\n", 167 | " x_center = (x + options[\"anchor_offset_x\"]) / feature_map_width\n", 168 | " y_center = (y + options[\"anchor_offset_y\"]) / feature_map_height\n", 169 | "\n", 170 | " new_anchor = [x_center, y_center, 0, 0]\n", 171 | " if options[\"fixed_anchor_size\"]:\n", 172 | " new_anchor[2] = 1.0\n", 173 | " new_anchor[3] = 1.0\n", 174 | " else:\n", 175 | " new_anchor[2] = anchor_width[anchor_id]\n", 176 | " new_anchor[3] = anchor_height[anchor_id]\n", 177 | " anchors.append(new_anchor)\n", 178 | "\n", 179 | " layer_id = last_same_stride_layer\n", 180 | "\n", 181 | " return anchors" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 7, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "anchors = generate_anchors(anchor_options)\n", 191 | "\n", 192 | "assert len(anchors) == 896\n", 193 | "\n", 194 | "anchors_back = generate_anchors(anchor_back_options)\n", 195 | "\n", 196 | "assert len(anchors_back) == 896\n" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "metadata": {}, 202 | "source": [ 203 | "Each anchor is `[x_center, y_center, width, height]` in normalized coordinates. For our use case, the width and height are always 1." 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 8, 209 | "metadata": { 210 | "scrolled": true 211 | }, 212 | "outputs": [ 213 | { 214 | "data": { 215 | "text/plain": [ 216 | "[[0.03125, 0.03125, 1.0, 1.0],\n", 217 | " [0.03125, 0.03125, 1.0, 1.0],\n", 218 | " [0.09375, 0.03125, 1.0, 1.0],\n", 219 | " [0.09375, 0.03125, 1.0, 1.0],\n", 220 | " [0.15625, 0.03125, 1.0, 1.0],\n", 221 | " [0.15625, 0.03125, 1.0, 1.0],\n", 222 | " [0.21875, 0.03125, 1.0, 1.0],\n", 223 | " [0.21875, 0.03125, 1.0, 1.0],\n", 224 | " [0.28125, 0.03125, 1.0, 1.0],\n", 225 | " [0.28125, 0.03125, 1.0, 1.0]]" 226 | ] 227 | }, 228 | "execution_count": 8, 229 | "metadata": {}, 230 | "output_type": "execute_result" 231 | } 232 | ], 233 | "source": [ 234 | "anchors[:10]" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 9, 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "data": { 244 | "text/plain": [ 245 | "[[0.03125, 0.03125, 1.0, 1.0],\n", 246 | " [0.03125, 0.03125, 1.0, 1.0],\n", 247 | " [0.09375, 0.03125, 1.0, 1.0],\n", 248 | " [0.09375, 0.03125, 1.0, 1.0],\n", 249 | " [0.15625, 0.03125, 1.0, 1.0],\n", 250 | " [0.15625, 0.03125, 1.0, 1.0],\n", 251 | " [0.21875, 0.03125, 1.0, 1.0],\n", 252 | " [0.21875, 0.03125, 1.0, 1.0],\n", 253 | " [0.28125, 0.03125, 1.0, 1.0],\n", 254 | " [0.28125, 0.03125, 1.0, 1.0]]" 255 | ] 256 | }, 257 | "execution_count": 9, 258 | "metadata": {}, 259 | "output_type": "execute_result" 260 | } 261 | ], 262 | "source": [ 263 | "anchors_back[:10]" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "Run the \"FaceDetectionConfig\" test case from the MediaPipe repo:" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 6, 276 | "metadata": {}, 277 | "outputs": [ 278 | { 279 | "name": "stdout", 280 | "output_type": "stream", 281 | "text": [ 282 | "Number of errors: 0\n" 283 | ] 284 | } 285 | ], 286 | "source": [ 287 | "anchor_options_test = {\n", 288 | " \"num_layers\": 5,\n", 289 | " \"min_scale\": 0.1171875,\n", 290 | " \"max_scale\": 0.75,\n", 291 | " \"input_size_height\": 256,\n", 292 | " \"input_size_width\": 256,\n", 293 | " \"anchor_offset_x\": 0.5,\n", 294 | " \"anchor_offset_y\": 0.5,\n", 295 | " \"strides\": [8, 16, 32, 32, 32],\n", 296 | " \"aspect_ratios\": [1.0],\n", 297 | " \"reduce_boxes_in_lowest_layer\": False,\n", 298 | " \"interpolated_scale_aspect_ratio\": 1.0,\n", 299 | " \"fixed_anchor_size\": True,\n", 300 | "}\n", 301 | "\n", 302 | "anchors_test = generate_anchors(anchor_options_test)\n", 303 | "anchors_golden = np.loadtxt(\"./mediapipe/mediapipe/calculators/tflite/testdata/anchor_golden_file_0.txt\")\n", 304 | "\n", 305 | "assert len(anchors_test) == len(anchors_golden)\n", 306 | "print(\"Number of errors:\", (np.abs(anchors_test - anchors_golden) > 1e-5).sum())" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "Run the \"MobileSSDConfig\" test case from the MediaPipe repo:" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": 7, 319 | "metadata": {}, 320 | "outputs": [ 321 | { 322 | "name": "stdout", 323 | "output_type": "stream", 324 | "text": [ 325 | "Number of errors: 0\n" 326 | ] 327 | } 328 | ], 329 | "source": [ 330 | "anchor_options_test = {\n", 331 | " \"num_layers\": 6,\n", 332 | " \"min_scale\": 0.2,\n", 333 | " \"max_scale\": 0.95,\n", 334 | " \"input_size_height\": 300,\n", 335 | " \"input_size_width\": 300,\n", 336 | " \"anchor_offset_x\": 0.5,\n", 337 | " \"anchor_offset_y\": 0.5,\n", 338 | " \"strides\": [16, 32, 64, 128, 256, 512],\n", 339 | " \"aspect_ratios\": [1.0, 2.0, 0.5, 3.0, 0.3333],\n", 340 | " \"reduce_boxes_in_lowest_layer\": True,\n", 341 | " \"interpolated_scale_aspect_ratio\": 1.0,\n", 342 | " \"fixed_anchor_size\": False,\n", 343 | "}\n", 344 | "\n", 345 | "anchors_test = generate_anchors(anchor_options_test)\n", 346 | "anchors_golden = np.loadtxt(\"./mediapipe/mediapipe/calculators/tflite/testdata/anchor_golden_file_1.txt\")\n", 347 | "\n", 348 | "assert len(anchors_test) == len(anchors_golden)\n", 349 | "print(\"Number of errors:\", (np.abs(anchors_test - anchors_golden) > 1e-5).sum())" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": {}, 355 | "source": [ 356 | "Save the anchors to a file:" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 10, 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [ 365 | "np.save(\"anchors.npy\", anchors)\n", 366 | "np.save(\"anchorsback.npy\", anchors_back)" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": null, 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [] 375 | } 376 | ], 377 | "metadata": { 378 | "kernelspec": { 379 | "display_name": "Python 3", 380 | "language": "python", 381 | "name": "python3" 382 | }, 383 | "language_info": { 384 | "codemirror_mode": { 385 | "name": "ipython", 386 | "version": 3 387 | }, 388 | "file_extension": ".py", 389 | "mimetype": "text/x-python", 390 | "name": "python", 391 | "nbconvert_exporter": "python", 392 | "pygments_lexer": "ipython3", 393 | "version": "3.8.5" 394 | } 395 | }, 396 | "nbformat": 4, 397 | "nbformat_minor": 2 398 | } 399 | -------------------------------------------------------------------------------- /Convert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Convert TFLite model to PyTorch\n", 8 | "\n", 9 | "This uses the model **face_detection_front.tflite** from [MediaPipe](https://github.com/google/mediapipe/tree/master/mediapipe/models).\n", 10 | "\n", 11 | "Using conda environnement:\n", 12 | "```\n", 13 | "conda create -c pytorch -c conda-forge -n BlazeConv 'pytorch=1.6' jupyter opencv matplotlib\n", 14 | "```\n", 15 | "```\n", 16 | "conda activate BlazeConv\n", 17 | "```\n", 18 | "```\n", 19 | "pip install tflite\n", 20 | "```" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "## Convert front camera TFLite model" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 1, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "import os\n", 37 | "import numpy as np\n", 38 | "from collections import OrderedDict" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "### Get the weights from the TFLite file" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "Load the TFLite model using the FlatBuffers library:" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "--2021-02-09 23:17:46-- https://github.com/google/mediapipe/raw/master/mediapipe/models/face_detection_front.tflite\n", 65 | "Résolution de github.com (github.com)… 140.82.121.3\n", 66 | "Connexion à github.com (github.com)|140.82.121.3|:443… connecté.\n", 67 | "requête HTTP transmise, en attente de la réponse… 302 Found\n", 68 | "Emplacement : https://raw.githubusercontent.com/google/mediapipe/master/mediapipe/models/face_detection_front.tflite [suivant]\n", 69 | "--2021-02-09 23:17:46-- https://raw.githubusercontent.com/google/mediapipe/master/mediapipe/models/face_detection_front.tflite\n", 70 | "Résolution de raw.githubusercontent.com (raw.githubusercontent.com)… 151.101.120.133\n", 71 | "Connexion à raw.githubusercontent.com (raw.githubusercontent.com)|151.101.120.133|:443… connecté.\n", 72 | "requête HTTP transmise, en attente de la réponse… 200 OK\n", 73 | "Taille : 229032 (224K) [application/octet-stream]\n", 74 | "Enregistre : «face_detection_front.tflite»\n", 75 | "\n", 76 | "face_detection_fron 100%[===================>] 223,66K --.-KB/s ds 0,01s \n", 77 | "\n", 78 | "En-tête de dernière modification manquant — horodatage arrêté.\n", 79 | "2021-02-09 23:17:46 (18,7 MB/s) - «face_detection_front.tflite» enregistré [229032/229032]\n", 80 | "\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "!wget -N https://github.com/google/mediapipe/raw/master/mediapipe/models/face_detection_front.tflite" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 3, 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "from tflite import Model\n", 95 | "\n", 96 | "front_data = open(\"./face_detection_front.tflite\", \"rb\").read()\n", 97 | "front_model = Model.GetRootAsModel(front_data, 0)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 4, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "data": { 107 | "text/plain": [ 108 | "b'keras2tflite_facedetector-front.tflite.generated'" 109 | ] 110 | }, 111 | "execution_count": 4, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | } 115 | ], 116 | "source": [ 117 | "front_subgraph = front_model.Subgraphs(0)\n", 118 | "front_subgraph.Name()" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 5, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "def get_shape(tensor):\n", 128 | " return [tensor.Shape(i) for i in range(tensor.ShapeLength())]" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "List all the tensors in the graph:" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 6, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "name": "stdout", 145 | "output_type": "stream", 146 | "text": [ 147 | " 0 b'input' 0 0 [1, 128, 128, 3]\n", 148 | " 1 b'conv2d/Kernel' 1 1 [24, 5, 5, 3]\n", 149 | " 2 b'conv2d/Bias' 1 2 [24]\n", 150 | " 3 b'conv2d' 0 0 [1, 64, 64, 24]\n", 151 | " 4 b'activation' 0 0 [1, 64, 64, 24]\n", 152 | " 5 b'depthwise_conv2d/Kernel' 1 3 [1, 3, 3, 24]\n", 153 | " 6 b'depthwise_conv2d/Bias' 1 4 [24]\n", 154 | " 7 b'depthwise_conv2d' 0 0 [1, 64, 64, 24]\n", 155 | " 8 b'conv2d_1/Kernel' 1 5 [24, 1, 1, 24]\n", 156 | " 9 b'conv2d_1/Bias' 1 6 [24]\n", 157 | " 10 b'conv2d_1' 0 0 [1, 64, 64, 24]\n", 158 | " 11 b'add__xeno_compat__1' 0 0 [1, 64, 64, 24]\n", 159 | " 12 b'activation_1' 0 0 [1, 64, 64, 24]\n", 160 | " 13 b'depthwise_conv2d_1/Kernel' 1 7 [1, 3, 3, 24]\n", 161 | " 14 b'depthwise_conv2d_1/Bias' 1 8 [24]\n", 162 | " 15 b'depthwise_conv2d_1' 0 0 [1, 64, 64, 24]\n", 163 | " 16 b'conv2d_2/Kernel' 1 9 [28, 1, 1, 24]\n", 164 | " 17 b'conv2d_2/Bias' 1 10 [28]\n", 165 | " 18 b'conv2d_2' 0 0 [1, 64, 64, 28]\n", 166 | " 19 b'channel_padding/Paddings' 2 11 [4, 2]\n", 167 | " 20 b'channel_padding' 0 0 [1, 64, 64, 28]\n", 168 | " 21 b'add_1__xeno_compat__1' 0 0 [1, 64, 64, 28]\n", 169 | " 22 b'activation_2' 0 0 [1, 64, 64, 28]\n", 170 | " 23 b'depthwise_conv2d_2/Kernel' 1 12 [1, 3, 3, 28]\n", 171 | " 24 b'depthwise_conv2d_2/Bias' 1 13 [28]\n", 172 | " 25 b'depthwise_conv2d_2' 0 0 [1, 32, 32, 28]\n", 173 | " 26 b'max_pooling2d' 0 0 [1, 32, 32, 28]\n", 174 | " 27 b'conv2d_3/Kernel' 1 14 [32, 1, 1, 28]\n", 175 | " 28 b'conv2d_3/Bias' 1 15 [32]\n", 176 | " 29 b'conv2d_3' 0 0 [1, 32, 32, 32]\n", 177 | " 30 b'channel_padding_1/Paddings' 2 16 [4, 2]\n", 178 | " 31 b'channel_padding_1' 0 0 [1, 32, 32, 32]\n", 179 | " 32 b'add_2__xeno_compat__1' 0 0 [1, 32, 32, 32]\n", 180 | " 33 b'activation_3' 0 0 [1, 32, 32, 32]\n", 181 | " 34 b'depthwise_conv2d_3/Kernel' 1 17 [1, 3, 3, 32]\n", 182 | " 35 b'depthwise_conv2d_3/Bias' 1 18 [32]\n", 183 | " 36 b'depthwise_conv2d_3' 0 0 [1, 32, 32, 32]\n", 184 | " 37 b'conv2d_4/Kernel' 1 19 [36, 1, 1, 32]\n", 185 | " 38 b'conv2d_4/Bias' 1 20 [36]\n", 186 | " 39 b'conv2d_4' 0 0 [1, 32, 32, 36]\n", 187 | " 40 b'channel_padding_2/Paddings' 2 21 [4, 2]\n", 188 | " 41 b'channel_padding_2' 0 0 [1, 32, 32, 36]\n", 189 | " 42 b'add_3__xeno_compat__1' 0 0 [1, 32, 32, 36]\n", 190 | " 43 b'activation_4' 0 0 [1, 32, 32, 36]\n", 191 | " 44 b'depthwise_conv2d_4/Kernel' 1 22 [1, 3, 3, 36]\n", 192 | " 45 b'depthwise_conv2d_4/Bias' 1 23 [36]\n", 193 | " 46 b'depthwise_conv2d_4' 0 0 [1, 32, 32, 36]\n", 194 | " 47 b'conv2d_5/Kernel' 1 24 [42, 1, 1, 36]\n", 195 | " 48 b'conv2d_5/Bias' 1 25 [42]\n", 196 | " 49 b'conv2d_5' 0 0 [1, 32, 32, 42]\n", 197 | " 50 b'channel_padding_3/Paddings' 2 26 [4, 2]\n", 198 | " 51 b'channel_padding_3' 0 0 [1, 32, 32, 42]\n", 199 | " 52 b'add_4__xeno_compat__1' 0 0 [1, 32, 32, 42]\n", 200 | " 53 b'activation_5' 0 0 [1, 32, 32, 42]\n", 201 | " 54 b'depthwise_conv2d_5/Kernel' 1 27 [1, 3, 3, 42]\n", 202 | " 55 b'depthwise_conv2d_5/Bias' 1 28 [42]\n", 203 | " 56 b'depthwise_conv2d_5' 0 0 [1, 16, 16, 42]\n", 204 | " 57 b'max_pooling2d_1' 0 0 [1, 16, 16, 42]\n", 205 | " 58 b'conv2d_6/Kernel' 1 29 [48, 1, 1, 42]\n", 206 | " 59 b'conv2d_6/Bias' 1 30 [48]\n", 207 | " 60 b'conv2d_6' 0 0 [1, 16, 16, 48]\n", 208 | " 61 b'channel_padding_4/Paddings' 2 31 [4, 2]\n", 209 | " 62 b'channel_padding_4' 0 0 [1, 16, 16, 48]\n", 210 | " 63 b'add_5__xeno_compat__1' 0 0 [1, 16, 16, 48]\n", 211 | " 64 b'activation_6' 0 0 [1, 16, 16, 48]\n", 212 | " 65 b'depthwise_conv2d_6/Kernel' 1 32 [1, 3, 3, 48]\n", 213 | " 66 b'depthwise_conv2d_6/Bias' 1 33 [48]\n", 214 | " 67 b'depthwise_conv2d_6' 0 0 [1, 16, 16, 48]\n", 215 | " 68 b'conv2d_7/Kernel' 1 34 [56, 1, 1, 48]\n", 216 | " 69 b'conv2d_7/Bias' 1 35 [56]\n", 217 | " 70 b'conv2d_7' 0 0 [1, 16, 16, 56]\n", 218 | " 71 b'channel_padding_5/Paddings' 2 36 [4, 2]\n", 219 | " 72 b'channel_padding_5' 0 0 [1, 16, 16, 56]\n", 220 | " 73 b'add_6__xeno_compat__1' 0 0 [1, 16, 16, 56]\n", 221 | " 74 b'activation_7' 0 0 [1, 16, 16, 56]\n", 222 | " 75 b'depthwise_conv2d_7/Kernel' 1 37 [1, 3, 3, 56]\n", 223 | " 76 b'depthwise_conv2d_7/Bias' 1 38 [56]\n", 224 | " 77 b'depthwise_conv2d_7' 0 0 [1, 16, 16, 56]\n", 225 | " 78 b'conv2d_8/Kernel' 1 39 [64, 1, 1, 56]\n", 226 | " 79 b'conv2d_8/Bias' 1 40 [64]\n", 227 | " 80 b'conv2d_8' 0 0 [1, 16, 16, 64]\n", 228 | " 81 b'channel_padding_6/Paddings' 2 41 [4, 2]\n", 229 | " 82 b'channel_padding_6' 0 0 [1, 16, 16, 64]\n", 230 | " 83 b'add_7__xeno_compat__1' 0 0 [1, 16, 16, 64]\n", 231 | " 84 b'activation_8' 0 0 [1, 16, 16, 64]\n", 232 | " 85 b'depthwise_conv2d_8/Kernel' 1 42 [1, 3, 3, 64]\n", 233 | " 86 b'depthwise_conv2d_8/Bias' 1 43 [64]\n", 234 | " 87 b'depthwise_conv2d_8' 0 0 [1, 16, 16, 64]\n", 235 | " 88 b'conv2d_9/Kernel' 1 44 [72, 1, 1, 64]\n", 236 | " 89 b'conv2d_9/Bias' 1 45 [72]\n", 237 | " 90 b'conv2d_9' 0 0 [1, 16, 16, 72]\n", 238 | " 91 b'channel_padding_7/Paddings' 2 46 [4, 2]\n", 239 | " 92 b'channel_padding_7' 0 0 [1, 16, 16, 72]\n", 240 | " 93 b'add_8__xeno_compat__1' 0 0 [1, 16, 16, 72]\n", 241 | " 94 b'activation_9' 0 0 [1, 16, 16, 72]\n", 242 | " 95 b'depthwise_conv2d_9/Kernel' 1 47 [1, 3, 3, 72]\n", 243 | " 96 b'depthwise_conv2d_9/Bias' 1 48 [72]\n", 244 | " 97 b'depthwise_conv2d_9' 0 0 [1, 16, 16, 72]\n", 245 | " 98 b'conv2d_10/Kernel' 1 49 [80, 1, 1, 72]\n", 246 | " 99 b'conv2d_10/Bias' 1 50 [80]\n", 247 | "100 b'conv2d_10' 0 0 [1, 16, 16, 80]\n", 248 | "101 b'channel_padding_8/Paddings' 2 51 [4, 2]\n", 249 | "102 b'channel_padding_8' 0 0 [1, 16, 16, 80]\n", 250 | "103 b'add_9__xeno_compat__1' 0 0 [1, 16, 16, 80]\n", 251 | "104 b'activation_10' 0 0 [1, 16, 16, 80]\n", 252 | "105 b'depthwise_conv2d_10/Kernel' 1 52 [1, 3, 3, 80]\n", 253 | "106 b'depthwise_conv2d_10/Bias' 1 53 [80]\n", 254 | "107 b'depthwise_conv2d_10' 0 0 [1, 16, 16, 80]\n", 255 | "108 b'conv2d_11/Kernel' 1 54 [88, 1, 1, 80]\n", 256 | "109 b'conv2d_11/Bias' 1 55 [88]\n", 257 | "110 b'conv2d_11' 0 0 [1, 16, 16, 88]\n", 258 | "111 b'channel_padding_9/Paddings' 2 56 [4, 2]\n", 259 | "112 b'channel_padding_9' 0 0 [1, 16, 16, 88]\n", 260 | "113 b'add_10__xeno_compat__1' 0 0 [1, 16, 16, 88]\n", 261 | "114 b'activation_11' 0 0 [1, 16, 16, 88]\n", 262 | "115 b'depthwise_conv2d_11/Kernel' 1 57 [1, 3, 3, 88]\n", 263 | "116 b'depthwise_conv2d_11/Bias' 1 58 [88]\n", 264 | "117 b'depthwise_conv2d_11' 0 0 [1, 8, 8, 88]\n", 265 | "118 b'max_pooling2d_2' 0 0 [1, 8, 8, 88]\n", 266 | "119 b'conv2d_12/Kernel' 1 59 [96, 1, 1, 88]\n", 267 | "120 b'conv2d_12/Bias' 1 60 [96]\n", 268 | "121 b'conv2d_12' 0 0 [1, 8, 8, 96]\n", 269 | "122 b'channel_padding_10/Paddings' 2 61 [4, 2]\n", 270 | "123 b'channel_padding_10' 0 0 [1, 8, 8, 96]\n", 271 | "124 b'add_11__xeno_compat__1' 0 0 [1, 8, 8, 96]\n", 272 | "125 b'activation_12' 0 0 [1, 8, 8, 96]\n", 273 | "126 b'depthwise_conv2d_12/Kernel' 1 62 [1, 3, 3, 96]\n", 274 | "127 b'depthwise_conv2d_12/Bias' 1 63 [96]\n", 275 | "128 b'depthwise_conv2d_12' 0 0 [1, 8, 8, 96]\n", 276 | "129 b'conv2d_13/Kernel' 1 64 [96, 1, 1, 96]\n", 277 | "130 b'conv2d_13/Bias' 1 65 [96]\n", 278 | "131 b'conv2d_13' 0 0 [1, 8, 8, 96]\n", 279 | "132 b'add_12__xeno_compat__1' 0 0 [1, 8, 8, 96]\n", 280 | "133 b'activation_13' 0 0 [1, 8, 8, 96]\n", 281 | "134 b'depthwise_conv2d_13/Kernel' 1 66 [1, 3, 3, 96]\n", 282 | "135 b'depthwise_conv2d_13/Bias' 1 67 [96]\n", 283 | "136 b'depthwise_conv2d_13' 0 0 [1, 8, 8, 96]\n", 284 | "137 b'conv2d_14/Kernel' 1 68 [96, 1, 1, 96]\n", 285 | "138 b'conv2d_14/Bias' 1 69 [96]\n", 286 | "139 b'conv2d_14' 0 0 [1, 8, 8, 96]\n", 287 | "140 b'add_13__xeno_compat__1' 0 0 [1, 8, 8, 96]\n", 288 | "141 b'activation_14' 0 0 [1, 8, 8, 96]\n", 289 | "142 b'depthwise_conv2d_14/Kernel' 1 70 [1, 3, 3, 96]\n", 290 | "143 b'depthwise_conv2d_14/Bias' 1 71 [96]\n", 291 | "144 b'depthwise_conv2d_14' 0 0 [1, 8, 8, 96]\n", 292 | "145 b'conv2d_15/Kernel' 1 72 [96, 1, 1, 96]\n", 293 | "146 b'conv2d_15/Bias' 1 73 [96]\n", 294 | "147 b'conv2d_15' 0 0 [1, 8, 8, 96]\n", 295 | "148 b'add_14__xeno_compat__1' 0 0 [1, 8, 8, 96]\n", 296 | "149 b'activation_15' 0 0 [1, 8, 8, 96]\n", 297 | "150 b'depthwise_conv2d_15/Kernel' 1 74 [1, 3, 3, 96]\n", 298 | "151 b'depthwise_conv2d_15/Bias' 1 75 [96]\n", 299 | "152 b'depthwise_conv2d_15' 0 0 [1, 8, 8, 96]\n", 300 | "153 b'conv2d_16/Kernel' 1 76 [96, 1, 1, 96]\n", 301 | "154 b'conv2d_16/Bias' 1 77 [96]\n", 302 | "155 b'conv2d_16' 0 0 [1, 8, 8, 96]\n", 303 | "156 b'add_15__xeno_compat__1' 0 0 [1, 8, 8, 96]\n", 304 | "157 b'activation_16' 0 0 [1, 8, 8, 96]\n", 305 | "158 b'classificator_8/Kernel' 1 78 [2, 1, 1, 88]\n", 306 | "159 b'classificator_8/Bias' 1 79 [2]\n", 307 | "160 b'classificator_8' 0 0 [1, 16, 16, 2]\n", 308 | "161 b'classificator_16/Kernel' 1 80 [6, 1, 1, 96]\n", 309 | "162 b'classificator_16/Bias' 1 81 [6]\n", 310 | "163 b'classificator_16' 0 0 [1, 8, 8, 6]\n", 311 | "164 b'regressor_8/Kernel' 1 82 [32, 1, 1, 88]\n", 312 | "165 b'regressor_8/Bias' 1 83 [32]\n", 313 | "166 b'regressor_8' 0 0 [1, 16, 16, 32]\n", 314 | "167 b'regressor_16/Kernel' 1 84 [96, 1, 1, 96]\n", 315 | "168 b'regressor_16/Bias' 1 85 [96]\n", 316 | "169 b'regressor_16' 0 0 [1, 8, 8, 96]\n", 317 | "170 b'reshape' 0 0 [1, 512, 1]\n", 318 | "171 b'reshape_2' 0 0 [1, 384, 1]\n", 319 | "172 b'reshape_1' 0 0 [1, 512, 16]\n", 320 | "173 b'reshape_3' 0 0 [1, 384, 16]\n", 321 | "174 b'classificators' 0 0 [1, 896, 1]\n", 322 | "175 b'regressors' 0 0 [1, 896, 16]\n", 323 | "176 b'conv2d_3/Bias_dequantize' 0 0 [32]\n", 324 | "177 b'conv2d_16/Bias_dequantize' 0 0 [96]\n", 325 | "178 b'conv2d_2/Bias_dequantize' 0 0 [28]\n", 326 | "179 b'depthwise_conv2d_3/Bias_dequantize' 0 0 [32]\n", 327 | "180 b'depthwise_conv2d_14/Bias_dequantize' 0 0 [96]\n", 328 | "181 b'classificator_16/Kernel_dequantize' 0 0 [6, 1, 1, 96]\n", 329 | "182 b'conv2d_9/Bias_dequantize' 0 0 [72]\n", 330 | "183 b'regressor_16/Bias_dequantize' 0 0 [96]\n", 331 | "184 b'depthwise_conv2d_2/Bias_dequantize' 0 0 [28]\n", 332 | "185 b'depthwise_conv2d/Bias_dequantize' 0 0 [24]\n", 333 | "186 b'depthwise_conv2d_15/Kernel_dequantize' 0 0 [1, 3, 3, 96]\n", 334 | "187 b'conv2d_8/Kernel_dequantize' 0 0 [64, 1, 1, 56]\n", 335 | "188 b'depthwise_conv2d_9/Bias_dequantize' 0 0 [72]\n", 336 | "189 b'depthwise_conv2d_1/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 337 | "190 b'regressor_8/Kernel_dequantize' 0 0 [32, 1, 1, 88]\n", 338 | "191 b'conv2d_15/Bias_dequantize' 0 0 [96]\n", 339 | "192 b'depthwise_conv2d_8/Kernel_dequantize' 0 0 [1, 3, 3, 64]\n", 340 | "193 b'conv2d/Bias_dequantize' 0 0 [24]\n", 341 | "194 b'conv2d_16/Kernel_dequantize' 0 0 [96, 1, 1, 96]\n", 342 | "195 b'conv2d_10/Bias_dequantize' 0 0 [80]\n", 343 | "196 b'depthwise_conv2d_13/Bias_dequantize' 0 0 [96]\n", 344 | "197 b'conv2d_4/Bias_dequantize' 0 0 [36]\n", 345 | "198 b'depthwise_conv2d_14/Kernel_dequantize' 0 0 [1, 3, 3, 96]\n", 346 | "199 b'conv2d_9/Kernel_dequantize' 0 0 [72, 1, 1, 64]\n", 347 | "200 b'depthwise_conv2d_10/Bias_dequantize' 0 0 [80]\n", 348 | "201 b'conv2d_3/Kernel_dequantize' 0 0 [32, 1, 1, 28]\n", 349 | "202 b'depthwise_conv2d_4/Bias_dequantize' 0 0 [36]\n", 350 | "203 b'conv2d_1/Bias_dequantize' 0 0 [24]\n", 351 | "204 b'conv2d_6/Bias_dequantize' 0 0 [48]\n", 352 | "205 b'depthwise_conv2d_9/Kernel_dequantize' 0 0 [1, 3, 3, 72]\n", 353 | "206 b'depthwise_conv2d_3/Kernel_dequantize' 0 0 [1, 3, 3, 32]\n", 354 | "207 b'conv2d_2/Kernel_dequantize' 0 0 [28, 1, 1, 24]\n", 355 | "208 b'regressor_16/Kernel_dequantize' 0 0 [96, 1, 1, 96]\n", 356 | "209 b'conv2d_12/Bias_dequantize' 0 0 [96]\n", 357 | "210 b'conv2d_5/Bias_dequantize' 0 0 [42]\n", 358 | "211 b'depthwise_conv2d_6/Bias_dequantize' 0 0 [48]\n", 359 | "212 b'depthwise_conv2d_2/Kernel_dequantize' 0 0 [1, 3, 3, 28]\n", 360 | "213 b'conv2d_14/Bias_dequantize' 0 0 [96]\n", 361 | "214 b'depthwise_conv2d/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 362 | "215 b'conv2d_4/Kernel_dequantize' 0 0 [36, 1, 1, 32]\n", 363 | "216 b'depthwise_conv2d_5/Bias_dequantize' 0 0 [42]\n", 364 | "217 b'conv2d_11/Bias_dequantize' 0 0 [88]\n", 365 | "218 b'depthwise_conv2d_12/Bias_dequantize' 0 0 [96]\n", 366 | "219 b'depthwise_conv2d_4/Kernel_dequantize' 0 0 [1, 3, 3, 36]\n", 367 | "220 b'conv2d_15/Kernel_dequantize' 0 0 [96, 1, 1, 96]\n", 368 | "221 b'conv2d_1/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 369 | "222 b'classificator_8/Bias_dequantize' 0 0 [2]\n", 370 | "223 b'depthwise_conv2d_13/Kernel_dequantize' 0 0 [1, 3, 3, 96]\n", 371 | "224 b'conv2d/Kernel_dequantize' 0 0 [24, 5, 5, 3]\n", 372 | "225 b'conv2d_10/Kernel_dequantize' 0 0 [80, 1, 1, 72]\n", 373 | "226 b'depthwise_conv2d_11/Bias_dequantize' 0 0 [88]\n", 374 | "227 b'conv2d_7/Bias_dequantize' 0 0 [56]\n", 375 | "228 b'depthwise_conv2d_10/Kernel_dequantize' 0 0 [1, 3, 3, 80]\n", 376 | "229 b'conv2d_12/Kernel_dequantize' 0 0 [96, 1, 1, 88]\n", 377 | "230 b'conv2d_14/Kernel_dequantize' 0 0 [96, 1, 1, 96]\n", 378 | "231 b'conv2d_13/Bias_dequantize' 0 0 [96]\n", 379 | "232 b'conv2d_6/Kernel_dequantize' 0 0 [48, 1, 1, 42]\n", 380 | "233 b'depthwise_conv2d_7/Bias_dequantize' 0 0 [56]\n", 381 | "234 b'classificator_16/Bias_dequantize' 0 0 [6]\n", 382 | "235 b'conv2d_11/Kernel_dequantize' 0 0 [88, 1, 1, 80]\n", 383 | "236 b'depthwise_conv2d_12/Kernel_dequantize' 0 0 [1, 3, 3, 96]\n", 384 | "237 b'conv2d_5/Kernel_dequantize' 0 0 [42, 1, 1, 36]\n", 385 | "238 b'depthwise_conv2d_6/Kernel_dequantize' 0 0 [1, 3, 3, 48]\n", 386 | "239 b'depthwise_conv2d_15/Bias_dequantize' 0 0 [96]\n", 387 | "240 b'conv2d_8/Bias_dequantize' 0 0 [64]\n", 388 | "241 b'depthwise_conv2d_11/Kernel_dequantize' 0 0 [1, 3, 3, 88]\n", 389 | "242 b'depthwise_conv2d_5/Kernel_dequantize' 0 0 [1, 3, 3, 42]\n", 390 | "243 b'conv2d_7/Kernel_dequantize' 0 0 [56, 1, 1, 48]\n", 391 | "244 b'depthwise_conv2d_8/Bias_dequantize' 0 0 [64]\n", 392 | "245 b'classificator_8/Kernel_dequantize' 0 0 [2, 1, 1, 88]\n", 393 | "246 b'depthwise_conv2d_7/Kernel_dequantize' 0 0 [1, 3, 3, 56]\n", 394 | "247 b'depthwise_conv2d_1/Bias_dequantize' 0 0 [24]\n", 395 | "248 b'regressor_8/Bias_dequantize' 0 0 [32]\n", 396 | "249 b'conv2d_13/Kernel_dequantize' 0 0 [96, 1, 1, 96]\n" 397 | ] 398 | } 399 | ], 400 | "source": [ 401 | "def print_graph(graph):\n", 402 | " for i in range(0, graph.TensorsLength()):\n", 403 | " tensor = graph.Tensors(i)\n", 404 | " print(\"%3d %30s %d %2d %s\" % (i, tensor.Name(), tensor.Type(), tensor.Buffer(), \n", 405 | " get_shape(graph.Tensors(i))))\n", 406 | "\n", 407 | "print_graph(front_subgraph)" 408 | ] 409 | }, 410 | { 411 | "cell_type": "markdown", 412 | "metadata": {}, 413 | "source": [ 414 | "Make a look-up table that lets us get the tensor index based on the tensor name:" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 7, 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [ 423 | "front_tensor_dict = {(front_subgraph.Tensors(i).Name().decode(\"utf8\")): i \n", 424 | " for i in range(front_subgraph.TensorsLength())}" 425 | ] 426 | }, 427 | { 428 | "cell_type": "markdown", 429 | "metadata": {}, 430 | "source": [ 431 | "Grab only the tensors that represent weights and biases." 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": 8, 437 | "metadata": {}, 438 | "outputs": [ 439 | { 440 | "data": { 441 | "text/plain": [ 442 | "85" 443 | ] 444 | }, 445 | "execution_count": 8, 446 | "metadata": {}, 447 | "output_type": "execute_result" 448 | } 449 | ], 450 | "source": [ 451 | "def get_parameters(graph):\n", 452 | " parameters = {}\n", 453 | " for i in range(graph.TensorsLength()):\n", 454 | " tensor = graph.Tensors(i)\n", 455 | " if tensor.Buffer() > 0:\n", 456 | " name = tensor.Name().decode(\"utf8\")\n", 457 | " parameters[name] = tensor.Buffer()\n", 458 | " return parameters\n", 459 | "\n", 460 | "front_parameters = get_parameters(front_subgraph)\n", 461 | "len(front_parameters)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "markdown", 466 | "metadata": {}, 467 | "source": [ 468 | "The buffers are simply arrays of bytes. As the docs say,\n", 469 | "\n", 470 | "> The data_buffer itself is an opaque container, with the assumption that the\n", 471 | "> target device is little-endian. In addition, all builtin operators assume\n", 472 | "> the memory is ordered such that if `shape` is [4, 3, 2], then index\n", 473 | "> [i, j, k] maps to `data_buffer[i*3*2 + j*2 + k]`.\n", 474 | "\n", 475 | "For weights and biases, we need to interpret every 4 bytes as being as float. On my machine, the native byte ordering is already little-endian so we don't need to do anything special for that.\n", 476 | "\n", 477 | "Found some weights and biases stored as float16 instead of float32 corresponding to Type 1 instead of 0." 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": 9, 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [ 486 | "def get_weights(model, graph, tensor_dict, tensor_name):\n", 487 | " i = tensor_dict[tensor_name]\n", 488 | " tensor = graph.Tensors(i)\n", 489 | " buffer = tensor.Buffer()\n", 490 | " shape = get_shape(tensor)\n", 491 | " assert(tensor.Type() == 0 or tensor.Type() == 1) # FLOAT32\n", 492 | " \n", 493 | " W = model.Buffers(buffer).DataAsNumpy()\n", 494 | " if tensor.Type() == 0:\n", 495 | " W = W.view(dtype=np.float32)\n", 496 | " elif tensor.Type() == 1:\n", 497 | " W = W.view(dtype=np.float16)\n", 498 | " W = W.reshape(shape)\n", 499 | " return W" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": 10, 505 | "metadata": {}, 506 | "outputs": [ 507 | { 508 | "data": { 509 | "text/plain": [ 510 | "((24, 5, 5, 3), (24,))" 511 | ] 512 | }, 513 | "execution_count": 10, 514 | "metadata": {}, 515 | "output_type": "execute_result" 516 | } 517 | ], 518 | "source": [ 519 | "W = get_weights(front_model, front_subgraph, front_tensor_dict, \"conv2d/Kernel\")\n", 520 | "b = get_weights(front_model, front_subgraph, front_tensor_dict, \"conv2d/Bias\")\n", 521 | "W.shape, b.shape" 522 | ] 523 | }, 524 | { 525 | "cell_type": "markdown", 526 | "metadata": {}, 527 | "source": [ 528 | "Now we can get the weights for all the layers and copy them into our PyTorch model." 529 | ] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "metadata": {}, 534 | "source": [ 535 | "### Convert the weights to PyTorch format" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": 11, 541 | "metadata": { 542 | "scrolled": true 543 | }, 544 | "outputs": [], 545 | "source": [ 546 | "import torch\n", 547 | "from blazeface import BlazeFace" 548 | ] 549 | }, 550 | { 551 | "cell_type": "code", 552 | "execution_count": 12, 553 | "metadata": {}, 554 | "outputs": [], 555 | "source": [ 556 | "front_net = BlazeFace()" 557 | ] 558 | }, 559 | { 560 | "cell_type": "code", 561 | "execution_count": 13, 562 | "metadata": {}, 563 | "outputs": [ 564 | { 565 | "data": { 566 | "text/plain": [ 567 | "BlazeFace(\n", 568 | " (backbone1): Sequential(\n", 569 | " (0): Conv2d(3, 24, kernel_size=(5, 5), stride=(2, 2))\n", 570 | " (1): ReLU(inplace=True)\n", 571 | " (2): BlazeBlock(\n", 572 | " (convs): Sequential(\n", 573 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 574 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 575 | " )\n", 576 | " (act): ReLU(inplace=True)\n", 577 | " )\n", 578 | " (3): BlazeBlock(\n", 579 | " (convs): Sequential(\n", 580 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 581 | " (1): Conv2d(24, 28, kernel_size=(1, 1), stride=(1, 1))\n", 582 | " )\n", 583 | " (act): ReLU(inplace=True)\n", 584 | " )\n", 585 | " (4): BlazeBlock(\n", 586 | " (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 587 | " (convs): Sequential(\n", 588 | " (0): Conv2d(28, 28, kernel_size=(3, 3), stride=(2, 2), groups=28)\n", 589 | " (1): Conv2d(28, 32, kernel_size=(1, 1), stride=(1, 1))\n", 590 | " )\n", 591 | " (act): ReLU(inplace=True)\n", 592 | " )\n", 593 | " (5): BlazeBlock(\n", 594 | " (convs): Sequential(\n", 595 | " (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)\n", 596 | " (1): Conv2d(32, 36, kernel_size=(1, 1), stride=(1, 1))\n", 597 | " )\n", 598 | " (act): ReLU(inplace=True)\n", 599 | " )\n", 600 | " (6): BlazeBlock(\n", 601 | " (convs): Sequential(\n", 602 | " (0): Conv2d(36, 36, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=36)\n", 603 | " (1): Conv2d(36, 42, kernel_size=(1, 1), stride=(1, 1))\n", 604 | " )\n", 605 | " (act): ReLU(inplace=True)\n", 606 | " )\n", 607 | " (7): BlazeBlock(\n", 608 | " (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 609 | " (convs): Sequential(\n", 610 | " (0): Conv2d(42, 42, kernel_size=(3, 3), stride=(2, 2), groups=42)\n", 611 | " (1): Conv2d(42, 48, kernel_size=(1, 1), stride=(1, 1))\n", 612 | " )\n", 613 | " (act): ReLU(inplace=True)\n", 614 | " )\n", 615 | " (8): BlazeBlock(\n", 616 | " (convs): Sequential(\n", 617 | " (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48)\n", 618 | " (1): Conv2d(48, 56, kernel_size=(1, 1), stride=(1, 1))\n", 619 | " )\n", 620 | " (act): ReLU(inplace=True)\n", 621 | " )\n", 622 | " (9): BlazeBlock(\n", 623 | " (convs): Sequential(\n", 624 | " (0): Conv2d(56, 56, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=56)\n", 625 | " (1): Conv2d(56, 64, kernel_size=(1, 1), stride=(1, 1))\n", 626 | " )\n", 627 | " (act): ReLU(inplace=True)\n", 628 | " )\n", 629 | " (10): BlazeBlock(\n", 630 | " (convs): Sequential(\n", 631 | " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)\n", 632 | " (1): Conv2d(64, 72, kernel_size=(1, 1), stride=(1, 1))\n", 633 | " )\n", 634 | " (act): ReLU(inplace=True)\n", 635 | " )\n", 636 | " (11): BlazeBlock(\n", 637 | " (convs): Sequential(\n", 638 | " (0): Conv2d(72, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=72)\n", 639 | " (1): Conv2d(72, 80, kernel_size=(1, 1), stride=(1, 1))\n", 640 | " )\n", 641 | " (act): ReLU(inplace=True)\n", 642 | " )\n", 643 | " (12): BlazeBlock(\n", 644 | " (convs): Sequential(\n", 645 | " (0): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=80)\n", 646 | " (1): Conv2d(80, 88, kernel_size=(1, 1), stride=(1, 1))\n", 647 | " )\n", 648 | " (act): ReLU(inplace=True)\n", 649 | " )\n", 650 | " )\n", 651 | " (backbone2): Sequential(\n", 652 | " (0): BlazeBlock(\n", 653 | " (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 654 | " (convs): Sequential(\n", 655 | " (0): Conv2d(88, 88, kernel_size=(3, 3), stride=(2, 2), groups=88)\n", 656 | " (1): Conv2d(88, 96, kernel_size=(1, 1), stride=(1, 1))\n", 657 | " )\n", 658 | " (act): ReLU(inplace=True)\n", 659 | " )\n", 660 | " (1): BlazeBlock(\n", 661 | " (convs): Sequential(\n", 662 | " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96)\n", 663 | " (1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))\n", 664 | " )\n", 665 | " (act): ReLU(inplace=True)\n", 666 | " )\n", 667 | " (2): BlazeBlock(\n", 668 | " (convs): Sequential(\n", 669 | " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96)\n", 670 | " (1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))\n", 671 | " )\n", 672 | " (act): ReLU(inplace=True)\n", 673 | " )\n", 674 | " (3): BlazeBlock(\n", 675 | " (convs): Sequential(\n", 676 | " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96)\n", 677 | " (1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))\n", 678 | " )\n", 679 | " (act): ReLU(inplace=True)\n", 680 | " )\n", 681 | " (4): BlazeBlock(\n", 682 | " (convs): Sequential(\n", 683 | " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96)\n", 684 | " (1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))\n", 685 | " )\n", 686 | " (act): ReLU(inplace=True)\n", 687 | " )\n", 688 | " )\n", 689 | " (classifier_8): Conv2d(88, 2, kernel_size=(1, 1), stride=(1, 1))\n", 690 | " (classifier_16): Conv2d(96, 6, kernel_size=(1, 1), stride=(1, 1))\n", 691 | " (regressor_8): Conv2d(88, 32, kernel_size=(1, 1), stride=(1, 1))\n", 692 | " (regressor_16): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))\n", 693 | ")" 694 | ] 695 | }, 696 | "execution_count": 13, 697 | "metadata": {}, 698 | "output_type": "execute_result" 699 | } 700 | ], 701 | "source": [ 702 | "front_net" 703 | ] 704 | }, 705 | { 706 | "cell_type": "markdown", 707 | "metadata": {}, 708 | "source": [ 709 | "Make a lookup table that maps the layer names between the two models. We're going to assume here that the tensors will be in the same order in both models. If not, we should get an error because shapes don't match." 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "execution_count": 14, 715 | "metadata": {}, 716 | "outputs": [ 717 | { 718 | "data": { 719 | "text/plain": [ 720 | "['conv2d/Kernel',\n", 721 | " 'conv2d/Bias',\n", 722 | " 'depthwise_conv2d/Kernel',\n", 723 | " 'depthwise_conv2d/Bias',\n", 724 | " 'conv2d_1/Kernel']" 725 | ] 726 | }, 727 | "execution_count": 14, 728 | "metadata": {}, 729 | "output_type": "execute_result" 730 | } 731 | ], 732 | "source": [ 733 | "def get_probable_names(graph):\n", 734 | " probable_names = []\n", 735 | " for i in range(0, graph.TensorsLength()):\n", 736 | " tensor = graph.Tensors(i)\n", 737 | " if tensor.Buffer() > 0 and (tensor.Type() == 0 or tensor.Type() == 1):\n", 738 | " probable_names.append(tensor.Name().decode(\"utf-8\"))\n", 739 | " return probable_names\n", 740 | "\n", 741 | "front_probable_names = get_probable_names(front_subgraph)\n", 742 | " \n", 743 | "front_probable_names[:5]" 744 | ] 745 | }, 746 | { 747 | "cell_type": "code", 748 | "execution_count": 15, 749 | "metadata": {}, 750 | "outputs": [], 751 | "source": [ 752 | "def get_convert(net, probable_names):\n", 753 | " convert = {}\n", 754 | " i = 0\n", 755 | " for name, params in net.state_dict().items():\n", 756 | " convert[name] = probable_names[i]\n", 757 | " i += 1\n", 758 | " return convert\n", 759 | "\n", 760 | "front_convert = get_convert(front_net, front_probable_names)" 761 | ] 762 | }, 763 | { 764 | "cell_type": "markdown", 765 | "metadata": {}, 766 | "source": [ 767 | "Copy the weights into the layers.\n", 768 | "\n", 769 | "Note that the ordering of the weights is different between PyTorch and TFLite, so we need to transpose them.\n", 770 | "\n", 771 | "Convolution weights:\n", 772 | "\n", 773 | " TFLite: (out_channels, kernel_height, kernel_width, in_channels)\n", 774 | " PyTorch: (out_channels, in_channels, kernel_height, kernel_width)\n", 775 | "\n", 776 | "Depthwise convolution weights:\n", 777 | "\n", 778 | " TFLite: (1, kernel_height, kernel_width, channels)\n", 779 | " PyTorch: (channels, 1, kernel_height, kernel_width)" 780 | ] 781 | }, 782 | { 783 | "cell_type": "code", 784 | "execution_count": 16, 785 | "metadata": {}, 786 | "outputs": [ 787 | { 788 | "name": "stdout", 789 | "output_type": "stream", 790 | "text": [ 791 | "backbone1.0.weight conv2d/Kernel (24, 5, 5, 3) torch.Size([24, 3, 5, 5])\n", 792 | "backbone1.0.bias conv2d/Bias (24,) torch.Size([24])\n", 793 | "backbone1.2.convs.0.weight depthwise_conv2d/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 794 | "backbone1.2.convs.0.bias depthwise_conv2d/Bias (24,) torch.Size([24])\n", 795 | "backbone1.2.convs.1.weight conv2d_1/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 796 | "backbone1.2.convs.1.bias conv2d_1/Bias (24,) torch.Size([24])\n", 797 | "backbone1.3.convs.0.weight depthwise_conv2d_1/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 798 | "backbone1.3.convs.0.bias depthwise_conv2d_1/Bias (24,) torch.Size([24])\n", 799 | "backbone1.3.convs.1.weight conv2d_2/Kernel (28, 1, 1, 24) torch.Size([28, 24, 1, 1])\n", 800 | "backbone1.3.convs.1.bias conv2d_2/Bias (28,) torch.Size([28])\n", 801 | "backbone1.4.convs.0.weight depthwise_conv2d_2/Kernel (1, 3, 3, 28) torch.Size([28, 1, 3, 3])\n", 802 | "backbone1.4.convs.0.bias depthwise_conv2d_2/Bias (28,) torch.Size([28])\n", 803 | "backbone1.4.convs.1.weight conv2d_3/Kernel (32, 1, 1, 28) torch.Size([32, 28, 1, 1])\n", 804 | "backbone1.4.convs.1.bias conv2d_3/Bias (32,) torch.Size([32])\n", 805 | "backbone1.5.convs.0.weight depthwise_conv2d_3/Kernel (1, 3, 3, 32) torch.Size([32, 1, 3, 3])\n", 806 | "backbone1.5.convs.0.bias depthwise_conv2d_3/Bias (32,) torch.Size([32])\n", 807 | "backbone1.5.convs.1.weight conv2d_4/Kernel (36, 1, 1, 32) torch.Size([36, 32, 1, 1])\n", 808 | "backbone1.5.convs.1.bias conv2d_4/Bias (36,) torch.Size([36])\n", 809 | "backbone1.6.convs.0.weight depthwise_conv2d_4/Kernel (1, 3, 3, 36) torch.Size([36, 1, 3, 3])\n", 810 | "backbone1.6.convs.0.bias depthwise_conv2d_4/Bias (36,) torch.Size([36])\n", 811 | "backbone1.6.convs.1.weight conv2d_5/Kernel (42, 1, 1, 36) torch.Size([42, 36, 1, 1])\n", 812 | "backbone1.6.convs.1.bias conv2d_5/Bias (42,) torch.Size([42])\n", 813 | "backbone1.7.convs.0.weight depthwise_conv2d_5/Kernel (1, 3, 3, 42) torch.Size([42, 1, 3, 3])\n", 814 | "backbone1.7.convs.0.bias depthwise_conv2d_5/Bias (42,) torch.Size([42])\n", 815 | "backbone1.7.convs.1.weight conv2d_6/Kernel (48, 1, 1, 42) torch.Size([48, 42, 1, 1])\n", 816 | "backbone1.7.convs.1.bias conv2d_6/Bias (48,) torch.Size([48])\n", 817 | "backbone1.8.convs.0.weight depthwise_conv2d_6/Kernel (1, 3, 3, 48) torch.Size([48, 1, 3, 3])\n", 818 | "backbone1.8.convs.0.bias depthwise_conv2d_6/Bias (48,) torch.Size([48])\n", 819 | "backbone1.8.convs.1.weight conv2d_7/Kernel (56, 1, 1, 48) torch.Size([56, 48, 1, 1])\n", 820 | "backbone1.8.convs.1.bias conv2d_7/Bias (56,) torch.Size([56])\n", 821 | "backbone1.9.convs.0.weight depthwise_conv2d_7/Kernel (1, 3, 3, 56) torch.Size([56, 1, 3, 3])\n", 822 | "backbone1.9.convs.0.bias depthwise_conv2d_7/Bias (56,) torch.Size([56])\n", 823 | "backbone1.9.convs.1.weight conv2d_8/Kernel (64, 1, 1, 56) torch.Size([64, 56, 1, 1])\n", 824 | "backbone1.9.convs.1.bias conv2d_8/Bias (64,) torch.Size([64])\n", 825 | "backbone1.10.convs.0.weight depthwise_conv2d_8/Kernel (1, 3, 3, 64) torch.Size([64, 1, 3, 3])\n", 826 | "backbone1.10.convs.0.bias depthwise_conv2d_8/Bias (64,) torch.Size([64])\n", 827 | "backbone1.10.convs.1.weight conv2d_9/Kernel (72, 1, 1, 64) torch.Size([72, 64, 1, 1])\n", 828 | "backbone1.10.convs.1.bias conv2d_9/Bias (72,) torch.Size([72])\n", 829 | "backbone1.11.convs.0.weight depthwise_conv2d_9/Kernel (1, 3, 3, 72) torch.Size([72, 1, 3, 3])\n", 830 | "backbone1.11.convs.0.bias depthwise_conv2d_9/Bias (72,) torch.Size([72])\n", 831 | "backbone1.11.convs.1.weight conv2d_10/Kernel (80, 1, 1, 72) torch.Size([80, 72, 1, 1])\n", 832 | "backbone1.11.convs.1.bias conv2d_10/Bias (80,) torch.Size([80])\n", 833 | "backbone1.12.convs.0.weight depthwise_conv2d_10/Kernel (1, 3, 3, 80) torch.Size([80, 1, 3, 3])\n", 834 | "backbone1.12.convs.0.bias depthwise_conv2d_10/Bias (80,) torch.Size([80])\n", 835 | "backbone1.12.convs.1.weight conv2d_11/Kernel (88, 1, 1, 80) torch.Size([88, 80, 1, 1])\n", 836 | "backbone1.12.convs.1.bias conv2d_11/Bias (88,) torch.Size([88])\n", 837 | "backbone2.0.convs.0.weight depthwise_conv2d_11/Kernel (1, 3, 3, 88) torch.Size([88, 1, 3, 3])\n", 838 | "backbone2.0.convs.0.bias depthwise_conv2d_11/Bias (88,) torch.Size([88])\n", 839 | "backbone2.0.convs.1.weight conv2d_12/Kernel (96, 1, 1, 88) torch.Size([96, 88, 1, 1])\n", 840 | "backbone2.0.convs.1.bias conv2d_12/Bias (96,) torch.Size([96])\n", 841 | "backbone2.1.convs.0.weight depthwise_conv2d_12/Kernel (1, 3, 3, 96) torch.Size([96, 1, 3, 3])\n", 842 | "backbone2.1.convs.0.bias depthwise_conv2d_12/Bias (96,) torch.Size([96])\n", 843 | "backbone2.1.convs.1.weight conv2d_13/Kernel (96, 1, 1, 96) torch.Size([96, 96, 1, 1])\n", 844 | "backbone2.1.convs.1.bias conv2d_13/Bias (96,) torch.Size([96])\n", 845 | "backbone2.2.convs.0.weight depthwise_conv2d_13/Kernel (1, 3, 3, 96) torch.Size([96, 1, 3, 3])\n", 846 | "backbone2.2.convs.0.bias depthwise_conv2d_13/Bias (96,) torch.Size([96])\n", 847 | "backbone2.2.convs.1.weight conv2d_14/Kernel (96, 1, 1, 96) torch.Size([96, 96, 1, 1])\n", 848 | "backbone2.2.convs.1.bias conv2d_14/Bias (96,) torch.Size([96])\n", 849 | "backbone2.3.convs.0.weight depthwise_conv2d_14/Kernel (1, 3, 3, 96) torch.Size([96, 1, 3, 3])\n", 850 | "backbone2.3.convs.0.bias depthwise_conv2d_14/Bias (96,) torch.Size([96])\n", 851 | "backbone2.3.convs.1.weight conv2d_15/Kernel (96, 1, 1, 96) torch.Size([96, 96, 1, 1])\n", 852 | "backbone2.3.convs.1.bias conv2d_15/Bias (96,) torch.Size([96])\n", 853 | "backbone2.4.convs.0.weight depthwise_conv2d_15/Kernel (1, 3, 3, 96) torch.Size([96, 1, 3, 3])\n", 854 | "backbone2.4.convs.0.bias depthwise_conv2d_15/Bias (96,) torch.Size([96])\n", 855 | "backbone2.4.convs.1.weight conv2d_16/Kernel (96, 1, 1, 96) torch.Size([96, 96, 1, 1])\n", 856 | "backbone2.4.convs.1.bias conv2d_16/Bias (96,) torch.Size([96])\n", 857 | "classifier_8.weight classificator_8/Kernel (2, 1, 1, 88) torch.Size([2, 88, 1, 1])\n", 858 | "classifier_8.bias classificator_8/Bias (2,) torch.Size([2])\n", 859 | "classifier_16.weight classificator_16/Kernel (6, 1, 1, 96) torch.Size([6, 96, 1, 1])\n", 860 | "classifier_16.bias classificator_16/Bias (6,) torch.Size([6])\n", 861 | "regressor_8.weight regressor_8/Kernel (32, 1, 1, 88) torch.Size([32, 88, 1, 1])\n", 862 | "regressor_8.bias regressor_8/Bias (32,) torch.Size([32])\n", 863 | "regressor_16.weight regressor_16/Kernel (96, 1, 1, 96) torch.Size([96, 96, 1, 1])\n", 864 | "regressor_16.bias regressor_16/Bias (96,) torch.Size([96])\n" 865 | ] 866 | }, 867 | { 868 | "name": "stderr", 869 | "output_type": "stream", 870 | "text": [ 871 | ":14: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /opt/conda/conda-bld/pytorch_1595629395347/work/torch/csrc/utils/tensor_numpy.cpp:141.)\n", 872 | " new_state_dict[dst] = torch.from_numpy(W)\n" 873 | ] 874 | } 875 | ], 876 | "source": [ 877 | "def build_state_dict(model, graph, tensor_dict, net, convert):\n", 878 | " new_state_dict = OrderedDict()\n", 879 | "\n", 880 | " for dst, src in convert.items():\n", 881 | " W = get_weights(model, graph, tensor_dict, src)\n", 882 | " print(dst, src, W.shape, net.state_dict()[dst].shape)\n", 883 | "\n", 884 | " if W.ndim == 4:\n", 885 | " if W.shape[0] == 1:\n", 886 | " W = W.transpose((3, 0, 1, 2)) # depthwise conv\n", 887 | " else:\n", 888 | " W = W.transpose((0, 3, 1, 2)) # regular conv\n", 889 | " \n", 890 | " new_state_dict[dst] = torch.from_numpy(W)\n", 891 | " return new_state_dict\n", 892 | "\n", 893 | "front_state_dict = build_state_dict(front_model, front_subgraph, front_tensor_dict, front_net, front_convert)" 894 | ] 895 | }, 896 | { 897 | "cell_type": "code", 898 | "execution_count": 17, 899 | "metadata": {}, 900 | "outputs": [ 901 | { 902 | "data": { 903 | "text/plain": [ 904 | "" 905 | ] 906 | }, 907 | "execution_count": 17, 908 | "metadata": {}, 909 | "output_type": "execute_result" 910 | } 911 | ], 912 | "source": [ 913 | "front_net.load_state_dict(front_state_dict, strict=True)" 914 | ] 915 | }, 916 | { 917 | "cell_type": "markdown", 918 | "metadata": {}, 919 | "source": [ 920 | "No errors? Then the conversion was successful!" 921 | ] 922 | }, 923 | { 924 | "cell_type": "markdown", 925 | "metadata": {}, 926 | "source": [ 927 | "### Save the checkpoint" 928 | ] 929 | }, 930 | { 931 | "cell_type": "code", 932 | "execution_count": 18, 933 | "metadata": {}, 934 | "outputs": [], 935 | "source": [ 936 | "torch.save(front_net.state_dict(), \"blazeface.pth\")" 937 | ] 938 | }, 939 | { 940 | "cell_type": "markdown", 941 | "metadata": {}, 942 | "source": [ 943 | "## Convert back camera TFLite model" 944 | ] 945 | }, 946 | { 947 | "cell_type": "code", 948 | "execution_count": 19, 949 | "metadata": {}, 950 | "outputs": [ 951 | { 952 | "name": "stdout", 953 | "output_type": "stream", 954 | "text": [ 955 | "--2021-02-09 23:19:58-- https://github.com/google/mediapipe/raw/master/mediapipe/models/face_detection_back.tflite\n", 956 | "Résolution de github.com (github.com)… 140.82.121.3\n", 957 | "Connexion à github.com (github.com)|140.82.121.3|:443… connecté.\n", 958 | "requête HTTP transmise, en attente de la réponse… 302 Found\n", 959 | "Emplacement : https://raw.githubusercontent.com/google/mediapipe/master/mediapipe/models/face_detection_back.tflite [suivant]\n", 960 | "--2021-02-09 23:19:58-- https://raw.githubusercontent.com/google/mediapipe/master/mediapipe/models/face_detection_back.tflite\n", 961 | "Résolution de raw.githubusercontent.com (raw.githubusercontent.com)… 151.101.120.133\n", 962 | "Connexion à raw.githubusercontent.com (raw.githubusercontent.com)|151.101.120.133|:443… connecté.\n", 963 | "requête HTTP transmise, en attente de la réponse… 200 OK\n", 964 | "Taille : 315332 (308K) [application/octet-stream]\n", 965 | "Enregistre : «face_detection_back.tflite»\n", 966 | "\n", 967 | "face_detection_back 100%[===================>] 307,94K --.-KB/s ds 0,02s \n", 968 | "\n", 969 | "En-tête de dernière modification manquant — horodatage arrêté.\n", 970 | "2021-02-09 23:19:58 (17,0 MB/s) - «face_detection_back.tflite» enregistré [315332/315332]\n", 971 | "\n" 972 | ] 973 | } 974 | ], 975 | "source": [ 976 | "!wget -N https://github.com/google/mediapipe/raw/master/mediapipe/models/face_detection_back.tflite" 977 | ] 978 | }, 979 | { 980 | "cell_type": "code", 981 | "execution_count": 20, 982 | "metadata": {}, 983 | "outputs": [ 984 | { 985 | "data": { 986 | "text/plain": [ 987 | "b'keras2tflite_facedetector-back.tflite.generated'" 988 | ] 989 | }, 990 | "execution_count": 20, 991 | "metadata": {}, 992 | "output_type": "execute_result" 993 | } 994 | ], 995 | "source": [ 996 | "back_data = open(\"./face_detection_back.tflite\", \"rb\").read()\n", 997 | "back_model = Model.GetRootAsModel(back_data, 0)\n", 998 | "back_subgraph = back_model.Subgraphs(0)\n", 999 | "back_subgraph.Name()" 1000 | ] 1001 | }, 1002 | { 1003 | "cell_type": "code", 1004 | "execution_count": 21, 1005 | "metadata": {}, 1006 | "outputs": [ 1007 | { 1008 | "name": "stdout", 1009 | "output_type": "stream", 1010 | "text": [ 1011 | " 0 b'input' 0 0 [1, 256, 256, 3]\n", 1012 | " 1 b'conv2d/Kernel' 1 1 [24, 5, 5, 3]\n", 1013 | " 2 b'conv2d/Bias' 1 2 [24]\n", 1014 | " 3 b'conv2d' 0 0 [1, 128, 128, 24]\n", 1015 | " 4 b'activation' 0 0 [1, 128, 128, 24]\n", 1016 | " 5 b'depthwise_conv2d/Kernel' 1 3 [1, 3, 3, 24]\n", 1017 | " 6 b'depthwise_conv2d/Bias' 1 4 [24]\n", 1018 | " 7 b'depthwise_conv2d' 0 0 [1, 128, 128, 24]\n", 1019 | " 8 b'conv2d_1/Kernel' 1 5 [24, 1, 1, 24]\n", 1020 | " 9 b'conv2d_1/Bias' 1 6 [24]\n", 1021 | " 10 b'conv2d_1' 0 0 [1, 128, 128, 24]\n", 1022 | " 11 b'add' 0 0 [1, 128, 128, 24]\n", 1023 | " 12 b'activation_1' 0 0 [1, 128, 128, 24]\n", 1024 | " 13 b'depthwise_conv2d_1/Kernel' 1 7 [1, 3, 3, 24]\n", 1025 | " 14 b'depthwise_conv2d_1/Bias' 1 8 [24]\n", 1026 | " 15 b'depthwise_conv2d_1' 0 0 [1, 128, 128, 24]\n", 1027 | " 16 b'conv2d_2/Kernel' 1 9 [24, 1, 1, 24]\n", 1028 | " 17 b'conv2d_2/Bias' 1 10 [24]\n", 1029 | " 18 b'conv2d_2' 0 0 [1, 128, 128, 24]\n", 1030 | " 19 b'add_1' 0 0 [1, 128, 128, 24]\n", 1031 | " 20 b'activation_2' 0 0 [1, 128, 128, 24]\n", 1032 | " 21 b'depthwise_conv2d_2/Kernel' 1 11 [1, 3, 3, 24]\n", 1033 | " 22 b'depthwise_conv2d_2/Bias' 1 12 [24]\n", 1034 | " 23 b'depthwise_conv2d_2' 0 0 [1, 128, 128, 24]\n", 1035 | " 24 b'conv2d_3/Kernel' 1 13 [24, 1, 1, 24]\n", 1036 | " 25 b'conv2d_3/Bias' 1 14 [24]\n", 1037 | " 26 b'conv2d_3' 0 0 [1, 128, 128, 24]\n", 1038 | " 27 b'add_2' 0 0 [1, 128, 128, 24]\n", 1039 | " 28 b'activation_3' 0 0 [1, 128, 128, 24]\n", 1040 | " 29 b'depthwise_conv2d_3/Kernel' 1 15 [1, 3, 3, 24]\n", 1041 | " 30 b'depthwise_conv2d_3/Bias' 1 16 [24]\n", 1042 | " 31 b'depthwise_conv2d_3' 0 0 [1, 128, 128, 24]\n", 1043 | " 32 b'conv2d_4/Kernel' 1 17 [24, 1, 1, 24]\n", 1044 | " 33 b'conv2d_4/Bias' 1 18 [24]\n", 1045 | " 34 b'conv2d_4' 0 0 [1, 128, 128, 24]\n", 1046 | " 35 b'add_3' 0 0 [1, 128, 128, 24]\n", 1047 | " 36 b'activation_4' 0 0 [1, 128, 128, 24]\n", 1048 | " 37 b'depthwise_conv2d_4/Kernel' 1 19 [1, 3, 3, 24]\n", 1049 | " 38 b'depthwise_conv2d_4/Bias' 1 20 [24]\n", 1050 | " 39 b'depthwise_conv2d_4' 0 0 [1, 128, 128, 24]\n", 1051 | " 40 b'conv2d_5/Kernel' 1 21 [24, 1, 1, 24]\n", 1052 | " 41 b'conv2d_5/Bias' 1 22 [24]\n", 1053 | " 42 b'conv2d_5' 0 0 [1, 128, 128, 24]\n", 1054 | " 43 b'add_4' 0 0 [1, 128, 128, 24]\n", 1055 | " 44 b'activation_5' 0 0 [1, 128, 128, 24]\n", 1056 | " 45 b'depthwise_conv2d_5/Kernel' 1 23 [1, 3, 3, 24]\n", 1057 | " 46 b'depthwise_conv2d_5/Bias' 1 24 [24]\n", 1058 | " 47 b'depthwise_conv2d_5' 0 0 [1, 128, 128, 24]\n", 1059 | " 48 b'conv2d_6/Kernel' 1 25 [24, 1, 1, 24]\n", 1060 | " 49 b'conv2d_6/Bias' 1 26 [24]\n", 1061 | " 50 b'conv2d_6' 0 0 [1, 128, 128, 24]\n", 1062 | " 51 b'add_5' 0 0 [1, 128, 128, 24]\n", 1063 | " 52 b'activation_6' 0 0 [1, 128, 128, 24]\n", 1064 | " 53 b'depthwise_conv2d_6/Kernel' 1 27 [1, 3, 3, 24]\n", 1065 | " 54 b'depthwise_conv2d_6/Bias' 1 28 [24]\n", 1066 | " 55 b'depthwise_conv2d_6' 0 0 [1, 128, 128, 24]\n", 1067 | " 56 b'conv2d_7/Kernel' 1 29 [24, 1, 1, 24]\n", 1068 | " 57 b'conv2d_7/Bias' 1 30 [24]\n", 1069 | " 58 b'conv2d_7' 0 0 [1, 128, 128, 24]\n", 1070 | " 59 b'add_6' 0 0 [1, 128, 128, 24]\n", 1071 | " 60 b'activation_7' 0 0 [1, 128, 128, 24]\n", 1072 | " 61 b'depthwise_conv2d_7/Kernel' 1 31 [1, 3, 3, 24]\n", 1073 | " 62 b'depthwise_conv2d_7/Bias' 1 32 [24]\n", 1074 | " 63 b'depthwise_conv2d_7' 0 0 [1, 64, 64, 24]\n", 1075 | " 64 b'conv2d_8/Kernel' 1 33 [24, 1, 1, 24]\n", 1076 | " 65 b'conv2d_8/Bias' 1 34 [24]\n", 1077 | " 66 b'conv2d_8' 0 0 [1, 64, 64, 24]\n", 1078 | " 67 b'max_pooling2d' 0 0 [1, 64, 64, 24]\n", 1079 | " 68 b'add_7' 0 0 [1, 64, 64, 24]\n", 1080 | " 69 b'activation_8' 0 0 [1, 64, 64, 24]\n", 1081 | " 70 b'depthwise_conv2d_8/Kernel' 1 35 [1, 3, 3, 24]\n", 1082 | " 71 b'depthwise_conv2d_8/Bias' 1 36 [24]\n", 1083 | " 72 b'depthwise_conv2d_8' 0 0 [1, 64, 64, 24]\n", 1084 | " 73 b'conv2d_9/Kernel' 1 37 [24, 1, 1, 24]\n", 1085 | " 74 b'conv2d_9/Bias' 1 38 [24]\n", 1086 | " 75 b'conv2d_9' 0 0 [1, 64, 64, 24]\n", 1087 | " 76 b'add_8' 0 0 [1, 64, 64, 24]\n", 1088 | " 77 b'activation_9' 0 0 [1, 64, 64, 24]\n", 1089 | " 78 b'depthwise_conv2d_9/Kernel' 1 39 [1, 3, 3, 24]\n", 1090 | " 79 b'depthwise_conv2d_9/Bias' 1 40 [24]\n", 1091 | " 80 b'depthwise_conv2d_9' 0 0 [1, 64, 64, 24]\n", 1092 | " 81 b'conv2d_10/Kernel' 1 41 [24, 1, 1, 24]\n", 1093 | " 82 b'conv2d_10/Bias' 1 42 [24]\n", 1094 | " 83 b'conv2d_10' 0 0 [1, 64, 64, 24]\n", 1095 | " 84 b'add_9' 0 0 [1, 64, 64, 24]\n", 1096 | " 85 b'activation_10' 0 0 [1, 64, 64, 24]\n", 1097 | " 86 b'depthwise_conv2d_10/Kernel' 1 43 [1, 3, 3, 24]\n", 1098 | " 87 b'depthwise_conv2d_10/Bias' 1 44 [24]\n", 1099 | " 88 b'depthwise_conv2d_10' 0 0 [1, 64, 64, 24]\n", 1100 | " 89 b'conv2d_11/Kernel' 1 45 [24, 1, 1, 24]\n", 1101 | " 90 b'conv2d_11/Bias' 1 46 [24]\n", 1102 | " 91 b'conv2d_11' 0 0 [1, 64, 64, 24]\n", 1103 | " 92 b'add_10' 0 0 [1, 64, 64, 24]\n", 1104 | " 93 b'activation_11' 0 0 [1, 64, 64, 24]\n", 1105 | " 94 b'depthwise_conv2d_11/Kernel' 1 47 [1, 3, 3, 24]\n", 1106 | " 95 b'depthwise_conv2d_11/Bias' 1 48 [24]\n", 1107 | " 96 b'depthwise_conv2d_11' 0 0 [1, 64, 64, 24]\n", 1108 | " 97 b'conv2d_12/Kernel' 1 49 [24, 1, 1, 24]\n", 1109 | " 98 b'conv2d_12/Bias' 1 50 [24]\n", 1110 | " 99 b'conv2d_12' 0 0 [1, 64, 64, 24]\n", 1111 | "100 b'add_11' 0 0 [1, 64, 64, 24]\n", 1112 | "101 b'activation_12' 0 0 [1, 64, 64, 24]\n", 1113 | "102 b'depthwise_conv2d_12/Kernel' 1 51 [1, 3, 3, 24]\n", 1114 | "103 b'depthwise_conv2d_12/Bias' 1 52 [24]\n", 1115 | "104 b'depthwise_conv2d_12' 0 0 [1, 64, 64, 24]\n", 1116 | "105 b'conv2d_13/Kernel' 1 53 [24, 1, 1, 24]\n", 1117 | "106 b'conv2d_13/Bias' 1 54 [24]\n", 1118 | "107 b'conv2d_13' 0 0 [1, 64, 64, 24]\n", 1119 | "108 b'add_12' 0 0 [1, 64, 64, 24]\n", 1120 | "109 b'activation_13' 0 0 [1, 64, 64, 24]\n", 1121 | "110 b'depthwise_conv2d_13/Kernel' 1 55 [1, 3, 3, 24]\n", 1122 | "111 b'depthwise_conv2d_13/Bias' 1 56 [24]\n", 1123 | "112 b'depthwise_conv2d_13' 0 0 [1, 64, 64, 24]\n", 1124 | "113 b'conv2d_14/Kernel' 1 57 [24, 1, 1, 24]\n", 1125 | "114 b'conv2d_14/Bias' 1 58 [24]\n", 1126 | "115 b'conv2d_14' 0 0 [1, 64, 64, 24]\n", 1127 | "116 b'add_13' 0 0 [1, 64, 64, 24]\n", 1128 | "117 b'activation_14' 0 0 [1, 64, 64, 24]\n", 1129 | "118 b'depthwise_conv2d_14/Kernel' 1 59 [1, 3, 3, 24]\n", 1130 | "119 b'depthwise_conv2d_14/Bias' 1 60 [24]\n", 1131 | "120 b'depthwise_conv2d_14' 0 0 [1, 64, 64, 24]\n", 1132 | "121 b'conv2d_15/Kernel' 1 61 [24, 1, 1, 24]\n", 1133 | "122 b'conv2d_15/Bias' 1 62 [24]\n", 1134 | "123 b'conv2d_15' 0 0 [1, 64, 64, 24]\n", 1135 | "124 b'add_14' 0 0 [1, 64, 64, 24]\n", 1136 | "125 b'activation_15' 0 0 [1, 64, 64, 24]\n", 1137 | "126 b'depthwise_conv2d_15/Kernel' 1 63 [1, 3, 3, 24]\n", 1138 | "127 b'depthwise_conv2d_15/Bias' 1 64 [24]\n", 1139 | "128 b'depthwise_conv2d_15' 0 0 [1, 32, 32, 24]\n", 1140 | "129 b'max_pooling2d_1' 0 0 [1, 32, 32, 24]\n", 1141 | "130 b'conv2d_16/Kernel' 1 65 [48, 1, 1, 24]\n", 1142 | "131 b'conv2d_16/Bias' 1 66 [48]\n", 1143 | "132 b'conv2d_16' 0 0 [1, 32, 32, 48]\n", 1144 | "133 b'channel_padding/Paddings' 2 67 [4, 2]\n", 1145 | "134 b'channel_padding' 0 0 [1, 32, 32, 48]\n", 1146 | "135 b'add_15' 0 0 [1, 32, 32, 48]\n", 1147 | "136 b'activation_16' 0 0 [1, 32, 32, 48]\n", 1148 | "137 b'depthwise_conv2d_16/Kernel' 1 68 [1, 3, 3, 48]\n", 1149 | "138 b'depthwise_conv2d_16/Bias' 1 69 [48]\n", 1150 | "139 b'depthwise_conv2d_16' 0 0 [1, 32, 32, 48]\n", 1151 | "140 b'conv2d_17/Kernel' 1 70 [48, 1, 1, 48]\n", 1152 | "141 b'conv2d_17/Bias' 1 71 [48]\n", 1153 | "142 b'conv2d_17' 0 0 [1, 32, 32, 48]\n", 1154 | "143 b'add_16' 0 0 [1, 32, 32, 48]\n", 1155 | "144 b'activation_17' 0 0 [1, 32, 32, 48]\n", 1156 | "145 b'depthwise_conv2d_17/Kernel' 1 72 [1, 3, 3, 48]\n", 1157 | "146 b'depthwise_conv2d_17/Bias' 1 73 [48]\n", 1158 | "147 b'depthwise_conv2d_17' 0 0 [1, 32, 32, 48]\n", 1159 | "148 b'conv2d_18/Kernel' 1 74 [48, 1, 1, 48]\n", 1160 | "149 b'conv2d_18/Bias' 1 75 [48]\n", 1161 | "150 b'conv2d_18' 0 0 [1, 32, 32, 48]\n", 1162 | "151 b'add_17' 0 0 [1, 32, 32, 48]\n", 1163 | "152 b'activation_18' 0 0 [1, 32, 32, 48]\n", 1164 | "153 b'depthwise_conv2d_18/Kernel' 1 76 [1, 3, 3, 48]\n", 1165 | "154 b'depthwise_conv2d_18/Bias' 1 77 [48]\n", 1166 | "155 b'depthwise_conv2d_18' 0 0 [1, 32, 32, 48]\n", 1167 | "156 b'conv2d_19/Kernel' 1 78 [48, 1, 1, 48]\n", 1168 | "157 b'conv2d_19/Bias' 1 79 [48]\n", 1169 | "158 b'conv2d_19' 0 0 [1, 32, 32, 48]\n", 1170 | "159 b'add_18' 0 0 [1, 32, 32, 48]\n", 1171 | "160 b'activation_19' 0 0 [1, 32, 32, 48]\n", 1172 | "161 b'depthwise_conv2d_19/Kernel' 1 80 [1, 3, 3, 48]\n", 1173 | "162 b'depthwise_conv2d_19/Bias' 1 81 [48]\n", 1174 | "163 b'depthwise_conv2d_19' 0 0 [1, 32, 32, 48]\n", 1175 | "164 b'conv2d_20/Kernel' 1 82 [48, 1, 1, 48]\n", 1176 | "165 b'conv2d_20/Bias' 1 83 [48]\n", 1177 | "166 b'conv2d_20' 0 0 [1, 32, 32, 48]\n", 1178 | "167 b'add_19' 0 0 [1, 32, 32, 48]\n", 1179 | "168 b'activation_20' 0 0 [1, 32, 32, 48]\n", 1180 | "169 b'depthwise_conv2d_20/Kernel' 1 84 [1, 3, 3, 48]\n", 1181 | "170 b'depthwise_conv2d_20/Bias' 1 85 [48]\n", 1182 | "171 b'depthwise_conv2d_20' 0 0 [1, 32, 32, 48]\n", 1183 | "172 b'conv2d_21/Kernel' 1 86 [48, 1, 1, 48]\n", 1184 | "173 b'conv2d_21/Bias' 1 87 [48]\n", 1185 | "174 b'conv2d_21' 0 0 [1, 32, 32, 48]\n", 1186 | "175 b'add_20' 0 0 [1, 32, 32, 48]\n", 1187 | "176 b'activation_21' 0 0 [1, 32, 32, 48]\n", 1188 | "177 b'depthwise_conv2d_21/Kernel' 1 88 [1, 3, 3, 48]\n", 1189 | "178 b'depthwise_conv2d_21/Bias' 1 89 [48]\n", 1190 | "179 b'depthwise_conv2d_21' 0 0 [1, 32, 32, 48]\n", 1191 | "180 b'conv2d_22/Kernel' 1 90 [48, 1, 1, 48]\n", 1192 | "181 b'conv2d_22/Bias' 1 91 [48]\n", 1193 | "182 b'conv2d_22' 0 0 [1, 32, 32, 48]\n", 1194 | "183 b'add_21' 0 0 [1, 32, 32, 48]\n", 1195 | "184 b'activation_22' 0 0 [1, 32, 32, 48]\n", 1196 | "185 b'depthwise_conv2d_22/Kernel' 1 92 [1, 3, 3, 48]\n", 1197 | "186 b'depthwise_conv2d_22/Bias' 1 93 [48]\n", 1198 | "187 b'depthwise_conv2d_22' 0 0 [1, 32, 32, 48]\n", 1199 | "188 b'conv2d_23/Kernel' 1 94 [48, 1, 1, 48]\n", 1200 | "189 b'conv2d_23/Bias' 1 95 [48]\n", 1201 | "190 b'conv2d_23' 0 0 [1, 32, 32, 48]\n", 1202 | "191 b'add_22' 0 0 [1, 32, 32, 48]\n", 1203 | "192 b'activation_23' 0 0 [1, 32, 32, 48]\n", 1204 | "193 b'depthwise_conv2d_23/Kernel' 1 96 [1, 3, 3, 48]\n", 1205 | "194 b'depthwise_conv2d_23/Bias' 1 97 [48]\n", 1206 | "195 b'depthwise_conv2d_23' 0 0 [1, 16, 16, 48]\n", 1207 | "196 b'max_pooling2d_2' 0 0 [1, 16, 16, 48]\n", 1208 | "197 b'conv2d_24/Kernel' 1 98 [96, 1, 1, 48]\n", 1209 | "198 b'conv2d_24/Bias' 1 99 [96]\n", 1210 | "199 b'conv2d_24' 0 0 [1, 16, 16, 96]\n", 1211 | "200 b'channel_padding_1/Paddings' 2 100 [4, 2]\n", 1212 | "201 b'channel_padding_1' 0 0 [1, 16, 16, 96]\n", 1213 | "202 b'add_23' 0 0 [1, 16, 16, 96]\n", 1214 | "203 b'activation_24' 0 0 [1, 16, 16, 96]\n", 1215 | "204 b'depthwise_conv2d_24/Kernel' 1 101 [1, 3, 3, 96]\n", 1216 | "205 b'depthwise_conv2d_24/Bias' 1 102 [96]\n", 1217 | "206 b'depthwise_conv2d_24' 0 0 [1, 16, 16, 96]\n", 1218 | "207 b'conv2d_25/Kernel' 1 103 [96, 1, 1, 96]\n", 1219 | "208 b'conv2d_25/Bias' 1 104 [96]\n", 1220 | "209 b'conv2d_25' 0 0 [1, 16, 16, 96]\n", 1221 | "210 b'add_24' 0 0 [1, 16, 16, 96]\n", 1222 | "211 b'activation_25' 0 0 [1, 16, 16, 96]\n", 1223 | "212 b'depthwise_conv2d_25/Kernel' 1 105 [1, 3, 3, 96]\n", 1224 | "213 b'depthwise_conv2d_25/Bias' 1 106 [96]\n", 1225 | "214 b'depthwise_conv2d_25' 0 0 [1, 16, 16, 96]\n", 1226 | "215 b'conv2d_26/Kernel' 1 107 [96, 1, 1, 96]\n", 1227 | "216 b'conv2d_26/Bias' 1 108 [96]\n", 1228 | "217 b'conv2d_26' 0 0 [1, 16, 16, 96]\n", 1229 | "218 b'add_25' 0 0 [1, 16, 16, 96]\n", 1230 | "219 b'activation_26' 0 0 [1, 16, 16, 96]\n", 1231 | "220 b'depthwise_conv2d_26/Kernel' 1 109 [1, 3, 3, 96]\n", 1232 | "221 b'depthwise_conv2d_26/Bias' 1 110 [96]\n", 1233 | "222 b'depthwise_conv2d_26' 0 0 [1, 16, 16, 96]\n", 1234 | "223 b'conv2d_27/Kernel' 1 111 [96, 1, 1, 96]\n", 1235 | "224 b'conv2d_27/Bias' 1 112 [96]\n", 1236 | "225 b'conv2d_27' 0 0 [1, 16, 16, 96]\n", 1237 | "226 b'add_26' 0 0 [1, 16, 16, 96]\n", 1238 | "227 b'activation_27' 0 0 [1, 16, 16, 96]\n", 1239 | "228 b'depthwise_conv2d_27/Kernel' 1 113 [1, 3, 3, 96]\n", 1240 | "229 b'depthwise_conv2d_27/Bias' 1 114 [96]\n", 1241 | "230 b'depthwise_conv2d_27' 0 0 [1, 16, 16, 96]\n", 1242 | "231 b'conv2d_28/Kernel' 1 115 [96, 1, 1, 96]\n", 1243 | "232 b'conv2d_28/Bias' 1 116 [96]\n", 1244 | "233 b'conv2d_28' 0 0 [1, 16, 16, 96]\n", 1245 | "234 b'add_27' 0 0 [1, 16, 16, 96]\n", 1246 | "235 b'activation_28' 0 0 [1, 16, 16, 96]\n", 1247 | "236 b'depthwise_conv2d_28/Kernel' 1 117 [1, 3, 3, 96]\n", 1248 | "237 b'depthwise_conv2d_28/Bias' 1 118 [96]\n", 1249 | "238 b'depthwise_conv2d_28' 0 0 [1, 16, 16, 96]\n", 1250 | "239 b'conv2d_29/Kernel' 1 119 [96, 1, 1, 96]\n", 1251 | "240 b'conv2d_29/Bias' 1 120 [96]\n", 1252 | "241 b'conv2d_29' 0 0 [1, 16, 16, 96]\n", 1253 | "242 b'add_28' 0 0 [1, 16, 16, 96]\n", 1254 | "243 b'activation_29' 0 0 [1, 16, 16, 96]\n", 1255 | "244 b'depthwise_conv2d_29/Kernel' 1 121 [1, 3, 3, 96]\n", 1256 | "245 b'depthwise_conv2d_29/Bias' 1 122 [96]\n", 1257 | "246 b'depthwise_conv2d_29' 0 0 [1, 16, 16, 96]\n", 1258 | "247 b'conv2d_30/Kernel' 1 123 [96, 1, 1, 96]\n", 1259 | "248 b'conv2d_30/Bias' 1 124 [96]\n", 1260 | "249 b'conv2d_30' 0 0 [1, 16, 16, 96]\n", 1261 | "250 b'add_29' 0 0 [1, 16, 16, 96]\n", 1262 | "251 b'activation_30' 0 0 [1, 16, 16, 96]\n", 1263 | "252 b'depthwise_conv2d_30/Kernel' 1 125 [1, 3, 3, 96]\n", 1264 | "253 b'depthwise_conv2d_30/Bias' 1 126 [96]\n", 1265 | "254 b'depthwise_conv2d_30' 0 0 [1, 16, 16, 96]\n", 1266 | "255 b'conv2d_31/Kernel' 1 127 [96, 1, 1, 96]\n", 1267 | "256 b'conv2d_31/Bias' 1 128 [96]\n", 1268 | "257 b'conv2d_31' 0 0 [1, 16, 16, 96]\n", 1269 | "258 b'add_30' 0 0 [1, 16, 16, 96]\n", 1270 | "259 b'activation_31' 0 0 [1, 16, 16, 96]\n", 1271 | "260 b'separable_conv2d__xeno_compat__depthwise/Kernel' 1 129 [1, 3, 3, 96]\n", 1272 | "261 b'separable_conv2d__xeno_compat__depthwise/Bias' 1 130 [96]\n", 1273 | "262 b'separable_conv2d__xeno_compat__depthwise' 0 0 [1, 8, 8, 96]\n", 1274 | "263 b'separable_conv2d/Kernel' 1 131 [96, 1, 1, 96]\n", 1275 | "264 b'separable_conv2d/Bias' 1 132 [96]\n", 1276 | "265 b'separable_conv2d' 0 0 [1, 8, 8, 96]\n", 1277 | "266 b'activation_32' 0 0 [1, 8, 8, 96]\n", 1278 | "267 b'classificator_16/Kernel' 1 133 [2, 1, 1, 96]\n", 1279 | "268 b'classificator_16/Bias' 1 134 [2]\n", 1280 | "269 b'classificator_16' 0 0 [1, 16, 16, 2]\n", 1281 | "270 b'classificator_32/Kernel' 1 135 [6, 1, 1, 96]\n", 1282 | "271 b'classificator_32/Bias' 1 136 [6]\n", 1283 | "272 b'classificator_32' 0 0 [1, 8, 8, 6]\n", 1284 | "273 b'regressor_16/Kernel' 1 137 [32, 1, 1, 96]\n", 1285 | "274 b'regressor_16/Bias' 1 138 [32]\n", 1286 | "275 b'regressor_16' 0 0 [1, 16, 16, 32]\n", 1287 | "276 b'regressor_32/Kernel' 1 139 [96, 1, 1, 96]\n", 1288 | "277 b'regressor_32/Bias' 1 140 [96]\n", 1289 | "278 b'regressor_32' 0 0 [1, 8, 8, 96]\n", 1290 | "279 b'reshape' 0 0 [1, 512, 1]\n", 1291 | "280 b'reshape_2' 0 0 [1, 384, 1]\n", 1292 | "281 b'reshape_1' 0 0 [1, 512, 16]\n", 1293 | "282 b'reshape_3' 0 0 [1, 384, 16]\n", 1294 | "283 b'classificators' 0 0 [1, 896, 1]\n", 1295 | "284 b'regressors' 0 0 [1, 896, 16]\n", 1296 | "285 b'depthwise_conv2d_22/Kernel_dequantize' 0 0 [1, 3, 3, 48]\n", 1297 | "286 b'conv2d_18/Bias_dequantize' 0 0 [48]\n", 1298 | "287 b'conv2d_14/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 1299 | "288 b'separable_conv2d/Bias_dequantize' 0 0 [96]\n", 1300 | "289 b'conv2d_5/Bias_dequantize' 0 0 [24]\n", 1301 | "290 b'depthwise_conv2d_27/Kernel_dequantize' 0 0 [1, 3, 3, 96]\n", 1302 | "291 b'depthwise_conv2d_21/Bias_dequantize' 0 0 [48]\n", 1303 | "292 b'conv2d_13/Bias_dequantize' 0 0 [24]\n", 1304 | "293 b'depthwise_conv2d_26/Bias_dequantize' 0 0 [96]\n", 1305 | "294 b'depthwise_conv2d_8/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1306 | "295 b'conv2d_25/Kernel_dequantize' 0 0 [96, 1, 1, 96]\n", 1307 | "296 b'conv2d_7/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 1308 | "297 b'depthwise_conv2d_28/Kernel_dequantize' 0 0 [1, 3, 3, 96]\n", 1309 | "298 b'conv2d_6/Bias_dequantize' 0 0 [24]\n", 1310 | "299 b'depthwise_conv2d_1/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1311 | "300 b'conv2d_20/Kernel_dequantize' 0 0 [48, 1, 1, 48]\n", 1312 | "301 b'conv2d_14/Bias_dequantize' 0 0 [24]\n", 1313 | "302 b'depthwise_conv2d_9/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1314 | "303 b'depthwise_conv2d_27/Bias_dequantize' 0 0 [96]\n", 1315 | "304 b'conv2d_19/Bias_dequantize' 0 0 [48]\n", 1316 | "305 b'depthwise_conv2d/Bias_dequantize' 0 0 [24]\n", 1317 | "306 b'depthwise_conv2d_23/Kernel_dequantize' 0 0 [1, 3, 3, 48]\n", 1318 | "307 b'conv2d_15/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 1319 | "308 b'depthwise_conv2d_8/Bias_dequantize' 0 0 [24]\n", 1320 | "309 b'depthwise_conv2d_22/Bias_dequantize' 0 0 [48]\n", 1321 | "310 b'conv2d_8/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 1322 | "311 b'conv2d_26/Kernel_dequantize' 0 0 [96, 1, 1, 96]\n", 1323 | "312 b'depthwise_conv2d_1/Bias_dequantize' 0 0 [24]\n", 1324 | "313 b'conv2d_20/Bias_dequantize' 0 0 [48]\n", 1325 | "314 b'conv2d_7/Bias_dequantize' 0 0 [24]\n", 1326 | "315 b'conv2d_21/Kernel_dequantize' 0 0 [48, 1, 1, 48]\n", 1327 | "316 b'depthwise_conv2d_2/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1328 | "317 b'conv2d_25/Bias_dequantize' 0 0 [96]\n", 1329 | "318 b'depthwise_conv2d_23/Bias_dequantize' 0 0 [48]\n", 1330 | "319 b'depthwise_conv2d_29/Kernel_dequantize' 0 0 [1, 3, 3, 96]\n", 1331 | "320 b'conv2d_15/Bias_dequantize' 0 0 [24]\n", 1332 | "321 b'depthwise_conv2d_28/Bias_dequantize' 0 0 [96]\n", 1333 | "322 b'depthwise_conv2d_10/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1334 | "323 b'regressor_16/Kernel_dequantize' 0 0 [32, 1, 1, 96]\n", 1335 | "324 b'conv2d_27/Kernel_dequantize' 0 0 [96, 1, 1, 96]\n", 1336 | "325 b'depthwise_conv2d_9/Bias_dequantize' 0 0 [24]\n", 1337 | "326 b'conv2d_26/Bias_dequantize' 0 0 [96]\n", 1338 | "327 b'conv2d_8/Bias_dequantize' 0 0 [24]\n", 1339 | "328 b'depthwise_conv2d_3/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1340 | "329 b'conv2d_22/Kernel_dequantize' 0 0 [48, 1, 1, 48]\n", 1341 | "330 b'depthwise_conv2d_11/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1342 | "331 b'depthwise_conv2d_29/Bias_dequantize' 0 0 [96]\n", 1343 | "332 b'depthwise_conv2d_2/Bias_dequantize' 0 0 [24]\n", 1344 | "333 b'depthwise_conv2d_16/Kernel_dequantize' 0 0 [1, 3, 3, 48]\n", 1345 | "334 b'conv2d_21/Bias_dequantize' 0 0 [48]\n", 1346 | "335 b'depthwise_conv2d_30/Kernel_dequantize' 0 0 [1, 3, 3, 96]\n", 1347 | "336 b'depthwise_conv2d_10/Bias_dequantize' 0 0 [24]\n", 1348 | "337 b'regressor_16/Bias_dequantize' 0 0 [32]\n", 1349 | "338 b'conv2d_16/Kernel_dequantize' 0 0 [48, 1, 1, 24]\n", 1350 | "339 b'classificator_16/Kernel_dequantize' 0 0 [2, 1, 1, 96]\n", 1351 | "340 b'conv2d_28/Kernel_dequantize' 0 0 [96, 1, 1, 96]\n", 1352 | "341 b'conv2d_1/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 1353 | "342 b'depthwise_conv2d_17/Kernel_dequantize' 0 0 [1, 3, 3, 48]\n", 1354 | "343 b'separable_conv2d__xeno_compat__depthwise/Kernel_dequantize' 0 0 [1, 3, 3, 96]\n", 1355 | "344 b'conv2d_9/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 1356 | "345 b'depthwise_conv2d_4/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1357 | "346 b'conv2d_27/Bias_dequantize' 0 0 [96]\n", 1358 | "347 b'conv2d/Kernel_dequantize' 0 0 [24, 5, 5, 3]\n", 1359 | "348 b'conv2d_23/Kernel_dequantize' 0 0 [48, 1, 1, 48]\n", 1360 | "349 b'depthwise_conv2d_12/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1361 | "350 b'depthwise_conv2d_16/Bias_dequantize' 0 0 [48]\n", 1362 | "351 b'depthwise_conv2d_30/Bias_dequantize' 0 0 [96]\n", 1363 | "352 b'conv2d_22/Bias_dequantize' 0 0 [48]\n", 1364 | "353 b'depthwise_conv2d_3/Bias_dequantize' 0 0 [24]\n", 1365 | "354 b'conv2d_2/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 1366 | "355 b'conv2d_16/Bias_dequantize' 0 0 [48]\n", 1367 | "356 b'depthwise_conv2d_11/Bias_dequantize' 0 0 [24]\n", 1368 | "357 b'depthwise_conv2d_5/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1369 | "358 b'conv2d_1/Bias_dequantize' 0 0 [24]\n", 1370 | "359 b'conv2d_29/Kernel_dequantize' 0 0 [96, 1, 1, 96]\n", 1371 | "360 b'conv2d_9/Bias_dequantize' 0 0 [24]\n", 1372 | "361 b'depthwise_conv2d_4/Bias_dequantize' 0 0 [24]\n", 1373 | "362 b'conv2d_23/Bias_dequantize' 0 0 [48]\n", 1374 | "363 b'conv2d/Bias_dequantize' 0 0 [24]\n", 1375 | "364 b'depthwise_conv2d_18/Kernel_dequantize' 0 0 [1, 3, 3, 48]\n", 1376 | "365 b'conv2d_28/Bias_dequantize' 0 0 [96]\n", 1377 | "366 b'conv2d_10/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 1378 | "367 b'depthwise_conv2d_12/Bias_dequantize' 0 0 [24]\n", 1379 | "368 b'classificator_16/Bias_dequantize' 0 0 [2]\n", 1380 | "369 b'depthwise_conv2d_17/Bias_dequantize' 0 0 [48]\n", 1381 | "370 b'separable_conv2d__xeno_compat__depthwise/Bias_dequantize' 0 0 [96]\n", 1382 | "371 b'depthwise_conv2d_13/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1383 | "372 b'conv2d_30/Kernel_dequantize' 0 0 [96, 1, 1, 96]\n", 1384 | "373 b'conv2d_3/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 1385 | "374 b'regressor_32/Kernel_dequantize' 0 0 [96, 1, 1, 96]\n", 1386 | "375 b'conv2d_11/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 1387 | "376 b'conv2d_29/Bias_dequantize' 0 0 [96]\n", 1388 | "377 b'depthwise_conv2d_6/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1389 | "378 b'depthwise_conv2d_24/Kernel_dequantize' 0 0 [1, 3, 3, 96]\n", 1390 | "379 b'conv2d_2/Bias_dequantize' 0 0 [24]\n", 1391 | "380 b'depthwise_conv2d_18/Bias_dequantize' 0 0 [48]\n", 1392 | "381 b'depthwise_conv2d_14/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1393 | "382 b'conv2d_10/Bias_dequantize' 0 0 [24]\n", 1394 | "383 b'conv2d_24/Kernel_dequantize' 0 0 [96, 1, 1, 48]\n", 1395 | "384 b'depthwise_conv2d_5/Bias_dequantize' 0 0 [24]\n", 1396 | "385 b'depthwise_conv2d_19/Kernel_dequantize' 0 0 [1, 3, 3, 48]\n", 1397 | "386 b'conv2d_4/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 1398 | "387 b'depthwise_conv2d_13/Bias_dequantize' 0 0 [24]\n", 1399 | "388 b'conv2d_3/Bias_dequantize' 0 0 [24]\n", 1400 | "389 b'conv2d_17/Kernel_dequantize' 0 0 [48, 1, 1, 48]\n", 1401 | "390 b'conv2d_31/Kernel_dequantize' 0 0 [96, 1, 1, 96]\n", 1402 | "391 b'depthwise_conv2d_6/Bias_dequantize' 0 0 [24]\n", 1403 | "392 b'depthwise_conv2d_24/Bias_dequantize' 0 0 [96]\n", 1404 | "393 b'depthwise_conv2d_20/Kernel_dequantize' 0 0 [1, 3, 3, 48]\n", 1405 | "394 b'conv2d_12/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 1406 | "395 b'depthwise_conv2d_25/Kernel_dequantize' 0 0 [1, 3, 3, 96]\n", 1407 | "396 b'depthwise_conv2d_7/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1408 | "397 b'conv2d_30/Bias_dequantize' 0 0 [96]\n", 1409 | "398 b'depthwise_conv2d_19/Bias_dequantize' 0 0 [48]\n", 1410 | "399 b'conv2d_24/Bias_dequantize' 0 0 [96]\n", 1411 | "400 b'regressor_32/Bias_dequantize' 0 0 [96]\n", 1412 | "401 b'depthwise_conv2d_15/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1413 | "402 b'conv2d_11/Bias_dequantize' 0 0 [24]\n", 1414 | "403 b'separable_conv2d/Kernel_dequantize' 0 0 [96, 1, 1, 96]\n", 1415 | "404 b'conv2d_5/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 1416 | "405 b'depthwise_conv2d_14/Bias_dequantize' 0 0 [24]\n", 1417 | "406 b'classificator_32/Kernel_dequantize' 0 0 [6, 1, 1, 96]\n", 1418 | "407 b'conv2d_31/Bias_dequantize' 0 0 [96]\n", 1419 | "408 b'conv2d_13/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n", 1420 | "409 b'depthwise_conv2d_26/Kernel_dequantize' 0 0 [1, 3, 3, 96]\n", 1421 | "410 b'conv2d_4/Bias_dequantize' 0 0 [24]\n", 1422 | "411 b'conv2d_18/Kernel_dequantize' 0 0 [48, 1, 1, 48]\n", 1423 | "412 b'conv2d_12/Bias_dequantize' 0 0 [24]\n", 1424 | "413 b'depthwise_conv2d_7/Bias_dequantize' 0 0 [24]\n", 1425 | "414 b'depthwise_conv2d_21/Kernel_dequantize' 0 0 [1, 3, 3, 48]\n", 1426 | "415 b'depthwise_conv2d_25/Bias_dequantize' 0 0 [96]\n", 1427 | "416 b'conv2d_17/Bias_dequantize' 0 0 [48]\n", 1428 | "417 b'depthwise_conv2d_15/Bias_dequantize' 0 0 [24]\n", 1429 | "418 b'depthwise_conv2d_20/Bias_dequantize' 0 0 [48]\n", 1430 | "419 b'conv2d_19/Kernel_dequantize' 0 0 [48, 1, 1, 48]\n", 1431 | "420 b'depthwise_conv2d/Kernel_dequantize' 0 0 [1, 3, 3, 24]\n", 1432 | "421 b'classificator_32/Bias_dequantize' 0 0 [6]\n", 1433 | "422 b'conv2d_6/Kernel_dequantize' 0 0 [24, 1, 1, 24]\n" 1434 | ] 1435 | } 1436 | ], 1437 | "source": [ 1438 | "print_graph(back_subgraph)" 1439 | ] 1440 | }, 1441 | { 1442 | "cell_type": "code", 1443 | "execution_count": 22, 1444 | "metadata": {}, 1445 | "outputs": [], 1446 | "source": [ 1447 | "back_tensor_dict = {(back_subgraph.Tensors(i).Name().decode(\"utf8\")): i \n", 1448 | " for i in range(back_subgraph.TensorsLength())}" 1449 | ] 1450 | }, 1451 | { 1452 | "cell_type": "code", 1453 | "execution_count": 23, 1454 | "metadata": {}, 1455 | "outputs": [ 1456 | { 1457 | "data": { 1458 | "text/plain": [ 1459 | "140" 1460 | ] 1461 | }, 1462 | "execution_count": 23, 1463 | "metadata": {}, 1464 | "output_type": "execute_result" 1465 | } 1466 | ], 1467 | "source": [ 1468 | "back_parameters = get_parameters(back_subgraph)\n", 1469 | "len(back_parameters)" 1470 | ] 1471 | }, 1472 | { 1473 | "cell_type": "code", 1474 | "execution_count": 24, 1475 | "metadata": {}, 1476 | "outputs": [ 1477 | { 1478 | "data": { 1479 | "text/plain": [ 1480 | "((24, 5, 5, 3), (24,))" 1481 | ] 1482 | }, 1483 | "execution_count": 24, 1484 | "metadata": {}, 1485 | "output_type": "execute_result" 1486 | } 1487 | ], 1488 | "source": [ 1489 | "W = get_weights(back_model, back_subgraph, back_tensor_dict, \"conv2d/Kernel\")\n", 1490 | "b = get_weights(back_model, back_subgraph, back_tensor_dict, \"conv2d/Bias\")\n", 1491 | "W.shape, b.shape" 1492 | ] 1493 | }, 1494 | { 1495 | "cell_type": "code", 1496 | "execution_count": 25, 1497 | "metadata": {}, 1498 | "outputs": [], 1499 | "source": [ 1500 | "back_net = BlazeFace(back_model=True)" 1501 | ] 1502 | }, 1503 | { 1504 | "cell_type": "code", 1505 | "execution_count": 26, 1506 | "metadata": {}, 1507 | "outputs": [ 1508 | { 1509 | "data": { 1510 | "text/plain": [ 1511 | "BlazeFace(\n", 1512 | " (backbone): Sequential(\n", 1513 | " (0): Conv2d(3, 24, kernel_size=(5, 5), stride=(2, 2))\n", 1514 | " (1): ReLU(inplace=True)\n", 1515 | " (2): BlazeBlock(\n", 1516 | " (convs): Sequential(\n", 1517 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 1518 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1519 | " )\n", 1520 | " (act): ReLU(inplace=True)\n", 1521 | " )\n", 1522 | " (3): BlazeBlock(\n", 1523 | " (convs): Sequential(\n", 1524 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 1525 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1526 | " )\n", 1527 | " (act): ReLU(inplace=True)\n", 1528 | " )\n", 1529 | " (4): BlazeBlock(\n", 1530 | " (convs): Sequential(\n", 1531 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 1532 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1533 | " )\n", 1534 | " (act): ReLU(inplace=True)\n", 1535 | " )\n", 1536 | " (5): BlazeBlock(\n", 1537 | " (convs): Sequential(\n", 1538 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 1539 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1540 | " )\n", 1541 | " (act): ReLU(inplace=True)\n", 1542 | " )\n", 1543 | " (6): BlazeBlock(\n", 1544 | " (convs): Sequential(\n", 1545 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 1546 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1547 | " )\n", 1548 | " (act): ReLU(inplace=True)\n", 1549 | " )\n", 1550 | " (7): BlazeBlock(\n", 1551 | " (convs): Sequential(\n", 1552 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 1553 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1554 | " )\n", 1555 | " (act): ReLU(inplace=True)\n", 1556 | " )\n", 1557 | " (8): BlazeBlock(\n", 1558 | " (convs): Sequential(\n", 1559 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 1560 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1561 | " )\n", 1562 | " (act): ReLU(inplace=True)\n", 1563 | " )\n", 1564 | " (9): BlazeBlock(\n", 1565 | " (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 1566 | " (convs): Sequential(\n", 1567 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), groups=24)\n", 1568 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1569 | " )\n", 1570 | " (act): ReLU(inplace=True)\n", 1571 | " )\n", 1572 | " (10): BlazeBlock(\n", 1573 | " (convs): Sequential(\n", 1574 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 1575 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1576 | " )\n", 1577 | " (act): ReLU(inplace=True)\n", 1578 | " )\n", 1579 | " (11): BlazeBlock(\n", 1580 | " (convs): Sequential(\n", 1581 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 1582 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1583 | " )\n", 1584 | " (act): ReLU(inplace=True)\n", 1585 | " )\n", 1586 | " (12): BlazeBlock(\n", 1587 | " (convs): Sequential(\n", 1588 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 1589 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1590 | " )\n", 1591 | " (act): ReLU(inplace=True)\n", 1592 | " )\n", 1593 | " (13): BlazeBlock(\n", 1594 | " (convs): Sequential(\n", 1595 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 1596 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1597 | " )\n", 1598 | " (act): ReLU(inplace=True)\n", 1599 | " )\n", 1600 | " (14): BlazeBlock(\n", 1601 | " (convs): Sequential(\n", 1602 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 1603 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1604 | " )\n", 1605 | " (act): ReLU(inplace=True)\n", 1606 | " )\n", 1607 | " (15): BlazeBlock(\n", 1608 | " (convs): Sequential(\n", 1609 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 1610 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1611 | " )\n", 1612 | " (act): ReLU(inplace=True)\n", 1613 | " )\n", 1614 | " (16): BlazeBlock(\n", 1615 | " (convs): Sequential(\n", 1616 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24)\n", 1617 | " (1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1))\n", 1618 | " )\n", 1619 | " (act): ReLU(inplace=True)\n", 1620 | " )\n", 1621 | " (17): BlazeBlock(\n", 1622 | " (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 1623 | " (convs): Sequential(\n", 1624 | " (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), groups=24)\n", 1625 | " (1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1))\n", 1626 | " )\n", 1627 | " (act): ReLU(inplace=True)\n", 1628 | " )\n", 1629 | " (18): BlazeBlock(\n", 1630 | " (convs): Sequential(\n", 1631 | " (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48)\n", 1632 | " (1): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1))\n", 1633 | " )\n", 1634 | " (act): ReLU(inplace=True)\n", 1635 | " )\n", 1636 | " (19): BlazeBlock(\n", 1637 | " (convs): Sequential(\n", 1638 | " (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48)\n", 1639 | " (1): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1))\n", 1640 | " )\n", 1641 | " (act): ReLU(inplace=True)\n", 1642 | " )\n", 1643 | " (20): BlazeBlock(\n", 1644 | " (convs): Sequential(\n", 1645 | " (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48)\n", 1646 | " (1): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1))\n", 1647 | " )\n", 1648 | " (act): ReLU(inplace=True)\n", 1649 | " )\n", 1650 | " (21): BlazeBlock(\n", 1651 | " (convs): Sequential(\n", 1652 | " (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48)\n", 1653 | " (1): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1))\n", 1654 | " )\n", 1655 | " (act): ReLU(inplace=True)\n", 1656 | " )\n", 1657 | " (22): BlazeBlock(\n", 1658 | " (convs): Sequential(\n", 1659 | " (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48)\n", 1660 | " (1): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1))\n", 1661 | " )\n", 1662 | " (act): ReLU(inplace=True)\n", 1663 | " )\n", 1664 | " (23): BlazeBlock(\n", 1665 | " (convs): Sequential(\n", 1666 | " (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48)\n", 1667 | " (1): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1))\n", 1668 | " )\n", 1669 | " (act): ReLU(inplace=True)\n", 1670 | " )\n", 1671 | " (24): BlazeBlock(\n", 1672 | " (convs): Sequential(\n", 1673 | " (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48)\n", 1674 | " (1): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1))\n", 1675 | " )\n", 1676 | " (act): ReLU(inplace=True)\n", 1677 | " )\n", 1678 | " (25): BlazeBlock(\n", 1679 | " (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", 1680 | " (convs): Sequential(\n", 1681 | " (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(2, 2), groups=48)\n", 1682 | " (1): Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))\n", 1683 | " )\n", 1684 | " (act): ReLU(inplace=True)\n", 1685 | " )\n", 1686 | " (26): BlazeBlock(\n", 1687 | " (convs): Sequential(\n", 1688 | " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96)\n", 1689 | " (1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))\n", 1690 | " )\n", 1691 | " (act): ReLU(inplace=True)\n", 1692 | " )\n", 1693 | " (27): BlazeBlock(\n", 1694 | " (convs): Sequential(\n", 1695 | " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96)\n", 1696 | " (1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))\n", 1697 | " )\n", 1698 | " (act): ReLU(inplace=True)\n", 1699 | " )\n", 1700 | " (28): BlazeBlock(\n", 1701 | " (convs): Sequential(\n", 1702 | " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96)\n", 1703 | " (1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))\n", 1704 | " )\n", 1705 | " (act): ReLU(inplace=True)\n", 1706 | " )\n", 1707 | " (29): BlazeBlock(\n", 1708 | " (convs): Sequential(\n", 1709 | " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96)\n", 1710 | " (1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))\n", 1711 | " )\n", 1712 | " (act): ReLU(inplace=True)\n", 1713 | " )\n", 1714 | " (30): BlazeBlock(\n", 1715 | " (convs): Sequential(\n", 1716 | " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96)\n", 1717 | " (1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))\n", 1718 | " )\n", 1719 | " (act): ReLU(inplace=True)\n", 1720 | " )\n", 1721 | " (31): BlazeBlock(\n", 1722 | " (convs): Sequential(\n", 1723 | " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96)\n", 1724 | " (1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))\n", 1725 | " )\n", 1726 | " (act): ReLU(inplace=True)\n", 1727 | " )\n", 1728 | " (32): BlazeBlock(\n", 1729 | " (convs): Sequential(\n", 1730 | " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96)\n", 1731 | " (1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))\n", 1732 | " )\n", 1733 | " (act): ReLU(inplace=True)\n", 1734 | " )\n", 1735 | " )\n", 1736 | " (final): FinalBlazeBlock(\n", 1737 | " (convs): Sequential(\n", 1738 | " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96)\n", 1739 | " (1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))\n", 1740 | " )\n", 1741 | " (act): ReLU(inplace=True)\n", 1742 | " )\n", 1743 | " (classifier_8): Conv2d(96, 2, kernel_size=(1, 1), stride=(1, 1))\n", 1744 | " (classifier_16): Conv2d(96, 6, kernel_size=(1, 1), stride=(1, 1))\n", 1745 | " (regressor_8): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1))\n", 1746 | " (regressor_16): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1))\n", 1747 | ")" 1748 | ] 1749 | }, 1750 | "execution_count": 26, 1751 | "metadata": {}, 1752 | "output_type": "execute_result" 1753 | } 1754 | ], 1755 | "source": [ 1756 | "back_net" 1757 | ] 1758 | }, 1759 | { 1760 | "cell_type": "code", 1761 | "execution_count": 27, 1762 | "metadata": {}, 1763 | "outputs": [ 1764 | { 1765 | "data": { 1766 | "text/plain": [ 1767 | "['conv2d/Kernel',\n", 1768 | " 'conv2d/Bias',\n", 1769 | " 'depthwise_conv2d/Kernel',\n", 1770 | " 'depthwise_conv2d/Bias',\n", 1771 | " 'conv2d_1/Kernel']" 1772 | ] 1773 | }, 1774 | "execution_count": 27, 1775 | "metadata": {}, 1776 | "output_type": "execute_result" 1777 | } 1778 | ], 1779 | "source": [ 1780 | "back_probable_names = get_probable_names(back_subgraph)\n", 1781 | "back_probable_names[:5]" 1782 | ] 1783 | }, 1784 | { 1785 | "cell_type": "code", 1786 | "execution_count": 28, 1787 | "metadata": {}, 1788 | "outputs": [], 1789 | "source": [ 1790 | "back_convert = get_convert(back_net, back_probable_names)" 1791 | ] 1792 | }, 1793 | { 1794 | "cell_type": "code", 1795 | "execution_count": 29, 1796 | "metadata": {}, 1797 | "outputs": [ 1798 | { 1799 | "name": "stdout", 1800 | "output_type": "stream", 1801 | "text": [ 1802 | "backbone.0.weight conv2d/Kernel (24, 5, 5, 3) torch.Size([24, 3, 5, 5])\n", 1803 | "backbone.0.bias conv2d/Bias (24,) torch.Size([24])\n", 1804 | "backbone.2.convs.0.weight depthwise_conv2d/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1805 | "backbone.2.convs.0.bias depthwise_conv2d/Bias (24,) torch.Size([24])\n", 1806 | "backbone.2.convs.1.weight conv2d_1/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1807 | "backbone.2.convs.1.bias conv2d_1/Bias (24,) torch.Size([24])\n", 1808 | "backbone.3.convs.0.weight depthwise_conv2d_1/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1809 | "backbone.3.convs.0.bias depthwise_conv2d_1/Bias (24,) torch.Size([24])\n", 1810 | "backbone.3.convs.1.weight conv2d_2/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1811 | "backbone.3.convs.1.bias conv2d_2/Bias (24,) torch.Size([24])\n", 1812 | "backbone.4.convs.0.weight depthwise_conv2d_2/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1813 | "backbone.4.convs.0.bias depthwise_conv2d_2/Bias (24,) torch.Size([24])\n", 1814 | "backbone.4.convs.1.weight conv2d_3/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1815 | "backbone.4.convs.1.bias conv2d_3/Bias (24,) torch.Size([24])\n", 1816 | "backbone.5.convs.0.weight depthwise_conv2d_3/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1817 | "backbone.5.convs.0.bias depthwise_conv2d_3/Bias (24,) torch.Size([24])\n", 1818 | "backbone.5.convs.1.weight conv2d_4/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1819 | "backbone.5.convs.1.bias conv2d_4/Bias (24,) torch.Size([24])\n", 1820 | "backbone.6.convs.0.weight depthwise_conv2d_4/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1821 | "backbone.6.convs.0.bias depthwise_conv2d_4/Bias (24,) torch.Size([24])\n", 1822 | "backbone.6.convs.1.weight conv2d_5/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1823 | "backbone.6.convs.1.bias conv2d_5/Bias (24,) torch.Size([24])\n", 1824 | "backbone.7.convs.0.weight depthwise_conv2d_5/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1825 | "backbone.7.convs.0.bias depthwise_conv2d_5/Bias (24,) torch.Size([24])\n", 1826 | "backbone.7.convs.1.weight conv2d_6/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1827 | "backbone.7.convs.1.bias conv2d_6/Bias (24,) torch.Size([24])\n", 1828 | "backbone.8.convs.0.weight depthwise_conv2d_6/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1829 | "backbone.8.convs.0.bias depthwise_conv2d_6/Bias (24,) torch.Size([24])\n", 1830 | "backbone.8.convs.1.weight conv2d_7/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1831 | "backbone.8.convs.1.bias conv2d_7/Bias (24,) torch.Size([24])\n", 1832 | "backbone.9.convs.0.weight depthwise_conv2d_7/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1833 | "backbone.9.convs.0.bias depthwise_conv2d_7/Bias (24,) torch.Size([24])\n", 1834 | "backbone.9.convs.1.weight conv2d_8/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1835 | "backbone.9.convs.1.bias conv2d_8/Bias (24,) torch.Size([24])\n", 1836 | "backbone.10.convs.0.weight depthwise_conv2d_8/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1837 | "backbone.10.convs.0.bias depthwise_conv2d_8/Bias (24,) torch.Size([24])\n", 1838 | "backbone.10.convs.1.weight conv2d_9/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1839 | "backbone.10.convs.1.bias conv2d_9/Bias (24,) torch.Size([24])\n", 1840 | "backbone.11.convs.0.weight depthwise_conv2d_9/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1841 | "backbone.11.convs.0.bias depthwise_conv2d_9/Bias (24,) torch.Size([24])\n", 1842 | "backbone.11.convs.1.weight conv2d_10/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1843 | "backbone.11.convs.1.bias conv2d_10/Bias (24,) torch.Size([24])\n", 1844 | "backbone.12.convs.0.weight depthwise_conv2d_10/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1845 | "backbone.12.convs.0.bias depthwise_conv2d_10/Bias (24,) torch.Size([24])\n", 1846 | "backbone.12.convs.1.weight conv2d_11/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1847 | "backbone.12.convs.1.bias conv2d_11/Bias (24,) torch.Size([24])\n", 1848 | "backbone.13.convs.0.weight depthwise_conv2d_11/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1849 | "backbone.13.convs.0.bias depthwise_conv2d_11/Bias (24,) torch.Size([24])\n", 1850 | "backbone.13.convs.1.weight conv2d_12/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1851 | "backbone.13.convs.1.bias conv2d_12/Bias (24,) torch.Size([24])\n", 1852 | "backbone.14.convs.0.weight depthwise_conv2d_12/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1853 | "backbone.14.convs.0.bias depthwise_conv2d_12/Bias (24,) torch.Size([24])\n", 1854 | "backbone.14.convs.1.weight conv2d_13/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1855 | "backbone.14.convs.1.bias conv2d_13/Bias (24,) torch.Size([24])\n", 1856 | "backbone.15.convs.0.weight depthwise_conv2d_13/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1857 | "backbone.15.convs.0.bias depthwise_conv2d_13/Bias (24,) torch.Size([24])\n", 1858 | "backbone.15.convs.1.weight conv2d_14/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1859 | "backbone.15.convs.1.bias conv2d_14/Bias (24,) torch.Size([24])\n", 1860 | "backbone.16.convs.0.weight depthwise_conv2d_14/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1861 | "backbone.16.convs.0.bias depthwise_conv2d_14/Bias (24,) torch.Size([24])\n", 1862 | "backbone.16.convs.1.weight conv2d_15/Kernel (24, 1, 1, 24) torch.Size([24, 24, 1, 1])\n", 1863 | "backbone.16.convs.1.bias conv2d_15/Bias (24,) torch.Size([24])\n", 1864 | "backbone.17.convs.0.weight depthwise_conv2d_15/Kernel (1, 3, 3, 24) torch.Size([24, 1, 3, 3])\n", 1865 | "backbone.17.convs.0.bias depthwise_conv2d_15/Bias (24,) torch.Size([24])\n", 1866 | "backbone.17.convs.1.weight conv2d_16/Kernel (48, 1, 1, 24) torch.Size([48, 24, 1, 1])\n", 1867 | "backbone.17.convs.1.bias conv2d_16/Bias (48,) torch.Size([48])\n", 1868 | "backbone.18.convs.0.weight depthwise_conv2d_16/Kernel (1, 3, 3, 48) torch.Size([48, 1, 3, 3])\n", 1869 | "backbone.18.convs.0.bias depthwise_conv2d_16/Bias (48,) torch.Size([48])\n", 1870 | "backbone.18.convs.1.weight conv2d_17/Kernel (48, 1, 1, 48) torch.Size([48, 48, 1, 1])\n", 1871 | "backbone.18.convs.1.bias conv2d_17/Bias (48,) torch.Size([48])\n", 1872 | "backbone.19.convs.0.weight depthwise_conv2d_17/Kernel (1, 3, 3, 48) torch.Size([48, 1, 3, 3])\n", 1873 | "backbone.19.convs.0.bias depthwise_conv2d_17/Bias (48,) torch.Size([48])\n", 1874 | "backbone.19.convs.1.weight conv2d_18/Kernel (48, 1, 1, 48) torch.Size([48, 48, 1, 1])\n", 1875 | "backbone.19.convs.1.bias conv2d_18/Bias (48,) torch.Size([48])\n", 1876 | "backbone.20.convs.0.weight depthwise_conv2d_18/Kernel (1, 3, 3, 48) torch.Size([48, 1, 3, 3])\n", 1877 | "backbone.20.convs.0.bias depthwise_conv2d_18/Bias (48,) torch.Size([48])\n", 1878 | "backbone.20.convs.1.weight conv2d_19/Kernel (48, 1, 1, 48) torch.Size([48, 48, 1, 1])\n", 1879 | "backbone.20.convs.1.bias conv2d_19/Bias (48,) torch.Size([48])\n", 1880 | "backbone.21.convs.0.weight depthwise_conv2d_19/Kernel (1, 3, 3, 48) torch.Size([48, 1, 3, 3])\n", 1881 | "backbone.21.convs.0.bias depthwise_conv2d_19/Bias (48,) torch.Size([48])\n", 1882 | "backbone.21.convs.1.weight conv2d_20/Kernel (48, 1, 1, 48) torch.Size([48, 48, 1, 1])\n", 1883 | "backbone.21.convs.1.bias conv2d_20/Bias (48,) torch.Size([48])\n", 1884 | "backbone.22.convs.0.weight depthwise_conv2d_20/Kernel (1, 3, 3, 48) torch.Size([48, 1, 3, 3])\n", 1885 | "backbone.22.convs.0.bias depthwise_conv2d_20/Bias (48,) torch.Size([48])\n", 1886 | "backbone.22.convs.1.weight conv2d_21/Kernel (48, 1, 1, 48) torch.Size([48, 48, 1, 1])\n", 1887 | "backbone.22.convs.1.bias conv2d_21/Bias (48,) torch.Size([48])\n", 1888 | "backbone.23.convs.0.weight depthwise_conv2d_21/Kernel (1, 3, 3, 48) torch.Size([48, 1, 3, 3])\n", 1889 | "backbone.23.convs.0.bias depthwise_conv2d_21/Bias (48,) torch.Size([48])\n", 1890 | "backbone.23.convs.1.weight conv2d_22/Kernel (48, 1, 1, 48) torch.Size([48, 48, 1, 1])\n", 1891 | "backbone.23.convs.1.bias conv2d_22/Bias (48,) torch.Size([48])\n", 1892 | "backbone.24.convs.0.weight depthwise_conv2d_22/Kernel (1, 3, 3, 48) torch.Size([48, 1, 3, 3])\n", 1893 | "backbone.24.convs.0.bias depthwise_conv2d_22/Bias (48,) torch.Size([48])\n", 1894 | "backbone.24.convs.1.weight conv2d_23/Kernel (48, 1, 1, 48) torch.Size([48, 48, 1, 1])\n", 1895 | "backbone.24.convs.1.bias conv2d_23/Bias (48,) torch.Size([48])\n", 1896 | "backbone.25.convs.0.weight depthwise_conv2d_23/Kernel (1, 3, 3, 48) torch.Size([48, 1, 3, 3])\n", 1897 | "backbone.25.convs.0.bias depthwise_conv2d_23/Bias (48,) torch.Size([48])\n", 1898 | "backbone.25.convs.1.weight conv2d_24/Kernel (96, 1, 1, 48) torch.Size([96, 48, 1, 1])\n", 1899 | "backbone.25.convs.1.bias conv2d_24/Bias (96,) torch.Size([96])\n", 1900 | "backbone.26.convs.0.weight depthwise_conv2d_24/Kernel (1, 3, 3, 96) torch.Size([96, 1, 3, 3])\n", 1901 | "backbone.26.convs.0.bias depthwise_conv2d_24/Bias (96,) torch.Size([96])\n", 1902 | "backbone.26.convs.1.weight conv2d_25/Kernel (96, 1, 1, 96) torch.Size([96, 96, 1, 1])\n", 1903 | "backbone.26.convs.1.bias conv2d_25/Bias (96,) torch.Size([96])\n", 1904 | "backbone.27.convs.0.weight depthwise_conv2d_25/Kernel (1, 3, 3, 96) torch.Size([96, 1, 3, 3])\n", 1905 | "backbone.27.convs.0.bias depthwise_conv2d_25/Bias (96,) torch.Size([96])\n", 1906 | "backbone.27.convs.1.weight conv2d_26/Kernel (96, 1, 1, 96) torch.Size([96, 96, 1, 1])\n", 1907 | "backbone.27.convs.1.bias conv2d_26/Bias (96,) torch.Size([96])\n", 1908 | "backbone.28.convs.0.weight depthwise_conv2d_26/Kernel (1, 3, 3, 96) torch.Size([96, 1, 3, 3])\n", 1909 | "backbone.28.convs.0.bias depthwise_conv2d_26/Bias (96,) torch.Size([96])\n", 1910 | "backbone.28.convs.1.weight conv2d_27/Kernel (96, 1, 1, 96) torch.Size([96, 96, 1, 1])\n", 1911 | "backbone.28.convs.1.bias conv2d_27/Bias (96,) torch.Size([96])\n", 1912 | "backbone.29.convs.0.weight depthwise_conv2d_27/Kernel (1, 3, 3, 96) torch.Size([96, 1, 3, 3])\n", 1913 | "backbone.29.convs.0.bias depthwise_conv2d_27/Bias (96,) torch.Size([96])\n", 1914 | "backbone.29.convs.1.weight conv2d_28/Kernel (96, 1, 1, 96) torch.Size([96, 96, 1, 1])\n", 1915 | "backbone.29.convs.1.bias conv2d_28/Bias (96,) torch.Size([96])\n", 1916 | "backbone.30.convs.0.weight depthwise_conv2d_28/Kernel (1, 3, 3, 96) torch.Size([96, 1, 3, 3])\n", 1917 | "backbone.30.convs.0.bias depthwise_conv2d_28/Bias (96,) torch.Size([96])\n", 1918 | "backbone.30.convs.1.weight conv2d_29/Kernel (96, 1, 1, 96) torch.Size([96, 96, 1, 1])\n", 1919 | "backbone.30.convs.1.bias conv2d_29/Bias (96,) torch.Size([96])\n", 1920 | "backbone.31.convs.0.weight depthwise_conv2d_29/Kernel (1, 3, 3, 96) torch.Size([96, 1, 3, 3])\n", 1921 | "backbone.31.convs.0.bias depthwise_conv2d_29/Bias (96,) torch.Size([96])\n", 1922 | "backbone.31.convs.1.weight conv2d_30/Kernel (96, 1, 1, 96) torch.Size([96, 96, 1, 1])\n", 1923 | "backbone.31.convs.1.bias conv2d_30/Bias (96,) torch.Size([96])\n", 1924 | "backbone.32.convs.0.weight depthwise_conv2d_30/Kernel (1, 3, 3, 96) torch.Size([96, 1, 3, 3])\n", 1925 | "backbone.32.convs.0.bias depthwise_conv2d_30/Bias (96,) torch.Size([96])\n", 1926 | "backbone.32.convs.1.weight conv2d_31/Kernel (96, 1, 1, 96) torch.Size([96, 96, 1, 1])\n", 1927 | "backbone.32.convs.1.bias conv2d_31/Bias (96,) torch.Size([96])\n", 1928 | "final.convs.0.weight separable_conv2d__xeno_compat__depthwise/Kernel (1, 3, 3, 96) torch.Size([96, 1, 3, 3])\n", 1929 | "final.convs.0.bias separable_conv2d__xeno_compat__depthwise/Bias (96,) torch.Size([96])\n", 1930 | "final.convs.1.weight separable_conv2d/Kernel (96, 1, 1, 96) torch.Size([96, 96, 1, 1])\n", 1931 | "final.convs.1.bias separable_conv2d/Bias (96,) torch.Size([96])\n", 1932 | "classifier_8.weight classificator_16/Kernel (2, 1, 1, 96) torch.Size([2, 96, 1, 1])\n", 1933 | "classifier_8.bias classificator_16/Bias (2,) torch.Size([2])\n", 1934 | "classifier_16.weight classificator_32/Kernel (6, 1, 1, 96) torch.Size([6, 96, 1, 1])\n", 1935 | "classifier_16.bias classificator_32/Bias (6,) torch.Size([6])\n", 1936 | "regressor_8.weight regressor_16/Kernel (32, 1, 1, 96) torch.Size([32, 96, 1, 1])\n", 1937 | "regressor_8.bias regressor_16/Bias (32,) torch.Size([32])\n", 1938 | "regressor_16.weight regressor_32/Kernel (96, 1, 1, 96) torch.Size([96, 96, 1, 1])\n", 1939 | "regressor_16.bias regressor_32/Bias (96,) torch.Size([96])\n" 1940 | ] 1941 | } 1942 | ], 1943 | "source": [ 1944 | "back_state_dict = build_state_dict(back_model, back_subgraph, back_tensor_dict, back_net, back_convert)" 1945 | ] 1946 | }, 1947 | { 1948 | "cell_type": "code", 1949 | "execution_count": 30, 1950 | "metadata": {}, 1951 | "outputs": [ 1952 | { 1953 | "data": { 1954 | "text/plain": [ 1955 | "" 1956 | ] 1957 | }, 1958 | "execution_count": 30, 1959 | "metadata": {}, 1960 | "output_type": "execute_result" 1961 | } 1962 | ], 1963 | "source": [ 1964 | "back_net.load_state_dict(back_state_dict, strict=True)" 1965 | ] 1966 | }, 1967 | { 1968 | "cell_type": "code", 1969 | "execution_count": 31, 1970 | "metadata": {}, 1971 | "outputs": [], 1972 | "source": [ 1973 | "torch.save(back_net.state_dict(), \"blazefaceback.pth\")" 1974 | ] 1975 | }, 1976 | { 1977 | "cell_type": "code", 1978 | "execution_count": null, 1979 | "metadata": {}, 1980 | "outputs": [], 1981 | "source": [] 1982 | } 1983 | ], 1984 | "metadata": { 1985 | "kernelspec": { 1986 | "display_name": "Python 3", 1987 | "language": "python", 1988 | "name": "python3" 1989 | }, 1990 | "language_info": { 1991 | "codemirror_mode": { 1992 | "name": "ipython", 1993 | "version": 3 1994 | }, 1995 | "file_extension": ".py", 1996 | "mimetype": "text/x-python", 1997 | "name": "python", 1998 | "nbconvert_exporter": "python", 1999 | "pygments_lexer": "ipython3", 2000 | "version": "3.8.5" 2001 | } 2002 | }, 2003 | "nbformat": 4, 2004 | "nbformat_minor": 2 2005 | } 2006 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Conversion of BlazeFace from TFLite to PyTorch done by Matthijs Hollemans 2 | in December 2019. Website: https://machinethink.net 3 | 4 | This work is licensed under the same terms as MediaPipe (Apache License 2.0) 5 | https://github.com/google/mediapipe/blob/master/LICENSE 6 | -------------------------------------------------------------------------------- /README.markdown: -------------------------------------------------------------------------------- 1 | # BlazeFace in Python 2 | 3 | BlazeFace is a fast, light-weight face detector from Google Research. [Read more](https://sites.google.com/view/perception-cv4arvr/blazeface), [Paper on arXiv](https://arxiv.org/abs/1907.05047) 4 | 5 | A pretrained model is available as part of Google's [MediaPipe](https://github.com/google/mediapipe/blob/master/mediapipe/docs/face_detection_mobile_gpu.md) framework. 6 | 7 | ![](https://raw.githubusercontent.com/google/mediapipe/master/mediapipe/docs/images/realtime_face_detection.gif) 8 | 9 | Besides a bounding box, BlazeFace also predicts 6 keypoints for face landmarks (2x eyes, 2x ears, nose, mouth). 10 | 11 | Because BlazeFace is designed for use on mobile devices, the pretrained model is in TFLite format. However, I wanted to use it from PyTorch and so I converted it. 12 | 13 | > **NOTE:** The MediaPipe model is slightly different from the model described in the BlazeFace paper. It uses depthwise convolutions with a 3x3 kernel, not 5x5. And it only uses "single" BlazeBlocks, not "double" ones. 14 | 15 | The BlazePaper paper mentions that there are two versions of the model, one for the front-facing camera and one for the back-facing camera. This repo includes only the frontal camera model, as that is the only one I was able to find an official trained version for. The difference between the two models is the dataset they were trained on. As the paper says, 16 | 17 | > For the frontal camera model, only faces that occupy more than 20% of the image area were considered due to the intended use case (the threshold for the rear-facing camera model was 5%). 18 | 19 | This means the included model will not be able to detect faces that are relatively small. It's really intended for selfies, not for general-purpose face detection. 20 | 21 | ## Inside this repo 22 | 23 | Essential files: 24 | 25 | - **blazeface.py**: defines the `BlazeFace` class that does all the work 26 | 27 | - **blazeface.pth**: the weights for the trained model 28 | 29 | - **anchors.npy**: lookup table with anchor boxes 30 | 31 | Notebooks: 32 | 33 | - **Anchors.ipynb**: creates anchor boxes and saves them as a binary file (anchors.npy) 34 | 35 | - **Convert.ipynb**: loads the weights from the TFLite model and converts them to PyTorch format (blazeface.pth) 36 | 37 | - **Inference.ipynb**: shows how to use the `BlazeFace` class to make face detections 38 | 39 | ## Detections 40 | 41 | Each face detection is a PyTorch `Tensor` consisting of 17 numbers: 42 | 43 | - The first 4 numbers describe the bounding box corners: 44 | - `ymin, xmin, ymax, xmax` 45 | - These are normalized coordinates (between 0 and 1). 46 | 47 | - The next 12 numbers are the x,y-coordinates of the 6 facial landmark keypoints: 48 | - `right_eye_x, right_eye_y` 49 | - `left_eye_x, left_eye_y` 50 | - `nose_x, nose_y` 51 | - `mouth_x, mouth_y` 52 | - `right_ear_x, right_ear_y` 53 | - `left_ear_x, left_ear_y` 54 | - Tip: these labeled as seen from the perspective of the person, so their right is your left. 55 | 56 | - The final number is the confidence score that this detection really is a face. 57 | 58 | ## Image credits 59 | 60 | Included for testing are the following images: 61 | 62 | - **1face.png**. Fei Fei Li by [ITU Pictures](https://www.flickr.com/photos/itupictures/35011409612/), CC BY 2.0 63 | 64 | - **3faces.png**. Geoffrey Hinton, Yoshua Bengio, Yann Lecun. Found at [AIBuilders](https://aibuilders.ai/le-prix-turing-recompense-trois-pionniers-de-lintelligence-artificielle-yann-lecun-yoshua-bengio-et-geoffrey-hinton/) 65 | 66 | - **4faces.png** from Andrew Ng’s Facebook page / [KDnuggets](https://www.kdnuggets.com/2015/03/talking-machine-deep-learning-gurus-p1.html) 67 | 68 | These images were scaled down to 128x128 pixels as that is the expected input size of the model. 69 | -------------------------------------------------------------------------------- /anchors.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hollance/BlazeFace-PyTorch/852bfd8e3d44ed6775761105bdcead4ef389a538/anchors.npy -------------------------------------------------------------------------------- /anchorsback.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hollance/BlazeFace-PyTorch/852bfd8e3d44ed6775761105bdcead4ef389a538/anchorsback.npy -------------------------------------------------------------------------------- /blazeface.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hollance/BlazeFace-PyTorch/852bfd8e3d44ed6775761105bdcead4ef389a538/blazeface.pth -------------------------------------------------------------------------------- /blazeface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BlazeBlock(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): 9 | super(BlazeBlock, self).__init__() 10 | 11 | self.stride = stride 12 | self.channel_pad = out_channels - in_channels 13 | 14 | # TFLite uses slightly different padding than PyTorch 15 | # on the depthwise conv layer when the stride is 2. 16 | if stride == 2: 17 | self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) 18 | padding = 0 19 | else: 20 | padding = (kernel_size - 1) // 2 21 | 22 | self.convs = nn.Sequential( 23 | nn.Conv2d(in_channels=in_channels, out_channels=in_channels, 24 | kernel_size=kernel_size, stride=stride, padding=padding, 25 | groups=in_channels, bias=True), 26 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 27 | kernel_size=1, stride=1, padding=0, bias=True), 28 | ) 29 | 30 | self.act = nn.ReLU(inplace=True) 31 | 32 | def forward(self, x): 33 | if self.stride == 2: 34 | h = F.pad(x, (0, 2, 0, 2), "constant", 0) 35 | x = self.max_pool(x) 36 | else: 37 | h = x 38 | 39 | if self.channel_pad > 0: 40 | x = F.pad(x, (0, 0, 0, 0, 0, self.channel_pad), "constant", 0) 41 | 42 | return self.act(self.convs(h) + x) 43 | 44 | class FinalBlazeBlock(nn.Module): 45 | def __init__(self, channels, kernel_size=3): 46 | super(FinalBlazeBlock, self).__init__() 47 | # TFLite uses slightly different padding than PyTorch 48 | # on the depthwise conv layer when the stride is 2. 49 | self.convs = nn.Sequential( 50 | nn.Conv2d(in_channels=channels, out_channels=channels, 51 | kernel_size=kernel_size, stride=2, padding=0, 52 | groups=channels, bias=True), 53 | nn.Conv2d(in_channels=channels, out_channels=channels, 54 | kernel_size=1, stride=1, padding=0, bias=True), 55 | ) 56 | 57 | self.act = nn.ReLU(inplace=True) 58 | 59 | def forward(self, x): 60 | h = F.pad(x, (0, 2, 0, 2), "constant", 0) 61 | 62 | return self.act(self.convs(h)) 63 | 64 | 65 | class BlazeFace(nn.Module): 66 | """The BlazeFace face detection model from MediaPipe. 67 | 68 | The version from MediaPipe is simpler than the one in the paper; 69 | it does not use the "double" BlazeBlocks. 70 | 71 | Because we won't be training this model, it doesn't need to have 72 | batchnorm layers. These have already been "folded" into the conv 73 | weights by TFLite. 74 | 75 | The conversion to PyTorch is fairly straightforward, but there are 76 | some small differences between TFLite and PyTorch in how they handle 77 | padding on conv layers with stride 2. 78 | 79 | This version works on batches, while the MediaPipe version can only 80 | handle a single image at a time. 81 | 82 | Based on code from https://github.com/tkat0/PyTorch_BlazeFace/ and 83 | https://github.com/google/mediapipe/ 84 | """ 85 | def __init__(self, back_model=False): 86 | super(BlazeFace, self).__init__() 87 | 88 | # These are the settings from the MediaPipe example graphs 89 | # mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt 90 | # and mediapipe/graphs/face_detection/face_detection_back_mobile_gpu.pbtxt 91 | self.num_classes = 1 92 | self.num_anchors = 896 93 | self.num_coords = 16 94 | self.score_clipping_thresh = 100.0 95 | self.back_model = back_model 96 | if back_model: 97 | self.x_scale = 256.0 98 | self.y_scale = 256.0 99 | self.h_scale = 256.0 100 | self.w_scale = 256.0 101 | self.min_score_thresh = 0.65 102 | else: 103 | self.x_scale = 128.0 104 | self.y_scale = 128.0 105 | self.h_scale = 128.0 106 | self.w_scale = 128.0 107 | self.min_score_thresh = 0.75 108 | self.min_suppression_threshold = 0.3 109 | 110 | self._define_layers() 111 | 112 | def _define_layers(self): 113 | if self.back_model: 114 | self.backbone = nn.Sequential( 115 | nn.Conv2d(in_channels=3, out_channels=24, kernel_size=5, stride=2, padding=0, bias=True), 116 | nn.ReLU(inplace=True), 117 | 118 | BlazeBlock(24, 24), 119 | BlazeBlock(24, 24), 120 | BlazeBlock(24, 24), 121 | BlazeBlock(24, 24), 122 | BlazeBlock(24, 24), 123 | BlazeBlock(24, 24), 124 | BlazeBlock(24, 24), 125 | BlazeBlock(24, 24, stride=2), 126 | BlazeBlock(24, 24), 127 | BlazeBlock(24, 24), 128 | BlazeBlock(24, 24), 129 | BlazeBlock(24, 24), 130 | BlazeBlock(24, 24), 131 | BlazeBlock(24, 24), 132 | BlazeBlock(24, 24), 133 | BlazeBlock(24, 48, stride=2), 134 | BlazeBlock(48, 48), 135 | BlazeBlock(48, 48), 136 | BlazeBlock(48, 48), 137 | BlazeBlock(48, 48), 138 | BlazeBlock(48, 48), 139 | BlazeBlock(48, 48), 140 | BlazeBlock(48, 48), 141 | BlazeBlock(48, 96, stride=2), 142 | BlazeBlock(96, 96), 143 | BlazeBlock(96, 96), 144 | BlazeBlock(96, 96), 145 | BlazeBlock(96, 96), 146 | BlazeBlock(96, 96), 147 | BlazeBlock(96, 96), 148 | BlazeBlock(96, 96), 149 | ) 150 | self.final = FinalBlazeBlock(96) 151 | self.classifier_8 = nn.Conv2d(96, 2, 1, bias=True) 152 | self.classifier_16 = nn.Conv2d(96, 6, 1, bias=True) 153 | 154 | self.regressor_8 = nn.Conv2d(96, 32, 1, bias=True) 155 | self.regressor_16 = nn.Conv2d(96, 96, 1, bias=True) 156 | else: 157 | self.backbone1 = nn.Sequential( 158 | nn.Conv2d(in_channels=3, out_channels=24, kernel_size=5, stride=2, padding=0, bias=True), 159 | nn.ReLU(inplace=True), 160 | 161 | BlazeBlock(24, 24), 162 | BlazeBlock(24, 28), 163 | BlazeBlock(28, 32, stride=2), 164 | BlazeBlock(32, 36), 165 | BlazeBlock(36, 42), 166 | BlazeBlock(42, 48, stride=2), 167 | BlazeBlock(48, 56), 168 | BlazeBlock(56, 64), 169 | BlazeBlock(64, 72), 170 | BlazeBlock(72, 80), 171 | BlazeBlock(80, 88), 172 | ) 173 | 174 | self.backbone2 = nn.Sequential( 175 | BlazeBlock(88, 96, stride=2), 176 | BlazeBlock(96, 96), 177 | BlazeBlock(96, 96), 178 | BlazeBlock(96, 96), 179 | BlazeBlock(96, 96), 180 | ) 181 | self.classifier_8 = nn.Conv2d(88, 2, 1, bias=True) 182 | self.classifier_16 = nn.Conv2d(96, 6, 1, bias=True) 183 | 184 | self.regressor_8 = nn.Conv2d(88, 32, 1, bias=True) 185 | self.regressor_16 = nn.Conv2d(96, 96, 1, bias=True) 186 | 187 | def forward(self, x): 188 | # TFLite uses slightly different padding on the first conv layer 189 | # than PyTorch, so do it manually. 190 | x = F.pad(x, (1, 2, 1, 2), "constant", 0) 191 | 192 | b = x.shape[0] # batch size, needed for reshaping later 193 | 194 | if self.back_model: 195 | x = self.backbone(x) # (b, 16, 16, 96) 196 | h = self.final(x) # (b, 8, 8, 96) 197 | else: 198 | x = self.backbone1(x) # (b, 88, 16, 16) 199 | h = self.backbone2(x) # (b, 96, 8, 8) 200 | 201 | # Note: Because PyTorch is NCHW but TFLite is NHWC, we need to 202 | # permute the output from the conv layers before reshaping it. 203 | 204 | c1 = self.classifier_8(x) # (b, 2, 16, 16) 205 | c1 = c1.permute(0, 2, 3, 1) # (b, 16, 16, 2) 206 | c1 = c1.reshape(b, -1, 1) # (b, 512, 1) 207 | 208 | c2 = self.classifier_16(h) # (b, 6, 8, 8) 209 | c2 = c2.permute(0, 2, 3, 1) # (b, 8, 8, 6) 210 | c2 = c2.reshape(b, -1, 1) # (b, 384, 1) 211 | 212 | c = torch.cat((c1, c2), dim=1) # (b, 896, 1) 213 | 214 | r1 = self.regressor_8(x) # (b, 32, 16, 16) 215 | r1 = r1.permute(0, 2, 3, 1) # (b, 16, 16, 32) 216 | r1 = r1.reshape(b, -1, 16) # (b, 512, 16) 217 | 218 | r2 = self.regressor_16(h) # (b, 96, 8, 8) 219 | r2 = r2.permute(0, 2, 3, 1) # (b, 8, 8, 96) 220 | r2 = r2.reshape(b, -1, 16) # (b, 384, 16) 221 | 222 | r = torch.cat((r1, r2), dim=1) # (b, 896, 16) 223 | return [r, c] 224 | 225 | def _device(self): 226 | """Which device (CPU or GPU) is being used by this model?""" 227 | return self.classifier_8.weight.device 228 | 229 | def load_weights(self, path): 230 | self.load_state_dict(torch.load(path)) 231 | self.eval() 232 | 233 | def load_anchors(self, path): 234 | self.anchors = torch.tensor(np.load(path), dtype=torch.float32, device=self._device()) 235 | assert(self.anchors.ndimension() == 2) 236 | assert(self.anchors.shape[0] == self.num_anchors) 237 | assert(self.anchors.shape[1] == 4) 238 | 239 | def _preprocess(self, x): 240 | """Converts the image pixels to the range [-1, 1].""" 241 | return x.float() / 127.5 - 1.0 242 | 243 | def predict_on_image(self, img): 244 | """Makes a prediction on a single image. 245 | 246 | Arguments: 247 | img: a NumPy array of shape (H, W, 3) or a PyTorch tensor of 248 | shape (3, H, W). The image's height and width should be 249 | 128 pixels. 250 | 251 | Returns: 252 | A tensor with face detections. 253 | """ 254 | if isinstance(img, np.ndarray): 255 | img = torch.from_numpy(img).permute((2, 0, 1)) 256 | 257 | return self.predict_on_batch(img.unsqueeze(0))[0] 258 | 259 | def predict_on_batch(self, x): 260 | """Makes a prediction on a batch of images. 261 | 262 | Arguments: 263 | x: a NumPy array of shape (b, H, W, 3) or a PyTorch tensor of 264 | shape (b, 3, H, W). The height and width should be 128 pixels. 265 | 266 | Returns: 267 | A list containing a tensor of face detections for each image in 268 | the batch. If no faces are found for an image, returns a tensor 269 | of shape (0, 17). 270 | 271 | Each face detection is a PyTorch tensor consisting of 17 numbers: 272 | - ymin, xmin, ymax, xmax 273 | - x,y-coordinates for the 6 keypoints 274 | - confidence score 275 | """ 276 | if isinstance(x, np.ndarray): 277 | x = torch.from_numpy(x).permute((0, 3, 1, 2)) 278 | 279 | assert x.shape[1] == 3 280 | if self.back_model: 281 | assert x.shape[2] == 256 282 | assert x.shape[3] == 256 283 | else: 284 | assert x.shape[2] == 128 285 | assert x.shape[3] == 128 286 | 287 | # 1. Preprocess the images into tensors: 288 | x = x.to(self._device()) 289 | x = self._preprocess(x) 290 | 291 | # 2. Run the neural network: 292 | with torch.no_grad(): 293 | out = self.__call__(x) 294 | 295 | # 3. Postprocess the raw predictions: 296 | detections = self._tensors_to_detections(out[0], out[1], self.anchors) 297 | 298 | # 4. Non-maximum suppression to remove overlapping detections: 299 | filtered_detections = [] 300 | for i in range(len(detections)): 301 | faces = self._weighted_non_max_suppression(detections[i]) 302 | faces = torch.stack(faces) if len(faces) > 0 else torch.zeros((0, 17)) 303 | filtered_detections.append(faces) 304 | 305 | return filtered_detections 306 | 307 | def _tensors_to_detections(self, raw_box_tensor, raw_score_tensor, anchors): 308 | """The output of the neural network is a tensor of shape (b, 896, 16) 309 | containing the bounding box regressor predictions, as well as a tensor 310 | of shape (b, 896, 1) with the classification confidences. 311 | 312 | This function converts these two "raw" tensors into proper detections. 313 | Returns a list of (num_detections, 17) tensors, one for each image in 314 | the batch. 315 | 316 | This is based on the source code from: 317 | mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc 318 | mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto 319 | """ 320 | assert raw_box_tensor.ndimension() == 3 321 | assert raw_box_tensor.shape[1] == self.num_anchors 322 | assert raw_box_tensor.shape[2] == self.num_coords 323 | 324 | assert raw_score_tensor.ndimension() == 3 325 | assert raw_score_tensor.shape[1] == self.num_anchors 326 | assert raw_score_tensor.shape[2] == self.num_classes 327 | 328 | assert raw_box_tensor.shape[0] == raw_score_tensor.shape[0] 329 | 330 | detection_boxes = self._decode_boxes(raw_box_tensor, anchors) 331 | 332 | thresh = self.score_clipping_thresh 333 | raw_score_tensor = raw_score_tensor.clamp(-thresh, thresh) 334 | detection_scores = raw_score_tensor.sigmoid().squeeze(dim=-1) 335 | 336 | # Note: we stripped off the last dimension from the scores tensor 337 | # because there is only has one class. Now we can simply use a mask 338 | # to filter out the boxes with too low confidence. 339 | mask = detection_scores >= self.min_score_thresh 340 | 341 | # Because each image from the batch can have a different number of 342 | # detections, process them one at a time using a loop. 343 | output_detections = [] 344 | for i in range(raw_box_tensor.shape[0]): 345 | boxes = detection_boxes[i, mask[i]] 346 | scores = detection_scores[i, mask[i]].unsqueeze(dim=-1) 347 | output_detections.append(torch.cat((boxes, scores), dim=-1)) 348 | 349 | return output_detections 350 | 351 | def _decode_boxes(self, raw_boxes, anchors): 352 | """Converts the predictions into actual coordinates using 353 | the anchor boxes. Processes the entire batch at once. 354 | """ 355 | boxes = torch.zeros_like(raw_boxes) 356 | 357 | x_center = raw_boxes[..., 0] / self.x_scale * anchors[:, 2] + anchors[:, 0] 358 | y_center = raw_boxes[..., 1] / self.y_scale * anchors[:, 3] + anchors[:, 1] 359 | 360 | w = raw_boxes[..., 2] / self.w_scale * anchors[:, 2] 361 | h = raw_boxes[..., 3] / self.h_scale * anchors[:, 3] 362 | 363 | boxes[..., 0] = y_center - h / 2. # ymin 364 | boxes[..., 1] = x_center - w / 2. # xmin 365 | boxes[..., 2] = y_center + h / 2. # ymax 366 | boxes[..., 3] = x_center + w / 2. # xmax 367 | 368 | for k in range(6): 369 | offset = 4 + k*2 370 | keypoint_x = raw_boxes[..., offset ] / self.x_scale * anchors[:, 2] + anchors[:, 0] 371 | keypoint_y = raw_boxes[..., offset + 1] / self.y_scale * anchors[:, 3] + anchors[:, 1] 372 | boxes[..., offset ] = keypoint_x 373 | boxes[..., offset + 1] = keypoint_y 374 | 375 | return boxes 376 | 377 | def _weighted_non_max_suppression(self, detections): 378 | """The alternative NMS method as mentioned in the BlazeFace paper: 379 | 380 | "We replace the suppression algorithm with a blending strategy that 381 | estimates the regression parameters of a bounding box as a weighted 382 | mean between the overlapping predictions." 383 | 384 | The original MediaPipe code assigns the score of the most confident 385 | detection to the weighted detection, but we take the average score 386 | of the overlapping detections. 387 | 388 | The input detections should be a Tensor of shape (count, 17). 389 | 390 | Returns a list of PyTorch tensors, one for each detected face. 391 | 392 | This is based on the source code from: 393 | mediapipe/calculators/util/non_max_suppression_calculator.cc 394 | mediapipe/calculators/util/non_max_suppression_calculator.proto 395 | """ 396 | if len(detections) == 0: return [] 397 | 398 | output_detections = [] 399 | 400 | # Sort the detections from highest to lowest score. 401 | remaining = torch.argsort(detections[:, 16], descending=True) 402 | 403 | while len(remaining) > 0: 404 | detection = detections[remaining[0]] 405 | 406 | # Compute the overlap between the first box and the other 407 | # remaining boxes. (Note that the other_boxes also include 408 | # the first_box.) 409 | first_box = detection[:4] 410 | other_boxes = detections[remaining, :4] 411 | ious = overlap_similarity(first_box, other_boxes) 412 | 413 | # If two detections don't overlap enough, they are considered 414 | # to be from different faces. 415 | mask = ious > self.min_suppression_threshold 416 | overlapping = remaining[mask] 417 | remaining = remaining[~mask] 418 | 419 | # Take an average of the coordinates from the overlapping 420 | # detections, weighted by their confidence scores. 421 | weighted_detection = detection.clone() 422 | if len(overlapping) > 1: 423 | coordinates = detections[overlapping, :16] 424 | scores = detections[overlapping, 16:17] 425 | total_score = scores.sum() 426 | weighted = (coordinates * scores).sum(dim=0) / total_score 427 | weighted_detection[:16] = weighted 428 | weighted_detection[16] = total_score / len(overlapping) 429 | 430 | output_detections.append(weighted_detection) 431 | 432 | return output_detections 433 | 434 | 435 | # IOU code from https://github.com/amdegroot/ssd.pytorch/blob/master/layers/box_utils.py 436 | 437 | def intersect(box_a, box_b): 438 | """ We resize both tensors to [A,B,2] without new malloc: 439 | [A,2] -> [A,1,2] -> [A,B,2] 440 | [B,2] -> [1,B,2] -> [A,B,2] 441 | Then we compute the area of intersect between box_a and box_b. 442 | Args: 443 | box_a: (tensor) bounding boxes, Shape: [A,4]. 444 | box_b: (tensor) bounding boxes, Shape: [B,4]. 445 | Return: 446 | (tensor) intersection area, Shape: [A,B]. 447 | """ 448 | A = box_a.size(0) 449 | B = box_b.size(0) 450 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 451 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 452 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 453 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 454 | inter = torch.clamp((max_xy - min_xy), min=0) 455 | return inter[:, :, 0] * inter[:, :, 1] 456 | 457 | 458 | def jaccard(box_a, box_b): 459 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 460 | is simply the intersection over union of two boxes. Here we operate on 461 | ground truth boxes and default boxes. 462 | E.g.: 463 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 464 | Args: 465 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 466 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 467 | Return: 468 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 469 | """ 470 | inter = intersect(box_a, box_b) 471 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 472 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 473 | area_b = ((box_b[:, 2]-box_b[:, 0]) * 474 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 475 | union = area_a + area_b - inter 476 | return inter / union # [A,B] 477 | 478 | 479 | def overlap_similarity(box, other_boxes): 480 | """Computes the IOU between a bounding box and set of other boxes.""" 481 | return jaccard(box.unsqueeze(0), other_boxes).squeeze(0) 482 | -------------------------------------------------------------------------------- /blazefaceback.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hollance/BlazeFace-PyTorch/852bfd8e3d44ed6775761105bdcead4ef389a538/blazefaceback.pth --------------------------------------------------------------------------------