├── __init__.py ├── assets ├── test └── logo.png ├── Zone ├── zone.png ├── class_large.label └── evaluation_class.py ├── reader ├── __pycache__ │ ├── reader.cpython-36.pyc │ ├── reader.cpython-38.pyc │ ├── reader_adap.cpython-36.pyc │ ├── reader_diap.cpython-36.pyc │ ├── reader_gc.cpython-36.pyc │ ├── reader_iv.cpython-38.pyc │ ├── reader_mpii.cpython-36.pyc │ ├── reader_pred.cpython-36.pyc │ └── reader_prediction.cpython-36.pyc └── reader.py ├── config ├── test │ └── config_iv_us.yaml └── train │ └── config_iv.yaml ├── gtools.py ├── evaluation.py ├── ctools.py ├── GazePTR.py ├── DATASET.md ├── GazeDPTR.py ├── tester ├── total.py └── leave.py ├── trainer ├── total.py └── leave.py ├── resnet.py ├── README.md ├── GazeDPTR_V2.py ├── LICENSE └── IVModule.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/test: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Zone/zone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/Zone/zone.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/assets/logo.png -------------------------------------------------------------------------------- /reader/__pycache__/reader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader.cpython-36.pyc -------------------------------------------------------------------------------- /reader/__pycache__/reader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader.cpython-38.pyc -------------------------------------------------------------------------------- /reader/__pycache__/reader_adap.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader_adap.cpython-36.pyc -------------------------------------------------------------------------------- /reader/__pycache__/reader_diap.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader_diap.cpython-36.pyc -------------------------------------------------------------------------------- /reader/__pycache__/reader_gc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader_gc.cpython-36.pyc -------------------------------------------------------------------------------- /reader/__pycache__/reader_iv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader_iv.cpython-38.pyc -------------------------------------------------------------------------------- /reader/__pycache__/reader_mpii.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader_mpii.cpython-36.pyc -------------------------------------------------------------------------------- /reader/__pycache__/reader_pred.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader_pred.cpython-36.pyc -------------------------------------------------------------------------------- /reader/__pycache__/reader_prediction.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader_prediction.cpython-36.pyc -------------------------------------------------------------------------------- /config/test/config_iv_us.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | load: 3 | begin_step: 10 4 | end_step: 80 5 | steps: 10 6 | 7 | data: 8 | origin: 9 | image: "/home/$YourPath$/Origin" 10 | label: "/home/$YourPath$/Origin/label_class" 11 | header: True 12 | name: ivorigin 13 | isFolder: True 14 | norm: 15 | image: "/home/$YourPath$/Norm" 16 | label: "/home/$YourPath$/Norm/label_class" 17 | header: True 18 | name: ivnorm 19 | isFolder: True 20 | 21 | savename: "evaluation" 22 | device: 0 23 | reader: reader 24 | -------------------------------------------------------------------------------- /gtools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def gazeto3d(gaze): 4 | assert gaze.size == 2, "The size of gaze must be 2" 5 | gaze_gt = np.zeros([3]) 6 | gaze_gt[0] = -np.cos(gaze[1]) * np.sin(gaze[0]) 7 | gaze_gt[1] = -np.sin(gaze[1]) 8 | gaze_gt[2] = -np.cos(gaze[1]) * np.cos(gaze[0]) 9 | return gaze_gt 10 | 11 | def angular(gaze, label): 12 | assert gaze.size == 3, "The size of gaze must be 3" 13 | assert label.size == 3, "The size of label must be 3" 14 | 15 | total = np.sum(gaze * label) 16 | return np.arccos(min(total/(np.linalg.norm(gaze)* np.linalg.norm(label)), 0.9999999))*180/np.pi 17 | 18 | 19 | -------------------------------------------------------------------------------- /config/train/config_iv.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | train: 3 | 4 | params: 5 | batch_size: 64 6 | epoch: 80 7 | lr: 0.001 8 | decay: 0.5 9 | decay_step: 60 10 | warmup: 5 11 | 12 | save: 13 | metapath: "/home/$YourSavePath$/exp/GazeDPTR" 14 | folder: iv 15 | model_name: trans6 16 | step: 10 17 | 18 | data: 19 | origin: 20 | image: "/home/$YourPath$/Origin" 21 | label: "/home/$YourPath$/Origin/label_class" 22 | header: True 23 | name: ivorigin 24 | isFolder: True 25 | norm: 26 | image: "/home/$YourPath$/Norm" 27 | label: "/home/$YourPath$/Norm/label_class" 28 | header: True 29 | name: ivnorm 30 | isFolder: True 31 | 32 | 33 | pretrain: 34 | enable: False 35 | path: None 36 | device: 0 37 | 38 | device: 0 39 | 40 | reader: reader 41 | 42 | # dropout = 0 43 | # dim_feed = 512 44 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | folder = sys.argv[1] 5 | begin = 10 6 | step = 10 7 | end = 80 8 | 9 | # pre_add = 'log_zone' 10 | pre_add = 'log_dir' 11 | 12 | 13 | def parse(line): 14 | line = line.strip().split(" ") 15 | number = int((line[3][:-1])) 16 | avg = float(line[-1]) 17 | return number, avg * number, avg 18 | 19 | def printResult(n_num, n_error, n_avg): 20 | print(f'Iter {i}: Total: {n_num} Error: {n_error/n_num:.2f}\tSub_Error: ', end = '') 21 | for avg in n_avg: 22 | print(f'{avg} ', end='') 23 | print('') 24 | 25 | 26 | sub_folders = os.listdir(folder) 27 | 28 | for i in range(begin, end+1, step): 29 | try: 30 | name = f"{i}.{pre_add}" 31 | n_num = 0 32 | n_error = 0 33 | n_avg = [] 34 | for sub_path in sub_folders: 35 | filename = os.path.join(folder, sub_path, name) 36 | with open(filename) as infile: 37 | lines = infile.readlines() 38 | number, error, avg = parse(lines[-1]) 39 | n_num += number 40 | n_error += error 41 | n_avg.append(avg) 42 | 43 | printResult(n_num, n_error, n_avg) 44 | except: 45 | pass 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /ctools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import time 4 | import os 5 | import json 6 | from easydict import EasyDict as edict 7 | 8 | class TimeCounter: 9 | # Create an time counter. 10 | # To count the rest time. 11 | 12 | # Input the total times. 13 | def __init__(self, total): 14 | self.total = total 15 | self.cur = 0 16 | self.begin = time.time() 17 | 18 | def step(self): 19 | end = time.time() 20 | self.cur += 1 21 | used = (end - self.begin)/self.cur 22 | rest = self.total - self.cur 23 | 24 | return np.max(rest * used, 0) 25 | 26 | 27 | def readfolder(data, specific=None, reverse=False): 28 | 29 | """" 30 | Traverse the folder 'data.label' and read data from all files in the folder. 31 | 32 | Specific is a list, specify the num of extracted file. 33 | 34 | When reverse is True, read the files which num is not in specific. 35 | """ 36 | 37 | 38 | folders = os.listdir(data.label) 39 | folders.sort() 40 | 41 | folder = folders 42 | if specific is not None: 43 | if reverse: 44 | num = np.arange(len(folders)) 45 | specific = list(filter(lambda x: x not in specific, num)) 46 | 47 | folder = [folders[i] for i in specific] 48 | 49 | data.label = [os.path.join(data.label, j) for j in folder] 50 | 51 | return data, folders 52 | 53 | 54 | def DictDumps(content): 55 | return json.dumps(content, ensure_ascii=False, indent=4) 56 | 57 | 58 | def GetLR(optimizer): 59 | LR = optimizer.state_dict()['param_groups'][0]['lr'] 60 | return LR 61 | 62 | -------------------------------------------------------------------------------- /Zone/class_large.label: -------------------------------------------------------------------------------- 1 | 0 37 0 2 | 1 57 0 3 | 2 77 0 4 | 3 97 0 5 | 4 62 9 6 | 5 63 9 7 | 6 64 9 8 | 7 65 9 9 | 8 72 9 10 | 9 73 9 11 | 10 80 0 12 | 11 81 0 13 | 12 84 9 14 | 13 85 9 15 | 14 90 0 16 | 15 91 0 17 | 16 92 9 18 | 17 93 9 19 | 18 94 9 20 | 19 95 9 21 | 20 98 0 22 | 21 88 0 23 | 22 78 0 24 | 23 68 0 25 | 24 48 0 26 | 25 38 0 27 | 26 89 0 28 | 27 69 0 29 | 28 59 0 30 | 29 39 0 31 | 30 d2 4 32 | 31 33 8 33 | 32 34 8 34 | 33 35 8 35 | 34 36 8 36 | 35 42 8 37 | 36 43 8 38 | 37 44 8 39 | 38 45 8 40 | 39 51 0 41 | 40 52 0 42 | 41 53 0 43 | 42 56 0 44 | 43 07 0 45 | 44 27 0 46 | 45 47 0 47 | 46 67 0 48 | 47 17 0 49 | 48 01 0 50 | 49 02 0 51 | 50 04 0 52 | 51 05 0 53 | 52 06 0 54 | 53 10 0 55 | 54 11 0 56 | 55 12 8 57 | 56 13 8 58 | 57 14 8 59 | 58 15 8 60 | 59 16 8 61 | 60 21 0 62 | 61 23 8 63 | 62 24 8 64 | 63 26 8 65 | 64 30 0 66 | 65 32 8 67 | 66 d3 4 68 | 67 d4 4 69 | 68 d0 4 70 | 69 c1 3 71 | 70 c0 3 72 | 71 b0 2 73 | 72 b1 2 74 | 73 b2 2 75 | 74 a1 1 76 | 75 a2 1 77 | 76 a4 1 78 | 77 a0 1 79 | 78 00 0 80 | 79 03 0 81 | 80 20 0 82 | 81 22 8 83 | 82 25 8 84 | 83 31 0 85 | 84 41 0 86 | 85 46 8 87 | 86 54 0 88 | 87 55 0 89 | 88 87 0 90 | 89 66 9 91 | 90 74 9 92 | 91 75 9 93 | 92 76 9 94 | 93 82 9 95 | 94 83 9 96 | 95 86 9 97 | 96 96 9 98 | 97 58 0 99 | 98 79 0 100 | 99 d1 4 101 | 100 f6 6 102 | 101 f7 6 103 | 102 f8 6 104 | 103 c2 3 105 | 104 c3 3 106 | 105 a3 1 107 | 106 e0 5 108 | 107 e1 5 109 | 108 e2 5 110 | 109 f0 6 111 | 110 f1 6 112 | 111 g4 7 113 | 112 g1 7 114 | 113 g2 7 115 | 114 g3 7 116 | 115 99 0 117 | 116 49 0 118 | 117 f2 6 119 | 118 71 0 120 | 119 f9 6 121 | 120 f5 6 122 | 121 f4 6 123 | -------------------------------------------------------------------------------- /GazePTR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | import copy 6 | from IVModule import Backbone, Transformer, PoseTransformer, TripleDifferentialProj, PositionalEncoder 7 | 8 | 9 | def ep0(x): 10 | return x.unsqueeze(0) 11 | 12 | class Model(nn.Module): 13 | 14 | def __init__(self): 15 | 16 | super(Model, self).__init__() 17 | 18 | # used for origin. 19 | transIn = 128 20 | convDims = [64, 128, 256, 512] 21 | 22 | # norm only used for gaze estimation. 23 | self.bnorm = Backbone(1, transIn, convDims) 24 | 25 | # MLP for gaze estimation 26 | self.MLP_n_dir = nn.Linear(transIn, 2) 27 | 28 | module_list = [] 29 | for i in range(len(convDims)): 30 | module_list.append(nn.Linear(transIn, 2)) 31 | self.MLPList_n = nn.ModuleList(module_list) 32 | 33 | # Loss function 34 | self.loss_op_re = nn.L1Loss() 35 | 36 | 37 | def forward(self, x_in, train=True): 38 | 39 | # feature [outFeatureNum, Batch, transIn], MLfeatgure: list[x1, x2...] 40 | feature_n, feature_list_n = self.bnorm(x_in['norm_face']) 41 | 42 | # Get feature for different task 43 | # [5, 128] [1. 5, 128] 44 | feature_n_dir = feature_n.squeeze() 45 | 46 | # estimate gaze from fused feature 47 | gaze = self.MLP_n_dir(feature_n_dir) 48 | # zone = self.MLP_o_zone(feature_o_zone) 49 | 50 | # for loss caculation 51 | loss_gaze_n = [] 52 | if train: 53 | for i, feature in enumerate(feature_list_n): 54 | loss_gaze_n.append(self.MLPList_n[i](feature)) 55 | 56 | return gaze, None, None, loss_gaze_n 57 | 58 | def loss(self, x_in, label): 59 | 60 | gaze, _, _, loss_gaze_n = self.forward(x_in) 61 | 62 | loss1 = 2 * self.loss_op_re(gaze, label.normGaze) 63 | 64 | loss2 = 0 65 | # for zone in zones: 66 | # loss2 += (0.2/3) * self.loss_op_cls(zone, label.zone.view(-1)) 67 | 68 | loss3 = 0 69 | # for pred in loss_gaze_o: 70 | # loss3 += self.loss_op_re(pred, label.originGaze) 71 | 72 | loss4 = 0 73 | for pred in loss_gaze_n: 74 | loss4 += self.loss_op_re(pred, label.normGaze) 75 | loss = loss1 + loss2 + loss3 + loss4 76 | 77 | return loss, [loss1, loss2, loss3, loss4] 78 | 79 | 80 | if __name__ == '__main__': 81 | x_in = {'origin': torch.zeros([5, 3, 224, 224]).cuda(), 82 | 'norm': torch.zeros([5, 3, 224, 224]).cuda(), 83 | 'pos': torch.zeros(5, 2, 6).cuda() 84 | } 85 | 86 | model = Model() 87 | model = model.to('cuda') 88 | print(model) 89 | a = model(x_in) 90 | print(a) 91 | -------------------------------------------------------------------------------- /Zone/evaluation_class.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | 5 | folder = sys.argv[1] 6 | begin = 20 7 | step = 20 8 | end = 100 9 | 10 | pre_add = 'log_zone' 11 | 12 | class classAdapt: 13 | def __init__(self): 14 | with open ('/home/mercury01/ssd/dataset/Gaze/FaceBased/IVGaze/CVPR_process/class_large.label') as infile: 15 | lines = infile.readlines() 16 | self.class_dict = {} 17 | for line in lines: 18 | line = line.strip().split(' ') 19 | number = line[0] 20 | 21 | classname = line[2] 22 | self.class_dict[number] = classname 23 | 24 | def get(self, name): 25 | return self.class_dict.get(name, '-1') 26 | 27 | 28 | adapt = classAdapt() 29 | 30 | 31 | def printMat(matrix): 32 | size = matrix.shape 33 | 34 | print('pr\gt\t', end = '') 35 | for i in range(size[1]): 36 | print(f'{i}', end = '\t') 37 | print('') 38 | 39 | for i in range(size[0]): 40 | print(f'{i}', end = '\t') 41 | for j in range(size[1]): 42 | print(f'{int(matrix[i, j])}', end = '\t') 43 | print('') 44 | return 0 45 | 46 | 47 | def getAnalysis(results): 48 | # result: a list of prediction. [[pred, gt], [pred, gt]] 49 | num = 10 50 | all_class = range(0, num) 51 | 52 | matrix = np.zeros((num, num)) 53 | 54 | total = 0 55 | for result in results: 56 | matrix[int(result[1]), int(result[0])] += 1 57 | total += 1 58 | 59 | printMat(matrix) 60 | 61 | for i in range(num): 62 | sum_i = np.sum(matrix[i, :]) 63 | acc = matrix[i, i] / sum_i 64 | print(f'Class {i}: {matrix[i, i]}/{sum_i} = {acc:.3f}') 65 | 66 | return matrix 67 | 68 | 69 | def parse(line): 70 | line = line.strip().split(" ") 71 | pred = line[1] 72 | gt = line[2] 73 | 74 | new_pred = adapt.get(pred) 75 | new_gt = adapt.get(gt) 76 | 77 | return line[0], new_pred, new_gt, int(new_pred == new_gt) 78 | 79 | sub_folders = os.listdir(folder) 80 | 81 | for i in range(begin, end+1, step): 82 | 83 | try: 84 | name = f"{i}.{pre_add}" 85 | 86 | total = 0 87 | count = 0 88 | w_total = 0 89 | w_count = 0 90 | n_result = [] 91 | for sub_path in sub_folders: 92 | 93 | 94 | filename = os.path.join(folder, sub_path, name) 95 | 96 | with open(filename) as infile: 97 | lines = infile.readlines() 98 | lines.pop(0) 99 | 100 | for line in lines: 101 | if len(line.strip().split(' ')) != 3: 102 | continue 103 | _, pred, gt, result = parse(line) 104 | if gt != str(0): 105 | total += result 106 | count += 1 107 | w_total += result 108 | w_count += 1 109 | n_result.append([pred, gt]) 110 | print(f'***************************Result---{i}*****************************') 111 | getAnalysis(n_result) 112 | print(f'Avg: {total}/{count} = {total/count:.3f} ') 113 | print(f'With 0: {w_total}/{w_count} = {w_total/w_count:.3f} ') 114 | except: 115 | pass 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /DATASET.md: -------------------------------------------------------------------------------- 1 | ![](./assets/logo.png) 2 | 3 | ## In-Vehicle Gaze Estimation DataSet 4 | We provide an in-vehicle gaze estimation dataset IVGaze. 5 | 6 | - IVGaze contains 44,705 images of 125 subjects. We divide the dataset into three subsets based on subjects. The image numbers of the three subsets are 15,165, 14,674, and 14,866. 7 | Three-fold cross-validation should be performed on the dataset. 8 | 9 | - The dataset was collected between 9 am and 7 pm in outdoor environments, covering a wide range of lighting conditions. 10 | 11 | - We consider two face accessories during the collection: glasses and masks. We also required a few subjects to wear sunglasses to facilitate future research. 12 | 13 | ## Dataset Structure 14 | ``` 15 | IVGazeDataset 16 | ├── class.label 17 | ├── Norm 18 | │   ├── 20220811 19 | │   │   ├── subject0000_out_eye_mask 20 | │   │   │ ├── 1.jpg 21 | │   │   │ ├── ... 22 | │   │   │ ├── ... 23 | │   │   │ ├── 81.jpg 24 | │   │   ├── ... 25 | │   │   ├── ... 26 | │   │   ├── subject0000_out_eye_nomask 27 | │   ├── 20221009 28 | │   ├── 20221010 29 | │   ├── 20221011 30 | │   ├── 20221012 31 | │   ├── 20221013 32 | │   ├── 20221014 33 | │   ├── 20221017 34 | │   ├── 20221018 35 | │   ├── 20221019 36 | │   ├── 20221020 37 | │   └── label_class 38 | │   ├── train1.txt 39 | │   ├── train2.txt 40 | │   └── train3.txt 41 | └── Origin 42 | ├── 20220811 43 | ├── ... 44 | ├── ... 45 | ├── 20221020 46 | └── label_class 47 | ├── train1.txt 48 | ├── train2.txt 49 | └── train3.txt 50 | ``` 51 | 52 | - `class.label`: This section offers gaze zone classification details. The first row denotes the class number according to `label_class`. The second row represents the original numbers assigned during the data collection phase. The third row indicates coarse region numbers. 53 | - `Norm`: This section contains normalized images and their corresponding labels. 54 | - `Norm/label_class`: Here, you'll find label files for three-fold validation. 55 | - `Origin`: This section provides original images directly cropped from facial images, along with their label files. 56 | 57 | ## Usage 58 | 59 | To retrieve data from the IVGaze Dataset, begin by reading the label file, such as `Norm/label_class/train1.txt`. Each line in the label file is formatted with space-separated values. You can read one line at a time for processing. 60 | 61 | ``` 62 | root = 'IVGazeDataset/Norm' 63 | with open(os.path.join(root, 'label_class/train1.txt')) as infile: 64 | lines = infile.readlines() 65 | 66 | for line in lines: 67 | line = line.strip().split(' ') 68 | 69 | # Read the image 70 | image_name = line[0] 71 | image = cv2.imread(os.path.join(root, image_name)) 72 | 73 | # GT for gaze and zone 74 | gaze = np.fromstring(line[1], sep=',') 75 | zone = int(line[3]) 76 | ``` 77 | 78 | ## Download 79 | To obtain access to the dataset, please send an email to `y.cheng.2@bham.ac.uk`. 80 | You will receive a Google Drive link within three days for downloading the dataset. 81 | 82 | Here's the email prompt for requesting access to the IVGaze Dataset. Please do not change the email subject. 83 | 84 | ``` 85 | Subject: Request for Access to IVGaze Dataset 86 | 87 | Dear Yihua, 88 | 89 | I hope this email finds you well. 90 | 91 | I am writing to request access to the IVGaze Dataset. My name is [Your Name], and I am a [student/researcher] from [Your Affiliation]. 92 | 93 | I assure you that I will only utilize the dataset for academic and research purposes and will not use it for commercial activities. 94 | 95 | Thank you for considering my request. I look forward to receiving access to the dataset. 96 | 97 | Best regards, 98 | [Your Name] 99 | ``` 100 | 101 | -------------------------------------------------------------------------------- /GazeDPTR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | import copy 6 | from IVModule import Backbone, Transformer, PoseTransformer, TripleDifferentialProj, PositionalEncoder 7 | 8 | 9 | def ep0(x): 10 | return x.unsqueeze(0) 11 | 12 | class Model(nn.Module): 13 | 14 | def __init__(self): 15 | 16 | super(Model, self).__init__() 17 | 18 | # used for origin. 19 | transIn = 128 20 | convDims = [64, 128, 256, 512] 21 | 22 | # origin produces two features, one for gaze zone, one for gaze direction. 23 | self.borigin = Backbone(1, transIn, convDims) 24 | 25 | # norm only used for gaze estimation. 26 | self.bnorm = Backbone(1, transIn, convDims) 27 | 28 | self.ptrans = PoseTransformer(transIn, 3) 29 | 30 | # MLP for gaze estimation 31 | self.MLP_o_dir = nn.Linear(transIn, 2) 32 | self.MLP_n_dir = nn.Linear(transIn, 2) 33 | 34 | 35 | module_list = [] 36 | for i in range(len(convDims)): 37 | module_list.append(nn.Linear(transIn, 2)) 38 | self.MLPList_o = nn.ModuleList(module_list) 39 | 40 | module_list = [] 41 | for i in range(len(convDims)): 42 | module_list.append(nn.Linear(transIn, 2)) 43 | self.MLPList_n = nn.ModuleList(module_list) 44 | 45 | self.MLP_o_dir2 = nn.Linear(transIn, 2) 46 | self.MLP_n_dir2 = nn.Linear(transIn, 2) 47 | 48 | # Loss function 49 | self.loss_op_re = nn.L1Loss() 50 | self.loss_op_cls = nn.CrossEntropyLoss() 51 | 52 | 53 | def forward(self, x_in, train=True): 54 | 55 | # feature [outFeatureNum, Batch, transIn], MLfeatgure: list[x1, x2...] 56 | 57 | # Extract feature from both two images 58 | feature_o, feature_list_o= self.borigin(x_in['origin_face']) 59 | 60 | feature_n, feature_list_n = self.bnorm(x_in['norm_face']) 61 | 62 | # Get feature for different task 63 | # [5, 128] [1. 5, 128] 64 | feature_o_dir = feature_o.squeeze() 65 | feature_n_dir = feature_n.squeeze() 66 | 67 | # Fuse two direction feature and input it into transformer 68 | features_dir = torch.cat([ep0(feature_o_dir), ep0(feature_n_dir)], 0) 69 | features = self.ptrans(features_dir, x_in['pos']) 70 | 71 | # Get fused feature 72 | # feature_o_dir2 = features[0, :] 73 | feature_n_dir2 = features[1, :] 74 | 75 | # estimate gaze from fused feature 76 | gaze = self.MLP_n_dir2(feature_n_dir2) 77 | # zone = self.MLP_o_zone(feature_o_zone) 78 | 79 | # for loss caculation 80 | loss_gaze_o = [] 81 | loss_gaze_n = [] 82 | if train: 83 | loss_gaze_n.append(self.MLP_n_dir(feature_n_dir)) 84 | loss_gaze_o.append(self.MLP_o_dir(feature_o_dir)) 85 | 86 | for i, feature in enumerate(feature_list_o): 87 | loss_gaze_o.append(self.MLPList_o[i](feature)) 88 | 89 | for i, feature in enumerate(feature_list_n): 90 | loss_gaze_n.append(self.MLPList_n[i](feature)) 91 | 92 | return gaze, None, loss_gaze_o, loss_gaze_n 93 | 94 | def loss(self, x_in, label): 95 | 96 | gaze, _, loss_gaze_o, loss_gaze_n = self.forward(x_in) 97 | 98 | loss1 = 2 * self.loss_op_re(gaze, label.normGaze) 99 | 100 | loss2 = 0 101 | # for zone in zones: 102 | # loss2 += (0.2/3) * self.loss_op_cls(zone, label.zone.view(-1)) 103 | 104 | loss3 = 0 105 | for pred in loss_gaze_o: 106 | loss3 += self.loss_op_re(pred, label.originGaze) 107 | 108 | loss4 = 0 109 | for pred in loss_gaze_n: 110 | loss4 += self.loss_op_re(pred, label.normGaze) 111 | loss = loss1 + loss2 + loss3 + loss4 112 | 113 | return loss, [loss1, loss2, loss3, loss4] 114 | 115 | 116 | if __name__ == '__main__': 117 | x_in = {'origin': torch.zeros([5, 3, 224, 224]).cuda(), 118 | 'norm': torch.zeros([5, 3, 224, 224]).cuda(), 119 | 'pos': torch.zeros(5, 2, 6).cuda() 120 | } 121 | 122 | model = Model() 123 | model = model.to('cuda') 124 | print(model) 125 | a = model(x_in) 126 | print(a) 127 | -------------------------------------------------------------------------------- /tester/total.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | base_dir = os.getcwd() 3 | sys.path.insert(0, base_dir) 4 | import model 5 | import importlib 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import cv2, yaml, copy 11 | from easydict import EasyDict as edict 12 | import ctools, gtools 13 | import argparse 14 | 15 | def main(train, test): 16 | 17 | # =================================> Setup <========================= 18 | reader = importlib.import_module("reader." + test.reader) 19 | torch.cuda.set_device(test.device) 20 | 21 | data = test.data 22 | load = test.load 23 | 24 | 25 | # ===============================> Read Data <========================= 26 | if data.isFolder: 27 | data, _ = ctools.readfolder(data) 28 | 29 | print(f"==> Test: {data.label} <==") 30 | dataset = reader.loader(data, 32, num_workers=4, shuffle=False) 31 | 32 | modelpath = os.path.join(train.save.metapath, 33 | train.save.folder, f"checkpoint/") 34 | 35 | logpath = os.path.join(train.save.metapath, 36 | train.save.folder, f"{test.savename}") 37 | 38 | 39 | if not os.path.exists(logpath): 40 | os.makedirs(logpath) 41 | 42 | # =============================> Test <============================= 43 | 44 | begin = load.begin_step; end = load.end_step; step = load.steps 45 | 46 | for saveiter in range(begin, end+step, step): 47 | 48 | print(f"Test {saveiter}") 49 | 50 | net = model.Model() 51 | 52 | statedict = torch.load( 53 | os.path.join(modelpath, 54 | f"Iter_{saveiter}_{train.save.model_name}.pt"), 55 | map_location={f"cuda:{train.device}": f"cuda:{test.device}"} 56 | ) 57 | 58 | net.cuda(); net.load_state_dict(statedict); net.eval() 59 | 60 | length = len(dataset); accs = 0; count = 0 61 | 62 | logname = f"{saveiter}.log" 63 | 64 | outfile = open(os.path.join(logpath, logname), 'w') 65 | outfile.write("name results gts\n") 66 | 67 | 68 | with torch.no_grad(): 69 | for j, (data, label) in enumerate(dataset): 70 | 71 | for key in data: 72 | if key != 'name': data[key] = data[key].cuda() 73 | 74 | names = data["name"] 75 | gts = label.cuda() 76 | 77 | gazes = net(data) 78 | 79 | for k, gaze in enumerate(gazes): 80 | 81 | gaze = gaze.cpu().detach().numpy() 82 | gt = gts.cpu().numpy()[k] 83 | 84 | count += 1 85 | accs += gtools.angular( 86 | gtools.gazeto3d(gaze), 87 | gtools.gazeto3d(gt) 88 | ) 89 | 90 | name = [names[k]] 91 | gaze = [str(u) for u in gaze] 92 | gt = [str(u) for u in gt] 93 | log = name + [",".join(gaze)] + [",".join(gt)] 94 | outfile.write(" ".join(log) + "\n") 95 | 96 | loger = f"[{saveiter}] Total Num: {count}, avg: {accs/count}" 97 | outfile.write(loger) 98 | print(loger) 99 | outfile.close() 100 | 101 | if __name__ == "__main__": 102 | 103 | parser = argparse.ArgumentParser(description='Pytorch Basic Model Training') 104 | 105 | parser.add_argument('-s', '--source', type=str, 106 | help = 'config path about training') 107 | 108 | parser.add_argument('-t', '--target', type=str, 109 | help = 'config path about test') 110 | 111 | args = parser.parse_args() 112 | 113 | # Read model from train config and Test data in test config. 114 | train_conf = edict(yaml.load(open(args.source), Loader=yaml.FullLoader)) 115 | 116 | test_conf = edict(yaml.load(open(args.target), Loader=yaml.FullLoader)) 117 | 118 | print("=======================>(Begin) Config of training<======================") 119 | print(ctools.DictDumps(train_conf)) 120 | print("=======================>(End) Config of training<======================") 121 | print("") 122 | print("=======================>(Begin) Config for test<======================") 123 | print(ctools.DictDumps(test_conf)) 124 | print("=======================>(End) Config for test<======================") 125 | 126 | main(train_conf.train, test_conf.test) 127 | 128 | 129 | -------------------------------------------------------------------------------- /trainer/total.py: -------------------------------------------------------------------------------- 1 | import sys,os 2 | base_dir = os.getcwd() 3 | sys.path.insert(0, base_dir) 4 | import model 5 | import importlib 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import copy 11 | import yaml 12 | import cv2 13 | import ctools 14 | from easydict import EasyDict as edict 15 | import torch.backends.cudnn as cudnn 16 | from warmup_scheduler import GradualWarmupScheduler 17 | import argparse 18 | 19 | def main(config): 20 | 21 | # ===================>> Setup <<================================= 22 | 23 | dataloader = importlib.import_module("reader." + config.reader) 24 | 25 | torch.cuda.set_device(config.device) 26 | cudnn.benchmark = True 27 | 28 | data = config.data 29 | save = config.save 30 | params = config.params 31 | 32 | print("===> Read data <===") 33 | 34 | if data.isFolder: 35 | data, _ = ctools.readfolder(data) 36 | 37 | dataset = dataloader.loader( 38 | data, 39 | params.batch_size, 40 | shuffle=True, 41 | num_workers=8 42 | ) 43 | 44 | 45 | print("===> Model building <===") 46 | net = model.Model() 47 | net.train(); net.cuda() 48 | 49 | 50 | # Pretrain 51 | pretrain = config.pretrain 52 | 53 | if pretrain.enable and pretrain.device: 54 | net.load_state_dict( 55 | torch.load( 56 | pretrain.path, 57 | map_location={f"cuda:{pretrain.device}": f"cuda:{config.device}"} 58 | ) 59 | ) 60 | elif pretrain.enable and not pretrain.device: 61 | net.load_state_dict( 62 | torch.load(pretrain.path) 63 | ) 64 | 65 | 66 | print("===> optimizer building <===") 67 | optimizer = optim.Adam( 68 | net.parameters(), 69 | lr=params.lr, 70 | betas=(0.9,0.999) 71 | ) 72 | 73 | scheduler = optim.lr_scheduler.StepLR( 74 | optimizer, 75 | step_size=params.decay_step, 76 | gamma=params.decay 77 | ) 78 | 79 | if params.warmup: 80 | scheduler = GradualWarmupScheduler( 81 | optimizer, 82 | multiplier=1, 83 | total_epoch=params.warmup, 84 | after_scheduler=scheduler 85 | ) 86 | 87 | savepath = os.path.join(save.metapath, save.folder, f"checkpoint") 88 | 89 | if not os.path.exists(savepath): 90 | os.makedirs(savepath) 91 | 92 | # =====================================>> Training << ==================================== 93 | print("===> Training <===") 94 | 95 | length = len(dataset); total = length * params.epoch 96 | timer = ctools.TimeCounter(total) 97 | 98 | 99 | optimizer.zero_grad() 100 | optimizer.step() 101 | scheduler.step() 102 | 103 | 104 | with open(os.path.join(savepath, "train_log"), 'w') as outfile: 105 | outfile.write(ctools.DictDumps(config) + '\n') 106 | 107 | for epoch in range(1, params.epoch+1): 108 | for i, (data, anno) in enumerate(dataset): 109 | 110 | # -------------- forward ------------- 111 | for key in data: 112 | if key != 'name': data[key] = data[key].cuda() 113 | 114 | anno = anno.cuda() 115 | loss = net.loss(data, anno) 116 | 117 | # -------------- Backward ------------ 118 | optimizer.zero_grad() 119 | loss.backward() 120 | optimizer.step() 121 | rest = timer.step()/3600 122 | 123 | 124 | if i % 20 == 0: 125 | log = f"[{epoch}/{params.epoch}]: " + \ 126 | f"[{i}/{length}] " +\ 127 | f"loss:{loss} " +\ 128 | f"lr:{ctools.GetLR(optimizer)} " +\ 129 | f"rest time:{rest:.2f}h" 130 | 131 | print(log); outfile.write(log + "\n") 132 | sys.stdout.flush(); outfile.flush() 133 | 134 | scheduler.step() 135 | 136 | if epoch % save.step == 0: 137 | torch.save( 138 | net.state_dict(), 139 | os.path.join( 140 | savepath, 141 | f"Iter_{epoch}_{save.model_name}.pt" 142 | ) 143 | ) 144 | 145 | 146 | if __name__ == "__main__": 147 | 148 | parser = argparse.ArgumentParser(description='Pytorch Basic Model Training') 149 | 150 | parser.add_argument('-s', '--train', type=str, 151 | help='The source config for training.') 152 | 153 | args = parser.parse_args() 154 | 155 | config = edict(yaml.load(open(args.train), Loader=yaml.FullLoader)) 156 | 157 | print("=====================>> (Begin) Training params << =======================") 158 | print(ctools.DictDumps(config)) 159 | print("=====================>> (End) Traning params << =======================") 160 | 161 | main(config.train) 162 | 163 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | 6 | model_urls = { 7 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 8 | } 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | 16 | 17 | class BasicBlock(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, inplanes, planes, stride=1, downsample=None): 21 | super(BasicBlock, self).__init__() 22 | self.conv1 = conv3x3(inplanes, planes, stride) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.conv2 = conv3x3(planes, planes) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | residual = x 32 | 33 | out = self.conv1(x) 34 | out = self.bn1(out) 35 | out = self.relu(out) 36 | 37 | out = self.conv2(out) 38 | out = self.bn2(out) 39 | 40 | if self.downsample is not None: 41 | residual = self.downsample(x) 42 | 43 | out += residual 44 | out = self.relu(out) 45 | 46 | return out 47 | 48 | 49 | class Bottleneck(nn.Module): 50 | expansion = 4 51 | 52 | def __init__(self, inplanes, planes, stride=1, downsample=None): 53 | super(Bottleneck, self).__init__() 54 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 57 | padding=1, bias=False) 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 60 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.downsample = downsample 63 | self.stride = stride 64 | 65 | def forward(self, x): 66 | residual = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv3(out) 77 | out = self.bn3(out) 78 | 79 | if self.downsample is not None: 80 | residual = self.downsample(x) 81 | 82 | out += residual 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | 88 | class ResNet(nn.Module): 89 | 90 | def __init__(self, block, layers, input_dim=[64, 128, 256, 512]): 91 | super(ResNet, self).__init__() 92 | 93 | self.inplanes = input_dim[0] 94 | 95 | #assert len(maps) == 5, f'The length of input_dim should be 5' 96 | 97 | self.conv1 = nn.Conv2d(3, input_dim[0], kernel_size=7, stride=2, padding=3, 98 | bias=False) 99 | self.bn1 = nn.BatchNorm2d(input_dim[0]) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 102 | 103 | self.layer1 = self._make_layer(block, input_dim[0], layers[0]) 104 | self.layer2 = self._make_layer(block, input_dim[1], layers[1], stride=2) 105 | self.layer3 = self._make_layer(block, input_dim[2], layers[2], stride=2) 106 | self.layer4 = self._make_layer(block, input_dim[3], layers[3], stride=2) 107 | 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 111 | elif isinstance(m, nn.BatchNorm2d): 112 | nn.init.constant_(m.weight, 1) 113 | nn.init.constant_(m.bias, 0) 114 | 115 | def _make_layer(self, block, planes, blocks, stride=1): 116 | downsample = None 117 | if stride != 1 or self.inplanes != planes * block.expansion: 118 | downsample = nn.Sequential( 119 | nn.Conv2d(self.inplanes, planes * block.expansion, 120 | kernel_size=1, stride=stride, bias=False), 121 | nn.BatchNorm2d(planes * block.expansion), 122 | ) 123 | 124 | layers = [] 125 | layers.append(block(self.inplanes, planes, stride, downsample)) 126 | self.inplanes = planes * block.expansion 127 | for i in range(1, blocks): 128 | layers.append(block(self.inplanes, planes)) 129 | 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x): 133 | x = self.conv1(x) 134 | x = self.bn1(x) 135 | x = self.relu(x) 136 | x = self.maxpool(x) 137 | 138 | x1 = self.layer1(x) 139 | 140 | x2 = self.layer2(x1) 141 | 142 | x3 = self.layer3(x2) 143 | 144 | x4 = self.layer4(x3) 145 | 146 | return x1, x2, x3, x4 147 | 148 | 149 | def resnet18(pretrained=False, **kwargs): 150 | """Constructs a ResNet-18 model. 151 | Args: 152 | pretrained (bool): If True, returns a model pre-trained on ImageNet 153 | """ 154 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 155 | if pretrained: 156 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']),strict=False) 157 | return model 158 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![](./assets/logo.png) 2 |

