├── lanms ├── .gitignore ├── include │ ├── clipper │ │ ├── clipper.cpp │ │ └── clipper.hpp │ └── pybind11 │ │ ├── typeid.h │ │ ├── complex.h │ │ ├── options.h │ │ ├── functional.h │ │ ├── eval.h │ │ ├── buffer_info.h │ │ ├── chrono.h │ │ ├── embed.h │ │ ├── descr.h │ │ ├── operators.h │ │ ├── stl.h │ │ ├── attr.h │ │ ├── stl_bind.h │ │ └── class_support.h ├── __pycache__ │ └── __init__.cpython-36.pyc ├── __main__.py ├── Makefile ├── __init__.py ├── adaptor.cpp ├── .ycm_extra_conf.py └── lanms.h ├── README.md ├── locality_aware_nms.py ├── loss.py ├── main.py ├── model.py ├── eval.py └── data_utils.py /lanms/.gitignore: -------------------------------------------------------------------------------- 1 | adaptor.so 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EAST 2 | Reappearance 'EAST:An Efficient and Accurate Scene Text Detector' with pytorch 3 | -------------------------------------------------------------------------------- /lanms/include/clipper/clipper.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kathrine94/EAST/HEAD/lanms/include/clipper/clipper.cpp -------------------------------------------------------------------------------- /lanms/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kathrine94/EAST/HEAD/lanms/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /lanms/__main__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | from . import merge_quadrangle_n9 5 | 6 | if __name__ == '__main__': 7 | # unit square with confidence 1 8 | q = np.array([0, 0, 0, 1, 1, 1, 1, 0, 1], dtype='float32') 9 | 10 | print(merge_quadrangle_n9(np.array([q, q + 0.1, q + 2]))) 11 | -------------------------------------------------------------------------------- /lanms/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS = -I include -std=c++11 -O3 $(shell python3-config --cflags) 2 | LDFLAGS = $(shell python3-config --ldflags) 3 | 4 | DEPS = lanms.h $(shell find include -xtype f) 5 | CXX_SOURCES = adaptor.cpp include/clipper/clipper.cpp 6 | 7 | LIB_SO = adaptor.so 8 | 9 | $(LIB_SO): $(CXX_SOURCES) $(DEPS) 10 | $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC 11 | 12 | clean: 13 | rm -rf $(LIB_SO) 14 | -------------------------------------------------------------------------------- /lanms/__init__.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import numpy as np 4 | 5 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 6 | 7 | if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value 8 | raise RuntimeError('Cannot compile lanms: {}'.format(BASE_DIR)) 9 | 10 | 11 | def merge_quadrangle_n9(polys, thres=0.3, precision=10000): 12 | from .adaptor import merge_quadrangle_n9 as nms_impl 13 | if len(polys) == 0: 14 | return np.array([], dtype='float32') 15 | p = polys.copy() 16 | p[:,:8] *= precision 17 | ret = np.array(nms_impl(p, thres), dtype='float32') 18 | ret[:,:8] /= precision 19 | return ret 20 | 21 | -------------------------------------------------------------------------------- /lanms/include/pybind11/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(pybind11) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(pybind11) 54 | -------------------------------------------------------------------------------- /locality_aware_nms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from shapely.geometry import Polygon 3 | 4 | 5 | def intersection(g, p): 6 | g = Polygon(g[:8].reshape((4, 2))) 7 | p = Polygon(p[:8].reshape((4, 2))) 8 | if not g.is_valid or not p.is_valid: 9 | return 0 10 | inter = Polygon(g).intersection(Polygon(p)).area 11 | union = g.area + p.area - inter 12 | if union == 0: 13 | return 0 14 | else: 15 | return inter/union 16 | 17 | 18 | def weighted_merge(g, p): 19 | g[:8] = (g[8] * g[:8] + p[8] * p[:8])/(g[8] + p[8]) 20 | g[8] = (g[8] + p[8]) 21 | return g 22 | 23 | 24 | def standard_nms(S, thres): 25 | order = np.argsort(S[:, 8])[::-1] 26 | keep = [] 27 | while order.size > 0: 28 | i = order[0] 29 | keep.append(i) 30 | ovr = np.array([intersection(S[i], S[t]) for t in order[1:]]) 31 | 32 | inds = np.where(ovr <= thres)[0] 33 | order = order[inds+1] 34 | 35 | return S[keep] 36 | 37 | 38 | def nms_locality(polys, thres=0.3): 39 | ''' 40 | locality aware nms of EAST 41 | :param polys: a N*9 numpy array. first 8 coordinates, then prob 42 | :return: boxes after nms 43 | ''' 44 | S = [] 45 | p = None 46 | for g in polys: 47 | if p is not None and intersection(g, p) > thres: 48 | p = weighted_merge(g, p) 49 | else: 50 | if p is not None: 51 | S.append(p) 52 | p = g 53 | if p is not None: 54 | S.append(p) 55 | 56 | if len(S) == 0: 57 | return np.array([]) 58 | return standard_nms(np.array(S), thres) 59 | 60 | 61 | if __name__ == '__main__': 62 | # 343,350,448,135,474,143,369,359 63 | print(Polygon(np.array([[343, 350], [448, 135], 64 | [474, 143], [369, 359]])).area) 65 | -------------------------------------------------------------------------------- /lanms/adaptor.cpp: -------------------------------------------------------------------------------- 1 | #include "pybind11/pybind11.h" 2 | #include "pybind11/numpy.h" 3 | #include "pybind11/stl.h" 4 | #include "pybind11/stl_bind.h" 5 | 6 | #include "lanms.h" 7 | 8 | namespace py = pybind11; 9 | 10 | 11 | namespace lanms_adaptor { 12 | 13 | std::vector> polys2floats(const std::vector &polys) { 14 | std::vector> ret; 15 | for (size_t i = 0; i < polys.size(); i ++) { 16 | auto &p = polys[i]; 17 | auto &poly = p.poly; 18 | ret.emplace_back(std::vector{ 19 | float(poly[0].X), float(poly[0].Y), 20 | float(poly[1].X), float(poly[1].Y), 21 | float(poly[2].X), float(poly[2].Y), 22 | float(poly[3].X), float(poly[3].Y), 23 | float(p.score), 24 | }); 25 | } 26 | 27 | return ret; 28 | } 29 | 30 | 31 | /** 32 | * 33 | * \param quad_n9 an n-by-9 numpy array, where first 8 numbers denote the 34 | * quadrangle, and the last one is the score 35 | * \param iou_threshold two quadrangles with iou score above this threshold 36 | * will be merged 37 | * 38 | * \return an n-by-9 numpy array, the merged quadrangles 39 | */ 40 | std::vector> merge_quadrangle_n9( 41 | py::array_t quad_n9, 42 | float iou_threshold) { 43 | auto pbuf = quad_n9.request(); 44 | if (pbuf.ndim != 2 || pbuf.shape[1] != 9) 45 | throw std::runtime_error("quadrangles must have a shape of (n, 9)"); 46 | auto n = pbuf.shape[0]; 47 | auto ptr = static_cast(pbuf.ptr); 48 | return polys2floats(lanms::merge_quadrangle_n9(ptr, n, iou_threshold)); 49 | } 50 | 51 | } 52 | 53 | PYBIND11_PLUGIN(adaptor) { 54 | py::module m("adaptor", "NMS"); 55 | 56 | m.def("merge_quadrangle_n9", &lanms_adaptor::merge_quadrangle_n9, 57 | "merge quadrangels"); 58 | 59 | return m.ptr(); 60 | } 61 | 62 | -------------------------------------------------------------------------------- /lanms/include/pybind11/complex.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/complex.h: Complex number support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | /// glibc defines I as a macro which breaks things, e.g., boost template names 16 | #ifdef I 17 | # undef I 18 | #endif 19 | 20 | NAMESPACE_BEGIN(pybind11) 21 | 22 | template struct format_descriptor, detail::enable_if_t::value>> { 23 | static constexpr const char c = format_descriptor::c; 24 | static constexpr const char value[3] = { 'Z', c, '\0' }; 25 | static std::string format() { return std::string(value); } 26 | }; 27 | 28 | template constexpr const char format_descriptor< 29 | std::complex, detail::enable_if_t::value>>::value[3]; 30 | 31 | NAMESPACE_BEGIN(detail) 32 | 33 | template struct is_fmt_numeric, detail::enable_if_t::value>> { 34 | static constexpr bool value = true; 35 | static constexpr int index = is_fmt_numeric::index + 3; 36 | }; 37 | 38 | template class type_caster> { 39 | public: 40 | bool load(handle src, bool convert) { 41 | if (!src) 42 | return false; 43 | if (!convert && !PyComplex_Check(src.ptr())) 44 | return false; 45 | Py_complex result = PyComplex_AsCComplex(src.ptr()); 46 | if (result.real == -1.0 && PyErr_Occurred()) { 47 | PyErr_Clear(); 48 | return false; 49 | } 50 | value = std::complex((T) result.real, (T) result.imag); 51 | return true; 52 | } 53 | 54 | static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) { 55 | return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); 56 | } 57 | 58 | PYBIND11_TYPE_CASTER(std::complex, _("complex")); 59 | }; 60 | NAMESPACE_END(detail) 61 | NAMESPACE_END(pybind11) 62 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | 5 | ### 此处默认真实值和预测值的格式均为 bs * W * H * channels 6 | import torch 7 | import torch.nn as nn 8 | 9 | def dice_coefficient(y_true_cls, y_pred_cls, 10 | training_mask): 11 | ''' 12 | dice loss 13 | :param y_true_cls: 14 | :param y_pred_cls: 15 | :param training_mask: 16 | :return: 17 | ''' 18 | eps = 1e-5 19 | intersection =torch.sum(y_true_cls * y_pred_cls * training_mask) 20 | union = torch.sum(y_true_cls * training_mask) + torch.sum(y_pred_cls * training_mask) + eps 21 | loss = 1. - (2 * intersection / union) 22 | 23 | return loss 24 | 25 | class LossFunc(nn.Module): 26 | def __init__(self): 27 | super(LossFunc, self).__init__() 28 | return 29 | 30 | def forward(self, y_true_cls, y_pred_cls, 31 | y_true_geo, y_pred_geo, 32 | training_mask): 33 | classification_loss = dice_coefficient(y_true_cls, y_pred_cls, training_mask) 34 | # scale classification loss to match the iou loss part 35 | classification_loss *= 0.01 36 | 37 | # d1 -> top, d2->right, d3->bottom, d4->left 38 | # d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = tf.split(value=y_true_geo, num_or_size_splits=5, axis=3) 39 | d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = torch.split(y_true_geo, 1, 1) 40 | # d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = tf.split(value=y_pred_geo, num_or_size_splits=5, axis=3) 41 | d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = torch.split(y_pred_geo, 1, 1) 42 | area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt) 43 | area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred) 44 | w_union = torch.min(d2_gt, d2_pred) + torch.min(d4_gt, d4_pred) 45 | h_union = torch.min(d1_gt, d1_pred) + torch.min(d3_gt, d3_pred) 46 | area_intersect = w_union * h_union 47 | area_union = area_gt + area_pred - area_intersect 48 | L_AABB = -torch.log((area_intersect + 1.0)/(area_union + 1.0)) 49 | L_theta = 1 - torch.cos(theta_pred - theta_gt) 50 | L_g = L_AABB + 20 * L_theta 51 | 52 | return torch.mean(L_g * y_true_cls * training_mask) + classification_loss 53 | 54 | -------------------------------------------------------------------------------- /lanms/include/pybind11/options.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/options.h: global settings that are configurable at runtime. 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "common.h" 13 | 14 | NAMESPACE_BEGIN(pybind11) 15 | 16 | class options { 17 | public: 18 | 19 | // Default RAII constructor, which leaves settings as they currently are. 20 | options() : previous_state(global_state()) {} 21 | 22 | // Class is non-copyable. 23 | options(const options&) = delete; 24 | options& operator=(const options&) = delete; 25 | 26 | // Destructor, which restores settings that were in effect before. 27 | ~options() { 28 | global_state() = previous_state; 29 | } 30 | 31 | // Setter methods (affect the global state): 32 | 33 | options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; } 34 | 35 | options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; } 36 | 37 | options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; } 38 | 39 | options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; } 40 | 41 | // Getter methods (return the global state): 42 | 43 | static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; } 44 | 45 | static bool show_function_signatures() { return global_state().show_function_signatures; } 46 | 47 | // This type is not meant to be allocated on the heap. 48 | void* operator new(size_t) = delete; 49 | 50 | private: 51 | 52 | struct state { 53 | bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings. 54 | bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings. 55 | }; 56 | 57 | static state &global_state() { 58 | static state instance; 59 | return instance; 60 | } 61 | 62 | state previous_state; 63 | }; 64 | 65 | NAMESPACE_END(pybind11) 66 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.autograd import Variable 4 | import os 5 | from torch import nn 6 | from torch.optim import lr_scheduler 7 | from torch.nn.utils.rnn import pack_padded_sequence 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | from model import East 11 | from loss import * 12 | from data_utils import custom_dset, collate_fn 13 | import time 14 | from tensorboardX import SummaryWriter 15 | 16 | writer = SummaryWriter() 17 | 18 | 19 | 20 | def train(epochs, model, trainloader, crit, optimizer, 21 | scheduler, save_step, weight_decay): 22 | 23 | for e in range(epochs): 24 | print('*'* 10) 25 | print('Epoch {} / {}'.format(e + 1, epochs)) 26 | model.train() 27 | start = time.time() 28 | loss = 0.0 29 | total = 0.0 30 | for i, (img, score_map, geo_map, training_mask) in enumerate(trainloader): 31 | scheduler.step() 32 | optimizer.zero_grad() 33 | 34 | img = Variable(img.cuda()) 35 | score_map = Variable(score_map.cuda()) 36 | geo_map = Variable(geo_map.cuda()) 37 | training_mask = Variable(training_mask.cuda()) 38 | f_score, f_geometry = model(img) 39 | loss1 = crit(score_map, f_score, geo_map, f_geometry, training_mask) 40 | 41 | loss += loss1.data[0] 42 | 43 | loss1.backward() 44 | optimizer.step() 45 | 46 | during = time.time() - start 47 | print("Loss : {:.6f}, Time:{:.2f} s ".format(loss/len(trainloader), during)) 48 | print() 49 | writer.add_scalar('loss', loss / len(trainloader), e) 50 | 51 | if (e + 1) % save_step == 0: 52 | if not os.path.exists('./checkpoints'): 53 | os.mkdir('./checkpoints') 54 | torch.save(model.state_dict(), './checkpoints/model_{}.pth'.format(e + 1)) 55 | 56 | 57 | def main(): 58 | root_path = '/home/mathu/Documents/express_recognition/data/telephone_txt/result/' 59 | train_img = root_path + 'print_pic' 60 | train_txt = root_path + 'print_txt' 61 | # root_path = '/home/mathu/Documents/express_recognition/data/icdar2015/' 62 | # train_img = root_path + 'train2015' 63 | # train_txt = root_path + 'train_label' 64 | 65 | trainset = custom_dset(train_img, train_txt) 66 | trainloader = DataLoader( 67 | trainset, batch_size=16, shuffle=True, collate_fn=collate_fn, num_workers=4) 68 | model = East() 69 | model = model.cuda() 70 | model.load_state_dict(torch.load('./checkpoints_total/model_1440.pth')) 71 | 72 | crit = LossFunc() 73 | weight_decay = 0 74 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 75 | # weight_decay=1) 76 | scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, 77 | gamma=0.94) 78 | 79 | train(epochs=1500, model=model, trainloader=trainloader, 80 | crit=crit, optimizer=optimizer,scheduler=scheduler, 81 | save_step=20, weight_decay=weight_decay) 82 | 83 | write.close() 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /lanms/include/pybind11/functional.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/functional.h: std::function<> support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | NAMESPACE_BEGIN(pybind11) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | template 19 | struct type_caster> { 20 | using type = std::function; 21 | using retval_type = conditional_t::value, void_type, Return>; 22 | using function_type = Return (*) (Args...); 23 | 24 | public: 25 | bool load(handle src, bool convert) { 26 | if (src.is_none()) { 27 | // Defer accepting None to other overloads (if we aren't in convert mode): 28 | if (!convert) return false; 29 | return true; 30 | } 31 | 32 | if (!isinstance(src)) 33 | return false; 34 | 35 | auto func = reinterpret_borrow(src); 36 | 37 | /* 38 | When passing a C++ function as an argument to another C++ 39 | function via Python, every function call would normally involve 40 | a full C++ -> Python -> C++ roundtrip, which can be prohibitive. 41 | Here, we try to at least detect the case where the function is 42 | stateless (i.e. function pointer or lambda function without 43 | captured variables), in which case the roundtrip can be avoided. 44 | */ 45 | if (auto cfunc = func.cpp_function()) { 46 | auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); 47 | auto rec = (function_record *) c; 48 | 49 | if (rec && rec->is_stateless && 50 | same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { 51 | struct capture { function_type f; }; 52 | value = ((capture *) &rec->data)->f; 53 | return true; 54 | } 55 | } 56 | 57 | value = [func](Args... args) -> Return { 58 | gil_scoped_acquire acq; 59 | object retval(func(std::forward(args)...)); 60 | /* Visual studio 2015 parser issue: need parentheses around this expression */ 61 | return (retval.template cast()); 62 | }; 63 | return true; 64 | } 65 | 66 | template 67 | static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) { 68 | if (!f_) 69 | return none().inc_ref(); 70 | 71 | auto result = f_.template target(); 72 | if (result) 73 | return cpp_function(*result, policy).release(); 74 | else 75 | return cpp_function(std::forward(f_), policy).release(); 76 | } 77 | 78 | PYBIND11_TYPE_CASTER(type, _("Callable[[") + 79 | argument_loader::arg_names() + _("], ") + 80 | make_caster::name() + 81 | _("]")); 82 | }; 83 | 84 | NAMESPACE_END(detail) 85 | NAMESPACE_END(pybind11) 86 | -------------------------------------------------------------------------------- /lanms/include/pybind11/eval.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/exec.h: Support for evaluating Python expressions and statements 3 | from strings and files 4 | 5 | Copyright (c) 2016 Klemens Morgenstern and 6 | Wenzel Jakob 7 | 8 | All rights reserved. Use of this source code is governed by a 9 | BSD-style license that can be found in the LICENSE file. 10 | */ 11 | 12 | #pragma once 13 | 14 | #include "pybind11.h" 15 | 16 | NAMESPACE_BEGIN(pybind11) 17 | 18 | enum eval_mode { 19 | /// Evaluate a string containing an isolated expression 20 | eval_expr, 21 | 22 | /// Evaluate a string containing a single statement. Returns \c none 23 | eval_single_statement, 24 | 25 | /// Evaluate a string containing a sequence of statement. Returns \c none 26 | eval_statements 27 | }; 28 | 29 | template 30 | object eval(str expr, object global = globals(), object local = object()) { 31 | if (!local) 32 | local = global; 33 | 34 | /* PyRun_String does not accept a PyObject / encoding specifier, 35 | this seems to be the only alternative */ 36 | std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr; 37 | 38 | int start; 39 | switch (mode) { 40 | case eval_expr: start = Py_eval_input; break; 41 | case eval_single_statement: start = Py_single_input; break; 42 | case eval_statements: start = Py_file_input; break; 43 | default: pybind11_fail("invalid evaluation mode"); 44 | } 45 | 46 | PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr()); 47 | if (!result) 48 | throw error_already_set(); 49 | return reinterpret_steal(result); 50 | } 51 | 52 | template 53 | object eval(const char (&s)[N], object global = globals(), object local = object()) { 54 | /* Support raw string literals by removing common leading whitespace */ 55 | auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s)) 56 | : str(s); 57 | return eval(expr, global, local); 58 | } 59 | 60 | inline void exec(str expr, object global = globals(), object local = object()) { 61 | eval(expr, global, local); 62 | } 63 | 64 | template 65 | void exec(const char (&s)[N], object global = globals(), object local = object()) { 66 | eval(s, global, local); 67 | } 68 | 69 | template 70 | object eval_file(str fname, object global = globals(), object local = object()) { 71 | if (!local) 72 | local = global; 73 | 74 | int start; 75 | switch (mode) { 76 | case eval_expr: start = Py_eval_input; break; 77 | case eval_single_statement: start = Py_single_input; break; 78 | case eval_statements: start = Py_file_input; break; 79 | default: pybind11_fail("invalid evaluation mode"); 80 | } 81 | 82 | int closeFile = 1; 83 | std::string fname_str = (std::string) fname; 84 | #if PY_VERSION_HEX >= 0x03040000 85 | FILE *f = _Py_fopen_obj(fname.ptr(), "r"); 86 | #elif PY_VERSION_HEX >= 0x03000000 87 | FILE *f = _Py_fopen(fname.ptr(), "r"); 88 | #else 89 | /* No unicode support in open() :( */ 90 | auto fobj = reinterpret_steal(PyFile_FromString( 91 | const_cast(fname_str.c_str()), 92 | const_cast("r"))); 93 | FILE *f = nullptr; 94 | if (fobj) 95 | f = PyFile_AsFile(fobj.ptr()); 96 | closeFile = 0; 97 | #endif 98 | if (!f) { 99 | PyErr_Clear(); 100 | pybind11_fail("File \"" + fname_str + "\" could not be opened!"); 101 | } 102 | 103 | #if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION) 104 | PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(), 105 | local.ptr()); 106 | (void) closeFile; 107 | #else 108 | PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(), 109 | local.ptr(), closeFile); 110 | #endif 111 | 112 | if (!result) 113 | throw error_already_set(); 114 | return reinterpret_steal(result); 115 | } 116 | 117 | NAMESPACE_END(pybind11) 118 | -------------------------------------------------------------------------------- /lanms/include/pybind11/buffer_info.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/buffer_info.h: Python buffer object interface 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "common.h" 13 | 14 | NAMESPACE_BEGIN(pybind11) 15 | 16 | /// Information record describing a Python buffer object 17 | struct buffer_info { 18 | void *ptr = nullptr; // Pointer to the underlying storage 19 | ssize_t itemsize = 0; // Size of individual items in bytes 20 | ssize_t size = 0; // Total number of entries 21 | std::string format; // For homogeneous buffers, this should be set to format_descriptor::format() 22 | ssize_t ndim = 0; // Number of dimensions 23 | std::vector shape; // Shape of the tensor (1 entry per dimension) 24 | std::vector strides; // Number of entries between adjacent entries (for each per dimension) 25 | 26 | buffer_info() { } 27 | 28 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 29 | detail::any_container shape_in, detail::any_container strides_in) 30 | : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), 31 | shape(std::move(shape_in)), strides(std::move(strides_in)) { 32 | if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) 33 | pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); 34 | for (size_t i = 0; i < (size_t) ndim; ++i) 35 | size *= shape[i]; 36 | } 37 | 38 | template 39 | buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in) 40 | : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { } 41 | 42 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size) 43 | : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { } 44 | 45 | template 46 | buffer_info(T *ptr, ssize_t size) 47 | : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { } 48 | 49 | explicit buffer_info(Py_buffer *view, bool ownview = true) 50 | : buffer_info(view->buf, view->itemsize, view->format, view->ndim, 51 | {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) { 52 | this->view = view; 53 | this->ownview = ownview; 54 | } 55 | 56 | buffer_info(const buffer_info &) = delete; 57 | buffer_info& operator=(const buffer_info &) = delete; 58 | 59 | buffer_info(buffer_info &&other) { 60 | (*this) = std::move(other); 61 | } 62 | 63 | buffer_info& operator=(buffer_info &&rhs) { 64 | ptr = rhs.ptr; 65 | itemsize = rhs.itemsize; 66 | size = rhs.size; 67 | format = std::move(rhs.format); 68 | ndim = rhs.ndim; 69 | shape = std::move(rhs.shape); 70 | strides = std::move(rhs.strides); 71 | std::swap(view, rhs.view); 72 | std::swap(ownview, rhs.ownview); 73 | return *this; 74 | } 75 | 76 | ~buffer_info() { 77 | if (view && ownview) { PyBuffer_Release(view); delete view; } 78 | } 79 | 80 | private: 81 | struct private_ctr_tag { }; 82 | 83 | buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 84 | detail::any_container &&shape_in, detail::any_container &&strides_in) 85 | : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { } 86 | 87 | Py_buffer *view = nullptr; 88 | bool ownview = false; 89 | }; 90 | 91 | NAMESPACE_BEGIN(detail) 92 | 93 | template struct compare_buffer_info { 94 | static bool compare(const buffer_info& b) { 95 | return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); 96 | } 97 | }; 98 | 99 | template struct compare_buffer_info::value>> { 100 | static bool compare(const buffer_info& b) { 101 | return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value || 102 | ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || 103 | ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); 104 | } 105 | }; 106 | 107 | NAMESPACE_END(detail) 108 | NAMESPACE_END(pybind11) 109 | -------------------------------------------------------------------------------- /lanms/.ycm_extra_conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (C) 2014 Google Inc. 4 | # 5 | # This file is part of YouCompleteMe. 6 | # 7 | # YouCompleteMe is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # YouCompleteMe is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with YouCompleteMe. If not, see . 19 | 20 | import os 21 | import sys 22 | import glob 23 | import ycm_core 24 | 25 | # These are the compilation flags that will be used in case there's no 26 | # compilation database set (by default, one is not set). 27 | # CHANGE THIS LIST OF FLAGS. YES, THIS IS THE DROID YOU HAVE BEEN LOOKING FOR. 28 | sys.path.append(os.path.dirname(__file__)) 29 | 30 | 31 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 32 | 33 | from plumbum.cmd import python_config 34 | 35 | 36 | flags = [ 37 | '-Wall', 38 | '-Wextra', 39 | '-Wnon-virtual-dtor', 40 | '-Winvalid-pch', 41 | '-Wno-unused-local-typedefs', 42 | '-std=c++11', 43 | '-x', 'c++', 44 | '-Iinclude', 45 | ] + python_config('--cflags').split() 46 | 47 | 48 | # Set this to the absolute path to the folder (NOT the file!) containing the 49 | # compile_commands.json file to use that instead of 'flags'. See here for 50 | # more details: http://clang.llvm.org/docs/JSONCompilationDatabase.html 51 | # 52 | # Most projects will NOT need to set this to anything; you can just change the 53 | # 'flags' list of compilation flags. 54 | compilation_database_folder = '' 55 | 56 | if os.path.exists( compilation_database_folder ): 57 | database = ycm_core.CompilationDatabase( compilation_database_folder ) 58 | else: 59 | database = None 60 | 61 | SOURCE_EXTENSIONS = [ '.cpp', '.cxx', '.cc', '.c', '.m', '.mm' ] 62 | 63 | def DirectoryOfThisScript(): 64 | return os.path.dirname( os.path.abspath( __file__ ) ) 65 | 66 | 67 | def MakeRelativePathsInFlagsAbsolute( flags, working_directory ): 68 | if not working_directory: 69 | return list( flags ) 70 | new_flags = [] 71 | make_next_absolute = False 72 | path_flags = [ '-isystem', '-I', '-iquote', '--sysroot=' ] 73 | for flag in flags: 74 | new_flag = flag 75 | 76 | if make_next_absolute: 77 | make_next_absolute = False 78 | if not flag.startswith( '/' ): 79 | new_flag = os.path.join( working_directory, flag ) 80 | 81 | for path_flag in path_flags: 82 | if flag == path_flag: 83 | make_next_absolute = True 84 | break 85 | 86 | if flag.startswith( path_flag ): 87 | path = flag[ len( path_flag ): ] 88 | new_flag = path_flag + os.path.join( working_directory, path ) 89 | break 90 | 91 | if new_flag: 92 | new_flags.append( new_flag ) 93 | return new_flags 94 | 95 | 96 | def IsHeaderFile( filename ): 97 | extension = os.path.splitext( filename )[ 1 ] 98 | return extension in [ '.h', '.hxx', '.hpp', '.hh' ] 99 | 100 | 101 | def GetCompilationInfoForFile( filename ): 102 | # The compilation_commands.json file generated by CMake does not have entries 103 | # for header files. So we do our best by asking the db for flags for a 104 | # corresponding source file, if any. If one exists, the flags for that file 105 | # should be good enough. 106 | if IsHeaderFile( filename ): 107 | basename = os.path.splitext( filename )[ 0 ] 108 | for extension in SOURCE_EXTENSIONS: 109 | replacement_file = basename + extension 110 | if os.path.exists( replacement_file ): 111 | compilation_info = database.GetCompilationInfoForFile( 112 | replacement_file ) 113 | if compilation_info.compiler_flags_: 114 | return compilation_info 115 | return None 116 | return database.GetCompilationInfoForFile( filename ) 117 | 118 | 119 | # This is the entry point; this function is called by ycmd to produce flags for 120 | # a file. 121 | def FlagsForFile( filename, **kwargs ): 122 | if database: 123 | # Bear in mind that compilation_info.compiler_flags_ does NOT return a 124 | # python list, but a "list-like" StringVec object 125 | compilation_info = GetCompilationInfoForFile( filename ) 126 | if not compilation_info: 127 | return None 128 | 129 | final_flags = MakeRelativePathsInFlagsAbsolute( 130 | compilation_info.compiler_flags_, 131 | compilation_info.compiler_working_dir_ ) 132 | else: 133 | relative_to = DirectoryOfThisScript() 134 | final_flags = MakeRelativePathsInFlagsAbsolute( flags, relative_to ) 135 | 136 | return { 137 | 'flags': final_flags, 138 | 'do_cache': True 139 | } 140 | 141 | -------------------------------------------------------------------------------- /lanms/lanms.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "clipper/clipper.hpp" 4 | 5 | // locality-aware NMS 6 | namespace lanms { 7 | 8 | namespace cl = ClipperLib; 9 | 10 | struct Polygon { 11 | cl::Path poly; 12 | float score; 13 | }; 14 | 15 | float paths_area(const ClipperLib::Paths &ps) { 16 | float area = 0; 17 | for (auto &&p: ps) 18 | area += cl::Area(p); 19 | return area; 20 | } 21 | 22 | float poly_iou(const Polygon &a, const Polygon &b) { 23 | cl::Clipper clpr; 24 | clpr.AddPath(a.poly, cl::ptSubject, true); 25 | clpr.AddPath(b.poly, cl::ptClip, true); 26 | 27 | cl::Paths inter, uni; 28 | clpr.Execute(cl::ctIntersection, inter, cl::pftEvenOdd); 29 | clpr.Execute(cl::ctUnion, uni, cl::pftEvenOdd); 30 | 31 | auto inter_area = paths_area(inter), 32 | uni_area = paths_area(uni); 33 | return std::abs(inter_area) / std::max(std::abs(uni_area), 1.0f); 34 | } 35 | 36 | bool should_merge(const Polygon &a, const Polygon &b, float iou_threshold) { 37 | return poly_iou(a, b) > iou_threshold; 38 | } 39 | 40 | /** 41 | * Incrementally merge polygons 42 | */ 43 | class PolyMerger { 44 | public: 45 | PolyMerger(): score(0), nr_polys(0) { 46 | memset(data, 0, sizeof(data)); 47 | } 48 | 49 | /** 50 | * Add a new polygon to be merged. 51 | */ 52 | void add(const Polygon &p_given) { 53 | Polygon p; 54 | if (nr_polys > 0) { 55 | // vertices of two polygons to merge may not in the same order; 56 | // we match their vertices by choosing the ordering that 57 | // minimizes the total squared distance. 58 | // see function normalize_poly for details. 59 | p = normalize_poly(get(), p_given); 60 | } else { 61 | p = p_given; 62 | } 63 | assert(p.poly.size() == 4); 64 | auto &poly = p.poly; 65 | auto s = p.score; 66 | data[0] += poly[0].X * s; 67 | data[1] += poly[0].Y * s; 68 | 69 | data[2] += poly[1].X * s; 70 | data[3] += poly[1].Y * s; 71 | 72 | data[4] += poly[2].X * s; 73 | data[5] += poly[2].Y * s; 74 | 75 | data[6] += poly[3].X * s; 76 | data[7] += poly[3].Y * s; 77 | 78 | score += p.score; 79 | 80 | nr_polys += 1; 81 | } 82 | 83 | inline std::int64_t sqr(std::int64_t x) { return x * x; } 84 | 85 | Polygon normalize_poly( 86 | const Polygon &ref, 87 | const Polygon &p) { 88 | 89 | std::int64_t min_d = std::numeric_limits::max(); 90 | size_t best_start = 0, best_order = 0; 91 | 92 | for (size_t start = 0; start < 4; start ++) { 93 | size_t j = start; 94 | std::int64_t d = ( 95 | sqr(ref.poly[(j + 0) % 4].X - p.poly[(j + 0) % 4].X) 96 | + sqr(ref.poly[(j + 0) % 4].Y - p.poly[(j + 0) % 4].Y) 97 | + sqr(ref.poly[(j + 1) % 4].X - p.poly[(j + 1) % 4].X) 98 | + sqr(ref.poly[(j + 1) % 4].Y - p.poly[(j + 1) % 4].Y) 99 | + sqr(ref.poly[(j + 2) % 4].X - p.poly[(j + 2) % 4].X) 100 | + sqr(ref.poly[(j + 2) % 4].Y - p.poly[(j + 2) % 4].Y) 101 | + sqr(ref.poly[(j + 3) % 4].X - p.poly[(j + 3) % 4].X) 102 | + sqr(ref.poly[(j + 3) % 4].Y - p.poly[(j + 3) % 4].Y) 103 | ); 104 | if (d < min_d) { 105 | min_d = d; 106 | best_start = start; 107 | best_order = 0; 108 | } 109 | 110 | d = ( 111 | sqr(ref.poly[(j + 0) % 4].X - p.poly[(j + 3) % 4].X) 112 | + sqr(ref.poly[(j + 0) % 4].Y - p.poly[(j + 3) % 4].Y) 113 | + sqr(ref.poly[(j + 1) % 4].X - p.poly[(j + 2) % 4].X) 114 | + sqr(ref.poly[(j + 1) % 4].Y - p.poly[(j + 2) % 4].Y) 115 | + sqr(ref.poly[(j + 2) % 4].X - p.poly[(j + 1) % 4].X) 116 | + sqr(ref.poly[(j + 2) % 4].Y - p.poly[(j + 1) % 4].Y) 117 | + sqr(ref.poly[(j + 3) % 4].X - p.poly[(j + 0) % 4].X) 118 | + sqr(ref.poly[(j + 3) % 4].Y - p.poly[(j + 0) % 4].Y) 119 | ); 120 | if (d < min_d) { 121 | min_d = d; 122 | best_start = start; 123 | best_order = 1; 124 | } 125 | } 126 | 127 | Polygon r; 128 | r.poly.resize(4); 129 | auto j = best_start; 130 | if (best_order == 0) { 131 | for (size_t i = 0; i < 4; i ++) 132 | r.poly[i] = p.poly[(j + i) % 4]; 133 | } else { 134 | for (size_t i = 0; i < 4; i ++) 135 | r.poly[i] = p.poly[(j + 4 - i - 1) % 4]; 136 | } 137 | r.score = p.score; 138 | return r; 139 | } 140 | 141 | Polygon get() const { 142 | Polygon p; 143 | 144 | auto &poly = p.poly; 145 | poly.resize(4); 146 | auto score_inv = 1.0f / std::max(1e-8f, score); 147 | poly[0].X = data[0] * score_inv; 148 | poly[0].Y = data[1] * score_inv; 149 | poly[1].X = data[2] * score_inv; 150 | poly[1].Y = data[3] * score_inv; 151 | poly[2].X = data[4] * score_inv; 152 | poly[2].Y = data[5] * score_inv; 153 | poly[3].X = data[6] * score_inv; 154 | poly[3].Y = data[7] * score_inv; 155 | 156 | assert(score > 0); 157 | p.score = score; 158 | 159 | return p; 160 | } 161 | 162 | private: 163 | std::int64_t data[8]; 164 | float score; 165 | std::int32_t nr_polys; 166 | }; 167 | 168 | 169 | /** 170 | * The standard NMS algorithm. 171 | */ 172 | std::vector standard_nms(std::vector &polys, float iou_threshold) { 173 | size_t n = polys.size(); 174 | if (n == 0) 175 | return {}; 176 | std::vector indices(n); 177 | std::iota(std::begin(indices), std::end(indices), 0); 178 | std::sort(std::begin(indices), std::end(indices), [&](size_t i, size_t j) { return polys[i].score > polys[j].score; }); 179 | 180 | std::vector keep; 181 | while (indices.size()) { 182 | size_t p = 0, cur = indices[0]; 183 | keep.emplace_back(cur); 184 | for (size_t i = 1; i < indices.size(); i ++) { 185 | if (!should_merge(polys[cur], polys[indices[i]], iou_threshold)) { 186 | indices[p ++] = indices[i]; 187 | } 188 | } 189 | indices.resize(p); 190 | } 191 | 192 | std::vector ret; 193 | for (auto &&i: keep) { 194 | ret.emplace_back(polys[i]); 195 | } 196 | return ret; 197 | } 198 | 199 | std::vector 200 | merge_quadrangle_n9(const float *data, size_t n, float iou_threshold) { 201 | using cInt = cl::cInt; 202 | 203 | // first pass 204 | std::vector polys; 205 | for (size_t i = 0; i < n; i ++) { 206 | auto p = data + i * 9; 207 | Polygon poly{ 208 | { 209 | {cInt(p[0]), cInt(p[1])}, 210 | {cInt(p[2]), cInt(p[3])}, 211 | {cInt(p[4]), cInt(p[5])}, 212 | {cInt(p[6]), cInt(p[7])}, 213 | }, 214 | p[8], 215 | }; 216 | 217 | if (polys.size()) { 218 | // merge with the last one 219 | auto &bpoly = polys.back(); 220 | if (should_merge(poly, bpoly, iou_threshold)) { 221 | PolyMerger merger; 222 | merger.add(bpoly); 223 | merger.add(poly); 224 | bpoly = merger.get(); 225 | } else { 226 | polys.emplace_back(poly); 227 | } 228 | } else { 229 | polys.emplace_back(poly); 230 | } 231 | } 232 | return standard_nms(polys, iou_threshold); 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /lanms/include/pybind11/chrono.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime 3 | 4 | Copyright (c) 2016 Trent Houliston and 5 | Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "pybind11.h" 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | // Backport the PyDateTime_DELTA functions from Python3.3 if required 20 | #ifndef PyDateTime_DELTA_GET_DAYS 21 | #define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days) 22 | #endif 23 | #ifndef PyDateTime_DELTA_GET_SECONDS 24 | #define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds) 25 | #endif 26 | #ifndef PyDateTime_DELTA_GET_MICROSECONDS 27 | #define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds) 28 | #endif 29 | 30 | NAMESPACE_BEGIN(pybind11) 31 | NAMESPACE_BEGIN(detail) 32 | 33 | template class duration_caster { 34 | public: 35 | typedef typename type::rep rep; 36 | typedef typename type::period period; 37 | 38 | typedef std::chrono::duration> days; 39 | 40 | bool load(handle src, bool) { 41 | using namespace std::chrono; 42 | 43 | // Lazy initialise the PyDateTime import 44 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 45 | 46 | if (!src) return false; 47 | // If invoked with datetime.delta object 48 | if (PyDelta_Check(src.ptr())) { 49 | value = type(duration_cast>( 50 | days(PyDateTime_DELTA_GET_DAYS(src.ptr())) 51 | + seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr())) 52 | + microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr())))); 53 | return true; 54 | } 55 | // If invoked with a float we assume it is seconds and convert 56 | else if (PyFloat_Check(src.ptr())) { 57 | value = type(duration_cast>(duration(PyFloat_AsDouble(src.ptr())))); 58 | return true; 59 | } 60 | else return false; 61 | } 62 | 63 | // If this is a duration just return it back 64 | static const std::chrono::duration& get_duration(const std::chrono::duration &src) { 65 | return src; 66 | } 67 | 68 | // If this is a time_point get the time_since_epoch 69 | template static std::chrono::duration get_duration(const std::chrono::time_point> &src) { 70 | return src.time_since_epoch(); 71 | } 72 | 73 | static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) { 74 | using namespace std::chrono; 75 | 76 | // Use overloaded function to get our duration from our source 77 | // Works out if it is a duration or time_point and get the duration 78 | auto d = get_duration(src); 79 | 80 | // Lazy initialise the PyDateTime import 81 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 82 | 83 | // Declare these special duration types so the conversions happen with the correct primitive types (int) 84 | using dd_t = duration>; 85 | using ss_t = duration>; 86 | using us_t = duration; 87 | 88 | auto dd = duration_cast(d); 89 | auto subd = d - dd; 90 | auto ss = duration_cast(subd); 91 | auto us = duration_cast(subd - ss); 92 | return PyDelta_FromDSU(dd.count(), ss.count(), us.count()); 93 | } 94 | 95 | PYBIND11_TYPE_CASTER(type, _("datetime.timedelta")); 96 | }; 97 | 98 | // This is for casting times on the system clock into datetime.datetime instances 99 | template class type_caster> { 100 | public: 101 | typedef std::chrono::time_point type; 102 | bool load(handle src, bool) { 103 | using namespace std::chrono; 104 | 105 | // Lazy initialise the PyDateTime import 106 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 107 | 108 | if (!src) return false; 109 | if (PyDateTime_Check(src.ptr())) { 110 | std::tm cal; 111 | cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr()); 112 | cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr()); 113 | cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr()); 114 | cal.tm_mday = PyDateTime_GET_DAY(src.ptr()); 115 | cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1; 116 | cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900; 117 | cal.tm_isdst = -1; 118 | 119 | value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr())); 120 | return true; 121 | } 122 | else return false; 123 | } 124 | 125 | static handle cast(const std::chrono::time_point &src, return_value_policy /* policy */, handle /* parent */) { 126 | using namespace std::chrono; 127 | 128 | // Lazy initialise the PyDateTime import 129 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 130 | 131 | std::time_t tt = system_clock::to_time_t(src); 132 | // this function uses static memory so it's best to copy it out asap just in case 133 | // otherwise other code that is using localtime may break this (not just python code) 134 | std::tm localtime = *std::localtime(&tt); 135 | 136 | // Declare these special duration types so the conversions happen with the correct primitive types (int) 137 | using us_t = duration; 138 | 139 | return PyDateTime_FromDateAndTime(localtime.tm_year + 1900, 140 | localtime.tm_mon + 1, 141 | localtime.tm_mday, 142 | localtime.tm_hour, 143 | localtime.tm_min, 144 | localtime.tm_sec, 145 | (duration_cast(src.time_since_epoch() % seconds(1))).count()); 146 | } 147 | PYBIND11_TYPE_CASTER(type, _("datetime.datetime")); 148 | }; 149 | 150 | // Other clocks that are not the system clock are not measured as datetime.datetime objects 151 | // since they are not measured on calendar time. So instead we just make them timedeltas 152 | // Or if they have passed us a time as a float we convert that 153 | template class type_caster> 154 | : public duration_caster> { 155 | }; 156 | 157 | template class type_caster> 158 | : public duration_caster> { 159 | }; 160 | 161 | NAMESPACE_END(detail) 162 | NAMESPACE_END(pybind11) 163 | -------------------------------------------------------------------------------- /lanms/include/pybind11/embed.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/embed.h: Support for embedding the interpreter 3 | 4 | Copyright (c) 2017 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include "eval.h" 14 | 15 | #if defined(PYPY_VERSION) 16 | # error Embedding the interpreter is not supported with PyPy 17 | #endif 18 | 19 | #if PY_MAJOR_VERSION >= 3 20 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 21 | extern "C" PyObject *pybind11_init_impl_##name() { \ 22 | return pybind11_init_wrapper_##name(); \ 23 | } 24 | #else 25 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 26 | extern "C" void pybind11_init_impl_##name() { \ 27 | pybind11_init_wrapper_##name(); \ 28 | } 29 | #endif 30 | 31 | /** \rst 32 | Add a new module to the table of builtins for the interpreter. Must be 33 | defined in global scope. The first macro parameter is the name of the 34 | module (without quotes). The second parameter is the variable which will 35 | be used as the interface to add functions and classes to the module. 36 | 37 | .. code-block:: cpp 38 | 39 | PYBIND11_EMBEDDED_MODULE(example, m) { 40 | // ... initialize functions and classes here 41 | m.def("foo", []() { 42 | return "Hello, World!"; 43 | }); 44 | } 45 | \endrst */ 46 | #define PYBIND11_EMBEDDED_MODULE(name, variable) \ 47 | static void pybind11_init_##name(pybind11::module &); \ 48 | static PyObject *pybind11_init_wrapper_##name() { \ 49 | auto m = pybind11::module(#name); \ 50 | try { \ 51 | pybind11_init_##name(m); \ 52 | return m.ptr(); \ 53 | } catch (pybind11::error_already_set &e) { \ 54 | PyErr_SetString(PyExc_ImportError, e.what()); \ 55 | return nullptr; \ 56 | } catch (const std::exception &e) { \ 57 | PyErr_SetString(PyExc_ImportError, e.what()); \ 58 | return nullptr; \ 59 | } \ 60 | } \ 61 | PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 62 | pybind11::detail::embedded_module name(#name, pybind11_init_impl_##name); \ 63 | void pybind11_init_##name(pybind11::module &variable) 64 | 65 | 66 | NAMESPACE_BEGIN(pybind11) 67 | NAMESPACE_BEGIN(detail) 68 | 69 | /// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks. 70 | struct embedded_module { 71 | #if PY_MAJOR_VERSION >= 3 72 | using init_t = PyObject *(*)(); 73 | #else 74 | using init_t = void (*)(); 75 | #endif 76 | embedded_module(const char *name, init_t init) { 77 | if (Py_IsInitialized()) 78 | pybind11_fail("Can't add new modules after the interpreter has been initialized"); 79 | 80 | auto result = PyImport_AppendInittab(name, init); 81 | if (result == -1) 82 | pybind11_fail("Insufficient memory to add a new module"); 83 | } 84 | }; 85 | 86 | NAMESPACE_END(detail) 87 | 88 | /** \rst 89 | Initialize the Python interpreter. No other pybind11 or CPython API functions can be 90 | called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The 91 | optional parameter can be used to skip the registration of signal handlers (see the 92 | Python documentation for details). Calling this function again after the interpreter 93 | has already been initialized is a fatal error. 94 | \endrst */ 95 | inline void initialize_interpreter(bool init_signal_handlers = true) { 96 | if (Py_IsInitialized()) 97 | pybind11_fail("The interpreter is already running"); 98 | 99 | Py_InitializeEx(init_signal_handlers ? 1 : 0); 100 | 101 | // Make .py files in the working directory available by default 102 | auto sys_path = reinterpret_borrow(module::import("sys").attr("path")); 103 | sys_path.append("."); 104 | } 105 | 106 | /** \rst 107 | Shut down the Python interpreter. No pybind11 or CPython API functions can be called 108 | after this. In addition, pybind11 objects must not outlive the interpreter: 109 | 110 | .. code-block:: cpp 111 | 112 | { // BAD 113 | py::initialize_interpreter(); 114 | auto hello = py::str("Hello, World!"); 115 | py::finalize_interpreter(); 116 | } // <-- BOOM, hello's destructor is called after interpreter shutdown 117 | 118 | { // GOOD 119 | py::initialize_interpreter(); 120 | { // scoped 121 | auto hello = py::str("Hello, World!"); 122 | } // <-- OK, hello is cleaned up properly 123 | py::finalize_interpreter(); 124 | } 125 | 126 | { // BETTER 127 | py::scoped_interpreter guard{}; 128 | auto hello = py::str("Hello, World!"); 129 | } 130 | 131 | .. warning:: 132 | 133 | The interpreter can be restarted by calling `initialize_interpreter` again. 134 | Modules created using pybind11 can be safely re-initialized. However, Python 135 | itself cannot completely unload binary extension modules and there are several 136 | caveats with regard to interpreter restarting. All the details can be found 137 | in the CPython documentation. In short, not all interpreter memory may be 138 | freed, either due to reference cycles or user-created global data. 139 | 140 | \endrst */ 141 | inline void finalize_interpreter() { 142 | handle builtins(PyEval_GetBuiltins()); 143 | const char *id = PYBIND11_INTERNALS_ID; 144 | 145 | // Get the internals pointer (without creating it if it doesn't exist). It's possible for the 146 | // internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()` 147 | // during destruction), so we get the pointer-pointer here and check it after Py_Finalize(). 148 | detail::internals **internals_ptr_ptr = &detail::get_internals_ptr(); 149 | // It could also be stashed in builtins, so look there too: 150 | if (builtins.contains(id) && isinstance(builtins[id])) 151 | internals_ptr_ptr = capsule(builtins[id]); 152 | 153 | Py_Finalize(); 154 | 155 | if (internals_ptr_ptr) { 156 | delete *internals_ptr_ptr; 157 | *internals_ptr_ptr = nullptr; 158 | } 159 | } 160 | 161 | /** \rst 162 | Scope guard version of `initialize_interpreter` and `finalize_interpreter`. 163 | This a move-only guard and only a single instance can exist. 164 | 165 | .. code-block:: cpp 166 | 167 | #include 168 | 169 | int main() { 170 | py::scoped_interpreter guard{}; 171 | py::print(Hello, World!); 172 | } // <-- interpreter shutdown 173 | \endrst */ 174 | class scoped_interpreter { 175 | public: 176 | scoped_interpreter(bool init_signal_handlers = true) { 177 | initialize_interpreter(init_signal_handlers); 178 | } 179 | 180 | scoped_interpreter(const scoped_interpreter &) = delete; 181 | scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; } 182 | scoped_interpreter &operator=(const scoped_interpreter &) = delete; 183 | scoped_interpreter &operator=(scoped_interpreter &&) = delete; 184 | 185 | ~scoped_interpreter() { 186 | if (is_valid) 187 | finalize_interpreter(); 188 | } 189 | 190 | private: 191 | bool is_valid = true; 192 | }; 193 | 194 | NAMESPACE_END(pybind11) 195 | -------------------------------------------------------------------------------- /lanms/include/pybind11/descr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/descr.h: Helper type for concatenating type signatures 3 | either at runtime (C++11) or compile time (C++14) 4 | 5 | Copyright (c) 2016 Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "common.h" 14 | 15 | NAMESPACE_BEGIN(pybind11) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | /* Concatenate type signatures at compile time using C++14 */ 19 | #if defined(PYBIND11_CPP14) && !defined(_MSC_VER) 20 | #define PYBIND11_CONSTEXPR_DESCR 21 | 22 | template class descr { 23 | template friend class descr; 24 | public: 25 | constexpr descr(char const (&text) [Size1+1], const std::type_info * const (&types)[Size2+1]) 26 | : descr(text, types, 27 | make_index_sequence(), 28 | make_index_sequence()) { } 29 | 30 | constexpr const char *text() const { return m_text; } 31 | constexpr const std::type_info * const * types() const { return m_types; } 32 | 33 | template 34 | constexpr descr operator+(const descr &other) const { 35 | return concat(other, 36 | make_index_sequence(), 37 | make_index_sequence(), 38 | make_index_sequence(), 39 | make_index_sequence()); 40 | } 41 | 42 | protected: 43 | template 44 | constexpr descr( 45 | char const (&text) [Size1+1], 46 | const std::type_info * const (&types) [Size2+1], 47 | index_sequence, index_sequence) 48 | : m_text{text[Indices1]..., '\0'}, 49 | m_types{types[Indices2]..., nullptr } {} 50 | 51 | template 53 | constexpr descr 54 | concat(const descr &other, 55 | index_sequence, index_sequence, 56 | index_sequence, index_sequence) const { 57 | return descr( 58 | { m_text[Indices1]..., other.m_text[OtherIndices1]..., '\0' }, 59 | { m_types[Indices2]..., other.m_types[OtherIndices2]..., nullptr } 60 | ); 61 | } 62 | 63 | protected: 64 | char m_text[Size1 + 1]; 65 | const std::type_info * m_types[Size2 + 1]; 66 | }; 67 | 68 | template constexpr descr _(char const(&text)[Size]) { 69 | return descr(text, { nullptr }); 70 | } 71 | 72 | template struct int_to_str : int_to_str { }; 73 | template struct int_to_str<0, Digits...> { 74 | static constexpr auto digits = descr({ ('0' + Digits)..., '\0' }, { nullptr }); 75 | }; 76 | 77 | // Ternary description (like std::conditional) 78 | template 79 | constexpr enable_if_t> _(char const(&text1)[Size1], char const(&)[Size2]) { 80 | return _(text1); 81 | } 82 | template 83 | constexpr enable_if_t> _(char const(&)[Size1], char const(&text2)[Size2]) { 84 | return _(text2); 85 | } 86 | template 87 | constexpr enable_if_t> _(descr d, descr) { return d; } 88 | template 89 | constexpr enable_if_t> _(descr, descr d) { return d; } 90 | 91 | template auto constexpr _() -> decltype(int_to_str::digits) { 92 | return int_to_str::digits; 93 | } 94 | 95 | template constexpr descr<1, 1> _() { 96 | return descr<1, 1>({ '%', '\0' }, { &typeid(Type), nullptr }); 97 | } 98 | 99 | inline constexpr descr<0, 0> concat() { return _(""); } 100 | template auto constexpr concat(descr descr) { return descr; } 101 | template auto constexpr concat(descr descr, Args&&... args) { return descr + _(", ") + concat(args...); } 102 | template auto constexpr type_descr(descr descr) { return _("{") + descr + _("}"); } 103 | 104 | #define PYBIND11_DESCR constexpr auto 105 | 106 | #else /* Simpler C++11 implementation based on run-time memory allocation and copying */ 107 | 108 | class descr { 109 | public: 110 | PYBIND11_NOINLINE descr(const char *text, const std::type_info * const * types) { 111 | size_t nChars = len(text), nTypes = len(types); 112 | m_text = new char[nChars]; 113 | m_types = new const std::type_info *[nTypes]; 114 | memcpy(m_text, text, nChars * sizeof(char)); 115 | memcpy(m_types, types, nTypes * sizeof(const std::type_info *)); 116 | } 117 | 118 | PYBIND11_NOINLINE descr operator+(descr &&d2) && { 119 | descr r; 120 | 121 | size_t nChars1 = len(m_text), nTypes1 = len(m_types); 122 | size_t nChars2 = len(d2.m_text), nTypes2 = len(d2.m_types); 123 | 124 | r.m_text = new char[nChars1 + nChars2 - 1]; 125 | r.m_types = new const std::type_info *[nTypes1 + nTypes2 - 1]; 126 | memcpy(r.m_text, m_text, (nChars1-1) * sizeof(char)); 127 | memcpy(r.m_text + nChars1 - 1, d2.m_text, nChars2 * sizeof(char)); 128 | memcpy(r.m_types, m_types, (nTypes1-1) * sizeof(std::type_info *)); 129 | memcpy(r.m_types + nTypes1 - 1, d2.m_types, nTypes2 * sizeof(std::type_info *)); 130 | 131 | delete[] m_text; delete[] m_types; 132 | delete[] d2.m_text; delete[] d2.m_types; 133 | 134 | return r; 135 | } 136 | 137 | char *text() { return m_text; } 138 | const std::type_info * * types() { return m_types; } 139 | 140 | protected: 141 | PYBIND11_NOINLINE descr() { } 142 | 143 | template static size_t len(const T *ptr) { // return length including null termination 144 | const T *it = ptr; 145 | while (*it++ != (T) 0) 146 | ; 147 | return static_cast(it - ptr); 148 | } 149 | 150 | const std::type_info **m_types = nullptr; 151 | char *m_text = nullptr; 152 | }; 153 | 154 | /* The 'PYBIND11_NOINLINE inline' combinations below are intentional to get the desired linkage while producing as little object code as possible */ 155 | 156 | PYBIND11_NOINLINE inline descr _(const char *text) { 157 | const std::type_info *types[1] = { nullptr }; 158 | return descr(text, types); 159 | } 160 | 161 | template PYBIND11_NOINLINE enable_if_t _(const char *text1, const char *) { return _(text1); } 162 | template PYBIND11_NOINLINE enable_if_t _(char const *, const char *text2) { return _(text2); } 163 | template PYBIND11_NOINLINE enable_if_t _(descr d, descr) { return d; } 164 | template PYBIND11_NOINLINE enable_if_t _(descr, descr d) { return d; } 165 | 166 | template PYBIND11_NOINLINE descr _() { 167 | const std::type_info *types[2] = { &typeid(Type), nullptr }; 168 | return descr("%", types); 169 | } 170 | 171 | template PYBIND11_NOINLINE descr _() { 172 | const std::type_info *types[1] = { nullptr }; 173 | return descr(std::to_string(Size).c_str(), types); 174 | } 175 | 176 | PYBIND11_NOINLINE inline descr concat() { return _(""); } 177 | PYBIND11_NOINLINE inline descr concat(descr &&d) { return d; } 178 | template PYBIND11_NOINLINE descr concat(descr &&d, Args&&... args) { return std::move(d) + _(", ") + concat(std::forward(args)...); } 179 | PYBIND11_NOINLINE inline descr type_descr(descr&& d) { return _("{") + std::move(d) + _("}"); } 180 | 181 | #define PYBIND11_DESCR ::pybind11::detail::descr 182 | #endif 183 | 184 | NAMESPACE_END(detail) 185 | NAMESPACE_END(pybind11) 186 | -------------------------------------------------------------------------------- /lanms/include/pybind11/operators.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/operator.h: Metatemplates for operator overloading 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | 14 | #if defined(__clang__) && !defined(__INTEL_COMPILER) 15 | # pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type())) 16 | #elif defined(_MSC_VER) 17 | # pragma warning(push) 18 | # pragma warning(disable: 4127) // warning C4127: Conditional expression is constant 19 | #endif 20 | 21 | NAMESPACE_BEGIN(pybind11) 22 | NAMESPACE_BEGIN(detail) 23 | 24 | /// Enumeration with all supported operator types 25 | enum op_id : int { 26 | op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift, 27 | op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert, 28 | op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le, 29 | op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift, 30 | op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero, 31 | op_repr, op_truediv, op_itruediv 32 | }; 33 | 34 | enum op_type : int { 35 | op_l, /* base type on left */ 36 | op_r, /* base type on right */ 37 | op_u /* unary operator */ 38 | }; 39 | 40 | struct self_t { }; 41 | static const self_t self = self_t(); 42 | 43 | /// Type for an unused type slot 44 | struct undefined_t { }; 45 | 46 | /// Don't warn about an unused variable 47 | inline self_t __self() { return self; } 48 | 49 | /// base template of operator implementations 50 | template struct op_impl { }; 51 | 52 | /// Operator implementation generator 53 | template struct op_ { 54 | template void execute(Class &cl, const Extra&... extra) const { 55 | using Base = typename Class::type; 56 | using L_type = conditional_t::value, Base, L>; 57 | using R_type = conditional_t::value, Base, R>; 58 | using op = op_impl; 59 | cl.def(op::name(), &op::execute, is_operator(), extra...); 60 | #if PY_MAJOR_VERSION < 3 61 | if (id == op_truediv || id == op_itruediv) 62 | cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", 63 | &op::execute, is_operator(), extra...); 64 | #endif 65 | } 66 | template void execute_cast(Class &cl, const Extra&... extra) const { 67 | using Base = typename Class::type; 68 | using L_type = conditional_t::value, Base, L>; 69 | using R_type = conditional_t::value, Base, R>; 70 | using op = op_impl; 71 | cl.def(op::name(), &op::execute_cast, is_operator(), extra...); 72 | #if PY_MAJOR_VERSION < 3 73 | if (id == op_truediv || id == op_itruediv) 74 | cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", 75 | &op::execute, is_operator(), extra...); 76 | #endif 77 | } 78 | }; 79 | 80 | #define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \ 81 | template struct op_impl { \ 82 | static char const* name() { return "__" #id "__"; } \ 83 | static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \ 84 | static B execute_cast(const L &l, const R &r) { return B(expr); } \ 85 | }; \ 86 | template struct op_impl { \ 87 | static char const* name() { return "__" #rid "__"; } \ 88 | static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \ 89 | static B execute_cast(const R &r, const L &l) { return B(expr); } \ 90 | }; \ 91 | inline op_ op(const self_t &, const self_t &) { \ 92 | return op_(); \ 93 | } \ 94 | template op_ op(const self_t &, const T &) { \ 95 | return op_(); \ 96 | } \ 97 | template op_ op(const T &, const self_t &) { \ 98 | return op_(); \ 99 | } 100 | 101 | #define PYBIND11_INPLACE_OPERATOR(id, op, expr) \ 102 | template struct op_impl { \ 103 | static char const* name() { return "__" #id "__"; } \ 104 | static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \ 105 | static B execute_cast(L &l, const R &r) { return B(expr); } \ 106 | }; \ 107 | template op_ op(const self_t &, const T &) { \ 108 | return op_(); \ 109 | } 110 | 111 | #define PYBIND11_UNARY_OPERATOR(id, op, expr) \ 112 | template struct op_impl { \ 113 | static char const* name() { return "__" #id "__"; } \ 114 | static auto execute(const L &l) -> decltype(expr) { return expr; } \ 115 | static B execute_cast(const L &l) { return B(expr); } \ 116 | }; \ 117 | inline op_ op(const self_t &) { \ 118 | return op_(); \ 119 | } 120 | 121 | PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r) 122 | PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r) 123 | PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r) 124 | PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r) 125 | PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r) 126 | PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r) 127 | PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r) 128 | PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r) 129 | PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r) 130 | PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r) 131 | PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r) 132 | PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r) 133 | PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r) 134 | PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r) 135 | PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r) 136 | PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r) 137 | //PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r)) 138 | PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r) 139 | PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r) 140 | PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r) 141 | PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r) 142 | PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r) 143 | PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r) 144 | PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r) 145 | PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r) 146 | PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r) 147 | PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r) 148 | PYBIND11_UNARY_OPERATOR(neg, operator-, -l) 149 | PYBIND11_UNARY_OPERATOR(pos, operator+, +l) 150 | PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l)) 151 | PYBIND11_UNARY_OPERATOR(invert, operator~, (~l)) 152 | PYBIND11_UNARY_OPERATOR(bool, operator!, !!l) 153 | PYBIND11_UNARY_OPERATOR(int, int_, (int) l) 154 | PYBIND11_UNARY_OPERATOR(float, float_, (double) l) 155 | 156 | #undef PYBIND11_BINARY_OPERATOR 157 | #undef PYBIND11_INPLACE_OPERATOR 158 | #undef PYBIND11_UNARY_OPERATOR 159 | NAMESPACE_END(detail) 160 | 161 | using detail::self; 162 | 163 | NAMESPACE_END(pybind11) 164 | 165 | #if defined(_MSC_VER) 166 | # pragma warning(pop) 167 | #endif 168 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | 6 | 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | "3x3 convolution with padding" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | expansion = 1 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None): 19 | super(BasicBlock, self).__init__() 20 | self.conv1 = conv3x3(inplanes, planes, stride) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.conv2 = conv3x3(planes, planes) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | residual = x 30 | 31 | out = self.conv1(x) 32 | out = self.bn1(out) 33 | out = self.relu(out) 34 | 35 | out = self.conv2(out) 36 | out = self.bn2(out) 37 | 38 | if self.downsample is not None: 39 | residual = self.downsample(x) 40 | 41 | out += residual 42 | out = self.relu(out) 43 | 44 | return out 45 | 46 | 47 | class Bottleneck(nn.Module): 48 | expansion = 4 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(Bottleneck, self).__init__() 52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 55 | padding=1, bias=False) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 58 | self.bn3 = nn.BatchNorm2d(planes * 4) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv3(out) 75 | out = self.bn3(out) 76 | 77 | if self.downsample is not None: 78 | residual = self.downsample(x) 79 | 80 | out += residual 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | 86 | class ResNet(nn.Module): 87 | 88 | def __init__(self, block, layers, num_classes=1000): 89 | self.inplanes = 64 90 | super(ResNet, self).__init__() 91 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 92 | bias=False) 93 | self.bn1 = nn.BatchNorm2d(64) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 96 | self.layer1 = self._make_layer(block, 64, layers[0]) 97 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 98 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 99 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 100 | self.avgpool = nn.AvgPool2d(7) 101 | self.fc = nn.Linear(512 * block.expansion, num_classes) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | m.weight.data.normal_(0, math.sqrt(2. / n)) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | 111 | def _make_layer(self, block, planes, blocks, stride=1): 112 | downsample = None 113 | if stride != 1 or self.inplanes != planes * block.expansion: 114 | downsample = nn.Sequential( 115 | nn.Conv2d(self.inplanes, planes * block.expansion, 116 | kernel_size=1, stride=stride, bias=False), 117 | nn.BatchNorm2d(planes * block.expansion), 118 | ) 119 | 120 | layers = [] 121 | layers.append(block(self.inplanes, planes, stride, downsample)) 122 | self.inplanes = planes * block.expansion 123 | for i in range(1, blocks): 124 | layers.append(block(self.inplanes, planes)) 125 | 126 | return nn.Sequential(*layers) 127 | 128 | def forward(self, x): 129 | f = [] 130 | x = self.conv1(x) 131 | x = self.bn1(x) 132 | x = self.relu(x) 133 | x = self.maxpool(x) 134 | x = self.layer1(x) 135 | f.append(x) 136 | x = self.layer2(x) 137 | f.append(x) 138 | x = self.layer3(x) 139 | f.append(x) 140 | x = self.layer4(x) 141 | f.append(x) 142 | # x = self.avgpool(x) 143 | # x = x.view(x.size(0), -1) 144 | # x = self.fc(x) 145 | ''' 146 | f中的每个元素的size分别是 bs 256 w/4 h/4, bs 512 w/8 h/8, 147 | bs 1024 w/16 h/16, bs 2048 w/32 h/32 148 | ''' 149 | return x, f 150 | 151 | 152 | 153 | 154 | def resnet50(pretrained=False, **kwargs): 155 | """Constructs a ResNet-50 model. 156 | 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | """ 160 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 161 | if pretrained: 162 | model.load_state_dict(torch.load("./resnet50-19c8e357.pth")) 163 | return model 164 | 165 | def mean_image_subtraction(images, means=[123.68, 116.78, 103.94]): 166 | ''' 167 | image normalization 168 | :param images: bs * w * h * channel 169 | :param means: 170 | :return: 171 | ''' 172 | num_channels = images.data.shape[1] 173 | if len(means) != num_channels: 174 | raise ValueError('len(means) must match the number of channels') 175 | for i in range(num_channels): 176 | images.data[:,i,:,:] -= means[i] 177 | 178 | return images 179 | 180 | class East(nn.Module): 181 | def __init__(self): 182 | super(East, self).__init__() 183 | self.resnet = resnet50(True) 184 | self.conv1 = nn.Conv2d(3072, 128, 1) 185 | self.bn1 = nn.BatchNorm2d(128) 186 | self.relu1 = nn.ReLU() 187 | 188 | self.conv2 = nn.Conv2d(128, 128, 3, padding=1) 189 | self.bn2 = nn.BatchNorm2d(128) 190 | self.relu2 = nn.ReLU() 191 | 192 | self.conv3 = nn.Conv2d(640, 64, 1) 193 | self.bn3 = nn.BatchNorm2d(64) 194 | self.relu3 = nn.ReLU() 195 | 196 | self.conv4 = nn.Conv2d(64, 64, 3 ,padding=1) 197 | self.bn4 = nn.BatchNorm2d(64) 198 | self.relu4 = nn.ReLU() 199 | 200 | self.conv5 = nn.Conv2d(320, 64, 1) 201 | self.bn5 = nn.BatchNorm2d(64) 202 | self.relu5 = nn.ReLU() 203 | 204 | self.conv6 = nn.Conv2d(64, 32, 3, padding=1) 205 | self.bn6 = nn.BatchNorm2d(32) 206 | self.relu6 = nn.ReLU() 207 | 208 | self.conv7 = nn.Conv2d(32, 32, 3, padding=1) 209 | self.bn7 = nn.BatchNorm2d(32) 210 | self.relu7 = nn.ReLU() 211 | 212 | self.conv8 = nn.Conv2d(32, 1, 1) 213 | self.sigmoid1 = nn.Sigmoid() 214 | self.conv9 = nn.Conv2d(32, 4, 1) 215 | self.sigmoid2 = nn.Sigmoid() 216 | self.conv10 = nn.Conv2d(32, 1, 1) 217 | self.sigmoid3 = nn.Sigmoid() 218 | self.unpool1 = nn.Upsample(scale_factor=2, mode='bilinear') 219 | self.unpool2 = nn.Upsample(scale_factor=2, mode='bilinear') 220 | self.unpool3 = nn.Upsample(scale_factor=2, mode='bilinear') 221 | def forward(self,images): 222 | images = mean_image_subtraction(images) 223 | _, f = self.resnet(images) 224 | h = f[3] # bs 2048 w/32 h/32 225 | g = (self.unpool1(h)) #bs 2048 w/16 h/16 226 | c = self.conv1(torch.cat((g, f[2]), 1)) 227 | c = self.bn1(c) 228 | c = self.relu1(c) 229 | 230 | h = self.conv2(c) # bs 128 w/16 h/16 231 | h = self.bn2(h) 232 | h = self.relu2(h) 233 | g = self.unpool2(h) # bs 128 w/8 h/8 234 | c = self.conv3(torch.cat((g, f[1]), 1)) 235 | c = self.bn3(c) 236 | c = self.relu3(c) 237 | 238 | h = self.conv4(c) # bs 64 w/8 h/8 239 | h = self.bn4(h) 240 | h = self.relu4(h) 241 | g = self.unpool3(h) # bs 64 w/4 h/4 242 | c = self.conv5(torch.cat((g, f[0]), 1)) 243 | c = self.bn5(c) 244 | c = self.relu5(c) 245 | 246 | h = self.conv6(c) # bs 32 w/4 h/4 247 | h = self.bn6(h) 248 | h = self.relu6(h) 249 | g = self.conv7(h) # bs 32 w/4 h/4 250 | g = self.bn7(g) 251 | g = self.relu7(g) 252 | 253 | F_score = self.conv8(g) # bs 1 w/4 h/4 254 | F_score = self.sigmoid1(F_score) 255 | geo_map = self.conv9(g) 256 | geo_map = self.sigmoid2(geo_map) * 512 257 | angle_map = self.conv10(g) 258 | angle_map = self.sigmoid3(angle_map) 259 | angle_map = (angle_map - 0.5) * math.pi / 2 260 | 261 | F_geometry = torch.cat((geo_map, angle_map), 1) # bs 5 w/4 w/4 262 | return F_score, F_geometry 263 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import time 3 | import math 4 | import os 5 | import numpy as np 6 | 7 | import locality_aware_nms as nms_locality 8 | import lanms 9 | 10 | import torch 11 | import model 12 | from data_utils import restore_rectangle 13 | from torch.autograd import Variable 14 | 15 | 16 | import model 17 | # test_data_path = '/home/mathu/Documents/express_recognition/data/icdar2015/test2015' 18 | test_data_path = '/home/mathu/Documents/express_recognition/data/telephone_txt/result/print_pic/' 19 | checkpoint_path = './checkpoints/model_500.pth' 20 | output_dir_txt = './result/txt' 21 | output_dir_pic = './result/pic' 22 | 23 | def rotate(box_List,image): 24 | #xuan zhuan tu pian 25 | 26 | n=len(box_List) 27 | c=0; 28 | angle=0 29 | for i in range(n): 30 | box=box_List[i] 31 | y1 = min(box[0][1], box[1][1], box[2][1], box[3][1]) 32 | y2 = max(box[0][1], box[1][1], box[2][1], box[3][1]) 33 | x1 = min(box[0][0], box[1][0], box[2][0], box[3][0]) 34 | x2 = max(box[0][0], box[1][0], box[2][0], box[3][0]) 35 | for j in range(4): 36 | if(box[j][1]==y2): 37 | k1=j 38 | for j in range(4): 39 | if(box[j][0]==x2 and j!=k1): 40 | k2=j 41 | c=(box[k1][0]-box[k2][0])*1.0/(box[k1][1]-box[k2][1]) 42 | if(c<0): 43 | c=-c 44 | if(c>1): 45 | c=1.0/c 46 | angle=math.atan(c)+angle 47 | angle=angle/n 48 | (h, w) = image.shape[:2] 49 | center = (w / 2, h / 2) 50 | scale=1 51 | M = cv2.getRotationMatrix2D(center,angle, scale) 52 | image_new = cv2.warpAffine(image, M, (w, h)) 53 | return image_new 54 | 55 | def get_images(): 56 | ''' 57 | find image files in test data path 58 | :return: list of files found 59 | ''' 60 | files = [] 61 | exts = ['jpg', 'png', 'jpeg', 'JPG'] 62 | for parent, dirnames, filenames in os.walk(test_data_path): 63 | for filename in filenames: 64 | for ext in exts: 65 | if filename.endswith(ext): 66 | files.append(os.path.join(parent, filename)) 67 | break 68 | # print('Find {} images'.format(len(files))) 69 | return files 70 | 71 | def resize_image(im, max_side_len=2400): 72 | ''' 73 | resize image to a size multiple of 32 which is required by the network 74 | :param im: the resized image 75 | :param max_side_len: limit of max image size to avoid out of memory in gpu 76 | :return: the resized image and the resize ratio 77 | ''' 78 | h, w, _ = im.shape 79 | 80 | resize_w = w 81 | resize_h = h 82 | 83 | # limit the max side 84 | if max(resize_h, resize_w) > max_side_len: 85 | ratio = float(max_side_len) / resize_h if resize_h > resize_w else float(max_side_len) / resize_w 86 | else: 87 | ratio = 1. 88 | resize_h = int(resize_h * ratio) 89 | resize_w = int(resize_w * ratio) 90 | 91 | resize_h = resize_h if resize_h % 32 == 0 else (resize_h // 32 - 1) * 32 92 | resize_w = resize_w if resize_w % 32 == 0 else (resize_w // 32 - 1) * 32 93 | im = cv2.resize(im, (int(resize_w), int(resize_h))) 94 | 95 | ratio_h = resize_h / float(h) 96 | ratio_w = resize_w / float(w) 97 | 98 | return im, (ratio_h, ratio_w) 99 | 100 | def detect(score_map, geo_map, timer, score_map_thresh=1e-5, box_thresh=1e-8, nms_thres=0.1): 101 | 102 | # def detect(score_map, geo_map, timer, score_map_thresh=0.8, box_thresh=0.1, nms_thres=0.2): 103 | ''' 104 | restore text boxes from score map and geo map 105 | :param score_map: 106 | :param geo_map: 107 | :param timer: 108 | :param score_map_thresh: threshhold for score map 109 | :param box_thresh: threshhold for boxes 110 | :param nms_thres: threshold for nms 111 | :return: 112 | ''' 113 | if len(score_map.shape) == 4: 114 | score_map = score_map[0, :, :, 0] 115 | geo_map = geo_map[0, :, :, ] 116 | # filter the score map 117 | xy_text = np.argwhere(score_map > score_map_thresh) 118 | # sort the text boxes via the y axis 119 | xy_text = xy_text[np.argsort(xy_text[:, 0])] 120 | # restore 121 | start = time.time() 122 | text_box_restored = restore_rectangle(xy_text[:, ::-1]*4, geo_map[xy_text[:, 0], xy_text[:, 1], :]) # N*4*2 123 | # print('{} text boxes before nms'.format(text_box_restored.shape[0])) 124 | boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) 125 | boxes[:, :8] = text_box_restored.reshape((-1, 8)) 126 | boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] 127 | print('11') 128 | input(boxes) 129 | timer['restore'] = time.time() - start 130 | # nms part 131 | start = time.time() 132 | # boxes = nms_locality.nms_locality(boxes.astype(np.float64), nms_thres) 133 | boxes = lanms.merge_quadrangle_n9(boxes.astype('float32'), nms_thres) 134 | timer['nms'] = time.time() - start 135 | if boxes.shape[0] == 0: 136 | return None, timer 137 | 138 | # here we filter some low score boxes by the average score map, this is different from the orginal paper 139 | for i, box in enumerate(boxes): 140 | mask = np.zeros_like(score_map, dtype=np.uint8) 141 | cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1) 142 | boxes[i, 8] = cv2.mean(score_map, mask)[0] 143 | print(22) 144 | print(boxes) 145 | boxes = boxes[boxes[:, 8] > box_thresh] 146 | print('333') 147 | input(boxes) 148 | return boxes, timer 149 | 150 | def sort_poly(p): 151 | min_axis = np.argmin(np.sum(p, axis=1)) 152 | p = p[[min_axis, (min_axis+1)%4, (min_axis+2)%4, (min_axis+3)%4]] 153 | if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]): 154 | return p 155 | else: 156 | return p[[0, 3, 2, 1]] 157 | 158 | def change_box(box_List): 159 | n=len(box_List) 160 | for i in range(n): 161 | box=box_List[i] 162 | y1 = min(box[0][1], box[1][1], box[2][1], box[3][1]) 163 | y2 = max(box[0][1], box[1][1], box[2][1], box[3][1]) 164 | x1 = min(box[0][0], box[1][0], box[2][0], box[3][0]) 165 | x2 = max(box[0][0], box[1][0], box[2][0], box[3][0]) 166 | box[0][1]=y1 167 | box[0][0]=x1 168 | box[1][1]=y1 169 | box[1][0]=x2 170 | box[3][1]=y2 171 | box[3][0]=x1 172 | box[2][1]=y2 173 | box[2][0]=x2 174 | box_List[i]=box 175 | return box_List 176 | 177 | def save_box(box_List,image,img_path): 178 | n=len(box_List) 179 | box_final = [] 180 | for i in range(n): 181 | box=box_List[i] 182 | y1_0 = int(min(box[0][1], box[1][1], box[2][1], box[3][1])) 183 | y2_0 = int(max(box[0][1], box[1][1], box[2][1], box[3][1])) 184 | x1_0 = int(min(box[0][0], box[1][0], box[2][0], box[3][0])) 185 | x2_0 = int(max(box[0][0], box[1][0], box[2][0], box[3][0])) 186 | y1 = max(int(y1_0 - 0.1 * (y2_0 - y1_0)), 0) 187 | y2 = min(int(y2_0 + 0.1 * (y2_0 - y1_0)), image.shape[0] - 1) 188 | x1 = max(int(x1_0 - 0.25 * (x2_0 - x1_0)), 0) 189 | x2 = min(int(x2_0 + 0.25 * (x2_0 - x1_0)), image.shape[1] - 1) 190 | image_new=image[y1:y2,x1:x2] 191 | 192 | # # 图像处理 193 | gray_2 = image_new[:,:,0] 194 | gradX = cv2.Sobel(gray_2, ddepth = cv2.CV_32F, dx = 1, dy = 0, ksize = -1) 195 | gradY = cv2.Sobel(gray_2, ddepth = cv2.CV_32F, dx = 0, dy = 1, ksize = -1) 196 | blurred = cv2.blur(gradX, (2, 2)) 197 | (_, thresh) = cv2.threshold(blurred, 160, 255, cv2.THRESH_BINARY) 198 | # closed = cv2.erode(thresh, None, iterations = 1) 199 | # closed = cv2.dilate(closed, None, iterations = 1) 200 | closed = thresh 201 | x_plus = [] 202 | x_left = 1 203 | x_right = closed.shape[1] 204 | for jj in range(0, closed.shape[1]): 205 | plus = 0 206 | for ii in range(0, closed.shape[0]): 207 | plus = plus + closed[ii][jj] 208 | x_plus.append(plus) 209 | 210 | for jj in range(0, int(closed.shape[1] * 0.5 - 1)): 211 | if(x_plus[jj] > 0.4 * max(x_plus)): 212 | x_left = max(jj - 5, 0) 213 | break 214 | for ii in range(closed.shape[1] - 1, int(closed.shape[1] * 0.5 + 1), -1): 215 | if(x_plus[ii] > 0.4 * max(x_plus)): 216 | x_right = min(ii + 5, closed.shape[1] - 1) 217 | break 218 | 219 | image_new = image_new[:, x_left:x_right] 220 | cv2.imwrite("." + img_path.split(".")[1]+'_'+str(i)+".jpg", image_new) 221 | box[0][1]=y1 222 | box[0][0]=x1 + x_left 223 | box[1][1]=y1 224 | box[1][0]=x1 + x_right 225 | box[3][1]=y2 226 | box[3][0]=x1 + x_left 227 | box[2][1]=y2 228 | box[2][0]=x1 + x_right 229 | box_List[i]=box 230 | return box_List 231 | 232 | East_model = model.East() 233 | East_model = East_model.eval() 234 | East_model = East_model.cuda() 235 | 236 | East_model.load_state_dict(torch.load(checkpoint_path)) 237 | 238 | def predict(argv=None): 239 | 240 | try: 241 | os.makedirs(output_dir_txt) 242 | os.makedirs(output_dir_pic) 243 | except OSError as e: 244 | if e.errno != 17: 245 | raise 246 | 247 | im_fn_list = get_images() 248 | start = time.time() 249 | for im_fn in im_fn_list: 250 | # print(im_fn) 251 | im = cv2.imread(im_fn)[:, :, ::-1] 252 | start_time = time.time() 253 | im_resized, (ratio_h, ratio_w) = resize_image(im) 254 | im_resized = im_resized.astype(np.float32) 255 | im_resized = Variable(torch.from_numpy(im_resized)).cuda() 256 | im_resized = im_resized.unsqueeze(0) 257 | im_resized = im_resized.permute(0, 3, 1, 2) 258 | 259 | timer = {'net': 0, 'restore': 0, 'nms': 0} 260 | 261 | score, geometry = East_model(im_resized) 262 | score = score.permute(0, 2, 3, 1) 263 | geometry = geometry.permute(0, 2, 3, 1) 264 | score = score.data.cpu().numpy() 265 | geometry = geometry.data.cpu().numpy() 266 | 267 | 268 | boxes, timer = detect(score_map=score, geo_map=geometry, timer=timer) 269 | 270 | 271 | if boxes is not None: 272 | boxes = boxes[:, :8].reshape((-1, 4, 2)) 273 | boxes[:, :, 0] /= ratio_w 274 | boxes[:, :, 1] /= ratio_h 275 | 276 | 277 | if boxes is not None: 278 | res_file = os.path.join(output_dir_txt, '{}.txt'.format( 279 | os.path.basename(im_fn).split('.')[0])) 280 | 281 | with open(res_file, 'w') as f: 282 | for box in boxes: 283 | 284 | box = sort_poly(box.astype(np.int32)) 285 | if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5: 286 | continue 287 | f.write('{}, {}, {}, {}, {}, {}, {}, {}\r\n'.format( 288 | box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0], box[2, 1], box[3, 0], box[3, 1])) 289 | cv2.polylines(im[:, :, ::-1], [box.astype(np.int32).reshape((-1, 1, 2))], True, 290 | color=(255, 255, 0), thickness=1) 291 | 292 | img_path = os.path.join(output_dir_pic, os.path.basename(im_fn)) 293 | cv2.imwrite(img_path, im[:, :, ::-1]) 294 | 295 | during = time.time() - start 296 | print('average :{:.6f}'.format(during / len(im_fn_list))) 297 | 298 | if __name__ == "__main__": 299 | predict() 300 | -------------------------------------------------------------------------------- /lanms/include/pybind11/stl.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/stl.h: Transparent conversion for STL data types 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #if defined(_MSC_VER) 22 | #pragma warning(push) 23 | #pragma warning(disable: 4127) // warning C4127: Conditional expression is constant 24 | #endif 25 | 26 | #ifdef __has_include 27 | // std::optional (but including it in c++14 mode isn't allowed) 28 | # if defined(PYBIND11_CPP17) && __has_include() 29 | # include 30 | # define PYBIND11_HAS_OPTIONAL 1 31 | # endif 32 | // std::experimental::optional (but not allowed in c++11 mode) 33 | # if defined(PYBIND11_CPP14) && __has_include() 34 | # include 35 | # define PYBIND11_HAS_EXP_OPTIONAL 1 36 | # endif 37 | // std::variant 38 | # if defined(PYBIND11_CPP17) && __has_include() 39 | # include 40 | # define PYBIND11_HAS_VARIANT 1 41 | # endif 42 | #elif defined(_MSC_VER) && defined(PYBIND11_CPP17) 43 | # include 44 | # include 45 | # define PYBIND11_HAS_OPTIONAL 1 46 | # define PYBIND11_HAS_VARIANT 1 47 | #endif 48 | 49 | NAMESPACE_BEGIN(pybind11) 50 | NAMESPACE_BEGIN(detail) 51 | 52 | /// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for 53 | /// forwarding a container element). Typically used indirect via forwarded_type(), below. 54 | template 55 | using forwarded_type = conditional_t< 56 | std::is_lvalue_reference::value, remove_reference_t &, remove_reference_t &&>; 57 | 58 | /// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically 59 | /// used for forwarding a container's elements. 60 | template 61 | forwarded_type forward_like(U &&u) { 62 | return std::forward>(std::forward(u)); 63 | } 64 | 65 | template struct set_caster { 66 | using type = Type; 67 | using key_conv = make_caster; 68 | 69 | bool load(handle src, bool convert) { 70 | if (!isinstance(src)) 71 | return false; 72 | auto s = reinterpret_borrow(src); 73 | value.clear(); 74 | for (auto entry : s) { 75 | key_conv conv; 76 | if (!conv.load(entry, convert)) 77 | return false; 78 | value.insert(cast_op(std::move(conv))); 79 | } 80 | return true; 81 | } 82 | 83 | template 84 | static handle cast(T &&src, return_value_policy policy, handle parent) { 85 | pybind11::set s; 86 | for (auto &value: src) { 87 | auto value_ = reinterpret_steal(key_conv::cast(forward_like(value), policy, parent)); 88 | if (!value_ || !s.add(value_)) 89 | return handle(); 90 | } 91 | return s.release(); 92 | } 93 | 94 | PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name() + _("]")); 95 | }; 96 | 97 | template struct map_caster { 98 | using key_conv = make_caster; 99 | using value_conv = make_caster; 100 | 101 | bool load(handle src, bool convert) { 102 | if (!isinstance(src)) 103 | return false; 104 | auto d = reinterpret_borrow(src); 105 | value.clear(); 106 | for (auto it : d) { 107 | key_conv kconv; 108 | value_conv vconv; 109 | if (!kconv.load(it.first.ptr(), convert) || 110 | !vconv.load(it.second.ptr(), convert)) 111 | return false; 112 | value.emplace(cast_op(std::move(kconv)), cast_op(std::move(vconv))); 113 | } 114 | return true; 115 | } 116 | 117 | template 118 | static handle cast(T &&src, return_value_policy policy, handle parent) { 119 | dict d; 120 | for (auto &kv: src) { 121 | auto key = reinterpret_steal(key_conv::cast(forward_like(kv.first), policy, parent)); 122 | auto value = reinterpret_steal(value_conv::cast(forward_like(kv.second), policy, parent)); 123 | if (!key || !value) 124 | return handle(); 125 | d[key] = value; 126 | } 127 | return d.release(); 128 | } 129 | 130 | PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name() + _(", ") + value_conv::name() + _("]")); 131 | }; 132 | 133 | template struct list_caster { 134 | using value_conv = make_caster; 135 | 136 | bool load(handle src, bool convert) { 137 | if (!isinstance(src)) 138 | return false; 139 | auto s = reinterpret_borrow(src); 140 | value.clear(); 141 | reserve_maybe(s, &value); 142 | for (auto it : s) { 143 | value_conv conv; 144 | if (!conv.load(it, convert)) 145 | return false; 146 | value.push_back(cast_op(std::move(conv))); 147 | } 148 | return true; 149 | } 150 | 151 | private: 152 | template ().reserve(0)), void>::value, int> = 0> 154 | void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); } 155 | void reserve_maybe(sequence, void *) { } 156 | 157 | public: 158 | template 159 | static handle cast(T &&src, return_value_policy policy, handle parent) { 160 | list l(src.size()); 161 | size_t index = 0; 162 | for (auto &value: src) { 163 | auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); 164 | if (!value_) 165 | return handle(); 166 | PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference 167 | } 168 | return l.release(); 169 | } 170 | 171 | PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name() + _("]")); 172 | }; 173 | 174 | template struct type_caster> 175 | : list_caster, Type> { }; 176 | 177 | template struct type_caster> 178 | : list_caster, Type> { }; 179 | 180 | template struct array_caster { 181 | using value_conv = make_caster; 182 | 183 | private: 184 | template 185 | bool require_size(enable_if_t size) { 186 | if (value.size() != size) 187 | value.resize(size); 188 | return true; 189 | } 190 | template 191 | bool require_size(enable_if_t size) { 192 | return size == Size; 193 | } 194 | 195 | public: 196 | bool load(handle src, bool convert) { 197 | if (!isinstance(src)) 198 | return false; 199 | auto l = reinterpret_borrow(src); 200 | if (!require_size(l.size())) 201 | return false; 202 | size_t ctr = 0; 203 | for (auto it : l) { 204 | value_conv conv; 205 | if (!conv.load(it, convert)) 206 | return false; 207 | value[ctr++] = cast_op(std::move(conv)); 208 | } 209 | return true; 210 | } 211 | 212 | template 213 | static handle cast(T &&src, return_value_policy policy, handle parent) { 214 | list l(src.size()); 215 | size_t index = 0; 216 | for (auto &value: src) { 217 | auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); 218 | if (!value_) 219 | return handle(); 220 | PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference 221 | } 222 | return l.release(); 223 | } 224 | 225 | PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name() + _(_(""), _("[") + _() + _("]")) + _("]")); 226 | }; 227 | 228 | template struct type_caster> 229 | : array_caster, Type, false, Size> { }; 230 | 231 | template struct type_caster> 232 | : array_caster, Type, true> { }; 233 | 234 | template struct type_caster> 235 | : set_caster, Key> { }; 236 | 237 | template struct type_caster> 238 | : set_caster, Key> { }; 239 | 240 | template struct type_caster> 241 | : map_caster, Key, Value> { }; 242 | 243 | template struct type_caster> 244 | : map_caster, Key, Value> { }; 245 | 246 | // This type caster is intended to be used for std::optional and std::experimental::optional 247 | template struct optional_caster { 248 | using value_conv = make_caster; 249 | 250 | template 251 | static handle cast(T_ &&src, return_value_policy policy, handle parent) { 252 | if (!src) 253 | return none().inc_ref(); 254 | return value_conv::cast(*std::forward(src), policy, parent); 255 | } 256 | 257 | bool load(handle src, bool convert) { 258 | if (!src) { 259 | return false; 260 | } else if (src.is_none()) { 261 | return true; // default-constructed value is already empty 262 | } 263 | value_conv inner_caster; 264 | if (!inner_caster.load(src, convert)) 265 | return false; 266 | 267 | value.emplace(cast_op(std::move(inner_caster))); 268 | return true; 269 | } 270 | 271 | PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name() + _("]")); 272 | }; 273 | 274 | #if PYBIND11_HAS_OPTIONAL 275 | template struct type_caster> 276 | : public optional_caster> {}; 277 | 278 | template<> struct type_caster 279 | : public void_caster {}; 280 | #endif 281 | 282 | #if PYBIND11_HAS_EXP_OPTIONAL 283 | template struct type_caster> 284 | : public optional_caster> {}; 285 | 286 | template<> struct type_caster 287 | : public void_caster {}; 288 | #endif 289 | 290 | /// Visit a variant and cast any found type to Python 291 | struct variant_caster_visitor { 292 | return_value_policy policy; 293 | handle parent; 294 | 295 | template 296 | handle operator()(T &&src) const { 297 | return make_caster::cast(std::forward(src), policy, parent); 298 | } 299 | }; 300 | 301 | /// Helper class which abstracts away variant's `visit` function. `std::variant` and similar 302 | /// `namespace::variant` types which provide a `namespace::visit()` function are handled here 303 | /// automatically using argument-dependent lookup. Users can provide specializations for other 304 | /// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`. 305 | template class Variant> 306 | struct visit_helper { 307 | template 308 | static auto call(Args &&...args) -> decltype(visit(std::forward(args)...)) { 309 | return visit(std::forward(args)...); 310 | } 311 | }; 312 | 313 | /// Generic variant caster 314 | template struct variant_caster; 315 | 316 | template class V, typename... Ts> 317 | struct variant_caster> { 318 | static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative."); 319 | 320 | template 321 | bool load_alternative(handle src, bool convert, type_list) { 322 | auto caster = make_caster(); 323 | if (caster.load(src, convert)) { 324 | value = cast_op(caster); 325 | return true; 326 | } 327 | return load_alternative(src, convert, type_list{}); 328 | } 329 | 330 | bool load_alternative(handle, bool, type_list<>) { return false; } 331 | 332 | bool load(handle src, bool convert) { 333 | // Do a first pass without conversions to improve constructor resolution. 334 | // E.g. `py::int_(1).cast>()` needs to fill the `int` 335 | // slot of the variant. Without two-pass loading `double` would be filled 336 | // because it appears first and a conversion is possible. 337 | if (convert && load_alternative(src, false, type_list{})) 338 | return true; 339 | return load_alternative(src, convert, type_list{}); 340 | } 341 | 342 | template 343 | static handle cast(Variant &&src, return_value_policy policy, handle parent) { 344 | return visit_helper::call(variant_caster_visitor{policy, parent}, 345 | std::forward(src)); 346 | } 347 | 348 | using Type = V; 349 | PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster::name()...) + _("]")); 350 | }; 351 | 352 | #if PYBIND11_HAS_VARIANT 353 | template 354 | struct type_caster> : variant_caster> { }; 355 | #endif 356 | NAMESPACE_END(detail) 357 | 358 | inline std::ostream &operator<<(std::ostream &os, const handle &obj) { 359 | os << (std::string) str(obj); 360 | return os; 361 | } 362 | 363 | NAMESPACE_END(pybind11) 364 | 365 | #if defined(_MSC_VER) 366 | #pragma warning(pop) 367 | #endif 368 | -------------------------------------------------------------------------------- /lanms/include/clipper/clipper.hpp: -------------------------------------------------------------------------------- 1 | /******************************************************************************* 2 | * * 3 | * Author : Angus Johnson * 4 | * Version : 6.4.0 * 5 | * Date : 2 July 2015 * 6 | * Website : http://www.angusj.com * 7 | * Copyright : Angus Johnson 2010-2015 * 8 | * * 9 | * License: * 10 | * Use, modification & distribution is subject to Boost Software License Ver 1. * 11 | * http://www.boost.org/LICENSE_1_0.txt * 12 | * * 13 | * Attributions: * 14 | * The code in this library is an extension of Bala Vatti's clipping algorithm: * 15 | * "A generic solution to polygon clipping" * 16 | * Communications of the ACM, Vol 35, Issue 7 (July 1992) pp 56-63. * 17 | * http://portal.acm.org/citation.cfm?id=129906 * 18 | * * 19 | * Computer graphics and geometric modeling: implementation and algorithms * 20 | * By Max K. Agoston * 21 | * Springer; 1 edition (January 4, 2005) * 22 | * http://books.google.com/books?q=vatti+clipping+agoston * 23 | * * 24 | * See also: * 25 | * "Polygon Offsetting by Computing Winding Numbers" * 26 | * Paper no. DETC2005-85513 pp. 565-575 * 27 | * ASME 2005 International Design Engineering Technical Conferences * 28 | * and Computers and Information in Engineering Conference (IDETC/CIE2005) * 29 | * September 24-28, 2005 , Long Beach, California, USA * 30 | * http://www.me.berkeley.edu/~mcmains/pubs/DAC05OffsetPolygon.pdf * 31 | * * 32 | *******************************************************************************/ 33 | 34 | #ifndef clipper_hpp 35 | #define clipper_hpp 36 | 37 | #define CLIPPER_VERSION "6.2.6" 38 | 39 | //use_int32: When enabled 32bit ints are used instead of 64bit ints. This 40 | //improve performance but coordinate values are limited to the range +/- 46340 41 | //#define use_int32 42 | 43 | //use_xyz: adds a Z member to IntPoint. Adds a minor cost to perfomance. 44 | //#define use_xyz 45 | 46 | //use_lines: Enables line clipping. Adds a very minor cost to performance. 47 | #define use_lines 48 | 49 | //use_deprecated: Enables temporary support for the obsolete functions 50 | //#define use_deprecated 51 | 52 | #include 53 | #include 54 | #include 55 | #include 56 | #include 57 | #include 58 | #include 59 | #include 60 | #include 61 | 62 | namespace ClipperLib { 63 | 64 | enum ClipType { ctIntersection, ctUnion, ctDifference, ctXor }; 65 | enum PolyType { ptSubject, ptClip }; 66 | //By far the most widely used winding rules for polygon filling are 67 | //EvenOdd & NonZero (GDI, GDI+, XLib, OpenGL, Cairo, AGG, Quartz, SVG, Gr32) 68 | //Others rules include Positive, Negative and ABS_GTR_EQ_TWO (only in OpenGL) 69 | //see http://glprogramming.com/red/chapter11.html 70 | enum PolyFillType { pftEvenOdd, pftNonZero, pftPositive, pftNegative }; 71 | 72 | #ifdef use_int32 73 | typedef int cInt; 74 | static cInt const loRange = 0x7FFF; 75 | static cInt const hiRange = 0x7FFF; 76 | #else 77 | typedef signed long long cInt; 78 | static cInt const loRange = 0x3FFFFFFF; 79 | static cInt const hiRange = 0x3FFFFFFFFFFFFFFFLL; 80 | typedef signed long long long64; //used by Int128 class 81 | typedef unsigned long long ulong64; 82 | 83 | #endif 84 | 85 | struct IntPoint { 86 | cInt X; 87 | cInt Y; 88 | #ifdef use_xyz 89 | cInt Z; 90 | IntPoint(cInt x = 0, cInt y = 0, cInt z = 0): X(x), Y(y), Z(z) {}; 91 | #else 92 | IntPoint(cInt x = 0, cInt y = 0): X(x), Y(y) {}; 93 | #endif 94 | 95 | friend inline bool operator== (const IntPoint& a, const IntPoint& b) 96 | { 97 | return a.X == b.X && a.Y == b.Y; 98 | } 99 | friend inline bool operator!= (const IntPoint& a, const IntPoint& b) 100 | { 101 | return a.X != b.X || a.Y != b.Y; 102 | } 103 | }; 104 | //------------------------------------------------------------------------------ 105 | 106 | typedef std::vector< IntPoint > Path; 107 | typedef std::vector< Path > Paths; 108 | 109 | inline Path& operator <<(Path& poly, const IntPoint& p) {poly.push_back(p); return poly;} 110 | inline Paths& operator <<(Paths& polys, const Path& p) {polys.push_back(p); return polys;} 111 | 112 | std::ostream& operator <<(std::ostream &s, const IntPoint &p); 113 | std::ostream& operator <<(std::ostream &s, const Path &p); 114 | std::ostream& operator <<(std::ostream &s, const Paths &p); 115 | 116 | struct DoublePoint 117 | { 118 | double X; 119 | double Y; 120 | DoublePoint(double x = 0, double y = 0) : X(x), Y(y) {} 121 | DoublePoint(IntPoint ip) : X((double)ip.X), Y((double)ip.Y) {} 122 | }; 123 | //------------------------------------------------------------------------------ 124 | 125 | #ifdef use_xyz 126 | typedef void (*ZFillCallback)(IntPoint& e1bot, IntPoint& e1top, IntPoint& e2bot, IntPoint& e2top, IntPoint& pt); 127 | #endif 128 | 129 | enum InitOptions {ioReverseSolution = 1, ioStrictlySimple = 2, ioPreserveCollinear = 4}; 130 | enum JoinType {jtSquare, jtRound, jtMiter}; 131 | enum EndType {etClosedPolygon, etClosedLine, etOpenButt, etOpenSquare, etOpenRound}; 132 | 133 | class PolyNode; 134 | typedef std::vector< PolyNode* > PolyNodes; 135 | 136 | class PolyNode 137 | { 138 | public: 139 | PolyNode(); 140 | virtual ~PolyNode(){}; 141 | Path Contour; 142 | PolyNodes Childs; 143 | PolyNode* Parent; 144 | PolyNode* GetNext() const; 145 | bool IsHole() const; 146 | bool IsOpen() const; 147 | int ChildCount() const; 148 | private: 149 | unsigned Index; //node index in Parent.Childs 150 | bool m_IsOpen; 151 | JoinType m_jointype; 152 | EndType m_endtype; 153 | PolyNode* GetNextSiblingUp() const; 154 | void AddChild(PolyNode& child); 155 | friend class Clipper; //to access Index 156 | friend class ClipperOffset; 157 | }; 158 | 159 | class PolyTree: public PolyNode 160 | { 161 | public: 162 | ~PolyTree(){Clear();}; 163 | PolyNode* GetFirst() const; 164 | void Clear(); 165 | int Total() const; 166 | private: 167 | PolyNodes AllNodes; 168 | friend class Clipper; //to access AllNodes 169 | }; 170 | 171 | bool Orientation(const Path &poly); 172 | double Area(const Path &poly); 173 | int PointInPolygon(const IntPoint &pt, const Path &path); 174 | 175 | void SimplifyPolygon(const Path &in_poly, Paths &out_polys, PolyFillType fillType = pftEvenOdd); 176 | void SimplifyPolygons(const Paths &in_polys, Paths &out_polys, PolyFillType fillType = pftEvenOdd); 177 | void SimplifyPolygons(Paths &polys, PolyFillType fillType = pftEvenOdd); 178 | 179 | void CleanPolygon(const Path& in_poly, Path& out_poly, double distance = 1.415); 180 | void CleanPolygon(Path& poly, double distance = 1.415); 181 | void CleanPolygons(const Paths& in_polys, Paths& out_polys, double distance = 1.415); 182 | void CleanPolygons(Paths& polys, double distance = 1.415); 183 | 184 | void MinkowskiSum(const Path& pattern, const Path& path, Paths& solution, bool pathIsClosed); 185 | void MinkowskiSum(const Path& pattern, const Paths& paths, Paths& solution, bool pathIsClosed); 186 | void MinkowskiDiff(const Path& poly1, const Path& poly2, Paths& solution); 187 | 188 | void PolyTreeToPaths(const PolyTree& polytree, Paths& paths); 189 | void ClosedPathsFromPolyTree(const PolyTree& polytree, Paths& paths); 190 | void OpenPathsFromPolyTree(PolyTree& polytree, Paths& paths); 191 | 192 | void ReversePath(Path& p); 193 | void ReversePaths(Paths& p); 194 | 195 | struct IntRect { cInt left; cInt top; cInt right; cInt bottom; }; 196 | 197 | //enums that are used internally ... 198 | enum EdgeSide { esLeft = 1, esRight = 2}; 199 | 200 | //forward declarations (for stuff used internally) ... 201 | struct TEdge; 202 | struct IntersectNode; 203 | struct LocalMinimum; 204 | struct OutPt; 205 | struct OutRec; 206 | struct Join; 207 | 208 | typedef std::vector < OutRec* > PolyOutList; 209 | typedef std::vector < TEdge* > EdgeList; 210 | typedef std::vector < Join* > JoinList; 211 | typedef std::vector < IntersectNode* > IntersectList; 212 | 213 | //------------------------------------------------------------------------------ 214 | 215 | //ClipperBase is the ancestor to the Clipper class. It should not be 216 | //instantiated directly. This class simply abstracts the conversion of sets of 217 | //polygon coordinates into edge objects that are stored in a LocalMinima list. 218 | class ClipperBase 219 | { 220 | public: 221 | ClipperBase(); 222 | virtual ~ClipperBase(); 223 | virtual bool AddPath(const Path &pg, PolyType PolyTyp, bool Closed); 224 | bool AddPaths(const Paths &ppg, PolyType PolyTyp, bool Closed); 225 | virtual void Clear(); 226 | IntRect GetBounds(); 227 | bool PreserveCollinear() {return m_PreserveCollinear;}; 228 | void PreserveCollinear(bool value) {m_PreserveCollinear = value;}; 229 | protected: 230 | void DisposeLocalMinimaList(); 231 | TEdge* AddBoundsToLML(TEdge *e, bool IsClosed); 232 | virtual void Reset(); 233 | TEdge* ProcessBound(TEdge* E, bool IsClockwise); 234 | void InsertScanbeam(const cInt Y); 235 | bool PopScanbeam(cInt &Y); 236 | bool LocalMinimaPending(); 237 | bool PopLocalMinima(cInt Y, const LocalMinimum *&locMin); 238 | OutRec* CreateOutRec(); 239 | void DisposeAllOutRecs(); 240 | void DisposeOutRec(PolyOutList::size_type index); 241 | void SwapPositionsInAEL(TEdge *edge1, TEdge *edge2); 242 | void DeleteFromAEL(TEdge *e); 243 | void UpdateEdgeIntoAEL(TEdge *&e); 244 | 245 | typedef std::vector MinimaList; 246 | MinimaList::iterator m_CurrentLM; 247 | MinimaList m_MinimaList; 248 | 249 | bool m_UseFullRange; 250 | EdgeList m_edges; 251 | bool m_PreserveCollinear; 252 | bool m_HasOpenPaths; 253 | PolyOutList m_PolyOuts; 254 | TEdge *m_ActiveEdges; 255 | 256 | typedef std::priority_queue ScanbeamList; 257 | ScanbeamList m_Scanbeam; 258 | }; 259 | //------------------------------------------------------------------------------ 260 | 261 | class Clipper : public virtual ClipperBase 262 | { 263 | public: 264 | Clipper(int initOptions = 0); 265 | bool Execute(ClipType clipType, 266 | Paths &solution, 267 | PolyFillType fillType = pftEvenOdd); 268 | bool Execute(ClipType clipType, 269 | Paths &solution, 270 | PolyFillType subjFillType, 271 | PolyFillType clipFillType); 272 | bool Execute(ClipType clipType, 273 | PolyTree &polytree, 274 | PolyFillType fillType = pftEvenOdd); 275 | bool Execute(ClipType clipType, 276 | PolyTree &polytree, 277 | PolyFillType subjFillType, 278 | PolyFillType clipFillType); 279 | bool ReverseSolution() { return m_ReverseOutput; }; 280 | void ReverseSolution(bool value) {m_ReverseOutput = value;}; 281 | bool StrictlySimple() {return m_StrictSimple;}; 282 | void StrictlySimple(bool value) {m_StrictSimple = value;}; 283 | //set the callback function for z value filling on intersections (otherwise Z is 0) 284 | #ifdef use_xyz 285 | void ZFillFunction(ZFillCallback zFillFunc); 286 | #endif 287 | protected: 288 | virtual bool ExecuteInternal(); 289 | private: 290 | JoinList m_Joins; 291 | JoinList m_GhostJoins; 292 | IntersectList m_IntersectList; 293 | ClipType m_ClipType; 294 | typedef std::list MaximaList; 295 | MaximaList m_Maxima; 296 | TEdge *m_SortedEdges; 297 | bool m_ExecuteLocked; 298 | PolyFillType m_ClipFillType; 299 | PolyFillType m_SubjFillType; 300 | bool m_ReverseOutput; 301 | bool m_UsingPolyTree; 302 | bool m_StrictSimple; 303 | #ifdef use_xyz 304 | ZFillCallback m_ZFill; //custom callback 305 | #endif 306 | void SetWindingCount(TEdge& edge); 307 | bool IsEvenOddFillType(const TEdge& edge) const; 308 | bool IsEvenOddAltFillType(const TEdge& edge) const; 309 | void InsertLocalMinimaIntoAEL(const cInt botY); 310 | void InsertEdgeIntoAEL(TEdge *edge, TEdge* startEdge); 311 | void AddEdgeToSEL(TEdge *edge); 312 | bool PopEdgeFromSEL(TEdge *&edge); 313 | void CopyAELToSEL(); 314 | void DeleteFromSEL(TEdge *e); 315 | void SwapPositionsInSEL(TEdge *edge1, TEdge *edge2); 316 | bool IsContributing(const TEdge& edge) const; 317 | bool IsTopHorz(const cInt XPos); 318 | void DoMaxima(TEdge *e); 319 | void ProcessHorizontals(); 320 | void ProcessHorizontal(TEdge *horzEdge); 321 | void AddLocalMaxPoly(TEdge *e1, TEdge *e2, const IntPoint &pt); 322 | OutPt* AddLocalMinPoly(TEdge *e1, TEdge *e2, const IntPoint &pt); 323 | OutRec* GetOutRec(int idx); 324 | void AppendPolygon(TEdge *e1, TEdge *e2); 325 | void IntersectEdges(TEdge *e1, TEdge *e2, IntPoint &pt); 326 | OutPt* AddOutPt(TEdge *e, const IntPoint &pt); 327 | OutPt* GetLastOutPt(TEdge *e); 328 | bool ProcessIntersections(const cInt topY); 329 | void BuildIntersectList(const cInt topY); 330 | void ProcessIntersectList(); 331 | void ProcessEdgesAtTopOfScanbeam(const cInt topY); 332 | void BuildResult(Paths& polys); 333 | void BuildResult2(PolyTree& polytree); 334 | void SetHoleState(TEdge *e, OutRec *outrec); 335 | void DisposeIntersectNodes(); 336 | bool FixupIntersectionOrder(); 337 | void FixupOutPolygon(OutRec &outrec); 338 | void FixupOutPolyline(OutRec &outrec); 339 | bool IsHole(TEdge *e); 340 | bool FindOwnerFromSplitRecs(OutRec &outRec, OutRec *&currOrfl); 341 | void FixHoleLinkage(OutRec &outrec); 342 | void AddJoin(OutPt *op1, OutPt *op2, const IntPoint offPt); 343 | void ClearJoins(); 344 | void ClearGhostJoins(); 345 | void AddGhostJoin(OutPt *op, const IntPoint offPt); 346 | bool JoinPoints(Join *j, OutRec* outRec1, OutRec* outRec2); 347 | void JoinCommonEdges(); 348 | void DoSimplePolygons(); 349 | void FixupFirstLefts1(OutRec* OldOutRec, OutRec* NewOutRec); 350 | void FixupFirstLefts2(OutRec* InnerOutRec, OutRec* OuterOutRec); 351 | void FixupFirstLefts3(OutRec* OldOutRec, OutRec* NewOutRec); 352 | #ifdef use_xyz 353 | void SetZ(IntPoint& pt, TEdge& e1, TEdge& e2); 354 | #endif 355 | }; 356 | //------------------------------------------------------------------------------ 357 | 358 | class ClipperOffset 359 | { 360 | public: 361 | ClipperOffset(double miterLimit = 2.0, double roundPrecision = 0.25); 362 | ~ClipperOffset(); 363 | void AddPath(const Path& path, JoinType joinType, EndType endType); 364 | void AddPaths(const Paths& paths, JoinType joinType, EndType endType); 365 | void Execute(Paths& solution, double delta); 366 | void Execute(PolyTree& solution, double delta); 367 | void Clear(); 368 | double MiterLimit; 369 | double ArcTolerance; 370 | private: 371 | Paths m_destPolys; 372 | Path m_srcPoly; 373 | Path m_destPoly; 374 | std::vector m_normals; 375 | double m_delta, m_sinA, m_sin, m_cos; 376 | double m_miterLim, m_StepsPerRad; 377 | IntPoint m_lowest; 378 | PolyNode m_polyNodes; 379 | 380 | void FixOrientations(); 381 | void DoOffset(double delta); 382 | void OffsetPoint(int j, int& k, JoinType jointype); 383 | void DoSquare(int j, int k); 384 | void DoMiter(int j, int k, double r); 385 | void DoRound(int j, int k); 386 | }; 387 | //------------------------------------------------------------------------------ 388 | 389 | class clipperException : public std::exception 390 | { 391 | public: 392 | clipperException(const char* description): m_descr(description) {} 393 | virtual ~clipperException() throw() {} 394 | virtual const char* what() const throw() {return m_descr.c_str();} 395 | private: 396 | std::string m_descr; 397 | }; 398 | //------------------------------------------------------------------------------ 399 | 400 | } //ClipperLib namespace 401 | 402 | #endif //clipper_hpp 403 | 404 | 405 | -------------------------------------------------------------------------------- /lanms/include/pybind11/attr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/attr.h: Infrastructure for processing custom 3 | type and function attributes 4 | 5 | Copyright (c) 2016 Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cast.h" 14 | 15 | NAMESPACE_BEGIN(pybind11) 16 | 17 | /// \addtogroup annotations 18 | /// @{ 19 | 20 | /// Annotation for methods 21 | struct is_method { handle class_; is_method(const handle &c) : class_(c) { } }; 22 | 23 | /// Annotation for operators 24 | struct is_operator { }; 25 | 26 | /// Annotation for parent scope 27 | struct scope { handle value; scope(const handle &s) : value(s) { } }; 28 | 29 | /// Annotation for documentation 30 | struct doc { const char *value; doc(const char *value) : value(value) { } }; 31 | 32 | /// Annotation for function names 33 | struct name { const char *value; name(const char *value) : value(value) { } }; 34 | 35 | /// Annotation indicating that a function is an overload associated with a given "sibling" 36 | struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } }; 37 | 38 | /// Annotation indicating that a class derives from another given type 39 | template struct base { 40 | PYBIND11_DEPRECATED("base() was deprecated in favor of specifying 'T' as a template argument to class_") 41 | base() { } 42 | }; 43 | 44 | /// Keep patient alive while nurse lives 45 | template struct keep_alive { }; 46 | 47 | /// Annotation indicating that a class is involved in a multiple inheritance relationship 48 | struct multiple_inheritance { }; 49 | 50 | /// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class 51 | struct dynamic_attr { }; 52 | 53 | /// Annotation which enables the buffer protocol for a type 54 | struct buffer_protocol { }; 55 | 56 | /// Annotation which requests that a special metaclass is created for a type 57 | struct metaclass { 58 | handle value; 59 | 60 | PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.") 61 | metaclass() {} 62 | 63 | /// Override pybind11's default metaclass 64 | explicit metaclass(handle value) : value(value) { } 65 | }; 66 | 67 | /// Annotation to mark enums as an arithmetic type 68 | struct arithmetic { }; 69 | 70 | /** \rst 71 | A call policy which places one or more guard variables (``Ts...``) around the function call. 72 | 73 | For example, this definition: 74 | 75 | .. code-block:: cpp 76 | 77 | m.def("foo", foo, py::call_guard()); 78 | 79 | is equivalent to the following pseudocode: 80 | 81 | .. code-block:: cpp 82 | 83 | m.def("foo", [](args...) { 84 | T scope_guard; 85 | return foo(args...); // forwarded arguments 86 | }); 87 | \endrst */ 88 | template struct call_guard; 89 | 90 | template <> struct call_guard<> { using type = detail::void_type; }; 91 | 92 | template 93 | struct call_guard { 94 | static_assert(std::is_default_constructible::value, 95 | "The guard type must be default constructible"); 96 | 97 | using type = T; 98 | }; 99 | 100 | template 101 | struct call_guard { 102 | struct type { 103 | T guard{}; // Compose multiple guard types with left-to-right default-constructor order 104 | typename call_guard::type next{}; 105 | }; 106 | }; 107 | 108 | /// @} annotations 109 | 110 | NAMESPACE_BEGIN(detail) 111 | /* Forward declarations */ 112 | enum op_id : int; 113 | enum op_type : int; 114 | struct undefined_t; 115 | template struct op_; 116 | template struct init; 117 | template struct init_alias; 118 | inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret); 119 | 120 | /// Internal data structure which holds metadata about a keyword argument 121 | struct argument_record { 122 | const char *name; ///< Argument name 123 | const char *descr; ///< Human-readable version of the argument value 124 | handle value; ///< Associated Python object 125 | bool convert : 1; ///< True if the argument is allowed to convert when loading 126 | bool none : 1; ///< True if None is allowed when loading 127 | 128 | argument_record(const char *name, const char *descr, handle value, bool convert, bool none) 129 | : name(name), descr(descr), value(value), convert(convert), none(none) { } 130 | }; 131 | 132 | /// Internal data structure which holds metadata about a bound function (signature, overloads, etc.) 133 | struct function_record { 134 | function_record() 135 | : is_constructor(false), is_stateless(false), is_operator(false), 136 | has_args(false), has_kwargs(false), is_method(false) { } 137 | 138 | /// Function name 139 | char *name = nullptr; /* why no C++ strings? They generate heavier code.. */ 140 | 141 | // User-specified documentation string 142 | char *doc = nullptr; 143 | 144 | /// Human-readable version of the function signature 145 | char *signature = nullptr; 146 | 147 | /// List of registered keyword arguments 148 | std::vector args; 149 | 150 | /// Pointer to lambda function which converts arguments and performs the actual call 151 | handle (*impl) (function_call &) = nullptr; 152 | 153 | /// Storage for the wrapped function pointer and captured data, if any 154 | void *data[3] = { }; 155 | 156 | /// Pointer to custom destructor for 'data' (if needed) 157 | void (*free_data) (function_record *ptr) = nullptr; 158 | 159 | /// Return value policy associated with this function 160 | return_value_policy policy = return_value_policy::automatic; 161 | 162 | /// True if name == '__init__' 163 | bool is_constructor : 1; 164 | 165 | /// True if this is a stateless function pointer 166 | bool is_stateless : 1; 167 | 168 | /// True if this is an operator (__add__), etc. 169 | bool is_operator : 1; 170 | 171 | /// True if the function has a '*args' argument 172 | bool has_args : 1; 173 | 174 | /// True if the function has a '**kwargs' argument 175 | bool has_kwargs : 1; 176 | 177 | /// True if this is a method 178 | bool is_method : 1; 179 | 180 | /// Number of arguments (including py::args and/or py::kwargs, if present) 181 | std::uint16_t nargs; 182 | 183 | /// Python method object 184 | PyMethodDef *def = nullptr; 185 | 186 | /// Python handle to the parent scope (a class or a module) 187 | handle scope; 188 | 189 | /// Python handle to the sibling function representing an overload chain 190 | handle sibling; 191 | 192 | /// Pointer to next overload 193 | function_record *next = nullptr; 194 | }; 195 | 196 | /// Special data structure which (temporarily) holds metadata about a bound class 197 | struct type_record { 198 | PYBIND11_NOINLINE type_record() 199 | : multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false) { } 200 | 201 | /// Handle to the parent scope 202 | handle scope; 203 | 204 | /// Name of the class 205 | const char *name = nullptr; 206 | 207 | // Pointer to RTTI type_info data structure 208 | const std::type_info *type = nullptr; 209 | 210 | /// How large is the underlying C++ type? 211 | size_t type_size = 0; 212 | 213 | /// How large is the type's holder? 214 | size_t holder_size = 0; 215 | 216 | /// The global operator new can be overridden with a class-specific variant 217 | void *(*operator_new)(size_t) = ::operator new; 218 | 219 | /// Function pointer to class_<..>::init_instance 220 | void (*init_instance)(instance *, const void *) = nullptr; 221 | 222 | /// Function pointer to class_<..>::dealloc 223 | void (*dealloc)(const detail::value_and_holder &) = nullptr; 224 | 225 | /// List of base classes of the newly created type 226 | list bases; 227 | 228 | /// Optional docstring 229 | const char *doc = nullptr; 230 | 231 | /// Custom metaclass (optional) 232 | handle metaclass; 233 | 234 | /// Multiple inheritance marker 235 | bool multiple_inheritance : 1; 236 | 237 | /// Does the class manage a __dict__? 238 | bool dynamic_attr : 1; 239 | 240 | /// Does the class implement the buffer protocol? 241 | bool buffer_protocol : 1; 242 | 243 | /// Is the default (unique_ptr) holder type used? 244 | bool default_holder : 1; 245 | 246 | PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) { 247 | auto base_info = detail::get_type_info(base, false); 248 | if (!base_info) { 249 | std::string tname(base.name()); 250 | detail::clean_type_id(tname); 251 | pybind11_fail("generic_type: type \"" + std::string(name) + 252 | "\" referenced unknown base type \"" + tname + "\""); 253 | } 254 | 255 | if (default_holder != base_info->default_holder) { 256 | std::string tname(base.name()); 257 | detail::clean_type_id(tname); 258 | pybind11_fail("generic_type: type \"" + std::string(name) + "\" " + 259 | (default_holder ? "does not have" : "has") + 260 | " a non-default holder type while its base \"" + tname + "\" " + 261 | (base_info->default_holder ? "does not" : "does")); 262 | } 263 | 264 | bases.append((PyObject *) base_info->type); 265 | 266 | if (base_info->type->tp_dictoffset != 0) 267 | dynamic_attr = true; 268 | 269 | if (caster) 270 | base_info->implicit_casts.emplace_back(type, caster); 271 | } 272 | }; 273 | 274 | inline function_call::function_call(function_record &f, handle p) : 275 | func(f), parent(p) { 276 | args.reserve(f.nargs); 277 | args_convert.reserve(f.nargs); 278 | } 279 | 280 | /** 281 | * Partial template specializations to process custom attributes provided to 282 | * cpp_function_ and class_. These are either used to initialize the respective 283 | * fields in the type_record and function_record data structures or executed at 284 | * runtime to deal with custom call policies (e.g. keep_alive). 285 | */ 286 | template struct process_attribute; 287 | 288 | template struct process_attribute_default { 289 | /// Default implementation: do nothing 290 | static void init(const T &, function_record *) { } 291 | static void init(const T &, type_record *) { } 292 | static void precall(function_call &) { } 293 | static void postcall(function_call &, handle) { } 294 | }; 295 | 296 | /// Process an attribute specifying the function's name 297 | template <> struct process_attribute : process_attribute_default { 298 | static void init(const name &n, function_record *r) { r->name = const_cast(n.value); } 299 | }; 300 | 301 | /// Process an attribute specifying the function's docstring 302 | template <> struct process_attribute : process_attribute_default { 303 | static void init(const doc &n, function_record *r) { r->doc = const_cast(n.value); } 304 | }; 305 | 306 | /// Process an attribute specifying the function's docstring (provided as a C-style string) 307 | template <> struct process_attribute : process_attribute_default { 308 | static void init(const char *d, function_record *r) { r->doc = const_cast(d); } 309 | static void init(const char *d, type_record *r) { r->doc = const_cast(d); } 310 | }; 311 | template <> struct process_attribute : process_attribute { }; 312 | 313 | /// Process an attribute indicating the function's return value policy 314 | template <> struct process_attribute : process_attribute_default { 315 | static void init(const return_value_policy &p, function_record *r) { r->policy = p; } 316 | }; 317 | 318 | /// Process an attribute which indicates that this is an overloaded function associated with a given sibling 319 | template <> struct process_attribute : process_attribute_default { 320 | static void init(const sibling &s, function_record *r) { r->sibling = s.value; } 321 | }; 322 | 323 | /// Process an attribute which indicates that this function is a method 324 | template <> struct process_attribute : process_attribute_default { 325 | static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; } 326 | }; 327 | 328 | /// Process an attribute which indicates the parent scope of a method 329 | template <> struct process_attribute : process_attribute_default { 330 | static void init(const scope &s, function_record *r) { r->scope = s.value; } 331 | }; 332 | 333 | /// Process an attribute which indicates that this function is an operator 334 | template <> struct process_attribute : process_attribute_default { 335 | static void init(const is_operator &, function_record *r) { r->is_operator = true; } 336 | }; 337 | 338 | /// Process a keyword argument attribute (*without* a default value) 339 | template <> struct process_attribute : process_attribute_default { 340 | static void init(const arg &a, function_record *r) { 341 | if (r->is_method && r->args.empty()) 342 | r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/); 343 | r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none); 344 | } 345 | }; 346 | 347 | /// Process a keyword argument attribute (*with* a default value) 348 | template <> struct process_attribute : process_attribute_default { 349 | static void init(const arg_v &a, function_record *r) { 350 | if (r->is_method && r->args.empty()) 351 | r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/); 352 | 353 | if (!a.value) { 354 | #if !defined(NDEBUG) 355 | std::string descr("'"); 356 | if (a.name) descr += std::string(a.name) + ": "; 357 | descr += a.type + "'"; 358 | if (r->is_method) { 359 | if (r->name) 360 | descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'"; 361 | else 362 | descr += " in method of '" + (std::string) str(r->scope) + "'"; 363 | } else if (r->name) { 364 | descr += " in function '" + (std::string) r->name + "'"; 365 | } 366 | pybind11_fail("arg(): could not convert default argument " 367 | + descr + " into a Python object (type not registered yet?)"); 368 | #else 369 | pybind11_fail("arg(): could not convert default argument " 370 | "into a Python object (type not registered yet?). " 371 | "Compile in debug mode for more information."); 372 | #endif 373 | } 374 | r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none); 375 | } 376 | }; 377 | 378 | /// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that) 379 | template 380 | struct process_attribute::value>> : process_attribute_default { 381 | static void init(const handle &h, type_record *r) { r->bases.append(h); } 382 | }; 383 | 384 | /// Process a parent class attribute (deprecated, does not support multiple inheritance) 385 | template 386 | struct process_attribute> : process_attribute_default> { 387 | static void init(const base &, type_record *r) { r->add_base(typeid(T), nullptr); } 388 | }; 389 | 390 | /// Process a multiple inheritance attribute 391 | template <> 392 | struct process_attribute : process_attribute_default { 393 | static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; } 394 | }; 395 | 396 | template <> 397 | struct process_attribute : process_attribute_default { 398 | static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; } 399 | }; 400 | 401 | template <> 402 | struct process_attribute : process_attribute_default { 403 | static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; } 404 | }; 405 | 406 | template <> 407 | struct process_attribute : process_attribute_default { 408 | static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; } 409 | }; 410 | 411 | 412 | /// Process an 'arithmetic' attribute for enums (does nothing here) 413 | template <> 414 | struct process_attribute : process_attribute_default {}; 415 | 416 | template 417 | struct process_attribute> : process_attribute_default> { }; 418 | 419 | /** 420 | * Process a keep_alive call policy -- invokes keep_alive_impl during the 421 | * pre-call handler if both Nurse, Patient != 0 and use the post-call handler 422 | * otherwise 423 | */ 424 | template struct process_attribute> : public process_attribute_default> { 425 | template = 0> 426 | static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); } 427 | template = 0> 428 | static void postcall(function_call &, handle) { } 429 | template = 0> 430 | static void precall(function_call &) { } 431 | template = 0> 432 | static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); } 433 | }; 434 | 435 | /// Recursively iterate over variadic template arguments 436 | template struct process_attributes { 437 | static void init(const Args&... args, function_record *r) { 438 | int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; 439 | ignore_unused(unused); 440 | } 441 | static void init(const Args&... args, type_record *r) { 442 | int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; 443 | ignore_unused(unused); 444 | } 445 | static void precall(function_call &call) { 446 | int unused[] = { 0, (process_attribute::type>::precall(call), 0) ... }; 447 | ignore_unused(unused); 448 | } 449 | static void postcall(function_call &call, handle fn_ret) { 450 | int unused[] = { 0, (process_attribute::type>::postcall(call, fn_ret), 0) ... }; 451 | ignore_unused(unused); 452 | } 453 | }; 454 | 455 | template 456 | using is_call_guard = is_instantiation; 457 | 458 | /// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found) 459 | template 460 | using extract_guard_t = typename exactly_one_t, Extra...>::type; 461 | 462 | /// Check the number of named arguments at compile time 463 | template ::value...), 465 | size_t self = constexpr_sum(std::is_same::value...)> 466 | constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) { 467 | return named == 0 || (self + named + has_args + has_kwargs) == nargs; 468 | } 469 | 470 | NAMESPACE_END(detail) 471 | NAMESPACE_END(pybind11) 472 | -------------------------------------------------------------------------------- /lanms/include/pybind11/stl_bind.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/std_bind.h: Binding generators for STL data types 3 | 4 | Copyright (c) 2016 Sergey Lyskov and Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "common.h" 13 | #include "operators.h" 14 | 15 | #include 16 | #include 17 | 18 | NAMESPACE_BEGIN(pybind11) 19 | NAMESPACE_BEGIN(detail) 20 | 21 | /* SFINAE helper class used by 'is_comparable */ 22 | template struct container_traits { 23 | template static std::true_type test_comparable(decltype(std::declval() == std::declval())*); 24 | template static std::false_type test_comparable(...); 25 | template static std::true_type test_value(typename T2::value_type *); 26 | template static std::false_type test_value(...); 27 | template static std::true_type test_pair(typename T2::first_type *, typename T2::second_type *); 28 | template static std::false_type test_pair(...); 29 | 30 | static constexpr const bool is_comparable = std::is_same(nullptr))>::value; 31 | static constexpr const bool is_pair = std::is_same(nullptr, nullptr))>::value; 32 | static constexpr const bool is_vector = std::is_same(nullptr))>::value; 33 | static constexpr const bool is_element = !is_pair && !is_vector; 34 | }; 35 | 36 | /* Default: is_comparable -> std::false_type */ 37 | template 38 | struct is_comparable : std::false_type { }; 39 | 40 | /* For non-map data structures, check whether operator== can be instantiated */ 41 | template 42 | struct is_comparable< 43 | T, enable_if_t::is_element && 44 | container_traits::is_comparable>> 45 | : std::true_type { }; 46 | 47 | /* For a vector/map data structure, recursively check the value type (which is std::pair for maps) */ 48 | template 49 | struct is_comparable::is_vector>> { 50 | static constexpr const bool value = 51 | is_comparable::value; 52 | }; 53 | 54 | /* For pairs, recursively check the two data types */ 55 | template 56 | struct is_comparable::is_pair>> { 57 | static constexpr const bool value = 58 | is_comparable::value && 59 | is_comparable::value; 60 | }; 61 | 62 | /* Fallback functions */ 63 | template void vector_if_copy_constructible(const Args &...) { } 64 | template void vector_if_equal_operator(const Args &...) { } 65 | template void vector_if_insertion_operator(const Args &...) { } 66 | template void vector_modifiers(const Args &...) { } 67 | 68 | template 69 | void vector_if_copy_constructible(enable_if_t::value, Class_> &cl) { 70 | cl.def(init(), "Copy constructor"); 71 | } 72 | 73 | template 74 | void vector_if_equal_operator(enable_if_t::value, Class_> &cl) { 75 | using T = typename Vector::value_type; 76 | 77 | cl.def(self == self); 78 | cl.def(self != self); 79 | 80 | cl.def("count", 81 | [](const Vector &v, const T &x) { 82 | return std::count(v.begin(), v.end(), x); 83 | }, 84 | arg("x"), 85 | "Return the number of times ``x`` appears in the list" 86 | ); 87 | 88 | cl.def("remove", [](Vector &v, const T &x) { 89 | auto p = std::find(v.begin(), v.end(), x); 90 | if (p != v.end()) 91 | v.erase(p); 92 | else 93 | throw value_error(); 94 | }, 95 | arg("x"), 96 | "Remove the first item from the list whose value is x. " 97 | "It is an error if there is no such item." 98 | ); 99 | 100 | cl.def("__contains__", 101 | [](const Vector &v, const T &x) { 102 | return std::find(v.begin(), v.end(), x) != v.end(); 103 | }, 104 | arg("x"), 105 | "Return true the container contains ``x``" 106 | ); 107 | } 108 | 109 | // Vector modifiers -- requires a copyable vector_type: 110 | // (Technically, some of these (pop and __delitem__) don't actually require copyability, but it seems 111 | // silly to allow deletion but not insertion, so include them here too.) 112 | template 113 | void vector_modifiers(enable_if_t::value, Class_> &cl) { 114 | using T = typename Vector::value_type; 115 | using SizeType = typename Vector::size_type; 116 | using DiffType = typename Vector::difference_type; 117 | 118 | cl.def("append", 119 | [](Vector &v, const T &value) { v.push_back(value); }, 120 | arg("x"), 121 | "Add an item to the end of the list"); 122 | 123 | cl.def("__init__", [](Vector &v, iterable it) { 124 | new (&v) Vector(); 125 | try { 126 | v.reserve(len(it)); 127 | for (handle h : it) 128 | v.push_back(h.cast()); 129 | } catch (...) { 130 | v.~Vector(); 131 | throw; 132 | } 133 | }); 134 | 135 | cl.def("extend", 136 | [](Vector &v, const Vector &src) { 137 | v.insert(v.end(), src.begin(), src.end()); 138 | }, 139 | arg("L"), 140 | "Extend the list by appending all the items in the given list" 141 | ); 142 | 143 | cl.def("insert", 144 | [](Vector &v, SizeType i, const T &x) { 145 | if (i > v.size()) 146 | throw index_error(); 147 | v.insert(v.begin() + (DiffType) i, x); 148 | }, 149 | arg("i") , arg("x"), 150 | "Insert an item at a given position." 151 | ); 152 | 153 | cl.def("pop", 154 | [](Vector &v) { 155 | if (v.empty()) 156 | throw index_error(); 157 | T t = v.back(); 158 | v.pop_back(); 159 | return t; 160 | }, 161 | "Remove and return the last item" 162 | ); 163 | 164 | cl.def("pop", 165 | [](Vector &v, SizeType i) { 166 | if (i >= v.size()) 167 | throw index_error(); 168 | T t = v[i]; 169 | v.erase(v.begin() + (DiffType) i); 170 | return t; 171 | }, 172 | arg("i"), 173 | "Remove and return the item at index ``i``" 174 | ); 175 | 176 | cl.def("__setitem__", 177 | [](Vector &v, SizeType i, const T &t) { 178 | if (i >= v.size()) 179 | throw index_error(); 180 | v[i] = t; 181 | } 182 | ); 183 | 184 | /// Slicing protocol 185 | cl.def("__getitem__", 186 | [](const Vector &v, slice slice) -> Vector * { 187 | size_t start, stop, step, slicelength; 188 | 189 | if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) 190 | throw error_already_set(); 191 | 192 | Vector *seq = new Vector(); 193 | seq->reserve((size_t) slicelength); 194 | 195 | for (size_t i=0; ipush_back(v[start]); 197 | start += step; 198 | } 199 | return seq; 200 | }, 201 | arg("s"), 202 | "Retrieve list elements using a slice object" 203 | ); 204 | 205 | cl.def("__setitem__", 206 | [](Vector &v, slice slice, const Vector &value) { 207 | size_t start, stop, step, slicelength; 208 | if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) 209 | throw error_already_set(); 210 | 211 | if (slicelength != value.size()) 212 | throw std::runtime_error("Left and right hand size of slice assignment have different sizes!"); 213 | 214 | for (size_t i=0; i= v.size()) 225 | throw index_error(); 226 | v.erase(v.begin() + DiffType(i)); 227 | }, 228 | "Delete the list elements at index ``i``" 229 | ); 230 | 231 | cl.def("__delitem__", 232 | [](Vector &v, slice slice) { 233 | size_t start, stop, step, slicelength; 234 | 235 | if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) 236 | throw error_already_set(); 237 | 238 | if (step == 1 && false) { 239 | v.erase(v.begin() + (DiffType) start, v.begin() + DiffType(start + slicelength)); 240 | } else { 241 | for (size_t i = 0; i < slicelength; ++i) { 242 | v.erase(v.begin() + DiffType(start)); 243 | start += step - 1; 244 | } 245 | } 246 | }, 247 | "Delete list elements using a slice object" 248 | ); 249 | 250 | } 251 | 252 | // If the type has an operator[] that doesn't return a reference (most notably std::vector), 253 | // we have to access by copying; otherwise we return by reference. 254 | template using vector_needs_copy = negation< 255 | std::is_same()[typename Vector::size_type()]), typename Vector::value_type &>>; 256 | 257 | // The usual case: access and iterate by reference 258 | template 259 | void vector_accessor(enable_if_t::value, Class_> &cl) { 260 | using T = typename Vector::value_type; 261 | using SizeType = typename Vector::size_type; 262 | using ItType = typename Vector::iterator; 263 | 264 | cl.def("__getitem__", 265 | [](Vector &v, SizeType i) -> T & { 266 | if (i >= v.size()) 267 | throw index_error(); 268 | return v[i]; 269 | }, 270 | return_value_policy::reference_internal // ref + keepalive 271 | ); 272 | 273 | cl.def("__iter__", 274 | [](Vector &v) { 275 | return make_iterator< 276 | return_value_policy::reference_internal, ItType, ItType, T&>( 277 | v.begin(), v.end()); 278 | }, 279 | keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ 280 | ); 281 | } 282 | 283 | // The case for special objects, like std::vector, that have to be returned-by-copy: 284 | template 285 | void vector_accessor(enable_if_t::value, Class_> &cl) { 286 | using T = typename Vector::value_type; 287 | using SizeType = typename Vector::size_type; 288 | using ItType = typename Vector::iterator; 289 | cl.def("__getitem__", 290 | [](const Vector &v, SizeType i) -> T { 291 | if (i >= v.size()) 292 | throw index_error(); 293 | return v[i]; 294 | } 295 | ); 296 | 297 | cl.def("__iter__", 298 | [](Vector &v) { 299 | return make_iterator< 300 | return_value_policy::copy, ItType, ItType, T>( 301 | v.begin(), v.end()); 302 | }, 303 | keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ 304 | ); 305 | } 306 | 307 | template auto vector_if_insertion_operator(Class_ &cl, std::string const &name) 308 | -> decltype(std::declval() << std::declval(), void()) { 309 | using size_type = typename Vector::size_type; 310 | 311 | cl.def("__repr__", 312 | [name](Vector &v) { 313 | std::ostringstream s; 314 | s << name << '['; 315 | for (size_type i=0; i < v.size(); ++i) { 316 | s << v[i]; 317 | if (i != v.size() - 1) 318 | s << ", "; 319 | } 320 | s << ']'; 321 | return s.str(); 322 | }, 323 | "Return the canonical string representation of this list." 324 | ); 325 | } 326 | 327 | // Provide the buffer interface for vectors if we have data() and we have a format for it 328 | // GCC seems to have "void std::vector::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer 329 | template 330 | struct vector_has_data_and_format : std::false_type {}; 331 | template 332 | struct vector_has_data_and_format::format(), std::declval().data()), typename Vector::value_type*>::value>> : std::true_type {}; 333 | 334 | // Add the buffer interface to a vector 335 | template 336 | enable_if_t...>::value> 337 | vector_buffer(Class_& cl) { 338 | using T = typename Vector::value_type; 339 | 340 | static_assert(vector_has_data_and_format::value, "There is not an appropriate format descriptor for this vector"); 341 | 342 | // numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here 343 | format_descriptor::format(); 344 | 345 | cl.def_buffer([](Vector& v) -> buffer_info { 346 | return buffer_info(v.data(), static_cast(sizeof(T)), format_descriptor::format(), 1, {v.size()}, {sizeof(T)}); 347 | }); 348 | 349 | cl.def("__init__", [](Vector& vec, buffer buf) { 350 | auto info = buf.request(); 351 | if (info.ndim != 1 || info.strides[0] % static_cast(sizeof(T))) 352 | throw type_error("Only valid 1D buffers can be copied to a vector"); 353 | if (!detail::compare_buffer_info::compare(info) || (ssize_t) sizeof(T) != info.itemsize) 354 | throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor::format() + ")"); 355 | new (&vec) Vector(); 356 | vec.reserve((size_t) info.shape[0]); 357 | T *p = static_cast(info.ptr); 358 | ssize_t step = info.strides[0] / static_cast(sizeof(T)); 359 | T *end = p + info.shape[0] * step; 360 | for (; p != end; p += step) 361 | vec.push_back(*p); 362 | }); 363 | 364 | return; 365 | } 366 | 367 | template 368 | enable_if_t...>::value> vector_buffer(Class_&) {} 369 | 370 | NAMESPACE_END(detail) 371 | 372 | // 373 | // std::vector 374 | // 375 | template , typename... Args> 376 | class_ bind_vector(module &m, std::string const &name, Args&&... args) { 377 | using Class_ = class_; 378 | 379 | Class_ cl(m, name.c_str(), std::forward(args)...); 380 | 381 | // Declare the buffer interface if a buffer_protocol() is passed in 382 | detail::vector_buffer(cl); 383 | 384 | cl.def(init<>()); 385 | 386 | // Register copy constructor (if possible) 387 | detail::vector_if_copy_constructible(cl); 388 | 389 | // Register comparison-related operators and functions (if possible) 390 | detail::vector_if_equal_operator(cl); 391 | 392 | // Register stream insertion operator (if possible) 393 | detail::vector_if_insertion_operator(cl, name); 394 | 395 | // Modifiers require copyable vector value type 396 | detail::vector_modifiers(cl); 397 | 398 | // Accessor and iterator; return by value if copyable, otherwise we return by ref + keep-alive 399 | detail::vector_accessor(cl); 400 | 401 | cl.def("__bool__", 402 | [](const Vector &v) -> bool { 403 | return !v.empty(); 404 | }, 405 | "Check whether the list is nonempty" 406 | ); 407 | 408 | cl.def("__len__", &Vector::size); 409 | 410 | 411 | 412 | 413 | #if 0 414 | // C++ style functions deprecated, leaving it here as an example 415 | cl.def(init()); 416 | 417 | cl.def("resize", 418 | (void (Vector::*) (size_type count)) & Vector::resize, 419 | "changes the number of elements stored"); 420 | 421 | cl.def("erase", 422 | [](Vector &v, SizeType i) { 423 | if (i >= v.size()) 424 | throw index_error(); 425 | v.erase(v.begin() + i); 426 | }, "erases element at index ``i``"); 427 | 428 | cl.def("empty", &Vector::empty, "checks whether the container is empty"); 429 | cl.def("size", &Vector::size, "returns the number of elements"); 430 | cl.def("push_back", (void (Vector::*)(const T&)) &Vector::push_back, "adds an element to the end"); 431 | cl.def("pop_back", &Vector::pop_back, "removes the last element"); 432 | 433 | cl.def("max_size", &Vector::max_size, "returns the maximum possible number of elements"); 434 | cl.def("reserve", &Vector::reserve, "reserves storage"); 435 | cl.def("capacity", &Vector::capacity, "returns the number of elements that can be held in currently allocated storage"); 436 | cl.def("shrink_to_fit", &Vector::shrink_to_fit, "reduces memory usage by freeing unused memory"); 437 | 438 | cl.def("clear", &Vector::clear, "clears the contents"); 439 | cl.def("swap", &Vector::swap, "swaps the contents"); 440 | 441 | cl.def("front", [](Vector &v) { 442 | if (v.size()) return v.front(); 443 | else throw index_error(); 444 | }, "access the first element"); 445 | 446 | cl.def("back", [](Vector &v) { 447 | if (v.size()) return v.back(); 448 | else throw index_error(); 449 | }, "access the last element "); 450 | 451 | #endif 452 | 453 | return cl; 454 | } 455 | 456 | 457 | 458 | // 459 | // std::map, std::unordered_map 460 | // 461 | 462 | NAMESPACE_BEGIN(detail) 463 | 464 | /* Fallback functions */ 465 | template void map_if_insertion_operator(const Args &...) { } 466 | template void map_assignment(const Args &...) { } 467 | 468 | // Map assignment when copy-assignable: just copy the value 469 | template 470 | void map_assignment(enable_if_t::value, Class_> &cl) { 471 | using KeyType = typename Map::key_type; 472 | using MappedType = typename Map::mapped_type; 473 | 474 | cl.def("__setitem__", 475 | [](Map &m, const KeyType &k, const MappedType &v) { 476 | auto it = m.find(k); 477 | if (it != m.end()) it->second = v; 478 | else m.emplace(k, v); 479 | } 480 | ); 481 | } 482 | 483 | // Not copy-assignable, but still copy-constructible: we can update the value by erasing and reinserting 484 | template 485 | void map_assignment(enable_if_t< 486 | !std::is_copy_assignable::value && 487 | is_copy_constructible::value, 488 | Class_> &cl) { 489 | using KeyType = typename Map::key_type; 490 | using MappedType = typename Map::mapped_type; 491 | 492 | cl.def("__setitem__", 493 | [](Map &m, const KeyType &k, const MappedType &v) { 494 | // We can't use m[k] = v; because value type might not be default constructable 495 | auto r = m.emplace(k, v); 496 | if (!r.second) { 497 | // value type is not copy assignable so the only way to insert it is to erase it first... 498 | m.erase(r.first); 499 | m.emplace(k, v); 500 | } 501 | } 502 | ); 503 | } 504 | 505 | 506 | template auto map_if_insertion_operator(Class_ &cl, std::string const &name) 507 | -> decltype(std::declval() << std::declval() << std::declval(), void()) { 508 | 509 | cl.def("__repr__", 510 | [name](Map &m) { 511 | std::ostringstream s; 512 | s << name << '{'; 513 | bool f = false; 514 | for (auto const &kv : m) { 515 | if (f) 516 | s << ", "; 517 | s << kv.first << ": " << kv.second; 518 | f = true; 519 | } 520 | s << '}'; 521 | return s.str(); 522 | }, 523 | "Return the canonical string representation of this map." 524 | ); 525 | } 526 | 527 | 528 | NAMESPACE_END(detail) 529 | 530 | template , typename... Args> 531 | class_ bind_map(module &m, const std::string &name, Args&&... args) { 532 | using KeyType = typename Map::key_type; 533 | using MappedType = typename Map::mapped_type; 534 | using Class_ = class_; 535 | 536 | Class_ cl(m, name.c_str(), std::forward(args)...); 537 | 538 | cl.def(init<>()); 539 | 540 | // Register stream insertion operator (if possible) 541 | detail::map_if_insertion_operator(cl, name); 542 | 543 | cl.def("__bool__", 544 | [](const Map &m) -> bool { return !m.empty(); }, 545 | "Check whether the map is nonempty" 546 | ); 547 | 548 | cl.def("__iter__", 549 | [](Map &m) { return make_key_iterator(m.begin(), m.end()); }, 550 | keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ 551 | ); 552 | 553 | cl.def("items", 554 | [](Map &m) { return make_iterator(m.begin(), m.end()); }, 555 | keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ 556 | ); 557 | 558 | cl.def("__getitem__", 559 | [](Map &m, const KeyType &k) -> MappedType & { 560 | auto it = m.find(k); 561 | if (it == m.end()) 562 | throw key_error(); 563 | return it->second; 564 | }, 565 | return_value_policy::reference_internal // ref + keepalive 566 | ); 567 | 568 | // Assignment provided only if the type is copyable 569 | detail::map_assignment(cl); 570 | 571 | cl.def("__delitem__", 572 | [](Map &m, const KeyType &k) { 573 | auto it = m.find(k); 574 | if (it == m.end()) 575 | throw key_error(); 576 | return m.erase(it); 577 | } 578 | ); 579 | 580 | cl.def("__len__", &Map::size); 581 | 582 | return cl; 583 | } 584 | 585 | NAMESPACE_END(pybind11) 586 | -------------------------------------------------------------------------------- /lanms/include/pybind11/class_support.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/class_support.h: Python C API implementation details for py::class_ 3 | 4 | Copyright (c) 2017 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "attr.h" 13 | 14 | NAMESPACE_BEGIN(pybind11) 15 | NAMESPACE_BEGIN(detail) 16 | 17 | inline PyTypeObject *type_incref(PyTypeObject *type) { 18 | Py_INCREF(type); 19 | return type; 20 | } 21 | 22 | #if !defined(PYPY_VERSION) 23 | 24 | /// `pybind11_static_property.__get__()`: Always pass the class instead of the instance. 25 | extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) { 26 | return PyProperty_Type.tp_descr_get(self, cls, cls); 27 | } 28 | 29 | /// `pybind11_static_property.__set__()`: Just like the above `__get__()`. 30 | extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) { 31 | PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj); 32 | return PyProperty_Type.tp_descr_set(self, cls, value); 33 | } 34 | 35 | /** A `static_property` is the same as a `property` but the `__get__()` and `__set__()` 36 | methods are modified to always use the object type instead of a concrete instance. 37 | Return value: New reference. */ 38 | inline PyTypeObject *make_static_property_type() { 39 | constexpr auto *name = "pybind11_static_property"; 40 | auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); 41 | 42 | /* Danger zone: from now (and until PyType_Ready), make sure to 43 | issue no Python C API calls which could potentially invoke the 44 | garbage collector (the GC will call type_traverse(), which will in 45 | turn find the newly constructed type in an invalid state) */ 46 | auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); 47 | if (!heap_type) 48 | pybind11_fail("make_static_property_type(): error allocating type!"); 49 | 50 | heap_type->ht_name = name_obj.inc_ref().ptr(); 51 | #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 52 | heap_type->ht_qualname = name_obj.inc_ref().ptr(); 53 | #endif 54 | 55 | auto type = &heap_type->ht_type; 56 | type->tp_name = name; 57 | type->tp_base = type_incref(&PyProperty_Type); 58 | type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; 59 | type->tp_descr_get = pybind11_static_get; 60 | type->tp_descr_set = pybind11_static_set; 61 | 62 | if (PyType_Ready(type) < 0) 63 | pybind11_fail("make_static_property_type(): failure in PyType_Ready()!"); 64 | 65 | setattr((PyObject *) type, "__module__", str("pybind11_builtins")); 66 | 67 | return type; 68 | } 69 | 70 | #else // PYPY 71 | 72 | /** PyPy has some issues with the above C API, so we evaluate Python code instead. 73 | This function will only be called once so performance isn't really a concern. 74 | Return value: New reference. */ 75 | inline PyTypeObject *make_static_property_type() { 76 | auto d = dict(); 77 | PyObject *result = PyRun_String(R"(\ 78 | class pybind11_static_property(property): 79 | def __get__(self, obj, cls): 80 | return property.__get__(self, cls, cls) 81 | 82 | def __set__(self, obj, value): 83 | cls = obj if isinstance(obj, type) else type(obj) 84 | property.__set__(self, cls, value) 85 | )", Py_file_input, d.ptr(), d.ptr() 86 | ); 87 | if (result == nullptr) 88 | throw error_already_set(); 89 | Py_DECREF(result); 90 | return (PyTypeObject *) d["pybind11_static_property"].cast().release().ptr(); 91 | } 92 | 93 | #endif // PYPY 94 | 95 | /** Types with static properties need to handle `Type.static_prop = x` in a specific way. 96 | By default, Python replaces the `static_property` itself, but for wrapped C++ types 97 | we need to call `static_property.__set__()` in order to propagate the new value to 98 | the underlying C++ data structure. */ 99 | extern "C" inline int pybind11_meta_setattro(PyObject* obj, PyObject* name, PyObject* value) { 100 | // Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw 101 | // descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`). 102 | PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); 103 | 104 | // The following assignment combinations are possible: 105 | // 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)` 106 | // 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop` 107 | // 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment 108 | const auto static_prop = (PyObject *) get_internals().static_property_type; 109 | const auto call_descr_set = descr && PyObject_IsInstance(descr, static_prop) 110 | && !PyObject_IsInstance(value, static_prop); 111 | if (call_descr_set) { 112 | // Call `static_property.__set__()` instead of replacing the `static_property`. 113 | #if !defined(PYPY_VERSION) 114 | return Py_TYPE(descr)->tp_descr_set(descr, obj, value); 115 | #else 116 | if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) { 117 | Py_DECREF(result); 118 | return 0; 119 | } else { 120 | return -1; 121 | } 122 | #endif 123 | } else { 124 | // Replace existing attribute. 125 | return PyType_Type.tp_setattro(obj, name, value); 126 | } 127 | } 128 | 129 | #if PY_MAJOR_VERSION >= 3 130 | /** 131 | * Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing 132 | * methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function, 133 | * when called on a class, or a PyMethod, when called on an instance. Override that behaviour here 134 | * to do a special case bypass for PyInstanceMethod_Types. 135 | */ 136 | extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) { 137 | PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); 138 | if (descr && PyInstanceMethod_Check(descr)) { 139 | Py_INCREF(descr); 140 | return descr; 141 | } 142 | else { 143 | return PyType_Type.tp_getattro(obj, name); 144 | } 145 | } 146 | #endif 147 | 148 | /** This metaclass is assigned by default to all pybind11 types and is required in order 149 | for static properties to function correctly. Users may override this using `py::metaclass`. 150 | Return value: New reference. */ 151 | inline PyTypeObject* make_default_metaclass() { 152 | constexpr auto *name = "pybind11_type"; 153 | auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); 154 | 155 | /* Danger zone: from now (and until PyType_Ready), make sure to 156 | issue no Python C API calls which could potentially invoke the 157 | garbage collector (the GC will call type_traverse(), which will in 158 | turn find the newly constructed type in an invalid state) */ 159 | auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); 160 | if (!heap_type) 161 | pybind11_fail("make_default_metaclass(): error allocating metaclass!"); 162 | 163 | heap_type->ht_name = name_obj.inc_ref().ptr(); 164 | #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 165 | heap_type->ht_qualname = name_obj.inc_ref().ptr(); 166 | #endif 167 | 168 | auto type = &heap_type->ht_type; 169 | type->tp_name = name; 170 | type->tp_base = type_incref(&PyType_Type); 171 | type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; 172 | 173 | type->tp_setattro = pybind11_meta_setattro; 174 | #if PY_MAJOR_VERSION >= 3 175 | type->tp_getattro = pybind11_meta_getattro; 176 | #endif 177 | 178 | if (PyType_Ready(type) < 0) 179 | pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!"); 180 | 181 | setattr((PyObject *) type, "__module__", str("pybind11_builtins")); 182 | 183 | return type; 184 | } 185 | 186 | /// For multiple inheritance types we need to recursively register/deregister base pointers for any 187 | /// base classes with pointers that are difference from the instance value pointer so that we can 188 | /// correctly recognize an offset base class pointer. This calls a function with any offset base ptrs. 189 | inline void traverse_offset_bases(void *valueptr, const detail::type_info *tinfo, instance *self, 190 | bool (*f)(void * /*parentptr*/, instance * /*self*/)) { 191 | for (handle h : reinterpret_borrow(tinfo->type->tp_bases)) { 192 | if (auto parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) { 193 | for (auto &c : parent_tinfo->implicit_casts) { 194 | if (c.first == tinfo->cpptype) { 195 | auto *parentptr = c.second(valueptr); 196 | if (parentptr != valueptr) 197 | f(parentptr, self); 198 | traverse_offset_bases(parentptr, parent_tinfo, self, f); 199 | break; 200 | } 201 | } 202 | } 203 | } 204 | } 205 | 206 | inline bool register_instance_impl(void *ptr, instance *self) { 207 | get_internals().registered_instances.emplace(ptr, self); 208 | return true; // unused, but gives the same signature as the deregister func 209 | } 210 | inline bool deregister_instance_impl(void *ptr, instance *self) { 211 | auto ®istered_instances = get_internals().registered_instances; 212 | auto range = registered_instances.equal_range(ptr); 213 | for (auto it = range.first; it != range.second; ++it) { 214 | if (Py_TYPE(self) == Py_TYPE(it->second)) { 215 | registered_instances.erase(it); 216 | return true; 217 | } 218 | } 219 | return false; 220 | } 221 | 222 | inline void register_instance(instance *self, void *valptr, const type_info *tinfo) { 223 | register_instance_impl(valptr, self); 224 | if (!tinfo->simple_ancestors) 225 | traverse_offset_bases(valptr, tinfo, self, register_instance_impl); 226 | } 227 | 228 | inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) { 229 | bool ret = deregister_instance_impl(valptr, self); 230 | if (!tinfo->simple_ancestors) 231 | traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl); 232 | return ret; 233 | } 234 | 235 | /// Instance creation function for all pybind11 types. It only allocates space for the C++ object 236 | /// (or multiple objects, for Python-side inheritance from multiple pybind11 types), but doesn't 237 | /// call the constructor -- an `__init__` function must do that (followed by an `init_instance` 238 | /// to set up the holder and register the instance). 239 | inline PyObject *make_new_instance(PyTypeObject *type, bool allocate_value /*= true (in cast.h)*/) { 240 | #if defined(PYPY_VERSION) 241 | // PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first inherited 242 | // object is a a plain Python type (i.e. not derived from an extension type). Fix it. 243 | ssize_t instance_size = static_cast(sizeof(instance)); 244 | if (type->tp_basicsize < instance_size) { 245 | type->tp_basicsize = instance_size; 246 | } 247 | #endif 248 | PyObject *self = type->tp_alloc(type, 0); 249 | auto inst = reinterpret_cast(self); 250 | // Allocate the value/holder internals: 251 | inst->allocate_layout(); 252 | 253 | inst->owned = true; 254 | // Allocate (if requested) the value pointers; otherwise leave them as nullptr 255 | if (allocate_value) { 256 | for (auto &v_h : values_and_holders(inst)) { 257 | void *&vptr = v_h.value_ptr(); 258 | vptr = v_h.type->operator_new(v_h.type->type_size); 259 | } 260 | } 261 | 262 | return self; 263 | } 264 | 265 | /// Instance creation function for all pybind11 types. It only allocates space for the 266 | /// C++ object, but doesn't call the constructor -- an `__init__` function must do that. 267 | extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) { 268 | return make_new_instance(type); 269 | } 270 | 271 | /// An `__init__` function constructs the C++ object. Users should provide at least one 272 | /// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the 273 | /// following default function will be used which simply throws an exception. 274 | extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) { 275 | PyTypeObject *type = Py_TYPE(self); 276 | std::string msg; 277 | #if defined(PYPY_VERSION) 278 | msg += handle((PyObject *) type).attr("__module__").cast() + "."; 279 | #endif 280 | msg += type->tp_name; 281 | msg += ": No constructor defined!"; 282 | PyErr_SetString(PyExc_TypeError, msg.c_str()); 283 | return -1; 284 | } 285 | 286 | inline void add_patient(PyObject *nurse, PyObject *patient) { 287 | auto &internals = get_internals(); 288 | auto instance = reinterpret_cast(nurse); 289 | instance->has_patients = true; 290 | Py_INCREF(patient); 291 | internals.patients[nurse].push_back(patient); 292 | } 293 | 294 | inline void clear_patients(PyObject *self) { 295 | auto instance = reinterpret_cast(self); 296 | auto &internals = get_internals(); 297 | auto pos = internals.patients.find(self); 298 | assert(pos != internals.patients.end()); 299 | // Clearing the patients can cause more Python code to run, which 300 | // can invalidate the iterator. Extract the vector of patients 301 | // from the unordered_map first. 302 | auto patients = std::move(pos->second); 303 | internals.patients.erase(pos); 304 | instance->has_patients = false; 305 | for (PyObject *&patient : patients) 306 | Py_CLEAR(patient); 307 | } 308 | 309 | /// Clears all internal data from the instance and removes it from registered instances in 310 | /// preparation for deallocation. 311 | inline void clear_instance(PyObject *self) { 312 | auto instance = reinterpret_cast(self); 313 | 314 | // Deallocate any values/holders, if present: 315 | for (auto &v_h : values_and_holders(instance)) { 316 | if (v_h) { 317 | 318 | // We have to deregister before we call dealloc because, for virtual MI types, we still 319 | // need to be able to get the parent pointers. 320 | if (v_h.instance_registered() && !deregister_instance(instance, v_h.value_ptr(), v_h.type)) 321 | pybind11_fail("pybind11_object_dealloc(): Tried to deallocate unregistered instance!"); 322 | 323 | if (instance->owned || v_h.holder_constructed()) 324 | v_h.type->dealloc(v_h); 325 | } 326 | } 327 | // Deallocate the value/holder layout internals: 328 | instance->deallocate_layout(); 329 | 330 | if (instance->weakrefs) 331 | PyObject_ClearWeakRefs(self); 332 | 333 | PyObject **dict_ptr = _PyObject_GetDictPtr(self); 334 | if (dict_ptr) 335 | Py_CLEAR(*dict_ptr); 336 | 337 | if (instance->has_patients) 338 | clear_patients(self); 339 | } 340 | 341 | /// Instance destructor function for all pybind11 types. It calls `type_info.dealloc` 342 | /// to destroy the C++ object itself, while the rest is Python bookkeeping. 343 | extern "C" inline void pybind11_object_dealloc(PyObject *self) { 344 | clear_instance(self); 345 | Py_TYPE(self)->tp_free(self); 346 | } 347 | 348 | /** Create the type which can be used as a common base for all classes. This is 349 | needed in order to satisfy Python's requirements for multiple inheritance. 350 | Return value: New reference. */ 351 | inline PyObject *make_object_base_type(PyTypeObject *metaclass) { 352 | constexpr auto *name = "pybind11_object"; 353 | auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); 354 | 355 | /* Danger zone: from now (and until PyType_Ready), make sure to 356 | issue no Python C API calls which could potentially invoke the 357 | garbage collector (the GC will call type_traverse(), which will in 358 | turn find the newly constructed type in an invalid state) */ 359 | auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); 360 | if (!heap_type) 361 | pybind11_fail("make_object_base_type(): error allocating type!"); 362 | 363 | heap_type->ht_name = name_obj.inc_ref().ptr(); 364 | #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 365 | heap_type->ht_qualname = name_obj.inc_ref().ptr(); 366 | #endif 367 | 368 | auto type = &heap_type->ht_type; 369 | type->tp_name = name; 370 | type->tp_base = type_incref(&PyBaseObject_Type); 371 | type->tp_basicsize = static_cast(sizeof(instance)); 372 | type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; 373 | 374 | type->tp_new = pybind11_object_new; 375 | type->tp_init = pybind11_object_init; 376 | type->tp_dealloc = pybind11_object_dealloc; 377 | 378 | /* Support weak references (needed for the keep_alive feature) */ 379 | type->tp_weaklistoffset = offsetof(instance, weakrefs); 380 | 381 | if (PyType_Ready(type) < 0) 382 | pybind11_fail("PyType_Ready failed in make_object_base_type():" + error_string()); 383 | 384 | setattr((PyObject *) type, "__module__", str("pybind11_builtins")); 385 | 386 | assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); 387 | return (PyObject *) heap_type; 388 | } 389 | 390 | /// dynamic_attr: Support for `d = instance.__dict__`. 391 | extern "C" inline PyObject *pybind11_get_dict(PyObject *self, void *) { 392 | PyObject *&dict = *_PyObject_GetDictPtr(self); 393 | if (!dict) 394 | dict = PyDict_New(); 395 | Py_XINCREF(dict); 396 | return dict; 397 | } 398 | 399 | /// dynamic_attr: Support for `instance.__dict__ = dict()`. 400 | extern "C" inline int pybind11_set_dict(PyObject *self, PyObject *new_dict, void *) { 401 | if (!PyDict_Check(new_dict)) { 402 | PyErr_Format(PyExc_TypeError, "__dict__ must be set to a dictionary, not a '%.200s'", 403 | Py_TYPE(new_dict)->tp_name); 404 | return -1; 405 | } 406 | PyObject *&dict = *_PyObject_GetDictPtr(self); 407 | Py_INCREF(new_dict); 408 | Py_CLEAR(dict); 409 | dict = new_dict; 410 | return 0; 411 | } 412 | 413 | /// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`. 414 | extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) { 415 | PyObject *&dict = *_PyObject_GetDictPtr(self); 416 | Py_VISIT(dict); 417 | return 0; 418 | } 419 | 420 | /// dynamic_attr: Allow the GC to clear the dictionary. 421 | extern "C" inline int pybind11_clear(PyObject *self) { 422 | PyObject *&dict = *_PyObject_GetDictPtr(self); 423 | Py_CLEAR(dict); 424 | return 0; 425 | } 426 | 427 | /// Give instances of this type a `__dict__` and opt into garbage collection. 428 | inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) { 429 | auto type = &heap_type->ht_type; 430 | #if defined(PYPY_VERSION) 431 | pybind11_fail(std::string(type->tp_name) + ": dynamic attributes are " 432 | "currently not supported in " 433 | "conjunction with PyPy!"); 434 | #endif 435 | type->tp_flags |= Py_TPFLAGS_HAVE_GC; 436 | type->tp_dictoffset = type->tp_basicsize; // place dict at the end 437 | type->tp_basicsize += (ssize_t)sizeof(PyObject *); // and allocate enough space for it 438 | type->tp_traverse = pybind11_traverse; 439 | type->tp_clear = pybind11_clear; 440 | 441 | static PyGetSetDef getset[] = { 442 | {const_cast("__dict__"), pybind11_get_dict, pybind11_set_dict, nullptr, nullptr}, 443 | {nullptr, nullptr, nullptr, nullptr, nullptr} 444 | }; 445 | type->tp_getset = getset; 446 | } 447 | 448 | /// buffer_protocol: Fill in the view as specified by flags. 449 | extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) { 450 | // Look for a `get_buffer` implementation in this type's info or any bases (following MRO). 451 | type_info *tinfo = nullptr; 452 | for (auto type : reinterpret_borrow(Py_TYPE(obj)->tp_mro)) { 453 | tinfo = get_type_info((PyTypeObject *) type.ptr()); 454 | if (tinfo && tinfo->get_buffer) 455 | break; 456 | } 457 | if (view == nullptr || obj == nullptr || !tinfo || !tinfo->get_buffer) { 458 | if (view) 459 | view->obj = nullptr; 460 | PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error"); 461 | return -1; 462 | } 463 | std::memset(view, 0, sizeof(Py_buffer)); 464 | buffer_info *info = tinfo->get_buffer(obj, tinfo->get_buffer_data); 465 | view->obj = obj; 466 | view->ndim = 1; 467 | view->internal = info; 468 | view->buf = info->ptr; 469 | view->itemsize = info->itemsize; 470 | view->len = view->itemsize; 471 | for (auto s : info->shape) 472 | view->len *= s; 473 | if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) 474 | view->format = const_cast(info->format.c_str()); 475 | if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { 476 | view->ndim = (int) info->ndim; 477 | view->strides = &info->strides[0]; 478 | view->shape = &info->shape[0]; 479 | } 480 | Py_INCREF(view->obj); 481 | return 0; 482 | } 483 | 484 | /// buffer_protocol: Release the resources of the buffer. 485 | extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) { 486 | delete (buffer_info *) view->internal; 487 | } 488 | 489 | /// Give this type a buffer interface. 490 | inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) { 491 | heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer; 492 | #if PY_MAJOR_VERSION < 3 493 | heap_type->ht_type.tp_flags |= Py_TPFLAGS_HAVE_NEWBUFFER; 494 | #endif 495 | 496 | heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer; 497 | heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer; 498 | } 499 | 500 | /** Create a brand new Python type according to the `type_record` specification. 501 | Return value: New reference. */ 502 | inline PyObject* make_new_python_type(const type_record &rec) { 503 | auto name = reinterpret_steal(PYBIND11_FROM_STRING(rec.name)); 504 | 505 | #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 506 | auto ht_qualname = name; 507 | if (rec.scope && hasattr(rec.scope, "__qualname__")) { 508 | ht_qualname = reinterpret_steal( 509 | PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr())); 510 | } 511 | #endif 512 | 513 | object module; 514 | if (rec.scope) { 515 | if (hasattr(rec.scope, "__module__")) 516 | module = rec.scope.attr("__module__"); 517 | else if (hasattr(rec.scope, "__name__")) 518 | module = rec.scope.attr("__name__"); 519 | } 520 | 521 | #if !defined(PYPY_VERSION) 522 | const auto full_name = module ? str(module).cast() + "." + rec.name 523 | : std::string(rec.name); 524 | #else 525 | const auto full_name = std::string(rec.name); 526 | #endif 527 | 528 | char *tp_doc = nullptr; 529 | if (rec.doc && options::show_user_defined_docstrings()) { 530 | /* Allocate memory for docstring (using PyObject_MALLOC, since 531 | Python will free this later on) */ 532 | size_t size = strlen(rec.doc) + 1; 533 | tp_doc = (char *) PyObject_MALLOC(size); 534 | memcpy((void *) tp_doc, rec.doc, size); 535 | } 536 | 537 | auto &internals = get_internals(); 538 | auto bases = tuple(rec.bases); 539 | auto base = (bases.size() == 0) ? internals.instance_base 540 | : bases[0].ptr(); 541 | 542 | /* Danger zone: from now (and until PyType_Ready), make sure to 543 | issue no Python C API calls which could potentially invoke the 544 | garbage collector (the GC will call type_traverse(), which will in 545 | turn find the newly constructed type in an invalid state) */ 546 | auto metaclass = rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr() 547 | : internals.default_metaclass; 548 | 549 | auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); 550 | if (!heap_type) 551 | pybind11_fail(std::string(rec.name) + ": Unable to create type object!"); 552 | 553 | heap_type->ht_name = name.release().ptr(); 554 | #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 3 555 | heap_type->ht_qualname = ht_qualname.release().ptr(); 556 | #endif 557 | 558 | auto type = &heap_type->ht_type; 559 | type->tp_name = strdup(full_name.c_str()); 560 | type->tp_doc = tp_doc; 561 | type->tp_base = type_incref((PyTypeObject *)base); 562 | type->tp_basicsize = static_cast(sizeof(instance)); 563 | if (bases.size() > 0) 564 | type->tp_bases = bases.release().ptr(); 565 | 566 | /* Don't inherit base __init__ */ 567 | type->tp_init = pybind11_object_init; 568 | 569 | /* Supported protocols */ 570 | type->tp_as_number = &heap_type->as_number; 571 | type->tp_as_sequence = &heap_type->as_sequence; 572 | type->tp_as_mapping = &heap_type->as_mapping; 573 | 574 | /* Flags */ 575 | type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; 576 | #if PY_MAJOR_VERSION < 3 577 | type->tp_flags |= Py_TPFLAGS_CHECKTYPES; 578 | #endif 579 | 580 | if (rec.dynamic_attr) 581 | enable_dynamic_attributes(heap_type); 582 | 583 | if (rec.buffer_protocol) 584 | enable_buffer_protocol(heap_type); 585 | 586 | if (PyType_Ready(type) < 0) 587 | pybind11_fail(std::string(rec.name) + ": PyType_Ready failed (" + error_string() + ")!"); 588 | 589 | assert(rec.dynamic_attr ? PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) 590 | : !PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); 591 | 592 | /* Register type with the parent scope */ 593 | if (rec.scope) 594 | setattr(rec.scope, rec.name, (PyObject *) type); 595 | 596 | if (module) // Needed by pydoc 597 | setattr((PyObject *) type, "__module__", module); 598 | 599 | return (PyObject *) type; 600 | } 601 | 602 | NAMESPACE_END(detail) 603 | NAMESPACE_END(pybind11) 604 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import torchvision.transforms as transforms 5 | from torch.utils import data 6 | import glob as gb 7 | import numpy as np 8 | import cv2 9 | import csv 10 | import matplotlib.pyplot as plt 11 | import matplotlib.patches as Patches 12 | from shapely.geometry import Polygon 13 | 14 | 15 | 16 | def get_images(root): 17 | ''' 18 | get images's path and name 19 | ''' 20 | files = [] 21 | for ext in ['jpg', 'png', 'jpeg', 'JPG']: 22 | files.extend(gb.glob(os.path.join(root, '*.{}'.format(ext)))) 23 | name = [] 24 | for i in range(len(files)): 25 | name.append(files[i].split('/')[-1]) 26 | return files, name 27 | 28 | 29 | def load_annoataion(p): 30 | ''' 31 | load annotation from the text file 32 | :param p: 33 | :return: 34 | ''' 35 | text_polys = [] 36 | text_tags = [] 37 | if not os.path.exists(p): 38 | return np.array(text_polys, dtype=np.float32) 39 | with open(p, 'r') as f: 40 | reader = csv.reader(f) 41 | for line in reader: 42 | label = line[-1] 43 | # strip BOM. \ufeff for python3, \xef\xbb\bf for python2 44 | line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in line] 45 | 46 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8])) 47 | text_polys.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) 48 | if label == '*' or label == '###': 49 | text_tags.append(True) 50 | else: 51 | text_tags.append(False) 52 | return np.array(text_polys, dtype=np.float32), np.array(text_tags, dtype=np.bool) 53 | 54 | 55 | def polygon_area(poly): 56 | ''' 57 | compute area of a polygon 58 | :param poly: 59 | :return: 60 | ''' 61 | edge = [ 62 | (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), 63 | (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), 64 | (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), 65 | (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]) 66 | ] 67 | return np.sum(edge)/2. 68 | 69 | 70 | def check_and_validate_polys(polys, tags, xxx_todo_changeme): 71 | ''' 72 | check so that the text poly is in the same direction, 73 | and also filter some invalid polygons 74 | :param polys: 75 | :param tags: 76 | :return: 77 | ''' 78 | (h, w) = xxx_todo_changeme 79 | if polys.shape[0] == 0: 80 | return polys 81 | polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w-1) 82 | polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h-1) 83 | 84 | validated_polys = [] 85 | validated_tags = [] 86 | for poly, tag in zip(polys, tags): 87 | p_area = polygon_area(poly) 88 | if abs(p_area) < 1: 89 | # print poly 90 | print('invalid poly') 91 | continue 92 | if p_area > 0: 93 | print('poly in wrong direction') 94 | poly = poly[(0, 3, 2, 1), :] 95 | validated_polys.append(poly) 96 | validated_tags.append(tag) 97 | return np.array(validated_polys), np.array(validated_tags) 98 | 99 | 100 | def crop_area(im, polys, tags, crop_background=False, max_tries=50): 101 | ''' 102 | make random crop from the input image 103 | :param im: 104 | :param polys: 105 | :param tags: 106 | :param crop_background: 107 | :param max_tries: 108 | :return: 109 | ''' 110 | h, w, _ = im.shape 111 | pad_h = h//10 112 | pad_w = w//10 113 | h_array = np.zeros((h + pad_h*2), dtype=np.int32) 114 | w_array = np.zeros((w + pad_w*2), dtype=np.int32) 115 | for poly in polys: 116 | poly = np.round(poly, decimals=0).astype(np.int32) 117 | minx = np.min(poly[:, 0]) 118 | maxx = np.max(poly[:, 0]) 119 | w_array[minx+pad_w:maxx+pad_w] = 1 120 | miny = np.min(poly[:, 1]) 121 | maxy = np.max(poly[:, 1]) 122 | h_array[miny+pad_h:maxy+pad_h] = 1 123 | # ensure the cropped area not across a text 124 | h_axis = np.where(h_array == 0)[0] 125 | w_axis = np.where(w_array == 0)[0] 126 | if len(h_axis) == 0 or len(w_axis) == 0: 127 | return im, polys, tags 128 | for i in range(max_tries): 129 | xx = np.random.choice(w_axis, size=2) 130 | xmin = np.min(xx) - pad_w 131 | xmax = np.max(xx) - pad_w 132 | xmin = np.clip(xmin, 0, w-1) 133 | xmax = np.clip(xmax, 0, w-1) 134 | yy = np.random.choice(h_axis, size=2) 135 | ymin = np.min(yy) - pad_h 136 | ymax = np.max(yy) - pad_h 137 | ymin = np.clip(ymin, 0, h-1) 138 | ymax = np.clip(ymax, 0, h-1) 139 | # if xmax - xmin < FLAGS.min_crop_side_ratio*w or ymax - ymin < FLAGS.min_crop_side_ratio*h: 140 | if xmax - xmin < 0.1*w or ymax - ymin < 0.1*h: 141 | # area too small 142 | continue 143 | if polys.shape[0] != 0: 144 | poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \ 145 | & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax) 146 | selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0] 147 | else: 148 | selected_polys = [] 149 | if len(selected_polys) == 0: 150 | # no text in this area 151 | if crop_background: 152 | return im[ymin:ymax+1, xmin:xmax+1, :], polys[selected_polys], tags[selected_polys] 153 | else: 154 | continue 155 | im = im[ymin:ymax+1, xmin:xmax+1, :] 156 | polys = polys[selected_polys] 157 | tags = tags[selected_polys] 158 | polys[:, :, 0] -= xmin 159 | polys[:, :, 1] -= ymin 160 | return im, polys, tags 161 | 162 | return im, polys, tags 163 | 164 | 165 | def shrink_poly(poly, r): 166 | ''' 167 | fit a poly inside the origin poly, maybe bugs here... 168 | used for generate the score map 169 | :param poly: the text poly 170 | :param r: r in the paper 171 | :return: the shrinked poly 172 | ''' 173 | # shrink ratio 174 | R = 0.3 175 | # find the longer pair 176 | if np.linalg.norm(poly[0] - poly[1]) + np.linalg.norm(poly[2] - poly[3]) > \ 177 | np.linalg.norm(poly[0] - poly[3]) + np.linalg.norm(poly[1] - poly[2]): 178 | # first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2) 179 | ## p0, p1 180 | theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0])) 181 | poly[0][0] += R * r[0] * np.cos(theta) 182 | poly[0][1] += R * r[0] * np.sin(theta) 183 | poly[1][0] -= R * r[1] * np.cos(theta) 184 | poly[1][1] -= R * r[1] * np.sin(theta) 185 | ## p2, p3 186 | theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0])) 187 | poly[3][0] += R * r[3] * np.cos(theta) 188 | poly[3][1] += R * r[3] * np.sin(theta) 189 | poly[2][0] -= R * r[2] * np.cos(theta) 190 | poly[2][1] -= R * r[2] * np.sin(theta) 191 | ## p0, p3 192 | theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1])) 193 | poly[0][0] += R * r[0] * np.sin(theta) 194 | poly[0][1] += R * r[0] * np.cos(theta) 195 | poly[3][0] -= R * r[3] * np.sin(theta) 196 | poly[3][1] -= R * r[3] * np.cos(theta) 197 | ## p1, p2 198 | theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1])) 199 | poly[1][0] += R * r[1] * np.sin(theta) 200 | poly[1][1] += R * r[1] * np.cos(theta) 201 | poly[2][0] -= R * r[2] * np.sin(theta) 202 | poly[2][1] -= R * r[2] * np.cos(theta) 203 | else: 204 | ## p0, p3 205 | # print poly 206 | theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1])) 207 | poly[0][0] += R * r[0] * np.sin(theta) 208 | poly[0][1] += R * r[0] * np.cos(theta) 209 | poly[3][0] -= R * r[3] * np.sin(theta) 210 | poly[3][1] -= R * r[3] * np.cos(theta) 211 | ## p1, p2 212 | theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1])) 213 | poly[1][0] += R * r[1] * np.sin(theta) 214 | poly[1][1] += R * r[1] * np.cos(theta) 215 | poly[2][0] -= R * r[2] * np.sin(theta) 216 | poly[2][1] -= R * r[2] * np.cos(theta) 217 | ## p0, p1 218 | theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0])) 219 | poly[0][0] += R * r[0] * np.cos(theta) 220 | poly[0][1] += R * r[0] * np.sin(theta) 221 | poly[1][0] -= R * r[1] * np.cos(theta) 222 | poly[1][1] -= R * r[1] * np.sin(theta) 223 | ## p2, p3 224 | theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0])) 225 | poly[3][0] += R * r[3] * np.cos(theta) 226 | poly[3][1] += R * r[3] * np.sin(theta) 227 | poly[2][0] -= R * r[2] * np.cos(theta) 228 | poly[2][1] -= R * r[2] * np.sin(theta) 229 | return poly 230 | 231 | 232 | def point_dist_to_line(p1, p2, p3): 233 | # compute the distance from p3 to p1-p2 234 | return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1) 235 | 236 | 237 | def fit_line(p1, p2): 238 | # fit a line ax+by+c = 0 239 | if p1[0] == p1[1]: 240 | return [1., 0., -p1[0]] 241 | else: 242 | [k, b] = np.polyfit(p1, p2, deg=1) 243 | return [k, -1., b] 244 | 245 | 246 | def line_cross_point(line1, line2): 247 | # line1 0= ax+by+c, compute the cross point of line1 and line2 248 | if line1[0] != 0 and line1[0] == line2[0]: 249 | print('Cross point does not exist') 250 | return None 251 | if line1[0] == 0 and line2[0] == 0: 252 | print('Cross point does not exist') 253 | return None 254 | if line1[1] == 0: 255 | x = -line1[2] 256 | y = line2[0] * x + line2[2] 257 | elif line2[1] == 0: 258 | x = -line2[2] 259 | y = line1[0] * x + line1[2] 260 | else: 261 | k1, _, b1 = line1 262 | k2, _, b2 = line2 263 | x = -(b1-b2)/(k1-k2) 264 | y = k1*x + b1 265 | return np.array([x, y], dtype=np.float32) 266 | 267 | 268 | def line_verticle(line, point): 269 | # get the verticle line from line across point 270 | if line[1] == 0: 271 | verticle = [0, -1, point[1]] 272 | else: 273 | if line[0] == 0: 274 | verticle = [1, 0, -point[0]] 275 | else: 276 | verticle = [-1./line[0], -1, point[1] - (-1/line[0] * point[0])] 277 | return verticle 278 | 279 | 280 | def rectangle_from_parallelogram(poly): 281 | ''' 282 | fit a rectangle from a parallelogram 283 | :param poly: 284 | :return: 285 | ''' 286 | p0, p1, p2, p3 = poly 287 | angle_p0 = np.arccos(np.dot(p1-p0, p3-p0)/(np.linalg.norm(p0-p1) * np.linalg.norm(p3-p0))) 288 | if angle_p0 < 0.5 * np.pi: 289 | if np.linalg.norm(p0 - p1) > np.linalg.norm(p0-p3): 290 | # p0 and p2 291 | ## p0 292 | p2p3 = fit_line([p2[0], p3[0]], [p2[1], p3[1]]) 293 | p2p3_verticle = line_verticle(p2p3, p0) 294 | 295 | new_p3 = line_cross_point(p2p3, p2p3_verticle) 296 | ## p2 297 | p0p1 = fit_line([p0[0], p1[0]], [p0[1], p1[1]]) 298 | p0p1_verticle = line_verticle(p0p1, p2) 299 | 300 | new_p1 = line_cross_point(p0p1, p0p1_verticle) 301 | return np.array([p0, new_p1, p2, new_p3], dtype=np.float32) 302 | else: 303 | p1p2 = fit_line([p1[0], p2[0]], [p1[1], p2[1]]) 304 | p1p2_verticle = line_verticle(p1p2, p0) 305 | 306 | new_p1 = line_cross_point(p1p2, p1p2_verticle) 307 | p0p3 = fit_line([p0[0], p3[0]], [p0[1], p3[1]]) 308 | p0p3_verticle = line_verticle(p0p3, p2) 309 | 310 | new_p3 = line_cross_point(p0p3, p0p3_verticle) 311 | return np.array([p0, new_p1, p2, new_p3], dtype=np.float32) 312 | else: 313 | if np.linalg.norm(p0-p1) > np.linalg.norm(p0-p3): 314 | # p1 and p3 315 | ## p1 316 | p2p3 = fit_line([p2[0], p3[0]], [p2[1], p3[1]]) 317 | p2p3_verticle = line_verticle(p2p3, p1) 318 | 319 | new_p2 = line_cross_point(p2p3, p2p3_verticle) 320 | ## p3 321 | p0p1 = fit_line([p0[0], p1[0]], [p0[1], p1[1]]) 322 | p0p1_verticle = line_verticle(p0p1, p3) 323 | 324 | new_p0 = line_cross_point(p0p1, p0p1_verticle) 325 | return np.array([new_p0, p1, new_p2, p3], dtype=np.float32) 326 | else: 327 | p0p3 = fit_line([p0[0], p3[0]], [p0[1], p3[1]]) 328 | p0p3_verticle = line_verticle(p0p3, p1) 329 | 330 | new_p0 = line_cross_point(p0p3, p0p3_verticle) 331 | p1p2 = fit_line([p1[0], p2[0]], [p1[1], p2[1]]) 332 | p1p2_verticle = line_verticle(p1p2, p3) 333 | 334 | new_p2 = line_cross_point(p1p2, p1p2_verticle) 335 | return np.array([new_p0, p1, new_p2, p3], dtype=np.float32) 336 | 337 | 338 | def sort_rectangle(poly): 339 | # sort the four coordinates of the polygon, points in poly should be sorted clockwise 340 | # First find the lowest point 341 | p_lowest = np.argmax(poly[:, 1]) 342 | if np.count_nonzero(poly[:, 1] == poly[p_lowest, 1]) == 2: 343 | # 底边平行于X轴, 那么p0为左上角 344 | p0_index = np.argmin(np.sum(poly, axis=1)) 345 | p1_index = (p0_index + 1) % 4 346 | p2_index = (p0_index + 2) % 4 347 | p3_index = (p0_index + 3) % 4 348 | return poly[[p0_index, p1_index, p2_index, p3_index]], 0. 349 | else: 350 | # 找到最低点右边的点 351 | p_lowest_right = (p_lowest - 1) % 4 352 | p_lowest_left = (p_lowest + 1) % 4 353 | angle = np.arctan(-(poly[p_lowest][1] - poly[p_lowest_right][1])/(poly[p_lowest][0] - poly[p_lowest_right][0])) 354 | # assert angle > 0 355 | if angle <= 0: 356 | print(angle, poly[p_lowest], poly[p_lowest_right]) 357 | if angle/np.pi * 180 > 45: 358 | # 这个点为p2 359 | p2_index = p_lowest 360 | p1_index = (p2_index - 1) % 4 361 | p0_index = (p2_index - 2) % 4 362 | p3_index = (p2_index + 1) % 4 363 | return poly[[p0_index, p1_index, p2_index, p3_index]], -(np.pi/2 - angle) 364 | else: 365 | # 这个点为p3 366 | p3_index = p_lowest 367 | p0_index = (p3_index + 1) % 4 368 | p1_index = (p3_index + 2) % 4 369 | p2_index = (p3_index + 3) % 4 370 | return poly[[p0_index, p1_index, p2_index, p3_index]], angle 371 | 372 | 373 | def restore_rectangle_rbox(origin, geometry): 374 | d = geometry[:, :4] 375 | angle = geometry[:, 4] 376 | # for angle > 0 377 | origin_0 = origin[angle >= 0] 378 | d_0 = d[angle >= 0] 379 | angle_0 = angle[angle >= 0] 380 | if origin_0.shape[0] > 0: 381 | p = np.array([np.zeros(d_0.shape[0]), -d_0[:, 0] - d_0[:, 2], 382 | d_0[:, 1] + d_0[:, 3], -d_0[:, 0] - d_0[:, 2], 383 | d_0[:, 1] + d_0[:, 3], np.zeros(d_0.shape[0]), 384 | np.zeros(d_0.shape[0]), np.zeros(d_0.shape[0]), 385 | d_0[:, 3], -d_0[:, 2]]) 386 | p = p.transpose((1, 0)).reshape((-1, 5, 2)) # N*5*2 387 | 388 | rotate_matrix_x = np.array([np.cos(angle_0), np.sin(angle_0)]).transpose((1, 0)) 389 | rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) # N*5*2 390 | 391 | rotate_matrix_y = np.array([-np.sin(angle_0), np.cos(angle_0)]).transpose((1, 0)) 392 | rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) 393 | 394 | p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis] # N*5*1 395 | p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis] # N*5*1 396 | 397 | p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2) # N*5*2 398 | 399 | p3_in_origin = origin_0 - p_rotate[:, 4, :] 400 | new_p0 = p_rotate[:, 0, :] + p3_in_origin # N*2 401 | new_p1 = p_rotate[:, 1, :] + p3_in_origin 402 | new_p2 = p_rotate[:, 2, :] + p3_in_origin 403 | new_p3 = p_rotate[:, 3, :] + p3_in_origin 404 | 405 | new_p_0 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :], 406 | new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1) # N*4*2 407 | else: 408 | new_p_0 = np.zeros((0, 4, 2)) 409 | # for angle < 0 410 | origin_1 = origin[angle < 0] 411 | d_1 = d[angle < 0] 412 | angle_1 = angle[angle < 0] 413 | if origin_1.shape[0] > 0: 414 | p = np.array([-d_1[:, 1] - d_1[:, 3], -d_1[:, 0] - d_1[:, 2], 415 | np.zeros(d_1.shape[0]), -d_1[:, 0] - d_1[:, 2], 416 | np.zeros(d_1.shape[0]), np.zeros(d_1.shape[0]), 417 | -d_1[:, 1] - d_1[:, 3], np.zeros(d_1.shape[0]), 418 | -d_1[:, 1], -d_1[:, 2]]) 419 | p = p.transpose((1, 0)).reshape((-1, 5, 2)) # N*5*2 420 | 421 | rotate_matrix_x = np.array([np.cos(-angle_1), -np.sin(-angle_1)]).transpose((1, 0)) 422 | rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) # N*5*2 423 | 424 | rotate_matrix_y = np.array([np.sin(-angle_1), np.cos(-angle_1)]).transpose((1, 0)) 425 | rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) 426 | 427 | p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis] # N*5*1 428 | p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis] # N*5*1 429 | 430 | p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2) # N*5*2 431 | 432 | p3_in_origin = origin_1 - p_rotate[:, 4, :] 433 | new_p0 = p_rotate[:, 0, :] + p3_in_origin # N*2 434 | new_p1 = p_rotate[:, 1, :] + p3_in_origin 435 | new_p2 = p_rotate[:, 2, :] + p3_in_origin 436 | new_p3 = p_rotate[:, 3, :] + p3_in_origin 437 | 438 | new_p_1 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :], 439 | new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1) # N*4*2 440 | else: 441 | new_p_1 = np.zeros((0, 4, 2)) 442 | return np.concatenate([new_p_0, new_p_1]) 443 | 444 | 445 | def restore_rectangle(origin, geometry): 446 | return restore_rectangle_rbox(origin, geometry) 447 | 448 | 449 | def generate_rbox(im_size, polys, tags): 450 | h, w = im_size 451 | poly_mask = np.zeros((h, w), dtype=np.uint8) 452 | score_map = np.zeros((h, w), dtype=np.uint8) 453 | geo_map = np.zeros((h, w, 5), dtype=np.float32) 454 | # mask used during traning, to ignore some hard areas 455 | training_mask = np.ones((h, w), dtype=np.uint8) 456 | for poly_idx, poly_tag in enumerate(zip(polys, tags)): 457 | poly = poly_tag[0] 458 | tag = poly_tag[1] 459 | 460 | r = [None, None, None, None] 461 | for i in range(4): 462 | r[i] = min(np.linalg.norm(poly[i] - poly[(i + 1) % 4]), 463 | np.linalg.norm(poly[i] - poly[(i - 1) % 4])) 464 | # score map 465 | shrinked_poly = shrink_poly(poly.copy(), r).astype(np.int32)[np.newaxis, :, :] 466 | cv2.fillPoly(score_map, shrinked_poly, 1) 467 | cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1) 468 | # if the poly is too small, then ignore it during training 469 | poly_h = min(np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2])) 470 | poly_w = min(np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3])) 471 | # if min(poly_h, poly_w) < FLAGS.min_text_size: 472 | if min(poly_h, poly_w) < 10: 473 | cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) 474 | if tag: 475 | cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) 476 | 477 | xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1)) 478 | # if geometry == 'RBOX': 479 | # 对任意两个顶点的组合生成一个平行四边形 480 | fitted_parallelograms = [] 481 | for i in range(4): 482 | p0 = poly[i] 483 | p1 = poly[(i + 1) % 4] 484 | p2 = poly[(i + 2) % 4] 485 | p3 = poly[(i + 3) % 4] 486 | edge = fit_line([p0[0], p1[0]], [p0[1], p1[1]]) 487 | backward_edge = fit_line([p0[0], p3[0]], [p0[1], p3[1]]) 488 | forward_edge = fit_line([p1[0], p2[0]], [p1[1], p2[1]]) 489 | if point_dist_to_line(p0, p1, p2) > point_dist_to_line(p0, p1, p3): 490 | # 平行线经过p2 491 | if edge[1] == 0: 492 | edge_opposite = [1, 0, -p2[0]] 493 | else: 494 | edge_opposite = [edge[0], -1, p2[1] - edge[0] * p2[0]] 495 | else: 496 | # 经过p3 497 | if edge[1] == 0: 498 | edge_opposite = [1, 0, -p3[0]] 499 | else: 500 | edge_opposite = [edge[0], -1, p3[1] - edge[0] * p3[0]] 501 | # move forward edge 502 | new_p0 = p0 503 | new_p1 = p1 504 | new_p2 = p2 505 | new_p3 = p3 506 | new_p2 = line_cross_point(forward_edge, edge_opposite) 507 | if point_dist_to_line(p1, new_p2, p0) > point_dist_to_line(p1, new_p2, p3): 508 | # across p0 509 | if forward_edge[1] == 0: 510 | forward_opposite = [1, 0, -p0[0]] 511 | else: 512 | forward_opposite = [forward_edge[0], -1, p0[1] - forward_edge[0] * p0[0]] 513 | else: 514 | # across p3 515 | if forward_edge[1] == 0: 516 | forward_opposite = [1, 0, -p3[0]] 517 | else: 518 | forward_opposite = [forward_edge[0], -1, p3[1] - forward_edge[0] * p3[0]] 519 | new_p0 = line_cross_point(forward_opposite, edge) 520 | new_p3 = line_cross_point(forward_opposite, edge_opposite) 521 | fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0]) 522 | # or move backward edge 523 | new_p0 = p0 524 | new_p1 = p1 525 | new_p2 = p2 526 | new_p3 = p3 527 | new_p3 = line_cross_point(backward_edge, edge_opposite) 528 | if point_dist_to_line(p0, p3, p1) > point_dist_to_line(p0, p3, p2): 529 | # across p1 530 | if backward_edge[1] == 0: 531 | backward_opposite = [1, 0, -p1[0]] 532 | else: 533 | backward_opposite = [backward_edge[0], -1, p1[1] - backward_edge[0] * p1[0]] 534 | else: 535 | # across p2 536 | if backward_edge[1] == 0: 537 | backward_opposite = [1, 0, -p2[0]] 538 | else: 539 | backward_opposite = [backward_edge[0], -1, p2[1] - backward_edge[0] * p2[0]] 540 | new_p1 = line_cross_point(backward_opposite, edge) 541 | new_p2 = line_cross_point(backward_opposite, edge_opposite) 542 | fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0]) 543 | areas = [Polygon(t).area for t in fitted_parallelograms] 544 | parallelogram = np.array(fitted_parallelograms[np.argmin(areas)][:-1], dtype=np.float32) 545 | # sort thie polygon 546 | parallelogram_coord_sum = np.sum(parallelogram, axis=1) 547 | min_coord_idx = np.argmin(parallelogram_coord_sum) 548 | parallelogram = parallelogram[ 549 | [min_coord_idx, (min_coord_idx + 1) % 4, (min_coord_idx + 2) % 4, (min_coord_idx + 3) % 4]] 550 | 551 | rectange = rectangle_from_parallelogram(parallelogram) 552 | rectange, rotate_angle = sort_rectangle(rectange) 553 | 554 | p0_rect, p1_rect, p2_rect, p3_rect = rectange 555 | for y, x in xy_in_poly: 556 | point = np.array([x, y], dtype=np.float32) 557 | # top 558 | geo_map[y, x, 0] = point_dist_to_line(p0_rect, p1_rect, point) 559 | # right 560 | geo_map[y, x, 1] = point_dist_to_line(p1_rect, p2_rect, point) 561 | # down 562 | geo_map[y, x, 2] = point_dist_to_line(p2_rect, p3_rect, point) 563 | # left 564 | geo_map[y, x, 3] = point_dist_to_line(p3_rect, p0_rect, point) 565 | # angle 566 | geo_map[y, x, 4] = rotate_angle 567 | return score_map, geo_map, training_mask 568 | 569 | def image_label(txt_root, image_list, img_name, index, 570 | input_size = 512, random_scale = np.array([0.5, 1, 2.0, 3.0]), 571 | background_ratio = 3./8): 572 | ''' 573 | get image's corresponding matrix and ground truth 574 | ''' 575 | 576 | try: 577 | im_fn = image_list[index] 578 | im_name = img_name[index] 579 | im = cv2.imread(im_fn) 580 | # print im_fn 581 | h, w, _ = im.shape 582 | # txt_fn = 'gt_' + im_name.replace(im_name.split('.')[1], 'txt') 583 | txt_fn = im_name.replace(im_name.split('.')[1], 'txt') 584 | txt_fn = os.path.join(txt_root, txt_fn) 585 | # if not os.path.exists(txt_fn): 586 | # pass 587 | 588 | text_polys, text_tags = load_annoataion(txt_fn) 589 | text_polys, text_tags = check_and_validate_polys(text_polys, text_tags, (h, w)) 590 | # if text_polys.shape[0] == 0: 591 | # continue 592 | # random scale this image 593 | rd_scale = np.random.choice(random_scale) 594 | im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) 595 | text_polys *= rd_scale 596 | # print rd_scale 597 | # random crop a area from image 598 | if np.random.rand() < background_ratio: 599 | # crop background 600 | im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=True) 601 | # if text_polys.shape[0] > 0: 602 | # # cannot find background 603 | # pass 604 | # pad and resize image 605 | new_h, new_w, _ = im.shape 606 | max_h_w_i = np.max([new_h, new_w, input_size]) 607 | im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8) 608 | im_padded[:new_h, :new_w, :] = im.copy() 609 | im = cv2.resize(im_padded, dsize=(input_size, input_size)) 610 | score_map = np.zeros((input_size, input_size), dtype=np.uint8) 611 | geo_map_channels = 5 612 | # geo_map_channels = 5 if FLAGS.geometry == 'RBOX' else 8 613 | geo_map = np.zeros((input_size, input_size, geo_map_channels), dtype=np.float32) 614 | training_mask = np.ones((input_size, input_size), dtype=np.uint8) 615 | else: 616 | im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=False) 617 | # if text_polys.shape[0] == 0: 618 | # pass 619 | h, w, _ = im.shape 620 | 621 | # pad the image to the training input size or the longer side of image 622 | new_h, new_w, _ = im.shape 623 | max_h_w_i = np.max([new_h, new_w, input_size]) 624 | im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8) 625 | im_padded[:new_h, :new_w, :] = im.copy() 626 | im = im_padded 627 | # resize the image to input size 628 | new_h, new_w, _ = im.shape 629 | resize_h = input_size 630 | resize_w = input_size 631 | im = cv2.resize(im, dsize=(resize_w, resize_h)) 632 | resize_ratio_3_x = resize_w/float(new_w) 633 | resize_ratio_3_y = resize_h/float(new_h) 634 | text_polys[:, :, 0] *= resize_ratio_3_x 635 | text_polys[:, :, 1] *= resize_ratio_3_y 636 | new_h, new_w, _ = im.shape 637 | score_map, geo_map, training_mask = generate_rbox((new_h, new_w), text_polys, text_tags) 638 | 639 | images = im[:, :, ::-1].astype(np.float32) 640 | score_maps = score_map[::4, ::4, np.newaxis].astype(np.float32) 641 | geo_maps = geo_map[::4, ::4, :].astype(np.float32) 642 | training_masks = training_mask[::4, ::4, np.newaxis].astype(np.float32) 643 | 644 | except Exception as e: 645 | images, score_maps, geo_maps, training_masks = None, None, None, None 646 | 647 | return images, score_maps, geo_maps, training_masks 648 | 649 | class custom_dset(data.Dataset): 650 | def __init__(self, img_root, txt_root): 651 | self.image_list, self.img_name = get_images(img_root) 652 | self.txt_root = txt_root 653 | def __getitem__(self, index): 654 | img, score_map, geo_map, training_mask = image_label(self.txt_root, 655 | self.image_list, self.img_name, index, input_size = 512, 656 | random_scale = np.array([0.5, 1, 2.0, 3.0]), background_ratio = 3./8) 657 | 658 | 659 | return img, score_map, geo_map, training_mask 660 | 661 | def __len__(self): 662 | return len(self.image_list) 663 | 664 | def collate_fn(batch): 665 | img, score_map, geo_map, training_mask = zip(*batch) 666 | bs = len(score_map) 667 | images = [] 668 | score_maps = [] 669 | geo_maps = [] 670 | training_masks = [] 671 | for i in range(bs): 672 | if img[i] is not None: 673 | a = torch.from_numpy(img[i]) 674 | a = a.permute(2, 0, 1) 675 | images.append(a) 676 | b = torch.from_numpy(score_map[i]) 677 | b = b.permute(2, 0, 1) 678 | score_maps.append(b) 679 | c = torch.from_numpy(geo_map[i]) 680 | c = c.permute(2, 0, 1) 681 | geo_maps.append(c) 682 | d = torch.from_numpy(training_mask[i]) 683 | d = d.permute(2, 0, 1) 684 | training_masks.append(d) 685 | images = torch.stack(images, 0) 686 | score_maps = torch.stack(score_maps, 0) 687 | geo_maps = torch.stack(geo_maps, 0) 688 | training_masks = torch.stack(training_masks, 0) 689 | 690 | return images, score_maps, geo_maps, training_masks 691 | ## img = bs * 512 * 512 *3 692 | ## score_map = bs* 128 * 128 * 1 693 | ## geo_map = bs * 128 * 128 * 5 694 | ## training_mask = bs * 128 * 128 * 1 --------------------------------------------------------------------------------