├── config.py ├── viz.png ├── README.md ├── train.py ├── loss_fn.py ├── demo.py ├── unet.py └── dataset.py /config.py: -------------------------------------------------------------------------------- 1 | device = 8 2 | -------------------------------------------------------------------------------- /viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aoru45/Learning-Generalized-Spoof-Cues-for-Face-Anti-spoofing/HEAD/viz.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Generalized Spoof Cues for Face Anti spoofing 2 | 3 | This is an unofficial code for the paper "Learning Generalized Spoof Cues for Face Anti spoofing" in pytorch. 4 | 5 | I train the model using our own dataset. The classification accuracy is good but the visualization is not as good as the paper states. 6 | 7 | ## Install 8 | 9 | ```cmd 10 | pip install torch torchvision tqdm albumentations 11 | ``` 12 | 13 | ## Usage 14 | 15 | First, make a dir containing positive and negative folder and place the corresponding image in the folder. 16 | 17 | Second, configure data path in dataset.py . 18 | 19 | Then run training 20 | 21 | ```cmd 22 | python train.py 23 | ``` 24 | 25 | For visualization, 26 | 27 | ```cmd 28 | python demo.py 29 | ``` 30 | 31 | The visualization effect is as shown in the figure: 32 | 33 | ![viz](viz.png) 34 | 35 | It will print the classification score at the same time. 36 | 37 | ## Contributing 38 | 39 | PRs accepted. 40 | 41 | ## License 42 | 43 | MIT © Aoru xue 44 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import torch.nn.init as init 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision.utils import save_image 7 | import numpy as np 8 | import glob 9 | from torch.utils.data import DataLoader 10 | import torch.optim as optim 11 | from unet import TripUNet 12 | import random 13 | import os 14 | from PIL import Image 15 | from torchvision.transforms import transforms 16 | from torch.utils.data.sampler import WeightedRandomSampler 17 | import numpy as np 18 | import torch.backends.cudnn as cudnn 19 | from loss_fn import TotalLoss 20 | from dataset import Traindataset 21 | cudnn.benchmark = True 22 | from config import * 23 | def train(model,criterion,optimizer,dataloader,num_epoches = 150,scheduler = None): 24 | model.cuda(device) 25 | step = 0 26 | for epoch in range(num_epoches): 27 | 28 | for phase in ["train", "save"]: 29 | 30 | if phase == "train": 31 | 32 | running_loss = 0.0 33 | 34 | model.train() 35 | 36 | for anchors, positives, negatives, labels in tqdm(dataloader[phase]): 37 | 38 | anchors = anchors.cuda(device) 39 | 40 | positives = positives.cuda(device) 41 | 42 | negatives = negatives.cuda(device) 43 | 44 | labels = labels.cuda(device) # 0 for positive positive negative and 1 for negative negative positive 45 | 46 | regression, classification, feat = model(anchors, positives, negatives) 47 | 48 | loss = criterion(regression, classification, feat, labels) 49 | 50 | optimizer.zero_grad() 51 | 52 | loss.backward() 53 | 54 | optimizer.step() 55 | 56 | running_loss += loss.item() 57 | if step % 100 == 0: 58 | 59 | print("-step: {} -loss: {} ".format(step, loss.item())) 60 | print("-epoch:{} -phase:{} -loss:{}".format(epoch,phase,running_loss/len(dataloader[phase]))) 61 | 62 | if scheduler is not None: 63 | 64 | scheduler.step() 65 | else: 66 | model.eval() 67 | 68 | # correct = 0 69 | 70 | # with torch.no_grad(): 71 | 72 | # for inputs,targets in tqdm(dataloader[phase]): 73 | 74 | # inputs = inputs.cuda(device) 75 | 76 | # targets = targets.cuda(device) 77 | 78 | # out = model(inputs).data.cpu().numpy() # (n,1) 79 | 80 | # pred = np.where(out >0.5, 1, 0) 81 | 82 | # num_correct = (pred == targets.long().data.cpu().numpy()).sum().item() 83 | 84 | # correct += num_correct 85 | 86 | # acc = correct/len(dataloader[phase].dataset) 87 | 88 | # print("-epoch: {} -phase: {} accuracy: {}".format(epoch,phase,acc)) 89 | 90 | # if acc >= best_acc: 91 | 92 | # best_acc = acc 93 | 94 | torch.save(model.state_dict(),"./ckpt/{}.pth".format(epoch)) 95 | 96 | if __name__ == "__main__": 97 | 98 | model = TripUNet() #resnet18(pretrained = False) 99 | 100 | criterion = TotalLoss() 101 | 102 | criterion = criterion.cuda(device) 103 | 104 | train_dataset = Traindataset() 105 | 106 | dataloader = DataLoader(train_dataset, batch_size = 12, shuffle = True , num_workers=2) 107 | 108 | optimizer = optim.Adam(params = model.parameters(),lr = 0.003) 109 | 110 | train(model, criterion, optimizer, {"train": dataloader}, num_epoches=150) 111 | -------------------------------------------------------------------------------- /loss_fn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from config import * 5 | 6 | class TripletLoss(nn.Module): 7 | def __init__(self, margin = 0.2): 8 | super(TripletLoss,self).__init__() 9 | self.margin = margin 10 | def forward(self, f_anchor, f_positive, f_negative): # (-1,c) 11 | f_anchor, f_positive, f_negative = renorm(f_anchor), renorm(f_positive), renorm(f_negative) 12 | b = f_anchor.size(0) 13 | f_anchor = f_anchor.view(b,-1) 14 | f_positive = f_positive.view(b,-1) 15 | f_negative = f_negative.view(b, -1) 16 | with torch.no_grad(): 17 | idx = hard_samples_mining(f_anchor, f_positive, f_negative, self.margin) 18 | 19 | d_ap = torch.norm(f_anchor[idx] - f_positive[idx], dim = 1) # (-1,1) 20 | d_an = torch.norm(f_anchor[idx] - f_negative[idx], dim = 1) 21 | return torch.clamp(d_ap - d_an + self.margin,0).mean() 22 | 23 | 24 | 25 | def hard_samples_mining(f_anchor,f_positive, f_negative, margin): 26 | d_ap = torch.norm(f_anchor - f_positive, dim = 1) 27 | d_an = torch.norm(f_anchor - f_negative, dim = 1) 28 | idx = (d_ap - d_an) < margin 29 | return idx 30 | def renorm(x): # Important for training! 31 | # renorm in batch axis to make sure every vector is in the range of [0,1] 32 | # important ! 33 | return x.renorm(2,0,1e-5).mul(1e5) 34 | class TotalLoss(nn.Module): 35 | def __init__(self,margin = 0.2): 36 | super(TotalLoss, self).__init__() 37 | self.margin = margin 38 | self.trip = TripletLoss(margin) 39 | self.reg = nn.MSELoss() 40 | self.cla = nn.CrossEntropyLoss() 41 | 42 | def forward(self, regression, classification, feat, labels): 43 | regression_anchor, regression_positive, regression_negative = regression 44 | b,c,_,_ = regression_anchor.size() 45 | classification_anchor, classification_positive, classification_negative = classification 46 | 47 | feat_anchor, feat_positive, feat_negative = feat 48 | reg_loss = self.reg(regression_negative[labels == 1], torch.zeros_like(regression_negative[labels == 1]).cuda(device)) + self.reg(regression_anchor[labels == 0], torch.zeros_like(regression_anchor[labels == 0]).cuda(device)) + self.reg(regression_positive[labels == 0], torch.zeros_like(regression_positive[labels == 0]).cuda(device)) 49 | cla_loss = self.cla(classification_anchor[labels==0], torch.tensor([1] * classification_anchor[labels==0].size(0), dtype = torch.long).cuda(device)) + \ 50 | self.cla(classification_anchor[labels==1], torch.tensor([0] * classification_anchor[labels==1].size(0), dtype = torch.long).cuda(device)) + \ 51 | self.cla(classification_positive[labels==0], torch.tensor([1] * classification_positive[labels==0].size(0), dtype = torch.long).cuda(device)) + \ 52 | self.cla(classification_positive[labels==1], torch.tensor([0] * classification_positive[labels==1].size(0), dtype = torch.long).cuda(device)) + \ 53 | self.cla(classification_negative[labels==0], torch.tensor([0] * classification_negative[labels==0].size(0), dtype = torch.long).cuda(device)) + \ 54 | self.cla(classification_negative[labels==1], torch.tensor([1] * classification_negative[labels==1].size(0), dtype = torch.long).cuda(device)) 55 | trip_loss = sum([self.trip(a,b,c) for a,b,c in zip(feat_anchor, feat_positive, feat_negative)]) 56 | return reg_loss + cla_loss + trip_loss 57 | if __name__ == "__main__": 58 | regression = [torch.randn(1,3,24,24), torch.randn(1,3,24,24), torch.randn(1,3,24,24)] 59 | classification = [torch.randn(1,2), torch.randn(1,2), torch.randn(1,2)] 60 | feat = [[torch.randn(1,16),torch.randn(1,16)],[torch.randn(1,16),torch.randn(1,16)],[torch.randn(1,16),torch.randn(1,16)]] 61 | labels = torch.tensor([0],dtype = torch.long) 62 | loss_fn = TotalLoss() 63 | res = loss_fn(regression, classification, feat, labels) 64 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import Dataset 4 | import torchvision.transforms as transforms 5 | from PIL import Image 6 | import os 7 | import random 8 | import numpy as np 9 | import cv2 as cv 10 | import math 11 | import sys 12 | import glob 13 | from config import * 14 | from torchvision.utils import save_image 15 | from unet import TripUNet 16 | class MyModel(nn.Module): 17 | def __init__(self): 18 | self.net = ModifiedUNet(n_channels = 3, n_classes = 3) 19 | self.net.eval() 20 | def forward(self,anchor, positive, negative): 21 | regression_anchor, classification_anchor, _ = self.net(anchor) # (b,3,h,w) and (b,2) 22 | regression_positive, classification_positive, _ = self.net(positive) 23 | regression_negative, classification_negative, _ = self.net(negative) 24 | return [regression_anchor, regression_positive, regression_negative] , [classification_anchor, classification_positive, classification_negative] 25 | 26 | def glob2(pattern1, pattern2): 27 | files = glob.glob(pattern1) 28 | files.extend(glob.glob(pattern2)) 29 | return files 30 | class Traindataset(Dataset): 31 | def __init__(self,root = "/ssd/xingduan/BCTC_ALL/Data",sub_dirs = ["2D_Plane","2D_Plane_Mask", "3D_Head_Model_Silicone", "3D_Head_Model_Wax", "Half_Mask"]): 32 | self.root = root 33 | self.sub_dirs = sub_dirs 34 | self.pos_filelist = { 35 | "liveness": glob2("{}/{}/*_rgb.jpg".format(root, "Live_Person"), "{}/{}/*_ir.jpg".format(root, "Live_Person")) 36 | } 37 | self.neg_filelist = { 38 | sub_dir: glob2("{}/{}/*_rgb.jpg".format(root, sub_dir), "{}/{}/*_ir.jpg".format(root, sub_dir)) for sub_dir in sub_dirs 39 | } 40 | self.transform = transforms.Compose([ 41 | transforms.Resize((64,64)), 42 | transforms.Grayscale(3), 43 | transforms.ToTensor() 44 | ]) 45 | 46 | 47 | def __getitem__(self,idx): 48 | imgs = [] 49 | for k in range(3): 50 | if k == 0: 51 | t = random.randint(0, len(self.pos_filelist["liveness"]) -1) 52 | l = self.pos_filelist["liveness"][t].split() # 取一个正样本 53 | elif k == 1: 54 | t = random.randint(0, len(self.pos_filelist["liveness"]) -1) 55 | l = self.pos_filelist["liveness"][t].split() 56 | else: 57 | key = random.choice(self.sub_dirs) 58 | t = random.randint(0, len(self.neg_filelist[key]) -1) 59 | l = self.neg_filelist[key][t].split() # 从所有类型的负样本中随机选取一个 60 | img_path = l[0] 61 | 62 | img = Image.open(os.path.join(self.root, img_path)).convert("RGB") 63 | 64 | img_w, img_h = img.size 65 | 66 | ymin,ymax,xmin,xmax = 92, 188, 42, 138 # crop 整张脸 67 | 68 | img = img.crop([xmin,ymin,xmax,ymax]) 69 | 70 | img = self.transform(img) 71 | 72 | imgs.append(img) 73 | 74 | return imgs[0], imgs[1], imgs[2] 75 | def __len__(self): 76 | return 20000 77 | 78 | 79 | if __name__ == "__main__": 80 | train_dataset = Traindataset() 81 | model = MyModel() 82 | model.eval() 83 | model.load_state_dict(torch.load("./ckpt/149.pth")) 84 | for i in range(30): 85 | anchor, positive, negative = train_dataset[0] 86 | reg,cla = model(anchor.unsqueeze(0), positive.unsqueeze(0), negative.unsqueeze(0)) 87 | 88 | cla_anchor, cla_positive, cla_negative = cla 89 | save_image(torch.cat([reg], dim= 0),"a.jpg") 90 | 91 | print(cla_anchor.data, cla_positive.data, cla_negative.data) 92 | 93 | img_anchor = (anchor.permute((1,2,0)).data.numpy() * 255).astype(np.uint8) 94 | img_positive = (positive.permute((1,2,0)).data.numpy() * 255).astype(np.uint8) 95 | img_negative = (negative.permute((1,2,0)).data.numpy() * 255).astype(np.uint8) 96 | cv.imshow("img_anchor", cv.cvtColor(img_anchor, cv.COLOR_RGB2BGR)) 97 | cv.imshow("img_positive", cv.cvtColor(img_positive, cv.COLOR_RGB2BGR)) 98 | cv.imshow("img_negative", cv.cvtColor(img_negative, cv.COLOR_RGB2BGR)) 99 | key = cv.waitKey(0) 100 | if key == ord("q"): 101 | break 102 | 103 | cv.destroyAllWindows() 104 | -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models import resnet18 5 | class DoubleConv(nn.Module): 6 | """(convolution => [BN] => ReLU) * 2""" 7 | 8 | def __init__(self, in_channels, out_channels, mid_channels=None): 9 | super().__init__() 10 | if not mid_channels: 11 | mid_channels = out_channels 12 | self.double_conv = nn.Sequential( 13 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 14 | nn.BatchNorm2d(mid_channels), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(out_channels), 18 | nn.ReLU(inplace=True) 19 | ) 20 | 21 | def forward(self, x): 22 | return self.double_conv(x) 23 | 24 | 25 | class Down(nn.Module): 26 | """Downscaling with maxpool then double conv""" 27 | 28 | def __init__(self, in_channels, out_channels): 29 | super().__init__() 30 | self.maxpool_conv = nn.Sequential( 31 | nn.MaxPool2d(2), 32 | DoubleConv(in_channels, out_channels) 33 | ) 34 | 35 | def forward(self, x): 36 | return self.maxpool_conv(x) 37 | 38 | 39 | class Up(nn.Module): 40 | """Upscaling then double conv""" 41 | 42 | def __init__(self, in_channels, out_channels, bilinear=True): 43 | super().__init__() 44 | 45 | # if bilinear, use the normal convolutions to reduce the number of channels 46 | if bilinear: 47 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 48 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 49 | else: 50 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 51 | self.conv = DoubleConv(in_channels, out_channels) 52 | 53 | 54 | def forward(self, x1, x2): 55 | x1 = self.up(x1) 56 | # input is CHW 57 | diffY = x2.size()[2] - x1.size()[2] 58 | diffX = x2.size()[3] - x1.size()[3] 59 | 60 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 61 | diffY // 2, diffY - diffY // 2]) 62 | # if you have padding issues, see 63 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 64 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 65 | x = torch.cat([x2, x1], dim=1) 66 | return self.conv(x) 67 | 68 | 69 | class OutConv(nn.Module): 70 | def __init__(self, in_channels, out_channels): 71 | super(OutConv, self).__init__() 72 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 73 | 74 | def forward(self, x): 75 | return self.conv(x) 76 | 77 | 78 | 79 | class ModifiedUNet(nn.Module): 80 | def __init__(self, n_channels, n_classes, bilinear=True): 81 | super(ModifiedUNet, self).__init__() 82 | self.n_channels = n_channels 83 | self.n_classes = n_classes 84 | self.bilinear = bilinear 85 | 86 | self.inc = DoubleConv(n_channels, 16) # 48 87 | self.down1 = Down(16, 32) # 24 88 | self.down2 = Down(32, 64) # 12 89 | self.down3 = Down(64, 128) # 6 90 | factor = 2 if bilinear else 1 91 | self.down4 = Down(128, 256 // factor) # 3 92 | self.up1 = Up(256, 128 // factor, bilinear) # 6 93 | self.up2 = Up(128, 64 // factor, bilinear)# 12 94 | self.up3 = Up(64, 32 // factor, bilinear)# 24 95 | self.up4 = Up(32, 16, bilinear) # 48 96 | self.regression = OutConv(16, n_classes)# 48 97 | self.classification = resnet18(pretrained = True) 98 | in_features = self.classification.fc.in_features 99 | self.classification.fc = nn.Linear(in_features, 2) 100 | def forward(self, x): 101 | e1 = self.inc(x) 102 | e2 = self.down1(e1) 103 | e3 = self.down2(e2) 104 | e4 = self.down3(e3) 105 | e5 = self.down4(e4) 106 | d1 = self.up1(e5, e4) 107 | d2 = self.up2(d1, e3) 108 | d3 = self.up3(d2, e2) 109 | d4 = self.up4(d3, e1) 110 | regression = self.regression(d4) 111 | classification = self.classification(x + regression) 112 | feat = [e5, d1, d2, d3, d4] 113 | return regression, classification, feat 114 | class TripUNet(nn.Module): 115 | def __init__(self,): 116 | super(TripUNet, self).__init__() 117 | self.net = ModifiedUNet(n_channels = 3, n_classes = 3) 118 | def forward(self,anchor, positive, negative): 119 | regression_anchor, classification_anchor, feat_anchor = self.net(anchor) 120 | regression_positive, classification_positive, feat_positive = self.net(positive) 121 | regression_negative, classification_negative, feat_negative = self.net(negative) 122 | return [regression_anchor, regression_positive, regression_negative],\ 123 | [classification_anchor, classification_positive, classification_negative],\ 124 | [feat_anchor, feat_positive, feat_negative] 125 | 126 | if __name__ == "__main__": 127 | model = UNet(3, 3) 128 | x = torch.randn(1,3,64,64) 129 | print(model(x).size()) 130 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import Dataset 4 | import torchvision.transforms as transforms 5 | from PIL import Image 6 | import os 7 | import random 8 | import numpy as np 9 | from albumentations import * 10 | import cv2 as cv 11 | import math 12 | import sys 13 | import glob 14 | def strong_aug(p=0.5): 15 | return Compose([ 16 | HorizontalFlip(), 17 | OneOf([ 18 | GaussNoise(), 19 | ], p=0.2), 20 | OneOf([ 21 | MotionBlur(p=0.2), 22 | MedianBlur(blur_limit=1, p=0.1), 23 | ], p=0.2), 24 | ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=30, p=0.2), 25 | OneOf([ 26 | #CLAHE(clip_limit=2), 27 | #IAASharpen(), 28 | #IAAEmboss(), 29 | RandomBrightnessContrast(0.1,0.1), 30 | ], p=0.3), 31 | #HueSaturationValue(p=0.3), 32 | ], p=p) 33 | 34 | def glob2(pattern1, pattern2): 35 | files = glob.glob(pattern1) 36 | files.extend(glob.glob(pattern2)) 37 | return files 38 | class Traindataset(Dataset): 39 | def __init__(self,root = "/ssd/xingduan/BCTC_ALL/Data",sub_dirs = ["2D_Plane","2D_Plane_Mask", "3D_Head_Model_Silicone", "3D_Head_Model_Wax", "Half_Mask"]): 40 | self.root = root 41 | self.sub_dirs = sub_dirs 42 | self.pos_filelist = { 43 | "liveness": glob2("{}/{}/*_rgb.jpg".format(root, "Live_Person"), "{}/{}/*_ir.jpg".format(root, "Live_Person")) 44 | } 45 | self.neg_filelist = { 46 | sub_dir: glob2("{}/{}/*_rgb.jpg".format(root, sub_dir), "{}/{}/*_ir.jpg".format(root, sub_dir)) for sub_dir in sub_dirs 47 | } 48 | self.transform = transforms.Compose([ 49 | transforms.Resize((64,64)), 50 | transforms.RandomGrayscale(0.5), 51 | transforms.ToTensor() 52 | ]) 53 | self.aug = strong_aug(0.5) 54 | 55 | 56 | def __getitem__(self,idx): 57 | imgs = [] 58 | labels = None # 规定 0 -> 正正负 1 -> 负负正 59 | if idx % 2 ==0: # 正正负的情况 60 | labels = 0 61 | for k in range(3): 62 | if k == 0: 63 | t = random.randint(0, len(self.pos_filelist["liveness"]) -1) 64 | l = self.pos_filelist["liveness"][t].split() # 取一个正样本 65 | elif k == 1: 66 | t = random.randint(0, len(self.pos_filelist["liveness"]) -1) 67 | l = self.pos_filelist["liveness"][t].split() 68 | else: 69 | key = random.choice(self.sub_dirs) 70 | t = random.randint(0, len(self.neg_filelist[key]) -1) 71 | l = self.neg_filelist[key][t].split() # 从所有类型的负样本中随机选取一个 72 | img_path = l[0] 73 | 74 | img = Image.open(os.path.join(self.root, img_path)).convert("RGB") 75 | 76 | img_w, img_h = img.size 77 | 78 | ymin,ymax,xmin,xmax = 92, 188, 42, 138 # crop 整张脸 79 | 80 | img = img.crop([xmin,ymin,xmax,ymax]) 81 | 82 | img = self.aug(image = np.array(img))["image"] #self.transform(img) 83 | 84 | img = self.transform(Image.fromarray(img)) 85 | 86 | imgs.append(img) 87 | else: # 负负正的情况 88 | 89 | labels = 1 90 | 91 | for k in range(3): 92 | 93 | if k == 0: 94 | key = random.choice(self.sub_dirs) 95 | t = random.randint(0, len(self.neg_filelist[key]) -1) 96 | l = self.neg_filelist[key][t].split() 97 | elif k == 1: 98 | key = random.choice(self.sub_dirs) 99 | t = random.randint(0, len(self.neg_filelist[key]) -1) 100 | l = self.neg_filelist[key][t].split() 101 | 102 | else: 103 | t = random.randint(0, len(self.pos_filelist["liveness"]) -1) 104 | l = self.pos_filelist["liveness"][t].split() 105 | img_path = l[0] 106 | 107 | img = Image.open(os.path.join(self.root, img_path)).convert("RGB") 108 | 109 | img_w, img_h = img.size 110 | 111 | ymin,ymax,xmin,xmax = 92, 188, 42, 138 # crop 整张脸 112 | 113 | img = img.crop([xmin,ymin,xmax,ymax]) 114 | 115 | img = self.aug(image = np.array(img))["image"] #self.transform(img) 116 | 117 | img = self.transform(Image.fromarray(img)) 118 | 119 | imgs.append(img) 120 | 121 | return imgs[0], imgs[1], imgs[2], torch.tensor(labels, dtype = torch.long) 122 | def __len__(self): 123 | return 20000 124 | 125 | 126 | 127 | if __name__ == "__main__": 128 | train_dataset = Traindataset() 129 | 130 | for i in range(30): 131 | anchor, positive, negative, label = train_dataset[0] 132 | 133 | img_anchor = (anchor.permute((1,2,0)).data.numpy() * 255).astype(np.uint8) 134 | img_positive = (positive.permute((1,2,0)).data.numpy() * 255).astype(np.uint8) 135 | img_negative = (negative.permute((1,2,0)).data.numpy() * 255).astype(np.uint8) 136 | cv.imshow("img_anchor", cv.cvtColor(img_anchor, cv.COLOR_RGB2BGR)) 137 | cv.imshow("img_positive", cv.cvtColor(img_positive, cv.COLOR_RGB2BGR)) 138 | cv.imshow("img_negative", cv.cvtColor(img_negative, cv.COLOR_RGB2BGR)) 139 | print(label) 140 | key = cv.waitKey(0) 141 | if key == ord("q"): 142 | break 143 | 144 | cv.destroyAllWindows() 145 | --------------------------------------------------------------------------------