├── .github └── workflows │ ├── release.yaml │ └── unittest.yaml ├── .gitignore ├── CITATION.cff ├── LICENSE ├── MODEL_ZOO.md ├── README.md ├── assets └── teaser.svg ├── config ├── celebv_hq │ ├── action │ │ ├── celebvhq_marlin_action_ft.yaml │ │ └── celebvhq_marlin_action_lp.yaml │ └── appearance │ │ ├── celebvhq_marlin_appearance_ft.yaml │ │ └── celebvhq_marlin_appearance_lp.yaml └── pretrain │ ├── marlin_vit_base.yaml │ ├── marlin_vit_large.yaml │ └── marlin_vit_small.yaml ├── dataset ├── __init__.py ├── celebv_hq.py ├── misc │ └── youtube_face │ │ ├── train_set.csv │ │ └── val_set.csv └── youtube_face.py ├── evaluate.py ├── hf_src ├── marlin_configs │ ├── vit_base.py │ ├── vit_large.py │ └── vit_small.py ├── marlin_huggingface │ ├── __init__.py │ ├── config.py │ ├── decoder.py │ ├── encoder.py │ ├── marlin.py │ ├── modules.py │ └── positional_embedding.py └── publish_hf.py ├── init.py ├── model ├── __init__.py ├── classifier.py └── marlin.py ├── preprocess ├── celebvhq_extract.py ├── celebvhq_preprocess.py └── ytf_preprocess.py ├── requirements.lib.txt ├── requirements.txt ├── setup.py ├── src └── marlin_pytorch │ ├── __init__.py │ ├── config.py │ ├── face_detector.py │ ├── model │ ├── __init__.py │ ├── decoder.py │ ├── encoder.py │ ├── marlin.py │ ├── modules.py │ └── positional_embedding.py │ └── util.py ├── test ├── __init__.py ├── input_sample │ ├── cropped01.mp4 │ ├── cropped02.mp4 │ ├── cropped03.mp4 │ ├── cropped04.mp4 │ ├── cropped05.mp4 │ ├── video01.mp4 │ ├── video02.mp4 │ ├── video03.mp4 │ ├── video04.mp4 │ └── video05.mp4 ├── output_sample │ └── marlin_vit_base │ │ ├── cropped01.npy │ │ ├── cropped02.npy │ │ ├── cropped03.npy │ │ ├── cropped04.npy │ │ ├── cropped05.npy │ │ ├── video01.npy │ │ ├── video02.npy │ │ ├── video03.npy │ │ ├── video04.npy │ │ └── video05.npy ├── test_marlin_pytorch.py ├── test_marlin_vit_base.py ├── test_marlin_vit_large.py ├── test_marlin_vit_small.py └── test_version.py ├── train.py ├── util ├── __init__.py ├── earlystop_lr.py ├── face_sdk │ ├── README.md │ ├── __init__.py │ ├── config │ │ ├── logging.conf │ │ └── model_conf.yaml │ ├── core │ │ ├── image_cropper │ │ │ ├── BaseImageCropper.py │ │ │ └── arcface_cropper │ │ │ │ └── FaceRecImageCropper.py │ │ ├── model_handler │ │ │ ├── BaseModelHandler.py │ │ │ ├── face_alignment │ │ │ │ └── FaceAlignModelHandler.py │ │ │ ├── face_detection │ │ │ │ └── FaceDetModelHandler.py │ │ │ ├── face_parsing │ │ │ │ └── FaceParsingModelHandler.py │ │ │ └── face_recognition │ │ │ │ └── FaceRecModelHandler.py │ │ └── model_loader │ │ │ ├── BaseModelLoader.py │ │ │ ├── face_alignment │ │ │ └── FaceAlignModelLoader.py │ │ │ ├── face_detection │ │ │ └── FaceDetModelLoader.py │ │ │ ├── face_parsing │ │ │ └── FaceParsingModelLoader.py │ │ │ └── face_recognition │ │ │ └── FaceRecModelLoader.py │ ├── face_crop.py │ ├── face_parse.py │ ├── models │ │ ├── face_alignment │ │ │ ├── face_alignment_1.0 │ │ │ │ ├── face_landmark_pfld.pkl │ │ │ │ └── model_meta.json │ │ │ └── face_alignment_2.0 │ │ │ │ ├── face_landmark_pfld.pkl │ │ │ │ └── model_meta.json │ │ ├── face_detection │ │ │ ├── face_detection_1.0 │ │ │ │ ├── face_detection_retina.pkl │ │ │ │ └── model_meta.json │ │ │ └── face_detection_2.0 │ │ │ │ ├── face_detection_retina.pkl │ │ │ │ └── model_meta.json │ │ ├── face_parsing │ │ │ ├── README.md │ │ │ └── face_parsing_1.0 │ │ │ │ └── model_meta.json │ │ ├── face_recognition │ │ │ ├── face_recognition_1.0 │ │ │ │ ├── face_recognition_mv.pkl │ │ │ │ └── model_meta.json │ │ │ └── face_recognition_2.0 │ │ │ │ ├── face_recognition_mv.pkl │ │ │ │ └── model_meta.json │ │ └── network_def │ │ │ ├── mobilefacenet_def.py │ │ │ ├── mobilev3_pfld.py │ │ │ └── retinaface_def.py │ └── utils │ │ ├── BuzException.py │ │ ├── draw.py │ │ ├── lms_trans.py │ │ ├── show.py │ │ └── transform.py ├── lr_logger.py ├── misc.py ├── seed.py └── system_stats_logger.py └── version.txt /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | name: Release 2 | on: 3 | push: 4 | branches: 5 | - "master" 6 | jobs: 7 | 8 | check-version: 9 | name: Check Version 10 | runs-on: ubuntu-20.04 11 | outputs: 12 | local-version: ${{ steps.get-local-version.outputs.version }} 13 | remote-version: ${{ steps.get-remote-version.outputs.version }} 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Get Local Version 17 | id: get-local-version 18 | run: echo "version=$(cat version.txt)" >> $GITHUB_OUTPUT 19 | - name: Get Remote Version 20 | id: get-remote-version 21 | run: echo "version=$(curl -s https://pypi.org/pypi/marlin_pytorch/json | jq -r '.info.version')" >> $GITHUB_OUTPUT 22 | 23 | release: 24 | runs-on: ubuntu-20.04 25 | needs: check-version 26 | if: needs.check-version.outputs.local-version != needs.check-version.outputs.remote-version 27 | 28 | strategy: 29 | matrix: 30 | python-version: ["3.9"] 31 | 32 | steps: 33 | - uses: actions/checkout@v2 34 | 35 | - name: Set up Python 36 | uses: actions/setup-python@v2 37 | with: 38 | python-version: ${{ matrix.python-version }} 39 | architecture: x64 40 | 41 | - name: Install dependencies 42 | run: | 43 | pip install -r requirements.lib.txt 44 | python init.py 45 | 46 | - name: Build package for marlin_pytorch 47 | run: python setup.py sdist bdist_wheel 48 | 49 | - name: Release marlin_pytorch to PyPI 50 | uses: pypa/gh-action-pypi-publish@release/v1 51 | with: 52 | user: __token__ 53 | password: ${{ secrets.PYPI_API_TOKEN }} 54 | 55 | - name: Get the version 56 | run: | 57 | VER=$(cat version.txt) 58 | echo "VERSION=$VER" >> $GITHUB_ENV 59 | 60 | - name: Release to GitHub Release 61 | uses: marvinpinto/action-automatic-releases@latest 62 | with: 63 | repo_token: "${{ secrets.GITHUB_TOKEN }}" 64 | automatic_release_tag: "${{ env.VERSION }}" 65 | title: "[${{ env.VERSION }}] Marlin-PyTorch Release" 66 | prerelease: false 67 | files: "dist/*" 68 | draft: true 69 | -------------------------------------------------------------------------------- /.github/workflows/unittest.yaml: -------------------------------------------------------------------------------- 1 | name: Unittest 2 | on: 3 | push: 4 | branches-ignore: 5 | - "master" 6 | pull_request: 7 | 8 | jobs: 9 | unittest: 10 | name: Unittest 11 | 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.6", "3.7", "3.8", "3.9", "3.10", "3.11"] 16 | torch-version: ["1.8.*", "1.9.*", "1.10.*", "1.11.*", "1.12.*", "1.13.*", "2.0.*", "2.1.*"] 17 | include: 18 | - torch-version: "1.8.*" 19 | torchvision-version: "0.9.*" 20 | - torch-version: "1.9.*" 21 | torchvision-version: "0.10.*" 22 | - torch-version: "1.10.*" 23 | torchvision-version: "0.11.*" 24 | - torch-version: "1.11.*" 25 | torchvision-version: "0.12.*" 26 | - torch-version: "1.12.*" 27 | torchvision-version: "0.13.*" 28 | - torch-version: "1.13.*" 29 | torchvision-version: "0.14.*" 30 | - torch-version: "2.0.*" 31 | torchvision-version: "0.15.*" 32 | - torch-version: "2.1.*" 33 | torchvision-version: "0.16.*" 34 | exclude: 35 | - python-version: "3.6" 36 | torch-version: "1.11.*" 37 | - python-version: "3.6" 38 | torch-version: "1.12.*" 39 | - python-version: "3.6" 40 | torch-version: "1.13.*" 41 | - python-version: "3.6" 42 | torch-version: "2.0.*" 43 | - python-version: "3.6" 44 | torch-version: "2.1.*" 45 | 46 | - python-version: "3.7" 47 | torch-version: "2.0.*" 48 | - python-version: "3.7" 49 | torch-version: "2.1.*" 50 | 51 | - python-version: "3.10" 52 | torch-version: "1.8.*" 53 | - python-version: "3.10" 54 | torch-version: "1.9.*" 55 | - python-version: "3.10" 56 | torch-version: "1.10.*" 57 | 58 | - python-version: "3.11" 59 | torch-version: "1.8.*" 60 | - python-version: "3.11" 61 | torch-version: "1.9.*" 62 | - python-version: "3.11" 63 | torch-version: "1.10.*" 64 | - python-version: "3.11" 65 | torch-version: "1.11.*" 66 | - python-version: "3.11" 67 | torch-version: "1.12.*" 68 | - python-version: "3.11" 69 | torch-version: "1.13.*" 70 | 71 | runs-on: ubuntu-20.04 72 | 73 | steps: 74 | - uses: actions/checkout@v3 75 | 76 | - name: Set Swap Space 77 | uses: pierotofy/set-swap-space@master 78 | with: 79 | swap-size-gb: 10 80 | 81 | - name: Set up Python 82 | uses: actions/setup-python@v4 83 | with: 84 | python-version: ${{ matrix.python-version }} 85 | architecture: x64 86 | 87 | - name: Install PyAV Dependencies for Python 3.6 88 | if: matrix.python-version == '3.6' 89 | run: | 90 | sudo apt install -y libavformat-dev libavdevice-dev 91 | pip install "av==6.*" 92 | 93 | - name: Install dependencies 94 | run: | 95 | sudo apt install -y ffmpeg wget 96 | pip install torch==${{ matrix.torch-version }} 97 | pip install torchvision==${{ matrix.torchvision-version }} 98 | pip install -r requirements.lib.txt 99 | python init.py 100 | 101 | - name: Download model checkpoints 102 | run: | 103 | mkdir test/model 104 | wget -O test/model/marlin_vit_base_ytf.encoder.pt https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_base_ytf.encoder.pt 105 | wget -O test/model/marlin_vit_base_ytf.full.pt https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_base_ytf.full.pt 106 | wget -O test/model/marlin_vit_small_ytf.encoder.pt https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_small_ytf.encoder.pt 107 | wget -O test/model/marlin_vit_small_ytf.full.pt https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_small_ytf.full.pt 108 | wget -O test/model/marlin_vit_large_ytf.encoder.pt https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_large_ytf.encoder.pt 109 | wget -O test/model/marlin_vit_large_ytf.full.pt https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_large_ytf.full.pt 110 | 111 | - name: Set PYTHONPATH 112 | run: echo "PYTHONPATH=$(pwd)/src" >> $GITHUB_ENV 113 | 114 | - name: Run Test 115 | run: | # python -m unittest discover test 116 | python -m unittest test/test_version.py 117 | python -m unittest test/test_marlin_vit_base.py 118 | python -m unittest test/test_marlin_vit_small.py 119 | python -m unittest test/test_marlin_vit_large.py 120 | 121 | coverage: 122 | # Run coverage and report to coveralls 123 | name: Coverage 124 | needs: [unittest] 125 | runs-on: ubuntu-20.04 126 | 127 | steps: 128 | - uses: actions/checkout@v2 129 | 130 | - name: Set up Python 131 | uses: actions/setup-python@v2 132 | with: 133 | python-version: "3.10" 134 | architecture: x64 135 | 136 | - name: Install dependencies 137 | run: | 138 | sudo apt install -y ffmpeg wget 139 | pip install torch=="1.13.*" 140 | pip install torchvision=="0.14.*" 141 | pip install -r requirements.lib.txt 142 | python init.py 143 | pip install coverage pytest coveralls 144 | 145 | - name: Download model checkpoints 146 | run: | 147 | mkdir test/model 148 | wget -O test/model/marlin_vit_base_ytf.encoder.pt https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_base_ytf.encoder.pt 149 | wget -O test/model/marlin_vit_base_ytf.full.pt https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_base_ytf.full.pt 150 | wget -O test/model/marlin_vit_small_ytf.encoder.pt https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_small_ytf.encoder.pt 151 | wget -O test/model/marlin_vit_small_ytf.full.pt https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_small_ytf.full.pt 152 | wget -O test/model/marlin_vit_large_ytf.encoder.pt https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_large_ytf.encoder.pt 153 | wget -O test/model/marlin_vit_large_ytf.full.pt https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_large_ytf.full.pt 154 | 155 | - name: Set PYTHONPATH 156 | run: echo "PYTHONPATH=$(pwd)/src" >> $GITHUB_ENV 157 | 158 | - name: Run Coverage 159 | run: coverage run --source=marlin_pytorch -m unittest discover 160 | 161 | - name: Coveralls 162 | run: coveralls --service=github 163 | env: 164 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 165 | COVERALLS_FLAG_NAME: marlin_pytorch 166 | COVERALLS_PARALLEL: true 167 | 168 | coveralls_finish: 169 | name: Coveralls Finish 170 | needs: [coverage] 171 | runs-on: ubuntu-20.04 172 | container: python:3-slim 173 | 174 | steps: 175 | - name: Finished 176 | run: | 177 | pip3 install --upgrade coveralls 178 | coveralls --service=github --finish 179 | env: 180 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 181 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python,pycharm 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm 3 | 4 | ### PyCharm ### 5 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 6 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 7 | 8 | # User-specific stuff 9 | .idea/**/workspace.xml 10 | .idea/**/tasks.xml 11 | .idea/**/usage.statistics.xml 12 | .idea/**/dictionaries 13 | .idea/**/shelf 14 | 15 | # Generated files 16 | .idea/**/contentModel.xml 17 | 18 | # Sensitive or high-churn files 19 | .idea/**/dataSources/ 20 | .idea/**/dataSources.ids 21 | .idea/**/dataSources.local.xml 22 | .idea/**/sqlDataSources.xml 23 | .idea/**/dynamic.xml 24 | .idea/**/uiDesigner.xml 25 | .idea/**/dbnavigator.xml 26 | 27 | # Gradle 28 | .idea/**/gradle.xml 29 | .idea/**/libraries 30 | 31 | # Gradle and Maven with auto-import 32 | # When using Gradle or Maven with auto-import, you should exclude module files, 33 | # since they will be recreated, and may cause churn. Uncomment if using 34 | # auto-import. 35 | # .idea/artifacts 36 | # .idea/compiler.xml 37 | # .idea/jarRepositories.xml 38 | # .idea/modules.xml 39 | # .idea/*.iml 40 | # .idea/modules 41 | # *.iml 42 | # *.ipr 43 | 44 | # CMake 45 | cmake-build-*/ 46 | 47 | # Mongo Explorer plugin 48 | .idea/**/mongoSettings.xml 49 | 50 | # File-based project format 51 | *.iws 52 | 53 | # IntelliJ 54 | out/ 55 | 56 | # mpeltonen/sbt-idea plugin 57 | .idea_modules/ 58 | 59 | # JIRA plugin 60 | atlassian-ide-plugin.xml 61 | 62 | # Cursive Clojure plugin 63 | .idea/replstate.xml 64 | 65 | # Crashlytics plugin (for Android Studio and IntelliJ) 66 | com_crashlytics_export_strings.xml 67 | crashlytics.properties 68 | crashlytics-build.properties 69 | fabric.properties 70 | 71 | # Editor-based Rest Client 72 | .idea/httpRequests 73 | 74 | # Android studio 3.1+ serialized cache file 75 | .idea/caches/build_file_checksums.ser 76 | 77 | ### PyCharm Patch ### 78 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 79 | 80 | # *.iml 81 | # modules.xml 82 | # .idea/misc.xml 83 | # *.ipr 84 | 85 | # Sonarlint plugin 86 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 87 | .idea/**/sonarlint/ 88 | 89 | # SonarQube Plugin 90 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 91 | .idea/**/sonarIssues.xml 92 | 93 | # Markdown Navigator plugin 94 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 95 | .idea/**/markdown-navigator.xml 96 | .idea/**/markdown-navigator-enh.xml 97 | .idea/**/markdown-navigator/ 98 | 99 | # Cache file creation bug 100 | # See https://youtrack.jetbrains.com/issue/JBR-2257 101 | .idea/$CACHE_FILE$ 102 | 103 | # CodeStream plugin 104 | # https://plugins.jetbrains.com/plugin/12206-codestream 105 | .idea/codestream.xml 106 | 107 | ### Python ### 108 | # Byte-compiled / optimized / DLL files 109 | __pycache__/ 110 | *.py[cod] 111 | *$py.class 112 | 113 | # C extensions 114 | *.so 115 | 116 | # Distribution / packaging 117 | .Python 118 | build/ 119 | develop-eggs/ 120 | dist/ 121 | downloads/ 122 | eggs/ 123 | .eggs/ 124 | parts/ 125 | sdist/ 126 | var/ 127 | wheels/ 128 | pip-wheel-metadata/ 129 | share/python-wheels/ 130 | *.egg-info/ 131 | .installed.cfg 132 | *.egg 133 | MANIFEST 134 | 135 | # PyInstaller 136 | # Usually these files are written by a python script from a template 137 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 138 | *.manifest 139 | *.spec 140 | 141 | # Installer logs 142 | pip-log.txt 143 | pip-delete-this-directory.txt 144 | 145 | # Unit test / coverage reports 146 | htmlcov/ 147 | .tox/ 148 | .nox/ 149 | .coverage 150 | .coverage.* 151 | .cache 152 | nosetests.xml 153 | coverage.xml 154 | *.cover 155 | *.py,cover 156 | .hypothesis/ 157 | .pytest_cache/ 158 | pytestdebug.log 159 | 160 | # Translations 161 | *.mo 162 | *.pot 163 | 164 | # Django stuff: 165 | *.log 166 | local_settings.py 167 | db.sqlite3 168 | db.sqlite3-journal 169 | 170 | # Flask stuff: 171 | instance/ 172 | .webassets-cache 173 | 174 | # Scrapy stuff: 175 | .scrapy 176 | 177 | # Sphinx documentation 178 | docs/_build/ 179 | 180 | # PyBuilder 181 | target/ 182 | 183 | # Jupyter Notebook 184 | .ipynb_checkpoints 185 | 186 | # IPython 187 | profile_default/ 188 | ipython_config.py 189 | 190 | # pyenv 191 | .python-version 192 | 193 | # pipenv 194 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 195 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 196 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 197 | # install all needed dependencies. 198 | #Pipfile.lock 199 | 200 | # poetry 201 | #poetry.lock 202 | 203 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 204 | __pypackages__/ 205 | 206 | # Celery stuff 207 | celerybeat-schedule 208 | celerybeat.pid 209 | 210 | # SageMath parsed files 211 | *.sage.py 212 | 213 | # Environments 214 | # .env 215 | .env/ 216 | .venv/ 217 | env/ 218 | venv/ 219 | ENV/ 220 | env.bak/ 221 | venv.bak/ 222 | pythonenv* 223 | 224 | # Spyder project settings 225 | .spyderproject 226 | .spyproject 227 | 228 | # Rope project settings 229 | .ropeproject 230 | 231 | # mkdocs documentation 232 | /site 233 | 234 | # mypy 235 | .mypy_cache/ 236 | .dmypy.json 237 | dmypy.json 238 | 239 | # Pyre type checker 240 | .pyre/ 241 | 242 | # pytype static type analyzer 243 | .pytype/ 244 | 245 | # operating system-related files 246 | # file properties cache/storage on macOS 247 | *.DS_Store 248 | # thumbnail cache on Windows 249 | Thumbs.db 250 | 251 | # profiling data 252 | .prof 253 | 254 | 255 | # End of https://www.toptal.com/developers/gitignore/api/python,pycharm 256 | 257 | # User-defined 258 | node_modules 259 | .idea 260 | scratch*.py 261 | scratch*.ipynb 262 | /src/marlin_pytorch/version.txt 263 | marlin.encoder.pt 264 | marlin.full.pt 265 | marlin.ckpt 266 | .marlin 267 | face_parsing.farl.lapa.main_ema_136500_jit191.pt 268 | logs 269 | *.pth 270 | *.ckpt 271 | *.pt 272 | lightning_logs 273 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you find this work useful in your research, please cite it." 3 | preferred-citation: 4 | type: conference-paper 5 | title: "MARLIN: Masked Autoencoder for facial video Representation LearnINg" 6 | authors: 7 | - family-names: "Cai" 8 | given-names: "Zhixi" 9 | - family-names: "Ghosh" 10 | given-names: "Shreya" 11 | - family-names: "Stefanov" 12 | given-names: "Kalin" 13 | - family-names: "Dhall" 14 | given-names: "Abhinav" 15 | - family-names: "Cai" 16 | given-names: "Jianfei" 17 | - family-names: "Rezatofighi" 18 | given-names: "Hamid" 19 | - family-names: "Haffari" 20 | given-names: "Reza" 21 | - family-names: "Hayat" 22 | given-names: "Munawar" 23 | collection-title: "Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition" 24 | year: 2023 25 | location: 26 | name: "Vancouver, Canada" 27 | start: 1493 28 | end: 1504 29 | doi: 10.1109/CVPR52729.2023.00150 30 | -------------------------------------------------------------------------------- /MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | ## MARLIN Model Zoo 2 | 3 | | Name | Method | Backbone | Dataset | Epoch | Embedding | Encoder Params | Encoder MACs | Download | 4 | |----------------------|:------:|:--------:|:-------:|:-----:|:---------:|:--------------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 5 | | marlin_vit_small_ytf | MARLIN | ViT-S | YTF | 2000 | 384 | 22.48M | 25.96G | [Encoder](https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_small_ytf.encoder.pt)/[Full](https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_small_ytf.full.pt) | 6 | | marlin_vit_base_ytf | MARLIN | ViT-B | YTF | 2000 | 768 | 87.43M | 101.85G | [Encoder](https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_base_ytf.encoder.pt)/[Full](https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_base_ytf.full.pt) | 7 | | marlin_vit_large_ytf | MARLIN | ViT-L | YTF | 2000 | 1024 | 305.47M | 357.92G | [Encoder](https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_large_ytf.encoder.pt)/[Full](https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_large_ytf.full.pt) | -------------------------------------------------------------------------------- /config/celebv_hq/action/celebvhq_marlin_action_ft.yaml: -------------------------------------------------------------------------------- 1 | model_name: "celebvhq_marlin_action_ft" 2 | backbone: "marlin_vit_base_ytf" 3 | dataset: "celebvhq" 4 | task: "action" 5 | temporal_reduction: "mean" 6 | learning_rate: 1.0e-4 7 | seq_mean_pool: true 8 | finetune: true -------------------------------------------------------------------------------- /config/celebv_hq/action/celebvhq_marlin_action_lp.yaml: -------------------------------------------------------------------------------- 1 | model_name: "celebvhq_marlin_action_lp" 2 | backbone: "marlin_vit_base_ytf" 3 | dataset: "celebvhq" 4 | task: "action" 5 | temporal_reduction: "mean" 6 | learning_rate: 1.0e-4 7 | seq_mean_pool: true 8 | finetune: false -------------------------------------------------------------------------------- /config/celebv_hq/appearance/celebvhq_marlin_appearance_ft.yaml: -------------------------------------------------------------------------------- 1 | model_name: "celebvhq_marlin_appearance_ft" 2 | backbone: "marlin_vit_base_ytf" 3 | dataset: "celebvhq" 4 | task: "appearance" 5 | temporal_reduction: "mean" 6 | learning_rate: 1.0e-4 7 | seq_mean_pool: true 8 | finetune: true -------------------------------------------------------------------------------- /config/celebv_hq/appearance/celebvhq_marlin_appearance_lp.yaml: -------------------------------------------------------------------------------- 1 | model_name: "celebvhq_marlin_appearance_lp" 2 | backbone: "marlin_vit_base_ytf" 3 | dataset: "celebvhq" 4 | task: "appearance" 5 | temporal_reduction: "mean" 6 | learning_rate: 1.0e-4 7 | seq_mean_pool: true 8 | finetune: false -------------------------------------------------------------------------------- /config/pretrain/marlin_vit_base.yaml: -------------------------------------------------------------------------------- 1 | model_name: "marlin_vit_base" 2 | img_size: 224 3 | patch_size: 16 4 | clip_frames: 16 5 | tubelet_size: 2 6 | mask_strategy: "fasking" 7 | temporal_sample_rate: 2 8 | mask_percentage_target: 0.9 9 | mlp_ratio: 4.0 10 | qkv_bias: true 11 | qk_scale: null 12 | drop_rate: 0.0 13 | attn_drop_rate: 0.0 14 | norm_layer: "LayerNorm" 15 | init_values: 0.0 16 | weight_decay: 0.0 17 | feature_dir: "Marlin_Features_Vit_Base" 18 | adv_loss: true 19 | adv_weight: 0.01 20 | gp_weight: 0.0 21 | d_steps: 1 22 | g_steps: 1 23 | 24 | learning_rate: 25 | base: 1.5e-4 26 | warmup: 1.0e-6 27 | min: 1.0e-5 28 | warmup_epochs: 40 29 | 30 | optimizer: 31 | type: "AdamW" 32 | eps: 1.0e-8 33 | betas: [0.9, 0.95] 34 | 35 | encoder: 36 | embed_dim: 768 37 | depth: 12 38 | num_heads: 12 39 | 40 | decoder: 41 | embed_dim: 384 42 | depth: 4 43 | num_heads: 6 44 | -------------------------------------------------------------------------------- /config/pretrain/marlin_vit_large.yaml: -------------------------------------------------------------------------------- 1 | model_name: "marlin_vit_large" 2 | img_size: 224 3 | patch_size: 16 4 | clip_frames: 16 5 | tubelet_size: 2 6 | mask_strategy: "fasking" 7 | temporal_sample_rate: 2 8 | mask_percentage_target: 0.9 9 | mlp_ratio: 4.0 10 | qkv_bias: true 11 | qk_scale: null 12 | drop_rate: 0.0 13 | attn_drop_rate: 0.0 14 | norm_layer: "LayerNorm" 15 | init_values: 0.0 16 | weight_decay: 0.0 17 | feature_dir: "Marlin_Features_Vit_Large" 18 | adv_loss: true 19 | adv_weight: 0.01 20 | gp_weight: 0.0 21 | d_steps: 1 22 | g_steps: 1 23 | 24 | learning_rate: 25 | base: 1.5e-4 26 | warmup: 1.0e-6 27 | min: 1.0e-5 28 | warmup_epochs: 40 29 | 30 | optimizer: 31 | type: "AdamW" 32 | eps: 1.0e-8 33 | betas: [0.9, 0.95] 34 | 35 | encoder: 36 | embed_dim: 1024 37 | depth: 24 38 | num_heads: 16 39 | 40 | decoder: 41 | embed_dim: 512 42 | depth: 12 43 | num_heads: 8 44 | -------------------------------------------------------------------------------- /config/pretrain/marlin_vit_small.yaml: -------------------------------------------------------------------------------- 1 | model_name: "marlin_vit_small" 2 | img_size: 224 3 | patch_size: 16 4 | clip_frames: 16 5 | tubelet_size: 2 6 | mask_strategy: "fasking" 7 | temporal_sample_rate: 2 8 | mask_percentage_target: 0.9 9 | mlp_ratio: 4.0 10 | qkv_bias: true 11 | qk_scale: null 12 | drop_rate: 0.0 13 | attn_drop_rate: 0.0 14 | norm_layer: "LayerNorm" 15 | init_values: 0.0 16 | weight_decay: 0.0 17 | feature_dir: "Marlin_Features_Vit_Small" 18 | adv_loss: true 19 | adv_weight: 0.01 20 | gp_weight: 0.0 21 | d_steps: 1 22 | g_steps: 1 23 | 24 | learning_rate: 25 | base: 1.5e-4 26 | warmup: 1.0e-6 27 | min: 1.0e-5 28 | warmup_epochs: 40 29 | 30 | optimizer: 31 | type: "AdamW" 32 | eps: 1.0e-8 33 | betas: [0.9, 0.95] 34 | 35 | encoder: 36 | embed_dim: 384 37 | depth: 12 38 | num_heads: 6 39 | 40 | decoder: 41 | embed_dim: 192 42 | depth: 4 43 | num_heads: 3 44 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/celebv_hq.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | from itertools import islice 4 | from typing import Optional 5 | 6 | import ffmpeg 7 | import numpy as np 8 | import torch 9 | import torchvision 10 | from pytorch_lightning import LightningDataModule 11 | from torch.utils.data import DataLoader 12 | 13 | from marlin_pytorch.util import read_video, padding_video 14 | from util.misc import sample_indexes, read_text, read_json 15 | 16 | 17 | class CelebvHqBase(LightningDataModule, ABC): 18 | 19 | def __init__(self, data_root: str, split: str, task: str, data_ratio: float = 1.0, take_num: int = None): 20 | super().__init__() 21 | self.data_root = data_root 22 | self.split = split 23 | assert task in ("appearance", "action") 24 | self.task = task 25 | self.take_num = take_num 26 | 27 | self.name_list = list( 28 | filter(lambda x: x != "", read_text(os.path.join(data_root, f"{self.split}.txt")).split("\n"))) 29 | self.metadata = read_json(os.path.join(data_root, "celebvhq_info.json")) 30 | 31 | if data_ratio < 1.0: 32 | self.name_list = self.name_list[:int(len(self.name_list) * data_ratio)] 33 | if take_num is not None: 34 | self.name_list = self.name_list[:self.take_num] 35 | 36 | print(f"Dataset {self.split} has {len(self.name_list)} videos") 37 | 38 | @abstractmethod 39 | def __getitem__(self, index: int): 40 | pass 41 | 42 | def __len__(self): 43 | return len(self.name_list) 44 | 45 | 46 | # for fine-tuning 47 | class CelebvHq(CelebvHqBase): 48 | 49 | def __init__(self, 50 | root_dir: str, 51 | split: str, 52 | task: str, 53 | clip_frames: int, 54 | temporal_sample_rate: int, 55 | data_ratio: float = 1.0, 56 | take_num: Optional[int] = None 57 | ): 58 | super().__init__(root_dir, split, task, data_ratio, take_num) 59 | self.clip_frames = clip_frames 60 | self.temporal_sample_rate = temporal_sample_rate 61 | 62 | def __getitem__(self, index: int): 63 | y = self.metadata["clips"][self.name_list[index]]["attributes"][self.task] 64 | video_path = os.path.join(self.data_root, "cropped", self.name_list[index] + ".mp4") 65 | 66 | probe = ffmpeg.probe(video_path)["streams"][0] 67 | n_frames = int(probe["nb_frames"]) 68 | 69 | if n_frames <= self.clip_frames: 70 | video = read_video(video_path, channel_first=True).video / 255 71 | # pad frames to 16 72 | video = padding_video(video, self.clip_frames, "same") # (T, C, H, W) 73 | video = video.permute(1, 0, 2, 3) # (C, T, H, W) 74 | return video, torch.tensor(y, dtype=torch.long) 75 | elif n_frames <= self.clip_frames * self.temporal_sample_rate: 76 | # reset a lower temporal sample rate 77 | sample_rate = n_frames // self.clip_frames 78 | else: 79 | sample_rate = self.temporal_sample_rate 80 | # sample frames 81 | video_indexes = sample_indexes(n_frames, self.clip_frames, sample_rate) 82 | reader = torchvision.io.VideoReader(video_path) 83 | fps = reader.get_metadata()["video"]["fps"][0] 84 | reader.seek(video_indexes[0].item() / fps, True) 85 | frames = [] 86 | for frame in islice(reader, 0, self.clip_frames * sample_rate, sample_rate): 87 | frames.append(frame["data"]) 88 | video = torch.stack(frames) / 255 # (T, C, H, W) 89 | video = video.permute(1, 0, 2, 3) # (C, T, H, W) 90 | assert video.shape[1] == self.clip_frames, video_path 91 | return video, torch.tensor(y, dtype=torch.long).bool() 92 | 93 | 94 | # For linear probing 95 | class CelebvHqFeatures(CelebvHqBase): 96 | 97 | def __init__(self, root_dir: str, 98 | feature_dir: str, 99 | split: str, 100 | task: str, 101 | temporal_reduction: str, 102 | data_ratio: float = 1.0, 103 | take_num: Optional[int] = None 104 | ): 105 | super().__init__(root_dir, split, task, data_ratio, take_num) 106 | self.feature_dir = feature_dir 107 | self.temporal_reduction = temporal_reduction 108 | 109 | def __getitem__(self, index: int): 110 | feat_path = os.path.join(self.data_root, self.feature_dir, self.name_list[index] + ".npy") 111 | 112 | x = torch.from_numpy(np.load(feat_path)).float() 113 | 114 | if x.size(0) == 0: 115 | x = torch.zeros(1, 768, dtype=torch.float32) 116 | 117 | if self.temporal_reduction == "mean": 118 | x = x.mean(dim=0) 119 | elif self.temporal_reduction == "max": 120 | x = x.max(dim=0)[0] 121 | elif self.temporal_reduction == "min": 122 | x = x.min(dim=0)[0] 123 | else: 124 | raise ValueError(self.temporal_reduction) 125 | 126 | y = self.metadata["clips"][self.name_list[index]]["attributes"][self.task] 127 | 128 | return x, torch.tensor(y, dtype=torch.long).bool() 129 | 130 | 131 | class CelebvHqDataModule(LightningDataModule): 132 | 133 | def __init__(self, root_dir: str, 134 | load_raw: bool, 135 | task: str, 136 | batch_size: int, 137 | num_workers: int = 0, 138 | clip_frames: int = None, 139 | temporal_sample_rate: int = None, 140 | feature_dir: str = None, 141 | temporal_reduction: str = "mean", 142 | data_ratio: float = 1.0, 143 | take_train: Optional[int] = None, 144 | take_val: Optional[int] = None, 145 | take_test: Optional[int] = None 146 | ): 147 | super().__init__() 148 | self.root_dir = root_dir 149 | self.task = task 150 | self.batch_size = batch_size 151 | self.num_workers = num_workers 152 | self.clip_frames = clip_frames 153 | self.temporal_sample_rate = temporal_sample_rate 154 | self.feature_dir = feature_dir 155 | self.temporal_reduction = temporal_reduction 156 | self.load_raw = load_raw 157 | self.data_ratio = data_ratio 158 | self.take_train = take_train 159 | self.take_val = take_val 160 | self.take_test = take_test 161 | 162 | if load_raw: 163 | assert clip_frames is not None 164 | assert temporal_sample_rate is not None 165 | else: 166 | assert feature_dir is not None 167 | assert temporal_reduction is not None 168 | 169 | self.train_dataset = None 170 | self.val_dataset = None 171 | self.test_dataset = None 172 | 173 | def setup(self, stage: Optional[str] = None): 174 | if self.load_raw: 175 | self.train_dataset = CelebvHq(self.root_dir, "train", self.task, self.clip_frames, 176 | self.temporal_sample_rate, self.data_ratio, self.take_train) 177 | self.val_dataset = CelebvHq(self.root_dir, "val", self.task, self.clip_frames, 178 | self.temporal_sample_rate, self.data_ratio, self.take_val) 179 | self.test_dataset = CelebvHq(self.root_dir, "test", self.task, self.clip_frames, 180 | self.temporal_sample_rate, 1.0, self.take_test) 181 | else: 182 | self.train_dataset = CelebvHqFeatures(self.root_dir, self.feature_dir, "train", self.task, 183 | self.temporal_reduction, self.data_ratio, self.take_train) 184 | self.val_dataset = CelebvHqFeatures(self.root_dir, self.feature_dir, "val", self.task, 185 | self.temporal_reduction, self.data_ratio, self.take_val) 186 | self.test_dataset = CelebvHqFeatures(self.root_dir, self.feature_dir, "test", self.task, 187 | self.temporal_reduction, 1.0, self.take_test) 188 | 189 | def train_dataloader(self): 190 | return DataLoader( 191 | self.train_dataset, 192 | batch_size=self.batch_size, 193 | shuffle=True, 194 | num_workers=self.num_workers, 195 | pin_memory=True, 196 | drop_last=True 197 | ) 198 | 199 | def val_dataloader(self): 200 | return DataLoader( 201 | self.val_dataset, 202 | batch_size=self.batch_size, 203 | shuffle=False, 204 | num_workers=self.num_workers, 205 | pin_memory=True 206 | ) 207 | 208 | def test_dataloader(self): 209 | return DataLoader( 210 | self.test_dataset, 211 | batch_size=self.batch_size, 212 | shuffle=False, 213 | num_workers=self.num_workers, 214 | pin_memory=True 215 | ) 216 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from pytorch_lightning import Trainer 5 | from pytorch_lightning.callbacks import ModelCheckpoint 6 | from tqdm.auto import tqdm 7 | 8 | from dataset.celebv_hq import CelebvHqDataModule 9 | from marlin_pytorch.config import resolve_config 10 | from marlin_pytorch.util import read_yaml 11 | from model.classifier import Classifier 12 | from util.earlystop_lr import EarlyStoppingLR 13 | from util.lr_logger import LrLogger 14 | from util.seed import Seed 15 | from util.system_stats_logger import SystemStatsLogger 16 | 17 | 18 | def train_celebvhq(args, config): 19 | data_path = args.data_path 20 | resume_ckpt = args.resume 21 | n_gpus = args.n_gpus 22 | max_epochs = args.epochs 23 | 24 | finetune = config["finetune"] 25 | learning_rate = config["learning_rate"] 26 | task = config["task"] 27 | 28 | if task == "appearance": 29 | num_classes = 40 30 | elif task == "action": 31 | num_classes = 35 32 | else: 33 | raise ValueError(f"Unknown task {task}") 34 | 35 | if finetune: 36 | backbone_config = resolve_config(config["backbone"]) 37 | 38 | model = Classifier( 39 | num_classes, config["backbone"], True, args.marlin_ckpt, "multilabel", config["learning_rate"], 40 | args.n_gpus > 1, 41 | ) 42 | 43 | dm = CelebvHqDataModule( 44 | data_path, finetune, task, 45 | batch_size=args.batch_size, 46 | num_workers=args.num_workers, 47 | clip_frames=backbone_config.n_frames, 48 | temporal_sample_rate=2 49 | ) 50 | 51 | else: 52 | model = Classifier( 53 | num_classes, config["backbone"], False, 54 | None, "multilabel", config["learning_rate"], args.n_gpus > 1, 55 | ) 56 | 57 | dm = CelebvHqDataModule( 58 | data_path, finetune, task, 59 | batch_size=args.batch_size, 60 | num_workers=args.num_workers, 61 | feature_dir=config["backbone"], 62 | temporal_reduction=config["temporal_reduction"] 63 | ) 64 | 65 | if args.skip_train: 66 | dm.setup() 67 | return resume_ckpt, dm 68 | 69 | strategy = None if n_gpus <= 1 else "ddp" 70 | accelerator = "cpu" if n_gpus == 0 else "gpu" 71 | 72 | ckpt_filename = config["model_name"] + "-{epoch}-{val_auc:.3f}" 73 | ckpt_monitor = "val_auc" 74 | 75 | try: 76 | precision = int(args.precision) 77 | except ValueError: 78 | precision = args.precision 79 | 80 | ckpt_callback = ModelCheckpoint(dirpath=f"ckpt/{config['model_name']}", save_last=True, 81 | filename=ckpt_filename, 82 | monitor=ckpt_monitor, 83 | mode="max") 84 | 85 | trainer = Trainer(log_every_n_steps=1, devices=n_gpus, accelerator=accelerator, benchmark=True, 86 | logger=True, precision=precision, max_epochs=max_epochs, 87 | strategy=strategy, resume_from_checkpoint=resume_ckpt, 88 | callbacks=[ckpt_callback, LrLogger(), EarlyStoppingLR(1e-6), SystemStatsLogger()]) 89 | 90 | trainer.fit(model, dm) 91 | 92 | return ckpt_callback.best_model_path, dm 93 | 94 | 95 | def evaluate_celebvhq(args, ckpt, dm): 96 | print("Load checkpoint", ckpt) 97 | model = Classifier.load_from_checkpoint(ckpt) 98 | accelerator = "cpu" if args.n_gpus == 0 else "gpu" 99 | trainer = Trainer(log_every_n_steps=1, devices=1 if args.n_gpus > 0 else 0, accelerator=accelerator, benchmark=True, 100 | logger=False, enable_checkpointing=False) 101 | Seed.set(42) 102 | model.eval() 103 | 104 | # collect predictions 105 | preds = trainer.predict(model, dm.test_dataloader()) 106 | preds = torch.cat(preds) 107 | 108 | # collect ground truth 109 | ys = torch.zeros_like(preds, dtype=torch.bool) 110 | for i, (_, y) in enumerate(tqdm(dm.test_dataloader())): 111 | ys[i * args.batch_size: (i + 1) * args.batch_size] = y 112 | 113 | preds = preds.sigmoid() 114 | acc = ((preds > 0.5) == ys).float().mean() 115 | auc = model.auc_fn(preds, ys) 116 | results = { 117 | "acc": acc, 118 | "auc": auc 119 | } 120 | print(results) 121 | 122 | 123 | def evaluate(args): 124 | config = read_yaml(args.config) 125 | dataset_name = config["dataset"] 126 | 127 | if dataset_name == "celebvhq": 128 | ckpt, dm = train_celebvhq(args, config) 129 | evaluate_celebvhq(args, ckpt, dm) 130 | else: 131 | raise NotImplementedError(f"Dataset {dataset_name} not implemented") 132 | 133 | 134 | if __name__ == '__main__': 135 | parser = argparse.ArgumentParser("CelebV-HQ evaluation") 136 | parser.add_argument("--config", type=str, help="Path to CelebV-HQ evaluation config file.") 137 | parser.add_argument("--data_path", type=str, help="Path to CelebV-HQ dataset.") 138 | parser.add_argument("--marlin_ckpt", type=str, default=None, 139 | help="Path to MARLIN checkpoint. Default: None, load from online.") 140 | parser.add_argument("--n_gpus", type=int, default=1) 141 | parser.add_argument("--precision", type=str, default="32") 142 | parser.add_argument("--num_workers", type=int, default=8) 143 | parser.add_argument("--batch_size", type=int, default=32) 144 | parser.add_argument("--epochs", type=int, default=2000, help="Max epochs to train.") 145 | parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume training.") 146 | parser.add_argument("--skip_train", action="store_true", default=False, 147 | help="Skip training and evaluate only.") 148 | 149 | args = parser.parse_args() 150 | if args.skip_train: 151 | assert args.resume is not None 152 | 153 | evaluate(args) 154 | -------------------------------------------------------------------------------- /hf_src/marlin_configs/vit_base.py: -------------------------------------------------------------------------------- 1 | from marlin_huggingface import MarlinConfig 2 | 3 | 4 | vit_base_config = MarlinConfig( 5 | img_size=224, 6 | patch_size=16, 7 | n_frames=16, 8 | mlp_ratio=4.0, 9 | qkv_bias=True, 10 | qk_scale=None, 11 | drop_rate=0.0, 12 | attn_drop_rate=0.0, 13 | norm_layer="LayerNorm", 14 | init_values=0.0, 15 | tubelet_size=2, 16 | encoder_embed_dim=768, 17 | encoder_depth=12, 18 | encoder_num_heads=12, 19 | decoder_embed_dim=384, 20 | decoder_depth=4, 21 | decoder_num_heads=6, 22 | ) 23 | -------------------------------------------------------------------------------- /hf_src/marlin_configs/vit_large.py: -------------------------------------------------------------------------------- 1 | from marlin_huggingface import MarlinConfig 2 | 3 | 4 | vit_large_config = MarlinConfig( 5 | img_size=224, 6 | patch_size=16, 7 | n_frames=16, 8 | mlp_ratio=4.0, 9 | qkv_bias=True, 10 | qk_scale=None, 11 | drop_rate=0.0, 12 | attn_drop_rate=0.0, 13 | norm_layer="LayerNorm", 14 | init_values=0.0, 15 | tubelet_size=2, 16 | encoder_embed_dim=1024, 17 | encoder_depth=24, 18 | encoder_num_heads=16, 19 | decoder_embed_dim=512, 20 | decoder_depth=12, 21 | decoder_num_heads=8, 22 | ) 23 | -------------------------------------------------------------------------------- /hf_src/marlin_configs/vit_small.py: -------------------------------------------------------------------------------- 1 | from marlin_huggingface import MarlinConfig 2 | 3 | vit_small_config = MarlinConfig( 4 | img_size=224, 5 | patch_size=16, 6 | n_frames=16, 7 | mlp_ratio=4., 8 | qkv_bias=True, 9 | qk_scale=None, 10 | drop_rate=0., 11 | attn_drop_rate=0., 12 | norm_layer="LayerNorm", 13 | init_values=0., 14 | tubelet_size=2, 15 | encoder_embed_dim=384, 16 | encoder_depth=12, 17 | encoder_num_heads=6, 18 | decoder_embed_dim=192, 19 | decoder_depth=4, 20 | decoder_num_heads=3, 21 | ) 22 | -------------------------------------------------------------------------------- /hf_src/marlin_huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel, AutoConfig 2 | 3 | from .config import MarlinConfig 4 | from .marlin import Marlin, MarlinModel 5 | 6 | MarlinConfig.register_for_auto_class() 7 | MarlinModel.register_for_auto_class() 8 | AutoConfig.register("marlin", MarlinConfig) 9 | AutoModel.register(MarlinConfig, MarlinModel) 10 | 11 | __all__ = ["Marlin", "MarlinModel", "MarlinConfig"] 12 | -------------------------------------------------------------------------------- /hf_src/marlin_huggingface/config.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | 4 | class MarlinConfig(PretrainedConfig): 5 | model_type = "marlin" 6 | 7 | def __init__(self, **kwargs): 8 | self.img_size = kwargs.pop("img_size", None) 9 | self.patch_size = kwargs.pop("patch_size", None) 10 | self.n_frames = kwargs.pop("n_frames", None) 11 | self.encoder_embed_dim = kwargs.pop("encoder_embed_dim", None) 12 | self.encoder_depth = kwargs.pop("encoder_depth", None) 13 | self.encoder_num_heads = kwargs.pop("encoder_num_heads", None) 14 | self.decoder_embed_dim = kwargs.pop("decoder_embed_dim", None) 15 | self.decoder_depth = kwargs.pop("decoder_depth", None) 16 | self.decoder_num_heads = kwargs.pop("decoder_num_heads", None) 17 | self.mlp_ratio = kwargs.pop("mlp_ratio", None) 18 | self.qkv_bias = kwargs.pop("qkv_bias", None) 19 | self.qk_scale = kwargs.pop("qk_scale", None) 20 | self.drop_rate = kwargs.pop("drop_rate", None) 21 | self.attn_drop_rate = kwargs.pop("attn_drop_rate", None) 22 | self.norm_layer = kwargs.pop("norm_layer", None) 23 | self.init_values = kwargs.pop("init_values", None) 24 | self.tubelet_size = kwargs.pop("tubelet_size", None) 25 | self.as_feature_extractor = kwargs.pop("as_feature_extractor", True) 26 | 27 | super().__init__(**kwargs) 28 | -------------------------------------------------------------------------------- /hf_src/marlin_huggingface/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn, Tensor 4 | from torch.nn import LayerNorm, Linear, ModuleList 5 | 6 | from .modules import Block, no_grad_trunc_normal_ 7 | from .positional_embedding import SinCosPositionalEmbedding 8 | 9 | 10 | class MarlinDecoder(nn.Module): 11 | 12 | def __init__(self, img_size=224, patch_size=16, n_frames=16, embed_dim=384, depth=8, 13 | num_heads=6, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 14 | norm_layer="LayerNorm", init_values=1., tubelet_size=2 15 | ): 16 | super().__init__() 17 | output_dim = 3 * tubelet_size * patch_size * patch_size 18 | self.patch_size = patch_size 19 | self.tubelet_size = tubelet_size 20 | self.n_patch_h = img_size // patch_size 21 | self.n_patch_w = img_size // patch_size 22 | self.embed_dim = embed_dim 23 | if norm_layer == "LayerNorm": 24 | self.norm_layer = LayerNorm 25 | self.norm = self.norm_layer(embed_dim) 26 | else: 27 | raise NotImplementedError("Only LayerNorm is supported") 28 | 29 | # sine-cosine positional embeddings 30 | self.pos_embedding = SinCosPositionalEmbedding( 31 | (self.n_patch_h * self.n_patch_w * (n_frames // tubelet_size), embed_dim), dropout_rate=0.) 32 | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 33 | 34 | self.blocks = ModuleList([ 35 | Block( 36 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 37 | drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=self.norm_layer, 38 | init_values=init_values 39 | ) for _ in range(depth)]) 40 | 41 | self.head = Linear(embed_dim, output_dim) 42 | self.apply(self._init_weights) 43 | no_grad_trunc_normal_(self.mask_token, mean=0., std=0.02, a=-0.02, b=0.02) 44 | 45 | @staticmethod 46 | def _init_weights(m): 47 | if isinstance(m, nn.Linear): 48 | nn.init.xavier_uniform_(m.weight) 49 | if isinstance(m, nn.Linear) and m.bias is not None: 50 | nn.init.constant_(m.bias, 0) 51 | elif isinstance(m, nn.LayerNorm): 52 | nn.init.constant_(m.bias, 0) 53 | nn.init.constant_(m.weight, 1.0) 54 | 55 | def unpatch_to_img(self, x: Tensor) -> Tensor: 56 | # x: (Batch, No. batches, Prod of cube size * C) 57 | x = rearrange(x, "b n (c p) -> b n p c", c=3) 58 | # x: (Batch, No. batches, Prod of cube size, C) 59 | x = rearrange(x, "b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)", p0=self.tubelet_size, 60 | p1=self.patch_size, p2=self.patch_size, h=self.n_patch_h, w=self.n_patch_w) 61 | # x: (B, C, T, H, W) 62 | return x 63 | 64 | def forward_features(self, x, return_token_num=0): 65 | for block in self.blocks: 66 | x = block(x) 67 | 68 | if return_token_num > 0: 69 | x = x[:, -return_token_num:] 70 | 71 | x = self.norm(x) 72 | x = self.head(x) 73 | # x: (B, N_mask, C) 74 | return x 75 | 76 | def forward(self, x, mask): 77 | # mask: 0 -> masked, 1 -> visible 78 | b, n, c = x.shape 79 | expand_pos_embed = self.pos_embedding.emb.data.expand(b, -1, -1) 80 | pos_emb_vis = expand_pos_embed[mask].view(b, -1, c) 81 | pos_emb_mask = expand_pos_embed[~mask].view(b, -1, c) 82 | x = torch.cat([x + pos_emb_vis, self.mask_token + pos_emb_mask], dim=1) 83 | 84 | mask_num = pos_emb_mask.shape[1] 85 | 86 | x = self.forward_features(x, return_token_num=mask_num) 87 | return x 88 | -------------------------------------------------------------------------------- /hf_src/marlin_huggingface/encoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn, Tensor 2 | from torch.nn import ModuleList, LayerNorm 3 | 4 | from .modules import PatchEmbedding3d, Block 5 | from .positional_embedding import SinCosPositionalEmbedding 6 | 7 | 8 | class MarlinEncoder(nn.Module): 9 | 10 | def __init__(self, img_size=224, patch_size=16, n_frames=16, embed_dim=768, depth=12, 11 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 12 | norm_layer="LayerNorm", init_values=0., tubelet_size=2 13 | ): 14 | super().__init__() 15 | 16 | self.embed_dim = embed_dim 17 | self.patch_embedding = PatchEmbedding3d( 18 | input_size=(3, n_frames, img_size, img_size), 19 | patch_size=(tubelet_size, patch_size, patch_size), 20 | embedding=embed_dim 21 | ) 22 | num_patches = (img_size // patch_size) * (img_size // patch_size) * (n_frames // tubelet_size) 23 | 24 | # sine-cosine positional embeddings 25 | self.pos_embedding = SinCosPositionalEmbedding((num_patches, embed_dim), dropout_rate=0.) 26 | 27 | if norm_layer == "LayerNorm": 28 | self.norm_layer = LayerNorm 29 | self.norm = self.norm_layer(embed_dim) 30 | else: 31 | raise NotImplementedError("Only LayerNorm is supported") 32 | 33 | self.blocks = ModuleList([ 34 | Block( 35 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 36 | drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=self.norm_layer, 37 | init_values=init_values) 38 | for _ in range(depth) 39 | ]) 40 | 41 | self.apply(self._init_weights) 42 | 43 | @staticmethod 44 | def _init_weights(m): 45 | if isinstance(m, nn.Linear): 46 | nn.init.xavier_uniform_(m.weight) 47 | if isinstance(m, nn.Linear) and m.bias is not None: 48 | nn.init.constant_(m.bias, 0) 49 | elif isinstance(m, nn.LayerNorm): 50 | nn.init.constant_(m.bias, 0) 51 | nn.init.constant_(m.weight, 1.0) 52 | 53 | def forward_features(self, x): 54 | for block in self.blocks: 55 | x = block(x) 56 | x = self.norm(x) 57 | return x 58 | 59 | def forward(self, x: Tensor, mask: Tensor) -> Tensor: 60 | # mask: (B, T, N) with boolean values, 0 -> masked, 1 -> visible 61 | assert len(x.shape) == 5, "x must be 5D" 62 | emb = self.patch_embedding(x) 63 | emb = self.pos_embedding(emb) 64 | b, _, c = emb.shape 65 | emb = emb[mask].view(b, -1, c) # only visible patches are used 66 | emb = self.forward_features(emb) 67 | return emb 68 | 69 | def extract_features(self, x: Tensor, seq_mean_pool: bool) -> Tensor: 70 | x = self.patch_embedding(x) 71 | x = self.pos_embedding(x) 72 | for block in self.blocks: 73 | x = block(x) 74 | 75 | if seq_mean_pool: 76 | x = x.mean(dim=1) 77 | x = self.norm(x) 78 | return x 79 | -------------------------------------------------------------------------------- /hf_src/marlin_huggingface/marlin.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Linear, Module 6 | from transformers import PreTrainedModel 7 | 8 | from .encoder import MarlinEncoder 9 | from .decoder import MarlinDecoder 10 | 11 | from .config import MarlinConfig 12 | 13 | 14 | class Marlin(Module): 15 | def __init__( 16 | self, 17 | img_size: int, 18 | patch_size: int, 19 | n_frames: int, 20 | encoder_embed_dim: int, 21 | encoder_depth: int, 22 | encoder_num_heads: int, 23 | decoder_embed_dim: int, 24 | decoder_depth: int, 25 | decoder_num_heads: int, 26 | mlp_ratio: float, 27 | qkv_bias: bool, 28 | qk_scale: Optional[float], 29 | drop_rate: float, 30 | attn_drop_rate: float, 31 | norm_layer: str, 32 | init_values: float, 33 | tubelet_size: int, 34 | as_feature_extractor: bool = True, 35 | ): 36 | super().__init__() 37 | self.encoder = MarlinEncoder( 38 | img_size=img_size, 39 | patch_size=patch_size, 40 | n_frames=n_frames, 41 | embed_dim=encoder_embed_dim, 42 | depth=encoder_depth, 43 | num_heads=encoder_num_heads, 44 | mlp_ratio=mlp_ratio, 45 | qkv_bias=qkv_bias, 46 | qk_scale=qk_scale, 47 | drop_rate=drop_rate, 48 | attn_drop_rate=attn_drop_rate, 49 | norm_layer=norm_layer, 50 | init_values=init_values, 51 | tubelet_size=tubelet_size, 52 | ) 53 | self.as_feature_extractor = as_feature_extractor 54 | self.clip_frames = n_frames 55 | if as_feature_extractor: 56 | self.enc_dec_proj = None 57 | self.decoder = None 58 | else: 59 | self.decoder = MarlinDecoder( 60 | img_size=img_size, 61 | patch_size=patch_size, 62 | embed_dim=decoder_embed_dim, 63 | depth=decoder_depth, 64 | num_heads=decoder_num_heads, 65 | mlp_ratio=mlp_ratio, 66 | qkv_bias=qkv_bias, 67 | qk_scale=qk_scale, 68 | drop_rate=drop_rate, 69 | attn_drop_rate=attn_drop_rate, 70 | norm_layer=norm_layer, 71 | init_values=init_values, 72 | tubelet_size=tubelet_size, 73 | ) 74 | 75 | self.enc_dec_proj = Linear(encoder_embed_dim, decoder_embed_dim, bias=False) 76 | 77 | def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: 78 | if self.as_feature_extractor: 79 | raise RuntimeError( 80 | "For feature extraction, please use `extract_features` or `extract_video`." 81 | ) 82 | else: 83 | assert mask is not None 84 | x = self.encoder(x, mask) 85 | x = self.enc_dec_proj(x) 86 | x = self.decoder(x, mask) 87 | return x 88 | 89 | @property 90 | def device(self): 91 | return self.encoder.norm.weight.device 92 | 93 | def extract_features(self, x: Tensor, keep_seq: bool = True): 94 | """Extract features for one video clip (v)""" 95 | if self.training: 96 | return self.encoder.extract_features(x, seq_mean_pool=not keep_seq) 97 | else: 98 | with torch.no_grad(): 99 | return self.encoder.extract_features(x, seq_mean_pool=not keep_seq) 100 | 101 | 102 | class MarlinModel(PreTrainedModel): 103 | config_class = MarlinConfig 104 | 105 | def __init__(self, config: MarlinConfig): 106 | super().__init__(config) 107 | self.config = config 108 | self.marlin = Marlin( 109 | img_size=config.img_size, 110 | patch_size=config.patch_size, 111 | n_frames=config.n_frames, 112 | encoder_embed_dim=config.encoder_embed_dim, 113 | encoder_depth=config.encoder_depth, 114 | encoder_num_heads=config.encoder_num_heads, 115 | decoder_embed_dim=config.decoder_embed_dim, 116 | decoder_depth=config.decoder_depth, 117 | decoder_num_heads=config.decoder_num_heads, 118 | mlp_ratio=config.mlp_ratio, 119 | qkv_bias=config.qkv_bias, 120 | qk_scale=config.qk_scale, 121 | drop_rate=config.drop_rate, 122 | attn_drop_rate=config.attn_drop_rate, 123 | norm_layer=config.norm_layer, 124 | init_values=config.init_values, 125 | tubelet_size=config.tubelet_size, 126 | ) 127 | 128 | def forward(self, x: Tensor, keep_seq: bool = True): 129 | return self.marlin.extract_features(x, keep_seq=keep_seq) 130 | -------------------------------------------------------------------------------- /hf_src/marlin_huggingface/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from typing import Union, Optional, Callable, Tuple, List, Sequence 4 | 5 | import torch 6 | from einops.layers.torch import Rearrange 7 | from torch import Tensor, nn, Size 8 | from torch.nn import Conv3d, ModuleList 9 | from torch.nn import functional as F 10 | 11 | Shape = Union[Size, List[int], Tuple[int, ...]] 12 | ModuleFactory = Union[Callable[[], nn.Module], Callable[[int], nn.Module]] 13 | 14 | 15 | class PatchEmbedding3d(nn.Module): 16 | 17 | def __init__(self, input_size: Shape, patch_size: Union[int, Shape], embedding: int, 18 | strides: Optional[Union[int, Shape]] = None, 19 | build_normalization: Optional[ModuleFactory] = None 20 | ): 21 | super().__init__() 22 | # channel, time, height, width 23 | c, t, h, w = input_size 24 | # patch_time, patch_height, patch_width 25 | pt, ph, pw = (patch_size, patch_size, patch_size) if type(patch_size) is int else patch_size 26 | 27 | # configure the strides for conv3d 28 | if strides is None: 29 | # no specified means no overlap and gap between patches 30 | strides = (pt, ph, pw) 31 | elif type(strides) is int: 32 | # transform the side length of strides to 3D 33 | strides = (strides, strides, strides) 34 | 35 | self.projection = Conv3d(c, embedding, kernel_size=(pt, ph, pw), stride=strides) 36 | self.has_norm = build_normalization is not None 37 | if self.has_norm: 38 | self.normalization = build_normalization() 39 | self.rearrange = Rearrange("b d nt nh nw -> b (nt nh nw) d") 40 | 41 | def forward(self, x: Tensor) -> Tensor: 42 | x = self.projection(x) 43 | x = self.rearrange(x) 44 | if self.has_norm: 45 | x = self.normalization(x) 46 | return x 47 | 48 | 49 | class Linear(nn.Module): 50 | 51 | def __init__(self, in_features: int, out_features: int, bias: bool = True, 52 | build_activation: Optional[ModuleFactory] = None, 53 | build_normalization: Optional[ModuleFactory] = None, 54 | normalization_after_activation: bool = False, 55 | dropout_rate: float = 0. 56 | ): 57 | super().__init__() 58 | self.linear = nn.Linear(in_features, out_features, bias) 59 | 60 | self.has_act = build_activation is not None 61 | if self.has_act: 62 | self.activation = build_activation() 63 | else: 64 | self.activation = None 65 | 66 | self.has_norm = build_normalization is not None 67 | if self.has_norm: 68 | self.normalization = build_normalization() 69 | self.norm_after_act = normalization_after_activation 70 | else: 71 | self.normalization = None 72 | 73 | self.has_dropout = dropout_rate > 0 74 | if self.has_dropout: 75 | self.dropout = nn.Dropout(dropout_rate) 76 | 77 | def forward(self, x: Tensor) -> Tensor: 78 | x = self.linear(x) 79 | if self.has_act and self.has_norm: 80 | if self.norm_after_act: 81 | x = self.activation(x) 82 | x = self.normalization(x) 83 | else: 84 | x = self.normalization(x) 85 | x = self.activation(x) 86 | elif self.has_act and not self.has_norm: 87 | x = self.activation(x) 88 | elif not self.has_act and self.has_norm: 89 | x = self.normalization(x) 90 | 91 | if self.has_dropout: 92 | x = self.dropout(x) 93 | return x 94 | 95 | 96 | class MLP(nn.Module): 97 | 98 | def __init__(self, neurons: Sequence[int], 99 | build_activation: Optional[ModuleFactory] = None, dropout_rate: float = 0. 100 | ): 101 | super().__init__() 102 | n_features = neurons[1:] 103 | self.layers: ModuleList[Linear] = ModuleList( 104 | [Linear(neurons[i], neurons[i + 1], True, build_activation, None, 105 | False, dropout_rate 106 | ) for i in range(len(n_features) - 1) 107 | ] + [ 108 | Linear(neurons[-2], neurons[-1], True) 109 | ] 110 | ) 111 | 112 | def forward(self, x: Tensor) -> Tensor: 113 | for layer in self.layers: 114 | x = layer(x) 115 | return x 116 | 117 | 118 | class Attention(nn.Module): 119 | 120 | def __init__( 121 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 122 | proj_drop=0., attn_head_dim=None 123 | ): 124 | super().__init__() 125 | self.num_heads = num_heads 126 | head_dim = dim // num_heads 127 | if attn_head_dim is not None: 128 | head_dim = attn_head_dim 129 | all_head_dim = head_dim * self.num_heads 130 | self.scale = qk_scale or head_dim ** -0.5 131 | 132 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 133 | if qkv_bias: 134 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 135 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 136 | else: 137 | self.q_bias = None 138 | self.v_bias = None 139 | 140 | self.attn_drop = nn.Dropout(attn_drop) 141 | self.proj = nn.Linear(all_head_dim, dim) 142 | self.proj_drop = nn.Dropout(proj_drop) 143 | 144 | def forward(self, x): 145 | B, N, C = x.shape 146 | qkv_bias = None 147 | if self.q_bias is not None: 148 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 149 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 150 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 151 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 152 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 153 | 154 | q = q * self.scale 155 | attn = (q @ k.transpose(-2, -1)) 156 | 157 | attn = attn.softmax(dim=-1) 158 | attn = self.attn_drop(attn) 159 | 160 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 161 | x = self.proj(x) 162 | x = self.proj_drop(x) 163 | return x 164 | 165 | 166 | class Block(nn.Module): 167 | 168 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 169 | init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 170 | attn_head_dim=None 171 | ): 172 | super().__init__() 173 | self.norm1 = norm_layer(dim) 174 | self.attn = Attention( 175 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 176 | attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim) 177 | self.norm2 = norm_layer(dim) 178 | mlp_hidden_dim = int(dim * mlp_ratio) 179 | self.mlp = MLP( 180 | neurons=[dim, mlp_hidden_dim, dim], 181 | build_activation=act_layer, 182 | dropout_rate=drop 183 | ) 184 | 185 | if init_values > 0: 186 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 187 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 188 | else: 189 | self.gamma_1, self.gamma_2 = None, None 190 | 191 | def forward(self, x): 192 | if self.gamma_1 is None: 193 | x = x + self.attn(self.norm1(x)) 194 | x = x + self.mlp(self.norm2(x)) 195 | else: 196 | x = x + (self.gamma_1 * self.attn(self.norm1(x))) 197 | x = x + (self.gamma_2 * self.mlp(self.norm2(x))) 198 | return x 199 | 200 | 201 | def no_grad_trunc_normal_(tensor, mean, std, a, b): 202 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 203 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 204 | def norm_cdf(x): 205 | # Computes standard normal cumulative distribution function 206 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 207 | 208 | if (mean < a - 2 * std) or (mean > b + 2 * std): 209 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 210 | "The distribution of values may be incorrect.", 211 | stacklevel=2) 212 | 213 | with torch.no_grad(): 214 | # Values are generated by using a truncated uniform distribution and 215 | # then using the inverse CDF for the normal distribution. 216 | # Get upper and lower cdf values 217 | l = norm_cdf((a - mean) / std) 218 | u = norm_cdf((b - mean) / std) 219 | 220 | # Uniformly fill tensor with values from [l, u], then translate to 221 | # [2l-1, 2u-1]. 222 | tensor.uniform_(2 * l - 1, 2 * u - 1) 223 | 224 | # Use inverse cdf transform for normal distribution to get truncated 225 | # standard normal 226 | tensor.erfinv_() 227 | 228 | # Transform to proper mean, std 229 | tensor.mul_(std * math.sqrt(2.)) 230 | tensor.add_(mean) 231 | 232 | # Clamp to ensure it's in the proper range 233 | tensor.clamp_(min=a, max=b) 234 | return tensor 235 | -------------------------------------------------------------------------------- /hf_src/marlin_huggingface/positional_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | from .modules import Shape 5 | 6 | 7 | class PositionalEmbedding(nn.Module): 8 | 9 | def __init__(self, input_shape: Shape, dropout_rate: float = 0.5, trainable: bool = True): 10 | super().__init__() 11 | self.input_shape = input_shape 12 | self.emb = nn.Parameter(torch.zeros(1, *input_shape), requires_grad=trainable) 13 | self.use_dropout = dropout_rate is not None and dropout_rate != 0. 14 | if self.use_dropout: 15 | self.dropout = nn.Dropout(dropout_rate) 16 | 17 | def forward(self, x: Tensor) -> Tensor: 18 | x = x + self.emb 19 | if self.use_dropout: 20 | x = self.dropout(x) 21 | return x 22 | 23 | @property 24 | def trainable(self): 25 | return self.emb.requires_grad 26 | 27 | @trainable.setter 28 | def trainable(self, value: bool): 29 | self.emb.requires_grad = value 30 | 31 | 32 | class SinCosPositionalEmbedding(PositionalEmbedding): 33 | 34 | def __init__(self, input_shape: Shape, dropout_rate: float = 0.5): 35 | super().__init__(input_shape, dropout_rate, trainable=False) 36 | self.emb.data = self.make_embedding().unsqueeze(0) 37 | 38 | def make_embedding(self) -> Tensor: 39 | n_position, d_hid = self.input_shape 40 | 41 | def get_position_angle_vec(position): 42 | return position / torch.tensor(10000).pow( 43 | 2 * torch.div(torch.arange(d_hid), 2, rounding_mode='trunc') / d_hid) 44 | 45 | sinusoid_table = torch.stack([get_position_angle_vec(pos_i) for pos_i in range(n_position)], 0) 46 | sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) # dim 2i 47 | sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) # dim 2i+1 48 | 49 | return sinusoid_table.float() 50 | -------------------------------------------------------------------------------- /hf_src/publish_hf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--model", type=str, default="small", choices=["small", "base", "large"]) 6 | args = parser.parse_args() 7 | 8 | match args.model: 9 | case "small": 10 | from marlin_configs.vit_small import vit_small_config as config 11 | case "base": 12 | from marlin_configs.vit_base import vit_base_config as config 13 | case "large": 14 | from marlin_configs.vit_large import vit_large_config as config 15 | from marlin_huggingface import MarlinModel 16 | 17 | WEIGHT_FILE = f"./marlin_vit_{args.model}_ytf.encoder.pt" 18 | 19 | model = MarlinModel(config) 20 | state_dict = torch.load(WEIGHT_FILE, map_location='cpu') 21 | model.marlin.load_state_dict(state_dict) 22 | 23 | model.save_pretrained( 24 | f"marlin_vit_{args.model}_ytf", 25 | config=config, 26 | safe_serialization=True, 27 | ) 28 | model.push_to_hub(f"marlin_vit_{args.model}_ytf") 29 | -------------------------------------------------------------------------------- /init.py: -------------------------------------------------------------------------------- 1 | def read_version() -> str: 2 | with open("version.txt", "r") as file: 3 | version = file.read() 4 | return version 5 | 6 | 7 | def write_version(version: str) -> None: 8 | with open("src/marlin_pytorch/version.txt", "w") as file: 9 | file.write(version) 10 | 11 | 12 | def init_version() -> None: 13 | version = read_version() 14 | write_version(version) 15 | 16 | 17 | if __name__ == '__main__': 18 | init_version() 19 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/model/__init__.py -------------------------------------------------------------------------------- /model/classifier.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Sequence, Dict, Literal, Any 2 | 3 | from pytorch_lightning import LightningModule 4 | from torch import Tensor 5 | from torch.nn import CrossEntropyLoss, Linear, Identity, BCEWithLogitsLoss 6 | from torch.optim import Adam 7 | from torch.optim.lr_scheduler import ReduceLROnPlateau 8 | from torchmetrics import Accuracy, AUROC 9 | 10 | from marlin_pytorch import Marlin 11 | from marlin_pytorch.config import resolve_config 12 | 13 | 14 | class Classifier(LightningModule): 15 | 16 | def __init__(self, num_classes: int, backbone: str, finetune: bool, 17 | marlin_ckpt: Optional[str] = None, 18 | task: Literal["binary", "multiclass", "multilabel"] = "binary", 19 | learning_rate: float = 1e-4, distributed: bool = False 20 | ): 21 | super().__init__() 22 | self.save_hyperparameters() 23 | 24 | if finetune: 25 | if marlin_ckpt is None: 26 | self.model = Marlin.from_online(backbone).encoder 27 | else: 28 | self.model = Marlin.from_file(backbone, marlin_ckpt).encoder 29 | else: 30 | self.model = None 31 | 32 | config = resolve_config(backbone) 33 | 34 | self.fc = Linear(config.encoder_embed_dim, num_classes) 35 | self.learning_rate = learning_rate 36 | self.distributed = distributed 37 | self.task = task 38 | if task in "binary": 39 | self.loss_fn = BCEWithLogitsLoss() 40 | self.acc_fn = Accuracy(task=task, num_classes=1) 41 | self.auc_fn = AUROC(task=task, num_classes=1) 42 | elif task == "multiclass": 43 | self.loss_fn = CrossEntropyLoss() 44 | self.acc_fn = Accuracy(task=task, num_classes=num_classes) 45 | self.auc_fn = AUROC(task=task, num_classes=num_classes) 46 | elif task == "multilabel": 47 | self.loss_fn = BCEWithLogitsLoss() 48 | self.acc_fn = Accuracy(task="binary", num_classes=1) 49 | self.auc_fn = AUROC(task="binary", num_classes=1) 50 | 51 | @classmethod 52 | def from_module(cls, model, learning_rate: float = 1e-4, distributed=False): 53 | return cls(model, learning_rate, distributed) 54 | 55 | def forward(self, x): 56 | if self.model is not None: 57 | feat = self.model.extract_features(x, True) 58 | else: 59 | feat = x 60 | return self.fc(feat) 61 | 62 | def step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]]) -> Dict[str, Tensor]: 63 | x, y = batch 64 | y_hat = self(x) 65 | if self.task == "multilabel": 66 | y_hat = y_hat.flatten() 67 | y = y.flatten() 68 | loss = self.loss_fn(y_hat, y.float()) 69 | prob = y_hat.sigmoid() 70 | acc = self.acc_fn(prob, y) 71 | auc = self.auc_fn(prob, y) 72 | return {"loss": loss, "acc": acc, "auc": auc} 73 | 74 | def training_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None, 75 | optimizer_idx: Optional[int] = None, hiddens: Optional[Tensor] = None 76 | ) -> Dict[str, Tensor]: 77 | loss_dict = self.step(batch) 78 | self.log_dict({f"train_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, 79 | prog_bar=False, sync_dist=self.distributed) 80 | return loss_dict["loss"] 81 | 82 | def validation_step(self, batch: Optional[Union[Tensor, Sequence[Tensor]]] = None, batch_idx: Optional[int] = None, 83 | dataloader_idx: Optional[int] = None 84 | ) -> Dict[str, Tensor]: 85 | loss_dict = self.step(batch) 86 | self.log_dict({f"val_{k}": v for k, v in loss_dict.items()}, on_step=True, on_epoch=True, 87 | prog_bar=True, sync_dist=self.distributed) 88 | return loss_dict["loss"] 89 | 90 | def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: 91 | return self(batch[0]) 92 | 93 | def configure_optimizers(self): 94 | optimizer = Adam(self.parameters(), lr=self.learning_rate, betas=(0.5, 0.9)) 95 | return { 96 | "optimizer": optimizer, 97 | "lr_scheduler": { 98 | "scheduler": ReduceLROnPlateau(optimizer, factor=0.5, patience=7, verbose=True, min_lr=1e-8), 99 | "monitor": "train_loss" 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /preprocess/celebvhq_extract.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | from tqdm.auto import tqdm 9 | 10 | from marlin_pytorch import Marlin 11 | from marlin_pytorch.config import resolve_config 12 | 13 | sys.path.append(".") 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser("CelebV-HQ Feature Extraction") 17 | parser.add_argument("--backbone", type=str) 18 | parser.add_argument("--data_dir", type=str) 19 | args = parser.parse_args() 20 | 21 | model = Marlin.from_online(args.backbone) 22 | config = resolve_config(args.backbone) 23 | feat_dir = args.backbone 24 | 25 | model.cuda() 26 | model.eval() 27 | 28 | raw_video_path = os.path.join(args.data_dir, "cropped") 29 | all_videos = sorted(list(filter(lambda x: x.endswith(".mp4"), os.listdir(raw_video_path)))) 30 | Path(os.path.join(args.data_dir, feat_dir)).mkdir(parents=True, exist_ok=True) 31 | for video_name in tqdm(all_videos): 32 | video_path = os.path.join(raw_video_path, video_name) 33 | save_path = os.path.join(args.data_dir, feat_dir, video_name.replace(".mp4", ".npy")) 34 | try: 35 | feat = model.extract_video( 36 | video_path, crop_face=False, 37 | sample_rate=config.tubelet_size, stride=config.n_frames, 38 | keep_seq=False, reduction="none") 39 | 40 | except Exception as e: 41 | print(f"Video {video_path} error.", e) 42 | feat = torch.zeros(0, model.encoder.embed_dim, dtype=torch.float32) 43 | np.save(save_path, feat.cpu().numpy()) 44 | -------------------------------------------------------------------------------- /preprocess/celebvhq_preprocess.py: -------------------------------------------------------------------------------- 1 | # parsing labels, segment and crop raw videos. 2 | import argparse 3 | import os 4 | import sys 5 | 6 | sys.path.append(os.getcwd()) 7 | 8 | 9 | def crop_face(root: str): 10 | from util.face_sdk.face_crop import process_videos 11 | source_dir = os.path.join(root, "downloaded") 12 | target_dir = os.path.join(root, "cropped") 13 | process_videos(source_dir, target_dir, ext="mp4") 14 | 15 | 16 | def gen_split(root: str): 17 | videos = list(filter(lambda x: x.endswith('.mp4'), os.listdir(os.path.join(root, 'cropped')))) 18 | total_num = len(videos) 19 | 20 | with open(os.path.join(root, "train.txt"), "w") as f: 21 | for i in range(int(total_num * 0.8)): 22 | f.write(videos[i][:-4] + "\n") 23 | 24 | with open(os.path.join(root, "val.txt"), "w") as f: 25 | for i in range(int(total_num * 0.8), int(total_num * 0.9)): 26 | f.write(videos[i][:-4] + "\n") 27 | 28 | with open(os.path.join(root, "test.txt"), "w") as f: 29 | for i in range(int(total_num * 0.9), total_num): 30 | f.write(videos[i][:-4] + "\n") 31 | 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--data_dir", help="Root directory of CelebV-HQ") 35 | args = parser.parse_args() 36 | 37 | if __name__ == '__main__': 38 | data_root = args.data_dir 39 | crop_face(data_root) 40 | 41 | if not os.path.exists(os.path.join(data_root, "train.txt")) or \ 42 | not os.path.exists(os.path.join(data_root, "val.txt")) or \ 43 | not os.path.exists(os.path.join(data_root, "test.txt")): 44 | gen_split(data_root) 45 | -------------------------------------------------------------------------------- /preprocess/ytf_preprocess.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import sys 3 | import argparse 4 | import shutil 5 | 6 | parser = argparse.ArgumentParser("Preprocess YTF dataset") 7 | parser.add_argument("--data_dir", type=str) 8 | parser.add_argument("--max_workers", type=int, default=8) 9 | 10 | if __name__ == '__main__': 11 | 12 | args = parser.parse_args() 13 | 14 | # copy the metadata (split) to the data_dir 15 | shutil.copy(os.path.join(os.path.dirname(__file__), "..", "dataset", "misc", "youtube_face", "train_set.csv"), 16 | args.data_dir) 17 | shutil.copy(os.path.join(os.path.dirname(__file__), "..", "dataset", "misc", "youtube_face", "val_set.csv"), 18 | args.data_dir) 19 | 20 | # Crop faces from videos 21 | sys.path.append(".") 22 | if not os.path.exists("logs"): 23 | os.mkdir("logs") 24 | 25 | from util.face_sdk.face_crop import process_images as face_crop_process_images 26 | face_crop_process_images( 27 | os.path.join(args.data_dir, "frame_images_DB"), 28 | os.path.join(args.data_dir, "crop_images_DB"), 29 | args.max_workers, 30 | ) 31 | 32 | # Face parsing based on these cropped faces 33 | from util.face_sdk.face_parse import process_images as face_parse_process_images 34 | face_parse_process_images( 35 | os.path.join(args.data_dir, "crop_images_DB"), 36 | os.path.join(args.data_dir, "face_parsing_images_DB") 37 | ) 38 | -------------------------------------------------------------------------------- /requirements.lib.txt: -------------------------------------------------------------------------------- 1 | torch >= 1.8.0 2 | torchvision >= 0.9.0 3 | numpy >= 1.10 4 | einops >= 0.1 5 | ffmpeg-python >= 0.2.0 6 | opencv-python >= 4.3 7 | av >= 6.0 8 | tqdm >= 4.0 9 | pyyaml >= 5.0 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python>=4.6 2 | numpy>=1.23 3 | einops~=0.4 4 | torchvision>=0.12.0 5 | torch 6 | pyyaml>=6.0 7 | tqdm>=4.64.0 8 | scikit-image>=0.19.3 9 | matplotlib>=3.5.2 10 | pillow>=9.2.0 11 | pandas~=1.4.3 12 | marlin_pytorch==0.3.4 13 | pytorch_lightning==1.7.* 14 | ffmpeg-python>=0.2.0 15 | torchmetrics == 0.11.* 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import init 2 | import setuptools 3 | 4 | with open("README.md", "r", encoding="UTF-8") as file: 5 | long_description = file.read() 6 | 7 | requirements = [] 8 | with open("requirements.lib.txt", "r", encoding="UTF-8") as file: 9 | for line in file: 10 | requirements.append(line.strip()) 11 | 12 | 13 | version = init.read_version() 14 | init.write_version(version) 15 | 16 | setuptools.setup( 17 | name="marlin_pytorch", 18 | version=version, 19 | author="ControlNet", 20 | author_email="smczx@hotmail.com", 21 | description="Official pytorch implementation for MARLIN.", 22 | long_description=long_description, 23 | long_description_content_type="text/markdown", 24 | url="https://github.com/ControlNet/MARLIN", 25 | project_urls={ 26 | "Bug Tracker": "https://github.com/ControlNet/MARLIN/issues", 27 | "Source Code": "https://github.com/ControlNet/MARLIN", 28 | }, 29 | keywords=["deep learning", "pytorch", "AI"], 30 | package_dir={"": "src"}, 31 | packages=setuptools.find_packages(where="src", include=["marlin_pytorch", "marlin_pytorch.*"]), 32 | package_data={ 33 | "marlin_pytorch": [ 34 | "version.txt" 35 | ] 36 | }, 37 | python_requires=">=3.6", 38 | install_requires=requirements, 39 | license="CC BY-NC 4.0", 40 | classifiers=[ 41 | "Programming Language :: Python :: 3", 42 | "Programming Language :: Python :: 3.6", 43 | "Programming Language :: Python :: 3.7", 44 | "Programming Language :: Python :: 3.8", 45 | "Programming Language :: Python :: 3.9", 46 | "Programming Language :: Python :: 3.10", 47 | "Programming Language :: Python :: 3.11", 48 | "License :: Other/Proprietary License", 49 | "Operating System :: OS Independent", 50 | "Intended Audience :: Developers", 51 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 52 | "Topic :: Multimedia :: Video", 53 | ], 54 | ) 55 | -------------------------------------------------------------------------------- /src/marlin_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | from .model import Marlin 4 | 5 | __all__ = [ 6 | "Marlin", 7 | ] 8 | 9 | with open(os.path.join(os.path.dirname(__file__), "version.txt"), "r") as file: 10 | __version__ = file.read() 11 | -------------------------------------------------------------------------------- /src/marlin_pytorch/config.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Optional, Type, TypeVar 4 | 5 | from marlin_pytorch.util import read_yaml, Singleton, NoArgInit 6 | 7 | 8 | @dataclass 9 | class MarlinConfig: 10 | img_size: int 11 | patch_size: int 12 | n_frames: int 13 | encoder_embed_dim: int 14 | encoder_depth: int 15 | encoder_num_heads: int 16 | decoder_embed_dim: int 17 | decoder_depth: int 18 | decoder_num_heads: int 19 | mlp_ratio: float 20 | qkv_bias: bool 21 | qk_scale: Optional[float] 22 | drop_rate: float 23 | attn_drop_rate: float 24 | norm_layer: str 25 | init_values: float 26 | tubelet_size: int 27 | 28 | 29 | class Downloadable(ABC): 30 | 31 | @property 32 | @abstractmethod 33 | def full_model_url(self) -> str: 34 | pass 35 | 36 | @property 37 | @abstractmethod 38 | def encoder_model_url(self) -> str: 39 | pass 40 | 41 | 42 | T = TypeVar("T", bound=MarlinConfig) 43 | 44 | _configs = {} 45 | 46 | 47 | def register_model(name: str): 48 | def wrapper(cls: Type[T]): 49 | _configs[name] = cls 50 | return cls 51 | 52 | return wrapper 53 | 54 | 55 | class SharedConfig(MarlinConfig): 56 | img_size = 224 57 | patch_size = 16 58 | n_frames = 16 59 | mlp_ratio = 4. 60 | qkv_bias = True 61 | qk_scale = None 62 | drop_rate = 0. 63 | attn_drop_rate = 0. 64 | norm_layer = "LayerNorm" 65 | init_values = 0. 66 | tubelet_size = 2 67 | 68 | 69 | @register_model("marlin_vit_base_ytf") 70 | @Singleton 71 | class MarlinVitBaseConfig(NoArgInit, SharedConfig, Downloadable): 72 | encoder_embed_dim = 768 73 | encoder_depth = 12 74 | encoder_num_heads = 12 75 | decoder_embed_dim = 384 76 | decoder_depth = 4 77 | decoder_num_heads = 6 78 | full_model_url = "https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_base_ytf.full.pt" 79 | encoder_model_url = "https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_base_ytf.encoder.pt" 80 | 81 | 82 | @register_model("marlin_vit_small_ytf") 83 | @Singleton 84 | class MarlinVitSmallConfig(NoArgInit, SharedConfig, Downloadable): 85 | encoder_embed_dim = 384 86 | encoder_depth = 12 87 | encoder_num_heads = 6 88 | decoder_embed_dim = 192 89 | decoder_depth = 4 90 | decoder_num_heads = 3 91 | full_model_url = \ 92 | "https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_small_ytf.full.pt" 93 | encoder_model_url = \ 94 | "https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_small_ytf.encoder.pt" 95 | 96 | 97 | @register_model("marlin_vit_large_ytf") 98 | @Singleton 99 | class MarlinVitLargeConfig(NoArgInit, SharedConfig, Downloadable): 100 | encoder_embed_dim = 1024 101 | encoder_depth = 24 102 | encoder_num_heads = 16 103 | decoder_embed_dim = 512 104 | decoder_depth = 12 105 | decoder_num_heads = 8 106 | full_model_url = \ 107 | "https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_large_ytf.full.pt" 108 | encoder_model_url = \ 109 | "https://github.com/ControlNet/MARLIN/releases/download/model_v1/marlin_vit_large_ytf.encoder.pt" 110 | 111 | 112 | def register_model_from_yaml(name: str, path: str) -> None: 113 | config = read_yaml(path) 114 | marlin_config = MarlinConfig( 115 | img_size=config["img_size"], 116 | patch_size=config["patch_size"], 117 | n_frames=config["clip_frames"], 118 | encoder_embed_dim=config["encoder"]["embed_dim"], 119 | encoder_depth=config["encoder"]["depth"], 120 | encoder_num_heads=config["encoder"]["num_heads"], 121 | decoder_embed_dim=config["decoder"]["embed_dim"], 122 | decoder_depth=config["decoder"]["depth"], 123 | decoder_num_heads=config["decoder"]["num_heads"], 124 | mlp_ratio=config["mlp_ratio"], 125 | qkv_bias=config["qkv_bias"], 126 | qk_scale=config["qk_scale"], 127 | drop_rate=config["drop_rate"], 128 | attn_drop_rate=config["attn_drop_rate"], 129 | norm_layer=config["norm_layer"], 130 | init_values=config["init_values"], 131 | tubelet_size=config["tubelet_size"] 132 | ) 133 | _configs[name] = marlin_config 134 | 135 | 136 | def resolve_config(name: str) -> MarlinConfig: 137 | if name in _configs: 138 | return _configs[name] 139 | else: 140 | raise ValueError(f"Model {name} not found. Please register it first. The current registered models are: " 141 | f"{_configs.keys()}") 142 | -------------------------------------------------------------------------------- /src/marlin_pytorch/face_detector.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path 3 | import sys 4 | from typing import Optional, Union, Tuple 5 | 6 | import cv2 7 | from numpy import ndarray 8 | 9 | from .util import read_yaml, Singleton, crop_with_padding 10 | 11 | 12 | @Singleton 13 | class FaceXZooFaceDetector: 14 | 15 | def __init__(self): 16 | self.faceDetModelHandler = None 17 | self.inited = False 18 | 19 | def init(self, face_sdk_path: Optional[str] = None, device: str = "cuda:0"): 20 | if face_sdk_path is not None: 21 | sys.path.append(face_sdk_path) 22 | else: 23 | if os.path.exists("FaceX-Zoo"): 24 | face_sdk_path = "FaceX-Zoo" 25 | sys.path.append(face_sdk_path) 26 | try: 27 | from core.model_handler.face_detection.FaceDetModelHandler import FaceDetModelHandler 28 | from core.model_loader.face_detection.FaceDetModelLoader import FaceDetModelLoader 29 | except ImportError: 30 | raise ImportError("FaceX-Zoo cannot be imported, please specify the path to the face_sdk path of FaceXZoo" 31 | " or put it in the working directory.") 32 | 33 | model_conf = read_yaml(os.path.join(face_sdk_path, "config", "model_conf.yaml")) 34 | model_path = os.path.join(face_sdk_path, 'models') 35 | scene = 'non-mask' 36 | model_category = 'face_detection' 37 | model_name = model_conf[scene][model_category] 38 | 39 | faceDetModelLoader = FaceDetModelLoader(model_path, model_category, model_name) 40 | model, cfg = faceDetModelLoader.load_model() 41 | self.faceDetModelHandler = FaceDetModelHandler(model, device, cfg) 42 | self.inited = True 43 | 44 | @staticmethod 45 | def install(path: Optional[str] = None) -> str: 46 | """ 47 | Install FaceX-Zoo by clone from GitHub. 48 | 49 | Args: 50 | path (``str``, optional): The path to install FaceX-Zoo, default is "./FaceX-Zoo". 51 | 52 | Returns: 53 | ``str``: The path to the installed FaceX-Zoo. 54 | 55 | """ 56 | path = path or "FaceX-Zoo" 57 | if os.path.exists(path): 58 | return path 59 | 60 | os.system(f"git clone --depth=1 https://github.com/ControlNet/FaceX-Zoo {path or ''}") 61 | return path 62 | 63 | def detect_face(self, image: ndarray): 64 | assert image.ndim == 3 and image.shape[2] == 3, "frame should be 3-dim" 65 | dets = self.faceDetModelHandler.inference_on_image(image) 66 | return dets 67 | 68 | def crop_face(self, frame, margin=1, x=0, y=0) -> Tuple[ndarray, int, int, int]: 69 | dets = self.detect_face(frame) 70 | if len(dets) > 0: 71 | x1, y1, x2, y2, confidence = dets[0] 72 | # center 73 | x, y = (int((x1 + x2) / 2), int((y1 + y2) / 2)) 74 | margin = int(max(abs(x2 - x1), abs(y2 - y1)) / 2) 75 | # crop face 76 | face = crop_with_padding(frame, x - margin, x + margin, y - margin, y + margin, 0) 77 | face = cv2.resize(face, (224, 224)) 78 | return face, margin, x, y 79 | 80 | def crop_image(self, image_path: str, out_path: str, max_faces=1, margin=0) -> None: 81 | if max_faces > 1: 82 | raise NotImplementedError("Multiple faces are not supported yet.") 83 | 84 | frame = cv2.imread(image_path) 85 | dets = self.detect_face(frame) 86 | for det in dets[:max_faces]: 87 | x1, y1, x2, y2, _ = det 88 | cropped = crop_with_padding(frame, int(x1 - margin), int(x2 + margin), int(y1 - margin), int(y2 + margin)) 89 | cv2.imwrite(out_path, cropped) 90 | 91 | def crop_video(self, video_path: str, out_path: str, frame_size: Optional[Union[int, Tuple[int, int]]] = None, 92 | margin=0, fourcc="mp4v" 93 | ) -> None: 94 | video = cv2.VideoCapture(video_path) 95 | if not video.isOpened(): 96 | raise IOError("Cannot open video file: " + video_path) 97 | 98 | # infer frame size 99 | if frame_size is None: 100 | frame_size = self._infer_frame_size(video_path, margin) 101 | 102 | if type(frame_size) is int: 103 | frame_size = frame_size, frame_size 104 | 105 | fps = video.get(cv2.CAP_PROP_FPS) 106 | writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*fourcc), fps, frame_size) 107 | x1, y1, x2, y2 = 0, 0, 1, 1 108 | while True: 109 | ret, frame = video.read() 110 | if not ret: 111 | break 112 | dets = self.detect_face(frame) 113 | if len(dets) > 0: 114 | x1, y1, x2, y2, confidence = dets[0] 115 | # center 116 | x, y = int((x1 + x2) / 2), int((y1 + y2) / 2) 117 | side = int(max(abs(x2 - x1), abs(y2 - y1))) 118 | x1 = x - side // 2 119 | x2 = x + side // 2 120 | y1 = y - side // 2 121 | y2 = y + side // 2 122 | 123 | cropped = crop_with_padding(frame, int(x1 - margin), int(x2 + margin), int(y1 - margin), int(y2 + margin)) 124 | resized = cv2.resize(cropped, frame_size) 125 | writer.write(resized) 126 | video.release() 127 | writer.release() 128 | 129 | def crop_image_dir(self, image_dir: str, out_dir: str, pattern="*.jpg", *args, **kwargs) -> None: 130 | all_images = glob.glob(os.path.join(image_dir, pattern), root_dir=image_dir) 131 | for image_path in all_images: 132 | out_path = os.path.join(out_dir, image_path) 133 | self.crop_image(image_path, out_path, *args, **kwargs) 134 | 135 | def crop_video_dir(self, video_dir: str, out_dir: str, pattern="*.mp4", *args, **kwargs) -> None: 136 | all_videos = glob.glob(os.path.join(video_dir, pattern), root_dir=video_dir) 137 | for video_path in all_videos: 138 | out_path = os.path.join(out_dir, video_path) 139 | self.crop_video(video_path, out_path, *args, **kwargs) 140 | 141 | def _infer_frame_size(self, video_path: str, margin: int = 0 142 | ) -> Tuple[int, int]: 143 | video = cv2.VideoCapture(video_path) 144 | while True: 145 | ret, frame = video.read() 146 | if not ret: 147 | break 148 | dets = self.detect_face(frame) 149 | if len(dets) > 0: 150 | x1, y1, x2, y2, confidence = dets[0] 151 | # center 152 | side = int(max(abs(x2 - x1), abs(y2 - y1))) 153 | video.release() 154 | return side + 2 * margin, side + 2 * margin 155 | -------------------------------------------------------------------------------- /src/marlin_pytorch/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .marlin import Marlin 2 | 3 | __all__ = ["Marlin"] 4 | -------------------------------------------------------------------------------- /src/marlin_pytorch/model/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn, Tensor 4 | from torch.nn import LayerNorm, Linear, ModuleList 5 | 6 | from .modules import Block, no_grad_trunc_normal_ 7 | from .positional_embedding import SinCosPositionalEmbedding 8 | 9 | 10 | class MarlinDecoder(nn.Module): 11 | 12 | def __init__(self, img_size=224, patch_size=16, n_frames=16, embed_dim=384, depth=8, 13 | num_heads=6, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 14 | norm_layer="LayerNorm", init_values=1., tubelet_size=2 15 | ): 16 | super().__init__() 17 | output_dim = 3 * tubelet_size * patch_size * patch_size 18 | self.patch_size = patch_size 19 | self.tubelet_size = tubelet_size 20 | self.n_patch_h = img_size // patch_size 21 | self.n_patch_w = img_size // patch_size 22 | self.embed_dim = embed_dim 23 | if norm_layer == "LayerNorm": 24 | self.norm_layer = LayerNorm 25 | self.norm = self.norm_layer(embed_dim) 26 | else: 27 | raise NotImplementedError("Only LayerNorm is supported") 28 | 29 | # sine-cosine positional embeddings 30 | self.pos_embedding = SinCosPositionalEmbedding( 31 | (self.n_patch_h * self.n_patch_w * (n_frames // tubelet_size), embed_dim), dropout_rate=0.) 32 | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 33 | 34 | self.blocks = ModuleList([ 35 | Block( 36 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 37 | drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=self.norm_layer, 38 | init_values=init_values 39 | ) for _ in range(depth)]) 40 | 41 | self.head = Linear(embed_dim, output_dim) 42 | self.apply(self._init_weights) 43 | no_grad_trunc_normal_(self.mask_token, mean=0., std=0.02, a=-0.02, b=0.02) 44 | 45 | @staticmethod 46 | def _init_weights(m): 47 | if isinstance(m, nn.Linear): 48 | nn.init.xavier_uniform_(m.weight) 49 | if isinstance(m, nn.Linear) and m.bias is not None: 50 | nn.init.constant_(m.bias, 0) 51 | elif isinstance(m, nn.LayerNorm): 52 | nn.init.constant_(m.bias, 0) 53 | nn.init.constant_(m.weight, 1.0) 54 | 55 | def unpatch_to_img(self, x: Tensor) -> Tensor: 56 | # x: (Batch, No. batches, Prod of cube size * C) 57 | x = rearrange(x, "b n (c p) -> b n p c", c=3) 58 | # x: (Batch, No. batches, Prod of cube size, C) 59 | x = rearrange(x, "b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)", p0=self.tubelet_size, 60 | p1=self.patch_size, p2=self.patch_size, h=self.n_patch_h, w=self.n_patch_w) 61 | # x: (B, C, T, H, W) 62 | return x 63 | 64 | def forward_features(self, x, return_token_num=0): 65 | for block in self.blocks: 66 | x = block(x) 67 | 68 | if return_token_num > 0: 69 | x = x[:, -return_token_num:] 70 | 71 | x = self.norm(x) 72 | x = self.head(x) 73 | # x: (B, N_mask, C) 74 | return x 75 | 76 | def forward(self, x, mask): 77 | # mask: 0 -> masked, 1 -> visible 78 | b, n, c = x.shape 79 | expand_pos_embed = self.pos_embedding.emb.data.expand(b, -1, -1) 80 | pos_emb_vis = expand_pos_embed[mask].view(b, -1, c) 81 | pos_emb_mask = expand_pos_embed[~mask].view(b, -1, c) 82 | x = torch.cat([x + pos_emb_vis, self.mask_token + pos_emb_mask], dim=1) 83 | 84 | mask_num = pos_emb_mask.shape[1] 85 | 86 | x = self.forward_features(x, return_token_num=mask_num) 87 | return x 88 | -------------------------------------------------------------------------------- /src/marlin_pytorch/model/encoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn, Tensor 2 | from torch.nn import ModuleList, LayerNorm 3 | 4 | from .modules import PatchEmbedding3d, Block 5 | from .positional_embedding import SinCosPositionalEmbedding 6 | 7 | 8 | class MarlinEncoder(nn.Module): 9 | 10 | def __init__(self, img_size=224, patch_size=16, n_frames=16, embed_dim=768, depth=12, 11 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 12 | norm_layer="LayerNorm", init_values=0., tubelet_size=2 13 | ): 14 | super().__init__() 15 | 16 | self.embed_dim = embed_dim 17 | self.patch_embedding = PatchEmbedding3d( 18 | input_size=(3, n_frames, img_size, img_size), 19 | patch_size=(tubelet_size, patch_size, patch_size), 20 | embedding=embed_dim 21 | ) 22 | num_patches = (img_size // patch_size) * (img_size // patch_size) * (n_frames // tubelet_size) 23 | 24 | # sine-cosine positional embeddings 25 | self.pos_embedding = SinCosPositionalEmbedding((num_patches, embed_dim), dropout_rate=0.) 26 | 27 | if norm_layer == "LayerNorm": 28 | self.norm_layer = LayerNorm 29 | self.norm = self.norm_layer(embed_dim) 30 | else: 31 | raise NotImplementedError("Only LayerNorm is supported") 32 | 33 | self.blocks = ModuleList([ 34 | Block( 35 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 36 | drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=self.norm_layer, 37 | init_values=init_values) 38 | for _ in range(depth) 39 | ]) 40 | 41 | self.apply(self._init_weights) 42 | 43 | @staticmethod 44 | def _init_weights(m): 45 | if isinstance(m, nn.Linear): 46 | nn.init.xavier_uniform_(m.weight) 47 | if isinstance(m, nn.Linear) and m.bias is not None: 48 | nn.init.constant_(m.bias, 0) 49 | elif isinstance(m, nn.LayerNorm): 50 | nn.init.constant_(m.bias, 0) 51 | nn.init.constant_(m.weight, 1.0) 52 | 53 | def forward_features(self, x): 54 | for block in self.blocks: 55 | x = block(x) 56 | x = self.norm(x) 57 | return x 58 | 59 | def forward(self, x: Tensor, mask: Tensor) -> Tensor: 60 | # mask: (B, T, N) with boolean values, 0 -> masked, 1 -> visible 61 | assert len(x.shape) == 5, "x must be 5D" 62 | emb = self.patch_embedding(x) 63 | emb = self.pos_embedding(emb) 64 | b, _, c = emb.shape 65 | emb = emb[mask].view(b, -1, c) # only visible patches are used 66 | emb = self.forward_features(emb) 67 | return emb 68 | 69 | def extract_features(self, x: Tensor, seq_mean_pool: bool) -> Tensor: 70 | x = self.patch_embedding(x) 71 | x = self.pos_embedding(x) 72 | for block in self.blocks: 73 | x = block(x) 74 | 75 | if seq_mean_pool: 76 | x = x.mean(dim=1) 77 | x = self.norm(x) 78 | return x 79 | -------------------------------------------------------------------------------- /src/marlin_pytorch/model/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from typing import Union, Optional, Callable, Tuple, List, Sequence 4 | 5 | import torch 6 | from einops.layers.torch import Rearrange 7 | from torch import Tensor, nn, Size 8 | from torch.nn import Conv3d, ModuleList 9 | from torch.nn import functional as F 10 | 11 | Shape = Union[Size, List[int], Tuple[int, ...]] 12 | ModuleFactory = Union[Callable[[], nn.Module], Callable[[int], nn.Module]] 13 | 14 | 15 | class PatchEmbedding3d(nn.Module): 16 | 17 | def __init__(self, input_size: Shape, patch_size: Union[int, Shape], embedding: int, 18 | strides: Optional[Union[int, Shape]] = None, 19 | build_normalization: Optional[ModuleFactory] = None 20 | ): 21 | super().__init__() 22 | # channel, time, height, width 23 | c, t, h, w = input_size 24 | # patch_time, patch_height, patch_width 25 | pt, ph, pw = (patch_size, patch_size, patch_size) if type(patch_size) is int else patch_size 26 | 27 | # configure the strides for conv3d 28 | if strides is None: 29 | # no specified means no overlap and gap between patches 30 | strides = (pt, ph, pw) 31 | elif type(strides) is int: 32 | # transform the side length of strides to 3D 33 | strides = (strides, strides, strides) 34 | 35 | self.projection = Conv3d(c, embedding, kernel_size=(pt, ph, pw), stride=strides) 36 | self.has_norm = build_normalization is not None 37 | if self.has_norm: 38 | self.normalization = build_normalization() 39 | self.rearrange = Rearrange("b d nt nh nw -> b (nt nh nw) d") 40 | 41 | def forward(self, x: Tensor) -> Tensor: 42 | x = self.projection(x) 43 | x = self.rearrange(x) 44 | if self.has_norm: 45 | x = self.normalization(x) 46 | return x 47 | 48 | 49 | class Linear(nn.Module): 50 | 51 | def __init__(self, in_features: int, out_features: int, bias: bool = True, 52 | build_activation: Optional[ModuleFactory] = None, 53 | build_normalization: Optional[ModuleFactory] = None, 54 | normalization_after_activation: bool = False, 55 | dropout_rate: float = 0. 56 | ): 57 | super().__init__() 58 | self.linear = nn.Linear(in_features, out_features, bias) 59 | 60 | self.has_act = build_activation is not None 61 | if self.has_act: 62 | self.activation = build_activation() 63 | else: 64 | self.activation = None 65 | 66 | self.has_norm = build_normalization is not None 67 | if self.has_norm: 68 | self.normalization = build_normalization() 69 | self.norm_after_act = normalization_after_activation 70 | else: 71 | self.normalization = None 72 | 73 | self.has_dropout = dropout_rate > 0 74 | if self.has_dropout: 75 | self.dropout = nn.Dropout(dropout_rate) 76 | 77 | def forward(self, x: Tensor) -> Tensor: 78 | x = self.linear(x) 79 | if self.has_act and self.has_norm: 80 | if self.norm_after_act: 81 | x = self.activation(x) 82 | x = self.normalization(x) 83 | else: 84 | x = self.normalization(x) 85 | x = self.activation(x) 86 | elif self.has_act and not self.has_norm: 87 | x = self.activation(x) 88 | elif not self.has_act and self.has_norm: 89 | x = self.normalization(x) 90 | 91 | if self.has_dropout: 92 | x = self.dropout(x) 93 | return x 94 | 95 | 96 | class MLP(nn.Module): 97 | 98 | def __init__(self, neurons: Sequence[int], 99 | build_activation: Optional[ModuleFactory] = None, dropout_rate: float = 0. 100 | ): 101 | super().__init__() 102 | n_features = neurons[1:] 103 | self.layers: ModuleList[Linear] = ModuleList( 104 | [Linear(neurons[i], neurons[i + 1], True, build_activation, None, 105 | False, dropout_rate 106 | ) for i in range(len(n_features) - 1) 107 | ] + [ 108 | Linear(neurons[-2], neurons[-1], True) 109 | ] 110 | ) 111 | 112 | def forward(self, x: Tensor) -> Tensor: 113 | for layer in self.layers: 114 | x = layer(x) 115 | return x 116 | 117 | 118 | class Attention(nn.Module): 119 | 120 | def __init__( 121 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 122 | proj_drop=0., attn_head_dim=None 123 | ): 124 | super().__init__() 125 | self.num_heads = num_heads 126 | head_dim = dim // num_heads 127 | if attn_head_dim is not None: 128 | head_dim = attn_head_dim 129 | all_head_dim = head_dim * self.num_heads 130 | self.scale = qk_scale or head_dim ** -0.5 131 | 132 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 133 | if qkv_bias: 134 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 135 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 136 | else: 137 | self.q_bias = None 138 | self.v_bias = None 139 | 140 | self.attn_drop = nn.Dropout(attn_drop) 141 | self.proj = nn.Linear(all_head_dim, dim) 142 | self.proj_drop = nn.Dropout(proj_drop) 143 | 144 | def forward(self, x): 145 | B, N, C = x.shape 146 | qkv_bias = None 147 | if self.q_bias is not None: 148 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 149 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 150 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 151 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 152 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 153 | 154 | q = q * self.scale 155 | attn = (q @ k.transpose(-2, -1)) 156 | 157 | attn = attn.softmax(dim=-1) 158 | attn = self.attn_drop(attn) 159 | 160 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 161 | x = self.proj(x) 162 | x = self.proj_drop(x) 163 | return x 164 | 165 | 166 | class Block(nn.Module): 167 | 168 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 169 | init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 170 | attn_head_dim=None 171 | ): 172 | super().__init__() 173 | self.norm1 = norm_layer(dim) 174 | self.attn = Attention( 175 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 176 | attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim) 177 | self.norm2 = norm_layer(dim) 178 | mlp_hidden_dim = int(dim * mlp_ratio) 179 | self.mlp = MLP( 180 | neurons=[dim, mlp_hidden_dim, dim], 181 | build_activation=act_layer, 182 | dropout_rate=drop 183 | ) 184 | 185 | if init_values > 0: 186 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 187 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) 188 | else: 189 | self.gamma_1, self.gamma_2 = None, None 190 | 191 | def forward(self, x): 192 | if self.gamma_1 is None: 193 | x = x + self.attn(self.norm1(x)) 194 | x = x + self.mlp(self.norm2(x)) 195 | else: 196 | x = x + (self.gamma_1 * self.attn(self.norm1(x))) 197 | x = x + (self.gamma_2 * self.mlp(self.norm2(x))) 198 | return x 199 | 200 | 201 | def no_grad_trunc_normal_(tensor, mean, std, a, b): 202 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 203 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 204 | def norm_cdf(x): 205 | # Computes standard normal cumulative distribution function 206 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 207 | 208 | if (mean < a - 2 * std) or (mean > b + 2 * std): 209 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 210 | "The distribution of values may be incorrect.", 211 | stacklevel=2) 212 | 213 | with torch.no_grad(): 214 | # Values are generated by using a truncated uniform distribution and 215 | # then using the inverse CDF for the normal distribution. 216 | # Get upper and lower cdf values 217 | l = norm_cdf((a - mean) / std) 218 | u = norm_cdf((b - mean) / std) 219 | 220 | # Uniformly fill tensor with values from [l, u], then translate to 221 | # [2l-1, 2u-1]. 222 | tensor.uniform_(2 * l - 1, 2 * u - 1) 223 | 224 | # Use inverse cdf transform for normal distribution to get truncated 225 | # standard normal 226 | tensor.erfinv_() 227 | 228 | # Transform to proper mean, std 229 | tensor.mul_(std * math.sqrt(2.)) 230 | tensor.add_(mean) 231 | 232 | # Clamp to ensure it's in the proper range 233 | tensor.clamp_(min=a, max=b) 234 | return tensor 235 | -------------------------------------------------------------------------------- /src/marlin_pytorch/model/positional_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor, nn 3 | 4 | from .modules import Shape 5 | 6 | 7 | class PositionalEmbedding(nn.Module): 8 | 9 | def __init__(self, input_shape: Shape, dropout_rate: float = 0.5, trainable: bool = True): 10 | super().__init__() 11 | self.input_shape = input_shape 12 | self.emb = nn.Parameter(torch.zeros(1, *input_shape), requires_grad=trainable) 13 | self.use_dropout = dropout_rate is not None and dropout_rate != 0. 14 | if self.use_dropout: 15 | self.dropout = nn.Dropout(dropout_rate) 16 | 17 | def forward(self, x: Tensor) -> Tensor: 18 | x = x + self.emb 19 | if self.use_dropout: 20 | x = self.dropout(x) 21 | return x 22 | 23 | @property 24 | def trainable(self): 25 | return self.emb.requires_grad 26 | 27 | @trainable.setter 28 | def trainable(self, value: bool): 29 | self.emb.requires_grad = value 30 | 31 | 32 | class SinCosPositionalEmbedding(PositionalEmbedding): 33 | 34 | def __init__(self, input_shape: Shape, dropout_rate: float = 0.5): 35 | super().__init__(input_shape, dropout_rate, trainable=False) 36 | self.emb.data = self.make_embedding().unsqueeze(0) 37 | 38 | def make_embedding(self) -> Tensor: 39 | n_position, d_hid = self.input_shape 40 | 41 | def get_position_angle_vec(position): 42 | return position / torch.tensor(10000).pow( 43 | 2 * torch.div(torch.arange(d_hid), 2, rounding_mode='trunc') / d_hid) 44 | 45 | sinusoid_table = torch.stack([get_position_angle_vec(pos_i) for pos_i in range(n_position)], 0) 46 | sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) # dim 2i 47 | sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) # dim 2i+1 48 | 49 | return sinusoid_table.float() 50 | -------------------------------------------------------------------------------- /src/marlin_pytorch/util.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, TypeVar, Union 2 | from typing import Type, Dict 3 | 4 | import numpy as np 5 | import torchvision 6 | import yaml 7 | from einops import rearrange 8 | from numpy import ndarray 9 | from torch import Tensor 10 | from torch.nn import functional as F 11 | from tqdm.auto import tqdm 12 | 13 | 14 | def read_video(path: str, channel_first: bool = True): 15 | video, audio, info = torchvision.io.read_video(path) 16 | if channel_first: 17 | video = rearrange(video, 'T H W C -> T C H W') 18 | return video 19 | 20 | 21 | def read_yaml(path: str) -> Dict[str, Any]: 22 | with open(path, "r") as file: 23 | return yaml.load(file, Loader=yaml.Loader) 24 | 25 | 26 | def padding_video(tensor: Tensor, target: int, padding_method: str = "zero", padding_position: str = "tail") -> Tensor: 27 | t, c, h, w = tensor.shape 28 | padding_size = target - t 29 | 30 | pad = _get_padding_pair(padding_size, padding_position) 31 | 32 | if padding_method == "zero": 33 | return F.pad(tensor, pad=[0, 0, 0, 0, 0, 0] + pad) 34 | elif padding_method == "same": 35 | tensor = rearrange(tensor, "t c h w -> c h w t") 36 | tensor = F.pad(tensor, pad=pad + [0, 0], mode="replicate") 37 | return rearrange(tensor, "c h w t -> t c h w") 38 | else: 39 | raise ValueError("Wrong padding method. It should be zero or tail or average.") 40 | 41 | 42 | def _get_padding_pair(padding_size: int, padding_position: str) -> List[int]: 43 | if padding_position == "tail": 44 | pad = [0, padding_size] 45 | elif padding_position == "head": 46 | pad = [padding_size, 0] 47 | elif padding_position == "average": 48 | padding_head = padding_size // 2 49 | padding_tail = padding_size - padding_head 50 | pad = [padding_head, padding_tail] 51 | else: 52 | raise ValueError("Wrong padding position. It should be zero or tail or average.") 53 | return pad 54 | 55 | 56 | class DownloadProgressBar(tqdm): 57 | total: int 58 | 59 | def update_to(self, b=1, bsize=1, tsize=None): 60 | if tsize is not None: 61 | self.total = tsize 62 | self.update(b * bsize - self.n) 63 | 64 | 65 | T = TypeVar("T") 66 | 67 | 68 | class Singleton: 69 | all_instances: Dict[Type, object] = {} 70 | 71 | def __new__(cls, clazz: Type[T]) -> T: 72 | cls.all_instances[clazz] = clazz() 73 | return cls.all_instances[clazz] 74 | 75 | 76 | def crop_with_padding(image: ndarray, x1: int, x2: int, y1: int, y2: int, pad_value: Union[int, float] = 0., 77 | batch: bool = False 78 | ) -> ndarray: 79 | assert y2 > y1 and x2 > x1, "Should follow y2 > y1 and x2 > x1" 80 | 81 | if not batch: 82 | image = image[np.newaxis, ...] 83 | 84 | crop_shape = np.array([y2 - y1, x2 - x1]) 85 | 86 | if len(image.shape) == 3: 87 | b, h, w = image.shape 88 | cropped = np.full((b, *crop_shape), pad_value, dtype=image.dtype) 89 | elif len(image.shape) == 4: 90 | b, h, w, c = image.shape 91 | cropped = np.full((b, *crop_shape, c), pad_value, dtype=image.dtype) 92 | else: 93 | raise ValueError("Invalid shape, the image should be one of following shapes: ([B,] H, W) or ([B,] H, W, C)") 94 | 95 | # compute cropped index of image 96 | image_y_start, image_x_start = np.clip([y1, x1], 0, [h, w]) 97 | image_y_end, image_x_end = np.clip([y2, x2], 0, [h, w]) 98 | 99 | # compute target index of output 100 | crop_y_start, crop_x_start = np.clip([-y1, -x1], 0, crop_shape) 101 | crop_y_end, crop_x_end = crop_shape - np.clip([y2 - h, x2 - w], 0, crop_shape) 102 | 103 | # assign values 104 | cropped[:, crop_y_start:crop_y_end, crop_x_start:crop_x_end] = \ 105 | image[:, image_y_start:image_y_end, image_x_start:image_x_end] 106 | 107 | return cropped if batch else cropped[0] 108 | 109 | 110 | class NoArgInit: 111 | def __init__(self): 112 | pass 113 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/__init__.py -------------------------------------------------------------------------------- /test/input_sample/cropped01.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/input_sample/cropped01.mp4 -------------------------------------------------------------------------------- /test/input_sample/cropped02.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/input_sample/cropped02.mp4 -------------------------------------------------------------------------------- /test/input_sample/cropped03.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/input_sample/cropped03.mp4 -------------------------------------------------------------------------------- /test/input_sample/cropped04.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/input_sample/cropped04.mp4 -------------------------------------------------------------------------------- /test/input_sample/cropped05.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/input_sample/cropped05.mp4 -------------------------------------------------------------------------------- /test/input_sample/video01.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/input_sample/video01.mp4 -------------------------------------------------------------------------------- /test/input_sample/video02.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/input_sample/video02.mp4 -------------------------------------------------------------------------------- /test/input_sample/video03.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/input_sample/video03.mp4 -------------------------------------------------------------------------------- /test/input_sample/video04.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/input_sample/video04.mp4 -------------------------------------------------------------------------------- /test/input_sample/video05.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/input_sample/video05.mp4 -------------------------------------------------------------------------------- /test/output_sample/marlin_vit_base/cropped01.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/output_sample/marlin_vit_base/cropped01.npy -------------------------------------------------------------------------------- /test/output_sample/marlin_vit_base/cropped02.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/output_sample/marlin_vit_base/cropped02.npy -------------------------------------------------------------------------------- /test/output_sample/marlin_vit_base/cropped03.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/output_sample/marlin_vit_base/cropped03.npy -------------------------------------------------------------------------------- /test/output_sample/marlin_vit_base/cropped04.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/output_sample/marlin_vit_base/cropped04.npy -------------------------------------------------------------------------------- /test/output_sample/marlin_vit_base/cropped05.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/output_sample/marlin_vit_base/cropped05.npy -------------------------------------------------------------------------------- /test/output_sample/marlin_vit_base/video01.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/output_sample/marlin_vit_base/video01.npy -------------------------------------------------------------------------------- /test/output_sample/marlin_vit_base/video02.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/output_sample/marlin_vit_base/video02.npy -------------------------------------------------------------------------------- /test/output_sample/marlin_vit_base/video03.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/output_sample/marlin_vit_base/video03.npy -------------------------------------------------------------------------------- /test/output_sample/marlin_vit_base/video04.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/output_sample/marlin_vit_base/video04.npy -------------------------------------------------------------------------------- /test/output_sample/marlin_vit_base/video05.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/test/output_sample/marlin_vit_base/video05.npy -------------------------------------------------------------------------------- /test/test_marlin_pytorch.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Callable 3 | 4 | import numpy as np 5 | import torch.cuda 6 | 7 | from marlin_pytorch import Marlin 8 | 9 | 10 | class TestMarlinPytorch: 11 | assertTrue: Callable 12 | CROP_VIDEOS = [f"cropped{str(i).zfill(2)}" for i in range(1, 6)] 13 | WILD_VIDEOS = [f"video{str(i).zfill(2)}" for i in range(1, 6)] 14 | USE_GPU = torch.cuda.is_available() 15 | 16 | MODEL_NAME: Optional[str] = None 17 | MODEL_ENCODER_PATH: Optional[str] = None 18 | MODEL_FULL_PATH: Optional[str] = None 19 | EMBEDDING_SIZE: Optional[int] = None 20 | 21 | def test_load_full_model_from_file(self): 22 | Marlin.from_file(self.MODEL_NAME, self.MODEL_FULL_PATH) 23 | self.assertTrue(True) 24 | 25 | def test_load_encoder_from_file(self): 26 | Marlin.from_file(self.MODEL_NAME, self.MODEL_ENCODER_PATH) 27 | self.assertTrue(True) 28 | 29 | def test_load_full_model_from_online(self): 30 | Marlin.from_online(self.MODEL_NAME, full_model=True) 31 | self.assertTrue(True) 32 | 33 | def test_load_encoder_from_online(self): 34 | Marlin.from_online(self.MODEL_NAME, full_model=False) 35 | self.assertTrue(True) 36 | 37 | def test_extract_wild_video(self): 38 | if not os.path.exists(os.path.join("test", "output_sample", self.MODEL_NAME)): 39 | return 40 | 41 | model = Marlin.from_file(self.MODEL_NAME, self.MODEL_ENCODER_PATH) 42 | if self.USE_GPU: 43 | model.cuda() 44 | 45 | for video in self.WILD_VIDEOS: 46 | feat = model.extract_video(os.path.join("test", "input_sample", f"{video}.mp4"), crop_face=True) 47 | feat = feat.cpu().numpy() 48 | true = np.load(os.path.join("test", "output_sample", self.MODEL_NAME, f"{video}.npy")) 49 | diff = np.abs(feat - true).mean() 50 | self.assertTrue(diff < 1.5e-4) 51 | 52 | def test_extract_cropped_video(self): 53 | if not os.path.exists(os.path.join("test", "output_sample", self.MODEL_NAME)): 54 | return 55 | 56 | model = Marlin.from_file(self.MODEL_NAME, self.MODEL_ENCODER_PATH) 57 | if self.USE_GPU: 58 | model.cuda() 59 | 60 | for video in self.CROP_VIDEOS: 61 | feat = model.extract_video(os.path.join("test", "input_sample", f"{video}.mp4")) 62 | feat = feat.cpu().numpy() 63 | true = np.load(os.path.join("test", "output_sample", self.MODEL_NAME, f"{video}.npy")) 64 | diff = np.abs(feat - true).mean() 65 | self.assertTrue(diff < 1.5e-4) 66 | 67 | def test_extract_cropped_clip(self): 68 | model = Marlin.from_file(self.MODEL_NAME, self.MODEL_ENCODER_PATH) 69 | if self.USE_GPU: 70 | model.cuda() 71 | 72 | x = torch.rand(1, 3, 16, 224, 224).to(model.device) 73 | self.assertTrue(model.extract_features(x).shape == (1, 1568, self.EMBEDDING_SIZE)) 74 | self.assertTrue(model.extract_features(x, keep_seq=False).shape == (1, self.EMBEDDING_SIZE)) 75 | 76 | def test_reconstruct_clip(self): 77 | model = Marlin.from_file(self.MODEL_NAME, self.MODEL_FULL_PATH) 78 | if self.USE_GPU: 79 | model.cuda() 80 | 81 | mask = torch.zeros((1, 1568)).to(model.device).bool() 82 | mask[:, :392] = True 83 | x = torch.rand(1, 3, 16, 224, 224).to(model.device) 84 | pred = model(x, mask) 85 | self.assertTrue(pred.shape == (1, 1176, 1536)) 86 | -------------------------------------------------------------------------------- /test/test_marlin_vit_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from typing import Optional 4 | 5 | from .test_marlin_pytorch import TestMarlinPytorch 6 | 7 | 8 | class MarlinViTBase(unittest.TestCase, TestMarlinPytorch): 9 | MODEL_NAME: Optional[str] = "marlin_vit_base_ytf" 10 | MODEL_ENCODER_PATH: Optional[str] = os.path.join("test", "model", f"marlin_vit_base_ytf.encoder.pt") 11 | MODEL_FULL_PATH: Optional[str] = os.path.join("test", "model", "marlin_vit_base_ytf.full.pt") 12 | EMBEDDING_SIZE: Optional[int] = 768 13 | -------------------------------------------------------------------------------- /test/test_marlin_vit_large.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from typing import Optional 4 | 5 | from .test_marlin_pytorch import TestMarlinPytorch 6 | 7 | 8 | class MarlinViTLarge(unittest.TestCase, TestMarlinPytorch): 9 | MODEL_NAME: Optional[str] = "marlin_vit_large_ytf" 10 | MODEL_ENCODER_PATH: Optional[str] = os.path.join("test", "model", f"marlin_vit_large_ytf.encoder.pt") 11 | MODEL_FULL_PATH: Optional[str] = os.path.join("test", "model", "marlin_vit_large_ytf.full.pt") 12 | EMBEDDING_SIZE: Optional[int] = 1024 13 | -------------------------------------------------------------------------------- /test/test_marlin_vit_small.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | from typing import Optional 4 | 5 | from .test_marlin_pytorch import TestMarlinPytorch 6 | 7 | 8 | class MarlinViTSmall(unittest.TestCase, TestMarlinPytorch): 9 | MODEL_NAME: Optional[str] = "marlin_vit_small_ytf" 10 | MODEL_ENCODER_PATH: Optional[str] = os.path.join("test", "model", f"marlin_vit_small_ytf.encoder.pt") 11 | MODEL_FULL_PATH: Optional[str] = os.path.join("test", "model", "marlin_vit_small_ytf.full.pt") 12 | EMBEDDING_SIZE: Optional[int] = 384 13 | -------------------------------------------------------------------------------- /test/test_version.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unittest 3 | 4 | 5 | class TestVersion(unittest.TestCase): 6 | 7 | def test_package_version(self): 8 | version_pattern = r""" 9 | v? 10 | (?: 11 | (?:(?P[0-9]+)!)? # epoch 12 | (?P[0-9]+(?:\.[0-9]+)*) # release segment 13 | (?P
                                          # pre-release
