├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── assets ├── inputs │ ├── README.md │ ├── image00.jpg │ ├── image01.jpg │ ├── image02.jpg │ ├── video00.mp4 │ ├── video01.mp4 │ ├── video02.mp4 │ ├── video03.mp4 │ └── video04.mp4 └── results │ ├── eth-xgaze_video01.gif │ ├── eth-xgaze_video02.gif │ ├── eth-xgaze_video03.gif │ ├── mpiifacegaze_video00.gif │ ├── mpiigaze_image00.jpg │ └── mpiigaze_video00.gif ├── demo.ipynb ├── ptgaze ├── __init__.py ├── __main__.py ├── common │ ├── __init__.py │ ├── camera.py │ ├── eye.py │ ├── face.py │ ├── face_model.py │ ├── face_model_68.py │ ├── face_model_mediapipe.py │ ├── face_parts.py │ └── visualizer.py ├── data │ ├── calib │ │ └── sample_params.yaml │ ├── configs │ │ ├── eth-xgaze.yaml │ │ ├── mpiifacegaze.yaml │ │ └── mpiigaze.yaml │ └── normalized_camera_params │ │ ├── eth-xgaze.yaml │ │ ├── mpiifacegaze.yaml │ │ └── mpiigaze.yaml ├── demo.py ├── gaze_estimator.py ├── head_pose_estimation │ ├── __init__.py │ ├── face_landmark_estimator.py │ └── head_pose_normalizer.py ├── main.py ├── models │ ├── __init__.py │ ├── mpiifacegaze │ │ ├── __init__.py │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ └── resnet_simple.py │ │ └── resnet_simple.py │ └── mpiigaze │ │ ├── __init__.py │ │ └── resnet_preact.py ├── transforms.py └── utils.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | images/ 2 | videos/ 3 | outputs/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | env/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | .venv 91 | venv/ 92 | ENV/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.0.1 4 | hooks: 5 | - id: check-executables-have-shebangs 6 | - id: check-json 7 | - id: check-merge-conflict 8 | - id: check-shebang-scripts-are-executable 9 | - id: check-toml 10 | - id: check-yaml 11 | - id: double-quote-string-fixer 12 | - id: end-of-file-fixer 13 | - id: mixed-line-ending 14 | args: ['--fix=lf'] 15 | - id: requirements-txt-fixer 16 | - id: trailing-whitespace 17 | - repo: https://github.com/myint/docformatter 18 | rev: v1.4 19 | hooks: 20 | - id: docformatter 21 | args: ['--in-place'] 22 | - repo: https://github.com/pycqa/isort 23 | rev: 5.8.0 24 | hooks: 25 | - id: isort 26 | - repo: https://github.com/pre-commit/mirrors-yapf 27 | rev: v0.31.0 28 | hooks: 29 | - id: yapf 30 | args: ['--parallel', '--in-place', '--style={spaces_before_comment: 2}'] 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 hysts 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include requirements.txt 4 | recursive-include ptgaze *.yaml 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A demo program of gaze estimation models (MPIIGaze, MPIIFaceGaze, ETH-XGaze) 2 | 3 | [![PyPI version](https://badge.fury.io/py/ptgaze.svg)](https://pypi.org/project/ptgaze/) 4 | [![Downloads](https://pepy.tech/badge/ptgaze)](https://pepy.tech/project/ptgaze) 5 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hysts/pytorch_mpiigaze_demo/blob/master/demo.ipynb) 6 | [![MIT License](https://img.shields.io/badge/license-MIT-green)](https://opensource.org/licenses/MIT) 7 | [![GitHub stars](https://img.shields.io/github/stars/hysts/pytorch_mpiigaze_demo.svg?style=flat-square&logo=github&label=Stars&logoColor=white)](https://github.com/hysts/pytorch_mpiigaze_demo) 8 | 9 | With this program, you can run gaze estimation on images and videos. 10 | By default, the video from a webcam will be used. 11 | 12 | ![ETH-XGaze video01 result](https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/master/assets/results/eth-xgaze_video01.gif) 13 | ![ETH-XGaze video02 result](https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/master/assets/results/eth-xgaze_video02.gif) 14 | ![ETH-XGaze video03 result](https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/master/assets/results/eth-xgaze_video03.gif) 15 | 16 | ![MPIIGaze video00 result](https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/master/assets/results/mpiigaze_video00.gif) 17 | ![MPIIFaceGaze video00 result](https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/master/assets/results/mpiifacegaze_video00.gif) 18 | 19 | ![MPIIGaze image00 result](https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/master/assets/results/mpiigaze_image00.jpg) 20 | 21 | To train a model for MPIIGaze and MPIIFaceGaze, 22 | use [this repository](https://github.com/hysts/pytorch_mpiigaze). 23 | You can also use [this repo](https://github.com/hysts/pl_gaze_estimation) 24 | to train a model with ETH-XGaze dataset. 25 | 26 | ## Quick start 27 | 28 | This program is tested only on Ubuntu. 29 | 30 | ### Installation 31 | 32 | ```bash 33 | pip install ptgaze 34 | ``` 35 | 36 | 37 | ### Run demo 38 | 39 | ```bash 40 | ptgaze --mode eth-xgaze 41 | ``` 42 | 43 | 44 | ### Usage 45 | 46 | 47 | ``` 48 | usage: ptgaze [-h] [--config CONFIG] [--mode {mpiigaze,mpiifacegaze,eth-xgaze}] 49 | [--face-detector {dlib,face_alignment_dlib,face_alignment_sfd,mediapipe}] 50 | [--device {cpu,cuda}] [--image IMAGE] [--video VIDEO] [--camera CAMERA] 51 | [--output-dir OUTPUT_DIR] [--ext {avi,mp4}] [--no-screen] [--debug] 52 | 53 | optional arguments: 54 | -h, --help show this help message and exit 55 | --config CONFIG Config file. When using a config file, all the other commandline arguments 56 | are ignored. See 57 | https://github.com/hysts/pytorch_mpiigaze_demo/ptgaze/data/configs/eth- 58 | xgaze.yaml 59 | --mode {mpiigaze,mpiifacegaze,eth-xgaze} 60 | With 'mpiigaze', MPIIGaze model will be used. With 'mpiifacegaze', 61 | MPIIFaceGaze model will be used. With 'eth-xgaze', ETH-XGaze model will be 62 | used. 63 | --face-detector {dlib,face_alignment_dlib,face_alignment_sfd,mediapipe} 64 | The method used to detect faces and find face landmarks (default: 65 | 'mediapipe') 66 | --device {cpu,cuda} Device used for model inference. 67 | --image IMAGE Path to an input image file. 68 | --video VIDEO Path to an input video file. 69 | --camera CAMERA Camera calibration file. See https://github.com/hysts/pytorch_mpiigaze_demo/ 70 | ptgaze/data/calib/sample_params.yaml 71 | --output-dir OUTPUT_DIR, -o OUTPUT_DIR 72 | If specified, the overlaid video will be saved to this directory. 73 | --ext {avi,mp4}, -e {avi,mp4} 74 | Output video file extension. 75 | --no-screen If specified, the video is not displayed on screen, and saved to the output 76 | directory. 77 | --debug 78 | ``` 79 | 80 | While processing an image or video, press the following keys on the window 81 | to show or hide intermediate results: 82 | 83 | - `l`: landmarks 84 | - `h`: head pose 85 | - `t`: projected points of 3D face model 86 | - `b`: face bounding box 87 | 88 | 89 | ## References 90 | 91 | - Zhang, Xucong, Seonwook Park, Thabo Beeler, Derek Bradley, Siyu Tang, and Otmar Hilliges. "ETH-XGaze: A Large Scale Dataset for Gaze Estimation under Extreme Head Pose and Gaze Variation." In European Conference on Computer Vision (ECCV), 2020. [arXiv:2007.15837](https://arxiv.org/abs/2007.15837), [Project Page](https://ait.ethz.ch/projects/2020/ETH-XGaze/), [GitHub](https://github.com/xucong-zhang/ETH-XGaze) 92 | - Zhang, Xucong, Yusuke Sugano, Mario Fritz, and Andreas Bulling. "Appearance-based Gaze Estimation in the Wild." Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2015. [arXiv:1504.02863](https://arxiv.org/abs/1504.02863), [Project Page](https://www.mpi-inf.mpg.de/departments/computer-vision-and-multimodal-computing/research/gaze-based-human-computer-interaction/appearance-based-gaze-estimation-in-the-wild/) 93 | - Zhang, Xucong, Yusuke Sugano, Mario Fritz, and Andreas Bulling. "It's Written All Over Your Face: Full-Face Appearance-Based Gaze Estimation." Proc. of the IEEE Conference on Computer Vision and Pattern Recognition Workshops(CVPRW), 2017. [arXiv:1611.08860](https://arxiv.org/abs/1611.08860), [Project Page](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/its-written-all-over-your-face-full-face-appearance-based-gaze-estimation/) 94 | - Zhang, Xucong, Yusuke Sugano, Mario Fritz, and Andreas Bulling. "MPIIGaze: Real-World Dataset and Deep Appearance-Based Gaze Estimation." IEEE transactions on pattern analysis and machine intelligence 41 (2017). [arXiv:1711.09017](https://arxiv.org/abs/1711.09017) 95 | - Zhang, Xucong, Yusuke Sugano, and Andreas Bulling. "Evaluation of Appearance-Based Methods and Implications for Gaze-Based Applications." Proc. ACM SIGCHI Conference on Human Factors in Computing Systems (CHI), 2019. [arXiv](https://arxiv.org/abs/1901.10906), [code](https://git.hcics.simtech.uni-stuttgart.de/public-projects/opengaze) 96 | -------------------------------------------------------------------------------- /assets/inputs/README.md: -------------------------------------------------------------------------------- 1 | The original images and videos are from the following public domain: 2 | 3 | - https://www.pexels.com/photo/photography-of-a-beautiful-woman-smiling-1024311/ 4 | - https://www.pexels.com/photo/laughing-man-wearing-gray-v-neck-t-shirt-936119/ 5 | - https://www.pexels.com/photo/photo-of-people-doing-handshakes-3184416/ 6 | 7 | - https://www.pexels.com/video/woman-in-a-group-having-a-drink-while-listening-3201742/ 8 | - https://www.pexels.com/video/children-sitting-in-a-classroom-8088529/ 9 | - https://www.pexels.com/video/girls-showing-their-artwork-8088556/ 10 | - https://www.pexels.com/video/students-eating-their-pizza-while-talking-to-their-teacher-5199621/ 11 | - https://www.pexels.com/video/a-father-playing-toys-with-her-daughter-at-home-4820867/ 12 | -------------------------------------------------------------------------------- /assets/inputs/image00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/assets/inputs/image00.jpg -------------------------------------------------------------------------------- /assets/inputs/image01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/assets/inputs/image01.jpg -------------------------------------------------------------------------------- /assets/inputs/image02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/assets/inputs/image02.jpg -------------------------------------------------------------------------------- /assets/inputs/video00.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/assets/inputs/video00.mp4 -------------------------------------------------------------------------------- /assets/inputs/video01.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/assets/inputs/video01.mp4 -------------------------------------------------------------------------------- /assets/inputs/video02.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/assets/inputs/video02.mp4 -------------------------------------------------------------------------------- /assets/inputs/video03.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/assets/inputs/video03.mp4 -------------------------------------------------------------------------------- /assets/inputs/video04.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/assets/inputs/video04.mp4 -------------------------------------------------------------------------------- /assets/results/eth-xgaze_video01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/assets/results/eth-xgaze_video01.gif -------------------------------------------------------------------------------- /assets/results/eth-xgaze_video02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/assets/results/eth-xgaze_video02.gif -------------------------------------------------------------------------------- /assets/results/eth-xgaze_video03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/assets/results/eth-xgaze_video03.gif -------------------------------------------------------------------------------- /assets/results/mpiifacegaze_video00.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/assets/results/mpiifacegaze_video00.gif -------------------------------------------------------------------------------- /assets/results/mpiigaze_image00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/assets/results/mpiigaze_image00.jpg -------------------------------------------------------------------------------- /assets/results/mpiigaze_video00.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/assets/results/mpiigaze_video00.gif -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "07236d7c", 7 | "metadata": { 8 | "ExecuteTime": { 9 | "start_time": "2021-10-21T09:52:57.162Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "!git clone -q https://github.com/hysts/pytorch_mpiigaze_demo\n", 15 | "!cd pytorch_mpiigaze_demo && python setup.py install\n", 16 | "!pip install -U pyyaml" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "a7f04682", 23 | "metadata": { 24 | "ExecuteTime": { 25 | "start_time": "2021-10-21T09:53:05.033Z" 26 | } 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "!ptgaze --mode eth-xgaze --video pytorch_mpiigaze_demo/assets/inputs/video01.mp4 --o . --no-screen" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "af873e50", 37 | "metadata": { 38 | "ExecuteTime": { 39 | "start_time": "2021-10-21T09:53:12.015Z" 40 | } 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "!ffmpeg -i video01.avi -c:v libx264 out.mp4" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "ead94a63", 51 | "metadata": { 52 | "ExecuteTime": { 53 | "start_time": "2021-10-21T09:53:24.304Z" 54 | } 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "from IPython.display import HTML\n", 59 | "from base64 import b64encode\n", 60 | "\n", 61 | "HTML(f\"\"\"\n", 62 | "\n", 65 | "\"\"\")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "id": "4cd9b150", 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [] 75 | } 76 | ], 77 | "metadata": { 78 | "kernelspec": { 79 | "display_name": "Python 3", 80 | "language": "python", 81 | "name": "python3" 82 | }, 83 | "language_info": { 84 | "codemirror_mode": { 85 | "name": "ipython", 86 | "version": 3 87 | }, 88 | "file_extension": ".py", 89 | "mimetype": "text/x-python", 90 | "name": "python", 91 | "nbconvert_exporter": "python", 92 | "pygments_lexer": "ipython3", 93 | "version": "3.9.5" 94 | }, 95 | "toc": { 96 | "base_numbering": 1, 97 | "nav_menu": {}, 98 | "number_sections": true, 99 | "sideBar": true, 100 | "skip_h1_title": false, 101 | "title_cell": "Table of Contents", 102 | "title_sidebar": "Contents", 103 | "toc_cell": false, 104 | "toc_position": {}, 105 | "toc_section_display": true, 106 | "toc_window_display": false 107 | } 108 | }, 109 | "nbformat": 4, 110 | "nbformat_minor": 5 111 | } 112 | -------------------------------------------------------------------------------- /ptgaze/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/ptgaze/__init__.py -------------------------------------------------------------------------------- /ptgaze/__main__.py: -------------------------------------------------------------------------------- 1 | import ptgaze.main 2 | 3 | ptgaze.main.main() 4 | -------------------------------------------------------------------------------- /ptgaze/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .camera import Camera 2 | from .eye import Eye 3 | from .face import Face 4 | from .face_parts import FaceParts, FacePartsName 5 | from .visualizer import Visualizer 6 | -------------------------------------------------------------------------------- /ptgaze/common/camera.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Optional 3 | 4 | import cv2 5 | import numpy as np 6 | import yaml 7 | 8 | 9 | @dataclasses.dataclass() 10 | class Camera: 11 | width: int = dataclasses.field(init=False) 12 | height: int = dataclasses.field(init=False) 13 | camera_matrix: np.ndarray = dataclasses.field(init=False) 14 | dist_coefficients: np.ndarray = dataclasses.field(init=False) 15 | 16 | camera_params_path: dataclasses.InitVar[str] = None 17 | 18 | def __post_init__(self, camera_params_path): 19 | with open(camera_params_path) as f: 20 | data = yaml.safe_load(f) 21 | self.width = data['image_width'] 22 | self.height = data['image_height'] 23 | self.camera_matrix = np.array(data['camera_matrix']['data']).reshape( 24 | 3, 3) 25 | self.dist_coefficients = np.array( 26 | data['distortion_coefficients']['data']).reshape(-1, 1) 27 | 28 | def project_points(self, 29 | points3d: np.ndarray, 30 | rvec: Optional[np.ndarray] = None, 31 | tvec: Optional[np.ndarray] = None) -> np.ndarray: 32 | assert points3d.shape[1] == 3 33 | if rvec is None: 34 | rvec = np.zeros(3, dtype=np.float) 35 | if tvec is None: 36 | tvec = np.zeros(3, dtype=np.float) 37 | points2d, _ = cv2.projectPoints(points3d, rvec, tvec, 38 | self.camera_matrix, 39 | self.dist_coefficients) 40 | return points2d.reshape(-1, 2) 41 | -------------------------------------------------------------------------------- /ptgaze/common/eye.py: -------------------------------------------------------------------------------- 1 | from .face_parts import FaceParts 2 | 3 | 4 | class Eye(FaceParts): 5 | pass 6 | -------------------------------------------------------------------------------- /ptgaze/common/face.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | 5 | from .eye import Eye 6 | from .face_parts import FaceParts, FacePartsName 7 | 8 | 9 | class Face(FaceParts): 10 | def __init__(self, bbox: np.ndarray, landmarks: np.ndarray): 11 | super().__init__(FacePartsName.FACE) 12 | self.bbox = bbox 13 | self.landmarks = landmarks 14 | 15 | self.reye: Eye = Eye(FacePartsName.REYE) 16 | self.leye: Eye = Eye(FacePartsName.LEYE) 17 | 18 | self.head_position: Optional[np.ndarray] = None 19 | self.model3d: Optional[np.ndarray] = None 20 | 21 | @staticmethod 22 | def change_coordinate_system(euler_angles: np.ndarray) -> np.ndarray: 23 | return euler_angles * np.array([-1, 1, -1]) 24 | -------------------------------------------------------------------------------- /ptgaze/common/face_model.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import cv2 4 | import numpy as np 5 | from scipy.spatial.transform import Rotation 6 | 7 | from .camera import Camera 8 | from .face import Face 9 | 10 | 11 | @dataclasses.dataclass(frozen=True) 12 | class FaceModel: 13 | LANDMARKS: np.ndarray 14 | REYE_INDICES: np.ndarray 15 | LEYE_INDICES: np.ndarray 16 | MOUTH_INDICES: np.ndarray 17 | NOSE_INDICES: np.ndarray 18 | CHIN_INDEX: int 19 | NOSE_INDEX: int 20 | 21 | def estimate_head_pose(self, face: Face, camera: Camera) -> None: 22 | """Estimate the head pose by fitting 3D template model.""" 23 | # If the number of the template points is small, cv2.solvePnP 24 | # becomes unstable, so set the default value for rvec and tvec 25 | # and set useExtrinsicGuess to True. 26 | # The default values of rvec and tvec below mean that the 27 | # initial estimate of the head pose is not rotated and the 28 | # face is in front of the camera. 29 | rvec = np.zeros(3, dtype=np.float) 30 | tvec = np.array([0, 0, 1], dtype=np.float) 31 | _, rvec, tvec = cv2.solvePnP(self.LANDMARKS, 32 | face.landmarks, 33 | camera.camera_matrix, 34 | camera.dist_coefficients, 35 | rvec, 36 | tvec, 37 | useExtrinsicGuess=True, 38 | flags=cv2.SOLVEPNP_ITERATIVE) 39 | rot = Rotation.from_rotvec(rvec) 40 | face.head_pose_rot = rot 41 | face.head_position = tvec 42 | face.reye.head_pose_rot = rot 43 | face.leye.head_pose_rot = rot 44 | 45 | def compute_3d_pose(self, face: Face) -> None: 46 | """Compute the transformed model.""" 47 | rot = face.head_pose_rot.as_matrix() 48 | face.model3d = self.LANDMARKS @ rot.T + face.head_position 49 | 50 | def compute_face_eye_centers(self, face: Face, mode: str) -> None: 51 | """Compute the centers of the face and eyes. 52 | 53 | In the case of MPIIFaceGaze, the face center is defined as the 54 | average coordinates of the six points at the corners of both 55 | eyes and the mouth. In the case of ETH-XGaze, it's defined as 56 | the average coordinates of the six points at the corners of both 57 | eyes and the nose. The eye centers are defined as the average 58 | coordinates of the corners of each eye. 59 | """ 60 | if mode == 'ETH-XGaze': 61 | face.center = face.model3d[np.concatenate( 62 | [self.REYE_INDICES, self.LEYE_INDICES, 63 | self.NOSE_INDICES])].mean(axis=0) 64 | else: 65 | face.center = face.model3d[np.concatenate( 66 | [self.REYE_INDICES, self.LEYE_INDICES, 67 | self.MOUTH_INDICES])].mean(axis=0) 68 | face.reye.center = face.model3d[self.REYE_INDICES].mean(axis=0) 69 | face.leye.center = face.model3d[self.LEYE_INDICES].mean(axis=0) 70 | -------------------------------------------------------------------------------- /ptgaze/common/face_model_68.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import numpy as np 4 | 5 | from .face_model import FaceModel 6 | 7 | 8 | @dataclasses.dataclass(frozen=True) 9 | class FaceModel68(FaceModel): 10 | """3D face model for Multi-PIE 68 points mark-up. 11 | 12 | In the camera coordinate system, the X axis points to the right from 13 | camera, the Y axis points down, and the Z axis points forward. 14 | 15 | The face model is facing the camera. Here, the Z axis is 16 | perpendicular to the plane passing through the three midpoints of 17 | the eyes and mouth, the X axis is parallel to the line passing 18 | through the midpoints of both eyes, and the origin is at the tip of 19 | the nose. 20 | 21 | The units of the coordinate system are meters and the distance 22 | between outer eye corners of the model is set to 90mm. 23 | 24 | The model coordinate system is defined as the camera coordinate 25 | system rotated 180 degrees around the Y axis. 26 | """ 27 | LANDMARKS: np.ndarray = np.array([ 28 | [-0.07141807, -0.02827123, 0.08114384], 29 | [-0.07067417, -0.00961522, 0.08035654], 30 | [-0.06844646, 0.00895837, 0.08046731], 31 | [-0.06474301, 0.02708319, 0.08045689], 32 | [-0.05778475, 0.04384917, 0.07802191], 33 | [-0.04673809, 0.05812865, 0.07192291], 34 | [-0.03293922, 0.06962711, 0.06106274], 35 | [-0.01744018, 0.07850638, 0.04752971], 36 | [0., 0.08105961, 0.0425195], 37 | [0.01744018, 0.07850638, 0.04752971], 38 | [0.03293922, 0.06962711, 0.06106274], 39 | [0.04673809, 0.05812865, 0.07192291], 40 | [0.05778475, 0.04384917, 0.07802191], 41 | [0.06474301, 0.02708319, 0.08045689], 42 | [0.06844646, 0.00895837, 0.08046731], 43 | [0.07067417, -0.00961522, 0.08035654], 44 | [0.07141807, -0.02827123, 0.08114384], 45 | [-0.05977758, -0.0447858, 0.04562813], 46 | [-0.05055506, -0.05334294, 0.03834846], 47 | [-0.0375633, -0.05609241, 0.03158344], 48 | [-0.02423648, -0.05463779, 0.02510117], 49 | [-0.01168798, -0.04986641, 0.02050337], 50 | [0.01168798, -0.04986641, 0.02050337], 51 | [0.02423648, -0.05463779, 0.02510117], 52 | [0.0375633, -0.05609241, 0.03158344], 53 | [0.05055506, -0.05334294, 0.03834846], 54 | [0.05977758, -0.0447858, 0.04562813], 55 | [0., -0.03515768, 0.02038099], 56 | [0., -0.02350421, 0.01366667], 57 | [0., -0.01196914, 0.00658284], 58 | [0., 0., 0.], 59 | [-0.01479319, 0.00949072, 0.01708772], 60 | [-0.00762319, 0.01179908, 0.01419133], 61 | [0., 0.01381676, 0.01205559], 62 | [0.00762319, 0.01179908, 0.01419133], 63 | [0.01479319, 0.00949072, 0.01708772], 64 | [-0.045, -0.032415, 0.03976718], 65 | [-0.0370546, -0.0371723, 0.03579593], 66 | [-0.0275166, -0.03714814, 0.03425518], 67 | [-0.01919724, -0.03101962, 0.03359268], 68 | [-0.02813814, -0.0294397, 0.03345652], 69 | [-0.03763013, -0.02948442, 0.03497732], 70 | [0.01919724, -0.03101962, 0.03359268], 71 | [0.0275166, -0.03714814, 0.03425518], 72 | [0.0370546, -0.0371723, 0.03579593], 73 | [0.045, -0.032415, 0.03976718], 74 | [0.03763013, -0.02948442, 0.03497732], 75 | [0.02813814, -0.0294397, 0.03345652], 76 | [-0.02847002, 0.03331642, 0.03667993], 77 | [-0.01796181, 0.02843251, 0.02335485], 78 | [-0.00742947, 0.0258057, 0.01630812], 79 | [0., 0.0275555, 0.01538404], 80 | [0.00742947, 0.0258057, 0.01630812], 81 | [0.01796181, 0.02843251, 0.02335485], 82 | [0.02847002, 0.03331642, 0.03667993], 83 | [0.0183606, 0.0423393, 0.02523355], 84 | [0.00808323, 0.04614537, 0.01820142], 85 | [0., 0.04688623, 0.01716318], 86 | [-0.00808323, 0.04614537, 0.01820142], 87 | [-0.0183606, 0.0423393, 0.02523355], 88 | [-0.02409981, 0.03367606, 0.03421466], 89 | [-0.00756874, 0.03192644, 0.01851247], 90 | [0., 0.03263345, 0.01732347], 91 | [0.00756874, 0.03192644, 0.01851247], 92 | [0.02409981, 0.03367606, 0.03421466], 93 | [0.00771924, 0.03711846, 0.01940396], 94 | [0., 0.03791103, 0.0180805], 95 | [-0.00771924, 0.03711846, 0.01940396], 96 | ], 97 | dtype=np.float64) 98 | 99 | REYE_INDICES: np.ndarray = np.array([36, 39]) 100 | LEYE_INDICES: np.ndarray = np.array([42, 45]) 101 | MOUTH_INDICES: np.ndarray = np.array([48, 54]) 102 | NOSE_INDICES: np.ndarray = np.array([31, 35]) 103 | 104 | CHIN_INDEX: int = 8 105 | NOSE_INDEX: int = 30 106 | -------------------------------------------------------------------------------- /ptgaze/common/face_model_mediapipe.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import numpy as np 4 | 5 | from .face_model import FaceModel 6 | 7 | 8 | @dataclasses.dataclass(frozen=True) 9 | class FaceModelMediaPipe(FaceModel): 10 | """3D face model for MediaPipe 468 points mark-up. 11 | 12 | In the camera coordinate system, the X axis points to the right from 13 | camera, the Y axis points down, and the Z axis points forward. 14 | 15 | The face model is facing the camera. Here, the Z axis is 16 | perpendicular to the plane passing through the three midpoints of 17 | the eyes and mouth, the X axis is parallel to the line passing 18 | through the midpoints of both eyes, and the origin is at the tip of 19 | the nose. 20 | 21 | The units of the coordinate system are meters and the distance 22 | between outer eye corners of the model is set to 90mm. 23 | 24 | The model coordinate system is defined as the camera coordinate 25 | system rotated 180 degrees around the Y axis. 26 | """ 27 | LANDMARKS: np.ndarray = np.array([ 28 | [0.0, 0.02279539, 0.01496097], 29 | [0.0, 0.0, 0.0], 30 | [0.0, 0.00962159, 0.01417337], 31 | [-0.00463928, -0.02082222, 0.00842021], 32 | [0.0, -0.00663695, -0.00110976], 33 | [0.0, -0.01492534, 0.00232734], 34 | [0.0, -0.0360012, 0.01686977], 35 | [-0.04253081, -0.03704511, 0.04195902], 36 | [0.0, -0.05145907, 0.0219084], 37 | [0.0, -0.06012844, 0.02090346], 38 | [0.0, -0.09388643, 0.02994069], 39 | [0.0, 0.02579946, 0.0161068], 40 | [0.0, 0.02791436, 0.01906174], 41 | [0.0, 0.02867571, 0.02256122], 42 | [0.0, 0.03415535, 0.0207085], 43 | [0.0, 0.03618712, 0.01946147], 44 | [0.0, 0.03892702, 0.01874156], 45 | [0.0, 0.04238258, 0.01940163], 46 | [0.0, 0.05022759, 0.02404232], 47 | [0.0, 0.0037423, 0.00363408], 48 | [-0.00416106, 0.00339584, 0.01027947], 49 | [-0.0708796, -0.06561666, 0.07375984], 50 | [-0.02628639, -0.03162763, 0.03627483], 51 | [-0.03198363, -0.0311268, 0.03678652], 52 | [-0.03775151, -0.03166267, 0.0382941], 53 | [-0.04465819, -0.03549815, 0.04320436], 54 | [-0.02164289, -0.03316732, 0.03623782], 55 | [-0.03208229, -0.04350791, 0.03359782], 56 | [-0.02673803, -0.04332202, 0.03383401], 57 | [-0.03745193, -0.04292151, 0.03503195], 58 | [-0.04161018, -0.04185934, 0.0375605], 59 | [-0.05062006, -0.03061283, 0.04699511], 60 | [-0.02266659, 0.06298903, 0.03085792], 61 | [-0.04445859, -0.03790856, 0.04302182], 62 | [-0.0721453, -0.03389874, 0.07402454], 63 | [-0.05799793, -0.03476411, 0.05271545], 64 | [-0.02844939, -0.00405997, 0.03042474], 65 | [-0.00711452, 0.0220249, 0.0159856], 66 | [-0.00606033, 0.02797697, 0.02030681], 67 | [-0.01431615, 0.02374088, 0.01979415], 68 | [-0.0191491, 0.02676281, 0.02446674], 69 | [-0.01131043, 0.02847072, 0.02285956], 70 | [-0.01563548, 0.02955898, 0.02633341], 71 | [-0.02650112, 0.03876784, 0.03287121], 72 | [-0.00427049, -0.00032731, 0.00115075], 73 | [-0.00496396, -0.00651206, 0.00035246], 74 | [-0.05253307, -0.05008447, 0.04112445], 75 | [-0.01718698, -0.02101474, 0.02917245], 76 | [-0.01608635, -0.00184349, 0.01661411], 77 | [-0.01651267, -0.00515997, 0.01894285], 78 | [-0.04765501, -0.00425311, 0.03940972], 79 | [-0.00478306, -0.01422631, 0.00374591], 80 | [-0.03734964, -0.05635095, 0.0292515], 81 | [-0.04588603, -0.05428902, 0.0342712], 82 | [-0.06279331, -0.07742292, 0.06049754], 83 | [-0.01220941, -0.0526903, 0.02369569], 84 | [-0.02193489, -0.04227182, 0.03475029], 85 | [-0.03102642, 0.03226119, 0.03379699], 86 | [-0.06719682, 0.0366178, 0.09221005], 87 | [-0.01193824, 0.0017993, 0.01737857], 88 | [-0.00729766, 0.00466847, 0.01642396], 89 | [-0.02456206, 0.03215756, 0.0319172], 90 | [-0.02204823, 0.03177643, 0.03313105], 91 | [-0.04985894, -0.05929326, 0.03723627], 92 | [-0.01592294, 0.00130844, 0.02018655], 93 | [-0.02644548, -0.05651519, 0.02554045], 94 | [-0.02760292, -0.06227836, 0.02459614], 95 | [-0.03523964, -0.09132841, 0.03746441], 96 | [-0.05599763, -0.06842335, 0.04751345], 97 | [-0.03063932, -0.07693009, 0.02945623], 98 | [-0.05720968, -0.05381449, 0.04644752], 99 | [-0.06374393, -0.05912455, 0.05883913], 100 | [-0.00672728, 0.02561151, 0.017378], 101 | [-0.0126256, 0.02660826, 0.02057825], 102 | [-0.01732553, 0.02825902, 0.02475025], 103 | [-0.01043625, 0.00338108, 0.01813149], 104 | [-0.02321234, 0.03202204, 0.03217448], 105 | [-0.02056846, 0.03350806, 0.02954721], 106 | [-0.02153084, 0.03149457, 0.03437511], 107 | [-0.00946874, -0.00091616, 0.0096333], 108 | [-0.01469132, 0.02909486, 0.02870696], 109 | [-0.0102434, 0.02862986, 0.02548911], 110 | [-0.00533422, 0.02866357, 0.02337402], 111 | [-0.0076972, 0.04968529, 0.02489721], 112 | [-0.00699606, 0.04164985, 0.020273], 113 | [-0.00669687, 0.03822905, 0.01965992], 114 | [-0.00630947, 0.03568236, 0.02026233], 115 | [-0.00583218, 0.03391117, 0.02135735], 116 | [-0.0153717, 0.03296341, 0.02730134], 117 | [-0.016156, 0.03349077, 0.02661972], 118 | [-0.01729053, 0.03491815, 0.02621141], 119 | [-0.01838624, 0.03701881, 0.02651867], 120 | [-0.0236825, 0.01979372, 0.02607508], 121 | [-0.07542244, -0.00077583, 0.09906925], 122 | [0.0, 0.00597138, 0.00874214], 123 | [-0.01826614, 0.03272666, 0.03076583], 124 | [-0.01929558, 0.03284966, 0.02978552], 125 | [-0.00597442, 0.00886821, 0.01609148], 126 | [-0.01405627, 0.00587331, 0.02234517], 127 | [-0.00662449, 0.00692456, 0.01611845], 128 | [-0.0234234, -0.01699087, 0.03181301], 129 | [-0.03327324, -0.01231728, 0.03361744], 130 | [-0.01726175, -0.002077, 0.02202249], 131 | [-0.05133204, -0.08612467, 0.04815162], 132 | [-0.04538641, -0.07446772, 0.0379218], 133 | [-0.03986562, -0.06236352, 0.03009289], 134 | [-0.02169681, 0.04313568, 0.0301973], 135 | [-0.01395634, -0.06138828, 0.02159572], 136 | [-0.016195, -0.07726082, 0.02554498], 137 | [-0.01891399, -0.09363242, 0.03200607], 138 | [-0.04195832, -0.0336207, 0.04100505], 139 | [-0.05733342, -0.02538603, 0.05043878], 140 | [-0.01859887, -0.03482622, 0.03632423], 141 | [-0.04988612, -0.04201519, 0.04391746], 142 | [-0.01303263, -0.02543318, 0.02644513], 143 | [-0.01305757, -0.00454086, 0.01059645], 144 | [-0.0646517, -0.02063984, 0.05785731], 145 | [-0.05258659, -0.02072676, 0.04501292], 146 | [-0.04432338, -0.01848961, 0.03952989], 147 | [-0.03300681, -0.01988506, 0.0360282], 148 | [-0.02430178, -0.02258357, 0.03436569], 149 | [-0.01820731, -0.02594819, 0.0325148], 150 | [-0.00563221, -0.03434558, 0.01908815], 151 | [-0.06338145, -0.00597586, 0.05594429], 152 | [-0.05587698, -0.04334936, 0.04787765], 153 | [-0.00242624, 0.00335992, 0.00404113], 154 | [-0.01611251, -0.01466191, 0.02580183], 155 | [-0.07743095, -0.03491864, 0.09480771], 156 | [-0.01391142, -0.02977913, 0.03026605], 157 | [-0.01785794, -0.00148581, 0.02625134], 158 | [-0.04670959, -0.03791326, 0.04391529], 159 | [-0.0133397, -0.00843104, 0.01378557], 160 | [-0.07270895, 0.01764052, 0.09728059], 161 | [-0.01856432, -0.0371211, 0.037177], 162 | [-0.00923388, -0.01199941, 0.0080366], 163 | [-0.05000589, 0.05008263, 0.05583081], 164 | [-0.05085276, 0.06051725, 0.06760893], 165 | [-0.07159291, -0.00315045, 0.07547648], 166 | [-0.05843051, 0.04121158, 0.06551513], 167 | [-0.06847258, -0.04789781, 0.06750909], 168 | [-0.02412942, 0.07131988, 0.03356391], 169 | [-0.00179909, 0.00562999, 0.00902303], 170 | [-0.02103655, -0.00962919, 0.02909485], 171 | [-0.06407571, -0.03362886, 0.05914761], 172 | [-0.03670075, -0.03487018, 0.03840374], 173 | [-0.03177186, -0.0342113, 0.036999], 174 | [-0.02196121, 0.03471457, 0.02995818], 175 | [-0.06234883, 0.00817565, 0.05812062], 176 | [-0.01292924, 0.08169055, 0.03381541], 177 | [-0.03210651, 0.07406413, 0.04673603], 178 | [-0.04068926, 0.06866244, 0.05550485], 179 | [0.0, -0.07672255, 0.02448293], 180 | [0.0, 0.08276513, 0.03211112], 181 | [-0.02724032, -0.03442667, 0.03698453], 182 | [-0.0228846, -0.03525756, 0.03778001], 183 | [-0.01998311, -0.03623412, 0.03786456], 184 | [-0.0613004, -0.04526126, 0.05437088], 185 | [-0.0228846, -0.04013369, 0.03700573], 186 | [-0.02724032, -0.04088675, 0.03603837], 187 | [-0.03177186, -0.04091001, 0.03598631], 188 | [-0.03670075, -0.04054579, 0.03751279], 189 | [-0.04018389, -0.03984222, 0.03992621], 190 | [-0.07555811, -0.05233676, 0.08467521], 191 | [-0.04018389, -0.0361056, 0.04034706], 192 | [0.0, 0.0139508, 0.01543339], 193 | [-0.01776217, 0.01557081, 0.02262488], 194 | [-0.01222237, 0.00055579, 0.01523139], 195 | [-0.00731493, 0.01409818, 0.01660261], 196 | [0.0, -0.04397892, 0.02239589], 197 | [-0.04135272, 0.05869773, 0.04803634], 198 | [-0.03311811, 0.0653395, 0.04092641], 199 | [-0.01313701, 0.0751313, 0.02773148], 200 | [-0.05940524, 0.05096764, 0.08107072], 201 | [-0.01998311, -0.03870703, 0.03731574], 202 | [-0.00901447, -0.02363857, 0.01721348], 203 | [0.0, 0.07638378, 0.02584163], 204 | [-0.02308977, 0.07847331, 0.03866534], 205 | [-0.06954154, 0.01312978, 0.07606767], 206 | [-0.01098819, 0.03331923, 0.02354877], 207 | [-0.01181124, 0.03453131, 0.0228604], 208 | [-0.01255818, 0.03661036, 0.02238553], 209 | [-0.01325085, 0.03979642, 0.02270594], 210 | [-0.01546388, 0.04692527, 0.02717711], 211 | [-0.01953754, 0.03057027, 0.03043891], 212 | [-0.02117802, 0.03010228, 0.02920508], 213 | [-0.02285339, 0.02924331, 0.02893166], 214 | [-0.0285016, 0.02538855, 0.0299061], 215 | [-0.05278538, 0.01112077, 0.0461438], 216 | [-0.00946709, -0.03034493, 0.02278825], 217 | [-0.01314173, -0.04231777, 0.032442], 218 | [-0.0178, -0.03986865, 0.03594049], 219 | [-0.0184511, 0.02972015, 0.0322834], 220 | [-0.05436187, 0.02903617, 0.05365752], 221 | [-0.00766444, -0.04308996, 0.02614151], 222 | [-0.01938616, 0.05487545, 0.02954519], 223 | [0.0, -0.02186278, 0.00700999], 224 | [-0.00516573, -0.02710437, 0.01327241], 225 | [0.0, -0.02855234, 0.01158854], 226 | [-0.01246815, -0.01357162, 0.01794568], 227 | [0.0, 0.06815329, 0.02294431], 228 | [0.0, 0.05864634, 0.02322126], 229 | [-0.00997827, 0.05804056, 0.02496028], 230 | [-0.03288807, 0.04255649, 0.03679852], 231 | [-0.02311631, 0.00439372, 0.02885519], 232 | [-0.0268025, 0.04984702, 0.03379452], 233 | [-0.03832928, 0.00410461, 0.03337873], 234 | [-0.0296186, 0.0114735, 0.03034661], 235 | [-0.04386901, 0.01556421, 0.03831718], 236 | [-0.01217295, 0.067076, 0.02506318], 237 | [-0.01542374, -0.00990022, 0.02274596], 238 | [-0.03878377, 0.04914899, 0.04164525], 239 | [-0.03084037, 0.05682977, 0.03661409], 240 | [-0.03747321, 0.0337668, 0.03749151], 241 | [-0.06094129, 0.02079126, 0.06002122], 242 | [-0.04588995, 0.03601861, 0.04492383], 243 | [-0.06583231, 0.02814404, 0.07405336], 244 | [-0.0349258, 0.02068955, 0.03345406], 245 | [-0.01255543, -0.01929206, 0.02168053], 246 | [-0.01126122, -0.00193263, 0.00936819], 247 | [-0.01443109, 0.00015909, 0.01570477], 248 | [-0.00923043, -0.00597823, 0.00472181], 249 | [-0.01755386, -0.04655982, 0.03147908], 250 | [-0.02632589, -0.04840693, 0.03110975], 251 | [-0.03388062, -0.04848841, 0.03166576], 252 | [-0.04075766, -0.04802278, 0.03399541], 253 | [-0.0462291, -0.04601556, 0.03829283], 254 | [-0.05171755, -0.03662618, 0.04804737], 255 | [-0.07297331, -0.01890037, 0.07524373], 256 | [-0.04706828, -0.02777865, 0.04366072], 257 | [-0.04071712, -0.02603686, 0.0399866], 258 | [-0.03269817, -0.02597524, 0.03743659], 259 | [-0.02527572, -0.02744176, 0.0361016], 260 | [-0.01970894, -0.0298537, 0.03513822], 261 | [-0.01579543, -0.03224806, 0.03390608], 262 | [-0.07664182, -0.01799997, 0.09911471], 263 | [-0.01397041, 0.00213274, 0.01845226], 264 | [-0.00884838, -0.01785605, 0.01242372], 265 | [-0.00767097, -0.0015883, 0.00397672], 266 | [-0.00460213, 0.00207241, 0.00688157], 267 | [-0.00748618, -0.00058871, 0.00677301], 268 | [-0.01236408, 0.00458703, 0.01995114], 269 | [-0.00387306, 0.00283125, 0.00517899], 270 | [-0.00319925, 0.00481066, 0.00966928], 271 | [-0.01639633, -0.03683163, 0.03611868], 272 | [-0.01255645, -0.03594009, 0.03271804], 273 | [-0.01031362, -0.03509528, 0.02859755], 274 | [-0.04253081, -0.03899161, 0.04160299], 275 | [-0.0453, -0.04036865, 0.04135919], 276 | [0.00463928, -0.02082222, 0.00842021], 277 | [0.04253081, -0.03704511, 0.04195902], 278 | [0.00416106, 0.00339584, 0.01027947], 279 | [0.0708796, -0.06561666, 0.07375984], 280 | [0.02628639, -0.03162763, 0.03627483], 281 | [0.03198363, -0.0311268, 0.03678652], 282 | [0.03775151, -0.03166267, 0.0382941], 283 | [0.04465819, -0.03549815, 0.04320436], 284 | [0.02164289, -0.03316732, 0.03623782], 285 | [0.03208229, -0.04350791, 0.03359782], 286 | [0.02673803, -0.04332202, 0.03383401], 287 | [0.03745193, -0.04292151, 0.03503195], 288 | [0.04161018, -0.04185934, 0.0375605], 289 | [0.05062006, -0.03061283, 0.04699511], 290 | [0.02266659, 0.06298903, 0.03085792], 291 | [0.04445859, -0.03790856, 0.04302182], 292 | [0.0721453, -0.03389874, 0.07402454], 293 | [0.05799793, -0.03476411, 0.05271545], 294 | [0.02844939, -0.00405997, 0.03042474], 295 | [0.00711452, 0.0220249, 0.0159856], 296 | [0.00606033, 0.02797697, 0.02030681], 297 | [0.01431615, 0.02374088, 0.01979415], 298 | [0.0191491, 0.02676281, 0.02446674], 299 | [0.01131043, 0.02847072, 0.02285956], 300 | [0.01563548, 0.02955898, 0.02633341], 301 | [0.02650112, 0.03876784, 0.03287121], 302 | [0.00427049, -0.00032731, 0.00115075], 303 | [0.00496396, -0.00651206, 0.00035246], 304 | [0.05253307, -0.05008447, 0.04112445], 305 | [0.01718698, -0.02101474, 0.02917245], 306 | [0.01608635, -0.00184349, 0.01661411], 307 | [0.01651267, -0.00515997, 0.01894285], 308 | [0.04765501, -0.00425311, 0.03940972], 309 | [0.00478306, -0.01422631, 0.00374591], 310 | [0.03734964, -0.05635095, 0.0292515], 311 | [0.04588603, -0.05428902, 0.0342712], 312 | [0.06279331, -0.07742292, 0.06049754], 313 | [0.01220941, -0.0526903, 0.02369569], 314 | [0.02193489, -0.04227182, 0.03475029], 315 | [0.03102642, 0.03226119, 0.03379699], 316 | [0.06719682, 0.0366178, 0.09221005], 317 | [0.01193824, 0.0017993, 0.01737857], 318 | [0.00729766, 0.00466847, 0.01642396], 319 | [0.02456206, 0.03215756, 0.0319172], 320 | [0.02204823, 0.03177643, 0.03313105], 321 | [0.04985894, -0.05929326, 0.03723627], 322 | [0.01592294, 0.00130844, 0.02018655], 323 | [0.02644548, -0.05651519, 0.02554045], 324 | [0.02760292, -0.06227836, 0.02459614], 325 | [0.03523964, -0.09132841, 0.03746441], 326 | [0.05599763, -0.06842335, 0.04751345], 327 | [0.03063932, -0.07693009, 0.02945623], 328 | [0.05720968, -0.05381449, 0.04644752], 329 | [0.06374393, -0.05912455, 0.05883913], 330 | [0.00672728, 0.02561151, 0.017378], 331 | [0.0126256, 0.02660826, 0.02057825], 332 | [0.01732553, 0.02825902, 0.02475025], 333 | [0.01043625, 0.00338108, 0.01813149], 334 | [0.02321234, 0.03202204, 0.03217448], 335 | [0.02056846, 0.03350806, 0.02954721], 336 | [0.02153084, 0.03149457, 0.03437511], 337 | [0.00946874, -0.00091616, 0.0096333], 338 | [0.01469132, 0.02909486, 0.02870696], 339 | [0.0102434, 0.02862986, 0.02548911], 340 | [0.00533422, 0.02866357, 0.02337402], 341 | [0.0076972, 0.04968529, 0.02489721], 342 | [0.00699606, 0.04164985, 0.020273], 343 | [0.00669687, 0.03822905, 0.01965992], 344 | [0.00630947, 0.03568236, 0.02026233], 345 | [0.00583218, 0.03391117, 0.02135735], 346 | [0.0153717, 0.03296341, 0.02730134], 347 | [0.016156, 0.03349077, 0.02661972], 348 | [0.01729053, 0.03491815, 0.02621141], 349 | [0.01838624, 0.03701881, 0.02651867], 350 | [0.0236825, 0.01979372, 0.02607508], 351 | [0.07542244, -0.00077583, 0.09906925], 352 | [0.01826614, 0.03272666, 0.03076583], 353 | [0.01929558, 0.03284966, 0.02978552], 354 | [0.00597442, 0.00886821, 0.01609148], 355 | [0.01405627, 0.00587331, 0.02234517], 356 | [0.00662449, 0.00692456, 0.01611845], 357 | [0.0234234, -0.01699087, 0.03181301], 358 | [0.03327324, -0.01231728, 0.03361744], 359 | [0.01726175, -0.002077, 0.02202249], 360 | [0.05133204, -0.08612467, 0.04815162], 361 | [0.04538641, -0.07446772, 0.0379218], 362 | [0.03986562, -0.06236352, 0.03009289], 363 | [0.02169681, 0.04313568, 0.0301973], 364 | [0.01395634, -0.06138828, 0.02159572], 365 | [0.016195, -0.07726082, 0.02554498], 366 | [0.01891399, -0.09363242, 0.03200607], 367 | [0.04195832, -0.0336207, 0.04100505], 368 | [0.05733342, -0.02538603, 0.05043878], 369 | [0.01859887, -0.03482622, 0.03632423], 370 | [0.04988612, -0.04201519, 0.04391746], 371 | [0.01303263, -0.02543318, 0.02644513], 372 | [0.01305757, -0.00454086, 0.01059645], 373 | [0.0646517, -0.02063984, 0.05785731], 374 | [0.05258659, -0.02072676, 0.04501292], 375 | [0.04432338, -0.01848961, 0.03952989], 376 | [0.03300681, -0.01988506, 0.0360282], 377 | [0.02430178, -0.02258357, 0.03436569], 378 | [0.01820731, -0.02594819, 0.0325148], 379 | [0.00563221, -0.03434558, 0.01908815], 380 | [0.06338145, -0.00597586, 0.05594429], 381 | [0.05587698, -0.04334936, 0.04787765], 382 | [0.00242624, 0.00335992, 0.00404113], 383 | [0.01611251, -0.01466191, 0.02580183], 384 | [0.07743095, -0.03491864, 0.09480771], 385 | [0.01391142, -0.02977913, 0.03026605], 386 | [0.01785794, -0.00148581, 0.02625134], 387 | [0.04670959, -0.03791326, 0.04391529], 388 | [0.0133397, -0.00843104, 0.01378557], 389 | [0.07270895, 0.01764052, 0.09728059], 390 | [0.01856432, -0.0371211, 0.037177], 391 | [0.00923388, -0.01199941, 0.0080366], 392 | [0.05000589, 0.05008263, 0.05583081], 393 | [0.05085276, 0.06051725, 0.06760893], 394 | [0.07159291, -0.00315045, 0.07547648], 395 | [0.05843051, 0.04121158, 0.06551513], 396 | [0.06847258, -0.04789781, 0.06750909], 397 | [0.02412942, 0.07131988, 0.03356391], 398 | [0.00179909, 0.00562999, 0.00902303], 399 | [0.02103655, -0.00962919, 0.02909485], 400 | [0.06407571, -0.03362886, 0.05914761], 401 | [0.03670075, -0.03487018, 0.03840374], 402 | [0.03177186, -0.0342113, 0.036999], 403 | [0.02196121, 0.03471457, 0.02995818], 404 | [0.06234883, 0.00817565, 0.05812062], 405 | [0.01292924, 0.08169055, 0.03381541], 406 | [0.03210651, 0.07406413, 0.04673603], 407 | [0.04068926, 0.06866244, 0.05550485], 408 | [0.02724032, -0.03442667, 0.03698453], 409 | [0.0228846, -0.03525756, 0.03778001], 410 | [0.01998311, -0.03623412, 0.03786456], 411 | [0.0613004, -0.04526126, 0.05437088], 412 | [0.0228846, -0.04013369, 0.03700573], 413 | [0.02724032, -0.04088675, 0.03603837], 414 | [0.03177186, -0.04091001, 0.03598631], 415 | [0.03670075, -0.04054579, 0.03751279], 416 | [0.04018389, -0.03984222, 0.03992621], 417 | [0.07555811, -0.05233676, 0.08467521], 418 | [0.04018389, -0.0361056, 0.04034706], 419 | [0.01776217, 0.01557081, 0.02262488], 420 | [0.01222237, 0.00055579, 0.01523139], 421 | [0.00731493, 0.01409818, 0.01660261], 422 | [0.04135272, 0.05869773, 0.04803634], 423 | [0.03311811, 0.0653395, 0.04092641], 424 | [0.01313701, 0.0751313, 0.02773148], 425 | [0.05940524, 0.05096764, 0.08107072], 426 | [0.01998311, -0.03870703, 0.03731574], 427 | [0.00901447, -0.02363857, 0.01721348], 428 | [0.02308977, 0.07847331, 0.03866534], 429 | [0.06954154, 0.01312978, 0.07606767], 430 | [0.01098819, 0.03331923, 0.02354877], 431 | [0.01181124, 0.03453131, 0.0228604], 432 | [0.01255818, 0.03661036, 0.02238553], 433 | [0.01325085, 0.03979642, 0.02270594], 434 | [0.01546388, 0.04692527, 0.02717711], 435 | [0.01953754, 0.03057027, 0.03043891], 436 | [0.02117802, 0.03010228, 0.02920508], 437 | [0.02285339, 0.02924331, 0.02893166], 438 | [0.0285016, 0.02538855, 0.0299061], 439 | [0.05278538, 0.01112077, 0.0461438], 440 | [0.00946709, -0.03034493, 0.02278825], 441 | [0.01314173, -0.04231777, 0.032442], 442 | [0.0178, -0.03986865, 0.03594049], 443 | [0.0184511, 0.02972015, 0.0322834], 444 | [0.05436187, 0.02903617, 0.05365752], 445 | [0.00766444, -0.04308996, 0.02614151], 446 | [0.01938616, 0.05487545, 0.02954519], 447 | [0.00516573, -0.02710437, 0.01327241], 448 | [0.01246815, -0.01357162, 0.01794568], 449 | [0.00997827, 0.05804056, 0.02496028], 450 | [0.03288807, 0.04255649, 0.03679852], 451 | [0.02311631, 0.00439372, 0.02885519], 452 | [0.0268025, 0.04984702, 0.03379452], 453 | [0.03832928, 0.00410461, 0.03337873], 454 | [0.0296186, 0.0114735, 0.03034661], 455 | [0.04386901, 0.01556421, 0.03831718], 456 | [0.01217295, 0.067076, 0.02506318], 457 | [0.01542374, -0.00990022, 0.02274596], 458 | [0.03878377, 0.04914899, 0.04164525], 459 | [0.03084037, 0.05682977, 0.03661409], 460 | [0.03747321, 0.0337668, 0.03749151], 461 | [0.06094129, 0.02079126, 0.06002122], 462 | [0.04588995, 0.03601861, 0.04492383], 463 | [0.06583231, 0.02814404, 0.07405336], 464 | [0.0349258, 0.02068955, 0.03345406], 465 | [0.01255543, -0.01929206, 0.02168053], 466 | [0.01126122, -0.00193263, 0.00936819], 467 | [0.01443109, 0.00015909, 0.01570477], 468 | [0.00923043, -0.00597823, 0.00472181], 469 | [0.01755386, -0.04655982, 0.03147908], 470 | [0.02632589, -0.04840693, 0.03110975], 471 | [0.03388062, -0.04848841, 0.03166576], 472 | [0.04075766, -0.04802278, 0.03399541], 473 | [0.0462291, -0.04601556, 0.03829283], 474 | [0.05171755, -0.03662618, 0.04804737], 475 | [0.07297331, -0.01890037, 0.07524373], 476 | [0.04706828, -0.02777865, 0.04366072], 477 | [0.04071712, -0.02603686, 0.0399866], 478 | [0.03269817, -0.02597524, 0.03743659], 479 | [0.02527572, -0.02744176, 0.0361016], 480 | [0.01970894, -0.0298537, 0.03513822], 481 | [0.01579543, -0.03224806, 0.03390608], 482 | [0.07664182, -0.01799997, 0.09911471], 483 | [0.01397041, 0.00213274, 0.01845226], 484 | [0.00884838, -0.01785605, 0.01242372], 485 | [0.00767097, -0.0015883, 0.00397672], 486 | [0.00460213, 0.00207241, 0.00688157], 487 | [0.00748618, -0.00058871, 0.00677301], 488 | [0.01236408, 0.00458703, 0.01995114], 489 | [0.00387306, 0.00283125, 0.00517899], 490 | [0.00319925, 0.00481066, 0.00966928], 491 | [0.01639633, -0.03683163, 0.03611868], 492 | [0.01255645, -0.03594009, 0.03271804], 493 | [0.01031362, -0.03509528, 0.02859755], 494 | [0.04253081, -0.03899161, 0.04160299], 495 | [0.0453, -0.04036865, 0.04135919], 496 | ], 497 | dtype=np.float64) 498 | 499 | REYE_INDICES: np.ndarray = np.array([33, 133]) 500 | LEYE_INDICES: np.ndarray = np.array([362, 263]) 501 | MOUTH_INDICES: np.ndarray = np.array([78, 308]) 502 | NOSE_INDICES: np.ndarray = np.array([240, 460]) 503 | 504 | CHIN_INDEX: int = 199 505 | NOSE_INDEX: int = 1 506 | -------------------------------------------------------------------------------- /ptgaze/common/face_parts.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from typing import Optional 3 | 4 | import numpy as np 5 | from scipy.spatial.transform import Rotation 6 | 7 | 8 | class FacePartsName(enum.Enum): 9 | FACE = enum.auto() 10 | REYE = enum.auto() 11 | LEYE = enum.auto() 12 | 13 | 14 | class FaceParts: 15 | def __init__(self, name: FacePartsName): 16 | self.name = name 17 | self.center: Optional[np.ndarray] = None 18 | self.head_pose_rot: Optional[Rotation] = None 19 | self.normalizing_rot: Optional[Rotation] = None 20 | self.normalized_head_rot2d: Optional[np.ndarray] = None 21 | self.normalized_image: Optional[np.ndarray] = None 22 | 23 | self.normalized_gaze_angles: Optional[np.ndarray] = None 24 | self.normalized_gaze_vector: Optional[np.ndarray] = None 25 | self.gaze_vector: Optional[np.ndarray] = None 26 | 27 | @property 28 | def distance(self) -> float: 29 | return np.linalg.norm(self.center) 30 | 31 | def angle_to_vector(self) -> None: 32 | pitch, yaw = self.normalized_gaze_angles 33 | self.normalized_gaze_vector = -np.array([ 34 | np.cos(pitch) * np.sin(yaw), 35 | np.sin(pitch), 36 | np.cos(pitch) * np.cos(yaw) 37 | ]) 38 | 39 | def denormalize_gaze_vector(self) -> None: 40 | normalizing_rot = self.normalizing_rot.as_matrix() 41 | # Here gaze vector is a row vector, and rotation matrices are 42 | # orthogonal, so multiplying the rotation matrix from the right is 43 | # the same as multiplying the inverse of the rotation matrix to the 44 | # column gaze vector from the left. 45 | self.gaze_vector = self.normalized_gaze_vector @ normalizing_rot 46 | 47 | @staticmethod 48 | def vector_to_angle(vector: np.ndarray) -> np.ndarray: 49 | assert vector.shape == (3, ) 50 | x, y, z = vector 51 | pitch = np.arcsin(-y) 52 | yaw = np.arctan2(-x, -z) 53 | return np.array([pitch, yaw]) 54 | -------------------------------------------------------------------------------- /ptgaze/common/visualizer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import cv2 4 | import numpy as np 5 | from scipy.spatial.transform import Rotation 6 | 7 | from .camera import Camera 8 | from .face import Face 9 | 10 | AXIS_COLORS = [(0, 0, 255), (0, 255, 0), (255, 0, 0)] 11 | 12 | 13 | class Visualizer: 14 | def __init__(self, camera: Camera, center_point_index: int): 15 | self._camera = camera 16 | self._center_point_index = center_point_index 17 | self.image: Optional[np.ndarray] = None 18 | 19 | def set_image(self, image: np.ndarray) -> None: 20 | self.image = image 21 | 22 | def draw_bbox(self, 23 | bbox: np.ndarray, 24 | color: Tuple[int, int, int] = (0, 255, 0), 25 | lw: int = 1) -> None: 26 | assert self.image is not None 27 | assert bbox.shape == (2, 2) 28 | bbox = np.round(bbox).astype(np.int).tolist() 29 | cv2.rectangle(self.image, tuple(bbox[0]), tuple(bbox[1]), color, lw) 30 | 31 | @staticmethod 32 | def _convert_pt(point: np.ndarray) -> Tuple[int, int]: 33 | return tuple(np.round(point).astype(np.int).tolist()) 34 | 35 | def draw_points(self, 36 | points: np.ndarray, 37 | color: Tuple[int, int, int] = (0, 0, 255), 38 | size: int = 3) -> None: 39 | assert self.image is not None 40 | assert points.shape[1] == 2 41 | for pt in points: 42 | pt = self._convert_pt(pt) 43 | cv2.circle(self.image, pt, size, color, cv2.FILLED) 44 | 45 | def draw_3d_points(self, 46 | points3d: np.ndarray, 47 | color: Tuple[int, int, int] = (255, 0, 255), 48 | size=3) -> None: 49 | assert self.image is not None 50 | assert points3d.shape[1] == 3 51 | points2d = self._camera.project_points(points3d) 52 | self.draw_points(points2d, color=color, size=size) 53 | 54 | def draw_3d_line(self, 55 | point0: np.ndarray, 56 | point1: np.ndarray, 57 | color: Tuple[int, int, int] = (255, 255, 0), 58 | lw=1) -> None: 59 | assert self.image is not None 60 | assert point0.shape == point1.shape == (3, ) 61 | points3d = np.vstack([point0, point1]) 62 | points2d = self._camera.project_points(points3d) 63 | pt0 = self._convert_pt(points2d[0]) 64 | pt1 = self._convert_pt(points2d[1]) 65 | cv2.line(self.image, pt0, pt1, color, lw, cv2.LINE_AA) 66 | 67 | def draw_model_axes(self, face: Face, length: float, lw: int = 2) -> None: 68 | assert self.image is not None 69 | assert face is not None 70 | assert face.head_pose_rot is not None 71 | assert face.head_position is not None 72 | assert face.landmarks is not None 73 | # Get the axes of the model coordinate system 74 | axes3d = np.eye(3, dtype=np.float) @ Rotation.from_euler( 75 | 'XYZ', [0, np.pi, 0]).as_matrix() 76 | axes3d = axes3d * length 77 | axes2d = self._camera.project_points(axes3d, 78 | face.head_pose_rot.as_rotvec(), 79 | face.head_position) 80 | center = face.landmarks[self._center_point_index] 81 | center = self._convert_pt(center) 82 | for pt, color in zip(axes2d, AXIS_COLORS): 83 | pt = self._convert_pt(pt) 84 | cv2.line(self.image, center, pt, color, lw, cv2.LINE_AA) 85 | -------------------------------------------------------------------------------- /ptgaze/data/calib/sample_params.yaml: -------------------------------------------------------------------------------- 1 | image_width: 640 2 | image_height: 480 3 | camera_matrix: 4 | rows: 3 5 | cols: 3 6 | data: [640., 0., 320., 7 | 0., 640., 240., 8 | 0., 0., 1.] 9 | distortion_coefficients: 10 | rows: 1 11 | cols: 5 12 | data: [0., 0., 0., 0., 0.] 13 | -------------------------------------------------------------------------------- /ptgaze/data/configs/eth-xgaze.yaml: -------------------------------------------------------------------------------- 1 | mode: ETH-XGaze 2 | device: cpu 3 | model: 4 | name: resnet18 5 | face_detector: 6 | mode: mediapipe 7 | dlib_model_path: ~/.ptgaze/dlib/shape_predictor_68_face_landmarks.dat 8 | mediapipe_max_num_faces: 3 9 | mediapipe_static_image_mode: false 10 | gaze_estimator: 11 | checkpoint: ~/.ptgaze/models/eth-xgaze_resnet18.pth 12 | camera_params: ${PACKAGE_ROOT}/data/calib/sample_params.yaml 13 | use_dummy_camera_params: false 14 | normalized_camera_params: ${PACKAGE_ROOT}/data/normalized_camera_params/eth-xgaze.yaml 15 | normalized_camera_distance: 0.6 16 | image_size: [224, 224] 17 | demo: 18 | use_camera: true 19 | display_on_screen: true 20 | wait_time: 1 21 | image_path: null 22 | video_path: null 23 | output_dir: null 24 | output_file_extension: avi 25 | head_pose_axis_length: 0.05 26 | gaze_visualization_length: 0.05 27 | show_bbox: true 28 | show_head_pose: false 29 | show_landmarks: false 30 | show_normalized_image: false 31 | show_template_model: false 32 | -------------------------------------------------------------------------------- /ptgaze/data/configs/mpiifacegaze.yaml: -------------------------------------------------------------------------------- 1 | mode: MPIIFaceGaze 2 | device: cpu 3 | model: 4 | name: resnet_simple 5 | backbone: 6 | name: resnet_simple 7 | pretrained: resnet18 8 | resnet_block: basic 9 | resnet_layers: [2, 2, 2] 10 | face_detector: 11 | mode: dlib 12 | dlib_model_path: ~/.ptgaze/dlib/shape_predictor_68_face_landmarks.dat 13 | mediapipe_max_num_faces: 3 14 | mediapipe_static_image_mode: false 15 | gaze_estimator: 16 | checkpoint: ~/.ptgaze/models/mpiifacegaze_resnet_simple.pth 17 | camera_params: ${PACKAGE_ROOT}/data/calib/sample_params.yaml 18 | use_dummy_camera_params: false 19 | normalized_camera_params: ${PACKAGE_ROOT}/data/normalized_camera_params/mpiifacegaze.yaml 20 | normalized_camera_distance: 1.0 21 | image_size: [224, 224] 22 | demo: 23 | use_camera: true 24 | display_on_screen: true 25 | wait_time: 1 26 | image_path: null 27 | video_path: null 28 | output_dir: null 29 | output_file_extension: avi 30 | head_pose_axis_length: 0.05 31 | gaze_visualization_length: 0.05 32 | show_bbox: true 33 | show_head_pose: false 34 | show_landmarks: false 35 | show_normalized_image: false 36 | show_template_model: false 37 | -------------------------------------------------------------------------------- /ptgaze/data/configs/mpiigaze.yaml: -------------------------------------------------------------------------------- 1 | mode: MPIIGaze 2 | device: cpu 3 | model: 4 | name: resnet_preact 5 | face_detector: 6 | mode: dlib 7 | dlib_model_path: ~/.ptgaze/dlib/shape_predictor_68_face_landmarks.dat 8 | mediapipe_max_num_faces: 3 9 | mediapipe_static_image_mode: false 10 | gaze_estimator: 11 | checkpoint: ~/.ptgaze/models/mpiigaze_resnet_preact.pth 12 | camera_params: ${PACKAGE_ROOT}/data/calib/sample_params.yaml 13 | use_dummy_camera_params: false 14 | normalized_camera_params: ${PACKAGE_ROOT}/data/normalized_camera_params/mpiigaze.yaml 15 | normalized_camera_distance: 0.6 16 | demo: 17 | use_camera: true 18 | display_on_screen: true 19 | wait_time: 1 20 | image_path: null 21 | video_path: null 22 | output_dir: null 23 | output_file_extension: avi 24 | head_pose_axis_length: 0.05 25 | gaze_visualization_length: 0.05 26 | show_bbox: true 27 | show_head_pose: false 28 | show_landmarks: false 29 | show_normalized_image: false 30 | show_template_model: false 31 | -------------------------------------------------------------------------------- /ptgaze/data/normalized_camera_params/eth-xgaze.yaml: -------------------------------------------------------------------------------- 1 | image_width: 224 2 | image_height: 224 3 | camera_matrix: 4 | rows: 3 5 | cols: 3 6 | data: [960., 0., 112., 7 | 0., 960., 112., 8 | 0., 0., 1.] 9 | distortion_coefficients: 10 | rows: 1 11 | cols: 5 12 | data: [0., 0., 0., 0., 0.] 13 | -------------------------------------------------------------------------------- /ptgaze/data/normalized_camera_params/mpiifacegaze.yaml: -------------------------------------------------------------------------------- 1 | image_width: 224 2 | image_height: 224 3 | camera_matrix: 4 | rows: 3 5 | cols: 3 6 | data: [1600., 0., 112., 7 | 0., 1600., 112., 8 | 0., 0., 1.] 9 | distortion_coefficients: 10 | rows: 1 11 | cols: 5 12 | data: [0., 0., 0., 0., 0.] 13 | -------------------------------------------------------------------------------- /ptgaze/data/normalized_camera_params/mpiigaze.yaml: -------------------------------------------------------------------------------- 1 | image_width: 60 2 | image_height: 36 3 | camera_matrix: 4 | rows: 3 5 | cols: 3 6 | data: [960., 0., 30, 7 | 0., 960., 18., 8 | 0., 0., 1.] 9 | distortion_coefficients: 10 | rows: 1 11 | cols: 5 12 | data: [0., 0., 0., 0., 0.] 13 | -------------------------------------------------------------------------------- /ptgaze/demo.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import pathlib 4 | from typing import Optional 5 | 6 | import cv2 7 | import numpy as np 8 | from omegaconf import DictConfig 9 | 10 | from .common import Face, FacePartsName, Visualizer 11 | from .gaze_estimator import GazeEstimator 12 | from .utils import get_3d_face_model 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class Demo: 19 | QUIT_KEYS = {27, ord('q')} 20 | 21 | def __init__(self, config: DictConfig): 22 | self.config = config 23 | self.gaze_estimator = GazeEstimator(config) 24 | face_model_3d = get_3d_face_model(config) 25 | self.visualizer = Visualizer(self.gaze_estimator.camera, 26 | face_model_3d.NOSE_INDEX) 27 | 28 | self.cap = self._create_capture() 29 | self.output_dir = self._create_output_dir() 30 | self.writer = self._create_video_writer() 31 | 32 | self.stop = False 33 | self.show_bbox = self.config.demo.show_bbox 34 | self.show_head_pose = self.config.demo.show_head_pose 35 | self.show_landmarks = self.config.demo.show_landmarks 36 | self.show_normalized_image = self.config.demo.show_normalized_image 37 | self.show_template_model = self.config.demo.show_template_model 38 | 39 | def run(self) -> None: 40 | if self.config.demo.use_camera or self.config.demo.video_path: 41 | self._run_on_video() 42 | elif self.config.demo.image_path: 43 | self._run_on_image() 44 | else: 45 | raise ValueError 46 | 47 | def _run_on_image(self): 48 | image = cv2.imread(self.config.demo.image_path) 49 | self._process_image(image) 50 | if self.config.demo.display_on_screen: 51 | while True: 52 | key_pressed = self._wait_key() 53 | if self.stop: 54 | break 55 | if key_pressed: 56 | self._process_image(image) 57 | cv2.imshow('image', self.visualizer.image) 58 | if self.config.demo.output_dir: 59 | name = pathlib.Path(self.config.demo.image_path).name 60 | output_path = pathlib.Path(self.config.demo.output_dir) / name 61 | cv2.imwrite(output_path.as_posix(), self.visualizer.image) 62 | 63 | def _run_on_video(self) -> None: 64 | while True: 65 | if self.config.demo.display_on_screen: 66 | self._wait_key() 67 | if self.stop: 68 | break 69 | 70 | ok, frame = self.cap.read() 71 | if not ok: 72 | break 73 | self._process_image(frame) 74 | 75 | if self.config.demo.display_on_screen: 76 | cv2.imshow('frame', self.visualizer.image) 77 | self.cap.release() 78 | if self.writer: 79 | self.writer.release() 80 | 81 | def _process_image(self, image) -> None: 82 | undistorted = cv2.undistort( 83 | image, self.gaze_estimator.camera.camera_matrix, 84 | self.gaze_estimator.camera.dist_coefficients) 85 | 86 | self.visualizer.set_image(image.copy()) 87 | faces = self.gaze_estimator.detect_faces(undistorted) 88 | for face in faces: 89 | self.gaze_estimator.estimate_gaze(undistorted, face) 90 | self._draw_face_bbox(face) 91 | self._draw_head_pose(face) 92 | self._draw_landmarks(face) 93 | self._draw_face_template_model(face) 94 | self._draw_gaze_vector(face) 95 | self._display_normalized_image(face) 96 | 97 | if self.config.demo.use_camera: 98 | self.visualizer.image = self.visualizer.image[:, ::-1] 99 | if self.writer: 100 | self.writer.write(self.visualizer.image) 101 | 102 | def _create_capture(self) -> Optional[cv2.VideoCapture]: 103 | if self.config.demo.image_path: 104 | return None 105 | if self.config.demo.use_camera: 106 | cap = cv2.VideoCapture(0) 107 | elif self.config.demo.video_path: 108 | cap = cv2.VideoCapture(self.config.demo.video_path) 109 | else: 110 | raise ValueError 111 | cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.gaze_estimator.camera.width) 112 | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.gaze_estimator.camera.height) 113 | return cap 114 | 115 | def _create_output_dir(self) -> Optional[pathlib.Path]: 116 | if not self.config.demo.output_dir: 117 | return 118 | output_dir = pathlib.Path(self.config.demo.output_dir) 119 | output_dir.mkdir(exist_ok=True, parents=True) 120 | return output_dir 121 | 122 | @staticmethod 123 | def _create_timestamp() -> str: 124 | dt = datetime.datetime.now() 125 | return dt.strftime('%Y%m%d_%H%M%S') 126 | 127 | def _create_video_writer(self) -> Optional[cv2.VideoWriter]: 128 | if self.config.demo.image_path: 129 | return None 130 | if not self.output_dir: 131 | return None 132 | ext = self.config.demo.output_file_extension 133 | if ext == 'mp4': 134 | fourcc = cv2.VideoWriter_fourcc(*'H264') 135 | elif ext == 'avi': 136 | fourcc = cv2.VideoWriter_fourcc(*'PIM1') 137 | else: 138 | raise ValueError 139 | if self.config.demo.use_camera: 140 | output_name = f'{self._create_timestamp()}.{ext}' 141 | elif self.config.demo.video_path: 142 | name = pathlib.Path(self.config.demo.video_path).stem 143 | output_name = f'{name}.{ext}' 144 | else: 145 | raise ValueError 146 | output_path = self.output_dir / output_name 147 | writer = cv2.VideoWriter(output_path.as_posix(), fourcc, 30, 148 | (self.gaze_estimator.camera.width, 149 | self.gaze_estimator.camera.height)) 150 | if writer is None: 151 | raise RuntimeError 152 | return writer 153 | 154 | def _wait_key(self) -> bool: 155 | key = cv2.waitKey(self.config.demo.wait_time) & 0xff 156 | if key in self.QUIT_KEYS: 157 | self.stop = True 158 | elif key == ord('b'): 159 | self.show_bbox = not self.show_bbox 160 | elif key == ord('l'): 161 | self.show_landmarks = not self.show_landmarks 162 | elif key == ord('h'): 163 | self.show_head_pose = not self.show_head_pose 164 | elif key == ord('n'): 165 | self.show_normalized_image = not self.show_normalized_image 166 | elif key == ord('t'): 167 | self.show_template_model = not self.show_template_model 168 | else: 169 | return False 170 | return True 171 | 172 | def _draw_face_bbox(self, face: Face) -> None: 173 | if not self.show_bbox: 174 | return 175 | self.visualizer.draw_bbox(face.bbox) 176 | 177 | def _draw_head_pose(self, face: Face) -> None: 178 | if not self.show_head_pose: 179 | return 180 | # Draw the axes of the model coordinate system 181 | length = self.config.demo.head_pose_axis_length 182 | self.visualizer.draw_model_axes(face, length, lw=2) 183 | 184 | euler_angles = face.head_pose_rot.as_euler('XYZ', degrees=True) 185 | pitch, yaw, roll = face.change_coordinate_system(euler_angles) 186 | logger.info(f'[head] pitch: {pitch:.2f}, yaw: {yaw:.2f}, ' 187 | f'roll: {roll:.2f}, distance: {face.distance:.2f}') 188 | 189 | def _draw_landmarks(self, face: Face) -> None: 190 | if not self.show_landmarks: 191 | return 192 | self.visualizer.draw_points(face.landmarks, 193 | color=(0, 255, 255), 194 | size=1) 195 | 196 | def _draw_face_template_model(self, face: Face) -> None: 197 | if not self.show_template_model: 198 | return 199 | self.visualizer.draw_3d_points(face.model3d, 200 | color=(255, 0, 525), 201 | size=1) 202 | 203 | def _display_normalized_image(self, face: Face) -> None: 204 | if not self.config.demo.display_on_screen: 205 | return 206 | if not self.show_normalized_image: 207 | return 208 | if self.config.mode == 'MPIIGaze': 209 | reye = face.reye.normalized_image 210 | leye = face.leye.normalized_image 211 | normalized = np.hstack([reye, leye]) 212 | elif self.config.mode in ['MPIIFaceGaze', 'ETH-XGaze']: 213 | normalized = face.normalized_image 214 | else: 215 | raise ValueError 216 | if self.config.demo.use_camera: 217 | normalized = normalized[:, ::-1] 218 | cv2.imshow('normalized', normalized) 219 | 220 | def _draw_gaze_vector(self, face: Face) -> None: 221 | length = self.config.demo.gaze_visualization_length 222 | if self.config.mode == 'MPIIGaze': 223 | for key in [FacePartsName.REYE, FacePartsName.LEYE]: 224 | eye = getattr(face, key.name.lower()) 225 | self.visualizer.draw_3d_line( 226 | eye.center, eye.center + length * eye.gaze_vector) 227 | pitch, yaw = np.rad2deg(eye.vector_to_angle(eye.gaze_vector)) 228 | logger.info( 229 | f'[{key.name.lower()}] pitch: {pitch:.2f}, yaw: {yaw:.2f}') 230 | elif self.config.mode in ['MPIIFaceGaze', 'ETH-XGaze']: 231 | self.visualizer.draw_3d_line( 232 | face.center, face.center + length * face.gaze_vector) 233 | pitch, yaw = np.rad2deg(face.vector_to_angle(face.gaze_vector)) 234 | logger.info(f'[face] pitch: {pitch:.2f}, yaw: {yaw:.2f}') 235 | else: 236 | raise ValueError 237 | -------------------------------------------------------------------------------- /ptgaze/gaze_estimator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch 6 | from omegaconf import DictConfig 7 | 8 | from .common import Camera, Face, FacePartsName 9 | from .head_pose_estimation import HeadPoseNormalizer, LandmarkEstimator 10 | from .models import create_model 11 | from .transforms import create_transform 12 | from .utils import get_3d_face_model 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class GazeEstimator: 18 | EYE_KEYS = [FacePartsName.REYE, FacePartsName.LEYE] 19 | 20 | def __init__(self, config: DictConfig): 21 | self._config = config 22 | 23 | self._face_model3d = get_3d_face_model(config) 24 | 25 | self.camera = Camera(config.gaze_estimator.camera_params) 26 | self._normalized_camera = Camera( 27 | config.gaze_estimator.normalized_camera_params) 28 | 29 | self._landmark_estimator = LandmarkEstimator(config) 30 | self._head_pose_normalizer = HeadPoseNormalizer( 31 | self.camera, self._normalized_camera, 32 | self._config.gaze_estimator.normalized_camera_distance) 33 | self._gaze_estimation_model = self._load_model() 34 | self._transform = create_transform(config) 35 | 36 | def _load_model(self) -> torch.nn.Module: 37 | model = create_model(self._config) 38 | checkpoint = torch.load(self._config.gaze_estimator.checkpoint, 39 | map_location='cpu') 40 | model.load_state_dict(checkpoint['model']) 41 | model.to(torch.device(self._config.device)) 42 | model.eval() 43 | return model 44 | 45 | def detect_faces(self, image: np.ndarray) -> List[Face]: 46 | return self._landmark_estimator.detect_faces(image) 47 | 48 | def estimate_gaze(self, image: np.ndarray, face: Face) -> None: 49 | self._face_model3d.estimate_head_pose(face, self.camera) 50 | self._face_model3d.compute_3d_pose(face) 51 | self._face_model3d.compute_face_eye_centers(face, self._config.mode) 52 | 53 | if self._config.mode == 'MPIIGaze': 54 | for key in self.EYE_KEYS: 55 | eye = getattr(face, key.name.lower()) 56 | self._head_pose_normalizer.normalize(image, eye) 57 | self._run_mpiigaze_model(face) 58 | elif self._config.mode == 'MPIIFaceGaze': 59 | self._head_pose_normalizer.normalize(image, face) 60 | self._run_mpiifacegaze_model(face) 61 | elif self._config.mode == 'ETH-XGaze': 62 | self._head_pose_normalizer.normalize(image, face) 63 | self._run_ethxgaze_model(face) 64 | else: 65 | raise ValueError 66 | 67 | @torch.no_grad() 68 | def _run_mpiigaze_model(self, face: Face) -> None: 69 | images = [] 70 | head_poses = [] 71 | for key in self.EYE_KEYS: 72 | eye = getattr(face, key.name.lower()) 73 | image = eye.normalized_image 74 | normalized_head_pose = eye.normalized_head_rot2d 75 | if key == FacePartsName.REYE: 76 | image = image[:, ::-1].copy() 77 | normalized_head_pose *= np.array([1, -1]) 78 | image = self._transform(image) 79 | images.append(image) 80 | head_poses.append(normalized_head_pose) 81 | images = torch.stack(images) 82 | head_poses = np.array(head_poses).astype(np.float32) 83 | head_poses = torch.from_numpy(head_poses) 84 | 85 | device = torch.device(self._config.device) 86 | images = images.to(device) 87 | head_poses = head_poses.to(device) 88 | predictions = self._gaze_estimation_model(images, head_poses) 89 | predictions = predictions.cpu().numpy() 90 | 91 | for i, key in enumerate(self.EYE_KEYS): 92 | eye = getattr(face, key.name.lower()) 93 | eye.normalized_gaze_angles = predictions[i] 94 | if key == FacePartsName.REYE: 95 | eye.normalized_gaze_angles *= np.array([1, -1]) 96 | eye.angle_to_vector() 97 | eye.denormalize_gaze_vector() 98 | 99 | @torch.no_grad() 100 | def _run_mpiifacegaze_model(self, face: Face) -> None: 101 | image = self._transform(face.normalized_image).unsqueeze(0) 102 | 103 | device = torch.device(self._config.device) 104 | image = image.to(device) 105 | prediction = self._gaze_estimation_model(image) 106 | prediction = prediction.cpu().numpy() 107 | 108 | face.normalized_gaze_angles = prediction[0] 109 | face.angle_to_vector() 110 | face.denormalize_gaze_vector() 111 | 112 | @torch.no_grad() 113 | def _run_ethxgaze_model(self, face: Face) -> None: 114 | image = self._transform(face.normalized_image).unsqueeze(0) 115 | 116 | device = torch.device(self._config.device) 117 | image = image.to(device) 118 | prediction = self._gaze_estimation_model(image) 119 | prediction = prediction.cpu().numpy() 120 | 121 | face.normalized_gaze_angles = prediction[0] 122 | face.angle_to_vector() 123 | face.denormalize_gaze_vector() 124 | -------------------------------------------------------------------------------- /ptgaze/head_pose_estimation/__init__.py: -------------------------------------------------------------------------------- 1 | from .face_landmark_estimator import LandmarkEstimator 2 | from .head_pose_normalizer import HeadPoseNormalizer 3 | -------------------------------------------------------------------------------- /ptgaze/head_pose_estimation/face_landmark_estimator.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import dlib 4 | import face_alignment 5 | import face_alignment.detection.sfd 6 | import mediapipe 7 | import numpy as np 8 | from omegaconf import DictConfig 9 | 10 | from ..common import Face 11 | 12 | 13 | class LandmarkEstimator: 14 | def __init__(self, config: DictConfig): 15 | self.mode = config.face_detector.mode 16 | if self.mode == 'dlib': 17 | self.detector = dlib.get_frontal_face_detector() 18 | self.predictor = dlib.shape_predictor( 19 | config.face_detector.dlib_model_path) 20 | elif self.mode == 'face_alignment_dlib': 21 | self.detector = dlib.get_frontal_face_detector() 22 | self.predictor = face_alignment.FaceAlignment( 23 | face_alignment.LandmarksType._2D, 24 | face_detector='dlib', 25 | flip_input=False, 26 | device=config.device) 27 | elif self.mode == 'face_alignment_sfd': 28 | self.detector = face_alignment.detection.sfd.sfd_detector.SFDDetector( 29 | device=config.device) 30 | self.predictor = face_alignment.FaceAlignment( 31 | face_alignment.LandmarksType._2D, 32 | flip_input=False, 33 | device=config.device) 34 | elif self.mode == 'mediapipe': 35 | self.detector = mediapipe.solutions.face_mesh.FaceMesh( 36 | max_num_faces=config.face_detector.mediapipe_max_num_faces, 37 | static_image_mode=config.face_detector. 38 | mediapipe_static_image_mode) 39 | else: 40 | raise ValueError 41 | 42 | def detect_faces(self, image: np.ndarray) -> List[Face]: 43 | if self.mode == 'dlib': 44 | return self._detect_faces_dlib(image) 45 | elif self.mode == 'face_alignment_dlib': 46 | return self._detect_faces_face_alignment_dlib(image) 47 | elif self.mode == 'face_alignment_sfd': 48 | return self._detect_faces_face_alignment_sfd(image) 49 | elif self.mode == 'mediapipe': 50 | return self._detect_faces_mediapipe(image) 51 | else: 52 | raise ValueError 53 | 54 | def _detect_faces_dlib(self, image: np.ndarray) -> List[Face]: 55 | bboxes = self.detector(image[:, :, ::-1], 0) 56 | detected = [] 57 | for bbox in bboxes: 58 | predictions = self.predictor(image[:, :, ::-1], bbox) 59 | landmarks = np.array([(pt.x, pt.y) for pt in predictions.parts()], 60 | dtype=np.float) 61 | bbox = np.array([[bbox.left(), bbox.top()], 62 | [bbox.right(), bbox.bottom()]], 63 | dtype=np.float) 64 | detected.append(Face(bbox, landmarks)) 65 | return detected 66 | 67 | def _detect_faces_face_alignment_dlib(self, 68 | image: np.ndarray) -> List[Face]: 69 | bboxes = self.detector(image[:, :, ::-1], 0) 70 | bboxes = [[bbox.left(), 71 | bbox.top(), 72 | bbox.right(), 73 | bbox.bottom()] for bbox in bboxes] 74 | predictions = self.predictor.get_landmarks(image[:, :, ::-1], 75 | detected_faces=bboxes) 76 | if predictions is None: 77 | predictions = [] 78 | detected = [] 79 | for bbox, landmarks in zip(bboxes, predictions): 80 | bbox = np.array(bbox, dtype=np.float).reshape(2, 2) 81 | detected.append(Face(bbox, landmarks)) 82 | return detected 83 | 84 | def _detect_faces_face_alignment_sfd(self, 85 | image: np.ndarray) -> List[Face]: 86 | bboxes = self.detector.detect_from_image(image[:, :, ::-1].copy()) 87 | bboxes = [bbox[:4] for bbox in bboxes] 88 | predictions = self.predictor.get_landmarks(image[:, :, ::-1], 89 | detected_faces=bboxes) 90 | if predictions is None: 91 | predictions = [] 92 | detected = [] 93 | for bbox, landmarks in zip(bboxes, predictions): 94 | bbox = np.array(bbox, dtype=np.float).reshape(2, 2) 95 | detected.append(Face(bbox, landmarks)) 96 | return detected 97 | 98 | def _detect_faces_mediapipe(self, image: np.ndarray) -> List[Face]: 99 | h, w = image.shape[:2] 100 | predictions = self.detector.process(image[:, :, ::-1]) 101 | detected = [] 102 | if predictions.multi_face_landmarks: 103 | for prediction in predictions.multi_face_landmarks: 104 | pts = np.array([(pt.x * w, pt.y * h) 105 | for pt in prediction.landmark], 106 | dtype=np.float64) 107 | bbox = np.vstack([pts.min(axis=0), pts.max(axis=0)]) 108 | bbox = np.round(bbox).astype(np.int32) 109 | detected.append(Face(bbox, pts)) 110 | return detected 111 | -------------------------------------------------------------------------------- /ptgaze/head_pose_estimation/head_pose_normalizer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from scipy.spatial.transform import Rotation 4 | 5 | from ..common import Camera, FaceParts, FacePartsName 6 | 7 | 8 | def _normalize_vector(vector: np.ndarray) -> np.ndarray: 9 | return vector / np.linalg.norm(vector) 10 | 11 | 12 | class HeadPoseNormalizer: 13 | def __init__(self, camera: Camera, normalized_camera: Camera, 14 | normalized_distance: float): 15 | self.camera = camera 16 | self.normalized_camera = normalized_camera 17 | self.normalized_distance = normalized_distance 18 | 19 | def normalize(self, image: np.ndarray, eye_or_face: FaceParts) -> None: 20 | eye_or_face.normalizing_rot = self._compute_normalizing_rotation( 21 | eye_or_face.center, eye_or_face.head_pose_rot) 22 | self._normalize_image(image, eye_or_face) 23 | self._normalize_head_pose(eye_or_face) 24 | 25 | def _normalize_image(self, image: np.ndarray, 26 | eye_or_face: FaceParts) -> None: 27 | camera_matrix_inv = np.linalg.inv(self.camera.camera_matrix) 28 | normalized_camera_matrix = self.normalized_camera.camera_matrix 29 | 30 | scale = self._get_scale_matrix(eye_or_face.distance) 31 | conversion_matrix = scale @ eye_or_face.normalizing_rot.as_matrix() 32 | 33 | projection_matrix = normalized_camera_matrix @ conversion_matrix @ camera_matrix_inv 34 | 35 | normalized_image = cv2.warpPerspective( 36 | image, projection_matrix, 37 | (self.normalized_camera.width, self.normalized_camera.height)) 38 | 39 | if eye_or_face.name in {FacePartsName.REYE, FacePartsName.LEYE}: 40 | normalized_image = cv2.cvtColor(normalized_image, 41 | cv2.COLOR_BGR2GRAY) 42 | normalized_image = cv2.equalizeHist(normalized_image) 43 | eye_or_face.normalized_image = normalized_image 44 | 45 | @staticmethod 46 | def _normalize_head_pose(eye_or_face: FaceParts) -> None: 47 | normalized_head_rot = eye_or_face.head_pose_rot * eye_or_face.normalizing_rot 48 | euler_angles2d = normalized_head_rot.as_euler('XYZ')[:2] 49 | eye_or_face.normalized_head_rot2d = euler_angles2d * np.array([1, -1]) 50 | 51 | @staticmethod 52 | def _compute_normalizing_rotation(center: np.ndarray, 53 | head_rot: Rotation) -> Rotation: 54 | # See section 4.2 and Figure 9 of https://arxiv.org/abs/1711.09017 55 | z_axis = _normalize_vector(center.ravel()) 56 | head_rot = head_rot.as_matrix() 57 | head_x_axis = head_rot[:, 0] 58 | y_axis = _normalize_vector(np.cross(z_axis, head_x_axis)) 59 | x_axis = _normalize_vector(np.cross(y_axis, z_axis)) 60 | return Rotation.from_matrix(np.vstack([x_axis, y_axis, z_axis])) 61 | 62 | def _get_scale_matrix(self, distance: float) -> np.ndarray: 63 | return np.array([ 64 | [1, 0, 0], 65 | [0, 1, 0], 66 | [0, 0, self.normalized_distance / distance], 67 | ], 68 | dtype=np.float) 69 | -------------------------------------------------------------------------------- /ptgaze/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import pathlib 4 | import warnings 5 | 6 | import torch 7 | from omegaconf import DictConfig, OmegaConf 8 | 9 | from .demo import Demo 10 | from .utils import (check_path_all, download_dlib_pretrained_model, 11 | download_ethxgaze_model, download_mpiifacegaze_model, 12 | download_mpiigaze_model, expanduser_all, 13 | generate_dummy_camera_params) 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def parse_args() -> argparse.Namespace: 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | '--config', 22 | type=str, 23 | help='Config file. When using a config file, all the other ' 24 | 'commandline arguments are ignored. ' 25 | 'See https://github.com/hysts/pytorch_mpiigaze_demo/ptgaze/data/configs/eth-xgaze.yaml' 26 | ) 27 | parser.add_argument( 28 | '--mode', 29 | type=str, 30 | choices=['mpiigaze', 'mpiifacegaze', 'eth-xgaze'], 31 | help='With \'mpiigaze\', MPIIGaze model will be used. ' 32 | 'With \'mpiifacegaze\', MPIIFaceGaze model will be used. ' 33 | 'With \'eth-xgaze\', ETH-XGaze model will be used.') 34 | parser.add_argument( 35 | '--face-detector', 36 | type=str, 37 | default='mediapipe', 38 | choices=[ 39 | 'dlib', 'face_alignment_dlib', 'face_alignment_sfd', 'mediapipe' 40 | ], 41 | help='The method used to detect faces and find face landmarks ' 42 | '(default: \'mediapipe\')') 43 | parser.add_argument('--device', 44 | type=str, 45 | choices=['cpu', 'cuda'], 46 | help='Device used for model inference.') 47 | parser.add_argument('--image', 48 | type=str, 49 | help='Path to an input image file.') 50 | parser.add_argument('--video', 51 | type=str, 52 | help='Path to an input video file.') 53 | parser.add_argument( 54 | '--camera', 55 | type=str, 56 | help='Camera calibration file. ' 57 | 'See https://github.com/hysts/pytorch_mpiigaze_demo/ptgaze/data/calib/sample_params.yaml' 58 | ) 59 | parser.add_argument( 60 | '--output-dir', 61 | '-o', 62 | type=str, 63 | help='If specified, the overlaid video will be saved to this directory.' 64 | ) 65 | parser.add_argument('--ext', 66 | '-e', 67 | type=str, 68 | choices=['avi', 'mp4'], 69 | help='Output video file extension.') 70 | parser.add_argument( 71 | '--no-screen', 72 | action='store_true', 73 | help='If specified, the video is not displayed on screen, and saved ' 74 | 'to the output directory.') 75 | parser.add_argument('--debug', action='store_true') 76 | return parser.parse_args() 77 | 78 | 79 | def load_mode_config(args: argparse.Namespace) -> DictConfig: 80 | package_root = pathlib.Path(__file__).parent.resolve() 81 | if args.mode == 'mpiigaze': 82 | path = package_root / 'data/configs/mpiigaze.yaml' 83 | elif args.mode == 'mpiifacegaze': 84 | path = package_root / 'data/configs/mpiifacegaze.yaml' 85 | elif args.mode == 'eth-xgaze': 86 | path = package_root / 'data/configs/eth-xgaze.yaml' 87 | else: 88 | raise ValueError 89 | config = OmegaConf.load(path) 90 | config.PACKAGE_ROOT = package_root.as_posix() 91 | 92 | if args.face_detector: 93 | config.face_detector.mode = args.face_detector 94 | if args.device: 95 | config.device = args.device 96 | if config.device == 'cuda' and not torch.cuda.is_available(): 97 | config.device = 'cpu' 98 | warnings.warn('Run on CPU because CUDA is not available.') 99 | if args.image and args.video: 100 | raise ValueError('Only one of --image or --video can be specified.') 101 | if args.image: 102 | config.demo.image_path = args.image 103 | config.demo.use_camera = False 104 | if args.video: 105 | config.demo.video_path = args.video 106 | config.demo.use_camera = False 107 | if args.camera: 108 | config.gaze_estimator.camera_params = args.camera 109 | elif args.image or args.video: 110 | config.gaze_estimator.use_dummy_camera_params = True 111 | if args.output_dir: 112 | config.demo.output_dir = args.output_dir 113 | if args.ext: 114 | config.demo.output_file_extension = args.ext 115 | if args.no_screen: 116 | config.demo.display_on_screen = False 117 | if not config.demo.output_dir: 118 | config.demo.output_dir = 'outputs' 119 | 120 | return config 121 | 122 | 123 | def main(): 124 | args = parse_args() 125 | if args.debug: 126 | logging.getLogger('ptgaze').setLevel(logging.DEBUG) 127 | 128 | if args.config: 129 | config = OmegaConf.load(args.config) 130 | elif args.mode: 131 | config = load_mode_config(args) 132 | else: 133 | raise ValueError( 134 | 'You need to specify one of \'--mode\' or \'--config\'.') 135 | expanduser_all(config) 136 | if config.gaze_estimator.use_dummy_camera_params: 137 | generate_dummy_camera_params(config) 138 | 139 | OmegaConf.set_readonly(config, True) 140 | logger.info(OmegaConf.to_yaml(config)) 141 | 142 | if config.face_detector.mode == 'dlib': 143 | download_dlib_pretrained_model() 144 | if args.mode: 145 | if config.mode == 'MPIIGaze': 146 | download_mpiigaze_model() 147 | elif config.mode == 'MPIIFaceGaze': 148 | download_mpiifacegaze_model() 149 | elif config.mode == 'ETH-XGaze': 150 | download_ethxgaze_model() 151 | 152 | check_path_all(config) 153 | 154 | demo = Demo(config) 155 | demo.run() 156 | -------------------------------------------------------------------------------- /ptgaze/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import timm 4 | import torch 5 | from omegaconf import DictConfig 6 | 7 | 8 | def create_model(config: DictConfig) -> torch.nn.Module: 9 | mode = config.mode 10 | if mode in ['MPIIGaze', 'MPIIFaceGaze']: 11 | module = importlib.import_module( 12 | f'ptgaze.models.{mode.lower()}.{config.model.name}') 13 | model = module.Model(config) 14 | elif mode == 'ETH-XGaze': 15 | model = timm.create_model(config.model.name, num_classes=2) 16 | else: 17 | raise ValueError 18 | device = torch.device(config.device) 19 | model.to(device) 20 | return model 21 | -------------------------------------------------------------------------------- /ptgaze/models/mpiifacegaze/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/ptgaze/models/mpiifacegaze/__init__.py -------------------------------------------------------------------------------- /ptgaze/models/mpiifacegaze/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch.nn as nn 4 | from omegaconf import DictConfig 5 | 6 | 7 | def create_backbone(config: DictConfig) -> nn.Module: 8 | backbone_name = config.model.backbone.name 9 | module = importlib.import_module( 10 | f'ptgaze.models.mpiifacegaze.backbones.{backbone_name}') 11 | return module.Model(config) 12 | -------------------------------------------------------------------------------- /ptgaze/models/mpiifacegaze/backbones/resnet_simple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from omegaconf import DictConfig 4 | 5 | 6 | class Model(torchvision.models.ResNet): 7 | def __init__(self, config: DictConfig): 8 | block_name = config.model.backbone.resnet_block 9 | if block_name == 'basic': 10 | block = torchvision.models.resnet.BasicBlock 11 | elif block_name == 'bottleneck': 12 | block = torchvision.models.resnet.Bottleneck 13 | else: 14 | raise ValueError 15 | layers = list(config.model.backbone.resnet_layers) + [1] 16 | super().__init__(block, layers) 17 | del self.layer4 18 | del self.avgpool 19 | del self.fc 20 | 21 | pretrained_name = config.model.backbone.pretrained 22 | if pretrained_name: 23 | state_dict = torch.hub.load_state_dict_from_url( 24 | torchvision.models.resnet.model_urls[pretrained_name]) 25 | self.load_state_dict(state_dict, strict=False) 26 | # While the pretrained models of torchvision are trained 27 | # using images with RGB channel order, in this repository 28 | # images are treated as BGR channel order. 29 | # Therefore, reverse the channel order of the first 30 | # convolutional layer. 31 | module = self.conv1 32 | module.weight.data = module.weight.data[:, [2, 1, 0]] 33 | 34 | with torch.no_grad(): 35 | data = torch.zeros((1, 3, 224, 224), dtype=torch.float32) 36 | features = self.forward(data) 37 | self.n_features = features.shape[1] 38 | 39 | def forward(self, x: torch.Tensor) -> torch.Tensor: 40 | x = self.conv1(x) 41 | x = self.bn1(x) 42 | x = self.relu(x) 43 | x = self.maxpool(x) 44 | 45 | x = self.layer1(x) 46 | x = self.layer2(x) 47 | x = self.layer3(x) 48 | return x 49 | -------------------------------------------------------------------------------- /ptgaze/models/mpiifacegaze/resnet_simple.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from omegaconf import DictConfig 7 | 8 | from .backbones import create_backbone 9 | 10 | 11 | class Model(nn.Module): 12 | def __init__(self, config: DictConfig): 13 | super().__init__() 14 | self.feature_extractor = create_backbone(config) 15 | n_channels = self.feature_extractor.n_features 16 | 17 | self.conv = nn.Conv2d(n_channels, 18 | 1, 19 | kernel_size=1, 20 | stride=1, 21 | padding=0) 22 | # This model assumes the input image size is 224x224. 23 | self.fc = nn.Linear(n_channels * 14**2, 2) 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | x = self.feature_extractor(x) 27 | y = F.relu(self.conv(x)) 28 | x = x * y 29 | x = x.view(x.size(0), -1) 30 | x = self.fc(x) 31 | return x 32 | -------------------------------------------------------------------------------- /ptgaze/models/mpiigaze/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_mpiigaze_demo/47cdf68414d20c8281bbb0a03112a298761aaa9b/ptgaze/models/mpiigaze/__init__.py -------------------------------------------------------------------------------- /ptgaze/models/mpiigaze/resnet_preact.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from omegaconf import DictConfig 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_channels: int, out_channels: int, stride: int): 9 | super().__init__() 10 | 11 | self.bn1 = nn.BatchNorm2d(in_channels) 12 | self.conv1 = nn.Conv2d(in_channels, 13 | out_channels, 14 | kernel_size=3, 15 | stride=stride, 16 | padding=1, 17 | bias=False) 18 | self.bn2 = nn.BatchNorm2d(out_channels) 19 | self.conv2 = nn.Conv2d(out_channels, 20 | out_channels, 21 | kernel_size=3, 22 | stride=1, 23 | padding=1, 24 | bias=False) 25 | 26 | self.shortcut = nn.Sequential() 27 | if in_channels != out_channels: 28 | self.shortcut.add_module( 29 | 'conv', 30 | nn.Conv2d(in_channels, 31 | out_channels, 32 | kernel_size=1, 33 | stride=stride, 34 | padding=0, 35 | bias=False)) 36 | 37 | def forward(self, x: torch.Tensor) -> torch.Tensor: 38 | x = F.relu(self.bn1(x), inplace=True) 39 | y = self.conv1(x) 40 | y = F.relu(self.bn2(y), inplace=True) 41 | y = self.conv2(y) 42 | y += self.shortcut(x) 43 | return y 44 | 45 | 46 | class Model(nn.Module): 47 | def __init__(self, config: DictConfig): 48 | super().__init__() 49 | 50 | depth = 8 51 | base_channels = 16 52 | input_shape = (1, 1, 36, 60) 53 | 54 | n_blocks_per_stage = (depth - 2) // 6 55 | assert n_blocks_per_stage * 6 + 2 == depth 56 | 57 | n_channels = [base_channels, base_channels * 2, base_channels * 4] 58 | 59 | self.conv = nn.Conv2d(input_shape[1], 60 | n_channels[0], 61 | kernel_size=(3, 3), 62 | stride=1, 63 | padding=1, 64 | bias=False) 65 | 66 | self.stage1 = self._make_stage(n_channels[0], 67 | n_channels[0], 68 | n_blocks_per_stage, 69 | BasicBlock, 70 | stride=1) 71 | self.stage2 = self._make_stage(n_channels[0], 72 | n_channels[1], 73 | n_blocks_per_stage, 74 | BasicBlock, 75 | stride=2) 76 | self.stage3 = self._make_stage(n_channels[1], 77 | n_channels[2], 78 | n_blocks_per_stage, 79 | BasicBlock, 80 | stride=2) 81 | self.bn = nn.BatchNorm2d(n_channels[2]) 82 | 83 | # compute conv feature size 84 | with torch.no_grad(): 85 | self.feature_size = self._forward_conv( 86 | torch.zeros(*input_shape)).view(-1).size(0) 87 | 88 | self.fc = nn.Linear(self.feature_size + 2, 2) 89 | 90 | @staticmethod 91 | def _make_stage(in_channels: int, out_channels: int, n_blocks: int, 92 | block: torch.nn.Module, stride: int) -> torch.nn.Module: 93 | stage = nn.Sequential() 94 | for index in range(n_blocks): 95 | block_name = f'block{index + 1}' 96 | if index == 0: 97 | stage.add_module( 98 | block_name, block(in_channels, out_channels, 99 | stride=stride)) 100 | else: 101 | stage.add_module(block_name, 102 | block(out_channels, out_channels, stride=1)) 103 | return stage 104 | 105 | def _forward_conv(self, x: torch.Tensor) -> torch.Tensor: 106 | x = self.conv(x) 107 | x = self.stage1(x) 108 | x = self.stage2(x) 109 | x = self.stage3(x) 110 | x = F.relu(self.bn(x), inplace=True) 111 | x = F.adaptive_avg_pool2d(x, output_size=1) 112 | return x 113 | 114 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 115 | x = self._forward_conv(x) 116 | x = x.view(x.size(0), -1) 117 | x = torch.cat([x, y], dim=1) 118 | x = self.fc(x) 119 | return x 120 | -------------------------------------------------------------------------------- /ptgaze/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import cv2 4 | import torchvision.transforms as T 5 | from omegaconf import DictConfig 6 | 7 | 8 | def create_transform(config: DictConfig) -> Any: 9 | if config.mode == 'MPIIGaze': 10 | return T.ToTensor() 11 | elif config.mode == 'MPIIFaceGaze': 12 | return _create_mpiifacegaze_transform(config) 13 | elif config.mode == 'ETH-XGaze': 14 | return _create_ethxgaze_transform(config) 15 | else: 16 | raise ValueError 17 | 18 | 19 | def _create_mpiifacegaze_transform(config: DictConfig) -> Any: 20 | size = tuple(config.gaze_estimator.image_size) 21 | transform = T.Compose([ 22 | T.Lambda(lambda x: cv2.resize(x, size)), 23 | T.ToTensor(), 24 | T.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 25 | 0.229]), # BGR 26 | ]) 27 | return transform 28 | 29 | 30 | def _create_ethxgaze_transform(config: DictConfig) -> Any: 31 | size = tuple(config.gaze_estimator.image_size) 32 | transform = T.Compose([ 33 | T.Lambda(lambda x: cv2.resize(x, size)), 34 | T.Lambda(lambda x: x[:, :, ::-1].copy()), # BGR -> RGB 35 | T.ToTensor(), 36 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 37 | 0.225]), # RGB 38 | ]) 39 | return transform 40 | -------------------------------------------------------------------------------- /ptgaze/utils.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import logging 3 | import operator 4 | import pathlib 5 | import tempfile 6 | 7 | import cv2 8 | import torch.hub 9 | import yaml 10 | from omegaconf import DictConfig 11 | 12 | from .common.face_model import FaceModel 13 | from .common.face_model_68 import FaceModel68 14 | from .common.face_model_mediapipe import FaceModelMediaPipe 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def get_3d_face_model(config: DictConfig) -> FaceModel: 20 | if config.face_detector.mode == 'mediapipe': 21 | return FaceModelMediaPipe() 22 | else: 23 | return FaceModel68() 24 | 25 | 26 | def download_dlib_pretrained_model() -> None: 27 | logger.debug('Called download_dlib_pretrained_model()') 28 | 29 | dlib_model_dir = pathlib.Path('~/.ptgaze/dlib/').expanduser() 30 | dlib_model_dir.mkdir(exist_ok=True, parents=True) 31 | dlib_model_path = dlib_model_dir / 'shape_predictor_68_face_landmarks.dat' 32 | logger.debug( 33 | f'Update config.face_detector.dlib_model_path to {dlib_model_path.as_posix()}' 34 | ) 35 | 36 | if dlib_model_path.exists(): 37 | logger.debug( 38 | f'dlib pretrained model {dlib_model_path.as_posix()} already exists.' 39 | ) 40 | return 41 | 42 | logger.debug('Download the dlib pretrained model') 43 | bz2_path = dlib_model_path.as_posix() + '.bz2' 44 | torch.hub.download_url_to_file( 45 | 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', 46 | bz2_path) 47 | with bz2.BZ2File(bz2_path, 'rb') as f_in, open(dlib_model_path, 48 | 'wb') as f_out: 49 | data = f_in.read() 50 | f_out.write(data) 51 | 52 | 53 | def download_mpiigaze_model() -> pathlib.Path: 54 | logger.debug('Called _download_mpiigaze_model()') 55 | output_dir = pathlib.Path('~/.ptgaze/models/').expanduser() 56 | output_dir.mkdir(exist_ok=True, parents=True) 57 | output_path = output_dir / 'mpiigaze_resnet_preact.pth' 58 | if not output_path.exists(): 59 | logger.debug('Download the pretrained model') 60 | torch.hub.download_url_to_file( 61 | 'https://github.com/hysts/pytorch_mpiigaze_demo/releases/download/v0.1.0/mpiigaze_resnet_preact.pth', 62 | output_path.as_posix()) 63 | else: 64 | logger.debug(f'The pretrained model {output_path} already exists.') 65 | return output_path 66 | 67 | 68 | def download_mpiifacegaze_model() -> pathlib.Path: 69 | logger.debug('Called _download_mpiifacegaze_model()') 70 | output_dir = pathlib.Path('~/.ptgaze/models/').expanduser() 71 | output_dir.mkdir(exist_ok=True, parents=True) 72 | output_path = output_dir / 'mpiifacegaze_resnet_simple.pth' 73 | if not output_path.exists(): 74 | logger.debug('Download the pretrained model') 75 | torch.hub.download_url_to_file( 76 | 'https://github.com/hysts/pytorch_mpiigaze_demo/releases/download/v0.1.0/mpiifacegaze_resnet_simple.pth', 77 | output_path.as_posix()) 78 | else: 79 | logger.debug(f'The pretrained model {output_path} already exists.') 80 | return output_path 81 | 82 | 83 | def download_ethxgaze_model() -> pathlib.Path: 84 | logger.debug('Called _download_ethxgaze_model()') 85 | output_dir = pathlib.Path('~/.ptgaze/models/').expanduser() 86 | output_dir.mkdir(exist_ok=True, parents=True) 87 | output_path = output_dir / 'eth-xgaze_resnet18.pth' 88 | if not output_path.exists(): 89 | logger.debug('Download the pretrained model') 90 | torch.hub.download_url_to_file( 91 | 'https://github.com/hysts/pytorch_mpiigaze_demo/releases/download/v0.2.2/eth-xgaze_resnet18.pth', 92 | output_path.as_posix()) 93 | else: 94 | logger.debug(f'The pretrained model {output_path} already exists.') 95 | return output_path 96 | 97 | 98 | def generate_dummy_camera_params(config: DictConfig) -> None: 99 | logger.debug('Called _generate_dummy_camera_params()') 100 | if config.demo.image_path: 101 | path = pathlib.Path(config.demo.image_path).expanduser() 102 | image = cv2.imread(path.as_posix()) 103 | h, w = image.shape[:2] 104 | elif config.demo.video_path: 105 | logger.debug(f'Open video {config.demo.video_path}') 106 | path = pathlib.Path(config.demo.video_path).expanduser().as_posix() 107 | cap = cv2.VideoCapture(path) 108 | if not cap.isOpened(): 109 | raise RuntimeError(f'{config.demo.video_path} is not opened.') 110 | h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 111 | w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 112 | cap.release() 113 | else: 114 | raise ValueError 115 | logger.debug(f'Frame size is ({w}, {h})') 116 | logger.debug(f'Close video {config.demo.video_path}') 117 | out_file = tempfile.NamedTemporaryFile(suffix='.yaml', delete=False) 118 | logger.debug(f'Create a dummy camera param file {out_file.name}') 119 | dic = { 120 | 'image_width': w, 121 | 'image_height': h, 122 | 'camera_matrix': { 123 | 'rows': 3, 124 | 'cols': 3, 125 | 'data': [w, 0., w // 2, 0., w, h // 2, 0., 0., 1.] 126 | }, 127 | 'distortion_coefficients': { 128 | 'rows': 1, 129 | 'cols': 5, 130 | 'data': [0., 0., 0., 0., 0.] 131 | } 132 | } 133 | with open(out_file.name, 'w') as f: 134 | yaml.safe_dump(dic, f) 135 | config.gaze_estimator.camera_params = out_file.name 136 | logger.debug( 137 | f'Update config.gaze_estimator.camera_params to {out_file.name}') 138 | 139 | 140 | def _expanduser(path: str) -> str: 141 | if not path: 142 | return path 143 | return pathlib.Path(path).expanduser().as_posix() 144 | 145 | 146 | def expanduser_all(config: DictConfig) -> None: 147 | if hasattr(config.face_detector, 'dlib_model_path'): 148 | config.face_detector.dlib_model_path = _expanduser( 149 | config.face_detector.dlib_model_path) 150 | config.gaze_estimator.checkpoint = _expanduser( 151 | config.gaze_estimator.checkpoint) 152 | config.gaze_estimator.camera_params = _expanduser( 153 | config.gaze_estimator.camera_params) 154 | config.gaze_estimator.normalized_camera_params = _expanduser( 155 | config.gaze_estimator.normalized_camera_params) 156 | if hasattr(config.demo, 'image_path'): 157 | config.demo.image_path = _expanduser(config.demo.image_path) 158 | if hasattr(config.demo, 'video_path'): 159 | config.demo.video_path = _expanduser(config.demo.video_path) 160 | if hasattr(config.demo, 'output_dir'): 161 | config.demo.output_dir = _expanduser(config.demo.output_dir) 162 | 163 | 164 | def _check_path(config: DictConfig, key: str) -> None: 165 | path_str = operator.attrgetter(key)(config) 166 | path = pathlib.Path(path_str) 167 | if not path.exists(): 168 | raise FileNotFoundError(f'config.{key}: {path.as_posix()} not found.') 169 | if not path.is_file(): 170 | raise ValueError(f'config.{key}: {path.as_posix()} is not a file.') 171 | 172 | 173 | def check_path_all(config: DictConfig) -> None: 174 | if config.face_detector.mode == 'dlib': 175 | _check_path(config, 'face_detector.dlib_model_path') 176 | _check_path(config, 'gaze_estimator.checkpoint') 177 | _check_path(config, 'gaze_estimator.camera_params') 178 | _check_path(config, 'gaze_estimator.normalized_camera_params') 179 | if config.demo.image_path: 180 | _check_path(config, 'demo.image_path') 181 | if config.demo.video_path: 182 | _check_path(config, 'demo.video_path') 183 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dlib 2 | face_alignment 3 | mediapipe 4 | numpy 5 | omegaconf 6 | opencv-python 7 | pyyaml 8 | scipy 9 | timm 10 | torch 11 | torchvision 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from setuptools import find_packages, setup 4 | 5 | 6 | def _get_long_description(): 7 | path = pathlib.Path(__file__).parent / 'README.md' 8 | with open(path, encoding='utf-8') as f: 9 | long_description = f.read() 10 | return long_description 11 | 12 | 13 | def _get_requirements(path): 14 | with open(path) as f: 15 | data = f.readlines() 16 | return data 17 | 18 | 19 | setup( 20 | name='ptgaze', 21 | version='0.2.8', 22 | author='hysts', 23 | url='https://github.com/hysts/pytorch_mpiigaze_demo', 24 | python_requires='>=3.7', 25 | install_requires=_get_requirements('requirements.txt'), 26 | packages=find_packages(exclude=('tests', )), 27 | include_package_data=True, 28 | entry_points={ 29 | 'console_scripts': [ 30 | 'ptgaze=ptgaze.main:main', 31 | ], 32 | }, 33 | description='Gaze estimation using MPIIGaze and MPIIFaceGaze', 34 | long_description=_get_long_description(), 35 | long_description_content_type='text/markdown', 36 | ) 37 | --------------------------------------------------------------------------------