├── .gitignore ├── .mailmap ├── LICENSE ├── README.md ├── assets ├── accuracy_mpii.jpg ├── accuracy_prl.jpg ├── dataset_collection_setup.jpg ├── dataset_figure.jpg ├── dataset_video.gif ├── inpaint_example.jpg ├── inpaint_overview.jpg ├── paper_abstract.jpg ├── rt_bene_best_poster_award.png ├── rt_bene_inference.gif ├── rt_bene_labels.png ├── rt_bene_overview.png ├── rt_bene_precision_recall.png ├── rtbene_pytorch_resnet18.png └── system_overview.jpg ├── rt_bene_model_training ├── README.md ├── pytorch │ ├── __init__.py │ ├── evaluate_model.ipynb │ ├── post_process_ckpt.py │ ├── rtbene_dataset.py │ ├── train_model.py │ └── util │ │ ├── GenerateRTBENEH5Dataset.py │ │ └── __init__.py └── tensorflow │ ├── __init__.py │ ├── dataset_manager.py │ ├── evaluate_blink_model.py │ ├── train_and_evaluate.py │ └── train_blink_model.py ├── rt_bene_standalone ├── README.md ├── estimate_blink_standalone.py └── samples_blink │ ├── left │ ├── left_blink.png │ └── left_open.png │ └── right │ ├── right_blink.png │ └── right_open.png ├── rt_gene ├── CMakeLists.txt ├── README.md ├── cfg │ └── ModelSize.cfg ├── launch │ ├── estimate_blink.launch │ ├── estimate_gaze.launch │ ├── start_kinect.launch │ ├── start_rosbag.launch │ ├── start_video.launch │ └── start_webcam.launch ├── model_nets │ ├── SFD │ │ └── README.md │ ├── ThreeDDFA │ │ ├── keypoints_sim.npy │ │ ├── param_whitening.pkl │ │ ├── param_whitening_py2.pkl │ │ ├── u_exp.npy │ │ └── u_shp.npy │ └── face_model_68.txt ├── msg │ ├── MSG_Blink.msg │ ├── MSG_BlinkList.msg │ ├── MSG_Gaze.msg │ ├── MSG_GazeList.msg │ ├── MSG_Headpose.msg │ ├── MSG_HeadposeList.msg │ ├── MSG_Landmarks.msg │ ├── MSG_LandmarksList.msg │ ├── MSG_SubjectImages.msg │ └── MSG_SubjectImagesList.msg ├── package.xml ├── rviz_cfg │ └── gaze_following.rviz ├── scripts │ ├── download_models.py │ ├── estimate_blink.py │ ├── estimate_gaze.py │ └── extract_landmarks_node.py ├── setup.py ├── src │ ├── __init__.py │ ├── rt_bene │ │ ├── __init__.py │ │ ├── blink_estimation_models_pytorch.py │ │ ├── estimate_blink_base.py │ │ ├── estimate_blink_pytorch.py │ │ └── estimate_blink_tensorflow.py │ └── rt_gene │ │ ├── SFD │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── net_s3fd.py │ │ └── sfd_detector.py │ │ ├── ThreeDDFA │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── ddfa.py │ │ ├── inference.py │ │ ├── io.py │ │ ├── mobilenet_v1.py │ │ └── params.py │ │ ├── __init__.py │ │ ├── download_tools.py │ │ ├── estimate_gaze_base.py │ │ ├── estimate_gaze_pytorch.py │ │ ├── estimate_gaze_tensorflow.py │ │ ├── extract_landmarks_method_base.py │ │ ├── gaze_estimation_models_pytorch.py │ │ ├── gaze_tools.py │ │ ├── gaze_tools_standalone.py │ │ ├── kalman_stabilizer.py │ │ ├── ros_tools.py │ │ ├── subject_ros_bridge.py │ │ ├── tracker_face_encoding.py │ │ ├── tracker_generic.py │ │ └── tracker_sequential.py └── webcam_configs │ ├── kinect2_calibration.yaml │ └── webcam_blue_26010230.yaml ├── rt_gene_inpainting ├── GAN_train.py ├── GAN_train_run.ipynb ├── GlassesCompletion.py ├── GlassesCompletion_run.py ├── README.md ├── external │ ├── LICENSE │ ├── __init__.py │ └── poissonblending.py ├── models.py └── utils.py ├── rt_gene_model_training ├── README.md ├── __init__.py ├── pytorch │ ├── __init__.py │ ├── evaluate_model.py │ ├── post_process_ckpt.py │ ├── rtgene_dataset.py │ ├── train_model.py │ └── utils │ │ ├── CombineGazeH5Datasets.py │ │ ├── GazeAngleAccuracy.py │ │ ├── GenerateEyePatchesRTGENEDataset.py │ │ ├── GenerateMPIIH5Dataset.py │ │ ├── GenerateRTGENEH5Dataset.py │ │ ├── LearningRateFinder.py │ │ ├── PinballLoss.py │ │ └── __init__.py └── tensorflow │ ├── __init__.py │ ├── evaluate_model.py │ ├── evaluate_models.sh │ ├── prepare_dataset.m │ ├── train_model.py │ ├── train_models_run.sh │ └── train_tools.py └── rt_gene_standalone ├── README.md ├── estimate_gaze_standalone.py └── samples_gaze ├── gaze_center.jpg ├── gaze_down.jpg ├── gaze_left.jpg ├── gaze_right.jpg └── gaze_up.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom stuff 2 | .DS_Store 3 | CMakeLists.txt.user 4 | .idea* 5 | scripts/.idea* 6 | .vscode 7 | 8 | # Model files 9 | rt_gene/model_nets/Model_allsubjects1.h5 10 | rt_gene/model_nets/dnn_deploy.prototxt 11 | rt_gene/model_nets/res10_300x300_ssd_iter_140000.caffemodel 12 | rt_gene/model_nets/dlib_face_recognition_resnet_model_v1.dat 13 | rt_gene/model_nets/ThreeDDFA/w_exp_sim.npy 14 | rt_gene/model_nets/ThreeDDFA/w_shp_sim.npy 15 | rt_gene/model_nets/phase1_wpdc_vdc.pth.tar 16 | rt_gene/model_nets/SFD/s3fd_facedetector.pth 17 | rt_gene/model_nets/all_subjects_mpii_prl_utmv_0_02.h5 18 | rt_gene/model_nets/all_subjects_mpii_prl_utmv_1_02.h5 19 | rt_gene/model_nets/all_subjects_mpii_prl_utmv_2_02.h5 20 | rt_gene/model_nets/all_subjects_mpii_prl_utmv_3_02.h5 21 | rt_gene/model_nets/blink_model_1.h5 22 | rt_gene/model_nets/blink_model_2.h5 23 | rt_gene/model_nets/rt-bene_mobilenetv2_fold1_best.h5 24 | rt_gene/model_nets/rt-bene_mobilenetv2_fold2_best.h5 25 | rt_gene/model_nets/rt-bene_mobilenetv2_fold3_best.h5 26 | rt_gene/model_nets/Model_allsubjects1_pytorch.model 27 | rt_gene/model_nets/Model_allsubjects2_pytorch.model 28 | rt_gene/model_nets/Model_allsubjects3_pytorch.model 29 | rt_gene/model_nets/Model_allsubjects4_pytorch.model 30 | rt_gene/model_nets/Model_prl_mpii_allsubjects1_pytorch.model 31 | rt_gene/model_nets/Model_prl_mpii_allsubjects2_pytorch.model 32 | rt_gene/model_nets/blink_model_pytorch_resnet18_allsubjects1.model 33 | rt_gene/model_nets/blink_model_pytorch_resnet18_allsubjects2.model 34 | rt_gene/model_nets/gaze_model_pytorch_vgg16_prl_mpii_allsubjects1.model 35 | rt_gene/model_nets/gaze_model_pytorch_vgg16_prl_mpii_allsubjects2.model 36 | rt_gene/model_nets/gaze_model_pytorch_vgg16_prl_mpii_allsubjects3.model 37 | rt_gene/model_nets/gaze_model_pytorch_vgg16_prl_mpii_allsubjects4.model 38 | rt_gene/model_nets/blink_model_pytorch_vgg16_allsubjects1.model 39 | rt_gene/model_nets/blink_model_pytorch_vgg16_allsubjects2.model 40 | rt_gene/model_nets/blink_model_pytorch_vgg16_allsubjects3.model 41 | rt_gene_standalone/samples_gaze/out 42 | 43 | *~ 44 | # Byte-compiled / optimized / DLL files 45 | __pycache__/ 46 | *.py[cod] 47 | *$py.class 48 | 49 | # C extensions 50 | *.so 51 | 52 | # Distribution / packaging 53 | .Python 54 | env/ 55 | build/ 56 | develop-eggs/ 57 | dist/ 58 | downloads/ 59 | eggs/ 60 | .eggs/ 61 | lib/ 62 | lib64/ 63 | parts/ 64 | sdist/ 65 | var/ 66 | wheels/ 67 | *.egg-info/ 68 | .installed.cfg 69 | *.egg 70 | 71 | # PyInstaller 72 | # Usually these files are written by a python script from a template 73 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 74 | *.manifest 75 | *.spec 76 | 77 | # Installer logs 78 | pip-log.txt 79 | pip-delete-this-directory.txt 80 | 81 | # Unit test / coverage reports 82 | htmlcov/ 83 | .tox/ 84 | .coverage 85 | .coverage.* 86 | .cache 87 | nosetests.xml 88 | coverage.xml 89 | *,cover 90 | .hypothesis/ 91 | 92 | # Translations 93 | *.mo 94 | *.pot 95 | 96 | # Django stuff: 97 | *.log 98 | local_settings.py 99 | 100 | # Flask stuff: 101 | instance/ 102 | .webassets-cache 103 | 104 | # Scrapy stuff: 105 | .scrapy 106 | 107 | # Sphinx documentation 108 | docs/_build/ 109 | 110 | # PyBuilder 111 | target/ 112 | 113 | # Jupyter Notebook 114 | .ipynb_checkpoints 115 | 116 | # pyenv 117 | .python-version 118 | 119 | # celery beat schedule file 120 | celerybeat-schedule 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # dotenv 126 | .env 127 | 128 | # virtualenv 129 | .venv 130 | venv/ 131 | ENV/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | -------------------------------------------------------------------------------- /.mailmap: -------------------------------------------------------------------------------- 1 | Tobias Fischer 2 | Tobias Fischer 3 | Tobias Fischer 4 | -------------------------------------------------------------------------------- /assets/accuracy_mpii.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/accuracy_mpii.jpg -------------------------------------------------------------------------------- /assets/accuracy_prl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/accuracy_prl.jpg -------------------------------------------------------------------------------- /assets/dataset_collection_setup.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/dataset_collection_setup.jpg -------------------------------------------------------------------------------- /assets/dataset_figure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/dataset_figure.jpg -------------------------------------------------------------------------------- /assets/dataset_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/dataset_video.gif -------------------------------------------------------------------------------- /assets/inpaint_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/inpaint_example.jpg -------------------------------------------------------------------------------- /assets/inpaint_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/inpaint_overview.jpg -------------------------------------------------------------------------------- /assets/paper_abstract.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/paper_abstract.jpg -------------------------------------------------------------------------------- /assets/rt_bene_best_poster_award.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/rt_bene_best_poster_award.png -------------------------------------------------------------------------------- /assets/rt_bene_inference.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/rt_bene_inference.gif -------------------------------------------------------------------------------- /assets/rt_bene_labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/rt_bene_labels.png -------------------------------------------------------------------------------- /assets/rt_bene_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/rt_bene_overview.png -------------------------------------------------------------------------------- /assets/rt_bene_precision_recall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/rt_bene_precision_recall.png -------------------------------------------------------------------------------- /assets/rtbene_pytorch_resnet18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/rtbene_pytorch_resnet18.png -------------------------------------------------------------------------------- /assets/system_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/assets/system_overview.jpg -------------------------------------------------------------------------------- /rt_bene_model_training/README.md: -------------------------------------------------------------------------------- 1 | # RT-BENE: A Dataset and Baselines for Real-Time Blink Estimation in Natural Environments 2 | [![License: CC BY-NC-SA 4.0](https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg?style=flat-square)](https://creativecommons.org/licenses/by-nc-sa/4.0/) 3 | ![stars](https://img.shields.io/github/stars/Tobias-Fischer/rt_gene.svg?style=flat-square) 4 | ![GitHub issues](https://img.shields.io/github/issues/Tobias-Fischer/rt_gene.svg?style=flat-square) 5 | ![GitHub repo size](https://img.shields.io/github/repo-size/Tobias-Fischer/rt_gene.svg?style=flat-square) 6 | 7 | ![Best Poster Award](../assets/rt_bene_best_poster_award.png) 8 | 9 | 10 | ## License + Attribution 11 | This code is licensed under [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/). Commercial usage is not permitted; please contact or regarding commercial licensing. If you use this dataset or the code in a scientific publication, please cite the following [paper](http://openaccess.thecvf.com/content_ICCVW_2019/html/GAZE/Cortacero_RT-BENE_A_Dataset_and_Baselines_for_Real-Time_Blink_Estimation_in_ICCVW_2019_paper.html): 12 | 13 | ``` 14 | @inproceedings{CortaceroICCV2019W, 15 | author={Kevin Cortacero and Tobias Fischer and Yiannis Demiris}, 16 | booktitle = {Proceedings of the IEEE International Conference on Computer Vision Workshops}, 17 | title = {RT-BENE: A Dataset and Baselines for Real-Time Blink Estimation in Natural Environments}, 18 | year = {2019}, 19 | } 20 | ``` 21 | 22 | RT-BENE was supported by the EU Horizon 2020 Project PAL (643783-RIA) and a Royal Academy of Engineering Chair in Emerging Technologies to Yiannis Demiris. 23 | 24 | More information can be found on the Personal Robotic Lab's website: . 25 | 26 | ## Requirements 27 | ### Tensorflow 28 | For pip users: `pip install tensorflow-gpu numpy tqdm opencv-python scikit-learn` or for conda users: `conda install tensorflow-gpu numpy tqdm opencv scikit-learn` 29 | ### Pytorch 30 | For conda users: conda install -c conda-forge numpy scipy tqdm pillow rospkg opencv scikit-learn h5py matplotlib pytorch-lightning pytorch torchvision 31 | 32 | ## Model training code 33 | ### Tensorflow (as per paper) 34 | This code was used to train the blink estimator for RT-BENE. The labels for the RT-BENE blink dataset are contained in the [rt_bene_dataset](../rt_bene_dataset) directory. The images corresponding to the labels can be downloaded from the RT-GENE dataset (labels are only available for the "noglasses" part): [download](https://zenodo.org/record/2529036) [(alternative link)](https://goo.gl/tfUaDm). Please run `python train_blink_model.py --help` to see the required arguments to train the model. 35 | ### Pytorch (experimental) 36 | This code attempts to duplicate the Tensorflow version using Pytorch and Pytorch-Lightning. It uses the same dataset. An HDF5 file is required, to generate this run [GenerateRTBENEH5Dataset.py](pytorch/util/GenerateRTBENEH5Dataset.py) with the argument to locate the RT-BENE dataset. 37 | [train_model](pytorch/train_model.py) contains the code required to train the model in pytorch. 38 | ### Eyepatch generation 39 | See [https://github.com/Tobias-Fischer/rt_gene/blob/master/rt_gene_model_training/pytorch/utils/GenerateEyePatchesRTGENEDataset.py](this script) to generate eye patches for model training (only needed on custom datasets) 40 | 41 | 42 | ## Model testing code 43 | ### Tensorflow 44 | Evaluation code for a 3-fold evaluation is provided in the [evaluate_blink_model.py](tensorflow/evaluate_blink_model.py) file. An example to train and evaluate an ensemble of models can be found in [train_and_evaluate.py](tensorflow/train_and_evaluate.py). Please run `python train_and_evaluate.py --help` to see the required arguments. 45 | 46 | ![Results](../assets/rt_bene_precision_recall.png) 47 | ### Pytorch 48 | Evaluation code for 3-fold validation is in [evaluate_model.py](pytorch/evaluate_model.py) 49 | ![Results](../assets/rtbene_pytorch_resnet18.png) 50 | -------------------------------------------------------------------------------- /rt_bene_model_training/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_bene_model_training/pytorch/__init__.py -------------------------------------------------------------------------------- /rt_bene_model_training/pytorch/post_process_ckpt.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import os 4 | from argparse import ArgumentParser 5 | from functools import partial 6 | from glob import glob 7 | from pathlib import Path 8 | import torch 9 | from tqdm import tqdm 10 | from rt_bene.blink_estimation_models_pytorch import BlinkEstimationModelResnet18, BlinkEstimationModelResnet50, \ 11 | BlinkEstimationModelVGG16, BlinkEstimationModelVGG19, BlinkEstimationModelDenseNet121 12 | 13 | if __name__ == "__main__": 14 | _root_parser = ArgumentParser(add_help=False) 15 | root_dir = os.path.dirname(os.path.realpath(__file__)) 16 | _root_parser.add_argument('--ckpt_dir', type=str, default=os.path.abspath( 17 | os.path.join(root_dir, '../../rt_bene_model_training/pytorch/checkpoints/'))) 18 | _root_parser.add_argument('--save_dir', type=str, default=os.path.abspath( 19 | os.path.join(root_dir, '../../rt_bene_model_training/pytorch/model_nets/'))) 20 | _root_parser.add_argument('--model_base', choices=["vgg16", "vgg19", "resnet18", "resnet50", "densenet121"], 21 | default="densenet121") 22 | _params = _root_parser.parse_args() 23 | 24 | _models = { 25 | "resnet18": BlinkEstimationModelResnet18, 26 | "resnet50": BlinkEstimationModelResnet50, 27 | "vgg16": BlinkEstimationModelVGG16, 28 | "vgg19": BlinkEstimationModelVGG19, 29 | "densenet121": BlinkEstimationModelDenseNet121 30 | } 31 | 32 | # create save dir 33 | Path(_params.save_dir).mkdir(parents=True, exist_ok=True) 34 | 35 | _model = _models.get(_params.model_base)() 36 | for ckpt in tqdm(glob(os.path.join(_params.ckpt_dir, "*.ckpt")), desc="Processing..."): 37 | filename, file_extension = os.path.splitext(ckpt) 38 | filename = os.path.basename(filename) 39 | _torch_load = torch.load(ckpt)['state_dict'] 40 | 41 | # the ckpt file saves the pytorch_lightning module which includes it's child members. The only child member we're interested in is the "_model". 42 | # Loading the state_dict with _model creates an error as the model tries to find a child called _model within it that doesn't 43 | # exist. Thus remove _model from the dictionary and all is well. 44 | _state_dict = dict(_torch_load.items()) 45 | _state_dict = {k[7:]: v for k, v in _state_dict.items() if k.startswith("_model.")} 46 | _model.load_state_dict(_state_dict) 47 | torch.save(_model.state_dict(), os.path.join(_params.save_dir, f"{filename}.model")) 48 | -------------------------------------------------------------------------------- /rt_bene_model_training/pytorch/rtbene_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils import data 6 | from torchvision import transforms 7 | from tqdm import tqdm 8 | import h5py 9 | 10 | 11 | class RTBENEH5Dataset(data.Dataset): 12 | 13 | def __init__(self, h5_pth, subject_list=None, transform=None, loader_desc="train"): 14 | self._h5_file = h5_pth 15 | self._transform = transform 16 | self._subject_labels = [] 17 | 18 | assert subject_list is not None, "Must pass a list of subjects to load the data for" 19 | 20 | if self._transform is None: 21 | self._transform = transforms.Compose([transforms.Resize((36, 60), transforms.InterpolationMode.BICUBIC), 22 | transforms.ToTensor(), 23 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 24 | std=[0.229, 0.224, 0.225])]) 25 | 26 | _wanted_subjects = ["s{:03d}".format(_i) for _i in subject_list] 27 | 28 | with h5py.File(self._h5_file, mode="r") as h5_file: 29 | for grp_s_n in tqdm(_wanted_subjects, desc="Loading ({}) subject metadata...".format(loader_desc), position=0): # subjects 30 | for grp_i_n, grp_i in h5_file[grp_s_n].items(): # images 31 | if "left" in grp_i.keys() and "right" in grp_i.keys() and "label" in grp_i.keys(): 32 | left_dataset = grp_i["left"] 33 | right_datset = grp_i['right'] 34 | 35 | assert len(left_dataset) == len( 36 | right_datset), "Weird: Dataset left/right images aren't equal length" 37 | for _i in range(len(left_dataset)): 38 | self._subject_labels.append(["/" + grp_s_n + "/" + grp_i_n, _i]) 39 | 40 | @staticmethod 41 | def get_class_weights(h5_file, subject_list): 42 | positive = 0 43 | total = 0 44 | _wanted_subjects = ["s{:03d}".format(_i) for _i in subject_list] 45 | 46 | for grp_s_n in tqdm(_wanted_subjects, desc="Loading class weights...", position=0): 47 | for grp_i_n, grp_i in h5_file[grp_s_n].items(): # images 48 | if "left" in grp_i.keys() and "right" in grp_i.keys() and "label" in grp_i.keys(): 49 | label = grp_i["label"][()][0] 50 | if label == 1.0: 51 | positive = positive + 1 52 | total = total + 1 53 | 54 | negative = total - positive 55 | weight_for_0 = (negative + positive) / negative 56 | weight_for_1 = (negative + positive) / positive 57 | return {0: weight_for_0, 1: weight_for_1} 58 | 59 | def __len__(self): 60 | return len(self._subject_labels) 61 | 62 | def __getitem__(self, index): 63 | sample = self._subject_labels[index] 64 | 65 | with h5py.File(self._h5_file, mode="r") as h5_file: 66 | left_img = h5_file[sample[0] + "/left"][sample[1]][()][0] 67 | right_img = h5_file[sample[0] + "/right"][sample[1]][()][0] 68 | label = h5_file[sample[0] + "/label"][()].astype(float) 69 | 70 | # Load data and get label 71 | transformed_left_img = self._transform(Image.fromarray(left_img, 'RGB')) 72 | transformed_right_img = self._transform(Image.fromarray(right_img, 'RGB')) 73 | 74 | return transformed_left_img, transformed_right_img, label 75 | -------------------------------------------------------------------------------- /rt_bene_model_training/pytorch/util/GenerateRTBENEH5Dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import h5py 5 | import numpy as np 6 | from PIL import Image 7 | from tqdm import tqdm 8 | 9 | script_path = os.path.dirname(os.path.realpath(__file__)) 10 | 11 | 12 | _required_size = (224, 224) 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser(description='Estimate gaze from images') 16 | parser.add_argument('--rt_bene_root', type=str, required=True, nargs='?', 17 | help='Path to the base directory of RT_BENE') 18 | parser.add_argument('--compress', action='store_true', dest="compress", help="Whether to use LZF compression or not") 19 | parser.add_argument('--no-compress', action='store_false', dest="compress") 20 | parser.set_defaults(compress=False) 21 | args = parser.parse_args() 22 | 23 | _compression = "lzf" if args.compress is True else None 24 | 25 | subject_path = [os.path.join(args.rt_bene_root, "s{:03d}_noglasses/".format(_i)) for _i in range(0, 17)] 26 | 27 | with h5py.File(os.path.abspath(os.path.join(args.rt_bene_root, "rtbene_dataset.hdf5")), mode='w') as hdf_file: 28 | for subject_id, subject_data in enumerate(subject_path): 29 | subject_id = str("s{:03d}".format(subject_id)) 30 | subject_grp = hdf_file.create_group(subject_id) 31 | with open(os.path.join(args.rt_bene_root, "{}_blink_labels.csv".format(subject_id)), "r") as f: 32 | _lines = f.readlines() 33 | 34 | for line in tqdm(_lines, desc="Subject {}".format(subject_id)): 35 | split = line.split(",") 36 | image_name = split[0][5:] 37 | image_grp = subject_grp.create_group(image_name.split("_")[0]) 38 | left_image_path = os.path.join(subject_data, "natural/left/", "left_{}".format(image_name)) 39 | right_image_path = os.path.join(subject_data, "natural/right/", "right_{}".format(image_name)) 40 | if os.path.exists(left_image_path) and os.path.exists(right_image_path): 41 | label = float(split[1].strip("\n")) 42 | if label != 0.5: # paper removed 0.5s 43 | left_image_data = np.array([np.array(Image.open(left_image_path).resize(_required_size))]) 44 | right_image_data = np.array([np.array(Image.open(right_image_path).resize(_required_size))]) 45 | image_grp.create_dataset("left", data=left_image_data, compression=_compression) 46 | image_grp.create_dataset("right", data=left_image_data, compression=_compression) 47 | image_grp.create_dataset("label", data=[label]) 48 | -------------------------------------------------------------------------------- /rt_bene_model_training/pytorch/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_bene_model_training/pytorch/util/__init__.py -------------------------------------------------------------------------------- /rt_bene_model_training/tensorflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_bene_model_training/tensorflow/__init__.py -------------------------------------------------------------------------------- /rt_bene_model_training/tensorflow/dataset_manager.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import csv 4 | import itertools 5 | 6 | import cv2 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | 11 | def read_rgb_image(img_path, size, flip): 12 | assert type(size) is tuple, "size parameter must be a tuple, (96, 96) for instance" 13 | img = cv2.imread(img_path, cv2.IMREAD_COLOR) 14 | if img is None: 15 | print("ERROR: can't read " + img_path) 16 | if flip: 17 | img = cv2.flip(img, 1) 18 | img = cv2.resize(img, size, cv2.INTER_CUBIC) 19 | return img 20 | 21 | 22 | def load_one_flipped_pair(l_path, r_path, size): 23 | l_img = read_rgb_image(l_path, size, flip=False) 24 | r_img = read_rgb_image(r_path, size, flip=True) 25 | return l_img, r_img 26 | 27 | 28 | class RTBeneDataset(object): 29 | def __init__(self, csv_subject_list, input_size): 30 | self.csv_subject_list = csv_subject_list 31 | self.input_size = input_size 32 | self.subjects = {} 33 | self.training_set = {} 34 | self.validation_set = {} 35 | self.folds = {} 36 | 37 | self.load() 38 | 39 | def load_one_subject(self, csv_labels, left_folder, right_folder): 40 | subject = {'y': []} 41 | 42 | left_inputs = [] 43 | right_inputs = [] 44 | 45 | with open(csv_labels) as csvfile: 46 | csv_rows = csv.reader(csvfile) 47 | for row in tqdm(csv_rows): 48 | img_name = row[0] 49 | img_lbl = float(row[1]) 50 | if img_lbl == 0.5: # annotators did not agree whether eye is open or not, so discard this sample 51 | continue 52 | left_img_path = left_folder + img_name 53 | right_img_path = right_folder + img_name.replace("left", "right") 54 | try: 55 | left_img, right_img = load_one_flipped_pair(left_img_path, right_img_path, self.input_size) 56 | left_inputs.append(left_img) 57 | right_inputs.append(right_img) 58 | subject['y'].append(img_lbl) 59 | except: 60 | print('Failure loading pair ' + left_img_path + ' ' + right_img_path) 61 | subject['x'] = [np.array(left_inputs), np.array(right_inputs)] 62 | return subject 63 | 64 | def load(self): 65 | with open(self.csv_subject_list) as csvfile: 66 | csv_rows = csv.reader(csvfile) 67 | for row in csv_rows: 68 | subject_id = int(row[0]) 69 | csv_labels = row[1] 70 | left_folder = row[2] 71 | right_folder = row[3] 72 | fold_type = row[4] 73 | fold_id = int(row[5]) 74 | 75 | if fold_type == 'discarded': 76 | print('\nsubject ' + str(subject_id) + ' is discarded.') 77 | else: 78 | print('\nsubject ' + str(subject_id) + ' is loading...') 79 | csv_filename = self.csv_subject_list.split('/')[-1] 80 | csv_labels = self.csv_subject_list.replace(csv_filename, csv_labels) 81 | left_folder = self.csv_subject_list.replace(csv_filename, left_folder) 82 | right_folder = self.csv_subject_list.replace(csv_filename, right_folder) 83 | 84 | if fold_type == 'training': 85 | self.training_set[subject_id] = self.load_one_subject(csv_labels, left_folder, right_folder) 86 | if fold_id not in self.folds.keys(): 87 | self.folds[fold_id] = [] 88 | self.folds[fold_id].append(subject_id) 89 | elif fold_type == 'validation': 90 | self.validation_set[subject_id] = self.load_one_subject(csv_labels, left_folder, right_folder) 91 | 92 | @staticmethod 93 | def get_data(dataset, subject_list): 94 | all_x_left = [dataset[subject_id]['x'][0] for subject_id in subject_list] 95 | all_x_right = [dataset[subject_id]['x'][1] for subject_id in subject_list] 96 | all_y = [np.array(dataset[subject_id]['y']) for subject_id in subject_list] 97 | fold = {'x': [np.concatenate(all_x_right), np.concatenate(all_x_left)], 'y': np.concatenate(all_y)} 98 | fold['positive'] = np.count_nonzero(fold['y'] == 1.) 99 | fold['negative'] = np.count_nonzero(fold['y'] == 0.) 100 | fold['y'] = fold['y'].tolist() 101 | return fold 102 | 103 | def get_training_data(self, fold_ids): 104 | subject_list = list(itertools.chain(*[self.folds[fold_id] for fold_id in fold_ids])) 105 | return self.get_data(self.training_set, subject_list) 106 | 107 | def get_validation_data(self): 108 | subject_list = self.validation_set.keys() 109 | return self.get_data(self.validation_set, subject_list) 110 | -------------------------------------------------------------------------------- /rt_bene_model_training/tensorflow/evaluate_blink_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import gc 4 | 5 | import tensorflow as tf 6 | from tensorflow.keras.models import load_model 7 | 8 | from sklearn.metrics import confusion_matrix, roc_curve, auc, average_precision_score 9 | 10 | import numpy as np 11 | 12 | tf.compat.v1.disable_eager_execution() 13 | 14 | config = tf.compat.v1.ConfigProto() 15 | config.gpu_options.allow_growth = True 16 | tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config)) 17 | 18 | 19 | fold_infos = { 20 | 'fold1': [2], 21 | 'fold2': [1], 22 | 'fold3': [0], 23 | 'all': [2, 1, 0] 24 | } 25 | 26 | model_metrics = [tf.keras.metrics.BinaryAccuracy()] 27 | 28 | 29 | def estimate_metrics(testing_fold, model_instance): 30 | threshold = 0.5 31 | p = model_instance.predict(x=testing_fold['x'], verbose=0) 32 | p = p >= threshold 33 | matrix = confusion_matrix(testing_fold['y'], p) 34 | ap = average_precision_score(testing_fold['y'], p) 35 | fpr, tpr, thresholds = roc_curve(testing_fold['y'], p) 36 | roc = auc(fpr, tpr) 37 | return matrix, ap, roc 38 | 39 | 40 | def get_metrics_from_matrix(matrix): 41 | tp, tn, fp, fn = matrix[1, 1], matrix[0, 0], matrix[0, 1], matrix[1, 0] 42 | precision = tp / (tp + fp) 43 | recall = tp / (tp + fn) 44 | f1score = 2. * (precision * recall) / (precision + recall) 45 | return precision, recall, f1score 46 | 47 | 48 | def threefold_evaluation(dataset, model_paths_fold1, model_paths_fold2, model_paths_fold3, input_size): 49 | folds = ['fold1', 'fold2', 'fold3'] 50 | aps = [] 51 | rocs = [] 52 | recalls = [] 53 | precisions = [] 54 | f1scores = [] 55 | models = [] 56 | 57 | for fold_to_eval_on, model_paths in zip(folds, [model_paths_fold1, model_paths_fold2, model_paths_fold3]): 58 | if len(model_paths_fold1) > 1: 59 | models = [load_model(model_path, compile=False) for model_path in model_paths] 60 | img_input_l = tf.keras.Input(shape=input_size, name='img_input_L') 61 | img_input_r = tf.keras.Input(shape=input_size, name='img_input_R') 62 | tensors = [model([img_input_r, img_input_l]) for model in models] 63 | output_layer = tf.keras.layers.average(tensors) 64 | model_instance = tf.keras.Model(inputs=[img_input_r, img_input_l], outputs=output_layer) 65 | else: 66 | model_instance = load_model(model_paths[0]) 67 | model_instance.compile() 68 | 69 | testing_fold = dataset.get_training_data(fold_infos[fold_to_eval_on]) # get the testing fold subjects 70 | 71 | matrix, ap, roc = estimate_metrics(testing_fold, model_instance) 72 | aps.append(ap) 73 | rocs.append(roc) 74 | precision, recall, f1score = get_metrics_from_matrix(matrix) 75 | recalls.append(recall) 76 | precisions.append(precision) 77 | f1scores.append(f1score) 78 | 79 | del model_instance, testing_fold 80 | # noinspection PyUnusedLocal 81 | for model in models: 82 | del model 83 | gc.collect() 84 | 85 | evaluation = {'AP': {}, 'ROC': {}, 'precision': {}, 'recall': {}, 'f1score': {}} 86 | evaluation['AP']['avg'] = np.mean(np.array(aps)) 87 | evaluation['AP']['std'] = np.std(np.array(aps)) 88 | evaluation['ROC']['avg'] = np.mean(np.array(rocs)) 89 | evaluation['ROC']['std'] = np.std(np.array(rocs)) 90 | evaluation['precision']['avg'] = np.mean(np.array(precisions)) 91 | evaluation['precision']['std'] = np.std(np.array(precisions)) 92 | evaluation['recall']['avg'] = np.mean(np.array(recalls)) 93 | evaluation['recall']['std'] = np.std(np.array(recalls)) 94 | evaluation['f1score']['avg'] = np.mean(np.array(f1scores)) 95 | evaluation['f1score']['std'] = np.std(np.array(f1scores)) 96 | return evaluation 97 | -------------------------------------------------------------------------------- /rt_bene_model_training/tensorflow/train_and_evaluate.py: -------------------------------------------------------------------------------- 1 | from evaluate_blink_model import threefold_evaluation 2 | from train_blink_model import ThreefoldTraining 3 | from dataset_manager import RTBeneDataset 4 | from pathlib import Path 5 | import argparse 6 | import pprint 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("model_save_root", help="target folder to save the models (auto-saved)") 11 | parser.add_argument("csv_subject_list", help="path to the dataset csv file") 12 | parser.add_argument("--ensemble_size", type=int, default=1, help="number of models to train for the ensemble") 13 | parser.add_argument("--batch_size", type=int, default=64) 14 | parser.add_argument("--epochs", type=int, default=15) 15 | parser.add_argument("--input_size", type=tuple, help="input size of images", default=(96, 96)) 16 | 17 | args = parser.parse_args() 18 | 19 | fold_list = ['fold1', 'fold2', 'fold3'] 20 | ensemble_size = args.ensemble_size # 1 is considered as single model 21 | epochs = args.epochs 22 | batch_size = args.batch_size 23 | input_size = args.input_size 24 | csv_subject_list = args.csv_subject_list 25 | model_save_root = args.model_save_root 26 | 27 | dataset = RTBeneDataset(csv_subject_list, input_size) 28 | 29 | threefold_training = ThreefoldTraining(dataset, epochs, batch_size, input_size) 30 | 31 | all_evaluations = {} 32 | 33 | for backbone in ['densenet121', 'resnet50', 'mobilenetv2']: 34 | models_fold1 = [] 35 | models_fold2 = [] 36 | models_fold3 = [] 37 | 38 | for i in range(1, ensemble_size + 1): 39 | model_save_path = Path(model_save_root + backbone + '/' + str(i)) 40 | model_save_path.mkdir(parents=True, exist_ok=True) 41 | threefold_training.train(backbone, str(model_save_path) + '/') 42 | 43 | models_fold1.append(str(model_save_path) + '/rt-bene_' + backbone + '_fold1_best.h5') 44 | models_fold2.append(str(model_save_path) + '/rt-bene_' + backbone + '_fold2_best.h5') 45 | models_fold3.append(str(model_save_path) + '/rt-bene_' + backbone + '_fold3_best.h5') 46 | 47 | evaluation = threefold_evaluation(dataset, models_fold1, models_fold2, models_fold3, input_size) 48 | all_evaluations[backbone] = evaluation 49 | 50 | threefold_training.free() 51 | 52 | pprint.pprint(all_evaluations) 53 | -------------------------------------------------------------------------------- /rt_bene_model_training/tensorflow/train_blink_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import gc 4 | import argparse 5 | 6 | import tensorflow as tf 7 | from tensorflow.keras.callbacks import ModelCheckpoint 8 | from tensorflow.keras.models import Model 9 | from tensorflow.keras.layers import Dense, Input, Dropout, BatchNormalization, Average, ReLU 10 | from tensorflow.keras.optimizers import Adam 11 | from tensorflow.keras.regularizers import l2 12 | 13 | from dataset_manager import RTBeneDataset 14 | 15 | tf.compat.v1.disable_eager_execution() 16 | 17 | config = tf.compat.v1.ConfigProto() 18 | config.gpu_options.allow_growth = True 19 | tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config)) 20 | 21 | 22 | def create_model_base(backbone, input_shape): 23 | if backbone == 'mobilenetv2': 24 | base = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_tensor=None, 25 | input_shape=input_shape, pooling='avg') 26 | 27 | elif backbone == 'densenet121': 28 | base = tf.keras.applications.DenseNet121(include_top=False, weights='imagenet', input_tensor=None, 29 | input_shape=input_shape, pooling='avg') 30 | 31 | elif backbone == 'resnet50': 32 | base = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', input_tensor=None, 33 | input_shape=input_shape, pooling='avg') 34 | else: 35 | raise Exception('Wrong backbone') 36 | 37 | for layer in base.layers: 38 | layer.trainable = True 39 | 40 | main_input = Input(shape=input_shape) 41 | main_output = base(main_input) 42 | 43 | final_fc_layer = Dense(512, kernel_regularizer=l2(0.01), bias_regularizer=l2(0.01))(main_output) 44 | final_fc_layer = BatchNormalization(epsilon=1e-3, momentum=0.999)(final_fc_layer) 45 | final_fc_layer = ReLU(6.)(final_fc_layer) 46 | final_fc_layer = Dropout(0.6)(final_fc_layer) 47 | output_tensor = Dense(1, activation='sigmoid')(final_fc_layer) # probability 48 | 49 | model = Model(inputs=main_input, outputs=output_tensor) 50 | 51 | return model 52 | 53 | 54 | def create_model(backbone, input_shape, lr, metrics): 55 | base = create_model_base(backbone, input_shape) 56 | 57 | # define the 2 inputs (left and right eyes) 58 | left_input = Input(shape=input_shape) 59 | right_input = Input(shape=input_shape) 60 | 61 | # get the 2 outputs using shared layers 62 | out_left = base(left_input) 63 | out_right = base(right_input) 64 | 65 | # average the predictions 66 | merged = Average()([out_left, out_right]) 67 | model = Model(inputs=[right_input, left_input], outputs=merged) 68 | 69 | model.compile(loss='binary_crossentropy', optimizer=Adam(lr=lr), metrics=metrics) 70 | model.summary() 71 | 72 | return model 73 | 74 | 75 | class ThreefoldTraining(object): 76 | def __init__(self, dataset, epochs, batch_size, input_size): 77 | self.fold_map = {'fold1': [0, 1], 'fold2': [0, 2], 'fold3': [1, 2]} 78 | self.model_metrics = [tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.Recall(), tf.keras.metrics.Precision()] 79 | self.dataset = dataset 80 | self.validation_set = dataset.get_validation_data() 81 | self.epochs = epochs 82 | self.batch_size = batch_size 83 | self.input_size = [input_size[0], input_size[1], 3] 84 | self.learning_rate = 1e-4 85 | 86 | def train(self, backbone, model_save_path): 87 | for fold_name, training_subjects_fold in self.fold_map.items(): 88 | training_set = self.dataset.get_training_data(training_subjects_fold) 89 | positive = training_set['positive'] 90 | negative = training_set['negative'] 91 | 92 | print('Number of positive samples in training data: {} ({:.2f}% of total)'. 93 | format(positive, 100 * float(positive) / len(training_set['y']))) 94 | 95 | model_instance = create_model(backbone, self.input_size, self.learning_rate, self.model_metrics) 96 | name = 'rt-bene_' + backbone + '_' + fold_name 97 | 98 | weight_for_0 = 1. / negative * (negative + positive) 99 | weight_for_1 = 1. / positive * (negative + positive) 100 | class_weight = {0: weight_for_0, 1: weight_for_1} 101 | 102 | save_best = ModelCheckpoint(model_save_path + name + '_best.h5', monitor='val_loss', verbose=1, 103 | save_best_only=True, save_weights_only=False, mode='min', period=1) 104 | auto_save = ModelCheckpoint(model_save_path + name + '_auto_{epoch:02d}.h5', verbose=1, 105 | save_best_only=False, save_weights_only=False, period=1) 106 | 107 | # train the model 108 | model_instance.fit(x=training_set['x'], y=training_set['y'], 109 | batch_size=self.batch_size, epochs=self.epochs, 110 | verbose=1, 111 | validation_data=(self.validation_set['x'], self.validation_set['y']), 112 | callbacks=[save_best, auto_save], 113 | class_weight=class_weight) 114 | # noinspection PyUnusedLocal 115 | model_instance, training_set = None, None 116 | del model_instance, training_set 117 | gc.collect() 118 | 119 | def free(self): 120 | self.validation_set = None 121 | del self.validation_set 122 | 123 | 124 | if __name__ == "__main__": 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument("backbone", choices=['densenet121', 'resnet50', 'mobilenetv2']) 127 | parser.add_argument("model_save_path", help="target folder to save the models (auto-saved)") 128 | parser.add_argument("csv_subject_list", help="path to the dataset csv file") 129 | parser.add_argument("--batch_size", type=int, default=64) 130 | parser.add_argument("--epochs", type=int, default=15) 131 | parser.add_argument("--input_size", type=tuple, help="input size of images", default=(96, 96)) 132 | 133 | args = parser.parse_args() 134 | 135 | rtbene_dataset = RTBeneDataset(args.csv_subject_list, args.input_size) 136 | 137 | threefold_training = ThreefoldTraining(rtbene_dataset, args.epochs, args.batch_size, args.input_size) 138 | threefold_training.train(args.backbone, args.model_save_path + '/') 139 | threefold_training.free() 140 | -------------------------------------------------------------------------------- /rt_bene_standalone/README.md: -------------------------------------------------------------------------------- 1 | # RT-BENE: A Dataset and Baselines for Real-Time Blink Estimation in Natural Environments 2 | ## License + Attribution 3 | This code is licensed under [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/). Commercial usage is not permitted; please contact or regarding commercial licensing. If you use this dataset or the code in a scientific publication, please cite the following [paper](http://openaccess.thecvf.com/content_ICCVW_2019/html/GAZE/Cortacero_RT-BENE_A_Dataset_and_Baselines_for_Real-Time_Blink_Estimation_in_ICCVW_2019_paper.html): 4 | 5 | ``` 6 | @inproceedings{CortaceroICCV2019W, 7 | author={Kevin Cortacero and Tobias Fischer and Yiannis Demiris}, 8 | booktitle = {Proceedings of the IEEE International Conference on Computer Vision Workshops}, 9 | title = {RT-BENE: A Dataset and Baselines for Real-Time Blink Estimation in Natural Environments}, 10 | year = {2019}, 11 | } 12 | ``` 13 | 14 | RT-BENE was supported by the EU Horizon 2020 Project PAL (643783-RIA) and a Royal Academy of Engineering Chair in Emerging Technologies to Yiannis Demiris. 15 | 16 | More information can be found on the Personal Robotic Lab's website: . 17 | 18 | ## Requirements 19 | Please follow the steps given in the Requirements section for the [RT-GENE standalone version](../rt_gene_standalone/README.md). 20 | 21 | ## Basic usage 22 | - Run `$HOME/rt_gene/rt_gene_standalone/estimate_blink_standalone.py`. For supported arguments, run `$HOME/rt_gene_standalone/scripts/estimate_blink_standalone.py --help` 23 | 24 | ### Optional ensemble model files 25 | - To use an ensemble scheme using multiple models, simply use the `--model` argument, e.g `cd $HOME/rt_gene/ && ./rt_gene_standalone/estimate_blink_standalone.py --models './rt_gene/model_nets/blink_model_1.h5' './rt_gene/model_nets/blink_model_2.h5'` 26 | 27 | ## List of libraries 28 | See [main README.md](../rt_gene/README.md) 29 | 30 | -------------------------------------------------------------------------------- /rt_bene_standalone/estimate_blink_standalone.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function, division, absolute_import 4 | 5 | import argparse 6 | import os 7 | import time 8 | from os import listdir 9 | 10 | import cv2 11 | import numpy as np 12 | 13 | script_path = os.path.dirname(os.path.realpath(__file__)) 14 | 15 | 16 | def str2bool(v): 17 | if isinstance(v, bool): 18 | return v 19 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 20 | return True 21 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 22 | return False 23 | else: 24 | raise argparse.ArgumentTypeError('Boolean value expected.') 25 | 26 | 27 | class BlinkEstimatorFolderPair(object): 28 | def __init__(self, blink_estimator, viz): 29 | self.blink_estimator = blink_estimator 30 | self.viz = viz 31 | 32 | def estimate(self, left_folder_path, right_folder_path): 33 | left_image_paths, right_image_paths = [], [] 34 | left_images, right_images = [], [] 35 | 36 | for left_image_name in sorted(listdir(left_folder_path)): 37 | left_image_path = left_folder_path + '/' + left_image_name 38 | left_image_paths.append(left_image_path) 39 | left_images.append(cv2.imread(left_image_path, cv2.IMREAD_COLOR)) 40 | 41 | for right_image_name in sorted(listdir(right_folder_path)): 42 | right_image_path = right_folder_path + '/' + right_image_name 43 | right_image_paths.append(right_image_path) 44 | right_images.append(cv2.imread(right_image_path, cv2.IMREAD_COLOR)) 45 | 46 | l_images_input, r_images_input = [], [] 47 | for l_img, r_img in zip(left_images, right_images): 48 | l_img_input, r_img_input = self.blink_estimator.inputs_from_images(l_img, r_img) 49 | l_images_input.append(l_img_input) 50 | r_images_input.append(r_img_input) 51 | 52 | start_time = time.time() 53 | probs = self.blink_estimator.predict(l_images_input, r_images_input) 54 | blinks = probs >= self.blink_estimator.threshold 55 | print( 56 | "Estimated blink for {} eye-image pairs, Time: {:.5f}s".format(len(left_images), time.time() - start_time)) 57 | if self.viz: 58 | for left_image, right_image, is_blinking in zip(left_images, right_images, blinks): 59 | pair_img = np.concatenate((right_image, left_image), axis=1) 60 | viz_img = self.blink_estimator.overlay_prediction_over_img(pair_img, is_blinking) 61 | cv2.imshow('folder images visualisation', viz_img) 62 | cv2.waitKey(0) 63 | for left_image, right_image, p, is_blinking in zip(left_image_paths, right_image_paths, probs, blinks): 64 | print("Blink: %s (p=%.3f) for image pair: %20s %20s" % ("Yes" if is_blinking else "No ", p, 65 | os.path.basename(left_image), 66 | os.path.basename(right_image))) 67 | 68 | 69 | if __name__ == '__main__': 70 | parser = argparse.ArgumentParser(description='Estimate blink from image or folder pair.') 71 | parser.add_argument('--left', type=str, help='Path to a left eye image or a directory containing left eye images', 72 | default=os.path.join(script_path, './samples_blink/left/')) 73 | parser.add_argument('--right', type=str, 74 | help='Path to a right eye image or a directory containing images right eye images', 75 | default=os.path.join(script_path, './samples_blink/right/')) 76 | parser.add_argument('--model', nargs='+', type=str, 77 | default=[os.path.abspath(os.path.join(script_path, '../rt_gene/model_nets/blink_model_pytorch_vgg16_allsubjects1.model'))], 78 | help='List of blink estimators') 79 | parser.add_argument('--model_type', type=str, default="vgg16") 80 | parser.add_argument('--threshold', type=float, default=0.5, 81 | help='Threshold to determine weither the prediction is positive or negative') 82 | parser.add_argument('--device_id', type=str, default="cuda") 83 | parser.add_argument('--blink_backend', type=str, choices=['tensorflow', 'pytorch'], default='pytorch') 84 | parser.add_argument('--vis_blink', type=str2bool, nargs='?', default=True, 85 | help='Show the overlayed result on original image or not') 86 | 87 | args = parser.parse_args() 88 | left_path = args.left 89 | right_path = args.right 90 | 91 | if args.blink_backend == "tensorflow": 92 | from rt_bene.estimate_blink_tensorflow import BlinkEstimatorTensorflow 93 | blink_estimator = BlinkEstimatorTensorflow(device_id_blink=args.device_id, threshold=0.425, model_files=args.model, model_type=args.model_type) 94 | elif args.blink_backend == "pytorch": 95 | from rt_bene.estimate_blink_pytorch import BlinkEstimatorPytorch 96 | blink_estimator = BlinkEstimatorPytorch(device_id_blink=args.device_id, threshold=0.425, model_files=args.model, model_type=args.model_type) 97 | else: 98 | raise Exception("Unknown backend") 99 | 100 | if os.path.isdir(left_path) and os.path.isdir(right_path): 101 | blink_folder = BlinkEstimatorFolderPair(blink_estimator, viz=args.vis_blink) 102 | blink_folder.estimate(left_path, right_path) 103 | if args.vis_blink: 104 | cv2.destroyAllWindows() 105 | else: 106 | raise Exception('Folders not found: Check that ' + left_path + ' and ' + right_path + ' exist') 107 | 108 | 109 | -------------------------------------------------------------------------------- /rt_bene_standalone/samples_blink/left/left_blink.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_bene_standalone/samples_blink/left/left_blink.png -------------------------------------------------------------------------------- /rt_bene_standalone/samples_blink/left/left_open.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_bene_standalone/samples_blink/left/left_open.png -------------------------------------------------------------------------------- /rt_bene_standalone/samples_blink/right/right_blink.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_bene_standalone/samples_blink/right/right_blink.png -------------------------------------------------------------------------------- /rt_bene_standalone/samples_blink/right/right_open.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_bene_standalone/samples_blink/right/right_open.png -------------------------------------------------------------------------------- /rt_gene/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(rt_gene) 3 | 4 | find_package(catkin REQUIRED COMPONENTS 5 | rospy 6 | message_generation 7 | std_msgs 8 | sensor_msgs 9 | geometry_msgs 10 | image_geometry 11 | cv_bridge 12 | image_transport 13 | tf 14 | uvc_camera 15 | dynamic_reconfigure 16 | ) 17 | 18 | catkin_python_setup() 19 | 20 | ################################################ 21 | ## Declare ROS messages, services, dynamic reconfigure and actions ## 22 | ################################################ 23 | 24 | # Generate messages in the 'msg' folder 25 | add_message_files( 26 | FILES 27 | MSG_SubjectImages.msg 28 | MSG_SubjectImagesList.msg 29 | MSG_Gaze.msg 30 | MSG_Headpose.msg 31 | MSG_Landmarks.msg 32 | MSG_Blink.msg 33 | MSG_GazeList.msg 34 | MSG_HeadposeList.msg 35 | MSG_LandmarksList.msg 36 | MSG_BlinkList.msg 37 | ) 38 | 39 | # Generate added messages and services with any dependencies listed here 40 | generate_messages( 41 | DEPENDENCIES 42 | std_msgs 43 | sensor_msgs 44 | geometry_msgs 45 | ) 46 | 47 | # Generate the dynamic reconfigure options 48 | generate_dynamic_reconfigure_options( 49 | cfg/ModelSize.cfg 50 | ) 51 | 52 | ################################### 53 | ## catkin specific configuration ## 54 | ################################### 55 | ## The catkin_package macro generates cmake config files for your package 56 | ## Declare things to be passed to dependent projects 57 | ## INCLUDE_DIRS: uncomment this if you package contains header files 58 | ## LIBRARIES: libraries you create in this project that dependent projects also need 59 | ## CATKIN_DEPENDS: catkin_packages dependent projects also need 60 | ## DEPENDS: system dependencies of this project that dependent projects also need 61 | catkin_package( 62 | # INCLUDE_DIRS include 63 | # LIBRARIES rt_gene 64 | CATKIN_DEPENDS rospy std_msgs geometry_msgs sensor_msgs image_geometry cv_bridge image_transport tf tf2_ros message_runtime 65 | # DEPENDS system_lib 66 | ) 67 | 68 | ############# 69 | ## Install ## 70 | ############# 71 | 72 | catkin_install_python(PROGRAMS scripts/estimate_gaze.py 73 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}) 74 | catkin_install_python(PROGRAMS scripts/extract_landmarks_node.py 75 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}) 76 | catkin_install_python(PROGRAMS scripts/estimate_blink.py 77 | DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}) 78 | 79 | # For launch files. 80 | install(DIRECTORY launch DESTINATION "${CATKIN_PACKAGE_SHARE_DESTINATION}") 81 | 82 | -------------------------------------------------------------------------------- /rt_gene/cfg/ModelSize.cfg: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | PACKAGE = "rt_gene" 3 | 4 | from dynamic_reconfigure.parameter_generator_catkin import * 5 | 6 | gen = ParameterGenerator() 7 | 8 | gen.add("interpupillary_distance", double_t, 0, "Interpupillary distance of participant", 0.058, 0.03, 0.10) 9 | gen.add("model_size", double_t, 0, "Scale of the 3D model of the face used for head pose estimation", 16.0, 10.0, 40.0) 10 | gen.add("head_pitch", double_t, 0, "Alter the head pose by the head_pitch", 0.0, -2.0, 2.0) 11 | 12 | exit(gen.generate(PACKAGE, "rt_gene", "ModelSize")) 13 | -------------------------------------------------------------------------------- /rt_gene/launch/estimate_blink.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | ['model_nets/blink_model_pytorch_vgg16_allsubjects1.model','model_nets/blink_model_pytorch_vgg16_allsubjects2.model'] 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /rt_gene/launch/estimate_gaze.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | ['model_nets/Model_allsubjects1.h5'] 36 | 37 | 38 | / 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /rt_gene/launch/start_kinect.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /rt_gene/launch/start_rosbag.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /rt_gene/launch/start_video.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /rt_gene/launch/start_webcam.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /rt_gene/model_nets/SFD/README.md: -------------------------------------------------------------------------------- 1 | EMPTY README.md to create folder. -------------------------------------------------------------------------------- /rt_gene/model_nets/ThreeDDFA/keypoints_sim.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene/model_nets/ThreeDDFA/keypoints_sim.npy -------------------------------------------------------------------------------- /rt_gene/model_nets/ThreeDDFA/param_whitening.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene/model_nets/ThreeDDFA/param_whitening.pkl -------------------------------------------------------------------------------- /rt_gene/model_nets/ThreeDDFA/param_whitening_py2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene/model_nets/ThreeDDFA/param_whitening_py2.pkl -------------------------------------------------------------------------------- /rt_gene/model_nets/ThreeDDFA/u_exp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene/model_nets/ThreeDDFA/u_exp.npy -------------------------------------------------------------------------------- /rt_gene/model_nets/ThreeDDFA/u_shp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene/model_nets/ThreeDDFA/u_shp.npy -------------------------------------------------------------------------------- /rt_gene/model_nets/face_model_68.txt: -------------------------------------------------------------------------------- 1 | -73.393523 2 | -72.775014 3 | -70.533638 4 | -66.850058 5 | -59.790187 6 | -48.368973 7 | -34.121101 8 | -17.875411 9 | 0.098749 10 | 17.477031 11 | 32.648966 12 | 46.372358 13 | 57.343480 14 | 64.388482 15 | 68.212038 16 | 70.486405 17 | 71.375822 18 | -61.119406 19 | -51.287588 20 | -37.804800 21 | -24.022754 22 | -11.635713 23 | 12.056636 24 | 25.106256 25 | 38.338588 26 | 51.191007 27 | 60.053851 28 | 0.653940 29 | 0.804809 30 | 0.992204 31 | 1.226783 32 | -14.772472 33 | -7.180239 34 | 0.555920 35 | 8.272499 36 | 15.214351 37 | -46.047290 38 | -37.674688 39 | -27.883856 40 | -19.648268 41 | -28.272965 42 | -38.082418 43 | 19.265868 44 | 27.894191 45 | 37.437529 46 | 45.170805 47 | 38.196454 48 | 28.764989 49 | -28.916267 50 | -17.533194 51 | -6.684590 52 | 0.381001 53 | 8.375443 54 | 18.876618 55 | 28.794412 56 | 19.057574 57 | 8.956375 58 | 0.381549 59 | -7.428895 60 | -18.160634 61 | -24.377490 62 | -6.897633 63 | 0.340663 64 | 8.444722 65 | 24.474473 66 | 8.449166 67 | 0.205322 68 | -7.198266 69 | -29.801432 70 | -10.949766 71 | 7.929818 72 | 26.074280 73 | 42.564390 74 | 56.481080 75 | 67.246992 76 | 75.056892 77 | 77.061286 78 | 74.758448 79 | 66.929021 80 | 56.311389 81 | 42.419126 82 | 25.455880 83 | 6.990805 84 | -11.666193 85 | -30.365191 86 | -49.361602 87 | -58.769795 88 | -61.996155 89 | -61.033399 90 | -56.686759 91 | -57.391033 92 | -61.902186 93 | -62.777713 94 | -59.302347 95 | -50.190255 96 | -42.193790 97 | -30.993721 98 | -19.944596 99 | -8.414541 100 | 2.598255 101 | 4.751589 102 | 6.562900 103 | 4.661005 104 | 2.643046 105 | -37.471411 106 | -42.730510 107 | -42.711517 108 | -36.754742 109 | -35.134493 110 | -34.919043 111 | -37.032306 112 | -43.342445 113 | -43.110822 114 | -38.086515 115 | -35.532024 116 | -35.484289 117 | 28.612716 118 | 22.172187 119 | 19.029051 120 | 20.721118 121 | 19.035460 122 | 22.394109 123 | 28.079924 124 | 36.298248 125 | 39.634575 126 | 40.395647 127 | 39.836405 128 | 36.677899 129 | 28.677771 130 | 25.475976 131 | 26.014269 132 | 25.326198 133 | 28.323008 134 | 30.596216 135 | 31.408738 136 | 30.844876 137 | 47.667532 138 | 45.909403 139 | 44.842580 140 | 43.141114 141 | 38.635298 142 | 30.750622 143 | 18.456453 144 | 3.609035 145 | -0.881698 146 | 5.181201 147 | 19.176563 148 | 30.770570 149 | 37.628629 150 | 40.886309 151 | 42.281449 152 | 44.142567 153 | 47.140426 154 | 14.254422 155 | 7.268147 156 | 0.442051 157 | -6.606501 158 | -11.967398 159 | -12.051204 160 | -7.315098 161 | -1.022953 162 | 5.349435 163 | 11.615746 164 | -13.380835 165 | -21.150853 166 | -29.284036 167 | -36.948060 168 | -20.132003 169 | -23.536684 170 | -25.944448 171 | -23.695741 172 | -20.858157 173 | 7.037989 174 | 3.021217 175 | 1.353629 176 | -0.111088 177 | -0.147273 178 | 1.476612 179 | -0.665746 180 | 0.247660 181 | 1.696435 182 | 4.894163 183 | 0.282961 184 | -1.172675 185 | -2.240310 186 | -15.934335 187 | -22.611355 188 | -23.748437 189 | -22.721995 190 | -15.610679 191 | -3.217393 192 | -14.987997 193 | -22.554245 194 | -23.591626 195 | -22.406106 196 | -15.121907 197 | -4.785684 198 | -20.893742 199 | -22.220479 200 | -21.025520 201 | -5.712776 202 | -20.671489 203 | -21.903670 204 | -20.328022 205 | -------------------------------------------------------------------------------- /rt_gene/msg/MSG_Blink.msg: -------------------------------------------------------------------------------- 1 | string subject_id 2 | bool blink 3 | float64 probability -------------------------------------------------------------------------------- /rt_gene/msg/MSG_BlinkList.msg: -------------------------------------------------------------------------------- 1 | Header header 2 | MSG_Blink[] subjects 3 | -------------------------------------------------------------------------------- /rt_gene/msg/MSG_Gaze.msg: -------------------------------------------------------------------------------- 1 | string subject_id 2 | float64 phi 3 | float64 theta 4 | -------------------------------------------------------------------------------- /rt_gene/msg/MSG_GazeList.msg: -------------------------------------------------------------------------------- 1 | Header header 2 | MSG_Gaze[] subjects 3 | -------------------------------------------------------------------------------- /rt_gene/msg/MSG_Headpose.msg: -------------------------------------------------------------------------------- 1 | string subject_id 2 | float64 roll 3 | float64 pitch 4 | float64 yaw 5 | float64 x 6 | float64 y 7 | float64 z 8 | -------------------------------------------------------------------------------- /rt_gene/msg/MSG_HeadposeList.msg: -------------------------------------------------------------------------------- 1 | Header header 2 | MSG_Headpose[] subjects 3 | -------------------------------------------------------------------------------- /rt_gene/msg/MSG_Landmarks.msg: -------------------------------------------------------------------------------- 1 | string subject_id 2 | float64[] landmarks 3 | -------------------------------------------------------------------------------- /rt_gene/msg/MSG_LandmarksList.msg: -------------------------------------------------------------------------------- 1 | Header header 2 | MSG_Landmarks[] subjects 3 | -------------------------------------------------------------------------------- /rt_gene/msg/MSG_SubjectImages.msg: -------------------------------------------------------------------------------- 1 | string subject_id 2 | sensor_msgs/Image face_img 3 | sensor_msgs/Image right_eye_img 4 | sensor_msgs/Image left_eye_img 5 | -------------------------------------------------------------------------------- /rt_gene/msg/MSG_SubjectImagesList.msg: -------------------------------------------------------------------------------- 1 | Header header 2 | MSG_SubjectImages[] subjects 3 | -------------------------------------------------------------------------------- /rt_gene/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | rt_gene 4 | 4.0.1 5 | A package implementing head pose estimation and gaze estimation based on the ECCV2018 paper RT-GENE: Real-Time Eye Gaze Estimation in Natural Environments 6 | 7 | Tobias Fischer 8 | 9 | CC BY-NC-SA 4.0 10 | 11 | https://www.tobiasfischer.info 12 | 13 | Tobias Fischer 14 | Hyung Jin Chang 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | catkin 28 | genmsg 29 | 30 | message_generation 31 | 32 | rospy 33 | std_msgs 34 | geometry_msgs 35 | sensor_msgs 36 | image_geometry 37 | cv_bridge 38 | image_transport 39 | tf 40 | tf2_ros 41 | 42 | uvc_camera 43 | dynamic_reconfigure 44 | 45 | 46 | 47 | 56 | 57 | rviz 58 | message_runtime 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /rt_gene/scripts/download_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 4 | # (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) 5 | 6 | from __future__ import print_function, division, absolute_import 7 | 8 | import rt_gene.download_tools as download_tools 9 | 10 | 11 | if __name__ == '__main__': 12 | download_tools.download_gaze_tensorflow_models() 13 | download_tools.download_gaze_pytorch_models() 14 | download_tools.download_blink_tensorflow_models() 15 | download_tools.download_blink_pytorch_models() 16 | download_tools.download_external_landmark_models() 17 | 18 | -------------------------------------------------------------------------------- /rt_gene/scripts/estimate_blink.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | CNN for blink estimation 5 | @Kevin Cortacero 6 | @Tobias Fischer (t.fischer@imperial.ac.uk) 7 | Licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) 8 | """ 9 | 10 | from __future__ import print_function, division, absolute_import 11 | 12 | import os 13 | import rospy 14 | import rospkg 15 | 16 | from rt_gene.msg import MSG_SubjectImagesList 17 | from rt_gene.msg import MSG_BlinkList, MSG_Blink 18 | from rt_gene.subject_ros_bridge import SubjectListBridge 19 | 20 | from sensor_msgs.msg import Image 21 | from cv_bridge import CvBridge 22 | import cv2 23 | 24 | import numpy as np 25 | import collections 26 | from tqdm import tqdm 27 | 28 | 29 | class BlinkEstimatorROS(object): 30 | def __init__(self, device_id_blink, model_files, threshold): 31 | self.cv_bridge = CvBridge() 32 | self.bridge = SubjectListBridge() 33 | self.viz = rospy.get_param("~viz", True) 34 | 35 | blink_backend = rospy.get_param("~blink_backend", default="pytorch") 36 | model_type = rospy.get_param("~model_type", default="resnet18") 37 | 38 | if blink_backend == "tensorflow": 39 | from rt_bene.estimate_blink_tensorflow import BlinkEstimatorTensorflow 40 | self._blink_estimator = BlinkEstimatorTensorflow(device_id_blink, model_files, model_type, threshold) 41 | elif blink_backend == "pytorch": 42 | from rt_bene.estimate_blink_pytorch import BlinkEstimatorPytorch 43 | self._blink_estimator = BlinkEstimatorPytorch(device_id_blink, model_files, model_type, threshold) 44 | else: 45 | raise ValueError("Incorrect gaze_base backend, choices are: tensorflow or pytorch") 46 | 47 | self._last_time = rospy.Time().now() 48 | self._freq_deque = collections.deque(maxlen=30) # average frequency statistic over roughly one second 49 | self._latency_deque = collections.deque(maxlen=30) 50 | 51 | self.blink_publisher = rospy.Publisher("/subjects/blink", MSG_BlinkList, queue_size=3) 52 | if self.viz: 53 | self.viz_pub = rospy.Publisher(rospy.get_param("~viz_topic", "/subjects/blink_images"), Image, queue_size=3) 54 | 55 | self.sub = rospy.Subscriber("/subjects/images", MSG_SubjectImagesList, self.callback, queue_size=1, 56 | buff_size=2 ** 24) 57 | 58 | def callback(self, msg): 59 | subjects = self.bridge.msg_to_images(msg) 60 | left_eyes = [] 61 | right_eyes = [] 62 | 63 | for subject in subjects.values(): 64 | _left, _right = self._blink_estimator.inputs_from_images(subject.left, subject.right) 65 | left_eyes.append(_left) 66 | right_eyes.append(_right) 67 | 68 | if len(left_eyes) == 0: 69 | return 70 | 71 | probs = self._blink_estimator.predict(left_eyes, right_eyes) 72 | 73 | self.publish_msg(msg.header, subjects, probs) 74 | 75 | if self.viz: 76 | blink_image_list = [] 77 | for subject, p in zip(subjects.values(), probs): 78 | resized_face = cv2.resize(subject.face, dsize=(224, 224), interpolation=cv2.INTER_CUBIC) 79 | blink_image_list.append(self._blink_estimator.overlay_prediction_over_img(resized_face, p)) 80 | 81 | if len(blink_image_list) > 0: 82 | blink_viz_img = self.cv_bridge.cv2_to_imgmsg(np.hstack(blink_image_list), encoding="bgr8") 83 | blink_viz_img.header.stamp = msg.header.stamp 84 | self.viz_pub.publish(blink_viz_img) 85 | 86 | _now = rospy.Time().now() 87 | timestamp = msg.header.stamp 88 | 89 | _freq = 1.0 / (_now - self._last_time).to_sec() 90 | self._freq_deque.append(_freq) 91 | self._latency_deque.append(_now.to_sec() - timestamp.to_sec()) 92 | self._last_time = _now 93 | tqdm.write( 94 | '\033[2K\033[1;32mTime now: {:.2f} message color: {:.2f} latency: {:.2f}s for {} subject(s) {:.0f}Hz\033[0m'.format( 95 | (_now.to_sec()), timestamp.to_sec(), np.mean(self._latency_deque), len(subjects), 96 | np.mean(self._freq_deque)), end="\r") 97 | 98 | def publish_msg(self, header, subjects, probabilities): 99 | blink_msg_list = MSG_BlinkList() 100 | blink_msg_list.header = header 101 | for subject_id, p in zip(subjects.keys(), probabilities): 102 | blink_msg = MSG_Blink() 103 | blink_msg.subject_id = str(subject_id) 104 | blink_msg.blink = bool(p >= self._blink_estimator.threshold) 105 | blink_msg.probability = p 106 | blink_msg_list.subjects.append(blink_msg) 107 | 108 | self.blink_publisher.publish(blink_msg_list) 109 | 110 | 111 | if __name__ == "__main__": 112 | try: 113 | rospy.init_node("blink_estimator") 114 | blink_detector = BlinkEstimatorROS(device_id_blink=rospy.get_param("~device_id_blinkestimation", "/gpu:0"), 115 | model_files=[os.path.join(rospkg.RosPack().get_path("rt_gene"), model_file) 116 | for model_file in rospy.get_param("~model_files")], 117 | threshold=rospy.get_param("~threshold", 0.5)) 118 | rospy.spin() 119 | except rospy.exceptions.ROSInterruptException: 120 | print("See ya") 121 | except rospy.ROSException as e: 122 | if str(e) == "publish() to a closed topic": 123 | print("See ya") 124 | else: 125 | raise e 126 | except KeyboardInterrupt: 127 | print("Shutting down") 128 | -------------------------------------------------------------------------------- /rt_gene/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from catkin_pkg.python_setup import generate_distutils_setup 3 | import distutils.log 4 | distutils.log.set_verbosity(distutils.log.DEBUG) # Set DEBUG level 5 | 6 | d = generate_distutils_setup( 7 | packages=['rt_gene', 'rt_bene'], 8 | package_dir={'': 'src'} 9 | ) 10 | 11 | setup(**d) 12 | 13 | -------------------------------------------------------------------------------- /rt_gene/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene/src/__init__.py -------------------------------------------------------------------------------- /rt_gene/src/rt_bene/__init__.py: -------------------------------------------------------------------------------- 1 | from . import estimate_blink_base 2 | -------------------------------------------------------------------------------- /rt_gene/src/rt_bene/estimate_blink_base.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import cv2 4 | 5 | BLINK_COLOR = (0, 0, 255) 6 | NO_BLINK_COLOR = (0, 255, 0) 7 | 8 | 9 | class BlinkEstimatorBase(object): 10 | def __init__(self, device_id, threshold): 11 | self.device_id = device_id 12 | self.threshold = threshold 13 | 14 | def inputs_from_images(self, left, right): 15 | pass 16 | 17 | def predict(self, left_eyes, right_eyes): 18 | pass 19 | 20 | def overlay_prediction_over_img(self, img, p, border_size=5): 21 | img_copy = img.copy() 22 | h, w = img_copy.shape[:2] 23 | if p > self.threshold: 24 | cv2.rectangle(img_copy, (0, 0), (w, h), BLINK_COLOR, border_size) 25 | else: 26 | cv2.rectangle(img_copy, (0, 0), (w, h), NO_BLINK_COLOR, border_size) 27 | return img_copy 28 | -------------------------------------------------------------------------------- /rt_gene/src/rt_bene/estimate_blink_pytorch.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | from tqdm import tqdm 4 | from rt_gene.download_tools import download_blink_pytorch_models, md5 5 | from rt_bene.estimate_blink_base import BlinkEstimatorBase 6 | from rt_bene.blink_estimation_models_pytorch import BlinkEstimationModelResnet18, BlinkEstimationModelVGG16, BlinkEstimationModelVGG19, BlinkEstimationModelResnet50, BlinkEstimationModelDenseNet121 7 | import os 8 | import cv2 9 | import torch 10 | from torchvision import transforms 11 | 12 | MODELS = { 13 | "resnet18": BlinkEstimationModelResnet18, 14 | "resnet50": BlinkEstimationModelResnet50, 15 | "vgg16": BlinkEstimationModelVGG16, 16 | "vgg19": BlinkEstimationModelVGG19, 17 | "densenet121": BlinkEstimationModelDenseNet121 18 | } 19 | 20 | 21 | class BlinkEstimatorPytorch(BlinkEstimatorBase): 22 | 23 | def __init__(self, device_id_blink, model_files, model_type, threshold, known_hashes=( 24 | "cde99055e3b6dcf9fae6b78191c0fd9b", "67339ceefcfec4b3b8b3d7ccb03fadfa", "e5de548b2a97162c5e655259463e4d23", "7c228fe7b95ce5960c4c5cae8f2d3a09", "0a0d2d066737b333737018d738de386f")): 25 | super(BlinkEstimatorPytorch, self).__init__(device_id=device_id_blink, threshold=threshold) 26 | download_blink_pytorch_models() 27 | 28 | assert model_type in MODELS.keys(), f"PyTorch backend only supports the following backends: [{','.join(MODELS.keys())}]" 29 | 30 | # check md5 hashes 31 | model_hashes = [md5(model) for model in model_files] 32 | correct = [1 for hash in model_hashes if hash not in known_hashes] 33 | if sum(correct) > 0: 34 | raise ImportError( 35 | "MD5 Hashes of supplied model_files do not match the known_hashes argument. You have probably not set " 36 | "the --models argument and therefore you are trying to use TensorFlow models. If you are training your " 37 | "own models, then please supply the md5sum hashes in the known_hashes argument. If you're not, " 38 | "then you're using old models. The newer models should have downloaded already so please update the " 39 | "estimate_blink.launch file that you've modified.") 40 | 41 | if "OMP_NUM_THREADS" not in os.environ: 42 | os.environ["OMP_NUM_THREADS"] = "8" 43 | tqdm.write("PyTorch using {} threads.".format(os.environ["OMP_NUM_THREADS"])) 44 | 45 | self._transform = transforms.Compose([lambda x: cv2.resize(x, dsize=(60, 36), interpolation=cv2.INTER_CUBIC), 46 | transforms.ToTensor(), 47 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 48 | std=[0.229, 0.224, 0.225])]) 49 | 50 | self._models = [] 51 | for ckpt in model_files: 52 | _model = MODELS[model_type]() 53 | _model.load_state_dict(torch.load(ckpt)) 54 | _model.to(self.device_id) 55 | _model.eval() 56 | self._models.append(_model) 57 | 58 | tqdm.write('Loaded ' + str(len(self._models)) + ' model(s)') 59 | tqdm.write('Ready') 60 | 61 | def predict(self, left_eyes, right_eyes): 62 | transformed_left = torch.stack(left_eyes).to(self.device_id) 63 | transformed_right = torch.stack(right_eyes).to(self.device_id) 64 | 65 | with torch.no_grad(): 66 | result = [torch.sigmoid(model(transformed_left, transformed_right)).detach().cpu() for model in self._models] 67 | result = torch.stack(result, dim=1) 68 | result = torch.mean(result, dim=1).numpy() 69 | return result 70 | 71 | def inputs_from_images(self, left, right): 72 | return self._transform(left).to(self.device_id), self._transform(right).to(self.device_id) 73 | -------------------------------------------------------------------------------- /rt_gene/src/rt_bene/estimate_blink_tensorflow.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | 3 | import numpy as np 4 | import cv2 5 | import tensorflow as tf 6 | from tqdm import tqdm 7 | from rt_gene.download_tools import download_blink_tensorflow_models 8 | from rt_bene.estimate_blink_base import BlinkEstimatorBase 9 | 10 | 11 | class BlinkEstimatorTensorflow(BlinkEstimatorBase): 12 | 13 | def __init__(self, device_id_blink, model_files, model_type, threshold): 14 | super(BlinkEstimatorTensorflow, self).__init__(device_id=device_id_blink, threshold=threshold) 15 | download_blink_tensorflow_models() 16 | self.device_id_blink = device_id_blink 17 | self._input_size = (96, 96) 18 | 19 | assert model_type == "densenet121", "Tensorflow backend only supports DenseNet-121" 20 | 21 | tf.compat.v1.disable_eager_execution() 22 | 23 | with tf.device(self.device_id_blink): 24 | config = tf.compat.v1.ConfigProto(inter_op_parallelism_threads=1, 25 | intra_op_parallelism_threads=1) 26 | if "gpu" in self.device_id_blink: 27 | config.gpu_options.allow_growth = True 28 | config.gpu_options.per_process_gpu_memory_fraction = 0.3 29 | config.log_device_placement = False 30 | self.sess = tf.compat.v1.Session(config=config) 31 | tf.compat.v1.keras.backend.set_session(self.sess) 32 | 33 | if not isinstance(model_files, list): 34 | model_files = [model_files] 35 | 36 | models = [] 37 | for model_path in model_files: 38 | tqdm.write('Load model ' + model_path) 39 | models.append(tf.keras.models.load_model(model_path, compile=False)) 40 | # noinspection PyProtectedMember 41 | models[-1]._name = "model_{}".format(len(models)) 42 | 43 | img_input_l = tf.keras.Input(shape=self._input_size + (3,), name='img_input_L') 44 | img_input_r = tf.keras.Input(shape=self._input_size + (3,), name='img_input_R') 45 | 46 | if len(models) == 1: 47 | self.model = models[0] 48 | else: 49 | tensors = [model([img_input_r, img_input_l]) for model in models] 50 | output_layer = tf.keras.layers.average(tensors) 51 | self.model = tf.keras.Model(inputs=[img_input_r, img_input_l], outputs=output_layer) 52 | 53 | # noinspection PyProtectedMember 54 | self.model._make_predict_function() 55 | self.graph = tf.compat.v1.get_default_graph() 56 | 57 | self.predict(np.zeros((1,) + self._input_size + (3,)), np.zeros((1,) + self._input_size + (3,))) 58 | 59 | tqdm.write('Loaded ' + str(len(models)) + ' model(s)') 60 | tqdm.write('Ready') 61 | 62 | def predict(self, left_eyes, right_eyes): 63 | with self.graph.as_default(): 64 | tf.compat.v1.keras.backend.set_session(self.sess) 65 | x = [np.array(right_eyes), np.array(left_eyes)] # model expects this order! 66 | p = self.model.predict(x, verbose=0) 67 | return p 68 | 69 | def inputs_from_images(self, cv_image_left, cv_image_right): 70 | _left = cv2.resize(cv_image_left, self._input_size, cv2.INTER_CUBIC) 71 | _right = cv2.flip(cv2.resize(cv_image_right, self._input_size, cv2.INTER_CUBIC), 1) 72 | return _left, _right 73 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/SFD/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Adrian Bulat 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | 31 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/SFD/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene/src/rt_gene/SFD/__init__.py -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/ThreeDDFA/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2018 Jianzhu Guo 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/ThreeDDFA/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ddfa 2 | from . import inference 3 | from . import io 4 | from . import mobilenet_v1 5 | # from . import params 6 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/ThreeDDFA/ddfa.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2018 Jianzhu Guo 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import numpy as np 26 | 27 | import torch 28 | 29 | 30 | def _parse_param(param): 31 | """Work for both numpy and tensor""" 32 | p_ = param[:12].reshape(3, -1) 33 | p = p_[:, :3] 34 | offset = p_[:, -1].reshape(3, 1) 35 | alpha_shp = param[12:52].reshape(-1, 1) 36 | alpha_exp = param[52:].reshape(-1, 1) 37 | return p, offset, alpha_shp, alpha_exp 38 | 39 | 40 | def reconstruct_vertex(param, whitening=True, dense=False, transform=True): 41 | """Whitening param -> 3d vertex, based on the 3dmm param: u_base, w_shp, w_exp 42 | dense: if True, return dense vertex, else return 68 sparse landmarks. All dense or sparse vertex is transformed to 43 | image coordinate space, but without alignment caused by face cropping. 44 | transform: whether transform to image space 45 | """ 46 | from .params import param_mean, param_std, w_shp, w_exp, u, std_size, w_shp_base, w_exp_base, u_base 47 | if len(param) == 12: 48 | param = np.concatenate((param, [0] * 50)) 49 | if whitening: 50 | if len(param) == 62: 51 | param = param * param_std + param_mean 52 | else: 53 | param = np.concatenate((param[:11], [0], param[11:])) 54 | param = param * param_std + param_mean 55 | 56 | p, offset, alpha_shp, alpha_exp = _parse_param(param) 57 | 58 | if dense: 59 | t1 = np.dot(w_shp, alpha_shp) 60 | t2 = np.dot(w_exp, alpha_exp) 61 | vertex = np.matmul(p, (u + t1 + t2).reshape(3, -1, order='F')) + offset 62 | else: 63 | """For 68 pts""" 64 | t1 = np.dot(w_shp_base, alpha_shp) 65 | t2 = np.dot(w_exp_base, alpha_exp) 66 | vertex = np.matmul(p, (u_base + t1 + t2).reshape(3, -1, order='F')) + offset 67 | 68 | if transform: 69 | # transform to image coordinate space 70 | vertex[1, :] = std_size + 1 - vertex[1, :] 71 | 72 | return vertex 73 | 74 | 75 | class ToTensorGjz(object): 76 | def __call__(self, pic): 77 | if isinstance(pic, np.ndarray): 78 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 79 | return img.float() 80 | 81 | def __repr__(self): 82 | return self.__class__.__name__ + '()' 83 | 84 | 85 | class NormalizeGjz(object): 86 | def __init__(self, mean, std): 87 | self.mean = mean 88 | self.std = std 89 | 90 | def __call__(self, tensor): 91 | tensor.sub_(self.mean).div_(self.std) 92 | return tensor 93 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/ThreeDDFA/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2018 Jianzhu Guo 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | 26 | import numpy as np 27 | from math import sqrt 28 | from .ddfa import reconstruct_vertex 29 | 30 | 31 | def get_suffix(filename): 32 | """a.jpg -> jpg""" 33 | pos = filename.rfind('.') 34 | if pos == -1: 35 | return '' 36 | return filename[pos:] 37 | 38 | 39 | def crop_img(img, roi_box): 40 | h, w = img.shape[:2] 41 | 42 | sx, sy, ex, ey = [int(round(_)) for _ in roi_box] 43 | dh, dw = ey - sy, ex - sx 44 | if len(img.shape) == 3: 45 | res = np.zeros((dh, dw, 3), dtype=np.uint8) 46 | else: 47 | res = np.zeros((dh, dw), dtype=np.uint8) 48 | if sx < 0: 49 | sx, dsx = 0, -sx 50 | else: 51 | dsx = 0 52 | 53 | if ex > w: 54 | ex, dex = w, dw - (ex - w) 55 | else: 56 | dex = dw 57 | 58 | if sy < 0: 59 | sy, dsy = 0, -sy 60 | else: 61 | dsy = 0 62 | 63 | if ey > h: 64 | ey, dey = h, dh - (ey - h) 65 | else: 66 | dey = dh 67 | 68 | res[dsy:dey, dsx:dex] = img[sy:ey, sx:ex] 69 | return res 70 | 71 | 72 | def calc_hypotenuse(pts): 73 | bbox = [min(pts[0, :]), min(pts[1, :]), max(pts[0, :]), max(pts[1, :])] 74 | center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] 75 | radius = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 76 | bbox = [center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius] 77 | llength = sqrt((bbox[2] - bbox[0]) ** 2 + (bbox[3] - bbox[1]) ** 2) 78 | return llength / 3 79 | 80 | 81 | def parse_roi_box_from_landmark(pts): 82 | """calc roi box from landmark""" 83 | bbox = [min(pts[0, :]), min(pts[1, :]), max(pts[0, :]), max(pts[1, :])] 84 | center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] 85 | radius = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2 86 | bbox = [center[0] - radius, center[1] - radius, center[0] + radius, center[1] + radius] 87 | 88 | llength = sqrt((bbox[2] - bbox[0]) ** 2 + (bbox[3] - bbox[1]) ** 2) 89 | center_x = (bbox[2] + bbox[0]) / 2 90 | center_y = (bbox[3] + bbox[1]) / 2 91 | 92 | roi_box = [0.0] * 4 93 | roi_box[0] = center_x - llength / 2 94 | roi_box[1] = center_y - llength / 2 95 | roi_box[2] = roi_box[0] + llength 96 | roi_box[3] = roi_box[1] + llength 97 | 98 | return roi_box 99 | 100 | 101 | def parse_roi_box_from_bbox(bbox): 102 | left, top, right, bottom = bbox 103 | old_size = (right - left + bottom - top) / 2 104 | center_x = right - (right - left) / 2.0 105 | center_y = bottom - (bottom - top) / 2.0 + old_size * 0.14 106 | size = int(old_size * 1.58) 107 | roi_box = [0] * 4 108 | roi_box[0] = center_x - size / 2 109 | roi_box[1] = center_y - size / 2 110 | roi_box[2] = roi_box[0] + size 111 | roi_box[3] = roi_box[1] + size 112 | return roi_box 113 | 114 | 115 | def _predict_vertices(param, roi_bbox, dense): 116 | from .params import std_size 117 | vertex = reconstruct_vertex(param, dense=dense) 118 | sx, sy, ex, ey = roi_bbox 119 | scale_x = (ex - sx) / std_size 120 | scale_y = (ey - sy) / std_size 121 | vertex[0, :] = vertex[0, :] * scale_x + sx 122 | vertex[1, :] = vertex[1, :] * scale_y + sy 123 | 124 | s = (scale_x + scale_y) / 2 125 | vertex[2, :] *= s 126 | 127 | return vertex 128 | 129 | 130 | def predict_68pts(param, roi_box): 131 | return _predict_vertices(param, roi_box, dense=False) 132 | 133 | 134 | def predict_dense(param, roi_box): 135 | return _predict_vertices(param, roi_box, dense=True) 136 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/ThreeDDFA/io.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2018 Jianzhu Guo 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | 26 | import os 27 | import numpy as np 28 | import torch 29 | import pickle 30 | import scipy.io as sio 31 | 32 | 33 | def mkdir(d): 34 | """only works on *nix system""" 35 | if not os.path.isdir(d) and not os.path.exists(d): 36 | os.system('mkdir -p {}'.format(d)) 37 | 38 | 39 | def _get_suffix(filename): 40 | """a.jpg -> jpg""" 41 | pos = filename.rfind('.') 42 | if pos == -1: 43 | return '' 44 | return filename[pos + 1:] 45 | 46 | 47 | def _load(fp): 48 | suffix = _get_suffix(fp) 49 | if suffix == 'npy': 50 | return np.load(fp) 51 | elif suffix == 'pkl': 52 | return pickle.load(open(fp, 'rb')) 53 | 54 | 55 | def _dump(wfp, obj): 56 | suffix = _get_suffix(wfp) 57 | if suffix == 'npy': 58 | np.save(wfp, obj) 59 | elif suffix == 'pkl': 60 | pickle.dump(obj, open(wfp, 'wb')) 61 | else: 62 | raise Exception('Unknown Type: {}'.format(suffix)) 63 | 64 | 65 | def _load_tensor(fp, mode='cpu'): 66 | if mode.lower() == 'cpu': 67 | return torch.from_numpy(_load(fp)) 68 | elif mode.lower() == 'gpu': 69 | return torch.from_numpy(_load(fp)).cuda() 70 | 71 | 72 | def _tensor_to_cuda(x): 73 | if x.is_cuda: 74 | return x 75 | else: 76 | return x.cuda() 77 | 78 | 79 | def _load_gpu(fp): 80 | return torch.from_numpy(_load(fp)).cuda() 81 | 82 | 83 | def load_bfm(model_path): 84 | suffix = _get_suffix(model_path) 85 | if suffix == 'mat': 86 | C = sio.loadmat(model_path) 87 | model = C['model_refine'] 88 | model = model[0, 0] 89 | 90 | model_new = {} 91 | w_shp = model['w'].astype(float) 92 | model_new['w_shp_sim'] = w_shp[:, :40] 93 | w_exp = model['w_exp'].astype(float) 94 | model_new['w_exp_sim'] = w_exp[:, :10] 95 | 96 | u_shp = model['mu_shape'] 97 | u_exp = model['mu_exp'] 98 | u = (u_shp + u_exp).astype(float) 99 | model_new['mu'] = u 100 | model_new['tri'] = model['tri'].astype(int) - 1 101 | 102 | # flatten it, pay attention to index value 103 | keypoints = model['keypoints'].astype(int) - 1 104 | keypoints = np.concatenate((3 * keypoints, 3 * keypoints + 1, 3 * keypoints + 2), axis=0) 105 | 106 | model_new['keypoints'] = keypoints.T.flatten() 107 | 108 | # 109 | w = np.concatenate((w_shp, w_exp), axis=1) 110 | w_base = w[keypoints] 111 | w_norm = np.linalg.norm(w, axis=0) 112 | w_base_norm = np.linalg.norm(w_base, axis=0) 113 | 114 | dim = w_shp.shape[0] // 3 115 | u_base = u[keypoints].reshape(-1, 1) 116 | w_shp_base = w_shp[keypoints] 117 | w_exp_base = w_exp[keypoints] 118 | 119 | model_new['w_norm'] = w_norm 120 | model_new['w_base_norm'] = w_base_norm 121 | model_new['dim'] = dim 122 | model_new['u_base'] = u_base 123 | model_new['w_shp_base'] = w_shp_base 124 | model_new['w_exp_base'] = w_exp_base 125 | 126 | _dump(model_path.replace('.mat', '.pkl'), model_new) 127 | return model_new 128 | else: 129 | return _load(model_path) 130 | 131 | 132 | _load_cpu = _load 133 | _numpy_to_tensor = lambda x: torch.from_numpy(x) 134 | _tensor_to_numpy = lambda x: x.cpu() 135 | _numpy_to_cuda = lambda x: _tensor_to_cuda(torch.from_numpy(x)) 136 | _cuda_to_tensor = lambda x: x.cpu() 137 | _cuda_to_numpy = lambda x: x.cpu().numpy() 138 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/ThreeDDFA/mobilenet_v1.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2018 Jianzhu Guo 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | Creates a MobileNet Model as defined in: 25 | Andrew G. Howard Menglong Zhu Bo Chen, et.al. (2017). 26 | MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications. 27 | Copyright (c) Yang Lu, 2017 28 | 29 | Modified By cleardusk 30 | """ 31 | 32 | from __future__ import division 33 | 34 | import math 35 | import torch.nn as nn 36 | 37 | __all__ = ['mobilenet_2', 'mobilenet_1', 'mobilenet_075', 'mobilenet_05', 'mobilenet_025'] 38 | 39 | 40 | class DepthWiseBlock(nn.Module): 41 | def __init__(self, inplanes, planes, stride=1, prelu=False): 42 | super(DepthWiseBlock, self).__init__() 43 | inplanes, planes = int(inplanes), int(planes) 44 | self.conv_dw = nn.Conv2d(inplanes, inplanes, kernel_size=3, padding=1, stride=stride, groups=inplanes, 45 | bias=False) 46 | self.bn_dw = nn.BatchNorm2d(inplanes) 47 | self.conv_sep = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False) 48 | self.bn_sep = nn.BatchNorm2d(planes) 49 | if prelu: 50 | self.relu = nn.PReLU() 51 | else: 52 | self.relu = nn.ReLU(inplace=True) 53 | 54 | def forward(self, x): 55 | out = self.conv_dw(x) 56 | out = self.bn_dw(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv_sep(out) 60 | out = self.bn_sep(out) 61 | out = self.relu(out) 62 | 63 | return out 64 | 65 | 66 | class MobileNet(nn.Module): 67 | def __init__(self, widen_factor=1.0, num_classes=1000, prelu=False, input_channel=3): 68 | """ Constructor 69 | Args: 70 | widen_factor: config of widen_factor 71 | num_classes: number of classes 72 | """ 73 | super(MobileNet, self).__init__() 74 | 75 | block = DepthWiseBlock 76 | self.conv1 = nn.Conv2d(input_channel, int(32 * widen_factor), kernel_size=3, stride=2, padding=1, 77 | bias=False) 78 | 79 | self.bn1 = nn.BatchNorm2d(int(32 * widen_factor)) 80 | if prelu: 81 | self.relu = nn.PReLU() 82 | else: 83 | self.relu = nn.ReLU(inplace=True) 84 | 85 | self.dw2_1 = block(32 * widen_factor, 64 * widen_factor, prelu=prelu) 86 | self.dw2_2 = block(64 * widen_factor, 128 * widen_factor, stride=2, prelu=prelu) 87 | 88 | self.dw3_1 = block(128 * widen_factor, 128 * widen_factor, prelu=prelu) 89 | self.dw3_2 = block(128 * widen_factor, 256 * widen_factor, stride=2, prelu=prelu) 90 | 91 | self.dw4_1 = block(256 * widen_factor, 256 * widen_factor, prelu=prelu) 92 | self.dw4_2 = block(256 * widen_factor, 512 * widen_factor, stride=2, prelu=prelu) 93 | 94 | self.dw5_1 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 95 | self.dw5_2 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 96 | self.dw5_3 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 97 | self.dw5_4 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 98 | self.dw5_5 = block(512 * widen_factor, 512 * widen_factor, prelu=prelu) 99 | self.dw5_6 = block(512 * widen_factor, 1024 * widen_factor, stride=2, prelu=prelu) 100 | 101 | self.dw6 = block(1024 * widen_factor, 1024 * widen_factor, prelu=prelu) 102 | 103 | self.avgpool = nn.AdaptiveAvgPool2d(1) 104 | self.fc = nn.Linear(int(1024 * widen_factor), num_classes) 105 | 106 | for m in self.modules(): 107 | if isinstance(m, nn.Conv2d): 108 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 109 | m.weight.data.normal_(0, math.sqrt(2. / n)) 110 | elif isinstance(m, nn.BatchNorm2d): 111 | m.weight.data.fill_(1) 112 | m.bias.data.zero_() 113 | 114 | def forward(self, x): 115 | x = self.conv1(x) 116 | x = self.bn1(x) 117 | x = self.relu(x) 118 | 119 | x = self.dw2_1(x) 120 | x = self.dw2_2(x) 121 | x = self.dw3_1(x) 122 | x = self.dw3_2(x) 123 | x = self.dw4_1(x) 124 | x = self.dw4_2(x) 125 | x = self.dw5_1(x) 126 | x = self.dw5_2(x) 127 | x = self.dw5_3(x) 128 | x = self.dw5_4(x) 129 | x = self.dw5_5(x) 130 | x = self.dw5_6(x) 131 | x = self.dw6(x) 132 | 133 | x = self.avgpool(x) 134 | x = x.view(x.size(0), -1) 135 | x = self.fc(x) 136 | 137 | return x 138 | 139 | 140 | def mobilenet(widen_factor=1.0, num_classes=1000): 141 | """ 142 | Construct MobileNet. 143 | widen_factor=1.0 for mobilenet_1 144 | widen_factor=0.75 for mobilenet_075 145 | widen_factor=0.5 for mobilenet_05 146 | widen_factor=0.25 for mobilenet_025 147 | """ 148 | model = MobileNet(widen_factor=widen_factor, num_classes=num_classes) 149 | return model 150 | 151 | 152 | def mobilenet_2(num_classes=62, input_channel=3): 153 | model = MobileNet(widen_factor=2.0, num_classes=num_classes, input_channel=input_channel) 154 | return model 155 | 156 | 157 | def mobilenet_1(num_classes=62, input_channel=3): 158 | model = MobileNet(widen_factor=1.0, num_classes=num_classes, input_channel=input_channel) 159 | return model 160 | 161 | 162 | def mobilenet_075(num_classes=62, input_channel=3): 163 | model = MobileNet(widen_factor=0.75, num_classes=num_classes, input_channel=input_channel) 164 | return model 165 | 166 | 167 | def mobilenet_05(num_classes=62, input_channel=3): 168 | model = MobileNet(widen_factor=0.5, num_classes=num_classes, input_channel=input_channel) 169 | return model 170 | 171 | 172 | def mobilenet_025(num_classes=62, input_channel=3): 173 | model = MobileNet(widen_factor=0.25, num_classes=num_classes, input_channel=input_channel) 174 | return model 175 | 176 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/ThreeDDFA/params.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2018 Jianzhu Guo 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import sys 26 | import os.path as osp 27 | from .io import _load 28 | 29 | 30 | d = osp.join(osp.dirname(osp.realpath(__file__)), '../../../model_nets/ThreeDDFA/') 31 | keypoints = _load(osp.join(d, 'keypoints_sim.npy')) 32 | w_shp = _load(osp.join(d, 'w_shp_sim.npy')) 33 | w_exp = _load(osp.join(d, 'w_exp_sim.npy')) # simplified version 34 | if sys.version_info > (3, 0): 35 | meta = _load(osp.join(d, 'param_whitening.pkl')) 36 | else: 37 | meta = _load(osp.join(d, 'param_whitening_py2.pkl')) 38 | # # param_mean and param_std are used for re-whitening 39 | param_mean = meta.get('param_mean') 40 | param_std = meta.get('param_std') 41 | u_shp = _load(osp.join(d, 'u_shp.npy')) 42 | u_exp = _load(osp.join(d, 'u_exp.npy')) 43 | u = u_shp + u_exp 44 | 45 | # for inference 46 | dim = w_shp.shape[0] // 3 47 | u_base = u[keypoints].reshape(-1, 1) 48 | w_shp_base = w_shp[keypoints] 49 | w_exp_base = w_exp[keypoints] 50 | std_size = 120 51 | 52 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/__init__.py: -------------------------------------------------------------------------------- 1 | # from . import tracker_face_encoding 2 | 3 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/estimate_gaze_base.py: -------------------------------------------------------------------------------- 1 | # Licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) 2 | 3 | import os 4 | import cv2 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from rt_gene.gaze_tools import get_endpoint 9 | 10 | 11 | class GazeEstimatorBase(object): 12 | """This class encapsulates a deep neural network for gaze estimation. 13 | 14 | It retrieves two image streams, one containing the left eye and another containing the right eye. 15 | It synchronizes these two images with the estimated head pose. 16 | The images are then converted in a suitable format, and a forward pass of the deep neural network 17 | results in the estimated gaze for this frame. The estimated gaze is then published in the (theta, phi) notation.""" 18 | def __init__(self, device_id_gaze, model_files): 19 | if "OMP_NUM_THREADS" not in os.environ: 20 | os.environ["OMP_NUM_THREADS"] = "8" 21 | tqdm.write("PyTorch using {} threads.".format(os.environ["OMP_NUM_THREADS"])) 22 | self.device_id_gazeestimation = device_id_gaze 23 | self.model_files = model_files 24 | 25 | if not isinstance(model_files, list): 26 | self.model_files = [model_files] 27 | 28 | if len(self.model_files) == 1: 29 | self._gaze_offset = 0.11 30 | else: 31 | self._gaze_offset = 0.0 32 | 33 | def estimate_gaze_twoeyes(self, inference_input_left_list, inference_input_right_list, inference_headpose_list): 34 | pass 35 | 36 | def input_from_image(self, cv_image): 37 | pass 38 | 39 | @staticmethod 40 | def visualize_eye_result(eye_image, est_gaze): 41 | """Here, we take the original eye eye_image and overlay the estimated gaze.""" 42 | output_image = np.copy(eye_image) 43 | 44 | center_x = output_image.shape[1] / 2 45 | center_y = output_image.shape[0] / 2 46 | 47 | endpoint_x, endpoint_y = get_endpoint(est_gaze[0], est_gaze[1], center_x, center_y, 50) 48 | 49 | cv2.line(output_image, (int(center_x), int(center_y)), (int(endpoint_x), int(endpoint_y)), (255, 0, 0)) 50 | return output_image 51 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/estimate_gaze_pytorch.py: -------------------------------------------------------------------------------- 1 | # Licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) 2 | 3 | import os 4 | 5 | import cv2 6 | import torch 7 | from torchvision import transforms 8 | from tqdm import tqdm 9 | 10 | from rt_gene.estimate_gaze_base import GazeEstimatorBase 11 | from rt_gene.gaze_estimation_models_pytorch import GazeEstimationModelVGG, GazeEstimationModelResnet18 12 | from rt_gene.download_tools import download_gaze_pytorch_models, md5 13 | 14 | 15 | class GazeEstimator(GazeEstimatorBase): 16 | def __init__(self, device_id_gaze, model_files, known_hashes=( 17 | "ae435739673411940eed18c98c29bfb1", "4afd7ccf5619552ed4a9f14606b7f4dd", "743902e643322c40bd78ca36aacc5b4d", 18 | "06a10f43088651053a65f9b0cd5ac4aa")): 19 | super(GazeEstimator, self).__init__(device_id_gaze, model_files) 20 | download_gaze_pytorch_models() 21 | # check md5 hashes 22 | _model_hashes = [md5(model) for model in model_files] 23 | _correct = [1 for hash in _model_hashes if hash not in known_hashes] 24 | if sum(_correct) > 0: 25 | raise ImportError( 26 | "MD5 Hashes of supplied model_files do not match the known_hashes argument. You have probably not set " 27 | "the --models argument and therefore you are trying to use TensorFlow models. If you are training your " 28 | "own models, then please supply the md5sum hashes in the known_hashes argument. If you're not, " 29 | "then you're using old models. The newer models should have downloaded already so please update the " 30 | "estimate_gaze.launch file that you've modified.") 31 | 32 | if "OMP_NUM_THREADS" not in os.environ: 33 | os.environ["OMP_NUM_THREADS"] = "8" 34 | tqdm.write("PyTorch using {} threads.".format(os.environ["OMP_NUM_THREADS"])) 35 | 36 | self._transform = transforms.Compose([lambda x: cv2.resize(x, dsize=(60, 36), interpolation=cv2.INTER_CUBIC), 37 | transforms.ToTensor(), 38 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 39 | std=[0.229, 0.224, 0.225])]) 40 | 41 | self._models = [] 42 | for ckpt in self.model_files: 43 | try: 44 | _model = GazeEstimationModelVGG(num_out=2) 45 | _model.load_state_dict(torch.load(ckpt)) 46 | _model.to(self.device_id_gazeestimation) 47 | _model.eval() 48 | self._models.append(_model) 49 | except Exception as e: 50 | print("Error loading checkpoint", ckpt) 51 | raise e 52 | 53 | tqdm.write('Loaded ' + str(len(self._models)) + ' model(s)') 54 | 55 | def estimate_gaze_twoeyes(self, inference_input_left_list, inference_input_right_list, inference_headpose_list): 56 | transformed_left = torch.stack(inference_input_left_list).to(self.device_id_gazeestimation) 57 | transformed_right = torch.stack(inference_input_right_list).to(self.device_id_gazeestimation) 58 | tranformed_head = torch.as_tensor(inference_headpose_list).to(self.device_id_gazeestimation) 59 | 60 | result = [model(transformed_left, transformed_right, tranformed_head).detach().cpu() for model in self._models] 61 | result = torch.stack(result, dim=1) 62 | result = torch.mean(result, dim=1).numpy() 63 | result[:, 1] += self._gaze_offset 64 | return result 65 | 66 | def input_from_image(self, cv_image): 67 | return self._transform(cv_image) 68 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/estimate_gaze_tensorflow.py: -------------------------------------------------------------------------------- 1 | # Licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from tqdm import tqdm 6 | 7 | from rt_gene.estimate_gaze_base import GazeEstimatorBase 8 | from rt_gene.download_tools import download_gaze_tensorflow_models 9 | 10 | 11 | class GazeEstimator(GazeEstimatorBase): 12 | def __init__(self, device_id_gaze, model_files): 13 | super(GazeEstimator, self).__init__(device_id_gaze, model_files) 14 | download_gaze_tensorflow_models() 15 | 16 | # Configure GPU settings if a GPU is specified 17 | if "gpu" in self.device_id_gazeestimation.lower(): 18 | gpus = tf.config.list_physical_devices('GPU') 19 | if gpus: 20 | try: 21 | for gpu in gpus: 22 | # Enable memory growth to prevent TensorFlow from allocating all GPU memory 23 | tf.config.experimental.set_memory_growth(gpu, True) 24 | except RuntimeError as e: 25 | print(f"GPU configuration error: {e}") 26 | 27 | # Set device context 28 | self.device = tf.device(self.device_id_gazeestimation) 29 | with self.device: 30 | # Define input layers 31 | img_input_l = tf.keras.Input(shape=(36, 60, 3), name='img_input_L') 32 | img_input_r = tf.keras.Input(shape=(36, 60, 3), name='img_input_R') 33 | headpose_input = tf.keras.Input(shape=(2,), name='headpose_input') 34 | 35 | # Load models 36 | models = [] 37 | for idx, model_file in enumerate(self.model_files, start=1): 38 | tqdm.write(f'Loading model {model_file}') 39 | model = tf.keras.models.load_model(model_file, compile=False) 40 | model._name = f"model_{idx}" 41 | models.append(model) 42 | 43 | # Create ensemble model 44 | if len(models) == 1: 45 | self.ensemble_model = models[0] 46 | elif len(models) > 1: 47 | # Collect outputs from all models 48 | outputs = [model([img_input_l, img_input_r, headpose_input]) for model in models] 49 | # Average the outputs 50 | averaged_output = tf.keras.layers.Average()(outputs) 51 | # Define the ensemble model 52 | self.ensemble_model = tf.keras.Model( 53 | inputs=[img_input_l, img_input_r, headpose_input], 54 | outputs=averaged_output, 55 | name='ensemble_model' 56 | ) 57 | else: 58 | raise ValueError("No models were loaded") 59 | 60 | tqdm.write(f'Loaded {len(models)} model(s)') 61 | 62 | def __del__(self): 63 | # No need to manually close sessions in TensorFlow 2.x 64 | pass 65 | 66 | def estimate_gaze_twoeyes(self, inference_input_left_list, inference_input_right_list, inference_headpose_list): 67 | """ 68 | Estimate gaze using the ensemble model. 69 | 70 | Args: 71 | inference_input_left_list (list or np.ndarray): List of left eye images. 72 | inference_input_right_list (list or np.ndarray): List of right eye images. 73 | inference_headpose_list (list or np.ndarray): List of head pose data. 74 | 75 | Returns: 76 | np.ndarray: Gaze predictions with offset applied. 77 | """ 78 | # Prepare inputs as a list matching the input layer order 79 | inputs = [ 80 | np.array(inference_input_left_list), 81 | np.array(inference_input_right_list), 82 | np.array(inference_headpose_list) 83 | ] 84 | 85 | # Perform prediction 86 | mean_prediction = self.ensemble_model.predict(inputs) 87 | 88 | # Apply gaze offset 89 | mean_prediction[:, 1] += self._gaze_offset 90 | 91 | return mean_prediction # returns [subject : [gaze_pose]] 92 | 93 | def input_from_image(self, cv_image): 94 | """ 95 | Convert an eye image from the landmark estimator to the format suitable for the gaze network. 96 | 97 | Args: 98 | cv_image (np.ndarray): Eye image array. 99 | 100 | Returns: 101 | np.ndarray: Preprocessed image suitable for the model. 102 | """ 103 | # Reshape and ensure correct data type 104 | currimg = cv_image.reshape(36, 60, 3, order='F').astype(float) 105 | 106 | # Initialize an array for the preprocessed image 107 | testimg = np.zeros((36, 60, 3)) 108 | 109 | # Subtract mean values for each channel (BGR) 110 | testimg[:, :, 0] = currimg[:, :, 0] - 103.939 111 | testimg[:, :, 1] = currimg[:, :, 1] - 116.779 112 | testimg[:, :, 2] = currimg[:, :, 2] - 123.68 113 | 114 | return testimg 115 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/extract_landmarks_method_base.py: -------------------------------------------------------------------------------- 1 | # Licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as transforms 7 | from torch.backends import cudnn as cudnn 8 | from tqdm import tqdm 9 | 10 | from rt_gene.download_tools import download_external_landmark_models 11 | 12 | # noinspection PyUnresolvedReferences 13 | from rt_gene import gaze_tools as gaze_tools 14 | from rt_gene.SFD.sfd_detector import SFDDetector 15 | from rt_gene.ThreeDDFA.ddfa import ToTensorGjz, NormalizeGjz 16 | from rt_gene.ThreeDDFA.inference import crop_img, predict_68pts, parse_roi_box_from_bbox, parse_roi_box_from_landmark 17 | from rt_gene.tracker_generic import TrackedSubject 18 | 19 | facial_landmark_transform = transforms.Compose([ToTensorGjz(), NormalizeGjz(mean=127.5, std=128)]) 20 | 21 | 22 | class LandmarkMethodBase(object): 23 | def __init__(self, device_id_facedetection, checkpoint_path_face=None, checkpoint_path_landmark=None, model_points_file=None): 24 | download_external_landmark_models() 25 | self.model_size_rescale = 16.0 26 | self.head_pitch = 0.0 27 | self.interpupillary_distance = 0.058 28 | self.eye_image_size = (60, 36) 29 | 30 | tqdm.write("Using device {} for face detection.".format(device_id_facedetection)) 31 | 32 | self.device = device_id_facedetection 33 | self.face_net = SFDDetector(device=device_id_facedetection, path_to_detector=checkpoint_path_face) 34 | self.facial_landmark_nn = self.load_face_landmark_model(checkpoint_path_landmark) 35 | 36 | self.model_points = self.get_full_model_points(model_points_file) 37 | 38 | def load_face_landmark_model(self, checkpoint_fp=None): 39 | import rt_gene.ThreeDDFA.mobilenet_v1 as mobilenet_v1 40 | if checkpoint_fp is None: 41 | import rospkg 42 | checkpoint_fp = rospkg.RosPack().get_path('rt_gene') + '/model_nets/phase1_wpdc_vdc.pth.tar' 43 | arch = 'mobilenet_1' 44 | 45 | checkpoint = torch.load(checkpoint_fp, map_location=lambda storage, loc: storage)['state_dict'] 46 | model = getattr(mobilenet_v1, arch)(num_classes=62) # 62 = 12(pose) + 40(shape) +10(expression) 47 | 48 | model_dict = model.state_dict() 49 | # because the model is trained by multiple gpus, prefix module should be removed 50 | for k in checkpoint.keys(): 51 | model_dict[k.replace('module.', '')] = checkpoint[k] 52 | model.load_state_dict(model_dict) 53 | cudnn.benchmark = True 54 | model = model.to(self.device) 55 | model.eval() 56 | return model 57 | 58 | def get_full_model_points(self, model_points_file=None): 59 | """Get all 68 3D model points from file""" 60 | raw_value = [] 61 | if model_points_file is None: 62 | import rospkg 63 | model_points_file = rospkg.RosPack().get_path('rt_gene') + '/model_nets/face_model_68.txt' 64 | 65 | with open(model_points_file) as f: 66 | for line in f: 67 | raw_value.append(line) 68 | model_points = np.array(raw_value, dtype=float) 69 | model_points = np.reshape(model_points, (3, -1)).T 70 | 71 | # index the expansion of the model based. 72 | model_points = model_points * (self.interpupillary_distance * self.model_size_rescale) 73 | 74 | return model_points 75 | 76 | def get_face_bb(self, image): 77 | faceboxes = [] 78 | fraction = 4.0 79 | image = cv2.resize(image, (0, 0), fx=1.0 / fraction, fy=1.0 / fraction) 80 | detections = self.face_net.detect_from_image(image) 81 | 82 | for result in detections: 83 | # scale back up to image size 84 | box = result[:4] 85 | confidence = result[4] 86 | 87 | if gaze_tools.box_in_image(box, image) and confidence > 0.6: 88 | box = [x * fraction for x in box] # scale back up 89 | diff_height_width = (box[3] - box[1]) - (box[2] - box[0]) 90 | offset_y = int(abs(diff_height_width / 2)) 91 | box_moved = gaze_tools.move_box(box, [0, offset_y]) 92 | 93 | # Make box square. 94 | facebox = gaze_tools.get_square_box(box_moved) 95 | faceboxes.append(facebox) 96 | 97 | return faceboxes 98 | 99 | @staticmethod 100 | def visualize_headpose_result(face_image, est_headpose): 101 | """Here, we take the original eye eye_image and overlay the estimated headpose.""" 102 | output_image = np.copy(face_image) 103 | 104 | center_x = output_image.shape[1] / 2 105 | center_y = output_image.shape[0] / 2 106 | 107 | endpoint_x, endpoint_y = gaze_tools.get_endpoint(est_headpose[1], est_headpose[0], center_x, center_y, 100) 108 | 109 | cv2.line(output_image, (int(center_x), int(center_y)), (int(endpoint_x), int(endpoint_y)), (0, 0, 255), 3) 110 | return output_image 111 | 112 | def ddfa_forward_pass(self, color_img, roi_box_list): 113 | img_step = [crop_img(color_img, roi_box) for roi_box in roi_box_list] 114 | img_step = [cv2.resize(img, dsize=(120, 120), interpolation=cv2.INTER_LINEAR) for img in img_step] 115 | _input = torch.cat([facial_landmark_transform(img).unsqueeze(0) for img in img_step], 0) 116 | with torch.no_grad(): 117 | _input = _input.to(self.device) 118 | param = self.facial_landmark_nn(_input).cpu().numpy().astype(float) 119 | 120 | return [predict_68pts(p.flatten(), roi_box) for p, roi_box in zip(param, roi_box_list)] 121 | 122 | def get_subjects_from_faceboxes(self, color_img, faceboxes): 123 | face_images = [gaze_tools.crop_face_from_image(color_img, b) for b in faceboxes] 124 | subjects = [] 125 | roi_box_list = [parse_roi_box_from_bbox(facebox) for facebox in faceboxes] 126 | initial_pts68_list = self.ddfa_forward_pass(color_img, roi_box_list) 127 | roi_box_refined_list = [parse_roi_box_from_landmark(initial_pts68) for initial_pts68 in initial_pts68_list] 128 | pts68_list = self.ddfa_forward_pass(color_img, roi_box_refined_list) 129 | 130 | for pts68, face_image, facebox in zip(pts68_list, face_images, faceboxes): 131 | np_landmarks = np.array((pts68[0], pts68[1])).T 132 | subjects.append(TrackedSubject(np.array(facebox), face_image, np_landmarks)) 133 | return subjects 134 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/gaze_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Tobias Fischer (t.fischer@imperial.ac.uk) 3 | Licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) 4 | """ 5 | 6 | from __future__ import print_function, division, absolute_import 7 | 8 | import math 9 | 10 | import numpy as np 11 | 12 | 13 | def get_phi_theta_from_euler(euler_angles): 14 | return -euler_angles[2], -euler_angles[1] 15 | 16 | 17 | def get_euler_from_phi_theta(phi, theta): 18 | return 0, -theta, -phi 19 | 20 | 21 | def get_endpoint(theta, phi, center_x, center_y, length=300): 22 | endpoint_x = -1.0 * length * math.cos(theta) * math.sin(phi) + center_x 23 | endpoint_y = -1.0 * length * math.sin(theta) + center_y 24 | return endpoint_x, endpoint_y 25 | 26 | 27 | def visualize_landmarks(image, landmarks): 28 | import cv2 29 | 30 | output_image = np.copy(image) 31 | for landmark in landmarks.reshape(-1, 2): 32 | cv2.circle(output_image, (landmark[0], landmark[1]), 2, (0, 0, 255), -1) 33 | return output_image 34 | 35 | 36 | def limit_yaw(euler_angles_head): 37 | # [0]: pos - roll right, neg - roll left 38 | # [1]: pos - look down, neg - look up 39 | # [2]: pos - rotate left, neg - rotate right 40 | euler_angles_head[2] += np.pi 41 | if euler_angles_head[2] > np.pi: 42 | euler_angles_head[2] -= 2 * np.pi 43 | 44 | return euler_angles_head 45 | 46 | 47 | def crop_face_from_image(color_img, box): 48 | _bb = list(map(int, box)) 49 | if _bb[0] < 0: 50 | _bb[0] = 0 51 | if _bb[1] < 0: 52 | _bb[1] = 0 53 | if _bb[2] > color_img.shape[1]: 54 | _bb[2] = color_img.shape[1] 55 | if _bb[3] > color_img.shape[0]: 56 | _bb[3] = color_img.shape[0] 57 | return color_img[_bb[1]: _bb[3], _bb[0]: _bb[2]] 58 | 59 | 60 | def is_rotation_vector_stable(last_rotation_vector, current_rotation_vector): 61 | # check to see if rotation_vector is wild, if so, stop checking head positions 62 | _unit_rotation_vector = current_rotation_vector / np.linalg.norm(current_rotation_vector) 63 | _unit_last_rotation_vector = last_rotation_vector / np.linalg.norm(last_rotation_vector) 64 | _theta = np.arccos(np.dot(_unit_last_rotation_vector.reshape(3, ), _unit_rotation_vector)) 65 | # tqdm.write("Head Rotation from last frame: {:.2f}".format(_theta)) 66 | if _theta > 0.1: 67 | # we have too much rotation here, likely unstable, thus error out 68 | print('Could not estimate head pose due to instability of landmarks') 69 | return False 70 | else: 71 | return True 72 | 73 | 74 | def move_box(box, offset): 75 | """Move the box to direction specified by vector offset""" 76 | left_x = box[0] + offset[0] 77 | top_y = box[1] + offset[1] 78 | right_x = box[2] + offset[0] 79 | bottom_y = box[3] + offset[1] 80 | 81 | return [left_x, top_y, right_x, bottom_y] 82 | 83 | 84 | def box_in_image(box, image): 85 | """Check if the box is in image""" 86 | rows = image.shape[0] 87 | cols = image.shape[1] 88 | 89 | return box[0] >= 0 and box[1] >= 0 and box[2] <= cols and box[3] <= rows 90 | 91 | 92 | def get_square_box(box): 93 | """Get a square box out of the given box, by expanding it.""" 94 | left_x = box[0] 95 | top_y = box[1] 96 | right_x = box[2] 97 | bottom_y = box[3] 98 | 99 | box_width = right_x - left_x 100 | box_height = bottom_y - top_y 101 | 102 | # Check if box is already a square. If not, make it a square. 103 | diff = box_height - box_width 104 | delta = int(abs(diff) / 2) 105 | 106 | if diff == 0: # Already a square. 107 | return box 108 | elif diff > 0: # Height > width, a slim box. 109 | left_x -= delta 110 | right_x += delta 111 | if diff % 2 == 1: 112 | right_x += 1 113 | else: # Width > height, a short box. 114 | top_y -= delta 115 | bottom_y += delta 116 | if diff % 2 == 1: 117 | bottom_y += 1 118 | 119 | return [left_x, top_y, right_x, bottom_y] 120 | 121 | 122 | def angle_loss(y_true, y_pred): 123 | # noinspection PyUnresolvedReferences 124 | from tensorflow.keras import backend as K 125 | return K.sum(K.square(y_pred - y_true), axis=-1) 126 | 127 | 128 | def accuracy_angle(y_true, y_pred): 129 | import tensorflow as tf 130 | 131 | pred_x = -1 * tf.cos(y_pred[0]) * tf.sin(y_pred[1]) 132 | pred_y = -1 * tf.sin(y_pred[0]) 133 | pred_z = -1 * tf.cos(y_pred[0]) * tf.cos(y_pred[1]) 134 | pred_norm = tf.sqrt(pred_x * pred_x + pred_y * pred_y + pred_z * pred_z) 135 | 136 | true_x = -1 * tf.cos(y_true[0]) * tf.sin(y_true[1]) 137 | true_y = -1 * tf.sin(y_true[0]) 138 | true_z = -1 * tf.cos(y_true[0]) * tf.cos(y_true[1]) 139 | true_norm = tf.sqrt(true_x * true_x + true_y * true_y + true_z * true_z) 140 | 141 | angle_value = (pred_x * true_x + pred_y * true_y + pred_z * true_z) / (true_norm * pred_norm) 142 | tf.clip_by_value(angle_value, -0.9999999999, 0.999999999) 143 | return (tf.acos(angle_value) * 180.0) / np.pi 144 | 145 | 146 | def get_normalised_eye_landmarks(landmarks, box): 147 | eye_indices = np.array([36, 39, 42, 45]) 148 | transformed_landmarks = landmarks[eye_indices] 149 | transformed_landmarks[:, 0] -= box[0] 150 | transformed_landmarks[:, 1] -= box[1] 151 | return transformed_landmarks 152 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/gaze_tools_standalone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2006, Christoph Gohlke 2 | # Copyright (c) 2006-2009, The Regents of the University of California 3 | # All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are met: 7 | # 8 | # * Redistributions of source code must retain the above copyright 9 | # notice, this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # * Neither the name of the copyright holders nor the names of any 14 | # contributors may be used to endorse or promote products derived 15 | # from this software without specific prior written permission. 16 | # 17 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 21 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | # POSSIBILITY OF SUCH DAMAGE. 28 | 29 | 30 | import math 31 | import numpy 32 | 33 | # map axes strings to/from tuples of inner axis, parity, repetition, frame 34 | _AXES2TUPLE = { 35 | 'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0), 36 | 'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0), 37 | 'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0), 38 | 'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0), 39 | 'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1), 40 | 'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1), 41 | 'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1), 42 | 'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)} 43 | 44 | _TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items()) 45 | 46 | # axis sequences for Euler angles 47 | _NEXT_AXIS = [1, 2, 0, 1] 48 | 49 | # epsilon for testing whether a number is close to zero 50 | _EPS = numpy.finfo(float).eps * 4.0 51 | 52 | 53 | def euler_from_matrix(matrix, axes='sxyz'): 54 | try: 55 | firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()] 56 | except (AttributeError, KeyError): 57 | _ = _TUPLE2AXES[axes] 58 | firstaxis, parity, repetition, frame = axes 59 | 60 | i = firstaxis 61 | j = _NEXT_AXIS[i+parity] 62 | k = _NEXT_AXIS[i-parity+1] 63 | 64 | M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:3, :3] 65 | if repetition: 66 | sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k]) 67 | if sy > _EPS: 68 | ax = math.atan2( M[i, j], M[i, k]) 69 | ay = math.atan2( sy, M[i, i]) 70 | az = math.atan2( M[j, i], -M[k, i]) 71 | else: 72 | ax = math.atan2(-M[j, k], M[j, j]) 73 | ay = math.atan2( sy, M[i, i]) 74 | az = 0.0 75 | else: 76 | cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i]) 77 | if cy > _EPS: 78 | ax = math.atan2( M[k, j], M[k, k]) 79 | ay = math.atan2(-M[k, i], cy) 80 | az = math.atan2( M[j, i], M[i, i]) 81 | else: 82 | ax = math.atan2(-M[j, k], M[j, j]) 83 | ay = math.atan2(-M[k, i], cy) 84 | az = 0.0 85 | 86 | if parity: 87 | ax, ay, az = -ax, -ay, -az 88 | if frame: 89 | ax, az = az, ax 90 | return ax, ay, az 91 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/kalman_stabilizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2017 Yin Guobing 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import numpy as np 26 | import cv2 27 | 28 | """ 29 | Code from https://github.com/yinguobing/head-pose-estimation 30 | Using Kalman Filter as a point stabilizer to stabilize a 2D point. 31 | """ 32 | 33 | 34 | class Stabilizer(object): 35 | """Using Kalman filter as a point stabilizer.""" 36 | 37 | def __init__(self, 38 | state_num=4, 39 | measure_num=2, 40 | cov_process=0.0001, 41 | cov_measure=0.1): 42 | """Initialization""" 43 | # Currently we only support scalar and point, so check user input first. 44 | assert state_num == 4 or state_num == 2, "Only scalar and point supported, Check state_num please." 45 | 46 | # Store the parameters. 47 | self.state_num = state_num 48 | self.measure_num = measure_num 49 | 50 | # The filter itself. 51 | self.filter = cv2.KalmanFilter(state_num, measure_num, 0) 52 | 53 | # Store the state. 54 | self.state = np.zeros((state_num, 1), dtype=float) 55 | 56 | # Store the measurement result. 57 | self.measurement = np.array((measure_num, 1), float) 58 | 59 | # Store the prediction. 60 | self.prediction = np.zeros((state_num, 1), float) 61 | 62 | # Kalman parameters setup for scalar. 63 | if self.measure_num == 1: 64 | self.filter.transitionMatrix = np.array([[1, 1], 65 | [0, 1]], float) 66 | 67 | self.filter.measurementMatrix = np.array([[1, 1]], float) 68 | 69 | self.filter.processNoiseCov = np.array([[1, 0], 70 | [0, 1]], float) * cov_process 71 | 72 | self.filter.measurementNoiseCov = np.array( 73 | [[1]], float) * cov_measure 74 | 75 | # Kalman parameters setup for point. 76 | if self.measure_num == 2: 77 | self.filter.transitionMatrix = np.array([[1, 0, 1, 0], 78 | [0, 1, 0, 1], 79 | [0, 0, 1, 0], 80 | [0, 0, 0, 1]], float) 81 | 82 | self.filter.measurementMatrix = np.array([[1, 0, 0, 0], 83 | [0, 1, 0, 0]], float) 84 | 85 | self.filter.processNoiseCov = np.array([[1, 0, 0, 0], 86 | [0, 1, 0, 0], 87 | [0, 0, 1, 0], 88 | [0, 0, 0, 1]], float) * cov_process 89 | 90 | self.filter.measurementNoiseCov = np.array([[1, 0], 91 | [0, 1]], float) * cov_measure 92 | 93 | def update(self, measurement): 94 | """Update the filter""" 95 | # Make kalman prediction 96 | self.prediction = self.filter.predict() 97 | 98 | # Get new measurement 99 | if self.measure_num == 1: 100 | self.measurement = np.array([[float(measurement[0])]]) 101 | else: 102 | self.measurement = np.array([[float(measurement[0])], 103 | [float(measurement[1])]]) 104 | 105 | # Correct according to mesurement 106 | self.filter.correct(self.measurement) 107 | 108 | # Update state value. 109 | self.state = self.filter.statePost 110 | 111 | def set_q_r(self, cov_process=0.1, cov_measure=0.001): 112 | """Set new value for processNoiseCov and measurementNoiseCov.""" 113 | if self.measure_num == 1: 114 | self.filter.processNoiseCov = np.array([[1, 0], 115 | [0, 1]], float) * cov_process 116 | self.filter.measurementNoiseCov = np.array( 117 | [[1]], float) * cov_measure 118 | else: 119 | self.filter.processNoiseCov = np.array([[1, 0, 0, 0], 120 | [0, 1, 0, 0], 121 | [0, 0, 1, 0], 122 | [0, 0, 0, 1]], float) * cov_process 123 | self.filter.measurementNoiseCov = np.array([[1, 0], 124 | [0, 1]], float) * cov_measure 125 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/ros_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | camera_to_ros = [[0.0, 0.0, 1.0, 0.0], 4 | [-1.0, 0.0, 0.0, 0.0], 5 | [0.0, -1.0, 0.0, 0.0], 6 | [0.0, 0.0, 0.0, 1.0]] 7 | 8 | ros_to_camera = [[0.0, -1.0, 0.0, 0.0], 9 | [0.0, 0.0, -1.0, 0.0], 10 | [1.0, 0.0, 0.0, 0.0], 11 | [0.0, 0.0, 0.0, 1.0]] 12 | 13 | 14 | def position_ros_to_tf(ros_position): 15 | return np.array([ros_position.x, ros_position.y, ros_position.z]) 16 | 17 | 18 | def position_tf_to_ros(tf_position): 19 | from geometry_msgs.msg import Point 20 | return Point(tf_position[0], tf_position[1], tf_position[2]) 21 | 22 | 23 | def quaternion_ros_to_tf(ros_quaternion): 24 | return np.array([ros_quaternion.x, ros_quaternion.y, ros_quaternion.z, ros_quaternion.w]) 25 | 26 | 27 | def quaternion_tf_to_ros(tf_quaternion): 28 | from geometry_msgs.msg import Quaternion 29 | return Quaternion(tf_quaternion[0], tf_quaternion[1], tf_quaternion[2], tf_quaternion[3]) 30 | 31 | 32 | def geometry_to_tuple(geometry_msg): 33 | return geometry_msg.x, geometry_msg.y, geometry_msg.z 34 | 35 | 36 | def convert_image(msg, desired_encoding="passthrough", ignore_invalid_depth=False): 37 | from cv_bridge import CvBridge 38 | 39 | type_as_str = str(type(msg)) 40 | if type_as_str.find('sensor_msgs.msg._CompressedImage.CompressedImage') >= 0 \ 41 | or type_as_str.find('_sensor_msgs__CompressedImage') >= 0: 42 | try: 43 | _, compr_type = msg.format.split(';') 44 | except ValueError: 45 | compr_type = '' 46 | if compr_type.strip() == 'tiff compressed': 47 | if ignore_invalid_depth: 48 | bridge = CvBridge() 49 | return bridge.compressed_imgmsg_to_cv2(msg, desired_encoding=desired_encoding) 50 | else: 51 | raise Exception('tiff compressed is not supported') 52 | else: 53 | bridge = CvBridge() 54 | return bridge.compressed_imgmsg_to_cv2(msg, desired_encoding=desired_encoding) 55 | else: 56 | bridge = CvBridge() 57 | return bridge.imgmsg_to_cv2(msg, desired_encoding=desired_encoding) 58 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/subject_ros_bridge.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Kevin Cortacero 3 | Licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) 4 | """ 5 | 6 | from rt_gene.msg import MSG_SubjectImagesList, MSG_SubjectImages 7 | 8 | from cv_bridge import CvBridge 9 | 10 | 11 | class SubjectImages(object): 12 | def __init__(self, s_id): 13 | self.id = s_id 14 | self.face = None 15 | self.right = None 16 | self.left = None 17 | 18 | 19 | class SubjectBridge(object): 20 | def __init__(self): 21 | self.__cv_bridge = CvBridge() 22 | 23 | def msg_to_images(self, subject_msg): 24 | subject = SubjectImages(subject_msg.subject_id) 25 | subject.face = self.__cv_bridge.imgmsg_to_cv2(subject_msg.face_img, "rgb8") 26 | subject.right = self.__cv_bridge.imgmsg_to_cv2(subject_msg.right_eye_img, "rgb8") 27 | subject.left = self.__cv_bridge.imgmsg_to_cv2(subject_msg.left_eye_img, "rgb8") 28 | return subject 29 | 30 | def images_to_msg(self, subject_id, subject): 31 | msg = MSG_SubjectImages() 32 | msg.subject_id = subject_id 33 | msg.face_img = self.__cv_bridge.cv2_to_imgmsg(subject.face_color, "rgb8") 34 | msg.right_eye_img = self.__cv_bridge.cv2_to_imgmsg(subject.right_eye_color, "rgb8") 35 | msg.left_eye_img = self.__cv_bridge.cv2_to_imgmsg(subject.left_eye_color, "rgb8") 36 | return msg 37 | 38 | 39 | class SubjectListBridge(object): 40 | def __init__(self): 41 | self.__subject_bridge = SubjectBridge() 42 | 43 | def msg_to_images(self, subject_msg): 44 | subject_dict = dict() 45 | for s in subject_msg.subjects: 46 | subject_dict[s.subject_id] = self.__subject_bridge.msg_to_images(s) 47 | return subject_dict 48 | 49 | def images_to_msg(self, subject_dict, timestamp): 50 | msg = MSG_SubjectImagesList() 51 | msg.header.stamp = timestamp 52 | for subject_id, s in subject_dict.items(): 53 | try: 54 | msg.subjects.append(self.__subject_bridge.images_to_msg(subject_id, s)) 55 | except TypeError: 56 | pass 57 | 58 | return msg 59 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/tracker_face_encoding.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Kevin Cortacero 3 | @Ahmed Al-Hindawi 4 | Licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) 5 | """ 6 | 7 | from __future__ import print_function 8 | 9 | import cv2 10 | import dlib 11 | import numpy as np 12 | import rospkg 13 | import rospy 14 | import scipy.optimize 15 | 16 | from rt_gene.gaze_tools import get_normalised_eye_landmarks 17 | from .tracker_generic import GenericTracker 18 | 19 | 20 | class FaceEncodingTracker(GenericTracker): 21 | FACE_ENCODER = dlib.face_recognition_model_v1( 22 | rospkg.RosPack().get_path('rt_gene') + '/model_nets/dlib_face_recognition_resnet_model_v1.dat') 23 | 24 | def __init__(self): 25 | super(FaceEncodingTracker, self).__init__() 26 | self.__encoding_list = {} 27 | self.__threshold = float(rospy.get_param("~face_encoding_threshold", default=0.6)) 28 | 29 | def __encode_subject(self, tracked_element): 30 | # get the face_color and face_chip it using the transformed_eye_landmarks 31 | eye_landmarks = get_normalised_eye_landmarks(tracked_element.landmarks, tracked_element.box) 32 | # Get the width of the eye, and compute how big the margin should be according to the width 33 | lefteye_width = eye_landmarks[3][0] - eye_landmarks[2][0] 34 | righteye_width = eye_landmarks[1][0] - eye_landmarks[0][0] 35 | 36 | lefteye_center_x = eye_landmarks[2][0] + lefteye_width / 2 37 | righteye_center_x = eye_landmarks[0][0] + righteye_width / 2 38 | lefteye_center_y = (eye_landmarks[2][1] + eye_landmarks[3][1]) / 2.0 39 | righteye_center_y = (eye_landmarks[1][1] + eye_landmarks[0][1]) / 2.0 40 | aligned_face, rot_matrix = GenericTracker.align_face_to_eyes(tracked_element.face_color, 41 | right_eye_center=(righteye_center_x, righteye_center_y), 42 | left_eye_center=(lefteye_center_x, lefteye_center_y), 43 | face_width=150, 44 | face_height=150) 45 | encoding = self.FACE_ENCODER.compute_face_descriptor(aligned_face) 46 | return encoding 47 | 48 | def __add_new_element(self, element): 49 | # encode the new array 50 | found_id = None 51 | 52 | encoding = np.array(self.__encode_subject(element)) 53 | # check to see if we've seen it before 54 | list_to_check = list(set(self.__encoding_list.keys()) - set(self._tracked_elements.keys())) 55 | 56 | for untracked_encoding_id in list_to_check: 57 | previous_encoding = self.__encoding_list[untracked_encoding_id] 58 | previous_encoding = np.fromstring(previous_encoding[1:-1], dtype=float, sep=",") 59 | distance = np.linalg.norm(previous_encoding - encoding, axis=0) 60 | 61 | # the new element and the previous encoding are the same person 62 | if distance < self.__threshold: 63 | self._tracked_elements[untracked_encoding_id] = element 64 | found_id = untracked_encoding_id 65 | break 66 | 67 | if found_id is None: 68 | found_id = self._generate_new_id() 69 | self._tracked_elements[found_id] = element 70 | 71 | self.__encoding_list[found_id] = np.array2string(encoding, formatter={'float_kind': lambda x: "{:.5f}".format(x)}, separator=",") 72 | 73 | return found_id 74 | 75 | def __update_element(self, element_id, element): 76 | self._tracked_elements[element_id] = element 77 | 78 | # (can be overridden if necessary) 79 | def _generate_new_id(self): 80 | self._i += 1 81 | return str(self._i) 82 | 83 | def get_tracked_elements(self): 84 | return self._tracked_elements 85 | 86 | def clear_elements(self): 87 | self._tracked_elements.clear() 88 | 89 | def track(self, new_elements): 90 | # if no elements yet, just add all the new ones 91 | if not self._tracked_elements: 92 | for e in new_elements: 93 | try: 94 | self.__add_new_element(e) 95 | except cv2.error: 96 | pass 97 | return 98 | 99 | current_tracked_element_ids = self._tracked_elements.keys() 100 | updated_tracked_element_ids = [] 101 | distance_matrix, map_index_to_id = self.get_distance_matrix(new_elements) 102 | 103 | # get best matching pairs with Hungarian Algorithm 104 | col, row = scipy.optimize.linear_sum_assignment(distance_matrix) 105 | 106 | # assign each new element to existing one or store it as new 107 | for j, new_element in enumerate(new_elements): 108 | row_list = row.tolist() 109 | if j in row_list: 110 | # find the index of the column matching 111 | row_idx = row_list.index(j) 112 | 113 | match_idx = col[row_idx] 114 | # if the new element matches with existing old one 115 | _new_idx = map_index_to_id[match_idx] 116 | self.__update_element(_new_idx, new_element) 117 | updated_tracked_element_ids.append(_new_idx) 118 | else: 119 | try: 120 | _new_idx = self.__add_new_element(new_element) 121 | updated_tracked_element_ids.append(_new_idx) 122 | except cv2.error: 123 | pass 124 | 125 | # store non-tracked elements in-case they reappear 126 | elements_to_delete = list(set(current_tracked_element_ids) - set(updated_tracked_element_ids)) 127 | for i in elements_to_delete: 128 | # don't track it anymore 129 | del self._tracked_elements[i] 130 | -------------------------------------------------------------------------------- /rt_gene/src/rt_gene/tracker_sequential.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Kevin Cortacero 3 | Licensed under Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) 4 | """ 5 | 6 | import rospy 7 | from scipy import optimize 8 | import scipy 9 | from .tracker_generic import GenericTracker 10 | 11 | 12 | class SequentialTracker(GenericTracker): 13 | def __init__(self): 14 | super(SequentialTracker, self).__init__() 15 | rospy.logwarn("** SequentialTracker is no longer supported, please use the FaceEncodingTracker instead") 16 | 17 | ''' --------------------------------------------------------------------''' 18 | ''' PRIVATE METHODS ''' 19 | 20 | def __add_new_element(self, element): 21 | new_id = self._generate_unique_id() 22 | self._tracked_elements[new_id] = element 23 | return new_id 24 | 25 | def __update_element(self, element_id, element): 26 | self._tracked_elements[element_id] = element 27 | 28 | ''' --------------------------------------------------------------------''' 29 | ''' PROTECTED METHODS ''' 30 | 31 | # (can be overridden if necessary) 32 | def _generate_unique_id(self): 33 | self._i += 1 34 | return str(self._i) 35 | 36 | ''' --------------------------------------------------------------------''' 37 | ''' PUBLIC METHODS ''' 38 | 39 | def get_tracked_elements(self): 40 | return self._tracked_elements 41 | 42 | def clear_elements(self): 43 | self._tracked_elements.clear() 44 | 45 | def track(self, new_elements): 46 | # if no elements yet, just add all the new ones 47 | if len(self._tracked_elements) == 0: 48 | [self.__add_new_element(e) for e in new_elements] 49 | return 50 | 51 | current_tracked_element_ids = self._tracked_elements.keys() 52 | updated_tracked_element_ids = [] 53 | distance_matrix, map_index_to_id = self.get_distance_matrix(new_elements) 54 | 55 | # get best matching pairs with Hungarian Algorithm 56 | col, row = scipy.optimize.linear_sum_assignment(distance_matrix) 57 | 58 | # assign each new element to existing one or store it as new 59 | for j, new_element in enumerate(new_elements): 60 | row_list = row.tolist() 61 | if j in row_list: 62 | match_idx = col[row_list.index(j)] 63 | # if the new element matches with existing old one 64 | matched_element_id = map_index_to_id[match_idx] 65 | self.__update_element(matched_element_id, new_element) 66 | _new_idx = matched_element_id 67 | 68 | else: 69 | # if the new element is not matching 70 | _new_idx = self.__add_new_element(new_element) 71 | updated_tracked_element_ids.append(_new_idx) 72 | 73 | # delete all the non-updated elements 74 | elements_to_delete = list(set(current_tracked_element_ids) - set(updated_tracked_element_ids)) 75 | for i in elements_to_delete: 76 | del self._tracked_elements[i] 77 | -------------------------------------------------------------------------------- /rt_gene/webcam_configs/kinect2_calibration.yaml: -------------------------------------------------------------------------------- 1 | image_width: 1920 2 | image_height: 1080 3 | camera_name: kinect2 4 | camera_matrix: 5 | rows: 3 6 | cols: 3 7 | data: [1055.2418494342428, 0.0, 959.5498558958496, 0.0, 1055.800634897113, 526.8365066435917, 0.0, 0.0, 1.0] 8 | distortion_model: plumb_bob 9 | distortion_coefficients: 10 | rows: 1 11 | cols: 5 12 | data: [0.06782730293229537, -0.10186826507839701, -0.003340617929219049, 0.00013233339185125118, 0.04009587588544287] 13 | rectification_matrix: 14 | rows: 3 15 | cols: 3 16 | data: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0] 17 | projection_matrix: 18 | rows: 3 19 | cols: 4 20 | data: [1055.2418494342428, 0.0, 959.5498558958496, 0.0, 0.0, 1055.800634897113, 526.8365066435917, 0.0, 0.0, 0.0, 1.0, 0.0] 21 | -------------------------------------------------------------------------------- /rt_gene/webcam_configs/webcam_blue_26010230.yaml: -------------------------------------------------------------------------------- 1 | image_width: 640 2 | image_height: 480 3 | camera_name: camera 4 | camera_matrix: 5 | rows: 3 6 | cols: 3 7 | data: [1130.394061179079, 0, 308.5144470973464, 0, 1119.415912901594, 250.0097790276739, 0, 0, 1] 8 | distortion_model: plumb_bob 9 | distortion_coefficients: 10 | rows: 1 11 | cols: 5 12 | data: [-0.46815344403296, 0.2834124079090991, 0.005041042335607818, 0.004870246512842942, 0] 13 | rectification_matrix: 14 | rows: 3 15 | cols: 3 16 | data: [1, 0, 0, 0, 1, 0, 0, 0, 1] 17 | projection_matrix: 18 | rows: 3 19 | cols: 4 20 | data: [1086.979248046875, 0, 309.1242169653051, 0, 0, 1094.509887695312, 251.3118299474172, 0, 0, 0, 1, 0] -------------------------------------------------------------------------------- /rt_gene_inpainting/GAN_train_run.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from GAN_train import GAN_train\n", 10 | "\n", 11 | "import tensorflow as tf\n", 12 | "\n", 13 | "config = tf.compat.v1.ConfigProto()\n", 14 | "config.gpu_options.visible_device_list = \"0\"\n", 15 | "# config.gpu_options.allow_growth = True\n", 16 | "tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))\n", 17 | "\n", 18 | "if __name__ == '__main__':\n", 19 | " dataset_folder_path = '/recordings_hdd/'\n", 20 | " subject = 's000'\n", 21 | " gan_train = GAN_train(dataset_folder_path, subject)\n", 22 | "\n", 23 | " gan_train.train(num_epoch=100, batch_size=96, save_interval=1)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [] 32 | } 33 | ], 34 | "metadata": { 35 | "kernelspec": { 36 | "display_name": "Python 2", 37 | "language": "python", 38 | "name": "python2" 39 | }, 40 | "language_info": { 41 | "codemirror_mode": { 42 | "name": "ipython", 43 | "version": 2 44 | }, 45 | "file_extension": ".py", 46 | "mimetype": "text/x-python", 47 | "name": "python", 48 | "nbconvert_exporter": "python", 49 | "pygments_lexer": "ipython2", 50 | "version": "2.7.15+" 51 | } 52 | }, 53 | "nbformat": 4, 54 | "nbformat_minor": 2 55 | } -------------------------------------------------------------------------------- /rt_gene_inpainting/GlassesCompletion_run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from GlassesCompletion import GlassesCompletion 4 | 5 | import tensorflow as tf 6 | 7 | tf.compat.v1.disable_eager_execution() 8 | 9 | config = tf.compat.v1.ConfigProto() 10 | config.gpu_options.visible_device_list = "0" 11 | config.gpu_options.allow_growth = True 12 | # config.gpu_options.per_process_gpu_memory_fraction = 1.0 13 | 14 | tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config)) 15 | 16 | if __name__ == '__main__': 17 | dataset_folder_path = '/recordings_hdd/' 18 | subject = 's000' 19 | completion = GlassesCompletion(dataset_folder_path, subject) 20 | completion.image_completion_random_search(nIter=1000, GPU_ID=config.gpu_options.visible_device_list) 21 | 22 | -------------------------------------------------------------------------------- /rt_gene_inpainting/README.md: -------------------------------------------------------------------------------- 1 | # RT-GENE: Real-Time Eye Gaze Estimation in Natural Environments 2 | [![License: CC BY-NC-SA 4.0](https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg?style=flat-square)](https://creativecommons.org/licenses/by-nc-sa/4.0/) 3 | ![stars](https://img.shields.io/github/stars/Tobias-Fischer/rt_gene.svg?style=flat-square) 4 | ![GitHub issues](https://img.shields.io/github/issues/Tobias-Fischer/rt_gene.svg?style=flat-square) 5 | ![GitHub repo size](https://img.shields.io/github/repo-size/Tobias-Fischer/rt_gene.svg?style=flat-square) 6 | 7 | 8 | ![Inpaining example](../assets/inpaint_example.jpg) 9 | 10 | ![Inpaining overview](../assets/inpaint_overview.jpg) 11 | 12 | 13 | ## License + Attribution 14 | This code is licensed under [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/). Commercial usage is not permitted; please contact or regarding commercial licensing. If you use this dataset or the code in a scientific publication, please cite the following [paper](http://openaccess.thecvf.com/content_ECCV_2018/html/Tobias_Fischer_RT-GENE_Real-Time_Eye_ECCV_2018_paper.html): 15 | 16 | ``` 17 | @inproceedings{FischerECCV2018, 18 | author = {Tobias Fischer and Hyung Jin Chang and Yiannis Demiris}, 19 | title = {{RT-GENE: Real-Time Eye Gaze Estimation in Natural Environments}}, 20 | booktitle = {European Conference on Computer Vision}, 21 | year = {2018}, 22 | month = {September}, 23 | pages = {339--357} 24 | } 25 | ``` 26 | 27 | This work was supported in part by the Samsung Global Research Outreach program, and in part by the EU Horizon 2020 Project PAL (643783-RIA). 28 | 29 | More information can be found on the Personal Robotic Lab's website: . 30 | 31 | ## Requirements 32 | - pip: `pip install tensorflow-gpu keras numpy scipy<=1.2.1 tqdm matplotlib pyamg` 33 | - conda: `conda install tensorflow-gpu keras numpy scipy<=1.2.1 tqdm matplotlib pyamg` 34 | 35 | ## Inpainting source code 36 | This code was used to inpaint the region covered by the eyetracking glasses. There are two parts: 37 | 1) training subject-specific GANs using the images where no eyetracking glasses are worn (`GAN_train.py` and `GAN_train_run.ipynb`) and 38 | 2) the actual inpainting using the trained GANs (`GlassesCompletion.py` and `GlassesCompletion_run.py`). 39 | 40 | In `GAN_train_run.ipynb` and `GlassesCompletion_run.py` the `dataset_folder_path` needs to be adjusted to where the dataset was downloaded to. 41 | 42 | ## List of libraries 43 | - [./external/poissonblending.py](./external/poissonblending.py): [MIT License](https://opensource.org/licenses/MIT); [Link to GitHub](https://github.com/parosky/poissonblending) 44 | - Some code taken from [DC-GAN](https://github.com/Newmu/dcgan_code): [MIT License](https://github.com/Newmu/dcgan_code/blob/master/LICENSE); [Link to GitHub](https://github.com/Newmu/dcgan_code) 45 | - Tensorflow; [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0), [Link to website](http://tensorflow.org/) 46 | - Keras; [MIT License](https://opensource.org/licenses/MIT), [Link to website](https://keras.io) 47 | 48 | -------------------------------------------------------------------------------- /rt_gene_inpainting/external/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Parosky 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /rt_gene_inpainting/external/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene_inpainting/external/__init__.py -------------------------------------------------------------------------------- /rt_gene_inpainting/external/poissonblending.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # 5 | # Original code from https://github.com/parosky/poissonblending 6 | 7 | import numpy as np 8 | import scipy.sparse 9 | import pyamg 10 | 11 | 12 | # pre-process the mask array so that uint64 types from opencv.imread can be adapted 13 | def prepare_mask(mask): 14 | if type(mask[0][0]) is np.ndarray: 15 | result = np.ndarray((mask.shape[0], mask.shape[1]), dtype=np.uint8) 16 | for i in range(mask.shape[0]): 17 | for j in range(mask.shape[1]): 18 | if sum(mask[i][j]) > 0: 19 | result[i][j] = 1 20 | else: 21 | result[i][j] = 0 22 | mask = result 23 | return mask 24 | 25 | 26 | def blend(img_target, img_source, img_mask, offset=(0, 0)): 27 | # compute regions to be blended 28 | region_source = ( 29 | max(-offset[0], 0), 30 | max(-offset[1], 0), 31 | min(img_target.shape[0] - offset[0], img_source.shape[0]), 32 | min(img_target.shape[1] - offset[1], img_source.shape[1])) 33 | region_target = ( 34 | max(offset[0], 0), 35 | max(offset[1], 0), 36 | min(img_target.shape[0], img_source.shape[0] + offset[0]), 37 | min(img_target.shape[1], img_source.shape[1] + offset[1])) 38 | region_size = (region_source[2] - region_source[0], region_source[3] - region_source[1]) 39 | 40 | # clip and normalize mask image 41 | img_mask = img_mask[region_source[0]:region_source[2], region_source[1]:region_source[3]] 42 | img_mask = prepare_mask(img_mask) 43 | img_mask[img_mask == 0] = False 44 | img_mask[img_mask != False] = True 45 | 46 | # create coefficient matrix 47 | A = scipy.sparse.identity(np.prod(region_size), format='lil') 48 | for y in range(region_size[0]): 49 | for x in range(region_size[1]): 50 | if img_mask[y, x]: 51 | index = x + y * region_size[1] 52 | A[index, index] = 4 53 | if index + 1 < np.prod(region_size): 54 | A[index, index + 1] = -1 55 | if index - 1 >= 0: 56 | A[index, index - 1] = -1 57 | if index + region_size[1] < np.prod(region_size): 58 | A[index, index + region_size[1]] = -1 59 | if index - region_size[1] >= 0: 60 | A[index, index - region_size[1]] = -1 61 | A = A.tocsr() 62 | 63 | # create poisson matrix for b 64 | P = pyamg.gallery.poisson(img_mask.shape) 65 | 66 | # for each layer (ex. RGB) 67 | for num_layer in range(img_target.shape[2]): 68 | # get subimages 69 | t = img_target[region_target[0]:region_target[2], region_target[1]:region_target[3], num_layer] 70 | s = img_source[region_source[0]:region_source[2], region_source[1]:region_source[3], num_layer] 71 | t = t.flatten() 72 | s = s.flatten() 73 | 74 | # create b 75 | b = P * s 76 | for y in range(region_size[0]): 77 | for x in range(region_size[1]): 78 | if not img_mask[y, x]: 79 | index = x + y * region_size[1] 80 | b[index] = t[index] 81 | 82 | # solve Ax = b 83 | x = pyamg.solve(A, b, verb=False, tol=1e-10) 84 | 85 | # assign x to target image 86 | x = np.reshape(x, region_size) 87 | x[x > 255] = 255 88 | x[x < 0] = 0 89 | x = np.array(x, img_target.dtype) 90 | img_target[region_target[0]:region_target[2], region_target[1]:region_target[3], num_layer] = x 91 | 92 | return img_target 93 | -------------------------------------------------------------------------------- /rt_gene_inpainting/models.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.models import Model 2 | from tensorflow.keras.layers import Input, Dense, Activation, Flatten, Reshape 3 | from tensorflow.keras.layers import Conv2D, Conv2DTranspose 4 | from tensorflow.keras.layers import LeakyReLU, Dropout 5 | from tensorflow.keras import initializers 6 | from tensorflow.keras import backend as K 7 | from tensorflow.keras.optimizers import Adam 8 | 9 | 10 | def set_trainability(model, trainable=False): 11 | model.trainable = trainable 12 | for layer in model.layers: 13 | layer.trainable = trainable 14 | 15 | 16 | # LSGAN Model 17 | class LSGAN_Model(object): 18 | def __init__(self, img_rows=28, img_cols=28, channel=1, noise_dim=100, dataset='MNIST'): 19 | 20 | self.dataset = dataset 21 | self.img_rows = img_rows 22 | self.img_cols = img_cols 23 | self.channel = channel 24 | self.noise_dim = noise_dim 25 | 26 | self.D = None # discriminator 27 | self.G = None # generator 28 | self.AM = None # adversarial model 29 | self.DM = None 30 | 31 | self.optimizer = Adam(lr=0.00005, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) 32 | 33 | def discriminator(self): 34 | if self.D: 35 | return self.D 36 | 37 | # kern_init = initializers.RandomNormal(mean=0.0, stddev=0.02, seed=None) 38 | kern_init = initializers.glorot_normal() 39 | 40 | input_shape = (self.img_rows, self.img_cols, self.channel) 41 | input_img = Input(shape=input_shape, name='Input_Image') 42 | 43 | x = Conv2D(16, 5, strides=2, input_shape=input_shape, padding='same', kernel_initializer=kern_init)(input_img) 44 | # x = BatchNormalization()(x) 45 | x = LeakyReLU(alpha=0.2)(x) 46 | 47 | x = Conv2D(32, 5, strides=2, padding='same', kernel_initializer=kern_init)(x) 48 | # x = BatchNormalization()(x) 49 | x = LeakyReLU(alpha=0.2)(x) 50 | 51 | x = Conv2D(64, 5, strides=2, padding='same', kernel_initializer=kern_init)(x) 52 | # x = BatchNormalization()(x) 53 | x = LeakyReLU(alpha=0.2)(x) 54 | 55 | x = Conv2D(128, 5, strides=2, padding='same', kernel_initializer=kern_init)(x) 56 | # x = BatchNormalization()(x) 57 | x = LeakyReLU(alpha=0.2)(x) 58 | 59 | x = Conv2D(256, 5, strides=2, padding='same', kernel_initializer=kern_init)(x) 60 | # x = BatchNormalization()(x) 61 | x = LeakyReLU(alpha=0.2)(x) 62 | 63 | x = Conv2D(512, 5, strides=2, padding='same', kernel_initializer=kern_init)(x) 64 | # x = BatchNormalization()(x) 65 | x = LeakyReLU(alpha=0.2)(x) 66 | 67 | # Out: 1-dim probability 68 | x = Flatten()(x) 69 | x = Dense(1, activation='sigmoid')(x) 70 | 71 | self.D = Model(inputs=input_img, outputs=x, name='Discriminator') 72 | 73 | self.D.summary() 74 | return self.D 75 | 76 | def generator(self): 77 | 78 | if self.G: 79 | return self.G 80 | 81 | kern_init = initializers.glorot_normal() 82 | 83 | input_shape = (self.noise_dim,) 84 | input_noise = Input(shape=input_shape, name='noise') 85 | 86 | dim = 7 87 | depth = 512 88 | 89 | x = Dense(dim * dim * depth, kernel_initializer=kern_init)(input_noise) 90 | # x = BatchNormalization()(x) 91 | # x = Activation('relu')(x) 92 | x = Reshape((dim, dim, depth))(x) 93 | 94 | x = Conv2DTranspose(depth / 2, 5, strides=2, padding='same', kernel_initializer=kern_init)(x) 95 | # x = BatchNormalization()(x) 96 | x = Activation('selu')(x) 97 | 98 | x = Conv2DTranspose(depth / 4, 5, strides=2, padding='same', kernel_initializer=kern_init)(x) 99 | # x = BatchNormalization()(x) 100 | x = Activation('selu')(x) 101 | 102 | x = Conv2DTranspose(depth / 8, 5, strides=2, padding='same', kernel_initializer=kern_init)(x) 103 | # x = BatchNormalization()(x) 104 | x = Activation('selu')(x) 105 | 106 | x = Conv2DTranspose(depth / 16, 5, strides=2, padding='same', kernel_initializer=kern_init)(x) 107 | # x = BatchNormalization()(x) 108 | x = Activation('selu')(x) 109 | 110 | x = Conv2DTranspose(self.channel, 5, strides=2, padding='same', kernel_initializer=kern_init)(x) 111 | x = Activation('tanh')(x) 112 | 113 | self.G = Model(inputs=input_noise, outputs=x, name='Generator') 114 | 115 | self.G.summary() 116 | return self.G 117 | 118 | def adversarial_model(self, gen, dis): 119 | if self.AM: 120 | return self.AM 121 | 122 | input_shape = (self.noise_dim,) 123 | input_noise_AM = Input(shape=input_shape, name='noise') 124 | img_fake = gen(input_noise_AM) 125 | x = dis(img_fake) 126 | out_AM = Dropout(1.0, name='out_img_fake')(x) 127 | self.AM = Model(inputs=input_noise_AM, outputs=out_AM) 128 | 129 | set_trainability(dis, False) 130 | 131 | self.AM.compile(loss=self.loss_LSGAN, optimizer=self.optimizer, metrics=['accuracy']) 132 | self.AM.summary() 133 | return self.AM 134 | 135 | def discriminator_model(self, dis): 136 | if self.DM: 137 | return self.DM 138 | 139 | input_shape = (self.img_rows, self.img_cols, self.channel) 140 | input_img = Input(shape=input_shape) 141 | x = dis(input_img) 142 | 143 | self.DM = Model(inputs=input_img, outputs=x) 144 | 145 | self.DM.compile(loss=self.loss_LSGAN, optimizer=self.optimizer, metrics=['accuracy']) 146 | return self.DM 147 | 148 | @staticmethod 149 | def loss_LSGAN(y_true, y_pred): 150 | return K.mean(K.square(y_pred-y_true), axis=-1)/2 151 | 152 | 153 | # Completion Model 154 | class Completion_Model(object): 155 | def __init__(self, noise_dim=100): 156 | self.noise_dim = noise_dim 157 | 158 | self.CL = None 159 | 160 | # complete loss = contextural loss + perceptual loss 161 | def cal_complete_loss(self, gen, dis): 162 | if self.CL: 163 | return self.CL 164 | 165 | input_shape = (self.noise_dim,) 166 | input_noise_CL = Input(shape=input_shape, name='noise') 167 | out_gen_img = gen(input_noise_CL) 168 | out_gen_img = Dropout(1.0, name='name_out_gen_img')(out_gen_img) 169 | out_dis_val = dis(out_gen_img) 170 | 171 | out_dis_val = Dropout(1.0, name='name_out_dis_val')(out_dis_val) 172 | self.CL = Model(inputs=input_noise_CL, outputs=[out_dis_val, out_gen_img]) 173 | return self.CL 174 | -------------------------------------------------------------------------------- /rt_gene_inpainting/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some codes from https://github.com/Newmu/dcgan_code 3 | # Updated: 21 Feb 2017 4 | """ 5 | from __future__ import print_function, division, absolute_import 6 | import scipy.misc 7 | import numpy as np 8 | import tensorflow as tf 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def imread_PRL(path, is_grayscale=False): 13 | if is_grayscale: 14 | return scipy.misc.imread(path, flatten=True).astype(float) / 127.5 - 1. 15 | else: 16 | return scipy.misc.imread(path).astype(float) / 127.5 - 1. 17 | 18 | 19 | def PRL_data_image_load(data, sample_idx=0): 20 | data_files = map(lambda i: data[i], sample_idx) 21 | 22 | data = [imread_PRL(data_file, is_grayscale=False) for data_file in data_files] 23 | data_images = np.array(data).astype(float) 24 | 25 | return data_images 26 | 27 | 28 | def write_log(callback, names, logs, batch_no): 29 | for name, value in zip(names, logs): 30 | summary = tf.compat.v1.Summary() 31 | summary_value = summary.value.add() 32 | summary_value.simple_value = value 33 | summary_value.tag = name 34 | callback.writer.add_summary(summary, batch_no) 35 | callback.writer.flush() 36 | 37 | 38 | def GAN_plot_images(generator, x_train, dataset='result', save2file=False, fake=True, samples=16, noise=None, step=0, 39 | folder_path='result'): 40 | img_rows = x_train.shape[1] 41 | img_cols = x_train.shape[2] 42 | channel = x_train.shape[3] 43 | filename = dataset+'.png' 44 | if fake: 45 | if noise is None: 46 | noise = np.random.uniform(-1.0, 1.0, size=[samples, 100]) 47 | else: 48 | filename = dataset+"_%05d.png" % step 49 | images = generator.predict(noise) 50 | else: 51 | i = np.random.randint(0, x_train.shape[0], samples) 52 | images = x_train[i, :, :, :] 53 | 54 | plt.figure(figsize=(10, 10)) 55 | for i in range(images.shape[0]): 56 | plt.subplot(int(np.sqrt(samples)), int(np.sqrt(samples)), i + 1) 57 | image = (images[i, :, :, :] + 1.) / 2. 58 | 59 | if channel == 1: 60 | image = np.reshape(image, [img_rows, img_cols]) 61 | plt.imshow(image, cmap='gray') 62 | plt.axis('off') 63 | elif channel == 3: 64 | image = np.reshape(image, [img_rows, img_cols, channel]) 65 | plt.imshow(image) 66 | plt.axis('off') 67 | 68 | plt.tight_layout() 69 | if save2file: 70 | plt.savefig(folder_path+'/'+filename) 71 | plt.close('all') 72 | else: 73 | plt.show() 74 | -------------------------------------------------------------------------------- /rt_gene_model_training/README.md: -------------------------------------------------------------------------------- 1 | # RT-GENE: Real-Time Eye Gaze Estimation in Natural Environments 2 | [![License: CC BY-NC-SA 4.0](https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg?style=flat-square)](https://creativecommons.org/licenses/by-nc-sa/4.0/) 3 | ![stars](https://img.shields.io/github/stars/Tobias-Fischer/rt_gene.svg?style=flat-square) 4 | ![GitHub issues](https://img.shields.io/github/issues/Tobias-Fischer/rt_gene.svg?style=flat-square) 5 | ![GitHub repo size](https://img.shields.io/github/repo-size/Tobias-Fischer/rt_gene.svg?style=flat-square) 6 | 7 | ![Dataset Collection Setup](../assets/dataset_collection_setup.jpg) 8 | 9 | 10 | ## License + Attribution 11 | This code is licensed under [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/). Commercial usage is not permitted; please contact or regarding commercial licensing. If you use this dataset or the code in a scientific publication, please cite the following [paper](http://openaccess.thecvf.com/content_ECCV_2018/html/Tobias_Fischer_RT-GENE_Real-Time_Eye_ECCV_2018_paper.html): 12 | 13 | ``` 14 | @inproceedings{FischerECCV2018, 15 | author = {Tobias Fischer and Hyung Jin Chang and Yiannis Demiris}, 16 | title = {{RT-GENE: Real-Time Eye Gaze Estimation in Natural Environments}}, 17 | booktitle = {European Conference on Computer Vision}, 18 | year = {2018}, 19 | month = {September}, 20 | pages = {339--357} 21 | } 22 | ``` 23 | 24 | This work was supported in part by the Samsung Global Research Outreach program, and in part by the EU Horizon 2020 Project PAL (643783-RIA). 25 | 26 | More information can be found on the Personal Robotic Lab's website: . 27 | 28 | ## Requirements 29 | - `pip install tensorflow-gpu numpy scipy tqdm matplotlib h5py scikit-learn pytorch_lightning torch torchvision Pillow` 30 | - Run the [rt_gene/scripts/download_models.py](../rt_gene/scripts/download_models.py) script to download the required model files 31 | 32 | ## Model training code (tensorflow) 33 | This code was used to train the eye gaze estimation CNN for RT-GENE. 34 | - First, the h5 files need to be created from the RAW images. We use the [prepare_dataset.m](./tensorflow/prepare_dataset.m) MATLAB script for this purpose. Please adjust the `load_path` and `save_path` variables. The `augmented` variable can be set to `0` to disable image image augmentations described in the paper. The `with_faces` variable can be set to `1` to also store the face images in the *.h5 files (warning: this requires a lot of memory). 35 | - Then, the [train_model.py](./tensorflow/train_model.py) file can be used to train the models in the 3-Fold setting as described in the paper. An example to call this script is given in the [train_models_run.sh](./tensorflow/train_models_run.sh) file. 36 | - Finally, the [evaluate_model.py](./tensorflow/evaluate_model.py) can be used to get the individual models' performance as well as the ensemble performance. An example to call this script is given in the [evaluate_models.sh](./tensorflow/evaluate_models.sh) file. 37 | 38 | ## Model training code (pytorch) 39 | - First, generate the new patches from the new RT-GENE pipeline using [GenerateEyePatchesRTGENEDataset.py](./pytorch/utils/GenerateEyePatchesRTGENEDataset.py). This will create new directories inside "RT_GENE/subject/inpainted": "left_new" and "right_new". 40 | - Compile the left_new and right_new patches along the labels into a h5 file using [GenerateRTGENEH5Dataset.py](./pytorch/utils/GenerateRTGENEH5Dataset.py), optionally augment the patches here to make it as similar to [prepare_dataset.m](./tensorflow/prepare_dataset.m) from the tensorflow preparation stage. 41 | - Run [train_model.py](./pytorch/train_model.py) on the h5 dataset generated. This will take a while. Available options can be viewed by running `train_model.py --help`. 42 | - Finally, run [post_process_ckpt.py](./pytorch/post_process_ckpt.py) on the generated `ckpt` files to turn them into models (reduces file size) 43 | 44 | ## List of libraries 45 | - Tensorflow; [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0), [Link to website](http://tensorflow.org/) 46 | -------------------------------------------------------------------------------- /rt_gene_model_training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene_model_training/__init__.py -------------------------------------------------------------------------------- /rt_gene_model_training/pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene_model_training/pytorch/__init__.py -------------------------------------------------------------------------------- /rt_gene_model_training/pytorch/evaluate_model.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from argparse import ArgumentParser 4 | from functools import partial 5 | 6 | import h5py 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | from rt_gene.gaze_estimation_models_pytorch import GazeEstimationModelResnet18, \ 13 | GazeEstimationModelVGG, GazeEstimationModelPreactResnet 14 | from rtgene_dataset import RTGENEH5Dataset 15 | from utils.GazeAngleAccuracy import GazeAngleAccuracy 16 | 17 | 18 | def test_fold(d_loader, model_list, fold_idx, model_idx="Ensemble"): 19 | assert type(model_list) is list, "model_list should be a list of models" 20 | angle_criterion_acc = [] 21 | p_bar = tqdm(d_loader) 22 | for left, right, headpose, gaze_labels in p_bar: 23 | p_bar.set_description("Testing Fold {}, Model \"{}\"...".format(fold_idx, model_idx)) 24 | left = left.to("cuda:0") 25 | right = right.to("cuda:0") 26 | headpose = headpose.to("cuda:0") 27 | angle_out = [_m(left, right, headpose).detach().cpu() for _m in model_list] 28 | angle_out = torch.stack(angle_out, dim=1) 29 | angle_out = torch.mean(angle_out, dim=1) 30 | angle_acc = criterion(angle_out[:, :2], gaze_labels) 31 | angle_criterion_acc.append(angle_acc) 32 | 33 | angle_criterion_acc_arr = np.array(angle_criterion_acc) 34 | tqdm.write( 35 | "\r\n\tFold: {}, Model: {}, Mean: {}, STD: {}".format(fold_idx, model_idx, np.mean(angle_criterion_acc_arr), np.std(angle_criterion_acc_arr))) 36 | 37 | 38 | if __name__ == "__main__": 39 | torch.backends.cudnn.benchmark = True 40 | 41 | root_dir = os.path.dirname(os.path.realpath(__file__)) 42 | 43 | root_parser = ArgumentParser(add_help=False) 44 | root_parser.add_argument('--model_loc', type=str, required=False, help='path to the model files to evaluate', action="append") 45 | root_parser.add_argument('--hdf5_file', type=str, default=os.path.abspath(os.path.join(root_dir, "../../RT_GENE/rtgene_dataset.hdf5"))) 46 | root_parser.add_argument('--num_io_workers', default=8, type=int) 47 | root_parser.add_argument('--loss_fn', choices=["mse", "pinball"], default="mse") 48 | root_parser.add_argument('--model_base', choices=["vgg", "resnet18_0", "preactresnet"], default="vgg") 49 | root_parser.add_argument('--batch_size', default=64, type=int) 50 | 51 | hyperparams = root_parser.parse_args() 52 | 53 | _param_num = { 54 | "mse": 2, 55 | "pinball": 3 56 | } 57 | _models = { 58 | "vgg": partial(GazeEstimationModelVGG, num_out=_param_num.get(hyperparams.loss_fn)), 59 | "resnet18_0": partial(GazeEstimationModelResnet18, num_out=_param_num.get(hyperparams.loss_fn)), 60 | "preactresnet": partial(GazeEstimationModelPreactResnet, num_out=_param_num.get(hyperparams.loss_fn)) 61 | } 62 | 63 | test_subjects = [[5, 6, 11, 12, 13], [3, 4, 7, 9], [1, 2, 8, 10]] 64 | criterion = GazeAngleAccuracy() 65 | 66 | # definition of an ensemble is a list of FILES, if any are folders, then not an ensemble 67 | ensemble = sum([os.path.isfile(s) for s in hyperparams.model_loc]) == len(hyperparams.model_loc) 68 | 69 | if ensemble: 70 | _models_list = [] 71 | for model_file in tqdm(hyperparams.model_loc, desc="Ensemble Evaluation; Loading models..."): 72 | _model = _models.get(hyperparams.model_base)() 73 | _model.load_state_dict(torch.load(model_file)) 74 | _model.to("cuda:0") 75 | _model.eval() 76 | _models_list.append(_model) 77 | 78 | for fold_idx, test_subject in enumerate(test_subjects): 79 | data_test = RTGENEH5Dataset(h5_file=h5py.File(hyperparams.hdf5_file, mode="r"), subject_list=test_subject) 80 | data_loader = DataLoader(data_test, batch_size=hyperparams.batch_size, shuffle=True, num_workers=hyperparams.num_io_workers, pin_memory=False) 81 | test_fold(data_loader, fold_idx=fold_idx, model_list=_models_list) 82 | else: 83 | folds = [os.path.abspath(os.path.join(hyperparams.model_loc, "fold_{}/".format(i))) for i in range(3)] 84 | tqdm.write("Every model in fold evaluation (i.e single model)") 85 | for fold_idx, (test_subject, fold) in enumerate(zip(test_subjects, folds)): 86 | # get each checkpoint and see which one is best 87 | epoch_ckpt = glob.glob(os.path.abspath(os.path.join(fold, "*.ckpt"))) 88 | for ckpt in tqdm(epoch_ckpt, desc="Checkpoint evaluation.."): 89 | # load data 90 | data_test = RTGENEH5Dataset(h5_file=h5py.File(hyperparams.hdf5_file, mode="r"), subject_list=test_subject) 91 | data_loader = DataLoader(data_test, batch_size=hyperparams.batch_size, shuffle=True, num_workers=hyperparams.num_io_workers, pin_memory=False) 92 | 93 | model = _models.get(hyperparams.model_base)() 94 | model.load_state_dict(torch.load(ckpt)) 95 | model.to("cuda:0") 96 | model.eval() 97 | 98 | test_fold(data_loader, model_list=[model], fold_idx=fold_idx, model_idx=os.path.basename(ckpt)) 99 | -------------------------------------------------------------------------------- /rt_gene_model_training/pytorch/post_process_ckpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from functools import partial 4 | from glob import glob 5 | from pathlib import Path 6 | 7 | import torch 8 | from tqdm import tqdm 9 | 10 | from rt_gene.gaze_estimation_models_pytorch import GazeEstimationModelResnet18, GazeEstimationModelVGG, GazeEstimationModelPreactResnet 11 | 12 | if __name__ == "__main__": 13 | _root_parser = ArgumentParser(add_help=False) 14 | root_dir = os.path.dirname(os.path.realpath(__file__)) 15 | _root_parser.add_argument('--ckpt_dir', type=str, default=os.path.abspath(os.path.join(root_dir, '../../rt_gene_model_training/pytorch/checkpoints/fold_0/'))) 16 | _root_parser.add_argument('--save_dir', type=str, default=os.path.abspath(os.path.join(root_dir, '../../rt_gene/model_nets/pytorch_models/'))) 17 | _root_parser.add_argument('--model_base', choices=["vgg16", "resnet18", "preactresnet"], default="vgg16") 18 | _root_parser.add_argument('--loss_fn', choices=["mse", "pinball"], default="mse") 19 | _params = _root_parser.parse_args() 20 | 21 | _param_num = { 22 | "mse": 2, 23 | "pinball": 3 24 | } 25 | _models = { 26 | "vgg16": partial(GazeEstimationModelVGG, num_out=_param_num.get(_params.loss_fn)), 27 | "resnet18": partial(GazeEstimationModelResnet18, num_out=_param_num.get(_params.loss_fn)), 28 | "preactresnet": partial(GazeEstimationModelPreactResnet, num_out=_param_num.get(_params.loss_fn)) 29 | } 30 | 31 | # create save dir 32 | Path(_params.save_dir).mkdir(parents=True, exist_ok=True) 33 | 34 | _model = _models.get(_params.model_base)() 35 | for ckpt in tqdm(glob(os.path.join(_params.ckpt_dir, "*.ckpt"))): 36 | filename, file_extension = os.path.splitext(ckpt) 37 | filename = os.path.basename(filename) 38 | _torch_load = torch.load(ckpt)['state_dict'] 39 | 40 | # the ckpt file saves the pytorch_lightning module which includes it's child members. The only child member we're interested in is the "_model". 41 | # Loading the state_dict with _model creates an error as the model tries to find a child called _model within it that doesn't 42 | # exist. Thus remove _model from the dictionary and all is well. 43 | _model_prefix = "_model." 44 | _state_dict = {k[len(_model_prefix):]: v for k, v in _torch_load.items() if k.startswith(_model_prefix)} 45 | _model.load_state_dict(_state_dict) 46 | torch.save(_model.state_dict(), os.path.join(_params.save_dir, "{}.model").format(filename)) 47 | -------------------------------------------------------------------------------- /rt_gene_model_training/pytorch/rtgene_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from torch.utils import data 6 | from torchvision import transforms 7 | from tqdm import tqdm 8 | 9 | 10 | class RTGENEH5Dataset(data.Dataset): 11 | 12 | def __init__(self, h5_file, subject_list=None, transform=None): 13 | self._h5_file = h5_file 14 | self._transform = transform 15 | self._subject_labels = [] 16 | 17 | assert subject_list is not None, "Must pass a list of subjects to load the data for" 18 | 19 | if self._transform is None: 20 | self._transform = transforms.Compose([transforms.Resize((36, 60), Image.BICUBIC), 21 | transforms.ToTensor(), 22 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 23 | 24 | _wanted_subjects = ["s{:03d}".format(_i) for _i in subject_list] 25 | 26 | for grp_s_n in tqdm(_wanted_subjects, desc="Loading subject metadata..."): # subjects 27 | for grp_i_n, grp_i in h5_file[grp_s_n].items(): # images 28 | if "left" in grp_i.keys() and "right" in grp_i.keys() and "label" in grp_i.keys(): 29 | left_dataset = grp_i["left"] 30 | right_datset = grp_i['right'] 31 | 32 | assert len(left_dataset) == len(right_datset), "Dataset left/right images aren't equal length" 33 | for _i in range(len(left_dataset)): 34 | self._subject_labels.append(["/" + grp_s_n + "/" + grp_i_n, _i]) 35 | 36 | def __len__(self): 37 | return len(self._subject_labels) 38 | 39 | def __getitem__(self, index): 40 | _sample = self._subject_labels[index] 41 | assert type(_sample[0]) == str, "Sample not found at index {}".format(index) 42 | _left_img = self._h5_file[_sample[0] + "/left"][_sample[1]][()] 43 | _right_img = self._h5_file[_sample[0] + "/right"][_sample[1]][()] 44 | label_data = self._h5_file[_sample[0]+"/label"][()] 45 | _groud_truth_headpose = label_data[0][()].astype(float) 46 | _ground_truth_gaze = label_data[1][()].astype(float) 47 | 48 | # Load data and get label 49 | _transformed_left = self._transform(Image.fromarray(_left_img, 'RGB')) 50 | _transformed_right = self._transform(Image.fromarray(_right_img, 'RGB')) 51 | 52 | return _transformed_left, _transformed_right, _groud_truth_headpose, _ground_truth_gaze 53 | 54 | 55 | class RTGENEFileDataset(data.Dataset): 56 | 57 | def __init__(self, root_path, subject_list=None, transform=None): 58 | self._root_path = root_path 59 | self._transform = transform 60 | self._subject_labels = [] 61 | 62 | assert subject_list is not None, "Must pass a list of subjects to load the data for" 63 | 64 | if self._transform is None: 65 | self._transform = transforms.Compose([transforms.Resize((224, 224), Image.BICUBIC), 66 | transforms.ToTensor(), 67 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 68 | 69 | subject_path = [os.path.join(root_path, "s{:03d}_glasses/".format(_i)) for _i in subject_list] 70 | 71 | for subject_data in subject_path: 72 | with open(os.path.join(subject_data, "label_combined.txt"), "r") as f: 73 | _lines = f.readlines() 74 | for line in _lines: 75 | split = line.split(",") 76 | left_img_path = os.path.join(subject_data, "inpainted/left_new/", "left_{:0=6d}_rgb.png".format(int(split[0]))) 77 | right_img_path = os.path.join(subject_data, "inpainted/right_new/", "right_{:0=6d}_rgb.png".format(int(split[0]))) 78 | if os.path.exists(left_img_path) and os.path.exists(right_img_path): 79 | head_phi = float(split[1].strip()[1:]) 80 | head_theta = float(split[2].strip()[:-1]) 81 | gaze_phi = float(split[3].strip()[1:]) 82 | gaze_theta = float(split[4].strip()[:-1]) 83 | self._subject_labels.append([left_img_path, right_img_path, head_phi, head_theta, gaze_phi, gaze_theta]) 84 | 85 | print("=> Loaded metadata for {} images".format(len(self._subject_labels))) 86 | 87 | def __len__(self): 88 | return len(self._subject_labels) 89 | 90 | def __getitem__(self, index): 91 | _sample = self._subject_labels[index] 92 | _groud_truth_headpose = [_sample[2], _sample[3]] 93 | _ground_truth_gaze = [_sample[4], _sample[5]] 94 | 95 | # Load data and get label 96 | _left_img = np.array(Image.open(os.path.join(self._root_path, _sample[0])).convert('RGB')) 97 | _right_img = np.array(Image.open(os.path.join(self._root_path, _sample[1])).convert('RGB')) 98 | 99 | _transformed_left = self._transform(Image.fromarray(_left_img, 'RGB')) 100 | _transformed_right = self._transform(Image.fromarray(_right_img, 'RGB')) 101 | 102 | return _transformed_left, _transformed_right, np.array(_groud_truth_headpose, dtype=float), np.array(_ground_truth_gaze, dtype=float) 103 | -------------------------------------------------------------------------------- /rt_gene_model_training/pytorch/utils/CombineGazeH5Datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import h5py 5 | 6 | root_dir = os.path.dirname(os.path.realpath(__file__)) 7 | 8 | root_parser = ArgumentParser(add_help=False) 9 | root_parser.add_argument('--hdf5_file', type=str, required=True, help='path to the datasets to combine', action="append") 10 | root_parser.add_argument('--save_file', type=str, default=os.path.abspath(os.path.join(root_dir, "combined_dataset.hdf5"))) 11 | params = root_parser.parse_args() 12 | 13 | assert len(params.hdf5_file) >= 2, "Need at least two datasets to combine" 14 | 15 | files = [h5py.File(path, mode="r") for path in params.hdf5_file] 16 | keys = [list(h5.keys()) for h5 in files] 17 | _ = [h5.close() for h5 in files] 18 | 19 | h5_all = h5py.File(params.save_file, mode='w') 20 | 21 | s_idx = 0 22 | for dataset, path in zip(keys, params.hdf5_file): 23 | for key in dataset: 24 | h5_all[str("s{:03d}".format(s_idx))] = h5py.ExternalLink(path, key) 25 | s_idx += 1 26 | 27 | h5_all.flush() 28 | h5_all.close() 29 | -------------------------------------------------------------------------------- /rt_gene_model_training/pytorch/utils/GazeAngleAccuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class GazeAngleAccuracy(object): 5 | 6 | def __call__(self, batch_y_pred, batch_y_true): 7 | batch = batch_y_true.size()[0] 8 | batch_y_pred = batch_y_pred.cpu().detach().numpy() 9 | batch_y_true = batch_y_true.cpu().detach().numpy() 10 | acc = [] 11 | for i in range(batch): 12 | y_true, y_pred = batch_y_true[i], batch_y_pred[i] 13 | pred_x = -1 * np.cos(y_pred[0]) * np.sin(y_pred[1]) 14 | pred_y = -1 * np.sin(y_pred[0]) 15 | pred_z = -1 * np.cos(y_pred[0]) * np.cos(y_pred[1]) 16 | pred = np.array([pred_x, pred_y, pred_z]) 17 | pred = pred / np.linalg.norm(pred) 18 | 19 | true_x = -1 * np.cos(y_true[0]) * np.sin(y_true[1]) 20 | true_y = -1 * np.sin(y_true[0]) 21 | true_z = -1 * np.cos(y_true[0]) * np.cos(y_true[1]) 22 | gt = np.array([true_x, true_y, true_z]) 23 | gt = gt / np.linalg.norm(gt) 24 | 25 | acc.append(np.rad2deg(np.arccos(np.dot(pred, gt)))) 26 | 27 | acc = np.mean(np.array(acc)) 28 | return acc 29 | -------------------------------------------------------------------------------- /rt_gene_model_training/pytorch/utils/GenerateEyePatchesRTGENEDataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | 3 | import argparse 4 | import os 5 | 6 | import cv2 7 | from tqdm import tqdm 8 | 9 | from rt_gene.extract_landmarks_method_base import LandmarkMethodBase 10 | 11 | script_path = os.path.dirname(os.path.realpath(__file__)) 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser(description='Estimate gaze from images') 15 | parser.add_argument('im_path', type=str, default=os.path.join(script_path, '../samples/natural'), 16 | nargs='?', help='Path to an image or a directory containing images') 17 | parser.add_argument('--output_path', type=str, default=os.path.join(script_path, '../samples/'), help='Output directory for left/right eye patches') 18 | 19 | landmark_estimator = LandmarkMethodBase(device_id_facedetection="cuda:0", 20 | checkpoint_path_face=os.path.join(script_path, "../../rt_gene/model_nets/SFD/s3fd_facedetector.pth"), 21 | checkpoint_path_landmark=os.path.join(script_path, "../../rt_gene/model_nets/phase1_wpdc_vdc.pth.tar"), 22 | model_points_file=os.path.join(script_path, "../../rt_gene/model_nets/face_model_68.txt")) 23 | 24 | args = parser.parse_args() 25 | 26 | image_path_list = [] 27 | if os.path.isfile(args.im_path): 28 | image_path_list.append(os.path.split(args.im_path)[1]) 29 | args.im_path = os.path.split(args.im_path)[0] 30 | elif os.path.isdir(args.im_path): 31 | for image_file_name in os.listdir(args.im_path): 32 | if image_file_name.endswith('.jpg') or image_file_name.endswith('.png'): 33 | if '_gaze' not in image_file_name and '_headpose' not in image_file_name: 34 | image_path_list.append(image_file_name) 35 | 36 | left_folder_path = os.path.join(args.output_path, "left_new") 37 | right_folder_path = os.path.join(args.output_path, "right_new") 38 | if not os.path.isdir(left_folder_path): 39 | os.makedirs(left_folder_path) 40 | if not os.path.isdir(right_folder_path): 41 | os.makedirs(right_folder_path) 42 | 43 | p_bar = tqdm(image_path_list) 44 | for image_file_name in p_bar: 45 | p_bar.set_description("Processing {}".format(image_file_name)) 46 | image = cv2.imread(os.path.join(args.im_path, image_file_name)) 47 | if image is None: 48 | continue 49 | 50 | faceboxes = landmark_estimator.get_face_bb(image) 51 | if len(faceboxes) == 0: 52 | continue 53 | 54 | subjects = landmark_estimator.get_subjects_from_faceboxes(image, faceboxes) 55 | for subject in subjects: 56 | le_c, re_c, _, _ = subject.get_eye_image_from_landmarks(subject, landmark_estimator.eye_image_size) 57 | 58 | if le_c is not None and re_c is not None: 59 | img_name = image_file_name.split(".")[0] 60 | left_image_path = ["left", img_name, "rgb.png"] 61 | left_image_path = os.path.join(left_folder_path, "_".join(left_image_path)) 62 | 63 | right_image_path = ["right", img_name, "rgb.png"] 64 | right_image_path = os.path.join(right_folder_path, "_".join(right_image_path)) 65 | 66 | cv2.imwrite(left_image_path, le_c) 67 | cv2.imwrite(right_image_path, re_c) 68 | -------------------------------------------------------------------------------- /rt_gene_model_training/pytorch/utils/GenerateRTGENEH5Dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | 3 | import argparse 4 | import os 5 | 6 | import h5py 7 | import numpy as np 8 | from PIL import Image, ImageFilter, ImageOps 9 | from torchvision import transforms 10 | from tqdm import tqdm 11 | 12 | script_path = os.path.dirname(os.path.realpath(__file__)) 13 | 14 | # Augmentations following `prepare_dataset.m`: randomly crop and resize the image 10 times, 15 | # along side two blurring stages, grayscaling and histogram normalisation 16 | _required_size = (224, 224) 17 | _transforms_list = [transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)), # equivalent to random 5px from each edge 18 | transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)), 19 | transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)), 20 | transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)), 21 | transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)), 22 | transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)), 23 | transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)), 24 | transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)), 25 | transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)), 26 | transforms.RandomResizedCrop(size=_required_size, scale=(0.85, 1.0)), 27 | transforms.Grayscale(num_output_channels=3), 28 | lambda x: x.filter(ImageFilter.GaussianBlur(radius=1)), 29 | lambda x: x.filter(ImageFilter.GaussianBlur(radius=3)), 30 | lambda x: ImageOps.equalize(x)] # histogram equalisation 31 | 32 | 33 | def load_and_augment(file_path, augment=False): 34 | image = Image.open(file_path).resize(_required_size) 35 | augmented_images = [np.array(trans(image)) for trans in _transforms_list if augment is True] 36 | augmented_images.append(np.array(image)) 37 | 38 | return np.array(augmented_images, dtype=np.uint8) 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser(description='Estimate gaze from images') 43 | parser.add_argument('--rt_gene_root', type=str, required=True, nargs='?', help='Path to the base directory of RT_GENE') 44 | parser.add_argument('--augment_dataset', type=bool, required=False, default=False, help="Whether to augment the dataset with predefined transforms") 45 | parser.add_argument('--compress', action='store_true', dest="compress") 46 | parser.add_argument('--no-compress', action='store_false', dest="compress") 47 | parser.set_defaults(compress=False) 48 | args = parser.parse_args() 49 | 50 | _compression = "lzf" if args.compress is True else None 51 | 52 | subject_path = [os.path.join(args.rt_gene_root, "s{:03d}_glasses/".format(_i)) for _i in range(0, 17)] 53 | 54 | hdf_file = h5py.File(os.path.abspath(os.path.join(args.rt_gene_root, 'rtgene_dataset.hdf5')), mode='w') 55 | for subject_id, subject_data in enumerate(subject_path): 56 | subject_id = str("s{:03d}".format(subject_id)) 57 | subject_grp = hdf_file.create_group(subject_id) 58 | with open(os.path.join(subject_data, "label_combined.txt"), "r") as f: 59 | _lines = f.readlines() 60 | 61 | for line in tqdm(_lines, desc="Subject {}".format(subject_id)): 62 | 63 | split = line.split(",") 64 | image_name = "{:0=6d}".format(int(split[0])) 65 | image_grp = subject_grp.create_group(image_name) 66 | left_img_path = os.path.join(subject_data, "inpainted/left_new/", "left_{:0=6d}_rgb.png".format(int(split[0]))) 67 | right_img_path = os.path.join(subject_data, "inpainted/right_new/", "right_{:0=6d}_rgb.png".format(int(split[0]))) 68 | if os.path.exists(left_img_path) and os.path.exists(right_img_path): 69 | head_phi = float(split[1].strip()[1:]) 70 | head_theta = float(split[2].strip()[:-1]) 71 | gaze_phi = float(split[3].strip()[1:]) 72 | gaze_theta = float(split[4].strip()[:-1]) 73 | labels = [(head_theta, head_phi), (gaze_theta, gaze_phi)] 74 | 75 | left_data = load_and_augment(left_img_path, augment=args.augment_dataset) 76 | right_data = load_and_augment(right_img_path, augment=args.augment_dataset) 77 | image_grp.create_dataset("left", data=left_data, compression=_compression) 78 | image_grp.create_dataset("right", data=right_data, compression=_compression) 79 | image_grp.create_dataset("label", data=labels) 80 | 81 | hdf_file.flush() 82 | hdf_file.close() 83 | -------------------------------------------------------------------------------- /rt_gene_model_training/pytorch/utils/LearningRateFinder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import h5py 5 | import matplotlib.pyplot as plt 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from torch.utils.tensorboard import SummaryWriter 9 | from tqdm import trange 10 | 11 | from gaze_estimation_models_pytorch import GazeEstimationModelResnet18 12 | from rtgene_dataset import RTGENEH5Dataset 13 | 14 | 15 | class RTGENELearningRateFinder(object): 16 | 17 | def __init__(self, model, optimiser, loss, batch_size=128): 18 | self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 19 | self.writer = SummaryWriter() 20 | 21 | _root_dir = os.path.dirname(os.path.realpath(__file__)) 22 | 23 | data_train = RTGENEH5Dataset(h5_file=h5py.File(os.path.abspath(os.path.join(_root_dir, "../../../RT_GENE/dataset.hdf5")), mode="r"), 24 | subject_list=list(range(16))) 25 | 26 | dataloader = DataLoader(data_train, batch_size=batch_size, shuffle=True, num_workers=4) 27 | 28 | # Train and evaluate 29 | logs, losses = self.find_lr(model=model, dataloader=dataloader, criterion=loss, optimiser=optimiser, batch_size=batch_size, 30 | epoch_length=len(data_train)) 31 | plt.plot(logs[10:-5], losses[10:-5]) 32 | plt.xscale('log') 33 | plt.show() 34 | 35 | def find_lr(self, dataloader, model, optimiser, criterion, init_value=1e-6, final_value=1e-3, beta=0.98, epoch_length=100000, batch_size=64): 36 | num = (epoch_length // batch_size) - 1 37 | mult = (final_value / init_value) ** (1 / num) 38 | lr = init_value 39 | optimiser.param_groups[0]['lr'] = lr 40 | avg_loss = 0. 41 | best_loss = 0. 42 | batch_num = 0 43 | losses = [] 44 | log_lrs = [] 45 | 46 | additional_steps = epoch_length // batch_size 47 | _rtgene_model = model.to(self._device) 48 | model.eval() 49 | 50 | data_iter = iter(dataloader) 51 | 52 | # Start training 53 | with trange(0, additional_steps) as pbar: 54 | for step in pbar: 55 | batch_num += 1 56 | # As before, get the loss for this mini-batch of inputs/outputs 57 | try: 58 | batch = next(data_iter) 59 | except StopIteration: 60 | return log_lrs, losses 61 | 62 | _left_patch, _right_patch, _labels, _head_pose = batch 63 | 64 | _left_patch = _left_patch.to(self._device) 65 | _right_patch = _right_patch.to(self._device) 66 | _labels = _labels.to(self._device).float() 67 | _head_pose = _head_pose.to(self._device).float() 68 | 69 | optimiser.zero_grad() 70 | with torch.set_grad_enabled(True): 71 | 72 | # Get model outputs and calculate loss 73 | angular_out = _rtgene_model(_left_patch, _right_patch, _head_pose) 74 | loss = criterion(angular_out, _labels) 75 | 76 | # Compute the smoothed loss 77 | avg_loss = beta * avg_loss + (1 - beta) * loss.item() 78 | smoothed_loss = avg_loss / (1 - beta ** batch_num) 79 | # Stop if the loss is exploding 80 | # if batch_num > 1 and smoothed_loss > 4 * best_loss: 81 | # return log_lrs, losses 82 | # Record the best loss 83 | if smoothed_loss < best_loss or batch_num == 1: 84 | best_loss = smoothed_loss 85 | # Store the values 86 | losses.append(smoothed_loss) 87 | log_lrs.append(lr) 88 | # Do the SGD step 89 | loss.backward() 90 | optimiser.step() 91 | # Update the lr for the next step 92 | lr *= mult 93 | optimiser.param_groups[0]['lr'] = lr 94 | 95 | pbar.set_description("Learning Rate: {:4.8e}, Loss: {:4.8f}".format(lr, smoothed_loss)) 96 | pbar.update() 97 | 98 | self.writer.add_scalar("data/lr", math.log10(lr), global_step=batch_num) 99 | self.writer.add_scalar("data/loss", smoothed_loss, global_step=batch_num) 100 | 101 | return log_lrs, losses 102 | 103 | 104 | if __name__ == "__main__": 105 | rt_gene_fast_model = GazeEstimationModelResnet18(num_out=2) 106 | params_to_update = [] 107 | for name, param in rt_gene_fast_model.named_parameters(): 108 | if param.requires_grad: 109 | params_to_update.append(param) 110 | 111 | learning_rate = 1e-1 112 | optimizer = torch.optim.Adam(params_to_update, lr=learning_rate, betas=(0.9, 0.95)) 113 | criterion = torch.nn.MSELoss(reduction="sum") 114 | 115 | RTGENELearningRateFinder(model=rt_gene_fast_model, optimiser=optimizer, loss=criterion, batch_size=128) 116 | -------------------------------------------------------------------------------- /rt_gene_model_training/pytorch/utils/PinballLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class PinballLoss(object): 5 | 6 | def __init__(self, reduction="mean"): 7 | super(PinballLoss, self).__init__() 8 | self.q1 = 0.45 9 | self.q9 = 1 - self.q1 10 | 11 | _reduction_strategies = { 12 | "mean": torch.mean, 13 | "sum": torch.sum, 14 | "none": lambda x: x 15 | } 16 | assert reduction in _reduction_strategies.keys(), "Reduction method unknown, possibilities include 'mean', 'sum' and 'none'" 17 | 18 | self._reduction_strategy = _reduction_strategies.get(reduction) 19 | 20 | def __call__(self, output, target): 21 | angle_o = output[:, :2] 22 | var_o = output[:, 2:3] 23 | var_o = var_o.view(-1, 1).expand(var_o.size(0), 2) 24 | 25 | q_10 = target - (angle_o - var_o) 26 | q_90 = target - (angle_o + var_o) 27 | 28 | loss_10 = torch.max(self.q1 * q_10, (self.q1 - 1) * q_10) 29 | loss_90 = torch.max(self.q9 * q_90, (self.q9 - 1) * q_90) 30 | 31 | loss_10 = self._reduction_strategy(loss_10) 32 | loss_90 = self._reduction_strategy(loss_90) 33 | 34 | return loss_10 + loss_90 35 | -------------------------------------------------------------------------------- /rt_gene_model_training/pytorch/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene_model_training/pytorch/utils/__init__.py -------------------------------------------------------------------------------- /rt_gene_model_training/tensorflow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene_model_training/tensorflow/__init__.py -------------------------------------------------------------------------------- /rt_gene_model_training/tensorflow/evaluate_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | for epoch in 01 02 03 04 4 | do 5 | # format is: FC1size FC2size FC3size model_type epoch_num GPU_num 6 | python evaluate_model.py 1024 512 256 512 VGG16 ${epoch} 0 7 | done 8 | 9 | -------------------------------------------------------------------------------- /rt_gene_model_training/tensorflow/train_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | """ 4 | VGG-16/VGG-19 architecture applied to RT-GENE Dataset 5 | @ Tobias Fischer (t.fischer@imperial.ac.uk), Hyung Jin Chang (hj.chang@imperial.ac.uk) 6 | """ 7 | 8 | from __future__ import print_function, division, absolute_import 9 | 10 | import argparse 11 | import gc 12 | import os 13 | 14 | import h5py 15 | import tensorflow as tf 16 | from tensorflow.keras import backend as K 17 | from tensorflow.keras.callbacks import ModelCheckpoint 18 | from tensorflow.keras.optimizers import Adam 19 | from train_tools import * 20 | 21 | tf.compat.v1.disable_eager_execution() 22 | 23 | path = '/recordings_hdd/mtcnn_twoeyes_inpainted_eccv/' 24 | 25 | 26 | subjects_test_threefold = [ 27 | ['s001', 's002', 's008', 's010'], 28 | ['s003', 's004', 's007', 's009'], 29 | ['s005', 's006', 's011', 's012', 's013'] 30 | ] 31 | subjects_train_threefold = [ 32 | ['s003', 's004', 's007', 's009', 's005', 's006', 's011', 's012', 's013'], 33 | ['s001', 's002', 's008', 's010', 's005', 's006', 's011', 's012', 's013'], 34 | ['s001', 's002', 's008', 's010', 's003', 's004', 's007', 's009'] 35 | ] 36 | 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("fc1_size", type=int) 39 | parser.add_argument("fc2_size", type=int) 40 | parser.add_argument("fc3_size", type=int) 41 | parser.add_argument("batch_size", type=int) 42 | parser.add_argument("model_type", choices=['VGG16', 'VGG19']) 43 | parser.add_argument("ensemble_num", type=int) 44 | parser.add_argument("gpu_num", choices=['0', '1', '2', '3']) 45 | 46 | args = parser.parse_args() 47 | 48 | # Parameters 49 | model_type = args.model_type 50 | fc1_size = args.fc1_size 51 | fc2_size = args.fc2_size 52 | fc3_size = args.fc3_size 53 | 54 | batch_size = args.batch_size 55 | num_epochs = 4 56 | validation_split = 0.05 57 | 58 | suffix = 'eccv_'+model_type+'_'+str(fc1_size)+'_'+str(fc2_size)+'_'+str(fc3_size)+'_'+str(batch_size)+'_'+str(args.ensemble_num) 59 | 60 | config = tf.compat.v1.ConfigProto() 61 | config.gpu_options.allow_growth = True 62 | config.gpu_options.visible_device_list = args.gpu_num 63 | 64 | for subjects_train, subjects_test in zip(subjects_train_threefold, subjects_test_threefold): 65 | print('subjects_test:', subjects_test) 66 | 67 | if os.path.isfile(path+"3Fold"+''.join(subjects_test)+suffix+"_01.h5") and \ 68 | os.path.isfile(path+"3Fold"+''.join(subjects_test)+suffix+"_02.h5") and \ 69 | os.path.isfile(path+"3Fold"+''.join(subjects_test)+suffix+"_03.h5") and \ 70 | os.path.isfile(path+"3Fold"+''.join(subjects_test)+suffix+"_04.h5"): 71 | print('Skip training, model already exists: '+path+"3Fold"+''.join(subjects_test)+suffix+"_XX.h5") 72 | continue 73 | 74 | if os.path.isfile(path+'eccv_'+model_type+'_'+str(fc1_size)+'_'+str(fc2_size)+'_'+str(fc3_size)+'_'+str(batch_size)+'_01.txt') and \ 75 | os.path.isfile(path+'eccv_'+model_type+'_'+str(fc1_size)+'_'+str(fc2_size)+'_'+str(fc3_size)+'_'+str(batch_size)+'_02.txt') and \ 76 | os.path.isfile(path+'eccv_'+model_type+'_'+str(fc1_size)+'_'+str(fc2_size)+'_'+str(fc3_size)+'_'+str(batch_size)+'_03.txt') and \ 77 | os.path.isfile(path+'eccv_'+model_type+'_'+str(fc1_size)+'_'+str(fc2_size)+'_'+str(fc3_size)+'_'+str(batch_size)+'_04.txt'): 78 | print('Skip training, model already evaluated: '+path+'eccv_'+model_type+'_'+str(fc1_size)+'_'+str(fc2_size)+'_'+str(fc3_size)+'_'+str(batch_size)+'_XX.txt') 79 | continue 80 | 81 | K.clear_session() 82 | tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config)) 83 | 84 | train_file_names = [path+'/RT_GENE_train_'+subject+'.mat' for subject in subjects_train] 85 | train_files = [h5py.File(train_file_name) for train_file_name in train_file_names] 86 | 87 | train_images_L, train_images_R, train_gazes, train_headposes, train_num = get_train_test_data_twoeyes(train_files, 'train') 88 | 89 | num_steps_epoch, num_steps_validation, size_validation_set = get_train_info(train_num, validation_split, batch_size) 90 | 91 | generator = GeneratorsTwoEyes(train_num, size_validation_set, batch_size, num_steps_epoch, 92 | train_images_L, train_images_R, train_gazes, train_headposes) 93 | 94 | adam = Adam(lr=0.00075, beta_1=0.9, beta_2=0.95) 95 | model = get_vgg_twoeyes(adam, model_type=model_type, fc1_size=fc1_size, fc2_size=fc2_size, fc3_size=fc3_size) 96 | 97 | checkpointer = ModelCheckpoint(filepath=path+"3Fold"+''.join(subjects_test)+suffix+"_{epoch:02d}.h5", 98 | verbose=1, save_best_only=False, period=1) 99 | 100 | history = model.fit_generator(generator.get_train_data(), 101 | steps_per_epoch=int(num_steps_epoch), 102 | epochs=num_epochs, 103 | use_multiprocessing=False, 104 | validation_data=generator.get_validation_data(), 105 | validation_steps=int(num_steps_validation), 106 | callbacks=[checkpointer]) 107 | 108 | # model.save(path+"3Fold"+''.join(subjects_test)+suffix+".h5") 109 | 110 | for train_file in train_files: 111 | train_file.close() 112 | 113 | train_images_L, train_images_R, train_gazes, train_headposes = None, None, None, None 114 | model = None 115 | gc.collect() 116 | 117 | -------------------------------------------------------------------------------- /rt_gene_model_training/tensorflow/train_models_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | for ensemble_num in 1 2 3 4 4 | do 5 | # format is: FC1size FC2size FC3size model_type ensemble_num GPU_num 6 | python train_model.py 1024 512 256 512 VGG16 ${ensemble_num} 0 7 | end 8 | 9 | -------------------------------------------------------------------------------- /rt_gene_standalone/README.md: -------------------------------------------------------------------------------- 1 | # RT-GENE: Real-Time Eye Gaze Estimation in Natural Environments 2 | [![License: CC BY-NC-SA 4.0](https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg?style=flat-square)](https://creativecommons.org/licenses/by-nc-sa/4.0/) 3 | ![stars](https://img.shields.io/github/stars/Tobias-Fischer/rt_gene.svg?style=flat-square) 4 | ![GitHub issues](https://img.shields.io/github/issues/Tobias-Fischer/rt_gene.svg?style=flat-square) 5 | ![GitHub repo size](https://img.shields.io/github/repo-size/Tobias-Fischer/rt_gene.svg?style=flat-square) 6 | 7 | ## License + Attribution 8 | This code is licensed under [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/). Commercial usage is not permitted; please contact or regarding commercial licensing. If you use this dataset or the code in a scientific publication, please cite the following [paper](http://openaccess.thecvf.com/content_ECCV_2018/html/Tobias_Fischer_RT-GENE_Real-Time_Eye_ECCV_2018_paper.html): 9 | 10 | ``` 11 | @inproceedings{FischerECCV2018, 12 | author = {Tobias Fischer and Hyung Jin Chang and Yiannis Demiris}, 13 | title = {{RT-GENE: Real-Time Eye Gaze Estimation in Natural Environments}}, 14 | booktitle = {European Conference on Computer Vision}, 15 | year = {2018}, 16 | month = {September}, 17 | pages = {339--357} 18 | } 19 | ``` 20 | 21 | This work was supported in part by the Samsung Global Research Outreach program, and in part by the EU Horizon 2020 Project PAL (643783-RIA). 22 | 23 | More information can be found on the Personal Robotic Lab's website: . 24 | 25 | ## Requirements 26 | 1. Install required Python packages: 27 | - For `conda` users (recommended): `conda install tensorflow-gpu numpy scipy tqdm pillow opencv matplotlib pytorch torchvision` 28 | - For `pip` users: `pip install tensorflow-gpu numpy scipy tqdm torch torchvision Pillow opencv-python matplotlib` 29 | 1. Download RT-GENE and add the source folder to your `PYTHONPATH` environment variable: 30 | 1. `cd $HOME/ && git clone https://github.com/Tobias-Fischer/rt_gene.git` 31 | 1. `export PYTHONPATH=$HOME/rt_gene/rt_gene/src` 32 | 33 | ## Basic usage 34 | - Run `$HOME/rt_gene/rt_gene_standalone/estimate_gaze_standalone.py`. For supported arguments, run `$HOME/rt_gene/rt_gene_standalone/estimate_gaze_standalone.py --help`. Note that the first time the script is run, various model files are downloaded automatically. An alternative mirror for the model files is [here](https://drive.google.com/drive/folders/1cdOlCoXBIv-KxBGPP88oijd85uc5XVGF?usp=sharing); these files need to be moved into `$HOME/rt_gene/rt_gene/model_nets`. 35 | 36 | ### Optional ensemble model files 37 | - To use an ensemble scheme using 4 models trained on the MPII, UTMV and RT-GENE datasets, simply use the `--models` argument, e.g `cd $HOME/rt_gene/ && ./rt_gene_standalone/estimate_gaze_standalone.py --models './rt_gene/model_nets/all_subjects_mpii_prl_utmv_0_02.h5' './rt_gene/model_nets/all_subjects_mpii_prl_utmv_1_02.h5' './rt_gene/model_nets/all_subjects_mpii_prl_utmv_2_02.h5' './rt_gene/model_nets/all_subjects_mpii_prl_utmv_3_02.h5'` 38 | 39 | ## List of libraries 40 | See [main README.md](../rt_gene/README.md) 41 | 42 | -------------------------------------------------------------------------------- /rt_gene_standalone/samples_gaze/gaze_center.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene_standalone/samples_gaze/gaze_center.jpg -------------------------------------------------------------------------------- /rt_gene_standalone/samples_gaze/gaze_down.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene_standalone/samples_gaze/gaze_down.jpg -------------------------------------------------------------------------------- /rt_gene_standalone/samples_gaze/gaze_left.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene_standalone/samples_gaze/gaze_left.jpg -------------------------------------------------------------------------------- /rt_gene_standalone/samples_gaze/gaze_right.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene_standalone/samples_gaze/gaze_right.jpg -------------------------------------------------------------------------------- /rt_gene_standalone/samples_gaze/gaze_up.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tobias-Fischer/rt_gene/450ad3d66ecb06d8845d5d0794a35fd0e429293b/rt_gene_standalone/samples_gaze/gaze_up.jpg --------------------------------------------------------------------------------