What Do You See in Vehicle? Comprehensive Vision Solution for In-Vehicle Gaze Estimation

3 | 4 | Yihua Cheng , Yaning Zhu, Zongji Wang, Hongquan Hao, Yongwei Liu, Shiqing Cheng, Xi Wang, Hyung Jin Chang, CVPR 2024 5 | 6 | 7 | 8 | 9 | 10 | ## Description 11 | This repository provides offical code of the paper titled *What Do You See in Vehicle? Comprehensive Vision Solution for In-Vehicle Gaze Estimation*, accepted at CVPR24. 12 | Our contribution includes: 13 | - We provide a dataset **IVGaze** collected on vehicles containing 44k images of 125 subjects. 14 | - We propose a gaze pyramid transformer (GazePTR) that leverages transformer-based multilevel features integration. 15 | - We introduce the dual-stream gaze pyramid transformer (GazeDPTR). Employing perspective transformation, we rotate virtual cameras to normalize images, utilizing camera pose to merge normalized and original images for accurate gaze estimation. 16 | 17 | Please visit our project page for details. The dataset is available on this page . 18 | 19 | [![Gaze](https://res.cloudinary.com/marcomontalbano/image/upload/v1720447174/video_to_markdown/images/youtube--050M4CK5EwI-c05b58ac6eb4c4700831b2b3070cd403.jpg)](https://www.youtube.com/watch?v=050M4CK5EwI&t=3s "Gaze") 20 | 21 | ## Requirement 22 | 23 | 1. Install Pytorch and torchvision. This code is written in `Python 3.8` and utilizes `PyTorch 1.13.1` with `CUDA 11.6` on Nvidia GeForce RTX 3090. While this environment is recommended, it is not mandatory. Feel free to run the code on your preferred environment. 24 | 25 | ``` 26 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 27 | ``` 28 | 29 | 2. Install other packages. 30 | ``` 31 | pip install opencv-python PyYAML easydict warmup_scheduler 32 | ``` 33 | 34 | If you have any issues due to missing packages, please report them. I will update the requirements. Thank you for your cooperation. 35 | 36 | ## Training 37 | 38 | **Step 1: Choose the model file.** 39 | 40 | We provide three models `GazePTR.py`, `GazeDPTR.py` and `GazeDPTR_v2.py`. The links of pre-trained weights are the same. Please load corresponding weights based on your requirement. 41 | 42 | | | Name | Description | Input | Output|Accuracy|Pretrained Weights| 43 | |:----|:----|:----|:----:|:----:|:----:|:----:| 44 | |1|GazePTR| This method leverages multi-level feature.|Normalized Images|Gaze Directions|7.04°| Link | 45 | |2|GazeDPTR| This method integrates feature from two images.|Normalized Images Original Images|Gaze Directions|6.71°| Link | 46 | |3|GazeDPTR_V2| This method contains a diffierential projection for gaze zone prediction. |Normalized Images Original Images|Gaze Directions Gaze Zone|6.71° 81.8%| Link | 47 | 48 | Please choose one model and rename it as `model.py`, *e.g.*, 49 | ``` 50 | cp GazeDPTR.py model.py 51 | ``` 52 | 53 | **Step 2: Modify the config file** 54 | 55 | 56 | Please modify `config/train/config_iv.yaml` according to your environment settings. 57 | 58 | - The `Save` attribute specifies the save path, where the model will be stored at`os.path.join({save.metapath}, {save.folder})`. Each saved model will be named as `Iter_{epoch}_{save.model_name}.pt` 59 | - The `data` attribute indicates the dataset path. Update the `image` and `label` to match your dataset location. 60 | 61 | **Step 3: Training models** 62 | 63 | Run the following command to initiate training. The argument `3` indicates that it will automatically perform three-fold cross-validation: 64 | 65 | ``` 66 | python trainer/leave.py config/train/config_iv.yaml 3 67 | ``` 68 | 69 | Once the training is complete, you will find the weights saved at `os.path.join({save.metapath}, {save.folder})`. 70 | Within the `checkpoint` directory, you will find three folders named `train1.txt`, `train2.txt`, and `train3.txt`, corresponding to the three-fold cross-validation. Each folder contains the respective trained model." 71 | 72 | ## Testing 73 | Run the following command for testing. 74 | ``` 75 | python tester/leave.py config/train/config_iv.yaml config/test/config_iv.yaml 3 76 | ``` 77 | Similarly, 78 | - Update the `image` and `label` in `config/test/config_iv.yaml` based on your dataset location. 79 | - The `savename` attribute specifies the folder to save prediction results, which will be stored at `os.path.join({save.metapath}, {save.folder})` as defined in `config/train/config_iv.yaml`. 80 | - The code `tester/leave.py` provides the gaze zone prediction results. Remove it if you do not require gaze zone prediction. 81 | 82 | ## Evaluation 83 | 84 | We provide `evaluation.py` script to assess the accuracy of gaze direction estimation. Run the following command: 85 | ``` 86 | python evaluation.py {PATH} 87 | ``` 88 | Replace `{PATH}` with the path of `{savename}` as configured in your settings. 89 | 90 | Please find the visualization code in the issues. 91 | 92 | ## Contact 93 | Please send email to `y.cheng.2@bham.ac.uk` if you have any questions. 94 | -------------------------------------------------------------------------------- /tester/leave.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | base_dir = os.getcwd() 3 | sys.path.insert(0, base_dir) 4 | import model 5 | import importlib 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import cv2, yaml, copy 11 | from easydict import EasyDict as edict 12 | import ctools, gtools 13 | import argparse 14 | 15 | def main(train, test): 16 | 17 | # ===============================> Setup <============================ 18 | reader = importlib.import_module("reader." + test.reader) 19 | 20 | data = test.data 21 | load = test.load 22 | torch.cuda.set_device(test.device) 23 | 24 | 25 | # ==============================> Read Data <======================== 26 | data.origin, folder = ctools.readfolder(data.origin, [test.person]) 27 | data.norm, folder = ctools.readfolder(data.norm, [test.person]) 28 | 29 | testname = folder[test.person] 30 | 31 | dataset = reader.loader(data, 500, num_workers=4, shuffle=True) 32 | 33 | modelpath = os.path.join(train.save.metapath, 34 | train.save.folder, f'checkpoint/{testname}') 35 | logpath = os.path.join(train.save.metapath, 36 | train.save.folder, f'{test.savename}/{testname}') 37 | 38 | if not os.path.exists(logpath): 39 | os.makedirs(logpath) 40 | 41 | # =============================> Test <============================== 42 | 43 | begin = load.begin_step; end = load.end_step; step = load.steps 44 | 45 | for saveiter in range(begin, end+step, step): 46 | print(f"Test {saveiter}") 47 | 48 | # ----------------------Load Model------------------------------ 49 | net = model.Model() 50 | 51 | 52 | statedict = torch.load( 53 | os.path.join(modelpath, f"Iter_{saveiter}_{train.save.model_name}.pt"), 54 | map_location={f"cuda:{train.device}":f"cuda:{test.device}"} 55 | ) 56 | 57 | 58 | net.cuda(); net.load_state_dict(statedict); net.eval() 59 | 60 | length = len(dataset); accs = 0; count = 0 61 | 62 | # -----------------------Open log file-------------------------------- 63 | logname = f"{saveiter}.log" 64 | 65 | outfile1 = open(os.path.join(logpath, logname + '_zone'), 'w') 66 | outfile1.write("name results gts\n") 67 | 68 | outfile2 = open(os.path.join(logpath, logname + '_dir'), 'w') 69 | outfile2.write("name results gts\n") 70 | 71 | 72 | 73 | # -------------------------Testing--------------------------------- 74 | with torch.no_grad(): 75 | n_true = {} 76 | n_total = {} 77 | 78 | for j, (data, label) in enumerate(dataset): 79 | 80 | for key in data: 81 | if key != 'name': data[key] = data[key].cuda() 82 | 83 | names = data["name"] 84 | 85 | gt_zones = label.zone.view(-1) 86 | gt_dirs = label.originGaze 87 | gt_dirs = label.normGaze 88 | gazes, zones, _, _ = net(data, train=False) 89 | 90 | 91 | for k, cls in enumerate(zones): 92 | 93 | gt = str(int(gt_zones[k])) 94 | name = [names[k]] 95 | if gt == str(int(cls)): 96 | n_true[gt] = 1 + n_true.get(gt, 0) 97 | 98 | n_total[gt] = 1 + n_total.get(gt, 0) 99 | 100 | log = name + [f"{cls}"] + [f"{gt}"] 101 | outfile1.write(" ".join(log) + "\n") 102 | 103 | for k, gaze in enumerate(gazes): 104 | 105 | gaze = gaze.cpu().detach().numpy() 106 | gt = gt_dirs.numpy()[k] 107 | 108 | count += 1 109 | accs += gtools.angular( 110 | gtools.gazeto3d(gaze), 111 | gtools.gazeto3d(gt) 112 | ) 113 | 114 | name = [names[k]] 115 | gaze = [str(u) for u in gaze] 116 | gt = [str(u) for u in gt] 117 | log = name + [",".join(gaze)] + [",".join(gt)] 118 | outfile2.write(" ".join(log) + "\n") 119 | 120 | keys = sorted(list(n_true.keys()), key = lambda x:int(x)) 121 | true_num = 0 122 | total_num = 0 123 | for key in keys: 124 | true_num += n_true[key] 125 | total_num += n_total[key] 126 | loger = f'Class {key} {n_true[key]} {n_total[key]} {n_true[key]/n_total[key]:.3f}\n' 127 | outfile1.write(loger) 128 | loger = f"[{saveiter}] Total Num: {total_num}, True: {true_num}, AP:{true_num/total_num:.3f}" 129 | outfile1.write(loger) 130 | print(loger) 131 | 132 | loger = f"[{saveiter}] Total Num: {count}, avg: {accs/count}" 133 | outfile2.write(loger) 134 | print(loger) 135 | outfile1.close() 136 | outfile2.close() 137 | 138 | if __name__ == "__main__": 139 | 140 | 141 | # Read model from train config and Test data in test config. 142 | train_conf = edict(yaml.load(open(sys.argv[1]), Loader=yaml.FullLoader)) 143 | 144 | test_conf = edict(yaml.load(open(sys.argv[2]), Loader=yaml.FullLoader)) 145 | 146 | for i in range(int(sys.argv[3])): 147 | test_conf_cur = copy.deepcopy(test_conf) 148 | 149 | test_conf_cur.person = i 150 | 151 | print("=======================>(Begin) Config of training<======================") 152 | 153 | print(ctools.DictDumps(train_conf)) 154 | 155 | print("=======================>(End) Config of training<======================") 156 | 157 | print("") 158 | 159 | print("=======================>(Begin) Config for test<======================") 160 | 161 | print(ctools.DictDumps(test_conf_cur)) 162 | 163 | print("=======================>(End) Config for test<======================") 164 | 165 | main(train_conf.train, test_conf_cur) 166 | 167 | -------------------------------------------------------------------------------- /trainer/leave.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | base_dir = os.getcwd() 3 | sys.path.insert(0, base_dir) 4 | import model 5 | import importlib 6 | import numpy as np 7 | import torch 8 | import cv2 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | import copy 12 | import yaml 13 | import ctools 14 | from easydict import EasyDict as edict 15 | import torch.backends.cudnn as cudnn 16 | from warmup_scheduler import GradualWarmupScheduler 17 | import random 18 | import argparse 19 | 20 | def setup_seed(seed=0): 21 | torch.manual_seed(seed) 22 | np.random.seed(seed) 23 | random.seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | 27 | 28 | def main(config): 29 | # ===============================> Setup <================================ 30 | 31 | setup_seed(123) 32 | 33 | dataloader = importlib.import_module("reader." + config.reader) 34 | torch.cuda.set_device(config.device) 35 | cudnn.benchmark = True 36 | 37 | data = config.data 38 | save = config.save 39 | params = config.params 40 | 41 | 42 | print("===> Read data <===") 43 | data.origin, folder = ctools.readfolder( 44 | data.origin, 45 | [config.person], 46 | reverse=True 47 | ) 48 | 49 | data.norm, folder = ctools.readfolder( 50 | data.norm, 51 | [config.person], 52 | reverse=True 53 | ) 54 | 55 | 56 | savename = folder[config.person] 57 | 58 | dataset = dataloader.loader( 59 | data, 60 | params.batch_size, 61 | shuffle=True, 62 | num_workers=6 63 | ) 64 | 65 | print("===> Model building <===") 66 | net = model.Model(); net.train(); net.cuda() 67 | 68 | 69 | # Pretrain 70 | pretrain = config.pretrain 71 | if pretrain.enable and pretrain.device: 72 | net.load_state_dict( 73 | torch.load( 74 | pretrain.path, 75 | map_location={f"cuda:{pretrain.device}": f"cuda:{config.device}"} 76 | ) 77 | ) 78 | elif pretrain.enable and not pretrain.device: 79 | net.load_state_dict( 80 | torch.load(pretrain.path) 81 | ) 82 | 83 | 84 | print("===> optimizer building <===") 85 | optimizer = optim.Adam( 86 | net.parameters(), 87 | lr=params.lr, 88 | betas=(0.9,0.95) 89 | ) 90 | 91 | scheduler = optim.lr_scheduler.StepLR( 92 | optimizer, 93 | step_size=params.decay_step, 94 | gamma=params.decay 95 | ) 96 | 97 | if params.warmup: 98 | scheduler = GradualWarmupScheduler( 99 | optimizer, 100 | multiplier=1, 101 | total_epoch=params.warmup, 102 | after_scheduler=scheduler 103 | ) 104 | 105 | savepath = os.path.join(save.metapath, save.folder, f"checkpoint/{savename}") 106 | 107 | if not os.path.exists(savepath): 108 | os.makedirs(savepath) 109 | 110 | # =======================================> Training < ========================== 111 | print("===> Training <===") 112 | length = len(dataset); total = length * params.epoch 113 | timer = ctools.TimeCounter(total) 114 | 115 | 116 | optimizer.zero_grad() 117 | optimizer.step() 118 | scheduler.step() 119 | 120 | with open(os.path.join(savepath, "train_log"), 'w') as outfile: 121 | outfile.write(ctools.DictDumps(config) + '\n') 122 | 123 | for epoch in range(1, params.epoch+1): 124 | for i, (data, anno) in enumerate(dataset): 125 | 126 | # ------------------forward-------------------- 127 | for key in data.keys(): 128 | if key != 'name': data[key] = data[key].cuda() 129 | 130 | for key in anno.keys(): anno[key] = anno[key].cuda() 131 | 132 | loss, losslist = net.loss(data, anno) 133 | 134 | # -----------------backward-------------------- 135 | optimizer.zero_grad() 136 | 137 | loss.backward() 138 | 139 | optimizer.step() 140 | 141 | rest = timer.step()/3600 142 | 143 | # -----------------loger---------------------- 144 | if i % 20 == 0: 145 | log = f"[{epoch}/{params.epoch}]: " +\ 146 | f"[{i}/{length}] " +\ 147 | f"loss:{loss:.3f} " +\ 148 | f"loss_re:{losslist[0]:.3f} " +\ 149 | f"loss_cls:{losslist[1]:.3f} " +\ 150 | f"loss_o:{losslist[2]:.3f} " +\ 151 | f"loss_n:{losslist[3]:.3f} " +\ 152 | f"lr:{ctools.GetLR(optimizer)} "+\ 153 | f"rest time:{rest:.2f}h" 154 | 155 | print(log); outfile.write(log + "\n") 156 | sys.stdout.flush(); outfile.flush() 157 | 158 | scheduler.step() 159 | 160 | if epoch % save.step == 0: 161 | torch.save( 162 | net.state_dict(), 163 | os.path.join(savepath, f"Iter_{epoch}_{save.model_name}.pt") 164 | ) 165 | 166 | if __name__ == "__main__": 167 | 168 | config = edict(yaml.load(open(sys.argv[1]), Loader=yaml.FullLoader)) 169 | 170 | config = config.train 171 | person = int(sys.argv[2]) 172 | 173 | for i in range(person): 174 | config_i = copy.deepcopy(config) 175 | config_i.person = i 176 | 177 | print("=====================>> (Begin) Training params << =======================") 178 | 179 | print(ctools.DictDumps(config_i)) 180 | 181 | print("=====================>> (End) Traning params << =======================") 182 | 183 | main(config_i) 184 | 185 | -------------------------------------------------------------------------------- /reader/reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import random 5 | import numpy as np 6 | from easydict import EasyDict as edict 7 | from torch.utils.data import Dataset, DataLoader 8 | from torchvision import transforms 9 | from PIL import Image 10 | import copy 11 | 12 | def gazeto2d(gaze): 13 | yaw = -np.arctan2(-gaze[0], -gaze[2]) 14 | pitch = -np.arcsin(-gaze[1]) 15 | return np.array([yaw, pitch]) 16 | 17 | def Decode_MPII(line): 18 | anno = edict() 19 | anno.face, anno.lefteye, anno.righteye = line[0], line[1], line[2] 20 | anno.name = line[3] 21 | 22 | anno.gaze3d, anno.head3d = line[5], line[6] 23 | anno.gaze2d, anno.head2d = line[7], line[8] 24 | return anno 25 | 26 | def Decode_IVOrigin(line): 27 | anno = edict() 28 | anno.face = line[0] 29 | anno.name = line[0] 30 | anno.gaze = line[1] 31 | anno.placeholder = line[2] 32 | anno.zone = line[3] 33 | # anno.target = line[4] 34 | anno.origin = line[5] 35 | return anno 36 | 37 | def Decode_IVNorm(line): 38 | anno = edict() 39 | anno.face = line[0] 40 | anno.name = line[0] 41 | anno.gaze = line[1] 42 | anno.head = line[2] 43 | anno.zone = line[3] 44 | anno.origin = line[4] 45 | anno.norm = line[6] 46 | return anno 47 | 48 | 49 | def Decode_Dict(): 50 | mapping = edict() 51 | mapping.ivorigin = Decode_IVOrigin 52 | mapping.ivnorm = Decode_IVNorm 53 | return mapping 54 | 55 | def long_substr(str1, str2): 56 | substr = '' 57 | for i in range(len(str1)): 58 | for j in range(len(str1)-i+1): 59 | if j > len(substr) and (str1[i:i+j] in str2): 60 | substr = str1[i:i+j] 61 | return len(substr) 62 | 63 | def Get_Decode(name): 64 | mapping = Decode_Dict() 65 | keys = list(mapping.keys()) 66 | name = name.lower() 67 | score = [long_substr(name, i) for i in keys] 68 | key = keys[score.index(max(score))] 69 | return mapping[key] 70 | 71 | 72 | class commonloader(Dataset): 73 | 74 | def __init__(self, dataset): 75 | 76 | # Read source data 77 | self.source = edict() 78 | self.source.origin = edict() 79 | self.source.norm = edict() 80 | 81 | # Read origin data 82 | origin = dataset.origin 83 | 84 | self.source.origin.root = origin.image 85 | self.source.origin.line = self.__readlines(origin.label, origin.header) 86 | self.source.origin.decode = Get_Decode(origin.name) 87 | 88 | # Read norm data 89 | norm = dataset.norm 90 | 91 | # self.source.norm = copy.deepcopy(dataset.norm) 92 | self.source.norm.root = norm.image 93 | self.source.norm.line = self.__readlines(norm.label, norm.header) 94 | self.source.norm.decode = Get_Decode(norm.name) 95 | 96 | # build transforms 97 | self.transforms = transforms.Compose([ 98 | transforms.ToTensor() 99 | ]) 100 | 101 | 102 | def __readlines(self, filename, header=True): 103 | 104 | data = [] 105 | if isinstance(filename, list): 106 | for i in filename: 107 | with open(i) as f: line = f.readlines() 108 | if header: line.pop(0) 109 | data.extend(line) 110 | 111 | else: 112 | with open(filename) as f: data = f.readlines() 113 | if header: data.pop(0) 114 | return data 115 | 116 | 117 | def __len__(self): 118 | assert len(self.source.origin.line) == len(self.source.norm.line), 'Two files are not aligned.' 119 | return len(self.source.origin.line) 120 | 121 | 122 | def __getitem__(self, idx): 123 | 124 | # ------------------Read origin----------------------- 125 | line = self.source.origin.line[idx] 126 | line = line.strip().split(" ") 127 | 128 | # decode the info 129 | anno = self.source.origin.decode(line) 130 | 131 | # read image 132 | origin_img = cv2.imread(os.path.join(self.source.origin.root, anno.face)) 133 | origin_img = self.transforms(origin_img) 134 | 135 | origin_cam_mat = np.diag((1, 1, 1)) 136 | origin_cam_mat = torch.from_numpy(origin_cam_mat).type(torch.FloatTensor) 137 | 138 | origin_z_axis = torch.Tensor([0, 0, 1]).type(torch.FloatTensor) 139 | 140 | zone = int(anno.zone) 141 | zone = torch.Tensor([zone]).type(torch.long) 142 | 143 | # read label 144 | origin_label = gazeto2d(np.array(anno.gaze.split(",")).astype("float")) 145 | origin_label = torch.from_numpy(origin_label).type(torch.FloatTensor) 146 | 147 | gaze_origin = np.array(anno.origin.split(",")).astype("float") 148 | gaze_origin = torch.from_numpy(gaze_origin).type(torch.FloatTensor) 149 | 150 | name = anno.name 151 | 152 | # --------------------read norm------------------------ 153 | line = self.source.norm.line[idx] 154 | line = line.strip().split(" ") 155 | 156 | # decode the info 157 | anno = self.source.norm.decode(line) 158 | 159 | # read image 160 | norm_img = cv2.imread(os.path.join(self.source.norm.root, anno.face)) 161 | norm_img = self.transforms(norm_img) 162 | 163 | # camera position. 164 | norm_mat = np.fromstring(anno.norm, sep=',') 165 | norm_mat = cv2.Rodrigues(norm_mat)[0] 166 | 167 | # Camera rotation. Label = R * prediction 168 | inv_mat = np.linalg.inv(norm_mat) 169 | z_axis = inv_mat[:, 2].flatten() 170 | 171 | norm_cam_mat = torch.from_numpy(inv_mat).type(torch.FloatTensor) 172 | z_axis = torch.from_numpy(z_axis).type(torch.FloatTensor) 173 | 174 | # read label 175 | norm_label = gazeto2d(np.array(anno.gaze.split(",")).astype("float")) 176 | norm_label = torch.from_numpy(norm_label).type(torch.FloatTensor) 177 | 178 | assert name == anno.name, 'Data is not aligned' 179 | 180 | pos = torch.concat([torch.unsqueeze(origin_z_axis, 0), torch.unsqueeze(z_axis, 0)] ,0) 181 | 182 | # --------------------------------------------------- 183 | data = edict() 184 | data.origin_face = origin_img 185 | data.origin_cam = origin_cam_mat 186 | data.norm_face = norm_img 187 | data.norm_cam = norm_cam_mat 188 | data.pos = pos 189 | data.name = anno.name 190 | data.gaze_origin = gaze_origin 191 | 192 | label = edict() 193 | label.originGaze = origin_label 194 | label.normGaze = norm_label 195 | label.zone = zone 196 | 197 | return data, label 198 | 199 | 200 | def loader(source, batch_size, shuffle=False, num_workers=0): 201 | 202 | dataset = commonloader(source) 203 | 204 | print(f"-- [Read Data]: Total num: {len(dataset)}") 205 | 206 | print(f"-- [Read Data]: Source: {source.norm.label}") 207 | 208 | load = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 209 | return load 210 | 211 | if __name__ == "__main__": 212 | 213 | path = './p00.label' 214 | # d = loader(path) 215 | # print(len(d)) 216 | # (data, label) = d.__getitem__(0) 217 | 218 | -------------------------------------------------------------------------------- /GazeDPTR_V2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | import copy 6 | from IVModule import Backbone, Transformer, PoseTransformer, TripleDifferentialProj, PositionalEncoder 7 | 8 | 9 | def ep0(x): 10 | return x.unsqueeze(0) 11 | 12 | class Model(nn.Module): 13 | 14 | def __init__(self): 15 | 16 | super(Model, self).__init__() 17 | 18 | # used for origin. 19 | transIn = 128 20 | convDims = [64, 128, 256, 512] 21 | 22 | # origin produces two features, one for gaze zone, one for gaze direction. 23 | self.borigin = Backbone(2, transIn, convDims) 24 | 25 | # norm only used for gaze estimation. 26 | self.bnorm = Backbone(1, transIn, convDims) 27 | 28 | self.ptrans = PoseTransformer(transIn, 3) 29 | 30 | # Proj 31 | self.proj = TripleDifferentialProj() 32 | 33 | # 3 * (30 * 2 + 1) 34 | self.PoseEncoder = PositionalEncoder(30, True) 35 | 36 | self.MLP_Pos_x = nn.Sequential( 37 | nn.Linear(183, 128), 38 | nn.ReLU(inplace=True), 39 | ) 40 | 41 | self.MLP_Pos_y = nn.Sequential( 42 | nn.Linear(183, 128), 43 | nn.ReLU(inplace=True), 44 | ) 45 | 46 | self.MLP_Pos_z = nn.Sequential( 47 | nn.Linear(183, 128), 48 | nn.ReLU(inplace=True), 49 | ) 50 | 51 | self.MLP_Pos = [self.MLP_Pos_x, self.MLP_Pos_y, self.MLP_Pos_z] 52 | 53 | 54 | # Transformer to combine gaze point and feature 55 | self.tripleTrans = Transformer( 56 | input_dim = 128, 57 | length = 3, 58 | layer_num=2, 59 | nhead=4 60 | ) 61 | 62 | self.zoneTrans = Transformer( 63 | input_dim = 128, 64 | length = 2, 65 | layer_num=2, 66 | nhead=4 67 | ) 68 | 69 | # MLP for gaze estimation 70 | class_num = 122 71 | 72 | self.MLP_o_dir = nn.Linear(transIn, 2) 73 | self.MLP_n_dir = nn.Linear(transIn, 2) 74 | 75 | 76 | module_list = [] 77 | for i in range(len(convDims)): 78 | module_list.append(nn.Linear(transIn, 2)) 79 | self.MLPList_o = nn.ModuleList(module_list) 80 | 81 | module_list = [] 82 | for i in range(len(convDims)): 83 | module_list.append(nn.Linear(transIn, 2)) 84 | self.MLPList_n = nn.ModuleList(module_list) 85 | 86 | self.MLP_o_dir2 = nn.Linear(transIn, 2) 87 | self.MLP_n_dir2 = nn.Linear(transIn, 2) 88 | 89 | self.MLP_o_zone = nn.Linear(transIn, class_num) 90 | self.MLP_o_zone2 = nn.Linear(transIn, class_num) 91 | self.MLP_o_zone3 = nn.Linear(transIn, class_num) 92 | 93 | # Loss function 94 | self.loss_op_re = nn.L1Loss() 95 | self.loss_op_cls = nn.CrossEntropyLoss() 96 | 97 | 98 | def forward(self, x_in, train=True): 99 | 100 | # feature [outFeatureNum, Batch, transIn], MLfeatgure: list[x1, x2...] 101 | 102 | # Extract feature from both two images 103 | feature_o, feature_list_o= self.borigin(x_in['origin_face']) 104 | 105 | feature_n, feature_list_n = self.bnorm(x_in['norm_face']) 106 | 107 | # Get feature for different task 108 | # [5, 128] [1. 5, 128] 109 | feature_o_zone = feature_o[0,:] 110 | feature_o_dir = feature_o[1,:] 111 | 112 | feature_n_dir = feature_n.squeeze() 113 | 114 | zone1 = self.MLP_o_zone(feature_o_zone) 115 | 116 | # Fuse two direction feature and input it into transformer 117 | features_dir = torch.cat([ep0(feature_o_dir), ep0(feature_n_dir)], 0) 118 | features = self.ptrans(features_dir, x_in['pos']) 119 | 120 | # Get fused feature 121 | # feature_o_dir2 = features[0, :] 122 | feature_n_dir2 = features[1, :] 123 | 124 | # estimate gaze from fused feature 125 | gaze = self.MLP_n_dir2(feature_n_dir2) 126 | # zone = self.MLP_o_zone(feature_o_zone) 127 | 128 | # Proj 129 | gaze_proj = self.proj(gaze.detach(), x_in['gaze_origin'], x_in['norm_cam']) 130 | gaze_proj_list = [] 131 | # gaze_proj_list.append(ep0(feature_o_zone)) 132 | 133 | for i, gaze2d in enumerate(gaze_proj): 134 | gaze_proj_list.append(ep0(self.MLP_Pos[i](self.PoseEncoder.encode(gaze2d)))) 135 | 136 | triple_feature = torch.cat(gaze_proj_list, 0) 137 | triple_feature = self.tripleTrans(triple_feature) 138 | zone2 = self.MLP_o_zone2(triple_feature) 139 | 140 | feature_o_zone2 = torch.cat([ep0(feature_o_zone), ep0(triple_feature)], 0) 141 | feature_o_zone2 = self.zoneTrans(feature_o_zone2) 142 | zone3 = self.MLP_o_zone3(feature_o_zone2) 143 | 144 | zone = [zone1, zone2, zone3] 145 | 146 | # for loss caculation 147 | loss_gaze_o = [] 148 | loss_gaze_n = [] 149 | if train: 150 | loss_gaze_n.append(self.MLP_n_dir(feature_n_dir)) 151 | loss_gaze_o.append(self.MLP_o_dir(feature_o_dir)) 152 | 153 | for i, feature in enumerate(feature_list_o): 154 | loss_gaze_o.append(self.MLPList_o[i](feature)) 155 | 156 | for i, feature in enumerate(feature_list_n): 157 | loss_gaze_n.append(self.MLPList_n[i](feature)) 158 | 159 | else: 160 | zone = zone2.max(1)[1] 161 | 162 | return gaze, zone, loss_gaze_o, loss_gaze_n 163 | 164 | def loss(self, x_in, label): 165 | 166 | gaze, zones, loss_gaze_o, loss_gaze_n = self.forward(x_in) 167 | 168 | loss1 = 2 * self.loss_op_re(gaze, label.normGaze) 169 | 170 | loss2 = 0 171 | for zone in zones: 172 | loss2 += (0.2/3) * self.loss_op_cls(zone, label.zone.view(-1)) 173 | 174 | loss3 = 0 175 | for pred in loss_gaze_o: 176 | loss3 += self.loss_op_re(pred, label.originGaze) 177 | 178 | loss4 = 0 179 | for pred in loss_gaze_n: 180 | loss4 += self.loss_op_re(pred, label.normGaze) 181 | loss = loss1 + loss2 + loss3 + loss4 182 | 183 | return loss, [loss1, loss2, loss3, loss4] 184 | 185 | 186 | if __name__ == '__main__': 187 | x_in = {'origin': torch.zeros([5, 3, 224, 224]).cuda(), 188 | 'norm': torch.zeros([5, 3, 224, 224]).cuda(), 189 | 'pos': torch.zeros(5, 2, 6).cuda() 190 | } 191 | 192 | model = Model() 193 | model = model.to('cuda') 194 | print(model) 195 | a = model(x_in) 196 | print(a) 197 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /IVModule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | import copy 6 | from resnet import resnet18 7 | 8 | def _get_clones(module, N): 9 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 10 | 11 | def ep0(x): 12 | return x.unsqueeze(0) 13 | 14 | class TripleDifferentialProj(nn.Module): 15 | 16 | def __init__(self): 17 | super(TripleDifferentialProj, self).__init__() 18 | 19 | # one 3D point in screen plane 20 | point = torch.Tensor([0, 0, 0]) 21 | 22 | # normal vector of screen plane 23 | normal = torch.Tensor([0, 0, 1]) 24 | 25 | def forward(self, gaze, origin, norm_mat = None): 26 | # inputted gaze is [pitch, yaw] 27 | 28 | gaze3d = self.gazeto3d(gaze) 29 | if norm_mat != None: 30 | gaze3d = torch.einsum('acd,ad->ac', norm_mat, gaze3d) 31 | 32 | gazex = self.gazeto2dPlus(gaze3d, origin, 0) 33 | gazey = self.gazeto2dPlus(gaze3d, origin, 1) 34 | gazez = self.gazeto2dPlus(gaze3d, origin, 2) 35 | gaze = [gazex, gazey, gazez] 36 | return gaze 37 | 38 | def gazeto3d(self, point): 39 | # Yaw Pitch, Here 40 | x = -torch.cos(point[:, 1]) * torch.sin(point[:, 0]) 41 | y = -torch.sin(point[:, 1]) 42 | z = -torch.cos(point[:, 1]) * torch.cos(point[:, 0]) 43 | gaze = torch.cat([x.unsqueeze(1), y.unsqueeze(1), z.unsqueeze(1)], 1) 44 | return gaze 45 | 46 | def gazeto2dPlus(self, gaze, origin, plane: int): 47 | 48 | assert plane < 3, 'plane should be 0(x), 1(y) or 2(z)' 49 | length = origin[:, plane] 50 | g_len = gaze[:, plane] 51 | scale = -length / g_len 52 | gaze = torch.einsum('ik, i->ik', gaze, scale) 53 | point = origin + gaze 54 | return point 55 | 56 | class TransformerEncoder(nn.Module): 57 | 58 | def __init__(self, encoder_layer, num_layers, norm=None): 59 | super().__init__() 60 | self.layers = _get_clones(encoder_layer, num_layers) 61 | self.num_layers = num_layers 62 | self.norm = norm 63 | 64 | def forward(self, src, pos): 65 | output = src 66 | for layer in self.layers: 67 | output = layer(output, pos) 68 | 69 | if self.norm is not None: 70 | output = self.norm(output) 71 | 72 | return output 73 | 74 | class PoseTransformerEncoderLayer(nn.Module): 75 | 76 | def __init__(self, d_model, nhead, dim_feedforward=512, dropout=0.1): 77 | super().__init__() 78 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 79 | # Implementation of Feedforward model 80 | self.linear1 = nn.Linear(d_model, dim_feedforward) 81 | self.dropout = nn.Dropout(dropout) 82 | self.linear2 = nn.Linear(dim_feedforward, d_model) 83 | 84 | self.norm1 = nn.LayerNorm(d_model) 85 | self.norm2 = nn.LayerNorm(d_model) 86 | 87 | self.dropout1 = nn.Dropout(dropout) 88 | self.dropout2 = nn.Dropout(dropout) 89 | 90 | self.activation = nn.ReLU(inplace=True) 91 | 92 | def pos_embed(self, src, pos): 93 | return src + pos 94 | 95 | 96 | def forward(self, src, pos): 97 | # src_mask: Optional[Tensor] = None, 98 | # src_key_padding_mask: Optional[Tensor] = None): 99 | # pos: Optional[Tensor] = None): 100 | 101 | q = k = self.pos_embed(src, pos) 102 | src2 = self.self_attn(q, k, value=src)[0] 103 | src = src + self.dropout1(src2) 104 | src = self.norm1(src) 105 | 106 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 107 | src = src + self.dropout2(src2) 108 | src = self.norm2(src) 109 | return src 110 | 111 | 112 | 113 | class TransformerEncoderLayer(nn.Module): 114 | 115 | def __init__(self, d_model, nhead, dim_feedforward=512, dropout=0.1): 116 | super().__init__() 117 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 118 | # Implementation of Feedforward model 119 | self.linear1 = nn.Linear(d_model, dim_feedforward) 120 | self.dropout = nn.Dropout(dropout) 121 | self.linear2 = nn.Linear(dim_feedforward, d_model) 122 | 123 | self.norm1 = nn.LayerNorm(d_model) 124 | self.norm2 = nn.LayerNorm(d_model) 125 | 126 | self.dropout1 = nn.Dropout(dropout) 127 | self.dropout2 = nn.Dropout(dropout) 128 | 129 | self.activation = nn.ReLU(inplace=True) 130 | 131 | def pos_embed(self, src, pos): 132 | batch_pos = pos.unsqueeze(1).repeat(1, src.size(1), 1) 133 | return src + batch_pos 134 | 135 | 136 | def forward(self, src, pos): 137 | # src_mask: Optional[Tensor] = None, 138 | # src_key_padding_mask: Optional[Tensor] = None): 139 | # pos: Optional[Tensor] = None): 140 | 141 | q = k = self.pos_embed(src, pos) 142 | src2 = self.self_attn(q, k, value=src)[0] 143 | src = src + self.dropout1(src2) 144 | src = self.norm1(src) 145 | 146 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 147 | src = src + self.dropout2(src2) 148 | src = self.norm2(src) 149 | return src 150 | 151 | 152 | class conv1x1(nn.Module): 153 | 154 | def __init__(self, in_planes, out_planes, stride=1): 155 | super(conv1x1, self).__init__() 156 | 157 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 158 | padding=0, bias=False) 159 | 160 | self.bn = nn.BatchNorm2d(out_planes) 161 | 162 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 163 | 164 | def forward(self, feature): 165 | output = self.conv(feature) 166 | output = self.bn(output) 167 | output = self.avgpool(output) 168 | output = output.squeeze() 169 | 170 | return output 171 | 172 | 173 | 174 | class MLEnhance(nn.Module): 175 | 176 | def __init__(self, input_nums, hidden_num): 177 | 178 | super(MLEnhance, self).__init__() 179 | 180 | # input_nums: [64, 128, 256, 512] 181 | length = len(input_nums) 182 | 183 | self.input_nums = input_nums 184 | self.hidden_num = hidden_num 185 | 186 | self.length = length 187 | 188 | layerList = [] 189 | 190 | for i in range(length): 191 | layerList.append(self.__build_layer(input_nums[i], hidden_num)) 192 | 193 | self.layerList = nn.ModuleList(layerList) 194 | 195 | 196 | def __build_layer(self, input_num, hidden_num): 197 | layer = conv1x1(input_num, hidden_num) 198 | 199 | return layer 200 | 201 | def forward(self, feature_list): 202 | 203 | out_feature_list = [] 204 | 205 | out_feature_gather =[] 206 | 207 | for i, feature in enumerate(feature_list): 208 | result = self.layerList[i](feature) 209 | 210 | # Dim [B, C] -> [1, B, C] 211 | out_feature_list.append(result) 212 | 213 | out_feature_gather.append(ep0(result)) 214 | 215 | 216 | # [L, B, C] 217 | feature = torch.cat(out_feature_gather, 0) 218 | return feature, out_feature_list 219 | 220 | class PositionalEncoder(): 221 | # encode low-dim, vec to high-dims. 222 | 223 | def __init__(self, number_freqs, include_identity=False): 224 | freq_bands = torch.pow(2, torch.linspace(0., number_freqs - 1, number_freqs)) 225 | self.embed_fns = [] 226 | self.output_dim = 0 227 | 228 | if include_identity: 229 | self.embed_fns.append(lambda x:x) 230 | self.output_dim += 1 231 | 232 | for freq in freq_bands: 233 | for transform_fns in [torch.sin, torch.cos]: 234 | self.embed_fns.append(lambda x, fns=transform_fns, freq=freq: fns(x*freq)) 235 | self.output_dim += 1 236 | 237 | def encode(self, vecs): 238 | # inputs: [B, N] 239 | # outputs: [B, N*number_freqs*2] 240 | return torch.cat([fn(vecs) for fn in self.embed_fns], -1) 241 | 242 | def getDims(self): 243 | return self.output_dim 244 | 245 | class PoseTransformer(nn.Module): 246 | 247 | def __init__(self, input_dim, pos_length, nhead=8, hidden_dim=512, layer_num = 6, 248 | pos_freq=30, pos_ident=True, pos_hidden=128, dropout=0.1): 249 | 250 | super(PoseTransformer, self).__init__() 251 | # input feature + added token 252 | 253 | # The input feature should be [L, Batch, Input_dim] 254 | encoder_layer = PoseTransformerEncoderLayer( 255 | input_dim, 256 | nhead = nhead, 257 | dim_feedforward = hidden_dim, 258 | dropout=dropout) 259 | 260 | encoder_norm = nn.LayerNorm(input_dim) 261 | self.encoder = TransformerEncoder(encoder_layer, num_layers = layer_num, norm = encoder_norm) 262 | 263 | self.pos_embedding = PositionalEncoder(pos_freq, pos_ident) 264 | 265 | out_dim = pos_length * (pos_freq * 2 + pos_ident) 266 | 267 | self.pos_encode = nn.Sequential( 268 | nn.Linear(out_dim, pos_hidden), 269 | nn.LeakyReLU(0.1), 270 | nn.Linear(pos_hidden, pos_hidden), 271 | nn.LeakyReLU(0.1), 272 | nn.Linear(pos_hidden, pos_hidden), 273 | nn.LeakyReLU(0.1), 274 | nn.Linear(pos_hidden, input_dim) 275 | ) 276 | 277 | def forward(self, feature, pos_feature): 278 | 279 | """ 280 | Inputs: 281 | feature: [length, batch, dim1] 282 | pos_feature: [batch, length, dim2] 283 | 284 | Outputs: 285 | feature: [batch, length, dim1] 286 | 287 | 288 | """ 289 | 290 | # feature: [Length, Batch, Dim] 291 | pos_feature = self.pos_embedding.encode(pos_feature) 292 | pos_feature = self.pos_encode(pos_feature) 293 | pos_feature = pos_feature.permute(1, 0, 2) 294 | 295 | # feature [Length, batch, dim] 296 | feature = self.encoder(feature, pos_feature) 297 | # feature = feature.permute(1, 0, 2) 298 | 299 | return feature 300 | 301 | class Transformer(nn.Module): 302 | 303 | def __init__(self, input_dim, nhead=8, hidden_dim=512, layer_num = 6, pred_num=1, length=4, dropout=0.1): 304 | 305 | super(Transformer, self).__init__() 306 | 307 | self.pnum = pred_num 308 | # input feature + added token 309 | # self.length = length + 1 310 | self.length = length 311 | 312 | # The input feature should be [L, Batch, Input_dim] 313 | encoder_layer = TransformerEncoderLayer( 314 | input_dim, 315 | nhead = nhead, 316 | dim_feedforward = hidden_dim, 317 | dropout=dropout) 318 | 319 | encoder_norm = nn.LayerNorm(input_dim) 320 | 321 | self.encoder = TransformerEncoder(encoder_layer, num_layers = layer_num, norm = encoder_norm) 322 | 323 | self.cls_token = nn.Parameter(torch.randn(pred_num, 1, input_dim)) 324 | 325 | self.token_pos_embedding = nn.Embedding(pred_num, input_dim) 326 | self.pos_embedding = nn.Embedding(length, input_dim) 327 | 328 | 329 | def forward(self, feature, num = 1): 330 | 331 | # feature: [Length, Batch, Dim] 332 | batch_size = feature.size(1) 333 | 334 | # cls_num, 1, Dim -> cls_num, Batch_size, Dim 335 | 336 | feature_list = [] 337 | 338 | for i in range(num): 339 | 340 | cls = self.cls_token[i, :, :].repeat((1, batch_size, 1)) 341 | 342 | feature_in = torch.cat([cls, feature], 0) 343 | 344 | # position 345 | position = torch.from_numpy(np.arange(self.length)).cuda() 346 | pos_feature = self.pos_embedding(position) 347 | 348 | token_position = torch.Tensor([i]).long().cuda() 349 | token_pos_feature = self.token_pos_embedding(token_position) 350 | 351 | pos_feature = torch.cat([pos_feature, token_pos_feature], 0) 352 | 353 | # feature [Length, batch, dim] 354 | feature_out = self.encoder(feature_in, pos_feature) 355 | 356 | # [batch, dim, length] 357 | # feature = feature.permute(1, 2, 0) 358 | 359 | # get the first dimension, [pnum, batch, dim] 360 | feature_out = feature_out[0:1, :, :] 361 | feature_list.append(feature_out) 362 | 363 | return torch.cat(feature_list, 0).squeeze() 364 | 365 | class Backbone(nn.Module): 366 | 367 | def __init__(self, outFeatureNum=1, transIn=128, convDims=[64, 128, 256, 512]): 368 | 369 | super(Backbone, self).__init__() 370 | 371 | self.base_model = resnet18(pretrained=True, input_dim = convDims) 372 | 373 | self.transformer = Transformer(input_dim = transIn, nhead = 8, hidden_dim=512, 374 | layer_num=6, pred_num = outFeatureNum, length = len(convDims)) 375 | 376 | # convert multi-scale feature 377 | self.mle = MLEnhance(input_nums = convDims, hidden_num = transIn) 378 | 379 | self.feed = nn.Linear(transIn, 2) 380 | 381 | self.loss_op = nn.L1Loss() 382 | 383 | self.loss_layerList = [] 384 | 385 | self.outFeatureNum = outFeatureNum 386 | 387 | # for i in range(len(convDims)): 388 | # self.loss_layerList.append(nn.Linear(tranIn, 2)) 389 | 390 | def forward(self, image): 391 | 392 | x1, x2, x3, x4 = self.base_model(image) 393 | 394 | feature, feature_list = self.mle([x1, x2, x3, x4]) 395 | 396 | #t_gaze = [] 397 | # for t_feature in feature_list: 398 | # t_gaze.append(self.loss_layerList(t_feature)) 399 | 400 | feature = self.transformer(feature, self.outFeatureNum) 401 | 402 | # gaze = self.feed(feature) 403 | 404 | return feature, feature_list 405 | 406 | --------------------------------------------------------------------------------