├── README.md ├── include ├── cudnn.h ├── detector.h ├── polygon.h ├── pybind11 │ ├── attr.h │ ├── buffer_info.h │ ├── cast.h │ ├── chrono.h │ ├── class_support.h │ ├── common.h │ ├── complex.h │ ├── descr.h │ ├── detail │ │ ├── class.h │ │ ├── common.h │ │ ├── descr.h │ │ ├── init.h │ │ ├── internals.h │ │ └── typeid.h │ ├── eigen.h │ ├── embed.h │ ├── eval.h │ ├── functional.h │ ├── iostream.h │ ├── numpy.h │ ├── operators.h │ ├── options.h │ ├── pybind11.h │ ├── pytypes.h │ ├── stl.h │ ├── stl_bind.h │ └── typeid.h ├── recognizer.h └── tensorflow_graph.h └── src ├── detector.cpp ├── main.cpp ├── polygon.cpp └── recognizer.cpp /README.md: -------------------------------------------------------------------------------- 1 | # psenet_cpp 2 | 3 | ## The model is trained by liuheng92/tensorflow_PSENet. 4 | ## Here is my code to convert ckpt to pb: 5 | 6 | ``` 7 | def freeze(ckpt_path=None, save_path=None): 8 | 9 | from tensorflow.python.tools import freeze_graph # , optimize_for_inference_lib 10 | 11 | with tf.get_default_graph().as_default(): 12 | input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images') 13 | global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) 14 | seg_maps_pred = model.model(input_images, is_training=False) 15 | 16 | tf.identity(seg_maps_pred, name='seg_maps') 17 | 18 | variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) 19 | saver = tf.train.Saver(variable_averages.variables_to_restore()) 20 | 21 | with tf.Session() as sess: 22 | saver.restore(sess, ckpt_path) 23 | 24 | print('Freeze Model Will Saved at ', save_path) 25 | fdir, name = os.path.split(save_path) 26 | tf.train.write_graph(sess.graph_def, fdir, name, as_text=True) 27 | 28 | freeze_graph.freeze_graph( 29 | input_graph=save_path, 30 | input_saver='', 31 | input_binary=False, 32 | input_checkpoint=ckpt_path, 33 | output_node_names='seg_maps', 34 | restore_op_name='', 35 | filename_tensor_name='', 36 | output_graph=save_path, 37 | clear_devices=True, 38 | initializer_nodes='', 39 | ) 40 | ``` 41 | -------------------------------------------------------------------------------- /include/detector.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by xpc on 19-4-28. 3 | // 4 | 5 | #pragma once 6 | 7 | #include 8 | #include "tensorflow_graph.h" 9 | #include "polygon.h" 10 | 11 | namespace tf = tensorflow; 12 | 13 | namespace SeetaOCR { 14 | 15 | class Detector: public TFGraph { 16 | public: 17 | Detector(const std::string& graph_file) 18 | : TFGraph(graph_file){ 19 | longestSide = 1024; 20 | outputTensorNames = {"seg_maps"}; 21 | Init(); 22 | }; 23 | 24 | void Predict(cv::Mat& inp); 25 | 26 | std::vector Polygons() { return polygons; } 27 | 28 | void Debug() {DEBUG=true;} 29 | 30 | void Predict(cv::Mat& inp, std::vector& _polygons) { 31 | Predict(inp); 32 | _polygons.assign(polygons.begin(), polygons.end()); 33 | } 34 | 35 | protected: 36 | bool DEBUG=false; 37 | int longestSide; 38 | cv::Mat resized; 39 | std::vector outputs; 40 | std::vector polygons; 41 | std::vector> inputs; 42 | 43 | void FeedImageToTensor(cv::Mat& inp); 44 | 45 | void PseAdaptor(tf::Tensor& features, 46 | std::map>& contours_map, 47 | const float thresh, 48 | const float min_area, 49 | const float ratio); 50 | 51 | void ResizeImage(cv::Mat& inp, cv::Mat& out, int longest_side); 52 | }; 53 | } 54 | 55 | /* 56 | +--------+----+----+----+----+------+------+------+------+ 57 | | | C1 | C2 | C3 | C4 | C(5) | C(6) | C(7) | C(8) | 58 | +--------+----+----+----+----+------+------+------+------+ 59 | | CV_8U | 0 | 8 | 16 | 24 | 32 | 40 | 48 | 56 | 60 | | CV_8S | 1 | 9 | 17 | 25 | 33 | 41 | 49 | 57 | 61 | | CV_16U | 2 | 10 | 18 | 26 | 34 | 42 | 50 | 58 | 62 | | CV_16S | 3 | 11 | 19 | 27 | 35 | 43 | 51 | 59 | 63 | | CV_32S | 4 | 12 | 20 | 28 | 36 | 44 | 52 | 60 | 64 | | CV_32F | 5 | 13 | 21 | 29 | 37 | 45 | 53 | 61 | 65 | | CV_64F | 6 | 14 | 22 | 30 | 38 | 46 | 54 | 62 | 66 | +--------+----+----+----+----+------+------+------+------+ 67 | */ -------------------------------------------------------------------------------- /include/polygon.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by seeta on 19-5-5. 3 | // 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | template 10 | double Distance(T& a, T& b) { 11 | return sqrt(pow(a.x - b.x, 2) + pow(a.y - b.y, 2)); 12 | } 13 | 14 | class Polygon { 15 | public: 16 | Polygon(cv::Mat& boxPts, cv::Size side, float scaleX=1, float scaleY=1) { 17 | for (int row=0; row < boxPts.rows; ++row) { 18 | auto x = boxPts.at(row, 0) * scaleX; 19 | auto y = boxPts.at(row, 1) * scaleY; 20 | if (x < 0) x = 0; 21 | if (y < 0) y = 0; 22 | if (x >= side.width) x = side.width - 1; 23 | if (y >= side.height) y = side.height - 1; 24 | cv::Point2f pt(x, y); 25 | if (row == 0) lb = pt; 26 | else if (row == 1) lt = pt; 27 | else if (row == 2) rt = pt; 28 | else if (row == 3) rb = pt; 29 | else throw std::range_error("check boxPts mast be (4, 2)"); 30 | } 31 | vec2f = {lt, rt, rb, lb}; 32 | SortVertex(); 33 | } 34 | 35 | std::vector ToVec2f() { return vec2f; } 36 | 37 | std::vector ToVec2i() { 38 | std::vector vec2i(vec2f.size()); 39 | for (int i=0; i < vec2f.size(); ++i) { 40 | auto x = (int) round(vec2f[i].x); 41 | auto y = (int) round(vec2f[i].y); 42 | vec2i[i] = cv::Point2i(x, y); 43 | } 44 | return vec2i; 45 | } 46 | 47 | float Area() { 48 | float area = 0.0; 49 | auto num = (int) vec2f.size(); 50 | for (int i=num-1, j=0; j < num; i=j++) { 51 | area += vec2f[i].x * vec2f[j].y; 52 | area -= vec2f[i].y * vec2f[j].x; 53 | } 54 | return area; 55 | } 56 | 57 | /** 58 | * @brief resort poly vertex like (lt, rt, rb, lb) 59 | */ 60 | void SortVertex() { 61 | int minAxis = 0; 62 | float minSum = -1; 63 | 64 | for (int i=0; i < vec2f.size(); ++i) { 65 | float sum = vec2f[i].x + vec2f[i].y; 66 | if (minSum < 0 || minSum > sum) { 67 | minSum = sum; 68 | minAxis = i; 69 | } 70 | } 71 | 72 | std::vector vertex({vec2f[(minAxis + 0) % 4], vec2f[(minAxis + 1) % 4], 73 | vec2f[(minAxis + 2) % 4], vec2f[(minAxis + 3) % 4]}); 74 | 75 | if (fabs(vertex[0].x - vertex[1].x) > fabs(vertex[0].y - vertex[1].y)) { 76 | vertex.swap(vec2f); 77 | } else { 78 | vec2f[0] = vertex[0]; 79 | vec2f[1] = vertex[3]; 80 | vec2f[2] = vertex[2]; 81 | vec2f[3] = vertex[1]; 82 | } 83 | } 84 | 85 | std::vector ToQuadROI() { 86 | std::vector quad; 87 | auto height = (int) round(fmax(Distance(vec2f[0], vec2f[3]), Distance(vec2f[1], vec2f[2]))); 88 | auto width = (int) round(fmax(Distance(vec2f[0], vec2f[1]), Distance(vec2f[2], vec2f[3]))); 89 | quad.emplace_back(cv::Point2f(0, 0)); 90 | quad.emplace_back(cv::Point2f(width, 0)); 91 | quad.emplace_back(cv::Point2f(width, height)); 92 | quad.emplace_back(cv::Point2f(0, height)); 93 | return quad; 94 | } 95 | 96 | protected: 97 | cv::Point2f lb; 98 | cv::Point2f lt; 99 | cv::Point2f rt; 100 | cv::Point2f rb; 101 | std::vector vec2f; 102 | }; 103 | 104 | 105 | -------------------------------------------------------------------------------- /include/pybind11/attr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/attr.h: Infrastructure for processing custom 3 | type and function attributes 4 | 5 | Copyright (c) 2016 Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cast.h" 14 | 15 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 16 | 17 | /// \addtogroup annotations 18 | /// @{ 19 | 20 | /// Annotation for methods 21 | struct is_method { handle class_; is_method(const handle &c) : class_(c) { } }; 22 | 23 | /// Annotation for operators 24 | struct is_operator { }; 25 | 26 | /// Annotation for parent scope 27 | struct scope { handle value; scope(const handle &s) : value(s) { } }; 28 | 29 | /// Annotation for documentation 30 | struct doc { const char *value; doc(const char *value) : value(value) { } }; 31 | 32 | /// Annotation for function names 33 | struct name { const char *value; name(const char *value) : value(value) { } }; 34 | 35 | /// Annotation indicating that a function is an overload associated with a given "sibling" 36 | struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } }; 37 | 38 | /// Annotation indicating that a class derives from another given type 39 | template struct base { 40 | PYBIND11_DEPRECATED("base() was deprecated in favor of specifying 'T' as a template argument to class_") 41 | base() { } 42 | }; 43 | 44 | /// Keep patient alive while nurse lives 45 | template struct keep_alive { }; 46 | 47 | /// Annotation indicating that a class is involved in a multiple inheritance relationship 48 | struct multiple_inheritance { }; 49 | 50 | /// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class 51 | struct dynamic_attr { }; 52 | 53 | /// Annotation which enables the buffer protocol for a type 54 | struct buffer_protocol { }; 55 | 56 | /// Annotation which requests that a special metaclass is created for a type 57 | struct metaclass { 58 | handle value; 59 | 60 | PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.") 61 | metaclass() {} 62 | 63 | /// Override pybind11's default metaclass 64 | explicit metaclass(handle value) : value(value) { } 65 | }; 66 | 67 | /// Annotation that marks a class as local to the module: 68 | struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } }; 69 | 70 | /// Annotation to mark enums as an arithmetic type 71 | struct arithmetic { }; 72 | 73 | /** \rst 74 | A call policy which places one or more guard variables (``Ts...``) around the function call. 75 | 76 | For example, this definition: 77 | 78 | .. code-block:: cpp 79 | 80 | m.def("foo", foo, py::call_guard()); 81 | 82 | is equivalent to the following pseudocode: 83 | 84 | .. code-block:: cpp 85 | 86 | m.def("foo", [](args...) { 87 | T scope_guard; 88 | return foo(args...); // forwarded arguments 89 | }); 90 | \endrst */ 91 | template struct call_guard; 92 | 93 | template <> struct call_guard<> { using type = detail::void_type; }; 94 | 95 | template 96 | struct call_guard { 97 | static_assert(std::is_default_constructible::value, 98 | "The guard type must be default constructible"); 99 | 100 | using type = T; 101 | }; 102 | 103 | template 104 | struct call_guard { 105 | struct type { 106 | T guard{}; // Compose multiple guard types with left-to-right default-constructor order 107 | typename call_guard::type next{}; 108 | }; 109 | }; 110 | 111 | /// @} annotations 112 | 113 | NAMESPACE_BEGIN(detail) 114 | /* Forward declarations */ 115 | enum op_id : int; 116 | enum op_type : int; 117 | struct undefined_t; 118 | template struct op_; 119 | inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret); 120 | 121 | /// Internal data structure which holds metadata about a keyword argument 122 | struct argument_record { 123 | const char *name; ///< Argument name 124 | const char *descr; ///< Human-readable version of the argument value 125 | handle value; ///< Associated Python object 126 | bool convert : 1; ///< True if the argument is allowed to convert when loading 127 | bool none : 1; ///< True if None is allowed when loading 128 | 129 | argument_record(const char *name, const char *descr, handle value, bool convert, bool none) 130 | : name(name), descr(descr), value(value), convert(convert), none(none) { } 131 | }; 132 | 133 | /// Internal data structure which holds metadata about a bound function (signature, overloads, etc.) 134 | struct function_record { 135 | function_record() 136 | : is_constructor(false), is_new_style_constructor(false), is_stateless(false), 137 | is_operator(false), has_args(false), has_kwargs(false), is_method(false) { } 138 | 139 | /// Function name 140 | char *name = nullptr; /* why no C++ strings? They generate heavier code.. */ 141 | 142 | // User-specified documentation string 143 | char *doc = nullptr; 144 | 145 | /// Human-readable version of the function signature 146 | char *signature = nullptr; 147 | 148 | /// List of registered keyword arguments 149 | std::vector args; 150 | 151 | /// Pointer to lambda function which converts arguments and performs the actual call 152 | handle (*impl) (function_call &) = nullptr; 153 | 154 | /// Storage for the wrapped function pointer and captured data, if any 155 | void *data[3] = { }; 156 | 157 | /// Pointer to custom destructor for 'data' (if needed) 158 | void (*free_data) (function_record *ptr) = nullptr; 159 | 160 | /// Return value policy associated with this function 161 | return_value_policy policy = return_value_policy::automatic; 162 | 163 | /// True if name == '__init__' 164 | bool is_constructor : 1; 165 | 166 | /// True if this is a new-style `__init__` defined in `detail/init.h` 167 | bool is_new_style_constructor : 1; 168 | 169 | /// True if this is a stateless function pointer 170 | bool is_stateless : 1; 171 | 172 | /// True if this is an operator (__add__), etc. 173 | bool is_operator : 1; 174 | 175 | /// True if the function has a '*args' argument 176 | bool has_args : 1; 177 | 178 | /// True if the function has a '**kwargs' argument 179 | bool has_kwargs : 1; 180 | 181 | /// True if this is a method 182 | bool is_method : 1; 183 | 184 | /// Number of arguments (including py::args and/or py::kwargs, if present) 185 | std::uint16_t nargs; 186 | 187 | /// Python method object 188 | PyMethodDef *def = nullptr; 189 | 190 | /// Python handle to the parent scope (a class or a module) 191 | handle scope; 192 | 193 | /// Python handle to the sibling function representing an overload chain 194 | handle sibling; 195 | 196 | /// Pointer to next overload 197 | function_record *next = nullptr; 198 | }; 199 | 200 | /// Special data structure which (temporarily) holds metadata about a bound class 201 | struct type_record { 202 | PYBIND11_NOINLINE type_record() 203 | : multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), module_local(false) { } 204 | 205 | /// Handle to the parent scope 206 | handle scope; 207 | 208 | /// Name of the class 209 | const char *name = nullptr; 210 | 211 | // Pointer to RTTI type_info data structure 212 | const std::type_info *type = nullptr; 213 | 214 | /// How large is the underlying C++ type? 215 | size_t type_size = 0; 216 | 217 | /// What is the alignment of the underlying C++ type? 218 | size_t type_align = 0; 219 | 220 | /// How large is the type's holder? 221 | size_t holder_size = 0; 222 | 223 | /// The global operator new can be overridden with a class-specific variant 224 | void *(*operator_new)(size_t) = nullptr; 225 | 226 | /// Function pointer to class_<..>::init_instance 227 | void (*init_instance)(instance *, const void *) = nullptr; 228 | 229 | /// Function pointer to class_<..>::dealloc 230 | void (*dealloc)(detail::value_and_holder &) = nullptr; 231 | 232 | /// List of base classes of the newly created type 233 | list bases; 234 | 235 | /// Optional docstring 236 | const char *doc = nullptr; 237 | 238 | /// Custom metaclass (optional) 239 | handle metaclass; 240 | 241 | /// Multiple inheritance marker 242 | bool multiple_inheritance : 1; 243 | 244 | /// Does the class manage a __dict__? 245 | bool dynamic_attr : 1; 246 | 247 | /// Does the class implement the buffer protocol? 248 | bool buffer_protocol : 1; 249 | 250 | /// Is the default (unique_ptr) holder type used? 251 | bool default_holder : 1; 252 | 253 | /// Is the class definition local to the module shared object? 254 | bool module_local : 1; 255 | 256 | PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) { 257 | auto base_info = detail::get_type_info(base, false); 258 | if (!base_info) { 259 | std::string tname(base.name()); 260 | detail::clean_type_id(tname); 261 | pybind11_fail("generic_type: type \"" + std::string(name) + 262 | "\" referenced unknown base type \"" + tname + "\""); 263 | } 264 | 265 | if (default_holder != base_info->default_holder) { 266 | std::string tname(base.name()); 267 | detail::clean_type_id(tname); 268 | pybind11_fail("generic_type: type \"" + std::string(name) + "\" " + 269 | (default_holder ? "does not have" : "has") + 270 | " a non-default holder type while its base \"" + tname + "\" " + 271 | (base_info->default_holder ? "does not" : "does")); 272 | } 273 | 274 | bases.append((PyObject *) base_info->type); 275 | 276 | if (base_info->type->tp_dictoffset != 0) 277 | dynamic_attr = true; 278 | 279 | if (caster) 280 | base_info->implicit_casts.emplace_back(type, caster); 281 | } 282 | }; 283 | 284 | inline function_call::function_call(const function_record &f, handle p) : 285 | func(f), parent(p) { 286 | args.reserve(f.nargs); 287 | args_convert.reserve(f.nargs); 288 | } 289 | 290 | /// Tag for a new-style `__init__` defined in `detail/init.h` 291 | struct is_new_style_constructor { }; 292 | 293 | /** 294 | * Partial template specializations to process custom attributes provided to 295 | * cpp_function_ and class_. These are either used to initialize the respective 296 | * fields in the type_record and function_record data structures or executed at 297 | * runtime to deal with custom call policies (e.g. keep_alive). 298 | */ 299 | template struct process_attribute; 300 | 301 | template struct process_attribute_default { 302 | /// Default implementation: do nothing 303 | static void init(const T &, function_record *) { } 304 | static void init(const T &, type_record *) { } 305 | static void precall(function_call &) { } 306 | static void postcall(function_call &, handle) { } 307 | }; 308 | 309 | /// Process an attribute specifying the function's name 310 | template <> struct process_attribute : process_attribute_default { 311 | static void init(const name &n, function_record *r) { r->name = const_cast(n.value); } 312 | }; 313 | 314 | /// Process an attribute specifying the function's docstring 315 | template <> struct process_attribute : process_attribute_default { 316 | static void init(const doc &n, function_record *r) { r->doc = const_cast(n.value); } 317 | }; 318 | 319 | /// Process an attribute specifying the function's docstring (provided as a C-style string) 320 | template <> struct process_attribute : process_attribute_default { 321 | static void init(const char *d, function_record *r) { r->doc = const_cast(d); } 322 | static void init(const char *d, type_record *r) { r->doc = const_cast(d); } 323 | }; 324 | template <> struct process_attribute : process_attribute { }; 325 | 326 | /// Process an attribute indicating the function's return value policy 327 | template <> struct process_attribute : process_attribute_default { 328 | static void init(const return_value_policy &p, function_record *r) { r->policy = p; } 329 | }; 330 | 331 | /// Process an attribute which indicates that this is an overloaded function associated with a given sibling 332 | template <> struct process_attribute : process_attribute_default { 333 | static void init(const sibling &s, function_record *r) { r->sibling = s.value; } 334 | }; 335 | 336 | /// Process an attribute which indicates that this function is a method 337 | template <> struct process_attribute : process_attribute_default { 338 | static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; } 339 | }; 340 | 341 | /// Process an attribute which indicates the parent scope of a method 342 | template <> struct process_attribute : process_attribute_default { 343 | static void init(const scope &s, function_record *r) { r->scope = s.value; } 344 | }; 345 | 346 | /// Process an attribute which indicates that this function is an operator 347 | template <> struct process_attribute : process_attribute_default { 348 | static void init(const is_operator &, function_record *r) { r->is_operator = true; } 349 | }; 350 | 351 | template <> struct process_attribute : process_attribute_default { 352 | static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; } 353 | }; 354 | 355 | /// Process a keyword argument attribute (*without* a default value) 356 | template <> struct process_attribute : process_attribute_default { 357 | static void init(const arg &a, function_record *r) { 358 | if (r->is_method && r->args.empty()) 359 | r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/); 360 | r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none); 361 | } 362 | }; 363 | 364 | /// Process a keyword argument attribute (*with* a default value) 365 | template <> struct process_attribute : process_attribute_default { 366 | static void init(const arg_v &a, function_record *r) { 367 | if (r->is_method && r->args.empty()) 368 | r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/); 369 | 370 | if (!a.value) { 371 | #if !defined(NDEBUG) 372 | std::string descr("'"); 373 | if (a.name) descr += std::string(a.name) + ": "; 374 | descr += a.type + "'"; 375 | if (r->is_method) { 376 | if (r->name) 377 | descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'"; 378 | else 379 | descr += " in method of '" + (std::string) str(r->scope) + "'"; 380 | } else if (r->name) { 381 | descr += " in function '" + (std::string) r->name + "'"; 382 | } 383 | pybind11_fail("arg(): could not convert default argument " 384 | + descr + " into a Python object (type not registered yet?)"); 385 | #else 386 | pybind11_fail("arg(): could not convert default argument " 387 | "into a Python object (type not registered yet?). " 388 | "Compile in debug mode for more information."); 389 | #endif 390 | } 391 | r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none); 392 | } 393 | }; 394 | 395 | /// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that) 396 | template 397 | struct process_attribute::value>> : process_attribute_default { 398 | static void init(const handle &h, type_record *r) { r->bases.append(h); } 399 | }; 400 | 401 | /// Process a parent class attribute (deprecated, does not support multiple inheritance) 402 | template 403 | struct process_attribute> : process_attribute_default> { 404 | static void init(const base &, type_record *r) { r->add_base(typeid(T), nullptr); } 405 | }; 406 | 407 | /// Process a multiple inheritance attribute 408 | template <> 409 | struct process_attribute : process_attribute_default { 410 | static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; } 411 | }; 412 | 413 | template <> 414 | struct process_attribute : process_attribute_default { 415 | static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; } 416 | }; 417 | 418 | template <> 419 | struct process_attribute : process_attribute_default { 420 | static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; } 421 | }; 422 | 423 | template <> 424 | struct process_attribute : process_attribute_default { 425 | static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; } 426 | }; 427 | 428 | template <> 429 | struct process_attribute : process_attribute_default { 430 | static void init(const module_local &l, type_record *r) { r->module_local = l.value; } 431 | }; 432 | 433 | /// Process an 'arithmetic' attribute for enums (does nothing here) 434 | template <> 435 | struct process_attribute : process_attribute_default {}; 436 | 437 | template 438 | struct process_attribute> : process_attribute_default> { }; 439 | 440 | /** 441 | * Process a keep_alive call policy -- invokes keep_alive_impl during the 442 | * pre-call handler if both Nurse, Patient != 0 and use the post-call handler 443 | * otherwise 444 | */ 445 | template struct process_attribute> : public process_attribute_default> { 446 | template = 0> 447 | static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); } 448 | template = 0> 449 | static void postcall(function_call &, handle) { } 450 | template = 0> 451 | static void precall(function_call &) { } 452 | template = 0> 453 | static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); } 454 | }; 455 | 456 | /// Recursively iterate over variadic template arguments 457 | template struct process_attributes { 458 | static void init(const Args&... args, function_record *r) { 459 | int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; 460 | ignore_unused(unused); 461 | } 462 | static void init(const Args&... args, type_record *r) { 463 | int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; 464 | ignore_unused(unused); 465 | } 466 | static void precall(function_call &call) { 467 | int unused[] = { 0, (process_attribute::type>::precall(call), 0) ... }; 468 | ignore_unused(unused); 469 | } 470 | static void postcall(function_call &call, handle fn_ret) { 471 | int unused[] = { 0, (process_attribute::type>::postcall(call, fn_ret), 0) ... }; 472 | ignore_unused(unused); 473 | } 474 | }; 475 | 476 | template 477 | using is_call_guard = is_instantiation; 478 | 479 | /// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found) 480 | template 481 | using extract_guard_t = typename exactly_one_t, Extra...>::type; 482 | 483 | /// Check the number of named arguments at compile time 484 | template ::value...), 486 | size_t self = constexpr_sum(std::is_same::value...)> 487 | constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) { 488 | return named == 0 || (self + named + has_args + has_kwargs) == nargs; 489 | } 490 | 491 | NAMESPACE_END(detail) 492 | NAMESPACE_END(PYBIND11_NAMESPACE) 493 | -------------------------------------------------------------------------------- /include/pybind11/buffer_info.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/buffer_info.h: Python buffer object interface 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | 16 | /// Information record describing a Python buffer object 17 | struct buffer_info { 18 | void *ptr = nullptr; // Pointer to the underlying storage 19 | ssize_t itemsize = 0; // Size of individual items in bytes 20 | ssize_t size = 0; // Total number of entries 21 | std::string format; // For homogeneous buffers, this should be set to format_descriptor::format() 22 | ssize_t ndim = 0; // Number of dimensions 23 | std::vector shape; // Shape of the tensor (1 entry per dimension) 24 | std::vector strides; // Number of entries between adjacent entries (for each per dimension) 25 | 26 | buffer_info() { } 27 | 28 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 29 | detail::any_container shape_in, detail::any_container strides_in) 30 | : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), 31 | shape(std::move(shape_in)), strides(std::move(strides_in)) { 32 | if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) 33 | pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); 34 | for (size_t i = 0; i < (size_t) ndim; ++i) 35 | size *= shape[i]; 36 | } 37 | 38 | template 39 | buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in) 40 | : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { } 41 | 42 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size) 43 | : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { } 44 | 45 | template 46 | buffer_info(T *ptr, ssize_t size) 47 | : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { } 48 | 49 | explicit buffer_info(Py_buffer *view, bool ownview = true) 50 | : buffer_info(view->buf, view->itemsize, view->format, view->ndim, 51 | {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) { 52 | this->view = view; 53 | this->ownview = ownview; 54 | } 55 | 56 | buffer_info(const buffer_info &) = delete; 57 | buffer_info& operator=(const buffer_info &) = delete; 58 | 59 | buffer_info(buffer_info &&other) { 60 | (*this) = std::move(other); 61 | } 62 | 63 | buffer_info& operator=(buffer_info &&rhs) { 64 | ptr = rhs.ptr; 65 | itemsize = rhs.itemsize; 66 | size = rhs.size; 67 | format = std::move(rhs.format); 68 | ndim = rhs.ndim; 69 | shape = std::move(rhs.shape); 70 | strides = std::move(rhs.strides); 71 | std::swap(view, rhs.view); 72 | std::swap(ownview, rhs.ownview); 73 | return *this; 74 | } 75 | 76 | ~buffer_info() { 77 | if (view && ownview) { PyBuffer_Release(view); delete view; } 78 | } 79 | 80 | private: 81 | struct private_ctr_tag { }; 82 | 83 | buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 84 | detail::any_container &&shape_in, detail::any_container &&strides_in) 85 | : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { } 86 | 87 | Py_buffer *view = nullptr; 88 | bool ownview = false; 89 | }; 90 | 91 | NAMESPACE_BEGIN(detail) 92 | 93 | template struct compare_buffer_info { 94 | static bool compare(const buffer_info& b) { 95 | return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); 96 | } 97 | }; 98 | 99 | template struct compare_buffer_info::value>> { 100 | static bool compare(const buffer_info& b) { 101 | return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value || 102 | ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || 103 | ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); 104 | } 105 | }; 106 | 107 | NAMESPACE_END(detail) 108 | NAMESPACE_END(PYBIND11_NAMESPACE) 109 | -------------------------------------------------------------------------------- /include/pybind11/chrono.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime 3 | 4 | Copyright (c) 2016 Trent Houliston and 5 | Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "pybind11.h" 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | // Backport the PyDateTime_DELTA functions from Python3.3 if required 20 | #ifndef PyDateTime_DELTA_GET_DAYS 21 | #define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days) 22 | #endif 23 | #ifndef PyDateTime_DELTA_GET_SECONDS 24 | #define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds) 25 | #endif 26 | #ifndef PyDateTime_DELTA_GET_MICROSECONDS 27 | #define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds) 28 | #endif 29 | 30 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 31 | NAMESPACE_BEGIN(detail) 32 | 33 | template class duration_caster { 34 | public: 35 | typedef typename type::rep rep; 36 | typedef typename type::period period; 37 | 38 | typedef std::chrono::duration> days; 39 | 40 | bool load(handle src, bool) { 41 | using namespace std::chrono; 42 | 43 | // Lazy initialise the PyDateTime import 44 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 45 | 46 | if (!src) return false; 47 | // If invoked with datetime.delta object 48 | if (PyDelta_Check(src.ptr())) { 49 | value = type(duration_cast>( 50 | days(PyDateTime_DELTA_GET_DAYS(src.ptr())) 51 | + seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr())) 52 | + microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr())))); 53 | return true; 54 | } 55 | // If invoked with a float we assume it is seconds and convert 56 | else if (PyFloat_Check(src.ptr())) { 57 | value = type(duration_cast>(duration(PyFloat_AsDouble(src.ptr())))); 58 | return true; 59 | } 60 | else return false; 61 | } 62 | 63 | // If this is a duration just return it back 64 | static const std::chrono::duration& get_duration(const std::chrono::duration &src) { 65 | return src; 66 | } 67 | 68 | // If this is a time_point get the time_since_epoch 69 | template static std::chrono::duration get_duration(const std::chrono::time_point> &src) { 70 | return src.time_since_epoch(); 71 | } 72 | 73 | static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) { 74 | using namespace std::chrono; 75 | 76 | // Use overloaded function to get our duration from our source 77 | // Works out if it is a duration or time_point and get the duration 78 | auto d = get_duration(src); 79 | 80 | // Lazy initialise the PyDateTime import 81 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 82 | 83 | // Declare these special duration types so the conversions happen with the correct primitive types (int) 84 | using dd_t = duration>; 85 | using ss_t = duration>; 86 | using us_t = duration; 87 | 88 | auto dd = duration_cast(d); 89 | auto subd = d - dd; 90 | auto ss = duration_cast(subd); 91 | auto us = duration_cast(subd - ss); 92 | return PyDelta_FromDSU(dd.count(), ss.count(), us.count()); 93 | } 94 | 95 | PYBIND11_TYPE_CASTER(type, _("datetime.timedelta")); 96 | }; 97 | 98 | // This is for casting times on the system clock into datetime.datetime instances 99 | template class type_caster> { 100 | public: 101 | typedef std::chrono::time_point type; 102 | bool load(handle src, bool) { 103 | using namespace std::chrono; 104 | 105 | // Lazy initialise the PyDateTime import 106 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 107 | 108 | if (!src) return false; 109 | if (PyDateTime_Check(src.ptr())) { 110 | std::tm cal; 111 | cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr()); 112 | cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr()); 113 | cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr()); 114 | cal.tm_mday = PyDateTime_GET_DAY(src.ptr()); 115 | cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1; 116 | cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900; 117 | cal.tm_isdst = -1; 118 | 119 | value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr())); 120 | return true; 121 | } 122 | else return false; 123 | } 124 | 125 | static handle cast(const std::chrono::time_point &src, return_value_policy /* policy */, handle /* parent */) { 126 | using namespace std::chrono; 127 | 128 | // Lazy initialise the PyDateTime import 129 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 130 | 131 | std::time_t tt = system_clock::to_time_t(src); 132 | // this function uses static memory so it's best to copy it out asap just in case 133 | // otherwise other code that is using localtime may break this (not just python code) 134 | std::tm localtime = *std::localtime(&tt); 135 | 136 | // Declare these special duration types so the conversions happen with the correct primitive types (int) 137 | using us_t = duration; 138 | 139 | return PyDateTime_FromDateAndTime(localtime.tm_year + 1900, 140 | localtime.tm_mon + 1, 141 | localtime.tm_mday, 142 | localtime.tm_hour, 143 | localtime.tm_min, 144 | localtime.tm_sec, 145 | (duration_cast(src.time_since_epoch() % seconds(1))).count()); 146 | } 147 | PYBIND11_TYPE_CASTER(type, _("datetime.datetime")); 148 | }; 149 | 150 | // Other clocks that are not the system clock are not measured as datetime.datetime objects 151 | // since they are not measured on calendar time. So instead we just make them timedeltas 152 | // Or if they have passed us a time as a float we convert that 153 | template class type_caster> 154 | : public duration_caster> { 155 | }; 156 | 157 | template class type_caster> 158 | : public duration_caster> { 159 | }; 160 | 161 | NAMESPACE_END(detail) 162 | NAMESPACE_END(PYBIND11_NAMESPACE) 163 | -------------------------------------------------------------------------------- /include/pybind11/class_support.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/class_support.h: Python C API implementation details for py::class_ 3 | 4 | Copyright (c) 2017 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "attr.h" 13 | 14 | NAMESPACE_BEGIN(pybind11) 15 | NAMESPACE_BEGIN(detail) 16 | 17 | inline PyTypeObject *type_incref(PyTypeObject *type) { 18 | Py_INCREF(type); 19 | return type; 20 | } 21 | 22 | #if !defined(PYPY_VERSION) 23 | 24 | /// `pybind11_static_property.__get__()`: Always pass the class instead of the instance. 25 | extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) { 26 | return PyProperty_Type.tp_descr_get(self, cls, cls); 27 | } 28 | 29 | /// `pybind11_static_property.__set__()`: Just like the above `__get__()`. 30 | extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) { 31 | PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj); 32 | return PyProperty_Type.tp_descr_set(self, cls, value); 33 | } 34 | 35 | /** A `static_property` is the same as a `property` but the `__get__()` and `__set__()` 36 | methods are modified to always use the object type instead of a concrete instance. 37 | Return value: New reference. */ 38 | inline PyTypeObject *make_static_property_type() { 39 | constexpr auto *name = "pybind11_static_property"; 40 | auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); 41 | 42 | /* Danger zone: from now (and until PyType_Ready), make sure to 43 | issue no Python C API calls which could potentially invoke the 44 | garbage collector (the GC will call type_traverse(), which will in 45 | turn find the newly constructed type in an invalid state) */ 46 | auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); 47 | if (!heap_type) 48 | pybind11_fail("make_static_property_type(): error allocating type!"); 49 | 50 | heap_type->ht_name = name_obj.inc_ref().ptr(); 51 | #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 52 | heap_type->ht_qualname = name_obj.inc_ref().ptr(); 53 | #endif 54 | 55 | auto type = &heap_type->ht_type; 56 | type->tp_name = name; 57 | type->tp_base = type_incref(&PyProperty_Type); 58 | type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; 59 | type->tp_descr_get = pybind11_static_get; 60 | type->tp_descr_set = pybind11_static_set; 61 | 62 | if (PyType_Ready(type) < 0) 63 | pybind11_fail("make_static_property_type(): failure in PyType_Ready()!"); 64 | 65 | setattr((PyObject *) type, "__module__", str("pybind11_builtins")); 66 | 67 | return type; 68 | } 69 | 70 | #else // PYPY 71 | 72 | /** PyPy has some issues with the above C API, so we evaluate Python code instead. 73 | This function will only be called once so performance isn't really a concern. 74 | Return value: New reference. */ 75 | inline PyTypeObject *make_static_property_type() { 76 | auto d = dict(); 77 | PyObject *result = PyRun_String(R"(\ 78 | class pybind11_static_property(property): 79 | def __get__(self, obj, cls): 80 | return property.__get__(self, cls, cls) 81 | 82 | def __set__(self, obj, value): 83 | cls = obj if isinstance(obj, type) else type(obj) 84 | property.__set__(self, cls, value) 85 | )", Py_file_input, d.ptr(), d.ptr() 86 | ); 87 | if (result == nullptr) 88 | throw error_already_set(); 89 | Py_DECREF(result); 90 | return (PyTypeObject *) d["pybind11_static_property"].cast().release().ptr(); 91 | } 92 | 93 | #endif // PYPY 94 | 95 | /** Types with static properties need to handle `Type.static_prop = x` in a specific way. 96 | By default, Python replaces the `static_property` itself, but for wrapped C++ types 97 | we need to call `static_property.__set__()` in order to propagate the new value to 98 | the underlying C++ data structure. */ 99 | extern "C" inline int pybind11_meta_setattro(PyObject* obj, PyObject* name, PyObject* value) { 100 | // Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw 101 | // descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`). 102 | PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); 103 | 104 | // The following assignment combinations are possible: 105 | // 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)` 106 | // 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop` 107 | // 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment 108 | const auto static_prop = (PyObject *) get_internals().static_property_type; 109 | const auto call_descr_set = descr && PyObject_IsInstance(descr, static_prop) 110 | && !PyObject_IsInstance(value, static_prop); 111 | if (call_descr_set) { 112 | // Call `static_property.__set__()` instead of replacing the `static_property`. 113 | #if !defined(PYPY_VERSION) 114 | return Py_TYPE(descr)->tp_descr_set(descr, obj, value); 115 | #else 116 | if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) { 117 | Py_DECREF(result); 118 | return 0; 119 | } else { 120 | return -1; 121 | } 122 | #endif 123 | } else { 124 | // Replace existing attribute. 125 | return PyType_Type.tp_setattro(obj, name, value); 126 | } 127 | } 128 | 129 | #if PY_MAJOR_VERSION >= 3 130 | /** 131 | * Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing 132 | * methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function, 133 | * when called on a class, or a PyMethod, when called on an instance. Override that behaviour here 134 | * to do a special case bypass for PyInstanceMethod_Types. 135 | */ 136 | extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) { 137 | PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); 138 | if (descr && PyInstanceMethod_Check(descr)) { 139 | Py_INCREF(descr); 140 | return descr; 141 | } 142 | else { 143 | return PyType_Type.tp_getattro(obj, name); 144 | } 145 | } 146 | #endif 147 | 148 | /** This metaclass is assigned by default to all pybind11 types and is required in order 149 | for static properties to function correctly. Users may override this using `py::metaclass`. 150 | Return value: New reference. */ 151 | inline PyTypeObject* make_default_metaclass() { 152 | constexpr auto *name = "pybind11_type"; 153 | auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); 154 | 155 | /* Danger zone: from now (and until PyType_Ready), make sure to 156 | issue no Python C API calls which could potentially invoke the 157 | garbage collector (the GC will call type_traverse(), which will in 158 | turn find the newly constructed type in an invalid state) */ 159 | auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); 160 | if (!heap_type) 161 | pybind11_fail("make_default_metaclass(): error allocating metaclass!"); 162 | 163 | heap_type->ht_name = name_obj.inc_ref().ptr(); 164 | #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 165 | heap_type->ht_qualname = name_obj.inc_ref().ptr(); 166 | #endif 167 | 168 | auto type = &heap_type->ht_type; 169 | type->tp_name = name; 170 | type->tp_base = type_incref(&PyType_Type); 171 | type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; 172 | 173 | type->tp_setattro = pybind11_meta_setattro; 174 | #if PY_MAJOR_VERSION >= 3 175 | type->tp_getattro = pybind11_meta_getattro; 176 | #endif 177 | 178 | if (PyType_Ready(type) < 0) 179 | pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!"); 180 | 181 | setattr((PyObject *) type, "__module__", str("pybind11_builtins")); 182 | 183 | return type; 184 | } 185 | 186 | /// For multiple inheritance types we need to recursively register/deregister base pointers for any 187 | /// base classes with pointers that are difference from the instance value pointer so that we can 188 | /// correctly recognize an offset base class pointer. This calls a function with any offset base ptrs. 189 | inline void traverse_offset_bases(void *valueptr, const detail::type_info *tinfo, instance *self, 190 | bool (*f)(void * /*parentptr*/, instance * /*self*/)) { 191 | for (handle h : reinterpret_borrow(tinfo->type->tp_bases)) { 192 | if (auto parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) { 193 | for (auto &c : parent_tinfo->implicit_casts) { 194 | if (c.first == tinfo->cpptype) { 195 | auto *parentptr = c.second(valueptr); 196 | if (parentptr != valueptr) 197 | f(parentptr, self); 198 | traverse_offset_bases(parentptr, parent_tinfo, self, f); 199 | break; 200 | } 201 | } 202 | } 203 | } 204 | } 205 | 206 | inline bool register_instance_impl(void *ptr, instance *self) { 207 | get_internals().registered_instances.emplace(ptr, self); 208 | return true; // unused, but gives the same signature as the deregister func 209 | } 210 | inline bool deregister_instance_impl(void *ptr, instance *self) { 211 | auto ®istered_instances = get_internals().registered_instances; 212 | auto range = registered_instances.equal_range(ptr); 213 | for (auto it = range.first; it != range.second; ++it) { 214 | if (Py_TYPE(self) == Py_TYPE(it->second)) { 215 | registered_instances.erase(it); 216 | return true; 217 | } 218 | } 219 | return false; 220 | } 221 | 222 | inline void register_instance(instance *self, void *valptr, const type_info *tinfo) { 223 | register_instance_impl(valptr, self); 224 | if (!tinfo->simple_ancestors) 225 | traverse_offset_bases(valptr, tinfo, self, register_instance_impl); 226 | } 227 | 228 | inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) { 229 | bool ret = deregister_instance_impl(valptr, self); 230 | if (!tinfo->simple_ancestors) 231 | traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl); 232 | return ret; 233 | } 234 | 235 | /// Instance creation function for all pybind11 types. It only allocates space for the C++ object 236 | /// (or multiple objects, for Python-side inheritance from multiple pybind11 types), but doesn't 237 | /// call the constructor -- an `__init__` function must do that (followed by an `init_instance` 238 | /// to set up the holder and register the instance). 239 | inline PyObject *make_new_instance(PyTypeObject *type, bool allocate_value /*= true (in cast.h)*/) { 240 | #if defined(PYPY_VERSION) 241 | // PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first inherited 242 | // object is a a plain Python type (i.e. not derived from an extension type). Fix it. 243 | ssize_t instance_size = static_cast(sizeof(instance)); 244 | if (type->tp_basicsize < instance_size) { 245 | type->tp_basicsize = instance_size; 246 | } 247 | #endif 248 | PyObject *self = type->tp_alloc(type, 0); 249 | auto inst = reinterpret_cast(self); 250 | // Allocate the value/holder internals: 251 | inst->allocate_layout(); 252 | 253 | inst->owned = true; 254 | // Allocate (if requested) the value pointers; otherwise leave them as nullptr 255 | if (allocate_value) { 256 | for (auto &v_h : values_and_holders(inst)) { 257 | void *&vptr = v_h.value_ptr(); 258 | vptr = v_h.type->operator_new(v_h.type->type_size); 259 | } 260 | } 261 | 262 | return self; 263 | } 264 | 265 | /// Instance creation function for all pybind11 types. It only allocates space for the 266 | /// C++ object, but doesn't call the constructor -- an `__init__` function must do that. 267 | extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) { 268 | return make_new_instance(type); 269 | } 270 | 271 | /// An `__init__` function constructs the C++ object. Users should provide at least one 272 | /// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the 273 | /// following default function will be used which simply throws an exception. 274 | extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) { 275 | PyTypeObject *type = Py_TYPE(self); 276 | std::string msg; 277 | #if defined(PYPY_VERSION) 278 | msg += handle((PyObject *) type).attr("__module__").cast() + "."; 279 | #endif 280 | msg += type->tp_name; 281 | msg += ": No constructor defined!"; 282 | PyErr_SetString(PyExc_TypeError, msg.c_str()); 283 | return -1; 284 | } 285 | 286 | inline void add_patient(PyObject *nurse, PyObject *patient) { 287 | auto &internals = get_internals(); 288 | auto instance = reinterpret_cast(nurse); 289 | instance->has_patients = true; 290 | Py_INCREF(patient); 291 | internals.patients[nurse].push_back(patient); 292 | } 293 | 294 | inline void clear_patients(PyObject *self) { 295 | auto instance = reinterpret_cast(self); 296 | auto &internals = get_internals(); 297 | auto pos = internals.patients.find(self); 298 | assert(pos != internals.patients.end()); 299 | // Clearing the patients can cause more Python code to run, which 300 | // can invalidate the iterator. Extract the vector of patients 301 | // from the unordered_map first. 302 | auto patients = std::move(pos->second); 303 | internals.patients.erase(pos); 304 | instance->has_patients = false; 305 | for (PyObject *&patient : patients) 306 | Py_CLEAR(patient); 307 | } 308 | 309 | /// Clears all internal data from the instance and removes it from registered instances in 310 | /// preparation for deallocation. 311 | inline void clear_instance(PyObject *self) { 312 | auto instance = reinterpret_cast(self); 313 | 314 | // Deallocate any values/holders, if present: 315 | for (auto &v_h : values_and_holders(instance)) { 316 | if (v_h) { 317 | 318 | // We have to deregister before we call dealloc because, for virtual MI types, we still 319 | // need to be able to get the parent pointers. 320 | if (v_h.instance_registered() && !deregister_instance(instance, v_h.value_ptr(), v_h.type)) 321 | pybind11_fail("pybind11_object_dealloc(): Tried to deallocate unregistered instance!"); 322 | 323 | if (instance->owned || v_h.holder_constructed()) 324 | v_h.type->dealloc(v_h); 325 | } 326 | } 327 | // Deallocate the value/holder layout internals: 328 | instance->deallocate_layout(); 329 | 330 | if (instance->weakrefs) 331 | PyObject_ClearWeakRefs(self); 332 | 333 | PyObject **dict_ptr = _PyObject_GetDictPtr(self); 334 | if (dict_ptr) 335 | Py_CLEAR(*dict_ptr); 336 | 337 | if (instance->has_patients) 338 | clear_patients(self); 339 | } 340 | 341 | /// Instance destructor function for all pybind11 types. It calls `type_info.dealloc` 342 | /// to destroy the C++ object itself, while the rest is Python bookkeeping. 343 | extern "C" inline void pybind11_object_dealloc(PyObject *self) { 344 | clear_instance(self); 345 | Py_TYPE(self)->tp_free(self); 346 | } 347 | 348 | /** Create the type which can be used as a common base for all classes. This is 349 | needed in order to satisfy Python's requirements for multiple inheritance. 350 | Return value: New reference. */ 351 | inline PyObject *make_object_base_type(PyTypeObject *metaclass) { 352 | constexpr auto *name = "pybind11_object"; 353 | auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); 354 | 355 | /* Danger zone: from now (and until PyType_Ready), make sure to 356 | issue no Python C API calls which could potentially invoke the 357 | garbage collector (the GC will call type_traverse(), which will in 358 | turn find the newly constructed type in an invalid state) */ 359 | auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); 360 | if (!heap_type) 361 | pybind11_fail("make_object_base_type(): error allocating type!"); 362 | 363 | heap_type->ht_name = name_obj.inc_ref().ptr(); 364 | #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 365 | heap_type->ht_qualname = name_obj.inc_ref().ptr(); 366 | #endif 367 | 368 | auto type = &heap_type->ht_type; 369 | type->tp_name = name; 370 | type->tp_base = type_incref(&PyBaseObject_Type); 371 | type->tp_basicsize = static_cast(sizeof(instance)); 372 | type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; 373 | 374 | type->tp_new = pybind11_object_new; 375 | type->tp_init = pybind11_object_init; 376 | type->tp_dealloc = pybind11_object_dealloc; 377 | 378 | /* Support weak references (needed for the keep_alive feature) */ 379 | type->tp_weaklistoffset = offsetof(instance, weakrefs); 380 | 381 | if (PyType_Ready(type) < 0) 382 | pybind11_fail("PyType_Ready failed in make_object_base_type():" + error_string()); 383 | 384 | setattr((PyObject *) type, "__module__", str("pybind11_builtins")); 385 | 386 | assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); 387 | return (PyObject *) heap_type; 388 | } 389 | 390 | /// dynamic_attr: Support for `d = instance.__dict__`. 391 | extern "C" inline PyObject *pybind11_get_dict(PyObject *self, void *) { 392 | PyObject *&dict = *_PyObject_GetDictPtr(self); 393 | if (!dict) 394 | dict = PyDict_New(); 395 | Py_XINCREF(dict); 396 | return dict; 397 | } 398 | 399 | /// dynamic_attr: Support for `instance.__dict__ = dict()`. 400 | extern "C" inline int pybind11_set_dict(PyObject *self, PyObject *new_dict, void *) { 401 | if (!PyDict_Check(new_dict)) { 402 | PyErr_Format(PyExc_TypeError, "__dict__ must be set to a dictionary, not a '%.200s'", 403 | Py_TYPE(new_dict)->tp_name); 404 | return -1; 405 | } 406 | PyObject *&dict = *_PyObject_GetDictPtr(self); 407 | Py_INCREF(new_dict); 408 | Py_CLEAR(dict); 409 | dict = new_dict; 410 | return 0; 411 | } 412 | 413 | /// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`. 414 | extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) { 415 | PyObject *&dict = *_PyObject_GetDictPtr(self); 416 | Py_VISIT(dict); 417 | return 0; 418 | } 419 | 420 | /// dynamic_attr: Allow the GC to clear the dictionary. 421 | extern "C" inline int pybind11_clear(PyObject *self) { 422 | PyObject *&dict = *_PyObject_GetDictPtr(self); 423 | Py_CLEAR(dict); 424 | return 0; 425 | } 426 | 427 | /// Give instances of this type a `__dict__` and opt into garbage collection. 428 | inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) { 429 | auto type = &heap_type->ht_type; 430 | #if defined(PYPY_VERSION) 431 | pybind11_fail(std::string(type->tp_name) + ": dynamic attributes are " 432 | "currently not supported in " 433 | "conjunction with PyPy!"); 434 | #endif 435 | type->tp_flags |= Py_TPFLAGS_HAVE_GC; 436 | type->tp_dictoffset = type->tp_basicsize; // place dict at the end 437 | type->tp_basicsize += (ssize_t)sizeof(PyObject *); // and allocate enough space for it 438 | type->tp_traverse = pybind11_traverse; 439 | type->tp_clear = pybind11_clear; 440 | 441 | static PyGetSetDef getset[] = { 442 | {const_cast("__dict__"), pybind11_get_dict, pybind11_set_dict, nullptr, nullptr}, 443 | {nullptr, nullptr, nullptr, nullptr, nullptr} 444 | }; 445 | type->tp_getset = getset; 446 | } 447 | 448 | /// buffer_protocol: Fill in the view as specified by flags. 449 | extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) { 450 | // Look for a `get_buffer` implementation in this type's info or any bases (following MRO). 451 | type_info *tinfo = nullptr; 452 | for (auto type : reinterpret_borrow(Py_TYPE(obj)->tp_mro)) { 453 | tinfo = get_type_info((PyTypeObject *) type.ptr()); 454 | if (tinfo && tinfo->get_buffer) 455 | break; 456 | } 457 | if (view == nullptr || obj == nullptr || !tinfo || !tinfo->get_buffer) { 458 | if (view) 459 | view->obj = nullptr; 460 | PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error"); 461 | return -1; 462 | } 463 | std::memset(view, 0, sizeof(Py_buffer)); 464 | buffer_info *info = tinfo->get_buffer(obj, tinfo->get_buffer_data); 465 | view->obj = obj; 466 | view->ndim = 1; 467 | view->internal = info; 468 | view->buf = info->ptr; 469 | view->itemsize = info->itemsize; 470 | view->len = view->itemsize; 471 | for (auto s : info->shape) 472 | view->len *= s; 473 | if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) 474 | view->format = const_cast(info->format.c_str()); 475 | if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { 476 | view->ndim = (int) info->ndim; 477 | view->strides = &info->strides[0]; 478 | view->shape = &info->shape[0]; 479 | } 480 | Py_INCREF(view->obj); 481 | return 0; 482 | } 483 | 484 | /// buffer_protocol: Release the resources of the buffer. 485 | extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) { 486 | delete (buffer_info *) view->internal; 487 | } 488 | 489 | /// Give this type a buffer interface. 490 | inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) { 491 | heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer; 492 | #if PY_MAJOR_VERSION < 3 493 | heap_type->ht_type.tp_flags |= Py_TPFLAGS_HAVE_NEWBUFFER; 494 | #endif 495 | 496 | heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer; 497 | heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer; 498 | } 499 | 500 | /** Create a brand new Python type according to the `type_record` specification. 501 | Return value: New reference. */ 502 | inline PyObject* make_new_python_type(const type_record &rec) { 503 | auto name = reinterpret_steal(PYBIND11_FROM_STRING(rec.name)); 504 | 505 | #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 506 | auto ht_qualname = name; 507 | if (rec.scope && hasattr(rec.scope, "__qualname__")) { 508 | ht_qualname = reinterpret_steal( 509 | PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr())); 510 | } 511 | #endif 512 | 513 | object module; 514 | if (rec.scope) { 515 | if (hasattr(rec.scope, "__module__")) 516 | module = rec.scope.attr("__module__"); 517 | else if (hasattr(rec.scope, "__name__")) 518 | module = rec.scope.attr("__name__"); 519 | } 520 | 521 | #if !defined(PYPY_VERSION) 522 | const auto full_name = module ? str(module).cast() + "." + rec.name 523 | : std::string(rec.name); 524 | #else 525 | const auto full_name = std::string(rec.name); 526 | #endif 527 | 528 | char *tp_doc = nullptr; 529 | if (rec.doc && options::show_user_defined_docstrings()) { 530 | /* Allocate memory for docstring (using PyObject_MALLOC, since 531 | Python will free this later on) */ 532 | size_t size = strlen(rec.doc) + 1; 533 | tp_doc = (char *) PyObject_MALLOC(size); 534 | memcpy((void *) tp_doc, rec.doc, size); 535 | } 536 | 537 | auto &internals = get_internals(); 538 | auto bases = tuple(rec.bases); 539 | auto base = (bases.size() == 0) ? internals.instance_base 540 | : bases[0].ptr(); 541 | 542 | /* Danger zone: from now (and until PyType_Ready), make sure to 543 | issue no Python C API calls which could potentially invoke the 544 | garbage collector (the GC will call type_traverse(), which will in 545 | turn find the newly constructed type in an invalid state) */ 546 | auto metaclass = rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr() 547 | : internals.default_metaclass; 548 | 549 | auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); 550 | if (!heap_type) 551 | pybind11_fail(std::string(rec.name) + ": Unable to create type object!"); 552 | 553 | heap_type->ht_name = name.release().ptr(); 554 | #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 555 | heap_type->ht_qualname = ht_qualname.release().ptr(); 556 | #endif 557 | 558 | auto type = &heap_type->ht_type; 559 | type->tp_name = strdup(full_name.c_str()); 560 | type->tp_doc = tp_doc; 561 | type->tp_base = type_incref((PyTypeObject *)base); 562 | type->tp_basicsize = static_cast(sizeof(instance)); 563 | if (bases.size() > 0) 564 | type->tp_bases = bases.release().ptr(); 565 | 566 | /* Don't inherit base __init__ */ 567 | type->tp_init = pybind11_object_init; 568 | 569 | /* Supported protocols */ 570 | type->tp_as_number = &heap_type->as_number; 571 | type->tp_as_sequence = &heap_type->as_sequence; 572 | type->tp_as_mapping = &heap_type->as_mapping; 573 | 574 | /* Flags */ 575 | type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; 576 | #if PY_MAJOR_VERSION < 3 577 | type->tp_flags |= Py_TPFLAGS_CHECKTYPES; 578 | #endif 579 | 580 | if (rec.dynamic_attr) 581 | enable_dynamic_attributes(heap_type); 582 | 583 | if (rec.buffer_protocol) 584 | enable_buffer_protocol(heap_type); 585 | 586 | if (PyType_Ready(type) < 0) 587 | pybind11_fail(std::string(rec.name) + ": PyType_Ready failed (" + error_string() + ")!"); 588 | 589 | assert(rec.dynamic_attr ? PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) 590 | : !PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); 591 | 592 | /* Register type with the parent scope */ 593 | if (rec.scope) 594 | setattr(rec.scope, rec.name, (PyObject *) type); 595 | 596 | if (module) // Needed by pydoc 597 | setattr((PyObject *) type, "__module__", module); 598 | 599 | return (PyObject *) type; 600 | } 601 | 602 | NAMESPACE_END(detail) 603 | NAMESPACE_END(pybind11) 604 | -------------------------------------------------------------------------------- /include/pybind11/common.h: -------------------------------------------------------------------------------- 1 | #include "detail/common.h" 2 | #warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'." 3 | -------------------------------------------------------------------------------- /include/pybind11/complex.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/complex.h: Complex number support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | /// glibc defines I as a macro which breaks things, e.g., boost template names 16 | #ifdef I 17 | # undef I 18 | #endif 19 | 20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 21 | 22 | template struct format_descriptor, detail::enable_if_t::value>> { 23 | static constexpr const char c = format_descriptor::c; 24 | static constexpr const char value[3] = { 'Z', c, '\0' }; 25 | static std::string format() { return std::string(value); } 26 | }; 27 | 28 | #ifndef PYBIND11_CPP17 29 | 30 | template constexpr const char format_descriptor< 31 | std::complex, detail::enable_if_t::value>>::value[3]; 32 | 33 | #endif 34 | 35 | NAMESPACE_BEGIN(detail) 36 | 37 | template struct is_fmt_numeric, detail::enable_if_t::value>> { 38 | static constexpr bool value = true; 39 | static constexpr int index = is_fmt_numeric::index + 3; 40 | }; 41 | 42 | template class type_caster> { 43 | public: 44 | bool load(handle src, bool convert) { 45 | if (!src) 46 | return false; 47 | if (!convert && !PyComplex_Check(src.ptr())) 48 | return false; 49 | Py_complex result = PyComplex_AsCComplex(src.ptr()); 50 | if (result.real == -1.0 && PyErr_Occurred()) { 51 | PyErr_Clear(); 52 | return false; 53 | } 54 | value = std::complex((T) result.real, (T) result.imag); 55 | return true; 56 | } 57 | 58 | static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) { 59 | return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); 60 | } 61 | 62 | PYBIND11_TYPE_CASTER(std::complex, _("complex")); 63 | }; 64 | NAMESPACE_END(detail) 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /include/pybind11/descr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/descr.h: Helper type for concatenating type signatures 3 | either at runtime (C++11) or compile time (C++14) 4 | 5 | Copyright (c) 2016 Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "common.h" 14 | 15 | NAMESPACE_BEGIN(pybind11) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | /* Concatenate type signatures at compile time using C++14 */ 19 | #if defined(PYBIND11_CPP14) && !defined(_MSC_VER) 20 | #define PYBIND11_CONSTEXPR_DESCR 21 | 22 | template class descr { 23 | template friend class descr; 24 | public: 25 | constexpr descr(char const (&text) [Size1+1], const std::type_info * const (&types)[Size2+1]) 26 | : descr(text, types, 27 | make_index_sequence(), 28 | make_index_sequence()) { } 29 | 30 | constexpr const char *text() const { return m_text; } 31 | constexpr const std::type_info * const * types() const { return m_types; } 32 | 33 | template 34 | constexpr descr operator+(const descr &other) const { 35 | return concat(other, 36 | make_index_sequence(), 37 | make_index_sequence(), 38 | make_index_sequence(), 39 | make_index_sequence()); 40 | } 41 | 42 | protected: 43 | template 44 | constexpr descr( 45 | char const (&text) [Size1+1], 46 | const std::type_info * const (&types) [Size2+1], 47 | index_sequence, index_sequence) 48 | : m_text{text[Indices1]..., '\0'}, 49 | m_types{types[Indices2]..., nullptr } {} 50 | 51 | template 53 | constexpr descr 54 | concat(const descr &other, 55 | index_sequence, index_sequence, 56 | index_sequence, index_sequence) const { 57 | return descr( 58 | { m_text[Indices1]..., other.m_text[OtherIndices1]..., '\0' }, 59 | { m_types[Indices2]..., other.m_types[OtherIndices2]..., nullptr } 60 | ); 61 | } 62 | 63 | protected: 64 | char m_text[Size1 + 1]; 65 | const std::type_info * m_types[Size2 + 1]; 66 | }; 67 | 68 | template constexpr descr _(char const(&text)[Size]) { 69 | return descr(text, { nullptr }); 70 | } 71 | 72 | template struct int_to_str : int_to_str { }; 73 | template struct int_to_str<0, Digits...> { 74 | static constexpr auto digits = descr({ ('0' + Digits)..., '\0' }, { nullptr }); 75 | }; 76 | 77 | // Ternary description (like std::conditional) 78 | template 79 | constexpr enable_if_t> _(char const(&text1)[Size1], char const(&)[Size2]) { 80 | return _(text1); 81 | } 82 | template 83 | constexpr enable_if_t> _(char const(&)[Size1], char const(&text2)[Size2]) { 84 | return _(text2); 85 | } 86 | template 87 | constexpr enable_if_t> _(descr d, descr) { return d; } 88 | template 89 | constexpr enable_if_t> _(descr, descr d) { return d; } 90 | 91 | template auto constexpr _() -> decltype(int_to_str::digits) { 92 | return int_to_str::digits; 93 | } 94 | 95 | template constexpr descr<1, 1> _() { 96 | return descr<1, 1>({ '%', '\0' }, { &typeid(Type), nullptr }); 97 | } 98 | 99 | inline constexpr descr<0, 0> concat() { return _(""); } 100 | template auto constexpr concat(descr descr) { return descr; } 101 | template auto constexpr concat(descr descr, Args&&... args) { return descr + _(", ") + concat(args...); } 102 | template auto constexpr type_descr(descr descr) { return _("{") + descr + _("}"); } 103 | 104 | #define PYBIND11_DESCR constexpr auto 105 | 106 | #else /* Simpler C++11 implementation based on run-time memory allocation and copying */ 107 | 108 | class descr { 109 | public: 110 | PYBIND11_NOINLINE descr(const char *text, const std::type_info * const * types) { 111 | size_t nChars = len(text), nTypes = len(types); 112 | m_text = new char[nChars]; 113 | m_types = new const std::type_info *[nTypes]; 114 | memcpy(m_text, text, nChars * sizeof(char)); 115 | memcpy(m_types, types, nTypes * sizeof(const std::type_info *)); 116 | } 117 | 118 | PYBIND11_NOINLINE descr operator+(descr &&d2) && { 119 | descr r; 120 | 121 | size_t nChars1 = len(m_text), nTypes1 = len(m_types); 122 | size_t nChars2 = len(d2.m_text), nTypes2 = len(d2.m_types); 123 | 124 | r.m_text = new char[nChars1 + nChars2 - 1]; 125 | r.m_types = new const std::type_info *[nTypes1 + nTypes2 - 1]; 126 | memcpy(r.m_text, m_text, (nChars1-1) * sizeof(char)); 127 | memcpy(r.m_text + nChars1 - 1, d2.m_text, nChars2 * sizeof(char)); 128 | memcpy(r.m_types, m_types, (nTypes1-1) * sizeof(std::type_info *)); 129 | memcpy(r.m_types + nTypes1 - 1, d2.m_types, nTypes2 * sizeof(std::type_info *)); 130 | 131 | delete[] m_text; delete[] m_types; 132 | delete[] d2.m_text; delete[] d2.m_types; 133 | 134 | return r; 135 | } 136 | 137 | char *text() { return m_text; } 138 | const std::type_info * * types() { return m_types; } 139 | 140 | protected: 141 | PYBIND11_NOINLINE descr() { } 142 | 143 | template static size_t len(const T *ptr) { // return length including null termination 144 | const T *it = ptr; 145 | while (*it++ != (T) 0) 146 | ; 147 | return static_cast(it - ptr); 148 | } 149 | 150 | const std::type_info **m_types = nullptr; 151 | char *m_text = nullptr; 152 | }; 153 | 154 | /* The 'PYBIND11_NOINLINE inline' combinations below are intentional to get the desired linkage while producing as little object code as possible */ 155 | 156 | PYBIND11_NOINLINE inline descr _(const char *text) { 157 | const std::type_info *types[1] = { nullptr }; 158 | return descr(text, types); 159 | } 160 | 161 | template PYBIND11_NOINLINE enable_if_t _(const char *text1, const char *) { return _(text1); } 162 | template PYBIND11_NOINLINE enable_if_t _(char const *, const char *text2) { return _(text2); } 163 | template PYBIND11_NOINLINE enable_if_t _(descr d, descr) { return d; } 164 | template PYBIND11_NOINLINE enable_if_t _(descr, descr d) { return d; } 165 | 166 | template PYBIND11_NOINLINE descr _() { 167 | const std::type_info *types[2] = { &typeid(Type), nullptr }; 168 | return descr("%", types); 169 | } 170 | 171 | template PYBIND11_NOINLINE descr _() { 172 | const std::type_info *types[1] = { nullptr }; 173 | return descr(std::to_string(Size).c_str(), types); 174 | } 175 | 176 | PYBIND11_NOINLINE inline descr concat() { return _(""); } 177 | PYBIND11_NOINLINE inline descr concat(descr &&d) { return d; } 178 | template PYBIND11_NOINLINE descr concat(descr &&d, Args&&... args) { return std::move(d) + _(", ") + concat(std::forward(args)...); } 179 | PYBIND11_NOINLINE inline descr type_descr(descr&& d) { return _("{") + std::move(d) + _("}"); } 180 | 181 | #define PYBIND11_DESCR ::pybind11::detail::descr 182 | #endif 183 | 184 | NAMESPACE_END(detail) 185 | NAMESPACE_END(pybind11) 186 | -------------------------------------------------------------------------------- /include/pybind11/detail/descr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | NAMESPACE_BEGIN(detail) 16 | 17 | #if !defined(_MSC_VER) 18 | # define PYBIND11_DESCR_CONSTEXPR static constexpr 19 | #else 20 | # define PYBIND11_DESCR_CONSTEXPR const 21 | #endif 22 | 23 | /* Concatenate type signatures at compile time */ 24 | template 25 | struct descr { 26 | char text[N + 1]; 27 | 28 | constexpr descr() : text{'\0'} { } 29 | constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence()) { } 30 | 31 | template 32 | constexpr descr(char const (&s)[N+1], index_sequence) : text{s[Is]..., '\0'} { } 33 | 34 | template 35 | constexpr descr(char c, Chars... cs) : text{c, static_cast(cs)..., '\0'} { } 36 | 37 | static constexpr std::array types() { 38 | return {{&typeid(Ts)..., nullptr}}; 39 | } 40 | }; 41 | 42 | template 43 | constexpr descr plus_impl(const descr &a, const descr &b, 44 | index_sequence, index_sequence) { 45 | return {a.text[Is1]..., b.text[Is2]...}; 46 | } 47 | 48 | template 49 | constexpr descr operator+(const descr &a, const descr &b) { 50 | return plus_impl(a, b, make_index_sequence(), make_index_sequence()); 51 | } 52 | 53 | template 54 | constexpr descr _(char const(&text)[N]) { return descr(text); } 55 | constexpr descr<0> _(char const(&)[1]) { return {}; } 56 | 57 | template struct int_to_str : int_to_str { }; 58 | template struct int_to_str<0, Digits...> { 59 | static constexpr auto digits = descr(('0' + Digits)...); 60 | }; 61 | 62 | // Ternary description (like std::conditional) 63 | template 64 | constexpr enable_if_t> _(char const(&text1)[N1], char const(&)[N2]) { 65 | return _(text1); 66 | } 67 | template 68 | constexpr enable_if_t> _(char const(&)[N1], char const(&text2)[N2]) { 69 | return _(text2); 70 | } 71 | 72 | template 73 | constexpr enable_if_t _(const T1 &d, const T2 &) { return d; } 74 | template 75 | constexpr enable_if_t _(const T1 &, const T2 &d) { return d; } 76 | 77 | template auto constexpr _() -> decltype(int_to_str::digits) { 78 | return int_to_str::digits; 79 | } 80 | 81 | template constexpr descr<1, Type> _() { return {'%'}; } 82 | 83 | constexpr descr<0> concat() { return {}; } 84 | 85 | template 86 | constexpr descr concat(const descr &descr) { return descr; } 87 | 88 | template 89 | constexpr auto concat(const descr &d, const Args &...args) 90 | -> decltype(std::declval>() + concat(args...)) { 91 | return d + _(", ") + concat(args...); 92 | } 93 | 94 | template 95 | constexpr descr type_descr(const descr &descr) { 96 | return _("{") + descr + _("}"); 97 | } 98 | 99 | NAMESPACE_END(detail) 100 | NAMESPACE_END(PYBIND11_NAMESPACE) 101 | -------------------------------------------------------------------------------- /include/pybind11/detail/init.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/init.h: init factory function implementation and support code. 3 | 4 | Copyright (c) 2017 Jason Rhinelander 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "class.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | NAMESPACE_BEGIN(detail) 16 | 17 | template <> 18 | class type_caster { 19 | public: 20 | bool load(handle h, bool) { 21 | value = reinterpret_cast(h.ptr()); 22 | return true; 23 | } 24 | 25 | template using cast_op_type = value_and_holder &; 26 | operator value_and_holder &() { return *value; } 27 | static constexpr auto name = _(); 28 | 29 | private: 30 | value_and_holder *value = nullptr; 31 | }; 32 | 33 | NAMESPACE_BEGIN(initimpl) 34 | 35 | inline void no_nullptr(void *ptr) { 36 | if (!ptr) throw type_error("pybind11::init(): factory function returned nullptr"); 37 | } 38 | 39 | // Implementing functions for all forms of py::init<...> and py::init(...) 40 | template using Cpp = typename Class::type; 41 | template using Alias = typename Class::type_alias; 42 | template using Holder = typename Class::holder_type; 43 | 44 | template using is_alias_constructible = std::is_constructible, Cpp &&>; 45 | 46 | // Takes a Cpp pointer and returns true if it actually is a polymorphic Alias instance. 47 | template = 0> 48 | bool is_alias(Cpp *ptr) { 49 | return dynamic_cast *>(ptr) != nullptr; 50 | } 51 | // Failing fallback version of the above for a no-alias class (always returns false) 52 | template 53 | constexpr bool is_alias(void *) { return false; } 54 | 55 | // Constructs and returns a new object; if the given arguments don't map to a constructor, we fall 56 | // back to brace aggregate initiailization so that for aggregate initialization can be used with 57 | // py::init, e.g. `py::init` to initialize a `struct T { int a; int b; }`. For 58 | // non-aggregate types, we need to use an ordinary T(...) constructor (invoking as `T{...}` usually 59 | // works, but will not do the expected thing when `T` has an `initializer_list` constructor). 60 | template ::value, int> = 0> 61 | inline Class *construct_or_initialize(Args &&...args) { return new Class(std::forward(args)...); } 62 | template ::value, int> = 0> 63 | inline Class *construct_or_initialize(Args &&...args) { return new Class{std::forward(args)...}; } 64 | 65 | // Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This allows types with 66 | // an alias to provide only a single Cpp factory function as long as the Alias can be 67 | // constructed from an rvalue reference of the base Cpp type. This means that Alias classes 68 | // can, when appropriate, simply define a `Alias(Cpp &&)` constructor rather than needing to 69 | // inherit all the base class constructors. 70 | template 71 | void construct_alias_from_cpp(std::true_type /*is_alias_constructible*/, 72 | value_and_holder &v_h, Cpp &&base) { 73 | v_h.value_ptr() = new Alias(std::move(base)); 74 | } 75 | template 76 | [[noreturn]] void construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/, 77 | value_and_holder &, Cpp &&) { 78 | throw type_error("pybind11::init(): unable to convert returned instance to required " 79 | "alias class: no `Alias(Class &&)` constructor available"); 80 | } 81 | 82 | // Error-generating fallback for factories that don't match one of the below construction 83 | // mechanisms. 84 | template 85 | void construct(...) { 86 | static_assert(!std::is_same::value /* always false */, 87 | "pybind11::init(): init function must return a compatible pointer, " 88 | "holder, or value"); 89 | } 90 | 91 | // Pointer return v1: the factory function returns a class pointer for a registered class. 92 | // If we don't need an alias (because this class doesn't have one, or because the final type is 93 | // inherited on the Python side) we can simply take over ownership. Otherwise we need to try to 94 | // construct an Alias from the returned base instance. 95 | template 96 | void construct(value_and_holder &v_h, Cpp *ptr, bool need_alias) { 97 | no_nullptr(ptr); 98 | if (Class::has_alias && need_alias && !is_alias(ptr)) { 99 | // We're going to try to construct an alias by moving the cpp type. Whether or not 100 | // that succeeds, we still need to destroy the original cpp pointer (either the 101 | // moved away leftover, if the alias construction works, or the value itself if we 102 | // throw an error), but we can't just call `delete ptr`: it might have a special 103 | // deleter, or might be shared_from_this. So we construct a holder around it as if 104 | // it was a normal instance, then steal the holder away into a local variable; thus 105 | // the holder and destruction happens when we leave the C++ scope, and the holder 106 | // class gets to handle the destruction however it likes. 107 | v_h.value_ptr() = ptr; 108 | v_h.set_instance_registered(true); // To prevent init_instance from registering it 109 | v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder 110 | Holder temp_holder(std::move(v_h.holder>())); // Steal the holder 111 | v_h.type->dealloc(v_h); // Destroys the moved-out holder remains, resets value ptr to null 112 | v_h.set_instance_registered(false); 113 | 114 | construct_alias_from_cpp(is_alias_constructible{}, v_h, std::move(*ptr)); 115 | } else { 116 | // Otherwise the type isn't inherited, so we don't need an Alias 117 | v_h.value_ptr() = ptr; 118 | } 119 | } 120 | 121 | // Pointer return v2: a factory that always returns an alias instance ptr. We simply take over 122 | // ownership of the pointer. 123 | template = 0> 124 | void construct(value_and_holder &v_h, Alias *alias_ptr, bool) { 125 | no_nullptr(alias_ptr); 126 | v_h.value_ptr() = static_cast *>(alias_ptr); 127 | } 128 | 129 | // Holder return: copy its pointer, and move or copy the returned holder into the new instance's 130 | // holder. This also handles types like std::shared_ptr and std::unique_ptr where T is a 131 | // derived type (through those holder's implicit conversion from derived class holder constructors). 132 | template 133 | void construct(value_and_holder &v_h, Holder holder, bool need_alias) { 134 | auto *ptr = holder_helper>::get(holder); 135 | // If we need an alias, check that the held pointer is actually an alias instance 136 | if (Class::has_alias && need_alias && !is_alias(ptr)) 137 | throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance " 138 | "is not an alias instance"); 139 | 140 | v_h.value_ptr() = ptr; 141 | v_h.type->init_instance(v_h.inst, &holder); 142 | } 143 | 144 | // return-by-value version 1: returning a cpp class by value. If the class has an alias and an 145 | // alias is required the alias must have an `Alias(Cpp &&)` constructor so that we can construct 146 | // the alias from the base when needed (i.e. because of Python-side inheritance). When we don't 147 | // need it, we simply move-construct the cpp value into a new instance. 148 | template 149 | void construct(value_and_holder &v_h, Cpp &&result, bool need_alias) { 150 | static_assert(std::is_move_constructible>::value, 151 | "pybind11::init() return-by-value factory function requires a movable class"); 152 | if (Class::has_alias && need_alias) 153 | construct_alias_from_cpp(is_alias_constructible{}, v_h, std::move(result)); 154 | else 155 | v_h.value_ptr() = new Cpp(std::move(result)); 156 | } 157 | 158 | // return-by-value version 2: returning a value of the alias type itself. We move-construct an 159 | // Alias instance (even if no the python-side inheritance is involved). The is intended for 160 | // cases where Alias initialization is always desired. 161 | template 162 | void construct(value_and_holder &v_h, Alias &&result, bool) { 163 | static_assert(std::is_move_constructible>::value, 164 | "pybind11::init() return-by-alias-value factory function requires a movable alias class"); 165 | v_h.value_ptr() = new Alias(std::move(result)); 166 | } 167 | 168 | // Implementing class for py::init<...>() 169 | template 170 | struct constructor { 171 | template = 0> 172 | static void execute(Class &cl, const Extra&... extra) { 173 | cl.def("__init__", [](value_and_holder &v_h, Args... args) { 174 | v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); 175 | }, is_new_style_constructor(), extra...); 176 | } 177 | 178 | template , Args...>::value, int> = 0> 181 | static void execute(Class &cl, const Extra&... extra) { 182 | cl.def("__init__", [](value_and_holder &v_h, Args... args) { 183 | if (Py_TYPE(v_h.inst) == v_h.type->type) 184 | v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); 185 | else 186 | v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); 187 | }, is_new_style_constructor(), extra...); 188 | } 189 | 190 | template , Args...>::value, int> = 0> 193 | static void execute(Class &cl, const Extra&... extra) { 194 | cl.def("__init__", [](value_and_holder &v_h, Args... args) { 195 | v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); 196 | }, is_new_style_constructor(), extra...); 197 | } 198 | }; 199 | 200 | // Implementing class for py::init_alias<...>() 201 | template struct alias_constructor { 202 | template , Args...>::value, int> = 0> 204 | static void execute(Class &cl, const Extra&... extra) { 205 | cl.def("__init__", [](value_and_holder &v_h, Args... args) { 206 | v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); 207 | }, is_new_style_constructor(), extra...); 208 | } 209 | }; 210 | 211 | // Implementation class for py::init(Func) and py::init(Func, AliasFunc) 212 | template , typename = function_signature_t> 214 | struct factory; 215 | 216 | // Specialization for py::init(Func) 217 | template 218 | struct factory { 219 | remove_reference_t class_factory; 220 | 221 | factory(Func &&f) : class_factory(std::forward(f)) { } 222 | 223 | // The given class either has no alias or has no separate alias factory; 224 | // this always constructs the class itself. If the class is registered with an alias 225 | // type and an alias instance is needed (i.e. because the final type is a Python class 226 | // inheriting from the C++ type) the returned value needs to either already be an alias 227 | // instance, or the alias needs to be constructible from a `Class &&` argument. 228 | template 229 | void execute(Class &cl, const Extra &...extra) && { 230 | #if defined(PYBIND11_CPP14) 231 | cl.def("__init__", [func = std::move(class_factory)] 232 | #else 233 | auto &func = class_factory; 234 | cl.def("__init__", [func] 235 | #endif 236 | (value_and_holder &v_h, Args... args) { 237 | construct(v_h, func(std::forward(args)...), 238 | Py_TYPE(v_h.inst) != v_h.type->type); 239 | }, is_new_style_constructor(), extra...); 240 | } 241 | }; 242 | 243 | // Specialization for py::init(Func, AliasFunc) 244 | template 246 | struct factory { 247 | static_assert(sizeof...(CArgs) == sizeof...(AArgs), 248 | "pybind11::init(class_factory, alias_factory): class and alias factories " 249 | "must have identical argument signatures"); 250 | static_assert(all_of...>::value, 251 | "pybind11::init(class_factory, alias_factory): class and alias factories " 252 | "must have identical argument signatures"); 253 | 254 | remove_reference_t class_factory; 255 | remove_reference_t alias_factory; 256 | 257 | factory(CFunc &&c, AFunc &&a) 258 | : class_factory(std::forward(c)), alias_factory(std::forward(a)) { } 259 | 260 | // The class factory is called when the `self` type passed to `__init__` is the direct 261 | // class (i.e. not inherited), the alias factory when `self` is a Python-side subtype. 262 | template 263 | void execute(Class &cl, const Extra&... extra) && { 264 | static_assert(Class::has_alias, "The two-argument version of `py::init()` can " 265 | "only be used if the class has an alias"); 266 | #if defined(PYBIND11_CPP14) 267 | cl.def("__init__", [class_func = std::move(class_factory), alias_func = std::move(alias_factory)] 268 | #else 269 | auto &class_func = class_factory; 270 | auto &alias_func = alias_factory; 271 | cl.def("__init__", [class_func, alias_func] 272 | #endif 273 | (value_and_holder &v_h, CArgs... args) { 274 | if (Py_TYPE(v_h.inst) == v_h.type->type) 275 | // If the instance type equals the registered type we don't have inheritance, so 276 | // don't need the alias and can construct using the class function: 277 | construct(v_h, class_func(std::forward(args)...), false); 278 | else 279 | construct(v_h, alias_func(std::forward(args)...), true); 280 | }, is_new_style_constructor(), extra...); 281 | } 282 | }; 283 | 284 | /// Set just the C++ state. Same as `__init__`. 285 | template 286 | void setstate(value_and_holder &v_h, T &&result, bool need_alias) { 287 | construct(v_h, std::forward(result), need_alias); 288 | } 289 | 290 | /// Set both the C++ and Python states 291 | template ::value, int> = 0> 293 | void setstate(value_and_holder &v_h, std::pair &&result, bool need_alias) { 294 | construct(v_h, std::move(result.first), need_alias); 295 | setattr((PyObject *) v_h.inst, "__dict__", result.second); 296 | } 297 | 298 | /// Implementation for py::pickle(GetState, SetState) 299 | template , typename = function_signature_t> 301 | struct pickle_factory; 302 | 303 | template 305 | struct pickle_factory { 306 | static_assert(std::is_same, intrinsic_t>::value, 307 | "The type returned by `__getstate__` must be the same " 308 | "as the argument accepted by `__setstate__`"); 309 | 310 | remove_reference_t get; 311 | remove_reference_t set; 312 | 313 | pickle_factory(Get get, Set set) 314 | : get(std::forward(get)), set(std::forward(set)) { } 315 | 316 | template 317 | void execute(Class &cl, const Extra &...extra) && { 318 | cl.def("__getstate__", std::move(get)); 319 | 320 | #if defined(PYBIND11_CPP14) 321 | cl.def("__setstate__", [func = std::move(set)] 322 | #else 323 | auto &func = set; 324 | cl.def("__setstate__", [func] 325 | #endif 326 | (value_and_holder &v_h, ArgState state) { 327 | setstate(v_h, func(std::forward(state)), 328 | Py_TYPE(v_h.inst) != v_h.type->type); 329 | }, is_new_style_constructor(), extra...); 330 | } 331 | }; 332 | 333 | NAMESPACE_END(initimpl) 334 | NAMESPACE_END(detail) 335 | NAMESPACE_END(pybind11) 336 | -------------------------------------------------------------------------------- /include/pybind11/detail/internals.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/internals.h: Internal data structure and related functions 3 | 4 | Copyright (c) 2017 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "../pytypes.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | NAMESPACE_BEGIN(detail) 16 | // Forward declarations 17 | inline PyTypeObject *make_static_property_type(); 18 | inline PyTypeObject *make_default_metaclass(); 19 | inline PyObject *make_object_base_type(PyTypeObject *metaclass); 20 | 21 | // The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new 22 | // Thread Specific Storage (TSS) API. 23 | #if PY_VERSION_HEX >= 0x03070000 24 | # define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr 25 | # define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key)) 26 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (tstate)) 27 | # define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr) 28 | #else 29 | // Usually an int but a long on Cygwin64 with Python 3.x 30 | # define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0 31 | # define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key)) 32 | # if PY_MAJOR_VERSION < 3 33 | # define PYBIND11_TLS_DELETE_VALUE(key) \ 34 | PyThread_delete_key_value(key) 35 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) \ 36 | do { \ 37 | PyThread_delete_key_value((key)); \ 38 | PyThread_set_key_value((key), (value)); \ 39 | } while (false) 40 | # else 41 | # define PYBIND11_TLS_DELETE_VALUE(key) \ 42 | PyThread_set_key_value((key), nullptr) 43 | # define PYBIND11_TLS_REPLACE_VALUE(key, value) \ 44 | PyThread_set_key_value((key), (value)) 45 | # endif 46 | #endif 47 | 48 | // Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly 49 | // other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module 50 | // even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under 51 | // libstdc++, this doesn't happen: equality and the type_index hash are based on the type name, 52 | // which works. If not under a known-good stl, provide our own name-based hash and equality 53 | // functions that use the type name. 54 | #if defined(__GLIBCXX__) 55 | inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; } 56 | using type_hash = std::hash; 57 | using type_equal_to = std::equal_to; 58 | #else 59 | inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { 60 | return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0; 61 | } 62 | 63 | struct type_hash { 64 | size_t operator()(const std::type_index &t) const { 65 | size_t hash = 5381; 66 | const char *ptr = t.name(); 67 | while (auto c = static_cast(*ptr++)) 68 | hash = (hash * 33) ^ c; 69 | return hash; 70 | } 71 | }; 72 | 73 | struct type_equal_to { 74 | bool operator()(const std::type_index &lhs, const std::type_index &rhs) const { 75 | return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0; 76 | } 77 | }; 78 | #endif 79 | 80 | template 81 | using type_map = std::unordered_map; 82 | 83 | struct overload_hash { 84 | inline size_t operator()(const std::pair& v) const { 85 | size_t value = std::hash()(v.first); 86 | value ^= std::hash()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2); 87 | return value; 88 | } 89 | }; 90 | 91 | /// Internal data structure used to track registered instances and types. 92 | /// Whenever binary incompatible changes are made to this structure, 93 | /// `PYBIND11_INTERNALS_VERSION` must be incremented. 94 | struct internals { 95 | type_map registered_types_cpp; // std::type_index -> pybind11's type information 96 | std::unordered_map> registered_types_py; // PyTypeObject* -> base type_info(s) 97 | std::unordered_multimap registered_instances; // void * -> instance* 98 | std::unordered_set, overload_hash> inactive_overload_cache; 99 | type_map> direct_conversions; 100 | std::unordered_map> patients; 101 | std::forward_list registered_exception_translators; 102 | std::unordered_map shared_data; // Custom data to be shared across extensions 103 | std::vector loader_patient_stack; // Used by `loader_life_support` 104 | std::forward_list static_strings; // Stores the std::strings backing detail::c_str() 105 | PyTypeObject *static_property_type; 106 | PyTypeObject *default_metaclass; 107 | PyObject *instance_base; 108 | #if defined(WITH_THREAD) 109 | PYBIND11_TLS_KEY_INIT(tstate); 110 | PyInterpreterState *istate = nullptr; 111 | #endif 112 | }; 113 | 114 | /// Additional type information which does not fit into the PyTypeObject. 115 | /// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`. 116 | struct type_info { 117 | PyTypeObject *type; 118 | const std::type_info *cpptype; 119 | size_t type_size, type_align, holder_size_in_ptrs; 120 | void *(*operator_new)(size_t); 121 | void (*init_instance)(instance *, const void *); 122 | void (*dealloc)(value_and_holder &v_h); 123 | std::vector implicit_conversions; 124 | std::vector> implicit_casts; 125 | std::vector *direct_conversions; 126 | buffer_info *(*get_buffer)(PyObject *, void *) = nullptr; 127 | void *get_buffer_data = nullptr; 128 | void *(*module_local_load)(PyObject *, const type_info *) = nullptr; 129 | /* A simple type never occurs as a (direct or indirect) parent 130 | * of a class that makes use of multiple inheritance */ 131 | bool simple_type : 1; 132 | /* True if there is no multiple inheritance in this type's inheritance tree */ 133 | bool simple_ancestors : 1; 134 | /* for base vs derived holder_type checks */ 135 | bool default_holder : 1; 136 | /* true if this is a type registered with py::module_local */ 137 | bool module_local : 1; 138 | }; 139 | 140 | /// Tracks the `internals` and `type_info` ABI version independent of the main library version 141 | #define PYBIND11_INTERNALS_VERSION 3 142 | 143 | #if defined(_DEBUG) 144 | # define PYBIND11_BUILD_TYPE "_debug" 145 | #else 146 | # define PYBIND11_BUILD_TYPE "" 147 | #endif 148 | 149 | #if defined(WITH_THREAD) 150 | # define PYBIND11_INTERNALS_KIND "" 151 | #else 152 | # define PYBIND11_INTERNALS_KIND "_without_thread" 153 | #endif 154 | 155 | #define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \ 156 | PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__" 157 | 158 | #define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \ 159 | PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_BUILD_TYPE "__" 160 | 161 | /// Each module locally stores a pointer to the `internals` data. The data 162 | /// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`. 163 | inline internals **&get_internals_pp() { 164 | static internals **internals_pp = nullptr; 165 | return internals_pp; 166 | } 167 | 168 | /// Return a reference to the current `internals` data 169 | PYBIND11_NOINLINE inline internals &get_internals() { 170 | auto **&internals_pp = get_internals_pp(); 171 | if (internals_pp && *internals_pp) 172 | return **internals_pp; 173 | 174 | constexpr auto *id = PYBIND11_INTERNALS_ID; 175 | auto builtins = handle(PyEval_GetBuiltins()); 176 | if (builtins.contains(id) && isinstance(builtins[id])) { 177 | internals_pp = static_cast(capsule(builtins[id])); 178 | 179 | // We loaded builtins through python's builtins, which means that our `error_already_set` 180 | // and `builtin_exception` may be different local classes than the ones set up in the 181 | // initial exception translator, below, so add another for our local exception classes. 182 | // 183 | // libstdc++ doesn't require this (types there are identified only by name) 184 | #if !defined(__GLIBCXX__) 185 | (*internals_pp)->registered_exception_translators.push_front( 186 | [](std::exception_ptr p) -> void { 187 | try { 188 | if (p) std::rethrow_exception(p); 189 | } catch (error_already_set &e) { e.restore(); return; 190 | } catch (const builtin_exception &e) { e.set_error(); return; 191 | } 192 | } 193 | ); 194 | #endif 195 | } else { 196 | if (!internals_pp) internals_pp = new internals*(); 197 | auto *&internals_ptr = *internals_pp; 198 | internals_ptr = new internals(); 199 | #if defined(WITH_THREAD) 200 | PyEval_InitThreads(); 201 | PyThreadState *tstate = PyThreadState_Get(); 202 | #if PY_VERSION_HEX >= 0x03070000 203 | internals_ptr->tstate = PyThread_tss_alloc(); 204 | if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate)) 205 | pybind11_fail("get_internals: could not successfully initialize the TSS key!"); 206 | PyThread_tss_set(internals_ptr->tstate, tstate); 207 | #else 208 | internals_ptr->tstate = PyThread_create_key(); 209 | if (internals_ptr->tstate == -1) 210 | pybind11_fail("get_internals: could not successfully initialize the TLS key!"); 211 | PyThread_set_key_value(internals_ptr->tstate, tstate); 212 | #endif 213 | internals_ptr->istate = tstate->interp; 214 | #endif 215 | builtins[id] = capsule(internals_pp); 216 | internals_ptr->registered_exception_translators.push_front( 217 | [](std::exception_ptr p) -> void { 218 | try { 219 | if (p) std::rethrow_exception(p); 220 | } catch (error_already_set &e) { e.restore(); return; 221 | } catch (const builtin_exception &e) { e.set_error(); return; 222 | } catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return; 223 | } catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 224 | } catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 225 | } catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 226 | } catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return; 227 | } catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; 228 | } catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return; 229 | } catch (...) { 230 | PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!"); 231 | return; 232 | } 233 | } 234 | ); 235 | internals_ptr->static_property_type = make_static_property_type(); 236 | internals_ptr->default_metaclass = make_default_metaclass(); 237 | internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass); 238 | } 239 | return **internals_pp; 240 | } 241 | 242 | /// Works like `internals.registered_types_cpp`, but for module-local registered types: 243 | inline type_map ®istered_local_types_cpp() { 244 | static type_map locals{}; 245 | return locals; 246 | } 247 | 248 | /// Constructs a std::string with the given arguments, stores it in `internals`, and returns its 249 | /// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only 250 | /// cleared when the program exits or after interpreter shutdown (when embedding), and so are 251 | /// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name). 252 | template 253 | const char *c_str(Args &&...args) { 254 | auto &strings = get_internals().static_strings; 255 | strings.emplace_front(std::forward(args)...); 256 | return strings.front().c_str(); 257 | } 258 | 259 | NAMESPACE_END(detail) 260 | 261 | /// Returns a named pointer that is shared among all extension modules (using the same 262 | /// pybind11 version) running in the current interpreter. Names starting with underscores 263 | /// are reserved for internal usage. Returns `nullptr` if no matching entry was found. 264 | inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) { 265 | auto &internals = detail::get_internals(); 266 | auto it = internals.shared_data.find(name); 267 | return it != internals.shared_data.end() ? it->second : nullptr; 268 | } 269 | 270 | /// Set the shared data that can be later recovered by `get_shared_data()`. 271 | inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) { 272 | detail::get_internals().shared_data[name] = data; 273 | return data; 274 | } 275 | 276 | /// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if 277 | /// such entry exists. Otherwise, a new object of default-constructible type `T` is 278 | /// added to the shared data under the given name and a reference to it is returned. 279 | template 280 | T &get_or_create_shared_data(const std::string &name) { 281 | auto &internals = detail::get_internals(); 282 | auto it = internals.shared_data.find(name); 283 | T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr); 284 | if (!ptr) { 285 | ptr = new T(); 286 | internals.shared_data[name] = ptr; 287 | } 288 | return *ptr; 289 | } 290 | 291 | NAMESPACE_END(PYBIND11_NAMESPACE) 292 | -------------------------------------------------------------------------------- /include/pybind11/detail/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(PYBIND11_NAMESPACE) 54 | -------------------------------------------------------------------------------- /include/pybind11/embed.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/embed.h: Support for embedding the interpreter 3 | 4 | Copyright (c) 2017 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include "eval.h" 14 | 15 | #if defined(PYPY_VERSION) 16 | # error Embedding the interpreter is not supported with PyPy 17 | #endif 18 | 19 | #if PY_MAJOR_VERSION >= 3 20 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 21 | extern "C" PyObject *pybind11_init_impl_##name() { \ 22 | return pybind11_init_wrapper_##name(); \ 23 | } 24 | #else 25 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 26 | extern "C" void pybind11_init_impl_##name() { \ 27 | pybind11_init_wrapper_##name(); \ 28 | } 29 | #endif 30 | 31 | /** \rst 32 | Add a new module to the table of builtins for the interpreter. Must be 33 | defined in global scope. The first macro parameter is the name of the 34 | module (without quotes). The second parameter is the variable which will 35 | be used as the interface to add functions and classes to the module. 36 | 37 | .. code-block:: cpp 38 | 39 | PYBIND11_EMBEDDED_MODULE(example, m) { 40 | // ... initialize functions and classes here 41 | m.def("foo", []() { 42 | return "Hello, World!"; 43 | }); 44 | } 45 | \endrst */ 46 | #define PYBIND11_EMBEDDED_MODULE(name, variable) \ 47 | static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \ 48 | static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \ 49 | auto m = pybind11::module(PYBIND11_TOSTRING(name)); \ 50 | try { \ 51 | PYBIND11_CONCAT(pybind11_init_, name)(m); \ 52 | return m.ptr(); \ 53 | } catch (pybind11::error_already_set &e) { \ 54 | PyErr_SetString(PyExc_ImportError, e.what()); \ 55 | return nullptr; \ 56 | } catch (const std::exception &e) { \ 57 | PyErr_SetString(PyExc_ImportError, e.what()); \ 58 | return nullptr; \ 59 | } \ 60 | } \ 61 | PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 62 | pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \ 63 | PYBIND11_CONCAT(pybind11_init_impl_, name)); \ 64 | void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable) 65 | 66 | 67 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 68 | NAMESPACE_BEGIN(detail) 69 | 70 | /// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks. 71 | struct embedded_module { 72 | #if PY_MAJOR_VERSION >= 3 73 | using init_t = PyObject *(*)(); 74 | #else 75 | using init_t = void (*)(); 76 | #endif 77 | embedded_module(const char *name, init_t init) { 78 | if (Py_IsInitialized()) 79 | pybind11_fail("Can't add new modules after the interpreter has been initialized"); 80 | 81 | auto result = PyImport_AppendInittab(name, init); 82 | if (result == -1) 83 | pybind11_fail("Insufficient memory to add a new module"); 84 | } 85 | }; 86 | 87 | NAMESPACE_END(detail) 88 | 89 | /** \rst 90 | Initialize the Python interpreter. No other pybind11 or CPython API functions can be 91 | called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The 92 | optional parameter can be used to skip the registration of signal handlers (see the 93 | `Python documentation`_ for details). Calling this function again after the interpreter 94 | has already been initialized is a fatal error. 95 | 96 | If initializing the Python interpreter fails, then the program is terminated. (This 97 | is controlled by the CPython runtime and is an exception to pybind11's normal behavior 98 | of throwing exceptions on errors.) 99 | 100 | .. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx 101 | \endrst */ 102 | inline void initialize_interpreter(bool init_signal_handlers = true) { 103 | if (Py_IsInitialized()) 104 | pybind11_fail("The interpreter is already running"); 105 | 106 | Py_InitializeEx(init_signal_handlers ? 1 : 0); 107 | 108 | // Make .py files in the working directory available by default 109 | module::import("sys").attr("path").cast().append("."); 110 | } 111 | 112 | /** \rst 113 | Shut down the Python interpreter. No pybind11 or CPython API functions can be called 114 | after this. In addition, pybind11 objects must not outlive the interpreter: 115 | 116 | .. code-block:: cpp 117 | 118 | { // BAD 119 | py::initialize_interpreter(); 120 | auto hello = py::str("Hello, World!"); 121 | py::finalize_interpreter(); 122 | } // <-- BOOM, hello's destructor is called after interpreter shutdown 123 | 124 | { // GOOD 125 | py::initialize_interpreter(); 126 | { // scoped 127 | auto hello = py::str("Hello, World!"); 128 | } // <-- OK, hello is cleaned up properly 129 | py::finalize_interpreter(); 130 | } 131 | 132 | { // BETTER 133 | py::scoped_interpreter guard{}; 134 | auto hello = py::str("Hello, World!"); 135 | } 136 | 137 | .. warning:: 138 | 139 | The interpreter can be restarted by calling `initialize_interpreter` again. 140 | Modules created using pybind11 can be safely re-initialized. However, Python 141 | itself cannot completely unload binary extension modules and there are several 142 | caveats with regard to interpreter restarting. All the details can be found 143 | in the CPython documentation. In short, not all interpreter memory may be 144 | freed, either due to reference cycles or user-created global data. 145 | 146 | \endrst */ 147 | inline void finalize_interpreter() { 148 | handle builtins(PyEval_GetBuiltins()); 149 | const char *id = PYBIND11_INTERNALS_ID; 150 | 151 | // Get the internals pointer (without creating it if it doesn't exist). It's possible for the 152 | // internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()` 153 | // during destruction), so we get the pointer-pointer here and check it after Py_Finalize(). 154 | detail::internals **internals_ptr_ptr = detail::get_internals_pp(); 155 | // It could also be stashed in builtins, so look there too: 156 | if (builtins.contains(id) && isinstance(builtins[id])) 157 | internals_ptr_ptr = capsule(builtins[id]); 158 | 159 | Py_Finalize(); 160 | 161 | if (internals_ptr_ptr) { 162 | delete *internals_ptr_ptr; 163 | *internals_ptr_ptr = nullptr; 164 | } 165 | } 166 | 167 | /** \rst 168 | Scope guard version of `initialize_interpreter` and `finalize_interpreter`. 169 | This a move-only guard and only a single instance can exist. 170 | 171 | .. code-block:: cpp 172 | 173 | #include 174 | 175 | int main() { 176 | py::scoped_interpreter guard{}; 177 | py::print(Hello, World!); 178 | } // <-- interpreter shutdown 179 | \endrst */ 180 | class scoped_interpreter { 181 | public: 182 | scoped_interpreter(bool init_signal_handlers = true) { 183 | initialize_interpreter(init_signal_handlers); 184 | } 185 | 186 | scoped_interpreter(const scoped_interpreter &) = delete; 187 | scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; } 188 | scoped_interpreter &operator=(const scoped_interpreter &) = delete; 189 | scoped_interpreter &operator=(scoped_interpreter &&) = delete; 190 | 191 | ~scoped_interpreter() { 192 | if (is_valid) 193 | finalize_interpreter(); 194 | } 195 | 196 | private: 197 | bool is_valid = true; 198 | }; 199 | 200 | NAMESPACE_END(PYBIND11_NAMESPACE) 201 | -------------------------------------------------------------------------------- /include/pybind11/eval.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/exec.h: Support for evaluating Python expressions and statements 3 | from strings and files 4 | 5 | Copyright (c) 2016 Klemens Morgenstern and 6 | Wenzel Jakob 7 | 8 | All rights reserved. Use of this source code is governed by a 9 | BSD-style license that can be found in the LICENSE file. 10 | */ 11 | 12 | #pragma once 13 | 14 | #include "pybind11.h" 15 | 16 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 17 | 18 | enum eval_mode { 19 | /// Evaluate a string containing an isolated expression 20 | eval_expr, 21 | 22 | /// Evaluate a string containing a single statement. Returns \c none 23 | eval_single_statement, 24 | 25 | /// Evaluate a string containing a sequence of statement. Returns \c none 26 | eval_statements 27 | }; 28 | 29 | template 30 | object eval(str expr, object global = globals(), object local = object()) { 31 | if (!local) 32 | local = global; 33 | 34 | /* PyRun_String does not accept a PyObject / encoding specifier, 35 | this seems to be the only alternative */ 36 | std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr; 37 | 38 | int start; 39 | switch (mode) { 40 | case eval_expr: start = Py_eval_input; break; 41 | case eval_single_statement: start = Py_single_input; break; 42 | case eval_statements: start = Py_file_input; break; 43 | default: pybind11_fail("invalid evaluation mode"); 44 | } 45 | 46 | PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr()); 47 | if (!result) 48 | throw error_already_set(); 49 | return reinterpret_steal(result); 50 | } 51 | 52 | template 53 | object eval(const char (&s)[N], object global = globals(), object local = object()) { 54 | /* Support raw string literals by removing common leading whitespace */ 55 | auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s)) 56 | : str(s); 57 | return eval(expr, global, local); 58 | } 59 | 60 | inline void exec(str expr, object global = globals(), object local = object()) { 61 | eval(expr, global, local); 62 | } 63 | 64 | template 65 | void exec(const char (&s)[N], object global = globals(), object local = object()) { 66 | eval(s, global, local); 67 | } 68 | 69 | template 70 | object eval_file(str fname, object global = globals(), object local = object()) { 71 | if (!local) 72 | local = global; 73 | 74 | int start; 75 | switch (mode) { 76 | case eval_expr: start = Py_eval_input; break; 77 | case eval_single_statement: start = Py_single_input; break; 78 | case eval_statements: start = Py_file_input; break; 79 | default: pybind11_fail("invalid evaluation mode"); 80 | } 81 | 82 | int closeFile = 1; 83 | std::string fname_str = (std::string) fname; 84 | #if PY_VERSION_HEX >= 0x03040000 85 | FILE *f = _Py_fopen_obj(fname.ptr(), "r"); 86 | #elif PY_VERSION_HEX >= 0x03000000 87 | FILE *f = _Py_fopen(fname.ptr(), "r"); 88 | #else 89 | /* No unicode support in open() :( */ 90 | auto fobj = reinterpret_steal(PyFile_FromString( 91 | const_cast(fname_str.c_str()), 92 | const_cast("r"))); 93 | FILE *f = nullptr; 94 | if (fobj) 95 | f = PyFile_AsFile(fobj.ptr()); 96 | closeFile = 0; 97 | #endif 98 | if (!f) { 99 | PyErr_Clear(); 100 | pybind11_fail("File \"" + fname_str + "\" could not be opened!"); 101 | } 102 | 103 | #if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION) 104 | PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(), 105 | local.ptr()); 106 | (void) closeFile; 107 | #else 108 | PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(), 109 | local.ptr(), closeFile); 110 | #endif 111 | 112 | if (!result) 113 | throw error_already_set(); 114 | return reinterpret_steal(result); 115 | } 116 | 117 | NAMESPACE_END(PYBIND11_NAMESPACE) 118 | -------------------------------------------------------------------------------- /include/pybind11/functional.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/functional.h: std::function<> support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | template 19 | struct type_caster> { 20 | using type = std::function; 21 | using retval_type = conditional_t::value, void_type, Return>; 22 | using function_type = Return (*) (Args...); 23 | 24 | public: 25 | bool load(handle src, bool convert) { 26 | if (src.is_none()) { 27 | // Defer accepting None to other overloads (if we aren't in convert mode): 28 | if (!convert) return false; 29 | return true; 30 | } 31 | 32 | if (!isinstance(src)) 33 | return false; 34 | 35 | auto func = reinterpret_borrow(src); 36 | 37 | /* 38 | When passing a C++ function as an argument to another C++ 39 | function via Python, every function call would normally involve 40 | a full C++ -> Python -> C++ roundtrip, which can be prohibitive. 41 | Here, we try to at least detect the case where the function is 42 | stateless (i.e. function pointer or lambda function without 43 | captured variables), in which case the roundtrip can be avoided. 44 | */ 45 | if (auto cfunc = func.cpp_function()) { 46 | auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); 47 | auto rec = (function_record *) c; 48 | 49 | if (rec && rec->is_stateless && 50 | same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { 51 | struct capture { function_type f; }; 52 | value = ((capture *) &rec->data)->f; 53 | return true; 54 | } 55 | } 56 | 57 | value = [func](Args... args) -> Return { 58 | gil_scoped_acquire acq; 59 | object retval(func(std::forward(args)...)); 60 | /* Visual studio 2015 parser issue: need parentheses around this expression */ 61 | return (retval.template cast()); 62 | }; 63 | return true; 64 | } 65 | 66 | template 67 | static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) { 68 | if (!f_) 69 | return none().inc_ref(); 70 | 71 | auto result = f_.template target(); 72 | if (result) 73 | return cpp_function(*result, policy).release(); 74 | else 75 | return cpp_function(std::forward(f_), policy).release(); 76 | } 77 | 78 | PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster::name...) + _("], ") 79 | + make_caster::name + _("]")); 80 | }; 81 | 82 | NAMESPACE_END(detail) 83 | NAMESPACE_END(PYBIND11_NAMESPACE) 84 | -------------------------------------------------------------------------------- /include/pybind11/iostream.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python 3 | 4 | Copyright (c) 2017 Henry F. Schreiner 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 21 | NAMESPACE_BEGIN(detail) 22 | 23 | // Buffer that writes to Python instead of C++ 24 | class pythonbuf : public std::streambuf { 25 | private: 26 | using traits_type = std::streambuf::traits_type; 27 | 28 | char d_buffer[1024]; 29 | object pywrite; 30 | object pyflush; 31 | 32 | int overflow(int c) { 33 | if (!traits_type::eq_int_type(c, traits_type::eof())) { 34 | *pptr() = traits_type::to_char_type(c); 35 | pbump(1); 36 | } 37 | return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof(); 38 | } 39 | 40 | int sync() { 41 | if (pbase() != pptr()) { 42 | // This subtraction cannot be negative, so dropping the sign 43 | str line(pbase(), static_cast(pptr() - pbase())); 44 | 45 | pywrite(line); 46 | pyflush(); 47 | 48 | setp(pbase(), epptr()); 49 | } 50 | return 0; 51 | } 52 | 53 | public: 54 | pythonbuf(object pyostream) 55 | : pywrite(pyostream.attr("write")), 56 | pyflush(pyostream.attr("flush")) { 57 | setp(d_buffer, d_buffer + sizeof(d_buffer) - 1); 58 | } 59 | 60 | /// Sync before destroy 61 | ~pythonbuf() { 62 | sync(); 63 | } 64 | }; 65 | 66 | NAMESPACE_END(detail) 67 | 68 | 69 | /** \rst 70 | This a move-only guard that redirects output. 71 | 72 | .. code-block:: cpp 73 | 74 | #include 75 | 76 | ... 77 | 78 | { 79 | py::scoped_ostream_redirect output; 80 | std::cout << "Hello, World!"; // Python stdout 81 | } // <-- return std::cout to normal 82 | 83 | You can explicitly pass the c++ stream and the python object, 84 | for example to guard stderr instead. 85 | 86 | .. code-block:: cpp 87 | 88 | { 89 | py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")}; 90 | std::cerr << "Hello, World!"; 91 | } 92 | \endrst */ 93 | class scoped_ostream_redirect { 94 | protected: 95 | std::streambuf *old; 96 | std::ostream &costream; 97 | detail::pythonbuf buffer; 98 | 99 | public: 100 | scoped_ostream_redirect( 101 | std::ostream &costream = std::cout, 102 | object pyostream = module::import("sys").attr("stdout")) 103 | : costream(costream), buffer(pyostream) { 104 | old = costream.rdbuf(&buffer); 105 | } 106 | 107 | ~scoped_ostream_redirect() { 108 | costream.rdbuf(old); 109 | } 110 | 111 | scoped_ostream_redirect(const scoped_ostream_redirect &) = delete; 112 | scoped_ostream_redirect(scoped_ostream_redirect &&other) = default; 113 | scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete; 114 | scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete; 115 | }; 116 | 117 | 118 | /** \rst 119 | Like `scoped_ostream_redirect`, but redirects cerr by default. This class 120 | is provided primary to make ``py::call_guard`` easier to make. 121 | 122 | .. code-block:: cpp 123 | 124 | m.def("noisy_func", &noisy_func, 125 | py::call_guard()); 127 | 128 | \endrst */ 129 | class scoped_estream_redirect : public scoped_ostream_redirect { 130 | public: 131 | scoped_estream_redirect( 132 | std::ostream &costream = std::cerr, 133 | object pyostream = module::import("sys").attr("stderr")) 134 | : scoped_ostream_redirect(costream,pyostream) {} 135 | }; 136 | 137 | 138 | NAMESPACE_BEGIN(detail) 139 | 140 | // Class to redirect output as a context manager. C++ backend. 141 | class OstreamRedirect { 142 | bool do_stdout_; 143 | bool do_stderr_; 144 | std::unique_ptr redirect_stdout; 145 | std::unique_ptr redirect_stderr; 146 | 147 | public: 148 | OstreamRedirect(bool do_stdout = true, bool do_stderr = true) 149 | : do_stdout_(do_stdout), do_stderr_(do_stderr) {} 150 | 151 | void enter() { 152 | if (do_stdout_) 153 | redirect_stdout.reset(new scoped_ostream_redirect()); 154 | if (do_stderr_) 155 | redirect_stderr.reset(new scoped_estream_redirect()); 156 | } 157 | 158 | void exit() { 159 | redirect_stdout.reset(); 160 | redirect_stderr.reset(); 161 | } 162 | }; 163 | 164 | NAMESPACE_END(detail) 165 | 166 | /** \rst 167 | This is a helper function to add a C++ redirect context manager to Python 168 | instead of using a C++ guard. To use it, add the following to your binding code: 169 | 170 | .. code-block:: cpp 171 | 172 | #include 173 | 174 | ... 175 | 176 | py::add_ostream_redirect(m, "ostream_redirect"); 177 | 178 | You now have a Python context manager that redirects your output: 179 | 180 | .. code-block:: python 181 | 182 | with m.ostream_redirect(): 183 | m.print_to_cout_function() 184 | 185 | This manager can optionally be told which streams to operate on: 186 | 187 | .. code-block:: python 188 | 189 | with m.ostream_redirect(stdout=true, stderr=true): 190 | m.noisy_function_with_error_printing() 191 | 192 | \endrst */ 193 | inline class_ add_ostream_redirect(module m, std::string name = "ostream_redirect") { 194 | return class_(m, name.c_str(), module_local()) 195 | .def(init(), arg("stdout")=true, arg("stderr")=true) 196 | .def("__enter__", &detail::OstreamRedirect::enter) 197 | .def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); }); 198 | } 199 | 200 | NAMESPACE_END(PYBIND11_NAMESPACE) 201 | -------------------------------------------------------------------------------- /include/pybind11/operators.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/operator.h: Metatemplates for operator overloading 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | 14 | #if defined(__clang__) && !defined(__INTEL_COMPILER) 15 | # pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type())) 16 | #elif defined(_MSC_VER) 17 | # pragma warning(push) 18 | # pragma warning(disable: 4127) // warning C4127: Conditional expression is constant 19 | #endif 20 | 21 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 22 | NAMESPACE_BEGIN(detail) 23 | 24 | /// Enumeration with all supported operator types 25 | enum op_id : int { 26 | op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift, 27 | op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert, 28 | op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le, 29 | op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift, 30 | op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero, 31 | op_repr, op_truediv, op_itruediv, op_hash 32 | }; 33 | 34 | enum op_type : int { 35 | op_l, /* base type on left */ 36 | op_r, /* base type on right */ 37 | op_u /* unary operator */ 38 | }; 39 | 40 | struct self_t { }; 41 | static const self_t self = self_t(); 42 | 43 | /// Type for an unused type slot 44 | struct undefined_t { }; 45 | 46 | /// Don't warn about an unused variable 47 | inline self_t __self() { return self; } 48 | 49 | /// base template of operator implementations 50 | template struct op_impl { }; 51 | 52 | /// Operator implementation generator 53 | template struct op_ { 54 | template void execute(Class &cl, const Extra&... extra) const { 55 | using Base = typename Class::type; 56 | using L_type = conditional_t::value, Base, L>; 57 | using R_type = conditional_t::value, Base, R>; 58 | using op = op_impl; 59 | cl.def(op::name(), &op::execute, is_operator(), extra...); 60 | #if PY_MAJOR_VERSION < 3 61 | if (id == op_truediv || id == op_itruediv) 62 | cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", 63 | &op::execute, is_operator(), extra...); 64 | #endif 65 | } 66 | template void execute_cast(Class &cl, const Extra&... extra) const { 67 | using Base = typename Class::type; 68 | using L_type = conditional_t::value, Base, L>; 69 | using R_type = conditional_t::value, Base, R>; 70 | using op = op_impl; 71 | cl.def(op::name(), &op::execute_cast, is_operator(), extra...); 72 | #if PY_MAJOR_VERSION < 3 73 | if (id == op_truediv || id == op_itruediv) 74 | cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", 75 | &op::execute, is_operator(), extra...); 76 | #endif 77 | } 78 | }; 79 | 80 | #define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \ 81 | template struct op_impl { \ 82 | static char const* name() { return "__" #id "__"; } \ 83 | static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \ 84 | static B execute_cast(const L &l, const R &r) { return B(expr); } \ 85 | }; \ 86 | template struct op_impl { \ 87 | static char const* name() { return "__" #rid "__"; } \ 88 | static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \ 89 | static B execute_cast(const R &r, const L &l) { return B(expr); } \ 90 | }; \ 91 | inline op_ op(const self_t &, const self_t &) { \ 92 | return op_(); \ 93 | } \ 94 | template op_ op(const self_t &, const T &) { \ 95 | return op_(); \ 96 | } \ 97 | template op_ op(const T &, const self_t &) { \ 98 | return op_(); \ 99 | } 100 | 101 | #define PYBIND11_INPLACE_OPERATOR(id, op, expr) \ 102 | template struct op_impl { \ 103 | static char const* name() { return "__" #id "__"; } \ 104 | static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \ 105 | static B execute_cast(L &l, const R &r) { return B(expr); } \ 106 | }; \ 107 | template op_ op(const self_t &, const T &) { \ 108 | return op_(); \ 109 | } 110 | 111 | #define PYBIND11_UNARY_OPERATOR(id, op, expr) \ 112 | template struct op_impl { \ 113 | static char const* name() { return "__" #id "__"; } \ 114 | static auto execute(const L &l) -> decltype(expr) { return expr; } \ 115 | static B execute_cast(const L &l) { return B(expr); } \ 116 | }; \ 117 | inline op_ op(const self_t &) { \ 118 | return op_(); \ 119 | } 120 | 121 | PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r) 122 | PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r) 123 | PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r) 124 | PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r) 125 | PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r) 126 | PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r) 127 | PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r) 128 | PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r) 129 | PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r) 130 | PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r) 131 | PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r) 132 | PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r) 133 | PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r) 134 | PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r) 135 | PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r) 136 | PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r) 137 | //PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r)) 138 | PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r) 139 | PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r) 140 | PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r) 141 | PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r) 142 | PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r) 143 | PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r) 144 | PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r) 145 | PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r) 146 | PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r) 147 | PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r) 148 | PYBIND11_UNARY_OPERATOR(neg, operator-, -l) 149 | PYBIND11_UNARY_OPERATOR(pos, operator+, +l) 150 | PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l)) 151 | PYBIND11_UNARY_OPERATOR(hash, hash, std::hash()(l)) 152 | PYBIND11_UNARY_OPERATOR(invert, operator~, (~l)) 153 | PYBIND11_UNARY_OPERATOR(bool, operator!, !!l) 154 | PYBIND11_UNARY_OPERATOR(int, int_, (int) l) 155 | PYBIND11_UNARY_OPERATOR(float, float_, (double) l) 156 | 157 | #undef PYBIND11_BINARY_OPERATOR 158 | #undef PYBIND11_INPLACE_OPERATOR 159 | #undef PYBIND11_UNARY_OPERATOR 160 | NAMESPACE_END(detail) 161 | 162 | using detail::self; 163 | 164 | NAMESPACE_END(PYBIND11_NAMESPACE) 165 | 166 | #if defined(_MSC_VER) 167 | # pragma warning(pop) 168 | #endif 169 | -------------------------------------------------------------------------------- /include/pybind11/options.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/options.h: global settings that are configurable at runtime. 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | 16 | class options { 17 | public: 18 | 19 | // Default RAII constructor, which leaves settings as they currently are. 20 | options() : previous_state(global_state()) {} 21 | 22 | // Class is non-copyable. 23 | options(const options&) = delete; 24 | options& operator=(const options&) = delete; 25 | 26 | // Destructor, which restores settings that were in effect before. 27 | ~options() { 28 | global_state() = previous_state; 29 | } 30 | 31 | // Setter methods (affect the global state): 32 | 33 | options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; } 34 | 35 | options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; } 36 | 37 | options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; } 38 | 39 | options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; } 40 | 41 | // Getter methods (return the global state): 42 | 43 | static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; } 44 | 45 | static bool show_function_signatures() { return global_state().show_function_signatures; } 46 | 47 | // This type is not meant to be allocated on the heap. 48 | void* operator new(size_t) = delete; 49 | 50 | private: 51 | 52 | struct state { 53 | bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings. 54 | bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings. 55 | }; 56 | 57 | static state &global_state() { 58 | static state instance; 59 | return instance; 60 | } 61 | 62 | state previous_state; 63 | }; 64 | 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /include/pybind11/stl.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/stl.h: Transparent conversion for STL data types 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #if defined(_MSC_VER) 23 | #pragma warning(push) 24 | #pragma warning(disable: 4127) // warning C4127: Conditional expression is constant 25 | #endif 26 | 27 | #ifdef __has_include 28 | // std::optional (but including it in c++14 mode isn't allowed) 29 | # if defined(PYBIND11_CPP17) && __has_include() 30 | # include 31 | # define PYBIND11_HAS_OPTIONAL 1 32 | # endif 33 | // std::experimental::optional (but not allowed in c++11 mode) 34 | # if defined(PYBIND11_CPP14) && (__has_include() && \ 35 | !__has_include()) 36 | # include 37 | # define PYBIND11_HAS_EXP_OPTIONAL 1 38 | # endif 39 | // std::variant 40 | # if defined(PYBIND11_CPP17) && __has_include() 41 | # include 42 | # define PYBIND11_HAS_VARIANT 1 43 | # endif 44 | #elif defined(_MSC_VER) && defined(PYBIND11_CPP17) 45 | # include 46 | # include 47 | # define PYBIND11_HAS_OPTIONAL 1 48 | # define PYBIND11_HAS_VARIANT 1 49 | #endif 50 | 51 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 52 | NAMESPACE_BEGIN(detail) 53 | 54 | /// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for 55 | /// forwarding a container element). Typically used indirect via forwarded_type(), below. 56 | template 57 | using forwarded_type = conditional_t< 58 | std::is_lvalue_reference::value, remove_reference_t &, remove_reference_t &&>; 59 | 60 | /// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically 61 | /// used for forwarding a container's elements. 62 | template 63 | forwarded_type forward_like(U &&u) { 64 | return std::forward>(std::forward(u)); 65 | } 66 | 67 | template struct set_caster { 68 | using type = Type; 69 | using key_conv = make_caster; 70 | 71 | bool load(handle src, bool convert) { 72 | if (!isinstance(src)) 73 | return false; 74 | auto s = reinterpret_borrow(src); 75 | value.clear(); 76 | for (auto entry : s) { 77 | key_conv conv; 78 | if (!conv.load(entry, convert)) 79 | return false; 80 | value.insert(cast_op(std::move(conv))); 81 | } 82 | return true; 83 | } 84 | 85 | template 86 | static handle cast(T &&src, return_value_policy policy, handle parent) { 87 | if (!std::is_lvalue_reference::value) 88 | policy = return_value_policy_override::policy(policy); 89 | pybind11::set s; 90 | for (auto &&value : src) { 91 | auto value_ = reinterpret_steal(key_conv::cast(forward_like(value), policy, parent)); 92 | if (!value_ || !s.add(value_)) 93 | return handle(); 94 | } 95 | return s.release(); 96 | } 97 | 98 | PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]")); 99 | }; 100 | 101 | template struct map_caster { 102 | using key_conv = make_caster; 103 | using value_conv = make_caster; 104 | 105 | bool load(handle src, bool convert) { 106 | if (!isinstance(src)) 107 | return false; 108 | auto d = reinterpret_borrow(src); 109 | value.clear(); 110 | for (auto it : d) { 111 | key_conv kconv; 112 | value_conv vconv; 113 | if (!kconv.load(it.first.ptr(), convert) || 114 | !vconv.load(it.second.ptr(), convert)) 115 | return false; 116 | value.emplace(cast_op(std::move(kconv)), cast_op(std::move(vconv))); 117 | } 118 | return true; 119 | } 120 | 121 | template 122 | static handle cast(T &&src, return_value_policy policy, handle parent) { 123 | dict d; 124 | return_value_policy policy_key = policy; 125 | return_value_policy policy_value = policy; 126 | if (!std::is_lvalue_reference::value) { 127 | policy_key = return_value_policy_override::policy(policy_key); 128 | policy_value = return_value_policy_override::policy(policy_value); 129 | } 130 | for (auto &&kv : src) { 131 | auto key = reinterpret_steal(key_conv::cast(forward_like(kv.first), policy_key, parent)); 132 | auto value = reinterpret_steal(value_conv::cast(forward_like(kv.second), policy_value, parent)); 133 | if (!key || !value) 134 | return handle(); 135 | d[key] = value; 136 | } 137 | return d.release(); 138 | } 139 | 140 | PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") + value_conv::name + _("]")); 141 | }; 142 | 143 | template struct list_caster { 144 | using value_conv = make_caster; 145 | 146 | bool load(handle src, bool convert) { 147 | if (!isinstance(src) || isinstance(src)) 148 | return false; 149 | auto s = reinterpret_borrow(src); 150 | value.clear(); 151 | reserve_maybe(s, &value); 152 | for (auto it : s) { 153 | value_conv conv; 154 | if (!conv.load(it, convert)) 155 | return false; 156 | value.push_back(cast_op(std::move(conv))); 157 | } 158 | return true; 159 | } 160 | 161 | private: 162 | template ().reserve(0)), void>::value, int> = 0> 164 | void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); } 165 | void reserve_maybe(sequence, void *) { } 166 | 167 | public: 168 | template 169 | static handle cast(T &&src, return_value_policy policy, handle parent) { 170 | if (!std::is_lvalue_reference::value) 171 | policy = return_value_policy_override::policy(policy); 172 | list l(src.size()); 173 | size_t index = 0; 174 | for (auto &&value : src) { 175 | auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); 176 | if (!value_) 177 | return handle(); 178 | PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference 179 | } 180 | return l.release(); 181 | } 182 | 183 | PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]")); 184 | }; 185 | 186 | template struct type_caster> 187 | : list_caster, Type> { }; 188 | 189 | template struct type_caster> 190 | : list_caster, Type> { }; 191 | 192 | template struct type_caster> 193 | : list_caster, Type> { }; 194 | 195 | template struct array_caster { 196 | using value_conv = make_caster; 197 | 198 | private: 199 | template 200 | bool require_size(enable_if_t size) { 201 | if (value.size() != size) 202 | value.resize(size); 203 | return true; 204 | } 205 | template 206 | bool require_size(enable_if_t size) { 207 | return size == Size; 208 | } 209 | 210 | public: 211 | bool load(handle src, bool convert) { 212 | if (!isinstance(src)) 213 | return false; 214 | auto l = reinterpret_borrow(src); 215 | if (!require_size(l.size())) 216 | return false; 217 | size_t ctr = 0; 218 | for (auto it : l) { 219 | value_conv conv; 220 | if (!conv.load(it, convert)) 221 | return false; 222 | value[ctr++] = cast_op(std::move(conv)); 223 | } 224 | return true; 225 | } 226 | 227 | template 228 | static handle cast(T &&src, return_value_policy policy, handle parent) { 229 | list l(src.size()); 230 | size_t index = 0; 231 | for (auto &&value : src) { 232 | auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); 233 | if (!value_) 234 | return handle(); 235 | PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference 236 | } 237 | return l.release(); 238 | } 239 | 240 | PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name + _(_(""), _("[") + _() + _("]")) + _("]")); 241 | }; 242 | 243 | template struct type_caster> 244 | : array_caster, Type, false, Size> { }; 245 | 246 | template struct type_caster> 247 | : array_caster, Type, true> { }; 248 | 249 | template struct type_caster> 250 | : set_caster, Key> { }; 251 | 252 | template struct type_caster> 253 | : set_caster, Key> { }; 254 | 255 | template struct type_caster> 256 | : map_caster, Key, Value> { }; 257 | 258 | template struct type_caster> 259 | : map_caster, Key, Value> { }; 260 | 261 | // This type caster is intended to be used for std::optional and std::experimental::optional 262 | template struct optional_caster { 263 | using value_conv = make_caster; 264 | 265 | template 266 | static handle cast(T_ &&src, return_value_policy policy, handle parent) { 267 | if (!src) 268 | return none().inc_ref(); 269 | policy = return_value_policy_override::policy(policy); 270 | return value_conv::cast(*std::forward(src), policy, parent); 271 | } 272 | 273 | bool load(handle src, bool convert) { 274 | if (!src) { 275 | return false; 276 | } else if (src.is_none()) { 277 | return true; // default-constructed value is already empty 278 | } 279 | value_conv inner_caster; 280 | if (!inner_caster.load(src, convert)) 281 | return false; 282 | 283 | value.emplace(cast_op(std::move(inner_caster))); 284 | return true; 285 | } 286 | 287 | PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]")); 288 | }; 289 | 290 | #if PYBIND11_HAS_OPTIONAL 291 | template struct type_caster> 292 | : public optional_caster> {}; 293 | 294 | template<> struct type_caster 295 | : public void_caster {}; 296 | #endif 297 | 298 | #if PYBIND11_HAS_EXP_OPTIONAL 299 | template struct type_caster> 300 | : public optional_caster> {}; 301 | 302 | template<> struct type_caster 303 | : public void_caster {}; 304 | #endif 305 | 306 | /// Visit a variant and cast any found type to Python 307 | struct variant_caster_visitor { 308 | return_value_policy policy; 309 | handle parent; 310 | 311 | using result_type = handle; // required by boost::variant in C++11 312 | 313 | template 314 | result_type operator()(T &&src) const { 315 | return make_caster::cast(std::forward(src), policy, parent); 316 | } 317 | }; 318 | 319 | /// Helper class which abstracts away variant's `visit` function. `std::variant` and similar 320 | /// `namespace::variant` types which provide a `namespace::visit()` function are handled here 321 | /// automatically using argument-dependent lookup. Users can provide specializations for other 322 | /// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`. 323 | template class Variant> 324 | struct visit_helper { 325 | template 326 | static auto call(Args &&...args) -> decltype(visit(std::forward(args)...)) { 327 | return visit(std::forward(args)...); 328 | } 329 | }; 330 | 331 | /// Generic variant caster 332 | template struct variant_caster; 333 | 334 | template class V, typename... Ts> 335 | struct variant_caster> { 336 | static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative."); 337 | 338 | template 339 | bool load_alternative(handle src, bool convert, type_list) { 340 | auto caster = make_caster(); 341 | if (caster.load(src, convert)) { 342 | value = cast_op(caster); 343 | return true; 344 | } 345 | return load_alternative(src, convert, type_list{}); 346 | } 347 | 348 | bool load_alternative(handle, bool, type_list<>) { return false; } 349 | 350 | bool load(handle src, bool convert) { 351 | // Do a first pass without conversions to improve constructor resolution. 352 | // E.g. `py::int_(1).cast>()` needs to fill the `int` 353 | // slot of the variant. Without two-pass loading `double` would be filled 354 | // because it appears first and a conversion is possible. 355 | if (convert && load_alternative(src, false, type_list{})) 356 | return true; 357 | return load_alternative(src, convert, type_list{}); 358 | } 359 | 360 | template 361 | static handle cast(Variant &&src, return_value_policy policy, handle parent) { 362 | return visit_helper::call(variant_caster_visitor{policy, parent}, 363 | std::forward(src)); 364 | } 365 | 366 | using Type = V; 367 | PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster::name...) + _("]")); 368 | }; 369 | 370 | #if PYBIND11_HAS_VARIANT 371 | template 372 | struct type_caster> : variant_caster> { }; 373 | #endif 374 | 375 | NAMESPACE_END(detail) 376 | 377 | inline std::ostream &operator<<(std::ostream &os, const handle &obj) { 378 | os << (std::string) str(obj); 379 | return os; 380 | } 381 | 382 | NAMESPACE_END(PYBIND11_NAMESPACE) 383 | 384 | #if defined(_MSC_VER) 385 | #pragma warning(pop) 386 | #endif 387 | -------------------------------------------------------------------------------- /include/pybind11/stl_bind.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/std_bind.h: Binding generators for STL data types 3 | 4 | Copyright (c) 2016 Sergey Lyskov and Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | #include "operators.h" 14 | 15 | #include 16 | #include 17 | 18 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 19 | NAMESPACE_BEGIN(detail) 20 | 21 | /* SFINAE helper class used by 'is_comparable */ 22 | template struct container_traits { 23 | template static std::true_type test_comparable(decltype(std::declval() == std::declval())*); 24 | template static std::false_type test_comparable(...); 25 | template static std::true_type test_value(typename T2::value_type *); 26 | template static std::false_type test_value(...); 27 | template static std::true_type test_pair(typename T2::first_type *, typename T2::second_type *); 28 | template static std::false_type test_pair(...); 29 | 30 | static constexpr const bool is_comparable = std::is_same(nullptr))>::value; 31 | static constexpr const bool is_pair = std::is_same(nullptr, nullptr))>::value; 32 | static constexpr const bool is_vector = std::is_same(nullptr))>::value; 33 | static constexpr const bool is_element = !is_pair && !is_vector; 34 | }; 35 | 36 | /* Default: is_comparable -> std::false_type */ 37 | template 38 | struct is_comparable : std::false_type { }; 39 | 40 | /* For non-map data structures, check whether operator== can be instantiated */ 41 | template 42 | struct is_comparable< 43 | T, enable_if_t::is_element && 44 | container_traits::is_comparable>> 45 | : std::true_type { }; 46 | 47 | /* For a vector/map data structure, recursively check the value type (which is std::pair for maps) */ 48 | template 49 | struct is_comparable::is_vector>> { 50 | static constexpr const bool value = 51 | is_comparable::value; 52 | }; 53 | 54 | /* For pairs, recursively check the two data types */ 55 | template 56 | struct is_comparable::is_pair>> { 57 | static constexpr const bool value = 58 | is_comparable::value && 59 | is_comparable::value; 60 | }; 61 | 62 | /* Fallback functions */ 63 | template void vector_if_copy_constructible(const Args &...) { } 64 | template void vector_if_equal_operator(const Args &...) { } 65 | template void vector_if_insertion_operator(const Args &...) { } 66 | template void vector_modifiers(const Args &...) { } 67 | 68 | template 69 | void vector_if_copy_constructible(enable_if_t::value, Class_> &cl) { 70 | cl.def(init(), "Copy constructor"); 71 | } 72 | 73 | template 74 | void vector_if_equal_operator(enable_if_t::value, Class_> &cl) { 75 | using T = typename Vector::value_type; 76 | 77 | cl.def(self == self); 78 | cl.def(self != self); 79 | 80 | cl.def("count", 81 | [](const Vector &v, const T &x) { 82 | return std::count(v.begin(), v.end(), x); 83 | }, 84 | arg("x"), 85 | "Return the number of times ``x`` appears in the list" 86 | ); 87 | 88 | cl.def("remove", [](Vector &v, const T &x) { 89 | auto p = std::find(v.begin(), v.end(), x); 90 | if (p != v.end()) 91 | v.erase(p); 92 | else 93 | throw value_error(); 94 | }, 95 | arg("x"), 96 | "Remove the first item from the list whose value is x. " 97 | "It is an error if there is no such item." 98 | ); 99 | 100 | cl.def("__contains__", 101 | [](const Vector &v, const T &x) { 102 | return std::find(v.begin(), v.end(), x) != v.end(); 103 | }, 104 | arg("x"), 105 | "Return true the container contains ``x``" 106 | ); 107 | } 108 | 109 | // Vector modifiers -- requires a copyable vector_type: 110 | // (Technically, some of these (pop and __delitem__) don't actually require copyability, but it seems 111 | // silly to allow deletion but not insertion, so include them here too.) 112 | template 113 | void vector_modifiers(enable_if_t::value, Class_> &cl) { 114 | using T = typename Vector::value_type; 115 | using SizeType = typename Vector::size_type; 116 | using DiffType = typename Vector::difference_type; 117 | 118 | cl.def("append", 119 | [](Vector &v, const T &value) { v.push_back(value); }, 120 | arg("x"), 121 | "Add an item to the end of the list"); 122 | 123 | cl.def(init([](iterable it) { 124 | auto v = std::unique_ptr(new Vector()); 125 | v->reserve(len(it)); 126 | for (handle h : it) 127 | v->push_back(h.cast()); 128 | return v.release(); 129 | })); 130 | 131 | cl.def("extend", 132 | [](Vector &v, const Vector &src) { 133 | v.insert(v.end(), src.begin(), src.end()); 134 | }, 135 | arg("L"), 136 | "Extend the list by appending all the items in the given list" 137 | ); 138 | 139 | cl.def("insert", 140 | [](Vector &v, SizeType i, const T &x) { 141 | if (i > v.size()) 142 | throw index_error(); 143 | v.insert(v.begin() + (DiffType) i, x); 144 | }, 145 | arg("i") , arg("x"), 146 | "Insert an item at a given position." 147 | ); 148 | 149 | cl.def("pop", 150 | [](Vector &v) { 151 | if (v.empty()) 152 | throw index_error(); 153 | T t = v.back(); 154 | v.pop_back(); 155 | return t; 156 | }, 157 | "Remove and return the last item" 158 | ); 159 | 160 | cl.def("pop", 161 | [](Vector &v, SizeType i) { 162 | if (i >= v.size()) 163 | throw index_error(); 164 | T t = v[i]; 165 | v.erase(v.begin() + (DiffType) i); 166 | return t; 167 | }, 168 | arg("i"), 169 | "Remove and return the item at index ``i``" 170 | ); 171 | 172 | cl.def("__setitem__", 173 | [](Vector &v, SizeType i, const T &t) { 174 | if (i >= v.size()) 175 | throw index_error(); 176 | v[i] = t; 177 | } 178 | ); 179 | 180 | /// Slicing protocol 181 | cl.def("__getitem__", 182 | [](const Vector &v, slice slice) -> Vector * { 183 | size_t start, stop, step, slicelength; 184 | 185 | if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) 186 | throw error_already_set(); 187 | 188 | Vector *seq = new Vector(); 189 | seq->reserve((size_t) slicelength); 190 | 191 | for (size_t i=0; ipush_back(v[start]); 193 | start += step; 194 | } 195 | return seq; 196 | }, 197 | arg("s"), 198 | "Retrieve list elements using a slice object" 199 | ); 200 | 201 | cl.def("__setitem__", 202 | [](Vector &v, slice slice, const Vector &value) { 203 | size_t start, stop, step, slicelength; 204 | if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) 205 | throw error_already_set(); 206 | 207 | if (slicelength != value.size()) 208 | throw std::runtime_error("Left and right hand size of slice assignment have different sizes!"); 209 | 210 | for (size_t i=0; i= v.size()) 221 | throw index_error(); 222 | v.erase(v.begin() + DiffType(i)); 223 | }, 224 | "Delete the list elements at index ``i``" 225 | ); 226 | 227 | cl.def("__delitem__", 228 | [](Vector &v, slice slice) { 229 | size_t start, stop, step, slicelength; 230 | 231 | if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) 232 | throw error_already_set(); 233 | 234 | if (step == 1 && false) { 235 | v.erase(v.begin() + (DiffType) start, v.begin() + DiffType(start + slicelength)); 236 | } else { 237 | for (size_t i = 0; i < slicelength; ++i) { 238 | v.erase(v.begin() + DiffType(start)); 239 | start += step - 1; 240 | } 241 | } 242 | }, 243 | "Delete list elements using a slice object" 244 | ); 245 | 246 | } 247 | 248 | // If the type has an operator[] that doesn't return a reference (most notably std::vector), 249 | // we have to access by copying; otherwise we return by reference. 250 | template using vector_needs_copy = negation< 251 | std::is_same()[typename Vector::size_type()]), typename Vector::value_type &>>; 252 | 253 | // The usual case: access and iterate by reference 254 | template 255 | void vector_accessor(enable_if_t::value, Class_> &cl) { 256 | using T = typename Vector::value_type; 257 | using SizeType = typename Vector::size_type; 258 | using ItType = typename Vector::iterator; 259 | 260 | cl.def("__getitem__", 261 | [](Vector &v, SizeType i) -> T & { 262 | if (i >= v.size()) 263 | throw index_error(); 264 | return v[i]; 265 | }, 266 | return_value_policy::reference_internal // ref + keepalive 267 | ); 268 | 269 | cl.def("__iter__", 270 | [](Vector &v) { 271 | return make_iterator< 272 | return_value_policy::reference_internal, ItType, ItType, T&>( 273 | v.begin(), v.end()); 274 | }, 275 | keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ 276 | ); 277 | } 278 | 279 | // The case for special objects, like std::vector, that have to be returned-by-copy: 280 | template 281 | void vector_accessor(enable_if_t::value, Class_> &cl) { 282 | using T = typename Vector::value_type; 283 | using SizeType = typename Vector::size_type; 284 | using ItType = typename Vector::iterator; 285 | cl.def("__getitem__", 286 | [](const Vector &v, SizeType i) -> T { 287 | if (i >= v.size()) 288 | throw index_error(); 289 | return v[i]; 290 | } 291 | ); 292 | 293 | cl.def("__iter__", 294 | [](Vector &v) { 295 | return make_iterator< 296 | return_value_policy::copy, ItType, ItType, T>( 297 | v.begin(), v.end()); 298 | }, 299 | keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ 300 | ); 301 | } 302 | 303 | template auto vector_if_insertion_operator(Class_ &cl, std::string const &name) 304 | -> decltype(std::declval() << std::declval(), void()) { 305 | using size_type = typename Vector::size_type; 306 | 307 | cl.def("__repr__", 308 | [name](Vector &v) { 309 | std::ostringstream s; 310 | s << name << '['; 311 | for (size_type i=0; i < v.size(); ++i) { 312 | s << v[i]; 313 | if (i != v.size() - 1) 314 | s << ", "; 315 | } 316 | s << ']'; 317 | return s.str(); 318 | }, 319 | "Return the canonical string representation of this list." 320 | ); 321 | } 322 | 323 | // Provide the buffer interface for vectors if we have data() and we have a format for it 324 | // GCC seems to have "void std::vector::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer 325 | template 326 | struct vector_has_data_and_format : std::false_type {}; 327 | template 328 | struct vector_has_data_and_format::format(), std::declval().data()), typename Vector::value_type*>::value>> : std::true_type {}; 329 | 330 | // Add the buffer interface to a vector 331 | template 332 | enable_if_t...>::value> 333 | vector_buffer(Class_& cl) { 334 | using T = typename Vector::value_type; 335 | 336 | static_assert(vector_has_data_and_format::value, "There is not an appropriate format descriptor for this vector"); 337 | 338 | // numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here 339 | format_descriptor::format(); 340 | 341 | cl.def_buffer([](Vector& v) -> buffer_info { 342 | return buffer_info(v.data(), static_cast(sizeof(T)), format_descriptor::format(), 1, {v.size()}, {sizeof(T)}); 343 | }); 344 | 345 | cl.def(init([](buffer buf) { 346 | auto info = buf.request(); 347 | if (info.ndim != 1 || info.strides[0] % static_cast(sizeof(T))) 348 | throw type_error("Only valid 1D buffers can be copied to a vector"); 349 | if (!detail::compare_buffer_info::compare(info) || (ssize_t) sizeof(T) != info.itemsize) 350 | throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor::format() + ")"); 351 | 352 | auto vec = std::unique_ptr(new Vector()); 353 | vec->reserve((size_t) info.shape[0]); 354 | T *p = static_cast(info.ptr); 355 | ssize_t step = info.strides[0] / static_cast(sizeof(T)); 356 | T *end = p + info.shape[0] * step; 357 | for (; p != end; p += step) 358 | vec->push_back(*p); 359 | return vec.release(); 360 | })); 361 | 362 | return; 363 | } 364 | 365 | template 366 | enable_if_t...>::value> vector_buffer(Class_&) {} 367 | 368 | NAMESPACE_END(detail) 369 | 370 | // 371 | // std::vector 372 | // 373 | template , typename... Args> 374 | class_ bind_vector(handle scope, std::string const &name, Args&&... args) { 375 | using Class_ = class_; 376 | 377 | // If the value_type is unregistered (e.g. a converting type) or is itself registered 378 | // module-local then make the vector binding module-local as well: 379 | using vtype = typename Vector::value_type; 380 | auto vtype_info = detail::get_type_info(typeid(vtype)); 381 | bool local = !vtype_info || vtype_info->module_local; 382 | 383 | Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward(args)...); 384 | 385 | // Declare the buffer interface if a buffer_protocol() is passed in 386 | detail::vector_buffer(cl); 387 | 388 | cl.def(init<>()); 389 | 390 | // Register copy constructor (if possible) 391 | detail::vector_if_copy_constructible(cl); 392 | 393 | // Register comparison-related operators and functions (if possible) 394 | detail::vector_if_equal_operator(cl); 395 | 396 | // Register stream insertion operator (if possible) 397 | detail::vector_if_insertion_operator(cl, name); 398 | 399 | // Modifiers require copyable vector value type 400 | detail::vector_modifiers(cl); 401 | 402 | // Accessor and iterator; return by value if copyable, otherwise we return by ref + keep-alive 403 | detail::vector_accessor(cl); 404 | 405 | cl.def("__bool__", 406 | [](const Vector &v) -> bool { 407 | return !v.empty(); 408 | }, 409 | "Check whether the list is nonempty" 410 | ); 411 | 412 | cl.def("__len__", &Vector::size); 413 | 414 | 415 | 416 | 417 | #if 0 418 | // C++ style functions deprecated, leaving it here as an example 419 | cl.def(init()); 420 | 421 | cl.def("resize", 422 | (void (Vector::*) (size_type count)) & Vector::resize, 423 | "changes the number of elements stored"); 424 | 425 | cl.def("erase", 426 | [](Vector &v, SizeType i) { 427 | if (i >= v.size()) 428 | throw index_error(); 429 | v.erase(v.begin() + i); 430 | }, "erases element at index ``i``"); 431 | 432 | cl.def("empty", &Vector::empty, "checks whether the container is empty"); 433 | cl.def("size", &Vector::size, "returns the number of elements"); 434 | cl.def("push_back", (void (Vector::*)(const T&)) &Vector::push_back, "adds an element to the end"); 435 | cl.def("pop_back", &Vector::pop_back, "removes the last element"); 436 | 437 | cl.def("max_size", &Vector::max_size, "returns the maximum possible number of elements"); 438 | cl.def("reserve", &Vector::reserve, "reserves storage"); 439 | cl.def("capacity", &Vector::capacity, "returns the number of elements that can be held in currently allocated storage"); 440 | cl.def("shrink_to_fit", &Vector::shrink_to_fit, "reduces memory usage by freeing unused memory"); 441 | 442 | cl.def("clear", &Vector::clear, "clears the contents"); 443 | cl.def("swap", &Vector::swap, "swaps the contents"); 444 | 445 | cl.def("front", [](Vector &v) { 446 | if (v.size()) return v.front(); 447 | else throw index_error(); 448 | }, "access the first element"); 449 | 450 | cl.def("back", [](Vector &v) { 451 | if (v.size()) return v.back(); 452 | else throw index_error(); 453 | }, "access the last element "); 454 | 455 | #endif 456 | 457 | return cl; 458 | } 459 | 460 | 461 | 462 | // 463 | // std::map, std::unordered_map 464 | // 465 | 466 | NAMESPACE_BEGIN(detail) 467 | 468 | /* Fallback functions */ 469 | template void map_if_insertion_operator(const Args &...) { } 470 | template void map_assignment(const Args &...) { } 471 | 472 | // Map assignment when copy-assignable: just copy the value 473 | template 474 | void map_assignment(enable_if_t::value, Class_> &cl) { 475 | using KeyType = typename Map::key_type; 476 | using MappedType = typename Map::mapped_type; 477 | 478 | cl.def("__setitem__", 479 | [](Map &m, const KeyType &k, const MappedType &v) { 480 | auto it = m.find(k); 481 | if (it != m.end()) it->second = v; 482 | else m.emplace(k, v); 483 | } 484 | ); 485 | } 486 | 487 | // Not copy-assignable, but still copy-constructible: we can update the value by erasing and reinserting 488 | template 489 | void map_assignment(enable_if_t< 490 | !std::is_copy_assignable::value && 491 | is_copy_constructible::value, 492 | Class_> &cl) { 493 | using KeyType = typename Map::key_type; 494 | using MappedType = typename Map::mapped_type; 495 | 496 | cl.def("__setitem__", 497 | [](Map &m, const KeyType &k, const MappedType &v) { 498 | // We can't use m[k] = v; because value type might not be default constructable 499 | auto r = m.emplace(k, v); 500 | if (!r.second) { 501 | // value type is not copy assignable so the only way to insert it is to erase it first... 502 | m.erase(r.first); 503 | m.emplace(k, v); 504 | } 505 | } 506 | ); 507 | } 508 | 509 | 510 | template auto map_if_insertion_operator(Class_ &cl, std::string const &name) 511 | -> decltype(std::declval() << std::declval() << std::declval(), void()) { 512 | 513 | cl.def("__repr__", 514 | [name](Map &m) { 515 | std::ostringstream s; 516 | s << name << '{'; 517 | bool f = false; 518 | for (auto const &kv : m) { 519 | if (f) 520 | s << ", "; 521 | s << kv.first << ": " << kv.second; 522 | f = true; 523 | } 524 | s << '}'; 525 | return s.str(); 526 | }, 527 | "Return the canonical string representation of this map." 528 | ); 529 | } 530 | 531 | 532 | NAMESPACE_END(detail) 533 | 534 | template , typename... Args> 535 | class_ bind_map(handle scope, const std::string &name, Args&&... args) { 536 | using KeyType = typename Map::key_type; 537 | using MappedType = typename Map::mapped_type; 538 | using Class_ = class_; 539 | 540 | // If either type is a non-module-local bound type then make the map binding non-local as well; 541 | // otherwise (e.g. both types are either module-local or converting) the map will be 542 | // module-local. 543 | auto tinfo = detail::get_type_info(typeid(MappedType)); 544 | bool local = !tinfo || tinfo->module_local; 545 | if (local) { 546 | tinfo = detail::get_type_info(typeid(KeyType)); 547 | local = !tinfo || tinfo->module_local; 548 | } 549 | 550 | Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward(args)...); 551 | 552 | cl.def(init<>()); 553 | 554 | // Register stream insertion operator (if possible) 555 | detail::map_if_insertion_operator(cl, name); 556 | 557 | cl.def("__bool__", 558 | [](const Map &m) -> bool { return !m.empty(); }, 559 | "Check whether the map is nonempty" 560 | ); 561 | 562 | cl.def("__iter__", 563 | [](Map &m) { return make_key_iterator(m.begin(), m.end()); }, 564 | keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ 565 | ); 566 | 567 | cl.def("items", 568 | [](Map &m) { return make_iterator(m.begin(), m.end()); }, 569 | keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ 570 | ); 571 | 572 | cl.def("__getitem__", 573 | [](Map &m, const KeyType &k) -> MappedType & { 574 | auto it = m.find(k); 575 | if (it == m.end()) 576 | throw key_error(); 577 | return it->second; 578 | }, 579 | return_value_policy::reference_internal // ref + keepalive 580 | ); 581 | 582 | // Assignment provided only if the type is copyable 583 | detail::map_assignment(cl); 584 | 585 | cl.def("__delitem__", 586 | [](Map &m, const KeyType &k) { 587 | auto it = m.find(k); 588 | if (it == m.end()) 589 | throw key_error(); 590 | m.erase(it); 591 | } 592 | ); 593 | 594 | cl.def("__len__", &Map::size); 595 | 596 | return cl; 597 | } 598 | 599 | NAMESPACE_END(PYBIND11_NAMESPACE) 600 | -------------------------------------------------------------------------------- /include/pybind11/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(pybind11) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(pybind11) 54 | -------------------------------------------------------------------------------- /include/recognizer.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by xpc on 19-4-29. 3 | // 4 | 5 | #pragma once 6 | 7 | #include 8 | #include "tensorflow_graph.h" 9 | 10 | namespace tf = tensorflow; 11 | 12 | namespace SeetaOCR { 13 | class Recognizer : public TFGraph { 14 | public: 15 | Recognizer(const std::string &graph_file, 16 | const std::string &label_file) 17 | : TFGraph(graph_file) { 18 | imageHeight = 32; 19 | outputTensorNames = {"indices", "values", "prob"}; 20 | Init(); 21 | LoadLabelFile(label_file); 22 | }; 23 | 24 | void LoadLabelFile(const std::string &label_file); 25 | 26 | void FeedImageToTensor(cv::Mat &inp); 27 | 28 | void FeedImagesToTensor(std::vector &inp); 29 | 30 | void Predict(); 31 | 32 | void Predict(std::vector &inp, std::map>& result) { 33 | FeedImagesToTensor(inp); 34 | Predict(); 35 | for (auto &d: decoded) { result[d.first] = d.second; } 36 | } 37 | 38 | void Debug() {DEBUG=true;} 39 | 40 | protected: 41 | bool DEBUG=false; 42 | int imageHeight; 43 | 44 | std::map label; 45 | std::vector> inputs; 46 | std::vector outputs; 47 | std::map> decoded; 48 | 49 | void ResizeImage(cv::Mat &inp, cv::Mat &out); 50 | 51 | void ResizeImages(std::vector &inp, std::vector &out); 52 | 53 | void Decode(tf::Tensor& indices, tf::Tensor& values, tf::Tensor& probs); 54 | }; 55 | } -------------------------------------------------------------------------------- /include/tensorflow_graph.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by seeta on 19-4-28. 3 | // 4 | 5 | #pragma once 6 | 7 | #include 8 | #include "tensorflow/core/public/session.h" 9 | 10 | namespace tf = tensorflow; 11 | 12 | class TFGraph { 13 | public: 14 | TFGraph(const std::string& graph_file){ 15 | graphFile = graph_file; 16 | } 17 | TFGraph(const std::string& graph_file, 18 | const std::vector& output_tensor_names) { 19 | graphFile = graph_file; 20 | outputTensorNames = output_tensor_names; 21 | } 22 | 23 | protected: 24 | void Init(){ 25 | tf::NewSession(tf::SessionOptions(), &session); 26 | 27 | statusLoad = tf::ReadBinaryProto(tf::Env::Default(), graphFile, &graphdef); //从pb文件中读取图模型; 28 | if (!statusLoad.ok()) { 29 | throw std::runtime_error("Loading model failed...\n" + statusLoad.ToString()); 30 | } 31 | 32 | statusCreate = session->Create(graphdef); 33 | if (!statusCreate.ok()) { 34 | throw std::runtime_error("Creating graph in session failed...\n" + statusCreate.ToString()); 35 | } 36 | } 37 | 38 | void FetchTensor(std::vector>& inputs, 39 | std::vector& outputs) { 40 | 41 | statusRun = session->Run(inputs, outputTensorNames, targetNodeNames, &outputs); 42 | 43 | if(!statusRun.ok()){ 44 | throw std::runtime_error("Session Run failed...\n" + statusRun.ToString()); 45 | } 46 | } 47 | 48 | tf::Session* session; 49 | tf::GraphDef graphdef; 50 | tf::Status statusRun; 51 | tf::Status statusLoad; 52 | tf::Status statusCreate; 53 | std::string graphFile; 54 | std::vector outputTensorNames; 55 | std::vector targetNodeNames = {}; 56 | }; 57 | 58 | -------------------------------------------------------------------------------- /src/detector.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by xpc on 19-4-28. 3 | // 4 | #include 5 | #include 6 | #include "detector.h" 7 | 8 | 9 | namespace SeetaOCR { 10 | 11 | void Detector::ResizeImage(cv::Mat& inp, cv::Mat& out, int longest_side) { 12 | float ratio = 1.0; 13 | auto width = (float) inp.cols; 14 | auto height = (float) inp.rows; 15 | 16 | if (fmax(height, width) > longest_side) { 17 | ratio = (height > width) ? (longest_side / height): (longest_side / width); 18 | } 19 | 20 | auto resizedH = (int) (height * ratio); 21 | auto resizedW = (int) (width * ratio); 22 | 23 | if (resizedH % 32 != 0) { 24 | resizedH = 32 * ((int)floor(resizedH / 32.0) + 1); 25 | } 26 | 27 | if (resizedW % 32 != 0) { 28 | resizedW = 32 * ((int)floor(resizedW / 32.0) + 1); 29 | } 30 | cv::resize(inp, out, cv::Size(resizedW, resizedH)); 31 | } 32 | 33 | void Detector::FeedImageToTensor(cv::Mat& inp){ 34 | ResizeImage(inp, resized, longestSide); 35 | 36 | tf::Tensor input_tensor(tf::DT_FLOAT, tf::TensorShape({1, resized.rows, resized.cols, 3})); 37 | auto input_tensor_ptr = input_tensor.tensor(); 38 | 39 | for (int n=0;n<1;++n) 40 | for(int h = 0 ; h < resized.rows; ++h) 41 | for(int w = 0; w < resized.cols; ++w) 42 | for(int c = 0; c < 3; ++c){ 43 | input_tensor_ptr(n, h, w, c) = resized.at(h, w)[2 - c]; 44 | } 45 | 46 | inputs = { 47 | {"input_images:0", input_tensor} 48 | }; 49 | } 50 | 51 | void Detector::Predict(cv::Mat& inp) { 52 | 53 | FeedImageToTensor(inp); 54 | FetchTensor(inputs, outputs); 55 | 56 | std::map> contoursMap; 57 | PseAdaptor(outputs[0], contoursMap, 0.9, 10, 1); 58 | 59 | float scaleX = (float) inp.cols / resized.cols; 60 | float scaleY = (float) inp.rows / resized.rows; 61 | 62 | std::vector().swap(polygons); 63 | 64 | for (auto &cnt: contoursMap) { 65 | cv::Mat boxPts; 66 | cv::RotatedRect minRect = cv::minAreaRect(cnt.second); 67 | cv::boxPoints(minRect, boxPts); 68 | 69 | Polygon polygon(boxPts, cv::Size(inp.rows, inp.cols), scaleX, scaleY); 70 | polygons.emplace_back(polygon); 71 | } 72 | 73 | if (DEBUG) { 74 | cv::Mat tmp(inp); 75 | for (int i=0; i < polygons.size(); ++i) { 76 | cv::Mat quad; 77 | std::vector quad_pts = polygons[i].ToQuadROI(); 78 | cv::Mat transmtx = cv::getPerspectiveTransform(polygons[i].ToVec2f(), quad_pts); 79 | cv::warpPerspective(tmp, quad, transmtx, cv::Size((int)quad_pts[2].x, (int)quad_pts[2].y)); 80 | cv::resize(quad, quad, cv::Size(0, 0), 0.3, 0.3); 81 | cv::imshow(std::to_string(i), quad); 82 | cv::polylines(tmp, polygons[i].ToVec2i(), true, cv::Scalar(0, 0, 255), 2); 83 | } 84 | cv::resize(tmp, tmp, cv::Size(0, 0), 0.3, 0.3); 85 | cv::imshow("debug", tmp); 86 | cv::waitKey(0); 87 | } 88 | } 89 | 90 | void Detector::PseAdaptor(tf::Tensor& features, 91 | std::map>& contours_map, 92 | const float thresh, 93 | const float min_area, 94 | const float ratio) { 95 | 96 | /// get kernels 97 | auto features_ptr = features.tensor(); 98 | 99 | auto N = (int) features.dim_size(0); 100 | auto H = (int) features.dim_size(1); 101 | auto W = (int) features.dim_size(2); 102 | auto C = (int) features.dim_size(3); 103 | 104 | std::vector kernels; 105 | 106 | float _thresh = thresh; 107 | for (int n = 0; n < N; ++n) { 108 | for (int c = C - 1; c >= 0; --c) { 109 | cv::Mat kernel(H, W, CV_8UC1); 110 | for (int h = 0; h < H; ++h) { 111 | for (int w = 0; w < W; ++w) { 112 | if (features_ptr(n, h, w, c) > _thresh) { 113 | kernel.at(h, w) = 1; 114 | } else { 115 | kernel.at(h, w) = 0; 116 | } 117 | } 118 | } 119 | kernels.push_back(kernel); 120 | _thresh = thresh * ratio; 121 | } 122 | } 123 | 124 | /// make label 125 | cv::Mat label; 126 | std::map areas; 127 | cv::Mat mask(H, W, CV_32S, cv::Scalar(0)); 128 | cv::connectedComponents(kernels[C - 1], label, 4); 129 | 130 | for (int y = 0; y < label.rows; ++y) { 131 | for (int x = 0; x < label.cols; ++x) { 132 | int value = label.at(y, x); 133 | if (value == 0) continue; 134 | areas[value] += 1; 135 | } 136 | } 137 | 138 | std::queue queue, next_queue; 139 | 140 | for (int y = 0; y < label.rows; ++y) { 141 | for (int x = 0; x < label.cols; ++x) { 142 | int value = label.at(y, x); 143 | if (value == 0) continue; 144 | if (areas[value] < min_area) { 145 | areas.erase(value); 146 | continue; 147 | } 148 | cv::Point point(x, y); 149 | queue.push(point); 150 | mask.at(y, x) = value; 151 | } 152 | } 153 | 154 | /// growing text line 155 | int dx[] = {-1, 1, 0, 0}; 156 | int dy[] = {0, 0, -1, 1}; 157 | 158 | for (int idx = C - 2; idx >= 0; --idx) { 159 | while (!queue.empty()) { 160 | cv::Point point = queue.front(); queue.pop(); 161 | int x = point.x; 162 | int y = point.y; 163 | int value = mask.at(y, x); 164 | 165 | bool is_edge = true; 166 | for (int d = 0; d < 4; ++d) { 167 | int _x = x + dx[d]; 168 | int _y = y + dy[d]; 169 | 170 | if (_y < 0 || _y >= mask.rows) continue; 171 | if (_x < 0 || _x >= mask.cols) continue; 172 | if (kernels[idx].at(_y, _x) == 0) continue; 173 | if (mask.at(_y, _x) > 0) continue; 174 | 175 | cv::Point point_dxy(_x, _y); 176 | queue.push(point_dxy); 177 | 178 | mask.at(_y, _x) = value; 179 | is_edge = false; 180 | } 181 | 182 | if (is_edge) next_queue.push(point); 183 | } 184 | std::swap(queue, next_queue); 185 | } 186 | 187 | /// make contoursMap 188 | for (int y=0; y < mask.rows; ++y) 189 | for (int x=0; x < mask.cols; ++x) { 190 | int idx = mask.at(y, x); 191 | if (idx == 0) continue; 192 | contours_map[idx].emplace_back(cv::Point(x, y)); 193 | } 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "detector.h" 5 | #include "recognizer.h" 6 | 7 | 8 | 9 | void DemoAll() { 10 | std::vector demo_images({"/media/seeta/新加卷2/zzsfp/发票驾驶证/JPEGImages/101.jpg", 11 | "/media/seeta/新加卷2/zzsfp/3.png", 12 | "/media/seeta/新加卷2/zzsfp/4.jpg", 13 | "../data/demo/1234.jpg", 14 | "../data/demo/321.jpg", 15 | "../data/demo/ktp1.jpg", 16 | 17 | "../data/demo/npwp1.jpg", 18 | "../data/demo/sim1.jpg", 19 | "../data/demo/img_911.jpg", 20 | "../data/demo/12345.jpg"}); 21 | clock_t start1, end1, start2, end2; 22 | 23 | SeetaOCR::Detector detector = SeetaOCR::Detector("../data/models/psenet.pb"); 24 | SeetaOCR::Recognizer recognizer = SeetaOCR::Recognizer("../data/models/crnn_mobilnet.pb", 25 | "../data/models/crnn.txt"); 26 | // detector.Debug(); 27 | // recognizer.Debug(); 28 | 29 | for (auto &image_path: demo_images) { 30 | 31 | cv::Mat image = cv::imread(image_path); 32 | 33 | start1 = clock(); 34 | std::vector polygons; 35 | detector.Predict(image, polygons); 36 | end1 = clock(); 37 | 38 | auto dur1 = (double)(end1 - start1); 39 | printf("Detector with Batchsize: %d, Use Time: %f s\n", 1, (dur1 / CLOCKS_PER_SEC)); 40 | 41 | std::vector ROImages; 42 | 43 | for (auto &p: polygons) { 44 | cv::Mat roi; 45 | std::vector quad_pts = p.ToQuadROI(); 46 | cv::Mat transmtx = cv::getPerspectiveTransform(p.ToVec2f(), quad_pts); 47 | cv::warpPerspective(image, roi, transmtx, cv::Size((int)quad_pts[2].x, (int)quad_pts[2].y)); 48 | ROImages.emplace_back(roi); 49 | cv::polylines(image, p.ToVec2i(), true, cv::Scalar(0, 0, 255), 2); 50 | } 51 | 52 | start2 = clock(); 53 | std::map> decoded; 54 | recognizer.Predict(ROImages, decoded); 55 | end2 = clock(); 56 | 57 | auto dur2 = (double)(end2 - start2); 58 | printf("Recognizer with Batchsize: %d, Use Time: %f s\n", (int) ROImages.size(), (dur2 / CLOCKS_PER_SEC)); 59 | 60 | for (auto &d: decoded) { 61 | if (d.second.second < 0.92) continue; 62 | int fontFace = CV_FONT_HERSHEY_SIMPLEX; 63 | double fontScale = 1; 64 | int thickness = 1; 65 | int baseline = 0; 66 | int lineType = 12; 67 | cv::Point2i textOrg = polygons[d.first].ToVec2i()[0]; 68 | std::string text = d.second.first; //+ " " + std::to_string(d.second.second); 69 | cv::Size textSize = cv::getTextSize(text, fontFace, fontScale, thickness, &baseline); 70 | //cv::rectangle(image, textOrg + cv::Point(0, baseline), textOrg + cv::Point(textSize.width, -textSize.height), 71 | // cv::Scalar(255, 255, 255), -1, lineType); 72 | //cv::putText(image, text, polygons[d.first].ToVec2i()[0], fontFace, fontScale, cv::Scalar(0, 0, 0), thickness, lineType); 73 | std::cout << d.first << " " << d.second.first << " " << d.second.second << std::endl; 74 | } 75 | cv::namedWindow("demo", CV_WINDOW_NORMAL); 76 | cv::imshow("demo", image); 77 | cv::waitKey(0); 78 | } 79 | } 80 | 81 | 82 | int main(){ 83 | DemoAll(); 84 | } -------------------------------------------------------------------------------- /src/polygon.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by seeta on 19-5-5. 3 | // 4 | 5 | #include "polygon.h" 6 | -------------------------------------------------------------------------------- /src/recognizer.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by xpc on 19-4-29. 3 | // 4 | 5 | #include 6 | #include "recognizer.h" 7 | 8 | namespace SeetaOCR { 9 | 10 | void Recognizer::LoadLabelFile(const std::string &label_file) { 11 | std::ifstream infile(label_file, std::ios::in); 12 | std::string line; 13 | int i = 0; 14 | 15 | std::map().swap(label); 16 | while (std::getline(infile, line)) { 17 | label[i] = line; 18 | i++; 19 | } 20 | infile.close(); 21 | } 22 | 23 | void Recognizer::FeedImageToTensor(cv::Mat &inp) { 24 | 25 | cv::Mat image; 26 | ResizeImage(inp, image); 27 | 28 | tf::Tensor input_image_tensor(tf::DT_FLOAT, tf::TensorShape({1, image.rows, image.cols, 3})); 29 | auto input_image_tensor_ptr = input_image_tensor.tensor(); 30 | 31 | for (int n = 0; n < 1; ++n) 32 | for (int h = 0; h < image.rows; ++h) 33 | for (int w = 0; w < image.cols; ++w) 34 | for (int c = 0; c < 3; ++c) { 35 | input_image_tensor_ptr(n, h, w, c) = image.at(h, w)[2 - c]; 36 | } 37 | 38 | tf::Tensor input_width_tensor(tf::DT_INT32, tf::TensorShape({1})); 39 | auto input_input_tensor_ptr = input_width_tensor.tensor(); 40 | input_input_tensor_ptr(0) = image.cols; 41 | 42 | inputs = { 43 | {"input_images:0", input_image_tensor}, 44 | {"input_widths:0", input_width_tensor} 45 | }; 46 | } 47 | 48 | void Recognizer::FeedImagesToTensor(std::vector &inp) { 49 | std::vector images; 50 | ResizeImages(inp, images); 51 | 52 | int N = (int)images.size(); 53 | int H = images[0].rows; 54 | int W = images[0].cols; 55 | int C = 3; 56 | 57 | tf::Tensor input_images_tensor(tf::DT_FLOAT, tf::TensorShape({N, H, W, C})); 58 | auto input_images_tensor_ptr = input_images_tensor.tensor(); 59 | 60 | tf::Tensor input_widths_tensor(tf::DT_INT32, tf::TensorShape({N})); 61 | auto input_widths_tensor_ptr = input_widths_tensor.tensor(); 62 | 63 | for (int n = 0; n < N; ++n) { 64 | input_widths_tensor_ptr(n) = W; 65 | for (int h = 0; h < H; ++h) { 66 | for (int w = 0; w < W; ++w) { 67 | for (int c = 0; c < 3; ++c) { 68 | input_images_tensor_ptr(n, h, w, c) = images[n].at(h, w)[2 - c]; // BGR -> RGB 69 | } 70 | } 71 | } 72 | } 73 | 74 | inputs = { 75 | {"input_images:0", input_images_tensor}, 76 | {"input_widths:0", input_widths_tensor} 77 | }; 78 | 79 | } 80 | 81 | void Recognizer::ResizeImage(cv::Mat &inp, cv::Mat &out) { 82 | int widthNew = (int)ceil(32.0 * inp.cols / inp.rows); 83 | 84 | cv::Mat resized; 85 | cv::resize(inp, resized, cv::Size(widthNew, 32)); 86 | 87 | int widthAlign = widthNew; 88 | 89 | if (widthAlign % 32 != 0) { 90 | widthAlign = 32 * ((int)floor(widthAlign / 32.0) + 1); 91 | } 92 | 93 | cv::Mat data(32, widthAlign, CV_8UC3, cv::Scalar(0, 0, 0)); 94 | 95 | for(int h = 0 ; h < resized.rows; ++h) 96 | for(int w = 0; w < resized.cols; ++w) 97 | for(int c = 0; c < 3; ++c){ 98 | data.at(h, w)[c] = resized.at(h, w)[c]; 99 | } 100 | out = data; 101 | } 102 | 103 | void Recognizer::ResizeImages(std::vector &inp, std::vector &out) { 104 | 105 | int widthMax = 0; 106 | std::vector widths; 107 | for (auto &img: inp) { 108 | int widthNew = (int)ceil(32.0 * img.cols / img.rows); 109 | widths.push_back(widthNew); 110 | widthMax = widthNew > widthMax ? widthNew : widthMax; 111 | } 112 | 113 | if (widthMax % 32 != 0) { 114 | widthMax = 32 * ((int)floor(widthMax / 32.0) + 1); 115 | } 116 | 117 | cv::Mat resized; 118 | for (int i = 0; i < widths.size(); ++i) { 119 | cv::resize(inp[i], resized, cv::Size(widths[i], 32)); 120 | cv::Mat tmp(32, widthMax, CV_8UC3, cv::Scalar(0, 0, 0)); 121 | for(int h = 0 ; h < resized.rows; ++h) 122 | for(int w = 0; w < resized.cols; ++w) 123 | for(int c = 0; c < 3; ++c){ 124 | tmp.at(h, w)[c] = resized.at(h, w)[c]; 125 | } 126 | out.push_back(tmp); 127 | } 128 | } 129 | 130 | void Recognizer::Predict() { 131 | FetchTensor(inputs, outputs); 132 | Decode(outputs[0], outputs[1], outputs[2]); 133 | } 134 | 135 | void Recognizer::Decode(tf::Tensor& indices, tf::Tensor& values, tf::Tensor& probs) { 136 | 137 | std::map>().swap(decoded); 138 | 139 | auto indices_ptr = indices.tensor(); 140 | auto values_ptr = values.tensor(); 141 | auto probs_ptr = probs.tensor(); 142 | 143 | int idx = 0; 144 | std::string str; 145 | 146 | for (int i = 0; i < indices.dim_size(0) ; ++i){ 147 | 148 | if (i == indices.dim_size(0) - 1) { 149 | str += label[values_ptr(i)]; 150 | decoded[idx] = std::make_pair(str, probs_ptr(idx)); 151 | break; 152 | } 153 | 154 | if (idx != indices_ptr(i, 0)){ 155 | decoded[idx] = std::make_pair(str, probs_ptr(idx)); 156 | str = ""; 157 | idx = indices_ptr(i, 0); 158 | } 159 | 160 | str += label[values_ptr(i)]; 161 | } 162 | 163 | if (DEBUG) { 164 | for (auto &d: decoded) { 165 | std::cout << d.first << " " << d.second.first << " " << d.second.second << std::endl; 166 | } 167 | } 168 | } 169 | } --------------------------------------------------------------------------------