├── code ├── models │ ├── gaze_aaai_refine_headpose.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── gaze_aaai.cpython-36.pyc │ │ ├── gaze_aaai.cpython-37.pyc │ │ ├── gaze_base.cpython-36.pyc │ │ ├── gaze_base.cpython-37.pyc │ │ ├── gaze_depth.cpython-36.pyc │ │ ├── gaze_depth.cpython-37.pyc │ │ ├── gaze_aaai_dv2.cpython-36.pyc │ │ ├── gaze_aaai_dv2.cpython-37.pyc │ │ ├── gaze_aaai_is.cpython-36.pyc │ │ ├── gaze_aaai_is.cpython-37.pyc │ │ ├── gaze_depth_v2.cpython-36.pyc │ │ ├── gaze_depth_v2.cpython-37.pyc │ │ ├── gaze_depth_v3.cpython-37.pyc │ │ ├── gaze_depth_v4.cpython-37.pyc │ │ ├── gaze_base_facepose.cpython-36.pyc │ │ ├── gaze_base_facepose.cpython-37.pyc │ │ ├── gaze_aaai_pose16_relu.cpython-36.pyc │ │ ├── gaze_aaai_pose16_relu.cpython-37.pyc │ │ ├── gaze_aaai_refine_headpose.cpython-36.pyc │ │ └── gaze_aaai_refine_headpose.cpython-37.pyc │ ├── __init__.py │ ├── depth_solver.py │ ├── gaze_base.py │ ├── gaze_base_facepose.py │ ├── gaze_depth.py │ ├── gaze_depth_v2.py │ ├── gaze_aaai_pose16_relu.py │ ├── gaze_depth_v3.py │ ├── gaze_depth_v4.py │ ├── gaze_depth_v5.py │ ├── gaze_aaai_dv2.py │ ├── gaze_aaai.py │ └── gaze_aaai_is.py ├── utils │ ├── __pycache__ │ │ ├── edict.cpython-36.pyc │ │ ├── edict.cpython-37.pyc │ │ ├── trainer.cpython-36.pyc │ │ ├── trainer.cpython-37.pyc │ │ ├── vispyplot.cpython-36.pyc │ │ └── vispyplot.cpython-37.pyc │ ├── edict.py │ ├── vispyplot.py │ ├── trainer.py │ └── gen_filelist.py ├── data │ ├── __pycache__ │ │ ├── gaze_dataset.cpython-36.pyc │ │ ├── gaze_dataset.cpython-37.pyc │ │ ├── gaze_dataset_v2.cpython-36.pyc │ │ └── gaze_dataset_v2.cpython-37.pyc │ ├── gen_landmark.py │ ├── data_check.py │ ├── gaze_dataset.py │ ├── gaze_dataset_v2.py │ └── dataset.py ├── .idea │ ├── markdown-navigator │ │ └── profiles_settings.xml │ ├── modules.xml │ ├── code_release.iml │ ├── deployment.xml │ ├── webServers.xml │ ├── inspectionProfiles │ │ └── Project_Default.xml │ ├── misc.xml │ └── workspace.xml ├── train_aaai_refine_headpose.sh ├── train_aaai_is.sh ├── train_aaai.sh ├── train_aaai_dv2.sh ├── train_aaai_pose16_relu.sh └── gen_optim_samples.py ├── camera ├── color.mat ├── depth.mat └── ex.txt ├── images └── poster.pdf └── README.md /code/models/gaze_aaai_refine_headpose.py: -------------------------------------------------------------------------------- 1 | from .gaze_aaai import * -------------------------------------------------------------------------------- /camera/color.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/camera/color.mat -------------------------------------------------------------------------------- /camera/depth.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/camera/depth.mat -------------------------------------------------------------------------------- /images/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/images/poster.pdf -------------------------------------------------------------------------------- /code/utils/__pycache__/edict.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/utils/__pycache__/edict.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/edict.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/utils/__pycache__/edict.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/utils/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/utils/__pycache__/trainer.cpython-37.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/vispyplot.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/utils/__pycache__/vispyplot.cpython-36.pyc -------------------------------------------------------------------------------- /code/utils/__pycache__/vispyplot.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/utils/__pycache__/vispyplot.cpython-37.pyc -------------------------------------------------------------------------------- /code/data/__pycache__/gaze_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/data/__pycache__/gaze_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /code/data/__pycache__/gaze_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/data/__pycache__/gaze_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_aaai.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_aaai.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_aaai.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_aaai.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_base.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_base.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_depth.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_depth.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_depth.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_depth.cpython-37.pyc -------------------------------------------------------------------------------- /code/data/__pycache__/gaze_dataset_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/data/__pycache__/gaze_dataset_v2.cpython-36.pyc -------------------------------------------------------------------------------- /code/data/__pycache__/gaze_dataset_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/data/__pycache__/gaze_dataset_v2.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_aaai_dv2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_aaai_dv2.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_aaai_dv2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_aaai_dv2.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_aaai_is.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_aaai_is.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_aaai_is.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_aaai_is.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_depth_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_depth_v2.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_depth_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_depth_v2.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_depth_v3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_depth_v3.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_depth_v4.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_depth_v4.cpython-37.pyc -------------------------------------------------------------------------------- /camera/ex.txt: -------------------------------------------------------------------------------- 1 | R = [0.999746 0.000291 -0.022534; 2 | 0.000156 0.999803 0.019831; 3 | 0.022535 -0.019830 0.999549]'; 4 | t=[23.590678 -0.954391 -4.049306 ]'; -------------------------------------------------------------------------------- /code/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import gaze_aaai 2 | from . import gaze_aaai_pose16_relu 3 | from . import gaze_aaai_dv2 4 | from . import gaze_aaai_refine_headpose -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_base_facepose.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_base_facepose.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_base_facepose.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_base_facepose.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_aaai_pose16_relu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_aaai_pose16_relu.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_aaai_pose16_relu.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_aaai_pose16_relu.cpython-37.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_aaai_refine_headpose.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_aaai_refine_headpose.cpython-36.pyc -------------------------------------------------------------------------------- /code/models/__pycache__/gaze_aaai_refine_headpose.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/RGBD-Gaze/HEAD/code/models/__pycache__/gaze_aaai_refine_headpose.cpython-37.pyc -------------------------------------------------------------------------------- /code/.idea/markdown-navigator/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /code/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /code/train_aaai_refine_headpose.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python trainer_aaai.py --data-root /home/ziheng/datasets/gaze \ 2 | --batch-size-train 64 --batch-size-val 128 --num-workers 0 --exp-name gaze_aaai_refine_headpose \ 3 | - train_headpose --epochs 10 --lr 1e-2\ 4 | - train_headpose --epochs 15 --lr 1e-3\ 5 | - train_headpose --epochs 17 --lr 1e-4\ 6 | - end 7 | 8 | -------------------------------------------------------------------------------- /code/train_aaai_is.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python trainer_aaai_is.py --data-root /home/ziheng/datasets/gaze \ 2 | --batch-size-train 64 --batch-size-val 128 --num-workers 24 \ 3 | - train_base --epochs 10 --lr 1e-2 --use-refined-depth False --fine-tune-headpose True\ 4 | - train_base --epochs 15 --lr 1e-3 --use-refined-depth False --fine-tune-headpose True\ 5 | - train_base --epochs 17 --lr 1e-4 --use-refined-depth False --fine-tune-headpose True\ 6 | - end 7 | 8 | -------------------------------------------------------------------------------- /code/.idea/code_release.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /code/train_aaai.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CUDA_VISIBLE_DEVICES=1, python trainer_aaai.py --data-root /p300/datasets/Gaze/ShanghaiTechGaze+/data \ 3 | --batch-size-train 64 --batch-size-val 128 --num-workers 24 \ 4 | - train_base --epochs 10 --lr 1e-2 --use-refined-depth False --fine-tune-headpose True\ 5 | - train_base --epochs 15 --lr 1e-3 --use-refined-depth False --fine-tune-headpose True\ 6 | - train_base --epochs 17 --lr 1e-4 --use-refined-depth False --fine-tune-headpose True\ 7 | - end 8 | 9 | -------------------------------------------------------------------------------- /code/train_aaai_dv2.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python trainer_aaai_dv2.py --data-root /home/ziheng/datasets/gaze \ 2 | --batch-size-train 64 --batch-size-val 128 --num-workers 24 --exp-name gaze_aaai_dv2 \ 3 | - train_base --epochs 10 --lr 1e-3 --use-refined-depth False --fine-tune-headpose True\ 4 | - train_base --epochs 15 --lr 1e-4 --use-refined-depth False --fine-tune-headpose True\ 5 | - train_base --epochs 17 --lr 1e-5 --use-refined-depth False --fine-tune-headpose True\ 6 | - end 7 | 8 | -------------------------------------------------------------------------------- /code/train_aaai_pose16_relu.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python trainer_aaai.py --data-root /home/ziheng/datasets/gaze \ 2 | --batch-size-train 64 --batch-size-val 128 --num-workers 24 --exp-name gaze_aaai_pose16_relu \ 3 | - train_base --epochs 10 --lr 1e-2 --use-refined-depth False --fine-tune-headpose True\ 4 | - train_base --epochs 15 --lr 1e-3 --use-refined-depth False --fine-tune-headpose True\ 5 | - train_base --epochs 17 --lr 1e-4 --use-refined-depth False --fine-tune-headpose True\ 6 | - end 7 | 8 | -------------------------------------------------------------------------------- /code/.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /code/.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 14 | 15 | -------------------------------------------------------------------------------- /code/.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 17 | -------------------------------------------------------------------------------- /code/utils/edict.py: -------------------------------------------------------------------------------- 1 | class edict(dict): 2 | def __init__(self, d=None, **kwargs): 3 | super(edict, self).__init__() 4 | if d is None: 5 | d = {} 6 | if kwargs: 7 | d.update(**kwargs) 8 | for k, v in d.items(): 9 | setattr(self, k, v) 10 | # Class attributes 11 | for k in self.__class__.__dict__.keys(): 12 | if not (k.startswith('__') and k.endswith('__')): 13 | setattr(self, k, getattr(self, k)) 14 | 15 | def __setattr__(self, name, value): 16 | # if isinstance(value, (list, tuple)): 17 | # value = [self.__class__(x) 18 | # if isinstance(x, dict) else x for x in value] 19 | # elif isinstance(value, dict) and not isinstance(value, self.__class__): 20 | # value = self.__class__(value) 21 | super(edict, self).__setattr__(name, value) 22 | super(edict, self).__setitem__(name, value) 23 | 24 | __setitem__ = __setattr__ 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RGBD-Gaze 2 | 3 | **'RGBD Based Gaze Estimation via Multi-task CNN'** [[paper](https://www.aaai.org/ojs/index.php/AAAI/article/view/4094)] 4 | [[poster](images/poster.pdf)] 5 | 6 | Dongze Lian*, Ziheng Zhang*, Weixin Luo, Lina Hu, Minye Wu, Zechao Li, Jingyi Yu, Shenghua Gao 7 | 8 | (* Equal Contribution) 9 | This paper is accepted by AAAI 2019. 10 | 11 | 12 | # The ShanghaiTechGaze+ dataset 13 | Download our ShanghaiTechGaze+ dataset: [OneDrive](https://yien01-my.sharepoint.com/:f:/g/personal/doubility_z0_tn/Er4Cs0-o6BtDoHbm8hqLnIcBGoTQcCCh61ZLShdADGAOGg?e=auJdbU). 14 | 15 | 16 | Our previous ShanghaiTechGaze dataset is also released in [Github](https://github.com/dongzelian/multi-view-gaze). 17 | 18 | # The camera parameters 19 | Camera parameters are saved in [camera](https://github.com/svip-lab/RGBD-Gaze/tree/master/camera) folder. 'color.mat' and 'depth.mat' are the intrinsic parameters of color camera and depth camera, respectively. 'ex.txt' is the extrinsic parameters of them. 20 | 21 | 22 | # Citation 23 | ``` 24 | @article{lian2018tnnls, 25 | Author = {Dongze Lian, Lina Hu, Weixin Luo, Yanyu Xu, Lixin Duan, Jingyi Yu, Shenghua Gao.}, 26 | Title = {Multi-view Multi-task Gaze Prediction with Deep Convolutional Neural Networks.}, 27 | Journal = {TNNLS}, 28 | Year = {2018} 29 | } 30 | ``` 31 | 32 | ``` 33 | @article{lian2019aaai, 34 | Author = {Dongze Lian, Ziheng Zhang, Weixin Luo, lina hu, Minye Wu, Zechao Li, Jingyi Yu, Shenghua Gao}, 35 | Title = {RGBD Based Gaze Estimation via Multi-task CNN}, 36 | Journal = {AAAI}, 37 | Year = {2019}} 38 | ``` -------------------------------------------------------------------------------- /code/utils/vispyplot.py: -------------------------------------------------------------------------------- 1 | import visdom 2 | from io import StringIO 3 | import matplotlib.pyplot as plt 4 | from .edict import edict 5 | from multiprocessing import Process 6 | 7 | plt.switch_backend('agg') 8 | 9 | vis_config = edict( 10 | server = 'http://127.0.0.1', 11 | port = 31540, 12 | env = 'main', 13 | ) 14 | 15 | 16 | class set_draw(object): 17 | def __init__(self, server=None, port=None, env=None, **subplot_kwargs): 18 | self._config = edict(vis_config) 19 | vis_config.server = server if server is not None else vis_config.server 20 | vis_config.port = port if port is not None else vis_config.port 21 | vis_config.env = env if env is not None else vis_config.env 22 | self.viz = visdom.Visdom(**vis_config) 23 | self.fig = None 24 | self.axes = None 25 | self.name = subplot_kwargs.pop('name') if 'name' in subplot_kwargs.keys() else 'default' 26 | if not (server or port): 27 | self.subplots(name=self.name, **subplot_kwargs) 28 | 29 | def __enter__(self): 30 | assert self.fig is not None and self.axes is not None 31 | return self 32 | 33 | def __exit__(self, exc_type, exc_val, exc_tb): 34 | # p = Process(target=self._remote_draw) 35 | # p.start() 36 | self._remote_draw() 37 | 38 | def _remote_draw(self): 39 | strio = StringIO() 40 | self.fig.savefig(strio, format="svg") 41 | plt.close(self.fig) 42 | self.viz.svg(svgstr=strio.getvalue(), win=self.name) 43 | vis_config.update(self._config) 44 | self.viz.close() 45 | self.viz = visdom.Visdom(**vis_config) 46 | 47 | def subplots(self, name, **plt_subplot_kwargs): 48 | self.name = name 49 | self.fig, self.axes = plt.subplots(**plt_subplot_kwargs) 50 | return self 51 | 52 | def close(self): 53 | self.viz.close(self.name, self.env) 54 | 55 | def __getattr__(self, item): 56 | return getattr(self.axes, item) 57 | -------------------------------------------------------------------------------- /code/models/depth_solver.py: -------------------------------------------------------------------------------- 1 | import cvxpy as cp 2 | import numpy as np 3 | from mpl_toolkits.mplot3d import Axes3D 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def scatter_face_landmark(face_landmark, median, median_mask): 9 | fig = plt.figure() 10 | ax = fig.add_subplot(111, projection='3d') 11 | ax.set_xlabel('X') 12 | ax.set_ylabel('Y') 13 | ax.set_zlabel('D') 14 | xs = face_landmark[median_mask, 0] 15 | ys = face_landmark[median_mask, 1] 16 | zs = median[median_mask] 17 | ax.scatter(xs, ys, zs, c="r", marker="x") 18 | ax.view_init(-109, -64) 19 | return fig, ax 20 | 21 | 22 | class DepthSolver(object): 23 | def __init__(self, solver=cp.OSQP): 24 | self.solver = solver 25 | 26 | def solve(self, rel_depth_pred, rel_depth_mask, depth_obs, depth_obs_mask, dist_landmark, sigma_dist, lambda_E_D, 27 | lambda_E_rel, lambda_E_s): 28 | E_D, E_rel, E_s = 0, 0, 0 29 | bs, num_landmarks = len(rel_depth_pred), len(depth_obs[0]) 30 | depth_pred = cp.Variable(value=np.ones((bs, num_landmarks)) * 750, shape=(bs, num_landmarks), nonneg=True) 31 | for i in range(bs): 32 | D = depth_pred[i] 33 | D_stack1 = cp.atoms.affine.transpose.transpose(cp.atoms.affine.vstack.vstack([D] * num_landmarks)) 34 | D_stack2 = cp.atoms.affine.vstack.vstack([D] * num_landmarks) 35 | D_err = (D_stack1 - D_stack2) 36 | if lambda_E_D > 0: 37 | E_D += cp.sum_squares(cp.atoms.affine.binary_operators.multiply( 38 | depth_obs[i] - depth_pred[i], depth_obs_mask[i] 39 | )) 40 | if lambda_E_rel > 0: 41 | E_rel += cp.sum_squares(cp.atoms.affine.binary_operators.multiply( 42 | D_err - rel_depth_pred[i] * dist_landmark[i], rel_depth_mask[i] 43 | )) 44 | if lambda_E_s > 0: 45 | E_s += cp.sum_squares(cp.atoms.affine.binary_operators.multiply( 46 | D_err, np.exp(-dist_landmark[i] ** 2 / sigma_dist ** 2) 47 | )) 48 | obj = cp.Minimize(E_D * lambda_E_D + E_rel * lambda_E_rel + E_s * lambda_E_s) 49 | prob = cp.Problem(obj, [D >= 0]) 50 | assert prob.is_dcp() and prob.is_qp() 51 | 52 | prob.solve(solver=self.solver, verbose=False) 53 | 54 | if prob.status != cp.OPTIMAL: 55 | raise Exception("Solver did not converge!") 56 | res = np.stack([d.value for d in depth_pred], axis=0) 57 | return res 58 | -------------------------------------------------------------------------------- /code/data/gen_landmark.py: -------------------------------------------------------------------------------- 1 | from gaze_dataset import GazePointAllDataset 2 | from random import Random 3 | from torchvision import transforms 4 | import numpy as np 5 | import cv2 6 | import dlib 7 | from multiprocessing import Pool 8 | from tqdm import tqdm 9 | import pickle 10 | import os 11 | import time 12 | 13 | 14 | def rect_to_bb(rect): 15 | # take a bounding predicted by dlib and convert it 16 | # to the format (x, y, w, h) as we would normally do 17 | # with OpenCV 18 | x = rect.left() 19 | y = rect.top() 20 | w = rect.right() - x 21 | h = rect.bottom() - y 22 | 23 | # return a tuple of (x, y, w, h) 24 | return (x, y, w, h) 25 | 26 | 27 | def shape_to_np(shape, dtype="int"): 28 | # initialize the list of (x, y)-coordinates 29 | coords = np.zeros((68, 2), dtype=dtype) 30 | 31 | # loop over the 68 facial landmarks and convert them 32 | # to a 2-tuple of (x, y)-coordinates 33 | for i in range(0, 68): 34 | coords[i] = (shape.part(i).x, shape.part(i).y) 35 | 36 | # return the list of (x, y)-coordinates 37 | return coords 38 | 39 | 40 | data_root = r"D:\data\gaze" 41 | 42 | dataset = GazePointAllDataset(root_dir=data_root, 43 | transform=None, 44 | phase='val', 45 | face_image=True, 46 | face_bbox=True) 47 | min_depth = [] 48 | max_depth = [] 49 | mean_depth = [] 50 | median_depth = [] 51 | bad_samples = [] 52 | 53 | cnt = 0 54 | 55 | 56 | def process(i): 57 | sample = dataset[i] 58 | face_image = np.asarray(sample["face_image"]) 59 | face_bbox = sample["face_bbox"].numpy() 60 | w_ori, h_ori = face_bbox[2] - face_bbox[0], face_bbox[3] - face_bbox[1] 61 | h_now, w_now = face_image.shape[:2] 62 | scale = np.array(w_ori / w_now, h_ori / h_now) 63 | pid = f"{sample['pid'].item():05d}" 64 | sid = f"{sample['sid'].item():05d}" 65 | os.makedirs(os.path.join(data_root, "landmark", pid), exist_ok=True) 66 | save_file = os.path.join(data_root, "landmark", pid, sid) 67 | detector = dlib.get_frontal_face_detector() 68 | predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat") 69 | rects = detector(face_image, 1) 70 | if len(rects) > 0: 71 | shape = predictor(face_image, rects[0]) 72 | face_lm = shape_to_np(shape) 73 | ori_lm = face_lm * scale + face_bbox[:2] 74 | np.save(save_file + '.npy', ori_lm) 75 | # Get the landmarks/parts for the face in box d. 76 | for (x, y) in face_lm: 77 | cv2.circle(face_image, (x, y), 1, (255, 0, 0), -1) 78 | cv2.imwrite(save_file + '.jpg', face_image[:, :, ::-1]) 79 | # Draw the face landmarks on the screen. 80 | else: 81 | pass 82 | # print(f"warn: landmark detection failed for pid: {pid}, sid: {sid}", flush=True) 83 | # np.save(save_file + '.npy', np.zeros((68, 2), dtype="int")) 84 | 85 | 86 | if __name__ == '__main__': 87 | with Pool(40) as pool: 88 | with tqdm(desc="progress", total=len(dataset)) as pbar: 89 | for i, _ in tqdm(enumerate(pool.imap_unordered(process, range(len(dataset))))): 90 | pbar.update() 91 | -------------------------------------------------------------------------------- /code/models/gaze_base.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet, model_urls 2 | import torch.utils.model_zoo as model_zoo 3 | from torch import nn 4 | import torch as th 5 | 6 | 7 | class ResNetEncoder(ResNet): 8 | def forward(self, x): 9 | x = self.conv1(x) 10 | x = self.bn1(x) 11 | x = self.relu(x) 12 | # x112_64 = x 13 | x = self.maxpool(x) 14 | x = self.layer1(x) 15 | # x56_64 = x 16 | x = self.layer2(x) 17 | # x28_128 = x 18 | x = self.layer3(x) 19 | # x14_256 = x 20 | x = self.layer4(x) 21 | x = self.avgpool(x) 22 | x = x.view(x.size(0), -1) 23 | x = self.relu(x) 24 | 25 | return x#, x112_64, x56_64, x28_128, x14_256 26 | 27 | 28 | def resnet18(pretrained=False, **kwargs): 29 | """Constructs a ResNet-18 model. 30 | 31 | Args: 32 | pretrained (bool): If True, returns a model pre-trained on ImageNet 33 | """ 34 | model = ResNetEncoder(BasicBlock, [2, 2, 2, 2], **kwargs) 35 | if pretrained: 36 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 37 | return model 38 | 39 | 40 | def resnet34(pretrained=False, **kwargs): 41 | """Constructs a ResNet-34 model. 42 | 43 | Args: 44 | pretrained (bool): If True, returns a model pre-trained on ImageNet 45 | """ 46 | model = ResNetEncoder(BasicBlock, [3, 4, 6, 3], **kwargs) 47 | if pretrained: 48 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 49 | return model 50 | 51 | 52 | def resnet50(pretrained=False, **kwargs): 53 | """Constructs a ResNet-50 model. 54 | 55 | Args: 56 | pretrained (bool): If True, returns a model pre-trained on ImageNet 57 | """ 58 | model = ResNetEncoder(Bottleneck, [3, 4, 6, 3], **kwargs) 59 | if pretrained: 60 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 61 | return model 62 | 63 | 64 | def resnet101(pretrained=False, **kwargs): 65 | """Constructs a ResNet-101 model. 66 | 67 | Args: 68 | pretrained (bool): If True, returns a model pre-trained on ImageNet 69 | """ 70 | model = ResNetEncoder(Bottleneck, [3, 4, 23, 3], **kwargs) 71 | if pretrained: 72 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 73 | return model 74 | 75 | 76 | def resnet152(pretrained=False, **kwargs): 77 | """Constructs a ResNet-152 model. 78 | 79 | Args: 80 | pretrained (bool): If True, returns a model pre-trained on ImageNet 81 | """ 82 | model = ResNetEncoder(Bottleneck, [3, 8, 36, 3], **kwargs) 83 | if pretrained: 84 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 85 | return model 86 | 87 | 88 | class Decoder(nn.Module): 89 | def __init__(self, feat_dim=512): 90 | super(Decoder, self).__init__() 91 | self.decoder1 = nn.Sequential( 92 | nn.Linear(feat_dim, 256), 93 | nn.ReLU(), 94 | nn.Linear(256, 128), 95 | nn.ReLU(), 96 | ) 97 | self.decoder2 = nn.Sequential( 98 | nn.Linear(128 + 3, 2) 99 | ) 100 | 101 | def forward(self, feat, info): 102 | out = self.decoder1(feat) 103 | out = th.cat([out, info], 1) 104 | out = self.decoder2(out) 105 | return out 106 | -------------------------------------------------------------------------------- /code/data/data_check.py: -------------------------------------------------------------------------------- 1 | from data.gaze_dataset import GazePointAllDataset 2 | from random import Random 3 | from torchvision import transforms 4 | import numpy as np 5 | import cv2 6 | from tqdm import trange 7 | import pickle 8 | 9 | rnd = Random() 10 | 11 | data_transforms = { 12 | 'train': transforms.Compose([ 13 | transforms.Resize(224), 14 | transforms.ToTensor(), 15 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 16 | ]), 17 | 'val': transforms.Compose([ 18 | transforms.Resize(224), 19 | transforms.ToTensor(), 20 | # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 21 | ]) 22 | } 23 | 24 | dataset = GazePointAllDataset(root_dir=r"D:\data\gaze", 25 | transform=data_transforms['train'], 26 | phase='train', 27 | face_image=True, face_depth=True, eye_image=True, 28 | eye_depth=True, 29 | info=True, eye_bbox=True, face_bbox=True, eye_coord=True) 30 | 31 | sample = dataset[rnd.choice(range(len(dataset)))] 32 | 33 | # face image and depth 34 | face_image, face_depth = sample["face_image"].numpy().transpose((1, 2, 0))[:, :, ::-1], sample["face_depth"].numpy().squeeze() 35 | # cv2.imshow("face_image", face_image) 36 | # cv2.imshow("face_depth", face_depth) 37 | # cv2.waitKey() 38 | 39 | # left and right eye image, coord and bbox 40 | left_eye_image, right_eye_image = sample["left_eye_image"].numpy().transpose((1, 2, 0))[:, :, ::-1], sample["right_eye_image"].numpy().transpose((1, 2, 0))[:, :, ::-1], 41 | face_bbox = sample["face_bbox"].numpy() 42 | left_eye_coord, right_eye_coord = sample["left_eye_coord"].numpy(), sample["right_eye_coord"].numpy() 43 | face_scale = sample["face_scale_factor"].item() 44 | left_eye_bbox, right_eye_bbox = sample["left_eye_bbox"].numpy(), sample["right_eye_bbox"].numpy() 45 | left_eye_bbox[:2] -= face_bbox[:2] 46 | left_eye_bbox[2:] -= face_bbox[:2] 47 | right_eye_bbox[:2] -= face_bbox[:2] 48 | right_eye_bbox[2:] -= face_bbox[:2] 49 | left_eye_coord -= face_bbox[:2] 50 | right_eye_coord -= face_bbox[:2] 51 | 52 | face_image = (face_image * 255).astype(np.uint8).copy() 53 | 54 | cv2.rectangle(face_image, tuple(np.int32(left_eye_bbox[:2] * face_scale).tolist()), tuple(np.int32(left_eye_bbox[2:] * face_scale).tolist()), (0, 255, 0)) 55 | cv2.rectangle(face_image, tuple(np.int32(right_eye_bbox[:2] * face_scale).tolist()), tuple(np.int32(right_eye_bbox[2:] * face_scale).tolist()), (0, 255, 0)) 56 | cv2.circle(face_image, tuple(np.int32(left_eye_coord * face_scale).tolist()), 3, [0, 0, 255], 1) 57 | cv2.circle(face_image, tuple(np.int32(right_eye_coord * face_scale).tolist()), 3, [0, 0, 255], 1) 58 | 59 | cv2.imshow("left_eye_image", left_eye_image) 60 | cv2.imshow("right_eye_image", right_eye_image) 61 | cv2.imshow("face_image", face_image) 62 | cv2.waitKey() 63 | 64 | # dataset = GazePointAllDataset(root_dir=r"D:\data\gaze", 65 | # transform=None, 66 | # phase='train', 67 | # face_depth=True) 68 | # min_depth = [] 69 | # max_depth = [] 70 | # mean_depth = [] 71 | # median_depth = [] 72 | # bad_samples = [] 73 | # for i in trange(len(dataset)): 74 | # sample = dataset[i] 75 | # face_depth = sample["face_depth"] 76 | # if np.sum((face_depth > 0) * (face_depth < 1024)) == 0: 77 | # bad_samples.append(i) 78 | # continue 79 | # median = np.median(face_depth[(face_depth > 0) * (face_depth < 1024)]) 80 | # if not (600 < median < 900): 81 | # bad_samples.append(i) 82 | # continue 83 | # min_depth.append(face_depth[face_depth > 0].min()) 84 | # max_depth.append(face_depth[face_depth < 1024].max()) 85 | # mean_depth.append(face_depth[(face_depth > 0) * (face_depth < 1024)].mean()) 86 | # median_depth.append(median) 87 | # pass 88 | # 89 | # with open("depth_stat.pkl", "wb+") as fp: 90 | # pickle.dump((bad_samples, min_depth, max_depth, mean_depth, median_depth), fp) 91 | # print(f"min: {min_depth}, max: {max_depth}\n") 92 | -------------------------------------------------------------------------------- /code/models/gaze_base_facepose.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet, model_urls 2 | import torch.utils.model_zoo as model_zoo 3 | from torch import nn 4 | import torch as th 5 | 6 | 7 | class ResNetEncoder(ResNet): 8 | def forward(self, x): 9 | x = self.conv1(x) 10 | x = self.bn1(x) 11 | x = self.relu(x) 12 | # x112_64 = x 13 | x = self.maxpool(x) 14 | x = self.layer1(x) 15 | # x56_64 = x 16 | x = self.layer2(x) 17 | # x28_128 = x 18 | x = self.layer3(x) 19 | # x14_256 = x 20 | x = self.layer4(x) 21 | x = self.avgpool(x) 22 | x = x.view(x.size(0), -1) 23 | x = self.relu(x) 24 | 25 | return x#, x112_64, x56_64, x28_128, x14_256 26 | 27 | 28 | def resnet18(pretrained=False, **kwargs): 29 | """Constructs a ResNet-18 model. 30 | 31 | Args: 32 | pretrained (bool): If True, returns a model pre-trained on ImageNet 33 | """ 34 | model = ResNetEncoder(BasicBlock, [2, 2, 2, 2], **kwargs) 35 | if pretrained: 36 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 37 | return model 38 | 39 | 40 | def resnet34(pretrained=False, **kwargs): 41 | """Constructs a ResNet-34 model. 42 | 43 | Args: 44 | pretrained (bool): If True, returns a model pre-trained on ImageNet 45 | """ 46 | model = ResNetEncoder(BasicBlock, [3, 4, 6, 3], **kwargs) 47 | if pretrained: 48 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 49 | return model 50 | 51 | 52 | def resnet50(pretrained=False, **kwargs): 53 | """Constructs a ResNet-50 model. 54 | 55 | Args: 56 | pretrained (bool): If True, returns a model pre-trained on ImageNet 57 | """ 58 | model = ResNetEncoder(Bottleneck, [3, 4, 6, 3], **kwargs) 59 | if pretrained: 60 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 61 | return model 62 | 63 | 64 | def resnet101(pretrained=False, **kwargs): 65 | """Constructs a ResNet-101 model. 66 | 67 | Args: 68 | pretrained (bool): If True, returns a model pre-trained on ImageNet 69 | """ 70 | model = ResNetEncoder(Bottleneck, [3, 4, 23, 3], **kwargs) 71 | if pretrained: 72 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 73 | return model 74 | 75 | 76 | def resnet152(pretrained=False, **kwargs): 77 | """Constructs a ResNet-152 model. 78 | 79 | Args: 80 | pretrained (bool): If True, returns a model pre-trained on ImageNet 81 | """ 82 | model = ResNetEncoder(Bottleneck, [3, 8, 36, 3], **kwargs) 83 | if pretrained: 84 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 85 | return model 86 | 87 | 88 | class Decoder(nn.Module): 89 | def __init__(self, feat_dim=512): 90 | super(Decoder, self).__init__() 91 | self.decoder1 = nn.Sequential( 92 | nn.Linear(feat_dim, 256), 93 | nn.ReLU(), 94 | nn.Linear(256, 128), 95 | nn.ReLU(), 96 | ) 97 | self.decoder2 = nn.Sequential( 98 | nn.Linear(128 + 3, 2) 99 | ) 100 | 101 | def forward(self, feat, info): 102 | out = self.decoder1(feat) 103 | out = th.cat([out, info], 1) 104 | out = self.decoder2(out) 105 | return out 106 | 107 | 108 | class FDecoder(nn.Module): 109 | def __init__(self, feat_dim=512): 110 | super(FDecoder, self).__init__() 111 | self.decoder1 = nn.Sequential( 112 | nn.Linear(feat_dim, 256), 113 | nn.ReLU(), 114 | nn.Linear(256, 128), 115 | nn.ReLU(), 116 | ) 117 | self.decoder2 = nn.Sequential( 118 | nn.Linear(128 + 3, 2) 119 | ) 120 | 121 | def forward(self, feat, finfo): 122 | out = self.decoder1(feat) 123 | out = th.cat([out, finfo], 1) 124 | out = self.decoder2(out) 125 | return out 126 | -------------------------------------------------------------------------------- /code/gen_optim_samples.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch.nn import functional as F 3 | import numpy as np 4 | import cvxpy as cp 5 | import visdom 6 | from mpl_toolkits.mplot3d import Axes3D 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | 10 | print(f"matplotlib version: {matplotlib.__version__}") 11 | 12 | # initialize dataset 13 | from torchvision import transforms 14 | from data.gaze_dataset_v2 import GazePointAllDataset 15 | data_transforms = { 16 | 'train': transforms.Compose([ 17 | transforms.Resize(224), 18 | transforms.ToTensor(), 19 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 20 | ]), 21 | 'val': transforms.Compose([ 22 | transforms.Resize(224), 23 | transforms.ToTensor(), 24 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 25 | ]) 26 | } 27 | trainset = GazePointAllDataset(root_dir=r'D:\data\gaze', 28 | transform=data_transforms['train'], 29 | phase='train', 30 | face_image=True, face_depth=True, eye_image=True, 31 | eye_depth=True, 32 | info=True, eye_bbox=True, face_bbox=True, eye_coord=True, 33 | landmark=True) 34 | print('The size of training data is: {}'.format(len(trainset))) 35 | 36 | 37 | def scatter_face_landmark(face_landmark, median, median_mask=None): 38 | if median_mask is None: 39 | median_mask = np.ones_like(median, dtype=np.int) 40 | fig = plt.figure() 41 | ax = fig.add_subplot(111, projection='3d') 42 | ax.set_xlabel('X') 43 | ax.set_ylabel('Y') 44 | ax.set_zlabel('D') 45 | xs = face_landmark[median_mask, 0] 46 | ys = face_landmark[median_mask, 1] 47 | zs = median[median_mask] 48 | ax.scatter(xs, ys, zs, c="r", marker="x") 49 | ax.view_init(-109, -64) 50 | plt.draw() 51 | plt.show() 52 | 53 | 54 | def extract_landmark_depth(depth, landmark, scale_factor, bbox, region_size=7): 55 | bs = depth.size(0) 56 | img_size = depth.size(3) 57 | assert depth.size(2) == depth.size(3) 58 | num_landmark = landmark.size(1) 59 | # transform landmarks to face image coordinate system (bs x lm x 2) 60 | face_lm_image = (landmark - bbox[:, :2].unsqueeze(1).to(depth)) * scale_factor.unsqueeze(1) 61 | face_lm = (face_lm_image / img_size * 2) - 1. 62 | 63 | # sample landmark region (bs x lm x lm_size x lm_size) 64 | # gen sample grid (bs x lm x lm_size x lm_size x 2) 65 | x = th.linspace(-region_size / 2, region_size / 2, region_size) / img_size * 2 66 | grid = th.stack(th.meshgrid([x, x])[::-1], dim=2).to(depth) 67 | grid = face_lm.view(bs, num_landmark, 1, 1, 2) + grid 68 | depth_landmark_regions = F.grid_sample( 69 | depth, grid.view(bs, num_landmark, -1, 2), mode="nearest", padding_mode="zeros" 70 | ).squeeze(1) 71 | 72 | # non-zero median 73 | depth_landmark_regions_sorted = th.sort(depth_landmark_regions, dim=2)[0] 74 | depth_landmark_regions_mask = depth_landmark_regions_sorted > 1.e-4 75 | depth_landmark_regions_mask[:, :, 0] = 0 76 | depth_landmark_regions_mask[:, :, 1:] = depth_landmark_regions_mask[:, :, 1:] - \ 77 | depth_landmark_regions_mask[:, :, :-1] 78 | depth_landmark_regions_mask[:, :, -1] = ((depth_landmark_regions_mask.sum(dim=2) == 0) + depth_landmark_regions_mask[:, :, -1]) > 0 79 | assert (depth_landmark_regions_mask.sum(dim=2) == 1).all(), f"{th.sum(depth_landmark_regions_mask, dim=2)}\n{depth_landmark_regions_mask[:, :, depth_landmark_regions_mask.size(2) - 1]}\n{depth_landmark_regions_mask.sum(dim=2) == 0}" 80 | nonzero_st = th.nonzero(depth_landmark_regions_mask) 81 | 82 | assert (nonzero_st[1:, 0] - nonzero_st[:-1, 0] >= 0).all() and \ 83 | ((nonzero_st[1:, 0] * num_landmark + nonzero_st[1:, 1]) - 84 | (nonzero_st[:-1, 0] * num_landmark + nonzero_st[:-1, 1]) >= 0).all() 85 | assert nonzero_st.size(0) == bs * num_landmark 86 | median_ind = ((nonzero_st[:, 2] + region_size * region_size - 1) / 2).long() 87 | depth_landmark_regions_sorted = depth_landmark_regions_sorted.view(bs * num_landmark, region_size * region_size) 88 | median = depth_landmark_regions_sorted[range(len(median_ind)), median_ind].view(bs, num_landmark) 89 | median_mask = median > 1.e-4 90 | 91 | return median, median_mask, face_lm_image 92 | 93 | 94 | def median_to_rel(median, median_mask, landmark): 95 | rel_median = median.view(68, 1).expand(68, 68) - median.view(1, 68).expand(68, 68) 96 | landmark_dist = th.norm(landmark.view(68, 1, 2).expand(68, 68, 2) - landmark.view(1, 68, 2).expand(68, 68, 2), dim=2) 97 | rel_median = rel_median / (landmark_dist + 1e-4) 98 | rel_median_mask = median_mask.view(68, 1).expand(68, 68) * median_mask.view(1, 68).expand(68, 68) 99 | return rel_median, rel_median_mask, landmark_dist 100 | 101 | 102 | sample = trainset[3917] 103 | face_image, face_depth, face_bbox, face_factor, face_landmark = \ 104 | sample['face_image'], \ 105 | sample['face_depth'], \ 106 | sample["face_bbox"], \ 107 | sample["face_scale_factor"], \ 108 | sample["face_landmark"] 109 | 110 | median, median_mask, face_lm_image = extract_landmark_depth(face_depth.unsqueeze(0), face_landmark.unsqueeze(0), 111 | face_factor.unsqueeze(0), face_bbox.unsqueeze(0), 7) 112 | median = median[0].numpy() * 500 + 500 113 | median_mask = np.abs(median - np.median(median)) < 90 114 | rel_median, rel_median_mask, landmark_dist = median_to_rel(th.from_numpy(median), 115 | th.from_numpy(median_mask.astype("float")), face_landmark) 116 | rel_median, rel_median_mask, landmark_dist = rel_median.numpy(), rel_median_mask.numpy(), landmark_dist.numpy() 117 | face_lm_image = face_lm_image[0].numpy() 118 | scatter_face_landmark(face_lm_image, median) 119 | # input("Press Enter to continue...") 120 | 121 | -------------------------------------------------------------------------------- /code/utils/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from contextlib import contextmanager 3 | import torch 4 | import torch.nn 5 | import logging 6 | import logging.handlers 7 | import time 8 | from .edict import edict 9 | 10 | 11 | class Trainer(object): 12 | def __init__(self, checkpoint_dir='./', is_cuda=True): 13 | self.checkpoint_dir = checkpoint_dir 14 | self.is_cuda = is_cuda 15 | self.temps = edict() 16 | self.extras = edict() 17 | self.meters = edict() 18 | self.models = edict() 19 | self._logger = logging.getLogger(self.__class__.__name__) 20 | # self.stream_handler = logging.StreamHandler(sys.stdout) 21 | # self.stream_handler.setFormatter(logging.Formatter('%(message)s')) 22 | # self._logger.addHandler(self.stream_handler) 23 | self.logger = self._logger 24 | self.time = 0 25 | self._time = time.time() 26 | 27 | @staticmethod 28 | def weights_init(m): 29 | classname = m.__class__.__name__ 30 | if classname.find('Conv') != -1: 31 | torch.nn.init.kaiming_normal_(m.weight.data) 32 | elif classname.find('BatchNorm') != -1: 33 | m.weight.data.fill_(1.) 34 | m.bias.data.fill_(1e-4) 35 | 36 | @staticmethod 37 | def _group_weight(module, lr): 38 | group_decay = [] 39 | group_no_decay = [] 40 | for m in module.modules(): 41 | if isinstance(m, torch.nn.Linear): 42 | group_decay.append(m.weight) 43 | if m.bias is not None: 44 | group_no_decay.append(m.bias) 45 | elif isinstance(m, torch.nn.modules.conv._ConvNd): 46 | group_decay.append(m.weight) 47 | if m.bias is not None: 48 | group_no_decay.append(m.bias) 49 | elif isinstance(m, torch.nn.modules.batchnorm._BatchNorm) or isinstance(m, torch.nn.GroupNorm): 50 | if m.weight is not None: 51 | group_no_decay.append(m.weight) 52 | if m.bias is not None: 53 | group_no_decay.append(m.bias) 54 | 55 | assert len(list( 56 | module.parameters())) == len(group_decay) + len(group_no_decay) 57 | groups = [ 58 | dict(params=group_decay, lr=lr), 59 | dict(params=group_no_decay, lr=lr, weight_decay=0.) 60 | ] 61 | return groups 62 | 63 | def save_state_dict(self, filename): 64 | state_dict = edict() 65 | if not os.path.exists(self.checkpoint_dir): 66 | try: 67 | os.makedirs(self.checkpoint_dir) 68 | except OSError as exc: # Guard against race condition 69 | raise exc 70 | # save models 71 | state_dict.models = edict() 72 | for name, model in self.models.items(): 73 | if isinstance(model, torch.nn.DataParallel): 74 | model = model.module 75 | state_dict.models[name] = model.state_dict() 76 | # save meters 77 | state_dict.meters = edict() 78 | for name, meter in self.meters.items(): 79 | state_dict.meters[name] = meter 80 | # save extras 81 | state_dict.extras = edict() 82 | for name, extra in self.extras.items(): 83 | state_dict.extras[name] = extra 84 | 85 | path = os.path.join(self.checkpoint_dir, filename) 86 | torch.save(state_dict, path, pickle_protocol=4) 87 | return self 88 | 89 | def load_state_dict(self, filename): 90 | path = os.path.join(self.checkpoint_dir, filename) 91 | saved_dict = edict(torch.load(path, map_location=lambda storage, loc: storage)) 92 | # load models 93 | for name, model in saved_dict.models.items(): 94 | assert isinstance(self.models.get(name), torch.nn.Module) 95 | self.models[name] = self.models[name].cpu() 96 | self.models[name].load_state_dict(model) 97 | # load meters 98 | for name, meter in saved_dict.meters.items(): 99 | self.meters[name] = meter 100 | # load extras 101 | for name, extra in saved_dict.extras.items(): 102 | self.extras[name] = extra 103 | return self 104 | 105 | def _get_logger(self, name, propagate=False): 106 | child_logger = self._logger.getChild(name) 107 | child_logger.propagate = propagate 108 | return self._logger.getChild(name) 109 | 110 | def _timeit(self): 111 | self.time = time.time() - self._time 112 | self._time = time.time() 113 | return self.time 114 | 115 | @contextmanager 116 | def _freeze(self, model): 117 | cache = [] 118 | for param in model.parameters(): 119 | cache.append(param.requires_grad) 120 | param.requires_grad = False 121 | yield 122 | for param in model.parameters(): 123 | param.requires_grad = cache.pop(0) 124 | 125 | def add_logger_handler(self, handler): 126 | self._logger.addHandler(handler) 127 | return self 128 | 129 | def timeit(self): 130 | self.time = time.time() - self._time 131 | self._time = time.time() 132 | return self 133 | 134 | def print(self, attr: str=None): 135 | if attr is None: 136 | print(self, flush=True) 137 | return self 138 | attrs = attr.split('.') 139 | obj = self 140 | for attr in attrs: 141 | obj = getattr(obj, attr) 142 | print(obj, flush=True) 143 | return self 144 | 145 | def chain_op(self, **kwargs): 146 | for op, args in kwargs: 147 | f = getattr(self, op, default=None) 148 | if f is None: 149 | raise ValueError('unrecognized op "%s"' % op) 150 | f(**args) 151 | return self 152 | 153 | def end(self): 154 | pass 155 | 156 | 157 | class CLSTrainer(Trainer): 158 | def accuracy(self, output, target, topk=(1,)): 159 | """Computes the precision@k for the specified values of k""" 160 | maxk = max(topk) 161 | batch_size = target.size(0) 162 | _, pred = output.topk(maxk, 1, True, True) 163 | pred = pred.t() 164 | correct = pred.eq(target.view(1, -1).expand_as(pred).type_as(pred)) 165 | res = [] 166 | for k in topk: 167 | correct_k = correct[:k].view(-1).float().sum(0).tolist()[0] 168 | res.append(correct_k / batch_size) 169 | return res 170 | -------------------------------------------------------------------------------- /code/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 39 | 40 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /code/utils/gen_filelist.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | def gen_filelist(data_root, out_root=None, log_root=None): 8 | rgb_dir = "color" 9 | gt_dir = "coordinate" 10 | eye_coord_dir = "eyecoordinate" 11 | depth_dir = "projected_depth_calibration" 12 | landmark_dir = "landmark" 13 | 14 | size_train_set = 159 15 | 16 | anno_filedict = dict( 17 | face_image=[], 18 | face_depth=[], 19 | face_landmark=[], 20 | face_bbox=[], 21 | left_eye_image=[], 22 | right_eye_image=[], 23 | left_eye_depth=[], 24 | right_eye_depth=[], 25 | left_eye_bbox=[], 26 | right_eye_bbox=[], 27 | left_eye_coord=[], 28 | right_eye_coord=[], 29 | gaze_point=[], 30 | is_train=[], 31 | has_landmark=[] 32 | ) 33 | 34 | # filter person id 35 | rgb_ids = set(os.listdir(os.path.join(data_root, rgb_dir))) 36 | gt_ids = set(os.listdir(os.path.join(data_root, gt_dir))) 37 | eye_coord_ids = set(os.listdir(os.path.join(data_root, eye_coord_dir))) 38 | depth_ids = set(os.listdir(os.path.join(data_root, depth_dir))) 39 | log_root = data_root if log_root is None else log_root 40 | 41 | valid_pids = rgb_ids.intersection(gt_ids).intersection(eye_coord_ids).intersection(depth_ids) 42 | print(f"pids valid/all: {len(valid_pids)}/{len(rgb_ids)} ") 43 | 44 | # filter samples 45 | valid_sample_ids = {} 46 | has_landmark = {} 47 | for pid in tqdm(valid_pids, desc="filtering samples"): 48 | valid_sample_ids[pid] = set() 49 | samples = [sample for sample in os.listdir(os.path.join(data_root, gt_dir, pid)) if sample.endswith(".npy")] 50 | for sample in samples: 51 | is_valid = True 52 | sample_id = f"{int(sample[3:8]) + 1:05d}" 53 | err_msg = f"[pid={pid}, sid={sample_id}]\n" 54 | has_landmark[(pid, sample_id)] = True 55 | # verify face bbox 56 | if not os.path.isfile(os.path.join(data_root, rgb_dir, pid, "color" + sample_id + "_face.txt")): 57 | err_msg += f'missing {os.path.join(rgb_dir, pid, "color" + sample_id + "_face.txt")}\n' 58 | is_valid = False 59 | # verify face landmark 60 | if not os.path.isfile(os.path.join(data_root, landmark_dir, pid, sample_id + ".npy")): 61 | err_msg += f'missing {os.path.join(data_root, landmark_dir, pid, sample_id + ".npy")}\n' 62 | has_landmark[(pid, sample_id)] = False 63 | # verify left eye bbox 64 | if not os.path.isfile(os.path.join(data_root, rgb_dir, pid, "color" + sample_id + "_left_eye.txt")): 65 | err_msg += f'missing {os.path.join(rgb_dir, pid, "color" + sample_id + "_left_eye.txt")}\n' 66 | is_valid = False 67 | # verify left eye image 68 | if not os.path.isfile(os.path.join(data_root, rgb_dir, pid, "color" + sample_id + "_lefteye.jpg")): 69 | err_msg += f'missing {os.path.join(rgb_dir, pid, "color" + sample_id + "_lefteye.jpg")}\n' 70 | is_valid = False 71 | # verify right eye bbox 72 | if not os.path.isfile(os.path.join(data_root, rgb_dir, pid, "color" + sample_id + "_right_eye.txt")): 73 | err_msg += f'missing {os.path.join(rgb_dir, pid, "color" + sample_id + "_right_eye.txt")}\n' 74 | is_valid = False 75 | # verify right eye image 76 | if not os.path.isfile(os.path.join(data_root, rgb_dir, pid, "color" + sample_id + "_righteye.jpg")): 77 | err_msg += f'missing {os.path.join(rgb_dir, pid, "color" + sample_id + "_righteye.jpg")}\n' 78 | is_valid = False 79 | # verify gaze gt 80 | if not os.path.isfile(os.path.join(data_root, gt_dir, pid, "xy_" + "{:05d}".format(int(sample_id) - 1) + ".npy")): 81 | err_msg += f'missing {os.path.join(gt_dir, pid, "xy_" + "{:05d}".format(int(sample_id) - 1) + ".npy")}\n' 82 | is_valid = False 83 | # verify left eye coordinate 84 | if not os.path.isfile(os.path.join(data_root, eye_coord_dir, pid, sample_id + "_le.npy")): 85 | err_msg += f'missing {os.path.join(eye_coord_dir, pid, sample_id + "_le.npy")}\n' 86 | is_valid = False 87 | # verify right eye coordinate 88 | if not os.path.isfile(os.path.join(data_root, eye_coord_dir, pid, sample_id + "_re.npy")): 89 | err_msg += f'missing {os.path.join(eye_coord_dir, pid, sample_id + "_re.npy")}\n' 90 | is_valid = False 91 | # verify face depth 92 | if not os.path.isfile(os.path.join(data_root, depth_dir, pid, "projected_depth" + sample_id + "_face.png")): 93 | err_msg += f'missing {os.path.join(depth_dir, pid, "projected_depth" + sample_id + "_face.png")}\n' 94 | is_valid = False 95 | # verify left eye depth 96 | if not os.path.isfile( 97 | os.path.join(data_root, depth_dir, pid, "projected_depth" + sample_id + "_lefteye.png")): 98 | err_msg += f'missing {os.path.join(depth_dir, pid, "projected_depth" + sample_id + "_lefteye.png")}\n' 99 | is_valid = False 100 | # verify right eye depth 101 | if not os.path.isfile( 102 | os.path.join(data_root, depth_dir, pid, "projected_depth" + sample_id + "_righteye.png")): 103 | err_msg += f'missing {os.path.join(depth_dir, pid, "projected_depth" + sample_id + "_righteye.png")}\n' 104 | is_valid = False 105 | if is_valid: 106 | valid_sample_ids[pid].add(sample_id) 107 | else: 108 | with open(os.path.join(log_root, "error.log"), "a+") as fp: 109 | print(err_msg, file=fp) 110 | if len(valid_sample_ids[pid]) == 0: 111 | valid_sample_ids.pop(pid) 112 | print(f"warn: pid={pid} has no valid samples, ignored.") 113 | 114 | id_list = [] 115 | pids = sorted(valid_sample_ids.keys()) 116 | for i, pid in enumerate(tqdm(pids, desc="preparing filelist")): 117 | sample_ids = sorted(valid_sample_ids[pid]) 118 | for sample_id in sample_ids: 119 | # sample_id = f"{int(sample[3:8]) + 1:05d}" 120 | id_list.append(pid + sample_id) 121 | anno_filedict["face_image"].append(os.path.join(rgb_dir, pid, "color" + sample_id + "_face.jpg")) 122 | anno_filedict["face_depth"].append( 123 | os.path.join(depth_dir, pid, "projected_depth" + sample_id + "_face.png")) 124 | anno_filedict["has_landmark"].append(has_landmark[(pid, sample_id)]) 125 | anno_filedict["face_landmark"].append(os.path.join(landmark_dir, pid, sample_id + ".npy")) 126 | anno_filedict["face_bbox"].append(os.path.join(rgb_dir, pid, "color" + sample_id + "_face.txt")) 127 | anno_filedict["left_eye_image"].append(os.path.join(rgb_dir, pid, "color" + sample_id + "_lefteye.jpg")) 128 | anno_filedict["right_eye_image"].append(os.path.join(rgb_dir, pid, "color" + sample_id + "_righteye.jpg")) 129 | anno_filedict["left_eye_depth"].append( 130 | os.path.join(depth_dir, pid, "projected_depth" + sample_id + "_lefteye.png")) 131 | anno_filedict["right_eye_depth"].append( 132 | os.path.join(depth_dir, pid, "projected_depth" + sample_id + "_righteye.png")) 133 | anno_filedict["left_eye_bbox"].append(os.path.join(rgb_dir, pid, "color" + sample_id + "_left_eye.txt")) 134 | anno_filedict["right_eye_bbox"].append(os.path.join(rgb_dir, pid, "color" + sample_id + "_right_eye.txt")) 135 | anno_filedict["left_eye_coord"].append(os.path.join(eye_coord_dir, pid, sample_id + "_le.npy")) 136 | anno_filedict["right_eye_coord"].append(os.path.join(eye_coord_dir, pid, sample_id + "_re.npy")) 137 | anno_filedict["gaze_point"].append(os.path.join(gt_dir, pid, "xy_" + "{:05d}".format(int(sample_id) - 1) + ".npy")) 138 | if i <= size_train_set: 139 | anno_filedict["is_train"].append(1) 140 | else: 141 | anno_filedict["is_train"].append(0) 142 | 143 | df = pd.DataFrame(data=anno_filedict, index=id_list) 144 | if out_root is None: 145 | out_root = data_root 146 | os.makedirs(out_root, exist_ok=True) 147 | # write training file list 148 | df[df["is_train"] == 1].to_csv(os.path.join(out_root, "train_filelist.csv"),) 149 | df[df["is_train"] == 0].to_csv(os.path.join(out_root, "val_filelist.csv")) 150 | 151 | 152 | if __name__ == '__main__': 153 | # gen_filelist(r"/home/ziheng/datasets/gaze") 154 | gen_filelist(r'D:\data\gaze') 155 | -------------------------------------------------------------------------------- /code/data/gaze_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | import numpy as np 5 | import pandas as pd 6 | import cv2 7 | 8 | import torch as th 9 | from torch.utils import data 10 | from torchvision import transforms as tf 11 | 12 | from time import time 13 | import pickle 14 | 15 | 16 | class GazePointAllDataset(data.Dataset): 17 | def __init__(self, root_dir, w_screen=59.77, h_screen=33.62, transform=None, phase="train", **kwargs): 18 | self.root_dir = root_dir 19 | self.w_screen = w_screen 20 | self.h_screen = h_screen 21 | self.transform = transform 22 | self.kwargs = kwargs 23 | self.anno = pd.read_csv(os.path.join(root_dir, phase + "_meta.csv"), index_col=0) 24 | # if os.path.isfile(os.path.join(root_dir, "depth_stat.pkl")): 25 | # with open(os.path.join(root_dir, "depth_stat.pkl"), "rb") as fp: 26 | # stat = pickle.load(fp) 27 | # anno.drop(anno.iloc[stat[0]]) 28 | root_dir = root_dir.rstrip("/").rstrip("\\") 29 | self.face_image_list = (root_dir + "/" + self.anno["face_image"]).tolist() 30 | self.face_depth_list = (root_dir + "/" + self.anno["face_depth"]).tolist() 31 | self.face_bbox_list = (root_dir + "/" + self.anno["face_bbox"]).tolist() 32 | self.le_image_list = (root_dir + "/" + self.anno["left_eye_image"]).tolist() 33 | self.re_image_list = (root_dir + "/" + self.anno["right_eye_image"]).tolist() 34 | self.le_depth_list = (root_dir + "/" + self.anno["left_eye_depth"]).tolist() 35 | self.re_depth_list = (root_dir + "/" + self.anno["right_eye_depth"]).tolist() 36 | self.le_bbox_list = (root_dir + "/" + self.anno["left_eye_bbox"]).tolist() 37 | self.re_bbox_list = (root_dir + "/" + self.anno["right_eye_bbox"]).tolist() 38 | # self.le_coord_list = (root_dir + "/" + self.anno["left_eye_coord"]).tolist() 39 | # self.re_coord_list = (root_dir + "/" + self.anno["right_eye_coord"]).tolist() 40 | self.gt_name_list = (root_dir + "/" + self.anno["gaze_point"]).tolist() 41 | 42 | for data_item in kwargs.keys(): 43 | if data_item not in ("face_image", "face_depth", "eye_image", "eye_depth", 44 | "face_bbox", "eye_bbox", "gt", "eye_coord", "info"): 45 | raise ValueError(f"unrecognized dataset item: {data_item}") 46 | 47 | def __len__(self): 48 | return len(self.face_image_list) 49 | 50 | def __getitem__(self, idx): 51 | with open(self.le_bbox_list[idx]) as fp: 52 | le_bbox = list(map(float, fp.readline().split())) 53 | with open(self.re_bbox_list[idx]) as fp: 54 | re_bbox = list(map(float, fp.readline().split())) 55 | with open(self.face_bbox_list[idx]) as fp: 56 | face_bbox = list(map(float, fp.readline().split())) 57 | 58 | le_coor = np.load(self.le_coord_list[idx]) 59 | re_coor = np.load(self.re_coord_list[idx]) 60 | gt = np.load(self.gt_name_list[idx]) 61 | 62 | gt[0] -= self.w_screen / 2 63 | gt[1] -= self.h_screen / 2 64 | 65 | sample = {} 66 | 67 | sample["index"] = th.LongTensor([idx]) 68 | index = f"{self.anno.index[idx]:010d}" 69 | sample["pid"] = th.LongTensor([int(index[:5])]) 70 | sample["sid"] = th.LongTensor([int(index[5:])]) 71 | 72 | sample['gt'] = th.FloatTensor(gt) 73 | 74 | if self.kwargs.get('face_image'): 75 | face_image = Image.open(self.face_image_list[idx]) 76 | sample['face_image'] = self.transform(face_image) if self.transform is not None else face_image 77 | 78 | if self.kwargs.get('face_depth'): 79 | assert np.abs((face_bbox[3] - face_bbox[1]) - (face_bbox[2] - face_bbox[0])) <= 2, f"invalid face bbox @ {self.face_bbox_list[idx]}" 80 | scale_factor = 224 / (face_bbox[2] - face_bbox[0]) 81 | # scale_factor = min(scale_factor, 1.004484) 82 | # scale_factor = max(scale_factor, 0.581818) 83 | face_depth = cv2.imread(self.face_depth_list[idx], -1) 84 | # face_depth = np.int32(face_depth) 85 | # face_depth[face_depth<500] = 500 86 | # face_depth[face_depth > 1023] = 1023 87 | # face_depth -= 512 88 | if self.transform is not None: 89 | face_depth = face_depth[np.newaxis, :, :]# / scale_factor 90 | # sample['face_depth'] = th.FloatTensor(face_depth / 883) 91 | sample['face_depth'] = th.clamp((th.FloatTensor(face_depth.astype('float')) - 500) / 500, 0., 1.) 92 | sample['face_scale_factor'] = th.FloatTensor([scale_factor]) 93 | else: 94 | sample['face_depth'] = face_depth 95 | # print('max: {}, min:{}'.format((face_depth / 430).max(), (face_depth / 430).min()), flush=True) 96 | 97 | if self.kwargs.get('eye_image'): 98 | le_image = Image.open(self.le_image_list[idx]) 99 | re_image = Image.open(self.re_image_list[idx]) 100 | sample['left_eye_image'] = self.transform(le_image) if self.transform is not None else le_image 101 | sample['right_eye_image'] = self.transform(re_image) if self.transform is not None else re_image 102 | 103 | if self.kwargs.get('eye_depth'): 104 | le_depth = cv2.imread(self.le_depth_list[idx], -1) 105 | re_depth = cv2.imread(self.re_depth_list[idx], -1) 106 | if self.transform is not None: 107 | le_depth = le_depth[np.newaxis, :, :].astype('float') # / le_scale_factor # the new dim is the dim with np.newaxis 108 | re_depth = re_depth[np.newaxis, :, :].astype('float') # / re_scale_factor 109 | # sample['left_depth'] = torch.FloatTensor(le_depth/1000) 110 | # sample['right_depth'] = torch.FloatTensor(re_depth/1000) 111 | sample['left_eye_depth'] = th.FloatTensor(le_depth) 112 | sample['right_eye_depth'] = th.FloatTensor(re_depth) 113 | else: 114 | sample['left_eye_depth'] = le_depth 115 | sample['right_eye_depth'] = re_depth 116 | 117 | if self.kwargs.get('eye_bbox'): 118 | assert le_bbox[3] - le_bbox[1] == le_bbox[2] - le_bbox[0], f"invalid left eye bbox @ {self.le_bbox_list[idx]}" 119 | le_scale_factor = 224 / (le_bbox[2] - le_bbox[0]) 120 | # le_scale_factor = min(le_scale_factor, 1.004484) 121 | # le_scale_factor = max(le_scale_factor, 0.581818) 122 | assert re_bbox[3] - re_bbox[1] == re_bbox[2] - re_bbox[0], f"invalid right eye bbox @ {self.re_bbox_list[idx]}" 123 | re_scale_factor = 224 / (re_bbox[2] - re_bbox[0]) 124 | # re_scale_factor = min(re_scale_factor, 1.004484) 125 | # re_scale_factor = max(re_scale_factor, 0.581818) 126 | sample["left_eye_scale_factor"] = th.FloatTensor([le_scale_factor]) 127 | sample["right_eye_scale_factor"] = th.FloatTensor([re_scale_factor]) 128 | sample['left_eye_bbox'] = th.FloatTensor(le_bbox) 129 | sample['right_eye_bbox'] = th.FloatTensor(re_bbox) 130 | 131 | if self.kwargs.get('face_bbox'): 132 | sample['face_bbox'] = th.FloatTensor(face_bbox) 133 | 134 | if self.kwargs.get('eye_coord'): 135 | sample['left_eye_coord'] = th.FloatTensor(np.float32(le_coor)) 136 | sample['right_eye_coord'] = th.FloatTensor(np.float32(re_coor)) 137 | 138 | if self.kwargs.get('info'): 139 | le_depth = np.clip((cv2.imread(self.le_depth_list[idx], -1) - 500) / 500, 0, 1.) 140 | re_depth = np.clip((cv2.imread(self.le_depth_list[idx], -1) - 500) / 500, 0, 1.) 141 | # get info 142 | le_depth_ = le_depth[le_depth > 0] 143 | if len(le_depth_) > 0: 144 | le_info = [le_coor[0] / 1920, le_coor[1] / 1080, np.mean(le_depth_)] 145 | else: 146 | le_info = [le_coor[0] / 1920, le_coor[1] / 1080] + [0.] 147 | 148 | re_depth_ = re_depth[re_depth > 0] 149 | if len(re_depth_) > 0: 150 | re_info = [re_coor[0] / 1920, re_coor[1] / 1080, np.mean(re_depth_)] 151 | else: 152 | re_info = [re_coor[0] / 1920, re_coor[1] / 1080] + [0.] 153 | sample['left_eye_info'] = th.FloatTensor(le_info) 154 | sample['right_eye_info'] = th.FloatTensor(re_info) 155 | 156 | return sample 157 | 158 | 159 | if __name__ == '__main__': 160 | from tqdm import tqdm 161 | dataset = GazePointAllDataset( 162 | root_dir=r"D:\\data\\gaze", 163 | phase="train", 164 | face_image=True, 165 | face_depth=True, 166 | face_bbox=True, 167 | eye_image=True, 168 | eye_depth=True, 169 | eye_bbox=True, 170 | eye_coord=True 171 | ) 172 | 173 | for sample in tqdm(dataset, desc="testing"): 174 | pass 175 | -------------------------------------------------------------------------------- /code/data/gaze_dataset_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | import numpy as np 5 | import pandas as pd 6 | import cv2 7 | 8 | import torch as th 9 | from torch.utils import data 10 | from torchvision import transforms as tf 11 | 12 | from time import time 13 | import pickle 14 | 15 | 16 | class GazePointAllDataset(data.Dataset): 17 | def __init__(self, root_dir, w_screen=59.77, h_screen=33.62, transform=None, phase="train", **kwargs): 18 | self.root_dir = root_dir 19 | self.w_screen = w_screen 20 | self.h_screen = h_screen 21 | self.transform = transform 22 | self.kwargs = kwargs 23 | anno = pd.read_csv(os.path.join(root_dir, phase + "_filelist.csv"), index_col=0) 24 | if "landmark" in kwargs.keys(): 25 | anno = anno[anno["has_landmark"]] 26 | # if os.path.isfile(os.path.join(root_dir, "depth_stat.pkl")): 27 | # with open(os.path.join(root_dir, "depth_stat.pkl"), "rb") as fp: 28 | # stat = pickle.load(fp) 29 | # anno.drop(anno.index[stat[0]]) 30 | root_dir = root_dir.rstrip("/").rstrip("\\") 31 | self.face_image_list = (root_dir + "/" + anno["face_image"]).tolist() 32 | self.face_depth_list = (root_dir + "/" + anno["face_depth"]).tolist() 33 | self.face_bbox_list = (root_dir + "/" + anno["face_bbox"]).tolist() 34 | self.le_image_list = (root_dir + "/" + anno["left_eye_image"]).tolist() 35 | self.re_image_list = (root_dir + "/" + anno["right_eye_image"]).tolist() 36 | self.le_depth_list = (root_dir + "/" + anno["left_eye_depth"]).tolist() 37 | self.re_depth_list = (root_dir + "/" + anno["right_eye_depth"]).tolist() 38 | self.le_bbox_list = (root_dir + "/" + anno["left_eye_bbox"]).tolist() 39 | self.re_bbox_list = (root_dir + "/" + anno["right_eye_bbox"]).tolist() 40 | self.le_coord_list = (root_dir + "/" + anno["left_eye_coord"]).tolist() 41 | self.re_coord_list = (root_dir + "/" + anno["right_eye_coord"]).tolist() 42 | self.gt_name_list = (root_dir + "/" + anno["gaze_point"]).tolist() 43 | if "landmark" in kwargs.keys(): 44 | self.face_landmark = (root_dir + "/" + anno["face_landmark"]).tolist() 45 | 46 | for data_item in kwargs.keys(): 47 | if data_item not in ("face_image", "face_depth", "eye_image", "eye_depth", 48 | "face_bbox", "eye_bbox", "gt", "eye_coord", "info", "landmark"): 49 | raise ValueError(f"unrecognized dataset item: {data_item}") 50 | 51 | def __len__(self): 52 | return len(self.face_image_list) 53 | 54 | def __getitem__(self, idx): 55 | with open(self.le_bbox_list[idx]) as fp: 56 | le_bbox = list(map(float, fp.readline().split())) 57 | with open(self.re_bbox_list[idx]) as fp: 58 | re_bbox = list(map(float, fp.readline().split())) 59 | with open(self.face_bbox_list[idx]) as fp: 60 | face_bbox = list(map(float, fp.readline().split())) 61 | 62 | le_coor = np.load(self.le_coord_list[idx]) 63 | re_coor = np.load(self.re_coord_list[idx]) 64 | gt = np.load(self.gt_name_list[idx]) 65 | 66 | gt[0] -= self.w_screen / 2 67 | gt[1] -= self.h_screen / 2 68 | 69 | sample = {} 70 | 71 | sample["index"] = th.FloatTensor([idx]) 72 | 73 | sample['gt'] = th.FloatTensor(gt) 74 | 75 | if self.kwargs.get('face_image'): 76 | face_image = Image.open(self.face_image_list[idx]) 77 | sample['face_image'] = self.transform(face_image) if self.transform is not None else face_image 78 | 79 | if self.kwargs.get('face_depth'): 80 | assert np.abs((face_bbox[3] - face_bbox[1]) - (face_bbox[2] - face_bbox[0])) <= 2, f"invalid face bbox @ {self.face_bbox_list[idx]}" 81 | scale_factor = 224 / (face_bbox[2] - face_bbox[0]) 82 | # scale_factor = min(scale_factor, 1.004484) 83 | # scale_factor = max(scale_factor, 0.581818) 84 | face_depth = cv2.imread(self.face_depth_list[idx], -1) 85 | # face_depth = np.int32(face_depth) 86 | # face_depth[face_depth<500] = 500 87 | # face_depth[face_depth > 1023] = 1023 88 | # face_depth -= 512 89 | if self.transform is not None: 90 | face_depth = face_depth[np.newaxis, :, :] # / scale_factor 91 | # sample['face_depth'] = th.FloatTensor(face_depth / 883) 92 | sample['face_depth'] = th.clamp((th.FloatTensor(face_depth.astype('float')) - 500) / 500, 0., 1.) 93 | sample['face_scale_factor'] = th.FloatTensor([scale_factor]) 94 | else: 95 | sample['face_depth'] = face_depth 96 | # print('max: {}, min:{}'.format((face_depth / 430).max(), (face_depth / 430).min()), flush=True) 97 | 98 | if self.kwargs.get('eye_image'): 99 | le_image = Image.open(self.le_image_list[idx]) 100 | re_image = Image.open(self.re_image_list[idx]) 101 | sample['left_eye_image'] = self.transform(le_image) if self.transform is not None else le_image 102 | sample['right_eye_image'] = self.transform(re_image) if self.transform is not None else re_image 103 | 104 | if self.kwargs.get('eye_depth'): 105 | le_depth = cv2.imread(self.le_depth_list[idx], -1) 106 | re_depth = cv2.imread(self.re_depth_list[idx], -1) 107 | if self.transform is not None: 108 | le_depth = le_depth[np.newaxis, :, :].astype('float') # / le_scale_factor # the new dim is the dim with np.newaxis 109 | re_depth = re_depth[np.newaxis, :, :].astype('float') # / re_scale_factor 110 | # sample['left_depth'] = torch.FloatTensor(le_depth/1000) 111 | # sample['right_depth'] = torch.FloatTensor(re_depth/1000) 112 | sample['left_eye_depth'] = th.FloatTensor(le_depth) 113 | sample['right_eye_depth'] = th.FloatTensor(re_depth) 114 | else: 115 | sample['left_eye_depth'] = le_depth 116 | sample['right_eye_depth'] = re_depth 117 | 118 | if self.kwargs.get('landmark'): 119 | landmark = np.load(self.face_landmark[idx]) 120 | sample['face_landmark'] = th.FloatTensor(landmark) 121 | 122 | if self.kwargs.get('eye_bbox'): 123 | assert le_bbox[3] - le_bbox[1] == le_bbox[2] - le_bbox[0], f"invalid left eye bbox @ {self.le_bbox_list[idx]}" 124 | le_scale_factor = 224 / (le_bbox[2] - le_bbox[0]) 125 | # le_scale_factor = min(le_scale_factor, 1.004484) 126 | # le_scale_factor = max(le_scale_factor, 0.581818) 127 | assert re_bbox[3] - re_bbox[1] == re_bbox[2] - re_bbox[0], f"invalid right eye bbox @ {self.re_bbox_list[idx]}" 128 | re_scale_factor = 224 / (re_bbox[2] - re_bbox[0]) 129 | # re_scale_factor = min(re_scale_factor, 1.004484) 130 | # re_scale_factor = max(re_scale_factor, 0.581818) 131 | sample["left_eye_scale_factor"] = th.FloatTensor([le_scale_factor]) 132 | sample["right_eye_scale_factor"] = th.FloatTensor([re_scale_factor]) 133 | sample['left_eye_bbox'] = th.FloatTensor(le_bbox) 134 | sample['right_eye_bbox'] = th.FloatTensor(re_bbox) 135 | 136 | if self.kwargs.get('face_bbox'): 137 | sample['face_bbox'] = th.FloatTensor(face_bbox) 138 | 139 | if self.kwargs.get('eye_coord'): 140 | sample['left_eye_coord'] = th.FloatTensor(np.float32(le_coor)) 141 | sample['right_eye_coord'] = th.FloatTensor(np.float32(re_coor)) 142 | 143 | if self.kwargs.get('info'): 144 | le_depth = np.clip((cv2.imread(self.le_depth_list[idx], -1) - 500) / 500, 0, 1.) 145 | re_depth = np.clip((cv2.imread(self.le_depth_list[idx], -1) - 500) / 500, 0, 1.) 146 | face_depth = np.clip((cv2.imread(self.face_depth_list[idx], -1) - 500) / 500, 0, 1.) 147 | # get info 148 | le_depth_ = le_depth[le_depth > 0] 149 | face_depth = face_depth[face_depth > 0] 150 | if len(le_depth_) > 0: 151 | le_info = [le_coor[0] / 1920, le_coor[1] / 1080, np.median(le_depth_)] 152 | else: 153 | le_info = [le_coor[0] / 1920, le_coor[1] / 1080, np.median(face_depth)] 154 | 155 | re_depth_ = re_depth[re_depth > 0] 156 | if len(re_depth_) > 0: 157 | re_info = [re_coor[0] / 1920, re_coor[1] / 1080, np.median(re_depth_)] 158 | else: 159 | re_info = [re_coor[0] / 1920, re_coor[1] / 1080, np.median(face_depth)] 160 | sample['left_eye_info'] = th.FloatTensor(le_info) 161 | sample['right_eye_info'] = th.FloatTensor(re_info) 162 | 163 | return sample 164 | 165 | 166 | if __name__ == '__main__': 167 | from tqdm import tqdm 168 | dataset = GazePointAllDataset( 169 | root_dir=r"D:\\data\\gaze", 170 | phase="train", 171 | face_image=True, 172 | face_depth=True, 173 | face_bbox=True, 174 | eye_image=True, 175 | eye_depth=True, 176 | eye_bbox=True, 177 | eye_coord=True 178 | ) 179 | 180 | for sample in tqdm(dataset, desc="testing"): 181 | pass 182 | -------------------------------------------------------------------------------- /code/models/gaze_depth.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet, model_urls 2 | import torch.utils.model_zoo as model_zoo 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import torch as th 6 | import pdb 7 | import cv2 8 | import numpy as np 9 | 10 | 11 | class ResNetEncoder(ResNet): 12 | def forward(self, x): 13 | x = self.conv1(x) 14 | x = self.bn1(x) 15 | x = self.relu(x) 16 | # x112_64 = x 17 | x = self.maxpool(x) 18 | x = self.layer1(x) 19 | # x56_64 = x 20 | x = self.layer2(x) 21 | # x28_128 = x 22 | x = self.layer3(x) 23 | # x14_256 = x 24 | x = self.layer4(x) 25 | x = self.avgpool(x) 26 | x = x.view(x.size(0), -1) 27 | x = self.relu(x) 28 | 29 | return x # , x112_64, x56_64, x28_128, x14_256 30 | 31 | 32 | def resnet18(pretrained=False, **kwargs): 33 | """Constructs a ResNet-18 model. 34 | 35 | Args: 36 | pretrained (bool): If True, returns a model pre-trained on ImageNet 37 | """ 38 | model = ResNetEncoder(BasicBlock, [2, 2, 2, 2], **kwargs) 39 | if pretrained: 40 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 41 | return model 42 | 43 | 44 | def resnet34(pretrained=False, **kwargs): 45 | """Constructs a ResNet-34 model. 46 | 47 | Args: 48 | pretrained (bool): If True, returns a model pre-trained on ImageNet 49 | """ 50 | model = ResNetEncoder(BasicBlock, [3, 4, 6, 3], **kwargs) 51 | if pretrained: 52 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 53 | return model 54 | 55 | 56 | def resnet50(pretrained=False, **kwargs): 57 | """Constructs a ResNet-50 model. 58 | 59 | Args: 60 | pretrained (bool): If True, returns a model pre-trained on ImageNet 61 | """ 62 | model = ResNetEncoder(Bottleneck, [3, 4, 6, 3], **kwargs) 63 | if pretrained: 64 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 65 | return model 66 | 67 | 68 | def resnet101(pretrained=False, **kwargs): 69 | """Constructs a ResNet-101 model. 70 | 71 | Args: 72 | pretrained (bool): If True, returns a model pre-trained on ImageNet 73 | """ 74 | model = ResNetEncoder(Bottleneck, [3, 4, 23, 3], **kwargs) 75 | if pretrained: 76 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 77 | return model 78 | 79 | 80 | def resnet152(pretrained=False, **kwargs): 81 | """Constructs a ResNet-152 model. 82 | 83 | Args: 84 | pretrained (bool): If True, returns a model pre-trained on ImageNet 85 | """ 86 | model = ResNetEncoder(Bottleneck, [3, 8, 36, 3], **kwargs) 87 | if pretrained: 88 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 89 | return model 90 | 91 | 92 | class Depth2AbsDepth(nn.Module): 93 | pass 94 | 95 | 96 | class RGBD2AbsDepth(nn.Module): 97 | pass 98 | 99 | 100 | class RGB2AbsDepth(nn.Module): 101 | pass 102 | 103 | 104 | class Depth2RelDepth(nn.Module): 105 | pass 106 | 107 | 108 | def extract_landmark_depth(depth, landmark, scale_factor, bbox, region_size=7): 109 | bs = depth.size(0) 110 | img_size = depth.size(3) 111 | assert depth.size(2) == depth.size(3) 112 | num_landmark = landmark.size(1) 113 | # transform landmarks to face image coordinate system (bs x lm x 2) 114 | face_lm = ((landmark - bbox[:, :2].unsqueeze(1).to(depth)) * scale_factor.unsqueeze(1) / img_size * 2) - 1. 115 | 116 | # sample landmark region (bs x lm x lm_size x lm_size) 117 | # gen sample grid (bs x lm x lm_size x lm_size x 2) 118 | x = th.linspace(-region_size / 2, region_size / 2, region_size) / img_size * 2 119 | grid = th.stack(th.meshgrid([x, x])[::-1], dim=2).to(depth) 120 | grid = face_lm.view(bs, num_landmark, 1, 1, 2) + grid 121 | depth_landmark_regions = F.grid_sample( 122 | depth, grid.view(bs, num_landmark, -1, 2), mode="nearest", padding_mode="zeros" 123 | ).squeeze(1) 124 | 125 | # while True: 126 | # # visualize landmark 127 | # for dep, lms, lmbs in zip(depth, face_lm, grid): 128 | # depth_vis = np.uint8(dep.squeeze(0).cpu().numpy() * 255) 129 | # depth_vis = np.stack([depth_vis, depth_vis, depth_vis], axis=2) 130 | # for lm, lmb in zip(lms, lmbs): 131 | # cv2.circle(depth_vis, tuple(((lm + 1) * 112).long().tolist()), 5, (0, 0, 255), 2) 132 | # x1, y1 = int((lmb[0, 0, 0].item() + 1) * 112), int((lmb[0, 0, 1].item() + 1) * 112) 133 | # x2, y2 = int((lmb[region_size-1, region_size-1, 0].item() + 1) * 112), int((lmb[region_size-1, region_size-1, 1].item() + 1) * 112) 134 | # cv2.rectangle(depth_vis, (x1, y1), (x2, y2), (0, 255, 0), 2) 135 | # cv2.imshow("res", depth_vis) 136 | # cv2.waitKey() 137 | # break 138 | 139 | # non-zero median 140 | depth_landmark_regions_sorted = th.sort(depth_landmark_regions, dim=2)[0] 141 | depth_landmark_regions_mask = depth_landmark_regions_sorted > 1.e-4 142 | depth_landmark_regions_mask[:, :, 0] = 0 143 | depth_landmark_regions_mask[:, :, 1:] = depth_landmark_regions_mask[:, :, 1:] - \ 144 | depth_landmark_regions_mask[:, :, :-1] 145 | depth_landmark_regions_mask[:, :, -1] = ((depth_landmark_regions_mask.sum(dim=2) == 0) + depth_landmark_regions_mask[:, :, -1]) > 0 146 | assert (depth_landmark_regions_mask.sum(dim=2) == 1).all(), f"{th.sum(depth_landmark_regions_mask, dim=2)}\n{depth_landmark_regions_mask[:, :, depth_landmark_regions_mask.size(2) - 1]}\n{depth_landmark_regions_mask.sum(dim=2) == 0}" 147 | nonzero_st = th.nonzero(depth_landmark_regions_mask) 148 | 149 | assert (nonzero_st[1:, 0] - nonzero_st[:-1, 0] >= 0).all() and \ 150 | ((nonzero_st[1:, 0] * num_landmark + nonzero_st[1:, 1]) - 151 | (nonzero_st[:-1, 0] * num_landmark + nonzero_st[:-1, 1]) >= 0).all() 152 | assert nonzero_st.size(0) == bs * num_landmark 153 | median_ind = ((nonzero_st[:, 2] + region_size * region_size - 1) / 2).long() 154 | depth_landmark_regions_sorted = depth_landmark_regions_sorted.view(bs * num_landmark, region_size * region_size) 155 | median = depth_landmark_regions_sorted[range(len(median_ind)), median_ind].view(bs, num_landmark) 156 | median_mask = median > 1.e-4 157 | 158 | return median, median_mask 159 | 160 | 161 | class RGB2RelDepth(nn.Module): 162 | def __init__(self, num_landmark=68): 163 | super(RGB2RelDepth, self).__init__() 164 | self.encoder = resnet18(pretrained=True) 165 | self.num_landmark = num_landmark 166 | self.depth = nn.Sequential( 167 | nn.Linear(512, num_landmark, bias=True), 168 | # nn.Sigmoid() 169 | ) 170 | 171 | def forward(self, face_image, landmarks): 172 | bs = face_image.size(0) 173 | feat = self.encoder(face_image) 174 | depth = self.depth(feat) 175 | d1 = depth.view(bs, self.num_landmark, 1).expand(bs, self.num_landmark, self.num_landmark) 176 | d2 = depth.view(bs, 1, self.num_landmark).expand(bs, self.num_landmark, self.num_landmark) 177 | depthdiff = d1 - d2 178 | assert th.allclose(depthdiff, -depthdiff.transpose(1, 2)) 179 | with th.no_grad(): 180 | lm1 = landmarks.view(bs, self.num_landmark, 1, 2).expand(bs, self.num_landmark, self.num_landmark, 2) 181 | lm2 = landmarks.view(bs, 1, self.num_landmark, 2).expand(bs, self.num_landmark, self.num_landmark, 2) 182 | lmdist = th.norm((lm1 - lm2).to(feat), dim=3) 183 | assert th.allclose(lmdist, lmdist.transpose(1, 2)) 184 | rel_depth = depthdiff / (lmdist + 1e-4) 185 | 186 | return rel_depth 187 | 188 | 189 | class LossRelDepth(nn.Module): 190 | def __init__(self, crit, num_landmark=68, image_size=224, landmark_region_size=7, depth_scale=500): 191 | super(LossRelDepth, self).__init__() 192 | self.crit = crit 193 | self.num_landmark = num_landmark 194 | self.image_size = image_size 195 | self.lm_region_size = landmark_region_size 196 | self.depth_scale = depth_scale 197 | 198 | def forward(self, rel_depth_pred, depth, landmarkds, scale_factor, bbox): 199 | with th.no_grad(): 200 | median, median_mask = extract_landmark_depth(depth, landmarkds, scale_factor, bbox, self.lm_region_size) 201 | median *= self.depth_scale 202 | bs = rel_depth_pred.size(0) 203 | median_rel_mask = median_mask.view(bs, self.num_landmark, 1).expand(bs, self.num_landmark, 204 | self.num_landmark) * \ 205 | median_mask.view(bs, 1, self.num_landmark).expand(bs, self.num_landmark, 206 | self.num_landmark) 207 | diag_mask = 1 - th.eye(self.num_landmark, self.num_landmark).to(rel_depth_pred) 208 | rel_median = median.view(bs, self.num_landmark, 1).expand(bs, self.num_landmark, self.num_landmark) - \ 209 | median.view(bs, 1, self.num_landmark).expand(bs, self.num_landmark, self.num_landmark) 210 | landmark_dist = th.norm( 211 | landmarkds.view(bs, self.num_landmark, 1, 2).expand(bs, self.num_landmark, self.num_landmark, 2) - 212 | landmarkds.view(bs, 1, self.num_landmark, 2).expand(bs, self.num_landmark, self.num_landmark, 2), 213 | dim=3 214 | ) 215 | assert th.allclose(landmark_dist, landmark_dist.transpose(1, 2)) 216 | rel_median = rel_median / (landmark_dist + 1e-4) * diag_mask 217 | loss = th.sum(self.crit(rel_depth_pred, rel_median, reduce=False) * median_rel_mask.to(median)) / \ 218 | (th.sum(median_rel_mask) + 1e-4) 219 | 220 | return loss 221 | 222 | 223 | # class RGB2RelDepth(nn.Module): 224 | # pass 225 | -------------------------------------------------------------------------------- /code/models/gaze_depth_v2.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet, model_urls 2 | import torch.utils.model_zoo as model_zoo 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import torch as th 6 | import pdb 7 | import cv2 8 | import numpy as np 9 | 10 | 11 | class ResNetEncoder(ResNet): 12 | def forward(self, x): 13 | x = self.conv1(x) 14 | x = self.bn1(x) 15 | x = self.relu(x) 16 | # x112_64 = x 17 | x = self.maxpool(x) 18 | x = self.layer1(x) 19 | # x56_64 = x 20 | x = self.layer2(x) 21 | # x28_128 = x 22 | x = self.layer3(x) 23 | # x14_256 = x 24 | x = self.layer4(x) 25 | x = self.avgpool(x) 26 | x = x.view(x.size(0), -1) 27 | x = self.relu(x) 28 | 29 | return x # , x112_64, x56_64, x28_128, x14_256 30 | 31 | 32 | def resnet18(pretrained=False, **kwargs): 33 | """Constructs a ResNet-18 model. 34 | 35 | Args: 36 | pretrained (bool): If True, returns a model pre-trained on ImageNet 37 | """ 38 | model = ResNetEncoder(BasicBlock, [2, 2, 2, 2], **kwargs) 39 | if pretrained: 40 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 41 | return model 42 | 43 | 44 | def resnet34(pretrained=False, **kwargs): 45 | """Constructs a ResNet-34 model. 46 | 47 | Args: 48 | pretrained (bool): If True, returns a model pre-trained on ImageNet 49 | """ 50 | model = ResNetEncoder(BasicBlock, [3, 4, 6, 3], **kwargs) 51 | if pretrained: 52 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 53 | return model 54 | 55 | 56 | def resnet50(pretrained=False, **kwargs): 57 | """Constructs a ResNet-50 model. 58 | 59 | Args: 60 | pretrained (bool): If True, returns a model pre-trained on ImageNet 61 | """ 62 | model = ResNetEncoder(Bottleneck, [3, 4, 6, 3], **kwargs) 63 | if pretrained: 64 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 65 | return model 66 | 67 | 68 | def resnet101(pretrained=False, **kwargs): 69 | """Constructs a ResNet-101 model. 70 | 71 | Args: 72 | pretrained (bool): If True, returns a model pre-trained on ImageNet 73 | """ 74 | model = ResNetEncoder(Bottleneck, [3, 4, 23, 3], **kwargs) 75 | if pretrained: 76 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 77 | return model 78 | 79 | 80 | def resnet152(pretrained=False, **kwargs): 81 | """Constructs a ResNet-152 model. 82 | 83 | Args: 84 | pretrained (bool): If True, returns a model pre-trained on ImageNet 85 | """ 86 | model = ResNetEncoder(Bottleneck, [3, 8, 36, 3], **kwargs) 87 | if pretrained: 88 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 89 | return model 90 | 91 | 92 | class Depth2AbsDepth(nn.Module): 93 | pass 94 | 95 | 96 | class RGBD2AbsDepth(nn.Module): 97 | pass 98 | 99 | 100 | class RGB2AbsDepth(nn.Module): 101 | pass 102 | 103 | 104 | class Depth2RelDepth(nn.Module): 105 | pass 106 | 107 | 108 | def extract_landmark_depth(depth, landmark, scale_factor, bbox, region_size=7): 109 | bs = depth.size(0) 110 | img_size = depth.size(3) 111 | assert depth.size(2) == depth.size(3) 112 | num_landmark = landmark.size(1) 113 | # transform landmarks to face image coordinate system (bs x lm x 2) 114 | face_lm = ((landmark - bbox[:, :2].unsqueeze(1).to(depth)) * scale_factor.unsqueeze(1) / img_size * 2) - 1. 115 | 116 | # sample landmark region (bs x lm x lm_size x lm_size) 117 | # gen sample grid (bs x lm x lm_size x lm_size x 2) 118 | x = th.linspace(-region_size / 2, region_size / 2, region_size) / img_size * 2 119 | grid = th.stack(th.meshgrid([x, x])[::-1], dim=2).to(depth) 120 | grid = face_lm.view(bs, num_landmark, 1, 1, 2) + grid 121 | depth_landmark_regions = F.grid_sample( 122 | depth, grid.view(bs, num_landmark, -1, 2), mode="nearest", padding_mode="zeros" 123 | ).squeeze(1) 124 | 125 | # while True: 126 | # # visualize landmark 127 | # for dep, lms, lmbs in zip(depth, face_lm, grid): 128 | # depth_vis = np.uint8(dep.squeeze(0).cpu().numpy() * 255) 129 | # depth_vis = np.stack([depth_vis, depth_vis, depth_vis], axis=2) 130 | # for lm, lmb in zip(lms, lmbs): 131 | # cv2.circle(depth_vis, tuple(((lm + 1) * 112).long().tolist()), 5, (0, 0, 255), 2) 132 | # x1, y1 = int((lmb[0, 0, 0].item() + 1) * 112), int((lmb[0, 0, 1].item() + 1) * 112) 133 | # x2, y2 = int((lmb[region_size-1, region_size-1, 0].item() + 1) * 112), int((lmb[region_size-1, region_size-1, 1].item() + 1) * 112) 134 | # cv2.rectangle(depth_vis, (x1, y1), (x2, y2), (0, 255, 0), 2) 135 | # cv2.imshow("res", depth_vis) 136 | # cv2.waitKey() 137 | # break 138 | 139 | # non-zero median 140 | depth_landmark_regions_sorted = th.sort(depth_landmark_regions, dim=2)[0] 141 | depth_landmark_regions_mask = depth_landmark_regions_sorted > 1.e-4 142 | depth_landmark_regions_mask[:, :, 0] = 0 143 | depth_landmark_regions_mask[:, :, 1:] = depth_landmark_regions_mask[:, :, 1:] - \ 144 | depth_landmark_regions_mask[:, :, :-1] 145 | depth_landmark_regions_mask[:, :, -1] = ((depth_landmark_regions_mask.sum(dim=2) == 0) + depth_landmark_regions_mask[:, :, -1]) > 0 146 | assert (depth_landmark_regions_mask.sum(dim=2) == 1).all(), f"{th.sum(depth_landmark_regions_mask, dim=2)}\n{depth_landmark_regions_mask[:, :, depth_landmark_regions_mask.size(2) - 1]}\n{depth_landmark_regions_mask.sum(dim=2) == 0}" 147 | nonzero_st = th.nonzero(depth_landmark_regions_mask) 148 | 149 | assert (nonzero_st[1:, 0] - nonzero_st[:-1, 0] >= 0).all() and \ 150 | ((nonzero_st[1:, 0] * num_landmark + nonzero_st[1:, 1]) - 151 | (nonzero_st[:-1, 0] * num_landmark + nonzero_st[:-1, 1]) >= 0).all() 152 | assert nonzero_st.size(0) == bs * num_landmark 153 | median_ind = ((nonzero_st[:, 2] + region_size * region_size - 1) / 2).long() 154 | depth_landmark_regions_sorted = depth_landmark_regions_sorted.view(bs * num_landmark, region_size * region_size) 155 | median = depth_landmark_regions_sorted[range(len(median_ind)), median_ind].view(bs, num_landmark) 156 | median_mask = median > 1.e-4 157 | 158 | return median, median_mask 159 | 160 | 161 | class RGB2RelDepth(nn.Module): 162 | def __init__(self, num_landmark=68, dim_landmark_feat=64): 163 | super(RGB2RelDepth, self).__init__() 164 | self.encoder = resnet18(pretrained=True) 165 | self.num_landmark = num_landmark 166 | self.dim_embedding = dim_landmark_feat 167 | self.depth_embedding = nn.Sequential( 168 | nn.Linear(512, num_landmark * dim_landmark_feat, bias=False), 169 | nn.BatchNorm1d(num_landmark * dim_landmark_feat), 170 | nn.ReLU(inplace=True), 171 | nn.Linear(num_landmark * dim_landmark_feat, num_landmark * dim_landmark_feat, bias=True) 172 | # nn.Sigmoid() 173 | ) 174 | self.depth = nn.Sequential( 175 | nn.Conv2d(dim_landmark_feat, 1, kernel_size=1, stride=1, padding=0, bias=False) 176 | ) 177 | 178 | def forward(self, face_image, landmarks): 179 | bs = face_image.size(0) 180 | feat = self.encoder(face_image) 181 | depth_feat = self.depth_embedding(feat) 182 | d1 = depth_feat.view(bs, self.dim_embedding, self.num_landmark, 1).expand(bs, self.dim_embedding, self.num_landmark, self.num_landmark) 183 | d2 = depth_feat.view(bs, self.dim_embedding, 1, self.num_landmark).expand(bs, self.dim_embedding, self.num_landmark, self.num_landmark) 184 | depthdiff = self.depth(d1 - d2).squeeze(1) 185 | assert th.allclose(depthdiff, -depthdiff.transpose(1, 2)), depthdiff 186 | with th.no_grad(): 187 | lm1 = landmarks.view(bs, self.num_landmark, 1, 2).expand(bs, self.num_landmark, self.num_landmark, 2) 188 | lm2 = landmarks.view(bs, 1, self.num_landmark, 2).expand(bs, self.num_landmark, self.num_landmark, 2) 189 | lmdist = th.norm((lm1 - lm2).to(feat), dim=3) 190 | assert th.allclose(lmdist, lmdist.transpose(1, 2)) 191 | rel_depth = depthdiff / (lmdist + 1e-4) 192 | 193 | return rel_depth 194 | 195 | 196 | class LossRelDepth(nn.Module): 197 | def __init__(self, crit, num_landmark=68, image_size=224, landmark_region_size=7, depth_scale=500): 198 | super(LossRelDepth, self).__init__() 199 | self.crit = crit 200 | self.num_landmark = num_landmark 201 | self.image_size = image_size 202 | self.lm_region_size = landmark_region_size 203 | self.depth_scale = depth_scale 204 | 205 | def forward(self, rel_depth_pred, depth, landmarkds, scale_factor, bbox): 206 | with th.no_grad(): 207 | median, median_mask = extract_landmark_depth(depth, landmarkds, scale_factor, bbox, self.lm_region_size) 208 | median *= self.depth_scale 209 | bs = rel_depth_pred.size(0) 210 | median_rel_mask = median_mask.view(bs, self.num_landmark, 1).expand(bs, self.num_landmark, 211 | self.num_landmark) * \ 212 | median_mask.view(bs, 1, self.num_landmark).expand(bs, self.num_landmark, 213 | self.num_landmark) 214 | diag_mask = 1 - th.eye(self.num_landmark, self.num_landmark).to(rel_depth_pred) 215 | rel_median = median.view(bs, self.num_landmark, 1).expand(bs, self.num_landmark, self.num_landmark) - \ 216 | median.view(bs, 1, self.num_landmark).expand(bs, self.num_landmark, self.num_landmark) 217 | landmark_dist = th.norm( 218 | landmarkds.view(bs, self.num_landmark, 1, 2).expand(bs, self.num_landmark, self.num_landmark, 2) - 219 | landmarkds.view(bs, 1, self.num_landmark, 2).expand(bs, self.num_landmark, self.num_landmark, 2), 220 | dim=3 221 | ) 222 | assert th.allclose(landmark_dist, landmark_dist.transpose(1, 2)) 223 | rel_median = rel_median / (landmark_dist + 1e-4) * diag_mask 224 | loss = th.sum(self.crit(rel_depth_pred, rel_median, reduce=False) * median_rel_mask.to(median)) / \ 225 | (th.sum(median_rel_mask) + 1e-4) 226 | 227 | return loss 228 | 229 | 230 | # class RGB2RelDepth(nn.Module): 231 | # pass 232 | -------------------------------------------------------------------------------- /code/models/gaze_aaai_pose16_relu.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.resnet import BasicBlock, ResNet, model_urls 2 | import torch.utils.model_zoo as model_zoo 3 | from torch import nn 4 | import torch as th 5 | import torch.nn.functional as F 6 | 7 | 8 | class ResNetEncoder(ResNet): 9 | def forward(self, x): 10 | x = self.conv1(x) 11 | x = self.bn1(x) 12 | x = self.relu(x) 13 | x = self.maxpool(x) 14 | x = self.layer1(x) 15 | x = self.layer2(x) 16 | x = self.layer3(x) 17 | x = self.layer4(x) 18 | x = self.avgpool(x) 19 | x = x.view(x.size(0), -1) 20 | x = self.relu(x) 21 | return x 22 | 23 | 24 | def resnet34(pretrained=False, **kwargs): 25 | """Constructs a ResNet-34 model. 26 | 27 | Args: 28 | pretrained (bool): If True, returns a model pre-trained on ImageNet 29 | """ 30 | model = ResNetEncoder(BasicBlock, [3, 4, 6, 3], **kwargs) 31 | if pretrained: 32 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 33 | return model 34 | 35 | 36 | class Decoder(nn.Module): 37 | def __init__(self, feat_dim=512): 38 | super(Decoder, self).__init__() 39 | self.ldecoder = nn.Sequential( 40 | nn.Linear(feat_dim, 128), 41 | nn.ReLU(True), 42 | ) 43 | self.rdecoder = nn.Sequential( 44 | nn.Linear(feat_dim, 128), 45 | nn.ReLU(True), 46 | ) 47 | self.lcoord = nn.Sequential( 48 | nn.Linear(128 + 16 + 3, 64), 49 | nn.ReLU(True), 50 | nn.Linear(64, 2) 51 | ) 52 | self.rcoord = nn.Sequential( 53 | nn.Linear(128 + 16 + 3, 64), 54 | nn.ReLU(True), 55 | nn.Linear(64, 2) 56 | ) 57 | 58 | def forward(self, lfeat, rfeat, head_pose, linfo, rinfo): 59 | l_coord_feat = self.ldecoder(lfeat) 60 | r_coord_feat = self.rdecoder(rfeat) 61 | l_coord = self.lcoord(th.cat([l_coord_feat, head_pose, linfo], 1)) 62 | r_coord = self.rcoord(th.cat([r_coord_feat, head_pose, rinfo], 1)) 63 | coord = (l_coord + r_coord) / 2. 64 | # coord = self.coord(th.cat([l_coord, r_coord], 1)) 65 | return coord 66 | 67 | 68 | class DepthL1(nn.Module): 69 | def __init__(self, th_lower=None, th_upper=None): 70 | super(DepthL1, self).__init__() 71 | self.th_lower = th_lower 72 | self.th_upper = th_upper 73 | 74 | def forward(self, pred, target): 75 | pred = pred.view(pred.size(0), -1) 76 | target = target.view(target.size(0), -1) 77 | if self.th_lower is not None: 78 | with th.no_grad(): 79 | mask_lower = (target > self.th_lower).float() 80 | else: 81 | mask_lower = 1. 82 | if self.th_upper is not None: 83 | with th.no_grad(): 84 | mask_upper = (target < self.th_upper).float() 85 | else: 86 | mask_upper = 1. 87 | 88 | return th.sum(th.abs(pred - target) * mask_lower * mask_upper) / (th.sum(mask_lower * mask_upper) + 1e-5) 89 | 90 | 91 | class ResnetBlock(nn.Module): 92 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 93 | super(ResnetBlock, self).__init__() 94 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 95 | 96 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 97 | conv_block = [] 98 | p = 0 99 | if padding_type == 'reflect': 100 | conv_block += [nn.ReflectionPad2d(1)] 101 | elif padding_type == 'replicate': 102 | conv_block += [nn.ReplicationPad2d(1)] 103 | elif padding_type == 'zero': 104 | p = 1 105 | else: 106 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 107 | 108 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 109 | norm_layer(dim), 110 | nn.ReLU(True)] 111 | if use_dropout: 112 | conv_block += [nn.Dropout(0.5)] 113 | 114 | p = 0 115 | if padding_type == 'reflect': 116 | conv_block += [nn.ReflectionPad2d(1)] 117 | elif padding_type == 'replicate': 118 | conv_block += [nn.ReplicationPad2d(1)] 119 | elif padding_type == 'zero': 120 | p = 1 121 | else: 122 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 123 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 124 | norm_layer(dim)] 125 | 126 | return nn.Sequential(*conv_block) 127 | 128 | def forward(self, x): 129 | out = x + self.conv_block(x) 130 | return F.relu(out) 131 | 132 | 133 | class RefineDepth(nn.Module): 134 | def __init__(self): 135 | super(RefineDepth, self).__init__() 136 | use_bias = False 137 | 138 | self.face_block1 = nn.Sequential( 139 | nn.ReflectionPad2d(3), 140 | nn.Conv2d(3, 64, kernel_size=7, padding=0, bias=use_bias), 141 | nn.BatchNorm2d(64), 142 | nn.ReLU(True) 143 | ) 144 | 145 | self.face_block2 = nn.Sequential( 146 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=use_bias), 147 | nn.BatchNorm2d(128), 148 | nn.ReLU(True) 149 | ) 150 | 151 | self.face_block3 = nn.Sequential( 152 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=use_bias), 153 | nn.BatchNorm2d(256), 154 | nn.ReLU(True) 155 | ) 156 | 157 | self.face_block4 = nn.Sequential( 158 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=use_bias), 159 | nn.BatchNorm2d(512), 160 | nn.ReLU(True) 161 | ) 162 | 163 | self.depth_block1 = nn.Sequential( 164 | nn.ReflectionPad2d(3), 165 | nn.Conv2d(1, 64, kernel_size=7, padding=0, bias=use_bias), 166 | nn.BatchNorm2d(64), 167 | nn.ReLU(True) 168 | ) 169 | 170 | self.depth_block2 = nn.Sequential( 171 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=use_bias), 172 | nn.BatchNorm2d(128), 173 | nn.ReLU(True) 174 | ) 175 | 176 | self.depth_block3 = nn.Sequential( 177 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=use_bias), 178 | nn.BatchNorm2d(256), 179 | nn.ReLU(True) 180 | ) 181 | self.depth_block4 = nn.Sequential( 182 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=use_bias), 183 | nn.BatchNorm2d(512), 184 | nn.ReLU(True) 185 | ) 186 | 187 | self.down1 = nn.Sequential( 188 | nn.Conv2d(1024, 512, kernel_size=1, stride=1, padding=0, bias=use_bias), 189 | nn.BatchNorm2d(512), 190 | nn.ReLU(True), 191 | ResnetBlock(512, padding_type='reflect', norm_layer=nn.BatchNorm2d, 192 | use_dropout=False, use_bias=use_bias), 193 | # ResnetBlock(512, padding_type='reflect', norm_layer=nn.BatchNorm2d, 194 | # use_dropout=False, use_bias=use_bias), 195 | # ResnetBlock(512, padding_type='reflect', norm_layer=nn.BatchNorm2d, 196 | # use_dropout=False, use_bias=use_bias) 197 | ) 198 | self.down2 = nn.Sequential( 199 | nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=use_bias), 200 | nn.BatchNorm2d(256), 201 | nn.ReLU(True), 202 | # ResnetBlock(256, padding_type='reflect', norm_layer=nn.BatchNorm2d, 203 | # use_dropout=False, use_bias=use_bias), 204 | # ResnetBlock(256, padding_type='reflect', norm_layer=nn.BatchNorm2d, 205 | # use_dropout=False, use_bias=use_bias) 206 | ) 207 | 208 | self.down3 = nn.Sequential( 209 | nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=use_bias), 210 | nn.BatchNorm2d(128), 211 | nn.ReLU(True), 212 | # ResnetBlock(128, padding_type='reflect', norm_layer=nn.BatchNorm2d, 213 | # use_dropout=False, use_bias=use_bias) 214 | ) 215 | 216 | self.down4 = nn.Sequential( 217 | nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, bias=use_bias), 218 | nn.BatchNorm2d(64), 219 | nn.ReLU(True), 220 | # ResnetBlock(64, padding_type='reflect', norm_layer=nn.BatchNorm2d, 221 | # use_dropout=False, use_bias=use_bias) 222 | ) 223 | 224 | self.head_pose = nn.Sequential( 225 | nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, bias=use_bias), 226 | nn.BatchNorm2d(512), 227 | nn.ReLU(True), 228 | nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=use_bias), 229 | nn.BatchNorm2d(1024), 230 | nn.ReLU(True), 231 | nn.AvgPool2d(7), 232 | nn.Conv2d(1024, 16, kernel_size=1, stride=1, padding=0, bias=True), 233 | # nn.BatchNorm2d(128), 234 | nn.ReLU(True) 235 | ) 236 | 237 | self.gen_block1 = nn.Sequential( 238 | nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 239 | nn.BatchNorm2d(256), 240 | nn.ReLU(True) 241 | ) 242 | 243 | self.gen_block2 = nn.Sequential( 244 | nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 245 | nn.BatchNorm2d(128), 246 | nn.ReLU(True) 247 | ) 248 | 249 | self.gen_block3 = nn.Sequential( 250 | nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 251 | nn.BatchNorm2d(64), 252 | nn.ReLU(True) 253 | ) 254 | self.gen_block4 = nn.Sequential( 255 | nn.ReflectionPad2d(3), 256 | nn.Conv2d(64, 1, kernel_size=7, padding=0), 257 | nn.Sigmoid() 258 | ) 259 | 260 | def forward(self, face, depth): 261 | face_f1 = self.face_block1(face) 262 | face_f2 = self.face_block2(face_f1) 263 | face_f3 = self.face_block3(face_f2) 264 | face_f4 = self.face_block4(face_f3) 265 | depth_f1 = self.depth_block1(depth) 266 | depth_f2 = self.depth_block2(depth_f1) 267 | depth_f3 = self.depth_block3(depth_f2) 268 | depth_f4 = self.depth_block4(depth_f3) 269 | mixed_f4 = self.down1(th.cat([face_f4, depth_f4], dim=1)) 270 | mixed_f3 = self.down2(th.cat([face_f3, depth_f3], dim=1)) 271 | mixed_f2 = self.down3(th.cat([face_f2, depth_f2], dim=1)) 272 | mixed_f1 = self.down4(th.cat([face_f1, depth_f1], dim=1)) 273 | gen_f3 = self.gen_block1(mixed_f4) + mixed_f3 274 | gen_f2 = self.gen_block2(gen_f3) + mixed_f2 275 | gen_f1 = self.gen_block3(gen_f2) + mixed_f1 276 | gen_depth = self.gen_block4(gen_f1) 277 | head_pose = self.head_pose(mixed_f4) 278 | return head_pose.view(head_pose.size(0), -1), gen_depth 279 | -------------------------------------------------------------------------------- /code/models/gaze_depth_v3.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet, model_urls 2 | import torch.utils.model_zoo as model_zoo 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import Parameter 6 | import torch as th 7 | import pdb 8 | import cv2 9 | import numpy as np 10 | 11 | 12 | class ResNetEncoder(ResNet): 13 | def forward(self, x): 14 | x = self.conv1(x) 15 | x = self.bn1(x) 16 | x = self.relu(x) 17 | # x112_64 = x 18 | x = self.maxpool(x) 19 | x = self.layer1(x) 20 | # x56_64 = x 21 | x = self.layer2(x) 22 | # x28_128 = x 23 | x = self.layer3(x) 24 | # x14_256 = x 25 | x = self.layer4(x) 26 | x = self.avgpool(x) 27 | x = x.view(x.size(0), -1) 28 | x = self.relu(x) 29 | 30 | return x # , x112_64, x56_64, x28_128, x14_256 31 | 32 | 33 | def resnet18(pretrained=False, **kwargs): 34 | """Constructs a ResNet-18 model. 35 | 36 | Args: 37 | pretrained (bool): If True, returns a model pre-trained on ImageNet 38 | """ 39 | model = ResNetEncoder(BasicBlock, [2, 2, 2, 2], **kwargs) 40 | if pretrained: 41 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 42 | return model 43 | 44 | 45 | def resnet34(pretrained=False, **kwargs): 46 | """Constructs a ResNet-34 model. 47 | 48 | Args: 49 | pretrained (bool): If True, returns a model pre-trained on ImageNet 50 | """ 51 | model = ResNetEncoder(BasicBlock, [3, 4, 6, 3], **kwargs) 52 | if pretrained: 53 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 54 | return model 55 | 56 | 57 | def resnet50(pretrained=False, **kwargs): 58 | """Constructs a ResNet-50 model. 59 | 60 | Args: 61 | pretrained (bool): If True, returns a model pre-trained on ImageNet 62 | """ 63 | model = ResNetEncoder(Bottleneck, [3, 4, 6, 3], **kwargs) 64 | if pretrained: 65 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 66 | return model 67 | 68 | 69 | def resnet101(pretrained=False, **kwargs): 70 | """Constructs a ResNet-101 model. 71 | 72 | Args: 73 | pretrained (bool): If True, returns a model pre-trained on ImageNet 74 | """ 75 | model = ResNetEncoder(Bottleneck, [3, 4, 23, 3], **kwargs) 76 | if pretrained: 77 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 78 | return model 79 | 80 | 81 | def resnet152(pretrained=False, **kwargs): 82 | """Constructs a ResNet-152 model. 83 | 84 | Args: 85 | pretrained (bool): If True, returns a model pre-trained on ImageNet 86 | """ 87 | model = ResNetEncoder(Bottleneck, [3, 8, 36, 3], **kwargs) 88 | if pretrained: 89 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 90 | return model 91 | 92 | 93 | class Depth2AbsDepth(nn.Module): 94 | pass 95 | 96 | 97 | class RGBD2AbsDepth(nn.Module): 98 | pass 99 | 100 | 101 | class RGB2AbsDepth(nn.Module): 102 | pass 103 | 104 | 105 | class Depth2RelDepth(nn.Module): 106 | pass 107 | 108 | 109 | def extract_landmark_depth(depth, landmark, scale_factor, bbox, region_size=7, depth_scale=500., 110 | valid_depth_range=90, debug=False): 111 | bs = depth.size(0) 112 | img_size = depth.size(3) 113 | assert depth.size(2) == depth.size(3) 114 | num_landmark = landmark.size(1) 115 | # transform landmarks to face image coordinate system (bs x lm x 2) 116 | face_lm = ((landmark - bbox[:, :2].unsqueeze(1).to(depth)) * scale_factor.unsqueeze(1) / img_size * 2) - 1. 117 | 118 | # sample landmark region (bs x lm x lm_size x lm_size) 119 | # gen sample grid (bs x lm x lm_size x lm_size x 2) 120 | x = th.linspace(-region_size / 2, region_size / 2, region_size) / img_size * 2 121 | grid = th.stack(th.meshgrid([x, x])[::-1], dim=2).to(depth) 122 | grid = face_lm.view(bs, num_landmark, 1, 1, 2) + grid 123 | depth_landmark_regions = F.grid_sample( 124 | depth, grid.view(bs, num_landmark, -1, 2), mode="nearest", padding_mode="zeros" 125 | ).squeeze(1) 126 | 127 | if debug: 128 | # visualize landmark 129 | for dep, lms, lmbs in zip(depth, face_lm, grid): 130 | depth_vis = np.uint8(dep.squeeze(0).cpu().numpy() * 255) 131 | depth_vis = np.stack([depth_vis, depth_vis, depth_vis], axis=2) 132 | for lm, lmb in zip(lms, lmbs): 133 | cv2.circle(depth_vis, tuple(((lm + 1) * 112).long().tolist()), 5, (0, 0, 255), 2) 134 | x1, y1 = int((lmb[0, 0, 0].item() + 1) * 112), int((lmb[0, 0, 1].item() + 1) * 112) 135 | x2, y2 = int((lmb[region_size-1, region_size-1, 0].item() + 1) * 112), int((lmb[region_size-1, region_size-1, 1].item() + 1) * 112) 136 | cv2.rectangle(depth_vis, (x1, y1), (x2, y2), (0, 255, 0), 2) 137 | cv2.imshow("res", depth_vis) 138 | cv2.waitKey() 139 | 140 | # non-zero median if exists, else return zero 141 | depth_landmark_regions_sorted = th.sort(depth_landmark_regions, dim=2)[0] 142 | depth_landmark_regions_mask = depth_landmark_regions_sorted > 1.e-4 143 | depth_landmark_regions_mask[:, :, 0] = 0 144 | depth_landmark_regions_mask[:, :, 1:] = depth_landmark_regions_mask[:, :, 1:] - \ 145 | depth_landmark_regions_mask[:, :, :-1] 146 | depth_landmark_regions_mask[:, :, -1] = ((depth_landmark_regions_mask.sum(dim=2) == 0) + depth_landmark_regions_mask[:, :, -1]) > 0 147 | assert (depth_landmark_regions_mask.sum(dim=2) == 1).all(), f"{th.sum(depth_landmark_regions_mask, dim=2)}\n{depth_landmark_regions_mask[:, :, depth_landmark_regions_mask.size(2) - 1]}\n{depth_landmark_regions_mask.sum(dim=2) == 0}" 148 | nonzero_st = th.nonzero(depth_landmark_regions_mask) 149 | 150 | assert (nonzero_st[1:, 0] - nonzero_st[:-1, 0] >= 0).all() and \ 151 | ((nonzero_st[1:, 0] * num_landmark + nonzero_st[1:, 1]) - 152 | (nonzero_st[:-1, 0] * num_landmark + nonzero_st[:-1, 1]) >= 0).all() 153 | assert nonzero_st.size(0) == bs * num_landmark 154 | median_ind = ((nonzero_st[:, 2] + region_size * region_size - 1) / 2).long() 155 | depth_landmark_regions_sorted = depth_landmark_regions_sorted.view(bs * num_landmark, region_size * region_size) 156 | median = depth_landmark_regions_sorted[range(len(median_ind)), median_ind].view(bs, num_landmark) 157 | median_mask = th.abs(median - th.median(median, dim=1)[0].unsqueeze(1)) < (valid_depth_range / depth_scale) 158 | 159 | return median, median_mask 160 | 161 | 162 | class RelMatrix2Col(nn.Module): 163 | def __init__(self, num_landmark=68): 164 | super(RelMatrix2Col, self).__init__() 165 | xx = th.arange(0, num_landmark).unsqueeze(1).expand(num_landmark, num_landmark) 166 | yy = xx.t() 167 | self.register_buffer("col_mask", xx > yy) 168 | 169 | def forward(self, rel_matrix): 170 | assert th.allclose(rel_matrix, -rel_matrix.transpose(rel_matrix.dim() - 1, rel_matrix.dim() - 2)) 171 | return rel_matrix[..., self.col_mask] 172 | 173 | 174 | class RelCol2Matrix(nn.Module): 175 | def __init__(self, num_landmark=68): 176 | super(RelCol2Matrix, self).__init__() 177 | self.num_landmark = num_landmark 178 | xx = th.arange(0, num_landmark).unsqueeze(1).expand(num_landmark, num_landmark) 179 | yy = xx.t() 180 | self.register_buffer("col_mask", xx > yy) 181 | 182 | def forward(self, rel_column): 183 | rel_matrix = rel_column.new_full((rel_column.size(0), self.num_landmark, self.num_landmark), 0) 184 | rel_matrix[..., self.col_mask] = rel_column 185 | rel_matrix[..., 1 - self.col_mask] = - rel_matrix.transpose( 186 | rel_matrix.dim() - 1, rel_matrix.dim() - 2 187 | )[..., 1 - self.col_mask] 188 | assert th.allclose(rel_matrix, -rel_matrix.transpose(rel_matrix.dim() - 1, rel_matrix.dim() - 2)) 189 | return rel_matrix 190 | 191 | 192 | class RGB2RelDepth(nn.Module): 193 | def __init__(self, num_landmark=68, dim_landmark_feat=128, initial_bias=None): 194 | super(RGB2RelDepth, self).__init__() 195 | self.encoder = resnet18(pretrained=True) 196 | self.num_landmark = num_landmark 197 | self.dim_rel_col = int(num_landmark * (num_landmark - 1) // 2) 198 | self.dim_embedding = dim_landmark_feat 199 | self.rel_depth = nn.Sequential( 200 | nn.Linear(512, self.dim_rel_col, bias=True), 201 | RelCol2Matrix(num_landmark=num_landmark) 202 | ) 203 | if initial_bias is not None: 204 | nn.init.constant_(self.rel_depth[0].bias, initial_bias) 205 | 206 | def forward(self, face_image): 207 | feat = self.encoder(face_image) 208 | rel_depth = self.rel_depth(feat) 209 | 210 | return rel_depth 211 | 212 | 213 | class LossRelDepth(nn.Module): 214 | def __init__(self, crit, num_landmark=68, image_size=224, landmark_region_size=7, depth_scale=500., 215 | valid_depth_range=90): 216 | super(LossRelDepth, self).__init__() 217 | self.crit = crit 218 | self.num_landmark = num_landmark 219 | self.image_size = image_size 220 | self.lm_region_size = landmark_region_size 221 | self.depth_scale = depth_scale 222 | self.valid_depth_range = valid_depth_range 223 | xx = th.arange(0, num_landmark).unsqueeze(1).expand(num_landmark, num_landmark) 224 | self.register_buffer("diag_mask", xx == xx.t()) 225 | 226 | def repeat_as_col(self, col): 227 | return col.view( 228 | col.size(0), self.num_landmark, 1, -1 229 | ).expand( 230 | col.size(0), self.num_landmark, self.num_landmark, -1 231 | ).squeeze() 232 | 233 | def repeat_as_row(self, row): 234 | return row.view( 235 | row.size(0), 1, self.num_landmark, -1 236 | ).expand( 237 | row.size(0), self.num_landmark, self.num_landmark, -1 238 | ).squeeze() 239 | 240 | def forward(self, rel_depth_pred, depth, landmarks, scale_factor, bbox): 241 | bs = rel_depth_pred.size(0) 242 | with th.no_grad(): 243 | landmark_dist = th.norm(self.repeat_as_col(landmarks) - self.repeat_as_row(landmarks), dim=3) 244 | assert th.allclose(landmark_dist, landmark_dist.transpose(1, 2)) 245 | median, median_mask = extract_landmark_depth( 246 | depth=depth, 247 | landmark=landmarks, 248 | scale_factor=scale_factor, 249 | bbox=bbox, 250 | region_size=self.lm_region_size, 251 | depth_scale=self.depth_scale, 252 | valid_depth_range=self.valid_depth_range 253 | ) 254 | median *= self.depth_scale 255 | median_diff = (self.repeat_as_col(median) - self.repeat_as_row(median)) / scale_factor.view(bs, 1, 1) 256 | median_rel_mask = (self.repeat_as_col(median_mask) * self.repeat_as_row(median_mask)).to(rel_depth_pred) 257 | # assert (median_diff[..., self.diag_mask] == 0).all() 258 | assert th.allclose(median_diff, -median_diff.transpose(1, 2)) 259 | loss_ele = self.crit(rel_depth_pred, median_diff, reduction='none') 260 | loss = th.sum(loss_ele * median_rel_mask) / (th.sum(median_rel_mask) + 1e-4) 261 | 262 | return loss, median_diff, median_rel_mask 263 | 264 | 265 | # class RGB2RelDepth(nn.Module): 266 | # pass 267 | -------------------------------------------------------------------------------- /code/models/gaze_depth_v4.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet, model_urls 2 | import torch.utils.model_zoo as model_zoo 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import Parameter 6 | import torch as th 7 | import pdb 8 | import cv2 9 | import numpy as np 10 | 11 | 12 | class ResNetEncoder(ResNet): 13 | def forward(self, x): 14 | x = self.conv1(x) 15 | x = self.bn1(x) 16 | x = self.relu(x) 17 | # x112_64 = x 18 | x = self.maxpool(x) 19 | x = self.layer1(x) 20 | # x56_64 = x 21 | x = self.layer2(x) 22 | # x28_128 = x 23 | x = self.layer3(x) 24 | # x14_256 = x 25 | x = self.layer4(x) 26 | x = self.avgpool(x) 27 | x = x.view(x.size(0), -1) 28 | x = self.relu(x) 29 | 30 | return x # , x112_64, x56_64, x28_128, x14_256 31 | 32 | 33 | def resnet18(pretrained=False, **kwargs): 34 | """Constructs a ResNet-18 model. 35 | 36 | Args: 37 | pretrained (bool): If True, returns a model pre-trained on ImageNet 38 | """ 39 | model = ResNetEncoder(BasicBlock, [2, 2, 2, 2], **kwargs) 40 | if pretrained: 41 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 42 | return model 43 | 44 | 45 | def resnet34(pretrained=False, **kwargs): 46 | """Constructs a ResNet-34 model. 47 | 48 | Args: 49 | pretrained (bool): If True, returns a model pre-trained on ImageNet 50 | """ 51 | model = ResNetEncoder(BasicBlock, [3, 4, 6, 3], **kwargs) 52 | if pretrained: 53 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 54 | return model 55 | 56 | 57 | def resnet50(pretrained=False, **kwargs): 58 | """Constructs a ResNet-50 model. 59 | 60 | Args: 61 | pretrained (bool): If True, returns a model pre-trained on ImageNet 62 | """ 63 | model = ResNetEncoder(Bottleneck, [3, 4, 6, 3], **kwargs) 64 | if pretrained: 65 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 66 | return model 67 | 68 | 69 | def resnet101(pretrained=False, **kwargs): 70 | """Constructs a ResNet-101 model. 71 | 72 | Args: 73 | pretrained (bool): If True, returns a model pre-trained on ImageNet 74 | """ 75 | model = ResNetEncoder(Bottleneck, [3, 4, 23, 3], **kwargs) 76 | if pretrained: 77 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 78 | return model 79 | 80 | 81 | def resnet152(pretrained=False, **kwargs): 82 | """Constructs a ResNet-152 model. 83 | 84 | Args: 85 | pretrained (bool): If True, returns a model pre-trained on ImageNet 86 | """ 87 | model = ResNetEncoder(Bottleneck, [3, 8, 36, 3], **kwargs) 88 | if pretrained: 89 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 90 | return model 91 | 92 | 93 | class Depth2AbsDepth(nn.Module): 94 | pass 95 | 96 | 97 | class RGBD2AbsDepth(nn.Module): 98 | pass 99 | 100 | 101 | class RGB2AbsDepth(nn.Module): 102 | pass 103 | 104 | 105 | class Depth2RelDepth(nn.Module): 106 | pass 107 | 108 | 109 | def extract_landmark_depth(depth, landmark, scale_factor, bbox, region_size=7, depth_scale=500., 110 | valid_depth_range=90, debug=False): 111 | bs = depth.size(0) 112 | img_size = depth.size(3) 113 | assert depth.size(2) == depth.size(3) 114 | num_landmark = landmark.size(1) 115 | # transform landmarks to face image coordinate system (bs x lm x 2) 116 | face_lm = ((landmark - bbox[:, :2].unsqueeze(1).to(depth)) * scale_factor.unsqueeze(1) / img_size * 2) - 1. 117 | 118 | # sample landmark region (bs x lm x lm_size x lm_size) 119 | # gen sample grid (bs x lm x lm_size x lm_size x 2) 120 | x = th.linspace(-region_size / 2, region_size / 2, region_size) / img_size * 2 121 | grid = th.stack(th.meshgrid([x, x])[::-1], dim=2).to(depth) 122 | grid = face_lm.view(bs, num_landmark, 1, 1, 2) + grid 123 | depth_landmark_regions = F.grid_sample( 124 | depth, grid.view(bs, num_landmark, -1, 2), mode="nearest", padding_mode="zeros" 125 | ).squeeze(1) 126 | 127 | if debug: 128 | # visualize landmark 129 | for dep, lms, lmbs in zip(depth, face_lm, grid): 130 | depth_vis = np.uint8(dep.squeeze(0).cpu().numpy() * 255) 131 | depth_vis = np.stack([depth_vis, depth_vis, depth_vis], axis=2) 132 | for lm, lmb in zip(lms, lmbs): 133 | cv2.circle(depth_vis, tuple(((lm + 1) * 112).long().tolist()), 5, (0, 0, 255), 2) 134 | x1, y1 = int((lmb[0, 0, 0].item() + 1) * 112), int((lmb[0, 0, 1].item() + 1) * 112) 135 | x2, y2 = int((lmb[region_size-1, region_size-1, 0].item() + 1) * 112), int((lmb[region_size-1, region_size-1, 1].item() + 1) * 112) 136 | cv2.rectangle(depth_vis, (x1, y1), (x2, y2), (0, 255, 0), 2) 137 | cv2.imshow("res", depth_vis) 138 | cv2.waitKey() 139 | 140 | # non-zero median if exists, else return zero 141 | depth_landmark_regions_sorted = th.sort(depth_landmark_regions, dim=2)[0] 142 | depth_landmark_regions_mask = depth_landmark_regions_sorted > 1.e-4 143 | depth_landmark_regions_mask[:, :, 0] = 0 144 | depth_landmark_regions_mask[:, :, 1:] = depth_landmark_regions_mask[:, :, 1:] - \ 145 | depth_landmark_regions_mask[:, :, :-1] 146 | depth_landmark_regions_mask[:, :, -1] = ((depth_landmark_regions_mask.sum(dim=2) == 0) + depth_landmark_regions_mask[:, :, -1]) > 0 147 | assert (depth_landmark_regions_mask.sum(dim=2) == 1).all(), f"{th.sum(depth_landmark_regions_mask, dim=2)}\n{depth_landmark_regions_mask[:, :, depth_landmark_regions_mask.size(2) - 1]}\n{depth_landmark_regions_mask.sum(dim=2) == 0}" 148 | nonzero_st = th.nonzero(depth_landmark_regions_mask) 149 | 150 | assert (nonzero_st[1:, 0] - nonzero_st[:-1, 0] >= 0).all() and \ 151 | ((nonzero_st[1:, 0] * num_landmark + nonzero_st[1:, 1]) - 152 | (nonzero_st[:-1, 0] * num_landmark + nonzero_st[:-1, 1]) >= 0).all() 153 | assert nonzero_st.size(0) == bs * num_landmark 154 | median_ind = ((nonzero_st[:, 2] + region_size * region_size - 1) / 2).long() 155 | depth_landmark_regions_sorted = depth_landmark_regions_sorted.view(bs * num_landmark, region_size * region_size) 156 | median = depth_landmark_regions_sorted[range(len(median_ind)), median_ind].view(bs, num_landmark) 157 | median_mask = th.abs(median - th.median(median, dim=1)[0].unsqueeze(1)) < (valid_depth_range / depth_scale) 158 | 159 | return median, median_mask 160 | 161 | 162 | class RelMatrix2Col(nn.Module): 163 | def __init__(self, num_landmark=68): 164 | super(RelMatrix2Col, self).__init__() 165 | xx = th.arange(0, num_landmark).unsqueeze(1).expand(num_landmark, num_landmark) 166 | yy = xx.t() 167 | self.register_buffer("col_mask", xx > yy) 168 | 169 | def forward(self, rel_matrix): 170 | assert th.allclose(rel_matrix, -rel_matrix.transpose(rel_matrix.dim() - 1, rel_matrix.dim() - 2)) 171 | return rel_matrix[..., self.col_mask] 172 | 173 | 174 | class RelCol2Matrix(nn.Module): 175 | def __init__(self, num_landmark=68): 176 | super(RelCol2Matrix, self).__init__() 177 | self.num_landmark = num_landmark 178 | xx = th.arange(0, num_landmark).unsqueeze(1).expand(num_landmark, num_landmark) 179 | yy = xx.t() 180 | self.register_buffer("col_mask", xx > yy) 181 | 182 | def forward(self, rel_column): 183 | rel_matrix = rel_column.new_full((rel_column.size(0), self.num_landmark, self.num_landmark), 0) 184 | rel_matrix[..., self.col_mask] = rel_column 185 | rel_matrix[..., 1 - self.col_mask] = - rel_matrix.transpose( 186 | rel_matrix.dim() - 1, rel_matrix.dim() - 2 187 | )[..., 1 - self.col_mask] 188 | assert th.allclose(rel_matrix, -rel_matrix.transpose(rel_matrix.dim() - 1, rel_matrix.dim() - 2)) 189 | return rel_matrix 190 | 191 | 192 | class RGB2RelDepth(nn.Module): 193 | def __init__(self, num_landmark=68, dim_landmark_feat=128, initial_bias=None): 194 | super(RGB2RelDepth, self).__init__() 195 | self.encoder = resnet18(pretrained=True) 196 | self.num_landmark = num_landmark 197 | self.dim_rel_col = int(num_landmark * (num_landmark - 1) // 2) 198 | self.dim_embedding = dim_landmark_feat 199 | self.rel_depth = nn.Sequential( 200 | nn.Linear(512, self.dim_rel_col, bias=True), 201 | RelCol2Matrix(num_landmark=num_landmark) 202 | ) 203 | if initial_bias is not None: 204 | nn.init.constant_(self.rel_depth[0].bias, initial_bias) 205 | 206 | def forward(self, face_image): 207 | feat = self.encoder(face_image) 208 | rel_depth = self.rel_depth(feat) 209 | 210 | return rel_depth 211 | 212 | 213 | class LossRelDepth(nn.Module): 214 | def __init__(self, crit, num_landmark=68, image_size=224, landmark_region_size=7, depth_scale=500., 215 | valid_depth_range=90): 216 | super(LossRelDepth, self).__init__() 217 | self.crit = crit 218 | self.num_landmark = num_landmark 219 | self.image_size = image_size 220 | self.lm_region_size = landmark_region_size 221 | self.depth_scale = depth_scale 222 | self.valid_depth_range = valid_depth_range 223 | xx = th.arange(0, num_landmark).unsqueeze(1).expand(num_landmark, num_landmark) 224 | self.register_buffer("diag_mask", xx == xx.t()) 225 | 226 | def repeat_as_col(self, col): 227 | return col.view( 228 | col.size(0), self.num_landmark, 1, -1 229 | ).expand( 230 | col.size(0), self.num_landmark, self.num_landmark, -1 231 | ).squeeze() 232 | 233 | def repeat_as_row(self, row): 234 | return row.view( 235 | row.size(0), 1, self.num_landmark, -1 236 | ).expand( 237 | row.size(0), self.num_landmark, self.num_landmark, -1 238 | ).squeeze() 239 | 240 | def forward(self, rel_depth_pred, depth, landmarks, scale_factor, bbox): 241 | bs = rel_depth_pred.size(0) 242 | with th.no_grad(): 243 | landmark_dist = th.norm(self.repeat_as_col(landmarks) - self.repeat_as_row(landmarks), dim=3) 244 | assert th.allclose(landmark_dist, landmark_dist.transpose(1, 2)) 245 | median, median_mask = extract_landmark_depth( 246 | depth=depth, 247 | landmark=landmarks, 248 | scale_factor=scale_factor, 249 | bbox=bbox, 250 | region_size=self.lm_region_size, 251 | depth_scale=self.depth_scale, 252 | valid_depth_range=self.valid_depth_range 253 | ) 254 | median *= self.depth_scale 255 | median_diff = (self.repeat_as_col(median) - self.repeat_as_row(median)) #/ scale_factor.view(bs, 1, 1) 256 | median_rel_mask = (self.repeat_as_col(median_mask) * self.repeat_as_row(median_mask)).to(rel_depth_pred) 257 | # assert (median_diff[..., self.diag_mask] == 0).all() 258 | assert th.allclose(median_diff, -median_diff.transpose(1, 2)) 259 | loss_ele = self.crit(rel_depth_pred, median_diff, reduction='none') 260 | loss = th.sum(loss_ele * median_rel_mask) / (th.sum(median_rel_mask) + 1e-4) 261 | 262 | return loss, median_diff, median_rel_mask 263 | 264 | 265 | # class RGB2RelDepth(nn.Module): 266 | # pass 267 | -------------------------------------------------------------------------------- /code/models/gaze_depth_v5.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet, model_urls 2 | import torch.utils.model_zoo as model_zoo 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import Parameter 6 | import torch as th 7 | import pdb 8 | import cv2 9 | import numpy as np 10 | 11 | 12 | class ResNetEncoder(ResNet): 13 | def forward(self, x): 14 | x = self.conv1(x) 15 | x = self.bn1(x) 16 | x = self.relu(x) 17 | # x112_64 = x 18 | x = self.maxpool(x) 19 | x = self.layer1(x) 20 | # x56_64 = x 21 | x = self.layer2(x) 22 | # x28_128 = x 23 | x = self.layer3(x) 24 | # x14_256 = x 25 | x = self.layer4(x) 26 | x = self.avgpool(x) 27 | x = x.view(x.size(0), -1) 28 | x = self.relu(x) 29 | 30 | return x # , x112_64, x56_64, x28_128, x14_256 31 | 32 | 33 | def resnet18(pretrained=False, **kwargs): 34 | """Constructs a ResNet-18 model. 35 | 36 | Args: 37 | pretrained (bool): If True, returns a model pre-trained on ImageNet 38 | """ 39 | model = ResNetEncoder(BasicBlock, [2, 2, 2, 2], **kwargs) 40 | if pretrained: 41 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 42 | return model 43 | 44 | 45 | def resnet34(pretrained=False, **kwargs): 46 | """Constructs a ResNet-34 model. 47 | 48 | Args: 49 | pretrained (bool): If True, returns a model pre-trained on ImageNet 50 | """ 51 | model = ResNetEncoder(BasicBlock, [3, 4, 6, 3], **kwargs) 52 | if pretrained: 53 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 54 | return model 55 | 56 | 57 | def resnet50(pretrained=False, **kwargs): 58 | """Constructs a ResNet-50 model. 59 | 60 | Args: 61 | pretrained (bool): If True, returns a model pre-trained on ImageNet 62 | """ 63 | model = ResNetEncoder(Bottleneck, [3, 4, 6, 3], **kwargs) 64 | if pretrained: 65 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 66 | return model 67 | 68 | 69 | def resnet101(pretrained=False, **kwargs): 70 | """Constructs a ResNet-101 model. 71 | 72 | Args: 73 | pretrained (bool): If True, returns a model pre-trained on ImageNet 74 | """ 75 | model = ResNetEncoder(Bottleneck, [3, 4, 23, 3], **kwargs) 76 | if pretrained: 77 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 78 | return model 79 | 80 | 81 | def resnet152(pretrained=False, **kwargs): 82 | """Constructs a ResNet-152 model. 83 | 84 | Args: 85 | pretrained (bool): If True, returns a model pre-trained on ImageNet 86 | """ 87 | model = ResNetEncoder(Bottleneck, [3, 8, 36, 3], **kwargs) 88 | if pretrained: 89 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 90 | return model 91 | 92 | 93 | class Depth2AbsDepth(nn.Module): 94 | pass 95 | 96 | 97 | class RGBD2AbsDepth(nn.Module): 98 | pass 99 | 100 | 101 | class RGB2AbsDepth(nn.Module): 102 | pass 103 | 104 | 105 | class Depth2RelDepth(nn.Module): 106 | pass 107 | 108 | 109 | def extract_landmark_depth(depth, landmark, scale_factor, bbox, region_size=7, depth_scale=500., 110 | valid_depth_range=90, debug=False): 111 | bs = depth.size(0) 112 | img_size = depth.size(3) 113 | assert depth.size(2) == depth.size(3) 114 | num_landmark = landmark.size(1) 115 | # transform landmarks to face image coordinate system (bs x lm x 2) 116 | face_lm = ((landmark - bbox[:, :2].unsqueeze(1).to(depth)) * scale_factor.unsqueeze(1) / img_size * 2) - 1. 117 | 118 | # sample landmark region (bs x lm x lm_size x lm_size) 119 | # gen sample grid (bs x lm x lm_size x lm_size x 2) 120 | x = th.linspace(-region_size / 2, region_size / 2, region_size) / img_size * 2 121 | grid = th.stack(th.meshgrid([x, x])[::-1], dim=2).to(depth) 122 | grid = face_lm.view(bs, num_landmark, 1, 1, 2) + grid 123 | depth_landmark_regions = F.grid_sample( 124 | depth, grid.view(bs, num_landmark, -1, 2), mode="nearest", padding_mode="zeros" 125 | ).squeeze(1) 126 | 127 | if debug: 128 | # visualize landmark 129 | for dep, lms, lmbs in zip(depth, face_lm, grid): 130 | depth_vis = np.uint8(dep.squeeze(0).cpu().numpy() * 255) 131 | depth_vis = np.stack([depth_vis, depth_vis, depth_vis], axis=2) 132 | for lm, lmb in zip(lms, lmbs): 133 | cv2.circle(depth_vis, tuple(((lm + 1) * 112).long().tolist()), 5, (0, 0, 255), 2) 134 | x1, y1 = int((lmb[0, 0, 0].item() + 1) * 112), int((lmb[0, 0, 1].item() + 1) * 112) 135 | x2, y2 = int((lmb[region_size-1, region_size-1, 0].item() + 1) * 112), int((lmb[region_size-1, region_size-1, 1].item() + 1) * 112) 136 | cv2.rectangle(depth_vis, (x1, y1), (x2, y2), (0, 255, 0), 2) 137 | cv2.imshow("res", depth_vis) 138 | cv2.waitKey() 139 | 140 | # non-zero median if exists, else return zero 141 | depth_landmark_regions_sorted = th.sort(depth_landmark_regions, dim=2)[0] 142 | depth_landmark_regions_mask = depth_landmark_regions_sorted > 1.e-4 143 | depth_landmark_regions_mask[:, :, 0] = 0 144 | depth_landmark_regions_mask[:, :, 1:] = depth_landmark_regions_mask[:, :, 1:] - \ 145 | depth_landmark_regions_mask[:, :, :-1] 146 | depth_landmark_regions_mask[:, :, -1] = ((depth_landmark_regions_mask.sum(dim=2) == 0) + depth_landmark_regions_mask[:, :, -1]) > 0 147 | assert (depth_landmark_regions_mask.sum(dim=2) == 1).all(), f"{th.sum(depth_landmark_regions_mask, dim=2)}\n{depth_landmark_regions_mask[:, :, depth_landmark_regions_mask.size(2) - 1]}\n{depth_landmark_regions_mask.sum(dim=2) == 0}" 148 | nonzero_st = th.nonzero(depth_landmark_regions_mask) 149 | 150 | assert (nonzero_st[1:, 0] - nonzero_st[:-1, 0] >= 0).all() and \ 151 | ((nonzero_st[1:, 0] * num_landmark + nonzero_st[1:, 1]) - 152 | (nonzero_st[:-1, 0] * num_landmark + nonzero_st[:-1, 1]) >= 0).all() 153 | assert nonzero_st.size(0) == bs * num_landmark 154 | median_ind = ((nonzero_st[:, 2] + region_size * region_size - 1) / 2).long() 155 | depth_landmark_regions_sorted = depth_landmark_regions_sorted.view(bs * num_landmark, region_size * region_size) 156 | median = depth_landmark_regions_sorted[range(len(median_ind)), median_ind].view(bs, num_landmark) 157 | median_mask = th.abs(median - th.median(median, dim=1)[0].unsqueeze(1)) < (valid_depth_range / depth_scale) 158 | 159 | return median, median_mask 160 | 161 | 162 | class RelMatrix2Col(nn.Module): 163 | def __init__(self, num_landmark=68): 164 | super(RelMatrix2Col, self).__init__() 165 | xx = th.arange(0, num_landmark).unsqueeze(1).expand(num_landmark, num_landmark) 166 | yy = xx.t() 167 | self.register_buffer("col_mask", xx > yy) 168 | 169 | def forward(self, rel_matrix): 170 | assert th.allclose(rel_matrix, -rel_matrix.transpose(rel_matrix.dim() - 1, rel_matrix.dim() - 2)) 171 | return rel_matrix[..., self.col_mask] 172 | 173 | 174 | class RelCol2Matrix(nn.Module): 175 | def __init__(self, num_landmark=68): 176 | super(RelCol2Matrix, self).__init__() 177 | self.num_landmark = num_landmark 178 | xx = th.arange(0, num_landmark).unsqueeze(1).expand(num_landmark, num_landmark) 179 | yy = xx.t() 180 | self.register_buffer("col_mask", xx > yy) 181 | 182 | def forward(self, rel_column): 183 | rel_matrix = rel_column.new_full((rel_column.size(0), self.num_landmark, self.num_landmark), 0) 184 | rel_matrix[..., self.col_mask] = rel_column 185 | rel_matrix[..., 1 - self.col_mask] = - rel_matrix.transpose( 186 | rel_matrix.dim() - 1, rel_matrix.dim() - 2 187 | )[..., 1 - self.col_mask] 188 | assert th.allclose(rel_matrix, -rel_matrix.transpose(rel_matrix.dim() - 1, rel_matrix.dim() - 2)) 189 | return rel_matrix 190 | 191 | 192 | class RGB2RelDepth(nn.Module): 193 | def __init__(self, num_landmark=68, dim_landmark_feat=128, initial_bias=None): 194 | super(RGB2RelDepth, self).__init__() 195 | self.encoder = resnet18(pretrained=True) 196 | self.num_landmark = num_landmark 197 | self.dim_rel_col = int(num_landmark * (num_landmark - 1) // 2) 198 | self.dim_embedding = dim_landmark_feat 199 | self.rel_depth = nn.Sequential( 200 | nn.Linear(512, self.dim_rel_col, bias=True), 201 | RelCol2Matrix(num_landmark=num_landmark) 202 | ) 203 | if initial_bias is not None: 204 | nn.init.constant_(self.rel_depth[0].bias, initial_bias) 205 | 206 | def forward(self, face_image): 207 | feat = self.encoder(face_image) 208 | rel_depth = self.rel_depth(feat) 209 | 210 | return rel_depth 211 | 212 | 213 | class LossRelDepth(nn.Module): 214 | def __init__(self, crit, num_landmark=68, image_size=224, landmark_region_size=7, depth_scale=500., 215 | valid_depth_range=90): 216 | super(LossRelDepth, self).__init__() 217 | self.crit = crit 218 | self.num_landmark = num_landmark 219 | self.image_size = image_size 220 | self.lm_region_size = landmark_region_size 221 | self.depth_scale = depth_scale 222 | self.valid_depth_range = valid_depth_range 223 | xx = th.arange(0, num_landmark).unsqueeze(1).expand(num_landmark, num_landmark) 224 | self.register_buffer("diag_mask", xx == xx.t()) 225 | 226 | def repeat_as_col(self, col): 227 | return col.view( 228 | col.size(0), self.num_landmark, 1, -1 229 | ).expand( 230 | col.size(0), self.num_landmark, self.num_landmark, -1 231 | ).squeeze() 232 | 233 | def repeat_as_row(self, row): 234 | return row.view( 235 | row.size(0), 1, self.num_landmark, -1 236 | ).expand( 237 | row.size(0), self.num_landmark, self.num_landmark, -1 238 | ).squeeze() 239 | 240 | def forward(self, rel_depth_pred, depth, landmarks, scale_factor, bbox): 241 | bs = rel_depth_pred.size(0) 242 | with th.no_grad(): 243 | landmark_dist = th.norm(self.repeat_as_col(landmarks) - self.repeat_as_row(landmarks), dim=3) 244 | assert th.allclose(landmark_dist, landmark_dist.transpose(1, 2)) 245 | median, median_mask = extract_landmark_depth( 246 | depth=depth, 247 | landmark=landmarks, 248 | scale_factor=scale_factor, 249 | bbox=bbox, 250 | region_size=self.lm_region_size, 251 | depth_scale=self.depth_scale, 252 | valid_depth_range=self.valid_depth_range 253 | ) 254 | median *= self.depth_scale 255 | median_diff = (self.repeat_as_col(median) - self.repeat_as_row(median)) / scale_factor.view(bs, 1, 1) 256 | median_rel_mask = (self.repeat_as_col(median_mask) * self.repeat_as_row(median_mask)).to(rel_depth_pred) 257 | # assert (median_diff[..., self.diag_mask] == 0).all() 258 | assert th.allclose(median_diff, -median_diff.transpose(1, 2)) 259 | loss_ele = self.crit(rel_depth_pred, median_diff, reduction='none') 260 | loss = th.sum(loss_ele * median_rel_mask) / (th.sum(median_rel_mask) + 1e-4) 261 | 262 | return loss, median_diff, median_rel_mask 263 | 264 | 265 | # class RGB2RelDepth(nn.Module): 266 | # pass 267 | -------------------------------------------------------------------------------- /code/models/gaze_aaai_dv2.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.resnet import BasicBlock, ResNet, model_urls 2 | import torch.utils.model_zoo as model_zoo 3 | from torch import nn 4 | import torch as th 5 | import torch.nn.functional as F 6 | 7 | 8 | class ResNetEncoder(ResNet): 9 | def forward(self, x): 10 | x = self.conv1(x) 11 | x = self.bn1(x) 12 | x = self.relu(x) 13 | x = self.maxpool(x) 14 | x = self.layer1(x) 15 | x = self.layer2(x) 16 | x = self.layer3(x) 17 | x = self.layer4(x) 18 | x = self.avgpool(x) 19 | x = x.view(x.size(0), -1) 20 | x = self.relu(x) 21 | return x 22 | 23 | 24 | def resnet34(pretrained=False, **kwargs): 25 | """Constructs a ResNet-34 model. 26 | 27 | Args: 28 | pretrained (bool): If True, returns a model pre-trained on ImageNet 29 | """ 30 | model = ResNetEncoder(BasicBlock, [3, 4, 6, 3], **kwargs) 31 | if pretrained: 32 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 33 | return model 34 | 35 | 36 | class Decoder(nn.Module): 37 | def __init__(self, feat_dim=512): 38 | super(Decoder, self).__init__() 39 | self.ldecoder = nn.Sequential( 40 | nn.Linear(feat_dim, 128), 41 | nn.ReLU(True), 42 | ) 43 | self.rdecoder = nn.Sequential( 44 | nn.Linear(feat_dim, 128), 45 | nn.ReLU(True), 46 | ) 47 | self.lcoord = nn.Sequential( 48 | nn.Linear(128 + 16 + 3, 64), 49 | nn.ReLU(True), 50 | nn.Linear(64, 2) 51 | ) 52 | self.rcoord = nn.Sequential( 53 | nn.Linear(128 + 16 + 3, 64), 54 | nn.ReLU(True), 55 | nn.Linear(64, 2) 56 | ) 57 | 58 | def forward(self, lfeat, rfeat, head_pose, linfo, rinfo): 59 | l_coord_feat = self.ldecoder(lfeat) 60 | r_coord_feat = self.rdecoder(rfeat) 61 | l_coord = self.lcoord(th.cat([l_coord_feat, head_pose, linfo], 1)) 62 | r_coord = self.rcoord(th.cat([r_coord_feat, head_pose, rinfo], 1)) 63 | coord = (l_coord + r_coord) / 2. 64 | # coord = self.coord(th.cat([l_coord, r_coord], 1)) 65 | return coord 66 | 67 | 68 | class DepthL1(nn.Module): 69 | def __init__(self, th_lower=None, th_upper=None): 70 | super(DepthL1, self).__init__() 71 | self.th_lower = th_lower 72 | self.th_upper = th_upper 73 | 74 | def forward(self, pred, target): 75 | pred = pred.view(pred.size(0), -1) 76 | target = target.view(target.size(0), -1) 77 | if self.th_lower is not None: 78 | with th.no_grad(): 79 | mask_lower = (target > self.th_lower).float() 80 | else: 81 | mask_lower = 1. 82 | if self.th_upper is not None: 83 | with th.no_grad(): 84 | mask_upper = (target < self.th_upper).float() 85 | else: 86 | mask_upper = 1. 87 | 88 | return th.sum(th.abs(pred - target) * mask_lower * mask_upper) / (th.sum(mask_lower * mask_upper) + 1e-5) 89 | 90 | 91 | class DepthBCE(nn.Module): 92 | def __init__(self, th_lower=None, th_upper=None): 93 | super(DepthBCE, self).__init__() 94 | self.th_lower = th_lower 95 | self.th_upper = th_upper 96 | 97 | def forward(self, pred, target): 98 | pred = pred.view(pred.size(0), -1) 99 | target = target.view(target.size(0), -1) 100 | if self.th_lower is not None: 101 | with th.no_grad(): 102 | mask_lower = (target > self.th_lower).float() 103 | else: 104 | mask_lower = 1. 105 | if self.th_upper is not None: 106 | with th.no_grad(): 107 | mask_upper = (target < self.th_upper).float() 108 | else: 109 | mask_upper = 1. 110 | 111 | weight = mask_upper * mask_lower 112 | return F.binary_cross_entropy(pred, target, weight=weight if not isinstance(weight, float) else None) 113 | 114 | 115 | class ResnetBlock(nn.Module): 116 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 117 | super(ResnetBlock, self).__init__() 118 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 119 | 120 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 121 | conv_block = [] 122 | p = 0 123 | if padding_type == 'reflect': 124 | conv_block += [nn.ReflectionPad2d(1)] 125 | elif padding_type == 'replicate': 126 | conv_block += [nn.ReplicationPad2d(1)] 127 | elif padding_type == 'zero': 128 | p = 1 129 | else: 130 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 131 | 132 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 133 | norm_layer(dim), 134 | nn.ReLU(True)] 135 | if use_dropout: 136 | conv_block += [nn.Dropout(0.5)] 137 | 138 | p = 0 139 | if padding_type == 'reflect': 140 | conv_block += [nn.ReflectionPad2d(1)] 141 | elif padding_type == 'replicate': 142 | conv_block += [nn.ReplicationPad2d(1)] 143 | elif padding_type == 'zero': 144 | p = 1 145 | else: 146 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 147 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 148 | norm_layer(dim)] 149 | 150 | return nn.Sequential(*conv_block) 151 | 152 | def forward(self, x): 153 | out = x + self.conv_block(x) 154 | return F.relu(out) 155 | 156 | 157 | class RefineDepth(nn.Module): 158 | def __init__(self): 159 | super(RefineDepth, self).__init__() 160 | use_bias = False 161 | 162 | self.face_block1 = nn.Sequential( 163 | nn.ReflectionPad2d(3), 164 | nn.Conv2d(3, 64, kernel_size=7, padding=0, bias=use_bias), 165 | nn.BatchNorm2d(64), 166 | nn.ReLU(True) 167 | ) 168 | 169 | self.face_block2 = nn.Sequential( 170 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=use_bias), 171 | nn.BatchNorm2d(128), 172 | nn.ReLU(True) 173 | ) 174 | 175 | self.face_block3 = nn.Sequential( 176 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=use_bias), 177 | nn.BatchNorm2d(256), 178 | nn.ReLU(True) 179 | ) 180 | 181 | self.face_block4 = nn.Sequential( 182 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=use_bias), 183 | nn.BatchNorm2d(512), 184 | nn.ReLU(True) 185 | ) 186 | 187 | self.depth_block1 = nn.Sequential( 188 | nn.ReflectionPad2d(3), 189 | nn.Conv2d(1, 64, kernel_size=7, padding=0, bias=use_bias), 190 | nn.BatchNorm2d(64), 191 | nn.ReLU(True) 192 | ) 193 | 194 | self.depth_block2 = nn.Sequential( 195 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=use_bias), 196 | nn.BatchNorm2d(128), 197 | nn.ReLU(True) 198 | ) 199 | 200 | self.depth_block3 = nn.Sequential( 201 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=use_bias), 202 | nn.BatchNorm2d(256), 203 | nn.ReLU(True) 204 | ) 205 | self.depth_block4 = nn.Sequential( 206 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=use_bias), 207 | nn.BatchNorm2d(512), 208 | nn.ReLU(True) 209 | ) 210 | 211 | self.down1 = nn.Sequential( 212 | nn.Conv2d(1024, 512, kernel_size=1, stride=1, padding=0, bias=use_bias), 213 | nn.BatchNorm2d(512), 214 | nn.ReLU(True), 215 | ResnetBlock(512, padding_type='reflect', norm_layer=nn.BatchNorm2d, 216 | use_dropout=False, use_bias=use_bias), 217 | # ResnetBlock(512, padding_type='reflect', norm_layer=nn.BatchNorm2d, 218 | # use_dropout=False, use_bias=use_bias), 219 | # ResnetBlock(512, padding_type='reflect', norm_layer=nn.BatchNorm2d, 220 | # use_dropout=False, use_bias=use_bias) 221 | ) 222 | self.down2 = nn.Sequential( 223 | nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=use_bias), 224 | nn.BatchNorm2d(256), 225 | nn.ReLU(True), 226 | # ResnetBlock(256, padding_type='reflect', norm_layer=nn.BatchNorm2d, 227 | # use_dropout=False, use_bias=use_bias), 228 | # ResnetBlock(256, padding_type='reflect', norm_layer=nn.BatchNorm2d, 229 | # use_dropout=False, use_bias=use_bias) 230 | ) 231 | 232 | self.down3 = nn.Sequential( 233 | nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=use_bias), 234 | nn.BatchNorm2d(128), 235 | nn.ReLU(True), 236 | # ResnetBlock(128, padding_type='reflect', norm_layer=nn.BatchNorm2d, 237 | # use_dropout=False, use_bias=use_bias) 238 | ) 239 | 240 | self.down4 = nn.Sequential( 241 | nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, bias=use_bias), 242 | nn.BatchNorm2d(64), 243 | nn.ReLU(True), 244 | # ResnetBlock(64, padding_type='reflect', norm_layer=nn.BatchNorm2d, 245 | # use_dropout=False, use_bias=use_bias) 246 | ) 247 | 248 | self.head_pose = nn.Sequential( 249 | nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, bias=use_bias), 250 | nn.BatchNorm2d(512), 251 | nn.ReLU(True), 252 | nn.Conv2d(512, 64, kernel_size=3, stride=2, padding=1, bias=use_bias), 253 | nn.BatchNorm2d(64), 254 | nn.ReLU(True), 255 | nn.AvgPool2d(7), 256 | nn.Conv2d(64, 16, kernel_size=1, stride=1, padding=0, bias=True), 257 | # nn.BatchNorm2d(128), 258 | nn.ReLU(True) 259 | ) 260 | 261 | self.gen_block1 = nn.Sequential( 262 | nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 263 | nn.BatchNorm2d(256), 264 | nn.ReLU(True) 265 | ) 266 | 267 | self.gen_block2 = nn.Sequential( 268 | nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 269 | nn.BatchNorm2d(128), 270 | nn.ReLU(True) 271 | ) 272 | 273 | self.gen_block3 = nn.Sequential( 274 | nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 275 | nn.BatchNorm2d(64), 276 | nn.ReLU(True) 277 | ) 278 | self.gen_block4 = nn.Sequential( 279 | nn.ReflectionPad2d(3), 280 | nn.Conv2d(64, 1, kernel_size=7, padding=0), 281 | nn.Sigmoid() 282 | ) 283 | 284 | def forward(self, face, depth): 285 | face_f1 = self.face_block1(face) 286 | face_f2 = self.face_block2(face_f1) 287 | face_f3 = self.face_block3(face_f2) 288 | face_f4 = self.face_block4(face_f3) 289 | depth_f1 = self.depth_block1(depth) 290 | depth_f2 = self.depth_block2(depth_f1) 291 | depth_f3 = self.depth_block3(depth_f2) 292 | depth_f4 = self.depth_block4(depth_f3) 293 | mixed_f4 = self.down1(th.cat([face_f4, depth_f4], dim=1)) 294 | mixed_f3 = self.down2(th.cat([face_f3, depth_f3], dim=1)) 295 | mixed_f2 = self.down3(th.cat([face_f2, depth_f2], dim=1)) 296 | mixed_f1 = self.down4(th.cat([face_f1, depth_f1], dim=1)) 297 | gen_f3 = self.gen_block1(mixed_f4) + mixed_f3 298 | gen_f2 = self.gen_block2(gen_f3) + mixed_f2 299 | gen_f1 = self.gen_block3(gen_f2) + mixed_f1 300 | gen_depth = self.gen_block4(gen_f1) 301 | head_pose = self.head_pose(mixed_f4) 302 | return head_pose.view(head_pose.size(0), -1), gen_depth 303 | -------------------------------------------------------------------------------- /code/models/gaze_aaai.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.resnet import BasicBlock, ResNet, model_urls 2 | import torch.utils.model_zoo as model_zoo 3 | from torch import nn 4 | import torch as th 5 | import torch.nn.functional as F 6 | 7 | 8 | class ResNetEncoder(ResNet): 9 | def forward(self, x): 10 | x = self.conv1(x) 11 | x = self.bn1(x) 12 | x = self.relu(x) 13 | x = self.maxpool(x) 14 | x = self.layer1(x) 15 | x = self.layer2(x) 16 | x = self.layer3(x) 17 | x = self.layer4(x) 18 | x = self.avgpool(x) 19 | x = x.view(x.size(0), -1) 20 | x = self.relu(x) 21 | return x 22 | 23 | 24 | def resnet34(pretrained=False, **kwargs): 25 | """Constructs a ResNet-34 model. 26 | 27 | Args: 28 | pretrained (bool): If True, returns a model pre-trained on ImageNet 29 | """ 30 | model = ResNetEncoder(BasicBlock, [3, 4, 6, 3], **kwargs) 31 | if pretrained: 32 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 33 | return model 34 | 35 | 36 | class Decoder(nn.Module): 37 | def __init__(self, feat_dim=512): 38 | super(Decoder, self).__init__() 39 | self.ldecoder = nn.Sequential( 40 | nn.Linear(feat_dim, 128), 41 | nn.ReLU(True), 42 | ) 43 | self.rdecoder = nn.Sequential( 44 | nn.Linear(feat_dim, 128), 45 | nn.ReLU(True), 46 | ) 47 | self.lcoord = nn.Sequential( 48 | nn.Linear(128 + 128 + 3, 64), 49 | nn.ReLU(True), 50 | nn.Linear(64, 2) 51 | ) 52 | self.rcoord = nn.Sequential( 53 | nn.Linear(128 + 128 + 3, 64), 54 | nn.ReLU(True), 55 | nn.Linear(64, 2) 56 | ) 57 | 58 | def forward(self, lfeat, rfeat, head_pose, linfo, rinfo): 59 | l_coord_feat = self.ldecoder(lfeat) 60 | r_coord_feat = self.rdecoder(rfeat) 61 | l_coord = self.lcoord(th.cat([l_coord_feat, head_pose, linfo], 1)) 62 | r_coord = self.rcoord(th.cat([r_coord_feat, head_pose, rinfo], 1)) 63 | coord = (l_coord + r_coord) / 2. 64 | # coord = self.coord(th.cat([l_coord, r_coord], 1)) 65 | return coord 66 | 67 | 68 | class DepthL1(nn.Module): 69 | def __init__(self, th_lower=None, th_upper=None): 70 | super(DepthL1, self).__init__() 71 | self.th_lower = th_lower 72 | self.th_upper = th_upper 73 | 74 | def forward(self, pred, target): 75 | pred = pred.view(pred.size(0), -1) 76 | target = target.view(target.size(0), -1) 77 | if self.th_lower is not None: 78 | with th.no_grad(): 79 | mask_lower = (target > self.th_lower).float() 80 | else: 81 | mask_lower = 1. 82 | if self.th_upper is not None: 83 | with th.no_grad(): 84 | mask_upper = (target < self.th_upper).float() 85 | else: 86 | mask_upper = 1. 87 | 88 | return th.sum(th.abs(pred - target) * mask_lower * mask_upper) / (th.sum(mask_lower * mask_upper) + 1e-5) 89 | 90 | 91 | class DepthBCE(nn.Module): 92 | def __init__(self, th_lower=None, th_upper=None): 93 | super(DepthBCE, self).__init__() 94 | self.th_lower = th_lower 95 | self.th_upper = th_upper 96 | 97 | def forward(self, pred, target): 98 | pred = pred.view(pred.size(0), -1) 99 | target = target.view(target.size(0), -1) 100 | if self.th_lower is not None: 101 | with th.no_grad(): 102 | mask_lower = (target > self.th_lower).float() 103 | else: 104 | mask_lower = 1. 105 | if self.th_upper is not None: 106 | with th.no_grad(): 107 | mask_upper = (target < self.th_upper).float() 108 | else: 109 | mask_upper = 1. 110 | 111 | weight = mask_upper * mask_lower 112 | return F.binary_cross_entropy(pred, target, weight=weight if not isinstance(weight, float) else None) 113 | 114 | 115 | class ResnetBlock(nn.Module): 116 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 117 | super(ResnetBlock, self).__init__() 118 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 119 | 120 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 121 | conv_block = [] 122 | p = 0 123 | if padding_type == 'reflect': 124 | conv_block += [nn.ReflectionPad2d(1)] 125 | elif padding_type == 'replicate': 126 | conv_block += [nn.ReplicationPad2d(1)] 127 | elif padding_type == 'zero': 128 | p = 1 129 | else: 130 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 131 | 132 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 133 | norm_layer(dim), 134 | nn.ReLU(True)] 135 | if use_dropout: 136 | conv_block += [nn.Dropout(0.5)] 137 | 138 | p = 0 139 | if padding_type == 'reflect': 140 | conv_block += [nn.ReflectionPad2d(1)] 141 | elif padding_type == 'replicate': 142 | conv_block += [nn.ReplicationPad2d(1)] 143 | elif padding_type == 'zero': 144 | p = 1 145 | else: 146 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 147 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 148 | norm_layer(dim)] 149 | 150 | return nn.Sequential(*conv_block) 151 | 152 | def forward(self, x): 153 | out = x + self.conv_block(x) 154 | return F.relu(out) 155 | 156 | 157 | class RefineDepth(nn.Module): 158 | def __init__(self): 159 | super(RefineDepth, self).__init__() 160 | use_bias = False 161 | 162 | self.face_block1 = nn.Sequential( 163 | nn.ReflectionPad2d(3), 164 | nn.Conv2d(3, 64, kernel_size=7, padding=0, bias=use_bias), 165 | nn.BatchNorm2d(64), 166 | nn.ReLU(True) 167 | ) 168 | 169 | self.face_block2 = nn.Sequential( 170 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=use_bias), 171 | nn.BatchNorm2d(128), 172 | nn.ReLU(True) 173 | ) 174 | 175 | self.face_block3 = nn.Sequential( 176 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=use_bias), 177 | nn.BatchNorm2d(256), 178 | nn.ReLU(True) 179 | ) 180 | 181 | self.face_block4 = nn.Sequential( 182 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=use_bias), 183 | nn.BatchNorm2d(512), 184 | nn.ReLU(True) 185 | ) 186 | 187 | self.depth_block1 = nn.Sequential( 188 | nn.ReflectionPad2d(3), 189 | nn.Conv2d(1, 64, kernel_size=7, padding=0, bias=use_bias), 190 | nn.BatchNorm2d(64), 191 | nn.ReLU(True) 192 | ) 193 | 194 | self.depth_block2 = nn.Sequential( 195 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=use_bias), 196 | nn.BatchNorm2d(128), 197 | nn.ReLU(True) 198 | ) 199 | 200 | self.depth_block3 = nn.Sequential( 201 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=use_bias), 202 | nn.BatchNorm2d(256), 203 | nn.ReLU(True) 204 | ) 205 | self.depth_block4 = nn.Sequential( 206 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=use_bias), 207 | nn.BatchNorm2d(512), 208 | nn.ReLU(True) 209 | ) 210 | 211 | self.down1 = nn.Sequential( 212 | nn.Conv2d(1024, 512, kernel_size=1, stride=1, padding=0, bias=use_bias), 213 | nn.BatchNorm2d(512), 214 | nn.ReLU(True), 215 | ResnetBlock(512, padding_type='reflect', norm_layer=nn.BatchNorm2d, 216 | use_dropout=False, use_bias=use_bias), 217 | # ResnetBlock(512, padding_type='reflect', norm_layer=nn.BatchNorm2d, 218 | # use_dropout=False, use_bias=use_bias), 219 | # ResnetBlock(512, padding_type='reflect', norm_layer=nn.BatchNorm2d, 220 | # use_dropout=False, use_bias=use_bias) 221 | ) 222 | self.down2 = nn.Sequential( 223 | nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=use_bias), 224 | nn.BatchNorm2d(256), 225 | nn.ReLU(True), 226 | # ResnetBlock(256, padding_type='reflect', norm_layer=nn.BatchNorm2d, 227 | # use_dropout=False, use_bias=use_bias), 228 | # ResnetBlock(256, padding_type='reflect', norm_layer=nn.BatchNorm2d, 229 | # use_dropout=False, use_bias=use_bias) 230 | ) 231 | 232 | self.down3 = nn.Sequential( 233 | nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=use_bias), 234 | nn.BatchNorm2d(128), 235 | nn.ReLU(True), 236 | # ResnetBlock(128, padding_type='reflect', norm_layer=nn.BatchNorm2d, 237 | # use_dropout=False, use_bias=use_bias) 238 | ) 239 | 240 | self.down4 = nn.Sequential( 241 | nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, bias=use_bias), 242 | nn.BatchNorm2d(64), 243 | nn.ReLU(True), 244 | # ResnetBlock(64, padding_type='reflect', norm_layer=nn.BatchNorm2d, 245 | # use_dropout=False, use_bias=use_bias) 246 | ) 247 | 248 | self.head_pose = nn.Sequential( 249 | nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, bias=use_bias), 250 | nn.BatchNorm2d(512), 251 | nn.ReLU(True), 252 | nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=use_bias), 253 | nn.BatchNorm2d(1024), 254 | nn.ReLU(True), 255 | nn.AvgPool2d(7), 256 | nn.Conv2d(1024, 128, kernel_size=1, stride=1, padding=0, bias=True), 257 | # nn.BatchNorm2d(128), 258 | # nn.ReLU(True) 259 | ) 260 | 261 | self.gen_block1 = nn.Sequential( 262 | nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 263 | nn.BatchNorm2d(256), 264 | nn.ReLU(True) 265 | ) 266 | 267 | self.gen_block2 = nn.Sequential( 268 | nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 269 | nn.BatchNorm2d(128), 270 | nn.ReLU(True) 271 | ) 272 | 273 | self.gen_block3 = nn.Sequential( 274 | nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 275 | nn.BatchNorm2d(64), 276 | nn.ReLU(True) 277 | ) 278 | self.gen_block4 = nn.Sequential( 279 | nn.ReflectionPad2d(3), 280 | nn.Conv2d(64, 1, kernel_size=7, padding=0), 281 | nn.Sigmoid() 282 | ) 283 | 284 | def forward(self, face, depth): 285 | face_f1 = self.face_block1(face) 286 | face_f2 = self.face_block2(face_f1) 287 | face_f3 = self.face_block3(face_f2) 288 | face_f4 = self.face_block4(face_f3) 289 | depth_f1 = self.depth_block1(depth) 290 | depth_f2 = self.depth_block2(depth_f1) 291 | depth_f3 = self.depth_block3(depth_f2) 292 | depth_f4 = self.depth_block4(depth_f3) 293 | mixed_f4 = self.down1(th.cat([face_f4, depth_f4], dim=1)) 294 | mixed_f3 = self.down2(th.cat([face_f3, depth_f3], dim=1)) 295 | mixed_f2 = self.down3(th.cat([face_f2, depth_f2], dim=1)) 296 | mixed_f1 = self.down4(th.cat([face_f1, depth_f1], dim=1)) 297 | gen_f3 = self.gen_block1(mixed_f4) + mixed_f3 298 | gen_f2 = self.gen_block2(gen_f3) + mixed_f2 299 | gen_f1 = self.gen_block3(gen_f2) + mixed_f1 300 | gen_depth = self.gen_block4(gen_f1) 301 | head_pose = self.head_pose(mixed_f4) 302 | return head_pose.view(head_pose.size(0), -1), gen_depth 303 | -------------------------------------------------------------------------------- /code/models/gaze_aaai_is.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.resnet import BasicBlock, ResNet, model_urls 2 | import torch.utils.model_zoo as model_zoo 3 | from torch import nn 4 | import torch as th 5 | import torch.nn.functional as F 6 | 7 | 8 | class ResNetEncoder(ResNet): 9 | def forward(self, x): 10 | x = self.conv1(x) 11 | x = self.bn1(x) 12 | x = self.relu(x) 13 | x = self.maxpool(x) 14 | x = self.layer1(x) 15 | x = self.layer2(x) 16 | x = self.layer3(x) 17 | x = self.layer4(x) 18 | x = self.avgpool(x) 19 | x = x.view(x.size(0), -1) 20 | x = self.relu(x) 21 | return x 22 | 23 | 24 | def resnet34(pretrained=False, **kwargs): 25 | """Constructs a ResNet-34 model. 26 | 27 | Args: 28 | pretrained (bool): If True, returns a model pre-trained on ImageNet 29 | """ 30 | model = ResNetEncoder(BasicBlock, [3, 4, 6, 3], **kwargs) 31 | if pretrained: 32 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 33 | return model 34 | 35 | 36 | class Decoder(nn.Module): 37 | def __init__(self, feat_dim=512): 38 | super(Decoder, self).__init__() 39 | self.ldecoder = nn.Sequential( 40 | nn.Linear(feat_dim, 128), 41 | nn.ReLU(True), 42 | ) 43 | self.rdecoder = nn.Sequential( 44 | nn.Linear(feat_dim, 128), 45 | nn.ReLU(True), 46 | ) 47 | self.lcoord = nn.Sequential( 48 | nn.Linear(128 + 128 + 3, 64), 49 | nn.ReLU(True), 50 | nn.Linear(64, 2) 51 | ) 52 | self.rcoord = nn.Sequential( 53 | nn.Linear(128 + 128 + 3, 64), 54 | nn.ReLU(True), 55 | nn.Linear(64, 2) 56 | ) 57 | 58 | def forward(self, lfeat, rfeat, head_pose, linfo, rinfo): 59 | l_coord_feat = self.ldecoder(lfeat) 60 | r_coord_feat = self.rdecoder(rfeat) 61 | l_coord = self.lcoord(th.cat([l_coord_feat, head_pose, linfo], 1)) 62 | r_coord = self.rcoord(th.cat([r_coord_feat, head_pose, rinfo], 1)) 63 | coord = (l_coord + r_coord) / 2. 64 | # coord = self.coord(th.cat([l_coord, r_coord], 1)) 65 | return coord 66 | 67 | 68 | class DepthL1(nn.Module): 69 | def __init__(self, th_lower=None, th_upper=None): 70 | super(DepthL1, self).__init__() 71 | self.th_lower = th_lower 72 | self.th_upper = th_upper 73 | 74 | def forward(self, pred, target): 75 | pred = pred.view(pred.size(0), -1) 76 | target = target.view(target.size(0), -1) 77 | if self.th_lower is not None: 78 | with th.no_grad(): 79 | mask_lower = (target > self.th_lower).float() 80 | else: 81 | mask_lower = 1. 82 | if self.th_upper is not None: 83 | with th.no_grad(): 84 | mask_upper = (target < self.th_upper).float() 85 | else: 86 | mask_upper = 1. 87 | 88 | return th.sum(th.abs(pred - target) * mask_lower * mask_upper) / (th.sum(mask_lower * mask_upper) + 1e-5) 89 | 90 | 91 | class DepthBCE(nn.Module): 92 | def __init__(self, th_lower=None, th_upper=None): 93 | super(DepthBCE, self).__init__() 94 | self.th_lower = th_lower 95 | self.th_upper = th_upper 96 | 97 | def forward(self, pred, target): 98 | pred = pred.view(pred.size(0), -1) 99 | target = target.view(target.size(0), -1) 100 | if self.th_lower is not None: 101 | with th.no_grad(): 102 | mask_lower = (target > self.th_lower).float() 103 | else: 104 | mask_lower = 1. 105 | if self.th_upper is not None: 106 | with th.no_grad(): 107 | mask_upper = (target < self.th_upper).float() 108 | else: 109 | mask_upper = 1. 110 | 111 | weight = mask_upper * mask_lower 112 | return F.binary_cross_entropy(pred, target, weight=weight if not isinstance(weight, float) else None) 113 | 114 | 115 | class ResnetBlock(nn.Module): 116 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 117 | super(ResnetBlock, self).__init__() 118 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 119 | 120 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 121 | conv_block = [] 122 | p = 0 123 | if padding_type == 'reflect': 124 | conv_block += [nn.ReflectionPad2d(1)] 125 | elif padding_type == 'replicate': 126 | conv_block += [nn.ReplicationPad2d(1)] 127 | elif padding_type == 'zero': 128 | p = 1 129 | else: 130 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 131 | 132 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 133 | norm_layer(dim), 134 | nn.ReLU(True)] 135 | if use_dropout: 136 | conv_block += [nn.Dropout(0.5)] 137 | 138 | p = 0 139 | if padding_type == 'reflect': 140 | conv_block += [nn.ReflectionPad2d(1)] 141 | elif padding_type == 'replicate': 142 | conv_block += [nn.ReplicationPad2d(1)] 143 | elif padding_type == 'zero': 144 | p = 1 145 | else: 146 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 147 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 148 | norm_layer(dim)] 149 | 150 | return nn.Sequential(*conv_block) 151 | 152 | def forward(self, x): 153 | out = x + self.conv_block(x) 154 | return F.relu(out) 155 | 156 | 157 | class RefineDepth(nn.Module): 158 | def __init__(self): 159 | super(RefineDepth, self).__init__() 160 | use_bias = False 161 | 162 | self.face_block1 = nn.Sequential( 163 | nn.ReflectionPad2d(3), 164 | nn.Conv2d(3, 64, kernel_size=7, padding=0, bias=use_bias), 165 | nn.BatchNorm2d(64), 166 | nn.ReLU(True) 167 | ) 168 | 169 | self.face_block2 = nn.Sequential( 170 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=use_bias), 171 | nn.BatchNorm2d(128), 172 | nn.ReLU(True) 173 | ) 174 | 175 | self.face_block3 = nn.Sequential( 176 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=use_bias), 177 | nn.BatchNorm2d(256), 178 | nn.ReLU(True) 179 | ) 180 | 181 | self.face_block4 = nn.Sequential( 182 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=use_bias), 183 | nn.BatchNorm2d(512), 184 | nn.ReLU(True) 185 | ) 186 | 187 | self.depth_block1 = nn.Sequential( 188 | nn.ReflectionPad2d(3), 189 | nn.Conv2d(1, 64, kernel_size=7, padding=0, bias=use_bias), 190 | nn.BatchNorm2d(64), 191 | nn.ReLU(True) 192 | ) 193 | 194 | self.depth_block2 = nn.Sequential( 195 | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=use_bias), 196 | nn.BatchNorm2d(128), 197 | nn.ReLU(True) 198 | ) 199 | 200 | self.depth_block3 = nn.Sequential( 201 | nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=use_bias), 202 | nn.BatchNorm2d(256), 203 | nn.ReLU(True) 204 | ) 205 | self.depth_block4 = nn.Sequential( 206 | nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=use_bias), 207 | nn.BatchNorm2d(512), 208 | nn.ReLU(True) 209 | ) 210 | 211 | self.down1 = nn.Sequential( 212 | nn.Conv2d(1024, 512, kernel_size=1, stride=1, padding=0, bias=use_bias), 213 | nn.BatchNorm2d(512), 214 | nn.ReLU(True), 215 | ResnetBlock(512, padding_type='reflect', norm_layer=nn.BatchNorm2d, 216 | use_dropout=False, use_bias=use_bias), 217 | # ResnetBlock(512, padding_type='reflect', norm_layer=nn.BatchNorm2d, 218 | # use_dropout=False, use_bias=use_bias), 219 | # ResnetBlock(512, padding_type='reflect', norm_layer=nn.BatchNorm2d, 220 | # use_dropout=False, use_bias=use_bias) 221 | ) 222 | self.down2 = nn.Sequential( 223 | nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0, bias=use_bias), 224 | nn.BatchNorm2d(256), 225 | nn.ReLU(True), 226 | # ResnetBlock(256, padding_type='reflect', norm_layer=nn.BatchNorm2d, 227 | # use_dropout=False, use_bias=use_bias), 228 | # ResnetBlock(256, padding_type='reflect', norm_layer=nn.BatchNorm2d, 229 | # use_dropout=False, use_bias=use_bias) 230 | ) 231 | 232 | self.down3 = nn.Sequential( 233 | nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0, bias=use_bias), 234 | nn.BatchNorm2d(128), 235 | nn.ReLU(True), 236 | # ResnetBlock(128, padding_type='reflect', norm_layer=nn.BatchNorm2d, 237 | # use_dropout=False, use_bias=use_bias) 238 | ) 239 | 240 | self.down4 = nn.Sequential( 241 | nn.Conv2d(128, 64, kernel_size=1, stride=1, padding=0, bias=use_bias), 242 | nn.BatchNorm2d(64), 243 | nn.ReLU(True), 244 | # ResnetBlock(64, padding_type='reflect', norm_layer=nn.BatchNorm2d, 245 | # use_dropout=False, use_bias=use_bias) 246 | ) 247 | 248 | self.head_pose1 = nn.Sequential( 249 | nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, bias=use_bias), 250 | nn.BatchNorm2d(512), 251 | nn.ReLU(True), 252 | nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=use_bias), 253 | nn.BatchNorm2d(1024), 254 | nn.ReLU(True), 255 | nn.AvgPool2d(7), 256 | nn.Conv2d(1024, 128, kernel_size=1, stride=1, padding=0, bias=True), 257 | nn.ReLU(True) 258 | # nn.BatchNorm2d(128), 259 | # nn.ReLU(True) 260 | ) 261 | 262 | self.head_pose2 = nn.Sequential( 263 | nn.Conv2d(128, 256, kernel_size=1, stride=1, padding=0, bias=True), 264 | nn.ReLU(True), 265 | nn.Conv2d(256, 64, kernel_size=1, stride=1, padding=0, bias=True), 266 | nn.ReLU(True), 267 | nn.Conv2d(64, 2, kernel_size=1, stride=1, padding=0, bias=True), 268 | ) 269 | 270 | self.gen_block1 = nn.Sequential( 271 | nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 272 | nn.BatchNorm2d(256), 273 | nn.ReLU(True) 274 | ) 275 | 276 | self.gen_block2 = nn.Sequential( 277 | nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 278 | nn.BatchNorm2d(128), 279 | nn.ReLU(True) 280 | ) 281 | 282 | self.gen_block3 = nn.Sequential( 283 | nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), 284 | nn.BatchNorm2d(64), 285 | nn.ReLU(True) 286 | ) 287 | self.gen_block4 = nn.Sequential( 288 | nn.ReflectionPad2d(3), 289 | nn.Conv2d(64, 1, kernel_size=7, padding=0), 290 | nn.Sigmoid() 291 | ) 292 | 293 | def forward(self, face, depth): 294 | face_f1 = self.face_block1(face) 295 | face_f2 = self.face_block2(face_f1) 296 | face_f3 = self.face_block3(face_f2) 297 | face_f4 = self.face_block4(face_f3) 298 | depth_f1 = self.depth_block1(depth) 299 | depth_f2 = self.depth_block2(depth_f1) 300 | depth_f3 = self.depth_block3(depth_f2) 301 | depth_f4 = self.depth_block4(depth_f3) 302 | mixed_f4 = self.down1(th.cat([face_f4, depth_f4], dim=1)) 303 | mixed_f3 = self.down2(th.cat([face_f3, depth_f3], dim=1)) 304 | mixed_f2 = self.down3(th.cat([face_f2, depth_f2], dim=1)) 305 | mixed_f1 = self.down4(th.cat([face_f1, depth_f1], dim=1)) 306 | gen_f3 = self.gen_block1(mixed_f4) + mixed_f3 307 | gen_f2 = self.gen_block2(gen_f3) + mixed_f2 308 | gen_f1 = self.gen_block3(gen_f2) + mixed_f1 309 | gen_depth = self.gen_block4(gen_f1) 310 | head_pose_f1 = self.head_pose1(mixed_f4) 311 | head_pose = self.head_pose2(head_pose_f1) 312 | return head_pose.view(head_pose.size(0), -1), head_pose_f1.view(head_pose_f1.size(0), -1), gen_depth 313 | -------------------------------------------------------------------------------- /code/.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | _train_base_epoch 73 | self.meters 74 | train_headpose 75 | self.temps.train_loader 76 | train_base 77 | _get_trainloader 78 | le_bbox_list 79 | right_eye_bbox 80 | left_eye_coord 81 | right_eye_coord 82 | face_scale_factor 83 | tensorboardX 84 | visdom 85 | le_coord_list 86 | le_coor 87 | 88 | 89 | 90 | 97 | 98 | 99 | 100 | 101 | true 102 | DEFINITION_ORDER 103 | 104 | 105 | 106 | 107 | 108 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 |