├── .gitignore ├── README.md ├── annotations ├── landmark_imagelist.txt └── wider_origin_anno.txt ├── checkpoint.py ├── config.py ├── data ├── test_images │ ├── img_12883.jpg │ ├── img_12884.jpg │ ├── img_12903.jpg │ ├── img_12934.jpg │ ├── img_12936.jpg │ ├── img_12965.jpg │ ├── img_12993.jpg │ ├── img_13068.jpg │ ├── img_13092.jpg │ ├── img_13094.jpg │ ├── img_13095.jpg │ ├── img_13101.jpg │ ├── img_13104.jpg │ ├── img_13105.jpg │ ├── img_13109.jpg │ ├── img_13116.jpg │ ├── img_13117.jpg │ ├── img_13129.jpg │ ├── img_13141.jpg │ ├── img_13144.jpg │ ├── img_13147.jpg │ ├── img_13152.jpg │ ├── img_13190.jpg │ ├── img_13193.jpg │ ├── img_13199.jpg │ ├── img_13214.jpg │ ├── img_13225.jpg │ ├── img_13235.jpg │ ├── img_13245.jpg │ ├── img_13271.jpg │ ├── img_13293.jpg │ ├── img_13298.jpg │ ├── img_13326.jpg │ ├── img_13331.jpg │ ├── img_13352.jpg │ ├── img_13370.jpg │ ├── img_13378.jpg │ ├── img_13383.jpg │ ├── img_13395.jpg │ ├── img_13418.jpg │ ├── img_13437.jpg │ ├── img_13454.jpg │ ├── img_13459.jpg │ ├── img_13479.jpg │ ├── img_13540.jpg │ ├── img_13561.jpg │ ├── img_13564.jpg │ ├── img_13570.jpg │ ├── img_13573.jpg │ ├── img_13575.jpg │ ├── img_13629.jpg │ ├── img_13643.jpg │ ├── img_13660.jpg │ ├── img_13663.jpg │ ├── img_13692.jpg │ ├── img_13705.jpg │ ├── img_13747.jpg │ ├── img_13752.jpg │ ├── img_13816.jpg │ ├── img_13858.jpg │ ├── img_13869.jpg │ ├── img_13893.jpg │ ├── img_13915.jpg │ └── img_13941.jpg └── you_result │ ├── img_12883.jpg │ ├── img_12884.jpg │ ├── img_12903.jpg │ ├── img_12934.jpg │ ├── img_12936.jpg │ ├── img_12965.jpg │ ├── img_12993.jpg │ ├── img_13068.jpg │ ├── img_13092.jpg │ ├── img_13094.jpg │ ├── img_13095.jpg │ ├── img_13101.jpg │ ├── img_13104.jpg │ ├── img_13105.jpg │ ├── img_13109.jpg │ ├── img_13116.jpg │ ├── img_13117.jpg │ ├── img_13129.jpg │ ├── img_13141.jpg │ ├── img_13144.jpg │ ├── img_13147.jpg │ ├── img_13152.jpg │ ├── img_13190.jpg │ ├── img_13193.jpg │ ├── img_13199.jpg │ ├── img_13214.jpg │ ├── img_13225.jpg │ ├── img_13235.jpg │ ├── img_13245.jpg │ ├── img_13271.jpg │ ├── img_13293.jpg │ ├── img_13298.jpg │ ├── img_13326.jpg │ ├── img_13331.jpg │ ├── img_13352.jpg │ ├── img_13370.jpg │ ├── img_13378.jpg │ ├── img_13383.jpg │ ├── img_13395.jpg │ ├── img_13418.jpg │ ├── img_13437.jpg │ ├── img_13454.jpg │ ├── img_13459.jpg │ ├── img_13479.jpg │ ├── img_13540.jpg │ ├── img_13561.jpg │ ├── img_13564.jpg │ ├── img_13570.jpg │ ├── img_13573.jpg │ ├── img_13575.jpg │ ├── img_13629.jpg │ ├── img_13643.jpg │ ├── img_13660.jpg │ ├── img_13663.jpg │ ├── img_13692.jpg │ ├── img_13705.jpg │ ├── img_13747.jpg │ ├── img_13752.jpg │ ├── img_13816.jpg │ ├── img_13858.jpg │ ├── img_13869.jpg │ ├── img_13893.jpg │ ├── img_13915.jpg │ └── img_13941.jpg ├── models ├── __init__.py ├── lossfn.py ├── onet.py ├── pnet.py └── rnet.py ├── preprocessing ├── __init__.py ├── assemble.py ├── assemble_onet_imglist.py ├── assemble_pnet_imglist.py ├── assemble_rnet_imglist.py ├── gen_landmark_12.py ├── gen_landmark_24.py ├── gen_landmark_48.py ├── gen_onet_data.py ├── gen_pnet_data.py └── gen_rnet_data.py ├── test_image.py ├── test_on_FDDB.py ├── test_youModel_images.py ├── tools ├── __init__.py ├── image_reader.py ├── image_tools.py ├── imagedb.py ├── logger.py ├── test_detect.py ├── train_detect.py ├── utils.py └── vision.py └── training ├── __init__.py ├── onet ├── config.py ├── train.py └── trainer.py ├── pnet ├── config.py ├── train.py └── trainer.py └── rnet ├── config.py ├── train.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | results/ 3 | *.pyc 4 | .vscode/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # 人脸检测 3 | 4 | 5 | --- 6 | ## 实验目的 7 | 1. 理解和掌握基于神经网络的人脸检测方法的理论基础知识。 8 | 2. 理解[MTCNN](https://kpzhang93.github.io/MTCNN_face_detection_alignment/paper/spl.pdf)人脸检测的基本流程,并加以实践。 9 | 10 | 11 | 12 | ## 实验环境 13 | [anaconda3](https://www.anaconda.com/download/) 14 | [pytorch 0.4.1](https://pytorch.org/) 15 | [torchvision](https://pytorch.org/) 16 | [opencv-python](https://pypi.org/project/opencv-python/) 17 | tensorflow, tensorboard 等。 18 | 19 | 20 | 21 | ## 实验步骤 22 | **一、获取代码** 23 | 24 | 实验完整代码[mtcnn_pytorch](https://github.com/xiezheng-cs/mtcnn_pytorch),可直接下载或是通过git clone命令下载。 25 | ```bash 26 | git clone https://github.com/xiezheng-cs/mtcnn_pytorch.git 27 | ``` 28 | 29 | **二、实验环境安装** 30 | 1. 确保本机或是服务器已安装好[anaconda3](https://www.anaconda.com/download/)环境; 31 | 2. pip或conda安装[pytorch 0.4.1 和 torchvision](https://pytorch.org/)环境; 32 | 3. pip或conda安装[opencv-python](https://pypi.org/project/opencv-python/)环境; 33 | 4. pip或conda按照tensorflow和tensorboard. 34 | 35 | ```bash 36 | pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-win_amd64.whl # Windows 37 | pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-linux_x86_64.whl # Linux 38 | pip install torchvision 39 | pip install opencv-python 40 | ``` 41 | 42 | **三、简单测试给定模型** 43 | 44 | 直接使用我们[训练好的网络模型](https://github.com/xiezheng-cs/mtcnn_pytorch/releases)在给定的测试数据集(位于mtcnn_pytorch/data/test_images/目录下,共64张测试图片),运行以下命令,即可在mtcnn_pytorch/data/you_result/目录下查看检测结果。 45 | ```bash 46 | cd mtcnn_pytorch/ 47 | python test_image.py 48 | ``` 49 | 50 | 51 | ----- 52 | -------------------------------------------------------------------------------- /checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tools.utils as utils 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class CheckPoint(object): 8 | """ 9 | save model state to file 10 | check_point_params: model, optimizer, epoch 11 | """ 12 | 13 | def __init__(self, save_path): 14 | 15 | self.save_path = os.path.join(save_path, "check_point") 16 | self.check_point_params = {'model': None, 17 | 'optimizer': None, 18 | 'epoch': None} 19 | 20 | # make directory 21 | if not os.path.isdir(self.save_path): 22 | os.makedirs(self.save_path) 23 | 24 | def load_state(self, model, state_dict): 25 | """ 26 | load state_dict to model 27 | :params model: 28 | :params state_dict: 29 | :return: model 30 | """ 31 | model.eval() 32 | model_dict = model.state_dict() 33 | 34 | for key, value in list(state_dict.items()): 35 | if key in list(model_dict.keys()): 36 | # print key, value.size() 37 | model_dict[key] = value 38 | else: 39 | pass 40 | # print "key error:", key, value.size() 41 | model.load_state_dict(model_dict) 42 | # model.load_state_dict(state_dict) 43 | # set the model in evaluation mode, otherwise the accuracy will change 44 | # model.eval() 45 | # model.load_state_dict(state_dict) 46 | 47 | return model 48 | 49 | def load_model(self, model_path): 50 | """ 51 | load model 52 | :params model_path: path to the model 53 | :return: model_state_dict 54 | """ 55 | if os.path.isfile(model_path): 56 | print("|===>Load retrain model from:", model_path) 57 | # model_state_dict = torch.load(model_path, map_location={'cuda:1':'cuda:0'}) 58 | model_state_dict = torch.load(model_path, map_location='cpu') 59 | return model_state_dict 60 | else: 61 | assert False, "file not exits, model path: " + model_path 62 | 63 | def load_checkpoint(self, checkpoint_path): 64 | """ 65 | load checkpoint file 66 | :params checkpoint_path: path to the checkpoint file 67 | :return: model_state_dict, optimizer_state_dict, epoch 68 | """ 69 | if os.path.isfile(checkpoint_path): 70 | print("|===>Load resume check-point from:", checkpoint_path) 71 | self.check_point_params = torch.load(checkpoint_path) 72 | model_state_dict = self.check_point_params['model'] 73 | optimizer_state_dict = self.check_point_params['optimizer'] 74 | epoch = self.check_point_params['epoch'] 75 | return model_state_dict, optimizer_state_dict, epoch 76 | else: 77 | assert False, "file not exits" + checkpoint_path 78 | 79 | def save_checkpoint(self, model, optimizer, epoch, index=0): 80 | """ 81 | :params model: model 82 | :params optimizer: optimizer 83 | :params epoch: training epoch 84 | :params index: index of saved file, default: 0 85 | Note: if we add hook to the grad by using register_hook(hook), then the hook function 86 | can not be saved so we need to save state_dict() only. Although save state dictionary 87 | is recommended, some times we still need to save the whole model as it can save all 88 | the information of the trained model, and we do not need to create a new network in 89 | next time. However, the GPU information will be saved too, which leads to some issues 90 | when we use the model on different machine 91 | """ 92 | 93 | # get state_dict from model and optimizer 94 | model = self.list2sequential(model) 95 | if isinstance(model, nn.DataParallel): 96 | model = model.module 97 | model = model.state_dict() 98 | optimizer = optimizer.state_dict() 99 | 100 | # save information to a dict 101 | self.check_point_params['model'] = model 102 | self.check_point_params['optimizer'] = optimizer 103 | self.check_point_params['epoch'] = epoch 104 | 105 | # save to file 106 | torch.save(self.check_point_params, os.path.join( 107 | self.save_path, "checkpoint_%03d.pth" % index)) 108 | 109 | def list2sequential(self, model): 110 | if isinstance(model, list): 111 | model = nn.Sequential(*model) 112 | return model 113 | 114 | def save_model(self, model, best_flag=False, index=0, tag=""): 115 | """ 116 | :params model: model to save 117 | :params best_flag: if True, the saved model is the one that gets best performance 118 | """ 119 | # get state dict 120 | model = self.list2sequential(model) 121 | if isinstance(model, nn.DataParallel): 122 | model = model.module 123 | model = model.state_dict() 124 | if best_flag: 125 | if tag != "": 126 | torch.save(model, os.path.join(self.save_path, "%s_best_model.pth"%tag)) 127 | else: 128 | torch.save(model, os.path.join(self.save_path, "best_model.pth")) 129 | else: 130 | if tag != "": 131 | torch.save(model, os.path.join(self.save_path, "%s_model_%03d.pth" % (tag, index))) 132 | else: 133 | torch.save(model, os.path.join(self.save_path, "model_%03d.pth" % index)) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | MODEL_STORE_DIR = "./models" 5 | 6 | ANNO_STORE_DIR = "./annotations" 7 | 8 | TRAIN_DATA_DIR = "./data" 9 | 10 | LOG_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))+"/log" 11 | 12 | USE_CUDA = True 13 | 14 | TRAIN_BATCH_SIZE = 512 15 | 16 | TRAIN_LR = 0.01 17 | 18 | END_EPOCH = 10 19 | 20 | 21 | PNET_POSTIVE_ANNO_FILENAME = "pos_12.txt" 22 | PNET_NEGATIVE_ANNO_FILENAME = "neg_12.txt" 23 | PNET_PART_ANNO_FILENAME = "part_12.txt" 24 | PNET_LANDMARK_ANNO_FILENAME = "landmark_12.txt" 25 | 26 | 27 | RNET_POSTIVE_ANNO_FILENAME = "pos_24.txt" 28 | RNET_NEGATIVE_ANNO_FILENAME = "neg_24.txt" 29 | RNET_PART_ANNO_FILENAME = "part_24.txt" 30 | RNET_LANDMARK_ANNO_FILENAME = "landmark_24.txt" 31 | 32 | 33 | ONET_POSTIVE_ANNO_FILENAME = "pos_48.txt" 34 | ONET_NEGATIVE_ANNO_FILENAME = "neg_48.txt" 35 | ONET_PART_ANNO_FILENAME = "part_48.txt" 36 | ONET_LANDMARK_ANNO_FILENAME = "landmark_48.txt" 37 | 38 | PNET_TRAIN_IMGLIST_FILENAME = "imglist_anno_12.txt" 39 | RNET_TRAIN_IMGLIST_FILENAME = "imglist_anno_24.txt" 40 | ONET_TRAIN_IMGLIST_FILENAME = "imglist_anno_48.txt" -------------------------------------------------------------------------------- /data/test_images/img_12883.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_12883.jpg -------------------------------------------------------------------------------- /data/test_images/img_12884.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_12884.jpg -------------------------------------------------------------------------------- /data/test_images/img_12903.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_12903.jpg -------------------------------------------------------------------------------- /data/test_images/img_12934.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_12934.jpg -------------------------------------------------------------------------------- /data/test_images/img_12936.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_12936.jpg -------------------------------------------------------------------------------- /data/test_images/img_12965.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_12965.jpg -------------------------------------------------------------------------------- /data/test_images/img_12993.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_12993.jpg -------------------------------------------------------------------------------- /data/test_images/img_13068.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13068.jpg -------------------------------------------------------------------------------- /data/test_images/img_13092.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13092.jpg -------------------------------------------------------------------------------- /data/test_images/img_13094.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13094.jpg -------------------------------------------------------------------------------- /data/test_images/img_13095.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13095.jpg -------------------------------------------------------------------------------- /data/test_images/img_13101.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13101.jpg -------------------------------------------------------------------------------- /data/test_images/img_13104.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13104.jpg -------------------------------------------------------------------------------- /data/test_images/img_13105.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13105.jpg -------------------------------------------------------------------------------- /data/test_images/img_13109.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13109.jpg -------------------------------------------------------------------------------- /data/test_images/img_13116.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13116.jpg -------------------------------------------------------------------------------- /data/test_images/img_13117.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13117.jpg -------------------------------------------------------------------------------- /data/test_images/img_13129.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13129.jpg -------------------------------------------------------------------------------- /data/test_images/img_13141.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13141.jpg -------------------------------------------------------------------------------- /data/test_images/img_13144.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13144.jpg -------------------------------------------------------------------------------- /data/test_images/img_13147.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13147.jpg -------------------------------------------------------------------------------- /data/test_images/img_13152.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13152.jpg -------------------------------------------------------------------------------- /data/test_images/img_13190.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13190.jpg -------------------------------------------------------------------------------- /data/test_images/img_13193.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13193.jpg -------------------------------------------------------------------------------- /data/test_images/img_13199.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13199.jpg -------------------------------------------------------------------------------- /data/test_images/img_13214.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13214.jpg -------------------------------------------------------------------------------- /data/test_images/img_13225.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13225.jpg -------------------------------------------------------------------------------- /data/test_images/img_13235.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13235.jpg -------------------------------------------------------------------------------- /data/test_images/img_13245.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13245.jpg -------------------------------------------------------------------------------- /data/test_images/img_13271.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13271.jpg -------------------------------------------------------------------------------- /data/test_images/img_13293.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13293.jpg -------------------------------------------------------------------------------- /data/test_images/img_13298.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13298.jpg -------------------------------------------------------------------------------- /data/test_images/img_13326.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13326.jpg -------------------------------------------------------------------------------- /data/test_images/img_13331.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13331.jpg -------------------------------------------------------------------------------- /data/test_images/img_13352.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13352.jpg -------------------------------------------------------------------------------- /data/test_images/img_13370.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13370.jpg -------------------------------------------------------------------------------- /data/test_images/img_13378.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13378.jpg -------------------------------------------------------------------------------- /data/test_images/img_13383.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13383.jpg -------------------------------------------------------------------------------- /data/test_images/img_13395.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13395.jpg -------------------------------------------------------------------------------- /data/test_images/img_13418.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13418.jpg -------------------------------------------------------------------------------- /data/test_images/img_13437.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13437.jpg -------------------------------------------------------------------------------- /data/test_images/img_13454.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13454.jpg -------------------------------------------------------------------------------- /data/test_images/img_13459.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13459.jpg -------------------------------------------------------------------------------- /data/test_images/img_13479.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13479.jpg -------------------------------------------------------------------------------- /data/test_images/img_13540.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13540.jpg -------------------------------------------------------------------------------- /data/test_images/img_13561.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13561.jpg -------------------------------------------------------------------------------- /data/test_images/img_13564.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13564.jpg -------------------------------------------------------------------------------- /data/test_images/img_13570.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13570.jpg -------------------------------------------------------------------------------- /data/test_images/img_13573.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13573.jpg -------------------------------------------------------------------------------- /data/test_images/img_13575.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13575.jpg -------------------------------------------------------------------------------- /data/test_images/img_13629.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13629.jpg -------------------------------------------------------------------------------- /data/test_images/img_13643.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13643.jpg -------------------------------------------------------------------------------- /data/test_images/img_13660.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13660.jpg -------------------------------------------------------------------------------- /data/test_images/img_13663.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13663.jpg -------------------------------------------------------------------------------- /data/test_images/img_13692.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13692.jpg -------------------------------------------------------------------------------- /data/test_images/img_13705.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13705.jpg -------------------------------------------------------------------------------- /data/test_images/img_13747.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13747.jpg -------------------------------------------------------------------------------- /data/test_images/img_13752.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13752.jpg -------------------------------------------------------------------------------- /data/test_images/img_13816.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13816.jpg -------------------------------------------------------------------------------- /data/test_images/img_13858.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13858.jpg -------------------------------------------------------------------------------- /data/test_images/img_13869.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13869.jpg -------------------------------------------------------------------------------- /data/test_images/img_13893.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13893.jpg -------------------------------------------------------------------------------- /data/test_images/img_13915.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13915.jpg -------------------------------------------------------------------------------- /data/test_images/img_13941.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/test_images/img_13941.jpg -------------------------------------------------------------------------------- /data/you_result/img_12883.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_12883.jpg -------------------------------------------------------------------------------- /data/you_result/img_12884.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_12884.jpg -------------------------------------------------------------------------------- /data/you_result/img_12903.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_12903.jpg -------------------------------------------------------------------------------- /data/you_result/img_12934.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_12934.jpg -------------------------------------------------------------------------------- /data/you_result/img_12936.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_12936.jpg -------------------------------------------------------------------------------- /data/you_result/img_12965.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_12965.jpg -------------------------------------------------------------------------------- /data/you_result/img_12993.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_12993.jpg -------------------------------------------------------------------------------- /data/you_result/img_13068.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13068.jpg -------------------------------------------------------------------------------- /data/you_result/img_13092.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13092.jpg -------------------------------------------------------------------------------- /data/you_result/img_13094.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13094.jpg -------------------------------------------------------------------------------- /data/you_result/img_13095.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13095.jpg -------------------------------------------------------------------------------- /data/you_result/img_13101.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13101.jpg -------------------------------------------------------------------------------- /data/you_result/img_13104.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13104.jpg -------------------------------------------------------------------------------- /data/you_result/img_13105.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13105.jpg -------------------------------------------------------------------------------- /data/you_result/img_13109.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13109.jpg -------------------------------------------------------------------------------- /data/you_result/img_13116.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13116.jpg -------------------------------------------------------------------------------- /data/you_result/img_13117.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13117.jpg -------------------------------------------------------------------------------- /data/you_result/img_13129.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13129.jpg -------------------------------------------------------------------------------- /data/you_result/img_13141.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13141.jpg -------------------------------------------------------------------------------- /data/you_result/img_13144.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13144.jpg -------------------------------------------------------------------------------- /data/you_result/img_13147.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13147.jpg -------------------------------------------------------------------------------- /data/you_result/img_13152.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13152.jpg -------------------------------------------------------------------------------- /data/you_result/img_13190.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13190.jpg -------------------------------------------------------------------------------- /data/you_result/img_13193.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13193.jpg -------------------------------------------------------------------------------- /data/you_result/img_13199.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13199.jpg -------------------------------------------------------------------------------- /data/you_result/img_13214.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13214.jpg -------------------------------------------------------------------------------- /data/you_result/img_13225.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13225.jpg -------------------------------------------------------------------------------- /data/you_result/img_13235.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13235.jpg -------------------------------------------------------------------------------- /data/you_result/img_13245.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13245.jpg -------------------------------------------------------------------------------- /data/you_result/img_13271.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13271.jpg -------------------------------------------------------------------------------- /data/you_result/img_13293.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13293.jpg -------------------------------------------------------------------------------- /data/you_result/img_13298.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13298.jpg -------------------------------------------------------------------------------- /data/you_result/img_13326.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13326.jpg -------------------------------------------------------------------------------- /data/you_result/img_13331.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13331.jpg -------------------------------------------------------------------------------- /data/you_result/img_13352.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13352.jpg -------------------------------------------------------------------------------- /data/you_result/img_13370.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13370.jpg -------------------------------------------------------------------------------- /data/you_result/img_13378.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13378.jpg -------------------------------------------------------------------------------- /data/you_result/img_13383.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13383.jpg -------------------------------------------------------------------------------- /data/you_result/img_13395.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13395.jpg -------------------------------------------------------------------------------- /data/you_result/img_13418.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13418.jpg -------------------------------------------------------------------------------- /data/you_result/img_13437.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13437.jpg -------------------------------------------------------------------------------- /data/you_result/img_13454.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13454.jpg -------------------------------------------------------------------------------- /data/you_result/img_13459.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13459.jpg -------------------------------------------------------------------------------- /data/you_result/img_13479.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13479.jpg -------------------------------------------------------------------------------- /data/you_result/img_13540.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13540.jpg -------------------------------------------------------------------------------- /data/you_result/img_13561.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13561.jpg -------------------------------------------------------------------------------- /data/you_result/img_13564.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13564.jpg -------------------------------------------------------------------------------- /data/you_result/img_13570.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13570.jpg -------------------------------------------------------------------------------- /data/you_result/img_13573.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13573.jpg -------------------------------------------------------------------------------- /data/you_result/img_13575.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13575.jpg -------------------------------------------------------------------------------- /data/you_result/img_13629.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13629.jpg -------------------------------------------------------------------------------- /data/you_result/img_13643.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13643.jpg -------------------------------------------------------------------------------- /data/you_result/img_13660.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13660.jpg -------------------------------------------------------------------------------- /data/you_result/img_13663.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13663.jpg -------------------------------------------------------------------------------- /data/you_result/img_13692.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13692.jpg -------------------------------------------------------------------------------- /data/you_result/img_13705.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13705.jpg -------------------------------------------------------------------------------- /data/you_result/img_13747.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13747.jpg -------------------------------------------------------------------------------- /data/you_result/img_13752.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13752.jpg -------------------------------------------------------------------------------- /data/you_result/img_13816.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13816.jpg -------------------------------------------------------------------------------- /data/you_result/img_13858.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13858.jpg -------------------------------------------------------------------------------- /data/you_result/img_13869.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13869.jpg -------------------------------------------------------------------------------- /data/you_result/img_13893.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13893.jpg -------------------------------------------------------------------------------- /data/you_result/img_13915.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13915.jpg -------------------------------------------------------------------------------- /data/you_result/img_13941.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/data/you_result/img_13941.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .onet import ONet 2 | from .rnet import RNet 3 | from .pnet import PNet 4 | 5 | __all__ = [ 6 | 'ONet', 7 | 'PNet', 8 | 'RNet' 9 | ] -------------------------------------------------------------------------------- /models/lossfn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LossFn: 6 | def __init__(self, device): 7 | # loss function 8 | self.loss_cls = nn.BCELoss().to(device) 9 | self.loss_box = nn.MSELoss().to(device) 10 | self.loss_landmark = nn.MSELoss().to(device) 11 | 12 | 13 | def cls_loss(self,gt_label,pred_label): 14 | # get the mask element which >= 0, only 0 and 1 can effect the detection loss 15 | pred_label = torch.squeeze(pred_label) 16 | mask = torch.ge(gt_label,0) 17 | valid_gt_label = torch.masked_select(gt_label,mask).float() 18 | valid_pred_label = torch.masked_select(pred_label,mask) 19 | return self.loss_cls(valid_pred_label,valid_gt_label) 20 | 21 | 22 | def box_loss(self,gt_label,gt_offset,pred_offset): 23 | #get the mask element which != 0 24 | mask = torch.ne(gt_label,0) 25 | #convert mask to dim index 26 | chose_index = torch.nonzero(mask) 27 | chose_index = torch.squeeze(chose_index) 28 | #only valid element can effect the loss 29 | valid_gt_offset = gt_offset[chose_index,:] 30 | valid_pred_offset = pred_offset[chose_index,:] 31 | valid_pred_offset = torch.squeeze(valid_pred_offset) 32 | return self.loss_box(valid_pred_offset,valid_gt_offset) 33 | 34 | 35 | def landmark_loss(self,gt_label,gt_landmark,pred_landmark): 36 | mask = torch.eq(gt_label,-2) 37 | 38 | chose_index = torch.nonzero(mask.data) 39 | chose_index = torch.squeeze(chose_index) 40 | 41 | valid_gt_landmark = gt_landmark[chose_index, :] 42 | valid_pred_landmark = pred_landmark[chose_index, :] 43 | return self.loss_landmark(valid_pred_landmark, valid_gt_landmark) -------------------------------------------------------------------------------- /models/onet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class ONet(nn.Module): 8 | ''' RNet ''' 9 | 10 | def __init__(self): 11 | super(ONet, self).__init__() 12 | 13 | self.pre_layer = nn.Sequential( 14 | nn.Conv2d(3, 32, kernel_size=3, stride=1), # conv1 15 | nn.PReLU(), # prelu1 16 | nn.MaxPool2d(kernel_size=3, stride=2), # pool1 17 | nn.Conv2d(32, 64, kernel_size=3, stride=1), # conv2 18 | nn.PReLU(), # prelu2 19 | nn.MaxPool2d(kernel_size=3, stride=2), # pool2 20 | nn.Conv2d(64, 64, kernel_size=3, stride=1), # conv3 21 | nn.PReLU(), # prelu3 22 | nn.MaxPool2d(kernel_size=2,stride=2), # pool3 23 | nn.Conv2d(64,128,kernel_size=2,stride=1), # conv4 24 | nn.PReLU() # prelu4 25 | ) 26 | self.fc = nn.Linear(128*2*2, 256) 27 | self.prelu5 = nn.PReLU() # prelu5 28 | # detection 29 | self.conv6_1 = nn.Linear(256, 1) 30 | # bounding box regression 31 | self.conv6_2 = nn.Linear(256, 4) 32 | # lanbmark localization 33 | self.conv6_3 = nn.Linear(256, 10) 34 | 35 | # weight initiation with xavier 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 39 | m.weight.data.normal_(0, math.sqrt(2. / n)) 40 | 41 | def forward(self, x): 42 | # backend 43 | x = self.pre_layer(x) 44 | x = x.view(x.size(0), -1) 45 | x = self.fc(x) 46 | x = self.prelu5(x) 47 | # detection 48 | det = F.sigmoid(self.conv6_1(x)) 49 | box = self.conv6_2(x) 50 | landmark = self.conv6_3(x) 51 | return det, box, landmark 52 | -------------------------------------------------------------------------------- /models/pnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import math 4 | 5 | class PNet(nn.Module): 6 | ''' PNet ''' 7 | 8 | def __init__(self): 9 | super(PNet, self).__init__() 10 | 11 | self.pre_layer = nn.Sequential( 12 | nn.Conv2d(3, 10, kernel_size=3, stride=1), # conv1 13 | nn.PReLU(), # PReLU1 14 | nn.MaxPool2d(kernel_size=2, stride=2), # pool1 15 | nn.Conv2d(10, 16, kernel_size=3, stride=1), # conv2 16 | nn.PReLU(), # PReLU2 17 | nn.Conv2d(16, 32, kernel_size=3, stride=1), # conv3 18 | nn.PReLU() # PReLU3 19 | ) 20 | 21 | # detection 22 | self.conv4_1 = nn.Conv2d(32, 1, kernel_size=1, stride=1) 23 | # bounding box regresion 24 | self.conv4_2 = nn.Conv2d(32, 4, kernel_size=1, stride=1) 25 | 26 | # weight initiation with xavier 27 | for m in self.modules(): 28 | if isinstance(m, nn.Conv2d): 29 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 30 | m.weight.data.normal_(0, math.sqrt(2. / n)) 31 | 32 | def forward(self, x): 33 | x = self.pre_layer(x) 34 | label = F.sigmoid(self.conv4_1(x)) 35 | offset = self.conv4_2(x) 36 | return label, offset -------------------------------------------------------------------------------- /models/rnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import math 4 | 5 | class RNet(nn.Module): 6 | ''' RNet ''' 7 | 8 | def __init__(self): 9 | super(RNet, self).__init__() 10 | 11 | self.pre_layer = nn.Sequential( 12 | nn.Conv2d(3, 28, kernel_size=3, stride=1), # conv1 13 | nn.PReLU(), # prelu1 14 | nn.MaxPool2d(kernel_size=3, stride=2), # pool1 15 | nn.Conv2d(28, 48, kernel_size=3, stride=1), # conv2 16 | nn.PReLU(), # prelu2 17 | nn.MaxPool2d(kernel_size=3, stride=2), # pool2 18 | nn.Conv2d(48, 64, kernel_size=2, stride=1), # conv3 19 | nn.PReLU() # prelu3 20 | 21 | ) 22 | self.fc = nn.Linear(64*2*2, 128) 23 | self.prelu4 = nn.PReLU() # prelu4 24 | # detection 25 | self.conv5_1 = nn.Linear(128, 1) 26 | # bounding box regression 27 | self.conv5_2 = nn.Linear(128, 4) 28 | 29 | # weight initiation with xavier 30 | for m in self.modules(): 31 | if isinstance(m, nn.Conv2d): 32 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 33 | m.weight.data.normal_(0, math.sqrt(2. / n)) 34 | 35 | def forward(self, x): 36 | # backend 37 | x = self.pre_layer(x) 38 | x = x.view(x.size(0), -1) 39 | x = self.fc(x) 40 | x = self.prelu4(x) 41 | # detection 42 | det = F.sigmoid(self.conv5_1(x)) 43 | box = self.conv5_2(x) 44 | return det, box -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/preprocessing/__init__.py -------------------------------------------------------------------------------- /preprocessing/assemble.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy.random as npr 3 | import numpy as np 4 | 5 | 6 | def assemble_data(output_file, anno_file_list=[]): 7 | # assemble the annotations to one file 8 | size = 12 9 | 10 | if len(anno_file_list) == 0: 11 | return 0 12 | 13 | if os.path.exists(output_file): 14 | os.remove(output_file) 15 | 16 | for anno_file in anno_file_list: 17 | with open(anno_file, 'r') as f: 18 | anno_lines = f.readlines() 19 | 20 | base_num = 250000 21 | 22 | if len(anno_lines) > base_num * 3: 23 | idx_keep = npr.choice(len(anno_lines), size=base_num * 3, replace=True) 24 | elif len(anno_lines) > 100000: 25 | idx_keep = npr.choice(len(anno_lines), size=len(anno_lines), replace=True) 26 | else: 27 | idx_keep = np.arange(len(anno_lines)) 28 | np.random.shuffle(idx_keep) 29 | chose_count = 0 30 | with open(output_file, 'a+') as f: 31 | for idx in idx_keep: 32 | f.write(anno_lines[idx]) 33 | chose_count += 1 34 | 35 | return chose_count 36 | -------------------------------------------------------------------------------- /preprocessing/assemble_onet_imglist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import config 3 | import preprocessing.assemble as assemble 4 | 5 | if __name__ == '__main__': 6 | anno_list = [] 7 | 8 | net_landmark_file = os.path.join(config.ANNO_STORE_DIR, config.ONET_LANDMARK_ANNO_FILENAME) 9 | net_postive_file = os.path.join(config.ANNO_STORE_DIR, config.ONET_POSTIVE_ANNO_FILENAME) 10 | net_part_file = os.path.join(config.ANNO_STORE_DIR, config.ONET_PART_ANNO_FILENAME) 11 | net_neg_file = os.path.join(config.ANNO_STORE_DIR, config.ONET_NEGATIVE_ANNO_FILENAME) 12 | 13 | anno_list.append(net_postive_file) 14 | anno_list.append(net_part_file) 15 | anno_list.append(net_neg_file) 16 | anno_list.append(net_landmark_file) 17 | 18 | imglist_file = os.path.join(config.ANNO_STORE_DIR, config.ONET_TRAIN_IMGLIST_FILENAME) 19 | 20 | chose_count = assemble.assemble_data(imglist_file, anno_list) 21 | print("PNet train annotation result file path:%s" % imglist_file) 22 | -------------------------------------------------------------------------------- /preprocessing/assemble_pnet_imglist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import config 3 | import preprocessing.assemble as assemble 4 | 5 | if __name__ == '__main__': 6 | anno_list = [] 7 | 8 | pnet_postive_file = os.path.join( 9 | config.ANNO_STORE_DIR, config.PNET_POSTIVE_ANNO_FILENAME) 10 | pnet_part_file = os.path.join( 11 | config.ANNO_STORE_DIR, config.PNET_PART_ANNO_FILENAME) 12 | pnet_neg_file = os.path.join( 13 | config.ANNO_STORE_DIR, config.PNET_NEGATIVE_ANNO_FILENAME) 14 | 15 | anno_list.append(pnet_postive_file) 16 | anno_list.append(pnet_part_file) 17 | anno_list.append(pnet_neg_file) 18 | 19 | imglist_filename = config.PNET_TRAIN_IMGLIST_FILENAME 20 | anno_dir = config.ANNO_STORE_DIR 21 | imglist_file = os.path.join(anno_dir, imglist_filename) 22 | 23 | chose_count = assemble.assemble_data(imglist_file, anno_list) 24 | print("PNet train annotation result file path:%s" % imglist_file) 25 | -------------------------------------------------------------------------------- /preprocessing/assemble_rnet_imglist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import config 3 | import preprocessing.assemble as assemble 4 | 5 | if __name__ == '__main__': 6 | anno_list = [] 7 | 8 | # rnet_landmark_file = os.path.join(config.ANNO_STORE_DIR,config.RNET_LANDMARK_ANNO_FILENAME) 9 | rnet_postive_file = os.path.join(config.ANNO_STORE_DIR, config.RNET_POSTIVE_ANNO_FILENAME) 10 | rnet_part_file = os.path.join(config.ANNO_STORE_DIR, config.RNET_PART_ANNO_FILENAME) 11 | rnet_neg_file = os.path.join(config.ANNO_STORE_DIR, config.RNET_NEGATIVE_ANNO_FILENAME) 12 | 13 | anno_list.append(rnet_postive_file) 14 | anno_list.append(rnet_part_file) 15 | anno_list.append(rnet_neg_file) 16 | # anno_list.append(rnet_landmark_file) 17 | 18 | imglist_file = os.path.join(config.ANNO_STORE_DIR, config.RNET_TRAIN_IMGLIST_FILENAME) 19 | 20 | chose_count = assemble.assemble_data(imglist_file, anno_list) 21 | print("PNet train annotation result file path:%s, total num of imgs: %d" % (imglist_file, chose_count)) 22 | -------------------------------------------------------------------------------- /preprocessing/gen_landmark_12.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | 5 | import numpy.random as npr 6 | import argparse 7 | import config 8 | import tools.utils as utils 9 | 10 | 11 | def gen_data(anno_file, data_dir, prefix): 12 | size = 12 13 | 14 | landmark_imgs_save_dir = os.path.join(data_dir, "12/landmark") 15 | if not os.path.exists(landmark_imgs_save_dir): 16 | os.makedirs(landmark_imgs_save_dir) 17 | 18 | anno_dir = config.ANNO_STORE_DIR 19 | if not os.path.exists(anno_dir): 20 | os.makedirs(anno_dir) 21 | 22 | landmark_anno_filename = config.PNET_LANDMARK_ANNO_FILENAME 23 | save_landmark_anno = os.path.join(anno_dir, landmark_anno_filename) 24 | 25 | f = open(save_landmark_anno, 'w') 26 | 27 | with open(anno_file, 'r') as f2: 28 | annotations = f2.readlines() 29 | 30 | num = len(annotations) 31 | print("%d total images" % num) 32 | 33 | l_idx = 0 34 | idx = 0 35 | # image_path bbox landmark(5*2) 36 | for annotation in annotations: 37 | # print imgPath 38 | 39 | annotation = annotation.strip().split(' ') 40 | assert len(annotation) == 15, "each line should have 15 element" 41 | im_path = os.path.join(prefix, annotation[0].replace("\\", "/")) 42 | 43 | gt_box = list(map(float, annotation[1:5])) 44 | # the bounging box in original anno_file is [left, right, top, bottom] 45 | gt_box = [gt_box[0], gt_box[2], gt_box[1], 46 | gt_box[3]] # [left, top, right, bottom] 47 | gt_box = np.array(gt_box, dtype=np.int32) 48 | 49 | landmark = map(float, annotation[5:]) 50 | landmark = np.array(landmark, dtype=np.float) 51 | 52 | img = cv2.imread(im_path) 53 | assert (img is not None) 54 | 55 | height, width, channel = img.shape 56 | # crop_face = img[gt_box[1]:gt_box[3]+1, gt_box[0]:gt_box[2]+1] 57 | # crop_face = cv2.resize(crop_face,(size,size)) 58 | 59 | idx = idx + 1 60 | if idx % 100 == 0: 61 | print("%d images done, landmark images: %d" % (idx, l_idx)) 62 | 63 | x1, y1, x2, y2 = gt_box 64 | 65 | # gt's width 66 | w = x2 - x1 67 | # gt's height 68 | h = y2 - y1 69 | if max(w, h) < 40 or x1 < 0 or y1 < 0: 70 | continue 71 | # random shift 72 | for i in range(20): 73 | bbox_size = npr.randint( 74 | int(min(w, h) * 0.8), np.ceil(1.25 * max(w, h))) 75 | delta_x = npr.randint(-w * 0.2, w * 0.2) 76 | delta_y = npr.randint(-h * 0.2, h * 0.2) 77 | nx1 = max(x1 + w / 2 - bbox_size / 2 + delta_x, 0) 78 | ny1 = max(y1 + h / 2 - bbox_size / 2 + delta_y, 0) 79 | 80 | nx2 = nx1 + bbox_size 81 | ny2 = ny1 + bbox_size 82 | if nx2 > width or ny2 > height: 83 | continue 84 | crop_box = np.array([nx1, ny1, nx2, ny2]) 85 | cropped_im = img[ny1:ny2, nx1:nx2, :] 86 | resized_im = cv2.resize( 87 | cropped_im, (size, size), interpolation=cv2.INTER_LINEAR) 88 | 89 | offset_x1 = (x1 - nx1) / float(bbox_size) 90 | offset_y1 = (y1 - ny1) / float(bbox_size) 91 | offset_x2 = (x2 - nx2) / float(bbox_size) 92 | offset_y2 = (y2 - ny2) / float(bbox_size) 93 | 94 | offset_left_eye_x = (landmark[0] - nx1) / float(bbox_size) 95 | offset_left_eye_y = (landmark[1] - ny1) / float(bbox_size) 96 | 97 | offset_right_eye_x = (landmark[2] - nx1) / float(bbox_size) 98 | offset_right_eye_y = (landmark[3] - ny1) / float(bbox_size) 99 | 100 | offset_nose_x = (landmark[4] - nx1) / float(bbox_size) 101 | offset_nose_y = (landmark[5] - ny1) / float(bbox_size) 102 | 103 | offset_left_mouth_x = (landmark[6] - nx1) / float(bbox_size) 104 | offset_left_mouth_y = (landmark[7] - ny1) / float(bbox_size) 105 | 106 | offset_right_mouth_x = (landmark[8] - nx1) / float(bbox_size) 107 | offset_right_mouth_y = (landmark[9] - ny1) / float(bbox_size) 108 | 109 | # cal iou 110 | iou = utils.IoU(crop_box.astype(np.float), 111 | np.expand_dims(gt_box.astype(np.float), 0)) 112 | if iou > 0.65: 113 | save_file = os.path.join( 114 | landmark_imgs_save_dir, "%s.jpg" % l_idx) 115 | cv2.imwrite(save_file, resized_im) 116 | 117 | f.write(save_file + ' -2 %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n' % 118 | (offset_x1, offset_y1, offset_x2, offset_y2, 119 | offset_left_eye_x, offset_left_eye_y, offset_right_eye_x, offset_right_eye_y, offset_nose_x, 120 | offset_nose_y, offset_left_mouth_x, offset_left_mouth_y, offset_right_mouth_x, 121 | offset_right_mouth_y)) 122 | 123 | l_idx += 1 124 | 125 | f.close() 126 | 127 | 128 | def parse_args(): 129 | parser = argparse.ArgumentParser(description='Test mtcnn', 130 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 131 | 132 | parser.add_argument('--traindata_store', dest='traindata_store', help='dface train data temporary folder', 133 | default=config.TRAIN_DATA_DIR, type=str) 134 | parser.add_argument('--anno_file', dest='annotation_file', help='landmark dataset original annotation file', 135 | default=os.path.join(config.ANNO_STORE_DIR, "landmark_imagelist.txt"), type=str) 136 | parser.add_argument('--prefix_path', dest='prefix_path', help='annotation file image prefix root path', 137 | default='/home/wujiyang/FaceProjects/MTCNN_TRAIN/training_data/landmark_train', type=str) 138 | 139 | args = parser.parse_args() 140 | return args 141 | 142 | 143 | if __name__ == '__main__': 144 | args = parse_args() 145 | gen_data(args.annotation_file, args.traindata_store, args.prefix_path) 146 | -------------------------------------------------------------------------------- /preprocessing/gen_landmark_24.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import numpy.random as npr 7 | 8 | import config 9 | import tools.utils as utils 10 | 11 | 12 | def gen_data(anno_file, data_dir, prefix): 13 | size = 24 14 | 15 | landmark_imgs_save_dir = os.path.join(data_dir, "24/landmark") 16 | if not os.path.exists(landmark_imgs_save_dir): 17 | os.makedirs(landmark_imgs_save_dir) 18 | 19 | anno_dir = config.ANNO_STORE_DIR 20 | if not os.path.exists(anno_dir): 21 | os.makedirs(anno_dir) 22 | 23 | landmark_anno_filename = config.RNET_LANDMARK_ANNO_FILENAME 24 | save_landmark_anno = os.path.join(anno_dir, landmark_anno_filename) 25 | 26 | f = open(save_landmark_anno, 'w') 27 | 28 | with open(anno_file, 'r') as f2: 29 | annotations = f2.readlines() 30 | 31 | num = len(annotations) 32 | print("%d total images" % num) 33 | 34 | l_idx = 0 35 | idx = 0 36 | # image_path bbox landmark(5*2) 37 | for annotation in annotations: 38 | # print imgPath 39 | 40 | annotation = annotation.strip().split(' ') 41 | assert len(annotation) == 15, "each line should have 15 element" 42 | im_path = os.path.join(prefix, annotation[0].replace("\\", "/")) 43 | 44 | gt_box = map(float, annotation[1:5]) 45 | # the bounging box in original anno_file is [left, right, top, bottom] 46 | gt_box = [gt_box[0], gt_box[2], gt_box[1], gt_box[3]] # [left, top, right, bottom] 47 | gt_box = np.array(gt_box, dtype=np.int32) 48 | 49 | landmark = map(float, annotation[5:]) 50 | landmark = np.array(landmark, dtype=np.float) 51 | 52 | img = cv2.imread(im_path) 53 | assert (img is not None) 54 | 55 | height, width, channel = img.shape 56 | # crop_face = img[gt_box[1]:gt_box[3]+1, gt_box[0]:gt_box[2]+1] 57 | # crop_face = cv2.resize(crop_face,(size,size)) 58 | 59 | idx = idx + 1 60 | if idx % 100 == 0: 61 | print("%d images done, landmark images: %d" % (idx, l_idx)) 62 | 63 | x1, y1, x2, y2 = gt_box 64 | 65 | # gt's width 66 | w = x2 - x1 67 | # gt's height 68 | h = y2 - y1 69 | if max(w, h) < 40 or x1 < 0 or y1 < 0: 70 | continue 71 | # random shift 72 | for i in range(20): 73 | bbox_size = npr.randint(int(min(w, h) * 0.8), np.ceil(1.25 * max(w, h))) 74 | delta_x = npr.randint(-w * 0.2, w * 0.2) 75 | delta_y = npr.randint(-h * 0.2, h * 0.2) 76 | nx1 = max(x1 + w / 2 - bbox_size / 2 + delta_x, 0) 77 | ny1 = max(y1 + h / 2 - bbox_size / 2 + delta_y, 0) 78 | 79 | nx2 = nx1 + bbox_size 80 | ny2 = ny1 + bbox_size 81 | if nx2 > width or ny2 > height: 82 | continue 83 | crop_box = np.array([nx1, ny1, nx2, ny2]) 84 | cropped_im = img[ny1:ny2 + 1, nx1:nx2 + 1, :] 85 | resized_im = cv2.resize(cropped_im, (size, size), interpolation=cv2.INTER_LINEAR) 86 | 87 | offset_x1 = (x1 - nx1) / float(bbox_size) 88 | offset_y1 = (y1 - ny1) / float(bbox_size) 89 | offset_x2 = (x2 - nx2) / float(bbox_size) 90 | offset_y2 = (y2 - ny2) / float(bbox_size) 91 | 92 | offset_left_eye_x = (landmark[0] - nx1) / float(bbox_size) 93 | offset_left_eye_y = (landmark[1] - ny1) / float(bbox_size) 94 | 95 | offset_right_eye_x = (landmark[2] - nx1) / float(bbox_size) 96 | offset_right_eye_y = (landmark[3] - ny1) / float(bbox_size) 97 | 98 | offset_nose_x = (landmark[4] - nx1) / float(bbox_size) 99 | offset_nose_y = (landmark[5] - ny1) / float(bbox_size) 100 | 101 | offset_left_mouth_x = (landmark[6] - nx1) / float(bbox_size) 102 | offset_left_mouth_y = (landmark[7] - ny1) / float(bbox_size) 103 | 104 | offset_right_mouth_x = (landmark[8] - nx1) / float(bbox_size) 105 | offset_right_mouth_y = (landmark[9] - ny1) / float(bbox_size) 106 | 107 | # cal iou 108 | iou = utils.IoU(crop_box.astype(np.float), np.expand_dims(gt_box.astype(np.float), 0)) 109 | if iou > 0.65: 110 | save_file = os.path.join(landmark_imgs_save_dir, "%s.jpg" % l_idx) 111 | cv2.imwrite(save_file, resized_im) 112 | 113 | f.write(save_file + ' -2 %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n' % \ 114 | (offset_x1, offset_y1, offset_x2, offset_y2, offset_left_eye_x, offset_left_eye_y, 115 | offset_right_eye_x, offset_right_eye_y, offset_nose_x, 116 | offset_nose_y, offset_left_mouth_x, offset_left_mouth_y, offset_right_mouth_x, 117 | offset_right_mouth_y)) 118 | 119 | l_idx += 1 120 | 121 | f.close() 122 | 123 | 124 | def parse_args(): 125 | parser = argparse.ArgumentParser(description='Test mtcnn', 126 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 127 | 128 | parser.add_argument('--traindata_store', dest='traindata_store', help='dface train data temporary folder', 129 | default=config.TRAIN_DATA_DIR, type=str) 130 | parser.add_argument('--anno_file', dest='annotation_file', help='landmark dataset original annotation file', 131 | default=os.path.join(config.ANNO_STORE_DIR, "landmark_imagelist.txt"), type=str) 132 | parser.add_argument('--prefix_path', dest='prefix_path', help='annotation file image prefix root path', 133 | default='/home/wujiyang/FaceProjects/MTCNN_TRAIN/training_data/landmark_train', type=str) 134 | 135 | args = parser.parse_args() 136 | return args 137 | 138 | 139 | if __name__ == '__main__': 140 | args = parse_args() 141 | 142 | gen_data(args.annotation_file, args.traindata_store, args.prefix_path) 143 | -------------------------------------------------------------------------------- /preprocessing/gen_landmark_48.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | 4 | import os 5 | import cv2 6 | import numpy as np 7 | import numpy.random as npr 8 | import argparse 9 | import config 10 | import tools.utils as utils 11 | 12 | 13 | def gen_data(anno_file, data_dir, prefix): 14 | size = 48 15 | 16 | landmark_imgs_save_dir = os.path.join(data_dir, "48/landmark") 17 | if not os.path.exists(landmark_imgs_save_dir): 18 | os.makedirs(landmark_imgs_save_dir) 19 | 20 | anno_dir = config.ANNO_STORE_DIR 21 | if not os.path.exists(anno_dir): 22 | os.makedirs(anno_dir) 23 | 24 | landmark_anno_filename = config.ONET_LANDMARK_ANNO_FILENAME 25 | save_landmark_anno = os.path.join(anno_dir, landmark_anno_filename) 26 | 27 | f = open(save_landmark_anno, 'w') 28 | 29 | with open(anno_file, 'r') as f2: 30 | annotations = f2.readlines() 31 | 32 | num = len(annotations) 33 | print("%d total images" % num) 34 | 35 | l_idx = 0 36 | idx = 0 37 | # image_path bbox landmark(5*2) 38 | for annotation in annotations: 39 | # print imgPath 40 | 41 | annotation = annotation.strip().split(' ') 42 | assert len(annotation) == 15, "each line should have 15 element" 43 | im_path = os.path.join(prefix, annotation[0].replace("\\", "/")) 44 | 45 | gt_box = list(map(float, annotation[1:5])) 46 | # the bounging box in original anno_file is [left, right, top, bottom] 47 | gt_box = [gt_box[0], gt_box[2], gt_box[1], 48 | gt_box[3]] # [left, top, right, bottom] 49 | gt_box = np.array(gt_box, dtype=np.int32) 50 | 51 | landmark = list(map(float, annotation[5:])) 52 | landmark = np.array(landmark, dtype=np.float) 53 | 54 | img = cv2.imread(im_path) 55 | assert (img is not None) 56 | 57 | height, width, channel = img.shape 58 | # crop_face = img[gt_box[1]:gt_box[3]+1, gt_box[0]:gt_box[2]+1] 59 | # crop_face = cv2.resize(crop_face,(size,size)) 60 | 61 | idx = idx + 1 62 | if idx % 100 == 0: 63 | print("%d images done, landmark images: %d" % (idx, l_idx)) 64 | 65 | x1, y1, x2, y2 = gt_box 66 | 67 | # gt's width 68 | w = x2 - x1 69 | # gt's height 70 | h = y2 - y1 71 | if max(w, h) < 40 or x1 < 0 or y1 < 0: 72 | continue 73 | # random shift 74 | for i in range(20): 75 | bbox_size = npr.randint( 76 | int(min(w, h) * 0.8), np.ceil(1.25 * max(w, h))) 77 | delta_x = npr.randint(-w * 0.2, w * 0.2) 78 | delta_y = npr.randint(-h * 0.2, h * 0.2) 79 | nx1 = int(max(x1 + w / 2 - bbox_size / 2 + delta_x, 0)) 80 | ny1 = int(max(y1 + h / 2 - bbox_size / 2 + delta_y, 0)) 81 | 82 | nx2 = int(nx1 + bbox_size) 83 | ny2 = int(ny1 + bbox_size) 84 | if nx2 > width or ny2 > height: 85 | continue 86 | crop_box = np.array([nx1, ny1, nx2, ny2]) 87 | cropped_im = img[ny1:ny2, nx1:nx2, :] 88 | resized_im = cv2.resize( 89 | cropped_im, (size, size), interpolation=cv2.INTER_LINEAR) 90 | 91 | offset_x1 = (x1 - nx1) / float(bbox_size) 92 | offset_y1 = (y1 - ny1) / float(bbox_size) 93 | offset_x2 = (x2 - nx2) / float(bbox_size) 94 | offset_y2 = (y2 - ny2) / float(bbox_size) 95 | 96 | offset_left_eye_x = (landmark[0] - nx1) / float(bbox_size) 97 | offset_left_eye_y = (landmark[1] - ny1) / float(bbox_size) 98 | 99 | offset_right_eye_x = (landmark[2] - nx1) / float(bbox_size) 100 | offset_right_eye_y = (landmark[3] - ny1) / float(bbox_size) 101 | 102 | offset_nose_x = (landmark[4] - nx1) / float(bbox_size) 103 | offset_nose_y = (landmark[5] - ny1) / float(bbox_size) 104 | 105 | offset_left_mouth_x = (landmark[6] - nx1) / float(bbox_size) 106 | offset_left_mouth_y = (landmark[7] - ny1) / float(bbox_size) 107 | 108 | offset_right_mouth_x = (landmark[8] - nx1) / float(bbox_size) 109 | offset_right_mouth_y = (landmark[9] - ny1) / float(bbox_size) 110 | 111 | # cal iou 112 | iou = utils.IoU(crop_box.astype(np.float), 113 | np.expand_dims(gt_box.astype(np.float), 0)) 114 | if iou > 0.65: 115 | save_file = os.path.join( 116 | landmark_imgs_save_dir, "%s.jpg" % l_idx) 117 | cv2.imwrite(save_file, resized_im) 118 | 119 | f.write(save_file + ' -2 %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n' % 120 | (offset_x1, offset_y1, offset_x2, offset_y2, 121 | offset_left_eye_x, offset_left_eye_y, offset_right_eye_x, offset_right_eye_y, offset_nose_x, 122 | offset_nose_y, offset_left_mouth_x, offset_left_mouth_y, offset_right_mouth_x, 123 | offset_right_mouth_y)) 124 | 125 | l_idx += 1 126 | 127 | f.close() 128 | 129 | 130 | def parse_args(): 131 | parser = argparse.ArgumentParser(description='Test mtcnn', 132 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 133 | 134 | parser.add_argument('--traindata_store', dest='traindata_store', help='dface train data temporary folder', 135 | default=config.TRAIN_DATA_DIR, type=str) 136 | parser.add_argument('--anno_file', dest='annotation_file', help='landmark dataset original annotation file', 137 | default=os.path.join(config.ANNO_STORE_DIR, "landmark_imagelist.txt"), type=str) 138 | parser.add_argument('--prefix_path', dest='prefix_path', help='annotation file image prefix root path', 139 | default='/home/xiezheng/datasets/Face/FaceAlignment/train', type=str) 140 | 141 | args = parser.parse_args() 142 | return args 143 | 144 | 145 | if __name__ == '__main__': 146 | args = parse_args() 147 | 148 | gen_data(args.annotation_file, args.traindata_store, args.prefix_path) 149 | 150 | -------------------------------------------------------------------------------- /preprocessing/gen_onet_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('./') 4 | 5 | import cv2 6 | import argparse 7 | import numpy as np 8 | import os 9 | import pickle 10 | import time 11 | 12 | import config 13 | import tools.vision as vision 14 | from tools.train_detect import MtcnnDetector 15 | from tools.imagedb import ImageDB 16 | from tools.image_reader import TestImageLoader 17 | from tools.utils import IoU, convert_to_square 18 | 19 | 20 | def gen_onet_data(data_dir, anno_file, pnet_model_file, rnet_model_file, prefix_path='', use_cuda=True, vis=False): 21 | mtcnn_detector = MtcnnDetector(p_model_path=pnet_model_file, 22 | r_model_path=rnet_model_file, 23 | o_model_path=None, 24 | min_face_size=12, 25 | use_cuda=True) 26 | 27 | imagedb = ImageDB(anno_file, mode="test", prefix_path=prefix_path) 28 | imdb = imagedb.load_imdb() 29 | image_reader = TestImageLoader(imdb, 1, False) 30 | 31 | all_boxes = list() 32 | batch_idx = 0 33 | 34 | for databatch in image_reader: 35 | if batch_idx % 100 == 0: 36 | print("%d images done" % batch_idx) 37 | im = databatch 38 | t = time.time() 39 | # detect an image by pnet and rnet 40 | p_boxes, p_boxes_align = mtcnn_detector.detect_pnet(im=im) 41 | boxes, boxes_align = mtcnn_detector.detect_rnet(im=im, dets=p_boxes_align) 42 | if boxes_align is None: 43 | all_boxes.append(np.array([])) 44 | batch_idx += 1 45 | continue 46 | if vis: 47 | vision.vis_face(im, boxes_align) 48 | 49 | t1 = time.time() - t 50 | print('time cost for image ', batch_idx, '/', image_reader.size, ': ', t1) 51 | all_boxes.append(boxes_align) 52 | batch_idx += 1 53 | 54 | save_path = config.TRAIN_DATA_DIR 55 | if not os.path.exists(save_path): 56 | os.mkdir(save_path) 57 | 58 | save_file = os.path.join( 59 | save_path, "pnet_rnet_detections_%d.pkl" % int(time.time())) 60 | 61 | with open(save_file, 'wb') as f: 62 | pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL) 63 | 64 | # save_file = '/home/liujing/Codes/MTCNN/data/pnet_detections_1532582821.pkl' 65 | get_onet_sample_data(data_dir, anno_file, save_file, prefix_path) 66 | 67 | 68 | def get_onet_sample_data(data_dir, anno_file, det_boxs_file, prefix): 69 | neg_save_dir = os.path.join(data_dir, "48/negative") 70 | pos_save_dir = os.path.join(data_dir, "48/positive") 71 | part_save_dir = os.path.join(data_dir, "48/part") 72 | 73 | for dir_path in [neg_save_dir, pos_save_dir, part_save_dir]: 74 | if not os.path.exists(dir_path): 75 | os.makedirs(dir_path) 76 | 77 | # load ground truth from annotation file 78 | # format of each line: image/path [x1,y1,x2,y2] for each gt_box in this image 79 | 80 | with open(anno_file, 'r') as f: 81 | annotations = f.readlines() 82 | 83 | image_size = 48 84 | im_idx_list = list() 85 | gt_boxes_list = list() 86 | num_of_images = len(annotations) 87 | print("processing %d images in total" % num_of_images) 88 | 89 | for annotation in annotations: 90 | annotation = annotation.strip().split(' ') 91 | im_idx = os.path.join(prefix, annotation[0]) 92 | 93 | boxes = list(map(float, annotation[1:])) 94 | boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) 95 | im_idx_list.append(im_idx) 96 | gt_boxes_list.append(boxes) 97 | 98 | save_path = config.ANNO_STORE_DIR 99 | if not os.path.exists(save_path): 100 | os.makedirs(save_path) 101 | 102 | f1 = open(os.path.join(save_path, 'pos_%d.txt' % image_size), 'w') 103 | f2 = open(os.path.join(save_path, 'neg_%d.txt' % image_size), 'w') 104 | f3 = open(os.path.join(save_path, 'part_%d.txt' % image_size), 'w') 105 | 106 | det_handle = open(det_boxs_file, 'rb') 107 | det_boxes = pickle.load(det_handle) 108 | print(len(det_boxes), num_of_images) 109 | assert len(det_boxes) == num_of_images, "incorrect detections or ground truths" 110 | 111 | # index of neg, pos and part face, used as their image names 112 | n_idx = 0 113 | p_idx = 0 114 | d_idx = 0 115 | image_done = 0 116 | for im_idx, dets, gts in zip(im_idx_list, det_boxes, gt_boxes_list): 117 | image_done += 1 118 | if image_done % 100 == 0: 119 | print("%d images done" % image_done) 120 | if dets.shape[0] == 0: 121 | continue 122 | img = cv2.imread(im_idx) 123 | dets = convert_to_square(dets) 124 | dets[:, 0:4] = np.round(dets[:, 0:4]) 125 | 126 | # each image have at most 50 neg_samples 127 | cur_n_idx = 0 128 | for box in dets: 129 | x_left, y_top, x_right, y_bottom = box[0:4].astype(int) 130 | width = x_right - x_left 131 | height = y_bottom - y_top 132 | # ignore box that is too small or beyond image border 133 | if width < 20 or x_left < 0 or y_top < 0 or x_right > img.shape[1] - 1 or y_bottom > img.shape[0] - 1: 134 | continue 135 | # compute intersection over union(IoU) between current box and all gt boxes 136 | Iou = IoU(box, gts) 137 | cropped_im = img[y_top:y_bottom, x_left:x_right, :] 138 | resized_im = cv2.resize(cropped_im, (image_size, image_size), 139 | interpolation=cv2.INTER_LINEAR) 140 | 141 | # save negative images and write label 142 | if np.max(Iou) < 0.3: 143 | # Iou with all gts must below 0.3 144 | cur_n_idx += 1 145 | if cur_n_idx <= 50: 146 | save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx) 147 | f2.write(save_file + ' 0\n') 148 | cv2.imwrite(save_file, resized_im) 149 | n_idx += 1 150 | else: 151 | # find gt_box with the highest iou 152 | idx = np.argmax(Iou) 153 | assigned_gt = gts[idx] 154 | x1, y1, x2, y2 = assigned_gt 155 | 156 | # compute bbox reg label 157 | offset_x1 = (x1 - x_left) / float(width) 158 | offset_y1 = (y1 - y_top) / float(height) 159 | offset_x2 = (x2 - x_right) / float(width) 160 | offset_y2 = (y2 - y_bottom) / float(height) 161 | 162 | # save positive and part-face images and write labels 163 | if np.max(Iou) >= 0.65: 164 | save_file = os.path.join(pos_save_dir, "%s.jpg" % p_idx) 165 | f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % ( 166 | offset_x1, offset_y1, offset_x2, offset_y2)) 167 | cv2.imwrite(save_file, resized_im) 168 | p_idx += 1 169 | 170 | elif np.max(Iou) >= 0.4: 171 | save_file = os.path.join(part_save_dir, "%s.jpg" % d_idx) 172 | f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % ( 173 | offset_x1, offset_y1, offset_x2, offset_y2)) 174 | cv2.imwrite(save_file, resized_im) 175 | d_idx += 1 176 | f1.close() 177 | f2.close() 178 | f3.close() 179 | 180 | 181 | def model_store_path(): 182 | return os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + "/model_store" 183 | 184 | 185 | def parse_args(): 186 | parser = argparse.ArgumentParser(description='Test mtcnn', 187 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 188 | 189 | parser.add_argument('--face_traindata_store', dest='traindata_store', help='face train data temporary folder', 190 | default=config.TRAIN_DATA_DIR, type=str) 191 | parser.add_argument('--anno_file', dest='annotation_file', help='wider face original annotation file', 192 | default=os.path.join(config.ANNO_STORE_DIR, "wider_origin_anno.txt"), type=str) 193 | parser.add_argument('--pmodel_file', dest='pnet_model_file', help='PNet model file path', 194 | default='./results/pnet/log_bs512_lr0.010_072402/check_point/model_050.pth', type=str) 195 | parser.add_argument('--rmodel_file', dest='rnet_model_file', help='RNet model file path', 196 | default='./results/rnet/log_bs512_lr0.001_072502/check_point/model_050.pth', type=str) 197 | parser.add_argument('--gpu', dest='use_cuda', help='with gpu', 198 | default=config.USE_CUDA, type=bool) 199 | parser.add_argument('--prefix_path', dest='prefix_path', help='annotation file image prefix root path', 200 | default='/home/dataset/WIDER/WIDER_train/images', type=str) 201 | 202 | args = parser.parse_args() 203 | return args 204 | 205 | 206 | if __name__ == '__main__': 207 | args = parse_args() 208 | gen_onet_data(args.traindata_store, args.annotation_file, args.pnet_model_file, args.rnet_model_file, 209 | args.prefix_path, args.use_cuda) 210 | -------------------------------------------------------------------------------- /preprocessing/gen_pnet_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | 4 | import os 5 | import argparse 6 | import numpy as np 7 | import cv2 8 | import numpy.random as npr 9 | from tools.utils import IoU 10 | import config 11 | 12 | 13 | def gen_pnet_data(data_dir, anno_file, prefix): 14 | neg_save_dir = os.path.join(data_dir, "12/negative") 15 | pos_save_dir = os.path.join(data_dir, "12/positive") 16 | part_save_dir = os.path.join(data_dir, "12/part") 17 | 18 | for dir_path in [neg_save_dir, pos_save_dir, part_save_dir]: # make 19 | if not os.path.exists(dir_path): 20 | os.makedirs(dir_path) 21 | 22 | # save_dir = os.path.join(data_dir, "pnet") 23 | # if not os.path.exists(save_dir): 24 | # os.mkdir(save_dir) 25 | 26 | post_save_file = os.path.join( 27 | config.ANNO_STORE_DIR, config.PNET_POSTIVE_ANNO_FILENAME) 28 | neg_save_file = os.path.join( 29 | config.ANNO_STORE_DIR, config.PNET_NEGATIVE_ANNO_FILENAME) 30 | part_save_file = os.path.join( 31 | config.ANNO_STORE_DIR, config.PNET_PART_ANNO_FILENAME) 32 | 33 | f1 = open(post_save_file, 'w') 34 | f2 = open(neg_save_file, 'w') 35 | f3 = open(part_save_file, 'w') 36 | 37 | with open(anno_file, 'r') as f: 38 | annotations = f.readlines() 39 | 40 | num = len(annotations) 41 | print("%d pics in total" % num) 42 | 43 | p_idx = 0 # positive examples index 44 | n_idx = 0 # negative examples index 45 | d_idx = 0 # partface examples index 46 | idx = 0 # pics index 47 | box_idx = 0 # boxes index 48 | 49 | for annotation in annotations: 50 | annotation = annotation.strip().split(' ') 51 | im_path = os.path.join(prefix, annotation[0]) # image_path 52 | # print(im_path) 53 | bbox = list(map(float, annotation[1:])) # map()函数是将func作用于seq中的每一个元素,并将所有的调用的结果作为一个list返回 54 | boxes = np.array(bbox, dtype=np.int32).reshape(-1, 4) # N*4 dim array 55 | img = cv2.imread(im_path) 56 | idx += 1 57 | 58 | if idx % 100 == 0: 59 | print(idx, "images done") 60 | 61 | height, width, channel = img.shape 62 | 63 | neg_num = 0 64 | while neg_num < 50: 65 | size = npr.randint(12, min(width, height) / 2) 66 | nx = npr.randint(0, width - size) 67 | ny = npr.randint(0, height - size) 68 | crop_box = np.array([nx, ny, nx + size, ny + size]) 69 | 70 | Iou = IoU(crop_box, boxes) 71 | 72 | if np.max(Iou) < 0.3: 73 | # Iou with all gts must below 0.3 74 | save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx) # save neg image 75 | f2.write(save_file + ' 0\n') 76 | cropped_im = img[ny: ny + size, nx: nx + size, :] 77 | resized_im = cv2.resize(cropped_im, (12, 12), interpolation=cv2.INTER_LINEAR) 78 | cv2.imwrite(save_file, resized_im) 79 | n_idx += 1 80 | neg_num += 1 81 | 82 | for box in boxes: 83 | # box (x_left, y_top, x_right, y_bottom) 84 | x1, y1, x2, y2 = box 85 | w = x2 - x1 86 | h = y2 - y1 87 | 88 | # ignore small faces 89 | # in case the ground truth boxes of small faces are not accurate 90 | if max(w, h) < 40 or x1 < 0 or y1 < 0: 91 | continue 92 | 93 | # generate negative examples that have overlap with gt 94 | for i in range(5): 95 | size = npr.randint(12, min(width, height) / 2) 96 | # delta_x and delta_y are offsets of (x1, y1) 97 | delta_x = npr.randint(max(-size, -x1), w) 98 | delta_y = npr.randint(max(-size, -y1), h) 99 | nx1 = max(0, x1 + delta_x) 100 | ny1 = max(0, y1 + delta_y) 101 | 102 | if nx1 + size > width or ny1 + size > height: 103 | continue 104 | crop_box = np.array([nx1, ny1, nx1 + size, ny1 + size]) 105 | Iou = IoU(crop_box, boxes) 106 | 107 | if np.max(Iou) < 0.3: 108 | # Iou with all gts must below 0.3 109 | save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx) 110 | cropped_im = img[ny1: ny1 + size, nx1: nx1 + size, :] 111 | resized_im = cv2.resize(cropped_im, (12, 12), interpolation=cv2.INTER_LINEAR) 112 | f2.write(save_file + ' 0\n') # neg samples with label 0 113 | cv2.imwrite(save_file, resized_im) 114 | n_idx += 1 115 | 116 | # generate positive examples and part faces 117 | # 每个box随机生成50个box,Iou>=0.65的作为positive examples,0.4<=Iou<0.65的作为part faces,其他忽略 118 | for i in range(20): 119 | size = npr.randint(int(min(w, h) * 0.8), 120 | np.ceil(1.25 * max(w, h))) 121 | 122 | # delta here is the offset of box center 123 | delta_x = npr.randint(-w * 0.2, w * 0.2) 124 | delta_y = npr.randint(-h * 0.2, h * 0.2) 125 | 126 | nx1 = int(max(x1 + w / 2 + delta_x - size / 2, 0)) 127 | ny1 = int(max(y1 + h / 2 + delta_y - size / 2, 0)) 128 | nx2 = int(nx1 + size) 129 | ny2 = int(ny1 + size) 130 | 131 | if nx2 > width or ny2 > height: 132 | continue 133 | crop_box = np.array([nx1, ny1, nx2, ny2]) 134 | 135 | # bbox偏移量的计算,由 x1 = nx1 + float(size)*offset_x1 推导而来 136 | offset_x1 = (x1 - nx1) / float(size) 137 | offset_y1 = (y1 - ny1) / float(size) 138 | offset_x2 = (x2 - nx2) / float(size) 139 | offset_y2 = (y2 - ny2) / float(size) 140 | 141 | cropped_im = img[ny1: ny2, nx1: nx2, :] 142 | resized_im = cv2.resize( 143 | cropped_im, (12, 12), interpolation=cv2.INTER_LINEAR) 144 | 145 | box_ = box.reshape(1, -1) 146 | if IoU(crop_box, box_) >= 0.65: 147 | save_file = os.path.join(pos_save_dir, "%s.jpg" % p_idx) 148 | f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % 149 | (offset_x1, offset_y1, offset_x2, offset_y2)) 150 | cv2.imwrite(save_file, resized_im) 151 | p_idx += 1 152 | elif IoU(crop_box, box_) >= 0.4: 153 | save_file = os.path.join(part_save_dir, "%s.jpg" % d_idx) 154 | f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % 155 | (offset_x1, offset_y1, offset_x2, offset_y2)) 156 | cv2.imwrite(save_file, resized_im) 157 | d_idx += 1 158 | box_idx += 1 159 | print("%s images done, pos: %s part: %s neg: %s" % 160 | (idx, p_idx, d_idx, n_idx)) 161 | 162 | f1.close() 163 | f2.close() 164 | f3.close() 165 | 166 | 167 | def parse_args(): 168 | parser = argparse.ArgumentParser(description='generate pnet training data', 169 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 170 | 171 | parser.add_argument('--face_traindata_store', dest='traindata_store', help='face train data temporary folder', 172 | default=config.TRAIN_DATA_DIR, type=str) 173 | parser.add_argument('--anno_file', dest='annotation_file', help='wider face original annotation file', 174 | default=os.path.join(config.ANNO_STORE_DIR, "wider_origin_anno.txt"), type=str) 175 | parser.add_argument('--prefix_path', dest='prefix_path', help='annotation file image prefix root path', 176 | default='/home/dataset/WIDER/WIDER_train/images', type=str) 177 | 178 | args = parser.parse_args() 179 | return args 180 | 181 | 182 | if __name__ == '__main__': 183 | args = parse_args() 184 | gen_pnet_data(args.traindata_store, args.annotation_file, args.prefix_path) 185 | -------------------------------------------------------------------------------- /preprocessing/gen_rnet_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | 4 | import argparse 5 | import os 6 | import pickle 7 | 8 | import time 9 | 10 | import numpy as np 11 | 12 | import tools.vision as vision 13 | import config 14 | import cv2 15 | from tools.train_detect import MtcnnDetector 16 | from tools.imagedb import ImageDB 17 | from tools.image_reader import TestImageLoader 18 | from tools.utils import IoU, convert_to_square 19 | 20 | os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID" 21 | os.environ['CUDA_VISIBLE_DEVICES'] = "0" 22 | 23 | 24 | def gen_rnet_data(data_dir, anno_file, pnet_model_file, prefix_path='', use_cuda=True, vis=False): 25 | # load the pnet and pnet_detector 26 | 27 | mtcnn_detector = MtcnnDetector(p_model_path=pnet_model_file, 28 | r_model_path=None, 29 | o_model_path=None, 30 | min_face_size=12, 31 | use_cuda=True) 32 | device = mtcnn_detector.device 33 | 34 | imagedb = ImageDB(anno_file, mode="test", prefix_path=prefix_path) 35 | imdb = imagedb.load_imdb() 36 | image_reader = TestImageLoader(imdb, 1, False) 37 | 38 | all_boxes = [] 39 | batch_idx = 0 40 | 41 | for databatch in image_reader: 42 | if batch_idx % 100 == 0: 43 | print("%d images done" % batch_idx) 44 | im = databatch 45 | t = time.time() 46 | boxes, boxes_align = mtcnn_detector.detect_pnet(im) 47 | if boxes_align is None: 48 | all_boxes.append(np.array([])) 49 | continue 50 | if vis: 51 | vision.vis_face(im, boxes_align) 52 | 53 | t1 = time.time() - t 54 | print('time cost for image {} / {} : {:.4f}'.format(batch_idx, image_reader.size, t1)) 55 | all_boxes.append(boxes_align) 56 | batch_idx += 1 57 | 58 | save_path = config.TRAIN_DATA_DIR 59 | if not os.path.exists(save_path): 60 | os.mkdir(save_path) 61 | 62 | save_file = os.path.join( 63 | save_path, "pnet_detections_%d.pkl" % int(time.time())) 64 | 65 | with open(save_file, 'wb') as f: 66 | pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL) 67 | 68 | # save_file = '/home/liujing/Codes/MTCNN/data/pnet_detections_1532530263.pkl' 69 | get_rnet_sample_data(data_dir, anno_file, save_file, prefix_path) 70 | 71 | 72 | def get_rnet_sample_data(data_dir, anno_file, det_boxes_file, prefix_path): 73 | neg_save_dir = os.path.join(data_dir, "24/negative") 74 | pos_save_dir = os.path.join(data_dir, "24/positive") 75 | part_save_dir = os.path.join(data_dir, "24/part") 76 | 77 | for dir_path in [neg_save_dir, pos_save_dir, part_save_dir]: 78 | if not os.path.exists(dir_path): 79 | os.makedirs(dir_path) 80 | 81 | # load ground truth from annotation file 82 | # format of each line: image/path [x1, y1, x2, y2] for each gt_box in this image 83 | with open(anno_file, 'r') as f: 84 | annotations = f.readlines() 85 | 86 | image_size = 24 87 | im_idx_list = list() 88 | gt_boxes_list = list() 89 | num_of_images = len(annotations) 90 | print("processing %d images in total" % num_of_images) 91 | for annotation in annotations: 92 | # for i in range(10): 93 | annotation = annotation.strip().split(' ') 94 | # annotation = annotations[i].strip().split(' ') 95 | im_idx = os.path.join(prefix_path, annotation[0]) 96 | boxes = list(map(float, annotation[1:])) 97 | boxes = np.array(boxes, dtype=np.float32).reshape(-1, 4) 98 | im_idx_list.append(im_idx) 99 | gt_boxes_list.append(boxes) 100 | 101 | save_path = config.ANNO_STORE_DIR 102 | if not os.path.exists(save_path): 103 | os.makedirs(save_path) 104 | 105 | f1 = open(os.path.join(save_path, 'pos_%d.txt' % image_size), 'w') 106 | f2 = open(os.path.join(save_path, 'neg_%d.txt' % image_size), 'w') 107 | f3 = open(os.path.join(save_path, 'part_%d.txt' % image_size), 'w') 108 | 109 | det_handle = open(det_boxes_file, 'rb') 110 | det_boxes = pickle.load(det_handle) 111 | print(len(det_boxes), num_of_images) 112 | assert len(det_boxes) == num_of_images, "incorrect detections or ground truths" 113 | 114 | # index of neg, pos and part face, used as their image names 115 | n_idx = 0 116 | p_idx = 0 117 | d_idx = 0 118 | image_done = 0 119 | for im_idx, dets, gts in zip(im_idx_list, det_boxes, gt_boxes_list): 120 | image_done += 1 121 | if image_done % 100 == 0: 122 | print("%d images done" % image_done) 123 | if dets.shape[0] == 0: 124 | continue 125 | img = cv2.imread(im_idx) 126 | dets = convert_to_square(dets) 127 | dets[:, 0:4] = np.round(dets[:, 0:4]) 128 | 129 | # each image have at most 50 neg_samples 130 | cur_n_idx = 0 131 | for box in dets: 132 | x_left, y_top, x_right, y_bottom = box[0:4].astype(int) 133 | width = x_right - x_left 134 | height = y_bottom - y_top 135 | # ignore box that is too small or beyond image border 136 | if width < 20 or x_left <= 0 or y_top <= 0 or x_right >= img.shape[1] or y_bottom >= img.shape[0]: 137 | continue 138 | # compute intersection over union(IoU) between current box and all gt boxes 139 | Iou = IoU(box, gts) 140 | cropped_im = img[y_top:y_bottom, x_left:x_right, :] 141 | resized_im = cv2.resize(cropped_im, (image_size, image_size), 142 | interpolation=cv2.INTER_LINEAR) 143 | # save negative images and write label 144 | 145 | if np.max(Iou) < 0.3: 146 | # Iou with all gts must below 0.3 147 | cur_n_idx += 1 148 | if cur_n_idx <= 50: 149 | save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx) 150 | f2.write(save_file + ' 0\n') 151 | cv2.imwrite(save_file, resized_im) 152 | n_idx += 1 153 | else: 154 | # find gt_box with the highest iou 155 | idx = np.argmax(Iou) 156 | assigned_gt = gts[idx] 157 | x1, y1, x2, y2 = assigned_gt 158 | 159 | # compute bbox reg label 160 | offset_x1 = (x1 - x_left) / float(width) 161 | offset_y1 = (y1 - y_top) / float(height) 162 | offset_x2 = (x2 - x_right) / float(width) 163 | offset_y2 = (y2 - y_bottom) / float(height) 164 | 165 | # save positive and part-face images and write labels 166 | if np.max(Iou) >= 0.65: 167 | save_file = os.path.join(pos_save_dir, "%s.jpg" % p_idx) 168 | f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % ( 169 | offset_x1, offset_y1, offset_x2, offset_y2)) 170 | cv2.imwrite(save_file, resized_im) 171 | p_idx += 1 172 | 173 | elif np.max(Iou) >= 0.4: 174 | save_file = os.path.join(part_save_dir, "%s.jpg" % d_idx) 175 | f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % ( 176 | offset_x1, offset_y1, offset_x2, offset_y2)) 177 | cv2.imwrite(save_file, resized_im) 178 | d_idx += 1 179 | 180 | f1.close() 181 | f2.close() 182 | f3.close() 183 | 184 | 185 | def parse_args(): 186 | parser = argparse.ArgumentParser(description='Test mtcnn', 187 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 188 | 189 | parser.add_argument('--face_traindata_store', dest='traindata_store', help='dface train data temporary folder', 190 | default=config.TRAIN_DATA_DIR, type=str) 191 | parser.add_argument('--anno_file', dest='annotation_file', help='wider face original annotation file', 192 | default=os.path.join(config.ANNO_STORE_DIR, "wider_origin_anno.txt"), type=str) 193 | parser.add_argument('--pmodel_file', dest='pnet_model_file', help='PNet model file path', 194 | default='./results/pnet/log_bs512_lr0.010_072402/check_point/model_050.pth', type=str) 195 | parser.add_argument('--gpu', dest='use_cuda', help='with gpu', 196 | default=config.USE_CUDA, type=bool) 197 | parser.add_argument('--prefix_path', dest='prefix_path', help='annotation file image prefix root path', 198 | default='/home/dataset/WIDER/WIDER_train/images', type=str) 199 | 200 | args = parser.parse_args() 201 | return args 202 | 203 | 204 | if __name__ == '__main__': 205 | args = parse_args() 206 | gen_rnet_data(args.traindata_store, args.annotation_file, 207 | args.pnet_model_file, args.prefix_path, args.use_cuda) 208 | -------------------------------------------------------------------------------- /test_image.py: -------------------------------------------------------------------------------- 1 | import time 2 | import sys 3 | import pathlib 4 | import logging 5 | import cv2 6 | from tools.test_detect import MtcnnDetector 7 | 8 | logger = logging.getLogger("app") 9 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s') 10 | console_handler = logging.StreamHandler(sys.stdout) 11 | logger.addHandler(console_handler) 12 | logger.setLevel(logging.INFO) 13 | console_handler.formatter = formatter # 也可以直接给formatter赋值 14 | 15 | 16 | def draw_images(img, bboxs, landmarks): # 在图片上绘制人脸框及特征点 17 | num_face = bboxs.shape[0] 18 | for i in range(num_face): 19 | cv2.rectangle(img, (int(bboxs[i, 0]), int(bboxs[i, 1])), (int( 20 | bboxs[i, 2]), int(bboxs[i, 3])), (0, 255, 0), 3) 21 | for p in landmarks: 22 | for i in range(5): 23 | cv2.circle(img, (int(p[2 * i]), int(p[2 * i + 1])), 6, (0, 0, 255), -1) 24 | return img 25 | 26 | 27 | if __name__ == '__main__': 28 | mtcnn_detector = MtcnnDetector(min_face_size=24, use_cuda=False) # 加载模型参数,构造检测器 29 | logger.info("Init the MtcnnDetector.") 30 | project_root = pathlib.Path() 31 | inputPath = project_root / "data" / "test_images" 32 | outputPath = project_root / "data" / "you_result" 33 | outputPath.mkdir(exist_ok=True) 34 | 35 | start = time.time() 36 | for num, input_img_filename in enumerate(inputPath.iterdir()): 37 | logger.info("Start to process No.{} image.".format(num)) 38 | img_name = input_img_filename.name 39 | logger.info("The name of the image is {}.".format(img_name)) 40 | 41 | img = cv2.imread(str(input_img_filename)) 42 | RGB_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 43 | bboxs, landmarks = mtcnn_detector.detect_face(RGB_image) # 检测得到bboxs以及特征点 44 | img = draw_images(img, bboxs, landmarks) # 得到绘制人脸框及特征点的图片 45 | savePath = outputPath / img_name # 图片保存路径 46 | logger.info("Process complete. Save image to {}.".format(str(savePath))) 47 | 48 | cv2.imwrite(str(savePath), img) # 保存图片 49 | 50 | logger.info("Finish all the images.") 51 | logger.info("Elapsed time: {:.3f}s".format(time.time() - start)) 52 | -------------------------------------------------------------------------------- /test_on_FDDB.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 18-8-8 下午2:54 4 | # @Author : xiezheng 5 | # @Site : 6 | # @File : test_on_FDDB.py 7 | 8 | import logging 9 | import os 10 | import sys 11 | 12 | import cv2 13 | 14 | from tools.train_detect import MtcnnDetector 15 | 16 | data_dir = '/home/datasets/FDDB' 17 | out_dir = 'results/FDDB-results' 18 | 19 | # log part 20 | logger = logging.getLogger("Test-FDDB") 21 | logger.setLevel(logging.INFO) 22 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s') 23 | # %(asctime)s: 打印日志的时间 %(levelname)s: 打印日志级别名称 %(message)s: 打印日志信息 24 | 25 | # StreamHandler: print log 26 | stream_handler = logging.StreamHandler(sys.stdout) 27 | stream_handler.setLevel(level=logging.INFO) 28 | stream_handler.formatter = formatter # 也可以直接给formatter赋值 29 | logger.addHandler(stream_handler) 30 | 31 | 32 | # FileHandler: save log-file 33 | # filename = os.path.join(out_dir, 'output_%s.log' % (time.strftime("%Y%m%d%H%M%S", time.localtime()))) 34 | # file_handler = logging.FileHandler(filename) 35 | # file_handler.setLevel(level=logging.INFO) 36 | # file_handler.setFormatter(formatter) 37 | # logger.addHandler(file_handler) 38 | 39 | 40 | def get_imdb_fddb(data_dir): 41 | imdb = [] 42 | nfold = 10 43 | for n in range(nfold): 44 | file_name = 'FDDB-folds/FDDB-folds-%02d.txt' % (n + 1) 45 | file_name = os.path.join(data_dir, file_name) 46 | fid = open(file_name, 'r') 47 | image_names = [] 48 | for im_name in fid.readlines(): 49 | image_names.append(im_name.strip('\n')) 50 | imdb.append(image_names) 51 | return imdb 52 | 53 | 54 | if __name__ == "__main__": 55 | mtcnn_detector = MtcnnDetector(p_model_path='./results/pnet/log_bs512_lr0.010_072402/check_point/model_050.pth', 56 | r_model_path='./results/rnet/log_bs512_lr0.001_072502/check_point/model_050.pth', 57 | o_model_path='./results/pnet/log_bs512_lr0.001_0726402/check_point/model_050.pth', 58 | min_face_size=12, 59 | use_cuda=False) 60 | # logger.info("Init the MtcnnDetector.") 61 | imdb = get_imdb_fddb(data_dir) 62 | nfold = len(imdb) 63 | 64 | for i in range(nfold): 65 | image_names = imdb[i] 66 | dets_file_name = os.path.join(out_dir, 'FDDB-det-fold-%02d.txt' % (i + 1)) 67 | fid = open(dets_file_name, 'w') 68 | # image_names_abs = [os.path.join(data_dir, 'originalPics', image_name + '.jpg') for image_name in image_names] 69 | 70 | for idx, im_name in enumerate(image_names): 71 | img_path = os.path.join(data_dir, 'originalPics', im_name + '.jpg') 72 | img = cv2.imread(img_path) 73 | boxes, _ = MtcnnDetector.detect_face(img) 74 | 75 | if boxes is None: 76 | fid.write(im_name + '\n') 77 | fid.write(str(1) + '\n') 78 | fid.write('%f %f %f %f %f\n' % (0, 0, 0, 0, 0.99)) 79 | continue 80 | 81 | fid.write(im_name + '\n') 82 | fid.write(str(len(boxes)) + '\n') 83 | 84 | for box in boxes: 85 | fid.write('%f %f %f %f %f\n' % ( 86 | float(box[0]), float(box[1]), float(box[2] - box[0] + 1), float(box[3] - box[1] + 1), box[4])) 87 | 88 | fid.close() 89 | -------------------------------------------------------------------------------- /test_youModel_images.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 18-8-1 上午9:08 4 | # @Author : xiezheng 5 | # @Site : 6 | # @File : test_youModel_images.py 7 | 8 | import time 9 | import sys 10 | import pathlib 11 | import logging 12 | import cv2 13 | from tools.train_detect import MtcnnDetector 14 | 15 | logger = logging.getLogger("app") 16 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s') 17 | console_handler = logging.StreamHandler(sys.stdout) 18 | logger.addHandler(console_handler) 19 | logger.setLevel(logging.INFO) 20 | console_handler.formatter = formatter # 也可以直接给formatter赋值 21 | 22 | 23 | def draw_images(img, bboxs, landmarks): # 在图片上绘制人脸框及特征点 24 | num_face = bboxs.shape[0] 25 | for i in range(num_face): 26 | cv2.rectangle(img, (int(bboxs[i, 0]), int(bboxs[i, 1])), (int( 27 | bboxs[i, 2]), int(bboxs[i, 3])), (0, 255, 0), 3) 28 | for p in landmarks: 29 | for i in range(5): 30 | cv2.circle(img, (int(p[2 * i]), int(p[2 * i + 1])), 3, (0, 0, 255), -1) 31 | return img 32 | 33 | 34 | if __name__ == '__main__': 35 | mtcnn_detector = MtcnnDetector(p_model_path="./results/pnet/log_bs512_lr0.010_072402/check_point/model_050.pth", 36 | r_model_path="./results/rnet/log_bs512_lr0.001_072502/check_point/model_050.pth", 37 | o_model_path="./results/onet/log_bs512_lr0.001_072602/check_point/model_050.pth", 38 | min_face_size=24, 39 | use_cuda=False) # 加载模型参数,构造检测器 40 | logger.info("Init the MtcnnDetector.") 41 | project_root = pathlib.Path() 42 | inputPath = project_root / "data" / "test_images" 43 | outputPath = project_root / "data" / "you_result" 44 | outputPath.mkdir(exist_ok=True) 45 | 46 | start = time.time() 47 | for num, input_img_filename in enumerate(inputPath.iterdir()): 48 | logger.info("Start to process No.{} image.".format(num)) 49 | img_name = input_img_filename.name 50 | logger.info("The name of the image is {}.".format(img_name)) 51 | 52 | img = cv2.imread(str(input_img_filename)) 53 | RGB_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 54 | bboxs, landmarks = mtcnn_detector.detect_face(RGB_image) # 检测得到bboxs以及特征点 55 | img = draw_images(img, bboxs, landmarks) # 得到绘制人脸框及特征点的图片 56 | savePath = outputPath / img_name # 图片保存路径 57 | logger.info("Process complete. Save image to {}.".format(str(savePath))) 58 | 59 | cv2.imwrite(str(savePath), img) # 保存图片 60 | 61 | logger.info("Finish all the images.") 62 | logger.info("Elapsed time: {:.3f}s".format(time.time() - start)) -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_detect import MtcnnDetector 2 | 3 | __all__ = [ 4 | 'MtcnnDetector' 5 | ] -------------------------------------------------------------------------------- /tools/image_reader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | class TestImageLoader: 5 | def __init__(self, imdb, batch_size=1, shuffle=False): 6 | self.imdb = imdb 7 | self.batch_size = batch_size 8 | self.shuffle = shuffle 9 | self.size = len(imdb) 10 | self.index = np.arange(self.size) 11 | 12 | self.cur = 0 13 | self.data = None 14 | self.label = None 15 | 16 | self.reset() 17 | self.get_batch() 18 | 19 | def reset(self): 20 | self.cur = 0 21 | if self.shuffle: 22 | np.random.shuffle(self.index) 23 | 24 | def iter_next(self): 25 | return self.cur + self.batch_size <= self.size 26 | 27 | def __iter__(self): 28 | return self 29 | 30 | def __next__(self): 31 | return self.next() 32 | 33 | def next(self): 34 | if self.iter_next(): 35 | self.get_batch() 36 | self.cur += self.batch_size 37 | return self.data 38 | else: 39 | raise StopIteration 40 | 41 | def getindex(self): 42 | return self.cur / self.batch_size 43 | 44 | def getpad(self): 45 | if self.cur + self.batch_size > self.size: 46 | return self.cur + self.batch_size - self.size 47 | else: 48 | return 0 49 | 50 | def get_batch(self): 51 | cur_from = self.cur 52 | cur_to = min(cur_from + self.batch_size, self.size) 53 | imdb = [self.imdb[self.index[i]] for i in range(cur_from, cur_to)] 54 | data = get_testbatch(imdb) 55 | self.data = data['data'] 56 | 57 | def get_testbatch(imdb): 58 | assert len(imdb) == 1, "Single batch only" 59 | im = cv2.imread(imdb[0]['image']) 60 | im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 61 | data = {'data': im} 62 | return data -------------------------------------------------------------------------------- /tools/image_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as transforms 4 | 5 | transform = transforms.Compose([ 6 | transforms.ToTensor(), 7 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 8 | ]) 9 | 10 | 11 | def convert_image_to_tensor(image): 12 | """convert an image to pytorch tensor 13 | 14 | Parameters: 15 | ---------- 16 | image: numpy array , h * w * c 17 | 18 | Returns: 19 | ------- 20 | image_tensor: pytorch.FloatTensor, c * h * w 21 | """ 22 | return transform(image) 23 | 24 | 25 | def convert_chwTensor_to_hwcNumpy(tensor): 26 | """convert a group images pytorch tensor(count * c * h * w) to numpy array images(count * h * w * c) 27 | Parameters: 28 | ---------- 29 | tensor: numpy array , count * c * h * w 30 | 31 | Returns: 32 | ------- 33 | numpy array images: count * h * w * c 34 | """ 35 | 36 | if isinstance(tensor, torch.FloatTensor): 37 | return np.transpose(tensor.detach().numpy(), (0, 2, 3, 1)) 38 | else: 39 | raise Exception( 40 | "covert b*c*h*w tensor to b*h*w*c numpy error.This tensor must have 4 dimension of float data type.") 41 | -------------------------------------------------------------------------------- /tools/imagedb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch.utils.data as data 4 | from PIL import Image 5 | 6 | 7 | def pil_loader(path): 8 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 9 | with open(path, 'rb') as f: 10 | img = Image.open(f) 11 | return img.convert('RGB') 12 | 13 | 14 | class ImageDB(object): 15 | def __init__(self, image_annotation_file, prefix_path='', mode='train'): 16 | self.prefix_path = prefix_path 17 | self.image_annotation_file = image_annotation_file 18 | self.classes = ['__background__', 'face'] 19 | self.num_classes = 2 20 | self.image_set_index = self.load_image_set_index() 21 | self.num_images = len(self.image_set_index) 22 | self.mode = mode 23 | 24 | def load_image_set_index(self): 25 | ''' Get image index 26 | 27 | Returns: 28 | image_set_index: str, relative path of image 29 | ''' 30 | assert os.path.exists(self.image_annotation_file), 'Path does not exist: {}'.format( 31 | self.image_annotation_file) 32 | with open(self.image_annotation_file, 'r') as f: 33 | image_set_index = [x.strip().split(' ')[0] for x in f.readlines()] 34 | return image_set_index 35 | 36 | def load_imdb(self): 37 | ''' Get and save ground truth image database 38 | 39 | Returns: 40 | gt_imdb: dict, image database with annotations 41 | ''' 42 | 43 | gt_imdb = self.load_annotations() 44 | 45 | return gt_imdb 46 | 47 | def real_image_path(self, index): 48 | ''' Given image's relative index, return full path of image ''' 49 | 50 | index = index.replace("\\", "/") 51 | 52 | if not os.path.exists(index): 53 | image_file = os.path.join(self.prefix_path, index) 54 | else: 55 | image_file = index 56 | if not image_file.endswith('.jpg'): 57 | image_file = image_file + '.jpg' 58 | assert os.path.exists( 59 | image_file), 'Path does not exist: {}'.format(image_file) 60 | 61 | return image_file 62 | 63 | def load_annotations(self, annotation_type=1): 64 | ''' Load annotations 65 | 66 | what's the meaning of annotation_type ? I don't know ! 67 | Returns: 68 | imdb: dict, image database with annotations 69 | ''' 70 | 71 | assert os.path.exists(self.image_annotation_file), 'annotations not found at {}'.format( 72 | self.image_annotation_file) 73 | with open(self.image_annotation_file, 'r') as f: 74 | annotations = f.readlines() 75 | 76 | imdb = [] 77 | for i in range(self.num_images): 78 | annotation = annotations[i].strip().split(' ') 79 | index = annotation[0] 80 | im_path = self.real_image_path(index) 81 | imdb_ = dict() 82 | imdb_['image'] = im_path 83 | 84 | if self.mode == 'test': 85 | pass 86 | else: 87 | label = annotation[1] 88 | imdb_['label'] = int(label) 89 | imdb_['flipped'] = False 90 | imdb_['bbox_target'] = np.zeros((4,)) 91 | imdb_['landmark_target'] = np.zeros((10,)) 92 | if len(annotation[2:]) == 4: 93 | bbox_target = annotation[2:6] 94 | imdb_['bbox_target'] = np.array(bbox_target).astype(float) 95 | if len(annotation[2:]) == 14: 96 | bbox_target = annotation[2:6] 97 | imdb_['bbox_target'] = np.array(bbox_target).astype(float) 98 | landmark = annotation[6:] 99 | imdb_['landmark_target'] = np.array(landmark).astype(float) 100 | imdb.append(imdb_) 101 | return imdb 102 | 103 | def append_flipped_images(self, imdb): 104 | ''' append flipped images to imdb 105 | 106 | Returns: 107 | imdb: dict, image database with flipped image annotations 108 | ''' 109 | print('append flipped images to imdb ', len(imdb)) 110 | for i in range(len(imdb)): 111 | imdb_ = imdb[i] 112 | m_bbox = imdb_['bbox_target'].copy() 113 | m_bbox[0], m_bbox[2] = -m_bbox[2], -m_bbox[0] 114 | 115 | landmark_ = imdb_['landmark_target'].copy() 116 | landmark_ = landmark_.reshape((5, 2)) 117 | landmark_ = np.asarray([(1 - x, y) for (x, y) in landmark_]) 118 | landmark_[[0, 1]] = landmark_[[1, 0]] 119 | landmark_[[3, 4]] = landmark_[[4, 3]] 120 | 121 | item = {'image': imdb_['image'], 122 | 'label': imdb_['label'], 123 | 'bbox_target': m_bbox, 124 | 'landmark_target': landmark_.reshape((10)), 125 | 'flipped': True} 126 | 127 | imdb.append(item) 128 | self.image_set_index *= 2 129 | print('after flipped images appended to imdb ', len(imdb)) 130 | 131 | return imdb 132 | 133 | 134 | class FaceDataset(data.Dataset): 135 | def __init__(self, image_annotation_file, prefix_path='', transform=None, is_train=False): 136 | self.image_annotation_file = image_annotation_file 137 | self.prefix_path = prefix_path 138 | self.is_train = is_train 139 | self.classes = ['__background__', 'face'] 140 | self.num_classes = 2 141 | self.image_set_index = self.load_image_set_index() 142 | self.num_images = len(self.image_set_index) 143 | self.gt_imdb = self.load_annotations() 144 | if self.is_train: 145 | self.gt_imdb = self.append_flipped_images(self.gt_imdb) 146 | self.transform = transform 147 | self.loader = pil_loader 148 | 149 | def load_image_set_index(self): 150 | """Get image index 151 | 152 | Parameters: 153 | ---------- 154 | Returns: 155 | ------- 156 | image_set_index: str 157 | relative path of image 158 | """ 159 | assert os.path.exists(self.image_annotation_file), 'Path does not exist: {}'.format( 160 | self.image_annotation_file) 161 | with open(self.image_annotation_file, 'r') as f: 162 | image_set_index = [x.strip().split(' ')[0] for x in f.readlines()] 163 | return image_set_index 164 | 165 | def real_image_path(self, index): 166 | """Given image index, return full path 167 | 168 | Parameters: 169 | ---------- 170 | index: str 171 | relative path of image 172 | Returns: 173 | ------- 174 | image_file: str 175 | full path of image 176 | """ 177 | 178 | index = index.replace("\\", "/") 179 | 180 | if not os.path.exists(index): 181 | image_file = os.path.join(self.prefix_path, index) 182 | else: 183 | image_file = index 184 | if not image_file.endswith('.jpg'): 185 | image_file = image_file + '.jpg' 186 | assert os.path.exists( 187 | image_file), 'Path does not exist: {}'.format(image_file) 188 | return image_file 189 | 190 | def load_annotations(self): 191 | """Load annotations 192 | 193 | Returns: 194 | ------- 195 | imdb: dict 196 | image database with annotations 197 | """ 198 | 199 | assert os.path.exists(self.image_annotation_file), 'annotations not found at {}'.format( 200 | self.image_annotation_file) 201 | with open(self.image_annotation_file, 'r') as f: 202 | annotations = f.readlines() 203 | 204 | imdb = [] 205 | for i in range(self.num_images): 206 | annotation = annotations[i].strip().split(' ') 207 | index = annotation[0] 208 | im_path = self.real_image_path(index) 209 | imdb_ = dict() 210 | imdb_['image'] = im_path 211 | 212 | if not self.is_train: 213 | # gt_boxes = map(float, annotation[1:]) 214 | # boxes = np.array(bbox, dtype=np.float32).reshape(-1, 4) 215 | # imdb_['gt_boxes'] = boxes 216 | pass 217 | else: 218 | label = annotation[1] 219 | imdb_['label'] = int(label) 220 | imdb_['flipped'] = False 221 | imdb_['bbox_target'] = np.zeros((4,)) 222 | imdb_['landmark_target'] = np.zeros((10,)) 223 | if len(annotation[2:]) == 4: 224 | bbox_target = annotation[2:6] 225 | imdb_['bbox_target'] = np.array(bbox_target).astype(float) 226 | if len(annotation[2:]) == 14: 227 | bbox_target = annotation[2:6] 228 | imdb_['bbox_target'] = np.array(bbox_target).astype(float) 229 | landmark = annotation[6:] 230 | imdb_['landmark_target'] = np.array(landmark).astype(float) 231 | imdb.append(imdb_) 232 | return imdb 233 | 234 | def append_flipped_images(self, imdb): 235 | """append flipped images to imdb 236 | 237 | Parameters: 238 | ---------- 239 | imdb: imdb 240 | image database 241 | Returns: 242 | ------- 243 | imdb: dict 244 | image database with flipped image annotations added 245 | """ 246 | print('append flipped images to imdb', len(imdb)) 247 | for i in range(len(imdb)): 248 | imdb_ = imdb[i] 249 | m_bbox = imdb_['bbox_target'].copy() 250 | m_bbox[0], m_bbox[2] = -m_bbox[2], -m_bbox[0] 251 | 252 | landmark_ = imdb_['landmark_target'].copy() 253 | landmark_ = landmark_.reshape((5, 2)) 254 | landmark_ = np.asarray([(1 - x, y) for (x, y) in landmark_]) 255 | landmark_[[0, 1]] = landmark_[[1, 0]] 256 | landmark_[[3, 4]] = landmark_[[4, 3]] 257 | 258 | item = {'image': imdb_['image'], 259 | 'label': imdb_['label'], 260 | 'bbox_target': m_bbox, 261 | 'landmark_target': landmark_.reshape((10)), 262 | 'flipped': True} 263 | 264 | imdb.append(item) 265 | self.image_set_index *= 2 266 | return imdb 267 | 268 | def __len__(self): 269 | return self.num_images 270 | 271 | def __getitem__(self, idx): 272 | imdb_ = self.gt_imdb[idx] 273 | image = self.loader(imdb_['image']) 274 | labels = {} 275 | labels['label'] = imdb_['label'] 276 | labels['bbox_target'] = imdb_['bbox_target'] 277 | labels['landmark_target'] = imdb_['landmark_target'] 278 | 279 | if self.transform: 280 | image = self.transform(image) 281 | 282 | return image, labels 283 | -------------------------------------------------------------------------------- /tools/logger.py: -------------------------------------------------------------------------------- 1 | # Code is from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 2 | import tensorflow as tf 3 | import numpy as np 4 | import scipy.misc 5 | 6 | try: 7 | from io import StringIO # Python 2.7 8 | except ImportError: 9 | from io import BytesIO # Python 3.x 10 | 11 | __all__ = ["Logger"] 12 | 13 | 14 | class Logger: 15 | 16 | def __init__(self, log_dir): 17 | """Create a summary writer logging to log_dir.""" 18 | self.writer = tf.summary.FileWriter(log_dir) 19 | 20 | def scalar_summary(self, tag, value, step): 21 | """Log a scalar variable.""" 22 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 23 | self.writer.add_summary(summary, step) 24 | 25 | def image_summary(self, tag, images, step): 26 | """Log a list of images.""" 27 | 28 | img_summaries = [] 29 | for i, img in enumerate(images): 30 | # Write the image to a string 31 | try: 32 | s = StringIO() 33 | except: 34 | s = BytesIO() 35 | scipy.misc.toimage(img).save(s, format="png") 36 | 37 | # Create an Image object 38 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 39 | height=img.shape[0], 40 | width=img.shape[1]) 41 | # Create a Summary value 42 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 43 | 44 | # Create and write Summary 45 | summary = tf.Summary(value=img_summaries) 46 | self.writer.add_summary(summary, step) 47 | 48 | def histo_summary(self, tag, values, step, bins=1000): 49 | """Log a histogram of the tensor of values.""" 50 | 51 | # Create a histogram using numpy 52 | counts, bin_edges = np.histogram(values, bins=bins) 53 | 54 | # Fill the fields of the histogram proto 55 | hist = tf.HistogramProto() 56 | hist.min = float(np.min(values)) 57 | hist.max = float(np.max(values)) 58 | hist.num = int(np.prod(values.shape)) 59 | hist.sum = float(np.sum(values)) 60 | hist.sum_squares = float(np.sum(values ** 2)) 61 | 62 | # Drop the start of the first bin 63 | bin_edges = bin_edges[1:] 64 | 65 | # Add bin edges and counts 66 | for edge in bin_edges: 67 | hist.bucket_limit.append(edge) 68 | for c in counts: 69 | hist.bucket.append(c) 70 | 71 | # Create and write Summary 72 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 73 | self.writer.add_summary(summary, step) 74 | self.writer.flush() 75 | -------------------------------------------------------------------------------- /tools/test_detect.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch.utils.model_zoo as model_zoo 4 | import cv2 5 | 6 | from models import PNet,ONet,RNet 7 | 8 | import logging 9 | from tools.image_tools import * 10 | import tools.utils as utils 11 | 12 | model_urls = { 13 | 'pnet': 'https://github.com/xiezheng-cs/mtcnn_pytorch/releases/download/mtcnn/pnet-3da9e965.pt', 14 | 'rnet': 'https://github.com/xiezheng-cs/mtcnn_pytorch/releases/download/mtcnn/rnet-ea379816.pt', 15 | 'onet': 'https://github.com/xiezheng-cs/mtcnn_pytorch/releases/download/mtcnn/onet-4b09b161.pt', 16 | } 17 | 18 | logger = logging.getLogger("app") 19 | class MtcnnDetector(object): 20 | ''' P, R, O net for face detection and landmark alignment''' 21 | 22 | def __init__(self, 23 | min_face_size=12, 24 | stride=2, 25 | threshold=[0.6, 0.7, 0.7], 26 | scale_factor=0.709, 27 | use_cuda=True): 28 | self.pnet_detector, self.rnet_detector, self.onet_detector = self.create_mtcnn_net(use_cuda) 29 | self.min_face_size = min_face_size 30 | self.stride = stride 31 | self.thresh = threshold 32 | self.scale_factor = scale_factor 33 | 34 | def create_mtcnn_net(self, use_cuda=True): 35 | self.device = torch.device( 36 | "cuda" if use_cuda and torch.cuda.is_available() else "cpu") 37 | 38 | pnet = PNet() 39 | pnet.load_state_dict(model_zoo.load_url(model_urls['pnet'])) 40 | pnet.to(self.device).eval() 41 | 42 | onet = ONet() 43 | onet.load_state_dict(model_zoo.load_url(model_urls['onet'])) 44 | onet.to(self.device).eval() 45 | 46 | rnet = RNet() 47 | rnet.load_state_dict(model_zoo.load_url(model_urls['rnet'])) 48 | rnet.to(self.device).eval() 49 | 50 | return pnet, rnet, onet 51 | 52 | def generate_bounding_box(self, map, reg, scale, threshold): 53 | ''' 54 | generate bbox from feature map 55 | for PNet, there exists no fc layer, only convolution layer ,so feature map n x m x 1/4 56 | Parameters: 57 | map: numpy array , n x m x 1, detect score for each position 58 | reg: numpy array , n x m x 4, bbox 59 | scale: float number, scale of this detection 60 | threshold: float number, detect threshold 61 | Returns: 62 | bbox array 63 | ''' 64 | stride = 2 65 | cellsize = 12 66 | 67 | t_index = np.where(map > threshold) 68 | # find nothing 69 | if t_index[0].size == 0: 70 | return np.array([]) 71 | 72 | dx1, dy1, dx2, dy2 = [reg[0, t_index[0], t_index[1], i] 73 | for i in range(4)] 74 | reg = np.array([dx1, dy1, dx2, dy2]) 75 | 76 | score = map[t_index[0], t_index[1], 0] 77 | boundingbox = np.vstack([np.round((stride * t_index[1]) / scale), 78 | np.round((stride * t_index[0]) / scale), 79 | np.round( 80 | (stride * t_index[1] + cellsize) / scale), 81 | np.round( 82 | (stride * t_index[0] + cellsize) / scale), 83 | score, 84 | reg, 85 | # landmarks 86 | ]) 87 | 88 | return boundingbox.T 89 | 90 | def resize_image(self, img, scale): 91 | """ 92 | resize image and transform dimention to [batchsize, channel, height, width] 93 | Parameters: 94 | ---------- 95 | img: numpy array , height x width x channel,input image, channels in BGR order here 96 | scale: float number, scale factor of resize operation 97 | Returns: 98 | ------- 99 | transformed image tensor , 1 x channel x height x width 100 | """ 101 | height, width, channels = img.shape 102 | new_height = int(height * scale) # resized new height 103 | new_width = int(width * scale) # resized new width 104 | new_dim = (new_width, new_height) 105 | img_resized = cv2.resize( 106 | img, new_dim, interpolation=cv2.INTER_LINEAR) # resized image 107 | return img_resized 108 | 109 | def pad(self, bboxes, w, h): 110 | """ 111 | pad the the boxes 112 | Parameters: 113 | ---------- 114 | bboxes: numpy array, n x 5, input bboxes 115 | w: float number, width of the input image 116 | h: float number, height of the input image 117 | Returns : 118 | ------ 119 | dy, dx : numpy array, n x 1, start point of the bbox in target image 120 | edy, edx : numpy array, n x 1, end point of the bbox in target image 121 | y, x : numpy array, n x 1, start point of the bbox in original image 122 | ey, ex : numpy array, n x 1, end point of the bbox in original image 123 | tmph, tmpw: numpy array, n x 1, height and width of the bbox 124 | """ 125 | 126 | tmpw = (bboxes[:, 2] - bboxes[:, 0]).astype(np.int32) 127 | tmph = (bboxes[:, 3] - bboxes[:, 1]).astype(np.int32) 128 | numbox = bboxes.shape[0] 129 | 130 | dx = np.zeros((numbox,)) 131 | dy = np.zeros((numbox,)) 132 | edx, edy = tmpw.copy(), tmph.copy() 133 | 134 | x, y, ex, ey = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3] 135 | 136 | tmp_index = np.where(ex > w) 137 | edx[tmp_index] = tmpw[tmp_index] + w - ex[tmp_index] 138 | ex[tmp_index] = w 139 | 140 | tmp_index = np.where(ey > h) 141 | edy[tmp_index] = tmph[tmp_index] + h - ey[tmp_index] 142 | ey[tmp_index] = h 143 | 144 | tmp_index = np.where(x < 0) 145 | dx[tmp_index] = 0 - x[tmp_index] 146 | x[tmp_index] = 0 147 | 148 | tmp_index = np.where(y < 0) 149 | dy[tmp_index] = 0 - y[tmp_index] 150 | y[tmp_index] = 0 151 | 152 | return_list = [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] 153 | return_list = [item.astype(np.int32) for item in return_list] 154 | 155 | return return_list 156 | 157 | def detect_pnet(self, im): 158 | """Get face candidates through pnet 159 | 160 | Parameters: 161 | ---------- 162 | im: numpy array, input image array 163 | 164 | Returns: 165 | ------- 166 | boxes: numpy array 167 | detected boxes before calibration 168 | boxes_align: numpy array 169 | boxes after calibration 170 | """ 171 | h, w, c = im.shape 172 | net_size = 12 173 | current_scale = float(net_size) / \ 174 | self.min_face_size # find initial scale 175 | im_resized = self.resize_image(im, current_scale) 176 | current_height, current_width, _ = im_resized.shape 177 | 178 | # fcn for pnet 179 | all_boxes = list() 180 | while min(current_height, current_width) > net_size: 181 | image_tensor = convert_image_to_tensor(im_resized) 182 | feed_imgs = image_tensor.unsqueeze(0) 183 | 184 | feed_imgs = feed_imgs.to(self.device) 185 | 186 | cls_map, reg = self.pnet_detector(feed_imgs) 187 | cls_map_np = convert_chwTensor_to_hwcNumpy(cls_map.cpu()) 188 | reg_np = convert_chwTensor_to_hwcNumpy(reg.cpu()) 189 | 190 | boxes = self.generate_bounding_box( 191 | cls_map_np[0, :, :], reg_np, current_scale, self.thresh[0]) 192 | 193 | current_scale *= self.scale_factor 194 | im_resized = self.resize_image(im, current_scale) 195 | current_height, current_width, _ = im_resized.shape 196 | 197 | if boxes.size == 0: 198 | continue 199 | keep = utils.nms(boxes[:, :5], 0.5, 'Union') 200 | boxes = boxes[keep] 201 | all_boxes.append(boxes) 202 | 203 | if len(all_boxes) == 0: 204 | return None, None 205 | 206 | all_boxes = np.vstack(all_boxes) 207 | 208 | # merge the detection from first stage 209 | keep = utils.nms(all_boxes[:, 0:5], 0.7, 'Union') 210 | all_boxes = all_boxes[keep] 211 | 212 | bw = all_boxes[:, 2] - all_boxes[:, 0] 213 | bh = all_boxes[:, 3] - all_boxes[:, 1] 214 | 215 | boxes = np.vstack([all_boxes[:, 0], 216 | all_boxes[:, 1], 217 | all_boxes[:, 2], 218 | all_boxes[:, 3], 219 | all_boxes[:, 4] 220 | ]) 221 | 222 | boxes = boxes.T 223 | 224 | align_topx = all_boxes[:, 0] + all_boxes[:, 5] * bw 225 | align_topy = all_boxes[:, 1] + all_boxes[:, 6] * bh 226 | align_bottomx = all_boxes[:, 2] + all_boxes[:, 7] * bw 227 | align_bottomy = all_boxes[:, 3] + all_boxes[:, 8] * bh 228 | 229 | # refine the boxes 230 | boxes_align = np.vstack([align_topx, 231 | align_topy, 232 | align_bottomx, 233 | align_bottomy, 234 | all_boxes[:, 4] 235 | ]) 236 | boxes_align = boxes_align.T 237 | 238 | return boxes, boxes_align 239 | 240 | def detect_rnet(self, im, dets): 241 | """Get face candidates using rnet 242 | 243 | Parameters: 244 | ---------- 245 | im: numpy array 246 | input image array 247 | dets: numpy array 248 | detection results of pnet 249 | 250 | Returns: 251 | ------- 252 | boxes: numpy array 253 | detected boxes before calibration 254 | boxes_align: numpy array 255 | boxes after calibration 256 | """ 257 | h, w, c = im.shape 258 | if dets is None: 259 | return None, None 260 | 261 | dets = utils.convert_to_square(dets) 262 | dets[:, 0:4] = np.round(dets[:, 0:4]) 263 | 264 | [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(dets, w, h) 265 | num_boxes = dets.shape[0] 266 | 267 | cropped_ims_tensors = [] 268 | for i in range(num_boxes): 269 | try: 270 | if tmph[i] > 0 and tmpw[i] > 0: 271 | tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8) 272 | tmp[dy[i]:edy[i], dx[i]:edx[i], :] = im[y[i]:ey[i], x[i]:ex[i], :] 273 | crop_im = cv2.resize(tmp, (24, 24)) 274 | crop_im_tensor = convert_image_to_tensor(crop_im) 275 | # cropped_ims_tensors[i, :, :, :] = crop_im_tensor 276 | cropped_ims_tensors.append(crop_im_tensor) 277 | except ValueError as e: 278 | print('dy: {}, edy: {}, dx: {}, edx: {}'.format(dy[i], edy[i], dx[i], edx[i])) 279 | print('y: {}, ey: {}, x: {}, ex: {}'.format(y[i], ey[i], x[i], ex[i])) 280 | print(e) 281 | 282 | feed_imgs = torch.stack(cropped_ims_tensors) 283 | 284 | feed_imgs = feed_imgs.to(self.device) 285 | 286 | cls_map, reg = self.rnet_detector(feed_imgs) 287 | cls_map = cls_map.cpu().data.numpy() 288 | reg = reg.cpu().data.numpy() 289 | 290 | keep_inds = np.where(cls_map > self.thresh[1])[0] 291 | 292 | if len(keep_inds) > 0: 293 | boxes = dets[keep_inds] 294 | cls = cls_map[keep_inds] 295 | reg = reg[keep_inds] 296 | else: 297 | return None, None 298 | 299 | keep = utils.nms(boxes, 0.7) 300 | if len(keep) == 0: 301 | return None, None 302 | 303 | keep_cls = cls[keep] 304 | keep_boxes = boxes[keep] 305 | keep_reg = reg[keep] 306 | bw = keep_boxes[:, 2] - keep_boxes[:, 0] 307 | bh = keep_boxes[:, 3] - keep_boxes[:, 1] 308 | boxes = np.vstack([keep_boxes[:, 0], 309 | keep_boxes[:, 1], 310 | keep_boxes[:, 2], 311 | keep_boxes[:, 3], 312 | keep_cls[:, 0] 313 | ]) 314 | align_topx = keep_boxes[:, 0] + keep_reg[:, 0] * bw 315 | align_topy = keep_boxes[:, 1] + keep_reg[:, 1] * bh 316 | align_bottomx = keep_boxes[:, 2] + keep_reg[:, 2] * bw 317 | align_bottomy = keep_boxes[:, 3] + keep_reg[:, 3] * bh 318 | 319 | boxes_align = np.vstack([align_topx, 320 | align_topy, 321 | align_bottomx, 322 | align_bottomy, 323 | keep_cls[:, 0] 324 | ]) 325 | boxes = boxes.T 326 | boxes_align = boxes_align.T 327 | 328 | return boxes, boxes_align 329 | 330 | def detect_onet(self, im, dets): 331 | """Get face candidates using onet 332 | 333 | Parameters: 334 | ---------- 335 | im: numpy array 336 | input image array 337 | dets: numpy array 338 | detection results of rnet 339 | 340 | Returns: 341 | ------- 342 | boxes_align: numpy array 343 | boxes after calibration 344 | landmarks_align: numpy array 345 | landmarks after calibration 346 | 347 | """ 348 | h, w, c = im.shape 349 | if dets is None: 350 | return None, None 351 | 352 | dets = utils.convert_to_square(dets) 353 | dets[:, 0:4] = np.round(dets[:, 0:4]) 354 | 355 | [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(dets, w, h) 356 | num_boxes = dets.shape[0] 357 | 358 | cropped_ims_tensors = [] 359 | for i in range(num_boxes): 360 | try: 361 | if tmph[i] > 0 and tmpw[i] > 0: 362 | tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8) 363 | tmp[dy[i]:edy[i], dx[i]:edx[i], :] = im[y[i]:ey[i], x[i]:ex[i], :] 364 | crop_im = cv2.resize(tmp, (48, 48)) 365 | crop_im_tensor = convert_image_to_tensor(crop_im) 366 | cropped_ims_tensors.append(crop_im_tensor) 367 | except ValueError as e: 368 | print(e) 369 | 370 | feed_imgs = torch.stack(cropped_ims_tensors) 371 | 372 | feed_imgs = feed_imgs.to(self.device) 373 | 374 | cls_map, reg, landmark = self.onet_detector(feed_imgs) 375 | 376 | cls_map = cls_map.cpu().data.numpy() 377 | reg = reg.cpu().data.numpy() 378 | landmark = landmark.cpu().data.numpy() 379 | 380 | keep_inds = np.where(cls_map > self.thresh[2])[0] 381 | 382 | if len(keep_inds) > 0: 383 | boxes = dets[keep_inds] 384 | cls = cls_map[keep_inds] 385 | reg = reg[keep_inds] 386 | landmark = landmark[keep_inds] 387 | else: 388 | return None, None 389 | 390 | keep = utils.nms(boxes, 0.7, mode="Minimum") 391 | 392 | if len(keep) == 0: 393 | return None, None 394 | 395 | keep_cls = cls[keep] 396 | keep_boxes = boxes[keep] 397 | keep_reg = reg[keep] 398 | keep_landmark = landmark[keep] 399 | 400 | bw = keep_boxes[:, 2] - keep_boxes[:, 0] 401 | bh = keep_boxes[:, 3] - keep_boxes[:, 1] 402 | 403 | align_topx = keep_boxes[:, 0] + keep_reg[:, 0] * bw 404 | align_topy = keep_boxes[:, 1] + keep_reg[:, 1] * bh 405 | align_bottomx = keep_boxes[:, 2] + keep_reg[:, 2] * bw 406 | align_bottomy = keep_boxes[:, 3] + keep_reg[:, 3] * bh 407 | 408 | align_landmark_topx = keep_boxes[:, 0] 409 | align_landmark_topy = keep_boxes[:, 1] 410 | 411 | boxes_align = np.vstack([align_topx, 412 | align_topy, 413 | align_bottomx, 414 | align_bottomy, 415 | keep_cls[:, 0] 416 | ]) 417 | 418 | boxes_align = boxes_align.T 419 | 420 | landmark = np.vstack([ 421 | align_landmark_topx + keep_landmark[:, 0] * bw, 422 | align_landmark_topy + keep_landmark[:, 1] * bh, 423 | align_landmark_topx + keep_landmark[:, 2] * bw, 424 | align_landmark_topy + keep_landmark[:, 3] * bh, 425 | align_landmark_topx + keep_landmark[:, 4] * bw, 426 | align_landmark_topy + keep_landmark[:, 5] * bh, 427 | align_landmark_topx + keep_landmark[:, 6] * bw, 428 | align_landmark_topy + keep_landmark[:, 7] * bh, 429 | align_landmark_topx + keep_landmark[:, 8] * bw, 430 | align_landmark_topy + keep_landmark[:, 9] * bh, 431 | ]) 432 | 433 | landmark_align = landmark.T 434 | 435 | return boxes_align, landmark_align 436 | 437 | def detect_face(self, img): 438 | ''' Detect face over image ''' 439 | boxes_align = np.array([]) 440 | landmark_align = np.array([]) 441 | 442 | t = time.time() 443 | 444 | # pnet 445 | if self.pnet_detector: 446 | boxes, boxes_align = self.detect_pnet(img) 447 | if boxes_align is None: 448 | return np.array([]), np.array([]) 449 | 450 | t1 = time.time() - t 451 | t = time.time() 452 | 453 | # rnet 454 | if self.rnet_detector: 455 | boxes, boxes_align = self.detect_rnet(img, boxes_align) 456 | if boxes_align is None: 457 | return np.array([]), np.array([]) 458 | 459 | t2 = time.time() - t 460 | t = time.time() 461 | 462 | # onet 463 | if self.onet_detector: 464 | boxes_align, landmark_align = self.detect_onet(img, boxes_align) 465 | if boxes_align is None: 466 | return np.array([]), np.array([]) 467 | 468 | t3 = time.time() - t 469 | t = time.time() 470 | 471 | logger.info(f"Total time cost: {t1+t2+t3:.4f}s, " 472 | f"PNet time cost: {t1:.4f}s, " 473 | f"RNet time cost: {t2:.4f}s, " 474 | f"ONet time cost: {t2:.4f}s. ") 475 | 476 | 477 | return boxes_align, landmark_align 478 | -------------------------------------------------------------------------------- /tools/train_detect.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 18-7-31 上午11:48 4 | # @Author : xiezheng 5 | # @Site : 6 | # @File : train_detect.py 7 | 8 | import time 9 | 10 | import numpy as np 11 | import torch 12 | import torchvision.transforms as transforms 13 | 14 | import cv2 15 | import os 16 | from models.onet import ONet 17 | from models.pnet import PNet 18 | from models.rnet import RNet 19 | from checkpoint import CheckPoint 20 | from tools.image_tools import * 21 | import tools.utils as utils 22 | 23 | 24 | class MtcnnDetector(object): 25 | ''' P, R, O net for face detection and landmark alignment''' 26 | 27 | def __init__(self, 28 | p_model_path=None, 29 | r_model_path=None, 30 | o_model_path=None, 31 | min_face_size=12, 32 | stride=2, 33 | threshold=[0.6, 0.7, 0.7], 34 | scale_factor=0.709, 35 | use_cuda=True): 36 | self.pnet_detector, self.rnet_detector, self.onet_detector = self.create_mtcnn_net( 37 | p_model_path, r_model_path, o_model_path, use_cuda) 38 | self.min_face_size = min_face_size 39 | self.stride = stride 40 | self.thresh = threshold 41 | self.scale_factor = scale_factor 42 | 43 | def create_mtcnn_net(self, p_model_path=None, r_model_path=None, o_model_path=None, use_cuda=True): 44 | dirname, _ = os.path.split(p_model_path) 45 | checkpoint = CheckPoint(dirname) 46 | 47 | pnet, rnet, onet = None, None, None 48 | self.device = torch.device( 49 | "cuda:0" if use_cuda and torch.cuda.is_available() else "cpu") 50 | 51 | if p_model_path is not None: 52 | pnet = PNet() 53 | pnet_model_state = checkpoint.load_model(p_model_path) 54 | pnet = checkpoint.load_state(pnet, pnet_model_state) 55 | if (use_cuda): 56 | pnet.to(self.device) 57 | pnet.eval() 58 | 59 | if r_model_path is not None: 60 | rnet = RNet() 61 | rnet_model_state = checkpoint.load_model(r_model_path) 62 | rnet = checkpoint.load_state(rnet, rnet_model_state) 63 | if (use_cuda): 64 | rnet.to(self.device) 65 | rnet.eval() 66 | 67 | if o_model_path is not None: 68 | onet = ONet() 69 | onet_model_state = checkpoint.load_model(o_model_path) 70 | onet = checkpoint.load_state(onet, onet_model_state) 71 | if (use_cuda): 72 | onet.to(self.device) 73 | onet.eval() 74 | 75 | return pnet, rnet, onet 76 | 77 | def generate_bounding_box(self, map, reg, scale, threshold): 78 | ''' 79 | generate bbox from feature map 80 | for PNet, there exists no fc layer, only convolution layer ,so feature map n x m x 1/4 81 | Parameters: 82 | map: numpy array , n x m x 1, detect score for each position 83 | reg: numpy array , n x m x 4, bbox 84 | scale: float number, scale of this detection 85 | threshold: float number, detect threshold 86 | Returns: 87 | bbox array 88 | ''' 89 | stride = 2 90 | cellsize = 12 91 | 92 | t_index = np.where(map > threshold) 93 | # find nothing 94 | if t_index[0].size == 0: 95 | return np.array([]) 96 | 97 | dx1, dy1, dx2, dy2 = [reg[0, t_index[0], t_index[1], i] 98 | for i in range(4)] 99 | reg = np.array([dx1, dy1, dx2, dy2]) 100 | 101 | score = map[t_index[0], t_index[1], 0] 102 | boundingbox = np.vstack([np.round((stride * t_index[1]) / scale), 103 | np.round((stride * t_index[0]) / scale), 104 | np.round( 105 | (stride * t_index[1] + cellsize) / scale), 106 | np.round( 107 | (stride * t_index[0] + cellsize) / scale), 108 | score, 109 | reg, 110 | # landmarks 111 | ]) 112 | 113 | return boundingbox.T 114 | 115 | def resize_image(self, img, scale): 116 | """ 117 | resize image and transform dimention to [batchsize, channel, height, width] 118 | Parameters: 119 | ---------- 120 | img: numpy array , height x width x channel,input image, channels in BGR order here 121 | scale: float number, scale factor of resize operation 122 | Returns: 123 | ------- 124 | transformed image tensor , 1 x channel x height x width 125 | """ 126 | height, width, channels = img.shape 127 | new_height = int(height * scale) # resized new height 128 | new_width = int(width * scale) # resized new width 129 | new_dim = (new_width, new_height) 130 | img_resized = cv2.resize( 131 | img, new_dim, interpolation=cv2.INTER_LINEAR) # resized image 132 | return img_resized 133 | 134 | def pad(self, bboxes, w, h): 135 | """ 136 | pad the the boxes 137 | Parameters: 138 | ---------- 139 | bboxes: numpy array, n x 5, input bboxes 140 | w: float number, width of the input image 141 | h: float number, height of the input image 142 | Returns : 143 | ------ 144 | dy, dx : numpy array, n x 1, start point of the bbox in target image 145 | edy, edx : numpy array, n x 1, end point of the bbox in target image 146 | y, x : numpy array, n x 1, start point of the bbox in original image 147 | ey, ex : numpy array, n x 1, end point of the bbox in original image 148 | tmph, tmpw: numpy array, n x 1, height and width of the bbox 149 | """ 150 | 151 | tmpw = (bboxes[:, 2] - bboxes[:, 0]).astype(np.int32) 152 | tmph = (bboxes[:, 3] - bboxes[:, 1]).astype(np.int32) 153 | numbox = bboxes.shape[0] 154 | 155 | dx = np.zeros((numbox,)) 156 | dy = np.zeros((numbox,)) 157 | edx, edy = tmpw.copy(), tmph.copy() 158 | 159 | x, y, ex, ey = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3] 160 | 161 | tmp_index = np.where(ex > w) 162 | edx[tmp_index] = tmpw[tmp_index] + w - ex[tmp_index] 163 | ex[tmp_index] = w 164 | 165 | tmp_index = np.where(ey > h) 166 | edy[tmp_index] = tmph[tmp_index] + h - ey[tmp_index] 167 | ey[tmp_index] = h 168 | 169 | tmp_index = np.where(x < 0) 170 | dx[tmp_index] = 0 - x[tmp_index] 171 | x[tmp_index] = 0 172 | 173 | tmp_index = np.where(y < 0) 174 | dy[tmp_index] = 0 - y[tmp_index] 175 | y[tmp_index] = 0 176 | 177 | return_list = [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] 178 | return_list = [item.astype(np.int32) for item in return_list] 179 | 180 | return return_list 181 | 182 | def detect_pnet(self, im): 183 | """Get face candidates through pnet 184 | 185 | Parameters: 186 | ---------- 187 | im: numpy array, input image array 188 | 189 | Returns: 190 | ------- 191 | boxes: numpy array 192 | detected boxes before calibration 193 | boxes_align: numpy array 194 | boxes after calibration 195 | """ 196 | h, w, c = im.shape 197 | net_size = 12 198 | current_scale = float(net_size) / \ 199 | self.min_face_size # find initial scale 200 | im_resized = self.resize_image(im, current_scale) 201 | current_height, current_width, _ = im_resized.shape 202 | 203 | # fcn for pnet 204 | all_boxes = list() 205 | while min(current_height, current_width) > net_size: 206 | image_tensor = convert_image_to_tensor(im_resized) 207 | feed_imgs = image_tensor.unsqueeze(0) 208 | 209 | feed_imgs = feed_imgs.to(self.device) 210 | 211 | cls_map, reg = self.pnet_detector(feed_imgs) 212 | cls_map_np = convert_chwTensor_to_hwcNumpy(cls_map.cpu()) 213 | reg_np = convert_chwTensor_to_hwcNumpy(reg.cpu()) 214 | 215 | boxes = self.generate_bounding_box( 216 | cls_map_np[0, :, :], reg_np, current_scale, self.thresh[0]) 217 | 218 | current_scale *= self.scale_factor 219 | im_resized = self.resize_image(im, current_scale) 220 | current_height, current_width, _ = im_resized.shape 221 | 222 | if boxes.size == 0: 223 | continue 224 | keep = utils.nms(boxes[:, :5], 0.5, 'Union') 225 | boxes = boxes[keep] 226 | all_boxes.append(boxes) 227 | 228 | if len(all_boxes) == 0: 229 | return None, None 230 | 231 | all_boxes = np.vstack(all_boxes) 232 | 233 | # merge the detection from first stage 234 | keep = utils.nms(all_boxes[:, 0:5], 0.7, 'Union') 235 | all_boxes = all_boxes[keep] 236 | 237 | bw = all_boxes[:, 2] - all_boxes[:, 0] 238 | bh = all_boxes[:, 3] - all_boxes[:, 1] 239 | 240 | boxes = np.vstack([all_boxes[:, 0], 241 | all_boxes[:, 1], 242 | all_boxes[:, 2], 243 | all_boxes[:, 3], 244 | all_boxes[:, 4] 245 | ]) 246 | 247 | boxes = boxes.T 248 | 249 | align_topx = all_boxes[:, 0] + all_boxes[:, 5] * bw 250 | align_topy = all_boxes[:, 1] + all_boxes[:, 6] * bh 251 | align_bottomx = all_boxes[:, 2] + all_boxes[:, 7] * bw 252 | align_bottomy = all_boxes[:, 3] + all_boxes[:, 8] * bh 253 | 254 | # refine the boxes 255 | boxes_align = np.vstack([align_topx, 256 | align_topy, 257 | align_bottomx, 258 | align_bottomy, 259 | all_boxes[:, 4] 260 | ]) 261 | boxes_align = boxes_align.T 262 | 263 | return boxes, boxes_align 264 | 265 | def detect_rnet(self, im, dets): 266 | """Get face candidates using rnet 267 | 268 | Parameters: 269 | ---------- 270 | im: numpy array 271 | input image array 272 | dets: numpy array 273 | detection results of pnet 274 | 275 | Returns: 276 | ------- 277 | boxes: numpy array 278 | detected boxes before calibration 279 | boxes_align: numpy array 280 | boxes after calibration 281 | """ 282 | h, w, c = im.shape 283 | if dets is None: 284 | return None, None 285 | 286 | dets = utils.convert_to_square(dets) 287 | dets[:, 0:4] = np.round(dets[:, 0:4]) 288 | 289 | [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(dets, w, h) 290 | num_boxes = dets.shape[0] 291 | 292 | cropped_ims_tensors = [] 293 | for i in range(num_boxes): 294 | try: 295 | if tmph[i] > 0 and tmpw[i] > 0: 296 | tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8) 297 | tmp[dy[i]:edy[i], dx[i]:edx[i], :] = im[y[i]:ey[i], x[i]:ex[i], :] 298 | crop_im = cv2.resize(tmp, (24, 24)) 299 | crop_im_tensor = convert_image_to_tensor(crop_im) 300 | # cropped_ims_tensors[i, :, :, :] = crop_im_tensor 301 | cropped_ims_tensors.append(crop_im_tensor) 302 | except ValueError as e: 303 | print('dy: {}, edy: {}, dx: {}, edx: {}'.format(dy[i], edy[i], dx[i], edx[i])) 304 | print('y: {}, ey: {}, x: {}, ex: {}'.format(y[i], ey[i], x[i], ex[i])) 305 | print(e) 306 | 307 | feed_imgs = torch.stack(cropped_ims_tensors) 308 | 309 | feed_imgs = feed_imgs.to(self.device) 310 | 311 | cls_map, reg = self.rnet_detector(feed_imgs) 312 | cls_map = cls_map.cpu().data.numpy() 313 | reg = reg.cpu().data.numpy() 314 | 315 | keep_inds = np.where(cls_map > self.thresh[1])[0] 316 | 317 | if len(keep_inds) > 0: 318 | boxes = dets[keep_inds] 319 | cls = cls_map[keep_inds] 320 | reg = reg[keep_inds] 321 | else: 322 | return None, None 323 | 324 | keep = utils.nms(boxes, 0.7) 325 | if len(keep) == 0: 326 | return None, None 327 | 328 | keep_cls = cls[keep] 329 | keep_boxes = boxes[keep] 330 | keep_reg = reg[keep] 331 | bw = keep_boxes[:, 2] - keep_boxes[:, 0] 332 | bh = keep_boxes[:, 3] - keep_boxes[:, 1] 333 | boxes = np.vstack([keep_boxes[:, 0], 334 | keep_boxes[:, 1], 335 | keep_boxes[:, 2], 336 | keep_boxes[:, 3], 337 | keep_cls[:, 0] 338 | ]) 339 | align_topx = keep_boxes[:, 0] + keep_reg[:, 0] * bw 340 | align_topy = keep_boxes[:, 1] + keep_reg[:, 1] * bh 341 | align_bottomx = keep_boxes[:, 2] + keep_reg[:, 2] * bw 342 | align_bottomy = keep_boxes[:, 3] + keep_reg[:, 3] * bh 343 | 344 | boxes_align = np.vstack([align_topx, 345 | align_topy, 346 | align_bottomx, 347 | align_bottomy, 348 | keep_cls[:, 0] 349 | ]) 350 | boxes = boxes.T 351 | boxes_align = boxes_align.T 352 | 353 | return boxes, boxes_align 354 | 355 | def detect_onet(self, im, dets): 356 | """Get face candidates using onet 357 | 358 | Parameters: 359 | ---------- 360 | im: numpy array 361 | input image array 362 | dets: numpy array 363 | detection results of rnet 364 | 365 | Returns: 366 | ------- 367 | boxes_align: numpy array 368 | boxes after calibration 369 | landmarks_align: numpy array 370 | landmarks after calibration 371 | 372 | """ 373 | h, w, c = im.shape 374 | if dets is None: 375 | return None, None 376 | 377 | dets = utils.convert_to_square(dets) 378 | dets[:, 0:4] = np.round(dets[:, 0:4]) 379 | 380 | [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = self.pad(dets, w, h) 381 | num_boxes = dets.shape[0] 382 | 383 | cropped_ims_tensors = [] 384 | for i in range(num_boxes): 385 | try: 386 | if tmph[i] > 0 and tmpw[i] > 0: 387 | tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8) 388 | tmp[dy[i]:edy[i], dx[i]:edx[i], :] = im[y[i]:ey[i], x[i]:ex[i], :] 389 | crop_im = cv2.resize(tmp, (48, 48)) 390 | crop_im_tensor = convert_image_to_tensor(crop_im) 391 | cropped_ims_tensors.append(crop_im_tensor) 392 | except ValueError as e: 393 | print(e) 394 | 395 | feed_imgs = torch.stack(cropped_ims_tensors) 396 | 397 | feed_imgs = feed_imgs.to(self.device) 398 | 399 | cls_map, reg, landmark = self.onet_detector(feed_imgs) 400 | 401 | cls_map = cls_map.cpu().data.numpy() 402 | reg = reg.cpu().data.numpy() 403 | landmark = landmark.cpu().data.numpy() 404 | 405 | keep_inds = np.where(cls_map > self.thresh[2])[0] 406 | 407 | if len(keep_inds) > 0: 408 | boxes = dets[keep_inds] 409 | cls = cls_map[keep_inds] 410 | reg = reg[keep_inds] 411 | landmark = landmark[keep_inds] 412 | else: 413 | return None, None 414 | 415 | keep = utils.nms(boxes, 0.7, mode="Minimum") 416 | 417 | if len(keep) == 0: 418 | return None, None 419 | 420 | keep_cls = cls[keep] 421 | keep_boxes = boxes[keep] 422 | keep_reg = reg[keep] 423 | keep_landmark = landmark[keep] 424 | 425 | bw = keep_boxes[:, 2] - keep_boxes[:, 0] 426 | bh = keep_boxes[:, 3] - keep_boxes[:, 1] 427 | 428 | align_topx = keep_boxes[:, 0] + keep_reg[:, 0] * bw 429 | align_topy = keep_boxes[:, 1] + keep_reg[:, 1] * bh 430 | align_bottomx = keep_boxes[:, 2] + keep_reg[:, 2] * bw 431 | align_bottomy = keep_boxes[:, 3] + keep_reg[:, 3] * bh 432 | 433 | align_landmark_topx = keep_boxes[:, 0] 434 | align_landmark_topy = keep_boxes[:, 1] 435 | 436 | boxes_align = np.vstack([align_topx, 437 | align_topy, 438 | align_bottomx, 439 | align_bottomy, 440 | keep_cls[:, 0] 441 | ]) 442 | 443 | boxes_align = boxes_align.T 444 | 445 | landmark = np.vstack([ 446 | align_landmark_topx + keep_landmark[:, 0] * bw, 447 | align_landmark_topy + keep_landmark[:, 1] * bh, 448 | align_landmark_topx + keep_landmark[:, 2] * bw, 449 | align_landmark_topy + keep_landmark[:, 3] * bh, 450 | align_landmark_topx + keep_landmark[:, 4] * bw, 451 | align_landmark_topy + keep_landmark[:, 5] * bh, 452 | align_landmark_topx + keep_landmark[:, 6] * bw, 453 | align_landmark_topy + keep_landmark[:, 7] * bh, 454 | align_landmark_topx + keep_landmark[:, 8] * bw, 455 | align_landmark_topy + keep_landmark[:, 9] * bh, 456 | ]) 457 | 458 | landmark_align = landmark.T 459 | 460 | return boxes_align, landmark_align 461 | 462 | def detect_face(self, img): 463 | ''' Detect face over image ''' 464 | boxes_align = np.array([]) 465 | landmark_align = np.array([]) 466 | 467 | t = time.time() 468 | 469 | # pnet 470 | if self.pnet_detector: 471 | boxes, boxes_align = self.detect_pnet(img) 472 | if boxes_align is None: 473 | return np.array([]), np.array([]) 474 | 475 | t1 = time.time() - t 476 | t = time.time() 477 | 478 | # rnet 479 | if self.rnet_detector: 480 | boxes, boxes_align = self.detect_rnet(img, boxes_align) 481 | if boxes_align is None: 482 | return np.array([]), np.array([]) 483 | 484 | t2 = time.time() - t 485 | t = time.time() 486 | 487 | # onet 488 | if self.onet_detector: 489 | boxes_align, landmark_align = self.detect_onet(img, boxes_align) 490 | if boxes_align is None: 491 | return np.array([]), np.array([]) 492 | 493 | t3 = time.time() - t 494 | t = time.time() 495 | print( 496 | "time cost " + '{:.3f}'.format(t1 + t2 + t3) + ' pnet {:.3f} rnet {:.3f} onet {:.3f}'.format(t1, t2, 497 | t3)) 498 | 499 | return boxes_align, landmark_align 500 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def IoU(box, boxes): 5 | """Compute IoU between detect box and gt boxes 6 | 7 | Parameters: 8 | ---------- 9 | box: numpy array , shape (5, ): x1, y1, x2, y2, score 10 | input box 11 | boxes: numpy array, shape (n, 4): x1, y1, x2, y2 12 | input ground truth boxes 13 | 14 | Returns: 15 | ------- 16 | ovr: numpy.array, shape (n, ) 17 | IoU 18 | """ 19 | box_area = (box[2] - box[0]) * (box[3] - box[1]) 20 | area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 21 | 22 | xx1 = np.maximum(box[0], boxes[:, 0]) 23 | yy1 = np.maximum(box[1], boxes[:, 1]) 24 | xx2 = np.minimum(box[2], boxes[:, 2]) 25 | yy2 = np.minimum(box[3], boxes[:, 3]) 26 | 27 | # compute the width and height of the inter box 28 | w = np.maximum(0, xx2 - xx1) 29 | h = np.maximum(0, yy2 - yy1) 30 | 31 | inter = w * h 32 | ovr = np.true_divide(inter, (box_area + area - inter)) 33 | 34 | 35 | return ovr 36 | 37 | 38 | def convert_to_square(bbox): 39 | ''' Convert bbox to a square which it can include the bbox 40 | Parameters: 41 | bbox: numpy array, shape n x 5 42 | 43 | returns: 44 | square box 45 | ''' 46 | 47 | square_bbox = bbox.copy() 48 | h = bbox[:, 3] - bbox[:, 1] 49 | w = bbox[:, 2] - bbox[:, 0] 50 | max_side = np.maximum(h, w) 51 | square_bbox[:, 0] = bbox[:, 0] + w*0.5 - max_side*0.5 52 | square_bbox[:, 1] = bbox[:, 1] + h*0.5 - max_side*0.5 53 | square_bbox[:, 2] = square_bbox[:, 0] + max_side 54 | square_bbox[:, 3] = square_bbox[:, 1] + max_side 55 | 56 | return square_bbox 57 | 58 | 59 | def nms(dets, thresh, mode='Union'): 60 | ''' greedily select bboxes with high confidence,if an box overlap with the highest score box > thres, rule it out 61 | 62 | params: 63 | dets: [[x1, y1, x2, y2, score]] 64 | thresh: retain overlap <= thresh 65 | return: 66 | indexes to keep 67 | ''' 68 | x1 = dets[:, 0] 69 | y1 = dets[:, 1] 70 | x2 = dets[:, 2] 71 | y2 = dets[:, 3] 72 | scores = dets[:, 4] 73 | 74 | areas = (x2 - x1) * (y2 - y1) 75 | order = scores.argsort()[::-1] # the index of scores by desc 76 | 77 | keep = [] 78 | while order.size > 0: 79 | i = order[0] 80 | keep.append(i) 81 | xx1 = np.maximum(x1[i], x1[order[1:]]) 82 | yy1 = np.maximum(y1[i], y1[order[1:]]) 83 | xx2 = np.minimum(x2[i], x2[order[1:]]) 84 | yy2 = np.minimum(y2[i], y2[order[1:]]) 85 | 86 | w = np.maximum(0.0, xx2 - xx1) 87 | h = np.maximum(0.0, yy2 - yy1) 88 | inter = w * h 89 | inter = w * h 90 | if mode == "Union": 91 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 92 | elif mode == "Minimum": 93 | ovr = inter / np.minimum(areas[i], areas[order[1:]]) 94 | 95 | inds = np.where(ovr <= thresh)[0] 96 | order = order[inds + 1] 97 | 98 | return keep 99 | 100 | 101 | class AverageMeter(object): 102 | """Computes and stores the average and current value""" 103 | 104 | def __init__(self): 105 | self.reset() 106 | 107 | def reset(self): 108 | """ 109 | reset all parameters 110 | """ 111 | self.val = 0 112 | self.avg = 0 113 | self.sum = 0 114 | self.count = 0 115 | 116 | def update(self, val, n=1): 117 | """ 118 | update parameters 119 | """ 120 | self.val = val 121 | self.sum += val * n 122 | self.count += n 123 | self.avg = self.sum / self.count 124 | -------------------------------------------------------------------------------- /tools/vision.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def vis_face(im_array, dets, landmarks=None): 5 | """Visualize detection results of an image 6 | 7 | Parameters: 8 | ---------- 9 | im_array: numpy.ndarray, shape(1, c, h, w) 10 | test image in rgb 11 | dets: numpy.ndarray([[x1 y1 x2 y2 score landmarks]]) 12 | detection results before calibration 13 | landmarks: numpy.ndarray([landmarks for five facial landmarks]) 14 | 15 | Returns: 16 | ------- 17 | """ 18 | figure = plt.figure() 19 | plt.imshow(im_array) 20 | figure.suptitle('Face Detector', fontsize=12, color='r') 21 | 22 | for i in range(dets.shape[0]): 23 | bbox = dets[i, 0:4] 24 | rect = plt.Rectangle((bbox[0], bbox[1]), 25 | bbox[2] - bbox[0], 26 | bbox[3] - bbox[1], fill=False, 27 | edgecolor='yellow', linewidth=0.9) 28 | plt.gca().add_patch(rect) 29 | 30 | if landmarks is not None: 31 | for i in range(landmarks.shape[0]): 32 | landmarks_one = landmarks[i, :] 33 | landmarks_one = landmarks_one.reshape((5, 2)) 34 | for j in range(5): 35 | plt.scatter(landmarks_one[j, 0], landmarks_one[j, 1], c='red', linewidths=1, marker='x', s=5) 36 | 37 | plt.show() 38 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezheng-cs/mtcnn_pytorch/3a89881b633be68b0226971b8817573ceb8a3b14/training/__init__.py -------------------------------------------------------------------------------- /training/onet/config.py: -------------------------------------------------------------------------------- 1 | class Config(object): 2 | def __init__(self): 3 | super(Config, self).__init__() 4 | # ------------ General options ---------------------------------------- 5 | self.save_path = "./results/onet/" 6 | self.dataPath = "/home/dataset/WIDER/WIDER_train/images" # path for loading data set 7 | self.annoPath = "./annotations/imglist_anno_48.txt" 8 | self.manualSeed = 1 # manually set RNG seed 9 | self.use_cuda = True 10 | self.GPU = "0" # default gpu to use 11 | 12 | # ------------- Data options ------------------------------------------- 13 | self.nThreads = 8 # number of data loader threads 14 | 15 | # ---------- Optimization options -------------------------------------- 16 | self.nEpochs = 50 # number of total epochs to train 400 17 | self.batchSize = 512 # mini-batch size 128 18 | 19 | # lr master for optimizer 1 (mask vector d) 20 | self.lr = 0.001 # initial learning rate 21 | self.step = [10, 25, 40] # step for linear or exp learning rate policy 22 | self.decayRate = 0.1 # lr decay rate 23 | self.endlr = -1 24 | 25 | # ---------- Model options --------------------------------------------- 26 | self.experimentID = "072602" 27 | 28 | # ---------- Resume or Retrain options --------------------------------------------- 29 | self.resume = None # "./checkpoint_064.pth" 30 | self.retrain = None 31 | 32 | self.save_path = self.save_path + "log_bs{:d}_lr{:.3f}_{}/".format(self.batchSize, self.lr, self.experimentID) 33 | -------------------------------------------------------------------------------- /training/onet/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | 4 | import argparse 5 | import torch 6 | from tools.imagedb import FaceDataset 7 | from torchvision import transforms 8 | from models.onet import ONet 9 | from training.onet.trainer import ONetTrainer 10 | from training.onet.config import Config 11 | from tools.logger import Logger 12 | from checkpoint import CheckPoint 13 | import os 14 | import config 15 | 16 | # Get config 17 | config = Config() 18 | if not os.path.exists(config.save_path): 19 | os.makedirs(config.save_path) 20 | 21 | # Set device 22 | os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU 23 | use_cuda = config.use_cuda and torch.cuda.is_available() 24 | torch.manual_seed(config.manualSeed) 25 | torch.cuda.manual_seed(config.manualSeed) 26 | device = torch.device("cuda" if use_cuda else "cpu") 27 | torch.backends.cudnn.benchmark = True 28 | 29 | # Set dataloader 30 | kwargs = {'num_workers': config.nThreads, 'pin_memory': True} if use_cuda else {} 31 | transform = transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 34 | ]) 35 | train_loader = torch.utils.data.DataLoader( 36 | FaceDataset(config.annoPath, transform=transform, is_train=True), batch_size=config.batchSize, shuffle=True, **kwargs) 37 | 38 | # Set model 39 | model = ONet() 40 | model = model.to(device) 41 | 42 | # Set checkpoint 43 | checkpoint = CheckPoint(config.save_path) 44 | 45 | # Set optimizer 46 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 47 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.step, gamma=0.1) 48 | 49 | # Set trainer 50 | logger = Logger(config.save_path) 51 | trainer = ONetTrainer(config.lr, train_loader, model, optimizer, scheduler, logger, device) 52 | 53 | for epoch in range(1, config.nEpochs + 1): 54 | trainer.train(epoch) 55 | checkpoint.save_model(model, index=epoch) 56 | -------------------------------------------------------------------------------- /training/onet/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | import time 4 | from models.lossfn import LossFn 5 | from tools.utils import AverageMeter 6 | 7 | 8 | class ONetTrainer(object): 9 | 10 | def __init__(self, lr, train_loader, model, optimizer, scheduler, logger, device): 11 | self.lr = lr 12 | self.train_loader = train_loader 13 | self.model = model 14 | self.optimizer = optimizer 15 | self.scheduler = scheduler 16 | self.device = device 17 | self.lossfn = LossFn(self.device) 18 | self.logger = logger 19 | self.run_count = 0 20 | self.scalar_info = {} 21 | 22 | def compute_accuracy(self, prob_cls, gt_cls): 23 | #we only need the detection which >= 0 24 | prob_cls = torch.squeeze(prob_cls) 25 | mask = torch.ge(gt_cls, 0) 26 | #get valid element 27 | valid_gt_cls = torch.masked_select(gt_cls, mask) 28 | valid_prob_cls = torch.masked_select(prob_cls, mask) 29 | size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) 30 | prob_ones = torch.ge(valid_prob_cls, 0.6).float() 31 | right_ones = torch.eq(prob_ones, valid_gt_cls.float()).float() 32 | 33 | return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) 34 | 35 | def update_lr(self, epoch): 36 | """ 37 | update learning rate of optimizers 38 | :param epoch: current training epoch 39 | """ 40 | # update learning rate of model optimizer 41 | for param_group in self.optimizer.param_groups: 42 | param_group['lr'] = self.lr 43 | 44 | def train(self, epoch): 45 | cls_loss_ = AverageMeter() 46 | box_offset_loss_ = AverageMeter() 47 | landmark_loss_ = AverageMeter() 48 | total_loss_ = AverageMeter() 49 | accuracy_ = AverageMeter() 50 | 51 | self.scheduler.step() 52 | self.model.train() 53 | 54 | for batch_idx, (data, target) in enumerate(self.train_loader): 55 | gt_label = target['label'] 56 | gt_bbox = target['bbox_target'] 57 | gt_landmark = target['landmark_target'] 58 | data, gt_label, gt_bbox, gt_landmark = data.to(self.device), gt_label.to( 59 | self.device), gt_bbox.to(self.device).float(), gt_landmark.to(self.device).float() 60 | 61 | cls_pred, box_offset_pred, landmark_offset_pred = self.model(data) 62 | # compute the loss 63 | cls_loss = self.lossfn.cls_loss(gt_label, cls_pred) 64 | box_offset_loss = self.lossfn.box_loss( 65 | gt_label, gt_bbox, box_offset_pred) 66 | landmark_loss = self.lossfn.landmark_loss(gt_label, gt_landmark, landmark_offset_pred) 67 | 68 | total_loss = cls_loss + box_offset_loss * 0.5 + landmark_loss 69 | accuracy = self.compute_accuracy(cls_pred, gt_label) 70 | 71 | self.optimizer.zero_grad() 72 | total_loss.backward() 73 | self.optimizer.step() 74 | 75 | cls_loss_.update(cls_loss, data.size(0)) 76 | box_offset_loss_.update(box_offset_loss, data.size(0)) 77 | landmark_loss_.update(landmark_loss, data.size(0)) 78 | total_loss_.update(total_loss, data.size(0)) 79 | accuracy_.update(accuracy, data.size(0)) 80 | 81 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}'.format( 82 | epoch, batch_idx * len(data), len(self.train_loader.dataset), 83 | 100. * batch_idx / len(self.train_loader), total_loss.item(), accuracy.item())) 84 | 85 | self.scalar_info['cls_loss'] = cls_loss_.avg 86 | self.scalar_info['box_offset_loss'] = box_offset_loss_.avg 87 | self.scalar_info['landmark_loss'] = landmark_loss_.avg 88 | self.scalar_info['total_loss'] = total_loss_.avg 89 | self.scalar_info['accuracy'] = accuracy_.avg 90 | self.scalar_info['lr'] = self.scheduler.get_lr()[0] 91 | 92 | if self.logger is not None: 93 | for tag, value in list(self.scalar_info.items()): 94 | self.logger.scalar_summary(tag, value, self.run_count) 95 | self.scalar_info = {} 96 | self.run_count += 1 97 | 98 | print("|===>Loss: {:.4f}".format(total_loss_.avg)) 99 | return cls_loss_.avg, box_offset_loss_.avg, landmark_loss_.avg, total_loss_.avg, accuracy_.avg -------------------------------------------------------------------------------- /training/pnet/config.py: -------------------------------------------------------------------------------- 1 | class Config(object): 2 | def __init__(self): 3 | super(Config, self).__init__() 4 | # ------------ General options ---------------------------------------- 5 | self.save_path = "./results/pnet/" 6 | self.dataPath = "'/home/dataset/WIDER/WIDER_train/images" # path for loading data set 7 | self.annoPath = "./annotations/imglist_anno_12.txt" 8 | self.manualSeed = 1 # manually set RNG seed 9 | self.use_cuda = True 10 | self.GPU = "0" # default gpu to use 11 | 12 | # ------------- Data options ------------------------------------------- 13 | self.nThreads = 8 # number of data loader threads 14 | 15 | # ---------- Optimization options -------------------------------------- 16 | self.nEpochs = 50 # number of total epochs to train 400 17 | self.batchSize = 512 # mini-batch size 128 18 | 19 | # lr master for optimizer 1 (mask vector d) 20 | self.lr = 0.01 # initial learning rate 21 | self.step = [10, 25, 40] # step for linear or exp learning rate policy 22 | self.decayRate = 0.1 # lr decay rate 23 | self.endlr = -1 24 | 25 | # ---------- Model options --------------------------------------------- 26 | self.experimentID = "072402" 27 | 28 | # ---------- Resume or Retrain options --------------------------------------------- 29 | self.resume = None # "./checkpoint_064.pth" 30 | self.retrain = None 31 | 32 | self.save_path = self.save_path + "log_bs{:d}_lr{:.3f}_{}/".format(self.batchSize, self.lr, self.experimentID) 33 | -------------------------------------------------------------------------------- /training/pnet/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | 4 | import argparse 5 | import torch 6 | from tools.imagedb import FaceDataset 7 | from torchvision import transforms 8 | from models.pnet import PNet 9 | from training.pnet.trainer import PNetTrainer 10 | from training.pnet.config import Config 11 | from tools.logger import Logger 12 | from checkpoint import CheckPoint 13 | import os 14 | import config 15 | 16 | # Get config 17 | config = Config() 18 | if not os.path.exists(config.save_path): 19 | os.makedirs(config.save_path) 20 | 21 | # Set device 22 | os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU 23 | use_cuda = config.use_cuda and torch.cuda.is_available() 24 | torch.manual_seed(config.manualSeed) 25 | torch.cuda.manual_seed(config.manualSeed) 26 | device = torch.device("cuda" if use_cuda else "cpu") 27 | torch.backends.cudnn.benchmark = True 28 | 29 | # Set dataloader 30 | kwargs = {'num_workers': config.nThreads, 'pin_memory': True} if use_cuda else {} 31 | transform = transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 34 | ]) 35 | train_loader = torch.utils.data.DataLoader( 36 | FaceDataset(config.annoPath, transform=transform, is_train=True), batch_size=config.batchSize, shuffle=True, **kwargs) 37 | 38 | # Set model 39 | model = PNet() 40 | model = model.to(device) 41 | 42 | # Set checkpoint 43 | checkpoint = CheckPoint(config.save_path) 44 | 45 | # Set optimizer 46 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 47 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.step, gamma=0.1) 48 | 49 | # Set trainer 50 | logger = Logger(config.save_path) 51 | trainer = PNetTrainer(config.lr, train_loader, model, optimizer, scheduler, logger, device) 52 | 53 | for epoch in range(1, config.nEpochs + 1): 54 | cls_loss_, box_offset_loss, total_loss, accuracy = trainer.train(epoch) 55 | checkpoint.save_model(model, index=epoch) 56 | 57 | -------------------------------------------------------------------------------- /training/pnet/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | import time 4 | from models.lossfn import LossFn 5 | from tools.utils import AverageMeter 6 | 7 | 8 | class PNetTrainer(object): 9 | 10 | def __init__(self, lr, train_loader, model, optimizer, scheduler, logger, device): 11 | self.lr = lr 12 | self.train_loader = train_loader 13 | self.model = model 14 | self.optimizer = optimizer 15 | self.scheduler = scheduler 16 | self.device = device 17 | self.lossfn = LossFn(self.device) 18 | self.logger = logger 19 | self.run_count = 0 20 | self.scalar_info = {} 21 | 22 | def compute_accuracy(self, prob_cls, gt_cls): 23 | #we only need the detection which >= 0 24 | prob_cls = torch.squeeze(prob_cls) 25 | mask = torch.ge(gt_cls, 0) 26 | #get valid element 27 | valid_gt_cls = torch.masked_select(gt_cls, mask) 28 | valid_prob_cls = torch.masked_select(prob_cls, mask) 29 | size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) 30 | prob_ones = torch.ge(valid_prob_cls, 0.6).float() 31 | right_ones = torch.eq(prob_ones, valid_gt_cls.float()).float() 32 | 33 | return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) 34 | 35 | def update_lr(self, epoch): 36 | """ 37 | update learning rate of optimizers 38 | :param epoch: current training epoch 39 | """ 40 | # update learning rate of model optimizer 41 | for param_group in self.optimizer.param_groups: 42 | param_group['lr'] = self.lr 43 | 44 | def train(self, epoch): 45 | cls_loss_ = AverageMeter() 46 | box_offset_loss_ = AverageMeter() 47 | total_loss_ = AverageMeter() 48 | accuracy_ = AverageMeter() 49 | 50 | self.scheduler.step() 51 | self.model.train() 52 | 53 | for batch_idx, (data, target) in enumerate(self.train_loader): 54 | gt_label = target['label'] 55 | gt_bbox = target['bbox_target'] 56 | data, gt_label, gt_bbox = data.to(self.device), gt_label.to( 57 | self.device), gt_bbox.to(self.device).float() 58 | 59 | cls_pred, box_offset_pred = self.model(data) 60 | # compute the loss 61 | cls_loss = self.lossfn.cls_loss(gt_label, cls_pred) 62 | box_offset_loss = self.lossfn.box_loss( 63 | gt_label, gt_bbox, box_offset_pred) 64 | 65 | total_loss = cls_loss + box_offset_loss * 0.5 66 | accuracy = self.compute_accuracy(cls_pred, gt_label) 67 | 68 | self.optimizer.zero_grad() 69 | total_loss.backward() 70 | self.optimizer.step() 71 | 72 | cls_loss_.update(cls_loss, data.size(0)) 73 | box_offset_loss_.update(box_offset_loss, data.size(0)) 74 | total_loss_.update(total_loss, data.size(0)) 75 | accuracy_.update(accuracy, data.size(0)) 76 | 77 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}'.format( 78 | epoch, batch_idx * len(data), len(self.train_loader.dataset), 79 | 100. * batch_idx / len(self.train_loader), total_loss.item(), accuracy.item())) 80 | 81 | self.scalar_info['cls_loss'] = cls_loss_.avg 82 | self.scalar_info['box_offset_loss'] = box_offset_loss_.avg 83 | self.scalar_info['total_loss'] = total_loss_.avg 84 | self.scalar_info['accuracy'] = accuracy_.avg 85 | self.scalar_info['lr'] = self.scheduler.get_lr()[0] 86 | 87 | if self.logger is not None: 88 | for tag, value in list(self.scalar_info.items()): 89 | self.logger.scalar_summary(tag, value, self.run_count) 90 | self.scalar_info = {} 91 | self.run_count += 1 92 | 93 | print("|===>Loss: {:.4f}".format(total_loss_.avg)) 94 | return cls_loss_.avg, box_offset_loss_.avg, total_loss_.avg, accuracy_.avg -------------------------------------------------------------------------------- /training/rnet/config.py: -------------------------------------------------------------------------------- 1 | class Config(object): 2 | def __init__(self): 3 | super(Config, self).__init__() 4 | # ------------ General options ---------------------------------------- 5 | self.save_path = "./results/rnet/" 6 | self.dataPath = "/home/dataset/WIDER/WIDER_train/images" # path for loading data set 7 | self.annoPath = "./annotations/imglist_anno_24.txt" 8 | self.manualSeed = 1 # manually set RNG seed 9 | self.use_cuda = True 10 | self.GPU = "1" # default gpu to use 11 | 12 | # ------------- Data options ------------------------------------------- 13 | self.nThreads = 8 # number of data loader threads 14 | 15 | # ---------- Optimization options -------------------------------------- 16 | self.nEpochs = 50 # number of total epochs to train 400 17 | self.batchSize = 512 # mini-batch size 128 18 | 19 | # lr master for optimizer 1 (mask vector d) 20 | self.lr = 0.001 # initial learning rate 21 | self.step = [10, 25, 40] # step for linear or exp learning rate policy 22 | self.decayRate = 0.1 # lr decay rate 23 | self.endlr = -1 24 | 25 | # ---------- Model options --------------------------------------------- 26 | self.experimentID = "072502" 27 | 28 | # ---------- Resume or Retrain options --------------------------------------------- 29 | self.resume = None # "./checkpoint_064.pth" 30 | self.retrain = None 31 | 32 | self.save_path = self.save_path + "log_bs{:d}_lr{:.3f}_{}/".format(self.batchSize, self.lr, self.experimentID) -------------------------------------------------------------------------------- /training/rnet/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | 4 | import argparse 5 | import torch 6 | from tools.imagedb import FaceDataset 7 | from torchvision import transforms 8 | from models.rnet import RNet 9 | from training.rnet.trainer import RNetTrainer 10 | from training.rnet.config import Config 11 | from tools.logger import Logger 12 | from checkpoint import CheckPoint 13 | import os 14 | import config 15 | 16 | # Get config 17 | config = Config() 18 | if not os.path.exists(config.save_path): 19 | os.makedirs(config.save_path) 20 | 21 | # Set device 22 | os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU 23 | use_cuda = config.use_cuda and torch.cuda.is_available() 24 | torch.manual_seed(config.manualSeed) 25 | torch.cuda.manual_seed(config.manualSeed) 26 | device = torch.device("cuda:0" if use_cuda else "cpu") 27 | torch.backends.cudnn.benchmark = True 28 | 29 | # Set dataloader 30 | kwargs = {'num_workers': config.nThreads, 'pin_memory': True} if use_cuda else {} 31 | transform = transforms.Compose([ 32 | transforms.ToTensor(), 33 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 34 | ]) 35 | train_loader = torch.utils.data.DataLoader( 36 | FaceDataset(config.annoPath, transform=transform, is_train=True), batch_size=config.batchSize, shuffle=True, **kwargs) 37 | 38 | # Set model 39 | model = RNet() 40 | model = model.to(device) 41 | 42 | # Set checkpoint 43 | checkpoint = CheckPoint(config.save_path) 44 | 45 | # Set optimizer 46 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 47 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.step, gamma=0.1) 48 | 49 | # Set trainer 50 | logger = Logger(config.save_path) 51 | trainer = RNetTrainer(config.lr, train_loader, model, optimizer, scheduler, logger, device) 52 | 53 | for epoch in range(1, config.nEpochs + 1): 54 | cls_loss_, box_offset_loss, total_loss, accuracy = trainer.train(epoch) 55 | checkpoint.save_model(model, index=epoch) 56 | -------------------------------------------------------------------------------- /training/rnet/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | import time 4 | from models.lossfn import LossFn 5 | from tools.utils import AverageMeter 6 | 7 | 8 | class RNetTrainer(object): 9 | 10 | def __init__(self, lr, train_loader, model, optimizer, scheduler, logger, device): 11 | self.lr = lr 12 | self.train_loader = train_loader 13 | self.model = model 14 | self.optimizer = optimizer 15 | self.scheduler = scheduler 16 | self.device = device 17 | self.lossfn = LossFn(self.device) 18 | self.logger = logger 19 | self.run_count = 0 20 | self.scalar_info = {} 21 | 22 | def compute_accuracy(self, prob_cls, gt_cls): 23 | #we only need the detection which >= 0 24 | prob_cls = torch.squeeze(prob_cls) 25 | mask = torch.ge(gt_cls, 0) 26 | #get valid element 27 | valid_gt_cls = torch.masked_select(gt_cls, mask) 28 | valid_prob_cls = torch.masked_select(prob_cls, mask) 29 | size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) 30 | prob_ones = torch.ge(valid_prob_cls, 0.6).float() 31 | right_ones = torch.eq(prob_ones, valid_gt_cls.float()).float() 32 | 33 | return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) 34 | 35 | def update_lr(self, epoch): 36 | """ 37 | update learning rate of optimizers 38 | :param epoch: current training epoch 39 | """ 40 | # update learning rate of model optimizer 41 | for param_group in self.optimizer.param_groups: 42 | param_group['lr'] = self.lr 43 | 44 | def train(self, epoch): 45 | cls_loss_ = AverageMeter() 46 | box_offset_loss_ = AverageMeter() 47 | total_loss_ = AverageMeter() 48 | accuracy_ = AverageMeter() 49 | 50 | self.scheduler.step() 51 | self.model.train() 52 | 53 | for batch_idx, (data, target) in enumerate(self.train_loader): 54 | gt_label = target['label'] 55 | gt_bbox = target['bbox_target'] 56 | data, gt_label, gt_bbox = data.to(self.device), gt_label.to( 57 | self.device), gt_bbox.to(self.device).float() 58 | 59 | cls_pred, box_offset_pred = self.model(data) 60 | # compute the loss 61 | cls_loss = self.lossfn.cls_loss(gt_label, cls_pred) 62 | box_offset_loss = self.lossfn.box_loss( 63 | gt_label, gt_bbox, box_offset_pred) 64 | 65 | total_loss = cls_loss + box_offset_loss * 0.5 66 | accuracy = self.compute_accuracy(cls_pred, gt_label) 67 | 68 | self.optimizer.zero_grad() 69 | total_loss.backward() 70 | self.optimizer.step() 71 | 72 | cls_loss_.update(cls_loss, data.size(0)) 73 | box_offset_loss_.update(box_offset_loss, data.size(0)) 74 | total_loss_.update(total_loss, data.size(0)) 75 | accuracy_.update(accuracy, data.size(0)) 76 | 77 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.6f}'.format( 78 | epoch, batch_idx * len(data), len(self.train_loader.dataset), 79 | 100. * batch_idx / len(self.train_loader), total_loss.item(), accuracy.item())) 80 | 81 | self.scalar_info['cls_loss'] = cls_loss_.avg 82 | self.scalar_info['box_offset_loss'] = box_offset_loss_.avg 83 | self.scalar_info['total_loss'] = total_loss_.avg 84 | self.scalar_info['accuracy'] = accuracy_.avg 85 | self.scalar_info['lr'] = self.scheduler.get_lr()[0] 86 | 87 | if self.logger is not None: 88 | for tag, value in list(self.scalar_info.items()): 89 | self.logger.scalar_summary(tag, value, self.run_count) 90 | self.scalar_info = {} 91 | self.run_count += 1 92 | 93 | print("|===>Loss: {:.4f}".format(total_loss_.avg)) 94 | return cls_loss_.avg, box_offset_loss_.avg, total_loss_.avg, accuracy_.avg --------------------------------------------------------------------------------