├── .gitignore
├── README.md
├── datasets
├── UnityEyes
│ └── README.md
├── mpii_gaze.py
└── unity_eyes.py
├── env-linux.yml
├── env-mac.yml
├── eval_mpiigaze.py
├── lbpcascade_frontalface_improved.xml
├── models
├── eyenet.py
├── layers.py
└── losses.py
├── notebooks
├── check_preprocessing.ipynb
├── explore_data.ipynb
└── explore_mpiigaze.ipynb
├── run_with_webcam.py
├── scripts
└── fetch_models.sh
├── static
├── fig1.png
└── ge_screenshot.png
├── test.py
├── test
├── __init__.py
├── data
│ └── imgs
│ │ └── 1.json
└── test_unity_eyes.py
├── train.py
└── util
├── eye_prediction.py
├── eye_sample.py
├── gaze.py
├── preprocess.py
└── softargmax.py
/.gitignore:
--------------------------------------------------------------------------------
1 | datasets/UnityEyes/imgs/
2 | datasets/MPIIGaze/
3 | datasets/UTMultiview/
4 | __pycache__/
5 | runs/
6 | .idea/
7 | .ipynb_checkpoints
8 | *.zip
9 | *.jpg
10 | *.dat
11 | checkpoint*
12 | *.bz2
13 | *.pt
14 | *.tar.gz
15 | *.pdf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Gaze Estimation with Deep Learning
2 |
3 | This project implements a deep learning model to predict eye region landmarks and gaze direction.
4 | The model is trained on a set of computer generated eye images synthesized with UnityEyes [1]. This work is heavily based on [2] but with some key modifications.
5 | This model achieves ~14% mean angular error on the MPIIGaze evaluation set after training on UnityEyes alone.
6 |
7 | ### Setup
8 |
9 | NOTE: This repo has been tested only on Ubuntu 16.04 and MacOS.
10 |
11 | First, create a conda env for your system and activate it:
12 | ```bash
13 | conda env create -f env-linux.yml
14 | conda activate ge-linux
15 | ```
16 |
17 | Then download the pretrained model files. One is for detecting face landmarks. The other is the main pytorch model.
18 |
19 | ```bash
20 | ./scripts/fetch_models.sh
21 | ```
22 |
23 | Finally, run the webcam demo. You will likely need a GPU and have cuda 10.1 installed in order to get acceptable performance.
24 |
25 | ```bash
26 | python run_with_webcam.py
27 | ```
28 |
29 | If you'd like to train the model yourself, please see the readme under `datasets/UnityEyes`.
30 |
31 | ### Materials and Methods
32 |
33 | Over 100k training images were generated using UnityEyes [1]. These images are each labeled
34 | with a json metadata file. The labels provide eye region landmark positions in screenspace,
35 | the direction the eye is looking in camera space, and other pieces of information. A rectangular region around the eye was extracted from each raw traing image and normalized to have a width equal to the eye width (1.5 times the distance between eye corners).
36 | For each preprocessed image, a set of heatmaps corresponding
37 | to 34 eye region landmarks was created. The model was trained to regress directly on the landmark locations and gaze direction in (pitch, yaw) form. The model was implemented in pytorch. The overall method is summarized in the following figure.
38 | 
39 |
40 | The model architecture is based on the stacked hourglass model [3]. The main modification was to add a separate pre-hourglass layer for predicting the gaze direction. The output of the additional layer is concatenated with the predicted eye-region landmarks before being passed to two fully connected layers. This way, the model can make use of the high-level landmark features for predicting the gaze direction.
41 |
42 | ### Demo Video
43 |
44 | [](https://drive.google.com/open?id=1WUUmd4quXq_YA5ANWDoUxqFGgguE_QJi)
45 |
46 |
47 | ### References
48 |
49 | 1. https://www.cl.cam.ac.uk/research/rainbow/projects/unityeyes/
50 | 2. https://github.com/swook/GazeML
51 | 3. https://github.com/princeton-vl/pytorch_stacked_hourglass
52 |
--------------------------------------------------------------------------------
/datasets/UnityEyes/README.md:
--------------------------------------------------------------------------------
1 | # UnityEyes
2 |
3 | UnityEyes is a application that generates synthetic but realistic eye images that are
4 | annotated with various pieces of information such as eye region locations and look vector.
5 |
6 | You can generate your own training images using the software which can be found at https://www.cl.cam.ac.uk/research/rainbow/projects/unityeyes/
7 |
8 | Please add your images to the `datasets/UnityEyes/imgs/` folder.
9 |
10 |
11 | ### Alternatively
12 |
13 | You can download the exact dataset that we used here https://drive.google.com/open?id=1wqTA4gutC-L4h8TcMMQO_3jJYBL15-h3
14 |
15 | Download this to ./datasets/UnityEyes/imgs.zip and unzip it.
16 |
17 | This dataset contains 100938 images together with their json metadata files.
--------------------------------------------------------------------------------
/datasets/mpii_gaze.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import torch
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 | import glob
6 | import os
7 | import cv2
8 | import scipy.io as sio
9 | import util.gaze
10 |
11 |
12 | class MPIIGaze(Dataset):
13 |
14 | def __init__(self, mpii_dir: str = 'datasets/MPIIGaze'):
15 |
16 | self.mpii_dir = mpii_dir
17 |
18 | eval_files = glob.glob(f'{mpii_dir}/Evaluation Subset/sample list for eye image/*.txt')
19 |
20 | self.eval_entries = []
21 | for ef in eval_files:
22 | person = os.path.splitext(os.path.basename(ef))[0]
23 | with open(ef) as f:
24 | lines = f.readlines()
25 | for line in lines:
26 | line = line.strip()
27 | if line != '':
28 | img_path, side = [x.strip() for x in line.split()]
29 | day, img = img_path.split('/')
30 | self.eval_entries.append({
31 | 'day': day,
32 | 'img_name': img,
33 | 'person': person,
34 | 'side': side
35 | })
36 |
37 | def __len__(self):
38 | return len(self.eval_entries)
39 |
40 | def __getitem__(self, idx):
41 | if torch.is_tensor(idx):
42 | idx = idx.tolist()
43 |
44 | return self._load_sample(idx)
45 |
46 | def _load_sample(self, i):
47 | entry = self.eval_entries[i]
48 | mat_path = os.path.join(self.mpii_dir, 'Data/Normalized', entry['person'], entry['day'] + '.mat')
49 | mat = sio.loadmat(mat_path)
50 |
51 | filenames = mat['filenames']
52 | row = np.argwhere(filenames == entry['img_name'])[0][0]
53 | side = entry['side']
54 |
55 | img = mat['data'][side][0, 0]['image'][0, 0][row]
56 | img = cv2.resize(img, (160, 96))
57 | img = cv2.equalizeHist(img)
58 | img = img / 255.
59 | img = img.astype(np.float32)
60 | if side == 'right':
61 | img = np.fliplr(img)
62 |
63 | (x, y, z) = mat['data'][side][0, 0]['gaze'][0, 0][row]
64 |
65 | theta = np.arcsin(-y)
66 | phi = np.arctan2(-x, -z)
67 | gaze = np.array([-theta, phi])
68 |
69 | return {
70 | 'img': img,
71 | 'gaze': gaze,
72 | 'side': side
73 | }
74 |
--------------------------------------------------------------------------------
/datasets/unity_eyes.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | from typing import Optional
4 |
5 | import torch
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from torch.utils.data import Dataset, DataLoader
9 | import glob
10 | import os
11 | import cv2
12 | import json
13 | from util.preprocess import preprocess_unityeyes_image
14 |
15 |
16 | class UnityEyesDataset(Dataset):
17 |
18 | def __init__(self, img_dir: Optional[str] = None):
19 |
20 | if img_dir is None:
21 | img_dir = os.path.join(os.path.dirname(__file__), 'UnityEyes/imgs')
22 |
23 | self.img_paths = glob.glob(os.path.join(img_dir, '*.jpg'))
24 | self.img_paths = sorted(self.img_paths, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
25 | self.json_paths = []
26 | for img_path in self.img_paths:
27 | idx = os.path.splitext(os.path.basename(img_path))[0]
28 | self.json_paths.append(os.path.join(img_dir, f'{idx}.json'))
29 |
30 | def __len__(self):
31 | return len(self.img_paths)
32 |
33 | def __getitem__(self, idx):
34 | if torch.is_tensor(idx):
35 | idx = idx.tolist()
36 |
37 | full_img = cv2.imread(self.img_paths[idx])
38 | with open(self.json_paths[idx]) as f:
39 | json_data = json.load(f)
40 |
41 | eye_sample = preprocess_unityeyes_image(full_img, json_data)
42 | sample = {'full_img': full_img, 'json_data': json_data }
43 | sample.update(eye_sample)
44 | return sample
--------------------------------------------------------------------------------
/env-linux.yml:
--------------------------------------------------------------------------------
1 | name: ge-linux
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1
8 | - absl-py=0.9.0
9 | - attrs=19.3.0
10 | - backcall=0.1.0
11 | - blas=1.0
12 | - bleach=3.1.0
13 | - bzip2=1.0.8
14 | - c-ares=1.15.0
15 | - ca-certificates=2020.1.1
16 | - cairo=1.16.0
17 | - certifi=2019.11.28
18 | - cudatoolkit=10.1.243
19 | - cycler=0.10.0
20 | - dbus=1.13.12
21 | - decorator=4.4.1
22 | - defusedxml=0.6.0
23 | - entrypoints=0.3
24 | - expat=2.2.6
25 | - ffmpeg=4.1.3
26 | - fontconfig=2.13.1
27 | - freetype=2.9.1
28 | - giflib=5.1.4
29 | - glib=2.63.1
30 | - gmp=6.1.2
31 | - gnutls=3.6.5
32 | - graphite2=1.3.13
33 | - grpcio=1.16.1
34 | - gst-plugins-base=1.14.0
35 | - gstreamer=1.14.0
36 | - harfbuzz=2.4.0
37 | - hdf5=1.10.5
38 | - icu=58.2
39 | - importlib_metadata=1.5.0
40 | - intel-openmp=2020.0
41 | - ipykernel=5.1.4
42 | - ipython=7.12.0
43 | - ipython_genutils=0.2.0
44 | - ipywidgets=7.5.1
45 | - jasper=1.900.1
46 | - jedi=0.16.0
47 | - jinja2=2.11.1
48 | - jpeg=9c
49 | - jsonschema=3.2.0
50 | - jupyter=1.0.0
51 | - jupyter_client=5.3.4
52 | - jupyter_console=6.1.0
53 | - jupyter_core=4.6.1
54 | - kiwisolver=1.1.0
55 | - lame=3.100
56 | - ld_impl_linux-64=2.33.1
57 | - libblas=3.8.0
58 | - libcblas=3.8.0
59 | - libedit=3.1.20181209
60 | - libffi=3.2.1
61 | - libgcc-ng=9.1.0
62 | - libgfortran-ng=7.3.0
63 | - libiconv=1.15
64 | - liblapack=3.8.0
65 | - liblapacke=3.8.0
66 | - libpng=1.6.37
67 | - libprotobuf=3.11.3
68 | - libsodium=1.0.16
69 | - libstdcxx-ng=9.1.0
70 | - libtiff=4.1.0
71 | - libuuid=2.32.1
72 | - libwebp=1.0.1
73 | - libxcb=1.13
74 | - libxml2=2.9.9
75 | - markdown=3.1.1
76 | - markupsafe=1.1.1
77 | - matplotlib=3.1.3
78 | - matplotlib-base=3.1.3
79 | - mistune=0.8.4
80 | - mkl=2020.0
81 | - mkl-service=2.3.0
82 | - mkl_fft=1.0.15
83 | - mkl_random=1.1.0
84 | - more-itertools=8.2.0
85 | - nbconvert=5.6.1
86 | - nbformat=5.0.4
87 | - ncurses=6.1
88 | - nettle=3.4.1
89 | - ninja=1.9.0
90 | - notebook=6.0.3
91 | - numpy=1.18.1
92 | - numpy-base=1.18.1
93 | - olefile=0.46
94 | - opencv=4.1.0
95 | - openh264=1.8.0
96 | - openssl=1.1.1d
97 | - packaging=20.1
98 | - pandoc=2.2.3.2
99 | - pandocfilters=1.4.2
100 | - parso=0.6.1
101 | - pcre=8.43
102 | - pexpect=4.8.0
103 | - pickleshare=0.7.5
104 | - pillow=7.0.0
105 | - pip=20.0.2
106 | - pixman=0.38.0
107 | - pluggy=0.13.1
108 | - prometheus_client=0.7.1
109 | - prompt_toolkit=3.0.3
110 | - protobuf=3.11.3
111 | - ptyprocess=0.6.0
112 | - py=1.8.1
113 | - pygments=2.5.2
114 | - pyparsing=2.4.6
115 | - pyqt=5.9.2
116 | - pyrsistent=0.15.7
117 | - pytest=5.3.5
118 | - python=3.6.10
119 | - python-dateutil=2.8.1
120 | - pytorch=1.4.0
121 | - pyzmq=18.1.1
122 | - qt=5.9.7
123 | - qtconsole=4.6.0
124 | - readline=7.0
125 | - send2trash=1.5.0
126 | - setuptools=45.2.0
127 | - sip=4.19.8
128 | - six=1.14.0
129 | - sqlite=3.31.1
130 | - tensorboard=2.0.0
131 | - terminado=0.8.3
132 | - testpath=0.4.4
133 | - tk=8.6.8
134 | - torchvision=0.5.0
135 | - tornado=6.0.3
136 | - traitlets=4.3.3
137 | - wcwidth=0.1.8
138 | - webencodings=0.5.1
139 | - werkzeug=1.0.0
140 | - wheel=0.34.2
141 | - widgetsnbextension=3.5.1
142 | - x264=1!152.20180806
143 | - xorg-kbproto=1.0.7
144 | - xorg-libice=1.0.10
145 | - xorg-libsm=1.2.3
146 | - xorg-libx11=1.6.9
147 | - xorg-libxext=1.3.4
148 | - xorg-libxrender=0.9.10
149 | - xorg-renderproto=0.11.1
150 | - xorg-xextproto=7.3.0
151 | - xorg-xproto=7.0.31
152 | - xz=5.2.4
153 | - zeromq=4.3.1
154 | - zipp=2.2.0
155 | - zlib=1.2.11
156 | - zstd=1.3.7
157 | - pip:
158 | - dlib==19.19.0
159 | - imutils==0.5.3
160 |
--------------------------------------------------------------------------------
/env-mac.yml:
--------------------------------------------------------------------------------
1 | name: ge-mac
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - absl-py=0.9.0
7 | - appnope=0.1.0
8 | - attrs=19.3.0
9 | - backcall=0.1.0
10 | - blas=1.0
11 | - bleach=3.1.0
12 | - bzip2=1.0.8
13 | - c-ares=1.15.0
14 | - ca-certificates=2020.1.1
15 | - cairo=1.14.12
16 | - certifi=2019.11.28
17 | - cycler=0.10.0
18 | - dbus=1.13.12
19 | - decorator=4.4.1
20 | - defusedxml=0.6.0
21 | - entrypoints=0.3
22 | - expat=2.2.6
23 | - ffmpeg=4.0
24 | - fontconfig=2.13.0
25 | - freetype=2.9.1
26 | - gettext=0.19.8.1
27 | - glib=2.63.1
28 | - graphite2=1.3.13
29 | - grpcio=1.16.1
30 | - harfbuzz=1.8.8
31 | - hdf5=1.10.2
32 | - icu=58.2
33 | - importlib_metadata=1.5.0
34 | - intel-openmp=2019.4
35 | - ipykernel=5.1.4
36 | - ipython=7.12.0
37 | - ipython_genutils=0.2.0
38 | - ipywidgets=7.5.1
39 | - jasper=2.0.14
40 | - jedi=0.16.0
41 | - jinja2=2.11.1
42 | - jpeg=9b
43 | - jsonschema=3.2.0
44 | - jupyter=1.0.0
45 | - jupyter_client=5.3.4
46 | - jupyter_console=6.1.0
47 | - jupyter_core=4.6.1
48 | - kiwisolver=1.1.0
49 | - libcxx=4.0.1
50 | - libcxxabi=4.0.1
51 | - libedit=3.1.20181209
52 | - libffi=3.2.1
53 | - libgfortran=3.0.1
54 | - libiconv=1.15
55 | - libopencv=3.4.2
56 | - libopus=1.3
57 | - libpng=1.6.37
58 | - libprotobuf=3.11.4
59 | - libsodium=1.0.16
60 | - libtiff=4.1.0
61 | - libvpx=1.7.0
62 | - libxml2=2.9.9
63 | - markdown=3.1.1
64 | - markupsafe=1.1.1
65 | - matplotlib=3.1.3
66 | - matplotlib-base=3.1.3
67 | - mistune=0.8.4
68 | - mkl=2019.4
69 | - mkl-service=2.3.0
70 | - mkl_fft=1.0.15
71 | - mkl_random=1.1.0
72 | - nbconvert=5.6.1
73 | - nbformat=5.0.4
74 | - ncurses=6.1
75 | - ninja=1.9.0
76 | - notebook=6.0.3
77 | - numpy=1.18.1
78 | - numpy-base=1.18.1
79 | - olefile=0.46
80 | - opencv=3.4.2
81 | - openssl=1.1.1d
82 | - pandoc=2.2.3.2
83 | - pandocfilters=1.4.2
84 | - parso=0.6.1
85 | - pcre=8.43
86 | - pexpect=4.8.0
87 | - pickleshare=0.7.5
88 | - pillow=7.0.0
89 | - pip=20.0.2
90 | - pixman=0.38.0
91 | - prometheus_client=0.7.1
92 | - prompt_toolkit=3.0.3
93 | - protobuf=3.11.4
94 | - ptyprocess=0.6.0
95 | - py-opencv=3.4.2
96 | - pygments=2.5.2
97 | - pyparsing=2.4.6
98 | - pyqt=5.9.2
99 | - pyrsistent=0.15.7
100 | - python=3.6.10
101 | - python-dateutil=2.8.1
102 | - pytorch=1.4.0
103 | - pyzmq=18.1.1
104 | - qt=5.9.7
105 | - qtconsole=4.6.0
106 | - readline=7.0
107 | - send2trash=1.5.0
108 | - setuptools=45.2.0
109 | - sip=4.19.8
110 | - six=1.14.0
111 | - sqlite=3.31.1
112 | - tensorboard=2.0.0
113 | - terminado=0.8.3
114 | - testpath=0.4.4
115 | - tk=8.6.8
116 | - torchvision=0.5.0
117 | - tornado=6.0.3
118 | - traitlets=4.3.3
119 | - wcwidth=0.1.8
120 | - webencodings=0.5.1
121 | - werkzeug=1.0.0
122 | - wheel=0.34.2
123 | - widgetsnbextension=3.5.1
124 | - xz=5.2.4
125 | - zeromq=4.3.1
126 | - zipp=2.2.0
127 | - zlib=1.2.11
128 | - zstd=1.3.7
129 | - pip:
130 | - dlib==19.19.0
131 | - imutils==0.5.3
132 |
133 |
--------------------------------------------------------------------------------
/eval_mpiigaze.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datasets.mpii_gaze import MPIIGaze
3 | from models.eyenet import EyeNet
4 | import os
5 | import numpy as np
6 | import cv2
7 | from util.preprocess import gaussian_2d
8 | from matplotlib import pyplot as plt
9 | import util.gaze
10 |
11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
12 | dataset = MPIIGaze()
13 | checkpoint = torch.load('checkpoint.pt', map_location=device)
14 | nstack = checkpoint['nstack']
15 | nfeatures = checkpoint['nfeatures']
16 | nlandmarks = checkpoint['nlandmarks']
17 | eyenet = EyeNet(nstack=nstack, nfeatures=nfeatures, nlandmarks=nlandmarks).to(device)
18 | eyenet.load_state_dict(checkpoint['model_state_dict'])
19 |
20 | with torch.no_grad():
21 | errors = []
22 |
23 | print('N', len(dataset))
24 | for i, sample in enumerate(dataset):
25 | print(i)
26 | x = torch.tensor([sample['img']]).float().to(device)
27 |
28 | heatmaps_pred, landmarks_pred, gaze_pred = eyenet.forward(x)
29 |
30 | gaze = sample['gaze'].reshape((1, 2))
31 | gaze_pred = np.asarray(gaze_pred.cpu().numpy())
32 |
33 | if sample['side'] == 'right':
34 | gaze_pred[0, 1] = -gaze_pred[0, 1]
35 |
36 | angular_error = util.gaze.angular_error(gaze, gaze_pred)
37 | errors.append(angular_error)
38 | print('---')
39 | print('error', angular_error)
40 | print('mean error', np.mean(errors))
41 | print('side', sample['side'])
42 | print('gaze', gaze)
43 | print('gaze pred', gaze_pred)
44 |
45 | # landmarks_pred = np.asarray(landmarks_pred.cpu().numpy())[0, :]
46 | #
47 | # plt.figure(figsize=(8, 9))
48 | #
49 | # iris_center = landmarks_pred[-2][::-1]
50 | # iris_center *= 2
51 | # img = sample['img']
52 | #
53 | # img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
54 | #
55 | # img_gaze_pred = img.copy()
56 | # util.gaze.draw_gaze(img_gaze_pred, iris_center, gaze_pred[0, :], length=60, color=(255, 0, 0))
57 | #
58 | # img_gaze = img.copy()
59 | # util.gaze.draw_gaze(img_gaze, iris_center, sample['gaze'], length=60, color=(0, 255, 0))
60 | #
61 | # plt.subplot(121)
62 | # plt.imshow(cv2.cvtColor(img_gaze, cv2.COLOR_BGR2RGB))
63 | # plt.title('True Gaze')
64 | #
65 | # plt.subplot(122)
66 | # plt.imshow(cv2.cvtColor(img_gaze_pred, cv2.COLOR_BGR2RGB))
67 | # plt.title('Predicted Gaze')
68 | #
69 | # plt.show()
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/lbpcascade_frontalface_improved.xml:
--------------------------------------------------------------------------------
1 |
2 |
66 |
67 |
68 |
69 | BOOST
70 | LBP
71 | 45
72 | 45
73 |
74 | GAB
75 | 9.9500000476837158e-001
76 | 5.0000000000000000e-001
77 | 9.4999999999999996e-001
78 | 1
79 | 100
80 |
81 | 256
82 | 1
83 | 19
84 |
85 |
86 | <_>
87 | 6
88 | -4.1617846488952637e+000
89 |
90 | <_>
91 |
92 | 0 -1 26 -1 -1 -17409 -1 -1 -1 -1 -1
93 |
94 | -9.9726462364196777e-001 -3.8938775658607483e-001
95 | <_>
96 |
97 | 0 -1 18 -1 -1 -21569 -20545 -1 -1 -20545 -1
98 |
99 | -9.8648911714553833e-001 -2.5386649370193481e-001
100 | <_>
101 |
102 | 0 -1 30 -21569 -16449 1006578219 -20801 -16449 -1 -21585 -1
103 |
104 | -9.6436238288879395e-001 -1.4039695262908936e-001
105 | <_>
106 |
107 | 0 -1 54 -1 -1 -16402 -4370 -1 -1 -1053010 -4456466
108 |
109 | -8.4081345796585083e-001 3.8321062922477722e-001
110 | <_>
111 |
112 | 0 -1 29 -184747280 -705314819 1326353 1364574079 -131073 -5
113 | 2147481147 -1
114 |
115 | -8.1084597110748291e-001 4.3495711684226990e-001
116 | <_>
117 |
118 | 0 -1 89 -142618625 -4097 -37269 -20933 872350430 -268476417
119 | 1207894255 2139032115
120 |
121 | -7.3140043020248413e-001 4.3799084424972534e-001
122 |
123 | <_>
124 | 6
125 | -4.0652265548706055e+000
126 |
127 | <_>
128 |
129 | 0 -1 19 -1 -1 -17409 -1 -1 -1 -1 -1
130 |
131 | -9.9727255105972290e-001 -7.2050148248672485e-001
132 | <_>
133 |
134 | 0 -1 38 -1 1073741823 -1 -1 -1 -1 -1 -1
135 |
136 | -9.8717331886291504e-001 -5.3031939268112183e-001
137 | <_>
138 |
139 | 0 -1 28 -16385 -1 -21569 -20545 -1 -1 -21569 -1
140 |
141 | -9.3442338705062866e-001 6.5213099122047424e-002
142 | <_>
143 |
144 | 0 -1 112 -2097153 -1 -1 -1 -1 -8193 -1 -35467
145 |
146 | -7.9567342996597290e-001 4.2883640527725220e-001
147 | <_>
148 |
149 | 0 -1 48 -134239573 -16465 58663467 -1079022929 -1073758273
150 | -81937 -8412501 -404766817
151 |
152 | -7.1264797449111938e-001 4.1050794720649719e-001
153 | <_>
154 |
155 | 0 -1 66 -17047555 -1099008003 2147479551 -1090584581 -69633
156 | -1342177281 -1090650121 -1472692240
157 |
158 | -7.6119172573089600e-001 4.2042696475982666e-001
159 |
160 | <_>
161 | 7
162 | -4.6904473304748535e+000
163 |
164 | <_>
165 |
166 | 0 -1 12 -1 -1 -17409 -1 -1 -1 -1 -1
167 |
168 | -9.9725550413131714e-001 -8.3142280578613281e-001
169 | <_>
170 |
171 | 0 -1 31 -1 -168429569 -1 -1 -1 -1 -1 -1
172 |
173 | -9.8183268308639526e-001 -3.6373397707939148e-001
174 | <_>
175 |
176 | 0 -1 38 -1 1073741759 -1 -1 -1 -1 -1 -1
177 |
178 | -9.1890293359756470e-001 7.8322596848011017e-002
179 | <_>
180 |
181 | 0 -1 27 -17409 -2097153 -134372726 -21873 -65 -536870913
182 | -161109 -4215889
183 |
184 | -8.0752444267272949e-001 1.9565649330615997e-001
185 | <_>
186 |
187 | 0 -1 46 -469779457 -286371842 -33619971 -212993 -1 -41943049
188 | -134217731 -1346863620
189 |
190 | -6.9232726097106934e-001 3.8141927123069763e-001
191 | <_>
192 |
193 | 0 -1 125 -1896950780 -1964839052 -9 707723004 -34078727
194 | -1074266122 -536872969 -262145
195 |
196 | -8.1760478019714355e-001 3.4172961115837097e-001
197 | <_>
198 |
199 | 0 -1 80 -402657501 654311423 -419533278 -452984853
200 | 1979676215 -1208090625 -167772569 -524289
201 |
202 | -6.3433408737182617e-001 4.3154156208038330e-001
203 |
204 | <_>
205 | 8
206 | -4.2590322494506836e+000
207 |
208 | <_>
209 |
210 | 0 -1 42 -1 -655361 -1 -1 -1 -1 -1 -1
211 |
212 | -9.9715477228164673e-001 -8.6178696155548096e-001
213 | <_>
214 |
215 | 0 -1 40 -1 -705300491 -1 -1 -1 -1 -1 -1
216 |
217 | -9.8356908559799194e-001 -5.7423096895217896e-001
218 | <_>
219 |
220 | 0 -1 43 -65 872413111 -2049 -1 -1 -1 -1 -1
221 |
222 | -9.2525935173034668e-001 -1.3835857808589935e-001
223 | <_>
224 |
225 | 0 -1 111 -1 -5242881 -1 -524289 -4194305 -1 -1 -43148
226 |
227 | -7.8076487779617310e-001 1.8362471461296082e-001
228 | <_>
229 |
230 | 0 -1 25 -145227841 868203194 -1627394049 935050171
231 | 2147483647 1006600191 -268439637 1002437615
232 |
233 | -7.2554033994674683e-001 3.3393219113349915e-001
234 | <_>
235 |
236 | 0 -1 116 -214961408 50592514 -2128 1072162674 -1077940293
237 | -1084489966 -134219854 -1074790401
238 |
239 | -6.1547595262527466e-001 3.9214438199996948e-001
240 | <_>
241 |
242 | 0 -1 3 -294987948 -1124421633 -73729 -268435841 -33654928
243 | 2122317823 -268599297 -33554945
244 |
245 | -6.4863425493240356e-001 3.8784855604171753e-001
246 | <_>
247 |
248 | 0 -1 22 -525585 -26738821 -17895690 1123482236 1996455758
249 | -8519849 -252182980 -461898753
250 |
251 | -5.5464369058609009e-001 4.4275921583175659e-001
252 |
253 | <_>
254 | 8
255 | -4.0009465217590332e+000
256 |
257 | <_>
258 |
259 | 0 -1 82 -1 -1 -1 -1 -33685505 -1 -1 -1
260 |
261 | -9.9707120656967163e-001 -8.9196771383285522e-001
262 | <_>
263 |
264 | 0 -1 84 -1 -1 -1 -1 2147446783 -1 -1 -1
265 |
266 | -9.8670446872711182e-001 -7.5064390897750854e-001
267 | <_>
268 |
269 | 0 -1 79 -1 -1 -262145 -1 -252379137 -1 -1 -1
270 |
271 | -8.9446705579757690e-001 7.0268943905830383e-002
272 | <_>
273 |
274 | 0 -1 61 -1 -8201 -1 -2097153 -16777217 -513 -16777217
275 | -1162149889
276 |
277 | -7.2166109085083008e-001 2.9786801338195801e-001
278 | <_>
279 |
280 | 0 -1 30 -21569 -1069121 1006578211 -134238545 -16450
281 | -268599297 -21617 -14680097
282 |
283 | -6.2449234724044800e-001 3.8551881909370422e-001
284 | <_>
285 |
286 | 0 -1 75 -268701913 -1999962377 1995165474 -453316822
287 | 1744684853 -2063597697 -134226057 -50336769
288 |
289 | -5.5207914113998413e-001 4.2211884260177612e-001
290 | <_>
291 |
292 | 0 -1 21 -352321825 -526489 -420020626 -486605074 1155483470
293 | -110104705 -587840772 -25428801
294 |
295 | -5.3324747085571289e-001 4.4535955786705017e-001
296 | <_>
297 |
298 | 0 -1 103 70270772 2012790229 -16810020 -245764 -1208090635
299 | -753667 -1073741828 -1363662420
300 |
301 | -6.4402890205383301e-001 3.8995954394340515e-001
302 |
303 | <_>
304 | 8
305 | -4.6897511482238770e+000
306 |
307 | <_>
308 |
309 | 0 -1 97 -1 -1 -1 -1 -524289 -524289 -1 -1
310 |
311 | -9.9684870243072510e-001 -8.8232177495956421e-001
312 | <_>
313 |
314 | 0 -1 84 -1 -1 -1 -1 2147438591 -1 -1 -1
315 |
316 | -9.8677414655685425e-001 -7.8965580463409424e-001
317 | <_>
318 |
319 | 0 -1 113 -1 -1 -1 -1 -1048577 -262149 -1048577 -35339
320 |
321 | -9.2621946334838867e-001 -2.9984828829765320e-001
322 | <_>
323 |
324 | 0 -1 33 -2249 867434291 -32769 -33562753 -1 -1073758209
325 | -4165 -1
326 |
327 | -7.2429555654525757e-001 2.2348840534687042e-001
328 | <_>
329 |
330 | 0 -1 98 1659068671 -142606337 587132538 -67108993 577718271
331 | -294921 -134479873 -129
332 |
333 | -5.5495566129684448e-001 3.5419258475303650e-001
334 | <_>
335 |
336 | 0 -1 100 -268441813 788267007 -286265494 -486576145 -8920251
337 | 2138505075 -151652570 -2050
338 |
339 | -5.3362584114074707e-001 3.9479774236679077e-001
340 | <_>
341 |
342 | 0 -1 51 -1368387212 -537102978 -98305 -163843 1065109500
343 | -16777217 -67321939 -1141359619
344 |
345 | -5.6162708997726440e-001 3.8008108735084534e-001
346 | <_>
347 |
348 | 0 -1 127 -268435550 1781120906 -251658720 -143130698
349 | -1048605 -1887436825 1979700688 -1008730125
350 |
351 | -5.1167154312133789e-001 4.0678605437278748e-001
352 |
353 | <_>
354 | 10
355 | -4.2179841995239258e+000
356 |
357 | <_>
358 |
359 | 0 -1 97 -1 -1 -1 -1 -524289 -524289 -1 -1
360 |
361 | -9.9685418605804443e-001 -8.8037383556365967e-001
362 | <_>
363 |
364 | 0 -1 90 -1 -1 -1 -1 -8912897 -524297 -8912897 -1
365 |
366 | -9.7972750663757324e-001 -5.7626229524612427e-001
367 | <_>
368 |
369 | 0 -1 96 -1 -1 -1 -1 -1 -65 -1 -2249
370 |
371 | -9.0239793062210083e-001 -1.7454113066196442e-001
372 | <_>
373 |
374 | 0 -1 71 -1 -4097 -1 -513 -16777217 -268468483 -16797697
375 | -1430589697
376 |
377 | -7.4346423149108887e-001 9.4165161252021790e-002
378 | <_>
379 |
380 | 0 -1 37 1364588304 -581845274 -536936460 -3 -308936705
381 | -1074331649 -4196865 -134225953
382 |
383 | -6.8877440690994263e-001 2.7647304534912109e-001
384 | <_>
385 |
386 | 0 -1 117 -37765187 -540675 -3 -327753 -1082458115 -65537
387 | 1071611901 536827253
388 |
389 | -5.7555085420608521e-001 3.4339720010757446e-001
390 | <_>
391 |
392 | 0 -1 85 -269490650 -1561395522 -1343312090 -857083986
393 | -1073750223 -369098755 -50856110 -2065
394 |
395 | -5.4036927223205566e-001 4.0065473318099976e-001
396 | <_>
397 |
398 | 0 -1 4 -425668880 -34427164 1879048177 -269570140 790740912
399 | -196740 2138535839 -536918145
400 |
401 | -4.8439365625381470e-001 4.4630467891693115e-001
402 | <_>
403 |
404 | 0 -1 92 74726960 -1246482434 -1 -246017 -1078607916
405 | -1073947163 -1644231687 -1359211496
406 |
407 | -5.6686979532241821e-001 3.6671569943428040e-001
408 | <_>
409 |
410 | 0 -1 11 -135274809 -1158173459 -353176850 540195262
411 | 2139086600 2071977814 -546898600 -96272673
412 |
413 | -5.1499199867248535e-001 4.0788397192955017e-001
414 |
415 | <_>
416 | 9
417 | -4.0345416069030762e+000
418 |
419 | <_>
420 |
421 | 0 -1 78 -1 -1 -1 -1 -8912897 -1 -8912897 -1
422 |
423 | -9.9573624134063721e-001 -8.5452395677566528e-001
424 | <_>
425 |
426 | 0 -1 93 -1 -1 -1 -1 -148635649 -524297 -8912897 -1
427 |
428 | -9.7307401895523071e-001 -5.2884924411773682e-001
429 | <_>
430 |
431 | 0 -1 77 -1 -8209 -1 -257 -772734977 -1 -201850881 -1
432 |
433 | -8.6225658655166626e-001 4.3712578713893890e-002
434 | <_>
435 |
436 | 0 -1 68 -570427393 -16649 -69633 -131073 -536944677 -1 -8737
437 | -1435828225
438 |
439 | -6.8078064918518066e-001 2.5120577216148376e-001
440 | <_>
441 |
442 | 0 -1 50 -1179697 -34082849 -3278356 -37429266 -1048578
443 | -555753474 -1015551096 -37489685
444 |
445 | -6.1699724197387695e-001 3.0963841080665588e-001
446 | <_>
447 |
448 | 0 -1 129 -1931606992 -17548804 -16842753 -1075021827
449 | 1073667572 -81921 -1611073620 -1415047752
450 |
451 | -6.0499197244644165e-001 3.0735063552856445e-001
452 | <_>
453 |
454 | 0 -1 136 -269754813 1761591286 -1073811523 2130378623 -17580
455 | -1082294665 -159514800 -1026883840
456 |
457 | -5.6772041320800781e-001 3.5023149847984314e-001
458 | <_>
459 |
460 | 0 -1 65 2016561683 1528827871 -10258447 960184191 125476830
461 | -8511618 -1078239365 187648611
462 |
463 | -5.5894804000854492e-001 3.4856522083282471e-001
464 | <_>
465 |
466 | 0 -1 13 -207423502 -333902 2013200231 -202348848 1042454451
467 | -16393 1073117139 2004162321
468 |
469 | -5.7197356224060059e-001 3.2818377017974854e-001
470 |
471 | <_>
472 | 9
473 | -3.4892759323120117e+000
474 |
475 | <_>
476 |
477 | 0 -1 78 -1 -1 -1 -1 -8912897 -1 -8912897 -1
478 |
479 | -9.8917990922927856e-001 -7.3812037706375122e-001
480 | <_>
481 |
482 | 0 -1 93 -1 -1 -1 -1 -148635649 -524297 -8912897 -1
483 |
484 | -9.3414896726608276e-001 -2.6945295929908752e-001
485 | <_>
486 |
487 | 0 -1 83 -1 -524289 -1 -1048577 1879011071 -32769 -524289
488 | -3178753
489 |
490 | -7.6891708374023438e-001 5.2568886429071426e-002
491 | <_>
492 |
493 | 0 -1 9 -352329729 -17891329 -16810117 -486871042 -688128841
494 | -1358954675 -16777218 -219217968
495 |
496 | -6.2337344884872437e-001 2.5143685936927795e-001
497 | <_>
498 |
499 | 0 -1 130 -2157 -1548812374 -1343233440 -418381854 -953155613
500 | -836960513 -713571200 -709888014
501 |
502 | -4.7277018427848816e-001 3.9616456627845764e-001
503 | <_>
504 |
505 | 0 -1 121 -1094717701 -67240065 -65857 -32899 -5783756
506 | -136446081 -134285352 -2003298884
507 |
508 | -5.1766264438629150e-001 3.5814732313156128e-001
509 | <_>
510 |
511 | 0 -1 23 -218830160 -119671186 5505075 1241491391 -1594469
512 | -2097185 2004828075 -67649541
513 |
514 | -6.5394639968872070e-001 3.0377501249313354e-001
515 | <_>
516 |
517 | 0 -1 115 -551814749 2099511088 -1090732551 -2045546512
518 | -1086341441 1059848178 800042912 252705994
519 |
520 | -5.2584588527679443e-001 3.3847147226333618e-001
521 | <_>
522 |
523 | 0 -1 99 -272651477 578776766 -285233490 -889225217
524 | 2147448656 377454463 2012701952 -68157761
525 |
526 | -6.1836904287338257e-001 2.8922611474990845e-001
527 |
528 | <_>
529 | 9
530 | -3.0220029354095459e+000
531 |
532 | <_>
533 |
534 | 0 -1 36 -1 -570425345 -1 -570425345 -1 -50331649 -6291457 -1
535 |
536 | -9.7703826427459717e-001 -6.2527233362197876e-001
537 | <_>
538 |
539 | 0 -1 124 -1430602241 -33619969 -1 -3 -1074003969 -1073758209
540 | -1073741825 -1073768705
541 |
542 | -8.9538317918777466e-001 -3.1887885928153992e-001
543 | <_>
544 |
545 | 0 -1 88 -1 -268439625 -65601 -268439569 -393809 -270532609
546 | -42076889 -288361721
547 |
548 | -6.8733429908752441e-001 1.2978810071945190e-001
549 | <_>
550 |
551 | 0 -1 132 -755049252 2042563807 1795096575 465121071
552 | -1090585188 -20609 -1459691784 539672495
553 |
554 | -5.7038843631744385e-001 3.0220884084701538e-001
555 | <_>
556 |
557 | 0 -1 20 -94377762 -25702678 1694167798 -231224662 1079955016
558 | -346144140 2029995743 -536918961
559 |
560 | -5.3204691410064697e-001 3.4054222702980042e-001
561 | <_>
562 |
563 | 0 -1 47 2143026943 -285278225 -3 -612438281 -16403 -131074
564 | -1 -1430749256
565 |
566 | -4.6176829934120178e-001 4.1114711761474609e-001
567 | <_>
568 |
569 | 0 -1 74 203424336 -25378820 -35667973 1073360894 -1912815660
570 | -573444 -356583491 -1365235056
571 |
572 | -4.9911966919898987e-001 3.5335537791252136e-001
573 | <_>
574 |
575 | 0 -1 6 -1056773 -1508430 -558153 -102747408 2133997491
576 | -269043865 2004842231 -8947721
577 |
578 | -4.0219521522521973e-001 4.3947893381118774e-001
579 | <_>
580 |
581 | 0 -1 70 -880809694 -1070282769 -1363162108 -838881281
582 | -680395161 -2064124929 -34244753 1173880701
583 |
584 | -5.3891533613204956e-001 3.2062566280364990e-001
585 |
586 | <_>
587 | 8
588 | -2.5489892959594727e+000
589 |
590 | <_>
591 |
592 | 0 -1 39 -1 -572522497 -8519681 -570425345 -4195329 -50333249
593 | -1 -1
594 |
595 | -9.4647216796875000e-001 -3.3662387728691101e-001
596 | <_>
597 |
598 | 0 -1 124 -1430735362 -33619971 -8201 -3 -1677983745
599 | -1073762817 -1074003969 -1142979329
600 |
601 | -8.0300611257553101e-001 -3.8466516882181168e-002
602 | <_>
603 |
604 | 0 -1 91 -67113217 -524289 -671482265 -786461 1677132031
605 | -268473345 -68005889 -70291765
606 |
607 | -5.8367580175399780e-001 2.6507318019866943e-001
608 | <_>
609 |
610 | 0 -1 17 -277872641 -553910292 -268435458 -16843010
611 | 1542420439 -1342178311 -143132940 -2834
612 |
613 | -4.6897178888320923e-001 3.7864661216735840e-001
614 | <_>
615 |
616 | 0 -1 137 -1312789 -290527285 -286326862 -5505280 -1712335966
617 | -2045979188 1165423617 -709363723
618 |
619 | -4.6382644772529602e-001 3.6114525794982910e-001
620 | <_>
621 |
622 | 0 -1 106 1355856590 -109445156 -96665606 2066939898
623 | 1356084692 1549031917 -30146561 -16581701
624 |
625 | -6.3095021247863770e-001 2.9294869303703308e-001
626 | <_>
627 |
628 | 0 -1 104 -335555328 118529 1860167712 -810680357 -33558656
629 | -1368391795 -402663552 -1343225921
630 |
631 | -5.9658926725387573e-001 2.7228885889053345e-001
632 | <_>
633 |
634 | 0 -1 76 217581168 -538349634 1062631419 1039868926
635 | -1090707460 -2228359 -1078042693 -1147128518
636 |
637 | -4.5812287926673889e-001 3.7063929438591003e-001
638 |
639 | <_>
640 | 9
641 | -2.5802578926086426e+000
642 |
643 | <_>
644 |
645 | 0 -1 35 -513 -706873891 -270541825 1564475391 -120602625
646 | -118490145 -3162113 -1025
647 |
648 | -8.9068460464477539e-001 -1.6470588743686676e-001
649 | <_>
650 |
651 | 0 -1 41 -1025 872144563 -2105361 -1078076417 -1048577
652 | -1145061461 -87557413 -1375993973
653 |
654 | -7.1808964014053345e-001 2.2022204473614693e-002
655 | <_>
656 |
657 | 0 -1 95 -42467849 967946223 -811601986 1030598351
658 | -1212430676 270856533 -1392539508 147705039
659 |
660 | -4.9424821138381958e-001 3.0048963427543640e-001
661 | <_>
662 |
663 | 0 -1 10 -218116370 -637284625 -87373174 -521998782
664 | -805355450 -615023745 -814267322 -12069282
665 |
666 | -5.5306458473205566e-001 2.9137542843818665e-001
667 | <_>
668 |
669 | 0 -1 105 -275849241 -527897 -11052049 -69756067 -15794193
670 | -1141376839 -564771 -287095455
671 |
672 | -4.6759819984436035e-001 3.6638516187667847e-001
673 | <_>
674 |
675 | 0 -1 24 -1900898096 -18985228 -44056577 -24675 -1074880639
676 | -283998 796335613 -1079041957
677 |
678 | -4.2737138271331787e-001 3.9243003726005554e-001
679 | <_>
680 |
681 | 0 -1 139 -555790844 410735094 -32106513 406822863 -897632192
682 | -912830145 -117771560 -1204027649
683 |
684 | -4.1896930336952209e-001 3.6744937300682068e-001
685 | <_>
686 |
687 | 0 -1 0 -1884822366 -1406613148 1135342180 -1979127580
688 | -68174862 246469804 1001386992 -708885872
689 |
690 | -5.7093089818954468e-001 2.9880744218826294e-001
691 | <_>
692 |
693 | 0 -1 45 -469053950 1439068142 2117758841 2004671078
694 | 207931006 1265321675 970353931 1541343047
695 |
696 | -6.0491901636123657e-001 2.4652053415775299e-001
697 |
698 | <_>
699 | 9
700 | -2.2425732612609863e+000
701 |
702 | <_>
703 |
704 | 0 -1 58 1481987157 282547485 -14952129 421131223 -391065352
705 | -24212488 -100094241 -1157907473
706 |
707 | -8.2822084426879883e-001 -2.1619293093681335e-001
708 | <_>
709 |
710 | 0 -1 126 -134217889 -543174305 -75497474 -16851650 -6685738
711 | -75834693 -2097200 -262146
712 |
713 | -5.4628932476043701e-001 2.7662658691406250e-001
714 | <_>
715 |
716 | 0 -1 133 -220728227 -604288517 -661662214 413104863
717 | -627323700 -251915415 -626200872 -1157958657
718 |
719 | -4.1643124818801880e-001 4.1700571775436401e-001
720 | <_>
721 |
722 | 0 -1 2 -186664033 -44236961 -1630262774 -65163606 -103237330
723 | -3083265 -1003729 2053105955
724 |
725 | -5.4847818613052368e-001 2.9710745811462402e-001
726 | <_>
727 |
728 | 0 -1 62 -256115886 -237611873 -620250696 387061799
729 | 1437882671 274878849 -8684449 1494294023
730 |
731 | -4.6202757954597473e-001 3.3915829658508301e-001
732 | <_>
733 |
734 | 0 -1 1 -309400577 -275864640 -1056864869 1737132756
735 | -272385089 1609671419 1740601343 1261376789
736 |
737 | -4.6158722043037415e-001 3.3939516544342041e-001
738 | <_>
739 |
740 | 0 -1 102 818197248 -196324552 286970589 -573270699
741 | -1174099579 -662077381 -1165157895 -1626859296
742 |
743 | -4.6193107962608337e-001 3.2456985116004944e-001
744 | <_>
745 |
746 | 0 -1 69 -1042550357 14675409 1367955200 -841482753
747 | 1642443255 8774277 1941304147 1099949563
748 |
749 | -4.9091196060180664e-001 3.3870378136634827e-001
750 | <_>
751 |
752 | 0 -1 72 -639654997 1375720439 -2129542805 1614801090
753 | -626787937 -5779294 1488699183 -525406458
754 |
755 | -4.9073097109794617e-001 3.0637946724891663e-001
756 |
757 | <_>
758 | 9
759 | -1.2258235216140747e+000
760 |
761 | <_>
762 |
763 | 0 -1 118 302046707 -16744240 1360106207 -543735387
764 | 1025700851 -1079408512 1796961263 -6334981
765 |
766 | -6.1358314752578735e-001 2.3539231717586517e-001
767 | <_>
768 |
769 | 0 -1 5 -144765953 -116448726 -653851877 1934829856 722021887
770 | 856564834 1933919231 -540838029
771 |
772 | -5.1209545135498047e-001 3.2506987452507019e-001
773 | <_>
774 |
775 | 0 -1 140 -170132825 -1438923874 1879300370 -1689337194
776 | -695606496 285911565 -1044188928 -154210028
777 |
778 | -5.1769560575485229e-001 3.2290914654731750e-001
779 | <_>
780 |
781 | 0 -1 131 -140776261 -355516414 822178224 -1039743806
782 | -1012208926 134887424 1438876097 -908591660
783 |
784 | -5.0321841239929199e-001 3.0263835191726685e-001
785 | <_>
786 |
787 | 0 -1 64 -2137211696 -1634281249 1464325973 498569935
788 | -1580152080 -2001687927 721783561 265096035
789 |
790 | -4.6532225608825684e-001 3.4638473391532898e-001
791 | <_>
792 |
793 | 0 -1 101 -255073589 -211824417 -972195129 -1063415417
794 | 1937994261 1363165220 -754733105 1967602541
795 |
796 | -4.9611270427703857e-001 3.3260712027549744e-001
797 | <_>
798 |
799 | 0 -1 81 -548146862 -655567194 -2062466596 1164562721
800 | 416408236 -1591631712 -83637777 975344427
801 |
802 | -4.9862930178642273e-001 3.2003280520439148e-001
803 | <_>
804 |
805 | 0 -1 55 -731904652 2147179896 2147442687 2112830847 -65604
806 | -131073 -42139667 -1074907393
807 |
808 | -3.6636069416999817e-001 4.5651626586914063e-001
809 | <_>
810 |
811 | 0 -1 67 1885036886 571985932 -1784930633 724431327
812 | 1940422257 -1085746880 964888398 731867951
813 |
814 | -5.2619713544845581e-001 3.2635414600372314e-001
815 |
816 | <_>
817 | 9
818 | -1.3604533672332764e+000
819 |
820 | <_>
821 |
822 | 0 -1 8 -287609985 -965585953 -2146397793 -492129894
823 | -729029645 -544619901 -645693256 -6565484
824 |
825 | -4.5212322473526001e-001 3.8910505175590515e-001
826 | <_>
827 |
828 | 0 -1 122 -102903523 -145031013 536899675 688195859
829 | -645291520 -1165359094 -905565928 171608223
830 |
831 | -4.9594074487686157e-001 3.4109055995941162e-001
832 | <_>
833 |
834 | 0 -1 134 -790640459 487931983 1778450522 1036604041
835 | -904752984 -954040118 -2134707506 304866043
836 |
837 | -4.1148442029953003e-001 3.9666590094566345e-001
838 | <_>
839 |
840 | 0 -1 141 -303829117 1726939070 922189815 -827983123
841 | 1567883042 1324809852 292710260 -942678754
842 |
843 | -3.5154473781585693e-001 4.8011952638626099e-001
844 | <_>
845 |
846 | 0 -1 59 -161295376 -159215460 -1858041315 2140644499
847 | -2009065472 -133804007 -2003265301 1263206851
848 |
849 | -4.2808216810226440e-001 3.9841541647911072e-001
850 | <_>
851 |
852 | 0 -1 34 -264248081 -667846464 1342624856 1381160835
853 | -2104716852 1342865409 -266612310 -165954877
854 |
855 | -4.3293288350105286e-001 4.0339657664299011e-001
856 | <_>
857 |
858 | 0 -1 32 -1600388464 -40369901 285344639 1394344275
859 | -255680312 -100532214 -1031663944 -7471079
860 |
861 | -4.1385015845298767e-001 4.5087572932243347e-001
862 | <_>
863 |
864 | 0 -1 15 1368521651 280207469 35779199 -105983261 1208124819
865 | -565870452 -1144024288 -591535344
866 |
867 | -4.2956474423408508e-001 4.2176279425621033e-001
868 | <_>
869 |
870 | 0 -1 109 1623607527 -661513115 -1073217263 -2142994420
871 | -1339883309 -89816956 436308899 1426178059
872 |
873 | -4.7764992713928223e-001 3.7551075220108032e-001
874 |
875 | <_>
876 | 9
877 | -4.2518746852874756e-001
878 |
879 | <_>
880 |
881 | 0 -1 135 -116728032 -1154420809 -1350582273 746061691
882 | -1073758277 2138570623 2113797566 -138674182
883 |
884 | -1.7125381529331207e-001 6.5421247482299805e-001
885 | <_>
886 |
887 | 0 -1 63 -453112432 -1795354691 -1342242964 494112553
888 | 209458404 -2114697500 1316830362 259213855
889 |
890 | -3.9870172739028931e-001 4.5807033777236938e-001
891 | <_>
892 |
893 | 0 -1 52 -268172036 294715533 268575185 486785157 -1065303920
894 | -360185856 -2147476808 134777113
895 |
896 | -5.3581339120864868e-001 3.5815808176994324e-001
897 | <_>
898 |
899 | 0 -1 86 -301996882 -345718921 1877946252 -940720129
900 | -58737369 -721944585 -92954835 -530449
901 |
902 | -3.9938014745712280e-001 4.9603295326232910e-001
903 | <_>
904 |
905 | 0 -1 14 -853281886 -756895766 2130706352 -9519120
906 | -1921059862 394133373 2138453959 -538200841
907 |
908 | -4.0230083465576172e-001 4.9537116289138794e-001
909 | <_>
910 |
911 | 0 -1 128 -2133448688 -641138493 1078022185 294060066
912 | -327122776 -2130640896 -2147466247 -1910634326
913 |
914 | -5.8290809392929077e-001 3.4102553129196167e-001
915 | <_>
916 |
917 | 0 -1 53 587265978 -2071658479 1108361221 -578448765
918 | -1811905899 -2008965119 33900729 762301595
919 |
920 | -4.5518967509269714e-001 4.7242793440818787e-001
921 | <_>
922 |
923 | 0 -1 138 -1022189373 -2139094976 16658 -1069445120
924 | -1073555454 -1073577856 1096068 -978351488
925 |
926 | -4.7530207037925720e-001 4.3885371088981628e-001
927 | <_>
928 |
929 | 0 -1 7 -395352441 -1073541103 -1056964605 1053186 269111298
930 | -2012184576 1611208714 -360415095
931 |
932 | -5.0448113679885864e-001 4.1588482260704041e-001
933 |
934 | <_>
935 | 7
936 | 2.7163455262780190e-002
937 |
938 | <_>
939 |
940 | 0 -1 49 783189748 -137429026 -257 709557994 2130460236
941 | -196611 -9580 585428708
942 |
943 | -2.0454545319080353e-001 7.9608374834060669e-001
944 | <_>
945 |
946 | 0 -1 108 1284360448 1057423155 1592696573 -852672655
947 | 1547382714 -1642594369 125705358 797134398
948 |
949 | -3.6474677920341492e-001 6.0925579071044922e-001
950 | <_>
951 |
952 | 0 -1 94 1347680270 -527720448 1091567712 1073745933
953 | -1073180671 0 285745154 -511192438
954 |
955 | -4.6406838297843933e-001 5.5626088380813599e-001
956 | <_>
957 |
958 | 0 -1 73 1705780944 -145486260 -115909 -281793505 -418072663
959 | -1681064068 1877454127 -1912330993
960 |
961 | -4.7043186426162720e-001 5.8430361747741699e-001
962 | <_>
963 |
964 | 0 -1 110 -2118142016 339509033 -285260567 1417764573
965 | 68144392 -468879483 -2033291636 231451911
966 |
967 | -4.8700931668281555e-001 5.4639810323715210e-001
968 | <_>
969 |
970 | 0 -1 119 -1888051818 489996135 -65539 849536890 2146716845
971 | -1107542088 -1275615746 -1119617586
972 |
973 | -4.3356490135192871e-001 6.5175366401672363e-001
974 | <_>
975 |
976 | 0 -1 44 -1879021438 336830528 1073766659 1477541961 8560696
977 | -1207369568 8462472 1493893448
978 |
979 | -5.4343086481094360e-001 5.2777874469757080e-001
980 |
981 | <_>
982 | 7
983 | 4.9174150824546814e-001
984 |
985 | <_>
986 |
987 | 0 -1 57 644098 15758324 1995964260 -463011882 893285175
988 | 83156983 2004317989 16021237
989 |
990 | -1.7073170840740204e-001 9.0782123804092407e-001
991 | <_>
992 |
993 | 0 -1 123 268632845 -2147450864 -2143240192 -2147401728
994 | 8523937 -1878523840 16777416 616824984
995 |
996 | -4.8744434118270874e-001 7.3311311006546021e-001
997 | <_>
998 |
999 | 0 -1 120 -2110735872 803880886 989739810 1673281312 91564930
1000 | -277454958 997709514 -581366443
1001 |
1002 | -4.0291741490364075e-001 8.2450771331787109e-001
1003 | <_>
1004 |
1005 | 0 -1 87 941753434 -1067128905 788512753 -1074450460
1006 | 779101657 -1346552460 938805167 -2050424642
1007 |
1008 | -3.6246949434280396e-001 8.7103593349456787e-001
1009 | <_>
1010 |
1011 | 0 -1 60 208 1645217920 130 538263552 33595552 -1475870592
1012 | 16783361 1375993867
1013 |
1014 | -6.1472141742706299e-001 5.9707164764404297e-001
1015 | <_>
1016 |
1017 | 0 -1 114 1860423179 1034692624 -285213187 -986681712
1018 | 1576755092 -1408205463 -127714 -1246035687
1019 |
1020 | -4.5621752738952637e-001 8.9482426643371582e-001
1021 | <_>
1022 |
1023 | 0 -1 107 33555004 -1861746688 1073807361 -754909184
1024 | 645922856 8388608 134250648 419635458
1025 |
1026 | -5.2466005086898804e-001 7.1834069490432739e-001
1027 |
1028 | <_>
1029 | 2
1030 | 1.9084988832473755e+000
1031 |
1032 | <_>
1033 |
1034 | 0 -1 16 536064 131072 -20971516 524288 576 1048577 0 40960
1035 |
1036 | -8.0000001192092896e-001 9.8018401861190796e-001
1037 | <_>
1038 |
1039 | 0 -1 56 67108864 0 4096 1074003968 8192 536870912 4 262144
1040 |
1041 | -9.6610915660858154e-001 9.2831486463546753e-001
1042 |
1043 | <_>
1044 |
1045 | 0 0 1 1
1046 | <_>
1047 |
1048 | 0 0 3 2
1049 | <_>
1050 |
1051 | 0 1 13 6
1052 | <_>
1053 |
1054 | 0 2 3 14
1055 | <_>
1056 |
1057 | 0 2 4 2
1058 | <_>
1059 |
1060 | 0 6 2 3
1061 | <_>
1062 |
1063 | 0 6 3 2
1064 | <_>
1065 |
1066 | 0 16 1 3
1067 | <_>
1068 |
1069 | 0 20 3 3
1070 | <_>
1071 |
1072 | 0 22 2 3
1073 | <_>
1074 |
1075 | 0 28 4 4
1076 | <_>
1077 |
1078 | 0 35 2 3
1079 | <_>
1080 |
1081 | 1 0 14 7
1082 | <_>
1083 |
1084 | 1 5 3 2
1085 | <_>
1086 |
1087 | 1 6 2 1
1088 | <_>
1089 |
1090 | 1 14 10 9
1091 | <_>
1092 |
1093 | 1 21 4 4
1094 | <_>
1095 |
1096 | 1 23 4 2
1097 | <_>
1098 |
1099 | 2 0 13 7
1100 | <_>
1101 |
1102 | 2 0 14 7
1103 | <_>
1104 |
1105 | 2 33 5 4
1106 | <_>
1107 |
1108 | 2 36 4 3
1109 | <_>
1110 |
1111 | 2 39 3 2
1112 | <_>
1113 |
1114 | 3 1 13 11
1115 | <_>
1116 |
1117 | 3 2 3 2
1118 | <_>
1119 |
1120 | 4 0 7 8
1121 | <_>
1122 |
1123 | 4 0 13 7
1124 | <_>
1125 |
1126 | 5 0 12 6
1127 | <_>
1128 |
1129 | 5 0 13 7
1130 | <_>
1131 |
1132 | 5 1 10 13
1133 | <_>
1134 |
1135 | 5 1 12 7
1136 | <_>
1137 |
1138 | 5 2 7 13
1139 | <_>
1140 |
1141 | 5 4 2 1
1142 | <_>
1143 |
1144 | 5 8 7 4
1145 | <_>
1146 |
1147 | 5 39 3 2
1148 | <_>
1149 |
1150 | 6 3 5 2
1151 | <_>
1152 |
1153 | 6 3 6 2
1154 | <_>
1155 |
1156 | 6 5 4 12
1157 | <_>
1158 |
1159 | 6 9 6 3
1160 | <_>
1161 |
1162 | 7 3 5 2
1163 | <_>
1164 |
1165 | 7 3 6 13
1166 | <_>
1167 |
1168 | 7 5 6 4
1169 | <_>
1170 |
1171 | 7 7 6 10
1172 | <_>
1173 |
1174 | 7 8 6 4
1175 | <_>
1176 |
1177 | 7 32 5 4
1178 | <_>
1179 |
1180 | 7 33 5 4
1181 | <_>
1182 |
1183 | 8 0 1 1
1184 | <_>
1185 |
1186 | 8 0 2 1
1187 | <_>
1188 |
1189 | 8 2 10 7
1190 | <_>
1191 |
1192 | 9 0 6 2
1193 | <_>
1194 |
1195 | 9 2 9 3
1196 | <_>
1197 |
1198 | 9 4 1 1
1199 | <_>
1200 |
1201 | 9 6 2 1
1202 | <_>
1203 |
1204 | 9 28 6 4
1205 | <_>
1206 |
1207 | 10 0 9 3
1208 | <_>
1209 |
1210 | 10 3 1 1
1211 | <_>
1212 |
1213 | 10 10 11 11
1214 | <_>
1215 |
1216 | 10 15 4 3
1217 | <_>
1218 |
1219 | 11 4 2 1
1220 | <_>
1221 |
1222 | 11 27 4 3
1223 | <_>
1224 |
1225 | 11 36 8 2
1226 | <_>
1227 |
1228 | 12 0 2 2
1229 | <_>
1230 |
1231 | 12 23 4 3
1232 | <_>
1233 |
1234 | 12 25 4 3
1235 | <_>
1236 |
1237 | 12 29 5 3
1238 | <_>
1239 |
1240 | 12 33 3 4
1241 | <_>
1242 |
1243 | 13 0 2 2
1244 | <_>
1245 |
1246 | 13 36 8 3
1247 | <_>
1248 |
1249 | 14 0 2 2
1250 | <_>
1251 |
1252 | 15 15 2 2
1253 | <_>
1254 |
1255 | 16 13 3 4
1256 | <_>
1257 |
1258 | 17 0 1 3
1259 | <_>
1260 |
1261 | 17 1 3 3
1262 | <_>
1263 |
1264 | 17 31 5 3
1265 | <_>
1266 |
1267 | 17 35 3 1
1268 | <_>
1269 |
1270 | 18 13 2 3
1271 | <_>
1272 |
1273 | 18 39 2 1
1274 | <_>
1275 |
1276 | 19 0 7 15
1277 | <_>
1278 |
1279 | 19 2 7 2
1280 | <_>
1281 |
1282 | 19 3 7 13
1283 | <_>
1284 |
1285 | 19 14 2 2
1286 | <_>
1287 |
1288 | 19 24 7 4
1289 | <_>
1290 |
1291 | 20 1 6 13
1292 | <_>
1293 |
1294 | 20 8 7 3
1295 | <_>
1296 |
1297 | 20 9 7 3
1298 | <_>
1299 |
1300 | 20 13 1 1
1301 | <_>
1302 |
1303 | 20 14 2 3
1304 | <_>
1305 |
1306 | 20 30 3 2
1307 | <_>
1308 |
1309 | 21 0 3 4
1310 | <_>
1311 |
1312 | 21 0 6 8
1313 | <_>
1314 |
1315 | 21 3 6 2
1316 | <_>
1317 |
1318 | 21 6 6 4
1319 | <_>
1320 |
1321 | 21 37 2 1
1322 | <_>
1323 |
1324 | 22 3 6 2
1325 | <_>
1326 |
1327 | 22 13 1 2
1328 | <_>
1329 |
1330 | 22 22 4 3
1331 | <_>
1332 |
1333 | 23 0 2 3
1334 | <_>
1335 |
1336 | 23 3 6 2
1337 | <_>
1338 |
1339 | 23 9 5 4
1340 | <_>
1341 |
1342 | 23 11 1 1
1343 | <_>
1344 |
1345 | 23 15 1 1
1346 | <_>
1347 |
1348 | 23 16 3 2
1349 | <_>
1350 |
1351 | 23 35 2 1
1352 | <_>
1353 |
1354 | 23 36 1 1
1355 | <_>
1356 |
1357 | 23 39 6 2
1358 | <_>
1359 |
1360 | 24 0 2 3
1361 | <_>
1362 |
1363 | 24 8 6 11
1364 | <_>
1365 |
1366 | 24 28 2 2
1367 | <_>
1368 |
1369 | 24 33 4 4
1370 | <_>
1371 |
1372 | 25 16 4 3
1373 | <_>
1374 |
1375 | 25 31 5 3
1376 | <_>
1377 |
1378 | 26 0 1 2
1379 | <_>
1380 |
1381 | 26 0 2 2
1382 | <_>
1383 |
1384 | 26 0 3 2
1385 | <_>
1386 |
1387 | 26 24 4 4
1388 | <_>
1389 |
1390 | 27 30 4 5
1391 | <_>
1392 |
1393 | 27 36 5 3
1394 | <_>
1395 |
1396 | 28 0 2 2
1397 | <_>
1398 |
1399 | 28 4 2 1
1400 | <_>
1401 |
1402 | 28 21 2 5
1403 | <_>
1404 |
1405 | 29 8 2 1
1406 | <_>
1407 |
1408 | 33 0 2 1
1409 | <_>
1410 |
1411 | 33 0 4 2
1412 | <_>
1413 |
1414 | 33 0 4 6
1415 | <_>
1416 |
1417 | 33 3 1 1
1418 | <_>
1419 |
1420 | 33 6 4 12
1421 | <_>
1422 |
1423 | 33 21 4 2
1424 | <_>
1425 |
1426 | 33 36 4 3
1427 | <_>
1428 |
1429 | 35 1 2 2
1430 | <_>
1431 |
1432 | 36 5 1 1
1433 | <_>
1434 |
1435 | 36 29 3 4
1436 | <_>
1437 |
1438 | 36 39 2 2
1439 | <_>
1440 |
1441 | 37 5 2 2
1442 | <_>
1443 |
1444 | 38 6 2 1
1445 | <_>
1446 |
1447 | 38 6 2 2
1448 | <_>
1449 |
1450 | 39 1 2 12
1451 | <_>
1452 |
1453 | 39 24 1 2
1454 | <_>
1455 |
1456 | 39 36 2 2
1457 | <_>
1458 |
1459 | 40 39 1 2
1460 | <_>
1461 |
1462 | 42 4 1 1
1463 | <_>
1464 |
1465 | 42 20 1 2
1466 | <_>
1467 |
1468 | 42 29 1 2
1469 |
1470 |
--------------------------------------------------------------------------------
/models/eyenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from models.layers import Conv, Hourglass, Pool, Residual
4 | from models.losses import HeatmapLoss
5 | from util.softargmax import softargmax2d
6 |
7 |
8 | class Merge(nn.Module):
9 | def __init__(self, x_dim, y_dim):
10 | super(Merge, self).__init__()
11 | self.conv = Conv(x_dim, y_dim, 1, relu=False, bn=False)
12 |
13 | def forward(self, x):
14 | return self.conv(x)
15 |
16 |
17 | class EyeNet(nn.Module):
18 | def __init__(self, nstack, nfeatures, nlandmarks, bn=False, increase=0, **kwargs):
19 | super(EyeNet, self).__init__()
20 |
21 | self.img_w = 160
22 | self.img_h = 96
23 | self.nstack = nstack
24 | self.nfeatures = nfeatures
25 | self.nlandmarks = nlandmarks
26 |
27 | self.heatmap_w = self.img_w / 2
28 | self.heatmap_h = self.img_h / 2
29 |
30 | self.nstack = nstack
31 | self.pre = nn.Sequential(
32 | Conv(1, 64, 7, 1, bn=True, relu=True),
33 | Residual(64, 128),
34 | Pool(2, 2),
35 | Residual(128, 128),
36 | Residual(128, nfeatures)
37 | )
38 |
39 | self.pre2 = nn.Sequential(
40 | Conv(nfeatures, 64, 7, 2, bn=True, relu=True),
41 | Residual(64, 128),
42 | Pool(2, 2),
43 | Residual(128, 128),
44 | Residual(128, nfeatures)
45 | )
46 |
47 | self.hgs = nn.ModuleList([
48 | nn.Sequential(
49 | Hourglass(4, nfeatures, bn, increase),
50 | ) for i in range(nstack)])
51 |
52 | self.features = nn.ModuleList([
53 | nn.Sequential(
54 | Residual(nfeatures, nfeatures),
55 | Conv(nfeatures, nfeatures, 1, bn=True, relu=True)
56 | ) for i in range(nstack)])
57 |
58 | self.outs = nn.ModuleList([Conv(nfeatures, nlandmarks, 1, relu=False, bn=False) for i in range(nstack)])
59 | self.merge_features = nn.ModuleList([Merge(nfeatures, nfeatures) for i in range(nstack - 1)])
60 | self.merge_preds = nn.ModuleList([Merge(nlandmarks, nfeatures) for i in range(nstack - 1)])
61 |
62 | self.gaze_fc1 = nn.Linear(in_features=int(nfeatures * self.img_w * self.img_h / 64 + nlandmarks*2), out_features=256)
63 | self.gaze_fc2 = nn.Linear(in_features=256, out_features=2)
64 |
65 | self.nstack = nstack
66 | self.heatmapLoss = HeatmapLoss()
67 | self.landmarks_loss = nn.MSELoss()
68 | self.gaze_loss = nn.MSELoss()
69 |
70 | def forward(self, imgs):
71 | # imgs of size 1,ih,iw
72 | x = imgs.unsqueeze(1)
73 | x = self.pre(x)
74 |
75 | gaze_x = self.pre2(x)
76 | gaze_x = gaze_x.flatten(start_dim=1)
77 |
78 | combined_hm_preds = []
79 | for i in torch.arange(self.nstack):
80 | hg = self.hgs[i](x)
81 | feature = self.features[i](hg)
82 | preds = self.outs[i](feature)
83 | combined_hm_preds.append(preds)
84 | if i < self.nstack - 1:
85 | x = x + self.merge_preds[i](preds) + self.merge_features[i](feature)
86 |
87 | heatmaps_out = torch.stack(combined_hm_preds, 1)
88 |
89 | # preds = N x nlandmarks * heatmap_w * heatmap_h
90 | landmarks_out = softargmax2d(preds) # N x nlandmarks x 2
91 |
92 | # Gaze
93 | gaze = torch.cat((gaze_x, landmarks_out.flatten(start_dim=1)), dim=1)
94 | gaze = self.gaze_fc1(gaze)
95 | gaze = nn.functional.relu(gaze)
96 | gaze = self.gaze_fc2(gaze)
97 |
98 | return heatmaps_out, landmarks_out, gaze
99 |
100 | def calc_loss(self, combined_hm_preds, heatmaps, landmarks_pred, landmarks, gaze_pred, gaze):
101 | combined_loss = []
102 | for i in range(self.nstack):
103 | combined_loss.append(self.heatmapLoss(combined_hm_preds[:, i, :], heatmaps))
104 |
105 | heatmap_loss = torch.stack(combined_loss, dim=1)
106 | landmarks_loss = self.landmarks_loss(landmarks_pred, landmarks)
107 | gaze_loss = self.gaze_loss(gaze_pred, gaze)
108 |
109 | return torch.sum(heatmap_loss), landmarks_loss, 1000 * gaze_loss
110 |
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | Pool = nn.MaxPool2d
4 |
5 |
6 | def batchnorm(x):
7 | return nn.BatchNorm2d(x.size()[1])(x)
8 |
9 |
10 | class Conv(nn.Module):
11 | def __init__(self, inp_dim, out_dim, kernel_size=3, stride = 1, bn = False, relu = True):
12 | super(Conv, self).__init__()
13 | self.inp_dim = inp_dim
14 | self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=True)
15 | self.relu = None
16 | self.bn = None
17 | if relu:
18 | self.relu = nn.ReLU()
19 | if bn:
20 | self.bn = nn.BatchNorm2d(out_dim)
21 |
22 | def forward(self, x):
23 | assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
24 | x = self.conv(x)
25 | if self.bn is not None:
26 | x = self.bn(x)
27 | if self.relu is not None:
28 | x = self.relu(x)
29 | return x
30 |
31 |
32 | class Residual(nn.Module):
33 | def __init__(self, inp_dim, out_dim):
34 | super(Residual, self).__init__()
35 | self.relu = nn.ReLU()
36 | self.bn1 = nn.BatchNorm2d(inp_dim)
37 | self.conv1 = Conv(inp_dim, int(out_dim/2), 1, relu=False)
38 | self.bn2 = nn.BatchNorm2d(int(out_dim/2))
39 | self.conv2 = Conv(int(out_dim/2), int(out_dim/2), 3, relu=False)
40 | self.bn3 = nn.BatchNorm2d(int(out_dim/2))
41 | self.conv3 = Conv(int(out_dim/2), out_dim, 1, relu=False)
42 | self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)
43 | if inp_dim == out_dim:
44 | self.need_skip = False
45 | else:
46 | self.need_skip = True
47 |
48 | def forward(self, x):
49 | if self.need_skip:
50 | residual = self.skip_layer(x)
51 | else:
52 | residual = x
53 | out = self.bn1(x)
54 | out = self.relu(out)
55 | out = self.conv1(out)
56 | out = self.bn2(out)
57 | out = self.relu(out)
58 | out = self.conv2(out)
59 | out = self.bn3(out)
60 | out = self.relu(out)
61 | out = self.conv3(out)
62 | out += residual
63 | return out
64 |
65 |
66 | class Hourglass(nn.Module):
67 | def __init__(self, n, f, bn=None, increase=0):
68 | super(Hourglass, self).__init__()
69 | nf = f + increase
70 | self.up1 = Residual(f, f)
71 | # Lower branch
72 | self.pool1 = Pool(2, 2)
73 | self.low1 = Residual(f, nf)
74 | self.n = n
75 | # Recursive hourglass
76 | if self.n > 1:
77 | self.low2 = Hourglass(n-1, nf, bn=bn)
78 | else:
79 | self.low2 = Residual(nf, nf)
80 | self.low3 = Residual(nf, f)
81 |
82 | def forward(self, x):
83 | up1 = self.up1(x)
84 | pool1 = self.pool1(x)
85 | low1 = self.low1(pool1)
86 | low2 = self.low2(low1)
87 | low3 = self.low3(low2)
88 | up2 = nn.functional.interpolate(low3, x.shape[2:], mode='bilinear')
89 | return up1 + up2
90 |
--------------------------------------------------------------------------------
/models/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class HeatmapLoss(torch.nn.Module):
5 | def __init__(self):
6 | super(HeatmapLoss, self).__init__()
7 |
8 | def forward(self, pred, gt):
9 | loss = ((pred - gt)**2)
10 | loss = torch.mean(loss, dim=(1, 2, 3))
11 | return loss
12 |
13 |
14 | class AngularError(torch.nn.Module):
15 | def __init__(self):
16 | super(AngularError, self).__init__()
17 |
18 | def forward(self, gaze_pred, gaze):
19 | loss = ((gaze_pred - gaze)**2)
20 | loss = torch.mean(loss, dim=(1, 2, 3))
21 | return loss
22 |
--------------------------------------------------------------------------------
/notebooks/check_preprocessing.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import sys\n",
10 | "import os\n",
11 | "module_path = os.path.abspath(os.path.join('..'))\n",
12 | "if module_path not in sys.path:\n",
13 | " sys.path.append(module_path)\n",
14 | "\n",
15 | "import cv2\n",
16 | "import numpy as np\n",
17 | "from matplotlib import pyplot as plt\n",
18 | "from datasets.unity_eyes import UnityEyesDataset\n"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "## Let's take look at the first image\n",
26 | "These images are generated from UnityEyes. Each image is a .jpg file and comes with a .json file of metadata. The json file contains locations of all the eye landmark positions withing the image. "
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 3,
32 | "metadata": {},
33 | "outputs": [],
34 | "source": [
35 | "dataset = UnityEyesDataset()\n",
36 | "sample = dataset[0]\n"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 4,
42 | "metadata": {},
43 | "outputs": [
44 | {
45 | "data": {
46 | "text/plain": [
47 | ""
48 | ]
49 | },
50 | "execution_count": 4,
51 | "metadata": {},
52 | "output_type": "execute_result"
53 | },
54 | {
55 | "data": {
56 | "image/png": "\n",
57 | "text/plain": [
58 | ""
59 | ]
60 | },
61 | "metadata": {
62 | "needs_background": "light"
63 | },
64 | "output_type": "display_data"
65 | }
66 | ],
67 | "source": [
68 | "plt.imshow(sample['full_img'])"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "## Ground Truth labels"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 5,
81 | "metadata": {},
82 | "outputs": [
83 | {
84 | "data": {
85 | "text/plain": [
86 | ""
87 | ]
88 | },
89 | "execution_count": 5,
90 | "metadata": {},
91 | "output_type": "execute_result"
92 | },
93 | {
94 | "data": {
95 | "image/png": "\n",
96 | "text/plain": [
97 | ""
98 | ]
99 | },
100 | "metadata": {
101 | "needs_background": "light"
102 | },
103 | "output_type": "display_data"
104 | }
105 | ],
106 | "source": [
107 | "heatmaps = sample['heatmaps']\n",
108 | "merged_heatmaps = np.mean(heatmaps[16:33], axis=0)\n",
109 | "plt.imshow(merged_heatmaps)"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": 6,
115 | "metadata": {},
116 | "outputs": [
117 | {
118 | "data": {
119 | "text/plain": [
120 | ""
121 | ]
122 | },
123 | "execution_count": 6,
124 | "metadata": {},
125 | "output_type": "execute_result"
126 | },
127 | {
128 | "data": {
129 | "image/png": "\n",
130 | "text/plain": [
131 | ""
132 | ]
133 | },
134 | "metadata": {
135 | "needs_background": "light"
136 | },
137 | "output_type": "display_data"
138 | }
139 | ],
140 | "source": [
141 | "plt.imshow(sample['img'])"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": 7,
147 | "metadata": {},
148 | "outputs": [
149 | {
150 | "data": {
151 | "text/plain": [
152 | "array([[33.516087 , 17.515186 ],\n",
153 | " [29.25981 , 14.860628 ],\n",
154 | " [20.977678 , 19.982723 ],\n",
155 | " [11.359354 , 31.219086 ],\n",
156 | " [ 9.150034 , 45.183544 ],\n",
157 | " [15.333946 , 55.666534 ],\n",
158 | " [22.347504 , 61.888092 ],\n",
159 | " [29.66824 , 64.39625 ],\n",
160 | " [35.430656 , 64.49766 ],\n",
161 | " [37.92424 , 63.95883 ],\n",
162 | " [40.799217 , 60.91427 ],\n",
163 | " [41.691284 , 54.81529 ],\n",
164 | " [41.837326 , 47.023544 ],\n",
165 | " [41.11672 , 36.859245 ],\n",
166 | " [39.741966 , 25.973618 ],\n",
167 | " [37.847527 , 19.782518 ],\n",
168 | " [19.173851 , 24.480421 ],\n",
169 | " [13.450941 , 24.940432 ],\n",
170 | " [ 8.675809 , 27.653122 ],\n",
171 | " [ 5.5754147, 32.205536 ],\n",
172 | " [ 4.621802 , 37.90466 ],\n",
173 | " [ 5.9601398, 43.88277 ],\n",
174 | " [ 9.38669 , 49.229847 ],\n",
175 | " [14.379737 , 53.131775 ],\n",
176 | " [20.179201 , 54.99453 ],\n",
177 | " [25.902151 , 54.534557 ],\n",
178 | " [30.677284 , 51.821865 ],\n",
179 | " [33.777637 , 47.26945 ],\n",
180 | " [34.73129 , 41.570328 ],\n",
181 | " [33.392914 , 35.592182 ],\n",
182 | " [29.966402 , 30.245142 ],\n",
183 | " [24.973314 , 26.343214 ],\n",
184 | " [19.67654 , 39.737488 ],\n",
185 | " [32.095753 , 28.837555 ]], dtype=float32)"
186 | ]
187 | },
188 | "execution_count": 7,
189 | "metadata": {},
190 | "output_type": "execute_result"
191 | }
192 | ],
193 | "source": [
194 | "sample['landmarks']\n"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": null,
200 | "metadata": {},
201 | "outputs": [],
202 | "source": []
203 | }
204 | ],
205 | "metadata": {
206 | "kernelspec": {
207 | "display_name": "Python 3",
208 | "language": "python",
209 | "name": "python3"
210 | },
211 | "language_info": {
212 | "codemirror_mode": {
213 | "name": "ipython",
214 | "version": 3
215 | },
216 | "file_extension": ".py",
217 | "mimetype": "text/x-python",
218 | "name": "python",
219 | "nbconvert_exporter": "python",
220 | "pygments_lexer": "ipython3",
221 | "version": "3.6.10"
222 | }
223 | },
224 | "nbformat": 4,
225 | "nbformat_minor": 4
226 | }
227 |
--------------------------------------------------------------------------------
/run_with_webcam.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import torch
4 | from torch.nn import DataParallel
5 |
6 | from models.eyenet import EyeNet
7 | import os
8 | import numpy as np
9 | import cv2
10 | import dlib
11 | import imutils
12 | import util.gaze
13 | from imutils import face_utils
14 |
15 | from util.eye_prediction import EyePrediction
16 | from util.eye_sample import EyeSample
17 |
18 | torch.backends.cudnn.enabled = True
19 |
20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21 | print(device)
22 |
23 | webcam = cv2.VideoCapture(0)
24 | webcam.set(cv2.CAP_PROP_FRAME_WIDTH, 960)
25 | webcam.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
26 | webcam.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG'))
27 | webcam.set(cv2.CAP_PROP_FPS, 60)
28 |
29 | dirname = os.path.dirname(__file__)
30 | face_cascade = cv2.CascadeClassifier(os.path.join(dirname, 'lbpcascade_frontalface_improved.xml'))
31 | landmarks_detector = dlib.shape_predictor(os.path.join(dirname, 'shape_predictor_5_face_landmarks.dat'))
32 |
33 | checkpoint = torch.load('checkpoint.pt', map_location=device)
34 | nstack = checkpoint['nstack']
35 | nfeatures = checkpoint['nfeatures']
36 | nlandmarks = checkpoint['nlandmarks']
37 | eyenet = EyeNet(nstack=nstack, nfeatures=nfeatures, nlandmarks=nlandmarks).to(device)
38 | eyenet.load_state_dict(checkpoint['model_state_dict'])
39 |
40 | def main():
41 | current_face = None
42 | landmarks = None
43 | alpha = 0.95
44 | left_eye = None
45 | right_eye = None
46 |
47 | while True:
48 | _, frame_bgr = webcam.read()
49 | orig_frame = frame_bgr.copy()
50 | frame = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
51 | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
52 | faces = face_cascade.detectMultiScale(gray)
53 |
54 | if len(faces):
55 | next_face = faces[0]
56 | if current_face is not None:
57 | current_face = alpha * next_face + (1 - alpha) * current_face
58 | else:
59 | current_face = next_face
60 |
61 | if current_face is not None:
62 | #draw_cascade_face(current_face, orig_frame)
63 | next_landmarks = detect_landmarks(current_face, gray)
64 |
65 | if landmarks is not None:
66 | landmarks = next_landmarks * alpha + (1 - alpha) * landmarks
67 | else:
68 | landmarks = next_landmarks
69 |
70 | #draw_landmarks(landmarks, orig_frame)
71 |
72 |
73 | if landmarks is not None:
74 | eye_samples = segment_eyes(gray, landmarks)
75 |
76 | eye_preds = run_eyenet(eye_samples)
77 | left_eyes = list(filter(lambda x: x.eye_sample.is_left, eye_preds))
78 | right_eyes = list(filter(lambda x: not x.eye_sample.is_left, eye_preds))
79 |
80 | if left_eyes:
81 | left_eye = smooth_eye_landmarks(left_eyes[0], left_eye, smoothing=0.1)
82 | if right_eyes:
83 | right_eye = smooth_eye_landmarks(right_eyes[0], right_eye, smoothing=0.1)
84 |
85 | for ep in [left_eye, right_eye]:
86 | for (x, y) in ep.landmarks[16:33]:
87 | color = (0, 255, 0)
88 | if ep.eye_sample.is_left:
89 | color = (255, 0, 0)
90 | cv2.circle(orig_frame,
91 | (int(round(x)), int(round(y))), 1, color, -1, lineType=cv2.LINE_AA)
92 |
93 | gaze = ep.gaze.copy()
94 | if ep.eye_sample.is_left:
95 | gaze[1] = -gaze[1]
96 | util.gaze.draw_gaze(orig_frame, ep.landmarks[-2], gaze, length=60.0, thickness=2)
97 |
98 | cv2.imshow("Webcam", orig_frame)
99 | cv2.waitKey(1)
100 |
101 |
102 | def detect_landmarks(face, frame, scale_x=0, scale_y=0):
103 | (x, y, w, h) = (int(e) for e in face)
104 | rectangle = dlib.rectangle(x, y, x + w, y + h)
105 | face_landmarks = landmarks_detector(frame, rectangle)
106 | return face_utils.shape_to_np(face_landmarks)
107 |
108 |
109 | def draw_cascade_face(face, frame):
110 | (x, y, w, h) = (int(e) for e in face)
111 | cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 0, 255), 2)
112 |
113 |
114 | def draw_landmarks(landmarks, frame):
115 | for (x, y) in landmarks:
116 | cv2.circle(frame, (int(x), int(y)), 2, (0, 255, 0), -1, lineType=cv2.LINE_AA)
117 |
118 |
119 | def segment_eyes(frame, landmarks, ow=160, oh=96):
120 | eyes = []
121 |
122 | # Segment eyes
123 | for corner1, corner2, is_left in [(2, 3, True), (0, 1, False)]:
124 | x1, y1 = landmarks[corner1, :]
125 | x2, y2 = landmarks[corner2, :]
126 | eye_width = 1.5 * np.linalg.norm(landmarks[corner1, :] - landmarks[corner2, :])
127 | if eye_width == 0.0:
128 | return eyes
129 |
130 | cx, cy = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
131 |
132 | # center image on middle of eye
133 | translate_mat = np.asmatrix(np.eye(3))
134 | translate_mat[:2, 2] = [[-cx], [-cy]]
135 | inv_translate_mat = np.asmatrix(np.eye(3))
136 | inv_translate_mat[:2, 2] = -translate_mat[:2, 2]
137 |
138 | # Scale
139 | scale = ow / eye_width
140 | scale_mat = np.asmatrix(np.eye(3))
141 | scale_mat[0, 0] = scale_mat[1, 1] = scale
142 | inv_scale = 1.0 / scale
143 | inv_scale_mat = np.asmatrix(np.eye(3))
144 | inv_scale_mat[0, 0] = inv_scale_mat[1, 1] = inv_scale
145 |
146 | estimated_radius = 0.5 * eye_width * scale
147 |
148 | # center image
149 | center_mat = np.asmatrix(np.eye(3))
150 | center_mat[:2, 2] = [[0.5 * ow], [0.5 * oh]]
151 | inv_center_mat = np.asmatrix(np.eye(3))
152 | inv_center_mat[:2, 2] = -center_mat[:2, 2]
153 |
154 | # Get rotated and scaled, and segmented image
155 | transform_mat = center_mat * scale_mat * translate_mat
156 | inv_transform_mat = (inv_translate_mat * inv_scale_mat * inv_center_mat)
157 |
158 | eye_image = cv2.warpAffine(frame, transform_mat[:2, :], (ow, oh))
159 | eye_image = cv2.equalizeHist(eye_image)
160 |
161 | if is_left:
162 | eye_image = np.fliplr(eye_image)
163 | cv2.imshow('left eye image', eye_image)
164 | else:
165 | cv2.imshow('right eye image', eye_image)
166 | eyes.append(EyeSample(orig_img=frame.copy(),
167 | img=eye_image,
168 | transform_inv=inv_transform_mat,
169 | is_left=is_left,
170 | estimated_radius=estimated_radius))
171 | return eyes
172 |
173 |
174 | def smooth_eye_landmarks(eye: EyePrediction, prev_eye: Optional[EyePrediction], smoothing=0.2, gaze_smoothing=0.4):
175 | if prev_eye is None:
176 | return eye
177 | return EyePrediction(
178 | eye_sample=eye.eye_sample,
179 | landmarks=smoothing * prev_eye.landmarks + (1 - smoothing) * eye.landmarks,
180 | gaze=gaze_smoothing * prev_eye.gaze + (1 - gaze_smoothing) * eye.gaze)
181 |
182 |
183 | def run_eyenet(eyes: List[EyeSample], ow=160, oh=96) -> List[EyePrediction]:
184 | result = []
185 | for eye in eyes:
186 | with torch.no_grad():
187 | x = torch.tensor([eye.img], dtype=torch.float32).to(device)
188 | _, landmarks, gaze = eyenet.forward(x)
189 | landmarks = np.asarray(landmarks.cpu().numpy()[0])
190 | gaze = np.asarray(gaze.cpu().numpy()[0])
191 | assert gaze.shape == (2,)
192 | assert landmarks.shape == (34, 2)
193 |
194 | landmarks = landmarks * np.array([oh/48, ow/80])
195 |
196 | temp = np.zeros((34, 3))
197 | if eye.is_left:
198 | temp[:, 0] = ow - landmarks[:, 1]
199 | else:
200 | temp[:, 0] = landmarks[:, 1]
201 | temp[:, 1] = landmarks[:, 0]
202 | temp[:, 2] = 1.0
203 | landmarks = temp
204 | assert landmarks.shape == (34, 3)
205 | landmarks = np.asarray(np.matmul(landmarks, eye.transform_inv.T))[:, :2]
206 | assert landmarks.shape == (34, 2)
207 | result.append(EyePrediction(eye_sample=eye, landmarks=landmarks, gaze=gaze))
208 | return result
209 |
210 |
211 | if __name__ == '__main__':
212 | main()
--------------------------------------------------------------------------------
/scripts/fetch_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
4 |
5 | cd ${DIR}/..
6 |
7 | # Download face landmark predictor model
8 | wget http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2
9 | bzip2 -d shape_predictor_5_face_landmarks.dat.bz2
10 | rm shape_predictor_5_face_landmarks.dat.bz2
11 |
12 |
13 | # Download trained pytorch model
14 | wget "https://drive.google.com/uc?export=download&id=17aJAUAIl-1VPvJcPeahH8MQrcLRpy9Li" -O checkpoint.pt
15 |
--------------------------------------------------------------------------------
/static/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david-wb/gaze-estimation/249691893a37944a03e4ad4a3448083b6f63af10/static/fig1.png
--------------------------------------------------------------------------------
/static/ge_screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david-wb/gaze-estimation/249691893a37944a03e4ad4a3448083b6f63af10/static/ge_screenshot.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datasets.unity_eyes import UnityEyesDataset
3 | from models.eyenet import EyeNet
4 | import os
5 | import numpy as np
6 | import cv2
7 | from util.preprocess import gaussian_2d
8 | from matplotlib import pyplot as plt
9 | from util.gaze import draw_gaze
10 |
11 | device = torch.device('cpu')
12 | dataset = UnityEyesDataset()
13 | checkpoint = torch.load('checkpoint.pt', map_location=device)
14 | nstack = checkpoint['nstack']
15 | nfeatures = checkpoint['nfeatures']
16 | nlandmarks = checkpoint['nlandmarks']
17 | eyenet = EyeNet(nstack=nstack, nfeatures=nfeatures, nlandmarks=nlandmarks).to(device)
18 | eyenet.load_state_dict(checkpoint['model_state_dict'])
19 |
20 | with torch.no_grad():
21 | sample = dataset[2]
22 | x = torch.tensor([sample['img']]).float().to(device)
23 | heatmaps = sample['heatmaps']
24 | heatmaps_pred, landmarks_pred, gaze_pred = eyenet.forward(x)
25 |
26 | landmarks_pred = landmarks_pred.cpu().numpy()[0, :]
27 |
28 | result = [gaussian_2d(w=80, h=48, cx=c[1], cy=c[0], sigma=3) for c in landmarks_pred]
29 |
30 | plt.figure(figsize=(8, 9))
31 |
32 | iris_center = sample['landmarks'][-2][::-1]
33 | iris_center *= 2
34 | img = sample['img']
35 |
36 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
37 |
38 | img_gaze_pred = img.copy()
39 | for (y, x) in landmarks_pred[-2:-1]:
40 | cv2.circle(img_gaze_pred, (int(x*2), int(y*2)), 2, (255, 0, 0), -1)
41 | draw_gaze(img_gaze_pred, iris_center, gaze_pred.cpu().numpy()[0, :], length=60, color=(255, 0, 0))
42 |
43 | img_gaze = img.copy()
44 | for (x, y) in sample['landmarks'][-2:-1]:
45 | cv2.circle(img_gaze, (int(x*2), int(y*2)), 2, (0, 255, 0), -1)
46 | draw_gaze(img_gaze, iris_center, sample['gaze'], length=60, color=(0, 255, 0))
47 |
48 | plt.subplot(321)
49 | plt.imshow(cv2.cvtColor(sample['full_img'], cv2.COLOR_BGR2RGB))
50 | plt.title('Raw training image')
51 |
52 | plt.subplot(322)
53 | plt.imshow(img, cmap='gray')
54 | plt.title('Preprocessed training image')
55 |
56 | plt.subplot(323)
57 | plt.imshow(np.mean(heatmaps[16:32], axis=0), cmap='gray')
58 | plt.title('Ground truth heatmaps')
59 |
60 | plt.subplot(324)
61 | plt.imshow(np.mean(result[16:32], axis=0), cmap='gray')
62 | plt.title('Predicted heatmaps')
63 |
64 | plt.subplot(325)
65 | plt.imshow(img_gaze, cmap='gray')
66 | plt.title('Ground truth landmarks and gaze vector')
67 |
68 | plt.subplot(326)
69 | plt.imshow(img_gaze_pred, cmap='gray')
70 | plt.title('Predicted landmarks and gaze vector')
71 | plt.show()
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/david-wb/gaze-estimation/249691893a37944a03e4ad4a3448083b6f63af10/test/__init__.py
--------------------------------------------------------------------------------
/test/data/imgs/1.json:
--------------------------------------------------------------------------------
1 | {
2 | "interior_margin_2d": [
3 | "(371.4855, 296.4230, 8.8595)",
4 | "(364.8002, 307.1421, 8.8963)",
5 | "(377.6998, 328.0000, 8.8956)",
6 | "(405.9977, 352.2230, 8.9568)",
7 | "(441.1661, 357.7870, 9.1683)",
8 | "(467.5667, 342.2133, 9.4351)",
9 | "(483.2352, 324.5502, 9.6703)",
10 | "(489.5518, 306.1135, 9.8769)",
11 | "(489.8072, 291.6013, 10.0333)",
12 | "(488.4502, 285.3214, 9.8645)",
13 | "(480.7827, 278.0810, 9.5506)",
14 | "(465.4229, 275.8344, 9.2366)",
15 | "(445.8000, 275.4666, 9.0163)",
16 | "(420.2020, 277.2814, 8.8672)",
17 | "(392.7874, 280.7436, 8.8285)",
18 | "(377.1956, 285.5146, 8.8521)"
19 | ],
20 | "caruncle_2d": [
21 | "(364.5717, 284.7197, 8.8818)",
22 | "(364.8584, 288.3494, 8.8758)",
23 | "(357.4482, 296.9649, 8.8906)",
24 | "(349.9624, 289.0576, 8.8963)",
25 | "(353.7258, 283.3381, 8.8824)",
26 | "(351.5754, 280.2595, 8.8752)",
27 | "(346.2976, 282.7485, 8.8394)"
28 | ],
29 | "iris_2d": [
30 | "(389.0269, 332.5428, 8.8831)",
31 | "(388.8647, 339.9151, 8.9267)",
32 | "(390.1854, 346.9555, 8.9762)",
33 | "(392.9381, 353.3934, 9.0297)",
34 | "(397.0171, 358.9813, 9.0850)",
35 | "(402.2656, 363.5046, 9.1402)",
36 | "(408.4820, 366.7894, 9.1929)",
37 | "(415.4274, 368.7094, 9.2413)",
38 | "(422.8348, 369.1910, 9.2835)",
39 | "(430.4196, 368.2155, 9.3178)",
40 | "(437.8902, 365.8205, 9.3429)",
41 | "(444.9597, 362.0980, 9.3579)",
42 | "(451.3564, 357.1910, 9.3622)",
43 | "(456.8343, 351.2882, 9.3555)",
44 | "(461.1831, 344.6164, 9.3383)",
45 | "(464.2355, 337.4319, 9.3111)",
46 | "(465.8743, 330.0109, 9.2750)",
47 | "(466.0365, 322.6385, 9.2313)",
48 | "(464.7159, 315.5981, 9.1818)",
49 | "(461.9632, 309.1602, 9.1284)",
50 | "(457.8842, 303.5723, 9.0730)",
51 | "(452.6357, 299.0490, 9.0179)",
52 | "(446.4193, 295.7643, 8.9651)",
53 | "(439.4739, 293.8442, 8.9167)",
54 | "(432.0665, 293.3626, 8.8746)",
55 | "(424.4817, 294.3381, 8.8403)",
56 | "(417.0110, 296.7332, 8.8152)",
57 | "(409.9415, 300.4557, 8.8002)",
58 | "(403.5449, 305.3626, 8.7959)",
59 | "(398.0670, 311.2654, 8.8025)",
60 | "(393.7182, 317.9373, 8.8197)",
61 | "(390.6658, 325.1217, 8.8470)"
62 | ],
63 | "eye_details": {
64 | "look_vec": "(0.3404, 0.3879, -0.8566, 0.0000)",
65 | "pupil_size": "0.1020631",
66 | "iris_size": "0.9349335",
67 | "iris_texture": "eyeball_brown"
68 | },
69 | "lighting_details": {
70 | "skybox_texture": "bergen_2k",
71 | "skybox_exposure": "1.191213",
72 | "skybox_rotation": "231",
73 | "ambient_intensity": "1.15163",
74 | "light_rotation": "(30.0, 248.5, 0.0)",
75 | "light_intensity": "0.8451564"
76 | },
77 | "eye_region_details": {
78 | "pca_shape_coeffs": [
79 | "-0.01061361",
80 | "0.009601755",
81 | "0.01964062",
82 | "0.0175013",
83 | "-0.01174545",
84 | "0.01343334",
85 | "0.005124273",
86 | "0.002106255",
87 | "0.002141998",
88 | "0.004727394",
89 | "-0.008374349",
90 | "-0.002742675",
91 | "0.006649553",
92 | "-0.002045201",
93 | "-0.003022511",
94 | "-0.005640308",
95 | "-0.004294305",
96 | "0.0002127625",
97 | "0.002193088",
98 | "1.924006E-17"
99 | ],
100 | "primary_skin_texture": "f02_color"
101 | },
102 | "head_pose": "(355.0770, 203.8451, 0.0000)"
103 | }
--------------------------------------------------------------------------------
/test/test_unity_eyes.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from datasets.unity_eyes import UnityEyesDataset
4 | import os
5 |
6 | def test_unity_eyes():
7 | ds = UnityEyesDataset(img_dir=os.path.join(os.path.dirname(__file__), 'data/imgs'))
8 | sample = ds[0]
9 | assert sample['full_img'].shape == (600, 800, 3)
10 | assert sample['img'].shape == (90, 150, 3)
11 | assert float(sample['json_data']['eye_details']['iris_size']) == 0.9349335
12 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from datasets.unity_eyes import UnityEyesDataset
5 | from torch.utils.data import DataLoader
6 | from models.eyenet import EyeNet
7 | from torch.utils.tensorboard import SummaryWriter
8 | from datetime import datetime
9 | import numpy as np
10 | import cv2
11 | import argparse
12 |
13 | # Set up pytorch
14 | torch.backends.cudnn.enabled = False
15 | torch.manual_seed(0)
16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17 | print('Device', device)
18 |
19 | # Set up cmdline args
20 | parser = argparse.ArgumentParser(description='Trains an EyeNet model')
21 | parser.add_argument('--nstack', type=int, default=3, help='Number of hourglass layers.')
22 | parser.add_argument('--nfeatures', type=int, default=32, help='Number of feature maps to use.')
23 | parser.add_argument('--nlandmarks', type=int, default=34, help='Number of landmarks to be predicted.')
24 | parser.add_argument('--nepochs', type=int, default=10, help='Number of epochs to iterate over all training examples.')
25 | parser.add_argument('--start_from', help='A model checkpoint file to begin training from. This overrides all other arguments.')
26 | parser.add_argument('--out', default='checkpoint.pt', help='The output checkpoint filename')
27 | args = parser.parse_args()
28 |
29 |
30 | def validate(eyenet: EyeNet, val_loader: DataLoader) -> float:
31 | with torch.no_grad():
32 | val_losses = []
33 | for val_batch in val_loader:
34 | val_imgs = val_batch['img'].float().to(device)
35 | heatmaps = val_batch['heatmaps'].to(device)
36 | landmarks = val_batch['landmarks'].to(device)
37 | gaze = val_batch['gaze'].float().to(device)
38 | heatmaps_pred, landmarks_pred, gaze_pred = eyenet.forward(val_imgs)
39 | heatmaps_loss, landmarks_loss, gaze_loss = eyenet.calc_loss(
40 | heatmaps_pred, heatmaps, landmarks_pred, landmarks, gaze_pred, gaze)
41 | loss = 1000 * heatmaps_loss + landmarks_loss + gaze_loss
42 | val_losses.append(loss.item())
43 | val_loss = np.mean(val_losses)
44 | return val_loss
45 |
46 |
47 | def train_epoch(epoch: int,
48 | eyenet: EyeNet,
49 | optimizer,
50 | train_loader : DataLoader,
51 | val_loader: DataLoader,
52 | best_val_loss: float,
53 | checkpoint_fn: str,
54 | writer: SummaryWriter):
55 |
56 | N = len(train_loader)
57 | for i_batch, sample_batched in enumerate(train_loader):
58 | i_batch += N * epoch
59 | imgs = sample_batched['img'].float().to(device)
60 | heatmaps_pred, landmarks_pred, gaze_pred = eyenet.forward(imgs)
61 |
62 | heatmaps = sample_batched['heatmaps'].to(device)
63 | landmarks = sample_batched['landmarks'].float().to(device)
64 | gaze = sample_batched['gaze'].float().to(device)
65 |
66 | heatmaps_loss, landmarks_loss, gaze_loss = eyenet.calc_loss(
67 | heatmaps_pred, heatmaps, landmarks_pred, landmarks, gaze_pred, gaze)
68 |
69 | loss = 1000 * heatmaps_loss + landmarks_loss + gaze_loss
70 |
71 | optimizer.zero_grad()
72 | loss.backward()
73 | optimizer.step()
74 |
75 | hm = np.mean(heatmaps[-1, 8:16].cpu().detach().numpy(), axis=0)
76 | hm_pred = np.mean(heatmaps_pred[-1, -1, 8:16].cpu().detach().numpy(), axis=0)
77 | norm_hm = cv2.normalize(hm, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
78 | norm_hm_pred = cv2.normalize(hm_pred, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
79 |
80 | if i_batch % 20 == 0:
81 | cv2.imwrite('true.jpg', norm_hm * 255)
82 | cv2.imwrite('pred.jpg', norm_hm_pred * 255)
83 | cv2.imwrite('eye.jpg', sample_batched['img'].numpy()[-1] * 255)
84 |
85 | writer.add_scalar("Training heatmaps loss", heatmaps_loss.item(), i_batch)
86 | writer.add_scalar("Training landmarks loss", landmarks_loss.item(), i_batch)
87 | writer.add_scalar("Training gaze loss", gaze_loss.item(), i_batch)
88 | writer.add_scalar("Training loss", loss.item(), i_batch)
89 |
90 | if i_batch > 0 and i_batch % 20 == 0:
91 | val_loss = validate(eyenet=eyenet, val_loader=val_loader)
92 | writer.add_scalar("validation loss", val_loss, i_batch)
93 | print('Epoch', epoch, 'Validation loss', val_loss)
94 | if val_loss < best_val_loss:
95 | best_val_loss = val_loss
96 | torch.save({
97 | 'nstack': eyenet.nstack,
98 | 'nfeatures': eyenet.nfeatures,
99 | 'nlandmarks': eyenet.nlandmarks,
100 | 'best_val_loss': best_val_loss,
101 | 'model_state_dict': eyenet.state_dict(),
102 | 'optimizer_state_dict': optimizer.state_dict(),
103 | }, checkpoint_fn)
104 |
105 | return best_val_loss
106 |
107 |
108 | def train(eyenet: EyeNet, optimizer, nepochs: int, best_val_loss: float, checkpoint_fn: str):
109 | timestr = datetime.now().strftime("%m%d%Y-%H%M%S")
110 | writer = SummaryWriter(f'runs/eyenet-{timestr}')
111 | dataset = UnityEyesDataset()
112 | N = len(dataset)
113 | VN = 160
114 | TN = N - VN
115 | train_set, val_set = torch.utils.data.random_split(dataset, (TN, VN))
116 |
117 | train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
118 | val_loader = DataLoader(val_set, batch_size=16, shuffle=True)
119 |
120 | for i in range(nepochs):
121 | best_val_loss = train_epoch(epoch=i,
122 | eyenet=eyenet,
123 | optimizer=optimizer,
124 | train_loader=train_loader,
125 | val_loader=val_loader,
126 | best_val_loss=best_val_loss,
127 | checkpoint_fn=checkpoint_fn,
128 | writer=writer)
129 |
130 |
131 | def main():
132 | learning_rate = 4 * 1e-4
133 |
134 | if args.start_from:
135 | start_from = torch.load(args.start_from, map_location=device)
136 | nstack = start_from['nstack']
137 | nfeatures = start_from['nfeatures']
138 | nlandmarks = start_from['nlandmarks']
139 | best_val_loss = start_from['best_val_loss']
140 | eyenet = EyeNet(nstack=nstack, nfeatures=nfeatures, nlandmarks=nlandmarks).to(device)
141 | optimizer = torch.optim.Adam(eyenet.parameters(), lr=learning_rate)
142 | eyenet.load_state_dict(start_from['model_state_dict'])
143 | optimizer.load_state_dict(start_from['optimizer_state_dict'])
144 | elif os.path.exists(args.out):
145 | raise Exception(f'Out file {args.out} already exists.')
146 | else:
147 | nstack = args.nstack
148 | nfeatures = args.nfeatures
149 | nlandmarks = args.nlandmarks
150 | best_val_loss = float('inf')
151 | eyenet = EyeNet(nstack=nstack, nfeatures=nfeatures, nlandmarks=nlandmarks).to(device)
152 | optimizer = torch.optim.Adam(eyenet.parameters(), lr=learning_rate)
153 |
154 | train(
155 | eyenet=eyenet,
156 | optimizer=optimizer,
157 | nepochs=args.nepochs,
158 | best_val_loss=best_val_loss,
159 | checkpoint_fn=args.out
160 | )
161 |
162 |
163 | if __name__ == '__main__':
164 | main()
--------------------------------------------------------------------------------
/util/eye_prediction.py:
--------------------------------------------------------------------------------
1 | from util.eye_sample import EyeSample
2 |
3 |
4 | class EyePrediction():
5 | def __init__(self, eye_sample: EyeSample, landmarks, gaze):
6 | self._eye_sample = eye_sample
7 | self._landmarks = landmarks
8 | self._gaze = gaze
9 |
10 | @property
11 | def eye_sample(self):
12 | return self._eye_sample
13 |
14 | @property
15 | def landmarks(self):
16 | return self._landmarks
17 |
18 | @property
19 | def gaze(self):
20 | return self._gaze
21 |
--------------------------------------------------------------------------------
/util/eye_sample.py:
--------------------------------------------------------------------------------
1 |
2 | class EyeSample:
3 | def __init__(self, orig_img, img, is_left, transform_inv, estimated_radius):
4 | self._orig_img = orig_img.copy()
5 | self._img = img.copy()
6 | self._is_left = is_left
7 | self._transform_inv = transform_inv
8 | self._estimated_radius = estimated_radius
9 | @property
10 | def orig_img(self):
11 | return self._orig_img
12 |
13 | @property
14 | def img(self):
15 | return self._img
16 |
17 | @property
18 | def is_left(self):
19 | return self._is_left
20 |
21 | @property
22 | def transform_inv(self):
23 | return self._transform_inv
24 |
25 | @property
26 | def estimated_radius(self):
27 | return self._estimated_radius
--------------------------------------------------------------------------------
/util/gaze.py:
--------------------------------------------------------------------------------
1 | """Utility methods for gaze angle and error calculations."""
2 | import cv2 as cv
3 | import numpy as np
4 |
5 | def pitchyaw_to_vector(pitchyaws):
6 | r"""Convert given yaw (:math:`\theta`) and pitch (:math:`\phi`) angles to unit gaze vectors.
7 |
8 | Args:
9 | pitchyaws (:obj:`numpy.array`): yaw and pitch angles :math:`(n\times 2)` in radians.
10 |
11 | Returns:
12 | :obj:`numpy.array` of shape :math:`(n\times 3)` with 3D vectors per row.
13 | """
14 | n = pitchyaws.shape[0]
15 | sin = np.sin(pitchyaws)
16 | cos = np.cos(pitchyaws)
17 | out = np.empty((n, 3))
18 | out[:, 0] = np.multiply(cos[:, 0], sin[:, 1])
19 | out[:, 1] = sin[:, 0]
20 | out[:, 2] = np.multiply(cos[:, 0], cos[:, 1])
21 | return out
22 |
23 |
24 | def vector_to_pitchyaw(vectors):
25 | r"""Convert given gaze vectors to yaw (:math:`\theta`) and pitch (:math:`\phi`) angles.
26 |
27 | Args:
28 | vectors (:obj:`numpy.array`): gaze vectors in 3D :math:`(n\times 3)`.
29 |
30 | Returns:
31 | :obj:`numpy.array` of shape :math:`(n\times 2)` with values in radians.
32 | """
33 | n = vectors.shape[0]
34 | out = np.empty((n, 2))
35 | vectors = np.divide(vectors, np.linalg.norm(vectors, axis=1).reshape(n, 1))
36 | out[:, 0] = np.arcsin(vectors[:, 1]) # theta
37 | out[:, 1] = np.arctan2(vectors[:, 0], vectors[:, 2]) # phi
38 | return out
39 |
40 | radians_to_degrees = 180.0 / np.pi
41 |
42 |
43 | def angular_error(a, b):
44 | """Calculate angular error (via cosine similarity)."""
45 | a = pitchyaw_to_vector(a) if a.shape[1] == 2 else a
46 | b = pitchyaw_to_vector(b) if b.shape[1] == 2 else b
47 |
48 | ab = np.sum(np.multiply(a, b), axis=1)
49 | a_norm = np.linalg.norm(a, axis=1)
50 | b_norm = np.linalg.norm(b, axis=1)
51 |
52 | # Avoid zero-values (to avoid NaNs)
53 | a_norm = np.clip(a_norm, a_min=1e-7, a_max=None)
54 | b_norm = np.clip(b_norm, a_min=1e-7, a_max=None)
55 |
56 | similarity = np.divide(ab, np.multiply(a_norm, b_norm))
57 |
58 | return np.arccos(similarity) * radians_to_degrees
59 |
60 |
61 | def mean_angular_error(a, b):
62 | """Calculate mean angular error (via cosine similarity)."""
63 | return np.mean(angular_error(a, b))
64 |
65 |
66 | def draw_gaze(image_in, eye_pos, pitchyaw, length=40.0, thickness=2, color=(0, 0, 255)):
67 | """Draw gaze angle on given image with a given eye positions."""
68 | image_out = image_in
69 | if len(image_out.shape) == 2 or image_out.shape[2] == 1:
70 | image_out = cv.cvtColor(image_out, cv.COLOR_GRAY2BGR)
71 | dx = -length * np.sin(pitchyaw[1])
72 | dy = length * np.sin(pitchyaw[0])
73 | cv.arrowedLine(image_out, tuple(np.round(eye_pos).astype(np.int32)),
74 | tuple(np.round([eye_pos[0] + dx, eye_pos[1] + dy]).astype(int)), color,
75 | thickness, cv.LINE_AA, tipLength=0.2)
76 | return image_out
77 |
--------------------------------------------------------------------------------
/util/preprocess.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import util.gaze
4 | from scipy.spatial.transform import Rotation as R
5 |
6 | def preprocess_unityeyes_image(img, json_data):
7 | ow = 160
8 | oh = 96
9 | # Prepare to segment eye image
10 | ih, iw = img.shape[:2]
11 | ih_2, iw_2 = ih/2.0, iw/2.0
12 |
13 | heatmap_w = int(ow/2)
14 | heatmap_h = int(oh/2)
15 |
16 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
17 |
18 | def process_coords(coords_list):
19 | coords = [eval(l) for l in coords_list]
20 | return np.array([(x, ih-y, z) for (x, y, z) in coords])
21 |
22 | interior_landmarks = process_coords(json_data['interior_margin_2d'])
23 | caruncle_landmarks = process_coords(json_data['caruncle_2d'])
24 | iris_landmarks = process_coords(json_data['iris_2d'])
25 |
26 | left_corner = np.mean(caruncle_landmarks[:, :2], axis=0)
27 | right_corner = interior_landmarks[8, :2]
28 | eye_width = 1.5 * abs(left_corner[0] - right_corner[0])
29 | eye_middle = np.mean([np.amin(interior_landmarks[:, :2], axis=0),
30 | np.amax(interior_landmarks[:, :2], axis=0)], axis=0)
31 |
32 | # Normalize to eye width.
33 | scale = ow/eye_width
34 |
35 | translate = np.asmatrix(np.eye(3))
36 | translate[0, 2] = -eye_middle[0] * scale
37 | translate[1, 2] = -eye_middle[1] * scale
38 |
39 | rand_x = np.random.uniform(low=-10, high=10)
40 | rand_y = np.random.uniform(low=-10, high=10)
41 | recenter = np.asmatrix(np.eye(3))
42 | recenter[0, 2] = ow/2 + rand_x
43 | recenter[1, 2] = oh/2 + rand_y
44 |
45 | scale_mat = np.asmatrix(np.eye(3))
46 | scale_mat[0, 0] = scale
47 | scale_mat[1, 1] = scale
48 |
49 | angle = 0 #np.random.normal(0, 1) * 20 * np.pi/180
50 | rotation = R.from_rotvec([0, 0, angle]).as_matrix()
51 |
52 | transform = recenter * rotation * translate * scale_mat
53 | transform_inv = np.linalg.inv(transform)
54 |
55 | # Apply transforms
56 | eye = cv2.warpAffine(img, transform[:2], (ow, oh))
57 |
58 | rand_blur = np.random.uniform(low=0, high=20)
59 | eye = cv2.GaussianBlur(eye, (5, 5), rand_blur)
60 |
61 | # Normalize eye image
62 | eye = cv2.equalizeHist(eye)
63 | eye = eye.astype(np.float32)
64 | eye = eye / 255.0
65 |
66 | # Gaze
67 | # Convert look vector to gaze direction in polar angles
68 | look_vec = np.array(eval(json_data['eye_details']['look_vec']))[:3].reshape((1, 3))
69 | #look_vec = np.matmul(look_vec, rotation.T)
70 |
71 | gaze = util.gaze.vector_to_pitchyaw(-look_vec).flatten()
72 | gaze = gaze.astype(np.float32)
73 |
74 | iris_center = np.mean(iris_landmarks[:, :2], axis=0)
75 |
76 | landmarks = np.concatenate([interior_landmarks[:, :2], # 8
77 | iris_landmarks[::2, :2], # 8
78 | iris_center.reshape((1, 2)),
79 | [[iw_2, ih_2]], # Eyeball center
80 | ]) # 18 in total
81 |
82 | landmarks = np.asmatrix(np.pad(landmarks, ((0, 0), (0, 1)), 'constant', constant_values=1))
83 | landmarks = np.asarray(landmarks * transform[:2].T) * np.array([heatmap_w/ow, heatmap_h/oh])
84 | landmarks = landmarks.astype(np.float32)
85 |
86 | # Swap columns so that landmarks are in (y, x), not (x, y)
87 | # This is because the network outputs landmarks as (y, x) values.
88 | temp = np.zeros((34, 2), dtype=np.float32)
89 | temp[:, 0] = landmarks[:, 1]
90 | temp[:, 1] = landmarks[:, 0]
91 | landmarks = temp
92 |
93 | heatmaps = get_heatmaps(w=heatmap_w, h=heatmap_h, landmarks=landmarks)
94 |
95 | assert heatmaps.shape == (34, heatmap_h, heatmap_w)
96 |
97 | return {
98 | 'img': eye,
99 | 'transform': np.asarray(transform),
100 | 'transform_inv': np.asarray(transform_inv),
101 | 'eye_middle': np.asarray(eye_middle),
102 | 'heatmaps': np.asarray(heatmaps),
103 | 'landmarks': np.asarray(landmarks),
104 | 'gaze': np.asarray(gaze)
105 | }
106 |
107 |
108 | def gaussian_2d(w, h, cx, cy, sigma=1.0):
109 | """Generate heatmap with single 2D gaussian."""
110 | xs, ys = np.meshgrid(
111 | np.linspace(0, w - 1, w, dtype=np.float32),
112 | np.linspace(0, h - 1, h, dtype=np.float32)
113 | )
114 |
115 | assert xs.shape == (h, w)
116 | alpha = -0.5 / (sigma ** 2)
117 | heatmap = np.exp(alpha * ((xs - cx) ** 2 + (ys - cy) ** 2))
118 | return heatmap
119 |
120 |
121 | def get_heatmaps(w, h, landmarks):
122 | heatmaps = []
123 | for (y, x) in landmarks:
124 | heatmaps.append(gaussian_2d(w, h, cx=x, cy=y, sigma=2.0))
125 | return np.array(heatmaps)
126 |
--------------------------------------------------------------------------------
/util/softargmax.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | def softargmax2d(input, beta=100, dtype=torch.float32):
7 | *_, h, w = input.shape
8 |
9 | input = input.reshape(*_, h * w)
10 | input = nn.functional.softmax(beta * input, dim=-1)
11 |
12 | indices_c, indices_r = np.meshgrid(
13 | np.linspace(0, 1, w),
14 | np.linspace(0, 1, h),
15 | indexing='xy'
16 | )
17 |
18 | indices_r = torch.tensor(np.reshape(indices_r, (-1, h * w)))
19 | indices_c = torch.tensor(np.reshape(indices_c, (-1, h * w)))
20 |
21 | device = input.get_device()
22 | if device >= 0:
23 | indices_r = indices_r.to(device)
24 | indices_c = indices_c.to(device)
25 |
26 | result_r = torch.sum((h - 1) * input * indices_r, dim=-1)
27 | result_c = torch.sum((w - 1) * input * indices_c, dim=-1)
28 |
29 | result = torch.stack([result_r, result_c], dim=-1)
30 |
31 | return result.type(dtype)
32 |
33 |
34 | def softargmax1d(input, beta=100):
35 | *_, n = input.shape
36 | input = nn.functional.softmax(beta * input, dim=-1)
37 | indices = torch.linspace(0, 1, n)
38 | result = torch.sum((n - 1) * input * indices, dim=-1)
39 | return result
40 |
41 |
--------------------------------------------------------------------------------