├── .env.example ├── .gitignore ├── Dockerfile ├── GazeEstimation ├── __init__.py ├── config │ ├── rgb.yml │ └── rgbd.yml ├── datasets │ ├── __init__.py │ ├── rgbddataset.py │ └── utils.py ├── lopo.py ├── models │ ├── __init__.py │ ├── spatial_weights_cnn.py │ ├── two_stream.py │ └── utils.py ├── preprocess │ └── format.py └── utils.py ├── LICENSE ├── README.md ├── docker-compose.yml ├── poetry.lock ├── poetry.toml ├── pyproject.toml └── setup.cfg /.env.example: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 2 | DOCKER_IMAGE=rgbdgaze 3 | DOCKER_SHM_SIZE=4g 4 | DOCKER_RUNTIME=nvidia 5 | HOST_DATADRIVE= 6 | DOCKER_SSH_AUTH_SOCK=/ssh-agent 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CUDA=10.2 2 | ARG CUDNN=7 3 | ARG UBUNTU=18.04 4 | 5 | FROM nvidia/cuda:${CUDA}-cudnn${CUDNN}-devel-ubuntu${UBUNTU} 6 | ENV DEBIAN_FRONTEND noninteractive 7 | 8 | ARG CUDA 9 | ARG CUDNN 10 | ARG UBUNTU 11 | ARG PYTHON=3.8.7 12 | 13 | ENV PYTHON_ROOT /root/local/python-$PYTHON 14 | ENV PATH $PYTHON_ROOT/bin:$PATH 15 | ENV PYENV_ROOT /root/.pyenv 16 | ENV POETRY=1.2.1 17 | 18 | RUN rm -f /etc/apt/sources.list.d/nvidia-ml.list && \ 19 | apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub && \ 20 | apt update && \ 21 | apt install -y --no-install-recommends \ 22 | build-essential \ 23 | ca-certificates \ 24 | cmake \ 25 | curl \ 26 | git \ 27 | less \ 28 | libbz2-dev \ 29 | libffi-dev \ 30 | libgl1 \ 31 | liblzma-dev \ 32 | libncurses5-dev \ 33 | libncursesw5-dev \ 34 | libreadline-dev \ 35 | libsqlite3-dev \ 36 | libssl-dev \ 37 | llvm \ 38 | make \ 39 | openssh-client \ 40 | python-openssl \ 41 | tk-dev \ 42 | tmux \ 43 | unzip \ 44 | vim \ 45 | wget \ 46 | xz-utils \ 47 | zip \ 48 | zlib1g-dev \ 49 | && apt-get clean \ 50 | && rm -rf /var/lib/apt/lists/* 51 | 52 | # python build 53 | RUN git clone https://github.com/pyenv/pyenv.git $PYENV_ROOT && \ 54 | $PYENV_ROOT/plugins/python-build/install.sh && \ 55 | /usr/local/bin/python-build -v $PYTHON $PYTHON_ROOT && \ 56 | rm -rf $PYENV_ROOT 57 | 58 | ENV HOME /root 59 | WORKDIR $HOME 60 | 61 | # install poetry 62 | RUN curl -sSL https://install.python-poetry.org | POETRY_VERSION=$POETRY python3 - 63 | ENV PATH $HOME/.local/bin:$PATH 64 | 65 | COPY pyproject.toml poetry.lock poetry.toml $WORKDIR/ 66 | 67 | 68 | RUN mkdir -m 700 $HOME/.ssh && ssh-keyscan github.com > $HOME/.ssh/known_hosts 69 | RUN --mount=type=ssh poetry install --no-root 70 | RUN pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html 71 | 72 | -------------------------------------------------------------------------------- /GazeEstimation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FIGLAB/RGBDGaze/bab5ce066d75f9011162e99013f7a511ef0aa6ea/GazeEstimation/__init__.py -------------------------------------------------------------------------------- /GazeEstimation/config/rgb.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: SpatialWeightsCNN 3 | NAME: paper_rgb 4 | IM_SIZE: 448 5 | FEATURE_TYPE: rgb 6 | 7 | DATA: 8 | NORMALIZE: TRUE 9 | NAME: RGBDGaze_dataset 10 | FEATURE_TYPE: rgb 11 | 12 | TRAIN: 13 | NUM_EPOCH: 10 14 | NUM_GPU: 1 15 | OPTIMIZER: 16 | NAME: SGD 17 | PARAM: 18 | LR: 0.0001 19 | STEP_SIZE: 3 20 | MOMENTUM: 0.9 21 | WEIGHT_DECAY: 0.0001 22 | DATALOADER: 23 | BATCH_SIZE: 128 24 | NUM_WORKERS: 0 25 | SHUFFLE: True 26 | PIN_MEMORY: True 27 | VALIDATION: 28 | VAL_INTERVAL: 1 29 | DATALOADER: 30 | BATCH_SIZE: 32 31 | NUM_WORKERS: 0 32 | SHUFFLE: False 33 | PIN_MEMORY: True 34 | EALRY_STOP: 35 | PATIENCE: 5 36 | 37 | TEST: 38 | DATALOADER: 39 | BATCH_SIZE: 32 40 | NUM_WORKERS: 0 41 | SHUFFLE: False 42 | PIN_MEMORY: True 43 | -------------------------------------------------------------------------------- /GazeEstimation/config/rgbd.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: TwoStream 3 | NAME: paper_rgbd 4 | IM_SIZE: 448 5 | FEATURE_TYPE: rgbd 6 | 7 | DATA: 8 | NORMALIZE: TRUE 9 | NAME: RGBDGaze_dataset 10 | FEATURE_TYPE: rgbd 11 | 12 | TRAIN: 13 | NUM_EPOCH: 20 14 | NUM_GPU: 1 15 | OPTIMIZER: 16 | NAME: SGD 17 | PARAM: 18 | LR: 0.0005 19 | STEP_SIZE: 5 20 | MOMENTUM: 0.9 21 | WEIGHT_DECAY: 0.0001 22 | DATALOADER: 23 | BATCH_SIZE: 8 24 | NUM_WORKERS: 0 25 | SHUFFLE: True 26 | PIN_MEMORY: True 27 | VALIDATION: 28 | VAL_INTERVAL: 1 29 | DATALOADER: 30 | BATCH_SIZE: 8 31 | NUM_WORKERS: 0 32 | SHUFFLE: False 33 | PIN_MEMORY: True 34 | EALRY_STOP: 35 | PATIENCE: 5 36 | 37 | TEST: 38 | DATALOADER: 39 | BATCH_SIZE: 8 40 | NUM_WORKERS: 0 41 | SHUFFLE: False 42 | PIN_MEMORY: True 43 | -------------------------------------------------------------------------------- /GazeEstimation/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import get_dataset 2 | -------------------------------------------------------------------------------- /GazeEstimation/datasets/rgbddataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from logging import getLogger 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import torch.utils.data as data 9 | from tqdm import tqdm 10 | 11 | logger = getLogger('logger') 12 | 13 | ACTIVITES = ['standing', 'walking', 'sitting', 'lying'] 14 | 15 | 16 | def load_dataset(dataset_dir, config): 17 | with open(os.path.join(dataset_dir, 'metadata.pkl'), 'rb') as f: 18 | metadata = pickle.load(f) 19 | data_config = config['DATA'] 20 | if 'TEST_PID' in data_config.keys(): 21 | test_pids = data_config['TEST_PID'] 22 | else: 23 | raise KeyError('No key of `TEST_PID` in config["DATA"]') 24 | 25 | if 'VAL_PID' in data_config.keys(): 26 | val_pids = data_config['VAL_PID'] 27 | else: 28 | val_pids = None 29 | 30 | if 'TRAIN_PID' in data_config.keys(): 31 | train_pids = data_config['TRAIN_PID'] 32 | else: 33 | train_pids = None 34 | 35 | logger.info(f'train: {train_pids}, val: {val_pids}, test: {test_pids}') 36 | train, val, test = _load_indices(metadata, train_pids, val_pids, test_pids) 37 | feature_type = data_config['FEATURE_TYPE'] 38 | train = Dataset(dataset_dir, metadata, train, feature_type) 39 | val = Dataset(dataset_dir, metadata, val, feature_type) 40 | test = Dataset(dataset_dir, metadata, test, feature_type) 41 | return train, val, test 42 | 43 | def _load_indices(meta_data, train_pids, val_pids, test_pids): 44 | train, val, test = [], [], [] 45 | for i, pid in enumerate(meta_data['pid']): 46 | if pid in test_pids: 47 | test.append(i) 48 | elif pid in val_pids: 49 | val.append(i) 50 | elif pid in train_pids: 51 | train.append(i) 52 | 53 | logger.info(f'{len(train)=}, {len(val)=}, {len(test)=}') 54 | return train, val, test 55 | 56 | 57 | class Dataset(data.Dataset): 58 | 59 | def __init__(self, dataset_dir, metadata, indices, feature_type): 60 | 61 | self.dataset_dir = dataset_dir 62 | self.metadata = metadata 63 | self.indices = indices 64 | self.feature_type = feature_type 65 | 66 | @staticmethod 67 | def normalize(tensor): 68 | tensor = tensor.div(255) 69 | dtype = tensor.dtype 70 | mean = torch.as_tensor([0.5 for _ in range(tensor.shape[0])], 71 | dtype=dtype, 72 | device=tensor.device).view(-1, 1, 1) 73 | std = torch.as_tensor([0.5 for _ in range(tensor.shape[0])], 74 | dtype=dtype, 75 | device=tensor.device).view(-1, 1, 1) 76 | return tensor.sub_(mean).div_(std) 77 | 78 | @staticmethod 79 | def load_tensor(path, dtype): 80 | return torch.load(path) 81 | 82 | def __getitem__(self, index): 83 | index = self.indices[index] 84 | tensor_p = os.path.join( 85 | self.dataset_dir, 86 | self.metadata['pid'][index], 87 | 'tensor', 88 | self.metadata['activity'][index], 89 | f'{self.metadata["frameIndex"][index]}.pt', 90 | ) 91 | 92 | tensor = torch.load(tensor_p) 93 | if self.feature_type == 'rgbd': 94 | pass 95 | elif self.feature_type == 'rgb': 96 | tensor = tensor[:3, :, :] 97 | else: 98 | raise TypeError('Unexpected feature type: {self.feature_type=}') 99 | tensor = self.normalize(tensor) 100 | gaze = np.array([self.metadata['labelDotX'][index], 101 | self.metadata['labelDotY'][index]], np.float32) 102 | gaze = torch.FloatTensor(gaze) 103 | return tensor, gaze 104 | 105 | def __len__(self): 106 | return len(self.indices) 107 | -------------------------------------------------------------------------------- /GazeEstimation/datasets/utils.py: -------------------------------------------------------------------------------- 1 | from .rgbddataset import load_dataset as load_rgbddataset 2 | 3 | 4 | def get_dataset(data_dir, config): 5 | if config['DATA']['NAME'] == 'RGBDGaze_dataset': 6 | return load_rgbddataset(data_dir, config) 7 | else: 8 | raise ValueError(f'Unexpected dataset type: {config["DATA"]["NAME"]=}') 9 | -------------------------------------------------------------------------------- /GazeEstimation/lopo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import pickle 5 | import random 6 | import time 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.nn as nn 11 | from datasets import get_dataset 12 | from models import get_model 13 | from utils import AverageMeter, load_config 14 | 15 | random.seed(2022) 16 | 17 | 18 | logger = logging.getLogger('logger') 19 | logger.setLevel(logging.DEBUG) 20 | # logging.basicConfig(level=logging.DEBUG) 21 | 22 | stream_handler = logging.StreamHandler() 23 | logger.addHandler(stream_handler) 24 | 25 | 26 | def get_all_participants(): 27 | ret = [] 28 | for p in Path('/root/datadrive/RGBDGaze/dataset/RGBDGaze_dataset').iterdir(): 29 | if not p.is_dir(): 30 | continue 31 | if not (p / 'tensor').is_dir(): 32 | continue 33 | ret.append(str(p).split('/')[-1]) 34 | return ret 35 | 36 | 37 | PARTICIPANTS = get_all_participants() 38 | 39 | 40 | def get_args(): 41 | parser = argparse.ArgumentParser(description='RGBD Gaze Estimation (pytorch).') 42 | parser.add_argument( 43 | '--project_dir', 44 | default='/root/datadrive/RGBDGaze', 45 | help='Path to project directory.', 46 | ) 47 | parser.add_argument( 48 | '--config', 49 | required=True, 50 | default='./config/rgbd.yml', 51 | help='Path to config yaml.', 52 | ) 53 | parser.add_argument( 54 | '--checkpoint', 55 | default='/root/datadrive/RGBDGaze/models/SpatialWeightsCNN_gazecapture/pretrained_rgb.pth', 56 | help='Path to checkpoint.', 57 | ) 58 | args = parser.parse_args() 59 | return args 60 | 61 | 62 | def get_dataloaders(data_dir, config): 63 | train, val, test = get_dataset(data_dir, config) 64 | 65 | train_loader = torch.utils.data.DataLoader( 66 | train, 67 | batch_size=config['TRAIN']['DATALOADER']['BATCH_SIZE'], 68 | shuffle=config['TRAIN']['DATALOADER']['SHUFFLE'], 69 | num_workers=config['TRAIN']['DATALOADER']['NUM_WORKERS'], 70 | pin_memory=config['TRAIN']['DATALOADER']['PIN_MEMORY'], 71 | ) 72 | 73 | val_loader = torch.utils.data.DataLoader( 74 | val, 75 | batch_size=config['TRAIN']['VALIDATION']['DATALOADER']['BATCH_SIZE'], 76 | shuffle=config['TRAIN']['VALIDATION']['DATALOADER']['SHUFFLE'], 77 | num_workers=config['TRAIN']['VALIDATION']['DATALOADER']['NUM_WORKERS'], 78 | pin_memory=config['TRAIN']['VALIDATION']['DATALOADER']['PIN_MEMORY'], 79 | ) 80 | 81 | test_loader = torch.utils.data.DataLoader( 82 | test, 83 | batch_size=config['TEST']['DATALOADER']['BATCH_SIZE'], 84 | shuffle=config['TEST']['DATALOADER']['SHUFFLE'], 85 | num_workers=config['TEST']['DATALOADER']['NUM_WORKERS'], 86 | pin_memory=config['TEST']['DATALOADER']['PIN_MEMORY'], 87 | ) 88 | return train_loader, val_loader, test_loader 89 | 90 | 91 | def get_criterion(): 92 | return nn.MSELoss().cuda() 93 | 94 | 95 | def get_optimizer(model, config): 96 | if config['TRAIN']['OPTIMIZER']['NAME'] == 'SGD': 97 | optimizer = torch.optim.SGD( 98 | model.parameters(), 99 | lr=config['TRAIN']['OPTIMIZER']['PARAM']['LR'], 100 | momentum=config['TRAIN']['OPTIMIZER']['PARAM']['MOMENTUM'], 101 | weight_decay=config['TRAIN']['OPTIMIZER']['PARAM']['WEIGHT_DECAY'], 102 | ) 103 | else: 104 | raise ValueError(f'Unexpected optimizer: {config["TRAIN"]["OPTIMIZER"]}') 105 | return optimizer 106 | 107 | 108 | def get_device(config): 109 | assert config['TRAIN']['NUM_GPU'] == torch.cuda.device_count(), \ 110 | f'Expected number of GPUs is {config["TRAIN"]["NUM_GPU"]}, ' \ 111 | + f'but {torch.cuda.device_count()} are going to be used.' 112 | return torch.device('cuda' if torch.cuda.is_available() else 'cpu') 113 | 114 | 115 | def get_scheduler(optimizer, config): 116 | return torch.optim.lr_scheduler.StepLR(optimizer, config['TRAIN']['OPTIMIZER']['PARAM']['STEP_SIZE'], gamma=0.1) 117 | 118 | 119 | def load_checkpoint(fp): 120 | data = torch.load(fp) 121 | return data 122 | 123 | 124 | def save_checkpoint(data, fp): 125 | os.makedirs(os.path.dirname(fp), exist_ok=True) 126 | torch.save(data, fp) 127 | 128 | 129 | def main(): 130 | args = get_args() 131 | config = load_config(args.config) 132 | print(config) 133 | data_dir = os.path.join(args.project_dir, 'dataset', config['DATA']['NAME']) 134 | model_dir = os.path.join(args.project_dir, 'models', 'LOPO-'+config['MODEL']['NAME']) 135 | logger.info(f'Using data from {data_dir} and output result to {model_dir}') 136 | 137 | torch.backends.cudnn.benchmark = True # input image size is fixed 138 | criterion = get_criterion() 139 | device = get_device(config) 140 | participants = PARTICIPANTS 141 | logger.info(f'Using {len(participants)} participants. LEAVE-ONE-PARTICIPANT-OUT') 142 | 143 | for pid in participants: 144 | logger.info('='*15) 145 | save_dir = os.path.join(model_dir, pid) 146 | if os.path.exists(save_dir): 147 | logger.info(f'already exist: skipping {pid}') 148 | continue 149 | else: 150 | os.makedirs(save_dir, exist_ok=False) 151 | 152 | handler = logging.FileHandler(filename=f'{save_dir}/log.log') 153 | logger.addHandler(handler) 154 | 155 | logger.info(f'Set {pid} as TEST and output result to {save_dir}') 156 | config['DATA']['TEST_PID'] = [pid] 157 | config['DATA']['VAL_PID'] = random.sample([p for p in participants if p != pid], 5) 158 | config['DATA']['TRAIN_PID'] = [p for p in participants if p != pid and p != config['DATA']['VAL_PID']] 159 | model = get_model(config) 160 | train_loader, val_loader, test_loader = get_dataloaders(data_dir, config) 161 | model = model.to(device) 162 | optimizer = get_optimizer(model, config) 163 | scheduler = get_scheduler(optimizer, config) 164 | 165 | if args.checkpoint is not None: 166 | saved = load_checkpoint(args.checkpoint) 167 | model.load_pretrained_data(saved['state_dict']) 168 | best_prec1 = saved['best_prec1'] 169 | del saved 170 | model.to(device) 171 | logger.info(f'Loaded checkpoint from {args.checkpoint}. \\ {best_prec1=}') 172 | 173 | max_epoch = config['TRAIN']['NUM_EPOCH'] 174 | best_model = model 175 | best_prec1 = 1e20 176 | best_epoch = 0 177 | 178 | for ep in range(1, max_epoch+1): 179 | train(train_loader, model, criterion, optimizer, ep, device) 180 | scheduler.step() 181 | if ep % config['TRAIN']['VALIDATION']['VAL_INTERVAL'] == 0: 182 | prec1, _, _ = evaluate(val_loader, 'val', model, criterion, ep, device) 183 | if prec1 < best_prec1: 184 | best_epoch = ep 185 | best_prec1 = prec1 186 | best_model = model 187 | if ep == 1: 188 | os.system(f'cp {args.config} {model_dir}') 189 | 190 | test_prec1, _, _ = evaluate(test_loader, 'test', model, criterion, ep, device) 191 | with open(os.path.join(save_dir, f'val_loss_{ep}.txt'), 'w') as f: 192 | f.write(str(prec1)) 193 | with open(os.path.join(save_dir, f'test_loss_{ep}.txt'), 'w') as f: 194 | f.write(str(test_prec1)) 195 | 196 | save_checkpoint( 197 | { 198 | 'epoch': best_epoch, 199 | 'state_dict': best_model.state_dict(), 200 | 'best_val_prec1': best_prec1, 201 | }, 202 | os.path.join(save_dir, f'best.pth'), 203 | ) 204 | 205 | loss, gt_list, pred_list = evaluate(test_loader, 'test', best_model, criterion, best_epoch, device) 206 | with open(os.path.join(save_dir, 'gt_list.pkl'), 'wb') as f: 207 | pickle.dump(gt_list, f) 208 | with open(os.path.join(save_dir, 'pred_list.pkl'), 'wb') as f: 209 | pickle.dump(pred_list, f) 210 | with open(os.path.join(save_dir, 'test_loss.txt'), 'w') as f: 211 | f.write(str(loss)) 212 | logger.removeHandler(handler) 213 | 214 | 215 | def train(loader, model, criterion, optimizer, epoch, device): 216 | batch_time = AverageMeter() 217 | losses = AverageMeter() 218 | 219 | # switch to train mode 220 | model.train() 221 | 222 | end = time.time() 223 | 224 | for i, data in enumerate(loader): 225 | 226 | assert isinstance(data, list) 227 | assert len(data) == 2 228 | face, gaze = data 229 | face = face.to(device) 230 | output = model(face) 231 | 232 | n_batch = face.size(0) 233 | gaze = gaze.to(device) 234 | loss = criterion(output, gaze) 235 | losses.update(loss.data.item(), n_batch) 236 | optimizer.zero_grad() 237 | loss.backward() 238 | optimizer.step() 239 | 240 | # measure elapsed time 241 | batch_time.update(time.time() - end) 242 | end = time.time() 243 | 244 | logger.info( 245 | 'Epoch (train): [{0}][{1}/{2}]\t' 246 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 247 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 248 | epoch, i, len(loader), batch_time=batch_time, loss=losses, 249 | ), 250 | ) 251 | 252 | 253 | @torch.no_grad() 254 | def evaluate(loader, val_type, model, criterion, epoch, device): 255 | batch_time = AverageMeter() 256 | losses = AverageMeter() 257 | l2_errors = AverageMeter() 258 | 259 | # switch to evaluate mode 260 | model.eval() 261 | 262 | gt_list, pred_list = [], [] 263 | end = time.time() 264 | for i, data in enumerate(loader): 265 | 266 | assert isinstance(data, list) 267 | assert len(data) == 2 268 | face, gaze = data 269 | face = face.to(device) 270 | output = model(face) 271 | 272 | gt_list.append(gaze) 273 | pred_list.append(output.to(torch.device('cpu'))) 274 | 275 | n_batch = face.size(0) 276 | gaze = gaze.to(device) 277 | loss = criterion(output, gaze) 278 | l2_error = output - gaze 279 | l2_error = torch.mul(l2_error, l2_error) 280 | l2_error = torch.sum(l2_error, 1) 281 | l2_error = torch.mean(torch.sqrt(l2_error)) 282 | losses.update(loss.data.item(), n_batch) 283 | l2_errors.update(l2_error.item(), n_batch) 284 | 285 | # measure elapsed time 286 | batch_time.update(time.time() - end) 287 | end = time.time() 288 | 289 | logger.info( 290 | 'Epoch ({0}): [{1}][{2}/{3}]\t' 291 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 292 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 293 | 'Error L2 {l2_error.val:.4f} ({l2_error.avg:.4f})\t'.format( 294 | val_type, epoch, i, len(loader), batch_time=batch_time, loss=losses, l2_error=l2_errors, 295 | ), 296 | ) 297 | 298 | return l2_errors.avg, gt_list, pred_list 299 | 300 | 301 | if __name__ == '__main__': 302 | main() 303 | print('DONE') 304 | -------------------------------------------------------------------------------- /GazeEstimation/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import get_model 2 | -------------------------------------------------------------------------------- /GazeEstimation/models/spatial_weights_cnn.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision import models 8 | 9 | logger = logging.getLogger('logger') 10 | 11 | 12 | def init_conv(m, mean, variance, bias): 13 | nn.init.normal_(m.weight.data, mean, math.sqrt(variance)) 14 | nn.init.constant_(m.bias.data, bias) 15 | return m 16 | 17 | 18 | class SpatialWeightsCNN(nn.Module): 19 | 20 | def __init__(self, feature_type): 21 | super(SpatialWeightsCNN, self).__init__() 22 | self.feature_layer = nn.Sequential(models.alexnet(pretrained=True).features) 23 | if feature_type == 'rgbd': 24 | new_layer = nn.Conv2d(4, 64, kernel_size=11, stride=4, padding=2) 25 | with torch.no_grad(): 26 | new_layer.weight[:, :3, :, :] = self.feature_layer[0][0].weight 27 | logger.info(f'Reshape feature layer: {self.feature_layer[0][0].weight.shape} => {new_layer.weight.shape}') 28 | self.feature_layer[0][0] = new_layer 29 | elif feature_type == 'rgb': 30 | pass 31 | elif feature_type == 'd': 32 | new_layer = nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2) 33 | logger.info(f'Reshape feature layer: {self.feature_layer[0][0].weight.shape} => {new_layer.weight.shape}') 34 | self.feature_layer[0][0] = new_layer 35 | else: 36 | raise TypeError(f'Unexpected feature type: {feature_type=}') 37 | 38 | self.conv1 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0) 39 | self.conv2 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0) 40 | self.conv3 = nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0) 41 | 42 | self.fc1 = nn.Linear(256 * 13**2, 4096) 43 | self.fc2 = nn.Linear(4096, 4096) 44 | self.fc3 = nn.Linear(4096, 2) 45 | 46 | self._register_hook() 47 | self._initialize_weight() 48 | 49 | def _initialize_weight(self): 50 | nn.init.normal_(self.conv1.weight, mean=0, std=0.01) 51 | nn.init.normal_(self.conv2.weight, mean=0, std=0.01) 52 | nn.init.normal_(self.conv3.weight, mean=0, std=0.001) 53 | nn.init.constant_(self.conv1.bias, val=0.1) 54 | nn.init.constant_(self.conv2.bias, val=0.1) 55 | nn.init.constant_(self.conv3.bias, val=1) 56 | nn.init.normal_(self.fc1.weight, mean=0, std=0.005) 57 | nn.init.normal_(self.fc2.weight, mean=0, std=0.0001) 58 | nn.init.normal_(self.fc3.weight, mean=0, std=0.0001) 59 | nn.init.constant_(self.fc1.bias, val=1) 60 | nn.init.zeros_(self.fc2.bias) 61 | nn.init.zeros_(self.fc3.bias) 62 | 63 | def _register_hook(self): 64 | n_channels = self.conv1.in_channels 65 | 66 | def hook(module, grad_in, grad_out): 67 | return tuple(grad / n_channels for grad in grad_in) 68 | 69 | self.handles = [] 70 | self.handles.append(self.conv3.register_backward_hook(hook)) 71 | 72 | def remove_hook(self): 73 | for handle in self.handles: 74 | handle.remove() 75 | 76 | def load_pretrained_data(self, state_dict): 77 | if state_dict['feature_layer.0.0.weight'].shape == self.feature_layer[0][0].weight.shape: 78 | self.load_state_dict(state_dict) 79 | else: 80 | first_layer_weight = state_dict['feature_layer.0.0.weight'] 81 | tmp_layer = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2) 82 | self.feature_layer[0][0] = tmp_layer 83 | self.load_state_dict(state_dict) 84 | new_layer = nn.Conv2d(4, 64, kernel_size=11, stride=4, padding=2) 85 | with torch.no_grad(): 86 | new_layer.weight[:, :3, :, :] = first_layer_weight 87 | self.feature_layer[0][0] = new_layer 88 | 89 | def forward(self, inputs): 90 | x = inputs 91 | x = self.feature_layer(x) 92 | y = F.relu(self.conv1(x)) 93 | y = F.relu(self.conv2(y)) 94 | y = F.relu(self.conv3(y)) 95 | x = x * y 96 | x = x.view(x.size(0), -1) 97 | x = F.dropout(F.relu(self.fc1(x)), p=0.5, training=self.training) 98 | x = F.dropout(F.relu(self.fc2(x)), p=0.5, training=self.training) 99 | x = self.fc3(x) 100 | return x 101 | -------------------------------------------------------------------------------- /GazeEstimation/models/two_stream.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .spatial_weights_cnn import SpatialWeightsCNN 9 | 10 | logger = logging.getLogger('logger') 11 | 12 | 13 | def init_conv(m, mean, variance, bias): 14 | nn.init.normal_(m.weight.data, mean, math.sqrt(variance)) 15 | nn.init.constant_(m.bias.data, bias) 16 | return m 17 | 18 | 19 | class TwoStream(nn.Module): 20 | 21 | def __init__(self): 22 | super(TwoStream, self).__init__() 23 | self.rgb_spatial = SpatialWeightsCNN(feature_type='rgb') 24 | self.d_spatial = SpatialWeightsCNN(feature_type='d') 25 | 26 | self.fc1 = nn.Linear(256 * 13**2 * 2, 4096) 27 | self.fc2 = nn.Linear(4096, 1024) 28 | self.fc3 = nn.Linear(1024, 2) 29 | 30 | self._initialize_weight() 31 | 32 | def _initialize_weight(self): 33 | nn.init.normal_(self.fc1.weight, mean=0, std=0.005) 34 | nn.init.normal_(self.fc2.weight, mean=0, std=0.0001) 35 | nn.init.normal_(self.fc3.weight, mean=0, std=0.0001) 36 | nn.init.constant_(self.fc1.bias, val=1) 37 | nn.init.zeros_(self.fc2.bias) 38 | nn.init.zeros_(self.fc3.bias) 39 | 40 | def load_pretrained_data(self, state_dict): 41 | 42 | try: 43 | self.load_state_dict(state_dict) 44 | logger.info('Load TwoStream from checkpoint') 45 | except BaseException: 46 | self.rgb_spatial.load_pretrained_data(state_dict) 47 | logger.info('Load RGB part of TwoStream from checkpoint') 48 | 49 | def forward(self, inputs): 50 | rgb = inputs[:, :3, :, :] 51 | d = inputs[:, 3, :, :].unsqueeze(1) 52 | 53 | rgb = self.rgb_spatial.feature_layer(rgb) 54 | rgb_y = F.relu(self.rgb_spatial.conv1(rgb)) 55 | rgb_y = F.relu(self.rgb_spatial.conv2(rgb_y)) 56 | rgb_y = F.relu(self.rgb_spatial.conv3(rgb_y)) 57 | 58 | rgb = rgb * rgb_y 59 | rgb = rgb.view(rgb.size(0), -1) 60 | 61 | d = self.d_spatial.feature_layer(d) 62 | d_y = F.relu(self.d_spatial.conv1(d)) 63 | d_y = F.relu(self.d_spatial.conv2(d_y)) 64 | d_y = F.relu(self.d_spatial.conv3(d_y)) 65 | d = d * d_y 66 | d = d.view(d.size(0), -1) 67 | 68 | x = torch.cat((rgb, d), 1) 69 | x = F.dropout(F.relu(self.fc1(x)), p=0.5, training=self.training) 70 | x = F.dropout(F.relu(self.fc2(x)), p=0.5, training=self.training) 71 | x = self.fc3(x) 72 | return x 73 | 74 | @torch.no_grad() 75 | def get_middle_tensors(self, inputs): 76 | rgb = inputs[:, :3, :, :] 77 | d = inputs[:, 3, :, :].unsqueeze(1) 78 | ret = {} 79 | 80 | rgb = self.rgb_spatial.feature_layer(rgb) 81 | ret['rgb'] = rgb 82 | rgb_y = F.relu(self.rgb_spatial.conv1(rgb)) 83 | rgb_y = F.relu(self.rgb_spatial.conv2(rgb_y)) 84 | rgb_y = F.relu(self.rgb_spatial.conv3(rgb_y)) 85 | ret['rgb_y'] = rgb_y 86 | 87 | d = self.d_spatial.feature_layer(d) 88 | ret['d'] = d 89 | d_y = F.relu(self.d_spatial.conv1(d)) 90 | d_y = F.relu(self.d_spatial.conv2(d_y)) 91 | d_y = F.relu(self.d_spatial.conv3(d_y)) 92 | ret['d_y'] = d_y 93 | return ret 94 | 95 | 96 | if __name__ == '__main__': 97 | model = TwoStream() 98 | x = torch.rand([1, 4, 448, 448]) 99 | print(model(x)) 100 | -------------------------------------------------------------------------------- /GazeEstimation/models/utils.py: -------------------------------------------------------------------------------- 1 | from .spatial_weights_cnn import SpatialWeightsCNN 2 | from .two_stream import TwoStream 3 | 4 | 5 | def get_model(config): 6 | if config['MODEL']['TYPE'] == 'SpatialWeightsCNN': 7 | model = SpatialWeightsCNN(feature_type=config['MODEL']['FEATURE_TYPE']) 8 | elif config['MODEL']['TYPE'] == 'TwoStream': 9 | model = TwoStream() 10 | else: 11 | raise ValueError(f'Unexpected model type: {config["MODEL"]["TYPE"]=}') 12 | 13 | return model 14 | -------------------------------------------------------------------------------- /GazeEstimation/preprocess/format.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from pathlib import Path 4 | 5 | import cv2 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from tqdm import tqdm 10 | 11 | DATA_DIR = '/root/datadrive/RGBDGaze/dataset/RGBDGaze_dataset' 12 | METADATA_PATH = f'{DATA_DIR}/metadata.pkl' 13 | ACTIVITIES = ['standing', 'walking', 'sitting', 'lying'] 14 | INPUT_SIZE = (448, 448) 15 | 16 | 17 | def get_all_participants(): 18 | ret = [] 19 | for p in Path(DATA_DIR).iterdir(): 20 | if not p.is_dir(): 21 | continue 22 | if not (p / 'decoded').is_dir(): 23 | continue 24 | ret.append(str(p).split('/')[-1]) 25 | return ret 26 | 27 | 28 | def get_device_info(device, spec_df): 29 | info = {} 30 | # merge `iPhone 12` -> `iPhone12` 31 | if device.split('iPhone')[1][0] == ' ': 32 | device = 'iPhone' + device.split('iPhone')[1][1:] 33 | 34 | if device in spec_df.Name.tolist(): 35 | for key in ['w_pt', 'w_cm', 'h_pt', 'h_cm']: 36 | info[key] = spec_df[spec_df['Name'] == device][key].tolist()[0] 37 | else: 38 | raise NotImplementedError(f'{device=} cannot be used') 39 | return info 40 | 41 | 42 | def pt_to_cm(x_pt, y_pt, info): 43 | x_cm = x_pt / info['w_pt'] * info['w_cm'] 44 | y_cm = y_pt / info['h_pt'] * info['h_cm'] 45 | 46 | x_cm -= (info['w_cm']) / 2 47 | y_cm *= -1 48 | return x_cm, y_cm 49 | 50 | 51 | def make_tensor(rgb, d, x, y, w, h): 52 | assert rgb.size > d.size 53 | if w != h: 54 | print(f'{w=} != {h=}') 55 | w = h = min(w, h) 56 | rgb = cv2.flip(rgb, 1) 57 | d = cv2.flip(d, 1) 58 | scale = rgb.shape[0] / d.shape[0] 59 | scaled_bbox_length = int(w/scale) 60 | scaled_rgb = cv2.resize(rgb, d.shape[::-1]) 61 | concat = np.concatenate((scaled_rgb, d[:, :, np.newaxis]), axis=2) 62 | assert np.array_equal(concat[:, :, :3], scaled_rgb) 63 | concat = concat[int(x/scale):int(x/scale)+scaled_bbox_length, 64 | int(y/scale):int(y/scale)+scaled_bbox_length, :] 65 | concat = cv2.resize(concat, INPUT_SIZE) 66 | # concat = cv2.flip(concat, 1) 67 | concat = concat.transpose((2, 0, 1)) 68 | return torch.from_numpy(concat) 69 | 70 | 71 | def process(pid, activity): 72 | print(f'==== {pid=} {activity=} ====') 73 | csv_file = f'{DATA_DIR}/{pid}/decoded/{activity}/label.csv' 74 | df = pd.read_csv(csv_file) 75 | spec_df = pd.read_csv(f'{DATA_DIR}/iphone_spec.csv') 76 | ret = { 77 | 'pid': [], 78 | 'device': [], 79 | 'activity': [], 80 | 'frameIndex': [], 81 | 'labelDotX': [], 82 | 'labelDotY': [], 83 | 'imuX': [], 84 | 'imuY': [], 85 | 'imuZ': [], 86 | } 87 | pbar = tqdm(total=len(df)) 88 | device_info = None 89 | if os.path.isdir(f'{DATA_DIR}/{pid}/tensor'): 90 | print('remove existing tensors') 91 | os.system('rm -rf {DATA_DIR}/{pid}/tensor') 92 | for (i, row) in df.iterrows(): 93 | pbar.update(1) 94 | uid = row['uid'] 95 | x, y, w, h = int(row['bbox_x']), int(row['bbox_y']), int(row['bbox_w']), int(row['bbox_h']) 96 | rgb_path = f'{DATA_DIR}/{pid}/decoded/{activity}/rgb/{uid}.jpg' 97 | depth_path = f'{DATA_DIR}/{pid}/decoded/{activity}/depth/{uid}.jpg' 98 | try: 99 | rgb_img = cv2.cvtColor(cv2.imread(rgb_path), cv2.COLOR_BGR2RGB) 100 | depth_img = cv2.imread(depth_path, cv2.IMREAD_GRAYSCALE) 101 | tensor = make_tensor(rgb_img, depth_img, x, y, w, h) 102 | except Exception as e: 103 | print(e) 104 | continue 105 | save_tensor_path = f'{DATA_DIR}/{pid}/tensor/{activity}/{uid}.pt' 106 | if not os.path.isdir(os.path.dirname(save_tensor_path)): 107 | os.system(f'mkdir -p {os.path.dirname(save_tensor_path)}') 108 | torch.save(tensor, save_tensor_path) 109 | 110 | if device_info is None: 111 | device = row['device'] 112 | device_info = get_device_info(device, spec_df) 113 | gt_x_cm, gt_y_cm = pt_to_cm(row['gt_x_pt'], row['gt_y_pt'], device_info) 114 | 115 | ret['pid'].append(pid) 116 | ret['device'].append(device) 117 | ret['activity'].append(activity) 118 | ret['frameIndex'].append(uid) 119 | ret['labelDotX'].append(gt_x_cm) 120 | ret['labelDotY'].append(gt_y_cm) 121 | ret['imuX'].append(float(row['imu_x'])) 122 | ret['imuY'].append(float(row['imu_y'])) 123 | ret['imuZ'].append(float(row['imu_z'])) 124 | 125 | return ret 126 | 127 | 128 | def update_metadata(org, new): 129 | for k in new.keys(): 130 | if k in org.keys(): 131 | org[k].extend(new[k]) 132 | else: 133 | org[k] = new[k] 134 | return org 135 | 136 | 137 | def main(): 138 | participants = get_all_participants() 139 | print('participants: ', participants) 140 | 141 | if os.path.exists(METADATA_PATH): 142 | with open(METADATA_PATH, 'rb') as f: 143 | metadata = pickle.load(f) 144 | else: 145 | metadata = {} 146 | 147 | for pid in participants: 148 | for activity in ACTIVITIES: 149 | sub_metadata = process(pid, activity) 150 | update_metadata(metadata, sub_metadata) 151 | with open(METADATA_PATH, 'wb') as f: 152 | pickle.dump(metadata, f) 153 | 154 | 155 | if __name__ == '__main__': 156 | main() 157 | -------------------------------------------------------------------------------- /GazeEstimation/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | 22 | 23 | def load_config(fp): 24 | with open(fp) as f: 25 | config = yaml.safe_load(f) 26 | return config 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. 340 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Research Code for RGBDGaze 2 | 3 | ![image](https://user-images.githubusercontent.com/12772049/200086408-2d85ff1b-9858-480c-b972-0b4b48239906.png) 4 | 5 | 6 | This is the research repository for [RGBDGaze: Gaze Tracking on Smartphones with RGB and Depth Data](https://dl.acm.org/doi/10.1145/3536221.3556568), presented at ACM ICMI 2022. 7 | 8 | It contains the training code and dataset link. 9 | 10 | # Environment 11 | 12 | - docker 13 | - docker-compose 14 | - nvidia-docker 15 | - nvidia-driver 16 | 17 | # How to use 18 | 19 | ## 1. Download dataset and pretrained RGB model 20 | 21 | - Dataset: https://www.dropbox.com/s/iixjrxzxx7nbupl/RGBDGaze_dataset.zip?dl=0 22 | - RGB part of the Spatial CNN model pretrained with [GazeCapture](https://gazecapture.csail.mit.edu/) dataset: https://www.dropbox.com/s/dgpn0j212l260q1/pretrained_rgb.pth?dl=0 23 | 24 | ## 2. Clone 25 | 26 | ``` 27 | $ git clone https://github.com/FIGLAB/RGBDGaze 28 | ``` 29 | 30 | ## 3. Setup 31 | 32 | ``` 33 | $ cp .env{.example,} 34 | ``` 35 | 36 | In `.env`, you can set a path to your data directory. 37 | 38 | ## 4. Docker build & run 39 | 40 | ``` 41 | $ DOCKER_BUILDKIT=1 docker build -t rgbdgaze --ssh default . 42 | $ docker-compose run --rm experiment 43 | ``` 44 | 45 | ## 5. Run 46 | 47 | Prepare following files in the docker container 48 | - /root/datadrive/RGBDGaze/dataset/RGBDGaze_dataset 49 | - /root/datadrive/RGBDGaze/models/SpatialWeightsCNN_gazecapture/pretrained_rgb.pth 50 | 51 | Make tensors to be used for training 52 | ``` 53 | $ cd preprocess 54 | $ python format.py 55 | ``` 56 | 57 | For training RGB+D model, run 58 | ``` 59 | $ python lopo.py --config ./config/rgbd.yml 60 | ``` 61 | 62 | For training RGB model, run 63 | ``` 64 | $ python lopo.py --config ./config/rgb.yml 65 | ``` 66 | 67 | # Dataset description 68 | 69 | ## Overview 70 | 71 | The data is organized in the following manner: 72 | 73 | - 45 participants (*1) 74 | - synchronized RGB + Depth images for different four context 75 | - standing, sitting, walking, and lying 76 | - meta data 77 | - corresponding gaze target on the screen 78 | - detected face bounding box 79 | - acceleration data 80 | - device id 81 | - intrinsic camera parameter of the device 82 | 83 | - *1: We used 50 participants data in the paper. However, five of them did not agree to be included in the public dataset. 84 | 85 | 86 | ## Structure 87 | 88 | The folder structure is organized like this: 89 | 90 | ``` 91 | RGBDGaze_dataset 92 | │ README.txt 93 | │ iphone_spec.csv 94 | │ 95 | └───P1 96 | │ │ intrinsic.json 97 | │ │ 98 | │ └───decoded 99 | │ │ 100 | │ └───standing 101 | │ │ │ label.csv 102 | │ │ │ 103 | │ │ └───rgb 104 | │ │ │ 1.jpg 105 | │ │ │ 2.jpg ... 106 | │ │ │ 107 | │ │ └───depth 108 | │ │ 109 | │ └───sitting 110 | │ └───walking 111 | │ └───lying 112 | │ 113 | └───P2 ... 114 | ``` 115 | 116 | 117 | # Reference 118 | 119 | [Download the paper here.](https://rikky0611.github.io/resource/paper/rgbdgaze_icmi2022_paper.pdf) 120 | 121 | ``` 122 | Riku Arakawa, Mayank Goel, Chris Harrison, Karan Ahuja. 2022. RGBDGaze: Gaze Tracking on Smartphones with RGB and Depth Data In Proceedings of the 2022 International Conference on Multimodal Interaction (ICMI '22). Association for Computing Machinery, New York, NY, USA. 123 | ``` 124 | 125 | ``` 126 | @inproceedings{DBLP:conf/icmi/ArakawaG0A22, 127 | author = {Riku Arakawa and 128 | Mayank Goel and 129 | Chris Harrison and 130 | Karan Ahuja}, 131 | title = {RGBDGaze: Gaze Tracking on Smartphones with {RGB} and Depth Data}, 132 | booktitle = {International Conference on Multimodal Interaction, {ICMI} 2022, Bengaluru, 133 | India, November 7-11, 2022}, 134 | pages = {329--336}, 135 | publisher = {{ACM}}, 136 | year = {2022}, 137 | doi = {10.1145/3536221.3556568}, 138 | address = {New York}, 139 | } 140 | ``` 141 | 142 | # License 143 | 144 | GPL v 2.0 License file present in repo. Please contact innovation@cmu.edu if you would like another license for your use. 145 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "2.3" 2 | services: 3 | experiment: 4 | build: 5 | context: . 6 | environment: 7 | - SSH_AUTH_SOCK=$DOCKER_SSH_AUTH_SOCK 8 | env_file: 9 | - .env 10 | image: $DOCKER_IMAGE 11 | runtime: $DOCKER_RUNTIME 12 | shm_size: $DOCKER_SHM_SIZE 13 | volumes: 14 | - .:/root/workspace 15 | - $HOST_DATADRIVE:/root/datadrive 16 | - $SSH_AUTH_SOCK:/ssh-agent 17 | - /run/host-services/ssh-auth.sock:/run/host-services/ssh-auth.sock 18 | command: /bin/bash 19 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | create = false 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "RGBDGaze" 3 | version = "1.0.0" 4 | description = "Algorithm package for RGBDGaze." 5 | authors = ["Riku Arakawa"] 6 | repository = "https://github.com/FIGLAB/RGBDGaze" 7 | packages = [{ include = "GazeEstimation" }] 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.8" 11 | numpy = "*" 12 | pandas = "*" 13 | PyYAML = "*" 14 | pycuda = "*" 15 | scipy = "*" 16 | scikit-learn = "==0.23.1" 17 | tensorboard = "*" 18 | torch = "*" 19 | tqdm = "*" 20 | torchvision = "*" 21 | coremltools = "*" 22 | opencv-python = "==4.5.5.64" 23 | seaborn = "^0.11.2" 24 | 25 | [tool.poetry.dev-dependencies] 26 | autopep8 = "*" 27 | flake8 = "*" 28 | flake8-commas = "*" 29 | flake8-isort = "*" 30 | flake8-quotes = "*" 31 | jupyter = "*" 32 | jupyterlab = "*" 33 | matplotlib = "*" 34 | pep8-naming = "*" 35 | 36 | [build-system] 37 | requires = ["poetry>=0.12"] 38 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | include_trailing_comma=True 3 | line_length = 120 4 | multi_line_output = 5 5 | 6 | [flake8] 7 | max-line-length = 120 8 | exclude = */__init__.py 9 | ignore = E123,E126,E226,F541 10 | --------------------------------------------------------------------------------