├── .github ├── FUNDING.yml └── workflows │ └── python.yml ├── .gitignore ├── .gitmodules ├── LICENSE.md ├── README.md ├── README_cn.md ├── __init__.py ├── codecov.yml ├── data ├── .gitignore ├── facenet-pytorch-banner.png ├── multiface.jpg ├── multiface_detected.png ├── onet.pt ├── pnet.pt ├── rnet.pt ├── test_images │ ├── angelina_jolie │ │ └── 1.jpg │ ├── bradley_cooper │ │ └── 1.jpg │ ├── kate_siegel │ │ └── 1.jpg │ ├── paul_rudd │ │ └── 1.jpg │ └── shea_whigham │ │ └── 1.jpg └── test_images_aligned │ ├── angelina_jolie │ └── 1.png │ ├── bradley_cooper │ └── 1.png │ ├── kate_siegel │ └── 1.png │ ├── paul_rudd │ └── 1.png │ └── shea_whigham │ └── 1.png ├── examples ├── face_tracking.ipynb ├── face_tracking_cn.ipynb ├── finetune.ipynb ├── finetune_cn.ipynb ├── infer.ipynb ├── infer_cn.ipynb ├── lfw_evaluate.ipynb ├── lfw_evaluate_cn.ipynb ├── performance-comparison.png ├── tracked.gif ├── video.mp4 └── video_tracked.mp4 ├── models ├── inception_resnet_v1.py ├── mtcnn.py └── utils │ ├── detect_face.py │ ├── download.py │ ├── tensorflow2pytorch.py │ └── training.py ├── setup.py └── tests ├── actions_requirements.txt ├── actions_test.py └── perf_test.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | custom: ["https://xscode.com/timesler/facenet-pytorch"] 2 | -------------------------------------------------------------------------------- /.github/workflows/python.yml: -------------------------------------------------------------------------------- 1 | name: Python 2 | 3 | on: [pull_request, push, workflow_dispatch] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | fail-fast: false 11 | matrix: 12 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install -r tests/actions_requirements.txt 24 | - name: Test with pytest 25 | run: | 26 | python --version 27 | echo "import tests.actions_test" > test.py 28 | coverage run --source models,examples test.py 29 | coverage report 30 | coverage xml 31 | - name: Upload coverage to Codecov 32 | uses: codecov/codecov-action@v3 33 | if: ${{ matrix.python-version == '3.12' }} 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .vscode 3 | .ipynb_checkpoints 4 | runs 5 | build 6 | dist 7 | *.egg-info 8 | *tmp* 9 | .coverage 10 | htmlcov 11 | test.py 12 | .cache 13 | .ipython 14 | .local 15 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "dependencies/facenet"] 2 | path = dependencies/facenet 3 | url = https://github.com/davidsandberg/facenet.git 4 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Timothy Esler 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Face Recognition Using Pytorch 2 | 3 | *You can also read a translated version of this file [in Chinese 简体中文版](README_cn.md).* 4 | 5 | [![Downloads](https://pepy.tech/badge/facenet-pytorch)](https://pepy.tech/project/facenet-pytorch) 6 | 7 | [![Code Coverage](https://img.shields.io/codecov/c/github/timesler/facenet-pytorch.svg)](https://codecov.io/gh/timesler/facenet-pytorch) 8 | 9 | This is a repository for Inception Resnet (V1) models in pytorch, pretrained on VGGFace2 and CASIA-Webface. 10 | 11 | Pytorch model weights were initialized using parameters ported from David Sandberg's [tensorflow facenet repo](https://github.com/davidsandberg/facenet). 12 | 13 | Also included in this repo is an efficient pytorch implementation of MTCNN for face detection prior to inference. These models are also pretrained. To our knowledge, this is the fastest MTCNN implementation available. 14 | 15 | ## Table of contents 16 | 17 | * [Table of contents](#table-of-contents) 18 | * [Quick start](#quick-start) 19 | * [Pretrained models](#pretrained-models) 20 | * [Example notebooks](#example-notebooks) 21 | + [*Complete detection and recognition pipeline*](#complete-detection-and-recognition-pipeline) 22 | + [*Face tracking in video streams*](#face-tracking-in-video-streams) 23 | + [*Finetuning pretrained models with new data*](#finetuning-pretrained-models-with-new-data) 24 | + [*Guide to MTCNN in facenet-pytorch*](#guide-to-mtcnn-in-facenet-pytorch) 25 | + [*Performance comparison of face detection packages*](#performance-comparison-of-face-detection-packages) 26 | + [*The FastMTCNN algorithm*](#the-fastmtcnn-algorithm) 27 | * [Running with docker](#running-with-docker) 28 | * [Use this repo in your own git project](#use-this-repo-in-your-own-git-project) 29 | * [Conversion of parameters from Tensorflow to Pytorch](#conversion-of-parameters-from-tensorflow-to-pytorch) 30 | * [References](#references) 31 | 32 | ## Quick start 33 | 34 | 1. Install: 35 | 36 | ```bash 37 | # With pip: 38 | pip install facenet-pytorch 39 | 40 | # or clone this repo, removing the '-' to allow python imports: 41 | git clone https://github.com/timesler/facenet-pytorch.git facenet_pytorch 42 | 43 | # or use a docker container (see https://github.com/timesler/docker-jupyter-dl-gpu): 44 | docker run -it --rm timesler/jupyter-dl-gpu pip install facenet-pytorch && ipython 45 | ``` 46 | 47 | 1. In python, import facenet-pytorch and instantiate models: 48 | 49 | ```python 50 | from facenet_pytorch import MTCNN, InceptionResnetV1 51 | 52 | # If required, create a face detection pipeline using MTCNN: 53 | mtcnn = MTCNN(image_size=, margin=) 54 | 55 | # Create an inception resnet (in eval mode): 56 | resnet = InceptionResnetV1(pretrained='vggface2').eval() 57 | ``` 58 | 59 | 1. Process an image: 60 | 61 | ```python 62 | from PIL import Image 63 | 64 | img = Image.open() 65 | 66 | # Get cropped and prewhitened image tensor 67 | img_cropped = mtcnn(img, save_path=) 68 | 69 | # Calculate embedding (unsqueeze to add batch dimension) 70 | img_embedding = resnet(img_cropped.unsqueeze(0)) 71 | 72 | # Or, if using for VGGFace2 classification 73 | resnet.classify = True 74 | img_probs = resnet(img_cropped.unsqueeze(0)) 75 | ``` 76 | 77 | See `help(MTCNN)` and `help(InceptionResnetV1)` for usage and implementation details. 78 | 79 | ## Pretrained models 80 | 81 | See: [models/inception_resnet_v1.py](models/inception_resnet_v1.py) 82 | 83 | The following models have been ported to pytorch (with links to download pytorch state_dict's): 84 | 85 | |Model name|LFW accuracy (as listed [here](https://github.com/davidsandberg/facenet))|Training dataset| 86 | | :- | :-: | -: | 87 | |[20180408-102900](https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180408-102900-casia-webface.pt) (111MB)|0.9905|CASIA-Webface| 88 | |[20180402-114759](https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180402-114759-vggface2.pt) (107MB)|0.9965|VGGFace2| 89 | 90 | There is no need to manually download the pretrained state_dict's; they are downloaded automatically on model instantiation and cached for future use in the torch cache. To use an Inception Resnet (V1) model for facial recognition/identification in pytorch, use: 91 | 92 | ```python 93 | from facenet_pytorch import InceptionResnetV1 94 | 95 | # For a model pretrained on VGGFace2 96 | model = InceptionResnetV1(pretrained='vggface2').eval() 97 | 98 | # For a model pretrained on CASIA-Webface 99 | model = InceptionResnetV1(pretrained='casia-webface').eval() 100 | 101 | # For an untrained model with 100 classes 102 | model = InceptionResnetV1(num_classes=100).eval() 103 | 104 | # For an untrained 1001-class classifier 105 | model = InceptionResnetV1(classify=True, num_classes=1001).eval() 106 | ``` 107 | 108 | Both pretrained models were trained on 160x160 px images, so will perform best if applied to images resized to this shape. For best results, images should also be cropped to the face using MTCNN (see below). 109 | 110 | By default, the above models will return 512-dimensional embeddings of images. To enable classification instead, either pass `classify=True` to the model constructor, or you can set the object attribute afterwards with `model.classify = True`. For VGGFace2, the pretrained model will output logit vectors of length 8631, and for CASIA-Webface logit vectors of length 10575. 111 | 112 | ## Example notebooks 113 | 114 | ### *Complete detection and recognition pipeline* 115 | 116 | Face recognition can be easily applied to raw images by first detecting faces using MTCNN before calculating embedding or probabilities using an Inception Resnet model. The example code at [examples/infer.ipynb](examples/infer.ipynb) provides a complete example pipeline utilizing datasets, dataloaders, and optional GPU processing. 117 | 118 | ### *Face tracking in video streams* 119 | 120 | MTCNN can be used to build a face tracking system (using the `MTCNN.detect()` method). A full face tracking example can be found at [examples/face_tracking.ipynb](examples/face_tracking.ipynb). 121 | 122 | ![](examples/tracked.gif) 123 | 124 | ### *Finetuning pretrained models with new data* 125 | 126 | In most situations, the best way to implement face recognition is to use the pretrained models directly, with either a clustering algorithm or a simple distance metrics to determine the identity of a face. However, if finetuning is required (i.e., if you want to select identity based on the model's output logits), an example can be found at [examples/finetune.ipynb](examples/finetune.ipynb). 127 | 128 | ### *Guide to MTCNN in facenet-pytorch* 129 | 130 | This guide demonstrates the functionality of the MTCNN module. Topics covered are: 131 | 132 | * Basic usage 133 | * Image normalization 134 | * Face margins 135 | * Multiple faces in a single image 136 | * Batched detection 137 | * Bounding boxes and facial landmarks 138 | * Saving face datasets 139 | 140 | See the [notebook on kaggle](https://www.kaggle.com/timesler/guide-to-mtcnn-in-facenet-pytorch). 141 | 142 | ### *Performance comparison of face detection packages* 143 | 144 | This notebook demonstrates the use of three face detection packages: 145 | 146 | 1. facenet-pytorch 147 | 1. mtcnn 148 | 1. dlib 149 | 150 | Each package is tested for its speed in detecting the faces in a set of 300 images (all frames from one video), with GPU support enabled. Performance is based on Kaggle's P100 notebook kernel. Results are summarized below. 151 | 152 | |Package|FPS (1080x1920)|FPS (720x1280)|FPS (540x960)| 153 | |---|---|---|---| 154 | |facenet-pytorch|12.97|20.32|25.50| 155 | |facenet-pytorch (non-batched)|9.75|14.81|19.68| 156 | |dlib|3.80|8.39|14.53| 157 | |mtcnn|3.04|5.70|8.23| 158 | 159 | ![](examples/performance-comparison.png) 160 | 161 | See the [notebook on kaggle](https://www.kaggle.com/timesler/comparison-of-face-detection-packages). 162 | 163 | ### *The FastMTCNN algorithm* 164 | 165 | This algorithm demonstrates how to achieve extremely efficient face detection specifically in videos, by taking advantage of similarities between adjacent frames. 166 | 167 | See the [notebook on kaggle](https://www.kaggle.com/timesler/fast-mtcnn-detector-55-fps-at-full-resolution). 168 | 169 | ## Running with docker 170 | 171 | The package and any of the example notebooks can be run with docker (or nvidia-docker) using: 172 | 173 | ```bash 174 | docker run --rm -p 8888:8888 175 | -v ./facenet-pytorch:/home/jovyan timesler/jupyter-dl-gpu \ 176 | -v :/home/jovyan/data 177 | pip install facenet-pytorch && jupyter lab 178 | ``` 179 | 180 | Navigate to the examples/ directory and run any of the ipython notebooks. 181 | 182 | See [timesler/jupyter-dl-gpu](https://github.com/timesler/docker-jupyter-dl-gpu) for docker container details. 183 | 184 | ## Use this repo in your own git project 185 | 186 | To use this code in your own git repo, I recommend first adding this repo as a submodule. Note that the dash ('-') in the repo name should be removed when cloning as a submodule as it will break python when importing: 187 | 188 | `git submodule add https://github.com/timesler/facenet-pytorch.git facenet_pytorch` 189 | 190 | Alternatively, the code can be installed as a package using pip: 191 | 192 | `pip install facenet-pytorch` 193 | 194 | ## Conversion of parameters from Tensorflow to Pytorch 195 | 196 | See: [models/utils/tensorflow2pytorch.py](models/tensorflow2pytorch.py) 197 | 198 | Note that this functionality is not needed to use the models in this repo, which depend only on the saved pytorch `state_dict`'s. 199 | 200 | Following instantiation of the pytorch model, each layer's weights were loaded from equivalent layers in the pretrained tensorflow models from [davidsandberg/facenet](https://github.com/davidsandberg/facenet). 201 | 202 | The equivalence of the outputs from the original tensorflow models and the pytorch-ported models have been tested and are identical: 203 | 204 | --- 205 | 206 | `>>> compare_model_outputs(mdl, sess, torch.randn(5, 160, 160, 3).detach())` 207 | 208 | ``` 209 | Passing test data through TF model 210 | 211 | tensor([[-0.0142, 0.0615, 0.0057, ..., 0.0497, 0.0375, -0.0838], 212 | [-0.0139, 0.0611, 0.0054, ..., 0.0472, 0.0343, -0.0850], 213 | [-0.0238, 0.0619, 0.0124, ..., 0.0598, 0.0334, -0.0852], 214 | [-0.0089, 0.0548, 0.0032, ..., 0.0506, 0.0337, -0.0881], 215 | [-0.0173, 0.0630, -0.0042, ..., 0.0487, 0.0295, -0.0791]]) 216 | 217 | Passing test data through PT model 218 | 219 | tensor([[-0.0142, 0.0615, 0.0057, ..., 0.0497, 0.0375, -0.0838], 220 | [-0.0139, 0.0611, 0.0054, ..., 0.0472, 0.0343, -0.0850], 221 | [-0.0238, 0.0619, 0.0124, ..., 0.0598, 0.0334, -0.0852], 222 | [-0.0089, 0.0548, 0.0032, ..., 0.0506, 0.0337, -0.0881], 223 | [-0.0173, 0.0630, -0.0042, ..., 0.0487, 0.0295, -0.0791]], 224 | grad_fn=) 225 | 226 | Distance 1.2874517096861382e-06 227 | ``` 228 | 229 | --- 230 | 231 | In order to re-run the conversion of tensorflow parameters into the pytorch model, ensure you clone this repo _with submodules_, as the davidsandberg/facenet repo is included as a submodule and parts of it are required for the conversion. 232 | 233 | ## References 234 | 235 | 1. David Sandberg's facenet repo: [https://github.com/davidsandberg/facenet](https://github.com/davidsandberg/facenet) 236 | 237 | 1. F. Schroff, D. Kalenichenko, J. Philbin. _FaceNet: A Unified Embedding for Face Recognition and Clustering_, arXiv:1503.03832, 2015. [PDF](https://arxiv.org/pdf/1503.03832) 238 | 239 | 1. Q. Cao, L. Shen, W. Xie, O. M. Parkhi, A. Zisserman. _VGGFace2: A dataset for recognising face across pose and age_, International Conference on Automatic Face and Gesture Recognition, 2018. [PDF](http://www.robots.ox.ac.uk/~vgg/publications/2018/Cao18/cao18.pdf) 240 | 241 | 1. D. Yi, Z. Lei, S. Liao and S. Z. Li. _CASIAWebface: Learning Face Representation from Scratch_, arXiv:1411.7923, 2014. [PDF](https://arxiv.org/pdf/1411.7923) 242 | 243 | 1. K. Zhang, Z. Zhang, Z. Li and Y. Qiao. _Joint Face Detection and Alignment Using Multitask Cascaded Convolutional Networks_, IEEE Signal Processing Letters, 2016. [PDF](https://kpzhang93.github.io/MTCNN_face_detection_alignment/paper/spl.pdf) 244 | -------------------------------------------------------------------------------- /README_cn.md: -------------------------------------------------------------------------------- 1 | # 使用 Pytorch 进行人脸识别 2 | 3 | *Click [here](README.md) to return to the English document* 4 | 5 | > 译者注: 6 | > 7 | > 本项目 [facenet-pytorch](https://github.com/timesler/facenet-pytorch) 是一个十分方便的人脸识别库,可以通过 [pip](https://pypi.org/project/facenet-pytorch/) 直接安装。 8 | > 9 | > 库中包含了两个重要功能 10 | > 11 | > - 人脸检测:使用MTCNN算法 12 | > - 人脸识别:使用FaceNet算法 13 | > 14 | > 利用这个库,可以轻松实现人脸检测和人脸向量映射操作。 15 | > 16 | > 为了方便中文开发者研究学习人脸识别相关任务、贡献代码,我将本项目的README文件以及位于 `examples` 里面的几个示例脚本中必要的部分翻译成了中文,以供参考。 17 | > 18 | > 向本项目的所有贡献者致敬。 19 | > 20 | > 英译汉:[远哥挺乐](https://github.com/yuan2001425) 21 | > 22 | > Translator's Note: 23 | > 24 | > This project [facenet-pytorch](https://github.com/timesler/facenet-pytorch) is a very convenient face recognition library that can be installed directly via [pip](https://pypi.org/project/facenet-pytorch/). 25 | > 26 | > The library contains two important features: 27 | > 28 | > - Face detection: using the MTCNN algorithm 29 | > - Face recognition: using the FaceNet algorithm 30 | > 31 | > With this library, one can easily carry out face detection and face vector mapping operations. 32 | > 33 | > In order to facilitate Chinese developers in studying face recognition and contributing code, I have translated the README file of this project and some necessary parts of several example scripts located in the `examples` directory into Chinese. 34 | > 35 | > Salute to all contributors to this project. 36 | > 37 | > Translated from English to Chinese by [远哥挺乐](https://github.com/yuan2001425). 38 | 39 | [![下载](https://pepy.tech/badge/facenet-pytorch)](https://pepy.tech/project/facenet-pytorch) 40 | 41 | [![代码覆盖率](https://img.shields.io/codecov/c/github/timesler/facenet-pytorch.svg)](https://codecov.io/gh/timesler/facenet-pytorch) 42 | 43 | |Python | 3.10 | 3.10 3.9 | 3.9 3.8 | 44 | | :---: | :---: | :---: | :---: | 45 | | 测试结果 | [![测试状态](https://github.com/timesler/facenet-pytorch/actions/workflows/python-3.10.yml/badge.svg?branch=master)](https://github.com/timesler/facenet-pytorch/actions?query=workflow%3A%22Python+3.10%22+branch%3Amaster) | [![测试状态](https://github.com/timesler/facenet-pytorch/actions/workflows/python-3.9.yml/badge.svg?branch=master)](https://github.com/timesler/facenet-pytorch/actions?query=workflow%3A%22Python+3.9%22+branch%3Amaster) | [![测试状态](https://github.com/timesler/facenet-pytorch/actions/workflows/python-3.8.yml/badge.svg?branch=master)](https://github.com/timesler/facenet-pytorch/actions?query=workflow%3A%22Python+3.8%22+branch%3Amaster) | 46 | 47 | [![xscode](https://img.shields.io/badge/Available%20on-xs%3Acode-blue?style=?style=plastic&logo=appveyor&logo=)](https://xscode.com/timesler/facenet-pytorch) 48 | 49 | 这是 pytorch 中 Inception Resnet (V1) 模型的存储库,在 VGGFace2 和 CASIA-Webface 上进行了预训练。 50 | 51 | Pytorch 模型权重使用从 David Sandberg 的 [tensorflow Facenet repo](https://github.com/davidsandberg/facenet) 移植的参数进行初始化。 52 | 53 | 该存储库中还包含 MTCNN 的高效 pytorch 实现,用于推理之前的人脸检测。这些模型也是经过预训练的。据我们所知,这是最快的 MTCNN 实现。 54 | 55 | ## 目录 56 | 57 | * [目录](#table-of-contents) 58 | * [快速启动](#quick-start) 59 | * [预训练模型](#pretrained-models) 60 | * [示例笔记本](#example-notebooks) 61 | + [*完整的检测和识别流程*](#complete-detection-and-recognition-pipeline) 62 | + [*视频流中的人脸跟踪*](#face-tracking-in-video-streams) 63 | + [*使用新数据微调预训练模型*](#finetuning-pretrained-models-with-new-data) 64 | + [*facenet-pytorch 中的 MTCNN 指南*](#guide-to-mtcnn-in-facenet-pytorch) 65 | + [*人脸检测包的性能比较*](#performance-comparison-of-face-detection-packages) 66 | + [*FastMTCNN 算法*](#the-fastmtcnn-algorithm) 67 | * [使用 docker 运行](#running-with-docker) 68 | * [在您自己的 git 项目中使用此存储库](#use-this-repo-in-your-own-git-project) 69 | * [Tensorflow 到 Pytorch 的参数转换](#conversion-of-parameters-from-tensorflow-to-pytorch) 70 | * [参考资料](#references) 71 | 72 | ## 快速启动 73 | 74 | 1. 安装: 75 | 76 | ````bash 77 | # 使用pip安装: 78 | pip install facenet-pytorch 79 | 80 | # 或克隆此存储库,删除“-”以允许 python 导入: 81 | git clone https://github.com/timesler/facenet-pytorch.git facenet_pytorch 82 | 83 | # 或使用 docker 容器(参见 https://github.com/timesler/docker-jupyter-dl-gpu): 84 | docker run -it --rm timesler/jupyter-dl-gpu pip install facenet-pytorch && ipython 85 | ```` 86 | 87 | 2. 在python中,导入 facenet-pytorch 并实例化模型: 88 | 89 | ````python 90 | from facenet_pytorch import MTCNN, InceptionResnetV1 91 | 92 | # 如果需要,使用 MTCNN 创建人脸检测模型: 93 | mtcnn = MTCNN(image_size=, margin=) 94 | 95 | # 创建一个 inception resnet(在 eval 模式下): 96 | resnet = InceptionResnetV1(pretrained='vggface2').eval() 97 | ```` 98 | 99 | 3. 处理图像: 100 | 101 | ````python 102 | from PIL import Image 103 | 104 | img = Image.open() 105 | 106 | # 获取裁剪和预白化的图像张量 107 | img_cropped = mtcnn(img, save_path=) 108 | 109 | # 计算嵌入(解压缩以添加批量维度) 110 | img_embedding = resnet(img_cropped.unsqueeze(0)) 111 | 112 | # 或者,如果用于 VGGFace2 分类 113 | resnet.classify = True 114 | img_probs = resnet(img_cropped.unsqueeze(0)) 115 | ```` 116 | 117 | 有关使用和实现详细信息,请参阅 `help(MTCNN)` 和 `help(InceptionResnetV1)` 。 118 | 119 | ## 预训练模型 120 | 121 | 请参阅:[models/inception_resnet_v1.py](models/inception_resnet_v1.py) 122 | 123 | 以下模型已移植到 pytorch(包含下载 pytorch state_dict 的链接): 124 | 125 | |模型名称|LFW 准确度(如[此处](https://github.com/davidsandberg/facenet)列出)|训练数据集| 126 | | :- | :-: | -: | 127 | |[20180408-102900](https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180408-102900-casia-webface.pt) (111MB)|0.9905|CASIA-Webface| 128 | |[20180402-114759](https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180402-114759-vggface2.pt) (107MB)|0.9965|VGGFace2| 129 | 130 | 无需手动下载预训练的state_dict;它们会在模型实例化时自动下载,并缓存在 torch 缓存中以供将来使用。要在 pytorch 中使用 Inception Resnet (V1) 模型进行面部识别/识别,请使用: 131 | 132 | ````python 133 | from facenet_pytorch import InceptionResnetV1 134 | 135 | # 对于在 VGGFace2 上预训练的模型 136 | model = InceptionResnetV1(pretrained='vggface2').eval() 137 | 138 | # 对于在 CASIA-Webface 上预训练的模型 139 | model = InceptionResnetV1(pretrained='casia-webface').eval() 140 | 141 | # 对于具有 100 个类的未经训练的模型 142 | model = InceptionResnetV1(num_classes=100).eval() 143 | 144 | # 对于未经训练的 1001 类分类器 145 | model = InceptionResnetV1(classify=True, num_classes=1001).eval() 146 | ```` 147 | 148 | 两个预训练模型均在 160x160 像素图像上进行训练,因此如果应用于调整为该形状的图像,则效果最佳。为了获得最佳结果,还应该使用 MTCNN 将图像裁剪到脸部(见下文)。 149 | 150 | 默认情况下,上述模型将返回 512 维图像嵌入。要启用分类,请将 `classify=True` 传递给模型构造函数,或者您可以随后使用 `model.classify = True` 设置对象属性。对于 VGGFace2,预训练模型将输出长度为 8631 的 logit 向量,对于 CASIA-Webface 则输出长度为 10575 的 logit 向量。 151 | 152 | ## 示例笔记本 153 | 154 | ### *完整的检测和识别流程* 155 | 156 | 通过首先使用 MTCNN 检测人脸,然后使用 Inception Resnet 模型计算嵌入或概率,可以轻松地将人脸识别应用于原始图像。 [examples/infer_cn.ipynb](examples/infer_cn.ipynb) 中的示例代码提供了一个利用数据集、数据加载器和可选 GPU 处理的完整示例流程。 157 | 158 | ### *视频流中的人脸跟踪* 159 | 160 | MTCNN 可用于构建人脸跟踪系统(使用 `MTCNN.detect()` 方法)。完整的面部跟踪示例可以在 [examples/face_tracking_cn.ipynb](examples/face_tracking_cn.ipynb) 中找到。 161 | 162 | ![](examples/tracked.gif) 163 | 164 | ### *使用新数据微调预训练模型* 165 | 166 | 在大多数情况下,实现人脸识别的最佳方法是直接使用预训练模型,通过聚类算法或简单的距离度量来确定人脸的身份。但是,如果需要微调(即,如果您想根据模型的输出 logits 选择标识),可以在 [examples/finetune_cn.ipynb](examples/finetune_cn.ipynb) 中找到示例。 167 | 168 | ### *facenet-pytorch 中的 MTCNN 指南* 169 | 170 | 本指南演示了 MTCNN 模块的功能。涵盖的主题有: 171 | 172 | * 基本用法 173 | * 图像标准化 174 | * 面边距 175 | * 单张图像中的多个面孔 176 | * 批量检测 177 | * 边界框和面部标志 178 | * 保存人脸数据集 179 | 180 | 请参阅[kaggle 笔记本](https://www.kaggle.com/timesler/guide-to-mtcnn-in-facenet-pytorch)。 181 | 182 | ### *人脸检测包的性能比较* 183 | 184 | 本笔记本演示了三个人脸检测包的使用: 185 | 186 | 1. facenet-pytorch 187 | 2. mtcnn 188 | 3. dlib 189 | 190 | 每个包都经过测试,测试其在启用 GPU 支持的情况下检测一组 300 张图像(来自一个视频的所有帧)中的面部的速度。性能基于 Kaggle 的 P100 笔记本内核。结果总结如下。 191 | 192 | |套餐|FPS (1080x1920)|FPS (720x1280)|FPS (540x960)| 193 | |---|---|---|---| 194 | |facenet-pytorch|12.97|20.32|25.50| 195 | |facenet-pytorch(非批处理)|9.75|14.81|19.68| 196 | |dlib|3.80|8.39|14.53| 197 | |mtcnn|3.04|5.70|8.23| 198 | 199 | ![](examples/performance-comparison.png) 200 | 201 | 请参阅[kaggle 笔记本](https://www.kaggle.com/timesler/comparison-of-face-detection-packages)。 202 | 203 | ### *FastMTCNN 算法* 204 | 205 | 该算法演示了如何通过利用相邻帧之间的相似性来实现极其高效的人脸检测,特别是在视频中。 206 | 207 | 请参阅[kaggle 笔记本](https://www.kaggle.com/timesler/fast-mtcnn- detector-55-fps-at-full-resolution)。 208 | 209 | ## 使用 docker 运行 210 | 211 | 该包和任何示例笔记本都可以使用 docker(或 nvidia-docker)运行: 212 | 213 | ````bash 214 | docker run --rm -p 8888:8888 215 | -v ./facenet-pytorch:/home/jovyan timesler/jupyter-dl-gpu \ 216 | -v :/home/jovyan/data 217 | pip install facenet-pytorch && jupyter lab 218 | ```` 219 | 220 | 导航到 example/ 目录并运行任何 ipython 笔记本。 221 | 222 | 有关 docker 容器的详细信息,请参阅 [timesler/jupyter-dl-gpu](https://github.com/timesler/docker-jupyter-dl-gpu)。 223 | 224 | ## 在您自己的 git 项目中使用此存储库 225 | 226 | 要在您自己的 git 存储库中使用此代码,我建议首先将此存储库添加为子模块。请注意,当克隆为子模块时,应删除存储库名称中的破折号(“-”),因为它会在导入时破坏 python: 227 | 228 | `git submodule add https://github.com/timesler/facenet-pytorch.git facenet_pytorch` 229 | 230 | 或者,可以使用 pip 将代码安装为包: 231 | 232 | `pip install facenet-pytorch` 233 | 234 | ## Tensorflow 到 Pytorch 的参数转换 235 | 236 | 请参阅:[models/utils/tensorflow2pytorch.py](models/tensorflow2pytorch.py) 237 | 238 | 请注意,使用此存储库中的模型不需要此功能,该功能仅依赖于pytorch保存的 `state_dict`。 239 | 240 | 实例化 pytorch 模型后,每层的权重均从 [davidsandberg/facenet](https://github.com/davidsandberg/facenet) 的预训练Tensorflow模型中的等效层加载。 241 | 242 | 原始Tensorflow模型和 pytorch 移植模型的输出的等效性已经过测试并且是相同的: 243 | 244 | --- 245 | 246 | `>>> compare_model_outputs(mdl, sess, torch.randn(5, 160, 160, 3).detach())` 247 | 248 | ```` 249 | Passing test data through TF model (通过TF模型传递测试数据) 250 | 251 | tensor([[-0.0142, 0.0615, 0.0057, ..., 0.0497, 0.0375, -0.0838], 252 | [-0.0139, 0.0611, 0.0054, ..., 0.0472, 0.0343, -0.0850], 253 | [-0.0238, 0.0619, 0.0124, ..., 0.0598, 0.0334, -0.0852], 254 | [-0.0089, 0.0548, 0.0032, ..., 0.0506, 0.0337, -0.0881], 255 | [-0.0173, 0.0630, -0.0042, ..., 0.0487, 0.0295, -0.0791]]) 256 | 257 | Passing test data through PT model (通过PT模型传递测试数据) 258 | 259 | tensor([[-0.0142, 0.0615, 0.0057, ..., 0.0497, 0.0375, -0.0838], 260 | [-0.0139, 0.0611, 0.0054, ..., 0.0472, 0.0343, -0.0850], 261 | [-0.0238, 0.0619, 0.0124, ..., 0.0598, 0.0334, -0.0852], 262 | [-0.0089, 0.0548, 0.0032, ..., 0.0506, 0.0337, -0.0881], 263 | [-0.0173, 0.0630, -0.0042, ..., 0.0487, 0.0295, -0.0791]], 264 | grad_fn=) 265 | 266 | Distance 1.2874517096861382e-06 (距离1.2874517096861382e-06) 267 | ```` 268 | 269 | --- 270 | 271 | 为了重新运行Tensorflow参数到 pytorch 模型的转换,请确保使用子模块克隆此存储库,因为 davidsandberg/facenet 存储库作为子模块包含在内,并且转换需要其中的一部分。 272 | 273 | ## 参考资料 274 | 275 | 1. David Sandberg's facenet repo: [https://github.com/davidsandberg/facenet](https://github.com/davidsandberg/facenet) 276 | 2. F. Schroff, D. Kalenichenko, J. Philbin. _FaceNet: A Unified Embedding for Face Recognition and Clustering_, arXiv:1503.03832, 2015. [PDF](https://arxiv.org/pdf/1503.03832) 277 | 278 | 3. Q. Cao, L. Shen, W. Xie, O. M. Parkhi, A. Zisserman. _VGGFace2: A dataset for recognising face across pose and age_, International Conference on Automatic Face and Gesture Recognition, 2018. [PDF](http://www.robots.ox.ac.uk/~vgg/publications/2018/Cao18/cao18.pdf) 279 | 280 | 4. D. Yi, Z. Lei, S. Liao and S. Z. Li. _CASIAWebface: Learning Face Representation from Scratch_, arXiv:1411.7923, 2014. [PDF](https://arxiv.org/pdf/1411.7923) 281 | 282 | 5. K. Zhang, Z. Zhang, Z. Li and Y. Qiao. _Joint Face Detection and Alignment Using Multitask Cascaded Convolutional Networks_, IEEE Signal Processing Letters, 2016. [PDF](https://kpzhang93.github.io/MTCNN_face_detection_alignment/paper/spl.pdf) 283 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .models.inception_resnet_v1 import InceptionResnetV1 2 | from .models.mtcnn import MTCNN, PNet, RNet, ONet, prewhiten, fixed_image_standardization 3 | from .models.utils.detect_face import extract_face 4 | from .models.utils import training 5 | 6 | import warnings 7 | warnings.filterwarnings( 8 | action="ignore", 9 | message="This overload of nonzero is deprecated:\n\tnonzero()", 10 | category=UserWarning 11 | ) -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: off 4 | patch: off 5 | codecov: 6 | token: 1e4f3aaa-9c74-4888-9408-71cc7fcfb64c 7 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | 2018* 2 | *.json 3 | profile.txt 4 | -------------------------------------------------------------------------------- /data/facenet-pytorch-banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/facenet-pytorch-banner.png -------------------------------------------------------------------------------- /data/multiface.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/multiface.jpg -------------------------------------------------------------------------------- /data/multiface_detected.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/multiface_detected.png -------------------------------------------------------------------------------- /data/onet.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/onet.pt -------------------------------------------------------------------------------- /data/pnet.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/pnet.pt -------------------------------------------------------------------------------- /data/rnet.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/rnet.pt -------------------------------------------------------------------------------- /data/test_images/angelina_jolie/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/test_images/angelina_jolie/1.jpg -------------------------------------------------------------------------------- /data/test_images/bradley_cooper/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/test_images/bradley_cooper/1.jpg -------------------------------------------------------------------------------- /data/test_images/kate_siegel/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/test_images/kate_siegel/1.jpg -------------------------------------------------------------------------------- /data/test_images/paul_rudd/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/test_images/paul_rudd/1.jpg -------------------------------------------------------------------------------- /data/test_images/shea_whigham/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/test_images/shea_whigham/1.jpg -------------------------------------------------------------------------------- /data/test_images_aligned/angelina_jolie/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/test_images_aligned/angelina_jolie/1.png -------------------------------------------------------------------------------- /data/test_images_aligned/bradley_cooper/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/test_images_aligned/bradley_cooper/1.png -------------------------------------------------------------------------------- /data/test_images_aligned/kate_siegel/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/test_images_aligned/kate_siegel/1.png -------------------------------------------------------------------------------- /data/test_images_aligned/paul_rudd/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/test_images_aligned/paul_rudd/1.png -------------------------------------------------------------------------------- /data/test_images_aligned/shea_whigham/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/data/test_images_aligned/shea_whigham/1.png -------------------------------------------------------------------------------- /examples/finetune.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Face detection and recognition training pipeline\n", 8 | "\n", 9 | "The following example illustrates how to fine-tune an InceptionResnetV1 model on your own dataset. This will mostly follow standard pytorch training patterns." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training\n", 19 | "import torch\n", 20 | "from torch.utils.data import DataLoader, SubsetRandomSampler\n", 21 | "from torch import optim\n", 22 | "from torch.optim.lr_scheduler import MultiStepLR\n", 23 | "from torch.utils.tensorboard import SummaryWriter\n", 24 | "from torchvision import datasets, transforms\n", 25 | "import numpy as np\n", 26 | "import os" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "#### Define run parameters\n", 34 | "\n", 35 | "The dataset should follow the VGGFace2/ImageNet-style directory layout. Modify `data_dir` to the location of the dataset on wish to finetune on." 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "data_dir = '../data/test_images'\n", 45 | "\n", 46 | "batch_size = 32\n", 47 | "epochs = 8\n", 48 | "workers = 0 if os.name == 'nt' else 8" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "#### Determine if an nvidia GPU is available" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 65 | "print('Running on device: {}'.format(device))" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "#### Define MTCNN module\n", 73 | "\n", 74 | "See `help(MTCNN)` for more details." 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "mtcnn = MTCNN(\n", 84 | " image_size=160, margin=0, min_face_size=20,\n", 85 | " thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,\n", 86 | " device=device\n", 87 | ")" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "#### Perfom MTCNN facial detection\n", 95 | "\n", 96 | "Iterate through the DataLoader object and obtain cropped faces." 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": { 103 | "scrolled": true 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "dataset = datasets.ImageFolder(data_dir, transform=transforms.Resize((512, 512)))\n", 108 | "dataset.samples = [\n", 109 | " (p, p.replace(data_dir, data_dir + '_cropped'))\n", 110 | " for p, _ in dataset.samples\n", 111 | "]\n", 112 | " \n", 113 | "loader = DataLoader(\n", 114 | " dataset,\n", 115 | " num_workers=workers,\n", 116 | " batch_size=batch_size,\n", 117 | " collate_fn=training.collate_pil\n", 118 | ")\n", 119 | "\n", 120 | "for i, (x, y) in enumerate(loader):\n", 121 | " mtcnn(x, save_path=y)\n", 122 | " print('\\rBatch {} of {}'.format(i + 1, len(loader)), end='')\n", 123 | " \n", 124 | "# Remove mtcnn to reduce GPU memory usage\n", 125 | "del mtcnn" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "#### Define Inception Resnet V1 module\n", 133 | "\n", 134 | "See `help(InceptionResnetV1)` for more details." 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "resnet = InceptionResnetV1(\n", 144 | " classify=True,\n", 145 | " pretrained='vggface2',\n", 146 | " num_classes=len(dataset.class_to_idx)\n", 147 | ").to(device)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "#### Define optimizer, scheduler, dataset, and dataloader" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "optimizer = optim.Adam(resnet.parameters(), lr=0.001)\n", 164 | "scheduler = MultiStepLR(optimizer, [5, 10])\n", 165 | "\n", 166 | "trans = transforms.Compose([\n", 167 | " np.float32,\n", 168 | " transforms.ToTensor(),\n", 169 | " fixed_image_standardization\n", 170 | "])\n", 171 | "dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)\n", 172 | "img_inds = np.arange(len(dataset))\n", 173 | "np.random.shuffle(img_inds)\n", 174 | "train_inds = img_inds[:int(0.8 * len(img_inds))]\n", 175 | "val_inds = img_inds[int(0.8 * len(img_inds)):]\n", 176 | "\n", 177 | "train_loader = DataLoader(\n", 178 | " dataset,\n", 179 | " num_workers=workers,\n", 180 | " batch_size=batch_size,\n", 181 | " sampler=SubsetRandomSampler(train_inds)\n", 182 | ")\n", 183 | "val_loader = DataLoader(\n", 184 | " dataset,\n", 185 | " num_workers=workers,\n", 186 | " batch_size=batch_size,\n", 187 | " sampler=SubsetRandomSampler(val_inds)\n", 188 | ")" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "#### Define loss and evaluation functions" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "loss_fn = torch.nn.CrossEntropyLoss()\n", 205 | "metrics = {\n", 206 | " 'fps': training.BatchTimer(),\n", 207 | " 'acc': training.accuracy\n", 208 | "}" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "#### Train model" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "writer = SummaryWriter()\n", 225 | "writer.iteration, writer.interval = 0, 10\n", 226 | "\n", 227 | "print('\\n\\nInitial')\n", 228 | "print('-' * 10)\n", 229 | "resnet.eval()\n", 230 | "training.pass_epoch(\n", 231 | " resnet, loss_fn, val_loader,\n", 232 | " batch_metrics=metrics, show_running=True, device=device,\n", 233 | " writer=writer\n", 234 | ")\n", 235 | "\n", 236 | "for epoch in range(epochs):\n", 237 | " print('\\nEpoch {}/{}'.format(epoch + 1, epochs))\n", 238 | " print('-' * 10)\n", 239 | "\n", 240 | " resnet.train()\n", 241 | " training.pass_epoch(\n", 242 | " resnet, loss_fn, train_loader, optimizer, scheduler,\n", 243 | " batch_metrics=metrics, show_running=True, device=device,\n", 244 | " writer=writer\n", 245 | " )\n", 246 | "\n", 247 | " resnet.eval()\n", 248 | " training.pass_epoch(\n", 249 | " resnet, loss_fn, val_loader,\n", 250 | " batch_metrics=metrics, show_running=True, device=device,\n", 251 | " writer=writer\n", 252 | " )\n", 253 | "\n", 254 | "writer.close()" 255 | ] 256 | } 257 | ], 258 | "metadata": { 259 | "kernelspec": { 260 | "display_name": "Python 3", 261 | "language": "python", 262 | "name": "python3" 263 | }, 264 | "language_info": { 265 | "codemirror_mode": { 266 | "name": "ipython", 267 | "version": 3 268 | }, 269 | "file_extension": ".py", 270 | "mimetype": "text/x-python", 271 | "name": "python", 272 | "nbconvert_exporter": "python", 273 | "pygments_lexer": "ipython3", 274 | "version": "3.7.3" 275 | } 276 | }, 277 | "nbformat": 4, 278 | "nbformat_minor": 2 279 | } 280 | -------------------------------------------------------------------------------- /examples/finetune_cn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 人脸检测和识别训练流程\n", 8 | "\n", 9 | "以下示例展示了如何在自己的数据集上微调InceptionResnetV1模型。这将主要遵循标准的PyTorch训练模式。" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training\n", 19 | "import torch\n", 20 | "from torch.utils.data import DataLoader, SubsetRandomSampler\n", 21 | "from torch import optim\n", 22 | "from torch.optim.lr_scheduler import MultiStepLR\n", 23 | "from torch.utils.tensorboard import SummaryWriter\n", 24 | "from torchvision import datasets, transforms\n", 25 | "import numpy as np\n", 26 | "import os" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "#### 定义运行参数\n", 34 | "\n", 35 | "数据集应该遵循VGGFace2/ImageNet风格的目录布局。将`data_dir`修改为您要微调的数据集所在的位置。" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "data_dir = '../data/test_images'\n", 45 | "\n", 46 | "batch_size = 32\n", 47 | "epochs = 8\n", 48 | "workers = 0 if os.name == 'nt' else 8" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "#### 判断是否有nvidia GPU可用" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 65 | "print('在该设备上运行: {}'.format(device))" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "#### 定义MTCNN模块\n", 73 | "\n", 74 | "查看`help(MTCNN)`获取更多细节。" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "mtcnn = MTCNN(\n", 84 | " image_size=160, margin=0, min_face_size=20,\n", 85 | " thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,\n", 86 | " device=device\n", 87 | ")" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "#### 执行MTCNN人脸检测\n", 95 | "\n", 96 | "迭代DataLoader对象并获取裁剪后的人脸。" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": { 103 | "scrolled": true 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "dataset = datasets.ImageFolder(data_dir, transform=transforms.Resize((512, 512)))\n", 108 | "dataset.samples = [\n", 109 | " (p, p.replace(data_dir, data_dir + '_cropped'))\n", 110 | " for p, _ in dataset.samples\n", 111 | "]\n", 112 | " \n", 113 | "loader = DataLoader(\n", 114 | " dataset,\n", 115 | " num_workers=workers,\n", 116 | " batch_size=batch_size,\n", 117 | " collate_fn=training.collate_pil\n", 118 | ")\n", 119 | "\n", 120 | "for i, (x, y) in enumerate(loader):\n", 121 | " mtcnn(x, save_path=y)\n", 122 | " print('\\r第 {} 批,共 {} 批'.format(i + 1, len(loader)), end='')\n", 123 | " \n", 124 | "# Remove mtcnn to reduce GPU memory usage\n", 125 | "del mtcnn" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "#### 定义Inception Resnet V1模块\n", 133 | "\n", 134 | "查看`help(InceptionResnetV1)`获取更多细节。" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": {}, 141 | "outputs": [], 142 | "source": [ 143 | "resnet = InceptionResnetV1(\n", 144 | " classify=True,\n", 145 | " pretrained='vggface2',\n", 146 | " num_classes=len(dataset.class_to_idx)\n", 147 | ").to(device)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "metadata": {}, 153 | "source": [ 154 | "#### 定义优化器、调度器、数据集和数据加载器" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "optimizer = optim.Adam(resnet.parameters(), lr=0.001)\n", 164 | "scheduler = MultiStepLR(optimizer, [5, 10])\n", 165 | "\n", 166 | "trans = transforms.Compose([\n", 167 | " np.float32,\n", 168 | " transforms.ToTensor(),\n", 169 | " fixed_image_standardization\n", 170 | "])\n", 171 | "dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)\n", 172 | "img_inds = np.arange(len(dataset))\n", 173 | "np.random.shuffle(img_inds)\n", 174 | "train_inds = img_inds[:int(0.8 * len(img_inds))]\n", 175 | "val_inds = img_inds[int(0.8 * len(img_inds)):]\n", 176 | "\n", 177 | "train_loader = DataLoader(\n", 178 | " dataset,\n", 179 | " num_workers=workers,\n", 180 | " batch_size=batch_size,\n", 181 | " sampler=SubsetRandomSampler(train_inds)\n", 182 | ")\n", 183 | "val_loader = DataLoader(\n", 184 | " dataset,\n", 185 | " num_workers=workers,\n", 186 | " batch_size=batch_size,\n", 187 | " sampler=SubsetRandomSampler(val_inds)\n", 188 | ")" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "#### 定义损失和评估函数" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "loss_fn = torch.nn.CrossEntropyLoss()\n", 205 | "metrics = {\n", 206 | " 'fps': training.BatchTimer(),\n", 207 | " 'acc': training.accuracy\n", 208 | "}" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "#### 训练模型" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "writer = SummaryWriter()\n", 225 | "writer.iteration, writer.interval = 0, 10\n", 226 | "\n", 227 | "print('\\n\\n初始化')\n", 228 | "print('-' * 10)\n", 229 | "resnet.eval()\n", 230 | "training.pass_epoch(\n", 231 | " resnet, loss_fn, val_loader,\n", 232 | " batch_metrics=metrics, show_running=True, device=device,\n", 233 | " writer=writer\n", 234 | ")\n", 235 | "\n", 236 | "for epoch in range(epochs):\n", 237 | " print('\\n循环 {}/{}'.format(epoch + 1, epochs))\n", 238 | " print('-' * 10)\n", 239 | "\n", 240 | " resnet.train()\n", 241 | " training.pass_epoch(\n", 242 | " resnet, loss_fn, train_loader, optimizer, scheduler,\n", 243 | " batch_metrics=metrics, show_running=True, device=device,\n", 244 | " writer=writer\n", 245 | " )\n", 246 | "\n", 247 | " resnet.eval()\n", 248 | " training.pass_epoch(\n", 249 | " resnet, loss_fn, val_loader,\n", 250 | " batch_metrics=metrics, show_running=True, device=device,\n", 251 | " writer=writer\n", 252 | " )\n", 253 | "\n", 254 | "writer.close()" 255 | ] 256 | } 257 | ], 258 | "metadata": { 259 | "kernelspec": { 260 | "display_name": "Python 3 (ipykernel)", 261 | "language": "python", 262 | "name": "python3" 263 | }, 264 | "language_info": { 265 | "codemirror_mode": { 266 | "name": "ipython", 267 | "version": 3 268 | }, 269 | "file_extension": ".py", 270 | "mimetype": "text/x-python", 271 | "name": "python", 272 | "nbconvert_exporter": "python", 273 | "pygments_lexer": "ipython3", 274 | "version": "3.11.3" 275 | } 276 | }, 277 | "nbformat": 4, 278 | "nbformat_minor": 2 279 | } 280 | -------------------------------------------------------------------------------- /examples/infer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Face detection and recognition inference pipeline\n", 8 | "\n", 9 | "The following example illustrates how to use the `facenet_pytorch` python package to perform face detection and recogition on an image dataset using an Inception Resnet V1 pretrained on the VGGFace2 dataset.\n", 10 | "\n", 11 | "The following Pytorch methods are included:\n", 12 | "* Datasets\n", 13 | "* Dataloaders\n", 14 | "* GPU/CPU processing" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "from facenet_pytorch import MTCNN, InceptionResnetV1\n", 24 | "import torch\n", 25 | "from torch.utils.data import DataLoader\n", 26 | "from torchvision import datasets\n", 27 | "import numpy as np\n", 28 | "import pandas as pd\n", 29 | "import os\n", 30 | "\n", 31 | "workers = 0 if os.name == 'nt' else 4" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "#### Determine if an nvidia GPU is available" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "Running on device: cuda:0\n" 51 | ] 52 | } 53 | ], 54 | "source": [ 55 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 56 | "print('Running on device: {}'.format(device))" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "#### Define MTCNN module\n", 64 | "\n", 65 | "Default params shown for illustration, but not needed. Note that, since MTCNN is a collection of neural nets and other code, the device must be passed in the following way to enable copying of objects when needed internally.\n", 66 | "\n", 67 | "See `help(MTCNN)` for more details." 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 3, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "mtcnn = MTCNN(\n", 77 | " image_size=160, margin=0, min_face_size=20,\n", 78 | " thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,\n", 79 | " device=device\n", 80 | ")" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "#### Define Inception Resnet V1 module\n", 88 | "\n", 89 | "Set classify=True for pretrained classifier. For this example, we will use the model to output embeddings/CNN features. Note that for inference, it is important to set the model to `eval` mode.\n", 90 | "\n", 91 | "See `help(InceptionResnetV1)` for more details." 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 4, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": {}, 106 | "source": [ 107 | "#### Define a dataset and data loader\n", 108 | "\n", 109 | "We add the `idx_to_class` attribute to the dataset to enable easy recoding of label indices to identity names later one." 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 5, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "def collate_fn(x):\n", 119 | " return x[0]\n", 120 | "\n", 121 | "dataset = datasets.ImageFolder('../data/test_images')\n", 122 | "dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}\n", 123 | "loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=workers)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "#### Perfom MTCNN facial detection\n", 131 | "\n", 132 | "Iterate through the DataLoader object and detect faces and associated detection probabilities for each. The `MTCNN` forward method returns images cropped to the detected face, if a face was detected. By default only a single detected face is returned - to have `MTCNN` return all detected faces, set `keep_all=True` when creating the MTCNN object above.\n", 133 | "\n", 134 | "To obtain bounding boxes rather than cropped face images, you can instead call the lower-level `mtcnn.detect()` function. See `help(mtcnn.detect)` for details." 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 6, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "name": "stdout", 144 | "output_type": "stream", 145 | "text": [ 146 | "Face detected with probability: 0.999957\n", 147 | "Face detected with probability: 0.999927\n", 148 | "Face detected with probability: 0.999662\n", 149 | "Face detected with probability: 0.999873\n", 150 | "Face detected with probability: 0.999991\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "aligned = []\n", 156 | "names = []\n", 157 | "for x, y in loader:\n", 158 | " x_aligned, prob = mtcnn(x, return_prob=True)\n", 159 | " if x_aligned is not None:\n", 160 | " print('Face detected with probability: {:8f}'.format(prob))\n", 161 | " aligned.append(x_aligned)\n", 162 | " names.append(dataset.idx_to_class[y])" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "#### Calculate image embeddings\n", 170 | "\n", 171 | "MTCNN will return images of faces all the same size, enabling easy batch processing with the Resnet recognition module. Here, since we only have a few images, we build a single batch and perform inference on it. \n", 172 | "\n", 173 | "For real datasets, code should be modified to control batch sizes being passed to the Resnet, particularly if being processed on a GPU. For repeated testing, it is best to separate face detection (using MTCNN) from embedding or classification (using InceptionResnetV1), as calculation of cropped faces or bounding boxes can then be performed a single time and detected faces saved for future use." 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 7, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "aligned = torch.stack(aligned).to(device)\n", 183 | "embeddings = resnet(aligned).detach().cpu()" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "#### Print distance matrix for classes" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 8, 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | " angelina_jolie bradley_cooper kate_siegel paul_rudd \\\n", 203 | "angelina_jolie 0.000000 1.344806 0.781201 1.425579 \n", 204 | "bradley_cooper 1.344806 0.000000 1.256238 0.922126 \n", 205 | "kate_siegel 0.781201 1.256238 0.000000 1.366423 \n", 206 | "paul_rudd 1.425579 0.922126 1.366423 0.000000 \n", 207 | "shea_whigham 1.448495 0.891145 1.416447 0.985438 \n", 208 | "\n", 209 | " shea_whigham \n", 210 | "angelina_jolie 1.448495 \n", 211 | "bradley_cooper 0.891145 \n", 212 | "kate_siegel 1.416447 \n", 213 | "paul_rudd 0.985438 \n", 214 | "shea_whigham 0.000000 \n" 215 | ] 216 | } 217 | ], 218 | "source": [ 219 | "dists = [[(e1 - e2).norm().item() for e2 in embeddings] for e1 in embeddings]\n", 220 | "print(pd.DataFrame(dists, columns=names, index=names))" 221 | ] 222 | } 223 | ], 224 | "metadata": { 225 | "kernelspec": { 226 | "display_name": "Python 3", 227 | "language": "python", 228 | "name": "python3" 229 | }, 230 | "language_info": { 231 | "codemirror_mode": { 232 | "name": "ipython", 233 | "version": 3 234 | }, 235 | "file_extension": ".py", 236 | "mimetype": "text/x-python", 237 | "name": "python", 238 | "nbconvert_exporter": "python", 239 | "pygments_lexer": "ipython3", 240 | "version": "3.7.3" 241 | } 242 | }, 243 | "nbformat": 4, 244 | "nbformat_minor": 2 245 | } -------------------------------------------------------------------------------- /examples/infer_cn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 人脸检测和识别推理流程\n", 8 | "\n", 9 | "以下示例展示了如何使用facenet_pytorch python包,在使用在VGGFace2数据集上预训练的Inception Resnet V1模型上对图像数据集执行人脸检测和识别。\n", 10 | "\n", 11 | "以下PyTorch方法已包含:\n", 12 | "\n", 13 | "* 数据集\n", 14 | "* 数据加载器\n", 15 | "* GPU/CPU处理" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": { 22 | "ExecuteTime": { 23 | "end_time": "2023-07-20T11:23:56.432302Z", 24 | "start_time": "2023-07-20T11:23:54.333487Z" 25 | } 26 | }, 27 | "outputs": [], 28 | "source": [ 29 | "from facenet_pytorch import MTCNN, InceptionResnetV1\n", 30 | "import torch\n", 31 | "from torch.utils.data import DataLoader\n", 32 | "from torchvision import datasets\n", 33 | "import numpy as np\n", 34 | "import pandas as pd\n", 35 | "import os\n", 36 | "\n", 37 | "workers = 0 if os.name == 'nt' else 4" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "metadata": {}, 43 | "source": [ 44 | "#### 判断是否有nvidia GPU可用" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "metadata": { 51 | "ExecuteTime": { 52 | "end_time": "2023-07-20T11:23:56.470217Z", 53 | "start_time": "2023-07-20T11:23:56.436290Z" 54 | } 55 | }, 56 | "outputs": [ 57 | { 58 | "name": "stdout", 59 | "output_type": "stream", 60 | "text": [ 61 | "在该设备上运行: cuda:0\n" 62 | ] 63 | } 64 | ], 65 | "source": [ 66 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 67 | "print('在该设备上运行: {}'.format(device))" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "#### 定义MTCNN模块\n", 75 | "\n", 76 | "为了说明,默认参数已显示,但不是必需的。请注意,由于MTCNN是一组神经网络和其他代码,因此必须以以下方式传递设备,以便在需要内部复制对象时启用。\n", 77 | "\n", 78 | "查看`help(MTCNN)`获取更多细节。" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 3, 84 | "metadata": { 85 | "ExecuteTime": { 86 | "end_time": "2023-07-20T11:23:56.587926Z", 87 | "start_time": "2023-07-20T11:23:56.472212Z" 88 | } 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "mtcnn = MTCNN(\n", 93 | " image_size=160, margin=0, min_face_size=20,\n", 94 | " thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,\n", 95 | " device=device\n", 96 | ")" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "#### 定义Inception Resnet V1模块\n", 104 | "\n", 105 | "设置classify=True以使用预训练分类器。对于本示例,我们将使用该模型输出嵌入/卷积特征。请注意,在推理过程中,将模型设置为`eval`模式非常重要。\n", 106 | "\n", 107 | "查看`help(InceptionResnetV1)`获取更多细节。" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 4, 113 | "metadata": { 114 | "ExecuteTime": { 115 | "end_time": "2023-07-20T11:23:56.988662Z", 116 | "start_time": "2023-07-20T11:23:56.588910Z" 117 | } 118 | }, 119 | "outputs": [], 120 | "source": [ 121 | "resnet = InceptionResnetV1(pretrained='vggface2').eval().to(device)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "#### 定义数据集和数据加载器\n", 129 | "\n", 130 | "我们向数据集添加了`idx_to_class`属性,以便稍后轻松重编标签索引为身份名称。" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 5, 136 | "metadata": { 137 | "ExecuteTime": { 138 | "end_time": "2023-07-20T11:23:56.995647Z", 139 | "start_time": "2023-07-20T11:23:56.989657Z" 140 | } 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "def collate_fn(x):\n", 145 | " return x[0]\n", 146 | "\n", 147 | "dataset = datasets.ImageFolder('../data/test_images')\n", 148 | "dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}\n", 149 | "loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=workers)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "#### 执行MTCNN人脸检测\n", 157 | "\n", 158 | "迭代DataLoader对象并检测每个人脸及其关联的检测概率。如果检测到脸部,`MTCNN`的前向方法将返回裁剪到检测到的脸部的图像。默认情况下,仅返回检测到的单个面部-要使`MTCNN`返回所有检测到的面部,请在上面创建MTCNN对象时设置`keep_all=True`。\n", 159 | "\n", 160 | "要获取边界框而不是裁剪的人脸图像,可以调用较低级别的`mtcnn.detect()`函数。查看`help(mtcnn.detect)`获取详细信息。" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 6, 166 | "metadata": { 167 | "ExecuteTime": { 168 | "end_time": "2023-07-20T11:24:00.971832Z", 169 | "start_time": "2023-07-20T11:23:56.996640Z" 170 | } 171 | }, 172 | "outputs": [ 173 | { 174 | "name": "stdout", 175 | "output_type": "stream", 176 | "text": [ 177 | "检测到的人脸及其概率: 0.999983\n", 178 | "检测到的人脸及其概率: 0.999934\n", 179 | "检测到的人脸及其概率: 0.999733\n", 180 | "检测到的人脸及其概率: 0.999880\n", 181 | "检测到的人脸及其概率: 0.999992\n" 182 | ] 183 | } 184 | ], 185 | "source": [ 186 | "aligned = []\n", 187 | "names = []\n", 188 | "for x, y in loader:\n", 189 | " x_aligned, prob = mtcnn(x, return_prob=True)\n", 190 | " if x_aligned is not None:\n", 191 | " print('检测到的人脸及其概率: {:8f}'.format(prob))\n", 192 | " aligned.append(x_aligned)\n", 193 | " names.append(dataset.idx_to_class[y])" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "#### 计算图像嵌入\n", 201 | "\n", 202 | "MTCNN将返回所有面部图像的相同大小,从而可以使用Resnet识别模块轻松进行批处理。在这里,由于我们只有一些图像,因此我们构建一个单个批次并对其执行推理。\n", 203 | "\n", 204 | "对于实际数据集,代码应修改为控制传递给Resnet的批处理大小,特别是如果在GPU上处理。对于重复测试,最好将人脸检测(使用MTCNN)与嵌入或分类(使用InceptionResnetV1)分开,因为剪切面或边界框的计算可以一次执行,检测到的面部可以保存供将来使用。" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 7, 210 | "metadata": { 211 | "ExecuteTime": { 212 | "end_time": "2023-07-20T11:24:01.037605Z", 213 | "start_time": "2023-07-20T11:24:00.973820Z" 214 | } 215 | }, 216 | "outputs": [], 217 | "source": [ 218 | "aligned = torch.stack(aligned).to(device)\n", 219 | "embeddings = resnet(aligned).detach().cpu()" 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": {}, 225 | "source": [ 226 | "#### 打印各类别的距离矩阵" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 8, 232 | "metadata": { 233 | "ExecuteTime": { 234 | "end_time": "2023-07-20T11:24:01.050571Z", 235 | "start_time": "2023-07-20T11:24:01.038602Z" 236 | } 237 | }, 238 | "outputs": [ 239 | { 240 | "data": { 241 | "text/html": [ 242 | "
\n", 243 | "\n", 256 | "\n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | "
angelina_joliebradley_cooperkate_siegelpaul_ruddshea_whigham
angelina_jolie0.0000001.4474800.8877281.4298471.399073
bradley_cooper1.4474800.0000001.3137491.0134471.038684
kate_siegel0.8877281.3137490.0000001.3883771.379655
paul_rudd1.4298471.0134471.3883770.0000001.100502
shea_whigham1.3990731.0386841.3796551.1005020.000000
\n", 310 | "
" 311 | ], 312 | "text/plain": [ 313 | " angelina_jolie bradley_cooper kate_siegel paul_rudd \\\n", 314 | "angelina_jolie 0.000000 1.447480 0.887728 1.429847 \n", 315 | "bradley_cooper 1.447480 0.000000 1.313749 1.013447 \n", 316 | "kate_siegel 0.887728 1.313749 0.000000 1.388377 \n", 317 | "paul_rudd 1.429847 1.013447 1.388377 0.000000 \n", 318 | "shea_whigham 1.399073 1.038684 1.379655 1.100502 \n", 319 | "\n", 320 | " shea_whigham \n", 321 | "angelina_jolie 1.399073 \n", 322 | "bradley_cooper 1.038684 \n", 323 | "kate_siegel 1.379655 \n", 324 | "paul_rudd 1.100502 \n", 325 | "shea_whigham 0.000000 " 326 | ] 327 | }, 328 | "execution_count": 8, 329 | "metadata": {}, 330 | "output_type": "execute_result" 331 | } 332 | ], 333 | "source": [ 334 | "dists = [[(e1 - e2).norm().item() for e2 in embeddings] for e1 in embeddings]\n", 335 | "pd.DataFrame(dists, columns=names, index=names)" 336 | ] 337 | } 338 | ], 339 | "metadata": { 340 | "kernelspec": { 341 | "display_name": "Python 3 (ipykernel)", 342 | "language": "python", 343 | "name": "python3" 344 | }, 345 | "language_info": { 346 | "codemirror_mode": { 347 | "name": "ipython", 348 | "version": 3 349 | }, 350 | "file_extension": ".py", 351 | "mimetype": "text/x-python", 352 | "name": "python", 353 | "nbconvert_exporter": "python", 354 | "pygments_lexer": "ipython3", 355 | "version": "3.11.3" 356 | } 357 | }, 358 | "nbformat": 4, 359 | "nbformat_minor": 2 360 | } 361 | -------------------------------------------------------------------------------- /examples/lfw_evaluate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "### facenet-pytorch LFW evaluation\n", 7 | "This notebook demonstrates how to evaluate performance against the LFW dataset." 8 | ], 9 | "metadata": { 10 | "collapsed": false 11 | } 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "outputs": [], 17 | "source": [ 18 | "from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training, extract_face\n", 19 | "import torch\n", 20 | "from torch.utils.data import DataLoader, SubsetRandomSampler, SequentialSampler\n", 21 | "from torchvision import datasets, transforms\n", 22 | "import numpy as np\n", 23 | "import os" 24 | ], 25 | "metadata": { 26 | "collapsed": false, 27 | "pycharm": { 28 | "name": "#%%\n" 29 | } 30 | } 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "outputs": [], 36 | "source": [ 37 | "data_dir = 'data/lfw/lfw'\n", 38 | "pairs_path = 'data/lfw/pairs.txt'\n", 39 | "\n", 40 | "batch_size = 16\n", 41 | "epochs = 15\n", 42 | "workers = 0 if os.name == 'nt' else 8" 43 | ], 44 | "metadata": { 45 | "collapsed": false, 46 | "pycharm": { 47 | "name": "#%%\n" 48 | } 49 | } 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "Running on device: cuda:0\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 65 | "print('Running on device: {}'.format(device))" 66 | ], 67 | "metadata": { 68 | "collapsed": false, 69 | "pycharm": { 70 | "name": "#%%\n" 71 | } 72 | } 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 4, 77 | "outputs": [], 78 | "source": [ 79 | "mtcnn = MTCNN(\n", 80 | " image_size=160,\n", 81 | " margin=14,\n", 82 | " device=device,\n", 83 | " selection_method='center_weighted_size'\n", 84 | ")" 85 | ], 86 | "metadata": { 87 | "collapsed": false, 88 | "pycharm": { 89 | "name": "#%%\n" 90 | } 91 | } 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 5, 96 | "outputs": [], 97 | "source": [ 98 | "# Define the data loader for the input set of images\n", 99 | "orig_img_ds = datasets.ImageFolder(data_dir, transform=None)" 100 | ], 101 | "metadata": { 102 | "collapsed": false, 103 | "pycharm": { 104 | "name": "#%%\n" 105 | } 106 | } 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 6, 111 | "outputs": [], 112 | "source": [ 113 | "\n", 114 | "# overwrites class labels in dataset with path so path can be used for saving output in mtcnn batches\n", 115 | "orig_img_ds.samples = [\n", 116 | " (p, p)\n", 117 | " for p, _ in orig_img_ds.samples\n", 118 | "]\n", 119 | "\n", 120 | "loader = DataLoader(\n", 121 | " orig_img_ds,\n", 122 | " num_workers=workers,\n", 123 | " batch_size=batch_size,\n", 124 | " collate_fn=training.collate_pil\n", 125 | ")\n" 126 | ], 127 | "metadata": { 128 | "collapsed": false, 129 | "pycharm": { 130 | "name": "#%%\n" 131 | } 132 | } 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "crop_paths = []\n", 141 | "box_probs = []\n", 142 | "\n", 143 | "for i, (x, b_paths) in enumerate(loader):\n", 144 | " crops = [p.replace(data_dir, data_dir + '_cropped') for p in b_paths]\n", 145 | " mtcnn(x, save_path=crops)\n", 146 | " crop_paths.extend(crops)\n", 147 | " print('\\rBatch {} of {}'.format(i + 1, len(loader)), end='')" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 8, 153 | "outputs": [], 154 | "source": [ 155 | "# Remove mtcnn to reduce GPU memory usage\n", 156 | "del mtcnn\n", 157 | "torch.cuda.empty_cache()" 158 | ], 159 | "metadata": { 160 | "collapsed": false, 161 | "pycharm": { 162 | "name": "#%%\n" 163 | } 164 | } 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 9, 169 | "outputs": [], 170 | "source": [ 171 | "# create dataset and data loaders from cropped images output from MTCNN\n", 172 | "\n", 173 | "trans = transforms.Compose([\n", 174 | " np.float32,\n", 175 | " transforms.ToTensor(),\n", 176 | " fixed_image_standardization\n", 177 | "])\n", 178 | "\n", 179 | "dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)\n", 180 | "\n", 181 | "embed_loader = DataLoader(\n", 182 | " dataset,\n", 183 | " num_workers=workers,\n", 184 | " batch_size=batch_size,\n", 185 | " sampler=SequentialSampler(dataset)\n", 186 | ")" 187 | ], 188 | "metadata": { 189 | "collapsed": false, 190 | "pycharm": { 191 | "name": "#%%\n" 192 | } 193 | } 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 10, 198 | "outputs": [], 199 | "source": [ 200 | "# Load pretrained resnet model\n", 201 | "resnet = InceptionResnetV1(\n", 202 | " classify=False,\n", 203 | " pretrained='vggface2'\n", 204 | ").to(device)" 205 | ], 206 | "metadata": { 207 | "collapsed": false, 208 | "pycharm": { 209 | "name": "#%%\n" 210 | } 211 | } 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 11, 216 | "outputs": [], 217 | "source": [ 218 | "classes = []\n", 219 | "embeddings = []\n", 220 | "resnet.eval()\n", 221 | "with torch.no_grad():\n", 222 | " for xb, yb in embed_loader:\n", 223 | " xb = xb.to(device)\n", 224 | " b_embeddings = resnet(xb)\n", 225 | " b_embeddings = b_embeddings.to('cpu').numpy()\n", 226 | " classes.extend(yb.numpy())\n", 227 | " embeddings.extend(b_embeddings)" 228 | ], 229 | "metadata": { 230 | "collapsed": false, 231 | "pycharm": { 232 | "name": "#%%\n" 233 | } 234 | } 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 12, 239 | "outputs": [], 240 | "source": [ 241 | "embeddings_dict = dict(zip(crop_paths,embeddings))\n", 242 | "\n" 243 | ], 244 | "metadata": { 245 | "collapsed": false, 246 | "pycharm": { 247 | "name": "#%%\n" 248 | } 249 | } 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "source": [ 254 | "#### Evaluate embeddings by using distance metrics to perform verification on the official LFW test set.\n", 255 | "\n", 256 | "The functions in the next block are copy pasted from `facenet.src.lfw`. Unfortunately that module has an absolute import from `facenet`, so can't be imported from the submodule\n", 257 | "\n", 258 | "added functionality to return false positive and false negatives" 259 | ], 260 | "metadata": { 261 | "collapsed": false 262 | } 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 13, 267 | "outputs": [], 268 | "source": [ 269 | "from sklearn.model_selection import KFold\n", 270 | "from scipy import interpolate\n", 271 | "\n", 272 | "# LFW functions taken from David Sandberg's FaceNet implementation\n", 273 | "def distance(embeddings1, embeddings2, distance_metric=0):\n", 274 | " if distance_metric==0:\n", 275 | " # Euclidian distance\n", 276 | " diff = np.subtract(embeddings1, embeddings2)\n", 277 | " dist = np.sum(np.square(diff),1)\n", 278 | " elif distance_metric==1:\n", 279 | " # Distance based on cosine similarity\n", 280 | " dot = np.sum(np.multiply(embeddings1, embeddings2), axis=1)\n", 281 | " norm = np.linalg.norm(embeddings1, axis=1) * np.linalg.norm(embeddings2, axis=1)\n", 282 | " similarity = dot / norm\n", 283 | " dist = np.arccos(similarity) / math.pi\n", 284 | " else:\n", 285 | " raise 'Undefined distance metric %d' % distance_metric\n", 286 | "\n", 287 | " return dist\n", 288 | "\n", 289 | "def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10, distance_metric=0, subtract_mean=False):\n", 290 | " assert(embeddings1.shape[0] == embeddings2.shape[0])\n", 291 | " assert(embeddings1.shape[1] == embeddings2.shape[1])\n", 292 | " nrof_pairs = min(len(actual_issame), embeddings1.shape[0])\n", 293 | " nrof_thresholds = len(thresholds)\n", 294 | " k_fold = KFold(n_splits=nrof_folds, shuffle=False)\n", 295 | "\n", 296 | " tprs = np.zeros((nrof_folds,nrof_thresholds))\n", 297 | " fprs = np.zeros((nrof_folds,nrof_thresholds))\n", 298 | " accuracy = np.zeros((nrof_folds))\n", 299 | "\n", 300 | " is_false_positive = []\n", 301 | " is_false_negative = []\n", 302 | "\n", 303 | " indices = np.arange(nrof_pairs)\n", 304 | "\n", 305 | " for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):\n", 306 | " if subtract_mean:\n", 307 | " mean = np.mean(np.concatenate([embeddings1[train_set], embeddings2[train_set]]), axis=0)\n", 308 | " else:\n", 309 | " mean = 0.0\n", 310 | " dist = distance(embeddings1-mean, embeddings2-mean, distance_metric)\n", 311 | "\n", 312 | " # Find the best threshold for the fold\n", 313 | " acc_train = np.zeros((nrof_thresholds))\n", 314 | " for threshold_idx, threshold in enumerate(thresholds):\n", 315 | " _, _, acc_train[threshold_idx], _ ,_ = calculate_accuracy(threshold, dist[train_set], actual_issame[train_set])\n", 316 | " best_threshold_index = np.argmax(acc_train)\n", 317 | " for threshold_idx, threshold in enumerate(thresholds):\n", 318 | " tprs[fold_idx,threshold_idx], fprs[fold_idx,threshold_idx], _, _, _ = calculate_accuracy(threshold, dist[test_set], actual_issame[test_set])\n", 319 | " _, _, accuracy[fold_idx], is_fp, is_fn = calculate_accuracy(thresholds[best_threshold_index], dist[test_set], actual_issame[test_set])\n", 320 | "\n", 321 | " tpr = np.mean(tprs,0)\n", 322 | " fpr = np.mean(fprs,0)\n", 323 | " is_false_positive.extend(is_fp)\n", 324 | " is_false_negative.extend(is_fn)\n", 325 | "\n", 326 | " return tpr, fpr, accuracy, is_false_positive, is_false_negative\n", 327 | "\n", 328 | "def calculate_accuracy(threshold, dist, actual_issame):\n", 329 | " predict_issame = np.less(dist, threshold)\n", 330 | " tp = np.sum(np.logical_and(predict_issame, actual_issame))\n", 331 | " fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))\n", 332 | " tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame)))\n", 333 | " fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))\n", 334 | "\n", 335 | " is_fp = np.logical_and(predict_issame, np.logical_not(actual_issame))\n", 336 | " is_fn = np.logical_and(np.logical_not(predict_issame), actual_issame)\n", 337 | "\n", 338 | " tpr = 0 if (tp+fn==0) else float(tp) / float(tp+fn)\n", 339 | " fpr = 0 if (fp+tn==0) else float(fp) / float(fp+tn)\n", 340 | " acc = float(tp+tn)/dist.size\n", 341 | " return tpr, fpr, acc, is_fp, is_fn\n", 342 | "\n", 343 | "def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds=10, distance_metric=0, subtract_mean=False):\n", 344 | " assert(embeddings1.shape[0] == embeddings2.shape[0])\n", 345 | " assert(embeddings1.shape[1] == embeddings2.shape[1])\n", 346 | " nrof_pairs = min(len(actual_issame), embeddings1.shape[0])\n", 347 | " nrof_thresholds = len(thresholds)\n", 348 | " k_fold = KFold(n_splits=nrof_folds, shuffle=False)\n", 349 | "\n", 350 | " val = np.zeros(nrof_folds)\n", 351 | " far = np.zeros(nrof_folds)\n", 352 | "\n", 353 | " indices = np.arange(nrof_pairs)\n", 354 | "\n", 355 | " for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):\n", 356 | " if subtract_mean:\n", 357 | " mean = np.mean(np.concatenate([embeddings1[train_set], embeddings2[train_set]]), axis=0)\n", 358 | " else:\n", 359 | " mean = 0.0\n", 360 | " dist = distance(embeddings1-mean, embeddings2-mean, distance_metric)\n", 361 | "\n", 362 | " # Find the threshold that gives FAR = far_target\n", 363 | " far_train = np.zeros(nrof_thresholds)\n", 364 | " for threshold_idx, threshold in enumerate(thresholds):\n", 365 | " _, far_train[threshold_idx] = calculate_val_far(threshold, dist[train_set], actual_issame[train_set])\n", 366 | " if np.max(far_train)>=far_target:\n", 367 | " f = interpolate.interp1d(far_train, thresholds, kind='slinear')\n", 368 | " threshold = f(far_target)\n", 369 | " else:\n", 370 | " threshold = 0.0\n", 371 | "\n", 372 | " val[fold_idx], far[fold_idx] = calculate_val_far(threshold, dist[test_set], actual_issame[test_set])\n", 373 | "\n", 374 | " val_mean = np.mean(val)\n", 375 | " far_mean = np.mean(far)\n", 376 | " val_std = np.std(val)\n", 377 | " return val_mean, val_std, far_mean\n", 378 | "\n", 379 | "def calculate_val_far(threshold, dist, actual_issame):\n", 380 | " predict_issame = np.less(dist, threshold)\n", 381 | " true_accept = np.sum(np.logical_and(predict_issame, actual_issame))\n", 382 | " false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))\n", 383 | " n_same = np.sum(actual_issame)\n", 384 | " n_diff = np.sum(np.logical_not(actual_issame))\n", 385 | " val = float(true_accept) / float(n_same)\n", 386 | " far = float(false_accept) / float(n_diff)\n", 387 | " return val, far\n", 388 | "\n", 389 | "\n", 390 | "\n", 391 | "def evaluate(embeddings, actual_issame, nrof_folds=10, distance_metric=0, subtract_mean=False):\n", 392 | " # Calculate evaluation metrics\n", 393 | " thresholds = np.arange(0, 4, 0.01)\n", 394 | " embeddings1 = embeddings[0::2]\n", 395 | " embeddings2 = embeddings[1::2]\n", 396 | " tpr, fpr, accuracy, fp, fn = calculate_roc(thresholds, embeddings1, embeddings2,\n", 397 | " np.asarray(actual_issame), nrof_folds=nrof_folds, distance_metric=distance_metric, subtract_mean=subtract_mean)\n", 398 | " thresholds = np.arange(0, 4, 0.001)\n", 399 | " val, val_std, far = calculate_val(thresholds, embeddings1, embeddings2,\n", 400 | " np.asarray(actual_issame), 1e-3, nrof_folds=nrof_folds, distance_metric=distance_metric, subtract_mean=subtract_mean)\n", 401 | " return tpr, fpr, accuracy, val, val_std, far, fp, fn\n", 402 | "\n", 403 | "def add_extension(path):\n", 404 | " if os.path.exists(path+'.jpg'):\n", 405 | " return path+'.jpg'\n", 406 | " elif os.path.exists(path+'.png'):\n", 407 | " return path+'.png'\n", 408 | " else:\n", 409 | " raise RuntimeError('No file \"%s\" with extension png or jpg.' % path)\n", 410 | "\n", 411 | "def get_paths(lfw_dir, pairs):\n", 412 | " nrof_skipped_pairs = 0\n", 413 | " path_list = []\n", 414 | " issame_list = []\n", 415 | " for pair in pairs:\n", 416 | " if len(pair) == 3:\n", 417 | " path0 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])))\n", 418 | " path1 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[2])))\n", 419 | " issame = True\n", 420 | " elif len(pair) == 4:\n", 421 | " path0 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])))\n", 422 | " path1 = add_extension(os.path.join(lfw_dir, pair[2], pair[2] + '_' + '%04d' % int(pair[3])))\n", 423 | " issame = False\n", 424 | " if os.path.exists(path0) and os.path.exists(path1): # Only add the pair if both paths exist\n", 425 | " path_list += (path0,path1)\n", 426 | " issame_list.append(issame)\n", 427 | " else:\n", 428 | " nrof_skipped_pairs += 1\n", 429 | " if nrof_skipped_pairs>0:\n", 430 | " print('Skipped %d image pairs' % nrof_skipped_pairs)\n", 431 | "\n", 432 | " return path_list, issame_list\n", 433 | "\n", 434 | "def read_pairs(pairs_filename):\n", 435 | " pairs = []\n", 436 | " with open(pairs_filename, 'r') as f:\n", 437 | " for line in f.readlines()[1:]:\n", 438 | " pair = line.strip().split()\n", 439 | " pairs.append(pair)\n", 440 | " return np.array(pairs, dtype=object)" 441 | ], 442 | "metadata": { 443 | "collapsed": false, 444 | "pycharm": { 445 | "name": "#%%\n" 446 | } 447 | } 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 14, 452 | "outputs": [], 453 | "source": [ 454 | "pairs = read_pairs(pairs_path)\n", 455 | "path_list, issame_list = get_paths(data_dir+'_cropped', pairs)\n", 456 | "embeddings = np.array([embeddings_dict[path] for path in path_list])\n", 457 | "\n", 458 | "tpr, fpr, accuracy, val, val_std, far, fp, fn = evaluate(embeddings, issame_list)" 459 | ], 460 | "metadata": { 461 | "collapsed": false, 462 | "pycharm": { 463 | "name": "#%%\n" 464 | } 465 | } 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": 15, 470 | "outputs": [ 471 | { 472 | "name": "stdout", 473 | "output_type": "stream", 474 | "text": [ 475 | "[0.995 0.995 0.99166667 0.99 0.99 0.99666667\n", 476 | " 0.99 0.995 0.99666667 0.99666667]\n" 477 | ] 478 | }, 479 | { 480 | "data": { 481 | "text/plain": "0.9936666666666666" 482 | }, 483 | "execution_count": 15, 484 | "metadata": {}, 485 | "output_type": "execute_result" 486 | } 487 | ], 488 | "source": [ 489 | "print(accuracy)\n", 490 | "np.mean(accuracy)\n", 491 | "\n" 492 | ], 493 | "metadata": { 494 | "collapsed": false, 495 | "pycharm": { 496 | "name": "#%%\n" 497 | } 498 | } 499 | } 500 | ], 501 | "metadata": { 502 | "kernelspec": { 503 | "display_name": "Python 3", 504 | "language": "python", 505 | "name": "python3" 506 | }, 507 | "language_info": { 508 | "codemirror_mode": { 509 | "name": "ipython", 510 | "version": 2 511 | }, 512 | "file_extension": ".py", 513 | "mimetype": "text/x-python", 514 | "name": "python", 515 | "nbconvert_exporter": "python", 516 | "pygments_lexer": "ipython2", 517 | "version": "2.7.6" 518 | } 519 | }, 520 | "nbformat": 4, 521 | "nbformat_minor": 0 522 | } -------------------------------------------------------------------------------- /examples/lfw_evaluate_cn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### facenet-pytorch LFW评估\n", 8 | "\n", 9 | "本笔记本演示了如何针对LFW数据集评估性能。" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": { 16 | "ExecuteTime": { 17 | "end_time": "2023-07-20T11:38:51.769786Z", 18 | "start_time": "2023-07-20T11:38:50.053469Z" 19 | }, 20 | "pycharm": { 21 | "name": "#%%\n" 22 | } 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training, extract_face\n", 27 | "import torch\n", 28 | "from torch.utils.data import DataLoader, SubsetRandomSampler, SequentialSampler\n", 29 | "from torchvision import datasets, transforms\n", 30 | "import numpy as np\n", 31 | "import os" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": { 38 | "ExecuteTime": { 39 | "end_time": "2023-07-20T11:38:51.775776Z", 40 | "start_time": "2023-07-20T11:38:51.770780Z" 41 | }, 42 | "pycharm": { 43 | "name": "#%%\n" 44 | } 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "data_dir = 'data/lfw/lfw'\n", 49 | "pairs_path = 'data/lfw/pairs.txt'\n", 50 | "\n", 51 | "batch_size = 16\n", 52 | "epochs = 15\n", 53 | "workers = 0 if os.name == 'nt' else 8" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": { 60 | "ExecuteTime": { 61 | "end_time": "2023-07-20T11:38:51.847516Z", 62 | "start_time": "2023-07-20T11:38:51.776769Z" 63 | }, 64 | "pycharm": { 65 | "name": "#%%\n" 66 | } 67 | }, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "在该设备上运行: cuda:0\n" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", 79 | "print('在该设备上运行: {}'.format(device))" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 4, 85 | "metadata": { 86 | "pycharm": { 87 | "name": "#%%\n" 88 | } 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "mtcnn = MTCNN(\n", 93 | " image_size=160,\n", 94 | " margin=14,\n", 95 | " device=device,\n", 96 | " selection_method='center_weighted_size'\n", 97 | ")" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 5, 103 | "metadata": { 104 | "pycharm": { 105 | "name": "#%%\n" 106 | } 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "# 定义输入图像的数据加载器\n", 111 | "orig_img_ds = datasets.ImageFolder(data_dir, transform=None)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 6, 117 | "metadata": { 118 | "pycharm": { 119 | "name": "#%%\n" 120 | } 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "\n", 125 | "# 覆盖数据集中的类标签以使用路径,以便在mtcnn批处理中保存输出\n", 126 | "orig_img_ds.samples = [\n", 127 | " (p, p)\n", 128 | " for p, _ in orig_img_ds.samples\n", 129 | "]\n", 130 | "\n", 131 | "loader = DataLoader(\n", 132 | " orig_img_ds,\n", 133 | " num_workers=workers,\n", 134 | " batch_size=batch_size,\n", 135 | " collate_fn=training.collate_pil\n", 136 | ")" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "crop_paths = []\n", 146 | "box_probs = []\n", 147 | "\n", 148 | "for i, (x, b_paths) in enumerate(loader):\n", 149 | " crops = [p.replace(data_dir, data_dir + '_cropped') for p in b_paths]\n", 150 | " mtcnn(x, save_path=crops)\n", 151 | " crop_paths.extend(crops)\n", 152 | " print('\\r第 {} 批,共 {} 批'.format(i + 1, len(loader)), end='')" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 8, 158 | "metadata": { 159 | "pycharm": { 160 | "name": "#%%\n" 161 | } 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "# 为减少GPU内存使用,删除mtcnn\n", 166 | "del mtcnn\n", 167 | "torch.cuda.empty_cache()" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 9, 173 | "metadata": { 174 | "pycharm": { 175 | "name": "#%%\n" 176 | } 177 | }, 178 | "outputs": [], 179 | "source": [ 180 | "# 从MTCNN裁剪的图像输出创建数据集和数据加载器\n", 181 | "\n", 182 | "trans = transforms.Compose([\n", 183 | " np.float32,\n", 184 | " transforms.ToTensor(),\n", 185 | " fixed_image_standardization\n", 186 | "])\n", 187 | "\n", 188 | "dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)\n", 189 | "\n", 190 | "embed_loader = DataLoader(\n", 191 | " dataset,\n", 192 | " num_workers=workers,\n", 193 | " batch_size=batch_size,\n", 194 | " sampler=SequentialSampler(dataset)\n", 195 | ")" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 10, 201 | "metadata": { 202 | "pycharm": { 203 | "name": "#%%\n" 204 | } 205 | }, 206 | "outputs": [], 207 | "source": [ 208 | "# 加载预训练的Resnet模型\n", 209 | "resnet = InceptionResnetV1(\n", 210 | " classify=False,\n", 211 | " pretrained='vggface2'\n", 212 | ").to(device)" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 11, 218 | "metadata": { 219 | "pycharm": { 220 | "name": "#%%\n" 221 | } 222 | }, 223 | "outputs": [], 224 | "source": [ 225 | "classes = []\n", 226 | "embeddings = []\n", 227 | "resnet.eval()\n", 228 | "with torch.no_grad():\n", 229 | " for xb, yb in embed_loader:\n", 230 | " xb = xb.to(device)\n", 231 | " b_embeddings = resnet(xb)\n", 232 | " b_embeddings = b_embeddings.to('cpu').numpy()\n", 233 | " classes.extend(yb.numpy())\n", 234 | " embeddings.extend(b_embeddings)" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": 12, 240 | "metadata": { 241 | "pycharm": { 242 | "name": "#%%\n" 243 | } 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "embeddings_dict = dict(zip(crop_paths,embeddings))\n", 248 | "\n" 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "metadata": {}, 254 | "source": [ 255 | "#### 使用距离度量评估嵌入,在官方LFW测试集上执行验证。\n", 256 | "\n", 257 | "下一个块中的函数是从`facenet.src.lfw`复制粘贴的。不幸的是,该模块具有从`facenet`绝对导入的绝对导入,因此无法从子模块导入\n", 258 | "\n", 259 | "添加了返回假阳性和假阴性的功能。" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": 13, 265 | "metadata": { 266 | "pycharm": { 267 | "name": "#%%\n" 268 | } 269 | }, 270 | "outputs": [], 271 | "source": [ 272 | "from sklearn.model_selection import KFold\n", 273 | "from scipy import interpolate\n", 274 | "\n", 275 | "# 以下是从David Sandberg的FaceNet实现中提取的LFW函数\n", 276 | "def distance(embeddings1, embeddings2, distance_metric=0):\n", 277 | " if distance_metric==0:\n", 278 | " # Euclidian distance\n", 279 | " diff = np.subtract(embeddings1, embeddings2)\n", 280 | " dist = np.sum(np.square(diff),1)\n", 281 | " elif distance_metric==1:\n", 282 | " # 基于余弦相似度的距离\n", 283 | " dot = np.sum(np.multiply(embeddings1, embeddings2), axis=1)\n", 284 | " norm = np.linalg.norm(embeddings1, axis=1) * np.linalg.norm(embeddings2, axis=1)\n", 285 | " similarity = dot / norm\n", 286 | " dist = np.arccos(similarity) / math.pi\n", 287 | " else:\n", 288 | " raise 'Undefined distance metric %d' % distance_metric\n", 289 | "\n", 290 | " return dist\n", 291 | "\n", 292 | "def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10, distance_metric=0, subtract_mean=False):\n", 293 | " assert(embeddings1.shape[0] == embeddings2.shape[0])\n", 294 | " assert(embeddings1.shape[1] == embeddings2.shape[1])\n", 295 | " nrof_pairs = min(len(actual_issame), embeddings1.shape[0])\n", 296 | " nrof_thresholds = len(thresholds)\n", 297 | " k_fold = KFold(n_splits=nrof_folds, shuffle=False)\n", 298 | "\n", 299 | " tprs = np.zeros((nrof_folds,nrof_thresholds))\n", 300 | " fprs = np.zeros((nrof_folds,nrof_thresholds))\n", 301 | " accuracy = np.zeros((nrof_folds))\n", 302 | "\n", 303 | " is_false_positive = []\n", 304 | " is_false_negative = []\n", 305 | "\n", 306 | " indices = np.arange(nrof_pairs)\n", 307 | "\n", 308 | " for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):\n", 309 | " if subtract_mean:\n", 310 | " mean = np.mean(np.concatenate([embeddings1[train_set], embeddings2[train_set]]), axis=0)\n", 311 | " else:\n", 312 | " mean = 0.0\n", 313 | " dist = distance(embeddings1-mean, embeddings2-mean, distance_metric)\n", 314 | "\n", 315 | " # 寻找折叠的最佳阈值\n", 316 | " acc_train = np.zeros((nrof_thresholds))\n", 317 | " for threshold_idx, threshold in enumerate(thresholds):\n", 318 | " _, _, acc_train[threshold_idx], _ ,_ = calculate_accuracy(threshold, dist[train_set], actual_issame[train_set])\n", 319 | " best_threshold_index = np.argmax(acc_train)\n", 320 | " for threshold_idx, threshold in enumerate(thresholds):\n", 321 | " tprs[fold_idx,threshold_idx], fprs[fold_idx,threshold_idx], _, _, _ = calculate_accuracy(threshold, dist[test_set], actual_issame[test_set])\n", 322 | " _, _, accuracy[fold_idx], is_fp, is_fn = calculate_accuracy(thresholds[best_threshold_index], dist[test_set], actual_issame[test_set])\n", 323 | "\n", 324 | " tpr = np.mean(tprs,0)\n", 325 | " fpr = np.mean(fprs,0)\n", 326 | " is_false_positive.extend(is_fp)\n", 327 | " is_false_negative.extend(is_fn)\n", 328 | "\n", 329 | " return tpr, fpr, accuracy, is_false_positive, is_false_negative\n", 330 | "\n", 331 | "def calculate_accuracy(threshold, dist, actual_issame):\n", 332 | " predict_issame = np.less(dist, threshold)\n", 333 | " tp = np.sum(np.logical_and(predict_issame, actual_issame))\n", 334 | " fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))\n", 335 | " tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame)))\n", 336 | " fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))\n", 337 | "\n", 338 | " is_fp = np.logical_and(predict_issame, np.logical_not(actual_issame))\n", 339 | " is_fn = np.logical_and(np.logical_not(predict_issame), actual_issame)\n", 340 | "\n", 341 | " tpr = 0 if (tp+fn==0) else float(tp) / float(tp+fn)\n", 342 | " fpr = 0 if (fp+tn==0) else float(fp) / float(fp+tn)\n", 343 | " acc = float(tp+tn)/dist.size\n", 344 | " return tpr, fpr, acc, is_fp, is_fn\n", 345 | "\n", 346 | "def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds=10, distance_metric=0, subtract_mean=False):\n", 347 | " assert(embeddings1.shape[0] == embeddings2.shape[0])\n", 348 | " assert(embeddings1.shape[1] == embeddings2.shape[1])\n", 349 | " nrof_pairs = min(len(actual_issame), embeddings1.shape[0])\n", 350 | " nrof_thresholds = len(thresholds)\n", 351 | " k_fold = KFold(n_splits=nrof_folds, shuffle=False)\n", 352 | "\n", 353 | " val = np.zeros(nrof_folds)\n", 354 | " far = np.zeros(nrof_folds)\n", 355 | "\n", 356 | " indices = np.arange(nrof_pairs)\n", 357 | "\n", 358 | " for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):\n", 359 | " if subtract_mean:\n", 360 | " mean = np.mean(np.concatenate([embeddings1[train_set], embeddings2[train_set]]), axis=0)\n", 361 | " else:\n", 362 | " mean = 0.0\n", 363 | " dist = distance(embeddings1-mean, embeddings2-mean, distance_metric)\n", 364 | "\n", 365 | " # 找到使FAR = far_target的阈值\n", 366 | " far_train = np.zeros(nrof_thresholds)\n", 367 | " for threshold_idx, threshold in enumerate(thresholds):\n", 368 | " _, far_train[threshold_idx] = calculate_val_far(threshold, dist[train_set], actual_issame[train_set])\n", 369 | " if np.max(far_train)>=far_target:\n", 370 | " f = interpolate.interp1d(far_train, thresholds, kind='slinear')\n", 371 | " threshold = f(far_target)\n", 372 | " else:\n", 373 | " threshold = 0.0\n", 374 | "\n", 375 | " val[fold_idx], far[fold_idx] = calculate_val_far(threshold, dist[test_set], actual_issame[test_set])\n", 376 | "\n", 377 | " val_mean = np.mean(val)\n", 378 | " far_mean = np.mean(far)\n", 379 | " val_std = np.std(val)\n", 380 | " return val_mean, val_std, far_mean\n", 381 | "\n", 382 | "def calculate_val_far(threshold, dist, actual_issame):\n", 383 | " predict_issame = np.less(dist, threshold)\n", 384 | " true_accept = np.sum(np.logical_and(predict_issame, actual_issame))\n", 385 | " false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))\n", 386 | " n_same = np.sum(actual_issame)\n", 387 | " n_diff = np.sum(np.logical_not(actual_issame))\n", 388 | " val = float(true_accept) / float(n_same)\n", 389 | " far = float(false_accept) / float(n_diff)\n", 390 | " return val, far\n", 391 | "\n", 392 | "\n", 393 | "\n", 394 | "def evaluate(embeddings, actual_issame, nrof_folds=10, distance_metric=0, subtract_mean=False):\n", 395 | " # 计算评估指标\n", 396 | " thresholds = np.arange(0, 4, 0.01)\n", 397 | " embeddings1 = embeddings[0::2]\n", 398 | " embeddings2 = embeddings[1::2]\n", 399 | " tpr, fpr, accuracy, fp, fn = calculate_roc(thresholds, embeddings1, embeddings2,\n", 400 | " np.asarray(actual_issame), nrof_folds=nrof_folds, distance_metric=distance_metric, subtract_mean=subtract_mean)\n", 401 | " thresholds = np.arange(0, 4, 0.001)\n", 402 | " val, val_std, far = calculate_val(thresholds, embeddings1, embeddings2,\n", 403 | " np.asarray(actual_issame), 1e-3, nrof_folds=nrof_folds, distance_metric=distance_metric, subtract_mean=subtract_mean)\n", 404 | " return tpr, fpr, accuracy, val, val_std, far, fp, fn\n", 405 | "\n", 406 | "def add_extension(path):\n", 407 | " if os.path.exists(path+'.jpg'):\n", 408 | " return path+'.jpg'\n", 409 | " elif os.path.exists(path+'.png'):\n", 410 | " return path+'.png'\n", 411 | " else:\n", 412 | " raise RuntimeError('No file \"%s\" with extension png or jpg.' % path)\n", 413 | "\n", 414 | "def get_paths(lfw_dir, pairs):\n", 415 | " nrof_skipped_pairs = 0\n", 416 | " path_list = []\n", 417 | " issame_list = []\n", 418 | " for pair in pairs:\n", 419 | " if len(pair) == 3:\n", 420 | " path0 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])))\n", 421 | " path1 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[2])))\n", 422 | " issame = True\n", 423 | " elif len(pair) == 4:\n", 424 | " path0 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])))\n", 425 | " path1 = add_extension(os.path.join(lfw_dir, pair[2], pair[2] + '_' + '%04d' % int(pair[3])))\n", 426 | " issame = False\n", 427 | " if os.path.exists(path0) and os.path.exists(path1): # 仅在两个路径都存在时添加配对\n", 428 | " path_list += (path0,path1)\n", 429 | " issame_list.append(issame)\n", 430 | " else:\n", 431 | " nrof_skipped_pairs += 1\n", 432 | " if nrof_skipped_pairs>0:\n", 433 | " print('跳过 %d 个图像对' % nrof_skipped_pairs)\n", 434 | "\n", 435 | " return path_list, issame_list\n", 436 | "\n", 437 | "def read_pairs(pairs_filename):\n", 438 | " pairs = []\n", 439 | " with open(pairs_filename, 'r') as f:\n", 440 | " for line in f.readlines()[1:]:\n", 441 | " pair = line.strip().split()\n", 442 | " pairs.append(pair)\n", 443 | " return np.array(pairs, dtype=object)" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": 14, 449 | "metadata": { 450 | "pycharm": { 451 | "name": "#%%\n" 452 | } 453 | }, 454 | "outputs": [], 455 | "source": [ 456 | "pairs = read_pairs(pairs_path)\n", 457 | "path_list, issame_list = get_paths(data_dir+'_cropped', pairs)\n", 458 | "embeddings = np.array([embeddings_dict[path] for path in path_list])\n", 459 | "\n", 460 | "tpr, fpr, accuracy, val, val_std, far, fp, fn = evaluate(embeddings, issame_list)" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": 15, 466 | "metadata": { 467 | "pycharm": { 468 | "name": "#%%\n" 469 | } 470 | }, 471 | "outputs": [ 472 | { 473 | "name": "stdout", 474 | "output_type": "stream", 475 | "text": [ 476 | "[0.995 0.995 0.99166667 0.99 0.99 0.99666667\n", 477 | " 0.99 0.995 0.99666667 0.99666667]\n" 478 | ] 479 | }, 480 | { 481 | "data": { 482 | "text/plain": [ 483 | "0.9936666666666666" 484 | ] 485 | }, 486 | "execution_count": 15, 487 | "metadata": {}, 488 | "output_type": "execute_result" 489 | } 490 | ], 491 | "source": [ 492 | "print(accuracy)\n", 493 | "np.mean(accuracy)\n", 494 | "\n" 495 | ] 496 | } 497 | ], 498 | "metadata": { 499 | "kernelspec": { 500 | "display_name": "Python 3 (ipykernel)", 501 | "language": "python", 502 | "name": "python3" 503 | }, 504 | "language_info": { 505 | "codemirror_mode": { 506 | "name": "ipython", 507 | "version": 3 508 | }, 509 | "file_extension": ".py", 510 | "mimetype": "text/x-python", 511 | "name": "python", 512 | "nbconvert_exporter": "python", 513 | "pygments_lexer": "ipython3", 514 | "version": "3.11.3" 515 | } 516 | }, 517 | "nbformat": 4, 518 | "nbformat_minor": 1 519 | } 520 | -------------------------------------------------------------------------------- /examples/performance-comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/examples/performance-comparison.png -------------------------------------------------------------------------------- /examples/tracked.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/examples/tracked.gif -------------------------------------------------------------------------------- /examples/video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/examples/video.mp4 -------------------------------------------------------------------------------- /examples/video_tracked.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/timesler/facenet-pytorch/787da06156087cd6b616fe6608213722bddc30cd/examples/video_tracked.mp4 -------------------------------------------------------------------------------- /models/inception_resnet_v1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | from requests.adapters import HTTPAdapter 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from .utils.download import download_url_to_file 10 | 11 | 12 | class BasicConv2d(nn.Module): 13 | 14 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 15 | super().__init__() 16 | self.conv = nn.Conv2d( 17 | in_planes, out_planes, 18 | kernel_size=kernel_size, stride=stride, 19 | padding=padding, bias=False 20 | ) # verify bias false 21 | self.bn = nn.BatchNorm2d( 22 | out_planes, 23 | eps=0.001, # value found in tensorflow 24 | momentum=0.1, # default pytorch value 25 | affine=True 26 | ) 27 | self.relu = nn.ReLU(inplace=False) 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | x = self.bn(x) 32 | x = self.relu(x) 33 | return x 34 | 35 | 36 | class Block35(nn.Module): 37 | 38 | def __init__(self, scale=1.0): 39 | super().__init__() 40 | 41 | self.scale = scale 42 | 43 | self.branch0 = BasicConv2d(256, 32, kernel_size=1, stride=1) 44 | 45 | self.branch1 = nn.Sequential( 46 | BasicConv2d(256, 32, kernel_size=1, stride=1), 47 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) 48 | ) 49 | 50 | self.branch2 = nn.Sequential( 51 | BasicConv2d(256, 32, kernel_size=1, stride=1), 52 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1), 53 | BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) 54 | ) 55 | 56 | self.conv2d = nn.Conv2d(96, 256, kernel_size=1, stride=1) 57 | self.relu = nn.ReLU(inplace=False) 58 | 59 | def forward(self, x): 60 | x0 = self.branch0(x) 61 | x1 = self.branch1(x) 62 | x2 = self.branch2(x) 63 | out = torch.cat((x0, x1, x2), 1) 64 | out = self.conv2d(out) 65 | out = out * self.scale + x 66 | out = self.relu(out) 67 | return out 68 | 69 | 70 | class Block17(nn.Module): 71 | 72 | def __init__(self, scale=1.0): 73 | super().__init__() 74 | 75 | self.scale = scale 76 | 77 | self.branch0 = BasicConv2d(896, 128, kernel_size=1, stride=1) 78 | 79 | self.branch1 = nn.Sequential( 80 | BasicConv2d(896, 128, kernel_size=1, stride=1), 81 | BasicConv2d(128, 128, kernel_size=(1,7), stride=1, padding=(0,3)), 82 | BasicConv2d(128, 128, kernel_size=(7,1), stride=1, padding=(3,0)) 83 | ) 84 | 85 | self.conv2d = nn.Conv2d(256, 896, kernel_size=1, stride=1) 86 | self.relu = nn.ReLU(inplace=False) 87 | 88 | def forward(self, x): 89 | x0 = self.branch0(x) 90 | x1 = self.branch1(x) 91 | out = torch.cat((x0, x1), 1) 92 | out = self.conv2d(out) 93 | out = out * self.scale + x 94 | out = self.relu(out) 95 | return out 96 | 97 | 98 | class Block8(nn.Module): 99 | 100 | def __init__(self, scale=1.0, noReLU=False): 101 | super().__init__() 102 | 103 | self.scale = scale 104 | self.noReLU = noReLU 105 | 106 | self.branch0 = BasicConv2d(1792, 192, kernel_size=1, stride=1) 107 | 108 | self.branch1 = nn.Sequential( 109 | BasicConv2d(1792, 192, kernel_size=1, stride=1), 110 | BasicConv2d(192, 192, kernel_size=(1,3), stride=1, padding=(0,1)), 111 | BasicConv2d(192, 192, kernel_size=(3,1), stride=1, padding=(1,0)) 112 | ) 113 | 114 | self.conv2d = nn.Conv2d(384, 1792, kernel_size=1, stride=1) 115 | if not self.noReLU: 116 | self.relu = nn.ReLU(inplace=False) 117 | 118 | def forward(self, x): 119 | x0 = self.branch0(x) 120 | x1 = self.branch1(x) 121 | out = torch.cat((x0, x1), 1) 122 | out = self.conv2d(out) 123 | out = out * self.scale + x 124 | if not self.noReLU: 125 | out = self.relu(out) 126 | return out 127 | 128 | 129 | class Mixed_6a(nn.Module): 130 | 131 | def __init__(self): 132 | super().__init__() 133 | 134 | self.branch0 = BasicConv2d(256, 384, kernel_size=3, stride=2) 135 | 136 | self.branch1 = nn.Sequential( 137 | BasicConv2d(256, 192, kernel_size=1, stride=1), 138 | BasicConv2d(192, 192, kernel_size=3, stride=1, padding=1), 139 | BasicConv2d(192, 256, kernel_size=3, stride=2) 140 | ) 141 | 142 | self.branch2 = nn.MaxPool2d(3, stride=2) 143 | 144 | def forward(self, x): 145 | x0 = self.branch0(x) 146 | x1 = self.branch1(x) 147 | x2 = self.branch2(x) 148 | out = torch.cat((x0, x1, x2), 1) 149 | return out 150 | 151 | 152 | class Mixed_7a(nn.Module): 153 | 154 | def __init__(self): 155 | super().__init__() 156 | 157 | self.branch0 = nn.Sequential( 158 | BasicConv2d(896, 256, kernel_size=1, stride=1), 159 | BasicConv2d(256, 384, kernel_size=3, stride=2) 160 | ) 161 | 162 | self.branch1 = nn.Sequential( 163 | BasicConv2d(896, 256, kernel_size=1, stride=1), 164 | BasicConv2d(256, 256, kernel_size=3, stride=2) 165 | ) 166 | 167 | self.branch2 = nn.Sequential( 168 | BasicConv2d(896, 256, kernel_size=1, stride=1), 169 | BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), 170 | BasicConv2d(256, 256, kernel_size=3, stride=2) 171 | ) 172 | 173 | self.branch3 = nn.MaxPool2d(3, stride=2) 174 | 175 | def forward(self, x): 176 | x0 = self.branch0(x) 177 | x1 = self.branch1(x) 178 | x2 = self.branch2(x) 179 | x3 = self.branch3(x) 180 | out = torch.cat((x0, x1, x2, x3), 1) 181 | return out 182 | 183 | 184 | class InceptionResnetV1(nn.Module): 185 | """Inception Resnet V1 model with optional loading of pretrained weights. 186 | 187 | Model parameters can be loaded based on pretraining on the VGGFace2 or CASIA-Webface 188 | datasets. Pretrained state_dicts are automatically downloaded on model instantiation if 189 | requested and cached in the torch cache. Subsequent instantiations use the cache rather than 190 | redownloading. 191 | 192 | Keyword Arguments: 193 | pretrained {str} -- Optional pretraining dataset. Either 'vggface2' or 'casia-webface'. 194 | (default: {None}) 195 | classify {bool} -- Whether the model should output classification probabilities or feature 196 | embeddings. (default: {False}) 197 | num_classes {int} -- Number of output classes. If 'pretrained' is set and num_classes not 198 | equal to that used for the pretrained model, the final linear layer will be randomly 199 | initialized. (default: {None}) 200 | dropout_prob {float} -- Dropout probability. (default: {0.6}) 201 | """ 202 | def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None): 203 | super().__init__() 204 | 205 | # Set simple attributes 206 | self.pretrained = pretrained 207 | self.classify = classify 208 | self.num_classes = num_classes 209 | 210 | if pretrained == 'vggface2': 211 | tmp_classes = 8631 212 | elif pretrained == 'casia-webface': 213 | tmp_classes = 10575 214 | elif pretrained is None and self.classify and self.num_classes is None: 215 | raise Exception('If "pretrained" is not specified and "classify" is True, "num_classes" must be specified') 216 | 217 | 218 | # Define layers 219 | self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) 220 | self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) 221 | self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) 222 | self.maxpool_3a = nn.MaxPool2d(3, stride=2) 223 | self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) 224 | self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) 225 | self.conv2d_4b = BasicConv2d(192, 256, kernel_size=3, stride=2) 226 | self.repeat_1 = nn.Sequential( 227 | Block35(scale=0.17), 228 | Block35(scale=0.17), 229 | Block35(scale=0.17), 230 | Block35(scale=0.17), 231 | Block35(scale=0.17), 232 | ) 233 | self.mixed_6a = Mixed_6a() 234 | self.repeat_2 = nn.Sequential( 235 | Block17(scale=0.10), 236 | Block17(scale=0.10), 237 | Block17(scale=0.10), 238 | Block17(scale=0.10), 239 | Block17(scale=0.10), 240 | Block17(scale=0.10), 241 | Block17(scale=0.10), 242 | Block17(scale=0.10), 243 | Block17(scale=0.10), 244 | Block17(scale=0.10), 245 | ) 246 | self.mixed_7a = Mixed_7a() 247 | self.repeat_3 = nn.Sequential( 248 | Block8(scale=0.20), 249 | Block8(scale=0.20), 250 | Block8(scale=0.20), 251 | Block8(scale=0.20), 252 | Block8(scale=0.20), 253 | ) 254 | self.block8 = Block8(noReLU=True) 255 | self.avgpool_1a = nn.AdaptiveAvgPool2d(1) 256 | self.dropout = nn.Dropout(dropout_prob) 257 | self.last_linear = nn.Linear(1792, 512, bias=False) 258 | self.last_bn = nn.BatchNorm1d(512, eps=0.001, momentum=0.1, affine=True) 259 | 260 | if pretrained is not None: 261 | self.logits = nn.Linear(512, tmp_classes) 262 | load_weights(self, pretrained) 263 | 264 | if self.classify and self.num_classes is not None: 265 | self.logits = nn.Linear(512, self.num_classes) 266 | 267 | self.device = torch.device('cpu') 268 | if device is not None: 269 | self.device = device 270 | self.to(device) 271 | 272 | def forward(self, x): 273 | """Calculate embeddings or logits given a batch of input image tensors. 274 | 275 | Arguments: 276 | x {torch.tensor} -- Batch of image tensors representing faces. 277 | 278 | Returns: 279 | torch.tensor -- Batch of embedding vectors or multinomial logits. 280 | """ 281 | x = self.conv2d_1a(x) 282 | x = self.conv2d_2a(x) 283 | x = self.conv2d_2b(x) 284 | x = self.maxpool_3a(x) 285 | x = self.conv2d_3b(x) 286 | x = self.conv2d_4a(x) 287 | x = self.conv2d_4b(x) 288 | x = self.repeat_1(x) 289 | x = self.mixed_6a(x) 290 | x = self.repeat_2(x) 291 | x = self.mixed_7a(x) 292 | x = self.repeat_3(x) 293 | x = self.block8(x) 294 | x = self.avgpool_1a(x) 295 | x = self.dropout(x) 296 | x = self.last_linear(x.view(x.shape[0], -1)) 297 | x = self.last_bn(x) 298 | if self.classify: 299 | x = self.logits(x) 300 | else: 301 | x = F.normalize(x, p=2, dim=1) 302 | return x 303 | 304 | 305 | def load_weights(mdl, name): 306 | """Download pretrained state_dict and load into model. 307 | 308 | Arguments: 309 | mdl {torch.nn.Module} -- Pytorch model. 310 | name {str} -- Name of dataset that was used to generate pretrained state_dict. 311 | 312 | Raises: 313 | ValueError: If 'pretrained' not equal to 'vggface2' or 'casia-webface'. 314 | """ 315 | if name == 'vggface2': 316 | path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180402-114759-vggface2.pt' 317 | elif name == 'casia-webface': 318 | path = 'https://github.com/timesler/facenet-pytorch/releases/download/v2.2.9/20180408-102900-casia-webface.pt' 319 | else: 320 | raise ValueError('Pretrained models only exist for "vggface2" and "casia-webface"') 321 | 322 | model_dir = os.path.join(get_torch_home(), 'checkpoints') 323 | os.makedirs(model_dir, exist_ok=True) 324 | 325 | cached_file = os.path.join(model_dir, os.path.basename(path)) 326 | if not os.path.exists(cached_file): 327 | download_url_to_file(path, cached_file) 328 | 329 | state_dict = torch.load(cached_file) 330 | mdl.load_state_dict(state_dict) 331 | 332 | 333 | def get_torch_home(): 334 | torch_home = os.path.expanduser( 335 | os.getenv( 336 | 'TORCH_HOME', 337 | os.path.join(os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch') 338 | ) 339 | ) 340 | return torch_home 341 | -------------------------------------------------------------------------------- /models/mtcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import os 5 | 6 | from .utils.detect_face import detect_face, extract_face 7 | 8 | 9 | class PNet(nn.Module): 10 | """MTCNN PNet. 11 | 12 | Keyword Arguments: 13 | pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True}) 14 | """ 15 | 16 | def __init__(self, pretrained=True): 17 | super().__init__() 18 | 19 | self.conv1 = nn.Conv2d(3, 10, kernel_size=3) 20 | self.prelu1 = nn.PReLU(10) 21 | self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True) 22 | self.conv2 = nn.Conv2d(10, 16, kernel_size=3) 23 | self.prelu2 = nn.PReLU(16) 24 | self.conv3 = nn.Conv2d(16, 32, kernel_size=3) 25 | self.prelu3 = nn.PReLU(32) 26 | self.conv4_1 = nn.Conv2d(32, 2, kernel_size=1) 27 | self.softmax4_1 = nn.Softmax(dim=1) 28 | self.conv4_2 = nn.Conv2d(32, 4, kernel_size=1) 29 | 30 | self.training = False 31 | 32 | if pretrained: 33 | state_dict_path = os.path.join(os.path.dirname(__file__), '../data/pnet.pt') 34 | state_dict = torch.load(state_dict_path) 35 | self.load_state_dict(state_dict) 36 | 37 | def forward(self, x): 38 | x = self.conv1(x) 39 | x = self.prelu1(x) 40 | x = self.pool1(x) 41 | x = self.conv2(x) 42 | x = self.prelu2(x) 43 | x = self.conv3(x) 44 | x = self.prelu3(x) 45 | a = self.conv4_1(x) 46 | a = self.softmax4_1(a) 47 | b = self.conv4_2(x) 48 | return b, a 49 | 50 | 51 | class RNet(nn.Module): 52 | """MTCNN RNet. 53 | 54 | Keyword Arguments: 55 | pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True}) 56 | """ 57 | 58 | def __init__(self, pretrained=True): 59 | super().__init__() 60 | 61 | self.conv1 = nn.Conv2d(3, 28, kernel_size=3) 62 | self.prelu1 = nn.PReLU(28) 63 | self.pool1 = nn.MaxPool2d(3, 2, ceil_mode=True) 64 | self.conv2 = nn.Conv2d(28, 48, kernel_size=3) 65 | self.prelu2 = nn.PReLU(48) 66 | self.pool2 = nn.MaxPool2d(3, 2, ceil_mode=True) 67 | self.conv3 = nn.Conv2d(48, 64, kernel_size=2) 68 | self.prelu3 = nn.PReLU(64) 69 | self.dense4 = nn.Linear(576, 128) 70 | self.prelu4 = nn.PReLU(128) 71 | self.dense5_1 = nn.Linear(128, 2) 72 | self.softmax5_1 = nn.Softmax(dim=1) 73 | self.dense5_2 = nn.Linear(128, 4) 74 | 75 | self.training = False 76 | 77 | if pretrained: 78 | state_dict_path = os.path.join(os.path.dirname(__file__), '../data/rnet.pt') 79 | state_dict = torch.load(state_dict_path) 80 | self.load_state_dict(state_dict) 81 | 82 | def forward(self, x): 83 | x = self.conv1(x) 84 | x = self.prelu1(x) 85 | x = self.pool1(x) 86 | x = self.conv2(x) 87 | x = self.prelu2(x) 88 | x = self.pool2(x) 89 | x = self.conv3(x) 90 | x = self.prelu3(x) 91 | x = x.permute(0, 3, 2, 1).contiguous() 92 | x = self.dense4(x.view(x.shape[0], -1)) 93 | x = self.prelu4(x) 94 | a = self.dense5_1(x) 95 | a = self.softmax5_1(a) 96 | b = self.dense5_2(x) 97 | return b, a 98 | 99 | 100 | class ONet(nn.Module): 101 | """MTCNN ONet. 102 | 103 | Keyword Arguments: 104 | pretrained {bool} -- Whether or not to load saved pretrained weights (default: {True}) 105 | """ 106 | 107 | def __init__(self, pretrained=True): 108 | super().__init__() 109 | 110 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3) 111 | self.prelu1 = nn.PReLU(32) 112 | self.pool1 = nn.MaxPool2d(3, 2, ceil_mode=True) 113 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3) 114 | self.prelu2 = nn.PReLU(64) 115 | self.pool2 = nn.MaxPool2d(3, 2, ceil_mode=True) 116 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3) 117 | self.prelu3 = nn.PReLU(64) 118 | self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True) 119 | self.conv4 = nn.Conv2d(64, 128, kernel_size=2) 120 | self.prelu4 = nn.PReLU(128) 121 | self.dense5 = nn.Linear(1152, 256) 122 | self.prelu5 = nn.PReLU(256) 123 | self.dense6_1 = nn.Linear(256, 2) 124 | self.softmax6_1 = nn.Softmax(dim=1) 125 | self.dense6_2 = nn.Linear(256, 4) 126 | self.dense6_3 = nn.Linear(256, 10) 127 | 128 | self.training = False 129 | 130 | if pretrained: 131 | state_dict_path = os.path.join(os.path.dirname(__file__), '../data/onet.pt') 132 | state_dict = torch.load(state_dict_path) 133 | self.load_state_dict(state_dict) 134 | 135 | def forward(self, x): 136 | x = self.conv1(x) 137 | x = self.prelu1(x) 138 | x = self.pool1(x) 139 | x = self.conv2(x) 140 | x = self.prelu2(x) 141 | x = self.pool2(x) 142 | x = self.conv3(x) 143 | x = self.prelu3(x) 144 | x = self.pool3(x) 145 | x = self.conv4(x) 146 | x = self.prelu4(x) 147 | x = x.permute(0, 3, 2, 1).contiguous() 148 | x = self.dense5(x.view(x.shape[0], -1)) 149 | x = self.prelu5(x) 150 | a = self.dense6_1(x) 151 | a = self.softmax6_1(a) 152 | b = self.dense6_2(x) 153 | c = self.dense6_3(x) 154 | return b, c, a 155 | 156 | 157 | class MTCNN(nn.Module): 158 | """MTCNN face detection module. 159 | 160 | This class loads pretrained P-, R-, and O-nets and returns images cropped to include the face 161 | only, given raw input images of one of the following types: 162 | - PIL image or list of PIL images 163 | - numpy.ndarray (uint8) representing either a single image (3D) or a batch of images (4D). 164 | Cropped faces can optionally be saved to file 165 | also. 166 | 167 | Keyword Arguments: 168 | image_size {int} -- Output image size in pixels. The image will be square. (default: {160}) 169 | margin {int} -- Margin to add to bounding box, in terms of pixels in the final image. 170 | Note that the application of the margin differs slightly from the davidsandberg/facenet 171 | repo, which applies the margin to the original image before resizing, making the margin 172 | dependent on the original image size (this is a bug in davidsandberg/facenet). 173 | (default: {0}) 174 | min_face_size {int} -- Minimum face size to search for. (default: {20}) 175 | thresholds {list} -- MTCNN face detection thresholds (default: {[0.6, 0.7, 0.7]}) 176 | factor {float} -- Factor used to create a scaling pyramid of face sizes. (default: {0.709}) 177 | post_process {bool} -- Whether or not to post process images tensors before returning. 178 | (default: {True}) 179 | select_largest {bool} -- If True, if multiple faces are detected, the largest is returned. 180 | If False, the face with the highest detection probability is returned. 181 | (default: {True}) 182 | selection_method {string} -- Which heuristic to use for selection. Default None. If 183 | specified, will override select_largest: 184 | "probability": highest probability selected 185 | "largest": largest box selected 186 | "largest_over_threshold": largest box over a certain probability selected 187 | "center_weighted_size": box size minus weighted squared offset from image center 188 | (default: {None}) 189 | keep_all {bool} -- If True, all detected faces are returned, in the order dictated by the 190 | select_largest parameter. If a save_path is specified, the first face is saved to that 191 | path and the remaining faces are saved to 1, 2 etc. 192 | (default: {False}) 193 | device {torch.device} -- The device on which to run neural net passes. Image tensors and 194 | models are copied to this device before running forward passes. (default: {None}) 195 | """ 196 | 197 | def __init__( 198 | self, image_size=160, margin=0, min_face_size=20, 199 | thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True, 200 | select_largest=True, selection_method=None, keep_all=False, device=None 201 | ): 202 | super().__init__() 203 | 204 | self.image_size = image_size 205 | self.margin = margin 206 | self.min_face_size = min_face_size 207 | self.thresholds = thresholds 208 | self.factor = factor 209 | self.post_process = post_process 210 | self.select_largest = select_largest 211 | self.keep_all = keep_all 212 | self.selection_method = selection_method 213 | 214 | self.pnet = PNet() 215 | self.rnet = RNet() 216 | self.onet = ONet() 217 | 218 | self.device = torch.device('cpu') 219 | if device is not None: 220 | self.device = device 221 | self.to(device) 222 | 223 | if not self.selection_method: 224 | self.selection_method = 'largest' if self.select_largest else 'probability' 225 | 226 | def forward(self, img, save_path=None, return_prob=False): 227 | """Run MTCNN face detection on a PIL image or numpy array. This method performs both 228 | detection and extraction of faces, returning tensors representing detected faces rather 229 | than the bounding boxes. To access bounding boxes, see the MTCNN.detect() method below. 230 | 231 | Arguments: 232 | img {PIL.Image, np.ndarray, or list} -- A PIL image, np.ndarray, torch.Tensor, or list. 233 | 234 | Keyword Arguments: 235 | save_path {str} -- An optional save path for the cropped image. Note that when 236 | self.post_process=True, although the returned tensor is post processed, the saved 237 | face image is not, so it is a true representation of the face in the input image. 238 | If `img` is a list of images, `save_path` should be a list of equal length. 239 | (default: {None}) 240 | return_prob {bool} -- Whether or not to return the detection probability. 241 | (default: {False}) 242 | 243 | Returns: 244 | Union[torch.Tensor, tuple(torch.tensor, float)] -- If detected, cropped image of a face 245 | with dimensions 3 x image_size x image_size. Optionally, the probability that a 246 | face was detected. If self.keep_all is True, n detected faces are returned in an 247 | n x 3 x image_size x image_size tensor with an optional list of detection 248 | probabilities. If `img` is a list of images, the item(s) returned have an extra 249 | dimension (batch) as the first dimension. 250 | 251 | Example: 252 | >>> from facenet_pytorch import MTCNN 253 | >>> mtcnn = MTCNN() 254 | >>> face_tensor, prob = mtcnn(img, save_path='face.png', return_prob=True) 255 | """ 256 | 257 | # Detect faces 258 | batch_boxes, batch_probs, batch_points = self.detect(img, landmarks=True) 259 | # Select faces 260 | if not self.keep_all: 261 | batch_boxes, batch_probs, batch_points = self.select_boxes( 262 | batch_boxes, batch_probs, batch_points, img, method=self.selection_method 263 | ) 264 | # Extract faces 265 | faces = self.extract(img, batch_boxes, save_path) 266 | 267 | if return_prob: 268 | return faces, batch_probs 269 | else: 270 | return faces 271 | 272 | def detect(self, img, landmarks=False): 273 | """Detect all faces in PIL image and return bounding boxes and optional facial landmarks. 274 | 275 | This method is used by the forward method and is also useful for face detection tasks 276 | that require lower-level handling of bounding boxes and facial landmarks (e.g., face 277 | tracking). The functionality of the forward function can be emulated by using this method 278 | followed by the extract_face() function. 279 | 280 | Arguments: 281 | img {PIL.Image, np.ndarray, or list} -- A PIL image, np.ndarray, torch.Tensor, or list. 282 | 283 | Keyword Arguments: 284 | landmarks {bool} -- Whether to return facial landmarks in addition to bounding boxes. 285 | (default: {False}) 286 | 287 | Returns: 288 | tuple(numpy.ndarray, list) -- For N detected faces, a tuple containing an 289 | Nx4 array of bounding boxes and a length N list of detection probabilities. 290 | Returned boxes will be sorted in descending order by detection probability if 291 | self.select_largest=False, otherwise the largest face will be returned first. 292 | If `img` is a list of images, the items returned have an extra dimension 293 | (batch) as the first dimension. Optionally, a third item, the facial landmarks, 294 | are returned if `landmarks=True`. 295 | 296 | Example: 297 | >>> from PIL import Image, ImageDraw 298 | >>> from facenet_pytorch import MTCNN, extract_face 299 | >>> mtcnn = MTCNN(keep_all=True) 300 | >>> boxes, probs, points = mtcnn.detect(img, landmarks=True) 301 | >>> # Draw boxes and save faces 302 | >>> img_draw = img.copy() 303 | >>> draw = ImageDraw.Draw(img_draw) 304 | >>> for i, (box, point) in enumerate(zip(boxes, points)): 305 | ... draw.rectangle(box.tolist(), width=5) 306 | ... for p in point: 307 | ... draw.rectangle((p - 10).tolist() + (p + 10).tolist(), width=10) 308 | ... extract_face(img, box, save_path='detected_face_{}.png'.format(i)) 309 | >>> img_draw.save('annotated_faces.png') 310 | """ 311 | 312 | with torch.no_grad(): 313 | batch_boxes, batch_points = detect_face( 314 | img, self.min_face_size, 315 | self.pnet, self.rnet, self.onet, 316 | self.thresholds, self.factor, 317 | self.device 318 | ) 319 | 320 | boxes, probs, points = [], [], [] 321 | for box, point in zip(batch_boxes, batch_points): 322 | box = np.array(box) 323 | point = np.array(point) 324 | if len(box) == 0: 325 | boxes.append(None) 326 | probs.append([None]) 327 | points.append(None) 328 | elif self.select_largest: 329 | box_order = np.argsort((box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]))[::-1] 330 | box = box[box_order] 331 | point = point[box_order] 332 | boxes.append(box[:, :4]) 333 | probs.append(box[:, 4]) 334 | points.append(point) 335 | else: 336 | boxes.append(box[:, :4]) 337 | probs.append(box[:, 4]) 338 | points.append(point) 339 | boxes = np.array(boxes, dtype=object) 340 | probs = np.array(probs, dtype=object) 341 | points = np.array(points, dtype=object) 342 | 343 | if ( 344 | not isinstance(img, (list, tuple)) and 345 | not (isinstance(img, np.ndarray) and len(img.shape) == 4) and 346 | not (isinstance(img, torch.Tensor) and len(img.shape) == 4) 347 | ): 348 | boxes = boxes[0] 349 | probs = probs[0] 350 | points = points[0] 351 | 352 | if landmarks: 353 | return boxes, probs, points 354 | 355 | return boxes, probs 356 | 357 | def select_boxes( 358 | self, all_boxes, all_probs, all_points, imgs, method='probability', threshold=0.9, 359 | center_weight=2.0 360 | ): 361 | """Selects a single box from multiple for a given image using one of multiple heuristics. 362 | 363 | Arguments: 364 | all_boxes {np.ndarray} -- Ix0 ndarray where each element is a Nx4 ndarry of 365 | bounding boxes for N detected faces in I images (output from self.detect). 366 | all_probs {np.ndarray} -- Ix0 ndarray where each element is a Nx0 ndarry of 367 | probabilities for N detected faces in I images (output from self.detect). 368 | all_points {np.ndarray} -- Ix0 ndarray where each element is a Nx5x2 array of 369 | points for N detected faces. (output from self.detect). 370 | imgs {PIL.Image, np.ndarray, or list} -- A PIL image, np.ndarray, torch.Tensor, or list. 371 | 372 | Keyword Arguments: 373 | method {str} -- Which heuristic to use for selection: 374 | "probability": highest probability selected 375 | "largest": largest box selected 376 | "largest_over_theshold": largest box over a certain probability selected 377 | "center_weighted_size": box size minus weighted squared offset from image center 378 | (default: {'probability'}) 379 | threshold {float} -- theshold for "largest_over_threshold" method. (default: {0.9}) 380 | center_weight {float} -- weight for squared offset in center weighted size method. 381 | (default: {2.0}) 382 | 383 | Returns: 384 | tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray) -- nx4 ndarray of bounding boxes 385 | for n images. Ix0 array of probabilities for each box, array of landmark points. 386 | """ 387 | 388 | #copying batch detection from extract, but would be easier to ensure detect creates consistent output. 389 | batch_mode = True 390 | if ( 391 | not isinstance(imgs, (list, tuple)) and 392 | not (isinstance(imgs, np.ndarray) and len(imgs.shape) == 4) and 393 | not (isinstance(imgs, torch.Tensor) and len(imgs.shape) == 4) 394 | ): 395 | imgs = [imgs] 396 | all_boxes = [all_boxes] 397 | all_probs = [all_probs] 398 | all_points = [all_points] 399 | batch_mode = False 400 | 401 | selected_boxes, selected_probs, selected_points = [], [], [] 402 | for boxes, points, probs, img in zip(all_boxes, all_points, all_probs, imgs): 403 | 404 | if boxes is None: 405 | selected_boxes.append(None) 406 | selected_probs.append([None]) 407 | selected_points.append(None) 408 | continue 409 | 410 | # If at least 1 box found 411 | boxes = np.array(boxes) 412 | probs = np.array(probs) 413 | points = np.array(points) 414 | 415 | if method == 'largest': 416 | box_order = np.argsort((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]))[::-1] 417 | elif method == 'probability': 418 | box_order = np.argsort(probs)[::-1] 419 | elif method == 'center_weighted_size': 420 | box_sizes = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 421 | img_center = (img.width / 2, img.height/2) 422 | box_centers = np.array(list(zip((boxes[:, 0] + boxes[:, 2]) / 2, (boxes[:, 1] + boxes[:, 3]) / 2))) 423 | offsets = box_centers - img_center 424 | offset_dist_squared = np.sum(np.power(offsets, 2.0), 1) 425 | box_order = np.argsort(box_sizes - offset_dist_squared * center_weight)[::-1] 426 | elif method == 'largest_over_threshold': 427 | box_mask = probs > threshold 428 | boxes = boxes[box_mask] 429 | box_order = np.argsort((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]))[::-1] 430 | if sum(box_mask) == 0: 431 | selected_boxes.append(None) 432 | selected_probs.append([None]) 433 | selected_points.append(None) 434 | continue 435 | 436 | box = boxes[box_order][[0]] 437 | prob = probs[box_order][[0]] 438 | point = points[box_order][[0]] 439 | selected_boxes.append(box) 440 | selected_probs.append(prob) 441 | selected_points.append(point) 442 | 443 | if batch_mode: 444 | selected_boxes = np.array(selected_boxes) 445 | selected_probs = np.array(selected_probs) 446 | selected_points = np.array(selected_points) 447 | else: 448 | selected_boxes = selected_boxes[0] 449 | selected_probs = selected_probs[0][0] 450 | selected_points = selected_points[0] 451 | 452 | return selected_boxes, selected_probs, selected_points 453 | 454 | def extract(self, img, batch_boxes, save_path): 455 | # Determine if a batch or single image was passed 456 | batch_mode = True 457 | if ( 458 | not isinstance(img, (list, tuple)) and 459 | not (isinstance(img, np.ndarray) and len(img.shape) == 4) and 460 | not (isinstance(img, torch.Tensor) and len(img.shape) == 4) 461 | ): 462 | img = [img] 463 | batch_boxes = [batch_boxes] 464 | batch_mode = False 465 | 466 | # Parse save path(s) 467 | if save_path is not None: 468 | if isinstance(save_path, str): 469 | save_path = [save_path] 470 | else: 471 | save_path = [None for _ in range(len(img))] 472 | 473 | # Process all bounding boxes 474 | faces = [] 475 | for im, box_im, path_im in zip(img, batch_boxes, save_path): 476 | if box_im is None: 477 | faces.append(None) 478 | continue 479 | 480 | if not self.keep_all: 481 | box_im = box_im[[0]] 482 | 483 | faces_im = [] 484 | for i, box in enumerate(box_im): 485 | face_path = path_im 486 | if path_im is not None and i > 0: 487 | save_name, ext = os.path.splitext(path_im) 488 | face_path = save_name + '_' + str(i + 1) + ext 489 | 490 | face = extract_face(im, box, self.image_size, self.margin, face_path) 491 | if self.post_process: 492 | face = fixed_image_standardization(face) 493 | faces_im.append(face) 494 | 495 | if self.keep_all: 496 | faces_im = torch.stack(faces_im) 497 | else: 498 | faces_im = faces_im[0] 499 | 500 | faces.append(faces_im) 501 | 502 | if not batch_mode: 503 | faces = faces[0] 504 | 505 | return faces 506 | 507 | 508 | def fixed_image_standardization(image_tensor): 509 | processed_tensor = (image_tensor - 127.5) / 128.0 510 | return processed_tensor 511 | 512 | 513 | def prewhiten(x): 514 | mean = x.mean() 515 | std = x.std() 516 | std_adj = std.clamp(min=1.0/(float(x.numel())**0.5)) 517 | y = (x - mean) / std_adj 518 | return y 519 | 520 | -------------------------------------------------------------------------------- /models/utils/detect_face.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import interpolate 3 | from torchvision.transforms import functional as F 4 | from torchvision.ops.boxes import batched_nms 5 | from PIL import Image 6 | import numpy as np 7 | import os 8 | import math 9 | 10 | # OpenCV is optional, but required if using numpy arrays instead of PIL 11 | try: 12 | import cv2 13 | except: 14 | pass 15 | 16 | def fixed_batch_process(im_data, model): 17 | batch_size = 512 18 | out = [] 19 | for i in range(0, len(im_data), batch_size): 20 | batch = im_data[i:(i+batch_size)] 21 | out.append(model(batch)) 22 | 23 | return tuple(torch.cat(v, dim=0) for v in zip(*out)) 24 | 25 | def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device): 26 | if isinstance(imgs, (np.ndarray, torch.Tensor)): 27 | if isinstance(imgs,np.ndarray): 28 | imgs = torch.as_tensor(imgs.copy(), device=device) 29 | 30 | if isinstance(imgs,torch.Tensor): 31 | imgs = torch.as_tensor(imgs, device=device) 32 | 33 | if len(imgs.shape) == 3: 34 | imgs = imgs.unsqueeze(0) 35 | else: 36 | if not isinstance(imgs, (list, tuple)): 37 | imgs = [imgs] 38 | if any(img.size != imgs[0].size for img in imgs): 39 | raise Exception("MTCNN batch processing only compatible with equal-dimension images.") 40 | imgs = np.stack([np.uint8(img) for img in imgs]) 41 | imgs = torch.as_tensor(imgs.copy(), device=device) 42 | 43 | 44 | 45 | model_dtype = next(pnet.parameters()).dtype 46 | imgs = imgs.permute(0, 3, 1, 2).type(model_dtype) 47 | 48 | batch_size = len(imgs) 49 | h, w = imgs.shape[2:4] 50 | m = 12.0 / minsize 51 | minl = min(h, w) 52 | minl = minl * m 53 | 54 | # Create scale pyramid 55 | scale_i = m 56 | scales = [] 57 | while minl >= 12: 58 | scales.append(scale_i) 59 | scale_i = scale_i * factor 60 | minl = minl * factor 61 | 62 | # First stage 63 | boxes = [] 64 | image_inds = [] 65 | 66 | scale_picks = [] 67 | 68 | all_i = 0 69 | offset = 0 70 | for scale in scales: 71 | im_data = imresample(imgs, (int(h * scale + 1), int(w * scale + 1))) 72 | im_data = (im_data - 127.5) * 0.0078125 73 | reg, probs = pnet(im_data) 74 | 75 | boxes_scale, image_inds_scale = generateBoundingBox(reg, probs[:, 1], scale, threshold[0]) 76 | boxes.append(boxes_scale) 77 | image_inds.append(image_inds_scale) 78 | 79 | pick = batched_nms(boxes_scale[:, :4], boxes_scale[:, 4], image_inds_scale, 0.5) 80 | scale_picks.append(pick + offset) 81 | offset += boxes_scale.shape[0] 82 | 83 | boxes = torch.cat(boxes, dim=0) 84 | image_inds = torch.cat(image_inds, dim=0) 85 | 86 | scale_picks = torch.cat(scale_picks, dim=0) 87 | 88 | # NMS within each scale + image 89 | boxes, image_inds = boxes[scale_picks], image_inds[scale_picks] 90 | 91 | 92 | # NMS within each image 93 | pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7) 94 | boxes, image_inds = boxes[pick], image_inds[pick] 95 | 96 | regw = boxes[:, 2] - boxes[:, 0] 97 | regh = boxes[:, 3] - boxes[:, 1] 98 | qq1 = boxes[:, 0] + boxes[:, 5] * regw 99 | qq2 = boxes[:, 1] + boxes[:, 6] * regh 100 | qq3 = boxes[:, 2] + boxes[:, 7] * regw 101 | qq4 = boxes[:, 3] + boxes[:, 8] * regh 102 | boxes = torch.stack([qq1, qq2, qq3, qq4, boxes[:, 4]]).permute(1, 0) 103 | boxes = rerec(boxes) 104 | y, ey, x, ex = pad(boxes, w, h) 105 | 106 | # Second stage 107 | if len(boxes) > 0: 108 | im_data = [] 109 | for k in range(len(y)): 110 | if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1): 111 | img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0) 112 | im_data.append(imresample(img_k, (24, 24))) 113 | im_data = torch.cat(im_data, dim=0) 114 | im_data = (im_data - 127.5) * 0.0078125 115 | 116 | # This is equivalent to out = rnet(im_data) to avoid GPU out of memory. 117 | out = fixed_batch_process(im_data, rnet) 118 | 119 | out0 = out[0].permute(1, 0) 120 | out1 = out[1].permute(1, 0) 121 | score = out1[1, :] 122 | ipass = score > threshold[1] 123 | boxes = torch.cat((boxes[ipass, :4], score[ipass].unsqueeze(1)), dim=1) 124 | image_inds = image_inds[ipass] 125 | mv = out0[:, ipass].permute(1, 0) 126 | 127 | # NMS within each image 128 | pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7) 129 | boxes, image_inds, mv = boxes[pick], image_inds[pick], mv[pick] 130 | boxes = bbreg(boxes, mv) 131 | boxes = rerec(boxes) 132 | 133 | # Third stage 134 | points = torch.zeros(0, 5, 2, device=device) 135 | if len(boxes) > 0: 136 | y, ey, x, ex = pad(boxes, w, h) 137 | im_data = [] 138 | for k in range(len(y)): 139 | if ey[k] > (y[k] - 1) and ex[k] > (x[k] - 1): 140 | img_k = imgs[image_inds[k], :, (y[k] - 1):ey[k], (x[k] - 1):ex[k]].unsqueeze(0) 141 | im_data.append(imresample(img_k, (48, 48))) 142 | im_data = torch.cat(im_data, dim=0) 143 | im_data = (im_data - 127.5) * 0.0078125 144 | 145 | # This is equivalent to out = onet(im_data) to avoid GPU out of memory. 146 | out = fixed_batch_process(im_data, onet) 147 | 148 | out0 = out[0].permute(1, 0) 149 | out1 = out[1].permute(1, 0) 150 | out2 = out[2].permute(1, 0) 151 | score = out2[1, :] 152 | points = out1 153 | ipass = score > threshold[2] 154 | points = points[:, ipass] 155 | boxes = torch.cat((boxes[ipass, :4], score[ipass].unsqueeze(1)), dim=1) 156 | image_inds = image_inds[ipass] 157 | mv = out0[:, ipass].permute(1, 0) 158 | 159 | w_i = boxes[:, 2] - boxes[:, 0] + 1 160 | h_i = boxes[:, 3] - boxes[:, 1] + 1 161 | points_x = w_i.repeat(5, 1) * points[:5, :] + boxes[:, 0].repeat(5, 1) - 1 162 | points_y = h_i.repeat(5, 1) * points[5:10, :] + boxes[:, 1].repeat(5, 1) - 1 163 | points = torch.stack((points_x, points_y)).permute(2, 1, 0) 164 | boxes = bbreg(boxes, mv) 165 | 166 | # NMS within each image using "Min" strategy 167 | # pick = batched_nms(boxes[:, :4], boxes[:, 4], image_inds, 0.7) 168 | pick = batched_nms_numpy(boxes[:, :4], boxes[:, 4], image_inds, 0.7, 'Min') 169 | boxes, image_inds, points = boxes[pick], image_inds[pick], points[pick] 170 | 171 | boxes = boxes.cpu().numpy() 172 | points = points.cpu().numpy() 173 | 174 | image_inds = image_inds.cpu() 175 | 176 | batch_boxes = [] 177 | batch_points = [] 178 | for b_i in range(batch_size): 179 | b_i_inds = np.where(image_inds == b_i) 180 | batch_boxes.append(boxes[b_i_inds].copy()) 181 | batch_points.append(points[b_i_inds].copy()) 182 | 183 | batch_boxes, batch_points = np.array(batch_boxes, dtype=object), np.array(batch_points, dtype=object) 184 | 185 | return batch_boxes, batch_points 186 | 187 | 188 | def bbreg(boundingbox, reg): 189 | if reg.shape[1] == 1: 190 | reg = torch.reshape(reg, (reg.shape[2], reg.shape[3])) 191 | 192 | w = boundingbox[:, 2] - boundingbox[:, 0] + 1 193 | h = boundingbox[:, 3] - boundingbox[:, 1] + 1 194 | b1 = boundingbox[:, 0] + reg[:, 0] * w 195 | b2 = boundingbox[:, 1] + reg[:, 1] * h 196 | b3 = boundingbox[:, 2] + reg[:, 2] * w 197 | b4 = boundingbox[:, 3] + reg[:, 3] * h 198 | boundingbox[:, :4] = torch.stack([b1, b2, b3, b4]).permute(1, 0) 199 | 200 | return boundingbox 201 | 202 | 203 | def generateBoundingBox(reg, probs, scale, thresh): 204 | stride = 2 205 | cellsize = 12 206 | 207 | reg = reg.permute(1, 0, 2, 3) 208 | 209 | mask = probs >= thresh 210 | mask_inds = mask.nonzero() 211 | image_inds = mask_inds[:, 0] 212 | score = probs[mask] 213 | reg = reg[:, mask].permute(1, 0) 214 | bb = mask_inds[:, 1:].type(reg.dtype).flip(1) 215 | q1 = ((stride * bb + 1) / scale).floor() 216 | q2 = ((stride * bb + cellsize - 1 + 1) / scale).floor() 217 | boundingbox = torch.cat([q1, q2, score.unsqueeze(1), reg], dim=1) 218 | return boundingbox, image_inds 219 | 220 | 221 | def nms_numpy(boxes, scores, threshold, method): 222 | if boxes.size == 0: 223 | return np.empty((0, 3)) 224 | 225 | x1 = boxes[:, 0].copy() 226 | y1 = boxes[:, 1].copy() 227 | x2 = boxes[:, 2].copy() 228 | y2 = boxes[:, 3].copy() 229 | s = scores 230 | area = (x2 - x1 + 1) * (y2 - y1 + 1) 231 | 232 | I = np.argsort(s) 233 | pick = np.zeros_like(s, dtype=np.int16) 234 | counter = 0 235 | while I.size > 0: 236 | i = I[-1] 237 | pick[counter] = i 238 | counter += 1 239 | idx = I[0:-1] 240 | 241 | xx1 = np.maximum(x1[i], x1[idx]).copy() 242 | yy1 = np.maximum(y1[i], y1[idx]).copy() 243 | xx2 = np.minimum(x2[i], x2[idx]).copy() 244 | yy2 = np.minimum(y2[i], y2[idx]).copy() 245 | 246 | w = np.maximum(0.0, xx2 - xx1 + 1).copy() 247 | h = np.maximum(0.0, yy2 - yy1 + 1).copy() 248 | 249 | inter = w * h 250 | if method == 'Min': 251 | o = inter / np.minimum(area[i], area[idx]) 252 | else: 253 | o = inter / (area[i] + area[idx] - inter) 254 | I = I[np.where(o <= threshold)] 255 | 256 | pick = pick[:counter].copy() 257 | return pick 258 | 259 | 260 | def batched_nms_numpy(boxes, scores, idxs, threshold, method): 261 | device = boxes.device 262 | if boxes.numel() == 0: 263 | return torch.empty((0,), dtype=torch.int64, device=device) 264 | # strategy: in order to perform NMS independently per class. 265 | # we add an offset to all the boxes. The offset is dependent 266 | # only on the class idx, and is large enough so that boxes 267 | # from different classes do not overlap 268 | max_coordinate = boxes.max() 269 | offsets = idxs.to(boxes) * (max_coordinate + 1) 270 | boxes_for_nms = boxes + offsets[:, None] 271 | boxes_for_nms = boxes_for_nms.cpu().numpy() 272 | scores = scores.cpu().numpy() 273 | keep = nms_numpy(boxes_for_nms, scores, threshold, method) 274 | return torch.as_tensor(keep, dtype=torch.long, device=device) 275 | 276 | 277 | def pad(boxes, w, h): 278 | boxes = boxes.trunc().int().cpu().numpy() 279 | x = boxes[:, 0] 280 | y = boxes[:, 1] 281 | ex = boxes[:, 2] 282 | ey = boxes[:, 3] 283 | 284 | x[x < 1] = 1 285 | y[y < 1] = 1 286 | ex[ex > w] = w 287 | ey[ey > h] = h 288 | 289 | return y, ey, x, ex 290 | 291 | 292 | def rerec(bboxA): 293 | h = bboxA[:, 3] - bboxA[:, 1] 294 | w = bboxA[:, 2] - bboxA[:, 0] 295 | 296 | l = torch.max(w, h) 297 | bboxA[:, 0] = bboxA[:, 0] + w * 0.5 - l * 0.5 298 | bboxA[:, 1] = bboxA[:, 1] + h * 0.5 - l * 0.5 299 | bboxA[:, 2:4] = bboxA[:, :2] + l.repeat(2, 1).permute(1, 0) 300 | 301 | return bboxA 302 | 303 | 304 | def imresample(img, sz): 305 | im_data = interpolate(img, size=sz, mode="area") 306 | return im_data 307 | 308 | 309 | def crop_resize(img, box, image_size): 310 | if isinstance(img, np.ndarray): 311 | img = img[box[1]:box[3], box[0]:box[2]] 312 | out = cv2.resize( 313 | img, 314 | (image_size, image_size), 315 | interpolation=cv2.INTER_AREA 316 | ).copy() 317 | elif isinstance(img, torch.Tensor): 318 | img = img[box[1]:box[3], box[0]:box[2]] 319 | out = imresample( 320 | img.permute(2, 0, 1).unsqueeze(0).float(), 321 | (image_size, image_size) 322 | ).byte().squeeze(0).permute(1, 2, 0) 323 | else: 324 | out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR) 325 | return out 326 | 327 | 328 | def save_img(img, path): 329 | if isinstance(img, np.ndarray): 330 | cv2.imwrite(path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 331 | else: 332 | img.save(path) 333 | 334 | 335 | def get_size(img): 336 | if isinstance(img, (np.ndarray, torch.Tensor)): 337 | return img.shape[1::-1] 338 | else: 339 | return img.size 340 | 341 | 342 | def extract_face(img, box, image_size=160, margin=0, save_path=None): 343 | """Extract face + margin from PIL Image given bounding box. 344 | 345 | Arguments: 346 | img {PIL.Image} -- A PIL Image. 347 | box {numpy.ndarray} -- Four-element bounding box. 348 | image_size {int} -- Output image size in pixels. The image will be square. 349 | margin {int} -- Margin to add to bounding box, in terms of pixels in the final image. 350 | Note that the application of the margin differs slightly from the davidsandberg/facenet 351 | repo, which applies the margin to the original image before resizing, making the margin 352 | dependent on the original image size. 353 | save_path {str} -- Save path for extracted face image. (default: {None}) 354 | 355 | Returns: 356 | torch.tensor -- tensor representing the extracted face. 357 | """ 358 | margin = [ 359 | margin * (box[2] - box[0]) / (image_size - margin), 360 | margin * (box[3] - box[1]) / (image_size - margin), 361 | ] 362 | raw_image_size = get_size(img) 363 | box = [ 364 | int(max(box[0] - margin[0] / 2, 0)), 365 | int(max(box[1] - margin[1] / 2, 0)), 366 | int(min(box[2] + margin[0] / 2, raw_image_size[0])), 367 | int(min(box[3] + margin[1] / 2, raw_image_size[1])), 368 | ] 369 | 370 | face = crop_resize(img, box, image_size) 371 | 372 | if save_path is not None: 373 | os.makedirs(os.path.dirname(save_path) + "/", exist_ok=True) 374 | save_img(face, save_path) 375 | 376 | if isinstance(face, np.ndarray) or isinstance(face, Image.Image): 377 | face = F.to_tensor(np.float32(face)) 378 | elif isinstance(face, torch.Tensor): 379 | face = face.float() 380 | else: 381 | raise NotImplementedError 382 | 383 | return face 384 | -------------------------------------------------------------------------------- /models/utils/download.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import shutil 4 | import sys 5 | import tempfile 6 | 7 | from urllib.request import urlopen, Request 8 | 9 | try: 10 | from tqdm.auto import tqdm # automatically select proper tqdm submodule if available 11 | except ImportError: 12 | from tqdm import tqdm 13 | 14 | 15 | def download_url_to_file(url, dst, hash_prefix=None, progress=True): 16 | r"""Download object at the given URL to a local path. 17 | Args: 18 | url (string): URL of the object to download 19 | dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` 20 | hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`. 21 | Default: None 22 | progress (bool, optional): whether or not to display a progress bar to stderr 23 | Default: True 24 | Example: 25 | >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') 26 | """ 27 | file_size = None 28 | # We use a different API for python2 since urllib(2) doesn't recognize the CA 29 | # certificates in older Python 30 | req = Request(url, headers={"User-Agent": "torch.hub"}) 31 | u = urlopen(req) 32 | meta = u.info() 33 | if hasattr(meta, 'getheaders'): 34 | content_length = meta.getheaders("Content-Length") 35 | else: 36 | content_length = meta.get_all("Content-Length") 37 | if content_length is not None and len(content_length) > 0: 38 | file_size = int(content_length[0]) 39 | 40 | # We deliberately save it in a temp file and move it after 41 | # download is complete. This prevents a local working checkpoint 42 | # being overridden by a broken download. 43 | dst = os.path.expanduser(dst) 44 | dst_dir = os.path.dirname(dst) 45 | f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) 46 | 47 | try: 48 | if hash_prefix is not None: 49 | sha256 = hashlib.sha256() 50 | with tqdm(total=file_size, disable=not progress, 51 | unit='B', unit_scale=True, unit_divisor=1024) as pbar: 52 | while True: 53 | buffer = u.read(8192) 54 | if len(buffer) == 0: 55 | break 56 | f.write(buffer) 57 | if hash_prefix is not None: 58 | sha256.update(buffer) 59 | pbar.update(len(buffer)) 60 | 61 | f.close() 62 | if hash_prefix is not None: 63 | digest = sha256.hexdigest() 64 | if digest[:len(hash_prefix)] != hash_prefix: 65 | raise RuntimeError('invalid hash value (expected "{}", got "{}")' 66 | .format(hash_prefix, digest)) 67 | shutil.move(f.name, dst) 68 | finally: 69 | f.close() 70 | if os.path.exists(f.name): 71 | os.remove(f.name) 72 | -------------------------------------------------------------------------------- /models/utils/tensorflow2pytorch.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import torch 3 | import json 4 | import os, sys 5 | 6 | from dependencies.facenet.src import facenet 7 | from dependencies.facenet.src.models import inception_resnet_v1 as tf_mdl 8 | from dependencies.facenet.src.align import detect_face 9 | 10 | from models.inception_resnet_v1 import InceptionResnetV1 11 | from models.mtcnn import PNet, RNet, ONet 12 | 13 | 14 | def import_tf_params(tf_mdl_dir, sess): 15 | """Import tensorflow model from save directory. 16 | 17 | Arguments: 18 | tf_mdl_dir {str} -- Location of protobuf, checkpoint, meta files. 19 | sess {tensorflow.Session} -- Tensorflow session object. 20 | 21 | Returns: 22 | (list, list, list) -- Tuple of lists containing the layer names, 23 | parameter arrays as numpy ndarrays, parameter shapes. 24 | """ 25 | print('\nLoading tensorflow model\n') 26 | if callable(tf_mdl_dir): 27 | tf_mdl_dir(sess) 28 | else: 29 | facenet.load_model(tf_mdl_dir) 30 | 31 | print('\nGetting model weights\n') 32 | tf_layers = tf.trainable_variables() 33 | tf_params = sess.run(tf_layers) 34 | 35 | tf_shapes = [p.shape for p in tf_params] 36 | tf_layers = [l.name for l in tf_layers] 37 | 38 | if not callable(tf_mdl_dir): 39 | path = os.path.join(tf_mdl_dir, 'layer_description.json') 40 | else: 41 | path = 'data/layer_description.json' 42 | with open(path, 'w') as f: 43 | json.dump({l: s for l, s in zip(tf_layers, tf_shapes)}, f) 44 | 45 | return tf_layers, tf_params, tf_shapes 46 | 47 | 48 | def get_layer_indices(layer_lookup, tf_layers): 49 | """Giving a lookup of model layer attribute names and tensorflow variable names, 50 | find matching parameters. 51 | 52 | Arguments: 53 | layer_lookup {dict} -- Dictionary mapping pytorch attribute names to (partial) 54 | tensorflow variable names. Expects dict of the form {'attr': ['tf_name', ...]} 55 | where the '...'s are ignored. 56 | tf_layers {list} -- List of tensorflow variable names. 57 | 58 | Returns: 59 | list -- The input dictionary with the list of matching inds appended to each item. 60 | """ 61 | layer_inds = {} 62 | for name, value in layer_lookup.items(): 63 | layer_inds[name] = value + [[i for i, n in enumerate(tf_layers) if value[0] in n]] 64 | return layer_inds 65 | 66 | 67 | def load_tf_batchNorm(weights, layer): 68 | """Load tensorflow weights into nn.BatchNorm object. 69 | 70 | Arguments: 71 | weights {list} -- Tensorflow parameters. 72 | layer {torch.nn.Module} -- nn.BatchNorm. 73 | """ 74 | layer.bias.data = torch.tensor(weights[0]).view(layer.bias.data.shape) 75 | layer.weight.data = torch.ones_like(layer.weight.data) 76 | layer.running_mean = torch.tensor(weights[1]).view(layer.running_mean.shape) 77 | layer.running_var = torch.tensor(weights[2]).view(layer.running_var.shape) 78 | 79 | 80 | def load_tf_conv2d(weights, layer, transpose=False): 81 | """Load tensorflow weights into nn.Conv2d object. 82 | 83 | Arguments: 84 | weights {list} -- Tensorflow parameters. 85 | layer {torch.nn.Module} -- nn.Conv2d. 86 | """ 87 | if isinstance(weights, list): 88 | if len(weights) == 2: 89 | layer.bias.data = ( 90 | torch.tensor(weights[1]) 91 | .view(layer.bias.data.shape) 92 | ) 93 | weights = weights[0] 94 | 95 | if transpose: 96 | dim_order = (3, 2, 1, 0) 97 | else: 98 | dim_order = (3, 2, 0, 1) 99 | 100 | layer.weight.data = ( 101 | torch.tensor(weights) 102 | .permute(dim_order) 103 | .view(layer.weight.data.shape) 104 | ) 105 | 106 | 107 | def load_tf_conv2d_trans(weights, layer): 108 | return load_tf_conv2d(weights, layer, transpose=True) 109 | 110 | 111 | def load_tf_basicConv2d(weights, layer): 112 | """Load tensorflow weights into grouped Conv2d+BatchNorm object. 113 | 114 | Arguments: 115 | weights {list} -- Tensorflow parameters. 116 | layer {torch.nn.Module} -- Object containing Conv2d+BatchNorm. 117 | """ 118 | load_tf_conv2d(weights[0], layer.conv) 119 | load_tf_batchNorm(weights[1:], layer.bn) 120 | 121 | 122 | def load_tf_linear(weights, layer): 123 | """Load tensorflow weights into nn.Linear object. 124 | 125 | Arguments: 126 | weights {list} -- Tensorflow parameters. 127 | layer {torch.nn.Module} -- nn.Linear. 128 | """ 129 | if isinstance(weights, list): 130 | if len(weights) == 2: 131 | layer.bias.data = ( 132 | torch.tensor(weights[1]) 133 | .view(layer.bias.data.shape) 134 | ) 135 | weights = weights[0] 136 | layer.weight.data = ( 137 | torch.tensor(weights) 138 | .transpose(-1, 0) 139 | .view(layer.weight.data.shape) 140 | ) 141 | 142 | 143 | # High-level parameter-loading functions: 144 | 145 | def load_tf_block35(weights, layer): 146 | load_tf_basicConv2d(weights[:4], layer.branch0) 147 | load_tf_basicConv2d(weights[4:8], layer.branch1[0]) 148 | load_tf_basicConv2d(weights[8:12], layer.branch1[1]) 149 | load_tf_basicConv2d(weights[12:16], layer.branch2[0]) 150 | load_tf_basicConv2d(weights[16:20], layer.branch2[1]) 151 | load_tf_basicConv2d(weights[20:24], layer.branch2[2]) 152 | load_tf_conv2d(weights[24:26], layer.conv2d) 153 | 154 | 155 | def load_tf_block17_8(weights, layer): 156 | load_tf_basicConv2d(weights[:4], layer.branch0) 157 | load_tf_basicConv2d(weights[4:8], layer.branch1[0]) 158 | load_tf_basicConv2d(weights[8:12], layer.branch1[1]) 159 | load_tf_basicConv2d(weights[12:16], layer.branch1[2]) 160 | load_tf_conv2d(weights[16:18], layer.conv2d) 161 | 162 | 163 | def load_tf_mixed6a(weights, layer): 164 | if len(weights) != 16: 165 | raise ValueError(f'Number of weight arrays ({len(weights)}) not equal to 16') 166 | load_tf_basicConv2d(weights[:4], layer.branch0) 167 | load_tf_basicConv2d(weights[4:8], layer.branch1[0]) 168 | load_tf_basicConv2d(weights[8:12], layer.branch1[1]) 169 | load_tf_basicConv2d(weights[12:16], layer.branch1[2]) 170 | 171 | 172 | def load_tf_mixed7a(weights, layer): 173 | if len(weights) != 28: 174 | raise ValueError(f'Number of weight arrays ({len(weights)}) not equal to 28') 175 | load_tf_basicConv2d(weights[:4], layer.branch0[0]) 176 | load_tf_basicConv2d(weights[4:8], layer.branch0[1]) 177 | load_tf_basicConv2d(weights[8:12], layer.branch1[0]) 178 | load_tf_basicConv2d(weights[12:16], layer.branch1[1]) 179 | load_tf_basicConv2d(weights[16:20], layer.branch2[0]) 180 | load_tf_basicConv2d(weights[20:24], layer.branch2[1]) 181 | load_tf_basicConv2d(weights[24:28], layer.branch2[2]) 182 | 183 | 184 | def load_tf_repeats(weights, layer, rptlen, subfun): 185 | if len(weights) % rptlen != 0: 186 | raise ValueError(f'Number of weight arrays ({len(weights)}) not divisible by {rptlen}') 187 | weights_split = [weights[i:i+rptlen] for i in range(0, len(weights), rptlen)] 188 | for i, w in enumerate(weights_split): 189 | subfun(w, getattr(layer, str(i))) 190 | 191 | 192 | def load_tf_repeat_1(weights, layer): 193 | load_tf_repeats(weights, layer, 26, load_tf_block35) 194 | 195 | 196 | def load_tf_repeat_2(weights, layer): 197 | load_tf_repeats(weights, layer, 18, load_tf_block17_8) 198 | 199 | 200 | def load_tf_repeat_3(weights, layer): 201 | load_tf_repeats(weights, layer, 18, load_tf_block17_8) 202 | 203 | 204 | def test_loaded_params(mdl, tf_params, tf_layers): 205 | """Check each parameter in a pytorch model for an equivalent parameter 206 | in a list of tensorflow variables. 207 | 208 | Arguments: 209 | mdl {torch.nn.Module} -- Pytorch model. 210 | tf_params {list} -- List of ndarrays representing tensorflow variables. 211 | tf_layers {list} -- Corresponding list of tensorflow variable names. 212 | """ 213 | tf_means = torch.stack([torch.tensor(p).mean() for p in tf_params]) 214 | for name, param in mdl.named_parameters(): 215 | pt_mean = param.data.mean() 216 | matching_inds = ((tf_means - pt_mean).abs() < 1e-8).nonzero() 217 | print(f'{name} equivalent to {[tf_layers[i] for i in matching_inds]}') 218 | 219 | 220 | def compare_model_outputs(pt_mdl, sess, test_data): 221 | """Given some testing data, compare the output of pytorch and tensorflow models. 222 | 223 | Arguments: 224 | pt_mdl {torch.nn.Module} -- Pytorch model. 225 | sess {tensorflow.Session} -- Tensorflow session object. 226 | test_data {torch.Tensor} -- Pytorch tensor. 227 | """ 228 | print('\nPassing test data through TF model\n') 229 | if isinstance(sess, tf.Session): 230 | images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0") 231 | phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0") 232 | embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0") 233 | feed_dict = {images_placeholder: test_data.numpy(), phase_train_placeholder: False} 234 | tf_output = torch.tensor(sess.run(embeddings, feed_dict=feed_dict)) 235 | else: 236 | tf_output = sess(test_data) 237 | 238 | print(tf_output) 239 | 240 | print('\nPassing test data through PT model\n') 241 | pt_output = pt_mdl(test_data.permute(0, 3, 1, 2)) 242 | print(pt_output) 243 | 244 | distance = (tf_output - pt_output).norm() 245 | print(f'\nDistance {distance}\n') 246 | 247 | 248 | def compare_mtcnn(pt_mdl, tf_fun, sess, ind, test_data): 249 | tf_mdls = tf_fun(sess) 250 | tf_mdl = tf_mdls[ind] 251 | 252 | print('\nPassing test data through TF model\n') 253 | tf_output = tf_mdl(test_data.numpy()) 254 | tf_output = [torch.tensor(out) for out in tf_output] 255 | print('\n'.join([str(o.view(-1)[:10]) for o in tf_output])) 256 | 257 | print('\nPassing test data through PT model\n') 258 | with torch.no_grad(): 259 | pt_output = pt_mdl(test_data.permute(0, 3, 2, 1)) 260 | pt_output = [torch.tensor(out) for out in pt_output] 261 | for i in range(len(pt_output)): 262 | if len(pt_output[i].shape) == 4: 263 | pt_output[i] = pt_output[i].permute(0, 3, 2, 1).contiguous() 264 | print('\n'.join([str(o.view(-1)[:10]) for o in pt_output])) 265 | 266 | distance = [(tf_o - pt_o).norm() for tf_o, pt_o in zip(tf_output, pt_output)] 267 | print(f'\nDistance {distance}\n') 268 | 269 | 270 | def load_tf_model_weights(mdl, layer_lookup, tf_mdl_dir, is_resnet=True, arg_num=None): 271 | """Load tensorflow parameters into a pytorch model. 272 | 273 | Arguments: 274 | mdl {torch.nn.Module} -- Pytorch model. 275 | layer_lookup {[type]} -- Dictionary mapping pytorch attribute names to (partial) 276 | tensorflow variable names, and a function suitable for loading weights. 277 | Expects dict of the form {'attr': ['tf_name', function]}. 278 | tf_mdl_dir {str} -- Location of protobuf, checkpoint, meta files. 279 | """ 280 | tf.reset_default_graph() 281 | with tf.Session() as sess: 282 | tf_layers, tf_params, tf_shapes = import_tf_params(tf_mdl_dir, sess) 283 | layer_info = get_layer_indices(layer_lookup, tf_layers) 284 | 285 | for layer_name, info in layer_info.items(): 286 | print(f'Loading {info[0]}/* into {layer_name}') 287 | weights = [tf_params[i] for i in info[2]] 288 | layer = getattr(mdl, layer_name) 289 | info[1](weights, layer) 290 | 291 | test_loaded_params(mdl, tf_params, tf_layers) 292 | 293 | if is_resnet: 294 | compare_model_outputs(mdl, sess, torch.randn(5, 160, 160, 3).detach()) 295 | 296 | 297 | def tensorflow2pytorch(): 298 | lookup_inception_resnet_v1 = { 299 | 'conv2d_1a': ['InceptionResnetV1/Conv2d_1a_3x3', load_tf_basicConv2d], 300 | 'conv2d_2a': ['InceptionResnetV1/Conv2d_2a_3x3', load_tf_basicConv2d], 301 | 'conv2d_2b': ['InceptionResnetV1/Conv2d_2b_3x3', load_tf_basicConv2d], 302 | 'conv2d_3b': ['InceptionResnetV1/Conv2d_3b_1x1', load_tf_basicConv2d], 303 | 'conv2d_4a': ['InceptionResnetV1/Conv2d_4a_3x3', load_tf_basicConv2d], 304 | 'conv2d_4b': ['InceptionResnetV1/Conv2d_4b_3x3', load_tf_basicConv2d], 305 | 'repeat_1': ['InceptionResnetV1/Repeat/block35', load_tf_repeat_1], 306 | 'mixed_6a': ['InceptionResnetV1/Mixed_6a', load_tf_mixed6a], 307 | 'repeat_2': ['InceptionResnetV1/Repeat_1/block17', load_tf_repeat_2], 308 | 'mixed_7a': ['InceptionResnetV1/Mixed_7a', load_tf_mixed7a], 309 | 'repeat_3': ['InceptionResnetV1/Repeat_2/block8', load_tf_repeat_3], 310 | 'block8': ['InceptionResnetV1/Block8', load_tf_block17_8], 311 | 'last_linear': ['InceptionResnetV1/Bottleneck/weights', load_tf_linear], 312 | 'last_bn': ['InceptionResnetV1/Bottleneck/BatchNorm', load_tf_batchNorm], 313 | 'logits': ['Logits', load_tf_linear], 314 | } 315 | 316 | print('\nLoad VGGFace2-trained weights and save\n') 317 | mdl = InceptionResnetV1(num_classes=8631).eval() 318 | tf_mdl_dir = 'data/20180402-114759' 319 | data_name = 'vggface2' 320 | load_tf_model_weights(mdl, lookup_inception_resnet_v1, tf_mdl_dir) 321 | state_dict = mdl.state_dict() 322 | torch.save(state_dict, f'{tf_mdl_dir}-{data_name}.pt') 323 | torch.save( 324 | { 325 | 'logits.weight': state_dict['logits.weight'], 326 | 'logits.bias': state_dict['logits.bias'], 327 | }, 328 | f'{tf_mdl_dir}-{data_name}-logits.pt' 329 | ) 330 | state_dict.pop('logits.weight') 331 | state_dict.pop('logits.bias') 332 | torch.save(state_dict, f'{tf_mdl_dir}-{data_name}-features.pt') 333 | 334 | print('\nLoad CASIA-Webface-trained weights and save\n') 335 | mdl = InceptionResnetV1(num_classes=10575).eval() 336 | tf_mdl_dir = 'data/20180408-102900' 337 | data_name = 'casia-webface' 338 | load_tf_model_weights(mdl, lookup_inception_resnet_v1, tf_mdl_dir) 339 | state_dict = mdl.state_dict() 340 | torch.save(state_dict, f'{tf_mdl_dir}-{data_name}.pt') 341 | torch.save( 342 | { 343 | 'logits.weight': state_dict['logits.weight'], 344 | 'logits.bias': state_dict['logits.bias'], 345 | }, 346 | f'{tf_mdl_dir}-{data_name}-logits.pt' 347 | ) 348 | state_dict.pop('logits.weight') 349 | state_dict.pop('logits.bias') 350 | torch.save(state_dict, f'{tf_mdl_dir}-{data_name}-features.pt') 351 | 352 | lookup_pnet = { 353 | 'conv1': ['pnet/conv1', load_tf_conv2d_trans], 354 | 'prelu1': ['pnet/PReLU1', load_tf_linear], 355 | 'conv2': ['pnet/conv2', load_tf_conv2d_trans], 356 | 'prelu2': ['pnet/PReLU2', load_tf_linear], 357 | 'conv3': ['pnet/conv3', load_tf_conv2d_trans], 358 | 'prelu3': ['pnet/PReLU3', load_tf_linear], 359 | 'conv4_1': ['pnet/conv4-1', load_tf_conv2d_trans], 360 | 'conv4_2': ['pnet/conv4-2', load_tf_conv2d_trans], 361 | } 362 | lookup_rnet = { 363 | 'conv1': ['rnet/conv1', load_tf_conv2d_trans], 364 | 'prelu1': ['rnet/prelu1', load_tf_linear], 365 | 'conv2': ['rnet/conv2', load_tf_conv2d_trans], 366 | 'prelu2': ['rnet/prelu2', load_tf_linear], 367 | 'conv3': ['rnet/conv3', load_tf_conv2d_trans], 368 | 'prelu3': ['rnet/prelu3', load_tf_linear], 369 | 'dense4': ['rnet/conv4', load_tf_linear], 370 | 'prelu4': ['rnet/prelu4', load_tf_linear], 371 | 'dense5_1': ['rnet/conv5-1', load_tf_linear], 372 | 'dense5_2': ['rnet/conv5-2', load_tf_linear], 373 | } 374 | lookup_onet = { 375 | 'conv1': ['onet/conv1', load_tf_conv2d_trans], 376 | 'prelu1': ['onet/prelu1', load_tf_linear], 377 | 'conv2': ['onet/conv2', load_tf_conv2d_trans], 378 | 'prelu2': ['onet/prelu2', load_tf_linear], 379 | 'conv3': ['onet/conv3', load_tf_conv2d_trans], 380 | 'prelu3': ['onet/prelu3', load_tf_linear], 381 | 'conv4': ['onet/conv4', load_tf_conv2d_trans], 382 | 'prelu4': ['onet/prelu4', load_tf_linear], 383 | 'dense5': ['onet/conv5', load_tf_linear], 384 | 'prelu5': ['onet/prelu5', load_tf_linear], 385 | 'dense6_1': ['onet/conv6-1', load_tf_linear], 386 | 'dense6_2': ['onet/conv6-2', load_tf_linear], 387 | 'dense6_3': ['onet/conv6-3', load_tf_linear], 388 | } 389 | 390 | print('\nLoad PNet weights and save\n') 391 | tf_mdl_dir = lambda sess: detect_face.create_mtcnn(sess, None) 392 | mdl = PNet() 393 | data_name = 'pnet' 394 | load_tf_model_weights(mdl, lookup_pnet, tf_mdl_dir, is_resnet=False, arg_num=0) 395 | torch.save(mdl.state_dict(), f'data/{data_name}.pt') 396 | tf.reset_default_graph() 397 | with tf.Session() as sess: 398 | compare_mtcnn(mdl, tf_mdl_dir, sess, 0, torch.randn(1, 256, 256, 3).detach()) 399 | 400 | print('\nLoad RNet weights and save\n') 401 | mdl = RNet() 402 | data_name = 'rnet' 403 | load_tf_model_weights(mdl, lookup_rnet, tf_mdl_dir, is_resnet=False, arg_num=1) 404 | torch.save(mdl.state_dict(), f'data/{data_name}.pt') 405 | tf.reset_default_graph() 406 | with tf.Session() as sess: 407 | compare_mtcnn(mdl, tf_mdl_dir, sess, 1, torch.randn(1, 24, 24, 3).detach()) 408 | 409 | print('\nLoad ONet weights and save\n') 410 | mdl = ONet() 411 | data_name = 'onet' 412 | load_tf_model_weights(mdl, lookup_onet, tf_mdl_dir, is_resnet=False, arg_num=2) 413 | torch.save(mdl.state_dict(), f'data/{data_name}.pt') 414 | tf.reset_default_graph() 415 | with tf.Session() as sess: 416 | compare_mtcnn(mdl, tf_mdl_dir, sess, 2, torch.randn(1, 48, 48, 3).detach()) 417 | -------------------------------------------------------------------------------- /models/utils/training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import time 4 | 5 | 6 | class Logger(object): 7 | 8 | def __init__(self, mode, length, calculate_mean=False): 9 | self.mode = mode 10 | self.length = length 11 | self.calculate_mean = calculate_mean 12 | if self.calculate_mean: 13 | self.fn = lambda x, i: x / (i + 1) 14 | else: 15 | self.fn = lambda x, i: x 16 | 17 | def __call__(self, loss, metrics, i): 18 | track_str = '\r{} | {:5d}/{:<5d}| '.format(self.mode, i + 1, self.length) 19 | loss_str = 'loss: {:9.4f} | '.format(self.fn(loss, i)) 20 | metric_str = ' | '.join('{}: {:9.4f}'.format(k, self.fn(v, i)) for k, v in metrics.items()) 21 | print(track_str + loss_str + metric_str + ' ', end='') 22 | if i + 1 == self.length: 23 | print('') 24 | 25 | 26 | class BatchTimer(object): 27 | """Batch timing class. 28 | Use this class for tracking training and testing time/rate per batch or per sample. 29 | 30 | Keyword Arguments: 31 | rate {bool} -- Whether to report a rate (batches or samples per second) or a time (seconds 32 | per batch or sample). (default: {True}) 33 | per_sample {bool} -- Whether to report times or rates per sample or per batch. 34 | (default: {True}) 35 | """ 36 | 37 | def __init__(self, rate=True, per_sample=True): 38 | self.start = time.time() 39 | self.end = None 40 | self.rate = rate 41 | self.per_sample = per_sample 42 | 43 | def __call__(self, y_pred, y): 44 | self.end = time.time() 45 | elapsed = self.end - self.start 46 | self.start = self.end 47 | self.end = None 48 | 49 | if self.per_sample: 50 | elapsed /= len(y_pred) 51 | if self.rate: 52 | elapsed = 1 / elapsed 53 | 54 | return torch.tensor(elapsed) 55 | 56 | 57 | def accuracy(logits, y): 58 | _, preds = torch.max(logits, 1) 59 | return (preds == y).float().mean() 60 | 61 | 62 | def pass_epoch( 63 | model, loss_fn, loader, optimizer=None, scheduler=None, 64 | batch_metrics={'time': BatchTimer()}, show_running=True, 65 | device='cpu', writer=None 66 | ): 67 | """Train or evaluate over a data epoch. 68 | 69 | Arguments: 70 | model {torch.nn.Module} -- Pytorch model. 71 | loss_fn {callable} -- A function to compute (scalar) loss. 72 | loader {torch.utils.data.DataLoader} -- A pytorch data loader. 73 | 74 | Keyword Arguments: 75 | optimizer {torch.optim.Optimizer} -- A pytorch optimizer. 76 | scheduler {torch.optim.lr_scheduler._LRScheduler} -- LR scheduler (default: {None}) 77 | batch_metrics {dict} -- Dictionary of metric functions to call on each batch. The default 78 | is a simple timer. A progressive average of these metrics, along with the average 79 | loss, is printed every batch. (default: {{'time': iter_timer()}}) 80 | show_running {bool} -- Whether or not to print losses and metrics for the current batch 81 | or rolling averages. (default: {False}) 82 | device {str or torch.device} -- Device for pytorch to use. (default: {'cpu'}) 83 | writer {torch.utils.tensorboard.SummaryWriter} -- Tensorboard SummaryWriter. (default: {None}) 84 | 85 | Returns: 86 | tuple(torch.Tensor, dict) -- A tuple of the average loss and a dictionary of average 87 | metric values across the epoch. 88 | """ 89 | 90 | mode = 'Train' if model.training else 'Valid' 91 | logger = Logger(mode, length=len(loader), calculate_mean=show_running) 92 | loss = 0 93 | metrics = {} 94 | 95 | for i_batch, (x, y) in enumerate(loader): 96 | x = x.to(device) 97 | y = y.to(device) 98 | y_pred = model(x) 99 | loss_batch = loss_fn(y_pred, y) 100 | 101 | if model.training: 102 | loss_batch.backward() 103 | optimizer.step() 104 | optimizer.zero_grad() 105 | 106 | metrics_batch = {} 107 | for metric_name, metric_fn in batch_metrics.items(): 108 | metrics_batch[metric_name] = metric_fn(y_pred, y).detach().cpu() 109 | metrics[metric_name] = metrics.get(metric_name, 0) + metrics_batch[metric_name] 110 | 111 | if writer is not None and model.training: 112 | if writer.iteration % writer.interval == 0: 113 | writer.add_scalars('loss', {mode: loss_batch.detach().cpu()}, writer.iteration) 114 | for metric_name, metric_batch in metrics_batch.items(): 115 | writer.add_scalars(metric_name, {mode: metric_batch}, writer.iteration) 116 | writer.iteration += 1 117 | 118 | loss_batch = loss_batch.detach().cpu() 119 | loss += loss_batch 120 | if show_running: 121 | logger(loss, metrics, i_batch) 122 | else: 123 | logger(loss_batch, metrics_batch, i_batch) 124 | 125 | if model.training and scheduler is not None: 126 | scheduler.step() 127 | 128 | loss = loss / (i_batch + 1) 129 | metrics = {k: v / (i_batch + 1) for k, v in metrics.items()} 130 | 131 | if writer is not None and not model.training: 132 | writer.add_scalars('loss', {mode: loss.detach()}, writer.iteration) 133 | for metric_name, metric in metrics.items(): 134 | writer.add_scalars(metric_name, {mode: metric}) 135 | 136 | return loss, metrics 137 | 138 | 139 | def collate_pil(x): 140 | out_x, out_y = [], [] 141 | for xx, yy in x: 142 | out_x.append(xx) 143 | out_y.append(yy) 144 | return out_x, out_y 145 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools, os 2 | 3 | PACKAGE_NAME = 'facenet-pytorch' 4 | VERSION = '2.5.2' 5 | AUTHOR = 'Tim Esler' 6 | EMAIL = 'tim.esler@gmail.com' 7 | DESCRIPTION = 'Pretrained Pytorch face detection and recognition models' 8 | GITHUB_URL = 'https://github.com/timesler/facenet-pytorch' 9 | 10 | parent_dir = os.path.dirname(os.path.realpath(__file__)) 11 | import_name = os.path.basename(parent_dir) 12 | 13 | with open('{}/README.md'.format(parent_dir), 'r') as f: 14 | long_description = f.read() 15 | 16 | setuptools.setup( 17 | name=PACKAGE_NAME, 18 | version=VERSION, 19 | author=AUTHOR, 20 | author_email=EMAIL, 21 | description=DESCRIPTION, 22 | long_description=long_description, 23 | long_description_content_type='text/markdown', 24 | url=GITHUB_URL, 25 | packages=[ 26 | 'facenet_pytorch', 27 | 'facenet_pytorch.models', 28 | 'facenet_pytorch.models.utils', 29 | 'facenet_pytorch.data', 30 | ], 31 | package_dir={'facenet_pytorch':'.'}, 32 | package_data={'': ['*net.pt']}, 33 | classifiers=[ 34 | "Programming Language :: Python :: 3", 35 | "License :: OSI Approved :: MIT License", 36 | "Operating System :: OS Independent", 37 | ], 38 | install_requires=[ 39 | 'numpy>=1.24.0,<2.0.0', 40 | 'Pillow>=10.2.0,<10.3.0', 41 | 'requests>=2.0.0,<3.0.0', 42 | 'torch>=2.2.0,<=2.3.0', 43 | 'torchvision>=0.17.0,<=0.18.0', 44 | 'tqdm>=4.0.0,<5.0.0', 45 | ], 46 | ) 47 | -------------------------------------------------------------------------------- /tests/actions_requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.24.0,<2.0.0 2 | requests>=2.0.0,<3.0.0 3 | torch>=2.2.0,<2.3.0 4 | torchvision>=0.17.0,<0.18.0 5 | Pillow>=10.2.0,<10.3.0 6 | opencv-python>=4.9.0 7 | scipy>=1.10.0,<2.0.0 8 | tqdm>=4.0.0,<5.0.0 9 | pandas>=2.0.0,<3.0.0 10 | coverage>=7.0.0,<8.0.0 11 | codecov>=2.0.0,<3.0.0 12 | jupyter>=1.0.0 13 | tensorboard>=2.0.0,<3.0.0 14 | ./ 15 | -------------------------------------------------------------------------------- /tests/actions_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | The following code is intended to be run only by Github actions for continuius intengration and 3 | testing purposes. For implementation examples see notebooks in the examples folder. 4 | """ 5 | 6 | from PIL import Image, ImageDraw 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms, datasets 10 | import numpy as np 11 | import pandas as pd 12 | from time import time 13 | import sys, os 14 | import glob 15 | 16 | from models.mtcnn import MTCNN, fixed_image_standardization 17 | from models.inception_resnet_v1 import InceptionResnetV1, get_torch_home 18 | 19 | 20 | #### CLEAR ALL OUTPUT FILES #### 21 | 22 | checkpoints = glob.glob(os.path.join(get_torch_home(), 'checkpoints/*')) 23 | for c in checkpoints: 24 | print('Removing {}'.format(c)) 25 | os.remove(c) 26 | 27 | crop_files = glob.glob('data/test_images_aligned/**/*.png') 28 | for c in crop_files: 29 | print('Removing {}'.format(c)) 30 | os.remove(c) 31 | 32 | 33 | #### TEST EXAMPLE IPYNB'S #### 34 | 35 | os.system('jupyter nbconvert --to script --stdout examples/infer.ipynb examples/finetune.ipynb > examples/tmptest.py') 36 | os.chdir('examples') 37 | try: 38 | import examples.tmptest 39 | except: 40 | import tmptest 41 | os.chdir('..') 42 | 43 | 44 | #### TEST MTCNN #### 45 | 46 | def get_image(path, trans): 47 | img = Image.open(path) 48 | img = trans(img) 49 | return img 50 | 51 | trans = transforms.Compose([ 52 | transforms.Resize(512) 53 | ]) 54 | 55 | trans_cropped = transforms.Compose([ 56 | np.float32, 57 | transforms.ToTensor(), 58 | fixed_image_standardization 59 | ]) 60 | 61 | dataset = datasets.ImageFolder('data/test_images', transform=trans) 62 | dataset.idx_to_class = {k: v for v, k in dataset.class_to_idx.items()} 63 | 64 | mtcnn_pt = MTCNN(device=torch.device('cpu')) 65 | 66 | names = [] 67 | aligned = [] 68 | aligned_fromfile = [] 69 | for img, idx in dataset: 70 | name = dataset.idx_to_class[idx] 71 | start = time() 72 | img_align = mtcnn_pt(img, save_path='data/test_images_aligned/{}/1.png'.format(name)) 73 | print('MTCNN time: {:6f} seconds'.format(time() - start)) 74 | 75 | # Comparison between types 76 | img_box = mtcnn_pt.detect(img)[0] 77 | assert (img_box - mtcnn_pt.detect(np.array(img))[0]).sum() < 1e-2 78 | assert (img_box - mtcnn_pt.detect(torch.as_tensor(np.array(img)))[0]).sum() < 1e-2 79 | 80 | # Batching test 81 | assert (img_box - mtcnn_pt.detect([img, img])[0]).sum() < 1e-2 82 | assert (img_box - mtcnn_pt.detect(np.array([np.array(img), np.array(img)]))[0]).sum() < 1e-2 83 | assert (img_box - mtcnn_pt.detect(torch.as_tensor([np.array(img), np.array(img)]))[0]).sum() < 1e-2 84 | 85 | # Box selection 86 | mtcnn_pt.selection_method = 'probability' 87 | print('\nprobability - ', mtcnn_pt.detect(img)) 88 | mtcnn_pt.selection_method = 'largest' 89 | print('largest - ', mtcnn_pt.detect(img)) 90 | mtcnn_pt.selection_method = 'largest_over_theshold' 91 | print('largest_over_theshold - ', mtcnn_pt.detect(img)) 92 | mtcnn_pt.selection_method = 'center_weighted_size' 93 | print('center_weighted_size - ', mtcnn_pt.detect(img)) 94 | 95 | if img_align is not None: 96 | names.append(name) 97 | aligned.append(img_align) 98 | aligned_fromfile.append(get_image('data/test_images_aligned/{}/1.png'.format(name), trans_cropped)) 99 | 100 | aligned = torch.stack(aligned) 101 | aligned_fromfile = torch.stack(aligned_fromfile) 102 | 103 | 104 | #### TEST EMBEDDINGS #### 105 | 106 | expected = [ 107 | [ 108 | [0.000000, 1.482895, 0.886342, 1.438450, 1.437583], 109 | [1.482895, 0.000000, 1.345686, 1.029880, 1.061939], 110 | [0.886342, 1.345686, 0.000000, 1.363125, 1.338803], 111 | [1.438450, 1.029880, 1.363125, 0.000000, 1.066040], 112 | [1.437583, 1.061939, 1.338803, 1.066040, 0.000000] 113 | ], 114 | [ 115 | [0.000000, 1.430769, 0.992931, 1.414197, 1.329544], 116 | [1.430769, 0.000000, 1.253911, 1.144899, 1.079755], 117 | [0.992931, 1.253911, 0.000000, 1.358875, 1.337322], 118 | [1.414197, 1.144899, 1.358875, 0.000000, 1.204118], 119 | [1.329544, 1.079755, 1.337322, 1.204118, 0.000000] 120 | ] 121 | ] 122 | 123 | for i, ds in enumerate(['vggface2', 'casia-webface']): 124 | resnet_pt = InceptionResnetV1(pretrained=ds).eval() 125 | 126 | start = time() 127 | embs = resnet_pt(aligned) 128 | print('\nResnet time: {:6f} seconds\n'.format(time() - start)) 129 | 130 | embs_fromfile = resnet_pt(aligned_fromfile) 131 | 132 | dists = [[(emb - e).norm().item() for e in embs] for emb in embs] 133 | dists_fromfile = [[(emb - e).norm().item() for e in embs_fromfile] for emb in embs_fromfile] 134 | 135 | print('\nOutput:') 136 | print(pd.DataFrame(dists, columns=names, index=names)) 137 | print('\nOutput (from file):') 138 | print(pd.DataFrame(dists_fromfile, columns=names, index=names)) 139 | print('\nExpected:') 140 | print(pd.DataFrame(expected[i], columns=names, index=names)) 141 | 142 | total_error = (torch.tensor(dists) - torch.tensor(expected[i])).norm() 143 | total_error_fromfile = (torch.tensor(dists_fromfile) - torch.tensor(expected[i])).norm() 144 | 145 | print('\nTotal error: {}, {}'.format(total_error, total_error_fromfile)) 146 | 147 | if sys.platform != 'win32': 148 | assert total_error < 1e-2 149 | assert total_error_fromfile < 1e-2 150 | 151 | 152 | #### TEST CLASSIFICATION #### 153 | 154 | resnet_pt = InceptionResnetV1(pretrained=ds, classify=True).eval() 155 | prob = resnet_pt(aligned) 156 | 157 | 158 | #### MULTI-FACE TEST #### 159 | 160 | mtcnn = MTCNN(keep_all=True) 161 | img = Image.open('data/multiface.jpg') 162 | boxes, probs = mtcnn.detect(img) 163 | 164 | draw = ImageDraw.Draw(img) 165 | for i, box in enumerate(boxes): 166 | draw.rectangle(box.tolist()) 167 | 168 | mtcnn(img, save_path='data/tmp.png') 169 | 170 | 171 | #### MTCNN TYPES TEST #### 172 | 173 | img = Image.open('data/multiface.jpg') 174 | 175 | mtcnn = MTCNN(keep_all=True) 176 | boxes_ref, _ = mtcnn.detect(img) 177 | _ = mtcnn(img) 178 | 179 | mtcnn = MTCNN(keep_all=True).double() 180 | boxes_test, _ = mtcnn.detect(img) 181 | _ = mtcnn(img) 182 | 183 | box_diff = boxes_ref[np.argsort(boxes_ref[:,1])] - boxes_test[np.argsort(boxes_test[:,1])] 184 | total_error = np.sum(np.abs(box_diff)) 185 | print('\nfp64 Total box error: {}'.format(total_error)) 186 | 187 | assert total_error < 1e-2 188 | 189 | 190 | # half is not supported on CPUs, only GPUs 191 | if torch.cuda.is_available(): 192 | 193 | mtcnn = MTCNN(keep_all=True, device='cuda').half() 194 | boxes_test, _ = mtcnn.detect(img) 195 | _ = mtcnn(img) 196 | 197 | box_diff = boxes_ref[np.argsort(boxes_ref[:,1])] - boxes_test[np.argsort(boxes_test[:,1])] 198 | print('fp16 Total box error: {}'.format(np.sum(np.abs(box_diff)))) 199 | 200 | # test new automatic multi precision to compare 201 | if hasattr(torch.cuda, 'amp'): 202 | with torch.cuda.amp.autocast(): 203 | mtcnn = MTCNN(keep_all=True, device='cuda') 204 | boxes_test, _ = mtcnn.detect(img) 205 | _ = mtcnn(img) 206 | 207 | box_diff = boxes_ref[np.argsort(boxes_ref[:,1])] - boxes_test[np.argsort(boxes_test[:,1])] 208 | print('AMP total box error: {}'.format(np.sum(np.abs(box_diff)))) 209 | 210 | 211 | #### MULTI-IMAGE TEST #### 212 | 213 | mtcnn = MTCNN(keep_all=True) 214 | img = [ 215 | Image.open('data/multiface.jpg'), 216 | Image.open('data/multiface.jpg') 217 | ] 218 | batch_boxes, batch_probs = mtcnn.detect(img) 219 | 220 | mtcnn(img, save_path=['data/tmp1.png', 'data/tmp1.png']) 221 | tmp_files = glob.glob('data/tmp*') 222 | for f in tmp_files: 223 | os.remove(f) 224 | 225 | 226 | #### NO-FACE TEST #### 227 | 228 | img = Image.new('RGB', (512, 512)) 229 | mtcnn(img) 230 | mtcnn(img, return_prob=True) 231 | -------------------------------------------------------------------------------- /tests/perf_test.py: -------------------------------------------------------------------------------- 1 | from facenet_pytorch import MTCNN, training 2 | import torch 3 | from torchvision import datasets, transforms 4 | from torch.utils.data import DataLoader, RandomSampler 5 | from tqdm import tqdm 6 | import time 7 | 8 | 9 | def main(): 10 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 11 | print(f'Running on device "{device}"') 12 | 13 | mtcnn = MTCNN(device=device) 14 | 15 | batch_size = 32 16 | 17 | # Generate data loader 18 | ds = datasets.ImageFolder( 19 | root='data/test_images/', 20 | transform=transforms.Resize((512, 512)) 21 | ) 22 | dl = DataLoader( 23 | dataset=ds, 24 | num_workers=4, 25 | collate_fn=training.collate_pil, 26 | batch_size=batch_size, 27 | sampler=RandomSampler(ds, replacement=True, num_samples=960), 28 | ) 29 | 30 | start = time.time() 31 | faces = [] 32 | for x, _ in tqdm(dl): 33 | faces.extend(mtcnn(x)) 34 | elapsed = time.time() - start 35 | print(f'Elapsed: {elapsed} | EPS: {len(dl) * batch_size / elapsed}') 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | --------------------------------------------------------------------------------