├── AERI ├── loader_aeri.py └── train_aeri.py ├── LICENSE ├── README.md ├── configs ├── eyediap_config.yaml ├── mpii_config.yaml └── utm_config.yaml ├── figures └── msgazenet.png ├── gaze_estimation ├── eyediap_5fold.py ├── mpii_loso.py ├── reader.py └── utm_3fold.py ├── models ├── aeri_unet.py ├── msgazenet.py └── unet_parts.py ├── requirements.txt └── utils.py /AERI/loader_aeri.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 1 17:35:03 2022 4 | 5 | @author: zunayed mahmud 6 | """ 7 | 8 | """ 9 | A portion of this code is borrowed from: https://github.com/swook/GazeML 10 | Please also refer to this source if you directly use this code! 11 | """ 12 | 13 | from torch.utils.data import Dataset 14 | import numpy as np 15 | from torchvision import transforms 16 | import torch 17 | import json 18 | import cv2 19 | import random 20 | from utils import * 21 | 22 | data_transforms=transforms.Compose([ 23 | transforms.ToTensor(), 24 | ]) 25 | 26 | 27 | 28 | class trainset(Dataset): 29 | def __init__(self): 30 | self.image_path="/Data/EyeMask/" # path to the EyeMask dataset 31 | 32 | 33 | def __getitem__(self, idx): 34 | self.full_image = cv2.imread(self.image_path+str(idx+1)+'.jpg', 0) 35 | self.mask1 = cv2.imread(self.image_path+str(idx+1)+'_eyeshape.jpg', 0) 36 | self.mask2 = cv2.imread(self.image_path+str(idx+1)+'_iris.jpg', 0) 37 | 38 | ih, iw= self.full_image.shape 39 | iw_2, ih_2 = 0.5 * iw, 0.5 * ih 40 | oh = 36 41 | ow = 60 42 | 43 | with open(self.image_path + str(idx+1) + '.json') as f: 44 | data = json.load(f) 45 | 46 | def process_coords(coords_list): 47 | coords = [eval(l) for l in coords_list] 48 | return np.array([(x, ih-y, z) for (x, y, z) in coords]) 49 | interior_landmarks = process_coords(data['interior_margin_2d']) 50 | caruncle_landmarks = process_coords(data['caruncle_2d']) 51 | iris_landmarks = process_coords(data['iris_2d']) 52 | 53 | left_corner = np.mean(caruncle_landmarks[:, :2], axis=0) 54 | right_corner = interior_landmarks[8, :2] 55 | eye_width = 1.5 * abs(left_corner[0] - right_corner[0]) 56 | eye_middle = np.mean([np.amin(interior_landmarks[:, :2], axis=0), 57 | np.amax(interior_landmarks[:, :2], axis=0)], axis=0) 58 | 59 | random_multipliers = [] 60 | difficulty = 1.0 61 | 62 | augmentation_ranges = { # (easy, hard) 63 | 'translation': (2.0, 10.0), 64 | 'rotation': (0.1, 2.0), 65 | 'intensity': (0.5, 20.0), 66 | 'blur': (0.1, 1.0), 67 | 'scale': (0.01, 0.1), 68 | 'rescale': (1.0, 0.2), 69 | 'num_line': (0.0, 2.0), 70 | 'heatmap_sigma': (5.0, 2.5), 71 | } 72 | 73 | def value_from_type(augmentation_type): 74 | # Scale to be in range 75 | easy_value, hard_value = augmentation_ranges[augmentation_type] 76 | value = (hard_value - easy_value) * difficulty + easy_value 77 | value = (np.clip(value, easy_value, hard_value) 78 | if easy_value < hard_value 79 | else np.clip(value, hard_value, easy_value)) 80 | return value 81 | 82 | def noisy_value_from_type(augmentation_type): 83 | # Get normal distributed random value 84 | if len(random_multipliers) == 0: 85 | random_multipliers.extend( 86 | list(np.random.normal(size=(len(augmentation_ranges),)))) 87 | return random_multipliers.pop() * value_from_type(augmentation_type) 88 | 89 | translate_mat = np.asmatrix(np.eye(3)) 90 | translate_mat[:2, 2] = [[-iw_2], [-ih_2]] 91 | 92 | rotate_mat = np.asmatrix(np.eye(3)) 93 | rotation_noise = noisy_value_from_type('rotation') 94 | 95 | if rotation_noise > 0: 96 | rotate_angle = np.radians(rotation_noise) 97 | cos_rotate = np.cos(rotate_angle) 98 | sin_rotate = np.sin(rotate_angle) 99 | rotate_mat[0, 0] = cos_rotate 100 | rotate_mat[0, 1] = -sin_rotate 101 | rotate_mat[1, 0] = sin_rotate 102 | rotate_mat[1, 1] = cos_rotate 103 | 104 | scale_mat = np.asmatrix(np.eye(3)) 105 | scale = 1. + noisy_value_from_type('scale') 106 | scale_inv = 1. / scale 107 | np.fill_diagonal(scale_mat, ow / eye_width * scale) 108 | original_eyeball_radius = 71.7593 109 | eyeball_radius = original_eyeball_radius * scale_mat[0, 0] # See: https://goo.gl/ZnXgDE 110 | eyeball_radius = np.float32(eyeball_radius) 111 | 112 | recentre_mat = np.asmatrix(np.eye(3)) 113 | recentre_mat[0, 2] = iw/2 - eye_middle[0] + 0.5 * eye_width * scale_inv 114 | recentre_mat[1, 2] = ih/2 - eye_middle[1] + 0.5 * oh / ow * eye_width * scale_inv 115 | recentre_mat[0, 2] += noisy_value_from_type('translation') # x 116 | recentre_mat[1, 2] += noisy_value_from_type('translation') # y 117 | 118 | transform_mat = recentre_mat * scale_mat * rotate_mat * translate_mat 119 | 120 | self.eye = cv2.warpAffine(self.full_image, transform_mat[:2, :3], (ow, oh)) 121 | self.eyeshape = cv2.warpAffine(self.mask1, transform_mat[:2, :3], (ow, oh)) 122 | self.iris = cv2.warpAffine(self.mask2, transform_mat[:2, :3], (ow, oh)) 123 | eye = self.eye 124 | eye = cv2.equalizeHist(eye) 125 | eye = resize_img(eye) 126 | eye = blur_img(eye) 127 | eye = change_contrast(eye) 128 | eye = blur_region(eye) 129 | eye = remove_region(eye) 130 | eye = add_line(eye) 131 | eye = noise(eye) 132 | eye = np.clip(eye,0,255) 133 | eye = eye.astype(np.uint8) 134 | 135 | mask1 = self.eyeshape 136 | mask1 = mask1.astype(np.uint8) 137 | mask1 = data_transforms(mask1[:,:,np.newaxis]) 138 | 139 | mask2 = self.iris 140 | mask2 = mask2.astype(np.uint8) 141 | mask2 = data_transforms(mask2[:,:,np.newaxis]) 142 | 143 | label=torch.cat((mask1,mask2),0) 144 | return data_transforms(eye[:,:,np.newaxis]), label 145 | 146 | def __len__(self): 147 | return 60000 148 | -------------------------------------------------------------------------------- /AERI/train_aeri.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 1 17:35:03 2022 4 | 5 | @author: zunayed mahmud 6 | """ 7 | 8 | import os 9 | from loader_aeri import trainset 10 | import torch 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | from models.aeri_unet import AERI_UNet 14 | import torch.optim as optim 15 | import torch.nn as nn 16 | from torch.utils.tensorboard import SummaryWriter 17 | 18 | train = trainset() 19 | train_loader = DataLoader(train, batch_size=32, shuffle=True, drop_last=True, num_workers=4) 20 | 21 | writer = SummaryWriter('/path/to/summary') 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | 24 | print("Model building") 25 | net=AERI_UNet().to(device) 26 | 27 | print("optimizer building") 28 | optimizer = optim.Adam(net.parameters(), lr=0.00001) 29 | criterion = nn.MSELoss() 30 | 31 | train_loss_values = [] 32 | 33 | print("Training") 34 | for epoch in range(30): 35 | train_running_loss = 0.0 36 | net.train() 37 | 38 | for i, data in enumerate(tqdm(train_loader, 0)): 39 | # Acquire data 40 | eye, label = data 41 | eye = eye.to(device) 42 | label = label.to(device) 43 | optimizer.zero_grad() 44 | 45 | # Forward 46 | out = net(eye.float()) 47 | 48 | # loss calculation 49 | loss=criterion(out, label.float()) 50 | train_running_loss += (loss.item()*eye.size(0)) 51 | 52 | # backward 53 | loss.backward() 54 | optimizer.step() 55 | 56 | 57 | 58 | train_epoch_loss = train_running_loss/len(train) 59 | train_loss_values.append(train_epoch_loss) 60 | 61 | 62 | print('[%d] loss: %.3f' % #.3f means 3 decimal points 63 | (epoch + 1, train_epoch_loss)) 64 | writer.add_scalar('Loss/MSE', train_epoch_loss, epoch) 65 | 66 | print('Saving Weights for MSE loss:', train_epoch_loss) 67 | state = { 68 | 'epoch': epoch, 69 | 'state_dict': net.state_dict(), 70 | 'optimizer': optimizer.state_dict(), 71 | } 72 | savepath = 'weights/aeri_weights/' 73 | if not os.path.exists(savepath): 74 | os.makedirs(savepath) 75 | 76 | model_name=savepath+'AERI_E:'+str(epoch+1)+'_L:'+str(train_epoch_loss)+'.t7' 77 | torch.save(state,model_name) 78 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 z-mahmud22 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MSGazeNet 2 | This is the official implementation of our work entitled "[Multistream Gaze Estimation with Anatomical Eye Region Isolation by Synthetic to Real Transfer Learning](https://arxiv.org/abs/2206.09256)" in PyTorch (Version 1.9.0). 3 | 4 | ![Alt text](/figures/msgazenet.png?raw=true "Optional Title") 5 | 6 | The repository contains the source code of our paper where we use the following datasets: 7 | 8 | * [MPIIGaze](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/gaze-based-human-computer-interaction/appearance-based-gaze-estimation-in-the-wild): The dataset provides eye images and their corresponding head pose and gaze annotations from 15 subjects. It was collected in an unconstrained manner over the course of several months. The standard evaluation protocol of this dataset is leave-one-subject-out. 9 | * [Eyediap](https://www.idiap.ch/en/dataset/eyediap): The dataset was collected in a laboratory environment from 16 subjects. Each participant participated in three different sessions and the standard evaluation protocol of this dataset is 5-fold validation. In this work, we have used the VGA videos from the continuous/discrete screen target session (CS/DS). 10 | * [UTMultiview](https://www.ut-vision.org/datasets/): The dataset was collected from 50 subjects in a laboratory setup. The collection procedure involved 8 cameras which generated multiview eye image samples and the corresponding gaze labels was also recorded. 11 | 12 | # Prerequisites 13 | Please follow the steps below to train MSGazeNet: 14 | 1. Create a virtual environment with required libraries 15 | 16 | To create a virtual environment via python: 17 | ``` 18 | python -m venv 19 | ``` 20 | 21 | To create a virtual enrionment via anaconda: 22 | ``` 23 | conda create -n 24 | ``` 25 | Install requirements 26 | ``` 27 | pip install -r requirements.txt 28 | ``` 29 | 3. Download all the datasets and preprocess them following Zhang et al. [1] 30 | 4. Place all the datasets into the 'Data' directory according to the following 31 | ``` 32 | Data 33 | ├───eyediap 34 | │ ├───Image 35 | │ └───Label 36 | ├───mpiigaze 37 | │ ├───Image 38 | │ └───Label 39 | └───utmultiview 40 | ├───Image 41 | └───Label 42 | ``` 43 | 5. Train the Anatomical Eye Region Isolation (AERI) network 44 | ``` 45 | python AERI/train_aeri.py 46 | ``` 47 | This would train the AERI network which can later be used in the framework for gaze estimation. The trained weights will be stored in 'weights/aeri_weights/' folder which will be created upon the execution of this code. 48 | 6. Train the gaze estimation network using the pretrained weights of AERI network 49 | 50 | For LOSO experiment on MPIIGaze: 51 | ``` 52 | python gaze_estimation/mpii_loso.py 53 | ``` 54 | For 5-fold experiment on Eyediap: 55 | ``` 56 | python gaze_estimation/eyediap_5fold.py 57 | ``` 58 | For 3-fold experiment on UTMultiview: 59 | ``` 60 | python gaze_estimation/utm_3fold.py 61 | ``` 62 | # Citation 63 | ``` 64 | @ARTICLE{10438413, 65 | author={Mahmud, Zunayed and Hungler, Paul and Etemad, Ali}, 66 | journal={IEEE Transactions on Artificial Intelligence}, 67 | title={Multistream Gaze Estimation with Anatomical Eye Region Isolation by Synthetic to Real Transfer Learning}, 68 | year={2024}, 69 | volume={}, 70 | number={}, 71 | pages={1-15}, 72 | keywords={Estimation;Synthetic data;Head;Iris;Feature extraction;Training;Lighting;Gaze estimation;eye region segmentation;multistream network;deep neural network;domain randomization;transfer learning}, 73 | doi={10.1109/TAI.2024.3366174}} 74 | ``` 75 | # Contact 76 | Please email me your questions or concerns at zunayed.mahmud@queensu.ca 77 | # References 78 | [1] X. Zhang, Y. Sugano, and A. Bulling, “Revisiting data normalization 79 | for appearance-based gaze estimation,” in Proceedings of the 2018 ACM 80 | Symposium on Eye Tracking Research & Applications, 2018, pp. 1–9. 81 | -------------------------------------------------------------------------------- /configs/eyediap_config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | train: 3 | params: 4 | batch_size: 32 5 | epoch: 30 6 | lr: 0.0001 7 | decay: 0.5 8 | decay_step: 5000 9 | loss: MSELoss 10 | save: 11 | save_path: "/weights/ge_weights/eyediap" 12 | data: 13 | image: "/Data/eyediap/Image" 14 | label: "/Data/eyediap/Label" 15 | 16 | 17 | -------------------------------------------------------------------------------- /configs/mpii_config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | train: 3 | params: 4 | batch_size: 32 5 | epoch: 30 6 | lr: 0.0001 7 | decay: 0.5 8 | decay_step: 5000 9 | loss: MSELoss 10 | save: 11 | save_path: "/weights/ge_weights/mpii" 12 | data: 13 | image: "/Data/mpiigaze/Image" 14 | label: "/Data/mpiigaze/Label" 15 | -------------------------------------------------------------------------------- /configs/utm_config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | train: 3 | params: 4 | batch_size: 32 5 | epoch: 30 6 | lr: 0.0001 7 | decay: 0.5 8 | decay_step: 5000 9 | loss: MSELoss 10 | save: 11 | save_path: "/weights/ge_weights/utm" 12 | data: 13 | image: "/Data/utmultiview/Image" 14 | label: "/Data/utmultiview/Label" 15 | -------------------------------------------------------------------------------- /figures/msgazenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/z-mahmud22/MSGazeNet/de88f100105e400a7485ded267fe14ef95b3d9a4/figures/msgazenet.png -------------------------------------------------------------------------------- /gaze_estimation/eyediap_5fold.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 1 17:35:03 2022 4 | 5 | @author: zunayed mahmud 6 | """ 7 | 8 | import models.msgazenet as msgazenet 9 | import gaze_estimation.reader as reader 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import time 15 | import sys 16 | import os 17 | import copy 18 | import yaml 19 | import torch.backends.cudnn as cudnn 20 | import torch.multiprocessing as mp 21 | import warnings 22 | from utils import create_trainset, create_testset 23 | 24 | if __name__ == "__main__": 25 | config = yaml.load(open('/configs/eyediap_config.yaml'), Loader=yaml.FullLoader) 26 | config = config["train"] 27 | cudnn.benchmark=True 28 | 29 | imagepath = config["data"]["image"] 30 | labelpath = config["data"]["label"] 31 | modelname = config["save"]["model_name"] 32 | 33 | folder = os.listdir(labelpath) 34 | folder.sort() 35 | 36 | if not os.path.exists(os.path.join(config["save"]["save_path"], f"msgazenet")): 37 | os.makedirs(os.path.join(config["save"]["save_path"], f"msgazenet")) 38 | with open(os.path.join(config["save"]["save_path"], f"msgazenet/", "model_params"), 'w') as outfile: 39 | params = f"Model:wideresnet num_block:3 depth:16 widen_factor:4 loss_fn:MSE base_lr:{config['params']['lr']} BS:{config['params']['batch_size']} dr:0.5" 40 | outfile.write(params) 41 | 42 | P_list={} 43 | P_list['0']=['p1.label','p2.label','p3.label'] 44 | P_list['1']=['p4.label','p5.label','p6.label'] 45 | P_list['2']=['p7.label','p8.label','p9.label'] 46 | P_list['3']=['p10.label','p11.label','p14.label'] 47 | P_list['4']=['p15.label','p16.label'] 48 | 49 | for i in range(5): 50 | trains = copy.deepcopy(folder) 51 | trains = create_trainset(trains, P_list[str(i)]) 52 | tests = create_testset(P_list[str(i)]) 53 | print(f"Train Set:{trains}") 54 | print(f"Test Set:{tests}") 55 | 56 | trainlabelpath = [os.path.join(labelpath, j) for j in trains] 57 | testlabelpath = [os.path.join(labelpath, j) for j in tests] 58 | 59 | savepath = os.path.join(config["save"]["save_path"], f"msgazenet/{i}") 60 | if not os.path.exists(savepath): 61 | os.makedirs(savepath) 62 | 63 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 64 | 65 | print("Read Traindata") 66 | train_dataset = reader.txtload(trainlabelpath, imagepath, config['params']['batch_size'], device, mode='train', shuffle=True, num_workers=0, header=True) 67 | print("Read Test Data") 68 | test_dataset = reader.txtload(testlabelpath, imagepath, config['params']['batch_size'], device, mode='test', shuffle=True, num_workers=0, header=True) 69 | 70 | print("Model building") 71 | model_builder = msgazenet.build_msgazenet(1, 16, 4, 0.01, 0.1, 0.5) 72 | net = model_builder.build(2) 73 | net.train() 74 | net.to(device) 75 | 76 | print("Optimizer building") 77 | lossfunc = config["params"]["loss"] 78 | loss_op = getattr(nn, lossfunc)().cuda() 79 | base_lr = config["params"]["lr"] 80 | 81 | decaysteps = config["params"]["decay_step"] 82 | decayratio = config["params"]["decay"] 83 | 84 | optimizer = optim.Adam(net.parameters(),lr=base_lr) 85 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=decayratio, patience=3) 86 | 87 | print("Training") 88 | train_length = len(train_dataset) 89 | total_train = train_length * config["params"]["epoch"] 90 | cur = 0 91 | min_error = 100 92 | timebegin = time.time() 93 | with open(os.path.join(savepath, "train_log"), 'w') as outfile: 94 | for epoch in range(1, config["params"]["epoch"]+1): 95 | train_accs = 0 96 | train_loss = 0 97 | train_count = 0 98 | 99 | for i, (data, label) in enumerate(train_dataset): 100 | 101 | # Acquire data 102 | data["eye"] = data["eye"].to(device) 103 | data["iris"] = data["iris"].to(device) 104 | data["eye_mask"] = data["eye_mask"].to(device) 105 | label = label.to(device) 106 | 107 | # forward 108 | gaze = net(data) 109 | 110 | for k, g_pred in enumerate(gaze): 111 | g_pred = g_pred.cpu().detach().numpy() 112 | train_count += 1 113 | train_accs += angular(gazeto3d(g_pred), gazeto3d(label.cpu().numpy()[k])) 114 | 115 | # loss calculation 116 | loss = loss_op(gaze, label) 117 | train_loss+=loss.item()*data["eye"].size(0) 118 | optimizer.zero_grad() 119 | 120 | # backward 121 | loss.backward() 122 | optimizer.step() 123 | cur += 1 124 | 125 | 126 | # print logs 127 | if i % 50 == 0: 128 | timeend = time.time() 129 | resttime = (timeend - timebegin)/cur * (total_train-cur)/3600 130 | log = f"[{epoch}/{config['params']['epoch']}]: [{i}/{train_length}] Train_loss:{loss} Train_AE:{mean_angular_error(gaze.cpu().detach().numpy(),label.cpu().detach().numpy(), gaze.shape[0])} lr:{optimizer.param_groups[0]['lr']}, remaining time:{resttime:.2f}h" 131 | print(log) 132 | 133 | 134 | train_epoch_loss = train_loss/train_count 135 | train_epoch_acc = train_accs/train_count 136 | 137 | logger = f"[{epoch}]: train_epoch_loss:{train_epoch_loss} train_epoch_AE:{train_epoch_acc} lr:{optimizer.param_groups[0]['lr']}, remaining time:{resttime:.2f}h" 138 | print(logger) 139 | outfile.write(logger + "\n") 140 | sys.stdout.flush() 141 | outfile.flush() 142 | 143 | #print("Testing") 144 | net.eval() 145 | 146 | test_length = len(test_dataset) 147 | total_test = test_length * config["params"]["epoch"] 148 | with torch.no_grad(): 149 | test_accs = 0 150 | test_loss = 0 151 | test_count = 0 152 | for i, (data, label) in enumerate(test_dataset): 153 | # Acquire test data 154 | data["eye"] = data["eye"].to(device) 155 | data["iris"] = data["iris"].to(device) 156 | data["eye_mask"] = data["eye_mask"].to(device) 157 | label = label.to(device) 158 | 159 | gaze = net(data) 160 | 161 | for k, g_pred in enumerate(gaze): 162 | g_pred = g_pred.cpu().detach().numpy() 163 | test_count += 1 164 | test_accs += angular(gazeto3d(g_pred), gazeto3d(label.cpu().numpy()[k])) 165 | 166 | loss = loss_op(gaze, label) 167 | test_loss+=loss.item()*data["eye"].size(0) 168 | cur += 1 169 | 170 | # print logs 171 | if i % 50 == 0: 172 | timeend = time.time() 173 | resttime = (timeend - timebegin)/cur * (total_test-cur)/3600 174 | log = f"[{epoch}/{config['params']['epoch']}]: [{i}/{test_length}] Test_loss:{loss} Test_AE:{mean_angular_error(gaze.cpu().detach().numpy(),label.cpu().detach().numpy(), gaze.shape[0])} lr:{optimizer.param_groups[0]['lr']}" 175 | print(log) 176 | 177 | test_epoch_loss = test_loss/test_count 178 | test_epoch_acc = test_accs/test_count 179 | scheduler.step(test_epoch_acc) 180 | 181 | logger = f"[{epoch}]: test_epoch_loss:{test_epoch_loss} test_epoch_AE:{test_epoch_acc} lr:{optimizer.param_groups[0]['lr']}" 182 | print(logger) 183 | outfile.write(logger + "\n") 184 | sys.stdout.flush() 185 | outfile.flush() 186 | 187 | if test_epoch_acc0.5]=1 74 | iris_mask[iris_mask<=0.5]=0 75 | 76 | eyemask=out.detach().cpu().permute(0,2,3,1)[:,:,:,0] 77 | eyemask[eyemask>0.5]=1 78 | eyemask[eyemask<=0.5]=0 79 | 80 | 81 | info = {"eye":img, 82 | "iris": iris_mask, 83 | "eye_mask":eyemask} 84 | 85 | return info, label 86 | 87 | class test_loader(Dataset): 88 | def __init__(self, path, root, device, header=True): 89 | 90 | self.lines = [] 91 | self.device = device 92 | self.isolator = AERI_UNet().to(self.device) 93 | self.checkpoint = torch.load("/weights/aeri_weights/*.t7") # load the weights of the AERI network 94 | self.isolator.load_state_dict(self.checkpoint['state_dict']) 95 | 96 | if isinstance(path, list): 97 | for i in path: 98 | with open(i) as f: 99 | line = f.readlines() 100 | if header: line.pop(0) 101 | self.lines.extend(line) 102 | else: 103 | with open(path) as f: 104 | self.lines = f.readlines() 105 | if header: self.lines.pop(0) 106 | 107 | self.root = root 108 | 109 | def __len__(self): 110 | return len(self.lines) 111 | 112 | def __getitem__(self, idx): 113 | line = self.lines[idx] 114 | line = line.strip().split(" ") 115 | 116 | name = line[1] 117 | gaze2d = line[5] 118 | eye = line[0] 119 | 120 | label = np.array(gaze2d.split(",")).astype("float") 121 | label = torch.from_numpy(label).type(torch.FloatTensor) 122 | 123 | img = cv2.imread(os.path.join(self.root, eye).replace('\\','/'),0)/255.0 124 | img = img[:,:,np.newaxis] 125 | img = img.transpose(2, 0, 1) 126 | img = torch.from_numpy(img).type(torch.FloatTensor) 127 | 128 | out = self.isolator(torch.unsqueeze(img,0).to(self.device)) 129 | 130 | iris_mask=out.detach().cpu().permute(0,2,3,1)[:,:,:,1] 131 | iris_mask[iris_mask>0.5]=1 132 | iris_mask[iris_mask<=0.5]=0 133 | 134 | eyemask=out.detach().cpu().permute(0,2,3,1)[:,:,:,0] 135 | eyemask[eyemask>0.5]=1 136 | eyemask[eyemask<=0.5]=0 137 | 138 | info = {"eye":img, 139 | "iris": iris_mask, 140 | "eye_mask":eyemask} 141 | 142 | return info, label 143 | 144 | def txtload(labelpath, imagepath, batch_size, device, mode, shuffle=True, num_workers=0, header=True): 145 | if mode == 'train': 146 | dataset = train_loader(labelpath, imagepath, device, header) 147 | print(f"[Read Data]: Total num: {len(dataset)}") 148 | print(f"[Read Data]: Label path: {labelpath}") 149 | load = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 150 | else: 151 | dataset = test_loader(labelpath, imagepath, device, header) 152 | print(f"[Read Data]: Total num: {len(dataset)}") 153 | print(f"[Read Data]: Label path: {labelpath}") 154 | load = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 155 | 156 | return load 157 | -------------------------------------------------------------------------------- /gaze_estimation/utm_3fold.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 1 17:35:03 2022 4 | 5 | @author: zunayed mahmud 6 | """ 7 | 8 | import models.msgazenet as msgazenet 9 | import gaze_estimation.reader as reader 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import time 15 | import sys 16 | import os 17 | import copy 18 | import yaml 19 | import torch.backends.cudnn as cudnn 20 | import torch.multiprocessing as mp 21 | import warnings 22 | 23 | 24 | if __name__ == "__main__": 25 | config = yaml.load(open('/configs/utm_config.yaml'), Loader=yaml.FullLoader) 26 | config = config["train"] 27 | cudnn.benchmark=True 28 | 29 | imagepath = config["data"]["image"] 30 | train_labelpath = config["data"]["label"] 31 | test_labelpath = config["data"]["label"] 32 | modelname = config["save"]["model_name"] 33 | 34 | train_folder = os.listdir(train_labelpath) 35 | train_folder.sort() 36 | test_folder = os.listdir(test_labelpath) 37 | test_folder.sort() 38 | 39 | if not os.path.exists(os.path.join(config["save"]["save_path"], f"msgazenet")): 40 | os.makedirs(os.path.join(config["save"]["save_path"], f"msgazenet")) 41 | with open(os.path.join(config["save"]["save_path"], f"msgazenet/", "model_params"), 'w') as outfile: 42 | params = f"Model:wideresnet num_block:3 depth:16 widen_factor:4 loss_fn:MSE base_lr:{config['params']['lr']} BS:{config['params']['batch_size']} dr:0.5" 43 | outfile.write(params) 44 | 45 | P_list={} 46 | P_list['0']=['s'+str(i).zfill(2)+'.label' for i in range(17)] 47 | P_list['1']=['s'+str(i).zfill(2)+'.label' for i in range(17,34)] 48 | P_list['2']=['s'+str(i).zfill(2)+'.label' for i in range(34,50)] 49 | 50 | for i in range(3): 51 | trains = copy.deepcopy(train_folder) 52 | trains = create_trainset(trains, P_list[str(i)]) 53 | tests = create_testset(P_list[str(i)]) 54 | print(f"Train Set:{trains}") 55 | print(f"Test Set:{tests}") 56 | 57 | trainlabelpath = [os.path.join(train_labelpath, j) for j in trains] 58 | testlabelpath = [os.path.join(test_labelpath, j) for j in tests] 59 | 60 | savepath = os.path.join(config["save"]["save_path"], f"msgazenet/{i}") 61 | if not os.path.exists(savepath): 62 | os.makedirs(savepath) 63 | 64 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 65 | 66 | print("Read Traindata") 67 | train_dataset = reader.txtload(trainlabelpath, imagepath, config['params']['batch_size'], device, mode='train', shuffle=True, num_workers=0, header=True) 68 | print("Read Test Data") 69 | test_dataset = reader.txtload(testlabelpath, imagepath, config['params']['batch_size'], device, mode='test', shuffle=True, num_workers=0, header=True) 70 | 71 | print("Model building") 72 | model_builder = msgazenet.build_msgazenet(1, 16, 4, 0.01, 0.1, 0.5) 73 | net = model_builder.build(2) 74 | net.train() 75 | net.to(device) 76 | 77 | print("Optimizer building") 78 | lossfunc = config["params"]["loss"] 79 | loss_op = getattr(nn, lossfunc)().cuda() 80 | base_lr = config["params"]["lr"] 81 | 82 | decaysteps = config["params"]["decay_step"] 83 | decayratio = config["params"]["decay"] 84 | 85 | optimizer = optim.Adam(net.parameters(),lr=base_lr) 86 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=decayratio, patience=3) 87 | 88 | print("Training") 89 | train_length = len(train_dataset) 90 | total_train = train_length * config["params"]["epoch"] 91 | cur = 0 92 | min_error = 100 93 | timebegin = time.time() 94 | with open(os.path.join(savepath, "train_log"), 'w') as outfile: 95 | for epoch in range(1, config["params"]["epoch"]+1): 96 | train_accs = 0 97 | train_loss = 0 98 | train_count = 0 99 | 100 | for i, (data, label) in enumerate(train_dataset): 101 | 102 | # Acquire data 103 | data["eye"] = data["eye"].to(device) 104 | data["iris"] = data["iris"].to(device) 105 | data["eye_mask"] = data["eye_mask"].to(device) 106 | label = label.to(device) 107 | 108 | # forward 109 | gaze = net(data) 110 | 111 | for k, g_pred in enumerate(gaze): 112 | g_pred = g_pred.cpu().detach().numpy() 113 | train_count += 1 114 | train_accs += angular(gazeto3d(g_pred), gazeto3d(label.cpu().numpy()[k])) 115 | 116 | # loss calculation 117 | loss = loss_op(gaze, label) 118 | train_loss+=loss.item()*data["eye"].size(0) 119 | optimizer.zero_grad() 120 | 121 | # backward 122 | loss.backward() 123 | optimizer.step() 124 | cur += 1 125 | 126 | 127 | # print logs 128 | if i % 50 == 0: 129 | timeend = time.time() 130 | resttime = (timeend - timebegin)/cur * (total_train-cur)/3600 131 | log = f"[{epoch}/{config['params']['epoch']}]: [{i}/{train_length}] Train_loss:{loss} Train_AE:{mean_angular_error(gaze.cpu().detach().numpy(),label.cpu().detach().numpy(), gaze.shape[0])} lr:{optimizer.param_groups[0]['lr']}, remaining time:{resttime:.2f}h" 132 | print(log) 133 | 134 | train_epoch_loss = train_loss/train_count 135 | train_epoch_acc = train_accs/train_count 136 | 137 | logger = f"[{epoch}]: train_epoch_loss:{train_epoch_loss} train_epoch_AE:{train_epoch_acc} lr:{optimizer.param_groups[0]['lr']}, remaining time:{resttime:.2f}h" 138 | print(logger) 139 | outfile.write(logger + "\n") 140 | sys.stdout.flush() 141 | outfile.flush() 142 | 143 | #print("Testing") 144 | net.eval() 145 | 146 | test_length = len(test_dataset) 147 | total_test = test_length * config["params"]["epoch"] 148 | with torch.no_grad(): 149 | test_accs = 0 150 | test_loss = 0 151 | test_count = 0 152 | for i, (data, label) in enumerate(test_dataset): 153 | # Acquire test data 154 | data["eye"] = data["eye"].to(device) 155 | data["iris"] = data["iris"].to(device) 156 | data["eye_mask"] = data["eye_mask"].to(device) 157 | label = label.to(device) 158 | 159 | gaze = net(data) 160 | 161 | for k, g_pred in enumerate(gaze): 162 | g_pred = g_pred.cpu().detach().numpy() 163 | test_count += 1 164 | test_accs += angular(gazeto3d(g_pred), gazeto3d(label.cpu().numpy()[k])) 165 | 166 | loss = loss_op(gaze, label) 167 | test_loss+=loss.item()*data["eye"].size(0) 168 | cur += 1 169 | 170 | # print logs 171 | if i % 50 == 0: 172 | timeend = time.time() 173 | resttime = (timeend - timebegin)/cur * (total_test-cur)/3600 174 | log = f"[{epoch}/{config['params']['epoch']}]: [{i}/{test_length}] Test_loss:{loss} Test_AE:{mean_angular_error(gaze.cpu().detach().numpy(),label.cpu().detach().numpy(), gaze.shape[0])} lr:{optimizer.param_groups[0]['lr']}" 175 | print(log) 176 | 177 | test_epoch_loss = test_loss/test_count 178 | test_epoch_acc = test_accs/test_count 179 | scheduler.step(test_epoch_acc) 180 | 181 | logger = f"[{epoch}]: test_epoch_loss:{test_epoch_loss} test_epoch_AE:{test_epoch_acc} lr:{optimizer.param_groups[0]['lr']}" 182 | print(logger) 183 | outfile.write(logger + "\n") 184 | sys.stdout.flush() 185 | outfile.flush() 186 | 187 | if test_epoch_acc 0: 43 | out = F.dropout(out, p=self.drop_rate, training=self.training) 44 | out = self.conv2(out) 45 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 46 | 47 | 48 | class NetworkBlock(nn.Module): 49 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate=0.0, activate_before_residual=False): 50 | super(NetworkBlock, self).__init__() 51 | self.layer = self._make_layer( 52 | block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual) 53 | 54 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual): 55 | layers = [] 56 | for i in range(int(nb_layers)): 57 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, 58 | i == 0 and stride or 1, drop_rate, activate_before_residual)) 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | return self.layer(x) 63 | 64 | class MSGazeNet(nn.Module): 65 | def __init__(self, first_stride, num_classes, depth=28, widen_factor=2, drop_rate=0.0, is_remix=False): 66 | super(MSGazeNet, self).__init__() 67 | channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 68 | assert ((depth - 4) % 6 == 0) 69 | n = (depth - 4) / 6 70 | block = BasicBlock 71 | # 1st conv before any network block 72 | self.conv1_eye = nn.Conv2d(1, channels[0], kernel_size=3, stride=1, 73 | padding=1, bias=True) 74 | self.conv1_iris = nn.Conv2d(1, channels[0], kernel_size=3, stride=1, 75 | padding=1, bias=True) 76 | self.conv1_eyemask = nn.Conv2d(1, channels[0], kernel_size=3, stride=1, 77 | padding=1, bias=True) 78 | # 1st block 79 | self.block1_eye = NetworkBlock( 80 | n, channels[0], channels[1], block, first_stride, drop_rate, activate_before_residual=True) 81 | self.block1_iris = NetworkBlock( 82 | n, channels[0], channels[1], block, first_stride, drop_rate, activate_before_residual=True) 83 | self.block1_eyemask = NetworkBlock( 84 | n, channels[0], channels[1], block, first_stride, drop_rate, activate_before_residual=True) 85 | # 2nd block 86 | self.block2_eye = NetworkBlock( 87 | n, channels[1], channels[2], block, 2, drop_rate) 88 | self.block2_iris = NetworkBlock( 89 | n, channels[1], channels[2], block, 2, drop_rate) 90 | self.block2_eyemask = NetworkBlock( 91 | n, channels[1], channels[2], block, 2, drop_rate) 92 | # 3rd block 93 | self.block3 = NetworkBlock( 94 | n, channels[2]*3, channels[3], block, 2, drop_rate) 95 | # global average pooling and classifier 96 | self.bn1 = nn.BatchNorm2d(channels[3], momentum=0.001, eps=0.001) 97 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=False) 98 | self.fc = nn.Sequential( 99 | nn.Linear(channels[3], channels[3]*2), 100 | nn.ReLU(inplace=True), 101 | nn.Dropout(0.25), 102 | nn.Linear(channels[3]*2, channels[3]), 103 | nn.ReLU(inplace=True), 104 | nn.Dropout(0.25), 105 | nn.Linear(channels[3], num_classes) 106 | ) 107 | self.channels = channels[3] 108 | 109 | # rot_classifier for Remix Match 110 | self.is_remix = is_remix 111 | if is_remix: 112 | self.rot_classifier = nn.Linear(self.channels, 4) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | elif isinstance(m, nn.Linear): 121 | nn.init.xavier_normal_(m.weight.data) 122 | m.bias.data.zero_() 123 | 124 | def forward(self, x, ood_test=False): 125 | out_x = self.conv1_eye(x['eye']) 126 | out_y = self.conv1_iris(x['iris']) 127 | out_z = self.conv1_eyemask(x['eye_mask']) 128 | out_x = self.block1_eye(out_x) 129 | out_y = self.block1_iris(out_y) 130 | out_z = self.block1_eyemask(out_z) 131 | out_x = self.block2_eye(out_x) 132 | out_y = self.block2_iris(out_y) 133 | out_z = self.block2_eyemask(out_z) 134 | out = torch.cat((out_x, out_y, out_z),1) 135 | out = self.block3(out) 136 | out = self.relu(self.bn1(out)) 137 | out = F.adaptive_avg_pool2d(out, 1) 138 | out = out.view(-1, self.channels) 139 | output = self.fc(out) 140 | 141 | return output 142 | 143 | 144 | class build_msgazenet: 145 | def __init__(self, first_stride=1, depth=28, widen_factor=2, bn_momentum=0.01, leaky_slope=0.0, dropRate=0.0, 146 | use_embed=False, is_remix=False): 147 | self.first_stride = first_stride 148 | self.depth = depth 149 | self.widen_factor = widen_factor 150 | self.bn_momentum = bn_momentum 151 | self.dropRate = dropRate 152 | self.leaky_slope = leaky_slope 153 | self.use_embed = use_embed 154 | self.is_remix = is_remix 155 | 156 | def build(self, num_classes): 157 | return MSGazeNet( 158 | first_stride=self.first_stride, 159 | depth=self.depth, 160 | num_classes=num_classes, 161 | widen_factor=self.widen_factor, 162 | drop_rate=self.dropRate, 163 | is_remix=self.is_remix, 164 | ) 165 | 166 | 167 | -------------------------------------------------------------------------------- /models/unet_parts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | This code is borrowed from: https://github.com/milesial/Pytorch-UNet 4 | Please refer to this source if you directly use this code! 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class DoubleConv(nn.Module): 13 | """(convolution => [BN] => ReLU) * 2""" 14 | 15 | def __init__(self, in_channels, out_channels, mid_channels=None): 16 | super().__init__() 17 | if not mid_channels: 18 | mid_channels = out_channels 19 | self.double_conv = nn.Sequential( 20 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 21 | nn.BatchNorm2d(mid_channels), 22 | nn.ReLU(inplace=True), 23 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 24 | nn.BatchNorm2d(out_channels), 25 | nn.ReLU(inplace=True) 26 | ) 27 | 28 | def forward(self, x): 29 | return self.double_conv(x) 30 | 31 | 32 | class Down(nn.Module): 33 | """Downscaling with maxpool then double conv""" 34 | 35 | def __init__(self, in_channels, out_channels): 36 | super().__init__() 37 | self.maxpool_conv = nn.Sequential( 38 | nn.MaxPool2d(2), 39 | DoubleConv(in_channels, out_channels) 40 | ) 41 | 42 | def forward(self, x): 43 | return self.maxpool_conv(x) 44 | 45 | 46 | class Up(nn.Module): 47 | """Upscaling then double conv""" 48 | 49 | def __init__(self, in_channels, out_channels, bilinear=True): 50 | super().__init__() 51 | 52 | # if bilinear, use the normal convolutions to reduce the number of channels 53 | if bilinear: 54 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 55 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 56 | else: 57 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 58 | self.conv = DoubleConv(in_channels, out_channels) 59 | 60 | 61 | def forward(self, x1, x2): 62 | x1 = self.up(x1) 63 | # input is CHW 64 | diffY = x2.size()[2] - x1.size()[2] 65 | diffX = x2.size()[3] - x1.size()[3] 66 | 67 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 68 | diffY // 2, diffY - diffY // 2]) 69 | # if you have padding issues, see 70 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 71 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 72 | x = torch.cat([x2, x1], dim=1) 73 | return self.conv(x) 74 | 75 | 76 | class OutConv(nn.Module): 77 | def __init__(self, in_channels, out_channels): 78 | super(OutConv, self).__init__() 79 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 80 | 81 | def forward(self, x): 82 | return self.conv(x) 83 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.20.2 2 | opencv-python==4.5.2.54 3 | pandas==1.0.5 4 | PyYAML==5.4.1 5 | scikit-learn==0.21.3 6 | scipy==1.4.1 7 | torch==1.9.0 8 | torchfile==0.1.0 9 | torchsummary==1.5.1 10 | torchvision==0.10.0 11 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 1 17:35:03 2022 4 | 5 | @author: zunayed mahmud 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | import random 11 | import cv2 12 | 13 | def create_trainset(folder, input_list): 14 | for j in input_list: 15 | folder.remove(j) 16 | return folder 17 | 18 | def create_testset(input_list): 19 | new_folder = [] 20 | for j in input_list: 21 | new_folder.append(j) 22 | return new_folder 23 | 24 | # ============================================================================= 25 | # GAZE FUNCTIONS 26 | # ============================================================================= 27 | 28 | def gazeto3d(gaze): 29 | gaze_gt = np.zeros([3]) 30 | gaze_gt[0] = -np.cos(gaze[1]) * np.sin(gaze[0]) 31 | gaze_gt[1] = -np.sin(gaze[1]) 32 | gaze_gt[2] = -np.cos(gaze[1]) * np.cos(gaze[0]) 33 | 34 | return gaze_gt 35 | 36 | def angular(gaze, label): 37 | total = np.sum(gaze * label) 38 | 39 | return np.arccos(min(total/(np.linalg.norm(gaze)* np.linalg.norm(label)), 0.9999999))*180/np.pi 40 | 41 | def mean_angular_error(a, b, batch_size): 42 | error=0 43 | for k, g_pred in enumerate(a): 44 | error+=angular(gazeto3d(g_pred), gazeto3d(b[k])) 45 | return error/batch_size 46 | 47 | 48 | # ============================================================================= 49 | # AUGMENTATIONS 50 | # ============================================================================= 51 | 52 | def add_line(img): 53 | num_of_line=random.randint(0,4) 54 | 55 | for line_no in range(num_of_line): 56 | x1=random.randint(0,img.shape[0]) 57 | x2=random.randint(0,img.shape[0]) 58 | y1=random.randint(0,img.shape[1]) 59 | y2=random.randint(0,img.shape[1]) 60 | 61 | color=random.randint(0,255) 62 | thickness=random.randint(1,3) 63 | # thickness=random.random() 64 | img=cv2.line(img, (x1,y1), (x2,y2), color, thickness) 65 | 66 | return img 67 | 68 | 69 | def change_contrast(img): 70 | op_no=random.randint(0,3) 71 | 72 | if op_no==0: 73 | return img 74 | if op_no==1: 75 | black_pixel=random.randint(0,100) 76 | img[imgwhite_pixel]=255 81 | return img 82 | if op_no==3: 83 | black_pixel=random.randint(0,100) 84 | black_pixel_new=random.randint(0,black_pixel) 85 | white_pixel=random.randint(155,255) 86 | white_pixel_new=random.randint(white_pixel,255) 87 | 88 | xp = [0, black_pixel, 128, white_pixel, 255] 89 | fp = [0, black_pixel_new, 128, white_pixel_new, 255] 90 | x = np.arange(256) 91 | table = np.interp(x, xp, fp).astype('uint8') 92 | img = cv2.LUT(img, table) 93 | return img 94 | 95 | 96 | def resize_img(img): 97 | randimage=random.randint(0,1) 98 | if randimage==0: 99 | 100 | return img 101 | else: 102 | compress_range=random.uniform(1,2) 103 | img= cv2.resize(img, (int(img.shape[1]/compress_range), int(img.shape[0]/compress_range)), interpolation = cv2.INTER_AREA) 104 | img= cv2.resize(img, (60,36), interpolation = cv2.INTER_AREA) 105 | 106 | return img 107 | 108 | 109 | def remove_region(img): 110 | num_of_line=random.randint(0,4) 111 | 112 | for line_no in range(num_of_line): 113 | w=random.randint(0,10) 114 | h=random.randint(0,10) 115 | 116 | x1=random.randint(0,img.shape[0]-w) 117 | y1=random.randint(0,img.shape[1]-h) 118 | 119 | 120 | color=random.randint(img.min(),img.max()) 121 | 122 | img[x1:x1+w,y1:y1+h]= color 123 | 124 | return img 125 | 126 | 127 | def blur_region(img): 128 | num_of_line=random.randint(0,4) 129 | 130 | for line_no in range(num_of_line): 131 | w=random.randint(0,10) 132 | h=random.randint(0,10) 133 | 134 | x1=random.randint(0,img.shape[0]-w) 135 | y1=random.randint(0,img.shape[1]-h) 136 | 137 | 138 | ksize=random.randrange(1,10,2) 139 | BlurImage=cv2.GaussianBlur(img,(ksize,ksize),0) 140 | img[x1:x1+w,y1:y1+h]= BlurImage[x1:x1+w,y1:y1+h] 141 | 142 | return img 143 | 144 | 145 | def blur_img(img): 146 | ksize=random.randrange(1,5,2) 147 | sigma=random.uniform(0,2) 148 | BlurImage=cv2.GaussianBlur(img,(ksize,ksize),sigma) 149 | return BlurImage 150 | 151 | 152 | def adjust_contrast(img1,contrast_factor): 153 | mean=np.mean(img1) 154 | bound=255 155 | ratio=1-random.uniform(0,contrast_factor ) 156 | img=np.clip((ratio * img1 + (1.0 - ratio) * mean),0,bound) 157 | return img.astype(img1.dtype) 158 | 159 | 160 | def noise(img): 161 | mean = 0 # some constant 162 | std = random.uniform(0,20.0) # some constant (standard deviation) 163 | noisy_img = img + np.random.normal(mean, std, (36, 60)) 164 | return noisy_img 165 | --------------------------------------------------------------------------------