├── LICENSE.md ├── OpenGaze__ICCV_2019_Sup_.pdf ├── README.md ├── code ├── README.md ├── data_loader.py ├── model.py ├── resnet.py ├── run.py ├── test.txt ├── train.txt └── validation.txt └── dataset ├── README.md └── assets └── coords.png /LICENSE.md: -------------------------------------------------------------------------------- 1 | ### Copyright (c) 2019 - Petr Kellnhofer, Adrià Recasens, Simon Stent, Wojciech Matusik, and Antonio Torralba. 2 | 3 | ## LICENSE AGREEMENT FOR USE OF GAZE360 DATABASE AND MODELS 4 | 5 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this License Agreement for Use of Gaze360 Database and Models ("Research License"). To the extent this Research License may be interpreted as a contract, You are granted the rights mentioned below in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 6 | 7 | ### Section 1 – Definitions 8 | 9 | a. __Licensor__ means the individual(s) or entity(ies) granting rights under this Research License. 10 | 11 | b. __You__ means the individual or entity exercising the Licensed Rights under this Research License. Your has a corresponding meaning. 12 | 13 | c. __Licensed Material__ refers to the Gaze360 database, models, and any related source. These contain eye-tracking data captured using our acquisition setup and machine learning models to predict where individuals are looking. 14 | 15 | ### Section 2 – Scope 16 | 17 | 1. Licensor desires to grant a license to You for the use of the Licensed Material. This license will in no case be considered a transfer of the Licensed Material. 18 | 19 | 2. You shall have no rights with respect to the Licensed Material or any portion thereof and shall not use the Licensed Material except as expressly set forth in this Agreement. 20 | 21 | 3. Subject to the terms and conditions of this Agreement, Licensor hereby grants to You for research use only, a royalty-free, non-exclusive, non-transferable, license subject to the following conditions: 22 | 23 | * The Licensed Material is only for Your research use and, in a need-to-know basis, of those direct research colleagues who belong to the same research institution as You and have adhered to the terms of this license. 24 | 25 | * The Licensed Material will not be copied nor distributed in any form other than for Your backup. 26 | * The Licensed Material will only be used for research purposes and will not be used nor included in commercial applications in any form (such as original files, encrypted files, files containing extracted features, models trained on dataset, other derivative works, etc). 27 | * Any work made public, whatever the form, based directly or indirectly on any part of the Licensed Material must include the following reference: 28 | 29 | > Petr Kellnhofer, Adrià Recasens, Simon Stent, Wojciech Matusik, and Antonio Torralba. “Gaze360: Physically Unconstrained Gaze Estimation in the Wild”. IEEE International Conference on Computer Vision (ICCV), 2019. 30 | 31 | 4. Licensor complies with the State of Massachusetts legislation in force. It is Your, and only yours, to comply with all the data protection laws that may affect You. 32 | 33 | ### Section 3 – Disclaimer of Warranties and Limitation of Liability 34 | a. Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You. 35 | 36 | b. To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Research License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You. 37 | 38 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 39 | 40 | ### Section 4 – Term and Termination 41 | a. If You fail to comply with this Research License, then Your rights under this Research License terminate automatically. 42 | 43 | b. Where Your right to use the Licensed Material has terminated under Section 4(a), it reinstates: 44 | 45 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 46 | 47 | 2. upon express reinstatement by the Licensor. 48 | 49 | For the avoidance of doubt, this Section 4(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Research License. 50 | 51 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Research License. 52 | 53 | d. Sections 1, 3, 4, 5 and 6 survive termination of this Research License. 54 | 55 | ### Section 5 – Other Terms and Conditions 56 | 57 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 58 | 59 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Research License. 60 | 61 | ### Section 6 – Interpretation 62 | 63 | a. For the avoidance of doubt, this Research License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Research License. 64 | 65 | b. To the extent possible, if any provision of this Research License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Research License without affecting the enforceability of the remaining terms and conditions. 66 | 67 | c. No term or condition of this Research License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 68 | 69 | d. Nothing in this Research License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 70 | 71 | -------------------------------------------------------------------------------- /OpenGaze__ICCV_2019_Sup_.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erkil1452/gaze360/cbeb8e4a241ecbda8ce62731243758fad4dda60a/OpenGaze__ICCV_2019_Sup_.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gaze360: Physically Unconstrained Gaze Estimation in the Wild Dataset 2 | 3 | ## Updated online demo: https://colab.research.google.com/drive/1SJbzd-gFTbiYjfZynIfrG044fWi6svbV?usp=sharing 4 | 5 | ## About 6 | 7 | This is code and data for Gaze360. The usage of the dataset and the code is for non-commercial research use only. By using this code you agree to terms of the [LICENSE](https://github.com/Erkil1452/gaze360/blob/master/LICENSE.md). If you use our dataset or code cite our [paper](http://gaze360.csail.mit.edu/iccv2019_gaze360.pdf) as: 8 | 9 | > Petr Kellnhofer*, Adrià Recasens*, Simon Stent, Wojciech Matusik, and Antonio Torralba. “Gaze360: Physically Unconstrained Gaze Estimation in the Wild”. IEEE International Conference on Computer Vision (ICCV), 2019. 10 | 11 | ``` 12 | @inproceedings{gaze360_2019, 13 | author = {Petr Kellnhofer and Adria Recasens and Simon Stent and Wojciech Matusik and and Antonio Torralba}, 14 | title = {Gaze360: Physically Unconstrained Gaze Estimation in the Wild}, 15 | booktitle = {IEEE International Conference on Computer Vision (ICCV)}, 16 | month = {October}, 17 | year = {2019} 18 | } 19 | ``` 20 | 21 | ## Gaze360 dataset 22 | You can obtain the Gaze360 dataset at [http://gaze360.csail.mit.edu](http://gaze360.csail.mit.edu). You can find an accurate description of its usage in the [dataset](https://github.com/erkil1452/gaze360/tree/master/dataset) section of this repository. 23 | 24 | 25 | ## Code 26 | You can find the code for the Gaze360 model in the [code](https://github.com/erkil1452/gaze360/tree/master/code) section of this repository. You can find code and instructions describing how to run Gaze360 in Youtube videos in 27 | - [Google Colab notebook v2](https://colab.research.google.com/drive/1SJbzd-gFTbiYjfZynIfrG044fWi6svbV?usp=sharing) (new version using detectron2). 28 | - [Google Colab notebook (beta)](https://colab.research.google.com/drive/1AUvmhpHklM9BNt0Mn5DjSo3JRuqKkU4y) (original version using original DensePose release). 29 | 30 | -------------------------------------------------------------------------------- /code/README.md: -------------------------------------------------------------------------------- 1 | # Gaze360: Physically Unconstrained Gaze Estimation in the Wild Dataset 2 | 3 | ## About 4 | 5 | This is code for training and running our Gaze360 model. The usage of this code is for non-commercial research use only. By using this code you agree to terms of the [LICENSE](https://github.com/Erkil1452/gaze360/blob/master/LICENSE.md). If you use our dataset or code cite our [paper](x) as: 6 | 7 | > Petr Kellnhofer*, Adrià Recasens*, Simon Stent, Wojciech Matusik, and Antonio Torralba. “Gaze360: Physically Unconstrained Gaze Estimation in the Wild”. IEEE International Conference on Computer Vision (ICCV), 2019. 8 | 9 | ``` 10 | @inproceedings{gaze360_2019, 11 | author = {Petr Kellnhofer and Adria Recasens and Simon Stent and Wojciech Matusik and and Antonio Torralba}, 12 | title = {Gaze360: Physically Unconstrained Gaze Estimation in the Wild}, 13 | booktitle = {IEEE International Conference on Computer Vision (ICCV)}, 14 | month = {October}, 15 | year = {2019} 16 | } 17 | ``` 18 | 19 | ## Data 20 | You can obtain the Gaze360 dataset and more information at [http://gaze360.csail.mit.edu](http://gaze360.csail.mit.edu). 21 | 22 | This repository provides already processed txt files with the split for training the Gaze360 model. The txt contains the following information: 23 | * Row 1: Image path 24 | * Row 2-4: Gaze vector 25 | 26 | Note that these splits only contain the samples which have available a one second window in the dataset. 27 | 28 | ## Requriments 29 | The implementation has been tested wihth PyTorch 1.1.0 but it is likely to work on previous version of PyTorch as well. 30 | 31 | 32 | ## Structure 33 | 34 | The code consists of 35 | - This readme. 36 | - The training/val/test splits to train the Gaze360 model, as described in the Data section. 37 | - The model and loss definition (model.py) 38 | - A script for training and evaluation of the Gaze360 model (run.py). 39 | - A data loader specific for the Gaze360 dataset (data_loader.py) 40 | 41 | ## Trained models 42 | 43 | The model weights can be downloaded from this [link](http://gaze360.csail.mit.edu/files/gaze360_model.pth.tar) 44 | 45 | ## Gaze360 in videos 46 | A beta version of the notebook describing how to run Gaze360 on Youtube videos is now [online](https://colab.research.google.com/drive/1AUvmhpHklM9BNt0Mn5DjSo3JRuqKkU4y)! 47 | -------------------------------------------------------------------------------- /code/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import torchvision.transforms as transforms 6 | import torch 7 | import numpy as np 8 | import re 9 | import glob 10 | import random 11 | import cv2 12 | import torch.nn as nn 13 | import math 14 | import random 15 | import scipy.io as sio 16 | 17 | IMG_EXTENSIONS = [ 18 | '.jpg', '.JPG', '.jpeg', '.JPEG', 19 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 20 | ] 21 | 22 | def is_image_file(filename): 23 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 24 | 25 | def make_dataset(source_path,file_name): 26 | images = [] 27 | print(file_name) 28 | with open(file_name, 'r') as f: 29 | for line in f: 30 | line = line[:-1] 31 | line = line.replace("\t", " ") 32 | line = line.replace(" ", " ") 33 | split_lines = line.split(" ") 34 | if(len(split_lines)>3): 35 | frame_number = int(split_lines[0].split('/')[-1][:-4]) 36 | lists_sources = [] 37 | for j in range(-3,4): 38 | name_frame = '/'.join(split_lines[0].split('/')[:-1]+['%0.6d.jpg'%(frame_number+j)]) 39 | name = '{0}/{1}'.format(source_path, name_frame) 40 | lists_sources.append(name) 41 | 42 | gaze = np.zeros((3)) 43 | 44 | gaze[0] = float(split_lines[1]) 45 | gaze[1] = float(split_lines[2]) 46 | gaze[2] = float(split_lines[3]) 47 | item = (lists_sources,gaze) 48 | images.append(item) 49 | return images 50 | 51 | 52 | def default_loader(path): 53 | try: 54 | im = Image.open(path).convert('RGB') 55 | return im 56 | except OSError: 57 | print(path) 58 | return Image.new("RGB", (512, 512), "white") 59 | 60 | 61 | 62 | 63 | class ImagerLoader(data.Dataset): 64 | def __init__(self, source_path,file_name, 65 | transform=None, target_transform=None, loader=default_loader): 66 | 67 | imgs = make_dataset(source_path,file_name) 68 | 69 | self.source_path = source_path 70 | self.file_name = file_name 71 | 72 | self.imgs = imgs 73 | self.transform = transform 74 | self.target_transform = transform 75 | self.loader = loader 76 | 77 | 78 | def __getitem__(self, index): 79 | path_source,gaze = self.imgs[index] 80 | 81 | 82 | gaze_float = torch.Tensor(gaze) 83 | gaze_float = torch.FloatTensor(gaze_float) 84 | normalized_gaze = nn.functional.normalize(gaze_float.view(1,3)).view(3) 85 | 86 | source_video = torch.FloatTensor(7,3,224,224) 87 | for i,frame_path in enumerate(path_source): 88 | source_video[i,...] = self.transform(self.loader(frame_path)) 89 | 90 | source_video = source_video.view(21,224,224) 91 | 92 | spherical_vector = torch.FloatTensor(2) 93 | spherical_vector[0] = math.atan2(normalized_gaze[0],-normalized_gaze[2]) 94 | spherical_vector[1] = math.asin(normalized_gaze[1]) 95 | return source_video,spherical_vector 96 | 97 | 98 | 99 | def __len__(self): 100 | return len(self.imgs) 101 | -------------------------------------------------------------------------------- /code/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.backends.cudnn as cudnn 5 | import torch.optim 6 | import torch.utils.data 7 | import torchvision.transforms as transforms 8 | import torchvision.datasets as datasets 9 | import torchvision.models as models 10 | import numpy as np 11 | from torch.nn.init import normal, constant 12 | import math 13 | from resnet import resnet18 14 | 15 | 16 | 17 | 18 | class GazeLSTM(nn.Module): 19 | def __init__(self): 20 | super(GazeLSTM, self).__init__() 21 | self.img_feature_dim = 256 # the dimension of the CNN feature to represent each frame 22 | 23 | self.base_model = resnet18(pretrained=True) 24 | 25 | self.base_model.fc2 = nn.Linear(1000, self.img_feature_dim) 26 | 27 | self.lstm = nn.LSTM(self.img_feature_dim, self.img_feature_dim,bidirectional=True,num_layers=2,batch_first=True) 28 | 29 | # The linear layer that maps the LSTM with the 3 outputs 30 | self.last_layer = nn.Linear(2*self.img_feature_dim, 3) 31 | 32 | 33 | def forward(self, input): 34 | 35 | base_out = self.base_model(input.view((-1, 3) + input.size()[-2:])) 36 | 37 | base_out = base_out.view(input.size(0),7,self.img_feature_dim) 38 | 39 | lstm_out, _ = self.lstm(base_out) 40 | lstm_out = lstm_out[:,3,:] 41 | output = self.last_layer(lstm_out).view(-1,3) 42 | 43 | 44 | angular_output = output[:,:2] 45 | angular_output[:,0:1] = math.pi*nn.Tanh()(angular_output[:,0:1]) 46 | angular_output[:,1:2] = (math.pi/2)*nn.Tanh()(angular_output[:,1:2]) 47 | 48 | var = math.pi*nn.Sigmoid()(output[:,2:3]) 49 | var = var.view(-1,1).expand(var.size(0),2) 50 | 51 | return angular_output,var 52 | class PinBallLoss(nn.Module): 53 | def __init__(self): 54 | super(PinBallLoss, self).__init__() 55 | self.q1 = 0.1 56 | self.q9 = 1-self.q1 57 | 58 | def forward(self, output_o,target_o,var_o): 59 | q_10 = target_o-(output_o-var_o) 60 | q_90 = target_o-(output_o+var_o) 61 | 62 | loss_10 = torch.max(self.q1*q_10, (self.q1-1)*q_10) 63 | loss_90 = torch.max(self.q9*q_90, (self.q9-1)*q_90) 64 | 65 | 66 | loss_10 = torch.mean(loss_10) 67 | loss_90 = torch.mean(loss_90) 68 | 69 | return loss_10+loss_90 70 | 71 | -------------------------------------------------------------------------------- /code/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, num_classes=1000): 98 | self.inplanes = 64 99 | super(ResNet, self).__init__() 100 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 101 | bias=False) 102 | self.bn1 = nn.BatchNorm2d(64) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | self.layer1 = self._make_layer(block, 64, layers[0]) 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 109 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 110 | #self.avgpool = nn.AdaptativeAvgPool((1,1), stride=1) 111 | self.fc1 = nn.Linear(512 * block.expansion, 1000) 112 | self.fc2 = nn.Linear(1000, 3) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 117 | elif isinstance(m, nn.BatchNorm2d): 118 | nn.init.constant_(m.weight, 1) 119 | nn.init.constant_(m.bias, 0) 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | x = self.layer1(x) 144 | feat_D = self.layer2(x) 145 | x = self.layer3(feat_D) 146 | x = self.layer4(x) 147 | #print('Size at output',x.size()) 148 | x = self.avgpool(x) 149 | x = x.view(x.size(0), -1) 150 | #x = nn.Dropout()(x) 151 | x = nn.ReLU()(self.fc1(x)) 152 | x = self.fc2(x) 153 | 154 | return x 155 | 156 | 157 | 158 | class ResNetCAM(nn.Module): 159 | 160 | def __init__(self, block, layers, num_classes=1000): 161 | self.inplanes = 64 162 | super(ResNetCAM, self).__init__() 163 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 164 | bias=False) 165 | self.bn1 = nn.BatchNorm2d(64) 166 | self.relu = nn.ReLU(inplace=True) 167 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 168 | self.layer1 = self._make_layer(block, 64, layers[0]) 169 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 170 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 171 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 172 | self.avgpool = nn.AvgPool2d(7, stride=1) 173 | self.fc1 = nn.Linear(512 * block.expansion, 1000) 174 | self.fc2 = nn.Linear(1000, 3) 175 | 176 | for m in self.modules(): 177 | if isinstance(m, nn.Conv2d): 178 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 179 | elif isinstance(m, nn.BatchNorm2d): 180 | nn.init.constant_(m.weight, 1) 181 | nn.init.constant_(m.bias, 0) 182 | 183 | def _make_layer(self, block, planes, blocks, stride=1): 184 | downsample = None 185 | if stride != 1 or self.inplanes != planes * block.expansion: 186 | downsample = nn.Sequential( 187 | nn.Conv2d(self.inplanes, planes * block.expansion, 188 | kernel_size=1, stride=stride, bias=False), 189 | nn.BatchNorm2d(planes * block.expansion), 190 | ) 191 | 192 | layers = [] 193 | layers.append(block(self.inplanes, planes, stride, downsample)) 194 | self.inplanes = planes * block.expansion 195 | for i in range(1, blocks): 196 | layers.append(block(self.inplanes, planes)) 197 | 198 | return nn.Sequential(*layers) 199 | 200 | def forward(self, x): 201 | x = self.conv1(x) 202 | x = self.bn1(x) 203 | x = self.relu(x) 204 | x = self.maxpool(x) 205 | x = self.layer1(x) 206 | x2 = self.layer2(x) 207 | x2 = self.layer3(x2) 208 | x2 = self.layer4(x2) 209 | return x,x2 210 | 211 | def resnetCAM(pretrained=False, **kwargs): 212 | """Constructs a ResNet-18 model. 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | """ 216 | model = ResNetCAM(BasicBlock, [2, 2, 2, 2], **kwargs) 217 | return model 218 | 219 | def resnet18(pretrained=False, **kwargs): 220 | """Constructs a ResNet-18 model. 221 | Args: 222 | pretrained (bool): If True, returns a model pre-trained on ImageNet 223 | """ 224 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 225 | if pretrained: 226 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']),strict=False) 227 | return model 228 | 229 | 230 | def resnet34(pretrained=False, **kwargs): 231 | """Constructs a ResNet-34 model. 232 | Args: 233 | pretrained (bool): If True, returns a model pre-trained on ImageNet 234 | """ 235 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 236 | if pretrained: 237 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 238 | return model 239 | 240 | 241 | def resnet50(pretrained=False, **kwargs): 242 | """Constructs a ResNet-50 model. 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | """ 246 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 247 | if pretrained: 248 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']),strict=False) 249 | return model 250 | 251 | 252 | def resnet101(pretrained=False, **kwargs): 253 | """Constructs a ResNet-101 model. 254 | Args: 255 | pretrained (bool): If True, returns a model pre-trained on ImageNet 256 | """ 257 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 258 | if pretrained: 259 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 260 | return model 261 | 262 | 263 | def resnet152(pretrained=False, **kwargs): 264 | """Constructs a ResNet-152 model. 265 | Args: 266 | pretrained (bool): If True, returns a model pre-trained on ImageNet 267 | """ 268 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 269 | if pretrained: 270 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 271 | return model 272 | 273 | -------------------------------------------------------------------------------- /code/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | import torchvision.transforms as transforms 12 | import torchvision.datasets as datasets 13 | import torchvision.models as models 14 | from data_loader import ImagerLoader 15 | import numpy as np 16 | import random 17 | import math 18 | import torchvision.utils as vutils 19 | from model import GazeLSTM,PinBallLoss 20 | 21 | source_path = "../imgs/" 22 | val_file = "validation.txt" 23 | train_file = "train.txt" 24 | 25 | test_file = "test.txt" 26 | 27 | 28 | workers = 30; 29 | epochs = 80 30 | batch_size = 80 31 | best_error = 100 # init with a large value 32 | lr = 1e-4 33 | 34 | test = False 35 | checkpoint_test = 'gaze360_model.pth.tar' 36 | network_name = 'Gaze360' 37 | 38 | from tensorboardX import SummaryWriter 39 | foo = SummaryWriter(comment=network_name) 40 | 41 | 42 | count_test = 0 43 | count = 0 44 | 45 | 46 | 47 | 48 | def main(): 49 | global args, best_error 50 | 51 | model_v = GazeLSTM() 52 | model = torch.nn.DataParallel(model_v).cuda() 53 | model.cuda() 54 | 55 | 56 | cudnn.benchmark = True 57 | 58 | image_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 59 | 60 | 61 | train_loader = torch.utils.data.DataLoader( 62 | ImagerLoader(source_path,train_file,transforms.Compose([ 63 | transforms.RandomResizedCrop(size=224,scale=(0.8,1)),transforms.ToTensor(),image_normalize, 64 | ])), 65 | batch_size=batch_size, shuffle=True, 66 | num_workers=workers, pin_memory=True) 67 | 68 | val_loader = torch.utils.data.DataLoader( 69 | ImagerLoader(source_path,val_file,transforms.Compose([ 70 | transforms.Resize((224,224)),transforms.ToTensor(),image_normalize, 71 | ])), 72 | batch_size=batch_size, shuffle=True, 73 | num_workers=workers, pin_memory=True) 74 | 75 | 76 | 77 | criterion = PinBallLoss().cuda() 78 | 79 | optimizer = torch.optim.Adam(model.parameters(), lr) 80 | 81 | if test==True: 82 | 83 | test_loader = torch.utils.data.DataLoader( 84 | ImagerLoader(source_path,test_file,transforms.Compose([ 85 | transforms.Resize((224,224)),transforms.ToTensor(),image_normalize, 86 | ])), 87 | batch_size=batch_size, shuffle=True, 88 | num_workers=workers, pin_memory=True) 89 | checkpoint = torch.load(checkpoint_test) 90 | model.load_state_dict(checkpoint['state_dict']) 91 | angular_error = validate(test_loader, model, criterion) 92 | print('Angular Error is',angular_error) 93 | 94 | 95 | for epoch in range(0, epochs): 96 | 97 | 98 | # train for one epoch 99 | train(train_loader, model, criterion, optimizer, epoch) 100 | 101 | # evaluate on validation set 102 | angular_error = validate(val_loader, model, criterion) 103 | 104 | # remember best angular error in validation and save checkpoint 105 | is_best = angular_error < best_error 106 | best_error = min(angular_error, best_error) 107 | save_checkpoint({ 108 | 'epoch': epoch + 1, 109 | 'state_dict': model.state_dict(), 110 | 'best_prec1': best_error, 111 | }, is_best) 112 | 113 | 114 | def train(train_loader, model, criterion,optimizer, epoch): 115 | global count 116 | batch_time = AverageMeter() 117 | data_time = AverageMeter() 118 | losses = AverageMeter() 119 | prediction_error = AverageMeter() 120 | angular = AverageMeter() 121 | 122 | # switch to train mode 123 | model.train() 124 | 125 | end = time.time() 126 | 127 | for i, (source_frame,target) in enumerate(train_loader): 128 | 129 | # measure data loading time 130 | data_time.update(time.time() - end) 131 | source_frame = source_frame.cuda(async=True) 132 | target = target.cuda(async=True) 133 | 134 | 135 | source_frame_var = torch.autograd.Variable(source_frame) 136 | target_var = torch.autograd.Variable(target) 137 | 138 | # compute output 139 | output,ang_error = model(source_frame_var) 140 | 141 | 142 | loss = criterion(output, target_var,ang_error) 143 | 144 | angular_error = compute_angular_error(output,target_var) 145 | pred_error = ang_error[:,0]*180/math.pi 146 | pred_error = torch.mean(pred_error,0) 147 | 148 | angular.update(angular_error, source_frame.size(0)) 149 | 150 | losses.update(loss.item(), source_frame.size(0)) 151 | 152 | prediction_error.update(pred_error, source_frame.size(0)) 153 | 154 | 155 | foo.add_scalar("loss", losses.val, count) 156 | foo.add_scalar("angular", angular.val, count) 157 | # compute gradient and do SGD step 158 | optimizer.zero_grad() 159 | loss.backward() 160 | optimizer.step() 161 | 162 | # measure elapsed time 163 | batch_time.update(time.time() - end) 164 | end = time.time() 165 | count = count +1 166 | 167 | print('Epoch: [{0}][{1}/{2}]\t' 168 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 169 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 170 | 'Angular {angular.val:.3f} ({angular.avg:.3f})\t' 171 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 172 | 'Prediction Error {prediction_error.val:.4f} ({prediction_error.avg:.4f})\t'.format( 173 | epoch, i, len(train_loader), batch_time=batch_time, 174 | data_time=data_time, loss=losses,angular=angular,prediction_error=prediction_error)) 175 | 176 | def validate(val_loader, model, criterion): 177 | global count_test 178 | batch_time = AverageMeter() 179 | losses = AverageMeter() 180 | prediction_error = AverageMeter() 181 | model.eval() 182 | end = time.time() 183 | angular = AverageMeter() 184 | 185 | for i, (source_frame,target) in enumerate(val_loader): 186 | 187 | source_frame = source_frame.cuda(async=True) 188 | target = target.cuda(async=True) 189 | 190 | source_frame_var = torch.autograd.Variable(source_frame,volatile=True) 191 | target_var = torch.autograd.Variable(target,volatile=True) 192 | with torch.no_grad(): 193 | # compute output 194 | output,ang_error = model(source_frame_var) 195 | 196 | loss = criterion(output, target_var,ang_error) 197 | angular_error = compute_angular_error(output,target_var) 198 | pred_error = ang_error[:,0]*180/math.pi 199 | pred_error = torch.mean(pred_error,0) 200 | 201 | angular.update(angular_error, source_frame.size(0)) 202 | prediction_error.update(pred_error, source_frame.size(0)) 203 | 204 | losses.update(loss.item(), source_frame.size(0)) 205 | 206 | batch_time.update(time.time() - end) 207 | end = time.time() 208 | 209 | 210 | print('Epoch: [{0}/{1}]\t' 211 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 212 | 'Angular {angular.val:.4f} ({angular.avg:.4f})\t' 213 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 214 | i, len(val_loader), batch_time=batch_time, 215 | loss=losses,angular=angular)) 216 | 217 | foo.add_scalar("predicted error", prediction_error.avg, count) 218 | foo.add_scalar("angular-test", angular.avg, count) 219 | foo.add_scalar("loss-test", losses.avg, count) 220 | return angular.avg 221 | 222 | 223 | 224 | 225 | def save_checkpoint(state, is_best, filename='checkpoint_'+network_name+'.pth.tar'): 226 | torch.save(state, filename) 227 | if is_best: 228 | shutil.copyfile(filename, 'model_best_'+network_name+'.pth.tar') 229 | 230 | 231 | def spherical2cartesial(x): 232 | 233 | output = torch.zeros(x.size(0),3) 234 | output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0]) 235 | output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0]) 236 | output[:,1] = torch.sin(x[:,1]) 237 | 238 | return output 239 | 240 | 241 | def compute_angular_error(input,target): 242 | 243 | input = spherical2cartesial(input) 244 | target = spherical2cartesial(target) 245 | 246 | input = input.view(-1,3,1) 247 | target = target.view(-1,1,3) 248 | output_dot = torch.bmm(target,input) 249 | output_dot = output_dot.view(-1) 250 | output_dot = torch.acos(output_dot) 251 | output_dot = output_dot.data 252 | output_dot = 180*torch.mean(output_dot)/math.pi 253 | return output_dot 254 | 255 | 256 | 257 | 258 | class AverageMeter(object): 259 | """Computes and stores the average and current value""" 260 | def __init__(self): 261 | self.reset() 262 | 263 | def reset(self): 264 | self.val = 0 265 | self.avg = 0 266 | self.sum = 0 267 | self.count = 0 268 | 269 | def update(self, val, n=1): 270 | self.val = val 271 | self.sum += val * n 272 | self.count += n 273 | self.avg = self.sum / self.count 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | if __name__ == '__main__': 283 | main() 284 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | # Gaze360: Physically Unconstrained Gaze Estimation in the Wild Dataset 2 | 3 | ## About 4 | 5 | This is a dataset of 197588 frames from 238 subjects with 3D gaze annotations as captured in our Gaze360 dataset. The dataset is for non-commercial research use only. By using this dataset you agree to terms of the [LICENSE](https://github.com/Erkil1452/gaze360/blob/master/LICENSE.md). If you use our dataset or code cite our [paper](http://gaze360.csail.mit.edu/iccv2019_gaze360.pdf) as: 6 | 7 | > Petr Kellnhofer*, Adrià Recasens*, Simon Stent, Wojciech Matusik, and Antonio Torralba. “Gaze360: Physically Unconstrained Gaze Estimation in the Wild”. IEEE International Conference on Computer Vision (ICCV), 2019. 8 | 9 | ``` 10 | @inproceedings{gaze360_2019, 11 | author = {Petr Kellnhofer and Adria Recasens and Simon Stent and Wojciech Matusik and and Antonio Torralba}, 12 | title = {Gaze360: Physically Unconstrained Gaze Estimation in the Wild}, 13 | booktitle = {IEEE International Conference on Computer Vision (ICCV)}, 14 | month = {October}, 15 | year = {2019} 16 | } 17 | ``` 18 | 19 | You can obtain this dataset and more information at [http://gaze360.csail.mit.edu](http://gaze360.csail.mit.edu) . 20 | 21 | ## Content 22 | 23 | The dataset contains body and head crops in original resolution as captured by our capturing setup based on Ladybug5 360 degree camera. The camera is equipped with 5+1 vertical sensors with `2048 x 2448` pixel resolution each. After rectification the images were stored as `3382 x 4096` pixel JPEG files. The provided crops are made from these original frames. As the camera sensor field of views overlap, same person may be captured by two cameras. Each such capture is treated as an independent sample (but all are counted as a single person in our 238 subject dataset). 24 | 25 | [AlphaPose](https://github.com/MVIG-SJTU/AlphaPose) algorithm was used for detection of subjects in our images. Refer to our [paper](http://gaze360.csail.mit.edu/iccv2019_gaze360.pdf) for more details. 26 | 27 | Additionally, we provide detections of eyes and faces from [dlib](http://dlib.net/) library wherever such detection is possible. Frames where such detection failed (occlusion, rear view,..) are marked by `-1` in respective fields. 28 | 29 | ## Structure 30 | 31 | The dataset consists of 32 | - This readme. 33 | - Our [license](https://github.com/Erkil1452/gaze360/blob/master/LICENSE.md) 34 | - `metadata.mat` Matlab file with annotations 35 | - JPEG images with head and body crops (only head crops used in our paper) 36 | 37 | ## Metadata 38 | 39 | The `metadata.mat` is a Matlab file which can be loaded using Matlab or using [scipy](https://www.scipy.org/) library for Python. Note that it uses **0-based indexing** (C-style) and not the 1-based indexing of Matlab. Therefore, if used in Matlab, +1 has to be added to every array access, e.g. `recName = recordings(recording(i) + 1);`. No such treatment is needed in Python which we use for our Pytorch model. 40 | 41 | All entries have as many items as there are frames in the dataset. The only exception is `recordings` which only defines list of independent recording session names and `splits` which defines our split names. Each recording corresponds to an acquisition session with a group of unique subjects. Refer to our [paper](http://gaze360.csail.mit.edu/iccv2019_gaze360.pdf) for more details about the data collection procedure. 42 | 43 | ### 3D Coordinates 44 | Two 3D coordinate systems are used. Refer to our [supplementary file](https://github.com/erkil1452/gaze360/blob/master/OpenGaze__ICCV_2019_Sup_.pdf) (most notably Figure 2) for more details. Briefly: 45 | - **Ladybug camera coordinate system** - right-handed coordinate system with z-axis pointing up and origin in the center of the Ladybug camera system. Units are meters. Used for object positions. 46 | - **Eye coordinate system** - right-handed coordinate system with y-axis pointing up and origin in the center of the Ladybug camera system. Units are meters. Origin and units are irelevant in practice as it is only used to express normalized gaze direction. The practical interpretation is that the positive x-axis points to the left, positive y-axis points up and positive z-axis points away from the camera, i.e. `[-1,0,0]` is a gaze looking to the right or `[0,0,-1]` straight into the camera from the camera's point of view, irrespective of subjects position in the world. 47 | 48 | ![3D Coordinate Systems](assets/coords.png) 49 | **3D Coordinate Systems**: (a) Estimting the subject's eye distance from camera using a ground plane assumption, and (b) the gaze transform between the subject's **Eye coordinate system** (`E`) and the **Ladybug camera coordinate system** (`L`). Positive `E_z` is pointing away. 50 | 51 | Please note that although the 3D gaze (`gaze_dir`) is defined as a difference between target's and subject's positions (`target_pos3d - person_eyes3d`) each of them is expressed in different coordinate system, i.e. `gaze_dir = M * (target_pos3d - person_eyes3d)` where `M` depends on a normal direction between eyes and the camera. Refer to our [supplementary file](https://github.com/erkil1452/gaze360/blob/master/OpenGaze__ICCV_2019_Sup_.pdf) for details. 52 | 53 | ### 2D Coordinates 54 | All 2D coordinates relate to a position inside the original rectified full frame with resolution of `3382 x 4096` pixels. Positions are stored as `[x, y]`. Bounding boxes are represented `[x,y,widht,height]`. All units are normalized, i.e. divided by the original frame resolution. Hence, `[0.0, 0.1, 0.1, 0.2]` is a `338 x 819` px box with offset of `[0, 410]` pixels. 55 | 56 | We do not provide full frames as they may contain personal information of unparticipating parties. We only provided crops for heads and full bodies which are defined by bounding boxes found in fields `person_head_bbox` and `person_body_bbox`. 57 | 58 | When working with these image crops, one must reproject the 2D coordinates in the `metadata.mat` from full frame into the actual crop. 59 | 60 | **Example (in Python):** 61 | 62 | `person_eye_left_bbox` defines bounding box for the left eye (if detected by dlib) inside the original full frame. To find its position inside the head crop we need to do: 63 | ``` 64 | headBBInFull = person_head_bbox[i,:] 65 | eyeBBInFull = person_eye_left_bbox[i,:] 66 | eyeBBInCrop = [ 67 | (eyeBBInFull[0] - headBBInFull[0]) / headBBInFull[2], # subtract offset of the crop 68 | (eyeBBInFull[1] - headBBInFull[1]) / headBBInFull[3], 69 | eyeBBInFull[2] / headBBInFull[2], # scale to smaller space of the crop 70 | eyeBBInFull[3] / headBBInFull[3], 71 | ] 72 | ``` 73 | The resulting `eyeBBInCrop` is still in normalized coordinates. To convert to pixels one must do: 74 | ``` 75 | imHead = cv2.imread(/**/) 76 | cropSizePx = [imHead.shape[1], imHead.shape[0]] # should be equal to -> (headBBInFull[2:] * [3382, 4096]).astype(int) 77 | eyeBBInCropPx = np.concatenate([eyeBBInCrop[:2] * cropSizePx, eyeBBInCrop[2:] * cropSizePx]).astype(int) 78 | ``` 79 | This should allow to crop the image as: 80 | ``` 81 | imEye = imHead[ 82 | eyeBBInCropPx[1]:(eyeBBInCropPx[1]+eyeBBInCropPx[3]), 83 | eyeBBInCropPx[0]:(eyeBBInCropPx[0]+eyeBBInCropPx[2]), 84 | :] 85 | ``` 86 | 87 | ### Splits ### 88 | 89 | We have used standard `train`/`val`/`test` splits of the Gaze360 dataset for training and evaluation of our model. The list of splits can be found in the `splits` entry and the assignment of particular frame in the `split`. 90 | 91 | Our model is temporal and uses symmetrical radius of 3 frames around the *central frame* (i.e., 7 frames in total) to produce prediction of gaze for the central frame. The split is always defined for the *central frame*. The first 3 and last 3 frames of each sequence cannot be used (due to missing boundary frames) and are marked as split `unused` = `3`. Note, that these frames are still being fed to the network but we do not evaluate their gaze (they are not ever the *central frame*). 92 | 93 | We also detect frames where `person_head_bbox` is (at least partially) occluded by the the target marker and mark them as `unused`. These frames can still be used as the 3-radius input but they are not used as the *central frame* for any phase of training / evaluation. 94 | 95 | ### Fields 96 | - `recordings` - recording name lists, a string used to refer to the file structure. 97 | - `recording` - an index (**0-based**) into `recordings`. 98 | - `frame` - 0-based frame index since the start of the recording. 99 | - `ts` - relative time (in seconds) since the start of the recording. 100 | - `target_cam` - index of the camera (0-4) by which the target was captured. 101 | - `target_pos3d` - 3D position of the target which subjects look at in **Ladybug camera coordinate system**. 102 | - `target_pos2d` - 2D position of the target inside the original full frame image (normalized-coordinates) 103 | - `person_identity` - a person identifier of a person used in the file structure. Unique only inside single recording. Note that a physical subject can have multiple identities inside the same recording due interuptions in tracking or capture in multiple cameras. 104 | - `person_cam` - index of the camera (0-4) by which the subject was captured. 105 | - `person_eyes3d` - 3D position of the subject's mid-eye point in **Ladybug camera coordinate system**. 106 | - `person_eyes2d` - 2D position of the subject's mid-eye point inside the original full frame image (normalized-coordinates). 107 | - `person_body_bbox` - bounding box of the subject's body in the original full frame image (normalized-coordinates). 108 | - `person_head_bbox` - bounding box of the subject's head in the original full frame image (normalized-coordinates). 109 | - `person_face_bbox` - bounding box of the subject's face in the original full frame image (normalized-coordinates) if detected by dlib, otherwise `[-1,-1,-1,-1]`. 110 | - `person_eye_left_bbox` - bounding box of the subject's left eye in the original full frame image (normalized-coordinates) if detected by dlib, otherwise `[-1,-1,-1,-1]`. 111 | - `person_eye_right_bbox` - bounding box of the subject's right eye in the original full frame image (normalized-coordinates) if detected by dlib, otherwise `[-1,-1,-1,-1]`. 112 | - `gaze_dir` - 3D gaze direction in the **Eye coordinate system** 113 | - `splits` - list of split names - `0 = train, 1 = val, 2 = test, 3 = 'unused` 114 | - `split` - ID of the split where this frame was used, i.e. `0` for `train`, `1` for `val`, `2` for `test` and `3` for `unused`. 115 | 116 | ### Images 117 | 118 | A corresponding image `i` can be accessed from this dataset base folder using the information in the `metadata.mat`. 119 | 120 | In Python: 121 | ``` 122 | im = cv2.imread(os.path.join( 123 | 'imgs', 124 | recordings[recording[i]], 125 | cropType, 126 | '%06d' % person_identity[i], 127 | '%06d.jpg' % frame[i] 128 | )) 129 | ``` 130 | 131 | In Matlab: 132 | ``` 133 | im = imread(fullfile(... 134 | 'imgs',... 135 | recordings{recording(i) + 1},... 136 | cropType,... 137 | sprintf('%06d', person_identity(i)), 138 | sprintf('%06d.jpg', frame(i))... 139 | )); 140 | ``` 141 | 142 | `cropType` is either `body` or `head`. Note that we only use `head` in our models. 143 | 144 | -------------------------------------------------------------------------------- /dataset/assets/coords.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/erkil1452/gaze360/cbeb8e4a241ecbda8ce62731243758fad4dda60a/dataset/assets/coords.png --------------------------------------------------------------------------------