├── .gitignore ├── LICENSE ├── README.md ├── assets ├── in_video.mp4 ├── out_gif.gif └── out_video.mp4 ├── config.py ├── data └── .gitkeep ├── download.sh ├── evaluate.py ├── inference.py ├── main.py ├── models ├── __init__.py ├── mobilenet.py ├── mobileone.py └── resnet.py ├── mpii_train.py ├── onnx_export.py ├── onnx_inference.py ├── reparameterize.py ├── requirements.txt ├── utils ├── datasets.py └── helpers.py └── weights └── .gitkeep /.gitignore: -------------------------------------------------------------------------------- 1 | data/Gaze360 2 | data/MPIIFaceGaze 3 | output/ 4 | *.pt 5 | *.onnx 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yakhyokhuja Valikhujaev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MobileGaze: Pre-trained mobile nets for Gaze-Estimation 2 | 3 | ![Downloads](https://img.shields.io/github/downloads/yakhyo/gaze-estimation/total) 4 | [![GitHub Repo stars](https://img.shields.io/github/stars/yakhyo/gaze-estimation)](https://github.com/yakhyo/gaze-estimation/stargazers) 5 | [![GitHub Repository](https://img.shields.io/badge/GitHub-Repository-blue?logo=github)](https://github.com/yakhyo/gaze-estimation) 6 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.14257640.svg)](https://doi.org/10.5281/zenodo.14257640) 7 | 8 | 11 | 12 | 17 | 18 | 19 |
20 | 21 |

22 | Video by Yan Krukau: https://www.pexels.com/video/male-teacher-with-his-students-8617126/ 23 |

24 |
25 | 26 | This project aims to perform gaze estimation using several deep learning models like ResNet, MobileNet v2, and MobileOne. It supports both classification and regression for predicting gaze direction. Built on top of [L2CS-Net](https://github.com/Ahmednull/L2CS-Net), the project includes additional pre-trained models and refined code for better performance and flexibility. 27 | 28 | ## Features 29 | 30 | - [x] **ONNX Inference**: Export pytorch weights to ONNX and ONNX runtime inference. 31 | - [x] **ResNet**: [Deep Residual Networks](https://arxiv.org/abs/1512.03385) - Enables deeper networks with better accuracy through residual learning. 32 | - [x] **MobileNet v2**: [Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) - Efficient model for mobile applications, balancing performance and computational cost. 33 | - [x] **MobileOne (s0-s4)**: [An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040) - Achieves near-instant inference times, ideal for real-time mobile applications. 34 | - [x] **Face Detection**: [uniface](https://github.com/yakhyo/uniface) - **Uniface** face detection library uses RetinaFace model. 35 | 36 | > [!NOTE] 37 | > All models are trained only on **Gaze360** dataset. 38 | 39 | ## Installation 40 | 41 | 1. Clone the repository: 42 | 43 | ```bash 44 | git clone https://github.com/yakyo/gaze-estimation.git 45 | cd gaze-estimation 46 | ``` 47 | 48 | 2. Install the required dependencies: 49 | 50 | ```bash 51 | pip install -r requirements.txt 52 | ``` 53 | 54 | 3. Download weight files: 55 | 56 | a) Download weights from the following links: 57 | 58 | | Model | PyTorch Weights | ONNX Weights | Size | Epochs | MAE | 59 | | ------------ | ----------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------- | ------- | ------ | ----- | 60 | | ResNet-18 | [resnet18.pt](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/resnet18.pt) | [resnet18_gaze.onnx](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/resnet18_gaze.onnx) | 43 MB | 200 | 12.84 | 61 | | ResNet-34 | [resnet34.pt](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/resnet34.pt) | [resnet34_gaze.onnx](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/resnet34_gaze.onnx) | 81.6 MB | 200 | 11.33 | 62 | | ResNet-50 | [resnet50.pt](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/resnet50.pt) | [resnet50_gaze.onnx](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/resnet50_gaze.onnx) | 91.3 MB | 200 | 11.34 | 63 | | MobileNet V2 | [mobilenetv2.pt](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/mobilenetv2.pt) | [mobilenetv2_gaze.onnx](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/mobilenetv2_gaze.onnx) | 9.59 MB | 200 | 13.07 | 64 | | MobileOne S0 | [mobileone_s0_fused.pt](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/mobileone_s0.pt) | [mobileone_s0_gaze.onnx](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/mobileone_s0_gaze.onnx) | 4.8 MB | 200 | 12.58 | 65 | | MobileOne S1 | [not available](#) | [not available](#) | xx MB | 200 | \* | 66 | | MobileOne S2 | [not available](#) | [not available](#) | xx MB | 200 | \* | 67 | | MobileOne S3 | [not available](#) | [not available](#) | xx MB | 200 | \* | 68 | | MobileOne S4 | [not availablet](#) | [not available](#) | xx MB | 200 | \* | 69 | 70 | '\*' - soon will be uploaded (due to limited computing resources I cannot publish rest of the weights, but you still can train them with given code). 71 | 72 | b) Run the command below to download weights to the `weights` directory (Linux): 73 | 74 | ```bash 75 | sh download.sh [model_name] 76 | resnet18 77 | resnet34 78 | resnet50 79 | mobilenetv2 80 | mobileone_s0 81 | mobileone_s1 82 | mobileone_s2 83 | mobileone_s3 84 | mobileone_s4 85 | ``` 86 | 87 | ## Usage 88 | 89 | ### Datasets 90 | 91 | Dataset folder structure: 92 | 93 | ``` 94 | data/ 95 | ├── Gaze360/ 96 | │ ├── Image/ 97 | │ └── Label/ 98 | └── MPIIFaceGaze/ 99 | ├── Image/ 100 | └── Label/ 101 | ``` 102 | 103 | **Gaze360** 104 | 105 | - Link to download dataset: https://gaze360.csail.mit.edu/download.php 106 | - Data pre-processing code: https://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/#gaze360 107 | 108 | **MPIIGaze** 109 | 110 | - Link to download dataset: [download page](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/its-written-all-over-your-face-full-face-appearance-based-gaze-estimation) 111 | - Data pre-processing code: https://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/#mpiifacegaze 112 | 113 | ### Training 114 | 115 | ```bash 116 | python main.py --data [dataset_path] --dataset [dataset_name] --arch [architecture_name] 117 | ``` 118 | 119 | `main.py` arguments: 120 | 121 | ``` 122 | usage: main.py [-h] [--data DATA] [--dataset DATASET] [--output OUTPUT] [--checkpoint CHECKPOINT] [--num-epochs NUM_EPOCHS] [--batch-size BATCH_SIZE] [--arch ARCH] [--alpha ALPHA] [--lr LR] [--num-workers NUM_WORKERS] 123 | 124 | Gaze estimation training. 125 | 126 | options: 127 | -h, --help show this help message and exit 128 | --data DATA Directory path for gaze images. 129 | --dataset DATASET Dataset name, available `gaze360`, `mpiigaze`. 130 | --output OUTPUT Path of output models. 131 | --checkpoint CHECKPOINT 132 | Path to checkpoint for resuming training. 133 | --num-epochs NUM_EPOCHS 134 | Maximum number of training epochs. 135 | --batch-size BATCH_SIZE 136 | Batch size. 137 | --arch ARCH Network architecture, currently available: resnet18/34/50, mobilenetv2, mobileone_s0-s4. 138 | --alpha ALPHA Regression loss coefficient. 139 | --lr LR Base learning rate. 140 | --num-workers NUM_WORKERS 141 | Number of workers for data loading. 142 | ``` 143 | 144 | ### Evaluation 145 | 146 | ```bash 147 | python evaluate.py --data [dataset_path] --dataset [dataset_name] --weight [weight_path] --arch [architecture_name] 148 | ``` 149 | 150 | `evaluate.py` arguments: 151 | 152 | ``` 153 | usage: evaluate.py [-h] [--data DATA] [--dataset DATASET] [--weights WEIGHTS] [--batch-size BATCH_SIZE] [--arch ARCH] [--num-workers NUM_WORKERS] 154 | 155 | Gaze estimation evaluation. 156 | 157 | options: 158 | -h, --help show this help message and exit 159 | --data DATA Directory path for gaze images. 160 | --dataset DATASET Dataset name, available `gaze360`, `mpiigaze` 161 | --weights WEIGHTS Path to model weight for evaluation. 162 | --batch-size BATCH_SIZE 163 | Batch size. 164 | --arch ARCH Network architecture, currently available: resnet18/34/50, mobilenetv2, mobileone_s0-s4. 165 | --num-workers NUM_WORKERS 166 | Number of workers for data loading. 167 | ``` 168 | 169 | ### Inference 170 | 171 | ```bash 172 | inference.py --model [model_name] --weight [model_weight_path] --view --source [source_video / cam_index] --output [output_file] --dataset [dataset_name] 173 | ``` 174 | 175 | `detect.py` arguments: 176 | 177 | ``` 178 | usage: inference.py [-h] [--model MODEL] [--weight WEIGHT] [--view] [--source SOURCE] [--output OUTPUT] [--dataset DATASET] 179 | 180 | Gaze estimation inference 181 | 182 | options: 183 | -h, --help show this help message and exit 184 | --model MODEL Model name, default `resnet18` 185 | --weight WEIGHT Path to gaze esimation model weights 186 | --view Display the inference results 187 | --source SOURCE Path to source video file or camera index 188 | --output OUTPUT Path to save output file 189 | --dataset DATASET Dataset name to get dataset related configs 190 | ``` 191 | 192 | ### ONNX Export and Inference 193 | 194 | **Export to ONNX** 195 | 196 | ```bash 197 | python onnx_export.py --weight [model_path] --model [model_name] --dynamic 198 | ``` 199 | 200 | `onnx_export.py` arguments: 201 | 202 | ``` 203 | usage: onnx_export.py [-h] [-w WEIGHT] [-n {resnet18,resnet34,resnet50,mobilenetv2,mobileone_s0}] [-d {gaze360}] [--dynamic] 204 | 205 | Gaze Estimation Model ONNX Export 206 | 207 | options: 208 | -h, --help show this help message and exit 209 | -w WEIGHT, --weight WEIGHT 210 | Trained state_dict file path to open 211 | -n {resnet18,resnet34,resnet50,mobilenetv2,mobileone_s0}, --model {resnet18,resnet34,resnet50,mobilenetv2,mobileone_s0} 212 | Backbone network architecture to use 213 | -d {gaze360,mpiigaze}, --dataset {gaze360,mpiigaze} 214 | Dataset name for bin configuration 215 | --dynamic Enable dynamic batch size and input dimensions for ONNX export 216 | ``` 217 | 218 | **ONNX Inference** 219 | 220 | ```bash 221 | python onnx_inference.py --source [source video / webcam index] --model [onnx model path] --output [path to save video] 222 | ``` 223 | 224 | `onnx_inference.py` arguments: 225 | 226 | ``` 227 | usage: onnx_inference.py [-h] --source SOURCE --model MODEL [--output OUTPUT] 228 | 229 | Gaze Estimation ONNX Inference 230 | 231 | options: 232 | -h, --help show this help message and exit 233 | --source SOURCE Video path or camera index (e.g., 0 for webcam) 234 | --model MODEL Path to ONNX model 235 | --output OUTPUT Path to save output video (optional) 236 | ``` 237 | 238 | ## Citation 239 | 240 | If you use this work in your research, please cite it as: 241 | 242 | Valikhujaev, Y. (2024). MobileGaze: Pre-trained mobile nets for Gaze-Estimation. Zenodo. [https://doi.org/10.5281/zenodo.14257640](https://doi.org/10.5281/zenodo.14257640) 243 | 244 | Alternatively, in BibTeX format: 245 | 246 | ```bibtex 247 | @misc{valikhujaev2024mobilegaze, 248 | author = {Valikhujaev, Y.}, 249 | title = {MobileGaze: Pre-trained mobile nets for Gaze-Estimation}, 250 | year = {2024}, 251 | publisher = {Zenodo}, 252 | doi = {10.5281/zenodo.14257640}, 253 | url = {https://doi.org/10.5281/zenodo.14257640} 254 | } 255 | ``` 256 | 257 | ## Reference 258 | 259 | 1. This project is built on top of [L2CS-Net](https://github.com/Ahmednull/L2CS-Net). Most of the code parts have been re-written for reproducibility and adaptability. Several additional backbones are provided with pre-trained weights. 260 | 2. https://github.com/apple/ml-mobileone 261 | 3. [uniface](https://github.com/yakhyo/uniface) - face detection library used for inference in `detect.py`. 262 | 263 | 268 | -------------------------------------------------------------------------------- /assets/in_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yakhyo/gaze-estimation/2ccde1d08d007727c2df1ce704c32e2683b2d0b9/assets/in_video.mp4 -------------------------------------------------------------------------------- /assets/out_gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yakhyo/gaze-estimation/2ccde1d08d007727c2df1ce704c32e2683b2d0b9/assets/out_gif.gif -------------------------------------------------------------------------------- /assets/out_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yakhyo/gaze-estimation/2ccde1d08d007727c2df1ce704c32e2683b2d0b9/assets/out_video.mp4 -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | data_config = { 2 | "gaze360": 3 | { 4 | "bins": 90, 5 | "binwidth": 4, 6 | "angle": 180 # angle range 7 | }, 8 | "mpiigaze": 9 | { 10 | "bins": 28, 11 | "binwidth": 3, 12 | "angle": 42 # angle range 13 | } 14 | 15 | } -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yakhyo/gaze-estimation/2ccde1d08d007727c2df1ce704c32e2683b2d0b9/data/.gitkeep -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the base URL for the downloads 4 | BASE_URL="https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1" 5 | 6 | # Create the weights directory if it does not exist 7 | mkdir -p weights 8 | 9 | # Check if a model name was provided 10 | if [ -z "$1" ]; then 11 | echo "Usage: $0 " 12 | echo "Example: $0 resnet18" 13 | exit 1 14 | fi 15 | 16 | # Determine the model name 17 | MODEL_NAME=$1 18 | MODEL_FILE="${MODEL_NAME}.pt" 19 | MODEL_FILE_ONNX="${MODEL_NAME}_gaze.onnx" 20 | 21 | # Download the model 22 | wget -O weights/$MODEL_FILE $BASE_URL/$MODEL_FILE 23 | wget -O weights/$MODEL_FILE $BASE_URL/$MODEL_FILE_ONNX 24 | 25 | # Check if the download was successful 26 | if [ $? -eq 0 ]; then 27 | echo "Downloaded $MODEL_FILE and $MODEL_FILE_ONNX to weights/" 28 | else 29 | echo "Failed to download $MODEL_FILE" 30 | fi -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | 9 | from tqdm import tqdm 10 | import torch.nn.functional as F 11 | 12 | from config import data_config 13 | from utils.helpers import angular_error, gaze_to_3d, get_dataloader, get_model 14 | 15 | import warnings 16 | warnings.filterwarnings("ignore") 17 | # Setup logging 18 | logging.basicConfig(level=logging.INFO, format='%(message)s') 19 | 20 | 21 | def parse_args(): 22 | """Parse input arguments.""" 23 | parser = argparse.ArgumentParser(description="Gaze estimation evaluation") 24 | parser.add_argument("--data", type=str, default="data/Gaze360", help="Directory path for gaze images.") 25 | parser.add_argument("--dataset", type=str, default="gaze360", help="Dataset name, available `gaze360`, `mpiigaze`") 26 | parser.add_argument("--weight", type=str, default="", help="Path to model weight for evaluation.") 27 | parser.add_argument("--batch-size", type=int, default=64, help="Batch size.") 28 | parser.add_argument( 29 | "--arch", 30 | type=str, 31 | default="resnet18", 32 | help="Network architecture, currently available: resnet18/34/50, mobilenetv2, mobileone_s0-s4." 33 | ) 34 | parser.add_argument("--num-workers", type=int, default=8, help="Number of workers for data loading.") 35 | 36 | args = parser.parse_args() 37 | 38 | # Override default values based on selected dataset 39 | if args.dataset in data_config: 40 | dataset_config = data_config[args.dataset] 41 | args.bins = dataset_config["bins"] 42 | args.binwidth = dataset_config["binwidth"] 43 | args.angle = dataset_config["angle"] 44 | else: 45 | raise ValueError(f"Unknown dataset: {args.dataset}. Available options: {list(data_config.keys())}") 46 | 47 | return args 48 | 49 | 50 | @torch.no_grad() 51 | def evaluate(params, model, data_loader, idx_tensor, device): 52 | """ 53 | Evaluate the model on the test dataset. 54 | 55 | Args: 56 | params (argparse.Namespace): Parsed command-line arguments. 57 | model (nn.Module): The gaze estimation model. 58 | data_loader (torch.utils.data.DataLoader): DataLoader for the test dataset. 59 | idx_tensor (torch.Tensor): Tensor representing bin indices. 60 | device (torch.device): Device to perform evaluation on. 61 | """ 62 | model.eval() 63 | average_error = 0 64 | total_samples = 0 65 | 66 | for images, labels_gaze, regression_labels_gaze, _ in tqdm(data_loader, total=len(data_loader)): 67 | total_samples += regression_labels_gaze.size(0) 68 | images = images.to(device) 69 | 70 | # Regression labels 71 | label_pitch = np.radians(regression_labels_gaze[:, 0], dtype=np.float32) 72 | label_yaw = np.radians(regression_labels_gaze[:, 1], dtype=np.float32) 73 | 74 | # Inference 75 | pitch, yaw = model(images) 76 | 77 | # Regression predictions 78 | pitch_predicted = F.softmax(pitch, dim=1) 79 | yaw_predicted = F.softmax(yaw, dim=1) 80 | 81 | # Mapping from binned (0 to 90) to angles (-180 to 180) or (0 to 28) to angles (-42, 42) 82 | pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) * params.binwidth - params.angle 83 | yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) * params.binwidth - params.angle 84 | 85 | pitch_predicted = np.radians(pitch_predicted.cpu()) 86 | yaw_predicted = np.radians(yaw_predicted.cpu()) 87 | 88 | for p, y, pl, yl in zip(pitch_predicted, yaw_predicted, label_pitch, label_yaw): 89 | average_error += angular_error(gaze_to_3d([p, y]), gaze_to_3d([pl, yl])) 90 | 91 | logging.info( 92 | f"Dataset: {params.dataset} | " 93 | f"Total Number of Samples: {total_samples} | " 94 | f"Mean Angular Error: {average_error/total_samples}" 95 | ) 96 | 97 | 98 | def main(): 99 | params = parse_args() 100 | 101 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 102 | torch.backends.cudnn.benchmark = True 103 | 104 | model = get_model(params.arch, params.bins, inference_mode=True) 105 | 106 | if os.path.exists(params.weight): 107 | model.load_state_dict(torch.load(params.weight, map_location=device, weights_only=True)) 108 | else: 109 | raise ValueError(f"Model weight not found at {params.weight}") 110 | 111 | model.to(device) 112 | test_loader = get_dataloader(params, mode="test") 113 | 114 | idx_tensor = torch.arange(params.bins, device=device, dtype=torch.float32) 115 | 116 | logging.info("Start Evaluation") 117 | evaluate(params, model, test_loader, idx_tensor, device) 118 | 119 | 120 | if __name__ == '__main__': 121 | main() 122 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import logging 3 | import argparse 4 | import warnings 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torchvision import transforms 10 | 11 | from config import data_config 12 | from utils.helpers import get_model, draw_bbox_gaze 13 | 14 | import uniface 15 | 16 | warnings.filterwarnings("ignore") 17 | logging.basicConfig(level=logging.INFO, format='%(message)s') 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description="Gaze estimation inference") 22 | parser.add_argument("--model", type=str, default="resnet34", help="Model name, default `resnet18`") 23 | parser.add_argument( 24 | "--weight", 25 | type=str, 26 | default="resnet34.pt", 27 | help="Path to gaze esimation model weights" 28 | ) 29 | parser.add_argument("--view", action="store_true", default=True, help="Display the inference results") 30 | parser.add_argument("--source", type=str, default="assets/in_video.mp4", 31 | help="Path to source video file or camera index") 32 | parser.add_argument("--output", type=str, default="output.mp4", help="Path to save output file") 33 | parser.add_argument("--dataset", type=str, default="gaze360", help="Dataset name to get dataset related configs") 34 | args = parser.parse_args() 35 | 36 | # Override default values based on selected dataset 37 | if args.dataset in data_config: 38 | dataset_config = data_config[args.dataset] 39 | args.bins = dataset_config["bins"] 40 | args.binwidth = dataset_config["binwidth"] 41 | args.angle = dataset_config["angle"] 42 | else: 43 | raise ValueError(f"Unknown dataset: {args.dataset}. Available options: {list(data_config.keys())}") 44 | 45 | return args 46 | 47 | 48 | def pre_process(image): 49 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 50 | transform = transforms.Compose([ 51 | transforms.ToPILImage(), 52 | transforms.Resize(448), 53 | transforms.ToTensor(), 54 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 55 | ]) 56 | 57 | image = transform(image) 58 | image_batch = image.unsqueeze(0) 59 | return image_batch 60 | 61 | 62 | def main(params): 63 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 64 | 65 | idx_tensor = torch.arange(params.bins, device=device, dtype=torch.float32) 66 | 67 | face_detector = uniface.RetinaFace() # third-party face detection library 68 | 69 | try: 70 | gaze_detector = get_model(params.model, params.bins, inference_mode=True) 71 | state_dict = torch.load(params.weight, map_location=device) 72 | gaze_detector.load_state_dict(state_dict) 73 | logging.info("Gaze Estimation model weights loaded.") 74 | except Exception as e: 75 | logging.info(f"Exception occured while loading pre-trained weights of gaze estimation model. Exception: {e}") 76 | 77 | gaze_detector.to(device) 78 | gaze_detector.eval() 79 | 80 | video_source = params.source 81 | if video_source.isdigit() or video_source == '0': 82 | cap = cv2.VideoCapture(int(video_source)) 83 | else: 84 | cap = cv2.VideoCapture(video_source) 85 | 86 | if params.output: 87 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 88 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 89 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 90 | out = cv2.VideoWriter(params.output, fourcc, cap.get(cv2.CAP_PROP_FPS), (width, height)) 91 | 92 | if not cap.isOpened(): 93 | raise IOError("Cannot open webcam") 94 | 95 | with torch.no_grad(): 96 | while True: 97 | success, frame = cap.read() 98 | 99 | if not success: 100 | logging.info("Failed to obtain frame or EOF") 101 | break 102 | 103 | bboxes, keypoints = face_detector.detect(frame) 104 | for bbox, keypoint in zip(bboxes, keypoints): 105 | x_min, y_min, x_max, y_max = map(int, bbox[:4]) 106 | 107 | image = frame[y_min:y_max, x_min:x_max] 108 | image = pre_process(image) 109 | image = image.to(device) 110 | 111 | pitch, yaw = gaze_detector(image) 112 | 113 | pitch_predicted, yaw_predicted = F.softmax(pitch, dim=1), F.softmax(yaw, dim=1) 114 | 115 | # Mapping from binned (0 to 90) to angles (-180 to 180) or (0 to 28) to angles (-42, 42) 116 | pitch_predicted = torch.sum(pitch_predicted * idx_tensor, dim=1) * params.binwidth - params.angle 117 | yaw_predicted = torch.sum(yaw_predicted * idx_tensor, dim=1) * params.binwidth - params.angle 118 | 119 | # Degrees to Radians 120 | pitch_predicted = np.radians(pitch_predicted.cpu()) 121 | yaw_predicted = np.radians(yaw_predicted.cpu()) 122 | 123 | # draw box and gaze direction 124 | draw_bbox_gaze(frame, bbox, pitch_predicted, yaw_predicted) 125 | 126 | if params.output: 127 | out.write(frame) 128 | 129 | if params.view: 130 | cv2.imshow('Demo', frame) 131 | if cv2.waitKey(1) & 0xFF == ord('q'): 132 | break 133 | 134 | cap.release() 135 | if params.output: 136 | out.release() 137 | cv2.destroyAllWindows() 138 | 139 | 140 | if __name__ == "__main__": 141 | args = parse_args() 142 | 143 | if not args.view and not args.output: 144 | raise Exception("At least one of --view or --ouput must be provided.") 145 | 146 | main(args) 147 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import argparse 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | 11 | from config import data_config 12 | from utils.helpers import get_model, get_dataloader 13 | 14 | # Setup logging 15 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') 16 | 17 | 18 | def parse_args(): 19 | """Parse input arguments.""" 20 | parser = argparse.ArgumentParser(description="Gaze estimation training") 21 | parser.add_argument("--data", type=str, default="data", help="Directory path for gaze images.") 22 | parser.add_argument("--dataset", type=str, default="gaze360", help="Dataset name, available `gaze360`, `mpiigaze`.") 23 | parser.add_argument("--output", type=str, default="output/", help="Path of output models.") 24 | parser.add_argument("--checkpoint", type=str, default="", help="Path to checkpoint for resuming training.") 25 | parser.add_argument("--num-epochs", type=int, default=100, help="Maximum number of training epochs.") 26 | parser.add_argument("--batch-size", type=int, default=64, help="Batch size.") 27 | parser.add_argument( 28 | "--arch", 29 | type=str, 30 | default="resnet18", 31 | help="Network architecture, currently available: resnet18/34/50, mobilenetv2, mobileone_s0-s4." 32 | ) 33 | parser.add_argument("--alpha", type=float, default=1, help="Regression loss coefficient.") 34 | parser.add_argument("--lr", type=float, default=0.00001, help="Base learning rate.") 35 | parser.add_argument("--num-workers", type=int, default=8, help="Number of workers for data loading.") 36 | 37 | args = parser.parse_args() 38 | 39 | # Override default values based on selected dataset 40 | if args.dataset in data_config: 41 | dataset_config = data_config[args.dataset] 42 | args.bins = dataset_config["bins"] 43 | args.binwidth = dataset_config["binwidth"] 44 | args.angle = dataset_config["angle"] 45 | else: 46 | raise ValueError(f"Unknown dataset: {args.dataset}. Available options: {list(data_config.keys())}") 47 | 48 | return args 49 | 50 | 51 | def initialize_model(params, device): 52 | """ 53 | Initialize the gaze estimation model, optimizer, and optionally load a checkpoint. 54 | 55 | Args: 56 | params (argparse.Namespace): Parsed command-line arguments. 57 | device (torch.device): Device to load the model and optimizer onto. 58 | 59 | Returns: 60 | Tuple[nn.Module, torch.optim.Optimizer, int]: Initialized model, optimizer, and the starting epoch. 61 | """ 62 | model = get_model(params.arch, params.bins, pretrained=True) 63 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr) 64 | start_epoch = 0 65 | 66 | if params.checkpoint: 67 | checkpoint = torch.load(params.checkpoint, map_location=device) 68 | model.load_state_dict(checkpoint['model_state_dict']) 69 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 70 | 71 | # Move optimizer states to device 72 | for state in optimizer.state.values(): 73 | for k, v in state.items(): 74 | if isinstance(v, torch.Tensor): 75 | state[k] = v.to(device) 76 | 77 | start_epoch = checkpoint['epoch'] 78 | logging.info(f'Resumed training from {params.checkpoint}, starting at epoch {start_epoch + 1}') 79 | 80 | return model.to(device), optimizer, start_epoch 81 | 82 | 83 | def train_one_epoch( 84 | params, 85 | model, 86 | cls_criterion, 87 | reg_criterion, 88 | optimizer, 89 | data_loader, 90 | idx_tensor, 91 | device, 92 | epoch 93 | ): 94 | """ 95 | Train the model for one epoch. 96 | 97 | Args: 98 | params (argparse.Namespace): Parsed command-line arguments. 99 | model (nn.Module): The gaze estimation model. 100 | cls_criterion (nn.Module): Loss function for classification. 101 | reg_criterion (nn.Module): Loss function for regression. 102 | optimizer (torch.optim.Optimizer): Optimizer for the model. 103 | data_loader (DataLoader): DataLoader for the training dataset. 104 | idx_tensor (torch.Tensor): Tensor representing bin indices. 105 | device (torch.device): Device to perform training on. 106 | epoch (int): The current epoch number. 107 | 108 | Returns: 109 | Tuple[float, float]: Average losses for pitch and yaw. 110 | """ 111 | 112 | model.train() 113 | sum_loss_pitch, sum_loss_yaw = 0, 0 114 | 115 | for idx, (images, labels_gaze, regression_labels_gaze, _) in enumerate(data_loader): 116 | images = images.to(device) 117 | 118 | # Binned labels 119 | label_pitch = labels_gaze[:, 0].to(device) 120 | label_yaw = labels_gaze[:, 1].to(device) 121 | 122 | # Regression labels 123 | label_pitch_regression = regression_labels_gaze[:, 0].to(device) 124 | label_yaw_regression = regression_labels_gaze[:, 1].to(device) 125 | 126 | # Inference 127 | pitch, yaw = model(images) 128 | 129 | # Cross Entropy Loss 130 | loss_pitch = cls_criterion(pitch, label_pitch) 131 | loss_yaw = cls_criterion(yaw, label_yaw) 132 | 133 | # Softmax 134 | pitch, yaw = F.softmax(pitch, dim=1), F.softmax(yaw, dim=1) 135 | 136 | # Mapping from binned (0 to 90) to angels (-180 to 180) 137 | pitch_predicted = torch.sum(pitch * idx_tensor, 1) * params.binwidth - params.angle 138 | yaw_predicted = torch.sum(yaw * idx_tensor, 1) * params.binwidth - params.angle 139 | 140 | # Mean Squared Error Loss 141 | loss_regression_pitch = reg_criterion(pitch_predicted, label_pitch_regression) 142 | loss_regression_yaw = reg_criterion(yaw_predicted, label_yaw_regression) 143 | 144 | # Calculate loss with regression alpha 145 | loss_pitch += params.alpha * loss_regression_pitch 146 | loss_yaw += params.alpha * loss_regression_yaw 147 | 148 | # Total loss for pitch and yaw 149 | loss = loss_pitch + loss_yaw 150 | 151 | optimizer.zero_grad() 152 | loss.backward() 153 | optimizer.step() 154 | 155 | sum_loss_pitch += loss_pitch.item() 156 | sum_loss_yaw += loss_yaw.item() 157 | 158 | if (idx + 1) % 100 == 0: 159 | logging.info( 160 | f'Epoch [{epoch + 1}/{params.num_epochs}], Iter [{idx + 1}/{len(data_loader)}] ' 161 | f'Losses: Gaze Yaw {sum_loss_yaw / (idx + 1):.4f}, Gaze Pitch {sum_loss_pitch / (idx + 1):.4f}' 162 | ) 163 | avg_loss_pitch, avg_loss_yaw = sum_loss_pitch / len(data_loader), sum_loss_yaw / len(data_loader) 164 | 165 | return avg_loss_pitch, avg_loss_yaw 166 | 167 | 168 | def main(): 169 | params = parse_args() 170 | 171 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 172 | summary_name = f'{params.dataset}_{params.arch}_{int(time.time())}' 173 | output = os.path.join(params.output, summary_name) 174 | if not os.path.exists(output): 175 | os.makedirs(output) 176 | torch.backends.cudnn.benchmark = True 177 | 178 | model, optimizer, start_epoch = initialize_model(params, device) 179 | train_loader = get_dataloader(params, mode="train") 180 | 181 | cls_criterion = nn.CrossEntropyLoss() 182 | reg_criterion = nn.MSELoss() 183 | idx_tensor = torch.arange(params.bins, device=device, dtype=torch.float32) 184 | 185 | best_loss = float('inf') 186 | print(f"Started training from epoch: {start_epoch + 1}") 187 | 188 | for epoch in range(start_epoch, params.num_epochs): 189 | avg_loss_pitch, avg_loss_yaw = train_one_epoch( 190 | params, 191 | model, 192 | cls_criterion, 193 | reg_criterion, 194 | optimizer, 195 | train_loader, 196 | idx_tensor, 197 | device, 198 | epoch 199 | ) 200 | 201 | logging.info( 202 | f'Epoch [{epoch + 1}/{params.num_epochs}] ' 203 | f'Losses: Gaze Yaw {avg_loss_yaw:.4f}, Gaze Pitch {avg_loss_pitch:.4f}' 204 | ) 205 | 206 | checkpoint_path = os.path.join(output, "checkpoint.ckpt") 207 | torch.save({ 208 | 'epoch': epoch + 1, 209 | 'model_state_dict': model.state_dict(), 210 | 'optimizer_state_dict': optimizer.state_dict(), 211 | 'loss': avg_loss_pitch + avg_loss_yaw, 212 | }, checkpoint_path) 213 | logging.info(f'Checkpoint saved at {checkpoint_path}') 214 | 215 | current_loss = (avg_loss_pitch + avg_loss_yaw) / len(train_loader) 216 | if current_loss < best_loss: 217 | best_loss = current_loss 218 | best_model_path = os.path.join(output, 'best_model.pt') 219 | torch.save(model.state_dict(), best_model_path) 220 | logging.info(f'Best model saved at {best_model_path}') 221 | 222 | 223 | if __name__ == '__main__': 224 | main() 225 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet18, resnet34, resnet50 2 | from .mobilenet import mobilenet_v2 3 | from .mobileone import mobileone_s0, mobileone_s1, mobileone_s2, mobileone_s3, mobileone_s4, reparameterize_model 4 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torchvision.models import MobileNet_V2_Weights 4 | 5 | from typing import Any, Callable, List, Optional 6 | 7 | __all__ = ["mobilenet_v2"] 8 | 9 | 10 | def _make_divisible(v: float, divisor: int = 8) -> int: 11 | """This function ensures that all layers have a channel number divisible by 8""" 12 | new_v = max(divisor, int(v + divisor / 2) // divisor * divisor) 13 | # Make sure that round down does not go down by more than 10%. 14 | if new_v < 0.9 * v: 15 | new_v += divisor 16 | return new_v 17 | 18 | 19 | class Conv2dNormActivation(torch.nn.Sequential): 20 | """Convolutional block, consists of nn.Conv2d, nn.BatchNorm2d, nn.ReLU""" 21 | 22 | def __init__( 23 | self, 24 | in_channels: int, 25 | out_channels: int, 26 | kernel_size: int = 3, 27 | stride: int = 1, 28 | padding: Optional = None, 29 | groups: int = 1, 30 | activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, 31 | dilation: int = 1, 32 | inplace: Optional[bool] = True, 33 | bias: bool = False, 34 | ) -> None: 35 | 36 | if padding is None: 37 | padding = (kernel_size - 1) // 2 * dilation 38 | 39 | layers: List[nn.Module] = [ 40 | nn.Conv2d( 41 | in_channels=in_channels, 42 | out_channels=out_channels, 43 | kernel_size=kernel_size, 44 | stride=stride, 45 | padding=padding, 46 | dilation=dilation, 47 | groups=groups, 48 | bias=bias, 49 | ), 50 | nn.BatchNorm2d(num_features=out_channels, eps=0.001, momentum=0.01) 51 | ] 52 | 53 | if activation_layer is not None: 54 | params = {} if inplace is None else {"inplace": inplace} 55 | layers.append(activation_layer(**params)) 56 | super().__init__(*layers) 57 | 58 | 59 | class InvertedResidual(nn.Module): 60 | def __init__(self, in_planes: int, out_planes: int, stride: int, expand_ratio: int) -> None: 61 | super().__init__() 62 | self.stride = stride 63 | if stride not in [1, 2]: 64 | raise ValueError(f"stride should be 1 or 2 instead of {stride}") 65 | 66 | hidden_dim = int(round(in_planes * expand_ratio)) 67 | self.use_res_connect = self.stride == 1 and in_planes == out_planes 68 | 69 | layers: List[nn.Module] = [] 70 | if expand_ratio != 1: 71 | # pw 72 | layers.append( 73 | Conv2dNormActivation( 74 | in_planes, 75 | hidden_dim, 76 | kernel_size=1, 77 | activation_layer=nn.ReLU6 78 | ) 79 | ) 80 | layers.extend( 81 | [ 82 | # dw 83 | Conv2dNormActivation( 84 | hidden_dim, 85 | hidden_dim, 86 | stride=stride, 87 | groups=hidden_dim, 88 | activation_layer=nn.ReLU6, 89 | ), 90 | # pw-linear 91 | nn.Conv2d(hidden_dim, out_planes, 1, 1, 0, bias=False), 92 | nn.BatchNorm2d(out_planes), 93 | ] 94 | ) 95 | self.conv = nn.Sequential(*layers) 96 | self.out_channels = out_planes 97 | self._is_cn = stride > 1 98 | 99 | def forward(self, x: Tensor) -> Tensor: 100 | if self.use_res_connect: 101 | return x + self.conv(x) 102 | else: 103 | return self.conv(x) 104 | 105 | 106 | class MobileNetV2(nn.Module): 107 | def __init__( 108 | self, 109 | num_classes: int = 1000, 110 | width_mult: float = 1.0, 111 | inverted_residual_setting: Optional[List[List[int]]] = None, 112 | round_nearest: int = 8, 113 | dropout: float = 0.2, 114 | ) -> None: 115 | """ 116 | MobileNet V2 main class 117 | 118 | Args: 119 | num_classes (int): Number of classes 120 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 121 | inverted_residual_setting: Network structure 122 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 123 | Set to 1 to turn off rounding 124 | block: Module specifying inverted residual building block for mobilenet 125 | dropout (float): The droupout probability 126 | 127 | """ 128 | super().__init__() 129 | 130 | input_channel = 32 131 | last_channel = 1280 132 | 133 | if inverted_residual_setting is None: 134 | inverted_residual_setting = [ 135 | # t, c, n, s 136 | [1, 16, 1, 1], 137 | [6, 24, 2, 2], 138 | [6, 32, 3, 2], 139 | [6, 64, 4, 2], 140 | [6, 96, 3, 1], 141 | [6, 160, 3, 2], 142 | [6, 320, 1, 1], 143 | ] 144 | 145 | # only check the first element, assuming user knows t,c,n,s are required 146 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 147 | raise ValueError( 148 | f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}" 149 | ) 150 | 151 | # building first layer 152 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 153 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 154 | features: List[nn.Module] = [ 155 | Conv2dNormActivation(3, input_channel, stride=2, activation_layer=nn.ReLU6) 156 | ] 157 | # building inverted residual blocks 158 | for t, c, n, s in inverted_residual_setting: 159 | output_channel = _make_divisible(c * width_mult, round_nearest) 160 | for i in range(n): 161 | stride = s if i == 0 else 1 162 | features.append(InvertedResidual(input_channel, output_channel, stride, expand_ratio=t)) 163 | input_channel = output_channel 164 | # building last several layers 165 | features.append( 166 | Conv2dNormActivation( 167 | input_channel, self.last_channel, kernel_size=1, activation_layer=nn.ReLU6 168 | ) 169 | ) 170 | # make it nn.Sequential 171 | self.features = nn.Sequential(*features) 172 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 173 | 174 | # building classifier 175 | # self.classifier = nn.Sequential( 176 | # nn.Dropout(p=dropout), 177 | # nn.Linear(self.last_channel, num_classes), 178 | # ) 179 | 180 | self.fc_yaw = nn.Linear(self.last_channel, num_classes) 181 | self.fc_pitch = nn.Linear(self.last_channel, num_classes) 182 | 183 | # weight initialization 184 | for m in self.modules(): 185 | if isinstance(m, nn.Conv2d): 186 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 187 | if m.bias is not None: 188 | nn.init.zeros_(m.bias) 189 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 190 | nn.init.ones_(m.weight) 191 | nn.init.zeros_(m.bias) 192 | elif isinstance(m, nn.Linear): 193 | nn.init.normal_(m.weight, 0, 0.01) 194 | nn.init.zeros_(m.bias) 195 | 196 | def forward(self, x: Tensor) -> Tensor: 197 | x = self.features(x) 198 | # Cannot use "squeeze" as batch-size can be 1 199 | x = self.avgpool(x) 200 | x = torch.flatten(x, 1) 201 | 202 | # Original FC layer from MobileNet V2 203 | # x = self.classifier(x) 204 | 205 | yaw = self.fc_yaw(x) 206 | pitch = self.fc_pitch(x) 207 | 208 | return pitch, yaw 209 | 210 | 211 | def load_filtered_state_dict(model, state_dict): 212 | """Update the model's state dictionary with filtered parameters. 213 | 214 | Args: 215 | model: The model instance to update (must have `state_dict` and `load_state_dict` methods). 216 | state_dict: A dictionary of parameters to load into the model. 217 | """ 218 | current_model_dict = model.state_dict() 219 | filtered_state_dict = {key: value for key, value in state_dict.items() if key in current_model_dict} 220 | current_model_dict.update(filtered_state_dict) 221 | model.load_state_dict(current_model_dict) 222 | 223 | 224 | def mobilenet_v2(*, pretrained: bool = True, progress: bool = True, **kwargs: Any) -> MobileNetV2: 225 | 226 | if pretrained: 227 | weights = MobileNet_V2_Weights.IMAGENET1K_V1 228 | else: 229 | weights = None 230 | 231 | model = MobileNetV2(**kwargs) 232 | 233 | if weights is not None: 234 | state_dict = weights.get_state_dict(progress=progress, check_hash=True) 235 | load_filtered_state_dict(model, state_dict) 236 | 237 | return model 238 | -------------------------------------------------------------------------------- /models/mobileone.py: -------------------------------------------------------------------------------- 1 | # Modified by Yakhyokhuja Valikhujaev 2 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 3 | 4 | import copy 5 | import logging 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, List, Tuple 12 | 13 | __all__ = [ 14 | "mobileone_s0", 15 | "mobileone_s1", 16 | "mobileone_s2", 17 | "mobileone_s3", 18 | "mobileone_s4", 19 | "mobileone_s5", 20 | "reparameterize_model" 21 | ] 22 | 23 | 24 | logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO) 25 | logger = logging.getLogger() 26 | 27 | 28 | class SqueezeExcitationBlock(nn.Module): 29 | """ 30 | Squeeze and Excite module. 31 | 32 | Pytorch implementation of `Squeeze-and-Excitation Networks` - 33 | https://arxiv.org/pdf/1709.01507.pdf 34 | """ 35 | 36 | def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None: 37 | """ 38 | Construct a Squeeze and Excite Module. 39 | 40 | Args: 41 | in_channels (int): Number of input channels. 42 | rd_ratio (float): Input channel reduction ratio. 43 | """ 44 | 45 | super().__init__() 46 | self.reduce = nn.Conv2d( 47 | in_channels=in_channels, 48 | out_channels=int(in_channels * rd_ratio), 49 | kernel_size=1, 50 | stride=1, 51 | bias=True 52 | ) 53 | self.expand = nn.Conv2d( 54 | in_channels=int(in_channels * rd_ratio), 55 | out_channels=in_channels, 56 | kernel_size=1, 57 | stride=1, 58 | bias=True 59 | ) 60 | 61 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 62 | """ Apply forward pass. """ 63 | b, c, h, w = inputs.size() 64 | x = F.avg_pool2d(inputs, kernel_size=[h, w]) 65 | x = self.reduce(x) 66 | x = F.relu(x) 67 | x = self.expand(x) 68 | x = torch.sigmoid(x) 69 | x = x.view(-1, c, 1, 1) 70 | return inputs * x 71 | 72 | 73 | class MobileOneBlock(nn.Module): 74 | """ MobileOne building block. 75 | 76 | This block has a multi-branched architecture at train-time 77 | and plain-CNN style architecture at inference time 78 | For more details, please refer to our paper: 79 | `An Improved One millisecond Mobile Backbone` - 80 | https://arxiv.org/pdf/2206.04040.pdf 81 | """ 82 | 83 | def __init__( 84 | self, 85 | in_channels: int, 86 | out_channels: int, 87 | kernel_size: int, 88 | stride: int = 1, 89 | padding: int = 0, 90 | dilation: int = 1, 91 | groups: int = 1, 92 | inference_mode: bool = False, 93 | use_se: bool = False, 94 | num_conv_branches: int = 1 95 | ) -> None: 96 | """ 97 | Construct a MobileOneBlock module. 98 | 99 | Args: 100 | in_channels (int): Number of channels in the input. 101 | out_channels (int): Number of channels produced by the block. 102 | kernel_size (int or tuple): Size of the convolution kernel. 103 | stride (int or tuple): Stride size. 104 | padding (int or tuple): Zero-padding size. 105 | dilation (int or tuple): Kernel dilation factor. 106 | groups (int): Group number. 107 | inference_mode (bool): If True, instantiates model in inference mode. 108 | use_se (bool): Whether to use SE-ReLU activations. 109 | num_conv_branches (int): Number of linear conv branches. 110 | """ 111 | 112 | super().__init__() 113 | self.inference_mode = inference_mode 114 | self.groups = groups 115 | self.stride = stride 116 | self.kernel_size = kernel_size 117 | self.in_channels = in_channels 118 | self.out_channels = out_channels 119 | self.num_conv_branches = num_conv_branches 120 | 121 | # Check if SE-ReLU is requested 122 | if use_se: 123 | self.se = SqueezeExcitationBlock(out_channels) 124 | else: 125 | self.se = nn.Identity() 126 | self.activation = nn.ReLU() 127 | 128 | if inference_mode: 129 | self.reparam_conv = nn.Conv2d( 130 | in_channels=in_channels, 131 | out_channels=out_channels, 132 | kernel_size=kernel_size, 133 | stride=stride, 134 | padding=padding, 135 | dilation=dilation, 136 | groups=groups, 137 | bias=True 138 | ) 139 | else: 140 | # Re-parameterizable skip connection 141 | self.rbr_skip = nn.BatchNorm2d(num_features=in_channels) \ 142 | if out_channels == in_channels and stride == 1 else None 143 | 144 | # Re-parameterizable conv branches 145 | rbr_conv = list() 146 | for _ in range(self.num_conv_branches): 147 | rbr_conv.append(self._conv_bn(kernel_size=kernel_size, padding=padding)) 148 | self.rbr_conv = nn.ModuleList(rbr_conv) 149 | 150 | # Re-parameterizable scale branch 151 | self.rbr_scale = None 152 | if kernel_size > 1: 153 | self.rbr_scale = self._conv_bn(kernel_size=1, padding=0) 154 | 155 | def forward(self, x: torch.Tensor) -> torch.Tensor: 156 | """ Apply forward pass. """ 157 | # Inference mode forward pass. 158 | if self.inference_mode: 159 | return self.activation(self.se(self.reparam_conv(x))) 160 | 161 | # Multi-branched train-time forward pass. 162 | # Skip branch output 163 | identity_out = 0 164 | if self.rbr_skip is not None: 165 | identity_out = self.rbr_skip(x) 166 | 167 | # Scale branch output 168 | scale_out = 0 169 | if self.rbr_scale is not None: 170 | scale_out = self.rbr_scale(x) 171 | 172 | # Other branches 173 | out = scale_out + identity_out 174 | for ix in range(self.num_conv_branches): 175 | out += self.rbr_conv[ix](x) 176 | 177 | return self.activation(self.se(out)) 178 | 179 | def reparameterize(self): 180 | """ 181 | Following works like `RepVGG: Making VGG-style ConvNets Great Again` - https://arxiv.org/pdf/2101.03697.pdf. 182 | We re-parameterize multi-branched architecture used at training time to obtain a plain CNN-like structure 183 | for inference. 184 | """ 185 | if self.inference_mode: 186 | return 187 | kernel, bias = self._get_kernel_bias() 188 | self.reparam_conv = nn.Conv2d( 189 | in_channels=self.rbr_conv[0].conv.in_channels, 190 | out_channels=self.rbr_conv[0].conv.out_channels, 191 | kernel_size=self.rbr_conv[0].conv.kernel_size, 192 | stride=self.rbr_conv[0].conv.stride, 193 | padding=self.rbr_conv[0].conv.padding, 194 | dilation=self.rbr_conv[0].conv.dilation, 195 | groups=self.rbr_conv[0].conv.groups, 196 | bias=True 197 | ) 198 | self.reparam_conv.weight.data = kernel 199 | self.reparam_conv.bias.data = bias 200 | 201 | # Delete un-used branches 202 | for para in self.parameters(): 203 | para.detach_() 204 | self.__delattr__('rbr_conv') 205 | self.__delattr__('rbr_scale') 206 | if hasattr(self, 'rbr_skip'): 207 | self.__delattr__('rbr_skip') 208 | 209 | self.inference_mode = True 210 | 211 | def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: 212 | """ 213 | Method to fuse batchnorm layer with preceding conv layer. 214 | Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 215 | 216 | Args: 217 | branch: The branch containing the convolutional and batchnorm layers to be fused. 218 | 219 | Returns: 220 | tuple: A tuple containing the kernel and bias after fusing batchnorm. 221 | """ 222 | # get weights and bias of scale branch 223 | kernel_scale = 0 224 | bias_scale = 0 225 | if self.rbr_scale is not None: 226 | kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale) 227 | # Pad scale branch kernel to match conv branch kernel size. 228 | pad = self.kernel_size // 2 229 | kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) 230 | 231 | # get weights and bias of skip branch 232 | kernel_identity = 0 233 | bias_identity = 0 234 | if self.rbr_skip is not None: 235 | kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip) 236 | 237 | # get weights and bias of conv branches 238 | kernel_conv = 0 239 | bias_conv = 0 240 | for ix in range(self.num_conv_branches): 241 | _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix]) 242 | kernel_conv += _kernel 243 | bias_conv += _bias 244 | 245 | kernel_final = kernel_conv + kernel_scale + kernel_identity 246 | bias_final = bias_conv + bias_scale + bias_identity 247 | return kernel_final, bias_final 248 | 249 | def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]: 250 | """ 251 | Method to fuse batchnorm layer with preceding conv layer. 252 | Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 253 | 254 | Args: 255 | branch: The branch containing the convolutional and batchnorm layers to be fused. 256 | 257 | Returns: 258 | tuple: A tuple containing the kernel and bias after fusing batchnorm. 259 | """ 260 | if isinstance(branch, nn.Sequential): 261 | kernel = branch.conv.weight 262 | running_mean = branch.bn.running_mean 263 | running_var = branch.bn.running_var 264 | gamma = branch.bn.weight 265 | beta = branch.bn.bias 266 | eps = branch.bn.eps 267 | else: 268 | assert isinstance(branch, nn.BatchNorm2d) 269 | if not hasattr(self, 'id_tensor'): 270 | input_dim = self.in_channels // self.groups 271 | kernel_value = torch.zeros( 272 | (self.in_channels, input_dim, self.kernel_size, self.kernel_size), 273 | dtype=branch.weight.dtype, 274 | device=branch.weight.device 275 | ) 276 | for i in range(self.in_channels): 277 | kernel_value[i, i % input_dim, 278 | self.kernel_size // 2, 279 | self.kernel_size // 2] = 1 280 | self.id_tensor = kernel_value 281 | kernel = self.id_tensor 282 | running_mean = branch.running_mean 283 | running_var = branch.running_var 284 | gamma = branch.weight 285 | beta = branch.bias 286 | eps = branch.eps 287 | std = (running_var + eps).sqrt() 288 | t = (gamma / std).reshape(-1, 1, 1, 1) 289 | return kernel * t, beta - running_mean * gamma / std 290 | 291 | def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential: 292 | """ 293 | Helper method to construct conv-batchnorm layers. 294 | 295 | Args: 296 | kernel_size (int): Size of the convolution kernel. 297 | padding (int): Zero-padding size. 298 | 299 | Returns: 300 | nn.Sequential: A Conv-BN module. 301 | """ 302 | mod_list = nn.Sequential() 303 | mod_list.add_module( 304 | 'conv', 305 | nn.Conv2d( 306 | in_channels=self.in_channels, 307 | out_channels=self.out_channels, 308 | kernel_size=kernel_size, 309 | stride=self.stride, 310 | padding=padding, 311 | groups=self.groups, 312 | bias=False 313 | ) 314 | ) 315 | mod_list.add_module('bn', nn.BatchNorm2d(num_features=self.out_channels)) 316 | return mod_list 317 | 318 | 319 | class MobileOne(nn.Module): 320 | """ 321 | MobileOne Model 322 | 323 | Pytorch implementation of `An Improved One millisecond Mobile Backbone` - 324 | https://arxiv.org/pdf/2206.04040.pdf 325 | """ 326 | 327 | def __init__( 328 | self, 329 | num_blocks_per_stage: List[int] = [2, 8, 10, 1], 330 | num_classes: int = 1000, 331 | width_multipliers: Optional[List[float]] = None, 332 | inference_mode: bool = False, 333 | use_se: bool = False, 334 | num_conv_branches: int = 1 335 | ) -> None: 336 | """ 337 | Construct MobileOne model. 338 | 339 | Args: 340 | num_blocks_per_stage (list): List of number of blocks per stage. 341 | num_classes (int): Number of classes in the dataset. 342 | width_multipliers (list): List of width multipliers for blocks in a stage. 343 | inference_mode (bool): If True, instantiates model in inference mode. 344 | use_se (bool): Whether to use SE-ReLU activations. 345 | num_conv_branches (int): Number of linear conv branches. 346 | """ 347 | super().__init__() 348 | 349 | assert len(width_multipliers) == 4 350 | self.inference_mode = inference_mode 351 | self.in_planes = min(64, int(64 * width_multipliers[0])) 352 | self.use_se = use_se 353 | self.num_conv_branches = num_conv_branches 354 | 355 | # Build stages 356 | self.stage0 = MobileOneBlock( 357 | in_channels=3, 358 | out_channels=self.in_planes, 359 | kernel_size=3, 360 | stride=2, 361 | padding=1, 362 | inference_mode=self.inference_mode 363 | ) 364 | self.cur_layer_idx = 1 365 | self.stage1 = self._make_stage(int(64 * width_multipliers[0]), num_blocks_per_stage[0], num_se_blocks=0) 366 | self.stage2 = self._make_stage(int(128 * width_multipliers[1]), num_blocks_per_stage[1], num_se_blocks=0) 367 | self.stage3 = self._make_stage( 368 | int(256 * width_multipliers[2]), 369 | num_blocks_per_stage[2], 370 | num_se_blocks=int(num_blocks_per_stage[2] // 2) if use_se else 0 371 | ) 372 | self.stage4 = self._make_stage( 373 | int(512 * width_multipliers[3]), 374 | num_blocks_per_stage[3], 375 | num_se_blocks=num_blocks_per_stage[3] if use_se else 0 376 | ) 377 | self.gap = nn.AdaptiveAvgPool2d(output_size=1) 378 | 379 | # yaw and pitch 380 | self.fc_yaw = nn.Linear(int(512 * width_multipliers[3]), num_classes) 381 | self.fc_pitch = nn.Linear(int(512 * width_multipliers[3]), num_classes) 382 | 383 | # self.linear = nn.Linear(int(512 * width_multipliers[3]), num_classes) 384 | 385 | def _make_stage(self, planes: int, num_blocks: int, num_se_blocks: int) -> nn.Sequential: 386 | """ 387 | Build a stage of the MobileOne model. 388 | 389 | Args: 390 | planes (int): Number of output channels. 391 | num_blocks (int): Number of blocks in this stage. 392 | num_se_blocks (int): Number of SE blocks in this stage. 393 | 394 | Returns: 395 | nn.Sequential: A stage of the MobileOne model. 396 | """ 397 | 398 | # Get strides for all layers 399 | strides = [2] + [1]*(num_blocks-1) 400 | blocks = [] 401 | for ix, stride in enumerate(strides): 402 | use_se = False 403 | if num_se_blocks > num_blocks: 404 | raise ValueError("Number of SE blocks cannot exceed number of layers.") 405 | if ix >= (num_blocks - num_se_blocks): 406 | use_se = True 407 | 408 | # Depthwise conv 409 | blocks.append( 410 | MobileOneBlock( 411 | in_channels=self.in_planes, 412 | out_channels=self.in_planes, 413 | kernel_size=3, 414 | stride=stride, 415 | padding=1, 416 | groups=self.in_planes, 417 | inference_mode=self.inference_mode, 418 | use_se=use_se, 419 | num_conv_branches=self.num_conv_branches 420 | ) 421 | ) 422 | # Pointwise conv 423 | blocks.append( 424 | MobileOneBlock( 425 | in_channels=self.in_planes, 426 | out_channels=planes, 427 | kernel_size=1, 428 | stride=1, 429 | padding=0, 430 | groups=1, 431 | inference_mode=self.inference_mode, 432 | use_se=use_se, 433 | num_conv_branches=self.num_conv_branches 434 | ) 435 | ) 436 | self.in_planes = planes 437 | self.cur_layer_idx += 1 438 | return nn.Sequential(*blocks) 439 | 440 | def forward(self, x: torch.Tensor) -> torch.Tensor: 441 | """ Apply forward pass . """ 442 | x = self.stage0(x) 443 | x = self.stage1(x) 444 | x = self.stage2(x) 445 | x = self.stage3(x) 446 | x = self.stage4(x) 447 | x = self.gap(x) 448 | x = x.view(x.size(0), -1) 449 | # x = self.linear(x) 450 | 451 | yaw = self.fc_yaw(x) 452 | pitch = self.fc_pitch(x) 453 | 454 | return pitch, yaw 455 | 456 | 457 | def reparameterize_model(model: torch.nn.Module) -> nn.Module: 458 | """ 459 | Re-parameterize the MobileOne model from a multi-branched structure (used in training) 460 | into a single branch for inference. 461 | 462 | Args: 463 | model (nn.Module): MobileOne model in training mode. 464 | 465 | Returns: 466 | nn.Module: MobileOne model re-parameterized for inference mode. 467 | """ 468 | 469 | # Avoid editing original graph 470 | model = copy.deepcopy(model) 471 | for module in model.modules(): 472 | if hasattr(module, 'reparameterize'): 473 | module.reparameterize() 474 | return model 475 | 476 | 477 | MOBILEONE_CONFIGS = { 478 | "mobileone_s0": 479 | { 480 | "weights": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar", 481 | "params": 482 | { 483 | "width_multipliers": (0.75, 1.0, 1.0, 2.0), 484 | "num_conv_branches": 4 485 | } 486 | }, 487 | "mobileone_s1": 488 | { 489 | "weights": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar", 490 | "params": 491 | { 492 | "width_multipliers": (1.5, 1.5, 2.0, 2.5), 493 | } 494 | 495 | }, 496 | "mobileone_s2": 497 | { 498 | "weights": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar", 499 | "params": 500 | { 501 | "width_multipliers": (1.5, 2.0, 2.5, 4.0), 502 | } 503 | }, 504 | "mobileone_s3": 505 | { 506 | "weights": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar", 507 | "params": 508 | { 509 | "width_multipliers": (2.0, 2.5, 3.0, 4.0), 510 | } 511 | }, 512 | "mobileone_s4": 513 | { 514 | "weights": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar", 515 | "params": 516 | { 517 | "width_multipliers": (3.0, 3.5, 3.5, 4.0), 518 | "use_se": True 519 | } 520 | } 521 | } 522 | 523 | 524 | def load_filtered_state_dict(model, state_dict): 525 | """Update the model's state dictionary with filtered parameters. 526 | 527 | Args: 528 | model: The model instance to update (must have `state_dict` and `load_state_dict` methods). 529 | state_dict: A dictionary of parameters to load into the model. 530 | """ 531 | current_model_dict = model.state_dict() 532 | filtered_state_dict = {key: value for key, value in state_dict.items() if key in current_model_dict} 533 | current_model_dict.update(filtered_state_dict) 534 | model.load_state_dict(current_model_dict) 535 | 536 | 537 | def create_mobileone_model(config, pretrained: bool = True, num_classes: int = 1000, inference_mode: bool = False) -> nn.Module: 538 | """ 539 | Create a MobileOne model based on the specified architecture name. 540 | 541 | Args: 542 | config (dict): The configuration dictionary for the MobileOne model. 543 | pretrained (bool): If True, loads pre-trained weights for the specified architecture. Defaults to True. 544 | num_classes (int): Number of output classes for the model. Defaults to 1000. 545 | inference_mode (bool): If True, instantiates the model in inference mode. Defaults to False. 546 | 547 | Returns: 548 | nn.Module: The constructed MobileOne model. 549 | """ 550 | weights = config["weights"] 551 | params = config["params"] 552 | 553 | model = MobileOne(num_classes=num_classes, inference_mode=inference_mode, **params) 554 | 555 | if pretrained: 556 | try: 557 | state_dict = torch.hub.load_state_dict_from_url(weights) 558 | load_filtered_state_dict(model, state_dict) 559 | logger.info("Pre-trained weights successfully loaded.") 560 | except Exception as e: 561 | logger.warning(f"Could not load pre-trained weights. Exception: {e}") 562 | logger.info("Creating model without pre-trained weights.") 563 | else: 564 | logger.info("Creating model without pre-trained weights.") 565 | 566 | return model 567 | 568 | 569 | def mobileone_s0(pretrained=True, num_classes=1000, inference_mode=False): 570 | return create_mobileone_model(MOBILEONE_CONFIGS['mobileone_s0'], pretrained, num_classes, inference_mode) 571 | 572 | 573 | def mobileone_s1(pretrained=True, num_classes=1000, inference_mode=False): 574 | return create_mobileone_model(MOBILEONE_CONFIGS['mobileone_s1'], pretrained, num_classes, inference_mode) 575 | 576 | 577 | def mobileone_s2(pretrained=True, num_classes=1000, inference_mode=False): 578 | return create_mobileone_model(MOBILEONE_CONFIGS['mobileone_s2'], pretrained, num_classes, inference_mode) 579 | 580 | 581 | def mobileone_s3(pretrained=True, num_classes=1000, inference_mode=False): 582 | return create_mobileone_model(MOBILEONE_CONFIGS['mobileone_s3'], pretrained, num_classes, inference_mode) 583 | 584 | 585 | def mobileone_s4(pretrained=True, num_classes=1000, inference_mode=False): 586 | return create_mobileone_model(MOBILEONE_CONFIGS['mobileone_s4'], pretrained, num_classes, inference_mode) 587 | 588 | 589 | if __name__ == "__main__": 590 | model = mobileone_s2() 591 | print(sum(p.numel() for p in model.parameters() if p.requires_grad)) 592 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torchvision.models import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights 4 | 5 | from typing import Any, Callable, List, Optional, Type, Tuple 6 | 7 | 8 | __all__ = ["resnet18", "resnet34", "resnet50"] 9 | 10 | 11 | def conv3x3(in_channels: int, out_channels: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d( 14 | in_channels, 15 | out_channels, 16 | kernel_size=3, 17 | stride=stride, 18 | padding=dilation, 19 | groups=groups, 20 | bias=False, 21 | dilation=dilation, 22 | ) 23 | 24 | 25 | def conv1x1(in_channels: int, out_channels: int, stride: int = 1) -> nn.Conv2d: 26 | """1x1 convolution""" 27 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion: int = 1 32 | 33 | def __init__( 34 | self, 35 | in_channels: int, 36 | out_channels: int, 37 | stride: int = 1, 38 | downsample: Optional[nn.Module] = None, 39 | groups: int = 1, 40 | base_width: int = 64, 41 | dilation: int = 1, 42 | norm_layer: Optional[Callable[..., nn.Module]] = None, 43 | ) -> None: 44 | super().__init__() 45 | if norm_layer is None: 46 | norm_layer = nn.BatchNorm2d 47 | if groups != 1 or base_width != 64: 48 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 49 | if dilation > 1: 50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 52 | self.conv1 = conv3x3(in_channels, out_channels, stride) 53 | self.bn1 = norm_layer(out_channels) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv2 = conv3x3(out_channels, out_channels) 56 | self.bn2 = norm_layer(out_channels) 57 | self.downsample = downsample 58 | self.stride = stride 59 | 60 | def forward(self, x: Tensor) -> Tensor: 61 | identity = x 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.relu(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | 70 | if self.downsample is not None: 71 | identity = self.downsample(x) 72 | 73 | out += identity 74 | out = self.relu(out) 75 | 76 | return out 77 | 78 | 79 | class Bottleneck(nn.Module): 80 | expansion: int = 4 81 | 82 | def __init__( 83 | self, 84 | inplanes: int, 85 | planes: int, 86 | stride: int = 1, 87 | downsample: Optional[nn.Module] = None, 88 | groups: int = 1, 89 | base_width: int = 64, 90 | dilation: int = 1, 91 | norm_layer: Optional[Callable[..., nn.Module]] = None, 92 | ) -> None: 93 | super().__init__() 94 | if norm_layer is None: 95 | norm_layer = nn.BatchNorm2d 96 | width = int(planes * (base_width / 64.0)) * groups 97 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 98 | self.conv1 = conv1x1(inplanes, width) 99 | self.bn1 = norm_layer(width) 100 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 101 | self.bn2 = norm_layer(width) 102 | self.conv3 = conv1x1(width, planes * self.expansion) 103 | self.bn3 = norm_layer(planes * self.expansion) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.downsample = downsample 106 | self.stride = stride 107 | 108 | def forward(self, x: Tensor) -> Tensor: 109 | identity = x 110 | 111 | out = self.conv1(x) 112 | out = self.bn1(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv2(out) 116 | out = self.bn2(out) 117 | out = self.relu(out) 118 | 119 | out = self.conv3(out) 120 | out = self.bn3(out) 121 | 122 | if self.downsample is not None: 123 | identity = self.downsample(x) 124 | 125 | out += identity 126 | out = self.relu(out) 127 | 128 | return out 129 | 130 | 131 | class ResNet(nn.Module): 132 | def __init__( 133 | self, 134 | block: Type[BasicBlock | Bottleneck], 135 | layers: List[int], 136 | num_classes: int = 1000, 137 | groups: int = 1, 138 | width_per_group: int = 64, 139 | replace_stride_with_dilation: Optional[List[bool]] = None, 140 | norm_layer: Optional[Callable[..., nn.Module]] = None, 141 | ) -> None: 142 | super().__init__() 143 | if norm_layer is None: 144 | norm_layer = nn.BatchNorm2d 145 | self._norm_layer = norm_layer 146 | 147 | self.in_channels = 64 148 | self.dilation = 1 149 | if replace_stride_with_dilation is None: 150 | # each element in the tuple indicates if we should replace 151 | # the 2x2 stride with a dilated convolution instead 152 | replace_stride_with_dilation = [False, False, False] 153 | if len(replace_stride_with_dilation) != 3: 154 | raise ValueError( 155 | "replace_stride_with_dilation should be None " 156 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 157 | ) 158 | self.groups = groups 159 | self.base_width = width_per_group 160 | self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False) 161 | self.bn1 = norm_layer(self.in_channels) 162 | self.relu = nn.ReLU(inplace=True) 163 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 164 | self.layer1 = self._make_layer(block, 64, layers[0]) 165 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 166 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 167 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 168 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 169 | 170 | # yaw and pitch 171 | self.fc_yaw = nn.Linear(512 * block.expansion, num_classes) 172 | self.fc_pitch = nn.Linear(512 * block.expansion, num_classes) 173 | 174 | # Original FC Layer for ResNet 175 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 176 | 177 | for m in self.modules(): 178 | if isinstance(m, nn.Conv2d): 179 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 180 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 181 | nn.init.constant_(m.weight, 1) 182 | nn.init.constant_(m.bias, 0) 183 | 184 | def _make_layer( 185 | self, 186 | block: Type[BasicBlock | Bottleneck], 187 | planes: int, 188 | blocks: int, 189 | stride: int = 1, 190 | dilate: bool = False, 191 | ) -> nn.Sequential: 192 | norm_layer = self._norm_layer 193 | downsample = None 194 | previous_dilation = self.dilation 195 | if dilate: 196 | self.dilation *= stride 197 | stride = 1 198 | if stride != 1 or self.in_channels != planes * block.expansion: 199 | downsample = nn.Sequential( 200 | conv1x1(self.in_channels, planes * block.expansion, stride), 201 | norm_layer(planes * block.expansion), 202 | ) 203 | 204 | layers = [] 205 | layers.append( 206 | block( 207 | self.in_channels, 208 | planes, 209 | stride, 210 | downsample, 211 | self.groups, 212 | self.base_width, 213 | previous_dilation, 214 | norm_layer 215 | ) 216 | ) 217 | self.in_channels = planes * block.expansion 218 | for _ in range(1, blocks): 219 | layers.append( 220 | block( 221 | self.in_channels, 222 | planes, 223 | groups=self.groups, 224 | base_width=self.base_width, 225 | dilation=self.dilation, 226 | norm_layer=norm_layer, 227 | ) 228 | ) 229 | 230 | return nn.Sequential(*layers) 231 | 232 | def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: 233 | x = self.conv1(x) 234 | x = self.bn1(x) 235 | x = self.relu(x) 236 | x = self.maxpool(x) 237 | 238 | x = self.layer1(x) # 1/4 239 | x = self.layer2(x) # 1/8 240 | x = self.layer3(x) # 1/16 241 | x = self.layer4(x) # 1/32 242 | 243 | x = self.avgpool(x) 244 | x = torch.flatten(x, 1) 245 | 246 | # Original FC Layer for ResNet 247 | # x = self.fc(x) 248 | yaw = self.fc_yaw(x) 249 | pitch = self.fc_pitch(x) 250 | 251 | return pitch, yaw 252 | 253 | 254 | def load_filtered_state_dict(model, state_dict): 255 | """Update the model's state dictionary with filtered parameters. 256 | 257 | Args: 258 | model: The model instance to update (must have `state_dict` and `load_state_dict` methods). 259 | state_dict: A dictionary of parameters to load into the model. 260 | """ 261 | current_model_dict = model.state_dict() 262 | filtered_state_dict = {key: value for key, value in state_dict.items() if key in current_model_dict} 263 | current_model_dict.update(filtered_state_dict) 264 | model.load_state_dict(current_model_dict) 265 | 266 | 267 | def _resnet(block: Type[BasicBlock], layers: List[int], weights: Optional[ResNet34_Weights], progress: bool, **kwargs: Any) -> ResNet: 268 | model = ResNet(block, layers, **kwargs) 269 | 270 | if weights is not None: 271 | state_dict = weights.get_state_dict(progress=progress, check_hash=True) 272 | load_filtered_state_dict(model, state_dict) 273 | 274 | return model 275 | 276 | 277 | def resnet18(*, pretrained: bool = True, progress: bool = True, **kwargs: Any) -> ResNet: 278 | if pretrained: 279 | weights = ResNet18_Weights.DEFAULT 280 | else: 281 | weights = None 282 | return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) 283 | 284 | 285 | def resnet34(*, pretrained: bool = True, progress: bool = True, **kwargs: Any) -> ResNet: 286 | if pretrained: 287 | weights = ResNet34_Weights.DEFAULT 288 | else: 289 | weights = None 290 | return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) 291 | 292 | 293 | def resnet50(*, pretrained: bool = True, progress: bool = True, **kwargs: Any) -> ResNet: 294 | if pretrained: 295 | weights = ResNet50_Weights.DEFAULT 296 | else: 297 | weights = None 298 | 299 | return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) 300 | -------------------------------------------------------------------------------- /mpii_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import logging 5 | import argparse 6 | import numpy as np 7 | from tqdm import tqdm 8 | from sklearn.model_selection import KFold 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import DataLoader 14 | 15 | from config import data_config 16 | from utils.helpers import angular_error, gaze_to_3d, get_model, get_dataloader 17 | 18 | # Setup logging 19 | logging.basicConfig( 20 | level=logging.INFO, 21 | format='%(asctime)s - %(message)s', 22 | # handlers=[ 23 | # logging.FileHandler("training.log"), 24 | # logging.StreamHandler(sys.stdout) # Display logs in terminal 25 | # ] 26 | ) 27 | 28 | 29 | def parse_args(): 30 | """Parse input arguments.""" 31 | parser = argparse.ArgumentParser(description="Gaze estimation training") 32 | parser.add_argument("--data", type=str, default="data", help="Directory path for gaze images.") 33 | parser.add_argument("--dataset", type=str, default="gaze360", help="Dataset name, available `gaze360`, `mpiigaze`.") 34 | parser.add_argument("--output", type=str, default="output/", help="Path of output models.") 35 | parser.add_argument("--checkpoint", type=str, default="", help="Path to checkpoint for resuming training.") 36 | parser.add_argument("--num-epochs", type=int, default=100, help="Maximum number of training epochs.") 37 | parser.add_argument("--batch-size", type=int, default=64, help="Batch size.") 38 | parser.add_argument( 39 | "--arch", 40 | type=str, 41 | default="resnet18", 42 | help="Network architecture, currently available: resnet18/34/50, mobilenetv2, mobileone_s0-s4." 43 | ) 44 | parser.add_argument("--alpha", type=float, default=1, help="Regression loss coefficient.") 45 | parser.add_argument("--lr", type=float, default=0.00001, help="Base learning rate.") 46 | parser.add_argument("--num-workers", type=int, default=8, help="Number of workers for data loading.") 47 | 48 | args = parser.parse_args() 49 | 50 | # Override default values based on selected dataset 51 | if args.dataset in data_config: 52 | dataset_config = data_config[args.dataset] 53 | args.bins = dataset_config["bins"] 54 | args.binwidth = dataset_config["binwidth"] 55 | args.angle = dataset_config["angle"] 56 | else: 57 | raise ValueError(f"Unknown dataset: {args.dataset}. Available options: {list(data_config.keys())}") 58 | 59 | return args 60 | 61 | 62 | def initialize_model(params, device): 63 | """ 64 | Initialize the gaze estimation model, optimizer, and optionally load a checkpoint. 65 | 66 | Args: 67 | params (argparse.Namespace): Parsed command-line arguments. 68 | device (torch.device): Device to load the model and optimizer onto. 69 | 70 | Returns: 71 | Tuple[nn.Module, torch.optim.Optimizer, int]: Initialized model, optimizer, and the starting epoch. 72 | """ 73 | model = get_model(params.arch, params.bins) 74 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr) 75 | start_epoch = 0 76 | 77 | if params.checkpoint: 78 | checkpoint = torch.load(params.checkpoint, map_location=device) 79 | model.load_state_dict(checkpoint['model_state_dict']) 80 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 81 | 82 | # Move optimizer states to device 83 | for state in optimizer.state.values(): 84 | for k, v in state.items(): 85 | if isinstance(v, torch.Tensor): 86 | state[k] = v.to(device) 87 | 88 | start_epoch = checkpoint['epoch'] 89 | logging.info(f'Resumed training from {params.checkpoint}, starting at epoch {start_epoch + 1}') 90 | 91 | return model.to(device), optimizer, start_epoch 92 | 93 | 94 | def train_one_epoch( 95 | params, 96 | model, 97 | cls_criterion, 98 | reg_criterion, 99 | optimizer, 100 | data_loader, 101 | idx_tensor, 102 | device, 103 | epoch 104 | ): 105 | """ 106 | Train the model for one epoch. 107 | 108 | Args: 109 | params (argparse.Namespace): Parsed command-line arguments. 110 | model (nn.Module): The gaze estimation model. 111 | cls_criterion (nn.Module): Loss function for classification. 112 | reg_criterion (nn.Module): Loss function for regression. 113 | optimizer (torch.optim.Optimizer): Optimizer for the model. 114 | data_loader (DataLoader): DataLoader for the training dataset. 115 | idx_tensor (torch.Tensor): Tensor representing bin indices. 116 | device (torch.device): Device to perform training on. 117 | epoch (int): The current epoch number. 118 | 119 | Returns: 120 | Tuple[float, float]: Average losses for pitch and yaw. 121 | """ 122 | 123 | model.train() 124 | sum_loss_pitch, sum_loss_yaw = 0, 0 125 | 126 | for idx, (images, labels_gaze, regression_labels_gaze, _) in enumerate(data_loader): 127 | images = images.to(device) 128 | 129 | # Binned labels 130 | label_pitch = labels_gaze[:, 0].to(device) 131 | label_yaw = labels_gaze[:, 1].to(device) 132 | 133 | # Regression labels 134 | label_pitch_regression = regression_labels_gaze[:, 0].to(device) 135 | label_yaw_regression = regression_labels_gaze[:, 1].to(device) 136 | 137 | # Inference 138 | pitch, yaw = model(images) 139 | 140 | # Cross Entropy Loss 141 | loss_pitch = cls_criterion(pitch, label_pitch) 142 | loss_yaw = cls_criterion(yaw, label_yaw) 143 | 144 | # Mapping from binned (0 to 90) to angels (-180 to 180) 145 | pitch_predicted = torch.sum(F.softmax(pitch, dim=1) * idx_tensor, 1) * params.binwidth - params.angle 146 | yaw_predicted = torch.sum(F.softmax(yaw, dim=1) * idx_tensor, 1) * params.binwidth - params.angle 147 | 148 | # Mean Squared Error Loss 149 | loss_regression_pitch = reg_criterion(pitch_predicted, label_pitch_regression) 150 | loss_regression_yaw = reg_criterion(yaw_predicted, label_yaw_regression) 151 | 152 | # Calculate loss with regression alpha 153 | loss_pitch += params.alpha * loss_regression_pitch 154 | loss_yaw += params.alpha * loss_regression_yaw 155 | 156 | # Total loss for pitch and yaw 157 | loss = loss_pitch + loss_yaw 158 | 159 | optimizer.zero_grad() 160 | loss.backward() 161 | optimizer.step() 162 | 163 | sum_loss_pitch += loss_pitch.item() 164 | sum_loss_yaw += loss_yaw.item() 165 | 166 | if (idx + 1) % 100 == 0: 167 | logging.info( 168 | f'Epoch [{epoch + 1}/{params.num_epochs}], Iter [{idx + 1}/{len(data_loader)}] ' 169 | f'Losses: Gaze Yaw {sum_loss_yaw / (idx + 1):.4f}, Gaze Pitch {sum_loss_pitch / (idx + 1):.4f}' 170 | ) 171 | avg_loss_pitch, avg_loss_yaw = sum_loss_pitch / len(data_loader), sum_loss_yaw / len(data_loader) 172 | 173 | return avg_loss_pitch, avg_loss_yaw 174 | 175 | 176 | @torch.no_grad() 177 | def evaluate(params, model, data_loader, idx_tensor, device): 178 | """ 179 | Evaluate the model on the test dataset. 180 | 181 | Args: 182 | params (argparse.Namespace): Parsed command-line arguments. 183 | model (nn.Module): The gaze estimation model. 184 | data_loader (torch.utils.data.DataLoader): DataLoader for the test dataset. 185 | idx_tensor (torch.Tensor): Tensor representing bin indices. 186 | device (torch.device): Device to perform evaluation on. 187 | """ 188 | model.eval() 189 | average_error = 0 190 | total_samples = 0 191 | 192 | for images, labels_gaze, regression_labels_gaze, _ in tqdm(data_loader, total=len(data_loader)): 193 | total_samples += regression_labels_gaze.size(0) 194 | images = images.to(device) 195 | 196 | # Regression labels 197 | label_pitch = np.radians(regression_labels_gaze[:, 0], dtype=np.float32) 198 | label_yaw = np.radians(regression_labels_gaze[:, 1], dtype=np.float32) 199 | 200 | # Inference 201 | pitch, yaw = model(images) 202 | 203 | # Regression predictions 204 | pitch_predicted = F.softmax(pitch, dim=1) 205 | yaw_predicted = F.softmax(yaw, dim=1) 206 | 207 | # Mapping from binned (0 to 90) to angles (-180 to 180) or (0 to 28) to angles (-42, 42) 208 | pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) * params.binwidth - params.angle 209 | yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) * params.binwidth - params.angle 210 | 211 | pitch_predicted = np.radians(pitch_predicted.cpu()) 212 | yaw_predicted = np.radians(yaw_predicted.cpu()) 213 | 214 | for p, y, pl, yl in zip(pitch_predicted, yaw_predicted, label_pitch, label_yaw): 215 | average_error += angular_error(gaze_to_3d([p, y]), gaze_to_3d([pl, yl])) 216 | 217 | logging.info( 218 | f"Dataset: {params.dataset} | " 219 | f"Total Number of Samples: {total_samples} | " 220 | f"Mean Angular Error: {average_error/total_samples}" 221 | ) 222 | return average_error/total_samples 223 | 224 | 225 | def main(): 226 | params = parse_args() 227 | 228 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 229 | summary_name = f'{params.dataset}_{params.arch}_{int(time.time())}' 230 | output = os.path.join(params.output, summary_name) 231 | if not os.path.exists(output): 232 | os.makedirs(output) 233 | torch.backends.cudnn.benchmark = True 234 | 235 | model, optimizer, start_epoch = initialize_model(params, device) 236 | data_loader = get_dataloader(params, mode="train") 237 | dataset = data_loader.dataset 238 | 239 | cls_criterion = nn.CrossEntropyLoss() 240 | reg_criterion = nn.MSELoss() 241 | idx_tensor = torch.arange(params.bins, device=device, dtype=torch.float32) 242 | 243 | best_avg_error = float('inf') 244 | k = 5 # number of folds 245 | kfold = KFold(n_splits=k, shuffle=True, random_state=42) 246 | 247 | fold_errors = [] 248 | # K-Fold Cross Validation 249 | for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)): 250 | print(f"Fold {fold+1}/{k}") 251 | 252 | # Split data into training and validation sets for this fold 253 | train_subset = torch.utils.data.Subset(dataset, train_idx) 254 | val_subset = torch.utils.data.Subset(dataset, val_idx) 255 | 256 | # Create data loaders for the subsets 257 | train_loader = torch.utils.data.DataLoader(train_subset, batch_size=params.batch_size, shuffle=True) 258 | val_loader = torch.utils.data.DataLoader(val_subset, batch_size=params.batch_size, shuffle=False) 259 | 260 | # Reset model and optimizer for each fold 261 | model, optimizer, start_epoch = initialize_model(params, device) 262 | 263 | for epoch in range(start_epoch, params.num_epochs): 264 | avg_loss_pitch, avg_loss_yaw = train_one_epoch( 265 | params, 266 | model, 267 | cls_criterion, 268 | reg_criterion, 269 | optimizer, 270 | train_loader, 271 | idx_tensor, 272 | device, 273 | epoch 274 | ) 275 | 276 | logging.info( 277 | f'Epoch [{epoch + 1}/{params.num_epochs}] ' 278 | f'Losses: Gaze Yaw {avg_loss_yaw:.4f}, Gaze Pitch {avg_loss_pitch:.4f}' 279 | ) 280 | 281 | # checkpoint_path = os.path.join(output, f"checkpoint_fold_{fold+1}.ckpt") 282 | # torch.save({ 283 | # 'epoch': epoch + 1, 284 | # 'model_state_dict': model.state_dict(), 285 | # 'optimizer_state_dict': optimizer.state_dict(), 286 | # 'loss': avg_loss_pitch + avg_loss_yaw, 287 | # }, checkpoint_path) 288 | # logging.info(f'Checkpoint saved at {checkpoint_path}') 289 | 290 | # Evaluate on validation set for the current fold 291 | avg_error = evaluate(params, model, val_loader, idx_tensor, device) # Returns average error 292 | fold_errors.append(avg_error) 293 | 294 | logging.info(f'Fold {fold+1} average error: {avg_error:.4f}') 295 | 296 | # Save the best model for the fold 297 | if avg_error < best_avg_error: 298 | best_avg_error = avg_error 299 | best_model_path = os.path.join(output, f'best_model.pt') 300 | torch.save(model.state_dict(), best_model_path) 301 | logging.info(f'Best model saved for fold {fold+1} at {best_model_path}') 302 | 303 | # Calculate average error across all folds 304 | avg_error_overall = np.mean(fold_errors) 305 | logging.info(f'Average error across {k} folds: {avg_error_overall:.4f}') 306 | 307 | 308 | if __name__ == '__main__': 309 | main() 310 | -------------------------------------------------------------------------------- /onnx_export.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from config import data_config 5 | from utils.helpers import get_model 6 | 7 | 8 | def parse_arguments(): 9 | parser = argparse.ArgumentParser(description='Gaze Estimation Model ONNX Export') 10 | 11 | parser.add_argument( 12 | '-w', '--weight', 13 | default='resnet34.pt', 14 | type=str, 15 | help='Trained state_dict file path to open' 16 | ) 17 | parser.add_argument( 18 | '-n', '--model', 19 | type=str, 20 | default='resnet34', 21 | choices=['resnet18', 'resnet34', 'resnet50', 'mobilenetv2', 'mobileone_s0'], 22 | help='Backbone network architecture to use' 23 | ) 24 | parser.add_argument( 25 | '-d', '--dataset', 26 | type=str, 27 | default='gaze360', 28 | choices=list(data_config.keys()), 29 | help='Dataset name for bin configuration' 30 | ) 31 | parser.add_argument( 32 | '--dynamic', 33 | action='store_true', 34 | help='Enable dynamic batch size and input dimensions for ONNX export' 35 | ) 36 | 37 | return parser.parse_args() 38 | 39 | 40 | @torch.no_grad() 41 | def onnx_export(params): 42 | # Get dataset config for bins 43 | if params.dataset not in data_config: 44 | raise KeyError(f"Unknown dataset: {params.dataset}. Available options: {list(data_config.keys())}") 45 | bins = data_config[params.dataset]['bins'] 46 | 47 | # Set device 48 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 49 | 50 | # Initialize model 51 | model = get_model(params.model, bins, inference_mode=True) 52 | model.to(device) 53 | 54 | # Load weights 55 | state_dict = torch.load(params.weight, map_location=device) 56 | model.load_state_dict(state_dict) 57 | print("Gaze model loaded successfully!") 58 | 59 | # Eval mode 60 | model.eval() 61 | 62 | # Generate ONNX output filename 63 | fname = os.path.splitext(os.path.basename(params.weight))[0] 64 | onnx_model = f'{fname}_gaze.onnx' 65 | print(f"==> Exporting model to ONNX format at '{onnx_model}'") 66 | 67 | # Dummy input: RGB image, 448x448 68 | dummy_input = torch.randn(1, 3, 448, 448).to(device) 69 | 70 | # Handle dynamic axes 71 | dynamic_axes = None 72 | if params.dynamic: 73 | dynamic_axes = { 74 | 'input': {0: 'batch_size'}, 75 | 'pitch': {0: 'batch_size'}, 76 | 'yaw': {0: 'batch_size'} 77 | } 78 | print("Exporting model with dynamic input shapes.") 79 | else: 80 | print("Exporting model with fixed input size: (1, 3, 448, 448)") 81 | 82 | # Export model 83 | torch.onnx.export( 84 | model, 85 | dummy_input, 86 | onnx_model, 87 | export_params=True, 88 | opset_version=20, 89 | do_constant_folding=True, 90 | input_names=['input'], 91 | output_names=['pitch', 'yaw'], 92 | dynamic_axes=dynamic_axes 93 | ) 94 | 95 | print(f"Model exported successfully to {onnx_model}") 96 | 97 | 98 | if __name__ == '__main__': 99 | args = parse_arguments() 100 | onnx_export(args) 101 | -------------------------------------------------------------------------------- /onnx_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Yakhyokhuja Valikhujaev 2 | # Author: Yakhyokhuja Valikhujaev 3 | # GitHub: https://github.com/yakhyo 4 | 5 | import cv2 6 | import uniface 7 | import argparse 8 | import numpy as np 9 | import onnxruntime as ort 10 | 11 | from typing import Tuple 12 | 13 | from utils.helpers import draw_bbox_gaze 14 | 15 | 16 | class GazeEstimationONNX: 17 | """ 18 | Gaze estimation using ONNXRuntime (logits to radian decoded). 19 | """ 20 | 21 | def __init__(self, model_path: str, session: ort.InferenceSession = None) -> None: 22 | """Initializes the GazeEstimationONNX class. 23 | 24 | Args: 25 | model_path (str): Path to the ONNX model file. 26 | session (ort.InferenceSession, optional): ONNX Session. Defaults to None. 27 | 28 | Raises: 29 | AssertionError: If model_path is None and session is not provided. 30 | """ 31 | self.session = session 32 | if self.session is None: 33 | assert model_path is not None, "Model path is required for the first time initialization." 34 | self.session = ort.InferenceSession( 35 | model_path, 36 | providers=["CPUExecutionProvider", "CUDAExecutionProvider"] 37 | ) 38 | 39 | self._bins = 90 40 | self._binwidth = 4 41 | self._angle_offset = 180 42 | self.idx_tensor = np.arange(self._bins, dtype=np.float32) 43 | 44 | self.input_shape = (448, 448) 45 | self.input_mean = [0.485, 0.456, 0.406] 46 | self.input_std = [0.229, 0.224, 0.225] 47 | 48 | input_cfg = self.session.get_inputs()[0] 49 | input_shape = input_cfg.shape 50 | 51 | self.input_name = input_cfg.name 52 | self.input_size = tuple(input_shape[2:][::-1]) 53 | 54 | outputs = self.session.get_outputs() 55 | output_names = [output.name for output in outputs] 56 | 57 | self.output_names = output_names 58 | assert len(output_names) == 2, "Expected 2 output nodes, got {}".format(len(output_names)) 59 | 60 | def preprocess(self, image: np.ndarray) -> np.ndarray: 61 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 62 | image = cv2.resize(image, self.input_size) # Resize to 448x448 63 | 64 | image = image.astype(np.float32) / 255.0 65 | 66 | mean = np.array(self.input_mean, dtype=np.float32) 67 | std = np.array(self.input_std, dtype=np.float32) 68 | image = (image - mean) / std 69 | 70 | image = np.transpose(image, (2, 0, 1)) # HWC → CHW 71 | image_batch = np.expand_dims(image, axis=0).astype(np.float32) # CHW → BCHW 72 | 73 | return image_batch 74 | 75 | def softmax(self, x: np.ndarray) -> np.ndarray: 76 | e_x = np.exp(x - np.max(x, axis=1, keepdims=True)) 77 | return e_x / e_x.sum(axis=1, keepdims=True) 78 | 79 | def decode(self, pitch_logits: np.ndarray, yaw_logits: np.ndarray) -> Tuple[float, float]: 80 | pitch_probs = self.softmax(pitch_logits) 81 | yaw_probs = self.softmax(yaw_logits) 82 | 83 | pitch = np.sum(pitch_probs * self.idx_tensor, axis=1) * self._binwidth - self._angle_offset 84 | yaw = np.sum(yaw_probs * self.idx_tensor, axis=1) * self._binwidth - self._angle_offset 85 | 86 | return np.radians(pitch[0]), np.radians(yaw[0]) 87 | 88 | def estimate(self, face_image: np.ndarray) -> Tuple[float, float]: 89 | input_tensor = self.preprocess(face_image) 90 | outputs = self.session.run(self.output_names, {"input": input_tensor}) 91 | 92 | return self.decode(outputs[0], outputs[1]) 93 | 94 | 95 | def parse_args(): 96 | parser = argparse.ArgumentParser(description="Gaze Estimation ONNX Inference") 97 | parser.add_argument( 98 | "--source", 99 | type=str, 100 | required=True, 101 | help="Video path or camera index (e.g., 0 for webcam)" 102 | ) 103 | parser.add_argument( 104 | "--model", 105 | type=str, 106 | required=True, 107 | help="Path to ONNX model" 108 | ) 109 | parser.add_argument( 110 | "--output", 111 | type=str, 112 | default=None, 113 | help="Path to save output video (optional)" 114 | ) 115 | return parser.parse_args() 116 | 117 | 118 | if __name__ == "__main__": 119 | args = parse_args() 120 | 121 | # Handle numeric webcam index 122 | try: 123 | source = int(args.source) 124 | except ValueError: 125 | source = args.source 126 | 127 | cap = cv2.VideoCapture(source) 128 | if not cap.isOpened(): 129 | raise IOError(f"Failed to open video source: {args.source}") 130 | 131 | # Initialize Gaze Estimation model 132 | engine = GazeEstimationONNX(model_path=args.model) 133 | detector = uniface.RetinaFace() 134 | 135 | # Optional output writer 136 | writer = None 137 | if args.output: 138 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 139 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 140 | fps = cap.get(cv2.CAP_PROP_FPS) or 30.0 141 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 142 | writer = cv2.VideoWriter(args.output, fourcc, fps, (width, height)) 143 | 144 | while cap.isOpened(): 145 | ret, frame = cap.read() 146 | if not ret: 147 | break 148 | 149 | bboxes, _ = detector.detect(frame) 150 | 151 | for bbox in bboxes: 152 | x_min, y_min, x_max, y_max = map(int, bbox[:4]) 153 | face_crop = frame[y_min:y_max, x_min:x_max] 154 | if face_crop.size == 0: 155 | continue 156 | 157 | pitch, yaw = engine.estimate(face_crop) 158 | draw_bbox_gaze(frame, bbox, pitch, yaw) 159 | 160 | if writer: 161 | writer.write(frame) 162 | 163 | cv2.imshow("Gaze Estimation", frame) 164 | if cv2.waitKey(1) & 0xFF == ord("q"): 165 | break 166 | 167 | cap.release() 168 | if writer: 169 | writer.release() 170 | cv2.destroyAllWindows() 171 | -------------------------------------------------------------------------------- /reparameterize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models import mobileone_s0, mobileone_s1, mobileone_s2, mobileone_s3, mobileone_s4, reparameterize_model 3 | 4 | state_dict = torch.load('mobileone_s0.pt') 5 | model = mobileone_s0(pretrained=False, num_classes=90) # 90 bins 6 | 7 | model.load_state_dict(state_dict) 8 | 9 | 10 | model.eval() 11 | model_eval = reparameterize_model(model) 12 | 13 | torch.save(model_eval.state_dict(), 's0_fused.pt') 14 | 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==2.0.1 2 | onnxruntime==1.19.0 3 | opencv-python==4.10.0.84 4 | pillow==10.2.0 5 | torch==2.4.0 6 | torchvision==0.19.0 7 | tqdm==4.66.5 8 | uniface==0.1.7 9 | -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class Gaze360(Dataset): 11 | def __init__(self, root: str, transform=None, angle: int = 180, binwidth: int = 4, mode: str = 'train'): 12 | self.labels_dir = os.path.join(root, "Label") 13 | self.images_dir = os.path.join(root, "Image") 14 | 15 | if mode in ['train', 'test', 'val']: 16 | labels_file = os.path.join(self.labels_dir, f"{mode}.label") 17 | else: 18 | raise ValueError(f"{mode} must be in ['train','test', 'val']") 19 | 20 | self.transform = transform 21 | self.angle = angle if mode == "train" else 90 22 | self.binwidth = binwidth 23 | 24 | self.lines = [] 25 | 26 | with open(labels_file) as f: 27 | lines = f.readlines()[1:] # Skip the header line 28 | self.orig_list_len = len(lines) 29 | 30 | for line in tqdm(lines, desc="Loading Labels"): 31 | gaze2d = line.strip().split(" ")[5] 32 | label = np.array(gaze2d.split(",")).astype(float) 33 | pitch, yaw = label * 180 / np.pi 34 | 35 | if abs(pitch) <= self.angle and abs(yaw) <= self.angle: 36 | self.lines.append(line) 37 | 38 | removed_items = self.orig_list_len - len(self.lines) 39 | print(f"{removed_items} items removed from dataset that have an angle > {self.angle}") 40 | 41 | def __len__(self): 42 | return len(self.lines) 43 | 44 | def __getitem__(self, idx): 45 | line = self.lines[idx].strip().split(" ") 46 | 47 | image_path = line[0] 48 | filename = line[3] 49 | gaze2d = line[5] 50 | 51 | label = np.array(gaze2d.split(",")).astype(float) 52 | pitch, yaw = label * 180 / np.pi 53 | 54 | image = Image.open(os.path.join(self.images_dir, image_path)) 55 | if self.transform is not None: 56 | image = self.transform(image) 57 | 58 | # bin values 59 | bins = np.arange(-self.angle, self.angle, self.binwidth) 60 | binned_pose = np.digitize([pitch, yaw], bins) - 1 61 | 62 | # binned and regression labels 63 | binned_labels = torch.tensor(binned_pose, dtype=torch.long) 64 | regression_labels = torch.tensor([pitch, yaw], dtype=torch.float32) 65 | 66 | return image, binned_labels, regression_labels, filename 67 | 68 | 69 | class MPIIGaze(Dataset): 70 | def __init__(self, root: str, transform=None, angle: int = 42, binwidth: int = 3): 71 | self.labels_dir = os.path.join(root, "Label") 72 | self.images_dir = os.path.join(root, "Image") 73 | 74 | label_files = [os.path.join(self.labels_dir, label) for label in os.listdir(self.labels_dir)] 75 | 76 | self.transform = transform 77 | self.orig_list_len = 0 78 | self.binwidth = binwidth 79 | self.angle = angle 80 | self.lines = [] 81 | 82 | for label_file in label_files: 83 | with open(label_file) as f: 84 | lines = f.readlines()[1:] # Skip the header line 85 | self.orig_list_len += len(lines) 86 | for line in lines: 87 | gaze2d = line.strip().split(" ")[7] 88 | label = np.array(gaze2d.split(",")).astype("float") 89 | pitch, yaw = label * 180 / np.pi 90 | 91 | if abs(pitch) <= self.angle and abs(yaw) <= self.angle: 92 | self.lines.append(line) 93 | 94 | removed_items = self.orig_list_len - len(self.lines) 95 | print(f"{removed_items} items removed from dataset that have an angle > {self.angle}") 96 | 97 | def __len__(self): 98 | return len(self.lines) 99 | 100 | def __getitem__(self, idx): 101 | line = self.lines[idx].strip().split(" ") 102 | 103 | image_path = line[0] 104 | filename = line[3] 105 | gaze2d = line[7] 106 | 107 | label = np.array(gaze2d.split(",")).astype("float") 108 | pitch, yaw = label * 180 / np.pi 109 | 110 | image = Image.open(os.path.join(self.images_dir, image_path)) 111 | if self.transform is not None: 112 | image = self.transform(image) 113 | 114 | # bin values 115 | bins = np.arange(-self.angle, self.angle, self.binwidth) 116 | binned_pose = np.digitize([pitch, yaw], bins) - 1 117 | 118 | # binned and regression labels 119 | binned_labels = torch.tensor(binned_pose, dtype=torch.long) 120 | regression_labels = torch.tensor([pitch, yaw], dtype=torch.float32) 121 | 122 | return image, binned_labels, regression_labels, filename 123 | -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms 8 | 9 | from utils.datasets import Gaze360, MPIIGaze 10 | 11 | from models import ( 12 | resnet18, 13 | resnet34, 14 | resnet50, 15 | mobilenet_v2, 16 | mobileone_s0, 17 | mobileone_s1, 18 | mobileone_s2, 19 | mobileone_s3, 20 | mobileone_s4 21 | ) 22 | 23 | 24 | def get_model(arch, bins, pretrained=False, inference_mode=False): 25 | """Return the model based on the specified architecture.""" 26 | if arch == 'resnet18': 27 | model = resnet18(pretrained=pretrained, num_classes=bins) 28 | elif arch == 'resnet34': 29 | model = resnet34(pretrained=pretrained, num_classes=bins) 30 | elif arch == 'resnet50': 31 | model = resnet50(pretrained=pretrained, num_classes=bins) 32 | elif arch == "mobilenetv2": 33 | model = mobilenet_v2(pretrained=pretrained, num_classes=bins) 34 | elif arch == "mobileone_s0": 35 | model = mobileone_s0(pretrained=pretrained, num_classes=bins, inference_mode=inference_mode) 36 | elif arch == "mobileone_s1": 37 | model = mobileone_s1(pretrained=pretrained, num_classes=bins, inference_mode=inference_mode) 38 | elif arch == "mobileone_s2": 39 | model = mobileone_s2(pretrained=pretrained, num_classes=bins, inference_mode=inference_mode) 40 | elif arch == "mobileone_s3": 41 | model = mobileone_s3(pretrained=pretrained, num_classes=bins, inference_mode=inference_mode) 42 | elif arch == "mobileone_s4": 43 | model = mobileone_s4(pretrained=pretrained, num_classes=bins, inference_mode=inference_mode) 44 | else: 45 | raise ValueError(f"Please choose available model architecture, currently chosen: {arch}") 46 | return model 47 | 48 | 49 | def angular_error(gaze_vector, label_vector): 50 | dot_product = np.dot(gaze_vector, label_vector) 51 | norm_product = np.linalg.norm(gaze_vector) * np.linalg.norm(label_vector) 52 | cosine_similarity = min(dot_product / norm_product, 0.9999999) 53 | 54 | return np.degrees(np.arccos(cosine_similarity)) 55 | 56 | 57 | def gaze_to_3d(gaze): 58 | yaw = gaze[0] # Horizontal angle 59 | pitch = gaze[1] # Vertical angle 60 | 61 | gaze_vector = np.zeros(3) 62 | gaze_vector[0] = -np.cos(pitch) * np.sin(yaw) 63 | gaze_vector[1] = -np.sin(pitch) 64 | gaze_vector[2] = -np.cos(pitch) * np.cos(yaw) 65 | 66 | return gaze_vector 67 | 68 | 69 | def get_dataloader(params, mode="train"): 70 | """Load dataset and return DataLoader.""" 71 | 72 | transform = transforms.Compose([ 73 | transforms.Resize(448), 74 | transforms.ToTensor(), 75 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 76 | ]) 77 | 78 | if params.dataset == "gaze360": 79 | dataset = Gaze360(params.data, transform, angle=params.angle, binwidth=params.binwidth, mode=mode) 80 | elif params.dataset == "mpiigaze": 81 | dataset = MPIIGaze(params.data, transform, angle=params.angle, binwidth=params.binwidth) 82 | else: 83 | raise ValueError("Supported dataset are `gaze360` and `mpiigaze`") 84 | 85 | data_loader = DataLoader( 86 | dataset=dataset, 87 | batch_size=params.batch_size, 88 | shuffle=True if mode == "train" else False, 89 | num_workers=params.num_workers, 90 | pin_memory=True 91 | ) 92 | return data_loader 93 | 94 | def draw_gaze(frame, bbox, pitch, yaw, thickness=2, color=(0, 0, 255)): 95 | """Draws gaze direction on a frame given bounding box and gaze angles.""" 96 | # Unpack bounding box coordinates 97 | x_min, y_min, x_max, y_max = map(int, bbox[:4]) 98 | 99 | # Calculate center of the bounding box 100 | x_center = (x_min + x_max) // 2 101 | y_center = (y_min + y_max) // 2 102 | 103 | # Handle grayscale frames by converting them to BGR 104 | if len(frame.shape) == 2 or frame.shape[2] == 1: 105 | frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR) 106 | 107 | # Calculate the direction of the gaze 108 | length = x_max - x_min 109 | dx = int(-length * np.sin(pitch) * np.cos(yaw)) 110 | dy = int(-length * np.sin(yaw)) 111 | 112 | point1 = (x_center, y_center) 113 | point2 = (x_center + dx, y_center + dy) 114 | 115 | # Draw gaze direction 116 | cv2.circle(frame, (x_center, y_center), radius=4, color=color, thickness=-1) 117 | cv2.arrowedLine( 118 | frame, 119 | point1, 120 | point2, 121 | color=color, 122 | thickness=thickness, 123 | line_type=cv2.LINE_AA, 124 | tipLength=0.25 125 | ) 126 | 127 | 128 | 129 | def draw_bbox(image, bbox, color=(0, 255, 0), thickness=2, proportion=0.2): 130 | x_min, y_min, x_max, y_max = map(int, bbox[:4]) 131 | 132 | width = x_max - x_min 133 | height = y_max - y_min 134 | 135 | corner_length = int(proportion * min(width, height)) 136 | 137 | # Draw the rectangle 138 | cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 1) 139 | 140 | # Top-left corner 141 | cv2.line(image, (x_min, y_min), (x_min + corner_length, y_min), color, thickness) 142 | cv2.line(image, (x_min, y_min), (x_min, y_min + corner_length), color, thickness) 143 | 144 | # Top-right corner 145 | cv2.line(image, (x_max, y_min), (x_max - corner_length, y_min), color, thickness) 146 | cv2.line(image, (x_max, y_min), (x_max, y_min + corner_length), color, thickness) 147 | 148 | # Bottom-left corner 149 | cv2.line(image, (x_min, y_max), (x_min, y_max - corner_length), color, thickness) 150 | cv2.line(image, (x_min, y_max), (x_min + corner_length, y_max), color, thickness) 151 | 152 | # Bottom-right corner 153 | cv2.line(image, (x_max, y_max), (x_max, y_max - corner_length), color, thickness) 154 | cv2.line(image, (x_max, y_max), (x_max - corner_length, y_max), color, thickness) 155 | 156 | 157 | def draw_bbox_gaze(frame: np.ndarray, bbox, pitch, yaw): 158 | draw_bbox(frame, bbox) 159 | draw_gaze(frame, bbox, pitch, yaw) 160 | -------------------------------------------------------------------------------- /weights/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yakhyo/gaze-estimation/2ccde1d08d007727c2df1ce704c32e2683b2d0b9/weights/.gitkeep --------------------------------------------------------------------------------