├── .gitignore ├── LICENSE_mediapipe ├── README.md ├── blendshape_info.py ├── face_landmarker_v2_with_blendshapes.task ├── image.jpg ├── mediapipe_blendshapes_model_to_pytorch.py ├── mlp_mixer.py ├── requirements.txt └── test_converted_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # MediaPipe Face Mesh files 2 | face_landmarker_v2_with_blendshapes.task 3 | face_blendshapes.pth 4 | face_blendshapes.tflite 5 | face_detector.tflite 6 | face_landmarks_detector.tflite 7 | geometry_pipeline_metadata_landmarks.binarypb 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | # meta data in macOS 140 | .DS_Store 141 | -------------------------------------------------------------------------------- /LICENSE_mediapipe: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | =========================================================================== 204 | For files under tasks/cc/text/language_detector/custom_ops/utils/utf/ 205 | =========================================================================== 206 | /* 207 | * The authors of this software are Rob Pike and Ken Thompson. 208 | * Copyright (c) 2002 by Lucent Technologies. 209 | * Permission to use, copy, modify, and distribute this software for any 210 | * purpose without fee is hereby granted, provided that this entire notice 211 | * is included in all copies of any software which is or includes a copy 212 | * or modification of this software and in all copies of the supporting 213 | * documentation for such software. 214 | * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED 215 | * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY 216 | * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY 217 | * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. 218 | */ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `convert-mediapipe-blendshapes-model-to-pytorch` 2 | 3 | This repository contains the code for converting the blendshapes component of MediaPipe's facemesh model to PyTorch. 4 | 5 | ## Converting the model 6 | 7 | Run the following commands: 8 | 9 | ```shell 10 | conda create -y -n deconstruct-mediapipe python=3.9 11 | conda activate deconstruct-mediapipe 12 | pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cpu 13 | pip install -r requirements.txt 14 | python mediapipe_blendshapes_model_to_pytorch.py 15 | ``` 16 | 17 | ## Checking the converted model 18 | 19 | If you want to check the converted model, you need to `pip install mediapipe`. Then run `python test_converted_model.py`. 20 | 21 | The outputs should look as follows, showing that the PyTorch model gives the same results for the (first 12) blendshapes as MediaPipe itself. 22 | 23 | ```shell 24 | --------------------------------- Face 1 -------------------------------- 25 | Blendshapes from MediaPipe: 26 | [0. 0.242 0.217 0.001 0.018 0.014 0. 0. 0. 0.098 0.048 0.017] 27 | Blendshapes from PyTorch: 28 | [0. 0.242 0.217 0.001 0.018 0.014 0. 0. 0. 0.098 0.048 0.017] 29 | Blendshapes from TFLite: 30 | [0. 0.242 0.217 0.001 0.018 0.014 0. 0. 0. 0.098 0.048 0.017] 31 | --------------------------------- Face 2 -------------------------------- 32 | Blendshapes from MediaPipe: 33 | [0. 0.085 0.169 0.008 0.032 0.013 0. 0. 0. 0.063 0.092 0.216] 34 | Blendshapes from PyTorch: 35 | [0. 0.085 0.169 0.008 0.032 0.013 0. 0. 0. 0.063 0.092 0.216] 36 | Blendshapes from TFLite: 37 | [0. 0.085 0.169 0.008 0.032 0.013 0. 0. 0. 0.063 0.092 0.216] 38 | ``` 39 | -------------------------------------------------------------------------------- /blendshape_info.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def parse_prototxt(path): 4 | with open(path, 'r') as f: 5 | d = f.readlines() 6 | out = [] 7 | out_all = [] 8 | out_blendshapes = [None] * 52 9 | for line in d: 10 | if line.strip().startswith('x: ') or line.strip().startswith('y: ') or line.strip().startswith('z: '): 11 | out.append(float(line.strip().split(': ')[1])) 12 | if line.strip().startswith('z: '): 13 | out_all.append(list(out)) 14 | out = [] 15 | if line.strip().startswith('index: '): 16 | curr_index = int(line.strip().split(': ')[1]) 17 | if line.strip().startswith('score: '): 18 | out_blendshapes[curr_index] = float(line.strip().split(': ')[1]) 19 | if out_blendshapes[0] is not None: 20 | out_all = out_blendshapes 21 | return np.array(out_all).astype('float32') 22 | 23 | 24 | BLENDSHAPE_MODEL_LANDMARKS_SUBSET = np.array([0, 1, 4, 5, 6, 7, 8, 10, 13, 14, 17, 21, 33, 37, 39, 25 | 40, 46, 52, 53, 54, 55, 58, 61, 63, 65, 66, 67, 70, 78, 80, 26 | 81, 82, 84, 87, 88, 91, 93, 95, 103, 105, 107, 109, 127, 132, 133, 27 | 136, 144, 145, 146, 148, 149, 150, 152, 153, 154, 155, 157, 158, 159, 160, 28 | 161, 162, 163, 168, 172, 173, 176, 178, 181, 185, 191, 195, 197, 234, 246, 29 | 249, 251, 263, 267, 269, 270, 276, 282, 283, 284, 285, 288, 291, 293, 295, 30 | 296, 297, 300, 308, 310, 311, 312, 314, 317, 318, 321, 323, 324, 332, 334, 31 | 336, 338, 356, 361, 362, 365, 373, 374, 375, 377, 378, 379, 380, 381, 382, 32 | 384, 385, 386, 387, 388, 389, 390, 397, 398, 400, 402, 405, 409, 415, 454, 33 | 466, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477]) 34 | 35 | BLENDSHAPE_NAMES = np.array([ 36 | "_neutral", 37 | "browDownLeft", 38 | "browDownRight", 39 | "browInnerUp", 40 | "browOuterUpLeft", 41 | "browOuterUpRight", 42 | "cheekPuff", 43 | "cheekSquintLeft", 44 | "cheekSquintRight", 45 | "eyeBlinkLeft", 46 | "eyeBlinkRight", 47 | "eyeLookDownLeft", 48 | "eyeLookDownRight", 49 | "eyeLookInLeft", 50 | "eyeLookInRight", 51 | "eyeLookOutLeft", 52 | "eyeLookOutRight", 53 | "eyeLookUpLeft", 54 | "eyeLookUpRight", 55 | "eyeSquintLeft", 56 | "eyeSquintRight", 57 | "eyeWideLeft", 58 | "eyeWideRight", 59 | "jawForward", 60 | "jawLeft", 61 | "jawOpen", 62 | "jawRight", 63 | "mouthClose", 64 | "mouthDimpleLeft", 65 | "mouthDimpleRight", 66 | "mouthFrownLeft", 67 | "mouthFrownRight", 68 | "mouthFunnel", 69 | "mouthLeft", 70 | "mouthLowerDownLeft", 71 | "mouthLowerDownRight", 72 | "mouthPressLeft", 73 | "mouthPressRight", 74 | "mouthPucker", 75 | "mouthRight", 76 | "mouthRollLower", 77 | "mouthRollUpper", 78 | "mouthShrugLower", 79 | "mouthShrugUpper", 80 | "mouthSmileLeft", 81 | "mouthSmileRight", 82 | "mouthStretchLeft", 83 | "mouthStretchRight", 84 | "mouthUpperUpLeft", 85 | "mouthUpperUpRight", 86 | "noseSneerLeft", 87 | "noseSneerRight"]) -------------------------------------------------------------------------------- /face_landmarker_v2_with_blendshapes.task: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlml/deconstruct-mediapipe/c357d5eeeecd79e58277923e7a081c1061590e0c/face_landmarker_v2_with_blendshapes.task -------------------------------------------------------------------------------- /image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nlml/deconstruct-mediapipe/c357d5eeeecd79e58277923e7a081c1061590e0c/image.jpg -------------------------------------------------------------------------------- /mediapipe_blendshapes_model_to_pytorch.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is used to convert the mediapipe blendshapes model to PyTorch. 3 | Usage: 4 | python mediapipe_blendshapes_model_to_pytorch.py --tflite_path ./face_blendshapes.tflite --output_path ./face_blendshapes.pth 5 | 6 | See README.md for more details. 7 | """ 8 | 9 | import os 10 | 11 | os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" 12 | import tf2onnx 13 | import requests 14 | import zipfile 15 | import argparse 16 | import numpy as np 17 | import onnx 18 | import torch 19 | from onnx.numpy_helper import to_array as _to_array 20 | 21 | from mlp_mixer import MediaPipeBlendshapesMLPMixer 22 | 23 | 24 | def to_array(tensor): 25 | # Override to_array to suppress PyTorch warning about non-writable tensor 26 | return np.copy(_to_array(tensor)) 27 | 28 | 29 | def get_model_weight(onnx_model, name): 30 | for _ in onnx_model.graph.initializer: 31 | if _.name == name: 32 | break 33 | return _ 34 | 35 | 36 | def conv_node_to_w_b(onnx_model, node): 37 | w_node = node.input[1] 38 | w = to_array(get_model_weight(onnx_model, w_node)) 39 | assert len(w.shape) == 4, w.shape 40 | b_node = node.input[2] 41 | b = to_array(get_model_weight(onnx_model, b_node)) 42 | assert len(b.shape) == 1, b.shape 43 | w, b = [torch.from_numpy(_).float() for _ in [w, b]] 44 | return w, b 45 | 46 | 47 | def get_node_weight_by_output_name(onnx_model, search_str, input_idx): 48 | fold_op_name = [n for n in onnx_model.graph.node if search_str == n.output[0]] 49 | fold_op_name = fold_op_name[0].input[input_idx] 50 | return get_model_weight(onnx_model, fold_op_name) 51 | 52 | 53 | def get_layernorm_weight(onnx_model, mixer_block_idx, norm_idx): 54 | search_str = f"model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/MixerBlock_{mixer_block_idx}/layer_norm{norm_idx}/batchnorm/mul" 55 | # if norm_idx == 2: 56 | # print(to_array(get_node_weight_by_output_name(onnx_model, search_str, 1)).shape) 57 | # import pdb; pdb.set_trace() 58 | return get_node_weight_by_output_name(onnx_model, search_str, 1) 59 | 60 | 61 | def get_conv_layer_weight_bias(onnx_model, mixer_block_idx, is_token_mixer, mlp_idx): 62 | assert mlp_idx in (1, 2) 63 | assert mixer_block_idx in (0, 1, 2, 3) 64 | search_str = """ 65 | model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/MixerBlock_0/mlp_token_mixing/Mlp_1/Relu; 66 | model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/MixerBlock_0/mlp_token_mixing/Mlp_1/BiasAdd; 67 | model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/MixerBlock_3/mlp_token_mixing/Mlp_1/Conv2D; 68 | model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/MixerBlock_0/mlp_token_mixing/Mlp_1/Conv2D; 69 | model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/MixerBlock_0/mlp_token_mixing/Mlp_1/BiasAdd/ReadVariableOp 70 | """ 71 | if mixer_block_idx == 3: 72 | search_str = """ 73 | model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/MixerBlock_0/mlp_token_mixing/Mlp_1/Relu; 74 | model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/MixerBlock_0/mlp_token_mixing/Mlp_1/BiasAdd; 75 | model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/MixerBlock_0/mlp_token_mixing/Mlp_1/Conv2D; 76 | model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/MixerBlock_0/mlp_token_mixing/Mlp_1/BiasAdd/ReadVariableOp 77 | """ 78 | search_str = search_str.replace("\n", "").replace(" ", "").strip() 79 | 80 | search_str = search_str.replace("MixerBlock_0", f"MixerBlock_{mixer_block_idx}") 81 | if mlp_idx == 2: 82 | replace_str = "model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/MixerBlock_0/mlp_token_mixing/Mlp_1/Relu;" 83 | replace_str = replace_str.replace( 84 | "MixerBlock_0", f"MixerBlock_{mixer_block_idx}" 85 | ) 86 | search_str = search_str.replace(replace_str, "") 87 | search_str = search_str.replace("Mlp_1", f"Mlp_{mlp_idx}") 88 | if not is_token_mixer: 89 | search_str = search_str.replace("mlp_token_mixing", "mlp_channel_mixing") 90 | ii, node = [ 91 | (i, n) for i, n in enumerate(onnx_model.graph.node) if search_str == n.output[0] 92 | ][0] 93 | w, b = conv_node_to_w_b(onnx_model, node) 94 | mlpname = "mlp_token_mixing" if is_token_mixer else "mlp_channel_mixing" 95 | idx = 0 if mlp_idx == 1 else 3 96 | return { 97 | f"mlpmixer_blocks.{mixer_block_idx}.{mlpname}.{idx}.weight": w, 98 | f"mlpmixer_blocks.{mixer_block_idx}.{mlpname}.{idx}.bias": b, 99 | } 100 | 101 | 102 | def get_state_dict_mlp_mixer_layer(onnx_model, mixer_block_idx): 103 | state_dict = {} 104 | norm1_weight = get_layernorm_weight(onnx_model, mixer_block_idx, 1) 105 | state_dict.update( 106 | { 107 | f"mlpmixer_blocks.{mixer_block_idx}.norm1.weight": torch.from_numpy( 108 | to_array(norm1_weight).reshape(-1) 109 | ).float() 110 | } 111 | ) 112 | state_dict.update(get_conv_layer_weight_bias(onnx_model, mixer_block_idx, True, 1)) 113 | state_dict.update(get_conv_layer_weight_bias(onnx_model, mixer_block_idx, True, 2)) 114 | norm2_weight = get_layernorm_weight(onnx_model, mixer_block_idx, 2) 115 | state_dict.update( 116 | { 117 | f"mlpmixer_blocks.{mixer_block_idx}.norm2.weight": torch.from_numpy( 118 | to_array(norm2_weight).reshape(-1) 119 | ).float() 120 | } 121 | ) 122 | state_dict.update(get_conv_layer_weight_bias(onnx_model, mixer_block_idx, False, 1)) 123 | state_dict.update(get_conv_layer_weight_bias(onnx_model, mixer_block_idx, False, 2)) 124 | return state_dict 125 | 126 | 127 | def conv_w_b_from_search_str(onnx_model, search_str): 128 | _, node = [ 129 | (i, n) for i, n in enumerate(onnx_model.graph.node) if search_str == n.output[0] 130 | ][0] 131 | return conv_node_to_w_b(onnx_model, node) 132 | 133 | 134 | def get_state_dict(onnx_model): 135 | state_dict = {} 136 | search_str = "model_1/GhumMarkerPoserMlpMixerGeneral/conv2d/BiasAdd;model_1/GhumMarkerPoserMlpMixerGeneral/conv2d/Conv2D;model_1/GhumMarkerPoserMlpMixerGeneral/conv2d/BiasAdd/ReadVariableOp" 137 | w, b = conv_w_b_from_search_str(onnx_model, search_str) 138 | state_dict.update({"conv1.weight": w, "conv1.bias": b}) 139 | search_str = "model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/input_tokens_embedding/BiasAdd;model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/MixerBlock_3/mlp_channel_mixing/Mlp_2/Conv2D;model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/input_tokens_embedding/Conv2D;model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/input_tokens_embedding/BiasAdd/ReadVariableOp" 140 | w, b = conv_w_b_from_search_str(onnx_model, search_str) 141 | state_dict.update({"conv2.weight": w, "conv2.bias": b}) 142 | search_str = "model_1/GhumMarkerPoserMlpMixerGeneral/MLPMixer/AddExtraTokens/concat" 143 | extra_token = get_node_weight_by_output_name(onnx_model, search_str, 0) 144 | state_dict.update({"extra_token": torch.from_numpy(to_array(extra_token)).float()}) 145 | # MLP Mixer layers 146 | state_dict.update(get_state_dict_mlp_mixer_layer(onnx_model, 0)) 147 | state_dict.update(get_state_dict_mlp_mixer_layer(onnx_model, 1)) 148 | state_dict.update(get_state_dict_mlp_mixer_layer(onnx_model, 2)) 149 | state_dict.update(get_state_dict_mlp_mixer_layer(onnx_model, 3)) 150 | search_str = "model_1/GhumMarkerPoserMlpMixerGeneral/output_blendweights/BiasAdd;model_1/GhumMarkerPoserMlpMixerGeneral/output_blendweights/Conv2D;model_1/GhumMarkerPoserMlpMixerGeneral/output_blendweights/BiasAdd/ReadVariableOp" 151 | w, b = conv_w_b_from_search_str(onnx_model, search_str) 152 | state_dict.update({"output_mlp.weight": w, "output_mlp.bias": b}) 153 | return state_dict 154 | 155 | 156 | def download_and_unzip_blendshapes_model(): 157 | # Equivalent to: 158 | # wget -O face_landmarker_v2_with_blendshapes.task -q https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task 159 | # unzip face_landmarker_v2_with_blendshapes.task 160 | url = "https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task" 161 | if not os.path.exists("face_landmarker_v2_with_blendshapes.task"): 162 | print("Downloading face_landmarker_v2_with_blendshapes.task") 163 | with open("face_landmarker_v2_with_blendshapes.task", "wb") as f: 164 | f.write(requests.get(url, allow_redirects=True).content) 165 | print("Unzipping face_landmarker_v2_with_blendshapes.task") 166 | with zipfile.ZipFile("face_landmarker_v2_with_blendshapes.task", "r") as zip_ref: 167 | zip_ref.extractall("./") 168 | 169 | 170 | if __name__ == "__main__": 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument("--tflite_path", type=str, default="./face_blendshapes.tflite") 173 | parser.add_argument("--output_path", type=str, default="./face_blendshapes.pth") 174 | args = parser.parse_args() 175 | if not os.path.exists(args.tflite_path): 176 | download_and_unzip_blendshapes_model() 177 | model_proto, external_tensor_storage = tf2onnx.convert.from_tflite(args.tflite_path) 178 | checker = onnx.checker.check_model(model_proto) 179 | state_dict = get_state_dict(model_proto) 180 | MediaPipeBlendshapesMLPMixer().load_state_dict(state_dict) 181 | torch.save(state_dict, args.output_path) 182 | print(f"Saved model to {args.output_path}") 183 | -------------------------------------------------------------------------------- /mlp_mixer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Re-implementation of the MLP-Mixer blendshapes prediction model from 3 | MediaPipe. See README.md for more details. 4 | 5 | The MediaPipeBlendshapesMLPMixer class' forward method has an expected 6 | input shape of: (batch_size, 146, 2). 7 | 146 here refers to this subset of face mesh landmarks output by MediaPipe: 8 | 0, 1, 4, 5, 6, 7, 8, 10, 13, 14, 17, 21, 33, 37, 39, 9 | 40, 46, 52, 53, 54, 55, 58, 61, 63, 65, 66, 67, 70, 78, 80, 10 | 81, 82, 84, 87, 88, 91, 93, 95, 103, 105, 107, 109, 127, 132, 133, 11 | 136, 144, 145, 146, 148, 149, 150, 152, 153, 154, 155, 157, 158, 159, 160, 12 | 161, 162, 163, 168, 172, 173, 176, 178, 181, 185, 191, 195, 197, 234, 246, 13 | 249, 251, 263, 267, 269, 270, 276, 282, 283, 284, 285, 288, 291, 293, 295, 14 | 296, 297, 300, 308, 310, 311, 312, 314, 317, 318, 321, 323, 324, 332, 334, 15 | 336, 338, 356, 361, 362, 365, 373, 374, 375, 377, 378, 379, 380, 381, 382, 16 | 384, 385, 386, 387, 388, 389, 390, 397, 398, 400, 402, 405, 409, 415, 454, 17 | 466, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477 18 | 19 | The 2 represents the x and y coordinates of the landmarks. 20 | The normalized landmarks that are output by mediapipe need to be multiplied 21 | by the image height and image width before being passed to this model. 22 | The scale doesn't matter, as it is normalized out by this model, but the 23 | aspect ratio of the x to y landmark coordinates needs to be correct. 24 | 25 | Example usage: 26 | ``` 27 | blendshape_model = MediaPipeBlendshapesMLPMixer() 28 | mesh_detector = init_mpipe_blendshapes_model() 29 | mesh_results = mesh_detector.detect(image_mp) 30 | landmarks_np = [] 31 | for face_idx in range(len(mesh_results.face_landmarks)): 32 | landmarks_np.append(np.array([[i.x, i.y, i.z] for i in mesh_results.face_landmarks[face_idx]])) 33 | landmarks_np = np.array(landmarks_np).astype('float32') 34 | lmks_tensor = landmarks_np[:1, BLENDSHAPE_MODEL_LANDMARKS_SUBSET, :2] 35 | img_size = np.array([image_mp.width, image_mp.height])[None, None].astype('float32') 36 | scaled_lmks_tensor = torch.from_numpy(lmks_tensor * img_size) 37 | output = blendshape_model(scaled_lmks_tensor) 38 | ``` 39 | """ 40 | 41 | import torch 42 | import torch.nn as nn 43 | 44 | 45 | class MLPMixerLayer(nn.Module): 46 | def __init__( 47 | self, 48 | in_dim, 49 | num_patches, 50 | hidden_units_mlp1, 51 | hidden_units_mlp2, 52 | dropout_rate=0.0, 53 | eps1=0.0000010132789611816406, 54 | eps2=0.0000010132789611816406, 55 | ): 56 | super().__init__() 57 | self.mlp_token_mixing = nn.Sequential( 58 | nn.Conv2d(num_patches, hidden_units_mlp1, 1), 59 | nn.ReLU(), 60 | nn.Dropout(dropout_rate), 61 | nn.Conv2d(hidden_units_mlp1, num_patches, 1), 62 | ) 63 | self.mlp_channel_mixing = nn.Sequential( 64 | nn.Conv2d(in_dim, hidden_units_mlp2, 1), 65 | nn.ReLU(), 66 | nn.Dropout(dropout_rate), 67 | nn.Conv2d(hidden_units_mlp2, in_dim, 1), 68 | ) 69 | self.norm1 = nn.LayerNorm(in_dim, bias=False, elementwise_affine=True, eps=eps1) 70 | self.norm2 = nn.LayerNorm(in_dim, bias=False, elementwise_affine=True, eps=eps2) 71 | 72 | def forward(self, x): 73 | x_1 = self.norm1(x) 74 | mlp1_outputs = self.mlp_token_mixing(x_1) 75 | x = x + mlp1_outputs 76 | x_2 = self.norm2(x) 77 | mlp2_outputs = self.mlp_channel_mixing(x_2.permute(0, 3, 2, 1)) 78 | x = x + mlp2_outputs.permute(0, 3, 2, 1) 79 | return x 80 | 81 | 82 | class MediaPipeBlendshapesMLPMixer(nn.Module): 83 | def __init__( 84 | self, 85 | in_dim=64, 86 | num_patches=97, 87 | hidden_units_mlp1=384, 88 | hidden_units_mlp2=256, 89 | num_blocks=4, 90 | dropout_rate=0.0, 91 | output_dim=52, 92 | ): 93 | super().__init__() 94 | self.conv1 = nn.Conv2d(146, 96, kernel_size=1) 95 | self.conv2 = nn.Conv2d(2, 64, kernel_size=1) 96 | self.extra_token = nn.Parameter(torch.randn(1, 64, 1, 1), requires_grad=True) 97 | self.mlpmixer_blocks = nn.Sequential( 98 | *[ 99 | MLPMixerLayer( 100 | in_dim, 101 | num_patches, 102 | hidden_units_mlp1, 103 | hidden_units_mlp2, 104 | dropout_rate, 105 | ) 106 | for _ in range(num_blocks) 107 | ] 108 | ) 109 | self.output_mlp = nn.Conv2d(in_dim, output_dim, 1) 110 | 111 | def forward(self, x): 112 | x = x - x.mean(1, keepdim=True) 113 | x = x / x.norm(dim=2, keepdim=True).mean(1, keepdim=True) 114 | x = x.unsqueeze(-2) * 0.5 115 | x = self.conv1(x) 116 | x = x.permute(0, 3, 2, 1) 117 | x = self.conv2(x) 118 | x = torch.cat([self.extra_token, x], dim=3) 119 | x = x.permute(0, 3, 2, 1) 120 | x = self.mlpmixer_blocks(x) 121 | x = x.permute(0, 3, 2, 1) 122 | x = x[:, :, :, :1] 123 | x = self.output_mlp(x) 124 | x = torch.sigmoid(x) 125 | return x#.squeeze() 126 | 127 | 128 | if __name__ == "__main__": 129 | model = MediaPipeBlendshapesMLPMixer() 130 | print(model) 131 | input_tensor = torch.randn(1, 146, 2) 132 | output = model(input_tensor) 133 | print(output.shape) 134 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | onnx==1.15.0 2 | onnxruntime==1.17.1 3 | tf2onnx==1.16.1 4 | tensorflow==2.16.1 5 | -------------------------------------------------------------------------------- /test_converted_model.py: -------------------------------------------------------------------------------- 1 | import mediapipe as mp 2 | from mediapipe.tasks import python 3 | from mediapipe.tasks.python import vision 4 | import tensorflow as tf 5 | import numpy as np 6 | import torch 7 | 8 | from blendshape_info import BLENDSHAPE_MODEL_LANDMARKS_SUBSET, BLENDSHAPE_NAMES 9 | from mlp_mixer import MediaPipeBlendshapesMLPMixer 10 | 11 | 12 | class TFLiteModel: 13 | def __init__(self, model_path: str): 14 | self.interpreter = tf.lite.Interpreter(model_path) 15 | self.interpreter.allocate_tensors() 16 | self.input_details = self.interpreter.get_input_details() 17 | self.output_details = self.interpreter.get_output_details() 18 | 19 | def predict(self, *data_args): 20 | assert len(data_args) == len(self.input_details) 21 | for data, details in zip(data_args, self.input_details): 22 | self.interpreter.set_tensor(details["index"], data) 23 | self.interpreter.invoke() 24 | return self.interpreter.get_tensor(self.output_details[0]["index"]) 25 | 26 | 27 | def init_mpipe_blendshapes_model(): 28 | base_options = python.BaseOptions( 29 | model_asset_path="face_landmarker_v2_with_blendshapes.task", 30 | # delegate=mp.tasks.BaseOptions.Delegate.GPU, 31 | delegate=mp.tasks.BaseOptions.Delegate.CPU, 32 | ) 33 | mp_mode = mp.tasks.vision.FaceDetectorOptions.running_mode.IMAGE 34 | options = vision.FaceLandmarkerOptions( 35 | base_options=base_options, 36 | running_mode=mp_mode, 37 | output_face_blendshapes=True, 38 | output_facial_transformation_matrixes=True, 39 | num_faces=2, 40 | # result_callback=self.result_callback, 41 | ) 42 | return vision.FaceLandmarker.create_from_options(options) 43 | 44 | 45 | def get_blendshape_score_by_index(blendshapes, i): 46 | return [_ for _ in blendshapes if _.index == i][0].score 47 | 48 | 49 | if __name__ == "__main__": 50 | # Init MediaPipe model 51 | mesh_detector = init_mpipe_blendshapes_model() 52 | # Init TFLite model 53 | tflite_model = TFLiteModel("face_blendshapes.tflite") 54 | # Init PyTorch model 55 | blendshape_model = MediaPipeBlendshapesMLPMixer() 56 | blendshape_model.load_state_dict(torch.load("face_blendshapes.pth")) 57 | 58 | # Run the image through MediaPipe 59 | IMAGE_FILE = "image.jpg" 60 | image_mp = mp.Image.create_from_file(IMAGE_FILE) 61 | mesh_results = mesh_detector.detect(image_mp) 62 | # Convert landmarks to numpy 63 | landmarks_np = [] 64 | for face_idx in range(len(mesh_results.face_landmarks)): 65 | landmarks_np.append( 66 | np.array([[i.x, i.y, i.z] for i in mesh_results.face_landmarks[face_idx]]) 67 | ) 68 | landmarks_np = np.array(landmarks_np).astype("float32") 69 | # Convert blendshapes to numpy 70 | blendshapes_np = np.array( 71 | [ 72 | [ 73 | get_blendshape_score_by_index( 74 | mesh_results.face_blendshapes[face_idx], i 75 | ) 76 | for i in range(len(BLENDSHAPE_NAMES)) 77 | ] 78 | for face_idx in range(len(mesh_results.face_landmarks)) 79 | ] 80 | ) 81 | img_size = np.array([image_mp.width, image_mp.height])[None, None].astype("float32") 82 | # Compare the results 83 | for face_idx in range(len(mesh_results.face_landmarks)): 84 | print("-" * 33 + f" Face {face_idx + 1} " + "-" * 32) 85 | print("Blendshapes from MediaPipe:") 86 | print(blendshapes_np[face_idx].round(3)[:12]) 87 | # Run the image through PyTorch 88 | lmks_tensor = landmarks_np[ 89 | face_idx : face_idx + 1, BLENDSHAPE_MODEL_LANDMARKS_SUBSET, :2 90 | ] 91 | scaled_lmks_tensor = lmks_tensor * img_size 92 | with torch.no_grad(): 93 | pytorch_output = blendshape_model(torch.from_numpy(scaled_lmks_tensor)) 94 | print("Blendshapes from PyTorch:") 95 | print(pytorch_output.squeeze().detach().numpy().round(3)[:12]) 96 | # Run the image through TFLite 97 | label = tflite_model.predict(scaled_lmks_tensor) 98 | print("Blendshapes from TFLite:") 99 | print(label.round(3)[:12]) 100 | --------------------------------------------------------------------------------