├── docs
├── mpiifacegaze_analysis.png
├── compare_points_on_screen_positions.svg
└── compare_model_results.svg
├── requirements.txt
├── LICENSE.md
├── eval.py
├── dataset
├── mpii_face_gaze_errors.py
├── mpii_face_gaze_dataset.py
└── mpii_face_gaze_preprocessing.py
├── .gitignore
├── model.py
├── utils.py
├── README.md
└── train.py
/docs/mpiifacegaze_analysis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pperle/gaze-tracking/HEAD/docs/mpiifacegaze_analysis.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy==1.4.1
2 | tqdm==4.61.2
3 | numpy==1.18.5
4 | h5py==2.10.0
5 | torch==1.9.0
6 | opencv_python==4.5.1.48
7 | torchvision==0.10.0
8 | albumentations==1.1.0
9 | matplotlib==3.4.3
10 | pandas==1.3.3
11 | Pillow==8.3.2
12 | pytorch_lightning==1.4.9
13 | scikit-image==0.18.3
14 | torchinfo==1.5.3
15 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 pperle
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from pytorch_lightning import seed_everything, Trainer
4 |
5 | from dataset.mpii_face_gaze_dataset import get_dataloaders
6 | from train import Model
7 |
8 | if __name__ == '__main__':
9 | parser = ArgumentParser()
10 | parser.add_argument("--path_to_checkpoints", type=str, default='./pretrained_models')
11 | parser.add_argument("--path_to_data", type=str, default='./data')
12 | parser.add_argument("--batch_size", type=int, default=64)
13 | parser.add_argument("--k", type=int, default=[9, 128], nargs='+')
14 | parser.add_argument("--adjust_slope", type=bool, default=False)
15 | parser.add_argument("--grid_calibration_samples", type=bool, default=False)
16 | args = parser.parse_args()
17 |
18 | for person_idx in range(15):
19 | person = f'p{person_idx:02d}'
20 |
21 | seed_everything(42)
22 | print('grid_calibration_samples', args.grid_calibration_samples)
23 | model = Model.load_from_checkpoint(f'{args.path_to_checkpoints}/{person}.ckpt', k=args.k, adjust_slope=args.adjust_slope, grid_calibration_samples=args.grid_calibration_samples)
24 |
25 | trainer = Trainer(
26 | gpus=1,
27 | benchmark=True,
28 | )
29 |
30 | _, _, test_dataloader = get_dataloaders(args.path_to_data, 0, person_idx, args.batch_size)
31 | trainer.test(model, test_dataloader)
32 |
--------------------------------------------------------------------------------
/dataset/mpii_face_gaze_errors.py:
--------------------------------------------------------------------------------
1 | import glob
2 | from argparse import ArgumentParser
3 |
4 | import pandas as pd
5 | import scipy.io
6 |
7 |
8 | def check_mpii_gaze_not_on_screen(input_path: str, output_path: str) -> None:
9 | """
10 | Create CSV file with the filename of the images where the gaze is not on the screen.
11 |
12 | :param input_path: path to the original MPIIGaze dataset
13 | :param output_path: output dataset
14 | :return:
15 | """
16 |
17 | data = {'file_name': [], 'on_screen_gaze_position': [], 'monitor_pixels': []}
18 |
19 | for person_file_path in sorted(glob.glob(f'{input_path}/Data/Original/p*'), reverse=True):
20 | person = person_file_path.split('/')[-1]
21 |
22 | screen_size = scipy.io.loadmat(f'{input_path}/Data/Original/{person}/Calibration/screenSize.mat')
23 | screen_width_pixel = screen_size["width_pixel"].item()
24 | screen_height_pixel = screen_size["height_pixel"].item()
25 |
26 | for day_file_path in sorted(glob.glob(f'{person_file_path}/d*')):
27 | day = day_file_path.split('/')[-1]
28 |
29 | df = pd.read_csv(f'{day_file_path}/annotation.txt', sep=' ', header=None)
30 | for row_idx in range(len(df)):
31 | row = df.iloc[row_idx]
32 | on_screen_gaze_target = row[24:26].to_numpy().reshape(-1).astype(int)
33 |
34 | if not (0 <= on_screen_gaze_target[0] <= screen_width_pixel and 0 <= on_screen_gaze_target[1] <= screen_height_pixel):
35 | file_name = f'{person}/{day}/{row_idx + 1:04d}.jpg'
36 |
37 | data['file_name'].append(file_name)
38 | data['on_screen_gaze_position'].append(list(on_screen_gaze_target))
39 | data['monitor_pixels'].append([screen_width_pixel, screen_height_pixel])
40 |
41 | pd.DataFrame(data).to_csv(f'{output_path}/not_on_screen.csv', index=False)
42 |
43 |
44 | def check_mpii_face_gaze_not_on_screen(input_path: str, output_path: str) -> None:
45 | """
46 | Create CSV file with the filename of the images where the gaze is not on the screen.
47 |
48 | :param input_path: path to the original MPIIFaceGaze dataset
49 | :param output_path: output dataset
50 | :return:
51 | """
52 |
53 | data = {'file_name': [], 'on_screen_gaze_position': [], 'monitor_pixels': []}
54 |
55 | for person_file_path in sorted(glob.glob(f'{input_path}/p*')):
56 | person = person_file_path.split('/')[-1]
57 |
58 | screen_size = scipy.io.loadmat(f'{input_path}/{person}/Calibration/screenSize.mat')
59 | screen_width_pixel = screen_size["width_pixel"].item()
60 | screen_height_pixel = screen_size["height_pixel"].item()
61 |
62 | df = pd.read_csv(f'{person_file_path}/{person}.txt', sep=' ', header=None)
63 | df_idx = 0
64 |
65 | for day_file_path in sorted(glob.glob(f'{person_file_path}/d*')):
66 | day = day_file_path.split('/')[-1]
67 |
68 | for image_file_path in sorted(glob.glob(f'{day_file_path}/*.jpg')):
69 | row = df.iloc[df_idx]
70 | on_screen_gaze_target = row[1:3].to_numpy().reshape(-1).astype(int)
71 |
72 | if not (0 <= on_screen_gaze_target[0] <= screen_width_pixel and 0 <= on_screen_gaze_target[1] <= screen_height_pixel):
73 | file_name = f'{person}/{day}/{image_file_path.split("/")[-1]}'
74 |
75 | data['file_name'].append(file_name)
76 | data['on_screen_gaze_position'].append(list(on_screen_gaze_target))
77 | data['monitor_pixels'].append([screen_width_pixel, screen_height_pixel])
78 |
79 | df_idx += 1
80 |
81 | pd.DataFrame(data).to_csv(f'{output_path}/not_on_screen.csv', index=False)
82 |
83 |
84 | if __name__ == '__main__':
85 | parser = ArgumentParser()
86 | parser.add_argument("--input_path", type=str, default='./MPIIFaceGaze')
87 | parser.add_argument("--output_path", type=str, default='./data')
88 | args = parser.parse_args()
89 |
90 | # check_mpiigaze('args.input_path, args.output_path)
91 | check_mpii_face_gaze_not_on_screen(args.input_path, args.output_path)
92 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.toptal.com/developers/gitignore/api/python,pycharm+all,visualstudiocode
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm+all,visualstudiocode
4 |
5 | ### PyCharm+all ###
6 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
7 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
8 |
9 | # User-specific stuff
10 | .idea/**/workspace.xml
11 | .idea/**/tasks.xml
12 | .idea/**/usage.statistics.xml
13 | .idea/**/dictionaries
14 | .idea/**/shelf
15 |
16 | # AWS User-specific
17 | .idea/**/aws.xml
18 |
19 | # Generated files
20 | .idea/**/contentModel.xml
21 |
22 | # Sensitive or high-churn files
23 | .idea/**/dataSources/
24 | .idea/**/dataSources.ids
25 | .idea/**/dataSources.local.xml
26 | .idea/**/sqlDataSources.xml
27 | .idea/**/dynamic.xml
28 | .idea/**/uiDesigner.xml
29 | .idea/**/dbnavigator.xml
30 |
31 | # Gradle
32 | .idea/**/gradle.xml
33 | .idea/**/libraries
34 |
35 | # Gradle and Maven with auto-import
36 | # When using Gradle or Maven with auto-import, you should exclude module files,
37 | # since they will be recreated, and may cause churn. Uncomment if using
38 | # auto-import.
39 | # .idea/artifacts
40 | # .idea/compiler.xml
41 | # .idea/jarRepositories.xml
42 | # .idea/modules.xml
43 | # .idea/*.iml
44 | # .idea/modules
45 | # *.iml
46 | # *.ipr
47 |
48 | # CMake
49 | cmake-build-*/
50 |
51 | # Mongo Explorer plugin
52 | .idea/**/mongoSettings.xml
53 |
54 | # File-based project format
55 | *.iws
56 |
57 | # IntelliJ
58 | out/
59 |
60 | # mpeltonen/sbt-idea plugin
61 | .idea_modules/
62 |
63 | # JIRA plugin
64 | atlassian-ide-plugin.xml
65 |
66 | # Cursive Clojure plugin
67 | .idea/replstate.xml
68 |
69 | # Crashlytics plugin (for Android Studio and IntelliJ)
70 | com_crashlytics_export_strings.xml
71 | crashlytics.properties
72 | crashlytics-build.properties
73 | fabric.properties
74 |
75 | # Editor-based Rest Client
76 | .idea/httpRequests
77 |
78 | # Android studio 3.1+ serialized cache file
79 | .idea/caches/build_file_checksums.ser
80 |
81 | ### PyCharm+all Patch ###
82 | # Ignores the whole .idea folder and all .iml files
83 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360
84 |
85 | .idea/
86 |
87 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023
88 |
89 | *.iml
90 | modules.xml
91 | .idea/misc.xml
92 | *.ipr
93 |
94 | # Sonarlint plugin
95 | .idea/sonarlint
96 |
97 | ### Python ###
98 | # Byte-compiled / optimized / DLL files
99 | __pycache__/
100 | *.py[cod]
101 | *$py.class
102 |
103 | # C extensions
104 | *.so
105 |
106 | # Distribution / packaging
107 | .Python
108 | build/
109 | develop-eggs/
110 | dist/
111 | downloads/
112 | eggs/
113 | .eggs/
114 | lib/
115 | lib64/
116 | parts/
117 | sdist/
118 | var/
119 | wheels/
120 | share/python-wheels/
121 | *.egg-info/
122 | .installed.cfg
123 | *.egg
124 | MANIFEST
125 |
126 | # PyInstaller
127 | # Usually these files are written by a python script from a template
128 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
129 | *.manifest
130 | *.spec
131 |
132 | # Installer logs
133 | pip-log.txt
134 | pip-delete-this-directory.txt
135 |
136 | # Unit test / coverage reports
137 | htmlcov/
138 | .tox/
139 | .nox/
140 | .coverage
141 | .coverage.*
142 | .cache
143 | nosetests.xml
144 | coverage.xml
145 | *.cover
146 | *.py,cover
147 | .hypothesis/
148 | .pytest_cache/
149 | cover/
150 |
151 | # Translations
152 | *.mo
153 | *.pot
154 |
155 | # Django stuff:
156 | *.log
157 | local_settings.py
158 | db.sqlite3
159 | db.sqlite3-journal
160 |
161 | # Flask stuff:
162 | instance/
163 | .webassets-cache
164 |
165 | # Scrapy stuff:
166 | .scrapy
167 |
168 | # Sphinx documentation
169 | docs/_build/
170 |
171 | # PyBuilder
172 | .pybuilder/
173 | target/
174 |
175 | # Jupyter Notebook
176 | .ipynb_checkpoints
177 |
178 | # IPython
179 | profile_default/
180 | ipython_config.py
181 |
182 | # pyenv
183 | # For a library or package, you might want to ignore these files since the code is
184 | # intended to run in multiple environments; otherwise, check them in:
185 | # .python-version
186 |
187 | # pipenv
188 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
189 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
190 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
191 | # install all needed dependencies.
192 | #Pipfile.lock
193 |
194 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
195 | __pypackages__/
196 |
197 | # Celery stuff
198 | celerybeat-schedule
199 | celerybeat.pid
200 |
201 | # SageMath parsed files
202 | *.sage.py
203 |
204 | # Environments
205 | .env
206 | .venv
207 | env/
208 | venv/
209 | ENV/
210 | env.bak/
211 | venv.bak/
212 |
213 | # Spyder project settings
214 | .spyderproject
215 | .spyproject
216 |
217 | # Rope project settings
218 | .ropeproject
219 |
220 | # mkdocs documentation
221 | /site
222 |
223 | # mypy
224 | .mypy_cache/
225 | .dmypy.json
226 | dmypy.json
227 |
228 | # Pyre type checker
229 | .pyre/
230 |
231 | # pytype static type analyzer
232 | .pytype/
233 |
234 | # Cython debug symbols
235 | cython_debug/
236 |
237 | ### VisualStudioCode ###
238 | .vscode/*
239 | !.vscode/settings.json
240 | !.vscode/tasks.json
241 | !.vscode/launch.json
242 | !.vscode/extensions.json
243 | *.code-workspace
244 |
245 | # Local History for Visual Studio Code
246 | .history/
247 |
248 | ### VisualStudioCode Patch ###
249 | # Ignore all local history of files
250 | .history
251 | .ionide
252 |
253 | # End of https://www.toptal.com/developers/gitignore/api/python,pycharm+all,visualstudiocode
254 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_lightning import LightningModule
3 | from torch import nn
4 | from torchinfo import summary
5 | from torchvision import models
6 |
7 |
8 | class SELayer(nn.Module):
9 | """
10 | Squeeze-and-Excitation layer
11 |
12 | https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py
13 | """
14 |
15 | def __init__(self, channel, reduction=16):
16 | super(SELayer, self).__init__()
17 | self.avg_pool = nn.AdaptiveAvgPool2d(1) # Squeeze
18 | self.fc = nn.Sequential( # Excitation (similar to attention)
19 | nn.Linear(channel, channel // reduction, bias=False),
20 | nn.ReLU(inplace=True),
21 | nn.Linear(channel // reduction, channel, bias=False),
22 | nn.Sigmoid()
23 | )
24 |
25 | def forward(self, x):
26 | b, c, _, _ = x.size()
27 | y = self.avg_pool(x).view(b, c)
28 | y = self.fc(y).view(b, c, 1, 1)
29 | return x * y.expand_as(x)
30 |
31 |
32 | class FinalModel(LightningModule):
33 | def __init__(self, *args, **kwargs):
34 | super().__init__(*args, **kwargs)
35 |
36 | self.subject_biases = nn.Parameter(torch.zeros(15 * 2, 2)) # pitch and yaw offset for the original and mirrored participant
37 |
38 | self.cnn_face = nn.Sequential(
39 | models.vgg16(pretrained=True).features[:9], # first four convolutional layers of VGG16 pretrained on ImageNet
40 | nn.Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), padding='same'),
41 | nn.ReLU(inplace=True),
42 | nn.BatchNorm2d(64),
43 | nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(2, 2)),
44 | nn.ReLU(inplace=True),
45 | nn.BatchNorm2d(64),
46 | nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(3, 3)),
47 | nn.ReLU(inplace=True),
48 | nn.BatchNorm2d(64),
49 | nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(5, 5)),
50 | nn.ReLU(inplace=True),
51 | nn.BatchNorm2d(128),
52 | nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(11, 11)),
53 | nn.ReLU(inplace=True),
54 | nn.BatchNorm2d(128),
55 | )
56 |
57 | self.cnn_eye = nn.Sequential(
58 | models.vgg16(pretrained=True).features[:9], # first four convolutional layers of VGG16 pretrained on ImageNet
59 | nn.Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), padding='same'),
60 | nn.ReLU(inplace=True),
61 | nn.BatchNorm2d(64),
62 | nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(2, 2)),
63 | nn.ReLU(inplace=True),
64 | nn.BatchNorm2d(64),
65 | nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(3, 3)),
66 | nn.ReLU(inplace=True),
67 | nn.BatchNorm2d(64),
68 | nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(4, 5)),
69 | nn.ReLU(inplace=True),
70 | nn.BatchNorm2d(128),
71 | nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding='valid', dilation=(5, 11)),
72 | nn.ReLU(inplace=True),
73 | nn.BatchNorm2d(128),
74 | )
75 |
76 | self.fc_face = nn.Sequential(
77 | nn.Flatten(),
78 | nn.Linear(6 * 6 * 128, 256),
79 | nn.ReLU(inplace=True),
80 | nn.BatchNorm1d(256),
81 | nn.Linear(256, 64),
82 | nn.ReLU(inplace=True),
83 | nn.BatchNorm1d(64),
84 | )
85 |
86 | self.cnn_eye2fc = nn.Sequential(
87 | SELayer(256),
88 |
89 | nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding='same'),
90 | nn.ReLU(inplace=True),
91 | nn.BatchNorm2d(256),
92 |
93 | SELayer(256),
94 |
95 | nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding='same'),
96 | nn.ReLU(inplace=True),
97 | nn.BatchNorm2d(128),
98 |
99 | SELayer(128),
100 | )
101 |
102 | self.fc_eye = nn.Sequential(
103 | nn.Flatten(),
104 | nn.Linear(4 * 6 * 128, 512),
105 | nn.ReLU(inplace=True),
106 | nn.BatchNorm1d(512),
107 | )
108 |
109 | self.fc_eyes_face = nn.Sequential(
110 | nn.Dropout(p=0.5),
111 | nn.Linear(576, 256),
112 | nn.ReLU(inplace=True),
113 | nn.BatchNorm1d(256),
114 | nn.Dropout(p=0.5),
115 | nn.Linear(256, 2),
116 | )
117 |
118 | def forward(self, person_idx: torch.Tensor, full_face: torch.Tensor, right_eye: torch.Tensor, left_eye: torch.Tensor):
119 | out_cnn_face = self.cnn_face(full_face)
120 | out_fc_face = self.fc_face(out_cnn_face)
121 |
122 | out_cnn_right_eye = self.cnn_eye(right_eye)
123 | out_cnn_left_eye = self.cnn_eye(left_eye)
124 | out_cnn_eye = torch.cat((out_cnn_right_eye, out_cnn_left_eye), dim=1)
125 |
126 | cnn_eye2fc_out = self.cnn_eye2fc(out_cnn_eye) # feature fusion
127 | out_fc_eye = self.fc_eye(cnn_eye2fc_out)
128 |
129 | fc_concatenated = torch.cat((out_fc_face, out_fc_eye), dim=1)
130 | t_hat = self.fc_eyes_face(fc_concatenated) # subject-independent term
131 |
132 | return t_hat + self.subject_biases[person_idx].squeeze(1) # t_hat + subject-dependent bias term
133 |
134 |
135 | if __name__ == '__main__':
136 | model = FinalModel()
137 | model.summarize(max_depth=1)
138 |
139 | print(model.cnn_face)
140 |
141 | batch_size = 16
142 | summary(model, [
143 | (batch_size, 1),
144 | (batch_size, 3, 96, 96), # full face
145 | (batch_size, 3, 64, 96), # right eye
146 | (batch_size, 3, 64, 96) # left eye
147 | ], dtypes=[torch.long, torch.float, torch.float, torch.float])
148 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import List
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | import torchvision
8 | from matplotlib import pyplot as plt
9 | from pytorch_lightning.loggers import TensorBoardLogger
10 | from PIL import Image
11 | import io
12 |
13 |
14 | class PitchYaw(Enum):
15 | PITCH = 'pitch'
16 | YAW = 'yaw'
17 |
18 |
19 | def pitchyaw_to_3d_vector(pitchyaw: torch.Tensor) -> torch.Tensor:
20 | """
21 | 2D pitch and yaw value to a 3D vector
22 |
23 | :param pitchyaw: 2D gaze value in pitch and yaw
24 | :return: 3D vector
25 | """
26 | return torch.stack([
27 | -torch.cos(pitchyaw[:, 0]) * torch.sin(pitchyaw[:, 1]),
28 | -torch.sin(pitchyaw[:, 0]),
29 | -torch.cos(pitchyaw[:, 0]) * torch.cos(pitchyaw[:, 1])
30 | ], dim=1)
31 |
32 |
33 | def calc_angle_error(labels: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
34 | """
35 | Calculate the angle between `labels` and `outputs` in degrees.
36 |
37 | :param labels: ground truth gaze vectors
38 | :param outputs: predicted gaze vectors
39 | :return: Mean angle in degrees.
40 | """
41 | labels = pitchyaw_to_3d_vector(labels)
42 | labels_norm = labels / torch.linalg.norm(labels, axis=1).reshape((-1, 1))
43 |
44 | outputs = pitchyaw_to_3d_vector(outputs)
45 | outputs_norm = outputs / torch.linalg.norm(outputs, axis=1).reshape((-1, 1))
46 |
47 | angles = F.cosine_similarity(outputs_norm, labels_norm, dim=1)
48 | angles = torch.clip(angles, -1.0, 1.0) # fix NaN values for 1.0 < angles < -1.0
49 |
50 | rad = torch.arccos(angles)
51 | return torch.rad2deg(rad).mean()
52 |
53 |
54 | def plot_prediction_vs_ground_truth(labels, outputs, axis: PitchYaw):
55 | """
56 | Create a plot between the predictions and the ground truth values.
57 |
58 | :param labels: ground truth values
59 | :param outputs: predicted values
60 | :param axis: weather pitch or yaw
61 | :return: scatter plot of predictions and the ground truth values
62 | """
63 |
64 | labels = torch.rad2deg(labels)
65 | outputs = torch.rad2deg(outputs)
66 |
67 | if axis == PitchYaw.PITCH:
68 | plt.scatter(labels[:, :1].cpu().detach().numpy().reshape(-1), outputs[:, :1].cpu().detach().numpy().reshape(-1))
69 | else:
70 | plt.scatter(labels[:, 1:].cpu().detach().numpy().reshape(-1), outputs[:, 1:].cpu().detach().numpy().reshape(-1))
71 | plt.plot([-30, 30], [-30, 30], color='#ff7f0e')
72 | plt.xlabel('ground truth (degrees)')
73 | plt.ylabel('prediction (degrees')
74 | plt.title(axis.value)
75 | if axis == PitchYaw.PITCH:
76 | plt.xlim((-30, 5))
77 | plt.ylim((-30, 5))
78 | else:
79 | plt.xlim((-30, 30))
80 | plt.ylim((-30, 30))
81 |
82 | return plt.gcf()
83 |
84 |
85 | def plot_to_image(fig) -> torch.Tensor:
86 | """
87 | Converts the matplotlib plot specified by 'figure' to a PNG image and
88 | returns it. The supplied figure is closed and inaccessible after this call.
89 |
90 | :param fig: matplotlib figure
91 | :return: plot for torchvision
92 | """
93 |
94 | # Save the plot to a PNG in memory.
95 | buf = io.BytesIO()
96 | plt.savefig(buf, format="png")
97 | plt.close(fig)
98 | buf.seek(0)
99 |
100 | image = Image.open(buf).convert("RGB")
101 | image = torchvision.transforms.ToTensor()(image)
102 | return image
103 |
104 |
105 | def log_figure(loggers: List, tag: str, figure, global_step: int) -> None:
106 | """
107 | Log figure as image. Only works for `TensorBoardLogger`.
108 |
109 | :param loggers:
110 | :param tag:
111 | :param figure:
112 | :param global_step:
113 | :return:
114 | """
115 |
116 | if isinstance(loggers, list):
117 | for logger in loggers:
118 | if isinstance(logger, TensorBoardLogger):
119 | logger.experiment.add_image(tag, plot_to_image(figure), global_step, dataformats="CHW")
120 | elif isinstance(loggers, TensorBoardLogger):
121 | loggers.experiment.add_image(tag, plot_to_image(figure), global_step, dataformats="CHW")
122 |
123 |
124 | def get_random_idx(k: int, size: int) -> np.ndarray:
125 | """
126 | Get `k` random values of a list of size `size`.
127 |
128 | :param k: number or random values
129 | :param size: total number of values
130 | :return: list of `k` random values
131 | """
132 | return (np.random.rand(k) * size).astype(int)
133 |
134 |
135 | def get_each_of_one_grid_idx(k: int, gaze_locations: np.ndarray, screen_sizes: np.ndarray) -> np.ndarray:
136 | """
137 | Get `k` random values of each of the $\sqrt{k}\times\sqrt{k}$ grid.
138 |
139 | :param k: number or random values
140 | :param gaze_locations: list of the position on the screen in pixels for each gaze value
141 | :param screen_sizes: list of the screen sizes in pixels for each gaze value
142 | :return: list of `k` random values
143 | """
144 | grids = int(np.sqrt(k)) # get grid size from k
145 |
146 | grid_width = screen_sizes[0][0] / grids
147 | height_width = screen_sizes[0][1] / grids
148 |
149 | gaze_locations = np.asarray(gaze_locations)
150 |
151 | valid_random_idx = []
152 |
153 | for width_range in range(grids):
154 | filter_width = (grid_width * width_range < gaze_locations[:, :1]) & (gaze_locations[:, :1] < grid_width * (width_range + 1))
155 | for height_range in range(grids):
156 | filter_height = (height_width * height_range < gaze_locations[:, 1:]) & (gaze_locations[:, 1:] < height_width * (height_range + 1))
157 | complete_filter = filter_width & filter_height
158 | complete_filter = complete_filter.reshape(-1)
159 | if sum(complete_filter) > 0:
160 | true_idxs = np.argwhere(complete_filter)
161 | random_idx = (np.random.rand(1) * len(true_idxs)).astype(int).item()
162 | valid_random_idx.append(true_idxs[random_idx].item())
163 |
164 | if len(valid_random_idx) != k:
165 | # fill missing calibration samples
166 | missing_k = k - len(valid_random_idx)
167 | missing_idxs = (np.random.rand(missing_k) * len(gaze_locations)).astype(int)
168 | for missing_idx in missing_idxs:
169 | valid_random_idx.append(missing_idx.item())
170 |
171 | return valid_random_idx
172 |
--------------------------------------------------------------------------------
/dataset/mpii_face_gaze_dataset.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from typing import List, Tuple
3 |
4 | import albumentations as A
5 | import h5py
6 | import numpy as np
7 | import pandas as pd
8 | import skimage.io
9 | import torch
10 | from albumentations.pytorch import ToTensorV2
11 | from torch.utils.data import DataLoader
12 | from torch.utils.data import Dataset
13 |
14 |
15 | def filter_persons_by_idx(file_names: List[str], keep_person_idxs: List[int]) -> List[int]:
16 | """
17 | Only keep idx that match the person ids in `keep_person_idxs`.
18 |
19 | :param file_names: list of all file names
20 | :param keep_person_idxs: list of person ids to keep
21 | :return: list of valid idxs that match `keep_person_idxs`
22 | """
23 | idx_per_person = [[] for _ in range(15)]
24 | if keep_person_idxs is not None:
25 | keep_person_idxs = [f'p{person_idx:02d}/' for person_idx in set(keep_person_idxs)]
26 | for idx, file_name in enumerate(file_names):
27 | if any(keep_person_idx in file_name for keep_person_idx in keep_person_idxs): # is a valid person_idx ?
28 | person_idx = int(file_name.split('/')[-3][1:])
29 | idx_per_person[person_idx].append(idx)
30 | else:
31 | for idx, file_name in enumerate(file_names):
32 | person_idx = int(file_name.split('/')[-3][1:])
33 | idx_per_person[person_idx].append(idx)
34 |
35 | return list(itertools.chain(*idx_per_person)) # flatten list
36 |
37 |
38 | def remove_error_data(data_path: str, file_names: List[str]) -> List[int]:
39 | """
40 | Remove erroneous data, where the gaze point is not in the screen.
41 |
42 | :param data_path: path to the dataset including the `not_on_screen.csv` file
43 | :param file_names: list of all file names
44 | :return: list of idxs of valid data
45 | """
46 | valid_idxs = []
47 |
48 | df = pd.read_csv(f'{data_path}/not_on_screen.csv')
49 | error_file_names = set([error_file_name[:-8] for error_file_name in df['file_name'].tolist()])
50 | file_names = [file_name[:-4] for file_name in file_names]
51 | for idx, file_name in enumerate(file_names):
52 | if file_name not in error_file_names:
53 | valid_idxs.append(idx)
54 |
55 | return valid_idxs
56 |
57 |
58 | class MPIIFaceGaze(Dataset):
59 | """
60 | MPIIFaceGaze dataset with offline preprocessing (= already preprocessed)
61 | """
62 |
63 | def __init__(self, data_path: str, file_name: str, keep_person_idxs: List[int], transform=None, train: bool = False, force_flip: bool = False, use_erroneous_data: bool = False):
64 | if keep_person_idxs is not None:
65 | assert len(keep_person_idxs) > 0
66 | assert max(keep_person_idxs) <= 14 # last person id = 14
67 | assert min(keep_person_idxs) >= 0 # first person id = 0
68 |
69 | self.data_path = data_path
70 | self.hdf5_file_name = f'{data_path}/{file_name}'
71 | self.h5_file = None
72 |
73 | self.transform = transform
74 | self.train = train
75 | self.force_flip = force_flip
76 |
77 | with h5py.File(self.hdf5_file_name, 'r') as f:
78 | file_names = [file_name.decode('utf-8') for file_name in f['file_name_base']]
79 |
80 | if not train:
81 | by_person_idx = filter_persons_by_idx(file_names, keep_person_idxs)
82 | self.idx2ValidIdx = by_person_idx
83 | else:
84 | by_person_idx = filter_persons_by_idx(file_names, keep_person_idxs)
85 | non_error_idx = file_names if use_erroneous_data else remove_error_data(data_path, file_names)
86 | self.idx2ValidIdx = list(set(by_person_idx) & set(non_error_idx))
87 |
88 | def __len__(self) -> int:
89 | return len(self.idx2ValidIdx) * 2 if self.train else len(self.idx2ValidIdx)
90 |
91 | def __del__(self):
92 | if self.h5_file is not None:
93 | self.h5_file.close()
94 |
95 | def open_hdf5(self):
96 | self.h5_file = h5py.File(self.hdf5_file_name, 'r')
97 |
98 | def __getitem__(self, idx):
99 | if torch.is_tensor(idx):
100 | idx = idx.tolist()
101 |
102 | if self.h5_file is None:
103 | self.open_hdf5()
104 |
105 | augmented_person = idx >= len(self.idx2ValidIdx)
106 | if augmented_person:
107 | idx -= len(self.idx2ValidIdx) # fix idx
108 |
109 | idx = self.idx2ValidIdx[idx]
110 |
111 | file_name = self.h5_file['file_name_base'][idx].decode('utf-8')
112 | gaze_location = self.h5_file['gaze_location'][idx]
113 | screen_size = self.h5_file['screen_size'][idx]
114 |
115 | person_idx = int(file_name.split('/')[-3][1:])
116 |
117 | left_eye_image = skimage.io.imread(f"{self.data_path}/{file_name}-left_eye.png")
118 | left_eye_image = np.flip(left_eye_image, axis=1)
119 | right_eye_image = skimage.io.imread(f"{self.data_path}/{file_name}-right_eye.png")
120 | full_face_image = skimage.io.imread(f"{self.data_path}/{file_name}-full_face.png")
121 | gaze_pitch = np.array(self.h5_file['gaze_pitch'][idx])
122 | gaze_yaw = np.array(self.h5_file['gaze_yaw'][idx])
123 |
124 | if augmented_person or self.force_flip:
125 | person_idx += 15 # fix person_idx
126 | left_eye_image = np.flip(left_eye_image, axis=1)
127 | right_eye_image = np.flip(right_eye_image, axis=1)
128 | full_face_image = np.flip(full_face_image, axis=1)
129 | gaze_yaw *= -1
130 |
131 | if self.transform:
132 | left_eye_image = self.transform(image=left_eye_image)["image"]
133 | right_eye_image = self.transform(image=right_eye_image)["image"]
134 | full_face_image = self.transform(image=full_face_image)["image"]
135 |
136 | return {
137 | 'file_name': file_name,
138 | 'gaze_location': gaze_location,
139 | 'screen_size': screen_size,
140 |
141 | 'person_idx': person_idx,
142 |
143 | 'left_eye_image': left_eye_image,
144 | 'right_eye_image': right_eye_image,
145 | 'full_face_image': full_face_image,
146 |
147 | 'gaze_pitch': gaze_pitch,
148 | 'gaze_yaw': gaze_yaw,
149 | }
150 |
151 |
152 | def get_dataloaders(path_to_data: str, validate_on_person: int, test_on_person: int, batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
153 | """
154 | Create train, valid and test dataset.
155 | The train dataset includes all persons except `validate_on_person` and `test_on_person`.
156 |
157 | :param path_to_data: path to dataset
158 | :param validate_on_person: person id to validate on during training
159 | :param test_on_person: person id to test on after training
160 | :param batch_size: batch size
161 | :return: train, valid and test dataset
162 | """
163 | transform = {
164 | 'train': A.Compose([
165 | A.ShiftScaleRotate(p=1.0, shift_limit=0.2, scale_limit=0.1, rotate_limit=10),
166 | A.Normalize(),
167 | ToTensorV2()
168 | ]),
169 | 'valid': A.Compose([
170 | A.Normalize(),
171 | ToTensorV2()
172 | ])
173 | }
174 |
175 | train_on_persons = list(range(0, 15))
176 | if validate_on_person in train_on_persons:
177 | train_on_persons.remove(validate_on_person)
178 | if test_on_person in train_on_persons:
179 | train_on_persons.remove(test_on_person)
180 | print('train on persons', train_on_persons)
181 | print('valid on person', validate_on_person)
182 | print('test on person', test_on_person)
183 |
184 | dataset_train = MPIIFaceGaze(path_to_data, 'data.h5', keep_person_idxs=train_on_persons, transform=transform['train'], train=True)
185 | print('len(dataset_train)', len(dataset_train))
186 | train_dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4)
187 |
188 | dataset_valid = MPIIFaceGaze(path_to_data, 'data.h5', keep_person_idxs=[validate_on_person], transform=transform['valid'])
189 | print('len(dataset_train)', len(dataset_valid))
190 | valid_dataloader = DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers=4)
191 |
192 | dataset_test = MPIIFaceGaze(path_to_data, 'data.h5', keep_person_idxs=[test_on_person], transform=transform['valid'], use_erroneous_data=True)
193 | print('len(dataset_train)', len(dataset_test))
194 | test_dataloader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=4)
195 |
196 | return train_dataloader, valid_dataloader, test_dataloader
197 |
--------------------------------------------------------------------------------
/dataset/mpii_face_gaze_preprocessing.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import pathlib
3 | from argparse import ArgumentParser
4 | from collections import defaultdict
5 | from typing import Tuple
6 |
7 | import cv2
8 | import h5py
9 | import numpy as np
10 | import pandas as pd
11 | import scipy.io
12 | import skimage.io
13 | from tqdm import tqdm
14 |
15 | from dataset.mpii_face_gaze_errors import check_mpii_face_gaze_not_on_screen
16 |
17 |
18 | def get_matrices(camera_matrix: np.ndarray, distance_norm: int, center_point: np.ndarray, focal_norm: int, head_rotation_matrix: np.ndarray, image_output_size: Tuple[int, int]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
19 | """
20 | Calculate rotation, scaling and transformation matrix.
21 |
22 | :param camera_matrix: intrinsic camera matrix
23 | :param distance_norm: normalized distance of the camera
24 | :param center_point: position of the center in the image
25 | :param focal_norm: normalized focal length
26 | :param head_rotation_matrix: rotation of the head
27 | :param image_output_size: output size of the output image
28 | :return: rotation, scaling and transformation matrix
29 | """
30 | # normalize image
31 | distance = np.linalg.norm(center_point) # actual distance between center point and original camera
32 | z_scale = distance_norm / distance
33 |
34 | cam_norm = np.array([
35 | [focal_norm, 0, image_output_size[0] / 2],
36 | [0, focal_norm, image_output_size[1] / 2],
37 | [0, 0, 1.0],
38 | ])
39 |
40 | scaling_matrix = np.array([
41 | [1.0, 0.0, 0.0],
42 | [0.0, 1.0, 0.0],
43 | [0.0, 0.0, z_scale],
44 | ])
45 |
46 | forward = (center_point / distance).reshape(3)
47 | down = np.cross(forward, head_rotation_matrix[:, 0])
48 | down /= np.linalg.norm(down)
49 | right = np.cross(down, forward)
50 | right /= np.linalg.norm(right)
51 |
52 | rotation_matrix = np.asarray([right, down, forward])
53 | transformation_matrix = np.dot(np.dot(cam_norm, scaling_matrix), np.dot(rotation_matrix, np.linalg.inv(camera_matrix)))
54 |
55 | return rotation_matrix, scaling_matrix, transformation_matrix
56 |
57 |
58 | def equalize_hist_rgb(rgb_img: np.ndarray) -> np.ndarray:
59 | """
60 | Equalize the histogram of a RGB image.
61 |
62 | :param rgb_img: RGB image
63 | :return: equalized RGB image
64 | """
65 | ycrcb_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2YCrCb) # convert from RGB color-space to YCrCb
66 | ycrcb_img[:, :, 0] = cv2.equalizeHist(ycrcb_img[:, :, 0]) # equalize the histogram of the Y channel
67 | equalized_img = cv2.cvtColor(ycrcb_img, cv2.COLOR_YCrCb2RGB) # convert back to RGB color-space from YCrCb
68 | return equalized_img
69 |
70 |
71 | def normalize_single_image(image: np.ndarray, head_rotation, gaze_target: np.ndarray, center_point: np.ndarray, camera_matrix: np.ndarray, is_eye: bool = True) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
72 | """
73 | The normalization process of a single image, creates a normalized eye image or a face image, depending on `is_eye`.
74 |
75 | :param image: original image
76 | :param head_rotation: rotation of the head
77 | :param gaze_target: 3D target of the gaze
78 | :param center_point: 3D point on the face to focus on
79 | :param camera_matrix: intrinsic camera matrix
80 | :param is_eye: if true the `distance_norm` and `image_output_size` values for the eye are used
81 | :return: normalized image, normalized gaze and rotation matrix
82 | """
83 | # normalized camera parameters
84 | focal_norm = 960 # focal length of normalized camera
85 | distance_norm = 500 if is_eye else 1600 # normalized distance between eye and camera
86 | image_output_size = (96, 64) if is_eye else (96, 96) # size of cropped eye image
87 |
88 | # compute estimated 3D positions of the landmarks
89 | if gaze_target is not None:
90 | gaze_target = gaze_target.reshape((3, 1))
91 |
92 | head_rotation_matrix, _ = cv2.Rodrigues(head_rotation)
93 | rotation_matrix, scaling_matrix, transformation_matrix = get_matrices(camera_matrix, distance_norm, center_point, focal_norm, head_rotation_matrix, image_output_size)
94 |
95 | img_warped = cv2.warpPerspective(image, transformation_matrix, image_output_size) # image normalization
96 | img_warped = equalize_hist_rgb(img_warped) # equalizes the histogram (normalization)
97 |
98 | if gaze_target is not None:
99 | # normalize gaze vector
100 | gaze_normalized = gaze_target - center_point # gaze vector
101 | # For modified data normalization, scaling is not applied to gaze direction, so here is only R applied.
102 | gaze_normalized = np.dot(rotation_matrix, gaze_normalized)
103 | gaze_normalized = gaze_normalized / np.linalg.norm(gaze_normalized)
104 | else:
105 | gaze_normalized = np.zeros(3)
106 |
107 | return img_warped, gaze_normalized.reshape(3), rotation_matrix
108 |
109 |
110 | def main(input_path: str, output_path: str):
111 | data = defaultdict(list)
112 |
113 | face_model = scipy.io.loadmat(f'{input_path}/6 points-based face model.mat')['model']
114 |
115 | for person_idx, person_path in enumerate(tqdm(sorted(glob.glob(f'{input_path}/p*')), desc='person')):
116 | person = person_path.split('/')[-1]
117 |
118 | camera_matrix = scipy.io.loadmat(f'{person_path}/Calibration/Camera.mat')['cameraMatrix']
119 | screen_size = scipy.io.loadmat(f'{person_path}/Calibration/screenSize.mat')
120 | screen_width_pixel = screen_size["width_pixel"].item()
121 | screen_height_pixel = screen_size["height_pixel"].item()
122 | annotations = pd.read_csv(f'{person_path}/{person}.txt', sep=' ', header=None, index_col=0)
123 |
124 | for day_path in tqdm(sorted(glob.glob(f'{person_path}/day*')), desc='day'):
125 | day = day_path.split('/')[-1]
126 | for image_path in sorted(glob.glob(f'{day_path}/*.jpg')):
127 | annotation = annotations.loc['/'.join(image_path.split('/')[-2:])]
128 |
129 | img = skimage.io.imread(image_path)
130 | height, width, _ = img.shape
131 |
132 | head_rotation = annotation[14:17].to_numpy().reshape(-1).astype(float) # 3D head rotation based on 6 points-based 3D face model
133 | head_translation = annotation[17:20].to_numpy().reshape(-1).astype(float) # 3D head translation based on 6 points-based 3D face model
134 | gaze_target_3d = annotation[23:26].to_numpy().reshape(-1).astype(float) # 3D gaze target position related to camera (on the screen)
135 |
136 | head_rotation_matrix, _ = cv2.Rodrigues(head_rotation)
137 | face_landmarks = np.dot(head_rotation_matrix, face_model) + head_translation.reshape((3, 1)) # 3D positions of facial landmarks
138 | left_eye_center = 0.5 * (face_landmarks[:, 2] + face_landmarks[:, 3]).reshape((3, 1)) # center eye
139 | right_eye_center = 0.5 * (face_landmarks[:, 0] + face_landmarks[:, 1]).reshape((3, 1)) # center eye
140 | face_center = face_landmarks.mean(axis=1).reshape((3, 1))
141 |
142 | img_warped_left_eye, _, _ = normalize_single_image(img, head_rotation, None, left_eye_center, camera_matrix)
143 | img_warped_right_eye, _, _ = normalize_single_image(img, head_rotation, None, right_eye_center, camera_matrix)
144 | img_warped_face, gaze_normalized, rotation_matrix = normalize_single_image(img, head_rotation, gaze_target_3d, face_center, camera_matrix, is_eye=False)
145 |
146 | # Q&A 2 https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/appearance-based-gaze-estimation-in-the-wild
147 | gaze_pitch = np.arcsin(-gaze_normalized[1])
148 | gaze_yaw = np.arctan2(-gaze_normalized[0], -gaze_normalized[2])
149 |
150 | base_file_name = f'{person}/{day}/'
151 | pathlib.Path(f"{output_path}/{base_file_name}").mkdir(parents=True, exist_ok=True)
152 | base_file_name += f'{image_path.split("/")[-1][:-4]}'
153 |
154 | skimage.io.imsave(f"{output_path}/{base_file_name}-left_eye.png", img_warped_left_eye.astype(np.uint8), check_contrast=False)
155 | skimage.io.imsave(f"{output_path}/{base_file_name}-right_eye.png", img_warped_right_eye.astype(np.uint8), check_contrast=False)
156 | skimage.io.imsave(f"{output_path}/{base_file_name}-full_face.png", img_warped_face.astype(np.uint8), check_contrast=False)
157 |
158 | data['file_name_base'].append(base_file_name)
159 | data['gaze_pitch'].append(gaze_pitch)
160 | data['gaze_yaw'].append(gaze_yaw)
161 | data['gaze_location'].append(list(annotation[:2]))
162 | data['screen_size'].append([screen_width_pixel, screen_height_pixel])
163 |
164 | with h5py.File(f'{output_path}/data.h5', 'w') as file:
165 | for key, value in data.items():
166 | if key == 'file_name_base': # only str
167 | file.create_dataset(key, data=value, compression='gzip', chunks=True)
168 | else:
169 | value = np.asarray(value)
170 | file.create_dataset(key, data=value, shape=value.shape, compression='gzip', chunks=True)
171 |
172 | check_mpii_face_gaze_not_on_screen(args.input_path, args.output_path)
173 |
174 |
175 | if __name__ == '__main__':
176 | parser = ArgumentParser()
177 | parser.add_argument("--input_path", type=str, default='./MPIIFaceGaze')
178 | parser.add_argument("--output_path", type=str, default='./data')
179 | args = parser.parse_args()
180 |
181 | main(args.input_path, args.output_path)
182 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Evaluation of a Monocular Eye Tracking Set-Up
2 |
3 | As part of my master thesis, I implemented a new state-of-the-art model that is based on the work of [Chen et al.](https://doi.org/10.1109/WACV45572.2020.9093419). \
4 | For 9 calibration samples, the previous state-of-the-art performance can be improved by up to 5.44% (2.553 degrees compared to 2.7 degrees) and for 128 calibration samples, by 7% (2.418 degrees compared to 2.6 degrees).
5 | This is accomplished by (a) improving the extraction of eye features, (b) refining the fusion process of these features, (c) removing erroneous data from the MPIIFaceGaze dataset during training, and (d) optimizing the calibration method.
6 |
7 | A software to [collect own gaze data](https://github.com/pperle/gaze-data-collection) and the [full gaze tracking pipeline](https://github.com/pperle/gaze-tracking-pipeline) is also available.
8 |
9 | 
10 |
11 | For the citaitions [1] - [10] please see below. "own model 1" represents the model described in the section below.
12 | "own model 2" uses the same model architecture as "own model 1" but is trained without the erroneous data, see MPIIFaceGaze section below.
13 | "own model 3" is the same as "own model 2" but with the calibrations points organized in a $\sqrt{k}\times\sqrt{k}$ grid instead of randomly on the screen.
14 |
15 |
16 | ## Model
17 | Since the feature extractors share the same weights for both eyes, it has been shown experimentally that the feature extraction process can be improved by flipping one of the eye images so that the noses of all eye images are on the same side.
18 | The main reason for this is that the images of the two eyes are more similar this way and the feature extractor can focus more on the relevant features, rather than the unimportant features, of either the left or the right eye.
19 |
20 | The architectural improvement that has had the most impact is the improved feature fusion process of left and right eye features.
21 | Instead of simply combining the two features, they are combined using Squeeze-and-Excitation (SE) blocks.
22 | This introduces a control mechanism for the channel relationships of the extracted feature maps that the model can learn serially.
23 |
24 | Start training by running `python train.py --path_to_data=./data --validate_on_person=1 --test_on_person=0`.
25 | For pretrained models, please see evaluation section.
26 |
27 | ## Data
28 | While examining and analyzing the most commonly used gaze prediction dataset, [MPIIFaceGaze](https://www.perceptualui.org/research/datasets/MPIIFaceGaze/) a subset of [MPIIGaze](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/appearance-based-gaze-estimation-in-the-wild/), in detail.
29 | It was realized that some recorded data does not match the provided screen sizes.
30 | For participant 2, 7, and 10, 0.043%, 8.79%, and 0.39% of the gazes directed at the screen did not match the screen provided, respectively.
31 | The left figure below shows recorded points in the datasets that do not match the provided screen size.
32 | These false target gaze positions are also visible in the right figure below, where the gaze point that are not on the screen have a different yaw offset to the ground truth.
33 |
34 | 
35 |
36 | To the best of our knowledge, we are the first to address this problem of this widespread dataset, and we propose to remove all days with any errors for people 2, 7, and 10, resulting in a new dataset we call MPIIFaceGaze-.
37 | This would only reduce the dataset by about 3.2%. As shown in the first figure, see "own model 2", removing these erroneous data improves the model's overall performance.
38 |
39 | For preprocessing MPIIFaceGaze, [download](https://www.perceptualui.org/research/datasets/MPIIFaceGaze/) the original dataset and then
40 | run `python dataset/mpii_face_gaze_preprocessing.py --input_path=./MPIIFaceGaze --output_path=./data`.
41 | Or [download the preprocessed dataset](https://drive.google.com/uc?export=download&id=1eCdULbgtJKmZPRrLoBtIS1mw_IQwH6Zi).
42 |
43 | To only generate the CSV files with all filenames which gaze is not on the screen, run `python dataset/mpii_face_gaze_errors.py --input_path=./MPIIFaceGaze --output_path=./data`.
44 | This can be run on MPIIGaze and MPIIFaceGaze, or the CSV files can be directly downloaded for [MPIIGaze](https://drive.google.com/file/d/1buUCPO0xluVxYxN4FwnupNDP3mpZ3SiN/view?usp=sharing) and [MPIIFaceGaze](https://drive.google.com/file/d/1Cq25df9124q8vkdsJO1BiqjuLSWlXz1r/view?usp=sharing).
45 |
46 | ## Calibration
47 | Nine calibration samples has become the norm for the comparison of different model architectures using MPIIFaceGaze.
48 | When the calibration points are organized in a $\sqrt{k}\times\sqrt{k}$ grid instead of randomly on the screen, or all in one position, the resulting person-specific calibration is more accurate.
49 | The three different ways to distribute the calibration point are compared in the figure below, also see "own model 3" in the first figure.
50 | Nine calibration samples aligned in a grid result in a lower angular error than 9 randomly positioned calibration samples.
51 |
52 | To collect your own calibration data or dataset, please refer to [gaze data collection](https://github.com/pperle/gaze-data-collection).
53 |
54 | 
55 |
56 |
57 | ## Evaluation
58 | For evaluation, the trained models are evaluated on the full MPIIFaceGaze, including the erroneous data, for a fair comparison to other approaches.
59 | Download the [pretrained "own model 2" models](https://drive.google.com/drive/folders/1-_bOyMgAQmnwRGfQ4QIQk7hrin0Mexch?usp=sharing) and
60 | run `python eval.py --path_to_checkpoints=./pretrained_models --path_to_data=./data` to reproduce the results shown in the figure above and the table below.
61 | `--grid_calibration_samples=True` takes a long time to evaluate, for the ease of use the number of calibration runs is reduced to 500.
62 |
63 | | | random calibration
k=9 | random calibration
k=128 | grid calibration
k=9 | grid calibration
k=128 |
k=all |
64 | |---|---:|---:|---:|---:|---:|
65 | | **p00** | 1.780 | 1.676 | 1.760 | 1.674 | 1.668 |
66 | | **p01** | 1.899 | 1.777 | 1.893 | 1.769 | 1.767 |
67 | | **p02** | 1.910 | 1.790 | 1.875 | 1.787 | 1.780 |
68 | | **p03** | 2.924 | 2.729 | 2.929 | 2.712 | 2.714 |
69 | | **p04** | 2.355 | 2.239 | 2.346 | 2.229 | 2.229 |
70 | | **p05** | 1.836 | 1.720 | 1.826 | 1.721 | 1.711 |
71 | | **p06** | 2.569 | 2.464 | 2.596 | 2.460 | 2.455 |
72 | | **p07** | 3.823 | 3.599 | 3.737 | 3.562 | 3.582 |
73 | | **p08** | 3.778 | 3.508 | 3.637 | 3.501 | 3.484 |
74 | | **p09** | 2.695 | 2.528 | 2.667 | 2.526 | 2.515 |
75 | | **p10** | 3.241 | 3.126 | 3.199 | 3.105 | 3.118 |
76 | | **p11** | 2.668 | 2.535 | 2.667 | 2.536 | 2.524 |
77 | | **p12** | 2.204 | 1.877 | 2.131 | 1.882 | 1.848 |
78 | | **p13** | 2.914 | 2.753 | 2.859 | 2.754 | 2.741 |
79 | | **p14** | 2.161 | 2.010 | 2.172 | 2.052 | 1.998 |
80 | | **mean** | **2.584** | **2.422** | **2.553** | **2.418** | **2.409** |
81 |
82 |
83 | ## Bibliography
84 | [1] Zhaokang Chen and Bertram E. Shi, “Appearance-based gaze estimation using dilated-convolutions”, Lecture Notes in Computer Science, vol. 11366, C. V. Jawahar, Hongdong Li, Greg Mori, and Konrad Schindler, Eds., pp. 309–324, 2018. DOI: 10.1007/978-3-030-20876-9_20. [Online]. Available: https://doi.org/10.1007/978-3-030-20876-9_20. \
85 | [2] ——, “Offset calibration for appearance-based gaze estimation via gaze decomposition”, in IEEE Winter Conference on Applications of Computer Vision, WACV 2020, Snowmass Village, CO, USA, March 1-5, 2020, IEEE, 2020, pp. 259–268. DOI: 10.1109/WACV45572.2020.9093419. [Online]. Available: https://doi.org/10.1109/WACV45572.2020.9093419. \
86 | [3] Tobias Fischer, Hyung Jin Chang, and Yiannis Demiris, “RT-GENE: real-time eye gaze estimation in natural environments”, in Computer Vision - ECCV 2018 - 15th European Conference, Munich, Germany, September 8-14, 2018, Proceedings, Part X, Vittorio Ferrari, Martial Hebert, Cristian Sminchisescu, and Yair Weiss, Eds., ser. Lecture Notes in Computer Science, vol. 11214, Springer, 2018, pp. 339–357. DOI: 10.1007/978-3-030-01249-6_21. [Online]. Available: https://doi.org/10.1007/978-3-030-01249-6_21. \
87 | [4] Erik Lindén, Jonas Sjöstrand, and Alexandre Proutière, “Learning to personalize in appearance-based gaze tracking”, pp. 1140–1148, 2019. DOI: 10.1109/ICCVW.2019.00145. [Online]. Available: https://doi.org/10.1109/ICCVW.2019.00145. \
88 | [5] Gang Liu, Yu Yu, Kenneth Alberto Funes Mora, and Jean-Marc Odobez, “A differential approach for gaze estimation with calibration”, in British Machine Vision Conference 2018, BMVC 2018, Newcastle, UK, September 3-6, 2018, BMVA Press, 2018, p. 235. [Online]. Available: http://bmvc2018.org/contents/papers/0792.pdf. \
89 | [6] Seonwook Park, Shalini De Mello, Pavlo Molchanov, Umar Iqbal, Otmar Hilliges, and Jan Kautz, “Few-shot adaptive gaze estimation”, pp. 9367–9376, 2019. DOI: 10.1109/ICCV.2019.00946. [Online]. Available: https://doi.org/10.1109/ICCV.2019.00946. \
90 | [7] Seonwook Park, Xucong Zhang, Andreas Bulling, and Otmar Hilliges, “Learning to find eye region landmarks for remote gaze estimation in unconstrained settings”, Bonita Sharif and Krzysztof Krejtz, Eds., 21:1–21:10, 2018. DOI: 10.1145/3204493.3204545. [Online]. Available: https://doi.org/10.1145/3204493.3204545. \
91 | [8] Yu Yu, Gang Liu, and Jean-Marc Odobez, “Improving few-shot user-specific gaze adaptation via gaze redirection synthesis”, pp. 11 937–11 946, 2019. DOI: 10.1109/CVPR.2019.01221. [Online]. Available: http://openaccess.thecvf.com/content_CVPR_2019/html/Yu_Improving_Few-Shot_User-Specific_Gaze_Adaptation_via_Gaze_Redirection_Synthesis_CVPR_2019_paper.html. \
92 | [9] Xucong Zhang, Yusuke Sugano, Mario Fritz, and Andreas Bulling, “It’s written all over your face: Full-face appearance-based gaze estimation”, pp. 2299–2308, 2017. DOI: 10.1109/CVPRW.2017.284. [Online]. Available: https://doi.org/10.1109/CVPRW.2017.284 \
93 | [10] ——, “Mpiigaze: Real-world dataset and deep appearance-based gaze estimation”, IEEE Trans. Pattern Anal. Mach. Intell., vol. 41, no. 1, pp. 162–175, 2019. DOI: 10.1109/TPAMI.2017.2778103. [Online]. Available: https://doi.org/10.1109/TPAMI.2017.2778103. \
94 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from typing import Tuple
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from pytorch_lightning import seed_everything, Trainer
8 | from pytorch_lightning.loggers import TensorBoardLogger
9 | from pytorch_lightning.utilities.types import STEP_OUTPUT, EPOCH_OUTPUT
10 |
11 | from dataset.mpii_face_gaze_dataset import get_dataloaders
12 | from model import FinalModel
13 | from utils import calc_angle_error, PitchYaw, plot_prediction_vs_ground_truth, log_figure, get_random_idx, get_each_of_one_grid_idx
14 |
15 |
16 | class Model(FinalModel):
17 | def __init__(self, learning_rate: float = 0.001, weight_decay: float = 0., k=None, adjust_slope: bool = False, grid_calibration_samples: bool = False, *args, **kwargs):
18 | super().__init__(*args, **kwargs)
19 | self.learning_rate = learning_rate
20 | self.weight_decay = weight_decay
21 | self.k = [9, 128] if k is None else k
22 | self.adjust_slope = adjust_slope
23 | self.grid_calibration_samples = grid_calibration_samples
24 |
25 | self.save_hyperparameters() # log hyperparameters
26 |
27 | def configure_optimizers(self):
28 | return torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
29 |
30 | def __step(self, batch: dict) -> Tuple:
31 | """
32 | Operates on a single batch of data.
33 |
34 | :param batch: The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
35 | :return: calculated loss, given values and predicted outputs
36 | """
37 | person_idx = batch['person_idx'].long()
38 | left_eye_image = batch['left_eye_image'].float()
39 | right_eye_image = batch['right_eye_image'].float()
40 | full_face_image = batch['full_face_image'].float()
41 |
42 | gaze_pitch = batch['gaze_pitch'].float()
43 | gaze_yaw = batch['gaze_yaw'].float()
44 | labels = torch.stack([gaze_pitch, gaze_yaw]).T
45 |
46 | outputs = self(person_idx, full_face_image, right_eye_image, left_eye_image) # prediction on the base model
47 | loss = F.mse_loss(outputs, labels)
48 |
49 | return loss, labels, outputs
50 |
51 | def training_step(self, train_batch: dict, batch_idx: int) -> STEP_OUTPUT:
52 | loss, labels, outputs = self.__step(train_batch)
53 |
54 | self.log('train/loss', loss)
55 | self.log('train/angular_error', calc_angle_error(labels, outputs))
56 |
57 | return loss
58 |
59 | def validation_step(self, valid_batch: dict, batch_idx: int) -> STEP_OUTPUT:
60 | loss, labels, outputs = self.__step(valid_batch)
61 |
62 | self.log('valid/offset(k=0)/loss', loss)
63 | self.log('valid/offset(k=0)/angular_error', calc_angle_error(labels, outputs))
64 |
65 | return {'loss': loss, 'labels': labels, 'outputs': outputs, 'gaze_locations': valid_batch['gaze_location'], 'screen_sizes': valid_batch['screen_size']}
66 |
67 | def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
68 | self.__log_and_plot_details(outputs, 'valid')
69 |
70 | def test_step(self, test_batch: dict, batch_idx: int) -> STEP_OUTPUT:
71 | loss, labels, outputs = self.__step(test_batch)
72 |
73 | self.log('test/offset(k=0)/loss', loss)
74 | self.log('test/offset(k=0)/angular_error', calc_angle_error(labels, outputs))
75 |
76 | return {'loss': loss, 'labels': labels, 'outputs': outputs, 'gaze_locations': test_batch['gaze_location'], 'screen_sizes': test_batch['screen_size']}
77 |
78 | def test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
79 | self.__log_and_plot_details(outputs, 'test')
80 |
81 | def __log_and_plot_details(self, outputs, tag: str):
82 | test_labels = torch.cat([output['labels'] for output in outputs])
83 | test_outputs = torch.cat([output['outputs'] for output in outputs])
84 | test_gaze_locations = torch.cat([output['gaze_locations'] for output in outputs])
85 | test_screen_sizes = torch.cat([output['screen_sizes'] for output in outputs])
86 |
87 | figure = plot_prediction_vs_ground_truth(test_labels, test_outputs, PitchYaw.PITCH)
88 | log_figure(self.logger, f'{tag}/offset(k=0)/pitch', figure, self.global_step)
89 |
90 | figure = plot_prediction_vs_ground_truth(test_labels, test_outputs, PitchYaw.YAW)
91 | log_figure(self.logger, f'{tag}/offset(k=0)/yaw', figure, self.global_step)
92 |
93 | # find calibration params
94 | last_x = 500
95 | calibration_train = test_outputs[:-last_x].cpu().detach().numpy()
96 | calibration_test = test_outputs[-last_x:].cpu().detach().numpy()
97 |
98 | calibration_train_labels = test_labels[:-last_x].cpu().detach().numpy()
99 | calibration_test_labels = test_labels[-last_x:].cpu().detach().numpy()
100 |
101 | gaze_locations_train = test_gaze_locations[:-last_x].cpu().detach().numpy()
102 | screen_sizes_train = test_screen_sizes[:-last_x].cpu().detach().numpy()
103 |
104 | if len(calibration_train) > 0:
105 | for k in self.k:
106 | if k <= 0:
107 | continue
108 | calibrated_solutions = []
109 |
110 | num_calibration_runs = 500 if self.grid_calibration_samples else 10_000 # original results are both evaluated with 10,000 runs
111 | for calibration_run_idx in range(num_calibration_runs): # get_each_of_one_grid_idx is slower than get_random_idx
112 | np.random.seed(42 + calibration_run_idx)
113 | calibration_sample_idxs = get_each_of_one_grid_idx(k, gaze_locations_train, screen_sizes_train) if self.grid_calibration_samples else get_random_idx(k, len(calibration_train))
114 | calibration_points_x = np.asarray([calibration_train[idx] for idx in calibration_sample_idxs])
115 | calibration_points_y = np.asarray([calibration_train_labels[idx] for idx in calibration_sample_idxs])
116 |
117 | if self.adjust_slope:
118 | m, b = np.polyfit(calibration_points_y[:, :1].reshape(-1), calibration_points_x[:, :1].reshape(-1), deg=1)
119 | pitch_fixed = (calibration_test[:, :1] - b) * (1 / m)
120 | m, b = np.polyfit(calibration_points_y[:, 1:].reshape(-1), calibration_points_x[:, 1:].reshape(-1), deg=1)
121 | yaw_fixed = (calibration_test[:, 1:] - b) * (1 / m)
122 | else:
123 | mean_diff_pitch = (calibration_points_y[:, :1] - calibration_points_x[:, :1]).mean() # mean offset
124 | pitch_fixed = calibration_test[:, :1] + mean_diff_pitch
125 | mean_diff_yaw = (calibration_points_y[:, 1:] - calibration_points_x[:, 1:]).mean() # mean offset
126 | yaw_fixed = calibration_test[:, 1:] + mean_diff_yaw
127 |
128 | pitch_fixed, yaw_fixed = torch.Tensor(pitch_fixed), torch.Tensor(yaw_fixed)
129 | outputs_fixed = torch.stack([pitch_fixed, yaw_fixed], dim=1).squeeze(-1)
130 | calibrated_solutions.append(calc_angle_error(torch.Tensor(calibration_test_labels), outputs_fixed).item())
131 |
132 | self.log(f'{tag}/offset(k={k})/mean_angular_error', np.asarray(calibrated_solutions).mean())
133 | self.log(f'{tag}/offset(k={k})/std_angular_error', np.asarray(calibrated_solutions).std())
134 |
135 | # best case, with all calibration samples, all values except the last `last_x` values
136 | if self.adjust_slope:
137 | m, b = np.polyfit(calibration_train_labels[:, :1].reshape(-1), calibration_train[:, :1].reshape(-1), deg=1)
138 | pitch_fixed = torch.Tensor((calibration_test[:, :1] - b) * (1 / m))
139 | m, b = np.polyfit(calibration_train_labels[:, 1:].reshape(-1), calibration_train[:, 1:].reshape(-1), deg=1)
140 | yaw_fixed = torch.Tensor((calibration_test[:, 1:] - b) * (1 / m))
141 | else:
142 | mean_diff_pitch = (calibration_train_labels[:, :1] - calibration_train[:, :1]).mean() # mean offset
143 | pitch_fixed = calibration_test[:, :1] + mean_diff_pitch
144 | mean_diff_yaw = (calibration_train_labels[:, 1:] - calibration_train[:, 1:]).mean() # mean offset
145 | yaw_fixed = calibration_test[:, 1:] + mean_diff_yaw
146 |
147 | pitch_fixed, yaw_fixed = torch.Tensor(pitch_fixed), torch.Tensor(yaw_fixed)
148 | outputs_fixed = torch.stack([pitch_fixed, yaw_fixed], dim=1).squeeze(-1)
149 | calibration_test_labels = torch.Tensor(calibration_test_labels)
150 | self.log(f'{tag}/offset(k=all)/angular_error', calc_angle_error(calibration_test_labels, outputs_fixed))
151 |
152 | figure = plot_prediction_vs_ground_truth(calibration_test_labels, outputs_fixed, PitchYaw.PITCH)
153 | log_figure(self.logger, f'{tag}/offset(k=all)/pitch', figure, self.global_step)
154 |
155 | figure = plot_prediction_vs_ground_truth(calibration_test_labels, outputs_fixed, PitchYaw.YAW)
156 | log_figure(self.logger, f'{tag}/offset(k=all)/yaw', figure, self.global_step)
157 |
158 |
159 | def main(path_to_data: str, validate_on_person: int, test_on_person: int, learning_rate: float, weight_decay: float, batch_size: int, k: int, adjust_slope: bool, grid_calibration_samples: bool):
160 | seed_everything(42)
161 |
162 | model = Model(learning_rate, weight_decay, k, adjust_slope, grid_calibration_samples)
163 |
164 | trainer = Trainer(
165 | gpus=1,
166 | max_epochs=50,
167 | default_root_dir='./saved_models/',
168 | logger=[
169 | TensorBoardLogger(save_dir="tb_logs"),
170 | ],
171 | benchmark=True,
172 | )
173 |
174 | train_dataloader, valid_dataloader, test_dataloader = get_dataloaders(path_to_data, validate_on_person, test_on_person, batch_size)
175 | trainer.fit(model, train_dataloader, valid_dataloader)
176 | trainer.test(model, test_dataloader)
177 |
178 |
179 | if __name__ == '__main__':
180 | parser = ArgumentParser()
181 | parser.add_argument("--path_to_data", type=str, default='./data')
182 | parser.add_argument("--validate_on_person", type=int, default=1)
183 | parser.add_argument("--test_on_person", type=int, default=0)
184 | parser.add_argument("--learning_rate", type=float, default=0.001)
185 | parser.add_argument("--weight_decay", type=float, default=0.)
186 | parser.add_argument("--batch_size", type=int, default=64)
187 | parser.add_argument("--k", type=int, default=[9, 128], nargs='+')
188 | parser.add_argument("--adjust_slope", type=bool, default=False)
189 | parser.add_argument("--grid_calibration_samples", type=bool, default=False)
190 | args = parser.parse_args()
191 |
192 | main(args.path_to_data, args.validate_on_person, args.test_on_person, args.learning_rate, args.weight_decay, args.batch_size, args.k, args.adjust_slope, args.grid_calibration_samples)
193 |
--------------------------------------------------------------------------------
/docs/compare_points_on_screen_positions.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/compare_model_results.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------