├── .gitignore ├── README.md ├── datapath.yaml ├── models.py ├── utils.py ├── datareader.py ├── resnet.py └── adapt.py /.gitignore: -------------------------------------------------------------------------------- 1 | id_* 2 | __pycache__* 3 | __pycache__/ 4 | .code* 5 | .coode/* 6 | *.pt 7 | *.log 8 | *_GA 9 | logs 10 | mean 11 | temp 12 | tmp 13 | test 14 | *.txt 15 | data 16 | figures 17 | core* 18 | *.pdf 19 | tools/ 20 | *.png 21 | adapt_logs/** 22 | checkpoints/** -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UnReGA 2 | Official code for the CVPR 2023 paper "Source-free Adaptive Gaze Estimation by Uncertainty Reduction". 3 | 4 | The xgaze-trained checkpoints and the enhanced MPIIRes datasets are available at https://drive.google.com/drive/folders/1f4pGXCgxzbMeZmArlDo2d3ysW6K9sHnx?usp=drive_link. -------------------------------------------------------------------------------- /datapath.yaml: -------------------------------------------------------------------------------- 1 | 2 | eth: 3 | image: "/data1/GazeData/eth/train/Image" 4 | label: "/data1/GazeData/eth/train/Label" 5 | mpii: 6 | image: "/data1/GazeData/MPIIFaceGaze/Image" 7 | label: "/data1/GazeData/MPIIFaceGaze/Label" 8 | mpiires: 9 | image: "/data1/GazeData/MPIIRes/Image" 10 | label: "/data1/GazeData/MPIIRes/Label" 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from resnet import resnet18, resnet50 4 | import torch.nn.functional as F 5 | from scipy.stats import norm 6 | 7 | 8 | class GazeRes(nn.Module): 9 | def __init__(self, backbone = "res18", drop_p=0.5): 10 | super(GazeRes, self).__init__() 11 | self.img_feature_dim = 256 # the dimension of the CNN feature to represent each frame 12 | if backbone == "res18": 13 | self.base_model = resnet18(pretrained=True) 14 | elif backbone == "res50": 15 | self.base_model = resnet50(pretrained=True) 16 | 17 | 18 | self.base_model.fc2 = nn.Linear(1000, self.img_feature_dim) 19 | 20 | # The linear layer that maps the LSTM with the 3 outputs 21 | self.last_layer = nn.Linear(self.img_feature_dim, 2) 22 | self.drop = nn.Dropout(drop_p) 23 | 24 | def forward(self, x_in): 25 | base_out = self.base_model(x_in["face"]) 26 | base_out = torch.flatten(base_out, start_dim=1) 27 | output = self.drop(base_out) 28 | output = self.last_layer(output) 29 | angular_output = output[:, :2] 30 | 31 | return angular_output, base_out 32 | 33 | 34 | class UncertaintyLoss(nn.Module): 35 | def __index__(self): 36 | sum(UncertaintyLoss, self).__init__() 37 | def forward(self, gaze, gaze_ema): 38 | assert gaze.shape == gaze_ema.shape 39 | std = torch.std(gaze, dim=2).reshape(-1, 2, 1) 40 | return torch.mean(std) 41 | def forward(self, gaze, gaze_ema, significant=1, std_alpha=0.2, gamma=0.01): 42 | assert gaze.shape == gaze_ema.shape 43 | std = torch.std(gaze, dim=2).reshape(-1, 2, 1) 44 | return torch.mean(std) 45 | 46 | class UncertaintyPseudoLabelLoss(nn.Module): 47 | def __init__(self, lamda_pseudo = 0.5): 48 | super(UncertaintyPseudoLabelLoss, self).__init__() 49 | self.lamda_pseudo = lamda_pseudo 50 | def forward(self, gaze, gaze_ema): 51 | assert gaze.shape == gaze_ema.shape 52 | std = torch.std(gaze, dim=2).reshape(-1, 2, 1) 53 | mean = torch.mean(gaze_ema, dim=2).reshape(-1, 2, 1) 54 | return torch.mean(std) + self.lamda_pseudo * torch.mean(torch.abs(gaze - mean)) 55 | 56 | class UncertaintyWPseudoLabelLoss(nn.Module): 57 | def __init__(self, lamda_pseudo = 0.5): 58 | super(UncertaintyWPseudoLabelLoss, self).__init__() 59 | self.lamda_pseudo = lamda_pseudo 60 | def forward(self, gaze, gaze_ema): 61 | assert gaze.shape == gaze_ema.shape 62 | std = torch.std(gaze, dim=2).reshape(-1, 2, 1) 63 | mean = torch.mean(gaze_ema, dim=2).reshape(-1, 2, 1) 64 | return torch.mean(std) + self.lamda_pseudo * torch.mean(torch.abs(gaze - mean) / std.detach()) 65 | 66 | 67 | class WeightedPseudoLabelLoss(nn.Module): 68 | def __init__(self, lamda_pseudo = 0.5): 69 | super(WeightedPseudoLabelLoss, self).__init__() 70 | # self.lamda_pseudo = lamda_pseudo 71 | def forward(self, gaze, gaze_ema): 72 | assert gaze.shape == gaze_ema.shape 73 | std = torch.std(gaze, dim=2).reshape(-1, 2, 1) 74 | mean = torch.mean(gaze_ema, dim=2).reshape(-1, 2, 1) 75 | return torch.mean(torch.abs(gaze - mean) / std.detach()) 76 | 77 | class PseudoLabelLoss(nn.Module): 78 | def __init__(self, lamda_pseudo = 0.5): 79 | super(PseudoLabelLoss, self).__init__() 80 | # self.lamda_pseudo = lamda_pseudo 81 | def forward(self, gaze, gaze_ema): 82 | assert gaze.shape == gaze_ema.shape 83 | mean = torch.mean(gaze_ema, dim=2).reshape(-1, 2, 1) 84 | return torch.mean(torch.abs(gaze - mean)) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import time 7 | import sys 8 | import os 9 | import copy 10 | import yaml 11 | import torch.nn.functional as F 12 | import random 13 | import collections 14 | import models 15 | 16 | def seed_everything(seed): 17 | """ 18 | Function to set random seeds for reproducibility. 19 | 20 | Args: 21 | seed (int): Random seed value. 22 | 23 | """ 24 | random.seed(seed) 25 | os.environ['PYTHONHASHSEED'] = str(seed) 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) 30 | torch.backends.cudnn.deterministic = True 31 | 32 | class AverageMeter(object): 33 | """ 34 | Computes and stores the average and current value. 35 | 36 | """ 37 | def __init__(self): 38 | self.reset() 39 | 40 | def reset(self): 41 | """ 42 | Reset the meter's values. 43 | 44 | """ 45 | self.val = 0 46 | self.avg = 0 47 | self.sum = 0 48 | self.count = 0 49 | 50 | def update(self, val, n=1): 51 | """ 52 | Update the meter with a new value. 53 | 54 | Args: 55 | val (float): New value to update the meter. 56 | n (int): Number of elements represented by the value. 57 | 58 | """ 59 | self.val = val 60 | self.sum += val * n 61 | self.count += n 62 | self.avg = self.sum / self.count 63 | 64 | def update_ema_params(model, ema_model, alpha, global_step): 65 | """ 66 | Update the Exponential Moving Average (EMA) of model parameters. 67 | 68 | Args: 69 | model (nn.Module): Model whose parameters are being updated. 70 | ema_model (nn.Module): EMA model that stores the averaged parameters. 71 | alpha (float): EMA decay parameter. 72 | global_step (int): Current global step of the training. 73 | 74 | """ 75 | alpha = min(1 - 1 / (global_step + 1), alpha) 76 | 77 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 78 | # Update EMA parameters with a weighted sum of current and EMA parameters 79 | ema_param.data.mul_(alpha).add_(1 - alpha, param.data) 80 | 81 | def mean_models_params(models): 82 | """ 83 | Compute the mean of model parameters from a list of models. 84 | 85 | Args: 86 | models (list): List of models to average parameters from. 87 | 88 | Returns: 89 | OrderedDict: Mean state_dict of model parameters. 90 | 91 | """ 92 | worker_state_dict = [x.state_dict() for x in models] 93 | weight_keys = list(worker_state_dict[0].keys()) 94 | fed_state_dict = collections.OrderedDict() 95 | 96 | for key in weight_keys: 97 | key_sum = 0 98 | for i in range(len(models)): 99 | key_sum = key_sum + worker_state_dict[i][key] 100 | fed_state_dict[key] = key_sum / len(models) 101 | 102 | return fed_state_dict 103 | 104 | def torch_angular_error(a, b, sum=False): 105 | """ 106 | Calculate the angular error between two sets of pitch-yaw angles. 107 | 108 | Args: 109 | a (Tensor): Tensor of pitch-yaw angles. 110 | b (Tensor): Tensor of pitch-yaw angles to compare against. 111 | sum (bool, optional): Whether to return the sum or mean of angular errors. 112 | 113 | Returns: 114 | float: Angular error or sum of angular errors. 115 | 116 | """ 117 | def pitchyaw_to_vector(pitchyaws): 118 | sin = torch.sin(pitchyaws) 119 | cos = torch.cos(pitchyaws) 120 | return torch.stack([cos[:, 0] * sin[:, 1], sin[:, 0], cos[:, 0] * cos[:, 1]], 1) 121 | 122 | def nn_angular_distance(a, b): 123 | sim = F.cosine_similarity(a, b, eps=1e-6) 124 | sim = F.hardtanh(sim, -1.0 + 1e-6, 1.0 - 1e-6) 125 | return torch.acos(sim) * 180.0 / np.pi 126 | 127 | y = pitchyaw_to_vector(a) 128 | y_hat = b 129 | 130 | if y_hat.shape[1] == 2: 131 | y_hat = pitchyaw_to_vector(y_hat) 132 | if sum: 133 | return torch.sum(nn_angular_distance(y, y_hat)) 134 | else: 135 | return torch.mean(nn_angular_distance(y, y_hat)) 136 | 137 | # Default case: Return the mean of angular errors 138 | return nn_angular_distance(y, y_hat).mean() 139 | 140 | def build_adaptation_loss(loss, lamda_pseudo = 0.01): 141 | if loss == "uncertainty": 142 | adaptation_loss = models.UncertaintyLoss().cuda() 143 | elif loss == "wpseudo": 144 | adaptation_loss = models.PseudoLabelLoss().cuda() 145 | elif loss == "pseudo": 146 | adaptation_loss = models.WeightedPseudoLabelLoss().cuda() 147 | elif loss == "uncertain_pseudo": 148 | adaptation_loss = models.UncertaintyPseudoLabelLoss(lamda_pseudo).cuda() 149 | elif loss == "uncertain_wpseudo": 150 | adaptation_loss = models.UncertaintyWPseudoLabelLoss(lamda_pseudo).cuda() 151 | return adaptation_loss -------------------------------------------------------------------------------- /datareader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | from torch.utils.data import Dataset, DataLoader 5 | import torch 6 | import pathlib 7 | import random 8 | import torchvision.transforms as transforms 9 | 10 | from PIL import Image 11 | def gazeto2d(gaze): 12 | yaw = np.arctan2(-gaze[0], -gaze[2]) 13 | pitch = np.arcsin(-gaze[1]) 14 | return np.array([yaw, pitch]) 15 | 16 | 17 | 18 | def get_transform(grayscale=False, convert=True, crop = False): 19 | transform_list = [] 20 | transform_list += [transforms.ToPILImage()] 21 | if grayscale: 22 | transform_list.append(transforms.Grayscale(1)) 23 | if crop: 24 | transform_list += [transforms.CenterCrop(192)] 25 | transform_list += [transforms.Resize(224)] 26 | if convert: 27 | transform_list += [transforms.ToTensor()] 28 | if grayscale: 29 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 30 | else: 31 | transform_list += [transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] 32 | return transforms.Compose(transform_list) 33 | 34 | class loader(Dataset): 35 | def __init__(self, path, root, pic_num, header=True, target = "mpii"): 36 | self.lines = [] 37 | self.pic_num = pic_num 38 | self.target = target 39 | if isinstance(path, list): 40 | for i in path: 41 | with open(i) as f: 42 | line = f.readlines() 43 | if header: line.pop(0) 44 | self.lines.extend(line) 45 | else: 46 | with open(path) as f: 47 | self.lines = f.readlines() 48 | if header: self.lines.pop(0) 49 | if self.pic_num >= 0: 50 | self.lines = self.lines[:self.pic_num] 51 | self.root = pathlib.Path(root) 52 | self.transform = get_transform() 53 | 54 | def __len__(self): 55 | # if self.pic_num < 0: 56 | return len(self.lines) 57 | # return self.pic_num 58 | 59 | def __getitem__(self, idx): 60 | line = self.lines[idx] 61 | line = line.strip().split(" ") 62 | # print(line) 63 | 64 | # name = line[0].split('/')[0] 65 | name = line[0] 66 | # if self.target == "mpii": 67 | # gaze2d = line[7] 68 | # head2d = line[8] 69 | # else: 70 | gaze2d = line[1] 71 | head2d = line[2] 72 | 73 | # lefteye = line[1] 74 | # righteye = line[2] 75 | face = line[0] 76 | # if self.target == "mpii": 77 | # label = np.array(gaze2d.split(",")[::-1]).astype("float") 78 | # else: 79 | label = np.array(gaze2d.split(",")).astype("float") 80 | label = torch.from_numpy(label[:2]).type(torch.FloatTensor) 81 | # print(label.shape) 82 | headpose = np.array(head2d.split(",")).astype("float") 83 | headpose = torch.from_numpy(headpose[:2]).type(torch.FloatTensor) 84 | 85 | # rimg = cv2.imread(os.path.join(self.root, righteye))/255.0 86 | # rimg = rimg.transpose(2, 0, 1) 87 | 88 | # limg = cv2.imread(os.path.join(self.root, lefteye))/255.0 89 | # limg = limg.transpose(2, 0, 1) 90 | 91 | # print(self.root/name/ face) 92 | imgpath = str(self.root / face) 93 | if self.target[:3] == "mix": imgpath = face 94 | fimg = cv2.imread(imgpath) 95 | 96 | ycrcb = cv2.cvtColor(fimg, cv2.COLOR_BGR2YCrCb) 97 | ycrcb[:, :, 0] = cv2.equalizeHist(ycrcb[:, :, 0]) 98 | fimg = cv2.cvtColor(ycrcb, cv2.COLOR_YCrCb2RGB) 99 | # fimg=crop(fimg) 100 | # print(fimg.shape) 101 | # fimg = cv2.resize(fimg, (448, 448)) / 255.0 102 | 103 | fimg = self.transform(fimg) 104 | img = {"face": fimg, 105 | "head_pose": headpose, 106 | "name": name} 107 | 108 | # img = {"left":torch.from_numpy(limg).type(torch.FloatTensor), 109 | # "right":torch.from_numpy(rimg).type(torch.FloatTensor), 110 | # "face":torch.from_numpy(fimg).type(torch.FloatTensor), 111 | # "head_pose":headpose, 112 | # "name":name} 113 | return img, label 114 | 115 | 116 | def txtload(labelpath, imagepath, batch_size, pic_num=-1, shuffle=True, num_workers=0, header=True, target = 1): 117 | # print(labelpath,imagepath) 118 | dataset = loader(labelpath, imagepath, pic_num, header, target = target) 119 | print(f"[Read Data]: Total num: {len(dataset)}") 120 | # print(f"[Read Data]: Label path: {labelpath}") 121 | # print(dataset.lines[:10]) 122 | load = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 123 | return load 124 | 125 | 126 | def seed_everything(seed): 127 | random.seed(seed) 128 | os.environ['PYTHONHASHSEED'] = str(seed) 129 | np.random.seed(seed) 130 | # torch.manual_seed(seed) 131 | torch.cuda.manual_seed(seed) 132 | torch.cuda.manual_seed_all(seed) 133 | torch.backends.cudnn.deterministic = True 134 | 135 | if __name__ == "__main__": 136 | # seed_everything(1) 137 | path = '/home/caixin/GazeData/MPIIFaceGaze/Label/p00.label' 138 | d = txtload(path, '/home/caixin/GazeData/MPIIFaceGaze/Image', batch_size=32, pic_num=5, 139 | shuffle=False, num_workers=4, header=True) 140 | print(len(d)) 141 | for i, (img, label) in enumerate(d): 142 | print(i, label) 143 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, num_classes=1000): 98 | self.inplanes = 64 99 | super(ResNet, self).__init__() 100 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 101 | bias=False) 102 | self.bn1 = nn.BatchNorm2d(64) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | self.layer1 = self._make_layer(block, 64, layers[0]) 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 109 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 110 | #self.avgpool = nn.AdaptativeAvgPool((1,1), stride=1) 111 | self.fc1 = nn.Linear(512 * block.expansion, 1000) 112 | self.fc2 = nn.Linear(1000, 3) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 117 | elif isinstance(m, nn.BatchNorm2d): 118 | nn.init.constant_(m.weight, 1) 119 | nn.init.constant_(m.bias, 0) 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | x = self.layer1(x) 144 | feat_D = self.layer2(x) 145 | x = self.layer3(feat_D) 146 | x = self.layer4(x) 147 | #print('Size at output',x.size()) 148 | x = self.avgpool(x) 149 | x = x.view(x.size(0), -1) 150 | #x = nn.Dropout()(x) 151 | x = nn.ReLU()(self.fc1(x)) 152 | x = self.fc2(x) 153 | 154 | return x 155 | 156 | 157 | 158 | class ResNetCAM(nn.Module): 159 | 160 | def __init__(self, block, layers, num_classes=1000): 161 | self.inplanes = 64 162 | super(ResNetCAM, self).__init__() 163 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 164 | bias=False) 165 | self.bn1 = nn.BatchNorm2d(64) 166 | self.relu = nn.ReLU(inplace=True) 167 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 168 | self.layer1 = self._make_layer(block, 64, layers[0]) 169 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 170 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 171 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 172 | self.avgpool = nn.AvgPool2d(7, stride=1) 173 | self.fc1 = nn.Linear(512 * block.expansion, 1000) 174 | self.fc2 = nn.Linear(1000, 3) 175 | 176 | for m in self.modules(): 177 | if isinstance(m, nn.Conv2d): 178 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 179 | elif isinstance(m, nn.BatchNorm2d): 180 | nn.init.constant_(m.weight, 1) 181 | nn.init.constant_(m.bias, 0) 182 | 183 | def _make_layer(self, block, planes, blocks, stride=1): 184 | downsample = None 185 | if stride != 1 or self.inplanes != planes * block.expansion: 186 | downsample = nn.Sequential( 187 | nn.Conv2d(self.inplanes, planes * block.expansion, 188 | kernel_size=1, stride=stride, bias=False), 189 | nn.BatchNorm2d(planes * block.expansion), 190 | ) 191 | 192 | layers = [] 193 | layers.append(block(self.inplanes, planes, stride, downsample)) 194 | self.inplanes = planes * block.expansion 195 | for i in range(1, blocks): 196 | layers.append(block(self.inplanes, planes)) 197 | 198 | return nn.Sequential(*layers) 199 | 200 | def forward(self, x): 201 | x = self.conv1(x) 202 | x = self.bn1(x) 203 | x = self.relu(x) 204 | x = self.maxpool(x) 205 | x = self.layer1(x) 206 | x2 = self.layer2(x) 207 | x2 = self.layer3(x2) 208 | x2 = self.layer4(x2) 209 | return x,x2 210 | 211 | def resnetCAM(pretrained=False, **kwargs): 212 | """Constructs a ResNet-18 model. 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | """ 216 | model = ResNetCAM(BasicBlock, [2, 2, 2, 2], **kwargs) 217 | return model 218 | 219 | def resnet18(pretrained=False, **kwargs): 220 | """Constructs a ResNet-18 model. 221 | Args: 222 | pretrained (bool): If True, returns a model pre-trained on ImageNet 223 | """ 224 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 225 | if pretrained: 226 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']),strict=False) 227 | return model 228 | 229 | 230 | def resnet34(pretrained=False, **kwargs): 231 | """Constructs a ResNet-34 model. 232 | Args: 233 | pretrained (bool): If True, returns a model pre-trained on ImageNet 234 | """ 235 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 236 | if pretrained: 237 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 238 | return model 239 | 240 | 241 | def resnet50(pretrained=False, **kwargs): 242 | """Constructs a ResNet-50 model. 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | """ 246 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 247 | if pretrained: 248 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']),strict=False) 249 | return model 250 | 251 | 252 | def resnet101(pretrained=False, **kwargs): 253 | """Constructs a ResNet-101 model. 254 | Args: 255 | pretrained (bool): If True, returns a model pre-trained on ImageNet 256 | """ 257 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 258 | if pretrained: 259 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 260 | return model 261 | 262 | 263 | def resnet152(pretrained=False, **kwargs): 264 | """Constructs a ResNet-152 model. 265 | Args: 266 | pretrained (bool): If True, returns a model pre-trained on ImageNet 267 | """ 268 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 269 | if pretrained: 270 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 271 | return model 272 | 273 | -------------------------------------------------------------------------------- /adapt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import os 4 | import torch 5 | import torch.optim as optim 6 | import random 7 | import time 8 | import numpy as np 9 | import sys 10 | import copy 11 | from torch import nn 12 | from utils import * 13 | import models 14 | from datareader import txtload 15 | 16 | def configure_model(model): 17 | """Configure model for freezing batch_norm.""" 18 | model.eval() 19 | return model 20 | 21 | # test result of statedict on dataset 22 | def test(statedict, dataset, outfile, epoch=0): 23 | """Test model on dataset. 24 | 25 | Args: 26 | statedict (dict): Model state_dict. 27 | dataset (torch.utils.data.Dataloader): Dataset to test on. 28 | outfile (file): File to write results to. 29 | epoch (int): Epoch number. 30 | """ 31 | global best_result_dict 32 | net = models.GazeRes(args.backbone) 33 | net.to(device) 34 | configure_model(net) 35 | net.load_state_dict(statedict, strict=False) 36 | accs = 0 37 | count = 0 38 | with torch.no_grad(): 39 | for j, (data, label) in enumerate(dataset): 40 | img = data["face"].to(device) 41 | names = data["name"] 42 | 43 | img = {"face": img} 44 | gts = label.to(device) 45 | gazes, _ = net(img) 46 | accs += torch_angular_error(gazes, gts) * gts.shape[0] 47 | count += gts.shape[0] 48 | avg_acc = accs / count 49 | loger = f"[{epoch}] Total Num: {count}, avg: {avg_acc:.3f} \n" 50 | print(loger) 51 | outfile.write(loger) 52 | outfile.flush() 53 | return avg_acc 54 | 55 | 56 | def test_ensemble(nets, dataset, outfile, epoch=0): 57 | """ 58 | Test average performance of models(nets) on dataset. 59 | Args: 60 | nets (list): List of models to test. 61 | dataset (torch.utils.data.Dataloader): Dataset to test on. 62 | outfile (file): File to write results to. 63 | epoch (int): Epoch number. 64 | """ 65 | 66 | for net in nets: 67 | net.eval() 68 | accs = 0 69 | count = 0 70 | with torch.no_grad(): 71 | for j, (data, label) in enumerate(dataset): 72 | img = data["face"].to(device) 73 | names = data["name"] 74 | 75 | img = {"face": img} 76 | gts = label.to(device) 77 | avg_gazes = 0 78 | for net in nets: 79 | gazes, _ = net(img) 80 | avg_gazes = avg_gazes + gazes 81 | avg_gazes = avg_gazes / len(nets) 82 | accs += torch_angular_error(gazes, gts) * gts.shape[0] 83 | count += gts.shape[0] 84 | 85 | avg_acc = accs / count 86 | loger = f"[{epoch}] Total Num: {count}, avg: {avg_acc:.3f} \n" 87 | print(loger) 88 | outfile.write(loger) 89 | 90 | return avg_acc 91 | 92 | def train_test(train_data, test_data, iteration, adapt_loss_op, outfile): 93 | """ 94 | Train and test model on dataset. 95 | Args: 96 | train_data (torch.utils.data.Dataloader): Dataset to train on. 97 | test_data (torch.utils.data.Dataloader): Dataset to test on. 98 | nets (list): List of models to train. 99 | nets_ema (list): List of EMA models to train. 100 | iteration (int): Number of iterations to train. 101 | """ 102 | 103 | # Initialize models 104 | for i in range(len(nets)): 105 | nets[i].load_state_dict(nets_init[i].state_dict()) 106 | nets_ema[i].load_state_dict(nets_init[i].state_dict()) 107 | configure_model(nets[i]) 108 | configure_model(nets_ema[i]) 109 | # Optimizer 110 | optimizer = optim.Adam(params, lr=args.lr, betas=(0.9, 0.95)) 111 | for i in range(iteration): 112 | 113 | gazes = torch.Tensor().to(device) 114 | gazes_ema = torch.Tensor().to(device) 115 | # Randomly sample 20 indices from training data 116 | indices = random.sample(range(train_data["face"].shape[0]), 20) 117 | img = train_data["face"][indices] 118 | img = {"face": img} 119 | for k in range(len(nets)): 120 | gaze, feature = nets[k](img) 121 | gazes = torch.cat((gazes, gaze.reshape(-1, 2, 1)), 2) 122 | gaze_ema, feature = nets_ema[k](img) 123 | gazes_ema = torch.cat((gazes_ema, gaze_ema.reshape(-1, 2, 1)), 2) 124 | 125 | outlier_loss = adapt_loss_op(gazes, gazes_ema) 126 | optimizer.zero_grad() 127 | outlier_loss.backward() 128 | optimizer.step() 129 | for k in range(len(nets)): 130 | update_ema_params(nets[k], nets_ema[k], 0.99, i) 131 | # print(outlier_loss.item()) 132 | outfile.write("Outlier_loss: %.4f \n"%(outlier_loss.item())) 133 | outfile.flush() 134 | statedict = mean_models_params(nets) 135 | error = test(statedict, test_data, outfile, i) 136 | 137 | 138 | return error 139 | 140 | if __name__ == "__main__": 141 | parser = argparse.ArgumentParser(description='Domain Adaptation') 142 | parser.add_argument('--seed', type=int, default=1, help='random seed') 143 | parser.add_argument('--backbone', type=str, default='res18', help='backbone') 144 | parser.add_argument('--batch_size', type=int, default=20, help='batch size') 145 | parser.add_argument('--num_models', type=float, default=10, help='number of pretrained models(>1)') 146 | parser.add_argument('--iteration', type=int, default=50, help='iteration for adaptation') 147 | parser.add_argument('--shuffle', type=bool, default=True, help='shuffle adaptation dataset') 148 | parser.add_argument('--target', type=str, default='mpii', help='target dataset, mpii/edp/capture') 149 | parser.add_argument('--source', type=str, default='eth', help='source dataset, eth/gaze360') 150 | parser.add_argument('--savepath', type=str, default="", help='save path for logs and models') 151 | parser.add_argument('-l', '--loss', default= "uncertain_wpseudo", help="the loss type for adapt") 152 | parser.add_argument('-lp', '--lamda_pseudo',type=float, default= 0.0001, help="the weight for pseudo loss") 153 | parser.add_argument('-n', '--num_experiments', default= 100, help="the number of experiments") 154 | parser.add_argument('--lr', type=float, default=2e-5, help="the learning rate") 155 | parser.add_argument('--ckpt_path',default="checkpoints/xgaze", help="the path of source model ckpts") 156 | # use config file 157 | # ... parse other arguments ... 158 | args = parser.parse_args() 159 | 160 | # Load configuration 161 | config = yaml.load(open("datapath.yaml"), Loader=yaml.FullLoader) 162 | 163 | imagepath_target = config[args.target]["image"] 164 | labelpath_target = config[args.target]["label"] 165 | # Set random seed 166 | seed_everything(args.seed) 167 | 168 | # Set device 169 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 170 | 171 | # Read target data 172 | if os.path.isdir(labelpath_target): 173 | folder_target = os.listdir(labelpath_target) 174 | folder_target.sort() 175 | else: 176 | folder_target = [os.path.basename(labelpath_target)] 177 | labelpath_target = os.path.dirname(labelpath_target) 178 | 179 | labelpath_list = [os.path.join(labelpath_target, j) for j in folder_target] 180 | dataset_target_for_adaptation = txtload(labelpath_list, imagepath_target, args.batch_size, 181 | shuffle=args.shuffle, num_workers=4, header=True, target=args.target) 182 | 183 | dataset_target = txtload(labelpath_list, imagepath_target, 256, 184 | shuffle=False, num_workers=8, header=True, target = args.target) 185 | # makdir for savepath 186 | savepath = os.path.join("adapt_logs", args.savepath, f"batch_size={args.batch_size}_iteration={args.iteration}_lr={args.lr}_loss={args.loss}_shuffle={args.shuffle}") 187 | if args.loss == "uncertain_pseudo" or args.loss == "uncertain_wpseudo": 188 | savepath = os.path.join("adapt_logs", args.savepath, f"batch_size={args.batch_size}_iteration={args.iteration}_lamda_pseudo={args.lamda_pseudo}_lr={args.lr}_loss={args.loss}_shuffle={args.shuffle}") 189 | if not os.path.exists(savepath): 190 | os.makedirs(savepath, exist_ok = True) 191 | 192 | # Model initialization 193 | 194 | params = [] 195 | loc = "cuda:0" 196 | device = torch.device(loc if torch.cuda.is_available() else "cpu") 197 | ckpt_path = args.ckpt_path 198 | if os.path.isdir(ckpt_path): 199 | ckpt_list = os.listdir(ckpt_path) 200 | # sort ckpt_list 201 | ckpt_list.sort(key=lambda x: int(x.split("=")[1].split(".")[0]),reverse=True) 202 | pre_models = [os.path.join(ckpt_path, j) for j in ckpt_list] 203 | elif os.path.isfile(ckpt_path): 204 | pre_models = [ckpt_path] 205 | else: 206 | raise ValueError("No such ckpt path") 207 | 208 | n = len(pre_models) 209 | n = min(n,args.num_models) 210 | 211 | nets = [models.GazeRes(args.backbone) for _ in range(n)] 212 | nets_ema = [models.GazeRes(args.backbone) for _ in range(n)] 213 | nets_init = [models.GazeRes(args.backbone) for _ in range(n)] 214 | 215 | 216 | for i in range(n): 217 | print(pre_models[i]) 218 | pretrain = torch.load(pre_models[i], map_location=loc) 219 | statedict = pretrain if "state_dict" not in pretrain else pretrain["state_dict"] 220 | nets[i].to(device) 221 | nets[i].load_state_dict(statedict) 222 | nets[i].eval() 223 | nets_ema[i].to(device) 224 | nets_ema[i].load_state_dict(statedict) 225 | nets_ema[i].eval() 226 | nets_init[i].to(device) 227 | nets_init[i].load_state_dict(statedict) 228 | nets_init[i].eval() 229 | for value in nets[i].parameters(): 230 | if value.requires_grad: 231 | params += [{'params': [value]}] 232 | for param in nets_ema[i].parameters(): 233 | param.detach_() 234 | 235 | 236 | 237 | # Training loop 238 | errors = AverageMeter() 239 | std_list = [] 240 | iteration = args.iteration 241 | adapt_loss_op = build_adaptation_loss(args.loss, args.lamda_pseudo) 242 | length_target = len(dataset_target_for_adaptation) 243 | with open(os.path.join(savepath, "train.log"), "w") as outfile: 244 | with open(os.path.join(savepath, "loss.log"), "w") as lossfile: 245 | for j, (data, label) in enumerate(dataset_target_for_adaptation): 246 | if j == 0: 247 | statedict = mean_models_params(nets) 248 | test(statedict, dataset_target, outfile, 0) 249 | outfile.write(" \n") 250 | if j > args.num_experiments: 251 | break 252 | label = label.to(device) 253 | for k, v in data.items(): 254 | if torch.is_tensor(v): 255 | data[k] = v.to(device) 256 | 257 | gaze_error = train_test(data, dataset_target, iteration, adapt_loss_op, lossfile) 258 | errors.update(gaze_error.item(), label.size(0)) 259 | std_list += [gaze_error.item()] 260 | timeend = time.time() 261 | log = f"[{j}/{length_target}] " \ 262 | f"batch_size: {args.batch_size} " \ 263 | f"iteration: {args.iteration} " \ 264 | f"avg_loss:{errors.avg:.4f} " \ 265 | f"gaze_loss:{errors.val:.4f} " 266 | print(log) 267 | outfile.write(log + "\n") 268 | sys.stdout.flush() 269 | outfile.flush() 270 | outfile.write("std = %.4f"%(np.std(std_list)) + "\n") --------------------------------------------------------------------------------