├── .gitignore ├── LICENSE ├── README.md ├── cal_recall ├── __init__.py ├── rrc_evaluation_funcs.py └── script.py ├── config.py ├── dataset ├── __init__.py ├── augment.py ├── augment_img.py └── data_utils.py ├── eval.py ├── imgs ├── img_125.jpg ├── img_31.jpg ├── img_73.jpg ├── img_83.jpg └── img_98.jpg ├── models ├── ShuffleNetV2.py ├── __init__.py ├── loss.py ├── mobilenetv3.py ├── model.py └── resnet.py ├── predict.py ├── pse ├── Makefile ├── __init__.py ├── include │ └── pybind11 │ │ ├── attr.h │ │ ├── buffer_info.h │ │ ├── cast.h │ │ ├── chrono.h │ │ ├── class_support.h │ │ ├── common.h │ │ ├── complex.h │ │ ├── descr.h │ │ ├── detail │ │ ├── class.h │ │ ├── common.h │ │ ├── descr.h │ │ ├── init.h │ │ ├── internals.h │ │ └── typeid.h │ │ ├── eigen.h │ │ ├── embed.h │ │ ├── eval.h │ │ ├── functional.h │ │ ├── iostream.h │ │ ├── numpy.h │ │ ├── operators.h │ │ ├── options.h │ │ ├── pybind11.h │ │ ├── pytypes.h │ │ ├── stl.h │ │ ├── stl_bind.h │ │ └── typeid.h ├── ncnn │ └── examples │ │ ├── CMakeLists.txt │ │ ├── psenet.cpp │ │ └── run.sh ├── pse.cpp └── pse.so ├── train.py └── utils ├── __init__.py ├── lr_scheduler.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | __pycache__/ 3 | **/.pytest_cache/ 4 | output/ 5 | result*/ 6 | log/ 7 | *.npy 8 | *.pyc 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Shape Robust Text Detection with Progressive Scale Expansion Network 2 | 3 | ## Requirements 4 | * pytorch 1.1 5 | * torchvision 0.3 6 | * pyclipper 7 | * opencv3 8 | * gcc 4.9+ 9 | 10 | ## Update 11 | ### 20190401 12 | 13 | 1. add author loss, the results are compared in [Performance](#Performance) 14 | 15 | 16 | ### Download 17 | resnet50 and resnet152 model on icdar 2015: 18 | 19 | 1. ~~[bauduyun](https://pan.baidu.com/s/1rN0oGBRsdUYmcQUayMZUOA) extract code: rxjf~~ 20 | 21 | 2. ~~[google drive](https://drive.google.com/drive/folders/1r3Q1GJ5990WYrwXKT29aHvfQNW92QkXv?usp=sharing)~~ 22 | 23 | ## Data Preparation 24 | follow icdar15 dataset format 25 | ``` 26 | img 27 | │ 1.jpg 28 | │ 2.jpg 29 | │ ... 30 | gt 31 | │ gt_1.txt 32 | │ gt_2.txt 33 | | ... 34 | ``` 35 | 36 | ## Train 37 | 1. config the `trainroot`,`testroot`in [config.py](config.py) 38 | 2. use following script to run 39 | ```sh 40 | python3 train.py 41 | ``` 42 | 43 | ## Test 44 | [eval.py](eval.py) is used to test model on test dataset 45 | 46 | 1. config `model_path`, `data_path`, `gt_path`, `save_path` in [eval.py](eval.py) 47 | 2. use following script to test 48 | ```sh 49 | python3 eval.py 50 | ``` 51 | 52 | ## Predict 53 | [predict.py](predict.py) is used to inference on single image 54 | 55 | 1. config `model_path`, `img_path`, `gt_path`, `save_path` in [predict.py](predict.py) 56 | 2. use following script to predict 57 | ```sh 58 | python3 predict.py 59 | ``` 60 | 61 | 62 |

Performance

