├── nets ├── __init__.py ├── arcface_training.py ├── mobilenet.py ├── arcface.py ├── mobilefacenet.py └── iresnet.py ├── utils ├── __init__.py ├── utils.py ├── callback.py ├── dataloader.py ├── utils_fit.py └── utils_metrics.py ├── datasets └── README.md ├── lfw └── README.md ├── logs └── README.md ├── img ├── 1_001.jpg ├── 1_002.jpg └── 2_001.jpg ├── model_data ├── roc_test.png └── arcface_mobilefacenet.pth ├── txt_annotation.py ├── LICENSE ├── summary.py ├── predict.py ├── eval_LFW.py ├── .gitignore ├── README.md ├── arcface.py ├── train.py └── 常见问题汇总.md /nets/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | 存放数据集 -------------------------------------------------------------------------------- /lfw/README.md: -------------------------------------------------------------------------------- 1 | 存放lfw数据集 -------------------------------------------------------------------------------- /logs/README.md: -------------------------------------------------------------------------------- 1 | 用于存放训练好的文件 -------------------------------------------------------------------------------- /img/1_001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bubbliiiing/arcface-pytorch/HEAD/img/1_001.jpg -------------------------------------------------------------------------------- /img/1_002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bubbliiiing/arcface-pytorch/HEAD/img/1_002.jpg -------------------------------------------------------------------------------- /img/2_001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bubbliiiing/arcface-pytorch/HEAD/img/2_001.jpg -------------------------------------------------------------------------------- /model_data/roc_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bubbliiiing/arcface-pytorch/HEAD/model_data/roc_test.png -------------------------------------------------------------------------------- /model_data/arcface_mobilefacenet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bubbliiiing/arcface-pytorch/HEAD/model_data/arcface_mobilefacenet.pth -------------------------------------------------------------------------------- /txt_annotation.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------# 2 | # 进行训练前需要利用这个文件生成cls_train.txt 3 | #------------------------------------------------# 4 | import os 5 | 6 | if __name__ == "__main__": 7 | #---------------------# 8 | # 训练集所在的路径 9 | #---------------------# 10 | datasets_path = "datasets" 11 | 12 | types_name = os.listdir(datasets_path) 13 | types_name = sorted(types_name) 14 | 15 | list_file = open('cls_train.txt', 'w') 16 | for cls_id, type_name in enumerate(types_name): 17 | photos_path = os.path.join(datasets_path, type_name) 18 | if not os.path.isdir(photos_path): 19 | continue 20 | photos_name = os.listdir(photos_path) 21 | 22 | for photo_name in photos_name: 23 | list_file.write(str(cls_id) + ";" + '%s'%(os.path.join(os.path.abspath(datasets_path), type_name, photo_name))) 24 | list_file.write('\n') 25 | list_file.close() 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Bubbliiiing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /summary.py: -------------------------------------------------------------------------------- 1 | #--------------------------------------------# 2 | # 该部分代码只用于看网络结构,并非测试代码 3 | #--------------------------------------------# 4 | import torch 5 | from thop import clever_format, profile 6 | from torchsummary import summary 7 | 8 | from nets.arcface import Arcface 9 | 10 | if __name__ == "__main__": 11 | input_shape = [112, 112] 12 | backbone = 'mobilefacenet' 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | model = Arcface(num_classes=10575, backbone=backbone, mode="predict").to(device) 16 | summary(model, (3, input_shape[0], input_shape[1])) 17 | 18 | dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device) 19 | flops, params = profile(model.to(device), (dummy_input, ), verbose=False) 20 | #--------------------------------------------------------# 21 | # flops * 2是因为profile没有将卷积作为两个operations 22 | # 有些论文将卷积算乘法、加法两个operations。此时乘2 23 | # 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2 24 | # 本代码选择乘2,参考YOLOX。 25 | #--------------------------------------------------------# 26 | flops = flops * 2 27 | flops, params = clever_format([flops, params], "%.3f") 28 | print('Total GFLOPS: %s' % (flops)) 29 | print('Total params: %s' % (params)) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from arcface import Arcface 4 | 5 | if __name__ == "__main__": 6 | model = Arcface() 7 | 8 | #----------------------------------------------------------------------------------------------------------# 9 | # mode用于指定测试的模式: 10 | # 'predict'表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释 11 | # 'fps'表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。 12 | #----------------------------------------------------------------------------------------------------------# 13 | mode = "predict" 14 | #-------------------------------------------------------------------------# 15 | # test_interval 用于指定测量fps的时候,图片检测的次数 16 | # 理论上test_interval越大,fps越准确。 17 | # fps_test_image fps测试图片 18 | #-------------------------------------------------------------------------# 19 | test_interval = 100 20 | fps_test_image = 'img/1_001.jpg' 21 | 22 | if mode == "predict": 23 | while True: 24 | image_1 = input('Input image_1 filename:') 25 | try: 26 | image_1 = Image.open(image_1) 27 | except: 28 | print('Image_1 Open Error! Try again!') 29 | continue 30 | 31 | image_2 = input('Input image_2 filename:') 32 | try: 33 | image_2 = Image.open(image_2) 34 | except: 35 | print('Image_2 Open Error! Try again!') 36 | continue 37 | 38 | probability = model.detect_image(image_1,image_2) 39 | print(probability) 40 | 41 | elif mode == "fps": 42 | img = Image.open(fps_test_image) 43 | tact_time = model.get_FPS(img, test_interval) 44 | print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1') -------------------------------------------------------------------------------- /nets/arcface_training.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | 5 | def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.1, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.3, step_num = 10): 6 | def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters): 7 | if iters <= warmup_total_iters: 8 | # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start 9 | lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2 10 | ) + warmup_lr_start 11 | elif iters >= total_iters - no_aug_iter: 12 | lr = min_lr 13 | else: 14 | lr = min_lr + 0.5 * (lr - min_lr) * ( 15 | 1.0 16 | + math.cos( 17 | math.pi 18 | * (iters - warmup_total_iters) 19 | / (total_iters - warmup_total_iters - no_aug_iter) 20 | ) 21 | ) 22 | return lr 23 | 24 | def step_lr(lr, decay_rate, step_size, iters): 25 | if step_size < 1: 26 | raise ValueError("step_size must above 1.") 27 | n = iters // step_size 28 | out_lr = lr * decay_rate ** n 29 | return out_lr 30 | 31 | if lr_decay_type == "cos": 32 | warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3) 33 | warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6) 34 | no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15) 35 | func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter) 36 | else: 37 | decay_rate = (min_lr / lr) ** (1 / (step_num - 1)) 38 | step_size = total_iters / step_num 39 | func = partial(step_lr, lr, decay_rate, step_size) 40 | 41 | return func 42 | 43 | def set_optimizer_lr(optimizer, lr_scheduler_func, epoch): 44 | lr = lr_scheduler_func(epoch) 45 | for param_group in optimizer.param_groups: 46 | param_group['lr'] = lr 47 | -------------------------------------------------------------------------------- /eval_LFW.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | 4 | from nets.arcface import Arcface 5 | from utils.dataloader import LFWDataset 6 | from utils.utils_metrics import test 7 | 8 | 9 | if __name__ == "__main__": 10 | #--------------------------------------# 11 | # 是否使用Cuda 12 | # 没有GPU可以设置成False 13 | #--------------------------------------# 14 | cuda = True 15 | #--------------------------------------# 16 | # 主干特征提取网络的选择 17 | # mobilefacenet 18 | # mobilenetv1 19 | # iresnet18 20 | # iresnet34 21 | # iresnet50 22 | # iresnet100 23 | # iresnet200 24 | #--------------------------------------# 25 | backbone = "mobilefacenet" 26 | #--------------------------------------# 27 | # 输入图像大小 28 | #--------------------------------------# 29 | input_shape = [112, 112, 3] 30 | #--------------------------------------# 31 | # 训练好的权值文件 32 | #--------------------------------------# 33 | model_path = "model_data/arcface_mobilefacenet.pth" 34 | #--------------------------------------# 35 | # LFW评估数据集的文件路径 36 | # 以及对应的txt文件 37 | #--------------------------------------# 38 | lfw_dir_path = "lfw" 39 | lfw_pairs_path = "model_data/lfw_pair.txt" 40 | #--------------------------------------# 41 | # 评估的批次大小和记录间隔 42 | #--------------------------------------# 43 | batch_size = 256 44 | log_interval = 1 45 | #--------------------------------------# 46 | # ROC图的保存路径 47 | #--------------------------------------# 48 | png_save_path = "model_data/roc_test.png" 49 | 50 | test_loader = torch.utils.data.DataLoader( 51 | LFWDataset(dir=lfw_dir_path, pairs_path=lfw_pairs_path, image_size=input_shape), batch_size=batch_size, shuffle=False) 52 | 53 | model = Arcface(backbone=backbone, mode="predict") 54 | 55 | print('Loading weights into state dict...') 56 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 57 | model.load_state_dict(torch.load(model_path, map_location=device), strict=False) 58 | model = model.eval() 59 | 60 | if cuda: 61 | model = torch.nn.DataParallel(model) 62 | cudnn.benchmark = True 63 | model = model.cuda() 64 | 65 | test(test_loader, model, png_save_path, log_interval, batch_size, cuda) 66 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore map, miou, datasets 2 | map_out/ 3 | miou_out/ 4 | VOCdevkit/ 5 | datasets/ 6 | Medical_Datasets/ 7 | lfw/ 8 | logs/ 9 | model_data/ 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | 7 | 8 | #---------------------------------------------------------# 9 | # 将图像转换成RGB图像,防止灰度图在预测时报错。 10 | # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB 11 | #---------------------------------------------------------# 12 | def cvtColor(image): 13 | if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: 14 | return image 15 | else: 16 | image = image.convert('RGB') 17 | return image 18 | 19 | #---------------------------------------------------# 20 | # 对输入图像进行resize 21 | #---------------------------------------------------# 22 | def resize_image(image, size, letterbox_image): 23 | iw, ih = image.size 24 | w, h = size 25 | if letterbox_image: 26 | scale = min(w/iw, h/ih) 27 | nw = int(iw*scale) 28 | nh = int(ih*scale) 29 | 30 | image = image.resize((nw,nh), Image.BICUBIC) 31 | new_image = Image.new('RGB', size, (128,128,128)) 32 | new_image.paste(image, ((w-nw)//2, (h-nh)//2)) 33 | else: 34 | new_image = image.resize((w, h), Image.BICUBIC) 35 | return new_image 36 | 37 | def get_num_classes(annotation_path): 38 | with open(annotation_path) as f: 39 | dataset_path = f.readlines() 40 | 41 | labels = [] 42 | for path in dataset_path: 43 | path_split = path.split(";") 44 | labels.append(int(path_split[0])) 45 | num_classes = np.max(labels) + 1 46 | return num_classes 47 | 48 | #---------------------------------------------------# 49 | # 获得学习率 50 | #---------------------------------------------------# 51 | def get_lr(optimizer): 52 | for param_group in optimizer.param_groups: 53 | return param_group['lr'] 54 | 55 | #---------------------------------------------------# 56 | # 设置种子 57 | #---------------------------------------------------# 58 | def seed_everything(seed=11): 59 | random.seed(seed) 60 | np.random.seed(seed) 61 | torch.manual_seed(seed) 62 | torch.cuda.manual_seed(seed) 63 | torch.cuda.manual_seed_all(seed) 64 | torch.backends.cudnn.deterministic = True 65 | torch.backends.cudnn.benchmark = False 66 | 67 | #---------------------------------------------------# 68 | # 设置Dataloader的种子 69 | #---------------------------------------------------# 70 | def worker_init_fn(worker_id, rank, seed): 71 | worker_seed = rank + seed 72 | random.seed(worker_seed) 73 | np.random.seed(worker_seed) 74 | torch.manual_seed(worker_seed) 75 | 76 | def preprocess_input(image): 77 | image /= 255.0 78 | image -= 0.5 79 | image /= 0.5 80 | return image 81 | 82 | def show_config(**kwargs): 83 | print('Configurations:') 84 | print('-' * 70) 85 | print('|%25s | %40s|' % ('keys', 'values')) 86 | print('-' * 70) 87 | for key, value in kwargs.items(): 88 | print('|%25s | %40s|' % (str(key), str(value))) 89 | print('-' * 70) -------------------------------------------------------------------------------- /nets/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def conv_bn(inp, oup, stride = 1): 6 | return nn.Sequential( 7 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 8 | nn.BatchNorm2d(oup), 9 | nn.ReLU6(inplace=True) 10 | ) 11 | 12 | def conv_dw(inp, oup, stride = 1): 13 | return nn.Sequential( 14 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 15 | nn.BatchNorm2d(inp), 16 | nn.ReLU6(inplace=True), 17 | 18 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 19 | nn.BatchNorm2d(oup), 20 | nn.ReLU6(inplace=True), 21 | ) 22 | 23 | class MobileNetV1(nn.Module): 24 | fc_scale = 7 * 7 25 | def __init__(self, dropout_keep_prob, embedding_size, pretrained): 26 | super(MobileNetV1, self).__init__() 27 | self.stage1 = nn.Sequential( 28 | conv_bn(3, 32, 1), # 3 29 | conv_dw(32, 64, 1), # 7 30 | 31 | conv_dw(64, 128, 2), # 11 32 | conv_dw(128, 128, 1), # 19 33 | 34 | conv_dw(128, 256, 2), # 27 35 | conv_dw(256, 256, 1), # 43 36 | ) 37 | self.stage2 = nn.Sequential( 38 | conv_dw(256, 512, 2), # 43 + 16 = 59 39 | conv_dw(512, 512, 1), # 59 + 32 = 91 40 | conv_dw(512, 512, 1), # 91 + 32 = 123 41 | conv_dw(512, 512, 1), # 123 + 32 = 155 42 | conv_dw(512, 512, 1), # 155 + 32 = 187 43 | conv_dw(512, 512, 1), # 187 + 32 = 219 44 | ) 45 | self.stage3 = nn.Sequential( 46 | conv_dw(512, 1024, 2), # 219 +3 2 = 241 47 | conv_dw(1024, 1024, 1), # 241 + 64 = 301 48 | ) 49 | 50 | self.sep = nn.Conv2d(1024, 512, kernel_size=1, bias=False) 51 | self.sep_bn = nn.BatchNorm2d(512) 52 | self.prelu = nn.PReLU(512) 53 | 54 | self.bn2 = nn.BatchNorm2d(512, eps=1e-05) 55 | self.dropout = nn.Dropout(p=dropout_keep_prob, inplace=True) 56 | self.linear = nn.Linear(512 * self.fc_scale, embedding_size) 57 | self.features = nn.BatchNorm1d(embedding_size, eps=1e-05) 58 | 59 | if pretrained: 60 | self.load_state_dict(torch.load("model_data/mobilenet_v1_backbone_weights.pth"), strict = False) 61 | else: 62 | for m in self.modules(): 63 | if isinstance(m, nn.Conv2d): 64 | nn.init.normal_(m.weight, 0, 0.1) 65 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 66 | nn.init.constant_(m.weight, 1) 67 | nn.init.constant_(m.bias, 0) 68 | 69 | def forward(self, x): 70 | x = self.stage1(x) 71 | x = self.stage2(x) 72 | x = self.stage3(x) 73 | 74 | x = self.sep(x) 75 | x = self.sep_bn(x) 76 | x = self.prelu(x) 77 | 78 | x = self.bn2(x) 79 | x = torch.flatten(x, 1) 80 | x = self.dropout(x) 81 | x = self.linear(x) 82 | x = self.features(x) 83 | return x 84 | 85 | def get_mobilenet(dropout_keep_prob, embedding_size, pretrained): 86 | return MobileNetV1(dropout_keep_prob, embedding_size, pretrained) 87 | -------------------------------------------------------------------------------- /utils/callback.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | import torch 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import scipy.signal 8 | from matplotlib import pyplot as plt 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | class LossHistory(): 12 | def __init__(self, log_dir, model, input_shape): 13 | time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S') 14 | self.log_dir = os.path.join(log_dir, "loss_" + str(time_str)) 15 | self.acc = [] 16 | self.losses = [] 17 | self.val_loss = [] 18 | 19 | os.makedirs(self.log_dir) 20 | self.writer = SummaryWriter(self.log_dir) 21 | dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1]) 22 | self.writer.add_graph(model, dummy_input) 23 | 24 | def append_loss(self, epoch, acc, loss, val_loss): 25 | if not os.path.exists(self.log_dir): 26 | os.makedirs(self.log_dir) 27 | 28 | self.acc.append(acc) 29 | self.losses.append(loss) 30 | self.val_loss.append(val_loss) 31 | 32 | with open(os.path.join(self.log_dir, "epoch_acc.txt"), 'a') as f: 33 | f.write(str(acc)) 34 | f.write("\n") 35 | with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f: 36 | f.write(str(loss)) 37 | f.write("\n") 38 | with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f: 39 | f.write(str(val_loss)) 40 | f.write("\n") 41 | 42 | self.writer.add_scalar('loss', loss, epoch) 43 | self.writer.add_scalar('val_loss', val_loss, epoch) 44 | self.loss_plot() 45 | 46 | def loss_plot(self): 47 | iters = range(len(self.losses)) 48 | 49 | plt.figure() 50 | plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') 51 | plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') 52 | try: 53 | if len(self.losses) < 25: 54 | num = 5 55 | else: 56 | num = 15 57 | plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') 58 | plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss') 59 | except: 60 | pass 61 | plt.grid(True) 62 | plt.xlabel('Epoch') 63 | plt.ylabel('Loss') 64 | plt.legend(loc="upper right") 65 | plt.savefig(os.path.join(self.log_dir, "epoch_loss.png")) 66 | plt.cla() 67 | plt.close("all") 68 | 69 | plt.figure() 70 | plt.plot(iters, self.acc, 'red', linewidth = 2, label='lfw acc') 71 | try: 72 | if len(self.losses) < 25: 73 | num = 5 74 | else: 75 | num = 15 76 | plt.plot(iters, scipy.signal.savgol_filter(self.acc, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth lfw acc') 77 | except: 78 | pass 79 | plt.grid(True) 80 | plt.xlabel('Epoch') 81 | plt.ylabel('Lfw Acc') 82 | plt.legend(loc="upper right") 83 | plt.savefig(os.path.join(self.log_dir, "epoch_acc.png")) 84 | plt.cla() 85 | plt.close("all") 86 | -------------------------------------------------------------------------------- /nets/arcface.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import Module, Parameter 7 | 8 | from nets.iresnet import (iresnet18, iresnet34, iresnet50, iresnet100, 9 | iresnet200) 10 | from nets.mobilefacenet import get_mbf 11 | from nets.mobilenet import get_mobilenet 12 | 13 | class Arcface_Head(Module): 14 | def __init__(self, embedding_size=128, num_classes=10575, s=64., m=0.5): 15 | super(Arcface_Head, self).__init__() 16 | self.s = s 17 | self.m = m 18 | self.weight = Parameter(torch.FloatTensor(num_classes, embedding_size)) 19 | nn.init.xavier_uniform_(self.weight) 20 | 21 | self.cos_m = math.cos(m) 22 | self.sin_m = math.sin(m) 23 | self.th = math.cos(math.pi - m) 24 | self.mm = math.sin(math.pi - m) * m 25 | 26 | def forward(self, input, label): 27 | cosine = F.linear(input, F.normalize(self.weight)) 28 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) 29 | phi = cosine * self.cos_m - sine * self.sin_m 30 | phi = torch.where(cosine.float() > self.th, phi.float(), cosine.float() - self.mm) 31 | 32 | one_hot = torch.zeros(cosine.size()).type_as(phi).long() 33 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 34 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 35 | output *= self.s 36 | return output 37 | 38 | class Arcface(nn.Module): 39 | def __init__(self, num_classes=None, backbone="mobilefacenet", pretrained=False, mode="train"): 40 | super(Arcface, self).__init__() 41 | if backbone=="mobilefacenet": 42 | embedding_size = 128 43 | s = 32 44 | self.arcface = get_mbf(embedding_size=embedding_size, pretrained=pretrained) 45 | 46 | elif backbone=="mobilenetv1": 47 | embedding_size = 512 48 | s = 64 49 | self.arcface = get_mobilenet(dropout_keep_prob=0.5, embedding_size=embedding_size, pretrained=pretrained) 50 | 51 | elif backbone=="iresnet18": 52 | embedding_size = 512 53 | s = 64 54 | self.arcface = iresnet18(dropout_keep_prob=0.5, embedding_size=embedding_size, pretrained=pretrained) 55 | 56 | elif backbone=="iresnet34": 57 | embedding_size = 512 58 | s = 64 59 | self.arcface = iresnet34(dropout_keep_prob=0.5, embedding_size=embedding_size, pretrained=pretrained) 60 | 61 | elif backbone=="iresnet50": 62 | embedding_size = 512 63 | s = 64 64 | self.arcface = iresnet50(dropout_keep_prob=0.5, embedding_size=embedding_size, pretrained=pretrained) 65 | 66 | elif backbone=="iresnet100": 67 | embedding_size = 512 68 | s = 64 69 | self.arcface = iresnet100(dropout_keep_prob=0.5, embedding_size=embedding_size, pretrained=pretrained) 70 | 71 | elif backbone=="iresnet200": 72 | embedding_size = 512 73 | s = 64 74 | self.arcface = iresnet200(dropout_keep_prob=0.5, embedding_size=embedding_size, pretrained=pretrained) 75 | else: 76 | raise ValueError('Unsupported backbone - `{}`, Use mobilefacenet, mobilenetv1.'.format(backbone)) 77 | 78 | self.mode = mode 79 | if mode == "train": 80 | self.head = Arcface_Head(embedding_size=embedding_size, num_classes=num_classes, s=s) 81 | 82 | def forward(self, x, y = None, mode = "predict"): 83 | x = self.arcface(x) 84 | x = x.view(x.size()[0], -1) 85 | x = F.normalize(x) 86 | if mode == "predict": 87 | return x 88 | else: 89 | x = self.head(x, y) 90 | return x 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Arcface:人脸识别模型在Pytorch当中的实现 2 | --- 3 | 4 | ## 目录 5 | 1. [仓库更新 Top News](#仓库更新) 6 | 2. [相关仓库 Related code](#相关仓库) 7 | 3. [性能情况 Performance](#性能情况) 8 | 4. [所需环境 Environment](#所需环境) 9 | 5. [注意事项 Attention](#注意事项) 10 | 6. [文件下载 Download](#文件下载) 11 | 7. [预测步骤 How2predict](#预测步骤) 12 | 8. [训练步骤 How2train](#训练步骤) 13 | 9. [参考资料 Reference](#Reference) 14 | 15 | ## Top News 16 | **`2022-03`**:**创建仓库,支持不同模型训练,支持大量可调整参数,支持step、cos学习率下降法、支持adam、sgd优化器选择、支持学习率根据batch_size自适应调整、新增图片裁剪。** 17 | 18 | ## 相关仓库 19 | | 模型 | 路径 | 20 | | :----- | :----- | 21 | facenet | https://github.com/bubbliiiing/facenet-pytorch 22 | arcface | https://github.com/bubbliiiing/arcface-pytorch 23 | retinaface | https://github.com/bubbliiiing/retinaface-pytorch 24 | facenet + retinaface | https://github.com/bubbliiiing/facenet-retinaface-pytorch 25 | 26 | ## 性能情况 27 | | 训练数据集 | 权值文件名称 | 测试数据集 | 输入图片大小 | accuracy | Validation rate | 28 | | :-----: | :-----: | :------: | :------: | :------: | :------: | 29 | | CASIA-WebFace | [arcface_mobilenet.pth](https://github.com/bubbliiiing/arcface-pytorch/releases/download/v1.0/arcface_mobilenet.pth) | LFW | 112x112 | 99.11% | 0.95033+-0.02152 @ FAR=0.00133 | 30 | | CASIA-WebFace | [arcface_mobilefacenet.pth](https://github.com/bubbliiiing/arcface-pytorch/releases/download/v1.0/arcface_mobilefacenet.pth) | LFW | 112x112 | 98.78% | 0.91100+-0.01745 @ FAR=0.00100 | 31 | | CASIA-WebFace | [arcface_iresnet50.pth](https://github.com/bubbliiiing/arcface-pytorch/releases/download/v1.0/arcface_iresnet50.pth) | LFW | 112x112 | 98.93% | 0.93100+-0.01422 @ FAR=0.00133 | 32 | 33 | (arcface_mobilenet的准确度相比其它较高是因为使用了backbone的预训练权重,正在努力调参中。) 34 | 35 | ## 所需环境 36 | pytorch==1.2.0 37 | 38 | ## 文件下载 39 | 已经训练好的权值可以在百度网盘下载。 40 | 链接: https://pan.baidu.com/s/1ElJlfmMwOGX699MsgLY8qA 提取码: z3rq 41 | 42 | 训练用的CASIA-WebFaces数据集以及评估用的LFW数据集可以在百度网盘下载。 43 | 链接: https://pan.baidu.com/s/1qMxFR8H_ih0xmY-rKgRejw 提取码: bcrq 44 | 45 | ## 预测步骤 46 | ### a、使用预训练权重 47 | 1. 下载完库后解压,可直接运行predict.py输入: 48 | ```python 49 | img\1_001.jpg 50 | img\1_002.jpg 51 | ``` 52 | 2. 也可以在百度网盘下载权值,放入model_data,修改arcface.py文件的model_path后,输入: 53 | ```python 54 | img\1_001.jpg 55 | img\1_002.jpg 56 | ``` 57 | ### b、使用自己训练的权重 58 | 1. 按照训练步骤训练。 59 | 2. 在arcface.py文件里面,在如下部分修改model_path和backbone使其对应训练好的文件;**model_path对应logs文件夹下面的权值文件,backbone对应主干特征提取网络**。 60 | ```python 61 | _defaults = { 62 | #--------------------------------------------------------------------------# 63 | # 使用自己训练好的模型进行预测要修改model_path,指向logs文件夹下的权值文件 64 | # 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。 65 | # 验证集损失较低不代表准确度较高,仅代表该权值在验证集上泛化性能较好。 66 | #--------------------------------------------------------------------------# 67 | "model_path" : "model_data/arcface_mobilefacenet.pth", 68 | #--------------------------------------------------------------------------# 69 | # 输入图片的大小。 70 | #--------------------------------------------------------------------------# 71 | "input_shape" : [112, 112, 3], 72 | #--------------------------------------------------------------------------# 73 | # 所使用到的主干特征提取网络,与训练的相同 74 | #--------------------------------------------------------------------------# 75 | "backbone" : "arcface_mobilefacenet", 76 | #--------------------------------------# 77 | # 是否进行不失真的resize 78 | #--------------------------------------# 79 | "letterbox_image" : True, 80 | #--------------------------------------# 81 | # 是否使用Cuda 82 | # 没有GPU可以设置成False 83 | #--------------------------------------# 84 | "cuda" : True, 85 | } 86 | ``` 87 | 3. 运行predict.py,输入 88 | ```python 89 | img\1_001.jpg 90 | img\1_002.jpg 91 | ``` 92 | 93 | ## 训练步骤 94 | 1. 本文使用如下格式进行训练。 95 | ``` 96 | |-datasets 97 | |-people0 98 | |-123.jpg 99 | |-234.jpg 100 | |-people1 101 | |-345.jpg 102 | |-456.jpg 103 | |-... 104 | ``` 105 | 2. 下载好数据集,将训练用的CASIA-WebFaces数据集以及评估用的LFW数据集,解压后放在根目录。 106 | 3. 在训练前利用txt_annotation.py文件生成对应的cls_train.txt。 107 | 4. 利用train.py训练模型,训练前,根据自己的需要选择backbone,model_path和backbone一定要对应。 108 | 5. 运行train.py即可开始训练。 109 | 110 | ## 评估步骤 111 | 1. 下载好评估数据集,将评估用的LFW数据集,解压后放在根目录 112 | 2. 在eval_LFW.py设置使用的主干特征提取网络和网络权值。 113 | 3. 运行eval_LFW.py来进行模型准确率评估。 114 | 115 | ## Reference 116 | https://github.com/deepinsight/insightface 117 | https://github.com/timesler/facenet-pytorch 118 | 119 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torchvision.datasets as datasets 7 | from PIL import Image 8 | 9 | from .utils import cvtColor, preprocess_input, resize_image 10 | 11 | 12 | class FacenetDataset(data.Dataset): 13 | def __init__(self, input_shape, lines, random): 14 | self.input_shape = input_shape 15 | self.lines = lines 16 | self.random = random 17 | 18 | def __len__(self): 19 | return len(self.lines) 20 | 21 | def rand(self, a=0, b=1): 22 | return np.random.rand()*(b-a) + a 23 | 24 | def __getitem__(self, index): 25 | annotation_path = self.lines[index].split(';')[1].split()[0] 26 | y = int(self.lines[index].split(';')[0]) 27 | 28 | image = cvtColor(Image.open(annotation_path)) 29 | #------------------------------------------# 30 | # 翻转图像 31 | #------------------------------------------# 32 | if self.rand()<.5 and self.random: 33 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 34 | image = resize_image(image, [self.input_shape[1], self.input_shape[0]], letterbox_image = True) 35 | 36 | image = np.transpose(preprocess_input(np.array(image, dtype='float32')), (2, 0, 1)) 37 | return image, y 38 | 39 | def dataset_collate(batch): 40 | images = [] 41 | targets = [] 42 | for image, y in batch: 43 | images.append(image) 44 | targets.append(y) 45 | images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) 46 | targets = torch.from_numpy(np.array(targets)).long() 47 | return images, targets 48 | 49 | class LFWDataset(datasets.ImageFolder): 50 | def __init__(self, dir, pairs_path, image_size, transform=None): 51 | super(LFWDataset, self).__init__(dir,transform) 52 | self.image_size = image_size 53 | self.pairs_path = pairs_path 54 | self.validation_images = self.get_lfw_paths(dir) 55 | 56 | def read_lfw_pairs(self,pairs_filename): 57 | pairs = [] 58 | with open(pairs_filename, 'r') as f: 59 | for line in f.readlines()[1:]: 60 | pair = line.strip().split() 61 | pairs.append(pair) 62 | return np.array(pairs) 63 | 64 | def get_lfw_paths(self,lfw_dir,file_ext="jpg"): 65 | 66 | pairs = self.read_lfw_pairs(self.pairs_path) 67 | 68 | nrof_skipped_pairs = 0 69 | path_list = [] 70 | issame_list = [] 71 | 72 | for i in range(len(pairs)): 73 | #for pair in pairs: 74 | pair = pairs[i] 75 | if len(pair) == 3: 76 | path0 = os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])+'.'+file_ext) 77 | path1 = os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[2])+'.'+file_ext) 78 | issame = True 79 | elif len(pair) == 4: 80 | path0 = os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])+'.'+file_ext) 81 | path1 = os.path.join(lfw_dir, pair[2], pair[2] + '_' + '%04d' % int(pair[3])+'.'+file_ext) 82 | issame = False 83 | if os.path.exists(path0) and os.path.exists(path1): # Only add the pair if both paths exist 84 | path_list.append((path0,path1,issame)) 85 | issame_list.append(issame) 86 | else: 87 | nrof_skipped_pairs += 1 88 | if nrof_skipped_pairs>0: 89 | print('Skipped %d image pairs' % nrof_skipped_pairs) 90 | 91 | return path_list 92 | 93 | def __getitem__(self, index): 94 | (path_1, path_2, issame) = self.validation_images[index] 95 | image1, image2 = Image.open(path_1), Image.open(path_2) 96 | 97 | image1 = resize_image(image1, [self.image_size[1], self.image_size[0]], letterbox_image = True) 98 | image2 = resize_image(image2, [self.image_size[1], self.image_size[0]], letterbox_image = True) 99 | 100 | image1, image2 = np.transpose(preprocess_input(np.array(image1, np.float32)),[2, 0, 1]), np.transpose(preprocess_input(np.array(image2, np.float32)),[2, 0, 1]) 101 | 102 | return image1, image2, issame 103 | 104 | def __len__(self): 105 | return len(self.validation_images) 106 | -------------------------------------------------------------------------------- /utils/utils_fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.distributed as dist 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | 10 | from .utils import get_lr 11 | from .utils_metrics import evaluate 12 | 13 | 14 | def fit_one_epoch(model_train, model, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, test_loader, lfw_eval_flag, fp16, scaler, save_period, save_dir, local_rank=0): 15 | total_loss = 0 16 | total_accuracy = 0 17 | 18 | val_total_loss = 0 19 | val_total_accuracy = 0 20 | 21 | if local_rank == 0: 22 | print('Start Train') 23 | pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 24 | model_train.train() 25 | for iteration, batch in enumerate(gen): 26 | if iteration >= epoch_step: 27 | break 28 | images, labels = batch 29 | with torch.no_grad(): 30 | if cuda: 31 | images = images.cuda(local_rank) 32 | labels = labels.cuda(local_rank) 33 | 34 | #----------------------# 35 | # 清零梯度 36 | #----------------------# 37 | optimizer.zero_grad() 38 | if not fp16: 39 | outputs = model_train(images, labels, mode="train") 40 | loss = nn.NLLLoss()(F.log_softmax(outputs, -1), labels) 41 | 42 | loss.backward() 43 | optimizer.step() 44 | else: 45 | from torch.cuda.amp import autocast 46 | with autocast(): 47 | outputs = model_train(images, labels, mode="train") 48 | loss = nn.NLLLoss()(F.log_softmax(outputs, -1), labels) 49 | #----------------------# 50 | # 反向传播 51 | #----------------------# 52 | scaler.scale(loss).backward() 53 | scaler.step(optimizer) 54 | scaler.update() 55 | 56 | with torch.no_grad(): 57 | accuracy = torch.mean((torch.argmax(F.softmax(outputs, dim=-1), dim=-1) == labels).type(torch.FloatTensor)) 58 | 59 | total_loss += loss.item() 60 | total_accuracy += accuracy.item() 61 | 62 | if local_rank == 0: 63 | pbar.set_postfix(**{'total_loss': total_loss / (iteration + 1), 64 | 'accuracy' : total_accuracy / (iteration + 1), 65 | 'lr' : get_lr(optimizer)}) 66 | pbar.update(1) 67 | 68 | if local_rank == 0: 69 | pbar.close() 70 | print('Finish Train') 71 | print('Start Validation') 72 | pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3) 73 | model_train.eval() 74 | for iteration, batch in enumerate(gen_val): 75 | if iteration >= epoch_step_val: 76 | break 77 | images, labels = batch 78 | with torch.no_grad(): 79 | if cuda: 80 | images = images.cuda(local_rank) 81 | labels = labels.cuda(local_rank) 82 | 83 | optimizer.zero_grad() 84 | outputs = model_train(images, labels, mode="train") 85 | loss = nn.NLLLoss()(F.log_softmax(outputs, -1), labels) 86 | 87 | accuracy = torch.mean((torch.argmax(F.softmax(outputs, dim=-1), dim=-1) == labels).type(torch.FloatTensor)) 88 | 89 | val_total_loss += loss.item() 90 | val_total_accuracy += accuracy.item() 91 | 92 | if local_rank == 0: 93 | pbar.set_postfix(**{'total_loss': val_total_loss / (iteration + 1), 94 | 'accuracy' : val_total_accuracy / (iteration + 1), 95 | 'lr' : get_lr(optimizer)}) 96 | pbar.update(1) 97 | 98 | if lfw_eval_flag: 99 | print("开始进行LFW数据集的验证。") 100 | labels, distances = [], [] 101 | for _, (data_a, data_p, label) in enumerate(test_loader): 102 | with torch.no_grad(): 103 | data_a, data_p = data_a.type(torch.FloatTensor), data_p.type(torch.FloatTensor) 104 | if cuda: 105 | data_a, data_p = data_a.cuda(local_rank), data_p.cuda(local_rank) 106 | 107 | out_a, out_p = model_train(data_a), model_train(data_p) 108 | dists = torch.sqrt(torch.sum((out_a - out_p) ** 2, 1)) 109 | distances.append(dists.data.cpu().numpy()) 110 | labels.append(label.data.cpu().numpy()) 111 | 112 | labels = np.array([sublabel for label in labels for sublabel in label]) 113 | distances = np.array([subdist for dist in distances for subdist in dist]) 114 | _, _, accuracy, _, _, _, _ = evaluate(distances,labels) 115 | 116 | if local_rank == 0: 117 | pbar.close() 118 | print('Finish Validation') 119 | 120 | if lfw_eval_flag: 121 | print('LFW_Accuracy: %2.5f+-%2.5f' % (np.mean(accuracy), np.std(accuracy))) 122 | 123 | loss_history.append_loss(epoch, np.mean(accuracy) if lfw_eval_flag else total_accuracy / epoch_step, total_loss / epoch_step, val_total_loss / epoch_step_val) 124 | print('Total Loss: %.4f' % (total_loss / epoch_step)) 125 | if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch: 126 | torch.save(model.state_dict(), os.path.join(save_dir, 'ep%03d-loss%.3f-val_loss%.3f.pth'%((epoch+1), total_loss / epoch_step, val_total_loss / epoch_step_val))) 127 | -------------------------------------------------------------------------------- /nets/mobilefacenet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import BatchNorm2d, Conv2d, Module, PReLU, Sequential 3 | 4 | class Flatten(Module): 5 | def forward(self, input): 6 | return input.view(input.size(0), -1) 7 | 8 | class Linear_block(Module): 9 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 10 | super(Linear_block, self).__init__() 11 | self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False) 12 | self.bn = BatchNorm2d(out_c) 13 | def forward(self, x): 14 | x = self.conv(x) 15 | x = self.bn(x) 16 | return x 17 | 18 | class Residual_Block(Module): 19 | def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): 20 | super(Residual_Block, self).__init__() 21 | self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 22 | self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride) 23 | self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 24 | self.residual = residual 25 | def forward(self, x): 26 | if self.residual: 27 | short_cut = x 28 | x = self.conv(x) 29 | x = self.conv_dw(x) 30 | x = self.project(x) 31 | if self.residual: 32 | output = short_cut + x 33 | else: 34 | output = x 35 | return output 36 | 37 | class Residual(Module): 38 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): 39 | super(Residual, self).__init__() 40 | modules = [] 41 | for _ in range(num_block): 42 | modules.append(Residual_Block(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups)) 43 | self.model = Sequential(*modules) 44 | def forward(self, x): 45 | return self.model(x) 46 | 47 | class Conv_block(Module): 48 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 49 | super(Conv_block, self).__init__() 50 | self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False) 51 | self.bn = BatchNorm2d(out_c) 52 | self.prelu = PReLU(out_c) 53 | def forward(self, x): 54 | x = self.conv(x) 55 | x = self.bn(x) 56 | x = self.prelu(x) 57 | return x 58 | 59 | class MobileFaceNet(Module): 60 | def __init__(self, embedding_size): 61 | super(MobileFaceNet, self).__init__() 62 | # 112,112,3 -> 56,56,64 63 | self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1)) 64 | 65 | # 56,56,64 -> 56,56,64 66 | self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) 67 | 68 | # 56,56,64 -> 28,28,64 69 | self.conv_23 = Residual_Block(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128) 70 | self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 71 | 72 | # 28,28,64 -> 14,14,128 73 | self.conv_34 = Residual_Block(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256) 74 | self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 75 | 76 | # 14,14,128 -> 7,7,128 77 | self.conv_45 = Residual_Block(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512) 78 | self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 79 | 80 | self.sep = nn.Conv2d(128, 512, kernel_size=1, bias=False) 81 | self.sep_bn = nn.BatchNorm2d(512) 82 | self.prelu = nn.PReLU(512) 83 | 84 | self.GDC_dw = nn.Conv2d(512, 512, kernel_size=7, bias=False, groups=512) 85 | self.GDC_bn = nn.BatchNorm2d(512) 86 | 87 | self.features = nn.Conv2d(512, embedding_size, kernel_size=1, bias=False) 88 | self.last_bn = nn.BatchNorm2d(embedding_size) 89 | 90 | self._initialize_weights() 91 | 92 | def _initialize_weights(self): 93 | for m in self.modules(): 94 | if isinstance(m, nn.Conv2d): 95 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 96 | if m.bias is not None: 97 | m.bias.data.zero_() 98 | elif isinstance(m, nn.BatchNorm2d): 99 | m.weight.data.fill_(1) 100 | m.bias.data.zero_() 101 | elif isinstance(m, nn.Linear): 102 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 103 | if m.bias is not None: 104 | m.bias.data.zero_() 105 | 106 | def forward(self, x): 107 | x = self.conv1(x) 108 | x = self.conv2_dw(x) 109 | x = self.conv_23(x) 110 | x = self.conv_3(x) 111 | x = self.conv_34(x) 112 | x = self.conv_4(x) 113 | x = self.conv_45(x) 114 | x = self.conv_5(x) 115 | 116 | x = self.sep(x) 117 | x = self.sep_bn(x) 118 | x = self.prelu(x) 119 | 120 | x = self.GDC_dw(x) 121 | x = self.GDC_bn(x) 122 | 123 | x = self.features(x) 124 | x = self.last_bn(x) 125 | return x 126 | 127 | 128 | def get_mbf(embedding_size, pretrained): 129 | if pretrained: 130 | raise ValueError("No pretrained model for mobilefacenet") 131 | return MobileFaceNet(embedding_size) 132 | -------------------------------------------------------------------------------- /arcface.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import torch.backends.cudnn as cudnn 5 | 6 | from nets.arcface import Arcface as arcface 7 | from utils.utils import preprocess_input, resize_image, show_config 8 | 9 | 10 | class Arcface(object): 11 | _defaults = { 12 | #--------------------------------------------------------------------------# 13 | # 使用自己训练好的模型进行预测要修改model_path,指向logs文件夹下的权值文件 14 | # 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。 15 | # 验证集损失较低不代表准确度较高,仅代表该权值在验证集上泛化性能较好。 16 | #--------------------------------------------------------------------------# 17 | "model_path" : "model_data/arcface_mobilefacenet.pth", 18 | #-------------------------------------------# 19 | # 输入图片的大小。 20 | #-------------------------------------------# 21 | "input_shape" : [112, 112, 3], 22 | #-------------------------------------------# 23 | # 所使用到的主干特征提取网络,与训练的相同 24 | # mobilefacenet 25 | # mobilenetv1 26 | # iresnet18 27 | # iresnet34 28 | # iresnet50 29 | # iresnet100 30 | # iresnet200 31 | #-------------------------------------------# 32 | "backbone" : "mobilefacenet", 33 | #-------------------------------------------# 34 | # 是否进行不失真的resize 35 | #-------------------------------------------# 36 | "letterbox_image" : True, 37 | #-------------------------------------------# 38 | # 是否使用Cuda 39 | # 没有GPU可以设置成False 40 | #-------------------------------------------# 41 | "cuda" : True, 42 | } 43 | 44 | @classmethod 45 | def get_defaults(cls, n): 46 | if n in cls._defaults: 47 | return cls._defaults[n] 48 | else: 49 | return "Unrecognized attribute name '" + n + "'" 50 | 51 | #---------------------------------------------------# 52 | # 初始化Arcface 53 | #---------------------------------------------------# 54 | def __init__(self, **kwargs): 55 | self.__dict__.update(self._defaults) 56 | for name, value in kwargs.items(): 57 | setattr(self, name, value) 58 | 59 | self.generate() 60 | 61 | show_config(**self._defaults) 62 | 63 | def generate(self): 64 | #---------------------------------------------------# 65 | # 载入模型与权值 66 | #---------------------------------------------------# 67 | print('Loading weights into state dict...') 68 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 69 | self.net = arcface(backbone=self.backbone, mode="predict").eval() 70 | self.net.load_state_dict(torch.load(self.model_path, map_location=device), strict=False) 71 | print('{} model loaded.'.format(self.model_path)) 72 | 73 | if self.cuda: 74 | self.net = torch.nn.DataParallel(self.net) 75 | cudnn.benchmark = True 76 | self.net = self.net.cuda() 77 | 78 | #---------------------------------------------------# 79 | # 检测图片 80 | #---------------------------------------------------# 81 | def detect_image(self, image_1, image_2): 82 | #---------------------------------------------------# 83 | # 图片预处理,归一化 84 | #---------------------------------------------------# 85 | with torch.no_grad(): 86 | image_1 = resize_image(image_1, [self.input_shape[1], self.input_shape[0]], letterbox_image=self.letterbox_image) 87 | image_2 = resize_image(image_2, [self.input_shape[1], self.input_shape[0]], letterbox_image=self.letterbox_image) 88 | 89 | photo_1 = torch.from_numpy(np.expand_dims(np.transpose(preprocess_input(np.array(image_1, np.float32)), (2, 0, 1)), 0)) 90 | photo_2 = torch.from_numpy(np.expand_dims(np.transpose(preprocess_input(np.array(image_2, np.float32)), (2, 0, 1)), 0)) 91 | 92 | if self.cuda: 93 | photo_1 = photo_1.cuda() 94 | photo_2 = photo_2.cuda() 95 | 96 | #---------------------------------------------------# 97 | # 图片传入网络进行预测 98 | #---------------------------------------------------# 99 | output1 = self.net(photo_1).cpu().numpy() 100 | output2 = self.net(photo_2).cpu().numpy() 101 | 102 | #---------------------------------------------------# 103 | # 计算二者之间的距离 104 | #---------------------------------------------------# 105 | l1 = np.linalg.norm(output1 - output2, axis=1) 106 | 107 | plt.subplot(1, 2, 1) 108 | plt.imshow(np.array(image_1)) 109 | 110 | plt.subplot(1, 2, 2) 111 | plt.imshow(np.array(image_2)) 112 | plt.text(-12, -12, 'Distance:%.3f' % l1, ha='center', va= 'bottom',fontsize=11) 113 | plt.show() 114 | return l1 115 | 116 | def get_FPS(self, image, test_interval): 117 | #---------------------------------------------------# 118 | # 对图片进行不失真的resize 119 | #---------------------------------------------------# 120 | image_data = resize_image(image, [self.input_shape[1], self.input_shape[0]], self.letterbox_image) 121 | #---------------------------------------------------------# 122 | # 归一化+添加上batch_size维度 123 | #---------------------------------------------------------# 124 | image_data = torch.from_numpy(np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)) 125 | with torch.no_grad(): 126 | #---------------------------------------------------# 127 | # 图片传入网络进行预测 128 | #---------------------------------------------------# 129 | preds = self.net(image_data).cpu().numpy() 130 | 131 | import time 132 | t1 = time.time() 133 | for _ in range(test_interval): 134 | with torch.no_grad(): 135 | #---------------------------------------------------# 136 | # 图片传入网络进行预测 137 | #---------------------------------------------------# 138 | preds = self.net(image_data).cpu().numpy() 139 | t2 = time.time() 140 | tact_time = (t2 - t1) / test_interval 141 | return tact_time 142 | -------------------------------------------------------------------------------- /utils/utils_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy import interpolate 4 | from sklearn.model_selection import KFold 5 | from tqdm import tqdm 6 | 7 | def evaluate(distances, labels, nrof_folds=10): 8 | # Calculate evaluation metrics 9 | thresholds = np.arange(0, 4, 0.01) 10 | tpr, fpr, accuracy, best_thresholds = calculate_roc(thresholds, distances, 11 | labels, nrof_folds=nrof_folds) 12 | thresholds = np.arange(0, 4, 0.001) 13 | val, val_std, far = calculate_val(thresholds, distances, 14 | labels, 1e-3, nrof_folds=nrof_folds) 15 | return tpr, fpr, accuracy, val, val_std, far, best_thresholds 16 | 17 | def calculate_roc(thresholds, distances, labels, nrof_folds=10): 18 | 19 | nrof_pairs = min(len(labels), len(distances)) 20 | nrof_thresholds = len(thresholds) 21 | k_fold = KFold(n_splits=nrof_folds, shuffle=False) 22 | 23 | tprs = np.zeros((nrof_folds,nrof_thresholds)) 24 | fprs = np.zeros((nrof_folds,nrof_thresholds)) 25 | accuracy = np.zeros((nrof_folds)) 26 | 27 | indices = np.arange(nrof_pairs) 28 | 29 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 30 | 31 | # Find the best threshold for the fold 32 | acc_train = np.zeros((nrof_thresholds)) 33 | for threshold_idx, threshold in enumerate(thresholds): 34 | _, _, acc_train[threshold_idx] = calculate_accuracy(threshold, distances[train_set], labels[train_set]) 35 | 36 | best_threshold_index = np.argmax(acc_train) 37 | for threshold_idx, threshold in enumerate(thresholds): 38 | tprs[fold_idx,threshold_idx], fprs[fold_idx,threshold_idx], _ = calculate_accuracy(threshold, distances[test_set], labels[test_set]) 39 | _, _, accuracy[fold_idx] = calculate_accuracy(thresholds[best_threshold_index], distances[test_set], labels[test_set]) 40 | tpr = np.mean(tprs,0) 41 | fpr = np.mean(fprs,0) 42 | return tpr, fpr, accuracy, thresholds[best_threshold_index] 43 | 44 | def calculate_accuracy(threshold, dist, actual_issame): 45 | predict_issame = np.less(dist, threshold) 46 | tp = np.sum(np.logical_and(predict_issame, actual_issame)) 47 | fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) 48 | tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame))) 49 | fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) 50 | 51 | tpr = 0 if (tp+fn==0) else float(tp) / float(tp+fn) 52 | fpr = 0 if (fp+tn==0) else float(fp) / float(fp+tn) 53 | acc = float(tp+tn)/dist.size 54 | return tpr, fpr, acc 55 | 56 | def calculate_val(thresholds, distances, labels, far_target=1e-3, nrof_folds=10): 57 | nrof_pairs = min(len(labels), len(distances)) 58 | nrof_thresholds = len(thresholds) 59 | k_fold = KFold(n_splits=nrof_folds, shuffle=False) 60 | 61 | val = np.zeros(nrof_folds) 62 | far = np.zeros(nrof_folds) 63 | 64 | indices = np.arange(nrof_pairs) 65 | 66 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 67 | # Find the threshold that gives FAR = far_target 68 | far_train = np.zeros(nrof_thresholds) 69 | for threshold_idx, threshold in enumerate(thresholds): 70 | _, far_train[threshold_idx] = calculate_val_far(threshold, distances[train_set], labels[train_set]) 71 | if np.max(far_train)>=far_target: 72 | f = interpolate.interp1d(far_train, thresholds, kind='slinear') 73 | threshold = f(far_target) 74 | else: 75 | threshold = 0.0 76 | 77 | val[fold_idx], far[fold_idx] = calculate_val_far(threshold, distances[test_set], labels[test_set]) 78 | 79 | val_mean = np.mean(val) 80 | far_mean = np.mean(far) 81 | val_std = np.std(val) 82 | return val_mean, val_std, far_mean 83 | 84 | def calculate_val_far(threshold, dist, actual_issame): 85 | predict_issame = np.less(dist, threshold) 86 | true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) 87 | false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) 88 | n_same = np.sum(actual_issame) 89 | n_diff = np.sum(np.logical_not(actual_issame)) 90 | if n_diff == 0: 91 | n_diff = 1 92 | if n_same == 0: 93 | return 0,0 94 | val = float(true_accept) / float(n_same) 95 | far = float(false_accept) / float(n_diff) 96 | return val, far 97 | 98 | def test(test_loader, model, png_save_path, log_interval, batch_size, cuda): 99 | labels, distances = [], [] 100 | pbar = tqdm(enumerate(test_loader)) 101 | for batch_idx, (data_a, data_p, label) in pbar: 102 | with torch.no_grad(): 103 | #--------------------------------------# 104 | # 加载数据,设置成cuda 105 | #--------------------------------------# 106 | data_a, data_p = data_a.type(torch.FloatTensor), data_p.type(torch.FloatTensor) 107 | if cuda: 108 | data_a, data_p = data_a.cuda(), data_p.cuda() 109 | #--------------------------------------# 110 | # 传入模型预测,获得预测结果 111 | # 获得预测结果的距离 112 | #--------------------------------------# 113 | out_a, out_p = model(data_a), model(data_p) 114 | dists = torch.sqrt(torch.sum((out_a - out_p) ** 2, 1)) 115 | 116 | #--------------------------------------# 117 | # 将结果添加进列表中 118 | #--------------------------------------# 119 | distances.append(dists.data.cpu().numpy()) 120 | labels.append(label.data.cpu().numpy()) 121 | 122 | #--------------------------------------# 123 | # 打印 124 | #--------------------------------------# 125 | if batch_idx % log_interval == 0: 126 | pbar.set_description('Test Epoch: [{}/{} ({:.0f}%)]'.format( 127 | batch_idx * batch_size, len(test_loader.dataset), 128 | 100. * batch_idx / len(test_loader))) 129 | 130 | #--------------------------------------# 131 | # 转换成numpy 132 | #--------------------------------------# 133 | labels = np.array([sublabel for label in labels for sublabel in label]) 134 | distances = np.array([subdist for dist in distances for subdist in dist]) 135 | 136 | tpr, fpr, accuracy, val, val_std, far, best_thresholds = evaluate(distances,labels) 137 | print('Accuracy: %2.5f+-%2.5f' % (np.mean(accuracy), np.std(accuracy))) 138 | print('Best_thresholds: %2.5f' % best_thresholds) 139 | print('Validation rate: %2.5f+-%2.5f @ FAR=%2.5f' % (val, val_std, far)) 140 | plot_roc(fpr, tpr, figure_name = png_save_path) 141 | 142 | def plot_roc(fpr, tpr, figure_name = "roc.png"): 143 | import matplotlib.pyplot as plt 144 | from sklearn.metrics import auc, roc_curve 145 | roc_auc = auc(fpr, tpr) 146 | fig = plt.figure() 147 | lw = 2 148 | plt.plot(fpr, tpr, color='darkorange', 149 | lw=lw, label='ROC curve (area = %0.2f)' % roc_auc) 150 | plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') 151 | plt.xlim([0.0, 1.0]) 152 | plt.ylim([0.0, 1.05]) 153 | plt.xlabel('False Positive Rate') 154 | plt.ylabel('True Positive Rate') 155 | plt.title('Receiver operating characteristic') 156 | plt.legend(loc="lower right") 157 | fig.savefig(figure_name, dpi=fig.dpi) 158 | -------------------------------------------------------------------------------- /nets/iresnet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | 5 | __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 9 | return nn.Conv2d(in_planes, 10 | out_planes, 11 | kernel_size=3, 12 | stride=stride, 13 | padding=dilation, 14 | groups=groups, 15 | bias=False, 16 | dilation=dilation) 17 | 18 | 19 | def conv1x1(in_planes, out_planes, stride=1): 20 | return nn.Conv2d(in_planes, 21 | out_planes, 22 | kernel_size=1, 23 | stride=stride, 24 | bias=False) 25 | 26 | 27 | class IBasicBlock(nn.Module): 28 | expansion = 1 29 | def __init__(self, inplanes, planes, stride=1, downsample=None, 30 | groups=1, base_width=64, dilation=1): 31 | super(IBasicBlock, self).__init__() 32 | if groups != 1 or base_width != 64: 33 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 34 | if dilation > 1: 35 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 36 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) 37 | self.conv1 = conv3x3(inplanes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) 39 | self.prelu = nn.PReLU(planes) 40 | self.conv2 = conv3x3(planes, planes, stride) 41 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | identity = x 47 | out = self.bn1(x) 48 | out = self.conv1(out) 49 | out = self.bn2(out) 50 | out = self.prelu(out) 51 | out = self.conv2(out) 52 | out = self.bn3(out) 53 | if self.downsample is not None: 54 | identity = self.downsample(x) 55 | out += identity 56 | return out 57 | 58 | 59 | class IResNet(nn.Module): 60 | fc_scale = 7 * 7 61 | def __init__(self, 62 | block, layers, dropout_keep_prob=0, embedding_size=512, zero_init_residual=False, 63 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): 64 | super(IResNet, self).__init__() 65 | self.fp16 = fp16 66 | self.inplanes = 64 67 | self.dilation = 1 68 | if replace_stride_with_dilation is None: 69 | replace_stride_with_dilation = [False, False, False] 70 | if len(replace_stride_with_dilation) != 3: 71 | raise ValueError("replace_stride_with_dilation should be None " 72 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 73 | self.groups = groups 74 | self.base_width = width_per_group 75 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 76 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 77 | self.prelu = nn.PReLU(self.inplanes) 78 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 79 | self.layer2 = self._make_layer(block, 80 | 128, 81 | layers[1], 82 | stride=2, 83 | dilate=replace_stride_with_dilation[0]) 84 | self.layer3 = self._make_layer(block, 85 | 256, 86 | layers[2], 87 | stride=2, 88 | dilate=replace_stride_with_dilation[1]) 89 | self.layer4 = self._make_layer(block, 90 | 512, 91 | layers[3], 92 | stride=2, 93 | dilate=replace_stride_with_dilation[2]) 94 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) 95 | self.dropout = nn.Dropout(p=dropout_keep_prob, inplace=True) 96 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, embedding_size) 97 | self.features = nn.BatchNorm1d(embedding_size, eps=1e-05) 98 | nn.init.constant_(self.features.weight, 1.0) 99 | self.features.weight.requires_grad = False 100 | 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | nn.init.normal_(m.weight, 0, 0.1) 104 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 105 | nn.init.constant_(m.weight, 1) 106 | nn.init.constant_(m.bias, 0) 107 | 108 | if zero_init_residual: 109 | for m in self.modules(): 110 | if isinstance(m, IBasicBlock): 111 | nn.init.constant_(m.bn2.weight, 0) 112 | 113 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 114 | downsample = None 115 | previous_dilation = self.dilation 116 | if dilate: 117 | self.dilation *= stride 118 | stride = 1 119 | if stride != 1 or self.inplanes != planes * block.expansion: 120 | downsample = nn.Sequential( 121 | conv1x1(self.inplanes, planes * block.expansion, stride), 122 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 123 | ) 124 | layers = [] 125 | layers.append( 126 | block(self.inplanes, planes, stride, downsample, self.groups, 127 | self.base_width, previous_dilation)) 128 | self.inplanes = planes * block.expansion 129 | for _ in range(1, blocks): 130 | layers.append( 131 | block(self.inplanes, 132 | planes, 133 | groups=self.groups, 134 | base_width=self.base_width, 135 | dilation=self.dilation)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | x = self.bn1(x) 142 | x = self.prelu(x) 143 | x = self.layer1(x) 144 | x = self.layer2(x) 145 | x = self.layer3(x) 146 | x = self.layer4(x) 147 | x = self.bn2(x) 148 | x = torch.flatten(x, 1) 149 | x = self.dropout(x) 150 | x = self.fc(x) 151 | x = self.features(x) 152 | return x 153 | 154 | 155 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 156 | model = IResNet(block, layers, **kwargs) 157 | if pretrained: 158 | raise ValueError("No pretrained model for iresnet") 159 | return model 160 | 161 | 162 | def iresnet18(pretrained=False, progress=True, **kwargs): 163 | return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, 164 | progress, **kwargs) 165 | 166 | 167 | def iresnet34(pretrained=False, progress=True, **kwargs): 168 | return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, 169 | progress, **kwargs) 170 | 171 | 172 | def iresnet50(pretrained=False, progress=True, **kwargs): 173 | return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, 174 | progress, **kwargs) 175 | 176 | 177 | def iresnet100(pretrained=False, progress=True, **kwargs): 178 | return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, 179 | progress, **kwargs) 180 | 181 | 182 | def iresnet200(pretrained=False, progress=True, **kwargs): 183 | return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, 184 | progress, **kwargs) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.distributed as dist 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | 11 | from nets.arcface import Arcface 12 | from nets.arcface_training import get_lr_scheduler, set_optimizer_lr 13 | from utils.callback import LossHistory 14 | from utils.dataloader import FacenetDataset, LFWDataset, dataset_collate 15 | from utils.utils import (get_num_classes, seed_everything, show_config, 16 | worker_init_fn) 17 | from utils.utils_fit import fit_one_epoch 18 | 19 | if __name__ == "__main__": 20 | #-------------------------------# 21 | # 是否使用Cuda 22 | # 没有GPU可以设置成False 23 | #-------------------------------# 24 | Cuda = True 25 | #----------------------------------------------# 26 | # Seed 用于固定随机种子 27 | # 使得每次独立训练都可以获得一样的结果 28 | #----------------------------------------------# 29 | seed = 11 30 | #---------------------------------------------------------------------# 31 | # distributed 用于指定是否使用单机多卡分布式运行 32 | # 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。 33 | # Windows系统下默认使用DP模式调用所有显卡,不支持DDP。 34 | # DP模式: 35 | # 设置 distributed = False 36 | # 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python train.py 37 | # DDP模式: 38 | # 设置 distributed = True 39 | # 在终端中输入 CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py 40 | #---------------------------------------------------------------------# 41 | distributed = False 42 | #---------------------------------------------------------------------# 43 | # sync_bn 是否使用sync_bn,DDP模式多卡可用 44 | #---------------------------------------------------------------------# 45 | sync_bn = False 46 | #---------------------------------------------------------------------# 47 | # fp16 是否使用混合精度训练 48 | # 可减少约一半的显存、需要pytorch1.7.1以上 49 | #---------------------------------------------------------------------# 50 | fp16 = False 51 | #--------------------------------------------------------# 52 | # 指向根目录下的cls_train.txt,读取人脸路径与标签 53 | #--------------------------------------------------------# 54 | annotation_path = "cls_train.txt" 55 | #--------------------------------------------------------# 56 | # 输入图像大小 57 | #--------------------------------------------------------# 58 | input_shape = [112, 112, 3] 59 | #--------------------------------------------------------# 60 | # 主干特征提取网络的选择 61 | # mobilefacenet 62 | # mobilenetv1 63 | # iresnet18 64 | # iresnet34 65 | # iresnet50 66 | # iresnet100 67 | # iresnet200 68 | # 69 | # 除了mobilenetv1外,其它的backbone均可从0开始训练。 70 | # 这是由于mobilenetv1没有残差边,收敛速度慢,因此建议: 71 | # 如果使用mobilenetv1为主干, 则设置pretrain = True 72 | # 如果使用其它网络为主干, 则设置pretrain = False 73 | #--------------------------------------------------------# 74 | backbone = "mobilefacenet" 75 | #----------------------------------------------------------------------------------------------------------------------------# 76 | # 如果训练过程中存在中断训练的操作,可以将model_path设置成logs文件夹下的权值文件,将已经训练了一部分的权值再次载入。 77 | # 同时修改下方的训练的参数,来保证模型epoch的连续性。 78 | # 79 | # 当model_path = ''的时候不加载整个模型的权值。 80 | # 81 | # 此处使用的是整个模型的权重,因此是在train.py进行加载的,pretrain不影响此处的权值加载。 82 | # 如果想要让模型从主干的预训练权值开始训练,则设置model_path = '',pretrain = True,此时仅加载主干。 83 | # 如果想要让模型从0开始训练,则设置model_path = '',pretrain = Fasle,此时从0开始训练。 84 | #----------------------------------------------------------------------------------------------------------------------------# 85 | model_path = "" 86 | #----------------------------------------------------------------------------------------------------------------------------# 87 | # 是否使用主干网络的预训练权重,此处使用的是主干的权重,因此是在模型构建的时候进行加载的。 88 | # 如果设置了model_path,则主干的权值无需加载,pretrained的值无意义。 89 | # 如果不设置model_path,pretrained = True,此时仅加载主干开始训练。 90 | # 如果不设置model_path,pretrained = False,此时从0开始训练。 91 | # 除了mobilenetv1外,其它的backbone均未提供预训练权重。 92 | #----------------------------------------------------------------------------------------------------------------------------# 93 | pretrained = False 94 | 95 | #----------------------------------------------------------------------------------------------------------------------------# 96 | # 显存不足与数据集大小无关,提示显存不足请调小batch_size。 97 | # 受到BatchNorm层影响,不能为1。 98 | # 99 | # 在此提供若干参数设置建议,各位训练者根据自己的需求进行灵活调整: 100 | # (一)从预训练权重开始训练: 101 | # Adam: 102 | # Init_Epoch = 0,Epoch = 100,optimizer_type = 'adam',Init_lr = 1e-3,weight_decay = 0。 103 | # SGD: 104 | # Init_Epoch = 0,Epoch = 100,optimizer_type = 'sgd',Init_lr = 1e-2,weight_decay = 5e-4。 105 | # 其中:UnFreeze_Epoch可以在100-300之间调整。 106 | # (二)batch_size的设置: 107 | # 在显卡能够接受的范围内,以大为好。显存不足与数据集大小无关,提示显存不足(OOM或者CUDA out of memory)请调小batch_size。 108 | # 受到BatchNorm层影响,batch_size最小为2,不能为1。 109 | # 正常情况下Freeze_batch_size建议为Unfreeze_batch_size的1-2倍。不建议设置的差距过大,因为关系到学习率的自动调整。 110 | #----------------------------------------------------------------------------------------------------------------------------# 111 | #------------------------------------------------------# 112 | # 训练参数 113 | # Init_Epoch 模型当前开始的训练世代 114 | # Epoch 模型总共训练的epoch 115 | # batch_size 每次输入的图片数量 116 | #------------------------------------------------------# 117 | Init_Epoch = 0 118 | Epoch = 100 119 | batch_size = 64 120 | 121 | #------------------------------------------------------------------# 122 | # 其它训练参数:学习率、优化器、学习率下降有关 123 | #------------------------------------------------------------------# 124 | #------------------------------------------------------------------# 125 | # Init_lr 模型的最大学习率 126 | # Min_lr 模型的最小学习率,默认为最大学习率的0.01 127 | #------------------------------------------------------------------# 128 | Init_lr = 1e-2 129 | Min_lr = Init_lr * 0.01 130 | #------------------------------------------------------------------# 131 | # optimizer_type 使用到的优化器种类,可选的有adam、sgd 132 | # 当使用Adam优化器时建议设置 Init_lr=1e-3 133 | # 当使用SGD优化器时建议设置 Init_lr=1e-2 134 | # momentum 优化器内部使用到的momentum参数 135 | # weight_decay 权值衰减,可防止过拟合 136 | # adam会导致weight_decay错误,使用adam时建议设置为0。 137 | #------------------------------------------------------------------# 138 | optimizer_type = "sgd" 139 | momentum = 0.9 140 | weight_decay = 5e-4 141 | #------------------------------------------------------------------# 142 | # lr_decay_type 使用到的学习率下降方式,可选的有step、cos 143 | #------------------------------------------------------------------# 144 | lr_decay_type = "cos" 145 | #------------------------------------------------------------------# 146 | # save_period 多少个epoch保存一次权值,默认每个世代都保存 147 | #------------------------------------------------------------------# 148 | save_period = 1 149 | #------------------------------------------------------------------# 150 | # save_dir 权值与日志文件保存的文件夹 151 | #------------------------------------------------------------------# 152 | save_dir = 'logs' 153 | #------------------------------------------------------------------# 154 | # 用于设置是否使用多线程读取数据 155 | # 开启后会加快数据读取速度,但是会占用更多内存 156 | # 内存较小的电脑可以设置为2或者0 157 | #------------------------------------------------------------------# 158 | num_workers = 4 159 | #------------------------------------------------------------------# 160 | # 是否开启LFW评估 161 | #------------------------------------------------------------------# 162 | lfw_eval_flag = True 163 | #------------------------------------------------------------------# 164 | # LFW评估数据集的文件路径和对应的txt文件 165 | #------------------------------------------------------------------# 166 | lfw_dir_path = "lfw" 167 | lfw_pairs_path = "model_data/lfw_pair.txt" 168 | 169 | seed_everything(seed) 170 | #------------------------------------------------------# 171 | # 设置用到的显卡 172 | #------------------------------------------------------# 173 | ngpus_per_node = torch.cuda.device_count() 174 | if distributed: 175 | dist.init_process_group(backend="nccl") 176 | local_rank = int(os.environ["LOCAL_RANK"]) 177 | rank = int(os.environ["RANK"]) 178 | device = torch.device("cuda", local_rank) 179 | if local_rank == 0: 180 | print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...") 181 | print("Gpu Device Count : ", ngpus_per_node) 182 | else: 183 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 184 | local_rank = 0 185 | rank = 0 186 | 187 | num_classes = get_num_classes(annotation_path) 188 | #---------------------------------# 189 | # 载入模型并加载预训练权重 190 | #---------------------------------# 191 | model = Arcface(num_classes=num_classes, backbone=backbone, pretrained=pretrained) 192 | 193 | if model_path != '': 194 | #------------------------------------------------------# 195 | # 权值文件请看README,百度网盘下载 196 | #------------------------------------------------------# 197 | if local_rank == 0: 198 | print('Load weights {}.'.format(model_path)) 199 | 200 | #------------------------------------------------------# 201 | # 根据预训练权重的Key和模型的Key进行加载 202 | #------------------------------------------------------# 203 | model_dict = model.state_dict() 204 | pretrained_dict = torch.load(model_path, map_location = device) 205 | load_key, no_load_key, temp_dict = [], [], {} 206 | for k, v in pretrained_dict.items(): 207 | if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v): 208 | temp_dict[k] = v 209 | load_key.append(k) 210 | else: 211 | no_load_key.append(k) 212 | model_dict.update(temp_dict) 213 | model.load_state_dict(model_dict) 214 | #------------------------------------------------------# 215 | # 显示没有匹配上的Key 216 | #------------------------------------------------------# 217 | if local_rank == 0: 218 | print("\nSuccessful Load Key:", str(load_key)[:500], "……\nSuccessful Load Key Num:", len(load_key)) 219 | print("\nFail To Load Key:", str(no_load_key)[:500], "……\nFail To Load Key num:", len(no_load_key)) 220 | print("\n\033[1;33;44m温馨提示,head部分没有载入是正常现象,Backbone部分没有载入是错误的。\033[0m") 221 | 222 | #----------------------# 223 | # 记录Loss 224 | #----------------------# 225 | if local_rank == 0: 226 | loss_history = LossHistory(save_dir, model, input_shape=input_shape) 227 | else: 228 | loss_history = None 229 | 230 | #------------------------------------------------------------------# 231 | # torch 1.2不支持amp,建议使用torch 1.7.1及以上正确使用fp16 232 | # 因此torch1.2这里显示"could not be resolve" 233 | #------------------------------------------------------------------# 234 | if fp16: 235 | from torch.cuda.amp import GradScaler as GradScaler 236 | scaler = GradScaler() 237 | else: 238 | scaler = None 239 | 240 | model_train = model.train() 241 | #----------------------------# 242 | # 多卡同步Bn 243 | #----------------------------# 244 | if sync_bn and ngpus_per_node > 1 and distributed: 245 | model_train = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_train) 246 | elif sync_bn: 247 | print("Sync_bn is not support in one gpu or not distributed.") 248 | 249 | if Cuda: 250 | if distributed: 251 | #----------------------------# 252 | # 多卡平行运行 253 | #----------------------------# 254 | model_train = model_train.cuda(local_rank) 255 | model_train = torch.nn.parallel.DistributedDataParallel(model_train, device_ids=[local_rank], find_unused_parameters=True) 256 | else: 257 | model_train = torch.nn.DataParallel(model) 258 | cudnn.benchmark = True 259 | model_train = model_train.cuda() 260 | 261 | #---------------------------------# 262 | # LFW估计 263 | #---------------------------------# 264 | LFW_loader = torch.utils.data.DataLoader( 265 | LFWDataset(dir=lfw_dir_path, pairs_path=lfw_pairs_path, image_size=input_shape), batch_size=32, shuffle=False) if lfw_eval_flag else None 266 | 267 | #-------------------------------------------------------# 268 | # 0.01用于验证,0.99用于训练 269 | #-------------------------------------------------------# 270 | val_split = 0.01 271 | with open(annotation_path,"r") as f: 272 | lines = f.readlines() 273 | np.random.seed(10101) 274 | np.random.shuffle(lines) 275 | np.random.seed(None) 276 | num_val = int(len(lines)*val_split) 277 | num_train = len(lines) - num_val 278 | 279 | show_config( 280 | num_classes = num_classes, backbone = backbone, model_path = model_path, input_shape = input_shape, \ 281 | Init_Epoch = Init_Epoch, Epoch = Epoch, batch_size = batch_size, \ 282 | Init_lr = Init_lr, Min_lr = Min_lr, optimizer_type = optimizer_type, momentum = momentum, lr_decay_type = lr_decay_type, \ 283 | save_period = save_period, save_dir = save_dir, num_workers = num_workers, num_train = num_train, num_val = num_val 284 | ) 285 | 286 | if True: 287 | #-------------------------------------------------------------------# 288 | # 判断当前batch_size,自适应调整学习率 289 | #-------------------------------------------------------------------# 290 | nbs = 64 291 | lr_limit_max = 1e-3 if optimizer_type == 'adam' else 1e-1 292 | lr_limit_min = 3e-4 if optimizer_type == 'adam' else 5e-4 293 | Init_lr_fit = min(max(batch_size / nbs * Init_lr, lr_limit_min), lr_limit_max) 294 | Min_lr_fit = min(max(batch_size / nbs * Min_lr, lr_limit_min * 1e-2), lr_limit_max * 1e-2) 295 | 296 | #---------------------------------------# 297 | # 根据optimizer_type选择优化器 298 | #---------------------------------------# 299 | optimizer = { 300 | 'adam' : optim.Adam(model.parameters(), Init_lr_fit, betas = (momentum, 0.999), weight_decay = weight_decay), 301 | 'sgd' : optim.SGD(model.parameters(), Init_lr_fit, momentum=momentum, nesterov=True, weight_decay = weight_decay) 302 | }[optimizer_type] 303 | 304 | #---------------------------------------# 305 | # 获得学习率下降的公式 306 | #---------------------------------------# 307 | lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, Epoch) 308 | 309 | #---------------------------------------# 310 | # 判断每一个世代的长度 311 | #---------------------------------------# 312 | epoch_step = num_train // batch_size 313 | epoch_step_val = num_val // batch_size 314 | 315 | if epoch_step == 0 or epoch_step_val == 0: 316 | raise ValueError("数据集过小,无法继续进行训练,请扩充数据集。") 317 | 318 | #---------------------------------------# 319 | # 构建数据集加载器。 320 | #---------------------------------------# 321 | train_dataset = FacenetDataset(input_shape, lines[:num_train], random = True) 322 | val_dataset = FacenetDataset(input_shape, lines[num_train:], random = False) 323 | 324 | if distributed: 325 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True,) 326 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False,) 327 | batch_size = batch_size // ngpus_per_node 328 | shuffle = False 329 | else: 330 | train_sampler = None 331 | val_sampler = None 332 | shuffle = True 333 | 334 | gen = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True, 335 | drop_last=True, collate_fn=dataset_collate, sampler=train_sampler, 336 | worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed)) 337 | gen_val = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers, pin_memory=True, 338 | drop_last=True, collate_fn=dataset_collate, sampler=val_sampler, 339 | worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed)) 340 | 341 | for epoch in range(Init_Epoch, Epoch): 342 | if distributed: 343 | train_sampler.set_epoch(epoch) 344 | 345 | set_optimizer_lr(optimizer, lr_scheduler_func, epoch) 346 | 347 | fit_one_epoch(model_train, model, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, Cuda, LFW_loader, lfw_eval_flag, fp16, scaler, save_period, save_dir, local_rank) 348 | 349 | if local_rank == 0: 350 | loss_history.writer.close() 351 | -------------------------------------------------------------------------------- /常见问题汇总.md: -------------------------------------------------------------------------------- 1 | 问题汇总的博客地址为[https://blog.csdn.net/weixin_44791964/article/details/107517428](https://blog.csdn.net/weixin_44791964/article/details/107517428)。 2 | 3 | # 问题汇总 4 | ## 1、下载问题 5 | ### a、代码下载 6 | **问:up主,可以给我发一份代码吗,代码在哪里下载啊? 7 | 答:Github上的地址就在视频简介里。复制一下就能进去下载了。** 8 | 9 | **问:up主,为什么我下载的代码提示压缩包损坏? 10 | 答:重新去Github下载。** 11 | 12 | **问:up主,为什么我下载的代码和你在视频以及博客上的代码不一样? 13 | 答:我常常会对代码进行更新,最终以实际的代码为准。** 14 | 15 | ### b、 权值下载 16 | **问:up主,为什么我下载的代码里面,model_data下面没有.pth或者.h5文件? 17 | 答:我一般会把权值上传到Github和百度网盘,在GITHUB的README里面就能找到。** 18 | 19 | ### c、 数据集下载 20 | **问:up主,XXXX数据集在哪里下载啊? 21 | 答:一般数据集的下载地址我会放在README里面,基本上都有,没有的话请及时联系我添加,直接发github的issue即可**。 22 | 23 | ## 2、环境配置问题 24 | ### a、现在库中所用的环境 25 | **pytorch代码对应的pytorch版本为1.2,博客地址对应**[https://blog.csdn.net/weixin_44791964/article/details/106037141](https://blog.csdn.net/weixin_44791964/article/details/106037141)。 26 | 27 | **keras代码对应的tensorflow版本为1.13.2,keras版本是2.1.5,博客地址对应**[https://blog.csdn.net/weixin_44791964/article/details/104702142](https://blog.csdn.net/weixin_44791964/article/details/104702142)。 28 | 29 | **tf2代码对应的tensorflow版本为2.2.0,无需安装keras,博客地址对应**[https://blog.csdn.net/weixin_44791964/article/details/109161493](https://blog.csdn.net/weixin_44791964/article/details/109161493)。 30 | 31 | **问:你的代码某某某版本的tensorflow和pytorch能用嘛? 32 | 答:最好按照我推荐的配置,配置教程也有!其它版本的我没有试过!可能出现问题但是一般问题不大。仅需要改少量代码即可。** 33 | 34 | ### b、30系列显卡环境配置 35 | 30系显卡由于框架更新不可使用上述环境配置教程。 36 | 当前我已经测试的可以用的30显卡配置如下: 37 | **pytorch代码对应的pytorch版本为1.7.0,cuda为11.0,cudnn为8.0.5**。 38 | 39 | **keras代码无法在win10下配置cuda11,在ubuntu下可以百度查询一下,配置tensorflow版本为1.15.4,keras版本是2.1.5或者2.3.1(少量函数接口不同,代码可能还需要少量调整。)** 40 | 41 | **tf2代码对应的tensorflow版本为2.4.0,cuda为11.0,cudnn为8.0.5**。 42 | 43 | ### c、GPU利用问题与环境使用问题 44 | **问:为什么我安装了tensorflow-gpu但是却没用利用GPU进行训练呢? 45 | 答:确认tensorflow-gpu已经装好,利用pip list查看tensorflow版本,然后查看任务管理器或者利用nvidia命令看看是否使用了gpu进行训练,任务管理器的话要看显存使用情况。** 46 | 47 | **问:up主,我好像没有在用gpu进行训练啊,怎么看是不是用了GPU进行训练? 48 | 答:查看是否使用GPU进行训练一般使用NVIDIA在命令行的查看命令,如果要看任务管理器的话,请看性能部分GPU的显存是否利用,或者查看任务管理器的Cuda,而非Copy。** 49 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20201013234241524.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDc5MTk2NA==,size_16,color_FFFFFF,t_70#pic_center) 50 | 51 | **问:up主,为什么我按照你的环境配置后还是不能使用? 52 | 答:请把你的GPU、CUDA、CUDNN、TF版本以及PYTORCH版本B站私聊告诉我。** 53 | 54 | **问:出现如下错误** 55 | ```python 56 | Traceback (most recent call last): 57 | File "C:\Users\focus\Anaconda3\ana\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\pywrap_tensorflow.py", line 58, in 58 | from tensorflow.python.pywrap_tensorflow_internal import * 59 | File "C:\Users\focus\Anaconda3\ana\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\pywrap_tensorflow_internal.py", line 28, in 60 | pywrap_tensorflow_internal = swig_import_helper() 61 | File "C:\Users\focus\Anaconda3\ana\envs\tensorflow-gpu\lib\site-packages\tensorflow\python\pywrap_tensorflow_internal.py", line 24, in swig_import_helper 62 | _mod = imp.load_module('_pywrap_tensorflow_internal', fp, pathname, description) 63 | File "C:\Users\focus\Anaconda3\ana\envs\tensorflow-gpu\lib\imp.py", line 243, in load_modulereturn load_dynamic(name, filename, file) 64 | File "C:\Users\focus\Anaconda3\ana\envs\tensorflow-gpu\lib\imp.py", line 343, in load_dynamic 65 | return _load(spec) 66 | ImportError: DLL load failed: 找不到指定的模块。 67 | ``` 68 | **答:如果没重启过就重启一下,否则重新按照步骤安装,还无法解决则把你的GPU、CUDA、CUDNN、TF版本以及PYTORCH版本私聊告诉我。** 69 | 70 | ### d、no module问题 71 | **问:为什么提示说no module name utils.utils(no module name nets.yolo、no module name nets.ssd等一系列问题)啊? 72 | 答:utils并不需要用pip装,它就在我上传的仓库的根目录,出现这个问题的原因是根目录不对,查查相对目录和根目录的概念。查了基本上就明白了。** 73 | 74 | **问:为什么提示说no module name matplotlib(no module name PIL,no module name cv2等等)? 75 | 答:这个库没安装打开命令行安装就好。pip install matplotlib** 76 | 77 | **问:为什么我已经用pip装了opencv(pillow、matplotlib等),还是提示no module name cv2? 78 | 答:没有激活环境装,要激活对应的conda环境进行安装才可以正常使用** 79 | 80 | **问:为什么提示说No module named 'torch' ? 81 | 答:其实我也真的很想知道为什么会有这个问题……这个pytorch没装是什么情况?一般就俩情况,一个是真的没装,还有一个是装到其它环境了,当前激活的环境不是自己装的环境。** 82 | 83 | **问:为什么提示说No module named 'tensorflow' ? 84 | 答:同上。** 85 | 86 | ### e、cuda安装失败问题 87 | 一般cuda安装前需要安装Visual Studio,装个2017版本即可。 88 | 89 | ### f、Ubuntu系统问题 90 | **所有代码在Ubuntu下可以使用,我两个系统都试过。** 91 | 92 | ### g、VSCODE提示错误的问题 93 | **问:为什么在VSCODE里面提示一大堆的错误啊? 94 | 答:我也提示一大堆的错误,但是不影响,是VSCODE的问题,如果不想看错误的话就装Pycharm。** 95 | 96 | ### h、使用cpu进行训练与预测的问题 97 | **对于keras和tf2的代码而言,如果想用cpu进行训练和预测,直接装cpu版本的tensorflow就可以了。** 98 | 99 | **对于pytorch的代码而言,如果想用cpu进行训练和预测,需要将cuda=True修改成cuda=False。** 100 | 101 | ### i、tqdm没有pos参数问题 102 | **问:运行代码提示'tqdm' object has no attribute 'pos'。 103 | 答:重装tqdm,换个版本就可以了。** 104 | 105 | ### j、提示decode(“utf-8”)的问题 106 | **由于h5py库的更新,安装过程中会自动安装h5py=3.0.0以上的版本,会导致decode("utf-8")的错误! 107 | 各位一定要在安装完tensorflow后利用命令装h5py=2.10.0!** 108 | ``` 109 | pip install h5py==2.10.0 110 | ``` 111 | 112 | ### k、提示TypeError: __array__() takes 1 positional argument but 2 were given错误 113 | 可以修改pillow版本解决。 114 | ``` 115 | pip install pillow==8.2.0 116 | ``` 117 | 118 | ### l、其它问题 119 | **问:为什么提示TypeError: cat() got an unexpected keyword argument 'axis',Traceback (most recent call last),AttributeError: 'Tensor' object has no attribute 'bool'? 120 | 答:这是版本问题,建议使用torch1.2以上版本** 121 | **其它有很多稀奇古怪的问题,很多是版本问题,建议按照我的视频教程安装Keras和tensorflow。比如装的是tensorflow2,就不用问我说为什么我没法运行Keras-yolo啥的。那是必然不行的。** 122 | 123 | ## 3、目标检测库问题汇总(人脸检测和分类库也可参考) 124 | ### a、shape不匹配问题 125 | #### 1)、训练时shape不匹配问题 126 | **问:up主,为什么运行train.py会提示shape不匹配啊? 127 | 答:在keras环境中,因为你训练的种类和原始的种类不同,网络结构会变化,所以最尾部的shape会有少量不匹配。** 128 | 129 | #### 2)、预测时shape不匹配问题 130 | **问:为什么我运行predict.py会提示我说shape不匹配呀。 131 | 在Pytorch里面是这样的:** 132 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200722171631901.png) 133 | 在Keras里面是这样的: 134 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200722171523380.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDc5MTk2NA==,size_16,color_FFFFFF,t_70) 135 | **答:原因主要有仨: 136 | 1、在ssd、FasterRCNN里面,可能是train.py里面的num_classes没改。 137 | 2、model_path没改。 138 | 3、classes_path没改。 139 | 请检查清楚了!确定自己所用的model_path和classes_path是对应的!训练的时候用到的num_classes或者classes_path也需要检查!** 140 | 141 | ### b、显存不足问题 142 | **问:为什么我运行train.py下面的命令行闪的贼快,还提示OOM啥的? 143 | 答:这是在keras中出现的,爆显存了,可以改小batch_size,SSD的显存占用率是最小的,建议用SSD; 144 | 2G显存:SSD、YOLOV4-TINY 145 | 4G显存:YOLOV3 146 | 6G显存:YOLOV4、Retinanet、M2det、Efficientdet、Faster RCNN等 147 | 8G+显存:随便选吧。** 148 | **需要注意的是,受到BatchNorm2d影响,batch_size不可为1,至少为2。** 149 | 150 | **问:为什么提示 RuntimeError: CUDA out of memory. Tried to allocate 52.00 MiB (GPU 0; 15.90 GiB total capacity; 14.85 GiB already allocated; 51.88 MiB free; 15.07 GiB reserved in total by PyTorch)? 151 | 答:这是pytorch中出现的,爆显存了,同上。** 152 | 153 | **问:为什么我显存都没利用,就直接爆显存了? 154 | 答:都爆显存了,自然就不利用了,模型没有开始训练。** 155 | ### c、训练问题(冻结训练,LOSS问题、训练效果问题等) 156 | **问:为什么要冻结训练和解冻训练呀? 157 | 答:这是迁移学习的思想,因为神经网络主干特征提取部分所提取到的特征是通用的,我们冻结起来训练可以加快训练效率,也可以防止权值被破坏。** 158 | 在冻结阶段,模型的主干被冻结了,特征提取网络不发生改变。占用的显存较小,仅对网络进行微调。 159 | 在解冻阶段,模型的主干不被冻结了,特征提取网络会发生改变。占用的显存较大,网络所有的参数都会发生改变。 160 | 161 | **问:为什么我的网络不收敛啊,LOSS是XXXX。 162 | 答:不同网络的LOSS不同,LOSS只是一个参考指标,用于查看网络是否收敛,而非评价网络好坏,我的yolo代码都没有归一化,所以LOSS值看起来比较高,LOSS的值不重要,重要的是是否在变小,预测是否有效果。** 163 | 164 | **问:为什么我的训练效果不好?预测了没有框(框不准)。 165 | 答:** 166 | 167 | 考虑几个问题: 168 | 1、目标信息问题,查看2007_train.txt文件是否有目标信息,没有的话请修改voc_annotation.py。 169 | 2、数据集问题,小于500的自行考虑增加数据集,同时测试不同的模型,确认数据集是好的。 170 | 3、是否解冻训练,如果数据集分布与常规画面差距过大需要进一步解冻训练,调整主干,加强特征提取能力。 171 | 4、网络问题,比如SSD不适合小目标,因为先验框固定了。 172 | 5、训练时长问题,有些同学只训练了几代表示没有效果,按默认参数训练完。 173 | 6、确认自己是否按照步骤去做了,如果比如voc_annotation.py里面的classes是否修改了等。 174 | 7、不同网络的LOSS不同,LOSS只是一个参考指标,用于查看网络是否收敛,而非评价网络好坏,LOSS的值不重要,重要的是是否收敛。 175 | 176 | **问:我怎么出现了gbk什么的编码错误啊:** 177 | ```python 178 | UnicodeDecodeError: 'gbk' codec can't decode byte 0xa6 in position 446: illegal multibyte sequence 179 | ``` 180 | **答:标签和路径不要使用中文,如果一定要使用中文,请注意处理的时候编码的问题,改成打开文件的encoding方式改为utf-8。** 181 | 182 | **问:我的图片是xxx*xxx的分辨率的,可以用吗!** 183 | **答:可以用,代码里面会自动进行resize或者数据增强。** 184 | 185 | **问:怎么进行多GPU训练? 186 | 答:pytorch的大多数代码可以直接使用gpu训练,keras的话直接百度就好了,实现并不复杂,我没有多卡没法详细测试,还需要各位同学自己努力了。** 187 | ### d、灰度图问题 188 | **问:能不能训练灰度图(预测灰度图)啊? 189 | 答:我的大多数库会将灰度图转化成RGB进行训练和预测,如果遇到代码不能训练或者预测灰度图的情况,可以尝试一下在get_random_data里面将Image.open后的结果转换成RGB,预测的时候也这样试试。(仅供参考)** 190 | 191 | ### e、断点续练问题 192 | **问:我已经训练过几个世代了,能不能从这个基础上继续开始训练 193 | 答:可以,你在训练前,和载入预训练权重一样载入训练过的权重就行了。一般训练好的权重会保存在logs文件夹里面,将model_path修改成你要开始的权值的路径即可。** 194 | 195 | ### f、预训练权重的问题 196 | **问:如果我要训练其它的数据集,预训练权重要怎么办啊?** 197 | **答:数据的预训练权重对不同数据集是通用的,因为特征是通用的,预训练权重对于99%的情况都必须要用,不用的话权值太过随机,特征提取效果不明显,网络训练的结果也不会好。** 198 | 199 | **问:up,我修改了网络,预训练权重还能用吗? 200 | 答:修改了主干的话,如果不是用的现有的网络,基本上预训练权重是不能用的,要么就自己判断权值里卷积核的shape然后自己匹配,要么只能自己预训练去了;修改了后半部分的话,前半部分的主干部分的预训练权重还是可以用的,如果是pytorch代码的话,需要自己修改一下载入权值的方式,判断shape后载入,如果是keras代码,直接by_name=True,skip_mismatch=True即可。** 201 | 权值匹配的方式可以参考如下: 202 | ```python 203 | # 加快模型训练的效率 204 | print('Loading weights into state dict...') 205 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 206 | model_dict = model.state_dict() 207 | pretrained_dict = torch.load(model_path, map_location=device) 208 | a = {} 209 | for k, v in pretrained_dict.items(): 210 | try: 211 | if np.shape(model_dict[k]) == np.shape(v): 212 | a[k]=v 213 | except: 214 | pass 215 | model_dict.update(a) 216 | model.load_state_dict(model_dict) 217 | print('Finished!') 218 | ``` 219 | 220 | **问:我要怎么不使用预训练权重啊? 221 | 答:把载入预训练权重的代码注释了就行。** 222 | 223 | **问:为什么我不使用预训练权重效果这么差啊? 224 | 答:因为随机初始化的权值不好,提取的特征不好,也就导致了模型训练的效果不好,voc07+12、coco+voc07+12效果都不一样,预训练权重还是非常重要的。** 225 | 226 | ### g、视频检测问题与摄像头检测问题 227 | **问:怎么用摄像头检测呀? 228 | 答:predict.py修改参数可以进行摄像头检测,也有视频详细解释了摄像头检测的思路。** 229 | 230 | **问:怎么用视频检测呀? 231 | 答:同上** 232 | ### h、从0开始训练问题 233 | **问:怎么在模型上从0开始训练? 234 | 答:在算力不足与调参能力不足的情况下从0开始训练毫无意义。模型特征提取能力在随机初始化参数的情况下非常差。没有好的参数调节能力和算力,无法使得网络正常收敛。** 235 | 如果一定要从0开始,那么训练的时候请注意几点: 236 | - 不载入预训练权重。 237 | - 不要进行冻结训练,注释冻结模型的代码。 238 | 239 | **问:为什么我不使用预训练权重效果这么差啊? 240 | 答:因为随机初始化的权值不好,提取的特征不好,也就导致了模型训练的效果不好,voc07+12、coco+voc07+12效果都不一样,预训练权重还是非常重要的。** 241 | 242 | ### i、保存问题 243 | **问:检测完的图片怎么保存? 244 | 答:一般目标检测用的是Image,所以查询一下PIL库的Image如何进行保存。详细看看predict.py文件的注释。** 245 | 246 | **问:怎么用视频保存呀? 247 | 答:详细看看predict.py文件的注释。** 248 | 249 | ### j、遍历问题 250 | **问:如何对一个文件夹的图片进行遍历? 251 | 答:一般使用os.listdir先找出文件夹里面的所有图片,然后根据predict.py文件里面的执行思路检测图片就行了,详细看看predict.py文件的注释。** 252 | 253 | **问:如何对一个文件夹的图片进行遍历?并且保存。 254 | 答:遍历的话一般使用os.listdir先找出文件夹里面的所有图片,然后根据predict.py文件里面的执行思路检测图片就行了。保存的话一般目标检测用的是Image,所以查询一下PIL库的Image如何进行保存。如果有些库用的是cv2,那就是查一下cv2怎么保存图片。详细看看predict.py文件的注释。** 255 | 256 | ### k、路径问题(No such file or directory) 257 | **问:我怎么出现了这样的错误呀:** 258 | ```python 259 | FileNotFoundError: 【Errno 2】 No such file or directory 260 | …………………………………… 261 | …………………………………… 262 | ``` 263 | **答:去检查一下文件夹路径,查看是否有对应文件;并且检查一下2007_train.txt,其中文件路径是否有错。** 264 | 关于路径有几个重要的点: 265 | **文件夹名称中一定不要有空格。 266 | 注意相对路径和绝对路径。 267 | 多百度路径相关的知识。** 268 | 269 | **所有的路径问题基本上都是根目录问题,好好查一下相对目录的概念!** 270 | ### l、和原版比较问题 271 | **问:你这个代码和原版比怎么样,可以达到原版的效果么? 272 | 答:基本上可以达到,我都用voc数据测过,我没有好显卡,没有能力在coco上测试与训练。** 273 | 274 | **问:你有没有实现yolov4所有的tricks,和原版差距多少? 275 | 答:并没有实现全部的改进部分,由于YOLOV4使用的改进实在太多了,很难完全实现与列出来,这里只列出来了一些我比较感兴趣,而且非常有效的改进。论文中提到的SAM(注意力机制模块),作者自己的源码也没有使用。还有其它很多的tricks,不是所有的tricks都有提升,我也没法实现全部的tricks。至于和原版的比较,我没有能力训练coco数据集,根据使用过的同学反应差距不大。** 276 | 277 | ### m、FPS问题(检测速度问题) 278 | **问:你这个FPS可以到达多少,可以到 XX FPS么? 279 | 答:FPS和机子的配置有关,配置高就快,配置低就慢。** 280 | 281 | **问:为什么我用服务器去测试yolov4(or others)的FPS只有十几? 282 | 答:检查是否正确安装了tensorflow-gpu或者pytorch的gpu版本,如果已经正确安装,可以去利用time.time()的方法查看detect_image里面,哪一段代码耗时更长(不仅只有网络耗时长,其它处理部分也会耗时,如绘图等)。** 283 | 284 | **问:为什么论文中说速度可以达到XX,但是这里却没有? 285 | 答:检查是否正确安装了tensorflow-gpu或者pytorch的gpu版本,如果已经正确安装,可以去利用time.time()的方法查看detect_image里面,哪一段代码耗时更长(不仅只有网络耗时长,其它处理部分也会耗时,如绘图等)。有些论文还会使用多batch进行预测,我并没有去实现这个部分。** 286 | 287 | ### n、预测图片不显示问题 288 | **问:为什么你的代码在预测完成后不显示图片?只是在命令行告诉我有什么目标。 289 | 答:给系统安装一个图片查看器就行了。** 290 | 291 | ### o、算法评价问题(目标检测的map、PR曲线、Recall、Precision等) 292 | **问:怎么计算map? 293 | 答:看map视频,都一个流程。** 294 | 295 | **问:计算map的时候,get_map.py里面有一个MINOVERLAP是什么用的,是iou吗? 296 | 答:是iou,它的作用是判断预测框和真实框的重合成度,如果重合程度大于MINOVERLAP,则预测正确。** 297 | 298 | **问:为什么get_map.py里面的self.confidence(self.score)要设置的那么小? 299 | 答:看一下map的视频的原理部分,要知道所有的结果然后再进行pr曲线的绘制。** 300 | 301 | **问:能不能说说怎么绘制PR曲线啥的呀。 302 | 答:可以看mAP视频,结果里面有PR曲线。** 303 | 304 | **问:怎么计算Recall、Precision指标。 305 | 答:这俩指标应该是相对于特定的置信度的,计算map的时候也会获得。** 306 | 307 | ### p、coco数据集训练问题 308 | **问:目标检测怎么训练COCO数据集啊?。 309 | 答:coco数据训练所需要的txt文件可以参考qqwweee的yolo3的库,格式都是一样的。** 310 | 311 | ### q、模型优化(模型修改)问题 312 | **问:up,YOLO系列使用Focal LOSS的代码你有吗,有提升吗? 313 | 答:很多人试过,提升效果也不大(甚至变的更Low),它自己有自己的正负样本的平衡方式。** 314 | 315 | **问:up,我修改了网络,预训练权重还能用吗? 316 | 答:修改了主干的话,如果不是用的现有的网络,基本上预训练权重是不能用的,要么就自己判断权值里卷积核的shape然后自己匹配,要么只能自己预训练去了;修改了后半部分的话,前半部分的主干部分的预训练权重还是可以用的,如果是pytorch代码的话,需要自己修改一下载入权值的方式,判断shape后载入,如果是keras代码,直接by_name=True,skip_mismatch=True即可。** 317 | 权值匹配的方式可以参考如下: 318 | ```python 319 | # 加快模型训练的效率 320 | print('Loading weights into state dict...') 321 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 322 | model_dict = model.state_dict() 323 | pretrained_dict = torch.load(model_path, map_location=device) 324 | a = {} 325 | for k, v in pretrained_dict.items(): 326 | try: 327 | if np.shape(model_dict[k]) == np.shape(v): 328 | a[k]=v 329 | except: 330 | pass 331 | model_dict.update(a) 332 | model.load_state_dict(model_dict) 333 | print('Finished!') 334 | ``` 335 | 336 | **问:up,怎么修改模型啊,我想发个小论文! 337 | 答:建议看看yolov3和yolov4的区别,然后看看yolov4的论文,作为一个大型调参现场非常有参考意义,使用了很多tricks。我能给的建议就是多看一些经典模型,然后拆解里面的亮点结构并使用。** 338 | 339 | ### r、部署问题 340 | 我没有具体部署到手机等设备上过,所以很多部署问题我并不了解…… 341 | 342 | ## 4、语义分割库问题汇总 343 | ### a、shape不匹配问题 344 | #### 1)、训练时shape不匹配问题 345 | **问:up主,为什么运行train.py会提示shape不匹配啊? 346 | 答:在keras环境中,因为你训练的种类和原始的种类不同,网络结构会变化,所以最尾部的shape会有少量不匹配。** 347 | 348 | #### 2)、预测时shape不匹配问题 349 | **问:为什么我运行predict.py会提示我说shape不匹配呀。 350 | 在Pytorch里面是这样的:** 351 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200722171631901.png) 352 | 在Keras里面是这样的: 353 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/20200722171523380.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDc5MTk2NA==,size_16,color_FFFFFF,t_70) 354 | **答:原因主要有二: 355 | 1、train.py里面的num_classes没改。 356 | 2、预测时num_classes没改。 357 | 请检查清楚!训练和预测的时候用到的num_classes都需要检查!** 358 | 359 | ### b、显存不足问题 360 | **问:为什么我运行train.py下面的命令行闪的贼快,还提示OOM啥的? 361 | 答:这是在keras中出现的,爆显存了,可以改小batch_size。** 362 | 363 | **需要注意的是,受到BatchNorm2d影响,batch_size不可为1,至少为2。** 364 | 365 | **问:为什么提示 RuntimeError: CUDA out of memory. Tried to allocate 52.00 MiB (GPU 0; 15.90 GiB total capacity; 14.85 GiB already allocated; 51.88 MiB free; 15.07 GiB reserved in total by PyTorch)? 366 | 答:这是pytorch中出现的,爆显存了,同上。** 367 | 368 | **问:为什么我显存都没利用,就直接爆显存了? 369 | 答:都爆显存了,自然就不利用了,模型没有开始训练。** 370 | 371 | ### c、训练问题(冻结训练,LOSS问题、训练效果问题等) 372 | **问:为什么要冻结训练和解冻训练呀? 373 | 答:这是迁移学习的思想,因为神经网络主干特征提取部分所提取到的特征是通用的,我们冻结起来训练可以加快训练效率,也可以防止权值被破坏。** 374 | **在冻结阶段,模型的主干被冻结了,特征提取网络不发生改变。占用的显存较小,仅对网络进行微调。** 375 | **在解冻阶段,模型的主干不被冻结了,特征提取网络会发生改变。占用的显存较大,网络所有的参数都会发生改变。** 376 | 377 | **问:为什么我的网络不收敛啊,LOSS是XXXX。 378 | 答:不同网络的LOSS不同,LOSS只是一个参考指标,用于查看网络是否收敛,而非评价网络好坏,我的yolo代码都没有归一化,所以LOSS值看起来比较高,LOSS的值不重要,重要的是是否在变小,预测是否有效果。** 379 | 380 | **问:为什么我的训练效果不好?预测了没有目标,结果是一片黑。 381 | 答:** 382 | **考虑几个问题: 383 | 1、数据集问题,这是最重要的问题。小于500的自行考虑增加数据集;一定要检查数据集的标签,视频中详细解析了VOC数据集的格式,但并不是有输入图片有输出标签即可,还需要确认标签的每一个像素值是否为它对应的种类。很多同学的标签格式不对,最常见的错误格式就是标签的背景为黑,目标为白,此时目标的像素点值为255,无法正常训练,目标需要为1才行。 384 | 2、是否解冻训练,如果数据集分布与常规画面差距过大需要进一步解冻训练,调整主干,加强特征提取能力。 385 | 3、网络问题,可以尝试不同的网络。 386 | 4、训练时长问题,有些同学只训练了几代表示没有效果,按默认参数训练完。 387 | 5、确认自己是否按照步骤去做了。 388 | 6、不同网络的LOSS不同,LOSS只是一个参考指标,用于查看网络是否收敛,而非评价网络好坏,LOSS的值不重要,重要的是是否收敛。** 389 | 390 | 391 | 392 | **问:为什么我的训练效果不好?对小目标预测不准确。 393 | 答:对于deeplab和pspnet而言,可以修改一下downsample_factor,当downsample_factor为16的时候下采样倍数过多,效果不太好,可以修改为8。** 394 | 395 | **问:我怎么出现了gbk什么的编码错误啊:** 396 | ```python 397 | UnicodeDecodeError: 'gbk' codec can't decode byte 0xa6 in position 446: illegal multibyte sequence 398 | ``` 399 | **答:标签和路径不要使用中文,如果一定要使用中文,请注意处理的时候编码的问题,改成打开文件的encoding方式改为utf-8。** 400 | 401 | **问:我的图片是xxx*xxx的分辨率的,可以用吗!** 402 | **答:可以用,代码里面会自动进行resize或者数据增强。** 403 | 404 | **问:怎么进行多GPU训练? 405 | 答:pytorch的大多数代码可以直接使用gpu训练,keras的话直接百度就好了,实现并不复杂,我没有多卡没法详细测试,还需要各位同学自己努力了。** 406 | 407 | ### d、灰度图问题 408 | **问:能不能训练灰度图(预测灰度图)啊? 409 | 答:我的大多数库会将灰度图转化成RGB进行训练和预测,如果遇到代码不能训练或者预测灰度图的情况,可以尝试一下在get_random_data里面将Image.open后的结果转换成RGB,预测的时候也这样试试。(仅供参考)** 410 | 411 | ### e、断点续练问题 412 | **问:我已经训练过几个世代了,能不能从这个基础上继续开始训练 413 | 答:可以,你在训练前,和载入预训练权重一样载入训练过的权重就行了。一般训练好的权重会保存在logs文件夹里面,将model_path修改成你要开始的权值的路径即可。** 414 | 415 | ### f、预训练权重的问题 416 | 417 | **问:如果我要训练其它的数据集,预训练权重要怎么办啊?** 418 | **答:数据的预训练权重对不同数据集是通用的,因为特征是通用的,预训练权重对于99%的情况都必须要用,不用的话权值太过随机,特征提取效果不明显,网络训练的结果也不会好。** 419 | 420 | **问:up,我修改了网络,预训练权重还能用吗? 421 | 答:修改了主干的话,如果不是用的现有的网络,基本上预训练权重是不能用的,要么就自己判断权值里卷积核的shape然后自己匹配,要么只能自己预训练去了;修改了后半部分的话,前半部分的主干部分的预训练权重还是可以用的,如果是pytorch代码的话,需要自己修改一下载入权值的方式,判断shape后载入,如果是keras代码,直接by_name=True,skip_mismatch=True即可。** 422 | 权值匹配的方式可以参考如下: 423 | 424 | ```python 425 | # 加快模型训练的效率 426 | print('Loading weights into state dict...') 427 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 428 | model_dict = model.state_dict() 429 | pretrained_dict = torch.load(model_path, map_location=device) 430 | a = {} 431 | for k, v in pretrained_dict.items(): 432 | try: 433 | if np.shape(model_dict[k]) == np.shape(v): 434 | a[k]=v 435 | except: 436 | pass 437 | model_dict.update(a) 438 | model.load_state_dict(model_dict) 439 | print('Finished!') 440 | ``` 441 | 442 | **问:我要怎么不使用预训练权重啊? 443 | 答:把载入预训练权重的代码注释了就行。** 444 | 445 | **问:为什么我不使用预训练权重效果这么差啊? 446 | 答:因为随机初始化的权值不好,提取的特征不好,也就导致了模型训练的效果不好,预训练权重还是非常重要的。** 447 | 448 | ### g、视频检测问题与摄像头检测问题 449 | **问:怎么用摄像头检测呀? 450 | 答:predict.py修改参数可以进行摄像头检测,也有视频详细解释了摄像头检测的思路。** 451 | 452 | **问:怎么用视频检测呀? 453 | 答:同上** 454 | 455 | ### h、从0开始训练问题 456 | **问:怎么在模型上从0开始训练? 457 | 答:在算力不足与调参能力不足的情况下从0开始训练毫无意义。模型特征提取能力在随机初始化参数的情况下非常差。没有好的参数调节能力和算力,无法使得网络正常收敛。** 458 | 如果一定要从0开始,那么训练的时候请注意几点: 459 | - 不载入预训练权重。 460 | - 不要进行冻结训练,注释冻结模型的代码。 461 | 462 | **问:为什么我不使用预训练权重效果这么差啊? 463 | 答:因为随机初始化的权值不好,提取的特征不好,也就导致了模型训练的效果不好,预训练权重还是非常重要的。** 464 | 465 | ### i、保存问题 466 | **问:检测完的图片怎么保存? 467 | 答:一般目标检测用的是Image,所以查询一下PIL库的Image如何进行保存。详细看看predict.py文件的注释。** 468 | 469 | **问:怎么用视频保存呀? 470 | 答:详细看看predict.py文件的注释。** 471 | 472 | ### j、遍历问题 473 | **问:如何对一个文件夹的图片进行遍历? 474 | 答:一般使用os.listdir先找出文件夹里面的所有图片,然后根据predict.py文件里面的执行思路检测图片就行了,详细看看predict.py文件的注释。** 475 | 476 | **问:如何对一个文件夹的图片进行遍历?并且保存。 477 | 答:遍历的话一般使用os.listdir先找出文件夹里面的所有图片,然后根据predict.py文件里面的执行思路检测图片就行了。保存的话一般目标检测用的是Image,所以查询一下PIL库的Image如何进行保存。如果有些库用的是cv2,那就是查一下cv2怎么保存图片。详细看看predict.py文件的注释。** 478 | 479 | ### k、路径问题(No such file or directory) 480 | **问:我怎么出现了这样的错误呀:** 481 | ```python 482 | FileNotFoundError: 【Errno 2】 No such file or directory 483 | …………………………………… 484 | …………………………………… 485 | ``` 486 | 487 | **答:去检查一下文件夹路径,查看是否有对应文件;并且检查一下2007_train.txt,其中文件路径是否有错。** 488 | 关于路径有几个重要的点: 489 | **文件夹名称中一定不要有空格。 490 | 注意相对路径和绝对路径。 491 | 多百度路径相关的知识。** 492 | 493 | **所有的路径问题基本上都是根目录问题,好好查一下相对目录的概念!** 494 | 495 | ### l、FPS问题(检测速度问题) 496 | **问:你这个FPS可以到达多少,可以到 XX FPS么? 497 | 答:FPS和机子的配置有关,配置高就快,配置低就慢。** 498 | 499 | **问:为什么论文中说速度可以达到XX,但是这里却没有? 500 | 答:检查是否正确安装了tensorflow-gpu或者pytorch的gpu版本,如果已经正确安装,可以去利用time.time()的方法查看detect_image里面,哪一段代码耗时更长(不仅只有网络耗时长,其它处理部分也会耗时,如绘图等)。有些论文还会使用多batch进行预测,我并没有去实现这个部分。** 501 | 502 | ### m、预测图片不显示问题 503 | **问:为什么你的代码在预测完成后不显示图片?只是在命令行告诉我有什么目标。 504 | 答:给系统安装一个图片查看器就行了。** 505 | 506 | ### n、算法评价问题(miou) 507 | **问:怎么计算miou? 508 | 答:参考视频里的miou测量部分。** 509 | 510 | **问:怎么计算Recall、Precision指标。 511 | 答:现有的代码还无法获得,需要各位同学理解一下混淆矩阵的概念,然后自行计算一下。** 512 | 513 | ### o、模型优化(模型修改)问题 514 | **问:up,我修改了网络,预训练权重还能用吗? 515 | 答:修改了主干的话,如果不是用的现有的网络,基本上预训练权重是不能用的,要么就自己判断权值里卷积核的shape然后自己匹配,要么只能自己预训练去了;修改了后半部分的话,前半部分的主干部分的预训练权重还是可以用的,如果是pytorch代码的话,需要自己修改一下载入权值的方式,判断shape后载入,如果是keras代码,直接by_name=True,skip_mismatch=True即可。** 516 | 权值匹配的方式可以参考如下: 517 | 518 | ```python 519 | # 加快模型训练的效率 520 | print('Loading weights into state dict...') 521 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 522 | model_dict = model.state_dict() 523 | pretrained_dict = torch.load(model_path, map_location=device) 524 | a = {} 525 | for k, v in pretrained_dict.items(): 526 | try: 527 | if np.shape(model_dict[k]) == np.shape(v): 528 | a[k]=v 529 | except: 530 | pass 531 | model_dict.update(a) 532 | model.load_state_dict(model_dict) 533 | print('Finished!') 534 | ``` 535 | 536 | 537 | 538 | **问:up,怎么修改模型啊,我想发个小论文! 539 | 答:建议看看目标检测中yolov4的论文,作为一个大型调参现场非常有参考意义,使用了很多tricks。我能给的建议就是多看一些经典模型,然后拆解里面的亮点结构并使用。常用的tricks如注意力机制什么的,可以试试。** 540 | 541 | ### p、部署问题 542 | 我没有具体部署到手机等设备上过,所以很多部署问题我并不了解…… 543 | 544 | ## 5、交流群问题 545 | **问:up,有没有QQ群啥的呢? 546 | 答:没有没有,我没有时间管理QQ群……** 547 | 548 | ## 6、怎么学习的问题 549 | **问:up,你的学习路线怎么样的?我是个小白我要怎么学? 550 | 答:这里有几点需要注意哈 551 | 1、我不是高手,很多东西我也不会,我的学习路线也不一定适用所有人。 552 | 2、我实验室不做深度学习,所以我很多东西都是自学,自己摸索,正确与否我也不知道。 553 | 3、我个人觉得学习更靠自学** 554 | 学习路线的话,我是先学习了莫烦的python教程,从tensorflow、keras、pytorch入门,入门完之后学的SSD,YOLO,然后了解了很多经典的卷积网,后面就开始学很多不同的代码了,我的学习方法就是一行一行的看,了解整个代码的执行流程,特征层的shape变化等,花了很多时间也没有什么捷径,就是要花时间吧。 --------------------------------------------------------------------------------