├── .gitignore ├── README.md ├── config.py ├── core ├── __pycache__ │ ├── model.cpython-35.pyc │ └── utils.cpython-35.pyc ├── model.py └── utils.py ├── dataloader ├── CASIA_Face_loader.py └── LFW_loader.py ├── lfw_eval.py ├── model ├── CASIA_ShuffleFaceNet_20200204_195134 │ ├── 010.ckpt │ ├── 020.ckpt │ ├── 030.ckpt │ ├── 040.ckpt │ ├── 050.ckpt │ ├── 060.ckpt │ ├── 070.ckpt │ └── log.log └── best │ └── 060.ckpt ├── result ├── best_result.mat └── tmp_result.mat └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | **/__pycache__ 3 | 4 | .swp 5 | .giosavewh5R2D 6 | 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ShuffleFaceNet Pytorch 2 | 3 | A PyTorch Implementation of ShuffleFaceNet using CosFace Loss and Complexity 1.5x. The code can be trained on CASIA-Webface and tested on LFW. 4 | 5 | [ShuffleFaceNet: A Lightweight Face Architecture for Efficientand Highly-Accurate Face Recognition](http://openaccess.thecvf.com/content_ICCVW_2019/papers/LSR/Martindez-Diaz_ShuffleFaceNet_A_Lightweight_Face_Architecture_for_Efficient_and_Highly-Accurate_Face_ICCVW_2019_paper.pdf) 6 | 7 | # References 8 | [ShuffleNet](https://github.com/kuangliu/pytorch-cifar/blob/master/models/shufflenet.py) 9 | 10 | [CosFace](https://github.com/YirongMao/softmax_variants/blob/master/model_utils.py) 11 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | BATCH_SIZE = 256 2 | SAVE_FREQ = 10 3 | TEST_FREQ = 5 4 | TOTAL_EPOCH = 70 5 | 6 | RESUME = '' 7 | SAVE_DIR = './model' 8 | MODEL_PRE = 'CASIA_ShuffleFaceNet_' 9 | 10 | 11 | CASIA_DATA_DIR = '/home/users/matheusb/recfaces/datasets/CASIA-WebFace/' #'/'CASIA-WebFace 12 | LFW_DATA_DIR = '/home/users/matheusb/recfaces/datasets/LFW/' 13 | 14 | GPU = 0,1,2 -------------------------------------------------------------------------------- /core/__pycache__/model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbfaria/ShuffleFaceNet_Pytorch/39818fbff657dbf4db308c44cd5bd0c70b56fd4f/core/__pycache__/model.cpython-35.pyc -------------------------------------------------------------------------------- /core/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbfaria/ShuffleFaceNet_Pytorch/39818fbff657dbf4db308c44cd5bd0c70b56fd4f/core/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /core/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import numpy as np 5 | 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from torch.nn import Parameter 9 | 10 | # http://openaccess.thecvf.com/content_ICCVW_2019/papers/LSR/Martindez-Diaz_ShuffleFaceNet_A_Lightweight_Face_Architecture_for_Efficient_and_Highly-Accurate_Face_ICCVW_2019_paper.pdf 11 | 12 | 13 | def channel_shuffle(x, groups): 14 | # type: (torch.Tensor, int) -> torch.Tensor 15 | batchsize, num_channels, height, width = x.data.size() 16 | channels_per_group = num_channels // groups 17 | 18 | # reshape 19 | x = x.view(batchsize, groups, 20 | channels_per_group, height, width) 21 | 22 | x = torch.transpose(x, 1, 2).contiguous() 23 | 24 | # flatten 25 | x = x.view(batchsize, -1, height, width) 26 | 27 | return x 28 | 29 | 30 | class InvertedResidual(nn.Module): 31 | def __init__(self, inp, oup, stride): 32 | super(InvertedResidual, self).__init__() 33 | 34 | if not (1 <= stride <= 3): 35 | raise ValueError('illegal stride value') 36 | self.stride = stride 37 | 38 | branch_features = oup // 2 39 | assert (self.stride != 1) or (inp == branch_features << 1) 40 | 41 | if self.stride > 1: 42 | self.branch1 = nn.Sequential( 43 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 44 | nn.BatchNorm2d(inp), 45 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 46 | nn.BatchNorm2d(branch_features), 47 | nn.PReLU(), 48 | ) 49 | else: 50 | self.branch1 = nn.Sequential() 51 | 52 | self.branch2 = nn.Sequential( 53 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 54 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 55 | nn.BatchNorm2d(branch_features), 56 | nn.PReLU(), 57 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 58 | nn.BatchNorm2d(branch_features), 59 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 60 | nn.BatchNorm2d(branch_features), 61 | nn.PReLU(), 62 | ) 63 | 64 | @staticmethod 65 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): 66 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 67 | 68 | def forward(self, x): 69 | if self.stride == 1: 70 | x1, x2 = x.chunk(2, dim=1) 71 | out = torch.cat((x1, self.branch2(x2)), dim=1) 72 | else: 73 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 74 | 75 | out = channel_shuffle(out, 2) 76 | 77 | return out 78 | 79 | 80 | class ShuffleFaceNet(nn.Module): 81 | def __init__(self, stages_repeats=[4, 8, 4], stages_out_channels=[24, 176, 352, 704, 1024], inverted_residual=InvertedResidual): 82 | super(ShuffleFaceNet, self).__init__() 83 | 84 | if len(stages_repeats) != 3: 85 | raise ValueError('expected stages_repeats as list of 3 positive ints') 86 | if len(stages_out_channels) != 5: 87 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 88 | self._stage_out_channels = stages_out_channels 89 | 90 | input_channels = 3 91 | output_channels = self._stage_out_channels[0] 92 | 93 | self.conv1 = nn.Sequential( 94 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), 95 | nn.BatchNorm2d(output_channels), 96 | nn.PReLU(), 97 | ) 98 | input_channels = output_channels 99 | 100 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 101 | 102 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 103 | for name, repeats, output_channels in zip( 104 | stage_names, stages_repeats, self._stage_out_channels[1:]): 105 | seq = [inverted_residual(input_channels, output_channels, 2)] 106 | for i in range(repeats - 1): 107 | seq.append(inverted_residual(output_channels, output_channels, 1)) 108 | setattr(self, name, nn.Sequential(*seq)) 109 | input_channels = output_channels 110 | 111 | output_channels = self._stage_out_channels[-1] 112 | self.conv5 = nn.Sequential( 113 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 114 | nn.BatchNorm2d(output_channels), 115 | nn.PReLU(), 116 | ) 117 | input_channels = output_channels 118 | 119 | self.gdc = nn.Sequential( 120 | nn.Conv2d(input_channels, output_channels, kernel_size=7, stride=1, padding=0, bias=False, groups=input_channels), 121 | nn.BatchNorm2d(output_channels), 122 | nn.PReLU(), 123 | ) 124 | 125 | input_channels = output_channels 126 | output_channels = 128 127 | 128 | self.linearconv = nn.Conv1d(input_channels, output_channels, kernel_size=1, stride=1, padding=0) 129 | 130 | self.bn = nn.BatchNorm2d(output_channels) 131 | 132 | 133 | def _forward_impl(self, x): 134 | # See note [TorchScript super()] 135 | x = nn.functional.interpolate(x, size=[112, 112]) 136 | x = self.conv1(x) 137 | # x = self.maxpool(x) 138 | x = self.stage2(x) 139 | x = self.stage3(x) 140 | x = self.stage4(x) 141 | x = self.conv5(x) 142 | #x = x.mean([2, 3]) # globalpool 143 | x = self.gdc(x) 144 | # x = np.squeeze(x, axis=2) 145 | x = x.view(x.size(0), 1024, 1) 146 | x = self.linearconv(x) 147 | x = x.view(x.size(0), 128, 1, 1) 148 | x = self.bn(x) 149 | x = x.view(x.size(0), -1) 150 | 151 | 152 | return x 153 | 154 | def forward(self, x): 155 | return self._forward_impl(x) 156 | 157 | class ArcMarginProduct(nn.Module): 158 | def __init__(self, in_features=128, out_features=200, s=32.0, m=0.50, easy_margin=False): 159 | super(ArcMarginProduct, self).__init__() 160 | self.in_features = in_features 161 | self.out_features = out_features 162 | self.s = s 163 | self.m = m 164 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 165 | nn.init.xavier_uniform_(self.weight) 166 | # init.kaiming_uniform_() 167 | # self.weight.data.normal_(std=0.001) 168 | 169 | self.easy_margin = easy_margin 170 | self.cos_m = math.cos(m) 171 | self.sin_m = math.sin(m) 172 | # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°] 173 | self.th = math.cos(math.pi - m) 174 | self.mm = math.sin(math.pi - m) * m 175 | 176 | def forward(self, x, label): 177 | cosine = F.linear(F.normalize(x), F.normalize(self.weight)) 178 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 179 | phi = cosine * self.cos_m - sine * self.sin_m 180 | if self.easy_margin: 181 | phi = torch.where(cosine > 0, phi, cosine) 182 | else: 183 | phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) 184 | 185 | one_hot = torch.zeros(cosine.size(), device='cuda') 186 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 187 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 188 | output *= self.s 189 | return output 190 | 191 | class CosFace_loss(nn.Module): 192 | """ 193 | Refer to paper: 194 | Hao Wang, Yitong Wang, Zheng Zhou, Xing Ji, Dihong Gong, Jingchao Zhou,Zhifeng Li, and Wei Liu 195 | CosFace: Large Margin Cosine Loss for Deep Face Recognition. CVPR2018 196 | re-implement by yirong mao 197 | 2018 07/02 198 | """ 199 | 200 | def __init__(self, num_classes, feat_dim, s=32.0, m=0.5): 201 | super(CosFace_loss, self).__init__() 202 | self.feat_dim = feat_dim 203 | self.num_classes = num_classes 204 | self.s = s 205 | self.m = m 206 | self.centers = nn.Parameter(torch.randn(num_classes, feat_dim)) 207 | 208 | def forward(self, feat, label): 209 | # print("feat ", feat.shape) 210 | # print("label ", label.shape) 211 | batch_size = feat.shape[0] 212 | norms = torch.norm(feat, p=2, dim=-1, keepdim=True) 213 | nfeat = torch.div(feat, norms) 214 | 215 | norms_c = torch.norm(self.centers, p=2, dim=-1, keepdim=True) 216 | ncenters = torch.div(self.centers, norms_c).cuda() 217 | # print("nfeats ", nfeat.shape) 218 | # print("ncenters ", ncenters.shape) 219 | 220 | logits = torch.matmul(nfeat, torch.transpose(ncenters, 0, 1)) 221 | # print("logits ", logits.shape) 222 | 223 | y_onehot = torch.FloatTensor(batch_size, self.num_classes) 224 | y_onehot.zero_() 225 | y_onehot = Variable(y_onehot).cuda() 226 | y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.m) 227 | # print("y_onehot ", y_onehot.shape) 228 | margin_logits = self.s * (logits - y_onehot) 229 | 230 | 231 | return logits, margin_logits 232 | 233 | 234 | if __name__ == "__main__": 235 | # input = Variable(torch.FloatTensor(2, 3, 112, 96)) 236 | net = ShuffleFaceNet() 237 | print(net) 238 | # x = net(input) 239 | # print(x.shape) 240 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import logging 4 | 5 | 6 | def init_log(output_dir): 7 | logging.basicConfig(level=logging.DEBUG, 8 | format='%(asctime)s %(message)s', 9 | datefmt='%Y%m%d-%H:%M:%S', 10 | filename=os.path.join(output_dir, 'log.log'), 11 | filemode='w') 12 | console = logging.StreamHandler() 13 | console.setLevel(logging.INFO) 14 | logging.getLogger('').addHandler(console) 15 | return logging 16 | 17 | 18 | if __name__ == '__main__': 19 | pass 20 | -------------------------------------------------------------------------------- /dataloader/CASIA_Face_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import imageio 3 | import os 4 | from sklearn import preprocessing 5 | import torch 6 | #ImageFile.LOAD_TRUNCATED_IMAGES = True 7 | 8 | import sys 9 | sys.path.append("..") 10 | 11 | class CASIA_Face(object): 12 | def __init__(self, root): 13 | self.image_list = [] 14 | self.label_list = [] 15 | 16 | for r, _, files in os.walk(root): 17 | for f in files: 18 | self.image_list.append(os.path.join(r, f)) 19 | self.label_list.append(os.path.basename(r)) 20 | 21 | le = preprocessing.LabelEncoder() 22 | self.label_list = le.fit_transform(self.label_list) 23 | self.class_nums = len(np.unique(self.label_list)) 24 | 25 | def __getitem__(self, index): 26 | img_path = self.image_list[index] 27 | target = self.label_list[index] 28 | img = imageio.imread(img_path) 29 | #img = np.resize(img, (112, 112)) 30 | 31 | 32 | if len(img.shape) == 2: 33 | img = np.stack([img] * 3, 2) 34 | 35 | flip = np.random.choice(2)*2-1 36 | img = img[:, ::flip, :] 37 | img = (img - 127.5) / 128.0 38 | img = img.transpose(2, 0, 1) 39 | img = torch.from_numpy(img).float() 40 | 41 | return img, target 42 | 43 | def __len__(self): 44 | return len(self.image_list) 45 | 46 | 47 | if __name__ == '__main__': 48 | data_dir = '/home/users/matheusb/recfaces/datasets/CASIA-WebFace/' 49 | dataset = CASIA_Face(root=data_dir) 50 | trainloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=8, drop_last=False) 51 | print(len(dataset)) 52 | for data in trainloader: 53 | print(data[0].shape) 54 | -------------------------------------------------------------------------------- /dataloader/LFW_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import imageio 3 | import torch 4 | 5 | import sys 6 | sys.path.append("..") 7 | # from retrieval.dataloaders.preprocessing import preprocess 8 | 9 | 10 | class LFW(object): 11 | def __init__(self, imgl, imgr): 12 | 13 | self.imgl_list = imgl 14 | self.imgr_list = imgr 15 | 16 | def __getitem__(self, index): 17 | imgl = imageio.imread(self.imgl_list[index]) 18 | if len(imgl.shape) == 2: 19 | imgl = np.stack([imgl] * 3, 2) 20 | imgr = imageio.imread(self.imgr_list[index]) 21 | if len(imgr.shape) == 2: 22 | imgr = np.stack([imgr] * 3, 2) 23 | 24 | # imgl = imgl[:, :, ::-1] 25 | # imgr = imgr[:, :, ::-1] 26 | imglist = [imgl, imgl[:, ::-1, :], imgr, imgr[:, ::-1, :]] 27 | for i in range(len(imglist)): 28 | imglist[i] = (imglist[i] - 127.5) / 128.0 29 | imglist[i] = imglist[i].transpose(2, 0, 1) 30 | imgs = [torch.from_numpy(i).float() for i in imglist] 31 | return imgs 32 | 33 | def __len__(self): 34 | return len(self.imgl_list) 35 | 36 | 37 | if __name__ == '__main__': 38 | data_dir = '/home/users/keiller/recfaces/datasets/LFW/' 39 | from lfw_eval import parseList 40 | nl, nr, folds, flags = parseList(root=data_dir) 41 | dataset = LFW(nl, nr) 42 | trainloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=8, drop_last=False) 43 | print(len(dataset)) 44 | for data in trainloader: 45 | print(data[0].shape) 46 | -------------------------------------------------------------------------------- /lfw_eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | # import caffe 3 | import os 4 | import numpy as np 5 | import cv2 6 | import scipy.io 7 | import copy 8 | import core.model 9 | import os 10 | import torch.utils.data 11 | from core import model 12 | from dataloader.LFW_loader import LFW 13 | from config import LFW_DATA_DIR 14 | import argparse 15 | 16 | 17 | def parseList(root): 18 | with open(os.path.join(root, 'pairs.txt')) as f: 19 | pairs = f.read().splitlines()[1:] 20 | 21 | # ORG 22 | folder_name = 'lfw' 23 | nameLs = [] 24 | nameRs = [] 25 | folds = [] 26 | flags = [] 27 | for i, p in enumerate(pairs): 28 | p = p.split('\t') 29 | if len(p) == 3: 30 | nameL = os.path.join(root, folder_name, p[0], p[0] + '_' + '{:04}.jpg'.format(int(p[1]))) 31 | nameR = os.path.join(root, folder_name, p[0], p[0] + '_' + '{:04}.jpg'.format(int(p[2]))) 32 | fold = i // 600 33 | flag = 1 34 | elif len(p) == 4: 35 | nameL = os.path.join(root, folder_name, p[0], p[0] + '_' + '{:04}.jpg'.format(int(p[1]))) 36 | nameR = os.path.join(root, folder_name, p[2], p[2] + '_' + '{:04}.jpg'.format(int(p[3]))) 37 | fold = i // 600 38 | flag = -1 39 | nameLs.append(nameL) 40 | nameRs.append(nameR) 41 | folds.append(fold) 42 | flags.append(flag) 43 | # print(nameLs) 44 | return [nameLs, nameRs, folds, flags] 45 | 46 | 47 | 48 | def getAccuracy(scores, flags, threshold): 49 | p = np.sum(scores[flags == 1] > threshold) 50 | n = np.sum(scores[flags == -1] < threshold) 51 | accuracy = (p + n) * 1.0 / len(scores) 52 | return accuracy 53 | 54 | 55 | def getThreshold(scores, flags, thrNum): 56 | accuracys = np.zeros((2 * thrNum + 1, 1)) 57 | thresholds = np.arange(-thrNum, thrNum + 1) * 1.0 / thrNum 58 | for i in range(2 * thrNum + 1): 59 | accuracys[i] = getAccuracy(scores, flags, thresholds[i]) 60 | 61 | max_index = np.squeeze(accuracys == np.max(accuracys)) 62 | bestThreshold = np.mean(thresholds[max_index]) 63 | return bestThreshold 64 | 65 | 66 | def evaluation_10_fold(root='./result/pytorch_result.mat'): 67 | ACCs = np.zeros(10) 68 | result = scipy.io.loadmat(root) 69 | for i in range(10): 70 | fold = result['fold'] 71 | flags = result['flag'] 72 | featureLs = result['fl'] 73 | featureRs = result['fr'] 74 | 75 | valFold = fold != i 76 | testFold = fold == i 77 | flags = np.squeeze(flags) 78 | 79 | if featureLs[valFold[0], :].shape[0] == 0: 80 | continue 81 | 82 | if featureRs[valFold[0], :].shape[0] == 0: 83 | continue 84 | 85 | mu = np.mean(np.concatenate((featureLs[valFold[0], :], featureRs[valFold[0], :]), 0), 0) 86 | mu = np.expand_dims(mu, 0) 87 | featureLs = featureLs - mu 88 | featureRs = featureRs - mu 89 | featureLs = featureLs / np.expand_dims(np.sqrt(np.sum(np.power(featureLs, 2), 1)), 1) 90 | featureRs = featureRs / np.expand_dims(np.sqrt(np.sum(np.power(featureRs, 2), 1)), 1) 91 | 92 | scores = np.sum(np.multiply(featureLs, featureRs), 1) 93 | # if len(scores)==0 : 94 | # continue; 95 | 96 | threshold = getThreshold(scores[valFold[0]], flags[valFold[0]], 10000) 97 | # print('Fold', i, 'Threshold', threshold) 98 | 99 | ACCs[i] = getAccuracy(scores[testFold[0]], flags[testFold[0]], threshold) 100 | # print('Fold', i, 'Accuracy', ACCs[i]) 101 | 102 | # print('{} {:.2f}'.format(i+1, ACCs[i] * 100)) 103 | # print('--------') 104 | # print('AVE {:.2f}'.format(np.mean(ACCs) * 100)) 105 | return ACCs 106 | 107 | 108 | 109 | def getFeatureFromTorch(lfw_dir, feature_save_dir, resume=None, gpu=True): 110 | net = model.ShuffleFaceNet() 111 | if gpu: 112 | net = net.cuda() 113 | if resume: 114 | ckpt = torch.load(resume, map_location='cpu') 115 | net.load_state_dict(ckpt['net_state_dict']) 116 | net.eval() 117 | nl, nr, flods, flags = parseList(lfw_dir) 118 | lfw_dataset = LFW(nl, nr) 119 | # lfw_loader = torch.utils.data.DataLoader(lfw_dataset, batch_size=32, 120 | # shuffle=False, num_workers=8, drop_last=False) 121 | 122 | lfw_loader = torch.utils.data.DataLoader(lfw_dataset, batch_size=32, 123 | shuffle=False, num_workers=2, drop_last=False) 124 | 125 | featureLs = None 126 | featureRs = None 127 | count = 0 128 | 129 | print(lfw_loader.dataset.__len__()) 130 | for data in lfw_loader: 131 | if gpu: 132 | for i in range(len(data)): 133 | data[i] = data[i].cuda() 134 | 135 | count += data[0].size(0) 136 | print('extracing deep features from the face pair {}...'.format(count)) 137 | res = [net(d).data.cpu().numpy() for d in data] 138 | featureL = np.concatenate((res[0], res[1]), 1) 139 | featureR = np.concatenate((res[2], res[3]), 1) 140 | 141 | if featureLs is None: 142 | featureLs = featureL 143 | else: 144 | featureLs = np.concatenate((featureLs, featureL), 0) 145 | 146 | if featureRs is None: 147 | featureRs = featureR 148 | else: 149 | featureRs = np.concatenate((featureRs, featureR), 0) 150 | # featureLs.append(featureL) 151 | # featureRs.append(featureR) 152 | 153 | result = {'fl': featureLs, 'fr': featureRs, 'fold': flods, 'flag': flags} 154 | scipy.io.savemat(feature_save_dir, result) 155 | 156 | 157 | # def getFeatureFromCaffe(gpu=True): 158 | # if gpu: 159 | # caffe.set_mode_gpu() 160 | # caffe.set_device(0) 161 | # else: 162 | # caffe.set_mode_cpu() 163 | # # caffe.reset_all() 164 | # model = '/home/xiaocc/Documents/caffe_project/sphereface/train/code/sphereface_deploy.prototxt' 165 | # weights = '/home/xiaocc/Documents/caffe_project/sphereface/train/result/sphereface_model.caffemodel' 166 | # net = caffe.Net(model, weights, caffe.TEST) 167 | # 168 | # nl, nr, flods, flags = parseList() 169 | # 170 | # featureLs = [] 171 | # featureRs = [] 172 | # for i in range(len(nl)): 173 | # print('extracing deep features from the {}th face pair ...'.format(i)) 174 | # featureL = extractDeepFeature(nl[i], net)[0] 175 | # featureR = extractDeepFeature(nr[i], net)[0] 176 | # featureLs.append(featureL) 177 | # featureRs.append(featureR) 178 | # result = {'fl': featureLs, 'fr': featureRs, 'fold': flods, 'flag': flags} 179 | # scipy.io.savemat('caffe_result.mat', result) 180 | # 181 | # def extractDeepFeature(f, net, h=112, w=96): 182 | # img = cv2.imread(f) 183 | # img = (img - 127.5) / 128 184 | # img = img.transpose((2, 0, 1)) 185 | # net.blobs['data'].reshape(1, 3, h, w) 186 | # net.blobs['data'].data[0, ...] = img 187 | # res = copy.deepcopy(net.forward()['fc5']) 188 | # net.blobs['data'].data[0, ...] = img[:, :, ::-1] 189 | # res_ = copy.deepcopy(net.forward()['fc5']) 190 | # r = np.concatenate((res, res_), 1) 191 | # return r 192 | 193 | if __name__ == '__main__': 194 | parser = argparse.ArgumentParser(description='Testing') 195 | parser.add_argument('--lfw_dir', type=str, default=LFW_DATA_DIR, help='The path of lfw data') 196 | parser.add_argument('--resume', type=str, default='./model/best/060.ckpt', 197 | help='The path pf save model') 198 | parser.add_argument('--feature_save_dir', type=str, default='./result/best_result.mat', 199 | help='The path of the extract features save, must be .mat file') 200 | args = parser.parse_args() 201 | 202 | 203 | # getFeatureFromCaffe() 204 | getFeatureFromTorch(args.lfw_dir, args.feature_save_dir, args.resume, False) 205 | ACCs = evaluation_10_fold(args.feature_save_dir) 206 | 207 | for i in range(len(ACCs)): 208 | print('{} {:.2f}'.format(i+1, ACCs[i] * 100)) 209 | print('--------') 210 | print('AVE {:.2f}'.format(np.mean(ACCs) * 100)) 211 | -------------------------------------------------------------------------------- /model/CASIA_ShuffleFaceNet_20200204_195134/010.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbfaria/ShuffleFaceNet_Pytorch/39818fbff657dbf4db308c44cd5bd0c70b56fd4f/model/CASIA_ShuffleFaceNet_20200204_195134/010.ckpt -------------------------------------------------------------------------------- /model/CASIA_ShuffleFaceNet_20200204_195134/020.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbfaria/ShuffleFaceNet_Pytorch/39818fbff657dbf4db308c44cd5bd0c70b56fd4f/model/CASIA_ShuffleFaceNet_20200204_195134/020.ckpt -------------------------------------------------------------------------------- /model/CASIA_ShuffleFaceNet_20200204_195134/030.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbfaria/ShuffleFaceNet_Pytorch/39818fbff657dbf4db308c44cd5bd0c70b56fd4f/model/CASIA_ShuffleFaceNet_20200204_195134/030.ckpt -------------------------------------------------------------------------------- /model/CASIA_ShuffleFaceNet_20200204_195134/040.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbfaria/ShuffleFaceNet_Pytorch/39818fbff657dbf4db308c44cd5bd0c70b56fd4f/model/CASIA_ShuffleFaceNet_20200204_195134/040.ckpt -------------------------------------------------------------------------------- /model/CASIA_ShuffleFaceNet_20200204_195134/050.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbfaria/ShuffleFaceNet_Pytorch/39818fbff657dbf4db308c44cd5bd0c70b56fd4f/model/CASIA_ShuffleFaceNet_20200204_195134/050.ckpt -------------------------------------------------------------------------------- /model/CASIA_ShuffleFaceNet_20200204_195134/060.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbfaria/ShuffleFaceNet_Pytorch/39818fbff657dbf4db308c44cd5bd0c70b56fd4f/model/CASIA_ShuffleFaceNet_20200204_195134/060.ckpt -------------------------------------------------------------------------------- /model/CASIA_ShuffleFaceNet_20200204_195134/070.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbfaria/ShuffleFaceNet_Pytorch/39818fbff657dbf4db308c44cd5bd0c70b56fd4f/model/CASIA_ShuffleFaceNet_20200204_195134/070.ckpt -------------------------------------------------------------------------------- /model/CASIA_ShuffleFaceNet_20200204_195134/log.log: -------------------------------------------------------------------------------- 1 | 20200204-19:54:00 Train Epoch: 1/70 ... 2 | 20200204-20:09:01 total_loss: 24.7161 time: 15m 1s 3 | 20200204-20:09:01 Train Epoch: 2/70 ... 4 | 20200204-20:19:09 total_loss: 22.4334 time: 10m 8s 5 | 20200204-20:19:09 Train Epoch: 3/70 ... 6 | 20200204-20:30:28 total_loss: 20.0291 time: 11m 20s 7 | 20200204-20:30:28 Train Epoch: 4/70 ... 8 | 20200204-20:42:53 total_loss: 17.6823 time: 12m 25s 9 | 20200204-20:42:53 Train Epoch: 5/70 ... 10 | 20200204-20:53:10 total_loss: 15.9964 time: 10m 17s 11 | 20200204-20:53:10 Test Epoch: 5 ... 12 | 20200204-20:54:19 ave: 87.3500 13 | 20200204-20:54:19 Train Epoch: 6/70 ... 14 | 20200204-21:04:27 total_loss: 14.8365 time: 10m 8s 15 | 20200204-21:04:27 Train Epoch: 7/70 ... 16 | 20200204-21:14:43 total_loss: 14.0145 time: 10m 16s 17 | 20200204-21:14:43 Train Epoch: 8/70 ... 18 | 20200204-21:24:57 total_loss: 13.3993 time: 10m 14s 19 | 20200204-21:24:57 Train Epoch: 9/70 ... 20 | 20200204-21:35:13 total_loss: 12.9140 time: 10m 15s 21 | 20200204-21:35:13 Train Epoch: 10/70 ... 22 | 20200204-21:45:28 total_loss: 12.5254 time: 10m 15s 23 | 20200204-21:45:28 Test Epoch: 10 ... 24 | 20200204-21:46:35 ave: 93.2333 25 | 20200204-21:46:35 Saving checkpoint: 10 26 | 20200204-21:46:35 Train Epoch: 11/70 ... 27 | 20200204-21:56:44 total_loss: 12.2015 time: 10m 9s 28 | 20200204-21:56:44 Train Epoch: 12/70 ... 29 | 20200204-22:07:40 total_loss: 11.9267 time: 10m 56s 30 | 20200204-22:07:40 Train Epoch: 13/70 ... 31 | 20200204-22:24:12 total_loss: 11.6846 time: 16m 32s 32 | 20200204-22:24:12 Train Epoch: 14/70 ... 33 | 20200204-22:34:26 total_loss: 11.4831 time: 10m 14s 34 | 20200204-22:34:26 Train Epoch: 15/70 ... 35 | 20200204-22:44:46 total_loss: 11.2979 time: 10m 19s 36 | 20200204-22:44:46 Test Epoch: 15 ... 37 | 20200204-22:45:52 ave: 94.8000 38 | 20200204-22:45:52 Train Epoch: 16/70 ... 39 | 20200204-22:56:13 total_loss: 11.1343 time: 10m 21s 40 | 20200204-22:56:13 Train Epoch: 17/70 ... 41 | 20200204-23:06:29 total_loss: 10.9896 time: 10m 16s 42 | 20200204-23:06:29 Train Epoch: 18/70 ... 43 | 20200204-23:16:42 total_loss: 10.8613 time: 10m 13s 44 | 20200204-23:16:42 Train Epoch: 19/70 ... 45 | 20200204-23:27:03 total_loss: 10.7441 time: 10m 21s 46 | 20200204-23:27:03 Train Epoch: 20/70 ... 47 | 20200204-23:37:22 total_loss: 10.6429 time: 10m 19s 48 | 20200204-23:37:22 Test Epoch: 20 ... 49 | 20200204-23:38:28 ave: 94.9167 50 | 20200204-23:38:28 Saving checkpoint: 20 51 | 20200204-23:38:28 Train Epoch: 21/70 ... 52 | 20200204-23:48:39 total_loss: 10.5420 time: 10m 11s 53 | 20200204-23:48:39 Train Epoch: 22/70 ... 54 | 20200204-23:58:45 total_loss: 10.4538 time: 10m 6s 55 | 20200204-23:58:45 Train Epoch: 23/70 ... 56 | 20200205-00:08:51 total_loss: 10.3765 time: 10m 6s 57 | 20200205-00:08:51 Train Epoch: 24/70 ... 58 | 20200205-00:18:58 total_loss: 10.3034 time: 10m 7s 59 | 20200205-00:18:58 Train Epoch: 25/70 ... 60 | 20200205-00:29:06 total_loss: 10.2350 time: 10m 8s 61 | 20200205-00:29:06 Test Epoch: 25 ... 62 | 20200205-00:30:13 ave: 95.5000 63 | 20200205-00:30:13 Train Epoch: 26/70 ... 64 | 20200205-00:40:20 total_loss: 10.1717 time: 10m 7s 65 | 20200205-00:40:20 Train Epoch: 27/70 ... 66 | 20200205-00:50:26 total_loss: 10.1039 time: 10m 7s 67 | 20200205-00:50:26 Train Epoch: 28/70 ... 68 | 20200205-01:00:30 total_loss: 10.0510 time: 10m 4s 69 | 20200205-01:00:30 Train Epoch: 29/70 ... 70 | 20200205-01:10:36 total_loss: 9.9984 time: 10m 5s 71 | 20200205-01:10:36 Train Epoch: 30/70 ... 72 | 20200205-01:20:42 total_loss: 9.9468 time: 10m 6s 73 | 20200205-01:20:42 Test Epoch: 30 ... 74 | 20200205-01:21:48 ave: 95.4167 75 | 20200205-01:21:48 Saving checkpoint: 30 76 | 20200205-01:21:48 Train Epoch: 31/70 ... 77 | 20200205-01:31:56 total_loss: 9.8942 time: 10m 8s 78 | 20200205-01:31:56 Train Epoch: 32/70 ... 79 | 20200205-01:42:10 total_loss: 9.8512 time: 10m 14s 80 | 20200205-01:42:10 Train Epoch: 33/70 ... 81 | 20200205-01:52:28 total_loss: 9.8034 time: 10m 17s 82 | 20200205-01:52:28 Train Epoch: 34/70 ... 83 | 20200205-02:02:46 total_loss: 9.7698 time: 10m 19s 84 | 20200205-02:02:46 Train Epoch: 35/70 ... 85 | 20200205-02:13:04 total_loss: 9.7216 time: 10m 18s 86 | 20200205-02:13:04 Test Epoch: 35 ... 87 | 20200205-02:14:10 ave: 95.3000 88 | 20200205-02:14:10 Train Epoch: 36/70 ... 89 | 20200205-02:24:27 total_loss: 9.6846 time: 10m 17s 90 | 20200205-02:24:27 Train Epoch: 37/70 ... 91 | 20200205-02:34:42 total_loss: 9.6543 time: 10m 15s 92 | 20200205-02:34:42 Train Epoch: 38/70 ... 93 | 20200205-02:44:57 total_loss: 9.6207 time: 10m 15s 94 | 20200205-02:44:57 Train Epoch: 39/70 ... 95 | 20200205-02:55:12 total_loss: 9.5894 time: 10m 15s 96 | 20200205-02:55:12 Train Epoch: 40/70 ... 97 | 20200205-03:05:28 total_loss: 9.5571 time: 10m 16s 98 | 20200205-03:05:28 Test Epoch: 40 ... 99 | 20200205-03:06:35 ave: 95.0833 100 | 20200205-03:06:35 Saving checkpoint: 40 101 | 20200205-03:06:35 Train Epoch: 41/70 ... 102 | 20200205-03:16:51 total_loss: 9.5275 time: 10m 16s 103 | 20200205-03:16:51 Train Epoch: 42/70 ... 104 | 20200205-03:27:07 total_loss: 9.5008 time: 10m 16s 105 | 20200205-03:27:07 Train Epoch: 43/70 ... 106 | 20200205-03:37:23 total_loss: 9.4724 time: 10m 15s 107 | 20200205-03:37:23 Train Epoch: 44/70 ... 108 | 20200205-03:47:39 total_loss: 9.4431 time: 10m 17s 109 | 20200205-03:47:40 Train Epoch: 45/70 ... 110 | 20200205-03:57:57 total_loss: 9.4210 time: 10m 17s 111 | 20200205-03:57:57 Test Epoch: 45 ... 112 | 20200205-03:59:03 ave: 95.4333 113 | 20200205-03:59:03 Train Epoch: 46/70 ... 114 | 20200205-04:09:18 total_loss: 9.4028 time: 10m 15s 115 | 20200205-04:09:18 Train Epoch: 47/70 ... 116 | 20200205-04:19:36 total_loss: 9.3764 time: 10m 17s 117 | 20200205-04:19:36 Train Epoch: 48/70 ... 118 | 20200205-04:29:53 total_loss: 9.3489 time: 10m 18s 119 | 20200205-04:29:53 Train Epoch: 49/70 ... 120 | 20200205-04:40:11 total_loss: 9.3333 time: 10m 18s 121 | 20200205-04:40:11 Train Epoch: 50/70 ... 122 | 20200205-04:50:28 total_loss: 9.3033 time: 10m 17s 123 | 20200205-04:50:28 Test Epoch: 50 ... 124 | 20200205-04:51:35 ave: 95.3667 125 | 20200205-04:51:35 Saving checkpoint: 50 126 | 20200205-04:51:35 Train Epoch: 51/70 ... 127 | 20200205-05:01:53 total_loss: 9.2883 time: 10m 18s 128 | 20200205-05:01:53 Train Epoch: 52/70 ... 129 | 20200205-05:12:09 total_loss: 9.2673 time: 10m 16s 130 | 20200205-05:12:09 Train Epoch: 53/70 ... 131 | 20200205-05:22:23 total_loss: 9.2407 time: 10m 14s 132 | 20200205-05:22:23 Train Epoch: 54/70 ... 133 | 20200205-05:32:41 total_loss: 9.2304 time: 10m 17s 134 | 20200205-05:32:41 Train Epoch: 55/70 ... 135 | 20200205-05:42:59 total_loss: 9.2198 time: 10m 18s 136 | 20200205-05:42:59 Test Epoch: 55 ... 137 | 20200205-05:44:06 ave: 95.0833 138 | 20200205-05:44:06 Train Epoch: 56/70 ... 139 | 20200205-05:54:19 total_loss: 9.1958 time: 10m 13s 140 | 20200205-05:54:19 Train Epoch: 57/70 ... 141 | 20200205-06:04:27 total_loss: 9.1820 time: 10m 8s 142 | 20200205-06:04:27 Train Epoch: 58/70 ... 143 | 20200205-06:14:34 total_loss: 9.1670 time: 10m 8s 144 | 20200205-06:14:34 Train Epoch: 59/70 ... 145 | 20200205-06:24:41 total_loss: 9.1511 time: 10m 6s 146 | 20200205-06:24:41 Train Epoch: 60/70 ... 147 | 20200205-06:34:48 total_loss: 9.1362 time: 10m 7s 148 | 20200205-06:34:48 Test Epoch: 60 ... 149 | 20200205-06:35:54 ave: 95.7667 150 | 20200205-06:35:54 Saving checkpoint: 60 151 | 20200205-06:35:55 Train Epoch: 61/70 ... 152 | 20200205-06:46:02 total_loss: 9.1182 time: 10m 8s 153 | 20200205-06:46:02 Train Epoch: 62/70 ... 154 | 20200205-06:56:09 total_loss: 9.1110 time: 10m 7s 155 | 20200205-06:56:09 Train Epoch: 63/70 ... 156 | 20200205-07:06:17 total_loss: 9.0954 time: 10m 8s 157 | 20200205-07:06:17 Train Epoch: 64/70 ... 158 | 20200205-07:16:24 total_loss: 9.0784 time: 10m 7s 159 | 20200205-07:16:24 Train Epoch: 65/70 ... 160 | 20200205-07:26:31 total_loss: 9.0612 time: 10m 7s 161 | 20200205-07:26:31 Test Epoch: 65 ... 162 | 20200205-07:27:38 ave: 95.7833 163 | 20200205-07:27:38 Train Epoch: 66/70 ... 164 | 20200205-07:37:47 total_loss: 9.0534 time: 10m 9s 165 | 20200205-07:37:47 Train Epoch: 67/70 ... 166 | 20200205-07:47:55 total_loss: 9.0333 time: 10m 8s 167 | 20200205-07:47:55 Train Epoch: 68/70 ... 168 | 20200205-07:58:03 total_loss: 9.0304 time: 10m 8s 169 | 20200205-07:58:03 Train Epoch: 69/70 ... 170 | 20200205-08:08:11 total_loss: 9.0127 time: 10m 8s 171 | 20200205-08:08:11 Train Epoch: 70/70 ... 172 | 20200205-08:18:19 total_loss: 8.9992 time: 10m 8s 173 | 20200205-08:18:19 Test Epoch: 70 ... 174 | 20200205-08:19:25 ave: 95.5333 175 | 20200205-08:19:25 Saving checkpoint: 70 176 | -------------------------------------------------------------------------------- /model/best/060.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbfaria/ShuffleFaceNet_Pytorch/39818fbff657dbf4db308c44cd5bd0c70b56fd4f/model/best/060.ckpt -------------------------------------------------------------------------------- /result/best_result.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbfaria/ShuffleFaceNet_Pytorch/39818fbff657dbf4db308c44cd5bd0c70b56fd4f/result/best_result.mat -------------------------------------------------------------------------------- /result/tmp_result.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbfaria/ShuffleFaceNet_Pytorch/39818fbff657dbf4db308c44cd5bd0c70b56fd4f/result/tmp_result.mat -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data 3 | from torch import nn 4 | from torch.nn import DataParallel 5 | from datetime import datetime 6 | from config import BATCH_SIZE, SAVE_FREQ, RESUME, SAVE_DIR, TEST_FREQ, TOTAL_EPOCH, MODEL_PRE, GPU 7 | from config import CASIA_DATA_DIR, LFW_DATA_DIR 8 | from core import model 9 | from core.utils import init_log 10 | from dataloader.CASIA_Face_loader import CASIA_Face 11 | from dataloader.LFW_loader import LFW 12 | from torch.optim import lr_scheduler 13 | import torch.optim as optim 14 | import time 15 | from lfw_eval import parseList, evaluation_10_fold 16 | import numpy as np 17 | import scipy.io 18 | 19 | 20 | def define_gpu(): 21 | # gpu init 22 | gpu_list = '' 23 | multi_gpus = False 24 | if isinstance(GPU, int): 25 | gpu_list = str(GPU) 26 | else: 27 | multi_gpus = True 28 | for i, gpu_id in enumerate(GPU): 29 | gpu_list += str(gpu_id) 30 | if i != len(GPU) - 1: 31 | gpu_list += ',' 32 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 33 | 34 | return multi_gpus 35 | 36 | 37 | if __name__ == '__main__': 38 | multi_gpus = define_gpu() 39 | print('multi_gpus', multi_gpus) 40 | 41 | # other init 42 | start_epoch = 1 43 | save_dir = os.path.join(SAVE_DIR, MODEL_PRE + datetime.now().strftime('%Y%m%d_%H%M%S')) 44 | if os.path.exists(save_dir): 45 | raise NameError('model dir exists!') 46 | os.makedirs(save_dir) 47 | logging = init_log(save_dir) 48 | _print = logging.info 49 | 50 | # define trainloader and testloader 51 | print('defining casia dataloader...') 52 | trainset = CASIA_Face(root=CASIA_DATA_DIR) 53 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, 54 | shuffle=True, num_workers=8, drop_last=False) 55 | 56 | # nl: left_image_path 57 | # nr: right_image_path 58 | print('defining lfw dataloader...') 59 | nl, nr, folds, flags = parseList(root=LFW_DATA_DIR) 60 | testdataset = LFW(nl, nr) 61 | testloader = torch.utils.data.DataLoader(testdataset, batch_size=32, 62 | shuffle=False, num_workers=8, drop_last=False) 63 | 64 | # define model 65 | print('defining shufflefacenet model...') 66 | net = model.ShuffleFaceNet() 67 | 68 | if RESUME: 69 | ckpt = torch.load(RESUME) 70 | net.load_state_dict(ckpt['net_state_dict']) 71 | start_epoch = ckpt['epoch'] + 1 72 | 73 | net = net.cuda() 74 | 75 | # NLLLoss 76 | nllloss = nn.CrossEntropyLoss().cuda() 77 | # CenterLoss 78 | lmcl_loss = model.CosFace_loss(num_classes=trainset.class_nums, feat_dim=128).cuda() 79 | 80 | 81 | if multi_gpus: 82 | net = DataParallel(net) 83 | lmcl_loss = DataParallel(lmcl_loss) 84 | 85 | criterion = [nllloss, lmcl_loss] 86 | 87 | # optimzer4nn 88 | optimizer4nn = optim.Adam(net.parameters(), lr=0.001, weight_decay=0.0005) 89 | sheduler_4nn = lr_scheduler.StepLR(optimizer4nn, 20, gamma=0.5) 90 | 91 | # optimzer4center 92 | optimzer4center = optim.Adam(lmcl_loss.parameters(), lr=0.01) 93 | sheduler_4center = lr_scheduler.StepLR(optimizer4nn, 20, gamma=0.5) 94 | 95 | best_acc = 0.0 96 | best_epoch = 0 97 | for epoch in range(start_epoch, TOTAL_EPOCH+1): 98 | # exp_lr_scheduler.step() 99 | optimizer4nn.step() 100 | optimzer4center.step() 101 | # train model 102 | _print('Train Epoch: {}/{} ...'.format(epoch, TOTAL_EPOCH)) 103 | net.train() 104 | 105 | train_total_loss = 0.0 106 | total = 0 107 | since = time.time() 108 | for data in trainloader: 109 | img, label = data[0].cuda(), data[1].cuda() 110 | batch_size = img.size(0) 111 | # optimizer_ft.zero_grad() 112 | 113 | raw_logits = net(img) 114 | 115 | logits, mlogits = criterion[1](raw_logits, label) 116 | total_loss = criterion[0](mlogits, label) 117 | 118 | optimizer4nn.zero_grad() 119 | optimzer4center.zero_grad() 120 | 121 | total_loss.backward() 122 | 123 | optimizer4nn.step() 124 | optimzer4center.step() 125 | 126 | train_total_loss += total_loss.item() * batch_size 127 | total += batch_size 128 | 129 | train_total_loss = train_total_loss / total 130 | time_elapsed = time.time() - since 131 | loss_msg = ' total_loss: {:.4f} time: {:.0f}m {:.0f}s'\ 132 | .format(train_total_loss, time_elapsed // 60, time_elapsed % 60) 133 | _print(loss_msg) 134 | 135 | # test model on lfw 136 | if epoch % TEST_FREQ == 0: 137 | net.eval() 138 | featureLs = None 139 | featureRs = None 140 | _print('Test Epoch: {} ...'.format(epoch)) 141 | for data in testloader: 142 | for i in range(len(data)): 143 | data[i] = data[i].cuda() 144 | res = [net(d).data.cpu().numpy() for d in data] 145 | featureL = np.concatenate((res[0], res[1]), 1) 146 | featureR = np.concatenate((res[2], res[3]), 1) 147 | if featureLs is None: 148 | featureLs = featureL 149 | else: 150 | featureLs = np.concatenate((featureLs, featureL), 0) 151 | if featureRs is None: 152 | featureRs = featureR 153 | else: 154 | featureRs = np.concatenate((featureRs, featureR), 0) 155 | 156 | result = {'fl': featureLs, 'fr': featureRs, 'fold': folds, 'flag': flags} 157 | # save tmp_result 158 | scipy.io.savemat('./result/tmp_result.mat', result) 159 | accs = evaluation_10_fold('./result/tmp_result.mat') 160 | _print(' ave: {:.4f}'.format(np.mean(accs) * 100)) 161 | 162 | # save model 163 | if epoch % SAVE_FREQ == 0: 164 | msg = 'Saving checkpoint: {}'.format(epoch) 165 | _print(msg) 166 | if multi_gpus: 167 | net_state_dict = net.module.state_dict() 168 | else: 169 | net_state_dict = net.state_dict() 170 | if not os.path.exists(save_dir): 171 | os.mkdir(save_dir) 172 | torch.save({ 173 | 'epoch': epoch, 174 | 'net_state_dict': net_state_dict}, 175 | os.path.join(save_dir, '%03d.ckpt' % epoch)) 176 | print('finishing training') 177 | --------------------------------------------------------------------------------