63 | 64 | ### [ICDAR 2015](http://rrc.cvc.uab.es/?ch=4) 65 | only train on ICDAR2015 dataset with single NVIDIA 1080Ti 66 | 67 | my implementation with my loss use adam and warm_up 68 | 69 | | Method | Precision (%) | Recall (%) | F-measure (%) | FPS(1080Ti) | 70 | |--------------------------|---------------|------------|---------------|-----| 71 | | PSENet-1s with resnet50 batch 8 | 81.13 | 77.03 | 79.03 | 1.76 | 72 | | PSENet-2s with resnet50 batch 8 | 81.36 | 77.13 | 79.18 | 3.55 | 73 | | PSENet-4s with resnet50 batch 8 | 81.00 | 76.55 | 78.71 | 4.43 | 74 | | PSENet-1s with resnet152 batch 4 | 85.45 | 80.06 | 82.67 | 1.48 | 75 | | PSENet-2s with resnet152 batch 4 | 85.42 | 80.11 | 82.68 | 2.56 | 76 | | PSENet-4s with resnet152 batch 4 | 83.93 | 79.00 | 81.39 | 2.99 | 77 | 78 | my implementation with my loss use adam and MultiStepLR 79 | 80 | | Method | Precision (%) | Recall (%) | F-measure (%) | FPS(1080Ti) | 81 | |--------------------------|---------------|------------|---------------|-----| 82 | | PSENet-1s with resnet50 batch 8 | 83.39 | 79.29 | 81.29 | 1.76 | 83 | | PSENet-2s with resnet50 batch 8 | 83.22 | 79.05 | 81.08 | 3.55 | 84 | | PSENet-4s with resnet50 batch 8 | 82.57 | 78.23 | 80.34 | 4.43 | 85 | | PSENet-1s with resnet152 batch 4 | 85.33 | 79.87 | 82.51 | 1.48 | 86 | | PSENet-2s with resnet152 batch 4 | 85.36 | 79.73 | 82.45 | 2.56 | 87 | | PSENet-4s with resnet152 batch 4 | 83.95 | 78.86 | 81.33 | 2.99 | 88 | 89 | my implementation with author loss use adam and warm_up 90 | 91 | | Method | Precision (%) | Recall (%) | F-measure (%) | FPS(1080Ti) | 92 | |--------------------------|---------------|------------|---------------|-----| 93 | | PSENet-1s with resnet50 batch 8 | 83.33 | 77.75 | 80.44 | 1.76 | 94 | | PSENet-2s with resnet50 batch 8 | 83.01 | 77.66 | 80.24 | 3.55 | 95 | | PSENet-4s with resnet50 batch 8 | 82.38 | 76.98 | 79.59 | 4.43 | 96 | | PSENet-1s with resnet152 batch 4 | 85.16 | 79.87 | 82.43 | 1.48 | 97 | | PSENet-2s with resnet152 batch 4 | 85.03 | 79.63 | 82.24 | 2.56 | 98 | | PSENet-4s with resnet152 batch 4 | 84.53S | 79.20 | 81.77 | 2.99 | 99 | 100 | my implementation with author loss use adam and MultiStepLR 101 | 102 | | Method | Precision (%) | Recall (%) | F-measure (%) | FPS(1080Ti) | 103 | |--------------------------|---------------|------------|---------------|-----| 104 | | PSENet-1s with resnet50 batch 8 | 83.93 | 79.48 | 81.65 | 1.76 | 105 | | PSENet-2s with resnet50 batch 8 | 84.17 | 79.63 | 81.84 | 3.55 | 106 | | PSENet-4s with resnet50 batch 8 | 83.50 | 78.71 | 81.04 | 4.43 | 107 | | PSENet-1s with resnet152 batch 4 | 85.16 | 79.58 | 82.28 | 1.48 | 108 | | PSENet-2s with resnet152 batch 4 | 85.13 | 79.15 | 82.03 | 2.56 | 109 | | PSENet-4s with resnet152 batch 4 | 84.40 | 78.71 | 81.46 | 2.99 | 110 | 111 | official implementation use SGD and StepLR 112 | 113 | | Method | Precision (%) | Recall (%) | F-measure (%) | FPS(1080Ti) | 114 | |--------------------------|---------------|------------|---------------|-----| 115 | | PSENet-1s with resnet50 batch 8 | 84.15 | 80.26 | 82.16 | 1.76 | 116 | | PSENet-2s with resnet50 batch 8 | 83.61 | 79.82 | 81.67 | 3.72 | 117 | | PSENet-4s with resnet50 batch 8 | 81.90 | 78.23 | 80.03 | 4.51 | 118 | | PSENet-1s with resnet152 batch 4 | 82.87 | 78.76 | 80.77 | 1.53 | 119 | | PSENet-2s with resnet152 batch 4 | 82.33 | 78.33 | 80.28 | 2.61 | 120 | | PSENet-4s with resnet152 batch 4 | 81.19 | 77.13 | 79.11 | 3.00 | 121 | 122 | ### examples 123 | ![](imgs/img_31.jpg) 124 | 125 | ![](imgs/img_73.jpg) 126 | 127 | ![](imgs/img_83.jpg) 128 | 129 | ![](imgs/img_98.jpg) 130 | 131 | ![](imgs/img_125.jpg) 132 | 133 | ### reference 134 | 1. https://github.com/liuheng92/tensorflow_PSENet 135 | 2. https://github.com/whai362/PSENet 136 | -------------------------------------------------------------------------------- /cal_recall/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/16/19 6:40 AM 3 | # @Author : zhoujun 4 | from .script import cal_recall_precison_f1 5 | __all__ = ['cal_recall_precison_f1'] -------------------------------------------------------------------------------- /cal_recall/script.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from collections import namedtuple 4 | from . import rrc_evaluation_funcs 5 | import Polygon as plg 6 | import numpy as np 7 | 8 | 9 | def default_evaluation_params(): 10 | """ 11 | default_evaluation_params: Default parameters to use for the validation and evaluation. 12 | """ 13 | return { 14 | 'IOU_CONSTRAINT': 0.5, 15 | 'AREA_PRECISION_CONSTRAINT': 0.5, 16 | 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt', 17 | 'DET_SAMPLE_NAME_2_ID': 'res_img_([0-9]+).txt', 18 | 'LTRB': False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) 19 | 'CRLF': False, # Lines are delimited by Windows CRLF format 20 | 'CONFIDENCES': False, # Detections must include confidence value. AP will be calculated 21 | 'PER_SAMPLE_RESULTS': True # Generate per sample results and produce data for visualization 22 | } 23 | 24 | 25 | def validate_data(gtFilePath, submFilePath, evaluationParams): 26 | """ 27 | Method validate_data: validates that all files in the results folder are correct (have the correct name contents). 28 | Validates also that there are no missing files in the folder. 29 | If some error detected, the method raises the error 30 | """ 31 | gt = rrc_evaluation_funcs.load_folder_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) 32 | 33 | subm = rrc_evaluation_funcs.load_folder_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) 34 | 35 | # Validate format of GroundTruth 36 | for k in gt: 37 | rrc_evaluation_funcs.validate_lines_in_file(k, gt[k], evaluationParams['CRLF'], evaluationParams['LTRB'], True) 38 | 39 | # Validate format of results 40 | for k in subm: 41 | if (k in gt) == False: 42 | raise Exception("The sample %s not present in GT" % k) 43 | 44 | rrc_evaluation_funcs.validate_lines_in_file(k, subm[k], evaluationParams['CRLF'], evaluationParams['LTRB'], 45 | False, evaluationParams['CONFIDENCES']) 46 | 47 | 48 | def evaluate_method(gtFilePath, submFilePath, evaluationParams): 49 | """ 50 | Method evaluate_method: evaluate method and returns the results 51 | Results. Dictionary with the following values: 52 | - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } 53 | - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } 54 | """ 55 | 56 | def polygon_from_points(points): 57 | """ 58 | Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 59 | """ 60 | resBoxes = np.empty([1, 8], dtype='int32') 61 | resBoxes[0, 0] = int(points[0]) 62 | resBoxes[0, 4] = int(points[1]) 63 | resBoxes[0, 1] = int(points[2]) 64 | resBoxes[0, 5] = int(points[3]) 65 | resBoxes[0, 2] = int(points[4]) 66 | resBoxes[0, 6] = int(points[5]) 67 | resBoxes[0, 3] = int(points[6]) 68 | resBoxes[0, 7] = int(points[7]) 69 | pointMat = resBoxes[0].reshape([2, 4]).T 70 | return plg.Polygon(pointMat) 71 | 72 | def rectangle_to_polygon(rect): 73 | resBoxes = np.empty([1, 8], dtype='int32') 74 | resBoxes[0, 0] = int(rect.xmin) 75 | resBoxes[0, 4] = int(rect.ymax) 76 | resBoxes[0, 1] = int(rect.xmin) 77 | resBoxes[0, 5] = int(rect.ymin) 78 | resBoxes[0, 2] = int(rect.xmax) 79 | resBoxes[0, 6] = int(rect.ymin) 80 | resBoxes[0, 3] = int(rect.xmax) 81 | resBoxes[0, 7] = int(rect.ymax) 82 | 83 | pointMat = resBoxes[0].reshape([2, 4]).T 84 | 85 | return plg.Polygon(pointMat) 86 | 87 | def rectangle_to_points(rect): 88 | points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), 89 | int(rect.xmin), int(rect.ymin)] 90 | return points 91 | 92 | def get_union(pD, pG): 93 | areaA = pD.area(); 94 | areaB = pG.area(); 95 | return areaA + areaB - get_intersection(pD, pG); 96 | 97 | def get_intersection_over_union(pD, pG): 98 | try: 99 | return get_intersection(pD, pG) / get_union(pD, pG); 100 | except: 101 | return 0 102 | 103 | def get_intersection(pD, pG): 104 | pInt = pD & pG 105 | if len(pInt) == 0: 106 | return 0 107 | return pInt.area() 108 | 109 | def compute_ap(confList, matchList, numGtCare): 110 | correct = 0 111 | AP = 0 112 | if len(confList) > 0: 113 | confList = np.array(confList) 114 | matchList = np.array(matchList) 115 | sorted_ind = np.argsort(-confList) 116 | confList = confList[sorted_ind] 117 | matchList = matchList[sorted_ind] 118 | for n in range(len(confList)): 119 | match = matchList[n] 120 | if match: 121 | correct += 1 122 | AP += float(correct) / (n + 1) 123 | 124 | if numGtCare > 0: 125 | AP /= numGtCare 126 | 127 | return AP 128 | 129 | perSampleMetrics = {} 130 | 131 | matchedSum = 0 132 | 133 | Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') 134 | 135 | gt = rrc_evaluation_funcs.load_folder_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) 136 | subm = rrc_evaluation_funcs.load_folder_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) 137 | 138 | numGlobalCareGt = 0; 139 | numGlobalCareDet = 0; 140 | 141 | arrGlobalConfidences = []; 142 | arrGlobalMatches = []; 143 | 144 | for resFile in gt: 145 | 146 | gtFile = gt[resFile] # rrc_evaluation_funcs.decode_utf8(gt[resFile]) 147 | recall = 0 148 | precision = 0 149 | hmean = 0 150 | 151 | detMatched = 0 152 | 153 | iouMat = np.empty([1, 1]) 154 | 155 | gtPols = [] 156 | detPols = [] 157 | 158 | gtPolPoints = [] 159 | detPolPoints = [] 160 | 161 | # Array of Ground Truth Polygons' keys marked as don't Care 162 | gtDontCarePolsNum = [] 163 | # Array of Detected Polygons' matched with a don't Care GT 164 | detDontCarePolsNum = [] 165 | 166 | pairs = [] 167 | detMatchedNums = [] 168 | 169 | arrSampleConfidences = []; 170 | arrSampleMatch = []; 171 | sampleAP = 0; 172 | 173 | evaluationLog = "" 174 | 175 | pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile, 176 | evaluationParams[ 177 | 'CRLF'], 178 | evaluationParams[ 179 | 'LTRB'], 180 | True, False) 181 | for n in range(len(pointsList)): 182 | points = pointsList[n] 183 | transcription = transcriptionsList[n] 184 | dontCare = transcription == "###" 185 | if evaluationParams['LTRB']: 186 | gtRect = Rectangle(*points) 187 | gtPol = rectangle_to_polygon(gtRect) 188 | else: 189 | gtPol = polygon_from_points(points) 190 | gtPols.append(gtPol) 191 | gtPolPoints.append(points) 192 | if dontCare: 193 | gtDontCarePolsNum.append(len(gtPols) - 1) 194 | 195 | evaluationLog += "GT polygons: " + str(len(gtPols)) + ( 196 | " (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n") 197 | 198 | if resFile in subm: 199 | 200 | detFile = subm[resFile] # rrc_evaluation_funcs.decode_utf8(subm[resFile]) 201 | 202 | pointsList, confidencesList, _ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile, 203 | evaluationParams[ 204 | 'CRLF'], 205 | evaluationParams[ 206 | 'LTRB'], 207 | False, 208 | evaluationParams[ 209 | 'CONFIDENCES']) 210 | for n in range(len(pointsList)): 211 | points = pointsList[n] 212 | 213 | if evaluationParams['LTRB']: 214 | detRect = Rectangle(*points) 215 | detPol = rectangle_to_polygon(detRect) 216 | else: 217 | detPol = polygon_from_points(points) 218 | detPols.append(detPol) 219 | detPolPoints.append(points) 220 | if len(gtDontCarePolsNum) > 0: 221 | for dontCarePol in gtDontCarePolsNum: 222 | dontCarePol = gtPols[dontCarePol] 223 | intersected_area = get_intersection(dontCarePol, detPol) 224 | pdDimensions = detPol.area() 225 | precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions 226 | if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT']): 227 | detDontCarePolsNum.append(len(detPols) - 1) 228 | break 229 | 230 | evaluationLog += "DET polygons: " + str(len(detPols)) + ( 231 | " (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n") 232 | 233 | if len(gtPols) > 0 and len(detPols) > 0: 234 | # Calculate IoU and precision matrixs 235 | outputShape = [len(gtPols), len(detPols)] 236 | iouMat = np.empty(outputShape) 237 | gtRectMat = np.zeros(len(gtPols), np.int8) 238 | detRectMat = np.zeros(len(detPols), np.int8) 239 | for gtNum in range(len(gtPols)): 240 | for detNum in range(len(detPols)): 241 | pG = gtPols[gtNum] 242 | pD = detPols[detNum] 243 | iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG) 244 | 245 | for gtNum in range(len(gtPols)): 246 | for detNum in range(len(detPols)): 247 | if gtRectMat[gtNum] == 0 and detRectMat[ 248 | detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum: 249 | if iouMat[gtNum, detNum] > evaluationParams['IOU_CONSTRAINT']: 250 | gtRectMat[gtNum] = 1 251 | detRectMat[detNum] = 1 252 | detMatched += 1 253 | pairs.append({'gt': gtNum, 'det': detNum}) 254 | detMatchedNums.append(detNum) 255 | evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n" 256 | 257 | if evaluationParams['CONFIDENCES']: 258 | for detNum in range(len(detPols)): 259 | if detNum not in detDontCarePolsNum: 260 | # we exclude the don't care detections 261 | match = detNum in detMatchedNums 262 | 263 | arrSampleConfidences.append(confidencesList[detNum]) 264 | arrSampleMatch.append(match) 265 | 266 | arrGlobalConfidences.append(confidencesList[detNum]); 267 | arrGlobalMatches.append(match); 268 | 269 | numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) 270 | numDetCare = (len(detPols) - len(detDontCarePolsNum)) 271 | if numGtCare == 0: 272 | recall = float(1) 273 | precision = float(0) if numDetCare > 0 else float(1) 274 | sampleAP = precision 275 | else: 276 | recall = float(detMatched) / numGtCare 277 | precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare 278 | if evaluationParams['CONFIDENCES'] and evaluationParams['PER_SAMPLE_RESULTS']: 279 | sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare) 280 | 281 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 282 | 283 | matchedSum += detMatched 284 | numGlobalCareGt += numGtCare 285 | numGlobalCareDet += numDetCare 286 | 287 | if evaluationParams['PER_SAMPLE_RESULTS']: 288 | perSampleMetrics[resFile] = { 289 | 'precision': precision, 290 | 'recall': recall, 291 | 'hmean': hmean, 292 | 'pairs': pairs, 293 | 'AP': sampleAP, 294 | 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), 295 | 'gtPolPoints': gtPolPoints, 296 | 'detPolPoints': detPolPoints, 297 | 'gtDontCare': gtDontCarePolsNum, 298 | 'detDontCare': detDontCarePolsNum, 299 | 'evaluationParams': evaluationParams, 300 | 'evaluationLog': evaluationLog 301 | } 302 | 303 | # Compute MAP and MAR 304 | AP = 0 305 | if evaluationParams['CONFIDENCES']: 306 | AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt) 307 | 308 | methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum) / numGlobalCareGt 309 | methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum) / numGlobalCareDet 310 | methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / ( 311 | methodRecall + methodPrecision) 312 | 313 | methodMetrics = {'precision': methodPrecision, 'recall': methodRecall, 'hmean': methodHmean, 'AP': AP} 314 | 315 | resDict = {'calculated': True, 'Message': '', 'method': methodMetrics, 'per_sample': perSampleMetrics} 316 | 317 | return resDict; 318 | 319 | 320 | def cal_recall_precison_f1(gt_path, result_path, show_result=False): 321 | p = {'g': gt_path, 's': result_path} 322 | result = rrc_evaluation_funcs.main_evaluation(p, default_evaluation_params, validate_data, evaluate_method, 323 | show_result) 324 | return result['method'] 325 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/3 17:40 3 | # @Author : zhoujun 4 | 5 | # data config 6 | trainroot = '/data2/dataset/ICD15/train' 7 | testroot = '/data2/dataset/ICD15/test' 8 | output_dir = 'output/psenet_icd2015_resnet152_4gpu_author_crop_adam_MultiStepLR_authorloss' 9 | data_shape = 640 10 | 11 | # train config 12 | gpu_id = '2' 13 | workers = 12 14 | start_epoch = 0 15 | epochs = 600 16 | 17 | train_batch_size = 4 18 | 19 | lr = 1e-4 20 | end_lr = 1e-7 21 | lr_gamma = 0.1 22 | lr_decay_step = [200,400] 23 | weight_decay = 5e-4 24 | warm_up_epoch = 6 25 | warm_up_lr = lr * lr_gamma 26 | 27 | display_input_images = False 28 | display_output_images = False 29 | display_interval = 10 30 | show_images_interval = 50 31 | 32 | pretrained = True 33 | restart_training = True 34 | checkpoint = '' 35 | 36 | # net config 37 | backbone = 'resnet152' 38 | Lambda = 0.7 39 | n = 6 40 | m = 0.5 41 | OHEM_ratio = 3 42 | scale = 1 43 | # random seed 44 | seed = 2 45 | 46 | 47 | def print(): 48 | from pprint import pformat 49 | tem_d = {} 50 | for k, v in globals().items(): 51 | if not k.startswith('_') and not callable(v): 52 | tem_d[k] = v 53 | return pformat(tem_d) 54 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/17/19 2:09 AM 3 | # @Author : zhoujun -------------------------------------------------------------------------------- /dataset/augment.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/12 13:06 3 | 4 | import cv2 5 | import numbers 6 | import math 7 | import random 8 | import numpy as np 9 | from skimage.util import random_noise 10 | 11 | 12 | def show_pic(img, bboxes=None, name='pic'): 13 | ''' 14 | 输入: 15 | img:图像array 16 | bboxes:图像的所有boudning box list, 格式为[[x_min, y_min, x_max, y_max]....] 17 | names:每个box对应的名称 18 | ''' 19 | show_img = img.copy() 20 | if not isinstance(bboxes, np.ndarray): 21 | bboxes = np.array(bboxes) 22 | for point in bboxes.astype(np.int): 23 | cv2.line(show_img, tuple(point[0]), tuple(point[1]), (255, 0, 0), 2) 24 | cv2.line(show_img, tuple(point[1]), tuple(point[2]), (255, 0, 0), 2) 25 | cv2.line(show_img, tuple(point[2]), tuple(point[3]), (255, 0, 0), 2) 26 | cv2.line(show_img, tuple(point[3]), tuple(point[0]), (255, 0, 0), 2) 27 | # cv2.namedWindow(name, 0) # 1表示原图 28 | # cv2.moveWindow(name, 0, 0) 29 | # cv2.resizeWindow(name, 1200, 800) # 可视化的图片大小 30 | cv2.imshow(name, show_img) 31 | 32 | 33 | # 图像均为cv2读取 34 | class DataAugment(): 35 | def __init__(self): 36 | pass 37 | 38 | def add_noise(self, im: np.ndarray): 39 | """ 40 | 对图片加噪声 41 | :param img: 图像array 42 | :return: 加噪声后的图像array,由于输出的像素是在[0,1]之间,所以得乘以255 43 | """ 44 | return (random_noise(im, mode='gaussian', clip=True) * 255).astype(im.dtype) 45 | 46 | def random_scale(self, im: np.ndarray, text_polys: np.ndarray, scales: np.ndarray or list) -> tuple: 47 | """ 48 | 从scales中随机选择一个尺度,对图片和文本框进行缩放 49 | :param im: 原图 50 | :param text_polys: 文本框 51 | :param scales: 尺度 52 | :return: 经过缩放的图片和文本 53 | """ 54 | tmp_text_polys = text_polys.copy() 55 | rd_scale = float(np.random.choice(scales)) 56 | im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) 57 | tmp_text_polys *= rd_scale 58 | return im, tmp_text_polys 59 | 60 | def random_rotate_img_bbox(self, img, text_polys, degrees: numbers.Number or list or tuple or np.ndarray, 61 | same_size=False): 62 | """ 63 | 从给定的角度中选择一个角度,对图片和文本框进行旋转 64 | :param img: 图片 65 | :param text_polys: 文本框 66 | :param degrees: 角度,可以是一个数值或者list 67 | :param same_size: 是否保持和原图一样大 68 | :return: 旋转后的图片和角度 69 | """ 70 | if isinstance(degrees, numbers.Number): 71 | if degrees < 0: 72 | raise ValueError("If degrees is a single number, it must be positive.") 73 | degrees = (-degrees, degrees) 74 | elif isinstance(degrees, list) or isinstance(degrees, tuple) or isinstance(degrees, np.ndarray): 75 | if len(degrees) != 2: 76 | raise ValueError("If degrees is a sequence, it must be of len 2.") 77 | degrees = degrees 78 | else: 79 | raise Exception('degrees must in Number or list or tuple or np.ndarray') 80 | # ---------------------- 旋转图像 ---------------------- 81 | w = img.shape[1] 82 | h = img.shape[0] 83 | angle = np.random.uniform(degrees[0], degrees[1]) 84 | 85 | if same_size: 86 | nw = w 87 | nh = h 88 | else: 89 | # 角度变弧度 90 | rangle = np.deg2rad(angle) 91 | # 计算旋转之后图像的w, h 92 | nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) 93 | nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) 94 | # 构造仿射矩阵 95 | rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1) 96 | # 计算原图中心点到新图中心点的偏移量 97 | rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0])) 98 | # 更新仿射矩阵 99 | rot_mat[0, 2] += rot_move[0] 100 | rot_mat[1, 2] += rot_move[1] 101 | # 仿射变换 102 | rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4) 103 | 104 | # ---------------------- 矫正bbox坐标 ---------------------- 105 | # rot_mat是最终的旋转矩阵 106 | # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下 107 | rot_text_polys = list() 108 | for bbox in text_polys: 109 | point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1])) 110 | point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1])) 111 | point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1])) 112 | point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1])) 113 | rot_text_polys.append([point1, point2, point3, point4]) 114 | return rot_img, np.array(rot_text_polys, dtype=np.float32) 115 | 116 | def random_crop_img_bboxes(self, im: np.ndarray, text_polys: np.ndarray, max_tries=50) -> tuple: 117 | """ 118 | 从图片中裁剪出 cropsize大小的图片和对应区域的文本框 119 | :param im: 图片 120 | :param text_polys: 文本框 121 | :param max_tries: 最大尝试次数 122 | :return: 裁剪后的图片和文本框 123 | """ 124 | h, w, _ = im.shape 125 | pad_h = h // 10 126 | pad_w = w // 10 127 | h_array = np.zeros((h + pad_h * 2), dtype=np.int32) 128 | w_array = np.zeros((w + pad_w * 2), dtype=np.int32) 129 | for poly in text_polys: 130 | poly = np.round(poly, decimals=0).astype(np.int32) # 四舍五入取整 131 | minx = np.min(poly[:, 0]) 132 | maxx = np.max(poly[:, 0]) 133 | w_array[minx + pad_w:maxx + pad_w] = 1 # 将文本区域的在w_array上设为1,表示x轴方向上这部分位置有文本 134 | miny = np.min(poly[:, 1]) 135 | maxy = np.max(poly[:, 1]) 136 | h_array[miny + pad_h:maxy + pad_h] = 1 # 将文本区域的在h_array上设为1,表示y轴方向上这部分位置有文本 137 | # 在两个轴上 拿出背景位置去进行随机的位置选择,避免选择的区域穿过文本 138 | h_axis = np.where(h_array == 0)[0] 139 | w_axis = np.where(w_array == 0)[0] 140 | if len(h_axis) == 0 or len(w_axis) == 0: 141 | # 整张图全是文本的情况下,直接返回 142 | return im, text_polys 143 | for i in range(max_tries): 144 | xx = np.random.choice(w_axis, size=2) 145 | # 对选择区域进行边界控制 146 | xmin = np.min(xx) - pad_w 147 | xmax = np.max(xx) - pad_w 148 | xmin = np.clip(xmin, 0, w - 1) 149 | xmax = np.clip(xmax, 0, w - 1) 150 | yy = np.random.choice(h_axis, size=2) 151 | ymin = np.min(yy) - pad_h 152 | ymax = np.max(yy) - pad_h 153 | ymin = np.clip(ymin, 0, h - 1) 154 | ymax = np.clip(ymax, 0, h - 1) 155 | if xmax - xmin < 0.1 * w or ymax - ymin < 0.1 * h: 156 | # 选择的区域过小 157 | # area too small 158 | continue 159 | if text_polys.shape[0] != 0: # 这个判断不知道干啥的 160 | poly_axis_in_area = (text_polys[:, :, 0] >= xmin) & (text_polys[:, :, 0] <= xmax) \ 161 | & (text_polys[:, :, 1] >= ymin) & (text_polys[:, :, 1] <= ymax) 162 | selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0] 163 | else: 164 | selected_polys = [] 165 | if len(selected_polys) == 0: 166 | # 区域内没有文本 167 | continue 168 | im = im[ymin:ymax + 1, xmin:xmax + 1, :] 169 | polys = text_polys[selected_polys] 170 | # 坐标调整到裁剪图片上 171 | polys[:, :, 0] -= xmin 172 | polys[:, :, 1] -= ymin 173 | return im, polys 174 | return im, text_polys 175 | 176 | def random_crop_image_pse(self, im: np.ndarray, text_polys: np.ndarray, input_size) -> tuple: 177 | """ 178 | 从图片中裁剪出 cropsize大小的图片和对应区域的文本框 179 | :param im: 图片 180 | :param text_polys: 文本框 181 | :param input_size: 输出图像大小 182 | :return: 裁剪后的图片和文本框 183 | """ 184 | h, w, _ = im.shape 185 | short_edge = min(h, w) 186 | if short_edge < input_size: 187 | # 保证短边 >= inputsize 188 | scale = input_size / short_edge 189 | im = cv2.resize(im, dsize=None, fx=scale, fy=scale) 190 | text_polys *= scale 191 | h, w, _ = im.shape 192 | # 计算随机范围 193 | w_range = w - input_size 194 | h_range = h - input_size 195 | for _ in range(50): 196 | xmin = random.randint(0, w_range) 197 | ymin = random.randint(0, h_range) 198 | xmax = xmin + input_size 199 | ymax = ymin + input_size 200 | if text_polys.shape[0] != 0: 201 | selected_polys = [] 202 | for poly in text_polys: 203 | if poly[:, 0].max() < xmin or poly[:, 0].min() > xmax or \ 204 | poly[:, 1].max() < ymin or poly[:, 1].min() > ymax: 205 | continue 206 | # area_p = cv2.contourArea(poly) 207 | poly[:, 0] -= xmin 208 | poly[:, 1] -= ymin 209 | poly[:, 0] = np.clip(poly[:, 0], 0, input_size) 210 | poly[:, 1] = np.clip(poly[:, 1], 0, input_size) 211 | # rect = cv2.minAreaRect(poly) 212 | # area_n = cv2.contourArea(poly) 213 | # h1, w1 = rect[1] 214 | # if w1 < 10 or h1 < 10 or area_n / area_p < 0.5: 215 | # continue 216 | selected_polys.append(poly) 217 | else: 218 | selected_polys = [] 219 | # if len(selected_polys) == 0: 220 | # 区域内没有文本 221 | # continue 222 | im = im[ymin:ymax, xmin:xmax, :] 223 | polys = np.array(selected_polys) 224 | return im, polys 225 | return im, text_polys 226 | 227 | def random_crop_author(self,imgs, img_size): 228 | h, w = imgs[0].shape[0:2] 229 | th, tw = img_size 230 | if w == tw and h == th: 231 | return imgs 232 | 233 | # label中存在文本实例,并且按照概率进行裁剪 234 | if np.max(imgs[1][:,:,-1]) > 0 and random.random() > 3.0 / 8.0: 235 | # 文本实例的top left点 236 | tl = np.min(np.where(imgs[1][:,:,-1] > 0), axis=1) - img_size 237 | tl[tl < 0] = 0 238 | # 文本实例的 bottom right 点 239 | br = np.max(np.where(imgs[1][:,:,-1] > 0), axis=1) - img_size 240 | br[br < 0] = 0 241 | # 保证选到右下角点是,有足够的距离进行crop 242 | br[0] = min(br[0], h - th) 243 | br[1] = min(br[1], w - tw) 244 | for _ in range(50000): 245 | i = random.randint(tl[0], br[0]) 246 | j = random.randint(tl[1], br[1]) 247 | # 保证最小的图有文本 248 | if imgs[1][:,:,0][i:i + th, j:j + tw].sum() <= 0: 249 | continue 250 | else: 251 | break 252 | else: 253 | i = random.randint(0, h - th) 254 | j = random.randint(0, w - tw) 255 | 256 | # return i, j, th, tw 257 | for idx in range(len(imgs)): 258 | if len(imgs[idx].shape) == 3: 259 | imgs[idx] = imgs[idx][i:i + th, j:j + tw, :] 260 | else: 261 | imgs[idx] = imgs[idx][i:i + th, j:j + tw] 262 | return imgs 263 | 264 | def resize(self, im: np.ndarray, text_polys: np.ndarray, 265 | input_size: numbers.Number or list or tuple or np.ndarray, keep_ratio: bool = False) -> tuple: 266 | """ 267 | 对图片和文本框进行resize 268 | :param im: 图片 269 | :param text_polys: 文本框 270 | :param input_size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h] 271 | :param keep_ratio: 是否保持长宽比 272 | :return: resize后的图片和文本框 273 | """ 274 | if isinstance(input_size, numbers.Number): 275 | if input_size < 0: 276 | raise ValueError("If input_size is a single number, it must be positive.") 277 | input_size = (input_size, input_size) 278 | elif isinstance(input_size, list) or isinstance(input_size, tuple) or isinstance(input_size, np.ndarray): 279 | if len(input_size) != 2: 280 | raise ValueError("If input_size is a sequence, it must be of len 2.") 281 | input_size = (input_size[0], input_size[1]) 282 | else: 283 | raise Exception('input_size must in Number or list or tuple or np.ndarray') 284 | if keep_ratio: 285 | # 将图片短边pad到和长边一样 286 | h, w, c = im.shape 287 | max_h = max(h, input_size[0]) 288 | max_w = max(w, input_size[1]) 289 | im_padded = np.zeros((max_h, max_w, c), dtype=np.uint8) 290 | im_padded[:h, :w] = im.copy() 291 | im = im_padded 292 | text_polys = text_polys.astype(np.float32) 293 | h, w, _ = im.shape 294 | im = cv2.resize(im, input_size) 295 | w_scale = input_size[0] / float(w) 296 | h_scale = input_size[1] / float(h) 297 | text_polys[:, :, 0] *= w_scale 298 | text_polys[:, :, 1] *= h_scale 299 | return im, text_polys 300 | 301 | def horizontal_flip(self, im: np.ndarray, text_polys: np.ndarray) -> tuple: 302 | """ 303 | 对图片和文本框进行水平翻转 304 | :param im: 图片 305 | :param text_polys: 文本框 306 | :return: 水平翻转之后的图片和文本框 307 | """ 308 | flip_text_polys = text_polys.copy() 309 | flip_im = cv2.flip(im, 1) 310 | h, w, _ = flip_im.shape 311 | flip_text_polys[:, :, 0] = w - flip_text_polys[:, :, 0] 312 | return flip_im, flip_text_polys 313 | 314 | def vertical_flip(self, im: np.ndarray, text_polys: np.ndarray) -> tuple: 315 | """ 316 | 对图片和文本框进行竖直翻转 317 | :param im: 图片 318 | :param text_polys: 文本框 319 | :return: 竖直翻转之后的图片和文本框 320 | """ 321 | flip_text_polys = text_polys.copy() 322 | flip_im = cv2.flip(im, 0) 323 | h, w, _ = flip_im.shape 324 | flip_text_polys[:, :, 1] = h - flip_text_polys[:, :, 1] 325 | return flip_im, flip_text_polys 326 | 327 | def test(self, im: np.ndarray, text_polys: np.ndarray): 328 | print('随机尺度缩放') 329 | t_im, t_text_polys = self.random_scale(im, text_polys, [0.5, 1, 2, 3]) 330 | print(t_im.shape, t_text_polys.dtype) 331 | show_pic(t_im, t_text_polys, 'random_scale') 332 | 333 | print('随机旋转') 334 | t_im, t_text_polys = self.random_rotate_img_bbox(im, text_polys, 10) 335 | print(t_im.shape, t_text_polys.dtype) 336 | show_pic(t_im, t_text_polys, 'random_rotate_img_bbox') 337 | 338 | print('随机裁剪') 339 | t_im, t_text_polys = self.random_crop_img_bboxes(im, text_polys) 340 | print(t_im.shape, t_text_polys.dtype) 341 | show_pic(t_im, t_text_polys, 'random_crop_img_bboxes') 342 | 343 | print('水平翻转') 344 | t_im, t_text_polys = self.horizontal_flip(im, text_polys) 345 | print(t_im.shape, t_text_polys.dtype) 346 | show_pic(t_im, t_text_polys, 'horizontal_flip') 347 | 348 | print('竖直翻转') 349 | t_im, t_text_polys = self.vertical_flip(im, text_polys) 350 | print(t_im.shape, t_text_polys.dtype) 351 | show_pic(t_im, t_text_polys, 'vertical_flip') 352 | show_pic(im, text_polys, 'vertical_flip_ori') 353 | 354 | print('加噪声') 355 | t_im = self.add_noise(im) 356 | print(t_im.shape) 357 | show_pic(t_im, text_polys, 'add_noise') 358 | show_pic(im, text_polys, 'add_noise_ori') 359 | -------------------------------------------------------------------------------- /dataset/augment_img.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/12 13:06 3 | 4 | import cv2 5 | import numbers 6 | import math 7 | import random 8 | import numpy as np 9 | from skimage.util import random_noise 10 | 11 | 12 | def show_pic(img, bboxes=None, name='pic'): 13 | ''' 14 | 输入: 15 | img:图像array 16 | bboxes:图像的所有boudning box list, 格式为[[x_min, y_min, x_max, y_max]....] 17 | names:每个box对应的名称 18 | ''' 19 | show_img = img.copy() 20 | if not isinstance(bboxes, np.ndarray): 21 | bboxes = np.array(bboxes) 22 | for point in bboxes.astype(np.int): 23 | cv2.line(show_img, tuple(point[0]), tuple(point[1]), (255, 0, 0), 2) 24 | cv2.line(show_img, tuple(point[1]), tuple(point[2]), (255, 0, 0), 2) 25 | cv2.line(show_img, tuple(point[2]), tuple(point[3]), (255, 0, 0), 2) 26 | cv2.line(show_img, tuple(point[3]), tuple(point[0]), (255, 0, 0), 2) 27 | # cv2.namedWindow(name, 0) # 1表示原图 28 | # cv2.moveWindow(name, 0, 0) 29 | # cv2.resizeWindow(name, 1200, 800) # 可视化的图片大小 30 | cv2.imshow(name, show_img) 31 | 32 | 33 | # 图像均为cv2读取 34 | class DataAugment(): 35 | def __init__(self): 36 | pass 37 | 38 | def add_noise(self, im: np.ndarray): 39 | """ 40 | 对图片加噪声 41 | :param img: 图像array 42 | :return: 加噪声后的图像array,由于输出的像素是在[0,1]之间,所以得乘以255 43 | """ 44 | return (random_noise(im, mode='gaussian', clip=True) * 255).astype(im.dtype) 45 | 46 | def random_scale(self, imgs: list, scales: np.ndarray or list, input_size: int) -> list: 47 | """ 48 | 从scales中随机选择一个尺度,对图片和文本框进行缩放 49 | :param imgs: 原图 和 label 50 | :param scales: 尺度 51 | :param input_size: 图片短边的长度 52 | :return: 经过缩放的图片和文本 53 | """ 54 | rd_scale = float(np.random.choice(scales)) 55 | for idx in range(len(imgs)): 56 | imgs[idx] = cv2.resize(imgs[idx], dsize=None, fx=rd_scale, fy=rd_scale) 57 | imgs[idx], _ = self.rescale(imgs[idx], min_side=input_size) 58 | return imgs 59 | 60 | def rescale(self, img, min_side): 61 | h, w = img.shape[:2] 62 | scale = 1.0 63 | if min(h, w) < min_side: 64 | if h <= w: 65 | scale = 1.0 * min_side / h 66 | else: 67 | scale = 1.0 * min_side / w 68 | img = cv2.resize(img, dsize=None, fx=scale, fy=scale) 69 | return img 70 | 71 | def random_rotate_img_bbox(self, imgs, degrees: numbers.Number or list or tuple or np.ndarray, 72 | same_size=False): 73 | """ 74 | 从给定的角度中选择一个角度,对图片和文本框进行旋转 75 | :param imgs: 原图 和 label 76 | :param degrees: 角度,可以是一个数值或者list 77 | :param same_size: 是否保持和原图一样大 78 | :return: 旋转后的图片和角度 79 | """ 80 | if isinstance(degrees, numbers.Number): 81 | if degrees < 0: 82 | raise ValueError("If degrees is a single number, it must be positive.") 83 | degrees = (-degrees, degrees) 84 | elif isinstance(degrees, list) or isinstance(degrees, tuple) or isinstance(degrees, np.ndarray): 85 | if len(degrees) != 2: 86 | raise ValueError("If degrees is a sequence, it must be of len 2.") 87 | degrees = degrees 88 | else: 89 | raise Exception('degrees must in Number or list or tuple or np.ndarray') 90 | # ---------------------- 旋转图像 ---------------------- 91 | w = imgs[0].shape[1] 92 | h = imgs[0].shape[0] 93 | angle = np.random.uniform(degrees[0], degrees[1]) 94 | 95 | if same_size: 96 | nw = w 97 | nh = h 98 | else: 99 | # 角度变弧度 100 | rangle = np.deg2rad(angle) 101 | # 计算旋转之后图像的w, h 102 | nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) 103 | nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) 104 | # 构造仿射矩阵 105 | rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1) 106 | # 计算原图中心点到新图中心点的偏移量 107 | rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0])) 108 | # 更新仿射矩阵 109 | rot_mat[0, 2] += rot_move[0] 110 | rot_mat[1, 2] += rot_move[1] 111 | for idx in range(len(imgs)): 112 | # 仿射变换 113 | imgs[idx] = cv2.warpAffine(imgs[idx], rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4) 114 | return imgs 115 | 116 | def random_crop(self, imgs, img_size): 117 | h, w = imgs[0].shape[0:2] 118 | th, tw = img_size 119 | if w == tw and h == th: 120 | return imgs 121 | 122 | # label中存在文本实例,并且按照概率进行裁剪 123 | if np.max(imgs[1][:, :, -1]) > 0 and random.random() > 3.0 / 8.0: 124 | # 文本实例的top left点 125 | tl = np.min(np.where(imgs[1][:, :, -1] > 0), axis=1) - img_size 126 | tl[tl < 0] = 0 127 | # 文本实例的 bottom right 点 128 | br = np.max(np.where(imgs[1][:, :, -1] > 0), axis=1) - img_size 129 | br[br < 0] = 0 130 | # 保证选到右下角点是,有足够的距离进行crop 131 | br[0] = min(br[0], h - th) 132 | br[1] = min(br[1], w - tw) 133 | for _ in range(50000): 134 | i = random.randint(tl[0], br[0]) 135 | j = random.randint(tl[1], br[1]) 136 | # 保证最小的图有文本 137 | if imgs[1][:, :, 0][i:i + th, j:j + tw].sum() <= 0: 138 | continue 139 | else: 140 | break 141 | else: 142 | i = random.randint(0, h - th) 143 | j = random.randint(0, w - tw) 144 | 145 | # return i, j, th, tw 146 | for idx in range(len(imgs)): 147 | if len(imgs[idx].shape) == 3: 148 | imgs[idx] = imgs[idx][i:i + th, j:j + tw, :] 149 | else: 150 | imgs[idx] = imgs[idx][i:i + th, j:j + tw] 151 | return imgs 152 | 153 | 154 | def horizontal_flip(self, imgs: list) -> list: 155 | """ 156 | 对图片和文本框进行水平翻转 157 | :param im: 图片 158 | :param text_polys: 文本框 159 | :return: 水平翻转之后的图片和文本框 160 | """ 161 | for idx in range(len(imgs)): 162 | imgs[idx] = cv2.flip(imgs[idx], 1) 163 | return imgs 164 | 165 | def vertical_flip(self, imgs: list) -> list: 166 | """ 167 | 对图片和文本框进行竖直翻转 168 | :param im: 图片 169 | :param text_polys: 文本框 170 | :return: 竖直翻转之后的图片和文本框 171 | """ 172 | for idx in range(len(imgs)): 173 | imgs[idx] = cv2.flip(imgs[idx], 0) 174 | return imgs -------------------------------------------------------------------------------- /dataset/data_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/6/11 15:54 3 | # @Author : zhoujun 4 | 5 | import os 6 | import random 7 | import pathlib 8 | import pyclipper 9 | from torch.utils import data 10 | import glob 11 | import numpy as np 12 | import cv2 13 | from dataset.augment import DataAugment 14 | from utils.utils import draw_bbox 15 | 16 | data_aug = DataAugment() 17 | 18 | 19 | def check_and_validate_polys(polys, xxx_todo_changeme): 20 | ''' 21 | check so that the text poly is in the same direction, 22 | and also filter some invalid polygons 23 | :param polys: 24 | :param tags: 25 | :return: 26 | ''' 27 | (h, w) = xxx_todo_changeme 28 | if polys.shape[0] == 0: 29 | return polys 30 | polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) # x coord not max w-1, and not min 0 31 | polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) # y coord not max h-1, and not min 0 32 | 33 | validated_polys = [] 34 | for poly in polys: 35 | p_area = cv2.contourArea(poly) 36 | if abs(p_area) < 1: 37 | continue 38 | validated_polys.append(poly) 39 | return np.array(validated_polys) 40 | 41 | 42 | def generate_rbox(im_size, text_polys, text_tags, training_mask, i, n, m): 43 | """ 44 | 生成mask图,白色部分是文本,黑色是北京 45 | :param im_size: 图像的h,w 46 | :param text_polys: 框的坐标 47 | :param text_tags: 标注文本框是否参与训练 48 | :return: 生成的mask图 49 | """ 50 | h, w = im_size 51 | score_map = np.zeros((h, w), dtype=np.uint8) 52 | for poly, tag in zip(text_polys, text_tags): 53 | poly = poly.astype(np.int) 54 | r_i = 1 - (1 - m) * (n - i) / (n - 1) 55 | d_i = cv2.contourArea(poly) * (1 - r_i * r_i) / cv2.arcLength(poly, True) 56 | pco = pyclipper.PyclipperOffset() 57 | # pco.AddPath(pyclipper.scale_to_clipper(poly), pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 58 | # shrinked_poly = np.floor(np.array(pyclipper.scale_from_clipper(pco.Execute(-d_i)))).astype(np.int) 59 | pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 60 | shrinked_poly = np.array(pco.Execute(-d_i)) 61 | cv2.fillPoly(score_map, shrinked_poly, 1) 62 | # 制作mask 63 | # rect = cv2.minAreaRect(shrinked_poly) 64 | # poly_h, poly_w = rect[1] 65 | 66 | # if min(poly_h, poly_w) < 10: 67 | # cv2.fillPoly(training_mask, shrinked_poly, 0) 68 | if tag: 69 | cv2.fillPoly(training_mask, shrinked_poly, 0) 70 | # 闭运算填充内部小框 71 | # kernel = np.ones((3, 3), np.uint8) 72 | # score_map = cv2.morphologyEx(score_map, cv2.MORPH_CLOSE, kernel) 73 | return score_map, training_mask 74 | 75 | 76 | def augmentation(im: np.ndarray, text_polys: np.ndarray, scales: np.ndarray, degrees: int, input_size: int) -> tuple: 77 | # the images are rescaled with ratio {0.5, 1.0, 2.0, 3.0} randomly 78 | im, text_polys = data_aug.random_scale(im, text_polys, scales) 79 | # the images are horizontally fliped and rotated in range [−10◦, 10◦] randomly 80 | if random.random() < 0.5: 81 | im, text_polys = data_aug.horizontal_flip(im, text_polys) 82 | if random.random() < 0.5: 83 | im, text_polys = data_aug.random_rotate_img_bbox(im, text_polys, degrees) 84 | # 640 × 640 random samples are cropped from the transformed images 85 | # im, text_polys = data_aug.random_crop_img_bboxes(im, text_polys) 86 | 87 | # im, text_polys = data_aug.resize(im, text_polys, input_size, keep_ratio=False) 88 | # im, text_polys = data_aug.random_crop_image_pse(im, text_polys, input_size) 89 | 90 | return im, text_polys 91 | 92 | 93 | def image_label(im_fn: str, text_polys: np.ndarray, text_tags: list, n: int, m: float, input_size: int, 94 | defrees: int = 10, 95 | scales: np.ndarray = np.array([0.5, 1, 2.0, 3.0])) -> tuple: 96 | ''' 97 | get image's corresponding matrix and ground truth 98 | return 99 | images [512, 512, 3] 100 | score [128, 128, 1] 101 | geo [128, 128, 5] 102 | mask [128, 128, 1] 103 | ''' 104 | 105 | im = cv2.imread(im_fn) 106 | im = cv2.cvtColor(im,cv2.COLOR_BGR2RGB) 107 | h, w, _ = im.shape 108 | # 检查越界 109 | text_polys = check_and_validate_polys(text_polys, (h, w)) 110 | im, text_polys, = augmentation(im, text_polys, scales, defrees, input_size) 111 | 112 | h, w, _ = im.shape 113 | short_edge = min(h, w) 114 | if short_edge < input_size: 115 | # 保证短边 >= inputsize 116 | scale = input_size / short_edge 117 | im = cv2.resize(im, dsize=None, fx=scale, fy=scale) 118 | text_polys *= scale 119 | 120 | # # normal images 121 | # im = im.astype(np.float32) 122 | # im /= 255.0 123 | # im -= np.array((0.485, 0.456, 0.406)) 124 | # im /= np.array((0.229, 0.224, 0.225)) 125 | 126 | h, w, _ = im.shape 127 | training_mask = np.ones((h, w), dtype=np.uint8) 128 | score_maps = [] 129 | for i in range(1, n + 1): 130 | # s1->sn,由小到大 131 | score_map, training_mask = generate_rbox((h, w), text_polys, text_tags, training_mask, i, n, m) 132 | score_maps.append(score_map) 133 | score_maps = np.array(score_maps, dtype=np.float32) 134 | imgs = data_aug.random_crop_author([im, score_maps.transpose((1, 2, 0)),training_mask], (input_size, input_size)) 135 | return imgs[0], imgs[1].transpose((2, 0, 1)), imgs[2]#im,score_maps,training_mask# 136 | 137 | 138 | class MyDataset(data.Dataset): 139 | def __init__(self, data_dir, data_shape: int = 640, n=6, m=0.5, transform=None, target_transform=None): 140 | self.data_list = self.load_data(data_dir) 141 | self.data_shape = data_shape 142 | self.transform = transform 143 | self.target_transform = target_transform 144 | self.n = n 145 | self.m = m 146 | 147 | def __getitem__(self, index): 148 | # print(self.image_list[index]) 149 | img_path, text_polys, text_tags = self.data_list[index] 150 | img, score_maps, training_mask = image_label(img_path, text_polys, text_tags, input_size=self.data_shape, 151 | n=self.n, 152 | m=self.m) 153 | # img = draw_bbox(img,text_polys) 154 | if self.transform: 155 | img = self.transform(img) 156 | if self.target_transform: 157 | score_maps = self.target_transform(score_maps) 158 | training_mask = self.target_transform(training_mask) 159 | return img, score_maps, training_mask 160 | 161 | def load_data(self, data_dir: str) -> list: 162 | data_list = [] 163 | for x in glob.glob(data_dir + '/img/*.jpg', recursive=True): 164 | d = pathlib.Path(x) 165 | label_path = os.path.join(data_dir, 'gt', ('gt_' + str(d.stem) + '.txt')) 166 | bboxs, text = self._get_annotation(label_path) 167 | if len(bboxs) > 0: 168 | data_list.append((x, bboxs, text)) 169 | else: 170 | print('there is no suit bbox on {}'.format(label_path)) 171 | return data_list 172 | 173 | def _get_annotation(self, label_path: str) -> tuple: 174 | boxes = [] 175 | text_tags = [] 176 | with open(label_path, encoding='utf-8', mode='r') as f: 177 | for line in f.readlines(): 178 | params = line.strip().strip('\ufeff').strip('\xef\xbb\xbf').split(',') 179 | try: 180 | label = params[8] 181 | if label == '*' or label == '###': 182 | text_tags.append(True) 183 | else: 184 | text_tags.append(False) 185 | # if label == '*' or label == '###': 186 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, params[:8])) 187 | boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) 188 | except: 189 | print('load label failed on {}'.format(label_path)) 190 | return np.array(boxes, dtype=np.float32), np.array(text_tags, dtype=np.bool) 191 | 192 | def __len__(self): 193 | return len(self.data_list) 194 | 195 | def save_label(self, img_path, label): 196 | save_path = img_path.replace('img', 'save') 197 | if not os.path.exists(os.path.split(save_path)[0]): 198 | os.makedirs(os.path.split(save_path)[0]) 199 | img = draw_bbox(img_path, label) 200 | cv2.imwrite(save_path, img) 201 | return img 202 | 203 | 204 | if __name__ == '__main__': 205 | import torch 206 | import config 207 | from utils.utils import show_img 208 | from tqdm import tqdm 209 | from torch.utils.data import DataLoader 210 | import matplotlib.pyplot as plt 211 | from torchvision import transforms 212 | 213 | train_data = MyDataset(config.trainroot, data_shape=config.data_shape, n=config.n, m=config.m, 214 | transform=transforms.ToTensor()) 215 | train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False, num_workers=0) 216 | 217 | pbar = tqdm(total=len(train_loader)) 218 | for i, (img, label, mask) in enumerate(train_loader): 219 | print(label.shape) 220 | print(img.shape) 221 | print(label[0][-1].sum()) 222 | print(mask[0].shape) 223 | # pbar.update(1) 224 | show_img((img[0] * mask[0].to(torch.float)).numpy().transpose(1, 2, 0), color=True) 225 | show_img(label[0]) 226 | show_img(mask[0]) 227 | plt.show() 228 | 229 | pbar.close() 230 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/6/11 15:54 3 | # @Author : zhoujun 4 | import torch 5 | import shutil 6 | import numpy as np 7 | import config 8 | import os 9 | import cv2 10 | from tqdm import tqdm 11 | from models import PSENet 12 | from predict import Pytorch_model 13 | from cal_recall.script import cal_recall_precison_f1 14 | from utils import draw_bbox 15 | 16 | torch.backends.cudnn.benchmark = True 17 | 18 | 19 | def main(model_path, backbone, scale, path, save_path, gpu_id): 20 | if os.path.exists(save_path): 21 | shutil.rmtree(save_path, ignore_errors=True) 22 | if not os.path.exists(save_path): 23 | os.makedirs(save_path) 24 | save_img_folder = os.path.join(save_path, 'img') 25 | if not os.path.exists(save_img_folder): 26 | os.makedirs(save_img_folder) 27 | save_txt_folder = os.path.join(save_path, 'result') 28 | if not os.path.exists(save_txt_folder): 29 | os.makedirs(save_txt_folder) 30 | img_paths = [os.path.join(path, x) for x in os.listdir(path)] 31 | net = PSENet(backbone=backbone, pretrained=False, result_num=config.n) 32 | model = Pytorch_model(model_path, net=net, scale=scale, gpu_id=gpu_id) 33 | total_frame = 0.0 34 | total_time = 0.0 35 | for img_path in tqdm(img_paths): 36 | img_name = os.path.basename(img_path).split('.')[0] 37 | save_name = os.path.join(save_txt_folder, 'res_' + img_name + '.txt') 38 | _, boxes_list, t = model.predict(img_path) 39 | total_frame += 1 40 | total_time += t 41 | # img = draw_bbox(img_path, boxes_list, color=(0, 0, 255)) 42 | # cv2.imwrite(os.path.join(save_img_folder, '{}.jpg'.format(img_name)), img) 43 | np.savetxt(save_name, boxes_list.reshape(-1, 8), delimiter=',', fmt='%d') 44 | print('fps:{}'.format(total_frame / total_time)) 45 | return save_txt_folder 46 | 47 | 48 | if __name__ == '__main__': 49 | os.environ['CUDA_VISIBLE_DEVICES'] = str('2') 50 | backbone = 'resnet152' 51 | scale = 4 52 | model_path = 'output/psenet_icd2015_resnet152_author_crop_adam_warm_up_myloss/best_r0.714011_p0.708214_f10.711100.pth' 53 | data_path = '/data2/dataset/ICD15/test/img' 54 | gt_path = '/data2/dataset/ICD15/test/gt' 55 | save_path = './result/_scale{}'.format(scale) 56 | gpu_id = 0 57 | print('backbone:{},scale:{},model_path:{}'.format(backbone,scale,model_path)) 58 | save_path = main(model_path, backbone, scale, data_path, save_path, gpu_id=gpu_id) 59 | result = cal_recall_precison_f1(gt_path=gt_path, result_path=save_path) 60 | print(result) 61 | # print(cal_recall_precison_f1('/data2/dataset/ICD151/test/gt', '/data1/zj/tensorflow_PSENet/tmp/')) 62 | -------------------------------------------------------------------------------- /imgs/img_125.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/PSENet.pytorch/f760c2f4938726a2d00efaf5e5b28218323c44ca/imgs/img_125.jpg -------------------------------------------------------------------------------- /imgs/img_31.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/PSENet.pytorch/f760c2f4938726a2d00efaf5e5b28218323c44ca/imgs/img_31.jpg -------------------------------------------------------------------------------- /imgs/img_73.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/PSENet.pytorch/f760c2f4938726a2d00efaf5e5b28218323c44ca/imgs/img_73.jpg -------------------------------------------------------------------------------- /imgs/img_83.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/PSENet.pytorch/f760c2f4938726a2d00efaf5e5b28218323c44ca/imgs/img_83.jpg -------------------------------------------------------------------------------- /imgs/img_98.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/PSENet.pytorch/f760c2f4938726a2d00efaf5e5b28218323c44ca/imgs/img_98.jpg -------------------------------------------------------------------------------- /models/ShuffleNetV2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import logging 4 | from torchvision.models.utils import load_state_dict_from_url 5 | 6 | logger = logging.getLogger('project') 7 | 8 | __all__ = [ 9 | 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 10 | 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' 11 | ] 12 | 13 | model_urls = { 14 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', 15 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', 16 | 'shufflenetv2_x1.5': None, 17 | 'shufflenetv2_x2.0': None, 18 | } 19 | 20 | 21 | def channel_shuffle(x, groups): 22 | batchsize, num_channels, height, width = x.data.size() 23 | channels_per_group = num_channels // groups 24 | 25 | # reshape 26 | x = x.view(batchsize, groups, 27 | channels_per_group, height, width) 28 | 29 | x = torch.transpose(x, 1, 2).contiguous() 30 | 31 | # flatten 32 | x = x.view(batchsize, -1, height, width) 33 | 34 | return x 35 | 36 | 37 | class InvertedResidual(nn.Module): 38 | def __init__(self, inp, oup, stride): 39 | super(InvertedResidual, self).__init__() 40 | 41 | if not (1 <= stride <= 3): 42 | raise ValueError('illegal stride value') 43 | self.stride = stride 44 | 45 | branch_features = oup // 2 46 | assert (self.stride != 1) or (inp == branch_features << 1) 47 | 48 | if self.stride > 1: 49 | self.branch1 = nn.Sequential( 50 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 51 | nn.BatchNorm2d(inp), 52 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 53 | nn.BatchNorm2d(branch_features), 54 | nn.ReLU(inplace=True), 55 | ) 56 | 57 | self.branch2 = nn.Sequential( 58 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 59 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 60 | nn.BatchNorm2d(branch_features), 61 | nn.ReLU(inplace=True), 62 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 63 | nn.BatchNorm2d(branch_features), 64 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 65 | nn.BatchNorm2d(branch_features), 66 | nn.ReLU(inplace=True), 67 | ) 68 | 69 | @staticmethod 70 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): 71 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 72 | 73 | def forward(self, x): 74 | if self.stride == 1: 75 | x1, x2 = x.chunk(2, dim=1) 76 | out = torch.cat((x1, self.branch2(x2)), dim=1) 77 | else: 78 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 79 | 80 | out = channel_shuffle(out, 2) 81 | 82 | return out 83 | 84 | 85 | class ShuffleNetV2(nn.Module): 86 | def __init__(self, stages_repeats, stages_out_channels, num_classes=1000): 87 | super(ShuffleNetV2, self).__init__() 88 | 89 | if len(stages_repeats) != 3: 90 | raise ValueError('expected stages_repeats as list of 3 positive ints') 91 | if len(stages_out_channels) != 5: 92 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 93 | self._stage_out_channels = stages_out_channels 94 | 95 | input_channels = 3 96 | output_channels = self._stage_out_channels[0] 97 | self.conv1 = nn.Sequential( 98 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), 99 | nn.BatchNorm2d(output_channels), 100 | nn.ReLU(inplace=True), 101 | ) 102 | input_channels = output_channels 103 | 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | 106 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 107 | for name, repeats, output_channels in zip( 108 | stage_names, stages_repeats, self._stage_out_channels[1:]): 109 | seq = [InvertedResidual(input_channels, output_channels, 2)] 110 | for i in range(repeats - 1): 111 | seq.append(InvertedResidual(output_channels, output_channels, 1)) 112 | setattr(self, name, nn.Sequential(*seq)) 113 | input_channels = output_channels 114 | 115 | output_channels = self._stage_out_channels[-1] 116 | self.conv5 = nn.Sequential( 117 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 118 | nn.BatchNorm2d(output_channels), 119 | nn.ReLU(inplace=True), 120 | ) 121 | 122 | def forward(self, x): 123 | x = self.conv1(x) 124 | c2 = self.maxpool(x) 125 | c3 = self.stage2(c2) 126 | c4 = self.stage3(c3) 127 | c5 = self.stage4(c4) 128 | # c5 = self.conv5(c5) 129 | return c2, c3, c4, c5 130 | 131 | 132 | def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): 133 | model = ShuffleNetV2(*args, **kwargs) 134 | 135 | if pretrained: 136 | model_url = model_urls[arch] 137 | if model_url is None: 138 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 139 | else: 140 | state_dict = load_state_dict_from_url(model_url, progress=progress) 141 | model.load_state_dict(state_dict,strict=False) 142 | logger.info('load pretrained models from imagenet') 143 | return model 144 | 145 | 146 | def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs): 147 | """ 148 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in 149 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 150 | `_. 151 | 152 | Args: 153 | pretrained (bool): If True, returns a models pre-trained on ImageNet 154 | progress (bool): If True, displays a progress bar of the download to stderr 155 | """ 156 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, 157 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) 158 | 159 | 160 | def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs): 161 | """ 162 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in 163 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 164 | `_. 165 | 166 | Args: 167 | pretrained (bool): If True, returns a models pre-trained on ImageNet 168 | progress (bool): If True, displays a progress bar of the download to stderr 169 | """ 170 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, 171 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) 172 | 173 | 174 | def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs): 175 | """ 176 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in 177 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 178 | `_. 179 | 180 | Args: 181 | pretrained (bool): If True, returns a models pre-trained on ImageNet 182 | progress (bool): If True, displays a progress bar of the download to stderr 183 | """ 184 | return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, 185 | [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) 186 | 187 | 188 | def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs): 189 | """ 190 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in 191 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 192 | `_. 193 | 194 | Args: 195 | pretrained (bool): If True, returns a models pre-trained on ImageNet 196 | progress (bool): If True, displays a progress bar of the download to stderr 197 | """ 198 | return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, 199 | [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) 200 | 201 | 202 | if __name__ == '__main__': 203 | import time 204 | 205 | device = torch.device('cpu') 206 | net = shufflenet_v2_x1_0(pretrained=True) 207 | net.eval() 208 | x = torch.zeros(1, 3, 640, 640).to(device) 209 | start = time.time() 210 | y = net(x) 211 | print(time.time() - start) 212 | for u in y: 213 | print(u.shape) 214 | # torch.save(net.state_dict(), f'shufflenet_v2_x2_0.pth') 215 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/2 18:18 3 | # @Author : zhoujun 4 | from .model import PSENet -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3/29/19 11:03 AM 3 | # @Author : zhoujun 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | 8 | 9 | class PSELoss(nn.Module): 10 | def __init__(self, Lambda, ratio=3, reduction='mean'): 11 | """Implement PSE Loss. 12 | """ 13 | super(PSELoss, self).__init__() 14 | assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']" 15 | self.Lambda = Lambda 16 | self.ratio = ratio 17 | self.reduction = reduction 18 | 19 | def forward(self, outputs, labels, training_masks): 20 | texts = outputs[:, -1, :, :] 21 | kernels = outputs[:, :-1, :, :] 22 | gt_texts = labels[:, -1, :, :] 23 | gt_kernels = labels[:, :-1, :, :] 24 | 25 | selected_masks = self.ohem_batch(texts, gt_texts, training_masks) 26 | selected_masks = selected_masks.to(outputs.device) 27 | 28 | loss_text = self.dice_loss(texts, gt_texts, selected_masks) 29 | 30 | loss_kernels = [] 31 | mask0 = torch.sigmoid(texts).data.cpu().numpy() 32 | mask1 = training_masks.data.cpu().numpy() 33 | selected_masks = ((mask0 > 0.5) & (mask1 > 0.5)).astype('float32') 34 | selected_masks = torch.from_numpy(selected_masks).float() 35 | selected_masks = selected_masks.to(outputs.device) 36 | kernels_num = gt_kernels.size()[1] 37 | for i in range(kernels_num): 38 | kernel_i = kernels[:, i, :, :] 39 | gt_kernel_i = gt_kernels[:, i, :, :] 40 | loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, selected_masks) 41 | loss_kernels.append(loss_kernel_i) 42 | loss_kernels = torch.stack(loss_kernels).mean(0) 43 | if self.reduction == 'mean': 44 | loss_text = loss_text.mean() 45 | loss_kernels = loss_kernels.mean() 46 | elif self.reduction == 'sum': 47 | loss_text = loss_text.sum() 48 | loss_kernels = loss_kernels.sum() 49 | 50 | loss = self.Lambda * loss_text + (1 - self.Lambda) * loss_kernels 51 | return loss_text, loss_kernels, loss 52 | 53 | def dice_loss(self, input, target, mask): 54 | input = torch.sigmoid(input) 55 | 56 | input = input.contiguous().view(input.size()[0], -1) 57 | target = target.contiguous().view(target.size()[0], -1) 58 | mask = mask.contiguous().view(mask.size()[0], -1) 59 | 60 | input = input * mask 61 | target = target * mask 62 | 63 | a = torch.sum(input * target, 1) 64 | b = torch.sum(input * input, 1) + 0.001 65 | c = torch.sum(target * target, 1) + 0.001 66 | d = (2 * a) / (b + c) 67 | return 1 - d 68 | 69 | def ohem_single(self, score, gt_text, training_mask): 70 | pos_num = (int)(np.sum(gt_text > 0.5)) - (int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5))) 71 | 72 | if pos_num == 0: 73 | # selected_mask = gt_text.copy() * 0 # may be not good 74 | selected_mask = training_mask 75 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') 76 | return selected_mask 77 | 78 | neg_num = (int)(np.sum(gt_text <= 0.5)) 79 | neg_num = (int)(min(pos_num * 3, neg_num)) 80 | 81 | if neg_num == 0: 82 | selected_mask = training_mask 83 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') 84 | return selected_mask 85 | 86 | neg_score = score[gt_text <= 0.5] 87 | # 将负样本得分从高到低排序 88 | neg_score_sorted = np.sort(-neg_score) 89 | threshold = -neg_score_sorted[neg_num - 1] 90 | # 选出 得分高的 负样本 和正样本 的 mask 91 | selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5) 92 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32') 93 | return selected_mask 94 | 95 | def ohem_batch(self, scores, gt_texts, training_masks): 96 | scores = scores.data.cpu().numpy() 97 | gt_texts = gt_texts.data.cpu().numpy() 98 | training_masks = training_masks.data.cpu().numpy() 99 | 100 | selected_masks = [] 101 | for i in range(scores.shape[0]): 102 | selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :])) 103 | 104 | selected_masks = np.concatenate(selected_masks, 0) 105 | selected_masks = torch.from_numpy(selected_masks).float() 106 | 107 | return selected_masks 108 | -------------------------------------------------------------------------------- /models/mobilenetv3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/5/23 15:22 3 | # @Author : zhoujun 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import init 9 | 10 | 11 | class hswish(nn.Module): 12 | def forward(self, x): 13 | out = x * F.relu6(x + 3, inplace=True) / 6 14 | return out 15 | 16 | 17 | class hsigmoid(nn.Module): 18 | def forward(self, x): 19 | out = F.relu6(x + 3, inplace=True) / 6 20 | return out 21 | 22 | 23 | class SeModule(nn.Module): 24 | def __init__(self, in_size, reduction=4): 25 | super(SeModule, self).__init__() 26 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 27 | 28 | self.se = nn.Sequential( 29 | nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False), 30 | nn.BatchNorm2d(in_size // reduction), 31 | nn.ReLU(inplace=True), 32 | nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False), 33 | nn.BatchNorm2d(in_size), 34 | hsigmoid() 35 | ) 36 | 37 | def forward(self, x): 38 | return x * self.se(x) 39 | 40 | 41 | class Block(nn.Module): 42 | '''expand + depthwise + pointwise''' 43 | 44 | def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride): 45 | super(Block, self).__init__() 46 | self.stride = stride 47 | self.se = semodule 48 | 49 | self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False) 50 | self.bn1 = nn.BatchNorm2d(expand_size) 51 | self.nolinear1 = nolinear 52 | self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, 53 | padding=kernel_size // 2, groups=expand_size, bias=False) 54 | self.bn2 = nn.BatchNorm2d(expand_size) 55 | self.nolinear2 = nolinear 56 | self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False) 57 | self.bn3 = nn.BatchNorm2d(out_size) 58 | 59 | self.shortcut = nn.Sequential() 60 | if stride == 1 and in_size != out_size: 61 | self.shortcut = nn.Sequential( 62 | nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False), 63 | nn.BatchNorm2d(out_size), 64 | ) 65 | 66 | def forward(self, x): 67 | out = self.nolinear1(self.bn1(self.conv1(x))) 68 | out = self.nolinear2(self.bn2(self.conv2(out))) 69 | out = self.bn3(self.conv3(out)) 70 | if self.se != None: 71 | out = self.se(out) 72 | out = out + self.shortcut(x) if self.stride == 1 else out 73 | return out 74 | 75 | 76 | class MobileNetV3_Large(nn.Module): 77 | def __init__(self, pretrained): 78 | super(MobileNetV3_Large, self).__init__() 79 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False) 80 | self.bn1 = nn.BatchNorm2d(16) 81 | self.hs1 = hswish() 82 | 83 | self.layer1 = nn.Sequential( 84 | Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1), 85 | Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2), 86 | Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1), 87 | ) 88 | 89 | self.layer2 = nn.Sequential( 90 | Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2), 91 | Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1), 92 | Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1), 93 | ) 94 | 95 | self.layer3 = nn.Sequential( 96 | Block(3, 40, 240, 80, hswish(), None, 2), 97 | Block(3, 80, 200, 80, hswish(), None, 1), 98 | Block(3, 80, 184, 80, hswish(), None, 1), 99 | Block(3, 80, 184, 80, hswish(), None, 1), 100 | Block(3, 80, 480, 112, hswish(), SeModule(112), 1), 101 | Block(3, 112, 672, 112, hswish(), SeModule(112), 1), 102 | Block(5, 112, 672, 160, hswish(), SeModule(160), 1), 103 | ) 104 | self.layer4 = nn.Sequential( 105 | Block(5, 160, 672, 160, hswish(), SeModule(160), 2), 106 | Block(5, 160, 960, 160, hswish(), SeModule(160), 1), 107 | ) 108 | self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False) 109 | self.bn2 = nn.BatchNorm2d(960) 110 | self.hs2 = hswish() 111 | self.init_params() 112 | 113 | def init_params(self): 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | init.kaiming_normal_(m.weight, mode='fan_out') 117 | if m.bias is not None: 118 | init.constant_(m.bias, 0) 119 | elif isinstance(m, nn.BatchNorm2d): 120 | init.constant_(m.weight, 1) 121 | init.constant_(m.bias, 0) 122 | elif isinstance(m, nn.Linear): 123 | init.normal_(m.weight, std=0.001) 124 | if m.bias is not None: 125 | init.constant_(m.bias, 0) 126 | 127 | def forward(self, x): 128 | c1 = self.hs1(self.bn1(self.conv1(x))) 129 | c2 = self.layer1(c1) 130 | c3 = self.layer2(c2) 131 | c4 = self.layer3(c3) 132 | c5 = self.layer4(c4) 133 | # c5 = self.hs2(self.bn2(self.conv2(c5))) 134 | return c1, c2, c3, c4, c5 135 | 136 | 137 | class MobileNetV3_Small(nn.Module): 138 | def __init__(self, pretrained): 139 | super(MobileNetV3_Small, self).__init__() 140 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False) 141 | self.bn1 = nn.BatchNorm2d(16) 142 | self.hs1 = hswish() 143 | 144 | self.layer1 = Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2) 145 | self.layer2 = nn.Sequential( 146 | Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2), 147 | Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1), 148 | ) 149 | 150 | self.layer3 = nn.Sequential( 151 | Block(5, 24, 96, 40, hswish(), SeModule(40), 2), 152 | Block(5, 40, 240, 40, hswish(), SeModule(40), 1), 153 | Block(5, 40, 240, 40, hswish(), SeModule(40), 1), 154 | Block(5, 40, 120, 48, hswish(), SeModule(48), 1), 155 | Block(5, 48, 144, 48, hswish(), SeModule(48), 1), 156 | ) 157 | self.layer4 = nn.Sequential( 158 | Block(5, 48, 288, 96, hswish(), SeModule(96), 2), 159 | Block(5, 96, 576, 96, hswish(), SeModule(96), 1), 160 | Block(5, 96, 576, 96, hswish(), SeModule(96), 1), 161 | ) 162 | self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False) 163 | self.bn2 = nn.BatchNorm2d(576) 164 | self.hs2 = hswish() 165 | self.init_params() 166 | 167 | def init_params(self): 168 | for m in self.modules(): 169 | if isinstance(m, nn.Conv2d): 170 | init.kaiming_normal_(m.weight, mode='fan_out') 171 | if m.bias is not None: 172 | init.constant_(m.bias, 0) 173 | elif isinstance(m, nn.BatchNorm2d): 174 | init.constant_(m.weight, 1) 175 | init.constant_(m.bias, 0) 176 | elif isinstance(m, nn.Linear): 177 | init.normal_(m.weight, std=0.001) 178 | if m.bias is not None: 179 | init.constant_(m.bias, 0) 180 | 181 | def forward(self, x): 182 | c1 = self.hs1(self.bn1(self.conv1(x))) 183 | c2 = self.layer1(c1) 184 | c3 = self.layer2(c2) 185 | c4 = self.layer3(c3) 186 | c5 = self.layer4(c4) 187 | # c5 = self.hs2(self.bn2(self.conv2(c5))) 188 | return c1, c2, c3, c4, c5 189 | 190 | 191 | if __name__ == '__main__': 192 | import time 193 | 194 | device = torch.device('cpu') 195 | net = MobileNetV3_Large(pretrained=False) 196 | net.eval() 197 | x = torch.zeros(1, 3, 608, 800).to(device) 198 | start = time.time() 199 | y = net(x) 200 | print(time.time() - start) 201 | for u in y: 202 | print(u.shape) 203 | torch.save(net.state_dict(), f'MobileNetV3_Large111.pth') 204 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/2 17:29 3 | # @Author : zhoujun 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152 8 | from models.mobilenetv3 import MobileNetV3_Large, MobileNetV3_Small 9 | from models.ShuffleNetV2 import shufflenet_v2_x1_0 10 | 11 | d = {'resnet18': {'models': resnet18, 'out': [64, 128, 256, 512]}, 12 | 'resnet34': {'models': resnet34, 'out': [64, 128, 256, 512]}, 13 | 'resnet50': {'models': resnet50, 'out': [256, 512, 1024, 2048]}, 14 | 'resnet101': {'models': resnet101, 'out': [256, 512, 1024, 2048]}, 15 | 'resnet152': {'models': resnet152, 'out': [256, 512, 1024, 2048]}, 16 | 'MobileNetV3_Large': {'models': MobileNetV3_Large, 'out': [24, 40, 160, 160]}, 17 | 'MobileNetV3_Small': {'models': MobileNetV3_Small, 'out': [16, 24, 48, 96]}, 18 | 'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]}} 19 | inplace = True 20 | 21 | 22 | class PSENet(nn.Module): 23 | def __init__(self, backbone, result_num=6, scale: int = 1, pretrained=False): 24 | super(PSENet, self).__init__() 25 | assert backbone in d, 'backbone must in: {}'.format(d) 26 | self.scale = scale 27 | conv_out = 256 28 | model, out = d[backbone]['models'], d[backbone]['out'] 29 | self.backbone = model(pretrained=pretrained) 30 | # Reduce channels 31 | # Top layer 32 | self.toplayer = nn.Sequential(nn.Conv2d(out[3], conv_out, kernel_size=1, stride=1, padding=0), 33 | nn.BatchNorm2d(conv_out), 34 | nn.ReLU(inplace=inplace) 35 | ) 36 | # Lateral layers 37 | self.latlayer1 = nn.Sequential(nn.Conv2d(out[2], conv_out, kernel_size=1, stride=1, padding=0), 38 | nn.BatchNorm2d(conv_out), 39 | nn.ReLU(inplace=inplace) 40 | ) 41 | self.latlayer2 = nn.Sequential(nn.Conv2d(out[1], conv_out, kernel_size=1, stride=1, padding=0), 42 | nn.BatchNorm2d(conv_out), 43 | nn.ReLU(inplace=inplace) 44 | ) 45 | self.latlayer3 = nn.Sequential(nn.Conv2d(out[0], conv_out, kernel_size=1, stride=1, padding=0), 46 | nn.BatchNorm2d(conv_out), 47 | nn.ReLU(inplace=inplace) 48 | ) 49 | 50 | # Smooth layers 51 | self.smooth1 = nn.Sequential(nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1), 52 | nn.BatchNorm2d(conv_out), 53 | nn.ReLU(inplace=inplace) 54 | ) 55 | self.smooth2 = nn.Sequential(nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1), 56 | nn.BatchNorm2d(conv_out), 57 | nn.ReLU(inplace=inplace) 58 | ) 59 | self.smooth3 = nn.Sequential(nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1), 60 | nn.BatchNorm2d(conv_out), 61 | nn.ReLU(inplace=inplace) 62 | ) 63 | 64 | self.conv = nn.Sequential( 65 | nn.Conv2d(conv_out * 4, conv_out, kernel_size=3, padding=1, stride=1), 66 | nn.BatchNorm2d(conv_out), 67 | nn.ReLU(inplace=inplace) 68 | ) 69 | self.out_conv = nn.Conv2d(conv_out, result_num, kernel_size=1, stride=1) 70 | 71 | def forward(self, input: torch.Tensor): 72 | _, _, H, W = input.size() 73 | c2, c3, c4, c5 = self.backbone(input) 74 | # Top-down 75 | p5 = self.toplayer(c5) 76 | p4 = self._upsample_add(p5, self.latlayer1(c4)) 77 | p4 = self.smooth1(p4) 78 | p3 = self._upsample_add(p4, self.latlayer2(c3)) 79 | p3 = self.smooth2(p3) 80 | p2 = self._upsample_add(p3, self.latlayer3(c2)) 81 | p2 = self.smooth3(p2) 82 | 83 | x = self._upsample_cat(p2, p3, p4, p5) 84 | x = self.conv(x) 85 | x = self.out_conv(x) 86 | 87 | if self.train: 88 | x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) 89 | else: 90 | x = F.interpolate(x, size=(H // self.scale, W // self.scale), mode='bilinear', align_corners=True) 91 | return x 92 | 93 | def _upsample_add(self, x, y): 94 | return F.interpolate(x, size=y.size()[2:], mode='bilinear', align_corners=False) + y 95 | 96 | def _upsample_cat(self, p2, p3, p4, p5): 97 | h, w = p2.size()[2:] 98 | p3 = F.interpolate(p3, size=(h, w), mode='bilinear', align_corners=False) 99 | p4 = F.interpolate(p4, size=(h, w), mode='bilinear', align_corners=False) 100 | p5 = F.interpolate(p5, size=(h, w), mode='bilinear', align_corners=False) 101 | return torch.cat([p2, p3, p4, p5], dim=1) 102 | 103 | 104 | if __name__ == '__main__': 105 | import time 106 | 107 | device = torch.device('cpu') 108 | backbone = 'shufflenetv2' 109 | net = PSENet(backbone=backbone, pretrained=False, result_num=6).to(device) 110 | net.eval() 111 | x = torch.zeros(1, 3, 512, 512).to(device) 112 | start = time.time() 113 | y = net(x) 114 | print(time.time() - start) 115 | print(y.shape) 116 | # torch.save(net.state_dict(),f'{backbone}.pth') 117 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2019/1/2 17:30 3 | # @Author : zhoujun 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | import logging 8 | import torch.utils.model_zoo as model_zoo 9 | import torchvision.models.resnet 10 | 11 | logger = logging.getLogger('project') 12 | 13 | __all__ = ['ResNet', 'resnet50', 'resnet101', 14 | 'resnet152'] 15 | 16 | model_urls = { 17 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 18 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 19 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 20 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 21 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 22 | } 23 | 24 | 25 | def conv3x3(in_planes, out_planes, stride=1): 26 | """3x3 convolution with padding""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 28 | padding=1, bias=False) 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None): 35 | super(BasicBlock, self).__init__() 36 | self.conv1 = conv3x3(inplanes, planes, stride) 37 | self.bn1 = nn.BatchNorm2d(planes) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.conv2 = conv3x3(planes, planes) 40 | self.bn2 = nn.BatchNorm2d(planes) 41 | self.downsample = downsample 42 | self.stride = stride 43 | 44 | def forward(self, x): 45 | residual = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | 54 | if self.downsample is not None: 55 | residual = self.downsample(x) 56 | 57 | out += residual 58 | out = self.relu(out) 59 | 60 | return out 61 | 62 | 63 | class Bottleneck(nn.Module): 64 | expansion = 4 65 | 66 | def __init__(self, inplanes, planes, stride=1, downsample=None): 67 | super(Bottleneck, self).__init__() 68 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 69 | self.bn1 = nn.BatchNorm2d(planes) 70 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 71 | padding=1, bias=False) 72 | self.bn2 = nn.BatchNorm2d(planes) 73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 74 | self.bn3 = nn.BatchNorm2d(planes * 4) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | residual = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv3(out) 91 | out = self.bn3(out) 92 | 93 | if self.downsample is not None: 94 | residual = self.downsample(x) 95 | 96 | out += residual 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class ResNet(nn.Module): 103 | 104 | def __init__(self, block, layers): 105 | self.inplanes = 64 106 | super(ResNet, self).__init__() 107 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 108 | bias=False) 109 | self.bn1 = nn.BatchNorm2d(64) 110 | self.relu = nn.ReLU(inplace=True) 111 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 112 | self.layer1 = self._make_layer(block, 64, layers[0]) 113 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 114 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 115 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 116 | 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 120 | m.weight.data.normal_(0, math.sqrt(2. / n)) 121 | elif isinstance(m, nn.BatchNorm2d): 122 | m.weight.data.fill_(1) 123 | m.bias.data.zero_() 124 | 125 | def _make_layer(self, block, planes, blocks, stride=1): 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = nn.Sequential( 129 | nn.Conv2d(self.inplanes, planes * block.expansion, 130 | kernel_size=1, stride=stride, bias=False), 131 | nn.BatchNorm2d(planes * block.expansion), 132 | ) 133 | 134 | layers = [] 135 | layers.append(block(self.inplanes, planes, stride, downsample)) 136 | self.inplanes = planes * block.expansion 137 | for i in range(1, blocks): 138 | layers.append(block(self.inplanes, planes)) 139 | 140 | return nn.Sequential(*layers) 141 | 142 | def _load_pretrained_model(self, model_url): 143 | pretrain_dict = model_zoo.load_url(model_url) 144 | model_dict = {} 145 | state_dict = self.state_dict() 146 | for k, v in pretrain_dict.items(): 147 | if k in state_dict: 148 | model_dict[k] = v 149 | state_dict.update(model_dict) 150 | self.load_state_dict(state_dict) 151 | logger.info('load pretrained models from imagenet') 152 | 153 | def forward(self, input): 154 | x = self.conv1(input) 155 | x = self.bn1(x) 156 | x = self.relu(x) 157 | x = self.maxpool(x) 158 | c2 = self.layer1(x) 159 | c3 = self.layer2(c2) 160 | c4 = self.layer3(c3) 161 | c5 = self.layer4(c4) 162 | return c2, c3, c4, c5 163 | 164 | def resnet18(pretrained=False, **kwargs): 165 | """Constructs a ResNet-18 models. 166 | 167 | Args: 168 | pretrained (bool): If True, returns a models pre-trained on ImageNet 169 | """ 170 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 171 | if pretrained: 172 | model._load_pretrained_model(model_urls['resnet18']) 173 | return model 174 | 175 | 176 | def resnet34(pretrained=False, **kwargs): 177 | """Constructs a ResNet-34 models. 178 | 179 | Args: 180 | pretrained (bool): If True, returns a models pre-trained on ImageNet 181 | """ 182 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 183 | if pretrained: 184 | model._load_pretrained_model(model_urls['resnet34']) 185 | return model 186 | 187 | def resnet50(pretrained=False, **kwargs): 188 | """Constructs a ResNet-50 models. 189 | 190 | Args: 191 | pretrained (bool): If True, returns a models pre-trained on ImageNet 192 | """ 193 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 194 | if pretrained: 195 | model._load_pretrained_model(model_urls['resnet50']) 196 | return model 197 | 198 | 199 | def resnet101(pretrained=False, **kwargs): 200 | """Constructs a ResNet-101 models. 201 | 202 | Args: 203 | pretrained (bool): If True, returns a models pre-trained on ImageNet 204 | """ 205 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 206 | if pretrained: 207 | model._load_pretrained_model(model_urls['resnet101']) 208 | return model 209 | 210 | 211 | def resnet152(pretrained=False, **kwargs): 212 | """Constructs a ResNet-152 models. 213 | 214 | Args: 215 | pretrained (bool): If True, returns a models pre-trained on ImageNet 216 | """ 217 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 218 | if pretrained: 219 | model._load_pretrained_model(model_urls['resnet152']) 220 | return model 221 | 222 | 223 | if __name__ == '__main__': 224 | x = torch.zeros(1, 3, 640, 640) 225 | net = resnet50() 226 | y = net(x) 227 | for u in y: 228 | print(u.shape) 229 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/4/19 11:14 AM 3 | # @Author : zhoujun 4 | import torch 5 | from torchvision import transforms 6 | import os 7 | import cv2 8 | import time 9 | import numpy as np 10 | 11 | from pse import decode as pse_decode 12 | 13 | 14 | class Pytorch_model: 15 | def __init__(self, model_path, net, scale, gpu_id=None): 16 | ''' 17 | 初始化pytorch模型 18 | :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件) 19 | :param net: 网络计算图,如果在model_path中指定的是参数的保存路径,则需要给出网络的计算图 20 | :param img_channel: 图像的通道数: 1,3 21 | :param gpu_id: 在哪一块gpu上运行 22 | ''' 23 | self.scale = scale 24 | if gpu_id is not None and isinstance(gpu_id, int) and torch.cuda.is_available(): 25 | self.device = torch.device("cuda:{}".format(gpu_id)) 26 | else: 27 | self.device = torch.device("cpu") 28 | self.net = torch.load(model_path, map_location=self.device)['state_dict'] 29 | print('device:', self.device) 30 | 31 | if net is not None: 32 | # 如果网络计算图和参数是分开保存的,就执行参数加载 33 | net = net.to(self.device) 34 | net.scale = scale 35 | try: 36 | sk = {} 37 | for k in self.net: 38 | sk[k[7:]] = self.net[k] 39 | net.load_state_dict(sk) 40 | except: 41 | net.load_state_dict(self.net) 42 | self.net = net 43 | print('load models') 44 | self.net.eval() 45 | 46 | def predict(self, img: str, long_size: int = 2240): 47 | ''' 48 | 对传入的图像进行预测,支持图像地址,opecv 读取图片,偏慢 49 | :param img: 图像地址 50 | :param is_numpy: 51 | :return: 52 | ''' 53 | assert os.path.exists(img), 'file is not exists' 54 | img = cv2.imread(img) 55 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 56 | h, w = img.shape[:2] 57 | 58 | scale = long_size / max(h, w) 59 | img = cv2.resize(img, None, fx=scale, fy=scale) 60 | # 将图片由(w,h)变为(1,img_channel,h,w) 61 | tensor = transforms.ToTensor()(img) 62 | tensor = tensor.unsqueeze_(0) 63 | 64 | tensor = tensor.to(self.device) 65 | with torch.no_grad(): 66 | torch.cuda.synchronize() 67 | start = time.time() 68 | preds = self.net(tensor) 69 | preds, boxes_list = pse_decode(preds[0], self.scale) 70 | scale = (preds.shape[1] / w, preds.shape[0] / h) 71 | # print(scale) 72 | # preds, boxes_list = decode(preds,num_pred=-1) 73 | if len(boxes_list): 74 | boxes_list = boxes_list / scale 75 | torch.cuda.synchronize() 76 | t = time.time() - start 77 | return preds, boxes_list, t 78 | 79 | 80 | def _get_annotation(label_path): 81 | boxes = [] 82 | with open(label_path, encoding='utf-8', mode='r') as f: 83 | for line in f.readlines(): 84 | params = line.strip().strip('\ufeff').strip('\xef\xbb\xbf').split(',') 85 | try: 86 | label = params[8] 87 | if label == '*' or label == '###': 88 | continue 89 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, params[:8])) 90 | boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) 91 | except: 92 | print('load label failed on {}'.format(label_path)) 93 | return np.array(boxes, dtype=np.float32) 94 | 95 | 96 | if __name__ == '__main__': 97 | import config 98 | from models import PSENet 99 | import matplotlib.pyplot as plt 100 | from utils.utils import show_img, draw_bbox 101 | 102 | os.environ['CUDA_VISIBLE_DEVICES'] = str('2') 103 | 104 | model_path = 'output/psenet_icd2015_resnet152_author_crop_adam_warm_up_myloss/best_r0.714011_p0.708214_f10.711100.pth' 105 | 106 | # model_path = 'output/psenet_icd2015_new_loss/final.pth' 107 | img_id = 10 108 | img_path = '/data2/dataset/ICD15/test/img/img_{}.jpg'.format(img_id) 109 | label_path = '/data2/dataset/ICD15/test/gt/gt_img_{}.txt'.format(img_id) 110 | label = _get_annotation(label_path) 111 | 112 | # 初始化网络 113 | net = PSENet(backbone='resnet152', pretrained=False, result_num=config.n) 114 | model = Pytorch_model(model_path, net=net, scale=1, gpu_id=0) 115 | # for i in range(100): 116 | # models.predict(img_path) 117 | preds, boxes_list,t = model.predict(img_path) 118 | print(boxes_list) 119 | show_img(preds) 120 | img = draw_bbox(img_path, boxes_list, color=(0, 0, 255)) 121 | cv2.imwrite('result.jpg', img) 122 | # img = draw_bbox(img, label,color=(0,0,255)) 123 | show_img(img, color=True) 124 | 125 | plt.show() 126 | -------------------------------------------------------------------------------- /pse/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS = -I include -std=c++11 -O3 $(shell python3-config --cflags) 2 | LDFLAGS = $(shell python3-config --ldflags) 3 | 4 | DEPS = $(shell find include -xtype f) 5 | CXX_SOURCES = pse.cpp 6 | 7 | LIB_SO = pse.so 8 | 9 | $(LIB_SO): $(CXX_SOURCES) $(DEPS) 10 | $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC 11 | 12 | clean: 13 | rm -rf $(LIB_SO) 14 | -------------------------------------------------------------------------------- /pse/__init__.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import numpy as np 4 | import cv2 5 | import torch 6 | 7 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 8 | 9 | if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value 10 | raise RuntimeError('Cannot compile pse: {}'.format(BASE_DIR)) 11 | 12 | def pse_warpper(kernals, min_area=5): 13 | ''' 14 | reference https://github.com/liuheng92/tensorflow_PSENet/blob/feature_dev/pse 15 | :param kernals: 16 | :param min_area: 17 | :return: 18 | ''' 19 | from .pse import pse_cpp 20 | kernal_num = len(kernals) 21 | if not kernal_num: 22 | return np.array([]), [] 23 | kernals = np.array(kernals) 24 | label_num, label = cv2.connectedComponents(kernals[0].astype(np.uint8), connectivity=4) 25 | label_values = [] 26 | for label_idx in range(1, label_num): 27 | if np.sum(label == label_idx) < min_area: 28 | label[label == label_idx] = 0 29 | continue 30 | label_values.append(label_idx) 31 | 32 | pred = pse_cpp(label, kernals, c=kernal_num) 33 | 34 | return np.array(pred), label_values 35 | 36 | 37 | def decode(preds, scale, threshold=0.7311): 38 | """ 39 | 在输出上使用sigmoid 将值转换为置信度,并使用阈值来进行文字和背景的区分 40 | :param preds: 网络输出 41 | :param scale: 网络的scale 42 | :param threshold: sigmoid的阈值 43 | :return: 最后的输出图和文本框 44 | """ 45 | preds = torch.sigmoid(preds) 46 | preds = preds.detach().cpu().numpy() 47 | 48 | score = preds[-1].astype(np.float32) 49 | preds = preds > threshold 50 | # preds = preds * preds[-1] # 使用最大的kernel作为其他小图的mask,不使用的话效果更好 51 | pred, label_values = pse_warpper(preds, 5) 52 | bbox_list = [] 53 | for label_value in label_values: 54 | points = np.array(np.where(pred == label_value)).transpose((1, 0))[:, ::-1] 55 | 56 | if points.shape[0] < 800 / (scale * scale): 57 | continue 58 | 59 | score_i = np.mean(score[pred == label_value]) 60 | if score_i < 0.93: 61 | continue 62 | 63 | rect = cv2.minAreaRect(points) 64 | bbox = cv2.boxPoints(rect) 65 | bbox_list.append([bbox[1], bbox[2], bbox[3], bbox[0]]) 66 | return pred, np.array(bbox_list) 67 | -------------------------------------------------------------------------------- /pse/include/pybind11/buffer_info.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/buffer_info.h: Python buffer object interface 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | 16 | /// Information record describing a Python buffer object 17 | struct buffer_info { 18 | void *ptr = nullptr; // Pointer to the underlying storage 19 | ssize_t itemsize = 0; // Size of individual items in bytes 20 | ssize_t size = 0; // Total number of entries 21 | std::string format; // For homogeneous buffers, this should be set to format_descriptor::format() 22 | ssize_t ndim = 0; // Number of dimensions 23 | std::vector shape; // Shape of the tensor (1 entry per dimension) 24 | std::vector strides; // Number of entries between adjacent entries (for each per dimension) 25 | 26 | buffer_info() { } 27 | 28 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 29 | detail::any_container shape_in, detail::any_container strides_in) 30 | : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), 31 | shape(std::move(shape_in)), strides(std::move(strides_in)) { 32 | if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) 33 | pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); 34 | for (size_t i = 0; i < (size_t) ndim; ++i) 35 | size *= shape[i]; 36 | } 37 | 38 | template 39 | buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in) 40 | : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { } 41 | 42 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size) 43 | : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { } 44 | 45 | template 46 | buffer_info(T *ptr, ssize_t size) 47 | : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { } 48 | 49 | explicit buffer_info(Py_buffer *view, bool ownview = true) 50 | : buffer_info(view->buf, view->itemsize, view->format, view->ndim, 51 | {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) { 52 | this->view = view; 53 | this->ownview = ownview; 54 | } 55 | 56 | buffer_info(const buffer_info &) = delete; 57 | buffer_info& operator=(const buffer_info &) = delete; 58 | 59 | buffer_info(buffer_info &&other) { 60 | (*this) = std::move(other); 61 | } 62 | 63 | buffer_info& operator=(buffer_info &&rhs) { 64 | ptr = rhs.ptr; 65 | itemsize = rhs.itemsize; 66 | size = rhs.size; 67 | format = std::move(rhs.format); 68 | ndim = rhs.ndim; 69 | shape = std::move(rhs.shape); 70 | strides = std::move(rhs.strides); 71 | std::swap(view, rhs.view); 72 | std::swap(ownview, rhs.ownview); 73 | return *this; 74 | } 75 | 76 | ~buffer_info() { 77 | if (view && ownview) { PyBuffer_Release(view); delete view; } 78 | } 79 | 80 | private: 81 | struct private_ctr_tag { }; 82 | 83 | buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 84 | detail::any_container &&shape_in, detail::any_container &&strides_in) 85 | : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { } 86 | 87 | Py_buffer *view = nullptr; 88 | bool ownview = false; 89 | }; 90 | 91 | NAMESPACE_BEGIN(detail) 92 | 93 | template struct compare_buffer_info { 94 | static bool compare(const buffer_info& b) { 95 | return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); 96 | } 97 | }; 98 | 99 | template struct compare_buffer_info::value>> { 100 | static bool compare(const buffer_info& b) { 101 | return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value || 102 | ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || 103 | ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); 104 | } 105 | }; 106 | 107 | NAMESPACE_END(detail) 108 | NAMESPACE_END(PYBIND11_NAMESPACE) 109 | -------------------------------------------------------------------------------- /pse/include/pybind11/chrono.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime 3 | 4 | Copyright (c) 2016 Trent Houliston and 5 | Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "pybind11.h" 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | // Backport the PyDateTime_DELTA functions from Python3.3 if required 20 | #ifndef PyDateTime_DELTA_GET_DAYS 21 | #define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days) 22 | #endif 23 | #ifndef PyDateTime_DELTA_GET_SECONDS 24 | #define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds) 25 | #endif 26 | #ifndef PyDateTime_DELTA_GET_MICROSECONDS 27 | #define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds) 28 | #endif 29 | 30 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 31 | NAMESPACE_BEGIN(detail) 32 | 33 | template class duration_caster { 34 | public: 35 | typedef typename type::rep rep; 36 | typedef typename type::period period; 37 | 38 | typedef std::chrono::duration> days; 39 | 40 | bool load(handle src, bool) { 41 | using namespace std::chrono; 42 | 43 | // Lazy initialise the PyDateTime import 44 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 45 | 46 | if (!src) return false; 47 | // If invoked with datetime.delta object 48 | if (PyDelta_Check(src.ptr())) { 49 | value = type(duration_cast>( 50 | days(PyDateTime_DELTA_GET_DAYS(src.ptr())) 51 | + seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr())) 52 | + microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr())))); 53 | return true; 54 | } 55 | // If invoked with a float we assume it is seconds and convert 56 | else if (PyFloat_Check(src.ptr())) { 57 | value = type(duration_cast>(duration(PyFloat_AsDouble(src.ptr())))); 58 | return true; 59 | } 60 | else return false; 61 | } 62 | 63 | // If this is a duration just return it back 64 | static const std::chrono::duration& get_duration(const std::chrono::duration &src) { 65 | return src; 66 | } 67 | 68 | // If this is a time_point get the time_since_epoch 69 | template static std::chrono::duration get_duration(const std::chrono::time_point> &src) { 70 | return src.time_since_epoch(); 71 | } 72 | 73 | static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) { 74 | using namespace std::chrono; 75 | 76 | // Use overloaded function to get our duration from our source 77 | // Works out if it is a duration or time_point and get the duration 78 | auto d = get_duration(src); 79 | 80 | // Lazy initialise the PyDateTime import 81 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 82 | 83 | // Declare these special duration types so the conversions happen with the correct primitive types (int) 84 | using dd_t = duration>; 85 | using ss_t = duration>; 86 | using us_t = duration; 87 | 88 | auto dd = duration_cast(d); 89 | auto subd = d - dd; 90 | auto ss = duration_cast(subd); 91 | auto us = duration_cast(subd - ss); 92 | return PyDelta_FromDSU(dd.count(), ss.count(), us.count()); 93 | } 94 | 95 | PYBIND11_TYPE_CASTER(type, _("datetime.timedelta")); 96 | }; 97 | 98 | // This is for casting times on the system clock into datetime.datetime instances 99 | template class type_caster> { 100 | public: 101 | typedef std::chrono::time_point type; 102 | bool load(handle src, bool) { 103 | using namespace std::chrono; 104 | 105 | // Lazy initialise the PyDateTime import 106 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 107 | 108 | if (!src) return false; 109 | if (PyDateTime_Check(src.ptr())) { 110 | std::tm cal; 111 | cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr()); 112 | cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr()); 113 | cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr()); 114 | cal.tm_mday = PyDateTime_GET_DAY(src.ptr()); 115 | cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1; 116 | cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900; 117 | cal.tm_isdst = -1; 118 | 119 | value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr())); 120 | return true; 121 | } 122 | else return false; 123 | } 124 | 125 | static handle cast(const std::chrono::time_point &src, return_value_policy /* policy */, handle /* parent */) { 126 | using namespace std::chrono; 127 | 128 | // Lazy initialise the PyDateTime import 129 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 130 | 131 | std::time_t tt = system_clock::to_time_t(src); 132 | // this function uses static memory so it's best to copy it out asap just in case 133 | // otherwise other code that is using localtime may break this (not just python code) 134 | std::tm localtime = *std::localtime(&tt); 135 | 136 | // Declare these special duration types so the conversions happen with the correct primitive types (int) 137 | using us_t = duration; 138 | 139 | return PyDateTime_FromDateAndTime(localtime.tm_year + 1900, 140 | localtime.tm_mon + 1, 141 | localtime.tm_mday, 142 | localtime.tm_hour, 143 | localtime.tm_min, 144 | localtime.tm_sec, 145 | (duration_cast(src.time_since_epoch() % seconds(1))).count()); 146 | } 147 | PYBIND11_TYPE_CASTER(type, _("datetime.datetime")); 148 | }; 149 | 150 | // Other clocks that are not the system clock are not measured as datetime.datetime objects 151 | // since they are not measured on calendar time. So instead we just make them timedeltas 152 | // Or if they have passed us a time as a float we convert that 153 | template class type_caster> 154 | : public duration_caster> { 155 | }; 156 | 157 | template class type_caster> 158 | : public duration_caster> { 159 | }; 160 | 161 | NAMESPACE_END(detail) 162 | NAMESPACE_END(PYBIND11_NAMESPACE) 163 | -------------------------------------------------------------------------------- /pse/include/pybind11/common.h: -------------------------------------------------------------------------------- 1 | #include "detail/common.h" 2 | #warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'." 3 | -------------------------------------------------------------------------------- /pse/include/pybind11/complex.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/complex.h: Complex number support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | /// glibc defines I as a macro which breaks things, e.g., boost template names 16 | #ifdef I 17 | # undef I 18 | #endif 19 | 20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 21 | 22 | template struct format_descriptor, detail::enable_if_t::value>> { 23 | static constexpr const char c = format_descriptor::c; 24 | static constexpr const char value[3] = { 'Z', c, '\0' }; 25 | static std::string format() { return std::string(value); } 26 | }; 27 | 28 | #ifndef PYBIND11_CPP17 29 | 30 | template constexpr const char format_descriptor< 31 | std::complex, detail::enable_if_t::value>>::value[3]; 32 | 33 | #endif 34 | 35 | NAMESPACE_BEGIN(detail) 36 | 37 | template struct is_fmt_numeric, detail::enable_if_t::value>> { 38 | static constexpr bool value = true; 39 | static constexpr int index = is_fmt_numeric::index + 3; 40 | }; 41 | 42 | template class type_caster> { 43 | public: 44 | bool load(handle src, bool convert) { 45 | if (!src) 46 | return false; 47 | if (!convert && !PyComplex_Check(src.ptr())) 48 | return false; 49 | Py_complex result = PyComplex_AsCComplex(src.ptr()); 50 | if (result.real == -1.0 && PyErr_Occurred()) { 51 | PyErr_Clear(); 52 | return false; 53 | } 54 | value = std::complex((T) result.real, (T) result.imag); 55 | return true; 56 | } 57 | 58 | static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) { 59 | return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); 60 | } 61 | 62 | PYBIND11_TYPE_CASTER(std::complex, _("complex")); 63 | }; 64 | NAMESPACE_END(detail) 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /pse/include/pybind11/descr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/descr.h: Helper type for concatenating type signatures 3 | either at runtime (C++11) or compile time (C++14) 4 | 5 | Copyright (c) 2016 Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "common.h" 14 | 15 | NAMESPACE_BEGIN(pybind11) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | /* Concatenate type signatures at compile time using C++14 */ 19 | #if defined(PYBIND11_CPP14) && !defined(_MSC_VER) 20 | #define PYBIND11_CONSTEXPR_DESCR 21 | 22 | template class descr { 23 | template friend class descr; 24 | public: 25 | constexpr descr(char const (&text) [Size1+1], const std::type_info * const (&types)[Size2+1]) 26 | : descr(text, types, 27 | make_index_sequence(), 28 | make_index_sequence()) { } 29 | 30 | constexpr const char *text() const { return m_text; } 31 | constexpr const std::type_info * const * types() const { return m_types; } 32 | 33 | template 34 | constexpr descr operator+(const descr &other) const { 35 | return concat(other, 36 | make_index_sequence(), 37 | make_index_sequence(), 38 | make_index_sequence(), 39 | make_index_sequence()); 40 | } 41 | 42 | protected: 43 | template 44 | constexpr descr( 45 | char const (&text) [Size1+1], 46 | const std::type_info * const (&types) [Size2+1], 47 | index_sequence, index_sequence) 48 | : m_text{text[Indices1]..., '\0'}, 49 | m_types{types[Indices2]..., nullptr } {} 50 | 51 | template 53 | constexpr descr 54 | concat(const descr &other, 55 | index_sequence, index_sequence, 56 | index_sequence, index_sequence) const { 57 | return descr( 58 | { m_text[Indices1]..., other.m_text[OtherIndices1]..., '\0' }, 59 | { m_types[Indices2]..., other.m_types[OtherIndices2]..., nullptr } 60 | ); 61 | } 62 | 63 | protected: 64 | char m_text[Size1 + 1]; 65 | const std::type_info * m_types[Size2 + 1]; 66 | }; 67 | 68 | template constexpr descr _(char const(&text)[Size]) { 69 | return descr(text, { nullptr }); 70 | } 71 | 72 | template struct int_to_str : int_to_str { }; 73 | template struct int_to_str<0, Digits...> { 74 | static constexpr auto digits = descr({ ('0' + Digits)..., '\0' }, { nullptr }); 75 | }; 76 | 77 | // Ternary description (like std::conditional) 78 | template 79 | constexpr enable_if_t> _(char const(&text1)[Size1], char const(&)[Size2]) { 80 | return _(text1); 81 | } 82 | template 83 | constexpr enable_if_t> _(char const(&)[Size1], char const(&text2)[Size2]) { 84 | return _(text2); 85 | } 86 | template 87 | constexpr enable_if_t> _(descr d, descr) { return d; } 88 | template 89 | constexpr enable_if_t> _(descr, descr d) { return d; } 90 | 91 | template auto constexpr _() -> decltype(int_to_str::digits) { 92 | return int_to_str::digits; 93 | } 94 | 95 | template constexpr descr<1, 1> _() { 96 | return descr<1, 1>({ '%', '\0' }, { &typeid(Type), nullptr }); 97 | } 98 | 99 | inline constexpr descr<0, 0> concat() { return _(""); } 100 | template auto constexpr concat(descr descr) { return descr; } 101 | template auto constexpr concat(descr descr, Args&&... args) { return descr + _(", ") + concat(args...); } 102 | template auto constexpr type_descr(descr descr) { return _("{") + descr + _("}"); } 103 | 104 | #define PYBIND11_DESCR constexpr auto 105 | 106 | #else /* Simpler C++11 implementation based on run-time memory allocation and copying */ 107 | 108 | class descr { 109 | public: 110 | PYBIND11_NOINLINE descr(const char *text, const std::type_info * const * types) { 111 | size_t nChars = len(text), nTypes = len(types); 112 | m_text = new char[nChars]; 113 | m_types = new const std::type_info *[nTypes]; 114 | memcpy(m_text, text, nChars * sizeof(char)); 115 | memcpy(m_types, types, nTypes * sizeof(const std::type_info *)); 116 | } 117 | 118 | PYBIND11_NOINLINE descr operator+(descr &&d2) && { 119 | descr r; 120 | 121 | size_t nChars1 = len(m_text), nTypes1 = len(m_types); 122 | size_t nChars2 = len(d2.m_text), nTypes2 = len(d2.m_types); 123 | 124 | r.m_text = new char[nChars1 + nChars2 - 1]; 125 | r.m_types = new const std::type_info *[nTypes1 + nTypes2 - 1]; 126 | memcpy(r.m_text, m_text, (nChars1-1) * sizeof(char)); 127 | memcpy(r.m_text + nChars1 - 1, d2.m_text, nChars2 * sizeof(char)); 128 | memcpy(r.m_types, m_types, (nTypes1-1) * sizeof(std::type_info *)); 129 | memcpy(r.m_types + nTypes1 - 1, d2.m_types, nTypes2 * sizeof(std::type_info *)); 130 | 131 | delete[] m_text; delete[] m_types; 132 | delete[] d2.m_text; delete[] d2.m_types; 133 | 134 | return r; 135 | } 136 | 137 | char *text() { return m_text; } 138 | const std::type_info * * types() { return m_types; } 139 | 140 | protected: 141 | PYBIND11_NOINLINE descr() { } 142 | 143 | template static size_t len(const T *ptr) { // return length including null termination 144 | const T *it = ptr; 145 | while (*it++ != (T) 0) 146 | ; 147 | return static_cast(it - ptr); 148 | } 149 | 150 | const std::type_info **m_types = nullptr; 151 | char *m_text = nullptr; 152 | }; 153 | 154 | /* The 'PYBIND11_NOINLINE inline' combinations below are intentional to get the desired linkage while producing as little object code as possible */ 155 | 156 | PYBIND11_NOINLINE inline descr _(const char *text) { 157 | const std::type_info *types[1] = { nullptr }; 158 | return descr(text, types); 159 | } 160 | 161 | template PYBIND11_NOINLINE enable_if_t _(const char *text1, const char *) { return _(text1); } 162 | template PYBIND11_NOINLINE enable_if_t _(char const *, const char *text2) { return _(text2); } 163 | template PYBIND11_NOINLINE enable_if_t _(descr d, descr) { return d; } 164 | template PYBIND11_NOINLINE enable_if_t _(descr, descr d) { return d; } 165 | 166 | template PYBIND11_NOINLINE descr _() { 167 | const std::type_info *types[2] = { &typeid(Type), nullptr }; 168 | return descr("%", types); 169 | } 170 | 171 | template PYBIND11_NOINLINE descr _() { 172 | const std::type_info *types[1] = { nullptr }; 173 | return descr(std::to_string(Size).c_str(), types); 174 | } 175 | 176 | PYBIND11_NOINLINE inline descr concat() { return _(""); } 177 | PYBIND11_NOINLINE inline descr concat(descr &&d) { return d; } 178 | template PYBIND11_NOINLINE descr concat(descr &&d, Args&&... args) { return std::move(d) + _(", ") + concat(std::forward(args)...); } 179 | PYBIND11_NOINLINE inline descr type_descr(descr&& d) { return _("{") + std::move(d) + _("}"); } 180 | 181 | #define PYBIND11_DESCR ::pybind11::detail::descr 182 | #endif 183 | 184 | NAMESPACE_END(detail) 185 | NAMESPACE_END(pybind11) 186 | -------------------------------------------------------------------------------- /pse/include/pybind11/detail/descr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | NAMESPACE_BEGIN(detail) 16 | 17 | #if !defined(_MSC_VER) 18 | # define PYBIND11_DESCR_CONSTEXPR static constexpr 19 | #else 20 | # define PYBIND11_DESCR_CONSTEXPR const 21 | #endif 22 | 23 | /* Concatenate type signatures at compile time */ 24 | template 25 | struct descr { 26 | char text[N + 1]; 27 | 28 | constexpr descr() : text{'\0'} { } 29 | constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence()) { } 30 | 31 | template 32 | constexpr descr(char const (&s)[N+1], index_sequence) : text{s[Is]..., '\0'} { } 33 | 34 | template 35 | constexpr descr(char c, Chars... cs) : text{c, static_cast(cs)..., '\0'} { } 36 | 37 | static constexpr std::array types() { 38 | return {{&typeid(Ts)..., nullptr}}; 39 | } 40 | }; 41 | 42 | template 43 | constexpr descr plus_impl(const descr &a, const descr &b, 44 | index_sequence, index_sequence) { 45 | return {a.text[Is1]..., b.text[Is2]...}; 46 | } 47 | 48 | template 49 | constexpr descr operator+(const descr &a, const descr &b) { 50 | return plus_impl(a, b, make_index_sequence(), make_index_sequence()); 51 | } 52 | 53 | template 54 | constexpr descr _(char const(&text)[N]) { return descr(text); } 55 | constexpr descr<0> _(char const(&)[1]) { return {}; } 56 | 57 | template struct int_to_str : int_to_str { }; 58 | template struct int_to_str<0, Digits...> { 59 | static constexpr auto digits = descr(('0' + Digits)...); 60 | }; 61 | 62 | // Ternary description (like std::conditional) 63 | template 64 | constexpr enable_if_t> _(char const(&text1)[N1], char const(&)[N2]) { 65 | return _(text1); 66 | } 67 | template 68 | constexpr enable_if_t> _(char const(&)[N1], char const(&text2)[N2]) { 69 | return _(text2); 70 | } 71 | 72 | template 73 | constexpr enable_if_t _(const T1 &d, const T2 &) { return d; } 74 | template 75 | constexpr enable_if_t _(const T1 &, const T2 &d) { return d; } 76 | 77 | template auto constexpr _() -> decltype(int_to_str::digits) { 78 | return int_to_str::digits; 79 | } 80 | 81 | template constexpr descr<1, Type> _() { return {'%'}; } 82 | 83 | constexpr descr<0> concat() { return {}; } 84 | 85 | template 86 | constexpr descr concat(const descr &descr) { return descr; } 87 | 88 | template 89 | constexpr auto concat(const descr &d, const Args &...args) 90 | -> decltype(std::declval>() + concat(args...)) { 91 | return d + _(", ") + concat(args...); 92 | } 93 | 94 | template 95 | constexpr descr type_descr(const descr &descr) { 96 | return _("{") + descr + _("}"); 97 | } 98 | 99 | NAMESPACE_END(detail) 100 | NAMESPACE_END(PYBIND11_NAMESPACE) 101 | -------------------------------------------------------------------------------- /pse/include/pybind11/detail/internals.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/internals.h: Internal data structure and related functions 3 | 4 | Copyright (c) 2017 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "../pytypes.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | NAMESPACE_BEGIN(detail) 16 | // Forward declarations 17 | inline PyTypeObject *make_static_property_type(); 18 | inline PyTypeObject *make_default_metaclass(); 19 | inline PyObject *make_object_base_type(PyTypeObject *metaclass); 20 | 21 | // The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new 22 | // Thread Specific Storage (TSS) API. 23 | #if PY_VERSION_HEX >= 0x03070000 24 | # define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr 25 | # define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key)) 26 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (tstate)) 27 | # define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr) 28 | #else 29 | // Usually an int but a long on Cygwin64 with Python 3.x 30 | # define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0 31 | # define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key)) 32 | # if PY_MAJOR_VERSION < 3 33 | # define PYBIND11_TLS_DELETE_VALUE(key) \ 34 | PyThread_delete_key_value(key) 35 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) \ 36 | do { \ 37 | PyThread_delete_key_value((key)); \ 38 | PyThread_set_key_value((key), (value)); \ 39 | } while (false) 40 | # else 41 | # define PYBIND11_TLS_DELETE_VALUE(key) \ 42 | PyThread_set_key_value((key), nullptr) 43 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) \ 44 | PyThread_set_key_value((key), (value)) 45 | # endif 46 | #endif 47 | 48 | // Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly 49 | // other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module 50 | // even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under 51 | // libstdc++, this doesn't happen: equality and the type_index hash are based on the type name, 52 | // which works. If not under a known-good stl, provide our own name-based hash and equality 53 | // functions that use the type name. 54 | #if defined(__GLIBCXX__) 55 | inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; } 56 | using type_hash = std::hash; 57 | using type_equal_to = std::equal_to; 58 | #else 59 | inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { 60 | return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0; 61 | } 62 | 63 | struct type_hash { 64 | size_t operator()(const std::type_index &t) const { 65 | size_t hash = 5381; 66 | const char *ptr = t.name(); 67 | while (auto c = static_cast(*ptr++)) 68 | hash = (hash * 33) ^ c; 69 | return hash; 70 | } 71 | }; 72 | 73 | struct type_equal_to { 74 | bool operator()(const std::type_index &lhs, const std::type_index &rhs) const { 75 | return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0; 76 | } 77 | }; 78 | #endif 79 | 80 | template 81 | using type_map = std::unordered_map; 82 | 83 | struct overload_hash { 84 | inline size_t operator()(const std::pair& v) const { 85 | size_t value = std::hash()(v.first); 86 | value ^= std::hash()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2); 87 | return value; 88 | } 89 | }; 90 | 91 | /// Internal data structure used to track registered instances and types. 92 | /// Whenever binary incompatible changes are made to this structure, 93 | /// `PYBIND11_INTERNALS_VERSION` must be incremented. 94 | struct internals { 95 | type_map registered_types_cpp; // std::type_index -> pybind11's type information 96 | std::unordered_map> registered_types_py; // PyTypeObject* -> base type_info(s) 97 | std::unordered_multimap registered_instances; // void * -> instance* 98 | std::unordered_set, overload_hash> inactive_overload_cache; 99 | type_map> direct_conversions; 100 | std::unordered_map> patients; 101 | std::forward_list registered_exception_translators; 102 | std::unordered_map shared_data; // Custom data to be shared across extensions 103 | std::vector loader_patient_stack; // Used by `loader_life_support` 104 | std::forward_list static_strings; // Stores the std::strings backing detail::c_str() 105 | PyTypeObject *static_property_type; 106 | PyTypeObject *default_metaclass; 107 | PyObject *instance_base; 108 | #if defined(WITH_THREAD) 109 | PYBIND11_TLS_KEY_INIT(tstate); 110 | PyInterpreterState *istate = nullptr; 111 | #endif 112 | }; 113 | 114 | /// Additional type information which does not fit into the PyTypeObject. 115 | /// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`. 116 | struct type_info { 117 | PyTypeObject *type; 118 | const std::type_info *cpptype; 119 | size_t type_size, type_align, holder_size_in_ptrs; 120 | void *(*operator_new)(size_t); 121 | void (*init_instance)(instance *, const void *); 122 | void (*dealloc)(value_and_holder &v_h); 123 | std::vector implicit_conversions; 124 | std::vector> implicit_casts; 125 | std::vector *direct_conversions; 126 | buffer_info *(*get_buffer)(PyObject *, void *) = nullptr; 127 | void *get_buffer_data = nullptr; 128 | void *(*module_local_load)(PyObject *, const type_info *) = nullptr; 129 | /* A simple type never occurs as a (direct or indirect) parent 130 | * of a class that makes use of multiple inheritance */ 131 | bool simple_type : 1; 132 | /* True if there is no multiple inheritance in this type's inheritance tree */ 133 | bool simple_ancestors : 1; 134 | /* for base vs derived holder_type checks */ 135 | bool default_holder : 1; 136 | /* true if this is a type registered with py::module_local */ 137 | bool module_local : 1; 138 | }; 139 | 140 | /// Tracks the `internals` and `type_info` ABI version independent of the main library version 141 | #define PYBIND11_INTERNALS_VERSION 3 142 | 143 | #if defined(_DEBUG) 144 | # define PYBIND11_BUILD_TYPE "_debug" 145 | #else 146 | # define PYBIND11_BUILD_TYPE "" 147 | #endif 148 | 149 | #if defined(WITH_THREAD) 150 | # define PYBIND11_INTERNALS_KIND "" 151 | #else 152 | # define PYBIND11_INTERNALS_KIND "_without_thread" 153 | #endif 154 | 155 | #define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \ 156 | PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__" 157 | 158 | #define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \ 159 | PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__" 160 | 161 | /// Each module locally stores a pointer to the `internals` data. The data 162 | /// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`. 163 | inline internals **&get_internals_pp() { 164 | static internals **internals_pp = nullptr; 165 | return internals_pp; 166 | } 167 | 168 | /// Return a reference to the current `internals` data 169 | PYBIND11_NOINLINE inline internals &get_internals() { 170 | auto **&internals_pp = get_internals_pp(); 171 | if (internals_pp && *internals_pp) 172 | return **internals_pp; 173 | 174 | constexpr auto *id = PYBIND11_INTERNALS_ID; 175 | auto builtins = handle(PyEval_GetBuiltins()); 176 | if (builtins.contains(id) && isinstance(builtins[id])) { 177 | internals_pp = static_cast(capsule(builtins[id])); 178 | 179 | // We loaded builtins through python's builtins, which means that our `error_already_set` 180 | // and `builtin_exception` may be different local classes than the ones set up in the 181 | // initial exception translator, below, so add another for our local exception classes. 182 | // 183 | // libstdc++ doesn't require this (types there are identified only by name) 184 | #if !defined(__GLIBCXX__) 185 | (*internals_pp)->registered_exception_translators.push_front( 186 | [](std::exception_ptr p) -> void { 187 | try { 188 | if (p) std::rethrow_exception(p); 189 | } catch (error_already_set &e) { e.restore(); return; 190 | } catch (const builtin_exception &e) { e.set_error(); return; 191 | } 192 | } 193 | ); 194 | #endif 195 | } else { 196 | if (!internals_pp) internals_pp = new internals*(); 197 | auto *&internals_ptr = *internals_pp; 198 | internals_ptr = new internals(); 199 | #if defined(WITH_THREAD) 200 | PyEval_InitThreads(); 201 | PyThreadState *tstate = PyThreadState_Get(); 202 | #if PY_VERSION_HEX >= 0x03070000 203 | internals_ptr->tstate = PyThread_tss_alloc(); 204 | if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate)) 205 | pybind11_fail("get_internals: could not successfully initialize the TSS key!"); 206 | PyThread_tss_set(internals_ptr->tstate, tstate); 207 | #else 208 | internals_ptr->tstate = PyThread_create_key(); 209 | if (internals_ptr->tstate == -1) 210 | pybind11_fail("get_internals: could not successfully initialize the TLS key!"); 211 | PyThread_set_key_value(internals_ptr->tstate, tstate); 212 | #endif 213 | internals_ptr->istate = tstate->interp; 214 | #endif 215 | builtins[id] = capsule(internals_pp); 216 | internals_ptr->registered_exception_translators.push_front( 217 | [](std::exception_ptr p) -> void { 218 | try { 219 | if (p) std::rethrow_exception(p); 220 | } catch (error_already_set &e) { e.restore(); return; 221 | } catch (const builtin_exception &e) { e.set_error(); return; 222 | } catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return; 223 | } catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 224 | } catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 225 | } catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 226 | } catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return; 227 | } catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 228 | } catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return; 229 | } catch (...) { 230 | PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!"); 231 | return; 232 | } 233 | } 234 | ); 235 | internals_ptr->static_property_type = make_static_property_type(); 236 | internals_ptr->default_metaclass = make_default_metaclass(); 237 | internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass); 238 | } 239 | return **internals_pp; 240 | } 241 | 242 | /// Works like `internals.registered_types_cpp`, but for module-local registered types: 243 | inline type_map ®istered_local_types_cpp() { 244 | static type_map locals{}; 245 | return locals; 246 | } 247 | 248 | /// Constructs a std::string with the given arguments, stores it in `internals`, and returns its 249 | /// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only 250 | /// cleared when the program exits or after interpreter shutdown (when embedding), and so are 251 | /// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name). 252 | template 253 | const char *c_str(Args &&...args) { 254 | auto &strings = get_internals().static_strings; 255 | strings.emplace_front(std::forward(args)...); 256 | return strings.front().c_str(); 257 | } 258 | 259 | NAMESPACE_END(detail) 260 | 261 | /// Returns a named pointer that is shared among all extension modules (using the same 262 | /// pybind11 version) running in the current interpreter. Names starting with underscores 263 | /// are reserved for internal usage. Returns `nullptr` if no matching entry was found. 264 | inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) { 265 | auto &internals = detail::get_internals(); 266 | auto it = internals.shared_data.find(name); 267 | return it != internals.shared_data.end() ? it->second : nullptr; 268 | } 269 | 270 | /// Set the shared data that can be later recovered by `get_shared_data()`. 271 | inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) { 272 | detail::get_internals().shared_data[name] = data; 273 | return data; 274 | } 275 | 276 | /// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if 277 | /// such entry exists. Otherwise, a new object of default-constructible type `T` is 278 | /// added to the shared data under the given name and a reference to it is returned. 279 | template 280 | T &get_or_create_shared_data(const std::string &name) { 281 | auto &internals = detail::get_internals(); 282 | auto it = internals.shared_data.find(name); 283 | T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr); 284 | if (!ptr) { 285 | ptr = new T(); 286 | internals.shared_data[name] = ptr; 287 | } 288 | return *ptr; 289 | } 290 | 291 | NAMESPACE_END(PYBIND11_NAMESPACE) 292 | -------------------------------------------------------------------------------- /pse/include/pybind11/detail/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(PYBIND11_NAMESPACE) 54 | -------------------------------------------------------------------------------- /pse/include/pybind11/embed.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/embed.h: Support for embedding the interpreter 3 | 4 | Copyright (c) 2017 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include "eval.h" 14 | 15 | #if defined(PYPY_VERSION) 16 | # error Embedding the interpreter is not supported with PyPy 17 | #endif 18 | 19 | #if PY_MAJOR_VERSION >= 3 20 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 21 | extern "C" PyObject *pybind11_init_impl_##name() { \ 22 | return pybind11_init_wrapper_##name(); \ 23 | } 24 | #else 25 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 26 | extern "C" void pybind11_init_impl_##name() { \ 27 | pybind11_init_wrapper_##name(); \ 28 | } 29 | #endif 30 | 31 | /** \rst 32 | Add a new module to the table of builtins for the interpreter. Must be 33 | defined in global scope. The first macro parameter is the name of the 34 | module (without quotes). The second parameter is the variable which will 35 | be used as the interface to add functions and classes to the module. 36 | 37 | .. code-block:: cpp 38 | 39 | PYBIND11_EMBEDDED_MODULE(example, m) { 40 | // ... initialize functions and classes here 41 | m.def("foo", []() { 42 | return "Hello, World!"; 43 | }); 44 | } 45 | \endrst */ 46 | #define PYBIND11_EMBEDDED_MODULE(name, variable) \ 47 | static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \ 48 | static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \ 49 | auto m = pybind11::module(PYBIND11_TOSTRING(name)); \ 50 | try { \ 51 | PYBIND11_CONCAT(pybind11_init_, name)(m); \ 52 | return m.ptr(); \ 53 | } catch (pybind11::error_already_set &e) { \ 54 | PyErr_SetString(PyExc_ImportError, e.what()); \ 55 | return nullptr; \ 56 | } catch (const std::exception &e) { \ 57 | PyErr_SetString(PyExc_ImportError, e.what()); \ 58 | return nullptr; \ 59 | } \ 60 | } \ 61 | PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 62 | pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \ 63 | PYBIND11_CONCAT(pybind11_init_impl_, name)); \ 64 | void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable) 65 | 66 | 67 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 68 | NAMESPACE_BEGIN(detail) 69 | 70 | /// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks. 71 | struct embedded_module { 72 | #if PY_MAJOR_VERSION >= 3 73 | using init_t = PyObject *(*)(); 74 | #else 75 | using init_t = void (*)(); 76 | #endif 77 | embedded_module(const char *name, init_t init) { 78 | if (Py_IsInitialized()) 79 | pybind11_fail("Can't add new modules after the interpreter has been initialized"); 80 | 81 | auto result = PyImport_AppendInittab(name, init); 82 | if (result == -1) 83 | pybind11_fail("Insufficient memory to add a new module"); 84 | } 85 | }; 86 | 87 | NAMESPACE_END(detail) 88 | 89 | /** \rst 90 | Initialize the Python interpreter. No other pybind11 or CPython API functions can be 91 | called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The 92 | optional parameter can be used to skip the registration of signal handlers (see the 93 | `Python documentation`_ for details). Calling this function again after the interpreter 94 | has already been initialized is a fatal error. 95 | 96 | If initializing the Python interpreter fails, then the program is terminated. (This 97 | is controlled by the CPython runtime and is an exception to pybind11's normal behavior 98 | of throwing exceptions on errors.) 99 | 100 | .. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx 101 | \endrst */ 102 | inline void initialize_interpreter(bool init_signal_handlers = true) { 103 | if (Py_IsInitialized()) 104 | pybind11_fail("The interpreter is already running"); 105 | 106 | Py_InitializeEx(init_signal_handlers ? 1 : 0); 107 | 108 | // Make .py files in the working directory available by default 109 | module::import("sys").attr("path").cast().append("."); 110 | } 111 | 112 | /** \rst 113 | Shut down the Python interpreter. No pybind11 or CPython API functions can be called 114 | after this. In addition, pybind11 objects must not outlive the interpreter: 115 | 116 | .. code-block:: cpp 117 | 118 | { // BAD 119 | py::initialize_interpreter(); 120 | auto hello = py::str("Hello, World!"); 121 | py::finalize_interpreter(); 122 | } // <-- BOOM, hello's destructor is called after interpreter shutdown 123 | 124 | { // GOOD 125 | py::initialize_interpreter(); 126 | { // scoped 127 | auto hello = py::str("Hello, World!"); 128 | } // <-- OK, hello is cleaned up properly 129 | py::finalize_interpreter(); 130 | } 131 | 132 | { // BETTER 133 | py::scoped_interpreter guard{}; 134 | auto hello = py::str("Hello, World!"); 135 | } 136 | 137 | .. warning:: 138 | 139 | The interpreter can be restarted by calling `initialize_interpreter` again. 140 | Modules created using pybind11 can be safely re-initialized. However, Python 141 | itself cannot completely unload binary extension modules and there are several 142 | caveats with regard to interpreter restarting. All the details can be found 143 | in the CPython documentation. In short, not all interpreter memory may be 144 | freed, either due to reference cycles or user-created global data. 145 | 146 | \endrst */ 147 | inline void finalize_interpreter() { 148 | handle builtins(PyEval_GetBuiltins()); 149 | const char *id = PYBIND11_INTERNALS_ID; 150 | 151 | // Get the internals pointer (without creating it if it doesn't exist). It's possible for the 152 | // internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()` 153 | // during destruction), so we get the pointer-pointer here and check it after Py_Finalize(). 154 | detail::internals **internals_ptr_ptr = detail::get_internals_pp(); 155 | // It could also be stashed in builtins, so look there too: 156 | if (builtins.contains(id) && isinstance(builtins[id])) 157 | internals_ptr_ptr = capsule(builtins[id]); 158 | 159 | Py_Finalize(); 160 | 161 | if (internals_ptr_ptr) { 162 | delete *internals_ptr_ptr; 163 | *internals_ptr_ptr = nullptr; 164 | } 165 | } 166 | 167 | /** \rst 168 | Scope guard version of `initialize_interpreter` and `finalize_interpreter`. 169 | This a move-only guard and only a single instance can exist. 170 | 171 | .. code-block:: cpp 172 | 173 | #include 174 | 175 | int main() { 176 | py::scoped_interpreter guard{}; 177 | py::print(Hello, World!); 178 | } // <-- interpreter shutdown 179 | \endrst */ 180 | class scoped_interpreter { 181 | public: 182 | scoped_interpreter(bool init_signal_handlers = true) { 183 | initialize_interpreter(init_signal_handlers); 184 | } 185 | 186 | scoped_interpreter(const scoped_interpreter &) = delete; 187 | scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; } 188 | scoped_interpreter &operator=(const scoped_interpreter &) = delete; 189 | scoped_interpreter &operator=(scoped_interpreter &&) = delete; 190 | 191 | ~scoped_interpreter() { 192 | if (is_valid) 193 | finalize_interpreter(); 194 | } 195 | 196 | private: 197 | bool is_valid = true; 198 | }; 199 | 200 | NAMESPACE_END(PYBIND11_NAMESPACE) 201 | -------------------------------------------------------------------------------- /pse/include/pybind11/eval.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/exec.h: Support for evaluating Python expressions and statements 3 | from strings and files 4 | 5 | Copyright (c) 2016 Klemens Morgenstern and 6 | Wenzel Jakob 7 | 8 | All rights reserved. Use of this source code is governed by a 9 | BSD-style license that can be found in the LICENSE file. 10 | */ 11 | 12 | #pragma once 13 | 14 | #include "pybind11.h" 15 | 16 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 17 | 18 | enum eval_mode { 19 | /// Evaluate a string containing an isolated expression 20 | eval_expr, 21 | 22 | /// Evaluate a string containing a single statement. Returns \c none 23 | eval_single_statement, 24 | 25 | /// Evaluate a string containing a sequence of statement. Returns \c none 26 | eval_statements 27 | }; 28 | 29 | template 30 | object eval(str expr, object global = globals(), object local = object()) { 31 | if (!local) 32 | local = global; 33 | 34 | /* PyRun_String does not accept a PyObject / encoding specifier, 35 | this seems to be the only alternative */ 36 | std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr; 37 | 38 | int start; 39 | switch (mode) { 40 | case eval_expr: start = Py_eval_input; break; 41 | case eval_single_statement: start = Py_single_input; break; 42 | case eval_statements: start = Py_file_input; break; 43 | default: pybind11_fail("invalid evaluation mode"); 44 | } 45 | 46 | PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr()); 47 | if (!result) 48 | throw error_already_set(); 49 | return reinterpret_steal(result); 50 | } 51 | 52 | template 53 | object eval(const char (&s)[N], object global = globals(), object local = object()) { 54 | /* Support raw string literals by removing common leading whitespace */ 55 | auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s)) 56 | : str(s); 57 | return eval(expr, global, local); 58 | } 59 | 60 | inline void exec(str expr, object global = globals(), object local = object()) { 61 | eval(expr, global, local); 62 | } 63 | 64 | template 65 | void exec(const char (&s)[N], object global = globals(), object local = object()) { 66 | eval(s, global, local); 67 | } 68 | 69 | template 70 | object eval_file(str fname, object global = globals(), object local = object()) { 71 | if (!local) 72 | local = global; 73 | 74 | int start; 75 | switch (mode) { 76 | case eval_expr: start = Py_eval_input; break; 77 | case eval_single_statement: start = Py_single_input; break; 78 | case eval_statements: start = Py_file_input; break; 79 | default: pybind11_fail("invalid evaluation mode"); 80 | } 81 | 82 | int closeFile = 1; 83 | std::string fname_str = (std::string) fname; 84 | #if PY_VERSION_HEX >= 0x03040000 85 | FILE *f = _Py_fopen_obj(fname.ptr(), "r"); 86 | #elif PY_VERSION_HEX >= 0x03000000 87 | FILE *f = _Py_fopen(fname.ptr(), "r"); 88 | #else 89 | /* No unicode support in open() :( */ 90 | auto fobj = reinterpret_steal(PyFile_FromString( 91 | const_cast(fname_str.c_str()), 92 | const_cast("r"))); 93 | FILE *f = nullptr; 94 | if (fobj) 95 | f = PyFile_AsFile(fobj.ptr()); 96 | closeFile = 0; 97 | #endif 98 | if (!f) { 99 | PyErr_Clear(); 100 | pybind11_fail("File \"" + fname_str + "\" could not be opened!"); 101 | } 102 | 103 | #if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION) 104 | PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(), 105 | local.ptr()); 106 | (void) closeFile; 107 | #else 108 | PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(), 109 | local.ptr(), closeFile); 110 | #endif 111 | 112 | if (!result) 113 | throw error_already_set(); 114 | return reinterpret_steal(result); 115 | } 116 | 117 | NAMESPACE_END(PYBIND11_NAMESPACE) 118 | -------------------------------------------------------------------------------- /pse/include/pybind11/functional.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/functional.h: std::function<> support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | template 19 | struct type_caster> { 20 | using type = std::function; 21 | using retval_type = conditional_t::value, void_type, Return>; 22 | using function_type = Return (*) (Args...); 23 | 24 | public: 25 | bool load(handle src, bool convert) { 26 | if (src.is_none()) { 27 | // Defer accepting None to other overloads (if we aren't in convert mode): 28 | if (!convert) return false; 29 | return true; 30 | } 31 | 32 | if (!isinstance(src)) 33 | return false; 34 | 35 | auto func = reinterpret_borrow(src); 36 | 37 | /* 38 | When passing a C++ function as an argument to another C++ 39 | function via Python, every function call would normally involve 40 | a full C++ -> Python -> C++ roundtrip, which can be prohibitive. 41 | Here, we try to at least detect the case where the function is 42 | stateless (i.e. function pointer or lambda function without 43 | captured variables), in which case the roundtrip can be avoided. 44 | */ 45 | if (auto cfunc = func.cpp_function()) { 46 | auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); 47 | auto rec = (function_record *) c; 48 | 49 | if (rec && rec->is_stateless && 50 | same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { 51 | struct capture { function_type f; }; 52 | value = ((capture *) &rec->data)->f; 53 | return true; 54 | } 55 | } 56 | 57 | value = [func](Args... args) -> Return { 58 | gil_scoped_acquire acq; 59 | object retval(func(std::forward(args)...)); 60 | /* Visual studio 2015 parser issue: need parentheses around this expression */ 61 | return (retval.template cast()); 62 | }; 63 | return true; 64 | } 65 | 66 | template 67 | static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) { 68 | if (!f_) 69 | return none().inc_ref(); 70 | 71 | auto result = f_.template target(); 72 | if (result) 73 | return cpp_function(*result, policy).release(); 74 | else 75 | return cpp_function(std::forward(f_), policy).release(); 76 | } 77 | 78 | PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster::name...) + _("], ") 79 | + make_caster::name + _("]")); 80 | }; 81 | 82 | NAMESPACE_END(detail) 83 | NAMESPACE_END(PYBIND11_NAMESPACE) 84 | -------------------------------------------------------------------------------- /pse/include/pybind11/iostream.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python 3 | 4 | Copyright (c) 2017 Henry F. Schreiner 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 21 | NAMESPACE_BEGIN(detail) 22 | 23 | // Buffer that writes to Python instead of C++ 24 | class pythonbuf : public std::streambuf { 25 | private: 26 | using traits_type = std::streambuf::traits_type; 27 | 28 | char d_buffer[1024]; 29 | object pywrite; 30 | object pyflush; 31 | 32 | int overflow(int c) { 33 | if (!traits_type::eq_int_type(c, traits_type::eof())) { 34 | *pptr() = traits_type::to_char_type(c); 35 | pbump(1); 36 | } 37 | return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof(); 38 | } 39 | 40 | int sync() { 41 | if (pbase() != pptr()) { 42 | // This subtraction cannot be negative, so dropping the sign 43 | str line(pbase(), static_cast(pptr() - pbase())); 44 | 45 | pywrite(line); 46 | pyflush(); 47 | 48 | setp(pbase(), epptr()); 49 | } 50 | return 0; 51 | } 52 | 53 | public: 54 | pythonbuf(object pyostream) 55 | : pywrite(pyostream.attr("write")), 56 | pyflush(pyostream.attr("flush")) { 57 | setp(d_buffer, d_buffer + sizeof(d_buffer) - 1); 58 | } 59 | 60 | /// Sync before destroy 61 | ~pythonbuf() { 62 | sync(); 63 | } 64 | }; 65 | 66 | NAMESPACE_END(detail) 67 | 68 | 69 | /** \rst 70 | This a move-only guard that redirects output. 71 | 72 | .. code-block:: cpp 73 | 74 | #include 75 | 76 | ... 77 | 78 | { 79 | py::scoped_ostream_redirect output; 80 | std::cout << "Hello, World!"; // Python stdout 81 | } // <-- return std::cout to normal 82 | 83 | You can explicitly pass the c++ stream and the python object, 84 | for example to guard stderr instead. 85 | 86 | .. code-block:: cpp 87 | 88 | { 89 | py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")}; 90 | std::cerr << "Hello, World!"; 91 | } 92 | \endrst */ 93 | class scoped_ostream_redirect { 94 | protected: 95 | std::streambuf *old; 96 | std::ostream &costream; 97 | detail::pythonbuf buffer; 98 | 99 | public: 100 | scoped_ostream_redirect( 101 | std::ostream &costream = std::cout, 102 | object pyostream = module::import("sys").attr("stdout")) 103 | : costream(costream), buffer(pyostream) { 104 | old = costream.rdbuf(&buffer); 105 | } 106 | 107 | ~scoped_ostream_redirect() { 108 | costream.rdbuf(old); 109 | } 110 | 111 | scoped_ostream_redirect(const scoped_ostream_redirect &) = delete; 112 | scoped_ostream_redirect(scoped_ostream_redirect &&other) = default; 113 | scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete; 114 | scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete; 115 | }; 116 | 117 | 118 | /** \rst 119 | Like `scoped_ostream_redirect`, but redirects cerr by default. This class 120 | is provided primary to make ``py::call_guard`` easier to make. 121 | 122 | .. code-block:: cpp 123 | 124 | m.def("noisy_func", &noisy_func, 125 | py::call_guard()); 127 | 128 | \endrst */ 129 | class scoped_estream_redirect : public scoped_ostream_redirect { 130 | public: 131 | scoped_estream_redirect( 132 | std::ostream &costream = std::cerr, 133 | object pyostream = module::import("sys").attr("stderr")) 134 | : scoped_ostream_redirect(costream,pyostream) {} 135 | }; 136 | 137 | 138 | NAMESPACE_BEGIN(detail) 139 | 140 | // Class to redirect output as a context manager. C++ backend. 141 | class OstreamRedirect { 142 | bool do_stdout_; 143 | bool do_stderr_; 144 | std::unique_ptr redirect_stdout; 145 | std::unique_ptr redirect_stderr; 146 | 147 | public: 148 | OstreamRedirect(bool do_stdout = true, bool do_stderr = true) 149 | : do_stdout_(do_stdout), do_stderr_(do_stderr) {} 150 | 151 | void enter() { 152 | if (do_stdout_) 153 | redirect_stdout.reset(new scoped_ostream_redirect()); 154 | if (do_stderr_) 155 | redirect_stderr.reset(new scoped_estream_redirect()); 156 | } 157 | 158 | void exit() { 159 | redirect_stdout.reset(); 160 | redirect_stderr.reset(); 161 | } 162 | }; 163 | 164 | NAMESPACE_END(detail) 165 | 166 | /** \rst 167 | This is a helper function to add a C++ redirect context manager to Python 168 | instead of using a C++ guard. To use it, add the following to your binding code: 169 | 170 | .. code-block:: cpp 171 | 172 | #include 173 | 174 | ... 175 | 176 | py::add_ostream_redirect(m, "ostream_redirect"); 177 | 178 | You now have a Python context manager that redirects your output: 179 | 180 | .. code-block:: python 181 | 182 | with m.ostream_redirect(): 183 | m.print_to_cout_function() 184 | 185 | This manager can optionally be told which streams to operate on: 186 | 187 | .. code-block:: python 188 | 189 | with m.ostream_redirect(stdout=true, stderr=true): 190 | m.noisy_function_with_error_printing() 191 | 192 | \endrst */ 193 | inline class_ add_ostream_redirect(module m, std::string name = "ostream_redirect") { 194 | return class_(m, name.c_str(), module_local()) 195 | .def(init(), arg("stdout")=true, arg("stderr")=true) 196 | .def("__enter__", &detail::OstreamRedirect::enter) 197 | .def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); }); 198 | } 199 | 200 | NAMESPACE_END(PYBIND11_NAMESPACE) 201 | -------------------------------------------------------------------------------- /pse/include/pybind11/operators.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/operator.h: Metatemplates for operator overloading 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | 14 | #if defined(__clang__) && !defined(__INTEL_COMPILER) 15 | # pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type())) 16 | #elif defined(_MSC_VER) 17 | # pragma warning(push) 18 | # pragma warning(disable: 4127) // warning C4127: Conditional expression is constant 19 | #endif 20 | 21 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 22 | NAMESPACE_BEGIN(detail) 23 | 24 | /// Enumeration with all supported operator types 25 | enum op_id : int { 26 | op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift, 27 | op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert, 28 | op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le, 29 | op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift, 30 | op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero, 31 | op_repr, op_truediv, op_itruediv, op_hash 32 | }; 33 | 34 | enum op_type : int { 35 | op_l, /* base type on left */ 36 | op_r, /* base type on right */ 37 | op_u /* unary operator */ 38 | }; 39 | 40 | struct self_t { }; 41 | static const self_t self = self_t(); 42 | 43 | /// Type for an unused type slot 44 | struct undefined_t { }; 45 | 46 | /// Don't warn about an unused variable 47 | inline self_t __self() { return self; } 48 | 49 | /// base template of operator implementations 50 | template struct op_impl { }; 51 | 52 | /// Operator implementation generator 53 | template struct op_ { 54 | template void execute(Class &cl, const Extra&... extra) const { 55 | using Base = typename Class::type; 56 | using L_type = conditional_t::value, Base, L>; 57 | using R_type = conditional_t::value, Base, R>; 58 | using op = op_impl; 59 | cl.def(op::name(), &op::execute, is_operator(), extra...); 60 | #if PY_MAJOR_VERSION < 3 61 | if (id == op_truediv || id == op_itruediv) 62 | cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", 63 | &op::execute, is_operator(), extra...); 64 | #endif 65 | } 66 | template void execute_cast(Class &cl, const Extra&... extra) const { 67 | using Base = typename Class::type; 68 | using L_type = conditional_t::value, Base, L>; 69 | using R_type = conditional_t::value, Base, R>; 70 | using op = op_impl; 71 | cl.def(op::name(), &op::execute_cast, is_operator(), extra...); 72 | #if PY_MAJOR_VERSION < 3 73 | if (id == op_truediv || id == op_itruediv) 74 | cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", 75 | &op::execute, is_operator(), extra...); 76 | #endif 77 | } 78 | }; 79 | 80 | #define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \ 81 | template struct op_impl { \ 82 | static char const* name() { return "__" #id "__"; } \ 83 | static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \ 84 | static B execute_cast(const L &l, const R &r) { return B(expr); } \ 85 | }; \ 86 | template struct op_impl { \ 87 | static char const* name() { return "__" #rid "__"; } \ 88 | static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \ 89 | static B execute_cast(const R &r, const L &l) { return B(expr); } \ 90 | }; \ 91 | inline op_ op(const self_t &, const self_t &) { \ 92 | return op_(); \ 93 | } \ 94 | template op_ op(const self_t &, const T &) { \ 95 | return op_(); \ 96 | } \ 97 | template op_ op(const T &, const self_t &) { \ 98 | return op_(); \ 99 | } 100 | 101 | #define PYBIND11_INPLACE_OPERATOR(id, op, expr) \ 102 | template struct op_impl { \ 103 | static char const* name() { return "__" #id "__"; } \ 104 | static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \ 105 | static B execute_cast(L &l, const R &r) { return B(expr); } \ 106 | }; \ 107 | template op_ op(const self_t &, const T &) { \ 108 | return op_(); \ 109 | } 110 | 111 | #define PYBIND11_UNARY_OPERATOR(id, op, expr) \ 112 | template struct op_impl { \ 113 | static char const* name() { return "__" #id "__"; } \ 114 | static auto execute(const L &l) -> decltype(expr) { return expr; } \ 115 | static B execute_cast(const L &l) { return B(expr); } \ 116 | }; \ 117 | inline op_ op(const self_t &) { \ 118 | return op_(); \ 119 | } 120 | 121 | PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r) 122 | PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r) 123 | PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r) 124 | PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r) 125 | PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r) 126 | PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r) 127 | PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r) 128 | PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r) 129 | PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r) 130 | PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r) 131 | PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r) 132 | PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r) 133 | PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r) 134 | PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r) 135 | PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r) 136 | PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r) 137 | //PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r)) 138 | PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r) 139 | PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r) 140 | PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r) 141 | PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r) 142 | PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r) 143 | PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r) 144 | PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r) 145 | PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r) 146 | PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r) 147 | PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r) 148 | PYBIND11_UNARY_OPERATOR(neg, operator-, -l) 149 | PYBIND11_UNARY_OPERATOR(pos, operator+, +l) 150 | PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l)) 151 | PYBIND11_UNARY_OPERATOR(hash, hash, std::hash()(l)) 152 | PYBIND11_UNARY_OPERATOR(invert, operator~, (~l)) 153 | PYBIND11_UNARY_OPERATOR(bool, operator!, !!l) 154 | PYBIND11_UNARY_OPERATOR(int, int_, (int) l) 155 | PYBIND11_UNARY_OPERATOR(float, float_, (double) l) 156 | 157 | #undef PYBIND11_BINARY_OPERATOR 158 | #undef PYBIND11_INPLACE_OPERATOR 159 | #undef PYBIND11_UNARY_OPERATOR 160 | NAMESPACE_END(detail) 161 | 162 | using detail::self; 163 | 164 | NAMESPACE_END(PYBIND11_NAMESPACE) 165 | 166 | #if defined(_MSC_VER) 167 | # pragma warning(pop) 168 | #endif 169 | -------------------------------------------------------------------------------- /pse/include/pybind11/options.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/options.h: global settings that are configurable at runtime. 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | 16 | class options { 17 | public: 18 | 19 | // Default RAII constructor, which leaves settings as they currently are. 20 | options() : previous_state(global_state()) {} 21 | 22 | // Class is non-copyable. 23 | options(const options&) = delete; 24 | options& operator=(const options&) = delete; 25 | 26 | // Destructor, which restores settings that were in effect before. 27 | ~options() { 28 | global_state() = previous_state; 29 | } 30 | 31 | // Setter methods (affect the global state): 32 | 33 | options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; } 34 | 35 | options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; } 36 | 37 | options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; } 38 | 39 | options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; } 40 | 41 | // Getter methods (return the global state): 42 | 43 | static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; } 44 | 45 | static bool show_function_signatures() { return global_state().show_function_signatures; } 46 | 47 | // This type is not meant to be allocated on the heap. 48 | void* operator new(size_t) = delete; 49 | 50 | private: 51 | 52 | struct state { 53 | bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings. 54 | bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings. 55 | }; 56 | 57 | static state &global_state() { 58 | static state instance; 59 | return instance; 60 | } 61 | 62 | state previous_state; 63 | }; 64 | 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /pse/include/pybind11/stl.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/stl.h: Transparent conversion for STL data types 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #if defined(_MSC_VER) 23 | #pragma warning(push) 24 | #pragma warning(disable: 4127) // warning C4127: Conditional expression is constant 25 | #endif 26 | 27 | #ifdef __has_include 28 | // std::optional (but including it in c++14 mode isn't allowed) 29 | # if defined(PYBIND11_CPP17) && __has_include() 30 | # include 31 | # define PYBIND11_HAS_OPTIONAL 1 32 | # endif 33 | // std::experimental::optional (but not allowed in c++11 mode) 34 | # if defined(PYBIND11_CPP14) && (__has_include() && \ 35 | !__has_include()) 36 | # include 37 | # define PYBIND11_HAS_EXP_OPTIONAL 1 38 | # endif 39 | // std::variant 40 | # if defined(PYBIND11_CPP17) && __has_include() 41 | # include 42 | # define PYBIND11_HAS_VARIANT 1 43 | # endif 44 | #elif defined(_MSC_VER) && defined(PYBIND11_CPP17) 45 | # include 46 | # include 47 | # define PYBIND11_HAS_OPTIONAL 1 48 | # define PYBIND11_HAS_VARIANT 1 49 | #endif 50 | 51 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 52 | NAMESPACE_BEGIN(detail) 53 | 54 | /// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for 55 | /// forwarding a container element). Typically used indirect via forwarded_type(), below. 56 | template 57 | using forwarded_type = conditional_t< 58 | std::is_lvalue_reference::value, remove_reference_t &, remove_reference_t &&>; 59 | 60 | /// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically 61 | /// used for forwarding a container's elements. 62 | template 63 | forwarded_type forward_like(U &&u) { 64 | return std::forward>(std::forward(u)); 65 | } 66 | 67 | template struct set_caster { 68 | using type = Type; 69 | using key_conv = make_caster; 70 | 71 | bool load(handle src, bool convert) { 72 | if (!isinstance(src)) 73 | return false; 74 | auto s = reinterpret_borrow(src); 75 | value.clear(); 76 | for (auto entry : s) { 77 | key_conv conv; 78 | if (!conv.load(entry, convert)) 79 | return false; 80 | value.insert(cast_op(std::move(conv))); 81 | } 82 | return true; 83 | } 84 | 85 | template 86 | static handle cast(T &&src, return_value_policy policy, handle parent) { 87 | if (!std::is_lvalue_reference::value) 88 | policy = return_value_policy_override::policy(policy); 89 | pybind11::set s; 90 | for (auto &&value : src) { 91 | auto value_ = reinterpret_steal(key_conv::cast(forward_like(value), policy, parent)); 92 | if (!value_ || !s.add(value_)) 93 | return handle(); 94 | } 95 | return s.release(); 96 | } 97 | 98 | PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]")); 99 | }; 100 | 101 | template struct map_caster { 102 | using key_conv = make_caster; 103 | using value_conv = make_caster; 104 | 105 | bool load(handle src, bool convert) { 106 | if (!isinstance(src)) 107 | return false; 108 | auto d = reinterpret_borrow(src); 109 | value.clear(); 110 | for (auto it : d) { 111 | key_conv kconv; 112 | value_conv vconv; 113 | if (!kconv.load(it.first.ptr(), convert) || 114 | !vconv.load(it.second.ptr(), convert)) 115 | return false; 116 | value.emplace(cast_op(std::move(kconv)), cast_op(std::move(vconv))); 117 | } 118 | return true; 119 | } 120 | 121 | template 122 | static handle cast(T &&src, return_value_policy policy, handle parent) { 123 | dict d; 124 | return_value_policy policy_key = policy; 125 | return_value_policy policy_value = policy; 126 | if (!std::is_lvalue_reference::value) { 127 | policy_key = return_value_policy_override::policy(policy_key); 128 | policy_value = return_value_policy_override::policy(policy_value); 129 | } 130 | for (auto &&kv : src) { 131 | auto key = reinterpret_steal(key_conv::cast(forward_like(kv.first), policy_key, parent)); 132 | auto value = reinterpret_steal(value_conv::cast(forward_like(kv.second), policy_value, parent)); 133 | if (!key || !value) 134 | return handle(); 135 | d[key] = value; 136 | } 137 | return d.release(); 138 | } 139 | 140 | PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") + value_conv::name + _("]")); 141 | }; 142 | 143 | template struct list_caster { 144 | using value_conv = make_caster; 145 | 146 | bool load(handle src, bool convert) { 147 | if (!isinstance(src) || isinstance(src)) 148 | return false; 149 | auto s = reinterpret_borrow(src); 150 | value.clear(); 151 | reserve_maybe(s, &value); 152 | for (auto it : s) { 153 | value_conv conv; 154 | if (!conv.load(it, convert)) 155 | return false; 156 | value.push_back(cast_op(std::move(conv))); 157 | } 158 | return true; 159 | } 160 | 161 | private: 162 | template ().reserve(0)), void>::value, int> = 0> 164 | void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); } 165 | void reserve_maybe(sequence, void *) { } 166 | 167 | public: 168 | template 169 | static handle cast(T &&src, return_value_policy policy, handle parent) { 170 | if (!std::is_lvalue_reference::value) 171 | policy = return_value_policy_override::policy(policy); 172 | list l(src.size()); 173 | size_t index = 0; 174 | for (auto &&value : src) { 175 | auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); 176 | if (!value_) 177 | return handle(); 178 | PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference 179 | } 180 | return l.release(); 181 | } 182 | 183 | PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]")); 184 | }; 185 | 186 | template struct type_caster> 187 | : list_caster, Type> { }; 188 | 189 | template struct type_caster> 190 | : list_caster, Type> { }; 191 | 192 | template struct type_caster> 193 | : list_caster, Type> { }; 194 | 195 | template struct array_caster { 196 | using value_conv = make_caster; 197 | 198 | private: 199 | template 200 | bool require_size(enable_if_t size) { 201 | if (value.size() != size) 202 | value.resize(size); 203 | return true; 204 | } 205 | template 206 | bool require_size(enable_if_t size) { 207 | return size == Size; 208 | } 209 | 210 | public: 211 | bool load(handle src, bool convert) { 212 | if (!isinstance(src)) 213 | return false; 214 | auto l = reinterpret_borrow(src); 215 | if (!require_size(l.size())) 216 | return false; 217 | size_t ctr = 0; 218 | for (auto it : l) { 219 | value_conv conv; 220 | if (!conv.load(it, convert)) 221 | return false; 222 | value[ctr++] = cast_op(std::move(conv)); 223 | } 224 | return true; 225 | } 226 | 227 | template 228 | static handle cast(T &&src, return_value_policy policy, handle parent) { 229 | list l(src.size()); 230 | size_t index = 0; 231 | for (auto &&value : src) { 232 | auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); 233 | if (!value_) 234 | return handle(); 235 | PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference 236 | } 237 | return l.release(); 238 | } 239 | 240 | PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name + _(_(""), _("[") + _() + _("]")) + _("]")); 241 | }; 242 | 243 | template struct type_caster> 244 | : array_caster, Type, false, Size> { }; 245 | 246 | template struct type_caster> 247 | : array_caster, Type, true> { }; 248 | 249 | template struct type_caster> 250 | : set_caster, Key> { }; 251 | 252 | template struct type_caster> 253 | : set_caster, Key> { }; 254 | 255 | template struct type_caster> 256 | : map_caster, Key, Value> { }; 257 | 258 | template struct type_caster> 259 | : map_caster, Key, Value> { }; 260 | 261 | // This type caster is intended to be used for std::optional and std::experimental::optional 262 | template struct optional_caster { 263 | using value_conv = make_caster; 264 | 265 | template 266 | static handle cast(T_ &&src, return_value_policy policy, handle parent) { 267 | if (!src) 268 | return none().inc_ref(); 269 | policy = return_value_policy_override::policy(policy); 270 | return value_conv::cast(*std::forward(src), policy, parent); 271 | } 272 | 273 | bool load(handle src, bool convert) { 274 | if (!src) { 275 | return false; 276 | } else if (src.is_none()) { 277 | return true; // default-constructed value is already empty 278 | } 279 | value_conv inner_caster; 280 | if (!inner_caster.load(src, convert)) 281 | return false; 282 | 283 | value.emplace(cast_op(std::move(inner_caster))); 284 | return true; 285 | } 286 | 287 | PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]")); 288 | }; 289 | 290 | #if PYBIND11_HAS_OPTIONAL 291 | template struct type_caster> 292 | : public optional_caster> {}; 293 | 294 | template<> struct type_caster 295 | : public void_caster {}; 296 | #endif 297 | 298 | #if PYBIND11_HAS_EXP_OPTIONAL 299 | template struct type_caster> 300 | : public optional_caster> {}; 301 | 302 | template<> struct type_caster 303 | : public void_caster {}; 304 | #endif 305 | 306 | /// Visit a variant and cast any found type to Python 307 | struct variant_caster_visitor { 308 | return_value_policy policy; 309 | handle parent; 310 | 311 | using result_type = handle; // required by boost::variant in C++11 312 | 313 | template 314 | result_type operator()(T &&src) const { 315 | return make_caster::cast(std::forward(src), policy, parent); 316 | } 317 | }; 318 | 319 | /// Helper class which abstracts away variant's `visit` function. `std::variant` and similar 320 | /// `namespace::variant` types which provide a `namespace::visit()` function are handled here 321 | /// automatically using argument-dependent lookup. Users can provide specializations for other 322 | /// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`. 323 | template class Variant> 324 | struct visit_helper { 325 | template 326 | static auto call(Args &&...args) -> decltype(visit(std::forward(args)...)) { 327 | return visit(std::forward(args)...); 328 | } 329 | }; 330 | 331 | /// Generic variant caster 332 | template struct variant_caster; 333 | 334 | template class V, typename... Ts> 335 | struct variant_caster> { 336 | static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative."); 337 | 338 | template 339 | bool load_alternative(handle src, bool convert, type_list) { 340 | auto caster = make_caster(); 341 | if (caster.load(src, convert)) { 342 | value = cast_op(caster); 343 | return true; 344 | } 345 | return load_alternative(src, convert, type_list{}); 346 | } 347 | 348 | bool load_alternative(handle, bool, type_list<>) { return false; } 349 | 350 | bool load(handle src, bool convert) { 351 | // Do a first pass without conversions to improve constructor resolution. 352 | // E.g. `py::int_(1).cast>()` needs to fill the `int` 353 | // slot of the variant. Without two-pass loading `double` would be filled 354 | // because it appears first and a conversion is possible. 355 | if (convert && load_alternative(src, false, type_list{})) 356 | return true; 357 | return load_alternative(src, convert, type_list{}); 358 | } 359 | 360 | template 361 | static handle cast(Variant &&src, return_value_policy policy, handle parent) { 362 | return visit_helper::call(variant_caster_visitor{policy, parent}, 363 | std::forward(src)); 364 | } 365 | 366 | using Type = V; 367 | PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster::name...) + _("]")); 368 | }; 369 | 370 | #if PYBIND11_HAS_VARIANT 371 | template 372 | struct type_caster> : variant_caster> { }; 373 | #endif 374 | 375 | NAMESPACE_END(detail) 376 | 377 | inline std::ostream &operator<<(std::ostream &os, const handle &obj) { 378 | os << (std::string) str(obj); 379 | return os; 380 | } 381 | 382 | NAMESPACE_END(PYBIND11_NAMESPACE) 383 | 384 | #if defined(_MSC_VER) 385 | #pragma warning(pop) 386 | #endif 387 | -------------------------------------------------------------------------------- /pse/include/pybind11/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(pybind11) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(pybind11) 54 | -------------------------------------------------------------------------------- /pse/ncnn/examples/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | find_package(OpenCV QUIET COMPONENTS core highgui imgproc imgcodecs) 2 | if(NOT OpenCV_FOUND) 3 | find_package(OpenCV REQUIRED COMPONENTS core highgui imgproc) 4 | endif() 5 | 6 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../src) 7 | include_directories(${CMAKE_CURRENT_BINARY_DIR}/../src) 8 | 9 | set(NCNN_EXAMPLE_LINK_LIBRARIES ncnn ${OpenCV_LIBS}) 10 | if(NCNN_VULKAN) 11 | list(APPEND NCNN_EXAMPLE_LINK_LIBRARIES ${Vulkan_LIBRARY}) 12 | endif() 13 | 14 | add_executable(psenet psenet.cpp) 15 | target_link_libraries(psenet ${NCNN_EXAMPLE_LINK_LIBRARIES}) 16 | -------------------------------------------------------------------------------- /pse/ncnn/examples/psenet.cpp: -------------------------------------------------------------------------------- 1 | // Tencent is pleased to support the open source community by making ncnn available. 2 | // 3 | // Copyright (C) 2018 THL A29 Limited, a Tencent company. All rights reserved. 4 | // 5 | // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6 | // in compliance with the License. You may obtain a copy of the License at 7 | // 8 | // https://opensource.org/licenses/BSD-3-Clause 9 | // 10 | // Unless required by applicable law or agreed to in writing, software distributed 11 | // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 | // CONDITIONS OF ANY KIND, either express or implied. See the License for the 13 | // specific language governing permissions and limitations under the License. 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include "platform.h" 24 | #include "net.h" 25 | 26 | void ncnn2cv(ncnn::Mat src, cv::Mat &score, cv::Mat &thre_img, const float thre_val = 0.5) { 27 | float *srcdata = (float *) src.data; 28 | for (int i = 0; i < src.h; i++) { 29 | for (int j = 0; j < src.w; j++) { 30 | score.at(i, j) = srcdata[i * src.w + j]; 31 | if (srcdata[i * src.w + j] >= thre_val) { 32 | thre_img.at(i, j) = 255; 33 | } else { 34 | thre_img.at(i, j) = 0; 35 | } 36 | } 37 | } 38 | } 39 | 40 | cv::Mat resize_img(cv::Mat src,const int long_size) 41 | { 42 | int w = src.cols; 43 | int h = src.rows; 44 | std::cout<<"原图尺寸 (" << w << ", "< h) 47 | { 48 | scale = (float)long_size / w; 49 | w = long_size; 50 | h = h * scale; 51 | } 52 | else 53 | { 54 | scale = (float)long_size / h; 55 | h = long_size; 56 | w = w * scale; 57 | } 58 | if (h % 32 != 0) 59 | { 60 | h = (h / 32 + 1) * 32; 61 | } 62 | if (w % 32 != 0) 63 | { 64 | w = (w / 32 + 1) * 32; 65 | } 66 | std::cout<<"缩放尺寸 (" << w << ", "<> &bboxs) { 73 | cv::Mat dst; 74 | if (src.channels() == 1) { 75 | cv::cvtColor(src, dst, cv::COLOR_GRAY2BGR); 76 | } else { 77 | dst = src.clone(); 78 | } 79 | auto color = cv::Scalar(0, 0, 255); 80 | for (auto bbox :bboxs) { 81 | cv::line(dst, bbox[0], bbox[1], color, 3); 82 | cv::line(dst, bbox[1], bbox[2], color, 3); 83 | cv::line(dst, bbox[2], bbox[3], color, 3); 84 | cv::line(dst, bbox[3], bbox[0], color, 3); 85 | } 86 | return dst; 87 | } 88 | 89 | std::vector> deocde(const cv::Mat &score, const cv::Mat &thre, const int scale, const float h_scale, const float w_scale) { 90 | int img_rows = score.rows; 91 | int img_cols = score.cols; 92 | auto min_w_h = std::min(img_cols,img_rows); 93 | min_w_h *= min_w_h / 20; 94 | cv::Mat stats, centroids, label_img(thre.size(), CV_32S); 95 | // 二值化 96 | // cv::threshold(cv_img * 255, thre, 0, 255, cv::THRESH_OTSU); 97 | // 计算连通域ss 98 | int nLabels = connectedComponentsWithStats(thre, label_img, stats, centroids); 99 | 100 | std::vector angles; 101 | std::vector> bboxs; 102 | 103 | for (int label = 1; label < nLabels; label++) { 104 | float area = stats.at(label, cv::CC_STAT_AREA); 105 | if (area < min_w_h / (scale * scale)) { 106 | continue; 107 | } 108 | // 计算该label的平均分数 109 | std::vector scores; 110 | std::vector points; 111 | for (int y = 0; y < img_rows; ++y) { 112 | for (int x = 0; x < img_cols; ++x) { 113 | if (label_img.at(y, x) == label) { 114 | scores.emplace_back(score.at(y, x)); 115 | points.emplace_back(cv::Point(x, y)); 116 | } 117 | } 118 | } 119 | 120 | //均值 121 | double sum = std::accumulate(std::begin(scores), std::end(scores), 0.0); 122 | if (sum == 0) { 123 | continue; 124 | } 125 | double mean = sum / scores.size(); 126 | 127 | if (mean < 0.8) { 128 | continue; 129 | } 130 | cv::RotatedRect rect = cv::minAreaRect(points); 131 | float w = rect.size.width; 132 | float h = rect.size.height; 133 | float angle = rect.angle; 134 | 135 | if (w < h) { 136 | std::swap(w, h); 137 | angle -= 90; 138 | } 139 | if (45 < std::abs(angle) && std::abs(angle) < 135) { 140 | std::swap(img_rows, img_cols); 141 | } 142 | points.clear(); 143 | // 对卡号进行限制,长宽比,卡号的宽度不能超过图片宽高的95% 144 | if (w > h * 8 && w < img_cols * 0.95) { 145 | cv::Mat bbox; 146 | cv::boxPoints(rect, bbox); 147 | for (int i = 0; i < bbox.rows; ++i) { 148 | points.emplace_back(cv::Point(int(bbox.at(i, 0) * w_scale), int(bbox.at(i, 1) * h_scale))); 149 | } 150 | bboxs.emplace_back(points); 151 | angles.emplace_back(angle); 152 | } 153 | } 154 | return bboxs; 155 | } 156 | 157 | static int detect_rfcn(const char *model, const char *model_param, const char *imagepath, const int long_size = 800) { 158 | cv::Mat im_bgr = cv::imread(imagepath, 1); 159 | 160 | if (im_bgr.empty()) { 161 | fprintf(stderr, "cv::imread %s failed\n", imagepath); 162 | return -1; 163 | } 164 | // 图像缩放 165 | auto im = resize_img(im_bgr, long_size); 166 | float h_scale = im_bgr.rows * 1.0 / im.rows; 167 | float w_scale = im_bgr.cols * 1.0 / im.cols; 168 | 169 | ncnn::Mat in = ncnn::Mat::from_pixels(im.data, ncnn::Mat::PIXEL_BGR, im.cols, im.rows); 170 | const float norm_vals[3] = { 1 / 255.f ,1 / 255.f ,1 / 255.f}; 171 | in.substract_mean_normalize(0,norm_vals); 172 | 173 | std::cout << "输入尺寸 (" << in.w << ", " << in.h << ")" << std::endl; 174 | 175 | ncnn::Net psenet; 176 | psenet.load_param(model_param); 177 | psenet.load_model(model); 178 | ncnn::Extractor ex = psenet.create_extractor(); 179 | // ex.set_num_threads(4);ss 180 | ex.input("0", in); 181 | 182 | ncnn::Mat preds; 183 | double time1 = static_cast( cv::getTickCount()); 184 | ex.extract("636", preds); 185 | std::cout << "前向时间:" << (static_cast( cv::getTickCount()) - time1) / cv::getTickFrequency() << "s" << std::endl; 186 | std::cout << "网络输出尺寸 (" << preds.w << ", " << preds.h << ", " << preds.c << ")" << std::endl; 187 | 188 | time1 = static_cast( cv::getTickCount()); 189 | cv::Mat score = cv::Mat::zeros(preds.h, preds.w, CV_32FC1); 190 | cv::Mat thre = cv::Mat::zeros(preds.h, preds.w, CV_8UC1); 191 | ncnn2cv(preds, score, thre); 192 | auto bboxs = deocde(score, thre, 1, h_scale, w_scale); 193 | std::cout << "decode 时间:" << (static_cast( cv::getTickCount()) - time1) / cv::getTickFrequency() << "s" << std::endl; 194 | auto result = draw_bbox(im_bgr, bboxs); 195 | cv::imwrite("/home/zj/project/ncnn/examples/imgs/result.jpg", result); 196 | cv::imwrite("/home/zj/project/ncnn/examples/imgs/net_result.jpg", score * 255); 197 | cv::imwrite("/home/zj/project/ncnn/examples/imgs/net_thre.jpg", thre); 198 | return 0; 199 | } 200 | 201 | int main(int argc, char **argv) { 202 | if (argc != 5) { 203 | fprintf(stderr, "Usage: %s [model model path imagepath long_size]\n", argv[0]); 204 | return -1; 205 | } 206 | const char *model = argv[1]; 207 | const char *model_param = argv[2]; 208 | const char *imagepath = argv[3]; 209 | const int long_size = atoi(argv[4]); 210 | std::cout << model << " " << model_param << " " << imagepath << " " << long_size << std::endl; 211 | 212 | detect_rfcn(model, model_param, imagepath, long_size); 213 | return 0; 214 | } 215 | -------------------------------------------------------------------------------- /pse/ncnn/examples/run.sh: -------------------------------------------------------------------------------- 1 | /home/zj/project/ncnn/build/examples/psenet /home/zj/project/ncnn/examples/psenet.bin /home/zj/project/ncnn/examples/psenet.param /home/zj/card/8_6217921001182693.jpg 600 -------------------------------------------------------------------------------- /pse/pse.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // pse 3 | // reference https://github.com/whai362/PSENet/issues/15 4 | // Created by liuheng on 11/3/19. 5 | // Copyright © 2019年 liuheng. All rights reserved. 6 | // 7 | #include 8 | #include "include/pybind11/pybind11.h" 9 | #include "include/pybind11/numpy.h" 10 | #include "include/pybind11/stl.h" 11 | #include "include/pybind11/stl_bind.h" 12 | 13 | namespace py = pybind11; 14 | 15 | namespace pse{ 16 | //S5->S0, small->big 17 | std::vector> pse( 18 | py::array_t label_map, 19 | py::array_t Sn, 20 | int c = 6) 21 | { 22 | auto pbuf_label_map = label_map.request(); 23 | auto pbuf_Sn = Sn.request(); 24 | if (pbuf_label_map.ndim != 2 || pbuf_label_map.shape[0]==0 || pbuf_label_map.shape[1]==0) 25 | throw std::runtime_error("label map must have a shape of (h>0, w>0)"); 26 | int h = pbuf_label_map.shape[0]; 27 | int w = pbuf_label_map.shape[1]; 28 | if (pbuf_Sn.ndim != 3 || pbuf_Sn.shape[0] != c || pbuf_Sn.shape[1]!=h || pbuf_Sn.shape[2]!=w) 29 | throw std::runtime_error("Sn must have a shape of (c>0, h>0, w>0)"); 30 | 31 | std::vector> res; 32 | for (size_t i = 0; i(w, 0)); 34 | auto ptr_label_map = static_cast(pbuf_label_map.ptr); 35 | auto ptr_Sn = static_cast(pbuf_Sn.ptr); 36 | 37 | std::queue> q, next_q; 38 | 39 | for (size_t i = 0; i0) 46 | { 47 | q.push(std::make_tuple(i, j, label)); 48 | res[i][j] = label; 49 | } 50 | } 51 | } 52 | 53 | int dx[4] = {-1, 1, 0, 0}; 54 | int dy[4] = {0, 0, -1, 1}; 55 | // merge from small to large kernel progressively 56 | for (int i = 1; i(q_n); 65 | int x = std::get<1>(q_n); 66 | int32_t l = std::get<2>(q_n); 67 | //store the edge pixel after one expansion 68 | bool is_edge = true; 69 | for (int idx=0; idx<4; idx++) 70 | { 71 | int index_y = y + dy[idx]; 72 | int index_x = x + dx[idx]; 73 | if (index_y<0 || index_y>=h || index_x<0 || index_x>=w) 74 | continue; 75 | if (!p_Sn[index_y*w+index_x] || res[index_y][index_x]>0) 76 | continue; 77 | q.push(std::make_tuple(index_y, index_x, l)); 78 | res[index_y][index_x]=l; 79 | is_edge = false; 80 | } 81 | if (is_edge){ 82 | next_q.push(std::make_tuple(y, x, l)); 83 | } 84 | } 85 | std::swap(q, next_q); 86 | } 87 | return res; 88 | } 89 | } 90 | 91 | PYBIND11_MODULE(pse, m){ 92 | m.def("pse_cpp", &pse::pse, " re-implementation pse algorithm(cpp)", py::arg("label_map"), py::arg("Sn"), py::arg("c")=6); 93 | } 94 | 95 | -------------------------------------------------------------------------------- /pse/pse.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WenmuZhou/PSENet.pytorch/f760c2f4938726a2d00efaf5e5b28218323c44ca/pse/pse.so -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2018/6/11 15:54 3 | # @Author : zhoujun 4 | import cv2 5 | import os 6 | import config 7 | 8 | os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu_id 9 | 10 | import shutil 11 | import glob 12 | import time 13 | import numpy as np 14 | import torch 15 | from tqdm import tqdm 16 | from torch import nn 17 | import torch.utils.data as Data 18 | from torchvision import transforms 19 | import torchvision.utils as vutils 20 | from torch.utils.tensorboard import SummaryWriter 21 | 22 | from dataset.data_utils import MyDataset 23 | from models import PSENet 24 | from models.loss import PSELoss 25 | from utils.utils import load_checkpoint, save_checkpoint, setup_logger 26 | from pse import decode as pse_decode 27 | from cal_recall import cal_recall_precison_f1 28 | 29 | 30 | def weights_init(m): 31 | if isinstance(m, nn.Conv2d): 32 | nn.init.kaiming_normal_(m.weight) 33 | if m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | 36 | 37 | # learning rate的warming up操作 38 | def adjust_learning_rate(optimizer, epoch): 39 | """Sets the learning rate 40 | # Adapted from PyTorch Imagenet example: 41 | # https://github.com/pytorch/examples/blob/master/imagenet/main.py 42 | """ 43 | if epoch < config.warm_up_epoch: 44 | lr = 1e-6 + (config.lr - 1e-6) * epoch / (config.warm_up_epoch) 45 | else: 46 | lr = config.lr * (config.lr_gamma ** (epoch / config.lr_decay_step[0])) 47 | 48 | for param_group in optimizer.param_groups: 49 | param_group['lr'] = lr 50 | 51 | return lr 52 | 53 | 54 | def train_epoch(net, optimizer, scheduler, train_loader, device, criterion, epoch, all_step, writer, logger): 55 | net.train() 56 | train_loss = 0. 57 | start = time.time() 58 | # lr = adjust_learning_rate(optimizer, epoch) 59 | lr = scheduler.get_last_lr()[0] 60 | for i, (images, labels, training_mask) in enumerate(train_loader): 61 | cur_batch = images.size()[0] 62 | images, labels, training_mask = images.to(device), labels.to(device), training_mask.to(device) 63 | # Forward 64 | y1 = net(images) 65 | loss_c, loss_s, loss = criterion(y1, labels, training_mask) 66 | # Backward 67 | optimizer.zero_grad() 68 | loss.backward() 69 | optimizer.step() 70 | train_loss += loss.item() 71 | 72 | loss_c = loss_c.item() 73 | loss_s = loss_s.item() 74 | loss = loss.item() 75 | cur_step = epoch * all_step + i 76 | writer.add_scalar(tag='Train/loss_c', scalar_value=loss_c, global_step=cur_step) 77 | writer.add_scalar(tag='Train/loss_s', scalar_value=loss_s, global_step=cur_step) 78 | writer.add_scalar(tag='Train/loss', scalar_value=loss, global_step=cur_step) 79 | writer.add_scalar(tag='Train/lr', scalar_value=lr, global_step=cur_step) 80 | 81 | if i % config.display_interval == 0: 82 | batch_time = time.time() - start 83 | logger.info( 84 | '[{}/{}], [{}/{}], step: {}, {:.3f} samples/sec, batch_loss: {:.4f}, batch_loss_c: {:.4f}, batch_loss_s: {:.4f}, time:{:.4f},lr:{:.4f}'.format( 85 | epoch, config.epochs, i, all_step, cur_step, config.display_interval * cur_batch / batch_time, 86 | loss, loss_c, loss_s, batch_time,lr)) 87 | start = time.time() 88 | 89 | if i % config.show_images_interval == 0: 90 | if config.display_input_images: 91 | # show images on tensorboard 92 | x = vutils.make_grid(images.detach().cpu(), nrow=4, normalize=True, scale_each=True, padding=20) 93 | writer.add_image(tag='input/image', img_tensor=x, global_step=cur_step) 94 | 95 | show_label = labels.detach().cpu() 96 | b, c, h, w = show_label.size() 97 | show_label = show_label.reshape(b * c, h, w) 98 | show_label = vutils.make_grid(show_label.unsqueeze(1), nrow=config.n, normalize=False, padding=20, 99 | pad_value=1) 100 | writer.add_image(tag='input/label', img_tensor=show_label, global_step=cur_step) 101 | 102 | if config.display_output_images: 103 | y1 = torch.sigmoid(y1) 104 | show_y = y1.detach().cpu() 105 | b, c, h, w = show_y.size() 106 | show_y = show_y.reshape(b * c, h, w) 107 | show_y = vutils.make_grid(show_y.unsqueeze(1), nrow=config.n, normalize=False, padding=20, pad_value=1) 108 | writer.add_image(tag='output/preds', img_tensor=show_y, global_step=cur_step) 109 | scheduler.step() 110 | writer.add_scalar(tag='Train_epoch/loss', scalar_value=train_loss / all_step, global_step=epoch) 111 | return train_loss / all_step, lr 112 | 113 | 114 | def eval(model, save_path, test_path, device): 115 | model.eval() 116 | # torch.cuda.empty_cache() # speed up evaluating after training finished 117 | img_path = os.path.join(test_path, 'img') 118 | gt_path = os.path.join(test_path, 'gt') 119 | if os.path.exists(save_path): 120 | shutil.rmtree(save_path, ignore_errors=True) 121 | if not os.path.exists(save_path): 122 | os.makedirs(save_path) 123 | long_size = 2240 124 | # 预测所有测试图片 125 | img_paths = [os.path.join(img_path, x) for x in os.listdir(img_path)] 126 | for img_path in tqdm(img_paths, desc='test models'): 127 | img_name = os.path.basename(img_path).split('.')[0] 128 | save_name = os.path.join(save_path, 'res_' + img_name + '.txt') 129 | 130 | assert os.path.exists(img_path), 'file is not exists' 131 | img = cv2.imread(img_path) 132 | h, w = img.shape[:2] 133 | #if max(h, w) > long_size: 134 | scale = long_size / max(h, w) 135 | img = cv2.resize(img, None, fx=scale, fy=scale) 136 | # 将图片由(w,h)变为(1,img_channel,h,w) 137 | tensor = transforms.ToTensor()(img) 138 | tensor = tensor.unsqueeze_(0) 139 | tensor = tensor.to(device) 140 | with torch.no_grad(): 141 | preds = model(tensor) 142 | preds, boxes_list = pse_decode(preds[0], config.scale) 143 | scale = (preds.shape[1] * 1.0 / w, preds.shape[0] * 1.0 / h) 144 | if len(boxes_list): 145 | boxes_list = boxes_list / scale 146 | np.savetxt(save_name, boxes_list.reshape(-1, 8), delimiter=',', fmt='%d') 147 | # 开始计算 recall precision f1 148 | result_dict = cal_recall_precison_f1(gt_path, save_path) 149 | return result_dict['recall'], result_dict['precision'], result_dict['hmean'] 150 | 151 | 152 | def main(): 153 | if config.output_dir is None: 154 | config.output_dir = 'output' 155 | if config.restart_training: 156 | shutil.rmtree(config.output_dir, ignore_errors=True) 157 | if not os.path.exists(config.output_dir): 158 | os.makedirs(config.output_dir) 159 | 160 | logger = setup_logger(os.path.join(config.output_dir, 'train_log')) 161 | logger.info(config.print()) 162 | 163 | torch.manual_seed(config.seed) # 为CPU设置随机种子 164 | if config.gpu_id is not None and torch.cuda.is_available(): 165 | torch.backends.cudnn.benchmark = True 166 | logger.info('train with gpu {} and pytorch {}'.format(config.gpu_id, torch.__version__)) 167 | device = torch.device("cuda:0") 168 | torch.cuda.manual_seed(config.seed) # 为当前GPU设置随机种子 169 | torch.cuda.manual_seed_all(config.seed) # 为所有GPU设置随机种子 170 | else: 171 | logger.info('train with cpu and pytorch {}'.format(torch.__version__)) 172 | device = torch.device("cpu") 173 | 174 | train_data = MyDataset(config.trainroot, data_shape=config.data_shape, n=config.n, m=config.m, 175 | transform=transforms.ToTensor()) 176 | train_loader = Data.DataLoader(dataset=train_data, batch_size=config.train_batch_size, shuffle=True, 177 | num_workers=int(config.workers)) 178 | 179 | writer = SummaryWriter(config.output_dir) 180 | model = PSENet(backbone=config.backbone, pretrained=config.pretrained, result_num=config.n, scale=config.scale) 181 | if not config.pretrained and not config.restart_training: 182 | model.apply(weights_init) 183 | 184 | num_gpus = torch.cuda.device_count() 185 | if num_gpus > 1: 186 | model = nn.DataParallel(model) 187 | model = model.to(device) 188 | # dummy_input = torch.autograd.Variable(torch.Tensor(1, 3, 600, 800).to(device)) 189 | # writer.add_graph(models=models, input_to_model=dummy_input) 190 | criterion = PSELoss(Lambda=config.Lambda, ratio=config.OHEM_ratio, reduction='mean') 191 | # optimizer = torch.optim.SGD(models.parameters(), lr=config.lr, momentum=0.99) 192 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 193 | if config.checkpoint != '' and not config.restart_training: 194 | start_epoch = load_checkpoint(config.checkpoint, model, logger, device, optimizer) 195 | start_epoch += 1 196 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, config.lr_decay_step, gamma=config.lr_gamma, 197 | last_epoch=start_epoch) 198 | else: 199 | start_epoch = config.start_epoch 200 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, config.lr_decay_step, gamma=config.lr_gamma) 201 | 202 | all_step = len(train_loader) 203 | logger.info('train dataset has {} samples,{} in dataloader'.format(train_data.__len__(), all_step)) 204 | epoch = 0 205 | best_model = {'recall': 0, 'precision': 0, 'f1': 0, 'models': ''} 206 | try: 207 | for epoch in range(start_epoch, config.epochs): 208 | start = time.time() 209 | train_loss, lr = train_epoch(model, optimizer, scheduler, train_loader, device, criterion, epoch, all_step, 210 | writer, logger) 211 | logger.info('[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format( 212 | epoch, config.epochs, train_loss, time.time() - start, lr)) 213 | # net_save_path = '{}/PSENet_{}_loss{:.6f}.pth'.format(config.output_dir, epoch, 214 | # train_loss) 215 | # save_checkpoint(net_save_path, models, optimizer, epoch, logger) 216 | if (0.3 < train_loss < 0.4 and epoch % 4 == 0) or train_loss < 0.3: 217 | recall, precision, f1 = eval(model, os.path.join(config.output_dir, 'output'), config.testroot, device) 218 | logger.info('test: recall: {:.6f}, precision: {:.6f}, f1: {:.6f}'.format(recall, precision, f1)) 219 | 220 | net_save_path = '{}/PSENet_{}_loss{:.6f}_r{:.6f}_p{:.6f}_f1{:.6f}.pth'.format(config.output_dir, epoch, 221 | train_loss, 222 | recall, 223 | precision, 224 | f1) 225 | save_checkpoint(net_save_path, model, optimizer, epoch, logger) 226 | if f1 > best_model['f1']: 227 | best_path = glob.glob(config.output_dir + '/Best_*.pth') 228 | for b_path in best_path: 229 | if os.path.exists(b_path): 230 | os.remove(b_path) 231 | 232 | best_model['recall'] = recall 233 | best_model['precision'] = precision 234 | best_model['f1'] = f1 235 | best_model['models'] = net_save_path 236 | 237 | best_save_path = '{}/Best_{}_r{:.6f}_p{:.6f}_f1{:.6f}.pth'.format(config.output_dir, epoch, 238 | recall, 239 | precision, 240 | f1) 241 | if os.path.exists(net_save_path): 242 | shutil.copyfile(net_save_path, best_save_path) 243 | else: 244 | save_checkpoint(best_save_path, model, optimizer, epoch, logger) 245 | 246 | pse_path = glob.glob(config.output_dir + '/PSENet_*.pth') 247 | for p_path in pse_path: 248 | if os.path.exists(p_path): 249 | os.remove(p_path) 250 | 251 | writer.add_scalar(tag='Test/recall', scalar_value=recall, global_step=epoch) 252 | writer.add_scalar(tag='Test/precision', scalar_value=precision, global_step=epoch) 253 | writer.add_scalar(tag='Test/f1', scalar_value=f1, global_step=epoch) 254 | writer.close() 255 | except KeyboardInterrupt: 256 | save_checkpoint('{}/final.pth'.format(config.output_dir), model, optimizer, epoch, logger) 257 | finally: 258 | if best_model['models']: 259 | logger.info(best_model) 260 | 261 | 262 | if __name__ == '__main__': 263 | main() 264 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3/22/19 11:45 AM 3 | # @Author : zhoujun 4 | from .utils import * -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/19/19 3:37 PM 3 | # @Author : zhoujun 4 | from torch.optim.lr_scheduler import MultiStepLR 5 | 6 | 7 | class WarmupMultiStepLR(MultiStepLR): 8 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, 9 | warmup_iters=500, last_epoch=-1): 10 | self.warmup_factor = warmup_factor 11 | self.warmup_iters = warmup_iters 12 | super().__init__(optimizer, milestones, gamma, last_epoch) 13 | 14 | def get_lr(self): 15 | lr = super().get_lr() 16 | if self.last_epoch < self.warmup_iters: 17 | alpha = self.last_epoch / self.warmup_iters 18 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 19 | return [l * warmup_factor for l in lr] 20 | return lr -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 1/4/19 11:18 AM 3 | # @Author : zhoujun 4 | import cv2 5 | import time 6 | import torch 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | def show_img(imgs: np.ndarray, color=False): 12 | if (len(imgs.shape) == 3 and color) or (len(imgs.shape) == 2 and not color): 13 | imgs = np.expand_dims(imgs, axis=0) 14 | for img in imgs: 15 | plt.figure() 16 | plt.imshow(img, cmap=None if color else 'gray') 17 | 18 | 19 | def draw_bbox(img_path, result, color=(255, 0, 0),thickness=2): 20 | if isinstance(img_path, str): 21 | img_path = cv2.imread(img_path) 22 | # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB) 23 | img_path = img_path.copy() 24 | for point in result: 25 | point = point.astype(int) 26 | cv2.line(img_path, tuple(point[0]), tuple(point[1]), color, thickness) 27 | cv2.line(img_path, tuple(point[1]), tuple(point[2]), color, thickness) 28 | cv2.line(img_path, tuple(point[2]), tuple(point[3]), color, thickness) 29 | cv2.line(img_path, tuple(point[3]), tuple(point[0]), color, thickness) 30 | return img_path 31 | 32 | 33 | def setup_logger(log_file_path: str = None): 34 | import logging 35 | from colorlog import ColoredFormatter 36 | logging.basicConfig(filename=log_file_path, format='%(asctime)s %(levelname)-8s %(filename)s: %(message)s', 37 | # 定义输出log的格式 38 | datefmt='%Y-%m-%d %H:%M:%S', ) 39 | """Return a logger with a default ColoredFormatter.""" 40 | formatter = ColoredFormatter("%(asctime)s %(log_color)s%(levelname)-8s %(reset)s %(filename)s: %(message)s", 41 | datefmt='%Y-%m-%d %H:%M:%S', 42 | reset=True, 43 | log_colors={ 44 | 'DEBUG': 'blue', 45 | 'INFO': 'green', 46 | 'WARNING': 'yellow', 47 | 'ERROR': 'red', 48 | 'CRITICAL': 'red', 49 | }) 50 | 51 | logger = logging.getLogger('project') 52 | handler = logging.StreamHandler() 53 | handler.setFormatter(formatter) 54 | logger.addHandler(handler) 55 | logger.setLevel(logging.DEBUG) 56 | logger.info('logger init finished') 57 | return logger 58 | 59 | 60 | def save_checkpoint(checkpoint_path, model, optimizer, epoch, logger): 61 | state = {'state_dict': model.state_dict(), 62 | 'optimizer': optimizer.state_dict(), 63 | 'epoch': epoch} 64 | torch.save(state, checkpoint_path) 65 | logger.info('models saved to %s' % checkpoint_path) 66 | 67 | 68 | def load_checkpoint(checkpoint_path, model, logger, device, optimizer=None): 69 | state = torch.load(checkpoint_path, map_location=device) 70 | model.load_state_dict(state['state_dict']) 71 | if optimizer is not None: 72 | optimizer.load_state_dict(state['optimizer']) 73 | start_epoch = state['epoch'] 74 | logger.info('models loaded from %s' % checkpoint_path) 75 | return start_epoch 76 | 77 | 78 | # --exeTime 79 | def exe_time(func): 80 | def newFunc(*args, **args2): 81 | t0 = time.time() 82 | back = func(*args, **args2) 83 | print("{} cost {:.3f}s".format(func.__name__, time.time() - t0)) 84 | return back 85 | return newFunc 86 | --------------------------------------------------------------------------------