├── .gitignore
├── LICENSE
├── README.md
├── cal_recall
├── __init__.py
├── rrc_evaluation_funcs.py
└── script.py
├── config.py
├── dataset
├── __init__.py
├── augment.py
├── augment_img.py
└── data_utils.py
├── eval.py
├── imgs
├── img_125.jpg
├── img_31.jpg
├── img_73.jpg
├── img_83.jpg
└── img_98.jpg
├── models
├── ShuffleNetV2.py
├── __init__.py
├── loss.py
├── mobilenetv3.py
├── model.py
└── resnet.py
├── predict.py
├── pse
├── Makefile
├── __init__.py
├── include
│ └── pybind11
│ │ ├── attr.h
│ │ ├── buffer_info.h
│ │ ├── cast.h
│ │ ├── chrono.h
│ │ ├── class_support.h
│ │ ├── common.h
│ │ ├── complex.h
│ │ ├── descr.h
│ │ ├── detail
│ │ ├── class.h
│ │ ├── common.h
│ │ ├── descr.h
│ │ ├── init.h
│ │ ├── internals.h
│ │ └── typeid.h
│ │ ├── eigen.h
│ │ ├── embed.h
│ │ ├── eval.h
│ │ ├── functional.h
│ │ ├── iostream.h
│ │ ├── numpy.h
│ │ ├── operators.h
│ │ ├── options.h
│ │ ├── pybind11.h
│ │ ├── pytypes.h
│ │ ├── stl.h
│ │ ├── stl_bind.h
│ │ └── typeid.h
├── ncnn
│ └── examples
│ │ ├── CMakeLists.txt
│ │ ├── psenet.cpp
│ │ └── run.sh
├── pse.cpp
└── pse.so
├── train.py
└── utils
├── __init__.py
├── lr_scheduler.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | __pycache__/
3 | **/.pytest_cache/
4 | output/
5 | result*/
6 | log/
7 | *.npy
8 | *.pyc
9 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Shape Robust Text Detection with Progressive Scale Expansion Network
2 |
3 | ## Requirements
4 | * pytorch 1.1
5 | * torchvision 0.3
6 | * pyclipper
7 | * opencv3
8 | * gcc 4.9+
9 |
10 | ## Update
11 | ### 20190401
12 |
13 | 1. add author loss, the results are compared in [Performance](#Performance)
14 |
15 |
16 | ### Download
17 | resnet50 and resnet152 model on icdar 2015:
18 |
19 | 1. ~~[bauduyun](https://pan.baidu.com/s/1rN0oGBRsdUYmcQUayMZUOA) extract code: rxjf~~
20 |
21 | 2. ~~[google drive](https://drive.google.com/drive/folders/1r3Q1GJ5990WYrwXKT29aHvfQNW92QkXv?usp=sharing)~~
22 |
23 | ## Data Preparation
24 | follow icdar15 dataset format
25 | ```
26 | img
27 | │ 1.jpg
28 | │ 2.jpg
29 | │ ...
30 | gt
31 | │ gt_1.txt
32 | │ gt_2.txt
33 | | ...
34 | ```
35 |
36 | ## Train
37 | 1. config the `trainroot`,`testroot`in [config.py](config.py)
38 | 2. use following script to run
39 | ```sh
40 | python3 train.py
41 | ```
42 |
43 | ## Test
44 | [eval.py](eval.py) is used to test model on test dataset
45 |
46 | 1. config `model_path`, `data_path`, `gt_path`, `save_path` in [eval.py](eval.py)
47 | 2. use following script to test
48 | ```sh
49 | python3 eval.py
50 | ```
51 |
52 | ## Predict
53 | [predict.py](predict.py) is used to inference on single image
54 |
55 | 1. config `model_path`, `img_path`, `gt_path`, `save_path` in [predict.py](predict.py)
56 | 2. use following script to predict
57 | ```sh
58 | python3 predict.py
59 | ```
60 |
61 |
62 |
63 |
64 | ### [ICDAR 2015](http://rrc.cvc.uab.es/?ch=4)
65 | only train on ICDAR2015 dataset with single NVIDIA 1080Ti
66 |
67 | my implementation with my loss use adam and warm_up
68 |
69 | | Method | Precision (%) | Recall (%) | F-measure (%) | FPS(1080Ti) |
70 | |--------------------------|---------------|------------|---------------|-----|
71 | | PSENet-1s with resnet50 batch 8 | 81.13 | 77.03 | 79.03 | 1.76 |
72 | | PSENet-2s with resnet50 batch 8 | 81.36 | 77.13 | 79.18 | 3.55 |
73 | | PSENet-4s with resnet50 batch 8 | 81.00 | 76.55 | 78.71 | 4.43 |
74 | | PSENet-1s with resnet152 batch 4 | 85.45 | 80.06 | 82.67 | 1.48 |
75 | | PSENet-2s with resnet152 batch 4 | 85.42 | 80.11 | 82.68 | 2.56 |
76 | | PSENet-4s with resnet152 batch 4 | 83.93 | 79.00 | 81.39 | 2.99 |
77 |
78 | my implementation with my loss use adam and MultiStepLR
79 |
80 | | Method | Precision (%) | Recall (%) | F-measure (%) | FPS(1080Ti) |
81 | |--------------------------|---------------|------------|---------------|-----|
82 | | PSENet-1s with resnet50 batch 8 | 83.39 | 79.29 | 81.29 | 1.76 |
83 | | PSENet-2s with resnet50 batch 8 | 83.22 | 79.05 | 81.08 | 3.55 |
84 | | PSENet-4s with resnet50 batch 8 | 82.57 | 78.23 | 80.34 | 4.43 |
85 | | PSENet-1s with resnet152 batch 4 | 85.33 | 79.87 | 82.51 | 1.48 |
86 | | PSENet-2s with resnet152 batch 4 | 85.36 | 79.73 | 82.45 | 2.56 |
87 | | PSENet-4s with resnet152 batch 4 | 83.95 | 78.86 | 81.33 | 2.99 |
88 |
89 | my implementation with author loss use adam and warm_up
90 |
91 | | Method | Precision (%) | Recall (%) | F-measure (%) | FPS(1080Ti) |
92 | |--------------------------|---------------|------------|---------------|-----|
93 | | PSENet-1s with resnet50 batch 8 | 83.33 | 77.75 | 80.44 | 1.76 |
94 | | PSENet-2s with resnet50 batch 8 | 83.01 | 77.66 | 80.24 | 3.55 |
95 | | PSENet-4s with resnet50 batch 8 | 82.38 | 76.98 | 79.59 | 4.43 |
96 | | PSENet-1s with resnet152 batch 4 | 85.16 | 79.87 | 82.43 | 1.48 |
97 | | PSENet-2s with resnet152 batch 4 | 85.03 | 79.63 | 82.24 | 2.56 |
98 | | PSENet-4s with resnet152 batch 4 | 84.53S | 79.20 | 81.77 | 2.99 |
99 |
100 | my implementation with author loss use adam and MultiStepLR
101 |
102 | | Method | Precision (%) | Recall (%) | F-measure (%) | FPS(1080Ti) |
103 | |--------------------------|---------------|------------|---------------|-----|
104 | | PSENet-1s with resnet50 batch 8 | 83.93 | 79.48 | 81.65 | 1.76 |
105 | | PSENet-2s with resnet50 batch 8 | 84.17 | 79.63 | 81.84 | 3.55 |
106 | | PSENet-4s with resnet50 batch 8 | 83.50 | 78.71 | 81.04 | 4.43 |
107 | | PSENet-1s with resnet152 batch 4 | 85.16 | 79.58 | 82.28 | 1.48 |
108 | | PSENet-2s with resnet152 batch 4 | 85.13 | 79.15 | 82.03 | 2.56 |
109 | | PSENet-4s with resnet152 batch 4 | 84.40 | 78.71 | 81.46 | 2.99 |
110 |
111 | official implementation use SGD and StepLR
112 |
113 | | Method | Precision (%) | Recall (%) | F-measure (%) | FPS(1080Ti) |
114 | |--------------------------|---------------|------------|---------------|-----|
115 | | PSENet-1s with resnet50 batch 8 | 84.15 | 80.26 | 82.16 | 1.76 |
116 | | PSENet-2s with resnet50 batch 8 | 83.61 | 79.82 | 81.67 | 3.72 |
117 | | PSENet-4s with resnet50 batch 8 | 81.90 | 78.23 | 80.03 | 4.51 |
118 | | PSENet-1s with resnet152 batch 4 | 82.87 | 78.76 | 80.77 | 1.53 |
119 | | PSENet-2s with resnet152 batch 4 | 82.33 | 78.33 | 80.28 | 2.61 |
120 | | PSENet-4s with resnet152 batch 4 | 81.19 | 77.13 | 79.11 | 3.00 |
121 |
122 | ### examples
123 | 
124 |
125 | 
126 |
127 | 
128 |
129 | 
130 |
131 | 
132 |
133 | ### reference
134 | 1. https://github.com/liuheng92/tensorflow_PSENet
135 | 2. https://github.com/whai362/PSENet
136 |
--------------------------------------------------------------------------------
/cal_recall/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/16/19 6:40 AM
3 | # @Author : zhoujun
4 | from .script import cal_recall_precison_f1
5 | __all__ = ['cal_recall_precison_f1']
--------------------------------------------------------------------------------
/cal_recall/script.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | from collections import namedtuple
4 | from . import rrc_evaluation_funcs
5 | import Polygon as plg
6 | import numpy as np
7 |
8 |
9 | def default_evaluation_params():
10 | """
11 | default_evaluation_params: Default parameters to use for the validation and evaluation.
12 | """
13 | return {
14 | 'IOU_CONSTRAINT': 0.5,
15 | 'AREA_PRECISION_CONSTRAINT': 0.5,
16 | 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt',
17 | 'DET_SAMPLE_NAME_2_ID': 'res_img_([0-9]+).txt',
18 | 'LTRB': False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4)
19 | 'CRLF': False, # Lines are delimited by Windows CRLF format
20 | 'CONFIDENCES': False, # Detections must include confidence value. AP will be calculated
21 | 'PER_SAMPLE_RESULTS': True # Generate per sample results and produce data for visualization
22 | }
23 |
24 |
25 | def validate_data(gtFilePath, submFilePath, evaluationParams):
26 | """
27 | Method validate_data: validates that all files in the results folder are correct (have the correct name contents).
28 | Validates also that there are no missing files in the folder.
29 | If some error detected, the method raises the error
30 | """
31 | gt = rrc_evaluation_funcs.load_folder_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
32 |
33 | subm = rrc_evaluation_funcs.load_folder_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
34 |
35 | # Validate format of GroundTruth
36 | for k in gt:
37 | rrc_evaluation_funcs.validate_lines_in_file(k, gt[k], evaluationParams['CRLF'], evaluationParams['LTRB'], True)
38 |
39 | # Validate format of results
40 | for k in subm:
41 | if (k in gt) == False:
42 | raise Exception("The sample %s not present in GT" % k)
43 |
44 | rrc_evaluation_funcs.validate_lines_in_file(k, subm[k], evaluationParams['CRLF'], evaluationParams['LTRB'],
45 | False, evaluationParams['CONFIDENCES'])
46 |
47 |
48 | def evaluate_method(gtFilePath, submFilePath, evaluationParams):
49 | """
50 | Method evaluate_method: evaluate method and returns the results
51 | Results. Dictionary with the following values:
52 | - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 }
53 | - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 }
54 | """
55 |
56 | def polygon_from_points(points):
57 | """
58 | Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
59 | """
60 | resBoxes = np.empty([1, 8], dtype='int32')
61 | resBoxes[0, 0] = int(points[0])
62 | resBoxes[0, 4] = int(points[1])
63 | resBoxes[0, 1] = int(points[2])
64 | resBoxes[0, 5] = int(points[3])
65 | resBoxes[0, 2] = int(points[4])
66 | resBoxes[0, 6] = int(points[5])
67 | resBoxes[0, 3] = int(points[6])
68 | resBoxes[0, 7] = int(points[7])
69 | pointMat = resBoxes[0].reshape([2, 4]).T
70 | return plg.Polygon(pointMat)
71 |
72 | def rectangle_to_polygon(rect):
73 | resBoxes = np.empty([1, 8], dtype='int32')
74 | resBoxes[0, 0] = int(rect.xmin)
75 | resBoxes[0, 4] = int(rect.ymax)
76 | resBoxes[0, 1] = int(rect.xmin)
77 | resBoxes[0, 5] = int(rect.ymin)
78 | resBoxes[0, 2] = int(rect.xmax)
79 | resBoxes[0, 6] = int(rect.ymin)
80 | resBoxes[0, 3] = int(rect.xmax)
81 | resBoxes[0, 7] = int(rect.ymax)
82 |
83 | pointMat = resBoxes[0].reshape([2, 4]).T
84 |
85 | return plg.Polygon(pointMat)
86 |
87 | def rectangle_to_points(rect):
88 | points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin),
89 | int(rect.xmin), int(rect.ymin)]
90 | return points
91 |
92 | def get_union(pD, pG):
93 | areaA = pD.area();
94 | areaB = pG.area();
95 | return areaA + areaB - get_intersection(pD, pG);
96 |
97 | def get_intersection_over_union(pD, pG):
98 | try:
99 | return get_intersection(pD, pG) / get_union(pD, pG);
100 | except:
101 | return 0
102 |
103 | def get_intersection(pD, pG):
104 | pInt = pD & pG
105 | if len(pInt) == 0:
106 | return 0
107 | return pInt.area()
108 |
109 | def compute_ap(confList, matchList, numGtCare):
110 | correct = 0
111 | AP = 0
112 | if len(confList) > 0:
113 | confList = np.array(confList)
114 | matchList = np.array(matchList)
115 | sorted_ind = np.argsort(-confList)
116 | confList = confList[sorted_ind]
117 | matchList = matchList[sorted_ind]
118 | for n in range(len(confList)):
119 | match = matchList[n]
120 | if match:
121 | correct += 1
122 | AP += float(correct) / (n + 1)
123 |
124 | if numGtCare > 0:
125 | AP /= numGtCare
126 |
127 | return AP
128 |
129 | perSampleMetrics = {}
130 |
131 | matchedSum = 0
132 |
133 | Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax')
134 |
135 | gt = rrc_evaluation_funcs.load_folder_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID'])
136 | subm = rrc_evaluation_funcs.load_folder_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True)
137 |
138 | numGlobalCareGt = 0;
139 | numGlobalCareDet = 0;
140 |
141 | arrGlobalConfidences = [];
142 | arrGlobalMatches = [];
143 |
144 | for resFile in gt:
145 |
146 | gtFile = gt[resFile] # rrc_evaluation_funcs.decode_utf8(gt[resFile])
147 | recall = 0
148 | precision = 0
149 | hmean = 0
150 |
151 | detMatched = 0
152 |
153 | iouMat = np.empty([1, 1])
154 |
155 | gtPols = []
156 | detPols = []
157 |
158 | gtPolPoints = []
159 | detPolPoints = []
160 |
161 | # Array of Ground Truth Polygons' keys marked as don't Care
162 | gtDontCarePolsNum = []
163 | # Array of Detected Polygons' matched with a don't Care GT
164 | detDontCarePolsNum = []
165 |
166 | pairs = []
167 | detMatchedNums = []
168 |
169 | arrSampleConfidences = [];
170 | arrSampleMatch = [];
171 | sampleAP = 0;
172 |
173 | evaluationLog = ""
174 |
175 | pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,
176 | evaluationParams[
177 | 'CRLF'],
178 | evaluationParams[
179 | 'LTRB'],
180 | True, False)
181 | for n in range(len(pointsList)):
182 | points = pointsList[n]
183 | transcription = transcriptionsList[n]
184 | dontCare = transcription == "###"
185 | if evaluationParams['LTRB']:
186 | gtRect = Rectangle(*points)
187 | gtPol = rectangle_to_polygon(gtRect)
188 | else:
189 | gtPol = polygon_from_points(points)
190 | gtPols.append(gtPol)
191 | gtPolPoints.append(points)
192 | if dontCare:
193 | gtDontCarePolsNum.append(len(gtPols) - 1)
194 |
195 | evaluationLog += "GT polygons: " + str(len(gtPols)) + (
196 | " (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n")
197 |
198 | if resFile in subm:
199 |
200 | detFile = subm[resFile] # rrc_evaluation_funcs.decode_utf8(subm[resFile])
201 |
202 | pointsList, confidencesList, _ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,
203 | evaluationParams[
204 | 'CRLF'],
205 | evaluationParams[
206 | 'LTRB'],
207 | False,
208 | evaluationParams[
209 | 'CONFIDENCES'])
210 | for n in range(len(pointsList)):
211 | points = pointsList[n]
212 |
213 | if evaluationParams['LTRB']:
214 | detRect = Rectangle(*points)
215 | detPol = rectangle_to_polygon(detRect)
216 | else:
217 | detPol = polygon_from_points(points)
218 | detPols.append(detPol)
219 | detPolPoints.append(points)
220 | if len(gtDontCarePolsNum) > 0:
221 | for dontCarePol in gtDontCarePolsNum:
222 | dontCarePol = gtPols[dontCarePol]
223 | intersected_area = get_intersection(dontCarePol, detPol)
224 | pdDimensions = detPol.area()
225 | precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions
226 | if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT']):
227 | detDontCarePolsNum.append(len(detPols) - 1)
228 | break
229 |
230 | evaluationLog += "DET polygons: " + str(len(detPols)) + (
231 | " (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n")
232 |
233 | if len(gtPols) > 0 and len(detPols) > 0:
234 | # Calculate IoU and precision matrixs
235 | outputShape = [len(gtPols), len(detPols)]
236 | iouMat = np.empty(outputShape)
237 | gtRectMat = np.zeros(len(gtPols), np.int8)
238 | detRectMat = np.zeros(len(detPols), np.int8)
239 | for gtNum in range(len(gtPols)):
240 | for detNum in range(len(detPols)):
241 | pG = gtPols[gtNum]
242 | pD = detPols[detNum]
243 | iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG)
244 |
245 | for gtNum in range(len(gtPols)):
246 | for detNum in range(len(detPols)):
247 | if gtRectMat[gtNum] == 0 and detRectMat[
248 | detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum:
249 | if iouMat[gtNum, detNum] > evaluationParams['IOU_CONSTRAINT']:
250 | gtRectMat[gtNum] = 1
251 | detRectMat[detNum] = 1
252 | detMatched += 1
253 | pairs.append({'gt': gtNum, 'det': detNum})
254 | detMatchedNums.append(detNum)
255 | evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n"
256 |
257 | if evaluationParams['CONFIDENCES']:
258 | for detNum in range(len(detPols)):
259 | if detNum not in detDontCarePolsNum:
260 | # we exclude the don't care detections
261 | match = detNum in detMatchedNums
262 |
263 | arrSampleConfidences.append(confidencesList[detNum])
264 | arrSampleMatch.append(match)
265 |
266 | arrGlobalConfidences.append(confidencesList[detNum]);
267 | arrGlobalMatches.append(match);
268 |
269 | numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
270 | numDetCare = (len(detPols) - len(detDontCarePolsNum))
271 | if numGtCare == 0:
272 | recall = float(1)
273 | precision = float(0) if numDetCare > 0 else float(1)
274 | sampleAP = precision
275 | else:
276 | recall = float(detMatched) / numGtCare
277 | precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
278 | if evaluationParams['CONFIDENCES'] and evaluationParams['PER_SAMPLE_RESULTS']:
279 | sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare)
280 |
281 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall)
282 |
283 | matchedSum += detMatched
284 | numGlobalCareGt += numGtCare
285 | numGlobalCareDet += numDetCare
286 |
287 | if evaluationParams['PER_SAMPLE_RESULTS']:
288 | perSampleMetrics[resFile] = {
289 | 'precision': precision,
290 | 'recall': recall,
291 | 'hmean': hmean,
292 | 'pairs': pairs,
293 | 'AP': sampleAP,
294 | 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
295 | 'gtPolPoints': gtPolPoints,
296 | 'detPolPoints': detPolPoints,
297 | 'gtDontCare': gtDontCarePolsNum,
298 | 'detDontCare': detDontCarePolsNum,
299 | 'evaluationParams': evaluationParams,
300 | 'evaluationLog': evaluationLog
301 | }
302 |
303 | # Compute MAP and MAR
304 | AP = 0
305 | if evaluationParams['CONFIDENCES']:
306 | AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt)
307 |
308 | methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum) / numGlobalCareGt
309 | methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum) / numGlobalCareDet
310 | methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / (
311 | methodRecall + methodPrecision)
312 |
313 | methodMetrics = {'precision': methodPrecision, 'recall': methodRecall, 'hmean': methodHmean, 'AP': AP}
314 |
315 | resDict = {'calculated': True, 'Message': '', 'method': methodMetrics, 'per_sample': perSampleMetrics}
316 |
317 | return resDict;
318 |
319 |
320 | def cal_recall_precison_f1(gt_path, result_path, show_result=False):
321 | p = {'g': gt_path, 's': result_path}
322 | result = rrc_evaluation_funcs.main_evaluation(p, default_evaluation_params, validate_data, evaluate_method,
323 | show_result)
324 | return result['method']
325 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/1/3 17:40
3 | # @Author : zhoujun
4 |
5 | # data config
6 | trainroot = '/data2/dataset/ICD15/train'
7 | testroot = '/data2/dataset/ICD15/test'
8 | output_dir = 'output/psenet_icd2015_resnet152_4gpu_author_crop_adam_MultiStepLR_authorloss'
9 | data_shape = 640
10 |
11 | # train config
12 | gpu_id = '2'
13 | workers = 12
14 | start_epoch = 0
15 | epochs = 600
16 |
17 | train_batch_size = 4
18 |
19 | lr = 1e-4
20 | end_lr = 1e-7
21 | lr_gamma = 0.1
22 | lr_decay_step = [200,400]
23 | weight_decay = 5e-4
24 | warm_up_epoch = 6
25 | warm_up_lr = lr * lr_gamma
26 |
27 | display_input_images = False
28 | display_output_images = False
29 | display_interval = 10
30 | show_images_interval = 50
31 |
32 | pretrained = True
33 | restart_training = True
34 | checkpoint = ''
35 |
36 | # net config
37 | backbone = 'resnet152'
38 | Lambda = 0.7
39 | n = 6
40 | m = 0.5
41 | OHEM_ratio = 3
42 | scale = 1
43 | # random seed
44 | seed = 2
45 |
46 |
47 | def print():
48 | from pprint import pformat
49 | tem_d = {}
50 | for k, v in globals().items():
51 | if not k.startswith('_') and not callable(v):
52 | tem_d[k] = v
53 | return pformat(tem_d)
54 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/17/19 2:09 AM
3 | # @Author : zhoujun
--------------------------------------------------------------------------------
/dataset/augment.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/1/12 13:06
3 |
4 | import cv2
5 | import numbers
6 | import math
7 | import random
8 | import numpy as np
9 | from skimage.util import random_noise
10 |
11 |
12 | def show_pic(img, bboxes=None, name='pic'):
13 | '''
14 | 输入:
15 | img:图像array
16 | bboxes:图像的所有boudning box list, 格式为[[x_min, y_min, x_max, y_max]....]
17 | names:每个box对应的名称
18 | '''
19 | show_img = img.copy()
20 | if not isinstance(bboxes, np.ndarray):
21 | bboxes = np.array(bboxes)
22 | for point in bboxes.astype(np.int):
23 | cv2.line(show_img, tuple(point[0]), tuple(point[1]), (255, 0, 0), 2)
24 | cv2.line(show_img, tuple(point[1]), tuple(point[2]), (255, 0, 0), 2)
25 | cv2.line(show_img, tuple(point[2]), tuple(point[3]), (255, 0, 0), 2)
26 | cv2.line(show_img, tuple(point[3]), tuple(point[0]), (255, 0, 0), 2)
27 | # cv2.namedWindow(name, 0) # 1表示原图
28 | # cv2.moveWindow(name, 0, 0)
29 | # cv2.resizeWindow(name, 1200, 800) # 可视化的图片大小
30 | cv2.imshow(name, show_img)
31 |
32 |
33 | # 图像均为cv2读取
34 | class DataAugment():
35 | def __init__(self):
36 | pass
37 |
38 | def add_noise(self, im: np.ndarray):
39 | """
40 | 对图片加噪声
41 | :param img: 图像array
42 | :return: 加噪声后的图像array,由于输出的像素是在[0,1]之间,所以得乘以255
43 | """
44 | return (random_noise(im, mode='gaussian', clip=True) * 255).astype(im.dtype)
45 |
46 | def random_scale(self, im: np.ndarray, text_polys: np.ndarray, scales: np.ndarray or list) -> tuple:
47 | """
48 | 从scales中随机选择一个尺度,对图片和文本框进行缩放
49 | :param im: 原图
50 | :param text_polys: 文本框
51 | :param scales: 尺度
52 | :return: 经过缩放的图片和文本
53 | """
54 | tmp_text_polys = text_polys.copy()
55 | rd_scale = float(np.random.choice(scales))
56 | im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
57 | tmp_text_polys *= rd_scale
58 | return im, tmp_text_polys
59 |
60 | def random_rotate_img_bbox(self, img, text_polys, degrees: numbers.Number or list or tuple or np.ndarray,
61 | same_size=False):
62 | """
63 | 从给定的角度中选择一个角度,对图片和文本框进行旋转
64 | :param img: 图片
65 | :param text_polys: 文本框
66 | :param degrees: 角度,可以是一个数值或者list
67 | :param same_size: 是否保持和原图一样大
68 | :return: 旋转后的图片和角度
69 | """
70 | if isinstance(degrees, numbers.Number):
71 | if degrees < 0:
72 | raise ValueError("If degrees is a single number, it must be positive.")
73 | degrees = (-degrees, degrees)
74 | elif isinstance(degrees, list) or isinstance(degrees, tuple) or isinstance(degrees, np.ndarray):
75 | if len(degrees) != 2:
76 | raise ValueError("If degrees is a sequence, it must be of len 2.")
77 | degrees = degrees
78 | else:
79 | raise Exception('degrees must in Number or list or tuple or np.ndarray')
80 | # ---------------------- 旋转图像 ----------------------
81 | w = img.shape[1]
82 | h = img.shape[0]
83 | angle = np.random.uniform(degrees[0], degrees[1])
84 |
85 | if same_size:
86 | nw = w
87 | nh = h
88 | else:
89 | # 角度变弧度
90 | rangle = np.deg2rad(angle)
91 | # 计算旋转之后图像的w, h
92 | nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
93 | nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
94 | # 构造仿射矩阵
95 | rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1)
96 | # 计算原图中心点到新图中心点的偏移量
97 | rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
98 | # 更新仿射矩阵
99 | rot_mat[0, 2] += rot_move[0]
100 | rot_mat[1, 2] += rot_move[1]
101 | # 仿射变换
102 | rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
103 |
104 | # ---------------------- 矫正bbox坐标 ----------------------
105 | # rot_mat是最终的旋转矩阵
106 | # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下
107 | rot_text_polys = list()
108 | for bbox in text_polys:
109 | point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
110 | point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
111 | point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
112 | point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
113 | rot_text_polys.append([point1, point2, point3, point4])
114 | return rot_img, np.array(rot_text_polys, dtype=np.float32)
115 |
116 | def random_crop_img_bboxes(self, im: np.ndarray, text_polys: np.ndarray, max_tries=50) -> tuple:
117 | """
118 | 从图片中裁剪出 cropsize大小的图片和对应区域的文本框
119 | :param im: 图片
120 | :param text_polys: 文本框
121 | :param max_tries: 最大尝试次数
122 | :return: 裁剪后的图片和文本框
123 | """
124 | h, w, _ = im.shape
125 | pad_h = h // 10
126 | pad_w = w // 10
127 | h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
128 | w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
129 | for poly in text_polys:
130 | poly = np.round(poly, decimals=0).astype(np.int32) # 四舍五入取整
131 | minx = np.min(poly[:, 0])
132 | maxx = np.max(poly[:, 0])
133 | w_array[minx + pad_w:maxx + pad_w] = 1 # 将文本区域的在w_array上设为1,表示x轴方向上这部分位置有文本
134 | miny = np.min(poly[:, 1])
135 | maxy = np.max(poly[:, 1])
136 | h_array[miny + pad_h:maxy + pad_h] = 1 # 将文本区域的在h_array上设为1,表示y轴方向上这部分位置有文本
137 | # 在两个轴上 拿出背景位置去进行随机的位置选择,避免选择的区域穿过文本
138 | h_axis = np.where(h_array == 0)[0]
139 | w_axis = np.where(w_array == 0)[0]
140 | if len(h_axis) == 0 or len(w_axis) == 0:
141 | # 整张图全是文本的情况下,直接返回
142 | return im, text_polys
143 | for i in range(max_tries):
144 | xx = np.random.choice(w_axis, size=2)
145 | # 对选择区域进行边界控制
146 | xmin = np.min(xx) - pad_w
147 | xmax = np.max(xx) - pad_w
148 | xmin = np.clip(xmin, 0, w - 1)
149 | xmax = np.clip(xmax, 0, w - 1)
150 | yy = np.random.choice(h_axis, size=2)
151 | ymin = np.min(yy) - pad_h
152 | ymax = np.max(yy) - pad_h
153 | ymin = np.clip(ymin, 0, h - 1)
154 | ymax = np.clip(ymax, 0, h - 1)
155 | if xmax - xmin < 0.1 * w or ymax - ymin < 0.1 * h:
156 | # 选择的区域过小
157 | # area too small
158 | continue
159 | if text_polys.shape[0] != 0: # 这个判断不知道干啥的
160 | poly_axis_in_area = (text_polys[:, :, 0] >= xmin) & (text_polys[:, :, 0] <= xmax) \
161 | & (text_polys[:, :, 1] >= ymin) & (text_polys[:, :, 1] <= ymax)
162 | selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
163 | else:
164 | selected_polys = []
165 | if len(selected_polys) == 0:
166 | # 区域内没有文本
167 | continue
168 | im = im[ymin:ymax + 1, xmin:xmax + 1, :]
169 | polys = text_polys[selected_polys]
170 | # 坐标调整到裁剪图片上
171 | polys[:, :, 0] -= xmin
172 | polys[:, :, 1] -= ymin
173 | return im, polys
174 | return im, text_polys
175 |
176 | def random_crop_image_pse(self, im: np.ndarray, text_polys: np.ndarray, input_size) -> tuple:
177 | """
178 | 从图片中裁剪出 cropsize大小的图片和对应区域的文本框
179 | :param im: 图片
180 | :param text_polys: 文本框
181 | :param input_size: 输出图像大小
182 | :return: 裁剪后的图片和文本框
183 | """
184 | h, w, _ = im.shape
185 | short_edge = min(h, w)
186 | if short_edge < input_size:
187 | # 保证短边 >= inputsize
188 | scale = input_size / short_edge
189 | im = cv2.resize(im, dsize=None, fx=scale, fy=scale)
190 | text_polys *= scale
191 | h, w, _ = im.shape
192 | # 计算随机范围
193 | w_range = w - input_size
194 | h_range = h - input_size
195 | for _ in range(50):
196 | xmin = random.randint(0, w_range)
197 | ymin = random.randint(0, h_range)
198 | xmax = xmin + input_size
199 | ymax = ymin + input_size
200 | if text_polys.shape[0] != 0:
201 | selected_polys = []
202 | for poly in text_polys:
203 | if poly[:, 0].max() < xmin or poly[:, 0].min() > xmax or \
204 | poly[:, 1].max() < ymin or poly[:, 1].min() > ymax:
205 | continue
206 | # area_p = cv2.contourArea(poly)
207 | poly[:, 0] -= xmin
208 | poly[:, 1] -= ymin
209 | poly[:, 0] = np.clip(poly[:, 0], 0, input_size)
210 | poly[:, 1] = np.clip(poly[:, 1], 0, input_size)
211 | # rect = cv2.minAreaRect(poly)
212 | # area_n = cv2.contourArea(poly)
213 | # h1, w1 = rect[1]
214 | # if w1 < 10 or h1 < 10 or area_n / area_p < 0.5:
215 | # continue
216 | selected_polys.append(poly)
217 | else:
218 | selected_polys = []
219 | # if len(selected_polys) == 0:
220 | # 区域内没有文本
221 | # continue
222 | im = im[ymin:ymax, xmin:xmax, :]
223 | polys = np.array(selected_polys)
224 | return im, polys
225 | return im, text_polys
226 |
227 | def random_crop_author(self,imgs, img_size):
228 | h, w = imgs[0].shape[0:2]
229 | th, tw = img_size
230 | if w == tw and h == th:
231 | return imgs
232 |
233 | # label中存在文本实例,并且按照概率进行裁剪
234 | if np.max(imgs[1][:,:,-1]) > 0 and random.random() > 3.0 / 8.0:
235 | # 文本实例的top left点
236 | tl = np.min(np.where(imgs[1][:,:,-1] > 0), axis=1) - img_size
237 | tl[tl < 0] = 0
238 | # 文本实例的 bottom right 点
239 | br = np.max(np.where(imgs[1][:,:,-1] > 0), axis=1) - img_size
240 | br[br < 0] = 0
241 | # 保证选到右下角点是,有足够的距离进行crop
242 | br[0] = min(br[0], h - th)
243 | br[1] = min(br[1], w - tw)
244 | for _ in range(50000):
245 | i = random.randint(tl[0], br[0])
246 | j = random.randint(tl[1], br[1])
247 | # 保证最小的图有文本
248 | if imgs[1][:,:,0][i:i + th, j:j + tw].sum() <= 0:
249 | continue
250 | else:
251 | break
252 | else:
253 | i = random.randint(0, h - th)
254 | j = random.randint(0, w - tw)
255 |
256 | # return i, j, th, tw
257 | for idx in range(len(imgs)):
258 | if len(imgs[idx].shape) == 3:
259 | imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
260 | else:
261 | imgs[idx] = imgs[idx][i:i + th, j:j + tw]
262 | return imgs
263 |
264 | def resize(self, im: np.ndarray, text_polys: np.ndarray,
265 | input_size: numbers.Number or list or tuple or np.ndarray, keep_ratio: bool = False) -> tuple:
266 | """
267 | 对图片和文本框进行resize
268 | :param im: 图片
269 | :param text_polys: 文本框
270 | :param input_size: resize尺寸,数字或者list的形式,如果为list形式,就是[w,h]
271 | :param keep_ratio: 是否保持长宽比
272 | :return: resize后的图片和文本框
273 | """
274 | if isinstance(input_size, numbers.Number):
275 | if input_size < 0:
276 | raise ValueError("If input_size is a single number, it must be positive.")
277 | input_size = (input_size, input_size)
278 | elif isinstance(input_size, list) or isinstance(input_size, tuple) or isinstance(input_size, np.ndarray):
279 | if len(input_size) != 2:
280 | raise ValueError("If input_size is a sequence, it must be of len 2.")
281 | input_size = (input_size[0], input_size[1])
282 | else:
283 | raise Exception('input_size must in Number or list or tuple or np.ndarray')
284 | if keep_ratio:
285 | # 将图片短边pad到和长边一样
286 | h, w, c = im.shape
287 | max_h = max(h, input_size[0])
288 | max_w = max(w, input_size[1])
289 | im_padded = np.zeros((max_h, max_w, c), dtype=np.uint8)
290 | im_padded[:h, :w] = im.copy()
291 | im = im_padded
292 | text_polys = text_polys.astype(np.float32)
293 | h, w, _ = im.shape
294 | im = cv2.resize(im, input_size)
295 | w_scale = input_size[0] / float(w)
296 | h_scale = input_size[1] / float(h)
297 | text_polys[:, :, 0] *= w_scale
298 | text_polys[:, :, 1] *= h_scale
299 | return im, text_polys
300 |
301 | def horizontal_flip(self, im: np.ndarray, text_polys: np.ndarray) -> tuple:
302 | """
303 | 对图片和文本框进行水平翻转
304 | :param im: 图片
305 | :param text_polys: 文本框
306 | :return: 水平翻转之后的图片和文本框
307 | """
308 | flip_text_polys = text_polys.copy()
309 | flip_im = cv2.flip(im, 1)
310 | h, w, _ = flip_im.shape
311 | flip_text_polys[:, :, 0] = w - flip_text_polys[:, :, 0]
312 | return flip_im, flip_text_polys
313 |
314 | def vertical_flip(self, im: np.ndarray, text_polys: np.ndarray) -> tuple:
315 | """
316 | 对图片和文本框进行竖直翻转
317 | :param im: 图片
318 | :param text_polys: 文本框
319 | :return: 竖直翻转之后的图片和文本框
320 | """
321 | flip_text_polys = text_polys.copy()
322 | flip_im = cv2.flip(im, 0)
323 | h, w, _ = flip_im.shape
324 | flip_text_polys[:, :, 1] = h - flip_text_polys[:, :, 1]
325 | return flip_im, flip_text_polys
326 |
327 | def test(self, im: np.ndarray, text_polys: np.ndarray):
328 | print('随机尺度缩放')
329 | t_im, t_text_polys = self.random_scale(im, text_polys, [0.5, 1, 2, 3])
330 | print(t_im.shape, t_text_polys.dtype)
331 | show_pic(t_im, t_text_polys, 'random_scale')
332 |
333 | print('随机旋转')
334 | t_im, t_text_polys = self.random_rotate_img_bbox(im, text_polys, 10)
335 | print(t_im.shape, t_text_polys.dtype)
336 | show_pic(t_im, t_text_polys, 'random_rotate_img_bbox')
337 |
338 | print('随机裁剪')
339 | t_im, t_text_polys = self.random_crop_img_bboxes(im, text_polys)
340 | print(t_im.shape, t_text_polys.dtype)
341 | show_pic(t_im, t_text_polys, 'random_crop_img_bboxes')
342 |
343 | print('水平翻转')
344 | t_im, t_text_polys = self.horizontal_flip(im, text_polys)
345 | print(t_im.shape, t_text_polys.dtype)
346 | show_pic(t_im, t_text_polys, 'horizontal_flip')
347 |
348 | print('竖直翻转')
349 | t_im, t_text_polys = self.vertical_flip(im, text_polys)
350 | print(t_im.shape, t_text_polys.dtype)
351 | show_pic(t_im, t_text_polys, 'vertical_flip')
352 | show_pic(im, text_polys, 'vertical_flip_ori')
353 |
354 | print('加噪声')
355 | t_im = self.add_noise(im)
356 | print(t_im.shape)
357 | show_pic(t_im, text_polys, 'add_noise')
358 | show_pic(im, text_polys, 'add_noise_ori')
359 |
--------------------------------------------------------------------------------
/dataset/augment_img.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/1/12 13:06
3 |
4 | import cv2
5 | import numbers
6 | import math
7 | import random
8 | import numpy as np
9 | from skimage.util import random_noise
10 |
11 |
12 | def show_pic(img, bboxes=None, name='pic'):
13 | '''
14 | 输入:
15 | img:图像array
16 | bboxes:图像的所有boudning box list, 格式为[[x_min, y_min, x_max, y_max]....]
17 | names:每个box对应的名称
18 | '''
19 | show_img = img.copy()
20 | if not isinstance(bboxes, np.ndarray):
21 | bboxes = np.array(bboxes)
22 | for point in bboxes.astype(np.int):
23 | cv2.line(show_img, tuple(point[0]), tuple(point[1]), (255, 0, 0), 2)
24 | cv2.line(show_img, tuple(point[1]), tuple(point[2]), (255, 0, 0), 2)
25 | cv2.line(show_img, tuple(point[2]), tuple(point[3]), (255, 0, 0), 2)
26 | cv2.line(show_img, tuple(point[3]), tuple(point[0]), (255, 0, 0), 2)
27 | # cv2.namedWindow(name, 0) # 1表示原图
28 | # cv2.moveWindow(name, 0, 0)
29 | # cv2.resizeWindow(name, 1200, 800) # 可视化的图片大小
30 | cv2.imshow(name, show_img)
31 |
32 |
33 | # 图像均为cv2读取
34 | class DataAugment():
35 | def __init__(self):
36 | pass
37 |
38 | def add_noise(self, im: np.ndarray):
39 | """
40 | 对图片加噪声
41 | :param img: 图像array
42 | :return: 加噪声后的图像array,由于输出的像素是在[0,1]之间,所以得乘以255
43 | """
44 | return (random_noise(im, mode='gaussian', clip=True) * 255).astype(im.dtype)
45 |
46 | def random_scale(self, imgs: list, scales: np.ndarray or list, input_size: int) -> list:
47 | """
48 | 从scales中随机选择一个尺度,对图片和文本框进行缩放
49 | :param imgs: 原图 和 label
50 | :param scales: 尺度
51 | :param input_size: 图片短边的长度
52 | :return: 经过缩放的图片和文本
53 | """
54 | rd_scale = float(np.random.choice(scales))
55 | for idx in range(len(imgs)):
56 | imgs[idx] = cv2.resize(imgs[idx], dsize=None, fx=rd_scale, fy=rd_scale)
57 | imgs[idx], _ = self.rescale(imgs[idx], min_side=input_size)
58 | return imgs
59 |
60 | def rescale(self, img, min_side):
61 | h, w = img.shape[:2]
62 | scale = 1.0
63 | if min(h, w) < min_side:
64 | if h <= w:
65 | scale = 1.0 * min_side / h
66 | else:
67 | scale = 1.0 * min_side / w
68 | img = cv2.resize(img, dsize=None, fx=scale, fy=scale)
69 | return img
70 |
71 | def random_rotate_img_bbox(self, imgs, degrees: numbers.Number or list or tuple or np.ndarray,
72 | same_size=False):
73 | """
74 | 从给定的角度中选择一个角度,对图片和文本框进行旋转
75 | :param imgs: 原图 和 label
76 | :param degrees: 角度,可以是一个数值或者list
77 | :param same_size: 是否保持和原图一样大
78 | :return: 旋转后的图片和角度
79 | """
80 | if isinstance(degrees, numbers.Number):
81 | if degrees < 0:
82 | raise ValueError("If degrees is a single number, it must be positive.")
83 | degrees = (-degrees, degrees)
84 | elif isinstance(degrees, list) or isinstance(degrees, tuple) or isinstance(degrees, np.ndarray):
85 | if len(degrees) != 2:
86 | raise ValueError("If degrees is a sequence, it must be of len 2.")
87 | degrees = degrees
88 | else:
89 | raise Exception('degrees must in Number or list or tuple or np.ndarray')
90 | # ---------------------- 旋转图像 ----------------------
91 | w = imgs[0].shape[1]
92 | h = imgs[0].shape[0]
93 | angle = np.random.uniform(degrees[0], degrees[1])
94 |
95 | if same_size:
96 | nw = w
97 | nh = h
98 | else:
99 | # 角度变弧度
100 | rangle = np.deg2rad(angle)
101 | # 计算旋转之后图像的w, h
102 | nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
103 | nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
104 | # 构造仿射矩阵
105 | rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, 1)
106 | # 计算原图中心点到新图中心点的偏移量
107 | rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
108 | # 更新仿射矩阵
109 | rot_mat[0, 2] += rot_move[0]
110 | rot_mat[1, 2] += rot_move[1]
111 | for idx in range(len(imgs)):
112 | # 仿射变换
113 | imgs[idx] = cv2.warpAffine(imgs[idx], rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
114 | return imgs
115 |
116 | def random_crop(self, imgs, img_size):
117 | h, w = imgs[0].shape[0:2]
118 | th, tw = img_size
119 | if w == tw and h == th:
120 | return imgs
121 |
122 | # label中存在文本实例,并且按照概率进行裁剪
123 | if np.max(imgs[1][:, :, -1]) > 0 and random.random() > 3.0 / 8.0:
124 | # 文本实例的top left点
125 | tl = np.min(np.where(imgs[1][:, :, -1] > 0), axis=1) - img_size
126 | tl[tl < 0] = 0
127 | # 文本实例的 bottom right 点
128 | br = np.max(np.where(imgs[1][:, :, -1] > 0), axis=1) - img_size
129 | br[br < 0] = 0
130 | # 保证选到右下角点是,有足够的距离进行crop
131 | br[0] = min(br[0], h - th)
132 | br[1] = min(br[1], w - tw)
133 | for _ in range(50000):
134 | i = random.randint(tl[0], br[0])
135 | j = random.randint(tl[1], br[1])
136 | # 保证最小的图有文本
137 | if imgs[1][:, :, 0][i:i + th, j:j + tw].sum() <= 0:
138 | continue
139 | else:
140 | break
141 | else:
142 | i = random.randint(0, h - th)
143 | j = random.randint(0, w - tw)
144 |
145 | # return i, j, th, tw
146 | for idx in range(len(imgs)):
147 | if len(imgs[idx].shape) == 3:
148 | imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
149 | else:
150 | imgs[idx] = imgs[idx][i:i + th, j:j + tw]
151 | return imgs
152 |
153 |
154 | def horizontal_flip(self, imgs: list) -> list:
155 | """
156 | 对图片和文本框进行水平翻转
157 | :param im: 图片
158 | :param text_polys: 文本框
159 | :return: 水平翻转之后的图片和文本框
160 | """
161 | for idx in range(len(imgs)):
162 | imgs[idx] = cv2.flip(imgs[idx], 1)
163 | return imgs
164 |
165 | def vertical_flip(self, imgs: list) -> list:
166 | """
167 | 对图片和文本框进行竖直翻转
168 | :param im: 图片
169 | :param text_polys: 文本框
170 | :return: 竖直翻转之后的图片和文本框
171 | """
172 | for idx in range(len(imgs)):
173 | imgs[idx] = cv2.flip(imgs[idx], 0)
174 | return imgs
--------------------------------------------------------------------------------
/dataset/data_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2018/6/11 15:54
3 | # @Author : zhoujun
4 |
5 | import os
6 | import random
7 | import pathlib
8 | import pyclipper
9 | from torch.utils import data
10 | import glob
11 | import numpy as np
12 | import cv2
13 | from dataset.augment import DataAugment
14 | from utils.utils import draw_bbox
15 |
16 | data_aug = DataAugment()
17 |
18 |
19 | def check_and_validate_polys(polys, xxx_todo_changeme):
20 | '''
21 | check so that the text poly is in the same direction,
22 | and also filter some invalid polygons
23 | :param polys:
24 | :param tags:
25 | :return:
26 | '''
27 | (h, w) = xxx_todo_changeme
28 | if polys.shape[0] == 0:
29 | return polys
30 | polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) # x coord not max w-1, and not min 0
31 | polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) # y coord not max h-1, and not min 0
32 |
33 | validated_polys = []
34 | for poly in polys:
35 | p_area = cv2.contourArea(poly)
36 | if abs(p_area) < 1:
37 | continue
38 | validated_polys.append(poly)
39 | return np.array(validated_polys)
40 |
41 |
42 | def generate_rbox(im_size, text_polys, text_tags, training_mask, i, n, m):
43 | """
44 | 生成mask图,白色部分是文本,黑色是北京
45 | :param im_size: 图像的h,w
46 | :param text_polys: 框的坐标
47 | :param text_tags: 标注文本框是否参与训练
48 | :return: 生成的mask图
49 | """
50 | h, w = im_size
51 | score_map = np.zeros((h, w), dtype=np.uint8)
52 | for poly, tag in zip(text_polys, text_tags):
53 | poly = poly.astype(np.int)
54 | r_i = 1 - (1 - m) * (n - i) / (n - 1)
55 | d_i = cv2.contourArea(poly) * (1 - r_i * r_i) / cv2.arcLength(poly, True)
56 | pco = pyclipper.PyclipperOffset()
57 | # pco.AddPath(pyclipper.scale_to_clipper(poly), pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
58 | # shrinked_poly = np.floor(np.array(pyclipper.scale_from_clipper(pco.Execute(-d_i)))).astype(np.int)
59 | pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
60 | shrinked_poly = np.array(pco.Execute(-d_i))
61 | cv2.fillPoly(score_map, shrinked_poly, 1)
62 | # 制作mask
63 | # rect = cv2.minAreaRect(shrinked_poly)
64 | # poly_h, poly_w = rect[1]
65 |
66 | # if min(poly_h, poly_w) < 10:
67 | # cv2.fillPoly(training_mask, shrinked_poly, 0)
68 | if tag:
69 | cv2.fillPoly(training_mask, shrinked_poly, 0)
70 | # 闭运算填充内部小框
71 | # kernel = np.ones((3, 3), np.uint8)
72 | # score_map = cv2.morphologyEx(score_map, cv2.MORPH_CLOSE, kernel)
73 | return score_map, training_mask
74 |
75 |
76 | def augmentation(im: np.ndarray, text_polys: np.ndarray, scales: np.ndarray, degrees: int, input_size: int) -> tuple:
77 | # the images are rescaled with ratio {0.5, 1.0, 2.0, 3.0} randomly
78 | im, text_polys = data_aug.random_scale(im, text_polys, scales)
79 | # the images are horizontally fliped and rotated in range [−10◦, 10◦] randomly
80 | if random.random() < 0.5:
81 | im, text_polys = data_aug.horizontal_flip(im, text_polys)
82 | if random.random() < 0.5:
83 | im, text_polys = data_aug.random_rotate_img_bbox(im, text_polys, degrees)
84 | # 640 × 640 random samples are cropped from the transformed images
85 | # im, text_polys = data_aug.random_crop_img_bboxes(im, text_polys)
86 |
87 | # im, text_polys = data_aug.resize(im, text_polys, input_size, keep_ratio=False)
88 | # im, text_polys = data_aug.random_crop_image_pse(im, text_polys, input_size)
89 |
90 | return im, text_polys
91 |
92 |
93 | def image_label(im_fn: str, text_polys: np.ndarray, text_tags: list, n: int, m: float, input_size: int,
94 | defrees: int = 10,
95 | scales: np.ndarray = np.array([0.5, 1, 2.0, 3.0])) -> tuple:
96 | '''
97 | get image's corresponding matrix and ground truth
98 | return
99 | images [512, 512, 3]
100 | score [128, 128, 1]
101 | geo [128, 128, 5]
102 | mask [128, 128, 1]
103 | '''
104 |
105 | im = cv2.imread(im_fn)
106 | im = cv2.cvtColor(im,cv2.COLOR_BGR2RGB)
107 | h, w, _ = im.shape
108 | # 检查越界
109 | text_polys = check_and_validate_polys(text_polys, (h, w))
110 | im, text_polys, = augmentation(im, text_polys, scales, defrees, input_size)
111 |
112 | h, w, _ = im.shape
113 | short_edge = min(h, w)
114 | if short_edge < input_size:
115 | # 保证短边 >= inputsize
116 | scale = input_size / short_edge
117 | im = cv2.resize(im, dsize=None, fx=scale, fy=scale)
118 | text_polys *= scale
119 |
120 | # # normal images
121 | # im = im.astype(np.float32)
122 | # im /= 255.0
123 | # im -= np.array((0.485, 0.456, 0.406))
124 | # im /= np.array((0.229, 0.224, 0.225))
125 |
126 | h, w, _ = im.shape
127 | training_mask = np.ones((h, w), dtype=np.uint8)
128 | score_maps = []
129 | for i in range(1, n + 1):
130 | # s1->sn,由小到大
131 | score_map, training_mask = generate_rbox((h, w), text_polys, text_tags, training_mask, i, n, m)
132 | score_maps.append(score_map)
133 | score_maps = np.array(score_maps, dtype=np.float32)
134 | imgs = data_aug.random_crop_author([im, score_maps.transpose((1, 2, 0)),training_mask], (input_size, input_size))
135 | return imgs[0], imgs[1].transpose((2, 0, 1)), imgs[2]#im,score_maps,training_mask#
136 |
137 |
138 | class MyDataset(data.Dataset):
139 | def __init__(self, data_dir, data_shape: int = 640, n=6, m=0.5, transform=None, target_transform=None):
140 | self.data_list = self.load_data(data_dir)
141 | self.data_shape = data_shape
142 | self.transform = transform
143 | self.target_transform = target_transform
144 | self.n = n
145 | self.m = m
146 |
147 | def __getitem__(self, index):
148 | # print(self.image_list[index])
149 | img_path, text_polys, text_tags = self.data_list[index]
150 | img, score_maps, training_mask = image_label(img_path, text_polys, text_tags, input_size=self.data_shape,
151 | n=self.n,
152 | m=self.m)
153 | # img = draw_bbox(img,text_polys)
154 | if self.transform:
155 | img = self.transform(img)
156 | if self.target_transform:
157 | score_maps = self.target_transform(score_maps)
158 | training_mask = self.target_transform(training_mask)
159 | return img, score_maps, training_mask
160 |
161 | def load_data(self, data_dir: str) -> list:
162 | data_list = []
163 | for x in glob.glob(data_dir + '/img/*.jpg', recursive=True):
164 | d = pathlib.Path(x)
165 | label_path = os.path.join(data_dir, 'gt', ('gt_' + str(d.stem) + '.txt'))
166 | bboxs, text = self._get_annotation(label_path)
167 | if len(bboxs) > 0:
168 | data_list.append((x, bboxs, text))
169 | else:
170 | print('there is no suit bbox on {}'.format(label_path))
171 | return data_list
172 |
173 | def _get_annotation(self, label_path: str) -> tuple:
174 | boxes = []
175 | text_tags = []
176 | with open(label_path, encoding='utf-8', mode='r') as f:
177 | for line in f.readlines():
178 | params = line.strip().strip('\ufeff').strip('\xef\xbb\xbf').split(',')
179 | try:
180 | label = params[8]
181 | if label == '*' or label == '###':
182 | text_tags.append(True)
183 | else:
184 | text_tags.append(False)
185 | # if label == '*' or label == '###':
186 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, params[:8]))
187 | boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
188 | except:
189 | print('load label failed on {}'.format(label_path))
190 | return np.array(boxes, dtype=np.float32), np.array(text_tags, dtype=np.bool)
191 |
192 | def __len__(self):
193 | return len(self.data_list)
194 |
195 | def save_label(self, img_path, label):
196 | save_path = img_path.replace('img', 'save')
197 | if not os.path.exists(os.path.split(save_path)[0]):
198 | os.makedirs(os.path.split(save_path)[0])
199 | img = draw_bbox(img_path, label)
200 | cv2.imwrite(save_path, img)
201 | return img
202 |
203 |
204 | if __name__ == '__main__':
205 | import torch
206 | import config
207 | from utils.utils import show_img
208 | from tqdm import tqdm
209 | from torch.utils.data import DataLoader
210 | import matplotlib.pyplot as plt
211 | from torchvision import transforms
212 |
213 | train_data = MyDataset(config.trainroot, data_shape=config.data_shape, n=config.n, m=config.m,
214 | transform=transforms.ToTensor())
215 | train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=False, num_workers=0)
216 |
217 | pbar = tqdm(total=len(train_loader))
218 | for i, (img, label, mask) in enumerate(train_loader):
219 | print(label.shape)
220 | print(img.shape)
221 | print(label[0][-1].sum())
222 | print(mask[0].shape)
223 | # pbar.update(1)
224 | show_img((img[0] * mask[0].to(torch.float)).numpy().transpose(1, 2, 0), color=True)
225 | show_img(label[0])
226 | show_img(mask[0])
227 | plt.show()
228 |
229 | pbar.close()
230 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2018/6/11 15:54
3 | # @Author : zhoujun
4 | import torch
5 | import shutil
6 | import numpy as np
7 | import config
8 | import os
9 | import cv2
10 | from tqdm import tqdm
11 | from models import PSENet
12 | from predict import Pytorch_model
13 | from cal_recall.script import cal_recall_precison_f1
14 | from utils import draw_bbox
15 |
16 | torch.backends.cudnn.benchmark = True
17 |
18 |
19 | def main(model_path, backbone, scale, path, save_path, gpu_id):
20 | if os.path.exists(save_path):
21 | shutil.rmtree(save_path, ignore_errors=True)
22 | if not os.path.exists(save_path):
23 | os.makedirs(save_path)
24 | save_img_folder = os.path.join(save_path, 'img')
25 | if not os.path.exists(save_img_folder):
26 | os.makedirs(save_img_folder)
27 | save_txt_folder = os.path.join(save_path, 'result')
28 | if not os.path.exists(save_txt_folder):
29 | os.makedirs(save_txt_folder)
30 | img_paths = [os.path.join(path, x) for x in os.listdir(path)]
31 | net = PSENet(backbone=backbone, pretrained=False, result_num=config.n)
32 | model = Pytorch_model(model_path, net=net, scale=scale, gpu_id=gpu_id)
33 | total_frame = 0.0
34 | total_time = 0.0
35 | for img_path in tqdm(img_paths):
36 | img_name = os.path.basename(img_path).split('.')[0]
37 | save_name = os.path.join(save_txt_folder, 'res_' + img_name + '.txt')
38 | _, boxes_list, t = model.predict(img_path)
39 | total_frame += 1
40 | total_time += t
41 | # img = draw_bbox(img_path, boxes_list, color=(0, 0, 255))
42 | # cv2.imwrite(os.path.join(save_img_folder, '{}.jpg'.format(img_name)), img)
43 | np.savetxt(save_name, boxes_list.reshape(-1, 8), delimiter=',', fmt='%d')
44 | print('fps:{}'.format(total_frame / total_time))
45 | return save_txt_folder
46 |
47 |
48 | if __name__ == '__main__':
49 | os.environ['CUDA_VISIBLE_DEVICES'] = str('2')
50 | backbone = 'resnet152'
51 | scale = 4
52 | model_path = 'output/psenet_icd2015_resnet152_author_crop_adam_warm_up_myloss/best_r0.714011_p0.708214_f10.711100.pth'
53 | data_path = '/data2/dataset/ICD15/test/img'
54 | gt_path = '/data2/dataset/ICD15/test/gt'
55 | save_path = './result/_scale{}'.format(scale)
56 | gpu_id = 0
57 | print('backbone:{},scale:{},model_path:{}'.format(backbone,scale,model_path))
58 | save_path = main(model_path, backbone, scale, data_path, save_path, gpu_id=gpu_id)
59 | result = cal_recall_precison_f1(gt_path=gt_path, result_path=save_path)
60 | print(result)
61 | # print(cal_recall_precison_f1('/data2/dataset/ICD151/test/gt', '/data1/zj/tensorflow_PSENet/tmp/'))
62 |
--------------------------------------------------------------------------------
/imgs/img_125.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenmuZhou/PSENet.pytorch/f760c2f4938726a2d00efaf5e5b28218323c44ca/imgs/img_125.jpg
--------------------------------------------------------------------------------
/imgs/img_31.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenmuZhou/PSENet.pytorch/f760c2f4938726a2d00efaf5e5b28218323c44ca/imgs/img_31.jpg
--------------------------------------------------------------------------------
/imgs/img_73.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenmuZhou/PSENet.pytorch/f760c2f4938726a2d00efaf5e5b28218323c44ca/imgs/img_73.jpg
--------------------------------------------------------------------------------
/imgs/img_83.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenmuZhou/PSENet.pytorch/f760c2f4938726a2d00efaf5e5b28218323c44ca/imgs/img_83.jpg
--------------------------------------------------------------------------------
/imgs/img_98.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WenmuZhou/PSENet.pytorch/f760c2f4938726a2d00efaf5e5b28218323c44ca/imgs/img_98.jpg
--------------------------------------------------------------------------------
/models/ShuffleNetV2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import logging
4 | from torchvision.models.utils import load_state_dict_from_url
5 |
6 | logger = logging.getLogger('project')
7 |
8 | __all__ = [
9 | 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
10 | 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
11 | ]
12 |
13 | model_urls = {
14 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
15 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
16 | 'shufflenetv2_x1.5': None,
17 | 'shufflenetv2_x2.0': None,
18 | }
19 |
20 |
21 | def channel_shuffle(x, groups):
22 | batchsize, num_channels, height, width = x.data.size()
23 | channels_per_group = num_channels // groups
24 |
25 | # reshape
26 | x = x.view(batchsize, groups,
27 | channels_per_group, height, width)
28 |
29 | x = torch.transpose(x, 1, 2).contiguous()
30 |
31 | # flatten
32 | x = x.view(batchsize, -1, height, width)
33 |
34 | return x
35 |
36 |
37 | class InvertedResidual(nn.Module):
38 | def __init__(self, inp, oup, stride):
39 | super(InvertedResidual, self).__init__()
40 |
41 | if not (1 <= stride <= 3):
42 | raise ValueError('illegal stride value')
43 | self.stride = stride
44 |
45 | branch_features = oup // 2
46 | assert (self.stride != 1) or (inp == branch_features << 1)
47 |
48 | if self.stride > 1:
49 | self.branch1 = nn.Sequential(
50 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
51 | nn.BatchNorm2d(inp),
52 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
53 | nn.BatchNorm2d(branch_features),
54 | nn.ReLU(inplace=True),
55 | )
56 |
57 | self.branch2 = nn.Sequential(
58 | nn.Conv2d(inp if (self.stride > 1) else branch_features,
59 | branch_features, kernel_size=1, stride=1, padding=0, bias=False),
60 | nn.BatchNorm2d(branch_features),
61 | nn.ReLU(inplace=True),
62 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
63 | nn.BatchNorm2d(branch_features),
64 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
65 | nn.BatchNorm2d(branch_features),
66 | nn.ReLU(inplace=True),
67 | )
68 |
69 | @staticmethod
70 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
71 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
72 |
73 | def forward(self, x):
74 | if self.stride == 1:
75 | x1, x2 = x.chunk(2, dim=1)
76 | out = torch.cat((x1, self.branch2(x2)), dim=1)
77 | else:
78 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
79 |
80 | out = channel_shuffle(out, 2)
81 |
82 | return out
83 |
84 |
85 | class ShuffleNetV2(nn.Module):
86 | def __init__(self, stages_repeats, stages_out_channels, num_classes=1000):
87 | super(ShuffleNetV2, self).__init__()
88 |
89 | if len(stages_repeats) != 3:
90 | raise ValueError('expected stages_repeats as list of 3 positive ints')
91 | if len(stages_out_channels) != 5:
92 | raise ValueError('expected stages_out_channels as list of 5 positive ints')
93 | self._stage_out_channels = stages_out_channels
94 |
95 | input_channels = 3
96 | output_channels = self._stage_out_channels[0]
97 | self.conv1 = nn.Sequential(
98 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
99 | nn.BatchNorm2d(output_channels),
100 | nn.ReLU(inplace=True),
101 | )
102 | input_channels = output_channels
103 |
104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
105 |
106 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
107 | for name, repeats, output_channels in zip(
108 | stage_names, stages_repeats, self._stage_out_channels[1:]):
109 | seq = [InvertedResidual(input_channels, output_channels, 2)]
110 | for i in range(repeats - 1):
111 | seq.append(InvertedResidual(output_channels, output_channels, 1))
112 | setattr(self, name, nn.Sequential(*seq))
113 | input_channels = output_channels
114 |
115 | output_channels = self._stage_out_channels[-1]
116 | self.conv5 = nn.Sequential(
117 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
118 | nn.BatchNorm2d(output_channels),
119 | nn.ReLU(inplace=True),
120 | )
121 |
122 | def forward(self, x):
123 | x = self.conv1(x)
124 | c2 = self.maxpool(x)
125 | c3 = self.stage2(c2)
126 | c4 = self.stage3(c3)
127 | c5 = self.stage4(c4)
128 | # c5 = self.conv5(c5)
129 | return c2, c3, c4, c5
130 |
131 |
132 | def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
133 | model = ShuffleNetV2(*args, **kwargs)
134 |
135 | if pretrained:
136 | model_url = model_urls[arch]
137 | if model_url is None:
138 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
139 | else:
140 | state_dict = load_state_dict_from_url(model_url, progress=progress)
141 | model.load_state_dict(state_dict,strict=False)
142 | logger.info('load pretrained models from imagenet')
143 | return model
144 |
145 |
146 | def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
147 | """
148 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in
149 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
150 | `_.
151 |
152 | Args:
153 | pretrained (bool): If True, returns a models pre-trained on ImageNet
154 | progress (bool): If True, displays a progress bar of the download to stderr
155 | """
156 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
157 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
158 |
159 |
160 | def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
161 | """
162 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in
163 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
164 | `_.
165 |
166 | Args:
167 | pretrained (bool): If True, returns a models pre-trained on ImageNet
168 | progress (bool): If True, displays a progress bar of the download to stderr
169 | """
170 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
171 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
172 |
173 |
174 | def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
175 | """
176 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in
177 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
178 | `_.
179 |
180 | Args:
181 | pretrained (bool): If True, returns a models pre-trained on ImageNet
182 | progress (bool): If True, displays a progress bar of the download to stderr
183 | """
184 | return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
185 | [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
186 |
187 |
188 | def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
189 | """
190 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in
191 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
192 | `_.
193 |
194 | Args:
195 | pretrained (bool): If True, returns a models pre-trained on ImageNet
196 | progress (bool): If True, displays a progress bar of the download to stderr
197 | """
198 | return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
199 | [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
200 |
201 |
202 | if __name__ == '__main__':
203 | import time
204 |
205 | device = torch.device('cpu')
206 | net = shufflenet_v2_x1_0(pretrained=True)
207 | net.eval()
208 | x = torch.zeros(1, 3, 640, 640).to(device)
209 | start = time.time()
210 | y = net(x)
211 | print(time.time() - start)
212 | for u in y:
213 | print(u.shape)
214 | # torch.save(net.state_dict(), f'shufflenet_v2_x2_0.pth')
215 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/1/2 18:18
3 | # @Author : zhoujun
4 | from .model import PSENet
--------------------------------------------------------------------------------
/models/loss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 3/29/19 11:03 AM
3 | # @Author : zhoujun
4 | import torch
5 | from torch import nn
6 | import numpy as np
7 |
8 |
9 | class PSELoss(nn.Module):
10 | def __init__(self, Lambda, ratio=3, reduction='mean'):
11 | """Implement PSE Loss.
12 | """
13 | super(PSELoss, self).__init__()
14 | assert reduction in ['mean', 'sum'], " reduction must in ['mean','sum']"
15 | self.Lambda = Lambda
16 | self.ratio = ratio
17 | self.reduction = reduction
18 |
19 | def forward(self, outputs, labels, training_masks):
20 | texts = outputs[:, -1, :, :]
21 | kernels = outputs[:, :-1, :, :]
22 | gt_texts = labels[:, -1, :, :]
23 | gt_kernels = labels[:, :-1, :, :]
24 |
25 | selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
26 | selected_masks = selected_masks.to(outputs.device)
27 |
28 | loss_text = self.dice_loss(texts, gt_texts, selected_masks)
29 |
30 | loss_kernels = []
31 | mask0 = torch.sigmoid(texts).data.cpu().numpy()
32 | mask1 = training_masks.data.cpu().numpy()
33 | selected_masks = ((mask0 > 0.5) & (mask1 > 0.5)).astype('float32')
34 | selected_masks = torch.from_numpy(selected_masks).float()
35 | selected_masks = selected_masks.to(outputs.device)
36 | kernels_num = gt_kernels.size()[1]
37 | for i in range(kernels_num):
38 | kernel_i = kernels[:, i, :, :]
39 | gt_kernel_i = gt_kernels[:, i, :, :]
40 | loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, selected_masks)
41 | loss_kernels.append(loss_kernel_i)
42 | loss_kernels = torch.stack(loss_kernels).mean(0)
43 | if self.reduction == 'mean':
44 | loss_text = loss_text.mean()
45 | loss_kernels = loss_kernels.mean()
46 | elif self.reduction == 'sum':
47 | loss_text = loss_text.sum()
48 | loss_kernels = loss_kernels.sum()
49 |
50 | loss = self.Lambda * loss_text + (1 - self.Lambda) * loss_kernels
51 | return loss_text, loss_kernels, loss
52 |
53 | def dice_loss(self, input, target, mask):
54 | input = torch.sigmoid(input)
55 |
56 | input = input.contiguous().view(input.size()[0], -1)
57 | target = target.contiguous().view(target.size()[0], -1)
58 | mask = mask.contiguous().view(mask.size()[0], -1)
59 |
60 | input = input * mask
61 | target = target * mask
62 |
63 | a = torch.sum(input * target, 1)
64 | b = torch.sum(input * input, 1) + 0.001
65 | c = torch.sum(target * target, 1) + 0.001
66 | d = (2 * a) / (b + c)
67 | return 1 - d
68 |
69 | def ohem_single(self, score, gt_text, training_mask):
70 | pos_num = (int)(np.sum(gt_text > 0.5)) - (int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
71 |
72 | if pos_num == 0:
73 | # selected_mask = gt_text.copy() * 0 # may be not good
74 | selected_mask = training_mask
75 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
76 | return selected_mask
77 |
78 | neg_num = (int)(np.sum(gt_text <= 0.5))
79 | neg_num = (int)(min(pos_num * 3, neg_num))
80 |
81 | if neg_num == 0:
82 | selected_mask = training_mask
83 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
84 | return selected_mask
85 |
86 | neg_score = score[gt_text <= 0.5]
87 | # 将负样本得分从高到低排序
88 | neg_score_sorted = np.sort(-neg_score)
89 | threshold = -neg_score_sorted[neg_num - 1]
90 | # 选出 得分高的 负样本 和正样本 的 mask
91 | selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5)
92 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
93 | return selected_mask
94 |
95 | def ohem_batch(self, scores, gt_texts, training_masks):
96 | scores = scores.data.cpu().numpy()
97 | gt_texts = gt_texts.data.cpu().numpy()
98 | training_masks = training_masks.data.cpu().numpy()
99 |
100 | selected_masks = []
101 | for i in range(scores.shape[0]):
102 | selected_masks.append(self.ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :]))
103 |
104 | selected_masks = np.concatenate(selected_masks, 0)
105 | selected_masks = torch.from_numpy(selected_masks).float()
106 |
107 | return selected_masks
108 |
--------------------------------------------------------------------------------
/models/mobilenetv3.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/5/23 15:22
3 | # @Author : zhoujun
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn import init
9 |
10 |
11 | class hswish(nn.Module):
12 | def forward(self, x):
13 | out = x * F.relu6(x + 3, inplace=True) / 6
14 | return out
15 |
16 |
17 | class hsigmoid(nn.Module):
18 | def forward(self, x):
19 | out = F.relu6(x + 3, inplace=True) / 6
20 | return out
21 |
22 |
23 | class SeModule(nn.Module):
24 | def __init__(self, in_size, reduction=4):
25 | super(SeModule, self).__init__()
26 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
27 |
28 | self.se = nn.Sequential(
29 | nn.Conv2d(in_size, in_size // reduction, kernel_size=1, stride=1, padding=0, bias=False),
30 | nn.BatchNorm2d(in_size // reduction),
31 | nn.ReLU(inplace=True),
32 | nn.Conv2d(in_size // reduction, in_size, kernel_size=1, stride=1, padding=0, bias=False),
33 | nn.BatchNorm2d(in_size),
34 | hsigmoid()
35 | )
36 |
37 | def forward(self, x):
38 | return x * self.se(x)
39 |
40 |
41 | class Block(nn.Module):
42 | '''expand + depthwise + pointwise'''
43 |
44 | def __init__(self, kernel_size, in_size, expand_size, out_size, nolinear, semodule, stride):
45 | super(Block, self).__init__()
46 | self.stride = stride
47 | self.se = semodule
48 |
49 | self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, stride=1, padding=0, bias=False)
50 | self.bn1 = nn.BatchNorm2d(expand_size)
51 | self.nolinear1 = nolinear
52 | self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride,
53 | padding=kernel_size // 2, groups=expand_size, bias=False)
54 | self.bn2 = nn.BatchNorm2d(expand_size)
55 | self.nolinear2 = nolinear
56 | self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, stride=1, padding=0, bias=False)
57 | self.bn3 = nn.BatchNorm2d(out_size)
58 |
59 | self.shortcut = nn.Sequential()
60 | if stride == 1 and in_size != out_size:
61 | self.shortcut = nn.Sequential(
62 | nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0, bias=False),
63 | nn.BatchNorm2d(out_size),
64 | )
65 |
66 | def forward(self, x):
67 | out = self.nolinear1(self.bn1(self.conv1(x)))
68 | out = self.nolinear2(self.bn2(self.conv2(out)))
69 | out = self.bn3(self.conv3(out))
70 | if self.se != None:
71 | out = self.se(out)
72 | out = out + self.shortcut(x) if self.stride == 1 else out
73 | return out
74 |
75 |
76 | class MobileNetV3_Large(nn.Module):
77 | def __init__(self, pretrained):
78 | super(MobileNetV3_Large, self).__init__()
79 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
80 | self.bn1 = nn.BatchNorm2d(16)
81 | self.hs1 = hswish()
82 |
83 | self.layer1 = nn.Sequential(
84 | Block(3, 16, 16, 16, nn.ReLU(inplace=True), None, 1),
85 | Block(3, 16, 64, 24, nn.ReLU(inplace=True), None, 2),
86 | Block(3, 24, 72, 24, nn.ReLU(inplace=True), None, 1),
87 | )
88 |
89 | self.layer2 = nn.Sequential(
90 | Block(5, 24, 72, 40, nn.ReLU(inplace=True), SeModule(40), 2),
91 | Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
92 | Block(5, 40, 120, 40, nn.ReLU(inplace=True), SeModule(40), 1),
93 | )
94 |
95 | self.layer3 = nn.Sequential(
96 | Block(3, 40, 240, 80, hswish(), None, 2),
97 | Block(3, 80, 200, 80, hswish(), None, 1),
98 | Block(3, 80, 184, 80, hswish(), None, 1),
99 | Block(3, 80, 184, 80, hswish(), None, 1),
100 | Block(3, 80, 480, 112, hswish(), SeModule(112), 1),
101 | Block(3, 112, 672, 112, hswish(), SeModule(112), 1),
102 | Block(5, 112, 672, 160, hswish(), SeModule(160), 1),
103 | )
104 | self.layer4 = nn.Sequential(
105 | Block(5, 160, 672, 160, hswish(), SeModule(160), 2),
106 | Block(5, 160, 960, 160, hswish(), SeModule(160), 1),
107 | )
108 | self.conv2 = nn.Conv2d(160, 960, kernel_size=1, stride=1, padding=0, bias=False)
109 | self.bn2 = nn.BatchNorm2d(960)
110 | self.hs2 = hswish()
111 | self.init_params()
112 |
113 | def init_params(self):
114 | for m in self.modules():
115 | if isinstance(m, nn.Conv2d):
116 | init.kaiming_normal_(m.weight, mode='fan_out')
117 | if m.bias is not None:
118 | init.constant_(m.bias, 0)
119 | elif isinstance(m, nn.BatchNorm2d):
120 | init.constant_(m.weight, 1)
121 | init.constant_(m.bias, 0)
122 | elif isinstance(m, nn.Linear):
123 | init.normal_(m.weight, std=0.001)
124 | if m.bias is not None:
125 | init.constant_(m.bias, 0)
126 |
127 | def forward(self, x):
128 | c1 = self.hs1(self.bn1(self.conv1(x)))
129 | c2 = self.layer1(c1)
130 | c3 = self.layer2(c2)
131 | c4 = self.layer3(c3)
132 | c5 = self.layer4(c4)
133 | # c5 = self.hs2(self.bn2(self.conv2(c5)))
134 | return c1, c2, c3, c4, c5
135 |
136 |
137 | class MobileNetV3_Small(nn.Module):
138 | def __init__(self, pretrained):
139 | super(MobileNetV3_Small, self).__init__()
140 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
141 | self.bn1 = nn.BatchNorm2d(16)
142 | self.hs1 = hswish()
143 |
144 | self.layer1 = Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2)
145 | self.layer2 = nn.Sequential(
146 | Block(3, 16, 72, 24, nn.ReLU(inplace=True), None, 2),
147 | Block(3, 24, 88, 24, nn.ReLU(inplace=True), None, 1),
148 | )
149 |
150 | self.layer3 = nn.Sequential(
151 | Block(5, 24, 96, 40, hswish(), SeModule(40), 2),
152 | Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
153 | Block(5, 40, 240, 40, hswish(), SeModule(40), 1),
154 | Block(5, 40, 120, 48, hswish(), SeModule(48), 1),
155 | Block(5, 48, 144, 48, hswish(), SeModule(48), 1),
156 | )
157 | self.layer4 = nn.Sequential(
158 | Block(5, 48, 288, 96, hswish(), SeModule(96), 2),
159 | Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
160 | Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
161 | )
162 | self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
163 | self.bn2 = nn.BatchNorm2d(576)
164 | self.hs2 = hswish()
165 | self.init_params()
166 |
167 | def init_params(self):
168 | for m in self.modules():
169 | if isinstance(m, nn.Conv2d):
170 | init.kaiming_normal_(m.weight, mode='fan_out')
171 | if m.bias is not None:
172 | init.constant_(m.bias, 0)
173 | elif isinstance(m, nn.BatchNorm2d):
174 | init.constant_(m.weight, 1)
175 | init.constant_(m.bias, 0)
176 | elif isinstance(m, nn.Linear):
177 | init.normal_(m.weight, std=0.001)
178 | if m.bias is not None:
179 | init.constant_(m.bias, 0)
180 |
181 | def forward(self, x):
182 | c1 = self.hs1(self.bn1(self.conv1(x)))
183 | c2 = self.layer1(c1)
184 | c3 = self.layer2(c2)
185 | c4 = self.layer3(c3)
186 | c5 = self.layer4(c4)
187 | # c5 = self.hs2(self.bn2(self.conv2(c5)))
188 | return c1, c2, c3, c4, c5
189 |
190 |
191 | if __name__ == '__main__':
192 | import time
193 |
194 | device = torch.device('cpu')
195 | net = MobileNetV3_Large(pretrained=False)
196 | net.eval()
197 | x = torch.zeros(1, 3, 608, 800).to(device)
198 | start = time.time()
199 | y = net(x)
200 | print(time.time() - start)
201 | for u in y:
202 | print(u.shape)
203 | torch.save(net.state_dict(), f'MobileNetV3_Large111.pth')
204 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/1/2 17:29
3 | # @Author : zhoujun
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
8 | from models.mobilenetv3 import MobileNetV3_Large, MobileNetV3_Small
9 | from models.ShuffleNetV2 import shufflenet_v2_x1_0
10 |
11 | d = {'resnet18': {'models': resnet18, 'out': [64, 128, 256, 512]},
12 | 'resnet34': {'models': resnet34, 'out': [64, 128, 256, 512]},
13 | 'resnet50': {'models': resnet50, 'out': [256, 512, 1024, 2048]},
14 | 'resnet101': {'models': resnet101, 'out': [256, 512, 1024, 2048]},
15 | 'resnet152': {'models': resnet152, 'out': [256, 512, 1024, 2048]},
16 | 'MobileNetV3_Large': {'models': MobileNetV3_Large, 'out': [24, 40, 160, 160]},
17 | 'MobileNetV3_Small': {'models': MobileNetV3_Small, 'out': [16, 24, 48, 96]},
18 | 'shufflenetv2': {'models': shufflenet_v2_x1_0, 'out': [24, 116, 232, 464]}}
19 | inplace = True
20 |
21 |
22 | class PSENet(nn.Module):
23 | def __init__(self, backbone, result_num=6, scale: int = 1, pretrained=False):
24 | super(PSENet, self).__init__()
25 | assert backbone in d, 'backbone must in: {}'.format(d)
26 | self.scale = scale
27 | conv_out = 256
28 | model, out = d[backbone]['models'], d[backbone]['out']
29 | self.backbone = model(pretrained=pretrained)
30 | # Reduce channels
31 | # Top layer
32 | self.toplayer = nn.Sequential(nn.Conv2d(out[3], conv_out, kernel_size=1, stride=1, padding=0),
33 | nn.BatchNorm2d(conv_out),
34 | nn.ReLU(inplace=inplace)
35 | )
36 | # Lateral layers
37 | self.latlayer1 = nn.Sequential(nn.Conv2d(out[2], conv_out, kernel_size=1, stride=1, padding=0),
38 | nn.BatchNorm2d(conv_out),
39 | nn.ReLU(inplace=inplace)
40 | )
41 | self.latlayer2 = nn.Sequential(nn.Conv2d(out[1], conv_out, kernel_size=1, stride=1, padding=0),
42 | nn.BatchNorm2d(conv_out),
43 | nn.ReLU(inplace=inplace)
44 | )
45 | self.latlayer3 = nn.Sequential(nn.Conv2d(out[0], conv_out, kernel_size=1, stride=1, padding=0),
46 | nn.BatchNorm2d(conv_out),
47 | nn.ReLU(inplace=inplace)
48 | )
49 |
50 | # Smooth layers
51 | self.smooth1 = nn.Sequential(nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
52 | nn.BatchNorm2d(conv_out),
53 | nn.ReLU(inplace=inplace)
54 | )
55 | self.smooth2 = nn.Sequential(nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
56 | nn.BatchNorm2d(conv_out),
57 | nn.ReLU(inplace=inplace)
58 | )
59 | self.smooth3 = nn.Sequential(nn.Conv2d(conv_out, conv_out, kernel_size=3, stride=1, padding=1),
60 | nn.BatchNorm2d(conv_out),
61 | nn.ReLU(inplace=inplace)
62 | )
63 |
64 | self.conv = nn.Sequential(
65 | nn.Conv2d(conv_out * 4, conv_out, kernel_size=3, padding=1, stride=1),
66 | nn.BatchNorm2d(conv_out),
67 | nn.ReLU(inplace=inplace)
68 | )
69 | self.out_conv = nn.Conv2d(conv_out, result_num, kernel_size=1, stride=1)
70 |
71 | def forward(self, input: torch.Tensor):
72 | _, _, H, W = input.size()
73 | c2, c3, c4, c5 = self.backbone(input)
74 | # Top-down
75 | p5 = self.toplayer(c5)
76 | p4 = self._upsample_add(p5, self.latlayer1(c4))
77 | p4 = self.smooth1(p4)
78 | p3 = self._upsample_add(p4, self.latlayer2(c3))
79 | p3 = self.smooth2(p3)
80 | p2 = self._upsample_add(p3, self.latlayer3(c2))
81 | p2 = self.smooth3(p2)
82 |
83 | x = self._upsample_cat(p2, p3, p4, p5)
84 | x = self.conv(x)
85 | x = self.out_conv(x)
86 |
87 | if self.train:
88 | x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
89 | else:
90 | x = F.interpolate(x, size=(H // self.scale, W // self.scale), mode='bilinear', align_corners=True)
91 | return x
92 |
93 | def _upsample_add(self, x, y):
94 | return F.interpolate(x, size=y.size()[2:], mode='bilinear', align_corners=False) + y
95 |
96 | def _upsample_cat(self, p2, p3, p4, p5):
97 | h, w = p2.size()[2:]
98 | p3 = F.interpolate(p3, size=(h, w), mode='bilinear', align_corners=False)
99 | p4 = F.interpolate(p4, size=(h, w), mode='bilinear', align_corners=False)
100 | p5 = F.interpolate(p5, size=(h, w), mode='bilinear', align_corners=False)
101 | return torch.cat([p2, p3, p4, p5], dim=1)
102 |
103 |
104 | if __name__ == '__main__':
105 | import time
106 |
107 | device = torch.device('cpu')
108 | backbone = 'shufflenetv2'
109 | net = PSENet(backbone=backbone, pretrained=False, result_num=6).to(device)
110 | net.eval()
111 | x = torch.zeros(1, 3, 512, 512).to(device)
112 | start = time.time()
113 | y = net(x)
114 | print(time.time() - start)
115 | print(y.shape)
116 | # torch.save(net.state_dict(),f'{backbone}.pth')
117 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2019/1/2 17:30
3 | # @Author : zhoujun
4 | import torch
5 | import torch.nn as nn
6 | import math
7 | import logging
8 | import torch.utils.model_zoo as model_zoo
9 | import torchvision.models.resnet
10 |
11 | logger = logging.getLogger('project')
12 |
13 | __all__ = ['ResNet', 'resnet50', 'resnet101',
14 | 'resnet152']
15 |
16 | model_urls = {
17 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
18 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
19 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
20 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
21 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
22 | }
23 |
24 |
25 | def conv3x3(in_planes, out_planes, stride=1):
26 | """3x3 convolution with padding"""
27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28 | padding=1, bias=False)
29 |
30 |
31 | class BasicBlock(nn.Module):
32 | expansion = 1
33 |
34 | def __init__(self, inplanes, planes, stride=1, downsample=None):
35 | super(BasicBlock, self).__init__()
36 | self.conv1 = conv3x3(inplanes, planes, stride)
37 | self.bn1 = nn.BatchNorm2d(planes)
38 | self.relu = nn.ReLU(inplace=True)
39 | self.conv2 = conv3x3(planes, planes)
40 | self.bn2 = nn.BatchNorm2d(planes)
41 | self.downsample = downsample
42 | self.stride = stride
43 |
44 | def forward(self, x):
45 | residual = x
46 |
47 | out = self.conv1(x)
48 | out = self.bn1(out)
49 | out = self.relu(out)
50 |
51 | out = self.conv2(out)
52 | out = self.bn2(out)
53 |
54 | if self.downsample is not None:
55 | residual = self.downsample(x)
56 |
57 | out += residual
58 | out = self.relu(out)
59 |
60 | return out
61 |
62 |
63 | class Bottleneck(nn.Module):
64 | expansion = 4
65 |
66 | def __init__(self, inplanes, planes, stride=1, downsample=None):
67 | super(Bottleneck, self).__init__()
68 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
69 | self.bn1 = nn.BatchNorm2d(planes)
70 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
71 | padding=1, bias=False)
72 | self.bn2 = nn.BatchNorm2d(planes)
73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
74 | self.bn3 = nn.BatchNorm2d(planes * 4)
75 | self.relu = nn.ReLU(inplace=True)
76 | self.downsample = downsample
77 | self.stride = stride
78 |
79 | def forward(self, x):
80 | residual = x
81 |
82 | out = self.conv1(x)
83 | out = self.bn1(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv2(out)
87 | out = self.bn2(out)
88 | out = self.relu(out)
89 |
90 | out = self.conv3(out)
91 | out = self.bn3(out)
92 |
93 | if self.downsample is not None:
94 | residual = self.downsample(x)
95 |
96 | out += residual
97 | out = self.relu(out)
98 |
99 | return out
100 |
101 |
102 | class ResNet(nn.Module):
103 |
104 | def __init__(self, block, layers):
105 | self.inplanes = 64
106 | super(ResNet, self).__init__()
107 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
108 | bias=False)
109 | self.bn1 = nn.BatchNorm2d(64)
110 | self.relu = nn.ReLU(inplace=True)
111 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
112 | self.layer1 = self._make_layer(block, 64, layers[0])
113 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
114 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
115 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
116 |
117 | for m in self.modules():
118 | if isinstance(m, nn.Conv2d):
119 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
120 | m.weight.data.normal_(0, math.sqrt(2. / n))
121 | elif isinstance(m, nn.BatchNorm2d):
122 | m.weight.data.fill_(1)
123 | m.bias.data.zero_()
124 |
125 | def _make_layer(self, block, planes, blocks, stride=1):
126 | downsample = None
127 | if stride != 1 or self.inplanes != planes * block.expansion:
128 | downsample = nn.Sequential(
129 | nn.Conv2d(self.inplanes, planes * block.expansion,
130 | kernel_size=1, stride=stride, bias=False),
131 | nn.BatchNorm2d(planes * block.expansion),
132 | )
133 |
134 | layers = []
135 | layers.append(block(self.inplanes, planes, stride, downsample))
136 | self.inplanes = planes * block.expansion
137 | for i in range(1, blocks):
138 | layers.append(block(self.inplanes, planes))
139 |
140 | return nn.Sequential(*layers)
141 |
142 | def _load_pretrained_model(self, model_url):
143 | pretrain_dict = model_zoo.load_url(model_url)
144 | model_dict = {}
145 | state_dict = self.state_dict()
146 | for k, v in pretrain_dict.items():
147 | if k in state_dict:
148 | model_dict[k] = v
149 | state_dict.update(model_dict)
150 | self.load_state_dict(state_dict)
151 | logger.info('load pretrained models from imagenet')
152 |
153 | def forward(self, input):
154 | x = self.conv1(input)
155 | x = self.bn1(x)
156 | x = self.relu(x)
157 | x = self.maxpool(x)
158 | c2 = self.layer1(x)
159 | c3 = self.layer2(c2)
160 | c4 = self.layer3(c3)
161 | c5 = self.layer4(c4)
162 | return c2, c3, c4, c5
163 |
164 | def resnet18(pretrained=False, **kwargs):
165 | """Constructs a ResNet-18 models.
166 |
167 | Args:
168 | pretrained (bool): If True, returns a models pre-trained on ImageNet
169 | """
170 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
171 | if pretrained:
172 | model._load_pretrained_model(model_urls['resnet18'])
173 | return model
174 |
175 |
176 | def resnet34(pretrained=False, **kwargs):
177 | """Constructs a ResNet-34 models.
178 |
179 | Args:
180 | pretrained (bool): If True, returns a models pre-trained on ImageNet
181 | """
182 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
183 | if pretrained:
184 | model._load_pretrained_model(model_urls['resnet34'])
185 | return model
186 |
187 | def resnet50(pretrained=False, **kwargs):
188 | """Constructs a ResNet-50 models.
189 |
190 | Args:
191 | pretrained (bool): If True, returns a models pre-trained on ImageNet
192 | """
193 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
194 | if pretrained:
195 | model._load_pretrained_model(model_urls['resnet50'])
196 | return model
197 |
198 |
199 | def resnet101(pretrained=False, **kwargs):
200 | """Constructs a ResNet-101 models.
201 |
202 | Args:
203 | pretrained (bool): If True, returns a models pre-trained on ImageNet
204 | """
205 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
206 | if pretrained:
207 | model._load_pretrained_model(model_urls['resnet101'])
208 | return model
209 |
210 |
211 | def resnet152(pretrained=False, **kwargs):
212 | """Constructs a ResNet-152 models.
213 |
214 | Args:
215 | pretrained (bool): If True, returns a models pre-trained on ImageNet
216 | """
217 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
218 | if pretrained:
219 | model._load_pretrained_model(model_urls['resnet152'])
220 | return model
221 |
222 |
223 | if __name__ == '__main__':
224 | x = torch.zeros(1, 3, 640, 640)
225 | net = resnet50()
226 | y = net(x)
227 | for u in y:
228 | print(u.shape)
229 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 1/4/19 11:14 AM
3 | # @Author : zhoujun
4 | import torch
5 | from torchvision import transforms
6 | import os
7 | import cv2
8 | import time
9 | import numpy as np
10 |
11 | from pse import decode as pse_decode
12 |
13 |
14 | class Pytorch_model:
15 | def __init__(self, model_path, net, scale, gpu_id=None):
16 | '''
17 | 初始化pytorch模型
18 | :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件)
19 | :param net: 网络计算图,如果在model_path中指定的是参数的保存路径,则需要给出网络的计算图
20 | :param img_channel: 图像的通道数: 1,3
21 | :param gpu_id: 在哪一块gpu上运行
22 | '''
23 | self.scale = scale
24 | if gpu_id is not None and isinstance(gpu_id, int) and torch.cuda.is_available():
25 | self.device = torch.device("cuda:{}".format(gpu_id))
26 | else:
27 | self.device = torch.device("cpu")
28 | self.net = torch.load(model_path, map_location=self.device)['state_dict']
29 | print('device:', self.device)
30 |
31 | if net is not None:
32 | # 如果网络计算图和参数是分开保存的,就执行参数加载
33 | net = net.to(self.device)
34 | net.scale = scale
35 | try:
36 | sk = {}
37 | for k in self.net:
38 | sk[k[7:]] = self.net[k]
39 | net.load_state_dict(sk)
40 | except:
41 | net.load_state_dict(self.net)
42 | self.net = net
43 | print('load models')
44 | self.net.eval()
45 |
46 | def predict(self, img: str, long_size: int = 2240):
47 | '''
48 | 对传入的图像进行预测,支持图像地址,opecv 读取图片,偏慢
49 | :param img: 图像地址
50 | :param is_numpy:
51 | :return:
52 | '''
53 | assert os.path.exists(img), 'file is not exists'
54 | img = cv2.imread(img)
55 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
56 | h, w = img.shape[:2]
57 |
58 | scale = long_size / max(h, w)
59 | img = cv2.resize(img, None, fx=scale, fy=scale)
60 | # 将图片由(w,h)变为(1,img_channel,h,w)
61 | tensor = transforms.ToTensor()(img)
62 | tensor = tensor.unsqueeze_(0)
63 |
64 | tensor = tensor.to(self.device)
65 | with torch.no_grad():
66 | torch.cuda.synchronize()
67 | start = time.time()
68 | preds = self.net(tensor)
69 | preds, boxes_list = pse_decode(preds[0], self.scale)
70 | scale = (preds.shape[1] / w, preds.shape[0] / h)
71 | # print(scale)
72 | # preds, boxes_list = decode(preds,num_pred=-1)
73 | if len(boxes_list):
74 | boxes_list = boxes_list / scale
75 | torch.cuda.synchronize()
76 | t = time.time() - start
77 | return preds, boxes_list, t
78 |
79 |
80 | def _get_annotation(label_path):
81 | boxes = []
82 | with open(label_path, encoding='utf-8', mode='r') as f:
83 | for line in f.readlines():
84 | params = line.strip().strip('\ufeff').strip('\xef\xbb\xbf').split(',')
85 | try:
86 | label = params[8]
87 | if label == '*' or label == '###':
88 | continue
89 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, params[:8]))
90 | boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]])
91 | except:
92 | print('load label failed on {}'.format(label_path))
93 | return np.array(boxes, dtype=np.float32)
94 |
95 |
96 | if __name__ == '__main__':
97 | import config
98 | from models import PSENet
99 | import matplotlib.pyplot as plt
100 | from utils.utils import show_img, draw_bbox
101 |
102 | os.environ['CUDA_VISIBLE_DEVICES'] = str('2')
103 |
104 | model_path = 'output/psenet_icd2015_resnet152_author_crop_adam_warm_up_myloss/best_r0.714011_p0.708214_f10.711100.pth'
105 |
106 | # model_path = 'output/psenet_icd2015_new_loss/final.pth'
107 | img_id = 10
108 | img_path = '/data2/dataset/ICD15/test/img/img_{}.jpg'.format(img_id)
109 | label_path = '/data2/dataset/ICD15/test/gt/gt_img_{}.txt'.format(img_id)
110 | label = _get_annotation(label_path)
111 |
112 | # 初始化网络
113 | net = PSENet(backbone='resnet152', pretrained=False, result_num=config.n)
114 | model = Pytorch_model(model_path, net=net, scale=1, gpu_id=0)
115 | # for i in range(100):
116 | # models.predict(img_path)
117 | preds, boxes_list,t = model.predict(img_path)
118 | print(boxes_list)
119 | show_img(preds)
120 | img = draw_bbox(img_path, boxes_list, color=(0, 0, 255))
121 | cv2.imwrite('result.jpg', img)
122 | # img = draw_bbox(img, label,color=(0,0,255))
123 | show_img(img, color=True)
124 |
125 | plt.show()
126 |
--------------------------------------------------------------------------------
/pse/Makefile:
--------------------------------------------------------------------------------
1 | CXXFLAGS = -I include -std=c++11 -O3 $(shell python3-config --cflags)
2 | LDFLAGS = $(shell python3-config --ldflags)
3 |
4 | DEPS = $(shell find include -xtype f)
5 | CXX_SOURCES = pse.cpp
6 |
7 | LIB_SO = pse.so
8 |
9 | $(LIB_SO): $(CXX_SOURCES) $(DEPS)
10 | $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC
11 |
12 | clean:
13 | rm -rf $(LIB_SO)
14 |
--------------------------------------------------------------------------------
/pse/__init__.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import os
3 | import numpy as np
4 | import cv2
5 | import torch
6 |
7 | BASE_DIR = os.path.dirname(os.path.realpath(__file__))
8 |
9 | if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value
10 | raise RuntimeError('Cannot compile pse: {}'.format(BASE_DIR))
11 |
12 | def pse_warpper(kernals, min_area=5):
13 | '''
14 | reference https://github.com/liuheng92/tensorflow_PSENet/blob/feature_dev/pse
15 | :param kernals:
16 | :param min_area:
17 | :return:
18 | '''
19 | from .pse import pse_cpp
20 | kernal_num = len(kernals)
21 | if not kernal_num:
22 | return np.array([]), []
23 | kernals = np.array(kernals)
24 | label_num, label = cv2.connectedComponents(kernals[0].astype(np.uint8), connectivity=4)
25 | label_values = []
26 | for label_idx in range(1, label_num):
27 | if np.sum(label == label_idx) < min_area:
28 | label[label == label_idx] = 0
29 | continue
30 | label_values.append(label_idx)
31 |
32 | pred = pse_cpp(label, kernals, c=kernal_num)
33 |
34 | return np.array(pred), label_values
35 |
36 |
37 | def decode(preds, scale, threshold=0.7311):
38 | """
39 | 在输出上使用sigmoid 将值转换为置信度,并使用阈值来进行文字和背景的区分
40 | :param preds: 网络输出
41 | :param scale: 网络的scale
42 | :param threshold: sigmoid的阈值
43 | :return: 最后的输出图和文本框
44 | """
45 | preds = torch.sigmoid(preds)
46 | preds = preds.detach().cpu().numpy()
47 |
48 | score = preds[-1].astype(np.float32)
49 | preds = preds > threshold
50 | # preds = preds * preds[-1] # 使用最大的kernel作为其他小图的mask,不使用的话效果更好
51 | pred, label_values = pse_warpper(preds, 5)
52 | bbox_list = []
53 | for label_value in label_values:
54 | points = np.array(np.where(pred == label_value)).transpose((1, 0))[:, ::-1]
55 |
56 | if points.shape[0] < 800 / (scale * scale):
57 | continue
58 |
59 | score_i = np.mean(score[pred == label_value])
60 | if score_i < 0.93:
61 | continue
62 |
63 | rect = cv2.minAreaRect(points)
64 | bbox = cv2.boxPoints(rect)
65 | bbox_list.append([bbox[1], bbox[2], bbox[3], bbox[0]])
66 | return pred, np.array(bbox_list)
67 |
--------------------------------------------------------------------------------
/pse/include/pybind11/buffer_info.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/buffer_info.h: Python buffer object interface
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "detail/common.h"
13 |
14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
15 |
16 | /// Information record describing a Python buffer object
17 | struct buffer_info {
18 | void *ptr = nullptr; // Pointer to the underlying storage
19 | ssize_t itemsize = 0; // Size of individual items in bytes
20 | ssize_t size = 0; // Total number of entries
21 | std::string format; // For homogeneous buffers, this should be set to format_descriptor::format()
22 | ssize_t ndim = 0; // Number of dimensions
23 | std::vector shape; // Shape of the tensor (1 entry per dimension)
24 | std::vector strides; // Number of entries between adjacent entries (for each per dimension)
25 |
26 | buffer_info() { }
27 |
28 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
29 | detail::any_container shape_in, detail::any_container strides_in)
30 | : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
31 | shape(std::move(shape_in)), strides(std::move(strides_in)) {
32 | if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size())
33 | pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
34 | for (size_t i = 0; i < (size_t) ndim; ++i)
35 | size *= shape[i];
36 | }
37 |
38 | template
39 | buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in)
40 | : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { }
41 |
42 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size)
43 | : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { }
44 |
45 | template
46 | buffer_info(T *ptr, ssize_t size)
47 | : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { }
48 |
49 | explicit buffer_info(Py_buffer *view, bool ownview = true)
50 | : buffer_info(view->buf, view->itemsize, view->format, view->ndim,
51 | {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) {
52 | this->view = view;
53 | this->ownview = ownview;
54 | }
55 |
56 | buffer_info(const buffer_info &) = delete;
57 | buffer_info& operator=(const buffer_info &) = delete;
58 |
59 | buffer_info(buffer_info &&other) {
60 | (*this) = std::move(other);
61 | }
62 |
63 | buffer_info& operator=(buffer_info &&rhs) {
64 | ptr = rhs.ptr;
65 | itemsize = rhs.itemsize;
66 | size = rhs.size;
67 | format = std::move(rhs.format);
68 | ndim = rhs.ndim;
69 | shape = std::move(rhs.shape);
70 | strides = std::move(rhs.strides);
71 | std::swap(view, rhs.view);
72 | std::swap(ownview, rhs.ownview);
73 | return *this;
74 | }
75 |
76 | ~buffer_info() {
77 | if (view && ownview) { PyBuffer_Release(view); delete view; }
78 | }
79 |
80 | private:
81 | struct private_ctr_tag { };
82 |
83 | buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
84 | detail::any_container &&shape_in, detail::any_container &&strides_in)
85 | : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { }
86 |
87 | Py_buffer *view = nullptr;
88 | bool ownview = false;
89 | };
90 |
91 | NAMESPACE_BEGIN(detail)
92 |
93 | template struct compare_buffer_info {
94 | static bool compare(const buffer_info& b) {
95 | return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T);
96 | }
97 | };
98 |
99 | template struct compare_buffer_info::value>> {
100 | static bool compare(const buffer_info& b) {
101 | return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value ||
102 | ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) ||
103 | ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n")));
104 | }
105 | };
106 |
107 | NAMESPACE_END(detail)
108 | NAMESPACE_END(PYBIND11_NAMESPACE)
109 |
--------------------------------------------------------------------------------
/pse/include/pybind11/chrono.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime
3 |
4 | Copyright (c) 2016 Trent Houliston and
5 | Wenzel Jakob
6 |
7 | All rights reserved. Use of this source code is governed by a
8 | BSD-style license that can be found in the LICENSE file.
9 | */
10 |
11 | #pragma once
12 |
13 | #include "pybind11.h"
14 | #include
15 | #include
16 | #include
17 | #include
18 |
19 | // Backport the PyDateTime_DELTA functions from Python3.3 if required
20 | #ifndef PyDateTime_DELTA_GET_DAYS
21 | #define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days)
22 | #endif
23 | #ifndef PyDateTime_DELTA_GET_SECONDS
24 | #define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds)
25 | #endif
26 | #ifndef PyDateTime_DELTA_GET_MICROSECONDS
27 | #define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds)
28 | #endif
29 |
30 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
31 | NAMESPACE_BEGIN(detail)
32 |
33 | template class duration_caster {
34 | public:
35 | typedef typename type::rep rep;
36 | typedef typename type::period period;
37 |
38 | typedef std::chrono::duration> days;
39 |
40 | bool load(handle src, bool) {
41 | using namespace std::chrono;
42 |
43 | // Lazy initialise the PyDateTime import
44 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
45 |
46 | if (!src) return false;
47 | // If invoked with datetime.delta object
48 | if (PyDelta_Check(src.ptr())) {
49 | value = type(duration_cast>(
50 | days(PyDateTime_DELTA_GET_DAYS(src.ptr()))
51 | + seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr()))
52 | + microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr()))));
53 | return true;
54 | }
55 | // If invoked with a float we assume it is seconds and convert
56 | else if (PyFloat_Check(src.ptr())) {
57 | value = type(duration_cast>(duration(PyFloat_AsDouble(src.ptr()))));
58 | return true;
59 | }
60 | else return false;
61 | }
62 |
63 | // If this is a duration just return it back
64 | static const std::chrono::duration& get_duration(const std::chrono::duration &src) {
65 | return src;
66 | }
67 |
68 | // If this is a time_point get the time_since_epoch
69 | template static std::chrono::duration get_duration(const std::chrono::time_point> &src) {
70 | return src.time_since_epoch();
71 | }
72 |
73 | static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) {
74 | using namespace std::chrono;
75 |
76 | // Use overloaded function to get our duration from our source
77 | // Works out if it is a duration or time_point and get the duration
78 | auto d = get_duration(src);
79 |
80 | // Lazy initialise the PyDateTime import
81 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
82 |
83 | // Declare these special duration types so the conversions happen with the correct primitive types (int)
84 | using dd_t = duration>;
85 | using ss_t = duration>;
86 | using us_t = duration;
87 |
88 | auto dd = duration_cast(d);
89 | auto subd = d - dd;
90 | auto ss = duration_cast(subd);
91 | auto us = duration_cast(subd - ss);
92 | return PyDelta_FromDSU(dd.count(), ss.count(), us.count());
93 | }
94 |
95 | PYBIND11_TYPE_CASTER(type, _("datetime.timedelta"));
96 | };
97 |
98 | // This is for casting times on the system clock into datetime.datetime instances
99 | template class type_caster> {
100 | public:
101 | typedef std::chrono::time_point type;
102 | bool load(handle src, bool) {
103 | using namespace std::chrono;
104 |
105 | // Lazy initialise the PyDateTime import
106 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
107 |
108 | if (!src) return false;
109 | if (PyDateTime_Check(src.ptr())) {
110 | std::tm cal;
111 | cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr());
112 | cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr());
113 | cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr());
114 | cal.tm_mday = PyDateTime_GET_DAY(src.ptr());
115 | cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1;
116 | cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900;
117 | cal.tm_isdst = -1;
118 |
119 | value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr()));
120 | return true;
121 | }
122 | else return false;
123 | }
124 |
125 | static handle cast(const std::chrono::time_point &src, return_value_policy /* policy */, handle /* parent */) {
126 | using namespace std::chrono;
127 |
128 | // Lazy initialise the PyDateTime import
129 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
130 |
131 | std::time_t tt = system_clock::to_time_t(src);
132 | // this function uses static memory so it's best to copy it out asap just in case
133 | // otherwise other code that is using localtime may break this (not just python code)
134 | std::tm localtime = *std::localtime(&tt);
135 |
136 | // Declare these special duration types so the conversions happen with the correct primitive types (int)
137 | using us_t = duration;
138 |
139 | return PyDateTime_FromDateAndTime(localtime.tm_year + 1900,
140 | localtime.tm_mon + 1,
141 | localtime.tm_mday,
142 | localtime.tm_hour,
143 | localtime.tm_min,
144 | localtime.tm_sec,
145 | (duration_cast(src.time_since_epoch() % seconds(1))).count());
146 | }
147 | PYBIND11_TYPE_CASTER(type, _("datetime.datetime"));
148 | };
149 |
150 | // Other clocks that are not the system clock are not measured as datetime.datetime objects
151 | // since they are not measured on calendar time. So instead we just make them timedeltas
152 | // Or if they have passed us a time as a float we convert that
153 | template class type_caster>
154 | : public duration_caster> {
155 | };
156 |
157 | template class type_caster>
158 | : public duration_caster> {
159 | };
160 |
161 | NAMESPACE_END(detail)
162 | NAMESPACE_END(PYBIND11_NAMESPACE)
163 |
--------------------------------------------------------------------------------
/pse/include/pybind11/common.h:
--------------------------------------------------------------------------------
1 | #include "detail/common.h"
2 | #warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'."
3 |
--------------------------------------------------------------------------------
/pse/include/pybind11/complex.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/complex.h: Complex number support
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "pybind11.h"
13 | #include
14 |
15 | /// glibc defines I as a macro which breaks things, e.g., boost template names
16 | #ifdef I
17 | # undef I
18 | #endif
19 |
20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
21 |
22 | template struct format_descriptor, detail::enable_if_t::value>> {
23 | static constexpr const char c = format_descriptor::c;
24 | static constexpr const char value[3] = { 'Z', c, '\0' };
25 | static std::string format() { return std::string(value); }
26 | };
27 |
28 | #ifndef PYBIND11_CPP17
29 |
30 | template constexpr const char format_descriptor<
31 | std::complex, detail::enable_if_t::value>>::value[3];
32 |
33 | #endif
34 |
35 | NAMESPACE_BEGIN(detail)
36 |
37 | template struct is_fmt_numeric, detail::enable_if_t::value>> {
38 | static constexpr bool value = true;
39 | static constexpr int index = is_fmt_numeric::index + 3;
40 | };
41 |
42 | template class type_caster> {
43 | public:
44 | bool load(handle src, bool convert) {
45 | if (!src)
46 | return false;
47 | if (!convert && !PyComplex_Check(src.ptr()))
48 | return false;
49 | Py_complex result = PyComplex_AsCComplex(src.ptr());
50 | if (result.real == -1.0 && PyErr_Occurred()) {
51 | PyErr_Clear();
52 | return false;
53 | }
54 | value = std::complex((T) result.real, (T) result.imag);
55 | return true;
56 | }
57 |
58 | static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) {
59 | return PyComplex_FromDoubles((double) src.real(), (double) src.imag());
60 | }
61 |
62 | PYBIND11_TYPE_CASTER(std::complex, _("complex"));
63 | };
64 | NAMESPACE_END(detail)
65 | NAMESPACE_END(PYBIND11_NAMESPACE)
66 |
--------------------------------------------------------------------------------
/pse/include/pybind11/descr.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/descr.h: Helper type for concatenating type signatures
3 | either at runtime (C++11) or compile time (C++14)
4 |
5 | Copyright (c) 2016 Wenzel Jakob
6 |
7 | All rights reserved. Use of this source code is governed by a
8 | BSD-style license that can be found in the LICENSE file.
9 | */
10 |
11 | #pragma once
12 |
13 | #include "common.h"
14 |
15 | NAMESPACE_BEGIN(pybind11)
16 | NAMESPACE_BEGIN(detail)
17 |
18 | /* Concatenate type signatures at compile time using C++14 */
19 | #if defined(PYBIND11_CPP14) && !defined(_MSC_VER)
20 | #define PYBIND11_CONSTEXPR_DESCR
21 |
22 | template class descr {
23 | template friend class descr;
24 | public:
25 | constexpr descr(char const (&text) [Size1+1], const std::type_info * const (&types)[Size2+1])
26 | : descr(text, types,
27 | make_index_sequence(),
28 | make_index_sequence()) { }
29 |
30 | constexpr const char *text() const { return m_text; }
31 | constexpr const std::type_info * const * types() const { return m_types; }
32 |
33 | template
34 | constexpr descr operator+(const descr &other) const {
35 | return concat(other,
36 | make_index_sequence(),
37 | make_index_sequence(),
38 | make_index_sequence(),
39 | make_index_sequence());
40 | }
41 |
42 | protected:
43 | template
44 | constexpr descr(
45 | char const (&text) [Size1+1],
46 | const std::type_info * const (&types) [Size2+1],
47 | index_sequence, index_sequence)
48 | : m_text{text[Indices1]..., '\0'},
49 | m_types{types[Indices2]..., nullptr } {}
50 |
51 | template
53 | constexpr descr
54 | concat(const descr &other,
55 | index_sequence, index_sequence,
56 | index_sequence, index_sequence) const {
57 | return descr(
58 | { m_text[Indices1]..., other.m_text[OtherIndices1]..., '\0' },
59 | { m_types[Indices2]..., other.m_types[OtherIndices2]..., nullptr }
60 | );
61 | }
62 |
63 | protected:
64 | char m_text[Size1 + 1];
65 | const std::type_info * m_types[Size2 + 1];
66 | };
67 |
68 | template constexpr descr _(char const(&text)[Size]) {
69 | return descr(text, { nullptr });
70 | }
71 |
72 | template struct int_to_str : int_to_str { };
73 | template struct int_to_str<0, Digits...> {
74 | static constexpr auto digits = descr({ ('0' + Digits)..., '\0' }, { nullptr });
75 | };
76 |
77 | // Ternary description (like std::conditional)
78 | template
79 | constexpr enable_if_t> _(char const(&text1)[Size1], char const(&)[Size2]) {
80 | return _(text1);
81 | }
82 | template
83 | constexpr enable_if_t> _(char const(&)[Size1], char const(&text2)[Size2]) {
84 | return _(text2);
85 | }
86 | template
87 | constexpr enable_if_t> _(descr d, descr) { return d; }
88 | template
89 | constexpr enable_if_t> _(descr, descr d) { return d; }
90 |
91 | template auto constexpr _() -> decltype(int_to_str::digits) {
92 | return int_to_str::digits;
93 | }
94 |
95 | template constexpr descr<1, 1> _() {
96 | return descr<1, 1>({ '%', '\0' }, { &typeid(Type), nullptr });
97 | }
98 |
99 | inline constexpr descr<0, 0> concat() { return _(""); }
100 | template auto constexpr concat(descr descr) { return descr; }
101 | template auto constexpr concat(descr descr, Args&&... args) { return descr + _(", ") + concat(args...); }
102 | template auto constexpr type_descr(descr descr) { return _("{") + descr + _("}"); }
103 |
104 | #define PYBIND11_DESCR constexpr auto
105 |
106 | #else /* Simpler C++11 implementation based on run-time memory allocation and copying */
107 |
108 | class descr {
109 | public:
110 | PYBIND11_NOINLINE descr(const char *text, const std::type_info * const * types) {
111 | size_t nChars = len(text), nTypes = len(types);
112 | m_text = new char[nChars];
113 | m_types = new const std::type_info *[nTypes];
114 | memcpy(m_text, text, nChars * sizeof(char));
115 | memcpy(m_types, types, nTypes * sizeof(const std::type_info *));
116 | }
117 |
118 | PYBIND11_NOINLINE descr operator+(descr &&d2) && {
119 | descr r;
120 |
121 | size_t nChars1 = len(m_text), nTypes1 = len(m_types);
122 | size_t nChars2 = len(d2.m_text), nTypes2 = len(d2.m_types);
123 |
124 | r.m_text = new char[nChars1 + nChars2 - 1];
125 | r.m_types = new const std::type_info *[nTypes1 + nTypes2 - 1];
126 | memcpy(r.m_text, m_text, (nChars1-1) * sizeof(char));
127 | memcpy(r.m_text + nChars1 - 1, d2.m_text, nChars2 * sizeof(char));
128 | memcpy(r.m_types, m_types, (nTypes1-1) * sizeof(std::type_info *));
129 | memcpy(r.m_types + nTypes1 - 1, d2.m_types, nTypes2 * sizeof(std::type_info *));
130 |
131 | delete[] m_text; delete[] m_types;
132 | delete[] d2.m_text; delete[] d2.m_types;
133 |
134 | return r;
135 | }
136 |
137 | char *text() { return m_text; }
138 | const std::type_info * * types() { return m_types; }
139 |
140 | protected:
141 | PYBIND11_NOINLINE descr() { }
142 |
143 | template static size_t len(const T *ptr) { // return length including null termination
144 | const T *it = ptr;
145 | while (*it++ != (T) 0)
146 | ;
147 | return static_cast(it - ptr);
148 | }
149 |
150 | const std::type_info **m_types = nullptr;
151 | char *m_text = nullptr;
152 | };
153 |
154 | /* The 'PYBIND11_NOINLINE inline' combinations below are intentional to get the desired linkage while producing as little object code as possible */
155 |
156 | PYBIND11_NOINLINE inline descr _(const char *text) {
157 | const std::type_info *types[1] = { nullptr };
158 | return descr(text, types);
159 | }
160 |
161 | template PYBIND11_NOINLINE enable_if_t _(const char *text1, const char *) { return _(text1); }
162 | template PYBIND11_NOINLINE enable_if_t _(char const *, const char *text2) { return _(text2); }
163 | template PYBIND11_NOINLINE enable_if_t _(descr d, descr) { return d; }
164 | template PYBIND11_NOINLINE enable_if_t _(descr, descr d) { return d; }
165 |
166 | template PYBIND11_NOINLINE descr _() {
167 | const std::type_info *types[2] = { &typeid(Type), nullptr };
168 | return descr("%", types);
169 | }
170 |
171 | template PYBIND11_NOINLINE descr _() {
172 | const std::type_info *types[1] = { nullptr };
173 | return descr(std::to_string(Size).c_str(), types);
174 | }
175 |
176 | PYBIND11_NOINLINE inline descr concat() { return _(""); }
177 | PYBIND11_NOINLINE inline descr concat(descr &&d) { return d; }
178 | template PYBIND11_NOINLINE descr concat(descr &&d, Args&&... args) { return std::move(d) + _(", ") + concat(std::forward(args)...); }
179 | PYBIND11_NOINLINE inline descr type_descr(descr&& d) { return _("{") + std::move(d) + _("}"); }
180 |
181 | #define PYBIND11_DESCR ::pybind11::detail::descr
182 | #endif
183 |
184 | NAMESPACE_END(detail)
185 | NAMESPACE_END(pybind11)
186 |
--------------------------------------------------------------------------------
/pse/include/pybind11/detail/descr.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "common.h"
13 |
14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
15 | NAMESPACE_BEGIN(detail)
16 |
17 | #if !defined(_MSC_VER)
18 | # define PYBIND11_DESCR_CONSTEXPR static constexpr
19 | #else
20 | # define PYBIND11_DESCR_CONSTEXPR const
21 | #endif
22 |
23 | /* Concatenate type signatures at compile time */
24 | template
25 | struct descr {
26 | char text[N + 1];
27 |
28 | constexpr descr() : text{'\0'} { }
29 | constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence()) { }
30 |
31 | template
32 | constexpr descr(char const (&s)[N+1], index_sequence) : text{s[Is]..., '\0'} { }
33 |
34 | template
35 | constexpr descr(char c, Chars... cs) : text{c, static_cast(cs)..., '\0'} { }
36 |
37 | static constexpr std::array types() {
38 | return {{&typeid(Ts)..., nullptr}};
39 | }
40 | };
41 |
42 | template
43 | constexpr descr plus_impl(const descr &a, const descr &b,
44 | index_sequence, index_sequence) {
45 | return {a.text[Is1]..., b.text[Is2]...};
46 | }
47 |
48 | template
49 | constexpr descr operator+(const descr &a, const descr &b) {
50 | return plus_impl(a, b, make_index_sequence(), make_index_sequence());
51 | }
52 |
53 | template
54 | constexpr descr _(char const(&text)[N]) { return descr(text); }
55 | constexpr descr<0> _(char const(&)[1]) { return {}; }
56 |
57 | template struct int_to_str : int_to_str { };
58 | template struct int_to_str<0, Digits...> {
59 | static constexpr auto digits = descr(('0' + Digits)...);
60 | };
61 |
62 | // Ternary description (like std::conditional)
63 | template
64 | constexpr enable_if_t> _(char const(&text1)[N1], char const(&)[N2]) {
65 | return _(text1);
66 | }
67 | template
68 | constexpr enable_if_t> _(char const(&)[N1], char const(&text2)[N2]) {
69 | return _(text2);
70 | }
71 |
72 | template
73 | constexpr enable_if_t _(const T1 &d, const T2 &) { return d; }
74 | template
75 | constexpr enable_if_t _(const T1 &, const T2 &d) { return d; }
76 |
77 | template auto constexpr _() -> decltype(int_to_str::digits) {
78 | return int_to_str::digits;
79 | }
80 |
81 | template constexpr descr<1, Type> _() { return {'%'}; }
82 |
83 | constexpr descr<0> concat() { return {}; }
84 |
85 | template
86 | constexpr descr concat(const descr &descr) { return descr; }
87 |
88 | template
89 | constexpr auto concat(const descr &d, const Args &...args)
90 | -> decltype(std::declval>() + concat(args...)) {
91 | return d + _(", ") + concat(args...);
92 | }
93 |
94 | template
95 | constexpr descr type_descr(const descr &descr) {
96 | return _("{") + descr + _("}");
97 | }
98 |
99 | NAMESPACE_END(detail)
100 | NAMESPACE_END(PYBIND11_NAMESPACE)
101 |
--------------------------------------------------------------------------------
/pse/include/pybind11/detail/internals.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/detail/internals.h: Internal data structure and related functions
3 |
4 | Copyright (c) 2017 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "../pytypes.h"
13 |
14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
15 | NAMESPACE_BEGIN(detail)
16 | // Forward declarations
17 | inline PyTypeObject *make_static_property_type();
18 | inline PyTypeObject *make_default_metaclass();
19 | inline PyObject *make_object_base_type(PyTypeObject *metaclass);
20 |
21 | // The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new
22 | // Thread Specific Storage (TSS) API.
23 | #if PY_VERSION_HEX >= 0x03070000
24 | # define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr
25 | # define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key))
26 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (tstate))
27 | # define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr)
28 | #else
29 | // Usually an int but a long on Cygwin64 with Python 3.x
30 | # define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0
31 | # define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key))
32 | # if PY_MAJOR_VERSION < 3
33 | # define PYBIND11_TLS_DELETE_VALUE(key) \
34 | PyThread_delete_key_value(key)
35 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) \
36 | do { \
37 | PyThread_delete_key_value((key)); \
38 | PyThread_set_key_value((key), (value)); \
39 | } while (false)
40 | # else
41 | # define PYBIND11_TLS_DELETE_VALUE(key) \
42 | PyThread_set_key_value((key), nullptr)
43 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) \
44 | PyThread_set_key_value((key), (value))
45 | # endif
46 | #endif
47 |
48 | // Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly
49 | // other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module
50 | // even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under
51 | // libstdc++, this doesn't happen: equality and the type_index hash are based on the type name,
52 | // which works. If not under a known-good stl, provide our own name-based hash and equality
53 | // functions that use the type name.
54 | #if defined(__GLIBCXX__)
55 | inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; }
56 | using type_hash = std::hash;
57 | using type_equal_to = std::equal_to;
58 | #else
59 | inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) {
60 | return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
61 | }
62 |
63 | struct type_hash {
64 | size_t operator()(const std::type_index &t) const {
65 | size_t hash = 5381;
66 | const char *ptr = t.name();
67 | while (auto c = static_cast(*ptr++))
68 | hash = (hash * 33) ^ c;
69 | return hash;
70 | }
71 | };
72 |
73 | struct type_equal_to {
74 | bool operator()(const std::type_index &lhs, const std::type_index &rhs) const {
75 | return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0;
76 | }
77 | };
78 | #endif
79 |
80 | template
81 | using type_map = std::unordered_map;
82 |
83 | struct overload_hash {
84 | inline size_t operator()(const std::pair& v) const {
85 | size_t value = std::hash()(v.first);
86 | value ^= std::hash()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2);
87 | return value;
88 | }
89 | };
90 |
91 | /// Internal data structure used to track registered instances and types.
92 | /// Whenever binary incompatible changes are made to this structure,
93 | /// `PYBIND11_INTERNALS_VERSION` must be incremented.
94 | struct internals {
95 | type_map registered_types_cpp; // std::type_index -> pybind11's type information
96 | std::unordered_map> registered_types_py; // PyTypeObject* -> base type_info(s)
97 | std::unordered_multimap registered_instances; // void * -> instance*
98 | std::unordered_set, overload_hash> inactive_overload_cache;
99 | type_map> direct_conversions;
100 | std::unordered_map> patients;
101 | std::forward_list registered_exception_translators;
102 | std::unordered_map shared_data; // Custom data to be shared across extensions
103 | std::vector loader_patient_stack; // Used by `loader_life_support`
104 | std::forward_list static_strings; // Stores the std::strings backing detail::c_str()
105 | PyTypeObject *static_property_type;
106 | PyTypeObject *default_metaclass;
107 | PyObject *instance_base;
108 | #if defined(WITH_THREAD)
109 | PYBIND11_TLS_KEY_INIT(tstate);
110 | PyInterpreterState *istate = nullptr;
111 | #endif
112 | };
113 |
114 | /// Additional type information which does not fit into the PyTypeObject.
115 | /// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`.
116 | struct type_info {
117 | PyTypeObject *type;
118 | const std::type_info *cpptype;
119 | size_t type_size, type_align, holder_size_in_ptrs;
120 | void *(*operator_new)(size_t);
121 | void (*init_instance)(instance *, const void *);
122 | void (*dealloc)(value_and_holder &v_h);
123 | std::vector implicit_conversions;
124 | std::vector> implicit_casts;
125 | std::vector *direct_conversions;
126 | buffer_info *(*get_buffer)(PyObject *, void *) = nullptr;
127 | void *get_buffer_data = nullptr;
128 | void *(*module_local_load)(PyObject *, const type_info *) = nullptr;
129 | /* A simple type never occurs as a (direct or indirect) parent
130 | * of a class that makes use of multiple inheritance */
131 | bool simple_type : 1;
132 | /* True if there is no multiple inheritance in this type's inheritance tree */
133 | bool simple_ancestors : 1;
134 | /* for base vs derived holder_type checks */
135 | bool default_holder : 1;
136 | /* true if this is a type registered with py::module_local */
137 | bool module_local : 1;
138 | };
139 |
140 | /// Tracks the `internals` and `type_info` ABI version independent of the main library version
141 | #define PYBIND11_INTERNALS_VERSION 3
142 |
143 | #if defined(_DEBUG)
144 | # define PYBIND11_BUILD_TYPE "_debug"
145 | #else
146 | # define PYBIND11_BUILD_TYPE ""
147 | #endif
148 |
149 | #if defined(WITH_THREAD)
150 | # define PYBIND11_INTERNALS_KIND ""
151 | #else
152 | # define PYBIND11_INTERNALS_KIND "_without_thread"
153 | #endif
154 |
155 | #define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \
156 | PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
157 |
158 | #define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \
159 | PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__"
160 |
161 | /// Each module locally stores a pointer to the `internals` data. The data
162 | /// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`.
163 | inline internals **&get_internals_pp() {
164 | static internals **internals_pp = nullptr;
165 | return internals_pp;
166 | }
167 |
168 | /// Return a reference to the current `internals` data
169 | PYBIND11_NOINLINE inline internals &get_internals() {
170 | auto **&internals_pp = get_internals_pp();
171 | if (internals_pp && *internals_pp)
172 | return **internals_pp;
173 |
174 | constexpr auto *id = PYBIND11_INTERNALS_ID;
175 | auto builtins = handle(PyEval_GetBuiltins());
176 | if (builtins.contains(id) && isinstance(builtins[id])) {
177 | internals_pp = static_cast(capsule(builtins[id]));
178 |
179 | // We loaded builtins through python's builtins, which means that our `error_already_set`
180 | // and `builtin_exception` may be different local classes than the ones set up in the
181 | // initial exception translator, below, so add another for our local exception classes.
182 | //
183 | // libstdc++ doesn't require this (types there are identified only by name)
184 | #if !defined(__GLIBCXX__)
185 | (*internals_pp)->registered_exception_translators.push_front(
186 | [](std::exception_ptr p) -> void {
187 | try {
188 | if (p) std::rethrow_exception(p);
189 | } catch (error_already_set &e) { e.restore(); return;
190 | } catch (const builtin_exception &e) { e.set_error(); return;
191 | }
192 | }
193 | );
194 | #endif
195 | } else {
196 | if (!internals_pp) internals_pp = new internals*();
197 | auto *&internals_ptr = *internals_pp;
198 | internals_ptr = new internals();
199 | #if defined(WITH_THREAD)
200 | PyEval_InitThreads();
201 | PyThreadState *tstate = PyThreadState_Get();
202 | #if PY_VERSION_HEX >= 0x03070000
203 | internals_ptr->tstate = PyThread_tss_alloc();
204 | if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate))
205 | pybind11_fail("get_internals: could not successfully initialize the TSS key!");
206 | PyThread_tss_set(internals_ptr->tstate, tstate);
207 | #else
208 | internals_ptr->tstate = PyThread_create_key();
209 | if (internals_ptr->tstate == -1)
210 | pybind11_fail("get_internals: could not successfully initialize the TLS key!");
211 | PyThread_set_key_value(internals_ptr->tstate, tstate);
212 | #endif
213 | internals_ptr->istate = tstate->interp;
214 | #endif
215 | builtins[id] = capsule(internals_pp);
216 | internals_ptr->registered_exception_translators.push_front(
217 | [](std::exception_ptr p) -> void {
218 | try {
219 | if (p) std::rethrow_exception(p);
220 | } catch (error_already_set &e) { e.restore(); return;
221 | } catch (const builtin_exception &e) { e.set_error(); return;
222 | } catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return;
223 | } catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
224 | } catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
225 | } catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
226 | } catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return;
227 | } catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return;
228 | } catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return;
229 | } catch (...) {
230 | PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!");
231 | return;
232 | }
233 | }
234 | );
235 | internals_ptr->static_property_type = make_static_property_type();
236 | internals_ptr->default_metaclass = make_default_metaclass();
237 | internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass);
238 | }
239 | return **internals_pp;
240 | }
241 |
242 | /// Works like `internals.registered_types_cpp`, but for module-local registered types:
243 | inline type_map ®istered_local_types_cpp() {
244 | static type_map locals{};
245 | return locals;
246 | }
247 |
248 | /// Constructs a std::string with the given arguments, stores it in `internals`, and returns its
249 | /// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only
250 | /// cleared when the program exits or after interpreter shutdown (when embedding), and so are
251 | /// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name).
252 | template
253 | const char *c_str(Args &&...args) {
254 | auto &strings = get_internals().static_strings;
255 | strings.emplace_front(std::forward(args)...);
256 | return strings.front().c_str();
257 | }
258 |
259 | NAMESPACE_END(detail)
260 |
261 | /// Returns a named pointer that is shared among all extension modules (using the same
262 | /// pybind11 version) running in the current interpreter. Names starting with underscores
263 | /// are reserved for internal usage. Returns `nullptr` if no matching entry was found.
264 | inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) {
265 | auto &internals = detail::get_internals();
266 | auto it = internals.shared_data.find(name);
267 | return it != internals.shared_data.end() ? it->second : nullptr;
268 | }
269 |
270 | /// Set the shared data that can be later recovered by `get_shared_data()`.
271 | inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) {
272 | detail::get_internals().shared_data[name] = data;
273 | return data;
274 | }
275 |
276 | /// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if
277 | /// such entry exists. Otherwise, a new object of default-constructible type `T` is
278 | /// added to the shared data under the given name and a reference to it is returned.
279 | template
280 | T &get_or_create_shared_data(const std::string &name) {
281 | auto &internals = detail::get_internals();
282 | auto it = internals.shared_data.find(name);
283 | T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr);
284 | if (!ptr) {
285 | ptr = new T();
286 | internals.shared_data[name] = ptr;
287 | }
288 | return *ptr;
289 | }
290 |
291 | NAMESPACE_END(PYBIND11_NAMESPACE)
292 |
--------------------------------------------------------------------------------
/pse/include/pybind11/detail/typeid.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/detail/typeid.h: Compiler-independent access to type identifiers
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include
13 | #include
14 |
15 | #if defined(__GNUG__)
16 | #include
17 | #endif
18 |
19 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
20 | NAMESPACE_BEGIN(detail)
21 | /// Erase all occurrences of a substring
22 | inline void erase_all(std::string &string, const std::string &search) {
23 | for (size_t pos = 0;;) {
24 | pos = string.find(search, pos);
25 | if (pos == std::string::npos) break;
26 | string.erase(pos, search.length());
27 | }
28 | }
29 |
30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) {
31 | #if defined(__GNUG__)
32 | int status = 0;
33 | std::unique_ptr res {
34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free };
35 | if (status == 0)
36 | name = res.get();
37 | #else
38 | detail::erase_all(name, "class ");
39 | detail::erase_all(name, "struct ");
40 | detail::erase_all(name, "enum ");
41 | #endif
42 | detail::erase_all(name, "pybind11::");
43 | }
44 | NAMESPACE_END(detail)
45 |
46 | /// Return a string representation of a C++ type
47 | template static std::string type_id() {
48 | std::string name(typeid(T).name());
49 | detail::clean_type_id(name);
50 | return name;
51 | }
52 |
53 | NAMESPACE_END(PYBIND11_NAMESPACE)
54 |
--------------------------------------------------------------------------------
/pse/include/pybind11/embed.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/embed.h: Support for embedding the interpreter
3 |
4 | Copyright (c) 2017 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "pybind11.h"
13 | #include "eval.h"
14 |
15 | #if defined(PYPY_VERSION)
16 | # error Embedding the interpreter is not supported with PyPy
17 | #endif
18 |
19 | #if PY_MAJOR_VERSION >= 3
20 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
21 | extern "C" PyObject *pybind11_init_impl_##name() { \
22 | return pybind11_init_wrapper_##name(); \
23 | }
24 | #else
25 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
26 | extern "C" void pybind11_init_impl_##name() { \
27 | pybind11_init_wrapper_##name(); \
28 | }
29 | #endif
30 |
31 | /** \rst
32 | Add a new module to the table of builtins for the interpreter. Must be
33 | defined in global scope. The first macro parameter is the name of the
34 | module (without quotes). The second parameter is the variable which will
35 | be used as the interface to add functions and classes to the module.
36 |
37 | .. code-block:: cpp
38 |
39 | PYBIND11_EMBEDDED_MODULE(example, m) {
40 | // ... initialize functions and classes here
41 | m.def("foo", []() {
42 | return "Hello, World!";
43 | });
44 | }
45 | \endrst */
46 | #define PYBIND11_EMBEDDED_MODULE(name, variable) \
47 | static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
48 | static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \
49 | auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
50 | try { \
51 | PYBIND11_CONCAT(pybind11_init_, name)(m); \
52 | return m.ptr(); \
53 | } catch (pybind11::error_already_set &e) { \
54 | PyErr_SetString(PyExc_ImportError, e.what()); \
55 | return nullptr; \
56 | } catch (const std::exception &e) { \
57 | PyErr_SetString(PyExc_ImportError, e.what()); \
58 | return nullptr; \
59 | } \
60 | } \
61 | PYBIND11_EMBEDDED_MODULE_IMPL(name) \
62 | pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \
63 | PYBIND11_CONCAT(pybind11_init_impl_, name)); \
64 | void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
65 |
66 |
67 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
68 | NAMESPACE_BEGIN(detail)
69 |
70 | /// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks.
71 | struct embedded_module {
72 | #if PY_MAJOR_VERSION >= 3
73 | using init_t = PyObject *(*)();
74 | #else
75 | using init_t = void (*)();
76 | #endif
77 | embedded_module(const char *name, init_t init) {
78 | if (Py_IsInitialized())
79 | pybind11_fail("Can't add new modules after the interpreter has been initialized");
80 |
81 | auto result = PyImport_AppendInittab(name, init);
82 | if (result == -1)
83 | pybind11_fail("Insufficient memory to add a new module");
84 | }
85 | };
86 |
87 | NAMESPACE_END(detail)
88 |
89 | /** \rst
90 | Initialize the Python interpreter. No other pybind11 or CPython API functions can be
91 | called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The
92 | optional parameter can be used to skip the registration of signal handlers (see the
93 | `Python documentation`_ for details). Calling this function again after the interpreter
94 | has already been initialized is a fatal error.
95 |
96 | If initializing the Python interpreter fails, then the program is terminated. (This
97 | is controlled by the CPython runtime and is an exception to pybind11's normal behavior
98 | of throwing exceptions on errors.)
99 |
100 | .. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx
101 | \endrst */
102 | inline void initialize_interpreter(bool init_signal_handlers = true) {
103 | if (Py_IsInitialized())
104 | pybind11_fail("The interpreter is already running");
105 |
106 | Py_InitializeEx(init_signal_handlers ? 1 : 0);
107 |
108 | // Make .py files in the working directory available by default
109 | module::import("sys").attr("path").cast().append(".");
110 | }
111 |
112 | /** \rst
113 | Shut down the Python interpreter. No pybind11 or CPython API functions can be called
114 | after this. In addition, pybind11 objects must not outlive the interpreter:
115 |
116 | .. code-block:: cpp
117 |
118 | { // BAD
119 | py::initialize_interpreter();
120 | auto hello = py::str("Hello, World!");
121 | py::finalize_interpreter();
122 | } // <-- BOOM, hello's destructor is called after interpreter shutdown
123 |
124 | { // GOOD
125 | py::initialize_interpreter();
126 | { // scoped
127 | auto hello = py::str("Hello, World!");
128 | } // <-- OK, hello is cleaned up properly
129 | py::finalize_interpreter();
130 | }
131 |
132 | { // BETTER
133 | py::scoped_interpreter guard{};
134 | auto hello = py::str("Hello, World!");
135 | }
136 |
137 | .. warning::
138 |
139 | The interpreter can be restarted by calling `initialize_interpreter` again.
140 | Modules created using pybind11 can be safely re-initialized. However, Python
141 | itself cannot completely unload binary extension modules and there are several
142 | caveats with regard to interpreter restarting. All the details can be found
143 | in the CPython documentation. In short, not all interpreter memory may be
144 | freed, either due to reference cycles or user-created global data.
145 |
146 | \endrst */
147 | inline void finalize_interpreter() {
148 | handle builtins(PyEval_GetBuiltins());
149 | const char *id = PYBIND11_INTERNALS_ID;
150 |
151 | // Get the internals pointer (without creating it if it doesn't exist). It's possible for the
152 | // internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()`
153 | // during destruction), so we get the pointer-pointer here and check it after Py_Finalize().
154 | detail::internals **internals_ptr_ptr = detail::get_internals_pp();
155 | // It could also be stashed in builtins, so look there too:
156 | if (builtins.contains(id) && isinstance(builtins[id]))
157 | internals_ptr_ptr = capsule(builtins[id]);
158 |
159 | Py_Finalize();
160 |
161 | if (internals_ptr_ptr) {
162 | delete *internals_ptr_ptr;
163 | *internals_ptr_ptr = nullptr;
164 | }
165 | }
166 |
167 | /** \rst
168 | Scope guard version of `initialize_interpreter` and `finalize_interpreter`.
169 | This a move-only guard and only a single instance can exist.
170 |
171 | .. code-block:: cpp
172 |
173 | #include
174 |
175 | int main() {
176 | py::scoped_interpreter guard{};
177 | py::print(Hello, World!);
178 | } // <-- interpreter shutdown
179 | \endrst */
180 | class scoped_interpreter {
181 | public:
182 | scoped_interpreter(bool init_signal_handlers = true) {
183 | initialize_interpreter(init_signal_handlers);
184 | }
185 |
186 | scoped_interpreter(const scoped_interpreter &) = delete;
187 | scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; }
188 | scoped_interpreter &operator=(const scoped_interpreter &) = delete;
189 | scoped_interpreter &operator=(scoped_interpreter &&) = delete;
190 |
191 | ~scoped_interpreter() {
192 | if (is_valid)
193 | finalize_interpreter();
194 | }
195 |
196 | private:
197 | bool is_valid = true;
198 | };
199 |
200 | NAMESPACE_END(PYBIND11_NAMESPACE)
201 |
--------------------------------------------------------------------------------
/pse/include/pybind11/eval.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/exec.h: Support for evaluating Python expressions and statements
3 | from strings and files
4 |
5 | Copyright (c) 2016 Klemens Morgenstern and
6 | Wenzel Jakob
7 |
8 | All rights reserved. Use of this source code is governed by a
9 | BSD-style license that can be found in the LICENSE file.
10 | */
11 |
12 | #pragma once
13 |
14 | #include "pybind11.h"
15 |
16 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
17 |
18 | enum eval_mode {
19 | /// Evaluate a string containing an isolated expression
20 | eval_expr,
21 |
22 | /// Evaluate a string containing a single statement. Returns \c none
23 | eval_single_statement,
24 |
25 | /// Evaluate a string containing a sequence of statement. Returns \c none
26 | eval_statements
27 | };
28 |
29 | template
30 | object eval(str expr, object global = globals(), object local = object()) {
31 | if (!local)
32 | local = global;
33 |
34 | /* PyRun_String does not accept a PyObject / encoding specifier,
35 | this seems to be the only alternative */
36 | std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr;
37 |
38 | int start;
39 | switch (mode) {
40 | case eval_expr: start = Py_eval_input; break;
41 | case eval_single_statement: start = Py_single_input; break;
42 | case eval_statements: start = Py_file_input; break;
43 | default: pybind11_fail("invalid evaluation mode");
44 | }
45 |
46 | PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr());
47 | if (!result)
48 | throw error_already_set();
49 | return reinterpret_steal