├── .github └── workflows │ ├── publish-pip.yml │ ├── pylint.yml │ └── release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── README_CN.md ├── VERSION ├── assets ├── icon.png ├── icon_small.png ├── landmarks_5.jpg ├── landmarks_68.png ├── landmarks_98.png ├── test.jpg └── test2.jpg ├── facexlib ├── __init__.py ├── alignment │ ├── README.md │ ├── __init__.py │ ├── awing_arch.py │ └── convert_98_to_68_landmarks.py ├── assessment │ ├── __init__.py │ └── hyperiqa_net.py ├── detection │ ├── __init__.py │ ├── align_trans.py │ ├── matlab_cp2tform.py │ ├── retinaface.py │ ├── retinaface_net.py │ └── retinaface_utils.py ├── headpose │ ├── __init__.py │ └── hopenet_arch.py ├── matting │ ├── __init__.py │ ├── backbone.py │ ├── mobilenetv2.py │ └── modnet.py ├── parsing │ ├── __init__.py │ ├── bisenet.py │ ├── parsenet.py │ └── resnet.py ├── recognition │ ├── __init__.py │ └── arcface_arch.py ├── tracking │ ├── README.md │ ├── __init__.py │ ├── data_association.py │ ├── kalman_tracker.py │ └── sort.py ├── utils │ ├── __init__.py │ ├── face_restoration_helper.py │ ├── face_utils.py │ └── misc.py ├── visualization │ ├── __init__.py │ ├── vis_alignment.py │ ├── vis_detection.py │ └── vis_headpose.py └── weights │ └── README.md ├── inference ├── inference_alignment.py ├── inference_crop_standard_faces.py ├── inference_detection.py ├── inference_headpose.py ├── inference_hyperiqa.py ├── inference_matting.py ├── inference_parsing.py ├── inference_parsing_parsenet.py ├── inference_recognition.py └── inference_tracking.py ├── requirements.txt ├── scripts ├── crop_faces_5landmarks.py ├── extract_detection_info_ffhq.py └── get_ffhq_template.py ├── setup.cfg └── setup.py /.github/workflows/publish-pip.yml: -------------------------------------------------------------------------------- 1 | name: PyPI Publish 2 | 3 | on: push 4 | 5 | jobs: 6 | build-n-publish: 7 | runs-on: ubuntu-latest 8 | if: startsWith(github.event.ref, 'refs/tags') 9 | 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 3.8 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: 3.8 16 | - name: Upgrade pip 17 | run: | 18 | pip install pip --upgrade 19 | pip install wheel 20 | 21 | - name: Install PyTorch (cpu) 22 | run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html 23 | - name: Install dependencies 24 | run: pip install -r requirements.txt 25 | - name: Build and install 26 | run: rm -rf .eggs && pip install -e . 27 | - name: Build for distribution 28 | run: python setup.py sdist bdist_wheel 29 | - name: Publish distribution to PyPI 30 | uses: pypa/gh-action-pypi-publish@master 31 | with: 32 | password: ${{ secrets.PYPI_API_TOKEN }} 33 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: PyLint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: [3.8] 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install codespell flake8 isort yapf 24 | 25 | - name: Lint 26 | run: | 27 | codespell 28 | flake8 . 29 | isort --check-only --diff facexlib/ inference/ scripts/ setup.py 30 | yapf -r -d facexlib/ inference/ scripts/ setup.py 31 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: release 2 | on: 3 | push: 4 | tags: 5 | - '*' 6 | 7 | jobs: 8 | build: 9 | permissions: write-all 10 | name: Create Release 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v2 15 | - name: Create Release 16 | id: create_release 17 | uses: actions/create-release@v1 18 | env: 19 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 20 | with: 21 | tag_name: ${{ github.ref }} 22 | release_name: FaceXlib ${{ github.ref }} Release Note 23 | body: | 24 | 🚀 See you again 😸 25 | 🚀Have a nice day 😸 and happy everyday 😃 26 | 🚀 Long time no see ☄️ 27 | 28 | ✨ **Highlights** 29 | ✅ [Features] Support ... 30 | 31 | 🐛 **Bug Fixes** 32 | 33 | 🌴 **Improvements** 34 | 35 | 📢📢📢 36 | 37 |

38 | 39 |

