├── .gitignore
├── LICENSE
├── README.md
├── assets
├── in_video.mp4
├── out_gif.gif
└── out_video.mp4
├── config.py
├── data
└── .gitkeep
├── download.sh
├── evaluate.py
├── inference.py
├── main.py
├── models
├── __init__.py
├── mobilenet.py
├── mobileone.py
└── resnet.py
├── mpii_train.py
├── onnx_export.py
├── onnx_inference.py
├── reparameterize.py
├── requirements.txt
├── utils
├── datasets.py
└── helpers.py
└── weights
└── .gitkeep
/.gitignore:
--------------------------------------------------------------------------------
1 | data/Gaze360
2 | data/MPIIFaceGaze
3 | output/
4 | *.pt
5 | *.onnx
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 | cover/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | .pybuilder/
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | # For a library or package, you might want to ignore these files since the code is
93 | # intended to run in multiple environments; otherwise, check them in:
94 | # .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Yakhyokhuja Valikhujaev
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MobileGaze: Pre-trained mobile nets for Gaze-Estimation
2 |
3 | 
4 | [](https://github.com/yakhyo/gaze-estimation/stargazers)
5 | [](https://github.com/yakhyo/gaze-estimation)
6 | [](https://doi.org/10.5281/zenodo.14257640)
7 |
8 |
11 |
12 |
17 |
18 |
19 |
20 |

21 |
22 | Video by Yan Krukau: https://www.pexels.com/video/male-teacher-with-his-students-8617126/
23 |
24 |
25 |
26 | This project aims to perform gaze estimation using several deep learning models like ResNet, MobileNet v2, and MobileOne. It supports both classification and regression for predicting gaze direction. Built on top of [L2CS-Net](https://github.com/Ahmednull/L2CS-Net), the project includes additional pre-trained models and refined code for better performance and flexibility.
27 |
28 | ## Features
29 |
30 | - [x] **ONNX Inference**: Export pytorch weights to ONNX and ONNX runtime inference.
31 | - [x] **ResNet**: [Deep Residual Networks](https://arxiv.org/abs/1512.03385) - Enables deeper networks with better accuracy through residual learning.
32 | - [x] **MobileNet v2**: [Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) - Efficient model for mobile applications, balancing performance and computational cost.
33 | - [x] **MobileOne (s0-s4)**: [An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040) - Achieves near-instant inference times, ideal for real-time mobile applications.
34 | - [x] **Face Detection**: [uniface](https://github.com/yakhyo/uniface) - **Uniface** face detection library uses RetinaFace model.
35 |
36 | > [!NOTE]
37 | > All models are trained only on **Gaze360** dataset.
38 |
39 | ## Installation
40 |
41 | 1. Clone the repository:
42 |
43 | ```bash
44 | git clone https://github.com/yakyo/gaze-estimation.git
45 | cd gaze-estimation
46 | ```
47 |
48 | 2. Install the required dependencies:
49 |
50 | ```bash
51 | pip install -r requirements.txt
52 | ```
53 |
54 | 3. Download weight files:
55 |
56 | a) Download weights from the following links:
57 |
58 | | Model | PyTorch Weights | ONNX Weights | Size | Epochs | MAE |
59 | | ------------ | ----------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------- | ------- | ------ | ----- |
60 | | ResNet-18 | [resnet18.pt](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/resnet18.pt) | [resnet18_gaze.onnx](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/resnet18_gaze.onnx) | 43 MB | 200 | 12.84 |
61 | | ResNet-34 | [resnet34.pt](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/resnet34.pt) | [resnet34_gaze.onnx](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/resnet34_gaze.onnx) | 81.6 MB | 200 | 11.33 |
62 | | ResNet-50 | [resnet50.pt](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/resnet50.pt) | [resnet50_gaze.onnx](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/resnet50_gaze.onnx) | 91.3 MB | 200 | 11.34 |
63 | | MobileNet V2 | [mobilenetv2.pt](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/mobilenetv2.pt) | [mobilenetv2_gaze.onnx](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/mobilenetv2_gaze.onnx) | 9.59 MB | 200 | 13.07 |
64 | | MobileOne S0 | [mobileone_s0_fused.pt](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/mobileone_s0.pt) | [mobileone_s0_gaze.onnx](https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1/mobileone_s0_gaze.onnx) | 4.8 MB | 200 | 12.58 |
65 | | MobileOne S1 | [not available](#) | [not available](#) | xx MB | 200 | \* |
66 | | MobileOne S2 | [not available](#) | [not available](#) | xx MB | 200 | \* |
67 | | MobileOne S3 | [not available](#) | [not available](#) | xx MB | 200 | \* |
68 | | MobileOne S4 | [not availablet](#) | [not available](#) | xx MB | 200 | \* |
69 |
70 | '\*' - soon will be uploaded (due to limited computing resources I cannot publish rest of the weights, but you still can train them with given code).
71 |
72 | b) Run the command below to download weights to the `weights` directory (Linux):
73 |
74 | ```bash
75 | sh download.sh [model_name]
76 | resnet18
77 | resnet34
78 | resnet50
79 | mobilenetv2
80 | mobileone_s0
81 | mobileone_s1
82 | mobileone_s2
83 | mobileone_s3
84 | mobileone_s4
85 | ```
86 |
87 | ## Usage
88 |
89 | ### Datasets
90 |
91 | Dataset folder structure:
92 |
93 | ```
94 | data/
95 | ├── Gaze360/
96 | │ ├── Image/
97 | │ └── Label/
98 | └── MPIIFaceGaze/
99 | ├── Image/
100 | └── Label/
101 | ```
102 |
103 | **Gaze360**
104 |
105 | - Link to download dataset: https://gaze360.csail.mit.edu/download.php
106 | - Data pre-processing code: https://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/#gaze360
107 |
108 | **MPIIGaze**
109 |
110 | - Link to download dataset: [download page](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/its-written-all-over-your-face-full-face-appearance-based-gaze-estimation)
111 | - Data pre-processing code: https://phi-ai.buaa.edu.cn/Gazehub/3D-dataset/#mpiifacegaze
112 |
113 | ### Training
114 |
115 | ```bash
116 | python main.py --data [dataset_path] --dataset [dataset_name] --arch [architecture_name]
117 | ```
118 |
119 | `main.py` arguments:
120 |
121 | ```
122 | usage: main.py [-h] [--data DATA] [--dataset DATASET] [--output OUTPUT] [--checkpoint CHECKPOINT] [--num-epochs NUM_EPOCHS] [--batch-size BATCH_SIZE] [--arch ARCH] [--alpha ALPHA] [--lr LR] [--num-workers NUM_WORKERS]
123 |
124 | Gaze estimation training.
125 |
126 | options:
127 | -h, --help show this help message and exit
128 | --data DATA Directory path for gaze images.
129 | --dataset DATASET Dataset name, available `gaze360`, `mpiigaze`.
130 | --output OUTPUT Path of output models.
131 | --checkpoint CHECKPOINT
132 | Path to checkpoint for resuming training.
133 | --num-epochs NUM_EPOCHS
134 | Maximum number of training epochs.
135 | --batch-size BATCH_SIZE
136 | Batch size.
137 | --arch ARCH Network architecture, currently available: resnet18/34/50, mobilenetv2, mobileone_s0-s4.
138 | --alpha ALPHA Regression loss coefficient.
139 | --lr LR Base learning rate.
140 | --num-workers NUM_WORKERS
141 | Number of workers for data loading.
142 | ```
143 |
144 | ### Evaluation
145 |
146 | ```bash
147 | python evaluate.py --data [dataset_path] --dataset [dataset_name] --weight [weight_path] --arch [architecture_name]
148 | ```
149 |
150 | `evaluate.py` arguments:
151 |
152 | ```
153 | usage: evaluate.py [-h] [--data DATA] [--dataset DATASET] [--weights WEIGHTS] [--batch-size BATCH_SIZE] [--arch ARCH] [--num-workers NUM_WORKERS]
154 |
155 | Gaze estimation evaluation.
156 |
157 | options:
158 | -h, --help show this help message and exit
159 | --data DATA Directory path for gaze images.
160 | --dataset DATASET Dataset name, available `gaze360`, `mpiigaze`
161 | --weights WEIGHTS Path to model weight for evaluation.
162 | --batch-size BATCH_SIZE
163 | Batch size.
164 | --arch ARCH Network architecture, currently available: resnet18/34/50, mobilenetv2, mobileone_s0-s4.
165 | --num-workers NUM_WORKERS
166 | Number of workers for data loading.
167 | ```
168 |
169 | ### Inference
170 |
171 | ```bash
172 | inference.py --model [model_name] --weight [model_weight_path] --view --source [source_video / cam_index] --output [output_file] --dataset [dataset_name]
173 | ```
174 |
175 | `detect.py` arguments:
176 |
177 | ```
178 | usage: inference.py [-h] [--model MODEL] [--weight WEIGHT] [--view] [--source SOURCE] [--output OUTPUT] [--dataset DATASET]
179 |
180 | Gaze estimation inference
181 |
182 | options:
183 | -h, --help show this help message and exit
184 | --model MODEL Model name, default `resnet18`
185 | --weight WEIGHT Path to gaze esimation model weights
186 | --view Display the inference results
187 | --source SOURCE Path to source video file or camera index
188 | --output OUTPUT Path to save output file
189 | --dataset DATASET Dataset name to get dataset related configs
190 | ```
191 |
192 | ### ONNX Export and Inference
193 |
194 | **Export to ONNX**
195 |
196 | ```bash
197 | python onnx_export.py --weight [model_path] --model [model_name] --dynamic
198 | ```
199 |
200 | `onnx_export.py` arguments:
201 |
202 | ```
203 | usage: onnx_export.py [-h] [-w WEIGHT] [-n {resnet18,resnet34,resnet50,mobilenetv2,mobileone_s0}] [-d {gaze360}] [--dynamic]
204 |
205 | Gaze Estimation Model ONNX Export
206 |
207 | options:
208 | -h, --help show this help message and exit
209 | -w WEIGHT, --weight WEIGHT
210 | Trained state_dict file path to open
211 | -n {resnet18,resnet34,resnet50,mobilenetv2,mobileone_s0}, --model {resnet18,resnet34,resnet50,mobilenetv2,mobileone_s0}
212 | Backbone network architecture to use
213 | -d {gaze360,mpiigaze}, --dataset {gaze360,mpiigaze}
214 | Dataset name for bin configuration
215 | --dynamic Enable dynamic batch size and input dimensions for ONNX export
216 | ```
217 |
218 | **ONNX Inference**
219 |
220 | ```bash
221 | python onnx_inference.py --source [source video / webcam index] --model [onnx model path] --output [path to save video]
222 | ```
223 |
224 | `onnx_inference.py` arguments:
225 |
226 | ```
227 | usage: onnx_inference.py [-h] --source SOURCE --model MODEL [--output OUTPUT]
228 |
229 | Gaze Estimation ONNX Inference
230 |
231 | options:
232 | -h, --help show this help message and exit
233 | --source SOURCE Video path or camera index (e.g., 0 for webcam)
234 | --model MODEL Path to ONNX model
235 | --output OUTPUT Path to save output video (optional)
236 | ```
237 |
238 | ## Citation
239 |
240 | If you use this work in your research, please cite it as:
241 |
242 | Valikhujaev, Y. (2024). MobileGaze: Pre-trained mobile nets for Gaze-Estimation. Zenodo. [https://doi.org/10.5281/zenodo.14257640](https://doi.org/10.5281/zenodo.14257640)
243 |
244 | Alternatively, in BibTeX format:
245 |
246 | ```bibtex
247 | @misc{valikhujaev2024mobilegaze,
248 | author = {Valikhujaev, Y.},
249 | title = {MobileGaze: Pre-trained mobile nets for Gaze-Estimation},
250 | year = {2024},
251 | publisher = {Zenodo},
252 | doi = {10.5281/zenodo.14257640},
253 | url = {https://doi.org/10.5281/zenodo.14257640}
254 | }
255 | ```
256 |
257 | ## Reference
258 |
259 | 1. This project is built on top of [L2CS-Net](https://github.com/Ahmednull/L2CS-Net). Most of the code parts have been re-written for reproducibility and adaptability. Several additional backbones are provided with pre-trained weights.
260 | 2. https://github.com/apple/ml-mobileone
261 | 3. [uniface](https://github.com/yakhyo/uniface) - face detection library used for inference in `detect.py`.
262 |
263 |
268 |
--------------------------------------------------------------------------------
/assets/in_video.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yakhyo/gaze-estimation/2ccde1d08d007727c2df1ce704c32e2683b2d0b9/assets/in_video.mp4
--------------------------------------------------------------------------------
/assets/out_gif.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yakhyo/gaze-estimation/2ccde1d08d007727c2df1ce704c32e2683b2d0b9/assets/out_gif.gif
--------------------------------------------------------------------------------
/assets/out_video.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yakhyo/gaze-estimation/2ccde1d08d007727c2df1ce704c32e2683b2d0b9/assets/out_video.mp4
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | data_config = {
2 | "gaze360":
3 | {
4 | "bins": 90,
5 | "binwidth": 4,
6 | "angle": 180 # angle range
7 | },
8 | "mpiigaze":
9 | {
10 | "bins": 28,
11 | "binwidth": 3,
12 | "angle": 42 # angle range
13 | }
14 |
15 | }
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yakhyo/gaze-estimation/2ccde1d08d007727c2df1ce704c32e2683b2d0b9/data/.gitkeep
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Define the base URL for the downloads
4 | BASE_URL="https://github.com/yakhyo/gaze-estimation/releases/download/v0.0.1"
5 |
6 | # Create the weights directory if it does not exist
7 | mkdir -p weights
8 |
9 | # Check if a model name was provided
10 | if [ -z "$1" ]; then
11 | echo "Usage: $0 "
12 | echo "Example: $0 resnet18"
13 | exit 1
14 | fi
15 |
16 | # Determine the model name
17 | MODEL_NAME=$1
18 | MODEL_FILE="${MODEL_NAME}.pt"
19 | MODEL_FILE_ONNX="${MODEL_NAME}_gaze.onnx"
20 |
21 | # Download the model
22 | wget -O weights/$MODEL_FILE $BASE_URL/$MODEL_FILE
23 | wget -O weights/$MODEL_FILE $BASE_URL/$MODEL_FILE_ONNX
24 |
25 | # Check if the download was successful
26 | if [ $? -eq 0 ]; then
27 | echo "Downloaded $MODEL_FILE and $MODEL_FILE_ONNX to weights/"
28 | else
29 | echo "Failed to download $MODEL_FILE"
30 | fi
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 |
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 |
9 | from tqdm import tqdm
10 | import torch.nn.functional as F
11 |
12 | from config import data_config
13 | from utils.helpers import angular_error, gaze_to_3d, get_dataloader, get_model
14 |
15 | import warnings
16 | warnings.filterwarnings("ignore")
17 | # Setup logging
18 | logging.basicConfig(level=logging.INFO, format='%(message)s')
19 |
20 |
21 | def parse_args():
22 | """Parse input arguments."""
23 | parser = argparse.ArgumentParser(description="Gaze estimation evaluation")
24 | parser.add_argument("--data", type=str, default="data/Gaze360", help="Directory path for gaze images.")
25 | parser.add_argument("--dataset", type=str, default="gaze360", help="Dataset name, available `gaze360`, `mpiigaze`")
26 | parser.add_argument("--weight", type=str, default="", help="Path to model weight for evaluation.")
27 | parser.add_argument("--batch-size", type=int, default=64, help="Batch size.")
28 | parser.add_argument(
29 | "--arch",
30 | type=str,
31 | default="resnet18",
32 | help="Network architecture, currently available: resnet18/34/50, mobilenetv2, mobileone_s0-s4."
33 | )
34 | parser.add_argument("--num-workers", type=int, default=8, help="Number of workers for data loading.")
35 |
36 | args = parser.parse_args()
37 |
38 | # Override default values based on selected dataset
39 | if args.dataset in data_config:
40 | dataset_config = data_config[args.dataset]
41 | args.bins = dataset_config["bins"]
42 | args.binwidth = dataset_config["binwidth"]
43 | args.angle = dataset_config["angle"]
44 | else:
45 | raise ValueError(f"Unknown dataset: {args.dataset}. Available options: {list(data_config.keys())}")
46 |
47 | return args
48 |
49 |
50 | @torch.no_grad()
51 | def evaluate(params, model, data_loader, idx_tensor, device):
52 | """
53 | Evaluate the model on the test dataset.
54 |
55 | Args:
56 | params (argparse.Namespace): Parsed command-line arguments.
57 | model (nn.Module): The gaze estimation model.
58 | data_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
59 | idx_tensor (torch.Tensor): Tensor representing bin indices.
60 | device (torch.device): Device to perform evaluation on.
61 | """
62 | model.eval()
63 | average_error = 0
64 | total_samples = 0
65 |
66 | for images, labels_gaze, regression_labels_gaze, _ in tqdm(data_loader, total=len(data_loader)):
67 | total_samples += regression_labels_gaze.size(0)
68 | images = images.to(device)
69 |
70 | # Regression labels
71 | label_pitch = np.radians(regression_labels_gaze[:, 0], dtype=np.float32)
72 | label_yaw = np.radians(regression_labels_gaze[:, 1], dtype=np.float32)
73 |
74 | # Inference
75 | pitch, yaw = model(images)
76 |
77 | # Regression predictions
78 | pitch_predicted = F.softmax(pitch, dim=1)
79 | yaw_predicted = F.softmax(yaw, dim=1)
80 |
81 | # Mapping from binned (0 to 90) to angles (-180 to 180) or (0 to 28) to angles (-42, 42)
82 | pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) * params.binwidth - params.angle
83 | yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) * params.binwidth - params.angle
84 |
85 | pitch_predicted = np.radians(pitch_predicted.cpu())
86 | yaw_predicted = np.radians(yaw_predicted.cpu())
87 |
88 | for p, y, pl, yl in zip(pitch_predicted, yaw_predicted, label_pitch, label_yaw):
89 | average_error += angular_error(gaze_to_3d([p, y]), gaze_to_3d([pl, yl]))
90 |
91 | logging.info(
92 | f"Dataset: {params.dataset} | "
93 | f"Total Number of Samples: {total_samples} | "
94 | f"Mean Angular Error: {average_error/total_samples}"
95 | )
96 |
97 |
98 | def main():
99 | params = parse_args()
100 |
101 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
102 | torch.backends.cudnn.benchmark = True
103 |
104 | model = get_model(params.arch, params.bins, inference_mode=True)
105 |
106 | if os.path.exists(params.weight):
107 | model.load_state_dict(torch.load(params.weight, map_location=device, weights_only=True))
108 | else:
109 | raise ValueError(f"Model weight not found at {params.weight}")
110 |
111 | model.to(device)
112 | test_loader = get_dataloader(params, mode="test")
113 |
114 | idx_tensor = torch.arange(params.bins, device=device, dtype=torch.float32)
115 |
116 | logging.info("Start Evaluation")
117 | evaluate(params, model, test_loader, idx_tensor, device)
118 |
119 |
120 | if __name__ == '__main__':
121 | main()
122 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import logging
3 | import argparse
4 | import warnings
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | from torchvision import transforms
10 |
11 | from config import data_config
12 | from utils.helpers import get_model, draw_bbox_gaze
13 |
14 | import uniface
15 |
16 | warnings.filterwarnings("ignore")
17 | logging.basicConfig(level=logging.INFO, format='%(message)s')
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser(description="Gaze estimation inference")
22 | parser.add_argument("--model", type=str, default="resnet34", help="Model name, default `resnet18`")
23 | parser.add_argument(
24 | "--weight",
25 | type=str,
26 | default="resnet34.pt",
27 | help="Path to gaze esimation model weights"
28 | )
29 | parser.add_argument("--view", action="store_true", default=True, help="Display the inference results")
30 | parser.add_argument("--source", type=str, default="assets/in_video.mp4",
31 | help="Path to source video file or camera index")
32 | parser.add_argument("--output", type=str, default="output.mp4", help="Path to save output file")
33 | parser.add_argument("--dataset", type=str, default="gaze360", help="Dataset name to get dataset related configs")
34 | args = parser.parse_args()
35 |
36 | # Override default values based on selected dataset
37 | if args.dataset in data_config:
38 | dataset_config = data_config[args.dataset]
39 | args.bins = dataset_config["bins"]
40 | args.binwidth = dataset_config["binwidth"]
41 | args.angle = dataset_config["angle"]
42 | else:
43 | raise ValueError(f"Unknown dataset: {args.dataset}. Available options: {list(data_config.keys())}")
44 |
45 | return args
46 |
47 |
48 | def pre_process(image):
49 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50 | transform = transforms.Compose([
51 | transforms.ToPILImage(),
52 | transforms.Resize(448),
53 | transforms.ToTensor(),
54 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
55 | ])
56 |
57 | image = transform(image)
58 | image_batch = image.unsqueeze(0)
59 | return image_batch
60 |
61 |
62 | def main(params):
63 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64 |
65 | idx_tensor = torch.arange(params.bins, device=device, dtype=torch.float32)
66 |
67 | face_detector = uniface.RetinaFace() # third-party face detection library
68 |
69 | try:
70 | gaze_detector = get_model(params.model, params.bins, inference_mode=True)
71 | state_dict = torch.load(params.weight, map_location=device)
72 | gaze_detector.load_state_dict(state_dict)
73 | logging.info("Gaze Estimation model weights loaded.")
74 | except Exception as e:
75 | logging.info(f"Exception occured while loading pre-trained weights of gaze estimation model. Exception: {e}")
76 |
77 | gaze_detector.to(device)
78 | gaze_detector.eval()
79 |
80 | video_source = params.source
81 | if video_source.isdigit() or video_source == '0':
82 | cap = cv2.VideoCapture(int(video_source))
83 | else:
84 | cap = cv2.VideoCapture(video_source)
85 |
86 | if params.output:
87 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
88 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
89 | fourcc = cv2.VideoWriter_fourcc(*"mp4v")
90 | out = cv2.VideoWriter(params.output, fourcc, cap.get(cv2.CAP_PROP_FPS), (width, height))
91 |
92 | if not cap.isOpened():
93 | raise IOError("Cannot open webcam")
94 |
95 | with torch.no_grad():
96 | while True:
97 | success, frame = cap.read()
98 |
99 | if not success:
100 | logging.info("Failed to obtain frame or EOF")
101 | break
102 |
103 | bboxes, keypoints = face_detector.detect(frame)
104 | for bbox, keypoint in zip(bboxes, keypoints):
105 | x_min, y_min, x_max, y_max = map(int, bbox[:4])
106 |
107 | image = frame[y_min:y_max, x_min:x_max]
108 | image = pre_process(image)
109 | image = image.to(device)
110 |
111 | pitch, yaw = gaze_detector(image)
112 |
113 | pitch_predicted, yaw_predicted = F.softmax(pitch, dim=1), F.softmax(yaw, dim=1)
114 |
115 | # Mapping from binned (0 to 90) to angles (-180 to 180) or (0 to 28) to angles (-42, 42)
116 | pitch_predicted = torch.sum(pitch_predicted * idx_tensor, dim=1) * params.binwidth - params.angle
117 | yaw_predicted = torch.sum(yaw_predicted * idx_tensor, dim=1) * params.binwidth - params.angle
118 |
119 | # Degrees to Radians
120 | pitch_predicted = np.radians(pitch_predicted.cpu())
121 | yaw_predicted = np.radians(yaw_predicted.cpu())
122 |
123 | # draw box and gaze direction
124 | draw_bbox_gaze(frame, bbox, pitch_predicted, yaw_predicted)
125 |
126 | if params.output:
127 | out.write(frame)
128 |
129 | if params.view:
130 | cv2.imshow('Demo', frame)
131 | if cv2.waitKey(1) & 0xFF == ord('q'):
132 | break
133 |
134 | cap.release()
135 | if params.output:
136 | out.release()
137 | cv2.destroyAllWindows()
138 |
139 |
140 | if __name__ == "__main__":
141 | args = parse_args()
142 |
143 | if not args.view and not args.output:
144 | raise Exception("At least one of --view or --ouput must be provided.")
145 |
146 | main(args)
147 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import logging
4 | import argparse
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.utils.data import DataLoader
10 |
11 | from config import data_config
12 | from utils.helpers import get_model, get_dataloader
13 |
14 | # Setup logging
15 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
16 |
17 |
18 | def parse_args():
19 | """Parse input arguments."""
20 | parser = argparse.ArgumentParser(description="Gaze estimation training")
21 | parser.add_argument("--data", type=str, default="data", help="Directory path for gaze images.")
22 | parser.add_argument("--dataset", type=str, default="gaze360", help="Dataset name, available `gaze360`, `mpiigaze`.")
23 | parser.add_argument("--output", type=str, default="output/", help="Path of output models.")
24 | parser.add_argument("--checkpoint", type=str, default="", help="Path to checkpoint for resuming training.")
25 | parser.add_argument("--num-epochs", type=int, default=100, help="Maximum number of training epochs.")
26 | parser.add_argument("--batch-size", type=int, default=64, help="Batch size.")
27 | parser.add_argument(
28 | "--arch",
29 | type=str,
30 | default="resnet18",
31 | help="Network architecture, currently available: resnet18/34/50, mobilenetv2, mobileone_s0-s4."
32 | )
33 | parser.add_argument("--alpha", type=float, default=1, help="Regression loss coefficient.")
34 | parser.add_argument("--lr", type=float, default=0.00001, help="Base learning rate.")
35 | parser.add_argument("--num-workers", type=int, default=8, help="Number of workers for data loading.")
36 |
37 | args = parser.parse_args()
38 |
39 | # Override default values based on selected dataset
40 | if args.dataset in data_config:
41 | dataset_config = data_config[args.dataset]
42 | args.bins = dataset_config["bins"]
43 | args.binwidth = dataset_config["binwidth"]
44 | args.angle = dataset_config["angle"]
45 | else:
46 | raise ValueError(f"Unknown dataset: {args.dataset}. Available options: {list(data_config.keys())}")
47 |
48 | return args
49 |
50 |
51 | def initialize_model(params, device):
52 | """
53 | Initialize the gaze estimation model, optimizer, and optionally load a checkpoint.
54 |
55 | Args:
56 | params (argparse.Namespace): Parsed command-line arguments.
57 | device (torch.device): Device to load the model and optimizer onto.
58 |
59 | Returns:
60 | Tuple[nn.Module, torch.optim.Optimizer, int]: Initialized model, optimizer, and the starting epoch.
61 | """
62 | model = get_model(params.arch, params.bins, pretrained=True)
63 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr)
64 | start_epoch = 0
65 |
66 | if params.checkpoint:
67 | checkpoint = torch.load(params.checkpoint, map_location=device)
68 | model.load_state_dict(checkpoint['model_state_dict'])
69 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
70 |
71 | # Move optimizer states to device
72 | for state in optimizer.state.values():
73 | for k, v in state.items():
74 | if isinstance(v, torch.Tensor):
75 | state[k] = v.to(device)
76 |
77 | start_epoch = checkpoint['epoch']
78 | logging.info(f'Resumed training from {params.checkpoint}, starting at epoch {start_epoch + 1}')
79 |
80 | return model.to(device), optimizer, start_epoch
81 |
82 |
83 | def train_one_epoch(
84 | params,
85 | model,
86 | cls_criterion,
87 | reg_criterion,
88 | optimizer,
89 | data_loader,
90 | idx_tensor,
91 | device,
92 | epoch
93 | ):
94 | """
95 | Train the model for one epoch.
96 |
97 | Args:
98 | params (argparse.Namespace): Parsed command-line arguments.
99 | model (nn.Module): The gaze estimation model.
100 | cls_criterion (nn.Module): Loss function for classification.
101 | reg_criterion (nn.Module): Loss function for regression.
102 | optimizer (torch.optim.Optimizer): Optimizer for the model.
103 | data_loader (DataLoader): DataLoader for the training dataset.
104 | idx_tensor (torch.Tensor): Tensor representing bin indices.
105 | device (torch.device): Device to perform training on.
106 | epoch (int): The current epoch number.
107 |
108 | Returns:
109 | Tuple[float, float]: Average losses for pitch and yaw.
110 | """
111 |
112 | model.train()
113 | sum_loss_pitch, sum_loss_yaw = 0, 0
114 |
115 | for idx, (images, labels_gaze, regression_labels_gaze, _) in enumerate(data_loader):
116 | images = images.to(device)
117 |
118 | # Binned labels
119 | label_pitch = labels_gaze[:, 0].to(device)
120 | label_yaw = labels_gaze[:, 1].to(device)
121 |
122 | # Regression labels
123 | label_pitch_regression = regression_labels_gaze[:, 0].to(device)
124 | label_yaw_regression = regression_labels_gaze[:, 1].to(device)
125 |
126 | # Inference
127 | pitch, yaw = model(images)
128 |
129 | # Cross Entropy Loss
130 | loss_pitch = cls_criterion(pitch, label_pitch)
131 | loss_yaw = cls_criterion(yaw, label_yaw)
132 |
133 | # Softmax
134 | pitch, yaw = F.softmax(pitch, dim=1), F.softmax(yaw, dim=1)
135 |
136 | # Mapping from binned (0 to 90) to angels (-180 to 180)
137 | pitch_predicted = torch.sum(pitch * idx_tensor, 1) * params.binwidth - params.angle
138 | yaw_predicted = torch.sum(yaw * idx_tensor, 1) * params.binwidth - params.angle
139 |
140 | # Mean Squared Error Loss
141 | loss_regression_pitch = reg_criterion(pitch_predicted, label_pitch_regression)
142 | loss_regression_yaw = reg_criterion(yaw_predicted, label_yaw_regression)
143 |
144 | # Calculate loss with regression alpha
145 | loss_pitch += params.alpha * loss_regression_pitch
146 | loss_yaw += params.alpha * loss_regression_yaw
147 |
148 | # Total loss for pitch and yaw
149 | loss = loss_pitch + loss_yaw
150 |
151 | optimizer.zero_grad()
152 | loss.backward()
153 | optimizer.step()
154 |
155 | sum_loss_pitch += loss_pitch.item()
156 | sum_loss_yaw += loss_yaw.item()
157 |
158 | if (idx + 1) % 100 == 0:
159 | logging.info(
160 | f'Epoch [{epoch + 1}/{params.num_epochs}], Iter [{idx + 1}/{len(data_loader)}] '
161 | f'Losses: Gaze Yaw {sum_loss_yaw / (idx + 1):.4f}, Gaze Pitch {sum_loss_pitch / (idx + 1):.4f}'
162 | )
163 | avg_loss_pitch, avg_loss_yaw = sum_loss_pitch / len(data_loader), sum_loss_yaw / len(data_loader)
164 |
165 | return avg_loss_pitch, avg_loss_yaw
166 |
167 |
168 | def main():
169 | params = parse_args()
170 |
171 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
172 | summary_name = f'{params.dataset}_{params.arch}_{int(time.time())}'
173 | output = os.path.join(params.output, summary_name)
174 | if not os.path.exists(output):
175 | os.makedirs(output)
176 | torch.backends.cudnn.benchmark = True
177 |
178 | model, optimizer, start_epoch = initialize_model(params, device)
179 | train_loader = get_dataloader(params, mode="train")
180 |
181 | cls_criterion = nn.CrossEntropyLoss()
182 | reg_criterion = nn.MSELoss()
183 | idx_tensor = torch.arange(params.bins, device=device, dtype=torch.float32)
184 |
185 | best_loss = float('inf')
186 | print(f"Started training from epoch: {start_epoch + 1}")
187 |
188 | for epoch in range(start_epoch, params.num_epochs):
189 | avg_loss_pitch, avg_loss_yaw = train_one_epoch(
190 | params,
191 | model,
192 | cls_criterion,
193 | reg_criterion,
194 | optimizer,
195 | train_loader,
196 | idx_tensor,
197 | device,
198 | epoch
199 | )
200 |
201 | logging.info(
202 | f'Epoch [{epoch + 1}/{params.num_epochs}] '
203 | f'Losses: Gaze Yaw {avg_loss_yaw:.4f}, Gaze Pitch {avg_loss_pitch:.4f}'
204 | )
205 |
206 | checkpoint_path = os.path.join(output, "checkpoint.ckpt")
207 | torch.save({
208 | 'epoch': epoch + 1,
209 | 'model_state_dict': model.state_dict(),
210 | 'optimizer_state_dict': optimizer.state_dict(),
211 | 'loss': avg_loss_pitch + avg_loss_yaw,
212 | }, checkpoint_path)
213 | logging.info(f'Checkpoint saved at {checkpoint_path}')
214 |
215 | current_loss = (avg_loss_pitch + avg_loss_yaw) / len(train_loader)
216 | if current_loss < best_loss:
217 | best_loss = current_loss
218 | best_model_path = os.path.join(output, 'best_model.pt')
219 | torch.save(model.state_dict(), best_model_path)
220 | logging.info(f'Best model saved at {best_model_path}')
221 |
222 |
223 | if __name__ == '__main__':
224 | main()
225 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import resnet18, resnet34, resnet50
2 | from .mobilenet import mobilenet_v2
3 | from .mobileone import mobileone_s0, mobileone_s1, mobileone_s2, mobileone_s3, mobileone_s4, reparameterize_model
4 |
--------------------------------------------------------------------------------
/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torchvision.models import MobileNet_V2_Weights
4 |
5 | from typing import Any, Callable, List, Optional
6 |
7 | __all__ = ["mobilenet_v2"]
8 |
9 |
10 | def _make_divisible(v: float, divisor: int = 8) -> int:
11 | """This function ensures that all layers have a channel number divisible by 8"""
12 | new_v = max(divisor, int(v + divisor / 2) // divisor * divisor)
13 | # Make sure that round down does not go down by more than 10%.
14 | if new_v < 0.9 * v:
15 | new_v += divisor
16 | return new_v
17 |
18 |
19 | class Conv2dNormActivation(torch.nn.Sequential):
20 | """Convolutional block, consists of nn.Conv2d, nn.BatchNorm2d, nn.ReLU"""
21 |
22 | def __init__(
23 | self,
24 | in_channels: int,
25 | out_channels: int,
26 | kernel_size: int = 3,
27 | stride: int = 1,
28 | padding: Optional = None,
29 | groups: int = 1,
30 | activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
31 | dilation: int = 1,
32 | inplace: Optional[bool] = True,
33 | bias: bool = False,
34 | ) -> None:
35 |
36 | if padding is None:
37 | padding = (kernel_size - 1) // 2 * dilation
38 |
39 | layers: List[nn.Module] = [
40 | nn.Conv2d(
41 | in_channels=in_channels,
42 | out_channels=out_channels,
43 | kernel_size=kernel_size,
44 | stride=stride,
45 | padding=padding,
46 | dilation=dilation,
47 | groups=groups,
48 | bias=bias,
49 | ),
50 | nn.BatchNorm2d(num_features=out_channels, eps=0.001, momentum=0.01)
51 | ]
52 |
53 | if activation_layer is not None:
54 | params = {} if inplace is None else {"inplace": inplace}
55 | layers.append(activation_layer(**params))
56 | super().__init__(*layers)
57 |
58 |
59 | class InvertedResidual(nn.Module):
60 | def __init__(self, in_planes: int, out_planes: int, stride: int, expand_ratio: int) -> None:
61 | super().__init__()
62 | self.stride = stride
63 | if stride not in [1, 2]:
64 | raise ValueError(f"stride should be 1 or 2 instead of {stride}")
65 |
66 | hidden_dim = int(round(in_planes * expand_ratio))
67 | self.use_res_connect = self.stride == 1 and in_planes == out_planes
68 |
69 | layers: List[nn.Module] = []
70 | if expand_ratio != 1:
71 | # pw
72 | layers.append(
73 | Conv2dNormActivation(
74 | in_planes,
75 | hidden_dim,
76 | kernel_size=1,
77 | activation_layer=nn.ReLU6
78 | )
79 | )
80 | layers.extend(
81 | [
82 | # dw
83 | Conv2dNormActivation(
84 | hidden_dim,
85 | hidden_dim,
86 | stride=stride,
87 | groups=hidden_dim,
88 | activation_layer=nn.ReLU6,
89 | ),
90 | # pw-linear
91 | nn.Conv2d(hidden_dim, out_planes, 1, 1, 0, bias=False),
92 | nn.BatchNorm2d(out_planes),
93 | ]
94 | )
95 | self.conv = nn.Sequential(*layers)
96 | self.out_channels = out_planes
97 | self._is_cn = stride > 1
98 |
99 | def forward(self, x: Tensor) -> Tensor:
100 | if self.use_res_connect:
101 | return x + self.conv(x)
102 | else:
103 | return self.conv(x)
104 |
105 |
106 | class MobileNetV2(nn.Module):
107 | def __init__(
108 | self,
109 | num_classes: int = 1000,
110 | width_mult: float = 1.0,
111 | inverted_residual_setting: Optional[List[List[int]]] = None,
112 | round_nearest: int = 8,
113 | dropout: float = 0.2,
114 | ) -> None:
115 | """
116 | MobileNet V2 main class
117 |
118 | Args:
119 | num_classes (int): Number of classes
120 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
121 | inverted_residual_setting: Network structure
122 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
123 | Set to 1 to turn off rounding
124 | block: Module specifying inverted residual building block for mobilenet
125 | dropout (float): The droupout probability
126 |
127 | """
128 | super().__init__()
129 |
130 | input_channel = 32
131 | last_channel = 1280
132 |
133 | if inverted_residual_setting is None:
134 | inverted_residual_setting = [
135 | # t, c, n, s
136 | [1, 16, 1, 1],
137 | [6, 24, 2, 2],
138 | [6, 32, 3, 2],
139 | [6, 64, 4, 2],
140 | [6, 96, 3, 1],
141 | [6, 160, 3, 2],
142 | [6, 320, 1, 1],
143 | ]
144 |
145 | # only check the first element, assuming user knows t,c,n,s are required
146 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
147 | raise ValueError(
148 | f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}"
149 | )
150 |
151 | # building first layer
152 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
153 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
154 | features: List[nn.Module] = [
155 | Conv2dNormActivation(3, input_channel, stride=2, activation_layer=nn.ReLU6)
156 | ]
157 | # building inverted residual blocks
158 | for t, c, n, s in inverted_residual_setting:
159 | output_channel = _make_divisible(c * width_mult, round_nearest)
160 | for i in range(n):
161 | stride = s if i == 0 else 1
162 | features.append(InvertedResidual(input_channel, output_channel, stride, expand_ratio=t))
163 | input_channel = output_channel
164 | # building last several layers
165 | features.append(
166 | Conv2dNormActivation(
167 | input_channel, self.last_channel, kernel_size=1, activation_layer=nn.ReLU6
168 | )
169 | )
170 | # make it nn.Sequential
171 | self.features = nn.Sequential(*features)
172 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
173 |
174 | # building classifier
175 | # self.classifier = nn.Sequential(
176 | # nn.Dropout(p=dropout),
177 | # nn.Linear(self.last_channel, num_classes),
178 | # )
179 |
180 | self.fc_yaw = nn.Linear(self.last_channel, num_classes)
181 | self.fc_pitch = nn.Linear(self.last_channel, num_classes)
182 |
183 | # weight initialization
184 | for m in self.modules():
185 | if isinstance(m, nn.Conv2d):
186 | nn.init.kaiming_normal_(m.weight, mode="fan_out")
187 | if m.bias is not None:
188 | nn.init.zeros_(m.bias)
189 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
190 | nn.init.ones_(m.weight)
191 | nn.init.zeros_(m.bias)
192 | elif isinstance(m, nn.Linear):
193 | nn.init.normal_(m.weight, 0, 0.01)
194 | nn.init.zeros_(m.bias)
195 |
196 | def forward(self, x: Tensor) -> Tensor:
197 | x = self.features(x)
198 | # Cannot use "squeeze" as batch-size can be 1
199 | x = self.avgpool(x)
200 | x = torch.flatten(x, 1)
201 |
202 | # Original FC layer from MobileNet V2
203 | # x = self.classifier(x)
204 |
205 | yaw = self.fc_yaw(x)
206 | pitch = self.fc_pitch(x)
207 |
208 | return pitch, yaw
209 |
210 |
211 | def load_filtered_state_dict(model, state_dict):
212 | """Update the model's state dictionary with filtered parameters.
213 |
214 | Args:
215 | model: The model instance to update (must have `state_dict` and `load_state_dict` methods).
216 | state_dict: A dictionary of parameters to load into the model.
217 | """
218 | current_model_dict = model.state_dict()
219 | filtered_state_dict = {key: value for key, value in state_dict.items() if key in current_model_dict}
220 | current_model_dict.update(filtered_state_dict)
221 | model.load_state_dict(current_model_dict)
222 |
223 |
224 | def mobilenet_v2(*, pretrained: bool = True, progress: bool = True, **kwargs: Any) -> MobileNetV2:
225 |
226 | if pretrained:
227 | weights = MobileNet_V2_Weights.IMAGENET1K_V1
228 | else:
229 | weights = None
230 |
231 | model = MobileNetV2(**kwargs)
232 |
233 | if weights is not None:
234 | state_dict = weights.get_state_dict(progress=progress, check_hash=True)
235 | load_filtered_state_dict(model, state_dict)
236 |
237 | return model
238 |
--------------------------------------------------------------------------------
/models/mobileone.py:
--------------------------------------------------------------------------------
1 | # Modified by Yakhyokhuja Valikhujaev
2 | # Copyright (C) 2022 Apple Inc. All Rights Reserved.
3 |
4 | import copy
5 | import logging
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from typing import Optional, List, Tuple
12 |
13 | __all__ = [
14 | "mobileone_s0",
15 | "mobileone_s1",
16 | "mobileone_s2",
17 | "mobileone_s3",
18 | "mobileone_s4",
19 | "mobileone_s5",
20 | "reparameterize_model"
21 | ]
22 |
23 |
24 | logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO)
25 | logger = logging.getLogger()
26 |
27 |
28 | class SqueezeExcitationBlock(nn.Module):
29 | """
30 | Squeeze and Excite module.
31 |
32 | Pytorch implementation of `Squeeze-and-Excitation Networks` -
33 | https://arxiv.org/pdf/1709.01507.pdf
34 | """
35 |
36 | def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None:
37 | """
38 | Construct a Squeeze and Excite Module.
39 |
40 | Args:
41 | in_channels (int): Number of input channels.
42 | rd_ratio (float): Input channel reduction ratio.
43 | """
44 |
45 | super().__init__()
46 | self.reduce = nn.Conv2d(
47 | in_channels=in_channels,
48 | out_channels=int(in_channels * rd_ratio),
49 | kernel_size=1,
50 | stride=1,
51 | bias=True
52 | )
53 | self.expand = nn.Conv2d(
54 | in_channels=int(in_channels * rd_ratio),
55 | out_channels=in_channels,
56 | kernel_size=1,
57 | stride=1,
58 | bias=True
59 | )
60 |
61 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
62 | """ Apply forward pass. """
63 | b, c, h, w = inputs.size()
64 | x = F.avg_pool2d(inputs, kernel_size=[h, w])
65 | x = self.reduce(x)
66 | x = F.relu(x)
67 | x = self.expand(x)
68 | x = torch.sigmoid(x)
69 | x = x.view(-1, c, 1, 1)
70 | return inputs * x
71 |
72 |
73 | class MobileOneBlock(nn.Module):
74 | """ MobileOne building block.
75 |
76 | This block has a multi-branched architecture at train-time
77 | and plain-CNN style architecture at inference time
78 | For more details, please refer to our paper:
79 | `An Improved One millisecond Mobile Backbone` -
80 | https://arxiv.org/pdf/2206.04040.pdf
81 | """
82 |
83 | def __init__(
84 | self,
85 | in_channels: int,
86 | out_channels: int,
87 | kernel_size: int,
88 | stride: int = 1,
89 | padding: int = 0,
90 | dilation: int = 1,
91 | groups: int = 1,
92 | inference_mode: bool = False,
93 | use_se: bool = False,
94 | num_conv_branches: int = 1
95 | ) -> None:
96 | """
97 | Construct a MobileOneBlock module.
98 |
99 | Args:
100 | in_channels (int): Number of channels in the input.
101 | out_channels (int): Number of channels produced by the block.
102 | kernel_size (int or tuple): Size of the convolution kernel.
103 | stride (int or tuple): Stride size.
104 | padding (int or tuple): Zero-padding size.
105 | dilation (int or tuple): Kernel dilation factor.
106 | groups (int): Group number.
107 | inference_mode (bool): If True, instantiates model in inference mode.
108 | use_se (bool): Whether to use SE-ReLU activations.
109 | num_conv_branches (int): Number of linear conv branches.
110 | """
111 |
112 | super().__init__()
113 | self.inference_mode = inference_mode
114 | self.groups = groups
115 | self.stride = stride
116 | self.kernel_size = kernel_size
117 | self.in_channels = in_channels
118 | self.out_channels = out_channels
119 | self.num_conv_branches = num_conv_branches
120 |
121 | # Check if SE-ReLU is requested
122 | if use_se:
123 | self.se = SqueezeExcitationBlock(out_channels)
124 | else:
125 | self.se = nn.Identity()
126 | self.activation = nn.ReLU()
127 |
128 | if inference_mode:
129 | self.reparam_conv = nn.Conv2d(
130 | in_channels=in_channels,
131 | out_channels=out_channels,
132 | kernel_size=kernel_size,
133 | stride=stride,
134 | padding=padding,
135 | dilation=dilation,
136 | groups=groups,
137 | bias=True
138 | )
139 | else:
140 | # Re-parameterizable skip connection
141 | self.rbr_skip = nn.BatchNorm2d(num_features=in_channels) \
142 | if out_channels == in_channels and stride == 1 else None
143 |
144 | # Re-parameterizable conv branches
145 | rbr_conv = list()
146 | for _ in range(self.num_conv_branches):
147 | rbr_conv.append(self._conv_bn(kernel_size=kernel_size, padding=padding))
148 | self.rbr_conv = nn.ModuleList(rbr_conv)
149 |
150 | # Re-parameterizable scale branch
151 | self.rbr_scale = None
152 | if kernel_size > 1:
153 | self.rbr_scale = self._conv_bn(kernel_size=1, padding=0)
154 |
155 | def forward(self, x: torch.Tensor) -> torch.Tensor:
156 | """ Apply forward pass. """
157 | # Inference mode forward pass.
158 | if self.inference_mode:
159 | return self.activation(self.se(self.reparam_conv(x)))
160 |
161 | # Multi-branched train-time forward pass.
162 | # Skip branch output
163 | identity_out = 0
164 | if self.rbr_skip is not None:
165 | identity_out = self.rbr_skip(x)
166 |
167 | # Scale branch output
168 | scale_out = 0
169 | if self.rbr_scale is not None:
170 | scale_out = self.rbr_scale(x)
171 |
172 | # Other branches
173 | out = scale_out + identity_out
174 | for ix in range(self.num_conv_branches):
175 | out += self.rbr_conv[ix](x)
176 |
177 | return self.activation(self.se(out))
178 |
179 | def reparameterize(self):
180 | """
181 | Following works like `RepVGG: Making VGG-style ConvNets Great Again` - https://arxiv.org/pdf/2101.03697.pdf.
182 | We re-parameterize multi-branched architecture used at training time to obtain a plain CNN-like structure
183 | for inference.
184 | """
185 | if self.inference_mode:
186 | return
187 | kernel, bias = self._get_kernel_bias()
188 | self.reparam_conv = nn.Conv2d(
189 | in_channels=self.rbr_conv[0].conv.in_channels,
190 | out_channels=self.rbr_conv[0].conv.out_channels,
191 | kernel_size=self.rbr_conv[0].conv.kernel_size,
192 | stride=self.rbr_conv[0].conv.stride,
193 | padding=self.rbr_conv[0].conv.padding,
194 | dilation=self.rbr_conv[0].conv.dilation,
195 | groups=self.rbr_conv[0].conv.groups,
196 | bias=True
197 | )
198 | self.reparam_conv.weight.data = kernel
199 | self.reparam_conv.bias.data = bias
200 |
201 | # Delete un-used branches
202 | for para in self.parameters():
203 | para.detach_()
204 | self.__delattr__('rbr_conv')
205 | self.__delattr__('rbr_scale')
206 | if hasattr(self, 'rbr_skip'):
207 | self.__delattr__('rbr_skip')
208 |
209 | self.inference_mode = True
210 |
211 | def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
212 | """
213 | Method to fuse batchnorm layer with preceding conv layer.
214 | Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
215 |
216 | Args:
217 | branch: The branch containing the convolutional and batchnorm layers to be fused.
218 |
219 | Returns:
220 | tuple: A tuple containing the kernel and bias after fusing batchnorm.
221 | """
222 | # get weights and bias of scale branch
223 | kernel_scale = 0
224 | bias_scale = 0
225 | if self.rbr_scale is not None:
226 | kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
227 | # Pad scale branch kernel to match conv branch kernel size.
228 | pad = self.kernel_size // 2
229 | kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
230 |
231 | # get weights and bias of skip branch
232 | kernel_identity = 0
233 | bias_identity = 0
234 | if self.rbr_skip is not None:
235 | kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
236 |
237 | # get weights and bias of conv branches
238 | kernel_conv = 0
239 | bias_conv = 0
240 | for ix in range(self.num_conv_branches):
241 | _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
242 | kernel_conv += _kernel
243 | bias_conv += _bias
244 |
245 | kernel_final = kernel_conv + kernel_scale + kernel_identity
246 | bias_final = bias_conv + bias_scale + bias_identity
247 | return kernel_final, bias_final
248 |
249 | def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
250 | """
251 | Method to fuse batchnorm layer with preceding conv layer.
252 | Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
253 |
254 | Args:
255 | branch: The branch containing the convolutional and batchnorm layers to be fused.
256 |
257 | Returns:
258 | tuple: A tuple containing the kernel and bias after fusing batchnorm.
259 | """
260 | if isinstance(branch, nn.Sequential):
261 | kernel = branch.conv.weight
262 | running_mean = branch.bn.running_mean
263 | running_var = branch.bn.running_var
264 | gamma = branch.bn.weight
265 | beta = branch.bn.bias
266 | eps = branch.bn.eps
267 | else:
268 | assert isinstance(branch, nn.BatchNorm2d)
269 | if not hasattr(self, 'id_tensor'):
270 | input_dim = self.in_channels // self.groups
271 | kernel_value = torch.zeros(
272 | (self.in_channels, input_dim, self.kernel_size, self.kernel_size),
273 | dtype=branch.weight.dtype,
274 | device=branch.weight.device
275 | )
276 | for i in range(self.in_channels):
277 | kernel_value[i, i % input_dim,
278 | self.kernel_size // 2,
279 | self.kernel_size // 2] = 1
280 | self.id_tensor = kernel_value
281 | kernel = self.id_tensor
282 | running_mean = branch.running_mean
283 | running_var = branch.running_var
284 | gamma = branch.weight
285 | beta = branch.bias
286 | eps = branch.eps
287 | std = (running_var + eps).sqrt()
288 | t = (gamma / std).reshape(-1, 1, 1, 1)
289 | return kernel * t, beta - running_mean * gamma / std
290 |
291 | def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential:
292 | """
293 | Helper method to construct conv-batchnorm layers.
294 |
295 | Args:
296 | kernel_size (int): Size of the convolution kernel.
297 | padding (int): Zero-padding size.
298 |
299 | Returns:
300 | nn.Sequential: A Conv-BN module.
301 | """
302 | mod_list = nn.Sequential()
303 | mod_list.add_module(
304 | 'conv',
305 | nn.Conv2d(
306 | in_channels=self.in_channels,
307 | out_channels=self.out_channels,
308 | kernel_size=kernel_size,
309 | stride=self.stride,
310 | padding=padding,
311 | groups=self.groups,
312 | bias=False
313 | )
314 | )
315 | mod_list.add_module('bn', nn.BatchNorm2d(num_features=self.out_channels))
316 | return mod_list
317 |
318 |
319 | class MobileOne(nn.Module):
320 | """
321 | MobileOne Model
322 |
323 | Pytorch implementation of `An Improved One millisecond Mobile Backbone` -
324 | https://arxiv.org/pdf/2206.04040.pdf
325 | """
326 |
327 | def __init__(
328 | self,
329 | num_blocks_per_stage: List[int] = [2, 8, 10, 1],
330 | num_classes: int = 1000,
331 | width_multipliers: Optional[List[float]] = None,
332 | inference_mode: bool = False,
333 | use_se: bool = False,
334 | num_conv_branches: int = 1
335 | ) -> None:
336 | """
337 | Construct MobileOne model.
338 |
339 | Args:
340 | num_blocks_per_stage (list): List of number of blocks per stage.
341 | num_classes (int): Number of classes in the dataset.
342 | width_multipliers (list): List of width multipliers for blocks in a stage.
343 | inference_mode (bool): If True, instantiates model in inference mode.
344 | use_se (bool): Whether to use SE-ReLU activations.
345 | num_conv_branches (int): Number of linear conv branches.
346 | """
347 | super().__init__()
348 |
349 | assert len(width_multipliers) == 4
350 | self.inference_mode = inference_mode
351 | self.in_planes = min(64, int(64 * width_multipliers[0]))
352 | self.use_se = use_se
353 | self.num_conv_branches = num_conv_branches
354 |
355 | # Build stages
356 | self.stage0 = MobileOneBlock(
357 | in_channels=3,
358 | out_channels=self.in_planes,
359 | kernel_size=3,
360 | stride=2,
361 | padding=1,
362 | inference_mode=self.inference_mode
363 | )
364 | self.cur_layer_idx = 1
365 | self.stage1 = self._make_stage(int(64 * width_multipliers[0]), num_blocks_per_stage[0], num_se_blocks=0)
366 | self.stage2 = self._make_stage(int(128 * width_multipliers[1]), num_blocks_per_stage[1], num_se_blocks=0)
367 | self.stage3 = self._make_stage(
368 | int(256 * width_multipliers[2]),
369 | num_blocks_per_stage[2],
370 | num_se_blocks=int(num_blocks_per_stage[2] // 2) if use_se else 0
371 | )
372 | self.stage4 = self._make_stage(
373 | int(512 * width_multipliers[3]),
374 | num_blocks_per_stage[3],
375 | num_se_blocks=num_blocks_per_stage[3] if use_se else 0
376 | )
377 | self.gap = nn.AdaptiveAvgPool2d(output_size=1)
378 |
379 | # yaw and pitch
380 | self.fc_yaw = nn.Linear(int(512 * width_multipliers[3]), num_classes)
381 | self.fc_pitch = nn.Linear(int(512 * width_multipliers[3]), num_classes)
382 |
383 | # self.linear = nn.Linear(int(512 * width_multipliers[3]), num_classes)
384 |
385 | def _make_stage(self, planes: int, num_blocks: int, num_se_blocks: int) -> nn.Sequential:
386 | """
387 | Build a stage of the MobileOne model.
388 |
389 | Args:
390 | planes (int): Number of output channels.
391 | num_blocks (int): Number of blocks in this stage.
392 | num_se_blocks (int): Number of SE blocks in this stage.
393 |
394 | Returns:
395 | nn.Sequential: A stage of the MobileOne model.
396 | """
397 |
398 | # Get strides for all layers
399 | strides = [2] + [1]*(num_blocks-1)
400 | blocks = []
401 | for ix, stride in enumerate(strides):
402 | use_se = False
403 | if num_se_blocks > num_blocks:
404 | raise ValueError("Number of SE blocks cannot exceed number of layers.")
405 | if ix >= (num_blocks - num_se_blocks):
406 | use_se = True
407 |
408 | # Depthwise conv
409 | blocks.append(
410 | MobileOneBlock(
411 | in_channels=self.in_planes,
412 | out_channels=self.in_planes,
413 | kernel_size=3,
414 | stride=stride,
415 | padding=1,
416 | groups=self.in_planes,
417 | inference_mode=self.inference_mode,
418 | use_se=use_se,
419 | num_conv_branches=self.num_conv_branches
420 | )
421 | )
422 | # Pointwise conv
423 | blocks.append(
424 | MobileOneBlock(
425 | in_channels=self.in_planes,
426 | out_channels=planes,
427 | kernel_size=1,
428 | stride=1,
429 | padding=0,
430 | groups=1,
431 | inference_mode=self.inference_mode,
432 | use_se=use_se,
433 | num_conv_branches=self.num_conv_branches
434 | )
435 | )
436 | self.in_planes = planes
437 | self.cur_layer_idx += 1
438 | return nn.Sequential(*blocks)
439 |
440 | def forward(self, x: torch.Tensor) -> torch.Tensor:
441 | """ Apply forward pass . """
442 | x = self.stage0(x)
443 | x = self.stage1(x)
444 | x = self.stage2(x)
445 | x = self.stage3(x)
446 | x = self.stage4(x)
447 | x = self.gap(x)
448 | x = x.view(x.size(0), -1)
449 | # x = self.linear(x)
450 |
451 | yaw = self.fc_yaw(x)
452 | pitch = self.fc_pitch(x)
453 |
454 | return pitch, yaw
455 |
456 |
457 | def reparameterize_model(model: torch.nn.Module) -> nn.Module:
458 | """
459 | Re-parameterize the MobileOne model from a multi-branched structure (used in training)
460 | into a single branch for inference.
461 |
462 | Args:
463 | model (nn.Module): MobileOne model in training mode.
464 |
465 | Returns:
466 | nn.Module: MobileOne model re-parameterized for inference mode.
467 | """
468 |
469 | # Avoid editing original graph
470 | model = copy.deepcopy(model)
471 | for module in model.modules():
472 | if hasattr(module, 'reparameterize'):
473 | module.reparameterize()
474 | return model
475 |
476 |
477 | MOBILEONE_CONFIGS = {
478 | "mobileone_s0":
479 | {
480 | "weights": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s0_unfused.pth.tar",
481 | "params":
482 | {
483 | "width_multipliers": (0.75, 1.0, 1.0, 2.0),
484 | "num_conv_branches": 4
485 | }
486 | },
487 | "mobileone_s1":
488 | {
489 | "weights": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s1_unfused.pth.tar",
490 | "params":
491 | {
492 | "width_multipliers": (1.5, 1.5, 2.0, 2.5),
493 | }
494 |
495 | },
496 | "mobileone_s2":
497 | {
498 | "weights": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s2_unfused.pth.tar",
499 | "params":
500 | {
501 | "width_multipliers": (1.5, 2.0, 2.5, 4.0),
502 | }
503 | },
504 | "mobileone_s3":
505 | {
506 | "weights": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s3_unfused.pth.tar",
507 | "params":
508 | {
509 | "width_multipliers": (2.0, 2.5, 3.0, 4.0),
510 | }
511 | },
512 | "mobileone_s4":
513 | {
514 | "weights": "https://docs-assets.developer.apple.com/ml-research/datasets/mobileone/mobileone_s4_unfused.pth.tar",
515 | "params":
516 | {
517 | "width_multipliers": (3.0, 3.5, 3.5, 4.0),
518 | "use_se": True
519 | }
520 | }
521 | }
522 |
523 |
524 | def load_filtered_state_dict(model, state_dict):
525 | """Update the model's state dictionary with filtered parameters.
526 |
527 | Args:
528 | model: The model instance to update (must have `state_dict` and `load_state_dict` methods).
529 | state_dict: A dictionary of parameters to load into the model.
530 | """
531 | current_model_dict = model.state_dict()
532 | filtered_state_dict = {key: value for key, value in state_dict.items() if key in current_model_dict}
533 | current_model_dict.update(filtered_state_dict)
534 | model.load_state_dict(current_model_dict)
535 |
536 |
537 | def create_mobileone_model(config, pretrained: bool = True, num_classes: int = 1000, inference_mode: bool = False) -> nn.Module:
538 | """
539 | Create a MobileOne model based on the specified architecture name.
540 |
541 | Args:
542 | config (dict): The configuration dictionary for the MobileOne model.
543 | pretrained (bool): If True, loads pre-trained weights for the specified architecture. Defaults to True.
544 | num_classes (int): Number of output classes for the model. Defaults to 1000.
545 | inference_mode (bool): If True, instantiates the model in inference mode. Defaults to False.
546 |
547 | Returns:
548 | nn.Module: The constructed MobileOne model.
549 | """
550 | weights = config["weights"]
551 | params = config["params"]
552 |
553 | model = MobileOne(num_classes=num_classes, inference_mode=inference_mode, **params)
554 |
555 | if pretrained:
556 | try:
557 | state_dict = torch.hub.load_state_dict_from_url(weights)
558 | load_filtered_state_dict(model, state_dict)
559 | logger.info("Pre-trained weights successfully loaded.")
560 | except Exception as e:
561 | logger.warning(f"Could not load pre-trained weights. Exception: {e}")
562 | logger.info("Creating model without pre-trained weights.")
563 | else:
564 | logger.info("Creating model without pre-trained weights.")
565 |
566 | return model
567 |
568 |
569 | def mobileone_s0(pretrained=True, num_classes=1000, inference_mode=False):
570 | return create_mobileone_model(MOBILEONE_CONFIGS['mobileone_s0'], pretrained, num_classes, inference_mode)
571 |
572 |
573 | def mobileone_s1(pretrained=True, num_classes=1000, inference_mode=False):
574 | return create_mobileone_model(MOBILEONE_CONFIGS['mobileone_s1'], pretrained, num_classes, inference_mode)
575 |
576 |
577 | def mobileone_s2(pretrained=True, num_classes=1000, inference_mode=False):
578 | return create_mobileone_model(MOBILEONE_CONFIGS['mobileone_s2'], pretrained, num_classes, inference_mode)
579 |
580 |
581 | def mobileone_s3(pretrained=True, num_classes=1000, inference_mode=False):
582 | return create_mobileone_model(MOBILEONE_CONFIGS['mobileone_s3'], pretrained, num_classes, inference_mode)
583 |
584 |
585 | def mobileone_s4(pretrained=True, num_classes=1000, inference_mode=False):
586 | return create_mobileone_model(MOBILEONE_CONFIGS['mobileone_s4'], pretrained, num_classes, inference_mode)
587 |
588 |
589 | if __name__ == "__main__":
590 | model = mobileone_s2()
591 | print(sum(p.numel() for p in model.parameters() if p.requires_grad))
592 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | from torchvision.models import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights
4 |
5 | from typing import Any, Callable, List, Optional, Type, Tuple
6 |
7 |
8 | __all__ = ["resnet18", "resnet34", "resnet50"]
9 |
10 |
11 | def conv3x3(in_channels: int, out_channels: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
12 | """3x3 convolution with padding"""
13 | return nn.Conv2d(
14 | in_channels,
15 | out_channels,
16 | kernel_size=3,
17 | stride=stride,
18 | padding=dilation,
19 | groups=groups,
20 | bias=False,
21 | dilation=dilation,
22 | )
23 |
24 |
25 | def conv1x1(in_channels: int, out_channels: int, stride: int = 1) -> nn.Conv2d:
26 | """1x1 convolution"""
27 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
28 |
29 |
30 | class BasicBlock(nn.Module):
31 | expansion: int = 1
32 |
33 | def __init__(
34 | self,
35 | in_channels: int,
36 | out_channels: int,
37 | stride: int = 1,
38 | downsample: Optional[nn.Module] = None,
39 | groups: int = 1,
40 | base_width: int = 64,
41 | dilation: int = 1,
42 | norm_layer: Optional[Callable[..., nn.Module]] = None,
43 | ) -> None:
44 | super().__init__()
45 | if norm_layer is None:
46 | norm_layer = nn.BatchNorm2d
47 | if groups != 1 or base_width != 64:
48 | raise ValueError("BasicBlock only supports groups=1 and base_width=64")
49 | if dilation > 1:
50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
52 | self.conv1 = conv3x3(in_channels, out_channels, stride)
53 | self.bn1 = norm_layer(out_channels)
54 | self.relu = nn.ReLU(inplace=True)
55 | self.conv2 = conv3x3(out_channels, out_channels)
56 | self.bn2 = norm_layer(out_channels)
57 | self.downsample = downsample
58 | self.stride = stride
59 |
60 | def forward(self, x: Tensor) -> Tensor:
61 | identity = x
62 |
63 | out = self.conv1(x)
64 | out = self.bn1(out)
65 | out = self.relu(out)
66 |
67 | out = self.conv2(out)
68 | out = self.bn2(out)
69 |
70 | if self.downsample is not None:
71 | identity = self.downsample(x)
72 |
73 | out += identity
74 | out = self.relu(out)
75 |
76 | return out
77 |
78 |
79 | class Bottleneck(nn.Module):
80 | expansion: int = 4
81 |
82 | def __init__(
83 | self,
84 | inplanes: int,
85 | planes: int,
86 | stride: int = 1,
87 | downsample: Optional[nn.Module] = None,
88 | groups: int = 1,
89 | base_width: int = 64,
90 | dilation: int = 1,
91 | norm_layer: Optional[Callable[..., nn.Module]] = None,
92 | ) -> None:
93 | super().__init__()
94 | if norm_layer is None:
95 | norm_layer = nn.BatchNorm2d
96 | width = int(planes * (base_width / 64.0)) * groups
97 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
98 | self.conv1 = conv1x1(inplanes, width)
99 | self.bn1 = norm_layer(width)
100 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
101 | self.bn2 = norm_layer(width)
102 | self.conv3 = conv1x1(width, planes * self.expansion)
103 | self.bn3 = norm_layer(planes * self.expansion)
104 | self.relu = nn.ReLU(inplace=True)
105 | self.downsample = downsample
106 | self.stride = stride
107 |
108 | def forward(self, x: Tensor) -> Tensor:
109 | identity = x
110 |
111 | out = self.conv1(x)
112 | out = self.bn1(out)
113 | out = self.relu(out)
114 |
115 | out = self.conv2(out)
116 | out = self.bn2(out)
117 | out = self.relu(out)
118 |
119 | out = self.conv3(out)
120 | out = self.bn3(out)
121 |
122 | if self.downsample is not None:
123 | identity = self.downsample(x)
124 |
125 | out += identity
126 | out = self.relu(out)
127 |
128 | return out
129 |
130 |
131 | class ResNet(nn.Module):
132 | def __init__(
133 | self,
134 | block: Type[BasicBlock | Bottleneck],
135 | layers: List[int],
136 | num_classes: int = 1000,
137 | groups: int = 1,
138 | width_per_group: int = 64,
139 | replace_stride_with_dilation: Optional[List[bool]] = None,
140 | norm_layer: Optional[Callable[..., nn.Module]] = None,
141 | ) -> None:
142 | super().__init__()
143 | if norm_layer is None:
144 | norm_layer = nn.BatchNorm2d
145 | self._norm_layer = norm_layer
146 |
147 | self.in_channels = 64
148 | self.dilation = 1
149 | if replace_stride_with_dilation is None:
150 | # each element in the tuple indicates if we should replace
151 | # the 2x2 stride with a dilated convolution instead
152 | replace_stride_with_dilation = [False, False, False]
153 | if len(replace_stride_with_dilation) != 3:
154 | raise ValueError(
155 | "replace_stride_with_dilation should be None "
156 | f"or a 3-element tuple, got {replace_stride_with_dilation}"
157 | )
158 | self.groups = groups
159 | self.base_width = width_per_group
160 | self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
161 | self.bn1 = norm_layer(self.in_channels)
162 | self.relu = nn.ReLU(inplace=True)
163 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
164 | self.layer1 = self._make_layer(block, 64, layers[0])
165 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
166 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
167 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
168 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
169 |
170 | # yaw and pitch
171 | self.fc_yaw = nn.Linear(512 * block.expansion, num_classes)
172 | self.fc_pitch = nn.Linear(512 * block.expansion, num_classes)
173 |
174 | # Original FC Layer for ResNet
175 | # self.fc = nn.Linear(512 * block.expansion, num_classes)
176 |
177 | for m in self.modules():
178 | if isinstance(m, nn.Conv2d):
179 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
180 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
181 | nn.init.constant_(m.weight, 1)
182 | nn.init.constant_(m.bias, 0)
183 |
184 | def _make_layer(
185 | self,
186 | block: Type[BasicBlock | Bottleneck],
187 | planes: int,
188 | blocks: int,
189 | stride: int = 1,
190 | dilate: bool = False,
191 | ) -> nn.Sequential:
192 | norm_layer = self._norm_layer
193 | downsample = None
194 | previous_dilation = self.dilation
195 | if dilate:
196 | self.dilation *= stride
197 | stride = 1
198 | if stride != 1 or self.in_channels != planes * block.expansion:
199 | downsample = nn.Sequential(
200 | conv1x1(self.in_channels, planes * block.expansion, stride),
201 | norm_layer(planes * block.expansion),
202 | )
203 |
204 | layers = []
205 | layers.append(
206 | block(
207 | self.in_channels,
208 | planes,
209 | stride,
210 | downsample,
211 | self.groups,
212 | self.base_width,
213 | previous_dilation,
214 | norm_layer
215 | )
216 | )
217 | self.in_channels = planes * block.expansion
218 | for _ in range(1, blocks):
219 | layers.append(
220 | block(
221 | self.in_channels,
222 | planes,
223 | groups=self.groups,
224 | base_width=self.base_width,
225 | dilation=self.dilation,
226 | norm_layer=norm_layer,
227 | )
228 | )
229 |
230 | return nn.Sequential(*layers)
231 |
232 | def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
233 | x = self.conv1(x)
234 | x = self.bn1(x)
235 | x = self.relu(x)
236 | x = self.maxpool(x)
237 |
238 | x = self.layer1(x) # 1/4
239 | x = self.layer2(x) # 1/8
240 | x = self.layer3(x) # 1/16
241 | x = self.layer4(x) # 1/32
242 |
243 | x = self.avgpool(x)
244 | x = torch.flatten(x, 1)
245 |
246 | # Original FC Layer for ResNet
247 | # x = self.fc(x)
248 | yaw = self.fc_yaw(x)
249 | pitch = self.fc_pitch(x)
250 |
251 | return pitch, yaw
252 |
253 |
254 | def load_filtered_state_dict(model, state_dict):
255 | """Update the model's state dictionary with filtered parameters.
256 |
257 | Args:
258 | model: The model instance to update (must have `state_dict` and `load_state_dict` methods).
259 | state_dict: A dictionary of parameters to load into the model.
260 | """
261 | current_model_dict = model.state_dict()
262 | filtered_state_dict = {key: value for key, value in state_dict.items() if key in current_model_dict}
263 | current_model_dict.update(filtered_state_dict)
264 | model.load_state_dict(current_model_dict)
265 |
266 |
267 | def _resnet(block: Type[BasicBlock], layers: List[int], weights: Optional[ResNet34_Weights], progress: bool, **kwargs: Any) -> ResNet:
268 | model = ResNet(block, layers, **kwargs)
269 |
270 | if weights is not None:
271 | state_dict = weights.get_state_dict(progress=progress, check_hash=True)
272 | load_filtered_state_dict(model, state_dict)
273 |
274 | return model
275 |
276 |
277 | def resnet18(*, pretrained: bool = True, progress: bool = True, **kwargs: Any) -> ResNet:
278 | if pretrained:
279 | weights = ResNet18_Weights.DEFAULT
280 | else:
281 | weights = None
282 | return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
283 |
284 |
285 | def resnet34(*, pretrained: bool = True, progress: bool = True, **kwargs: Any) -> ResNet:
286 | if pretrained:
287 | weights = ResNet34_Weights.DEFAULT
288 | else:
289 | weights = None
290 | return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
291 |
292 |
293 | def resnet50(*, pretrained: bool = True, progress: bool = True, **kwargs: Any) -> ResNet:
294 | if pretrained:
295 | weights = ResNet50_Weights.DEFAULT
296 | else:
297 | weights = None
298 |
299 | return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
300 |
--------------------------------------------------------------------------------
/mpii_train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import logging
5 | import argparse
6 | import numpy as np
7 | from tqdm import tqdm
8 | from sklearn.model_selection import KFold
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.utils.data import DataLoader
14 |
15 | from config import data_config
16 | from utils.helpers import angular_error, gaze_to_3d, get_model, get_dataloader
17 |
18 | # Setup logging
19 | logging.basicConfig(
20 | level=logging.INFO,
21 | format='%(asctime)s - %(message)s',
22 | # handlers=[
23 | # logging.FileHandler("training.log"),
24 | # logging.StreamHandler(sys.stdout) # Display logs in terminal
25 | # ]
26 | )
27 |
28 |
29 | def parse_args():
30 | """Parse input arguments."""
31 | parser = argparse.ArgumentParser(description="Gaze estimation training")
32 | parser.add_argument("--data", type=str, default="data", help="Directory path for gaze images.")
33 | parser.add_argument("--dataset", type=str, default="gaze360", help="Dataset name, available `gaze360`, `mpiigaze`.")
34 | parser.add_argument("--output", type=str, default="output/", help="Path of output models.")
35 | parser.add_argument("--checkpoint", type=str, default="", help="Path to checkpoint for resuming training.")
36 | parser.add_argument("--num-epochs", type=int, default=100, help="Maximum number of training epochs.")
37 | parser.add_argument("--batch-size", type=int, default=64, help="Batch size.")
38 | parser.add_argument(
39 | "--arch",
40 | type=str,
41 | default="resnet18",
42 | help="Network architecture, currently available: resnet18/34/50, mobilenetv2, mobileone_s0-s4."
43 | )
44 | parser.add_argument("--alpha", type=float, default=1, help="Regression loss coefficient.")
45 | parser.add_argument("--lr", type=float, default=0.00001, help="Base learning rate.")
46 | parser.add_argument("--num-workers", type=int, default=8, help="Number of workers for data loading.")
47 |
48 | args = parser.parse_args()
49 |
50 | # Override default values based on selected dataset
51 | if args.dataset in data_config:
52 | dataset_config = data_config[args.dataset]
53 | args.bins = dataset_config["bins"]
54 | args.binwidth = dataset_config["binwidth"]
55 | args.angle = dataset_config["angle"]
56 | else:
57 | raise ValueError(f"Unknown dataset: {args.dataset}. Available options: {list(data_config.keys())}")
58 |
59 | return args
60 |
61 |
62 | def initialize_model(params, device):
63 | """
64 | Initialize the gaze estimation model, optimizer, and optionally load a checkpoint.
65 |
66 | Args:
67 | params (argparse.Namespace): Parsed command-line arguments.
68 | device (torch.device): Device to load the model and optimizer onto.
69 |
70 | Returns:
71 | Tuple[nn.Module, torch.optim.Optimizer, int]: Initialized model, optimizer, and the starting epoch.
72 | """
73 | model = get_model(params.arch, params.bins)
74 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr)
75 | start_epoch = 0
76 |
77 | if params.checkpoint:
78 | checkpoint = torch.load(params.checkpoint, map_location=device)
79 | model.load_state_dict(checkpoint['model_state_dict'])
80 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
81 |
82 | # Move optimizer states to device
83 | for state in optimizer.state.values():
84 | for k, v in state.items():
85 | if isinstance(v, torch.Tensor):
86 | state[k] = v.to(device)
87 |
88 | start_epoch = checkpoint['epoch']
89 | logging.info(f'Resumed training from {params.checkpoint}, starting at epoch {start_epoch + 1}')
90 |
91 | return model.to(device), optimizer, start_epoch
92 |
93 |
94 | def train_one_epoch(
95 | params,
96 | model,
97 | cls_criterion,
98 | reg_criterion,
99 | optimizer,
100 | data_loader,
101 | idx_tensor,
102 | device,
103 | epoch
104 | ):
105 | """
106 | Train the model for one epoch.
107 |
108 | Args:
109 | params (argparse.Namespace): Parsed command-line arguments.
110 | model (nn.Module): The gaze estimation model.
111 | cls_criterion (nn.Module): Loss function for classification.
112 | reg_criterion (nn.Module): Loss function for regression.
113 | optimizer (torch.optim.Optimizer): Optimizer for the model.
114 | data_loader (DataLoader): DataLoader for the training dataset.
115 | idx_tensor (torch.Tensor): Tensor representing bin indices.
116 | device (torch.device): Device to perform training on.
117 | epoch (int): The current epoch number.
118 |
119 | Returns:
120 | Tuple[float, float]: Average losses for pitch and yaw.
121 | """
122 |
123 | model.train()
124 | sum_loss_pitch, sum_loss_yaw = 0, 0
125 |
126 | for idx, (images, labels_gaze, regression_labels_gaze, _) in enumerate(data_loader):
127 | images = images.to(device)
128 |
129 | # Binned labels
130 | label_pitch = labels_gaze[:, 0].to(device)
131 | label_yaw = labels_gaze[:, 1].to(device)
132 |
133 | # Regression labels
134 | label_pitch_regression = regression_labels_gaze[:, 0].to(device)
135 | label_yaw_regression = regression_labels_gaze[:, 1].to(device)
136 |
137 | # Inference
138 | pitch, yaw = model(images)
139 |
140 | # Cross Entropy Loss
141 | loss_pitch = cls_criterion(pitch, label_pitch)
142 | loss_yaw = cls_criterion(yaw, label_yaw)
143 |
144 | # Mapping from binned (0 to 90) to angels (-180 to 180)
145 | pitch_predicted = torch.sum(F.softmax(pitch, dim=1) * idx_tensor, 1) * params.binwidth - params.angle
146 | yaw_predicted = torch.sum(F.softmax(yaw, dim=1) * idx_tensor, 1) * params.binwidth - params.angle
147 |
148 | # Mean Squared Error Loss
149 | loss_regression_pitch = reg_criterion(pitch_predicted, label_pitch_regression)
150 | loss_regression_yaw = reg_criterion(yaw_predicted, label_yaw_regression)
151 |
152 | # Calculate loss with regression alpha
153 | loss_pitch += params.alpha * loss_regression_pitch
154 | loss_yaw += params.alpha * loss_regression_yaw
155 |
156 | # Total loss for pitch and yaw
157 | loss = loss_pitch + loss_yaw
158 |
159 | optimizer.zero_grad()
160 | loss.backward()
161 | optimizer.step()
162 |
163 | sum_loss_pitch += loss_pitch.item()
164 | sum_loss_yaw += loss_yaw.item()
165 |
166 | if (idx + 1) % 100 == 0:
167 | logging.info(
168 | f'Epoch [{epoch + 1}/{params.num_epochs}], Iter [{idx + 1}/{len(data_loader)}] '
169 | f'Losses: Gaze Yaw {sum_loss_yaw / (idx + 1):.4f}, Gaze Pitch {sum_loss_pitch / (idx + 1):.4f}'
170 | )
171 | avg_loss_pitch, avg_loss_yaw = sum_loss_pitch / len(data_loader), sum_loss_yaw / len(data_loader)
172 |
173 | return avg_loss_pitch, avg_loss_yaw
174 |
175 |
176 | @torch.no_grad()
177 | def evaluate(params, model, data_loader, idx_tensor, device):
178 | """
179 | Evaluate the model on the test dataset.
180 |
181 | Args:
182 | params (argparse.Namespace): Parsed command-line arguments.
183 | model (nn.Module): The gaze estimation model.
184 | data_loader (torch.utils.data.DataLoader): DataLoader for the test dataset.
185 | idx_tensor (torch.Tensor): Tensor representing bin indices.
186 | device (torch.device): Device to perform evaluation on.
187 | """
188 | model.eval()
189 | average_error = 0
190 | total_samples = 0
191 |
192 | for images, labels_gaze, regression_labels_gaze, _ in tqdm(data_loader, total=len(data_loader)):
193 | total_samples += regression_labels_gaze.size(0)
194 | images = images.to(device)
195 |
196 | # Regression labels
197 | label_pitch = np.radians(regression_labels_gaze[:, 0], dtype=np.float32)
198 | label_yaw = np.radians(regression_labels_gaze[:, 1], dtype=np.float32)
199 |
200 | # Inference
201 | pitch, yaw = model(images)
202 |
203 | # Regression predictions
204 | pitch_predicted = F.softmax(pitch, dim=1)
205 | yaw_predicted = F.softmax(yaw, dim=1)
206 |
207 | # Mapping from binned (0 to 90) to angles (-180 to 180) or (0 to 28) to angles (-42, 42)
208 | pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) * params.binwidth - params.angle
209 | yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) * params.binwidth - params.angle
210 |
211 | pitch_predicted = np.radians(pitch_predicted.cpu())
212 | yaw_predicted = np.radians(yaw_predicted.cpu())
213 |
214 | for p, y, pl, yl in zip(pitch_predicted, yaw_predicted, label_pitch, label_yaw):
215 | average_error += angular_error(gaze_to_3d([p, y]), gaze_to_3d([pl, yl]))
216 |
217 | logging.info(
218 | f"Dataset: {params.dataset} | "
219 | f"Total Number of Samples: {total_samples} | "
220 | f"Mean Angular Error: {average_error/total_samples}"
221 | )
222 | return average_error/total_samples
223 |
224 |
225 | def main():
226 | params = parse_args()
227 |
228 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
229 | summary_name = f'{params.dataset}_{params.arch}_{int(time.time())}'
230 | output = os.path.join(params.output, summary_name)
231 | if not os.path.exists(output):
232 | os.makedirs(output)
233 | torch.backends.cudnn.benchmark = True
234 |
235 | model, optimizer, start_epoch = initialize_model(params, device)
236 | data_loader = get_dataloader(params, mode="train")
237 | dataset = data_loader.dataset
238 |
239 | cls_criterion = nn.CrossEntropyLoss()
240 | reg_criterion = nn.MSELoss()
241 | idx_tensor = torch.arange(params.bins, device=device, dtype=torch.float32)
242 |
243 | best_avg_error = float('inf')
244 | k = 5 # number of folds
245 | kfold = KFold(n_splits=k, shuffle=True, random_state=42)
246 |
247 | fold_errors = []
248 | # K-Fold Cross Validation
249 | for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
250 | print(f"Fold {fold+1}/{k}")
251 |
252 | # Split data into training and validation sets for this fold
253 | train_subset = torch.utils.data.Subset(dataset, train_idx)
254 | val_subset = torch.utils.data.Subset(dataset, val_idx)
255 |
256 | # Create data loaders for the subsets
257 | train_loader = torch.utils.data.DataLoader(train_subset, batch_size=params.batch_size, shuffle=True)
258 | val_loader = torch.utils.data.DataLoader(val_subset, batch_size=params.batch_size, shuffle=False)
259 |
260 | # Reset model and optimizer for each fold
261 | model, optimizer, start_epoch = initialize_model(params, device)
262 |
263 | for epoch in range(start_epoch, params.num_epochs):
264 | avg_loss_pitch, avg_loss_yaw = train_one_epoch(
265 | params,
266 | model,
267 | cls_criterion,
268 | reg_criterion,
269 | optimizer,
270 | train_loader,
271 | idx_tensor,
272 | device,
273 | epoch
274 | )
275 |
276 | logging.info(
277 | f'Epoch [{epoch + 1}/{params.num_epochs}] '
278 | f'Losses: Gaze Yaw {avg_loss_yaw:.4f}, Gaze Pitch {avg_loss_pitch:.4f}'
279 | )
280 |
281 | # checkpoint_path = os.path.join(output, f"checkpoint_fold_{fold+1}.ckpt")
282 | # torch.save({
283 | # 'epoch': epoch + 1,
284 | # 'model_state_dict': model.state_dict(),
285 | # 'optimizer_state_dict': optimizer.state_dict(),
286 | # 'loss': avg_loss_pitch + avg_loss_yaw,
287 | # }, checkpoint_path)
288 | # logging.info(f'Checkpoint saved at {checkpoint_path}')
289 |
290 | # Evaluate on validation set for the current fold
291 | avg_error = evaluate(params, model, val_loader, idx_tensor, device) # Returns average error
292 | fold_errors.append(avg_error)
293 |
294 | logging.info(f'Fold {fold+1} average error: {avg_error:.4f}')
295 |
296 | # Save the best model for the fold
297 | if avg_error < best_avg_error:
298 | best_avg_error = avg_error
299 | best_model_path = os.path.join(output, f'best_model.pt')
300 | torch.save(model.state_dict(), best_model_path)
301 | logging.info(f'Best model saved for fold {fold+1} at {best_model_path}')
302 |
303 | # Calculate average error across all folds
304 | avg_error_overall = np.mean(fold_errors)
305 | logging.info(f'Average error across {k} folds: {avg_error_overall:.4f}')
306 |
307 |
308 | if __name__ == '__main__':
309 | main()
310 |
--------------------------------------------------------------------------------
/onnx_export.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | from config import data_config
5 | from utils.helpers import get_model
6 |
7 |
8 | def parse_arguments():
9 | parser = argparse.ArgumentParser(description='Gaze Estimation Model ONNX Export')
10 |
11 | parser.add_argument(
12 | '-w', '--weight',
13 | default='resnet34.pt',
14 | type=str,
15 | help='Trained state_dict file path to open'
16 | )
17 | parser.add_argument(
18 | '-n', '--model',
19 | type=str,
20 | default='resnet34',
21 | choices=['resnet18', 'resnet34', 'resnet50', 'mobilenetv2', 'mobileone_s0'],
22 | help='Backbone network architecture to use'
23 | )
24 | parser.add_argument(
25 | '-d', '--dataset',
26 | type=str,
27 | default='gaze360',
28 | choices=list(data_config.keys()),
29 | help='Dataset name for bin configuration'
30 | )
31 | parser.add_argument(
32 | '--dynamic',
33 | action='store_true',
34 | help='Enable dynamic batch size and input dimensions for ONNX export'
35 | )
36 |
37 | return parser.parse_args()
38 |
39 |
40 | @torch.no_grad()
41 | def onnx_export(params):
42 | # Get dataset config for bins
43 | if params.dataset not in data_config:
44 | raise KeyError(f"Unknown dataset: {params.dataset}. Available options: {list(data_config.keys())}")
45 | bins = data_config[params.dataset]['bins']
46 |
47 | # Set device
48 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49 |
50 | # Initialize model
51 | model = get_model(params.model, bins, inference_mode=True)
52 | model.to(device)
53 |
54 | # Load weights
55 | state_dict = torch.load(params.weight, map_location=device)
56 | model.load_state_dict(state_dict)
57 | print("Gaze model loaded successfully!")
58 |
59 | # Eval mode
60 | model.eval()
61 |
62 | # Generate ONNX output filename
63 | fname = os.path.splitext(os.path.basename(params.weight))[0]
64 | onnx_model = f'{fname}_gaze.onnx'
65 | print(f"==> Exporting model to ONNX format at '{onnx_model}'")
66 |
67 | # Dummy input: RGB image, 448x448
68 | dummy_input = torch.randn(1, 3, 448, 448).to(device)
69 |
70 | # Handle dynamic axes
71 | dynamic_axes = None
72 | if params.dynamic:
73 | dynamic_axes = {
74 | 'input': {0: 'batch_size'},
75 | 'pitch': {0: 'batch_size'},
76 | 'yaw': {0: 'batch_size'}
77 | }
78 | print("Exporting model with dynamic input shapes.")
79 | else:
80 | print("Exporting model with fixed input size: (1, 3, 448, 448)")
81 |
82 | # Export model
83 | torch.onnx.export(
84 | model,
85 | dummy_input,
86 | onnx_model,
87 | export_params=True,
88 | opset_version=20,
89 | do_constant_folding=True,
90 | input_names=['input'],
91 | output_names=['pitch', 'yaw'],
92 | dynamic_axes=dynamic_axes
93 | )
94 |
95 | print(f"Model exported successfully to {onnx_model}")
96 |
97 |
98 | if __name__ == '__main__':
99 | args = parse_arguments()
100 | onnx_export(args)
101 |
--------------------------------------------------------------------------------
/onnx_inference.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Yakhyokhuja Valikhujaev
2 | # Author: Yakhyokhuja Valikhujaev
3 | # GitHub: https://github.com/yakhyo
4 |
5 | import cv2
6 | import uniface
7 | import argparse
8 | import numpy as np
9 | import onnxruntime as ort
10 |
11 | from typing import Tuple
12 |
13 | from utils.helpers import draw_bbox_gaze
14 |
15 |
16 | class GazeEstimationONNX:
17 | """
18 | Gaze estimation using ONNXRuntime (logits to radian decoded).
19 | """
20 |
21 | def __init__(self, model_path: str, session: ort.InferenceSession = None) -> None:
22 | """Initializes the GazeEstimationONNX class.
23 |
24 | Args:
25 | model_path (str): Path to the ONNX model file.
26 | session (ort.InferenceSession, optional): ONNX Session. Defaults to None.
27 |
28 | Raises:
29 | AssertionError: If model_path is None and session is not provided.
30 | """
31 | self.session = session
32 | if self.session is None:
33 | assert model_path is not None, "Model path is required for the first time initialization."
34 | self.session = ort.InferenceSession(
35 | model_path,
36 | providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
37 | )
38 |
39 | self._bins = 90
40 | self._binwidth = 4
41 | self._angle_offset = 180
42 | self.idx_tensor = np.arange(self._bins, dtype=np.float32)
43 |
44 | self.input_shape = (448, 448)
45 | self.input_mean = [0.485, 0.456, 0.406]
46 | self.input_std = [0.229, 0.224, 0.225]
47 |
48 | input_cfg = self.session.get_inputs()[0]
49 | input_shape = input_cfg.shape
50 |
51 | self.input_name = input_cfg.name
52 | self.input_size = tuple(input_shape[2:][::-1])
53 |
54 | outputs = self.session.get_outputs()
55 | output_names = [output.name for output in outputs]
56 |
57 | self.output_names = output_names
58 | assert len(output_names) == 2, "Expected 2 output nodes, got {}".format(len(output_names))
59 |
60 | def preprocess(self, image: np.ndarray) -> np.ndarray:
61 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
62 | image = cv2.resize(image, self.input_size) # Resize to 448x448
63 |
64 | image = image.astype(np.float32) / 255.0
65 |
66 | mean = np.array(self.input_mean, dtype=np.float32)
67 | std = np.array(self.input_std, dtype=np.float32)
68 | image = (image - mean) / std
69 |
70 | image = np.transpose(image, (2, 0, 1)) # HWC → CHW
71 | image_batch = np.expand_dims(image, axis=0).astype(np.float32) # CHW → BCHW
72 |
73 | return image_batch
74 |
75 | def softmax(self, x: np.ndarray) -> np.ndarray:
76 | e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
77 | return e_x / e_x.sum(axis=1, keepdims=True)
78 |
79 | def decode(self, pitch_logits: np.ndarray, yaw_logits: np.ndarray) -> Tuple[float, float]:
80 | pitch_probs = self.softmax(pitch_logits)
81 | yaw_probs = self.softmax(yaw_logits)
82 |
83 | pitch = np.sum(pitch_probs * self.idx_tensor, axis=1) * self._binwidth - self._angle_offset
84 | yaw = np.sum(yaw_probs * self.idx_tensor, axis=1) * self._binwidth - self._angle_offset
85 |
86 | return np.radians(pitch[0]), np.radians(yaw[0])
87 |
88 | def estimate(self, face_image: np.ndarray) -> Tuple[float, float]:
89 | input_tensor = self.preprocess(face_image)
90 | outputs = self.session.run(self.output_names, {"input": input_tensor})
91 |
92 | return self.decode(outputs[0], outputs[1])
93 |
94 |
95 | def parse_args():
96 | parser = argparse.ArgumentParser(description="Gaze Estimation ONNX Inference")
97 | parser.add_argument(
98 | "--source",
99 | type=str,
100 | required=True,
101 | help="Video path or camera index (e.g., 0 for webcam)"
102 | )
103 | parser.add_argument(
104 | "--model",
105 | type=str,
106 | required=True,
107 | help="Path to ONNX model"
108 | )
109 | parser.add_argument(
110 | "--output",
111 | type=str,
112 | default=None,
113 | help="Path to save output video (optional)"
114 | )
115 | return parser.parse_args()
116 |
117 |
118 | if __name__ == "__main__":
119 | args = parse_args()
120 |
121 | # Handle numeric webcam index
122 | try:
123 | source = int(args.source)
124 | except ValueError:
125 | source = args.source
126 |
127 | cap = cv2.VideoCapture(source)
128 | if not cap.isOpened():
129 | raise IOError(f"Failed to open video source: {args.source}")
130 |
131 | # Initialize Gaze Estimation model
132 | engine = GazeEstimationONNX(model_path=args.model)
133 | detector = uniface.RetinaFace()
134 |
135 | # Optional output writer
136 | writer = None
137 | if args.output:
138 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
139 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
140 | fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
141 | fourcc = cv2.VideoWriter_fourcc(*"mp4v")
142 | writer = cv2.VideoWriter(args.output, fourcc, fps, (width, height))
143 |
144 | while cap.isOpened():
145 | ret, frame = cap.read()
146 | if not ret:
147 | break
148 |
149 | bboxes, _ = detector.detect(frame)
150 |
151 | for bbox in bboxes:
152 | x_min, y_min, x_max, y_max = map(int, bbox[:4])
153 | face_crop = frame[y_min:y_max, x_min:x_max]
154 | if face_crop.size == 0:
155 | continue
156 |
157 | pitch, yaw = engine.estimate(face_crop)
158 | draw_bbox_gaze(frame, bbox, pitch, yaw)
159 |
160 | if writer:
161 | writer.write(frame)
162 |
163 | cv2.imshow("Gaze Estimation", frame)
164 | if cv2.waitKey(1) & 0xFF == ord("q"):
165 | break
166 |
167 | cap.release()
168 | if writer:
169 | writer.release()
170 | cv2.destroyAllWindows()
171 |
--------------------------------------------------------------------------------
/reparameterize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models import mobileone_s0, mobileone_s1, mobileone_s2, mobileone_s3, mobileone_s4, reparameterize_model
3 |
4 | state_dict = torch.load('mobileone_s0.pt')
5 | model = mobileone_s0(pretrained=False, num_classes=90) # 90 bins
6 |
7 | model.load_state_dict(state_dict)
8 |
9 |
10 | model.eval()
11 | model_eval = reparameterize_model(model)
12 |
13 | torch.save(model_eval.state_dict(), 's0_fused.pt')
14 |
15 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==2.0.1
2 | onnxruntime==1.19.0
3 | opencv-python==4.10.0.84
4 | pillow==10.2.0
5 | torch==2.4.0
6 | torchvision==0.19.0
7 | tqdm==4.66.5
8 | uniface==0.1.7
9 |
--------------------------------------------------------------------------------
/utils/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from tqdm import tqdm
4 | from PIL import Image
5 |
6 | import torch
7 | from torch.utils.data import Dataset
8 |
9 |
10 | class Gaze360(Dataset):
11 | def __init__(self, root: str, transform=None, angle: int = 180, binwidth: int = 4, mode: str = 'train'):
12 | self.labels_dir = os.path.join(root, "Label")
13 | self.images_dir = os.path.join(root, "Image")
14 |
15 | if mode in ['train', 'test', 'val']:
16 | labels_file = os.path.join(self.labels_dir, f"{mode}.label")
17 | else:
18 | raise ValueError(f"{mode} must be in ['train','test', 'val']")
19 |
20 | self.transform = transform
21 | self.angle = angle if mode == "train" else 90
22 | self.binwidth = binwidth
23 |
24 | self.lines = []
25 |
26 | with open(labels_file) as f:
27 | lines = f.readlines()[1:] # Skip the header line
28 | self.orig_list_len = len(lines)
29 |
30 | for line in tqdm(lines, desc="Loading Labels"):
31 | gaze2d = line.strip().split(" ")[5]
32 | label = np.array(gaze2d.split(",")).astype(float)
33 | pitch, yaw = label * 180 / np.pi
34 |
35 | if abs(pitch) <= self.angle and abs(yaw) <= self.angle:
36 | self.lines.append(line)
37 |
38 | removed_items = self.orig_list_len - len(self.lines)
39 | print(f"{removed_items} items removed from dataset that have an angle > {self.angle}")
40 |
41 | def __len__(self):
42 | return len(self.lines)
43 |
44 | def __getitem__(self, idx):
45 | line = self.lines[idx].strip().split(" ")
46 |
47 | image_path = line[0]
48 | filename = line[3]
49 | gaze2d = line[5]
50 |
51 | label = np.array(gaze2d.split(",")).astype(float)
52 | pitch, yaw = label * 180 / np.pi
53 |
54 | image = Image.open(os.path.join(self.images_dir, image_path))
55 | if self.transform is not None:
56 | image = self.transform(image)
57 |
58 | # bin values
59 | bins = np.arange(-self.angle, self.angle, self.binwidth)
60 | binned_pose = np.digitize([pitch, yaw], bins) - 1
61 |
62 | # binned and regression labels
63 | binned_labels = torch.tensor(binned_pose, dtype=torch.long)
64 | regression_labels = torch.tensor([pitch, yaw], dtype=torch.float32)
65 |
66 | return image, binned_labels, regression_labels, filename
67 |
68 |
69 | class MPIIGaze(Dataset):
70 | def __init__(self, root: str, transform=None, angle: int = 42, binwidth: int = 3):
71 | self.labels_dir = os.path.join(root, "Label")
72 | self.images_dir = os.path.join(root, "Image")
73 |
74 | label_files = [os.path.join(self.labels_dir, label) for label in os.listdir(self.labels_dir)]
75 |
76 | self.transform = transform
77 | self.orig_list_len = 0
78 | self.binwidth = binwidth
79 | self.angle = angle
80 | self.lines = []
81 |
82 | for label_file in label_files:
83 | with open(label_file) as f:
84 | lines = f.readlines()[1:] # Skip the header line
85 | self.orig_list_len += len(lines)
86 | for line in lines:
87 | gaze2d = line.strip().split(" ")[7]
88 | label = np.array(gaze2d.split(",")).astype("float")
89 | pitch, yaw = label * 180 / np.pi
90 |
91 | if abs(pitch) <= self.angle and abs(yaw) <= self.angle:
92 | self.lines.append(line)
93 |
94 | removed_items = self.orig_list_len - len(self.lines)
95 | print(f"{removed_items} items removed from dataset that have an angle > {self.angle}")
96 |
97 | def __len__(self):
98 | return len(self.lines)
99 |
100 | def __getitem__(self, idx):
101 | line = self.lines[idx].strip().split(" ")
102 |
103 | image_path = line[0]
104 | filename = line[3]
105 | gaze2d = line[7]
106 |
107 | label = np.array(gaze2d.split(",")).astype("float")
108 | pitch, yaw = label * 180 / np.pi
109 |
110 | image = Image.open(os.path.join(self.images_dir, image_path))
111 | if self.transform is not None:
112 | image = self.transform(image)
113 |
114 | # bin values
115 | bins = np.arange(-self.angle, self.angle, self.binwidth)
116 | binned_pose = np.digitize([pitch, yaw], bins) - 1
117 |
118 | # binned and regression labels
119 | binned_labels = torch.tensor(binned_pose, dtype=torch.long)
120 | regression_labels = torch.tensor([pitch, yaw], dtype=torch.float32)
121 |
122 | return image, binned_labels, regression_labels, filename
123 |
--------------------------------------------------------------------------------
/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.utils.data import DataLoader
7 | from torchvision import transforms
8 |
9 | from utils.datasets import Gaze360, MPIIGaze
10 |
11 | from models import (
12 | resnet18,
13 | resnet34,
14 | resnet50,
15 | mobilenet_v2,
16 | mobileone_s0,
17 | mobileone_s1,
18 | mobileone_s2,
19 | mobileone_s3,
20 | mobileone_s4
21 | )
22 |
23 |
24 | def get_model(arch, bins, pretrained=False, inference_mode=False):
25 | """Return the model based on the specified architecture."""
26 | if arch == 'resnet18':
27 | model = resnet18(pretrained=pretrained, num_classes=bins)
28 | elif arch == 'resnet34':
29 | model = resnet34(pretrained=pretrained, num_classes=bins)
30 | elif arch == 'resnet50':
31 | model = resnet50(pretrained=pretrained, num_classes=bins)
32 | elif arch == "mobilenetv2":
33 | model = mobilenet_v2(pretrained=pretrained, num_classes=bins)
34 | elif arch == "mobileone_s0":
35 | model = mobileone_s0(pretrained=pretrained, num_classes=bins, inference_mode=inference_mode)
36 | elif arch == "mobileone_s1":
37 | model = mobileone_s1(pretrained=pretrained, num_classes=bins, inference_mode=inference_mode)
38 | elif arch == "mobileone_s2":
39 | model = mobileone_s2(pretrained=pretrained, num_classes=bins, inference_mode=inference_mode)
40 | elif arch == "mobileone_s3":
41 | model = mobileone_s3(pretrained=pretrained, num_classes=bins, inference_mode=inference_mode)
42 | elif arch == "mobileone_s4":
43 | model = mobileone_s4(pretrained=pretrained, num_classes=bins, inference_mode=inference_mode)
44 | else:
45 | raise ValueError(f"Please choose available model architecture, currently chosen: {arch}")
46 | return model
47 |
48 |
49 | def angular_error(gaze_vector, label_vector):
50 | dot_product = np.dot(gaze_vector, label_vector)
51 | norm_product = np.linalg.norm(gaze_vector) * np.linalg.norm(label_vector)
52 | cosine_similarity = min(dot_product / norm_product, 0.9999999)
53 |
54 | return np.degrees(np.arccos(cosine_similarity))
55 |
56 |
57 | def gaze_to_3d(gaze):
58 | yaw = gaze[0] # Horizontal angle
59 | pitch = gaze[1] # Vertical angle
60 |
61 | gaze_vector = np.zeros(3)
62 | gaze_vector[0] = -np.cos(pitch) * np.sin(yaw)
63 | gaze_vector[1] = -np.sin(pitch)
64 | gaze_vector[2] = -np.cos(pitch) * np.cos(yaw)
65 |
66 | return gaze_vector
67 |
68 |
69 | def get_dataloader(params, mode="train"):
70 | """Load dataset and return DataLoader."""
71 |
72 | transform = transforms.Compose([
73 | transforms.Resize(448),
74 | transforms.ToTensor(),
75 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
76 | ])
77 |
78 | if params.dataset == "gaze360":
79 | dataset = Gaze360(params.data, transform, angle=params.angle, binwidth=params.binwidth, mode=mode)
80 | elif params.dataset == "mpiigaze":
81 | dataset = MPIIGaze(params.data, transform, angle=params.angle, binwidth=params.binwidth)
82 | else:
83 | raise ValueError("Supported dataset are `gaze360` and `mpiigaze`")
84 |
85 | data_loader = DataLoader(
86 | dataset=dataset,
87 | batch_size=params.batch_size,
88 | shuffle=True if mode == "train" else False,
89 | num_workers=params.num_workers,
90 | pin_memory=True
91 | )
92 | return data_loader
93 |
94 | def draw_gaze(frame, bbox, pitch, yaw, thickness=2, color=(0, 0, 255)):
95 | """Draws gaze direction on a frame given bounding box and gaze angles."""
96 | # Unpack bounding box coordinates
97 | x_min, y_min, x_max, y_max = map(int, bbox[:4])
98 |
99 | # Calculate center of the bounding box
100 | x_center = (x_min + x_max) // 2
101 | y_center = (y_min + y_max) // 2
102 |
103 | # Handle grayscale frames by converting them to BGR
104 | if len(frame.shape) == 2 or frame.shape[2] == 1:
105 | frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
106 |
107 | # Calculate the direction of the gaze
108 | length = x_max - x_min
109 | dx = int(-length * np.sin(pitch) * np.cos(yaw))
110 | dy = int(-length * np.sin(yaw))
111 |
112 | point1 = (x_center, y_center)
113 | point2 = (x_center + dx, y_center + dy)
114 |
115 | # Draw gaze direction
116 | cv2.circle(frame, (x_center, y_center), radius=4, color=color, thickness=-1)
117 | cv2.arrowedLine(
118 | frame,
119 | point1,
120 | point2,
121 | color=color,
122 | thickness=thickness,
123 | line_type=cv2.LINE_AA,
124 | tipLength=0.25
125 | )
126 |
127 |
128 |
129 | def draw_bbox(image, bbox, color=(0, 255, 0), thickness=2, proportion=0.2):
130 | x_min, y_min, x_max, y_max = map(int, bbox[:4])
131 |
132 | width = x_max - x_min
133 | height = y_max - y_min
134 |
135 | corner_length = int(proportion * min(width, height))
136 |
137 | # Draw the rectangle
138 | cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, 1)
139 |
140 | # Top-left corner
141 | cv2.line(image, (x_min, y_min), (x_min + corner_length, y_min), color, thickness)
142 | cv2.line(image, (x_min, y_min), (x_min, y_min + corner_length), color, thickness)
143 |
144 | # Top-right corner
145 | cv2.line(image, (x_max, y_min), (x_max - corner_length, y_min), color, thickness)
146 | cv2.line(image, (x_max, y_min), (x_max, y_min + corner_length), color, thickness)
147 |
148 | # Bottom-left corner
149 | cv2.line(image, (x_min, y_max), (x_min, y_max - corner_length), color, thickness)
150 | cv2.line(image, (x_min, y_max), (x_min + corner_length, y_max), color, thickness)
151 |
152 | # Bottom-right corner
153 | cv2.line(image, (x_max, y_max), (x_max, y_max - corner_length), color, thickness)
154 | cv2.line(image, (x_max, y_max), (x_max - corner_length, y_max), color, thickness)
155 |
156 |
157 | def draw_bbox_gaze(frame: np.ndarray, bbox, pitch, yaw):
158 | draw_bbox(frame, bbox)
159 | draw_gaze(frame, bbox, pitch, yaw)
160 |
--------------------------------------------------------------------------------
/weights/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yakhyo/gaze-estimation/2ccde1d08d007727c2df1ce704c32e2683b2d0b9/weights/.gitkeep
--------------------------------------------------------------------------------