├── .gitignore ├── LICENSE ├── eval.py ├── figure ├── result0.jpg ├── result1.jpg ├── result2.jpg ├── result3.jpg ├── result4.jpg └── result5.jpg ├── nets ├── __init__.py ├── model.py └── resnet │ ├── __init__.py │ ├── resnet_utils.py │ └── resnet_v1.py ├── pse ├── Makefile ├── __init__.py ├── include │ └── pybind11 │ │ ├── attr.h │ │ ├── buffer_info.h │ │ ├── cast.h │ │ ├── chrono.h │ │ ├── class_support.h │ │ ├── common.h │ │ ├── complex.h │ │ ├── descr.h │ │ ├── detail │ │ ├── class.h │ │ ├── common.h │ │ ├── descr.h │ │ ├── init.h │ │ ├── internals.h │ │ └── typeid.h │ │ ├── eigen.h │ │ ├── embed.h │ │ ├── eval.h │ │ ├── functional.h │ │ ├── iostream.h │ │ ├── numpy.h │ │ ├── operators.h │ │ ├── options.h │ │ ├── pybind11.h │ │ ├── pytypes.h │ │ ├── stl.h │ │ ├── stl_bind.h │ │ └── typeid.h └── pse.cpp ├── readme.md ├── train.py └── utils ├── __init__.py ├── data_provider ├── __init__.py ├── data_provider.py └── data_util.py └── utils_tool.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | *.pyc 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Michael liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import cv2 3 | import time 4 | import os 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.python.client import timeline 8 | from utils.utils_tool import logger, cfg 9 | import matplotlib.pyplot as plt 10 | 11 | tf.app.flags.DEFINE_string('test_data_path', None, '') 12 | tf.app.flags.DEFINE_string('gpu_list', '0', '') 13 | tf.app.flags.DEFINE_string('checkpoint_path', './', '') 14 | tf.app.flags.DEFINE_string('output_dir', './results/', '') 15 | tf.app.flags.DEFINE_bool('no_write_images', False, 'do not write images') 16 | 17 | from nets import model 18 | from pse import pse 19 | 20 | FLAGS = tf.app.flags.FLAGS 21 | 22 | logger.setLevel(cfg.debug) 23 | 24 | def get_images(): 25 | ''' 26 | find image files in test data path 27 | :return: list of files found 28 | ''' 29 | files = [] 30 | exts = ['jpg', 'png', 'jpeg', 'JPG'] 31 | for parent, dirnames, filenames in os.walk(FLAGS.test_data_path): 32 | for filename in filenames: 33 | for ext in exts: 34 | if filename.endswith(ext): 35 | files.append(os.path.join(parent, filename)) 36 | break 37 | logger.info('Find {} images'.format(len(files))) 38 | return files 39 | 40 | 41 | def resize_image(im, max_side_len=1200): 42 | ''' 43 | resize image to a size multiple of 32 which is required by the network 44 | :param im: the resized image 45 | :param max_side_len: limit of max image size to avoid out of memory in gpu 46 | :return: the resized image and the resize ratio 47 | ''' 48 | h, w, _ = im.shape 49 | 50 | resize_w = w 51 | resize_h = h 52 | 53 | # limit the max side 54 | if max(resize_h, resize_w) > max_side_len: 55 | ratio = float(max_side_len) / resize_h if resize_h > resize_w else float(max_side_len) / resize_w 56 | else: 57 | ratio = 1. 58 | 59 | #ratio = float(max_side_len) / resize_h if resize_h > resize_w else float(max_side_len) / resize_w 60 | 61 | 62 | resize_h = int(resize_h * ratio) 63 | resize_w = int(resize_w * ratio) 64 | 65 | resize_h = resize_h if resize_h % 32 == 0 else (resize_h // 32 + 1) * 32 66 | resize_w = resize_w if resize_w % 32 == 0 else (resize_w // 32 + 1) * 32 67 | logger.info('resize_w:{}, resize_h:{}'.format(resize_w, resize_h)) 68 | im = cv2.resize(im, (int(resize_w), int(resize_h))) 69 | 70 | ratio_h = resize_h / float(h) 71 | ratio_w = resize_w / float(w) 72 | 73 | return im, (ratio_h, ratio_w) 74 | 75 | 76 | def detect(seg_maps, timer, image_w, image_h, min_area_thresh=10, seg_map_thresh=0.9, ratio = 1): 77 | ''' 78 | restore text boxes from score map and geo map 79 | :param seg_maps: 80 | :param timer: 81 | :param min_area_thresh: 82 | :param seg_map_thresh: threshhold for seg map 83 | :param ratio: compute each seg map thresh 84 | :return: 85 | ''' 86 | if len(seg_maps.shape) == 4: 87 | seg_maps = seg_maps[0, :, :, ] 88 | #get kernals, sequence: 0->n, max -> min 89 | kernals = [] 90 | one = np.ones_like(seg_maps[..., 0], dtype=np.uint8) 91 | zero = np.zeros_like(seg_maps[..., 0], dtype=np.uint8) 92 | thresh = seg_map_thresh 93 | for i in range(seg_maps.shape[-1]-1, -1, -1): 94 | kernal = np.where(seg_maps[..., i]>thresh, one, zero) 95 | kernals.append(kernal) 96 | thresh = seg_map_thresh*ratio 97 | start = time.time() 98 | mask_res, label_values = pse(kernals, min_area_thresh) 99 | timer['pse'] = time.time()-start 100 | mask_res = np.array(mask_res) 101 | mask_res_resized = cv2.resize(mask_res, (image_w, image_h), interpolation=cv2.INTER_NEAREST) 102 | boxes = [] 103 | for label_value in label_values: 104 | #(y,x) 105 | points = np.argwhere(mask_res_resized==label_value) 106 | points = points[:, (1,0)] 107 | rect = cv2.minAreaRect(points) 108 | box = cv2.boxPoints(rect) 109 | boxes.append(box) 110 | 111 | return np.array(boxes), kernals, timer 112 | 113 | def show_score_geo(color_im, kernels, im_res): 114 | fig = plt.figure() 115 | cmap = plt.cm.hot 116 | # 117 | ax = fig.add_subplot(241) 118 | im = kernels[0]*255 119 | ax.imshow(im) 120 | 121 | ax = fig.add_subplot(242) 122 | im = kernels[1]*255 123 | ax.imshow(im, cmap) 124 | 125 | ax = fig.add_subplot(243) 126 | im = kernels[2]*255 127 | ax.imshow(im, cmap) 128 | 129 | ax = fig.add_subplot(244) 130 | im = kernels[3]*255 131 | ax.imshow(im, cmap) 132 | 133 | ax = fig.add_subplot(245) 134 | im = kernels[4]*255 135 | ax.imshow(im, cmap) 136 | 137 | ax = fig.add_subplot(246) 138 | im = kernels[5]*255 139 | ax.imshow(im, cmap) 140 | 141 | ax = fig.add_subplot(247) 142 | im = color_im 143 | ax.imshow(im) 144 | 145 | ax = fig.add_subplot(248) 146 | im = im_res 147 | ax.imshow(im) 148 | 149 | fig.show() 150 | 151 | 152 | def main(argv=None): 153 | import os 154 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list 155 | 156 | try: 157 | os.makedirs(FLAGS.output_dir) 158 | except OSError as e: 159 | if e.errno != 17: 160 | raise 161 | 162 | with tf.get_default_graph().as_default(): 163 | input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images') 164 | global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) 165 | seg_maps_pred = model.model(input_images, is_training=False) 166 | 167 | variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) 168 | saver = tf.train.Saver(variable_averages.variables_to_restore()) 169 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 170 | ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path) 171 | model_path = os.path.join(FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path)) 172 | logger.info('Restore from {}'.format(model_path)) 173 | saver.restore(sess, model_path) 174 | 175 | im_fn_list = get_images() 176 | for im_fn in im_fn_list: 177 | im = cv2.imread(im_fn)[:, :, ::-1] 178 | logger.debug('image file:{}'.format(im_fn)) 179 | 180 | start_time = time.time() 181 | im_resized, (ratio_h, ratio_w) = resize_image(im) 182 | h, w, _ = im_resized.shape 183 | # options = tf.RunOptions(trace_level = tf.RunOptions.FULL_TRACE) 184 | # run_metadata = tf.RunMetadata() 185 | timer = {'net': 0, 'pse': 0} 186 | start = time.time() 187 | seg_maps = sess.run(seg_maps_pred, feed_dict={input_images: [im_resized]}) 188 | timer['net'] = time.time() - start 189 | # fetched_timeline = timeline.Timeline(run_metadata.step_stats) 190 | # chrome_trace = fetched_timeline.generate_chrome_trace_format() 191 | # with open(os.path.join(FLAGS.output_dir, os.path.basename(im_fn).split('.')[0]+'.json'), 'w') as f: 192 | # f.write(chrome_trace) 193 | 194 | boxes, kernels, timer = detect(seg_maps=seg_maps, timer=timer, image_w=w, image_h=h) 195 | logger.info('{} : net {:.0f}ms, pse {:.0f}ms'.format( 196 | im_fn, timer['net']*1000, timer['pse']*1000)) 197 | 198 | if boxes is not None: 199 | boxes = boxes.reshape((-1, 4, 2)) 200 | boxes[:, :, 0] /= ratio_w 201 | boxes[:, :, 1] /= ratio_h 202 | h, w, _ = im.shape 203 | boxes[:, :, 0] = np.clip(boxes[:, :, 0], 0, w) 204 | boxes[:, :, 1] = np.clip(boxes[:, :, 1], 0, h) 205 | 206 | duration = time.time() - start_time 207 | logger.info('[timing] {}'.format(duration)) 208 | 209 | # save to file 210 | if boxes is not None: 211 | res_file = os.path.join( 212 | FLAGS.output_dir, 213 | '{}.txt'.format(os.path.splitext( 214 | os.path.basename(im_fn))[0])) 215 | 216 | 217 | with open(res_file, 'w') as f: 218 | num =0 219 | for i in xrange(len(boxes)): 220 | # to avoid submitting errors 221 | box = boxes[i] 222 | if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3]-box[0]) < 5: 223 | continue 224 | 225 | num += 1 226 | 227 | f.write('{},{},{},{},{},{},{},{}\r\n'.format( 228 | box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0], box[2, 1], box[3, 0], box[3, 1])) 229 | cv2.polylines(im[:, :, ::-1], [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=2) 230 | if not FLAGS.no_write_images: 231 | img_path = os.path.join(FLAGS.output_dir, os.path.basename(im_fn)) 232 | cv2.imwrite(img_path, im[:, :, ::-1]) 233 | # show_score_geo(im_resized, kernels, im) 234 | if __name__ == '__main__': 235 | tf.app.run() 236 | -------------------------------------------------------------------------------- /figure/result0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichaelHL-ai/tensorflow_PSENet/e2cd908f301b762150aa36893677c1c51c98ff9e/figure/result0.jpg -------------------------------------------------------------------------------- /figure/result1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichaelHL-ai/tensorflow_PSENet/e2cd908f301b762150aa36893677c1c51c98ff9e/figure/result1.jpg -------------------------------------------------------------------------------- /figure/result2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichaelHL-ai/tensorflow_PSENet/e2cd908f301b762150aa36893677c1c51c98ff9e/figure/result2.jpg -------------------------------------------------------------------------------- /figure/result3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichaelHL-ai/tensorflow_PSENet/e2cd908f301b762150aa36893677c1c51c98ff9e/figure/result3.jpg -------------------------------------------------------------------------------- /figure/result4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichaelHL-ai/tensorflow_PSENet/e2cd908f301b762150aa36893677c1c51c98ff9e/figure/result4.jpg -------------------------------------------------------------------------------- /figure/result5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichaelHL-ai/tensorflow_PSENet/e2cd908f301b762150aa36893677c1c51c98ff9e/figure/result5.jpg -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichaelHL-ai/tensorflow_PSENet/e2cd908f301b762150aa36893677c1c51c98ff9e/nets/__init__.py -------------------------------------------------------------------------------- /nets/model.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | import tensorflow as tf 3 | from utils.utils_tool import logger 4 | 5 | from tensorflow.contrib import slim 6 | 7 | tf.app.flags.DEFINE_integer('text_scale', 512, '') 8 | 9 | from nets.resnet import resnet_v1 10 | 11 | FLAGS = tf.app.flags.FLAGS 12 | 13 | #TODO:bilinear or nearest_neighbor? 14 | def unpool(inputs, rate): 15 | return tf.image.resize_bilinear(inputs, size=[tf.shape(inputs)[1]*rate, tf.shape(inputs)[2]*rate]) 16 | 17 | def mean_image_subtraction(images, means=[123.68, 116.78, 103.94]): 18 | ''' 19 | image normalization 20 | :param images: 21 | :param means: 22 | :return: 23 | ''' 24 | num_channels = images.get_shape().as_list()[-1] 25 | if len(means) != num_channels: 26 | raise ValueError('len(means) must match the number of channels') 27 | channels = tf.split(axis=3, num_or_size_splits=num_channels, value=images) 28 | for i in range(num_channels): 29 | channels[i] -= means[i] 30 | return tf.concat(axis=3, values=channels) 31 | 32 | def build_feature_pyramid(C, weight_decay): 33 | 34 | ''' 35 | reference: https://github.com/CharlesShang/FastMaskRCNN 36 | build P2, P3, P4, P5 37 | :return: multi-scale feature map 38 | ''' 39 | 40 | feature_pyramid = {} 41 | with tf.variable_scope('build_feature_pyramid'): 42 | with slim.arg_scope([slim.conv2d], weights_regularizer=slim.l2_regularizer(weight_decay)): 43 | feature_pyramid['P5'] = slim.conv2d(C['C5'], 44 | num_outputs=256, 45 | kernel_size=[1, 1], 46 | stride=1, 47 | scope='build_P5') 48 | 49 | # feature_pyramid['P6'] = slim.max_pool2d(feature_pyramid['P5'], 50 | # kernel_size=[2, 2], stride=2, scope='build_P6') 51 | # P6 is down sample of P5 52 | 53 | for layer in range(4, 1, -1): 54 | p, c = feature_pyramid['P' + str(layer + 1)], C['C' + str(layer)] 55 | up_sample_shape = tf.shape(c) 56 | up_sample = tf.image.resize_nearest_neighbor(p, [up_sample_shape[1], up_sample_shape[2]], 57 | name='build_P%d/up_sample_nearest_neighbor' % layer) 58 | 59 | c = slim.conv2d(c, num_outputs=256, kernel_size=[1, 1], stride=1, 60 | scope='build_P%d/reduce_dimension' % layer) 61 | p = up_sample + c 62 | p = slim.conv2d(p, 256, kernel_size=[3, 3], stride=1, 63 | padding='SAME', scope='build_P%d/avoid_aliasing' % layer) 64 | feature_pyramid['P' + str(layer)] = p 65 | return feature_pyramid 66 | 67 | def model(images, outputs = 6, weight_decay=1e-5, is_training=True): 68 | ''' 69 | define the model, we use slim's implemention of resnet 70 | ''' 71 | images = mean_image_subtraction(images) 72 | 73 | with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)): 74 | logits, end_points = resnet_v1.resnet_v1_50(images, is_training=is_training, scope='resnet_v1_50') 75 | 76 | #no non-linearities in FPN article 77 | feature_pyramid = build_feature_pyramid(end_points, weight_decay=weight_decay) 78 | #unpool sample P 79 | P_concat = [] 80 | for i in range(3, 0, -1): 81 | P_concat.append(unpool(feature_pyramid['P'+str(i+2)], 2**i)) 82 | P_concat.append(feature_pyramid['P2']) 83 | #F = C(P2,P3,P4,P5) 84 | F = tf.concat(P_concat, axis=-1) 85 | 86 | #reduce to 256 channels 87 | with tf.variable_scope('feature_results'): 88 | batch_norm_params = { 89 | 'decay': 0.997, 90 | 'epsilon': 1e-5, 91 | 'scale': True, 92 | 'is_training': is_training 93 | } 94 | with slim.arg_scope([slim.conv2d], 95 | activation_fn=tf.nn.relu, 96 | normalizer_fn=slim.batch_norm, 97 | normalizer_params=batch_norm_params, 98 | weights_regularizer=slim.l2_regularizer(weight_decay)): 99 | F = slim.conv2d(F, 256, 3) 100 | with slim.arg_scope([slim.conv2d], 101 | weights_regularizer=slim.l2_regularizer(weight_decay), 102 | activation_fn=None): 103 | S = slim.conv2d(F, outputs, 1) 104 | 105 | seg_S_pred = tf.nn.sigmoid(S) 106 | 107 | return seg_S_pred 108 | 109 | def dice_coefficient(y_true_cls, y_pred_cls, 110 | training_mask): 111 | ''' 112 | dice loss 113 | :param y_true_cls: ground truth 114 | :param y_pred_cls: predict 115 | :param training_mask: 116 | :return: 117 | ''' 118 | eps = 1e-5 119 | intersection = tf.reduce_sum(y_true_cls * y_pred_cls * training_mask) 120 | union = tf.reduce_sum(y_true_cls * training_mask) + tf.reduce_sum(y_pred_cls * training_mask) + eps 121 | dice = 2 * intersection / union 122 | loss = 1. - dice 123 | # tf.summary.scalar('classification_dice_loss', loss) 124 | return dice, loss 125 | 126 | def loss(y_true_cls, y_pred_cls, 127 | training_mask): 128 | g1, g2, g3, g4, g5, g6 = tf.split(value=y_true_cls, num_or_size_splits=6, axis=3) 129 | s1, s2, s3, s4, s5, s6 = tf.split(value=y_pred_cls, num_or_size_splits=6, axis=3) 130 | Gn = [g1, g2, g3, g4, g5, g6] 131 | Sn = [s1, s2, s3, s4, s5, s6] 132 | _, Lc = dice_coefficient(Gn[5], Sn[5], training_mask=training_mask) 133 | tf.summary.scalar('Lc_loss', Lc) 134 | 135 | one = tf.ones_like(Sn[5]) 136 | zero = tf.zeros_like(Sn[5]) 137 | W = tf.where(Sn[5] >= 0.5, x=one, y=zero) 138 | D = 0 139 | for i in range(5): 140 | di, _ = dice_coefficient(Gn[i]*W, Sn[i]*W, training_mask=training_mask) 141 | D += di 142 | Ls = 1-D/5. 143 | tf.summary.scalar('Ls_loss', Ls) 144 | lambda_ = 0.7 145 | L = lambda_*Lc + (1-lambda_)*Ls 146 | return L 147 | 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /nets/resnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichaelHL-ai/tensorflow_PSENet/e2cd908f301b762150aa36893677c1c51c98ff9e/nets/resnet/__init__.py -------------------------------------------------------------------------------- /nets/resnet/resnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains building blocks for various versions of Residual Networks. 16 | 17 | Residual networks (ResNets) were proposed in: 18 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 20 | 21 | More variants were introduced in: 22 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 24 | 25 | We can obtain different ResNet variants by changing the network depth, width, 26 | and form of residual unit. This module implements the infrastructure for 27 | building them. Concrete ResNet units and full ResNet networks are implemented in 28 | the accompanying resnet_v1.py and resnet_v2.py modules. 29 | 30 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current 31 | implementation we subsample the output activations in the last residual unit of 32 | each block, instead of subsampling the input activations in the first residual 33 | unit of each block. The two implementations give identical results but our 34 | implementation is more memory efficient. 35 | """ 36 | 37 | 38 | 39 | 40 | import collections 41 | import tensorflow as tf 42 | 43 | slim = tf.contrib.slim 44 | 45 | 46 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): 47 | """A named tuple describing a ResNet block. 48 | 49 | Its parts are: 50 | scope: The scope of the `Block`. 51 | unit_fn: The ResNet unit function which takes as input a `Tensor` and 52 | returns another `Tensor` with the output of the ResNet unit. 53 | args: A list of length equal to the number of units in the `Block`. The list 54 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the 55 | block to serve as argument to unit_fn. 56 | """ 57 | 58 | 59 | def subsample(inputs, factor, scope=None): 60 | """Subsamples the input along the spatial dimensions. 61 | 62 | Args: 63 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 64 | factor: The subsampling factor. 65 | scope: Optional variable_scope. 66 | 67 | Returns: 68 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 69 | input, either intact (if factor == 1) or subsampled (if factor > 1). 70 | """ 71 | if factor == 1: 72 | return inputs 73 | else: 74 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 75 | 76 | 77 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 78 | """Strided 2-D convolution with 'SAME' padding. 79 | 80 | When stride > 1, then we do explicit zero-padding, followed by conv2d with 81 | 'VALID' padding. 82 | 83 | Note that 84 | 85 | net = conv2d_same(inputs, num_outputs, 3, stride=stride) 86 | 87 | is equivalent to 88 | 89 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') 90 | net = subsample(net, factor=stride) 91 | 92 | whereas 93 | 94 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') 95 | 96 | is different when the input's height or width is even, which is why we add the 97 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). 98 | 99 | Args: 100 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. 101 | num_outputs: An integer, the number of output filters. 102 | kernel_size: An int with the kernel_size of the filters. 103 | stride: An integer, the output stride. 104 | rate: An integer, rate for atrous convolution. 105 | scope: Scope. 106 | 107 | Returns: 108 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with 109 | the convolution output. 110 | """ 111 | if stride == 1: 112 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 113 | padding='SAME', scope=scope) 114 | else: 115 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 116 | pad_total = kernel_size_effective - 1 117 | pad_beg = pad_total // 2 118 | pad_end = pad_total - pad_beg 119 | inputs = tf.pad(inputs, 120 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 121 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 122 | rate=rate, padding='VALID', scope=scope) 123 | 124 | 125 | @slim.add_arg_scope 126 | def stack_blocks_dense(net, blocks, output_stride=None, 127 | outputs_collections=None): 128 | """Stacks ResNet `Blocks` and controls output feature density. 129 | 130 | First, this function creates scopes for the ResNet in the form of 131 | 'block_name/unit_1', 'block_name/unit_2', etc. 132 | 133 | Second, this function allows the user to explicitly control the ResNet 134 | output_stride, which is the ratio of the input to output spatial resolution. 135 | This is useful for dense prediction tasks such as semantic segmentation or 136 | object detection. 137 | 138 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a 139 | factor of 2 when transitioning between consecutive ResNet blocks. This results 140 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to 141 | half the nominal network stride (e.g., output_stride=4), then we compute 142 | responses twice. 143 | 144 | Control of the output feature density is implemented by atrous convolution. 145 | 146 | Args: 147 | net: A `Tensor` of size [batch, height, width, channels]. 148 | blocks: A list of length equal to the number of ResNet `Blocks`. Each 149 | element is a ResNet `Block` object describing the units in the `Block`. 150 | output_stride: If `None`, then the output will be computed at the nominal 151 | network stride. If output_stride is not `None`, it specifies the requested 152 | ratio of input to output spatial resolution, which needs to be equal to 153 | the product of unit strides from the start up to some level of the ResNet. 154 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, 155 | then valid values for the output_stride are 1, 2, 6, 24 or None (which 156 | is equivalent to output_stride=24). 157 | outputs_collections: Collection to add the ResNet block outputs. 158 | 159 | Returns: 160 | net: Output tensor with stride equal to the specified output_stride. 161 | 162 | Raises: 163 | ValueError: If the target output_stride is not valid. 164 | """ 165 | # The current_stride variable keeps track of the effective stride of the 166 | # activations. This allows us to invoke atrous convolution whenever applying 167 | # the next residual unit would result in the activations having stride larger 168 | # than the target output_stride. 169 | current_stride = 1 170 | 171 | # The atrous convolution rate parameter. 172 | rate = 1 173 | 174 | for block in blocks: 175 | with tf.variable_scope(block.scope, 'block', [net]) as sc: 176 | for i, unit in enumerate(block.args): 177 | if output_stride is not None and current_stride > output_stride: 178 | raise ValueError('The target output_stride cannot be reached.') 179 | 180 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]): 181 | unit_depth, unit_depth_bottleneck, unit_stride = unit 182 | # If we have reached the target output_stride, then we need to employ 183 | # atrous convolution with stride=1 and multiply the atrous rate by the 184 | # current unit's stride for use in subsequent layers. 185 | if output_stride is not None and current_stride == output_stride: 186 | net = block.unit_fn(net, 187 | depth=unit_depth, 188 | depth_bottleneck=unit_depth_bottleneck, 189 | stride=1, 190 | rate=rate) 191 | rate *= unit_stride 192 | 193 | else: 194 | net = block.unit_fn(net, 195 | depth=unit_depth, 196 | depth_bottleneck=unit_depth_bottleneck, 197 | stride=unit_stride, 198 | rate=1) 199 | current_stride *= unit_stride 200 | print(sc.name, net.shape) 201 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 202 | 203 | if output_stride is not None and current_stride != output_stride: 204 | raise ValueError('The target output_stride cannot be reached.') 205 | 206 | return net 207 | 208 | 209 | def resnet_arg_scope(weight_decay=0.0001, 210 | batch_norm_decay=0.997, 211 | batch_norm_epsilon=1e-5, 212 | batch_norm_scale=True): 213 | """Defines the default ResNet arg scope. 214 | 215 | TODO(gpapan): The batch-normalization related default values above are 216 | appropriate for use in conjunction with the reference ResNet models 217 | released at https://github.com/KaimingHe/deep-residual-networks. When 218 | training ResNets from scratch, they might need to be tuned. 219 | 220 | Args: 221 | weight_decay: The weight decay to use for regularizing the model. 222 | batch_norm_decay: The moving average decay when estimating layer activation 223 | statistics in batch normalization. 224 | batch_norm_epsilon: Small constant to prevent division by zero when 225 | normalizing activations by their variance in batch normalization. 226 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 227 | activations in the batch normalization layer. 228 | 229 | Returns: 230 | An `arg_scope` to use for the resnet models. 231 | """ 232 | batch_norm_params = { 233 | 'decay': batch_norm_decay, 234 | 'epsilon': batch_norm_epsilon, 235 | 'scale': batch_norm_scale, 236 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 237 | } 238 | 239 | with slim.arg_scope( 240 | [slim.conv2d], 241 | weights_regularizer=slim.l2_regularizer(weight_decay), 242 | weights_initializer=slim.variance_scaling_initializer(), 243 | activation_fn=tf.nn.relu, 244 | normalizer_fn=slim.batch_norm, 245 | normalizer_params=batch_norm_params): 246 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 247 | # The following implies padding='SAME' for pool1, which makes feature 248 | # alignment easier for dense prediction tasks. This is also used in 249 | # https://github.com/facebook/fb.resnet.torch. However the accompanying 250 | # code of 'Deep Residual Learning for Image Recognition' uses 251 | # padding='VALID' for pool1. You can switch to that choice by setting 252 | # slim.arg_scope([slim.max_pool2d], padding='VALID'). 253 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 254 | return arg_sc 255 | -------------------------------------------------------------------------------- /nets/resnet/resnet_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for the original form of Residual Networks. 16 | 17 | The 'v1' residual networks (ResNets) implemented in this module were proposed 18 | by: 19 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 20 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 21 | 22 | Other variants were introduced in: 23 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 25 | 26 | The networks defined in this module utilize the bottleneck building block of 27 | [1] with projection shortcuts only for increasing depths. They employ batch 28 | normalization *after* every weight layer. This is the architecture used by 29 | MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and 30 | ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1' 31 | architecture and the alternative 'v2' architecture of [2] which uses batch 32 | normalization *before* every weight layer in the so-called full pre-activation 33 | units. 34 | 35 | Typical use: 36 | 37 | from tensorflow.contrib.slim.nets import resnet_v1 38 | 39 | ResNet-101 for image classification into 1000 classes: 40 | 41 | # inputs has shape [batch, 224, 224, 3] 42 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 43 | net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False) 44 | 45 | ResNet-101 for semantic segmentation into 21 classes: 46 | 47 | # inputs has shape [batch, 513, 513, 3] 48 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 49 | net, end_points = resnet_v1.resnet_v1_101(inputs, 50 | 21, 51 | is_training=False, 52 | global_pool=False, 53 | output_stride=16) 54 | """ 55 | # from __future__ import absolute_import 56 | # from __future__ import division 57 | # from __future__ import print_function 58 | 59 | import tensorflow as tf 60 | from tensorflow.contrib import slim 61 | 62 | from . import resnet_utils 63 | 64 | resnet_arg_scope = resnet_utils.resnet_arg_scope 65 | 66 | 67 | @slim.add_arg_scope 68 | def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1, 69 | outputs_collections=None, scope=None): 70 | """Bottleneck residual unit variant with BN after convolutions. 71 | 72 | This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for 73 | its definition. Note that we use here the bottleneck variant which has an 74 | extra bottleneck layer. 75 | 76 | When putting together two consecutive ResNet blocks that use this unit, one 77 | should use stride = 2 in the last unit of the first block. 78 | 79 | Args: 80 | inputs: A tensor of size [batch, height, width, channels]. 81 | depth: The depth of the ResNet unit output. 82 | depth_bottleneck: The depth of the bottleneck layers. 83 | stride: The ResNet unit's stride. Determines the amount of downsampling of 84 | the units output compared to its input. 85 | rate: An integer, rate for atrous convolution. 86 | outputs_collections: Collection to add the ResNet unit output. 87 | scope: Optional variable_scope. 88 | 89 | Returns: 90 | The ResNet unit's output. 91 | """ 92 | with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: 93 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 94 | if depth == depth_in: 95 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 96 | else: 97 | shortcut = slim.conv2d(inputs, depth, [1, 1], stride=stride, 98 | activation_fn=None, scope='shortcut') 99 | 100 | residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1, 101 | scope='conv1') 102 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride, 103 | rate=rate, scope='conv2') 104 | residual = slim.conv2d(residual, depth, [1, 1], stride=1, 105 | activation_fn=None, scope='conv3') 106 | 107 | output = tf.nn.relu(shortcut + residual) 108 | 109 | return slim.utils.collect_named_outputs(outputs_collections, 110 | sc.original_name_scope, 111 | output) 112 | 113 | 114 | def resnet_v1(inputs, 115 | blocks, 116 | num_classes=None, 117 | is_training=True, 118 | global_pool=True, 119 | output_stride=None, 120 | include_root_block=True, 121 | spatial_squeeze=True, 122 | reuse=None, 123 | scope=None): 124 | """Generator for v1 ResNet models. 125 | 126 | This function generates a family of ResNet v1 models. See the resnet_v1_*() 127 | methods for specific model instantiations, obtained by selecting different 128 | block instantiations that produce ResNets of various depths. 129 | 130 | Training for image classification on Imagenet is usually done with [224, 224] 131 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 132 | block for the ResNets defined in [1] that have nominal stride equal to 32. 133 | However, for dense prediction tasks we advise that one uses inputs with 134 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 135 | this case the feature maps at the ResNet output will have spatial shape 136 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 137 | and corners exactly aligned with the input image corners, which greatly 138 | facilitates alignment of the features to the image. Using as input [225, 225] 139 | images results in [8, 8] feature maps at the output of the last ResNet block. 140 | 141 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 142 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 143 | have nominal stride equal to 32 and a good choice in FCN mode is to use 144 | output_stride=16 in order to increase the density of the computed features at 145 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 146 | 147 | Args: 148 | inputs: A tensor of size [batch, height_in, width_in, channels]. 149 | blocks: A list of length equal to the number of ResNet blocks. Each element 150 | is a resnet_utils.Block object describing the units in the block. 151 | num_classes: Number of predicted classes for classification tasks. If None 152 | we return the features before the logit layer. 153 | is_training: whether is training or not. 154 | global_pool: If True, we perform global average pooling before computing the 155 | logits. Set to True for image classification, False for dense prediction. 156 | output_stride: If None, then the output will be computed at the nominal 157 | network stride. If output_stride is not None, it specifies the requested 158 | ratio of input to output spatial resolution. 159 | include_root_block: If True, include the initial convolution followed by 160 | max-pooling, if False excludes it. 161 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is 162 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 163 | reuse: whether or not the network and its variables should be reused. To be 164 | able to reuse 'scope' must be given. 165 | scope: Optional variable_scope. 166 | 167 | Returns: 168 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 169 | If global_pool is False, then height_out and width_out are reduced by a 170 | factor of output_stride compared to the respective height_in and width_in, 171 | else both height_out and width_out equal one. If num_classes is None, then 172 | net is the output of the last ResNet block, potentially after global 173 | average pooling. If num_classes is not None, net contains the pre-softmax 174 | activations. 175 | end_points: A dictionary from components of the network to the corresponding 176 | activation. 177 | 178 | Raises: 179 | ValueError: If the target output_stride is not valid. 180 | """ 181 | with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc: 182 | end_points_collection = sc.name + '_end_points' 183 | with slim.arg_scope([slim.conv2d, bottleneck, 184 | resnet_utils.stack_blocks_dense], 185 | outputs_collections=end_points_collection): 186 | with slim.arg_scope([slim.batch_norm], is_training=is_training): 187 | net = inputs 188 | if include_root_block: 189 | if output_stride is not None: 190 | if output_stride % 4 != 0: 191 | raise ValueError('The output_stride needs to be a multiple of 4.') 192 | output_stride /= 4 193 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') 194 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') 195 | 196 | net = slim.utils.collect_named_outputs(end_points_collection, 'C2', net) 197 | 198 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) 199 | 200 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 201 | 202 | # end_points['pool2'] = end_points['resnet_v1_50/pool1/MaxPool:0'] 203 | try: 204 | end_points['C3'] = end_points['resnet_v1_50/block1'] 205 | end_points['C4'] = end_points['resnet_v1_50/block2'] 206 | except: 207 | end_points['C3'] = end_points['Detection/resnet_v1_50/block1'] 208 | end_points['C4'] = end_points['Detection/resnet_v1_50/block2'] 209 | end_points['C5'] = net 210 | # if global_pool: 211 | # # Global average pooling. 212 | # net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 213 | # if num_classes is not None: 214 | # net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 215 | # normalizer_fn=None, scope='logits') 216 | # if spatial_squeeze: 217 | # logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze') 218 | # else: 219 | # logits = net 220 | # # Convert end_points_collection into a dictionary of end_points. 221 | # end_points = slim.utils.convert_collection_to_dict(end_points_collection) 222 | # if num_classes is not None: 223 | # end_points['predictions'] = slim.softmax(logits, scope='predictions') 224 | return net, end_points 225 | 226 | 227 | resnet_v1.default_image_size = 224 228 | 229 | 230 | def resnet_v1_50(inputs, 231 | num_classes=None, 232 | is_training=True, 233 | global_pool=True, 234 | output_stride=None, 235 | spatial_squeeze=True, 236 | reuse=None, 237 | scope='resnet_v1_50'): 238 | """ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" 239 | blocks = [ 240 | resnet_utils.Block( 241 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 242 | resnet_utils.Block( 243 | 'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]), 244 | resnet_utils.Block( 245 | 'block3', bottleneck, [(1024, 256, 1)] * 5 + [(1024, 256, 2)]), 246 | resnet_utils.Block( 247 | 'block4', bottleneck, [(2048, 512, 1)] * 3) 248 | ] 249 | return resnet_v1(inputs, blocks, num_classes, is_training, 250 | global_pool=global_pool, output_stride=output_stride, 251 | include_root_block=True, spatial_squeeze=spatial_squeeze, 252 | reuse=reuse, scope=scope) 253 | 254 | 255 | resnet_v1_50.default_image_size = resnet_v1.default_image_size 256 | 257 | 258 | def resnet_v1_101(inputs, 259 | num_classes=None, 260 | is_training=True, 261 | global_pool=True, 262 | output_stride=None, 263 | spatial_squeeze=True, 264 | reuse=None, 265 | scope='resnet_v1_101'): 266 | """ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" 267 | blocks = [ 268 | resnet_utils.Block( 269 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 270 | resnet_utils.Block( 271 | 'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]), 272 | resnet_utils.Block( 273 | 'block3', bottleneck, [(1024, 256, 1)] * 22 + [(1024, 256, 2)]), 274 | resnet_utils.Block( 275 | 'block4', bottleneck, [(2048, 512, 1)] * 3) 276 | ] 277 | return resnet_v1(inputs, blocks, num_classes, is_training, 278 | global_pool=global_pool, output_stride=output_stride, 279 | include_root_block=True, spatial_squeeze=spatial_squeeze, 280 | reuse=reuse, scope=scope) 281 | 282 | 283 | resnet_v1_101.default_image_size = resnet_v1.default_image_size 284 | 285 | 286 | def resnet_v1_152(inputs, 287 | num_classes=None, 288 | is_training=True, 289 | global_pool=True, 290 | output_stride=None, 291 | spatial_squeeze=True, 292 | reuse=None, 293 | scope='resnet_v1_152'): 294 | """ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" 295 | blocks = [ 296 | resnet_utils.Block( 297 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 298 | resnet_utils.Block( 299 | 'block2', bottleneck, [(512, 128, 1)] * 7 + [(512, 128, 2)]), 300 | resnet_utils.Block( 301 | 'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]), 302 | resnet_utils.Block( 303 | 'block4', bottleneck, [(2048, 512, 1)] * 3)] 304 | return resnet_v1(inputs, blocks, num_classes, is_training, 305 | global_pool=global_pool, output_stride=output_stride, 306 | include_root_block=True, spatial_squeeze=spatial_squeeze, 307 | reuse=reuse, scope=scope) 308 | 309 | 310 | resnet_v1_152.default_image_size = resnet_v1.default_image_size 311 | 312 | 313 | def resnet_v1_200(inputs, 314 | num_classes=None, 315 | is_training=True, 316 | global_pool=True, 317 | output_stride=None, 318 | spatial_squeeze=True, 319 | reuse=None, 320 | scope='resnet_v1_200'): 321 | """ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" 322 | blocks = [ 323 | resnet_utils.Block( 324 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 325 | resnet_utils.Block( 326 | 'block2', bottleneck, [(512, 128, 1)] * 23 + [(512, 128, 2)]), 327 | resnet_utils.Block( 328 | 'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]), 329 | resnet_utils.Block( 330 | 'block4', bottleneck, [(2048, 512, 1)] * 3)] 331 | return resnet_v1(inputs, blocks, num_classes, is_training, 332 | global_pool=global_pool, output_stride=output_stride, 333 | include_root_block=True, spatial_squeeze=spatial_squeeze, 334 | reuse=reuse, scope=scope) 335 | 336 | 337 | resnet_v1_200.default_image_size = resnet_v1.default_image_size 338 | 339 | 340 | if __name__ == '__main__': 341 | input = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='input') 342 | with slim.arg_scope(resnet_arg_scope()) as sc: 343 | logits = resnet_v1_50(input) -------------------------------------------------------------------------------- /pse/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS = -I include -std=c++11 -O3 $(shell python-config --cflags) 2 | LDFLAGS = $(shell python-config --ldflags) 3 | 4 | DEPS = $(shell find include -xtype f) 5 | CXX_SOURCES = pse.cpp 6 | 7 | LIB_SO = pse.so 8 | 9 | $(LIB_SO): $(CXX_SOURCES) $(DEPS) 10 | $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC 11 | 12 | clean: 13 | rm -rf $(LIB_SO) 14 | -------------------------------------------------------------------------------- /pse/__init__.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import numpy as np 4 | import cv2 5 | 6 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 7 | 8 | if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value 9 | raise RuntimeError('Cannot compile pse: {}'.format(BASE_DIR)) 10 | 11 | def pse(kernals, min_area=5): 12 | ''' 13 | :param kernals: 14 | :param min_area: 15 | :return: 16 | ''' 17 | from .pse import pse_cpp 18 | kernal_num = len(kernals) 19 | if not kernal_num: 20 | return np.array([]), [] 21 | kernals = np.array(kernals) 22 | label_num, label = cv2.connectedComponents(kernals[kernal_num - 1].astype(np.uint8), connectivity=4) 23 | label_values = [] 24 | for label_idx in range(1, label_num): 25 | if np.sum(label == label_idx) < min_area: 26 | label[label == label_idx] = 0 27 | continue 28 | label_values.append(label_idx) 29 | 30 | pred = pse_cpp(label, kernals, c=6) 31 | 32 | return pred, label_values 33 | 34 | 35 | -------------------------------------------------------------------------------- /pse/include/pybind11/attr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/attr.h: Infrastructure for processing custom 3 | type and function attributes 4 | 5 | Copyright (c) 2016 Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cast.h" 14 | 15 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 16 | 17 | /// \addtogroup annotations 18 | /// @{ 19 | 20 | /// Annotation for methods 21 | struct is_method { handle class_; is_method(const handle &c) : class_(c) { } }; 22 | 23 | /// Annotation for operators 24 | struct is_operator { }; 25 | 26 | /// Annotation for parent scope 27 | struct scope { handle value; scope(const handle &s) : value(s) { } }; 28 | 29 | /// Annotation for documentation 30 | struct doc { const char *value; doc(const char *value) : value(value) { } }; 31 | 32 | /// Annotation for function names 33 | struct name { const char *value; name(const char *value) : value(value) { } }; 34 | 35 | /// Annotation indicating that a function is an overload associated with a given "sibling" 36 | struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } }; 37 | 38 | /// Annotation indicating that a class derives from another given type 39 | template struct base { 40 | PYBIND11_DEPRECATED("base() was deprecated in favor of specifying 'T' as a template argument to class_") 41 | base() { } 42 | }; 43 | 44 | /// Keep patient alive while nurse lives 45 | template struct keep_alive { }; 46 | 47 | /// Annotation indicating that a class is involved in a multiple inheritance relationship 48 | struct multiple_inheritance { }; 49 | 50 | /// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class 51 | struct dynamic_attr { }; 52 | 53 | /// Annotation which enables the buffer protocol for a type 54 | struct buffer_protocol { }; 55 | 56 | /// Annotation which requests that a special metaclass is created for a type 57 | struct metaclass { 58 | handle value; 59 | 60 | PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.") 61 | metaclass() {} 62 | 63 | /// Override pybind11's default metaclass 64 | explicit metaclass(handle value) : value(value) { } 65 | }; 66 | 67 | /// Annotation that marks a class as local to the module: 68 | struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } }; 69 | 70 | /// Annotation to mark enums as an arithmetic type 71 | struct arithmetic { }; 72 | 73 | /** \rst 74 | A call policy which places one or more guard variables (``Ts...``) around the function call. 75 | 76 | For example, this definition: 77 | 78 | .. code-block:: cpp 79 | 80 | m.def("foo", foo, py::call_guard()); 81 | 82 | is equivalent to the following pseudocode: 83 | 84 | .. code-block:: cpp 85 | 86 | m.def("foo", [](args...) { 87 | T scope_guard; 88 | return foo(args...); // forwarded arguments 89 | }); 90 | \endrst */ 91 | template struct call_guard; 92 | 93 | template <> struct call_guard<> { using type = detail::void_type; }; 94 | 95 | template 96 | struct call_guard { 97 | static_assert(std::is_default_constructible::value, 98 | "The guard type must be default constructible"); 99 | 100 | using type = T; 101 | }; 102 | 103 | template 104 | struct call_guard { 105 | struct type { 106 | T guard{}; // Compose multiple guard types with left-to-right default-constructor order 107 | typename call_guard::type next{}; 108 | }; 109 | }; 110 | 111 | /// @} annotations 112 | 113 | NAMESPACE_BEGIN(detail) 114 | /* Forward declarations */ 115 | enum op_id : int; 116 | enum op_type : int; 117 | struct undefined_t; 118 | template struct op_; 119 | inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret); 120 | 121 | /// Internal data structure which holds metadata about a keyword argument 122 | struct argument_record { 123 | const char *name; ///< Argument name 124 | const char *descr; ///< Human-readable version of the argument value 125 | handle value; ///< Associated Python object 126 | bool convert : 1; ///< True if the argument is allowed to convert when loading 127 | bool none : 1; ///< True if None is allowed when loading 128 | 129 | argument_record(const char *name, const char *descr, handle value, bool convert, bool none) 130 | : name(name), descr(descr), value(value), convert(convert), none(none) { } 131 | }; 132 | 133 | /// Internal data structure which holds metadata about a bound function (signature, overloads, etc.) 134 | struct function_record { 135 | function_record() 136 | : is_constructor(false), is_new_style_constructor(false), is_stateless(false), 137 | is_operator(false), has_args(false), has_kwargs(false), is_method(false) { } 138 | 139 | /// Function name 140 | char *name = nullptr; /* why no C++ strings? They generate heavier code.. */ 141 | 142 | // User-specified documentation string 143 | char *doc = nullptr; 144 | 145 | /// Human-readable version of the function signature 146 | char *signature = nullptr; 147 | 148 | /// List of registered keyword arguments 149 | std::vector args; 150 | 151 | /// Pointer to lambda function which converts arguments and performs the actual call 152 | handle (*impl) (function_call &) = nullptr; 153 | 154 | /// Storage for the wrapped function pointer and captured data, if any 155 | void *data[3] = { }; 156 | 157 | /// Pointer to custom destructor for 'data' (if needed) 158 | void (*free_data) (function_record *ptr) = nullptr; 159 | 160 | /// Return value policy associated with this function 161 | return_value_policy policy = return_value_policy::automatic; 162 | 163 | /// True if name == '__init__' 164 | bool is_constructor : 1; 165 | 166 | /// True if this is a new-style `__init__` defined in `detail/init.h` 167 | bool is_new_style_constructor : 1; 168 | 169 | /// True if this is a stateless function pointer 170 | bool is_stateless : 1; 171 | 172 | /// True if this is an operator (__add__), etc. 173 | bool is_operator : 1; 174 | 175 | /// True if the function has a '*args' argument 176 | bool has_args : 1; 177 | 178 | /// True if the function has a '**kwargs' argument 179 | bool has_kwargs : 1; 180 | 181 | /// True if this is a method 182 | bool is_method : 1; 183 | 184 | /// Number of arguments (including py::args and/or py::kwargs, if present) 185 | std::uint16_t nargs; 186 | 187 | /// Python method object 188 | PyMethodDef *def = nullptr; 189 | 190 | /// Python handle to the parent scope (a class or a module) 191 | handle scope; 192 | 193 | /// Python handle to the sibling function representing an overload chain 194 | handle sibling; 195 | 196 | /// Pointer to next overload 197 | function_record *next = nullptr; 198 | }; 199 | 200 | /// Special data structure which (temporarily) holds metadata about a bound class 201 | struct type_record { 202 | PYBIND11_NOINLINE type_record() 203 | : multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), module_local(false) { } 204 | 205 | /// Handle to the parent scope 206 | handle scope; 207 | 208 | /// Name of the class 209 | const char *name = nullptr; 210 | 211 | // Pointer to RTTI type_info data structure 212 | const std::type_info *type = nullptr; 213 | 214 | /// How large is the underlying C++ type? 215 | size_t type_size = 0; 216 | 217 | /// What is the alignment of the underlying C++ type? 218 | size_t type_align = 0; 219 | 220 | /// How large is the type's holder? 221 | size_t holder_size = 0; 222 | 223 | /// The global operator new can be overridden with a class-specific variant 224 | void *(*operator_new)(size_t) = nullptr; 225 | 226 | /// Function pointer to class_<..>::init_instance 227 | void (*init_instance)(instance *, const void *) = nullptr; 228 | 229 | /// Function pointer to class_<..>::dealloc 230 | void (*dealloc)(detail::value_and_holder &) = nullptr; 231 | 232 | /// List of base classes of the newly created type 233 | list bases; 234 | 235 | /// Optional docstring 236 | const char *doc = nullptr; 237 | 238 | /// Custom metaclass (optional) 239 | handle metaclass; 240 | 241 | /// Multiple inheritance marker 242 | bool multiple_inheritance : 1; 243 | 244 | /// Does the class manage a __dict__? 245 | bool dynamic_attr : 1; 246 | 247 | /// Does the class implement the buffer protocol? 248 | bool buffer_protocol : 1; 249 | 250 | /// Is the default (unique_ptr) holder type used? 251 | bool default_holder : 1; 252 | 253 | /// Is the class definition local to the module shared object? 254 | bool module_local : 1; 255 | 256 | PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) { 257 | auto base_info = detail::get_type_info(base, false); 258 | if (!base_info) { 259 | std::string tname(base.name()); 260 | detail::clean_type_id(tname); 261 | pybind11_fail("generic_type: type \"" + std::string(name) + 262 | "\" referenced unknown base type \"" + tname + "\""); 263 | } 264 | 265 | if (default_holder != base_info->default_holder) { 266 | std::string tname(base.name()); 267 | detail::clean_type_id(tname); 268 | pybind11_fail("generic_type: type \"" + std::string(name) + "\" " + 269 | (default_holder ? "does not have" : "has") + 270 | " a non-default holder type while its base \"" + tname + "\" " + 271 | (base_info->default_holder ? "does not" : "does")); 272 | } 273 | 274 | bases.append((PyObject *) base_info->type); 275 | 276 | if (base_info->type->tp_dictoffset != 0) 277 | dynamic_attr = true; 278 | 279 | if (caster) 280 | base_info->implicit_casts.emplace_back(type, caster); 281 | } 282 | }; 283 | 284 | inline function_call::function_call(const function_record &f, handle p) : 285 | func(f), parent(p) { 286 | args.reserve(f.nargs); 287 | args_convert.reserve(f.nargs); 288 | } 289 | 290 | /// Tag for a new-style `__init__` defined in `detail/init.h` 291 | struct is_new_style_constructor { }; 292 | 293 | /** 294 | * Partial template specializations to process custom attributes provided to 295 | * cpp_function_ and class_. These are either used to initialize the respective 296 | * fields in the type_record and function_record data structures or executed at 297 | * runtime to deal with custom call policies (e.g. keep_alive). 298 | */ 299 | template struct process_attribute; 300 | 301 | template struct process_attribute_default { 302 | /// Default implementation: do nothing 303 | static void init(const T &, function_record *) { } 304 | static void init(const T &, type_record *) { } 305 | static void precall(function_call &) { } 306 | static void postcall(function_call &, handle) { } 307 | }; 308 | 309 | /// Process an attribute specifying the function's name 310 | template <> struct process_attribute : process_attribute_default { 311 | static void init(const name &n, function_record *r) { r->name = const_cast(n.value); } 312 | }; 313 | 314 | /// Process an attribute specifying the function's docstring 315 | template <> struct process_attribute : process_attribute_default { 316 | static void init(const doc &n, function_record *r) { r->doc = const_cast(n.value); } 317 | }; 318 | 319 | /// Process an attribute specifying the function's docstring (provided as a C-style string) 320 | template <> struct process_attribute : process_attribute_default { 321 | static void init(const char *d, function_record *r) { r->doc = const_cast(d); } 322 | static void init(const char *d, type_record *r) { r->doc = const_cast(d); } 323 | }; 324 | template <> struct process_attribute : process_attribute { }; 325 | 326 | /// Process an attribute indicating the function's return value policy 327 | template <> struct process_attribute : process_attribute_default { 328 | static void init(const return_value_policy &p, function_record *r) { r->policy = p; } 329 | }; 330 | 331 | /// Process an attribute which indicates that this is an overloaded function associated with a given sibling 332 | template <> struct process_attribute : process_attribute_default { 333 | static void init(const sibling &s, function_record *r) { r->sibling = s.value; } 334 | }; 335 | 336 | /// Process an attribute which indicates that this function is a method 337 | template <> struct process_attribute : process_attribute_default { 338 | static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; } 339 | }; 340 | 341 | /// Process an attribute which indicates the parent scope of a method 342 | template <> struct process_attribute : process_attribute_default { 343 | static void init(const scope &s, function_record *r) { r->scope = s.value; } 344 | }; 345 | 346 | /// Process an attribute which indicates that this function is an operator 347 | template <> struct process_attribute : process_attribute_default { 348 | static void init(const is_operator &, function_record *r) { r->is_operator = true; } 349 | }; 350 | 351 | template <> struct process_attribute : process_attribute_default { 352 | static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; } 353 | }; 354 | 355 | /// Process a keyword argument attribute (*without* a default value) 356 | template <> struct process_attribute : process_attribute_default { 357 | static void init(const arg &a, function_record *r) { 358 | if (r->is_method && r->args.empty()) 359 | r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/); 360 | r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none); 361 | } 362 | }; 363 | 364 | /// Process a keyword argument attribute (*with* a default value) 365 | template <> struct process_attribute : process_attribute_default { 366 | static void init(const arg_v &a, function_record *r) { 367 | if (r->is_method && r->args.empty()) 368 | r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/); 369 | 370 | if (!a.value) { 371 | #if !defined(NDEBUG) 372 | std::string descr("'"); 373 | if (a.name) descr += std::string(a.name) + ": "; 374 | descr += a.type + "'"; 375 | if (r->is_method) { 376 | if (r->name) 377 | descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'"; 378 | else 379 | descr += " in method of '" + (std::string) str(r->scope) + "'"; 380 | } else if (r->name) { 381 | descr += " in function '" + (std::string) r->name + "'"; 382 | } 383 | pybind11_fail("arg(): could not convert default argument " 384 | + descr + " into a Python object (type not registered yet?)"); 385 | #else 386 | pybind11_fail("arg(): could not convert default argument " 387 | "into a Python object (type not registered yet?). " 388 | "Compile in debug mode for more information."); 389 | #endif 390 | } 391 | r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none); 392 | } 393 | }; 394 | 395 | /// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that) 396 | template 397 | struct process_attribute::value>> : process_attribute_default { 398 | static void init(const handle &h, type_record *r) { r->bases.append(h); } 399 | }; 400 | 401 | /// Process a parent class attribute (deprecated, does not support multiple inheritance) 402 | template 403 | struct process_attribute> : process_attribute_default> { 404 | static void init(const base &, type_record *r) { r->add_base(typeid(T), nullptr); } 405 | }; 406 | 407 | /// Process a multiple inheritance attribute 408 | template <> 409 | struct process_attribute : process_attribute_default { 410 | static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; } 411 | }; 412 | 413 | template <> 414 | struct process_attribute : process_attribute_default { 415 | static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; } 416 | }; 417 | 418 | template <> 419 | struct process_attribute : process_attribute_default { 420 | static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; } 421 | }; 422 | 423 | template <> 424 | struct process_attribute : process_attribute_default { 425 | static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; } 426 | }; 427 | 428 | template <> 429 | struct process_attribute : process_attribute_default { 430 | static void init(const module_local &l, type_record *r) { r->module_local = l.value; } 431 | }; 432 | 433 | /// Process an 'arithmetic' attribute for enums (does nothing here) 434 | template <> 435 | struct process_attribute : process_attribute_default {}; 436 | 437 | template 438 | struct process_attribute> : process_attribute_default> { }; 439 | 440 | /** 441 | * Process a keep_alive call policy -- invokes keep_alive_impl during the 442 | * pre-call handler if both Nurse, Patient != 0 and use the post-call handler 443 | * otherwise 444 | */ 445 | template struct process_attribute> : public process_attribute_default> { 446 | template = 0> 447 | static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); } 448 | template = 0> 449 | static void postcall(function_call &, handle) { } 450 | template = 0> 451 | static void precall(function_call &) { } 452 | template = 0> 453 | static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); } 454 | }; 455 | 456 | /// Recursively iterate over variadic template arguments 457 | template struct process_attributes { 458 | static void init(const Args&... args, function_record *r) { 459 | int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; 460 | ignore_unused(unused); 461 | } 462 | static void init(const Args&... args, type_record *r) { 463 | int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; 464 | ignore_unused(unused); 465 | } 466 | static void precall(function_call &call) { 467 | int unused[] = { 0, (process_attribute::type>::precall(call), 0) ... }; 468 | ignore_unused(unused); 469 | } 470 | static void postcall(function_call &call, handle fn_ret) { 471 | int unused[] = { 0, (process_attribute::type>::postcall(call, fn_ret), 0) ... }; 472 | ignore_unused(unused); 473 | } 474 | }; 475 | 476 | template 477 | using is_call_guard = is_instantiation; 478 | 479 | /// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found) 480 | template 481 | using extract_guard_t = typename exactly_one_t, Extra...>::type; 482 | 483 | /// Check the number of named arguments at compile time 484 | template ::value...), 486 | size_t self = constexpr_sum(std::is_same::value...)> 487 | constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) { 488 | return named == 0 || (self + named + has_args + has_kwargs) == nargs; 489 | } 490 | 491 | NAMESPACE_END(detail) 492 | NAMESPACE_END(PYBIND11_NAMESPACE) 493 | -------------------------------------------------------------------------------- /pse/include/pybind11/buffer_info.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/buffer_info.h: Python buffer object interface 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | 16 | /// Information record describing a Python buffer object 17 | struct buffer_info { 18 | void *ptr = nullptr; // Pointer to the underlying storage 19 | ssize_t itemsize = 0; // Size of individual items in bytes 20 | ssize_t size = 0; // Total number of entries 21 | std::string format; // For homogeneous buffers, this should be set to format_descriptor::format() 22 | ssize_t ndim = 0; // Number of dimensions 23 | std::vector shape; // Shape of the tensor (1 entry per dimension) 24 | std::vector strides; // Number of entries between adjacent entries (for each per dimension) 25 | 26 | buffer_info() { } 27 | 28 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 29 | detail::any_container shape_in, detail::any_container strides_in) 30 | : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), 31 | shape(std::move(shape_in)), strides(std::move(strides_in)) { 32 | if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) 33 | pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); 34 | for (size_t i = 0; i < (size_t) ndim; ++i) 35 | size *= shape[i]; 36 | } 37 | 38 | template 39 | buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in) 40 | : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { } 41 | 42 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size) 43 | : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { } 44 | 45 | template 46 | buffer_info(T *ptr, ssize_t size) 47 | : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { } 48 | 49 | explicit buffer_info(Py_buffer *view, bool ownview = true) 50 | : buffer_info(view->buf, view->itemsize, view->format, view->ndim, 51 | {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) { 52 | this->view = view; 53 | this->ownview = ownview; 54 | } 55 | 56 | buffer_info(const buffer_info &) = delete; 57 | buffer_info& operator=(const buffer_info &) = delete; 58 | 59 | buffer_info(buffer_info &&other) { 60 | (*this) = std::move(other); 61 | } 62 | 63 | buffer_info& operator=(buffer_info &&rhs) { 64 | ptr = rhs.ptr; 65 | itemsize = rhs.itemsize; 66 | size = rhs.size; 67 | format = std::move(rhs.format); 68 | ndim = rhs.ndim; 69 | shape = std::move(rhs.shape); 70 | strides = std::move(rhs.strides); 71 | std::swap(view, rhs.view); 72 | std::swap(ownview, rhs.ownview); 73 | return *this; 74 | } 75 | 76 | ~buffer_info() { 77 | if (view && ownview) { PyBuffer_Release(view); delete view; } 78 | } 79 | 80 | private: 81 | struct private_ctr_tag { }; 82 | 83 | buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 84 | detail::any_container &&shape_in, detail::any_container &&strides_in) 85 | : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { } 86 | 87 | Py_buffer *view = nullptr; 88 | bool ownview = false; 89 | }; 90 | 91 | NAMESPACE_BEGIN(detail) 92 | 93 | template struct compare_buffer_info { 94 | static bool compare(const buffer_info& b) { 95 | return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); 96 | } 97 | }; 98 | 99 | template struct compare_buffer_info::value>> { 100 | static bool compare(const buffer_info& b) { 101 | return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value || 102 | ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || 103 | ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); 104 | } 105 | }; 106 | 107 | NAMESPACE_END(detail) 108 | NAMESPACE_END(PYBIND11_NAMESPACE) 109 | -------------------------------------------------------------------------------- /pse/include/pybind11/chrono.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime 3 | 4 | Copyright (c) 2016 Trent Houliston and 5 | Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "pybind11.h" 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | // Backport the PyDateTime_DELTA functions from Python3.3 if required 20 | #ifndef PyDateTime_DELTA_GET_DAYS 21 | #define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days) 22 | #endif 23 | #ifndef PyDateTime_DELTA_GET_SECONDS 24 | #define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds) 25 | #endif 26 | #ifndef PyDateTime_DELTA_GET_MICROSECONDS 27 | #define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds) 28 | #endif 29 | 30 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 31 | NAMESPACE_BEGIN(detail) 32 | 33 | template class duration_caster { 34 | public: 35 | typedef typename type::rep rep; 36 | typedef typename type::period period; 37 | 38 | typedef std::chrono::duration> days; 39 | 40 | bool load(handle src, bool) { 41 | using namespace std::chrono; 42 | 43 | // Lazy initialise the PyDateTime import 44 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 45 | 46 | if (!src) return false; 47 | // If invoked with datetime.delta object 48 | if (PyDelta_Check(src.ptr())) { 49 | value = type(duration_cast>( 50 | days(PyDateTime_DELTA_GET_DAYS(src.ptr())) 51 | + seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr())) 52 | + microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr())))); 53 | return true; 54 | } 55 | // If invoked with a float we assume it is seconds and convert 56 | else if (PyFloat_Check(src.ptr())) { 57 | value = type(duration_cast>(duration(PyFloat_AsDouble(src.ptr())))); 58 | return true; 59 | } 60 | else return false; 61 | } 62 | 63 | // If this is a duration just return it back 64 | static const std::chrono::duration& get_duration(const std::chrono::duration &src) { 65 | return src; 66 | } 67 | 68 | // If this is a time_point get the time_since_epoch 69 | template static std::chrono::duration get_duration(const std::chrono::time_point> &src) { 70 | return src.time_since_epoch(); 71 | } 72 | 73 | static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) { 74 | using namespace std::chrono; 75 | 76 | // Use overloaded function to get our duration from our source 77 | // Works out if it is a duration or time_point and get the duration 78 | auto d = get_duration(src); 79 | 80 | // Lazy initialise the PyDateTime import 81 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 82 | 83 | // Declare these special duration types so the conversions happen with the correct primitive types (int) 84 | using dd_t = duration>; 85 | using ss_t = duration>; 86 | using us_t = duration; 87 | 88 | auto dd = duration_cast(d); 89 | auto subd = d - dd; 90 | auto ss = duration_cast(subd); 91 | auto us = duration_cast(subd - ss); 92 | return PyDelta_FromDSU(dd.count(), ss.count(), us.count()); 93 | } 94 | 95 | PYBIND11_TYPE_CASTER(type, _("datetime.timedelta")); 96 | }; 97 | 98 | // This is for casting times on the system clock into datetime.datetime instances 99 | template class type_caster> { 100 | public: 101 | typedef std::chrono::time_point type; 102 | bool load(handle src, bool) { 103 | using namespace std::chrono; 104 | 105 | // Lazy initialise the PyDateTime import 106 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 107 | 108 | if (!src) return false; 109 | if (PyDateTime_Check(src.ptr())) { 110 | std::tm cal; 111 | cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr()); 112 | cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr()); 113 | cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr()); 114 | cal.tm_mday = PyDateTime_GET_DAY(src.ptr()); 115 | cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1; 116 | cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900; 117 | cal.tm_isdst = -1; 118 | 119 | value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr())); 120 | return true; 121 | } 122 | else return false; 123 | } 124 | 125 | static handle cast(const std::chrono::time_point &src, return_value_policy /* policy */, handle /* parent */) { 126 | using namespace std::chrono; 127 | 128 | // Lazy initialise the PyDateTime import 129 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 130 | 131 | std::time_t tt = system_clock::to_time_t(src); 132 | // this function uses static memory so it's best to copy it out asap just in case 133 | // otherwise other code that is using localtime may break this (not just python code) 134 | std::tm localtime = *std::localtime(&tt); 135 | 136 | // Declare these special duration types so the conversions happen with the correct primitive types (int) 137 | using us_t = duration; 138 | 139 | return PyDateTime_FromDateAndTime(localtime.tm_year + 1900, 140 | localtime.tm_mon + 1, 141 | localtime.tm_mday, 142 | localtime.tm_hour, 143 | localtime.tm_min, 144 | localtime.tm_sec, 145 | (duration_cast(src.time_since_epoch() % seconds(1))).count()); 146 | } 147 | PYBIND11_TYPE_CASTER(type, _("datetime.datetime")); 148 | }; 149 | 150 | // Other clocks that are not the system clock are not measured as datetime.datetime objects 151 | // since they are not measured on calendar time. So instead we just make them timedeltas 152 | // Or if they have passed us a time as a float we convert that 153 | template class type_caster> 154 | : public duration_caster> { 155 | }; 156 | 157 | template class type_caster> 158 | : public duration_caster> { 159 | }; 160 | 161 | NAMESPACE_END(detail) 162 | NAMESPACE_END(PYBIND11_NAMESPACE) 163 | -------------------------------------------------------------------------------- /pse/include/pybind11/common.h: -------------------------------------------------------------------------------- 1 | #include "detail/common.h" 2 | #warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'." 3 | -------------------------------------------------------------------------------- /pse/include/pybind11/complex.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/complex.h: Complex number support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | /// glibc defines I as a macro which breaks things, e.g., boost template names 16 | #ifdef I 17 | # undef I 18 | #endif 19 | 20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 21 | 22 | template struct format_descriptor, detail::enable_if_t::value>> { 23 | static constexpr const char c = format_descriptor::c; 24 | static constexpr const char value[3] = { 'Z', c, '\0' }; 25 | static std::string format() { return std::string(value); } 26 | }; 27 | 28 | #ifndef PYBIND11_CPP17 29 | 30 | template constexpr const char format_descriptor< 31 | std::complex, detail::enable_if_t::value>>::value[3]; 32 | 33 | #endif 34 | 35 | NAMESPACE_BEGIN(detail) 36 | 37 | template struct is_fmt_numeric, detail::enable_if_t::value>> { 38 | static constexpr bool value = true; 39 | static constexpr int index = is_fmt_numeric::index + 3; 40 | }; 41 | 42 | template class type_caster> { 43 | public: 44 | bool load(handle src, bool convert) { 45 | if (!src) 46 | return false; 47 | if (!convert && !PyComplex_Check(src.ptr())) 48 | return false; 49 | Py_complex result = PyComplex_AsCComplex(src.ptr()); 50 | if (result.real == -1.0 && PyErr_Occurred()) { 51 | PyErr_Clear(); 52 | return false; 53 | } 54 | value = std::complex((T) result.real, (T) result.imag); 55 | return true; 56 | } 57 | 58 | static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) { 59 | return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); 60 | } 61 | 62 | PYBIND11_TYPE_CASTER(std::complex, _("complex")); 63 | }; 64 | NAMESPACE_END(detail) 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /pse/include/pybind11/descr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/descr.h: Helper type for concatenating type signatures 3 | either at runtime (C++11) or compile time (C++14) 4 | 5 | Copyright (c) 2016 Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "common.h" 14 | 15 | NAMESPACE_BEGIN(pybind11) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | /* Concatenate type signatures at compile time using C++14 */ 19 | #if defined(PYBIND11_CPP14) && !defined(_MSC_VER) 20 | #define PYBIND11_CONSTEXPR_DESCR 21 | 22 | template class descr { 23 | template friend class descr; 24 | public: 25 | constexpr descr(char const (&text) [Size1+1], const std::type_info * const (&types)[Size2+1]) 26 | : descr(text, types, 27 | make_index_sequence(), 28 | make_index_sequence()) { } 29 | 30 | constexpr const char *text() const { return m_text; } 31 | constexpr const std::type_info * const * types() const { return m_types; } 32 | 33 | template 34 | constexpr descr operator+(const descr &other) const { 35 | return concat(other, 36 | make_index_sequence(), 37 | make_index_sequence(), 38 | make_index_sequence(), 39 | make_index_sequence()); 40 | } 41 | 42 | protected: 43 | template 44 | constexpr descr( 45 | char const (&text) [Size1+1], 46 | const std::type_info * const (&types) [Size2+1], 47 | index_sequence, index_sequence) 48 | : m_text{text[Indices1]..., '\0'}, 49 | m_types{types[Indices2]..., nullptr } {} 50 | 51 | template 53 | constexpr descr 54 | concat(const descr &other, 55 | index_sequence, index_sequence, 56 | index_sequence, index_sequence) const { 57 | return descr( 58 | { m_text[Indices1]..., other.m_text[OtherIndices1]..., '\0' }, 59 | { m_types[Indices2]..., other.m_types[OtherIndices2]..., nullptr } 60 | ); 61 | } 62 | 63 | protected: 64 | char m_text[Size1 + 1]; 65 | const std::type_info * m_types[Size2 + 1]; 66 | }; 67 | 68 | template constexpr descr _(char const(&text)[Size]) { 69 | return descr(text, { nullptr }); 70 | } 71 | 72 | template struct int_to_str : int_to_str { }; 73 | template struct int_to_str<0, Digits...> { 74 | static constexpr auto digits = descr({ ('0' + Digits)..., '\0' }, { nullptr }); 75 | }; 76 | 77 | // Ternary description (like std::conditional) 78 | template 79 | constexpr enable_if_t> _(char const(&text1)[Size1], char const(&)[Size2]) { 80 | return _(text1); 81 | } 82 | template 83 | constexpr enable_if_t> _(char const(&)[Size1], char const(&text2)[Size2]) { 84 | return _(text2); 85 | } 86 | template 87 | constexpr enable_if_t> _(descr d, descr) { return d; } 88 | template 89 | constexpr enable_if_t> _(descr, descr d) { return d; } 90 | 91 | template auto constexpr _() -> decltype(int_to_str::digits) { 92 | return int_to_str::digits; 93 | } 94 | 95 | template constexpr descr<1, 1> _() { 96 | return descr<1, 1>({ '%', '\0' }, { &typeid(Type), nullptr }); 97 | } 98 | 99 | inline constexpr descr<0, 0> concat() { return _(""); } 100 | template auto constexpr concat(descr descr) { return descr; } 101 | template auto constexpr concat(descr descr, Args&&... args) { return descr + _(", ") + concat(args...); } 102 | template auto constexpr type_descr(descr descr) { return _("{") + descr + _("}"); } 103 | 104 | #define PYBIND11_DESCR constexpr auto 105 | 106 | #else /* Simpler C++11 implementation based on run-time memory allocation and copying */ 107 | 108 | class descr { 109 | public: 110 | PYBIND11_NOINLINE descr(const char *text, const std::type_info * const * types) { 111 | size_t nChars = len(text), nTypes = len(types); 112 | m_text = new char[nChars]; 113 | m_types = new const std::type_info *[nTypes]; 114 | memcpy(m_text, text, nChars * sizeof(char)); 115 | memcpy(m_types, types, nTypes * sizeof(const std::type_info *)); 116 | } 117 | 118 | PYBIND11_NOINLINE descr operator+(descr &&d2) && { 119 | descr r; 120 | 121 | size_t nChars1 = len(m_text), nTypes1 = len(m_types); 122 | size_t nChars2 = len(d2.m_text), nTypes2 = len(d2.m_types); 123 | 124 | r.m_text = new char[nChars1 + nChars2 - 1]; 125 | r.m_types = new const std::type_info *[nTypes1 + nTypes2 - 1]; 126 | memcpy(r.m_text, m_text, (nChars1-1) * sizeof(char)); 127 | memcpy(r.m_text + nChars1 - 1, d2.m_text, nChars2 * sizeof(char)); 128 | memcpy(r.m_types, m_types, (nTypes1-1) * sizeof(std::type_info *)); 129 | memcpy(r.m_types + nTypes1 - 1, d2.m_types, nTypes2 * sizeof(std::type_info *)); 130 | 131 | delete[] m_text; delete[] m_types; 132 | delete[] d2.m_text; delete[] d2.m_types; 133 | 134 | return r; 135 | } 136 | 137 | char *text() { return m_text; } 138 | const std::type_info * * types() { return m_types; } 139 | 140 | protected: 141 | PYBIND11_NOINLINE descr() { } 142 | 143 | template static size_t len(const T *ptr) { // return length including null termination 144 | const T *it = ptr; 145 | while (*it++ != (T) 0) 146 | ; 147 | return static_cast(it - ptr); 148 | } 149 | 150 | const std::type_info **m_types = nullptr; 151 | char *m_text = nullptr; 152 | }; 153 | 154 | /* The 'PYBIND11_NOINLINE inline' combinations below are intentional to get the desired linkage while producing as little object code as possible */ 155 | 156 | PYBIND11_NOINLINE inline descr _(const char *text) { 157 | const std::type_info *types[1] = { nullptr }; 158 | return descr(text, types); 159 | } 160 | 161 | template PYBIND11_NOINLINE enable_if_t _(const char *text1, const char *) { return _(text1); } 162 | template PYBIND11_NOINLINE enable_if_t _(char const *, const char *text2) { return _(text2); } 163 | template PYBIND11_NOINLINE enable_if_t _(descr d, descr) { return d; } 164 | template PYBIND11_NOINLINE enable_if_t _(descr, descr d) { return d; } 165 | 166 | template PYBIND11_NOINLINE descr _() { 167 | const std::type_info *types[2] = { &typeid(Type), nullptr }; 168 | return descr("%", types); 169 | } 170 | 171 | template PYBIND11_NOINLINE descr _() { 172 | const std::type_info *types[1] = { nullptr }; 173 | return descr(std::to_string(Size).c_str(), types); 174 | } 175 | 176 | PYBIND11_NOINLINE inline descr concat() { return _(""); } 177 | PYBIND11_NOINLINE inline descr concat(descr &&d) { return d; } 178 | template PYBIND11_NOINLINE descr concat(descr &&d, Args&&... args) { return std::move(d) + _(", ") + concat(std::forward(args)...); } 179 | PYBIND11_NOINLINE inline descr type_descr(descr&& d) { return _("{") + std::move(d) + _("}"); } 180 | 181 | #define PYBIND11_DESCR ::pybind11::detail::descr 182 | #endif 183 | 184 | NAMESPACE_END(detail) 185 | NAMESPACE_END(pybind11) 186 | -------------------------------------------------------------------------------- /pse/include/pybind11/detail/descr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | NAMESPACE_BEGIN(detail) 16 | 17 | #if !defined(_MSC_VER) 18 | # define PYBIND11_DESCR_CONSTEXPR static constexpr 19 | #else 20 | # define PYBIND11_DESCR_CONSTEXPR const 21 | #endif 22 | 23 | /* Concatenate type signatures at compile time */ 24 | template 25 | struct descr { 26 | char text[N + 1]; 27 | 28 | constexpr descr() : text{'\0'} { } 29 | constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence()) { } 30 | 31 | template 32 | constexpr descr(char const (&s)[N+1], index_sequence) : text{s[Is]..., '\0'} { } 33 | 34 | template 35 | constexpr descr(char c, Chars... cs) : text{c, static_cast(cs)..., '\0'} { } 36 | 37 | static constexpr std::array types() { 38 | return {{&typeid(Ts)..., nullptr}}; 39 | } 40 | }; 41 | 42 | template 43 | constexpr descr plus_impl(const descr &a, const descr &b, 44 | index_sequence, index_sequence) { 45 | return {a.text[Is1]..., b.text[Is2]...}; 46 | } 47 | 48 | template 49 | constexpr descr operator+(const descr &a, const descr &b) { 50 | return plus_impl(a, b, make_index_sequence(), make_index_sequence()); 51 | } 52 | 53 | template 54 | constexpr descr _(char const(&text)[N]) { return descr(text); } 55 | constexpr descr<0> _(char const(&)[1]) { return {}; } 56 | 57 | template struct int_to_str : int_to_str { }; 58 | template struct int_to_str<0, Digits...> { 59 | static constexpr auto digits = descr(('0' + Digits)...); 60 | }; 61 | 62 | // Ternary description (like std::conditional) 63 | template 64 | constexpr enable_if_t> _(char const(&text1)[N1], char const(&)[N2]) { 65 | return _(text1); 66 | } 67 | template 68 | constexpr enable_if_t> _(char const(&)[N1], char const(&text2)[N2]) { 69 | return _(text2); 70 | } 71 | 72 | template 73 | constexpr enable_if_t _(const T1 &d, const T2 &) { return d; } 74 | template 75 | constexpr enable_if_t _(const T1 &, const T2 &d) { return d; } 76 | 77 | template auto constexpr _() -> decltype(int_to_str::digits) { 78 | return int_to_str::digits; 79 | } 80 | 81 | template constexpr descr<1, Type> _() { return {'%'}; } 82 | 83 | constexpr descr<0> concat() { return {}; } 84 | 85 | template 86 | constexpr descr concat(const descr &descr) { return descr; } 87 | 88 | template 89 | constexpr auto concat(const descr &d, const Args &...args) 90 | -> decltype(std::declval>() + concat(args...)) { 91 | return d + _(", ") + concat(args...); 92 | } 93 | 94 | template 95 | constexpr descr type_descr(const descr &descr) { 96 | return _("{") + descr + _("}"); 97 | } 98 | 99 | NAMESPACE_END(detail) 100 | NAMESPACE_END(PYBIND11_NAMESPACE) 101 | -------------------------------------------------------------------------------- /pse/include/pybind11/detail/init.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/init.h: init factory function implementation and support code. 3 | 4 | Copyright (c) 2017 Jason Rhinelander 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "class.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | NAMESPACE_BEGIN(detail) 16 | 17 | template <> 18 | class type_caster { 19 | public: 20 | bool load(handle h, bool) { 21 | value = reinterpret_cast(h.ptr()); 22 | return true; 23 | } 24 | 25 | template using cast_op_type = value_and_holder &; 26 | operator value_and_holder &() { return *value; } 27 | static constexpr auto name = _(); 28 | 29 | private: 30 | value_and_holder *value = nullptr; 31 | }; 32 | 33 | NAMESPACE_BEGIN(initimpl) 34 | 35 | inline void no_nullptr(void *ptr) { 36 | if (!ptr) throw type_error("pybind11::init(): factory function returned nullptr"); 37 | } 38 | 39 | // Implementing functions for all forms of py::init<...> and py::init(...) 40 | template using Cpp = typename Class::type; 41 | template using Alias = typename Class::type_alias; 42 | template using Holder = typename Class::holder_type; 43 | 44 | template using is_alias_constructible = std::is_constructible, Cpp &&>; 45 | 46 | // Takes a Cpp pointer and returns true if it actually is a polymorphic Alias instance. 47 | template = 0> 48 | bool is_alias(Cpp *ptr) { 49 | return dynamic_cast *>(ptr) != nullptr; 50 | } 51 | // Failing fallback version of the above for a no-alias class (always returns false) 52 | template 53 | constexpr bool is_alias(void *) { return false; } 54 | 55 | // Constructs and returns a new object; if the given arguments don't map to a constructor, we fall 56 | // back to brace aggregate initiailization so that for aggregate initialization can be used with 57 | // py::init, e.g. `py::init` to initialize a `struct T { int a; int b; }`. For 58 | // non-aggregate types, we need to use an ordinary T(...) constructor (invoking as `T{...}` usually 59 | // works, but will not do the expected thing when `T` has an `initializer_list` constructor). 60 | template ::value, int> = 0> 61 | inline Class *construct_or_initialize(Args &&...args) { return new Class(std::forward(args)...); } 62 | template ::value, int> = 0> 63 | inline Class *construct_or_initialize(Args &&...args) { return new Class{std::forward(args)...}; } 64 | 65 | // Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This allows types with 66 | // an alias to provide only a single Cpp factory function as long as the Alias can be 67 | // constructed from an rvalue reference of the base Cpp type. This means that Alias classes 68 | // can, when appropriate, simply define a `Alias(Cpp &&)` constructor rather than needing to 69 | // inherit all the base class constructors. 70 | template 71 | void construct_alias_from_cpp(std::true_type /*is_alias_constructible*/, 72 | value_and_holder &v_h, Cpp &&base) { 73 | v_h.value_ptr() = new Alias(std::move(base)); 74 | } 75 | template 76 | [[noreturn]] void construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/, 77 | value_and_holder &, Cpp &&) { 78 | throw type_error("pybind11::init(): unable to convert returned instance to required " 79 | "alias class: no `Alias(Class &&)` constructor available"); 80 | } 81 | 82 | // Error-generating fallback for factories that don't match one of the below construction 83 | // mechanisms. 84 | template 85 | void construct(...) { 86 | static_assert(!std::is_same::value /* always false */, 87 | "pybind11::init(): init function must return a compatible pointer, " 88 | "holder, or value"); 89 | } 90 | 91 | // Pointer return v1: the factory function returns a class pointer for a registered class. 92 | // If we don't need an alias (because this class doesn't have one, or because the final type is 93 | // inherited on the Python side) we can simply take over ownership. Otherwise we need to try to 94 | // construct an Alias from the returned base instance. 95 | template 96 | void construct(value_and_holder &v_h, Cpp *ptr, bool need_alias) { 97 | no_nullptr(ptr); 98 | if (Class::has_alias && need_alias && !is_alias(ptr)) { 99 | // We're going to try to construct an alias by moving the cpp type. Whether or not 100 | // that succeeds, we still need to destroy the original cpp pointer (either the 101 | // moved away leftover, if the alias construction works, or the value itself if we 102 | // throw an error), but we can't just call `delete ptr`: it might have a special 103 | // deleter, or might be shared_from_this. So we construct a holder around it as if 104 | // it was a normal instance, then steal the holder away into a local variable; thus 105 | // the holder and destruction happens when we leave the C++ scope, and the holder 106 | // class gets to handle the destruction however it likes. 107 | v_h.value_ptr() = ptr; 108 | v_h.set_instance_registered(true); // To prevent init_instance from registering it 109 | v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder 110 | Holder temp_holder(std::move(v_h.holder>())); // Steal the holder 111 | v_h.type->dealloc(v_h); // Destroys the moved-out holder remains, resets value ptr to null 112 | v_h.set_instance_registered(false); 113 | 114 | construct_alias_from_cpp(is_alias_constructible{}, v_h, std::move(*ptr)); 115 | } else { 116 | // Otherwise the type isn't inherited, so we don't need an Alias 117 | v_h.value_ptr() = ptr; 118 | } 119 | } 120 | 121 | // Pointer return v2: a factory that always returns an alias instance ptr. We simply take over 122 | // ownership of the pointer. 123 | template = 0> 124 | void construct(value_and_holder &v_h, Alias *alias_ptr, bool) { 125 | no_nullptr(alias_ptr); 126 | v_h.value_ptr() = static_cast *>(alias_ptr); 127 | } 128 | 129 | // Holder return: copy its pointer, and move or copy the returned holder into the new instance's 130 | // holder. This also handles types like std::shared_ptr and std::unique_ptr where T is a 131 | // derived type (through those holder's implicit conversion from derived class holder constructors). 132 | template 133 | void construct(value_and_holder &v_h, Holder holder, bool need_alias) { 134 | auto *ptr = holder_helper>::get(holder); 135 | // If we need an alias, check that the held pointer is actually an alias instance 136 | if (Class::has_alias && need_alias && !is_alias(ptr)) 137 | throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance " 138 | "is not an alias instance"); 139 | 140 | v_h.value_ptr() = ptr; 141 | v_h.type->init_instance(v_h.inst, &holder); 142 | } 143 | 144 | // return-by-value version 1: returning a cpp class by value. If the class has an alias and an 145 | // alias is required the alias must have an `Alias(Cpp &&)` constructor so that we can construct 146 | // the alias from the base when needed (i.e. because of Python-side inheritance). When we don't 147 | // need it, we simply move-construct the cpp value into a new instance. 148 | template 149 | void construct(value_and_holder &v_h, Cpp &&result, bool need_alias) { 150 | static_assert(std::is_move_constructible>::value, 151 | "pybind11::init() return-by-value factory function requires a movable class"); 152 | if (Class::has_alias && need_alias) 153 | construct_alias_from_cpp(is_alias_constructible{}, v_h, std::move(result)); 154 | else 155 | v_h.value_ptr() = new Cpp(std::move(result)); 156 | } 157 | 158 | // return-by-value version 2: returning a value of the alias type itself. We move-construct an 159 | // Alias instance (even if no the python-side inheritance is involved). The is intended for 160 | // cases where Alias initialization is always desired. 161 | template 162 | void construct(value_and_holder &v_h, Alias &&result, bool) { 163 | static_assert(std::is_move_constructible>::value, 164 | "pybind11::init() return-by-alias-value factory function requires a movable alias class"); 165 | v_h.value_ptr() = new Alias(std::move(result)); 166 | } 167 | 168 | // Implementing class for py::init<...>() 169 | template 170 | struct constructor { 171 | template = 0> 172 | static void execute(Class &cl, const Extra&... extra) { 173 | cl.def("__init__", [](value_and_holder &v_h, Args... args) { 174 | v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); 175 | }, is_new_style_constructor(), extra...); 176 | } 177 | 178 | template , Args...>::value, int> = 0> 181 | static void execute(Class &cl, const Extra&... extra) { 182 | cl.def("__init__", [](value_and_holder &v_h, Args... args) { 183 | if (Py_TYPE(v_h.inst) == v_h.type->type) 184 | v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); 185 | else 186 | v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); 187 | }, is_new_style_constructor(), extra...); 188 | } 189 | 190 | template , Args...>::value, int> = 0> 193 | static void execute(Class &cl, const Extra&... extra) { 194 | cl.def("__init__", [](value_and_holder &v_h, Args... args) { 195 | v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); 196 | }, is_new_style_constructor(), extra...); 197 | } 198 | }; 199 | 200 | // Implementing class for py::init_alias<...>() 201 | template struct alias_constructor { 202 | template , Args...>::value, int> = 0> 204 | static void execute(Class &cl, const Extra&... extra) { 205 | cl.def("__init__", [](value_and_holder &v_h, Args... args) { 206 | v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); 207 | }, is_new_style_constructor(), extra...); 208 | } 209 | }; 210 | 211 | // Implementation class for py::init(Func) and py::init(Func, AliasFunc) 212 | template , typename = function_signature_t> 214 | struct factory; 215 | 216 | // Specialization for py::init(Func) 217 | template 218 | struct factory { 219 | remove_reference_t class_factory; 220 | 221 | factory(Func &&f) : class_factory(std::forward(f)) { } 222 | 223 | // The given class either has no alias or has no separate alias factory; 224 | // this always constructs the class itself. If the class is registered with an alias 225 | // type and an alias instance is needed (i.e. because the final type is a Python class 226 | // inheriting from the C++ type) the returned value needs to either already be an alias 227 | // instance, or the alias needs to be constructible from a `Class &&` argument. 228 | template 229 | void execute(Class &cl, const Extra &...extra) && { 230 | #if defined(PYBIND11_CPP14) 231 | cl.def("__init__", [func = std::move(class_factory)] 232 | #else 233 | auto &func = class_factory; 234 | cl.def("__init__", [func] 235 | #endif 236 | (value_and_holder &v_h, Args... args) { 237 | construct(v_h, func(std::forward(args)...), 238 | Py_TYPE(v_h.inst) != v_h.type->type); 239 | }, is_new_style_constructor(), extra...); 240 | } 241 | }; 242 | 243 | // Specialization for py::init(Func, AliasFunc) 244 | template 246 | struct factory { 247 | static_assert(sizeof...(CArgs) == sizeof...(AArgs), 248 | "pybind11::init(class_factory, alias_factory): class and alias factories " 249 | "must have identical argument signatures"); 250 | static_assert(all_of...>::value, 251 | "pybind11::init(class_factory, alias_factory): class and alias factories " 252 | "must have identical argument signatures"); 253 | 254 | remove_reference_t class_factory; 255 | remove_reference_t alias_factory; 256 | 257 | factory(CFunc &&c, AFunc &&a) 258 | : class_factory(std::forward(c)), alias_factory(std::forward(a)) { } 259 | 260 | // The class factory is called when the `self` type passed to `__init__` is the direct 261 | // class (i.e. not inherited), the alias factory when `self` is a Python-side subtype. 262 | template 263 | void execute(Class &cl, const Extra&... extra) && { 264 | static_assert(Class::has_alias, "The two-argument version of `py::init()` can " 265 | "only be used if the class has an alias"); 266 | #if defined(PYBIND11_CPP14) 267 | cl.def("__init__", [class_func = std::move(class_factory), alias_func = std::move(alias_factory)] 268 | #else 269 | auto &class_func = class_factory; 270 | auto &alias_func = alias_factory; 271 | cl.def("__init__", [class_func, alias_func] 272 | #endif 273 | (value_and_holder &v_h, CArgs... args) { 274 | if (Py_TYPE(v_h.inst) == v_h.type->type) 275 | // If the instance type equals the registered type we don't have inheritance, so 276 | // don't need the alias and can construct using the class function: 277 | construct(v_h, class_func(std::forward(args)...), false); 278 | else 279 | construct(v_h, alias_func(std::forward(args)...), true); 280 | }, is_new_style_constructor(), extra...); 281 | } 282 | }; 283 | 284 | /// Set just the C++ state. Same as `__init__`. 285 | template 286 | void setstate(value_and_holder &v_h, T &&result, bool need_alias) { 287 | construct(v_h, std::forward(result), need_alias); 288 | } 289 | 290 | /// Set both the C++ and Python states 291 | template ::value, int> = 0> 293 | void setstate(value_and_holder &v_h, std::pair &&result, bool need_alias) { 294 | construct(v_h, std::move(result.first), need_alias); 295 | setattr((PyObject *) v_h.inst, "__dict__", result.second); 296 | } 297 | 298 | /// Implementation for py::pickle(GetState, SetState) 299 | template , typename = function_signature_t> 301 | struct pickle_factory; 302 | 303 | template 305 | struct pickle_factory { 306 | static_assert(std::is_same, intrinsic_t>::value, 307 | "The type returned by `__getstate__` must be the same " 308 | "as the argument accepted by `__setstate__`"); 309 | 310 | remove_reference_t get; 311 | remove_reference_t set; 312 | 313 | pickle_factory(Get get, Set set) 314 | : get(std::forward(get)), set(std::forward(set)) { } 315 | 316 | template 317 | void execute(Class &cl, const Extra &...extra) && { 318 | cl.def("__getstate__", std::move(get)); 319 | 320 | #if defined(PYBIND11_CPP14) 321 | cl.def("__setstate__", [func = std::move(set)] 322 | #else 323 | auto &func = set; 324 | cl.def("__setstate__", [func] 325 | #endif 326 | (value_and_holder &v_h, ArgState state) { 327 | setstate(v_h, func(std::forward(state)), 328 | Py_TYPE(v_h.inst) != v_h.type->type); 329 | }, is_new_style_constructor(), extra...); 330 | } 331 | }; 332 | 333 | NAMESPACE_END(initimpl) 334 | NAMESPACE_END(detail) 335 | NAMESPACE_END(pybind11) 336 | -------------------------------------------------------------------------------- /pse/include/pybind11/detail/internals.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/internals.h: Internal data structure and related functions 3 | 4 | Copyright (c) 2017 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "../pytypes.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | NAMESPACE_BEGIN(detail) 16 | // Forward declarations 17 | inline PyTypeObject *make_static_property_type(); 18 | inline PyTypeObject *make_default_metaclass(); 19 | inline PyObject *make_object_base_type(PyTypeObject *metaclass); 20 | 21 | // The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new 22 | // Thread Specific Storage (TSS) API. 23 | #if PY_VERSION_HEX >= 0x03070000 24 | # define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr 25 | # define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key)) 26 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (tstate)) 27 | # define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr) 28 | #else 29 | // Usually an int but a long on Cygwin64 with Python 3.x 30 | # define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0 31 | # define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key)) 32 | # if PY_MAJOR_VERSION < 3 33 | # define PYBIND11_TLS_DELETE_VALUE(key) \ 34 | PyThread_delete_key_value(key) 35 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) \ 36 | do { \ 37 | PyThread_delete_key_value((key)); \ 38 | PyThread_set_key_value((key), (value)); \ 39 | } while (false) 40 | # else 41 | # define PYBIND11_TLS_DELETE_VALUE(key) \ 42 | PyThread_set_key_value((key), nullptr) 43 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) \ 44 | PyThread_set_key_value((key), (value)) 45 | # endif 46 | #endif 47 | 48 | // Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly 49 | // other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module 50 | // even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under 51 | // libstdc++, this doesn't happen: equality and the type_index hash are based on the type name, 52 | // which works. If not under a known-good stl, provide our own name-based hash and equality 53 | // functions that use the type name. 54 | #if defined(__GLIBCXX__) 55 | inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; } 56 | using type_hash = std::hash; 57 | using type_equal_to = std::equal_to; 58 | #else 59 | inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { 60 | return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0; 61 | } 62 | 63 | struct type_hash { 64 | size_t operator()(const std::type_index &t) const { 65 | size_t hash = 5381; 66 | const char *ptr = t.name(); 67 | while (auto c = static_cast(*ptr++)) 68 | hash = (hash * 33) ^ c; 69 | return hash; 70 | } 71 | }; 72 | 73 | struct type_equal_to { 74 | bool operator()(const std::type_index &lhs, const std::type_index &rhs) const { 75 | return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0; 76 | } 77 | }; 78 | #endif 79 | 80 | template 81 | using type_map = std::unordered_map; 82 | 83 | struct overload_hash { 84 | inline size_t operator()(const std::pair& v) const { 85 | size_t value = std::hash()(v.first); 86 | value ^= std::hash()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2); 87 | return value; 88 | } 89 | }; 90 | 91 | /// Internal data structure used to track registered instances and types. 92 | /// Whenever binary incompatible changes are made to this structure, 93 | /// `PYBIND11_INTERNALS_VERSION` must be incremented. 94 | struct internals { 95 | type_map registered_types_cpp; // std::type_index -> pybind11's type information 96 | std::unordered_map> registered_types_py; // PyTypeObject* -> base type_info(s) 97 | std::unordered_multimap registered_instances; // void * -> instance* 98 | std::unordered_set, overload_hash> inactive_overload_cache; 99 | type_map> direct_conversions; 100 | std::unordered_map> patients; 101 | std::forward_list registered_exception_translators; 102 | std::unordered_map shared_data; // Custom data to be shared across extensions 103 | std::vector loader_patient_stack; // Used by `loader_life_support` 104 | std::forward_list static_strings; // Stores the std::strings backing detail::c_str() 105 | PyTypeObject *static_property_type; 106 | PyTypeObject *default_metaclass; 107 | PyObject *instance_base; 108 | #if defined(WITH_THREAD) 109 | PYBIND11_TLS_KEY_INIT(tstate); 110 | PyInterpreterState *istate = nullptr; 111 | #endif 112 | }; 113 | 114 | /// Additional type information which does not fit into the PyTypeObject. 115 | /// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`. 116 | struct type_info { 117 | PyTypeObject *type; 118 | const std::type_info *cpptype; 119 | size_t type_size, type_align, holder_size_in_ptrs; 120 | void *(*operator_new)(size_t); 121 | void (*init_instance)(instance *, const void *); 122 | void (*dealloc)(value_and_holder &v_h); 123 | std::vector implicit_conversions; 124 | std::vector> implicit_casts; 125 | std::vector *direct_conversions; 126 | buffer_info *(*get_buffer)(PyObject *, void *) = nullptr; 127 | void *get_buffer_data = nullptr; 128 | void *(*module_local_load)(PyObject *, const type_info *) = nullptr; 129 | /* A simple type never occurs as a (direct or indirect) parent 130 | * of a class that makes use of multiple inheritance */ 131 | bool simple_type : 1; 132 | /* True if there is no multiple inheritance in this type's inheritance tree */ 133 | bool simple_ancestors : 1; 134 | /* for base vs derived holder_type checks */ 135 | bool default_holder : 1; 136 | /* true if this is a type registered with py::module_local */ 137 | bool module_local : 1; 138 | }; 139 | 140 | /// Tracks the `internals` and `type_info` ABI version independent of the main library version 141 | #define PYBIND11_INTERNALS_VERSION 3 142 | 143 | #if defined(_DEBUG) 144 | # define PYBIND11_BUILD_TYPE "_debug" 145 | #else 146 | # define PYBIND11_BUILD_TYPE "" 147 | #endif 148 | 149 | #if defined(WITH_THREAD) 150 | # define PYBIND11_INTERNALS_KIND "" 151 | #else 152 | # define PYBIND11_INTERNALS_KIND "_without_thread" 153 | #endif 154 | 155 | #define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \ 156 | PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__" 157 | 158 | #define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \ 159 | PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__" 160 | 161 | /// Each module locally stores a pointer to the `internals` data. The data 162 | /// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`. 163 | inline internals **&get_internals_pp() { 164 | static internals **internals_pp = nullptr; 165 | return internals_pp; 166 | } 167 | 168 | /// Return a reference to the current `internals` data 169 | PYBIND11_NOINLINE inline internals &get_internals() { 170 | auto **&internals_pp = get_internals_pp(); 171 | if (internals_pp && *internals_pp) 172 | return **internals_pp; 173 | 174 | constexpr auto *id = PYBIND11_INTERNALS_ID; 175 | auto builtins = handle(PyEval_GetBuiltins()); 176 | if (builtins.contains(id) && isinstance(builtins[id])) { 177 | internals_pp = static_cast(capsule(builtins[id])); 178 | 179 | // We loaded builtins through python's builtins, which means that our `error_already_set` 180 | // and `builtin_exception` may be different local classes than the ones set up in the 181 | // initial exception translator, below, so add another for our local exception classes. 182 | // 183 | // libstdc++ doesn't require this (types there are identified only by name) 184 | #if !defined(__GLIBCXX__) 185 | (*internals_pp)->registered_exception_translators.push_front( 186 | [](std::exception_ptr p) -> void { 187 | try { 188 | if (p) std::rethrow_exception(p); 189 | } catch (error_already_set &e) { e.restore(); return; 190 | } catch (const builtin_exception &e) { e.set_error(); return; 191 | } 192 | } 193 | ); 194 | #endif 195 | } else { 196 | if (!internals_pp) internals_pp = new internals*(); 197 | auto *&internals_ptr = *internals_pp; 198 | internals_ptr = new internals(); 199 | #if defined(WITH_THREAD) 200 | PyEval_InitThreads(); 201 | PyThreadState *tstate = PyThreadState_Get(); 202 | #if PY_VERSION_HEX >= 0x03070000 203 | internals_ptr->tstate = PyThread_tss_alloc(); 204 | if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate)) 205 | pybind11_fail("get_internals: could not successfully initialize the TSS key!"); 206 | PyThread_tss_set(internals_ptr->tstate, tstate); 207 | #else 208 | internals_ptr->tstate = PyThread_create_key(); 209 | if (internals_ptr->tstate == -1) 210 | pybind11_fail("get_internals: could not successfully initialize the TLS key!"); 211 | PyThread_set_key_value(internals_ptr->tstate, tstate); 212 | #endif 213 | internals_ptr->istate = tstate->interp; 214 | #endif 215 | builtins[id] = capsule(internals_pp); 216 | internals_ptr->registered_exception_translators.push_front( 217 | [](std::exception_ptr p) -> void { 218 | try { 219 | if (p) std::rethrow_exception(p); 220 | } catch (error_already_set &e) { e.restore(); return; 221 | } catch (const builtin_exception &e) { e.set_error(); return; 222 | } catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return; 223 | } catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 224 | } catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 225 | } catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 226 | } catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return; 227 | } catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 228 | } catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return; 229 | } catch (...) { 230 | PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!"); 231 | return; 232 | } 233 | } 234 | ); 235 | internals_ptr->static_property_type = make_static_property_type(); 236 | internals_ptr->default_metaclass = make_default_metaclass(); 237 | internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass); 238 | } 239 | return **internals_pp; 240 | } 241 | 242 | /// Works like `internals.registered_types_cpp`, but for module-local registered types: 243 | inline type_map ®istered_local_types_cpp() { 244 | static type_map locals{}; 245 | return locals; 246 | } 247 | 248 | /// Constructs a std::string with the given arguments, stores it in `internals`, and returns its 249 | /// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only 250 | /// cleared when the program exits or after interpreter shutdown (when embedding), and so are 251 | /// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name). 252 | template 253 | const char *c_str(Args &&...args) { 254 | auto &strings = get_internals().static_strings; 255 | strings.emplace_front(std::forward(args)...); 256 | return strings.front().c_str(); 257 | } 258 | 259 | NAMESPACE_END(detail) 260 | 261 | /// Returns a named pointer that is shared among all extension modules (using the same 262 | /// pybind11 version) running in the current interpreter. Names starting with underscores 263 | /// are reserved for internal usage. Returns `nullptr` if no matching entry was found. 264 | inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) { 265 | auto &internals = detail::get_internals(); 266 | auto it = internals.shared_data.find(name); 267 | return it != internals.shared_data.end() ? it->second : nullptr; 268 | } 269 | 270 | /// Set the shared data that can be later recovered by `get_shared_data()`. 271 | inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) { 272 | detail::get_internals().shared_data[name] = data; 273 | return data; 274 | } 275 | 276 | /// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if 277 | /// such entry exists. Otherwise, a new object of default-constructible type `T` is 278 | /// added to the shared data under the given name and a reference to it is returned. 279 | template 280 | T &get_or_create_shared_data(const std::string &name) { 281 | auto &internals = detail::get_internals(); 282 | auto it = internals.shared_data.find(name); 283 | T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr); 284 | if (!ptr) { 285 | ptr = new T(); 286 | internals.shared_data[name] = ptr; 287 | } 288 | return *ptr; 289 | } 290 | 291 | NAMESPACE_END(PYBIND11_NAMESPACE) 292 | -------------------------------------------------------------------------------- /pse/include/pybind11/detail/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(PYBIND11_NAMESPACE) 54 | -------------------------------------------------------------------------------- /pse/include/pybind11/embed.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/embed.h: Support for embedding the interpreter 3 | 4 | Copyright (c) 2017 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include "eval.h" 14 | 15 | #if defined(PYPY_VERSION) 16 | # error Embedding the interpreter is not supported with PyPy 17 | #endif 18 | 19 | #if PY_MAJOR_VERSION >= 3 20 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 21 | extern "C" PyObject *pybind11_init_impl_##name() { \ 22 | return pybind11_init_wrapper_##name(); \ 23 | } 24 | #else 25 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 26 | extern "C" void pybind11_init_impl_##name() { \ 27 | pybind11_init_wrapper_##name(); \ 28 | } 29 | #endif 30 | 31 | /** \rst 32 | Add a new module to the table of builtins for the interpreter. Must be 33 | defined in global scope. The first macro parameter is the name of the 34 | module (without quotes). The second parameter is the variable which will 35 | be used as the interface to add functions and classes to the module. 36 | 37 | .. code-block:: cpp 38 | 39 | PYBIND11_EMBEDDED_MODULE(example, m) { 40 | // ... initialize functions and classes here 41 | m.def("foo", []() { 42 | return "Hello, World!"; 43 | }); 44 | } 45 | \endrst */ 46 | #define PYBIND11_EMBEDDED_MODULE(name, variable) \ 47 | static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \ 48 | static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \ 49 | auto m = pybind11::module(PYBIND11_TOSTRING(name)); \ 50 | try { \ 51 | PYBIND11_CONCAT(pybind11_init_, name)(m); \ 52 | return m.ptr(); \ 53 | } catch (pybind11::error_already_set &e) { \ 54 | PyErr_SetString(PyExc_ImportError, e.what()); \ 55 | return nullptr; \ 56 | } catch (const std::exception &e) { \ 57 | PyErr_SetString(PyExc_ImportError, e.what()); \ 58 | return nullptr; \ 59 | } \ 60 | } \ 61 | PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 62 | pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \ 63 | PYBIND11_CONCAT(pybind11_init_impl_, name)); \ 64 | void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable) 65 | 66 | 67 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 68 | NAMESPACE_BEGIN(detail) 69 | 70 | /// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks. 71 | struct embedded_module { 72 | #if PY_MAJOR_VERSION >= 3 73 | using init_t = PyObject *(*)(); 74 | #else 75 | using init_t = void (*)(); 76 | #endif 77 | embedded_module(const char *name, init_t init) { 78 | if (Py_IsInitialized()) 79 | pybind11_fail("Can't add new modules after the interpreter has been initialized"); 80 | 81 | auto result = PyImport_AppendInittab(name, init); 82 | if (result == -1) 83 | pybind11_fail("Insufficient memory to add a new module"); 84 | } 85 | }; 86 | 87 | NAMESPACE_END(detail) 88 | 89 | /** \rst 90 | Initialize the Python interpreter. No other pybind11 or CPython API functions can be 91 | called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The 92 | optional parameter can be used to skip the registration of signal handlers (see the 93 | `Python documentation`_ for details). Calling this function again after the interpreter 94 | has already been initialized is a fatal error. 95 | 96 | If initializing the Python interpreter fails, then the program is terminated. (This 97 | is controlled by the CPython runtime and is an exception to pybind11's normal behavior 98 | of throwing exceptions on errors.) 99 | 100 | .. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx 101 | \endrst */ 102 | inline void initialize_interpreter(bool init_signal_handlers = true) { 103 | if (Py_IsInitialized()) 104 | pybind11_fail("The interpreter is already running"); 105 | 106 | Py_InitializeEx(init_signal_handlers ? 1 : 0); 107 | 108 | // Make .py files in the working directory available by default 109 | module::import("sys").attr("path").cast().append("."); 110 | } 111 | 112 | /** \rst 113 | Shut down the Python interpreter. No pybind11 or CPython API functions can be called 114 | after this. In addition, pybind11 objects must not outlive the interpreter: 115 | 116 | .. code-block:: cpp 117 | 118 | { // BAD 119 | py::initialize_interpreter(); 120 | auto hello = py::str("Hello, World!"); 121 | py::finalize_interpreter(); 122 | } // <-- BOOM, hello's destructor is called after interpreter shutdown 123 | 124 | { // GOOD 125 | py::initialize_interpreter(); 126 | { // scoped 127 | auto hello = py::str("Hello, World!"); 128 | } // <-- OK, hello is cleaned up properly 129 | py::finalize_interpreter(); 130 | } 131 | 132 | { // BETTER 133 | py::scoped_interpreter guard{}; 134 | auto hello = py::str("Hello, World!"); 135 | } 136 | 137 | .. warning:: 138 | 139 | The interpreter can be restarted by calling `initialize_interpreter` again. 140 | Modules created using pybind11 can be safely re-initialized. However, Python 141 | itself cannot completely unload binary extension modules and there are several 142 | caveats with regard to interpreter restarting. All the details can be found 143 | in the CPython documentation. In short, not all interpreter memory may be 144 | freed, either due to reference cycles or user-created global data. 145 | 146 | \endrst */ 147 | inline void finalize_interpreter() { 148 | handle builtins(PyEval_GetBuiltins()); 149 | const char *id = PYBIND11_INTERNALS_ID; 150 | 151 | // Get the internals pointer (without creating it if it doesn't exist). It's possible for the 152 | // internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()` 153 | // during destruction), so we get the pointer-pointer here and check it after Py_Finalize(). 154 | detail::internals **internals_ptr_ptr = detail::get_internals_pp(); 155 | // It could also be stashed in builtins, so look there too: 156 | if (builtins.contains(id) && isinstance(builtins[id])) 157 | internals_ptr_ptr = capsule(builtins[id]); 158 | 159 | Py_Finalize(); 160 | 161 | if (internals_ptr_ptr) { 162 | delete *internals_ptr_ptr; 163 | *internals_ptr_ptr = nullptr; 164 | } 165 | } 166 | 167 | /** \rst 168 | Scope guard version of `initialize_interpreter` and `finalize_interpreter`. 169 | This a move-only guard and only a single instance can exist. 170 | 171 | .. code-block:: cpp 172 | 173 | #include 174 | 175 | int main() { 176 | py::scoped_interpreter guard{}; 177 | py::print(Hello, World!); 178 | } // <-- interpreter shutdown 179 | \endrst */ 180 | class scoped_interpreter { 181 | public: 182 | scoped_interpreter(bool init_signal_handlers = true) { 183 | initialize_interpreter(init_signal_handlers); 184 | } 185 | 186 | scoped_interpreter(const scoped_interpreter &) = delete; 187 | scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; } 188 | scoped_interpreter &operator=(const scoped_interpreter &) = delete; 189 | scoped_interpreter &operator=(scoped_interpreter &&) = delete; 190 | 191 | ~scoped_interpreter() { 192 | if (is_valid) 193 | finalize_interpreter(); 194 | } 195 | 196 | private: 197 | bool is_valid = true; 198 | }; 199 | 200 | NAMESPACE_END(PYBIND11_NAMESPACE) 201 | -------------------------------------------------------------------------------- /pse/include/pybind11/eval.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/exec.h: Support for evaluating Python expressions and statements 3 | from strings and files 4 | 5 | Copyright (c) 2016 Klemens Morgenstern and 6 | Wenzel Jakob 7 | 8 | All rights reserved. Use of this source code is governed by a 9 | BSD-style license that can be found in the LICENSE file. 10 | */ 11 | 12 | #pragma once 13 | 14 | #include "pybind11.h" 15 | 16 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 17 | 18 | enum eval_mode { 19 | /// Evaluate a string containing an isolated expression 20 | eval_expr, 21 | 22 | /// Evaluate a string containing a single statement. Returns \c none 23 | eval_single_statement, 24 | 25 | /// Evaluate a string containing a sequence of statement. Returns \c none 26 | eval_statements 27 | }; 28 | 29 | template 30 | object eval(str expr, object global = globals(), object local = object()) { 31 | if (!local) 32 | local = global; 33 | 34 | /* PyRun_String does not accept a PyObject / encoding specifier, 35 | this seems to be the only alternative */ 36 | std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr; 37 | 38 | int start; 39 | switch (mode) { 40 | case eval_expr: start = Py_eval_input; break; 41 | case eval_single_statement: start = Py_single_input; break; 42 | case eval_statements: start = Py_file_input; break; 43 | default: pybind11_fail("invalid evaluation mode"); 44 | } 45 | 46 | PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr()); 47 | if (!result) 48 | throw error_already_set(); 49 | return reinterpret_steal(result); 50 | } 51 | 52 | template 53 | object eval(const char (&s)[N], object global = globals(), object local = object()) { 54 | /* Support raw string literals by removing common leading whitespace */ 55 | auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s)) 56 | : str(s); 57 | return eval(expr, global, local); 58 | } 59 | 60 | inline void exec(str expr, object global = globals(), object local = object()) { 61 | eval(expr, global, local); 62 | } 63 | 64 | template 65 | void exec(const char (&s)[N], object global = globals(), object local = object()) { 66 | eval(s, global, local); 67 | } 68 | 69 | template 70 | object eval_file(str fname, object global = globals(), object local = object()) { 71 | if (!local) 72 | local = global; 73 | 74 | int start; 75 | switch (mode) { 76 | case eval_expr: start = Py_eval_input; break; 77 | case eval_single_statement: start = Py_single_input; break; 78 | case eval_statements: start = Py_file_input; break; 79 | default: pybind11_fail("invalid evaluation mode"); 80 | } 81 | 82 | int closeFile = 1; 83 | std::string fname_str = (std::string) fname; 84 | #if PY_VERSION_HEX >= 0x03040000 85 | FILE *f = _Py_fopen_obj(fname.ptr(), "r"); 86 | #elif PY_VERSION_HEX >= 0x03000000 87 | FILE *f = _Py_fopen(fname.ptr(), "r"); 88 | #else 89 | /* No unicode support in open() :( */ 90 | auto fobj = reinterpret_steal(PyFile_FromString( 91 | const_cast(fname_str.c_str()), 92 | const_cast("r"))); 93 | FILE *f = nullptr; 94 | if (fobj) 95 | f = PyFile_AsFile(fobj.ptr()); 96 | closeFile = 0; 97 | #endif 98 | if (!f) { 99 | PyErr_Clear(); 100 | pybind11_fail("File \"" + fname_str + "\" could not be opened!"); 101 | } 102 | 103 | #if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION) 104 | PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(), 105 | local.ptr()); 106 | (void) closeFile; 107 | #else 108 | PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(), 109 | local.ptr(), closeFile); 110 | #endif 111 | 112 | if (!result) 113 | throw error_already_set(); 114 | return reinterpret_steal(result); 115 | } 116 | 117 | NAMESPACE_END(PYBIND11_NAMESPACE) 118 | -------------------------------------------------------------------------------- /pse/include/pybind11/functional.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/functional.h: std::function<> support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | template 19 | struct type_caster> { 20 | using type = std::function; 21 | using retval_type = conditional_t::value, void_type, Return>; 22 | using function_type = Return (*) (Args...); 23 | 24 | public: 25 | bool load(handle src, bool convert) { 26 | if (src.is_none()) { 27 | // Defer accepting None to other overloads (if we aren't in convert mode): 28 | if (!convert) return false; 29 | return true; 30 | } 31 | 32 | if (!isinstance(src)) 33 | return false; 34 | 35 | auto func = reinterpret_borrow(src); 36 | 37 | /* 38 | When passing a C++ function as an argument to another C++ 39 | function via Python, every function call would normally involve 40 | a full C++ -> Python -> C++ roundtrip, which can be prohibitive. 41 | Here, we try to at least detect the case where the function is 42 | stateless (i.e. function pointer or lambda function without 43 | captured variables), in which case the roundtrip can be avoided. 44 | */ 45 | if (auto cfunc = func.cpp_function()) { 46 | auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); 47 | auto rec = (function_record *) c; 48 | 49 | if (rec && rec->is_stateless && 50 | same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { 51 | struct capture { function_type f; }; 52 | value = ((capture *) &rec->data)->f; 53 | return true; 54 | } 55 | } 56 | 57 | value = [func](Args... args) -> Return { 58 | gil_scoped_acquire acq; 59 | object retval(func(std::forward(args)...)); 60 | /* Visual studio 2015 parser issue: need parentheses around this expression */ 61 | return (retval.template cast()); 62 | }; 63 | return true; 64 | } 65 | 66 | template 67 | static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) { 68 | if (!f_) 69 | return none().inc_ref(); 70 | 71 | auto result = f_.template target(); 72 | if (result) 73 | return cpp_function(*result, policy).release(); 74 | else 75 | return cpp_function(std::forward(f_), policy).release(); 76 | } 77 | 78 | PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster::name...) + _("], ") 79 | + make_caster::name + _("]")); 80 | }; 81 | 82 | NAMESPACE_END(detail) 83 | NAMESPACE_END(PYBIND11_NAMESPACE) 84 | -------------------------------------------------------------------------------- /pse/include/pybind11/iostream.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python 3 | 4 | Copyright (c) 2017 Henry F. Schreiner 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 21 | NAMESPACE_BEGIN(detail) 22 | 23 | // Buffer that writes to Python instead of C++ 24 | class pythonbuf : public std::streambuf { 25 | private: 26 | using traits_type = std::streambuf::traits_type; 27 | 28 | char d_buffer[1024]; 29 | object pywrite; 30 | object pyflush; 31 | 32 | int overflow(int c) { 33 | if (!traits_type::eq_int_type(c, traits_type::eof())) { 34 | *pptr() = traits_type::to_char_type(c); 35 | pbump(1); 36 | } 37 | return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof(); 38 | } 39 | 40 | int sync() { 41 | if (pbase() != pptr()) { 42 | // This subtraction cannot be negative, so dropping the sign 43 | str line(pbase(), static_cast(pptr() - pbase())); 44 | 45 | pywrite(line); 46 | pyflush(); 47 | 48 | setp(pbase(), epptr()); 49 | } 50 | return 0; 51 | } 52 | 53 | public: 54 | pythonbuf(object pyostream) 55 | : pywrite(pyostream.attr("write")), 56 | pyflush(pyostream.attr("flush")) { 57 | setp(d_buffer, d_buffer + sizeof(d_buffer) - 1); 58 | } 59 | 60 | /// Sync before destroy 61 | ~pythonbuf() { 62 | sync(); 63 | } 64 | }; 65 | 66 | NAMESPACE_END(detail) 67 | 68 | 69 | /** \rst 70 | This a move-only guard that redirects output. 71 | 72 | .. code-block:: cpp 73 | 74 | #include 75 | 76 | ... 77 | 78 | { 79 | py::scoped_ostream_redirect output; 80 | std::cout << "Hello, World!"; // Python stdout 81 | } // <-- return std::cout to normal 82 | 83 | You can explicitly pass the c++ stream and the python object, 84 | for example to guard stderr instead. 85 | 86 | .. code-block:: cpp 87 | 88 | { 89 | py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")}; 90 | std::cerr << "Hello, World!"; 91 | } 92 | \endrst */ 93 | class scoped_ostream_redirect { 94 | protected: 95 | std::streambuf *old; 96 | std::ostream &costream; 97 | detail::pythonbuf buffer; 98 | 99 | public: 100 | scoped_ostream_redirect( 101 | std::ostream &costream = std::cout, 102 | object pyostream = module::import("sys").attr("stdout")) 103 | : costream(costream), buffer(pyostream) { 104 | old = costream.rdbuf(&buffer); 105 | } 106 | 107 | ~scoped_ostream_redirect() { 108 | costream.rdbuf(old); 109 | } 110 | 111 | scoped_ostream_redirect(const scoped_ostream_redirect &) = delete; 112 | scoped_ostream_redirect(scoped_ostream_redirect &&other) = default; 113 | scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete; 114 | scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete; 115 | }; 116 | 117 | 118 | /** \rst 119 | Like `scoped_ostream_redirect`, but redirects cerr by default. This class 120 | is provided primary to make ``py::call_guard`` easier to make. 121 | 122 | .. code-block:: cpp 123 | 124 | m.def("noisy_func", &noisy_func, 125 | py::call_guard()); 127 | 128 | \endrst */ 129 | class scoped_estream_redirect : public scoped_ostream_redirect { 130 | public: 131 | scoped_estream_redirect( 132 | std::ostream &costream = std::cerr, 133 | object pyostream = module::import("sys").attr("stderr")) 134 | : scoped_ostream_redirect(costream,pyostream) {} 135 | }; 136 | 137 | 138 | NAMESPACE_BEGIN(detail) 139 | 140 | // Class to redirect output as a context manager. C++ backend. 141 | class OstreamRedirect { 142 | bool do_stdout_; 143 | bool do_stderr_; 144 | std::unique_ptr redirect_stdout; 145 | std::unique_ptr redirect_stderr; 146 | 147 | public: 148 | OstreamRedirect(bool do_stdout = true, bool do_stderr = true) 149 | : do_stdout_(do_stdout), do_stderr_(do_stderr) {} 150 | 151 | void enter() { 152 | if (do_stdout_) 153 | redirect_stdout.reset(new scoped_ostream_redirect()); 154 | if (do_stderr_) 155 | redirect_stderr.reset(new scoped_estream_redirect()); 156 | } 157 | 158 | void exit() { 159 | redirect_stdout.reset(); 160 | redirect_stderr.reset(); 161 | } 162 | }; 163 | 164 | NAMESPACE_END(detail) 165 | 166 | /** \rst 167 | This is a helper function to add a C++ redirect context manager to Python 168 | instead of using a C++ guard. To use it, add the following to your binding code: 169 | 170 | .. code-block:: cpp 171 | 172 | #include 173 | 174 | ... 175 | 176 | py::add_ostream_redirect(m, "ostream_redirect"); 177 | 178 | You now have a Python context manager that redirects your output: 179 | 180 | .. code-block:: python 181 | 182 | with m.ostream_redirect(): 183 | m.print_to_cout_function() 184 | 185 | This manager can optionally be told which streams to operate on: 186 | 187 | .. code-block:: python 188 | 189 | with m.ostream_redirect(stdout=true, stderr=true): 190 | m.noisy_function_with_error_printing() 191 | 192 | \endrst */ 193 | inline class_ add_ostream_redirect(module m, std::string name = "ostream_redirect") { 194 | return class_(m, name.c_str(), module_local()) 195 | .def(init(), arg("stdout")=true, arg("stderr")=true) 196 | .def("__enter__", &detail::OstreamRedirect::enter) 197 | .def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); }); 198 | } 199 | 200 | NAMESPACE_END(PYBIND11_NAMESPACE) 201 | -------------------------------------------------------------------------------- /pse/include/pybind11/operators.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/operator.h: Metatemplates for operator overloading 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | 14 | #if defined(__clang__) && !defined(__INTEL_COMPILER) 15 | # pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type())) 16 | #elif defined(_MSC_VER) 17 | # pragma warning(push) 18 | # pragma warning(disable: 4127) // warning C4127: Conditional expression is constant 19 | #endif 20 | 21 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 22 | NAMESPACE_BEGIN(detail) 23 | 24 | /// Enumeration with all supported operator types 25 | enum op_id : int { 26 | op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift, 27 | op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert, 28 | op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le, 29 | op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift, 30 | op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero, 31 | op_repr, op_truediv, op_itruediv, op_hash 32 | }; 33 | 34 | enum op_type : int { 35 | op_l, /* base type on left */ 36 | op_r, /* base type on right */ 37 | op_u /* unary operator */ 38 | }; 39 | 40 | struct self_t { }; 41 | static const self_t self = self_t(); 42 | 43 | /// Type for an unused type slot 44 | struct undefined_t { }; 45 | 46 | /// Don't warn about an unused variable 47 | inline self_t __self() { return self; } 48 | 49 | /// base template of operator implementations 50 | template struct op_impl { }; 51 | 52 | /// Operator implementation generator 53 | template struct op_ { 54 | template void execute(Class &cl, const Extra&... extra) const { 55 | using Base = typename Class::type; 56 | using L_type = conditional_t::value, Base, L>; 57 | using R_type = conditional_t::value, Base, R>; 58 | using op = op_impl; 59 | cl.def(op::name(), &op::execute, is_operator(), extra...); 60 | #if PY_MAJOR_VERSION < 3 61 | if (id == op_truediv || id == op_itruediv) 62 | cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", 63 | &op::execute, is_operator(), extra...); 64 | #endif 65 | } 66 | template void execute_cast(Class &cl, const Extra&... extra) const { 67 | using Base = typename Class::type; 68 | using L_type = conditional_t::value, Base, L>; 69 | using R_type = conditional_t::value, Base, R>; 70 | using op = op_impl; 71 | cl.def(op::name(), &op::execute_cast, is_operator(), extra...); 72 | #if PY_MAJOR_VERSION < 3 73 | if (id == op_truediv || id == op_itruediv) 74 | cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", 75 | &op::execute, is_operator(), extra...); 76 | #endif 77 | } 78 | }; 79 | 80 | #define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \ 81 | template struct op_impl { \ 82 | static char const* name() { return "__" #id "__"; } \ 83 | static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \ 84 | static B execute_cast(const L &l, const R &r) { return B(expr); } \ 85 | }; \ 86 | template struct op_impl { \ 87 | static char const* name() { return "__" #rid "__"; } \ 88 | static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \ 89 | static B execute_cast(const R &r, const L &l) { return B(expr); } \ 90 | }; \ 91 | inline op_ op(const self_t &, const self_t &) { \ 92 | return op_(); \ 93 | } \ 94 | template op_ op(const self_t &, const T &) { \ 95 | return op_(); \ 96 | } \ 97 | template op_ op(const T &, const self_t &) { \ 98 | return op_(); \ 99 | } 100 | 101 | #define PYBIND11_INPLACE_OPERATOR(id, op, expr) \ 102 | template struct op_impl { \ 103 | static char const* name() { return "__" #id "__"; } \ 104 | static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \ 105 | static B execute_cast(L &l, const R &r) { return B(expr); } \ 106 | }; \ 107 | template op_ op(const self_t &, const T &) { \ 108 | return op_(); \ 109 | } 110 | 111 | #define PYBIND11_UNARY_OPERATOR(id, op, expr) \ 112 | template struct op_impl { \ 113 | static char const* name() { return "__" #id "__"; } \ 114 | static auto execute(const L &l) -> decltype(expr) { return expr; } \ 115 | static B execute_cast(const L &l) { return B(expr); } \ 116 | }; \ 117 | inline op_ op(const self_t &) { \ 118 | return op_(); \ 119 | } 120 | 121 | PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r) 122 | PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r) 123 | PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r) 124 | PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r) 125 | PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r) 126 | PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r) 127 | PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r) 128 | PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r) 129 | PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r) 130 | PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r) 131 | PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r) 132 | PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r) 133 | PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r) 134 | PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r) 135 | PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r) 136 | PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r) 137 | //PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r)) 138 | PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r) 139 | PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r) 140 | PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r) 141 | PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r) 142 | PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r) 143 | PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r) 144 | PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r) 145 | PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r) 146 | PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r) 147 | PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r) 148 | PYBIND11_UNARY_OPERATOR(neg, operator-, -l) 149 | PYBIND11_UNARY_OPERATOR(pos, operator+, +l) 150 | PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l)) 151 | PYBIND11_UNARY_OPERATOR(hash, hash, std::hash()(l)) 152 | PYBIND11_UNARY_OPERATOR(invert, operator~, (~l)) 153 | PYBIND11_UNARY_OPERATOR(bool, operator!, !!l) 154 | PYBIND11_UNARY_OPERATOR(int, int_, (int) l) 155 | PYBIND11_UNARY_OPERATOR(float, float_, (double) l) 156 | 157 | #undef PYBIND11_BINARY_OPERATOR 158 | #undef PYBIND11_INPLACE_OPERATOR 159 | #undef PYBIND11_UNARY_OPERATOR 160 | NAMESPACE_END(detail) 161 | 162 | using detail::self; 163 | 164 | NAMESPACE_END(PYBIND11_NAMESPACE) 165 | 166 | #if defined(_MSC_VER) 167 | # pragma warning(pop) 168 | #endif 169 | -------------------------------------------------------------------------------- /pse/include/pybind11/options.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/options.h: global settings that are configurable at runtime. 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | 16 | class options { 17 | public: 18 | 19 | // Default RAII constructor, which leaves settings as they currently are. 20 | options() : previous_state(global_state()) {} 21 | 22 | // Class is non-copyable. 23 | options(const options&) = delete; 24 | options& operator=(const options&) = delete; 25 | 26 | // Destructor, which restores settings that were in effect before. 27 | ~options() { 28 | global_state() = previous_state; 29 | } 30 | 31 | // Setter methods (affect the global state): 32 | 33 | options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; } 34 | 35 | options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; } 36 | 37 | options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; } 38 | 39 | options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; } 40 | 41 | // Getter methods (return the global state): 42 | 43 | static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; } 44 | 45 | static bool show_function_signatures() { return global_state().show_function_signatures; } 46 | 47 | // This type is not meant to be allocated on the heap. 48 | void* operator new(size_t) = delete; 49 | 50 | private: 51 | 52 | struct state { 53 | bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings. 54 | bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings. 55 | }; 56 | 57 | static state &global_state() { 58 | static state instance; 59 | return instance; 60 | } 61 | 62 | state previous_state; 63 | }; 64 | 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /pse/include/pybind11/stl.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/stl.h: Transparent conversion for STL data types 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #if defined(_MSC_VER) 23 | #pragma warning(push) 24 | #pragma warning(disable: 4127) // warning C4127: Conditional expression is constant 25 | #endif 26 | 27 | #ifdef __has_include 28 | // std::optional (but including it in c++14 mode isn't allowed) 29 | # if defined(PYBIND11_CPP17) && __has_include() 30 | # include 31 | # define PYBIND11_HAS_OPTIONAL 1 32 | # endif 33 | // std::experimental::optional (but not allowed in c++11 mode) 34 | # if defined(PYBIND11_CPP14) && (__has_include() && \ 35 | !__has_include()) 36 | # include 37 | # define PYBIND11_HAS_EXP_OPTIONAL 1 38 | # endif 39 | // std::variant 40 | # if defined(PYBIND11_CPP17) && __has_include() 41 | # include 42 | # define PYBIND11_HAS_VARIANT 1 43 | # endif 44 | #elif defined(_MSC_VER) && defined(PYBIND11_CPP17) 45 | # include 46 | # include 47 | # define PYBIND11_HAS_OPTIONAL 1 48 | # define PYBIND11_HAS_VARIANT 1 49 | #endif 50 | 51 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 52 | NAMESPACE_BEGIN(detail) 53 | 54 | /// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for 55 | /// forwarding a container element). Typically used indirect via forwarded_type(), below. 56 | template 57 | using forwarded_type = conditional_t< 58 | std::is_lvalue_reference::value, remove_reference_t &, remove_reference_t &&>; 59 | 60 | /// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically 61 | /// used for forwarding a container's elements. 62 | template 63 | forwarded_type forward_like(U &&u) { 64 | return std::forward>(std::forward(u)); 65 | } 66 | 67 | template struct set_caster { 68 | using type = Type; 69 | using key_conv = make_caster; 70 | 71 | bool load(handle src, bool convert) { 72 | if (!isinstance(src)) 73 | return false; 74 | auto s = reinterpret_borrow(src); 75 | value.clear(); 76 | for (auto entry : s) { 77 | key_conv conv; 78 | if (!conv.load(entry, convert)) 79 | return false; 80 | value.insert(cast_op(std::move(conv))); 81 | } 82 | return true; 83 | } 84 | 85 | template 86 | static handle cast(T &&src, return_value_policy policy, handle parent) { 87 | if (!std::is_lvalue_reference::value) 88 | policy = return_value_policy_override::policy(policy); 89 | pybind11::set s; 90 | for (auto &&value : src) { 91 | auto value_ = reinterpret_steal(key_conv::cast(forward_like(value), policy, parent)); 92 | if (!value_ || !s.add(value_)) 93 | return handle(); 94 | } 95 | return s.release(); 96 | } 97 | 98 | PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]")); 99 | }; 100 | 101 | template struct map_caster { 102 | using key_conv = make_caster; 103 | using value_conv = make_caster; 104 | 105 | bool load(handle src, bool convert) { 106 | if (!isinstance(src)) 107 | return false; 108 | auto d = reinterpret_borrow(src); 109 | value.clear(); 110 | for (auto it : d) { 111 | key_conv kconv; 112 | value_conv vconv; 113 | if (!kconv.load(it.first.ptr(), convert) || 114 | !vconv.load(it.second.ptr(), convert)) 115 | return false; 116 | value.emplace(cast_op(std::move(kconv)), cast_op(std::move(vconv))); 117 | } 118 | return true; 119 | } 120 | 121 | template 122 | static handle cast(T &&src, return_value_policy policy, handle parent) { 123 | dict d; 124 | return_value_policy policy_key = policy; 125 | return_value_policy policy_value = policy; 126 | if (!std::is_lvalue_reference::value) { 127 | policy_key = return_value_policy_override::policy(policy_key); 128 | policy_value = return_value_policy_override::policy(policy_value); 129 | } 130 | for (auto &&kv : src) { 131 | auto key = reinterpret_steal(key_conv::cast(forward_like(kv.first), policy_key, parent)); 132 | auto value = reinterpret_steal(value_conv::cast(forward_like(kv.second), policy_value, parent)); 133 | if (!key || !value) 134 | return handle(); 135 | d[key] = value; 136 | } 137 | return d.release(); 138 | } 139 | 140 | PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") + value_conv::name + _("]")); 141 | }; 142 | 143 | template struct list_caster { 144 | using value_conv = make_caster; 145 | 146 | bool load(handle src, bool convert) { 147 | if (!isinstance(src) || isinstance(src)) 148 | return false; 149 | auto s = reinterpret_borrow(src); 150 | value.clear(); 151 | reserve_maybe(s, &value); 152 | for (auto it : s) { 153 | value_conv conv; 154 | if (!conv.load(it, convert)) 155 | return false; 156 | value.push_back(cast_op(std::move(conv))); 157 | } 158 | return true; 159 | } 160 | 161 | private: 162 | template ().reserve(0)), void>::value, int> = 0> 164 | void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); } 165 | void reserve_maybe(sequence, void *) { } 166 | 167 | public: 168 | template 169 | static handle cast(T &&src, return_value_policy policy, handle parent) { 170 | if (!std::is_lvalue_reference::value) 171 | policy = return_value_policy_override::policy(policy); 172 | list l(src.size()); 173 | size_t index = 0; 174 | for (auto &&value : src) { 175 | auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); 176 | if (!value_) 177 | return handle(); 178 | PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference 179 | } 180 | return l.release(); 181 | } 182 | 183 | PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]")); 184 | }; 185 | 186 | template struct type_caster> 187 | : list_caster, Type> { }; 188 | 189 | template struct type_caster> 190 | : list_caster, Type> { }; 191 | 192 | template struct type_caster> 193 | : list_caster, Type> { }; 194 | 195 | template struct array_caster { 196 | using value_conv = make_caster; 197 | 198 | private: 199 | template 200 | bool require_size(enable_if_t size) { 201 | if (value.size() != size) 202 | value.resize(size); 203 | return true; 204 | } 205 | template 206 | bool require_size(enable_if_t size) { 207 | return size == Size; 208 | } 209 | 210 | public: 211 | bool load(handle src, bool convert) { 212 | if (!isinstance(src)) 213 | return false; 214 | auto l = reinterpret_borrow(src); 215 | if (!require_size(l.size())) 216 | return false; 217 | size_t ctr = 0; 218 | for (auto it : l) { 219 | value_conv conv; 220 | if (!conv.load(it, convert)) 221 | return false; 222 | value[ctr++] = cast_op(std::move(conv)); 223 | } 224 | return true; 225 | } 226 | 227 | template 228 | static handle cast(T &&src, return_value_policy policy, handle parent) { 229 | list l(src.size()); 230 | size_t index = 0; 231 | for (auto &&value : src) { 232 | auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); 233 | if (!value_) 234 | return handle(); 235 | PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference 236 | } 237 | return l.release(); 238 | } 239 | 240 | PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name + _(_(""), _("[") + _() + _("]")) + _("]")); 241 | }; 242 | 243 | template struct type_caster> 244 | : array_caster, Type, false, Size> { }; 245 | 246 | template struct type_caster> 247 | : array_caster, Type, true> { }; 248 | 249 | template struct type_caster> 250 | : set_caster, Key> { }; 251 | 252 | template struct type_caster> 253 | : set_caster, Key> { }; 254 | 255 | template struct type_caster> 256 | : map_caster, Key, Value> { }; 257 | 258 | template struct type_caster> 259 | : map_caster, Key, Value> { }; 260 | 261 | // This type caster is intended to be used for std::optional and std::experimental::optional 262 | template struct optional_caster { 263 | using value_conv = make_caster; 264 | 265 | template 266 | static handle cast(T_ &&src, return_value_policy policy, handle parent) { 267 | if (!src) 268 | return none().inc_ref(); 269 | policy = return_value_policy_override::policy(policy); 270 | return value_conv::cast(*std::forward(src), policy, parent); 271 | } 272 | 273 | bool load(handle src, bool convert) { 274 | if (!src) { 275 | return false; 276 | } else if (src.is_none()) { 277 | return true; // default-constructed value is already empty 278 | } 279 | value_conv inner_caster; 280 | if (!inner_caster.load(src, convert)) 281 | return false; 282 | 283 | value.emplace(cast_op(std::move(inner_caster))); 284 | return true; 285 | } 286 | 287 | PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]")); 288 | }; 289 | 290 | #if PYBIND11_HAS_OPTIONAL 291 | template struct type_caster> 292 | : public optional_caster> {}; 293 | 294 | template<> struct type_caster 295 | : public void_caster {}; 296 | #endif 297 | 298 | #if PYBIND11_HAS_EXP_OPTIONAL 299 | template struct type_caster> 300 | : public optional_caster> {}; 301 | 302 | template<> struct type_caster 303 | : public void_caster {}; 304 | #endif 305 | 306 | /// Visit a variant and cast any found type to Python 307 | struct variant_caster_visitor { 308 | return_value_policy policy; 309 | handle parent; 310 | 311 | using result_type = handle; // required by boost::variant in C++11 312 | 313 | template 314 | result_type operator()(T &&src) const { 315 | return make_caster::cast(std::forward(src), policy, parent); 316 | } 317 | }; 318 | 319 | /// Helper class which abstracts away variant's `visit` function. `std::variant` and similar 320 | /// `namespace::variant` types which provide a `namespace::visit()` function are handled here 321 | /// automatically using argument-dependent lookup. Users can provide specializations for other 322 | /// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`. 323 | template class Variant> 324 | struct visit_helper { 325 | template 326 | static auto call(Args &&...args) -> decltype(visit(std::forward(args)...)) { 327 | return visit(std::forward(args)...); 328 | } 329 | }; 330 | 331 | /// Generic variant caster 332 | template struct variant_caster; 333 | 334 | template class V, typename... Ts> 335 | struct variant_caster> { 336 | static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative."); 337 | 338 | template 339 | bool load_alternative(handle src, bool convert, type_list) { 340 | auto caster = make_caster(); 341 | if (caster.load(src, convert)) { 342 | value = cast_op(caster); 343 | return true; 344 | } 345 | return load_alternative(src, convert, type_list{}); 346 | } 347 | 348 | bool load_alternative(handle, bool, type_list<>) { return false; } 349 | 350 | bool load(handle src, bool convert) { 351 | // Do a first pass without conversions to improve constructor resolution. 352 | // E.g. `py::int_(1).cast>()` needs to fill the `int` 353 | // slot of the variant. Without two-pass loading `double` would be filled 354 | // because it appears first and a conversion is possible. 355 | if (convert && load_alternative(src, false, type_list{})) 356 | return true; 357 | return load_alternative(src, convert, type_list{}); 358 | } 359 | 360 | template 361 | static handle cast(Variant &&src, return_value_policy policy, handle parent) { 362 | return visit_helper::call(variant_caster_visitor{policy, parent}, 363 | std::forward(src)); 364 | } 365 | 366 | using Type = V; 367 | PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster::name...) + _("]")); 368 | }; 369 | 370 | #if PYBIND11_HAS_VARIANT 371 | template 372 | struct type_caster> : variant_caster> { }; 373 | #endif 374 | 375 | NAMESPACE_END(detail) 376 | 377 | inline std::ostream &operator<<(std::ostream &os, const handle &obj) { 378 | os << (std::string) str(obj); 379 | return os; 380 | } 381 | 382 | NAMESPACE_END(PYBIND11_NAMESPACE) 383 | 384 | #if defined(_MSC_VER) 385 | #pragma warning(pop) 386 | #endif 387 | -------------------------------------------------------------------------------- /pse/include/pybind11/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(pybind11) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(pybind11) 54 | -------------------------------------------------------------------------------- /pse/pse.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // pse 3 | // reference https://github.com/whai362/PSENet/issues/15 4 | // Created by liuheng on 11/3/19. 5 | // Copyright © 2019年 liuheng. All rights reserved. 6 | // 7 | #include 8 | #include "include/pybind11/pybind11.h" 9 | #include "include/pybind11/numpy.h" 10 | #include "include/pybind11/stl.h" 11 | #include "include/pybind11/stl_bind.h" 12 | 13 | namespace py = pybind11; 14 | 15 | namespace pse{ 16 | //S5->S0, small->big 17 | std::vector> pse( 18 | py::array_t label_map, 19 | py::array_t Sn, 20 | int c = 6) 21 | { 22 | auto pbuf_label_map = label_map.request(); 23 | auto pbuf_Sn = Sn.request(); 24 | if (pbuf_label_map.ndim != 2 || pbuf_label_map.shape[0]==0 || pbuf_label_map.shape[1]==0) 25 | throw std::runtime_error("label map must have a shape of (h>0, w>0)"); 26 | int h = pbuf_label_map.shape[0]; 27 | int w = pbuf_label_map.shape[1]; 28 | if (pbuf_Sn.ndim != 3 || pbuf_Sn.shape[0] != c || pbuf_Sn.shape[1]!=h || pbuf_Sn.shape[2]!=w) 29 | throw std::runtime_error("Sn must have a shape of (c>0, h>0, w>0)"); 30 | 31 | std::vector> res; 32 | for (size_t i = 0; i(w, 0)); 34 | auto ptr_label_map = static_cast(pbuf_label_map.ptr); 35 | auto ptr_Sn = static_cast(pbuf_Sn.ptr); 36 | 37 | std::queue> q, next_q; 38 | 39 | for (size_t i = 0; i0) 46 | { 47 | q.push(std::make_tuple(i, j, label)); 48 | res[i][j] = label; 49 | } 50 | } 51 | } 52 | 53 | int dx[4] = {-1, 1, 0, 0}; 54 | int dy[4] = {0, 0, -1, 1}; 55 | for (int i = c-2; i>=0; i--) 56 | { 57 | //get each kernels 58 | auto p_Sn = ptr_Sn + i*h*w; 59 | while(!q.empty()){ 60 | //get each queue menber in q 61 | auto q_n = q.front(); 62 | q.pop(); 63 | int y = std::get<0>(q_n); 64 | int x = std::get<1>(q_n); 65 | int32_t l = std::get<2>(q_n); 66 | //store the edge pixel after one expansion 67 | bool is_edge = true; 68 | for (int idx=0; idx<4; idx++) 69 | { 70 | int index_y = y + dy[idx]; 71 | int index_x = x + dx[idx]; 72 | if (index_y<0 || index_y>=h || index_x<0 || index_x>=w) 73 | continue; 74 | if (!p_Sn[index_y*w+index_x] || res[index_y][index_x]>0) 75 | continue; 76 | q.push(std::make_tuple(index_y, index_x, l)); 77 | res[index_y][index_x]=l; 78 | is_edge = false; 79 | } 80 | if (is_edge){ 81 | next_q.push(std::make_tuple(y, x, l)); 82 | } 83 | } 84 | std::swap(q, next_q); 85 | } 86 | return res; 87 | } 88 | } 89 | 90 | PYBIND11_MODULE(pse, m){ 91 | m.def("pse_cpp", &pse::pse, " re-implementation pse algorithm(cpp)", py::arg("label_map"), py::arg("Sn"), py::arg("c")=6); 92 | } 93 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # PSENet: Shape Robust Text Detection with Progressive Scale Expansion Network 2 | 3 | ### Introduction 4 | This is a tensorflow re-implementation of [PSENet: Shape Robust Text Detection with Progressive Scale Expansion Network](https://arxiv.org/abs/1806.02559). 5 | 6 | Thanks for the author's ([@whai362](https://github.com/whai362)) awesome work! 7 | 8 | ### Installation 9 | 1. Any version of tensorflow version > 1.0 should be ok. 10 | 2. python 2 or 3 will be ok. 11 | 12 | ### Download 13 | trained on ICDAR 2015 (training set) + ICDAR2017 MLT (training set): 14 | 15 | [baiduyun](https://pan.baidu.com/s/14tQHf9MjuD0lSmwkoZhnCg) extract code: pffd 16 | 17 | [google drive](https://drive.google.com/file/d/1TjJvtwMp8hJXQhn6Yz2lbPdvBGH-ZQ8u/view?usp=sharing) 18 | 19 | This model is not as good as article's, it's just a reference. 20 | You can finetune on it or you can do a lot of optimization based on this code. 21 | 22 | | Database | Precision (%) | Recall (%) | F-measure (%) | 23 | | - | - | - | - | 24 | | ICDAR 2015(val) | 74.61 | 80.93 | 77.64 | 25 | 26 | 27 | ### Train 28 | If you want to train the model, you should provide the dataset path, in the dataset path, a separate gt text file should be provided for each image, and **make sure that gt text and image file have the same names**. 29 | 30 | Then run train.py like: 31 | 32 | ``` 33 | python train.py --gpu_list=0 --input_size=512 --batch_size_per_gpu=8 --checkpoint_path=./resnet_v1_50/ \ 34 | --training_data_path=./data/ocr/icdar2015/ 35 | ``` 36 | 37 | If you have more than one gpu, you can pass gpu ids to gpu_list(like --gpu_list=0,1,2,3) 38 | 39 | **Note:** 40 | 1. right now , only support icdar2017 data format input, like (116,1179,206,1179,206,1207,116,1207,"###"), 41 | but you can modify data_provider.py to support polygon format input 42 | 2. Already support polygon shrink by using pyclipper module 43 | 3. this re-implementation is just for fun, but I'll continue to improve this code. 44 | 4. re-implementation pse algorithm by using c++ 45 | ***(if you use python2, just run it, if python3, please replace python-config with python3-config in makefile)*** 46 | 47 | ### Test 48 | run eval.py like: 49 | ``` 50 | python eval.py --test_data_path=./tmp/images/ --gpu_list=0 --checkpoint_path=./resnet_v1_50/ \ 51 | --output_dir=./tmp/ 52 | ``` 53 | 54 | a text file and result image will be then written to the output path. 55 | 56 | ### Examples 57 | ![result0](figure/result0.jpg) 58 | ![result1](figure/result1.jpg) 59 | ![result2](figure/result2.jpg) 60 | ![result3](figure/result3.jpg) 61 | ![result4](figure/result4.jpg) 62 | ![result5](figure/result5.jpg) 63 | 64 | ### About issues 65 | If you encounter any issue check issues first, or you can open a new issue. 66 | 67 | ### Reference 68 | 1. http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz 69 | 2. https://github.com/CharlesShang/FastMaskRCNN 70 | 3. https://github.com/whai362/PSENet/issues/15 71 | 4. https://github.com/argman/EAST 72 | 73 | ### Acknowledge 74 | [@rkshuai](https://github.com/rkshuai) found a bug about concat features in model.py. 75 | 76 | **If this repository helps you,please star it. Thanks.** 77 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import tensorflow as tf 4 | from tensorflow.contrib import slim 5 | from utils.utils_tool import logger, cfg 6 | 7 | tf.app.flags.DEFINE_integer('input_size', 512, '') 8 | tf.app.flags.DEFINE_integer('batch_size_per_gpu', 8, '') 9 | tf.app.flags.DEFINE_integer('num_readers', 32, '') 10 | tf.app.flags.DEFINE_float('learning_rate', 0.00001, '') 11 | tf.app.flags.DEFINE_integer('max_steps', 100000, '') 12 | tf.app.flags.DEFINE_float('moving_average_decay', 0.997, '') 13 | tf.app.flags.DEFINE_string('gpu_list', '0', '') 14 | tf.app.flags.DEFINE_string('checkpoint_path', './resnet_train/', '') 15 | tf.app.flags.DEFINE_boolean('restore', False, 'whether to resotre from checkpoint') 16 | tf.app.flags.DEFINE_integer('save_checkpoint_steps', 1000, '') 17 | tf.app.flags.DEFINE_integer('save_summary_steps', 100, '') 18 | tf.app.flags.DEFINE_string('pretrained_model_path', None, '') 19 | 20 | from nets import model 21 | from utils.data_provider import data_provider 22 | 23 | FLAGS = tf.app.flags.FLAGS 24 | 25 | gpus = list(range(len(FLAGS.gpu_list.split(',')))) 26 | 27 | logger.setLevel(cfg.debug) 28 | 29 | def tower_loss(images, seg_maps_gt, training_masks, reuse_variables=None): 30 | # Build inference graph 31 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables): 32 | seg_maps_pred = model.model(images, is_training=True) 33 | 34 | model_loss = model.loss(seg_maps_gt, seg_maps_pred, training_masks) 35 | total_loss = tf.add_n([model_loss] + tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 36 | 37 | # add summary 38 | if reuse_variables is None: 39 | tf.summary.image('input', images) 40 | tf.summary.image('seg_map_0_gt', seg_maps_gt[:, :, :, 0:1] * 255) 41 | tf.summary.image('seg_map_0_pred', seg_maps_pred[:, :, :, 0:1] * 255) 42 | tf.summary.image('training_masks', training_masks) 43 | tf.summary.scalar('model_loss', model_loss) 44 | tf.summary.scalar('total_loss', total_loss) 45 | 46 | return total_loss, model_loss 47 | 48 | 49 | def average_gradients(tower_grads): 50 | average_grads = [] 51 | for grad_and_vars in zip(*tower_grads): 52 | grads = [] 53 | for g, _ in grad_and_vars: 54 | expanded_g = tf.expand_dims(g, 0) 55 | grads.append(expanded_g) 56 | 57 | grad = tf.concat(grads, 0) 58 | grad = tf.reduce_mean(grad, 0) 59 | 60 | v = grad_and_vars[0][1] 61 | grad_and_var = (grad, v) 62 | average_grads.append(grad_and_var) 63 | 64 | return average_grads 65 | 66 | 67 | def main(argv=None): 68 | import os 69 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list 70 | if not tf.gfile.Exists(FLAGS.checkpoint_path): 71 | tf.gfile.MkDir(FLAGS.checkpoint_path) 72 | else: 73 | if not FLAGS.restore: 74 | tf.gfile.DeleteRecursively(FLAGS.checkpoint_path) 75 | tf.gfile.MkDir(FLAGS.checkpoint_path) 76 | 77 | input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images') 78 | input_seg_maps = tf.placeholder(tf.float32, shape=[None, None, None, 6], name='input_score_maps') 79 | input_training_masks = tf.placeholder(tf.float32, shape=[None, None, None, 1], name='input_training_masks') 80 | 81 | global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) 82 | learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, decay_steps=10000, decay_rate=0.94, staircase=True) 83 | # add summary 84 | tf.summary.scalar('learning_rate', learning_rate) 85 | # opt = tf.train.RMSPropOptimizer(learning_rate, decay=0.9, momentum=0.9) 86 | opt = tf.train.AdamOptimizer(learning_rate) 87 | # opt = tf.train.MomentumOptimizer(learning_rate, 0.9) 88 | 89 | 90 | # split 91 | input_images_split = tf.split(input_images, len(gpus)) 92 | input_seg_maps_split = tf.split(input_seg_maps, len(gpus)) 93 | input_training_masks_split = tf.split(input_training_masks, len(gpus)) 94 | 95 | tower_grads = [] 96 | reuse_variables = None 97 | for i, gpu_id in enumerate(gpus): 98 | with tf.device('/gpu:%d' % gpu_id): 99 | with tf.name_scope('model_%d' % gpu_id) as scope: 100 | iis = input_images_split[i] 101 | isegs = input_seg_maps_split[i] 102 | itms = input_training_masks_split[i] 103 | total_loss, model_loss = tower_loss(iis, isegs, itms, reuse_variables) 104 | batch_norm_updates_op = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)) 105 | reuse_variables = True 106 | 107 | grads = opt.compute_gradients(total_loss) 108 | tower_grads.append(grads) 109 | 110 | grads = average_gradients(tower_grads) 111 | apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) 112 | 113 | summary_op = tf.summary.merge_all() 114 | # save moving average 115 | variable_averages = tf.train.ExponentialMovingAverage( 116 | FLAGS.moving_average_decay, global_step) 117 | variables_averages_op = variable_averages.apply(tf.trainable_variables()) 118 | # batch norm updates 119 | with tf.control_dependencies([variables_averages_op, apply_gradient_op, batch_norm_updates_op]): 120 | train_op = tf.no_op(name='train_op') 121 | 122 | saver = tf.train.Saver(tf.global_variables()) 123 | summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path, tf.get_default_graph()) 124 | 125 | init = tf.global_variables_initializer() 126 | 127 | if FLAGS.pretrained_model_path is not None: 128 | variable_restore_op = slim.assign_from_checkpoint_fn(FLAGS.pretrained_model_path, slim.get_trainable_variables(), 129 | ignore_missing_vars=True) 130 | gpu_options=tf.GPUOptions(allow_growth=True) 131 | #gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.75) 132 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)) as sess: 133 | if FLAGS.restore: 134 | logger.info('continue training from previous checkpoint') 135 | ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path) 136 | logger.debug(ckpt) 137 | saver.restore(sess, ckpt) 138 | else: 139 | sess.run(init) 140 | if FLAGS.pretrained_model_path is not None: 141 | variable_restore_op(sess) 142 | 143 | data_generator = data_provider.get_batch(num_workers=FLAGS.num_readers, 144 | input_size=FLAGS.input_size, 145 | batch_size=FLAGS.batch_size_per_gpu * len(gpus)) 146 | 147 | start = time.time() 148 | for step in range(FLAGS.max_steps): 149 | data = next(data_generator) 150 | ml, tl, _ = sess.run([model_loss, total_loss, train_op], feed_dict={input_images: data[0], 151 | input_seg_maps: data[2], 152 | input_training_masks: data[3]}) 153 | if np.isnan(tl): 154 | logger.error('Loss diverged, stop training') 155 | break 156 | 157 | if step % 10 == 0: 158 | avg_time_per_step = (time.time() - start)/10 159 | avg_examples_per_second = (10 * FLAGS.batch_size_per_gpu * len(gpus))/(time.time() - start) 160 | start = time.time() 161 | logger.info('Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second'.format( 162 | step, ml, tl, avg_time_per_step, avg_examples_per_second)) 163 | 164 | if step % FLAGS.save_checkpoint_steps == 0: 165 | saver.save(sess, os.path.join(FLAGS.checkpoint_path, 'model.ckpt'), global_step=global_step) 166 | 167 | if step % FLAGS.save_summary_steps == 0: 168 | _, tl, summary_str = sess.run([train_op, total_loss, summary_op], feed_dict={input_images: data[0], 169 | input_seg_maps: data[2], 170 | input_training_masks: data[3]}) 171 | summary_writer.add_summary(summary_str, global_step=step) 172 | 173 | if __name__ == '__main__': 174 | tf.app.run() 175 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichaelHL-ai/tensorflow_PSENet/e2cd908f301b762150aa36893677c1c51c98ff9e/utils/__init__.py -------------------------------------------------------------------------------- /utils/data_provider/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MichaelHL-ai/tensorflow_PSENet/e2cd908f301b762150aa36893677c1c51c98ff9e/utils/data_provider/__init__.py -------------------------------------------------------------------------------- /utils/data_provider/data_provider.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import os 3 | import glob 4 | import time 5 | import json 6 | import csv 7 | import traceback 8 | import cv2 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from utils.utils_tool import logger 12 | from utils.data_provider.data_util import GeneratorEnqueuer 13 | import tensorflow as tf 14 | import pyclipper 15 | 16 | tf.app.flags.DEFINE_string('training_data_path', None, 17 | 'training dataset to use') 18 | tf.app.flags.DEFINE_integer('max_image_large_side', 1280, 19 | 'max image size of training') 20 | tf.app.flags.DEFINE_integer('max_text_size', 800, 21 | 'if the text in the input image is bigger than this, then we resize' 22 | 'the image according to this') 23 | tf.app.flags.DEFINE_integer('min_text_area_size', 10, 24 | 'if the text area size is smaller than this, we ignore it during training') 25 | tf.app.flags.DEFINE_float('min_crop_side_ratio', 0.1, 26 | 'when doing random crop from input image, the' 27 | 'min length of min(H, W') 28 | 29 | FLAGS = tf.app.flags.FLAGS 30 | 31 | 32 | def get_files(exts): 33 | files = [] 34 | for ext in exts: 35 | files.extend(glob.glob( 36 | os.path.join(FLAGS.training_data_path, '*.{}'.format(ext)))) 37 | return files 38 | 39 | def get_json_label(): 40 | label_file_list = get_files(['json']) 41 | label = {} 42 | for label_file in label_file_list: 43 | with open(label_file, 'r') as f: 44 | json_label = json.load(f) 45 | 46 | for k, v in json_label.items(): 47 | if not label.has_key(k): 48 | label[k] = v 49 | return label 50 | 51 | def load_annoataion(p): 52 | ''' 53 | load annotation from the text file 54 | :param p: 55 | :return: 56 | ''' 57 | text_polys = [] 58 | text_tags = [] 59 | if not os.path.exists(p): 60 | return np.array(text_polys, dtype=np.float32) 61 | with open(p, 'r') as f: 62 | reader = csv.reader(f) 63 | for line in reader: 64 | label = line[-1] 65 | # strip BOM. \ufeff for python3, \xef\xbb\bf for python2 66 | line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in line] 67 | 68 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8])) 69 | text_polys.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) 70 | #TODO:maybe add '?' for icpr2018 (michael) 71 | if label == '*' or label == '###' or label == '?': 72 | text_tags.append(True) 73 | else: 74 | text_tags.append(False) 75 | return np.array(text_polys, dtype=np.float32), np.array(text_tags, dtype=np.bool) 76 | 77 | def check_and_validate_polys(polys, tags, xxx_todo_changeme): 78 | ''' 79 | check so that the text poly is in the same direction, 80 | and also filter some invalid polygons 81 | :param polys: 82 | :param tags: 83 | :return: 84 | ''' 85 | (h, w) = xxx_todo_changeme 86 | if polys.shape[0] == 0: 87 | return [], [] 88 | polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w-1) 89 | polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h-1) 90 | 91 | validated_polys = [] 92 | validated_tags = [] 93 | for poly, tag in zip(polys, tags): 94 | if abs(pyclipper.Area(poly))<1: 95 | continue 96 | #clockwise 97 | if pyclipper.Orientation(poly): 98 | poly = poly[::-1] 99 | 100 | validated_polys.append(poly) 101 | validated_tags.append(tag) 102 | return np.array(validated_polys), np.array(validated_tags) 103 | 104 | def crop_area(im, polys, tags, crop_background=False, max_tries=50): 105 | ''' 106 | make random crop from the input image 107 | :param im: 108 | :param polys: 109 | :param tags: 110 | :param crop_background: 111 | :param max_tries: 112 | :return: 113 | ''' 114 | h, w, _ = im.shape 115 | pad_h = h//10 116 | pad_w = w//10 117 | h_array = np.zeros((h + pad_h*2), dtype=np.int32) 118 | w_array = np.zeros((w + pad_w*2), dtype=np.int32) 119 | for poly in polys: 120 | poly = np.round(poly, decimals=0).astype(np.int32) 121 | minx = np.min(poly[:, 0]) 122 | maxx = np.max(poly[:, 0]) 123 | w_array[minx+pad_w:maxx+pad_w] = 1 124 | miny = np.min(poly[:, 1]) 125 | maxy = np.max(poly[:, 1]) 126 | h_array[miny+pad_h:maxy+pad_h] = 1 127 | # ensure the cropped area not across a text 128 | h_axis = np.where(h_array == 0)[0] 129 | w_axis = np.where(w_array == 0)[0] 130 | if len(h_axis) == 0 or len(w_axis) == 0: 131 | return im, polys, tags 132 | for i in range(max_tries): 133 | xx = np.random.choice(w_axis, size=2) 134 | xmin = np.min(xx) - pad_w 135 | xmax = np.max(xx) - pad_w 136 | xmin = np.clip(xmin, 0, w-1) 137 | xmax = np.clip(xmax, 0, w-1) 138 | yy = np.random.choice(h_axis, size=2) 139 | ymin = np.min(yy) - pad_h 140 | ymax = np.max(yy) - pad_h 141 | ymin = np.clip(ymin, 0, h-1) 142 | ymax = np.clip(ymax, 0, h-1) 143 | if xmax - xmin < FLAGS.min_crop_side_ratio*w or ymax - ymin < FLAGS.min_crop_side_ratio*h: 144 | # area too small 145 | continue 146 | if polys.shape[0] != 0: 147 | poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \ 148 | & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax) 149 | selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0] 150 | else: 151 | selected_polys = [] 152 | if len(selected_polys) == 0: 153 | # no text in this area 154 | if crop_background: 155 | return im[ymin:ymax+1, xmin:xmax+1, :], polys[selected_polys], tags[selected_polys] 156 | else: 157 | continue 158 | im = im[ymin:ymax+1, xmin:xmax+1, :] 159 | polys = polys[selected_polys] 160 | tags = tags[selected_polys] 161 | polys[:, :, 0] -= xmin 162 | polys[:, :, 1] -= ymin 163 | return im, polys, tags 164 | 165 | return im, polys, tags 166 | 167 | def perimeter(poly): 168 | try: 169 | p=0 170 | nums = poly.shape[0] 171 | for i in range(nums): 172 | p += abs(np.linalg.norm(poly[i%nums]-poly[(i+1)%nums])) 173 | # logger.debug('perimeter:{}'.format(p)) 174 | return p 175 | except Exception as e: 176 | traceback.print_exc() 177 | raise e 178 | 179 | def shrink_poly(poly, r): 180 | try: 181 | area_poly = abs(pyclipper.Area(poly)) 182 | perimeter_poly = perimeter(poly) 183 | poly_s = [] 184 | pco = pyclipper.PyclipperOffset() 185 | if perimeter_poly: 186 | d=area_poly*(1-r*r)/perimeter_poly 187 | pco.AddPath(poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 188 | poly_s = pco.Execute(-d) 189 | return poly_s 190 | except Exception as e: 191 | traceback.print_exc() 192 | raise e 193 | 194 | #TODO:filter small text(when shrincked region shape is 0 no matter what scale ratio is) 195 | def generate_seg(im_size, polys, tags, image_name, scale_ratio): 196 | ''' 197 | :param im_size: input image size 198 | :param polys: input text regions 199 | :param tags: ignore text regions tags 200 | :param image_index: for log 201 | :param scale_ratio:ground truth scale ratio, default[0.5, 0.6, 0.7, 0.8, 0.9, 1.0] 202 | :return: 203 | seg_maps: segmentation results with different scale ratio, save in different channel 204 | training_mask: ignore text regions 205 | ''' 206 | h, w = im_size 207 | #mark different text poly 208 | seg_maps = np.zeros((h,w,6), dtype=np.uint8) 209 | # mask used during traning, to ignore some hard areas 210 | training_mask = np.ones((h, w), dtype=np.uint8) 211 | ignore_poly_mark = [] 212 | for i in range(len(scale_ratio)): 213 | seg_map = np.zeros((h,w), dtype=np.uint8) 214 | for poly_idx, poly_tag in enumerate(zip(polys, tags)): 215 | poly = poly_tag[0] 216 | tag = poly_tag[1] 217 | 218 | # ignore ### 219 | if i == 0 and tag: 220 | cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) 221 | ignore_poly_mark.append(poly_idx) 222 | 223 | # seg map 224 | shrinked_polys = [] 225 | if poly_idx not in ignore_poly_mark: 226 | shrinked_polys = shrink_poly(poly.copy(), scale_ratio[i]) 227 | 228 | if not len(shrinked_polys) and poly_idx not in ignore_poly_mark: 229 | logger.info("before shrink poly area:{} len(shrinked_poly) is 0,image {}".format( 230 | abs(pyclipper.Area(poly)),image_name)) 231 | # if the poly is too small, then ignore it during training 232 | cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) 233 | ignore_poly_mark.append(poly_idx) 234 | continue 235 | for shrinked_poly in shrinked_polys: 236 | seg_map = cv2.fillPoly(seg_map, [np.array(shrinked_poly).astype(np.int32)], 1) 237 | 238 | seg_maps[..., i] = seg_map 239 | return seg_maps, training_mask 240 | 241 | 242 | def generator(input_size=512, batch_size=32, 243 | background_ratio=3./8, 244 | random_scale=np.array([0.125, 0.25,0.5, 1, 2.0, 3.0]), 245 | vis=False, 246 | scale_ratio=np.array([0.5, 0.6, 0.7, 0.8, 0.9, 1.0])): 247 | ''' 248 | reference from https://github.com/argman/EAST 249 | :param input_size: 250 | :param batch_size: 251 | :param background_ratio: 252 | :param random_scale: 253 | :param vis: 254 | :param scale_ratio:ground truth scale ratio 255 | :return: 256 | ''' 257 | image_list = np.array(get_files(['jpg', 'png', 'jpeg', 'JPG'])) 258 | 259 | logger.info('{} training images in {}'.format( 260 | image_list.shape[0], FLAGS.training_data_path)) 261 | index = np.arange(0, image_list.shape[0]) 262 | 263 | while True: 264 | np.random.shuffle(index) 265 | images = [] 266 | image_fns = [] 267 | seg_maps = [] 268 | training_masks = [] 269 | for i in index: 270 | try: 271 | im_fn = image_list[i] 272 | im = cv2.imread(im_fn) 273 | if im is None: 274 | logger.info(im_fn) 275 | h, w, _ = im.shape 276 | txt_fn = im_fn.replace(os.path.basename(im_fn).split('.')[1], 'txt') 277 | if not os.path.exists(txt_fn): 278 | continue 279 | 280 | text_polys, text_tags = load_annoataion(txt_fn) 281 | if text_polys.shape[0] == 0: 282 | continue 283 | text_polys, text_tags = check_and_validate_polys(text_polys, text_tags, (h, w)) 284 | 285 | # random scale this image 286 | rd_scale = np.random.choice(random_scale) 287 | im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) 288 | text_polys *= rd_scale 289 | # random crop a area from image 290 | if np.random.rand() < background_ratio: 291 | # crop background 292 | im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=True) 293 | if text_polys.shape[0] > 0: 294 | # cannot find background 295 | continue 296 | # pad and resize image 297 | new_h, new_w, _ = im.shape 298 | #max_h_w_i = np.max([new_h, new_w, input_size]) 299 | im_padded = np.zeros((new_h, new_w, 3), dtype=np.uint8) 300 | im_padded[:new_h, :new_w, :] = im.copy() 301 | im = cv2.resize(im_padded, dsize=(input_size, input_size)) 302 | seg_map_per_image = np.zeros((input_size, input_size, scale_ratio.shape[0]), dtype=np.uint8) 303 | training_mask = np.ones((input_size, input_size), dtype=np.uint8) 304 | else: 305 | im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=False) 306 | if text_polys.shape[0] == 0: 307 | continue 308 | # h, w, _ = im.shape 309 | 310 | # pad the image to the training input size or the longer side of image 311 | new_h, new_w, _ = im.shape 312 | #max_h_w_i = np.max([new_h, new_w, input_size]) 313 | im_padded = np.zeros((new_h, new_w, 3), dtype=np.uint8) 314 | im_padded[:new_h, :new_w, :] = im.copy() 315 | im = im_padded 316 | # resize the image to input size 317 | new_h, new_w, _ = im.shape 318 | resize_h = input_size 319 | resize_w = input_size 320 | im = cv2.resize(im, dsize=(resize_w, resize_h)) 321 | resize_ratio_3_x = resize_w/float(new_w) 322 | resize_ratio_3_y = resize_h/float(new_h) 323 | text_polys[:, :, 0] *= resize_ratio_3_x 324 | text_polys[:, :, 1] *= resize_ratio_3_y 325 | new_h, new_w, _ = im.shape 326 | seg_map_per_image, training_mask = generate_seg((new_h, new_w), text_polys, text_tags, 327 | image_list[i], scale_ratio) 328 | if not len(seg_map_per_image): 329 | logger.info("len(seg_map)==0 image: %d " % i) 330 | continue 331 | 332 | if vis: 333 | fig, axs = plt.subplots(3, 3, figsize=(20, 30)) 334 | axs[0, 0].imshow(im[..., ::-1]) 335 | axs[0, 0].set_xticks([]) 336 | axs[0, 0].set_yticks([]) 337 | axs[0, 1].imshow(seg_map_per_image[..., 0]) 338 | axs[0, 1].set_xticks([]) 339 | axs[0, 1].set_yticks([]) 340 | axs[0, 2].imshow(seg_map_per_image[..., 1]) 341 | axs[0, 2].set_xticks([]) 342 | axs[0, 2].set_yticks([]) 343 | axs[1, 0].imshow(seg_map_per_image[..., 2]) 344 | axs[1, 0].set_xticks([]) 345 | axs[1, 0].set_yticks([]) 346 | axs[1, 1].imshow(seg_map_per_image[..., 3]) 347 | axs[1, 1].set_xticks([]) 348 | axs[1, 1].set_yticks([]) 349 | axs[1, 2].imshow(seg_map_per_image[..., 4]) 350 | axs[1, 2].set_xticks([]) 351 | axs[1, 2].set_yticks([]) 352 | axs[2, 0].imshow(seg_map_per_image[..., 5]) 353 | axs[2, 0].set_xticks([]) 354 | axs[2, 0].set_yticks([]) 355 | axs[2, 1].imshow(training_mask) 356 | axs[2, 1].set_xticks([]) 357 | axs[2, 1].set_yticks([]) 358 | plt.tight_layout() 359 | plt.show() 360 | plt.close() 361 | 362 | images.append(im[..., ::-1].astype(np.float32)) 363 | image_fns.append(im_fn) 364 | seg_maps.append(seg_map_per_image[::4, ::4, :].astype(np.float32)) 365 | training_masks.append(training_mask[::4, ::4, np.newaxis].astype(np.float32)) 366 | 367 | if len(images) == batch_size: 368 | yield images, image_fns, seg_maps, training_masks 369 | images = [] 370 | image_fns = [] 371 | seg_maps = [] 372 | training_masks = [] 373 | except Exception as e: 374 | traceback.print_exc() 375 | continue 376 | 377 | 378 | def get_batch(num_workers, **kwargs): 379 | try: 380 | enqueuer = GeneratorEnqueuer(generator(**kwargs), use_multiprocessing=True) 381 | enqueuer.start(max_queue_size=24, workers=num_workers) 382 | generator_output = None 383 | while True: 384 | while enqueuer.is_running(): 385 | if not enqueuer.queue.empty(): 386 | generator_output = enqueuer.queue.get() 387 | break 388 | else: 389 | time.sleep(0.01) 390 | yield generator_output 391 | generator_output = None 392 | finally: 393 | if enqueuer is not None: 394 | enqueuer.stop() 395 | 396 | 397 | if __name__ == '__main__': 398 | gen = get_batch(num_workers=2, vis=True) 399 | while True: 400 | image, bbox, im_info = next(gen) 401 | logger.debug('done') 402 | -------------------------------------------------------------------------------- /utils/data_provider/data_util.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import threading 3 | import time 4 | 5 | import numpy as np 6 | 7 | try: 8 | import queue 9 | except ImportError: 10 | import Queue as queue 11 | 12 | 13 | class GeneratorEnqueuer(): 14 | def __init__(self, generator, 15 | use_multiprocessing=False, 16 | wait_time=0.05, 17 | random_seed=None): 18 | self.wait_time = wait_time 19 | self._generator = generator 20 | self._use_multiprocessing = use_multiprocessing 21 | self._threads = [] 22 | self._stop_event = None 23 | self.queue = None 24 | self.random_seed = random_seed 25 | 26 | def start(self, workers=1, max_queue_size=10): 27 | def data_generator_task(): 28 | while not self._stop_event.is_set(): 29 | try: 30 | if self._use_multiprocessing or self.queue.qsize() < max_queue_size: 31 | generator_output = next(self._generator) 32 | self.queue.put(generator_output) 33 | else: 34 | time.sleep(self.wait_time) 35 | except Exception: 36 | self._stop_event.set() 37 | raise 38 | 39 | try: 40 | if self._use_multiprocessing: 41 | self.queue = multiprocessing.Queue(maxsize=max_queue_size) 42 | self._stop_event = multiprocessing.Event() 43 | else: 44 | self.queue = queue.Queue() 45 | self._stop_event = threading.Event() 46 | 47 | for _ in range(workers): 48 | if self._use_multiprocessing: 49 | # Reset random seed else all children processes 50 | # share the same seed 51 | np.random.seed(self.random_seed) 52 | thread = multiprocessing.Process(target=data_generator_task) 53 | thread.daemon = True 54 | if self.random_seed is not None: 55 | self.random_seed += 1 56 | else: 57 | thread = threading.Thread(target=data_generator_task) 58 | self._threads.append(thread) 59 | thread.start() 60 | except: 61 | self.stop() 62 | raise 63 | 64 | def is_running(self): 65 | return self._stop_event is not None and not self._stop_event.is_set() 66 | 67 | def stop(self, timeout=None): 68 | if self.is_running(): 69 | self._stop_event.set() 70 | 71 | for thread in self._threads: 72 | if thread.is_alive(): 73 | if self._use_multiprocessing: 74 | thread.terminate() 75 | else: 76 | thread.join(timeout) 77 | 78 | if self._use_multiprocessing: 79 | if self.queue is not None: 80 | self.queue.close() 81 | 82 | self._threads = [] 83 | self._stop_event = None 84 | self.queue = None 85 | 86 | def get(self): 87 | while self.is_running(): 88 | if not self.queue.empty(): 89 | inputs = self.queue.get() 90 | if inputs is not None: 91 | yield inputs 92 | else: 93 | time.sleep(self.wait_time) 94 | -------------------------------------------------------------------------------- /utils/utils_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from easydict import EasyDict as edict 3 | import Queue 4 | import numpy as np 5 | import cv2 6 | 7 | logging.basicConfig() 8 | logger = logging.getLogger() 9 | logger.setLevel(logging.DEBUG) 10 | 11 | __C = edict() 12 | cfg = __C 13 | 14 | #log level 15 | __C.error = logging.ERROR 16 | __C.warning = logging.WARNING 17 | __C.info = logging.INFO 18 | __C.debug = logging.DEBUG 19 | 20 | 21 | def pse(kernals, min_area=5): 22 | ''' 23 | reference https://github.com/whai362/PSENet/issues/15 24 | :param kernals: 25 | :param min_area: 26 | :return: 27 | ''' 28 | kernal_num = len(kernals) 29 | if not kernal_num: 30 | logger.error('not kernals!') 31 | return np.array([]), [] 32 | pred = np.zeros(kernals[0].shape, dtype='int32') 33 | 34 | label_num, label = cv2.connectedComponents(kernals[kernal_num - 1].astype(np.uint8), connectivity=4) 35 | label_values = [] 36 | for label_idx in range(1, label_num): 37 | if np.sum(label == label_idx) < min_area: 38 | label[label == label_idx] = 0 39 | continue 40 | label_values.append(label_idx) 41 | 42 | queue = Queue.Queue(maxsize=0) 43 | next_queue = Queue.Queue(maxsize=0) 44 | points = np.array(np.where(label > 0)).transpose((1, 0)) 45 | 46 | 47 | for point_idx in range(points.shape[0]): 48 | x, y = points[point_idx, 0], points[point_idx, 1] 49 | l = label[x, y] 50 | queue.put((x, y, l)) 51 | pred[x, y] = l 52 | 53 | dx = [-1, 1, 0, 0] 54 | dy = [0, 0, -1, 1] 55 | for kernal_idx in range(kernal_num - 2, -1, -1): 56 | kernal = kernals[kernal_idx].copy() 57 | while not queue.empty(): 58 | (x, y, l) = queue.get() 59 | 60 | is_edge = True 61 | for j in range(4): 62 | tmpx = x + dx[j] 63 | tmpy = y + dy[j] 64 | if tmpx < 0 or tmpx >= kernal.shape[0] or tmpy < 0 or tmpy >= kernal.shape[1]: 65 | continue 66 | if kernal[tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: 67 | continue 68 | 69 | queue.put((tmpx, tmpy, l)) 70 | pred[tmpx, tmpy] = l 71 | is_edge = False 72 | if is_edge: 73 | next_queue.put((x, y, l)) 74 | 75 | # kernal[pred > 0] = 0 76 | queue, next_queue = next_queue, queue 77 | 78 | # points = np.array(np.where(pred > 0)).transpose((1, 0)) 79 | # for point_idx in range(points.shape[0]): 80 | # x, y = points[point_idx, 0], points[point_idx, 1] 81 | # l = pred[x, y] 82 | # queue.put((x, y, l)) 83 | 84 | return pred, label_values --------------------------------------------------------------------------------