├── .gitignore ├── MTCNN.py ├── MTCNN_debug.py ├── README.md ├── checkpoint.py ├── config.py ├── gen_dataset ├── assemble.py ├── assemble.pyc ├── assemble_Onet_imglist.py ├── assemble_Pnet_imglist.py ├── assemble_Rnet_imglist.py ├── gen_Onet_data.py ├── gen_Onet_landmark.py ├── gen_Pnet_data.py ├── gen_Pnet_data_using_prob.py ├── gen_Rnet_data.py ├── gen_user_Onet_data.py ├── gen_user_Pnet_data.py ├── gen_user_Rnet_data.py ├── transform_mat2txt.py └── wider_loader.py ├── loss.py ├── nets ├── mtcnn.py ├── prunable_layers.py ├── prune_mtcnn.py ├── quant_mtcnn.py └── slim_mtcnn.py ├── onnx2ncnn ├── onet.onnx ├── onnx2ncnn.sh ├── pnet.onnx └── rnet.onnx ├── pretrained_weights ├── mtcnn │ ├── best_onet.pth │ ├── best_onet_landmark.pth │ ├── best_pnet.pth │ └── best_rnet.pth └── quant_mtcnn │ └── best_pnet.pth ├── prunning ├── Onet_Prune.py ├── Pnet_Prune.py ├── Rnet_Prune.py ├── pnet │ ├── prune.py │ └── pruner.py └── utils │ ├── FilterPrunner.py │ ├── prunable_layers.py │ ├── prune_mtcnn.py │ └── util.py ├── quantization-aware_training └── pnet │ ├── train.py │ └── trainer.py ├── test ├── results │ └── fddb.png ├── test_FDDB.py ├── test_image.py └── test_video.py ├── tools ├── average_meter.py ├── dataset.py └── utils.py └── training ├── onet ├── landmark_train.py ├── landmark_trainer.py ├── train.py └── trainer.py ├── pnet ├── train.py └── trainer.py └── rnet ├── train.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | data 3 | annotations 4 | -------------------------------------------------------------------------------- /MTCNN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms as transforms 8 | 9 | from nets.mtcnn import ONet 10 | from nets.mtcnn import PNet 11 | from nets.mtcnn import RNet 12 | import tools.utils as utils 13 | import config 14 | 15 | class MTCNNDetector(object): 16 | ''' P, R, O net for face detection and alignment''' 17 | def __init__(self, 18 | p_model_path=None, 19 | r_model_path=None, 20 | o_model_path=None, 21 | min_face_size=12, 22 | stride=2, 23 | threshold=[0.6, 0.7, 0.7], 24 | scale_factor=0.709, 25 | use_cuda=True): 26 | 27 | self.pnet_detector, self.rnet_detector, self.onet_detector = self.create_mtcnn_net( 28 | p_model_path, r_model_path, o_model_path, use_cuda) 29 | 30 | self.min_face_size = min_face_size 31 | self.stride = stride 32 | self.thresh = threshold 33 | self.scale_factor = scale_factor 34 | 35 | def create_mtcnn_net(self, p_model_path=None, r_model_path=None, o_model_path=None, use_cuda=True): 36 | '''Create MTCNN Pnet, Rnet, Onet, load weights if there are any.''' 37 | pnet, rnet, onet = None, None, None 38 | self.device = torch.device( 39 | "cuda:1" if use_cuda and torch.cuda.is_available() else "cpu") 40 | 41 | if p_model_path is not None: 42 | pnet = PNet() 43 | pnet.load_state_dict(torch.load(p_model_path)) 44 | if (use_cuda): 45 | pnet.to(self.device) 46 | pnet.eval() 47 | 48 | if r_model_path is not None: 49 | rnet = RNet() 50 | rnet.load_state_dict(torch.load(r_model_path)) 51 | if (use_cuda): 52 | rnet.to(self.device) 53 | rnet.eval() 54 | 55 | if o_model_path is not None: 56 | onet = ONet(train_landmarks=True) 57 | onet.load_state_dict(torch.load(o_model_path)) 58 | if (use_cuda): 59 | onet.to(self.device) 60 | onet.eval() 61 | 62 | return pnet, rnet, onet 63 | 64 | def generate_bounding_box(self, cls_map, bbox_map, scale, threshold): 65 | ''' 66 | generate bounding bboxes from feature map 67 | for PNet, there exists no fc layer, only convolution layer, 68 | so feature map n x m x 2/4, 2 for classification, 4 for bboxes 69 | 70 | Parameters: 71 | ----------- 72 | cls_map: numpy array , 1 x n x m x 2, detect score for each position 73 | bbox_map: numpy array , 1 x n x m x 4, detect bbox regression value for each position 74 | scale: float number, scale of this detection 75 | threshold: float number, detect threshold 76 | Returns: 77 | -------- 78 | bbox array 79 | ''' 80 | stride = config.STRIDE 81 | cellsize = config.PNET_SIZE 82 | # softmax layer 1 for face, return a tuple with an array of row idxs and 83 | # an array of col idxs 84 | # locate face above threshold from cls_map 85 | t_index = np.where(cls_map[0, :, :, 1] > threshold) 86 | 87 | # find nothing 88 | if t_index[0].size == 0: 89 | return np.array([]) 90 | 91 | dx1, dy1, dx2, dy2 = [bbox_map[0, t_index[0], t_index[1], i] 92 | for i in range(4)] 93 | bbox_map = np.array([dx1, dy1, dx2, dy2]) 94 | 95 | score = cls_map[0, t_index[0], t_index[1], 1] 96 | boundingbox = np.vstack([np.round((stride * t_index[1]) / scale), 97 | np.round((stride * t_index[0]) / scale), 98 | np.round( 99 | (stride * t_index[1] + cellsize) / scale), 100 | np.round( 101 | (stride * t_index[0] + cellsize) / scale), 102 | score, 103 | bbox_map, 104 | ]) 105 | 106 | return boundingbox.T 107 | 108 | def resize_image(self, img, scale): 109 | """ 110 | Resize image and transform dimention to [batchsize, channel, height, width] 111 | Parameters: 112 | ---------- 113 | img: numpy array , height x width x channel, input image, channels in BGR order here 114 | scale: float number, scale factor of resize operation 115 | Returns: 116 | ------- 117 | transformed image tensor , 1 x channel x height x width 118 | """ 119 | height, width, channels = img.shape 120 | new_height = int(height * scale) # resized new height 121 | new_width = int(width * scale) # resized new width 122 | new_size = (new_width, new_height) 123 | img_resized = cv2.resize( 124 | img, new_size, interpolation=cv2.INTER_LINEAR) # resized image 125 | return img_resized 126 | 127 | def detect_pnet(self, im): 128 | """Get face candidates through pnet 129 | 130 | Parameters: 131 | ----------- 132 | im: numpy array, input image array 133 | 134 | Returns: 135 | -------- 136 | bboxes_align: numpy array 137 | bboxes after calibration 138 | """ 139 | h, w, c = im.shape 140 | net_size = config.PNET_SIZE 141 | current_scale = float(net_size) / self.min_face_size # find initial scale 142 | im_resized = self.resize_image(im, current_scale) 143 | current_height, current_width, _ = im_resized.shape 144 | 145 | # bounding boxes for all the pyramid scales 146 | all_bboxes = list() 147 | # generating bounding boxes for each scale 148 | while min(current_height, current_width) > net_size: 149 | image_tensor = utils.convert_image_to_tensor(im_resized) 150 | feed_imgs = image_tensor.unsqueeze(0) 151 | feed_imgs = feed_imgs.to(self.device) 152 | 153 | cls_map, reg_map = self.pnet_detector(feed_imgs) 154 | cls_map_np = utils.convert_chwTensor_to_hwcNumpy(cls_map.cpu()) 155 | reg_map_np = utils.convert_chwTensor_to_hwcNumpy(reg_map.cpu()) 156 | bboxes = self.generate_bounding_box( 157 | cls_map_np, reg_map_np, current_scale, self.thresh[0]) 158 | 159 | current_scale *= self.scale_factor 160 | im_resized = self.resize_image(im, current_scale) 161 | current_height, current_width, _ = im_resized.shape 162 | 163 | if bboxes.size == 0: 164 | continue 165 | 166 | keep = utils.nms(bboxes[:, :5], 0.5, 'Union') 167 | bboxes = bboxes[keep] 168 | all_bboxes.append(bboxes) 169 | 170 | 171 | if len(all_bboxes) == 0: 172 | return None 173 | 174 | all_bboxes = np.vstack(all_bboxes) 175 | 176 | # apply nms to the detections from all the scales 177 | keep = utils.nms(all_bboxes[:, 0:5], 0.7, 'Union') 178 | all_bboxes = all_bboxes[keep] 179 | 180 | # 0-4: original bboxes, 5: score, 5: offsets 181 | bboxes_align = utils.calibrate_box(all_bboxes[:, 0:5], all_bboxes[:, 5:]) 182 | bboxes_align = utils.convert_to_square(bboxes_align) 183 | bboxes_align[:, 0:4] = np.round(bboxes_align[:, 0:4]) 184 | 185 | return bboxes_align 186 | 187 | def detect_rnet(self, im, bboxes): 188 | """Get face candidates using rnet 189 | 190 | Parameters: 191 | ---------- 192 | im: numpy array 193 | input image array 194 | bboxes: numpy array 195 | detection results of pnet 196 | 197 | Returns: 198 | ------- 199 | bboxes_align: numpy array 200 | bboxes after calibration 201 | """ 202 | net_size = config.RNET_SIZE 203 | h, w, c = im.shape 204 | if bboxes is None: 205 | return None 206 | 207 | num_bboxes = bboxes.shape[0] 208 | 209 | [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = utils.correct_bboxes(bboxes, w, h) 210 | 211 | 212 | # crop face using pnet proposals 213 | cropped_ims_tensors = [] 214 | for i in range(num_bboxes): 215 | try: 216 | if tmph[i] > 0 and tmpw[i] > 0: 217 | tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8) 218 | tmp[dy[i]:edy[i], dx[i]:edx[i], :] = im[y[i]:ey[i], x[i]:ex[i], :] 219 | crop_im = cv2.resize(tmp, (net_size, net_size)) 220 | crop_im_tensor = utils.convert_image_to_tensor(crop_im) 221 | cropped_ims_tensors.append(crop_im_tensor) 222 | except ValueError as e: 223 | print('dy: {}, edy: {}, dx: {}, edx: {}'.format(dy[i], edy[i], dx[i], edx[i])) 224 | print('y: {}, ey: {}, x: {}, ex: {}'.format(y[i], ey[i], x[i], ex[i])) 225 | print(e) 226 | 227 | # provide input tensor, if there are too many proposals in PNet 228 | # there might be OOM 229 | feed_imgs = torch.stack(cropped_ims_tensors) 230 | feed_imgs = feed_imgs.to(self.device) 231 | 232 | cls, reg = self.rnet_detector(feed_imgs) 233 | cls = cls.cpu().data.numpy() 234 | reg = reg.cpu().data.numpy() 235 | 236 | keep_inds = np.where(cls[:, 1] > self.thresh[1])[0] 237 | if len(keep_inds) > 0: 238 | keep_bboxes = bboxes[keep_inds] 239 | keep_cls = cls[keep_inds, :] 240 | keep_reg = reg[keep_inds] 241 | # using softmax 1 as cls score 242 | keep_bboxes[:, 4] = keep_cls[:, 1].reshape((-1,)) 243 | else: 244 | return None 245 | 246 | keep = utils.nms(keep_bboxes, 0.7) 247 | if len(keep) == 0: 248 | return None 249 | 250 | keep_cls = keep_cls[keep] 251 | keep_bboxes = keep_bboxes[keep] 252 | keep_reg = keep_reg[keep] 253 | 254 | bboxes_align = utils.calibrate_box(keep_bboxes, keep_reg) 255 | bboxes_align = utils.convert_to_square(bboxes_align) 256 | bboxes_align[:, 0:4] = np.round(bboxes_align[:, 0:4]) 257 | 258 | return bboxes_align 259 | 260 | def detect_onet(self, im, bboxes): 261 | """Get face candidates using onet 262 | 263 | Parameters: 264 | ---------- 265 | im: numpy array 266 | input image array 267 | bboxes: numpy array 268 | detection results of rnet 269 | 270 | Returns: 271 | ------- 272 | bboxes_align: numpy array 273 | bboxes after calibration 274 | """ 275 | net_size = config.ONET_SIZE 276 | h, w, c = im.shape 277 | if bboxes is None: 278 | return None 279 | 280 | [dy, edy, dx, edx, y, ey, x, ex, tmpw, tmph] = utils.correct_bboxes(bboxes, w, h) 281 | num_bboxes = bboxes.shape[0] 282 | 283 | # crop face using rnet proposal 284 | cropped_ims_tensors = [] 285 | for i in range(num_bboxes): 286 | try: 287 | if tmph[i] > 0 and tmpw[i] > 0: 288 | tmp = np.zeros((tmph[i], tmpw[i], 3), dtype=np.uint8) 289 | tmp[dy[i]:edy[i], dx[i]:edx[i], :] = im[y[i]:ey[i], x[i]:ex[i], :] 290 | crop_im = cv2.resize(tmp, (net_size, net_size)) 291 | crop_im_tensor = utils.convert_image_to_tensor(crop_im) 292 | cropped_ims_tensors.append(crop_im_tensor) 293 | except ValueError as e: 294 | print(e) 295 | 296 | feed_imgs = torch.stack(cropped_ims_tensors) 297 | feed_imgs = feed_imgs.to(self.device) 298 | 299 | cls, reg, landmarks = self.onet_detector(feed_imgs) 300 | cls = cls.cpu().data.numpy() 301 | reg = reg.cpu().data.numpy() 302 | landmarks = landmarks.cpu().data.numpy() 303 | 304 | 305 | keep_inds = np.where(cls[:, 1] > self.thresh[2])[0] 306 | 307 | if len(keep_inds) > 0: 308 | keep_bboxes = bboxes[keep_inds] 309 | keep_cls = cls[keep_inds, :] 310 | keep_reg = reg[keep_inds] 311 | keep_landmarks = landmarks[keep_inds, :] 312 | keep_bboxes[:, 4] = keep_cls[:, 1].reshape((-1,)) 313 | else: 314 | return None 315 | 316 | # compute landmarks point 317 | # width = keep_bboxes[:, 2] - bboxes[:, 0] + 1.0 318 | # height = keep_bboxes[:, 3] - bboxes[:, 1] + 1.0 319 | # xmin, ymin = bboxes[:, 0], bboxes[:, 1] 320 | 321 | bboxes_align = utils.calibrate_box(keep_bboxes, keep_reg) 322 | keep = utils.nms(bboxes_align, 0.7, mode='Minimum') 323 | 324 | if len(keep) == 0: 325 | return None 326 | 327 | bboxes_align = bboxes_align[keep] 328 | bboxes_align = utils.convert_to_square(bboxes_align) 329 | landmarks = keep_landmarks[keep] 330 | 331 | return bboxes_align, landmarks 332 | 333 | def detect_face(self, img): 334 | ''' Detect face over image ''' 335 | bboxes_align = np.array([]) 336 | 337 | t = time.time() 338 | 339 | # pnet 340 | if self.pnet_detector: 341 | bboxes_align = self.detect_pnet(img) 342 | if bboxes_align is None: 343 | return np.array([]) 344 | 345 | t1 = time.time() - t 346 | t = time.time() 347 | 348 | # rnet 349 | if self.rnet_detector: 350 | bboxes_align = self.detect_rnet(img, bboxes_align) 351 | if bboxes_align is None: 352 | return np.array([]) 353 | 354 | t2 = time.time() - t 355 | t = time.time() 356 | 357 | # onet 358 | if self.onet_detector: 359 | bboxes_align, landmarks = self.detect_onet(img, bboxes_align) 360 | if bboxes_align is None: 361 | return np.array([]) 362 | 363 | t3 = time.time() - t 364 | t = time.time() 365 | print( 366 | "time cost " + '{:.3f}'.format(t1 + t2 + t3) + \ 367 | ' pnet {:.3f} rnet {:.3f} onet {:.3f}'.format(t1, t2, t3)) 368 | 369 | return bboxes_align, landmarks 370 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Repo for training, optimizing and deploying MTCNN for mobile devices 2 | 3 | ## Prepare dataset 4 | 1. Download [WIDER FACE]() dataset and put them under `./data` directory. 5 | 2. Transform matlab format label train and val format file into text format. 6 | RUN `pyhon gen_dataset/transform_mat2txt.py` 7 | Change the mode variable in `transform_mat2txt.py` to `val` to generate val data label. 8 | 9 | ## Train MTCNN 10 | ### Train Pnet 11 | 1. generate pnet training data: 12 | RUN `python gen_dataset/gen_Pnet_data.py` 13 | Change the mode variable in `gen_Pnet_data.py` to `val` to generate val data label. 14 | 2. Training Pnet: 15 | RUN `python training/pnet/train.py` to train your model. 16 | 3. Save weights: 17 | We use the validation dataset to help us with choose the best pnet model. The weights are saved in `pretrained_weights/mtcnn/best_pnet.pth` 18 | 19 | ### Train Rnet 20 | After we trained Pnet, we can use Pnet to generate data for training Rnet. 21 | 1. generate Rnet training data: 22 | RUN `python gen_dataset/gen_Rnet_data.py` 23 | Change the mode variable in `gen_Pnet_data.py` to `val` to generate val data label. 24 | 2. Training Rnet: 25 | RUN `python training/rnet/train.py` to train your model. 26 | 3. 3. Save weights: 27 | We use the validation dataset to help us with choose the best rnet model. The weights are saved in `pretrained_weights/mtcnn/best_rnet.pth` 28 | 29 | ### Train Onet 30 | After we trained Pnet and Rnet, we can use Pnet and Rnet to generate data for training Onet. 31 | 1. generate Onet training data: 32 | RUN `python gen_dataset/gen_Onet_data.py` 33 | Change the mode variable in `gen_Pnet_data.py` to `val` to generate val data label. 34 | 2. Training Onet: 35 | RUN `python training/Onet/train.py` to train your model. 36 | 3. 3. Save weights: 37 | We use the validation dataset to help us with choose the best onet model. The weights are saved in `pretrained_weights/mtcnn/best_onet.pth` 38 | 39 | ### Results 40 | 41 | | WIDER FACE | Pnet | Rnet | Onet | 42 | | :---------: |:------:|:-----:|:-----:| 43 | | cls loss | 0.156 | 0.120| 0.129 | 44 | | offset loss | 0.01 | 0.01 | 0.0063| 45 | | cls acc | 0.944 | 0.962| 0.956 | 46 | 47 | | PRIVATE DATA| Pnet | Rnet | Onet | 48 | | :---------: |:------:|:-----:|:-----:| 49 | | cls loss | 0.05 | 0.09 | 0.104 | 50 | | offset loss | 0.0047 | 0.011 | 0.0057| 51 | | cls acc | 0.983 | 0.971 | 0.970 | 52 | 53 | ## Optimize MTCNN 54 | ### Lighter MTCNN 55 | By combine shufflenet structure and mobilenet structure we can design light weight Pnet, Rnet, and Onet. In this way can can optimize the size of the model and at the same time decrease the inference speeed. 56 | 57 | ### Larger Pnet 58 | According to my observation, small pnet brings many false positives which becomes a burden or rnet and onet. By increase the Pnet size, there will be less false positives and improve the overall efficiency. 59 | 60 | ### Prune MTCNN 61 | 62 | Model Prunning is a better strategy than design mobile cnn for such small networks as Pnet, Rnet, and Onet. By iteratively pruning MTCNN models, we can decrease and model size and improve inference speed at the same time. 63 | 64 | | PRIVATE DATA| Pnet | Rnet | Onet | 65 | | :---------: |:------:|:-----:|:-----:| 66 | | cls loss | 0.091 | 0.1223 | 0.1055 | 67 | | offset loss | 0.0055 | 0.0116 | 0.0062 | 68 | | cls acc | 0.970 | 0.958 | 0.959 | 69 | 70 | Inference speed benchmark using ncnn inference framework, we can seen from the chart below that the inference speed has been increased by 2-3 times. 71 | 72 | ``` 73 | pnet min = 27.31 max = 28.31 avg = 27.62 74 | rnet min = 0.50 max = 0.62 avg = 0.58 75 | onet min = 3.14 max = 3.82 avg = 3.25 76 | pruned_pnet min = 6.76 max = 7.13 avg = 6.89 77 | pruned_rnet min = 0.21 max = 0.22 avg = 0.21 78 | pruned_onet min = 1.16 max = 1.52 avg = 1.27 79 | ``` 80 | 81 | We could also treat prunning process as a *NAS(network architecture search)* procesure. After we obtained the model, we could train it from zero. And I achieved better accuracy using this method on pruned model above. 82 | 83 | | PRIVATE DATA| Pnet | Rnet | Onet | 84 | | :---------: |:------:|:-----:|:-----:| 85 | | cls loss | 0.0083 | 0.1038 | 0.0923 | 86 | | offset loss | 0.00588 | 0.012 | 0.00588 | 87 | | cls acc | 0.9718 | 0.965 | 0.9706 | 88 | 89 | ### Quantization Aware Training 90 | 91 | By using quantization aware training library [brevitas](https://github.com/Xilinx/brevitas), I managed to achieve 96.2% accuracy on Pnet which is 2% lower than the original version, but the model size if 4x smaller and the inference speed is to be estimated. 92 | 93 | However, when training Rnet and Onet, OOM errors occured. I will figure out why in the future. 94 | 95 | | PRIVATE DATA| Pnet | Rnet | Onet | 96 | | :---------: |:------:|:-----:|:-----:| 97 | | cls loss | 0.107 | - | - | 98 | | offset loss | 0.0080 | - | - | 99 | | cls acc | 0.962 | - | - | 100 | 101 | 102 | ### Knowledge Distillation 103 | 104 | ## Deploy MTCNN 105 | 106 | 107 | ## Todo 108 | - [ ] Data Augmentation to avoid overfitting 109 | - [ ] Use L1 Smooth loss or WingLoss for bbox and landmarks localization 110 | 111 | 112 | ## References 113 | 1. https://github.com/xuexingyu24/MTCNN_Tutorial 114 | 2. https://github.com/xuexingyu24/Pruning_MTCNN_MobileFaceNet_Using_Pytorch 115 | 116 | -------------------------------------------------------------------------------- /checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import config 7 | from tools import utils 8 | 9 | class CheckPoint(object): 10 | """ 11 | save model state to file 12 | check_point_params: model, optimizer, epoch 13 | """ 14 | 15 | def __init__(self, save_path): 16 | 17 | self.save_path = os.path.join(save_path, "check_points") 18 | self.check_point_params = {'model': None, 19 | 'optimizer': None, 20 | 'epoch': None} 21 | 22 | # make directory 23 | if not os.path.isdir(self.save_path): 24 | os.makedirs(self.save_path) 25 | 26 | def load_state(self, model, state_dict): 27 | """ 28 | load state_dict to model 29 | :params model: 30 | :params state_dict: 31 | :return: model 32 | """ 33 | model.eval() 34 | model_dict = model.state_dict() 35 | 36 | for key, value in list(state_dict.items()): 37 | if key in list(model_dict.keys()): 38 | # print key, value.size() 39 | model_dict[key] = value 40 | else: 41 | pass 42 | # print "key error:", key, value.size() 43 | model.load_state_dict(model_dict) 44 | # model.load_state_dict(state_dict) 45 | # set the model in evaluation mode, otherwise the accuracy will change 46 | # model.eval() 47 | # model.load_state_dict(state_dict) 48 | 49 | return model 50 | 51 | def load_model(self, model_path): 52 | """ 53 | load model 54 | :params model_path: path to the model 55 | :return: model_state_dict 56 | """ 57 | if os.path.isfile(model_path): 58 | print("|===>Load retrain model from:", model_path) 59 | # model_state_dict = torch.load(model_path, map_location={'cuda:1':'cuda:0'}) 60 | model_state_dict = torch.load(model_path, map_location='cpu') 61 | return model_state_dict 62 | else: 63 | assert False, "file not exits, model path: " + model_path 64 | 65 | def load_checkpoint(self, checkpoint_path): 66 | """ 67 | load checkpoint file 68 | :params checkpoint_path: path to the checkpoint file 69 | :return: model_state_dict, optimizer_state_dict, epoch 70 | """ 71 | if os.path.isfile(checkpoint_path): 72 | print("|===>Load resume check-point from:", checkpoint_path) 73 | self.check_point_params = torch.load(checkpoint_path) 74 | model_state_dict = self.check_point_params['model'] 75 | optimizer_state_dict = self.check_point_params['optimizer'] 76 | epoch = self.check_point_params['epoch'] 77 | return model_state_dict, optimizer_state_dict, epoch 78 | else: 79 | assert False, "file not exits" + checkpoint_path 80 | 81 | def save_checkpoint(self, model, optimizer, epoch, index=0): 82 | """ 83 | :params model: model 84 | :params optimizer: optimizer 85 | :params epoch: training epoch 86 | :params index: index of saved file, default: 0 87 | Note: if we add hook to the grad by using register_hook(hook), then the hook function 88 | can not be saved so we need to save state_dict() only. Although save state dictionary 89 | is recommended, some times we still need to save the whole model as it can save all 90 | the information of the trained model, and we do not need to create a new network in 91 | next time. However, the GPU information will be saved too, which leads to some issues 92 | when we use the model on different machine 93 | """ 94 | 95 | # get state_dict from model and optimizer 96 | model = self.list2sequential(model) 97 | if isinstance(model, nn.DataParallel): 98 | model = model.module 99 | model = model.state_dict() 100 | optimizer = optimizer.state_dict() 101 | 102 | # save information to a dict 103 | self.check_point_params['model'] = model 104 | self.check_point_params['optimizer'] = optimizer 105 | self.check_point_params['epoch'] = epoch 106 | 107 | # save to file 108 | torch.save(self.check_point_params, os.path.join( 109 | self.save_path, "checkpoint_%03d.pth" % index)) 110 | 111 | def list2sequential(self, model): 112 | if isinstance(model, list): 113 | model = nn.Sequential(*model) 114 | return model 115 | 116 | def save_model(self, model, best_flag=False, index=0, tag=""): 117 | """ 118 | :params model: model to save 119 | :params best_flag: if True, the saved model is the one that gets best performance 120 | """ 121 | # get state dict 122 | model = self.list2sequential(model) 123 | if isinstance(model, nn.DataParallel): 124 | model = model.module 125 | model = model.state_dict() 126 | if best_flag: 127 | if tag != "": 128 | torch.save(model, os.path.join(self.save_path, "%s_best_model.pth"%tag)) 129 | else: 130 | torch.save(model, os.path.join(self.save_path, "best_model.pth")) 131 | else: 132 | if tag != "": 133 | torch.save(model, os.path.join(self.save_path, "%s_model_%03d.pth" % (tag, index))) 134 | else: 135 | torch.save(model, os.path.join(self.save_path, "model_%03d.pth" % index)) 136 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from easydict import EasyDict as edict 4 | ''' 5 | MODEL_STORE_DIR = "./models" 6 | ANNO_STORE_DIR = "./annotations" 7 | TRAIN_DATA_DIR = "/dev/data/MTCNN" 8 | ''' 9 | 10 | ''' 11 | # -------- generate training dataset ---------- # 12 | PROB_THRESH = 0.15 # for ./annotations/wider_prob.txt' threshold 13 | 14 | USE_CELEBA = False 15 | USE_USER = False 16 | 17 | SAVE_PREFIX = 'wider' 18 | 19 | WIDER_DATA_PATH = '/home/idealabs/data/opensource_dataset/WIDER/WIDER_train/images' 20 | WIDER_DATA_ANNO_FILE = 'wider_anno_tmp.txt' 21 | WIDER_DATA_PROB_FILE = 'wider_prob.txt' 22 | WIDER_VAL_DATA_ANNO_FILE = 'wider_anno_val.txt' 23 | 24 | USER_DATA_PATH = '/home/idealabs/data/opensource_dataset/user/imgs' 25 | USER_DATA_ANNO_FILE = 'user_anno.txt' 26 | 27 | CELEBA_DATA_PATH = '/home/idealabs/data/opensource_dataset/celeba/img_celeba' 28 | CELEBA_DATA_ANNO_FILE = 'small_celeba_anno.txt' 29 | ''' 30 | 31 | # --------------- training MTCNN config ------------- # 32 | ANNO_PATH = './annotations' 33 | 34 | STRIDE = 2 35 | PNET_SIZE = 12 36 | PNET_POSTIVE_ANNO_FILENAME = "pos_{0}.txt".format(PNET_SIZE) 37 | PNET_NEGATIVE_ANNO_FILENAME = "neg_{0}.txt".format(PNET_SIZE) 38 | PNET_PART_ANNO_FILENAME = "part_{0}.txt".format(PNET_SIZE) 39 | PNET_LANDMARK_ANNO_FILENAME = "landmark_{0}.txt".format(PNET_SIZE) 40 | 41 | RNET_SIZE = 24 42 | RNET_POSTIVE_ANNO_FILENAME = "pos_{0}.txt".format(RNET_SIZE) 43 | RNET_NEGATIVE_ANNO_FILENAME = "neg_{0}.txt".format(RNET_SIZE) 44 | RNET_PART_ANNO_FILENAME = "part_{0}.txt".format(RNET_SIZE) 45 | RNET_LANDMARK_ANNO_FILENAME = "landmark_{0}.txt".format(RNET_SIZE) 46 | 47 | ONET_SIZE = 48 48 | ONET_POSTIVE_ANNO_FILENAME = "pos_{0}.txt".format(ONET_SIZE) 49 | ONET_NEGATIVE_ANNO_FILENAME = "neg_{0}.txt".format(ONET_SIZE) 50 | ONET_PART_ANNO_FILENAME = "part_{0}.txt".format(ONET_SIZE) 51 | ONET_LANDMARK_ANNO_FILENAME = "landmark_{0}.txt".format(ONET_SIZE) 52 | 53 | PNET_TRAIN_IMGLIST_FILENAME = "imglist_anno_{0}_train.txt".format(PNET_SIZE) 54 | RNET_TRAIN_IMGLIST_FILENAME = "imglist_anno_{0}_train.txt".format(RNET_SIZE) 55 | ONET_TRAIN_IMGLIST_FILENAME = "imglist_anno_{0}_train.txt".format(ONET_SIZE) 56 | PNET_VAL_IMGLIST_FILENAME = 'imglist_anno_{0}_val.txt'.format(PNET_SIZE) 57 | RNET_VAL_IMGLIST_FILENAME = 'imglist_anno_{0}_val.txt'.format(RNET_SIZE) 58 | ONET_VAL_IMGLIST_FILENAME = 'imglist_anno_{0}_val.txt'.format(ONET_SIZE) 59 | 60 | 61 | USE_CUDA = True 62 | BATCH_SIZE = 1024 63 | LR = 0.01 64 | EPOCHS = 100 65 | STEPS = [10, 40, 80] 66 | 67 | # --------------------- tracking -----------------------# 68 | ''' 69 | TRACE = edict() 70 | TRACE.ema_or_one_euro='euro' ### post process 71 | TRACE.pixel_thres=1 72 | TRACE.smooth_box=0.3 ## if use euro, this will be disable 73 | TRACE.smooth_landmark=0.95 ## if use euro, this will be disable 74 | TRACE.iou_thres=0.5 75 | ''' 76 | -------------------------------------------------------------------------------- /gen_dataset/assemble.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy.random as npr 4 | import numpy as np 5 | 6 | def assemble_data(output_file, anno_file_list=[]): 7 | 8 | #assemble the pos, neg, part annotations to one file 9 | 10 | if len(anno_file_list)==0: 11 | return 0 12 | 13 | if os.path.exists(output_file): 14 | os.remove(output_file) 15 | 16 | for anno_file in anno_file_list: 17 | with open(anno_file, 'r') as f: 18 | print(anno_file) 19 | anno_lines = f.readlines() 20 | 21 | base_num = 250000 22 | 23 | if len(anno_lines) > base_num * 3: 24 | idx_keep = npr.choice(len(anno_lines), size=base_num * 3, replace=True) 25 | elif len(anno_lines) > 100000: 26 | idx_keep = npr.choice(len(anno_lines), size=len(anno_lines), replace=True) 27 | else: 28 | idx_keep = np.arange(len(anno_lines)) 29 | np.random.shuffle(idx_keep) 30 | chose_count = 0 31 | with open(output_file, 'a+') as f: 32 | for idx in idx_keep: 33 | # write lables of pos, neg, part images 34 | f.write(anno_lines[idx]) 35 | chose_count+=1 36 | 37 | return chose_count 38 | -------------------------------------------------------------------------------- /gen_dataset/assemble.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/digital-nomad-cheng/MTCNN_PyTorch/752b9f0f32f5c25df5647c7c279c24715dc3aa87/gen_dataset/assemble.pyc -------------------------------------------------------------------------------- /gen_dataset/assemble_Onet_imglist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | import assemble 5 | 6 | mode = 'val' 7 | onet_postive_file = 'annotations/pos_48_{}.txt'.format(mode) 8 | onet_part_file = 'annotations/part_48_{}.txt'.format(mode) 9 | onet_neg_file = 'annotations/neg_48_{}.txt'.format(mode) 10 | onet_landmark_file = 'annotations/landmark_48_{}.txt'.format(mode) 11 | imglist_filename = 'annotations/imglist_anno_48_{}.txt'.format(mode) 12 | 13 | if __name__ == '__main__': 14 | 15 | anno_list = [] 16 | 17 | anno_list.append(onet_postive_file) 18 | anno_list.append(onet_part_file) 19 | anno_list.append(onet_neg_file) 20 | anno_list.append(onet_landmark_file) 21 | 22 | chose_count = assemble.assemble_data(imglist_filename ,anno_list) 23 | print("ONet train annotation result file path:%s" % imglist_filename) 24 | -------------------------------------------------------------------------------- /gen_dataset/assemble_Pnet_imglist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | import assemble 5 | import config 6 | 7 | mode = 'val' 8 | net_size = config.PNET_SIZE 9 | 10 | pnet_postive_file = 'annotations/pos_{}_{}.txt'.format(net_size, mode) 11 | pnet_part_file = 'annotations/part_{}_{}.txt'.format(net_size, mode) 12 | pnet_neg_file = 'annotations/neg_{}_{}.txt'.format(net_size, mode) 13 | # pnet_landmark_file = './anno_store/landmark_12.txt' 14 | imglist_filename = 'annotations/imglist_anno_{}_{}.txt'.format(net_size, mode) 15 | 16 | if __name__ == '__main__': 17 | 18 | anno_list = [] 19 | 20 | anno_list.append(pnet_postive_file) 21 | anno_list.append(pnet_part_file) 22 | anno_list.append(pnet_neg_file) 23 | # anno_list.append(pnet_landmark_file) 24 | 25 | chose_count = assemble.assemble_data(imglist_filename ,anno_list) 26 | print("PNet train annotation result file path:%s" % imglist_filename) 27 | -------------------------------------------------------------------------------- /gen_dataset/assemble_Rnet_imglist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.getcwd()) 4 | import assemble 5 | 6 | mode = 'val' 7 | 8 | 9 | rnet_postive_file = 'annotations/pos_24_{}.txt'.format(mode) 10 | rnet_part_file = 'annotations/part_24_{}.txt'.format(mode) 11 | rnet_neg_file = 'annotations/neg_24_{}.txt'.format(mode) 12 | # pnet_landmark_file = './annotations/landmark_12.txt' 13 | imglist_filename = 'annotations/imglist_anno_24_{}.txt'.format(mode) 14 | 15 | if __name__ == '__main__': 16 | 17 | anno_list = [] 18 | 19 | anno_list.append(rnet_postive_file) 20 | anno_list.append(rnet_part_file) 21 | anno_list.append(rnet_neg_file) 22 | # anno_list.append(pnet_landmark_file) 23 | 24 | chose_count = assemble.assemble_data(imglist_filename ,anno_list) 25 | print("RNet train annotation result file path:%s" % imglist_filename) 26 | -------------------------------------------------------------------------------- /gen_dataset/gen_Onet_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate positive, negative, positive images whose size are 48*48 from PNet & RNet 3 | """ 4 | import os, sys 5 | sys.path.append('.') 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | 11 | from tools.utils import* 12 | from MTCNN import MTCNNDetector 13 | 14 | mode = 'val' 15 | prefix = '' 16 | anno_file = "annotations/wider_anno_{}.txt".format(mode) 17 | im_dir = "./data/WIDER_{}/images".format(mode) 18 | pos_save_dir = "./data/{}/48/positive".format(mode) 19 | part_save_dir = "./data/{}/48/part".format(mode) 20 | neg_save_dir = "./data/{}/48/negative".format(mode) 21 | 22 | if not os.path.exists(pos_save_dir): 23 | os.makedirs(pos_save_dir) 24 | if not os.path.exists(part_save_dir): 25 | os.makedirs(part_save_dir) 26 | if not os.path.exists(neg_save_dir): 27 | os.makedirs(neg_save_dir) 28 | 29 | # store labels of positive, negative, part images 30 | f1 = open(os.path.join('annotations', 'pos_48_{}.txt'.format(mode)), 'w') 31 | f2 = open(os.path.join('annotations', 'neg_48_{}.txt'.format(mode)), 'w') 32 | f3 = open(os.path.join('annotations', 'part_48_{}.txt'.format(mode)), 'w') 33 | 34 | # anno_file: store labels of the wider face training data 35 | with open(anno_file, 'r') as f: 36 | annotations = f.readlines() 37 | num = len(annotations) 38 | print("%d pics in total" % num) 39 | 40 | image_size = 48 41 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 42 | 43 | mtcnn_detector = MTCNNDetector(p_model_path='./pretrained_weights/mtcnn/best_pnet.pth', 44 | r_model_path='./pretrained_weights/mtcnn/best_rnet.pth') 45 | p_idx = 0 # positive 46 | n_idx = 0 # negative 47 | d_idx = 0 # dont care 48 | idx = 0 49 | for annotation in annotations: 50 | annotation = annotation.strip().split(' ') 51 | im_path = os.path.join(prefix, annotation[0]) 52 | print(im_path) 53 | bbox = list(map(float, annotation[1:])) 54 | boxes = np.array(bbox, dtype=np.int32).reshape(-1, 4) 55 | # anno form is x1, y1, w, h, convert to x1, y1, x2, y2 56 | boxes[:,2] += boxes[:,0] - 1 57 | boxes[:,3] += boxes[:,1] - 1 58 | 59 | image = cv2.imread(im_path) 60 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 61 | #bboxes, landmarks = create_mtcnn_net(image, 12, device, p_model_path='../train/pnet_Weights', r_model_path='../train/rnet_Weights') 62 | bboxes = mtcnn_detector.detect_face(image) 63 | 64 | if bboxes.shape[0] == 0: 65 | continue 66 | 67 | dets = np.round(bboxes[:, 0:4]) 68 | 69 | 70 | img = cv2.imread(im_path) 71 | idx += 1 72 | 73 | height, width, channel = img.shape 74 | 75 | for box in dets: 76 | x_left, y_top, x_right, y_bottom = box[0:4].astype(int) 77 | width = x_right - x_left + 1 78 | height = y_bottom - y_top + 1 79 | 80 | # ignore box that is too small or beyond image border 81 | if width < 20 or x_left < 0 or y_top < 0 or x_right > img.shape[1] - 1 or y_bottom > img.shape[0] - 1: 82 | continue 83 | 84 | # compute intersection over union(IoU) between current box and all gt boxes 85 | Iou = IoU(box, boxes) 86 | cropped_im = img[y_top:y_bottom + 1, x_left:x_right + 1, :] 87 | resized_im = cv2.resize(cropped_im, (image_size, image_size), 88 | interpolation=cv2.INTER_LINEAR) 89 | 90 | # save negative images and write label 91 | if np.max(Iou) < 0.2 and n_idx < 1.0*p_idx+1: 92 | # Iou with all gts must below 0.3 93 | save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx) 94 | f2.write(save_file + ' 0\n') 95 | cv2.imwrite(save_file, resized_im) 96 | n_idx += 1 97 | else: 98 | # find gt_box with the highest iou 99 | idx_Iou = np.argmax(Iou) 100 | assigned_gt = boxes[idx_Iou] 101 | x1, y1, x2, y2 = assigned_gt 102 | 103 | # compute bbox reg label 104 | offset_x1 = (x1 - x_left) / float(width) 105 | offset_y1 = (y1 - y_top) / float(height) 106 | offset_x2 = (x2 - x_right) / float(width) 107 | offset_y2 = (y2 - y_bottom) / float(height) 108 | 109 | # save positive and part-face images and write labels 110 | if np.max(Iou) >= 0.65: 111 | save_file = os.path.join(pos_save_dir, "%s.jpg" % p_idx) 112 | f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % ( 113 | offset_x1, offset_y1, offset_x2, offset_y2)) 114 | cv2.imwrite(save_file, resized_im) 115 | p_idx += 1 116 | 117 | elif np.max(Iou) >= 0.4 and d_idx < 1.0*p_idx + 1: 118 | save_file = os.path.join(part_save_dir, "%s.jpg" % d_idx) 119 | f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % ( 120 | offset_x1, offset_y1, offset_x2, offset_y2)) 121 | cv2.imwrite(save_file, resized_im) 122 | d_idx += 1 123 | 124 | print("%s images done, pos: %s part: %s neg: %s" % (idx, p_idx, d_idx, n_idx)) 125 | 126 | # if idx == 20: 127 | # break 128 | 129 | f1.close() 130 | f2.close() 131 | f3.close() 132 | -------------------------------------------------------------------------------- /gen_dataset/gen_Onet_landmark.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append('.') 3 | import time 4 | import random 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | from tools.utils import * 10 | 11 | '''bbox annotations in FacePoint dataset is x0, x1, y0, y1''' 12 | 13 | mode = 'train' 14 | prefix = '' 15 | data_dir = './data/FacePoint' 16 | anno_file = './data/FacePoint/{}ImageList.txt'.format(mode) 17 | save_dir = './data/{}/'.format(mode) 18 | size = 48 19 | image_id = 0 20 | 21 | landmark_imgs_save_dir = os.path.join(save_dir,"48/landmark") 22 | if not os.path.exists(landmark_imgs_save_dir): 23 | os.makedirs(landmark_imgs_save_dir) 24 | 25 | anno_dir = './annotations' 26 | if not os.path.exists(anno_dir): 27 | os.makedirs(anno_dir) 28 | 29 | landmark_anno_filename = "landmark_48_{}.txt".format(mode) 30 | save_landmark_anno = os.path.join(anno_dir,landmark_anno_filename) 31 | 32 | f = open(save_landmark_anno, 'w') 33 | 34 | with open(anno_file, 'r') as f2: 35 | annotations = f2.readlines() 36 | 37 | num = len(annotations) 38 | print("%d total images" % num) 39 | 40 | l_idx =0 41 | idx = 0 42 | 43 | for annotation in annotations: 44 | # print imgPath 45 | 46 | annotation = annotation.strip().split(' ') 47 | 48 | assert len(annotation)==15,"each line should have 15 element" 49 | 50 | im_path = os.path.join(data_dir, annotation[0]) 51 | 52 | img = cv2.imread(im_path) 53 | 54 | # print(im_path) 55 | assert (img is not None) 56 | 57 | height, width, channel = img.shape 58 | 59 | gt_box = list(map(float, annotation[1:5])) 60 | gt_box = np.array(gt_box, dtype=np.int32) 61 | 62 | landmark = list(map(float, annotation[5:])) 63 | landmark = np.array(landmark, dtype=np.float) 64 | 65 | idx = idx + 1 66 | if idx % 100 == 0: 67 | print("%d images done, landmark images: %d"%(idx,l_idx)) 68 | 69 | x1, x2, y1, y2 = gt_box 70 | gt_box[1] = y1 71 | gt_box[2] = x2 72 | # time.sleep(5) 73 | 74 | # gt's width 75 | w = x2 - x1 + 1 76 | # gt's height 77 | h = y2 - y1 + 1 78 | if max(w, h) < 40 or x1 < 0 or y1 < 0: 79 | continue 80 | # random shift 81 | for i in range(15): 82 | bbox_size = np.random.randint(int(min(w, h) * 0.8), np.ceil(1.25 * max(w, h))) 83 | delta_x = np.random.randint(-w * 0.2, w * 0.2) 84 | delta_y = np.random.randint(-h * 0.2, h * 0.2) 85 | nx1 = max(x1 + w / 2 - bbox_size / 2 + delta_x, 0) 86 | ny1 = max(y1 + h / 2 - bbox_size / 2 + delta_y, 0) 87 | 88 | nx2 = nx1 + bbox_size 89 | ny2 = ny1 + bbox_size 90 | if nx2 > width or ny2 > height: 91 | continue 92 | crop_box = np.array([nx1, ny1, nx2, ny2]) 93 | cropped_im = img[int(ny1):int(ny2) + 1, int(nx1):int(nx2) + 1, :] 94 | resized_im = cv2.resize(cropped_im, (size, size),interpolation=cv2.INTER_LINEAR) 95 | 96 | offset_x1 = (x1 - nx1) / float(bbox_size) 97 | offset_y1 = (y1 - ny1) / float(bbox_size) 98 | offset_x2 = (x2 - nx2) / float(bbox_size) 99 | offset_y2 = (y2 - ny2) / float(bbox_size) 100 | 101 | offset_left_eye_x = (landmark[0] - nx1) / float(bbox_size) 102 | offset_left_eye_y = (landmark[1] - ny1) / float(bbox_size) 103 | 104 | offset_right_eye_x = (landmark[2] - nx1) / float(bbox_size) 105 | offset_right_eye_y = (landmark[3] - ny1) / float(bbox_size) 106 | 107 | offset_nose_x = (landmark[4] - nx1) / float(bbox_size) 108 | offset_nose_y = (landmark[5] - ny1) / float(bbox_size) 109 | 110 | offset_left_mouth_x = (landmark[6] - nx1) / float(bbox_size) 111 | offset_left_mouth_y = (landmark[7] - ny1) / float(bbox_size) 112 | 113 | offset_right_mouth_x = (landmark[8] - nx1) / float(bbox_size) 114 | offset_right_mouth_y = (landmark[9] - ny1) / float(bbox_size) 115 | 116 | 117 | # cal iou 118 | iou = IoU(crop_box.astype(np.float), np.expand_dims(gt_box.astype(np.float), 0)) 119 | # print(iou) 120 | if iou > 0.65: 121 | save_file = os.path.join(landmark_imgs_save_dir, "%s.jpg" % l_idx) 122 | cv2.imwrite(save_file, resized_im) 123 | 124 | f.write(save_file + ' -2 %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n' % \ 125 | (offset_x1, offset_y1, offset_x2, offset_y2, \ 126 | # offset_left_eye_x,offset_left_eye_y,offset_right_eye_x,offset_right_eye_y,offset_nose_x,offset_nose_y,offset_left_mouth_x,offset_left_mouth_y,offset_right_mouth_x,offset_right_mouth_y)) 127 | offset_left_eye_x, offset_right_eye_x, offset_nose_x, offset_left_mouth_x, offset_right_mouth_x, offset_left_eye_y, offset_right_eye_y, offset_nose_y, offset_left_mouth_y, offset_right_mouth_y)) 128 | l_idx += 1 129 | 130 | # if idx == 20: 131 | # break 132 | 133 | f.close() 134 | 135 | 136 | -------------------------------------------------------------------------------- /gen_dataset/gen_Pnet_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate positive, negative, positive size 12*12 images for PNet 3 | """ 4 | import os, sys 5 | sys.path.append('.') 6 | import cv2 7 | import os 8 | import numpy as np 9 | from tools.utils import* 10 | import config 11 | 12 | mode = 'val' 13 | net_size = config.PNET_SIZE 14 | prefix = '' 15 | anno_file = "./annotations/wider_anno_{}.txt".format(mode) 16 | im_dir = "./data/WIDER_train/images" 17 | 18 | pos_save_dir = "./data/{}/{}/positive".format(mode, net_size) 19 | part_save_dir = "./data/{}/{}/part".format(mode, net_size) 20 | neg_save_dir = "./data/{}/{}/negative".format(mode, net_size) 21 | 22 | if not os.path.exists(pos_save_dir): 23 | os.makedirs(pos_save_dir) 24 | if not os.path.exists(part_save_dir): 25 | os.makedirs(part_save_dir) 26 | if not os.path.exists(neg_save_dir): 27 | os.makedirs(neg_save_dir) 28 | 29 | # store labels of positive, negative, part images 30 | f1 = open(os.path.join('annotations', 'pos_{}_{}.txt'.format(net_size, mode)), 'w') 31 | f2 = open(os.path.join('annotations', 'neg_{}_{}.txt'.format(net_size, mode)), 'w') 32 | f3 = open(os.path.join('annotations', 'part_{}_{}.txt'.format(net_size, mode)), 'w') 33 | 34 | # anno_file: store labels of the wider face training data 35 | with open(anno_file, 'r') as f: 36 | annotations = f.readlines() 37 | num = len(annotations) 38 | print("%d pics in total" % num) 39 | 40 | p_idx = 0 # positive 41 | n_idx = 0 # negative 42 | d_idx = 0 # dont care 43 | idx = 0 44 | for annotation in annotations: 45 | annotation = annotation.strip().split(' ') 46 | im_path = os.path.join(prefix, annotation[0]) 47 | print(im_path) 48 | bbox = list(map(float, annotation[1:])) 49 | boxes = np.array(bbox, dtype=np.int32).reshape(-1, 4) 50 | # anno form is x1, y1, w, h, convert to x1, y1, x2, y2 51 | boxes[:,2] += boxes[:,0] - 1 52 | boxes[:,3] += boxes[:,1] - 1 53 | 54 | img = cv2.imread(im_path) 55 | idx += 1 56 | 57 | height, width, channel = img.shape 58 | 59 | neg_num = 0 60 | while neg_num < 35: 61 | size = np.random.randint(12, min(width, height) / 2) 62 | nx = np.random.randint(0, width - size) 63 | ny = np.random.randint(0, height - size) 64 | crop_box = np.array([nx, ny, nx + size, ny + size]) 65 | 66 | Iou = IoU(crop_box, boxes) 67 | 68 | cropped_im = img[ny: ny + size, nx: nx + size, :] 69 | resized_im = cv2.resize(cropped_im, (net_size, net_size), interpolation=cv2.INTER_LINEAR) 70 | 71 | if np.max(Iou) < 0.3: 72 | # Iou with all gts must below 0.3 73 | save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx) 74 | f2.write(save_file + ' 0\n') 75 | cv2.imwrite(save_file, resized_im) 76 | n_idx += 1 77 | neg_num += 1 78 | 79 | for box in boxes: 80 | # box (x_left, y_top, w, h) 81 | x1, y1, x2, y2 = box 82 | w = x2 - x1 + 1 83 | h = y2 - y1 + 1 84 | 85 | # ignore small faces 86 | # in case the ground truth boxes of small faces are not accurate 87 | if max(w, h) < 40 or x1 < 0 or y1 < 0 or w < 0 or h < 0: 88 | continue 89 | 90 | # generate negative examples that have overlap with gt 91 | for i in range(5): 92 | size = np.random.randint(12, min(width, height) / 2) 93 | # delta_x and delta_y are offsets of (x1, y1) 94 | 95 | delta_x = np.random.randint(max(-size, -x1), w) 96 | delta_y = np.random.randint(max(-size, -y1), h) 97 | nx1 = max(0, x1 + delta_x) 98 | ny1 = max(0, y1 + delta_y) 99 | 100 | if nx1 + size > width or ny1 + size > height: 101 | continue 102 | crop_box = np.array([nx1, ny1, nx1 + size, ny1 + size]) 103 | Iou = IoU(crop_box, boxes) 104 | 105 | cropped_im = img[ny1: ny1 + size, nx1: nx1 + size, :] 106 | resized_im = cv2.resize(cropped_im, (net_size, net_size), interpolation=cv2.INTER_LINEAR) 107 | 108 | if np.max(Iou) < 0.3: 109 | # Iou with all gts must below 0.3 110 | save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx) 111 | f2.write(save_file + ' 0\n') 112 | cv2.imwrite(save_file, resized_im) 113 | n_idx += 1 114 | 115 | # generate positive examples and part faces 116 | for i in range(20): 117 | size = np.random.randint(int(min(w, h) * 0.8), np.ceil(1.25 * max(w, h))) 118 | 119 | # delta here is the offset of box center 120 | delta_x = np.random.randint(-w * 0.2, w * 0.2) 121 | delta_y = np.random.randint(-h * 0.2, h * 0.2) 122 | 123 | nx1 = max(x1 + w / 2 + delta_x - size / 2, 0) 124 | ny1 = max(y1 + h / 2 + delta_y - size / 2, 0) 125 | nx2 = nx1 + size 126 | ny2 = ny1 + size 127 | 128 | if nx2 > width or ny2 > height: 129 | continue 130 | crop_box = np.array([nx1, ny1, nx2, ny2]) 131 | 132 | offset_x1 = (x1 - nx1) / float(size) 133 | offset_y1 = (y1 - ny1) / float(size) 134 | offset_x2 = (x2 - nx2) / float(size) 135 | offset_y2 = (y2 - ny2) / float(size) 136 | 137 | cropped_im = img[int(ny1): int(ny2), int(nx1): int(nx2), :] 138 | resized_im = cv2.resize(cropped_im, (net_size, net_size), interpolation=cv2.INTER_LINEAR) 139 | 140 | box_ = box.reshape(1, -1) 141 | if IoU(crop_box, box_) >= 0.65: 142 | save_file = os.path.join(pos_save_dir, "%s.jpg" % p_idx) 143 | f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % (offset_x1, offset_y1, offset_x2, offset_y2)) 144 | cv2.imwrite(save_file, resized_im) 145 | p_idx += 1 146 | elif IoU(crop_box, box_) >= 0.4 and d_idx < 1.2*p_idx + 1: 147 | save_file = os.path.join(part_save_dir, "%s.jpg" % d_idx) 148 | f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % (offset_x1, offset_y1, offset_x2, offset_y2)) 149 | cv2.imwrite(save_file, resized_im) 150 | d_idx += 1 151 | 152 | 153 | 154 | print("%s images done, pos: %s part: %s neg: %s" % (idx, p_idx, d_idx, n_idx)) 155 | 156 | # if idx == 20: 157 | # break 158 | 159 | f1.close() 160 | f2.close() 161 | f3.close() 162 | -------------------------------------------------------------------------------- /gen_dataset/gen_Pnet_data_using_prob.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate positive, negative, positive size 12*12 images for PNet 3 | """ 4 | import os, sys 5 | sys.path.append('.') 6 | import cv2 7 | import os 8 | import numpy as np 9 | from tools.utils import* 10 | 11 | mode = 'train' 12 | 13 | prefix = '' 14 | anno_file = "./annotations/wider_anno_{}.txt".format(mode) 15 | prob_file = "./annotations/wider_anno_train_prob.txt" 16 | im_dir = "./data/WIDER_train/images" 17 | 18 | pos_save_dir = "./data/{}/12/positive".format(mode) 19 | part_save_dir = "./data/{}/12/part".format(mode) 20 | neg_save_dir = "./data/{}/12/negative".format(mode) 21 | 22 | if not os.path.exists(pos_save_dir): 23 | os.makedirs(pos_save_dir) 24 | if not os.path.exists(part_save_dir): 25 | os.makedirs(part_save_dir) 26 | if not os.path.exists(neg_save_dir): 27 | os.makedirs(neg_save_dir) 28 | 29 | # store labels of positive, negative, part images 30 | f1 = open(os.path.join('annotations', 'pos_12_{}.txt'.format(mode)), 'w') 31 | f2 = open(os.path.join('annotations', 'neg_12_{}.txt'.format(mode)), 'w') 32 | f3 = open(os.path.join('annotations', 'part_12_{}.txt'.format(mode)), 'w') 33 | 34 | # anno_file: store labels of the wider face training data 35 | with open(anno_file, 'r') as f: 36 | annotations = f.readlines() 37 | if mode == 'train': 38 | with open(prob_file, 'r') as f: 39 | line_probs = f.readlines() 40 | 41 | num = len(annotations) 42 | print("%d pics in total" % num) 43 | 44 | p_idx = 0 # positive 45 | n_idx = 0 # negative 46 | d_idx = 0 # dont care 47 | idx = 0 48 | for line, annotation in enumerate(annotations): 49 | annotation = annotation.strip().split(' ') 50 | im_path = os.path.join(prefix, annotation[0]) 51 | 52 | if mode == 'train': 53 | probs = line_probs[line] 54 | probs = probs.strip().split(' ') 55 | probs = list(map(float, probs[1:])) 56 | probs = np.array(probs, dtype=np.float32).reshape(-1, 1) 57 | print(im_path) 58 | bbox = list(map(float, annotation[1:])) 59 | boxes = np.array(bbox, dtype=np.int32).reshape(-1, 4) 60 | # anno form is x1, y1, w, h, convert to x1, y1, x2, y2 61 | boxes[:,2] += boxes[:,0] - 1 62 | boxes[:,3] += boxes[:,1] - 1 63 | 64 | img = cv2.imread(im_path) 65 | idx += 1 66 | 67 | height, width, channel = img.shape 68 | 69 | neg_num = 0 70 | while neg_num < 35: 71 | size = np.random.randint(12, min(width, height) / 2) 72 | nx = np.random.randint(0, width - size) 73 | ny = np.random.randint(0, height - size) 74 | crop_box = np.array([nx, ny, nx + size, ny + size]) 75 | 76 | Iou = IoU(crop_box, boxes) 77 | 78 | cropped_im = img[ny: ny + size, nx: nx + size, :] 79 | resized_im = cv2.resize(cropped_im, (12, 12), interpolation=cv2.INTER_LINEAR) 80 | 81 | if np.max(Iou) < 0.3: 82 | # Iou with all gts must below 0.3 83 | save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx) 84 | f2.write(save_file + ' 0\n') 85 | cv2.imwrite(save_file, resized_im) 86 | n_idx += 1 87 | neg_num += 1 88 | 89 | for face_index, box in enumerate(boxes): 90 | # box (x_left, y_top, w, h) 91 | x1, y1, x2, y2 = box 92 | w = x2 - x1 + 1 93 | h = y2 - y1 + 1 94 | 95 | prob = probs[face_index] 96 | # ignore small faces 97 | # in case the ground truth boxes of small faces are not accurate 98 | if mode == 'train': 99 | if max(w, h) < 40 or x1 < 0 or y1 < 0 or min(w, h) < 10 or prob < 0.5: 100 | continue 101 | else: 102 | if max(w, h) < 40 or x1 < 0 or y1 < 0 or min(w, h) < 10 or prob: 103 | continue 104 | # generate negative examples that have overlap with gt 105 | for i in range(5): 106 | size = np.random.randint(12, min(width, height) / 2) 107 | # delta_x and delta_y are offsets of (x1, y1) 108 | 109 | delta_x = np.random.randint(max(-size, -x1), w) 110 | delta_y = np.random.randint(max(-size, -y1), h) 111 | nx1 = max(0, x1 + delta_x) 112 | ny1 = max(0, y1 + delta_y) 113 | 114 | if nx1 + size > width or ny1 + size > height: 115 | continue 116 | crop_box = np.array([nx1, ny1, nx1 + size, ny1 + size]) 117 | Iou = IoU(crop_box, boxes) 118 | 119 | cropped_im = img[ny1: ny1 + size, nx1: nx1 + size, :] 120 | resized_im = cv2.resize(cropped_im, (12, 12), interpolation=cv2.INTER_LINEAR) 121 | 122 | if np.max(Iou) < 0.3: 123 | # Iou with all gts must below 0.3 124 | save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx) 125 | f2.write(save_file + ' 0\n') 126 | cv2.imwrite(save_file, resized_im) 127 | n_idx += 1 128 | 129 | # generate positive examples and part faces 130 | for i in range(20): 131 | size = np.random.randint(int(min(w, h) * 0.8), np.ceil(1.25 * max(w, h))) 132 | 133 | # delta here is the offset of box center 134 | delta_x = np.random.randint(-w * 0.2, w * 0.2) 135 | delta_y = np.random.randint(-h * 0.2, h * 0.2) 136 | 137 | nx1 = max(x1 + w / 2 + delta_x - size / 2, 0) 138 | ny1 = max(y1 + h / 2 + delta_y - size / 2, 0) 139 | nx2 = nx1 + size 140 | ny2 = ny1 + size 141 | 142 | if nx2 > width or ny2 > height: 143 | continue 144 | crop_box = np.array([nx1, ny1, nx2, ny2]) 145 | 146 | offset_x1 = (x1 - nx1) / float(size) 147 | offset_y1 = (y1 - ny1) / float(size) 148 | offset_x2 = (x2 - nx2) / float(size) 149 | offset_y2 = (y2 - ny2) / float(size) 150 | 151 | cropped_im = img[int(ny1): int(ny2), int(nx1): int(nx2), :] 152 | resized_im = cv2.resize(cropped_im, (12, 12), interpolation=cv2.INTER_LINEAR) 153 | 154 | box_ = box.reshape(1, -1) 155 | if IoU(crop_box, box_) >= 0.65: 156 | save_file = os.path.join(pos_save_dir, "%s.jpg" % p_idx) 157 | f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % (offset_x1, offset_y1, offset_x2, offset_y2)) 158 | cv2.imwrite(save_file, resized_im) 159 | p_idx += 1 160 | elif IoU(crop_box, box_) >= 0.4 and d_idx < 1.2*p_idx + 1: 161 | save_file = os.path.join(part_save_dir, "%s.jpg" % d_idx) 162 | f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % (offset_x1, offset_y1, offset_x2, offset_y2)) 163 | cv2.imwrite(save_file, resized_im) 164 | d_idx += 1 165 | 166 | 167 | 168 | print("%s images done, pos: %s part: %s neg: %s" % (idx, p_idx, d_idx, n_idx)) 169 | 170 | # if idx == 20: 171 | # break 172 | 173 | f1.close() 174 | f2.close() 175 | f3.close() 176 | -------------------------------------------------------------------------------- /gen_dataset/gen_Rnet_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate positive, negative, positive images whose size are 24*24 from Pnet. 3 | """ 4 | import sys 5 | sys.path.append('.') 6 | import cv2 7 | import os 8 | import numpy as np 9 | from tools.utils import* 10 | import torch 11 | from MTCNN import MTCNNDetector 12 | 13 | mode = 'train' 14 | prefix = '' 15 | anno_file = "./annotations/wider_anno_{}.txt".format(mode) 16 | im_dir = "/home/idealabs/data/opensource_dataset/WIDER/WIDER_{}/images".format(mode) 17 | pos_save_dir = "./data/{}/24/positive".format(mode) 18 | part_save_dir = "./data/{}/24/part".format(mode) 19 | neg_save_dir = "./data/{}/24/negative".format(mode) 20 | 21 | if not os.path.exists(pos_save_dir): 22 | os.makedirs(pos_save_dir) 23 | if not os.path.exists(part_save_dir): 24 | os.makedirs(part_save_dir) 25 | if not os.path.exists(neg_save_dir): 26 | os.makedirs(neg_save_dir) 27 | 28 | # store labels of positive, negative, part images 29 | f1 = open(os.path.join('annotations', 'pos_24_{}.txt'.format(mode)), 'w') 30 | f2 = open(os.path.join('annotations', 'neg_24_{}.txt'.format(mode)), 'w') 31 | f3 = open(os.path.join('annotations', 'part_24_{}.txt'.format(mode)), 'w') 32 | 33 | # anno_file: store labels of the wider face training data 34 | with open(anno_file, 'r') as f: 35 | annotations = f.readlines() 36 | num = len(annotations) 37 | print("%d pics in total" % num) 38 | 39 | image_size = 24 40 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 41 | print(device) 42 | p_idx = 0 # positive 43 | n_idx = 0 # negative 44 | d_idx = 0 # dont care 45 | idx = 0 46 | 47 | # create MTCNN Detector 48 | mtcnn_detector = MTCNNDetector(p_model_path='./pretrained_weights/mtcnn/best_pnet.pth') 49 | 50 | for annotation in annotations: 51 | annotation = annotation.strip().split(' ') 52 | im_path = os.path.join(prefix, annotation[0]) 53 | print(im_path) 54 | bbox = list(map(float, annotation[1:])) 55 | boxes = np.array(bbox, dtype=np.int32).reshape(-1, 4) 56 | # anno form is x1, y1, w, h, convert to x1, y1, x2, y2 57 | boxes[:,2] += boxes[:,0] - 1 58 | boxes[:,3] += boxes[:,1] - 1 59 | 60 | image = cv2.imread(im_path) 61 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 62 | bboxes = mtcnn_detector.detect_face(image) 63 | 64 | # bboxes, landmarks = create_mtcnn_net(image, 12, device, p_model_path='../train/pnet_Weights') 65 | if bboxes.shape[0] == 0: 66 | continue 67 | 68 | dets = np.round(bboxes[:, 0:4]) 69 | 70 | 71 | img = cv2.imread(im_path) 72 | idx += 1 73 | 74 | height, width, channel = img.shape 75 | 76 | for box in dets: 77 | x_left, y_top, x_right, y_bottom = box[0:4].astype(int) 78 | width = x_right - x_left + 1 79 | height = y_bottom - y_top + 1 80 | 81 | # ignore box that is too small or beyond image border 82 | if width < 20 or x_left < 0 or y_top < 0 or x_right > img.shape[1] - 1 or y_bottom > img.shape[0] - 1: 83 | continue 84 | 85 | # compute intersection over union(IoU) between current box and all gt boxes 86 | Iou = IoU(box, boxes) 87 | cropped_im = img[y_top:y_bottom + 1, x_left:x_right + 1, :] 88 | resized_im = cv2.resize(cropped_im, (image_size, image_size), 89 | interpolation=cv2.INTER_LINEAR) 90 | 91 | # save negative images and write label 92 | if np.max(Iou) < 0.2 and n_idx < 3.0*p_idx+1: 93 | # Iou with all gts must below 0.3 94 | save_file = os.path.join(neg_save_dir, "%s.jpg" % n_idx) 95 | f2.write(save_file + ' 0\n') 96 | cv2.imwrite(save_file, resized_im) 97 | n_idx += 1 98 | else: 99 | # find gt_box with the highest iou 100 | idx_Iou = np.argmax(Iou) 101 | assigned_gt = boxes[idx_Iou] 102 | x1, y1, x2, y2 = assigned_gt 103 | 104 | # compute bbox reg label 105 | offset_x1 = (x1 - x_left) / float(width) 106 | offset_y1 = (y1 - y_top) / float(height) 107 | offset_x2 = (x2 - x_right) / float(width) 108 | offset_y2 = (y2 - y_bottom) / float(height) 109 | 110 | # save positive and part-face images and write labels 111 | if np.max(Iou) >= 0.65: 112 | save_file = os.path.join(pos_save_dir, "%s.jpg" % p_idx) 113 | f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % ( 114 | offset_x1, offset_y1, offset_x2, offset_y2)) 115 | cv2.imwrite(save_file, resized_im) 116 | p_idx += 1 117 | 118 | elif np.max(Iou) >= 0.4 and d_idx < 1.0*p_idx + 1: 119 | save_file = os.path.join(part_save_dir, "%s.jpg" % d_idx) 120 | f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % ( 121 | offset_x1, offset_y1, offset_x2, offset_y2)) 122 | cv2.imwrite(save_file, resized_im) 123 | d_idx += 1 124 | 125 | print("%s images done, pos: %s part: %s neg: %s" % (idx, p_idx, d_idx, n_idx)) 126 | 127 | #if idx == 20: 128 | # break 129 | 130 | f1.close() 131 | f2.close() 132 | f3.close() 133 | -------------------------------------------------------------------------------- /gen_dataset/gen_user_Onet_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate positive, negative, positive images whose size are 48*48 from PNet & RNet 3 | """ 4 | import os, sys 5 | sys.path.append('.') 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | 11 | from tools.utils import* 12 | from MTCNN import MTCNNDetector 13 | 14 | mode = 'train' 15 | prefix = '' 16 | anno_file = "annotations/user_anno_{}.txt".format(mode) 17 | im_dir = "./data/user/images".format(mode) 18 | pos_save_dir = "./data/{}/48/positive".format(mode) 19 | part_save_dir = "./data/{}/48/part".format(mode) 20 | neg_save_dir = "./data/{}/48/negative".format(mode) 21 | 22 | if not os.path.exists(pos_save_dir): 23 | os.makedirs(pos_save_dir) 24 | if not os.path.exists(part_save_dir): 25 | os.makedirs(part_save_dir) 26 | if not os.path.exists(neg_save_dir): 27 | os.makedirs(neg_save_dir) 28 | 29 | # store labels of positive, negative, part images 30 | f1 = open(os.path.join('annotations', 'pos_48_{}.txt'.format(mode)), 'a') 31 | f2 = open(os.path.join('annotations', 'neg_48_{}.txt'.format(mode)), 'a') 32 | f3 = open(os.path.join('annotations', 'part_48_{}.txt'.format(mode)), 'a') 33 | 34 | # anno_file: store labels of the wider face training data 35 | with open(anno_file, 'r') as f: 36 | annotations = f.readlines() 37 | num = len(annotations) 38 | print("%d pics in total" % num) 39 | 40 | image_size = 48 41 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 42 | 43 | mtcnn_detector = MTCNNDetector(p_model_path='./pretrained_weights/mtcnn/best_pnet.pth', 44 | r_model_path='./pretrained_weights/mtcnn/best_rnet.pth') 45 | p_idx = 0 # positive 46 | n_idx = 0 # negative 47 | d_idx = 0 # dont care 48 | idx = 0 49 | for annotation in annotations: 50 | annotation = annotation.strip().split(' ') 51 | im_path = os.path.join(im_dir, annotation[0]) 52 | print(im_path) 53 | bbox = list(map(float, annotation[1:])) 54 | boxes = np.array(bbox, dtype=np.int32).reshape(-1, 4) 55 | 56 | # anno form is x1, y1, w, h, convert to x1, y1, x2, y2 57 | # boxes[:,2] += boxes[:,0] - 1 58 | # boxes[:,3] += boxes[:,1] - 1 59 | 60 | image = cv2.imread(im_path) 61 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 62 | #bboxes, landmarks = create_mtcnn_net(image, 12, device, p_model_path='../train/pnet_Weights', r_model_path='../train/rnet_Weights') 63 | bboxes = mtcnn_detector.detect_face(image) 64 | 65 | if bboxes.shape[0] == 0: 66 | continue 67 | 68 | dets = np.round(bboxes[:, 0:4]) 69 | 70 | img = cv2.imread(im_path) 71 | idx += 1 72 | 73 | height, width, channel = img.shape 74 | 75 | for box in dets: 76 | x_left, y_top, x_right, y_bottom = box[0:4].astype(int) 77 | width = x_right - x_left + 1 78 | height = y_bottom - y_top + 1 79 | 80 | # ignore box that is too small or beyond image border 81 | if width < 20 or x_left < 0 or y_top < 0 or x_right > img.shape[1] - 1 or y_bottom > img.shape[0] - 1: 82 | continue 83 | 84 | # compute intersection over union(IoU) between current box and all gt boxes 85 | Iou = IoU(box, boxes) 86 | cropped_im = img[y_top:y_bottom + 1, x_left:x_right + 1, :] 87 | resized_im = cv2.resize(cropped_im, (image_size, image_size), 88 | interpolation=cv2.INTER_LINEAR) 89 | 90 | # save negative images and write label 91 | if np.max(Iou) < 0.2 and n_idx < 1.0*p_idx+1: 92 | # Iou with all gts must below 0.3 93 | save_file = os.path.join(neg_save_dir, "user_%s.jpg" % n_idx) 94 | f2.write(save_file + ' 0\n') 95 | cv2.imwrite(save_file, resized_im) 96 | n_idx += 1 97 | else: 98 | # find gt_box with the highest iou 99 | idx_Iou = np.argmax(Iou) 100 | assigned_gt = boxes[idx_Iou] 101 | x1, y1, x2, y2 = assigned_gt 102 | 103 | # compute bbox reg label 104 | offset_x1 = (x1 - x_left) / float(width) 105 | offset_y1 = (y1 - y_top) / float(height) 106 | offset_x2 = (x2 - x_right) / float(width) 107 | offset_y2 = (y2 - y_bottom) / float(height) 108 | 109 | # save positive and part-face images and write labels 110 | if np.max(Iou) >= 0.65: 111 | save_file = os.path.join(pos_save_dir, "user_%s.jpg" % p_idx) 112 | f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % ( 113 | offset_x1, offset_y1, offset_x2, offset_y2)) 114 | cv2.imwrite(save_file, resized_im) 115 | p_idx += 1 116 | 117 | elif np.max(Iou) >= 0.4 and d_idx < 1.0*p_idx + 1: 118 | save_file = os.path.join(part_save_dir, "user_%s.jpg" % d_idx) 119 | f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % ( 120 | offset_x1, offset_y1, offset_x2, offset_y2)) 121 | cv2.imwrite(save_file, resized_im) 122 | d_idx += 1 123 | 124 | print("%s images done, pos: %s part: %s neg: %s" % (idx, p_idx, d_idx, n_idx)) 125 | 126 | # if idx == 20: 127 | # break 128 | 129 | f1.close() 130 | f2.close() 131 | f3.close() 132 | -------------------------------------------------------------------------------- /gen_dataset/gen_user_Pnet_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate positive, negative, positive size 12*12 images for PNet 3 | """ 4 | import os, sys 5 | sys.path.append('.') 6 | import cv2 7 | import os 8 | import numpy as np 9 | from tools.utils import* 10 | 11 | mode = 'val' 12 | 13 | prefix = '' 14 | anno_file = "./annotations/user_anno_{}.txt".format(mode) 15 | im_dir = "./data/user/images" 16 | 17 | pos_save_dir = "./data/{}/12/positive".format(mode) 18 | part_save_dir = "./data/{}/12/part".format(mode) 19 | neg_save_dir = "./data/{}/12/negative".format(mode) 20 | 21 | if not os.path.exists(pos_save_dir): 22 | os.makedirs(pos_save_dir) 23 | if not os.path.exists(part_save_dir): 24 | os.makedirs(part_save_dir) 25 | if not os.path.exists(neg_save_dir): 26 | os.makedirs(neg_save_dir) 27 | 28 | # store labels of positive, negative, part images 29 | f1 = open(os.path.join('annotations', 'pos_12_{}.txt'.format(mode)), 'a') 30 | f2 = open(os.path.join('annotations', 'neg_12_{}.txt'.format(mode)), 'a') 31 | f3 = open(os.path.join('annotations', 'part_12_{}.txt'.format(mode)), 'a') 32 | 33 | # anno_file: store labels of the wider face training data 34 | with open(anno_file, 'r') as f: 35 | annotations = f.readlines() 36 | num = len(annotations) 37 | print("%d pics in total" % num) 38 | 39 | p_idx = 0 # positive 40 | n_idx = 0 # negative 41 | d_idx = 0 # dont care 42 | idx = 0 43 | for annotation in annotations: 44 | annotation = annotation.strip().split(' ') 45 | im_path = os.path.join(im_dir, annotation[0]) 46 | print(im_path) 47 | bbox = list(map(float, annotation[1:])) 48 | boxes = np.array(bbox, dtype=np.int32).reshape(-1, 4) 49 | 50 | # user annotation is x1, y1, x2, y2 51 | # anno form is x1, y1, w, h, convert to x1, y1, x2, y2 52 | # boxes[:,2] += boxes[:,0] - 1 53 | # boxes[:,3] += boxes[:,1] - 1 54 | 55 | img = cv2.imread(im_path) 56 | idx += 1 57 | 58 | height, width, channel = img.shape 59 | 60 | neg_num = 0 61 | while neg_num < 35: 62 | size = np.random.randint(12, min(width, height) / 2) 63 | nx = np.random.randint(0, width - size) 64 | ny = np.random.randint(0, height - size) 65 | crop_box = np.array([nx, ny, nx + size, ny + size]) 66 | 67 | Iou = IoU(crop_box, boxes) 68 | 69 | cropped_im = img[ny: ny + size, nx: nx + size, :] 70 | resized_im = cv2.resize(cropped_im, (12, 12), interpolation=cv2.INTER_LINEAR) 71 | 72 | if np.max(Iou) < 0.3: 73 | # Iou with all gts must below 0.3 74 | save_file = os.path.join(neg_save_dir, "user_%s.jpg" % n_idx) 75 | f2.write(save_file + ' 0\n') 76 | cv2.imwrite(save_file, resized_im) 77 | n_idx += 1 78 | neg_num += 1 79 | 80 | for box in boxes: 81 | # box (x_left, y_top, w, h) 82 | x1, y1, x2, y2 = box 83 | w = x2 - x1 + 1 84 | h = y2 - y1 + 1 85 | 86 | # ignore small faces 87 | # in case the ground truth boxes of small faces are not accurate 88 | if max(w, h) < 40 or x1 < 0 or y1 < 0 or w < 0 or h < 0: 89 | continue 90 | 91 | # generate negative examples that have overlap with gt 92 | for i in range(5): 93 | size = np.random.randint(12, min(width, height) / 2) 94 | # delta_x and delta_y are offsets of (x1, y1) 95 | 96 | delta_x = np.random.randint(max(-size, -x1), w) 97 | delta_y = np.random.randint(max(-size, -y1), h) 98 | nx1 = max(0, x1 + delta_x) 99 | ny1 = max(0, y1 + delta_y) 100 | 101 | if nx1 + size > width or ny1 + size > height: 102 | continue 103 | crop_box = np.array([nx1, ny1, nx1 + size, ny1 + size]) 104 | Iou = IoU(crop_box, boxes) 105 | 106 | cropped_im = img[ny1: ny1 + size, nx1: nx1 + size, :] 107 | resized_im = cv2.resize(cropped_im, (12, 12), interpolation=cv2.INTER_LINEAR) 108 | 109 | if np.max(Iou) < 0.3: 110 | # Iou with all gts must below 0.3 111 | save_file = os.path.join(neg_save_dir, "user_%s.jpg" % n_idx) 112 | f2.write(save_file + ' 0\n') 113 | cv2.imwrite(save_file, resized_im) 114 | n_idx += 1 115 | 116 | # generate positive examples and part faces 117 | for i in range(20): 118 | size = np.random.randint(int(min(w, h) * 0.8), np.ceil(1.25 * max(w, h))) 119 | 120 | # delta here is the offset of box center 121 | delta_x = np.random.randint(-w * 0.2, w * 0.2) 122 | delta_y = np.random.randint(-h * 0.2, h * 0.2) 123 | 124 | nx1 = max(x1 + w / 2 + delta_x - size / 2, 0) 125 | ny1 = max(y1 + h / 2 + delta_y - size / 2, 0) 126 | nx2 = nx1 + size 127 | ny2 = ny1 + size 128 | 129 | if nx2 > width or ny2 > height: 130 | continue 131 | crop_box = np.array([nx1, ny1, nx2, ny2]) 132 | 133 | offset_x1 = (x1 - nx1) / float(size) 134 | offset_y1 = (y1 - ny1) / float(size) 135 | offset_x2 = (x2 - nx2) / float(size) 136 | offset_y2 = (y2 - ny2) / float(size) 137 | 138 | cropped_im = img[int(ny1): int(ny2), int(nx1): int(nx2), :] 139 | resized_im = cv2.resize(cropped_im, (12, 12), interpolation=cv2.INTER_LINEAR) 140 | 141 | box_ = box.reshape(1, -1) 142 | if IoU(crop_box, box_) >= 0.65: 143 | save_file = os.path.join(pos_save_dir, "user_%s.jpg" % p_idx) 144 | f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % (offset_x1, offset_y1, offset_x2, offset_y2)) 145 | cv2.imwrite(save_file, resized_im) 146 | p_idx += 1 147 | elif IoU(crop_box, box_) >= 0.4 and d_idx < 1.2*p_idx + 1: 148 | save_file = os.path.join(part_save_dir, "user_%s.jpg" % d_idx) 149 | f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % (offset_x1, offset_y1, offset_x2, offset_y2)) 150 | cv2.imwrite(save_file, resized_im) 151 | d_idx += 1 152 | 153 | 154 | 155 | print("%s images done, pos: %s part: %s neg: %s" % (idx, p_idx, d_idx, n_idx)) 156 | 157 | # if idx == 20: 158 | # break 159 | 160 | f1.close() 161 | f2.close() 162 | f3.close() 163 | -------------------------------------------------------------------------------- /gen_dataset/gen_user_Rnet_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate positive, negative, positive images whose size are 24*24 from Pnet. 3 | """ 4 | import sys 5 | sys.path.append('.') 6 | import cv2 7 | import os 8 | import numpy as np 9 | from tools.utils import* 10 | import torch 11 | from MTCNN import MTCNNDetector 12 | 13 | mode = 'train' 14 | prefix = '' 15 | anno_file = "./annotations/user_anno_{}.txt".format(mode) 16 | im_dir = "./data/user/images" 17 | pos_save_dir = "./data/{}/24/positive".format(mode) 18 | part_save_dir = "./data/{}/24/part".format(mode) 19 | neg_save_dir = "./data/{}/24/negative".format(mode) 20 | 21 | if not os.path.exists(pos_save_dir): 22 | os.makedirs(pos_save_dir) 23 | if not os.path.exists(part_save_dir): 24 | os.makedirs(part_save_dir) 25 | if not os.path.exists(neg_save_dir): 26 | os.makedirs(neg_save_dir) 27 | 28 | # store labels of positive, negative, part images 29 | f1 = open(os.path.join('annotations', 'pos_24_{}.txt'.format(mode)), 'a') 30 | f2 = open(os.path.join('annotations', 'neg_24_{}.txt'.format(mode)), 'a') 31 | f3 = open(os.path.join('annotations', 'part_24_{}.txt'.format(mode)), 'a') 32 | 33 | # anno_file: store labels of the wider face training data 34 | with open(anno_file, 'r') as f: 35 | annotations = f.readlines() 36 | num = len(annotations) 37 | print("%d pics in total" % num) 38 | 39 | image_size = 24 40 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 41 | print(device) 42 | p_idx = 0 # positive 43 | n_idx = 0 # negative 44 | d_idx = 0 # dont care 45 | idx = 0 46 | 47 | # create MTCNN Detector 48 | mtcnn_detector = MTCNNDetector(p_model_path='./pretrained_weights/mtcnn/best_pnet.pth') 49 | 50 | for annotation in annotations: 51 | annotation = annotation.strip().split(' ') 52 | im_path = os.path.join(im_dir, annotation[0]) 53 | print(im_path) 54 | bbox = list(map(float, annotation[1:])) 55 | boxes = np.array(bbox, dtype=np.int32).reshape(-1, 4) 56 | 57 | # anno form is x1, y1, w, h, convert to x1, y1, x2, y2 58 | # boxes[:,2] += boxes[:,0] - 1 59 | # boxes[:,3] += boxes[:,1] - 1 60 | 61 | image = cv2.imread(im_path) 62 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 63 | bboxes = mtcnn_detector.detect_face(image) 64 | # bboxes, landmarks = create_mtcnn_net(image, 12, device, p_model_path='../train/pnet_Weights') 65 | if bboxes.shape[0] == 0: 66 | continue 67 | 68 | dets = np.round(bboxes[:, 0:4]) 69 | 70 | 71 | img = cv2.imread(im_path) 72 | idx += 1 73 | 74 | height, width, channel = img.shape 75 | 76 | for box in dets: 77 | x_left, y_top, x_right, y_bottom = box[0:4].astype(int) 78 | width = x_right - x_left + 1 79 | height = y_bottom - y_top + 1 80 | 81 | # ignore box that is too small or beyond image border 82 | if width < 20 or x_left < 0 or y_top < 0 or x_right > img.shape[1] - 1 or y_bottom > img.shape[0] - 1: 83 | continue 84 | 85 | # compute intersection over union(IoU) between current box and all gt boxes 86 | Iou = IoU(box, boxes) 87 | cropped_im = img[y_top:y_bottom + 1, x_left:x_right + 1, :] 88 | resized_im = cv2.resize(cropped_im, (image_size, image_size), 89 | interpolation=cv2.INTER_LINEAR) 90 | 91 | # save negative images and write label 92 | if np.max(Iou) < 0.2 and n_idx < 3.2*p_idx+1: 93 | # Iou with all gts must below 0.3 94 | save_file = os.path.join(neg_save_dir, "user_%s.jpg" % n_idx) 95 | f2.write(save_file + ' 0\n') 96 | cv2.imwrite(save_file, resized_im) 97 | n_idx += 1 98 | else: 99 | # find gt_box with the highest iou 100 | idx_Iou = np.argmax(Iou) 101 | assigned_gt = boxes[idx_Iou] 102 | x1, y1, x2, y2 = assigned_gt 103 | 104 | # compute bbox reg label 105 | offset_x1 = (x1 - x_left) / float(width) 106 | offset_y1 = (y1 - y_top) / float(height) 107 | offset_x2 = (x2 - x_right) / float(width) 108 | offset_y2 = (y2 - y_bottom) / float(height) 109 | 110 | # save positive and part-face images and write labels 111 | if np.max(Iou) >= 0.65: 112 | save_file = os.path.join(pos_save_dir, "user_%s.jpg" % p_idx) 113 | f1.write(save_file + ' 1 %.2f %.2f %.2f %.2f\n' % ( 114 | offset_x1, offset_y1, offset_x2, offset_y2)) 115 | cv2.imwrite(save_file, resized_im) 116 | p_idx += 1 117 | 118 | elif np.max(Iou) >= 0.4 and d_idx < 1.2*p_idx + 1: 119 | save_file = os.path.join(part_save_dir, "user_%s.jpg" % d_idx) 120 | f3.write(save_file + ' -1 %.2f %.2f %.2f %.2f\n' % ( 121 | offset_x1, offset_y1, offset_x2, offset_y2)) 122 | cv2.imwrite(save_file, resized_im) 123 | d_idx += 1 124 | 125 | print("%s images done, pos: %s part: %s neg: %s" % (idx, p_idx, d_idx, n_idx)) 126 | 127 | #if idx == 20: 128 | # break 129 | 130 | f1.close() 131 | f2.close() 132 | f3.close() 133 | -------------------------------------------------------------------------------- /gen_dataset/transform_mat2txt.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append('.') 3 | import time 4 | 5 | import cv2 6 | 7 | from wider_loader import WIDER 8 | 9 | """ 10 | Transfrom .mat label to .txt label format 11 | """ 12 | mode = 'val' 13 | 14 | # wider face original images path 15 | path_to_image = "./data/WIDER_{}/images".format(mode) 16 | 17 | # matlab label file path 18 | file_to_label = "./data/wider_face_split/wider_face_{}.mat".format(mode) 19 | 20 | # target annotation file path 21 | target_file = './annotations/wider_anno_{}.txt'.format(mode) 22 | 23 | wider = WIDER(file_to_label, path_to_image) 24 | 25 | line_count = 0 26 | box_count = 0 27 | 28 | print('start transforming....') 29 | t = time.time() 30 | 31 | with open(target_file, 'w+') as f: 32 | for data in wider.next(): 33 | line = [] 34 | line.append(str(data.image_name)) 35 | line_count += 1 36 | for i,box in enumerate(data.bboxes): 37 | box_count += 1 38 | for j,bvalue in enumerate(box): 39 | line.append(str(bvalue)) 40 | 41 | line.append('\n') 42 | 43 | line_str = ' '.join(line) 44 | f.write(line_str) 45 | 46 | st = time.time()-t 47 | print('end transforming') 48 | print('spend time:%d'%st) 49 | print('total line(images):%d'%line_count) 50 | print('total boxes(faces):%d'%box_count) 51 | -------------------------------------------------------------------------------- /gen_dataset/wider_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.io import loadmat 3 | 4 | class DATA: 5 | def __init__(self, image_name, bboxes): 6 | self.image_name = image_name 7 | self.bboxes = bboxes 8 | 9 | 10 | class WIDER(object): 11 | def __init__(self, file_to_label, path_to_image=None): 12 | self.file_to_label = file_to_label 13 | self.path_to_image = path_to_image 14 | 15 | self.f = loadmat(file_to_label) 16 | self.event_list = self.f['event_list'] 17 | self.file_list = self.f['file_list'] 18 | self.face_bbx_list = self.f['face_bbx_list'] 19 | 20 | def next(self): 21 | for event_idx, event in enumerate(self.event_list): 22 | # fix error of "can't not .. bytes and strings" 23 | e = str(event[0][0].encode('utf-8'))[2:-1] 24 | for file, bbx in zip(self.file_list[event_idx][0], 25 | self.face_bbx_list[event_idx][0]): 26 | f = file[0][0].encode('utf-8') 27 | # print(type(e), type(f)) # bytes, bytes 28 | # fix error of "can't not .. bytes and strings" 29 | f = str(f)[2:-1] 30 | # path_of_image = os.path.join(self.path_to_image, str(e), str(f)) + ".jpg" 31 | path_of_image = self.path_to_image + '/' + e + '/' + f + ".jpg" 32 | # print(path_of_image) 33 | 34 | bboxes = [] 35 | bbx0 = bbx[0] 36 | for i in range(bbx0.shape[0]): 37 | xmin, ymin, xmax, ymax = bbx0[i] 38 | bboxes.append((int(xmin), int(ymin), int(xmax), int(ymax))) 39 | yield DATA(path_of_image, bboxes) 40 | 41 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Loss: 6 | """Losses for classification, face box regression, landmark regression""" 7 | def __init__(self, device): 8 | # loss function 9 | # self.loss_cls = nn.BCELoss().to(device) use this loss for sigmoid score 10 | self.loss_cls = nn.CrossEntropyLoss().to(device) 11 | self.loss_box = nn.MSELoss().to(device) 12 | self.loss_landmark = nn.MSELoss().to(device) 13 | 14 | 15 | def cls_loss(self, gt_label, pred_label): 16 | # get the mask element which >= 0, only 0 and 1 can effect the detection loss 17 | # kind of confused here, maybe its related to cropped data state 18 | pred_label = torch.squeeze(pred_label) 19 | mask = torch.ge(gt_label, 0) # mask is a BoolTensor, select indexes greater or equal than 0 20 | valid_gt_label = torch.masked_select(gt_label, mask)# .float() 21 | #valid_pred_label = torch.masked_select(pred_label, mask) 22 | valid_pred_label = pred_label[mask, :] 23 | return self.loss_cls(valid_pred_label, valid_gt_label) 24 | 25 | 26 | def box_loss(self, gt_label, gt_offset,pred_offset): 27 | # get the mask element which != 0 28 | mask = torch.ne(gt_label, 0) 29 | 30 | # convert mask to dim index 31 | chose_index = torch.nonzero(mask) 32 | chose_index = torch.squeeze(chose_index) 33 | 34 | # only valid element can effect the loss 35 | valid_gt_offset = gt_offset[chose_index,:] 36 | valid_pred_offset = pred_offset[chose_index,:] 37 | valid_pred_offset = torch.squeeze(valid_pred_offset) 38 | return self.loss_box(valid_pred_offset,valid_gt_offset) 39 | 40 | 41 | def landmark_loss(self, gt_label, gt_landmark, pred_landmark): 42 | mask = torch.eq(gt_label,-2) 43 | 44 | chose_index = torch.nonzero(mask.data) 45 | chose_index = torch.squeeze(chose_index) 46 | 47 | valid_gt_landmark = gt_landmark[chose_index, :] 48 | valid_pred_landmark = pred_landmark[chose_index, :] 49 | return self.loss_landmark(valid_pred_landmark, valid_gt_landmark) 50 | 51 | class NLL_OHEM(torch.nn.NLLLoss): 52 | """online hard sample mining""" 53 | def __init__(self, ratio): 54 | super(NLL_OHEM).__init__(None, True) 55 | self.ratio = ratio 56 | 57 | def forward(self, x, y, ratio=None): 58 | if ratio is not None: 59 | self.ratio = ratio 60 | 61 | num_inst = x.size(0) 62 | num_hns = int(self.ratio*num_inst) 63 | 64 | x_ = x.clone() 65 | inst_losses = torch.autograd.Variable(torch.zeros(num_inst)).cuda() 66 | 67 | for idx, label in enumerate(y.data): 68 | inst_loss[idx] = -x_.data[idx, label] 69 | _, idxs = inst_losses.topk(num__hns) 70 | x_hn = x.index_select(0, idxs) 71 | y_hn = y.index_select(0, idxs) 72 | 73 | return torch.nn.functional.nll_loss(x_hn, y_hn) 74 | 75 | -------------------------------------------------------------------------------- /nets/mtcnn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchsummary import summary 7 | 8 | class Flatten(nn.Module): 9 | def __init__(self): 10 | super(Flatten, self).__init__() 11 | 12 | def forward(self, x): 13 | # without this pretrained model won't working 14 | x = x.transpose(3, 2).contiguous() 15 | return x.view(x.size(0), -1) 16 | 17 | class PNet(nn.Module): 18 | '''12*12 stride 2''' 19 | def __init__(self, is_train=False): 20 | super(PNet, self).__init__() 21 | self.is_train = is_train 22 | 23 | ''' 24 | conv1: (H-2)*(W-2)*10 25 | prelu1: (H-2)*(W-2)*10 26 | 27 | pool1: ((H-2)/2)*((W-2)/2)*10 28 | 29 | conv2: ((H-2)/2-2)*((W-2)/2-2)*16 30 | prelu2: ((H-2)/2-2)*((W-2)/2-2)*16 31 | 32 | conv3: ((H-2)/2-4)*((W-2)/2-4)*32 33 | prelu3: ((H-2)/2-4)*((W-2)/2-4)*32 34 | 35 | conv4_1: ((H-2)/2-4)*((W-2)/2-4)*2 36 | conv4_2: ((H-2)/2-4)*((W-2)/2-4)*4 37 | 38 | The last feature map size is: (H - 10)/2 = (H - 12)/2 + 1. 39 | Thus the effect of PNet equals to moving 12*12 convolution window with 40 | kernel size 3, stirde 2. 41 | ''' 42 | 43 | self.features = nn.Sequential(OrderedDict([ 44 | ('conv1', nn.Conv2d(3, 10, 3, 1)), 45 | ('prelu1', nn.PReLU(10)), 46 | ('pool1', nn.MaxPool2d(2, 2, ceil_mode=False)), 47 | 48 | ('conv2', nn.Conv2d(10, 16, 3, 1)), 49 | ('prelu2', nn.PReLU(16)), 50 | 51 | ('conv3', nn.Conv2d(16, 32, 3, 1)), 52 | ('prelu3', nn.PReLU(32)), 53 | ])) 54 | 55 | self.conv4_1 = nn.Conv2d(32, 2, 1, 1) 56 | self.conv4_2 = nn.Conv2d(32, 4, 1, 1) 57 | 58 | def forward(self, x): 59 | x = self.features(x) 60 | scores = self.conv4_1(x) 61 | offsets = self.conv4_2(x) 62 | 63 | # append softmax for inference 64 | if not self.is_train: 65 | socres = F.softmax(scores, dim=1) 66 | 67 | return scores, offsets 68 | 69 | class RNet(nn.Module): 70 | '''Input size should be 24*24*3''' 71 | def __init__(self, is_train=False): 72 | super(RNet, self).__init__() 73 | self.is_train = is_train 74 | 75 | self.features = nn.Sequential(OrderedDict([ 76 | ('conv1', nn.Conv2d(3, 28, 3, 1)), # 24 -2 = 22 77 | ('prelu1', nn.PReLU(28)), 78 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=False)), # (22-3)/2 + 1 = 10 79 | 80 | ('conv2', nn.Conv2d(28, 48, 3, 1)), # 10 - 2 = 8 81 | ('prelu2', nn.PReLU(48)), 82 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=False)), # (8-3)/2 + 1 = 3 83 | 84 | ('conv3', nn.Conv2d(48, 64, 2, 1)), # 3 - 1 = 2 85 | ('prelu3', nn.PReLU(64)), 86 | 87 | ('flatten', Flatten()), 88 | ('conv4', nn.Linear(64*2*2, 128)), 89 | ('prelu4', nn.PReLU(128)) 90 | ])) 91 | 92 | self.conv5_1 = nn.Linear(128, 2) 93 | self.conv5_2 = nn.Linear(128, 4) 94 | 95 | def forward(self, x): 96 | x = self.features(x) 97 | scores = self.conv5_1(x) 98 | offsets = self.conv5_2(x) 99 | 100 | if not self.is_train: 101 | scores = F.softmax(scores, dim=1) 102 | return scores, offsets 103 | 104 | class ONet(nn.Module): 105 | '''Input size should be 48*48*3''' 106 | def __init__(self, is_train=False, train_landmarks=False): 107 | super(ONet, self).__init__() 108 | 109 | self.is_train = is_train 110 | self.train_landmarks = train_landmarks 111 | 112 | self.features = nn.Sequential(OrderedDict([ 113 | ('conv1', nn.Conv2d(3, 32, 3, 1)), # 48 - 2 = 46 114 | ('prelu1', nn.PReLU(32)), 115 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=False)), # (46-3)/2 + 1 = 22 116 | 117 | ('conv2', nn.Conv2d(32, 64, 3, 1)), # 22 - 2 = 20 118 | ('prelu2', nn.PReLU(64)), 119 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=False)), # (20-3)/2 + 1 = 9 120 | 121 | ('conv3', nn.Conv2d(64, 64, 3, 1)), # 9 - 2 = 7 122 | ('prelu3', nn.PReLU(64)), 123 | ('pool3', nn.MaxPool2d(2, 2, ceil_mode=False)), # (7-2)/2 + 1 = 3 124 | 125 | ('conv4', nn.Conv2d(64, 128, 2, 1)), # 3 - 1 = 2 126 | ('prelu4', nn.PReLU(128)), 127 | 128 | ('flatten', Flatten()), 129 | ('conv5', nn.Linear(128*2*2, 256)), 130 | ('prelu5', nn.PReLU(256)) 131 | ])) 132 | 133 | self.conv6_1 = nn.Linear(256, 2) 134 | self.conv6_2 = nn.Linear(256, 4) 135 | self.conv6_3 = nn.Linear(256, 10) 136 | 137 | def forward(self, x): 138 | x = self.features(x) 139 | scores = self.conv6_1(x) 140 | offsets = self.conv6_2(x) 141 | 142 | if not self.is_train: 143 | scores = F.softmax(scores, dim=1) 144 | 145 | if self.train_landmarks: 146 | landmarks = self.conv6_3(x) 147 | return scores, offsets, landmarks 148 | 149 | return scores, offsets 150 | 151 | 152 | if __name__ == "__main__": 153 | pnet = PNet(is_train=False) 154 | pnet.load_state_dict(torch.load('./pretrained_weights/best_pnet.pth')) 155 | torch.onnx.export(pnet, torch.randn(1, 3, 12, 12), './onnx2ncnn/pnet.onnx', 156 | input_names=['input'], output_names=['scores', 'offsets']) 157 | #summary(pnet.cuda(), (3, 12, 12)) 158 | 159 | rnet = RNet(is_train=False) 160 | rnet.load_state_dict(torch.load('./pretrained_weights/best_rnet.pth')) 161 | torch.onnx.export(rnet, torch.randn(1, 3, 24, 24), './onnx2ncnn/rnet.onnx', 162 | input_names=['input'], output_names=['scores', 'offsets']) 163 | #summary(rnet.cuda(), (3, 24, 24)) 164 | 165 | 166 | onet = ONet(is_train=False) 167 | #onet.load_state_dict(torch.load('./pretrained_weights/best_onet.pth')) 168 | #summary(onet.cuda(), (3, 48, 48)) 169 | -------------------------------------------------------------------------------- /nets/prunable_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | import torch 4 | 5 | 6 | class PConv2d(nn.Conv2d): 7 | """ 8 | Exactly like a Conv2d, but saves the activation of the last forward pass 9 | This allows calculation of the taylor estimate in https://arxiv.org/abs/1611.06440 10 | Includes convenience functions for feature map pruning 11 | """ 12 | 13 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): 14 | super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 15 | self.__recent_activations = None 16 | self.taylor_estimates = None 17 | self.register_backward_hook(self.__estimate_taylor_importance) 18 | 19 | def forward(self, x): 20 | output = super().forward(x) 21 | self.__recent_activations = output.clone() 22 | return output 23 | 24 | def __estimate_taylor_importance(self, _, grad_input, grad_output): 25 | # skip dim=1, its the dim for depth 26 | n_batch, _, n_x, n_y = self.__recent_activations.size() 27 | n_parameters = n_batch * n_x * n_y 28 | 29 | estimates = self.__recent_activations.mul_(grad_output[0]) \ 30 | .sum(dim=3) \ 31 | .sum(dim=2) \ 32 | .sum(dim=0) \ 33 | .div_(n_parameters) 34 | 35 | # normalization 36 | self.taylor_estimates = torch.abs(estimates) / torch.sqrt(torch.sum(estimates * estimates)) 37 | del estimates, self.__recent_activations 38 | self.__recent_activations = None 39 | 40 | def prune_feature_map(self, map_index, device): 41 | # is_cuda = self.weight.is_cuda 42 | 43 | indices = Variable(torch.LongTensor([i for i in range(self.out_channels) if i != map_index])) 44 | # indices = indices.cuda() if is_cuda else indices 45 | indices = indices.to(device) 46 | 47 | self.weight = nn.Parameter(self.weight.index_select(0, indices).data) 48 | self.bias = nn.Parameter(self.bias.index_select(0, indices).data) 49 | self.out_channels -= 1 50 | 51 | def drop_input_channel(self, index, device): 52 | """ 53 | Use when a convnet earlier in the chain is pruned. Reduces input channel count 54 | :param index: 55 | :return: 56 | """ 57 | # is_cuda = self.weight.is_cuda 58 | 59 | indices = Variable(torch.LongTensor([i for i in range(self.in_channels) if i != index])) 60 | # indices = indices.cuda() if is_cuda else indices 61 | indices = indices.to(device) 62 | 63 | self.weight = nn.Parameter(self.weight.index_select(1, indices).data) 64 | self.in_channels -= 1 65 | 66 | 67 | class PLinear(nn.Linear): 68 | 69 | def drop_inputs(self, input_shape, index, dim=0, device='cpu'): 70 | """ 71 | Previous layer is expected to be a convnet which just underwent pruning 72 | Drop cells connected to the pruned layer of the convnet 73 | :param input_shape: shape of inputs before flattening, should exclude batch_size 74 | :param index: index to drop 75 | :param dim: dimension where index is dropped, w.r.t input_shape 76 | :return: 77 | """ 78 | # is_cuda = self.weight.is_cuda 79 | 80 | reshaped = self.weight.view(-1, *input_shape) 81 | dim_length = input_shape[dim] 82 | indices = Variable(torch.LongTensor([i for i in range(dim_length) if i != index])) 83 | # indices = indices.cuda() if is_cuda else indices 84 | indices = indices.to(device) 85 | 86 | self.weight = nn.Parameter( 87 | reshaped.index_select(dim+1, indices) 88 | .data 89 | .view(self.out_features, -1) 90 | ) 91 | self.in_features = self.weight.size()[1] 92 | 93 | class PPReLU(nn.PReLU): 94 | def drop_input_channel(self, index, device): 95 | # is_cuda = self.weight.is_cuda 96 | # choose indexs not equal to index to drop 97 | indices = Variable(torch.LongTensor([i for i in range(self.num_parameters) if i != index])) 98 | # indices = indices.cuda() if is_cuda else indices 99 | indices = indices.to(device) 100 | self.weight = nn.Parameter(self.weight.index_select(0, indices).data) 101 | self.num_parameters -= 1 102 | 103 | class PBatchNorm2d(nn.BatchNorm2d): 104 | 105 | def drop_input_channel(self, index, device): 106 | if self.affine: 107 | # is_cuda = self.weight.is_cuda 108 | indices = Variable(torch.LongTensor([i for i in range(self.num_features) if i != index])) 109 | # indices = indices.cuda() if is_cuda else indices 110 | indices = indices.to(device) 111 | 112 | self.weight = nn.Parameter(self.weight.index_select(0, indices).data) 113 | self.bias = nn.Parameter(self.bias.index_select(0, indices).data) 114 | self.running_mean = self.running_mean.index_select(0, indices.data) 115 | self.running_var = self.running_var.index_select(0, indices.data) 116 | 117 | self.num_features -= 1 118 | 119 | -------------------------------------------------------------------------------- /nets/prune_mtcnn.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append('.') 3 | from collections import OrderedDict 4 | from operator import itemgetter 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torchsummary import summary 10 | 11 | import nets.prunable_layers as pnn 12 | 13 | class PNet(nn.Module): 14 | '''12*12 stride 2''' 15 | def __init__(self, is_train=False): 16 | super(PNet, self).__init__() 17 | self.is_train = is_train 18 | 19 | ''' 20 | conv1: (H-2)*(W-2)*10 21 | prelu1: (H-2)*(W-2)*10 22 | 23 | pool1: ((H-2)/2)*((W-2)/2)*10 24 | 25 | conv2: ((H-2)/2-2)*((W-2)/2-2)*16 26 | prelu2: ((H-2)/2-2)*((W-2)/2-2)*16 27 | 28 | conv3: ((H-2)/2-4)*((W-2)/2-4)*32 29 | prelu3: ((H-2)/2-4)*((W-2)/2-4)*32 30 | 31 | conv4_1: ((H-2)/2-4)*((W-2)/2-4)*2 32 | conv4_2: ((H-2)/2-4)*((W-2)/2-4)*4 33 | 34 | The last feature map size is: (H - 10)/2 = (H - 12)/2 + 1. 35 | Thus the effect of PNet equals to moving 12*12 convolution window with 36 | kernel size 3, stirde 2. 37 | ''' 38 | 39 | self.features = nn.Sequential(OrderedDict([ 40 | ('conv1', pnn.PConv2d(3, 10, 3, 1)), 41 | ('prelu1', pnn.PPReLU(10)), 42 | ('pool1', nn.MaxPool2d(2, 2, ceil_mode=False)), 43 | 44 | ('conv2', pnn.PConv2d(10, 16, 3, 1)), 45 | ('prelu2', pnn.PPReLU(16)), 46 | 47 | ('conv3', pnn.PConv2d(16, 32, 3, 1)), 48 | ('prelu3', pnn.PPReLU(32)), 49 | ])) 50 | 51 | self.conv4_1 = pnn.PConv2d(32, 2, 1, 1) 52 | self.conv4_2 = pnn.PConv2d(32, 4, 1, 1) 53 | 54 | def forward(self, x): 55 | x = self.features(x) 56 | scores = self.conv4_1(x) 57 | offsets = self.conv4_2(x) 58 | 59 | # append softmax for inference 60 | if not self.is_train: 61 | socres = F.softmax(scores, dim=1) 62 | 63 | return scores, offsets 64 | 65 | def prune(self, device): 66 | features = list(self.features) 67 | per_layer_taylor_estimates = [(module.taylor_estimates, layer_idx) 68 | for layer_idx, module in enumerate(features) 69 | if issubclass(type(module), pnn.PConv2d) and module.out_channels > 1] 70 | per_filter_taylor_estimates = [(per_filter_estimate, filter_idx, layer_idx) 71 | for per_layer_estimate, layer_idx in per_layer_taylor_estimates 72 | for filter_idx, per_filter_estimate in enumerate(per_layer_estimate)] 73 | 74 | _, min_filter_idx, min_layer_idx = min(per_filter_taylor_estimates, key=itemgetter(0)) 75 | pconv2d = self.features[min_layer_idx] 76 | pconv2d.prune_feature_map(min_filter_idx, device) 77 | 78 | prelu = self.features[min_layer_idx+1] 79 | prelu.drop_input_channel(min_filter_idx, device) 80 | 81 | next_conv = None 82 | next_conv_layer_idx = min_layer_idx+1 83 | 84 | while next_conv_layer_idx < len(self.features._modules.items()): 85 | res = list(self.features._modules.items())[next_conv_layer_idx] 86 | if isinstance(res[1], pnn.PConv2d): 87 | next_name, next_conv = res 88 | break 89 | next_conv_layer_idx += 1 90 | if next_conv is not None: 91 | next_conv.drop_input_channel(min_filter_idx) 92 | 93 | # if it's the last conv in self.features 94 | if min_layer_idx+1 == len(self.features._modules.items())-1: 95 | self.conv4_1.drop_input_channel(min_filter_idx, device) 96 | self.conv4_2.drop_input_channel(min_filter_idx, device) 97 | 98 | if __name__ == "__main__": 99 | pnet = PNet(is_train=False) 100 | summary(pnet.cuda(), (3, 12, 12)) 101 | -------------------------------------------------------------------------------- /nets/quant_mtcnn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import brevitas.nn as qnn 7 | from brevitas.core.quant import QuantType 8 | from torchsummary import summary 9 | 10 | class Flatten(nn.Module): 11 | def __init__(self): 12 | super(Flatten, self).__init__() 13 | 14 | def forward(self, x): 15 | # without this pretrained model won't working 16 | # x = x.transpose(3, 2).contiguous() 17 | return x.view(x.size(0), -1) 18 | 19 | class PNet(nn.Module): 20 | '''12*12 stride 2''' 21 | def __init__(self, is_train=False): 22 | super(PNet, self).__init__() 23 | self.is_train = is_train 24 | 25 | ''' 26 | conv1: (H-2)*(W-2)*10 27 | prelu1: (H-2)*(W-2)*10 28 | 29 | pool1: ((H-2)/2)*((W-2)/2)*10 30 | 31 | conv2: ((H-2)/2-2)*((W-2)/2-2)*16 32 | prelu2: ((H-2)/2-2)*((W-2)/2-2)*16 33 | 34 | conv3: ((H-2)/2-4)*((W-2)/2-4)*32 35 | prelu3: ((H-2)/2-4)*((W-2)/2-4)*32 36 | 37 | conv4_1: ((H-2)/2-4)*((W-2)/2-4)*2 38 | conv4_2: ((H-2)/2-4)*((W-2)/2-4)*4 39 | 40 | The last feature map size is: (H - 10)/2 = (H - 12)/2 + 1. 41 | Thus the effect of PNet equals to moving 12*12 convolution window with 42 | kernel size 3, stirde 2. 43 | ''' 44 | 45 | self.features = nn.Sequential(OrderedDict([ 46 | ('conv1', qnn.QuantConv2d(3, 10, 3, 1, 47 | weight_quant_type=QuantType.INT, weight_bit_width=8)), 48 | ('prelu1', qnn.QuantReLU(quant_type=QuantType.INT, bit_width=8, max_val=6)), 49 | ('pool1', nn.MaxPool2d(2, 2, ceil_mode=False)), 50 | 51 | ('conv2', qnn.QuantConv2d(10, 16, 3, 1, 52 | weight_quant_type=QuantType.INT, weight_bit_width=8)), 53 | ('prelu2', qnn.QuantReLU(quant_type=QuantType.INT, bit_width=8, max_val=6)), 54 | 55 | ('conv3', qnn.QuantConv2d(16, 32, 3, 1, 56 | weight_quant_type=QuantType.INT, weight_bit_width=8)), 57 | ('prelu3', qnn.QuantReLU(quant_type=QuantType.INT, bit_width=8, max_val=6)), 58 | ])) 59 | 60 | self.conv4_1 = qnn.QuantConv2d(32, 2, 1, 1, 61 | weight_quant_type=QuantType.INT, weight_bit_width=8) 62 | self.conv4_2 = qnn.QuantConv2d(32, 4, 1, 1, 63 | weight_quant_type=QuantType.INT, weight_bit_width=8) 64 | 65 | def forward(self, x): 66 | x = self.features(x) 67 | scores = self.conv4_1(x) 68 | offsets = self.conv4_2(x) 69 | 70 | # append softmax for inference 71 | if not self.is_train: 72 | socres = F.softmax(scores, dim=1) 73 | 74 | return scores, offsets 75 | 76 | class RNet(nn.Module): 77 | '''Input size should be 24*24*3''' 78 | def __init__(self, is_train=False): 79 | super(RNet, self).__init__() 80 | self.is_train = is_train 81 | 82 | self.features = nn.Sequential(OrderedDict([ 83 | ('conv1', qnn.QuantConv2d(3, 28, 3, 1, 84 | weight_quant_type=QuantType.INT, weight_bit_width=8)), # 24 -2 = 22 85 | ('prelu1', qnn.QuantReLU(quant_type=QuantType.INT, bit_width=8, max_val=6)), 86 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=False)), # (22-3)/2 + 1 = 10 87 | 88 | ('conv2', qnn.QuantConv2d(28, 48, 3, 1, 89 | weight_quant_type=QuantType.INT, weight_bit_width=8)), # 10 - 2 = 8 90 | ('prelu2', qnn.QuantReLU(quant_type=QuantType.INT, bit_width=8, max_val=6)), 91 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=False)), # (8-3)/2 + 1 = 3 92 | 93 | ('conv3', qnn.QuantConv2d(48, 64, 2, 1, 94 | weight_quant_type=QuantType.INT, weight_bit_width=8)), # 3 - 1 = 2 95 | ('prelu3', qnn.QuantReLU(quant_type=QuantType.INT, bit_width=8, max_val=6)), 96 | 97 | ('flatten', Flatten()), 98 | ('conv4', qnn.QuantLinear(64*2*2, 128, 99 | weight_quant_type=QuantType.INT, bias=False, weight_bit_width=8)), 100 | ('prelu4', qnn.QuantReLU(quant_type=QuantType.INT, bit_width=8, max_val=6)), 101 | #('dropout', nn.Dropout(0.2)) 102 | ])) 103 | 104 | self.conv5_1 = qnn.QuantLinear(128, 2, 105 | weight_quant_type=QuantType.INT, bias=False, weight_bit_width=8) 106 | self.conv5_2 = qnn.QuantLinear(128, 4, 107 | weight_quant_type=QuantType.INT, bias=False, weight_bit_width=8) 108 | 109 | def forward(self, x): 110 | x = self.features(x) 111 | scores = self.conv5_1(x) 112 | offsets = self.conv5_2(x) 113 | 114 | if not self.is_train: 115 | scores = F.softmax(scores, dim=1) 116 | return scores, offsets 117 | 118 | class ONet(nn.Module): 119 | '''Input size should be 48*48*3''' 120 | def __init__(self, is_train=False, train_landmarks=False): 121 | super(ONet, self).__init__() 122 | 123 | self.is_train = is_train 124 | self.train_landmarks = train_landmarks 125 | 126 | self.features = nn.Sequential(OrderedDict([ 127 | ('conv1', qnn.QuantConv2d(3, 32, 3, 1, 128 | weight_quant_type=QuantType.INT, weight_bit_width=8)), # 48 - 2 = 46 129 | ('prelu1', qnn.QuantReLU(quant_type=QuantType.INT, bit_width=8, max_val=6)), 130 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode=False)), # (46-3)/2 + 1 = 22 131 | 132 | ('conv2', qnn.QuantConv2d(32, 64, 3, 1, 133 | weight_quant_type=QuantType.INT, weight_bit_width=8)), # 22 - 2 = 20 134 | ('prelu2', qnn.QuantReLU(quant_type=QuantType.INT, bit_width=8, max_val=6)), 135 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode=False)), # (20-3)/2 + 1 = 9 136 | 137 | ('conv3', qnn.QuantConv2d(64, 64, 3, 1, 138 | weight_quant_type=QuantType.INT, bit_width=8, max_val=6)), # 9 - 2 = 7 139 | ('prelu3', qnn.QuantReLU(quant_type=QuantType.INT, bit_width=8, max_val=6)), 140 | ('pool3', nn.MaxPool2d(2, 2, ceil_mode=False)), # (7-2)/2 + 1 = 3 141 | 142 | ('conv4', qnn.QuantConv2d(64, 128, 2, 1, 143 | weight_quant_type=QuantType.INT, bit_width=8, max_val=6)), # 3 - 1 = 2 144 | ('prelu4', qnn.QuantReLU(quant_type=QuantType.INT, bit_width=8, max_val=6)), 145 | 146 | ('flatten', Flatten()), 147 | ('conv5', qnn.QuantLinear(128*2*2, 256, 148 | weight_quant_type=QuantType.INT, bias=False, weight_bit_width=8)), 149 | ('prelu5', qnn.QuantReLU(quant_type=QuantType.INT, bit_width=8, max_val=6)), 150 | ('dropout', nn.Dropout(0.2)) 151 | ])) 152 | 153 | self.conv6_1 = qnn.QuantLinear(256, 2, 154 | weight_quant_type=QuantType.INT, bias=False, weight_bit_width=8) 155 | self.conv6_2 = qnn.QuantLinear(256, 4, 156 | weight_quant_type=QuantType.INT, bias=False, weight_bit_width=8) 157 | self.conv6_3 = qnn.QuantLinear(256, 10, 158 | weight_quant_type=QuantType.INT, bias=False, weight_bit_width=8) 159 | 160 | def forward(self, x): 161 | x = self.features(x) 162 | scores = self.conv6_1(x) 163 | offsets = self.conv6_2(x) 164 | 165 | if not self.is_train: 166 | scores = F.softmax(scores, dim=1) 167 | 168 | if self.train_landmarks: 169 | landmarks = self.conv6_3(x) 170 | return scores, offsets, landmarks 171 | 172 | return scores, offsets 173 | 174 | 175 | if __name__ == "__main__": 176 | pnet = PNet(is_train=False) 177 | print(pnet) 178 | pnet(torch.randn(1, 3, 12, 12)) 179 | #pnet.load_state_dict(torch.load('./pretrained_weights/mtcnn/best_pnet.pth')) 180 | #torch.onnx.export(pnet, torch.randn(1, 3, 12, 12), './onnx2ncnn/pnet.onnx', 181 | # input_names=['input'], output_names=['scores', 'offsets']) 182 | #summary(pnet.cuda(), (3, 12, 12)) 183 | 184 | rnet = RNet(is_train=False) 185 | print(rnet) 186 | rnet(torch.randn(1, 3, 24, 24)) 187 | #rnet.load_state_dict(torch.load('./pretrained_weights/mtcnn/best_rnet.pth')) 188 | #torch.onnx.export(rnet, torch.randn(1, 3, 24, 24), './onnx2ncnn/rnet.onnx', 189 | # input_names=['input'], output_names=['scores', 'offsets']) 190 | #summary(rnet.cuda(), (3, 24, 24)) 191 | 192 | #onet = ONet(is_train=False) 193 | #onet.load_state_dict(torch.load('./pretrained_weights/mtcnn/best_onet.pth')) 194 | #torch.onnx.export(onet, torch.randn(1, 3, 48, 48), './onnx2ncnn/onet.onnx', 195 | # input_names=['input'], output_names=['scores', 'offsets']) 196 | #summary(onet.cuda(), (3, 48, 48)) 197 | -------------------------------------------------------------------------------- /nets/slim_mtcnn.py: -------------------------------------------------------------------------------- 1 | '''PNet, RNet, ONet, inspired from shufflenet and mobilenet''' 2 | 3 | import math 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torchsummary import summary 10 | 11 | def channel_shuffle(x, groups): 12 | """Shuffle channels, from PyTorch official code 13 | 14 | Parameters: 15 | ----------- 16 | x: pytorch tensor have shape N*C*H*C 17 | groups: num of groups to split channels 18 | Returns: channel shuffled pytorch tensor 19 | -------- 20 | """ 21 | batch_size, num_channels, height, width = x.data.size() 22 | channels_per_group = num_channels // groups 23 | 24 | # reshape 25 | x = x.view(batch_size, groups, channels_per_group, height, width) 26 | x = torch.transpose(x, 1, 2).contiguous() 27 | 28 | # flatten 29 | x = x.view(batch_size, -1, height, width) 30 | 31 | return x 32 | 33 | class ShuffleConv2d(nn.Module): 34 | def __init__(self, inp, oup, kernel_size=3, stride=2, groups=2): 35 | super(ShuffleConv2d, self).__init__() 36 | self.groups = groups 37 | inp_per_group = inp // groups 38 | oup_per_group = oup // groups 39 | self.shuffle_conv = nn.Sequential( 40 | nn.Conv2d(inp_per_group, inp_per_group, kernel_size=kernel_size, 41 | stride=stride, groups=inp_per_group), 42 | nn.PReLU(), 43 | nn.Conv2d(inp_per_group, oup_per_group, kernel_size=1, stride=1), 44 | nn.PReLU() 45 | ) 46 | 47 | def forward(self, x): 48 | x_list = list(x.chunk(self.groups, dim=1)) 49 | for i, x in enumerate(x_list): 50 | x_list[i] = self.shuffle_conv(x) 51 | x = torch.cat(x_list, dim=1) 52 | x = channel_shuffle(x, self.groups) 53 | return x 54 | 55 | class DepthwiseConv2d(nn.Module): 56 | def __init__(self, inp, oup, stride): 57 | super(DepthwiseConv2d, self).__init__() 58 | self.use_res = (inp == oup) and (stride == 1) 59 | self.feature = nn.Sequential( 60 | nn.Conv2d(inp, inp, kernel_size=3, stride=stride, groups=inp), 61 | nn.PReLU(), 62 | nn.Conv2d(inp, oup, kernel_size=1, stride=1), 63 | nn.PReLU() 64 | ) 65 | 66 | def forward(self, x): 67 | out = self.feature(x) 68 | if self.use_res: 69 | out = out + x 70 | return out 71 | 72 | class Flatten(nn.Module): 73 | def __init__(self): 74 | super(Flatten, self).__init__() 75 | 76 | def forward(self, x): 77 | # without this pretrained model won't working 78 | x = x.transpose(3, 2).contiguous() 79 | return x.view(x.size(0), -1) 80 | 81 | class PNet(nn.Module): 82 | '''12*12 stride 2''' 83 | def __init__(self, is_train=False): 84 | super(PNet, self).__init__() 85 | 86 | self.is_train = is_train 87 | 88 | self.conv1 = nn.Sequential( 89 | nn.Conv2d(3, 12, kernel_size=3, stride=1), 90 | nn.PReLU(), 91 | nn.MaxPool2d(2, 2, ceil_mode=False) 92 | ) 93 | self.conv2 = ShuffleConv2d(12, 24, 3, 1, 2) 94 | self.conv3 = ShuffleConv2d(24, 32, 3, 1, 2) 95 | 96 | self.conv4_1 = nn.Conv2d(32, 2, kernel_size=1) 97 | self.conv4_2 = nn.Conv2d(32, 4, kernel_size=1) 98 | 99 | def forward(self, x): 100 | x = self.conv1(x) 101 | x = self.conv2(x) 102 | x = self.conv3(x) 103 | 104 | scores = self.conv4_1(x) 105 | offsets = self.conv4_2(x) 106 | 107 | # append softmax for inference 108 | if not self.is_train: 109 | scores = F.softmax(scores, dim=1) 110 | 111 | return scores, offsets 112 | 113 | class RNet(nn.Module): 114 | def __init__(self, is_train=False): 115 | super(RNet, self).__init__() 116 | self.is_train = is_train 117 | 118 | self.conv1 = nn.Sequential( 119 | nn.Conv2d(3, 32, kernel_size=3, stride=1), 120 | nn.PReLU(), 121 | nn.MaxPool2d(3, 2, ceil_mode=False) 122 | ) 123 | self.conv2 = nn.Sequential( 124 | ShuffleConv2d(32, 64, 3, 1, 2), 125 | nn.MaxPool2d(3, 2, ceil_mode=False), 126 | ) 127 | self.conv3 = nn.Sequential( 128 | ShuffleConv2d(64, 128, 2, 1, 2) 129 | ) 130 | self.conv4 = nn.Sequential( 131 | Flatten(), 132 | nn.Linear(128*2*2, 128), 133 | nn.PReLU() 134 | ) 135 | 136 | self.conv5_1 = nn.Linear(128, 2) 137 | self.conv5_2 = nn.Linear(128, 4) 138 | 139 | 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 143 | m.weight.data.normal_(0, math.sqrt(2. / n)) 144 | 145 | def forward(self, x): 146 | x = self.conv1(x) 147 | x = self.conv2(x) 148 | x = self.conv3(x) 149 | x = self.conv4(x) 150 | 151 | scores = self.conv5_1(x) 152 | offsets = self.conv5_2(x) 153 | 154 | if not self.is_train: 155 | scores = F.softmax(scores, dim=1) 156 | 157 | return scores, offsets 158 | 159 | class ONet(nn.Module): 160 | def __init__(self, is_train=False, train_landmarks=False): 161 | super(ONet, self).__init__() 162 | self.is_train = is_train 163 | self.train_landmarks = train_landmarks 164 | 165 | self.conv1 = nn.Sequential( 166 | nn.Conv2d(3, 32, kernel_size=3), 167 | nn.PReLU(32), 168 | nn.MaxPool2d(3, 2, ceil_mode=False) 169 | ) 170 | self.conv2 = nn.Sequential( 171 | ShuffleConv2d(32, 64, 3, 1, 2), 172 | nn.MaxPool2d(3, 2, ceil_mode=False) 173 | ) 174 | self.conv3 = nn.Sequential( 175 | ShuffleConv2d(64, 128, 3, 1, 2), 176 | nn.MaxPool2d(2, 2, ceil_mode=False) 177 | ) 178 | self.conv4 = nn.Sequential( 179 | ShuffleConv2d(128, 256, 3, 1, 2) 180 | ) 181 | self.conv5 = nn.Sequential( 182 | Flatten(), 183 | nn.Linear(256, 256), 184 | nn.PReLU(256) 185 | ) 186 | 187 | self.conv6_1 = nn.Linear(256, 2) 188 | self.conv6_2 = nn.Linear(256, 4) 189 | self.conv6_3 = nn.Linear(256, 10) 190 | 191 | def forward(self, x): 192 | x = self.conv1(x) 193 | x = self.conv2(x) 194 | x = self.conv3(x) 195 | x = self.conv4(x) 196 | x = self.conv5(x) 197 | 198 | scores = self.conv6_1(x) 199 | offsets = self.conv6_2(x) 200 | 201 | if not self.is_train: 202 | scores = F.softmax(scores, dim=1) 203 | 204 | if self.train_landmarks: 205 | landmarks = self.conv6_3(x) 206 | 207 | return scores, offsets, landmarks 208 | 209 | return scores, offsets 210 | 211 | if __name__ == "__main__": 212 | #device = torch.device("cuda:1") 213 | #pnet = PNet(is_train=False) 214 | #pnet.load_state_dict(torch.load('./pretrained_weights/best_pnet.pth')) 215 | #torch.onnx.export(pnet, torch.randn(1, 3, 12, 12), './onnx2ncnn/pnet.onnx', 216 | # input_names=['input'], output_names=['scores', 'offsets']) 217 | #pnet.to(device) 218 | #summary(pnet, (3, 12, 12)) 219 | 220 | 221 | #rnet = RNet(is_train=False) 222 | #rnet.load_state_dict(torch.load('./pretrained_weights/best_rnet.pth')) 223 | #torch.onnx.export(rnet, torch.randn(1, 3, 24, 24), './onnx2ncnn/rnet.onnx', 224 | # input_names=['input'], output_names=['scores', 'offsets']) 225 | #summary(rnet.cuda(), (3, 24, 24)) 226 | 227 | onet = ONet(is_train=False) 228 | #onet.load_state_dict(torch.load('./pretrained_weights/best_onet.pth')) 229 | #summary(onet.cuda(), (3, 48, 48)) 230 | torch.onnx.export(onet, torch.randn(1, 3, 48, 48), './onnx2ncnn/onet.onnx', 231 | input_names=['input'], output_names=['scores', 'offsets']) 232 | -------------------------------------------------------------------------------- /onnx2ncnn/onet.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/digital-nomad-cheng/MTCNN_PyTorch/752b9f0f32f5c25df5647c7c279c24715dc3aa87/onnx2ncnn/onet.onnx -------------------------------------------------------------------------------- /onnx2ncnn/onnx2ncnn.sh: -------------------------------------------------------------------------------- 1 | python -m onnxsim pnet.onnx pnet_sim.onnx 2 | python -m onnxsim rnet.onnx rnet_sim.onnx 3 | python -m onnxsim onet.onnx onet_sim.onnx 4 | echo "Finished simplified" 5 | ~/Libs/ncnn/build/tools/onnx/onnx2ncnn pnet_sim.onnx mobile_pnet_sim.param mobile_pnet_sim.bin 6 | ~/Libs/ncnn/build/tools/onnx/onnx2ncnn rnet_sim.onnx mobile_rnet_sim.param mobile_rnet_sim.bin 7 | ~/Libs/ncnn/build/tools/onnx/onnx2ncnn onet_sim.onnx mobile_onet_sim.param mobile_onet_sim.bin 8 | -------------------------------------------------------------------------------- /onnx2ncnn/pnet.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/digital-nomad-cheng/MTCNN_PyTorch/752b9f0f32f5c25df5647c7c279c24715dc3aa87/onnx2ncnn/pnet.onnx -------------------------------------------------------------------------------- /onnx2ncnn/rnet.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/digital-nomad-cheng/MTCNN_PyTorch/752b9f0f32f5c25df5647c7c279c24715dc3aa87/onnx2ncnn/rnet.onnx -------------------------------------------------------------------------------- /pretrained_weights/mtcnn/best_onet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/digital-nomad-cheng/MTCNN_PyTorch/752b9f0f32f5c25df5647c7c279c24715dc3aa87/pretrained_weights/mtcnn/best_onet.pth -------------------------------------------------------------------------------- /pretrained_weights/mtcnn/best_onet_landmark.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/digital-nomad-cheng/MTCNN_PyTorch/752b9f0f32f5c25df5647c7c279c24715dc3aa87/pretrained_weights/mtcnn/best_onet_landmark.pth -------------------------------------------------------------------------------- /pretrained_weights/mtcnn/best_pnet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/digital-nomad-cheng/MTCNN_PyTorch/752b9f0f32f5c25df5647c7c279c24715dc3aa87/pretrained_weights/mtcnn/best_pnet.pth -------------------------------------------------------------------------------- /pretrained_weights/mtcnn/best_rnet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/digital-nomad-cheng/MTCNN_PyTorch/752b9f0f32f5c25df5647c7c279c24715dc3aa87/pretrained_weights/mtcnn/best_rnet.pth -------------------------------------------------------------------------------- /pretrained_weights/quant_mtcnn/best_pnet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/digital-nomad-cheng/MTCNN_PyTorch/752b9f0f32f5c25df5647c7c279c24715dc3aa87/pretrained_weights/quant_mtcnn/best_pnet.pth -------------------------------------------------------------------------------- /prunning/Pnet_Prune.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append('.') 3 | import math 4 | import time 5 | import argparse 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import Dataset 10 | 11 | from nets.mtcnn import PNet 12 | 13 | from prunning.utils.FilterPrunner import FilterPrunner 14 | from prunning.utils.prune_mtcnn import prune_mtcnn 15 | from prunning.utils.util import ListDataset, printProgressBar, total_num_filters 16 | 17 | def test(model, path): 18 | 19 | batch_size = 2048 20 | dataloader = torch.utils.data.DataLoader(ListDataset(path), batch_size=batch_size, shuffle=True) 21 | dataset_sizes = len(ListDataset(path)) 22 | 23 | model.eval() 24 | loss_cls = nn.CrossEntropyLoss() 25 | loss_offset = nn.MSELoss() 26 | 27 | running_correct = 0 28 | running_gt = 0 29 | running_loss, running_loss_cls, running_loss_offset = 0.0, 0.0, 0.0 30 | 31 | for i_batch, sample_batched in enumerate(dataloader): 32 | 33 | printProgressBar(i_batch + 1, dataset_sizes // batch_size + 1, prefix = 'Progress:', suffix = 'Complete', length = 50) 34 | 35 | input_images, gt_label, gt_offset = sample_batched['input_img'], sample_batched['label'], sample_batched['bbox_target'] 36 | input_images = input_images.to(device) 37 | gt_label = gt_label.to(device) 38 | gt_offset = gt_offset.type(torch.FloatTensor).to(device) 39 | 40 | with torch.set_grad_enabled(False): 41 | pred_label, pred_offsets = model(input_images) 42 | pred_offsets = torch.squeeze(pred_offsets) 43 | pred_label = torch.squeeze(pred_label) 44 | 45 | mask_cls = torch.ge(gt_label, 0) 46 | valid_gt_label = gt_label[mask_cls] 47 | valid_pred_label = pred_label[mask_cls] 48 | 49 | unmask = torch.eq(gt_label, 0) 50 | mask_offset = torch.eq(unmask, 0) 51 | valid_gt_offset = gt_offset[mask_offset] 52 | valid_pred_offset = pred_offsets[mask_offset] 53 | 54 | loss = torch.tensor(0.0).to(device) 55 | num_gt = len(valid_gt_label) 56 | 57 | if len(valid_gt_label) != 0: 58 | loss += 0.02*loss_cls(valid_pred_label, valid_gt_label) 59 | cls_loss = loss_cls(valid_pred_label, valid_gt_label).item() 60 | pred = torch.max(valid_pred_label, 1)[1] 61 | eval_correct = (pred == valid_gt_label).sum().item() 62 | 63 | if len(valid_gt_offset) != 0: 64 | loss += 0.6*loss_offset(valid_pred_offset, valid_gt_offset) 65 | offset_loss = loss_offset(valid_pred_offset, valid_gt_offset).item() 66 | 67 | # statistics 68 | running_loss += loss.item()*batch_size 69 | running_loss_cls += cls_loss*batch_size 70 | running_loss_offset += offset_loss*batch_size 71 | running_correct += eval_correct 72 | running_gt += num_gt 73 | 74 | epoch_loss = running_loss / dataset_sizes 75 | epoch_loss_cls = running_loss_cls / dataset_sizes 76 | epoch_loss_offset = running_loss_offset / dataset_sizes 77 | epoch_accuracy = running_correct / (running_gt + 1e-16) 78 | 79 | return epoch_accuracy, epoch_loss, epoch_loss_cls, epoch_loss_offset 80 | 81 | def train(model, path, epoch=10): 82 | 83 | batch_size = 2048 84 | dataloader = torch.utils.data.DataLoader(ListDataset(path), batch_size=batch_size, shuffle=True) 85 | dataset_sizes = len(ListDataset(path)) 86 | 87 | model.train() 88 | loss_cls = nn.CrossEntropyLoss() 89 | loss_offset = nn.MSELoss() 90 | 91 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 92 | 93 | num_epochs = epoch 94 | for epoch in range(num_epochs): 95 | print('Epoch {}/{}'.format(epoch, num_epochs-1)) 96 | 97 | running_loss, running_loss_cls, running_loss_offset = 0.0, 0.0, 0.0 98 | running_correct = 0.0 99 | running_gt = 0.0 100 | 101 | for i_batch, sample_batched in enumerate(dataloader): 102 | 103 | printProgressBar(i_batch + 1, dataset_sizes // batch_size + 1, prefix = 'Progress:', suffix = 'Complete', length = 50) 104 | 105 | input_images, gt_label, gt_offset = sample_batched['input_img'], sample_batched[ 106 | 'label'], sample_batched['bbox_target'] 107 | input_images = input_images.to(device) 108 | gt_label = gt_label.to(device) 109 | gt_offset = gt_offset.type(torch.FloatTensor).to(device) 110 | 111 | # zero the parameter gradients 112 | optimizer.zero_grad() 113 | 114 | with torch.set_grad_enabled(True): 115 | pred_label, pred_offsets = model(input_images) 116 | pred_offsets = torch.squeeze(pred_offsets) 117 | pred_label = torch.squeeze(pred_label) 118 | # calculate the cls loss 119 | # get the mask element which >= 0, only 0 and 1 can effect the detection loss 120 | mask_cls = torch.ge(gt_label, 0) 121 | valid_gt_label = gt_label[mask_cls] 122 | valid_pred_label = pred_label[mask_cls] 123 | 124 | # calculate the box loss 125 | # get the mask element which != 0 126 | unmask = torch.eq(gt_label, 0) 127 | mask_offset = torch.eq(unmask, 0) 128 | valid_gt_offset = gt_offset[mask_offset] 129 | valid_pred_offset = pred_offsets[mask_offset] 130 | 131 | loss = torch.tensor(0.0).to(device) 132 | cls_loss, offset_loss = 0.0, 0.0 133 | eval_correct = 0.0 134 | num_gt = len(valid_gt_label) 135 | 136 | if len(valid_gt_label) != 0: 137 | loss += 0.02*loss_cls(valid_pred_label, valid_gt_label) 138 | cls_loss = loss_cls(valid_pred_label, valid_gt_label).item() 139 | pred = torch.max(valid_pred_label, 1)[1] 140 | eval_correct = (pred == valid_gt_label).sum().item() 141 | 142 | if len(valid_gt_offset) != 0: 143 | loss += 0.6*loss_offset(valid_pred_offset, valid_gt_offset) 144 | offset_loss = loss_offset(valid_pred_offset, valid_gt_offset).item() 145 | 146 | loss.backward() 147 | optimizer.step() 148 | 149 | # statistics 150 | running_loss += loss.item()*batch_size 151 | running_loss_cls += cls_loss*batch_size 152 | running_loss_offset += offset_loss*batch_size 153 | running_correct += eval_correct 154 | running_gt += num_gt 155 | 156 | epoch_loss = running_loss / dataset_sizes 157 | epoch_loss_cls = running_loss_cls / dataset_sizes 158 | epoch_loss_offset = running_loss_offset / dataset_sizes 159 | epoch_accuracy = running_correct / (running_gt + 1e-16) 160 | 161 | print('accuracy: {:.4f} loss: {:.4f} cls Loss: {:.4f} offset Loss: {:.4f}' 162 | .format(epoch_accuracy, epoch_loss, epoch_loss_cls, epoch_loss_offset)) 163 | 164 | def prune_model(model, prunner, path): 165 | 166 | batch_size = 64 167 | dataloader = torch.utils.data.DataLoader(ListDataset(path), batch_size=batch_size, shuffle=True) 168 | dataset_sizes = len(ListDataset(path)) 169 | 170 | model.train() 171 | loss_cls = nn.CrossEntropyLoss() 172 | loss_offset = nn.MSELoss() 173 | 174 | prunner.reset() 175 | 176 | for i_batch, sample_batched in enumerate(dataloader): 177 | 178 | printProgressBar(i_batch + 1, dataset_sizes // batch_size + 1, prefix = 'Progress:', suffix = 'Complete', length = 50) 179 | 180 | input_images, gt_label, gt_offset = sample_batched['input_img'], sample_batched['label'], sample_batched['bbox_target'] 181 | input_images = input_images.to(device) 182 | gt_label = gt_label.to(device) 183 | gt_offset = gt_offset.type(torch.FloatTensor).to(device) 184 | 185 | # zero the parameter gradients 186 | model.zero_grad() 187 | 188 | with torch.set_grad_enabled(True): 189 | _, pred_offsets, pred_label = prunner.forward(input_images) 190 | pred_offsets = torch.squeeze(pred_offsets) 191 | pred_label = torch.squeeze(pred_label) 192 | # calculate the cls loss 193 | # get the mask element which >= 0, only 0 and 1 can effect the detection loss 194 | mask_cls = torch.ge(gt_label, 0) 195 | valid_gt_label = gt_label[mask_cls] 196 | valid_pred_label = pred_label[mask_cls] 197 | 198 | # calculate the box loss 199 | # get the mask element which != 0 200 | unmask = torch.eq(gt_label, 0) 201 | mask_offset = torch.eq(unmask, 0) 202 | valid_gt_offset = gt_offset[mask_offset] 203 | valid_pred_offset = pred_offsets[mask_offset] 204 | 205 | loss = torch.tensor(0.0).to(device) 206 | 207 | if len(valid_gt_label) != 0: 208 | loss += 0.02*loss_cls(valid_pred_label, valid_gt_label) 209 | 210 | if len(valid_gt_offset) != 0: 211 | loss += 0.6*loss_offset(valid_pred_offset, valid_gt_offset) 212 | 213 | loss.backward() 214 | 215 | prunner.normalize_ranks_per_layer() 216 | filters_to_prune = prunner.get_prunning_plan(args.filter_size) 217 | 218 | return filters_to_prune 219 | 220 | def get_args(): 221 | parser = argparse.ArgumentParser() 222 | parser.add_argument("--train_path", type = str, default = "./annotations/imglist_anno_12_train.txt") 223 | parser.add_argument("--test_path", type = str, default = "./annotations/imglist_anno_12_val.txt") 224 | parser.add_argument("--filter_size", type = int, default = 5) 225 | parser.add_argument("--filter_percentage", type = float, default = 0.5) 226 | args = parser.parse_args() 227 | return args 228 | 229 | if __name__ == '__main__': 230 | 231 | args = get_args() 232 | 233 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 234 | 235 | model = PNet(is_train=True).to(device) 236 | model.load_state_dict(torch.load("./pretrained_weights/mtcnn/best_pnet.pth", map_location=lambda storage, loc: storage)) 237 | 238 | prunner = FilterPrunner(model, use_cuda = True) 239 | 240 | save_dir = './prunning/saving_pnet_prunning_result' 241 | if os.path.exists(save_dir): 242 | raise NameError('model dir exists!') 243 | os.makedirs(save_dir) 244 | 245 | print("Check the initial model accuracy") 246 | since = time.time() 247 | accuracy, loss, loss_cls, loss_offset = test(model, args.test_path) 248 | print('initial test :: accuracy: {:.4f} loss: {:.4f} cls loss: {:.4f} offset loss: {:.4f}'.format(accuracy, loss, loss_cls, loss_offset)) 249 | print("initial test :: time cost is {:.2f} s".format(time.time()-since)) 250 | 251 | #Make sure all the layers are trainable 252 | for param in model.features.parameters(): 253 | param.requires_grad = True 254 | 255 | number_of_filters = total_num_filters(model) 256 | print("total model conv2D filters are: ", number_of_filters) 257 | 258 | num_filters_to_prune_per_iteration = args.filter_size 259 | 260 | iterations = math.ceil((float(number_of_filters) * args.filter_percentage) / num_filters_to_prune_per_iteration) 261 | print("Number of iterations to prune {} % filters:".format(args.filter_percentage*100), iterations) 262 | 263 | for it in range(iterations): 264 | 265 | print("iter{}. Ranking filters ..".format(it)) 266 | filters_to_prune = prune_model(model, prunner, args.test_path) 267 | 268 | layers_prunned = [(k, len(filters_to_prune[k])) for k in sorted(filters_to_prune.keys())] # k: layer index, number of filters 269 | print("iter{}. Layers that will be prunned".format(it), layers_prunned) 270 | 271 | print("iter{}. Prunning filters.. ".format(it)) 272 | for layer_index, filter_index in filters_to_prune.items(): 273 | model = prune_mtcnn(model, layer_index, *filter_index, use_cuda=False) 274 | model = model.to(device) 275 | 276 | print("iter{}. {:.2f}% Filters remaining".format(it, 100*float(total_num_filters(model)) / number_of_filters)) 277 | 278 | accuracy, loss, loss_cls, loss_offset = test(model, args.test_path) 279 | print('iter{}. without retrain :: accuracy: {:.4f} loss: {:.4f} cls loss: {:.4f} offset loss: {:.4f}'.format(it, accuracy, loss, loss_cls, loss_offset)) 280 | 281 | print("iter{}. Fine tuning to recover from prunning iteration.. ".format(it)) 282 | #torch.cuda.empty_cache() 283 | train(model, path=args.train_path, epoch = 6) 284 | 285 | since = time.time() 286 | accuracy, loss, loss_cls, loss_offset = test(model, args.test_path) 287 | print('iter{}. after retrain :: accuracy: {:.4f} loss: {:.4f} cls loss: {:.4f} offset loss: {:.4f}'.format(it, accuracy, loss, loss_cls, loss_offset)) 288 | print("iter{}. test time cost is {:.2f} s".format(it, time.time()-since)) 289 | 290 | torch.save(model.state_dict(), os.path.join(save_dir, 'pnet_weights_pruned_{}'.format(it))) 291 | torch.save(model, os.path.join(save_dir, 'pnet_prunned_{}'.format(it))) 292 | torch.onnx.export(model, torch.randn(1, 3, 12, 12).to(device), 293 | './onnx2ncnn/pruned_pnet_{}.onnx'.format(it), 294 | input_names=['input'], output_names=['scores', 'offsets']) 295 | print("Finished prunning") 296 | -------------------------------------------------------------------------------- /prunning/Rnet_Prune.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append('.') 3 | import math 4 | import time 5 | import argparse 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import Dataset 10 | 11 | from nets.mtcnn import RNet 12 | from prunning.utils.FilterPrunner import FilterPrunner 13 | from prunning.utils.prune_mtcnn import prune_mtcnn 14 | from prunning.utils.util import ListDataset, printProgressBar, total_num_filters 15 | 16 | def test(model, path): 17 | 18 | batch_size = 1024 19 | dataloader = torch.utils.data.DataLoader(ListDataset(path), batch_size=batch_size, shuffle=True) 20 | dataset_sizes = len(ListDataset(path)) 21 | 22 | model.eval() 23 | loss_cls = nn.CrossEntropyLoss() 24 | loss_offset = nn.MSELoss() 25 | 26 | running_correct = 0 27 | running_gt = 0 28 | running_loss, running_loss_cls, running_loss_offset = 0.0, 0.0, 0.0 29 | 30 | for i_batch, sample_batched in enumerate(dataloader): 31 | 32 | printProgressBar(i_batch + 1, dataset_sizes // batch_size + 1, prefix = 'Progress:', suffix = 'Complete', length = 50) 33 | 34 | input_images, gt_label, gt_offset = sample_batched['input_img'], sample_batched['label'], sample_batched['bbox_target'] 35 | input_images = input_images.to(device) 36 | gt_label = gt_label.to(device) 37 | gt_offset = gt_offset.type(torch.FloatTensor).to(device) 38 | 39 | with torch.set_grad_enabled(False): 40 | pred_label, pred_offsets = model(input_images) 41 | pred_offsets = torch.squeeze(pred_offsets) 42 | pred_label = torch.squeeze(pred_label) 43 | 44 | mask_cls = torch.ge(gt_label, 0) 45 | valid_gt_label = gt_label[mask_cls] 46 | valid_pred_label = pred_label[mask_cls] 47 | 48 | unmask = torch.eq(gt_label, 0) 49 | mask_offset = torch.eq(unmask, 0) 50 | valid_gt_offset = gt_offset[mask_offset] 51 | valid_pred_offset = pred_offsets[mask_offset] 52 | 53 | loss = torch.tensor(0.0).to(device) 54 | num_gt = len(valid_gt_label) 55 | 56 | if len(valid_gt_label) != 0: 57 | loss += 0.02*loss_cls(valid_pred_label, valid_gt_label) 58 | cls_loss = loss_cls(valid_pred_label, valid_gt_label).item() 59 | pred = torch.max(valid_pred_label, 1)[1] 60 | eval_correct = (pred == valid_gt_label).sum().item() 61 | 62 | if len(valid_gt_offset) != 0: 63 | loss += 0.6*loss_offset(valid_pred_offset, valid_gt_offset) 64 | offset_loss = loss_offset(valid_pred_offset, valid_gt_offset).item() 65 | 66 | # statistics 67 | running_loss += loss.item()*batch_size 68 | running_loss_cls += cls_loss*batch_size 69 | running_loss_offset += offset_loss*batch_size 70 | running_correct += eval_correct 71 | running_gt += num_gt 72 | 73 | epoch_loss = running_loss / dataset_sizes 74 | epoch_loss_cls = running_loss_cls / dataset_sizes 75 | epoch_loss_offset = running_loss_offset / dataset_sizes 76 | epoch_accuracy = running_correct / (running_gt + 1e-16) 77 | 78 | return epoch_accuracy, epoch_loss, epoch_loss_cls, epoch_loss_offset 79 | 80 | def train(model, path, epoch=10): 81 | 82 | batch_size = 1024 83 | dataloader = torch.utils.data.DataLoader(ListDataset(path), batch_size=batch_size, shuffle=True) 84 | dataset_sizes = len(ListDataset(path)) 85 | 86 | model.train() 87 | loss_cls = nn.CrossEntropyLoss() 88 | loss_offset = nn.MSELoss() 89 | 90 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 91 | 92 | num_epochs = epoch 93 | for epoch in range(num_epochs): 94 | print('Epoch {}/{}'.format(epoch, num_epochs-1)) 95 | 96 | running_loss, running_loss_cls, running_loss_offset = 0.0, 0.0, 0.0 97 | running_correct = 0.0 98 | running_gt = 0.0 99 | 100 | for i_batch, sample_batched in enumerate(dataloader): 101 | 102 | printProgressBar(i_batch + 1, dataset_sizes // batch_size + 1, prefix = 'Progress:', suffix = 'Complete', length = 50) 103 | 104 | input_images, gt_label, gt_offset = sample_batched['input_img'], sample_batched[ 105 | 'label'], sample_batched['bbox_target'] 106 | input_images = input_images.to(device) 107 | gt_label = gt_label.to(device) 108 | gt_offset = gt_offset.type(torch.FloatTensor).to(device) 109 | 110 | # zero the parameter gradients 111 | optimizer.zero_grad() 112 | 113 | with torch.set_grad_enabled(True): 114 | pred_label, pred_offsets = model(input_images) 115 | pred_offsets = torch.squeeze(pred_offsets) 116 | pred_label = torch.squeeze(pred_label) 117 | # calculate the cls loss 118 | # get the mask element which >= 0, only 0 and 1 can effect the detection loss 119 | mask_cls = torch.ge(gt_label, 0) 120 | valid_gt_label = gt_label[mask_cls] 121 | valid_pred_label = pred_label[mask_cls] 122 | 123 | # calculate the box loss 124 | # get the mask element which != 0 125 | unmask = torch.eq(gt_label, 0) 126 | mask_offset = torch.eq(unmask, 0) 127 | valid_gt_offset = gt_offset[mask_offset] 128 | valid_pred_offset = pred_offsets[mask_offset] 129 | 130 | loss = torch.tensor(0.0).to(device) 131 | cls_loss, offset_loss = 0.0, 0.0 132 | eval_correct = 0.0 133 | num_gt = len(valid_gt_label) 134 | 135 | if len(valid_gt_label) != 0: 136 | loss += 0.02*loss_cls(valid_pred_label, valid_gt_label) 137 | cls_loss = loss_cls(valid_pred_label, valid_gt_label).item() 138 | pred = torch.max(valid_pred_label, 1)[1] 139 | eval_correct = (pred == valid_gt_label).sum().item() 140 | 141 | if len(valid_gt_offset) != 0: 142 | loss += 0.6*loss_offset(valid_pred_offset, valid_gt_offset) 143 | offset_loss = loss_offset(valid_pred_offset, valid_gt_offset).item() 144 | 145 | loss.backward() 146 | optimizer.step() 147 | 148 | # statistics 149 | running_loss += loss.item()*batch_size 150 | running_loss_cls += cls_loss*batch_size 151 | running_loss_offset += offset_loss*batch_size 152 | running_correct += eval_correct 153 | running_gt += num_gt 154 | 155 | epoch_loss = running_loss / dataset_sizes 156 | epoch_loss_cls = running_loss_cls / dataset_sizes 157 | epoch_loss_offset = running_loss_offset / dataset_sizes 158 | epoch_accuracy = running_correct / (running_gt + 1e-16) 159 | 160 | print('accuracy: {:.4f} loss: {:.4f} cls Loss: {:.4f} offset Loss: {:.4f}' 161 | .format(epoch_accuracy, epoch_loss, epoch_loss_cls, epoch_loss_offset)) 162 | 163 | def prune_model(model, prunner, path): 164 | 165 | batch_size = 1024 166 | dataloader = torch.utils.data.DataLoader(ListDataset(path), batch_size=batch_size, shuffle=True) 167 | dataset_sizes = len(ListDataset(path)) 168 | 169 | model.train() 170 | loss_cls = nn.CrossEntropyLoss() 171 | loss_offset = nn.MSELoss() 172 | 173 | prunner.reset() 174 | 175 | for i_batch, sample_batched in enumerate(dataloader): 176 | 177 | printProgressBar(i_batch + 1, dataset_sizes // batch_size + 1, prefix = 'Progress:', suffix = 'Complete', length = 50) 178 | 179 | input_images, gt_label, gt_offset = sample_batched['input_img'], sample_batched['label'], sample_batched['bbox_target'] 180 | input_images = input_images.to(device) 181 | gt_label = gt_label.to(device) 182 | gt_offset = gt_offset.type(torch.FloatTensor).to(device) 183 | 184 | # zero the parameter gradients 185 | model.zero_grad() 186 | 187 | with torch.set_grad_enabled(True): 188 | _, pred_offsets, pred_label = prunner.forward(input_images) 189 | pred_offsets = torch.squeeze(pred_offsets) 190 | pred_label = torch.squeeze(pred_label) 191 | # calculate the cls loss 192 | # get the mask element which >= 0, only 0 and 1 can effect the detection loss 193 | mask_cls = torch.ge(gt_label, 0) 194 | valid_gt_label = gt_label[mask_cls] 195 | valid_pred_label = pred_label[mask_cls] 196 | 197 | # calculate the box loss 198 | # get the mask element which != 0 199 | unmask = torch.eq(gt_label, 0) 200 | mask_offset = torch.eq(unmask, 0) 201 | valid_gt_offset = gt_offset[mask_offset] 202 | valid_pred_offset = pred_offsets[mask_offset] 203 | 204 | loss = torch.tensor(0.0).to(device) 205 | 206 | if len(valid_gt_label) != 0: 207 | loss += 0.02*loss_cls(valid_pred_label, valid_gt_label) 208 | 209 | if len(valid_gt_offset) != 0: 210 | loss += 0.6*loss_offset(valid_pred_offset, valid_gt_offset) 211 | 212 | loss.backward() 213 | 214 | prunner.normalize_ranks_per_layer() 215 | filters_to_prune = prunner.get_prunning_plan(args.filter_size) 216 | 217 | return filters_to_prune 218 | 219 | def get_args(): 220 | parser = argparse.ArgumentParser() 221 | parser.add_argument("--train_path", type = str, default = "./annotations/imglist_anno_24_train.txt") 222 | parser.add_argument("--test_path", type = str, default = "./annotations/imglist_anno_24_val.txt") 223 | parser.add_argument("--filter_size", type = int, default = 5) 224 | parser.add_argument("--filter_percentage", type = float, default = 0.5) 225 | args = parser.parse_args() 226 | return args 227 | 228 | if __name__ == '__main__': 229 | 230 | args = get_args() 231 | 232 | device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu') 233 | 234 | model = RNet(is_train=True).to(device) 235 | model.load_state_dict(torch.load("pretrained_weights/mtcnn/best_rnet.pth", map_location=lambda storage, loc: storage)) 236 | 237 | prunner = FilterPrunner(model, use_cuda = True) 238 | 239 | save_dir = './prunning/saving_rnet_prunning_result' 240 | if os.path.exists(save_dir): 241 | raise NameError('model dir exists!') 242 | os.makedirs(save_dir) 243 | 244 | print("Check the initial model accuracy") 245 | since = time.time() 246 | accuracy, loss, loss_cls, loss_offset = test(model, args.test_path) 247 | print('initial test :: accuracy: {:.4f} loss: {:.4f} cls loss: {:.4f} offset loss: {:.4f}'.format(accuracy, loss, loss_cls, loss_offset)) 248 | print("initial test :: time cost is {:.2f} s".format(time.time()-since)) 249 | 250 | #Make sure all the layers are trainable 251 | for param in model.features.parameters(): 252 | param.requires_grad = True 253 | 254 | number_of_filters = total_num_filters(model) 255 | print("total model conv2D filters are: ", number_of_filters) 256 | 257 | num_filters_to_prune_per_iteration = args.filter_size 258 | 259 | iterations = math.ceil((float(number_of_filters) * args.filter_percentage) / num_filters_to_prune_per_iteration) 260 | print("Number of iterations to prune {} % filters:".format(args.filter_percentage*100), iterations) 261 | 262 | for it in range(iterations): 263 | 264 | print("iter{}. Ranking filters ..".format(it)) 265 | filters_to_prune = prune_model(model, prunner, args.test_path) 266 | 267 | layers_prunned = [(k, len(filters_to_prune[k])) for k in sorted(filters_to_prune.keys())] # k: layer index, number of filters 268 | print("iter{}. Layers that will be prunned".format(it), layers_prunned) 269 | 270 | print("iter{}. Prunning filters.. ".format(it)) 271 | for layer_index, filter_index in filters_to_prune.items(): 272 | model = prune_mtcnn(model, layer_index, *filter_index, use_cuda=True) 273 | model = model.to(device) 274 | 275 | print("iter{}. {:.2f}% Filters remaining".format(it, 100*float(total_num_filters(model)) / number_of_filters)) 276 | 277 | accuracy, loss, loss_cls, loss_offset = test(model, args.test_path) 278 | print('iter{}. without retrain :: accuracy: {:.4f} loss: {:.4f} cls loss: {:.4f} offset loss: {:.4f}'.format(it, accuracy, loss, loss_cls, loss_offset)) 279 | 280 | print("iter{}. Fine tuning to recover from prunning iteration.. ".format(it)) 281 | torch.cuda.empty_cache() 282 | train(model, path=args.train_path, epoch = 6) 283 | 284 | since = time.time() 285 | accuracy, loss, loss_cls, loss_offset = test(model, args.test_path) 286 | print('iter{}. after retrain :: accuracy: {:.4f} loss: {:.4f} cls loss: {:.4f} offset loss: {:.4f}'.format(it, accuracy, loss, loss_cls, loss_offset)) 287 | print("iter{}. test time cost is {:.2f} s".format(it, time.time()-since)) 288 | 289 | torch.save(model.state_dict(), os.path.join(save_dir, 'rnet_weights_pruned_{}'.format(it))) 290 | torch.save(model, os.path.join(save_dir, 'rnet_prunned_{}'.format(it))) 291 | 292 | torch.onnx.export(model, torch.randn(1, 3, 24, 24).to(device), 293 | './onnx2ncnn/pruned_rnet_{}.onnx'.format(it), 294 | input_names=['input'], output_names=['scores', 'offsets']) 295 | 296 | print("Finished prunning") 297 | 298 | 299 | -------------------------------------------------------------------------------- /prunning/pnet/prune.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | import os 4 | import argparse 5 | 6 | import torch 7 | from torchvision import transforms 8 | 9 | from tools.dataset import FaceDataset 10 | from nets.prune_mtcnn import PNet 11 | from prunning.pnet.pruner import PNetPruner 12 | from checkpoint import CheckPoint 13 | import config 14 | 15 | # Set device 16 | use_cuda = config.USE_CUDA and torch.cuda.is_available() 17 | # torch.cuda.manual_seed(train_config.manualSeed) 18 | device = torch.device("cuda:1" if use_cuda else "cpu") 19 | torch.backends.cudnn.benchmark = True 20 | 21 | # Set dataloader 22 | kwargs = {'num_workers': 8, 'pin_memory': False} if use_cuda else {} 23 | train_data = FaceDataset(os.path.join(config.ANNO_PATH, config.PNET_TRAIN_IMGLIST_FILENAME)) 24 | val_data = FaceDataset(os.path.join(config.ANNO_PATH, config.PNET_VAL_IMGLIST_FILENAME)) 25 | dataloaders = {'train': torch.utils.data.DataLoader(train_data, 26 | batch_size=config.BATCH_SIZE, shuffle=True, **kwargs), 27 | 'val': torch.utils.data.DataLoader(val_data, 28 | batch_size=config.BATCH_SIZE, shuffle=True, **kwargs) 29 | } 30 | 31 | # Set model 32 | model = PNet(is_train=True) 33 | model = model.to(device) 34 | model.load_state_dict(torch.load('pretrained_weights/mtcnn/best_pnet.pth'), strict=True) 35 | 36 | # Set checkpoint 37 | # checkpoint = CheckPoint(train_config.save_path) 38 | 39 | # Set optimizer 40 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) 41 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.STEPS, gamma=0.1) 42 | 43 | # Set trainer 44 | trainer = PNetPruner(config.EPOCHS, dataloaders, model, optimizer, scheduler, device, 0.5, 2) 45 | 46 | trainer.prune() 47 | 48 | # checkpoint.save_model(model, index=epoch, tag=config.SAVE_PREFIX) 49 | 50 | -------------------------------------------------------------------------------- /prunning/pnet/pruner.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append('.') 3 | import time 4 | import datetime 5 | from collections import OrderedDict 6 | import itertools 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from loss import Loss 14 | from tools.average_meter import AverageMeter 15 | 16 | class PNetPruner(object): 17 | 18 | def __init__(self, epochs, dataloaders, model, optimizer, scheduler, device, 19 | prune_ratio, finetune_epochs): 20 | self.epochs = epochs 21 | self.dataloaders = dataloaders 22 | self.model = model 23 | self.optimizer = optimizer 24 | self.scheduler = scheduler 25 | self.device = device 26 | self.lossfn = Loss(self.device) 27 | 28 | self.prune_iters = self._estimate_pruning_iterations(model, prune_ratio) 29 | print("Total prunning iterations:", self.prune_iters) 30 | self.finetune_epochs = finetune_epochs 31 | 32 | def compute_accuracy(self, prob_cls, gt_cls): 33 | # we only need the detection which >= 0 34 | prob_cls = torch.squeeze(prob_cls) 35 | mask = torch.ge(gt_cls, 0) 36 | 37 | # get valid elements 38 | valid_gt_cls = gt_cls[mask] 39 | valid_prob_cls = prob_cls[mask] 40 | size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) 41 | 42 | # get max index with softmax layer 43 | _, valid_pred_cls = torch.max(valid_prob_cls, dim=1) 44 | 45 | right_ones = torch.eq(valid_pred_cls.float(), valid_gt_cls.float()).float() 46 | 47 | return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) 48 | 49 | def prune(self): 50 | print("Before Prunning...") 51 | self.train_epoch(0, 'val') 52 | for i in range(self.prune_iters): 53 | self.prune_step() 54 | print("After Prunning Iter ", i) 55 | self.train_epoch(i, 'val') 56 | print("Finetuning...") 57 | for epoch in range(self.finetune_epochs): 58 | self.train_epoch(i, 'train') 59 | self.train_epoch(i, 'val') 60 | torch.save(self.model.state_dict(), './prunning/results/pruned_pnet.pth') 61 | torch.onnx.export(self.model, torch.randn(1, 3, 12, 12).to(self.device), 62 | './onnx2ncnn/pruned_pnet.onnx', 63 | input_names=['input'], output_names=['scores', 'offsets']) 64 | 65 | def prune_step(self): 66 | self.model.train() 67 | 68 | sample_idx = np.random.randint(0, len(self.dataloaders['train'])) 69 | for batch_idx, sample in enumerate(self.dataloaders['train']): 70 | if batch_idx == sample_idx: 71 | data = sample['input_img'] 72 | gt_cls = sample['cls_target'] 73 | gt_bbox = sample['bbox_target'] 74 | 75 | data, gt_cls, gt_bbox = data.to(self.device), gt_cls.to(self.device), gt_bbox.to(self.device).float() 76 | pred_cls, pred_bbox = self.model(data) 77 | cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls) 78 | bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox) 79 | total_loss = cls_loss + 5*bbox_loss 80 | total_loss.backward() 81 | self.model.prune(self.device) 82 | 83 | 84 | def train_epoch(self, epoch, phase): 85 | cls_loss_ = AverageMeter() 86 | bbox_loss_ = AverageMeter() 87 | total_loss_ = AverageMeter() 88 | accuracy_ = AverageMeter() 89 | if phase == 'train': 90 | self.model.train() 91 | else: 92 | self.model.eval() 93 | 94 | for batch_idx, sample in enumerate(self.dataloaders[phase]): 95 | data = sample['input_img'] 96 | gt_cls = sample['cls_target'] 97 | gt_bbox = sample['bbox_target'] 98 | data, gt_cls, gt_bbox = data.to(self.device), gt_cls.to( 99 | self.device), gt_bbox.to(self.device).float() 100 | 101 | self.optimizer.zero_grad() 102 | with torch.set_grad_enabled(phase == 'train'): 103 | pred_cls, pred_bbox = self.model(data) 104 | 105 | # compute the cls loss and bbox loss and weighted them together 106 | cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls) 107 | bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox) 108 | total_loss = cls_loss + 5*bbox_loss 109 | 110 | # compute clssification accuracy 111 | accuracy = self.compute_accuracy(pred_cls, gt_cls) 112 | 113 | if phase == 'train': 114 | total_loss.backward() 115 | self.optimizer.step() 116 | 117 | cls_loss_.update(cls_loss, data.size(0)) 118 | bbox_loss_.update(bbox_loss, data.size(0)) 119 | total_loss_.update(total_loss, data.size(0)) 120 | accuracy_.update(accuracy, data.size(0)) 121 | 122 | #if batch_idx % 40 == 0: 123 | # print('{} Epoch: {} [{:08d}/{:08d} ({:02.0f}%)]\tLoss: {:.6f} cls Loss: {:.6f} offset Loss:{:.6f}\tAccuracy: {:.6f} LR:{:.7f}'.format( 124 | # phase, epoch, batch_idx * len(data), len(self.dataloaders[phase].dataset), 125 | # 100. * batch_idx / len(self.dataloaders[phase]), total_loss.item(), cls_loss.item(), bbox_loss.item(), accuracy.item(), self.optimizer.param_groups[0]['lr'])) 126 | 127 | print("{} epoch Loss: {:.6f} cls Loss: {:.6f} bbox Loss: {:.6f} Accuracy: {:.6f}".format( 128 | phase, total_loss_.avg, cls_loss_.avg, bbox_loss_.avg, accuracy_.avg)) 129 | 130 | # torch.save(self.model.state_dict(), './pretrained_weights/quant_mtcnn/best_pnet.pth') 131 | 132 | return cls_loss_.avg, bbox_loss_.avg, total_loss_.avg, accuracy_.avg 133 | 134 | 135 | def _estimate_pruning_iterations(self, model, prune_ratio): 136 | '''Estimate how many feature maps to prune using estimated params per 137 | feature map divide by total param to prune, since we only prune 1 filter 138 | at a time, iterations should equal to total filters to prune 139 | 140 | Parameters: 141 | ----------- 142 | model: pytorch model 143 | prune_ratio: ration of total params to prune 144 | 145 | Return: 146 | ------- 147 | num of iterations of pruning 148 | ''' 149 | # we only prune Conv2d layers here, Linear layer will be considered later 150 | conv2ds = [module for module in model.modules() 151 | if issubclass(type(module), nn.Conv2d)] 152 | num_feature_maps = np.sum(conv2d.out_channels for conv2d in conv2ds) 153 | 154 | conv2d_params = (module.parameters() for module in model.modules() 155 | if issubclass(type(module), nn.Conv2d)) 156 | param_objs = itertools.chain(*conv2d_params) 157 | # num_param: in * out * w * h per feature map 158 | num_params = np.sum(np.prod(np.array(p.size())) for p in param_objs) 159 | 160 | params_per_map = num_params // num_feature_maps 161 | 162 | 163 | return int(np.ceil(num_params * prune_ratio / params_per_map)) 164 | 165 | if __name__ == "__main__": 166 | pass 167 | -------------------------------------------------------------------------------- /prunning/utils/FilterPrunner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from operator import itemgetter 4 | from heapq import nsmallest 5 | 6 | class FilterPrunner: 7 | """Class FilterPrunner performs structured pruning on filters based on the first 8 | order Taylor expansion of the network cost function from Nvidia: 9 | "Pruning Convolutional Neural Networks for Resource Efficient Inference" 10 | 11 | Parameters: 12 | ----------- 13 | model: the DNN model which should be composed with model.features and model.classifier. 14 | """ 15 | def __init__(self, model, use_cuda = False): 16 | self.model = model 17 | self.reset() 18 | self.use_cuda = use_cuda 19 | 20 | def reset(self): 21 | self.filter_ranks = {} 22 | 23 | def forward(self, x): 24 | self.activations = [] 25 | self.gradients = [] 26 | self.grad_index = 0 27 | self.activation_to_layer = {} 28 | 29 | activation_index = 0 30 | for layer, (name, module) in enumerate(self.model.features._modules.items()): 31 | x = module(x) 32 | if isinstance(module, torch.nn.modules.conv.Conv2d): 33 | x.register_hook(self.compute_rank) 34 | self.activations.append(x) 35 | self.activation_to_layer[activation_index] = layer # the ith conv2d layer 36 | activation_index += 1 37 | # Pnet 38 | if hasattr(self.model, 'conv4_1') and hasattr(self.model, 'conv4_2'): 39 | a = self.model.conv4_1(x) 40 | b = self.model.conv4_2(x) 41 | c = None 42 | 43 | # Rnet 44 | if hasattr(self.model, 'conv5_1') and hasattr(self.model, 'conv5_2'): 45 | a = self.model.conv5_1(x) 46 | b = self.model.conv5_2(x) 47 | c = None 48 | 49 | # Onet 50 | if hasattr(self.model, 'conv6_1') and hasattr(self.model, 'conv6_2') \ 51 | and hasattr(self.model, 'conv6_3'): 52 | a = self.model.conv6_1(x) 53 | b = self.model.conv6_2(x) 54 | c = self.model.conv6_3(x) 55 | 56 | return c, b, a 57 | 58 | def compute_rank(self, grad): 59 | activation_index = len(self.activations) - self.grad_index - 1 60 | activation = self.activations[activation_index] 61 | taylor = activation * grad 62 | 63 | # Get the average value for every filter, 64 | # accross all the other dimensions 65 | taylor = taylor.mean(dim=(0, 2, 3)).data 66 | 67 | if activation_index not in self.filter_ranks: 68 | self.filter_ranks[activation_index] = \ 69 | torch.FloatTensor(activation.size(1)).zero_() 70 | 71 | if self.use_cuda: 72 | self.filter_ranks[activation_index] = self.filter_ranks[activation_index].cuda() 73 | 74 | self.filter_ranks[activation_index] += taylor 75 | self.grad_index += 1 76 | 77 | def lowest_ranking_filters(self, num): 78 | data = [] 79 | for i in sorted(self.filter_ranks.keys()): 80 | for j in range(self.filter_ranks[i].size(0)): 81 | data.append((self.activation_to_layer[i], j, self.filter_ranks[i][j])) 82 | 83 | return nsmallest(num, data, itemgetter(2)) 84 | 85 | def normalize_ranks_per_layer(self): 86 | for i in self.filter_ranks: 87 | v = torch.abs(self.filter_ranks[i]).cpu() 88 | v = v / np.sqrt(torch.sum(v * v)) 89 | self.filter_ranks[i] = v 90 | 91 | def get_prunning_plan(self, num_filters_to_prune): 92 | filters_to_prune = self.lowest_ranking_filters(num_filters_to_prune) 93 | 94 | filters_to_prune_per_layer = {} 95 | for (l, f, _) in filters_to_prune: 96 | if l not in filters_to_prune_per_layer: 97 | filters_to_prune_per_layer[l] = [] 98 | filters_to_prune_per_layer[l].append(f) 99 | 100 | 101 | for l in filters_to_prune_per_layer: 102 | filters_to_prune_per_layer[l] = sorted(filters_to_prune_per_layer[l]) 103 | 104 | return filters_to_prune_per_layer 105 | -------------------------------------------------------------------------------- /prunning/utils/prunable_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | import torch 4 | 5 | 6 | class PConv2d(nn.Conv2d): 7 | """ 8 | Exactly like a Conv2d, but saves the activation of the last forward pass 9 | This allows calculation of the taylor estimate in https://arxiv.org/abs/1611.06440 10 | Includes convenience functions for feature map pruning 11 | """ 12 | 13 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): 14 | super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 15 | self.__recent_activations = None 16 | self.taylor_estimates = None 17 | self.__pruning_hook = None 18 | self.__pruning = False 19 | 20 | def forward(self, x): 21 | output = super().forward(x) 22 | if self.__pruning: 23 | self.__recent_activations = output.clone() 24 | return output 25 | 26 | def pruning(self, flag): 27 | self.__pruning = flag 28 | if flag: 29 | self.__pruning_hook = self.register_backward_hook(self.__estimate_taylor_importance) 30 | else: 31 | self.__pruning_hook = None 32 | 33 | def __estimate_taylor_importance(self, _, grad_input, grad_output): 34 | # skip dim=1, its the dim for depth 35 | n_batch, _, n_x, n_y = self.__recent_activations.size() 36 | n_parameters = n_batch * n_x * n_y 37 | 38 | estimates = self.__recent_activations.mul_(grad_output[0]) \ 39 | .sum(dim=3) \ 40 | .sum(dim=2) \ 41 | .sum(dim=0) \ 42 | .div_(n_parameters) 43 | 44 | # normalization 45 | self.taylor_estimates = torch.abs(estimates) / torch.sqrt(torch.sum(estimates * estimates)) 46 | del estimates, self.__recent_activations 47 | self.__recent_activations = None 48 | 49 | def prune_feature_map(self, map_index): 50 | is_cuda = self.weight.is_cuda 51 | 52 | indices = Variable(torch.LongTensor([i for i in range(self.out_channels) if i != map_index])) 53 | indices = indices.cuda() if is_cuda else indices 54 | 55 | self.weight = nn.Parameter(self.weight.index_select(0, indices).data) 56 | self.bias = nn.Parameter(self.bias.index_select(0, indices).data) 57 | self.out_channels -= 1 58 | 59 | def drop_input_channel(self, index): 60 | """ 61 | Use when a convnet earlier in the chain is pruned. Reduces input channel count 62 | :param index: 63 | :return: 64 | """ 65 | is_cuda = self.weight.is_cuda 66 | 67 | indices = Variable(torch.LongTensor([i for i in range(self.in_channels) if i != index])) 68 | indices = indices.cuda() if is_cuda else indices 69 | 70 | self.weight = nn.Parameter(self.weight.index_select(1, indices).data) 71 | self.in_channels -= 1 72 | 73 | 74 | class PLinear(nn.Linear): 75 | 76 | def drop_inputs(self, input_shape, index, dim=0): 77 | """ 78 | Previous layer is expected to be a convnet which just underwent pruning 79 | Drop cells connected to the pruned layer of the convnet 80 | :param input_shape: shape of inputs before flattening, should exclude batch_size 81 | :param index: index to drop 82 | :param dim: dimension where index is dropped, w.r.t input_shape 83 | :return: 84 | """ 85 | is_cuda = self.weight.is_cuda 86 | 87 | reshaped = self.weight.view(-1, *input_shape) 88 | dim_length = input_shape[dim] 89 | indices = Variable(torch.LongTensor([i for i in range(dim_length) if i != index])) 90 | indices = indices.cuda() if is_cuda else indices 91 | 92 | self.weight = nn.Parameter( 93 | reshaped.index_select(dim+1, indices) 94 | .data 95 | .view(self.out_features, -1) 96 | ) 97 | self.in_features = self.weight.size()[1] 98 | 99 | class PPReLU(nn.PReLU): 100 | def drop_inputs(self, index): 101 | is_cuda = self.weight.is_cuda 102 | # choose indexs not equal to index to drop 103 | indices = Variable(torch.LongTensor([i for i in range(self.num_parameters) if i != index])) 104 | indices = indices.cuda() if is_cuda else indices 105 | self.weight = nn.Parameter(self.weight.index_select(0, indices).data) 106 | self.num_parameters -= 1 107 | 108 | class PBatchNorm2d(nn.BatchNorm2d): 109 | 110 | def drop_input_channel(self, index): 111 | if self.affine: 112 | is_cuda = self.weight.is_cuda 113 | indices = Variable(torch.LongTensor([i for i in range(self.num_features) if i != index])) 114 | indices = indices.cuda() if is_cuda else indices 115 | 116 | self.weight = nn.Parameter(self.weight.index_select(0, indices).data) 117 | self.bias = nn.Parameter(self.bias.index_select(0, indices).data) 118 | self.running_mean = self.running_mean.index_select(0, indices.data) 119 | self.running_var = self.running_var.index_select(0, indices.data) 120 | 121 | self.num_features -= 1 122 | 123 | -------------------------------------------------------------------------------- /prunning/utils/prune_mtcnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Wed Aug 28 10:42:25 2019 3 | 4 | @author: xingyu 5 | """ 6 | 7 | import torch 8 | import numpy as np 9 | 10 | def replace_layers(model, i, indexes, layers): 11 | 12 | """ 13 | replace conv layers of model.feature 14 | 15 | :param model: 16 | :param i: index of model.feature 17 | :param indexes: array of indexes of layers to be replaced 18 | :param layers: array of new layers to replace 19 | :return: model with replaced layers 20 | """ 21 | if i in indexes: 22 | return layers[indexes.index(i)] 23 | return model[i] 24 | 25 | def prune_Conv2d(conv, filter_index, Next=False, use_cuda = True): 26 | 27 | """ 28 | :param conv: conv layer to be pruned 29 | :param filter_index: filter index to be pruned 30 | :param Next: False: the conv to be pruned by reconstructing the out_channel, 31 | True: represent the next conv to be pruned by reconstructing the input_channel 32 | :param use_cuda: 33 | :return: 34 | """ 35 | 36 | if Next: 37 | new_conv = \ 38 | torch.nn.Conv2d(in_channels=conv.in_channels - len(filter_index), \ 39 | out_channels=conv.out_channels, \ 40 | kernel_size=conv.kernel_size, \ 41 | stride=conv.stride, 42 | padding=conv.padding, 43 | dilation=conv.dilation, 44 | groups=conv.groups, 45 | bias=(conv.bias is not None)) 46 | 47 | old_weights = conv.weight.data.cpu().numpy() # i.e. (512, 512, 3, 3) 48 | new_weights = np.delete(old_weights, filter_index, axis=1) # i.e. (512, 511, 3, 3) 49 | new_conv.weight.data = torch.from_numpy(new_weights) 50 | if use_cuda: 51 | new_conv.weight.data = new_conv.weight.data.cuda() 52 | 53 | new_conv.bias.data = conv.bias.data # bias is not changed 54 | else: 55 | new_conv = \ 56 | torch.nn.Conv2d(in_channels=conv.in_channels, \ 57 | out_channels=conv.out_channels - len(filter_index), 58 | kernel_size=conv.kernel_size, \ 59 | stride=conv.stride, 60 | padding=conv.padding, 61 | dilation=conv.dilation, 62 | groups=conv.groups, 63 | bias=(conv.bias is not None)) 64 | 65 | old_weights = conv.weight.data.cpu().numpy() # i.e. (512, 512, 3, 3) 66 | new_weights = np.delete(old_weights, filter_index, axis=0) # i.e. (511, 512, 3, 3) 67 | new_conv.weight.data = torch.from_numpy(new_weights) 68 | if use_cuda: 69 | new_conv.weight.data = new_conv.weight.data.cuda() 70 | 71 | bias_numpy = conv.bias.data.cpu().numpy() # i.e. (512,) 72 | bias = np.delete(bias_numpy, filter_index) # i.e. (511,) 73 | new_conv.bias.data = torch.from_numpy(bias) 74 | if use_cuda: 75 | new_conv.bias.data = new_conv.bias.data.cuda() 76 | 77 | return new_conv 78 | 79 | def prune_PReLu(prelu, filter_index, use_cuda=True): 80 | # prune PReLu 81 | new_prelu = torch.nn.PReLU(num_parameters=prelu.num_parameters - len(filter_index)) 82 | old_weights = prelu.weight.data.cpu().numpy() 83 | new_weights = np.delete(old_weights, filter_index) 84 | new_prelu.weight.data = torch.from_numpy(new_weights) 85 | if use_cuda: 86 | new_prelu.weight.data = new_prelu.weight.data.cuda() 87 | 88 | return new_prelu 89 | 90 | def prune_linear(linear_layer, conv, filter_index, use_cuda=True): 91 | # prune fully connected layer which is the next to the to-be-pruned conv layer 92 | params_per_input_channel = linear_layer.in_features // conv.out_channels 93 | new_linear_layer = torch.nn.Linear(linear_layer.in_features - len(filter_index) * params_per_input_channel, 94 | linear_layer.out_features, bias=(linear_layer.bias is not None)) 95 | 96 | old_weights = linear_layer.weight.data.cpu().numpy() # i.e. (4096, 25088) (out_feature x in_feature) 97 | 98 | delete_array = [] 99 | for filter in filter_index: 100 | delete_array += [filter * params_per_input_channel + x for x in range(params_per_input_channel)] 101 | new_weights = np.delete(old_weights, delete_array, axis=1) # i.e. (4096, 25039) 102 | 103 | new_linear_layer.weight.data = torch.from_numpy(new_weights) 104 | 105 | if linear_layer.bias is not None: 106 | new_linear_layer.bias.data = linear_layer.bias.data 107 | if use_cuda: 108 | new_linear_layer.weight.data = new_linear_layer.weight.data.cuda() 109 | 110 | return new_linear_layer 111 | 112 | 113 | def prune_mtcnn(model, layer_index, *filter_index, use_cuda=False): 114 | 115 | _, conv = list(model.features._modules.items())[layer_index] 116 | _, prelu = list(model.features._modules.items())[layer_index+1] 117 | next_conv = None 118 | offset = 1 119 | 120 | if len(filter_index) >= conv.out_channels: 121 | raise BaseException("Cannot prune the whole conv layer") 122 | 123 | while layer_index + offset < len(model.features._modules.items()): 124 | res = list(model.features._modules.items())[layer_index + offset] 125 | if isinstance(res[1], torch.nn.modules.conv.Conv2d): 126 | next_name, next_conv = res 127 | break 128 | offset = offset + 1 129 | 130 | # The new conv layer constructed as follow: 131 | new_conv = prune_Conv2d(conv, filter_index, Next=False, use_cuda = use_cuda) 132 | 133 | # The new PReLU layer constructed as follow: 134 | new_PReLU = prune_PReLu(prelu, filter_index, use_cuda = use_cuda) 135 | 136 | # next conv layer needs to be reconstructed 137 | if not next_conv is None: 138 | next_new_conv = prune_Conv2d(next_conv, filter_index, Next=True, use_cuda = use_cuda) 139 | 140 | features = torch.nn.Sequential( 141 | *(replace_layers(model.features, i, [layer_index, layer_index+1, layer_index + offset], \ 142 | [new_conv, new_PReLU, next_new_conv]) for i, _ in enumerate(model.features))) 143 | del model.features # reset 144 | del conv # reset 145 | 146 | model.features = features 147 | 148 | else: 149 | 150 | linear_layer = None 151 | offset = 1 152 | while layer_index + offset < len(model.features._modules.items()): 153 | res = list(model.features._modules.items())[layer_index + offset] 154 | if isinstance(res[1], torch.nn.Linear): 155 | layer_name, linear_layer = res 156 | break 157 | offset = offset + 1 158 | 159 | if not linear_layer is None: 160 | new_linear_layer = prune_linear(linear_layer, conv, filter_index, use_cuda = use_cuda) 161 | 162 | features = torch.nn.Sequential( 163 | *(replace_layers(model.features, i, [layer_index, layer_index+1, layer_index + offset], \ 164 | [new_conv, new_PReLU, new_linear_layer]) for i, _ in enumerate(model.features))) 165 | 166 | del model.features # reset 167 | del conv # reset 168 | 169 | model.features = features 170 | 171 | else: 172 | model.features = torch.nn.Sequential( 173 | *(replace_layers(model.features, i, [layer_index, layer_index+1], \ 174 | [new_conv, new_PReLU]) for i, _ in enumerate(model.features))) 175 | 176 | if hasattr(model, 'conv4_1') and hasattr(model, 'conv4_2'): 177 | 178 | conv4_1, conv4_2 = model.conv4_1, model.conv4_2 179 | # new conv4_1 and conv4_2 layer need to be reconstructed 180 | new_conv4_1 = prune_Conv2d(conv4_1, filter_index, Next=True, use_cuda = use_cuda) 181 | new_conv4_2 = prune_Conv2d(conv4_2, filter_index, Next=True, use_cuda = use_cuda) 182 | 183 | del model.conv4_1 184 | del model.conv4_2 185 | 186 | model.conv4_1 = new_conv4_1 187 | model.conv4_2 = new_conv4_2 188 | 189 | return model 190 | 191 | if __name__ == '__main__': 192 | 193 | import sys 194 | sys.path.append("../Base_Model") 195 | 196 | from MTCNN_nets import PNet, RNet, ONet 197 | 198 | model = ONet() 199 | model.train() 200 | 201 | layer_index = 9 202 | filter_index = (2,4) 203 | 204 | model = prune_mtcnn(model, layer_index, *filter_index, use_cuda=False) 205 | 206 | print(model) 207 | 208 | 209 | 210 | -------------------------------------------------------------------------------- /prunning/utils/util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | class ListDataset(Dataset): 7 | def __init__(self, list_path): 8 | with open(list_path, 'r') as file: 9 | self.img_files = file.readlines() 10 | 11 | def __len__(self): 12 | return len(self.img_files) 13 | 14 | def __getitem__(self, index): 15 | 16 | annotation = self.img_files[index % len(self.img_files)].strip().split(' ') 17 | #--------- 18 | # Image 19 | #--------- 20 | img = cv2.imread(annotation[0]) 21 | img = img[:,:,::-1] 22 | img = np.asarray(img, 'float32') 23 | img = img.transpose((2, 0, 1)) 24 | img = (img - 127.5) * 0.0078125 25 | input_img = torch.FloatTensor(img) 26 | #--------- 27 | # Label 28 | #--------- 29 | label = int(annotation[1]) 30 | bbox_target = np.zeros((4,)) 31 | landmark = np.zeros((10,)) 32 | 33 | if len(annotation[2:]) == 4: 34 | bbox_target = np.array(annotation[2:6]).astype(float) 35 | if len(annotation[2:]) == 14: 36 | bbox_target = np.array(annotation[2:6]).astype(float) 37 | landmark = np.array(annotation[6:]).astype(float) 38 | 39 | sample = {'input_img': input_img, 'label': label, 'bbox_target': bbox_target, 'landmark': landmark} 40 | 41 | return sample 42 | 43 | # Print iterations progress 44 | def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█'): 45 | """ 46 | Call in a loop to create terminal progress bar 47 | @params: 48 | iteration - Required : current iteration (Int) 49 | total - Required : total iterations (Int) 50 | prefix - Optional : prefix string (Str) 51 | suffix - Optional : suffix string (Str) 52 | decimals - Optional : positive number of decimals in percent complete (Int) 53 | length - Optional : character length of bar (Int) 54 | fill - Optional : bar fill character (Str) 55 | """ 56 | percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) 57 | filledLength = int(length * iteration // total) 58 | bar = fill * filledLength + '-' * (length - filledLength) 59 | print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = '\r') 60 | # Print New Line on Complete 61 | if iteration == total: 62 | print() 63 | 64 | def total_num_filters(model): 65 | filters = 0 66 | for name, module in model.features._modules.items(): 67 | if isinstance(module, torch.nn.modules.conv.Conv2d): 68 | filters = filters + module.out_channels 69 | return filters -------------------------------------------------------------------------------- /quantization-aware_training/pnet/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | import os 4 | import argparse 5 | 6 | import torch 7 | from torchvision import transforms 8 | 9 | from tools.dataset import FaceDataset 10 | from nets.quant_mtcnn import PNet 11 | from training.pnet.trainer import PNetTrainer 12 | from checkpoint import CheckPoint 13 | import config 14 | 15 | # Set device 16 | use_cuda = config.USE_CUDA and torch.cuda.is_available() 17 | # torch.cuda.manual_seed(train_config.manualSeed) 18 | device = torch.device("cuda:1" if use_cuda else "cpu") 19 | torch.backends.cudnn.benchmark = True 20 | 21 | # Set dataloader 22 | kwargs = {'num_workers': 8, 'pin_memory': False} if use_cuda else {} 23 | train_data = FaceDataset(os.path.join(config.ANNO_PATH, config.PNET_TRAIN_IMGLIST_FILENAME)) 24 | val_data = FaceDataset(os.path.join(config.ANNO_PATH, config.PNET_VAL_IMGLIST_FILENAME)) 25 | dataloaders = {'train': torch.utils.data.DataLoader(train_data, 26 | batch_size=config.BATCH_SIZE, shuffle=True, **kwargs), 27 | 'val': torch.utils.data.DataLoader(val_data, 28 | batch_size=config.BATCH_SIZE, shuffle=True, **kwargs) 29 | } 30 | 31 | # Set model 32 | model = PNet(is_train=True) 33 | model = model.to(device) 34 | model.load_state_dict(torch.load('pretrained_weights/quant_mtcnn/best_pnet.pth'), strict=True) 35 | 36 | # Set checkpoint 37 | #checkpoint = CheckPoint(train_config.save_path) 38 | 39 | # Set optimizer 40 | optimizer = torch.optim.Adam(model.parameters(), lr=config.LR) 41 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.STEPS, gamma=0.1) 42 | 43 | # Set trainer 44 | trainer = PNetTrainer(config.EPOCHS, dataloaders, model, optimizer, scheduler, device) 45 | 46 | trainer.train() 47 | 48 | #checkpoint.save_model(model, index=epoch, tag=config.SAVE_PREFIX) 49 | 50 | -------------------------------------------------------------------------------- /quantization-aware_training/pnet/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | 4 | import torch 5 | 6 | from loss import Loss 7 | from tools.average_meter import AverageMeter 8 | 9 | 10 | class PNetTrainer(object): 11 | 12 | def __init__(self, epochs, dataloaders, model, optimizer, scheduler, device): 13 | self.epochs = epochs 14 | self.dataloaders = dataloaders 15 | self.model = model 16 | self.optimizer = optimizer 17 | self.scheduler = scheduler 18 | self.device = device 19 | self.lossfn = Loss(self.device) 20 | 21 | # save best model 22 | self.best_val_loss = 100 23 | 24 | def compute_accuracy(self, prob_cls, gt_cls): 25 | # we only need the detection which >= 0 26 | prob_cls = torch.squeeze(prob_cls) 27 | mask = torch.ge(gt_cls, 0) 28 | 29 | # get valid elements 30 | valid_gt_cls = gt_cls[mask] 31 | valid_prob_cls = prob_cls[mask] 32 | size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) 33 | 34 | # get max index with softmax layer 35 | _, valid_pred_cls = torch.max(valid_prob_cls, dim=1) 36 | 37 | right_ones = torch.eq(valid_pred_cls.float(), valid_gt_cls.float()).float() 38 | 39 | return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) 40 | 41 | def train(self): 42 | for epoch in range(self.epochs): 43 | self.train_epoch(epoch, 'train') 44 | self.train_epoch(epoch, 'val') 45 | 46 | 47 | def train_epoch(self, epoch, phase): 48 | cls_loss_ = AverageMeter() 49 | bbox_loss_ = AverageMeter() 50 | total_loss_ = AverageMeter() 51 | accuracy_ = AverageMeter() 52 | if phase == 'train': 53 | self.model.train() 54 | else: 55 | self.model.eval() 56 | 57 | for batch_idx, sample in enumerate(self.dataloaders[phase]): 58 | data = sample['input_img'] 59 | gt_cls = sample['cls_target'] 60 | gt_bbox = sample['bbox_target'] 61 | data, gt_cls, gt_bbox = data.to(self.device), gt_cls.to( 62 | self.device), gt_bbox.to(self.device).float() 63 | 64 | self.optimizer.zero_grad() 65 | with torch.set_grad_enabled(phase == 'train'): 66 | pred_cls, pred_bbox = self.model(data) 67 | 68 | # compute the cls loss and bbox loss and weighted them together 69 | cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls) 70 | bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox) 71 | total_loss = cls_loss + 5*bbox_loss 72 | 73 | # compute clssification accuracy 74 | accuracy = self.compute_accuracy(pred_cls, gt_cls) 75 | 76 | if phase == 'train': 77 | total_loss.backward() 78 | self.optimizer.step() 79 | 80 | cls_loss_.update(cls_loss, data.size(0)) 81 | bbox_loss_.update(bbox_loss, data.size(0)) 82 | total_loss_.update(total_loss, data.size(0)) 83 | accuracy_.update(accuracy, data.size(0)) 84 | 85 | if batch_idx % 40 == 0: 86 | print('{} Epoch: {} [{:08d}/{:08d} ({:02.0f}%)]\tLoss: {:.6f} cls Loss: {:.6f} offset Loss:{:.6f}\tAccuracy: {:.6f} LR:{:.7f}'.format( 87 | phase, epoch, batch_idx * len(data), len(self.dataloaders[phase].dataset), 88 | 100. * batch_idx / len(self.dataloaders[phase]), total_loss.item(), cls_loss.item(), bbox_loss.item(), accuracy.item(), self.optimizer.param_groups[0]['lr'])) 89 | 90 | if phase == 'train': 91 | self.scheduler.step() 92 | 93 | print("{} epoch Loss: {:.6f} cls Loss: {:.6f} bbox Loss: {:.6f} Accuracy: {:.6f}".format( 94 | phase, total_loss_.avg, cls_loss_.avg, bbox_loss_.avg, accuracy_.avg)) 95 | 96 | if phase == 'val' and total_loss_.avg < self.best_val_loss: 97 | self.best_val_loss = total_loss_.avg 98 | torch.save(self.model.state_dict(), './pretrained_weights/quant_mtcnn/best_pnet.pth') 99 | 100 | return cls_loss_.avg, bbox_loss_.avg, total_loss_.avg, accuracy_.avg 101 | 102 | -------------------------------------------------------------------------------- /test/results/fddb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/digital-nomad-cheng/MTCNN_PyTorch/752b9f0f32f5c25df5647c7c279c24715dc3aa87/test/results/fddb.png -------------------------------------------------------------------------------- /test/test_FDDB.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append('.') 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from tools.utils import * 8 | from MTCNN import MTCNNDetector 9 | 10 | if __name__ == "__main__": 11 | base_model_path = './pretrained_weights/mtcnn' 12 | mtcnn_detector = MTCNNDetector( 13 | p_model_path=os.path.join(base_model_path, 'best_pnet.pth'), 14 | r_model_path=os.path.join(base_model_path, 'best_rnet.pth'), 15 | o_model_path=os.path.join(base_model_path, 'best_onet.pth'), 16 | threshold=[0.7, 0.8, 0.9] 17 | ) 18 | fddb_path = './data/FDDB' 19 | for i in range(1, 11): 20 | with open(os.path.join(fddb_path, 'FDDB-folds/imgpath', 'FDDB-fold-{:02d}.txt'.format(i)), 'r') as f: 21 | lines = f.readlines() 22 | for line in lines: 23 | image_path = line.strip() + '.jpg' 24 | print(image_path) 25 | image = cv2.imread(os.path.join(fddb_path, image_path)) 26 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 27 | bboxes = mtcnn_detector.detect_face(image) 28 | for i in range(bboxes.shape[0]): 29 | x0, y0, x1, y1 = bboxes[i, :4] 30 | cv2.rectangle(image, (int(x0), int(y0)), (int(x1), int(y1)), (0, 255, 255), 1) 31 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 32 | cv2.imshow('image', image) 33 | cv2.waitKey(0) 34 | -------------------------------------------------------------------------------- /test/test_image.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append('.') 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | from tools.utils import * 8 | from MTCNN import MTCNNDetector 9 | #from MTCNN_debug import MTCNNDetector 10 | 11 | if __name__ == "__main__": 12 | base_model_path = './pretrained_weights/mtcnn' 13 | mtcnn_detector = MTCNNDetector( 14 | p_model_path=os.path.join(base_model_path, 'best_pnet.pth'), 15 | r_model_path=os.path.join(base_model_path, 'best_rnet.pth'), 16 | o_model_path=os.path.join(base_model_path, 'best_onet_landmark.pth'), 17 | min_face_size=40, 18 | threshold=[0.7, 0.8, 0.9] 19 | ) 20 | 21 | #image_path = './data/user/images' 22 | image_path='./test' 23 | images = [f for f in os.listdir(image_path) if f.endswith('.jpg')] 24 | 25 | for image in images: 26 | image = cv2.imread(os.path.join(image_path, image), 1) 27 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 28 | bboxes, landmarks = mtcnn_detector.detect_face(image) 29 | for i in range(bboxes.shape[0]): 30 | x0, y0, x1, y1 = bboxes[i, :4] 31 | width = int(x1 - x0 + 1) 32 | height = int(y1 - y0 + 1) 33 | cv2.rectangle(image, (int(x0), int(y0)), (int(x1), int(y1)), (0, 255, 255), 1) 34 | for j in range(5): 35 | x, y = int(x0 + landmarks[i, j]*width), int(y0 + landmarks[i, j+5]*height) 36 | cv2.circle(image, (x, y), 2, (255, 0, 255), 2) 37 | 38 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 39 | cv2.imshow('image', image) 40 | cv2.waitKey(0) 41 | -------------------------------------------------------------------------------- /test/test_video.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os, sys 3 | sys.path.append('.') 4 | import pathlib 5 | import logging 6 | 7 | import cv2 8 | import torch 9 | 10 | from MTCNN import MTCNNDetector 11 | import config 12 | 13 | logger = logging.getLogger("app") 14 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s: %(message)s') 15 | console_handler = logging.StreamHandler(sys.stdout) 16 | logger.addHandler(console_handler) 17 | logger.setLevel(logging.INFO) 18 | console_handler.formatter = formatter # 也可以直接给formatter赋值 19 | 20 | 21 | def draw_images(img, bboxs): # 在图片上绘制人脸框及特征点 22 | num_face = bboxs.shape[0] 23 | for i in range(num_face): 24 | cv2.rectangle(img, (int(bboxs[i, 0]), int(bboxs[i, 1])), (int( 25 | bboxs[i, 2]), int(bboxs[i, 3])), (0, 255, 0), 3) 26 | return img 27 | 28 | 29 | if __name__ == '__main__': 30 | base_model_path = './pretrained_weights/mtcnn' 31 | mtcnn_detector = MTCNNDetector( 32 | p_model_path=os.path.join(base_model_path, 'best_pnet.pth'), 33 | r_model_path=os.path.join(base_model_path, 'best_rnet.pth'), 34 | o_model_path=os.path.join(base_model_path, 'best_onet.pth'), 35 | min_face_size=24, threshold=[0.7, 0.8, 0.9], use_cuda=False) 36 | logger.info("Init the MtcnnDetector.") 37 | 38 | cap = cv2.VideoCapture('./test/test_video.mov') 39 | if not cap.isOpened(): 40 | print("Failed to open capture from file") 41 | start = time.time() 42 | num = 0 43 | while(cap.isOpened): 44 | ret, img = cap.read() 45 | logger.info("Start to process No.{} image.".format(num)) 46 | RGB_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 47 | #RGB_image = cv2.medianBlur(RGB_image, 5) 48 | bboxs = mtcnn_detector.detect_face(RGB_image) 49 | img = draw_images(img, bboxs) 50 | cv2.imshow('frame', img) 51 | cv2.waitKey(1) 52 | num += 1 53 | logger.info("Finish all the images.") 54 | logger.info("Elapsed time: {:.3f}s".format(time.time() - start)) 55 | -------------------------------------------------------------------------------- /tools/average_meter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as transforms 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | """ 13 | reset all parameters 14 | """ 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | """ 22 | update parameters 23 | """ 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | transform = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 33 | ]) 34 | 35 | def convert_image_to_tensor(image): 36 | """convert an image to pytorch tensor 37 | 38 | Parameters: 39 | ---------- 40 | image: numpy array , h * w * c 41 | 42 | Returns: 43 | ------- 44 | image_tensor: pytorch.FloatTensor, c * h * w 45 | """ 46 | 47 | return transform(image) 48 | 49 | 50 | def convert_chwTensor_to_hwcNumpy(tensor): 51 | """convert a group images pytorch tensor(count * c * h * w) to numpy array images(count * h * w * c) 52 | Parameters: 53 | ---------- 54 | tensor: numpy array , count * c * h * w 55 | 56 | Returns: 57 | ------- 58 | numpy array images: count * h * w * c 59 | """ 60 | 61 | if isinstance(tensor, torch.FloatTensor): 62 | return np.transpose(tensor.detach().numpy(), (0, 2, 3, 1)) 63 | else: 64 | raise Exception( 65 | "covert b*c*h*w tensor to b*h*w*c numpy error.This tensor must have 4 dimension of float data type.") 66 | -------------------------------------------------------------------------------- /tools/dataset.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append('.') 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | import config 10 | class FaceDataset(Dataset): 11 | '''Dataset class for MTCNN face detector''' 12 | def __init__(self, annotation_path): 13 | with open(annotation_path, 'r') as f: 14 | self.img_files = f.readlines() 15 | 16 | def __len__(self): 17 | return len(self.img_files) 18 | 19 | def __getitem__(self, index): 20 | annotation = self.img_files[index % len(self.img_files)].strip().split(' ') 21 | img = cv2.imread(annotation[0], 1) 22 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 23 | img = np.asarray(img, 'float32') 24 | img = np.transpose(img, (2, 0, 1)) 25 | img = (img - 127.5) / 127.5 # rescale pixel value to between -1 and 1 26 | input_img = torch.FloatTensor(img) 27 | 28 | cls_target = int(annotation[1]) 29 | bbox_target = np.zeros((4,)) 30 | landmark_target = np.zeros((10,)) 31 | 32 | if len(annotation[2:]) == 4: 33 | bbox_target = np.array(annotation[2:6]).astype(float) 34 | if len(annotation[2:]) == 14: 35 | bbox_target = np.array(annotation[2:6]).astype(float) 36 | landmark_target = np.array(annotation[2:6]).astype(float) 37 | sample = {'input_img': input_img, 38 | 'cls_target': cls_target, 39 | 'bbox_target': bbox_target, 40 | 'landmark_target': landmark_target} 41 | return sample 42 | 43 | if __name__ == '__main__': 44 | 45 | 46 | 47 | train_data = Dataset(os.path.join(config.ANNO_PATH, config.PNET_TRAIN_IMGLIST_FILENAME)) 48 | val_data = Dataset(os.path.join(config.ANNO_PATH, config.PNET_VAL_IMGLIST_FILENAME)) 49 | dataloaders = {'train': torch.utils.data.DataLoader(train_data, 50 | batch_size=config.BATCH_SIZE, shuffle=True), 51 | 'val': torch.utils.data.DataLoader(val_data, 52 | batch_size=config.BATCH_SIZE, shuffle=True) 53 | } 54 | 55 | for batch_idx, sample_batched in enumerate(dataloaders['train']): 56 | images_batch, cls_batch, bbox_batch, landmark_batch = \ 57 | sample_batched['input_img'], sample_batched['cls_target'], sample_batched['bbox_target'], sample_batched['landmark_target'] 58 | 59 | print(batch_idx, images_batch.shape, cls_batch.shape, bbox_batch.shape, landmark_batch.shape) 60 | 61 | if batch_idx == 3: 62 | break 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as transforms 4 | 5 | 6 | def nms(bboxes, overlap_threshold=0.5, mode='Union'): 7 | """Non-maximum suppression. 8 | Parameters: 9 | ----------- 10 | bboxes: a float numpy array of shape [n, 5], 11 | where each row is (xmin, ymin, xmax, ymax, score). 12 | overlap_threshold: a float number. 13 | mode: 'Union' or 'Minimum'. 14 | 'Union': intersection over union of two bboxes 15 | 'Minimum': intersection over minimum of two bboxes 16 | Returns: 17 | -------- 18 | list with indices of the selected bboxes 19 | """ 20 | 21 | # if there are no bboxes, return the empty list 22 | if len(bboxes) == 0: 23 | return [] 24 | 25 | # list of picked indices 26 | pick = [] 27 | 28 | # grab the coordinates of bboxes 29 | x1, y1, x2, y2, score = [bboxes[:, i] for i in range(5)] 30 | 31 | area = (x2 - x1 + 1.0)*(y2 - y1 + 1.0) 32 | ids = np.argsort(score) # index increasing order sorted using score 33 | 34 | while len(ids) > 0: 35 | 36 | # grab index of the largest value 37 | last = len(ids) - 1 38 | i = ids[last] 39 | pick.append(i) 40 | 41 | # compute intersections of the bbox with the largest score 42 | # with the rest of bboxes 43 | 44 | # left top corner of intersection bboxes 45 | ix1 = np.maximum(x1[i], x1[ids[:last]]) 46 | iy1 = np.maximum(y1[i], y1[ids[:last]]) 47 | 48 | # right bottom corner of intersection bboxes 49 | ix2 = np.minimum(x2[i], x2[ids[:last]]) 50 | iy2 = np.minimum(y2[i], y2[ids[:last]]) 51 | 52 | # width and height of intersection bboxes 53 | w = np.maximum(0.0, ix2 - ix1 + 1.0) 54 | h = np.maximum(0.0, iy2 - iy1 + 1.0) 55 | 56 | # intersections' areas 57 | inter = w * h 58 | if mode == 'Minimum': 59 | overlap = inter/np.minimum(area[i], area[ids[:last]]) 60 | elif mode == 'Union': 61 | # intersection over union (IoU) 62 | overlap = inter/(area[i] + area[ids[:last]] - inter) 63 | 64 | # delete all bboxes where overlap is too big 65 | ids = np.delete( 66 | ids, 67 | np.concatenate([[last], np.where(overlap > overlap_threshold)[0]]) 68 | ) 69 | 70 | return pick 71 | 72 | def calibrate_box(bboxes, offsets): 73 | """Transform bounding boxes to be more like true bounding boxes. 74 | 'offsets' is one of the outputs of the nets. 75 | 76 | Parameters: 77 | ----------- 78 | bboxes: a float numpy array of shape [n, 5]. 79 | offsets: a float numpy array of shape [n, 4]. 80 | 81 | Returns: 82 | -------- 83 | a float numpy array of shape [n, 5]. 84 | """ 85 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 86 | w = x2 - x1 + 1.0 87 | h = y2 - y1 + 1.0 88 | w = np.expand_dims(w, 1) 89 | h = np.expand_dims(h, 1) 90 | 91 | # this is what happening here: 92 | # tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)] 93 | # x1_true = x1 + tx1*w 94 | # y1_true = y1 + ty1*h 95 | # x2_true = x2 + tx2*w 96 | # y2_true = y2 + ty2*h 97 | # below is just more compact form of this 98 | 99 | # are offsets always such that 100 | # x1 < x2 and y1 < y2 ? 101 | 102 | translation = np.hstack([w, h, w, h])*offsets 103 | bboxes[:, 0:4] = bboxes[:, 0:4] + translation 104 | return bboxes 105 | 106 | def convert_to_square(bboxes): 107 | """Convert bounding boxes to a square form. 108 | 109 | Parameters: 110 | ----------- 111 | bboxes: a float numpy array of shape [n, 5]. 112 | 113 | Returns: 114 | -------- 115 | a float numpy array of shape [n, 5], squared bounding boxes. 116 | """ 117 | 118 | square_bboxes = np.zeros_like(bboxes) 119 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 120 | h = y2 - y1 + 1.0 121 | w = x2 - x1 + 1.0 122 | max_side = np.maximum(h, w) 123 | square_bboxes[:, 0] = x1 + w*0.5 - max_side*0.5 124 | square_bboxes[:, 1] = y1 + h*0.5 - max_side*0.5 125 | square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0 126 | square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0 127 | return square_bboxes 128 | 129 | def correct_bboxes(bboxes, width, height): 130 | """Crop boxes that are too big and get coordinates with respect to cutouts. 131 | 132 | Parameters: 133 | ----------- 134 | bboxes: a float numpy array of shape [n, 5], 135 | where each row is (xmin, ymin, xmax, ymax, score). 136 | width: a float number. 137 | height: a float number. 138 | 139 | Returns: 140 | -------- 141 | dy, dx, edy, edx: a int numpy arrays of shape [n], 142 | coordinates of the boxes with respect to the cutouts. 143 | y, x, ey, ex: a int numpy arrays of shape [n], 144 | corrected ymin, xmin, ymax, xmax. 145 | h, w: a int numpy arrays of shape [n], 146 | just heights and widths of boxes. 147 | 148 | in the following order: 149 | [dy, edy, dx, edx, y, ey, x, ex, w, h]. 150 | """ 151 | 152 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 153 | x2,y2 = np.clip(x2, x1, None), np.clip(y2, y1, None) 154 | w, h = x2 - x1 + 1.0, y2 - y1 + 1.0 155 | num_boxes = bboxes.shape[0] 156 | 157 | # 'e' stands for end 158 | # (x, y) -> (ex, ey) 159 | x, y, ex, ey = x1, y1, x2, y2 160 | 161 | # we need to cut out a box from the image. 162 | # (x, y, ex, ey) are corrected coordinates of the box 163 | # in the image. 164 | # (dx, dy, edx, edy) are coordinates of the box in the cutout 165 | # from the image. 166 | dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,)) 167 | edx, edy = w.copy() - 1.0, h.copy() - 1.0 168 | 169 | # if box's bottom right corner is too far right 170 | ind = np.where(ex > width - 1.0)[0] 171 | edx[ind] = w[ind] + width - 2.0 - ex[ind] 172 | ex[ind] = width - 1.0 173 | 174 | # if box's bottom right corner is too low 175 | ind = np.where(ey > height - 1.0)[0] 176 | edy[ind] = h[ind] + height - 2.0 - ey[ind] 177 | ey[ind] = height - 1.0 178 | 179 | # if box's top left corner is too far left 180 | ind = np.where(x < 0.0)[0] 181 | dx[ind] = 0.0 - x[ind] 182 | x[ind] = 0.0 183 | 184 | # if box's top left corner is too high 185 | ind = np.where(y < 0.0)[0] 186 | dy[ind] = 0.0 - y[ind] 187 | y[ind] = 0.0 188 | 189 | return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h] 190 | return_list = [i.astype('int32') for i in return_list] 191 | 192 | return return_list 193 | 194 | def IoU(box, boxes): 195 | """Compute IoU between detect box and gt boxes 196 | 197 | Parameters: 198 | ----------- 199 | box: numpy array , shape (5, ): x1, y1, x2, y2, score 200 | input box 201 | boxes: numpy array, shape (n, 4): x1, y1, x2, y2 202 | input ground truth boxes 203 | 204 | Returns: 205 | -------- 206 | ovr: numpy.array, shape (n, ) 207 | IoU 208 | """ 209 | # box = (x1, y1, x2, y2) 210 | box_area = (box[2] - box[0] + 1) * (box[3] - box[1] + 1) 211 | area = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1) 212 | 213 | # abtain the offset of the interception of union between crop_box and gt_box 214 | xx1 = np.maximum(box[0], boxes[:, 0]) 215 | yy1 = np.maximum(box[1], boxes[:, 1]) 216 | xx2 = np.minimum(box[2], boxes[:, 2]) 217 | yy2 = np.minimum(box[3], boxes[:, 3]) 218 | 219 | # compute the width and height of the bounding box 220 | w = np.maximum(0, xx2 - xx1 + 1) 221 | h = np.maximum(0, yy2 - yy1 + 1) 222 | 223 | inter = w * h 224 | ovr = inter / (box_area + area - inter) 225 | return ovr 226 | 227 | transform = transforms.Compose([ 228 | transforms.ToTensor(), 229 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 230 | ]) 231 | 232 | def convert_image_to_tensor(image): 233 | """Convert an image to pytorch tensor and do some preprocessing 234 | 235 | Parameters: 236 | ----------- 237 | image: numpy array , h * w * c 238 | 239 | Returns: 240 | -------- 241 | image_tensor: pytorch.FloatTensor, c * h * w 242 | """ 243 | 244 | return transform(image) 245 | 246 | def convert_chwTensor_to_hwcNumpy(tensor): 247 | """Convert a group images pytorch tensor(count * c * h * w) into 248 | numpy array images(count * h * w * c) 249 | 250 | Parameters: 251 | ----------- 252 | tensor: pytorch tensor array , count * c * h * w 253 | 254 | Returns: 255 | -------- 256 | numpy array images: count * h * w * c 257 | """ 258 | if isinstance(tensor, torch.FloatTensor): 259 | return np.transpose(tensor.detach().numpy(), (0, 2, 3, 1)) 260 | else: 261 | raise Exception( 262 | "covert b*c*h*w tensor to b*h*w*c numpy error.This tensor must have \ 263 | 4 dimension of float data type.") 264 | -------------------------------------------------------------------------------- /training/onet/landmark_train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | import os 4 | import argparse 5 | 6 | import torch 7 | from torchvision import transforms 8 | 9 | from tools.dataset import FaceDataset 10 | from nets.mtcnn import ONet 11 | from training.onet.landmark_trainer import ONetTrainer 12 | from checkpoint import CheckPoint 13 | import config 14 | 15 | # Set device 16 | use_cuda = config.USE_CUDA and torch.cuda.is_available() 17 | device = torch.device("cuda:1" if use_cuda else "cpu") 18 | torch.backends.cudnn.benchmark = True 19 | 20 | # Set dataloader 21 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 22 | train_data = FaceDataset(os.path.join(config.ANNO_PATH, config.ONET_TRAIN_IMGLIST_FILENAME)) 23 | val_data = FaceDataset(os.path.join(config.ANNO_PATH, config.ONET_VAL_IMGLIST_FILENAME)) 24 | dataloaders = {'train': torch.utils.data.DataLoader(train_data, 25 | batch_size=config.BATCH_SIZE, shuffle=True, **kwargs), 26 | 'val': torch.utils.data.DataLoader(val_data, 27 | batch_size=config.BATCH_SIZE, shuffle=True, **kwargs) 28 | } 29 | 30 | # Set model 31 | model = ONet(is_train=True, train_landmarks=True) 32 | model = model.to(device) 33 | model.load_state_dict(torch.load('pretrained_weights/mtcnn/best_onet.pth'), strict=True) 34 | print(model) 35 | 36 | # Set checkpoint 37 | #checkpoint = CheckPoint(train_config.save_path) 38 | 39 | # Set optimizer 40 | optimizer = torch.optim.Adam(model.parameters(), lr=config.LR) 41 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.STEPS, gamma=0.1) 42 | 43 | # Set trainer 44 | trainer = ONetTrainer(config.EPOCHS, dataloaders, model, optimizer, scheduler, device) 45 | 46 | trainer.train() 47 | 48 | #checkpoint.save_model(model, index=epoch, tag=config.SAVE_PREFIX) 49 | 50 | -------------------------------------------------------------------------------- /training/onet/landmark_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | 4 | import torch 5 | 6 | import config 7 | from loss import Loss 8 | from tools.average_meter import AverageMeter 9 | 10 | 11 | class ONetTrainer(object): 12 | 13 | def __init__(self, epochs, dataloaders, model, optimizer, scheduler, device): 14 | self.epochs = epochs 15 | self.dataloaders = dataloaders 16 | self.model = model 17 | self.optimizer = optimizer 18 | self.scheduler = scheduler 19 | self.device = device 20 | self.lossfn = Loss(self.device) 21 | 22 | # save best model 23 | self.best_val_loss = 100 24 | 25 | def compute_accuracy(self, prob_cls, gt_cls): 26 | # we only need the detection which >= 0 27 | prob_cls = torch.squeeze(prob_cls) 28 | tmp_gt_cls = gt_cls.detach().clone() 29 | tmp_gt_cls[tmp_gt_cls==-2] = 1 30 | mask = torch.ge(tmp_gt_cls, 0) 31 | 32 | # get valid elements 33 | valid_gt_cls = tmp_gt_cls[mask] 34 | valid_prob_cls = prob_cls[mask] 35 | size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) 36 | 37 | # get max index with softmax layer 38 | _, valid_pred_cls = torch.max(valid_prob_cls, dim=1) 39 | 40 | right_ones = torch.eq(valid_pred_cls.float(), valid_gt_cls.float()).float() 41 | 42 | return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) 43 | 44 | def train(self): 45 | for epoch in range(self.epochs): 46 | self.train_epoch(epoch, 'train') 47 | self.train_epoch(epoch, 'val') 48 | 49 | 50 | def train_epoch(self, epoch, phase): 51 | cls_loss_ = AverageMeter() 52 | bbox_loss_ = AverageMeter() 53 | landmark_loss_ = AverageMeter() 54 | total_loss_ = AverageMeter() 55 | accuracy_ = AverageMeter() 56 | 57 | if phase == 'train': 58 | self.model.train() 59 | else: 60 | self.model.eval() 61 | 62 | for batch_idx, sample in enumerate(self.dataloaders[phase]): 63 | data = sample['input_img'] 64 | gt_cls = sample['cls_target'] 65 | gt_bbox = sample['bbox_target'] 66 | gt_landmark = sample['landmark_target'] 67 | 68 | data, gt_cls, gt_bbox, gt_landmark = data.to(self.device), \ 69 | gt_cls.to(self.device), gt_bbox.to(self.device).float(), \ 70 | gt_landmark.to(self.device).float() 71 | 72 | self.optimizer.zero_grad() 73 | with torch.set_grad_enabled(phase == 'train'): 74 | pred_cls, pred_bbox, pred_landmark = self.model(data) 75 | # compute the cls loss and bbox loss and weighted them together 76 | cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls) 77 | bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox) 78 | landmark_loss = self.lossfn.landmark_loss(gt_cls, gt_landmark, pred_landmark) 79 | total_loss = cls_loss + 20*bbox_loss + 20*landmark_loss 80 | 81 | # compute clssification accuracy 82 | accuracy = self.compute_accuracy(pred_cls, gt_cls) 83 | 84 | if phase == 'train': 85 | total_loss.backward() 86 | self.optimizer.step() 87 | 88 | cls_loss_.update(cls_loss, data.size(0)) 89 | bbox_loss_.update(bbox_loss, data.size(0)) 90 | landmark_loss_.update(landmark_loss, data.size(0)) 91 | total_loss_.update(total_loss, data.size(0)) 92 | accuracy_.update(accuracy, data.size(0)) 93 | 94 | if batch_idx % 40 == 0: 95 | print('{} Epoch: {} [{:08d}/{:08d} ({:02.0f}%)]\tLoss: {:.6f} cls Loss: {:.6f} offset Loss:{:.6f} landmark Loss: {:.6f}\tAccuracy: {:.6f} LR:{:.7f}'.format( 96 | phase, epoch, batch_idx * len(data), len(self.dataloaders[phase].dataset), 97 | 100. * batch_idx / len(self.dataloaders[phase]), total_loss.item(), cls_loss.item(), bbox_loss.item(), landmark_loss.item(), accuracy.item(), self.optimizer.param_groups[0]['lr'])) 98 | 99 | if phase == 'train': 100 | self.scheduler.step() 101 | 102 | print("{} epoch Loss: {:.6f} cls Loss: {:.6f} bbox Loss: {:.6f} landmark Loss: {:.6f} Accuracy: {:.6f}".format( 103 | phase, total_loss_.avg, cls_loss_.avg, bbox_loss_.avg, landmark_loss_.avg, accuracy_.avg)) 104 | 105 | if phase == 'val' and total_loss_.avg < self.best_val_loss: 106 | self.best_val_loss = total_loss_.avg 107 | torch.save(self.model.state_dict(), './pretrained_weights/mtcnn/best_onet_landmark_2.pth') 108 | 109 | return cls_loss_.avg, bbox_loss_.avg, total_loss_.avg, landmark_loss_.avg, accuracy_.avg 110 | 111 | -------------------------------------------------------------------------------- /training/onet/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | import os 4 | import argparse 5 | 6 | import torch 7 | from torchvision import transforms 8 | 9 | from tools.dataset import FaceDataset 10 | from nets.mtcnn import ONet 11 | from training.onet.trainer import ONetTrainer 12 | from checkpoint import CheckPoint 13 | import config 14 | 15 | # Set device 16 | use_cuda = config.USE_CUDA and torch.cuda.is_available() 17 | device = torch.device("cuda:1" if use_cuda else "cpu") 18 | torch.backends.cudnn.benchmark = True 19 | 20 | # Set dataloader 21 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 22 | train_data = FaceDataset(os.path.join(config.ANNO_PATH, config.ONET_TRAIN_IMGLIST_FILENAME)) 23 | val_data = FaceDataset(os.path.join(config.ANNO_PATH, config.ONET_VAL_IMGLIST_FILENAME)) 24 | dataloaders = {'train': torch.utils.data.DataLoader(train_data, 25 | batch_size=config.BATCH_SIZE, shuffle=True, **kwargs), 26 | 'val': torch.utils.data.DataLoader(val_data, 27 | batch_size=config.BATCH_SIZE, shuffle=True, **kwargs) 28 | } 29 | 30 | # Set model 31 | model = ONet(is_train=True) 32 | model = model.to(device) 33 | model.load_state_dict(torch.load('pretrained_weights/best_onet.pth'), strict=True) 34 | 35 | # Set checkpoint 36 | #checkpoint = CheckPoint(train_config.save_path) 37 | 38 | # Set optimizer 39 | optimizer = torch.optim.Adam(model.parameters(), lr=config.LR) 40 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.STEPS, gamma=0.1) 41 | 42 | # Set trainer 43 | trainer = ONetTrainer(config.EPOCHS, dataloaders, model, optimizer, scheduler, device) 44 | 45 | trainer.train() 46 | 47 | #checkpoint.save_model(model, index=epoch, tag=config.SAVE_PREFIX) 48 | 49 | -------------------------------------------------------------------------------- /training/onet/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | 4 | import torch 5 | 6 | from loss import Loss 7 | from tools.average_meter import AverageMeter 8 | 9 | 10 | class ONetTrainer(object): 11 | 12 | def __init__(self, epochs, dataloaders, model, optimizer, scheduler, device): 13 | self.epochs = epochs 14 | self.dataloaders = dataloaders 15 | self.model = model 16 | self.optimizer = optimizer 17 | self.scheduler = scheduler 18 | self.device = device 19 | self.lossfn = Loss(self.device) 20 | 21 | # save best model 22 | self.best_val_loss = 100 23 | 24 | def compute_accuracy(self, prob_cls, gt_cls): 25 | # we only need the detection which >= 0 26 | prob_cls = torch.squeeze(prob_cls) 27 | mask = torch.ge(gt_cls, 0) 28 | 29 | # get valid elements 30 | valid_gt_cls = gt_cls[mask] 31 | valid_prob_cls = prob_cls[mask] 32 | size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) 33 | 34 | # get max index with softmax layer 35 | _, valid_pred_cls = torch.max(valid_prob_cls, dim=1) 36 | 37 | right_ones = torch.eq(valid_pred_cls.float(), valid_gt_cls.float()).float() 38 | 39 | return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) 40 | 41 | def train(self): 42 | for epoch in range(self.epochs): 43 | self.train_epoch(epoch, 'train') 44 | self.train_epoch(epoch, 'val') 45 | 46 | 47 | def train_epoch(self, epoch, phase): 48 | cls_loss_ = AverageMeter() 49 | bbox_loss_ = AverageMeter() 50 | total_loss_ = AverageMeter() 51 | accuracy_ = AverageMeter() 52 | if phase == 'train': 53 | self.model.train() 54 | else: 55 | self.model.eval() 56 | 57 | for batch_idx, sample in enumerate(self.dataloaders[phase]): 58 | data = sample['input_img'] 59 | gt_cls = sample['cls_target'] 60 | gt_bbox = sample['bbox_target'] 61 | data, gt_cls, gt_bbox = data.to(self.device), gt_cls.to( 62 | self.device), gt_bbox.to(self.device).float() 63 | 64 | self.optimizer.zero_grad() 65 | with torch.set_grad_enabled(phase == 'train'): 66 | pred_cls, pred_bbox = self.model(data) 67 | 68 | # compute the cls loss and bbox loss and weighted them together 69 | cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls) 70 | bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox) 71 | total_loss = cls_loss + 10*bbox_loss 72 | 73 | # compute clssification accuracy 74 | accuracy = self.compute_accuracy(pred_cls, gt_cls) 75 | 76 | if phase == 'train': 77 | total_loss.backward() 78 | self.optimizer.step() 79 | 80 | cls_loss_.update(cls_loss, data.size(0)) 81 | bbox_loss_.update(bbox_loss, data.size(0)) 82 | total_loss_.update(total_loss, data.size(0)) 83 | accuracy_.update(accuracy, data.size(0)) 84 | 85 | if batch_idx % 40 == 0: 86 | print('{} Epoch: {} [{:08d}/{:08d} ({:02.0f}%)]\tLoss: {:.6f} cls Loss: {:.6f} offset Loss:{:.6f}\tAccuracy: {:.6f} LR:{:.7f}'.format( 87 | phase, epoch, batch_idx * len(data), len(self.dataloaders[phase].dataset), 88 | 100. * batch_idx / len(self.dataloaders[phase]), total_loss.item(), cls_loss.item(), bbox_loss.item(), accuracy.item(), self.optimizer.param_groups[0]['lr'])) 89 | 90 | if phase == 'train': 91 | self.scheduler.step() 92 | 93 | print("{} epoch Loss: {:.6f} cls Loss: {:.6f} bbox Loss: {:.6f} Accuracy: {:.6f}".format( 94 | phase, total_loss_.avg, cls_loss_.avg, bbox_loss_.avg, accuracy_.avg)) 95 | 96 | if phase == 'val' and total_loss_.avg < self.best_val_loss: 97 | self.best_val_loss = total_loss_.avg 98 | torch.save(self.model.state_dict(), './pretrained_weights/best_onet.pth') 99 | 100 | return cls_loss_.avg, bbox_loss_.avg, total_loss_.avg, accuracy_.avg 101 | 102 | -------------------------------------------------------------------------------- /training/pnet/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | import os 4 | import argparse 5 | 6 | import torch 7 | from torchvision import transforms 8 | 9 | from tools.dataset import FaceDataset 10 | from nets.mtcnn import PNet 11 | from training.pnet.trainer import PNetTrainer 12 | from checkpoint import CheckPoint 13 | import config 14 | 15 | # Set device 16 | use_cuda = config.USE_CUDA and torch.cuda.is_available() 17 | # torch.cuda.manual_seed(train_config.manualSeed) 18 | device = torch.device("cuda:1" if use_cuda else "cpu") 19 | torch.backends.cudnn.benchmark = True 20 | 21 | # Set dataloader 22 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 23 | train_data = FaceDataset(os.path.join(config.ANNO_PATH, config.PNET_TRAIN_IMGLIST_FILENAME)) 24 | val_data = FaceDataset(os.path.join(config.ANNO_PATH, config.PNET_VAL_IMGLIST_FILENAME)) 25 | dataloaders = {'train': torch.utils.data.DataLoader(train_data, 26 | batch_size=config.BATCH_SIZE, shuffle=True, **kwargs), 27 | 'val': torch.utils.data.DataLoader(val_data, 28 | batch_size=config.BATCH_SIZE, shuffle=True, **kwargs) 29 | } 30 | 31 | # Set model 32 | model = PNet(is_train=True) 33 | model = model.to(device) 34 | model.load_state_dict(torch.load('pretrained_weights/best_pnet.pth'), strict=True) 35 | 36 | # Set checkpoint 37 | #checkpoint = CheckPoint(train_config.save_path) 38 | 39 | # Set optimizer 40 | optimizer = torch.optim.Adam(model.parameters(), lr=config.LR) 41 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.STEPS, gamma=0.1) 42 | 43 | # Set trainer 44 | trainer = PNetTrainer(config.EPOCHS, dataloaders, model, optimizer, scheduler, device) 45 | 46 | trainer.train() 47 | 48 | #checkpoint.save_model(model, index=epoch, tag=config.SAVE_PREFIX) 49 | 50 | -------------------------------------------------------------------------------- /training/pnet/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | 4 | import torch 5 | 6 | from loss import Loss 7 | from tools.average_meter import AverageMeter 8 | 9 | 10 | class PNetTrainer(object): 11 | 12 | def __init__(self, epochs, dataloaders, model, optimizer, scheduler, device): 13 | self.epochs = epochs 14 | self.dataloaders = dataloaders 15 | self.model = model 16 | self.optimizer = optimizer 17 | self.scheduler = scheduler 18 | self.device = device 19 | self.lossfn = Loss(self.device) 20 | 21 | # save best model 22 | self.best_val_loss = 100 23 | 24 | def compute_accuracy(self, prob_cls, gt_cls): 25 | # we only need the detection which >= 0 26 | prob_cls = torch.squeeze(prob_cls) 27 | mask = torch.ge(gt_cls, 0) 28 | 29 | # get valid elements 30 | valid_gt_cls = gt_cls[mask] 31 | valid_prob_cls = prob_cls[mask] 32 | size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) 33 | 34 | # get max index with softmax layer 35 | _, valid_pred_cls = torch.max(valid_prob_cls, dim=1) 36 | 37 | right_ones = torch.eq(valid_pred_cls.float(), valid_gt_cls.float()).float() 38 | 39 | return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) 40 | 41 | def train(self): 42 | for epoch in range(self.epochs): 43 | self.train_epoch(epoch, 'train') 44 | self.train_epoch(epoch, 'val') 45 | 46 | 47 | def train_epoch(self, epoch, phase): 48 | cls_loss_ = AverageMeter() 49 | bbox_loss_ = AverageMeter() 50 | total_loss_ = AverageMeter() 51 | accuracy_ = AverageMeter() 52 | if phase == 'train': 53 | self.model.train() 54 | else: 55 | self.model.eval() 56 | 57 | for batch_idx, sample in enumerate(self.dataloaders[phase]): 58 | data = sample['input_img'] 59 | gt_cls = sample['cls_target'] 60 | gt_bbox = sample['bbox_target'] 61 | data, gt_cls, gt_bbox = data.to(self.device), gt_cls.to( 62 | self.device), gt_bbox.to(self.device).float() 63 | 64 | self.optimizer.zero_grad() 65 | with torch.set_grad_enabled(phase == 'train'): 66 | pred_cls, pred_bbox = self.model(data) 67 | 68 | # compute the cls loss and bbox loss and weighted them together 69 | cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls) 70 | bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox) 71 | total_loss = cls_loss + 5*bbox_loss 72 | 73 | # compute clssification accuracy 74 | accuracy = self.compute_accuracy(pred_cls, gt_cls) 75 | 76 | if phase == 'train': 77 | total_loss.backward() 78 | self.optimizer.step() 79 | 80 | cls_loss_.update(cls_loss, data.size(0)) 81 | bbox_loss_.update(bbox_loss, data.size(0)) 82 | total_loss_.update(total_loss, data.size(0)) 83 | accuracy_.update(accuracy, data.size(0)) 84 | 85 | if batch_idx % 40 == 0: 86 | print('{} Epoch: {} [{:08d}/{:08d} ({:02.0f}%)]\tLoss: {:.6f} cls Loss: {:.6f} offset Loss:{:.6f}\tAccuracy: {:.6f} LR:{:.7f}'.format( 87 | phase, epoch, batch_idx * len(data), len(self.dataloaders[phase].dataset), 88 | 100. * batch_idx / len(self.dataloaders[phase]), total_loss.item(), cls_loss.item(), bbox_loss.item(), accuracy.item(), self.optimizer.param_groups[0]['lr'])) 89 | 90 | if phase == 'train': 91 | self.scheduler.step() 92 | 93 | print("{} epoch Loss: {:.6f} cls Loss: {:.6f} bbox Loss: {:.6f} Accuracy: {:.6f}".format( 94 | phase, total_loss_.avg, cls_loss_.avg, bbox_loss_.avg, accuracy_.avg)) 95 | 96 | if phase == 'val' and total_loss_.avg < self.best_val_loss: 97 | self.best_val_loss = total_loss_.avg 98 | torch.save(self.model.state_dict(), './pretrained_weights/best_pnet.pth') 99 | 100 | return cls_loss_.avg, bbox_loss_.avg, total_loss_.avg, accuracy_.avg 101 | 102 | -------------------------------------------------------------------------------- /training/rnet/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | import os 4 | import argparse 5 | 6 | import torch 7 | from torchvision import transforms 8 | 9 | from tools.dataset import FaceDataset 10 | from nets.mtcnn import RNet 11 | from training.rnet.trainer import RNetTrainer 12 | from checkpoint import CheckPoint 13 | import config 14 | 15 | # Set device 16 | use_cuda = config.USE_CUDA and torch.cuda.is_available() 17 | # torch.cuda.manual_seed(train_config.manualSeed) 18 | device = torch.device("cuda:1" if use_cuda else "cpu") 19 | torch.backends.cudnn.benchmark = True 20 | 21 | # Set dataloader 22 | kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {} 23 | train_data = FaceDataset(os.path.join(config.ANNO_PATH, config.RNET_TRAIN_IMGLIST_FILENAME)) 24 | val_data = FaceDataset(os.path.join(config.ANNO_PATH, config.RNET_VAL_IMGLIST_FILENAME)) 25 | dataloaders = {'train': torch.utils.data.DataLoader(train_data, 26 | batch_size=config.BATCH_SIZE, shuffle=True, **kwargs), 27 | 'val': torch.utils.data.DataLoader(val_data, 28 | batch_size=config.BATCH_SIZE, shuffle=True, **kwargs) 29 | } 30 | 31 | # Set model 32 | model = RNet(is_train=True) 33 | model = model.to(device) 34 | model.load_state_dict(torch.load('./pretrained_weights/best_rnet.pth'), strict=True) 35 | 36 | # Set checkpoint 37 | #checkpoint = CheckPoint(train_config.save_path) 38 | 39 | # Set optimizer 40 | optimizer = torch.optim.Adam(model.parameters(), lr=config.LR) 41 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.STEPS, gamma=0.1) 42 | 43 | # Set trainer 44 | trainer = RNetTrainer(config.EPOCHS, dataloaders, model, optimizer, scheduler, device) 45 | 46 | trainer.train() 47 | 48 | #checkpoint.save_model(model, index=epoch, tag=config.SAVE_PREFIX) 49 | 50 | -------------------------------------------------------------------------------- /training/rnet/trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | 4 | import torch 5 | 6 | from loss import Loss 7 | from tools.average_meter import AverageMeter 8 | 9 | 10 | class RNetTrainer(object): 11 | 12 | def __init__(self, epochs, dataloaders, model, optimizer, scheduler, device): 13 | self.epochs = epochs 14 | self.dataloaders = dataloaders 15 | self.model = model 16 | self.optimizer = optimizer 17 | self.scheduler = scheduler 18 | self.device = device 19 | self.lossfn = Loss(self.device) 20 | 21 | # save best model 22 | self.best_val_loss = 100 23 | 24 | def compute_accuracy(self, prob_cls, gt_cls): 25 | # we only need the detection which >= 0 26 | prob_cls = torch.squeeze(prob_cls) 27 | mask = torch.ge(gt_cls, 0) 28 | 29 | # get valid elements 30 | valid_gt_cls = gt_cls[mask] 31 | valid_prob_cls = prob_cls[mask] 32 | size = min(valid_gt_cls.size()[0], valid_prob_cls.size()[0]) 33 | 34 | # get max index with softmax layer 35 | _, valid_pred_cls = torch.max(valid_prob_cls, dim=1) 36 | 37 | right_ones = torch.eq(valid_pred_cls.float(), valid_gt_cls.float()).float() 38 | 39 | return torch.div(torch.mul(torch.sum(right_ones), float(1.0)), float(size)) 40 | 41 | def train(self): 42 | for epoch in range(self.epochs): 43 | self.train_epoch(epoch, 'train') 44 | self.train_epoch(epoch, 'val') 45 | 46 | 47 | def train_epoch(self, epoch, phase): 48 | cls_loss_ = AverageMeter() 49 | bbox_loss_ = AverageMeter() 50 | total_loss_ = AverageMeter() 51 | accuracy_ = AverageMeter() 52 | if phase == 'train': 53 | self.model.train() 54 | else: 55 | self.model.eval() 56 | 57 | for batch_idx, sample in enumerate(self.dataloaders[phase]): 58 | data = sample['input_img'] 59 | gt_cls = sample['cls_target'] 60 | gt_bbox = sample['bbox_target'] 61 | data, gt_cls, gt_bbox = data.to(self.device), gt_cls.to( 62 | self.device), gt_bbox.to(self.device).float() 63 | 64 | self.optimizer.zero_grad() 65 | with torch.set_grad_enabled(phase == 'train'): 66 | pred_cls, pred_bbox = self.model(data) 67 | 68 | # compute the cls loss and bbox loss and weighted them together 69 | cls_loss = self.lossfn.cls_loss(gt_cls, pred_cls) 70 | bbox_loss = self.lossfn.box_loss(gt_cls, gt_bbox, pred_bbox) 71 | total_loss = cls_loss + 5*bbox_loss 72 | 73 | # compute clssification accuracy 74 | accuracy = self.compute_accuracy(pred_cls, gt_cls) 75 | 76 | if phase == 'train': 77 | total_loss.backward() 78 | self.optimizer.step() 79 | 80 | cls_loss_.update(cls_loss, data.size(0)) 81 | bbox_loss_.update(bbox_loss, data.size(0)) 82 | total_loss_.update(total_loss, data.size(0)) 83 | accuracy_.update(accuracy, data.size(0)) 84 | 85 | if batch_idx % 40 == 0: 86 | print('{} Epoch: {} [{:08d}/{:08d} ({:02.0f}%)]\tLoss: {:.6f} cls Loss: {:.6f} offset Loss:{:.6f}\tAccuracy: {:.6f} LR:{:.7f}'.format( 87 | phase, epoch, batch_idx * len(data), len(self.dataloaders[phase].dataset), 88 | 100. * batch_idx / len(self.dataloaders[phase]), total_loss.item(), cls_loss.item(), bbox_loss.item(), accuracy.item(), self.optimizer.param_groups[0]['lr'])) 89 | 90 | if phase == 'train': 91 | self.scheduler.step() 92 | 93 | print("{} epoch Loss: {:.6f} cls Loss: {:.6f} bbox Loss: {:.6f} Accuracy: {:.6f}".format( 94 | phase, total_loss_.avg, cls_loss_.avg, bbox_loss_.avg, accuracy_.avg)) 95 | 96 | if phase == 'val' and total_loss_.avg < self.best_val_loss: 97 | self.best_val_loss = total_loss_.avg 98 | torch.save(self.model.state_dict(), './pretrained_weights/best_rnet.pth') 99 | 100 | return cls_loss_.avg, bbox_loss_.avg, total_loss_.avg, accuracy_.avg 101 | 102 | --------------------------------------------------------------------------------