40 | draft: true 41 | prerelease: false 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | *.pth 3 | *.png 4 | *.jpg 5 | version.py 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # flake8 3 | - repo: https://github.com/PyCQA/flake8 4 | rev: 3.8.3 5 | hooks: 6 | - id: flake8 7 | args: ["--config=setup.cfg", "--ignore=W504, W503"] 8 | 9 | # modify known_third_party 10 | - repo: https://github.com/asottile/seed-isort-config 11 | rev: v2.2.0 12 | hooks: 13 | - id: seed-isort-config 14 | 15 | # isort 16 | - repo: https://github.com/timothycrosley/isort 17 | rev: 5.2.2 18 | hooks: 19 | - id: isort 20 | 21 | # yapf 22 | - repo: https://github.com/pre-commit/mirrors-yapf 23 | rev: v0.30.0 24 | hooks: 25 | - id: yapf 26 | 27 | # codespell 28 | - repo: https://github.com/codespell-project/codespell 29 | rev: v2.1.0 30 | hooks: 31 | - id: codespell 32 | 33 | # pre-commit-hooks 34 | - repo: https://github.com/pre-commit/pre-commit-hooks 35 | rev: v3.2.0 36 | hooks: 37 | - id: trailing-whitespace # Trim trailing whitespace 38 | - id: check-yaml # Attempt to load all yaml files to verify syntax 39 | - id: check-merge-conflict # Check for files that contain merge conflict strings 40 | - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings 41 | - id: end-of-file-fixer # Make sure files end in a newline and only a newline 42 | - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0 43 | - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*- 44 | args: ["--remove"] 45 | - id: mixed-line-ending # Replace or check mixed line ending 46 | args: ["--fix=lf"] 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Xintao Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include assets/*.png 2 | include assets/*.jpg 3 | include inference/*.py 4 | include scripts/*.py 5 | include VERSION 6 | include requirements.txt 7 | include facexlib/weights/README.md 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ![icon](assets/icon_small.png) FaceXLib 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/facexlib)](https://pypi.org/project/facexlib/) 4 | [![download](https://img.shields.io/github/downloads/xinntao/facexlib/total.svg)](https://github.com/xinntao/facexlib/releases) 5 | [![Open issue](https://img.shields.io/github/issues/xinntao/facexlib)](https://github.com/xinntao/facexlib/issues) 6 | [![Closed issue](https://img.shields.io/github/issues-closed/xinntao/facexlib)](https://github.com/xinntao/facexlib/issues) 7 | [![LICENSE](https://img.shields.io/github/license/xinntao/facexlib.svg)](https://github.com/xinntao/facexlib/blob/master/LICENSE) 8 | [![python lint](https://github.com/xinntao/facexlib/actions/workflows/pylint.yml/badge.svg)](https://github.com/xinntao/facexlib/blob/master/.github/workflows/pylint.yml) 9 | [![Publish-pip](https://github.com/xinntao/facexlib/actions/workflows/publish-pip.yml/badge.svg)](https://github.com/xinntao/facexlib/blob/master/.github/workflows/publish-pip.yml) 10 | 11 | [English](README.md) **|** [简体中文](README_CN.md) 12 | 13 | --- 14 | 15 | **facexlib** aims at providing ready-to-use **face-related** functions based on current SOTA open-source methods.
16 | Only PyTorch reference codes are available. For training or fine-tuning, please refer to their original repositories listed below.
17 | Note that we just provide a collection of these algorithms. You need to refer to their original LICENCEs for your intended use. 18 | 19 | If facexlib is helpful in your projects, please help to :star: this repo. Thanks:blush:
20 | Other recommended projects:   :arrow_forward: [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN)   :arrow_forward: [GFPGAN](https://github.com/TencentARC/GFPGAN)   :arrow_forward: [BasicSR](https://github.com/xinntao/BasicSR) 21 | 22 | --- 23 | 24 | ## :sparkles: Functions 25 | 26 | | Function | Sources | Original LICENSE | 27 | | :--- | :---: | :---: | 28 | | [Detection](facexlib/detection/README.md) | [Pytorch_Retinaface](https://github.com/biubug6/Pytorch_Retinaface) | MIT | 29 | | [Alignment](facexlib/alignment/README.md) |[AdaptiveWingLoss](https://github.com/protossw512/AdaptiveWingLoss) | Apache 2.0 | 30 | | [Recognition](facexlib/recognition/README.md) | [InsightFace_Pytorch](https://github.com/TreB1eN/InsightFace_Pytorch) | MIT | 31 | | [Parsing](facexlib/parsing/README.md) | [face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch) | MIT | 32 | | [Matting](facexlib/matting/README.md) | [MODNet](https://github.com/ZHKKKe/MODNet) | CC 4.0 | 33 | | [Headpose](facexlib/headpose/README.md) | [deep-head-pose](https://github.com/natanielruiz/deep-head-pose) | Apache 2.0 | 34 | | [Tracking](facexlib/tracking/README.md) | [SORT](https://github.com/abewley/sort) | GPL 3.0 | 35 | | [Assessment](facexlib/assessment/README.md) | [hyperIQA](https://github.com/SSL92/hyperIQA) | - | 36 | | [Utils](facexlib/utils/README.md) | Face Restoration Helper | - | 37 | 38 | ## :eyes: Demo and Tutorials 39 | 40 | ## :wrench: Dependencies and Installation 41 | 42 | - Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html)) 43 | - [PyTorch >= 1.7](https://pytorch.org/) 44 | - Option: NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) 45 | 46 | ### Installation 47 | 48 | ```bash 49 | pip install facexlib 50 | ``` 51 | 52 | ### Pre-trained models 53 | 54 | It will **automatically** download pre-trained models at the first inference.
55 | If your network is not stable, you can download in advance (may with other download tools), and put them in the folder: `PACKAGE_ROOT_PATH/facexlib/weights`. 56 | 57 | ## :scroll: License and Acknowledgement 58 | 59 | This project is released under the MIT license.
60 | 61 | ## :e-mail: Contact 62 | 63 | If you have any question, open an issue or email `xintao.wang@outlook.com`. 64 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | # ![icon](assets/icon_small.png) FaceXLib 2 | 3 | [English](README.md) **|** [简体中文](README_CN.md) 4 | 5 | --- 6 | 7 | `facexlib` is a **pytorch-based** library for **face-related** functions, such as detection, alignment, recognition, tracking, utils for face restorations, *etc*. 8 | It only provides inference (without training). 9 | This repo is based current STOA open-source methods (see [more details](#Functions)). 10 | 11 | ## :eyes: Demo 12 | 13 | ## :wrench: Dependencies and Installation 14 | 15 | - Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html)) 16 | - [PyTorch >= 1.3](https://pytorch.org/) 17 | - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) 18 | 19 | ## :sparkles: Functions 20 | 21 | | Function | Description | Reference | 22 | | :--- | :---: | :---: | 23 | | Detection | ([More details](detection/README.md) | [Pytorch_Retinaface](https://github.com/biubug6/Pytorch_Retinaface) | 24 | | Alignment | ([More details](alignment/README.md) | [AdaptiveWingLoss](https://github.com/protossw512/AdaptiveWingLoss) | 25 | | Recognition | ([More details](recognition/README.md) | [InsightFace_Pytorch](https://github.com/TreB1eN/InsightFace_Pytorch) | 26 | | Tracking | ([More details](tracking/README.md) | [SORT](https://github.com/abewley/sort) | 27 | | Utils | ([More details](utils/README.md)) | | 28 | 29 | ## :scroll: License and Acknowledgement 30 | 31 | This project is released under the MIT license.
32 | 33 | ## :e-mail: Contact 34 | 35 | If you have any question, open an issue or email `xintao.wang@outlook.com`. 36 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.3.0 2 | -------------------------------------------------------------------------------- /assets/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/assets/icon.png -------------------------------------------------------------------------------- /assets/icon_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/assets/icon_small.png -------------------------------------------------------------------------------- /assets/landmarks_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/assets/landmarks_5.jpg -------------------------------------------------------------------------------- /assets/landmarks_68.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/assets/landmarks_68.png -------------------------------------------------------------------------------- /assets/landmarks_98.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/assets/landmarks_98.png -------------------------------------------------------------------------------- /assets/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/assets/test.jpg -------------------------------------------------------------------------------- /assets/test2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/assets/test2.jpg -------------------------------------------------------------------------------- /facexlib/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .alignment import * 3 | from .detection import * 4 | from .recognition import * 5 | from .tracking import * 6 | from .utils import * 7 | from .version import __gitsha__, __version__ 8 | from .visualization import * 9 | -------------------------------------------------------------------------------- /facexlib/alignment/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Landmarks 3 | 4 | - 5 landmarks 5 | 6 |

7 | 8 |

9 | 10 | - 68 landmarks 11 | 12 |

13 | 14 |

15 | 16 | - 98 landmarks 17 | 18 |

19 | 20 |

21 | -------------------------------------------------------------------------------- /facexlib/alignment/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from facexlib.utils import load_file_from_url 4 | from .awing_arch import FAN 5 | from .convert_98_to_68_landmarks import landmark_98_to_68 6 | 7 | __all__ = ['FAN', 'landmark_98_to_68'] 8 | 9 | 10 | def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None): 11 | if model_name == 'awing_fan': 12 | model = FAN(num_modules=4, num_landmarks=98, device=device) 13 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth' 14 | else: 15 | raise NotImplementedError(f'{model_name} is not implemented.') 16 | 17 | model_path = load_file_from_url( 18 | url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) 19 | model.load_state_dict(torch.load(model_path)['state_dict'], strict=True) 20 | model.eval() 21 | model = model.to(device) 22 | return model 23 | -------------------------------------------------------------------------------- /facexlib/alignment/awing_arch.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def calculate_points(heatmaps): 9 | # change heatmaps to landmarks 10 | B, N, H, W = heatmaps.shape 11 | HW = H * W 12 | BN_range = np.arange(B * N) 13 | 14 | heatline = heatmaps.reshape(B, N, HW) 15 | indexes = np.argmax(heatline, axis=2) 16 | 17 | preds = np.stack((indexes % W, indexes // W), axis=2) 18 | preds = preds.astype(np.float, copy=False) 19 | 20 | inr = indexes.ravel() 21 | 22 | heatline = heatline.reshape(B * N, HW) 23 | x_up = heatline[BN_range, inr + 1] 24 | x_down = heatline[BN_range, inr - 1] 25 | # y_up = heatline[BN_range, inr + W] 26 | 27 | if any((inr + W) >= 4096): 28 | y_up = heatline[BN_range, 4095] 29 | else: 30 | y_up = heatline[BN_range, inr + W] 31 | if any((inr - W) <= 0): 32 | y_down = heatline[BN_range, 0] 33 | else: 34 | y_down = heatline[BN_range, inr - W] 35 | 36 | think_diff = np.sign(np.stack((x_up - x_down, y_up - y_down), axis=1)) 37 | think_diff *= .25 38 | 39 | preds += think_diff.reshape(B, N, 2) 40 | preds += .5 41 | return preds 42 | 43 | 44 | class AddCoordsTh(nn.Module): 45 | 46 | def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False): 47 | super(AddCoordsTh, self).__init__() 48 | self.x_dim = x_dim 49 | self.y_dim = y_dim 50 | self.with_r = with_r 51 | self.with_boundary = with_boundary 52 | 53 | def forward(self, input_tensor, heatmap=None): 54 | """ 55 | input_tensor: (batch, c, x_dim, y_dim) 56 | """ 57 | batch_size_tensor = input_tensor.shape[0] 58 | 59 | xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32, device=input_tensor.device) 60 | xx_ones = xx_ones.unsqueeze(-1) 61 | 62 | xx_range = torch.arange(self.x_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0) 63 | xx_range = xx_range.unsqueeze(1) 64 | 65 | xx_channel = torch.matmul(xx_ones.float(), xx_range.float()) 66 | xx_channel = xx_channel.unsqueeze(-1) 67 | 68 | yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32, device=input_tensor.device) 69 | yy_ones = yy_ones.unsqueeze(1) 70 | 71 | yy_range = torch.arange(self.y_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0) 72 | yy_range = yy_range.unsqueeze(-1) 73 | 74 | yy_channel = torch.matmul(yy_range.float(), yy_ones.float()) 75 | yy_channel = yy_channel.unsqueeze(-1) 76 | 77 | xx_channel = xx_channel.permute(0, 3, 2, 1) 78 | yy_channel = yy_channel.permute(0, 3, 2, 1) 79 | 80 | xx_channel = xx_channel / (self.x_dim - 1) 81 | yy_channel = yy_channel / (self.y_dim - 1) 82 | 83 | xx_channel = xx_channel * 2 - 1 84 | yy_channel = yy_channel * 2 - 1 85 | 86 | xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1) 87 | yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1) 88 | 89 | if self.with_boundary and heatmap is not None: 90 | boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0) 91 | 92 | zero_tensor = torch.zeros_like(xx_channel) 93 | xx_boundary_channel = torch.where(boundary_channel > 0.05, xx_channel, zero_tensor) 94 | yy_boundary_channel = torch.where(boundary_channel > 0.05, yy_channel, zero_tensor) 95 | if self.with_boundary and heatmap is not None: 96 | xx_boundary_channel = xx_boundary_channel.to(input_tensor.device) 97 | yy_boundary_channel = yy_boundary_channel.to(input_tensor.device) 98 | ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) 99 | 100 | if self.with_r: 101 | rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2)) 102 | rr = rr / torch.max(rr) 103 | ret = torch.cat([ret, rr], dim=1) 104 | 105 | if self.with_boundary and heatmap is not None: 106 | ret = torch.cat([ret, xx_boundary_channel, yy_boundary_channel], dim=1) 107 | return ret 108 | 109 | 110 | class CoordConvTh(nn.Module): 111 | """CoordConv layer as in the paper.""" 112 | 113 | def __init__(self, x_dim, y_dim, with_r, with_boundary, in_channels, first_one=False, *args, **kwargs): 114 | super(CoordConvTh, self).__init__() 115 | self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r, with_boundary=with_boundary) 116 | in_channels += 2 117 | if with_r: 118 | in_channels += 1 119 | if with_boundary and not first_one: 120 | in_channels += 2 121 | self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs) 122 | 123 | def forward(self, input_tensor, heatmap=None): 124 | ret = self.addcoords(input_tensor, heatmap) 125 | last_channel = ret[:, -2:, :, :] 126 | ret = self.conv(ret) 127 | return ret, last_channel 128 | 129 | 130 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False, dilation=1): 131 | '3x3 convolution with padding' 132 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=strd, padding=padding, bias=bias, dilation=dilation) 133 | 134 | 135 | class BasicBlock(nn.Module): 136 | expansion = 1 137 | 138 | def __init__(self, inplanes, planes, stride=1, downsample=None): 139 | super(BasicBlock, self).__init__() 140 | self.conv1 = conv3x3(inplanes, planes, stride) 141 | # self.bn1 = nn.BatchNorm2d(planes) 142 | self.relu = nn.ReLU(inplace=True) 143 | self.conv2 = conv3x3(planes, planes) 144 | # self.bn2 = nn.BatchNorm2d(planes) 145 | self.downsample = downsample 146 | self.stride = stride 147 | 148 | def forward(self, x): 149 | residual = x 150 | 151 | out = self.conv1(x) 152 | out = self.relu(out) 153 | 154 | out = self.conv2(out) 155 | 156 | if self.downsample is not None: 157 | residual = self.downsample(x) 158 | 159 | out += residual 160 | out = self.relu(out) 161 | 162 | return out 163 | 164 | 165 | class ConvBlock(nn.Module): 166 | 167 | def __init__(self, in_planes, out_planes): 168 | super(ConvBlock, self).__init__() 169 | self.bn1 = nn.BatchNorm2d(in_planes) 170 | self.conv1 = conv3x3(in_planes, int(out_planes / 2)) 171 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 172 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), padding=1, dilation=1) 173 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 174 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), padding=1, dilation=1) 175 | 176 | if in_planes != out_planes: 177 | self.downsample = nn.Sequential( 178 | nn.BatchNorm2d(in_planes), 179 | nn.ReLU(True), 180 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False), 181 | ) 182 | else: 183 | self.downsample = None 184 | 185 | def forward(self, x): 186 | residual = x 187 | 188 | out1 = self.bn1(x) 189 | out1 = F.relu(out1, True) 190 | out1 = self.conv1(out1) 191 | 192 | out2 = self.bn2(out1) 193 | out2 = F.relu(out2, True) 194 | out2 = self.conv2(out2) 195 | 196 | out3 = self.bn3(out2) 197 | out3 = F.relu(out3, True) 198 | out3 = self.conv3(out3) 199 | 200 | out3 = torch.cat((out1, out2, out3), 1) 201 | 202 | if self.downsample is not None: 203 | residual = self.downsample(residual) 204 | 205 | out3 += residual 206 | 207 | return out3 208 | 209 | 210 | class HourGlass(nn.Module): 211 | 212 | def __init__(self, num_modules, depth, num_features, first_one=False): 213 | super(HourGlass, self).__init__() 214 | self.num_modules = num_modules 215 | self.depth = depth 216 | self.features = num_features 217 | self.coordconv = CoordConvTh( 218 | x_dim=64, 219 | y_dim=64, 220 | with_r=True, 221 | with_boundary=True, 222 | in_channels=256, 223 | first_one=first_one, 224 | out_channels=256, 225 | kernel_size=1, 226 | stride=1, 227 | padding=0) 228 | self._generate_network(self.depth) 229 | 230 | def _generate_network(self, level): 231 | self.add_module('b1_' + str(level), ConvBlock(256, 256)) 232 | 233 | self.add_module('b2_' + str(level), ConvBlock(256, 256)) 234 | 235 | if level > 1: 236 | self._generate_network(level - 1) 237 | else: 238 | self.add_module('b2_plus_' + str(level), ConvBlock(256, 256)) 239 | 240 | self.add_module('b3_' + str(level), ConvBlock(256, 256)) 241 | 242 | def _forward(self, level, inp): 243 | # Upper branch 244 | up1 = inp 245 | up1 = self._modules['b1_' + str(level)](up1) 246 | 247 | # Lower branch 248 | low1 = F.avg_pool2d(inp, 2, stride=2) 249 | low1 = self._modules['b2_' + str(level)](low1) 250 | 251 | if level > 1: 252 | low2 = self._forward(level - 1, low1) 253 | else: 254 | low2 = low1 255 | low2 = self._modules['b2_plus_' + str(level)](low2) 256 | 257 | low3 = low2 258 | low3 = self._modules['b3_' + str(level)](low3) 259 | 260 | up2 = F.interpolate(low3, scale_factor=2, mode='nearest') 261 | 262 | return up1 + up2 263 | 264 | def forward(self, x, heatmap): 265 | x, last_channel = self.coordconv(x, heatmap) 266 | return self._forward(self.depth, x), last_channel 267 | 268 | 269 | class FAN(nn.Module): 270 | 271 | def __init__(self, num_modules=1, end_relu=False, gray_scale=False, num_landmarks=68, device='cuda'): 272 | super(FAN, self).__init__() 273 | self.device = device 274 | self.num_modules = num_modules 275 | self.gray_scale = gray_scale 276 | self.end_relu = end_relu 277 | self.num_landmarks = num_landmarks 278 | 279 | # Base part 280 | if self.gray_scale: 281 | self.conv1 = CoordConvTh( 282 | x_dim=256, 283 | y_dim=256, 284 | with_r=True, 285 | with_boundary=False, 286 | in_channels=3, 287 | out_channels=64, 288 | kernel_size=7, 289 | stride=2, 290 | padding=3) 291 | else: 292 | self.conv1 = CoordConvTh( 293 | x_dim=256, 294 | y_dim=256, 295 | with_r=True, 296 | with_boundary=False, 297 | in_channels=3, 298 | out_channels=64, 299 | kernel_size=7, 300 | stride=2, 301 | padding=3) 302 | self.bn1 = nn.BatchNorm2d(64) 303 | self.conv2 = ConvBlock(64, 128) 304 | self.conv3 = ConvBlock(128, 128) 305 | self.conv4 = ConvBlock(128, 256) 306 | 307 | # Stacking part 308 | for hg_module in range(self.num_modules): 309 | if hg_module == 0: 310 | first_one = True 311 | else: 312 | first_one = False 313 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256, first_one)) 314 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) 315 | self.add_module('conv_last' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 316 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) 317 | self.add_module('l' + str(hg_module), nn.Conv2d(256, num_landmarks + 1, kernel_size=1, stride=1, padding=0)) 318 | 319 | if hg_module < self.num_modules - 1: 320 | self.add_module('bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) 321 | self.add_module('al' + str(hg_module), 322 | nn.Conv2d(num_landmarks + 1, 256, kernel_size=1, stride=1, padding=0)) 323 | 324 | def forward(self, x): 325 | x, _ = self.conv1(x) 326 | x = F.relu(self.bn1(x), True) 327 | # x = F.relu(self.bn1(self.conv1(x)), True) 328 | x = F.avg_pool2d(self.conv2(x), 2, stride=2) 329 | x = self.conv3(x) 330 | x = self.conv4(x) 331 | 332 | previous = x 333 | 334 | outputs = [] 335 | boundary_channels = [] 336 | tmp_out = None 337 | for i in range(self.num_modules): 338 | hg, boundary_channel = self._modules['m' + str(i)](previous, tmp_out) 339 | 340 | ll = hg 341 | ll = self._modules['top_m_' + str(i)](ll) 342 | 343 | ll = F.relu(self._modules['bn_end' + str(i)](self._modules['conv_last' + str(i)](ll)), True) 344 | 345 | # Predict heatmaps 346 | tmp_out = self._modules['l' + str(i)](ll) 347 | if self.end_relu: 348 | tmp_out = F.relu(tmp_out) # HACK: Added relu 349 | outputs.append(tmp_out) 350 | boundary_channels.append(boundary_channel) 351 | 352 | if i < self.num_modules - 1: 353 | ll = self._modules['bl' + str(i)](ll) 354 | tmp_out_ = self._modules['al' + str(i)](tmp_out) 355 | previous = previous + ll + tmp_out_ 356 | 357 | return outputs, boundary_channels 358 | 359 | def get_landmarks(self, img): 360 | H, W, _ = img.shape 361 | offset = W / 64, H / 64, 0, 0 362 | 363 | img = cv2.resize(img, (256, 256)) 364 | inp = img[..., ::-1] 365 | inp = torch.from_numpy(np.ascontiguousarray(inp.transpose((2, 0, 1)))).float() 366 | inp = inp.to(self.device) 367 | inp.div_(255.0).unsqueeze_(0) 368 | 369 | outputs, _ = self.forward(inp) 370 | out = outputs[-1][:, :-1, :, :] 371 | heatmaps = out.detach().cpu().numpy() 372 | 373 | pred = calculate_points(heatmaps).reshape(-1, 2) 374 | 375 | pred *= offset[:2] 376 | pred += offset[-2:] 377 | 378 | return pred 379 | -------------------------------------------------------------------------------- /facexlib/alignment/convert_98_to_68_landmarks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def load_txt_file(file_path): 5 | """Load data or string from txt file.""" 6 | 7 | with open(file_path, 'r') as cfile: 8 | content = cfile.readlines() 9 | cfile.close() 10 | content = [x.strip() for x in content] 11 | num_lines = len(content) 12 | return content, num_lines 13 | 14 | 15 | def anno_parser(anno_path, num_pts, line_offset=0): 16 | """Parse the annotation. 17 | Args: 18 | anno_path: path of anno file (suffix .txt) 19 | num_pts: number of landmarks. 20 | line_offset: first point starts, default: 0. 21 | 22 | Returns: 23 | pts: num_pts x 2 (x, y) 24 | """ 25 | 26 | data, _ = load_txt_file(anno_path) 27 | n_points = num_pts 28 | # read points coordinate. 29 | pts = np.zeros((n_points, 2), dtype='float32') 30 | for point_index in range(n_points): 31 | try: 32 | pts_list = data[point_index + line_offset].split(',') 33 | pts[point_index, 0] = float(pts_list[0]) 34 | pts[point_index, 1] = float(pts_list[1]) 35 | except ValueError: 36 | print(f'Error in loading points in {anno_path}') 37 | return pts 38 | 39 | 40 | def landmark_98_to_68(landmark_98): 41 | """Transfer 98 landmark positions to 68 landmark positions. 42 | Args: 43 | landmark_98(numpy array): Polar coordinates of 98 landmarks, (98, 2) 44 | Returns: 45 | landmark_68(numpy array): Polar coordinates of 98 landmarks, (68, 2) 46 | """ 47 | 48 | landmark_68 = np.zeros((68, 2), dtype='float32') 49 | # cheek 50 | for i in range(0, 33): 51 | if i % 2 == 0: 52 | landmark_68[int(i / 2), :] = landmark_98[i, :] 53 | # nose 54 | for i in range(51, 60): 55 | landmark_68[i - 24, :] = landmark_98[i, :] 56 | # mouth 57 | for i in range(76, 96): 58 | landmark_68[i - 28, :] = landmark_98[i, :] 59 | # left eyebrow 60 | landmark_68[17, :] = landmark_98[33, :] 61 | landmark_68[18, :] = (landmark_98[34, :] + landmark_98[41, :]) / 2 62 | landmark_68[19, :] = (landmark_98[35, :] + landmark_98[40, :]) / 2 63 | landmark_68[20, :] = (landmark_98[36, :] + landmark_98[39, :]) / 2 64 | landmark_68[21, :] = (landmark_98[37, :] + landmark_98[38, :]) / 2 65 | # right eyebrow 66 | landmark_68[22, :] = (landmark_98[42, :] + landmark_98[50, :]) / 2 67 | landmark_68[23, :] = (landmark_98[43, :] + landmark_98[49, :]) / 2 68 | landmark_68[24, :] = (landmark_98[44, :] + landmark_98[48, :]) / 2 69 | landmark_68[25, :] = (landmark_98[45, :] + landmark_98[47, :]) / 2 70 | landmark_68[26, :] = landmark_98[46, :] 71 | # left eye 72 | LUT_landmark_68_left_eye = [36, 37, 38, 39, 40, 41] 73 | LUT_landmark_98_left_eye = [60, 61, 63, 64, 65, 67] 74 | for idx, landmark_98_index in enumerate(LUT_landmark_98_left_eye): 75 | landmark_68[LUT_landmark_68_left_eye[idx], :] = landmark_98[landmark_98_index, :] 76 | # right eye 77 | LUT_landmark_68_right_eye = [42, 43, 44, 45, 46, 47] 78 | LUT_landmark_98_right_eye = [68, 69, 71, 72, 73, 75] 79 | for idx, landmark_98_index in enumerate(LUT_landmark_98_right_eye): 80 | landmark_68[LUT_landmark_68_right_eye[idx], :] = landmark_98[landmark_98_index, :] 81 | 82 | return landmark_68 83 | -------------------------------------------------------------------------------- /facexlib/assessment/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from facexlib.utils import load_file_from_url 4 | from .hyperiqa_net import HyperIQA 5 | 6 | 7 | def init_assessment_model(model_name, half=False, device='cuda', model_rootpath=None): 8 | if model_name == 'hypernet': 9 | model = HyperIQA(16, 112, 224, 112, 56, 28, 14, 7) 10 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/assessment_hyperIQA.pth' 11 | else: 12 | raise NotImplementedError(f'{model_name} is not implemented.') 13 | 14 | # load the pre-trained hypernet model 15 | hypernet_model_path = load_file_from_url( 16 | url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) 17 | model.hypernet.load_state_dict((torch.load(hypernet_model_path, map_location=lambda storage, loc: storage))) 18 | model = model.eval() 19 | model = model.to(device) 20 | return model 21 | -------------------------------------------------------------------------------- /facexlib/assessment/hyperiqa_net.py: -------------------------------------------------------------------------------- 1 | import torch as torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class HyperIQA(nn.Module): 7 | """ 8 | Combine the hypernet and target network within a network. 9 | """ 10 | 11 | def __init__(self, *args): 12 | super(HyperIQA, self).__init__() 13 | self.hypernet = HyperNet(*args) 14 | 15 | def forward(self, img): 16 | net_params = self.hypernet(img) 17 | # build the target network 18 | target_net = TargetNet(net_params) 19 | for param in target_net.parameters(): 20 | param.requires_grad = False 21 | # predict the face quality 22 | pred = target_net(net_params['target_in_vec']) 23 | return pred 24 | 25 | 26 | class HyperNet(nn.Module): 27 | """ 28 | Hyper network for learning perceptual rules. 29 | Args: 30 | lda_out_channels: local distortion aware module output size. 31 | hyper_in_channels: input feature channels for hyper network. 32 | target_in_size: input vector size for target network. 33 | target_fc(i)_size: fully connection layer size of target network. 34 | feature_size: input feature map width/height for hyper network. 35 | Note: 36 | For size match, input args must satisfy: 'target_fc(i)_size * target_fc(i+1)_size' is divisible by 'feature_size ^ 2'. # noqa E501 37 | """ 38 | 39 | def __init__(self, lda_out_channels, hyper_in_channels, target_in_size, target_fc1_size, target_fc2_size, 40 | target_fc3_size, target_fc4_size, feature_size): 41 | super(HyperNet, self).__init__() 42 | 43 | self.hyperInChn = hyper_in_channels 44 | self.target_in_size = target_in_size 45 | self.f1 = target_fc1_size 46 | self.f2 = target_fc2_size 47 | self.f3 = target_fc3_size 48 | self.f4 = target_fc4_size 49 | self.feature_size = feature_size 50 | 51 | self.res = resnet50_backbone(lda_out_channels, target_in_size) 52 | 53 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) 54 | 55 | # Conv layers for resnet output features 56 | self.conv1 = nn.Sequential( 57 | nn.Conv2d(2048, 1024, 1, padding=(0, 0)), nn.ReLU(inplace=True), nn.Conv2d(1024, 512, 1, padding=(0, 0)), 58 | nn.ReLU(inplace=True), nn.Conv2d(512, self.hyperInChn, 1, padding=(0, 0)), nn.ReLU(inplace=True)) 59 | 60 | # Hyper network part, conv for generating target fc weights, fc for generating target fc biases 61 | self.fc1w_conv = nn.Conv2d( 62 | self.hyperInChn, int(self.target_in_size * self.f1 / feature_size**2), 3, padding=(1, 1)) 63 | self.fc1b_fc = nn.Linear(self.hyperInChn, self.f1) 64 | 65 | self.fc2w_conv = nn.Conv2d(self.hyperInChn, int(self.f1 * self.f2 / feature_size**2), 3, padding=(1, 1)) 66 | self.fc2b_fc = nn.Linear(self.hyperInChn, self.f2) 67 | 68 | self.fc3w_conv = nn.Conv2d(self.hyperInChn, int(self.f2 * self.f3 / feature_size**2), 3, padding=(1, 1)) 69 | self.fc3b_fc = nn.Linear(self.hyperInChn, self.f3) 70 | 71 | self.fc4w_conv = nn.Conv2d(self.hyperInChn, int(self.f3 * self.f4 / feature_size**2), 3, padding=(1, 1)) 72 | self.fc4b_fc = nn.Linear(self.hyperInChn, self.f4) 73 | 74 | self.fc5w_fc = nn.Linear(self.hyperInChn, self.f4) 75 | self.fc5b_fc = nn.Linear(self.hyperInChn, 1) 76 | 77 | def forward(self, img): 78 | feature_size = self.feature_size 79 | 80 | res_out = self.res(img) 81 | 82 | # input vector for target net 83 | target_in_vec = res_out['target_in_vec'].view(-1, self.target_in_size, 1, 1) 84 | 85 | # input features for hyper net 86 | hyper_in_feat = self.conv1(res_out['hyper_in_feat']).view(-1, self.hyperInChn, feature_size, feature_size) 87 | 88 | # generating target net weights & biases 89 | target_fc1w = self.fc1w_conv(hyper_in_feat).view(-1, self.f1, self.target_in_size, 1, 1) 90 | target_fc1b = self.fc1b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f1) 91 | 92 | target_fc2w = self.fc2w_conv(hyper_in_feat).view(-1, self.f2, self.f1, 1, 1) 93 | target_fc2b = self.fc2b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f2) 94 | 95 | target_fc3w = self.fc3w_conv(hyper_in_feat).view(-1, self.f3, self.f2, 1, 1) 96 | target_fc3b = self.fc3b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f3) 97 | 98 | target_fc4w = self.fc4w_conv(hyper_in_feat).view(-1, self.f4, self.f3, 1, 1) 99 | target_fc4b = self.fc4b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f4) 100 | 101 | target_fc5w = self.fc5w_fc(self.pool(hyper_in_feat).squeeze()).view(-1, 1, self.f4, 1, 1) 102 | target_fc5b = self.fc5b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, 1) 103 | 104 | out = {} 105 | out['target_in_vec'] = target_in_vec 106 | out['target_fc1w'] = target_fc1w 107 | out['target_fc1b'] = target_fc1b 108 | out['target_fc2w'] = target_fc2w 109 | out['target_fc2b'] = target_fc2b 110 | out['target_fc3w'] = target_fc3w 111 | out['target_fc3b'] = target_fc3b 112 | out['target_fc4w'] = target_fc4w 113 | out['target_fc4b'] = target_fc4b 114 | out['target_fc5w'] = target_fc5w 115 | out['target_fc5b'] = target_fc5b 116 | 117 | return out 118 | 119 | 120 | class Bottleneck(nn.Module): 121 | expansion = 4 122 | 123 | def __init__(self, inplanes, planes, stride=1, downsample=None): 124 | super(Bottleneck, self).__init__() 125 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 126 | self.bn1 = nn.BatchNorm2d(planes) 127 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 128 | self.bn2 = nn.BatchNorm2d(planes) 129 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 130 | self.bn3 = nn.BatchNorm2d(planes * 4) 131 | self.relu = nn.ReLU(inplace=True) 132 | self.downsample = downsample 133 | self.stride = stride 134 | 135 | def forward(self, x): 136 | residual = x 137 | 138 | out = self.conv1(x) 139 | out = self.bn1(out) 140 | out = self.relu(out) 141 | 142 | out = self.conv2(out) 143 | out = self.bn2(out) 144 | out = self.relu(out) 145 | 146 | out = self.conv3(out) 147 | out = self.bn3(out) 148 | 149 | if self.downsample is not None: 150 | residual = self.downsample(x) 151 | 152 | out += residual 153 | out = self.relu(out) 154 | 155 | return out 156 | 157 | 158 | class ResNetBackbone(nn.Module): 159 | 160 | def __init__(self, lda_out_channels, in_chn, block, layers, num_classes=1000): 161 | super(ResNetBackbone, self).__init__() 162 | self.inplanes = 64 163 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 164 | self.bn1 = nn.BatchNorm2d(64) 165 | self.relu = nn.ReLU(inplace=True) 166 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 167 | self.layer1 = self._make_layer(block, 64, layers[0]) 168 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 169 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 170 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 171 | 172 | # local distortion aware module 173 | self.lda1_pool = nn.Sequential( 174 | nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0, bias=False), 175 | nn.AvgPool2d(7, stride=7), 176 | ) 177 | self.lda1_fc = nn.Linear(16 * 64, lda_out_channels) 178 | 179 | self.lda2_pool = nn.Sequential( 180 | nn.Conv2d(512, 32, kernel_size=1, stride=1, padding=0, bias=False), 181 | nn.AvgPool2d(7, stride=7), 182 | ) 183 | self.lda2_fc = nn.Linear(32 * 16, lda_out_channels) 184 | 185 | self.lda3_pool = nn.Sequential( 186 | nn.Conv2d(1024, 64, kernel_size=1, stride=1, padding=0, bias=False), 187 | nn.AvgPool2d(7, stride=7), 188 | ) 189 | self.lda3_fc = nn.Linear(64 * 4, lda_out_channels) 190 | 191 | self.lda4_pool = nn.AvgPool2d(7, stride=7) 192 | self.lda4_fc = nn.Linear(2048, in_chn - lda_out_channels * 3) 193 | 194 | def _make_layer(self, block, planes, blocks, stride=1): 195 | downsample = None 196 | if stride != 1 or self.inplanes != planes * block.expansion: 197 | downsample = nn.Sequential( 198 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 199 | nn.BatchNorm2d(planes * block.expansion), 200 | ) 201 | 202 | layers = [] 203 | layers.append(block(self.inplanes, planes, stride, downsample)) 204 | self.inplanes = planes * block.expansion 205 | for i in range(1, blocks): 206 | layers.append(block(self.inplanes, planes)) 207 | 208 | return nn.Sequential(*layers) 209 | 210 | def forward(self, x): 211 | x = self.conv1(x) 212 | x = self.bn1(x) 213 | x = self.relu(x) 214 | x = self.maxpool(x) 215 | x = self.layer1(x) 216 | 217 | # the same effect as lda operation in the paper, but save much more memory 218 | lda_1 = self.lda1_fc(self.lda1_pool(x).view(x.size(0), -1)) 219 | x = self.layer2(x) 220 | lda_2 = self.lda2_fc(self.lda2_pool(x).view(x.size(0), -1)) 221 | x = self.layer3(x) 222 | lda_3 = self.lda3_fc(self.lda3_pool(x).view(x.size(0), -1)) 223 | x = self.layer4(x) 224 | lda_4 = self.lda4_fc(self.lda4_pool(x).view(x.size(0), -1)) 225 | 226 | vec = torch.cat((lda_1, lda_2, lda_3, lda_4), 1) 227 | 228 | out = {} 229 | out['hyper_in_feat'] = x 230 | out['target_in_vec'] = vec 231 | 232 | return out 233 | 234 | 235 | def resnet50_backbone(lda_out_channels, in_chn, **kwargs): 236 | """Constructs a ResNet-50 model_hyper.""" 237 | model = ResNetBackbone(lda_out_channels, in_chn, Bottleneck, [3, 4, 6, 3], **kwargs) 238 | return model 239 | 240 | 241 | class TargetNet(nn.Module): 242 | """ 243 | Target network for quality prediction. 244 | """ 245 | 246 | def __init__(self, paras): 247 | super(TargetNet, self).__init__() 248 | self.l1 = nn.Sequential( 249 | TargetFC(paras['target_fc1w'], paras['target_fc1b']), 250 | nn.Sigmoid(), 251 | ) 252 | self.l2 = nn.Sequential( 253 | TargetFC(paras['target_fc2w'], paras['target_fc2b']), 254 | nn.Sigmoid(), 255 | ) 256 | 257 | self.l3 = nn.Sequential( 258 | TargetFC(paras['target_fc3w'], paras['target_fc3b']), 259 | nn.Sigmoid(), 260 | ) 261 | 262 | self.l4 = nn.Sequential( 263 | TargetFC(paras['target_fc4w'], paras['target_fc4b']), 264 | nn.Sigmoid(), 265 | TargetFC(paras['target_fc5w'], paras['target_fc5b']), 266 | ) 267 | 268 | def forward(self, x): 269 | q = self.l1(x) 270 | # q = F.dropout(q) 271 | q = self.l2(q) 272 | q = self.l3(q) 273 | q = self.l4(q).squeeze() 274 | return q 275 | 276 | 277 | class TargetFC(nn.Module): 278 | """ 279 | Fully connection operations for target net 280 | Note: 281 | Weights & biases are different for different images in a batch, 282 | thus here we use group convolution for calculating images in a batch with individual weights & biases. 283 | """ 284 | 285 | def __init__(self, weight, bias): 286 | super(TargetFC, self).__init__() 287 | self.weight = weight 288 | self.bias = bias 289 | 290 | def forward(self, input_): 291 | 292 | input_re = input_.view(-1, input_.shape[0] * input_.shape[1], input_.shape[2], input_.shape[3]) 293 | weight_re = self.weight.view(self.weight.shape[0] * self.weight.shape[1], self.weight.shape[2], 294 | self.weight.shape[3], self.weight.shape[4]) 295 | bias_re = self.bias.view(self.bias.shape[0] * self.bias.shape[1]) 296 | out = F.conv2d(input=input_re, weight=weight_re, bias=bias_re, groups=self.weight.shape[0]) 297 | 298 | return out.view(input_.shape[0], self.weight.shape[1], input_.shape[2], input_.shape[3]) 299 | -------------------------------------------------------------------------------- /facexlib/detection/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | 4 | from facexlib.utils import load_file_from_url 5 | from .retinaface import RetinaFace 6 | 7 | 8 | def init_detection_model(model_name, half=False, device='cuda', model_rootpath=None): 9 | if model_name == 'retinaface_resnet50': 10 | model = RetinaFace(network_name='resnet50', half=half, device=device) 11 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth' 12 | elif model_name == 'retinaface_mobile0.25': 13 | model = RetinaFace(network_name='mobile0.25', half=half, device=device) 14 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth' 15 | else: 16 | raise NotImplementedError(f'{model_name} is not implemented.') 17 | 18 | model_path = load_file_from_url( 19 | url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) 20 | 21 | # TODO: clean pretrained model 22 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 23 | # remove unnecessary 'module.' 24 | for k, v in deepcopy(load_net).items(): 25 | if k.startswith('module.'): 26 | load_net[k[7:]] = v 27 | load_net.pop(k) 28 | model.load_state_dict(load_net, strict=True) 29 | model.eval() 30 | model = model.to(device) 31 | return model 32 | -------------------------------------------------------------------------------- /facexlib/detection/align_trans.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from .matlab_cp2tform import get_similarity_transform_for_cv2 5 | 6 | # reference facial points, a list of coordinates (x,y) 7 | REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278], 8 | [33.54930115, 92.3655014], [62.72990036, 92.20410156]] 9 | 10 | DEFAULT_CROP_SIZE = (96, 112) 11 | 12 | 13 | class FaceWarpException(Exception): 14 | 15 | def __str__(self): 16 | return 'In File {}:{}'.format(__file__, super.__str__(self)) 17 | 18 | 19 | def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False): 20 | """ 21 | Function: 22 | ---------- 23 | get reference 5 key points according to crop settings: 24 | 0. Set default crop_size: 25 | if default_square: 26 | crop_size = (112, 112) 27 | else: 28 | crop_size = (96, 112) 29 | 1. Pad the crop_size by inner_padding_factor in each side; 30 | 2. Resize crop_size into (output_size - outer_padding*2), 31 | pad into output_size with outer_padding; 32 | 3. Output reference_5point; 33 | Parameters: 34 | ---------- 35 | @output_size: (w, h) or None 36 | size of aligned face image 37 | @inner_padding_factor: (w_factor, h_factor) 38 | padding factor for inner (w, h) 39 | @outer_padding: (w_pad, h_pad) 40 | each row is a pair of coordinates (x, y) 41 | @default_square: True or False 42 | if True: 43 | default crop_size = (112, 112) 44 | else: 45 | default crop_size = (96, 112); 46 | !!! make sure, if output_size is not None: 47 | (output_size - outer_padding) 48 | = some_scale * (default crop_size * (1.0 + 49 | inner_padding_factor)) 50 | Returns: 51 | ---------- 52 | @reference_5point: 5x2 np.array 53 | each row is a pair of transformed coordinates (x, y) 54 | """ 55 | 56 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) 57 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE) 58 | 59 | # 0) make the inner region a square 60 | if default_square: 61 | size_diff = max(tmp_crop_size) - tmp_crop_size 62 | tmp_5pts += size_diff / 2 63 | tmp_crop_size += size_diff 64 | 65 | if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]): 66 | 67 | return tmp_5pts 68 | 69 | if (inner_padding_factor == 0 and outer_padding == (0, 0)): 70 | if output_size is None: 71 | return tmp_5pts 72 | else: 73 | raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) 74 | 75 | # check output size 76 | if not (0 <= inner_padding_factor <= 1.0): 77 | raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') 78 | 79 | if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None): 80 | output_size = tmp_crop_size * \ 81 | (1 + inner_padding_factor * 2).astype(np.int32) 82 | output_size += np.array(outer_padding) 83 | if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]): 84 | raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])') 85 | 86 | # 1) pad the inner region according inner_padding_factor 87 | if inner_padding_factor > 0: 88 | size_diff = tmp_crop_size * inner_padding_factor * 2 89 | tmp_5pts += size_diff / 2 90 | tmp_crop_size += np.round(size_diff).astype(np.int32) 91 | 92 | # 2) resize the padded inner region 93 | size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 94 | 95 | if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: 96 | raise FaceWarpException('Must have (output_size - outer_padding)' 97 | '= some_scale * (crop_size * (1.0 + inner_padding_factor)') 98 | 99 | scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] 100 | tmp_5pts = tmp_5pts * scale_factor 101 | # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) 102 | # tmp_5pts = tmp_5pts + size_diff / 2 103 | tmp_crop_size = size_bf_outer_pad 104 | 105 | # 3) add outer_padding to make output_size 106 | reference_5point = tmp_5pts + np.array(outer_padding) 107 | tmp_crop_size = output_size 108 | 109 | return reference_5point 110 | 111 | 112 | def get_affine_transform_matrix(src_pts, dst_pts): 113 | """ 114 | Function: 115 | ---------- 116 | get affine transform matrix 'tfm' from src_pts to dst_pts 117 | Parameters: 118 | ---------- 119 | @src_pts: Kx2 np.array 120 | source points matrix, each row is a pair of coordinates (x, y) 121 | @dst_pts: Kx2 np.array 122 | destination points matrix, each row is a pair of coordinates (x, y) 123 | Returns: 124 | ---------- 125 | @tfm: 2x3 np.array 126 | transform matrix from src_pts to dst_pts 127 | """ 128 | 129 | tfm = np.float32([[1, 0, 0], [0, 1, 0]]) 130 | n_pts = src_pts.shape[0] 131 | ones = np.ones((n_pts, 1), src_pts.dtype) 132 | src_pts_ = np.hstack([src_pts, ones]) 133 | dst_pts_ = np.hstack([dst_pts, ones]) 134 | 135 | A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) 136 | 137 | if rank == 3: 138 | tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]]) 139 | elif rank == 2: 140 | tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) 141 | 142 | return tfm 143 | 144 | 145 | def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'): 146 | """ 147 | Function: 148 | ---------- 149 | apply affine transform 'trans' to uv 150 | Parameters: 151 | ---------- 152 | @src_img: 3x3 np.array 153 | input image 154 | @facial_pts: could be 155 | 1)a list of K coordinates (x,y) 156 | or 157 | 2) Kx2 or 2xK np.array 158 | each row or col is a pair of coordinates (x, y) 159 | @reference_pts: could be 160 | 1) a list of K coordinates (x,y) 161 | or 162 | 2) Kx2 or 2xK np.array 163 | each row or col is a pair of coordinates (x, y) 164 | or 165 | 3) None 166 | if None, use default reference facial points 167 | @crop_size: (w, h) 168 | output face image size 169 | @align_type: transform type, could be one of 170 | 1) 'similarity': use similarity transform 171 | 2) 'cv2_affine': use the first 3 points to do affine transform, 172 | by calling cv2.getAffineTransform() 173 | 3) 'affine': use all points to do affine transform 174 | Returns: 175 | ---------- 176 | @face_img: output face image with size (w, h) = @crop_size 177 | """ 178 | 179 | if reference_pts is None: 180 | if crop_size[0] == 96 and crop_size[1] == 112: 181 | reference_pts = REFERENCE_FACIAL_POINTS 182 | else: 183 | default_square = False 184 | inner_padding_factor = 0 185 | outer_padding = (0, 0) 186 | output_size = crop_size 187 | 188 | reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding, 189 | default_square) 190 | 191 | ref_pts = np.float32(reference_pts) 192 | ref_pts_shp = ref_pts.shape 193 | if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: 194 | raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2') 195 | 196 | if ref_pts_shp[0] == 2: 197 | ref_pts = ref_pts.T 198 | 199 | src_pts = np.float32(facial_pts) 200 | src_pts_shp = src_pts.shape 201 | if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: 202 | raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2') 203 | 204 | if src_pts_shp[0] == 2: 205 | src_pts = src_pts.T 206 | 207 | if src_pts.shape != ref_pts.shape: 208 | raise FaceWarpException('facial_pts and reference_pts must have the same shape') 209 | 210 | if align_type == 'cv2_affine': 211 | tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) 212 | elif align_type == 'affine': 213 | tfm = get_affine_transform_matrix(src_pts, ref_pts) 214 | else: 215 | tfm = get_similarity_transform_for_cv2(src_pts, ref_pts) 216 | 217 | face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1])) 218 | 219 | return face_img 220 | -------------------------------------------------------------------------------- /facexlib/detection/matlab_cp2tform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import inv, lstsq 3 | from numpy.linalg import matrix_rank as rank 4 | from numpy.linalg import norm 5 | 6 | 7 | class MatlabCp2tormException(Exception): 8 | 9 | def __str__(self): 10 | return 'In File {}:{}'.format(__file__, super.__str__(self)) 11 | 12 | 13 | def tformfwd(trans, uv): 14 | """ 15 | Function: 16 | ---------- 17 | apply affine transform 'trans' to uv 18 | 19 | Parameters: 20 | ---------- 21 | @trans: 3x3 np.array 22 | transform matrix 23 | @uv: Kx2 np.array 24 | each row is a pair of coordinates (x, y) 25 | 26 | Returns: 27 | ---------- 28 | @xy: Kx2 np.array 29 | each row is a pair of transformed coordinates (x, y) 30 | """ 31 | uv = np.hstack((uv, np.ones((uv.shape[0], 1)))) 32 | xy = np.dot(uv, trans) 33 | xy = xy[:, 0:-1] 34 | return xy 35 | 36 | 37 | def tforminv(trans, uv): 38 | """ 39 | Function: 40 | ---------- 41 | apply the inverse of affine transform 'trans' to uv 42 | 43 | Parameters: 44 | ---------- 45 | @trans: 3x3 np.array 46 | transform matrix 47 | @uv: Kx2 np.array 48 | each row is a pair of coordinates (x, y) 49 | 50 | Returns: 51 | ---------- 52 | @xy: Kx2 np.array 53 | each row is a pair of inverse-transformed coordinates (x, y) 54 | """ 55 | Tinv = inv(trans) 56 | xy = tformfwd(Tinv, uv) 57 | return xy 58 | 59 | 60 | def findNonreflectiveSimilarity(uv, xy, options=None): 61 | options = {'K': 2} 62 | 63 | K = options['K'] 64 | M = xy.shape[0] 65 | x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector 66 | y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector 67 | 68 | tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) 69 | tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) 70 | X = np.vstack((tmp1, tmp2)) 71 | 72 | u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector 73 | v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector 74 | U = np.vstack((u, v)) 75 | 76 | # We know that X * r = U 77 | if rank(X) >= 2 * K: 78 | r, _, _, _ = lstsq(X, U, rcond=-1) 79 | r = np.squeeze(r) 80 | else: 81 | raise Exception('cp2tform:twoUniquePointsReq') 82 | sc = r[0] 83 | ss = r[1] 84 | tx = r[2] 85 | ty = r[3] 86 | 87 | Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]]) 88 | T = inv(Tinv) 89 | T[:, 2] = np.array([0, 0, 1]) 90 | 91 | return T, Tinv 92 | 93 | 94 | def findSimilarity(uv, xy, options=None): 95 | options = {'K': 2} 96 | 97 | # uv = np.array(uv) 98 | # xy = np.array(xy) 99 | 100 | # Solve for trans1 101 | trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) 102 | 103 | # Solve for trans2 104 | 105 | # manually reflect the xy data across the Y-axis 106 | xyR = xy 107 | xyR[:, 0] = -1 * xyR[:, 0] 108 | 109 | trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) 110 | 111 | # manually reflect the tform to undo the reflection done on xyR 112 | TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) 113 | 114 | trans2 = np.dot(trans2r, TreflectY) 115 | 116 | # Figure out if trans1 or trans2 is better 117 | xy1 = tformfwd(trans1, uv) 118 | norm1 = norm(xy1 - xy) 119 | 120 | xy2 = tformfwd(trans2, uv) 121 | norm2 = norm(xy2 - xy) 122 | 123 | if norm1 <= norm2: 124 | return trans1, trans1_inv 125 | else: 126 | trans2_inv = inv(trans2) 127 | return trans2, trans2_inv 128 | 129 | 130 | def get_similarity_transform(src_pts, dst_pts, reflective=True): 131 | """ 132 | Function: 133 | ---------- 134 | Find Similarity Transform Matrix 'trans': 135 | u = src_pts[:, 0] 136 | v = src_pts[:, 1] 137 | x = dst_pts[:, 0] 138 | y = dst_pts[:, 1] 139 | [x, y, 1] = [u, v, 1] * trans 140 | 141 | Parameters: 142 | ---------- 143 | @src_pts: Kx2 np.array 144 | source points, each row is a pair of coordinates (x, y) 145 | @dst_pts: Kx2 np.array 146 | destination points, each row is a pair of transformed 147 | coordinates (x, y) 148 | @reflective: True or False 149 | if True: 150 | use reflective similarity transform 151 | else: 152 | use non-reflective similarity transform 153 | 154 | Returns: 155 | ---------- 156 | @trans: 3x3 np.array 157 | transform matrix from uv to xy 158 | trans_inv: 3x3 np.array 159 | inverse of trans, transform matrix from xy to uv 160 | """ 161 | 162 | if reflective: 163 | trans, trans_inv = findSimilarity(src_pts, dst_pts) 164 | else: 165 | trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) 166 | 167 | return trans, trans_inv 168 | 169 | 170 | def cvt_tform_mat_for_cv2(trans): 171 | """ 172 | Function: 173 | ---------- 174 | Convert Transform Matrix 'trans' into 'cv2_trans' which could be 175 | directly used by cv2.warpAffine(): 176 | u = src_pts[:, 0] 177 | v = src_pts[:, 1] 178 | x = dst_pts[:, 0] 179 | y = dst_pts[:, 1] 180 | [x, y].T = cv_trans * [u, v, 1].T 181 | 182 | Parameters: 183 | ---------- 184 | @trans: 3x3 np.array 185 | transform matrix from uv to xy 186 | 187 | Returns: 188 | ---------- 189 | @cv2_trans: 2x3 np.array 190 | transform matrix from src_pts to dst_pts, could be directly used 191 | for cv2.warpAffine() 192 | """ 193 | cv2_trans = trans[:, 0:2].T 194 | 195 | return cv2_trans 196 | 197 | 198 | def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): 199 | """ 200 | Function: 201 | ---------- 202 | Find Similarity Transform Matrix 'cv2_trans' which could be 203 | directly used by cv2.warpAffine(): 204 | u = src_pts[:, 0] 205 | v = src_pts[:, 1] 206 | x = dst_pts[:, 0] 207 | y = dst_pts[:, 1] 208 | [x, y].T = cv_trans * [u, v, 1].T 209 | 210 | Parameters: 211 | ---------- 212 | @src_pts: Kx2 np.array 213 | source points, each row is a pair of coordinates (x, y) 214 | @dst_pts: Kx2 np.array 215 | destination points, each row is a pair of transformed 216 | coordinates (x, y) 217 | reflective: True or False 218 | if True: 219 | use reflective similarity transform 220 | else: 221 | use non-reflective similarity transform 222 | 223 | Returns: 224 | ---------- 225 | @cv2_trans: 2x3 np.array 226 | transform matrix from src_pts to dst_pts, could be directly used 227 | for cv2.warpAffine() 228 | """ 229 | trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) 230 | cv2_trans = cvt_tform_mat_for_cv2(trans) 231 | 232 | return cv2_trans 233 | 234 | 235 | if __name__ == '__main__': 236 | """ 237 | u = [0, 6, -2] 238 | v = [0, 3, 5] 239 | x = [-1, 0, 4] 240 | y = [-1, -10, 4] 241 | 242 | # In Matlab, run: 243 | # 244 | # uv = [u'; v']; 245 | # xy = [x'; y']; 246 | # tform_sim=cp2tform(uv,xy,'similarity'); 247 | # 248 | # trans = tform_sim.tdata.T 249 | # ans = 250 | # -0.0764 -1.6190 0 251 | # 1.6190 -0.0764 0 252 | # -3.2156 0.0290 1.0000 253 | # trans_inv = tform_sim.tdata.Tinv 254 | # ans = 255 | # 256 | # -0.0291 0.6163 0 257 | # -0.6163 -0.0291 0 258 | # -0.0756 1.9826 1.0000 259 | # xy_m=tformfwd(tform_sim, u,v) 260 | # 261 | # xy_m = 262 | # 263 | # -3.2156 0.0290 264 | # 1.1833 -9.9143 265 | # 5.0323 2.8853 266 | # uv_m=tforminv(tform_sim, x,y) 267 | # 268 | # uv_m = 269 | # 270 | # 0.5698 1.3953 271 | # 6.0872 2.2733 272 | # -2.6570 4.3314 273 | """ 274 | u = [0, 6, -2] 275 | v = [0, 3, 5] 276 | x = [-1, 0, 4] 277 | y = [-1, -10, 4] 278 | 279 | uv = np.array((u, v)).T 280 | xy = np.array((x, y)).T 281 | 282 | print('\n--->uv:') 283 | print(uv) 284 | print('\n--->xy:') 285 | print(xy) 286 | 287 | trans, trans_inv = get_similarity_transform(uv, xy) 288 | 289 | print('\n--->trans matrix:') 290 | print(trans) 291 | 292 | print('\n--->trans_inv matrix:') 293 | print(trans_inv) 294 | 295 | print('\n---> apply transform to uv') 296 | print('\nxy_m = uv_augmented * trans') 297 | uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1)))) 298 | xy_m = np.dot(uv_aug, trans) 299 | print(xy_m) 300 | 301 | print('\nxy_m = tformfwd(trans, uv)') 302 | xy_m = tformfwd(trans, uv) 303 | print(xy_m) 304 | 305 | print('\n---> apply inverse transform to xy') 306 | print('\nuv_m = xy_augmented * trans_inv') 307 | xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1)))) 308 | uv_m = np.dot(xy_aug, trans_inv) 309 | print(uv_m) 310 | 311 | print('\nuv_m = tformfwd(trans_inv, xy)') 312 | uv_m = tformfwd(trans_inv, xy) 313 | print(uv_m) 314 | 315 | uv_m = tforminv(trans, xy) 316 | print('\nuv_m = tforminv(trans, xy)') 317 | print(uv_m) 318 | -------------------------------------------------------------------------------- /facexlib/detection/retinaface.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from PIL import Image 7 | from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter 8 | 9 | from facexlib.detection.align_trans import get_reference_facial_points, warp_and_crop_face 10 | from facexlib.detection.retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head 11 | from facexlib.detection.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm, 12 | py_cpu_nms) 13 | 14 | 15 | def generate_config(network_name): 16 | 17 | cfg_mnet = { 18 | 'name': 'mobilenet0.25', 19 | 'min_sizes': [[16, 32], [64, 128], [256, 512]], 20 | 'steps': [8, 16, 32], 21 | 'variance': [0.1, 0.2], 22 | 'clip': False, 23 | 'loc_weight': 2.0, 24 | 'gpu_train': True, 25 | 'batch_size': 32, 26 | 'ngpu': 1, 27 | 'epoch': 250, 28 | 'decay1': 190, 29 | 'decay2': 220, 30 | 'image_size': 640, 31 | 'return_layers': { 32 | 'stage1': 1, 33 | 'stage2': 2, 34 | 'stage3': 3 35 | }, 36 | 'in_channel': 32, 37 | 'out_channel': 64 38 | } 39 | 40 | cfg_re50 = { 41 | 'name': 'Resnet50', 42 | 'min_sizes': [[16, 32], [64, 128], [256, 512]], 43 | 'steps': [8, 16, 32], 44 | 'variance': [0.1, 0.2], 45 | 'clip': False, 46 | 'loc_weight': 2.0, 47 | 'gpu_train': True, 48 | 'batch_size': 24, 49 | 'ngpu': 4, 50 | 'epoch': 100, 51 | 'decay1': 70, 52 | 'decay2': 90, 53 | 'image_size': 840, 54 | 'return_layers': { 55 | 'layer2': 1, 56 | 'layer3': 2, 57 | 'layer4': 3 58 | }, 59 | 'in_channel': 256, 60 | 'out_channel': 256 61 | } 62 | 63 | if network_name == 'mobile0.25': 64 | return cfg_mnet 65 | elif network_name == 'resnet50': 66 | return cfg_re50 67 | else: 68 | raise NotImplementedError(f'network_name={network_name}') 69 | 70 | 71 | class RetinaFace(nn.Module): 72 | 73 | def __init__(self, network_name='resnet50', half=False, phase='test', device=None): 74 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device 75 | 76 | super(RetinaFace, self).__init__() 77 | self.half_inference = half 78 | cfg = generate_config(network_name) 79 | self.backbone = cfg['name'] 80 | 81 | self.model_name = f'retinaface_{network_name}' 82 | self.cfg = cfg 83 | self.phase = phase 84 | self.target_size, self.max_size = 1600, 2150 85 | self.resize, self.scale, self.scale1 = 1., None, None 86 | self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]], device=self.device) 87 | self.reference = get_reference_facial_points(default_square=True) 88 | # Build network. 89 | backbone = None 90 | if cfg['name'] == 'mobilenet0.25': 91 | backbone = MobileNetV1() 92 | self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) 93 | elif cfg['name'] == 'Resnet50': 94 | import torchvision.models as models 95 | backbone = models.resnet50(pretrained=False) 96 | self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) 97 | 98 | in_channels_stage2 = cfg['in_channel'] 99 | in_channels_list = [ 100 | in_channels_stage2 * 2, 101 | in_channels_stage2 * 4, 102 | in_channels_stage2 * 8, 103 | ] 104 | 105 | out_channels = cfg['out_channel'] 106 | self.fpn = FPN(in_channels_list, out_channels) 107 | self.ssh1 = SSH(out_channels, out_channels) 108 | self.ssh2 = SSH(out_channels, out_channels) 109 | self.ssh3 = SSH(out_channels, out_channels) 110 | 111 | self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg['out_channel']) 112 | self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) 113 | self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) 114 | 115 | self.to(self.device) 116 | self.eval() 117 | if self.half_inference: 118 | self.half() 119 | 120 | def forward(self, inputs): 121 | out = self.body(inputs) 122 | 123 | if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50': 124 | out = list(out.values()) 125 | # FPN 126 | fpn = self.fpn(out) 127 | 128 | # SSH 129 | feature1 = self.ssh1(fpn[0]) 130 | feature2 = self.ssh2(fpn[1]) 131 | feature3 = self.ssh3(fpn[2]) 132 | features = [feature1, feature2, feature3] 133 | 134 | bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) 135 | classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1) 136 | tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)] 137 | ldm_regressions = (torch.cat(tmp, dim=1)) 138 | 139 | if self.phase == 'train': 140 | output = (bbox_regressions, classifications, ldm_regressions) 141 | else: 142 | output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) 143 | return output 144 | 145 | def __detect_faces(self, inputs): 146 | # get scale 147 | height, width = inputs.shape[2:] 148 | self.scale = torch.tensor([width, height, width, height], dtype=torch.float32, device=self.device) 149 | tmp = [width, height, width, height, width, height, width, height, width, height] 150 | self.scale1 = torch.tensor(tmp, dtype=torch.float32, device=self.device) 151 | 152 | # forawrd 153 | inputs = inputs.to(self.device) 154 | if self.half_inference: 155 | inputs = inputs.half() 156 | loc, conf, landmarks = self(inputs) 157 | 158 | # get priorbox 159 | priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:]) 160 | priors = priorbox.forward().to(self.device) 161 | 162 | return loc, conf, landmarks, priors 163 | 164 | # single image detection 165 | def transform(self, image, use_origin_size): 166 | # convert to opencv format 167 | if isinstance(image, Image.Image): 168 | image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) 169 | image = image.astype(np.float32) 170 | 171 | # testing scale 172 | im_size_min = np.min(image.shape[0:2]) 173 | im_size_max = np.max(image.shape[0:2]) 174 | resize = float(self.target_size) / float(im_size_min) 175 | 176 | # prevent bigger axis from being more than max_size 177 | if np.round(resize * im_size_max) > self.max_size: 178 | resize = float(self.max_size) / float(im_size_max) 179 | resize = 1 if use_origin_size else resize 180 | 181 | # resize 182 | if resize != 1: 183 | image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) 184 | 185 | # convert to torch.tensor format 186 | # image -= (104, 117, 123) 187 | image = image.transpose(2, 0, 1) 188 | image = torch.from_numpy(image).unsqueeze(0) 189 | 190 | return image, resize 191 | 192 | def detect_faces( 193 | self, 194 | image, 195 | conf_threshold=0.8, 196 | nms_threshold=0.4, 197 | use_origin_size=True, 198 | ): 199 | image, self.resize = self.transform(image, use_origin_size) 200 | image = image.to(self.device) 201 | if self.half_inference: 202 | image = image.half() 203 | image = image - self.mean_tensor 204 | 205 | loc, conf, landmarks, priors = self.__detect_faces(image) 206 | 207 | boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance']) 208 | boxes = boxes * self.scale / self.resize 209 | boxes = boxes.cpu().numpy() 210 | 211 | scores = conf.squeeze(0).data.cpu().numpy()[:, 1] 212 | 213 | landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg['variance']) 214 | landmarks = landmarks * self.scale1 / self.resize 215 | landmarks = landmarks.cpu().numpy() 216 | 217 | # ignore low scores 218 | inds = np.where(scores > conf_threshold)[0] 219 | boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds] 220 | 221 | # sort 222 | order = scores.argsort()[::-1] 223 | boxes, landmarks, scores = boxes[order], landmarks[order], scores[order] 224 | 225 | # do NMS 226 | bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) 227 | keep = py_cpu_nms(bounding_boxes, nms_threshold) 228 | bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep] 229 | # self.t['forward_pass'].toc() 230 | # print(self.t['forward_pass'].average_time) 231 | # import sys 232 | # sys.stdout.flush() 233 | return np.concatenate((bounding_boxes, landmarks), axis=1) 234 | 235 | def __align_multi(self, image, boxes, landmarks, limit=None): 236 | 237 | if len(boxes) < 1: 238 | return [], [] 239 | 240 | if limit: 241 | boxes = boxes[:limit] 242 | landmarks = landmarks[:limit] 243 | 244 | faces = [] 245 | for landmark in landmarks: 246 | facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)] 247 | 248 | warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112)) 249 | faces.append(warped_face) 250 | 251 | return np.concatenate((boxes, landmarks), axis=1), faces 252 | 253 | def align_multi(self, img, conf_threshold=0.8, limit=None): 254 | 255 | rlt = self.detect_faces(img, conf_threshold=conf_threshold) 256 | boxes, landmarks = rlt[:, 0:5], rlt[:, 5:] 257 | 258 | return self.__align_multi(img, boxes, landmarks, limit) 259 | 260 | # batched detection 261 | def batched_transform(self, frames, use_origin_size): 262 | """ 263 | Arguments: 264 | frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c], 265 | type=np.float32, BGR format). 266 | use_origin_size: whether to use origin size. 267 | """ 268 | from_PIL = True if isinstance(frames[0], Image.Image) else False 269 | 270 | # convert to opencv format 271 | if from_PIL: 272 | frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames] 273 | frames = np.asarray(frames, dtype=np.float32) 274 | 275 | # testing scale 276 | im_size_min = np.min(frames[0].shape[0:2]) 277 | im_size_max = np.max(frames[0].shape[0:2]) 278 | resize = float(self.target_size) / float(im_size_min) 279 | 280 | # prevent bigger axis from being more than max_size 281 | if np.round(resize * im_size_max) > self.max_size: 282 | resize = float(self.max_size) / float(im_size_max) 283 | resize = 1 if use_origin_size else resize 284 | 285 | # resize 286 | if resize != 1: 287 | if not from_PIL: 288 | frames = F.interpolate(frames, scale_factor=resize) 289 | else: 290 | frames = [ 291 | cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) 292 | for frame in frames 293 | ] 294 | 295 | # convert to torch.tensor format 296 | if not from_PIL: 297 | frames = frames.transpose(1, 2).transpose(1, 3).contiguous() 298 | else: 299 | frames = frames.transpose((0, 3, 1, 2)) 300 | frames = torch.from_numpy(frames) 301 | 302 | return frames, resize 303 | 304 | def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True): 305 | """ 306 | Arguments: 307 | frames: a list of PIL.Image, or np.array(shape=[n, h, w, c], 308 | type=np.uint8, BGR format). 309 | conf_threshold: confidence threshold. 310 | nms_threshold: nms threshold. 311 | use_origin_size: whether to use origin size. 312 | Returns: 313 | final_bounding_boxes: list of np.array ([n_boxes, 5], 314 | type=np.float32). 315 | final_landmarks: list of np.array ([n_boxes, 10], type=np.float32). 316 | """ 317 | # self.t['forward_pass'].tic() 318 | frames, self.resize = self.batched_transform(frames, use_origin_size) 319 | frames = frames.to(self.device) 320 | frames = frames - self.mean_tensor 321 | 322 | b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames) 323 | 324 | final_bounding_boxes, final_landmarks = [], [] 325 | 326 | # decode 327 | priors = priors.unsqueeze(0) 328 | b_loc = batched_decode(b_loc, priors, self.cfg['variance']) * self.scale / self.resize 329 | b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize 330 | b_conf = b_conf[:, :, 1] 331 | 332 | # index for selection 333 | b_indice = b_conf > conf_threshold 334 | 335 | # concat 336 | b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float() 337 | 338 | for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice): 339 | 340 | # ignore low scores 341 | pred, landm = pred[inds, :], landm[inds, :] 342 | if pred.shape[0] == 0: 343 | final_bounding_boxes.append(np.array([], dtype=np.float32)) 344 | final_landmarks.append(np.array([], dtype=np.float32)) 345 | continue 346 | 347 | # sort 348 | # order = score.argsort(descending=True) 349 | # box, landm, score = box[order], landm[order], score[order] 350 | 351 | # to CPU 352 | bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy() 353 | 354 | # NMS 355 | keep = py_cpu_nms(bounding_boxes, nms_threshold) 356 | bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep] 357 | 358 | # append 359 | final_bounding_boxes.append(bounding_boxes) 360 | final_landmarks.append(landmarks) 361 | # self.t['forward_pass'].toc(average=True) 362 | # self.batch_time += self.t['forward_pass'].diff 363 | # self.total_frame += len(frames) 364 | # print(self.batch_time / self.total_frame) 365 | 366 | return final_bounding_boxes, final_landmarks 367 | -------------------------------------------------------------------------------- /facexlib/detection/retinaface_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def conv_bn(inp, oup, stride=1, leaky=0): 7 | return nn.Sequential( 8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), 9 | nn.LeakyReLU(negative_slope=leaky, inplace=True)) 10 | 11 | 12 | def conv_bn_no_relu(inp, oup, stride): 13 | return nn.Sequential( 14 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 15 | nn.BatchNorm2d(oup), 16 | ) 17 | 18 | 19 | def conv_bn1X1(inp, oup, stride, leaky=0): 20 | return nn.Sequential( 21 | nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup), 22 | nn.LeakyReLU(negative_slope=leaky, inplace=True)) 23 | 24 | 25 | def conv_dw(inp, oup, stride, leaky=0.1): 26 | return nn.Sequential( 27 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 28 | nn.BatchNorm2d(inp), 29 | nn.LeakyReLU(negative_slope=leaky, inplace=True), 30 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 31 | nn.BatchNorm2d(oup), 32 | nn.LeakyReLU(negative_slope=leaky, inplace=True), 33 | ) 34 | 35 | 36 | class SSH(nn.Module): 37 | 38 | def __init__(self, in_channel, out_channel): 39 | super(SSH, self).__init__() 40 | assert out_channel % 4 == 0 41 | leaky = 0 42 | if (out_channel <= 64): 43 | leaky = 0.1 44 | self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) 45 | 46 | self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky) 47 | self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) 48 | 49 | self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky) 50 | self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) 51 | 52 | def forward(self, input): 53 | conv3X3 = self.conv3X3(input) 54 | 55 | conv5X5_1 = self.conv5X5_1(input) 56 | conv5X5 = self.conv5X5_2(conv5X5_1) 57 | 58 | conv7X7_2 = self.conv7X7_2(conv5X5_1) 59 | conv7X7 = self.conv7x7_3(conv7X7_2) 60 | 61 | out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) 62 | out = F.relu(out) 63 | return out 64 | 65 | 66 | class FPN(nn.Module): 67 | 68 | def __init__(self, in_channels_list, out_channels): 69 | super(FPN, self).__init__() 70 | leaky = 0 71 | if (out_channels <= 64): 72 | leaky = 0.1 73 | self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky) 74 | self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky) 75 | self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky) 76 | 77 | self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) 78 | self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) 79 | 80 | def forward(self, input): 81 | # names = list(input.keys()) 82 | # input = list(input.values()) 83 | 84 | output1 = self.output1(input[0]) 85 | output2 = self.output2(input[1]) 86 | output3 = self.output3(input[2]) 87 | 88 | up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest') 89 | output2 = output2 + up3 90 | output2 = self.merge2(output2) 91 | 92 | up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest') 93 | output1 = output1 + up2 94 | output1 = self.merge1(output1) 95 | 96 | out = [output1, output2, output3] 97 | return out 98 | 99 | 100 | class MobileNetV1(nn.Module): 101 | 102 | def __init__(self): 103 | super(MobileNetV1, self).__init__() 104 | self.stage1 = nn.Sequential( 105 | conv_bn(3, 8, 2, leaky=0.1), # 3 106 | conv_dw(8, 16, 1), # 7 107 | conv_dw(16, 32, 2), # 11 108 | conv_dw(32, 32, 1), # 19 109 | conv_dw(32, 64, 2), # 27 110 | conv_dw(64, 64, 1), # 43 111 | ) 112 | self.stage2 = nn.Sequential( 113 | conv_dw(64, 128, 2), # 43 + 16 = 59 114 | conv_dw(128, 128, 1), # 59 + 32 = 91 115 | conv_dw(128, 128, 1), # 91 + 32 = 123 116 | conv_dw(128, 128, 1), # 123 + 32 = 155 117 | conv_dw(128, 128, 1), # 155 + 32 = 187 118 | conv_dw(128, 128, 1), # 187 + 32 = 219 119 | ) 120 | self.stage3 = nn.Sequential( 121 | conv_dw(128, 256, 2), # 219 +3 2 = 241 122 | conv_dw(256, 256, 1), # 241 + 64 = 301 123 | ) 124 | self.avg = nn.AdaptiveAvgPool2d((1, 1)) 125 | self.fc = nn.Linear(256, 1000) 126 | 127 | def forward(self, x): 128 | x = self.stage1(x) 129 | x = self.stage2(x) 130 | x = self.stage3(x) 131 | x = self.avg(x) 132 | # x = self.model(x) 133 | x = x.view(-1, 256) 134 | x = self.fc(x) 135 | return x 136 | 137 | 138 | class ClassHead(nn.Module): 139 | 140 | def __init__(self, inchannels=512, num_anchors=3): 141 | super(ClassHead, self).__init__() 142 | self.num_anchors = num_anchors 143 | self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) 144 | 145 | def forward(self, x): 146 | out = self.conv1x1(x) 147 | out = out.permute(0, 2, 3, 1).contiguous() 148 | 149 | return out.view(out.shape[0], -1, 2) 150 | 151 | 152 | class BboxHead(nn.Module): 153 | 154 | def __init__(self, inchannels=512, num_anchors=3): 155 | super(BboxHead, self).__init__() 156 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) 157 | 158 | def forward(self, x): 159 | out = self.conv1x1(x) 160 | out = out.permute(0, 2, 3, 1).contiguous() 161 | 162 | return out.view(out.shape[0], -1, 4) 163 | 164 | 165 | class LandmarkHead(nn.Module): 166 | 167 | def __init__(self, inchannels=512, num_anchors=3): 168 | super(LandmarkHead, self).__init__() 169 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) 170 | 171 | def forward(self, x): 172 | out = self.conv1x1(x) 173 | out = out.permute(0, 2, 3, 1).contiguous() 174 | 175 | return out.view(out.shape[0], -1, 10) 176 | 177 | 178 | def make_class_head(fpn_num=3, inchannels=64, anchor_num=2): 179 | classhead = nn.ModuleList() 180 | for i in range(fpn_num): 181 | classhead.append(ClassHead(inchannels, anchor_num)) 182 | return classhead 183 | 184 | 185 | def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2): 186 | bboxhead = nn.ModuleList() 187 | for i in range(fpn_num): 188 | bboxhead.append(BboxHead(inchannels, anchor_num)) 189 | return bboxhead 190 | 191 | 192 | def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2): 193 | landmarkhead = nn.ModuleList() 194 | for i in range(fpn_num): 195 | landmarkhead.append(LandmarkHead(inchannels, anchor_num)) 196 | return landmarkhead 197 | -------------------------------------------------------------------------------- /facexlib/headpose/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from facexlib.utils import load_file_from_url 4 | from .hopenet_arch import HopeNet 5 | 6 | 7 | def init_headpose_model(model_name, half=False, device='cuda', model_rootpath=None): 8 | if model_name == 'hopenet': 9 | model = HopeNet('resnet', [3, 4, 6, 3], 66) 10 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/headpose_hopenet.pth' 11 | else: 12 | raise NotImplementedError(f'{model_name} is not implemented.') 13 | 14 | model_path = load_file_from_url( 15 | url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) 16 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage)['params'] 17 | model.load_state_dict(load_net, strict=True) 18 | model.eval() 19 | model = model.to(device) 20 | return model 21 | -------------------------------------------------------------------------------- /facexlib/headpose/hopenet_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | class HopeNet(nn.Module): 7 | # Hopenet with 3 output layers for yaw, pitch and roll 8 | # Predicts Euler angles by binning and regression with the expected value 9 | def __init__(self, block, layers, num_bins): 10 | super(HopeNet, self).__init__() 11 | if block == 'resnet': 12 | block = torchvision.models.resnet.Bottleneck 13 | self.inplanes = 64 14 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 15 | self.bn1 = nn.BatchNorm2d(64) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 18 | self.layer1 = self._make_layer(block, 64, layers[0]) 19 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 20 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 21 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 22 | self.avgpool = nn.AvgPool2d(7) 23 | self.fc_yaw = nn.Linear(512 * block.expansion, num_bins) 24 | self.fc_pitch = nn.Linear(512 * block.expansion, num_bins) 25 | self.fc_roll = nn.Linear(512 * block.expansion, num_bins) 26 | 27 | self.idx_tensor = torch.arange(66).float() 28 | 29 | def _make_layer(self, block, planes, blocks, stride=1): 30 | downsample = None 31 | if stride != 1 or self.inplanes != planes * block.expansion: 32 | downsample = nn.Sequential( 33 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 34 | nn.BatchNorm2d(planes * block.expansion), 35 | ) 36 | 37 | layers = [] 38 | layers.append(block(self.inplanes, planes, stride, downsample)) 39 | self.inplanes = planes * block.expansion 40 | for i in range(1, blocks): 41 | layers.append(block(self.inplanes, planes)) 42 | return nn.Sequential(*layers) 43 | 44 | @staticmethod 45 | def softmax_temperature(tensor, temperature): 46 | result = torch.exp(tensor / temperature) 47 | result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result)) 48 | return result 49 | 50 | def bin2degree(self, predict): 51 | predict = self.softmax_temperature(predict, 1) 52 | return torch.sum(predict * self.idx_tensor.type_as(predict), 1) * 3 - 99 53 | 54 | def forward(self, x): 55 | x = self.relu(self.bn1(self.conv1(x))) 56 | x = self.maxpool(x) 57 | 58 | x = self.layer1(x) 59 | x = self.layer2(x) 60 | x = self.layer3(x) 61 | x = self.layer4(x) 62 | 63 | x = self.avgpool(x) 64 | x = x.view(x.size(0), -1) 65 | pre_yaw = self.fc_yaw(x) 66 | pre_pitch = self.fc_pitch(x) 67 | pre_roll = self.fc_roll(x) 68 | 69 | yaw = self.bin2degree(pre_yaw) 70 | pitch = self.bin2degree(pre_pitch) 71 | roll = self.bin2degree(pre_roll) 72 | return yaw, pitch, roll 73 | -------------------------------------------------------------------------------- /facexlib/matting/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | 4 | from facexlib.utils import load_file_from_url 5 | from .modnet import MODNet 6 | 7 | 8 | def init_matting_model(model_name='modnet', half=False, device='cuda', model_rootpath=None): 9 | if model_name == 'modnet': 10 | model = MODNet(backbone_pretrained=False) 11 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/matting_modnet_portrait.pth' 12 | else: 13 | raise NotImplementedError(f'{model_name} is not implemented.') 14 | 15 | model_path = load_file_from_url( 16 | url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) 17 | # TODO: clean pretrained model 18 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 19 | # remove unnecessary 'module.' 20 | for k, v in deepcopy(load_net).items(): 21 | if k.startswith('module.'): 22 | load_net[k[7:]] = v 23 | load_net.pop(k) 24 | model.load_state_dict(load_net, strict=True) 25 | model.eval() 26 | model = model.to(device) 27 | return model 28 | -------------------------------------------------------------------------------- /facexlib/matting/backbone.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .mobilenetv2 import MobileNetV2 6 | 7 | 8 | class BaseBackbone(nn.Module): 9 | """ Superclass of Replaceable Backbone Model for Semantic Estimation 10 | """ 11 | 12 | def __init__(self, in_channels): 13 | super(BaseBackbone, self).__init__() 14 | self.in_channels = in_channels 15 | 16 | self.model = None 17 | self.enc_channels = [] 18 | 19 | def forward(self, x): 20 | raise NotImplementedError 21 | 22 | def load_pretrained_ckpt(self): 23 | raise NotImplementedError 24 | 25 | 26 | class MobileNetV2Backbone(BaseBackbone): 27 | """ MobileNetV2 Backbone 28 | """ 29 | 30 | def __init__(self, in_channels): 31 | super(MobileNetV2Backbone, self).__init__(in_channels) 32 | 33 | self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None) 34 | self.enc_channels = [16, 24, 32, 96, 1280] 35 | 36 | def forward(self, x): 37 | # x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x) 38 | x = self.model.features[0](x) 39 | x = self.model.features[1](x) 40 | enc2x = x 41 | 42 | # x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x) 43 | x = self.model.features[2](x) 44 | x = self.model.features[3](x) 45 | enc4x = x 46 | 47 | # x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x) 48 | x = self.model.features[4](x) 49 | x = self.model.features[5](x) 50 | x = self.model.features[6](x) 51 | enc8x = x 52 | 53 | # x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x) 54 | x = self.model.features[7](x) 55 | x = self.model.features[8](x) 56 | x = self.model.features[9](x) 57 | x = self.model.features[10](x) 58 | x = self.model.features[11](x) 59 | x = self.model.features[12](x) 60 | x = self.model.features[13](x) 61 | enc16x = x 62 | 63 | # x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x) 64 | x = self.model.features[14](x) 65 | x = self.model.features[15](x) 66 | x = self.model.features[16](x) 67 | x = self.model.features[17](x) 68 | x = self.model.features[18](x) 69 | enc32x = x 70 | return [enc2x, enc4x, enc8x, enc16x, enc32x] 71 | 72 | def load_pretrained_ckpt(self): 73 | # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch 74 | ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt' 75 | if not os.path.exists(ckpt_path): 76 | print('cannot find the pretrained mobilenetv2 backbone') 77 | exit() 78 | 79 | ckpt = torch.load(ckpt_path) 80 | self.model.load_state_dict(ckpt) 81 | -------------------------------------------------------------------------------- /facexlib/matting/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ This file is adapted from https://github.com/thuyngch/Human-Segmentation-PyTorch""" 2 | 3 | import math 4 | import torch 5 | from torch import nn 6 | 7 | # ------------------------------------------------------------------------------ 8 | # Useful functions 9 | # ------------------------------------------------------------------------------ 10 | 11 | 12 | def _make_divisible(v, divisor, min_value=None): 13 | if min_value is None: 14 | min_value = divisor 15 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 16 | # Make sure that round down does not go down by more than 10%. 17 | if new_v < 0.9 * v: 18 | new_v += divisor 19 | return new_v 20 | 21 | 22 | def conv_bn(inp, oup, stride): 23 | return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True)) 24 | 25 | 26 | def conv_1x1_bn(inp, oup): 27 | return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True)) 28 | 29 | 30 | # ------------------------------------------------------------------------------ 31 | # Class of Inverted Residual block 32 | # ------------------------------------------------------------------------------ 33 | 34 | 35 | class InvertedResidual(nn.Module): 36 | 37 | def __init__(self, inp, oup, stride, expansion, dilation=1): 38 | super(InvertedResidual, self).__init__() 39 | self.stride = stride 40 | assert stride in [1, 2] 41 | 42 | hidden_dim = round(inp * expansion) 43 | self.use_res_connect = self.stride == 1 and inp == oup 44 | 45 | if expansion == 1: 46 | self.conv = nn.Sequential( 47 | # dw 48 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False), 49 | nn.BatchNorm2d(hidden_dim), 50 | nn.ReLU6(inplace=True), 51 | # pw-linear 52 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 53 | nn.BatchNorm2d(oup), 54 | ) 55 | else: 56 | self.conv = nn.Sequential( 57 | # pw 58 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 59 | nn.BatchNorm2d(hidden_dim), 60 | nn.ReLU6(inplace=True), 61 | # dw 62 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False), 63 | nn.BatchNorm2d(hidden_dim), 64 | nn.ReLU6(inplace=True), 65 | # pw-linear 66 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 67 | nn.BatchNorm2d(oup), 68 | ) 69 | 70 | def forward(self, x): 71 | if self.use_res_connect: 72 | return x + self.conv(x) 73 | else: 74 | return self.conv(x) 75 | 76 | 77 | # ------------------------------------------------------------------------------ 78 | # Class of MobileNetV2 79 | # ------------------------------------------------------------------------------ 80 | 81 | 82 | class MobileNetV2(nn.Module): 83 | 84 | def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000): 85 | super(MobileNetV2, self).__init__() 86 | self.in_channels = in_channels 87 | self.num_classes = num_classes 88 | input_channel = 32 89 | last_channel = 1280 90 | interverted_residual_setting = [ 91 | # t, c, n, s 92 | [1, 16, 1, 1], 93 | [expansion, 24, 2, 2], 94 | [expansion, 32, 3, 2], 95 | [expansion, 64, 4, 2], 96 | [expansion, 96, 3, 1], 97 | [expansion, 160, 3, 2], 98 | [expansion, 320, 1, 1], 99 | ] 100 | 101 | # building first layer 102 | input_channel = _make_divisible(input_channel * alpha, 8) 103 | self.last_channel = _make_divisible(last_channel * alpha, 8) if alpha > 1.0 else last_channel 104 | self.features = [conv_bn(self.in_channels, input_channel, 2)] 105 | 106 | # building inverted residual blocks 107 | for t, c, n, s in interverted_residual_setting: 108 | output_channel = _make_divisible(int(c * alpha), 8) 109 | for i in range(n): 110 | if i == 0: 111 | self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t)) 112 | else: 113 | self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t)) 114 | input_channel = output_channel 115 | 116 | # building last several layers 117 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 118 | 119 | # make it nn.Sequential 120 | self.features = nn.Sequential(*self.features) 121 | 122 | # building classifier 123 | if self.num_classes is not None: 124 | self.classifier = nn.Sequential( 125 | nn.Dropout(0.2), 126 | nn.Linear(self.last_channel, num_classes), 127 | ) 128 | 129 | # Initialize weights 130 | self._init_weights() 131 | 132 | def forward(self, x): 133 | # Stage1 134 | x = self.features[0](x) 135 | x = self.features[1](x) 136 | # Stage2 137 | x = self.features[2](x) 138 | x = self.features[3](x) 139 | # Stage3 140 | x = self.features[4](x) 141 | x = self.features[5](x) 142 | x = self.features[6](x) 143 | # Stage4 144 | x = self.features[7](x) 145 | x = self.features[8](x) 146 | x = self.features[9](x) 147 | x = self.features[10](x) 148 | x = self.features[11](x) 149 | x = self.features[12](x) 150 | x = self.features[13](x) 151 | # Stage5 152 | x = self.features[14](x) 153 | x = self.features[15](x) 154 | x = self.features[16](x) 155 | x = self.features[17](x) 156 | x = self.features[18](x) 157 | 158 | # Classification 159 | if self.num_classes is not None: 160 | x = x.mean(dim=(2, 3)) 161 | x = self.classifier(x) 162 | 163 | # Output 164 | return x 165 | 166 | def _load_pretrained_model(self, pretrained_file): 167 | pretrain_dict = torch.load(pretrained_file, map_location='cpu') 168 | model_dict = {} 169 | state_dict = self.state_dict() 170 | print('[MobileNetV2] Loading pretrained model...') 171 | for k, v in pretrain_dict.items(): 172 | if k in state_dict: 173 | model_dict[k] = v 174 | else: 175 | print(k, 'is ignored') 176 | state_dict.update(model_dict) 177 | self.load_state_dict(state_dict) 178 | 179 | def _init_weights(self): 180 | for m in self.modules(): 181 | if isinstance(m, nn.Conv2d): 182 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 183 | m.weight.data.normal_(0, math.sqrt(2. / n)) 184 | if m.bias is not None: 185 | m.bias.data.zero_() 186 | elif isinstance(m, nn.BatchNorm2d): 187 | m.weight.data.fill_(1) 188 | m.bias.data.zero_() 189 | elif isinstance(m, nn.Linear): 190 | n = m.weight.size(1) 191 | m.weight.data.normal_(0, 0.01) 192 | m.bias.data.zero_() 193 | -------------------------------------------------------------------------------- /facexlib/matting/modnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .backbone import MobileNetV2Backbone 6 | 7 | # ------------------------------------------------------------------------------ 8 | # MODNet Basic Modules 9 | # ------------------------------------------------------------------------------ 10 | 11 | 12 | class IBNorm(nn.Module): 13 | """ Combine Instance Norm and Batch Norm into One Layer 14 | """ 15 | 16 | def __init__(self, in_channels): 17 | super(IBNorm, self).__init__() 18 | in_channels = in_channels 19 | self.bnorm_channels = int(in_channels / 2) 20 | self.inorm_channels = in_channels - self.bnorm_channels 21 | 22 | self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True) 23 | self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False) 24 | 25 | def forward(self, x): 26 | bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous()) 27 | in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous()) 28 | 29 | return torch.cat((bn_x, in_x), 1) 30 | 31 | 32 | class Conv2dIBNormRelu(nn.Module): 33 | """ Convolution + IBNorm + ReLu 34 | """ 35 | 36 | def __init__(self, 37 | in_channels, 38 | out_channels, 39 | kernel_size, 40 | stride=1, 41 | padding=0, 42 | dilation=1, 43 | groups=1, 44 | bias=True, 45 | with_ibn=True, 46 | with_relu=True): 47 | super(Conv2dIBNormRelu, self).__init__() 48 | 49 | layers = [ 50 | nn.Conv2d( 51 | in_channels, 52 | out_channels, 53 | kernel_size, 54 | stride=stride, 55 | padding=padding, 56 | dilation=dilation, 57 | groups=groups, 58 | bias=bias) 59 | ] 60 | 61 | if with_ibn: 62 | layers.append(IBNorm(out_channels)) 63 | if with_relu: 64 | layers.append(nn.ReLU(inplace=True)) 65 | 66 | self.layers = nn.Sequential(*layers) 67 | 68 | def forward(self, x): 69 | return self.layers(x) 70 | 71 | 72 | class SEBlock(nn.Module): 73 | """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf 74 | """ 75 | 76 | def __init__(self, in_channels, out_channels, reduction=1): 77 | super(SEBlock, self).__init__() 78 | self.pool = nn.AdaptiveAvgPool2d(1) 79 | self.fc = nn.Sequential( 80 | nn.Linear(in_channels, int(in_channels // reduction), bias=False), nn.ReLU(inplace=True), 81 | nn.Linear(int(in_channels // reduction), out_channels, bias=False), nn.Sigmoid()) 82 | 83 | def forward(self, x): 84 | b, c, _, _ = x.size() 85 | w = self.pool(x).view(b, c) 86 | w = self.fc(w).view(b, c, 1, 1) 87 | 88 | return x * w.expand_as(x) 89 | 90 | 91 | # ------------------------------------------------------------------------------ 92 | # MODNet Branches 93 | # ------------------------------------------------------------------------------ 94 | 95 | 96 | class LRBranch(nn.Module): 97 | """ Low Resolution Branch of MODNet 98 | """ 99 | 100 | def __init__(self, backbone): 101 | super(LRBranch, self).__init__() 102 | 103 | enc_channels = backbone.enc_channels 104 | 105 | self.backbone = backbone 106 | self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4) 107 | self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2) 108 | self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2) 109 | self.conv_lr = Conv2dIBNormRelu( 110 | enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False) 111 | 112 | def forward(self, img, inference): 113 | enc_features = self.backbone.forward(img) 114 | enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4] 115 | 116 | enc32x = self.se_block(enc32x) 117 | lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False) 118 | lr16x = self.conv_lr16x(lr16x) 119 | lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False) 120 | lr8x = self.conv_lr8x(lr8x) 121 | 122 | pred_semantic = None 123 | if not inference: 124 | lr = self.conv_lr(lr8x) 125 | pred_semantic = torch.sigmoid(lr) 126 | 127 | return pred_semantic, lr8x, [enc2x, enc4x] 128 | 129 | 130 | class HRBranch(nn.Module): 131 | """ High Resolution Branch of MODNet 132 | """ 133 | 134 | def __init__(self, hr_channels, enc_channels): 135 | super(HRBranch, self).__init__() 136 | 137 | self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0) 138 | self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1) 139 | 140 | self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0) 141 | self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1) 142 | 143 | self.conv_hr4x = nn.Sequential( 144 | Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1), 145 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), 146 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), 147 | ) 148 | 149 | self.conv_hr2x = nn.Sequential( 150 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1), 151 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1), 152 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), 153 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1), 154 | ) 155 | 156 | self.conv_hr = nn.Sequential( 157 | Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1), 158 | Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False), 159 | ) 160 | 161 | def forward(self, img, enc2x, enc4x, lr8x, inference): 162 | img2x = F.interpolate(img, scale_factor=1 / 2, mode='bilinear', align_corners=False) 163 | img4x = F.interpolate(img, scale_factor=1 / 4, mode='bilinear', align_corners=False) 164 | 165 | enc2x = self.tohr_enc2x(enc2x) 166 | hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1)) 167 | 168 | enc4x = self.tohr_enc4x(enc4x) 169 | hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1)) 170 | 171 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) 172 | hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1)) 173 | 174 | hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False) 175 | hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1)) 176 | 177 | pred_detail = None 178 | if not inference: 179 | hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False) 180 | hr = self.conv_hr(torch.cat((hr, img), dim=1)) 181 | pred_detail = torch.sigmoid(hr) 182 | 183 | return pred_detail, hr2x 184 | 185 | 186 | class FusionBranch(nn.Module): 187 | """ Fusion Branch of MODNet 188 | """ 189 | 190 | def __init__(self, hr_channels, enc_channels): 191 | super(FusionBranch, self).__init__() 192 | self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2) 193 | 194 | self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1) 195 | self.conv_f = nn.Sequential( 196 | Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1), 197 | Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False), 198 | ) 199 | 200 | def forward(self, img, lr8x, hr2x): 201 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False) 202 | lr4x = self.conv_lr4x(lr4x) 203 | lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False) 204 | 205 | f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1)) 206 | f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False) 207 | f = self.conv_f(torch.cat((f, img), dim=1)) 208 | pred_matte = torch.sigmoid(f) 209 | 210 | return pred_matte 211 | 212 | 213 | # ------------------------------------------------------------------------------ 214 | # MODNet 215 | # ------------------------------------------------------------------------------ 216 | 217 | 218 | class MODNet(nn.Module): 219 | """ Architecture of MODNet 220 | """ 221 | 222 | def __init__(self, in_channels=3, hr_channels=32, backbone_pretrained=True): 223 | super(MODNet, self).__init__() 224 | 225 | self.in_channels = in_channels 226 | self.hr_channels = hr_channels 227 | self.backbone_pretrained = backbone_pretrained 228 | 229 | self.backbone = MobileNetV2Backbone(self.in_channels) 230 | 231 | self.lr_branch = LRBranch(self.backbone) 232 | self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels) 233 | self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels) 234 | 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | self._init_conv(m) 238 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 239 | self._init_norm(m) 240 | 241 | if self.backbone_pretrained: 242 | self.backbone.load_pretrained_ckpt() 243 | 244 | def forward(self, img, inference): 245 | pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference) 246 | pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference) 247 | pred_matte = self.f_branch(img, lr8x, hr2x) 248 | 249 | return pred_semantic, pred_detail, pred_matte 250 | 251 | def freeze_norm(self): 252 | norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d] 253 | for m in self.modules(): 254 | for n in norm_types: 255 | if isinstance(m, n): 256 | m.eval() 257 | continue 258 | 259 | def _init_conv(self, conv): 260 | nn.init.kaiming_uniform_(conv.weight, a=0, mode='fan_in', nonlinearity='relu') 261 | if conv.bias is not None: 262 | nn.init.constant_(conv.bias, 0) 263 | 264 | def _init_norm(self, norm): 265 | if norm.weight is not None: 266 | nn.init.constant_(norm.weight, 1) 267 | nn.init.constant_(norm.bias, 0) 268 | -------------------------------------------------------------------------------- /facexlib/parsing/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from facexlib.utils import load_file_from_url 4 | from .bisenet import BiSeNet 5 | from .parsenet import ParseNet 6 | 7 | 8 | def init_parsing_model(model_name='bisenet', half=False, device='cuda', model_rootpath=None): 9 | if model_name == 'bisenet': 10 | model = BiSeNet(num_class=19) 11 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_bisenet.pth' 12 | elif model_name == 'parsenet': 13 | model = ParseNet(in_size=512, out_size=512, parsing_ch=19) 14 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth' 15 | else: 16 | raise NotImplementedError(f'{model_name} is not implemented.') 17 | 18 | model_path = load_file_from_url( 19 | url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) 20 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage) 21 | model.load_state_dict(load_net, strict=True) 22 | model.eval() 23 | model = model.to(device) 24 | return model 25 | -------------------------------------------------------------------------------- /facexlib/parsing/bisenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .resnet import ResNet18 6 | 7 | 8 | class ConvBNReLU(nn.Module): 9 | 10 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): 11 | super(ConvBNReLU, self).__init__() 12 | self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) 13 | self.bn = nn.BatchNorm2d(out_chan) 14 | 15 | def forward(self, x): 16 | x = self.conv(x) 17 | x = F.relu(self.bn(x)) 18 | return x 19 | 20 | 21 | class BiSeNetOutput(nn.Module): 22 | 23 | def __init__(self, in_chan, mid_chan, num_class): 24 | super(BiSeNetOutput, self).__init__() 25 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 26 | self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False) 27 | 28 | def forward(self, x): 29 | feat = self.conv(x) 30 | out = self.conv_out(feat) 31 | return out, feat 32 | 33 | 34 | class AttentionRefinementModule(nn.Module): 35 | 36 | def __init__(self, in_chan, out_chan): 37 | super(AttentionRefinementModule, self).__init__() 38 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 39 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) 40 | self.bn_atten = nn.BatchNorm2d(out_chan) 41 | self.sigmoid_atten = nn.Sigmoid() 42 | 43 | def forward(self, x): 44 | feat = self.conv(x) 45 | atten = F.avg_pool2d(feat, feat.size()[2:]) 46 | atten = self.conv_atten(atten) 47 | atten = self.bn_atten(atten) 48 | atten = self.sigmoid_atten(atten) 49 | out = torch.mul(feat, atten) 50 | return out 51 | 52 | 53 | class ContextPath(nn.Module): 54 | 55 | def __init__(self): 56 | super(ContextPath, self).__init__() 57 | self.resnet = ResNet18() 58 | self.arm16 = AttentionRefinementModule(256, 128) 59 | self.arm32 = AttentionRefinementModule(512, 128) 60 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 61 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 62 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 63 | 64 | def forward(self, x): 65 | feat8, feat16, feat32 = self.resnet(x) 66 | h8, w8 = feat8.size()[2:] 67 | h16, w16 = feat16.size()[2:] 68 | h32, w32 = feat32.size()[2:] 69 | 70 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 71 | avg = self.conv_avg(avg) 72 | avg_up = F.interpolate(avg, (h32, w32), mode='nearest') 73 | 74 | feat32_arm = self.arm32(feat32) 75 | feat32_sum = feat32_arm + avg_up 76 | feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest') 77 | feat32_up = self.conv_head32(feat32_up) 78 | 79 | feat16_arm = self.arm16(feat16) 80 | feat16_sum = feat16_arm + feat32_up 81 | feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest') 82 | feat16_up = self.conv_head16(feat16_up) 83 | 84 | return feat8, feat16_up, feat32_up # x8, x8, x16 85 | 86 | 87 | class FeatureFusionModule(nn.Module): 88 | 89 | def __init__(self, in_chan, out_chan): 90 | super(FeatureFusionModule, self).__init__() 91 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 92 | self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) 93 | self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.sigmoid = nn.Sigmoid() 96 | 97 | def forward(self, fsp, fcp): 98 | fcat = torch.cat([fsp, fcp], dim=1) 99 | feat = self.convblk(fcat) 100 | atten = F.avg_pool2d(feat, feat.size()[2:]) 101 | atten = self.conv1(atten) 102 | atten = self.relu(atten) 103 | atten = self.conv2(atten) 104 | atten = self.sigmoid(atten) 105 | feat_atten = torch.mul(feat, atten) 106 | feat_out = feat_atten + feat 107 | return feat_out 108 | 109 | 110 | class BiSeNet(nn.Module): 111 | 112 | def __init__(self, num_class): 113 | super(BiSeNet, self).__init__() 114 | self.cp = ContextPath() 115 | self.ffm = FeatureFusionModule(256, 256) 116 | self.conv_out = BiSeNetOutput(256, 256, num_class) 117 | self.conv_out16 = BiSeNetOutput(128, 64, num_class) 118 | self.conv_out32 = BiSeNetOutput(128, 64, num_class) 119 | 120 | def forward(self, x, return_feat=False): 121 | h, w = x.size()[2:] 122 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature 123 | feat_sp = feat_res8 # replace spatial path feature with res3b1 feature 124 | feat_fuse = self.ffm(feat_sp, feat_cp8) 125 | 126 | out, feat = self.conv_out(feat_fuse) 127 | out16, feat16 = self.conv_out16(feat_cp8) 128 | out32, feat32 = self.conv_out32(feat_cp16) 129 | 130 | out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) 131 | out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True) 132 | out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True) 133 | 134 | if return_feat: 135 | feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True) 136 | feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True) 137 | feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True) 138 | return out, out16, out32, feat, feat16, feat32 139 | else: 140 | return out, out16, out32 141 | -------------------------------------------------------------------------------- /facexlib/parsing/parsenet.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/chaofengc/PSFRGAN 2 | """ 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class NormLayer(nn.Module): 9 | """Normalization Layers. 10 | 11 | Args: 12 | channels: input channels, for batch norm and instance norm. 13 | input_size: input shape without batch size, for layer norm. 14 | """ 15 | 16 | def __init__(self, channels, normalize_shape=None, norm_type='bn'): 17 | super(NormLayer, self).__init__() 18 | norm_type = norm_type.lower() 19 | self.norm_type = norm_type 20 | if norm_type == 'bn': 21 | self.norm = nn.BatchNorm2d(channels, affine=True) 22 | elif norm_type == 'in': 23 | self.norm = nn.InstanceNorm2d(channels, affine=False) 24 | elif norm_type == 'gn': 25 | self.norm = nn.GroupNorm(32, channels, affine=True) 26 | elif norm_type == 'pixel': 27 | self.norm = lambda x: F.normalize(x, p=2, dim=1) 28 | elif norm_type == 'layer': 29 | self.norm = nn.LayerNorm(normalize_shape) 30 | elif norm_type == 'none': 31 | self.norm = lambda x: x * 1.0 32 | else: 33 | assert 1 == 0, f'Norm type {norm_type} not support.' 34 | 35 | def forward(self, x, ref=None): 36 | if self.norm_type == 'spade': 37 | return self.norm(x, ref) 38 | else: 39 | return self.norm(x) 40 | 41 | 42 | class ReluLayer(nn.Module): 43 | """Relu Layer. 44 | 45 | Args: 46 | relu type: type of relu layer, candidates are 47 | - ReLU 48 | - LeakyReLU: default relu slope 0.2 49 | - PRelu 50 | - SELU 51 | - none: direct pass 52 | """ 53 | 54 | def __init__(self, channels, relu_type='relu'): 55 | super(ReluLayer, self).__init__() 56 | relu_type = relu_type.lower() 57 | if relu_type == 'relu': 58 | self.func = nn.ReLU(True) 59 | elif relu_type == 'leakyrelu': 60 | self.func = nn.LeakyReLU(0.2, inplace=True) 61 | elif relu_type == 'prelu': 62 | self.func = nn.PReLU(channels) 63 | elif relu_type == 'selu': 64 | self.func = nn.SELU(True) 65 | elif relu_type == 'none': 66 | self.func = lambda x: x * 1.0 67 | else: 68 | assert 1 == 0, f'Relu type {relu_type} not support.' 69 | 70 | def forward(self, x): 71 | return self.func(x) 72 | 73 | 74 | class ConvLayer(nn.Module): 75 | 76 | def __init__(self, 77 | in_channels, 78 | out_channels, 79 | kernel_size=3, 80 | scale='none', 81 | norm_type='none', 82 | relu_type='none', 83 | use_pad=True, 84 | bias=True): 85 | super(ConvLayer, self).__init__() 86 | self.use_pad = use_pad 87 | self.norm_type = norm_type 88 | if norm_type in ['bn']: 89 | bias = False 90 | 91 | stride = 2 if scale == 'down' else 1 92 | 93 | self.scale_func = lambda x: x 94 | if scale == 'up': 95 | self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest') 96 | 97 | self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2))) 98 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) 99 | 100 | self.relu = ReluLayer(out_channels, relu_type) 101 | self.norm = NormLayer(out_channels, norm_type=norm_type) 102 | 103 | def forward(self, x): 104 | out = self.scale_func(x) 105 | if self.use_pad: 106 | out = self.reflection_pad(out) 107 | out = self.conv2d(out) 108 | out = self.norm(out) 109 | out = self.relu(out) 110 | return out 111 | 112 | 113 | class ResidualBlock(nn.Module): 114 | """ 115 | Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html 116 | """ 117 | 118 | def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'): 119 | super(ResidualBlock, self).__init__() 120 | 121 | if scale == 'none' and c_in == c_out: 122 | self.shortcut_func = lambda x: x 123 | else: 124 | self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) 125 | 126 | scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']} 127 | scale_conf = scale_config_dict[scale] 128 | 129 | self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) 130 | self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none') 131 | 132 | def forward(self, x): 133 | identity = self.shortcut_func(x) 134 | 135 | res = self.conv1(x) 136 | res = self.conv2(res) 137 | return identity + res 138 | 139 | 140 | class ParseNet(nn.Module): 141 | 142 | def __init__(self, 143 | in_size=128, 144 | out_size=128, 145 | min_feat_size=32, 146 | base_ch=64, 147 | parsing_ch=19, 148 | res_depth=10, 149 | relu_type='LeakyReLU', 150 | norm_type='bn', 151 | ch_range=[32, 256]): 152 | super().__init__() 153 | self.res_depth = res_depth 154 | act_args = {'norm_type': norm_type, 'relu_type': relu_type} 155 | min_ch, max_ch = ch_range 156 | 157 | ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731 158 | min_feat_size = min(in_size, min_feat_size) 159 | 160 | down_steps = int(np.log2(in_size // min_feat_size)) 161 | up_steps = int(np.log2(out_size // min_feat_size)) 162 | 163 | # =============== define encoder-body-decoder ==================== 164 | self.encoder = [] 165 | self.encoder.append(ConvLayer(3, base_ch, 3, 1)) 166 | head_ch = base_ch 167 | for i in range(down_steps): 168 | cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) 169 | self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args)) 170 | head_ch = head_ch * 2 171 | 172 | self.body = [] 173 | for i in range(res_depth): 174 | self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) 175 | 176 | self.decoder = [] 177 | for i in range(up_steps): 178 | cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) 179 | self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args)) 180 | head_ch = head_ch // 2 181 | 182 | self.encoder = nn.Sequential(*self.encoder) 183 | self.body = nn.Sequential(*self.body) 184 | self.decoder = nn.Sequential(*self.decoder) 185 | self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) 186 | self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) 187 | 188 | def forward(self, x): 189 | feat = self.encoder(x) 190 | x = feat + self.body(feat) 191 | x = self.decoder(x) 192 | out_img = self.out_img_conv(x) 193 | out_mask = self.out_mask_conv(x) 194 | return out_mask, out_img 195 | -------------------------------------------------------------------------------- /facexlib/parsing/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | 12 | def __init__(self, in_chan, out_chan, stride=1): 13 | super(BasicBlock, self).__init__() 14 | self.conv1 = conv3x3(in_chan, out_chan, stride) 15 | self.bn1 = nn.BatchNorm2d(out_chan) 16 | self.conv2 = conv3x3(out_chan, out_chan) 17 | self.bn2 = nn.BatchNorm2d(out_chan) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = None 20 | if in_chan != out_chan or stride != 1: 21 | self.downsample = nn.Sequential( 22 | nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), 23 | nn.BatchNorm2d(out_chan), 24 | ) 25 | 26 | def forward(self, x): 27 | residual = self.conv1(x) 28 | residual = F.relu(self.bn1(residual)) 29 | residual = self.conv2(residual) 30 | residual = self.bn2(residual) 31 | 32 | shortcut = x 33 | if self.downsample is not None: 34 | shortcut = self.downsample(x) 35 | 36 | out = shortcut + residual 37 | out = self.relu(out) 38 | return out 39 | 40 | 41 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 42 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 43 | for i in range(bnum - 1): 44 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 45 | return nn.Sequential(*layers) 46 | 47 | 48 | class ResNet18(nn.Module): 49 | 50 | def __init__(self): 51 | super(ResNet18, self).__init__() 52 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 53 | self.bn1 = nn.BatchNorm2d(64) 54 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 55 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 56 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 57 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 58 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = F.relu(self.bn1(x)) 63 | x = self.maxpool(x) 64 | 65 | x = self.layer1(x) 66 | feat8 = self.layer2(x) # 1/8 67 | feat16 = self.layer3(feat8) # 1/16 68 | feat32 = self.layer4(feat16) # 1/32 69 | return feat8, feat16, feat32 70 | -------------------------------------------------------------------------------- /facexlib/recognition/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from facexlib.utils import load_file_from_url 4 | from .arcface_arch import Backbone 5 | 6 | 7 | def init_recognition_model(model_name, half=False, device='cuda', model_rootpath=None): 8 | if model_name == 'arcface': 9 | model = Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se').to('cuda').eval() 10 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/recognition_arcface_ir_se50.pth' 11 | else: 12 | raise NotImplementedError(f'{model_name} is not implemented.') 13 | 14 | model_path = load_file_from_url( 15 | url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath) 16 | model.load_state_dict(torch.load(model_path), strict=True) 17 | model.eval() 18 | model = model.to(device) 19 | return model 20 | -------------------------------------------------------------------------------- /facexlib/recognition/arcface_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import namedtuple 3 | from torch.nn import (AdaptiveAvgPool2d, BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, PReLU, 4 | ReLU, Sequential, Sigmoid) 5 | 6 | # Original Arcface Model 7 | 8 | 9 | class Flatten(Module): 10 | 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | def l2_norm(input, axis=1): 16 | norm = torch.norm(input, 2, axis, True) 17 | output = torch.div(input, norm) 18 | return output 19 | 20 | 21 | class SEModule(Module): 22 | 23 | def __init__(self, channels, reduction): 24 | super(SEModule, self).__init__() 25 | self.avg_pool = AdaptiveAvgPool2d(1) 26 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 27 | self.relu = ReLU(inplace=True) 28 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 29 | self.sigmoid = Sigmoid() 30 | 31 | def forward(self, x): 32 | module_input = x 33 | x = self.avg_pool(x) 34 | x = self.fc1(x) 35 | x = self.relu(x) 36 | x = self.fc2(x) 37 | x = self.sigmoid(x) 38 | return module_input * x 39 | 40 | 41 | class bottleneck_IR(Module): 42 | 43 | def __init__(self, in_channel, depth, stride): 44 | super(bottleneck_IR, self).__init__() 45 | if in_channel == depth: 46 | self.shortcut_layer = MaxPool2d(1, stride) 47 | else: 48 | self.shortcut_layer = Sequential(Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth)) 49 | self.res_layer = Sequential( 50 | BatchNorm2d(in_channel), Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 51 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)) 52 | 53 | def forward(self, x): 54 | shortcut = self.shortcut_layer(x) 55 | res = self.res_layer(x) 56 | return res + shortcut 57 | 58 | 59 | class bottleneck_IR_SE(Module): 60 | 61 | def __init__(self, in_channel, depth, stride): 62 | super(bottleneck_IR_SE, self).__init__() 63 | if in_channel == depth: 64 | self.shortcut_layer = MaxPool2d(1, stride) 65 | else: 66 | self.shortcut_layer = Sequential(Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth)) 67 | self.res_layer = Sequential( 68 | BatchNorm2d(in_channel), Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 69 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth), SEModule(depth, 16)) 70 | 71 | def forward(self, x): 72 | shortcut = self.shortcut_layer(x) 73 | res = self.res_layer(x) 74 | return res + shortcut 75 | 76 | 77 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 78 | '''A named tuple describing a ResNet block.''' 79 | 80 | 81 | def get_block(in_channel, depth, num_units, stride=2): 82 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 83 | 84 | 85 | def get_blocks(num_layers): 86 | if num_layers == 50: 87 | blocks = [ 88 | get_block(in_channel=64, depth=64, num_units=3), 89 | get_block(in_channel=64, depth=128, num_units=4), 90 | get_block(in_channel=128, depth=256, num_units=14), 91 | get_block(in_channel=256, depth=512, num_units=3) 92 | ] 93 | elif num_layers == 100: 94 | blocks = [ 95 | get_block(in_channel=64, depth=64, num_units=3), 96 | get_block(in_channel=64, depth=128, num_units=13), 97 | get_block(in_channel=128, depth=256, num_units=30), 98 | get_block(in_channel=256, depth=512, num_units=3) 99 | ] 100 | elif num_layers == 152: 101 | blocks = [ 102 | get_block(in_channel=64, depth=64, num_units=3), 103 | get_block(in_channel=64, depth=128, num_units=8), 104 | get_block(in_channel=128, depth=256, num_units=36), 105 | get_block(in_channel=256, depth=512, num_units=3) 106 | ] 107 | return blocks 108 | 109 | 110 | class Backbone(Module): 111 | 112 | def __init__(self, num_layers, drop_ratio, mode='ir'): 113 | super(Backbone, self).__init__() 114 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 115 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 116 | blocks = get_blocks(num_layers) 117 | if mode == 'ir': 118 | unit_module = bottleneck_IR 119 | elif mode == 'ir_se': 120 | unit_module = bottleneck_IR_SE 121 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64)) 122 | self.output_layer = Sequential( 123 | BatchNorm2d(512), Dropout(drop_ratio), Flatten(), Linear(512 * 7 * 7, 512), BatchNorm1d(512)) 124 | modules = [] 125 | for block in blocks: 126 | for bottleneck in block: 127 | modules.append(unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride)) 128 | self.body = Sequential(*modules) 129 | 130 | def forward(self, x): 131 | x = self.input_layer(x) 132 | x = self.body(x) 133 | x = self.output_layer(x) 134 | return l2_norm(x) 135 | 136 | 137 | # MobileFaceNet 138 | 139 | 140 | class Conv_block(Module): 141 | 142 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 143 | super(Conv_block, self).__init__() 144 | self.conv = Conv2d( 145 | in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False) 146 | self.bn = BatchNorm2d(out_c) 147 | self.prelu = PReLU(out_c) 148 | 149 | def forward(self, x): 150 | x = self.conv(x) 151 | x = self.bn(x) 152 | x = self.prelu(x) 153 | return x 154 | 155 | 156 | class Linear_block(Module): 157 | 158 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 159 | super(Linear_block, self).__init__() 160 | self.conv = Conv2d( 161 | in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False) 162 | self.bn = BatchNorm2d(out_c) 163 | 164 | def forward(self, x): 165 | x = self.conv(x) 166 | x = self.bn(x) 167 | return x 168 | 169 | 170 | class Depth_Wise(Module): 171 | 172 | def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): 173 | super(Depth_Wise, self).__init__() 174 | self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 175 | self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride) 176 | self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 177 | self.residual = residual 178 | 179 | def forward(self, x): 180 | if self.residual: 181 | short_cut = x 182 | x = self.conv(x) 183 | x = self.conv_dw(x) 184 | x = self.project(x) 185 | if self.residual: 186 | output = short_cut + x 187 | else: 188 | output = x 189 | return output 190 | 191 | 192 | class Residual(Module): 193 | 194 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): 195 | super(Residual, self).__init__() 196 | modules = [] 197 | for _ in range(num_block): 198 | modules.append( 199 | Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups)) 200 | self.model = Sequential(*modules) 201 | 202 | def forward(self, x): 203 | return self.model(x) 204 | 205 | 206 | class MobileFaceNet(Module): 207 | 208 | def __init__(self, embedding_size): 209 | super(MobileFaceNet, self).__init__() 210 | self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1)) 211 | self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) 212 | self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128) 213 | self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 214 | self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256) 215 | self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 216 | self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512) 217 | self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 218 | self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) 219 | self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)) 220 | self.conv_6_flatten = Flatten() 221 | self.linear = Linear(512, embedding_size, bias=False) 222 | self.bn = BatchNorm1d(embedding_size) 223 | 224 | def forward(self, x): 225 | out = self.conv1(x) 226 | out = self.conv2_dw(out) 227 | out = self.conv_23(out) 228 | out = self.conv_3(out) 229 | out = self.conv_34(out) 230 | out = self.conv_4(out) 231 | out = self.conv_45(out) 232 | out = self.conv_5(out) 233 | out = self.conv_6_sep(out) 234 | out = self.conv_6_dw(out) 235 | out = self.conv_6_flatten(out) 236 | out = self.linear(out) 237 | out = self.bn(out) 238 | return l2_norm(out) 239 | -------------------------------------------------------------------------------- /facexlib/tracking/README.md: -------------------------------------------------------------------------------- 1 | https://github.com/abewley/sort 2 | -------------------------------------------------------------------------------- /facexlib/tracking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/facexlib/tracking/__init__.py -------------------------------------------------------------------------------- /facexlib/tracking/data_association.py: -------------------------------------------------------------------------------- 1 | """ 2 | For each detected item, it computes the intersection over union (IOU) w.r.t. 3 | each tracked object. (IOU matrix) 4 | Then, it applies the Hungarian algorithm (via linear_assignment) to assign each 5 | det. item to the best possible tracked item (i.e. to the one with max IOU) 6 | """ 7 | 8 | import numpy as np 9 | from numba import jit 10 | from scipy.optimize import linear_sum_assignment as linear_assignment 11 | 12 | 13 | @jit 14 | def iou(bb_test, bb_gt): 15 | """Computes IOU between two bboxes in the form [x1,y1,x2,y2] 16 | """ 17 | xx1 = np.maximum(bb_test[0], bb_gt[0]) 18 | yy1 = np.maximum(bb_test[1], bb_gt[1]) 19 | xx2 = np.minimum(bb_test[2], bb_gt[2]) 20 | yy2 = np.minimum(bb_test[3], bb_gt[3]) 21 | w = np.maximum(0., xx2 - xx1) 22 | h = np.maximum(0., yy2 - yy1) 23 | wh = w * h 24 | o = wh / ((bb_test[2] - bb_test[0]) * (bb_test[3] - bb_test[1]) + (bb_gt[2] - bb_gt[0]) * 25 | (bb_gt[3] - bb_gt[1]) - wh) 26 | return (o) 27 | 28 | 29 | def associate_detections_to_trackers(detections, trackers, iou_threshold=0.25): 30 | """Assigns detections to tracked object (both represented as bounding boxes) 31 | 32 | Returns: 33 | 3 lists of matches, unmatched_detections and unmatched_trackers. 34 | """ 35 | if len(trackers) == 0: 36 | return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int) 37 | 38 | iou_matrix = np.zeros((len(detections), len(trackers)), dtype=np.float32) 39 | 40 | for d, det in enumerate(detections): 41 | for t, trk in enumerate(trackers): 42 | iou_matrix[d, t] = iou(det, trk) 43 | # The linear assignment module tries to minimize the total assignment cost. 44 | # In our case we pass -iou_matrix as we want to maximise the total IOU 45 | # between track predictions and the frame detection. 46 | row_ind, col_ind = linear_assignment(-iou_matrix) 47 | 48 | unmatched_detections = [] 49 | for d, det in enumerate(detections): 50 | if d not in row_ind: 51 | unmatched_detections.append(d) 52 | unmatched_trackers = [] 53 | for t, trk in enumerate(trackers): 54 | if t not in col_ind: 55 | unmatched_trackers.append(t) 56 | 57 | # filter out matched with low IOU 58 | matches = [] 59 | for row, col in zip(row_ind, col_ind): 60 | if iou_matrix[row, col] < iou_threshold: 61 | unmatched_detections.append(row) 62 | unmatched_trackers.append(col) 63 | else: 64 | matches.append(np.array([[row, col]])) 65 | 66 | if len(matches) == 0: 67 | matches = np.empty((0, 2), dtype=int) 68 | else: 69 | matches = np.concatenate(matches, axis=0) 70 | 71 | return matches, np.array(unmatched_detections), np.array(unmatched_trackers) 72 | -------------------------------------------------------------------------------- /facexlib/tracking/kalman_tracker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from filterpy.kalman import KalmanFilter 3 | 4 | 5 | def convert_bbox_to_z(bbox): 6 | """Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form 7 | [x,y,s,r] where x,y is the centre of the box and s is the scale/area and 8 | r is the aspect ratio 9 | """ 10 | w = bbox[2] - bbox[0] 11 | h = bbox[3] - bbox[1] 12 | x = bbox[0] + w / 2. 13 | y = bbox[1] + h / 2. 14 | s = w * h # scale is just area 15 | r = w / float(h) 16 | return np.array([x, y, s, r]).reshape((4, 1)) 17 | 18 | 19 | def convert_x_to_bbox(x, score=None): 20 | """Takes a bounding box in the centre form [x,y,s,r] and returns it in 21 | the form [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom 22 | right 23 | """ 24 | w = np.sqrt(x[2] * x[3]) 25 | h = x[2] / w 26 | if score is None: 27 | return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2.]).reshape((1, 4)) 28 | else: 29 | return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score]).reshape((1, 5)) 30 | 31 | 32 | class KalmanBoxTracker(object): 33 | """This class represents the internal state of individual tracked objects 34 | observed as bbox. 35 | doc: https://filterpy.readthedocs.io/en/latest/kalman/KalmanFilter.html 36 | """ 37 | count = 0 38 | 39 | def __init__(self, bbox): 40 | """Initialize a tracker using initial bounding box. 41 | """ 42 | # define constant velocity model 43 | # TODO: x: what is the meanning of x[4:7], v? 44 | self.kf = KalmanFilter(dim_x=7, dim_z=4) 45 | # F (dim_x, dim_x): state transition matrix 46 | self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0], [0, 0, 1, 0, 0, 0, 47 | 1], [0, 0, 0, 1, 0, 0, 0], 48 | [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 1]]) 49 | # H (dim_z, dim_x): measurement function 50 | self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], 51 | [0, 0, 0, 1, 0, 0, 0]]) 52 | # R (dim_z, dim_z): measurement uncertainty/noise 53 | self.kf.R[2:, 2:] *= 10. 54 | # P (dim_x, dim_x): covariance matrix 55 | # give high uncertainty to the unobservable initial velocities 56 | self.kf.P[4:, 4:] *= 1000. 57 | self.kf.P *= 10. 58 | # Q (dim_x, dim_x): Process uncertainty/noise 59 | self.kf.Q[-1, -1] *= 0.01 60 | self.kf.Q[4:, 4:] *= 0.01 61 | # x (dim_x, 1): filter state estimate 62 | self.kf.x[:4] = convert_bbox_to_z(bbox) 63 | 64 | self.time_since_update = 0 65 | self.id = KalmanBoxTracker.count 66 | KalmanBoxTracker.count += 1 67 | self.history = [] 68 | self.hits = 0 69 | self.hit_streak = 0 70 | self.age = 0 71 | 72 | # 解决画面中无人脸检测到时而导致的原有追踪器人像预测的漂移bug 73 | self.predict_num = 0 # 连续预测的数目 74 | 75 | # additional fields 76 | self.face_attributes = [] 77 | 78 | def update(self, bbox): 79 | """Updates the state vector with observed bbox. 80 | """ 81 | self.time_since_update = 0 82 | self.history = [] 83 | self.hits += 1 84 | self.hit_streak += 1 # 连续命中 85 | if bbox != []: 86 | self.kf.update(convert_bbox_to_z(bbox)) 87 | self.predict_num = 0 88 | else: 89 | self.predict_num += 1 90 | 91 | def predict(self): 92 | """Advances the state vector and returns the predicted bounding box 93 | estimate. 94 | """ 95 | 96 | if (self.kf.x[6] + self.kf.x[2]) <= 0: 97 | self.kf.x[6] *= 0.0 98 | self.kf.predict() 99 | self.age += 1 100 | if self.time_since_update > 0: 101 | self.hit_streak = 0 102 | self.time_since_update += 1 103 | self.history.append(convert_x_to_bbox(self.kf.x)) 104 | return self.history[-1][0] 105 | 106 | def get_state(self): 107 | """Returns the current bounding box estimate.""" 108 | return convert_x_to_bbox(self.kf.x)[0] 109 | -------------------------------------------------------------------------------- /facexlib/tracking/sort.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from facexlib.tracking.data_association import associate_detections_to_trackers 4 | from facexlib.tracking.kalman_tracker import KalmanBoxTracker 5 | 6 | 7 | class SORT(object): 8 | """SORT: A Simple, Online and Realtime Tracker. 9 | 10 | Ref: https://github.com/abewley/sort 11 | """ 12 | 13 | def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3): 14 | self.max_age = max_age 15 | self.min_hits = min_hits # 最小的连续命中, 只有满足的才会被返回 16 | self.iou_threshold = iou_threshold 17 | self.trackers = [] 18 | self.frame_count = 0 19 | 20 | def update(self, dets, img_size, additional_attr, detect_interval): 21 | """This method must be called once for each frame even with 22 | empty detections. 23 | NOTE:as in practical realtime MOT, the detector doesn't run on every 24 | single frame. 25 | 26 | Args: 27 | dets (Numpy array): detections in the format 28 | [[x0,y0,x1,y1,score], [x0,y0,x1,y1,score], ...] 29 | 30 | Returns: 31 | a similar array, where the last column is the object ID. 32 | """ 33 | self.frame_count += 1 34 | 35 | # get predicted locations from existing trackers 36 | trks = np.zeros((len(self.trackers), 5)) 37 | to_del = [] # To be deleted 38 | ret = [] 39 | # predict tracker position using Kalman filter 40 | for t, trk in enumerate(trks): 41 | pos = self.trackers[t].predict() # Kalman predict ,very fast ,<1ms 42 | trk[:] = [pos[0], pos[1], pos[2], pos[3], 0] 43 | if np.any(np.isnan(pos)): 44 | to_del.append(t) 45 | trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) 46 | for t in reversed(to_del): 47 | self.trackers.pop(t) 48 | 49 | if dets != []: 50 | matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers( # noqa: E501 51 | dets, trks) 52 | 53 | # update matched trackers with assigned detections 54 | for t, trk in enumerate(self.trackers): 55 | if t not in unmatched_trks: 56 | d = matched[np.where(matched[:, 1] == t)[0], 0] 57 | trk.update(dets[d, :][0]) 58 | trk.face_attributes.append(additional_attr[d[0]]) 59 | 60 | # create and initialize new trackers for unmatched detections 61 | for i in unmatched_dets: 62 | trk = KalmanBoxTracker(dets[i, :]) 63 | trk.face_attributes.append(additional_attr[i]) 64 | print(f'New tracker: {trk.id + 1}.') 65 | self.trackers.append(trk) 66 | 67 | i = len(self.trackers) 68 | for trk in reversed(self.trackers): 69 | if dets == []: 70 | trk.update([]) 71 | 72 | d = trk.get_state() 73 | # get return tracklet 74 | # 1) time_since_update < 1: detected 75 | # 2) i) hit_streak >= min_hits: 最小的连续命中 76 | # ii) frame_count <= min_hits: 最开始的几帧 77 | if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits): 78 | ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1)) # +1 as MOT benchmark requires positive 79 | i -= 1 80 | 81 | # remove dead tracklet 82 | # 1) time_since_update >= max_age: 多久没有更新了 83 | # 2) predict_num: 连续预测的帧数 84 | # 3) out of image size 85 | if (trk.time_since_update >= self.max_age) or (trk.predict_num >= detect_interval) or ( 86 | d[2] < 0 or d[3] < 0 or d[0] > img_size[1] or d[1] > img_size[0]): 87 | print(f'Remove tracker: {trk.id + 1}') 88 | self.trackers.pop(i) 89 | if len(ret) > 0: 90 | return np.concatenate(ret) 91 | else: 92 | return np.empty((0, 5)) 93 | -------------------------------------------------------------------------------- /facexlib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back 2 | from .misc import img2tensor, load_file_from_url, scandir 3 | 4 | __all__ = [ 5 | 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', 'paste_face_back', 6 | 'img2tensor', 'scandir' 7 | ] 8 | -------------------------------------------------------------------------------- /facexlib/utils/face_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def compute_increased_bbox(bbox, increase_area, preserve_aspect=True): 7 | left, top, right, bot = bbox 8 | width = right - left 9 | height = bot - top 10 | 11 | if preserve_aspect: 12 | width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)) 13 | height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)) 14 | else: 15 | width_increase = height_increase = increase_area 16 | left = int(left - width_increase * width) 17 | top = int(top - height_increase * height) 18 | right = int(right + width_increase * width) 19 | bot = int(bot + height_increase * height) 20 | return (left, top, right, bot) 21 | 22 | 23 | def get_valid_bboxes(bboxes, h, w): 24 | left = max(bboxes[0], 0) 25 | top = max(bboxes[1], 0) 26 | right = min(bboxes[2], w) 27 | bottom = min(bboxes[3], h) 28 | return (left, top, right, bottom) 29 | 30 | 31 | def align_crop_face_landmarks(img, 32 | landmarks, 33 | output_size, 34 | transform_size=None, 35 | enable_padding=True, 36 | return_inverse_affine=False, 37 | shrink_ratio=(1, 1)): 38 | """Align and crop face with landmarks. 39 | 40 | The output_size and transform_size are based on width. The height is 41 | adjusted based on shrink_ratio_h/shring_ration_w. 42 | 43 | Modified from: 44 | https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py 45 | 46 | Args: 47 | img (Numpy array): Input image. 48 | landmarks (Numpy array): 5 or 68 or 98 landmarks. 49 | output_size (int): Output face size. 50 | transform_size (ing): Transform size. Usually the four time of 51 | output_size. 52 | enable_padding (float): Default: True. 53 | shrink_ratio (float | tuple[float] | list[float]): Shring the whole 54 | face for height and width (crop larger area). Default: (1, 1). 55 | 56 | Returns: 57 | (Numpy array): Cropped face. 58 | """ 59 | lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5 60 | 61 | if isinstance(shrink_ratio, (float, int)): 62 | shrink_ratio = (shrink_ratio, shrink_ratio) 63 | if transform_size is None: 64 | transform_size = output_size * 4 65 | 66 | # Parse landmarks 67 | lm = np.array(landmarks) 68 | if lm.shape[0] == 5 and lm_type == 'retinaface_5': 69 | eye_left = lm[0] 70 | eye_right = lm[1] 71 | mouth_avg = (lm[3] + lm[4]) * 0.5 72 | elif lm.shape[0] == 5 and lm_type == 'dlib_5': 73 | lm_eye_left = lm[2:4] 74 | lm_eye_right = lm[0:2] 75 | eye_left = np.mean(lm_eye_left, axis=0) 76 | eye_right = np.mean(lm_eye_right, axis=0) 77 | mouth_avg = lm[4] 78 | elif lm.shape[0] == 68: 79 | lm_eye_left = lm[36:42] 80 | lm_eye_right = lm[42:48] 81 | eye_left = np.mean(lm_eye_left, axis=0) 82 | eye_right = np.mean(lm_eye_right, axis=0) 83 | mouth_avg = (lm[48] + lm[54]) * 0.5 84 | elif lm.shape[0] == 98: 85 | lm_eye_left = lm[60:68] 86 | lm_eye_right = lm[68:76] 87 | eye_left = np.mean(lm_eye_left, axis=0) 88 | eye_right = np.mean(lm_eye_right, axis=0) 89 | mouth_avg = (lm[76] + lm[82]) * 0.5 90 | 91 | eye_avg = (eye_left + eye_right) * 0.5 92 | eye_to_eye = eye_right - eye_left 93 | eye_to_mouth = mouth_avg - eye_avg 94 | 95 | # Get the oriented crop rectangle 96 | # x: half width of the oriented crop rectangle 97 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 98 | # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise 99 | # norm with the hypotenuse: get the direction 100 | x /= np.hypot(*x) # get the hypotenuse of a right triangle 101 | rect_scale = 1 # TODO: you can edit it to get larger rect 102 | x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) 103 | # y: half height of the oriented crop rectangle 104 | y = np.flipud(x) * [-1, 1] 105 | 106 | x *= shrink_ratio[1] # width 107 | y *= shrink_ratio[0] # height 108 | 109 | # c: center 110 | c = eye_avg + eye_to_mouth * 0.1 111 | # quad: (left_top, left_bottom, right_bottom, right_top) 112 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 113 | # qsize: side length of the square 114 | qsize = np.hypot(*x) * 2 115 | 116 | quad_ori = np.copy(quad) 117 | # Shrink, for large face 118 | # TODO: do we really need shrink 119 | shrink = int(np.floor(qsize / output_size * 0.5)) 120 | if shrink > 1: 121 | h, w = img.shape[0:2] 122 | rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink))) 123 | img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA) 124 | quad /= shrink 125 | qsize /= shrink 126 | 127 | # Crop 128 | h, w = img.shape[0:2] 129 | border = max(int(np.rint(qsize * 0.1)), 3) 130 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 131 | int(np.ceil(max(quad[:, 1])))) 132 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h)) 133 | if crop[2] - crop[0] < w or crop[3] - crop[1] < h: 134 | img = img[crop[1]:crop[3], crop[0]:crop[2], :] 135 | quad -= crop[0:2] 136 | 137 | # Pad 138 | # pad: (width_left, height_top, width_right, height_bottom) 139 | h, w = img.shape[0:2] 140 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 141 | int(np.ceil(max(quad[:, 1])))) 142 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0)) 143 | if enable_padding and max(pad) > border - 4: 144 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 145 | img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 146 | h, w = img.shape[0:2] 147 | y, x, _ = np.ogrid[:h, :w, :1] 148 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], 149 | np.float32(w - 1 - x) / pad[2]), 150 | 1.0 - np.minimum(np.float32(y) / pad[1], 151 | np.float32(h - 1 - y) / pad[3])) 152 | blur = int(qsize * 0.02) 153 | if blur % 2 == 0: 154 | blur += 1 155 | blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur)) 156 | 157 | img = img.astype('float32') 158 | img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 159 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) 160 | img = np.clip(img, 0, 255) # float32, [0, 255] 161 | quad += pad[:2] 162 | 163 | # Transform use cv2 164 | h_ratio = shrink_ratio[0] / shrink_ratio[1] 165 | dst_h, dst_w = int(transform_size * h_ratio), transform_size 166 | template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) 167 | # use cv2.LMEDS method for the equivalence to skimage transform 168 | # ref: https://blog.csdn.net/yichxi/article/details/115827338 169 | affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0] 170 | cropped_face = cv2.warpAffine( 171 | img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray 172 | 173 | if output_size < transform_size: 174 | cropped_face = cv2.resize( 175 | cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR) 176 | 177 | if return_inverse_affine: 178 | dst_h, dst_w = int(output_size * h_ratio), output_size 179 | template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) 180 | # use cv2.LMEDS method for the equivalence to skimage transform 181 | # ref: https://blog.csdn.net/yichxi/article/details/115827338 182 | affine_matrix = cv2.estimateAffinePartial2D( 183 | quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0] 184 | inverse_affine = cv2.invertAffineTransform(affine_matrix) 185 | else: 186 | inverse_affine = None 187 | return cropped_face, inverse_affine 188 | 189 | 190 | def paste_face_back(img, face, inverse_affine): 191 | h, w = img.shape[0:2] 192 | face_h, face_w = face.shape[0:2] 193 | inv_restored = cv2.warpAffine(face, inverse_affine, (w, h)) 194 | mask = np.ones((face_h, face_w, 3), dtype=np.float32) 195 | inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h)) 196 | # remove the black borders 197 | inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8)) 198 | inv_restored_remove_border = inv_mask_erosion * inv_restored 199 | total_face_area = np.sum(inv_mask_erosion) // 3 200 | # compute the fusion edge based on the area of face 201 | w_edge = int(total_face_area**0.5) // 20 202 | erosion_radius = w_edge * 2 203 | inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) 204 | blur_size = w_edge * 2 205 | inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) 206 | img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img 207 | # float32, [0, 255] 208 | return img 209 | 210 | 211 | if __name__ == '__main__': 212 | import os 213 | 214 | from facexlib.detection import init_detection_model 215 | from facexlib.utils.face_restoration_helper import get_largest_face 216 | from facexlib.visualization import visualize_detection 217 | 218 | img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png' 219 | img_name = os.splitext(os.path.basename(img_path))[0] 220 | 221 | # initialize model 222 | det_net = init_detection_model('retinaface_resnet50', half=False) 223 | img_ori = cv2.imread(img_path) 224 | h, w = img_ori.shape[0:2] 225 | # if larger than 800, scale it 226 | scale = max(h / 800, w / 800) 227 | if scale > 1: 228 | img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR) 229 | 230 | with torch.no_grad(): 231 | bboxes = det_net.detect_faces(img, 0.97) 232 | if scale > 1: 233 | bboxes *= scale # the score is incorrect 234 | bboxes = get_largest_face(bboxes, h, w)[0] 235 | visualize_detection(img_ori, [bboxes], f'tmp/{img_name}_det.png') 236 | 237 | landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)]) 238 | 239 | cropped_face, inverse_affine = align_crop_face_landmarks( 240 | img_ori, 241 | landmarks, 242 | output_size=512, 243 | transform_size=None, 244 | enable_padding=True, 245 | return_inverse_affine=True, 246 | shrink_ratio=(1, 1)) 247 | 248 | cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face) 249 | img = paste_face_back(img_ori, cropped_face, inverse_affine) 250 | cv2.imwrite(f'tmp/{img_name}_back.png', img) 251 | -------------------------------------------------------------------------------- /facexlib/utils/misc.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import os.path as osp 4 | import torch 5 | from torch.hub import download_url_to_file, get_dir 6 | from urllib.parse import urlparse 7 | 8 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | 10 | 11 | def imwrite(img, file_path, params=None, auto_mkdir=True): 12 | """Write image to file. 13 | 14 | Args: 15 | img (ndarray): Image array to be written. 16 | file_path (str): Image file path. 17 | params (None or list): Same as opencv's :func:`imwrite` interface. 18 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 19 | whether to create it automatically. 20 | 21 | Returns: 22 | bool: Successful or not. 23 | """ 24 | if auto_mkdir: 25 | dir_name = os.path.abspath(os.path.dirname(file_path)) 26 | os.makedirs(dir_name, exist_ok=True) 27 | return cv2.imwrite(file_path, img, params) 28 | 29 | 30 | def img2tensor(imgs, bgr2rgb=True, float32=True): 31 | """Numpy array to tensor. 32 | 33 | Args: 34 | imgs (list[ndarray] | ndarray): Input images. 35 | bgr2rgb (bool): Whether to change bgr to rgb. 36 | float32 (bool): Whether to change to float32. 37 | 38 | Returns: 39 | list[tensor] | tensor: Tensor images. If returned results only have 40 | one element, just return tensor. 41 | """ 42 | 43 | def _totensor(img, bgr2rgb, float32): 44 | if img.shape[2] == 3 and bgr2rgb: 45 | if img.dtype == 'float64': 46 | img = img.astype('float32') 47 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 48 | img = torch.from_numpy(img.transpose(2, 0, 1)) 49 | if float32: 50 | img = img.float() 51 | return img 52 | 53 | if isinstance(imgs, list): 54 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 55 | else: 56 | return _totensor(imgs, bgr2rgb, float32) 57 | 58 | 59 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None, save_dir=None): 60 | """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 61 | """ 62 | if model_dir is None: 63 | hub_dir = get_dir() 64 | model_dir = os.path.join(hub_dir, 'checkpoints') 65 | 66 | if save_dir is None: 67 | save_dir = os.path.join(ROOT_DIR, model_dir) 68 | os.makedirs(save_dir, exist_ok=True) 69 | 70 | parts = urlparse(url) 71 | filename = os.path.basename(parts.path) 72 | if file_name is not None: 73 | filename = file_name 74 | cached_file = os.path.abspath(os.path.join(save_dir, filename)) 75 | if not os.path.exists(cached_file): 76 | print(f'Downloading: "{url}" to {cached_file}\n') 77 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 78 | return cached_file 79 | 80 | 81 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 82 | """Scan a directory to find the interested files. 83 | Args: 84 | dir_path (str): Path of the directory. 85 | suffix (str | tuple(str), optional): File suffix that we are 86 | interested in. Default: None. 87 | recursive (bool, optional): If set to True, recursively scan the 88 | directory. Default: False. 89 | full_path (bool, optional): If set to True, include the dir_path. 90 | Default: False. 91 | Returns: 92 | A generator for all the interested files with relative paths. 93 | """ 94 | 95 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 96 | raise TypeError('"suffix" must be a string or tuple of strings') 97 | 98 | root = dir_path 99 | 100 | def _scandir(dir_path, suffix, recursive): 101 | for entry in os.scandir(dir_path): 102 | if not entry.name.startswith('.') and entry.is_file(): 103 | if full_path: 104 | return_path = entry.path 105 | else: 106 | return_path = osp.relpath(entry.path, root) 107 | 108 | if suffix is None: 109 | yield return_path 110 | elif return_path.endswith(suffix): 111 | yield return_path 112 | else: 113 | if recursive: 114 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 115 | else: 116 | continue 117 | 118 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 119 | -------------------------------------------------------------------------------- /facexlib/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from .vis_alignment import visualize_alignment 2 | from .vis_detection import visualize_detection 3 | from .vis_headpose import visualize_headpose 4 | 5 | __all__ = ['visualize_detection', 'visualize_alignment', 'visualize_headpose'] 6 | -------------------------------------------------------------------------------- /facexlib/visualization/vis_alignment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def visualize_alignment(img, landmarks, save_path=None, to_bgr=False): 6 | img = np.copy(img) 7 | h, w = img.shape[0:2] 8 | circle_size = int(max(h, w) / 150) 9 | if to_bgr: 10 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 11 | 12 | for landmarks_face in landmarks: 13 | for lm in landmarks_face: 14 | cv2.circle(img, (int(lm[0]), int(lm[1])), 1, (0, 150, 0), circle_size) 15 | 16 | # save img 17 | if save_path is not None: 18 | cv2.imwrite(save_path, img) 19 | -------------------------------------------------------------------------------- /facexlib/visualization/vis_detection.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def visualize_detection(img, bboxes_and_landmarks, save_path=None, to_bgr=False): 6 | """Visualize detection results. 7 | 8 | Args: 9 | img (Numpy array): Input image. CHW, BGR, [0, 255], uint8. 10 | """ 11 | img = np.copy(img) 12 | if to_bgr: 13 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 14 | 15 | for b in bboxes_and_landmarks: 16 | # confidence 17 | cv2.putText(img, f'{b[4]:.4f}', (int(b[0]), int(b[1] + 12)), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255)) 18 | # bounding boxes 19 | b = list(map(int, b)) 20 | cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2) 21 | # landmarks (for retinaface) 22 | cv2.circle(img, (b[5], b[6]), 1, (0, 0, 255), 4) 23 | cv2.circle(img, (b[7], b[8]), 1, (0, 255, 255), 4) 24 | cv2.circle(img, (b[9], b[10]), 1, (255, 0, 255), 4) 25 | cv2.circle(img, (b[11], b[12]), 1, (0, 255, 0), 4) 26 | cv2.circle(img, (b[13], b[14]), 1, (255, 0, 0), 4) 27 | # save img 28 | if save_path is not None: 29 | cv2.imwrite(save_path, img) 30 | -------------------------------------------------------------------------------- /facexlib/visualization/vis_headpose.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from math import cos, sin 4 | 5 | 6 | def draw_axis(img, yaw, pitch, roll, tdx=None, tdy=None, size=100): 7 | """draw head pose axis.""" 8 | 9 | pitch = pitch * np.pi / 180 10 | yaw = -yaw * np.pi / 180 11 | roll = roll * np.pi / 180 12 | 13 | if tdx is None or tdy is None: 14 | height, width = img.shape[:2] 15 | tdx = width / 2 16 | tdy = height / 2 17 | 18 | # X axis pointing to right, drawn in red 19 | x1 = size * (cos(yaw) * cos(roll)) + tdx 20 | y1 = size * (cos(pitch) * sin(roll) + cos(roll) * sin(pitch) * sin(yaw)) + tdy 21 | # Y axis pointing downside, drawn in green 22 | x2 = size * (-cos(yaw) * sin(roll)) + tdx 23 | y2 = size * (cos(pitch) * cos(roll) - sin(pitch) * sin(yaw) * sin(roll)) + tdy 24 | # Z axis, out of the screen, drawn in blue 25 | x3 = size * (sin(yaw)) + tdx 26 | y3 = size * (-cos(yaw) * sin(pitch)) + tdy 27 | 28 | cv2.line(img, (int(tdx), int(tdy)), (int(x1), int(y1)), (0, 0, 255), 3) 29 | cv2.line(img, (int(tdx), int(tdy)), (int(x2), int(y2)), (0, 255, 0), 3) 30 | cv2.line(img, (int(tdx), int(tdy)), (int(x3), int(y3)), (255, 0, 0), 2) 31 | 32 | return img 33 | 34 | 35 | def draw_pose_cube(img, yaw, pitch, roll, tdx=None, tdy=None, size=150.): 36 | """draw head pose cube. 37 | Where (tdx, tdy) is the translation of the face. 38 | For pose we have [pitch yaw roll tdx tdy tdz scale_factor] 39 | """ 40 | 41 | p = pitch * np.pi / 180 42 | y = -yaw * np.pi / 180 43 | r = roll * np.pi / 180 44 | if tdx is not None and tdy is not None: 45 | face_x = tdx - 0.50 * size 46 | face_y = tdy - 0.50 * size 47 | else: 48 | height, width = img.shape[:2] 49 | face_x = width / 2 - 0.5 * size 50 | face_y = height / 2 - 0.5 * size 51 | 52 | x1 = size * (cos(y) * cos(r)) + face_x 53 | y1 = size * (cos(p) * sin(r) + cos(r) * sin(p) * sin(y)) + face_y 54 | x2 = size * (-cos(y) * sin(r)) + face_x 55 | y2 = size * (cos(p) * cos(r) - sin(p) * sin(y) * sin(r)) + face_y 56 | x3 = size * (sin(y)) + face_x 57 | y3 = size * (-cos(y) * sin(p)) + face_y 58 | 59 | # Draw base in red 60 | cv2.line(img, (int(face_x), int(face_y)), (int(x1), int(y1)), (0, 0, 255), 3) 61 | cv2.line(img, (int(face_x), int(face_y)), (int(x2), int(y2)), (0, 0, 255), 3) 62 | cv2.line(img, (int(x2), int(y2)), (int(x2 + x1 - face_x), int(y2 + y1 - face_y)), (0, 0, 255), 3) 63 | cv2.line(img, (int(x1), int(y1)), (int(x1 + x2 - face_x), int(y1 + y2 - face_y)), (0, 0, 255), 3) 64 | # Draw pillars in blue 65 | cv2.line(img, (int(face_x), int(face_y)), (int(x3), int(y3)), (255, 0, 0), 2) 66 | cv2.line(img, (int(x1), int(y1)), (int(x1 + x3 - face_x), int(y1 + y3 - face_y)), (255, 0, 0), 2) 67 | cv2.line(img, (int(x2), int(y2)), (int(x2 + x3 - face_x), int(y2 + y3 - face_y)), (255, 0, 0), 2) 68 | cv2.line(img, (int(x2 + x1 - face_x), int(y2 + y1 - face_y)), 69 | (int(x3 + x1 + x2 - 2 * face_x), int(y3 + y2 + y1 - 2 * face_y)), (255, 0, 0), 2) 70 | # Draw top in green 71 | cv2.line(img, (int(x3 + x1 - face_x), int(y3 + y1 - face_y)), 72 | (int(x3 + x1 + x2 - 2 * face_x), int(y3 + y2 + y1 - 2 * face_y)), (0, 255, 0), 2) 73 | cv2.line(img, (int(x2 + x3 - face_x), int(y2 + y3 - face_y)), 74 | (int(x3 + x1 + x2 - 2 * face_x), int(y3 + y2 + y1 - 2 * face_y)), (0, 255, 0), 2) 75 | cv2.line(img, (int(x3), int(y3)), (int(x3 + x1 - face_x), int(y3 + y1 - face_y)), (0, 255, 0), 2) 76 | cv2.line(img, (int(x3), int(y3)), (int(x3 + x2 - face_x), int(y3 + y2 - face_y)), (0, 255, 0), 2) 77 | 78 | return img 79 | 80 | 81 | def visualize_headpose(img, yaw, pitch, roll, save_path=None, to_bgr=False): 82 | img = np.copy(img) 83 | if to_bgr: 84 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 85 | show_string = (f'y {yaw[0].item():.2f}, p {pitch[0].item():.2f}, ' + f'r {roll[0].item():.2f}') 86 | cv2.putText(img, show_string, (30, img.shape[0] - 30), fontFace=1, fontScale=1, color=(0, 0, 255), thickness=2) 87 | draw_pose_cube(img, yaw[0], pitch[0], roll[0], size=100) 88 | draw_axis(img, yaw[0], pitch[0], roll[0], tdx=50, tdy=50, size=100) 89 | # save img 90 | if save_path is not None: 91 | cv2.imwrite(save_path, img) 92 | -------------------------------------------------------------------------------- /facexlib/weights/README.md: -------------------------------------------------------------------------------- 1 | # Weights 2 | 3 | Put the downloaded weights to this folder. 4 | -------------------------------------------------------------------------------- /inference/inference_alignment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import torch 4 | 5 | from facexlib.alignment import init_alignment_model, landmark_98_to_68 6 | from facexlib.visualization import visualize_alignment 7 | 8 | 9 | def main(args): 10 | # initialize model 11 | align_net = init_alignment_model(args.model_name, device=args.device) 12 | 13 | img = cv2.imread(args.img_path) 14 | with torch.no_grad(): 15 | landmarks = align_net.get_landmarks(img) 16 | if args.to68: 17 | landmarks = landmark_98_to_68(landmarks) 18 | visualize_alignment(img, [landmarks], args.save_path) 19 | 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--img_path', type=str, default='assets/test2.jpg') 24 | parser.add_argument('--save_path', type=str, default='test_alignment.png') 25 | parser.add_argument('--model_name', type=str, default='awing_fan') 26 | parser.add_argument('--device', type=str, default='cuda') 27 | parser.add_argument('--to68', action='store_true') 28 | args = parser.parse_args() 29 | 30 | main(args) 31 | -------------------------------------------------------------------------------- /inference/inference_crop_standard_faces.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | 4 | from facexlib.detection import init_detection_model 5 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 6 | 7 | input_img = '/home/wxt/datasets/ffhq/ffhq_wild/00028.png' 8 | # initialize face helper 9 | face_helper = FaceRestoreHelper( 10 | upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png') 11 | 12 | face_helper.clean_all() 13 | 14 | det_net = init_detection_model('retinaface_resnet50', half=False) 15 | img = cv2.imread(input_img) 16 | with torch.no_grad(): 17 | bboxes = det_net.detect_faces(img, 0.97) 18 | # x0, y0, x1, y1, confidence_score, five points (x, y) 19 | print(bboxes.shape) 20 | bboxes = bboxes[3] 21 | 22 | bboxes[0] -= 100 23 | bboxes[1] -= 100 24 | bboxes[2] += 100 25 | bboxes[3] += 100 26 | img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :] 27 | 28 | face_helper.read_image(img) 29 | # get face landmarks for each face 30 | face_helper.get_face_landmarks_5(only_center_face=True, pad_blur=False) 31 | # align and warp each face 32 | # save_crop_path = os.path.join(save_root, 'cropped_faces', img_name) 33 | save_crop_path = '00028_cvwarp.png' 34 | face_helper.align_warp_face(save_crop_path) 35 | 36 | # for i in range(50): 37 | # img = cv2.imread(f'inputs/ffhq_512/{i:08d}.png') 38 | # cv2.circle(img, (193, 240), 1, (0, 0, 255), 4) 39 | # cv2.circle(img, (319, 240), 1, (0, 255, 255), 4) 40 | # cv2.circle(img, (257, 314), 1, (255, 0, 255), 4) 41 | # cv2.circle(img, (201, 371), 1, (0, 255, 0), 4) 42 | # cv2.circle(img, (313, 371), 1, (255, 0, 0), 4) 43 | 44 | # cv2.imwrite(f'ffhq_lm/{i:08d}_lm.png', img) 45 | 46 | # [875.5 719.83333333] [1192.5 715.66666667] [1060. 997.] 47 | -------------------------------------------------------------------------------- /inference/inference_detection.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import torch 4 | 5 | from facexlib.detection import init_detection_model 6 | from facexlib.visualization import visualize_detection 7 | 8 | 9 | def main(args): 10 | # initialize model 11 | det_net = init_detection_model(args.model_name, half=args.half) 12 | 13 | img = cv2.imread(args.img_path) 14 | with torch.no_grad(): 15 | bboxes = det_net.detect_faces(img, 0.97) 16 | # x0, y0, x1, y1, confidence_score, five points (x, y) 17 | print(bboxes) 18 | visualize_detection(img, bboxes, args.save_path) 19 | 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--img_path', type=str, default='assets/test.jpg') 24 | parser.add_argument('--save_path', type=str, default='test_detection.png') 25 | parser.add_argument( 26 | '--model_name', type=str, default='retinaface_resnet50', help='retinaface_resnet50 | retinaface_mobile0.25') 27 | parser.add_argument('--half', action='store_true') 28 | args = parser.parse_args() 29 | 30 | main(args) 31 | -------------------------------------------------------------------------------- /inference/inference_headpose.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import torch 5 | from torchvision.transforms.functional import normalize 6 | 7 | from facexlib.detection import init_detection_model 8 | from facexlib.headpose import init_headpose_model 9 | from facexlib.utils.misc import img2tensor 10 | from facexlib.visualization import visualize_headpose 11 | 12 | 13 | def main(args): 14 | # initialize model 15 | det_net = init_detection_model(args.detection_model_name, half=args.half) 16 | headpose_net = init_headpose_model(args.headpose_model_name, half=args.half) 17 | 18 | img = cv2.imread(args.img_path) 19 | with torch.no_grad(): 20 | bboxes = det_net.detect_faces(img, 0.97) 21 | # x0, y0, x1, y1, confidence_score, five points (x, y) 22 | bbox = list(map(int, bboxes[0])) 23 | # crop face region 24 | thld = 10 25 | h, w, _ = img.shape 26 | top = max(bbox[1] - thld, 0) 27 | bottom = min(bbox[3] + thld, h) 28 | left = max(bbox[0] - thld, 0) 29 | right = min(bbox[2] + thld, w) 30 | 31 | det_face = img[top:bottom, left:right, :].astype(np.float32) / 255. 32 | 33 | # resize 34 | det_face = cv2.resize(det_face, (224, 224), interpolation=cv2.INTER_LINEAR) 35 | det_face = img2tensor(np.copy(det_face), bgr2rgb=False) 36 | 37 | # normalize 38 | normalize(det_face, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], inplace=True) 39 | det_face = det_face.unsqueeze(0).cuda() 40 | 41 | yaw, pitch, roll = headpose_net(det_face) 42 | visualize_headpose(img, yaw, pitch, roll, args.save_path) 43 | 44 | 45 | if __name__ == '__main__': 46 | parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.') 47 | parser.add_argument('--img_path', type=str, default='assets/test.jpg') 48 | parser.add_argument('--save_path', type=str, default='assets/test_headpose.png') 49 | parser.add_argument('--detection_model_name', type=str, default='retinaface_resnet50') 50 | parser.add_argument('--headpose_model_name', type=str, default='hopenet') 51 | parser.add_argument('--half', action='store_true') 52 | args = parser.parse_args() 53 | 54 | main(args) 55 | -------------------------------------------------------------------------------- /inference/inference_hyperiqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import os 5 | import torch 6 | import torchvision 7 | from PIL import Image 8 | 9 | from facexlib.assessment import init_assessment_model 10 | from facexlib.detection import init_detection_model 11 | 12 | 13 | def main(args): 14 | """Scripts about evaluating face quality. 15 | Two steps: 16 | 1) detect the face region and crop the face 17 | 2) evaluate the face quality by hyperIQA 18 | """ 19 | # initialize model 20 | det_net = init_detection_model(args.detection_model_name, half=False) 21 | assess_net = init_assessment_model(args.assess_model_name, half=False) 22 | 23 | # specified face transformation in original hyperIQA 24 | transforms = torchvision.transforms.Compose([ 25 | torchvision.transforms.Resize((512, 384)), 26 | torchvision.transforms.RandomCrop(size=224), 27 | torchvision.transforms.ToTensor(), 28 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 29 | ]) 30 | 31 | img = cv2.imread(args.img_path) 32 | img_name = os.path.basename(args.img_path) 33 | basename, _ = os.path.splitext(img_name) 34 | with torch.no_grad(): 35 | bboxes = det_net.detect_faces(img, 0.97) 36 | box = list(map(int, bboxes[0])) 37 | pred_scores = [] 38 | # BRG -> RGB 39 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 40 | 41 | for i in range(10): 42 | detect_face = img[box[1]:box[3], box[0]:box[2], :] 43 | detect_face = Image.fromarray(detect_face) 44 | 45 | detect_face = transforms(detect_face) 46 | detect_face = torch.tensor(detect_face.cuda()).unsqueeze(0) 47 | 48 | pred = assess_net(detect_face) 49 | pred_scores.append(float(pred.item())) 50 | score = np.mean(pred_scores) 51 | # quality score ranges from 0-100, a higher score indicates a better quality 52 | print(f'{basename} {score:.4f}') 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--img_path', type=str, default='assets/test2.jpg') 58 | parser.add_argument('--detection_model_name', type=str, default='retinaface_resnet50') 59 | parser.add_argument('--assess_model_name', type=str, default='hypernet') 60 | parser.add_argument('--half', action='store_true') 61 | args = parser.parse_args() 62 | 63 | main(args) 64 | -------------------------------------------------------------------------------- /inference/inference_matting.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from torchvision.transforms.functional import normalize 6 | 7 | from facexlib.matting import init_matting_model 8 | from facexlib.utils import img2tensor 9 | 10 | 11 | def main(args): 12 | modnet = init_matting_model() 13 | 14 | # read image 15 | img = cv2.imread(args.img_path) / 255. 16 | # unify image channels to 3 17 | if len(img.shape) == 2: 18 | img = img[:, :, None] 19 | if img.shape[2] == 1: 20 | img = np.repeat(img, 3, axis=2) 21 | elif img.shape[2] == 4: 22 | img = img[:, :, 0:3] 23 | 24 | img_t = img2tensor(img, bgr2rgb=True, float32=True) 25 | normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 26 | img_t = img_t.unsqueeze(0).cuda() 27 | 28 | # resize image for input 29 | _, _, im_h, im_w = img_t.shape 30 | ref_size = 512 31 | if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size: 32 | if im_w >= im_h: 33 | im_rh = ref_size 34 | im_rw = int(im_w / im_h * ref_size) 35 | elif im_w < im_h: 36 | im_rw = ref_size 37 | im_rh = int(im_h / im_w * ref_size) 38 | else: 39 | im_rh = im_h 40 | im_rw = im_w 41 | im_rw = im_rw - im_rw % 32 42 | im_rh = im_rh - im_rh % 32 43 | img_t = F.interpolate(img_t, size=(im_rh, im_rw), mode='area') 44 | 45 | # inference 46 | _, _, matte = modnet(img_t, True) 47 | 48 | # resize and save matte 49 | matte = F.interpolate(matte, size=(im_h, im_w), mode='area') 50 | matte = matte[0][0].data.cpu().numpy() 51 | cv2.imwrite(args.save_path, (matte * 255).astype('uint8')) 52 | 53 | # get foreground 54 | matte = matte[:, :, None] 55 | foreground = img * matte + np.full(img.shape, 1) * (1 - matte) 56 | cv2.imwrite(args.save_path.replace('.png', '_fg.png'), foreground * 255) 57 | 58 | 59 | if __name__ == '__main__': 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument('--img_path', type=str, default='assets/test.jpg') 62 | parser.add_argument('--save_path', type=str, default='test_matting.png') 63 | args = parser.parse_args() 64 | 65 | main(args) 66 | -------------------------------------------------------------------------------- /inference/inference_parsing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.transforms.functional import normalize 7 | 8 | from facexlib.parsing import init_parsing_model 9 | from facexlib.utils.misc import img2tensor 10 | 11 | 12 | def vis_parsing_maps(img, parsing_anno, stride, save_anno_path=None, save_vis_path=None): 13 | # Colors for all 20 parts 14 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 0, 85], [255, 0, 170], [0, 255, 0], [85, 255, 0], 15 | [170, 255, 0], [0, 255, 85], [0, 255, 170], [0, 0, 255], [85, 0, 255], [170, 0, 255], [0, 85, 255], 16 | [0, 170, 255], [255, 255, 0], [255, 255, 85], [255, 255, 170], [255, 0, 255], [255, 85, 255], 17 | [255, 170, 255], [0, 255, 255], [85, 255, 255], [170, 255, 255]] 18 | # 0: 'background' 19 | # attributions = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 20 | # 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 10 'nose', 21 | # 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 22 | # 16 'cloth', 17 'hair', 18 'hat'] 23 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8) 24 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) 25 | if save_anno_path is not None: 26 | cv2.imwrite(save_anno_path, vis_parsing_anno) 27 | 28 | if save_vis_path is not None: 29 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 30 | num_of_class = np.max(vis_parsing_anno) 31 | for pi in range(1, num_of_class + 1): 32 | index = np.where(vis_parsing_anno == pi) 33 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] 34 | 35 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) 36 | vis_im = cv2.addWeighted(img, 0.4, vis_parsing_anno_color, 0.6, 0) 37 | 38 | cv2.imwrite(save_vis_path, vis_im) 39 | 40 | 41 | def main(img_path, output): 42 | net = init_parsing_model(model_name='bisenet') 43 | 44 | img_name = os.path.basename(img_path) 45 | img_basename = os.path.splitext(img_name)[0] 46 | 47 | img_input = cv2.imread(img_path) 48 | img_input = cv2.resize(img_input, (512, 512), interpolation=cv2.INTER_LINEAR) 49 | img = img2tensor(img_input.astype('float32') / 255., bgr2rgb=True, float32=True) 50 | normalize(img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True) 51 | img = torch.unsqueeze(img, 0).cuda() 52 | 53 | with torch.no_grad(): 54 | out = net(img)[0] 55 | out = out.squeeze(0).cpu().numpy().argmax(0) 56 | 57 | vis_parsing_maps( 58 | img_input, 59 | out, 60 | stride=1, 61 | save_anno_path=os.path.join(output, f'{img_basename}.png'), 62 | save_vis_path=os.path.join(output, f'{img_basename}_vis.png')) 63 | 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser() 67 | 68 | parser.add_argument('--input', type=str, default='datasets/ffhq/ffhq_512/00000000.png') 69 | parser.add_argument('--output', type=str, default='results', help='output folder') 70 | args = parser.parse_args() 71 | 72 | os.makedirs(args.output, exist_ok=True) 73 | 74 | main(args.input, args.output) 75 | -------------------------------------------------------------------------------- /inference/inference_parsing_parsenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.transforms.functional import normalize 7 | 8 | from facexlib.parsing import init_parsing_model 9 | from facexlib.utils.misc import img2tensor 10 | 11 | 12 | def vis_parsing_maps(img, parsing_anno, stride, save_anno_path=None, save_vis_path=None): 13 | # Colors for all parts 14 | part_colors = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], 15 | [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], 16 | [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] 17 | # 0: 'background' 1: 'skin' 2: 'nose' 18 | # 3: 'eye_g' 4: 'l_eye' 5: 'r_eye' 19 | # 6: 'l_brow' 7: 'r_brow' 8: 'l_ear' 20 | # 9: 'r_ear' 10: 'mouth' 11: 'u_lip' 21 | # 12: 'l_lip' 13: 'hair' 14: 'hat' 22 | # 15: 'ear_r' 16: 'neck_l' 17: 'neck' 23 | # 18: 'cloth' 24 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8) 25 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) 26 | if save_anno_path is not None: 27 | cv2.imwrite(save_anno_path, vis_parsing_anno) 28 | 29 | if save_vis_path is not None: 30 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255 31 | num_of_class = np.max(vis_parsing_anno) 32 | for pi in range(1, num_of_class + 1): 33 | index = np.where(vis_parsing_anno == pi) 34 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] 35 | 36 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) 37 | vis_im = cv2.addWeighted(img, 0.4, vis_parsing_anno_color, 0.6, 0) 38 | 39 | cv2.imwrite(save_vis_path, vis_im) 40 | 41 | 42 | def main(img_path, output): 43 | net = init_parsing_model(model_name='parsenet') 44 | 45 | img_name = os.path.basename(img_path) 46 | img_basename = os.path.splitext(img_name)[0] 47 | 48 | img_input = cv2.imread(img_path) 49 | # resize to 512 x 512 for better performance 50 | img_input = cv2.resize(img_input, (512, 512), interpolation=cv2.INTER_LINEAR) 51 | img = img2tensor(img_input.astype('float32') / 255., bgr2rgb=True, float32=True) 52 | normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 53 | img = torch.unsqueeze(img, 0).cuda() 54 | 55 | with torch.no_grad(): 56 | out = net(img)[0] 57 | out = out.squeeze(0).cpu().numpy().argmax(0) 58 | 59 | vis_parsing_maps( 60 | img_input, 61 | out, 62 | stride=1, 63 | save_anno_path=os.path.join(output, f'{img_basename}.png'), 64 | save_vis_path=os.path.join(output, f'{img_basename}_vis.png')) 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser() 69 | 70 | parser.add_argument('--input', type=str, default='datasets/ffhq/ffhq_512/00000000.png') 71 | parser.add_argument('--output', type=str, default='results', help='output folder') 72 | args = parser.parse_args() 73 | 74 | os.makedirs(args.output, exist_ok=True) 75 | 76 | main(args.input, args.output) 77 | -------------------------------------------------------------------------------- /inference/inference_recognition.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import math 4 | import numpy as np 5 | import os 6 | import torch 7 | 8 | from facexlib.recognition import ResNetArcFace, cosin_metric, load_image 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--folder1', type=str) 13 | parser.add_argument('--folder2', type=str) 14 | parser.add_argument('--model_path', type=str, default='facexlib/recognition/weights/arcface_resnet18.pth') 15 | 16 | args = parser.parse_args() 17 | 18 | img_list1 = sorted(glob.glob(os.path.join(args.folder1, '*'))) 19 | img_list2 = sorted(glob.glob(os.path.join(args.folder2, '*'))) 20 | print(img_list1, img_list2) 21 | model = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False) 22 | model.load_state_dict(torch.load(args.model_path)) 23 | model.to(torch.device('cuda')) 24 | model.eval() 25 | 26 | dist_list = [] 27 | identical_count = 0 28 | for idx, (img_path1, img_path2) in enumerate(zip(img_list1, img_list2)): 29 | basename = os.path.splitext(os.path.basename(img_path1))[0] 30 | img1 = load_image(img_path1) 31 | img2 = load_image(img_path2) 32 | 33 | data = torch.stack([img1, img2], dim=0) 34 | data = data.to(torch.device('cuda')) 35 | output = model(data) 36 | print(output.size()) 37 | output = output.data.cpu().numpy() 38 | dist = cosin_metric(output[0], output[1]) 39 | dist = np.arccos(dist) / math.pi * 180 40 | print(f'{idx} - {dist} o : {basename}') 41 | if dist < 1: 42 | print(f'{basename} is almost identical to original.') 43 | identical_count += 1 44 | else: 45 | dist_list.append(dist) 46 | 47 | print(f'Result dist: {sum(dist_list) / len(dist_list):.6f}') 48 | print(f'identical count: {identical_count}') 49 | -------------------------------------------------------------------------------- /inference/inference_tracking.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import glob 4 | import numpy as np 5 | import os 6 | import torch 7 | from tqdm import tqdm 8 | 9 | from facexlib.detection import init_detection_model 10 | from facexlib.tracking.sort import SORT 11 | 12 | 13 | def main(args): 14 | detect_interval = args.detect_interval 15 | margin = args.margin 16 | face_score_threshold = args.face_score_threshold 17 | 18 | save_frame = True 19 | if save_frame: 20 | colors = np.random.rand(32, 3) 21 | 22 | # init detection model and tracker 23 | det_net = init_detection_model('retinaface_resnet50', half=False) 24 | tracker = SORT(max_age=1, min_hits=2, iou_threshold=0.2) 25 | print('Start track...') 26 | 27 | # track over all frames 28 | frame_paths = sorted(glob.glob(os.path.join(args.input_folder, '*.jpg'))) 29 | pbar = tqdm(total=len(frame_paths), unit='frames', desc='Extract') 30 | for idx, path in enumerate(frame_paths): 31 | img_basename = os.path.basename(path) 32 | frame = cv2.imread(path) 33 | img_size = frame.shape[0:2] 34 | 35 | # detection face bboxes 36 | with torch.no_grad(): 37 | bboxes = det_net.detect_faces(frame, 0.97) 38 | 39 | additional_attr = [] 40 | face_list = [] 41 | 42 | for idx_bb, bbox in enumerate(bboxes): 43 | score = bbox[4] 44 | if score > face_score_threshold: 45 | bbox = bbox[0:5] 46 | det = bbox[0:4] 47 | 48 | # face rectangle 49 | det[0] = np.maximum(det[0] - margin, 0) 50 | det[1] = np.maximum(det[1] - margin, 0) 51 | det[2] = np.minimum(det[2] + margin, img_size[1]) 52 | det[3] = np.minimum(det[3] + margin, img_size[0]) 53 | face_list.append(bbox) 54 | additional_attr.append([score]) 55 | trackers = tracker.update(np.array(face_list), img_size, additional_attr, detect_interval) 56 | 57 | pbar.update(1) 58 | pbar.set_description(f'{idx}: detect {len(bboxes)} faces in {img_basename}') 59 | 60 | # save frame 61 | if save_frame: 62 | for d in trackers: 63 | d = d.astype(np.int32) 64 | cv2.rectangle(frame, (d[0], d[1]), (d[2], d[3]), colors[d[4] % 32, :] * 255, 3) 65 | if len(face_list) != 0: 66 | cv2.putText(frame, 'ID : %d DETECT' % (d[4]), (d[0] - 10, d[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 67 | 0.75, colors[d[4] % 32, :] * 255, 2) 68 | cv2.putText(frame, 'DETECTOR', (5, 45), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (1, 1, 1), 2) 69 | else: 70 | cv2.putText(frame, 'ID : %d' % (d[4]), (d[0] - 10, d[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.75, 71 | colors[d[4] % 32, :] * 255, 2) 72 | save_path = os.path.join(args.save_folder, img_basename) 73 | cv2.imwrite(save_path, frame) 74 | 75 | 76 | if __name__ == '__main__': 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('--input_folder', help='Path to the input folder', type=str) 79 | parser.add_argument('--save_folder', help='Path to save visualized frames', type=str, default=None) 80 | 81 | parser.add_argument( 82 | '--detect_interval', 83 | help=('how many frames to make a detection, trade-off ' 84 | 'between performance and fluency'), 85 | type=int, 86 | default=1) 87 | # if the face is big in your video ,you can set it bigger for easy tracking 88 | parser.add_argument('--margin', help='add margin for face', type=int, default=20) 89 | parser.add_argument( 90 | '--face_score_threshold', help='The threshold of the extracted faces,range 0 < x <=1', type=float, default=0.85) 91 | 92 | args = parser.parse_args() 93 | os.makedirs(args.save_folder, exist_ok=True) 94 | main(args) 95 | 96 | # add verification 97 | # remove last few frames 98 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | filterpy 2 | numba 3 | numpy 4 | numpy 5 | opencv-python 6 | Pillow 7 | scipy 8 | torch 9 | torchvision 10 | tqdm 11 | -------------------------------------------------------------------------------- /scripts/crop_faces_5landmarks.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import facexlib.utils.face_restoration_helper as face_restoration_helper 5 | 6 | 7 | def crop_one_img(img, save_cropped_path=None): 8 | FaceRestoreHelper.clean_all() 9 | FaceRestoreHelper.read_image(img) 10 | # get face landmarks 11 | FaceRestoreHelper.get_face_landmarks_5() 12 | FaceRestoreHelper.align_warp_face(save_cropped_path) 13 | 14 | 15 | if __name__ == '__main__': 16 | # initialize face helper 17 | FaceRestoreHelper = face_restoration_helper.FaceRestoreHelper(upscale_factor=1) 18 | 19 | img_paths = glob.glob('/home/wxt/Projects/test/*') 20 | save_path = 'test' 21 | for idx, path in enumerate(img_paths): 22 | print(idx, path) 23 | file_name = os.path.basename(path) 24 | save_cropped_path = os.path.join(save_path, file_name) 25 | crop_one_img(path, save_cropped_path=save_cropped_path) 26 | -------------------------------------------------------------------------------- /scripts/extract_detection_info_ffhq.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import glob 3 | import numpy as np 4 | import os 5 | import torch 6 | from PIL import Image 7 | from tqdm import tqdm 8 | 9 | from facexlib.detection import init_detection_model 10 | 11 | 12 | def draw_and_save(image, bboxes_and_landmarks, save_path, order_type=1): 13 | """Visualize results 14 | """ 15 | if isinstance(image, Image.Image): 16 | image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) 17 | image = image.astype(np.float32) 18 | for b in bboxes_and_landmarks: 19 | # confidence 20 | cv2.putText(image, '{:.4f}'.format(b[4]), (int(b[0]), int(b[1] + 12)), cv2.FONT_HERSHEY_DUPLEX, 0.5, 21 | (255, 255, 255)) 22 | # bounding boxes 23 | b = list(map(int, b)) 24 | cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2) 25 | # landmarks 26 | if order_type == 0: # mtcnn 27 | cv2.circle(image, (b[5], b[10]), 1, (0, 0, 255), 4) 28 | cv2.circle(image, (b[6], b[11]), 1, (0, 255, 255), 4) 29 | cv2.circle(image, (b[7], b[12]), 1, (255, 0, 255), 4) 30 | cv2.circle(image, (b[8], b[13]), 1, (0, 255, 0), 4) 31 | cv2.circle(image, (b[9], b[14]), 1, (255, 0, 0), 4) 32 | else: # retinaface, centerface 33 | cv2.circle(image, (b[5], b[6]), 1, (0, 0, 255), 4) 34 | cv2.circle(image, (b[7], b[8]), 1, (0, 255, 255), 4) 35 | cv2.circle(image, (b[9], b[10]), 1, (255, 0, 255), 4) 36 | cv2.circle(image, (b[11], b[12]), 1, (0, 255, 0), 4) 37 | cv2.circle(image, (b[13], b[14]), 1, (255, 0, 0), 4) 38 | # save image 39 | cv2.imwrite(save_path, image) 40 | 41 | 42 | det_net = init_detection_model('retinaface_resnet50') 43 | half = False 44 | 45 | det_net.cuda().eval() 46 | if half: 47 | det_net = det_net.half() 48 | 49 | img_list = sorted(glob.glob('../../BasicSR-private/datasets/ffhq/ffhq_512/*')) 50 | 51 | 52 | def get_center_landmark(landmarks, center): 53 | center = np.array(center) 54 | center_dist = [] 55 | for landmark in landmarks: 56 | landmark_center = np.array([(landmark[0] + landmark[2]) / 2, (landmark[1] + landmark[3]) / 2]) 57 | dist = np.linalg.norm(landmark_center - center) 58 | center_dist.append(dist) 59 | center_idx = center_dist.index(min(center_dist)) 60 | return landmarks[center_idx] 61 | 62 | 63 | pbar = tqdm(total=len(img_list), unit='image') 64 | save_np = [] 65 | for idx, path in enumerate(img_list): 66 | img_name = os.path.basename(path) 67 | pbar.update(1) 68 | pbar.set_description(path) 69 | img = Image.open(path) 70 | with torch.no_grad(): 71 | bboxes, warped_face_list = det_net.align_multi(img, 0.97, half=half) 72 | if len(bboxes) > 1: 73 | bboxes = [get_center_landmark(bboxes, (256, 256))] 74 | save_np.append(bboxes) 75 | # draw_and_save(img, bboxes, os.path.join('tmp', img_name), 1) 76 | np.save('ffhq_det_info.npy', save_np) 77 | -------------------------------------------------------------------------------- /scripts/get_ffhq_template.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image 4 | 5 | bboxes = np.load('ffhq_det_info.npy', allow_pickle=True) 6 | 7 | bboxes = np.array(bboxes).squeeze(1) 8 | 9 | bboxes = np.mean(bboxes, axis=0) 10 | 11 | print(bboxes) 12 | 13 | 14 | def draw_and_save(image, bboxes_and_landmarks, save_path, order_type=1): 15 | """Visualize results 16 | """ 17 | if isinstance(image, Image.Image): 18 | image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) 19 | image = image.astype(np.float32) 20 | for b in bboxes_and_landmarks: 21 | # confidence 22 | cv2.putText(image, '{:.4f}'.format(b[4]), (int(b[0]), int(b[1] + 12)), cv2.FONT_HERSHEY_DUPLEX, 0.5, 23 | (255, 255, 255)) 24 | # bounding boxes 25 | b = list(map(int, b)) 26 | cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2) 27 | # landmarks 28 | if order_type == 0: # mtcnn 29 | cv2.circle(image, (b[5], b[10]), 1, (0, 0, 255), 4) 30 | cv2.circle(image, (b[6], b[11]), 1, (0, 255, 255), 4) 31 | cv2.circle(image, (b[7], b[12]), 1, (255, 0, 255), 4) 32 | cv2.circle(image, (b[8], b[13]), 1, (0, 255, 0), 4) 33 | cv2.circle(image, (b[9], b[14]), 1, (255, 0, 0), 4) 34 | else: # retinaface, centerface 35 | cv2.circle(image, (b[5], b[6]), 1, (0, 0, 255), 4) 36 | cv2.circle(image, (b[7], b[8]), 1, (0, 255, 255), 4) 37 | cv2.circle(image, (b[9], b[10]), 1, (255, 0, 255), 4) 38 | cv2.circle(image, (b[11], b[12]), 1, (0, 255, 0), 4) 39 | cv2.circle(image, (b[13], b[14]), 1, (255, 0, 0), 4) 40 | # save image 41 | cv2.imwrite(save_path, image) 42 | 43 | 44 | img = Image.open('inputs/00000000.png') 45 | # bboxes = np.array([ 46 | # 118.177826 * 2, 92.759514 * 2, 394.95926 * 2, 472.53278 * 2, 0.9995705 * 2, # noqa: E501 47 | # 686.77227723, 488.62376238, 586.77227723, 493.59405941, 337.91089109, 48 | # 488.38613861, 437.95049505, 493.51485149, 513.58415842, 678.5049505 49 | # ]) 50 | # bboxes = bboxes / 2 51 | draw_and_save(img, [bboxes], 'template_detall.png', 1) 52 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # line break before binary operator (W503) 4 | W503, 5 | # line break after binary operator (W504) 6 | W504, 7 | max-line-length=120 8 | 9 | [yapf] 10 | based_on_style = pep8 11 | column_limit = 120 12 | blank_line_before_nested_class_or_def = true 13 | split_before_expression_after_opening_paren = true 14 | 15 | [isort] 16 | line_length = 120 17 | multi_line_output = 0 18 | known_standard_library = pkg_resources,setuptools 19 | known_first_party = facexlib 20 | known_third_party = PIL,cv2,filterpy,numba,numpy,scipy,torch,torchvision,tqdm 21 | no_lines_before = STDLIB,LOCALFOLDER 22 | default_section = THIRDPARTY 23 | 24 | [codespell] 25 | skip = .git,./docs/build,*.cfg 26 | count = 27 | quiet-level = 3 28 | ignore-words-list = mot,ans 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import time 8 | 9 | version_file = 'facexlib/version.py' 10 | 11 | 12 | def readme(): 13 | with open('README.md', encoding='utf-8') as f: 14 | content = f.read() 15 | return content 16 | 17 | 18 | def get_git_hash(): 19 | 20 | def _minimal_ext_cmd(cmd): 21 | # construct minimal environment 22 | env = {} 23 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 24 | v = os.environ.get(k) 25 | if v is not None: 26 | env[k] = v 27 | # LANGUAGE is used on win32 28 | env['LANGUAGE'] = 'C' 29 | env['LANG'] = 'C' 30 | env['LC_ALL'] = 'C' 31 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 32 | return out 33 | 34 | try: 35 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 36 | sha = out.strip().decode('ascii') 37 | except OSError: 38 | sha = 'unknown' 39 | 40 | return sha 41 | 42 | 43 | def get_hash(): 44 | if os.path.exists('.git'): 45 | sha = get_git_hash()[:7] 46 | else: 47 | sha = 'unknown' 48 | 49 | return sha 50 | 51 | 52 | def write_version_py(): 53 | content = """# GENERATED VERSION FILE 54 | # TIME: {} 55 | __version__ = '{}' 56 | __gitsha__ = '{}' 57 | version_info = ({}) 58 | """ 59 | sha = get_hash() 60 | with open('VERSION', 'r') as f: 61 | SHORT_VERSION = f.read().strip() 62 | VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 63 | 64 | version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) 65 | with open(version_file, 'w') as f: 66 | f.write(version_file_str) 67 | 68 | 69 | def get_version(): 70 | with open(version_file, 'r') as f: 71 | exec(compile(f.read(), version_file, 'exec')) 72 | return locals()['__version__'] 73 | 74 | 75 | def get_requirements(filename='requirements.txt'): 76 | here = os.path.dirname(os.path.realpath(__file__)) 77 | with open(os.path.join(here, filename), 'r') as f: 78 | requires = [line.replace('\n', '') for line in f.readlines()] 79 | return requires 80 | 81 | 82 | if __name__ == '__main__': 83 | write_version_py() 84 | setup( 85 | name='facexlib', 86 | version=get_version(), 87 | description='Basic face library', 88 | long_description=readme(), 89 | long_description_content_type='text/markdown', 90 | author='Xintao Wang', 91 | author_email='xintao.wang@outlook.com', 92 | keywords='computer vision, face, detection, landmark, alignment', 93 | url='https://github.com/xinntao/facexlib', 94 | include_package_data=True, 95 | packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), 96 | classifiers=[ 97 | 'Development Status :: 4 - Beta', 98 | 'License :: OSI Approved :: Apache Software License', 99 | 'Operating System :: OS Independent', 100 | 'Programming Language :: Python :: 3', 101 | 'Programming Language :: Python :: 3.7', 102 | 'Programming Language :: Python :: 3.8', 103 | ], 104 | license='Apache License 2.0', 105 | setup_requires=['cython', 'numpy'], 106 | install_requires=get_requirements(), 107 | zip_safe=False) 108 | --------------------------------------------------------------------------------