14 |                     [-_\.]?
15 |                     (?P(a|b|c|rc|alpha|beta|pre|preview))
16 |                     [-_\.]?
17 |                     (?P[0-9]+)?
18 |                 )?
19 |                 (?P                                         # post release
20 |                     (?:-(?P[0-9]+))
21 |                     |
22 |                     (?:
23 |                         [-_\.]?
24 |                         (?Ppost|rev|r)
25 |                         [-_\.]?
26 |                         (?P[0-9]+)?
27 |                     )
28 |                 )?
29 |                 (?P                                          # dev release
30 |                     [-_\.]?
31 |                     (?Pdev)
32 |                     [-_\.]?
33 |                     (?P[0-9]+)?
34 |                 )?
35 |             )
36 |             (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
37 |         """
38 | 
39 |         regex = re.compile(
40 |             r"^\s*" + version_pattern + r"\s*$",
41 |             re.VERBOSE | re.IGNORECASE,
42 |         )
43 | 
44 |         # read setup.py file
45 |         def read(file):
46 |             with open(file, "r", encoding="UTF-8") as input_file:
47 |                 text = input_file.read()
48 |             return text
49 | 
50 |         try:
51 |             version = read("version.txt")
52 |         except FileNotFoundError:
53 |             version = read("../version.txt")
54 | 
55 |         self.assertTrue(regex.match(version) is not None)
56 | 


--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | 
  3 | from pytorch_lightning.callbacks import ModelCheckpoint
  4 | from pytorch_lightning.trainer import Trainer
  5 | 
  6 | from dataset.youtube_face import YoutubeFaceDataModule
  7 | from marlin_pytorch.util import read_yaml
  8 | from util.misc import load_official_pretrain_model
  9 | 
 10 | parser = argparse.ArgumentParser("MARLIN pretraining")
 11 | parser.add_argument("--config", type=str)
 12 | parser.add_argument("--data_dir", type=str)
 13 | parser.add_argument("--n_gpus", type=int, default=1)
 14 | parser.add_argument("--num_workers", type=int, default=8)
 15 | parser.add_argument("--batch_size", type=int, default=16)
 16 | parser.add_argument("--epochs", type=int, default=2000)
 17 | parser.add_argument("--official_pretrained", type=str, default=None)
 18 | parser.add_argument("--resume", type=str, default=None)
 19 | 
 20 | if __name__ == '__main__':
 21 |     args = parser.parse_args()
 22 |     config_path = args.config
 23 |     data_path = args.data_dir
 24 |     resume_ckpt = args.resume
 25 |     config = read_yaml(config_path)
 26 | 
 27 |     batch_size = args.batch_size
 28 |     max_epochs = args.epochs
 29 |     num_workers = args.num_workers
 30 |     official_pretrained = args.official_pretrained
 31 | 
 32 |     model_name = config["model_name"]
 33 |     learning_rate = config["learning_rate"]["base"]
 34 |     warmup_lr = config["learning_rate"]["warmup"]
 35 |     min_lr = config["learning_rate"]["min"]
 36 |     warmup_epochs = config["learning_rate"]["warmup_epochs"]
 37 |     n_gpus = args.n_gpus
 38 |     img_size = config["img_size"]
 39 |     patch_size = config["patch_size"]
 40 |     clip_frames = config["clip_frames"]
 41 |     tubelet_size = config["tubelet_size"]
 42 |     mask_strategy = config["mask_strategy"]
 43 |     temporal_sample_rate = config["temporal_sample_rate"]
 44 |     mask_percentage_target = config["mask_percentage_target"]
 45 |     encoder_embed_dim = config["encoder"]["embed_dim"]
 46 |     encoder_depth = config["encoder"]["depth"]
 47 |     encoder_num_heads = config["encoder"]["num_heads"]
 48 |     decoder_embed_dim = config["decoder"]["embed_dim"]
 49 |     decoder_depth = config["decoder"]["depth"]
 50 |     decoder_num_heads = config["decoder"]["num_heads"]
 51 |     mlp_ratio = config["mlp_ratio"]
 52 |     qkv_bias = config["qkv_bias"]
 53 |     qk_scale = config["qk_scale"]
 54 |     drop_rate = config["drop_rate"]
 55 |     attn_drop_rate = config["attn_drop_rate"]
 56 |     norm_layer = config["norm_layer"]
 57 |     init_values = config["init_values"]
 58 |     optimizer_type = config["optimizer"]["type"]
 59 |     optimizer_eps = config["optimizer"]["eps"]
 60 |     optimizer_betas = config["optimizer"]["betas"]
 61 |     weight_decay = config["weight_decay"]
 62 |     adv_loss = config["adv_loss"]
 63 | 
 64 |     total_batch_size = batch_size * n_gpus
 65 |     learning_rate = learning_rate * total_batch_size / 256
 66 |     warmup_lr = warmup_lr * total_batch_size / 256
 67 |     min_lr = min_lr * total_batch_size / 256
 68 | 
 69 |     dm = YoutubeFaceDataModule(
 70 |         root_dir=data_path,
 71 |         batch_size=batch_size,
 72 |         clip_frames=clip_frames,
 73 |         temporal_sample_rate=temporal_sample_rate,
 74 |         patch_size=patch_size,
 75 |         tubelet_size=tubelet_size,
 76 |         mask_percentage_target=mask_percentage_target,
 77 |         mask_strategy=mask_strategy,
 78 |         num_workers=num_workers,
 79 |         take_train=None,
 80 |         take_val=None
 81 |     )
 82 |     dm.setup()
 83 | 
 84 |     if adv_loss:
 85 |         from model.marlin import Marlin
 86 |     else:
 87 |         raise NotImplementedError
 88 | 
 89 |     model = Marlin(
 90 |         img_size=img_size,
 91 |         patch_size=patch_size,
 92 |         n_frames=clip_frames,
 93 |         encoder_embed_dim=encoder_embed_dim,
 94 |         encoder_depth=encoder_depth,
 95 |         encoder_num_heads=encoder_num_heads,
 96 |         decoder_embed_dim=decoder_embed_dim,
 97 |         decoder_depth=decoder_depth,
 98 |         decoder_num_heads=decoder_num_heads,
 99 |         mlp_ratio=mlp_ratio,
100 |         qkv_bias=qkv_bias,
101 |         qk_scale=qk_scale,
102 |         drop_rate=drop_rate,
103 |         attn_drop_rate=attn_drop_rate,
104 |         norm_layer=norm_layer,
105 |         init_values=init_values,
106 |         tubelet_size=tubelet_size,
107 |         optimizer_type=optimizer_type,
108 |         optimizer_eps=optimizer_eps,
109 |         optimizer_betas=optimizer_betas,
110 |         weight_decay=weight_decay,
111 |         learning_rate=learning_rate,
112 |         warmup_lr=warmup_lr,
113 |         min_lr=min_lr,
114 |         warmup_epochs=warmup_epochs,
115 |         max_epochs=max_epochs,
116 |         iter_per_epoch=len(dm.train_dataloader()),
117 |         distributed=n_gpus > 1,
118 |         name=model_name
119 |     )
120 | 
121 |     if adv_loss:
122 |         model.adv_weight = config["adv_weight"]
123 |         model.gp_weight = config["gp_weight"]
124 |         model.d_steps = config["d_steps"]
125 |         model.g_steps = config["g_steps"]
126 | 
127 |     if official_pretrained is not None:
128 |         print(load_official_pretrain_model(model, official_pretrained))
129 | 
130 |     accelerator = None if n_gpus <= 1 else "ddp"
131 |     device = "gpu" if n_gpus > 0 else "cpu"
132 |     n_gpus = n_gpus if n_gpus > 0 else None
133 | 
134 |     trainer = Trainer(log_every_n_steps=1, devices=n_gpus, accelerator=device,
135 |         logger=True, precision=32, max_epochs=max_epochs,
136 |         strategy=accelerator, resume_from_checkpoint=resume_ckpt,
137 |         callbacks=[ModelCheckpoint(dirpath=f"ckpt/{model_name}", save_last=True,
138 |             filename=model.name + "-{epoch}-{val_loss:.3f}",
139 |             monitor="val_loss", mode="min")])
140 | 
141 |     trainer.fit(model, dm)
142 | 


--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/util/__init__.py


--------------------------------------------------------------------------------
/util/earlystop_lr.py:
--------------------------------------------------------------------------------
 1 | from pytorch_lightning import Trainer, LightningModule
 2 | from pytorch_lightning.callbacks import Callback
 3 | import re
 4 | 
 5 | 
 6 | class EarlyStoppingLR(Callback):
 7 | 
 8 |     def __init__(self, lr_threshold: float, mode="all"):
 9 |         self.lr_threshold = lr_threshold
10 | 
11 |         if mode in ("any", "all"):
12 |             self.mode = mode
13 |         else:
14 |             raise ValueError(f"mode must be one of ('any', 'all')")
15 | 
16 |     def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
17 |         self._run_early_stop_checking(trainer)
18 | 
19 |     def _run_early_stop_checking(self, trainer: Trainer) -> None:
20 |         metrics = trainer._logger_connector.callback_metrics
21 |         if len(metrics) == 0:
22 |             return
23 |         all_lr = []
24 |         for key, value in metrics.items():
25 |             if re.match(r"opt\d+_lr\d+", key):
26 |                 all_lr.append(value)
27 | 
28 |         if len(all_lr) == 0:
29 |             return
30 | 
31 |         if self.mode == "all":
32 |             if all(lr <= self.lr_threshold for lr in all_lr):
33 |                 trainer.should_stop = True
34 |         elif self.mode == "any":
35 |             if any(lr <= self.lr_threshold for lr in all_lr):
36 |                 trainer.should_stop = True
37 | 


--------------------------------------------------------------------------------
/util/face_sdk/README.md:
--------------------------------------------------------------------------------
1 | ## Face SDK
2 | 
3 | From [FaceXZoo](https://github.com/JDAI-CV/FaceX-Zoo).
4 | 


--------------------------------------------------------------------------------
/util/face_sdk/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/util/face_sdk/__init__.py


--------------------------------------------------------------------------------
/util/face_sdk/config/logging.conf:
--------------------------------------------------------------------------------
 1 | [loggers] # loggers object list
 2 |     keys = root, sdk, api
 3 | 
 4 | [handlers] # handlers object list
 5 |     keys = consoleHandlers, fileHandlers
 6 | 
 7 | [formatters] # formatters list
 8 |     keys = fmt
 9 | 
10 | [logger_root]
11 |     level = DEBUG
12 |     handlers = consoleHandlers, fileHandlers
13 | 
14 | [logger_sdk] # sdk logger
15 |     level = DEBUG
16 |     handlers = fileHandlers
17 |     qualname = sdk
18 |     propagate = 0
19 | 
20 | [logger_api] # api logger
21 |     level = DEBUG
22 |     handlers = consoleHandlers
23 |     qualname = api
24 |     propagate = 0
25 | 
26 | [handler_consoleHandlers]# consoleHandlers.
27 |     class = StreamHandler
28 |     level = DEBUG
29 |     formatter = fmt
30 |     args = (sys.stdout,)
31 | 
32 | [handler_fileHandlers]# fileHandlers
33 |     class = logging.handlers.RotatingFileHandler
34 |     level = DEBUG
35 |     formatter = fmt
36 |     args = ('logs/sdk.log', 'a', 10000, 3, 'UTF-8')
37 | 
38 | [formatter_fmt] # fmt format
39 |     format = %(levelname)s %(asctime)s %(filename)s: %(lineno)d] %(message)s
40 |     datefmt = %Y-%m-%d %H:%M:%S


--------------------------------------------------------------------------------
/util/face_sdk/config/model_conf.yaml:
--------------------------------------------------------------------------------
 1 | non-mask:
 2 |     face_detection: face_detection_1.0
 3 |     face_alignment: face_alignment_1.0
 4 |     face_recognition: face_recognition_1.0
 5 |     face_parsing: face_parsing_1.0
 6 | mask:
 7 |     face_detection: face_detection_2.0
 8 |     face_alignment: face_alignment_2.0
 9 |     face_recognition: face_recognition_2.0
10 |     


--------------------------------------------------------------------------------
/util/face_sdk/core/image_cropper/BaseImageCropper.py:
--------------------------------------------------------------------------------
 1 | """
 2 | @author: JiXuan Xu, Jun Wang
 3 | @date: 20201015
 4 | @contact: jun21wangustc@gmail.com 
 5 | """
 6 | from abc import ABCMeta, abstractmethod
 7 | 
 8 | 
 9 | class BaseImageCropper(metaclass=ABCMeta):
10 |     """Base class for all model loader.
11 |     All image alignment classes need to inherit this base class.
12 |     """
13 | 
14 |     def __init__(self):
15 |         pass
16 | 
17 |     @abstractmethod
18 |     def crop_image_by_mat(self, image, landmarks):
19 |         """Should be overridden by all subclasses.
20 |         Used for online image cropping, input the original Mat, 
21 |         and return the Mat obtained from the image cropping.
22 |         """
23 |         pass
24 | 


--------------------------------------------------------------------------------
/util/face_sdk/core/image_cropper/arcface_cropper/FaceRecImageCropper.py:
--------------------------------------------------------------------------------
  1 | """
  2 | @author: JiXuan Xu, Jun Wang
  3 | @date: 20201015
  4 | @contact: jun21wangustc@gmail.com 
  5 | """
  6 | # based on:
  7 | # https://github.com/deepinsight/insightface/blob/master/recognition/common/face_align.py
  8 | 
  9 | import cv2
 10 | import numpy as np
 11 | from skimage import transform as trans
 12 | 
 13 | from util.face_sdk.core.image_cropper.BaseImageCropper import BaseImageCropper
 14 | from util.face_sdk.utils.lms_trans import lms106_2_lms5, lms25_2_lms5
 15 | 
 16 | src1 = np.array([
 17 |     [51.642, 50.115],
 18 |     [57.617, 49.990],
 19 |     [35.740, 69.007],
 20 |     [51.157, 89.050],
 21 |     [57.025, 89.702]], dtype=np.float32)
 22 | # <--left
 23 | src2 = np.array([
 24 |     [45.031, 50.118],
 25 |     [65.568, 50.872],
 26 |     [39.677, 68.111],
 27 |     [45.177, 86.190],
 28 |     [64.246, 86.758]], dtype=np.float32)
 29 | 
 30 | # ---frontal
 31 | src3 = np.array([
 32 |     [39.730, 51.138],
 33 |     [72.270, 51.138],
 34 |     [56.000, 68.493],
 35 |     [42.463, 87.010],
 36 |     [69.537, 87.010]], dtype=np.float32)
 37 | 
 38 | # -->right
 39 | src4 = np.array([
 40 |     [46.845, 50.872],
 41 |     [67.382, 50.118],
 42 |     [72.737, 68.111],
 43 |     [48.167, 86.758],
 44 |     [67.236, 86.190]], dtype=np.float32)
 45 | 
 46 | # -->right profile
 47 | src5 = np.array([
 48 |     [54.796, 49.990],
 49 |     [60.771, 50.115],
 50 |     [76.673, 69.007],
 51 |     [55.388, 89.702],
 52 |     [61.257, 89.050]], dtype=np.float32)
 53 | 
 54 | src = np.array([src1, src2, src3, src4, src5])
 55 | src_map = {112: src, 224: src * 2}
 56 | 
 57 | arcface_src = np.array([
 58 |     [38.2946, 51.6963],
 59 |     [73.5318, 51.5014],
 60 |     [56.0252, 71.7366],
 61 |     [41.5493, 92.3655],
 62 |     [70.7299, 92.2041]], dtype=np.float32)
 63 | 
 64 | arcface_src = np.expand_dims(arcface_src, axis=0)
 65 | 
 66 | 
 67 | # In[66]:
 68 | 
 69 | # lmk is prediction; src is template
 70 | def estimate_norm(lmk, image_size=112, mode='arcface'):
 71 |     assert lmk.shape == (5, 2)
 72 |     tform = trans.SimilarityTransform()
 73 |     lmk_tran = np.insert(lmk, 2, values=np.ones(5), axis=1)
 74 |     min_M = []
 75 |     min_index = []
 76 |     min_error = float('inf')
 77 |     if mode == 'arcface':
 78 |         assert image_size == 112
 79 |         src = arcface_src
 80 |     else:
 81 |         src = src_map[image_size]
 82 |     for i in np.arange(src.shape[0]):
 83 |         tform.estimate(lmk, src[i])
 84 |         M = tform.params[0:2, :]
 85 |         results = np.dot(M, lmk_tran.T)
 86 |         results = results.T
 87 |         error = np.sum(np.sqrt(np.sum((results - src[i]) ** 2, axis=1)))
 88 |         #         print(error)
 89 |         if error < min_error:
 90 |             min_error = error
 91 |             min_M = M
 92 |             min_index = i
 93 |     return min_M, min_index
 94 | 
 95 | 
 96 | def norm_crop(img, landmark, image_size=112, mode='arcface'):
 97 |     M, pose_index = estimate_norm(landmark, image_size, mode)
 98 |     warped = cv2.warpAffine(img, M, (image_size, image_size), borderValue=0.0)
 99 |     return warped
100 | 
101 | 
102 | # my class warpper
103 | class FaceRecImageCropper(BaseImageCropper):
104 |     """Implementation of image cropper
105 | 
106 |     Attributes:
107 |         image: the input image.
108 |         landmarks: using landmarks information to crop.
109 |     """
110 | 
111 |     def __init__(self):
112 |         super().__init__()
113 | 
114 |     def crop_image_by_mat(self, image, landmarks):
115 |         if len(landmarks) == 106 * 2:
116 |             landmarks = lms106_2_lms5(landmarks)
117 |         if len(landmarks) == 25 * 2:
118 |             landmarks = lms25_2_lms5(landmarks)
119 |         assert (len(landmarks) == 5 * 2)
120 |         landmarks = np.array(landmarks)
121 |         height, width, channel = image.shape
122 |         if channel != 3:
123 |             print('Error input.')
124 |         landmarks = landmarks.reshape((5, 2))
125 |         cropped_image = norm_crop(image, landmarks)
126 |         return cropped_image
127 | 


--------------------------------------------------------------------------------
/util/face_sdk/core/model_handler/BaseModelHandler.py:
--------------------------------------------------------------------------------
 1 | """
 2 | @author: JiXuan Xu, Jun Wang
 3 | @date: 20201015
 4 | @contact: jun21wangustc@gmail.com 
 5 | """
 6 | from abc import ABCMeta, abstractmethod
 7 | import torch
 8 | 
 9 | 
10 | class BaseModelHandler(metaclass=ABCMeta):
11 |     """Base class for all neural network models.
12 |     All the model loaders need to inherit this base class, 
13 |     and each new model needs to implement the "inference_on_image" method
14 |     """
15 | 
16 |     def __init__(self, model, device, cfg):
17 |         """
18 |         Generate the model by loading the configuration file.
19 |         #######:param cfg: Cfg Node
20 |         """
21 |         self.model = model
22 |         self.model.eval()
23 |         self.cfg = cfg
24 |         self.device = torch.device(device)
25 | 
26 |     @abstractmethod
27 |     def inference_on_image(self, image):
28 |         pass
29 | 
30 |     def _preprocess(self, image):
31 |         pass
32 | 
33 |     def _postprocess(self, output):
34 |         pass
35 | 


--------------------------------------------------------------------------------
/util/face_sdk/core/model_handler/face_alignment/FaceAlignModelHandler.py:
--------------------------------------------------------------------------------
  1 | """
  2 | @author: JiXuan Xu, Jun Wang
  3 | @date: 20201023
  4 | @contact: jun21wangustc@gmail.com 
  5 | """
  6 | import logging.config
  7 | import os.path
  8 | 
  9 | logging.config.fileConfig(os.path.join("util", "face_sdk", "config", "logging.conf"))
 10 | logger = logging.getLogger('sdk')
 11 | 
 12 | import cv2
 13 | import torch
 14 | import numpy as np
 15 | import torch.backends.cudnn as cudnn
 16 | 
 17 | from util.face_sdk.core.model_handler.BaseModelHandler import BaseModelHandler
 18 | from util.face_sdk.utils.BuzException import *
 19 | from torchvision import transforms
 20 | 
 21 | 
 22 | class FaceAlignModelHandler(BaseModelHandler):
 23 |     """Implementation of face landmark model handler
 24 | 
 25 |     Attributes:
 26 |         model: the face landmark model.
 27 |         device: use cpu or gpu to process.
 28 |         cfg(dict): testing config, inherit from the parent class.
 29 |     """
 30 | 
 31 |     def __init__(self, model, device, cfg):
 32 |         """
 33 |         Init FaceLmsModelHandler settings. 
 34 |         """
 35 |         super().__init__(model, device, cfg)
 36 |         self.img_size = self.cfg['img_size']
 37 | 
 38 |     def inference_on_image(self, image, dets):
 39 |         """Get the inference of the image and process the inference result.
 40 | 
 41 |         Returns:
 42 |             A numpy array, the landmarks prediction based on the shape of original image, shape: (106, 2), 
 43 |         """
 44 |         cudnn.benchmark = True
 45 |         try:
 46 |             image_pre = self._preprocess(image, dets)
 47 |         except Exception as e:
 48 |             raise e
 49 |         self.model = self.model.to(self.device)
 50 |         image_pre = image_pre.unsqueeze(0)
 51 |         with torch.no_grad():
 52 |             image_pre = image_pre.to(self.device)
 53 |             _, landmarks_normal = self.model(image_pre)
 54 |         landmarks = self._postprocess(landmarks_normal)
 55 |         return landmarks
 56 | 
 57 |     # Adapted from https://github.com/Hsintao/pfld_106_face_landmarks/blob/master/data/prepare.py
 58 |     def _preprocess(self, image, det):
 59 |         """Preprocess the input image, cutting the input image through the face detection information.
 60 |         Using the face detection result(dets) to get the face position in the input image.
 61 |         After determining the center of face position and the box size of face, crop the image
 62 |         and resize it into preset size.
 63 | 
 64 |         Returns:
 65 |            A torch tensor, the image after preprecess, shape: (3, 112, 112).
 66 |         """
 67 |         if not isinstance(image, np.ndarray):
 68 |             logger.error('The input should be the ndarray read by cv2!')
 69 |             raise InputError()
 70 |         img = image.copy()
 71 |         self.image_org = image.copy()
 72 |         img = np.float32(img)
 73 | 
 74 |         xy = np.array([det[0], det[1]])
 75 |         zz = np.array([det[2], det[3]])
 76 |         wh = zz - xy + 1
 77 |         center = (xy + wh / 2).astype(np.int32)
 78 |         boxsize = int(np.max(wh) * 1.2)
 79 |         xy = center - boxsize // 2
 80 |         self.xy = xy
 81 |         self.boxsize = boxsize
 82 |         x1, y1 = xy
 83 |         x2, y2 = xy + boxsize
 84 |         height, width, _ = img.shape
 85 |         dx = max(0, -x1)
 86 |         dy = max(0, -y1)
 87 |         x1 = max(0, x1)
 88 |         y1 = max(0, y1)
 89 |         edx = max(0, x2 - width)
 90 |         edy = max(0, y2 - height)
 91 |         x2 = min(width, x2)
 92 |         y2 = min(height, y2)
 93 |         imageT = image[y1:y2, x1:x2]
 94 |         if dx > 0 or dy > 0 or edx > 0 or edy > 0:
 95 |             imageT = cv2.copyMakeBorder(
 96 |                 imageT, dy, edy, dx, edx, cv2.BORDER_CONSTANT, 0)
 97 | 
 98 |         imageT = cv2.resize(imageT, (self.img_size, self.img_size))
 99 |         t = transforms.Compose([transforms.ToTensor()])
100 |         img_after = t(imageT)
101 |         return img_after
102 | 
103 |     def _postprocess(self, landmarks_normal):
104 |         """Process the predicted landmarks into the form of the original image.
105 | 
106 |         Returns:
107 |             A numpy array, the landmarks based on the shape of original image, shape: (106, 2), 
108 |         """
109 |         landmarks_normal = landmarks_normal.cpu().numpy()
110 |         landmarks_normal = landmarks_normal.reshape(landmarks_normal.shape[0], -1, 2)
111 |         landmarks = landmarks_normal[0] * [self.boxsize, self.boxsize] + self.xy
112 |         return landmarks
113 | 


--------------------------------------------------------------------------------
/util/face_sdk/core/model_handler/face_detection/FaceDetModelHandler.py:
--------------------------------------------------------------------------------
  1 | """
  2 | @author: JiXuan Xu, Jun Wang
  3 | @date: 20201019
  4 | @contact: jun21wangustc@gmail.com 
  5 | """
  6 | 
  7 | import logging.config
  8 | import os
  9 | 
 10 | logging.config.fileConfig(os.path.join("util", "face_sdk", "config", "logging.conf"))
 11 | logger = logging.getLogger('sdk')
 12 | 
 13 | import torch
 14 | import numpy as np
 15 | from math import ceil
 16 | from itertools import product as product
 17 | import torch.backends.cudnn as cudnn
 18 | 
 19 | from util.face_sdk.core.model_handler.BaseModelHandler import BaseModelHandler
 20 | from util.face_sdk.utils.BuzException import *
 21 | 
 22 | 
 23 | class FaceDetModelHandler(BaseModelHandler):
 24 |     """Implementation of face detection model handler
 25 | 
 26 |     Attributes:
 27 |         model: the face detection model.
 28 |         device: use cpu or gpu to process.
 29 |         cfg(dict): testing config, inherit from the parent class.
 30 |     """
 31 | 
 32 |     def __init__(self, model, device, cfg):
 33 |         """
 34 |         Init FaceDetModelHandler settings. 
 35 |         """
 36 |         super().__init__(model, device, cfg)
 37 |         self.variance = self.cfg['variance']
 38 | 
 39 |     def inference_on_image(self, image):
 40 |         """Get the inference of the image and process the inference result.
 41 | 
 42 |         Returns:
 43 |             A numpy array, the shape is N * (x, y, w, h, confidence), 
 44 |             N is the number of detection box.
 45 |         """
 46 |         cudnn.benchmark = True
 47 |         input_height, input_width, _ = image.shape
 48 |         try:
 49 |             image, scale = self._preprocess(image)
 50 |         except Exception as e:
 51 |             raise e
 52 |         self.model = self.model.to(self.device)
 53 |         image = torch.from_numpy(image).unsqueeze(0)
 54 |         with torch.no_grad():
 55 |             image = image.to(self.device)
 56 |             scale = scale.to(self.device)
 57 |             loc, conf, landms = self.model(image)
 58 |         dets = self._postprocess(loc, conf, scale, input_height, input_width)
 59 |         return dets
 60 | 
 61 |     def _preprocess(self, image):
 62 |         """Preprocess the image, such as standardization and other operations.
 63 | 
 64 |         Returns:
 65 |             A numpy array list, the shape is channel * h * w.
 66 |             A tensor, the shape is 4.
 67 |         """
 68 |         if not isinstance(image, np.ndarray):
 69 |             logger.error('The input should be the ndarray read by cv2!')
 70 |             raise InputError()
 71 |         img = np.float32(image)
 72 |         scale = torch.Tensor([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
 73 |         img -= (104, 117, 123)
 74 |         img = img.transpose(2, 0, 1)
 75 |         return img, scale
 76 | 
 77 |     def _postprocess(self, loc, conf, scale, input_height, input_width):
 78 |         """Postprecess the prediction result.
 79 |         Decode detection result, set the confidence threshold and do the NMS
 80 |         to keep the appropriate detection box. 
 81 | 
 82 |         Returns:
 83 |             A numpy array, the shape is N * (x, y, w, h, confidence), 
 84 |             N is the number of detection box.
 85 |         """
 86 |         priorbox = PriorBox(self.cfg, image_size=(input_height, input_width))
 87 |         priors = priorbox.forward()
 88 |         priors = priors.to(self.device)
 89 |         prior_data = priors.data
 90 |         boxes = self.decode(loc.data.squeeze(0), prior_data, self.cfg['variance'])
 91 |         boxes = boxes * scale
 92 |         boxes = boxes.cpu().numpy()
 93 |         scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
 94 | 
 95 |         # ignore low scores
 96 |         inds = np.where(scores > self.cfg['confidence_threshold'])[0]
 97 |         boxes = boxes[inds]
 98 |         scores = scores[inds]
 99 | 
100 |         # keep top-K before NMS
101 |         order = scores.argsort()[::-1]
102 |         boxes = boxes[order]
103 |         scores = scores[order]
104 | 
105 |         # do NMS
106 |         nms_threshold = 0.2
107 |         dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
108 |         keep = self.py_cpu_nms(dets, nms_threshold)
109 |         dets = dets[keep, :]
110 |         return dets
111 | 
112 |     # Adapted from https://github.com/chainer/chainercv
113 |     def decode(self, loc, priors, variances):
114 |         """Decode locations from predictions using priors to undo
115 |         the encoding we did for offset regression at train time.
116 |         Args:
117 |             loc (tensor): location predictions for loc layers,
118 |                 Shape: [num_priors,4]
119 |             priors (tensor): Prior boxes in center-offset form.
120 |                 Shape: [num_priors,4].
121 |             variances: (list[float]) Variances of priorboxes
122 | 
123 |         Return:
124 |             decoded bounding box predictions
125 |         """
126 |         boxes = torch.cat((priors[:, :2], priors[:, 2:]), 1)
127 |         boxes[:, :2] = priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:]
128 |         boxes[:, 2:] = priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])
129 |         boxes[:, :2] -= boxes[:, 2:] / 2
130 |         boxes[:, 2:] += boxes[:, :2]
131 |         return boxes
132 | 
133 |     # Adapted from https://github.com/biubug6/Pytorch_Retinaface
134 |     def py_cpu_nms(self, dets, thresh):
135 |         """Python version NMS.
136 | 
137 |         Returns:
138 |             The kept index after NMS.
139 |         """
140 |         x1 = dets[:, 0]
141 |         y1 = dets[:, 1]
142 |         x2 = dets[:, 2]
143 |         y2 = dets[:, 3]
144 |         scores = dets[:, 4]
145 |         areas = (x2 - x1 + 1) * (y2 - y1 + 1)
146 |         order = scores.argsort()[::-1]
147 |         keep = []
148 |         while order.size > 0:
149 |             i = order[0]
150 |             keep.append(i)
151 |             xx1 = np.maximum(x1[i], x1[order[1:]])
152 |             yy1 = np.maximum(y1[i], y1[order[1:]])
153 |             xx2 = np.minimum(x2[i], x2[order[1:]])
154 |             yy2 = np.minimum(y2[i], y2[order[1:]])
155 |             w = np.maximum(0.0, xx2 - xx1 + 1)
156 |             h = np.maximum(0.0, yy2 - yy1 + 1)
157 |             inter = w * h
158 |             ovr = inter / (areas[i] + areas[order[1:]] - inter)
159 |             inds = np.where(ovr <= thresh)[0]
160 |             order = order[inds + 1]
161 |         return keep
162 | 
163 | 
164 | # Adapted from https://github.com/biubug6/Pytorch_Retinafacey
165 | class PriorBox(object):
166 |     """Compute the suitable parameters of anchors for later decode operation
167 | 
168 |     Attributes:
169 |         cfg(dict): testing config.
170 |         image_size(tuple): the input image size.
171 |     """
172 | 
173 |     def __init__(self, cfg, image_size=None):
174 |         """
175 |         Init priorBox settings related to the generation of anchors. 
176 |         """
177 |         super(PriorBox, self).__init__()
178 |         self.min_sizes = cfg['min_sizes']
179 |         self.steps = cfg['steps']
180 |         self.image_size = image_size
181 |         self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]
182 |         self.name = "s"
183 | 
184 |     def forward(self):
185 |         anchors = []
186 |         for k, f in enumerate(self.feature_maps):
187 |             min_sizes = self.min_sizes[k]
188 |             for i, j in product(range(f[0]), range(f[1])):
189 |                 for min_size in min_sizes:
190 |                     s_kx = min_size / self.image_size[1]
191 |                     s_ky = min_size / self.image_size[0]
192 |                     dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
193 |                     dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
194 |                     for cy, cx in product(dense_cy, dense_cx):
195 |                         anchors += [cx, cy, s_kx, s_ky]
196 |         # back to torch land
197 |         output = torch.Tensor(anchors).view(-1, 4)
198 |         return output
199 | 


--------------------------------------------------------------------------------
/util/face_sdk/core/model_handler/face_parsing/FaceParsingModelHandler.py:
--------------------------------------------------------------------------------
  1 | """
  2 | @author: fengyu, wangjun
  3 | @date: 20220620
  4 | @contact: fengyu_cnyc@163.com
  5 | """
  6 | 
  7 | # based on:
  8 | # https://github.com/FacePerceiver/facer/blob/main/facer/face_parsing/farl.py
  9 | import logging.config
 10 | import os
 11 | 
 12 | from util.face_sdk.utils.BuzException import InputError
 13 | 
 14 | logging.config.fileConfig(os.path.join("util", "face_sdk", "config", "logging.conf"))
 15 | logger = logging.getLogger('sdk')
 16 | 
 17 | import numpy as np
 18 | import torch.backends.cudnn as cudnn
 19 | from torch.nn import functional as F
 20 | 
 21 | from util.face_sdk.core.model_handler.BaseModelHandler import BaseModelHandler
 22 | from util.face_sdk.utils.transform import *
 23 | 
 24 | pretrain_settings = {
 25 |     'lapa/448': {
 26 |         'matrix_src_tag': 'points',
 27 |         'get_matrix_fn': functools.partial(get_face_align_matrix,
 28 |                                            target_shape=(448, 448), target_face_scale=1.0),
 29 |         'get_grid_fn': functools.partial(make_tanh_warp_grid,
 30 |                                          warp_factor=0.8, warped_shape=(448, 448)),
 31 |         'get_inv_grid_fn': functools.partial(make_inverted_tanh_warp_grid,
 32 |                                              warp_factor=0.8, warped_shape=(448, 448)),
 33 |         'label_names': ['background', 'face', 'rb', 'lb', 're',
 34 |             'le', 'nose', 'ulip', 'imouth', 'llip', 'hair']
 35 |     }
 36 | }
 37 | 
 38 | 
 39 | class FaceParsingModelHandler(BaseModelHandler):
 40 | 
 41 |     def __init__(self, model=None, device=None, cfg=None):
 42 |         super().__init__(model, device, cfg)
 43 | 
 44 |         self.model = model.to(self.device)
 45 | 
 46 |     def _preprocess(self, image, face_nums):
 47 |         """Preprocess the image, such as standardization and other operations.
 48 | 
 49 |         Returns:
 50 |             A tensor, the shape is 1 x 3 x h x w.
 51 |             A dict, {'rects','points','scores','image_ids'} 
 52 |         """
 53 |         if not isinstance(image, np.ndarray):
 54 |             logger.error('The input should be the ndarray read by cv2!')
 55 |             raise InputError()
 56 |         img = np.float32(image)
 57 |         img = img.transpose(2, 0, 1)
 58 |         img = np.expand_dims(img, 0).repeat(face_nums, axis=0)
 59 |         return torch.from_numpy(img)
 60 | 
 61 |     def inference_on_image(self, face_nums: int, images: torch.Tensor, landmarks):
 62 |         """Get the inference of the image and process the inference result.
 63 | 
 64 |         Returns:
 65 |              
 66 |         """
 67 |         cudnn.benchmark = True
 68 |         try:
 69 |             image_pre = self._preprocess(images, face_nums)
 70 |         except Exception as e:
 71 |             raise e
 72 |         setting = pretrain_settings['lapa/448']
 73 |         images = image_pre.float() / 255.0
 74 |         _, _, h, w = images.shape
 75 |         simages = images.to(self.device)
 76 |         matrix = setting['get_matrix_fn'](landmarks.to(self.device))
 77 |         grid = setting['get_grid_fn'](matrix=matrix, orig_shape=(h, w))
 78 |         inv_grid = setting['get_inv_grid_fn'](matrix=matrix, orig_shape=(h, w))
 79 | 
 80 |         w_images = F.grid_sample(
 81 |             simages, grid, mode='bilinear', align_corners=False)
 82 | 
 83 |         w_seg_logits, _ = self.model(w_images)  # (b*n) x c x h x w
 84 | 
 85 |         seg_logits = F.grid_sample(
 86 |             w_seg_logits, inv_grid, mode='bilinear', align_corners=False)
 87 |         data_pre = {}
 88 |         data_pre['seg'] = {
 89 |             'logits': seg_logits,
 90 |             'label_names': setting['label_names']
 91 |         }
 92 |         return data_pre
 93 | 
 94 |     def _postprocess(self, loc, conf, scale, input_height, input_width):
 95 |         """Postprecess the prediction result.
 96 |         Decode detection result, set the confidence threshold and do the NMS
 97 |         to keep the appropriate detection box. 
 98 | 
 99 |         Returns:
100 |             A numpy array, the shape is N * (x, y, w, h, confidence), 
101 |             N is the number of detection box.
102 |         """
103 |         pass
104 | 


--------------------------------------------------------------------------------
/util/face_sdk/core/model_handler/face_recognition/FaceRecModelHandler.py:
--------------------------------------------------------------------------------
 1 | """
 2 | @author: JiXuan Xu, Jun Wang
 3 | @date: 20201015
 4 | @contact: jun21wangustc@gmail.com 
 5 | """
 6 | import logging.config
 7 | import os
 8 | 
 9 | logging.config.fileConfig(os.path.join("util", "face_sdk", "config", "logging.conf"))
10 | logger = logging.getLogger('sdk')
11 | 
12 | import numpy as np
13 | import torch
14 | 
15 | from util.face_sdk.core.model_handler.BaseModelHandler import BaseModelHandler
16 | from util.face_sdk.utils.BuzException import *
17 | 
18 | 
19 | class FaceRecModelHandler(BaseModelHandler):
20 |     """Implementation of face recognition model handler
21 | 
22 |     Attributes:
23 |         model: the face recognition model.
24 |         device: use cpu or gpu to process.
25 |         cfg(dict): testing config, inherit from the parent class.
26 |     """
27 | 
28 |     def __init__(self, model, device, cfg):
29 |         """
30 |         Init FaceRecModelHandler settings. 
31 |         """
32 |         super().__init__(model, device, cfg)
33 |         self.mean = self.cfg['mean']
34 |         self.std = self.cfg['std']
35 |         self.input_height = self.cfg['input_height']
36 |         self.input_width = self.cfg['input_width']
37 | 
38 |     def inference_on_image(self, image):
39 |         """Get the inference of the image.
40 | 
41 |         Returns:
42 |             A numpy array, the output feature, shape (512,), 
43 |         """
44 |         try:
45 |             image = self._preprocess(image)
46 |         except Exception as e:
47 |             raise e
48 |         image = torch.unsqueeze(image, 0)
49 |         image = image.to(self.device)
50 |         with torch.no_grad():
51 |             feature = self.model(image).cpu().numpy()
52 |         feature = np.squeeze(feature)
53 |         return feature
54 | 
55 |     def _preprocess(self, image):
56 |         """Preprocess the input image.
57 | 
58 |         Returns:
59 |            A torch tensor, the input after preprecess, shape: (3, 112, 112).
60 |         """
61 |         if not isinstance(image, np.ndarray):
62 |             logger.error('The input should be the ndarray read by cv2!')
63 |             raise InputError()
64 |         height, width, channels = image.shape
65 |         if height != self.input_height or width != self.input_width:
66 |             raise FalseImageSizeError()
67 |         if image.ndim == 2:
68 |             image = image[:, :, np.newaxis]
69 |         if image.ndim == 4:
70 |             image = image[:, :, :3]
71 |         if image.ndim > 4:
72 |             raise FaseChannelError(image.ndim)
73 |         image = (image.transpose((2, 0, 1)) - self.mean) / self.std
74 |         image = image.astype(np.float32)
75 |         image = torch.from_numpy(image)
76 |         return image
77 | 


--------------------------------------------------------------------------------
/util/face_sdk/core/model_loader/BaseModelLoader.py:
--------------------------------------------------------------------------------
 1 | """
 2 | @author: JiXuan Xu, Jun Wang
 3 | @date: 20201015
 4 | @contact: jun21wangustc@gmail.com 
 5 | """
 6 | import os
 7 | import sys
 8 | 
 9 | sys.path.append(os.path.join("util", "face_sdk", "models", "network_def"))
10 | import logging.config
11 | 
12 | logging.config.fileConfig(os.path.join("util", "face_sdk", "config", "logging.conf"))
13 | logger = logging.getLogger('sdk')
14 | from abc import ABCMeta, abstractmethod
15 | 
16 | import json
17 | 
18 | 
19 | class BaseModelLoader(metaclass=ABCMeta):
20 |     """Base class for all model loader.
21 |     All the model loaders need to inherit this base class, 
22 |     and each new model needs to implement the "load model" method
23 |     """
24 | 
25 |     def __init__(self, model_path, model_category, model_name, meta_file='model_meta.json'):
26 |         model_root_dir = os.path.join(model_path, model_category, model_name)
27 |         meta_file_path = os.path.join(model_root_dir, meta_file)
28 |         self.cfg = {}
29 |         try:
30 |             self.meta_conf = json.load(open(meta_file_path, 'r'))
31 |         except IOError as e:
32 |             logger.error('The configuration file meta.json was not found or failed to parse the file!')
33 |             raise e
34 |         except Exception as e:
35 |             logger.info('The configuration file format is wrong!')
36 |             raise e
37 |         else:
38 |             logger.info('Successfully parsed the model configuration file meta.json!')
39 |         # common configs for all model
40 |         self.cfg['model_path'] = model_path
41 |         self.cfg['model_category'] = model_category
42 |         self.cfg['model_name'] = model_name
43 |         self.cfg['model_type'] = self.meta_conf['model_type']
44 |         self.cfg['model_info'] = self.meta_conf['model_info']
45 |         self.cfg['model_file_path'] = os.path.join(model_root_dir, self.meta_conf['model_file'])
46 |         self.cfg['release_date'] = self.meta_conf['release_date']
47 |         self.cfg['input_height'] = self.meta_conf['input_height']
48 |         self.cfg['input_width'] = self.meta_conf['input_width']
49 | 
50 |     @abstractmethod
51 |     def load_model(self):
52 |         """Should be overridden by all subclasses.
53 |         Different models may have different configuration information,
54 |         such as mean, so each model implements its own loader
55 |         """
56 |         pass
57 | 


--------------------------------------------------------------------------------
/util/face_sdk/core/model_loader/face_alignment/FaceAlignModelLoader.py:
--------------------------------------------------------------------------------
 1 | """
 2 | @author: JiXuan Xu, Jun Wang
 3 | @date: 20201023
 4 | @contact: jun21wangustc@gmail.com 
 5 | """
 6 | import logging.config
 7 | import os
 8 | 
 9 | logging.config.fileConfig(os.path.join("util", "face_sdk", "config", "logging.conf"))
10 | logger = logging.getLogger('sdk')
11 | 
12 | import torch
13 | 
14 | from util.face_sdk.core.model_loader.BaseModelLoader import BaseModelLoader
15 | 
16 | 
17 | class FaceAlignModelLoader(BaseModelLoader):
18 | 
19 |     def __init__(self, model_path, model_category, model_name, meta_file='model_meta.json'):
20 |         logger.info('Start to analyze the face landmark model, model path: %s, model category: %s,model name: %s' %
21 |                     (model_path, model_category, model_name))
22 |         super().__init__(model_path, model_category, model_name, meta_file)
23 |         self.cfg['img_size'] = self.meta_conf['input_width']
24 | 
25 |     def load_model(self):
26 |         try:
27 |             model = torch.load(self.cfg['model_file_path'])
28 |         except Exception as e:
29 |             logger.error('The model failed to load, please check the model path: %s!'
30 |                          % self.cfg['model_file_path'])
31 |             raise e
32 |         else:
33 |             logger.info('Successfully loaded the face landmark model!')
34 |             return model, self.cfg
35 | 


--------------------------------------------------------------------------------
/util/face_sdk/core/model_loader/face_detection/FaceDetModelLoader.py:
--------------------------------------------------------------------------------
 1 | """
 2 | @author: JiXuan Xu, Jun Wang
 3 | @date: 20201019
 4 | @contact: jun21wangustc@gmail.com 
 5 | """
 6 | import logging.config
 7 | import os
 8 | 
 9 | logging.config.fileConfig(os.path.join("util", "face_sdk", "config", "logging.conf"))
10 | logger = logging.getLogger('sdk')
11 | 
12 | import torch
13 | 
14 | from util.face_sdk.core.model_loader.BaseModelLoader import BaseModelLoader
15 | 
16 | 
17 | class FaceDetModelLoader(BaseModelLoader):
18 | 
19 |     def __init__(self, model_path, model_category, model_name, meta_file='model_meta.json'):
20 |         logger.info('Start to analyze the face detection model, model path: %s, model category: %s,model name: %s' %
21 |                     (model_path, model_category, model_name))
22 |         super().__init__(model_path, model_category, model_name, meta_file)
23 |         self.cfg['min_sizes'] = self.meta_conf['min_sizes']
24 |         self.cfg['steps'] = self.meta_conf['steps']
25 |         self.cfg['variance'] = self.meta_conf['variance']
26 |         self.cfg['in_channel'] = self.meta_conf['in_channel']
27 |         self.cfg['out_channel'] = self.meta_conf['out_channel']
28 |         self.cfg['confidence_threshold'] = self.meta_conf['confidence_threshold']
29 | 
30 |     def load_model(self):
31 |         model = torch.load(self.cfg['model_file_path'])
32 |         return model, self.cfg
33 | 


--------------------------------------------------------------------------------
/util/face_sdk/core/model_loader/face_parsing/FaceParsingModelLoader.py:
--------------------------------------------------------------------------------
 1 | """
 2 | @author: fengyu, wangjun
 3 | @date: 20220620
 4 | @contact: fengyu_cnyc@163.com
 5 | """
 6 | 
 7 | import logging.config
 8 | import os
 9 | 
10 | logging.config.fileConfig(os.path.join("util", "face_sdk", "config", "logging.conf"))
11 | logger = logging.getLogger('sdk')
12 | 
13 | import torch
14 | 
15 | from util.face_sdk.core.model_loader.BaseModelLoader import BaseModelLoader
16 | 
17 | class FaceParsingModelLoader(BaseModelLoader):
18 |     def __init__(self, model_path, model_category, model_name, meta_file='model_meta.json'):
19 |         logger.info('Start to analyze the face parsing model, model path: %s, model category: %s,model name: %s' %
20 |                     (model_path, model_category, model_name))
21 |         super().__init__(model_path, model_category, model_name, meta_file)
22 | 
23 |         self.cfg['input_height'] = self.meta_conf['input_height']
24 |         self.cfg['input_width'] = self.meta_conf['input_width']
25 | 
26 |         
27 |     def load_model(self):
28 |         try:
29 |             model = torch.jit.load(self.cfg['model_file_path'])
30 |         except Exception as e:
31 |             logger.error('The model failed to load, please check the model path: %s!'
32 |                          % self.cfg['model_file_path'])
33 |             raise e
34 |         else:
35 |             logger.info('Successfully loaded the face parsing model!')
36 |             return model, self.cfg


--------------------------------------------------------------------------------
/util/face_sdk/core/model_loader/face_recognition/FaceRecModelLoader.py:
--------------------------------------------------------------------------------
 1 | """
 2 | @author: JiXuan Xu, Jun Wang
 3 | @date: 20201015
 4 | @contact: jun21wangustc@gmail.com 
 5 | """
 6 | import logging.config
 7 | import os
 8 | 
 9 | logging.config.fileConfig(os.path.join("util", "face_sdk", "config", "logging.conf"))
10 | logger = logging.getLogger('sdk')
11 | 
12 | import torch
13 | 
14 | from util.face_sdk.core.model_loader.BaseModelLoader import BaseModelLoader
15 | 
16 | 
17 | class FaceRecModelLoader(BaseModelLoader):
18 | 
19 |     def __init__(self, model_path, model_category, model_name, meta_file='model_meta.json'):
20 |         logger.info('Start to analyze the face recognition model, model path: %s, model category: %s,model name: %s' %
21 |                     (model_path, model_category, model_name))
22 |         super().__init__(model_path, model_category, model_name, meta_file)
23 |         self.cfg['mean'] = self.meta_conf['mean']
24 |         self.cfg['std'] = self.meta_conf['std']
25 | 
26 |     def load_model(self):
27 |         try:
28 |             model = torch.load(self.cfg['model_file_path'])
29 |         except Exception as e:
30 |             logger.error('The model failed to load, please check the model path: %s!'
31 |                          % self.cfg['model_file_path'])
32 |             raise e
33 |         else:
34 |             logger.info('Successfully loaded the face recognition model!')
35 |             return model, self.cfg
36 | 


--------------------------------------------------------------------------------
/util/face_sdk/face_crop.py:
--------------------------------------------------------------------------------
  1 | import glob
  2 | import logging.config
  3 | import os
  4 | import sys
  5 | from concurrent.futures import ProcessPoolExecutor
  6 | from pathlib import Path
  7 | from typing import Tuple
  8 | 
  9 | import cv2
 10 | import ffmpeg
 11 | import yaml
 12 | from numpy import ndarray
 13 | from tqdm.auto import tqdm
 14 | 
 15 | from marlin_pytorch.util import crop_with_padding
 16 | from util.face_sdk.core.model_handler.face_detection.FaceDetModelHandler import FaceDetModelHandler
 17 | from util.face_sdk.core.model_loader.face_detection.FaceDetModelLoader import FaceDetModelLoader
 18 | 
 19 | logging.config.fileConfig(os.path.join("util", "face_sdk", "config", "logging.conf"))
 20 | logger = logging.getLogger('api')
 21 | 
 22 | with open(os.path.join("util", "face_sdk", "config", "model_conf.yaml")) as f:
 23 |     model_conf = yaml.load(f, Loader=yaml.FullLoader)
 24 | 
 25 | # common setting for all model, need not modify.
 26 | model_path = os.path.join("util", "face_sdk", 'models')
 27 | 
 28 | # model setting, modified along with model
 29 | scene = 'non-mask'
 30 | model_category = 'face_detection'
 31 | model_name = model_conf[scene][model_category]
 32 | 
 33 | logger.info('Start to load the face detection model...')
 34 | # load model
 35 | sys.path.append(os.path.join("util", "face_sdk"))
 36 | faceDetModelLoader = FaceDetModelLoader(model_path, model_category, model_name)
 37 | model, cfg = faceDetModelLoader.load_model()
 38 | faceDetModelHandler = FaceDetModelHandler(model, 'cuda:0', cfg)
 39 | 
 40 | 
 41 | def crop_face(frame, margin=1, x=0, y=0) -> Tuple[ndarray, int, int, int]:
 42 |     assert frame.ndim == 3 and frame.shape[2] == 3, "frame should be 3-dim"
 43 |     dets = faceDetModelHandler.inference_on_image(frame)
 44 |     if len(dets) > 0:
 45 |         x1, y1, x2, y2, confidence = dets[0]
 46 |         # center
 47 |         x, y = (int((x1 + x2) / 2), int((y1 + y2) / 2))
 48 |         margin = int(max(abs(x2 - x1), abs(y2 - y1)) / 2)
 49 |     # crop face
 50 |     face = crop_with_padding(frame, x - margin, x + margin, y - margin, y + margin, 0)
 51 |     face = cv2.resize(face, (224, 224))
 52 |     return face, margin, x, y
 53 | 
 54 | 
 55 | def crop_face_video(video_path: str, save_path: str, fourcc=cv2.VideoWriter_fourcc(*"mp4v"), fps=30) -> None:
 56 |     cap = cv2.VideoCapture(video_path)
 57 |     writer = cv2.VideoWriter(save_path, fourcc=fourcc, fps=fps, frameSize=(224, 224))
 58 |     x, y = 0, 0
 59 |     margin = 1
 60 | 
 61 |     while True:
 62 |         ret, frame = cap.read()
 63 |         if ret:
 64 |             face, margin, x, y = crop_face(frame, margin, x, y)
 65 |             writer.write(face)
 66 |         else:
 67 |             break
 68 | 
 69 |     cap.release()
 70 |     writer.release()
 71 | 
 72 | 
 73 | def crop_face_img(img_path: str, save_path: str):
 74 |     frame = cv2.imread(img_path)
 75 |     face = crop_face(frame)[0]
 76 |     cv2.imwrite(save_path, face)
 77 | 
 78 | 
 79 | def process_videos(video_path, output_path, ext="mp4", max_workers=8):
 80 |     if ext == "mp4":
 81 |         fourcc = cv2.VideoWriter_fourcc(*"mp4v")
 82 |     elif ext == "avi":
 83 |         fourcc = cv2.VideoWriter_fourcc(*"XVID")
 84 |     else:
 85 |         raise ValueError("ext should be mp4 or avi")
 86 | 
 87 |     Path(output_path).mkdir(parents=True, exist_ok=True)
 88 | 
 89 |     files = os.listdir(video_path)
 90 |     with ProcessPoolExecutor(max_workers=max_workers) as executor:
 91 |         futures = []
 92 | 
 93 |         for f_name in tqdm(files):
 94 |             if f_name.endswith('.' + ext):
 95 |                 source_path = os.path.join(video_path, f_name)
 96 |                 target_path = os.path.join(output_path, f_name)
 97 |                 fps = eval(ffmpeg.probe(source_path)["streams"][0]["avg_frame_rate"])
 98 |                 futures.append(executor.submit(crop_face_video, source_path, target_path, fourcc,
 99 |                     fps))
100 | 
101 |         for future in tqdm(futures):
102 |             future.result()
103 | 
104 | 
105 | def process_images(image_path: str, output_path: str, max_workers: int = 8):
106 |     Path(output_path).mkdir(parents=True, exist_ok=True)
107 |     files = glob.glob(f"{image_path}/*/*/*.jpg")
108 |     with ProcessPoolExecutor(max_workers=max_workers) as executor:
109 |         futures = []
110 | 
111 |         for file in tqdm(files):
112 |             save_path = file.replace(image_path, output_path)
113 |             Path("/".join(save_path.split("/")[:-1])).mkdir(parents=True, exist_ok=True)
114 |             futures.append(executor.submit(crop_face_img, file, save_path))
115 | 
116 |         for future in tqdm(futures):
117 |             future.result()
118 | 


--------------------------------------------------------------------------------
/util/face_sdk/face_parse.py:
--------------------------------------------------------------------------------
 1 | import glob
 2 | import logging.config
 3 | import os.path
 4 | import sys
 5 | from multiprocessing import set_start_method
 6 | from pathlib import Path
 7 | 
 8 | import cv2
 9 | import numpy as np
10 | import torch
11 | import yaml
12 | from tqdm.auto import tqdm
13 | 
14 | from util.face_sdk.core.model_handler.face_alignment.FaceAlignModelHandler import FaceAlignModelHandler
15 | from util.face_sdk.core.model_handler.face_detection.FaceDetModelHandler import FaceDetModelHandler
16 | from util.face_sdk.core.model_handler.face_parsing.FaceParsingModelHandler import FaceParsingModelHandler
17 | from util.face_sdk.core.model_loader.face_alignment.FaceAlignModelLoader import FaceAlignModelLoader
18 | from util.face_sdk.core.model_loader.face_detection.FaceDetModelLoader import FaceDetModelLoader
19 | from util.face_sdk.core.model_loader.face_parsing.FaceParsingModelLoader import FaceParsingModelLoader
20 | 
21 | mpl_logger = logging.getLogger('matplotlib')
22 | mpl_logger.setLevel(logging.WARNING)
23 | 
24 | logging.config.fileConfig(os.path.join("util", "face_sdk", "config", "logging.conf"))
25 | logger = logging.getLogger('api')
26 | 
27 | with open(os.path.join("util", "face_sdk", "config", "model_conf.yaml")) as f:
28 |     model_conf = yaml.load(f, Loader=yaml.FullLoader)
29 | 
30 | # common setting for all models, need not modify.
31 | model_path = os.path.join("util", "face_sdk", "models")
32 | sys.path.append(os.path.join("util", "face_sdk"))
33 | # face detection model setting.
34 | scene = 'non-mask'
35 | model_category = 'face_detection'
36 | model_name = model_conf[scene][model_category]
37 | logger.info('Start to load the face detection model...')
38 | faceDetModelLoader = FaceDetModelLoader(model_path, model_category, model_name)
39 | model, cfg = faceDetModelLoader.load_model()
40 | faceDetModelHandler = FaceDetModelHandler(model, 'cuda:0', cfg)
41 | 
42 | # face landmark model setting.
43 | model_category = 'face_alignment'
44 | model_name = model_conf[scene][model_category]
45 | logger.info('Start to load the face landmark model...')
46 | faceAlignModelLoader = FaceAlignModelLoader(model_path, model_category, model_name)
47 | model, cfg = faceAlignModelLoader.load_model()
48 | faceAlignModelHandler = FaceAlignModelHandler(model, 'cuda:0', cfg)
49 | 
50 | # face parsing model setting.
51 | scene = 'non-mask'
52 | model_category = 'face_parsing'
53 | model_name = model_conf[scene][model_category]
54 | logger.info('Start to load the face parsing model...')
55 | faceParsingModelLoader = FaceParsingModelLoader(model_path, model_category, model_name)
56 | model, cfg = faceParsingModelLoader.load_model()
57 | faceParsingModelHandler = FaceParsingModelHandler(model, 'cuda:0', cfg)
58 | 
59 | 
60 | def parse_face_img(img_path: str, output_path: str):
61 |     image = cv2.imread(img_path, cv2.IMREAD_COLOR)
62 |     dets = faceDetModelHandler.inference_on_image(image)
63 |     face_nums = dets.shape[0]
64 |     with torch.no_grad():
65 |         for i in range(face_nums):
66 |             landmarks = faceAlignModelHandler.inference_on_image(image, dets[i])
67 | 
68 |             landmarks = torch.from_numpy(landmarks[[104, 105, 54, 84, 90]]).float()
69 |             if i == 0:
70 |                 landmarks_five = landmarks.unsqueeze(0)
71 |             else:
72 |                 landmarks_five = torch.cat([landmarks_five, landmarks.unsqueeze(0)], dim=0)
73 |         try:
74 |             faces = faceParsingModelHandler.inference_on_image(face_nums, image, landmarks_five)["seg"]["logits"].cpu()
75 |             faces = faces.softmax(dim=1).argmax(dim=1).numpy()
76 |         except UnboundLocalError:
77 |             faces = np.zeros((0, 224, 224), dtype="int64")
78 |         np.save(output_path, faces)
79 | 
80 | 
81 | def check_exists(output_path: str):
82 |     return os.path.exists(output_path)
83 | 
84 | 
85 | def process_images(image_path: str, output_path: str):
86 |     Path(output_path).mkdir(parents=True, exist_ok=True)
87 |     files = glob.glob(f"{image_path}/*/*/*.jpg")
88 | 
89 |     for i, file in enumerate(tqdm(files)):
90 |         save_path = file.replace(image_path, output_path).replace(".jpg", ".npy")
91 |         Path("/".join(save_path.split("/")[:-1])).mkdir(parents=True, exist_ok=True)
92 |         parse_face_img(file, save_path)
93 | 


--------------------------------------------------------------------------------
/util/face_sdk/models/face_alignment/face_alignment_1.0/face_landmark_pfld.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/util/face_sdk/models/face_alignment/face_alignment_1.0/face_landmark_pfld.pkl


--------------------------------------------------------------------------------
/util/face_sdk/models/face_alignment/face_alignment_1.0/model_meta.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "model_type" : "pfld face landmark nets",
 3 |     "model_info" : "some model info",
 4 |     "model_file" : "face_landmark_pfld.pkl",
 5 |     "release_date" : "20201023",
 6 |     "input_height" : 112,
 7 |     "input_width" : 112
 8 | }
 9 | 
10 | 


--------------------------------------------------------------------------------
/util/face_sdk/models/face_alignment/face_alignment_2.0/face_landmark_pfld.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/util/face_sdk/models/face_alignment/face_alignment_2.0/face_landmark_pfld.pkl


--------------------------------------------------------------------------------
/util/face_sdk/models/face_alignment/face_alignment_2.0/model_meta.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "model_type" : "pfld mask face landmark nets",
 3 |     "model_info" : "some model info",
 4 |     "model_file" : "face_landmark_pfld.pkl",
 5 |     "release_date" : "20201229",
 6 |     "input_height" : 112,
 7 |     "input_width" : 112
 8 | }
 9 | 
10 | 


--------------------------------------------------------------------------------
/util/face_sdk/models/face_detection/face_detection_1.0/face_detection_retina.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/util/face_sdk/models/face_detection/face_detection_1.0/face_detection_retina.pkl


--------------------------------------------------------------------------------
/util/face_sdk/models/face_detection/face_detection_1.0/model_meta.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "model_type" : "retina face detect nets",
 3 |     "model_info" : "some model info",
 4 |     "model_file" : "face_detection_retina.pkl",
 5 |     "release_date" : "20201019",
 6 |     "input_height" : 120,
 7 |     "input_width" : 120,
 8 |     "min_sizes": [[16, 32], [64, 128], [256, 512]],
 9 |     "steps": [8, 16, 32],
10 |     "variance": [0.1, 0.2],
11 |     "in_channel": 256,
12 |     "out_channel": 256,
13 |     "confidence_threshold": 0.7
14 | }
15 | 
16 | 


--------------------------------------------------------------------------------
/util/face_sdk/models/face_detection/face_detection_2.0/face_detection_retina.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/util/face_sdk/models/face_detection/face_detection_2.0/face_detection_retina.pkl


--------------------------------------------------------------------------------
/util/face_sdk/models/face_detection/face_detection_2.0/model_meta.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "model_type" : "retina mask face detect nets",
 3 |     "model_info" : "some model info",
 4 |     "model_file" : "face_detection_retina.pkl",
 5 |     "release_date" : "20201219",
 6 |     "input_height" : 120,
 7 |     "input_width" : 120,
 8 |     "min_sizes": [[16, 32], [64, 128], [256, 512]],
 9 |     "steps": [8, 16, 32],
10 |     "variance": [0.1, 0.2],
11 |     "in_channel": 256,
12 |     "out_channel": 256,
13 |     "confidence_threshold": 0.4
14 | }
15 | 
16 | 


--------------------------------------------------------------------------------
/util/face_sdk/models/face_parsing/README.md:
--------------------------------------------------------------------------------
1 | # Face parsing Model
2 | [face_parsing.farl.lapa]https://github.com/FacePerceiver/facer/releases/download/models-v1/face_parsing.farl.lapa.main_ema_136500_jit191.pt
3 | Please put the pre-trained model under FaceX-Zoo/face_sdk/models/face_parsing/face_parsing_1.0/
4 | 


--------------------------------------------------------------------------------
/util/face_sdk/models/face_parsing/face_parsing_1.0/model_meta.json:
--------------------------------------------------------------------------------
1 | {
2 |     "model_type" : "face_parsing.farl.lapa",
3 |     "model_info" : "some model info",
4 |     "model_file" : "face_parsing.farl.lapa.main_ema_136500_jit191.pt",
5 |     "release_date" : "20220226",
6 |     "input_height" : 448,
7 |     "input_width" : 448
8 | }
9 | 


--------------------------------------------------------------------------------
/util/face_sdk/models/face_recognition/face_recognition_1.0/face_recognition_mv.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/util/face_sdk/models/face_recognition/face_recognition_1.0/face_recognition_mv.pkl


--------------------------------------------------------------------------------
/util/face_sdk/models/face_recognition/face_recognition_1.0/model_meta.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "model_type" : "mobile face nets",
 3 |     "model_info" : "some model info",
 4 |     "model_file" : "face_recognition_mv.pkl",
 5 |     "release_date" : "20200630",
 6 |     "input_height" : 112,
 7 |     "input_width" : 112,
 8 |     "mean" : 127.5,
 9 |     "std" : 128.0
10 | }
11 | 


--------------------------------------------------------------------------------
/util/face_sdk/models/face_recognition/face_recognition_2.0/face_recognition_mv.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ControlNet/MARLIN/c9d5698f98799828b5bcd0528e408a06e3f58502/util/face_sdk/models/face_recognition/face_recognition_2.0/face_recognition_mv.pkl


--------------------------------------------------------------------------------
/util/face_sdk/models/face_recognition/face_recognition_2.0/model_meta.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "model_type" : "mobile face nets",
 3 |     "model_info" : "some model info",
 4 |     "model_file" : "face_recognition_mv.pkl",
 5 |     "release_date" : "20200630",
 6 |     "input_height" : 112,
 7 |     "input_width" : 112,
 8 |     "mean" : 127.5,
 9 |     "std" : 128.0
10 | }
11 | 


--------------------------------------------------------------------------------
/util/face_sdk/models/network_def/mobilefacenet_def.py:
--------------------------------------------------------------------------------
  1 | """
  2 | @author: Jun Wang 
  3 | @date: 20201019
  4 | @contact: jun21wangustc@gmail.com
  5 | """
  6 | # based on:
  7 | # https://github.com/TreB1eN/InsightFace_Pytorch/blob/master/model.py
  8 | 
  9 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
 10 | import torch
 11 | 
 12 | 
 13 | class Flatten(Module):
 14 | 
 15 |     def forward(self, input):
 16 |         return input.view(input.size(0), -1)
 17 | 
 18 | 
 19 | def l2_norm(input, axis=1):
 20 |     norm = torch.norm(input, 2, axis, True)
 21 |     output = torch.div(input, norm)
 22 |     return output
 23 | 
 24 | 
 25 | class Conv_block(Module):
 26 | 
 27 |     def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
 28 |         super(Conv_block, self).__init__()
 29 |         self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding,
 30 |                            bias=False)
 31 |         self.bn = BatchNorm2d(out_c)
 32 |         self.prelu = PReLU(out_c)
 33 | 
 34 |     def forward(self, x):
 35 |         x = self.conv(x)
 36 |         x = self.bn(x)
 37 |         x = self.prelu(x)
 38 |         return x
 39 | 
 40 | 
 41 | class Linear_block(Module):
 42 | 
 43 |     def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
 44 |         super(Linear_block, self).__init__()
 45 |         self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding,
 46 |                            bias=False)
 47 |         self.bn = BatchNorm2d(out_c)
 48 | 
 49 |     def forward(self, x):
 50 |         x = self.conv(x)
 51 |         x = self.bn(x)
 52 |         return x
 53 | 
 54 | 
 55 | class Depth_Wise(Module):
 56 | 
 57 |     def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
 58 |         super(Depth_Wise, self).__init__()
 59 |         self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
 60 |         self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
 61 |         self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
 62 |         self.residual = residual
 63 | 
 64 |     def forward(self, x):
 65 |         if self.residual:
 66 |             short_cut = x
 67 |         x = self.conv(x)
 68 |         x = self.conv_dw(x)
 69 |         x = self.project(x)
 70 |         if self.residual:
 71 |             output = short_cut + x
 72 |         else:
 73 |             output = x
 74 |         return output
 75 | 
 76 | 
 77 | class Residual(Module):
 78 | 
 79 |     def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
 80 |         super(Residual, self).__init__()
 81 |         modules = []
 82 |         for _ in range(num_block):
 83 |             modules.append(
 84 |                 Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
 85 |         self.model = Sequential(*modules)
 86 | 
 87 |     def forward(self, x):
 88 |         return self.model(x)
 89 | 
 90 | 
 91 | class MobileFaceNet(Module):
 92 | 
 93 |     def __init__(self, embedding_size, out_h, out_w):
 94 |         super(MobileFaceNet, self).__init__()
 95 |         self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
 96 |         self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
 97 |         self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
 98 |         self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
 99 |         self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
100 |         self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
101 |         self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
102 |         self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
103 |         self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
104 |         # self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7,7), stride=(1, 1), padding=(0, 0))
105 |         # self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(4,7), stride=(1, 1), padding=(0, 0))
106 |         self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(out_h, out_w), stride=(1, 1), padding=(0, 0))
107 |         self.conv_6_flatten = Flatten()
108 |         self.linear = Linear(512, embedding_size, bias=False)
109 |         self.bn = BatchNorm1d(embedding_size)
110 | 
111 |     def forward(self, x):
112 |         out = self.conv1(x)
113 |         out = self.conv2_dw(out)
114 |         out = self.conv_23(out)
115 |         out = self.conv_3(out)
116 |         out = self.conv_34(out)
117 |         out = self.conv_4(out)
118 |         out = self.conv_45(out)
119 |         out = self.conv_5(out)
120 |         out = self.conv_6_sep(out)
121 |         out = self.conv_6_dw(out)
122 |         out = self.conv_6_flatten(out)
123 |         out = self.linear(out)
124 |         out = self.bn(out)
125 |         return l2_norm(out)
126 | 


--------------------------------------------------------------------------------
/util/face_sdk/models/network_def/mobilev3_pfld.py:
--------------------------------------------------------------------------------
  1 | # derive from:
  2 | # https://github.com/Hsintao/pfld_106_face_landmarks/blob/master/models/mobilev3_pfld.py
  3 | 
  4 | import torch
  5 | import torch.nn as nn
  6 | import torch.nn.functional as F
  7 | 
  8 | 
  9 | def conv_bn(inp, oup, kernel_size, stride, padding=1, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d,
 10 |     nlin_layer=nn.ReLU
 11 | ):
 12 |     return nn.Sequential(
 13 |         conv_layer(inp, oup, kernel_size, stride, padding, bias=False),
 14 |         norm_layer(oup),
 15 |         nlin_layer(inplace=True)
 16 |     )
 17 | 
 18 | 
 19 | def conv_1x1_bn(inp, oup, conv_layer=nn.Conv2d, norm_layer=nn.BatchNorm2d, nlin_layer=nn.ReLU):
 20 |     return nn.Sequential(
 21 |         conv_layer(inp, oup, 1, 1, 0, bias=False),
 22 |         norm_layer(oup),
 23 |         nlin_layer(inplace=True)
 24 |     )
 25 | 
 26 | 
 27 | class Hswish(nn.Module):
 28 | 
 29 |     def __init__(self, inplace=True):
 30 |         super(Hswish, self).__init__()
 31 |         self.inplace = inplace
 32 | 
 33 |     def forward(self, x):
 34 |         return x * F.relu6(x + 3., inplace=self.inplace) / 6.
 35 | 
 36 | 
 37 | class Hsigmoid(nn.Module):
 38 | 
 39 |     def __init__(self, inplace=True):
 40 |         super(Hsigmoid, self).__init__()
 41 |         self.inplace = inplace
 42 | 
 43 |     def forward(self, x):
 44 |         return F.relu6(x + 3., inplace=self.inplace) / 6.
 45 | 
 46 | 
 47 | class SEModule(nn.Module):
 48 | 
 49 |     def __init__(self, channel, reduction=4):
 50 |         super(SEModule, self).__init__()
 51 |         self.avg_pool = nn.AdaptiveAvgPool2d(1)
 52 |         self.fc = nn.Sequential(
 53 |             nn.Linear(channel, channel // reduction, bias=False),
 54 |             nn.ReLU(inplace=True),
 55 |             nn.Linear(channel // reduction, channel, bias=False),
 56 |             Hsigmoid()
 57 |         )
 58 | 
 59 |     def forward(self, x):
 60 |         b, c, h, w = x.size()
 61 |         # F.avg_pool2d()
 62 |         y = self.avg_pool(x).view(b, c)
 63 |         y = self.fc(y).view(b, c, 1, 1)
 64 |         return x * y
 65 | 
 66 | 
 67 | class Identity(nn.Module):
 68 | 
 69 |     def __init__(self, channel):
 70 |         super(Identity, self).__init__()
 71 | 
 72 |     def forward(self, x):
 73 |         return x
 74 | 
 75 | 
 76 | class MobileBottleneck(nn.Module):
 77 | 
 78 |     def __init__(self, inp, oup, kernel, stride, exp, se=False, nl='RE'):
 79 |         super(MobileBottleneck, self).__init__()
 80 |         assert stride in [1, 2]
 81 |         assert kernel in [3, 5]
 82 |         padding = (kernel - 1) // 2
 83 |         self.use_res_connect = stride == 1 and inp == oup
 84 | 
 85 |         conv_layer = nn.Conv2d
 86 |         norm_layer = nn.BatchNorm2d
 87 |         if nl == 'RE':
 88 |             nlin_layer = nn.ReLU  # or ReLU6
 89 |         elif nl == 'HS':
 90 |             nlin_layer = Hswish
 91 |         else:
 92 |             raise NotImplementedError
 93 |         if se:
 94 |             SELayer = SEModule
 95 |         else:
 96 |             SELayer = Identity
 97 | 
 98 |         self.conv = nn.Sequential(
 99 |             # pw
100 |             conv_layer(inp, exp, 1, 1, 0, bias=False),
101 |             norm_layer(exp),
102 |             nlin_layer(inplace=True),
103 |             # dw
104 |             conv_layer(exp, exp, kernel, stride, padding, groups=exp, bias=False),
105 |             norm_layer(exp),
106 |             SELayer(exp),
107 |             nlin_layer(inplace=True),
108 |             # pw-linear
109 |             conv_layer(exp, oup, 1, 1, 0, bias=False),
110 |             norm_layer(oup),
111 |         )
112 | 
113 |     def forward(self, x):
114 |         if self.use_res_connect:
115 |             return x + self.conv(x)
116 |         else:
117 |             return self.conv(x)
118 | 
119 | 
120 | class PFLDInference(nn.Module):
121 | 
122 |     def __init__(self):
123 |         super(PFLDInference, self).__init__()
124 |         self.use_attention = True
125 |         self.conv_bn1 = conv_bn(3, 16, 3, stride=1, nlin_layer=Hswish)
126 |         self.conv_bn2 = MobileBottleneck(16, 16, 3, 1, 16, False, 'RE')
127 | 
128 |         self.conv3_1 = MobileBottleneck(16, 24, 3, 2, 64, False, 'RE')
129 | 
130 |         self.block3_2 = MobileBottleneck(24, 24, 3, 1, 72, False, "RE")
131 |         self.block3_3 = MobileBottleneck(24, 40, 5, 2, 72, self.use_attention, "RE")
132 |         self.block3_4 = MobileBottleneck(40, 40, 5, 1, 120, self.use_attention, "RE")
133 |         self.block3_5 = MobileBottleneck(40, 40, 5, 1, 120, self.use_attention, "RE")
134 | 
135 |         self.conv4_1 = MobileBottleneck(40, 80, 3, 2, 240, False, "RE")
136 | 
137 |         self.conv5_1 = MobileBottleneck(80, 80, 3, 1, 200, False, "HS")
138 |         self.block5_2 = MobileBottleneck(80, 112, 3, 1, 480, self.use_attention, "HS")
139 |         self.block5_3 = MobileBottleneck(112, 112, 3, 1, 672, self.use_attention, "HS")
140 |         self.block5_4 = MobileBottleneck(112, 160, 3, 1, 672, self.use_attention, "HS")
141 | 
142 |         self.conv6_1 = MobileBottleneck(160, 16, 3, 1, 320, False, "HS")  # [16, 14, 14]
143 | 
144 |         self.conv7 = nn.Conv2d(16, 32, 3, 2, padding=1)
145 |         self.conv8 = nn.Conv2d(32, 128, 7, 1, 0)
146 |         self.avg_pool1 = nn.AvgPool2d(14)
147 |         self.avg_pool2 = nn.AvgPool2d(7)
148 |         self.fc = nn.Linear(176, 106 * 2)
149 | 
150 |     def forward(self, x):  # x: 3, 112, 112
151 |         x = self.conv_bn1(x)  # [64, 56, 56]
152 |         x = self.conv_bn2(x)  # [64, 56, 56]
153 |         x = self.conv3_1(x)
154 |         x = self.block3_2(x)
155 |         x = self.block3_3(x)
156 |         x = self.block3_4(x)
157 |         out1 = self.block3_5(x)
158 | 
159 |         x = self.conv4_1(out1)
160 | 
161 |         x = self.conv5_1(x)
162 |         x = self.block5_2(x)
163 |         x = self.block5_3(x)
164 |         x = self.block5_4(x)
165 |         x = self.conv6_1(x)
166 |         x1 = self.avg_pool1(x)
167 |         x1 = x1.view(x1.size(0), -1)
168 | 
169 |         x = self.conv7(x)
170 |         x2 = self.avg_pool2(x)
171 |         x2 = x2.view(x2.size(0), -1)
172 | 
173 |         x3 = self.conv8(x)
174 |         x3 = x3.view(x1.size(0), -1)
175 | 
176 |         multi_scale = torch.cat([x1, x2, x3], 1)
177 |         landmarks = self.fc(multi_scale)
178 | 
179 |         return out1, landmarks
180 | 
181 | 
182 | class AuxiliaryNet(nn.Module):
183 | 
184 |     def __init__(self):
185 |         super(AuxiliaryNet, self).__init__()
186 |         self.conv1 = conv_bn(40, 128, 3, 2)
187 |         self.conv2 = conv_bn(128, 128, 3, 1)
188 |         self.conv3 = conv_bn(128, 32, 3, 2)
189 |         self.conv4 = conv_bn(32, 128, 3, 1, padding=0)
190 |         self.max_pool1 = nn.MaxPool2d(5)
191 |         self.fc1 = nn.Linear(128, 32)
192 |         self.fc2 = nn.Linear(32, 3)
193 | 
194 |     def forward(self, x):
195 |         x = self.conv1(x)
196 |         x = self.conv2(x)
197 |         x = self.conv3(x)
198 |         x = self.conv4(x)
199 |         x = self.max_pool1(x)
200 |         x = x.view(x.size(0), -1)
201 |         x = self.fc1(x)
202 |         x = self.fc2(x)
203 | 
204 |         return x
205 | 


--------------------------------------------------------------------------------
/util/face_sdk/models/network_def/retinaface_def.py:
--------------------------------------------------------------------------------
  1 | """
  2 | @author: JiXuan Xu, Jun Wang
  3 | @date: 20201019
  4 | @contact: jun21wangustc@gmail.com 
  5 | """
  6 | 
  7 | # based on:
  8 | # https://github.com/biubug6/Pytorch_Retinaface/blob/master/models/retinaface.py
  9 | 
 10 | import torch
 11 | import torch.nn as nn
 12 | import torch.nn.functional as F
 13 | import torchvision.models._utils as _utils
 14 | 
 15 | 
 16 | def conv_bn(inp, oup, stride=1, leaky=0):
 17 |     return nn.Sequential(
 18 |         nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
 19 |         nn.BatchNorm2d(oup),
 20 |         nn.LeakyReLU(negative_slope=leaky, inplace=True)
 21 |     )
 22 | 
 23 | 
 24 | def conv_bn_no_relu(inp, oup, stride):
 25 |     return nn.Sequential(
 26 |         nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
 27 |         nn.BatchNorm2d(oup),
 28 |     )
 29 | 
 30 | 
 31 | def conv_bn1X1(inp, oup, stride, leaky=0):
 32 |     return nn.Sequential(
 33 |         nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
 34 |         nn.BatchNorm2d(oup),
 35 |         nn.LeakyReLU(negative_slope=leaky, inplace=True)
 36 |     )
 37 | 
 38 | 
 39 | def conv_dw(inp, oup, stride, leaky=0.1):
 40 |     return nn.Sequential(
 41 |         nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
 42 |         nn.BatchNorm2d(inp),
 43 |         nn.LeakyReLU(negative_slope=leaky, inplace=True),
 44 | 
 45 |         nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
 46 |         nn.BatchNorm2d(oup),
 47 |         nn.LeakyReLU(negative_slope=leaky, inplace=True),
 48 |     )
 49 | 
 50 | 
 51 | class SSH(nn.Module):
 52 | 
 53 |     def __init__(self, in_channel, out_channel):
 54 |         super(SSH, self).__init__()
 55 |         assert out_channel % 4 == 0
 56 |         leaky = 0
 57 |         if (out_channel <= 64):
 58 |             leaky = 0.1
 59 |         self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
 60 | 
 61 |         self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
 62 |         self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
 63 | 
 64 |         self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
 65 |         self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
 66 | 
 67 |     def forward(self, input):
 68 |         conv3X3 = self.conv3X3(input)
 69 | 
 70 |         conv5X5_1 = self.conv5X5_1(input)
 71 |         conv5X5 = self.conv5X5_2(conv5X5_1)
 72 | 
 73 |         conv7X7_2 = self.conv7X7_2(conv5X5_1)
 74 |         conv7X7 = self.conv7x7_3(conv7X7_2)
 75 | 
 76 |         out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
 77 |         out = F.relu(out)
 78 |         return out
 79 | 
 80 | 
 81 | class FPN(nn.Module):
 82 | 
 83 |     def __init__(self, in_channels_list, out_channels):
 84 |         super(FPN, self).__init__()
 85 |         leaky = 0
 86 |         if (out_channels <= 64):
 87 |             leaky = 0.1
 88 |         self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
 89 |         self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
 90 |         self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
 91 | 
 92 |         self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
 93 |         self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
 94 | 
 95 |     def forward(self, input):
 96 |         # names = list(input.keys())
 97 |         input = list(input.values())
 98 | 
 99 |         output1 = self.output1(input[0])
100 |         output2 = self.output2(input[1])
101 |         output3 = self.output3(input[2])
102 | 
103 |         up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
104 |         output2 = output2 + up3
105 |         output2 = self.merge2(output2)
106 | 
107 |         up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
108 |         output1 = output1 + up2
109 |         output1 = self.merge1(output1)
110 | 
111 |         out = [output1, output2, output3]
112 |         return out
113 | 
114 | 
115 | class MobileNetV1(nn.Module):
116 | 
117 |     def __init__(self):
118 |         super(MobileNetV1, self).__init__()
119 |         self.stage1 = nn.Sequential(
120 |             conv_bn(3, 8, 2, leaky=0.1),  # 3
121 |             conv_dw(8, 16, 1),  # 7
122 |             conv_dw(16, 32, 2),  # 11
123 |             conv_dw(32, 32, 1),  # 19
124 |             conv_dw(32, 64, 2),  # 27
125 |             conv_dw(64, 64, 1),  # 43
126 |         )
127 |         self.stage2 = nn.Sequential(
128 |             conv_dw(64, 128, 2),  # 43 + 16 = 59
129 |             conv_dw(128, 128, 1),  # 59 + 32 = 91
130 |             conv_dw(128, 128, 1),  # 91 + 32 = 123
131 |             conv_dw(128, 128, 1),  # 123 + 32 = 155
132 |             conv_dw(128, 128, 1),  # 155 + 32 = 187
133 |             conv_dw(128, 128, 1),  # 187 + 32 = 219
134 |         )
135 |         self.stage3 = nn.Sequential(
136 |             conv_dw(128, 256, 2),  # 219 +3 2 = 241
137 |             conv_dw(256, 256, 1),  # 241 + 64 = 301
138 |         )
139 |         self.avg = nn.AdaptiveAvgPool2d((1, 1))
140 |         self.fc = nn.Linear(256, 1000)
141 | 
142 |     def forward(self, x):
143 |         x = self.stage1(x)
144 |         x = self.stage2(x)
145 |         x = self.stage3(x)
146 |         x = self.avg(x)
147 |         # x = self.model(x)
148 |         x = x.view(-1, 256)
149 |         x = self.fc(x)
150 |         return x
151 | 
152 | 
153 | class ClassHead(nn.Module):
154 | 
155 |     def __init__(self, inchannels=512, num_anchors=3):
156 |         super(ClassHead, self).__init__()
157 |         self.num_anchors = num_anchors
158 |         self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
159 | 
160 |     def forward(self, x):
161 |         out = self.conv1x1(x)
162 |         out = out.permute(0, 2, 3, 1).contiguous()
163 | 
164 |         return out.view(out.shape[0], -1, 2)
165 | 
166 | 
167 | class BboxHead(nn.Module):
168 | 
169 |     def __init__(self, inchannels=512, num_anchors=3):
170 |         super(BboxHead, self).__init__()
171 |         self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
172 | 
173 |     def forward(self, x):
174 |         out = self.conv1x1(x)
175 |         out = out.permute(0, 2, 3, 1).contiguous()
176 | 
177 |         return out.view(out.shape[0], -1, 4)
178 | 
179 | 
180 | class LandmarkHead(nn.Module):
181 | 
182 |     def __init__(self, inchannels=512, num_anchors=3):
183 |         super(LandmarkHead, self).__init__()
184 |         self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
185 | 
186 |     def forward(self, x):
187 |         out = self.conv1x1(x)
188 |         out = out.permute(0, 2, 3, 1).contiguous()
189 | 
190 |         return out.view(out.shape[0], -1, 10)
191 | 
192 | 
193 | class RetinaFace(nn.Module):
194 | 
195 |     def __init__(self, cfg=None, phase='train'):
196 |         """
197 |         :param cfg:  Network related settings.
198 |         :param phase: train or test.
199 |         """
200 |         super(RetinaFace, self).__init__()
201 |         self.phase = phase
202 |         backbone = MobileNetV1()
203 | 
204 |         self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
205 |         in_channels_stage2 = cfg['in_channel']
206 |         in_channels_list = [
207 |             in_channels_stage2 * 2,
208 |             in_channels_stage2 * 4,
209 |             in_channels_stage2 * 8,
210 |         ]
211 |         out_channels = cfg['out_channel']
212 |         self.fpn = FPN(in_channels_list, out_channels)
213 |         self.ssh1 = SSH(out_channels, out_channels)
214 |         self.ssh2 = SSH(out_channels, out_channels)
215 |         self.ssh3 = SSH(out_channels, out_channels)
216 | 
217 |         self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
218 |         self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
219 |         self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
220 | 
221 |     def _make_class_head(self, fpn_num=3, inchannels=64, anchor_num=2):
222 |         classhead = nn.ModuleList()
223 |         for i in range(fpn_num):
224 |             classhead.append(ClassHead(inchannels, anchor_num))
225 |         return classhead
226 | 
227 |     def _make_bbox_head(self, fpn_num=3, inchannels=64, anchor_num=2):
228 |         bboxhead = nn.ModuleList()
229 |         for i in range(fpn_num):
230 |             bboxhead.append(BboxHead(inchannels, anchor_num))
231 |         return bboxhead
232 | 
233 |     def _make_landmark_head(self, fpn_num=3, inchannels=64, anchor_num=2):
234 |         landmarkhead = nn.ModuleList()
235 |         for i in range(fpn_num):
236 |             landmarkhead.append(LandmarkHead(inchannels, anchor_num))
237 |         return landmarkhead
238 | 
239 |     def forward(self, inputs):
240 |         out = self.body(inputs)
241 | 
242 |         # FPN
243 |         fpn = self.fpn(out)
244 | 
245 |         # SSH
246 |         feature1 = self.ssh1(fpn[0])
247 |         feature2 = self.ssh2(fpn[1])
248 |         feature3 = self.ssh3(fpn[2])
249 |         features = [feature1, feature2, feature3]
250 | 
251 |         bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
252 |         classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1)
253 |         ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
254 | 
255 |         if self.phase == 'train':
256 |             output = (bbox_regressions, classifications, ldm_regressions)
257 |         else:
258 |             output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
259 |         return output
260 | 


--------------------------------------------------------------------------------
/util/face_sdk/utils/BuzException.py:
--------------------------------------------------------------------------------
 1 | """
 2 | @author: JiXuan Xu, Jun Wang
 3 | @date: 20201015
 4 | @contact: jun21wangustc@gmail.com 
 5 | """
 6 | 
 7 | 
 8 | # all self defined exception is derived from BuzException
 9 | class BuzException(Exception):
10 |     pass
11 | 
12 | 
13 | class InputError(BuzException):
14 | 
15 |     def __init__(self):
16 |         pass
17 | 
18 |     def __str__(self):
19 |         return ("Input type error!")
20 | 
21 | 
22 | ###############################################
23 | # all image related exception.
24 | ###############################################
25 | class ImageException(BuzException):
26 |     pass
27 | 
28 | 
29 | class EmptyImageError(ImageException):
30 | 
31 |     def __init__(self):
32 |         pass
33 | 
34 |     def __str__(self):
35 |         return ("The input image is empty.")
36 | 
37 | 
38 | class FalseImageSizeError(ImageException):
39 | 
40 |     def __init__(self):
41 |         pass
42 | 
43 |     def __str__(self):
44 |         return ("The input image size is false.")
45 | 
46 | 
47 | class FaseChannelError(ImageException):
48 | 
49 |     def __init__(self, channel):
50 |         self.channel = channel
51 | 
52 |     def __str__(self):
53 |         return ("Input channel {} is invalid(only 2, 3, 4 channel is support.),".format(repr(self.channel)))
54 | 


--------------------------------------------------------------------------------
/util/face_sdk/utils/draw.py:
--------------------------------------------------------------------------------
  1 | # based on:
  2 | # https://github.com/FacePerceiver/facer/blob/main/facer/draw.py
  3 | import colorsys
  4 | import random
  5 | from typing import Dict
  6 | 
  7 | import numpy as np
  8 | import torch
  9 | from skimage.draw import circle_perimeter_aa
 10 | 
 11 | 
 12 | def _gen_random_colors(N, bright=True):
 13 |     brightness = 1.0 if bright else 0.7
 14 |     hsv = [(i / N, 1, brightness) for i in range(N)]
 15 |     colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
 16 |     random.shuffle(colors)
 17 |     return colors
 18 | 
 19 | 
 20 | _static_label_colors = [
 21 |                            np.array((1.0, 1.0, 1.0), np.float32),
 22 |                            np.array((255, 250, 79), np.float32) / 255.0,  # face
 23 |                            np.array([255, 125, 138], np.float32) / 255.0,  # lb
 24 |                            np.array([213, 32, 29], np.float32) / 255.0,  # rb
 25 |                            np.array([0, 144, 187], np.float32) / 255.0,  # le
 26 |                            np.array([0, 196, 253], np.float32) / 255.0,  # re
 27 |                            np.array([255, 129, 54], np.float32) / 255.0,  # nose
 28 |                            np.array([88, 233, 135], np.float32) / 255.0,  # ulip
 29 |                            np.array([0, 117, 27], np.float32) / 255.0,  # llip
 30 |                            np.array([255, 76, 249], np.float32) / 255.0,  # imouth
 31 |                            np.array((1.0, 0.0, 0.0), np.float32),  # hair
 32 |                            np.array((255, 250, 100), np.float32) / 255.0,  # lr
 33 |                            np.array((255, 250, 100), np.float32) / 255.0,  # rr
 34 |                            np.array((250, 245, 50), np.float32) / 255.0,  # neck
 35 |                            np.array((0.0, 1.0, 0.5), np.float32),  # cloth
 36 |                            np.array((1.0, 0.0, 0.5), np.float32),
 37 |                        ] + _gen_random_colors(256)
 38 | 
 39 | _names_in_static_label_colors = [
 40 |     'background', 'face', 'lb', 'rb', 'le', 're', 'nose',
 41 |     'ulip', 'llip', 'imouth', 'hair', 'lr', 'rr', 'neck',
 42 |     'cloth', 'eyeg', 'hat', 'earr'
 43 | ]
 44 | 
 45 | 
 46 | def select_data(selection, data):
 47 |     if isinstance(data, dict):
 48 |         return {name: select_data(selection, val) for name, val in data.items()}
 49 |     elif isinstance(data, (list, tuple)):
 50 |         return [select_data(selection, val) for val in data]
 51 |     elif isinstance(data, torch.Tensor):
 52 |         return data[selection]
 53 |     return data
 54 | 
 55 | 
 56 | def _blend_labels(image, labels, label_names_dict=None,
 57 |     default_alpha=0.6, color_offset=None
 58 | ):
 59 |     assert labels.ndim == 2
 60 |     bg_mask = labels == 0
 61 |     if label_names_dict is None:
 62 |         colors = _static_label_colors
 63 |     else:
 64 |         colors = [np.array((1.0, 1.0, 1.0), np.float32)]
 65 |         for i in range(1, labels.max() + 1):
 66 |             if isinstance(label_names_dict, dict) and i not in label_names_dict:
 67 |                 bg_mask = np.logical_or(bg_mask, labels == i)
 68 |                 colors.append(np.zeros((3)))
 69 |                 continue
 70 |             label_name = label_names_dict[i]
 71 |             if label_name in _names_in_static_label_colors:
 72 |                 color = _static_label_colors[
 73 |                     _names_in_static_label_colors.index(
 74 |                         label_name)]
 75 |             else:
 76 |                 color = np.array((1.0, 1.0, 1.0), np.float32)
 77 |             colors.append(color)
 78 | 
 79 |     if color_offset is not None:
 80 |         ncolors = []
 81 |         for c in colors:
 82 |             nc = np.array(c)
 83 |             if (nc != np.zeros(3)).any():
 84 |                 nc += color_offset
 85 |             ncolors.append(nc)
 86 |         colors = ncolors
 87 | 
 88 |     if image is None:
 89 |         image = orig_image = np.zeros(
 90 |             [labels.shape[0], labels.shape[1], 3], np.float32)
 91 |         alpha = 1.0
 92 |     else:
 93 |         orig_image = image / np.max(image)
 94 |         image = orig_image * (1.0 - default_alpha)
 95 |         alpha = default_alpha
 96 |     for i in range(1, np.max(labels) + 1):
 97 |         image += alpha * \
 98 |                  np.tile(
 99 |                      np.expand_dims(
100 |                          (labels == i).astype(np.float32), -1),
101 |                      [1, 1, 3]) * colors[(i) % len(colors)]
102 |     image[np.where(image > 1.0)] = 1.0
103 |     image[np.where(image < 0)] = 0.0
104 |     image[np.where(bg_mask)] = orig_image[np.where(bg_mask)]
105 |     return image
106 | 
107 | 
108 | def _draw_hwc(image: torch.Tensor, data: Dict[str, torch.Tensor]):
109 |     dtype = image.dtype
110 |     h, w, _ = image.shape
111 | 
112 |     for tag, batch_content in data.items():
113 |         if tag == 'points':
114 |             for content in batch_content:
115 |                 # content: npoints x 2
116 |                 for x, y in content:
117 |                     x = max(min(int(x), w - 1), 0)
118 |                     y = max(min(int(y), h - 1), 0)
119 |                     rr, cc, val = circle_perimeter_aa(y, x, 1)
120 |                     valid = np.all([rr >= 0, rr < h, cc >= 0, cc < w], axis=0)
121 |                     rr = rr[valid]
122 |                     cc = cc[valid]
123 |                     val = val[valid]
124 |                     val = val[:, None][:, [0, 0, 0]]
125 |                     image[rr, cc] = image[rr, cc] * (1.0 - val) + val * 255
126 | 
127 |         if tag == 'seg':
128 |             label_names = batch_content['label_names']
129 |             for seg_logits in batch_content['logits']:
130 |                 # content: nclasses x h x w
131 |                 seg_probs = seg_logits.softmax(dim=0)
132 |                 seg_labels = seg_probs.argmax(dim=0).cpu().numpy()
133 |                 image = (_blend_labels(image.astype(np.float32) /
134 |                                        255, seg_labels,
135 |                                        label_names_dict=label_names) * 255).astype(dtype)
136 | 
137 |     return torch.from_numpy(image).cuda()
138 | 
139 | 
140 | def draw_bchw(images, data):
141 |     image = _draw_hwc(images, data).permute(2, 0, 1).unsqueeze(0)
142 |     return image
143 | 


--------------------------------------------------------------------------------
/util/face_sdk/utils/lms_trans.py:
--------------------------------------------------------------------------------
 1 | """
 2 | @author: JiXuan Xu, Jun Wang
 3 | @date: 20201015
 4 | @contact: jun21wangustc@gmail.com 
 5 | """
 6 | # it's a approximate map
 7 | # 15 --> (99+103)/2
 8 | # 17, 19; 20, 22; 16; 9 will be used in face crop(25 points)
 9 | lms25_2_lms106 = {
10 |     1: 105, 2: 106, 3: 34, 4: 38, 5: 43,
11 |     6: 47, 7: 52, 8: 55, 9: 88, 10: 94,
12 |     11: 85, 12: 91, 13: 63, 14: 59, 15: 99,
13 |     16: 61, 17: 71, 18: 73, 19: 67, 20: 80,
14 |     21: 82, 22: 76, 23: 36, 24: 45, 25: 17
15 | }
16 | 
17 | # 1: left eye center
18 | # 2: right eye center
19 | # 3: nose tip
20 | # 4: left mouth corner
21 | # 5: right mouth corner
22 | lms5_2_lms25 = {1: 1, 2: 2, 3: 8, 4: 11, 5: 12}
23 | lms5_2_lms106 = {1: 105, 2: 106, 3: 55, 4: 85, 5: 91}
24 | 
25 | 
26 | def lms106_2_lms25(lms_106):
27 |     lms25 = []
28 |     for cur_point_index in range(25):
29 |         cur_point_id = cur_point_index + 1
30 |         point_id_106 = lms25_2_lms106[cur_point_id]
31 |         cur_point_index_106 = point_id_106 - 1
32 |         cur_point_x = lms_106[cur_point_index_106 * 2]
33 |         cur_point_y = lms_106[cur_point_index_106 * 2 + 1]
34 |         lms25.append(cur_point_x)
35 |         lms25.append(cur_point_y)
36 |     return lms25
37 | 
38 | 
39 | def lms106_2_lms5(lms_106):
40 |     lms5 = []
41 |     for cur_point_index in range(5):
42 |         cur_point_id = cur_point_index + 1
43 |         point_id_106 = lms5_2_lms106[cur_point_id]
44 |         cur_point_index_106 = point_id_106 - 1
45 |         cur_point_x = lms_106[cur_point_index_106 * 2]
46 |         cur_point_y = lms_106[cur_point_index_106 * 2 + 1]
47 |         lms5.append(cur_point_x)
48 |         lms5.append(cur_point_y)
49 |     return lms5
50 | 
51 | 
52 | def lms25_2_lms5(lms_25):
53 |     lms5 = []
54 |     for cur_point_index in range(5):
55 |         cur_point_id = cur_point_index + 1
56 |         point_id_25 = lms5_2_lms25[cur_point_id]
57 |         cur_point_index_25 = point_id_25 - 1
58 |         cur_point_x = lms_25[cur_point_index_25 * 2]
59 |         cur_point_y = lms_25[cur_point_index_25 * 2 + 1]
60 |         lms5.append(cur_point_x)
61 |         lms5.append(cur_point_y)
62 |     return lms5
63 | 


--------------------------------------------------------------------------------
/util/face_sdk/utils/show.py:
--------------------------------------------------------------------------------
 1 | # based on:
 2 | # https://github.com/FacePerceiver/facer/blob/main/facer/show.py
 3 | import math
 4 | from typing import Optional
 5 | 
 6 | import matplotlib.pyplot as plt
 7 | import torch
 8 | from PIL import Image
 9 | 
10 | 
11 | def bchw2hwc(images: torch.Tensor, nrows: Optional[int] = None, border: int = 2,
12 |     background_value: float = 0
13 | ) -> torch.Tensor:
14 |     """ make a grid image from an image batch.
15 |     Args:
16 |         images (torch.Tensor): input image batch.
17 |         nrows: rows of grid.
18 |         border: border size in pixel.
19 |         background_value: color value of background.
20 |     """
21 |     assert images.ndim == 4  # n x c x h x w
22 |     images = images.permute(0, 2, 3, 1)  # n x h x w x c
23 |     n, h, w, c = images.shape
24 |     if nrows is None:
25 |         nrows = max(int(math.sqrt(n)), 1)
26 |     ncols = (n + nrows - 1) // nrows
27 |     result = torch.full([(h + border) * nrows - border,
28 |                             (w + border) * ncols - border, c], background_value,
29 |                         device=images.device,
30 |                         dtype=images.dtype)
31 | 
32 |     for i, single_image in enumerate(images):
33 |         row = i // ncols
34 |         col = i % ncols
35 |         yy = (h + border) * row
36 |         xx = (w + border) * col
37 |         result[yy:(yy + h), xx:(xx + w), :] = single_image
38 |     return result
39 | 
40 | 
41 | def show_hwc(image: torch.Tensor):
42 |     if image.dtype != torch.uint8:
43 |         image = image.to(torch.uint8)
44 |     if image.size(2) == 1:
45 |         image = image.repeat(1, 1, 3)
46 |     pimage = Image.fromarray(image.cpu().numpy())
47 |     pimage.save("test.jpg")
48 |     plt.imshow(pimage)
49 |     plt.show()
50 | 
51 | 
52 | def show_bchw(image: torch.Tensor):
53 |     show_hwc(bchw2hwc(image))
54 | 


--------------------------------------------------------------------------------
/util/lr_logger.py:
--------------------------------------------------------------------------------
 1 | from pytorch_lightning import Callback, Trainer, LightningModule
 2 | 
 3 | 
 4 | class LrLogger(Callback):
 5 |     """Log learning rate in each epoch start."""
 6 | 
 7 |     def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
 8 |         for i, optimizer in enumerate(trainer.optimizers):
 9 |             for j, params in enumerate(optimizer.param_groups):
10 |                 key = f"opt{i}_lr{j}"
11 |                 value = params["lr"]
12 |                 pl_module.logger.log_metrics({key: value}, step=trainer.global_step)
13 |                 pl_module.log(key, value, logger=False, sync_dist=pl_module.distributed)
14 | 


--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
 1 | from __future__ import annotations
 2 | 
 3 | import json
 4 | import re
 5 | 
 6 | import torch
 7 | from torch import Tensor
 8 | 
 9 | 
10 | def rename_key(dictionary, old_key, new_key):
11 |     dictionary[new_key] = dictionary.pop(old_key)
12 | 
13 | 
14 | def load_official_pretrain_model(model, pth_path):
15 |     state_dict = torch.load(pth_path)["model"]
16 | 
17 |     mutation_keys = [
18 |         ("mask_token", "decoder.mask_token"),
19 |     ]
20 | 
21 |     for old_key in state_dict:
22 |         m = re.match(r"(en|de)coder\.blocks\.(\d+?)\.mlp\.fc([1-2])\.(weight|bias)", old_key)
23 |         if m:
24 |             new_key = f"{m[1]}coder.blocks.{m[2]}.mlp.layers.{int(m[3]) - 1}.linear.{m[4]}"
25 |             mutation_keys.append((old_key, new_key))
26 | 
27 |         if old_key.startswith("encoder_to_decoder"):
28 |             new_key = old_key.replace("encoder_to_decoder", "enc_dec_proj")
29 |             mutation_keys.append((old_key, new_key))
30 | 
31 |         if old_key.startswith("encoder.patch_embed"):
32 |             new_key = old_key.replace("patch_embed.proj", "patch_embedding.projection")
33 |             mutation_keys.append((old_key, new_key))
34 | 
35 |     for old_key, new_key in mutation_keys:
36 |         rename_key(state_dict, old_key, new_key)
37 | 
38 |     return model.load_state_dict(state_dict, strict=False)
39 | 
40 | 
41 | def sample_indexes(total_frames: int, n_frames: int, temporal_sample_rate: int) -> Tensor:
42 |     try:
43 |         start_ind = torch.randint(0, total_frames - (n_frames * temporal_sample_rate) + 1, ())
44 |     except RuntimeError as e:
45 |         print(f"total_frames: {total_frames}, n_frames: {n_frames}, temporal_sample_rate: {temporal_sample_rate}")
46 |         raise e
47 |     return torch.arange(n_frames) * temporal_sample_rate + start_ind
48 | 
49 | 
50 | def read_text(path: str, encoding: str = "UTF-8") -> str:
51 |     with open(path, "r", encoding=encoding) as file:
52 |         text = file.read()
53 |     return text
54 | 
55 | 
56 | def read_json(path: str):
57 |     with open(path, "r") as file:
58 |         return json.load(file)
59 | 


--------------------------------------------------------------------------------
/util/seed.py:
--------------------------------------------------------------------------------
 1 | import random
 2 | from typing import Callable
 3 | 
 4 | import numpy as np
 5 | import torch
 6 | from torch import Generator
 7 | 
 8 | 
 9 | class Seed:
10 |     seed: int = None
11 | 
12 |     @classmethod
13 |     def torch(cls, seed: int) -> None:
14 |         torch.manual_seed(seed)
15 |         torch.cuda.manual_seed(seed)
16 | 
17 |     @classmethod
18 |     def python(cls, seed: int) -> None:
19 |         random.seed(seed)
20 | 
21 |     @classmethod
22 |     def numpy(cls, seed: int) -> None:
23 |         np.random.seed(seed)
24 | 
25 |     @classmethod
26 |     def set(cls, seed: int, use_deterministic_algorithms: bool = False) -> None:
27 |         cls.torch(seed)
28 |         cls.python(seed)
29 |         cls.numpy(seed)
30 |         cls.seed = seed
31 |         torch.use_deterministic_algorithms(use_deterministic_algorithms)
32 | 
33 |     @classmethod
34 |     def _is_set(cls) -> bool:
35 |         return cls.seed is not None
36 | 
37 |     @classmethod
38 |     def get_loader_worker_init(cls) -> Callable[[int], None]:
39 |         def seed_worker(worker_id):
40 |             worker_seed = torch.initial_seed() % 2 ** 32
41 |             np.random.seed(worker_seed)
42 |             random.seed(worker_seed)
43 | 
44 |         if cls._is_set():
45 |             return seed_worker
46 |         else:
47 |             return lambda x: None
48 | 
49 |     @classmethod
50 |     def get_torch_generator(cls, device="cpu") -> Generator:
51 |         g = torch.Generator(device)
52 |         g.manual_seed(cls.seed)
53 |         return g
54 | 


--------------------------------------------------------------------------------
/util/system_stats_logger.py:
--------------------------------------------------------------------------------
 1 | from pytorch_lightning import Callback, Trainer, LightningModule
 2 | 
 3 | 
 4 | class SystemStatsLogger(Callback):
 5 |     """Log system stats for each training epoch"""
 6 | 
 7 |     def __init__(self):
 8 |         try:
 9 |             import psutil
10 |         except ImportError:
11 |             raise ImportError("psutil is required to use SystemStatsLogger")
12 |         self.psutil = psutil
13 | 
14 |     def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
15 |         cpu_usage = self.psutil.cpu_percent()
16 |         memory_usage = self.psutil.virtual_memory().percent
17 |         logged_info = {
18 |             "cpu_usage": cpu_usage,
19 |             "memory_usage": memory_usage
20 |         }
21 |         pl_module.logger.log_metrics(logged_info, step=trainer.global_step)
22 |         pl_module.log_dict(logged_info, logger=False, sync_dist=pl_module.distributed)
23 | 


--------------------------------------------------------------------------------
/version.txt:
--------------------------------------------------------------------------------
1 | 0.3.4


--------------------------------------------------------------------------------