├── .github
└── workflows
│ ├── publish-pip.yml
│ ├── pylint.yml
│ └── release.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── MANIFEST.in
├── README.md
├── README_CN.md
├── VERSION
├── assets
├── icon.png
├── icon_small.png
├── landmarks_5.jpg
├── landmarks_68.png
├── landmarks_98.png
├── test.jpg
└── test2.jpg
├── facexlib
├── __init__.py
├── alignment
│ ├── README.md
│ ├── __init__.py
│ ├── awing_arch.py
│ └── convert_98_to_68_landmarks.py
├── assessment
│ ├── __init__.py
│ └── hyperiqa_net.py
├── detection
│ ├── __init__.py
│ ├── align_trans.py
│ ├── matlab_cp2tform.py
│ ├── retinaface.py
│ ├── retinaface_net.py
│ └── retinaface_utils.py
├── headpose
│ ├── __init__.py
│ └── hopenet_arch.py
├── matting
│ ├── __init__.py
│ ├── backbone.py
│ ├── mobilenetv2.py
│ └── modnet.py
├── parsing
│ ├── __init__.py
│ ├── bisenet.py
│ ├── parsenet.py
│ └── resnet.py
├── recognition
│ ├── __init__.py
│ └── arcface_arch.py
├── tracking
│ ├── README.md
│ ├── __init__.py
│ ├── data_association.py
│ ├── kalman_tracker.py
│ └── sort.py
├── utils
│ ├── __init__.py
│ ├── face_restoration_helper.py
│ ├── face_utils.py
│ └── misc.py
├── visualization
│ ├── __init__.py
│ ├── vis_alignment.py
│ ├── vis_detection.py
│ └── vis_headpose.py
└── weights
│ └── README.md
├── inference
├── inference_alignment.py
├── inference_crop_standard_faces.py
├── inference_detection.py
├── inference_headpose.py
├── inference_hyperiqa.py
├── inference_matting.py
├── inference_parsing.py
├── inference_parsing_parsenet.py
├── inference_recognition.py
└── inference_tracking.py
├── requirements.txt
├── scripts
├── crop_faces_5landmarks.py
├── extract_detection_info_ffhq.py
└── get_ffhq_template.py
├── setup.cfg
└── setup.py
/.github/workflows/publish-pip.yml:
--------------------------------------------------------------------------------
1 | name: PyPI Publish
2 |
3 | on: push
4 |
5 | jobs:
6 | build-n-publish:
7 | runs-on: ubuntu-latest
8 | if: startsWith(github.event.ref, 'refs/tags')
9 |
10 | steps:
11 | - uses: actions/checkout@v2
12 | - name: Set up Python 3.8
13 | uses: actions/setup-python@v1
14 | with:
15 | python-version: 3.8
16 | - name: Upgrade pip
17 | run: |
18 | pip install pip --upgrade
19 | pip install wheel
20 |
21 | - name: Install PyTorch (cpu)
22 | run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
23 | - name: Install dependencies
24 | run: pip install -r requirements.txt
25 | - name: Build and install
26 | run: rm -rf .eggs && pip install -e .
27 | - name: Build for distribution
28 | run: python setup.py sdist bdist_wheel
29 | - name: Publish distribution to PyPI
30 | uses: pypa/gh-action-pypi-publish@master
31 | with:
32 | password: ${{ secrets.PYPI_API_TOKEN }}
33 |
--------------------------------------------------------------------------------
/.github/workflows/pylint.yml:
--------------------------------------------------------------------------------
1 | name: PyLint
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 | strategy:
10 | matrix:
11 | python-version: [3.8]
12 |
13 | steps:
14 | - uses: actions/checkout@v2
15 | - name: Set up Python ${{ matrix.python-version }}
16 | uses: actions/setup-python@v2
17 | with:
18 | python-version: ${{ matrix.python-version }}
19 |
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install codespell flake8 isort yapf
24 |
25 | - name: Lint
26 | run: |
27 | codespell
28 | flake8 .
29 | isort --check-only --diff facexlib/ inference/ scripts/ setup.py
30 | yapf -r -d facexlib/ inference/ scripts/ setup.py
31 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: release
2 | on:
3 | push:
4 | tags:
5 | - '*'
6 |
7 | jobs:
8 | build:
9 | permissions: write-all
10 | name: Create Release
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Checkout code
14 | uses: actions/checkout@v2
15 | - name: Create Release
16 | id: create_release
17 | uses: actions/create-release@v1
18 | env:
19 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
20 | with:
21 | tag_name: ${{ github.ref }}
22 | release_name: FaceXlib ${{ github.ref }} Release Note
23 | body: |
24 | 🚀 See you again 😸
25 | 🚀Have a nice day 😸 and happy everyday 😃
26 | 🚀 Long time no see ☄️
27 |
28 | ✨ **Highlights**
29 | ✅ [Features] Support ...
30 |
31 | 🐛 **Bug Fixes**
32 |
33 | 🌴 **Improvements**
34 |
35 | 📢📢📢
36 |
37 |
38 |
39 |
40 | draft: true
41 | prerelease: false
42 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | *.pth
3 | *.png
4 | *.jpg
5 | version.py
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | pip-wheel-metadata/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | # flake8
3 | - repo: https://github.com/PyCQA/flake8
4 | rev: 3.8.3
5 | hooks:
6 | - id: flake8
7 | args: ["--config=setup.cfg", "--ignore=W504, W503"]
8 |
9 | # modify known_third_party
10 | - repo: https://github.com/asottile/seed-isort-config
11 | rev: v2.2.0
12 | hooks:
13 | - id: seed-isort-config
14 |
15 | # isort
16 | - repo: https://github.com/timothycrosley/isort
17 | rev: 5.2.2
18 | hooks:
19 | - id: isort
20 |
21 | # yapf
22 | - repo: https://github.com/pre-commit/mirrors-yapf
23 | rev: v0.30.0
24 | hooks:
25 | - id: yapf
26 |
27 | # codespell
28 | - repo: https://github.com/codespell-project/codespell
29 | rev: v2.1.0
30 | hooks:
31 | - id: codespell
32 |
33 | # pre-commit-hooks
34 | - repo: https://github.com/pre-commit/pre-commit-hooks
35 | rev: v3.2.0
36 | hooks:
37 | - id: trailing-whitespace # Trim trailing whitespace
38 | - id: check-yaml # Attempt to load all yaml files to verify syntax
39 | - id: check-merge-conflict # Check for files that contain merge conflict strings
40 | - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings
41 | - id: end-of-file-fixer # Make sure files end in a newline and only a newline
42 | - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0
43 | - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*-
44 | args: ["--remove"]
45 | - id: mixed-line-ending # Replace or check mixed line ending
46 | args: ["--fix=lf"]
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Xintao Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include assets/*.png
2 | include assets/*.jpg
3 | include inference/*.py
4 | include scripts/*.py
5 | include VERSION
6 | include requirements.txt
7 | include facexlib/weights/README.md
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #  FaceXLib
2 |
3 | [](https://pypi.org/project/facexlib/)
4 | [](https://github.com/xinntao/facexlib/releases)
5 | [](https://github.com/xinntao/facexlib/issues)
6 | [](https://github.com/xinntao/facexlib/issues)
7 | [](https://github.com/xinntao/facexlib/blob/master/LICENSE)
8 | [](https://github.com/xinntao/facexlib/blob/master/.github/workflows/pylint.yml)
9 | [](https://github.com/xinntao/facexlib/blob/master/.github/workflows/publish-pip.yml)
10 |
11 | [English](README.md) **|** [简体中文](README_CN.md)
12 |
13 | ---
14 |
15 | **facexlib** aims at providing ready-to-use **face-related** functions based on current SOTA open-source methods.
16 | Only PyTorch reference codes are available. For training or fine-tuning, please refer to their original repositories listed below.
17 | Note that we just provide a collection of these algorithms. You need to refer to their original LICENCEs for your intended use.
18 |
19 | If facexlib is helpful in your projects, please help to :star: this repo. Thanks:blush:
20 | Other recommended projects: :arrow_forward: [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) :arrow_forward: [GFPGAN](https://github.com/TencentARC/GFPGAN) :arrow_forward: [BasicSR](https://github.com/xinntao/BasicSR)
21 |
22 | ---
23 |
24 | ## :sparkles: Functions
25 |
26 | | Function | Sources | Original LICENSE |
27 | | :--- | :---: | :---: |
28 | | [Detection](facexlib/detection/README.md) | [Pytorch_Retinaface](https://github.com/biubug6/Pytorch_Retinaface) | MIT |
29 | | [Alignment](facexlib/alignment/README.md) |[AdaptiveWingLoss](https://github.com/protossw512/AdaptiveWingLoss) | Apache 2.0 |
30 | | [Recognition](facexlib/recognition/README.md) | [InsightFace_Pytorch](https://github.com/TreB1eN/InsightFace_Pytorch) | MIT |
31 | | [Parsing](facexlib/parsing/README.md) | [face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch) | MIT |
32 | | [Matting](facexlib/matting/README.md) | [MODNet](https://github.com/ZHKKKe/MODNet) | CC 4.0 |
33 | | [Headpose](facexlib/headpose/README.md) | [deep-head-pose](https://github.com/natanielruiz/deep-head-pose) | Apache 2.0 |
34 | | [Tracking](facexlib/tracking/README.md) | [SORT](https://github.com/abewley/sort) | GPL 3.0 |
35 | | [Assessment](facexlib/assessment/README.md) | [hyperIQA](https://github.com/SSL92/hyperIQA) | - |
36 | | [Utils](facexlib/utils/README.md) | Face Restoration Helper | - |
37 |
38 | ## :eyes: Demo and Tutorials
39 |
40 | ## :wrench: Dependencies and Installation
41 |
42 | - Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
43 | - [PyTorch >= 1.7](https://pytorch.org/)
44 | - Option: NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
45 |
46 | ### Installation
47 |
48 | ```bash
49 | pip install facexlib
50 | ```
51 |
52 | ### Pre-trained models
53 |
54 | It will **automatically** download pre-trained models at the first inference.
55 | If your network is not stable, you can download in advance (may with other download tools), and put them in the folder: `PACKAGE_ROOT_PATH/facexlib/weights`.
56 |
57 | ## :scroll: License and Acknowledgement
58 |
59 | This project is released under the MIT license.
60 |
61 | ## :e-mail: Contact
62 |
63 | If you have any question, open an issue or email `xintao.wang@outlook.com`.
64 |
--------------------------------------------------------------------------------
/README_CN.md:
--------------------------------------------------------------------------------
1 | #  FaceXLib
2 |
3 | [English](README.md) **|** [简体中文](README_CN.md)
4 |
5 | ---
6 |
7 | `facexlib` is a **pytorch-based** library for **face-related** functions, such as detection, alignment, recognition, tracking, utils for face restorations, *etc*.
8 | It only provides inference (without training).
9 | This repo is based current STOA open-source methods (see [more details](#Functions)).
10 |
11 | ## :eyes: Demo
12 |
13 | ## :wrench: Dependencies and Installation
14 |
15 | - Python >= 3.7 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
16 | - [PyTorch >= 1.3](https://pytorch.org/)
17 | - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
18 |
19 | ## :sparkles: Functions
20 |
21 | | Function | Description | Reference |
22 | | :--- | :---: | :---: |
23 | | Detection | ([More details](detection/README.md) | [Pytorch_Retinaface](https://github.com/biubug6/Pytorch_Retinaface) |
24 | | Alignment | ([More details](alignment/README.md) | [AdaptiveWingLoss](https://github.com/protossw512/AdaptiveWingLoss) |
25 | | Recognition | ([More details](recognition/README.md) | [InsightFace_Pytorch](https://github.com/TreB1eN/InsightFace_Pytorch) |
26 | | Tracking | ([More details](tracking/README.md) | [SORT](https://github.com/abewley/sort) |
27 | | Utils | ([More details](utils/README.md)) | |
28 |
29 | ## :scroll: License and Acknowledgement
30 |
31 | This project is released under the MIT license.
32 |
33 | ## :e-mail: Contact
34 |
35 | If you have any question, open an issue or email `xintao.wang@outlook.com`.
36 |
--------------------------------------------------------------------------------
/VERSION:
--------------------------------------------------------------------------------
1 | 0.3.0
2 |
--------------------------------------------------------------------------------
/assets/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/assets/icon.png
--------------------------------------------------------------------------------
/assets/icon_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/assets/icon_small.png
--------------------------------------------------------------------------------
/assets/landmarks_5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/assets/landmarks_5.jpg
--------------------------------------------------------------------------------
/assets/landmarks_68.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/assets/landmarks_68.png
--------------------------------------------------------------------------------
/assets/landmarks_98.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/assets/landmarks_98.png
--------------------------------------------------------------------------------
/assets/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/assets/test.jpg
--------------------------------------------------------------------------------
/assets/test2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/assets/test2.jpg
--------------------------------------------------------------------------------
/facexlib/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from .alignment import *
3 | from .detection import *
4 | from .recognition import *
5 | from .tracking import *
6 | from .utils import *
7 | from .version import __gitsha__, __version__
8 | from .visualization import *
9 |
--------------------------------------------------------------------------------
/facexlib/alignment/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Landmarks
3 |
4 | - 5 landmarks
5 |
6 |
7 |
8 |
9 |
10 | - 68 landmarks
11 |
12 |
13 |
14 |
15 |
16 | - 98 landmarks
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/facexlib/alignment/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from facexlib.utils import load_file_from_url
4 | from .awing_arch import FAN
5 | from .convert_98_to_68_landmarks import landmark_98_to_68
6 |
7 | __all__ = ['FAN', 'landmark_98_to_68']
8 |
9 |
10 | def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
11 | if model_name == 'awing_fan':
12 | model = FAN(num_modules=4, num_landmarks=98, device=device)
13 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
14 | else:
15 | raise NotImplementedError(f'{model_name} is not implemented.')
16 |
17 | model_path = load_file_from_url(
18 | url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
19 | model.load_state_dict(torch.load(model_path)['state_dict'], strict=True)
20 | model.eval()
21 | model = model.to(device)
22 | return model
23 |
--------------------------------------------------------------------------------
/facexlib/alignment/awing_arch.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | def calculate_points(heatmaps):
9 | # change heatmaps to landmarks
10 | B, N, H, W = heatmaps.shape
11 | HW = H * W
12 | BN_range = np.arange(B * N)
13 |
14 | heatline = heatmaps.reshape(B, N, HW)
15 | indexes = np.argmax(heatline, axis=2)
16 |
17 | preds = np.stack((indexes % W, indexes // W), axis=2)
18 | preds = preds.astype(np.float, copy=False)
19 |
20 | inr = indexes.ravel()
21 |
22 | heatline = heatline.reshape(B * N, HW)
23 | x_up = heatline[BN_range, inr + 1]
24 | x_down = heatline[BN_range, inr - 1]
25 | # y_up = heatline[BN_range, inr + W]
26 |
27 | if any((inr + W) >= 4096):
28 | y_up = heatline[BN_range, 4095]
29 | else:
30 | y_up = heatline[BN_range, inr + W]
31 | if any((inr - W) <= 0):
32 | y_down = heatline[BN_range, 0]
33 | else:
34 | y_down = heatline[BN_range, inr - W]
35 |
36 | think_diff = np.sign(np.stack((x_up - x_down, y_up - y_down), axis=1))
37 | think_diff *= .25
38 |
39 | preds += think_diff.reshape(B, N, 2)
40 | preds += .5
41 | return preds
42 |
43 |
44 | class AddCoordsTh(nn.Module):
45 |
46 | def __init__(self, x_dim=64, y_dim=64, with_r=False, with_boundary=False):
47 | super(AddCoordsTh, self).__init__()
48 | self.x_dim = x_dim
49 | self.y_dim = y_dim
50 | self.with_r = with_r
51 | self.with_boundary = with_boundary
52 |
53 | def forward(self, input_tensor, heatmap=None):
54 | """
55 | input_tensor: (batch, c, x_dim, y_dim)
56 | """
57 | batch_size_tensor = input_tensor.shape[0]
58 |
59 | xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32, device=input_tensor.device)
60 | xx_ones = xx_ones.unsqueeze(-1)
61 |
62 | xx_range = torch.arange(self.x_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0)
63 | xx_range = xx_range.unsqueeze(1)
64 |
65 | xx_channel = torch.matmul(xx_ones.float(), xx_range.float())
66 | xx_channel = xx_channel.unsqueeze(-1)
67 |
68 | yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32, device=input_tensor.device)
69 | yy_ones = yy_ones.unsqueeze(1)
70 |
71 | yy_range = torch.arange(self.y_dim, dtype=torch.int32, device=input_tensor.device).unsqueeze(0)
72 | yy_range = yy_range.unsqueeze(-1)
73 |
74 | yy_channel = torch.matmul(yy_range.float(), yy_ones.float())
75 | yy_channel = yy_channel.unsqueeze(-1)
76 |
77 | xx_channel = xx_channel.permute(0, 3, 2, 1)
78 | yy_channel = yy_channel.permute(0, 3, 2, 1)
79 |
80 | xx_channel = xx_channel / (self.x_dim - 1)
81 | yy_channel = yy_channel / (self.y_dim - 1)
82 |
83 | xx_channel = xx_channel * 2 - 1
84 | yy_channel = yy_channel * 2 - 1
85 |
86 | xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1)
87 | yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1)
88 |
89 | if self.with_boundary and heatmap is not None:
90 | boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0)
91 |
92 | zero_tensor = torch.zeros_like(xx_channel)
93 | xx_boundary_channel = torch.where(boundary_channel > 0.05, xx_channel, zero_tensor)
94 | yy_boundary_channel = torch.where(boundary_channel > 0.05, yy_channel, zero_tensor)
95 | if self.with_boundary and heatmap is not None:
96 | xx_boundary_channel = xx_boundary_channel.to(input_tensor.device)
97 | yy_boundary_channel = yy_boundary_channel.to(input_tensor.device)
98 | ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)
99 |
100 | if self.with_r:
101 | rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2))
102 | rr = rr / torch.max(rr)
103 | ret = torch.cat([ret, rr], dim=1)
104 |
105 | if self.with_boundary and heatmap is not None:
106 | ret = torch.cat([ret, xx_boundary_channel, yy_boundary_channel], dim=1)
107 | return ret
108 |
109 |
110 | class CoordConvTh(nn.Module):
111 | """CoordConv layer as in the paper."""
112 |
113 | def __init__(self, x_dim, y_dim, with_r, with_boundary, in_channels, first_one=False, *args, **kwargs):
114 | super(CoordConvTh, self).__init__()
115 | self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r, with_boundary=with_boundary)
116 | in_channels += 2
117 | if with_r:
118 | in_channels += 1
119 | if with_boundary and not first_one:
120 | in_channels += 2
121 | self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs)
122 |
123 | def forward(self, input_tensor, heatmap=None):
124 | ret = self.addcoords(input_tensor, heatmap)
125 | last_channel = ret[:, -2:, :, :]
126 | ret = self.conv(ret)
127 | return ret, last_channel
128 |
129 |
130 | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False, dilation=1):
131 | '3x3 convolution with padding'
132 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=strd, padding=padding, bias=bias, dilation=dilation)
133 |
134 |
135 | class BasicBlock(nn.Module):
136 | expansion = 1
137 |
138 | def __init__(self, inplanes, planes, stride=1, downsample=None):
139 | super(BasicBlock, self).__init__()
140 | self.conv1 = conv3x3(inplanes, planes, stride)
141 | # self.bn1 = nn.BatchNorm2d(planes)
142 | self.relu = nn.ReLU(inplace=True)
143 | self.conv2 = conv3x3(planes, planes)
144 | # self.bn2 = nn.BatchNorm2d(planes)
145 | self.downsample = downsample
146 | self.stride = stride
147 |
148 | def forward(self, x):
149 | residual = x
150 |
151 | out = self.conv1(x)
152 | out = self.relu(out)
153 |
154 | out = self.conv2(out)
155 |
156 | if self.downsample is not None:
157 | residual = self.downsample(x)
158 |
159 | out += residual
160 | out = self.relu(out)
161 |
162 | return out
163 |
164 |
165 | class ConvBlock(nn.Module):
166 |
167 | def __init__(self, in_planes, out_planes):
168 | super(ConvBlock, self).__init__()
169 | self.bn1 = nn.BatchNorm2d(in_planes)
170 | self.conv1 = conv3x3(in_planes, int(out_planes / 2))
171 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
172 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4), padding=1, dilation=1)
173 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
174 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4), padding=1, dilation=1)
175 |
176 | if in_planes != out_planes:
177 | self.downsample = nn.Sequential(
178 | nn.BatchNorm2d(in_planes),
179 | nn.ReLU(True),
180 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False),
181 | )
182 | else:
183 | self.downsample = None
184 |
185 | def forward(self, x):
186 | residual = x
187 |
188 | out1 = self.bn1(x)
189 | out1 = F.relu(out1, True)
190 | out1 = self.conv1(out1)
191 |
192 | out2 = self.bn2(out1)
193 | out2 = F.relu(out2, True)
194 | out2 = self.conv2(out2)
195 |
196 | out3 = self.bn3(out2)
197 | out3 = F.relu(out3, True)
198 | out3 = self.conv3(out3)
199 |
200 | out3 = torch.cat((out1, out2, out3), 1)
201 |
202 | if self.downsample is not None:
203 | residual = self.downsample(residual)
204 |
205 | out3 += residual
206 |
207 | return out3
208 |
209 |
210 | class HourGlass(nn.Module):
211 |
212 | def __init__(self, num_modules, depth, num_features, first_one=False):
213 | super(HourGlass, self).__init__()
214 | self.num_modules = num_modules
215 | self.depth = depth
216 | self.features = num_features
217 | self.coordconv = CoordConvTh(
218 | x_dim=64,
219 | y_dim=64,
220 | with_r=True,
221 | with_boundary=True,
222 | in_channels=256,
223 | first_one=first_one,
224 | out_channels=256,
225 | kernel_size=1,
226 | stride=1,
227 | padding=0)
228 | self._generate_network(self.depth)
229 |
230 | def _generate_network(self, level):
231 | self.add_module('b1_' + str(level), ConvBlock(256, 256))
232 |
233 | self.add_module('b2_' + str(level), ConvBlock(256, 256))
234 |
235 | if level > 1:
236 | self._generate_network(level - 1)
237 | else:
238 | self.add_module('b2_plus_' + str(level), ConvBlock(256, 256))
239 |
240 | self.add_module('b3_' + str(level), ConvBlock(256, 256))
241 |
242 | def _forward(self, level, inp):
243 | # Upper branch
244 | up1 = inp
245 | up1 = self._modules['b1_' + str(level)](up1)
246 |
247 | # Lower branch
248 | low1 = F.avg_pool2d(inp, 2, stride=2)
249 | low1 = self._modules['b2_' + str(level)](low1)
250 |
251 | if level > 1:
252 | low2 = self._forward(level - 1, low1)
253 | else:
254 | low2 = low1
255 | low2 = self._modules['b2_plus_' + str(level)](low2)
256 |
257 | low3 = low2
258 | low3 = self._modules['b3_' + str(level)](low3)
259 |
260 | up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
261 |
262 | return up1 + up2
263 |
264 | def forward(self, x, heatmap):
265 | x, last_channel = self.coordconv(x, heatmap)
266 | return self._forward(self.depth, x), last_channel
267 |
268 |
269 | class FAN(nn.Module):
270 |
271 | def __init__(self, num_modules=1, end_relu=False, gray_scale=False, num_landmarks=68, device='cuda'):
272 | super(FAN, self).__init__()
273 | self.device = device
274 | self.num_modules = num_modules
275 | self.gray_scale = gray_scale
276 | self.end_relu = end_relu
277 | self.num_landmarks = num_landmarks
278 |
279 | # Base part
280 | if self.gray_scale:
281 | self.conv1 = CoordConvTh(
282 | x_dim=256,
283 | y_dim=256,
284 | with_r=True,
285 | with_boundary=False,
286 | in_channels=3,
287 | out_channels=64,
288 | kernel_size=7,
289 | stride=2,
290 | padding=3)
291 | else:
292 | self.conv1 = CoordConvTh(
293 | x_dim=256,
294 | y_dim=256,
295 | with_r=True,
296 | with_boundary=False,
297 | in_channels=3,
298 | out_channels=64,
299 | kernel_size=7,
300 | stride=2,
301 | padding=3)
302 | self.bn1 = nn.BatchNorm2d(64)
303 | self.conv2 = ConvBlock(64, 128)
304 | self.conv3 = ConvBlock(128, 128)
305 | self.conv4 = ConvBlock(128, 256)
306 |
307 | # Stacking part
308 | for hg_module in range(self.num_modules):
309 | if hg_module == 0:
310 | first_one = True
311 | else:
312 | first_one = False
313 | self.add_module('m' + str(hg_module), HourGlass(1, 4, 256, first_one))
314 | self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
315 | self.add_module('conv_last' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
316 | self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
317 | self.add_module('l' + str(hg_module), nn.Conv2d(256, num_landmarks + 1, kernel_size=1, stride=1, padding=0))
318 |
319 | if hg_module < self.num_modules - 1:
320 | self.add_module('bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
321 | self.add_module('al' + str(hg_module),
322 | nn.Conv2d(num_landmarks + 1, 256, kernel_size=1, stride=1, padding=0))
323 |
324 | def forward(self, x):
325 | x, _ = self.conv1(x)
326 | x = F.relu(self.bn1(x), True)
327 | # x = F.relu(self.bn1(self.conv1(x)), True)
328 | x = F.avg_pool2d(self.conv2(x), 2, stride=2)
329 | x = self.conv3(x)
330 | x = self.conv4(x)
331 |
332 | previous = x
333 |
334 | outputs = []
335 | boundary_channels = []
336 | tmp_out = None
337 | for i in range(self.num_modules):
338 | hg, boundary_channel = self._modules['m' + str(i)](previous, tmp_out)
339 |
340 | ll = hg
341 | ll = self._modules['top_m_' + str(i)](ll)
342 |
343 | ll = F.relu(self._modules['bn_end' + str(i)](self._modules['conv_last' + str(i)](ll)), True)
344 |
345 | # Predict heatmaps
346 | tmp_out = self._modules['l' + str(i)](ll)
347 | if self.end_relu:
348 | tmp_out = F.relu(tmp_out) # HACK: Added relu
349 | outputs.append(tmp_out)
350 | boundary_channels.append(boundary_channel)
351 |
352 | if i < self.num_modules - 1:
353 | ll = self._modules['bl' + str(i)](ll)
354 | tmp_out_ = self._modules['al' + str(i)](tmp_out)
355 | previous = previous + ll + tmp_out_
356 |
357 | return outputs, boundary_channels
358 |
359 | def get_landmarks(self, img):
360 | H, W, _ = img.shape
361 | offset = W / 64, H / 64, 0, 0
362 |
363 | img = cv2.resize(img, (256, 256))
364 | inp = img[..., ::-1]
365 | inp = torch.from_numpy(np.ascontiguousarray(inp.transpose((2, 0, 1)))).float()
366 | inp = inp.to(self.device)
367 | inp.div_(255.0).unsqueeze_(0)
368 |
369 | outputs, _ = self.forward(inp)
370 | out = outputs[-1][:, :-1, :, :]
371 | heatmaps = out.detach().cpu().numpy()
372 |
373 | pred = calculate_points(heatmaps).reshape(-1, 2)
374 |
375 | pred *= offset[:2]
376 | pred += offset[-2:]
377 |
378 | return pred
379 |
--------------------------------------------------------------------------------
/facexlib/alignment/convert_98_to_68_landmarks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def load_txt_file(file_path):
5 | """Load data or string from txt file."""
6 |
7 | with open(file_path, 'r') as cfile:
8 | content = cfile.readlines()
9 | cfile.close()
10 | content = [x.strip() for x in content]
11 | num_lines = len(content)
12 | return content, num_lines
13 |
14 |
15 | def anno_parser(anno_path, num_pts, line_offset=0):
16 | """Parse the annotation.
17 | Args:
18 | anno_path: path of anno file (suffix .txt)
19 | num_pts: number of landmarks.
20 | line_offset: first point starts, default: 0.
21 |
22 | Returns:
23 | pts: num_pts x 2 (x, y)
24 | """
25 |
26 | data, _ = load_txt_file(anno_path)
27 | n_points = num_pts
28 | # read points coordinate.
29 | pts = np.zeros((n_points, 2), dtype='float32')
30 | for point_index in range(n_points):
31 | try:
32 | pts_list = data[point_index + line_offset].split(',')
33 | pts[point_index, 0] = float(pts_list[0])
34 | pts[point_index, 1] = float(pts_list[1])
35 | except ValueError:
36 | print(f'Error in loading points in {anno_path}')
37 | return pts
38 |
39 |
40 | def landmark_98_to_68(landmark_98):
41 | """Transfer 98 landmark positions to 68 landmark positions.
42 | Args:
43 | landmark_98(numpy array): Polar coordinates of 98 landmarks, (98, 2)
44 | Returns:
45 | landmark_68(numpy array): Polar coordinates of 98 landmarks, (68, 2)
46 | """
47 |
48 | landmark_68 = np.zeros((68, 2), dtype='float32')
49 | # cheek
50 | for i in range(0, 33):
51 | if i % 2 == 0:
52 | landmark_68[int(i / 2), :] = landmark_98[i, :]
53 | # nose
54 | for i in range(51, 60):
55 | landmark_68[i - 24, :] = landmark_98[i, :]
56 | # mouth
57 | for i in range(76, 96):
58 | landmark_68[i - 28, :] = landmark_98[i, :]
59 | # left eyebrow
60 | landmark_68[17, :] = landmark_98[33, :]
61 | landmark_68[18, :] = (landmark_98[34, :] + landmark_98[41, :]) / 2
62 | landmark_68[19, :] = (landmark_98[35, :] + landmark_98[40, :]) / 2
63 | landmark_68[20, :] = (landmark_98[36, :] + landmark_98[39, :]) / 2
64 | landmark_68[21, :] = (landmark_98[37, :] + landmark_98[38, :]) / 2
65 | # right eyebrow
66 | landmark_68[22, :] = (landmark_98[42, :] + landmark_98[50, :]) / 2
67 | landmark_68[23, :] = (landmark_98[43, :] + landmark_98[49, :]) / 2
68 | landmark_68[24, :] = (landmark_98[44, :] + landmark_98[48, :]) / 2
69 | landmark_68[25, :] = (landmark_98[45, :] + landmark_98[47, :]) / 2
70 | landmark_68[26, :] = landmark_98[46, :]
71 | # left eye
72 | LUT_landmark_68_left_eye = [36, 37, 38, 39, 40, 41]
73 | LUT_landmark_98_left_eye = [60, 61, 63, 64, 65, 67]
74 | for idx, landmark_98_index in enumerate(LUT_landmark_98_left_eye):
75 | landmark_68[LUT_landmark_68_left_eye[idx], :] = landmark_98[landmark_98_index, :]
76 | # right eye
77 | LUT_landmark_68_right_eye = [42, 43, 44, 45, 46, 47]
78 | LUT_landmark_98_right_eye = [68, 69, 71, 72, 73, 75]
79 | for idx, landmark_98_index in enumerate(LUT_landmark_98_right_eye):
80 | landmark_68[LUT_landmark_68_right_eye[idx], :] = landmark_98[landmark_98_index, :]
81 |
82 | return landmark_68
83 |
--------------------------------------------------------------------------------
/facexlib/assessment/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from facexlib.utils import load_file_from_url
4 | from .hyperiqa_net import HyperIQA
5 |
6 |
7 | def init_assessment_model(model_name, half=False, device='cuda', model_rootpath=None):
8 | if model_name == 'hypernet':
9 | model = HyperIQA(16, 112, 224, 112, 56, 28, 14, 7)
10 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/assessment_hyperIQA.pth'
11 | else:
12 | raise NotImplementedError(f'{model_name} is not implemented.')
13 |
14 | # load the pre-trained hypernet model
15 | hypernet_model_path = load_file_from_url(
16 | url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
17 | model.hypernet.load_state_dict((torch.load(hypernet_model_path, map_location=lambda storage, loc: storage)))
18 | model = model.eval()
19 | model = model.to(device)
20 | return model
21 |
--------------------------------------------------------------------------------
/facexlib/assessment/hyperiqa_net.py:
--------------------------------------------------------------------------------
1 | import torch as torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class HyperIQA(nn.Module):
7 | """
8 | Combine the hypernet and target network within a network.
9 | """
10 |
11 | def __init__(self, *args):
12 | super(HyperIQA, self).__init__()
13 | self.hypernet = HyperNet(*args)
14 |
15 | def forward(self, img):
16 | net_params = self.hypernet(img)
17 | # build the target network
18 | target_net = TargetNet(net_params)
19 | for param in target_net.parameters():
20 | param.requires_grad = False
21 | # predict the face quality
22 | pred = target_net(net_params['target_in_vec'])
23 | return pred
24 |
25 |
26 | class HyperNet(nn.Module):
27 | """
28 | Hyper network for learning perceptual rules.
29 | Args:
30 | lda_out_channels: local distortion aware module output size.
31 | hyper_in_channels: input feature channels for hyper network.
32 | target_in_size: input vector size for target network.
33 | target_fc(i)_size: fully connection layer size of target network.
34 | feature_size: input feature map width/height for hyper network.
35 | Note:
36 | For size match, input args must satisfy: 'target_fc(i)_size * target_fc(i+1)_size' is divisible by 'feature_size ^ 2'. # noqa E501
37 | """
38 |
39 | def __init__(self, lda_out_channels, hyper_in_channels, target_in_size, target_fc1_size, target_fc2_size,
40 | target_fc3_size, target_fc4_size, feature_size):
41 | super(HyperNet, self).__init__()
42 |
43 | self.hyperInChn = hyper_in_channels
44 | self.target_in_size = target_in_size
45 | self.f1 = target_fc1_size
46 | self.f2 = target_fc2_size
47 | self.f3 = target_fc3_size
48 | self.f4 = target_fc4_size
49 | self.feature_size = feature_size
50 |
51 | self.res = resnet50_backbone(lda_out_channels, target_in_size)
52 |
53 | self.pool = nn.AdaptiveAvgPool2d((1, 1))
54 |
55 | # Conv layers for resnet output features
56 | self.conv1 = nn.Sequential(
57 | nn.Conv2d(2048, 1024, 1, padding=(0, 0)), nn.ReLU(inplace=True), nn.Conv2d(1024, 512, 1, padding=(0, 0)),
58 | nn.ReLU(inplace=True), nn.Conv2d(512, self.hyperInChn, 1, padding=(0, 0)), nn.ReLU(inplace=True))
59 |
60 | # Hyper network part, conv for generating target fc weights, fc for generating target fc biases
61 | self.fc1w_conv = nn.Conv2d(
62 | self.hyperInChn, int(self.target_in_size * self.f1 / feature_size**2), 3, padding=(1, 1))
63 | self.fc1b_fc = nn.Linear(self.hyperInChn, self.f1)
64 |
65 | self.fc2w_conv = nn.Conv2d(self.hyperInChn, int(self.f1 * self.f2 / feature_size**2), 3, padding=(1, 1))
66 | self.fc2b_fc = nn.Linear(self.hyperInChn, self.f2)
67 |
68 | self.fc3w_conv = nn.Conv2d(self.hyperInChn, int(self.f2 * self.f3 / feature_size**2), 3, padding=(1, 1))
69 | self.fc3b_fc = nn.Linear(self.hyperInChn, self.f3)
70 |
71 | self.fc4w_conv = nn.Conv2d(self.hyperInChn, int(self.f3 * self.f4 / feature_size**2), 3, padding=(1, 1))
72 | self.fc4b_fc = nn.Linear(self.hyperInChn, self.f4)
73 |
74 | self.fc5w_fc = nn.Linear(self.hyperInChn, self.f4)
75 | self.fc5b_fc = nn.Linear(self.hyperInChn, 1)
76 |
77 | def forward(self, img):
78 | feature_size = self.feature_size
79 |
80 | res_out = self.res(img)
81 |
82 | # input vector for target net
83 | target_in_vec = res_out['target_in_vec'].view(-1, self.target_in_size, 1, 1)
84 |
85 | # input features for hyper net
86 | hyper_in_feat = self.conv1(res_out['hyper_in_feat']).view(-1, self.hyperInChn, feature_size, feature_size)
87 |
88 | # generating target net weights & biases
89 | target_fc1w = self.fc1w_conv(hyper_in_feat).view(-1, self.f1, self.target_in_size, 1, 1)
90 | target_fc1b = self.fc1b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f1)
91 |
92 | target_fc2w = self.fc2w_conv(hyper_in_feat).view(-1, self.f2, self.f1, 1, 1)
93 | target_fc2b = self.fc2b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f2)
94 |
95 | target_fc3w = self.fc3w_conv(hyper_in_feat).view(-1, self.f3, self.f2, 1, 1)
96 | target_fc3b = self.fc3b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f3)
97 |
98 | target_fc4w = self.fc4w_conv(hyper_in_feat).view(-1, self.f4, self.f3, 1, 1)
99 | target_fc4b = self.fc4b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, self.f4)
100 |
101 | target_fc5w = self.fc5w_fc(self.pool(hyper_in_feat).squeeze()).view(-1, 1, self.f4, 1, 1)
102 | target_fc5b = self.fc5b_fc(self.pool(hyper_in_feat).squeeze()).view(-1, 1)
103 |
104 | out = {}
105 | out['target_in_vec'] = target_in_vec
106 | out['target_fc1w'] = target_fc1w
107 | out['target_fc1b'] = target_fc1b
108 | out['target_fc2w'] = target_fc2w
109 | out['target_fc2b'] = target_fc2b
110 | out['target_fc3w'] = target_fc3w
111 | out['target_fc3b'] = target_fc3b
112 | out['target_fc4w'] = target_fc4w
113 | out['target_fc4b'] = target_fc4b
114 | out['target_fc5w'] = target_fc5w
115 | out['target_fc5b'] = target_fc5b
116 |
117 | return out
118 |
119 |
120 | class Bottleneck(nn.Module):
121 | expansion = 4
122 |
123 | def __init__(self, inplanes, planes, stride=1, downsample=None):
124 | super(Bottleneck, self).__init__()
125 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
126 | self.bn1 = nn.BatchNorm2d(planes)
127 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
128 | self.bn2 = nn.BatchNorm2d(planes)
129 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
130 | self.bn3 = nn.BatchNorm2d(planes * 4)
131 | self.relu = nn.ReLU(inplace=True)
132 | self.downsample = downsample
133 | self.stride = stride
134 |
135 | def forward(self, x):
136 | residual = x
137 |
138 | out = self.conv1(x)
139 | out = self.bn1(out)
140 | out = self.relu(out)
141 |
142 | out = self.conv2(out)
143 | out = self.bn2(out)
144 | out = self.relu(out)
145 |
146 | out = self.conv3(out)
147 | out = self.bn3(out)
148 |
149 | if self.downsample is not None:
150 | residual = self.downsample(x)
151 |
152 | out += residual
153 | out = self.relu(out)
154 |
155 | return out
156 |
157 |
158 | class ResNetBackbone(nn.Module):
159 |
160 | def __init__(self, lda_out_channels, in_chn, block, layers, num_classes=1000):
161 | super(ResNetBackbone, self).__init__()
162 | self.inplanes = 64
163 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
164 | self.bn1 = nn.BatchNorm2d(64)
165 | self.relu = nn.ReLU(inplace=True)
166 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
167 | self.layer1 = self._make_layer(block, 64, layers[0])
168 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
169 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
170 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
171 |
172 | # local distortion aware module
173 | self.lda1_pool = nn.Sequential(
174 | nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0, bias=False),
175 | nn.AvgPool2d(7, stride=7),
176 | )
177 | self.lda1_fc = nn.Linear(16 * 64, lda_out_channels)
178 |
179 | self.lda2_pool = nn.Sequential(
180 | nn.Conv2d(512, 32, kernel_size=1, stride=1, padding=0, bias=False),
181 | nn.AvgPool2d(7, stride=7),
182 | )
183 | self.lda2_fc = nn.Linear(32 * 16, lda_out_channels)
184 |
185 | self.lda3_pool = nn.Sequential(
186 | nn.Conv2d(1024, 64, kernel_size=1, stride=1, padding=0, bias=False),
187 | nn.AvgPool2d(7, stride=7),
188 | )
189 | self.lda3_fc = nn.Linear(64 * 4, lda_out_channels)
190 |
191 | self.lda4_pool = nn.AvgPool2d(7, stride=7)
192 | self.lda4_fc = nn.Linear(2048, in_chn - lda_out_channels * 3)
193 |
194 | def _make_layer(self, block, planes, blocks, stride=1):
195 | downsample = None
196 | if stride != 1 or self.inplanes != planes * block.expansion:
197 | downsample = nn.Sequential(
198 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
199 | nn.BatchNorm2d(planes * block.expansion),
200 | )
201 |
202 | layers = []
203 | layers.append(block(self.inplanes, planes, stride, downsample))
204 | self.inplanes = planes * block.expansion
205 | for i in range(1, blocks):
206 | layers.append(block(self.inplanes, planes))
207 |
208 | return nn.Sequential(*layers)
209 |
210 | def forward(self, x):
211 | x = self.conv1(x)
212 | x = self.bn1(x)
213 | x = self.relu(x)
214 | x = self.maxpool(x)
215 | x = self.layer1(x)
216 |
217 | # the same effect as lda operation in the paper, but save much more memory
218 | lda_1 = self.lda1_fc(self.lda1_pool(x).view(x.size(0), -1))
219 | x = self.layer2(x)
220 | lda_2 = self.lda2_fc(self.lda2_pool(x).view(x.size(0), -1))
221 | x = self.layer3(x)
222 | lda_3 = self.lda3_fc(self.lda3_pool(x).view(x.size(0), -1))
223 | x = self.layer4(x)
224 | lda_4 = self.lda4_fc(self.lda4_pool(x).view(x.size(0), -1))
225 |
226 | vec = torch.cat((lda_1, lda_2, lda_3, lda_4), 1)
227 |
228 | out = {}
229 | out['hyper_in_feat'] = x
230 | out['target_in_vec'] = vec
231 |
232 | return out
233 |
234 |
235 | def resnet50_backbone(lda_out_channels, in_chn, **kwargs):
236 | """Constructs a ResNet-50 model_hyper."""
237 | model = ResNetBackbone(lda_out_channels, in_chn, Bottleneck, [3, 4, 6, 3], **kwargs)
238 | return model
239 |
240 |
241 | class TargetNet(nn.Module):
242 | """
243 | Target network for quality prediction.
244 | """
245 |
246 | def __init__(self, paras):
247 | super(TargetNet, self).__init__()
248 | self.l1 = nn.Sequential(
249 | TargetFC(paras['target_fc1w'], paras['target_fc1b']),
250 | nn.Sigmoid(),
251 | )
252 | self.l2 = nn.Sequential(
253 | TargetFC(paras['target_fc2w'], paras['target_fc2b']),
254 | nn.Sigmoid(),
255 | )
256 |
257 | self.l3 = nn.Sequential(
258 | TargetFC(paras['target_fc3w'], paras['target_fc3b']),
259 | nn.Sigmoid(),
260 | )
261 |
262 | self.l4 = nn.Sequential(
263 | TargetFC(paras['target_fc4w'], paras['target_fc4b']),
264 | nn.Sigmoid(),
265 | TargetFC(paras['target_fc5w'], paras['target_fc5b']),
266 | )
267 |
268 | def forward(self, x):
269 | q = self.l1(x)
270 | # q = F.dropout(q)
271 | q = self.l2(q)
272 | q = self.l3(q)
273 | q = self.l4(q).squeeze()
274 | return q
275 |
276 |
277 | class TargetFC(nn.Module):
278 | """
279 | Fully connection operations for target net
280 | Note:
281 | Weights & biases are different for different images in a batch,
282 | thus here we use group convolution for calculating images in a batch with individual weights & biases.
283 | """
284 |
285 | def __init__(self, weight, bias):
286 | super(TargetFC, self).__init__()
287 | self.weight = weight
288 | self.bias = bias
289 |
290 | def forward(self, input_):
291 |
292 | input_re = input_.view(-1, input_.shape[0] * input_.shape[1], input_.shape[2], input_.shape[3])
293 | weight_re = self.weight.view(self.weight.shape[0] * self.weight.shape[1], self.weight.shape[2],
294 | self.weight.shape[3], self.weight.shape[4])
295 | bias_re = self.bias.view(self.bias.shape[0] * self.bias.shape[1])
296 | out = F.conv2d(input=input_re, weight=weight_re, bias=bias_re, groups=self.weight.shape[0])
297 |
298 | return out.view(input_.shape[0], self.weight.shape[1], input_.shape[2], input_.shape[3])
299 |
--------------------------------------------------------------------------------
/facexlib/detection/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from copy import deepcopy
3 |
4 | from facexlib.utils import load_file_from_url
5 | from .retinaface import RetinaFace
6 |
7 |
8 | def init_detection_model(model_name, half=False, device='cuda', model_rootpath=None):
9 | if model_name == 'retinaface_resnet50':
10 | model = RetinaFace(network_name='resnet50', half=half, device=device)
11 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth'
12 | elif model_name == 'retinaface_mobile0.25':
13 | model = RetinaFace(network_name='mobile0.25', half=half, device=device)
14 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
15 | else:
16 | raise NotImplementedError(f'{model_name} is not implemented.')
17 |
18 | model_path = load_file_from_url(
19 | url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
20 |
21 | # TODO: clean pretrained model
22 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
23 | # remove unnecessary 'module.'
24 | for k, v in deepcopy(load_net).items():
25 | if k.startswith('module.'):
26 | load_net[k[7:]] = v
27 | load_net.pop(k)
28 | model.load_state_dict(load_net, strict=True)
29 | model.eval()
30 | model = model.to(device)
31 | return model
32 |
--------------------------------------------------------------------------------
/facexlib/detection/align_trans.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | from .matlab_cp2tform import get_similarity_transform_for_cv2
5 |
6 | # reference facial points, a list of coordinates (x,y)
7 | REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278],
8 | [33.54930115, 92.3655014], [62.72990036, 92.20410156]]
9 |
10 | DEFAULT_CROP_SIZE = (96, 112)
11 |
12 |
13 | class FaceWarpException(Exception):
14 |
15 | def __str__(self):
16 | return 'In File {}:{}'.format(__file__, super.__str__(self))
17 |
18 |
19 | def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
20 | """
21 | Function:
22 | ----------
23 | get reference 5 key points according to crop settings:
24 | 0. Set default crop_size:
25 | if default_square:
26 | crop_size = (112, 112)
27 | else:
28 | crop_size = (96, 112)
29 | 1. Pad the crop_size by inner_padding_factor in each side;
30 | 2. Resize crop_size into (output_size - outer_padding*2),
31 | pad into output_size with outer_padding;
32 | 3. Output reference_5point;
33 | Parameters:
34 | ----------
35 | @output_size: (w, h) or None
36 | size of aligned face image
37 | @inner_padding_factor: (w_factor, h_factor)
38 | padding factor for inner (w, h)
39 | @outer_padding: (w_pad, h_pad)
40 | each row is a pair of coordinates (x, y)
41 | @default_square: True or False
42 | if True:
43 | default crop_size = (112, 112)
44 | else:
45 | default crop_size = (96, 112);
46 | !!! make sure, if output_size is not None:
47 | (output_size - outer_padding)
48 | = some_scale * (default crop_size * (1.0 +
49 | inner_padding_factor))
50 | Returns:
51 | ----------
52 | @reference_5point: 5x2 np.array
53 | each row is a pair of transformed coordinates (x, y)
54 | """
55 |
56 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
57 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
58 |
59 | # 0) make the inner region a square
60 | if default_square:
61 | size_diff = max(tmp_crop_size) - tmp_crop_size
62 | tmp_5pts += size_diff / 2
63 | tmp_crop_size += size_diff
64 |
65 | if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]):
66 |
67 | return tmp_5pts
68 |
69 | if (inner_padding_factor == 0 and outer_padding == (0, 0)):
70 | if output_size is None:
71 | return tmp_5pts
72 | else:
73 | raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
74 |
75 | # check output size
76 | if not (0 <= inner_padding_factor <= 1.0):
77 | raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
78 |
79 | if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None):
80 | output_size = tmp_crop_size * \
81 | (1 + inner_padding_factor * 2).astype(np.int32)
82 | output_size += np.array(outer_padding)
83 | if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
84 | raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])')
85 |
86 | # 1) pad the inner region according inner_padding_factor
87 | if inner_padding_factor > 0:
88 | size_diff = tmp_crop_size * inner_padding_factor * 2
89 | tmp_5pts += size_diff / 2
90 | tmp_crop_size += np.round(size_diff).astype(np.int32)
91 |
92 | # 2) resize the padded inner region
93 | size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
94 |
95 | if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
96 | raise FaceWarpException('Must have (output_size - outer_padding)'
97 | '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
98 |
99 | scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
100 | tmp_5pts = tmp_5pts * scale_factor
101 | # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
102 | # tmp_5pts = tmp_5pts + size_diff / 2
103 | tmp_crop_size = size_bf_outer_pad
104 |
105 | # 3) add outer_padding to make output_size
106 | reference_5point = tmp_5pts + np.array(outer_padding)
107 | tmp_crop_size = output_size
108 |
109 | return reference_5point
110 |
111 |
112 | def get_affine_transform_matrix(src_pts, dst_pts):
113 | """
114 | Function:
115 | ----------
116 | get affine transform matrix 'tfm' from src_pts to dst_pts
117 | Parameters:
118 | ----------
119 | @src_pts: Kx2 np.array
120 | source points matrix, each row is a pair of coordinates (x, y)
121 | @dst_pts: Kx2 np.array
122 | destination points matrix, each row is a pair of coordinates (x, y)
123 | Returns:
124 | ----------
125 | @tfm: 2x3 np.array
126 | transform matrix from src_pts to dst_pts
127 | """
128 |
129 | tfm = np.float32([[1, 0, 0], [0, 1, 0]])
130 | n_pts = src_pts.shape[0]
131 | ones = np.ones((n_pts, 1), src_pts.dtype)
132 | src_pts_ = np.hstack([src_pts, ones])
133 | dst_pts_ = np.hstack([dst_pts, ones])
134 |
135 | A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
136 |
137 | if rank == 3:
138 | tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
139 | elif rank == 2:
140 | tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
141 |
142 | return tfm
143 |
144 |
145 | def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'):
146 | """
147 | Function:
148 | ----------
149 | apply affine transform 'trans' to uv
150 | Parameters:
151 | ----------
152 | @src_img: 3x3 np.array
153 | input image
154 | @facial_pts: could be
155 | 1)a list of K coordinates (x,y)
156 | or
157 | 2) Kx2 or 2xK np.array
158 | each row or col is a pair of coordinates (x, y)
159 | @reference_pts: could be
160 | 1) a list of K coordinates (x,y)
161 | or
162 | 2) Kx2 or 2xK np.array
163 | each row or col is a pair of coordinates (x, y)
164 | or
165 | 3) None
166 | if None, use default reference facial points
167 | @crop_size: (w, h)
168 | output face image size
169 | @align_type: transform type, could be one of
170 | 1) 'similarity': use similarity transform
171 | 2) 'cv2_affine': use the first 3 points to do affine transform,
172 | by calling cv2.getAffineTransform()
173 | 3) 'affine': use all points to do affine transform
174 | Returns:
175 | ----------
176 | @face_img: output face image with size (w, h) = @crop_size
177 | """
178 |
179 | if reference_pts is None:
180 | if crop_size[0] == 96 and crop_size[1] == 112:
181 | reference_pts = REFERENCE_FACIAL_POINTS
182 | else:
183 | default_square = False
184 | inner_padding_factor = 0
185 | outer_padding = (0, 0)
186 | output_size = crop_size
187 |
188 | reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding,
189 | default_square)
190 |
191 | ref_pts = np.float32(reference_pts)
192 | ref_pts_shp = ref_pts.shape
193 | if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
194 | raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2')
195 |
196 | if ref_pts_shp[0] == 2:
197 | ref_pts = ref_pts.T
198 |
199 | src_pts = np.float32(facial_pts)
200 | src_pts_shp = src_pts.shape
201 | if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
202 | raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2')
203 |
204 | if src_pts_shp[0] == 2:
205 | src_pts = src_pts.T
206 |
207 | if src_pts.shape != ref_pts.shape:
208 | raise FaceWarpException('facial_pts and reference_pts must have the same shape')
209 |
210 | if align_type == 'cv2_affine':
211 | tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
212 | elif align_type == 'affine':
213 | tfm = get_affine_transform_matrix(src_pts, ref_pts)
214 | else:
215 | tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
216 |
217 | face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
218 |
219 | return face_img
220 |
--------------------------------------------------------------------------------
/facexlib/detection/matlab_cp2tform.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.linalg import inv, lstsq
3 | from numpy.linalg import matrix_rank as rank
4 | from numpy.linalg import norm
5 |
6 |
7 | class MatlabCp2tormException(Exception):
8 |
9 | def __str__(self):
10 | return 'In File {}:{}'.format(__file__, super.__str__(self))
11 |
12 |
13 | def tformfwd(trans, uv):
14 | """
15 | Function:
16 | ----------
17 | apply affine transform 'trans' to uv
18 |
19 | Parameters:
20 | ----------
21 | @trans: 3x3 np.array
22 | transform matrix
23 | @uv: Kx2 np.array
24 | each row is a pair of coordinates (x, y)
25 |
26 | Returns:
27 | ----------
28 | @xy: Kx2 np.array
29 | each row is a pair of transformed coordinates (x, y)
30 | """
31 | uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
32 | xy = np.dot(uv, trans)
33 | xy = xy[:, 0:-1]
34 | return xy
35 |
36 |
37 | def tforminv(trans, uv):
38 | """
39 | Function:
40 | ----------
41 | apply the inverse of affine transform 'trans' to uv
42 |
43 | Parameters:
44 | ----------
45 | @trans: 3x3 np.array
46 | transform matrix
47 | @uv: Kx2 np.array
48 | each row is a pair of coordinates (x, y)
49 |
50 | Returns:
51 | ----------
52 | @xy: Kx2 np.array
53 | each row is a pair of inverse-transformed coordinates (x, y)
54 | """
55 | Tinv = inv(trans)
56 | xy = tformfwd(Tinv, uv)
57 | return xy
58 |
59 |
60 | def findNonreflectiveSimilarity(uv, xy, options=None):
61 | options = {'K': 2}
62 |
63 | K = options['K']
64 | M = xy.shape[0]
65 | x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
66 | y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
67 |
68 | tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
69 | tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
70 | X = np.vstack((tmp1, tmp2))
71 |
72 | u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
73 | v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
74 | U = np.vstack((u, v))
75 |
76 | # We know that X * r = U
77 | if rank(X) >= 2 * K:
78 | r, _, _, _ = lstsq(X, U, rcond=-1)
79 | r = np.squeeze(r)
80 | else:
81 | raise Exception('cp2tform:twoUniquePointsReq')
82 | sc = r[0]
83 | ss = r[1]
84 | tx = r[2]
85 | ty = r[3]
86 |
87 | Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
88 | T = inv(Tinv)
89 | T[:, 2] = np.array([0, 0, 1])
90 |
91 | return T, Tinv
92 |
93 |
94 | def findSimilarity(uv, xy, options=None):
95 | options = {'K': 2}
96 |
97 | # uv = np.array(uv)
98 | # xy = np.array(xy)
99 |
100 | # Solve for trans1
101 | trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
102 |
103 | # Solve for trans2
104 |
105 | # manually reflect the xy data across the Y-axis
106 | xyR = xy
107 | xyR[:, 0] = -1 * xyR[:, 0]
108 |
109 | trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
110 |
111 | # manually reflect the tform to undo the reflection done on xyR
112 | TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
113 |
114 | trans2 = np.dot(trans2r, TreflectY)
115 |
116 | # Figure out if trans1 or trans2 is better
117 | xy1 = tformfwd(trans1, uv)
118 | norm1 = norm(xy1 - xy)
119 |
120 | xy2 = tformfwd(trans2, uv)
121 | norm2 = norm(xy2 - xy)
122 |
123 | if norm1 <= norm2:
124 | return trans1, trans1_inv
125 | else:
126 | trans2_inv = inv(trans2)
127 | return trans2, trans2_inv
128 |
129 |
130 | def get_similarity_transform(src_pts, dst_pts, reflective=True):
131 | """
132 | Function:
133 | ----------
134 | Find Similarity Transform Matrix 'trans':
135 | u = src_pts[:, 0]
136 | v = src_pts[:, 1]
137 | x = dst_pts[:, 0]
138 | y = dst_pts[:, 1]
139 | [x, y, 1] = [u, v, 1] * trans
140 |
141 | Parameters:
142 | ----------
143 | @src_pts: Kx2 np.array
144 | source points, each row is a pair of coordinates (x, y)
145 | @dst_pts: Kx2 np.array
146 | destination points, each row is a pair of transformed
147 | coordinates (x, y)
148 | @reflective: True or False
149 | if True:
150 | use reflective similarity transform
151 | else:
152 | use non-reflective similarity transform
153 |
154 | Returns:
155 | ----------
156 | @trans: 3x3 np.array
157 | transform matrix from uv to xy
158 | trans_inv: 3x3 np.array
159 | inverse of trans, transform matrix from xy to uv
160 | """
161 |
162 | if reflective:
163 | trans, trans_inv = findSimilarity(src_pts, dst_pts)
164 | else:
165 | trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
166 |
167 | return trans, trans_inv
168 |
169 |
170 | def cvt_tform_mat_for_cv2(trans):
171 | """
172 | Function:
173 | ----------
174 | Convert Transform Matrix 'trans' into 'cv2_trans' which could be
175 | directly used by cv2.warpAffine():
176 | u = src_pts[:, 0]
177 | v = src_pts[:, 1]
178 | x = dst_pts[:, 0]
179 | y = dst_pts[:, 1]
180 | [x, y].T = cv_trans * [u, v, 1].T
181 |
182 | Parameters:
183 | ----------
184 | @trans: 3x3 np.array
185 | transform matrix from uv to xy
186 |
187 | Returns:
188 | ----------
189 | @cv2_trans: 2x3 np.array
190 | transform matrix from src_pts to dst_pts, could be directly used
191 | for cv2.warpAffine()
192 | """
193 | cv2_trans = trans[:, 0:2].T
194 |
195 | return cv2_trans
196 |
197 |
198 | def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
199 | """
200 | Function:
201 | ----------
202 | Find Similarity Transform Matrix 'cv2_trans' which could be
203 | directly used by cv2.warpAffine():
204 | u = src_pts[:, 0]
205 | v = src_pts[:, 1]
206 | x = dst_pts[:, 0]
207 | y = dst_pts[:, 1]
208 | [x, y].T = cv_trans * [u, v, 1].T
209 |
210 | Parameters:
211 | ----------
212 | @src_pts: Kx2 np.array
213 | source points, each row is a pair of coordinates (x, y)
214 | @dst_pts: Kx2 np.array
215 | destination points, each row is a pair of transformed
216 | coordinates (x, y)
217 | reflective: True or False
218 | if True:
219 | use reflective similarity transform
220 | else:
221 | use non-reflective similarity transform
222 |
223 | Returns:
224 | ----------
225 | @cv2_trans: 2x3 np.array
226 | transform matrix from src_pts to dst_pts, could be directly used
227 | for cv2.warpAffine()
228 | """
229 | trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
230 | cv2_trans = cvt_tform_mat_for_cv2(trans)
231 |
232 | return cv2_trans
233 |
234 |
235 | if __name__ == '__main__':
236 | """
237 | u = [0, 6, -2]
238 | v = [0, 3, 5]
239 | x = [-1, 0, 4]
240 | y = [-1, -10, 4]
241 |
242 | # In Matlab, run:
243 | #
244 | # uv = [u'; v'];
245 | # xy = [x'; y'];
246 | # tform_sim=cp2tform(uv,xy,'similarity');
247 | #
248 | # trans = tform_sim.tdata.T
249 | # ans =
250 | # -0.0764 -1.6190 0
251 | # 1.6190 -0.0764 0
252 | # -3.2156 0.0290 1.0000
253 | # trans_inv = tform_sim.tdata.Tinv
254 | # ans =
255 | #
256 | # -0.0291 0.6163 0
257 | # -0.6163 -0.0291 0
258 | # -0.0756 1.9826 1.0000
259 | # xy_m=tformfwd(tform_sim, u,v)
260 | #
261 | # xy_m =
262 | #
263 | # -3.2156 0.0290
264 | # 1.1833 -9.9143
265 | # 5.0323 2.8853
266 | # uv_m=tforminv(tform_sim, x,y)
267 | #
268 | # uv_m =
269 | #
270 | # 0.5698 1.3953
271 | # 6.0872 2.2733
272 | # -2.6570 4.3314
273 | """
274 | u = [0, 6, -2]
275 | v = [0, 3, 5]
276 | x = [-1, 0, 4]
277 | y = [-1, -10, 4]
278 |
279 | uv = np.array((u, v)).T
280 | xy = np.array((x, y)).T
281 |
282 | print('\n--->uv:')
283 | print(uv)
284 | print('\n--->xy:')
285 | print(xy)
286 |
287 | trans, trans_inv = get_similarity_transform(uv, xy)
288 |
289 | print('\n--->trans matrix:')
290 | print(trans)
291 |
292 | print('\n--->trans_inv matrix:')
293 | print(trans_inv)
294 |
295 | print('\n---> apply transform to uv')
296 | print('\nxy_m = uv_augmented * trans')
297 | uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
298 | xy_m = np.dot(uv_aug, trans)
299 | print(xy_m)
300 |
301 | print('\nxy_m = tformfwd(trans, uv)')
302 | xy_m = tformfwd(trans, uv)
303 | print(xy_m)
304 |
305 | print('\n---> apply inverse transform to xy')
306 | print('\nuv_m = xy_augmented * trans_inv')
307 | xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
308 | uv_m = np.dot(xy_aug, trans_inv)
309 | print(uv_m)
310 |
311 | print('\nuv_m = tformfwd(trans_inv, xy)')
312 | uv_m = tformfwd(trans_inv, xy)
313 | print(uv_m)
314 |
315 | uv_m = tforminv(trans, xy)
316 | print('\nuv_m = tforminv(trans, xy)')
317 | print(uv_m)
318 |
--------------------------------------------------------------------------------
/facexlib/detection/retinaface.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from PIL import Image
7 | from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter
8 |
9 | from facexlib.detection.align_trans import get_reference_facial_points, warp_and_crop_face
10 | from facexlib.detection.retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head
11 | from facexlib.detection.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
12 | py_cpu_nms)
13 |
14 |
15 | def generate_config(network_name):
16 |
17 | cfg_mnet = {
18 | 'name': 'mobilenet0.25',
19 | 'min_sizes': [[16, 32], [64, 128], [256, 512]],
20 | 'steps': [8, 16, 32],
21 | 'variance': [0.1, 0.2],
22 | 'clip': False,
23 | 'loc_weight': 2.0,
24 | 'gpu_train': True,
25 | 'batch_size': 32,
26 | 'ngpu': 1,
27 | 'epoch': 250,
28 | 'decay1': 190,
29 | 'decay2': 220,
30 | 'image_size': 640,
31 | 'return_layers': {
32 | 'stage1': 1,
33 | 'stage2': 2,
34 | 'stage3': 3
35 | },
36 | 'in_channel': 32,
37 | 'out_channel': 64
38 | }
39 |
40 | cfg_re50 = {
41 | 'name': 'Resnet50',
42 | 'min_sizes': [[16, 32], [64, 128], [256, 512]],
43 | 'steps': [8, 16, 32],
44 | 'variance': [0.1, 0.2],
45 | 'clip': False,
46 | 'loc_weight': 2.0,
47 | 'gpu_train': True,
48 | 'batch_size': 24,
49 | 'ngpu': 4,
50 | 'epoch': 100,
51 | 'decay1': 70,
52 | 'decay2': 90,
53 | 'image_size': 840,
54 | 'return_layers': {
55 | 'layer2': 1,
56 | 'layer3': 2,
57 | 'layer4': 3
58 | },
59 | 'in_channel': 256,
60 | 'out_channel': 256
61 | }
62 |
63 | if network_name == 'mobile0.25':
64 | return cfg_mnet
65 | elif network_name == 'resnet50':
66 | return cfg_re50
67 | else:
68 | raise NotImplementedError(f'network_name={network_name}')
69 |
70 |
71 | class RetinaFace(nn.Module):
72 |
73 | def __init__(self, network_name='resnet50', half=False, phase='test', device=None):
74 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
75 |
76 | super(RetinaFace, self).__init__()
77 | self.half_inference = half
78 | cfg = generate_config(network_name)
79 | self.backbone = cfg['name']
80 |
81 | self.model_name = f'retinaface_{network_name}'
82 | self.cfg = cfg
83 | self.phase = phase
84 | self.target_size, self.max_size = 1600, 2150
85 | self.resize, self.scale, self.scale1 = 1., None, None
86 | self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]], device=self.device)
87 | self.reference = get_reference_facial_points(default_square=True)
88 | # Build network.
89 | backbone = None
90 | if cfg['name'] == 'mobilenet0.25':
91 | backbone = MobileNetV1()
92 | self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
93 | elif cfg['name'] == 'Resnet50':
94 | import torchvision.models as models
95 | backbone = models.resnet50(pretrained=False)
96 | self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
97 |
98 | in_channels_stage2 = cfg['in_channel']
99 | in_channels_list = [
100 | in_channels_stage2 * 2,
101 | in_channels_stage2 * 4,
102 | in_channels_stage2 * 8,
103 | ]
104 |
105 | out_channels = cfg['out_channel']
106 | self.fpn = FPN(in_channels_list, out_channels)
107 | self.ssh1 = SSH(out_channels, out_channels)
108 | self.ssh2 = SSH(out_channels, out_channels)
109 | self.ssh3 = SSH(out_channels, out_channels)
110 |
111 | self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
112 | self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
113 | self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
114 |
115 | self.to(self.device)
116 | self.eval()
117 | if self.half_inference:
118 | self.half()
119 |
120 | def forward(self, inputs):
121 | out = self.body(inputs)
122 |
123 | if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50':
124 | out = list(out.values())
125 | # FPN
126 | fpn = self.fpn(out)
127 |
128 | # SSH
129 | feature1 = self.ssh1(fpn[0])
130 | feature2 = self.ssh2(fpn[1])
131 | feature3 = self.ssh3(fpn[2])
132 | features = [feature1, feature2, feature3]
133 |
134 | bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
135 | classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1)
136 | tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)]
137 | ldm_regressions = (torch.cat(tmp, dim=1))
138 |
139 | if self.phase == 'train':
140 | output = (bbox_regressions, classifications, ldm_regressions)
141 | else:
142 | output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
143 | return output
144 |
145 | def __detect_faces(self, inputs):
146 | # get scale
147 | height, width = inputs.shape[2:]
148 | self.scale = torch.tensor([width, height, width, height], dtype=torch.float32, device=self.device)
149 | tmp = [width, height, width, height, width, height, width, height, width, height]
150 | self.scale1 = torch.tensor(tmp, dtype=torch.float32, device=self.device)
151 |
152 | # forawrd
153 | inputs = inputs.to(self.device)
154 | if self.half_inference:
155 | inputs = inputs.half()
156 | loc, conf, landmarks = self(inputs)
157 |
158 | # get priorbox
159 | priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
160 | priors = priorbox.forward().to(self.device)
161 |
162 | return loc, conf, landmarks, priors
163 |
164 | # single image detection
165 | def transform(self, image, use_origin_size):
166 | # convert to opencv format
167 | if isinstance(image, Image.Image):
168 | image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
169 | image = image.astype(np.float32)
170 |
171 | # testing scale
172 | im_size_min = np.min(image.shape[0:2])
173 | im_size_max = np.max(image.shape[0:2])
174 | resize = float(self.target_size) / float(im_size_min)
175 |
176 | # prevent bigger axis from being more than max_size
177 | if np.round(resize * im_size_max) > self.max_size:
178 | resize = float(self.max_size) / float(im_size_max)
179 | resize = 1 if use_origin_size else resize
180 |
181 | # resize
182 | if resize != 1:
183 | image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
184 |
185 | # convert to torch.tensor format
186 | # image -= (104, 117, 123)
187 | image = image.transpose(2, 0, 1)
188 | image = torch.from_numpy(image).unsqueeze(0)
189 |
190 | return image, resize
191 |
192 | def detect_faces(
193 | self,
194 | image,
195 | conf_threshold=0.8,
196 | nms_threshold=0.4,
197 | use_origin_size=True,
198 | ):
199 | image, self.resize = self.transform(image, use_origin_size)
200 | image = image.to(self.device)
201 | if self.half_inference:
202 | image = image.half()
203 | image = image - self.mean_tensor
204 |
205 | loc, conf, landmarks, priors = self.__detect_faces(image)
206 |
207 | boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance'])
208 | boxes = boxes * self.scale / self.resize
209 | boxes = boxes.cpu().numpy()
210 |
211 | scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
212 |
213 | landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg['variance'])
214 | landmarks = landmarks * self.scale1 / self.resize
215 | landmarks = landmarks.cpu().numpy()
216 |
217 | # ignore low scores
218 | inds = np.where(scores > conf_threshold)[0]
219 | boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]
220 |
221 | # sort
222 | order = scores.argsort()[::-1]
223 | boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]
224 |
225 | # do NMS
226 | bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
227 | keep = py_cpu_nms(bounding_boxes, nms_threshold)
228 | bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
229 | # self.t['forward_pass'].toc()
230 | # print(self.t['forward_pass'].average_time)
231 | # import sys
232 | # sys.stdout.flush()
233 | return np.concatenate((bounding_boxes, landmarks), axis=1)
234 |
235 | def __align_multi(self, image, boxes, landmarks, limit=None):
236 |
237 | if len(boxes) < 1:
238 | return [], []
239 |
240 | if limit:
241 | boxes = boxes[:limit]
242 | landmarks = landmarks[:limit]
243 |
244 | faces = []
245 | for landmark in landmarks:
246 | facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)]
247 |
248 | warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112))
249 | faces.append(warped_face)
250 |
251 | return np.concatenate((boxes, landmarks), axis=1), faces
252 |
253 | def align_multi(self, img, conf_threshold=0.8, limit=None):
254 |
255 | rlt = self.detect_faces(img, conf_threshold=conf_threshold)
256 | boxes, landmarks = rlt[:, 0:5], rlt[:, 5:]
257 |
258 | return self.__align_multi(img, boxes, landmarks, limit)
259 |
260 | # batched detection
261 | def batched_transform(self, frames, use_origin_size):
262 | """
263 | Arguments:
264 | frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c],
265 | type=np.float32, BGR format).
266 | use_origin_size: whether to use origin size.
267 | """
268 | from_PIL = True if isinstance(frames[0], Image.Image) else False
269 |
270 | # convert to opencv format
271 | if from_PIL:
272 | frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames]
273 | frames = np.asarray(frames, dtype=np.float32)
274 |
275 | # testing scale
276 | im_size_min = np.min(frames[0].shape[0:2])
277 | im_size_max = np.max(frames[0].shape[0:2])
278 | resize = float(self.target_size) / float(im_size_min)
279 |
280 | # prevent bigger axis from being more than max_size
281 | if np.round(resize * im_size_max) > self.max_size:
282 | resize = float(self.max_size) / float(im_size_max)
283 | resize = 1 if use_origin_size else resize
284 |
285 | # resize
286 | if resize != 1:
287 | if not from_PIL:
288 | frames = F.interpolate(frames, scale_factor=resize)
289 | else:
290 | frames = [
291 | cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
292 | for frame in frames
293 | ]
294 |
295 | # convert to torch.tensor format
296 | if not from_PIL:
297 | frames = frames.transpose(1, 2).transpose(1, 3).contiguous()
298 | else:
299 | frames = frames.transpose((0, 3, 1, 2))
300 | frames = torch.from_numpy(frames)
301 |
302 | return frames, resize
303 |
304 | def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True):
305 | """
306 | Arguments:
307 | frames: a list of PIL.Image, or np.array(shape=[n, h, w, c],
308 | type=np.uint8, BGR format).
309 | conf_threshold: confidence threshold.
310 | nms_threshold: nms threshold.
311 | use_origin_size: whether to use origin size.
312 | Returns:
313 | final_bounding_boxes: list of np.array ([n_boxes, 5],
314 | type=np.float32).
315 | final_landmarks: list of np.array ([n_boxes, 10], type=np.float32).
316 | """
317 | # self.t['forward_pass'].tic()
318 | frames, self.resize = self.batched_transform(frames, use_origin_size)
319 | frames = frames.to(self.device)
320 | frames = frames - self.mean_tensor
321 |
322 | b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
323 |
324 | final_bounding_boxes, final_landmarks = [], []
325 |
326 | # decode
327 | priors = priors.unsqueeze(0)
328 | b_loc = batched_decode(b_loc, priors, self.cfg['variance']) * self.scale / self.resize
329 | b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize
330 | b_conf = b_conf[:, :, 1]
331 |
332 | # index for selection
333 | b_indice = b_conf > conf_threshold
334 |
335 | # concat
336 | b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float()
337 |
338 | for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice):
339 |
340 | # ignore low scores
341 | pred, landm = pred[inds, :], landm[inds, :]
342 | if pred.shape[0] == 0:
343 | final_bounding_boxes.append(np.array([], dtype=np.float32))
344 | final_landmarks.append(np.array([], dtype=np.float32))
345 | continue
346 |
347 | # sort
348 | # order = score.argsort(descending=True)
349 | # box, landm, score = box[order], landm[order], score[order]
350 |
351 | # to CPU
352 | bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy()
353 |
354 | # NMS
355 | keep = py_cpu_nms(bounding_boxes, nms_threshold)
356 | bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep]
357 |
358 | # append
359 | final_bounding_boxes.append(bounding_boxes)
360 | final_landmarks.append(landmarks)
361 | # self.t['forward_pass'].toc(average=True)
362 | # self.batch_time += self.t['forward_pass'].diff
363 | # self.total_frame += len(frames)
364 | # print(self.batch_time / self.total_frame)
365 |
366 | return final_bounding_boxes, final_landmarks
367 |
--------------------------------------------------------------------------------
/facexlib/detection/retinaface_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def conv_bn(inp, oup, stride=1, leaky=0):
7 | return nn.Sequential(
8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
9 | nn.LeakyReLU(negative_slope=leaky, inplace=True))
10 |
11 |
12 | def conv_bn_no_relu(inp, oup, stride):
13 | return nn.Sequential(
14 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
15 | nn.BatchNorm2d(oup),
16 | )
17 |
18 |
19 | def conv_bn1X1(inp, oup, stride, leaky=0):
20 | return nn.Sequential(
21 | nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup),
22 | nn.LeakyReLU(negative_slope=leaky, inplace=True))
23 |
24 |
25 | def conv_dw(inp, oup, stride, leaky=0.1):
26 | return nn.Sequential(
27 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
28 | nn.BatchNorm2d(inp),
29 | nn.LeakyReLU(negative_slope=leaky, inplace=True),
30 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
31 | nn.BatchNorm2d(oup),
32 | nn.LeakyReLU(negative_slope=leaky, inplace=True),
33 | )
34 |
35 |
36 | class SSH(nn.Module):
37 |
38 | def __init__(self, in_channel, out_channel):
39 | super(SSH, self).__init__()
40 | assert out_channel % 4 == 0
41 | leaky = 0
42 | if (out_channel <= 64):
43 | leaky = 0.1
44 | self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
45 |
46 | self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
47 | self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
48 |
49 | self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
50 | self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
51 |
52 | def forward(self, input):
53 | conv3X3 = self.conv3X3(input)
54 |
55 | conv5X5_1 = self.conv5X5_1(input)
56 | conv5X5 = self.conv5X5_2(conv5X5_1)
57 |
58 | conv7X7_2 = self.conv7X7_2(conv5X5_1)
59 | conv7X7 = self.conv7x7_3(conv7X7_2)
60 |
61 | out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
62 | out = F.relu(out)
63 | return out
64 |
65 |
66 | class FPN(nn.Module):
67 |
68 | def __init__(self, in_channels_list, out_channels):
69 | super(FPN, self).__init__()
70 | leaky = 0
71 | if (out_channels <= 64):
72 | leaky = 0.1
73 | self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
74 | self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
75 | self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
76 |
77 | self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
78 | self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
79 |
80 | def forward(self, input):
81 | # names = list(input.keys())
82 | # input = list(input.values())
83 |
84 | output1 = self.output1(input[0])
85 | output2 = self.output2(input[1])
86 | output3 = self.output3(input[2])
87 |
88 | up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest')
89 | output2 = output2 + up3
90 | output2 = self.merge2(output2)
91 |
92 | up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest')
93 | output1 = output1 + up2
94 | output1 = self.merge1(output1)
95 |
96 | out = [output1, output2, output3]
97 | return out
98 |
99 |
100 | class MobileNetV1(nn.Module):
101 |
102 | def __init__(self):
103 | super(MobileNetV1, self).__init__()
104 | self.stage1 = nn.Sequential(
105 | conv_bn(3, 8, 2, leaky=0.1), # 3
106 | conv_dw(8, 16, 1), # 7
107 | conv_dw(16, 32, 2), # 11
108 | conv_dw(32, 32, 1), # 19
109 | conv_dw(32, 64, 2), # 27
110 | conv_dw(64, 64, 1), # 43
111 | )
112 | self.stage2 = nn.Sequential(
113 | conv_dw(64, 128, 2), # 43 + 16 = 59
114 | conv_dw(128, 128, 1), # 59 + 32 = 91
115 | conv_dw(128, 128, 1), # 91 + 32 = 123
116 | conv_dw(128, 128, 1), # 123 + 32 = 155
117 | conv_dw(128, 128, 1), # 155 + 32 = 187
118 | conv_dw(128, 128, 1), # 187 + 32 = 219
119 | )
120 | self.stage3 = nn.Sequential(
121 | conv_dw(128, 256, 2), # 219 +3 2 = 241
122 | conv_dw(256, 256, 1), # 241 + 64 = 301
123 | )
124 | self.avg = nn.AdaptiveAvgPool2d((1, 1))
125 | self.fc = nn.Linear(256, 1000)
126 |
127 | def forward(self, x):
128 | x = self.stage1(x)
129 | x = self.stage2(x)
130 | x = self.stage3(x)
131 | x = self.avg(x)
132 | # x = self.model(x)
133 | x = x.view(-1, 256)
134 | x = self.fc(x)
135 | return x
136 |
137 |
138 | class ClassHead(nn.Module):
139 |
140 | def __init__(self, inchannels=512, num_anchors=3):
141 | super(ClassHead, self).__init__()
142 | self.num_anchors = num_anchors
143 | self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
144 |
145 | def forward(self, x):
146 | out = self.conv1x1(x)
147 | out = out.permute(0, 2, 3, 1).contiguous()
148 |
149 | return out.view(out.shape[0], -1, 2)
150 |
151 |
152 | class BboxHead(nn.Module):
153 |
154 | def __init__(self, inchannels=512, num_anchors=3):
155 | super(BboxHead, self).__init__()
156 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
157 |
158 | def forward(self, x):
159 | out = self.conv1x1(x)
160 | out = out.permute(0, 2, 3, 1).contiguous()
161 |
162 | return out.view(out.shape[0], -1, 4)
163 |
164 |
165 | class LandmarkHead(nn.Module):
166 |
167 | def __init__(self, inchannels=512, num_anchors=3):
168 | super(LandmarkHead, self).__init__()
169 | self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
170 |
171 | def forward(self, x):
172 | out = self.conv1x1(x)
173 | out = out.permute(0, 2, 3, 1).contiguous()
174 |
175 | return out.view(out.shape[0], -1, 10)
176 |
177 |
178 | def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
179 | classhead = nn.ModuleList()
180 | for i in range(fpn_num):
181 | classhead.append(ClassHead(inchannels, anchor_num))
182 | return classhead
183 |
184 |
185 | def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
186 | bboxhead = nn.ModuleList()
187 | for i in range(fpn_num):
188 | bboxhead.append(BboxHead(inchannels, anchor_num))
189 | return bboxhead
190 |
191 |
192 | def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):
193 | landmarkhead = nn.ModuleList()
194 | for i in range(fpn_num):
195 | landmarkhead.append(LandmarkHead(inchannels, anchor_num))
196 | return landmarkhead
197 |
--------------------------------------------------------------------------------
/facexlib/headpose/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from facexlib.utils import load_file_from_url
4 | from .hopenet_arch import HopeNet
5 |
6 |
7 | def init_headpose_model(model_name, half=False, device='cuda', model_rootpath=None):
8 | if model_name == 'hopenet':
9 | model = HopeNet('resnet', [3, 4, 6, 3], 66)
10 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/headpose_hopenet.pth'
11 | else:
12 | raise NotImplementedError(f'{model_name} is not implemented.')
13 |
14 | model_path = load_file_from_url(
15 | url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
16 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage)['params']
17 | model.load_state_dict(load_net, strict=True)
18 | model.eval()
19 | model = model.to(device)
20 | return model
21 |
--------------------------------------------------------------------------------
/facexlib/headpose/hopenet_arch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 |
5 |
6 | class HopeNet(nn.Module):
7 | # Hopenet with 3 output layers for yaw, pitch and roll
8 | # Predicts Euler angles by binning and regression with the expected value
9 | def __init__(self, block, layers, num_bins):
10 | super(HopeNet, self).__init__()
11 | if block == 'resnet':
12 | block = torchvision.models.resnet.Bottleneck
13 | self.inplanes = 64
14 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
15 | self.bn1 = nn.BatchNorm2d(64)
16 | self.relu = nn.ReLU(inplace=True)
17 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
18 | self.layer1 = self._make_layer(block, 64, layers[0])
19 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
20 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
21 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
22 | self.avgpool = nn.AvgPool2d(7)
23 | self.fc_yaw = nn.Linear(512 * block.expansion, num_bins)
24 | self.fc_pitch = nn.Linear(512 * block.expansion, num_bins)
25 | self.fc_roll = nn.Linear(512 * block.expansion, num_bins)
26 |
27 | self.idx_tensor = torch.arange(66).float()
28 |
29 | def _make_layer(self, block, planes, blocks, stride=1):
30 | downsample = None
31 | if stride != 1 or self.inplanes != planes * block.expansion:
32 | downsample = nn.Sequential(
33 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
34 | nn.BatchNorm2d(planes * block.expansion),
35 | )
36 |
37 | layers = []
38 | layers.append(block(self.inplanes, planes, stride, downsample))
39 | self.inplanes = planes * block.expansion
40 | for i in range(1, blocks):
41 | layers.append(block(self.inplanes, planes))
42 | return nn.Sequential(*layers)
43 |
44 | @staticmethod
45 | def softmax_temperature(tensor, temperature):
46 | result = torch.exp(tensor / temperature)
47 | result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result))
48 | return result
49 |
50 | def bin2degree(self, predict):
51 | predict = self.softmax_temperature(predict, 1)
52 | return torch.sum(predict * self.idx_tensor.type_as(predict), 1) * 3 - 99
53 |
54 | def forward(self, x):
55 | x = self.relu(self.bn1(self.conv1(x)))
56 | x = self.maxpool(x)
57 |
58 | x = self.layer1(x)
59 | x = self.layer2(x)
60 | x = self.layer3(x)
61 | x = self.layer4(x)
62 |
63 | x = self.avgpool(x)
64 | x = x.view(x.size(0), -1)
65 | pre_yaw = self.fc_yaw(x)
66 | pre_pitch = self.fc_pitch(x)
67 | pre_roll = self.fc_roll(x)
68 |
69 | yaw = self.bin2degree(pre_yaw)
70 | pitch = self.bin2degree(pre_pitch)
71 | roll = self.bin2degree(pre_roll)
72 | return yaw, pitch, roll
73 |
--------------------------------------------------------------------------------
/facexlib/matting/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from copy import deepcopy
3 |
4 | from facexlib.utils import load_file_from_url
5 | from .modnet import MODNet
6 |
7 |
8 | def init_matting_model(model_name='modnet', half=False, device='cuda', model_rootpath=None):
9 | if model_name == 'modnet':
10 | model = MODNet(backbone_pretrained=False)
11 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/matting_modnet_portrait.pth'
12 | else:
13 | raise NotImplementedError(f'{model_name} is not implemented.')
14 |
15 | model_path = load_file_from_url(
16 | url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
17 | # TODO: clean pretrained model
18 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
19 | # remove unnecessary 'module.'
20 | for k, v in deepcopy(load_net).items():
21 | if k.startswith('module.'):
22 | load_net[k[7:]] = v
23 | load_net.pop(k)
24 | model.load_state_dict(load_net, strict=True)
25 | model.eval()
26 | model = model.to(device)
27 | return model
28 |
--------------------------------------------------------------------------------
/facexlib/matting/backbone.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 | from .mobilenetv2 import MobileNetV2
6 |
7 |
8 | class BaseBackbone(nn.Module):
9 | """ Superclass of Replaceable Backbone Model for Semantic Estimation
10 | """
11 |
12 | def __init__(self, in_channels):
13 | super(BaseBackbone, self).__init__()
14 | self.in_channels = in_channels
15 |
16 | self.model = None
17 | self.enc_channels = []
18 |
19 | def forward(self, x):
20 | raise NotImplementedError
21 |
22 | def load_pretrained_ckpt(self):
23 | raise NotImplementedError
24 |
25 |
26 | class MobileNetV2Backbone(BaseBackbone):
27 | """ MobileNetV2 Backbone
28 | """
29 |
30 | def __init__(self, in_channels):
31 | super(MobileNetV2Backbone, self).__init__(in_channels)
32 |
33 | self.model = MobileNetV2(self.in_channels, alpha=1.0, expansion=6, num_classes=None)
34 | self.enc_channels = [16, 24, 32, 96, 1280]
35 |
36 | def forward(self, x):
37 | # x = reduce(lambda x, n: self.model.features[n](x), list(range(0, 2)), x)
38 | x = self.model.features[0](x)
39 | x = self.model.features[1](x)
40 | enc2x = x
41 |
42 | # x = reduce(lambda x, n: self.model.features[n](x), list(range(2, 4)), x)
43 | x = self.model.features[2](x)
44 | x = self.model.features[3](x)
45 | enc4x = x
46 |
47 | # x = reduce(lambda x, n: self.model.features[n](x), list(range(4, 7)), x)
48 | x = self.model.features[4](x)
49 | x = self.model.features[5](x)
50 | x = self.model.features[6](x)
51 | enc8x = x
52 |
53 | # x = reduce(lambda x, n: self.model.features[n](x), list(range(7, 14)), x)
54 | x = self.model.features[7](x)
55 | x = self.model.features[8](x)
56 | x = self.model.features[9](x)
57 | x = self.model.features[10](x)
58 | x = self.model.features[11](x)
59 | x = self.model.features[12](x)
60 | x = self.model.features[13](x)
61 | enc16x = x
62 |
63 | # x = reduce(lambda x, n: self.model.features[n](x), list(range(14, 19)), x)
64 | x = self.model.features[14](x)
65 | x = self.model.features[15](x)
66 | x = self.model.features[16](x)
67 | x = self.model.features[17](x)
68 | x = self.model.features[18](x)
69 | enc32x = x
70 | return [enc2x, enc4x, enc8x, enc16x, enc32x]
71 |
72 | def load_pretrained_ckpt(self):
73 | # the pre-trained model is provided by https://github.com/thuyngch/Human-Segmentation-PyTorch
74 | ckpt_path = './pretrained/mobilenetv2_human_seg.ckpt'
75 | if not os.path.exists(ckpt_path):
76 | print('cannot find the pretrained mobilenetv2 backbone')
77 | exit()
78 |
79 | ckpt = torch.load(ckpt_path)
80 | self.model.load_state_dict(ckpt)
81 |
--------------------------------------------------------------------------------
/facexlib/matting/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | """ This file is adapted from https://github.com/thuyngch/Human-Segmentation-PyTorch"""
2 |
3 | import math
4 | import torch
5 | from torch import nn
6 |
7 | # ------------------------------------------------------------------------------
8 | # Useful functions
9 | # ------------------------------------------------------------------------------
10 |
11 |
12 | def _make_divisible(v, divisor, min_value=None):
13 | if min_value is None:
14 | min_value = divisor
15 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
16 | # Make sure that round down does not go down by more than 10%.
17 | if new_v < 0.9 * v:
18 | new_v += divisor
19 | return new_v
20 |
21 |
22 | def conv_bn(inp, oup, stride):
23 | return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True))
24 |
25 |
26 | def conv_1x1_bn(inp, oup):
27 | return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU6(inplace=True))
28 |
29 |
30 | # ------------------------------------------------------------------------------
31 | # Class of Inverted Residual block
32 | # ------------------------------------------------------------------------------
33 |
34 |
35 | class InvertedResidual(nn.Module):
36 |
37 | def __init__(self, inp, oup, stride, expansion, dilation=1):
38 | super(InvertedResidual, self).__init__()
39 | self.stride = stride
40 | assert stride in [1, 2]
41 |
42 | hidden_dim = round(inp * expansion)
43 | self.use_res_connect = self.stride == 1 and inp == oup
44 |
45 | if expansion == 1:
46 | self.conv = nn.Sequential(
47 | # dw
48 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
49 | nn.BatchNorm2d(hidden_dim),
50 | nn.ReLU6(inplace=True),
51 | # pw-linear
52 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
53 | nn.BatchNorm2d(oup),
54 | )
55 | else:
56 | self.conv = nn.Sequential(
57 | # pw
58 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
59 | nn.BatchNorm2d(hidden_dim),
60 | nn.ReLU6(inplace=True),
61 | # dw
62 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, dilation=dilation, bias=False),
63 | nn.BatchNorm2d(hidden_dim),
64 | nn.ReLU6(inplace=True),
65 | # pw-linear
66 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
67 | nn.BatchNorm2d(oup),
68 | )
69 |
70 | def forward(self, x):
71 | if self.use_res_connect:
72 | return x + self.conv(x)
73 | else:
74 | return self.conv(x)
75 |
76 |
77 | # ------------------------------------------------------------------------------
78 | # Class of MobileNetV2
79 | # ------------------------------------------------------------------------------
80 |
81 |
82 | class MobileNetV2(nn.Module):
83 |
84 | def __init__(self, in_channels, alpha=1.0, expansion=6, num_classes=1000):
85 | super(MobileNetV2, self).__init__()
86 | self.in_channels = in_channels
87 | self.num_classes = num_classes
88 | input_channel = 32
89 | last_channel = 1280
90 | interverted_residual_setting = [
91 | # t, c, n, s
92 | [1, 16, 1, 1],
93 | [expansion, 24, 2, 2],
94 | [expansion, 32, 3, 2],
95 | [expansion, 64, 4, 2],
96 | [expansion, 96, 3, 1],
97 | [expansion, 160, 3, 2],
98 | [expansion, 320, 1, 1],
99 | ]
100 |
101 | # building first layer
102 | input_channel = _make_divisible(input_channel * alpha, 8)
103 | self.last_channel = _make_divisible(last_channel * alpha, 8) if alpha > 1.0 else last_channel
104 | self.features = [conv_bn(self.in_channels, input_channel, 2)]
105 |
106 | # building inverted residual blocks
107 | for t, c, n, s in interverted_residual_setting:
108 | output_channel = _make_divisible(int(c * alpha), 8)
109 | for i in range(n):
110 | if i == 0:
111 | self.features.append(InvertedResidual(input_channel, output_channel, s, expansion=t))
112 | else:
113 | self.features.append(InvertedResidual(input_channel, output_channel, 1, expansion=t))
114 | input_channel = output_channel
115 |
116 | # building last several layers
117 | self.features.append(conv_1x1_bn(input_channel, self.last_channel))
118 |
119 | # make it nn.Sequential
120 | self.features = nn.Sequential(*self.features)
121 |
122 | # building classifier
123 | if self.num_classes is not None:
124 | self.classifier = nn.Sequential(
125 | nn.Dropout(0.2),
126 | nn.Linear(self.last_channel, num_classes),
127 | )
128 |
129 | # Initialize weights
130 | self._init_weights()
131 |
132 | def forward(self, x):
133 | # Stage1
134 | x = self.features[0](x)
135 | x = self.features[1](x)
136 | # Stage2
137 | x = self.features[2](x)
138 | x = self.features[3](x)
139 | # Stage3
140 | x = self.features[4](x)
141 | x = self.features[5](x)
142 | x = self.features[6](x)
143 | # Stage4
144 | x = self.features[7](x)
145 | x = self.features[8](x)
146 | x = self.features[9](x)
147 | x = self.features[10](x)
148 | x = self.features[11](x)
149 | x = self.features[12](x)
150 | x = self.features[13](x)
151 | # Stage5
152 | x = self.features[14](x)
153 | x = self.features[15](x)
154 | x = self.features[16](x)
155 | x = self.features[17](x)
156 | x = self.features[18](x)
157 |
158 | # Classification
159 | if self.num_classes is not None:
160 | x = x.mean(dim=(2, 3))
161 | x = self.classifier(x)
162 |
163 | # Output
164 | return x
165 |
166 | def _load_pretrained_model(self, pretrained_file):
167 | pretrain_dict = torch.load(pretrained_file, map_location='cpu')
168 | model_dict = {}
169 | state_dict = self.state_dict()
170 | print('[MobileNetV2] Loading pretrained model...')
171 | for k, v in pretrain_dict.items():
172 | if k in state_dict:
173 | model_dict[k] = v
174 | else:
175 | print(k, 'is ignored')
176 | state_dict.update(model_dict)
177 | self.load_state_dict(state_dict)
178 |
179 | def _init_weights(self):
180 | for m in self.modules():
181 | if isinstance(m, nn.Conv2d):
182 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
183 | m.weight.data.normal_(0, math.sqrt(2. / n))
184 | if m.bias is not None:
185 | m.bias.data.zero_()
186 | elif isinstance(m, nn.BatchNorm2d):
187 | m.weight.data.fill_(1)
188 | m.bias.data.zero_()
189 | elif isinstance(m, nn.Linear):
190 | n = m.weight.size(1)
191 | m.weight.data.normal_(0, 0.01)
192 | m.bias.data.zero_()
193 |
--------------------------------------------------------------------------------
/facexlib/matting/modnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .backbone import MobileNetV2Backbone
6 |
7 | # ------------------------------------------------------------------------------
8 | # MODNet Basic Modules
9 | # ------------------------------------------------------------------------------
10 |
11 |
12 | class IBNorm(nn.Module):
13 | """ Combine Instance Norm and Batch Norm into One Layer
14 | """
15 |
16 | def __init__(self, in_channels):
17 | super(IBNorm, self).__init__()
18 | in_channels = in_channels
19 | self.bnorm_channels = int(in_channels / 2)
20 | self.inorm_channels = in_channels - self.bnorm_channels
21 |
22 | self.bnorm = nn.BatchNorm2d(self.bnorm_channels, affine=True)
23 | self.inorm = nn.InstanceNorm2d(self.inorm_channels, affine=False)
24 |
25 | def forward(self, x):
26 | bn_x = self.bnorm(x[:, :self.bnorm_channels, ...].contiguous())
27 | in_x = self.inorm(x[:, self.bnorm_channels:, ...].contiguous())
28 |
29 | return torch.cat((bn_x, in_x), 1)
30 |
31 |
32 | class Conv2dIBNormRelu(nn.Module):
33 | """ Convolution + IBNorm + ReLu
34 | """
35 |
36 | def __init__(self,
37 | in_channels,
38 | out_channels,
39 | kernel_size,
40 | stride=1,
41 | padding=0,
42 | dilation=1,
43 | groups=1,
44 | bias=True,
45 | with_ibn=True,
46 | with_relu=True):
47 | super(Conv2dIBNormRelu, self).__init__()
48 |
49 | layers = [
50 | nn.Conv2d(
51 | in_channels,
52 | out_channels,
53 | kernel_size,
54 | stride=stride,
55 | padding=padding,
56 | dilation=dilation,
57 | groups=groups,
58 | bias=bias)
59 | ]
60 |
61 | if with_ibn:
62 | layers.append(IBNorm(out_channels))
63 | if with_relu:
64 | layers.append(nn.ReLU(inplace=True))
65 |
66 | self.layers = nn.Sequential(*layers)
67 |
68 | def forward(self, x):
69 | return self.layers(x)
70 |
71 |
72 | class SEBlock(nn.Module):
73 | """ SE Block Proposed in https://arxiv.org/pdf/1709.01507.pdf
74 | """
75 |
76 | def __init__(self, in_channels, out_channels, reduction=1):
77 | super(SEBlock, self).__init__()
78 | self.pool = nn.AdaptiveAvgPool2d(1)
79 | self.fc = nn.Sequential(
80 | nn.Linear(in_channels, int(in_channels // reduction), bias=False), nn.ReLU(inplace=True),
81 | nn.Linear(int(in_channels // reduction), out_channels, bias=False), nn.Sigmoid())
82 |
83 | def forward(self, x):
84 | b, c, _, _ = x.size()
85 | w = self.pool(x).view(b, c)
86 | w = self.fc(w).view(b, c, 1, 1)
87 |
88 | return x * w.expand_as(x)
89 |
90 |
91 | # ------------------------------------------------------------------------------
92 | # MODNet Branches
93 | # ------------------------------------------------------------------------------
94 |
95 |
96 | class LRBranch(nn.Module):
97 | """ Low Resolution Branch of MODNet
98 | """
99 |
100 | def __init__(self, backbone):
101 | super(LRBranch, self).__init__()
102 |
103 | enc_channels = backbone.enc_channels
104 |
105 | self.backbone = backbone
106 | self.se_block = SEBlock(enc_channels[4], enc_channels[4], reduction=4)
107 | self.conv_lr16x = Conv2dIBNormRelu(enc_channels[4], enc_channels[3], 5, stride=1, padding=2)
108 | self.conv_lr8x = Conv2dIBNormRelu(enc_channels[3], enc_channels[2], 5, stride=1, padding=2)
109 | self.conv_lr = Conv2dIBNormRelu(
110 | enc_channels[2], 1, kernel_size=3, stride=2, padding=1, with_ibn=False, with_relu=False)
111 |
112 | def forward(self, img, inference):
113 | enc_features = self.backbone.forward(img)
114 | enc2x, enc4x, enc32x = enc_features[0], enc_features[1], enc_features[4]
115 |
116 | enc32x = self.se_block(enc32x)
117 | lr16x = F.interpolate(enc32x, scale_factor=2, mode='bilinear', align_corners=False)
118 | lr16x = self.conv_lr16x(lr16x)
119 | lr8x = F.interpolate(lr16x, scale_factor=2, mode='bilinear', align_corners=False)
120 | lr8x = self.conv_lr8x(lr8x)
121 |
122 | pred_semantic = None
123 | if not inference:
124 | lr = self.conv_lr(lr8x)
125 | pred_semantic = torch.sigmoid(lr)
126 |
127 | return pred_semantic, lr8x, [enc2x, enc4x]
128 |
129 |
130 | class HRBranch(nn.Module):
131 | """ High Resolution Branch of MODNet
132 | """
133 |
134 | def __init__(self, hr_channels, enc_channels):
135 | super(HRBranch, self).__init__()
136 |
137 | self.tohr_enc2x = Conv2dIBNormRelu(enc_channels[0], hr_channels, 1, stride=1, padding=0)
138 | self.conv_enc2x = Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=2, padding=1)
139 |
140 | self.tohr_enc4x = Conv2dIBNormRelu(enc_channels[1], hr_channels, 1, stride=1, padding=0)
141 | self.conv_enc4x = Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1)
142 |
143 | self.conv_hr4x = nn.Sequential(
144 | Conv2dIBNormRelu(3 * hr_channels + 3, 2 * hr_channels, 3, stride=1, padding=1),
145 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
146 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
147 | )
148 |
149 | self.conv_hr2x = nn.Sequential(
150 | Conv2dIBNormRelu(2 * hr_channels, 2 * hr_channels, 3, stride=1, padding=1),
151 | Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1),
152 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
153 | Conv2dIBNormRelu(hr_channels, hr_channels, 3, stride=1, padding=1),
154 | )
155 |
156 | self.conv_hr = nn.Sequential(
157 | Conv2dIBNormRelu(hr_channels + 3, hr_channels, 3, stride=1, padding=1),
158 | Conv2dIBNormRelu(hr_channels, 1, kernel_size=1, stride=1, padding=0, with_ibn=False, with_relu=False),
159 | )
160 |
161 | def forward(self, img, enc2x, enc4x, lr8x, inference):
162 | img2x = F.interpolate(img, scale_factor=1 / 2, mode='bilinear', align_corners=False)
163 | img4x = F.interpolate(img, scale_factor=1 / 4, mode='bilinear', align_corners=False)
164 |
165 | enc2x = self.tohr_enc2x(enc2x)
166 | hr4x = self.conv_enc2x(torch.cat((img2x, enc2x), dim=1))
167 |
168 | enc4x = self.tohr_enc4x(enc4x)
169 | hr4x = self.conv_enc4x(torch.cat((hr4x, enc4x), dim=1))
170 |
171 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
172 | hr4x = self.conv_hr4x(torch.cat((hr4x, lr4x, img4x), dim=1))
173 |
174 | hr2x = F.interpolate(hr4x, scale_factor=2, mode='bilinear', align_corners=False)
175 | hr2x = self.conv_hr2x(torch.cat((hr2x, enc2x), dim=1))
176 |
177 | pred_detail = None
178 | if not inference:
179 | hr = F.interpolate(hr2x, scale_factor=2, mode='bilinear', align_corners=False)
180 | hr = self.conv_hr(torch.cat((hr, img), dim=1))
181 | pred_detail = torch.sigmoid(hr)
182 |
183 | return pred_detail, hr2x
184 |
185 |
186 | class FusionBranch(nn.Module):
187 | """ Fusion Branch of MODNet
188 | """
189 |
190 | def __init__(self, hr_channels, enc_channels):
191 | super(FusionBranch, self).__init__()
192 | self.conv_lr4x = Conv2dIBNormRelu(enc_channels[2], hr_channels, 5, stride=1, padding=2)
193 |
194 | self.conv_f2x = Conv2dIBNormRelu(2 * hr_channels, hr_channels, 3, stride=1, padding=1)
195 | self.conv_f = nn.Sequential(
196 | Conv2dIBNormRelu(hr_channels + 3, int(hr_channels / 2), 3, stride=1, padding=1),
197 | Conv2dIBNormRelu(int(hr_channels / 2), 1, 1, stride=1, padding=0, with_ibn=False, with_relu=False),
198 | )
199 |
200 | def forward(self, img, lr8x, hr2x):
201 | lr4x = F.interpolate(lr8x, scale_factor=2, mode='bilinear', align_corners=False)
202 | lr4x = self.conv_lr4x(lr4x)
203 | lr2x = F.interpolate(lr4x, scale_factor=2, mode='bilinear', align_corners=False)
204 |
205 | f2x = self.conv_f2x(torch.cat((lr2x, hr2x), dim=1))
206 | f = F.interpolate(f2x, scale_factor=2, mode='bilinear', align_corners=False)
207 | f = self.conv_f(torch.cat((f, img), dim=1))
208 | pred_matte = torch.sigmoid(f)
209 |
210 | return pred_matte
211 |
212 |
213 | # ------------------------------------------------------------------------------
214 | # MODNet
215 | # ------------------------------------------------------------------------------
216 |
217 |
218 | class MODNet(nn.Module):
219 | """ Architecture of MODNet
220 | """
221 |
222 | def __init__(self, in_channels=3, hr_channels=32, backbone_pretrained=True):
223 | super(MODNet, self).__init__()
224 |
225 | self.in_channels = in_channels
226 | self.hr_channels = hr_channels
227 | self.backbone_pretrained = backbone_pretrained
228 |
229 | self.backbone = MobileNetV2Backbone(self.in_channels)
230 |
231 | self.lr_branch = LRBranch(self.backbone)
232 | self.hr_branch = HRBranch(self.hr_channels, self.backbone.enc_channels)
233 | self.f_branch = FusionBranch(self.hr_channels, self.backbone.enc_channels)
234 |
235 | for m in self.modules():
236 | if isinstance(m, nn.Conv2d):
237 | self._init_conv(m)
238 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
239 | self._init_norm(m)
240 |
241 | if self.backbone_pretrained:
242 | self.backbone.load_pretrained_ckpt()
243 |
244 | def forward(self, img, inference):
245 | pred_semantic, lr8x, [enc2x, enc4x] = self.lr_branch(img, inference)
246 | pred_detail, hr2x = self.hr_branch(img, enc2x, enc4x, lr8x, inference)
247 | pred_matte = self.f_branch(img, lr8x, hr2x)
248 |
249 | return pred_semantic, pred_detail, pred_matte
250 |
251 | def freeze_norm(self):
252 | norm_types = [nn.BatchNorm2d, nn.InstanceNorm2d]
253 | for m in self.modules():
254 | for n in norm_types:
255 | if isinstance(m, n):
256 | m.eval()
257 | continue
258 |
259 | def _init_conv(self, conv):
260 | nn.init.kaiming_uniform_(conv.weight, a=0, mode='fan_in', nonlinearity='relu')
261 | if conv.bias is not None:
262 | nn.init.constant_(conv.bias, 0)
263 |
264 | def _init_norm(self, norm):
265 | if norm.weight is not None:
266 | nn.init.constant_(norm.weight, 1)
267 | nn.init.constant_(norm.bias, 0)
268 |
--------------------------------------------------------------------------------
/facexlib/parsing/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from facexlib.utils import load_file_from_url
4 | from .bisenet import BiSeNet
5 | from .parsenet import ParseNet
6 |
7 |
8 | def init_parsing_model(model_name='bisenet', half=False, device='cuda', model_rootpath=None):
9 | if model_name == 'bisenet':
10 | model = BiSeNet(num_class=19)
11 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.0/parsing_bisenet.pth'
12 | elif model_name == 'parsenet':
13 | model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
14 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth'
15 | else:
16 | raise NotImplementedError(f'{model_name} is not implemented.')
17 |
18 | model_path = load_file_from_url(
19 | url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
20 | load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
21 | model.load_state_dict(load_net, strict=True)
22 | model.eval()
23 | model = model.to(device)
24 | return model
25 |
--------------------------------------------------------------------------------
/facexlib/parsing/bisenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .resnet import ResNet18
6 |
7 |
8 | class ConvBNReLU(nn.Module):
9 |
10 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
11 | super(ConvBNReLU, self).__init__()
12 | self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
13 | self.bn = nn.BatchNorm2d(out_chan)
14 |
15 | def forward(self, x):
16 | x = self.conv(x)
17 | x = F.relu(self.bn(x))
18 | return x
19 |
20 |
21 | class BiSeNetOutput(nn.Module):
22 |
23 | def __init__(self, in_chan, mid_chan, num_class):
24 | super(BiSeNetOutput, self).__init__()
25 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
26 | self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False)
27 |
28 | def forward(self, x):
29 | feat = self.conv(x)
30 | out = self.conv_out(feat)
31 | return out, feat
32 |
33 |
34 | class AttentionRefinementModule(nn.Module):
35 |
36 | def __init__(self, in_chan, out_chan):
37 | super(AttentionRefinementModule, self).__init__()
38 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
39 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
40 | self.bn_atten = nn.BatchNorm2d(out_chan)
41 | self.sigmoid_atten = nn.Sigmoid()
42 |
43 | def forward(self, x):
44 | feat = self.conv(x)
45 | atten = F.avg_pool2d(feat, feat.size()[2:])
46 | atten = self.conv_atten(atten)
47 | atten = self.bn_atten(atten)
48 | atten = self.sigmoid_atten(atten)
49 | out = torch.mul(feat, atten)
50 | return out
51 |
52 |
53 | class ContextPath(nn.Module):
54 |
55 | def __init__(self):
56 | super(ContextPath, self).__init__()
57 | self.resnet = ResNet18()
58 | self.arm16 = AttentionRefinementModule(256, 128)
59 | self.arm32 = AttentionRefinementModule(512, 128)
60 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
61 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
62 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
63 |
64 | def forward(self, x):
65 | feat8, feat16, feat32 = self.resnet(x)
66 | h8, w8 = feat8.size()[2:]
67 | h16, w16 = feat16.size()[2:]
68 | h32, w32 = feat32.size()[2:]
69 |
70 | avg = F.avg_pool2d(feat32, feat32.size()[2:])
71 | avg = self.conv_avg(avg)
72 | avg_up = F.interpolate(avg, (h32, w32), mode='nearest')
73 |
74 | feat32_arm = self.arm32(feat32)
75 | feat32_sum = feat32_arm + avg_up
76 | feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest')
77 | feat32_up = self.conv_head32(feat32_up)
78 |
79 | feat16_arm = self.arm16(feat16)
80 | feat16_sum = feat16_arm + feat32_up
81 | feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest')
82 | feat16_up = self.conv_head16(feat16_up)
83 |
84 | return feat8, feat16_up, feat32_up # x8, x8, x16
85 |
86 |
87 | class FeatureFusionModule(nn.Module):
88 |
89 | def __init__(self, in_chan, out_chan):
90 | super(FeatureFusionModule, self).__init__()
91 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
92 | self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
93 | self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.sigmoid = nn.Sigmoid()
96 |
97 | def forward(self, fsp, fcp):
98 | fcat = torch.cat([fsp, fcp], dim=1)
99 | feat = self.convblk(fcat)
100 | atten = F.avg_pool2d(feat, feat.size()[2:])
101 | atten = self.conv1(atten)
102 | atten = self.relu(atten)
103 | atten = self.conv2(atten)
104 | atten = self.sigmoid(atten)
105 | feat_atten = torch.mul(feat, atten)
106 | feat_out = feat_atten + feat
107 | return feat_out
108 |
109 |
110 | class BiSeNet(nn.Module):
111 |
112 | def __init__(self, num_class):
113 | super(BiSeNet, self).__init__()
114 | self.cp = ContextPath()
115 | self.ffm = FeatureFusionModule(256, 256)
116 | self.conv_out = BiSeNetOutput(256, 256, num_class)
117 | self.conv_out16 = BiSeNetOutput(128, 64, num_class)
118 | self.conv_out32 = BiSeNetOutput(128, 64, num_class)
119 |
120 | def forward(self, x, return_feat=False):
121 | h, w = x.size()[2:]
122 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature
123 | feat_sp = feat_res8 # replace spatial path feature with res3b1 feature
124 | feat_fuse = self.ffm(feat_sp, feat_cp8)
125 |
126 | out, feat = self.conv_out(feat_fuse)
127 | out16, feat16 = self.conv_out16(feat_cp8)
128 | out32, feat32 = self.conv_out32(feat_cp16)
129 |
130 | out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True)
131 | out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True)
132 | out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True)
133 |
134 | if return_feat:
135 | feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True)
136 | feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True)
137 | feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True)
138 | return out, out16, out32, feat, feat16, feat32
139 | else:
140 | return out, out16, out32
141 |
--------------------------------------------------------------------------------
/facexlib/parsing/parsenet.py:
--------------------------------------------------------------------------------
1 | """Modified from https://github.com/chaofengc/PSFRGAN
2 | """
3 | import numpy as np
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 |
8 | class NormLayer(nn.Module):
9 | """Normalization Layers.
10 |
11 | Args:
12 | channels: input channels, for batch norm and instance norm.
13 | input_size: input shape without batch size, for layer norm.
14 | """
15 |
16 | def __init__(self, channels, normalize_shape=None, norm_type='bn'):
17 | super(NormLayer, self).__init__()
18 | norm_type = norm_type.lower()
19 | self.norm_type = norm_type
20 | if norm_type == 'bn':
21 | self.norm = nn.BatchNorm2d(channels, affine=True)
22 | elif norm_type == 'in':
23 | self.norm = nn.InstanceNorm2d(channels, affine=False)
24 | elif norm_type == 'gn':
25 | self.norm = nn.GroupNorm(32, channels, affine=True)
26 | elif norm_type == 'pixel':
27 | self.norm = lambda x: F.normalize(x, p=2, dim=1)
28 | elif norm_type == 'layer':
29 | self.norm = nn.LayerNorm(normalize_shape)
30 | elif norm_type == 'none':
31 | self.norm = lambda x: x * 1.0
32 | else:
33 | assert 1 == 0, f'Norm type {norm_type} not support.'
34 |
35 | def forward(self, x, ref=None):
36 | if self.norm_type == 'spade':
37 | return self.norm(x, ref)
38 | else:
39 | return self.norm(x)
40 |
41 |
42 | class ReluLayer(nn.Module):
43 | """Relu Layer.
44 |
45 | Args:
46 | relu type: type of relu layer, candidates are
47 | - ReLU
48 | - LeakyReLU: default relu slope 0.2
49 | - PRelu
50 | - SELU
51 | - none: direct pass
52 | """
53 |
54 | def __init__(self, channels, relu_type='relu'):
55 | super(ReluLayer, self).__init__()
56 | relu_type = relu_type.lower()
57 | if relu_type == 'relu':
58 | self.func = nn.ReLU(True)
59 | elif relu_type == 'leakyrelu':
60 | self.func = nn.LeakyReLU(0.2, inplace=True)
61 | elif relu_type == 'prelu':
62 | self.func = nn.PReLU(channels)
63 | elif relu_type == 'selu':
64 | self.func = nn.SELU(True)
65 | elif relu_type == 'none':
66 | self.func = lambda x: x * 1.0
67 | else:
68 | assert 1 == 0, f'Relu type {relu_type} not support.'
69 |
70 | def forward(self, x):
71 | return self.func(x)
72 |
73 |
74 | class ConvLayer(nn.Module):
75 |
76 | def __init__(self,
77 | in_channels,
78 | out_channels,
79 | kernel_size=3,
80 | scale='none',
81 | norm_type='none',
82 | relu_type='none',
83 | use_pad=True,
84 | bias=True):
85 | super(ConvLayer, self).__init__()
86 | self.use_pad = use_pad
87 | self.norm_type = norm_type
88 | if norm_type in ['bn']:
89 | bias = False
90 |
91 | stride = 2 if scale == 'down' else 1
92 |
93 | self.scale_func = lambda x: x
94 | if scale == 'up':
95 | self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
96 |
97 | self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2)))
98 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
99 |
100 | self.relu = ReluLayer(out_channels, relu_type)
101 | self.norm = NormLayer(out_channels, norm_type=norm_type)
102 |
103 | def forward(self, x):
104 | out = self.scale_func(x)
105 | if self.use_pad:
106 | out = self.reflection_pad(out)
107 | out = self.conv2d(out)
108 | out = self.norm(out)
109 | out = self.relu(out)
110 | return out
111 |
112 |
113 | class ResidualBlock(nn.Module):
114 | """
115 | Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
116 | """
117 |
118 | def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
119 | super(ResidualBlock, self).__init__()
120 |
121 | if scale == 'none' and c_in == c_out:
122 | self.shortcut_func = lambda x: x
123 | else:
124 | self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
125 |
126 | scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
127 | scale_conf = scale_config_dict[scale]
128 |
129 | self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
130 | self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
131 |
132 | def forward(self, x):
133 | identity = self.shortcut_func(x)
134 |
135 | res = self.conv1(x)
136 | res = self.conv2(res)
137 | return identity + res
138 |
139 |
140 | class ParseNet(nn.Module):
141 |
142 | def __init__(self,
143 | in_size=128,
144 | out_size=128,
145 | min_feat_size=32,
146 | base_ch=64,
147 | parsing_ch=19,
148 | res_depth=10,
149 | relu_type='LeakyReLU',
150 | norm_type='bn',
151 | ch_range=[32, 256]):
152 | super().__init__()
153 | self.res_depth = res_depth
154 | act_args = {'norm_type': norm_type, 'relu_type': relu_type}
155 | min_ch, max_ch = ch_range
156 |
157 | ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
158 | min_feat_size = min(in_size, min_feat_size)
159 |
160 | down_steps = int(np.log2(in_size // min_feat_size))
161 | up_steps = int(np.log2(out_size // min_feat_size))
162 |
163 | # =============== define encoder-body-decoder ====================
164 | self.encoder = []
165 | self.encoder.append(ConvLayer(3, base_ch, 3, 1))
166 | head_ch = base_ch
167 | for i in range(down_steps):
168 | cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
169 | self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
170 | head_ch = head_ch * 2
171 |
172 | self.body = []
173 | for i in range(res_depth):
174 | self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
175 |
176 | self.decoder = []
177 | for i in range(up_steps):
178 | cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
179 | self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
180 | head_ch = head_ch // 2
181 |
182 | self.encoder = nn.Sequential(*self.encoder)
183 | self.body = nn.Sequential(*self.body)
184 | self.decoder = nn.Sequential(*self.decoder)
185 | self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
186 | self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
187 |
188 | def forward(self, x):
189 | feat = self.encoder(x)
190 | x = feat + self.body(feat)
191 | x = self.decoder(x)
192 | out_img = self.out_img_conv(x)
193 | out_mask = self.out_mask_conv(x)
194 | return out_mask, out_img
195 |
--------------------------------------------------------------------------------
/facexlib/parsing/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | def conv3x3(in_planes, out_planes, stride=1):
6 | """3x3 convolution with padding"""
7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
8 |
9 |
10 | class BasicBlock(nn.Module):
11 |
12 | def __init__(self, in_chan, out_chan, stride=1):
13 | super(BasicBlock, self).__init__()
14 | self.conv1 = conv3x3(in_chan, out_chan, stride)
15 | self.bn1 = nn.BatchNorm2d(out_chan)
16 | self.conv2 = conv3x3(out_chan, out_chan)
17 | self.bn2 = nn.BatchNorm2d(out_chan)
18 | self.relu = nn.ReLU(inplace=True)
19 | self.downsample = None
20 | if in_chan != out_chan or stride != 1:
21 | self.downsample = nn.Sequential(
22 | nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
23 | nn.BatchNorm2d(out_chan),
24 | )
25 |
26 | def forward(self, x):
27 | residual = self.conv1(x)
28 | residual = F.relu(self.bn1(residual))
29 | residual = self.conv2(residual)
30 | residual = self.bn2(residual)
31 |
32 | shortcut = x
33 | if self.downsample is not None:
34 | shortcut = self.downsample(x)
35 |
36 | out = shortcut + residual
37 | out = self.relu(out)
38 | return out
39 |
40 |
41 | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
42 | layers = [BasicBlock(in_chan, out_chan, stride=stride)]
43 | for i in range(bnum - 1):
44 | layers.append(BasicBlock(out_chan, out_chan, stride=1))
45 | return nn.Sequential(*layers)
46 |
47 |
48 | class ResNet18(nn.Module):
49 |
50 | def __init__(self):
51 | super(ResNet18, self).__init__()
52 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
53 | self.bn1 = nn.BatchNorm2d(64)
54 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
55 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
56 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
57 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
58 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
59 |
60 | def forward(self, x):
61 | x = self.conv1(x)
62 | x = F.relu(self.bn1(x))
63 | x = self.maxpool(x)
64 |
65 | x = self.layer1(x)
66 | feat8 = self.layer2(x) # 1/8
67 | feat16 = self.layer3(feat8) # 1/16
68 | feat32 = self.layer4(feat16) # 1/32
69 | return feat8, feat16, feat32
70 |
--------------------------------------------------------------------------------
/facexlib/recognition/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from facexlib.utils import load_file_from_url
4 | from .arcface_arch import Backbone
5 |
6 |
7 | def init_recognition_model(model_name, half=False, device='cuda', model_rootpath=None):
8 | if model_name == 'arcface':
9 | model = Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se').to('cuda').eval()
10 | model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/recognition_arcface_ir_se50.pth'
11 | else:
12 | raise NotImplementedError(f'{model_name} is not implemented.')
13 |
14 | model_path = load_file_from_url(
15 | url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
16 | model.load_state_dict(torch.load(model_path), strict=True)
17 | model.eval()
18 | model = model.to(device)
19 | return model
20 |
--------------------------------------------------------------------------------
/facexlib/recognition/arcface_arch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import namedtuple
3 | from torch.nn import (AdaptiveAvgPool2d, BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear, MaxPool2d, Module, PReLU,
4 | ReLU, Sequential, Sigmoid)
5 |
6 | # Original Arcface Model
7 |
8 |
9 | class Flatten(Module):
10 |
11 | def forward(self, input):
12 | return input.view(input.size(0), -1)
13 |
14 |
15 | def l2_norm(input, axis=1):
16 | norm = torch.norm(input, 2, axis, True)
17 | output = torch.div(input, norm)
18 | return output
19 |
20 |
21 | class SEModule(Module):
22 |
23 | def __init__(self, channels, reduction):
24 | super(SEModule, self).__init__()
25 | self.avg_pool = AdaptiveAvgPool2d(1)
26 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
27 | self.relu = ReLU(inplace=True)
28 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
29 | self.sigmoid = Sigmoid()
30 |
31 | def forward(self, x):
32 | module_input = x
33 | x = self.avg_pool(x)
34 | x = self.fc1(x)
35 | x = self.relu(x)
36 | x = self.fc2(x)
37 | x = self.sigmoid(x)
38 | return module_input * x
39 |
40 |
41 | class bottleneck_IR(Module):
42 |
43 | def __init__(self, in_channel, depth, stride):
44 | super(bottleneck_IR, self).__init__()
45 | if in_channel == depth:
46 | self.shortcut_layer = MaxPool2d(1, stride)
47 | else:
48 | self.shortcut_layer = Sequential(Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth))
49 | self.res_layer = Sequential(
50 | BatchNorm2d(in_channel), Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
51 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth))
52 |
53 | def forward(self, x):
54 | shortcut = self.shortcut_layer(x)
55 | res = self.res_layer(x)
56 | return res + shortcut
57 |
58 |
59 | class bottleneck_IR_SE(Module):
60 |
61 | def __init__(self, in_channel, depth, stride):
62 | super(bottleneck_IR_SE, self).__init__()
63 | if in_channel == depth:
64 | self.shortcut_layer = MaxPool2d(1, stride)
65 | else:
66 | self.shortcut_layer = Sequential(Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth))
67 | self.res_layer = Sequential(
68 | BatchNorm2d(in_channel), Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
69 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth), SEModule(depth, 16))
70 |
71 | def forward(self, x):
72 | shortcut = self.shortcut_layer(x)
73 | res = self.res_layer(x)
74 | return res + shortcut
75 |
76 |
77 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
78 | '''A named tuple describing a ResNet block.'''
79 |
80 |
81 | def get_block(in_channel, depth, num_units, stride=2):
82 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
83 |
84 |
85 | def get_blocks(num_layers):
86 | if num_layers == 50:
87 | blocks = [
88 | get_block(in_channel=64, depth=64, num_units=3),
89 | get_block(in_channel=64, depth=128, num_units=4),
90 | get_block(in_channel=128, depth=256, num_units=14),
91 | get_block(in_channel=256, depth=512, num_units=3)
92 | ]
93 | elif num_layers == 100:
94 | blocks = [
95 | get_block(in_channel=64, depth=64, num_units=3),
96 | get_block(in_channel=64, depth=128, num_units=13),
97 | get_block(in_channel=128, depth=256, num_units=30),
98 | get_block(in_channel=256, depth=512, num_units=3)
99 | ]
100 | elif num_layers == 152:
101 | blocks = [
102 | get_block(in_channel=64, depth=64, num_units=3),
103 | get_block(in_channel=64, depth=128, num_units=8),
104 | get_block(in_channel=128, depth=256, num_units=36),
105 | get_block(in_channel=256, depth=512, num_units=3)
106 | ]
107 | return blocks
108 |
109 |
110 | class Backbone(Module):
111 |
112 | def __init__(self, num_layers, drop_ratio, mode='ir'):
113 | super(Backbone, self).__init__()
114 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
115 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
116 | blocks = get_blocks(num_layers)
117 | if mode == 'ir':
118 | unit_module = bottleneck_IR
119 | elif mode == 'ir_se':
120 | unit_module = bottleneck_IR_SE
121 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64), PReLU(64))
122 | self.output_layer = Sequential(
123 | BatchNorm2d(512), Dropout(drop_ratio), Flatten(), Linear(512 * 7 * 7, 512), BatchNorm1d(512))
124 | modules = []
125 | for block in blocks:
126 | for bottleneck in block:
127 | modules.append(unit_module(bottleneck.in_channel, bottleneck.depth, bottleneck.stride))
128 | self.body = Sequential(*modules)
129 |
130 | def forward(self, x):
131 | x = self.input_layer(x)
132 | x = self.body(x)
133 | x = self.output_layer(x)
134 | return l2_norm(x)
135 |
136 |
137 | # MobileFaceNet
138 |
139 |
140 | class Conv_block(Module):
141 |
142 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
143 | super(Conv_block, self).__init__()
144 | self.conv = Conv2d(
145 | in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
146 | self.bn = BatchNorm2d(out_c)
147 | self.prelu = PReLU(out_c)
148 |
149 | def forward(self, x):
150 | x = self.conv(x)
151 | x = self.bn(x)
152 | x = self.prelu(x)
153 | return x
154 |
155 |
156 | class Linear_block(Module):
157 |
158 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
159 | super(Linear_block, self).__init__()
160 | self.conv = Conv2d(
161 | in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
162 | self.bn = BatchNorm2d(out_c)
163 |
164 | def forward(self, x):
165 | x = self.conv(x)
166 | x = self.bn(x)
167 | return x
168 |
169 |
170 | class Depth_Wise(Module):
171 |
172 | def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
173 | super(Depth_Wise, self).__init__()
174 | self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
175 | self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
176 | self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
177 | self.residual = residual
178 |
179 | def forward(self, x):
180 | if self.residual:
181 | short_cut = x
182 | x = self.conv(x)
183 | x = self.conv_dw(x)
184 | x = self.project(x)
185 | if self.residual:
186 | output = short_cut + x
187 | else:
188 | output = x
189 | return output
190 |
191 |
192 | class Residual(Module):
193 |
194 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
195 | super(Residual, self).__init__()
196 | modules = []
197 | for _ in range(num_block):
198 | modules.append(
199 | Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
200 | self.model = Sequential(*modules)
201 |
202 | def forward(self, x):
203 | return self.model(x)
204 |
205 |
206 | class MobileFaceNet(Module):
207 |
208 | def __init__(self, embedding_size):
209 | super(MobileFaceNet, self).__init__()
210 | self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
211 | self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
212 | self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
213 | self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
214 | self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
215 | self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
216 | self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
217 | self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
218 | self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
219 | self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0))
220 | self.conv_6_flatten = Flatten()
221 | self.linear = Linear(512, embedding_size, bias=False)
222 | self.bn = BatchNorm1d(embedding_size)
223 |
224 | def forward(self, x):
225 | out = self.conv1(x)
226 | out = self.conv2_dw(out)
227 | out = self.conv_23(out)
228 | out = self.conv_3(out)
229 | out = self.conv_34(out)
230 | out = self.conv_4(out)
231 | out = self.conv_45(out)
232 | out = self.conv_5(out)
233 | out = self.conv_6_sep(out)
234 | out = self.conv_6_dw(out)
235 | out = self.conv_6_flatten(out)
236 | out = self.linear(out)
237 | out = self.bn(out)
238 | return l2_norm(out)
239 |
--------------------------------------------------------------------------------
/facexlib/tracking/README.md:
--------------------------------------------------------------------------------
1 | https://github.com/abewley/sort
2 |
--------------------------------------------------------------------------------
/facexlib/tracking/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xinntao/facexlib/260620ae93990a300f4b16448df9bb459f1caba9/facexlib/tracking/__init__.py
--------------------------------------------------------------------------------
/facexlib/tracking/data_association.py:
--------------------------------------------------------------------------------
1 | """
2 | For each detected item, it computes the intersection over union (IOU) w.r.t.
3 | each tracked object. (IOU matrix)
4 | Then, it applies the Hungarian algorithm (via linear_assignment) to assign each
5 | det. item to the best possible tracked item (i.e. to the one with max IOU)
6 | """
7 |
8 | import numpy as np
9 | from numba import jit
10 | from scipy.optimize import linear_sum_assignment as linear_assignment
11 |
12 |
13 | @jit
14 | def iou(bb_test, bb_gt):
15 | """Computes IOU between two bboxes in the form [x1,y1,x2,y2]
16 | """
17 | xx1 = np.maximum(bb_test[0], bb_gt[0])
18 | yy1 = np.maximum(bb_test[1], bb_gt[1])
19 | xx2 = np.minimum(bb_test[2], bb_gt[2])
20 | yy2 = np.minimum(bb_test[3], bb_gt[3])
21 | w = np.maximum(0., xx2 - xx1)
22 | h = np.maximum(0., yy2 - yy1)
23 | wh = w * h
24 | o = wh / ((bb_test[2] - bb_test[0]) * (bb_test[3] - bb_test[1]) + (bb_gt[2] - bb_gt[0]) *
25 | (bb_gt[3] - bb_gt[1]) - wh)
26 | return (o)
27 |
28 |
29 | def associate_detections_to_trackers(detections, trackers, iou_threshold=0.25):
30 | """Assigns detections to tracked object (both represented as bounding boxes)
31 |
32 | Returns:
33 | 3 lists of matches, unmatched_detections and unmatched_trackers.
34 | """
35 | if len(trackers) == 0:
36 | return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int)
37 |
38 | iou_matrix = np.zeros((len(detections), len(trackers)), dtype=np.float32)
39 |
40 | for d, det in enumerate(detections):
41 | for t, trk in enumerate(trackers):
42 | iou_matrix[d, t] = iou(det, trk)
43 | # The linear assignment module tries to minimize the total assignment cost.
44 | # In our case we pass -iou_matrix as we want to maximise the total IOU
45 | # between track predictions and the frame detection.
46 | row_ind, col_ind = linear_assignment(-iou_matrix)
47 |
48 | unmatched_detections = []
49 | for d, det in enumerate(detections):
50 | if d not in row_ind:
51 | unmatched_detections.append(d)
52 | unmatched_trackers = []
53 | for t, trk in enumerate(trackers):
54 | if t not in col_ind:
55 | unmatched_trackers.append(t)
56 |
57 | # filter out matched with low IOU
58 | matches = []
59 | for row, col in zip(row_ind, col_ind):
60 | if iou_matrix[row, col] < iou_threshold:
61 | unmatched_detections.append(row)
62 | unmatched_trackers.append(col)
63 | else:
64 | matches.append(np.array([[row, col]]))
65 |
66 | if len(matches) == 0:
67 | matches = np.empty((0, 2), dtype=int)
68 | else:
69 | matches = np.concatenate(matches, axis=0)
70 |
71 | return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
72 |
--------------------------------------------------------------------------------
/facexlib/tracking/kalman_tracker.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from filterpy.kalman import KalmanFilter
3 |
4 |
5 | def convert_bbox_to_z(bbox):
6 | """Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
7 | [x,y,s,r] where x,y is the centre of the box and s is the scale/area and
8 | r is the aspect ratio
9 | """
10 | w = bbox[2] - bbox[0]
11 | h = bbox[3] - bbox[1]
12 | x = bbox[0] + w / 2.
13 | y = bbox[1] + h / 2.
14 | s = w * h # scale is just area
15 | r = w / float(h)
16 | return np.array([x, y, s, r]).reshape((4, 1))
17 |
18 |
19 | def convert_x_to_bbox(x, score=None):
20 | """Takes a bounding box in the centre form [x,y,s,r] and returns it in
21 | the form [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom
22 | right
23 | """
24 | w = np.sqrt(x[2] * x[3])
25 | h = x[2] / w
26 | if score is None:
27 | return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2.]).reshape((1, 4))
28 | else:
29 | return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score]).reshape((1, 5))
30 |
31 |
32 | class KalmanBoxTracker(object):
33 | """This class represents the internal state of individual tracked objects
34 | observed as bbox.
35 | doc: https://filterpy.readthedocs.io/en/latest/kalman/KalmanFilter.html
36 | """
37 | count = 0
38 |
39 | def __init__(self, bbox):
40 | """Initialize a tracker using initial bounding box.
41 | """
42 | # define constant velocity model
43 | # TODO: x: what is the meanning of x[4:7], v?
44 | self.kf = KalmanFilter(dim_x=7, dim_z=4)
45 | # F (dim_x, dim_x): state transition matrix
46 | self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0], [0, 0, 1, 0, 0, 0,
47 | 1], [0, 0, 0, 1, 0, 0, 0],
48 | [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 1, 0], [0, 0, 0, 0, 0, 0, 1]])
49 | # H (dim_z, dim_x): measurement function
50 | self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0],
51 | [0, 0, 0, 1, 0, 0, 0]])
52 | # R (dim_z, dim_z): measurement uncertainty/noise
53 | self.kf.R[2:, 2:] *= 10.
54 | # P (dim_x, dim_x): covariance matrix
55 | # give high uncertainty to the unobservable initial velocities
56 | self.kf.P[4:, 4:] *= 1000.
57 | self.kf.P *= 10.
58 | # Q (dim_x, dim_x): Process uncertainty/noise
59 | self.kf.Q[-1, -1] *= 0.01
60 | self.kf.Q[4:, 4:] *= 0.01
61 | # x (dim_x, 1): filter state estimate
62 | self.kf.x[:4] = convert_bbox_to_z(bbox)
63 |
64 | self.time_since_update = 0
65 | self.id = KalmanBoxTracker.count
66 | KalmanBoxTracker.count += 1
67 | self.history = []
68 | self.hits = 0
69 | self.hit_streak = 0
70 | self.age = 0
71 |
72 | # 解决画面中无人脸检测到时而导致的原有追踪器人像预测的漂移bug
73 | self.predict_num = 0 # 连续预测的数目
74 |
75 | # additional fields
76 | self.face_attributes = []
77 |
78 | def update(self, bbox):
79 | """Updates the state vector with observed bbox.
80 | """
81 | self.time_since_update = 0
82 | self.history = []
83 | self.hits += 1
84 | self.hit_streak += 1 # 连续命中
85 | if bbox != []:
86 | self.kf.update(convert_bbox_to_z(bbox))
87 | self.predict_num = 0
88 | else:
89 | self.predict_num += 1
90 |
91 | def predict(self):
92 | """Advances the state vector and returns the predicted bounding box
93 | estimate.
94 | """
95 |
96 | if (self.kf.x[6] + self.kf.x[2]) <= 0:
97 | self.kf.x[6] *= 0.0
98 | self.kf.predict()
99 | self.age += 1
100 | if self.time_since_update > 0:
101 | self.hit_streak = 0
102 | self.time_since_update += 1
103 | self.history.append(convert_x_to_bbox(self.kf.x))
104 | return self.history[-1][0]
105 |
106 | def get_state(self):
107 | """Returns the current bounding box estimate."""
108 | return convert_x_to_bbox(self.kf.x)[0]
109 |
--------------------------------------------------------------------------------
/facexlib/tracking/sort.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from facexlib.tracking.data_association import associate_detections_to_trackers
4 | from facexlib.tracking.kalman_tracker import KalmanBoxTracker
5 |
6 |
7 | class SORT(object):
8 | """SORT: A Simple, Online and Realtime Tracker.
9 |
10 | Ref: https://github.com/abewley/sort
11 | """
12 |
13 | def __init__(self, max_age=1, min_hits=3, iou_threshold=0.3):
14 | self.max_age = max_age
15 | self.min_hits = min_hits # 最小的连续命中, 只有满足的才会被返回
16 | self.iou_threshold = iou_threshold
17 | self.trackers = []
18 | self.frame_count = 0
19 |
20 | def update(self, dets, img_size, additional_attr, detect_interval):
21 | """This method must be called once for each frame even with
22 | empty detections.
23 | NOTE:as in practical realtime MOT, the detector doesn't run on every
24 | single frame.
25 |
26 | Args:
27 | dets (Numpy array): detections in the format
28 | [[x0,y0,x1,y1,score], [x0,y0,x1,y1,score], ...]
29 |
30 | Returns:
31 | a similar array, where the last column is the object ID.
32 | """
33 | self.frame_count += 1
34 |
35 | # get predicted locations from existing trackers
36 | trks = np.zeros((len(self.trackers), 5))
37 | to_del = [] # To be deleted
38 | ret = []
39 | # predict tracker position using Kalman filter
40 | for t, trk in enumerate(trks):
41 | pos = self.trackers[t].predict() # Kalman predict ,very fast ,<1ms
42 | trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
43 | if np.any(np.isnan(pos)):
44 | to_del.append(t)
45 | trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
46 | for t in reversed(to_del):
47 | self.trackers.pop(t)
48 |
49 | if dets != []:
50 | matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers( # noqa: E501
51 | dets, trks)
52 |
53 | # update matched trackers with assigned detections
54 | for t, trk in enumerate(self.trackers):
55 | if t not in unmatched_trks:
56 | d = matched[np.where(matched[:, 1] == t)[0], 0]
57 | trk.update(dets[d, :][0])
58 | trk.face_attributes.append(additional_attr[d[0]])
59 |
60 | # create and initialize new trackers for unmatched detections
61 | for i in unmatched_dets:
62 | trk = KalmanBoxTracker(dets[i, :])
63 | trk.face_attributes.append(additional_attr[i])
64 | print(f'New tracker: {trk.id + 1}.')
65 | self.trackers.append(trk)
66 |
67 | i = len(self.trackers)
68 | for trk in reversed(self.trackers):
69 | if dets == []:
70 | trk.update([])
71 |
72 | d = trk.get_state()
73 | # get return tracklet
74 | # 1) time_since_update < 1: detected
75 | # 2) i) hit_streak >= min_hits: 最小的连续命中
76 | # ii) frame_count <= min_hits: 最开始的几帧
77 | if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
78 | ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1)) # +1 as MOT benchmark requires positive
79 | i -= 1
80 |
81 | # remove dead tracklet
82 | # 1) time_since_update >= max_age: 多久没有更新了
83 | # 2) predict_num: 连续预测的帧数
84 | # 3) out of image size
85 | if (trk.time_since_update >= self.max_age) or (trk.predict_num >= detect_interval) or (
86 | d[2] < 0 or d[3] < 0 or d[0] > img_size[1] or d[1] > img_size[0]):
87 | print(f'Remove tracker: {trk.id + 1}')
88 | self.trackers.pop(i)
89 | if len(ret) > 0:
90 | return np.concatenate(ret)
91 | else:
92 | return np.empty((0, 5))
93 |
--------------------------------------------------------------------------------
/facexlib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back
2 | from .misc import img2tensor, load_file_from_url, scandir
3 |
4 | __all__ = [
5 | 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', 'paste_face_back',
6 | 'img2tensor', 'scandir'
7 | ]
8 |
--------------------------------------------------------------------------------
/facexlib/utils/face_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
7 | left, top, right, bot = bbox
8 | width = right - left
9 | height = bot - top
10 |
11 | if preserve_aspect:
12 | width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
13 | height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
14 | else:
15 | width_increase = height_increase = increase_area
16 | left = int(left - width_increase * width)
17 | top = int(top - height_increase * height)
18 | right = int(right + width_increase * width)
19 | bot = int(bot + height_increase * height)
20 | return (left, top, right, bot)
21 |
22 |
23 | def get_valid_bboxes(bboxes, h, w):
24 | left = max(bboxes[0], 0)
25 | top = max(bboxes[1], 0)
26 | right = min(bboxes[2], w)
27 | bottom = min(bboxes[3], h)
28 | return (left, top, right, bottom)
29 |
30 |
31 | def align_crop_face_landmarks(img,
32 | landmarks,
33 | output_size,
34 | transform_size=None,
35 | enable_padding=True,
36 | return_inverse_affine=False,
37 | shrink_ratio=(1, 1)):
38 | """Align and crop face with landmarks.
39 |
40 | The output_size and transform_size are based on width. The height is
41 | adjusted based on shrink_ratio_h/shring_ration_w.
42 |
43 | Modified from:
44 | https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
45 |
46 | Args:
47 | img (Numpy array): Input image.
48 | landmarks (Numpy array): 5 or 68 or 98 landmarks.
49 | output_size (int): Output face size.
50 | transform_size (ing): Transform size. Usually the four time of
51 | output_size.
52 | enable_padding (float): Default: True.
53 | shrink_ratio (float | tuple[float] | list[float]): Shring the whole
54 | face for height and width (crop larger area). Default: (1, 1).
55 |
56 | Returns:
57 | (Numpy array): Cropped face.
58 | """
59 | lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5
60 |
61 | if isinstance(shrink_ratio, (float, int)):
62 | shrink_ratio = (shrink_ratio, shrink_ratio)
63 | if transform_size is None:
64 | transform_size = output_size * 4
65 |
66 | # Parse landmarks
67 | lm = np.array(landmarks)
68 | if lm.shape[0] == 5 and lm_type == 'retinaface_5':
69 | eye_left = lm[0]
70 | eye_right = lm[1]
71 | mouth_avg = (lm[3] + lm[4]) * 0.5
72 | elif lm.shape[0] == 5 and lm_type == 'dlib_5':
73 | lm_eye_left = lm[2:4]
74 | lm_eye_right = lm[0:2]
75 | eye_left = np.mean(lm_eye_left, axis=0)
76 | eye_right = np.mean(lm_eye_right, axis=0)
77 | mouth_avg = lm[4]
78 | elif lm.shape[0] == 68:
79 | lm_eye_left = lm[36:42]
80 | lm_eye_right = lm[42:48]
81 | eye_left = np.mean(lm_eye_left, axis=0)
82 | eye_right = np.mean(lm_eye_right, axis=0)
83 | mouth_avg = (lm[48] + lm[54]) * 0.5
84 | elif lm.shape[0] == 98:
85 | lm_eye_left = lm[60:68]
86 | lm_eye_right = lm[68:76]
87 | eye_left = np.mean(lm_eye_left, axis=0)
88 | eye_right = np.mean(lm_eye_right, axis=0)
89 | mouth_avg = (lm[76] + lm[82]) * 0.5
90 |
91 | eye_avg = (eye_left + eye_right) * 0.5
92 | eye_to_eye = eye_right - eye_left
93 | eye_to_mouth = mouth_avg - eye_avg
94 |
95 | # Get the oriented crop rectangle
96 | # x: half width of the oriented crop rectangle
97 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
98 | # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
99 | # norm with the hypotenuse: get the direction
100 | x /= np.hypot(*x) # get the hypotenuse of a right triangle
101 | rect_scale = 1 # TODO: you can edit it to get larger rect
102 | x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
103 | # y: half height of the oriented crop rectangle
104 | y = np.flipud(x) * [-1, 1]
105 |
106 | x *= shrink_ratio[1] # width
107 | y *= shrink_ratio[0] # height
108 |
109 | # c: center
110 | c = eye_avg + eye_to_mouth * 0.1
111 | # quad: (left_top, left_bottom, right_bottom, right_top)
112 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
113 | # qsize: side length of the square
114 | qsize = np.hypot(*x) * 2
115 |
116 | quad_ori = np.copy(quad)
117 | # Shrink, for large face
118 | # TODO: do we really need shrink
119 | shrink = int(np.floor(qsize / output_size * 0.5))
120 | if shrink > 1:
121 | h, w = img.shape[0:2]
122 | rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
123 | img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
124 | quad /= shrink
125 | qsize /= shrink
126 |
127 | # Crop
128 | h, w = img.shape[0:2]
129 | border = max(int(np.rint(qsize * 0.1)), 3)
130 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
131 | int(np.ceil(max(quad[:, 1]))))
132 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h))
133 | if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
134 | img = img[crop[1]:crop[3], crop[0]:crop[2], :]
135 | quad -= crop[0:2]
136 |
137 | # Pad
138 | # pad: (width_left, height_top, width_right, height_bottom)
139 | h, w = img.shape[0:2]
140 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
141 | int(np.ceil(max(quad[:, 1]))))
142 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0))
143 | if enable_padding and max(pad) > border - 4:
144 | pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
145 | img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
146 | h, w = img.shape[0:2]
147 | y, x, _ = np.ogrid[:h, :w, :1]
148 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
149 | np.float32(w - 1 - x) / pad[2]),
150 | 1.0 - np.minimum(np.float32(y) / pad[1],
151 | np.float32(h - 1 - y) / pad[3]))
152 | blur = int(qsize * 0.02)
153 | if blur % 2 == 0:
154 | blur += 1
155 | blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
156 |
157 | img = img.astype('float32')
158 | img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
159 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
160 | img = np.clip(img, 0, 255) # float32, [0, 255]
161 | quad += pad[:2]
162 |
163 | # Transform use cv2
164 | h_ratio = shrink_ratio[0] / shrink_ratio[1]
165 | dst_h, dst_w = int(transform_size * h_ratio), transform_size
166 | template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
167 | # use cv2.LMEDS method for the equivalence to skimage transform
168 | # ref: https://blog.csdn.net/yichxi/article/details/115827338
169 | affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
170 | cropped_face = cv2.warpAffine(
171 | img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray
172 |
173 | if output_size < transform_size:
174 | cropped_face = cv2.resize(
175 | cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR)
176 |
177 | if return_inverse_affine:
178 | dst_h, dst_w = int(output_size * h_ratio), output_size
179 | template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
180 | # use cv2.LMEDS method for the equivalence to skimage transform
181 | # ref: https://blog.csdn.net/yichxi/article/details/115827338
182 | affine_matrix = cv2.estimateAffinePartial2D(
183 | quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0]
184 | inverse_affine = cv2.invertAffineTransform(affine_matrix)
185 | else:
186 | inverse_affine = None
187 | return cropped_face, inverse_affine
188 |
189 |
190 | def paste_face_back(img, face, inverse_affine):
191 | h, w = img.shape[0:2]
192 | face_h, face_w = face.shape[0:2]
193 | inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
194 | mask = np.ones((face_h, face_w, 3), dtype=np.float32)
195 | inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
196 | # remove the black borders
197 | inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
198 | inv_restored_remove_border = inv_mask_erosion * inv_restored
199 | total_face_area = np.sum(inv_mask_erosion) // 3
200 | # compute the fusion edge based on the area of face
201 | w_edge = int(total_face_area**0.5) // 20
202 | erosion_radius = w_edge * 2
203 | inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
204 | blur_size = w_edge * 2
205 | inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
206 | img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
207 | # float32, [0, 255]
208 | return img
209 |
210 |
211 | if __name__ == '__main__':
212 | import os
213 |
214 | from facexlib.detection import init_detection_model
215 | from facexlib.utils.face_restoration_helper import get_largest_face
216 | from facexlib.visualization import visualize_detection
217 |
218 | img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png'
219 | img_name = os.splitext(os.path.basename(img_path))[0]
220 |
221 | # initialize model
222 | det_net = init_detection_model('retinaface_resnet50', half=False)
223 | img_ori = cv2.imread(img_path)
224 | h, w = img_ori.shape[0:2]
225 | # if larger than 800, scale it
226 | scale = max(h / 800, w / 800)
227 | if scale > 1:
228 | img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR)
229 |
230 | with torch.no_grad():
231 | bboxes = det_net.detect_faces(img, 0.97)
232 | if scale > 1:
233 | bboxes *= scale # the score is incorrect
234 | bboxes = get_largest_face(bboxes, h, w)[0]
235 | visualize_detection(img_ori, [bboxes], f'tmp/{img_name}_det.png')
236 |
237 | landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)])
238 |
239 | cropped_face, inverse_affine = align_crop_face_landmarks(
240 | img_ori,
241 | landmarks,
242 | output_size=512,
243 | transform_size=None,
244 | enable_padding=True,
245 | return_inverse_affine=True,
246 | shrink_ratio=(1, 1))
247 |
248 | cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face)
249 | img = paste_face_back(img_ori, cropped_face, inverse_affine)
250 | cv2.imwrite(f'tmp/{img_name}_back.png', img)
251 |
--------------------------------------------------------------------------------
/facexlib/utils/misc.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import os.path as osp
4 | import torch
5 | from torch.hub import download_url_to_file, get_dir
6 | from urllib.parse import urlparse
7 |
8 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9 |
10 |
11 | def imwrite(img, file_path, params=None, auto_mkdir=True):
12 | """Write image to file.
13 |
14 | Args:
15 | img (ndarray): Image array to be written.
16 | file_path (str): Image file path.
17 | params (None or list): Same as opencv's :func:`imwrite` interface.
18 | auto_mkdir (bool): If the parent folder of `file_path` does not exist,
19 | whether to create it automatically.
20 |
21 | Returns:
22 | bool: Successful or not.
23 | """
24 | if auto_mkdir:
25 | dir_name = os.path.abspath(os.path.dirname(file_path))
26 | os.makedirs(dir_name, exist_ok=True)
27 | return cv2.imwrite(file_path, img, params)
28 |
29 |
30 | def img2tensor(imgs, bgr2rgb=True, float32=True):
31 | """Numpy array to tensor.
32 |
33 | Args:
34 | imgs (list[ndarray] | ndarray): Input images.
35 | bgr2rgb (bool): Whether to change bgr to rgb.
36 | float32 (bool): Whether to change to float32.
37 |
38 | Returns:
39 | list[tensor] | tensor: Tensor images. If returned results only have
40 | one element, just return tensor.
41 | """
42 |
43 | def _totensor(img, bgr2rgb, float32):
44 | if img.shape[2] == 3 and bgr2rgb:
45 | if img.dtype == 'float64':
46 | img = img.astype('float32')
47 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
48 | img = torch.from_numpy(img.transpose(2, 0, 1))
49 | if float32:
50 | img = img.float()
51 | return img
52 |
53 | if isinstance(imgs, list):
54 | return [_totensor(img, bgr2rgb, float32) for img in imgs]
55 | else:
56 | return _totensor(imgs, bgr2rgb, float32)
57 |
58 |
59 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None, save_dir=None):
60 | """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
61 | """
62 | if model_dir is None:
63 | hub_dir = get_dir()
64 | model_dir = os.path.join(hub_dir, 'checkpoints')
65 |
66 | if save_dir is None:
67 | save_dir = os.path.join(ROOT_DIR, model_dir)
68 | os.makedirs(save_dir, exist_ok=True)
69 |
70 | parts = urlparse(url)
71 | filename = os.path.basename(parts.path)
72 | if file_name is not None:
73 | filename = file_name
74 | cached_file = os.path.abspath(os.path.join(save_dir, filename))
75 | if not os.path.exists(cached_file):
76 | print(f'Downloading: "{url}" to {cached_file}\n')
77 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
78 | return cached_file
79 |
80 |
81 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
82 | """Scan a directory to find the interested files.
83 | Args:
84 | dir_path (str): Path of the directory.
85 | suffix (str | tuple(str), optional): File suffix that we are
86 | interested in. Default: None.
87 | recursive (bool, optional): If set to True, recursively scan the
88 | directory. Default: False.
89 | full_path (bool, optional): If set to True, include the dir_path.
90 | Default: False.
91 | Returns:
92 | A generator for all the interested files with relative paths.
93 | """
94 |
95 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
96 | raise TypeError('"suffix" must be a string or tuple of strings')
97 |
98 | root = dir_path
99 |
100 | def _scandir(dir_path, suffix, recursive):
101 | for entry in os.scandir(dir_path):
102 | if not entry.name.startswith('.') and entry.is_file():
103 | if full_path:
104 | return_path = entry.path
105 | else:
106 | return_path = osp.relpath(entry.path, root)
107 |
108 | if suffix is None:
109 | yield return_path
110 | elif return_path.endswith(suffix):
111 | yield return_path
112 | else:
113 | if recursive:
114 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
115 | else:
116 | continue
117 |
118 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
119 |
--------------------------------------------------------------------------------
/facexlib/visualization/__init__.py:
--------------------------------------------------------------------------------
1 | from .vis_alignment import visualize_alignment
2 | from .vis_detection import visualize_detection
3 | from .vis_headpose import visualize_headpose
4 |
5 | __all__ = ['visualize_detection', 'visualize_alignment', 'visualize_headpose']
6 |
--------------------------------------------------------------------------------
/facexlib/visualization/vis_alignment.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def visualize_alignment(img, landmarks, save_path=None, to_bgr=False):
6 | img = np.copy(img)
7 | h, w = img.shape[0:2]
8 | circle_size = int(max(h, w) / 150)
9 | if to_bgr:
10 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
11 |
12 | for landmarks_face in landmarks:
13 | for lm in landmarks_face:
14 | cv2.circle(img, (int(lm[0]), int(lm[1])), 1, (0, 150, 0), circle_size)
15 |
16 | # save img
17 | if save_path is not None:
18 | cv2.imwrite(save_path, img)
19 |
--------------------------------------------------------------------------------
/facexlib/visualization/vis_detection.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def visualize_detection(img, bboxes_and_landmarks, save_path=None, to_bgr=False):
6 | """Visualize detection results.
7 |
8 | Args:
9 | img (Numpy array): Input image. CHW, BGR, [0, 255], uint8.
10 | """
11 | img = np.copy(img)
12 | if to_bgr:
13 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
14 |
15 | for b in bboxes_and_landmarks:
16 | # confidence
17 | cv2.putText(img, f'{b[4]:.4f}', (int(b[0]), int(b[1] + 12)), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255))
18 | # bounding boxes
19 | b = list(map(int, b))
20 | cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
21 | # landmarks (for retinaface)
22 | cv2.circle(img, (b[5], b[6]), 1, (0, 0, 255), 4)
23 | cv2.circle(img, (b[7], b[8]), 1, (0, 255, 255), 4)
24 | cv2.circle(img, (b[9], b[10]), 1, (255, 0, 255), 4)
25 | cv2.circle(img, (b[11], b[12]), 1, (0, 255, 0), 4)
26 | cv2.circle(img, (b[13], b[14]), 1, (255, 0, 0), 4)
27 | # save img
28 | if save_path is not None:
29 | cv2.imwrite(save_path, img)
30 |
--------------------------------------------------------------------------------
/facexlib/visualization/vis_headpose.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from math import cos, sin
4 |
5 |
6 | def draw_axis(img, yaw, pitch, roll, tdx=None, tdy=None, size=100):
7 | """draw head pose axis."""
8 |
9 | pitch = pitch * np.pi / 180
10 | yaw = -yaw * np.pi / 180
11 | roll = roll * np.pi / 180
12 |
13 | if tdx is None or tdy is None:
14 | height, width = img.shape[:2]
15 | tdx = width / 2
16 | tdy = height / 2
17 |
18 | # X axis pointing to right, drawn in red
19 | x1 = size * (cos(yaw) * cos(roll)) + tdx
20 | y1 = size * (cos(pitch) * sin(roll) + cos(roll) * sin(pitch) * sin(yaw)) + tdy
21 | # Y axis pointing downside, drawn in green
22 | x2 = size * (-cos(yaw) * sin(roll)) + tdx
23 | y2 = size * (cos(pitch) * cos(roll) - sin(pitch) * sin(yaw) * sin(roll)) + tdy
24 | # Z axis, out of the screen, drawn in blue
25 | x3 = size * (sin(yaw)) + tdx
26 | y3 = size * (-cos(yaw) * sin(pitch)) + tdy
27 |
28 | cv2.line(img, (int(tdx), int(tdy)), (int(x1), int(y1)), (0, 0, 255), 3)
29 | cv2.line(img, (int(tdx), int(tdy)), (int(x2), int(y2)), (0, 255, 0), 3)
30 | cv2.line(img, (int(tdx), int(tdy)), (int(x3), int(y3)), (255, 0, 0), 2)
31 |
32 | return img
33 |
34 |
35 | def draw_pose_cube(img, yaw, pitch, roll, tdx=None, tdy=None, size=150.):
36 | """draw head pose cube.
37 | Where (tdx, tdy) is the translation of the face.
38 | For pose we have [pitch yaw roll tdx tdy tdz scale_factor]
39 | """
40 |
41 | p = pitch * np.pi / 180
42 | y = -yaw * np.pi / 180
43 | r = roll * np.pi / 180
44 | if tdx is not None and tdy is not None:
45 | face_x = tdx - 0.50 * size
46 | face_y = tdy - 0.50 * size
47 | else:
48 | height, width = img.shape[:2]
49 | face_x = width / 2 - 0.5 * size
50 | face_y = height / 2 - 0.5 * size
51 |
52 | x1 = size * (cos(y) * cos(r)) + face_x
53 | y1 = size * (cos(p) * sin(r) + cos(r) * sin(p) * sin(y)) + face_y
54 | x2 = size * (-cos(y) * sin(r)) + face_x
55 | y2 = size * (cos(p) * cos(r) - sin(p) * sin(y) * sin(r)) + face_y
56 | x3 = size * (sin(y)) + face_x
57 | y3 = size * (-cos(y) * sin(p)) + face_y
58 |
59 | # Draw base in red
60 | cv2.line(img, (int(face_x), int(face_y)), (int(x1), int(y1)), (0, 0, 255), 3)
61 | cv2.line(img, (int(face_x), int(face_y)), (int(x2), int(y2)), (0, 0, 255), 3)
62 | cv2.line(img, (int(x2), int(y2)), (int(x2 + x1 - face_x), int(y2 + y1 - face_y)), (0, 0, 255), 3)
63 | cv2.line(img, (int(x1), int(y1)), (int(x1 + x2 - face_x), int(y1 + y2 - face_y)), (0, 0, 255), 3)
64 | # Draw pillars in blue
65 | cv2.line(img, (int(face_x), int(face_y)), (int(x3), int(y3)), (255, 0, 0), 2)
66 | cv2.line(img, (int(x1), int(y1)), (int(x1 + x3 - face_x), int(y1 + y3 - face_y)), (255, 0, 0), 2)
67 | cv2.line(img, (int(x2), int(y2)), (int(x2 + x3 - face_x), int(y2 + y3 - face_y)), (255, 0, 0), 2)
68 | cv2.line(img, (int(x2 + x1 - face_x), int(y2 + y1 - face_y)),
69 | (int(x3 + x1 + x2 - 2 * face_x), int(y3 + y2 + y1 - 2 * face_y)), (255, 0, 0), 2)
70 | # Draw top in green
71 | cv2.line(img, (int(x3 + x1 - face_x), int(y3 + y1 - face_y)),
72 | (int(x3 + x1 + x2 - 2 * face_x), int(y3 + y2 + y1 - 2 * face_y)), (0, 255, 0), 2)
73 | cv2.line(img, (int(x2 + x3 - face_x), int(y2 + y3 - face_y)),
74 | (int(x3 + x1 + x2 - 2 * face_x), int(y3 + y2 + y1 - 2 * face_y)), (0, 255, 0), 2)
75 | cv2.line(img, (int(x3), int(y3)), (int(x3 + x1 - face_x), int(y3 + y1 - face_y)), (0, 255, 0), 2)
76 | cv2.line(img, (int(x3), int(y3)), (int(x3 + x2 - face_x), int(y3 + y2 - face_y)), (0, 255, 0), 2)
77 |
78 | return img
79 |
80 |
81 | def visualize_headpose(img, yaw, pitch, roll, save_path=None, to_bgr=False):
82 | img = np.copy(img)
83 | if to_bgr:
84 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
85 | show_string = (f'y {yaw[0].item():.2f}, p {pitch[0].item():.2f}, ' + f'r {roll[0].item():.2f}')
86 | cv2.putText(img, show_string, (30, img.shape[0] - 30), fontFace=1, fontScale=1, color=(0, 0, 255), thickness=2)
87 | draw_pose_cube(img, yaw[0], pitch[0], roll[0], size=100)
88 | draw_axis(img, yaw[0], pitch[0], roll[0], tdx=50, tdy=50, size=100)
89 | # save img
90 | if save_path is not None:
91 | cv2.imwrite(save_path, img)
92 |
--------------------------------------------------------------------------------
/facexlib/weights/README.md:
--------------------------------------------------------------------------------
1 | # Weights
2 |
3 | Put the downloaded weights to this folder.
4 |
--------------------------------------------------------------------------------
/inference/inference_alignment.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import torch
4 |
5 | from facexlib.alignment import init_alignment_model, landmark_98_to_68
6 | from facexlib.visualization import visualize_alignment
7 |
8 |
9 | def main(args):
10 | # initialize model
11 | align_net = init_alignment_model(args.model_name, device=args.device)
12 |
13 | img = cv2.imread(args.img_path)
14 | with torch.no_grad():
15 | landmarks = align_net.get_landmarks(img)
16 | if args.to68:
17 | landmarks = landmark_98_to_68(landmarks)
18 | visualize_alignment(img, [landmarks], args.save_path)
19 |
20 |
21 | if __name__ == '__main__':
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--img_path', type=str, default='assets/test2.jpg')
24 | parser.add_argument('--save_path', type=str, default='test_alignment.png')
25 | parser.add_argument('--model_name', type=str, default='awing_fan')
26 | parser.add_argument('--device', type=str, default='cuda')
27 | parser.add_argument('--to68', action='store_true')
28 | args = parser.parse_args()
29 |
30 | main(args)
31 |
--------------------------------------------------------------------------------
/inference/inference_crop_standard_faces.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 |
4 | from facexlib.detection import init_detection_model
5 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper
6 |
7 | input_img = '/home/wxt/datasets/ffhq/ffhq_wild/00028.png'
8 | # initialize face helper
9 | face_helper = FaceRestoreHelper(
10 | upscale_factor=1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png')
11 |
12 | face_helper.clean_all()
13 |
14 | det_net = init_detection_model('retinaface_resnet50', half=False)
15 | img = cv2.imread(input_img)
16 | with torch.no_grad():
17 | bboxes = det_net.detect_faces(img, 0.97)
18 | # x0, y0, x1, y1, confidence_score, five points (x, y)
19 | print(bboxes.shape)
20 | bboxes = bboxes[3]
21 |
22 | bboxes[0] -= 100
23 | bboxes[1] -= 100
24 | bboxes[2] += 100
25 | bboxes[3] += 100
26 | img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
27 |
28 | face_helper.read_image(img)
29 | # get face landmarks for each face
30 | face_helper.get_face_landmarks_5(only_center_face=True, pad_blur=False)
31 | # align and warp each face
32 | # save_crop_path = os.path.join(save_root, 'cropped_faces', img_name)
33 | save_crop_path = '00028_cvwarp.png'
34 | face_helper.align_warp_face(save_crop_path)
35 |
36 | # for i in range(50):
37 | # img = cv2.imread(f'inputs/ffhq_512/{i:08d}.png')
38 | # cv2.circle(img, (193, 240), 1, (0, 0, 255), 4)
39 | # cv2.circle(img, (319, 240), 1, (0, 255, 255), 4)
40 | # cv2.circle(img, (257, 314), 1, (255, 0, 255), 4)
41 | # cv2.circle(img, (201, 371), 1, (0, 255, 0), 4)
42 | # cv2.circle(img, (313, 371), 1, (255, 0, 0), 4)
43 |
44 | # cv2.imwrite(f'ffhq_lm/{i:08d}_lm.png', img)
45 |
46 | # [875.5 719.83333333] [1192.5 715.66666667] [1060. 997.]
47 |
--------------------------------------------------------------------------------
/inference/inference_detection.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import torch
4 |
5 | from facexlib.detection import init_detection_model
6 | from facexlib.visualization import visualize_detection
7 |
8 |
9 | def main(args):
10 | # initialize model
11 | det_net = init_detection_model(args.model_name, half=args.half)
12 |
13 | img = cv2.imread(args.img_path)
14 | with torch.no_grad():
15 | bboxes = det_net.detect_faces(img, 0.97)
16 | # x0, y0, x1, y1, confidence_score, five points (x, y)
17 | print(bboxes)
18 | visualize_detection(img, bboxes, args.save_path)
19 |
20 |
21 | if __name__ == '__main__':
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--img_path', type=str, default='assets/test.jpg')
24 | parser.add_argument('--save_path', type=str, default='test_detection.png')
25 | parser.add_argument(
26 | '--model_name', type=str, default='retinaface_resnet50', help='retinaface_resnet50 | retinaface_mobile0.25')
27 | parser.add_argument('--half', action='store_true')
28 | args = parser.parse_args()
29 |
30 | main(args)
31 |
--------------------------------------------------------------------------------
/inference/inference_headpose.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import numpy as np
4 | import torch
5 | from torchvision.transforms.functional import normalize
6 |
7 | from facexlib.detection import init_detection_model
8 | from facexlib.headpose import init_headpose_model
9 | from facexlib.utils.misc import img2tensor
10 | from facexlib.visualization import visualize_headpose
11 |
12 |
13 | def main(args):
14 | # initialize model
15 | det_net = init_detection_model(args.detection_model_name, half=args.half)
16 | headpose_net = init_headpose_model(args.headpose_model_name, half=args.half)
17 |
18 | img = cv2.imread(args.img_path)
19 | with torch.no_grad():
20 | bboxes = det_net.detect_faces(img, 0.97)
21 | # x0, y0, x1, y1, confidence_score, five points (x, y)
22 | bbox = list(map(int, bboxes[0]))
23 | # crop face region
24 | thld = 10
25 | h, w, _ = img.shape
26 | top = max(bbox[1] - thld, 0)
27 | bottom = min(bbox[3] + thld, h)
28 | left = max(bbox[0] - thld, 0)
29 | right = min(bbox[2] + thld, w)
30 |
31 | det_face = img[top:bottom, left:right, :].astype(np.float32) / 255.
32 |
33 | # resize
34 | det_face = cv2.resize(det_face, (224, 224), interpolation=cv2.INTER_LINEAR)
35 | det_face = img2tensor(np.copy(det_face), bgr2rgb=False)
36 |
37 | # normalize
38 | normalize(det_face, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], inplace=True)
39 | det_face = det_face.unsqueeze(0).cuda()
40 |
41 | yaw, pitch, roll = headpose_net(det_face)
42 | visualize_headpose(img, yaw, pitch, roll, args.save_path)
43 |
44 |
45 | if __name__ == '__main__':
46 | parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
47 | parser.add_argument('--img_path', type=str, default='assets/test.jpg')
48 | parser.add_argument('--save_path', type=str, default='assets/test_headpose.png')
49 | parser.add_argument('--detection_model_name', type=str, default='retinaface_resnet50')
50 | parser.add_argument('--headpose_model_name', type=str, default='hopenet')
51 | parser.add_argument('--half', action='store_true')
52 | args = parser.parse_args()
53 |
54 | main(args)
55 |
--------------------------------------------------------------------------------
/inference/inference_hyperiqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import numpy as np
4 | import os
5 | import torch
6 | import torchvision
7 | from PIL import Image
8 |
9 | from facexlib.assessment import init_assessment_model
10 | from facexlib.detection import init_detection_model
11 |
12 |
13 | def main(args):
14 | """Scripts about evaluating face quality.
15 | Two steps:
16 | 1) detect the face region and crop the face
17 | 2) evaluate the face quality by hyperIQA
18 | """
19 | # initialize model
20 | det_net = init_detection_model(args.detection_model_name, half=False)
21 | assess_net = init_assessment_model(args.assess_model_name, half=False)
22 |
23 | # specified face transformation in original hyperIQA
24 | transforms = torchvision.transforms.Compose([
25 | torchvision.transforms.Resize((512, 384)),
26 | torchvision.transforms.RandomCrop(size=224),
27 | torchvision.transforms.ToTensor(),
28 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
29 | ])
30 |
31 | img = cv2.imread(args.img_path)
32 | img_name = os.path.basename(args.img_path)
33 | basename, _ = os.path.splitext(img_name)
34 | with torch.no_grad():
35 | bboxes = det_net.detect_faces(img, 0.97)
36 | box = list(map(int, bboxes[0]))
37 | pred_scores = []
38 | # BRG -> RGB
39 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
40 |
41 | for i in range(10):
42 | detect_face = img[box[1]:box[3], box[0]:box[2], :]
43 | detect_face = Image.fromarray(detect_face)
44 |
45 | detect_face = transforms(detect_face)
46 | detect_face = torch.tensor(detect_face.cuda()).unsqueeze(0)
47 |
48 | pred = assess_net(detect_face)
49 | pred_scores.append(float(pred.item()))
50 | score = np.mean(pred_scores)
51 | # quality score ranges from 0-100, a higher score indicates a better quality
52 | print(f'{basename} {score:.4f}')
53 |
54 |
55 | if __name__ == '__main__':
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument('--img_path', type=str, default='assets/test2.jpg')
58 | parser.add_argument('--detection_model_name', type=str, default='retinaface_resnet50')
59 | parser.add_argument('--assess_model_name', type=str, default='hypernet')
60 | parser.add_argument('--half', action='store_true')
61 | args = parser.parse_args()
62 |
63 | main(args)
64 |
--------------------------------------------------------------------------------
/inference/inference_matting.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from torchvision.transforms.functional import normalize
6 |
7 | from facexlib.matting import init_matting_model
8 | from facexlib.utils import img2tensor
9 |
10 |
11 | def main(args):
12 | modnet = init_matting_model()
13 |
14 | # read image
15 | img = cv2.imread(args.img_path) / 255.
16 | # unify image channels to 3
17 | if len(img.shape) == 2:
18 | img = img[:, :, None]
19 | if img.shape[2] == 1:
20 | img = np.repeat(img, 3, axis=2)
21 | elif img.shape[2] == 4:
22 | img = img[:, :, 0:3]
23 |
24 | img_t = img2tensor(img, bgr2rgb=True, float32=True)
25 | normalize(img_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
26 | img_t = img_t.unsqueeze(0).cuda()
27 |
28 | # resize image for input
29 | _, _, im_h, im_w = img_t.shape
30 | ref_size = 512
31 | if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
32 | if im_w >= im_h:
33 | im_rh = ref_size
34 | im_rw = int(im_w / im_h * ref_size)
35 | elif im_w < im_h:
36 | im_rw = ref_size
37 | im_rh = int(im_h / im_w * ref_size)
38 | else:
39 | im_rh = im_h
40 | im_rw = im_w
41 | im_rw = im_rw - im_rw % 32
42 | im_rh = im_rh - im_rh % 32
43 | img_t = F.interpolate(img_t, size=(im_rh, im_rw), mode='area')
44 |
45 | # inference
46 | _, _, matte = modnet(img_t, True)
47 |
48 | # resize and save matte
49 | matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
50 | matte = matte[0][0].data.cpu().numpy()
51 | cv2.imwrite(args.save_path, (matte * 255).astype('uint8'))
52 |
53 | # get foreground
54 | matte = matte[:, :, None]
55 | foreground = img * matte + np.full(img.shape, 1) * (1 - matte)
56 | cv2.imwrite(args.save_path.replace('.png', '_fg.png'), foreground * 255)
57 |
58 |
59 | if __name__ == '__main__':
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument('--img_path', type=str, default='assets/test.jpg')
62 | parser.add_argument('--save_path', type=str, default='test_matting.png')
63 | args = parser.parse_args()
64 |
65 | main(args)
66 |
--------------------------------------------------------------------------------
/inference/inference_parsing.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import numpy as np
4 | import os
5 | import torch
6 | from torchvision.transforms.functional import normalize
7 |
8 | from facexlib.parsing import init_parsing_model
9 | from facexlib.utils.misc import img2tensor
10 |
11 |
12 | def vis_parsing_maps(img, parsing_anno, stride, save_anno_path=None, save_vis_path=None):
13 | # Colors for all 20 parts
14 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 0, 85], [255, 0, 170], [0, 255, 0], [85, 255, 0],
15 | [170, 255, 0], [0, 255, 85], [0, 255, 170], [0, 0, 255], [85, 0, 255], [170, 0, 255], [0, 85, 255],
16 | [0, 170, 255], [255, 255, 0], [255, 255, 85], [255, 255, 170], [255, 0, 255], [255, 85, 255],
17 | [255, 170, 255], [0, 255, 255], [85, 255, 255], [170, 255, 255]]
18 | # 0: 'background'
19 | # attributions = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye',
20 | # 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r', 10 'nose',
21 | # 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l',
22 | # 16 'cloth', 17 'hair', 18 'hat']
23 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
24 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
25 | if save_anno_path is not None:
26 | cv2.imwrite(save_anno_path, vis_parsing_anno)
27 |
28 | if save_vis_path is not None:
29 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
30 | num_of_class = np.max(vis_parsing_anno)
31 | for pi in range(1, num_of_class + 1):
32 | index = np.where(vis_parsing_anno == pi)
33 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
34 |
35 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
36 | vis_im = cv2.addWeighted(img, 0.4, vis_parsing_anno_color, 0.6, 0)
37 |
38 | cv2.imwrite(save_vis_path, vis_im)
39 |
40 |
41 | def main(img_path, output):
42 | net = init_parsing_model(model_name='bisenet')
43 |
44 | img_name = os.path.basename(img_path)
45 | img_basename = os.path.splitext(img_name)[0]
46 |
47 | img_input = cv2.imread(img_path)
48 | img_input = cv2.resize(img_input, (512, 512), interpolation=cv2.INTER_LINEAR)
49 | img = img2tensor(img_input.astype('float32') / 255., bgr2rgb=True, float32=True)
50 | normalize(img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True)
51 | img = torch.unsqueeze(img, 0).cuda()
52 |
53 | with torch.no_grad():
54 | out = net(img)[0]
55 | out = out.squeeze(0).cpu().numpy().argmax(0)
56 |
57 | vis_parsing_maps(
58 | img_input,
59 | out,
60 | stride=1,
61 | save_anno_path=os.path.join(output, f'{img_basename}.png'),
62 | save_vis_path=os.path.join(output, f'{img_basename}_vis.png'))
63 |
64 |
65 | if __name__ == '__main__':
66 | parser = argparse.ArgumentParser()
67 |
68 | parser.add_argument('--input', type=str, default='datasets/ffhq/ffhq_512/00000000.png')
69 | parser.add_argument('--output', type=str, default='results', help='output folder')
70 | args = parser.parse_args()
71 |
72 | os.makedirs(args.output, exist_ok=True)
73 |
74 | main(args.input, args.output)
75 |
--------------------------------------------------------------------------------
/inference/inference_parsing_parsenet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import numpy as np
4 | import os
5 | import torch
6 | from torchvision.transforms.functional import normalize
7 |
8 | from facexlib.parsing import init_parsing_model
9 | from facexlib.utils.misc import img2tensor
10 |
11 |
12 | def vis_parsing_maps(img, parsing_anno, stride, save_anno_path=None, save_vis_path=None):
13 | # Colors for all parts
14 | part_colors = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255],
15 | [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204],
16 | [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
17 | # 0: 'background' 1: 'skin' 2: 'nose'
18 | # 3: 'eye_g' 4: 'l_eye' 5: 'r_eye'
19 | # 6: 'l_brow' 7: 'r_brow' 8: 'l_ear'
20 | # 9: 'r_ear' 10: 'mouth' 11: 'u_lip'
21 | # 12: 'l_lip' 13: 'hair' 14: 'hat'
22 | # 15: 'ear_r' 16: 'neck_l' 17: 'neck'
23 | # 18: 'cloth'
24 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
25 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
26 | if save_anno_path is not None:
27 | cv2.imwrite(save_anno_path, vis_parsing_anno)
28 |
29 | if save_vis_path is not None:
30 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
31 | num_of_class = np.max(vis_parsing_anno)
32 | for pi in range(1, num_of_class + 1):
33 | index = np.where(vis_parsing_anno == pi)
34 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
35 |
36 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
37 | vis_im = cv2.addWeighted(img, 0.4, vis_parsing_anno_color, 0.6, 0)
38 |
39 | cv2.imwrite(save_vis_path, vis_im)
40 |
41 |
42 | def main(img_path, output):
43 | net = init_parsing_model(model_name='parsenet')
44 |
45 | img_name = os.path.basename(img_path)
46 | img_basename = os.path.splitext(img_name)[0]
47 |
48 | img_input = cv2.imread(img_path)
49 | # resize to 512 x 512 for better performance
50 | img_input = cv2.resize(img_input, (512, 512), interpolation=cv2.INTER_LINEAR)
51 | img = img2tensor(img_input.astype('float32') / 255., bgr2rgb=True, float32=True)
52 | normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
53 | img = torch.unsqueeze(img, 0).cuda()
54 |
55 | with torch.no_grad():
56 | out = net(img)[0]
57 | out = out.squeeze(0).cpu().numpy().argmax(0)
58 |
59 | vis_parsing_maps(
60 | img_input,
61 | out,
62 | stride=1,
63 | save_anno_path=os.path.join(output, f'{img_basename}.png'),
64 | save_vis_path=os.path.join(output, f'{img_basename}_vis.png'))
65 |
66 |
67 | if __name__ == '__main__':
68 | parser = argparse.ArgumentParser()
69 |
70 | parser.add_argument('--input', type=str, default='datasets/ffhq/ffhq_512/00000000.png')
71 | parser.add_argument('--output', type=str, default='results', help='output folder')
72 | args = parser.parse_args()
73 |
74 | os.makedirs(args.output, exist_ok=True)
75 |
76 | main(args.input, args.output)
77 |
--------------------------------------------------------------------------------
/inference/inference_recognition.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import math
4 | import numpy as np
5 | import os
6 | import torch
7 |
8 | from facexlib.recognition import ResNetArcFace, cosin_metric, load_image
9 |
10 | if __name__ == '__main__':
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--folder1', type=str)
13 | parser.add_argument('--folder2', type=str)
14 | parser.add_argument('--model_path', type=str, default='facexlib/recognition/weights/arcface_resnet18.pth')
15 |
16 | args = parser.parse_args()
17 |
18 | img_list1 = sorted(glob.glob(os.path.join(args.folder1, '*')))
19 | img_list2 = sorted(glob.glob(os.path.join(args.folder2, '*')))
20 | print(img_list1, img_list2)
21 | model = ResNetArcFace(block='IRBlock', layers=(2, 2, 2, 2), use_se=False)
22 | model.load_state_dict(torch.load(args.model_path))
23 | model.to(torch.device('cuda'))
24 | model.eval()
25 |
26 | dist_list = []
27 | identical_count = 0
28 | for idx, (img_path1, img_path2) in enumerate(zip(img_list1, img_list2)):
29 | basename = os.path.splitext(os.path.basename(img_path1))[0]
30 | img1 = load_image(img_path1)
31 | img2 = load_image(img_path2)
32 |
33 | data = torch.stack([img1, img2], dim=0)
34 | data = data.to(torch.device('cuda'))
35 | output = model(data)
36 | print(output.size())
37 | output = output.data.cpu().numpy()
38 | dist = cosin_metric(output[0], output[1])
39 | dist = np.arccos(dist) / math.pi * 180
40 | print(f'{idx} - {dist} o : {basename}')
41 | if dist < 1:
42 | print(f'{basename} is almost identical to original.')
43 | identical_count += 1
44 | else:
45 | dist_list.append(dist)
46 |
47 | print(f'Result dist: {sum(dist_list) / len(dist_list):.6f}')
48 | print(f'identical count: {identical_count}')
49 |
--------------------------------------------------------------------------------
/inference/inference_tracking.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import glob
4 | import numpy as np
5 | import os
6 | import torch
7 | from tqdm import tqdm
8 |
9 | from facexlib.detection import init_detection_model
10 | from facexlib.tracking.sort import SORT
11 |
12 |
13 | def main(args):
14 | detect_interval = args.detect_interval
15 | margin = args.margin
16 | face_score_threshold = args.face_score_threshold
17 |
18 | save_frame = True
19 | if save_frame:
20 | colors = np.random.rand(32, 3)
21 |
22 | # init detection model and tracker
23 | det_net = init_detection_model('retinaface_resnet50', half=False)
24 | tracker = SORT(max_age=1, min_hits=2, iou_threshold=0.2)
25 | print('Start track...')
26 |
27 | # track over all frames
28 | frame_paths = sorted(glob.glob(os.path.join(args.input_folder, '*.jpg')))
29 | pbar = tqdm(total=len(frame_paths), unit='frames', desc='Extract')
30 | for idx, path in enumerate(frame_paths):
31 | img_basename = os.path.basename(path)
32 | frame = cv2.imread(path)
33 | img_size = frame.shape[0:2]
34 |
35 | # detection face bboxes
36 | with torch.no_grad():
37 | bboxes = det_net.detect_faces(frame, 0.97)
38 |
39 | additional_attr = []
40 | face_list = []
41 |
42 | for idx_bb, bbox in enumerate(bboxes):
43 | score = bbox[4]
44 | if score > face_score_threshold:
45 | bbox = bbox[0:5]
46 | det = bbox[0:4]
47 |
48 | # face rectangle
49 | det[0] = np.maximum(det[0] - margin, 0)
50 | det[1] = np.maximum(det[1] - margin, 0)
51 | det[2] = np.minimum(det[2] + margin, img_size[1])
52 | det[3] = np.minimum(det[3] + margin, img_size[0])
53 | face_list.append(bbox)
54 | additional_attr.append([score])
55 | trackers = tracker.update(np.array(face_list), img_size, additional_attr, detect_interval)
56 |
57 | pbar.update(1)
58 | pbar.set_description(f'{idx}: detect {len(bboxes)} faces in {img_basename}')
59 |
60 | # save frame
61 | if save_frame:
62 | for d in trackers:
63 | d = d.astype(np.int32)
64 | cv2.rectangle(frame, (d[0], d[1]), (d[2], d[3]), colors[d[4] % 32, :] * 255, 3)
65 | if len(face_list) != 0:
66 | cv2.putText(frame, 'ID : %d DETECT' % (d[4]), (d[0] - 10, d[1] - 10), cv2.FONT_HERSHEY_SIMPLEX,
67 | 0.75, colors[d[4] % 32, :] * 255, 2)
68 | cv2.putText(frame, 'DETECTOR', (5, 45), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (1, 1, 1), 2)
69 | else:
70 | cv2.putText(frame, 'ID : %d' % (d[4]), (d[0] - 10, d[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.75,
71 | colors[d[4] % 32, :] * 255, 2)
72 | save_path = os.path.join(args.save_folder, img_basename)
73 | cv2.imwrite(save_path, frame)
74 |
75 |
76 | if __name__ == '__main__':
77 | parser = argparse.ArgumentParser()
78 | parser.add_argument('--input_folder', help='Path to the input folder', type=str)
79 | parser.add_argument('--save_folder', help='Path to save visualized frames', type=str, default=None)
80 |
81 | parser.add_argument(
82 | '--detect_interval',
83 | help=('how many frames to make a detection, trade-off '
84 | 'between performance and fluency'),
85 | type=int,
86 | default=1)
87 | # if the face is big in your video ,you can set it bigger for easy tracking
88 | parser.add_argument('--margin', help='add margin for face', type=int, default=20)
89 | parser.add_argument(
90 | '--face_score_threshold', help='The threshold of the extracted faces,range 0 < x <=1', type=float, default=0.85)
91 |
92 | args = parser.parse_args()
93 | os.makedirs(args.save_folder, exist_ok=True)
94 | main(args)
95 |
96 | # add verification
97 | # remove last few frames
98 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | filterpy
2 | numba
3 | numpy
4 | numpy
5 | opencv-python
6 | Pillow
7 | scipy
8 | torch
9 | torchvision
10 | tqdm
11 |
--------------------------------------------------------------------------------
/scripts/crop_faces_5landmarks.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | import facexlib.utils.face_restoration_helper as face_restoration_helper
5 |
6 |
7 | def crop_one_img(img, save_cropped_path=None):
8 | FaceRestoreHelper.clean_all()
9 | FaceRestoreHelper.read_image(img)
10 | # get face landmarks
11 | FaceRestoreHelper.get_face_landmarks_5()
12 | FaceRestoreHelper.align_warp_face(save_cropped_path)
13 |
14 |
15 | if __name__ == '__main__':
16 | # initialize face helper
17 | FaceRestoreHelper = face_restoration_helper.FaceRestoreHelper(upscale_factor=1)
18 |
19 | img_paths = glob.glob('/home/wxt/Projects/test/*')
20 | save_path = 'test'
21 | for idx, path in enumerate(img_paths):
22 | print(idx, path)
23 | file_name = os.path.basename(path)
24 | save_cropped_path = os.path.join(save_path, file_name)
25 | crop_one_img(path, save_cropped_path=save_cropped_path)
26 |
--------------------------------------------------------------------------------
/scripts/extract_detection_info_ffhq.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import glob
3 | import numpy as np
4 | import os
5 | import torch
6 | from PIL import Image
7 | from tqdm import tqdm
8 |
9 | from facexlib.detection import init_detection_model
10 |
11 |
12 | def draw_and_save(image, bboxes_and_landmarks, save_path, order_type=1):
13 | """Visualize results
14 | """
15 | if isinstance(image, Image.Image):
16 | image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
17 | image = image.astype(np.float32)
18 | for b in bboxes_and_landmarks:
19 | # confidence
20 | cv2.putText(image, '{:.4f}'.format(b[4]), (int(b[0]), int(b[1] + 12)), cv2.FONT_HERSHEY_DUPLEX, 0.5,
21 | (255, 255, 255))
22 | # bounding boxes
23 | b = list(map(int, b))
24 | cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
25 | # landmarks
26 | if order_type == 0: # mtcnn
27 | cv2.circle(image, (b[5], b[10]), 1, (0, 0, 255), 4)
28 | cv2.circle(image, (b[6], b[11]), 1, (0, 255, 255), 4)
29 | cv2.circle(image, (b[7], b[12]), 1, (255, 0, 255), 4)
30 | cv2.circle(image, (b[8], b[13]), 1, (0, 255, 0), 4)
31 | cv2.circle(image, (b[9], b[14]), 1, (255, 0, 0), 4)
32 | else: # retinaface, centerface
33 | cv2.circle(image, (b[5], b[6]), 1, (0, 0, 255), 4)
34 | cv2.circle(image, (b[7], b[8]), 1, (0, 255, 255), 4)
35 | cv2.circle(image, (b[9], b[10]), 1, (255, 0, 255), 4)
36 | cv2.circle(image, (b[11], b[12]), 1, (0, 255, 0), 4)
37 | cv2.circle(image, (b[13], b[14]), 1, (255, 0, 0), 4)
38 | # save image
39 | cv2.imwrite(save_path, image)
40 |
41 |
42 | det_net = init_detection_model('retinaface_resnet50')
43 | half = False
44 |
45 | det_net.cuda().eval()
46 | if half:
47 | det_net = det_net.half()
48 |
49 | img_list = sorted(glob.glob('../../BasicSR-private/datasets/ffhq/ffhq_512/*'))
50 |
51 |
52 | def get_center_landmark(landmarks, center):
53 | center = np.array(center)
54 | center_dist = []
55 | for landmark in landmarks:
56 | landmark_center = np.array([(landmark[0] + landmark[2]) / 2, (landmark[1] + landmark[3]) / 2])
57 | dist = np.linalg.norm(landmark_center - center)
58 | center_dist.append(dist)
59 | center_idx = center_dist.index(min(center_dist))
60 | return landmarks[center_idx]
61 |
62 |
63 | pbar = tqdm(total=len(img_list), unit='image')
64 | save_np = []
65 | for idx, path in enumerate(img_list):
66 | img_name = os.path.basename(path)
67 | pbar.update(1)
68 | pbar.set_description(path)
69 | img = Image.open(path)
70 | with torch.no_grad():
71 | bboxes, warped_face_list = det_net.align_multi(img, 0.97, half=half)
72 | if len(bboxes) > 1:
73 | bboxes = [get_center_landmark(bboxes, (256, 256))]
74 | save_np.append(bboxes)
75 | # draw_and_save(img, bboxes, os.path.join('tmp', img_name), 1)
76 | np.save('ffhq_det_info.npy', save_np)
77 |
--------------------------------------------------------------------------------
/scripts/get_ffhq_template.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from PIL import Image
4 |
5 | bboxes = np.load('ffhq_det_info.npy', allow_pickle=True)
6 |
7 | bboxes = np.array(bboxes).squeeze(1)
8 |
9 | bboxes = np.mean(bboxes, axis=0)
10 |
11 | print(bboxes)
12 |
13 |
14 | def draw_and_save(image, bboxes_and_landmarks, save_path, order_type=1):
15 | """Visualize results
16 | """
17 | if isinstance(image, Image.Image):
18 | image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
19 | image = image.astype(np.float32)
20 | for b in bboxes_and_landmarks:
21 | # confidence
22 | cv2.putText(image, '{:.4f}'.format(b[4]), (int(b[0]), int(b[1] + 12)), cv2.FONT_HERSHEY_DUPLEX, 0.5,
23 | (255, 255, 255))
24 | # bounding boxes
25 | b = list(map(int, b))
26 | cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
27 | # landmarks
28 | if order_type == 0: # mtcnn
29 | cv2.circle(image, (b[5], b[10]), 1, (0, 0, 255), 4)
30 | cv2.circle(image, (b[6], b[11]), 1, (0, 255, 255), 4)
31 | cv2.circle(image, (b[7], b[12]), 1, (255, 0, 255), 4)
32 | cv2.circle(image, (b[8], b[13]), 1, (0, 255, 0), 4)
33 | cv2.circle(image, (b[9], b[14]), 1, (255, 0, 0), 4)
34 | else: # retinaface, centerface
35 | cv2.circle(image, (b[5], b[6]), 1, (0, 0, 255), 4)
36 | cv2.circle(image, (b[7], b[8]), 1, (0, 255, 255), 4)
37 | cv2.circle(image, (b[9], b[10]), 1, (255, 0, 255), 4)
38 | cv2.circle(image, (b[11], b[12]), 1, (0, 255, 0), 4)
39 | cv2.circle(image, (b[13], b[14]), 1, (255, 0, 0), 4)
40 | # save image
41 | cv2.imwrite(save_path, image)
42 |
43 |
44 | img = Image.open('inputs/00000000.png')
45 | # bboxes = np.array([
46 | # 118.177826 * 2, 92.759514 * 2, 394.95926 * 2, 472.53278 * 2, 0.9995705 * 2, # noqa: E501
47 | # 686.77227723, 488.62376238, 586.77227723, 493.59405941, 337.91089109,
48 | # 488.38613861, 437.95049505, 493.51485149, 513.58415842, 678.5049505
49 | # ])
50 | # bboxes = bboxes / 2
51 | draw_and_save(img, [bboxes], 'template_detall.png', 1)
52 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore =
3 | # line break before binary operator (W503)
4 | W503,
5 | # line break after binary operator (W504)
6 | W504,
7 | max-line-length=120
8 |
9 | [yapf]
10 | based_on_style = pep8
11 | column_limit = 120
12 | blank_line_before_nested_class_or_def = true
13 | split_before_expression_after_opening_paren = true
14 |
15 | [isort]
16 | line_length = 120
17 | multi_line_output = 0
18 | known_standard_library = pkg_resources,setuptools
19 | known_first_party = facexlib
20 | known_third_party = PIL,cv2,filterpy,numba,numpy,scipy,torch,torchvision,tqdm
21 | no_lines_before = STDLIB,LOCALFOLDER
22 | default_section = THIRDPARTY
23 |
24 | [codespell]
25 | skip = .git,./docs/build,*.cfg
26 | count =
27 | quiet-level = 3
28 | ignore-words-list = mot,ans
29 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import find_packages, setup
4 |
5 | import os
6 | import subprocess
7 | import time
8 |
9 | version_file = 'facexlib/version.py'
10 |
11 |
12 | def readme():
13 | with open('README.md', encoding='utf-8') as f:
14 | content = f.read()
15 | return content
16 |
17 |
18 | def get_git_hash():
19 |
20 | def _minimal_ext_cmd(cmd):
21 | # construct minimal environment
22 | env = {}
23 | for k in ['SYSTEMROOT', 'PATH', 'HOME']:
24 | v = os.environ.get(k)
25 | if v is not None:
26 | env[k] = v
27 | # LANGUAGE is used on win32
28 | env['LANGUAGE'] = 'C'
29 | env['LANG'] = 'C'
30 | env['LC_ALL'] = 'C'
31 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
32 | return out
33 |
34 | try:
35 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
36 | sha = out.strip().decode('ascii')
37 | except OSError:
38 | sha = 'unknown'
39 |
40 | return sha
41 |
42 |
43 | def get_hash():
44 | if os.path.exists('.git'):
45 | sha = get_git_hash()[:7]
46 | else:
47 | sha = 'unknown'
48 |
49 | return sha
50 |
51 |
52 | def write_version_py():
53 | content = """# GENERATED VERSION FILE
54 | # TIME: {}
55 | __version__ = '{}'
56 | __gitsha__ = '{}'
57 | version_info = ({})
58 | """
59 | sha = get_hash()
60 | with open('VERSION', 'r') as f:
61 | SHORT_VERSION = f.read().strip()
62 | VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
63 |
64 | version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
65 | with open(version_file, 'w') as f:
66 | f.write(version_file_str)
67 |
68 |
69 | def get_version():
70 | with open(version_file, 'r') as f:
71 | exec(compile(f.read(), version_file, 'exec'))
72 | return locals()['__version__']
73 |
74 |
75 | def get_requirements(filename='requirements.txt'):
76 | here = os.path.dirname(os.path.realpath(__file__))
77 | with open(os.path.join(here, filename), 'r') as f:
78 | requires = [line.replace('\n', '') for line in f.readlines()]
79 | return requires
80 |
81 |
82 | if __name__ == '__main__':
83 | write_version_py()
84 | setup(
85 | name='facexlib',
86 | version=get_version(),
87 | description='Basic face library',
88 | long_description=readme(),
89 | long_description_content_type='text/markdown',
90 | author='Xintao Wang',
91 | author_email='xintao.wang@outlook.com',
92 | keywords='computer vision, face, detection, landmark, alignment',
93 | url='https://github.com/xinntao/facexlib',
94 | include_package_data=True,
95 | packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
96 | classifiers=[
97 | 'Development Status :: 4 - Beta',
98 | 'License :: OSI Approved :: Apache Software License',
99 | 'Operating System :: OS Independent',
100 | 'Programming Language :: Python :: 3',
101 | 'Programming Language :: Python :: 3.7',
102 | 'Programming Language :: Python :: 3.8',
103 | ],
104 | license='Apache License 2.0',
105 | setup_requires=['cython', 'numpy'],
106 | install_requires=get_requirements(),
107 | zip_safe=False)
108 |
--------------------------------------------------------------------------------