├── .gitattributes ├── Data.py ├── Nets.py ├── README.md ├── crop_images.py ├── generate_frames_and_bbox.py ├── statistics.py └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /Data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 19.10.8 20:13 3 | @File:Data.py 4 | @author: coderwangson 5 | """ 6 | "#codeing=utf-8" 7 | from torch.utils import data 8 | from torchvision import transforms 9 | from PIL import Image 10 | import numpy as np 11 | class Data(data.Dataset): 12 | def __init__(self,db_dir,is_train): 13 | self.is_train = is_train 14 | self.file_list,self.label = self.get_file_list(db_dir) 15 | if self.is_train: 16 | self.transforms = transforms.Compose([transforms.ToTensor()]) 17 | else: 18 | self.transforms = transforms.Compose([transforms.ToTensor()]) 19 | 20 | def __getitem__(self, index): 21 | img = self.transforms(Image.open(self.file_list[index]).convert('RGB')) 22 | 23 | label = self.label[index] 24 | return img,label 25 | 26 | def __len__(self): 27 | return len(self.file_list) 28 | 29 | def get_file_list(self,db_dir): 30 | file_list = [] 31 | label_list = [] 32 | for file in open(db_dir + "/file_list.txt", "r"): 33 | file_info = file.strip("\n").split(" ") 34 | file_name = file_info[0] 35 | label = file_info[1] 36 | if self.is_train and file_name.split("/")[-4].startswith("train"): 37 | file_list.append(file_name) 38 | label_list.append(int(label)) 39 | if not self.is_train and file_name.split("/")[-4].startswith("test"): 40 | file_list.append(file_name) 41 | label_list.append(int(label)) 42 | return file_list,label_list 43 | 44 | if __name__ == '__main__': 45 | data = Data("/home/userwyh/code/dataset/CASIA_scale/scale_1.0/",True) 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /Nets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 19.10.8 19:00 3 | @File:Nets.py 4 | @author: coderwangson 5 | """ 6 | "#codeing=utf-8" 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | class Net(nn.Module): 11 | def __init__(self,num_classes): 12 | super().__init__() 13 | self.num_classes = num_classes 14 | self.relu = nn.ReLU() 15 | self.lrn = nn.LocalResponseNorm(5, 0.0001, 0.75, 2) 16 | self.pool = nn.MaxPool2d(3, 2) 17 | 18 | # Conv1 19 | self.conv1 = nn.Conv2d(3,96,11,stride=4) 20 | 21 | # Conv2 22 | self.conv2 = nn.Conv2d(96,256,5,padding=2) 23 | 24 | # Conv3 25 | self.conv3 = nn.Conv2d(256,384,3, padding=1) 26 | 27 | # Conv4 28 | self.conv4 = nn.Conv2d(384,384, 3, padding=1) 29 | 30 | # Conv5 31 | self.conv5 = nn.Conv2d(384,256, 3, padding=1) 32 | 33 | # fc1 34 | self.fc1 = nn.Linear(1024,4096) 35 | 36 | # fc2 37 | self.fc2 = nn.Linear(4096,4096) 38 | 39 | #fc3 40 | self.fc3 = nn.Linear(4096,self.num_classes) 41 | def forward(self, x): 42 | x = self.conv1(x) 43 | x = self.relu(x) 44 | x = self.lrn(x) 45 | x = self.pool(x) 46 | 47 | x = self.conv2(x) 48 | x = self.relu(x) 49 | x = self.lrn(x) 50 | x = self.pool(x) 51 | 52 | x = self.conv3(x) 53 | x = self.relu(x) 54 | 55 | x = self.conv4(x) 56 | x = self.relu(x) 57 | 58 | x = self.conv5(x) 59 | x = self.relu(x) 60 | x = self.pool(x) 61 | 62 | x = x.view(x.size(0), -1) # Flatten 63 | 64 | x = self.fc1(x) 65 | x = self.relu(x) 66 | x = F.dropout(x, p=0.5, training=self.training) 67 | 68 | x = self.fc2(x) 69 | x = self.relu(x) 70 | x = F.dropout(x, p=0.5, training=self.training) 71 | 72 | x = self.fc3(x) 73 | return x 74 | 75 | if __name__ == '__main__': 76 | net = Net(2) 77 | print(net(torch.ones(3,3,128,128))) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learn Convolutional Neural Network for Face Anti-Spoofing using pytorch 2 | 3 | ## requirements 4 | 5 | * pytorch 6 | * cv2 7 | * tensorflow 8 | * [mtcnn][1] 9 | 10 | ## Step 1 11 | 12 | run `generate_frames_and_bbox.py`,video is sampled as a frame,also generate a file_list containing the list of files_name and the bbox of the face 13 | 14 | **like this:** 15 | 16 | file_name x y w h label 17 | 18 | /home/CASIA_frames/test_release/27/1/frame_42.jpg 233 122 170 215 1 19 | 20 | ## Step 2 21 | 22 | run `crop_image.py`,generate face photos at different scales,like this 23 | 24 | ![image](https://tva2.sinaimg.cn/large/005Dd0fOly1g7s9se2dv3j30aq037wg2.jpg) 25 | 26 | To facilitate training, generate a file_list for each scale. 27 | 28 | ## Step 3 29 | 30 | run `train.py`,a network will be trained and tested every n epochs 31 | 32 | 33 | [1]: https://github.com/ipazc/mtcnn -------------------------------------------------------------------------------- /crop_images.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import cv2, sys, os 4 | from math import ceil 5 | from glob import glob 6 | import os 7 | interpolation = cv2.INTER_CUBIC 8 | borderMode = cv2.BORDER_REPLICATE 9 | 10 | def crop_face(img, bbox, crop_sz, bbox_ext, extra_pad=0): 11 | shape = img.shape # [height, width, channels] 12 | x, y, w, h = bbox 13 | 14 | jitt_pad = int(ceil(float(extra_pad) * min(w, h) / crop_sz)) 15 | 16 | pad = 0 17 | if x < w * bbox_ext + jitt_pad: 18 | pad = max(pad, w * bbox_ext + jitt_pad - x) 19 | if x + w * (1 + bbox_ext) + jitt_pad > shape[1]: 20 | pad = max(pad, x + w * (1 + bbox_ext) + jitt_pad - shape[1]) 21 | if y < h * bbox_ext + jitt_pad: 22 | pad = max(pad, h * bbox_ext + jitt_pad - y) 23 | if y + h * (1 + bbox_ext) + jitt_pad > shape[0]: 24 | pad = max(pad, y + h * (1 + bbox_ext) + jitt_pad - shape[0]) 25 | pad = int(pad) 26 | 27 | if pad > 0: 28 | pad = pad + 3 29 | replicate = cv2.copyMakeBorder(img, pad, pad, pad, pad, borderMode) 30 | else: 31 | replicate = img 32 | cropped = replicate[int(pad + y - h * bbox_ext - jitt_pad) : int(pad + y + h * (1 + bbox_ext) + jitt_pad), 33 | int(pad + x - w * bbox_ext - jitt_pad) : int(pad + x + w * (1 + bbox_ext) + jitt_pad)] 34 | resized = cv2.resize(cropped, (crop_sz + 2*extra_pad, crop_sz + 2*extra_pad), interpolation=interpolation) 35 | return resized 36 | 37 | 38 | def process_db_casia(db_dir, save_dir, scale, crop_sz): 39 | if not os.path.exists(save_dir): 40 | os.makedirs(save_dir) 41 | file_list = open(save_dir + "/file_list.txt", "w") # save info 42 | for file in open(db_dir+"/file_list.txt","r"): 43 | print("processing(scale %f): %s" % (scale, file)) 44 | file_info = file.strip("\n").split(" ") 45 | file_name = file_info[0] 46 | bbox = map(int, file_info[1:5]) 47 | label = file_info[5] 48 | 49 | cur_save_dir = os.path.join(save_dir, *file_name.split("/")[-4:-1]) 50 | if not os.path.exists(cur_save_dir): 51 | os.makedirs(cur_save_dir) 52 | frame = cv2.imread(file_name) 53 | bbox_ext = (scale - 1.0) / 2 54 | cropped = crop_face(frame, bbox, crop_sz, bbox_ext) 55 | 56 | save_fname = os.path.join(cur_save_dir,file_name.split("/")[-1]) 57 | file_list.writelines("%s %s\n" % (save_fname,label)) 58 | cv2.imwrite(save_fname, cropped, [cv2.IMWRITE_JPEG_QUALITY, 100]) 59 | 60 | 61 | 62 | 63 | 64 | 65 | if __name__ == "__main__": 66 | db_dir = '/home/userwyh/code/dataset/CASIA_frames' 67 | save_dir = '/home/userwyh/code/dataset/CASIA_scale' 68 | 69 | crop_sz = 128 70 | scales = [1.0, 1.4, 1.8, 2.2, 2.6] 71 | for scale in scales: 72 | cur_save_dir = save_dir + '/scale_' + str(scale) 73 | process_db_casia(db_dir, cur_save_dir, scale, crop_sz) -------------------------------------------------------------------------------- /generate_frames_and_bbox.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 19.10.8 16:35 3 | @File:generate_frames_and_bbox.py 4 | @author: coderwangson 5 | """ 6 | "#codeing=utf-8" 7 | # TODO using pip install 8 | from mtcnn.mtcnn import MTCNN 9 | import cv2 10 | import os 11 | from glob import glob 12 | detector = MTCNN() 13 | true_img_start = ('1', '2', 'HR_1') 14 | def generate_frames_and_bbox(db_dir,save_dir,skip_num): 15 | file_list = open(save_dir+"/file_list.txt","w") 16 | for file in glob("%s/*/*/*.avi"%db_dir): 17 | print("Processing video %s"%file) 18 | dir_name = os.path.join(save_dir, *file.replace(".avi", "").split("/")[-3:]) 19 | if not os.path.exists(dir_name): 20 | os.makedirs(dir_name) 21 | frame_num = 0 22 | count = 0 23 | vidcap = cv2.VideoCapture(file) 24 | success, frame = vidcap.read() 25 | while success: 26 | 27 | # 只保存有人脸的帧 28 | detect_res = detector.detect_faces(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 29 | if len(detect_res)>0 and count%skip_num==0: 30 | 31 | file_name = os.path.join(dir_name,"frame_%d.jpg" % frame_num) 32 | # bbox = (x,y,w,h) 33 | bbox = (detect_res[0]['box'][0],detect_res[0]['box'][1],detect_res[0]['box'][2],detect_res[0]['box'][3]) 34 | 35 | label_txt = file.replace(".avi", "").split("/")[-1] 36 | 37 | label = 1 if label_txt in true_img_start else 0 38 | # file_name x y w h label 39 | file_list.writelines("%s %d %d %d %d %d\n"%(file_name,bbox[0],bbox[1],bbox[2],bbox[3],label)) 40 | 41 | cv2.imwrite(file_name,frame) 42 | frame_num+=1 43 | count+=1 44 | success, frame = vidcap.read() # 获取下一帧 45 | 46 | vidcap.release() 47 | 48 | file_list.close() 49 | def read(): 50 | file = open("/home/userwyh/code/dataset/CASIA_frames/file_list.txt") # 打开文件 51 | for line in file: 52 | print(line.strip("\n").split(" ")) 53 | 54 | 55 | if __name__ == '__main__': 56 | db_dir = "/home/userwyh/code/dataset/CASIA" 57 | save_dir = "/home/userwyh/code/dataset/CASIA_frames" 58 | generate_frames_and_bbox(db_dir,save_dir,3) 59 | 60 | 61 | # read() -------------------------------------------------------------------------------- /statistics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import roc_curve 2 | from sklearn.metrics import auc 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import sys 6 | 7 | # import matplotlib.pyplot as plt 8 | 9 | 10 | 11 | def HTER(label, score, thred=0.5, EERtype ="hter"): 12 | scores = [] 13 | FAR_SUM = [] 14 | FRR_SUM = [] 15 | TPR_SUM = [] 16 | roc_EER = [] 17 | for i in range(0, len(label)): 18 | tmp = [] 19 | tmp1 = score[i] 20 | tmp2 = label[i] 21 | tmp.append(tmp1) 22 | tmp.append(tmp2) 23 | scores.append(tmp) 24 | #print score 25 | scores = sorted(scores); # min->max 26 | #print scores 27 | sort_score = np.matrix(scores); 28 | #print sort_score 29 | minIndex = sys.maxsize; 30 | minDis = sys.maxsize; 31 | minTh = sys.maxsize; 32 | eer = sys.maxsize; 33 | alltrue = sort_score.sum(axis=0)[0,1]; 34 | allfalse = len(scores) - alltrue; 35 | fa = allfalse; 36 | miss = 0; 37 | #print sort_score 38 | #print alltrue 39 | for i in range(0, len(scores)): 40 | # min -> max 41 | if sort_score[i, 1] == 1: 42 | miss += 1; 43 | else: 44 | fa -= 1; 45 | 46 | FAR=float(fa)/allfalse; 47 | FRR=float(miss)/alltrue; 48 | TPR=1-FRR 49 | FAR_SUM.append(FAR) 50 | FRR_SUM.append(FRR) 51 | TPR_SUM.append(TPR) 52 | if FAR == 0.1: 53 | TPR_r = TPR 54 | #print "when FAR = 0.1, TPR = %f"%TPR_r 55 | 56 | if abs(FAR - FRR) < minDis: 57 | minDis = abs(FAR - FRR) 58 | eer = min(FAR,FRR); 59 | minIndex = i; 60 | minTh = sort_score[i, 0]; 61 | roc_auc = auc(FAR_SUM, TPR_SUM) 62 | #print score 63 | #print sort_score[:,0] 64 | cords = list(zip(FAR_SUM, FRR_SUM, sort_score[:,0])) 65 | ht = [] 66 | ht.append(FAR_SUM) 67 | ht.append(FRR_SUM) 68 | ht = np.array(ht) 69 | ind = np.argmin(np.mean(ht,axis=0)) 70 | # print (ind) 71 | # print (ht.shape) 72 | hter_min = (ht[0,ind] + ht[1,ind])/2.0 73 | # print (ht[:,ind]) 74 | #print cords 75 | for item in cords: 76 | item_fpr, item_fnr, item_thd = item 77 | roc_EER.append(abs(item_thd - thred)) 78 | eer_index = np.argmin(roc_EER) 79 | eer_fpr, eer_fnr, thd = cords[eer_index] 80 | hter = (eer_fpr + eer_fnr)/2 81 | # print (eer_fpr,eer_fnr,thd) 82 | print (EERtype + " " + 'HTER is :%f %%' % (hter*100)) 83 | print (EERtype + " " + 'FAR is :%f' % eer_fpr) 84 | print (EERtype + " " + 'FRR is :%f' % eer_fnr) 85 | return hter 86 | 87 | def EER(label, score, EERtype="eer"): 88 | scores = [] 89 | FAR_SUM = [] 90 | FRR_SUM = [] 91 | TPR_SUM = [] 92 | for i in range(0, len(label)): 93 | tmp = [] 94 | tmp1 = score[i] 95 | tmp2 = label[i] 96 | tmp.append(tmp1) 97 | tmp.append(tmp2) 98 | scores.append(tmp) 99 | 100 | scores = sorted(scores); # min->max 101 | sort_score = np.matrix(scores); 102 | minIndex = sys.maxsize; 103 | minDis = sys.maxsize; 104 | minTh = sys.maxsize; 105 | eer = sys.maxsize; 106 | alltrue = sort_score.sum(axis=0)[0,1]; 107 | allfalse = len(scores) - alltrue; 108 | fa = allfalse; 109 | miss = 0; 110 | #print sort_score 111 | #print alltrue 112 | for i in range(0, len(scores)): 113 | # min -> max 114 | if sort_score[i, 1] == 1: 115 | miss += 1; 116 | else: 117 | fa -= 1; 118 | 119 | FAR=float(fa)/allfalse; 120 | FRR=float(miss)/alltrue; 121 | TPR=1-FRR 122 | FAR_SUM.append(FAR) 123 | FRR_SUM.append(FRR) 124 | TPR_SUM.append(TPR) 125 | 126 | if abs(FAR - FRR) < minDis: 127 | minDis = abs(FAR - FRR) 128 | eer = min(FAR,FRR) 129 | minTh = sort_score[i, 0] 130 | roc_auc = auc(FAR_SUM, TPR_SUM) 131 | 132 | #plt.plot(FAR_SUM, TPR_SUM, lw=1, label='ROC(area = %f)'%(roc_auc)) 133 | #plt.plot(FAR_SUM, TPR_SUM, lw=1) 134 | #plt.plot([0, 1], [1, 0], '--', color=(0.6, 0.6, 0.6), label='Luck') 135 | #plt.savefig("test_result") 136 | #plt.show() 137 | 138 | #plt.plot(FAR_SUM, FRR_SUM, lw=1) 139 | #plt.plot([0, 1], [1, 0], '--', color=(0.6, 0.6, 0.6), label='Luck') 140 | #plt.savefig("test_result2") 141 | #plt.show() 142 | print (EERtype + " " + 'EER is :%f %%' % (eer*100)) 143 | print (EERtype + " " + 'AUC is :%f' % roc_auc) 144 | print (EERtype + " " + 'thd is :%f' % minTh) 145 | 146 | return roc_auc, eer, minTh 147 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 19.10.8 20:47 3 | @File:train.py 4 | @author: coderwangson 5 | """ 6 | "#codeing=utf-8" 7 | import os 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torch.utils import data 12 | from Data import Data 13 | from Nets import Net 14 | from sklearn.metrics import accuracy_score 15 | import statistics 16 | 17 | os.environ['CUDA_VISIBLE_DEVICES']='2' 18 | # Detect devices 19 | use_cuda = torch.cuda.is_available() 20 | device = torch.device("cuda" if use_cuda else "cpu") # use CPU or GPU 21 | learning_rate = 0.01 22 | 23 | params = {'batch_size': 32, 'shuffle': True, 'num_workers': 8, 24 | 'pin_memory': True} if use_cuda else {} 25 | if use_cuda: 26 | net = Net(2).to(device) 27 | if torch.cuda.device_count() > 1:# if train using DataParallel,test must using DataParallel 28 | print("Using", torch.cuda.device_count(), "GPUs!") 29 | model = nn.DataParallel(net) 30 | 31 | train_data_set = Data("/home/userwyh/code/dataset/CASIA_scale/scale_2.2/",True) 32 | test_data_set = Data("/home/userwyh/code/dataset/CASIA_scale/scale_2.2/",False) 33 | train_data_loader = data.DataLoader(train_data_set, **params) 34 | test_data_loader = data.DataLoader(test_data_set, **params) 35 | 36 | optimizer = torch.optim.SGD(list(net.parameters()), lr=learning_rate) 37 | 38 | def train(net,dataloader,optimizer,epoch): 39 | net.train() 40 | scores = [] 41 | all_y = [] 42 | 43 | for e in range(epoch): 44 | for batch_idx, (X, y) in enumerate(dataloader): 45 | 46 | X, y = X.to(device), y.to(device).view(-1, ) 47 | optimizer.zero_grad() 48 | output = net(X) 49 | scores.extend(F.softmax(output).detach().cpu().numpy()[:, 1:]) 50 | all_y.extend(y.cpu().numpy()) 51 | loss = F.cross_entropy(output, y) 52 | loss.backward() 53 | optimizer.step() 54 | 55 | 56 | if e%5==0: 57 | print("the loss is",loss.item()) 58 | val(net,test_data_loader,e) 59 | 60 | def val(net,dataloader,e): 61 | net.eval() 62 | scores = [] 63 | all_y = [] 64 | for batch_idx, (X, y) in enumerate(dataloader): 65 | 66 | X, y = X.to(device), y.to(device).view(-1, ) 67 | optimizer.zero_grad() 68 | output = net(X) 69 | scores.extend(F.softmax(output).detach().cpu().numpy()[:, 1:]) 70 | all_y.extend(y.cpu().numpy()) 71 | 72 | 73 | print("epoch %d "%(e)) 74 | statistics.EER(all_y, scores) 75 | statistics.HTER(all_y, scores, 0.5) 76 | 77 | if __name__ == '__main__': 78 | 79 | train(net,train_data_loader,optimizer,50) --------------------------------------------------------------------------------