├── __init__.py
├── assets
├── test
└── logo.png
├── Zone
├── zone.png
├── class_large.label
└── evaluation_class.py
├── reader
├── __pycache__
│ ├── reader.cpython-36.pyc
│ ├── reader.cpython-38.pyc
│ ├── reader_adap.cpython-36.pyc
│ ├── reader_diap.cpython-36.pyc
│ ├── reader_gc.cpython-36.pyc
│ ├── reader_iv.cpython-38.pyc
│ ├── reader_mpii.cpython-36.pyc
│ ├── reader_pred.cpython-36.pyc
│ └── reader_prediction.cpython-36.pyc
└── reader.py
├── config
├── test
│ └── config_iv_us.yaml
└── train
│ └── config_iv.yaml
├── gtools.py
├── evaluation.py
├── ctools.py
├── GazePTR.py
├── DATASET.md
├── GazeDPTR.py
├── tester
├── total.py
└── leave.py
├── trainer
├── total.py
└── leave.py
├── resnet.py
├── README.md
├── GazeDPTR_V2.py
├── LICENSE
└── IVModule.py
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/test:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/Zone/zone.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/Zone/zone.png
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/assets/logo.png
--------------------------------------------------------------------------------
/reader/__pycache__/reader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader.cpython-36.pyc
--------------------------------------------------------------------------------
/reader/__pycache__/reader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader.cpython-38.pyc
--------------------------------------------------------------------------------
/reader/__pycache__/reader_adap.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader_adap.cpython-36.pyc
--------------------------------------------------------------------------------
/reader/__pycache__/reader_diap.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader_diap.cpython-36.pyc
--------------------------------------------------------------------------------
/reader/__pycache__/reader_gc.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader_gc.cpython-36.pyc
--------------------------------------------------------------------------------
/reader/__pycache__/reader_iv.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader_iv.cpython-38.pyc
--------------------------------------------------------------------------------
/reader/__pycache__/reader_mpii.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader_mpii.cpython-36.pyc
--------------------------------------------------------------------------------
/reader/__pycache__/reader_pred.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader_pred.cpython-36.pyc
--------------------------------------------------------------------------------
/reader/__pycache__/reader_prediction.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yihuacheng/IVGaze/HEAD/reader/__pycache__/reader_prediction.cpython-36.pyc
--------------------------------------------------------------------------------
/config/test/config_iv_us.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | load:
3 | begin_step: 10
4 | end_step: 80
5 | steps: 10
6 |
7 | data:
8 | origin:
9 | image: "/home/$YourPath$/Origin"
10 | label: "/home/$YourPath$/Origin/label_class"
11 | header: True
12 | name: ivorigin
13 | isFolder: True
14 | norm:
15 | image: "/home/$YourPath$/Norm"
16 | label: "/home/$YourPath$/Norm/label_class"
17 | header: True
18 | name: ivnorm
19 | isFolder: True
20 |
21 | savename: "evaluation"
22 | device: 0
23 | reader: reader
24 |
--------------------------------------------------------------------------------
/gtools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def gazeto3d(gaze):
4 | assert gaze.size == 2, "The size of gaze must be 2"
5 | gaze_gt = np.zeros([3])
6 | gaze_gt[0] = -np.cos(gaze[1]) * np.sin(gaze[0])
7 | gaze_gt[1] = -np.sin(gaze[1])
8 | gaze_gt[2] = -np.cos(gaze[1]) * np.cos(gaze[0])
9 | return gaze_gt
10 |
11 | def angular(gaze, label):
12 | assert gaze.size == 3, "The size of gaze must be 3"
13 | assert label.size == 3, "The size of label must be 3"
14 |
15 | total = np.sum(gaze * label)
16 | return np.arccos(min(total/(np.linalg.norm(gaze)* np.linalg.norm(label)), 0.9999999))*180/np.pi
17 |
18 |
19 |
--------------------------------------------------------------------------------
/config/train/config_iv.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | train:
3 |
4 | params:
5 | batch_size: 64
6 | epoch: 80
7 | lr: 0.001
8 | decay: 0.5
9 | decay_step: 60
10 | warmup: 5
11 |
12 | save:
13 | metapath: "/home/$YourSavePath$/exp/GazeDPTR"
14 | folder: iv
15 | model_name: trans6
16 | step: 10
17 |
18 | data:
19 | origin:
20 | image: "/home/$YourPath$/Origin"
21 | label: "/home/$YourPath$/Origin/label_class"
22 | header: True
23 | name: ivorigin
24 | isFolder: True
25 | norm:
26 | image: "/home/$YourPath$/Norm"
27 | label: "/home/$YourPath$/Norm/label_class"
28 | header: True
29 | name: ivnorm
30 | isFolder: True
31 |
32 |
33 | pretrain:
34 | enable: False
35 | path: None
36 | device: 0
37 |
38 | device: 0
39 |
40 | reader: reader
41 |
42 | # dropout = 0
43 | # dim_feed = 512
44 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | folder = sys.argv[1]
5 | begin = 10
6 | step = 10
7 | end = 80
8 |
9 | # pre_add = 'log_zone'
10 | pre_add = 'log_dir'
11 |
12 |
13 | def parse(line):
14 | line = line.strip().split(" ")
15 | number = int((line[3][:-1]))
16 | avg = float(line[-1])
17 | return number, avg * number, avg
18 |
19 | def printResult(n_num, n_error, n_avg):
20 | print(f'Iter {i}: Total: {n_num} Error: {n_error/n_num:.2f}\tSub_Error: ', end = '')
21 | for avg in n_avg:
22 | print(f'{avg} ', end='')
23 | print('')
24 |
25 |
26 | sub_folders = os.listdir(folder)
27 |
28 | for i in range(begin, end+1, step):
29 | try:
30 | name = f"{i}.{pre_add}"
31 | n_num = 0
32 | n_error = 0
33 | n_avg = []
34 | for sub_path in sub_folders:
35 | filename = os.path.join(folder, sub_path, name)
36 | with open(filename) as infile:
37 | lines = infile.readlines()
38 | number, error, avg = parse(lines[-1])
39 | n_num += number
40 | n_error += error
41 | n_avg.append(avg)
42 |
43 | printResult(n_num, n_error, n_avg)
44 | except:
45 | pass
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/ctools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sys
3 | import time
4 | import os
5 | import json
6 | from easydict import EasyDict as edict
7 |
8 | class TimeCounter:
9 | # Create an time counter.
10 | # To count the rest time.
11 |
12 | # Input the total times.
13 | def __init__(self, total):
14 | self.total = total
15 | self.cur = 0
16 | self.begin = time.time()
17 |
18 | def step(self):
19 | end = time.time()
20 | self.cur += 1
21 | used = (end - self.begin)/self.cur
22 | rest = self.total - self.cur
23 |
24 | return np.max(rest * used, 0)
25 |
26 |
27 | def readfolder(data, specific=None, reverse=False):
28 |
29 | """"
30 | Traverse the folder 'data.label' and read data from all files in the folder.
31 |
32 | Specific is a list, specify the num of extracted file.
33 |
34 | When reverse is True, read the files which num is not in specific.
35 | """
36 |
37 |
38 | folders = os.listdir(data.label)
39 | folders.sort()
40 |
41 | folder = folders
42 | if specific is not None:
43 | if reverse:
44 | num = np.arange(len(folders))
45 | specific = list(filter(lambda x: x not in specific, num))
46 |
47 | folder = [folders[i] for i in specific]
48 |
49 | data.label = [os.path.join(data.label, j) for j in folder]
50 |
51 | return data, folders
52 |
53 |
54 | def DictDumps(content):
55 | return json.dumps(content, ensure_ascii=False, indent=4)
56 |
57 |
58 | def GetLR(optimizer):
59 | LR = optimizer.state_dict()['param_groups'][0]['lr']
60 | return LR
61 |
62 |
--------------------------------------------------------------------------------
/Zone/class_large.label:
--------------------------------------------------------------------------------
1 | 0 37 0
2 | 1 57 0
3 | 2 77 0
4 | 3 97 0
5 | 4 62 9
6 | 5 63 9
7 | 6 64 9
8 | 7 65 9
9 | 8 72 9
10 | 9 73 9
11 | 10 80 0
12 | 11 81 0
13 | 12 84 9
14 | 13 85 9
15 | 14 90 0
16 | 15 91 0
17 | 16 92 9
18 | 17 93 9
19 | 18 94 9
20 | 19 95 9
21 | 20 98 0
22 | 21 88 0
23 | 22 78 0
24 | 23 68 0
25 | 24 48 0
26 | 25 38 0
27 | 26 89 0
28 | 27 69 0
29 | 28 59 0
30 | 29 39 0
31 | 30 d2 4
32 | 31 33 8
33 | 32 34 8
34 | 33 35 8
35 | 34 36 8
36 | 35 42 8
37 | 36 43 8
38 | 37 44 8
39 | 38 45 8
40 | 39 51 0
41 | 40 52 0
42 | 41 53 0
43 | 42 56 0
44 | 43 07 0
45 | 44 27 0
46 | 45 47 0
47 | 46 67 0
48 | 47 17 0
49 | 48 01 0
50 | 49 02 0
51 | 50 04 0
52 | 51 05 0
53 | 52 06 0
54 | 53 10 0
55 | 54 11 0
56 | 55 12 8
57 | 56 13 8
58 | 57 14 8
59 | 58 15 8
60 | 59 16 8
61 | 60 21 0
62 | 61 23 8
63 | 62 24 8
64 | 63 26 8
65 | 64 30 0
66 | 65 32 8
67 | 66 d3 4
68 | 67 d4 4
69 | 68 d0 4
70 | 69 c1 3
71 | 70 c0 3
72 | 71 b0 2
73 | 72 b1 2
74 | 73 b2 2
75 | 74 a1 1
76 | 75 a2 1
77 | 76 a4 1
78 | 77 a0 1
79 | 78 00 0
80 | 79 03 0
81 | 80 20 0
82 | 81 22 8
83 | 82 25 8
84 | 83 31 0
85 | 84 41 0
86 | 85 46 8
87 | 86 54 0
88 | 87 55 0
89 | 88 87 0
90 | 89 66 9
91 | 90 74 9
92 | 91 75 9
93 | 92 76 9
94 | 93 82 9
95 | 94 83 9
96 | 95 86 9
97 | 96 96 9
98 | 97 58 0
99 | 98 79 0
100 | 99 d1 4
101 | 100 f6 6
102 | 101 f7 6
103 | 102 f8 6
104 | 103 c2 3
105 | 104 c3 3
106 | 105 a3 1
107 | 106 e0 5
108 | 107 e1 5
109 | 108 e2 5
110 | 109 f0 6
111 | 110 f1 6
112 | 111 g4 7
113 | 112 g1 7
114 | 113 g2 7
115 | 114 g3 7
116 | 115 99 0
117 | 116 49 0
118 | 117 f2 6
119 | 118 71 0
120 | 119 f9 6
121 | 120 f5 6
122 | 121 f4 6
123 |
--------------------------------------------------------------------------------
/GazePTR.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import math
5 | import copy
6 | from IVModule import Backbone, Transformer, PoseTransformer, TripleDifferentialProj, PositionalEncoder
7 |
8 |
9 | def ep0(x):
10 | return x.unsqueeze(0)
11 |
12 | class Model(nn.Module):
13 |
14 | def __init__(self):
15 |
16 | super(Model, self).__init__()
17 |
18 | # used for origin.
19 | transIn = 128
20 | convDims = [64, 128, 256, 512]
21 |
22 | # norm only used for gaze estimation.
23 | self.bnorm = Backbone(1, transIn, convDims)
24 |
25 | # MLP for gaze estimation
26 | self.MLP_n_dir = nn.Linear(transIn, 2)
27 |
28 | module_list = []
29 | for i in range(len(convDims)):
30 | module_list.append(nn.Linear(transIn, 2))
31 | self.MLPList_n = nn.ModuleList(module_list)
32 |
33 | # Loss function
34 | self.loss_op_re = nn.L1Loss()
35 |
36 |
37 | def forward(self, x_in, train=True):
38 |
39 | # feature [outFeatureNum, Batch, transIn], MLfeatgure: list[x1, x2...]
40 | feature_n, feature_list_n = self.bnorm(x_in['norm_face'])
41 |
42 | # Get feature for different task
43 | # [5, 128] [1. 5, 128]
44 | feature_n_dir = feature_n.squeeze()
45 |
46 | # estimate gaze from fused feature
47 | gaze = self.MLP_n_dir(feature_n_dir)
48 | # zone = self.MLP_o_zone(feature_o_zone)
49 |
50 | # for loss caculation
51 | loss_gaze_n = []
52 | if train:
53 | for i, feature in enumerate(feature_list_n):
54 | loss_gaze_n.append(self.MLPList_n[i](feature))
55 |
56 | return gaze, None, None, loss_gaze_n
57 |
58 | def loss(self, x_in, label):
59 |
60 | gaze, _, _, loss_gaze_n = self.forward(x_in)
61 |
62 | loss1 = 2 * self.loss_op_re(gaze, label.normGaze)
63 |
64 | loss2 = 0
65 | # for zone in zones:
66 | # loss2 += (0.2/3) * self.loss_op_cls(zone, label.zone.view(-1))
67 |
68 | loss3 = 0
69 | # for pred in loss_gaze_o:
70 | # loss3 += self.loss_op_re(pred, label.originGaze)
71 |
72 | loss4 = 0
73 | for pred in loss_gaze_n:
74 | loss4 += self.loss_op_re(pred, label.normGaze)
75 | loss = loss1 + loss2 + loss3 + loss4
76 |
77 | return loss, [loss1, loss2, loss3, loss4]
78 |
79 |
80 | if __name__ == '__main__':
81 | x_in = {'origin': torch.zeros([5, 3, 224, 224]).cuda(),
82 | 'norm': torch.zeros([5, 3, 224, 224]).cuda(),
83 | 'pos': torch.zeros(5, 2, 6).cuda()
84 | }
85 |
86 | model = Model()
87 | model = model.to('cuda')
88 | print(model)
89 | a = model(x_in)
90 | print(a)
91 |
--------------------------------------------------------------------------------
/Zone/evaluation_class.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import numpy as np
4 |
5 | folder = sys.argv[1]
6 | begin = 20
7 | step = 20
8 | end = 100
9 |
10 | pre_add = 'log_zone'
11 |
12 | class classAdapt:
13 | def __init__(self):
14 | with open ('/home/mercury01/ssd/dataset/Gaze/FaceBased/IVGaze/CVPR_process/class_large.label') as infile:
15 | lines = infile.readlines()
16 | self.class_dict = {}
17 | for line in lines:
18 | line = line.strip().split(' ')
19 | number = line[0]
20 |
21 | classname = line[2]
22 | self.class_dict[number] = classname
23 |
24 | def get(self, name):
25 | return self.class_dict.get(name, '-1')
26 |
27 |
28 | adapt = classAdapt()
29 |
30 |
31 | def printMat(matrix):
32 | size = matrix.shape
33 |
34 | print('pr\gt\t', end = '')
35 | for i in range(size[1]):
36 | print(f'{i}', end = '\t')
37 | print('')
38 |
39 | for i in range(size[0]):
40 | print(f'{i}', end = '\t')
41 | for j in range(size[1]):
42 | print(f'{int(matrix[i, j])}', end = '\t')
43 | print('')
44 | return 0
45 |
46 |
47 | def getAnalysis(results):
48 | # result: a list of prediction. [[pred, gt], [pred, gt]]
49 | num = 10
50 | all_class = range(0, num)
51 |
52 | matrix = np.zeros((num, num))
53 |
54 | total = 0
55 | for result in results:
56 | matrix[int(result[1]), int(result[0])] += 1
57 | total += 1
58 |
59 | printMat(matrix)
60 |
61 | for i in range(num):
62 | sum_i = np.sum(matrix[i, :])
63 | acc = matrix[i, i] / sum_i
64 | print(f'Class {i}: {matrix[i, i]}/{sum_i} = {acc:.3f}')
65 |
66 | return matrix
67 |
68 |
69 | def parse(line):
70 | line = line.strip().split(" ")
71 | pred = line[1]
72 | gt = line[2]
73 |
74 | new_pred = adapt.get(pred)
75 | new_gt = adapt.get(gt)
76 |
77 | return line[0], new_pred, new_gt, int(new_pred == new_gt)
78 |
79 | sub_folders = os.listdir(folder)
80 |
81 | for i in range(begin, end+1, step):
82 |
83 | try:
84 | name = f"{i}.{pre_add}"
85 |
86 | total = 0
87 | count = 0
88 | w_total = 0
89 | w_count = 0
90 | n_result = []
91 | for sub_path in sub_folders:
92 |
93 |
94 | filename = os.path.join(folder, sub_path, name)
95 |
96 | with open(filename) as infile:
97 | lines = infile.readlines()
98 | lines.pop(0)
99 |
100 | for line in lines:
101 | if len(line.strip().split(' ')) != 3:
102 | continue
103 | _, pred, gt, result = parse(line)
104 | if gt != str(0):
105 | total += result
106 | count += 1
107 | w_total += result
108 | w_count += 1
109 | n_result.append([pred, gt])
110 | print(f'***************************Result---{i}*****************************')
111 | getAnalysis(n_result)
112 | print(f'Avg: {total}/{count} = {total/count:.3f} ')
113 | print(f'With 0: {w_total}/{w_count} = {w_total/w_count:.3f} ')
114 | except:
115 | pass
116 |
117 |
118 |
119 |
120 |
--------------------------------------------------------------------------------
/DATASET.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | ## In-Vehicle Gaze Estimation DataSet
4 | We provide an in-vehicle gaze estimation dataset IVGaze.
5 |
6 | - IVGaze contains 44,705 images of 125 subjects. We divide the dataset into three subsets based on subjects. The image numbers of the three subsets are 15,165, 14,674, and 14,866.
7 | Three-fold cross-validation should be performed on the dataset.
8 |
9 | - The dataset was collected between 9 am and 7 pm in outdoor environments, covering a wide range of lighting conditions.
10 |
11 | - We consider two face accessories during the collection: glasses and masks. We also required a few subjects to wear sunglasses to facilitate future research.
12 |
13 | ## Dataset Structure
14 | ```
15 | IVGazeDataset
16 | ├── class.label
17 | ├── Norm
18 | │ ├── 20220811
19 | │ │ ├── subject0000_out_eye_mask
20 | │ │ │ ├── 1.jpg
21 | │ │ │ ├── ...
22 | │ │ │ ├── ...
23 | │ │ │ ├── 81.jpg
24 | │ │ ├── ...
25 | │ │ ├── ...
26 | │ │ ├── subject0000_out_eye_nomask
27 | │ ├── 20221009
28 | │ ├── 20221010
29 | │ ├── 20221011
30 | │ ├── 20221012
31 | │ ├── 20221013
32 | │ ├── 20221014
33 | │ ├── 20221017
34 | │ ├── 20221018
35 | │ ├── 20221019
36 | │ ├── 20221020
37 | │ └── label_class
38 | │ ├── train1.txt
39 | │ ├── train2.txt
40 | │ └── train3.txt
41 | └── Origin
42 | ├── 20220811
43 | ├── ...
44 | ├── ...
45 | ├── 20221020
46 | └── label_class
47 | ├── train1.txt
48 | ├── train2.txt
49 | └── train3.txt
50 | ```
51 |
52 | - `class.label`: This section offers gaze zone classification details. The first row denotes the class number according to `label_class`. The second row represents the original numbers assigned during the data collection phase. The third row indicates coarse region numbers.
53 | - `Norm`: This section contains normalized images and their corresponding labels.
54 | - `Norm/label_class`: Here, you'll find label files for three-fold validation.
55 | - `Origin`: This section provides original images directly cropped from facial images, along with their label files.
56 |
57 | ## Usage
58 |
59 | To retrieve data from the IVGaze Dataset, begin by reading the label file, such as `Norm/label_class/train1.txt`. Each line in the label file is formatted with space-separated values. You can read one line at a time for processing.
60 |
61 | ```
62 | root = 'IVGazeDataset/Norm'
63 | with open(os.path.join(root, 'label_class/train1.txt')) as infile:
64 | lines = infile.readlines()
65 |
66 | for line in lines:
67 | line = line.strip().split(' ')
68 |
69 | # Read the image
70 | image_name = line[0]
71 | image = cv2.imread(os.path.join(root, image_name))
72 |
73 | # GT for gaze and zone
74 | gaze = np.fromstring(line[1], sep=',')
75 | zone = int(line[3])
76 | ```
77 |
78 | ## Download
79 | To obtain access to the dataset, please send an email to `y.cheng.2@bham.ac.uk`.
80 | You will receive a Google Drive link within three days for downloading the dataset.
81 |
82 | Here's the email prompt for requesting access to the IVGaze Dataset. Please do not change the email subject.
83 |
84 | ```
85 | Subject: Request for Access to IVGaze Dataset
86 |
87 | Dear Yihua,
88 |
89 | I hope this email finds you well.
90 |
91 | I am writing to request access to the IVGaze Dataset. My name is [Your Name], and I am a [student/researcher] from [Your Affiliation].
92 |
93 | I assure you that I will only utilize the dataset for academic and research purposes and will not use it for commercial activities.
94 |
95 | Thank you for considering my request. I look forward to receiving access to the dataset.
96 |
97 | Best regards,
98 | [Your Name]
99 | ```
100 |
101 |
--------------------------------------------------------------------------------
/GazeDPTR.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import math
5 | import copy
6 | from IVModule import Backbone, Transformer, PoseTransformer, TripleDifferentialProj, PositionalEncoder
7 |
8 |
9 | def ep0(x):
10 | return x.unsqueeze(0)
11 |
12 | class Model(nn.Module):
13 |
14 | def __init__(self):
15 |
16 | super(Model, self).__init__()
17 |
18 | # used for origin.
19 | transIn = 128
20 | convDims = [64, 128, 256, 512]
21 |
22 | # origin produces two features, one for gaze zone, one for gaze direction.
23 | self.borigin = Backbone(1, transIn, convDims)
24 |
25 | # norm only used for gaze estimation.
26 | self.bnorm = Backbone(1, transIn, convDims)
27 |
28 | self.ptrans = PoseTransformer(transIn, 3)
29 |
30 | # MLP for gaze estimation
31 | self.MLP_o_dir = nn.Linear(transIn, 2)
32 | self.MLP_n_dir = nn.Linear(transIn, 2)
33 |
34 |
35 | module_list = []
36 | for i in range(len(convDims)):
37 | module_list.append(nn.Linear(transIn, 2))
38 | self.MLPList_o = nn.ModuleList(module_list)
39 |
40 | module_list = []
41 | for i in range(len(convDims)):
42 | module_list.append(nn.Linear(transIn, 2))
43 | self.MLPList_n = nn.ModuleList(module_list)
44 |
45 | self.MLP_o_dir2 = nn.Linear(transIn, 2)
46 | self.MLP_n_dir2 = nn.Linear(transIn, 2)
47 |
48 | # Loss function
49 | self.loss_op_re = nn.L1Loss()
50 | self.loss_op_cls = nn.CrossEntropyLoss()
51 |
52 |
53 | def forward(self, x_in, train=True):
54 |
55 | # feature [outFeatureNum, Batch, transIn], MLfeatgure: list[x1, x2...]
56 |
57 | # Extract feature from both two images
58 | feature_o, feature_list_o= self.borigin(x_in['origin_face'])
59 |
60 | feature_n, feature_list_n = self.bnorm(x_in['norm_face'])
61 |
62 | # Get feature for different task
63 | # [5, 128] [1. 5, 128]
64 | feature_o_dir = feature_o.squeeze()
65 | feature_n_dir = feature_n.squeeze()
66 |
67 | # Fuse two direction feature and input it into transformer
68 | features_dir = torch.cat([ep0(feature_o_dir), ep0(feature_n_dir)], 0)
69 | features = self.ptrans(features_dir, x_in['pos'])
70 |
71 | # Get fused feature
72 | # feature_o_dir2 = features[0, :]
73 | feature_n_dir2 = features[1, :]
74 |
75 | # estimate gaze from fused feature
76 | gaze = self.MLP_n_dir2(feature_n_dir2)
77 | # zone = self.MLP_o_zone(feature_o_zone)
78 |
79 | # for loss caculation
80 | loss_gaze_o = []
81 | loss_gaze_n = []
82 | if train:
83 | loss_gaze_n.append(self.MLP_n_dir(feature_n_dir))
84 | loss_gaze_o.append(self.MLP_o_dir(feature_o_dir))
85 |
86 | for i, feature in enumerate(feature_list_o):
87 | loss_gaze_o.append(self.MLPList_o[i](feature))
88 |
89 | for i, feature in enumerate(feature_list_n):
90 | loss_gaze_n.append(self.MLPList_n[i](feature))
91 |
92 | return gaze, None, loss_gaze_o, loss_gaze_n
93 |
94 | def loss(self, x_in, label):
95 |
96 | gaze, _, loss_gaze_o, loss_gaze_n = self.forward(x_in)
97 |
98 | loss1 = 2 * self.loss_op_re(gaze, label.normGaze)
99 |
100 | loss2 = 0
101 | # for zone in zones:
102 | # loss2 += (0.2/3) * self.loss_op_cls(zone, label.zone.view(-1))
103 |
104 | loss3 = 0
105 | for pred in loss_gaze_o:
106 | loss3 += self.loss_op_re(pred, label.originGaze)
107 |
108 | loss4 = 0
109 | for pred in loss_gaze_n:
110 | loss4 += self.loss_op_re(pred, label.normGaze)
111 | loss = loss1 + loss2 + loss3 + loss4
112 |
113 | return loss, [loss1, loss2, loss3, loss4]
114 |
115 |
116 | if __name__ == '__main__':
117 | x_in = {'origin': torch.zeros([5, 3, 224, 224]).cuda(),
118 | 'norm': torch.zeros([5, 3, 224, 224]).cuda(),
119 | 'pos': torch.zeros(5, 2, 6).cuda()
120 | }
121 |
122 | model = Model()
123 | model = model.to('cuda')
124 | print(model)
125 | a = model(x_in)
126 | print(a)
127 |
--------------------------------------------------------------------------------
/tester/total.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | base_dir = os.getcwd()
3 | sys.path.insert(0, base_dir)
4 | import model
5 | import importlib
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import cv2, yaml, copy
11 | from easydict import EasyDict as edict
12 | import ctools, gtools
13 | import argparse
14 |
15 | def main(train, test):
16 |
17 | # =================================> Setup <=========================
18 | reader = importlib.import_module("reader." + test.reader)
19 | torch.cuda.set_device(test.device)
20 |
21 | data = test.data
22 | load = test.load
23 |
24 |
25 | # ===============================> Read Data <=========================
26 | if data.isFolder:
27 | data, _ = ctools.readfolder(data)
28 |
29 | print(f"==> Test: {data.label} <==")
30 | dataset = reader.loader(data, 32, num_workers=4, shuffle=False)
31 |
32 | modelpath = os.path.join(train.save.metapath,
33 | train.save.folder, f"checkpoint/")
34 |
35 | logpath = os.path.join(train.save.metapath,
36 | train.save.folder, f"{test.savename}")
37 |
38 |
39 | if not os.path.exists(logpath):
40 | os.makedirs(logpath)
41 |
42 | # =============================> Test <=============================
43 |
44 | begin = load.begin_step; end = load.end_step; step = load.steps
45 |
46 | for saveiter in range(begin, end+step, step):
47 |
48 | print(f"Test {saveiter}")
49 |
50 | net = model.Model()
51 |
52 | statedict = torch.load(
53 | os.path.join(modelpath,
54 | f"Iter_{saveiter}_{train.save.model_name}.pt"),
55 | map_location={f"cuda:{train.device}": f"cuda:{test.device}"}
56 | )
57 |
58 | net.cuda(); net.load_state_dict(statedict); net.eval()
59 |
60 | length = len(dataset); accs = 0; count = 0
61 |
62 | logname = f"{saveiter}.log"
63 |
64 | outfile = open(os.path.join(logpath, logname), 'w')
65 | outfile.write("name results gts\n")
66 |
67 |
68 | with torch.no_grad():
69 | for j, (data, label) in enumerate(dataset):
70 |
71 | for key in data:
72 | if key != 'name': data[key] = data[key].cuda()
73 |
74 | names = data["name"]
75 | gts = label.cuda()
76 |
77 | gazes = net(data)
78 |
79 | for k, gaze in enumerate(gazes):
80 |
81 | gaze = gaze.cpu().detach().numpy()
82 | gt = gts.cpu().numpy()[k]
83 |
84 | count += 1
85 | accs += gtools.angular(
86 | gtools.gazeto3d(gaze),
87 | gtools.gazeto3d(gt)
88 | )
89 |
90 | name = [names[k]]
91 | gaze = [str(u) for u in gaze]
92 | gt = [str(u) for u in gt]
93 | log = name + [",".join(gaze)] + [",".join(gt)]
94 | outfile.write(" ".join(log) + "\n")
95 |
96 | loger = f"[{saveiter}] Total Num: {count}, avg: {accs/count}"
97 | outfile.write(loger)
98 | print(loger)
99 | outfile.close()
100 |
101 | if __name__ == "__main__":
102 |
103 | parser = argparse.ArgumentParser(description='Pytorch Basic Model Training')
104 |
105 | parser.add_argument('-s', '--source', type=str,
106 | help = 'config path about training')
107 |
108 | parser.add_argument('-t', '--target', type=str,
109 | help = 'config path about test')
110 |
111 | args = parser.parse_args()
112 |
113 | # Read model from train config and Test data in test config.
114 | train_conf = edict(yaml.load(open(args.source), Loader=yaml.FullLoader))
115 |
116 | test_conf = edict(yaml.load(open(args.target), Loader=yaml.FullLoader))
117 |
118 | print("=======================>(Begin) Config of training<======================")
119 | print(ctools.DictDumps(train_conf))
120 | print("=======================>(End) Config of training<======================")
121 | print("")
122 | print("=======================>(Begin) Config for test<======================")
123 | print(ctools.DictDumps(test_conf))
124 | print("=======================>(End) Config for test<======================")
125 |
126 | main(train_conf.train, test_conf.test)
127 |
128 |
129 |
--------------------------------------------------------------------------------
/trainer/total.py:
--------------------------------------------------------------------------------
1 | import sys,os
2 | base_dir = os.getcwd()
3 | sys.path.insert(0, base_dir)
4 | import model
5 | import importlib
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import copy
11 | import yaml
12 | import cv2
13 | import ctools
14 | from easydict import EasyDict as edict
15 | import torch.backends.cudnn as cudnn
16 | from warmup_scheduler import GradualWarmupScheduler
17 | import argparse
18 |
19 | def main(config):
20 |
21 | # ===================>> Setup <<=================================
22 |
23 | dataloader = importlib.import_module("reader." + config.reader)
24 |
25 | torch.cuda.set_device(config.device)
26 | cudnn.benchmark = True
27 |
28 | data = config.data
29 | save = config.save
30 | params = config.params
31 |
32 | print("===> Read data <===")
33 |
34 | if data.isFolder:
35 | data, _ = ctools.readfolder(data)
36 |
37 | dataset = dataloader.loader(
38 | data,
39 | params.batch_size,
40 | shuffle=True,
41 | num_workers=8
42 | )
43 |
44 |
45 | print("===> Model building <===")
46 | net = model.Model()
47 | net.train(); net.cuda()
48 |
49 |
50 | # Pretrain
51 | pretrain = config.pretrain
52 |
53 | if pretrain.enable and pretrain.device:
54 | net.load_state_dict(
55 | torch.load(
56 | pretrain.path,
57 | map_location={f"cuda:{pretrain.device}": f"cuda:{config.device}"}
58 | )
59 | )
60 | elif pretrain.enable and not pretrain.device:
61 | net.load_state_dict(
62 | torch.load(pretrain.path)
63 | )
64 |
65 |
66 | print("===> optimizer building <===")
67 | optimizer = optim.Adam(
68 | net.parameters(),
69 | lr=params.lr,
70 | betas=(0.9,0.999)
71 | )
72 |
73 | scheduler = optim.lr_scheduler.StepLR(
74 | optimizer,
75 | step_size=params.decay_step,
76 | gamma=params.decay
77 | )
78 |
79 | if params.warmup:
80 | scheduler = GradualWarmupScheduler(
81 | optimizer,
82 | multiplier=1,
83 | total_epoch=params.warmup,
84 | after_scheduler=scheduler
85 | )
86 |
87 | savepath = os.path.join(save.metapath, save.folder, f"checkpoint")
88 |
89 | if not os.path.exists(savepath):
90 | os.makedirs(savepath)
91 |
92 | # =====================================>> Training << ====================================
93 | print("===> Training <===")
94 |
95 | length = len(dataset); total = length * params.epoch
96 | timer = ctools.TimeCounter(total)
97 |
98 |
99 | optimizer.zero_grad()
100 | optimizer.step()
101 | scheduler.step()
102 |
103 |
104 | with open(os.path.join(savepath, "train_log"), 'w') as outfile:
105 | outfile.write(ctools.DictDumps(config) + '\n')
106 |
107 | for epoch in range(1, params.epoch+1):
108 | for i, (data, anno) in enumerate(dataset):
109 |
110 | # -------------- forward -------------
111 | for key in data:
112 | if key != 'name': data[key] = data[key].cuda()
113 |
114 | anno = anno.cuda()
115 | loss = net.loss(data, anno)
116 |
117 | # -------------- Backward ------------
118 | optimizer.zero_grad()
119 | loss.backward()
120 | optimizer.step()
121 | rest = timer.step()/3600
122 |
123 |
124 | if i % 20 == 0:
125 | log = f"[{epoch}/{params.epoch}]: " + \
126 | f"[{i}/{length}] " +\
127 | f"loss:{loss} " +\
128 | f"lr:{ctools.GetLR(optimizer)} " +\
129 | f"rest time:{rest:.2f}h"
130 |
131 | print(log); outfile.write(log + "\n")
132 | sys.stdout.flush(); outfile.flush()
133 |
134 | scheduler.step()
135 |
136 | if epoch % save.step == 0:
137 | torch.save(
138 | net.state_dict(),
139 | os.path.join(
140 | savepath,
141 | f"Iter_{epoch}_{save.model_name}.pt"
142 | )
143 | )
144 |
145 |
146 | if __name__ == "__main__":
147 |
148 | parser = argparse.ArgumentParser(description='Pytorch Basic Model Training')
149 |
150 | parser.add_argument('-s', '--train', type=str,
151 | help='The source config for training.')
152 |
153 | args = parser.parse_args()
154 |
155 | config = edict(yaml.load(open(args.train), Loader=yaml.FullLoader))
156 |
157 | print("=====================>> (Begin) Training params << =======================")
158 | print(ctools.DictDumps(config))
159 | print("=====================>> (End) Traning params << =======================")
160 |
161 | main(config.train)
162 |
163 |
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 |
4 |
5 |
6 | model_urls = {
7 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
8 | }
9 |
10 |
11 | def conv3x3(in_planes, out_planes, stride=1):
12 | """3x3 convolution with padding"""
13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
14 | padding=1, bias=False)
15 |
16 |
17 | class BasicBlock(nn.Module):
18 | expansion = 1
19 |
20 | def __init__(self, inplanes, planes, stride=1, downsample=None):
21 | super(BasicBlock, self).__init__()
22 | self.conv1 = conv3x3(inplanes, planes, stride)
23 | self.bn1 = nn.BatchNorm2d(planes)
24 | self.relu = nn.ReLU(inplace=True)
25 | self.conv2 = conv3x3(planes, planes)
26 | self.bn2 = nn.BatchNorm2d(planes)
27 | self.downsample = downsample
28 | self.stride = stride
29 |
30 | def forward(self, x):
31 | residual = x
32 |
33 | out = self.conv1(x)
34 | out = self.bn1(out)
35 | out = self.relu(out)
36 |
37 | out = self.conv2(out)
38 | out = self.bn2(out)
39 |
40 | if self.downsample is not None:
41 | residual = self.downsample(x)
42 |
43 | out += residual
44 | out = self.relu(out)
45 |
46 | return out
47 |
48 |
49 | class Bottleneck(nn.Module):
50 | expansion = 4
51 |
52 | def __init__(self, inplanes, planes, stride=1, downsample=None):
53 | super(Bottleneck, self).__init__()
54 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
55 | self.bn1 = nn.BatchNorm2d(planes)
56 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
57 | padding=1, bias=False)
58 | self.bn2 = nn.BatchNorm2d(planes)
59 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
60 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
61 | self.relu = nn.ReLU(inplace=True)
62 | self.downsample = downsample
63 | self.stride = stride
64 |
65 | def forward(self, x):
66 | residual = x
67 |
68 | out = self.conv1(x)
69 | out = self.bn1(out)
70 | out = self.relu(out)
71 |
72 | out = self.conv2(out)
73 | out = self.bn2(out)
74 | out = self.relu(out)
75 |
76 | out = self.conv3(out)
77 | out = self.bn3(out)
78 |
79 | if self.downsample is not None:
80 | residual = self.downsample(x)
81 |
82 | out += residual
83 | out = self.relu(out)
84 |
85 | return out
86 |
87 |
88 | class ResNet(nn.Module):
89 |
90 | def __init__(self, block, layers, input_dim=[64, 128, 256, 512]):
91 | super(ResNet, self).__init__()
92 |
93 | self.inplanes = input_dim[0]
94 |
95 | #assert len(maps) == 5, f'The length of input_dim should be 5'
96 |
97 | self.conv1 = nn.Conv2d(3, input_dim[0], kernel_size=7, stride=2, padding=3,
98 | bias=False)
99 | self.bn1 = nn.BatchNorm2d(input_dim[0])
100 | self.relu = nn.ReLU(inplace=True)
101 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
102 |
103 | self.layer1 = self._make_layer(block, input_dim[0], layers[0])
104 | self.layer2 = self._make_layer(block, input_dim[1], layers[1], stride=2)
105 | self.layer3 = self._make_layer(block, input_dim[2], layers[2], stride=2)
106 | self.layer4 = self._make_layer(block, input_dim[3], layers[3], stride=2)
107 |
108 | for m in self.modules():
109 | if isinstance(m, nn.Conv2d):
110 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
111 | elif isinstance(m, nn.BatchNorm2d):
112 | nn.init.constant_(m.weight, 1)
113 | nn.init.constant_(m.bias, 0)
114 |
115 | def _make_layer(self, block, planes, blocks, stride=1):
116 | downsample = None
117 | if stride != 1 or self.inplanes != planes * block.expansion:
118 | downsample = nn.Sequential(
119 | nn.Conv2d(self.inplanes, planes * block.expansion,
120 | kernel_size=1, stride=stride, bias=False),
121 | nn.BatchNorm2d(planes * block.expansion),
122 | )
123 |
124 | layers = []
125 | layers.append(block(self.inplanes, planes, stride, downsample))
126 | self.inplanes = planes * block.expansion
127 | for i in range(1, blocks):
128 | layers.append(block(self.inplanes, planes))
129 |
130 | return nn.Sequential(*layers)
131 |
132 | def forward(self, x):
133 | x = self.conv1(x)
134 | x = self.bn1(x)
135 | x = self.relu(x)
136 | x = self.maxpool(x)
137 |
138 | x1 = self.layer1(x)
139 |
140 | x2 = self.layer2(x1)
141 |
142 | x3 = self.layer3(x2)
143 |
144 | x4 = self.layer4(x3)
145 |
146 | return x1, x2, x3, x4
147 |
148 |
149 | def resnet18(pretrained=False, **kwargs):
150 | """Constructs a ResNet-18 model.
151 | Args:
152 | pretrained (bool): If True, returns a model pre-trained on ImageNet
153 | """
154 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
155 | if pretrained:
156 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']),strict=False)
157 | return model
158 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
What Do You See in Vehicle? Comprehensive Vision Solution for In-Vehicle Gaze Estimation
3 |
4 | Yihua Cheng , Yaning Zhu, Zongji Wang, Hongquan Hao, Yongwei Liu, Shiqing Cheng, Xi Wang, Hyung Jin Chang, CVPR 2024
5 |
6 |
7 |
8 |
9 |
10 | ## Description
11 | This repository provides offical code of the paper titled *What Do You See in Vehicle? Comprehensive Vision Solution for In-Vehicle Gaze Estimation*, accepted at CVPR24.
12 | Our contribution includes:
13 | - We provide a dataset **IVGaze** collected on vehicles containing 44k images of 125 subjects.
14 | - We propose a gaze pyramid transformer (GazePTR) that leverages transformer-based multilevel features integration.
15 | - We introduce the dual-stream gaze pyramid transformer (GazeDPTR). Employing perspective transformation, we rotate virtual cameras to normalize images, utilizing camera pose to merge normalized and original images for accurate gaze estimation.
16 |
17 | Please visit our project page for details. The dataset is available on this page .
18 |
19 | [](https://www.youtube.com/watch?v=050M4CK5EwI&t=3s "Gaze")
20 |
21 | ## Requirement
22 |
23 | 1. Install Pytorch and torchvision. This code is written in `Python 3.8` and utilizes `PyTorch 1.13.1` with `CUDA 11.6` on Nvidia GeForce RTX 3090. While this environment is recommended, it is not mandatory. Feel free to run the code on your preferred environment.
24 |
25 | ```
26 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
27 | ```
28 |
29 | 2. Install other packages.
30 | ```
31 | pip install opencv-python PyYAML easydict warmup_scheduler
32 | ```
33 |
34 | If you have any issues due to missing packages, please report them. I will update the requirements. Thank you for your cooperation.
35 |
36 | ## Training
37 |
38 | **Step 1: Choose the model file.**
39 |
40 | We provide three models `GazePTR.py`, `GazeDPTR.py` and `GazeDPTR_v2.py`. The links of pre-trained weights are the same. Please load corresponding weights based on your requirement.
41 |
42 | | | Name | Description | Input | Output|Accuracy|Pretrained Weights|
43 | |:----|:----|:----|:----:|:----:|:----:|:----:|
44 | |1|GazePTR| This method leverages multi-level feature.|Normalized Images|Gaze Directions|7.04°| Link |
45 | |2|GazeDPTR| This method integrates feature from two images.|Normalized Images Original Images|Gaze Directions|6.71°| Link |
46 | |3|GazeDPTR_V2| This method contains a diffierential projection for gaze zone prediction. |Normalized Images Original Images|Gaze Directions Gaze Zone|6.71° 81.8%| Link |
47 |
48 | Please choose one model and rename it as `model.py`, *e.g.*,
49 | ```
50 | cp GazeDPTR.py model.py
51 | ```
52 |
53 | **Step 2: Modify the config file**
54 |
55 |
56 | Please modify `config/train/config_iv.yaml` according to your environment settings.
57 |
58 | - The `Save` attribute specifies the save path, where the model will be stored at`os.path.join({save.metapath}, {save.folder})`. Each saved model will be named as `Iter_{epoch}_{save.model_name}.pt`
59 | - The `data` attribute indicates the dataset path. Update the `image` and `label` to match your dataset location.
60 |
61 | **Step 3: Training models**
62 |
63 | Run the following command to initiate training. The argument `3` indicates that it will automatically perform three-fold cross-validation:
64 |
65 | ```
66 | python trainer/leave.py config/train/config_iv.yaml 3
67 | ```
68 |
69 | Once the training is complete, you will find the weights saved at `os.path.join({save.metapath}, {save.folder})`.
70 | Within the `checkpoint` directory, you will find three folders named `train1.txt`, `train2.txt`, and `train3.txt`, corresponding to the three-fold cross-validation. Each folder contains the respective trained model."
71 |
72 | ## Testing
73 | Run the following command for testing.
74 | ```
75 | python tester/leave.py config/train/config_iv.yaml config/test/config_iv.yaml 3
76 | ```
77 | Similarly,
78 | - Update the `image` and `label` in `config/test/config_iv.yaml` based on your dataset location.
79 | - The `savename` attribute specifies the folder to save prediction results, which will be stored at `os.path.join({save.metapath}, {save.folder})` as defined in `config/train/config_iv.yaml`.
80 | - The code `tester/leave.py` provides the gaze zone prediction results. Remove it if you do not require gaze zone prediction.
81 |
82 | ## Evaluation
83 |
84 | We provide `evaluation.py` script to assess the accuracy of gaze direction estimation. Run the following command:
85 | ```
86 | python evaluation.py {PATH}
87 | ```
88 | Replace `{PATH}` with the path of `{savename}` as configured in your settings.
89 |
90 | Please find the visualization code in the issues.
91 |
92 | ## Contact
93 | Please send email to `y.cheng.2@bham.ac.uk` if you have any questions.
94 |
--------------------------------------------------------------------------------
/tester/leave.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | base_dir = os.getcwd()
3 | sys.path.insert(0, base_dir)
4 | import model
5 | import importlib
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import cv2, yaml, copy
11 | from easydict import EasyDict as edict
12 | import ctools, gtools
13 | import argparse
14 |
15 | def main(train, test):
16 |
17 | # ===============================> Setup <============================
18 | reader = importlib.import_module("reader." + test.reader)
19 |
20 | data = test.data
21 | load = test.load
22 | torch.cuda.set_device(test.device)
23 |
24 |
25 | # ==============================> Read Data <========================
26 | data.origin, folder = ctools.readfolder(data.origin, [test.person])
27 | data.norm, folder = ctools.readfolder(data.norm, [test.person])
28 |
29 | testname = folder[test.person]
30 |
31 | dataset = reader.loader(data, 500, num_workers=4, shuffle=True)
32 |
33 | modelpath = os.path.join(train.save.metapath,
34 | train.save.folder, f'checkpoint/{testname}')
35 | logpath = os.path.join(train.save.metapath,
36 | train.save.folder, f'{test.savename}/{testname}')
37 |
38 | if not os.path.exists(logpath):
39 | os.makedirs(logpath)
40 |
41 | # =============================> Test <==============================
42 |
43 | begin = load.begin_step; end = load.end_step; step = load.steps
44 |
45 | for saveiter in range(begin, end+step, step):
46 | print(f"Test {saveiter}")
47 |
48 | # ----------------------Load Model------------------------------
49 | net = model.Model()
50 |
51 |
52 | statedict = torch.load(
53 | os.path.join(modelpath, f"Iter_{saveiter}_{train.save.model_name}.pt"),
54 | map_location={f"cuda:{train.device}":f"cuda:{test.device}"}
55 | )
56 |
57 |
58 | net.cuda(); net.load_state_dict(statedict); net.eval()
59 |
60 | length = len(dataset); accs = 0; count = 0
61 |
62 | # -----------------------Open log file--------------------------------
63 | logname = f"{saveiter}.log"
64 |
65 | outfile1 = open(os.path.join(logpath, logname + '_zone'), 'w')
66 | outfile1.write("name results gts\n")
67 |
68 | outfile2 = open(os.path.join(logpath, logname + '_dir'), 'w')
69 | outfile2.write("name results gts\n")
70 |
71 |
72 |
73 | # -------------------------Testing---------------------------------
74 | with torch.no_grad():
75 | n_true = {}
76 | n_total = {}
77 |
78 | for j, (data, label) in enumerate(dataset):
79 |
80 | for key in data:
81 | if key != 'name': data[key] = data[key].cuda()
82 |
83 | names = data["name"]
84 |
85 | gt_zones = label.zone.view(-1)
86 | gt_dirs = label.originGaze
87 | gt_dirs = label.normGaze
88 | gazes, zones, _, _ = net(data, train=False)
89 |
90 |
91 | for k, cls in enumerate(zones):
92 |
93 | gt = str(int(gt_zones[k]))
94 | name = [names[k]]
95 | if gt == str(int(cls)):
96 | n_true[gt] = 1 + n_true.get(gt, 0)
97 |
98 | n_total[gt] = 1 + n_total.get(gt, 0)
99 |
100 | log = name + [f"{cls}"] + [f"{gt}"]
101 | outfile1.write(" ".join(log) + "\n")
102 |
103 | for k, gaze in enumerate(gazes):
104 |
105 | gaze = gaze.cpu().detach().numpy()
106 | gt = gt_dirs.numpy()[k]
107 |
108 | count += 1
109 | accs += gtools.angular(
110 | gtools.gazeto3d(gaze),
111 | gtools.gazeto3d(gt)
112 | )
113 |
114 | name = [names[k]]
115 | gaze = [str(u) for u in gaze]
116 | gt = [str(u) for u in gt]
117 | log = name + [",".join(gaze)] + [",".join(gt)]
118 | outfile2.write(" ".join(log) + "\n")
119 |
120 | keys = sorted(list(n_true.keys()), key = lambda x:int(x))
121 | true_num = 0
122 | total_num = 0
123 | for key in keys:
124 | true_num += n_true[key]
125 | total_num += n_total[key]
126 | loger = f'Class {key} {n_true[key]} {n_total[key]} {n_true[key]/n_total[key]:.3f}\n'
127 | outfile1.write(loger)
128 | loger = f"[{saveiter}] Total Num: {total_num}, True: {true_num}, AP:{true_num/total_num:.3f}"
129 | outfile1.write(loger)
130 | print(loger)
131 |
132 | loger = f"[{saveiter}] Total Num: {count}, avg: {accs/count}"
133 | outfile2.write(loger)
134 | print(loger)
135 | outfile1.close()
136 | outfile2.close()
137 |
138 | if __name__ == "__main__":
139 |
140 |
141 | # Read model from train config and Test data in test config.
142 | train_conf = edict(yaml.load(open(sys.argv[1]), Loader=yaml.FullLoader))
143 |
144 | test_conf = edict(yaml.load(open(sys.argv[2]), Loader=yaml.FullLoader))
145 |
146 | for i in range(int(sys.argv[3])):
147 | test_conf_cur = copy.deepcopy(test_conf)
148 |
149 | test_conf_cur.person = i
150 |
151 | print("=======================>(Begin) Config of training<======================")
152 |
153 | print(ctools.DictDumps(train_conf))
154 |
155 | print("=======================>(End) Config of training<======================")
156 |
157 | print("")
158 |
159 | print("=======================>(Begin) Config for test<======================")
160 |
161 | print(ctools.DictDumps(test_conf_cur))
162 |
163 | print("=======================>(End) Config for test<======================")
164 |
165 | main(train_conf.train, test_conf_cur)
166 |
167 |
--------------------------------------------------------------------------------
/trainer/leave.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | base_dir = os.getcwd()
3 | sys.path.insert(0, base_dir)
4 | import model
5 | import importlib
6 | import numpy as np
7 | import torch
8 | import cv2
9 | import torch.nn as nn
10 | import torch.optim as optim
11 | import copy
12 | import yaml
13 | import ctools
14 | from easydict import EasyDict as edict
15 | import torch.backends.cudnn as cudnn
16 | from warmup_scheduler import GradualWarmupScheduler
17 | import random
18 | import argparse
19 |
20 | def setup_seed(seed=0):
21 | torch.manual_seed(seed)
22 | np.random.seed(seed)
23 | random.seed(seed)
24 | torch.cuda.manual_seed(seed)
25 | torch.cuda.manual_seed_all(seed)
26 |
27 |
28 | def main(config):
29 | # ===============================> Setup <================================
30 |
31 | setup_seed(123)
32 |
33 | dataloader = importlib.import_module("reader." + config.reader)
34 | torch.cuda.set_device(config.device)
35 | cudnn.benchmark = True
36 |
37 | data = config.data
38 | save = config.save
39 | params = config.params
40 |
41 |
42 | print("===> Read data <===")
43 | data.origin, folder = ctools.readfolder(
44 | data.origin,
45 | [config.person],
46 | reverse=True
47 | )
48 |
49 | data.norm, folder = ctools.readfolder(
50 | data.norm,
51 | [config.person],
52 | reverse=True
53 | )
54 |
55 |
56 | savename = folder[config.person]
57 |
58 | dataset = dataloader.loader(
59 | data,
60 | params.batch_size,
61 | shuffle=True,
62 | num_workers=6
63 | )
64 |
65 | print("===> Model building <===")
66 | net = model.Model(); net.train(); net.cuda()
67 |
68 |
69 | # Pretrain
70 | pretrain = config.pretrain
71 | if pretrain.enable and pretrain.device:
72 | net.load_state_dict(
73 | torch.load(
74 | pretrain.path,
75 | map_location={f"cuda:{pretrain.device}": f"cuda:{config.device}"}
76 | )
77 | )
78 | elif pretrain.enable and not pretrain.device:
79 | net.load_state_dict(
80 | torch.load(pretrain.path)
81 | )
82 |
83 |
84 | print("===> optimizer building <===")
85 | optimizer = optim.Adam(
86 | net.parameters(),
87 | lr=params.lr,
88 | betas=(0.9,0.95)
89 | )
90 |
91 | scheduler = optim.lr_scheduler.StepLR(
92 | optimizer,
93 | step_size=params.decay_step,
94 | gamma=params.decay
95 | )
96 |
97 | if params.warmup:
98 | scheduler = GradualWarmupScheduler(
99 | optimizer,
100 | multiplier=1,
101 | total_epoch=params.warmup,
102 | after_scheduler=scheduler
103 | )
104 |
105 | savepath = os.path.join(save.metapath, save.folder, f"checkpoint/{savename}")
106 |
107 | if not os.path.exists(savepath):
108 | os.makedirs(savepath)
109 |
110 | # =======================================> Training < ==========================
111 | print("===> Training <===")
112 | length = len(dataset); total = length * params.epoch
113 | timer = ctools.TimeCounter(total)
114 |
115 |
116 | optimizer.zero_grad()
117 | optimizer.step()
118 | scheduler.step()
119 |
120 | with open(os.path.join(savepath, "train_log"), 'w') as outfile:
121 | outfile.write(ctools.DictDumps(config) + '\n')
122 |
123 | for epoch in range(1, params.epoch+1):
124 | for i, (data, anno) in enumerate(dataset):
125 |
126 | # ------------------forward--------------------
127 | for key in data.keys():
128 | if key != 'name': data[key] = data[key].cuda()
129 |
130 | for key in anno.keys(): anno[key] = anno[key].cuda()
131 |
132 | loss, losslist = net.loss(data, anno)
133 |
134 | # -----------------backward--------------------
135 | optimizer.zero_grad()
136 |
137 | loss.backward()
138 |
139 | optimizer.step()
140 |
141 | rest = timer.step()/3600
142 |
143 | # -----------------loger----------------------
144 | if i % 20 == 0:
145 | log = f"[{epoch}/{params.epoch}]: " +\
146 | f"[{i}/{length}] " +\
147 | f"loss:{loss:.3f} " +\
148 | f"loss_re:{losslist[0]:.3f} " +\
149 | f"loss_cls:{losslist[1]:.3f} " +\
150 | f"loss_o:{losslist[2]:.3f} " +\
151 | f"loss_n:{losslist[3]:.3f} " +\
152 | f"lr:{ctools.GetLR(optimizer)} "+\
153 | f"rest time:{rest:.2f}h"
154 |
155 | print(log); outfile.write(log + "\n")
156 | sys.stdout.flush(); outfile.flush()
157 |
158 | scheduler.step()
159 |
160 | if epoch % save.step == 0:
161 | torch.save(
162 | net.state_dict(),
163 | os.path.join(savepath, f"Iter_{epoch}_{save.model_name}.pt")
164 | )
165 |
166 | if __name__ == "__main__":
167 |
168 | config = edict(yaml.load(open(sys.argv[1]), Loader=yaml.FullLoader))
169 |
170 | config = config.train
171 | person = int(sys.argv[2])
172 |
173 | for i in range(person):
174 | config_i = copy.deepcopy(config)
175 | config_i.person = i
176 |
177 | print("=====================>> (Begin) Training params << =======================")
178 |
179 | print(ctools.DictDumps(config_i))
180 |
181 | print("=====================>> (End) Traning params << =======================")
182 |
183 | main(config_i)
184 |
185 |
--------------------------------------------------------------------------------
/reader/reader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import random
5 | import numpy as np
6 | from easydict import EasyDict as edict
7 | from torch.utils.data import Dataset, DataLoader
8 | from torchvision import transforms
9 | from PIL import Image
10 | import copy
11 |
12 | def gazeto2d(gaze):
13 | yaw = -np.arctan2(-gaze[0], -gaze[2])
14 | pitch = -np.arcsin(-gaze[1])
15 | return np.array([yaw, pitch])
16 |
17 | def Decode_MPII(line):
18 | anno = edict()
19 | anno.face, anno.lefteye, anno.righteye = line[0], line[1], line[2]
20 | anno.name = line[3]
21 |
22 | anno.gaze3d, anno.head3d = line[5], line[6]
23 | anno.gaze2d, anno.head2d = line[7], line[8]
24 | return anno
25 |
26 | def Decode_IVOrigin(line):
27 | anno = edict()
28 | anno.face = line[0]
29 | anno.name = line[0]
30 | anno.gaze = line[1]
31 | anno.placeholder = line[2]
32 | anno.zone = line[3]
33 | # anno.target = line[4]
34 | anno.origin = line[5]
35 | return anno
36 |
37 | def Decode_IVNorm(line):
38 | anno = edict()
39 | anno.face = line[0]
40 | anno.name = line[0]
41 | anno.gaze = line[1]
42 | anno.head = line[2]
43 | anno.zone = line[3]
44 | anno.origin = line[4]
45 | anno.norm = line[6]
46 | return anno
47 |
48 |
49 | def Decode_Dict():
50 | mapping = edict()
51 | mapping.ivorigin = Decode_IVOrigin
52 | mapping.ivnorm = Decode_IVNorm
53 | return mapping
54 |
55 | def long_substr(str1, str2):
56 | substr = ''
57 | for i in range(len(str1)):
58 | for j in range(len(str1)-i+1):
59 | if j > len(substr) and (str1[i:i+j] in str2):
60 | substr = str1[i:i+j]
61 | return len(substr)
62 |
63 | def Get_Decode(name):
64 | mapping = Decode_Dict()
65 | keys = list(mapping.keys())
66 | name = name.lower()
67 | score = [long_substr(name, i) for i in keys]
68 | key = keys[score.index(max(score))]
69 | return mapping[key]
70 |
71 |
72 | class commonloader(Dataset):
73 |
74 | def __init__(self, dataset):
75 |
76 | # Read source data
77 | self.source = edict()
78 | self.source.origin = edict()
79 | self.source.norm = edict()
80 |
81 | # Read origin data
82 | origin = dataset.origin
83 |
84 | self.source.origin.root = origin.image
85 | self.source.origin.line = self.__readlines(origin.label, origin.header)
86 | self.source.origin.decode = Get_Decode(origin.name)
87 |
88 | # Read norm data
89 | norm = dataset.norm
90 |
91 | # self.source.norm = copy.deepcopy(dataset.norm)
92 | self.source.norm.root = norm.image
93 | self.source.norm.line = self.__readlines(norm.label, norm.header)
94 | self.source.norm.decode = Get_Decode(norm.name)
95 |
96 | # build transforms
97 | self.transforms = transforms.Compose([
98 | transforms.ToTensor()
99 | ])
100 |
101 |
102 | def __readlines(self, filename, header=True):
103 |
104 | data = []
105 | if isinstance(filename, list):
106 | for i in filename:
107 | with open(i) as f: line = f.readlines()
108 | if header: line.pop(0)
109 | data.extend(line)
110 |
111 | else:
112 | with open(filename) as f: data = f.readlines()
113 | if header: data.pop(0)
114 | return data
115 |
116 |
117 | def __len__(self):
118 | assert len(self.source.origin.line) == len(self.source.norm.line), 'Two files are not aligned.'
119 | return len(self.source.origin.line)
120 |
121 |
122 | def __getitem__(self, idx):
123 |
124 | # ------------------Read origin-----------------------
125 | line = self.source.origin.line[idx]
126 | line = line.strip().split(" ")
127 |
128 | # decode the info
129 | anno = self.source.origin.decode(line)
130 |
131 | # read image
132 | origin_img = cv2.imread(os.path.join(self.source.origin.root, anno.face))
133 | origin_img = self.transforms(origin_img)
134 |
135 | origin_cam_mat = np.diag((1, 1, 1))
136 | origin_cam_mat = torch.from_numpy(origin_cam_mat).type(torch.FloatTensor)
137 |
138 | origin_z_axis = torch.Tensor([0, 0, 1]).type(torch.FloatTensor)
139 |
140 | zone = int(anno.zone)
141 | zone = torch.Tensor([zone]).type(torch.long)
142 |
143 | # read label
144 | origin_label = gazeto2d(np.array(anno.gaze.split(",")).astype("float"))
145 | origin_label = torch.from_numpy(origin_label).type(torch.FloatTensor)
146 |
147 | gaze_origin = np.array(anno.origin.split(",")).astype("float")
148 | gaze_origin = torch.from_numpy(gaze_origin).type(torch.FloatTensor)
149 |
150 | name = anno.name
151 |
152 | # --------------------read norm------------------------
153 | line = self.source.norm.line[idx]
154 | line = line.strip().split(" ")
155 |
156 | # decode the info
157 | anno = self.source.norm.decode(line)
158 |
159 | # read image
160 | norm_img = cv2.imread(os.path.join(self.source.norm.root, anno.face))
161 | norm_img = self.transforms(norm_img)
162 |
163 | # camera position.
164 | norm_mat = np.fromstring(anno.norm, sep=',')
165 | norm_mat = cv2.Rodrigues(norm_mat)[0]
166 |
167 | # Camera rotation. Label = R * prediction
168 | inv_mat = np.linalg.inv(norm_mat)
169 | z_axis = inv_mat[:, 2].flatten()
170 |
171 | norm_cam_mat = torch.from_numpy(inv_mat).type(torch.FloatTensor)
172 | z_axis = torch.from_numpy(z_axis).type(torch.FloatTensor)
173 |
174 | # read label
175 | norm_label = gazeto2d(np.array(anno.gaze.split(",")).astype("float"))
176 | norm_label = torch.from_numpy(norm_label).type(torch.FloatTensor)
177 |
178 | assert name == anno.name, 'Data is not aligned'
179 |
180 | pos = torch.concat([torch.unsqueeze(origin_z_axis, 0), torch.unsqueeze(z_axis, 0)] ,0)
181 |
182 | # ---------------------------------------------------
183 | data = edict()
184 | data.origin_face = origin_img
185 | data.origin_cam = origin_cam_mat
186 | data.norm_face = norm_img
187 | data.norm_cam = norm_cam_mat
188 | data.pos = pos
189 | data.name = anno.name
190 | data.gaze_origin = gaze_origin
191 |
192 | label = edict()
193 | label.originGaze = origin_label
194 | label.normGaze = norm_label
195 | label.zone = zone
196 |
197 | return data, label
198 |
199 |
200 | def loader(source, batch_size, shuffle=False, num_workers=0):
201 |
202 | dataset = commonloader(source)
203 |
204 | print(f"-- [Read Data]: Total num: {len(dataset)}")
205 |
206 | print(f"-- [Read Data]: Source: {source.norm.label}")
207 |
208 | load = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
209 | return load
210 |
211 | if __name__ == "__main__":
212 |
213 | path = './p00.label'
214 | # d = loader(path)
215 | # print(len(d))
216 | # (data, label) = d.__getitem__(0)
217 |
218 |
--------------------------------------------------------------------------------
/GazeDPTR_V2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import math
5 | import copy
6 | from IVModule import Backbone, Transformer, PoseTransformer, TripleDifferentialProj, PositionalEncoder
7 |
8 |
9 | def ep0(x):
10 | return x.unsqueeze(0)
11 |
12 | class Model(nn.Module):
13 |
14 | def __init__(self):
15 |
16 | super(Model, self).__init__()
17 |
18 | # used for origin.
19 | transIn = 128
20 | convDims = [64, 128, 256, 512]
21 |
22 | # origin produces two features, one for gaze zone, one for gaze direction.
23 | self.borigin = Backbone(2, transIn, convDims)
24 |
25 | # norm only used for gaze estimation.
26 | self.bnorm = Backbone(1, transIn, convDims)
27 |
28 | self.ptrans = PoseTransformer(transIn, 3)
29 |
30 | # Proj
31 | self.proj = TripleDifferentialProj()
32 |
33 | # 3 * (30 * 2 + 1)
34 | self.PoseEncoder = PositionalEncoder(30, True)
35 |
36 | self.MLP_Pos_x = nn.Sequential(
37 | nn.Linear(183, 128),
38 | nn.ReLU(inplace=True),
39 | )
40 |
41 | self.MLP_Pos_y = nn.Sequential(
42 | nn.Linear(183, 128),
43 | nn.ReLU(inplace=True),
44 | )
45 |
46 | self.MLP_Pos_z = nn.Sequential(
47 | nn.Linear(183, 128),
48 | nn.ReLU(inplace=True),
49 | )
50 |
51 | self.MLP_Pos = [self.MLP_Pos_x, self.MLP_Pos_y, self.MLP_Pos_z]
52 |
53 |
54 | # Transformer to combine gaze point and feature
55 | self.tripleTrans = Transformer(
56 | input_dim = 128,
57 | length = 3,
58 | layer_num=2,
59 | nhead=4
60 | )
61 |
62 | self.zoneTrans = Transformer(
63 | input_dim = 128,
64 | length = 2,
65 | layer_num=2,
66 | nhead=4
67 | )
68 |
69 | # MLP for gaze estimation
70 | class_num = 122
71 |
72 | self.MLP_o_dir = nn.Linear(transIn, 2)
73 | self.MLP_n_dir = nn.Linear(transIn, 2)
74 |
75 |
76 | module_list = []
77 | for i in range(len(convDims)):
78 | module_list.append(nn.Linear(transIn, 2))
79 | self.MLPList_o = nn.ModuleList(module_list)
80 |
81 | module_list = []
82 | for i in range(len(convDims)):
83 | module_list.append(nn.Linear(transIn, 2))
84 | self.MLPList_n = nn.ModuleList(module_list)
85 |
86 | self.MLP_o_dir2 = nn.Linear(transIn, 2)
87 | self.MLP_n_dir2 = nn.Linear(transIn, 2)
88 |
89 | self.MLP_o_zone = nn.Linear(transIn, class_num)
90 | self.MLP_o_zone2 = nn.Linear(transIn, class_num)
91 | self.MLP_o_zone3 = nn.Linear(transIn, class_num)
92 |
93 | # Loss function
94 | self.loss_op_re = nn.L1Loss()
95 | self.loss_op_cls = nn.CrossEntropyLoss()
96 |
97 |
98 | def forward(self, x_in, train=True):
99 |
100 | # feature [outFeatureNum, Batch, transIn], MLfeatgure: list[x1, x2...]
101 |
102 | # Extract feature from both two images
103 | feature_o, feature_list_o= self.borigin(x_in['origin_face'])
104 |
105 | feature_n, feature_list_n = self.bnorm(x_in['norm_face'])
106 |
107 | # Get feature for different task
108 | # [5, 128] [1. 5, 128]
109 | feature_o_zone = feature_o[0,:]
110 | feature_o_dir = feature_o[1,:]
111 |
112 | feature_n_dir = feature_n.squeeze()
113 |
114 | zone1 = self.MLP_o_zone(feature_o_zone)
115 |
116 | # Fuse two direction feature and input it into transformer
117 | features_dir = torch.cat([ep0(feature_o_dir), ep0(feature_n_dir)], 0)
118 | features = self.ptrans(features_dir, x_in['pos'])
119 |
120 | # Get fused feature
121 | # feature_o_dir2 = features[0, :]
122 | feature_n_dir2 = features[1, :]
123 |
124 | # estimate gaze from fused feature
125 | gaze = self.MLP_n_dir2(feature_n_dir2)
126 | # zone = self.MLP_o_zone(feature_o_zone)
127 |
128 | # Proj
129 | gaze_proj = self.proj(gaze.detach(), x_in['gaze_origin'], x_in['norm_cam'])
130 | gaze_proj_list = []
131 | # gaze_proj_list.append(ep0(feature_o_zone))
132 |
133 | for i, gaze2d in enumerate(gaze_proj):
134 | gaze_proj_list.append(ep0(self.MLP_Pos[i](self.PoseEncoder.encode(gaze2d))))
135 |
136 | triple_feature = torch.cat(gaze_proj_list, 0)
137 | triple_feature = self.tripleTrans(triple_feature)
138 | zone2 = self.MLP_o_zone2(triple_feature)
139 |
140 | feature_o_zone2 = torch.cat([ep0(feature_o_zone), ep0(triple_feature)], 0)
141 | feature_o_zone2 = self.zoneTrans(feature_o_zone2)
142 | zone3 = self.MLP_o_zone3(feature_o_zone2)
143 |
144 | zone = [zone1, zone2, zone3]
145 |
146 | # for loss caculation
147 | loss_gaze_o = []
148 | loss_gaze_n = []
149 | if train:
150 | loss_gaze_n.append(self.MLP_n_dir(feature_n_dir))
151 | loss_gaze_o.append(self.MLP_o_dir(feature_o_dir))
152 |
153 | for i, feature in enumerate(feature_list_o):
154 | loss_gaze_o.append(self.MLPList_o[i](feature))
155 |
156 | for i, feature in enumerate(feature_list_n):
157 | loss_gaze_n.append(self.MLPList_n[i](feature))
158 |
159 | else:
160 | zone = zone2.max(1)[1]
161 |
162 | return gaze, zone, loss_gaze_o, loss_gaze_n
163 |
164 | def loss(self, x_in, label):
165 |
166 | gaze, zones, loss_gaze_o, loss_gaze_n = self.forward(x_in)
167 |
168 | loss1 = 2 * self.loss_op_re(gaze, label.normGaze)
169 |
170 | loss2 = 0
171 | for zone in zones:
172 | loss2 += (0.2/3) * self.loss_op_cls(zone, label.zone.view(-1))
173 |
174 | loss3 = 0
175 | for pred in loss_gaze_o:
176 | loss3 += self.loss_op_re(pred, label.originGaze)
177 |
178 | loss4 = 0
179 | for pred in loss_gaze_n:
180 | loss4 += self.loss_op_re(pred, label.normGaze)
181 | loss = loss1 + loss2 + loss3 + loss4
182 |
183 | return loss, [loss1, loss2, loss3, loss4]
184 |
185 |
186 | if __name__ == '__main__':
187 | x_in = {'origin': torch.zeros([5, 3, 224, 224]).cuda(),
188 | 'norm': torch.zeros([5, 3, 224, 224]).cuda(),
189 | 'pos': torch.zeros(5, 2, 6).cuda()
190 | }
191 |
192 | model = Model()
193 | model = model.to('cuda')
194 | print(model)
195 | a = model(x_in)
196 | print(a)
197 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/IVModule.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import math
5 | import copy
6 | from resnet import resnet18
7 |
8 | def _get_clones(module, N):
9 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
10 |
11 | def ep0(x):
12 | return x.unsqueeze(0)
13 |
14 | class TripleDifferentialProj(nn.Module):
15 |
16 | def __init__(self):
17 | super(TripleDifferentialProj, self).__init__()
18 |
19 | # one 3D point in screen plane
20 | point = torch.Tensor([0, 0, 0])
21 |
22 | # normal vector of screen plane
23 | normal = torch.Tensor([0, 0, 1])
24 |
25 | def forward(self, gaze, origin, norm_mat = None):
26 | # inputted gaze is [pitch, yaw]
27 |
28 | gaze3d = self.gazeto3d(gaze)
29 | if norm_mat != None:
30 | gaze3d = torch.einsum('acd,ad->ac', norm_mat, gaze3d)
31 |
32 | gazex = self.gazeto2dPlus(gaze3d, origin, 0)
33 | gazey = self.gazeto2dPlus(gaze3d, origin, 1)
34 | gazez = self.gazeto2dPlus(gaze3d, origin, 2)
35 | gaze = [gazex, gazey, gazez]
36 | return gaze
37 |
38 | def gazeto3d(self, point):
39 | # Yaw Pitch, Here
40 | x = -torch.cos(point[:, 1]) * torch.sin(point[:, 0])
41 | y = -torch.sin(point[:, 1])
42 | z = -torch.cos(point[:, 1]) * torch.cos(point[:, 0])
43 | gaze = torch.cat([x.unsqueeze(1), y.unsqueeze(1), z.unsqueeze(1)], 1)
44 | return gaze
45 |
46 | def gazeto2dPlus(self, gaze, origin, plane: int):
47 |
48 | assert plane < 3, 'plane should be 0(x), 1(y) or 2(z)'
49 | length = origin[:, plane]
50 | g_len = gaze[:, plane]
51 | scale = -length / g_len
52 | gaze = torch.einsum('ik, i->ik', gaze, scale)
53 | point = origin + gaze
54 | return point
55 |
56 | class TransformerEncoder(nn.Module):
57 |
58 | def __init__(self, encoder_layer, num_layers, norm=None):
59 | super().__init__()
60 | self.layers = _get_clones(encoder_layer, num_layers)
61 | self.num_layers = num_layers
62 | self.norm = norm
63 |
64 | def forward(self, src, pos):
65 | output = src
66 | for layer in self.layers:
67 | output = layer(output, pos)
68 |
69 | if self.norm is not None:
70 | output = self.norm(output)
71 |
72 | return output
73 |
74 | class PoseTransformerEncoderLayer(nn.Module):
75 |
76 | def __init__(self, d_model, nhead, dim_feedforward=512, dropout=0.1):
77 | super().__init__()
78 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
79 | # Implementation of Feedforward model
80 | self.linear1 = nn.Linear(d_model, dim_feedforward)
81 | self.dropout = nn.Dropout(dropout)
82 | self.linear2 = nn.Linear(dim_feedforward, d_model)
83 |
84 | self.norm1 = nn.LayerNorm(d_model)
85 | self.norm2 = nn.LayerNorm(d_model)
86 |
87 | self.dropout1 = nn.Dropout(dropout)
88 | self.dropout2 = nn.Dropout(dropout)
89 |
90 | self.activation = nn.ReLU(inplace=True)
91 |
92 | def pos_embed(self, src, pos):
93 | return src + pos
94 |
95 |
96 | def forward(self, src, pos):
97 | # src_mask: Optional[Tensor] = None,
98 | # src_key_padding_mask: Optional[Tensor] = None):
99 | # pos: Optional[Tensor] = None):
100 |
101 | q = k = self.pos_embed(src, pos)
102 | src2 = self.self_attn(q, k, value=src)[0]
103 | src = src + self.dropout1(src2)
104 | src = self.norm1(src)
105 |
106 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
107 | src = src + self.dropout2(src2)
108 | src = self.norm2(src)
109 | return src
110 |
111 |
112 |
113 | class TransformerEncoderLayer(nn.Module):
114 |
115 | def __init__(self, d_model, nhead, dim_feedforward=512, dropout=0.1):
116 | super().__init__()
117 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
118 | # Implementation of Feedforward model
119 | self.linear1 = nn.Linear(d_model, dim_feedforward)
120 | self.dropout = nn.Dropout(dropout)
121 | self.linear2 = nn.Linear(dim_feedforward, d_model)
122 |
123 | self.norm1 = nn.LayerNorm(d_model)
124 | self.norm2 = nn.LayerNorm(d_model)
125 |
126 | self.dropout1 = nn.Dropout(dropout)
127 | self.dropout2 = nn.Dropout(dropout)
128 |
129 | self.activation = nn.ReLU(inplace=True)
130 |
131 | def pos_embed(self, src, pos):
132 | batch_pos = pos.unsqueeze(1).repeat(1, src.size(1), 1)
133 | return src + batch_pos
134 |
135 |
136 | def forward(self, src, pos):
137 | # src_mask: Optional[Tensor] = None,
138 | # src_key_padding_mask: Optional[Tensor] = None):
139 | # pos: Optional[Tensor] = None):
140 |
141 | q = k = self.pos_embed(src, pos)
142 | src2 = self.self_attn(q, k, value=src)[0]
143 | src = src + self.dropout1(src2)
144 | src = self.norm1(src)
145 |
146 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
147 | src = src + self.dropout2(src2)
148 | src = self.norm2(src)
149 | return src
150 |
151 |
152 | class conv1x1(nn.Module):
153 |
154 | def __init__(self, in_planes, out_planes, stride=1):
155 | super(conv1x1, self).__init__()
156 |
157 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
158 | padding=0, bias=False)
159 |
160 | self.bn = nn.BatchNorm2d(out_planes)
161 |
162 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
163 |
164 | def forward(self, feature):
165 | output = self.conv(feature)
166 | output = self.bn(output)
167 | output = self.avgpool(output)
168 | output = output.squeeze()
169 |
170 | return output
171 |
172 |
173 |
174 | class MLEnhance(nn.Module):
175 |
176 | def __init__(self, input_nums, hidden_num):
177 |
178 | super(MLEnhance, self).__init__()
179 |
180 | # input_nums: [64, 128, 256, 512]
181 | length = len(input_nums)
182 |
183 | self.input_nums = input_nums
184 | self.hidden_num = hidden_num
185 |
186 | self.length = length
187 |
188 | layerList = []
189 |
190 | for i in range(length):
191 | layerList.append(self.__build_layer(input_nums[i], hidden_num))
192 |
193 | self.layerList = nn.ModuleList(layerList)
194 |
195 |
196 | def __build_layer(self, input_num, hidden_num):
197 | layer = conv1x1(input_num, hidden_num)
198 |
199 | return layer
200 |
201 | def forward(self, feature_list):
202 |
203 | out_feature_list = []
204 |
205 | out_feature_gather =[]
206 |
207 | for i, feature in enumerate(feature_list):
208 | result = self.layerList[i](feature)
209 |
210 | # Dim [B, C] -> [1, B, C]
211 | out_feature_list.append(result)
212 |
213 | out_feature_gather.append(ep0(result))
214 |
215 |
216 | # [L, B, C]
217 | feature = torch.cat(out_feature_gather, 0)
218 | return feature, out_feature_list
219 |
220 | class PositionalEncoder():
221 | # encode low-dim, vec to high-dims.
222 |
223 | def __init__(self, number_freqs, include_identity=False):
224 | freq_bands = torch.pow(2, torch.linspace(0., number_freqs - 1, number_freqs))
225 | self.embed_fns = []
226 | self.output_dim = 0
227 |
228 | if include_identity:
229 | self.embed_fns.append(lambda x:x)
230 | self.output_dim += 1
231 |
232 | for freq in freq_bands:
233 | for transform_fns in [torch.sin, torch.cos]:
234 | self.embed_fns.append(lambda x, fns=transform_fns, freq=freq: fns(x*freq))
235 | self.output_dim += 1
236 |
237 | def encode(self, vecs):
238 | # inputs: [B, N]
239 | # outputs: [B, N*number_freqs*2]
240 | return torch.cat([fn(vecs) for fn in self.embed_fns], -1)
241 |
242 | def getDims(self):
243 | return self.output_dim
244 |
245 | class PoseTransformer(nn.Module):
246 |
247 | def __init__(self, input_dim, pos_length, nhead=8, hidden_dim=512, layer_num = 6,
248 | pos_freq=30, pos_ident=True, pos_hidden=128, dropout=0.1):
249 |
250 | super(PoseTransformer, self).__init__()
251 | # input feature + added token
252 |
253 | # The input feature should be [L, Batch, Input_dim]
254 | encoder_layer = PoseTransformerEncoderLayer(
255 | input_dim,
256 | nhead = nhead,
257 | dim_feedforward = hidden_dim,
258 | dropout=dropout)
259 |
260 | encoder_norm = nn.LayerNorm(input_dim)
261 | self.encoder = TransformerEncoder(encoder_layer, num_layers = layer_num, norm = encoder_norm)
262 |
263 | self.pos_embedding = PositionalEncoder(pos_freq, pos_ident)
264 |
265 | out_dim = pos_length * (pos_freq * 2 + pos_ident)
266 |
267 | self.pos_encode = nn.Sequential(
268 | nn.Linear(out_dim, pos_hidden),
269 | nn.LeakyReLU(0.1),
270 | nn.Linear(pos_hidden, pos_hidden),
271 | nn.LeakyReLU(0.1),
272 | nn.Linear(pos_hidden, pos_hidden),
273 | nn.LeakyReLU(0.1),
274 | nn.Linear(pos_hidden, input_dim)
275 | )
276 |
277 | def forward(self, feature, pos_feature):
278 |
279 | """
280 | Inputs:
281 | feature: [length, batch, dim1]
282 | pos_feature: [batch, length, dim2]
283 |
284 | Outputs:
285 | feature: [batch, length, dim1]
286 |
287 |
288 | """
289 |
290 | # feature: [Length, Batch, Dim]
291 | pos_feature = self.pos_embedding.encode(pos_feature)
292 | pos_feature = self.pos_encode(pos_feature)
293 | pos_feature = pos_feature.permute(1, 0, 2)
294 |
295 | # feature [Length, batch, dim]
296 | feature = self.encoder(feature, pos_feature)
297 | # feature = feature.permute(1, 0, 2)
298 |
299 | return feature
300 |
301 | class Transformer(nn.Module):
302 |
303 | def __init__(self, input_dim, nhead=8, hidden_dim=512, layer_num = 6, pred_num=1, length=4, dropout=0.1):
304 |
305 | super(Transformer, self).__init__()
306 |
307 | self.pnum = pred_num
308 | # input feature + added token
309 | # self.length = length + 1
310 | self.length = length
311 |
312 | # The input feature should be [L, Batch, Input_dim]
313 | encoder_layer = TransformerEncoderLayer(
314 | input_dim,
315 | nhead = nhead,
316 | dim_feedforward = hidden_dim,
317 | dropout=dropout)
318 |
319 | encoder_norm = nn.LayerNorm(input_dim)
320 |
321 | self.encoder = TransformerEncoder(encoder_layer, num_layers = layer_num, norm = encoder_norm)
322 |
323 | self.cls_token = nn.Parameter(torch.randn(pred_num, 1, input_dim))
324 |
325 | self.token_pos_embedding = nn.Embedding(pred_num, input_dim)
326 | self.pos_embedding = nn.Embedding(length, input_dim)
327 |
328 |
329 | def forward(self, feature, num = 1):
330 |
331 | # feature: [Length, Batch, Dim]
332 | batch_size = feature.size(1)
333 |
334 | # cls_num, 1, Dim -> cls_num, Batch_size, Dim
335 |
336 | feature_list = []
337 |
338 | for i in range(num):
339 |
340 | cls = self.cls_token[i, :, :].repeat((1, batch_size, 1))
341 |
342 | feature_in = torch.cat([cls, feature], 0)
343 |
344 | # position
345 | position = torch.from_numpy(np.arange(self.length)).cuda()
346 | pos_feature = self.pos_embedding(position)
347 |
348 | token_position = torch.Tensor([i]).long().cuda()
349 | token_pos_feature = self.token_pos_embedding(token_position)
350 |
351 | pos_feature = torch.cat([pos_feature, token_pos_feature], 0)
352 |
353 | # feature [Length, batch, dim]
354 | feature_out = self.encoder(feature_in, pos_feature)
355 |
356 | # [batch, dim, length]
357 | # feature = feature.permute(1, 2, 0)
358 |
359 | # get the first dimension, [pnum, batch, dim]
360 | feature_out = feature_out[0:1, :, :]
361 | feature_list.append(feature_out)
362 |
363 | return torch.cat(feature_list, 0).squeeze()
364 |
365 | class Backbone(nn.Module):
366 |
367 | def __init__(self, outFeatureNum=1, transIn=128, convDims=[64, 128, 256, 512]):
368 |
369 | super(Backbone, self).__init__()
370 |
371 | self.base_model = resnet18(pretrained=True, input_dim = convDims)
372 |
373 | self.transformer = Transformer(input_dim = transIn, nhead = 8, hidden_dim=512,
374 | layer_num=6, pred_num = outFeatureNum, length = len(convDims))
375 |
376 | # convert multi-scale feature
377 | self.mle = MLEnhance(input_nums = convDims, hidden_num = transIn)
378 |
379 | self.feed = nn.Linear(transIn, 2)
380 |
381 | self.loss_op = nn.L1Loss()
382 |
383 | self.loss_layerList = []
384 |
385 | self.outFeatureNum = outFeatureNum
386 |
387 | # for i in range(len(convDims)):
388 | # self.loss_layerList.append(nn.Linear(tranIn, 2))
389 |
390 | def forward(self, image):
391 |
392 | x1, x2, x3, x4 = self.base_model(image)
393 |
394 | feature, feature_list = self.mle([x1, x2, x3, x4])
395 |
396 | #t_gaze = []
397 | # for t_feature in feature_list:
398 | # t_gaze.append(self.loss_layerList(t_feature))
399 |
400 | feature = self.transformer(feature, self.outFeatureNum)
401 |
402 | # gaze = self.feed(feature)
403 |
404 | return feature, feature_list
405 |
406 |
--------------------------------------------------------------------------------