├── docs ├── mpiifacegaze_analysis.png ├── compare_points_on_screen_positions.svg └── compare_model_results.svg ├── requirements.txt ├── LICENSE.md ├── eval.py ├── dataset ├── mpii_face_gaze_errors.py ├── mpii_face_gaze_dataset.py └── mpii_face_gaze_preprocessing.py ├── .gitignore ├── model.py ├── utils.py ├── README.md └── train.py /docs/mpiifacegaze_analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pperle/gaze-tracking/HEAD/docs/mpiifacegaze_analysis.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.4.1 2 | tqdm==4.61.2 3 | numpy==1.18.5 4 | h5py==2.10.0 5 | torch==1.9.0 6 | opencv_python==4.5.1.48 7 | torchvision==0.10.0 8 | albumentations==1.1.0 9 | matplotlib==3.4.3 10 | pandas==1.3.3 11 | Pillow==8.3.2 12 | pytorch_lightning==1.4.9 13 | scikit-image==0.18.3 14 | torchinfo==1.5.3 15 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 pperle 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from pytorch_lightning import seed_everything, Trainer 4 | 5 | from dataset.mpii_face_gaze_dataset import get_dataloaders 6 | from train import Model 7 | 8 | if __name__ == '__main__': 9 | parser = ArgumentParser() 10 | parser.add_argument("--path_to_checkpoints", type=str, default='./pretrained_models') 11 | parser.add_argument("--path_to_data", type=str, default='./data') 12 | parser.add_argument("--batch_size", type=int, default=64) 13 | parser.add_argument("--k", type=int, default=[9, 128], nargs='+') 14 | parser.add_argument("--adjust_slope", type=bool, default=False) 15 | parser.add_argument("--grid_calibration_samples", type=bool, default=False) 16 | args = parser.parse_args() 17 | 18 | for person_idx in range(15): 19 | person = f'p{person_idx:02d}' 20 | 21 | seed_everything(42) 22 | print('grid_calibration_samples', args.grid_calibration_samples) 23 | model = Model.load_from_checkpoint(f'{args.path_to_checkpoints}/{person}.ckpt', k=args.k, adjust_slope=args.adjust_slope, grid_calibration_samples=args.grid_calibration_samples) 24 | 25 | trainer = Trainer( 26 | gpus=1, 27 | benchmark=True, 28 | ) 29 | 30 | _, _, test_dataloader = get_dataloaders(args.path_to_data, 0, person_idx, args.batch_size) 31 | trainer.test(model, test_dataloader) 32 | -------------------------------------------------------------------------------- /dataset/mpii_face_gaze_errors.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from argparse import ArgumentParser 3 | 4 | import pandas as pd 5 | import scipy.io 6 | 7 | 8 | def check_mpii_gaze_not_on_screen(input_path: str, output_path: str) -> None: 9 | """ 10 | Create CSV file with the filename of the images where the gaze is not on the screen. 11 | 12 | :param input_path: path to the original MPIIGaze dataset 13 | :param output_path: output dataset 14 | :return: 15 | """ 16 | 17 | data = {'file_name': [], 'on_screen_gaze_position': [], 'monitor_pixels': []} 18 | 19 | for person_file_path in sorted(glob.glob(f'{input_path}/Data/Original/p*'), reverse=True): 20 | person = person_file_path.split('/')[-1] 21 | 22 | screen_size = scipy.io.loadmat(f'{input_path}/Data/Original/{person}/Calibration/screenSize.mat') 23 | screen_width_pixel = screen_size["width_pixel"].item() 24 | screen_height_pixel = screen_size["height_pixel"].item() 25 | 26 | for day_file_path in sorted(glob.glob(f'{person_file_path}/d*')): 27 | day = day_file_path.split('/')[-1] 28 | 29 | df = pd.read_csv(f'{day_file_path}/annotation.txt', sep=' ', header=None) 30 | for row_idx in range(len(df)): 31 | row = df.iloc[row_idx] 32 | on_screen_gaze_target = row[24:26].to_numpy().reshape(-1).astype(int) 33 | 34 | if not (0 <= on_screen_gaze_target[0] <= screen_width_pixel and 0 <= on_screen_gaze_target[1] <= screen_height_pixel): 35 | file_name = f'{person}/{day}/{row_idx + 1:04d}.jpg' 36 | 37 | data['file_name'].append(file_name) 38 | data['on_screen_gaze_position'].append(list(on_screen_gaze_target)) 39 | data['monitor_pixels'].append([screen_width_pixel, screen_height_pixel]) 40 | 41 | pd.DataFrame(data).to_csv(f'{output_path}/not_on_screen.csv', index=False) 42 | 43 | 44 | def check_mpii_face_gaze_not_on_screen(input_path: str, output_path: str) -> None: 45 | """ 46 | Create CSV file with the filename of the images where the gaze is not on the screen. 47 | 48 | :param input_path: path to the original MPIIFaceGaze dataset 49 | :param output_path: output dataset 50 | :return: 51 | """ 52 | 53 | data = {'file_name': [], 'on_screen_gaze_position': [], 'monitor_pixels': []} 54 | 55 | for person_file_path in sorted(glob.glob(f'{input_path}/p*')): 56 | person = person_file_path.split('/')[-1] 57 | 58 | screen_size = scipy.io.loadmat(f'{input_path}/{person}/Calibration/screenSize.mat') 59 | screen_width_pixel = screen_size["width_pixel"].item() 60 | screen_height_pixel = screen_size["height_pixel"].item() 61 | 62 | df = pd.read_csv(f'{person_file_path}/{person}.txt', sep=' ', header=None) 63 | df_idx = 0 64 | 65 | for day_file_path in sorted(glob.glob(f'{person_file_path}/d*')): 66 | day = day_file_path.split('/')[-1] 67 | 68 | for image_file_path in sorted(glob.glob(f'{day_file_path}/*.jpg')): 69 | row = df.iloc[df_idx] 70 | on_screen_gaze_target = row[1:3].to_numpy().reshape(-1).astype(int) 71 | 72 | if not (0 <= on_screen_gaze_target[0] <= screen_width_pixel and 0 <= on_screen_gaze_target[1] <= screen_height_pixel): 73 | file_name = f'{person}/{day}/{image_file_path.split("/")[-1]}' 74 | 75 | data['file_name'].append(file_name) 76 | data['on_screen_gaze_position'].append(list(on_screen_gaze_target)) 77 | data['monitor_pixels'].append([screen_width_pixel, screen_height_pixel]) 78 | 79 | df_idx += 1 80 | 81 | pd.DataFrame(data).to_csv(f'{output_path}/not_on_screen.csv', index=False) 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = ArgumentParser() 86 | parser.add_argument("--input_path", type=str, default='./MPIIFaceGaze') 87 | parser.add_argument("--output_path", type=str, default='./data') 88 | args = parser.parse_args() 89 | 90 | # check_mpiigaze('args.input_path, args.output_path) 91 | check_mpii_face_gaze_not_on_screen(args.input_path, args.output_path) 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,pycharm+all,visualstudiocode 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm+all,visualstudiocode 4 | 5 | ### PyCharm+all ### 6 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 7 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 8 | 9 | # User-specific stuff 10 | .idea/**/workspace.xml 11 | .idea/**/tasks.xml 12 | .idea/**/usage.statistics.xml 13 | .idea/**/dictionaries 14 | .idea/**/shelf 15 | 16 | # AWS User-specific 17 | .idea/**/aws.xml 18 | 19 | # Generated files 20 | .idea/**/contentModel.xml 21 | 22 | # Sensitive or high-churn files 23 | .idea/**/dataSources/ 24 | .idea/**/dataSources.ids 25 | .idea/**/dataSources.local.xml 26 | .idea/**/sqlDataSources.xml 27 | .idea/**/dynamic.xml 28 | .idea/**/uiDesigner.xml 29 | .idea/**/dbnavigator.xml 30 | 31 | # Gradle 32 | .idea/**/gradle.xml 33 | .idea/**/libraries 34 | 35 | # Gradle and Maven with auto-import 36 | # When using Gradle or Maven with auto-import, you should exclude module files, 37 | # since they will be recreated, and may cause churn. Uncomment if using 38 | # auto-import. 39 | # .idea/artifacts 40 | # .idea/compiler.xml 41 | # .idea/jarRepositories.xml 42 | # .idea/modules.xml 43 | # .idea/*.iml 44 | # .idea/modules 45 | # *.iml 46 | # *.ipr 47 | 48 | # CMake 49 | cmake-build-*/ 50 | 51 | # Mongo Explorer plugin 52 | .idea/**/mongoSettings.xml 53 | 54 | # File-based project format 55 | *.iws 56 | 57 | # IntelliJ 58 | out/ 59 | 60 | # mpeltonen/sbt-idea plugin 61 | .idea_modules/ 62 | 63 | # JIRA plugin 64 | atlassian-ide-plugin.xml 65 | 66 | # Cursive Clojure plugin 67 | .idea/replstate.xml 68 | 69 | # Crashlytics plugin (for Android Studio and IntelliJ) 70 | com_crashlytics_export_strings.xml 71 | crashlytics.properties 72 | crashlytics-build.properties 73 | fabric.properties 74 | 75 | # Editor-based Rest Client 76 | .idea/httpRequests 77 | 78 | # Android studio 3.1+ serialized cache file 79 | .idea/caches/build_file_checksums.ser 80 | 81 | ### PyCharm+all Patch ### 82 | # Ignores the whole .idea folder and all .iml files 83 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 84 | 85 | .idea/ 86 | 87 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 88 | 89 | *.iml 90 | modules.xml 91 | .idea/misc.xml 92 | *.ipr 93 | 94 | # Sonarlint plugin 95 | .idea/sonarlint 96 | 97 | ### Python ### 98 | # Byte-compiled / optimized / DLL files 99 | __pycache__/ 100 | *.py[cod] 101 | *$py.class 102 | 103 | # C extensions 104 | *.so 105 | 106 | # Distribution / packaging 107 | .Python 108 | build/ 109 | develop-eggs/ 110 | dist/ 111 | downloads/ 112 | eggs/ 113 | .eggs/ 114 | lib/ 115 | lib64/ 116 | parts/ 117 | sdist/ 118 | var/ 119 | wheels/ 120 | share/python-wheels/ 121 | *.egg-info/ 122 | .installed.cfg 123 | *.egg 124 | MANIFEST 125 | 126 | # PyInstaller 127 | # Usually these files are written by a python script from a template 128 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 129 | *.manifest 130 | *.spec 131 | 132 | # Installer logs 133 | pip-log.txt 134 | pip-delete-this-directory.txt 135 | 136 | # Unit test / coverage reports 137 | htmlcov/ 138 | .tox/ 139 | .nox/ 140 | .coverage 141 | .coverage.* 142 | .cache 143 | nosetests.xml 144 | coverage.xml 145 | *.cover 146 | *.py,cover 147 | .hypothesis/ 148 | .pytest_cache/ 149 | cover/ 150 | 151 | # Translations 152 | *.mo 153 | *.pot 154 | 155 | # Django stuff: 156 | *.log 157 | local_settings.py 158 | db.sqlite3 159 | db.sqlite3-journal 160 | 161 | # Flask stuff: 162 | instance/ 163 | .webassets-cache 164 | 165 | # Scrapy stuff: 166 | .scrapy 167 | 168 | # Sphinx documentation 169 | docs/_build/ 170 | 171 | # PyBuilder 172 | .pybuilder/ 173 | target/ 174 | 175 | # Jupyter Notebook 176 | .ipynb_checkpoints 177 | 178 | # IPython 179 | profile_default/ 180 | ipython_config.py 181 | 182 | # pyenv 183 | # For a library or package, you might want to ignore these files since the code is 184 | # intended to run in multiple environments; otherwise, check them in: 185 | # .python-version 186 | 187 | # pipenv 188 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 189 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 190 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 191 | # install all needed dependencies. 192 | #Pipfile.lock 193 | 194 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 195 | __pypackages__/ 196 | 197 | # Celery stuff 198 | celerybeat-schedule 199 | celerybeat.pid 200 | 201 | # SageMath parsed files 202 | *.sage.py 203 | 204 | # Environments 205 | .env 206 | .venv 207 | env/ 208 | venv/ 209 | ENV/ 210 | env.bak/ 211 | venv.bak/ 212 | 213 | # Spyder project settings 214 | .spyderproject 215 | .spyproject 216 | 217 | # Rope project settings 218 | .ropeproject 219 | 220 | # mkdocs documentation 221 | /site 222 | 223 | # mypy 224 | .mypy_cache/ 225 | .dmypy.json 226 | dmypy.json 227 | 228 | # Pyre type checker 229 | .pyre/ 230 | 231 | # pytype static type analyzer 232 | .pytype/ 233 | 234 | # Cython debug symbols 235 | cython_debug/ 236 | 237 | ### VisualStudioCode ### 238 | .vscode/* 239 | !.vscode/settings.json 240 | !.vscode/tasks.json 241 | !.vscode/launch.json 242 | !.vscode/extensions.json 243 | *.code-workspace 244 | 245 | # Local History for Visual Studio Code 246 | .history/ 247 | 248 | ### VisualStudioCode Patch ### 249 | # Ignore all local history of files 250 | .history 251 | .ionide 252 | 253 | # End of https://www.toptal.com/developers/gitignore/api/python,pycharm+all,visualstudiocode 254 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning import LightningModule 3 | from torch import nn 4 | from torchinfo import summary 5 | from torchvision import models 6 | 7 | 8 | class SELayer(nn.Module): 9 | """ 10 | Squeeze-and-Excitation layer 11 | 12 | https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py 13 | """ 14 | 15 | def __init__(self, channel, reduction=16): 16 | super(SELayer, self).__init__() 17 | self.avg_pool = nn.AdaptiveAvgPool2d(1) # Squeeze 18 | self.fc = nn.Sequential( # Excitation (similar to attention) 19 | nn.Linear(channel, channel // reduction, bias=False), 20 | nn.ReLU(inplace=True), 21 | nn.Linear(channel // reduction, channel, bias=False), 22 | nn.Sigmoid() 23 | ) 24 | 25 | def forward(self, x): 26 | b, c, _, _ = x.size() 27 | y = self.avg_pool(x).view(b, c) 28 | y = self.fc(y).view(b, c, 1, 1) 29 | return x * y.expand_as(x) 30 | 31 | 32 | class FinalModel(LightningModule): 33 | def __init__(self, *args, **kwargs): 34 | super().__init__(*args, **kwargs) 35 | 36 | self.subject_biases = nn.Parameter(torch.zeros(15 * 2, 2)) # pitch and yaw offset for the original and mirrored participant 37 | 38 | self.cnn_face = nn.Sequential( 39 | models.vgg16(pretrained=True).features[:9], # first four convolutional layers of VGG16 pretrained on ImageNet 40 | nn.Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), padding='same'), 41 | nn.ReLU(inplace=True), 42 | nn.BatchNorm2d(64), 43 | nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(2, 2)), 44 | nn.ReLU(inplace=True), 45 | nn.BatchNorm2d(64), 46 | nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(3, 3)), 47 | nn.ReLU(inplace=True), 48 | nn.BatchNorm2d(64), 49 | nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(5, 5)), 50 | nn.ReLU(inplace=True), 51 | nn.BatchNorm2d(128), 52 | nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(11, 11)), 53 | nn.ReLU(inplace=True), 54 | nn.BatchNorm2d(128), 55 | ) 56 | 57 | self.cnn_eye = nn.Sequential( 58 | models.vgg16(pretrained=True).features[:9], # first four convolutional layers of VGG16 pretrained on ImageNet 59 | nn.Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), padding='same'), 60 | nn.ReLU(inplace=True), 61 | nn.BatchNorm2d(64), 62 | nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(2, 2)), 63 | nn.ReLU(inplace=True), 64 | nn.BatchNorm2d(64), 65 | nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(3, 3)), 66 | nn.ReLU(inplace=True), 67 | nn.BatchNorm2d(64), 68 | nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(4, 5)), 69 | nn.ReLU(inplace=True), 70 | nn.BatchNorm2d(128), 71 | nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(5, 11)), 72 | nn.ReLU(inplace=True), 73 | nn.BatchNorm2d(128), 74 | ) 75 | 76 | self.fc_face = nn.Sequential( 77 | nn.Flatten(), 78 | nn.Linear(6 * 6 * 128, 256), 79 | nn.ReLU(inplace=True), 80 | nn.BatchNorm1d(256), 81 | nn.Linear(256, 64), 82 | nn.ReLU(inplace=True), 83 | nn.BatchNorm1d(64), 84 | ) 85 | 86 | self.cnn_eye2fc = nn.Sequential( 87 | SELayer(256), 88 | 89 | nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding='same'), 90 | nn.ReLU(inplace=True), 91 | nn.BatchNorm2d(256), 92 | 93 | SELayer(256), 94 | 95 | nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding='same'), 96 | nn.ReLU(inplace=True), 97 | nn.BatchNorm2d(128), 98 | 99 | SELayer(128), 100 | ) 101 | 102 | self.fc_eye = nn.Sequential( 103 | nn.Flatten(), 104 | nn.Linear(4 * 6 * 128, 512), 105 | nn.ReLU(inplace=True), 106 | nn.BatchNorm1d(512), 107 | ) 108 | 109 | self.fc_eyes_face = nn.Sequential( 110 | nn.Dropout(p=0.5), 111 | nn.Linear(576, 256), 112 | nn.ReLU(inplace=True), 113 | nn.BatchNorm1d(256), 114 | nn.Dropout(p=0.5), 115 | nn.Linear(256, 2), 116 | ) 117 | 118 | def forward(self, person_idx: torch.Tensor, full_face: torch.Tensor, right_eye: torch.Tensor, left_eye: torch.Tensor): 119 | out_cnn_face = self.cnn_face(full_face) 120 | out_fc_face = self.fc_face(out_cnn_face) 121 | 122 | out_cnn_right_eye = self.cnn_eye(right_eye) 123 | out_cnn_left_eye = self.cnn_eye(left_eye) 124 | out_cnn_eye = torch.cat((out_cnn_right_eye, out_cnn_left_eye), dim=1) 125 | 126 | cnn_eye2fc_out = self.cnn_eye2fc(out_cnn_eye) # feature fusion 127 | out_fc_eye = self.fc_eye(cnn_eye2fc_out) 128 | 129 | fc_concatenated = torch.cat((out_fc_face, out_fc_eye), dim=1) 130 | t_hat = self.fc_eyes_face(fc_concatenated) # subject-independent term 131 | 132 | return t_hat + self.subject_biases[person_idx].squeeze(1) # t_hat + subject-dependent bias term 133 | 134 | 135 | if __name__ == '__main__': 136 | model = FinalModel() 137 | model.summarize(max_depth=1) 138 | 139 | print(model.cnn_face) 140 | 141 | batch_size = 16 142 | summary(model, [ 143 | (batch_size, 1), 144 | (batch_size, 3, 96, 96), # full face 145 | (batch_size, 3, 64, 96), # right eye 146 | (batch_size, 3, 64, 96) # left eye 147 | ], dtypes=[torch.long, torch.float, torch.float, torch.float]) 148 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision 8 | from matplotlib import pyplot as plt 9 | from pytorch_lightning.loggers import TensorBoardLogger 10 | from PIL import Image 11 | import io 12 | 13 | 14 | class PitchYaw(Enum): 15 | PITCH = 'pitch' 16 | YAW = 'yaw' 17 | 18 | 19 | def pitchyaw_to_3d_vector(pitchyaw: torch.Tensor) -> torch.Tensor: 20 | """ 21 | 2D pitch and yaw value to a 3D vector 22 | 23 | :param pitchyaw: 2D gaze value in pitch and yaw 24 | :return: 3D vector 25 | """ 26 | return torch.stack([ 27 | -torch.cos(pitchyaw[:, 0]) * torch.sin(pitchyaw[:, 1]), 28 | -torch.sin(pitchyaw[:, 0]), 29 | -torch.cos(pitchyaw[:, 0]) * torch.cos(pitchyaw[:, 1]) 30 | ], dim=1) 31 | 32 | 33 | def calc_angle_error(labels: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor: 34 | """ 35 | Calculate the angle between `labels` and `outputs` in degrees. 36 | 37 | :param labels: ground truth gaze vectors 38 | :param outputs: predicted gaze vectors 39 | :return: Mean angle in degrees. 40 | """ 41 | labels = pitchyaw_to_3d_vector(labels) 42 | labels_norm = labels / torch.linalg.norm(labels, axis=1).reshape((-1, 1)) 43 | 44 | outputs = pitchyaw_to_3d_vector(outputs) 45 | outputs_norm = outputs / torch.linalg.norm(outputs, axis=1).reshape((-1, 1)) 46 | 47 | angles = F.cosine_similarity(outputs_norm, labels_norm, dim=1) 48 | angles = torch.clip(angles, -1.0, 1.0) # fix NaN values for 1.0 < angles < -1.0 49 | 50 | rad = torch.arccos(angles) 51 | return torch.rad2deg(rad).mean() 52 | 53 | 54 | def plot_prediction_vs_ground_truth(labels, outputs, axis: PitchYaw): 55 | """ 56 | Create a plot between the predictions and the ground truth values. 57 | 58 | :param labels: ground truth values 59 | :param outputs: predicted values 60 | :param axis: weather pitch or yaw 61 | :return: scatter plot of predictions and the ground truth values 62 | """ 63 | 64 | labels = torch.rad2deg(labels) 65 | outputs = torch.rad2deg(outputs) 66 | 67 | if axis == PitchYaw.PITCH: 68 | plt.scatter(labels[:, :1].cpu().detach().numpy().reshape(-1), outputs[:, :1].cpu().detach().numpy().reshape(-1)) 69 | else: 70 | plt.scatter(labels[:, 1:].cpu().detach().numpy().reshape(-1), outputs[:, 1:].cpu().detach().numpy().reshape(-1)) 71 | plt.plot([-30, 30], [-30, 30], color='#ff7f0e') 72 | plt.xlabel('ground truth (degrees)') 73 | plt.ylabel('prediction (degrees') 74 | plt.title(axis.value) 75 | if axis == PitchYaw.PITCH: 76 | plt.xlim((-30, 5)) 77 | plt.ylim((-30, 5)) 78 | else: 79 | plt.xlim((-30, 30)) 80 | plt.ylim((-30, 30)) 81 | 82 | return plt.gcf() 83 | 84 | 85 | def plot_to_image(fig) -> torch.Tensor: 86 | """ 87 | Converts the matplotlib plot specified by 'figure' to a PNG image and 88 | returns it. The supplied figure is closed and inaccessible after this call. 89 | 90 | :param fig: matplotlib figure 91 | :return: plot for torchvision 92 | """ 93 | 94 | # Save the plot to a PNG in memory. 95 | buf = io.BytesIO() 96 | plt.savefig(buf, format="png") 97 | plt.close(fig) 98 | buf.seek(0) 99 | 100 | image = Image.open(buf).convert("RGB") 101 | image = torchvision.transforms.ToTensor()(image) 102 | return image 103 | 104 | 105 | def log_figure(loggers: List, tag: str, figure, global_step: int) -> None: 106 | """ 107 | Log figure as image. Only works for `TensorBoardLogger`. 108 | 109 | :param loggers: 110 | :param tag: 111 | :param figure: 112 | :param global_step: 113 | :return: 114 | """ 115 | 116 | if isinstance(loggers, list): 117 | for logger in loggers: 118 | if isinstance(logger, TensorBoardLogger): 119 | logger.experiment.add_image(tag, plot_to_image(figure), global_step, dataformats="CHW") 120 | elif isinstance(loggers, TensorBoardLogger): 121 | loggers.experiment.add_image(tag, plot_to_image(figure), global_step, dataformats="CHW") 122 | 123 | 124 | def get_random_idx(k: int, size: int) -> np.ndarray: 125 | """ 126 | Get `k` random values of a list of size `size`. 127 | 128 | :param k: number or random values 129 | :param size: total number of values 130 | :return: list of `k` random values 131 | """ 132 | return (np.random.rand(k) * size).astype(int) 133 | 134 | 135 | def get_each_of_one_grid_idx(k: int, gaze_locations: np.ndarray, screen_sizes: np.ndarray) -> np.ndarray: 136 | """ 137 | Get `k` random values of each of the $\sqrt{k}\times\sqrt{k}$ grid. 138 | 139 | :param k: number or random values 140 | :param gaze_locations: list of the position on the screen in pixels for each gaze value 141 | :param screen_sizes: list of the screen sizes in pixels for each gaze value 142 | :return: list of `k` random values 143 | """ 144 | grids = int(np.sqrt(k)) # get grid size from k 145 | 146 | grid_width = screen_sizes[0][0] / grids 147 | height_width = screen_sizes[0][1] / grids 148 | 149 | gaze_locations = np.asarray(gaze_locations) 150 | 151 | valid_random_idx = [] 152 | 153 | for width_range in range(grids): 154 | filter_width = (grid_width * width_range < gaze_locations[:, :1]) & (gaze_locations[:, :1] < grid_width * (width_range + 1)) 155 | for height_range in range(grids): 156 | filter_height = (height_width * height_range < gaze_locations[:, 1:]) & (gaze_locations[:, 1:] < height_width * (height_range + 1)) 157 | complete_filter = filter_width & filter_height 158 | complete_filter = complete_filter.reshape(-1) 159 | if sum(complete_filter) > 0: 160 | true_idxs = np.argwhere(complete_filter) 161 | random_idx = (np.random.rand(1) * len(true_idxs)).astype(int).item() 162 | valid_random_idx.append(true_idxs[random_idx].item()) 163 | 164 | if len(valid_random_idx) != k: 165 | # fill missing calibration samples 166 | missing_k = k - len(valid_random_idx) 167 | missing_idxs = (np.random.rand(missing_k) * len(gaze_locations)).astype(int) 168 | for missing_idx in missing_idxs: 169 | valid_random_idx.append(missing_idx.item()) 170 | 171 | return valid_random_idx 172 | -------------------------------------------------------------------------------- /dataset/mpii_face_gaze_dataset.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import List, Tuple 3 | 4 | import albumentations as A 5 | import h5py 6 | import numpy as np 7 | import pandas as pd 8 | import skimage.io 9 | import torch 10 | from albumentations.pytorch import ToTensorV2 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data import Dataset 13 | 14 | 15 | def filter_persons_by_idx(file_names: List[str], keep_person_idxs: List[int]) -> List[int]: 16 | """ 17 | Only keep idx that match the person ids in `keep_person_idxs`. 18 | 19 | :param file_names: list of all file names 20 | :param keep_person_idxs: list of person ids to keep 21 | :return: list of valid idxs that match `keep_person_idxs` 22 | """ 23 | idx_per_person = [[] for _ in range(15)] 24 | if keep_person_idxs is not None: 25 | keep_person_idxs = [f'p{person_idx:02d}/' for person_idx in set(keep_person_idxs)] 26 | for idx, file_name in enumerate(file_names): 27 | if any(keep_person_idx in file_name for keep_person_idx in keep_person_idxs): # is a valid person_idx ? 28 | person_idx = int(file_name.split('/')[-3][1:]) 29 | idx_per_person[person_idx].append(idx) 30 | else: 31 | for idx, file_name in enumerate(file_names): 32 | person_idx = int(file_name.split('/')[-3][1:]) 33 | idx_per_person[person_idx].append(idx) 34 | 35 | return list(itertools.chain(*idx_per_person)) # flatten list 36 | 37 | 38 | def remove_error_data(data_path: str, file_names: List[str]) -> List[int]: 39 | """ 40 | Remove erroneous data, where the gaze point is not in the screen. 41 | 42 | :param data_path: path to the dataset including the `not_on_screen.csv` file 43 | :param file_names: list of all file names 44 | :return: list of idxs of valid data 45 | """ 46 | valid_idxs = [] 47 | 48 | df = pd.read_csv(f'{data_path}/not_on_screen.csv') 49 | error_file_names = set([error_file_name[:-8] for error_file_name in df['file_name'].tolist()]) 50 | file_names = [file_name[:-4] for file_name in file_names] 51 | for idx, file_name in enumerate(file_names): 52 | if file_name not in error_file_names: 53 | valid_idxs.append(idx) 54 | 55 | return valid_idxs 56 | 57 | 58 | class MPIIFaceGaze(Dataset): 59 | """ 60 | MPIIFaceGaze dataset with offline preprocessing (= already preprocessed) 61 | """ 62 | 63 | def __init__(self, data_path: str, file_name: str, keep_person_idxs: List[int], transform=None, train: bool = False, force_flip: bool = False, use_erroneous_data: bool = False): 64 | if keep_person_idxs is not None: 65 | assert len(keep_person_idxs) > 0 66 | assert max(keep_person_idxs) <= 14 # last person id = 14 67 | assert min(keep_person_idxs) >= 0 # first person id = 0 68 | 69 | self.data_path = data_path 70 | self.hdf5_file_name = f'{data_path}/{file_name}' 71 | self.h5_file = None 72 | 73 | self.transform = transform 74 | self.train = train 75 | self.force_flip = force_flip 76 | 77 | with h5py.File(self.hdf5_file_name, 'r') as f: 78 | file_names = [file_name.decode('utf-8') for file_name in f['file_name_base']] 79 | 80 | if not train: 81 | by_person_idx = filter_persons_by_idx(file_names, keep_person_idxs) 82 | self.idx2ValidIdx = by_person_idx 83 | else: 84 | by_person_idx = filter_persons_by_idx(file_names, keep_person_idxs) 85 | non_error_idx = file_names if use_erroneous_data else remove_error_data(data_path, file_names) 86 | self.idx2ValidIdx = list(set(by_person_idx) & set(non_error_idx)) 87 | 88 | def __len__(self) -> int: 89 | return len(self.idx2ValidIdx) * 2 if self.train else len(self.idx2ValidIdx) 90 | 91 | def __del__(self): 92 | if self.h5_file is not None: 93 | self.h5_file.close() 94 | 95 | def open_hdf5(self): 96 | self.h5_file = h5py.File(self.hdf5_file_name, 'r') 97 | 98 | def __getitem__(self, idx): 99 | if torch.is_tensor(idx): 100 | idx = idx.tolist() 101 | 102 | if self.h5_file is None: 103 | self.open_hdf5() 104 | 105 | augmented_person = idx >= len(self.idx2ValidIdx) 106 | if augmented_person: 107 | idx -= len(self.idx2ValidIdx) # fix idx 108 | 109 | idx = self.idx2ValidIdx[idx] 110 | 111 | file_name = self.h5_file['file_name_base'][idx].decode('utf-8') 112 | gaze_location = self.h5_file['gaze_location'][idx] 113 | screen_size = self.h5_file['screen_size'][idx] 114 | 115 | person_idx = int(file_name.split('/')[-3][1:]) 116 | 117 | left_eye_image = skimage.io.imread(f"{self.data_path}/{file_name}-left_eye.png") 118 | left_eye_image = np.flip(left_eye_image, axis=1) 119 | right_eye_image = skimage.io.imread(f"{self.data_path}/{file_name}-right_eye.png") 120 | full_face_image = skimage.io.imread(f"{self.data_path}/{file_name}-full_face.png") 121 | gaze_pitch = np.array(self.h5_file['gaze_pitch'][idx]) 122 | gaze_yaw = np.array(self.h5_file['gaze_yaw'][idx]) 123 | 124 | if augmented_person or self.force_flip: 125 | person_idx += 15 # fix person_idx 126 | left_eye_image = np.flip(left_eye_image, axis=1) 127 | right_eye_image = np.flip(right_eye_image, axis=1) 128 | full_face_image = np.flip(full_face_image, axis=1) 129 | gaze_yaw *= -1 130 | 131 | if self.transform: 132 | left_eye_image = self.transform(image=left_eye_image)["image"] 133 | right_eye_image = self.transform(image=right_eye_image)["image"] 134 | full_face_image = self.transform(image=full_face_image)["image"] 135 | 136 | return { 137 | 'file_name': file_name, 138 | 'gaze_location': gaze_location, 139 | 'screen_size': screen_size, 140 | 141 | 'person_idx': person_idx, 142 | 143 | 'left_eye_image': left_eye_image, 144 | 'right_eye_image': right_eye_image, 145 | 'full_face_image': full_face_image, 146 | 147 | 'gaze_pitch': gaze_pitch, 148 | 'gaze_yaw': gaze_yaw, 149 | } 150 | 151 | 152 | def get_dataloaders(path_to_data: str, validate_on_person: int, test_on_person: int, batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader]: 153 | """ 154 | Create train, valid and test dataset. 155 | The train dataset includes all persons except `validate_on_person` and `test_on_person`. 156 | 157 | :param path_to_data: path to dataset 158 | :param validate_on_person: person id to validate on during training 159 | :param test_on_person: person id to test on after training 160 | :param batch_size: batch size 161 | :return: train, valid and test dataset 162 | """ 163 | transform = { 164 | 'train': A.Compose([ 165 | A.ShiftScaleRotate(p=1.0, shift_limit=0.2, scale_limit=0.1, rotate_limit=10), 166 | A.Normalize(), 167 | ToTensorV2() 168 | ]), 169 | 'valid': A.Compose([ 170 | A.Normalize(), 171 | ToTensorV2() 172 | ]) 173 | } 174 | 175 | train_on_persons = list(range(0, 15)) 176 | if validate_on_person in train_on_persons: 177 | train_on_persons.remove(validate_on_person) 178 | if test_on_person in train_on_persons: 179 | train_on_persons.remove(test_on_person) 180 | print('train on persons', train_on_persons) 181 | print('valid on person', validate_on_person) 182 | print('test on person', test_on_person) 183 | 184 | dataset_train = MPIIFaceGaze(path_to_data, 'data.h5', keep_person_idxs=train_on_persons, transform=transform['train'], train=True) 185 | print('len(dataset_train)', len(dataset_train)) 186 | train_dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4) 187 | 188 | dataset_valid = MPIIFaceGaze(path_to_data, 'data.h5', keep_person_idxs=[validate_on_person], transform=transform['valid']) 189 | print('len(dataset_train)', len(dataset_valid)) 190 | valid_dataloader = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers=4) 191 | 192 | dataset_test = MPIIFaceGaze(path_to_data, 'data.h5', keep_person_idxs=[test_on_person], transform=transform['valid'], use_erroneous_data=True) 193 | print('len(dataset_train)', len(dataset_test)) 194 | test_dataloader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=4) 195 | 196 | return train_dataloader, valid_dataloader, test_dataloader 197 | -------------------------------------------------------------------------------- /dataset/mpii_face_gaze_preprocessing.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pathlib 3 | from argparse import ArgumentParser 4 | from collections import defaultdict 5 | from typing import Tuple 6 | 7 | import cv2 8 | import h5py 9 | import numpy as np 10 | import pandas as pd 11 | import scipy.io 12 | import skimage.io 13 | from tqdm import tqdm 14 | 15 | from dataset.mpii_face_gaze_errors import check_mpii_face_gaze_not_on_screen 16 | 17 | 18 | def get_matrices(camera_matrix: np.ndarray, distance_norm: int, center_point: np.ndarray, focal_norm: int, head_rotation_matrix: np.ndarray, image_output_size: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 19 | """ 20 | Calculate rotation, scaling and transformation matrix. 21 | 22 | :param camera_matrix: intrinsic camera matrix 23 | :param distance_norm: normalized distance of the camera 24 | :param center_point: position of the center in the image 25 | :param focal_norm: normalized focal length 26 | :param head_rotation_matrix: rotation of the head 27 | :param image_output_size: output size of the output image 28 | :return: rotation, scaling and transformation matrix 29 | """ 30 | # normalize image 31 | distance = np.linalg.norm(center_point) # actual distance between center point and original camera 32 | z_scale = distance_norm / distance 33 | 34 | cam_norm = np.array([ 35 | [focal_norm, 0, image_output_size[0] / 2], 36 | [0, focal_norm, image_output_size[1] / 2], 37 | [0, 0, 1.0], 38 | ]) 39 | 40 | scaling_matrix = np.array([ 41 | [1.0, 0.0, 0.0], 42 | [0.0, 1.0, 0.0], 43 | [0.0, 0.0, z_scale], 44 | ]) 45 | 46 | forward = (center_point / distance).reshape(3) 47 | down = np.cross(forward, head_rotation_matrix[:, 0]) 48 | down /= np.linalg.norm(down) 49 | right = np.cross(down, forward) 50 | right /= np.linalg.norm(right) 51 | 52 | rotation_matrix = np.asarray([right, down, forward]) 53 | transformation_matrix = np.dot(np.dot(cam_norm, scaling_matrix), np.dot(rotation_matrix, np.linalg.inv(camera_matrix))) 54 | 55 | return rotation_matrix, scaling_matrix, transformation_matrix 56 | 57 | 58 | def equalize_hist_rgb(rgb_img: np.ndarray) -> np.ndarray: 59 | """ 60 | Equalize the histogram of a RGB image. 61 | 62 | :param rgb_img: RGB image 63 | :return: equalized RGB image 64 | """ 65 | ycrcb_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2YCrCb) # convert from RGB color-space to YCrCb 66 | ycrcb_img[:, :, 0] = cv2.equalizeHist(ycrcb_img[:, :, 0]) # equalize the histogram of the Y channel 67 | equalized_img = cv2.cvtColor(ycrcb_img, cv2.COLOR_YCrCb2RGB) # convert back to RGB color-space from YCrCb 68 | return equalized_img 69 | 70 | 71 | def normalize_single_image(image: np.ndarray, head_rotation, gaze_target: np.ndarray, center_point: np.ndarray, camera_matrix: np.ndarray, is_eye: bool = True) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 72 | """ 73 | The normalization process of a single image, creates a normalized eye image or a face image, depending on `is_eye`. 74 | 75 | :param image: original image 76 | :param head_rotation: rotation of the head 77 | :param gaze_target: 3D target of the gaze 78 | :param center_point: 3D point on the face to focus on 79 | :param camera_matrix: intrinsic camera matrix 80 | :param is_eye: if true the `distance_norm` and `image_output_size` values for the eye are used 81 | :return: normalized image, normalized gaze and rotation matrix 82 | """ 83 | # normalized camera parameters 84 | focal_norm = 960 # focal length of normalized camera 85 | distance_norm = 500 if is_eye else 1600 # normalized distance between eye and camera 86 | image_output_size = (96, 64) if is_eye else (96, 96) # size of cropped eye image 87 | 88 | # compute estimated 3D positions of the landmarks 89 | if gaze_target is not None: 90 | gaze_target = gaze_target.reshape((3, 1)) 91 | 92 | head_rotation_matrix, _ = cv2.Rodrigues(head_rotation) 93 | rotation_matrix, scaling_matrix, transformation_matrix = get_matrices(camera_matrix, distance_norm, center_point, focal_norm, head_rotation_matrix, image_output_size) 94 | 95 | img_warped = cv2.warpPerspective(image, transformation_matrix, image_output_size) # image normalization 96 | img_warped = equalize_hist_rgb(img_warped) # equalizes the histogram (normalization) 97 | 98 | if gaze_target is not None: 99 | # normalize gaze vector 100 | gaze_normalized = gaze_target - center_point # gaze vector 101 | # For modified data normalization, scaling is not applied to gaze direction, so here is only R applied. 102 | gaze_normalized = np.dot(rotation_matrix, gaze_normalized) 103 | gaze_normalized = gaze_normalized / np.linalg.norm(gaze_normalized) 104 | else: 105 | gaze_normalized = np.zeros(3) 106 | 107 | return img_warped, gaze_normalized.reshape(3), rotation_matrix 108 | 109 | 110 | def main(input_path: str, output_path: str): 111 | data = defaultdict(list) 112 | 113 | face_model = scipy.io.loadmat(f'{input_path}/6 points-based face model.mat')['model'] 114 | 115 | for person_idx, person_path in enumerate(tqdm(sorted(glob.glob(f'{input_path}/p*')), desc='person')): 116 | person = person_path.split('/')[-1] 117 | 118 | camera_matrix = scipy.io.loadmat(f'{person_path}/Calibration/Camera.mat')['cameraMatrix'] 119 | screen_size = scipy.io.loadmat(f'{person_path}/Calibration/screenSize.mat') 120 | screen_width_pixel = screen_size["width_pixel"].item() 121 | screen_height_pixel = screen_size["height_pixel"].item() 122 | annotations = pd.read_csv(f'{person_path}/{person}.txt', sep=' ', header=None, index_col=0) 123 | 124 | for day_path in tqdm(sorted(glob.glob(f'{person_path}/day*')), desc='day'): 125 | day = day_path.split('/')[-1] 126 | for image_path in sorted(glob.glob(f'{day_path}/*.jpg')): 127 | annotation = annotations.loc['/'.join(image_path.split('/')[-2:])] 128 | 129 | img = skimage.io.imread(image_path) 130 | height, width, _ = img.shape 131 | 132 | head_rotation = annotation[14:17].to_numpy().reshape(-1).astype(float) # 3D head rotation based on 6 points-based 3D face model 133 | head_translation = annotation[17:20].to_numpy().reshape(-1).astype(float) # 3D head translation based on 6 points-based 3D face model 134 | gaze_target_3d = annotation[23:26].to_numpy().reshape(-1).astype(float) # 3D gaze target position related to camera (on the screen) 135 | 136 | head_rotation_matrix, _ = cv2.Rodrigues(head_rotation) 137 | face_landmarks = np.dot(head_rotation_matrix, face_model) + head_translation.reshape((3, 1)) # 3D positions of facial landmarks 138 | left_eye_center = 0.5 * (face_landmarks[:, 2] + face_landmarks[:, 3]).reshape((3, 1)) # center eye 139 | right_eye_center = 0.5 * (face_landmarks[:, 0] + face_landmarks[:, 1]).reshape((3, 1)) # center eye 140 | face_center = face_landmarks.mean(axis=1).reshape((3, 1)) 141 | 142 | img_warped_left_eye, _, _ = normalize_single_image(img, head_rotation, None, left_eye_center, camera_matrix) 143 | img_warped_right_eye, _, _ = normalize_single_image(img, head_rotation, None, right_eye_center, camera_matrix) 144 | img_warped_face, gaze_normalized, rotation_matrix = normalize_single_image(img, head_rotation, gaze_target_3d, face_center, camera_matrix, is_eye=False) 145 | 146 | # Q&A 2 https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/appearance-based-gaze-estimation-in-the-wild 147 | gaze_pitch = np.arcsin(-gaze_normalized[1]) 148 | gaze_yaw = np.arctan2(-gaze_normalized[0], -gaze_normalized[2]) 149 | 150 | base_file_name = f'{person}/{day}/' 151 | pathlib.Path(f"{output_path}/{base_file_name}").mkdir(parents=True, exist_ok=True) 152 | base_file_name += f'{image_path.split("/")[-1][:-4]}' 153 | 154 | skimage.io.imsave(f"{output_path}/{base_file_name}-left_eye.png", img_warped_left_eye.astype(np.uint8), check_contrast=False) 155 | skimage.io.imsave(f"{output_path}/{base_file_name}-right_eye.png", img_warped_right_eye.astype(np.uint8), check_contrast=False) 156 | skimage.io.imsave(f"{output_path}/{base_file_name}-full_face.png", img_warped_face.astype(np.uint8), check_contrast=False) 157 | 158 | data['file_name_base'].append(base_file_name) 159 | data['gaze_pitch'].append(gaze_pitch) 160 | data['gaze_yaw'].append(gaze_yaw) 161 | data['gaze_location'].append(list(annotation[:2])) 162 | data['screen_size'].append([screen_width_pixel, screen_height_pixel]) 163 | 164 | with h5py.File(f'{output_path}/data.h5', 'w') as file: 165 | for key, value in data.items(): 166 | if key == 'file_name_base': # only str 167 | file.create_dataset(key, data=value, compression='gzip', chunks=True) 168 | else: 169 | value = np.asarray(value) 170 | file.create_dataset(key, data=value, shape=value.shape, compression='gzip', chunks=True) 171 | 172 | check_mpii_face_gaze_not_on_screen(args.input_path, args.output_path) 173 | 174 | 175 | if __name__ == '__main__': 176 | parser = ArgumentParser() 177 | parser.add_argument("--input_path", type=str, default='./MPIIFaceGaze') 178 | parser.add_argument("--output_path", type=str, default='./data') 179 | args = parser.parse_args() 180 | 181 | main(args.input_path, args.output_path) 182 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Evaluation of a Monocular Eye Tracking Set-Up 2 | 3 | As part of my master thesis, I implemented a new state-of-the-art model that is based on the work of [Chen et al.](https://doi.org/10.1109/WACV45572.2020.9093419). \ 4 | For 9 calibration samples, the previous state-of-the-art performance can be improved by up to 5.44% (2.553 degrees compared to 2.7 degrees) and for 128 calibration samples, by 7% (2.418 degrees compared to 2.6 degrees). 5 | This is accomplished by (a) improving the extraction of eye features, (b) refining the fusion process of these features, (c) removing erroneous data from the MPIIFaceGaze dataset during training, and (d) optimizing the calibration method. 6 | 7 | A software to [collect own gaze data](https://github.com/pperle/gaze-data-collection) and the [full gaze tracking pipeline](https://github.com/pperle/gaze-tracking-pipeline) is also available. 8 | 9 | ![Results of the different models.](./docs/compare_model_results.svg) 10 | 11 | For the citaitions [1] - [10] please see below. "own model 1" represents the model described in the section below. 12 | "own model 2" uses the same model architecture as "own model 1" but is trained without the erroneous data, see MPIIFaceGaze section below. 13 | "own model 3" is the same as "own model 2" but with the calibrations points organized in a $\sqrt{k}\times\sqrt{k}$ grid instead of randomly on the screen. 14 | 15 | 16 | ## Model 17 | Since the feature extractors share the same weights for both eyes, it has been shown experimentally that the feature extraction process can be improved by flipping one of the eye images so that the noses of all eye images are on the same side. 18 | The main reason for this is that the images of the two eyes are more similar this way and the feature extractor can focus more on the relevant features, rather than the unimportant features, of either the left or the right eye. 19 | 20 | The architectural improvement that has had the most impact is the improved feature fusion process of left and right eye features. 21 | Instead of simply combining the two features, they are combined using Squeeze-and-Excitation (SE) blocks. 22 | This introduces a control mechanism for the channel relationships of the extracted feature maps that the model can learn serially. 23 | 24 | Start training by running `python train.py --path_to_data=./data --validate_on_person=1 --test_on_person=0`. 25 | For pretrained models, please see evaluation section. 26 | 27 | ## Data 28 | While examining and analyzing the most commonly used gaze prediction dataset, [MPIIFaceGaze](https://www.perceptualui.org/research/datasets/MPIIFaceGaze/) a subset of [MPIIGaze](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/appearance-based-gaze-estimation-in-the-wild/), in detail. 29 | It was realized that some recorded data does not match the provided screen sizes. 30 | For participant 2, 7, and 10, 0.043%, 8.79%, and 0.39% of the gazes directed at the screen did not match the screen provided, respectively. 31 | The left figure below shows recorded points in the datasets that do not match the provided screen size. 32 | These false target gaze positions are also visible in the right figure below, where the gaze point that are not on the screen have a different yaw offset to the ground truth. 33 | 34 | ![Results of the MPIIFaceGaze analysis](./docs/mpiifacegaze_analysis.png) 35 | 36 | To the best of our knowledge, we are the first to address this problem of this widespread dataset, and we propose to remove all days with any errors for people 2, 7, and 10, resulting in a new dataset we call MPIIFaceGaze-. 37 | This would only reduce the dataset by about 3.2%. As shown in the first figure, see "own model 2", removing these erroneous data improves the model's overall performance. 38 | 39 | For preprocessing MPIIFaceGaze, [download](https://www.perceptualui.org/research/datasets/MPIIFaceGaze/) the original dataset and then 40 | run `python dataset/mpii_face_gaze_preprocessing.py --input_path=./MPIIFaceGaze --output_path=./data`. 41 | Or [download the preprocessed dataset](https://drive.google.com/uc?export=download&id=1eCdULbgtJKmZPRrLoBtIS1mw_IQwH6Zi). 42 | 43 | To only generate the CSV files with all filenames which gaze is not on the screen, run `python dataset/mpii_face_gaze_errors.py --input_path=./MPIIFaceGaze --output_path=./data`. 44 | This can be run on MPIIGaze and MPIIFaceGaze, or the CSV files can be directly downloaded for [MPIIGaze](https://drive.google.com/file/d/1buUCPO0xluVxYxN4FwnupNDP3mpZ3SiN/view?usp=sharing) and [MPIIFaceGaze](https://drive.google.com/file/d/1Cq25df9124q8vkdsJO1BiqjuLSWlXz1r/view?usp=sharing). 45 | 46 | ## Calibration 47 | Nine calibration samples has become the norm for the comparison of different model architectures using MPIIFaceGaze. 48 | When the calibration points are organized in a $\sqrt{k}\times\sqrt{k}$ grid instead of randomly on the screen, or all in one position, the resulting person-specific calibration is more accurate. 49 | The three different ways to distribute the calibration point are compared in the figure below, also see "own model 3" in the first figure. 50 | Nine calibration samples aligned in a grid result in a lower angular error than 9 randomly positioned calibration samples. 51 | 52 | To collect your own calibration data or dataset, please refer to [gaze data collection](https://github.com/pperle/gaze-data-collection). 53 | 54 | ![Comparison of the position of the calibration samples.](./docs/compare_points_on_screen_positions.svg) 55 | 56 | 57 | ## Evaluation 58 | For evaluation, the trained models are evaluated on the full MPIIFaceGaze, including the erroneous data, for a fair comparison to other approaches. 59 | Download the [pretrained "own model 2" models](https://drive.google.com/drive/folders/1-_bOyMgAQmnwRGfQ4QIQk7hrin0Mexch?usp=sharing) and 60 | run `python eval.py --path_to_checkpoints=./pretrained_models --path_to_data=./data` to reproduce the results shown in the figure above and the table below. 61 | `--grid_calibration_samples=True` takes a long time to evaluate, for the ease of use the number of calibration runs is reduced to 500. 62 | 63 | | | random calibration
k=9 | random calibration
k=128 | grid calibration
k=9 | grid calibration
k=128 |
k=all | 64 | |---|---:|---:|---:|---:|---:| 65 | | **p00** | 1.780 | 1.676 | 1.760 | 1.674 | 1.668 | 66 | | **p01** | 1.899 | 1.777 | 1.893 | 1.769 | 1.767 | 67 | | **p02** | 1.910 | 1.790 | 1.875 | 1.787 | 1.780 | 68 | | **p03** | 2.924 | 2.729 | 2.929 | 2.712 | 2.714 | 69 | | **p04** | 2.355 | 2.239 | 2.346 | 2.229 | 2.229 | 70 | | **p05** | 1.836 | 1.720 | 1.826 | 1.721 | 1.711 | 71 | | **p06** | 2.569 | 2.464 | 2.596 | 2.460 | 2.455 | 72 | | **p07** | 3.823 | 3.599 | 3.737 | 3.562 | 3.582 | 73 | | **p08** | 3.778 | 3.508 | 3.637 | 3.501 | 3.484 | 74 | | **p09** | 2.695 | 2.528 | 2.667 | 2.526 | 2.515 | 75 | | **p10** | 3.241 | 3.126 | 3.199 | 3.105 | 3.118 | 76 | | **p11** | 2.668 | 2.535 | 2.667 | 2.536 | 2.524 | 77 | | **p12** | 2.204 | 1.877 | 2.131 | 1.882 | 1.848 | 78 | | **p13** | 2.914 | 2.753 | 2.859 | 2.754 | 2.741 | 79 | | **p14** | 2.161 | 2.010 | 2.172 | 2.052 | 1.998 | 80 | | **mean** | **2.584** | **2.422** | **2.553** | **2.418** | **2.409** | 81 | 82 | 83 | ## Bibliography 84 | [1] Zhaokang Chen and Bertram E. Shi, “Appearance-based gaze estimation using dilated-convolutions”, Lecture Notes in Computer Science, vol. 11366, C. V. Jawahar, Hongdong Li, Greg Mori, and Konrad Schindler, Eds., pp. 309–324, 2018. DOI: 10.1007/978-3-030-20876-9_20. [Online]. Available: https://doi.org/10.1007/978-3-030-20876-9_20. \ 85 | [2] ——, “Offset calibration for appearance-based gaze estimation via gaze decomposition”, in IEEE Winter Conference on Applications of Computer Vision, WACV 2020, Snowmass Village, CO, USA, March 1-5, 2020, IEEE, 2020, pp. 259–268. DOI: 10.1109/WACV45572.2020.9093419. [Online]. Available: https://doi.org/10.1109/WACV45572.2020.9093419. \ 86 | [3] Tobias Fischer, Hyung Jin Chang, and Yiannis Demiris, “RT-GENE: real-time eye gaze estimation in natural environments”, in Computer Vision - ECCV 2018 - 15th European Conference, Munich, Germany, September 8-14, 2018, Proceedings, Part X, Vittorio Ferrari, Martial Hebert, Cristian Sminchisescu, and Yair Weiss, Eds., ser. Lecture Notes in Computer Science, vol. 11214, Springer, 2018, pp. 339–357. DOI: 10.1007/978-3-030-01249-6_21. [Online]. Available: https://doi.org/10.1007/978-3-030-01249-6_21. \ 87 | [4] Erik Lindén, Jonas Sjöstrand, and Alexandre Proutière, “Learning to personalize in appearance-based gaze tracking”, pp. 1140–1148, 2019. DOI: 10.1109/ICCVW.2019.00145. [Online]. Available: https://doi.org/10.1109/ICCVW.2019.00145. \ 88 | [5] Gang Liu, Yu Yu, Kenneth Alberto Funes Mora, and Jean-Marc Odobez, “A differential approach for gaze estimation with calibration”, in British Machine Vision Conference 2018, BMVC 2018, Newcastle, UK, September 3-6, 2018, BMVA Press, 2018, p. 235. [Online]. Available: http://bmvc2018.org/contents/papers/0792.pdf. \ 89 | [6] Seonwook Park, Shalini De Mello, Pavlo Molchanov, Umar Iqbal, Otmar Hilliges, and Jan Kautz, “Few-shot adaptive gaze estimation”, pp. 9367–9376, 2019. DOI: 10.1109/ICCV.2019.00946. [Online]. Available: https://doi.org/10.1109/ICCV.2019.00946. \ 90 | [7] Seonwook Park, Xucong Zhang, Andreas Bulling, and Otmar Hilliges, “Learning to find eye region landmarks for remote gaze estimation in unconstrained settings”, Bonita Sharif and Krzysztof Krejtz, Eds., 21:1–21:10, 2018. DOI: 10.1145/3204493.3204545. [Online]. Available: https://doi.org/10.1145/3204493.3204545. \ 91 | [8] Yu Yu, Gang Liu, and Jean-Marc Odobez, “Improving few-shot user-specific gaze adaptation via gaze redirection synthesis”, pp. 11 937–11 946, 2019. DOI: 10.1109/CVPR.2019.01221. [Online]. Available: http://openaccess.thecvf.com/content_CVPR_2019/html/Yu_Improving_Few-Shot_User-Specific_Gaze_Adaptation_via_Gaze_Redirection_Synthesis_CVPR_2019_paper.html. \ 92 | [9] Xucong Zhang, Yusuke Sugano, Mario Fritz, and Andreas Bulling, “It’s written all over your face: Full-face appearance-based gaze estimation”, pp. 2299–2308, 2017. DOI: 10.1109/CVPRW.2017.284. [Online]. Available: https://doi.org/10.1109/CVPRW.2017.284 \ 93 | [10] ——, “Mpiigaze: Real-world dataset and deep appearance-based gaze estimation”, IEEE Trans. Pattern Anal. Mach. Intell., vol. 41, no. 1, pp. 162–175, 2019. DOI: 10.1109/TPAMI.2017.2778103. [Online]. Available: https://doi.org/10.1109/TPAMI.2017.2778103. \ 94 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from pytorch_lightning import seed_everything, Trainer 8 | from pytorch_lightning.loggers import TensorBoardLogger 9 | from pytorch_lightning.utilities.types import STEP_OUTPUT, EPOCH_OUTPUT 10 | 11 | from dataset.mpii_face_gaze_dataset import get_dataloaders 12 | from model import FinalModel 13 | from utils import calc_angle_error, PitchYaw, plot_prediction_vs_ground_truth, log_figure, get_random_idx, get_each_of_one_grid_idx 14 | 15 | 16 | class Model(FinalModel): 17 | def __init__(self, learning_rate: float = 0.001, weight_decay: float = 0., k=None, adjust_slope: bool = False, grid_calibration_samples: bool = False, *args, **kwargs): 18 | super().__init__(*args, **kwargs) 19 | self.learning_rate = learning_rate 20 | self.weight_decay = weight_decay 21 | self.k = [9, 128] if k is None else k 22 | self.adjust_slope = adjust_slope 23 | self.grid_calibration_samples = grid_calibration_samples 24 | 25 | self.save_hyperparameters() # log hyperparameters 26 | 27 | def configure_optimizers(self): 28 | return torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 29 | 30 | def __step(self, batch: dict) -> Tuple: 31 | """ 32 | Operates on a single batch of data. 33 | 34 | :param batch: The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. 35 | :return: calculated loss, given values and predicted outputs 36 | """ 37 | person_idx = batch['person_idx'].long() 38 | left_eye_image = batch['left_eye_image'].float() 39 | right_eye_image = batch['right_eye_image'].float() 40 | full_face_image = batch['full_face_image'].float() 41 | 42 | gaze_pitch = batch['gaze_pitch'].float() 43 | gaze_yaw = batch['gaze_yaw'].float() 44 | labels = torch.stack([gaze_pitch, gaze_yaw]).T 45 | 46 | outputs = self(person_idx, full_face_image, right_eye_image, left_eye_image) # prediction on the base model 47 | loss = F.mse_loss(outputs, labels) 48 | 49 | return loss, labels, outputs 50 | 51 | def training_step(self, train_batch: dict, batch_idx: int) -> STEP_OUTPUT: 52 | loss, labels, outputs = self.__step(train_batch) 53 | 54 | self.log('train/loss', loss) 55 | self.log('train/angular_error', calc_angle_error(labels, outputs)) 56 | 57 | return loss 58 | 59 | def validation_step(self, valid_batch: dict, batch_idx: int) -> STEP_OUTPUT: 60 | loss, labels, outputs = self.__step(valid_batch) 61 | 62 | self.log('valid/offset(k=0)/loss', loss) 63 | self.log('valid/offset(k=0)/angular_error', calc_angle_error(labels, outputs)) 64 | 65 | return {'loss': loss, 'labels': labels, 'outputs': outputs, 'gaze_locations': valid_batch['gaze_location'], 'screen_sizes': valid_batch['screen_size']} 66 | 67 | def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: 68 | self.__log_and_plot_details(outputs, 'valid') 69 | 70 | def test_step(self, test_batch: dict, batch_idx: int) -> STEP_OUTPUT: 71 | loss, labels, outputs = self.__step(test_batch) 72 | 73 | self.log('test/offset(k=0)/loss', loss) 74 | self.log('test/offset(k=0)/angular_error', calc_angle_error(labels, outputs)) 75 | 76 | return {'loss': loss, 'labels': labels, 'outputs': outputs, 'gaze_locations': test_batch['gaze_location'], 'screen_sizes': test_batch['screen_size']} 77 | 78 | def test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: 79 | self.__log_and_plot_details(outputs, 'test') 80 | 81 | def __log_and_plot_details(self, outputs, tag: str): 82 | test_labels = torch.cat([output['labels'] for output in outputs]) 83 | test_outputs = torch.cat([output['outputs'] for output in outputs]) 84 | test_gaze_locations = torch.cat([output['gaze_locations'] for output in outputs]) 85 | test_screen_sizes = torch.cat([output['screen_sizes'] for output in outputs]) 86 | 87 | figure = plot_prediction_vs_ground_truth(test_labels, test_outputs, PitchYaw.PITCH) 88 | log_figure(self.logger, f'{tag}/offset(k=0)/pitch', figure, self.global_step) 89 | 90 | figure = plot_prediction_vs_ground_truth(test_labels, test_outputs, PitchYaw.YAW) 91 | log_figure(self.logger, f'{tag}/offset(k=0)/yaw', figure, self.global_step) 92 | 93 | # find calibration params 94 | last_x = 500 95 | calibration_train = test_outputs[:-last_x].cpu().detach().numpy() 96 | calibration_test = test_outputs[-last_x:].cpu().detach().numpy() 97 | 98 | calibration_train_labels = test_labels[:-last_x].cpu().detach().numpy() 99 | calibration_test_labels = test_labels[-last_x:].cpu().detach().numpy() 100 | 101 | gaze_locations_train = test_gaze_locations[:-last_x].cpu().detach().numpy() 102 | screen_sizes_train = test_screen_sizes[:-last_x].cpu().detach().numpy() 103 | 104 | if len(calibration_train) > 0: 105 | for k in self.k: 106 | if k <= 0: 107 | continue 108 | calibrated_solutions = [] 109 | 110 | num_calibration_runs = 500 if self.grid_calibration_samples else 10_000 # original results are both evaluated with 10,000 runs 111 | for calibration_run_idx in range(num_calibration_runs): # get_each_of_one_grid_idx is slower than get_random_idx 112 | np.random.seed(42 + calibration_run_idx) 113 | calibration_sample_idxs = get_each_of_one_grid_idx(k, gaze_locations_train, screen_sizes_train) if self.grid_calibration_samples else get_random_idx(k, len(calibration_train)) 114 | calibration_points_x = np.asarray([calibration_train[idx] for idx in calibration_sample_idxs]) 115 | calibration_points_y = np.asarray([calibration_train_labels[idx] for idx in calibration_sample_idxs]) 116 | 117 | if self.adjust_slope: 118 | m, b = np.polyfit(calibration_points_y[:, :1].reshape(-1), calibration_points_x[:, :1].reshape(-1), deg=1) 119 | pitch_fixed = (calibration_test[:, :1] - b) * (1 / m) 120 | m, b = np.polyfit(calibration_points_y[:, 1:].reshape(-1), calibration_points_x[:, 1:].reshape(-1), deg=1) 121 | yaw_fixed = (calibration_test[:, 1:] - b) * (1 / m) 122 | else: 123 | mean_diff_pitch = (calibration_points_y[:, :1] - calibration_points_x[:, :1]).mean() # mean offset 124 | pitch_fixed = calibration_test[:, :1] + mean_diff_pitch 125 | mean_diff_yaw = (calibration_points_y[:, 1:] - calibration_points_x[:, 1:]).mean() # mean offset 126 | yaw_fixed = calibration_test[:, 1:] + mean_diff_yaw 127 | 128 | pitch_fixed, yaw_fixed = torch.Tensor(pitch_fixed), torch.Tensor(yaw_fixed) 129 | outputs_fixed = torch.stack([pitch_fixed, yaw_fixed], dim=1).squeeze(-1) 130 | calibrated_solutions.append(calc_angle_error(torch.Tensor(calibration_test_labels), outputs_fixed).item()) 131 | 132 | self.log(f'{tag}/offset(k={k})/mean_angular_error', np.asarray(calibrated_solutions).mean()) 133 | self.log(f'{tag}/offset(k={k})/std_angular_error', np.asarray(calibrated_solutions).std()) 134 | 135 | # best case, with all calibration samples, all values except the last `last_x` values 136 | if self.adjust_slope: 137 | m, b = np.polyfit(calibration_train_labels[:, :1].reshape(-1), calibration_train[:, :1].reshape(-1), deg=1) 138 | pitch_fixed = torch.Tensor((calibration_test[:, :1] - b) * (1 / m)) 139 | m, b = np.polyfit(calibration_train_labels[:, 1:].reshape(-1), calibration_train[:, 1:].reshape(-1), deg=1) 140 | yaw_fixed = torch.Tensor((calibration_test[:, 1:] - b) * (1 / m)) 141 | else: 142 | mean_diff_pitch = (calibration_train_labels[:, :1] - calibration_train[:, :1]).mean() # mean offset 143 | pitch_fixed = calibration_test[:, :1] + mean_diff_pitch 144 | mean_diff_yaw = (calibration_train_labels[:, 1:] - calibration_train[:, 1:]).mean() # mean offset 145 | yaw_fixed = calibration_test[:, 1:] + mean_diff_yaw 146 | 147 | pitch_fixed, yaw_fixed = torch.Tensor(pitch_fixed), torch.Tensor(yaw_fixed) 148 | outputs_fixed = torch.stack([pitch_fixed, yaw_fixed], dim=1).squeeze(-1) 149 | calibration_test_labels = torch.Tensor(calibration_test_labels) 150 | self.log(f'{tag}/offset(k=all)/angular_error', calc_angle_error(calibration_test_labels, outputs_fixed)) 151 | 152 | figure = plot_prediction_vs_ground_truth(calibration_test_labels, outputs_fixed, PitchYaw.PITCH) 153 | log_figure(self.logger, f'{tag}/offset(k=all)/pitch', figure, self.global_step) 154 | 155 | figure = plot_prediction_vs_ground_truth(calibration_test_labels, outputs_fixed, PitchYaw.YAW) 156 | log_figure(self.logger, f'{tag}/offset(k=all)/yaw', figure, self.global_step) 157 | 158 | 159 | def main(path_to_data: str, validate_on_person: int, test_on_person: int, learning_rate: float, weight_decay: float, batch_size: int, k: int, adjust_slope: bool, grid_calibration_samples: bool): 160 | seed_everything(42) 161 | 162 | model = Model(learning_rate, weight_decay, k, adjust_slope, grid_calibration_samples) 163 | 164 | trainer = Trainer( 165 | gpus=1, 166 | max_epochs=50, 167 | default_root_dir='./saved_models/', 168 | logger=[ 169 | TensorBoardLogger(save_dir="tb_logs"), 170 | ], 171 | benchmark=True, 172 | ) 173 | 174 | train_dataloader, valid_dataloader, test_dataloader = get_dataloaders(path_to_data, validate_on_person, test_on_person, batch_size) 175 | trainer.fit(model, train_dataloader, valid_dataloader) 176 | trainer.test(model, test_dataloader) 177 | 178 | 179 | if __name__ == '__main__': 180 | parser = ArgumentParser() 181 | parser.add_argument("--path_to_data", type=str, default='./data') 182 | parser.add_argument("--validate_on_person", type=int, default=1) 183 | parser.add_argument("--test_on_person", type=int, default=0) 184 | parser.add_argument("--learning_rate", type=float, default=0.001) 185 | parser.add_argument("--weight_decay", type=float, default=0.) 186 | parser.add_argument("--batch_size", type=int, default=64) 187 | parser.add_argument("--k", type=int, default=[9, 128], nargs='+') 188 | parser.add_argument("--adjust_slope", type=bool, default=False) 189 | parser.add_argument("--grid_calibration_samples", type=bool, default=False) 190 | args = parser.parse_args() 191 | 192 | main(args.path_to_data, args.validate_on_person, args.test_on_person, args.learning_rate, args.weight_decay, args.batch_size, args.k, args.adjust_slope, args.grid_calibration_samples) 193 | -------------------------------------------------------------------------------- /docs/compare_points_on_screen_positions.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/compare_model_results.svg: -------------------------------------------------------------------------------- 1 | --------------------------------------------------------------------------------