├── EAST ├── convert_to_txt.py ├── data_util.py ├── deploy.sh ├── eval.py ├── icdar.py ├── lanms │ ├── Makefile │ ├── __init__.py │ ├── __main__.py │ ├── adaptor.cpp │ ├── include │ │ ├── clipper │ │ │ ├── clipper.cpp │ │ │ └── clipper.hpp │ │ └── pybind11 │ │ │ ├── attr.h │ │ │ ├── buffer_info.h │ │ │ ├── cast.h │ │ │ ├── chrono.h │ │ │ ├── class_support.h │ │ │ ├── common.h │ │ │ ├── complex.h │ │ │ ├── descr.h │ │ │ ├── eigen.h │ │ │ ├── embed.h │ │ │ ├── eval.h │ │ │ ├── functional.h │ │ │ ├── numpy.h │ │ │ ├── operators.h │ │ │ ├── options.h │ │ │ ├── pybind11.h │ │ │ ├── pytypes.h │ │ │ ├── stl.h │ │ │ ├── stl_bind.h │ │ │ └── typeid.h │ └── lanms.h ├── locality_aware_nms.py ├── model.py ├── multigpu_train.py ├── nets │ ├── resnet_utils.py │ └── resnet_v1.py ├── output.txt └── run_demo_server.py ├── LICENSE ├── README.md └── ocr_densenet ├── code ├── ocr │ ├── dataloader.py │ ├── densenet.py │ ├── main.py │ ├── resnet.py │ └── tools │ │ ├── measures.py │ │ ├── parse.py │ │ ├── plot.py │ │ ├── py_op.py │ │ ├── segmentation.py │ │ └── utils.py └── preprocessing │ ├── analysis_dataset.py │ ├── map_word_to_index.py │ └── show_black.py ├── files ├── alphabet_count_dict.json ├── alphabet_index_dict.json ├── black.json ├── image_hw_ratio_dict.json ├── train.csv ├── train_alphabet.json └── ttf │ └── simsun.ttf ├── make_test_data.py └── makedata.py /EAST/convert_to_txt.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | 4 | res = {} 5 | 6 | def get_annotations(path): 7 | with open(path, "r") as f: 8 | reader = csv.reader(f) 9 | for item in reader: 10 | if not item[0].endswith('jpg'): 11 | continue 12 | if item[0] not in res: 13 | res[item[0]] = [] 14 | res[item[0]].append(item[1:]) 15 | return res 16 | 17 | def write_txt(d, path): 18 | for name, objects in d.items(): 19 | name = name.split('.')[0] + '.txt' 20 | with open(os.path.join(path, name), 'w') as f: 21 | for ob in objects: 22 | f.write(','.join(ob) + '\n') 23 | 24 | 25 | if __name__ == '__main__': 26 | path = 'D:/data/chinese/train_lable.csv' 27 | save_path = 'D:/data/chinese/trian_dataset/' 28 | d = get_annotations(path) 29 | write_txt(d, save_path) 30 | -------------------------------------------------------------------------------- /EAST/data_util.py: -------------------------------------------------------------------------------- 1 | ''' 2 | this file is modified from keras implemention of data process multi-threading, 3 | see https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py 4 | ''' 5 | import time 6 | import numpy as np 7 | import threading 8 | import multiprocessing 9 | try: 10 | import queue 11 | except ImportError: 12 | import Queue as queue 13 | 14 | 15 | class GeneratorEnqueuer(): 16 | """Builds a queue out of a data generator. 17 | 18 | Used in `fit_generator`, `evaluate_generator`, `predict_generator`. 19 | 20 | # Arguments 21 | generator: a generator function which endlessly yields data 22 | use_multiprocessing: use multiprocessing if True, otherwise threading 23 | wait_time: time to sleep in-between calls to `put()` 24 | random_seed: Initial seed for workers, 25 | will be incremented by one for each workers. 26 | """ 27 | 28 | def __init__(self, generator, 29 | use_multiprocessing=False, 30 | wait_time=0.05, 31 | random_seed=None): 32 | self.wait_time = wait_time 33 | self._generator = generator 34 | self._use_multiprocessing = use_multiprocessing 35 | self._threads = [] 36 | self._stop_event = None 37 | self.queue = None 38 | self.random_seed = random_seed 39 | 40 | def start(self, workers=1, max_queue_size=10): 41 | """Kicks off threads which add data from the generator into the queue. 42 | 43 | # Arguments 44 | workers: number of worker threads 45 | max_queue_size: queue size 46 | (when full, threads could block on `put()`) 47 | """ 48 | 49 | def data_generator_task(): 50 | while not self._stop_event.is_set(): 51 | try: 52 | if self._use_multiprocessing or self.queue.qsize() < max_queue_size: 53 | generator_output = next(self._generator) 54 | self.queue.put(generator_output) 55 | else: 56 | time.sleep(self.wait_time) 57 | except Exception: 58 | self._stop_event.set() 59 | raise 60 | 61 | try: 62 | if self._use_multiprocessing: 63 | self.queue = multiprocessing.Queue(maxsize=max_queue_size) 64 | self._stop_event = multiprocessing.Event() 65 | else: 66 | self.queue = queue.Queue() 67 | self._stop_event = threading.Event() 68 | 69 | for _ in range(workers): 70 | if self._use_multiprocessing: 71 | # Reset random seed else all children processes 72 | # share the same seed 73 | np.random.seed(self.random_seed) 74 | thread = multiprocessing.Process(target=data_generator_task) 75 | thread.daemon = False 76 | if self.random_seed is not None: 77 | self.random_seed += 1 78 | else: 79 | thread = threading.Thread(target=data_generator_task) 80 | self._threads.append(thread) 81 | thread.start() 82 | except: 83 | self.stop() 84 | raise 85 | 86 | def is_running(self): 87 | return self._stop_event is not None and not self._stop_event.is_set() 88 | 89 | def stop(self, timeout=None): 90 | """Stops running threads and wait for them to exit, if necessary. 91 | 92 | Should be called by the same thread which called `start()`. 93 | 94 | # Arguments 95 | timeout: maximum time to wait on `thread.join()`. 96 | """ 97 | if self.is_running(): 98 | self._stop_event.set() 99 | 100 | for thread in self._threads: 101 | if thread.is_alive(): 102 | if self._use_multiprocessing: 103 | thread.terminate() 104 | else: 105 | thread.join(timeout) 106 | 107 | if self._use_multiprocessing: 108 | if self.queue is not None: 109 | self.queue.close() 110 | 111 | self._threads = [] 112 | self._stop_event = None 113 | self.queue = None 114 | 115 | def get(self): 116 | """Creates a generator to extract data from the queue. 117 | 118 | Skip the data if it is `None`. 119 | 120 | # Returns 121 | A generator 122 | """ 123 | while self.is_running(): 124 | if not self.queue.empty(): 125 | inputs = self.queue.get() 126 | if inputs is not None: 127 | yield inputs 128 | else: 129 | time.sleep(self.wait_time) -------------------------------------------------------------------------------- /EAST/deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p server_log 3 | gunicorn -w 3 run_demo_server:app -b 0.0.0.0:8769 -t 120 \ 4 | --error-logfile server_log/error.log \ 5 | --access-logfile server_log/access.log 6 | -------------------------------------------------------------------------------- /EAST/eval.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import time 3 | import math 4 | import os 5 | import numpy as np 6 | import tensorflow as tf 7 | from moxing.framework import file 8 | import locality_aware_nms as nms_locality 9 | # import lanms 10 | import moxing.tensorflow as mox 11 | tf.app.flags.DEFINE_string('test_data_path', '/cache/test_dataset', '') 12 | tf.app.flags.DEFINE_string('test_data_path_obs', 's3://tcd-public/test_dataset', '') 13 | tf.app.flags.DEFINE_string('gpu_list', '0', '') 14 | tf.app.flags.DEFINE_string('checkpoint_path', '../ckpt_chinese', '') 15 | tf.app.flags.DEFINE_string('output', './output.txt', '') 16 | tf.app.flags.DEFINE_bool('no_write_images', False, 'do not write images') 17 | tf.app.flags.DEFINE_string('data_url', 'q', '') 18 | tf.app.flags.DEFINE_string('train_url', 'q', '') 19 | tf.app.flags.DEFINE_integer('num_gpus', 1, '') 20 | 21 | import model 22 | from icdar import restore_rectangle 23 | 24 | FLAGS = tf.app.flags.FLAGS 25 | mox.file.copy_parallel(FLAGS.test_data_path_obs, FLAGS.test_data_path) 26 | 27 | def get_images(): 28 | ''' 29 | find image files in test data path 30 | :return: list of files found 31 | ''' 32 | files = [] 33 | exts = ['jpg', 'png', 'jpeg', 'JPG'] 34 | for parent, dirnames, filenames in os.walk(FLAGS.test_data_path): 35 | for filename in filenames: 36 | for ext in exts: 37 | if filename.endswith(ext): 38 | files.append(os.path.join(parent, filename)) 39 | break 40 | print('Find {} images'.format(len(files))) 41 | return files 42 | 43 | 44 | def resize_image(im, max_side_len=2400): 45 | ''' 46 | resize image to a size multiple of 32 which is required by the network 47 | :param im: the resized image 48 | :param max_side_len: limit of max image size to avoid out of memory in gpu 49 | :return: the resized image and the resize ratio 50 | ''' 51 | h, w, _ = im.shape 52 | 53 | resize_w = w 54 | resize_h = h 55 | 56 | # limit the max side 57 | if max(resize_h, resize_w) > max_side_len: 58 | ratio = float(max_side_len) / resize_h if resize_h > resize_w else float(max_side_len) / resize_w 59 | else: 60 | ratio = 1. 61 | resize_h = int(resize_h * ratio) 62 | resize_w = int(resize_w * ratio) 63 | 64 | resize_h = resize_h if resize_h % 32 == 0 else (resize_h // 32 - 1) * 32 65 | resize_w = resize_w if resize_w % 32 == 0 else (resize_w // 32 - 1) * 32 66 | im = cv2.resize(im, (int(resize_w), int(resize_h))) 67 | 68 | ratio_h = resize_h / float(h) 69 | ratio_w = resize_w / float(w) 70 | 71 | return im, (ratio_h, ratio_w) 72 | 73 | 74 | def detect(score_map, geo_map, timer, score_map_thresh=0.8, box_thresh=0.1, nms_thres=0.2): 75 | ''' 76 | restore text boxes from score map and geo map 77 | :param score_map: 78 | :param geo_map: 79 | :param timer: 80 | :param score_map_thresh: threshhold for score map 81 | :param box_thresh: threshhold for boxes 82 | :param nms_thres: threshold for nms 83 | :return: 84 | ''' 85 | if len(score_map.shape) == 4: 86 | score_map = score_map[0, :, :, 0] 87 | geo_map = geo_map[0, :, :, ] 88 | # filter the score map 89 | xy_text = np.argwhere(score_map > score_map_thresh) 90 | # sort the text boxes via the y axis 91 | xy_text = xy_text[np.argsort(xy_text[:, 0])] 92 | # restore 93 | start = time.time() 94 | text_box_restored = restore_rectangle(xy_text[:, ::-1]*4, geo_map[xy_text[:, 0], xy_text[:, 1], :]) # N*4*2 95 | print('{} text boxes before nms'.format(text_box_restored.shape[0])) 96 | boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) 97 | boxes[:, :8] = text_box_restored.reshape((-1, 8)) 98 | boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] 99 | timer['restore'] = time.time() - start 100 | # nms part 101 | start = time.time() 102 | boxes = nms_locality.nms_locality(boxes.astype(np.float64), nms_thres) 103 | # boxes = lanms.merge_quadrangle_n9(boxes.astype('float32'), nms_thres) 104 | timer['nms'] = time.time() - start 105 | 106 | if boxes.shape[0] == 0: 107 | return None, timer 108 | 109 | # here we filter some low score boxes by the average score map, this is different from the orginal paper 110 | for i, box in enumerate(boxes): 111 | mask = np.zeros_like(score_map, dtype=np.uint8) 112 | cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1) 113 | boxes[i, 8] = cv2.mean(score_map, mask)[0] 114 | boxes = boxes[boxes[:, 8] > box_thresh] 115 | 116 | return boxes, timer 117 | 118 | 119 | def sort_poly(p): 120 | min_axis = np.argmin(np.sum(p, axis=1)) 121 | p = p[[min_axis, (min_axis+1)%4, (min_axis+2)%4, (min_axis+3)%4]] 122 | if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]): 123 | return p 124 | else: 125 | return p[[0, 3, 2, 1]] 126 | 127 | 128 | def main(argv=None): 129 | import os 130 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list 131 | 132 | 133 | with tf.get_default_graph().as_default(): 134 | input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images') 135 | global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) 136 | 137 | f_score, f_geometry = model.model(input_images, is_training=False) 138 | 139 | variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) 140 | saver = tf.train.Saver(variable_averages.variables_to_restore()) 141 | 142 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 143 | ckpt_state = tf.train.get_checkpoint_state(FLAGS.checkpoint_path) 144 | model_path = os.path.join(FLAGS.checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path)) 145 | print('Restore from {}'.format(model_path)) 146 | saver.restore(sess, model_path) 147 | 148 | im_fn_list = get_images() 149 | for im_fn in im_fn_list: 150 | im = cv2.imread(im_fn)[:, :, ::-1] 151 | start_time = time.time() 152 | im_resized, (ratio_h, ratio_w) = resize_image(im) 153 | 154 | timer = {'net': 0, 'restore': 0, 'nms': 0} 155 | start = time.time() 156 | score, geometry = sess.run([f_score, f_geometry], feed_dict={input_images: [im_resized]}) 157 | timer['net'] = time.time() - start 158 | 159 | boxes, timer = detect(score_map=score, geo_map=geometry, timer=timer) 160 | print('{} : net {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms'.format( 161 | im_fn, timer['net']*1000, timer['restore']*1000, timer['nms']*1000)) 162 | 163 | if boxes is not None: 164 | boxes = boxes[:, :8].reshape((-1, 4, 2)) 165 | boxes[:, :, 0] /= ratio_w 166 | boxes[:, :, 1] /= ratio_h 167 | 168 | duration = time.time() - start_time 169 | print('[timing] {}'.format(duration)) 170 | 171 | # save to file 172 | output_path = FLAGS.output 173 | with file.File(output_path, 'a') as f: 174 | if boxes is not None: 175 | 176 | for box in boxes: 177 | # to avoid submitting errors 178 | box = sort_poly(box.astype(np.int32)) 179 | if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3]-box[0]) < 5: 180 | continue 181 | f.write('{},{},{},{},{},{},{},{},{}\r\n'.format( 182 | os.path.basename(im_fn).split('.')[0], 183 | box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0], box[2, 1], box[3, 0], box[3, 1], 184 | )) 185 | cv2.polylines(im[:, :, ::-1], [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1) 186 | # if not FLAGS.no_write_images: 187 | # img_path = os.path.join(FLAGS.output_dir, os.path.basename(im_fn)) 188 | # cv2.imwrite(img_path, im[:, :, ::-1]) 189 | 190 | if __name__ == '__main__': 191 | tf.app.run() 192 | -------------------------------------------------------------------------------- /EAST/lanms/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS = -I include -std=c++11 -O3 $(shell python3-config --cflags) 2 | LDFLAGS = $(shell python3-config --ldflags) 3 | 4 | DEPS = lanms.h $(shell find include -xtype f) 5 | CXX_SOURCES = adaptor.cpp include/clipper/clipper.cpp 6 | 7 | LIB_SO = adaptor.so 8 | 9 | $(LIB_SO): $(CXX_SOURCES) $(DEPS) 10 | $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC 11 | 12 | clean: 13 | rm -rf $(LIB_SO) 14 | -------------------------------------------------------------------------------- /EAST/lanms/__init__.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import numpy as np 4 | 5 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 6 | 7 | if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value 8 | raise RuntimeError('Cannot compile lanms: {}'.format(BASE_DIR)) 9 | 10 | 11 | def merge_quadrangle_n9(polys, thres=0.3, precision=10000): 12 | from .adaptor import merge_quadrangle_n9 as nms_impl 13 | if len(polys) == 0: 14 | return np.array([], dtype='float32') 15 | p = polys.copy() 16 | p[:,:8] *= precision 17 | ret = np.array(nms_impl(p, thres), dtype='float32') 18 | ret[:,:8] /= precision 19 | return ret 20 | 21 | -------------------------------------------------------------------------------- /EAST/lanms/__main__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | from . import merge_quadrangle_n9 5 | 6 | if __name__ == '__main__': 7 | # unit square with confidence 1 8 | q = np.array([0, 0, 0, 1, 1, 1, 1, 0, 1], dtype='float32') 9 | 10 | print(merge_quadrangle_n9(np.array([q, q + 0.1, q + 2]))) 11 | -------------------------------------------------------------------------------- /EAST/lanms/adaptor.cpp: -------------------------------------------------------------------------------- 1 | #include "pybind11/pybind11.h" 2 | #include "pybind11/numpy.h" 3 | #include "pybind11/stl.h" 4 | #include "pybind11/stl_bind.h" 5 | 6 | #include "lanms.h" 7 | 8 | namespace py = pybind11; 9 | 10 | 11 | namespace lanms_adaptor { 12 | 13 | std::vector> polys2floats(const std::vector &polys) { 14 | std::vector> ret; 15 | for (size_t i = 0; i < polys.size(); i ++) { 16 | auto &p = polys[i]; 17 | auto &poly = p.poly; 18 | ret.emplace_back(std::vector{ 19 | float(poly[0].X), float(poly[0].Y), 20 | float(poly[1].X), float(poly[1].Y), 21 | float(poly[2].X), float(poly[2].Y), 22 | float(poly[3].X), float(poly[3].Y), 23 | float(p.score), 24 | }); 25 | } 26 | 27 | return ret; 28 | } 29 | 30 | 31 | /** 32 | * 33 | * \param quad_n9 an n-by-9 numpy array, where first 8 numbers denote the 34 | * quadrangle, and the last one is the score 35 | * \param iou_threshold two quadrangles with iou score above this threshold 36 | * will be merged 37 | * 38 | * \return an n-by-9 numpy array, the merged quadrangles 39 | */ 40 | std::vector> merge_quadrangle_n9( 41 | py::array_t quad_n9, 42 | float iou_threshold) { 43 | auto pbuf = quad_n9.request(); 44 | if (pbuf.ndim != 2 || pbuf.shape[1] != 9) 45 | throw std::runtime_error("quadrangles must have a shape of (n, 9)"); 46 | auto n = pbuf.shape[0]; 47 | auto ptr = static_cast(pbuf.ptr); 48 | return polys2floats(lanms::merge_quadrangle_n9(ptr, n, iou_threshold)); 49 | } 50 | 51 | } 52 | 53 | PYBIND11_PLUGIN(adaptor) { 54 | py::module m("adaptor", "NMS"); 55 | 56 | m.def("merge_quadrangle_n9", &lanms_adaptor::merge_quadrangle_n9, 57 | "merge quadrangels"); 58 | 59 | return m.ptr(); 60 | } 61 | 62 | -------------------------------------------------------------------------------- /EAST/lanms/include/clipper/clipper.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataFountainCode/huawei_code_share/f1ef76649ea5c87a7be2d93dfaec1ff9a4d3e4b5/EAST/lanms/include/clipper/clipper.cpp -------------------------------------------------------------------------------- /EAST/lanms/include/pybind11/buffer_info.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/buffer_info.h: Python buffer object interface 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "common.h" 13 | 14 | NAMESPACE_BEGIN(pybind11) 15 | 16 | /// Information record describing a Python buffer object 17 | struct buffer_info { 18 | void *ptr = nullptr; // Pointer to the underlying storage 19 | ssize_t itemsize = 0; // Size of individual items in bytes 20 | ssize_t size = 0; // Total number of entries 21 | std::string format; // For homogeneous buffers, this should be set to format_descriptor::format() 22 | ssize_t ndim = 0; // Number of dimensions 23 | std::vector shape; // Shape of the tensor (1 entry per dimension) 24 | std::vector strides; // Number of entries between adjacent entries (for each per dimension) 25 | 26 | buffer_info() { } 27 | 28 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 29 | detail::any_container shape_in, detail::any_container strides_in) 30 | : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), 31 | shape(std::move(shape_in)), strides(std::move(strides_in)) { 32 | if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) 33 | pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); 34 | for (size_t i = 0; i < (size_t) ndim; ++i) 35 | size *= shape[i]; 36 | } 37 | 38 | template 39 | buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in) 40 | : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { } 41 | 42 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size) 43 | : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { } 44 | 45 | template 46 | buffer_info(T *ptr, ssize_t size) 47 | : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { } 48 | 49 | explicit buffer_info(Py_buffer *view, bool ownview = true) 50 | : buffer_info(view->buf, view->itemsize, view->format, view->ndim, 51 | {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) { 52 | this->view = view; 53 | this->ownview = ownview; 54 | } 55 | 56 | buffer_info(const buffer_info &) = delete; 57 | buffer_info& operator=(const buffer_info &) = delete; 58 | 59 | buffer_info(buffer_info &&other) { 60 | (*this) = std::move(other); 61 | } 62 | 63 | buffer_info& operator=(buffer_info &&rhs) { 64 | ptr = rhs.ptr; 65 | itemsize = rhs.itemsize; 66 | size = rhs.size; 67 | format = std::move(rhs.format); 68 | ndim = rhs.ndim; 69 | shape = std::move(rhs.shape); 70 | strides = std::move(rhs.strides); 71 | std::swap(view, rhs.view); 72 | std::swap(ownview, rhs.ownview); 73 | return *this; 74 | } 75 | 76 | ~buffer_info() { 77 | if (view && ownview) { PyBuffer_Release(view); delete view; } 78 | } 79 | 80 | private: 81 | struct private_ctr_tag { }; 82 | 83 | buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 84 | detail::any_container &&shape_in, detail::any_container &&strides_in) 85 | : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { } 86 | 87 | Py_buffer *view = nullptr; 88 | bool ownview = false; 89 | }; 90 | 91 | NAMESPACE_BEGIN(detail) 92 | 93 | template struct compare_buffer_info { 94 | static bool compare(const buffer_info& b) { 95 | return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); 96 | } 97 | }; 98 | 99 | template struct compare_buffer_info::value>> { 100 | static bool compare(const buffer_info& b) { 101 | return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value || 102 | ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || 103 | ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); 104 | } 105 | }; 106 | 107 | NAMESPACE_END(detail) 108 | NAMESPACE_END(pybind11) 109 | -------------------------------------------------------------------------------- /EAST/lanms/include/pybind11/chrono.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime 3 | 4 | Copyright (c) 2016 Trent Houliston and 5 | Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "pybind11.h" 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | // Backport the PyDateTime_DELTA functions from Python3.3 if required 20 | #ifndef PyDateTime_DELTA_GET_DAYS 21 | #define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days) 22 | #endif 23 | #ifndef PyDateTime_DELTA_GET_SECONDS 24 | #define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds) 25 | #endif 26 | #ifndef PyDateTime_DELTA_GET_MICROSECONDS 27 | #define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds) 28 | #endif 29 | 30 | NAMESPACE_BEGIN(pybind11) 31 | NAMESPACE_BEGIN(detail) 32 | 33 | template class duration_caster { 34 | public: 35 | typedef typename type::rep rep; 36 | typedef typename type::period period; 37 | 38 | typedef std::chrono::duration> days; 39 | 40 | bool load(handle src, bool) { 41 | using namespace std::chrono; 42 | 43 | // Lazy initialise the PyDateTime import 44 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 45 | 46 | if (!src) return false; 47 | // If invoked with datetime.delta object 48 | if (PyDelta_Check(src.ptr())) { 49 | value = type(duration_cast>( 50 | days(PyDateTime_DELTA_GET_DAYS(src.ptr())) 51 | + seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr())) 52 | + microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr())))); 53 | return true; 54 | } 55 | // If invoked with a float we assume it is seconds and convert 56 | else if (PyFloat_Check(src.ptr())) { 57 | value = type(duration_cast>(duration(PyFloat_AsDouble(src.ptr())))); 58 | return true; 59 | } 60 | else return false; 61 | } 62 | 63 | // If this is a duration just return it back 64 | static const std::chrono::duration& get_duration(const std::chrono::duration &src) { 65 | return src; 66 | } 67 | 68 | // If this is a time_point get the time_since_epoch 69 | template static std::chrono::duration get_duration(const std::chrono::time_point> &src) { 70 | return src.time_since_epoch(); 71 | } 72 | 73 | static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) { 74 | using namespace std::chrono; 75 | 76 | // Use overloaded function to get our duration from our source 77 | // Works out if it is a duration or time_point and get the duration 78 | auto d = get_duration(src); 79 | 80 | // Lazy initialise the PyDateTime import 81 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 82 | 83 | // Declare these special duration types so the conversions happen with the correct primitive types (int) 84 | using dd_t = duration>; 85 | using ss_t = duration>; 86 | using us_t = duration; 87 | 88 | auto dd = duration_cast(d); 89 | auto subd = d - dd; 90 | auto ss = duration_cast(subd); 91 | auto us = duration_cast(subd - ss); 92 | return PyDelta_FromDSU(dd.count(), ss.count(), us.count()); 93 | } 94 | 95 | PYBIND11_TYPE_CASTER(type, _("datetime.timedelta")); 96 | }; 97 | 98 | // This is for casting times on the system clock into datetime.datetime instances 99 | template class type_caster> { 100 | public: 101 | typedef std::chrono::time_point type; 102 | bool load(handle src, bool) { 103 | using namespace std::chrono; 104 | 105 | // Lazy initialise the PyDateTime import 106 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 107 | 108 | if (!src) return false; 109 | if (PyDateTime_Check(src.ptr())) { 110 | std::tm cal; 111 | cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr()); 112 | cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr()); 113 | cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr()); 114 | cal.tm_mday = PyDateTime_GET_DAY(src.ptr()); 115 | cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1; 116 | cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900; 117 | cal.tm_isdst = -1; 118 | 119 | value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr())); 120 | return true; 121 | } 122 | else return false; 123 | } 124 | 125 | static handle cast(const std::chrono::time_point &src, return_value_policy /* policy */, handle /* parent */) { 126 | using namespace std::chrono; 127 | 128 | // Lazy initialise the PyDateTime import 129 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 130 | 131 | std::time_t tt = system_clock::to_time_t(src); 132 | // this function uses static memory so it's best to copy it out asap just in case 133 | // otherwise other code that is using localtime may break this (not just python code) 134 | std::tm localtime = *std::localtime(&tt); 135 | 136 | // Declare these special duration types so the conversions happen with the correct primitive types (int) 137 | using us_t = duration; 138 | 139 | return PyDateTime_FromDateAndTime(localtime.tm_year + 1900, 140 | localtime.tm_mon + 1, 141 | localtime.tm_mday, 142 | localtime.tm_hour, 143 | localtime.tm_min, 144 | localtime.tm_sec, 145 | (duration_cast(src.time_since_epoch() % seconds(1))).count()); 146 | } 147 | PYBIND11_TYPE_CASTER(type, _("datetime.datetime")); 148 | }; 149 | 150 | // Other clocks that are not the system clock are not measured as datetime.datetime objects 151 | // since they are not measured on calendar time. So instead we just make them timedeltas 152 | // Or if they have passed us a time as a float we convert that 153 | template class type_caster> 154 | : public duration_caster> { 155 | }; 156 | 157 | template class type_caster> 158 | : public duration_caster> { 159 | }; 160 | 161 | NAMESPACE_END(detail) 162 | NAMESPACE_END(pybind11) 163 | -------------------------------------------------------------------------------- /EAST/lanms/include/pybind11/complex.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/complex.h: Complex number support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | /// glibc defines I as a macro which breaks things, e.g., boost template names 16 | #ifdef I 17 | # undef I 18 | #endif 19 | 20 | NAMESPACE_BEGIN(pybind11) 21 | 22 | template struct format_descriptor, detail::enable_if_t::value>> { 23 | static constexpr const char c = format_descriptor::c; 24 | static constexpr const char value[3] = { 'Z', c, '\0' }; 25 | static std::string format() { return std::string(value); } 26 | }; 27 | 28 | template constexpr const char format_descriptor< 29 | std::complex, detail::enable_if_t::value>>::value[3]; 30 | 31 | NAMESPACE_BEGIN(detail) 32 | 33 | template struct is_fmt_numeric, detail::enable_if_t::value>> { 34 | static constexpr bool value = true; 35 | static constexpr int index = is_fmt_numeric::index + 3; 36 | }; 37 | 38 | template class type_caster> { 39 | public: 40 | bool load(handle src, bool convert) { 41 | if (!src) 42 | return false; 43 | if (!convert && !PyComplex_Check(src.ptr())) 44 | return false; 45 | Py_complex result = PyComplex_AsCComplex(src.ptr()); 46 | if (result.real == -1.0 && PyErr_Occurred()) { 47 | PyErr_Clear(); 48 | return false; 49 | } 50 | value = std::complex((T) result.real, (T) result.imag); 51 | return true; 52 | } 53 | 54 | static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) { 55 | return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); 56 | } 57 | 58 | PYBIND11_TYPE_CASTER(std::complex, _("complex")); 59 | }; 60 | NAMESPACE_END(detail) 61 | NAMESPACE_END(pybind11) 62 | -------------------------------------------------------------------------------- /EAST/lanms/include/pybind11/descr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/descr.h: Helper type for concatenating type signatures 3 | either at runtime (C++11) or compile time (C++14) 4 | 5 | Copyright (c) 2016 Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "common.h" 14 | 15 | NAMESPACE_BEGIN(pybind11) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | /* Concatenate type signatures at compile time using C++14 */ 19 | #if defined(PYBIND11_CPP14) && !defined(_MSC_VER) 20 | #define PYBIND11_CONSTEXPR_DESCR 21 | 22 | template class descr { 23 | template friend class descr; 24 | public: 25 | constexpr descr(char const (&text) [Size1+1], const std::type_info * const (&types)[Size2+1]) 26 | : descr(text, types, 27 | make_index_sequence(), 28 | make_index_sequence()) { } 29 | 30 | constexpr const char *text() const { return m_text; } 31 | constexpr const std::type_info * const * types() const { return m_types; } 32 | 33 | template 34 | constexpr descr operator+(const descr &other) const { 35 | return concat(other, 36 | make_index_sequence(), 37 | make_index_sequence(), 38 | make_index_sequence(), 39 | make_index_sequence()); 40 | } 41 | 42 | protected: 43 | template 44 | constexpr descr( 45 | char const (&text) [Size1+1], 46 | const std::type_info * const (&types) [Size2+1], 47 | index_sequence, index_sequence) 48 | : m_text{text[Indices1]..., '\0'}, 49 | m_types{types[Indices2]..., nullptr } {} 50 | 51 | template 53 | constexpr descr 54 | concat(const descr &other, 55 | index_sequence, index_sequence, 56 | index_sequence, index_sequence) const { 57 | return descr( 58 | { m_text[Indices1]..., other.m_text[OtherIndices1]..., '\0' }, 59 | { m_types[Indices2]..., other.m_types[OtherIndices2]..., nullptr } 60 | ); 61 | } 62 | 63 | protected: 64 | char m_text[Size1 + 1]; 65 | const std::type_info * m_types[Size2 + 1]; 66 | }; 67 | 68 | template constexpr descr _(char const(&text)[Size]) { 69 | return descr(text, { nullptr }); 70 | } 71 | 72 | template struct int_to_str : int_to_str { }; 73 | template struct int_to_str<0, Digits...> { 74 | static constexpr auto digits = descr({ ('0' + Digits)..., '\0' }, { nullptr }); 75 | }; 76 | 77 | // Ternary description (like std::conditional) 78 | template 79 | constexpr enable_if_t> _(char const(&text1)[Size1], char const(&)[Size2]) { 80 | return _(text1); 81 | } 82 | template 83 | constexpr enable_if_t> _(char const(&)[Size1], char const(&text2)[Size2]) { 84 | return _(text2); 85 | } 86 | template 87 | constexpr enable_if_t> _(descr d, descr) { return d; } 88 | template 89 | constexpr enable_if_t> _(descr, descr d) { return d; } 90 | 91 | template auto constexpr _() -> decltype(int_to_str::digits) { 92 | return int_to_str::digits; 93 | } 94 | 95 | template constexpr descr<1, 1> _() { 96 | return descr<1, 1>({ '%', '\0' }, { &typeid(Type), nullptr }); 97 | } 98 | 99 | inline constexpr descr<0, 0> concat() { return _(""); } 100 | template auto constexpr concat(descr descr) { return descr; } 101 | template auto constexpr concat(descr descr, Args&&... args) { return descr + _(", ") + concat(args...); } 102 | template auto constexpr type_descr(descr descr) { return _("{") + descr + _("}"); } 103 | 104 | #define PYBIND11_DESCR constexpr auto 105 | 106 | #else /* Simpler C++11 implementation based on run-time memory allocation and copying */ 107 | 108 | class descr { 109 | public: 110 | PYBIND11_NOINLINE descr(const char *text, const std::type_info * const * types) { 111 | size_t nChars = len(text), nTypes = len(types); 112 | m_text = new char[nChars]; 113 | m_types = new const std::type_info *[nTypes]; 114 | memcpy(m_text, text, nChars * sizeof(char)); 115 | memcpy(m_types, types, nTypes * sizeof(const std::type_info *)); 116 | } 117 | 118 | PYBIND11_NOINLINE descr operator+(descr &&d2) && { 119 | descr r; 120 | 121 | size_t nChars1 = len(m_text), nTypes1 = len(m_types); 122 | size_t nChars2 = len(d2.m_text), nTypes2 = len(d2.m_types); 123 | 124 | r.m_text = new char[nChars1 + nChars2 - 1]; 125 | r.m_types = new const std::type_info *[nTypes1 + nTypes2 - 1]; 126 | memcpy(r.m_text, m_text, (nChars1-1) * sizeof(char)); 127 | memcpy(r.m_text + nChars1 - 1, d2.m_text, nChars2 * sizeof(char)); 128 | memcpy(r.m_types, m_types, (nTypes1-1) * sizeof(std::type_info *)); 129 | memcpy(r.m_types + nTypes1 - 1, d2.m_types, nTypes2 * sizeof(std::type_info *)); 130 | 131 | delete[] m_text; delete[] m_types; 132 | delete[] d2.m_text; delete[] d2.m_types; 133 | 134 | return r; 135 | } 136 | 137 | char *text() { return m_text; } 138 | const std::type_info * * types() { return m_types; } 139 | 140 | protected: 141 | PYBIND11_NOINLINE descr() { } 142 | 143 | template static size_t len(const T *ptr) { // return length including null termination 144 | const T *it = ptr; 145 | while (*it++ != (T) 0) 146 | ; 147 | return static_cast(it - ptr); 148 | } 149 | 150 | const std::type_info **m_types = nullptr; 151 | char *m_text = nullptr; 152 | }; 153 | 154 | /* The 'PYBIND11_NOINLINE inline' combinations below are intentional to get the desired linkage while producing as little object code as possible */ 155 | 156 | PYBIND11_NOINLINE inline descr _(const char *text) { 157 | const std::type_info *types[1] = { nullptr }; 158 | return descr(text, types); 159 | } 160 | 161 | template PYBIND11_NOINLINE enable_if_t _(const char *text1, const char *) { return _(text1); } 162 | template PYBIND11_NOINLINE enable_if_t _(char const *, const char *text2) { return _(text2); } 163 | template PYBIND11_NOINLINE enable_if_t _(descr d, descr) { return d; } 164 | template PYBIND11_NOINLINE enable_if_t _(descr, descr d) { return d; } 165 | 166 | template PYBIND11_NOINLINE descr _() { 167 | const std::type_info *types[2] = { &typeid(Type), nullptr }; 168 | return descr("%", types); 169 | } 170 | 171 | template PYBIND11_NOINLINE descr _() { 172 | const std::type_info *types[1] = { nullptr }; 173 | return descr(std::to_string(Size).c_str(), types); 174 | } 175 | 176 | PYBIND11_NOINLINE inline descr concat() { return _(""); } 177 | PYBIND11_NOINLINE inline descr concat(descr &&d) { return d; } 178 | template PYBIND11_NOINLINE descr concat(descr &&d, Args&&... args) { return std::move(d) + _(", ") + concat(std::forward(args)...); } 179 | PYBIND11_NOINLINE inline descr type_descr(descr&& d) { return _("{") + std::move(d) + _("}"); } 180 | 181 | #define PYBIND11_DESCR ::pybind11::detail::descr 182 | #endif 183 | 184 | NAMESPACE_END(detail) 185 | NAMESPACE_END(pybind11) 186 | -------------------------------------------------------------------------------- /EAST/lanms/include/pybind11/embed.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/embed.h: Support for embedding the interpreter 3 | 4 | Copyright (c) 2017 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include "eval.h" 14 | 15 | #if defined(PYPY_VERSION) 16 | # error Embedding the interpreter is not supported with PyPy 17 | #endif 18 | 19 | #if PY_MAJOR_VERSION >= 3 20 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 21 | extern "C" PyObject *pybind11_init_impl_##name() { \ 22 | return pybind11_init_wrapper_##name(); \ 23 | } 24 | #else 25 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 26 | extern "C" void pybind11_init_impl_##name() { \ 27 | pybind11_init_wrapper_##name(); \ 28 | } 29 | #endif 30 | 31 | /** \rst 32 | Add a new module to the table of builtins for the interpreter. Must be 33 | defined in global scope. The first macro parameter is the name of the 34 | module (without quotes). The second parameter is the variable which will 35 | be used as the interface to add functions and classes to the module. 36 | 37 | .. code-block:: cpp 38 | 39 | PYBIND11_EMBEDDED_MODULE(example, m) { 40 | // ... initialize functions and classes here 41 | m.def("foo", []() { 42 | return "Hello, World!"; 43 | }); 44 | } 45 | \endrst */ 46 | #define PYBIND11_EMBEDDED_MODULE(name, variable) \ 47 | static void pybind11_init_##name(pybind11::module &); \ 48 | static PyObject *pybind11_init_wrapper_##name() { \ 49 | auto m = pybind11::module(#name); \ 50 | try { \ 51 | pybind11_init_##name(m); \ 52 | return m.ptr(); \ 53 | } catch (pybind11::error_already_set &e) { \ 54 | PyErr_SetString(PyExc_ImportError, e.what()); \ 55 | return nullptr; \ 56 | } catch (const std::exception &e) { \ 57 | PyErr_SetString(PyExc_ImportError, e.what()); \ 58 | return nullptr; \ 59 | } \ 60 | } \ 61 | PYBIND11_EMBEDDED_MODULE_IMPL(name) \ 62 | pybind11::detail::embedded_module name(#name, pybind11_init_impl_##name); \ 63 | void pybind11_init_##name(pybind11::module &variable) 64 | 65 | 66 | NAMESPACE_BEGIN(pybind11) 67 | NAMESPACE_BEGIN(detail) 68 | 69 | /// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks. 70 | struct embedded_module { 71 | #if PY_MAJOR_VERSION >= 3 72 | using init_t = PyObject *(*)(); 73 | #else 74 | using init_t = void (*)(); 75 | #endif 76 | embedded_module(const char *name, init_t init) { 77 | if (Py_IsInitialized()) 78 | pybind11_fail("Can't add new modules after the interpreter has been initialized"); 79 | 80 | auto result = PyImport_AppendInittab(name, init); 81 | if (result == -1) 82 | pybind11_fail("Insufficient memory to add a new module"); 83 | } 84 | }; 85 | 86 | NAMESPACE_END(detail) 87 | 88 | /** \rst 89 | Initialize the Python interpreter. No other pybind11 or CPython API functions can be 90 | called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The 91 | optional parameter can be used to skip the registration of signal handlers (see the 92 | Python documentation for details). Calling this function again after the interpreter 93 | has already been initialized is a fatal error. 94 | \endrst */ 95 | inline void initialize_interpreter(bool init_signal_handlers = true) { 96 | if (Py_IsInitialized()) 97 | pybind11_fail("The interpreter is already running"); 98 | 99 | Py_InitializeEx(init_signal_handlers ? 1 : 0); 100 | 101 | // Make .py files in the working directory available by default 102 | auto sys_path = reinterpret_borrow(module::import("sys").attr("path")); 103 | sys_path.append("."); 104 | } 105 | 106 | /** \rst 107 | Shut down the Python interpreter. No pybind11 or CPython API functions can be called 108 | after this. In addition, pybind11 objects must not outlive the interpreter: 109 | 110 | .. code-block:: cpp 111 | 112 | { // BAD 113 | py::initialize_interpreter(); 114 | auto hello = py::str("Hello, World!"); 115 | py::finalize_interpreter(); 116 | } // <-- BOOM, hello's destructor is called after interpreter shutdown 117 | 118 | { // GOOD 119 | py::initialize_interpreter(); 120 | { // scoped 121 | auto hello = py::str("Hello, World!"); 122 | } // <-- OK, hello is cleaned up properly 123 | py::finalize_interpreter(); 124 | } 125 | 126 | { // BETTER 127 | py::scoped_interpreter guard{}; 128 | auto hello = py::str("Hello, World!"); 129 | } 130 | 131 | .. warning:: 132 | 133 | The interpreter can be restarted by calling `initialize_interpreter` again. 134 | Modules created using pybind11 can be safely re-initialized. However, Python 135 | itself cannot completely unload binary extension modules and there are several 136 | caveats with regard to interpreter restarting. All the details can be found 137 | in the CPython documentation. In short, not all interpreter memory may be 138 | freed, either due to reference cycles or user-created global data. 139 | 140 | \endrst */ 141 | inline void finalize_interpreter() { 142 | handle builtins(PyEval_GetBuiltins()); 143 | const char *id = PYBIND11_INTERNALS_ID; 144 | 145 | // Get the internals pointer (without creating it if it doesn't exist). It's possible for the 146 | // internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()` 147 | // during destruction), so we get the pointer-pointer here and check it after Py_Finalize(). 148 | detail::internals **internals_ptr_ptr = &detail::get_internals_ptr(); 149 | // It could also be stashed in builtins, so look there too: 150 | if (builtins.contains(id) && isinstance(builtins[id])) 151 | internals_ptr_ptr = capsule(builtins[id]); 152 | 153 | Py_Finalize(); 154 | 155 | if (internals_ptr_ptr) { 156 | delete *internals_ptr_ptr; 157 | *internals_ptr_ptr = nullptr; 158 | } 159 | } 160 | 161 | /** \rst 162 | Scope guard version of `initialize_interpreter` and `finalize_interpreter`. 163 | This a move-only guard and only a single instance can exist. 164 | 165 | .. code-block:: cpp 166 | 167 | #include 168 | 169 | int main() { 170 | py::scoped_interpreter guard{}; 171 | py::print(Hello, World!); 172 | } // <-- interpreter shutdown 173 | \endrst */ 174 | class scoped_interpreter { 175 | public: 176 | scoped_interpreter(bool init_signal_handlers = true) { 177 | initialize_interpreter(init_signal_handlers); 178 | } 179 | 180 | scoped_interpreter(const scoped_interpreter &) = delete; 181 | scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; } 182 | scoped_interpreter &operator=(const scoped_interpreter &) = delete; 183 | scoped_interpreter &operator=(scoped_interpreter &&) = delete; 184 | 185 | ~scoped_interpreter() { 186 | if (is_valid) 187 | finalize_interpreter(); 188 | } 189 | 190 | private: 191 | bool is_valid = true; 192 | }; 193 | 194 | NAMESPACE_END(pybind11) 195 | -------------------------------------------------------------------------------- /EAST/lanms/include/pybind11/eval.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/exec.h: Support for evaluating Python expressions and statements 3 | from strings and files 4 | 5 | Copyright (c) 2016 Klemens Morgenstern and 6 | Wenzel Jakob 7 | 8 | All rights reserved. Use of this source code is governed by a 9 | BSD-style license that can be found in the LICENSE file. 10 | */ 11 | 12 | #pragma once 13 | 14 | #include "pybind11.h" 15 | 16 | NAMESPACE_BEGIN(pybind11) 17 | 18 | enum eval_mode { 19 | /// Evaluate a string containing an isolated expression 20 | eval_expr, 21 | 22 | /// Evaluate a string containing a single statement. Returns \c none 23 | eval_single_statement, 24 | 25 | /// Evaluate a string containing a sequence of statement. Returns \c none 26 | eval_statements 27 | }; 28 | 29 | template 30 | object eval(str expr, object global = globals(), object local = object()) { 31 | if (!local) 32 | local = global; 33 | 34 | /* PyRun_String does not accept a PyObject / encoding specifier, 35 | this seems to be the only alternative */ 36 | std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr; 37 | 38 | int start; 39 | switch (mode) { 40 | case eval_expr: start = Py_eval_input; break; 41 | case eval_single_statement: start = Py_single_input; break; 42 | case eval_statements: start = Py_file_input; break; 43 | default: pybind11_fail("invalid evaluation mode"); 44 | } 45 | 46 | PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr()); 47 | if (!result) 48 | throw error_already_set(); 49 | return reinterpret_steal(result); 50 | } 51 | 52 | template 53 | object eval(const char (&s)[N], object global = globals(), object local = object()) { 54 | /* Support raw string literals by removing common leading whitespace */ 55 | auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s)) 56 | : str(s); 57 | return eval(expr, global, local); 58 | } 59 | 60 | inline void exec(str expr, object global = globals(), object local = object()) { 61 | eval(expr, global, local); 62 | } 63 | 64 | template 65 | void exec(const char (&s)[N], object global = globals(), object local = object()) { 66 | eval(s, global, local); 67 | } 68 | 69 | template 70 | object eval_file(str fname, object global = globals(), object local = object()) { 71 | if (!local) 72 | local = global; 73 | 74 | int start; 75 | switch (mode) { 76 | case eval_expr: start = Py_eval_input; break; 77 | case eval_single_statement: start = Py_single_input; break; 78 | case eval_statements: start = Py_file_input; break; 79 | default: pybind11_fail("invalid evaluation mode"); 80 | } 81 | 82 | int closeFile = 1; 83 | std::string fname_str = (std::string) fname; 84 | #if PY_VERSION_HEX >= 0x03040000 85 | FILE *f = _Py_fopen_obj(fname.ptr(), "r"); 86 | #elif PY_VERSION_HEX >= 0x03000000 87 | FILE *f = _Py_fopen(fname.ptr(), "r"); 88 | #else 89 | /* No unicode support in open() :( */ 90 | auto fobj = reinterpret_steal(PyFile_FromString( 91 | const_cast(fname_str.c_str()), 92 | const_cast("r"))); 93 | FILE *f = nullptr; 94 | if (fobj) 95 | f = PyFile_AsFile(fobj.ptr()); 96 | closeFile = 0; 97 | #endif 98 | if (!f) { 99 | PyErr_Clear(); 100 | pybind11_fail("File \"" + fname_str + "\" could not be opened!"); 101 | } 102 | 103 | #if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION) 104 | PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(), 105 | local.ptr()); 106 | (void) closeFile; 107 | #else 108 | PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(), 109 | local.ptr(), closeFile); 110 | #endif 111 | 112 | if (!result) 113 | throw error_already_set(); 114 | return reinterpret_steal(result); 115 | } 116 | 117 | NAMESPACE_END(pybind11) 118 | -------------------------------------------------------------------------------- /EAST/lanms/include/pybind11/functional.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/functional.h: std::function<> support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | NAMESPACE_BEGIN(pybind11) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | template 19 | struct type_caster> { 20 | using type = std::function; 21 | using retval_type = conditional_t::value, void_type, Return>; 22 | using function_type = Return (*) (Args...); 23 | 24 | public: 25 | bool load(handle src, bool convert) { 26 | if (src.is_none()) { 27 | // Defer accepting None to other overloads (if we aren't in convert mode): 28 | if (!convert) return false; 29 | return true; 30 | } 31 | 32 | if (!isinstance(src)) 33 | return false; 34 | 35 | auto func = reinterpret_borrow(src); 36 | 37 | /* 38 | When passing a C++ function as an argument to another C++ 39 | function via Python, every function call would normally involve 40 | a full C++ -> Python -> C++ roundtrip, which can be prohibitive. 41 | Here, we try to at least detect the case where the function is 42 | stateless (i.e. function pointer or lambda function without 43 | captured variables), in which case the roundtrip can be avoided. 44 | */ 45 | if (auto cfunc = func.cpp_function()) { 46 | auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); 47 | auto rec = (function_record *) c; 48 | 49 | if (rec && rec->is_stateless && 50 | same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { 51 | struct capture { function_type f; }; 52 | value = ((capture *) &rec->data)->f; 53 | return true; 54 | } 55 | } 56 | 57 | value = [func](Args... args) -> Return { 58 | gil_scoped_acquire acq; 59 | object retval(func(std::forward(args)...)); 60 | /* Visual studio 2015 parser issue: need parentheses around this expression */ 61 | return (retval.template cast()); 62 | }; 63 | return true; 64 | } 65 | 66 | template 67 | static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) { 68 | if (!f_) 69 | return none().inc_ref(); 70 | 71 | auto result = f_.template target(); 72 | if (result) 73 | return cpp_function(*result, policy).release(); 74 | else 75 | return cpp_function(std::forward(f_), policy).release(); 76 | } 77 | 78 | PYBIND11_TYPE_CASTER(type, _("Callable[[") + 79 | argument_loader::arg_names() + _("], ") + 80 | make_caster::name() + 81 | _("]")); 82 | }; 83 | 84 | NAMESPACE_END(detail) 85 | NAMESPACE_END(pybind11) 86 | -------------------------------------------------------------------------------- /EAST/lanms/include/pybind11/operators.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/operator.h: Metatemplates for operator overloading 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | 14 | #if defined(__clang__) && !defined(__INTEL_COMPILER) 15 | # pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type())) 16 | #elif defined(_MSC_VER) 17 | # pragma warning(push) 18 | # pragma warning(disable: 4127) // warning C4127: Conditional expression is constant 19 | #endif 20 | 21 | NAMESPACE_BEGIN(pybind11) 22 | NAMESPACE_BEGIN(detail) 23 | 24 | /// Enumeration with all supported operator types 25 | enum op_id : int { 26 | op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift, 27 | op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert, 28 | op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le, 29 | op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift, 30 | op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero, 31 | op_repr, op_truediv, op_itruediv 32 | }; 33 | 34 | enum op_type : int { 35 | op_l, /* base type on left */ 36 | op_r, /* base type on right */ 37 | op_u /* unary operator */ 38 | }; 39 | 40 | struct self_t { }; 41 | static const self_t self = self_t(); 42 | 43 | /// Type for an unused type slot 44 | struct undefined_t { }; 45 | 46 | /// Don't warn about an unused variable 47 | inline self_t __self() { return self; } 48 | 49 | /// base template of operator implementations 50 | template struct op_impl { }; 51 | 52 | /// Operator implementation generator 53 | template struct op_ { 54 | template void execute(Class &cl, const Extra&... extra) const { 55 | using Base = typename Class::type; 56 | using L_type = conditional_t::value, Base, L>; 57 | using R_type = conditional_t::value, Base, R>; 58 | using op = op_impl; 59 | cl.def(op::name(), &op::execute, is_operator(), extra...); 60 | #if PY_MAJOR_VERSION < 3 61 | if (id == op_truediv || id == op_itruediv) 62 | cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", 63 | &op::execute, is_operator(), extra...); 64 | #endif 65 | } 66 | template void execute_cast(Class &cl, const Extra&... extra) const { 67 | using Base = typename Class::type; 68 | using L_type = conditional_t::value, Base, L>; 69 | using R_type = conditional_t::value, Base, R>; 70 | using op = op_impl; 71 | cl.def(op::name(), &op::execute_cast, is_operator(), extra...); 72 | #if PY_MAJOR_VERSION < 3 73 | if (id == op_truediv || id == op_itruediv) 74 | cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", 75 | &op::execute, is_operator(), extra...); 76 | #endif 77 | } 78 | }; 79 | 80 | #define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \ 81 | template struct op_impl { \ 82 | static char const* name() { return "__" #id "__"; } \ 83 | static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \ 84 | static B execute_cast(const L &l, const R &r) { return B(expr); } \ 85 | }; \ 86 | template struct op_impl { \ 87 | static char const* name() { return "__" #rid "__"; } \ 88 | static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \ 89 | static B execute_cast(const R &r, const L &l) { return B(expr); } \ 90 | }; \ 91 | inline op_ op(const self_t &, const self_t &) { \ 92 | return op_(); \ 93 | } \ 94 | template op_ op(const self_t &, const T &) { \ 95 | return op_(); \ 96 | } \ 97 | template op_ op(const T &, const self_t &) { \ 98 | return op_(); \ 99 | } 100 | 101 | #define PYBIND11_INPLACE_OPERATOR(id, op, expr) \ 102 | template struct op_impl { \ 103 | static char const* name() { return "__" #id "__"; } \ 104 | static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \ 105 | static B execute_cast(L &l, const R &r) { return B(expr); } \ 106 | }; \ 107 | template op_ op(const self_t &, const T &) { \ 108 | return op_(); \ 109 | } 110 | 111 | #define PYBIND11_UNARY_OPERATOR(id, op, expr) \ 112 | template struct op_impl { \ 113 | static char const* name() { return "__" #id "__"; } \ 114 | static auto execute(const L &l) -> decltype(expr) { return expr; } \ 115 | static B execute_cast(const L &l) { return B(expr); } \ 116 | }; \ 117 | inline op_ op(const self_t &) { \ 118 | return op_(); \ 119 | } 120 | 121 | PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r) 122 | PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r) 123 | PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r) 124 | PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r) 125 | PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r) 126 | PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r) 127 | PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r) 128 | PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r) 129 | PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r) 130 | PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r) 131 | PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r) 132 | PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r) 133 | PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r) 134 | PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r) 135 | PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r) 136 | PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r) 137 | //PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r)) 138 | PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r) 139 | PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r) 140 | PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r) 141 | PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r) 142 | PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r) 143 | PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r) 144 | PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r) 145 | PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r) 146 | PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r) 147 | PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r) 148 | PYBIND11_UNARY_OPERATOR(neg, operator-, -l) 149 | PYBIND11_UNARY_OPERATOR(pos, operator+, +l) 150 | PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l)) 151 | PYBIND11_UNARY_OPERATOR(invert, operator~, (~l)) 152 | PYBIND11_UNARY_OPERATOR(bool, operator!, !!l) 153 | PYBIND11_UNARY_OPERATOR(int, int_, (int) l) 154 | PYBIND11_UNARY_OPERATOR(float, float_, (double) l) 155 | 156 | #undef PYBIND11_BINARY_OPERATOR 157 | #undef PYBIND11_INPLACE_OPERATOR 158 | #undef PYBIND11_UNARY_OPERATOR 159 | NAMESPACE_END(detail) 160 | 161 | using detail::self; 162 | 163 | NAMESPACE_END(pybind11) 164 | 165 | #if defined(_MSC_VER) 166 | # pragma warning(pop) 167 | #endif 168 | -------------------------------------------------------------------------------- /EAST/lanms/include/pybind11/options.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/options.h: global settings that are configurable at runtime. 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "common.h" 13 | 14 | NAMESPACE_BEGIN(pybind11) 15 | 16 | class options { 17 | public: 18 | 19 | // Default RAII constructor, which leaves settings as they currently are. 20 | options() : previous_state(global_state()) {} 21 | 22 | // Class is non-copyable. 23 | options(const options&) = delete; 24 | options& operator=(const options&) = delete; 25 | 26 | // Destructor, which restores settings that were in effect before. 27 | ~options() { 28 | global_state() = previous_state; 29 | } 30 | 31 | // Setter methods (affect the global state): 32 | 33 | options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; } 34 | 35 | options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; } 36 | 37 | options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; } 38 | 39 | options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; } 40 | 41 | // Getter methods (return the global state): 42 | 43 | static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; } 44 | 45 | static bool show_function_signatures() { return global_state().show_function_signatures; } 46 | 47 | // This type is not meant to be allocated on the heap. 48 | void* operator new(size_t) = delete; 49 | 50 | private: 51 | 52 | struct state { 53 | bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings. 54 | bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings. 55 | }; 56 | 57 | static state &global_state() { 58 | static state instance; 59 | return instance; 60 | } 61 | 62 | state previous_state; 63 | }; 64 | 65 | NAMESPACE_END(pybind11) 66 | -------------------------------------------------------------------------------- /EAST/lanms/include/pybind11/stl.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/stl.h: Transparent conversion for STL data types 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #if defined(_MSC_VER) 22 | #pragma warning(push) 23 | #pragma warning(disable: 4127) // warning C4127: Conditional expression is constant 24 | #endif 25 | 26 | #ifdef __has_include 27 | // std::optional (but including it in c++14 mode isn't allowed) 28 | # if defined(PYBIND11_CPP17) && __has_include() 29 | # include 30 | # define PYBIND11_HAS_OPTIONAL 1 31 | # endif 32 | // std::experimental::optional (but not allowed in c++11 mode) 33 | # if defined(PYBIND11_CPP14) && __has_include() 34 | # include 35 | # define PYBIND11_HAS_EXP_OPTIONAL 1 36 | # endif 37 | // std::variant 38 | # if defined(PYBIND11_CPP17) && __has_include() 39 | # include 40 | # define PYBIND11_HAS_VARIANT 1 41 | # endif 42 | #elif defined(_MSC_VER) && defined(PYBIND11_CPP17) 43 | # include 44 | # include 45 | # define PYBIND11_HAS_OPTIONAL 1 46 | # define PYBIND11_HAS_VARIANT 1 47 | #endif 48 | 49 | NAMESPACE_BEGIN(pybind11) 50 | NAMESPACE_BEGIN(detail) 51 | 52 | /// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for 53 | /// forwarding a container element). Typically used indirect via forwarded_type(), below. 54 | template 55 | using forwarded_type = conditional_t< 56 | std::is_lvalue_reference::value, remove_reference_t &, remove_reference_t &&>; 57 | 58 | /// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically 59 | /// used for forwarding a container's elements. 60 | template 61 | forwarded_type forward_like(U &&u) { 62 | return std::forward>(std::forward(u)); 63 | } 64 | 65 | template struct set_caster { 66 | using type = Type; 67 | using key_conv = make_caster; 68 | 69 | bool load(handle src, bool convert) { 70 | if (!isinstance(src)) 71 | return false; 72 | auto s = reinterpret_borrow(src); 73 | value.clear(); 74 | for (auto entry : s) { 75 | key_conv conv; 76 | if (!conv.load(entry, convert)) 77 | return false; 78 | value.insert(cast_op(std::move(conv))); 79 | } 80 | return true; 81 | } 82 | 83 | template 84 | static handle cast(T &&src, return_value_policy policy, handle parent) { 85 | pybind11::set s; 86 | for (auto &value: src) { 87 | auto value_ = reinterpret_steal(key_conv::cast(forward_like(value), policy, parent)); 88 | if (!value_ || !s.add(value_)) 89 | return handle(); 90 | } 91 | return s.release(); 92 | } 93 | 94 | PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name() + _("]")); 95 | }; 96 | 97 | template struct map_caster { 98 | using key_conv = make_caster; 99 | using value_conv = make_caster; 100 | 101 | bool load(handle src, bool convert) { 102 | if (!isinstance(src)) 103 | return false; 104 | auto d = reinterpret_borrow(src); 105 | value.clear(); 106 | for (auto it : d) { 107 | key_conv kconv; 108 | value_conv vconv; 109 | if (!kconv.load(it.first.ptr(), convert) || 110 | !vconv.load(it.second.ptr(), convert)) 111 | return false; 112 | value.emplace(cast_op(std::move(kconv)), cast_op(std::move(vconv))); 113 | } 114 | return true; 115 | } 116 | 117 | template 118 | static handle cast(T &&src, return_value_policy policy, handle parent) { 119 | dict d; 120 | for (auto &kv: src) { 121 | auto key = reinterpret_steal(key_conv::cast(forward_like(kv.first), policy, parent)); 122 | auto value = reinterpret_steal(value_conv::cast(forward_like(kv.second), policy, parent)); 123 | if (!key || !value) 124 | return handle(); 125 | d[key] = value; 126 | } 127 | return d.release(); 128 | } 129 | 130 | PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name() + _(", ") + value_conv::name() + _("]")); 131 | }; 132 | 133 | template struct list_caster { 134 | using value_conv = make_caster; 135 | 136 | bool load(handle src, bool convert) { 137 | if (!isinstance(src)) 138 | return false; 139 | auto s = reinterpret_borrow(src); 140 | value.clear(); 141 | reserve_maybe(s, &value); 142 | for (auto it : s) { 143 | value_conv conv; 144 | if (!conv.load(it, convert)) 145 | return false; 146 | value.push_back(cast_op(std::move(conv))); 147 | } 148 | return true; 149 | } 150 | 151 | private: 152 | template ().reserve(0)), void>::value, int> = 0> 154 | void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); } 155 | void reserve_maybe(sequence, void *) { } 156 | 157 | public: 158 | template 159 | static handle cast(T &&src, return_value_policy policy, handle parent) { 160 | list l(src.size()); 161 | size_t index = 0; 162 | for (auto &value: src) { 163 | auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); 164 | if (!value_) 165 | return handle(); 166 | PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference 167 | } 168 | return l.release(); 169 | } 170 | 171 | PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name() + _("]")); 172 | }; 173 | 174 | template struct type_caster> 175 | : list_caster, Type> { }; 176 | 177 | template struct type_caster> 178 | : list_caster, Type> { }; 179 | 180 | template struct array_caster { 181 | using value_conv = make_caster; 182 | 183 | private: 184 | template 185 | bool require_size(enable_if_t size) { 186 | if (value.size() != size) 187 | value.resize(size); 188 | return true; 189 | } 190 | template 191 | bool require_size(enable_if_t size) { 192 | return size == Size; 193 | } 194 | 195 | public: 196 | bool load(handle src, bool convert) { 197 | if (!isinstance(src)) 198 | return false; 199 | auto l = reinterpret_borrow(src); 200 | if (!require_size(l.size())) 201 | return false; 202 | size_t ctr = 0; 203 | for (auto it : l) { 204 | value_conv conv; 205 | if (!conv.load(it, convert)) 206 | return false; 207 | value[ctr++] = cast_op(std::move(conv)); 208 | } 209 | return true; 210 | } 211 | 212 | template 213 | static handle cast(T &&src, return_value_policy policy, handle parent) { 214 | list l(src.size()); 215 | size_t index = 0; 216 | for (auto &value: src) { 217 | auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); 218 | if (!value_) 219 | return handle(); 220 | PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference 221 | } 222 | return l.release(); 223 | } 224 | 225 | PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name() + _(_(""), _("[") + _() + _("]")) + _("]")); 226 | }; 227 | 228 | template struct type_caster> 229 | : array_caster, Type, false, Size> { }; 230 | 231 | template struct type_caster> 232 | : array_caster, Type, true> { }; 233 | 234 | template struct type_caster> 235 | : set_caster, Key> { }; 236 | 237 | template struct type_caster> 238 | : set_caster, Key> { }; 239 | 240 | template struct type_caster> 241 | : map_caster, Key, Value> { }; 242 | 243 | template struct type_caster> 244 | : map_caster, Key, Value> { }; 245 | 246 | // This type caster is intended to be used for std::optional and std::experimental::optional 247 | template struct optional_caster { 248 | using value_conv = make_caster; 249 | 250 | template 251 | static handle cast(T_ &&src, return_value_policy policy, handle parent) { 252 | if (!src) 253 | return none().inc_ref(); 254 | return value_conv::cast(*std::forward(src), policy, parent); 255 | } 256 | 257 | bool load(handle src, bool convert) { 258 | if (!src) { 259 | return false; 260 | } else if (src.is_none()) { 261 | return true; // default-constructed value is already empty 262 | } 263 | value_conv inner_caster; 264 | if (!inner_caster.load(src, convert)) 265 | return false; 266 | 267 | value.emplace(cast_op(std::move(inner_caster))); 268 | return true; 269 | } 270 | 271 | PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name() + _("]")); 272 | }; 273 | 274 | #if PYBIND11_HAS_OPTIONAL 275 | template struct type_caster> 276 | : public optional_caster> {}; 277 | 278 | template<> struct type_caster 279 | : public void_caster {}; 280 | #endif 281 | 282 | #if PYBIND11_HAS_EXP_OPTIONAL 283 | template struct type_caster> 284 | : public optional_caster> {}; 285 | 286 | template<> struct type_caster 287 | : public void_caster {}; 288 | #endif 289 | 290 | /// Visit a variant and cast any found type to Python 291 | struct variant_caster_visitor { 292 | return_value_policy policy; 293 | handle parent; 294 | 295 | template 296 | handle operator()(T &&src) const { 297 | return make_caster::cast(std::forward(src), policy, parent); 298 | } 299 | }; 300 | 301 | /// Helper class which abstracts away variant's `visit` function. `std::variant` and similar 302 | /// `namespace::variant` types which provide a `namespace::visit()` function are handled here 303 | /// automatically using argument-dependent lookup. Users can provide specializations for other 304 | /// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`. 305 | template class Variant> 306 | struct visit_helper { 307 | template 308 | static auto call(Args &&...args) -> decltype(visit(std::forward(args)...)) { 309 | return visit(std::forward(args)...); 310 | } 311 | }; 312 | 313 | /// Generic variant caster 314 | template struct variant_caster; 315 | 316 | template class V, typename... Ts> 317 | struct variant_caster> { 318 | static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative."); 319 | 320 | template 321 | bool load_alternative(handle src, bool convert, type_list) { 322 | auto caster = make_caster(); 323 | if (caster.load(src, convert)) { 324 | value = cast_op(caster); 325 | return true; 326 | } 327 | return load_alternative(src, convert, type_list{}); 328 | } 329 | 330 | bool load_alternative(handle, bool, type_list<>) { return false; } 331 | 332 | bool load(handle src, bool convert) { 333 | // Do a first pass without conversions to improve constructor resolution. 334 | // E.g. `py::int_(1).cast>()` needs to fill the `int` 335 | // slot of the variant. Without two-pass loading `double` would be filled 336 | // because it appears first and a conversion is possible. 337 | if (convert && load_alternative(src, false, type_list{})) 338 | return true; 339 | return load_alternative(src, convert, type_list{}); 340 | } 341 | 342 | template 343 | static handle cast(Variant &&src, return_value_policy policy, handle parent) { 344 | return visit_helper::call(variant_caster_visitor{policy, parent}, 345 | std::forward(src)); 346 | } 347 | 348 | using Type = V; 349 | PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster::name()...) + _("]")); 350 | }; 351 | 352 | #if PYBIND11_HAS_VARIANT 353 | template 354 | struct type_caster> : variant_caster> { }; 355 | #endif 356 | NAMESPACE_END(detail) 357 | 358 | inline std::ostream &operator<<(std::ostream &os, const handle &obj) { 359 | os << (std::string) str(obj); 360 | return os; 361 | } 362 | 363 | NAMESPACE_END(pybind11) 364 | 365 | #if defined(_MSC_VER) 366 | #pragma warning(pop) 367 | #endif 368 | -------------------------------------------------------------------------------- /EAST/lanms/include/pybind11/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | NAMESPACE_BEGIN(pybind11) 20 | NAMESPACE_BEGIN(detail) 21 | /// Erase all occurrences of a substring 22 | inline void erase_all(std::string &string, const std::string &search) { 23 | for (size_t pos = 0;;) { 24 | pos = string.find(search, pos); 25 | if (pos == std::string::npos) break; 26 | string.erase(pos, search.length()); 27 | } 28 | } 29 | 30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 31 | #if defined(__GNUG__) 32 | int status = 0; 33 | std::unique_ptr res { 34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 35 | if (status == 0) 36 | name = res.get(); 37 | #else 38 | detail::erase_all(name, "class "); 39 | detail::erase_all(name, "struct "); 40 | detail::erase_all(name, "enum "); 41 | #endif 42 | detail::erase_all(name, "pybind11::"); 43 | } 44 | NAMESPACE_END(detail) 45 | 46 | /// Return a string representation of a C++ type 47 | template static std::string type_id() { 48 | std::string name(typeid(T).name()); 49 | detail::clean_type_id(name); 50 | return name; 51 | } 52 | 53 | NAMESPACE_END(pybind11) 54 | -------------------------------------------------------------------------------- /EAST/lanms/lanms.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "clipper/clipper.hpp" 4 | 5 | // locality-aware NMS 6 | namespace lanms { 7 | 8 | namespace cl = ClipperLib; 9 | 10 | struct Polygon { 11 | cl::Path poly; 12 | float score; 13 | }; 14 | 15 | float paths_area(const ClipperLib::Paths &ps) { 16 | float area = 0; 17 | for (auto &&p: ps) 18 | area += cl::Area(p); 19 | return area; 20 | } 21 | 22 | float poly_iou(const Polygon &a, const Polygon &b) { 23 | cl::Clipper clpr; 24 | clpr.AddPath(a.poly, cl::ptSubject, true); 25 | clpr.AddPath(b.poly, cl::ptClip, true); 26 | 27 | cl::Paths inter, uni; 28 | clpr.Execute(cl::ctIntersection, inter, cl::pftEvenOdd); 29 | clpr.Execute(cl::ctUnion, uni, cl::pftEvenOdd); 30 | 31 | auto inter_area = paths_area(inter), 32 | uni_area = paths_area(uni); 33 | return std::abs(inter_area) / std::max(std::abs(uni_area), 1.0f); 34 | } 35 | 36 | bool should_merge(const Polygon &a, const Polygon &b, float iou_threshold) { 37 | return poly_iou(a, b) > iou_threshold; 38 | } 39 | 40 | /** 41 | * Incrementally merge polygons 42 | */ 43 | class PolyMerger { 44 | public: 45 | PolyMerger(): score(0), nr_polys(0) { 46 | memset(data, 0, sizeof(data)); 47 | } 48 | 49 | /** 50 | * Add a new polygon to be merged. 51 | */ 52 | void add(const Polygon &p_given) { 53 | Polygon p; 54 | if (nr_polys > 0) { 55 | // vertices of two polygons to merge may not in the same order; 56 | // we match their vertices by choosing the ordering that 57 | // minimizes the total squared distance. 58 | // see function normalize_poly for details. 59 | p = normalize_poly(get(), p_given); 60 | } else { 61 | p = p_given; 62 | } 63 | assert(p.poly.size() == 4); 64 | auto &poly = p.poly; 65 | auto s = p.score; 66 | data[0] += poly[0].X * s; 67 | data[1] += poly[0].Y * s; 68 | 69 | data[2] += poly[1].X * s; 70 | data[3] += poly[1].Y * s; 71 | 72 | data[4] += poly[2].X * s; 73 | data[5] += poly[2].Y * s; 74 | 75 | data[6] += poly[3].X * s; 76 | data[7] += poly[3].Y * s; 77 | 78 | score += p.score; 79 | 80 | nr_polys += 1; 81 | } 82 | 83 | inline std::int64_t sqr(std::int64_t x) { return x * x; } 84 | 85 | Polygon normalize_poly( 86 | const Polygon &ref, 87 | const Polygon &p) { 88 | 89 | std::int64_t min_d = std::numeric_limits::max(); 90 | size_t best_start = 0, best_order = 0; 91 | 92 | for (size_t start = 0; start < 4; start ++) { 93 | size_t j = start; 94 | std::int64_t d = ( 95 | sqr(ref.poly[(j + 0) % 4].X - p.poly[(j + 0) % 4].X) 96 | + sqr(ref.poly[(j + 0) % 4].Y - p.poly[(j + 0) % 4].Y) 97 | + sqr(ref.poly[(j + 1) % 4].X - p.poly[(j + 1) % 4].X) 98 | + sqr(ref.poly[(j + 1) % 4].Y - p.poly[(j + 1) % 4].Y) 99 | + sqr(ref.poly[(j + 2) % 4].X - p.poly[(j + 2) % 4].X) 100 | + sqr(ref.poly[(j + 2) % 4].Y - p.poly[(j + 2) % 4].Y) 101 | + sqr(ref.poly[(j + 3) % 4].X - p.poly[(j + 3) % 4].X) 102 | + sqr(ref.poly[(j + 3) % 4].Y - p.poly[(j + 3) % 4].Y) 103 | ); 104 | if (d < min_d) { 105 | min_d = d; 106 | best_start = start; 107 | best_order = 0; 108 | } 109 | 110 | d = ( 111 | sqr(ref.poly[(j + 0) % 4].X - p.poly[(j + 3) % 4].X) 112 | + sqr(ref.poly[(j + 0) % 4].Y - p.poly[(j + 3) % 4].Y) 113 | + sqr(ref.poly[(j + 1) % 4].X - p.poly[(j + 2) % 4].X) 114 | + sqr(ref.poly[(j + 1) % 4].Y - p.poly[(j + 2) % 4].Y) 115 | + sqr(ref.poly[(j + 2) % 4].X - p.poly[(j + 1) % 4].X) 116 | + sqr(ref.poly[(j + 2) % 4].Y - p.poly[(j + 1) % 4].Y) 117 | + sqr(ref.poly[(j + 3) % 4].X - p.poly[(j + 0) % 4].X) 118 | + sqr(ref.poly[(j + 3) % 4].Y - p.poly[(j + 0) % 4].Y) 119 | ); 120 | if (d < min_d) { 121 | min_d = d; 122 | best_start = start; 123 | best_order = 1; 124 | } 125 | } 126 | 127 | Polygon r; 128 | r.poly.resize(4); 129 | auto j = best_start; 130 | if (best_order == 0) { 131 | for (size_t i = 0; i < 4; i ++) 132 | r.poly[i] = p.poly[(j + i) % 4]; 133 | } else { 134 | for (size_t i = 0; i < 4; i ++) 135 | r.poly[i] = p.poly[(j + 4 - i - 1) % 4]; 136 | } 137 | r.score = p.score; 138 | return r; 139 | } 140 | 141 | Polygon get() const { 142 | Polygon p; 143 | 144 | auto &poly = p.poly; 145 | poly.resize(4); 146 | auto score_inv = 1.0f / std::max(1e-8f, score); 147 | poly[0].X = data[0] * score_inv; 148 | poly[0].Y = data[1] * score_inv; 149 | poly[1].X = data[2] * score_inv; 150 | poly[1].Y = data[3] * score_inv; 151 | poly[2].X = data[4] * score_inv; 152 | poly[2].Y = data[5] * score_inv; 153 | poly[3].X = data[6] * score_inv; 154 | poly[3].Y = data[7] * score_inv; 155 | 156 | assert(score > 0); 157 | p.score = score; 158 | 159 | return p; 160 | } 161 | 162 | private: 163 | std::int64_t data[8]; 164 | float score; 165 | std::int32_t nr_polys; 166 | }; 167 | 168 | 169 | /** 170 | * The standard NMS algorithm. 171 | */ 172 | std::vector standard_nms(std::vector &polys, float iou_threshold) { 173 | size_t n = polys.size(); 174 | if (n == 0) 175 | return {}; 176 | std::vector indices(n); 177 | std::iota(std::begin(indices), std::end(indices), 0); 178 | std::sort(std::begin(indices), std::end(indices), [&](size_t i, size_t j) { return polys[i].score > polys[j].score; }); 179 | 180 | std::vector keep; 181 | while (indices.size()) { 182 | size_t p = 0, cur = indices[0]; 183 | keep.emplace_back(cur); 184 | for (size_t i = 1; i < indices.size(); i ++) { 185 | if (!should_merge(polys[cur], polys[indices[i]], iou_threshold)) { 186 | indices[p ++] = indices[i]; 187 | } 188 | } 189 | indices.resize(p); 190 | } 191 | 192 | std::vector ret; 193 | for (auto &&i: keep) { 194 | ret.emplace_back(polys[i]); 195 | } 196 | return ret; 197 | } 198 | 199 | std::vector 200 | merge_quadrangle_n9(const float *data, size_t n, float iou_threshold) { 201 | using cInt = cl::cInt; 202 | 203 | // first pass 204 | std::vector polys; 205 | for (size_t i = 0; i < n; i ++) { 206 | auto p = data + i * 9; 207 | Polygon poly{ 208 | { 209 | {cInt(p[0]), cInt(p[1])}, 210 | {cInt(p[2]), cInt(p[3])}, 211 | {cInt(p[4]), cInt(p[5])}, 212 | {cInt(p[6]), cInt(p[7])}, 213 | }, 214 | p[8], 215 | }; 216 | 217 | if (polys.size()) { 218 | // merge with the last one 219 | auto &bpoly = polys.back(); 220 | if (should_merge(poly, bpoly, iou_threshold)) { 221 | PolyMerger merger; 222 | merger.add(bpoly); 223 | merger.add(poly); 224 | bpoly = merger.get(); 225 | } else { 226 | polys.emplace_back(poly); 227 | } 228 | } else { 229 | polys.emplace_back(poly); 230 | } 231 | } 232 | return standard_nms(polys, iou_threshold); 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /EAST/locality_aware_nms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | os.system('pip install shapely') 4 | from shapely.geometry import Polygon 5 | 6 | 7 | def intersection(g, p): 8 | g = Polygon(g[:8].reshape((4, 2))) 9 | p = Polygon(p[:8].reshape((4, 2))) 10 | if not g.is_valid or not p.is_valid: 11 | return 0 12 | inter = Polygon(g).intersection(Polygon(p)).area 13 | union = g.area + p.area - inter 14 | if union == 0: 15 | return 0 16 | else: 17 | return inter/union 18 | 19 | 20 | def weighted_merge(g, p): 21 | g[:8] = (g[8] * g[:8] + p[8] * p[:8])/(g[8] + p[8]) 22 | g[8] = (g[8] + p[8]) 23 | return g 24 | 25 | 26 | def standard_nms(S, thres): 27 | order = np.argsort(S[:, 8])[::-1] 28 | keep = [] 29 | while order.size > 0: 30 | i = order[0] 31 | keep.append(i) 32 | ovr = np.array([intersection(S[i], S[t]) for t in order[1:]]) 33 | 34 | inds = np.where(ovr <= thres)[0] 35 | order = order[inds+1] 36 | 37 | return S[keep] 38 | 39 | 40 | def nms_locality(polys, thres=0.3): 41 | ''' 42 | locality aware nms of EAST 43 | :param polys: a N*9 numpy array. first 8 coordinates, then prob 44 | :return: boxes after nms 45 | ''' 46 | S = [] 47 | p = None 48 | for g in polys: 49 | if p is not None and intersection(g, p) > thres: 50 | p = weighted_merge(g, p) 51 | else: 52 | if p is not None: 53 | S.append(p) 54 | p = g 55 | if p is not None: 56 | S.append(p) 57 | 58 | if len(S) == 0: 59 | return np.array([]) 60 | return standard_nms(np.array(S), thres) 61 | 62 | 63 | if __name__ == '__main__': 64 | # 343,350,448,135,474,143,369,359 65 | print(Polygon(np.array([[343, 350], [448, 135], 66 | [474, 143], [369, 359]])).area) 67 | -------------------------------------------------------------------------------- /EAST/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from tensorflow.contrib import slim 5 | 6 | tf.app.flags.DEFINE_integer('text_scale', 512, '') 7 | 8 | from nets import resnet_v1 9 | 10 | FLAGS = tf.app.flags.FLAGS 11 | 12 | 13 | def unpool(inputs): 14 | return tf.image.resize_bilinear(inputs, size=[tf.shape(inputs)[1]*2, tf.shape(inputs)[2]*2]) 15 | 16 | 17 | def mean_image_subtraction(images, means=[123.68, 116.78, 103.94]): 18 | ''' 19 | image normalization 20 | :param images: 21 | :param means: 22 | :return: 23 | ''' 24 | num_channels = images.get_shape().as_list()[-1] 25 | if len(means) != num_channels: 26 | raise ValueError('len(means) must match the number of channels') 27 | channels = tf.split(axis=3, num_or_size_splits=num_channels, value=images) 28 | for i in range(num_channels): 29 | channels[i] -= means[i] 30 | return tf.concat(axis=3, values=channels) 31 | 32 | 33 | def model(images, weight_decay=1e-5, is_training=True): 34 | ''' 35 | define the model, we use slim's implemention of resnet 36 | ''' 37 | images = mean_image_subtraction(images) 38 | 39 | with slim.arg_scope(resnet_v1.resnet_arg_scope(weight_decay=weight_decay)): 40 | logits, end_points = resnet_v1.resnet_v1_50(images, is_training=is_training, scope='resnet_v1_50') 41 | 42 | with tf.variable_scope('feature_fusion', values=[end_points.values]): 43 | batch_norm_params = { 44 | 'decay': 0.997, 45 | 'epsilon': 1e-5, 46 | 'scale': True, 47 | 'is_training': is_training 48 | } 49 | with slim.arg_scope([slim.conv2d], 50 | activation_fn=tf.nn.relu, 51 | normalizer_fn=slim.batch_norm, 52 | normalizer_params=batch_norm_params, 53 | weights_regularizer=slim.l2_regularizer(weight_decay)): 54 | f = [end_points['pool5'], end_points['pool4'], 55 | end_points['pool3'], end_points['pool2']] 56 | for i in range(4): 57 | print('Shape of f_{} {}'.format(i, f[i].shape)) 58 | g = [None, None, None, None] 59 | h = [None, None, None, None] 60 | num_outputs = [None, 128, 64, 32] 61 | for i in range(4): 62 | if i == 0: 63 | h[i] = f[i] 64 | else: 65 | c1_1 = slim.conv2d(tf.concat([g[i-1], f[i]], axis=-1), num_outputs[i], 1) 66 | h[i] = slim.conv2d(c1_1, num_outputs[i], 3) 67 | if i <= 2: 68 | g[i] = unpool(h[i]) 69 | else: 70 | g[i] = slim.conv2d(h[i], num_outputs[i], 3) 71 | print('Shape of h_{} {}, g_{} {}'.format(i, h[i].shape, i, g[i].shape)) 72 | 73 | # here we use a slightly different way for regression part, 74 | # we first use a sigmoid to limit the regression range, and also 75 | # this is do with the angle map 76 | F_score = slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) 77 | # 4 channel of axis aligned bbox and 1 channel rotation angle 78 | geo_map = slim.conv2d(g[3], 4, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) * FLAGS.text_scale 79 | angle_map = (slim.conv2d(g[3], 1, 1, activation_fn=tf.nn.sigmoid, normalizer_fn=None) - 0.5) * np.pi/2 # angle is between [-45, 45] 80 | F_geometry = tf.concat([geo_map, angle_map], axis=-1) 81 | 82 | return F_score, F_geometry 83 | 84 | 85 | def dice_coefficient(y_true_cls, y_pred_cls, 86 | training_mask): 87 | ''' 88 | dice loss 89 | :param y_true_cls: 90 | :param y_pred_cls: 91 | :param training_mask: 92 | :return: 93 | ''' 94 | eps = 1e-5 95 | intersection = tf.reduce_sum(y_true_cls * y_pred_cls * training_mask) 96 | union = tf.reduce_sum(y_true_cls * training_mask) + tf.reduce_sum(y_pred_cls * training_mask) + eps 97 | loss = 1. - (2 * intersection / union) 98 | tf.summary.scalar('classification_dice_loss', loss) 99 | return loss 100 | 101 | 102 | 103 | def loss(y_true_cls, y_pred_cls, 104 | y_true_geo, y_pred_geo, 105 | training_mask): 106 | ''' 107 | define the loss used for training, contraning two part, 108 | the first part we use dice loss instead of weighted logloss, 109 | the second part is the iou loss defined in the paper 110 | :param y_true_cls: ground truth of text 111 | :param y_pred_cls: prediction os text 112 | :param y_true_geo: ground truth of geometry 113 | :param y_pred_geo: prediction of geometry 114 | :param training_mask: mask used in training, to ignore some text annotated by ### 115 | :return: 116 | ''' 117 | classification_loss = dice_coefficient(y_true_cls, y_pred_cls, training_mask) 118 | # scale classification loss to match the iou loss part 119 | classification_loss *= 0.01 120 | 121 | # d1 -> top, d2->right, d3->bottom, d4->left 122 | d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = tf.split(value=y_true_geo, num_or_size_splits=5, axis=3) 123 | d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = tf.split(value=y_pred_geo, num_or_size_splits=5, axis=3) 124 | area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt) 125 | area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred) 126 | w_union = tf.minimum(d2_gt, d2_pred) + tf.minimum(d4_gt, d4_pred) 127 | h_union = tf.minimum(d1_gt, d1_pred) + tf.minimum(d3_gt, d3_pred) 128 | area_intersect = w_union * h_union 129 | area_union = area_gt + area_pred - area_intersect 130 | L_AABB = -tf.log((area_intersect + 1.0)/(area_union + 1.0)) 131 | L_theta = 1 - tf.cos(theta_pred - theta_gt) 132 | tf.summary.scalar('geometry_AABB', tf.reduce_mean(L_AABB * y_true_cls * training_mask)) 133 | tf.summary.scalar('geometry_theta', tf.reduce_mean(L_theta * y_true_cls * training_mask)) 134 | L_g = L_AABB + 20 * L_theta 135 | 136 | return tf.reduce_mean(L_g * y_true_cls * training_mask) + classification_loss 137 | -------------------------------------------------------------------------------- /EAST/multigpu_train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import tensorflow as tf 4 | import moxing.tensorflow as mox 5 | from tensorflow.contrib import slim 6 | 7 | tf.app.flags.DEFINE_integer('input_size', 512, '') 8 | tf.app.flags.DEFINE_integer('batch_size_per_gpu', 16, '') 9 | tf.app.flags.DEFINE_integer('num_readers', 16, '') 10 | tf.app.flags.DEFINE_float('learning_rate', 0.001, '') 11 | tf.app.flags.DEFINE_integer('max_steps', 10000, '') 12 | tf.app.flags.DEFINE_float('moving_average_decay', 0.997, '') 13 | tf.app.flags.DEFINE_string('gpu_list', '0,1,2,3', '') 14 | tf.app.flags.DEFINE_string('checkpoint_path', '/cache/east_ckpt/', '') 15 | tf.app.flags.DEFINE_string('checkpoint_path_obs', 's3://tcd-public/ckpt', '') 16 | tf.app.flags.DEFINE_boolean('restore', False, 'whether to resotre from checkpoint') 17 | tf.app.flags.DEFINE_integer('save_checkpoint_steps', 1000, '') 18 | tf.app.flags.DEFINE_integer('save_summary_steps', 100, '') 19 | tf.app.flags.DEFINE_string('pretrained_model_path', None, '') 20 | tf.app.flags.DEFINE_string('data_url', 'q', '') 21 | tf.app.flags.DEFINE_string('train_url', 'q', '') 22 | tf.app.flags.DEFINE_integer('num_gpus', 1, '') 23 | tf.app.flags.DEFINE_string('training_data_path_obs', './data/train', 24 | 'training dataset to use') 25 | import model 26 | import icdar 27 | 28 | FLAGS = tf.app.flags.FLAGS 29 | 30 | gpus = list(range(len(FLAGS.gpu_list.split(',')))) 31 | 32 | mox.file.copy_parallel(FLAGS.training_data_path_obs, FLAGS.training_data_path) 33 | 34 | 35 | def tower_loss(images, score_maps, geo_maps, training_masks, reuse_variables=None): 36 | # Build inference graph 37 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables): 38 | f_score, f_geometry = model.model(images, is_training=True) 39 | 40 | model_loss = model.loss(score_maps, f_score, 41 | geo_maps, f_geometry, 42 | training_masks) 43 | total_loss = tf.add_n([model_loss] + tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 44 | 45 | # add summary 46 | if reuse_variables is None: 47 | tf.summary.image('input', images) 48 | tf.summary.image('score_map', score_maps) 49 | tf.summary.image('score_map_pred', f_score * 255) 50 | tf.summary.image('geo_map_0', geo_maps[:, :, :, 0:1]) 51 | tf.summary.image('geo_map_0_pred', f_geometry[:, :, :, 0:1]) 52 | tf.summary.image('training_masks', training_masks) 53 | tf.summary.scalar('model_loss', model_loss) 54 | tf.summary.scalar('total_loss', total_loss) 55 | 56 | return total_loss, model_loss 57 | 58 | 59 | def average_gradients(tower_grads): 60 | average_grads = [] 61 | for grad_and_vars in zip(*tower_grads): 62 | grads = [] 63 | for g, _ in grad_and_vars: 64 | expanded_g = tf.expand_dims(g, 0) 65 | grads.append(expanded_g) 66 | 67 | grad = tf.concat(grads, 0) 68 | grad = tf.reduce_mean(grad, 0) 69 | 70 | v = grad_and_vars[0][1] 71 | grad_and_var = (grad, v) 72 | average_grads.append(grad_and_var) 73 | 74 | return average_grads 75 | 76 | 77 | def main(argv=None): 78 | import os 79 | os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list 80 | if not tf.gfile.Exists(FLAGS.checkpoint_path): 81 | tf.gfile.MkDir(FLAGS.checkpoint_path) 82 | else: 83 | if not FLAGS.restore: 84 | tf.gfile.DeleteRecursively(FLAGS.checkpoint_path) 85 | tf.gfile.MkDir(FLAGS.checkpoint_path) 86 | 87 | input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images') 88 | input_score_maps = tf.placeholder(tf.float32, shape=[None, None, None, 1], 89 | name='input_score_maps') 90 | if FLAGS.geometry == 'RBOX': 91 | input_geo_maps = tf.placeholder(tf.float32, shape=[None, None, None, 5], 92 | name='input_geo_maps') 93 | else: 94 | input_geo_maps = tf.placeholder(tf.float32, shape=[None, None, None, 8], 95 | name='input_geo_maps') 96 | input_training_masks = tf.placeholder(tf.float32, shape=[None, None, None, 1], 97 | name='input_training_masks') 98 | 99 | global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), 100 | trainable=False) 101 | learning_rate = tf.train.exponential_decay(FLAGS.learning_rate, global_step, decay_steps=500, 102 | decay_rate=0.94, staircase=True) 103 | # add summary 104 | tf.summary.scalar('learning_rate', learning_rate) 105 | opt = tf.train.AdamOptimizer(learning_rate) 106 | # opt = tf.train.MomentumOptimizer(learning_rate, 0.9) 107 | 108 | 109 | # split 110 | input_images_split = tf.split(input_images, len(gpus)) 111 | input_score_maps_split = tf.split(input_score_maps, len(gpus)) 112 | input_geo_maps_split = tf.split(input_geo_maps, len(gpus)) 113 | input_training_masks_split = tf.split(input_training_masks, len(gpus)) 114 | 115 | tower_grads = [] 116 | reuse_variables = None 117 | for i, gpu_id in enumerate(gpus): 118 | with tf.device('/gpu:%d' % gpu_id): 119 | with tf.name_scope('model_%d' % gpu_id) as scope: 120 | iis = input_images_split[i] 121 | isms = input_score_maps_split[i] 122 | igms = input_geo_maps_split[i] 123 | itms = input_training_masks_split[i] 124 | total_loss, model_loss = tower_loss(iis, isms, igms, itms, reuse_variables) 125 | batch_norm_updates_op = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope)) 126 | reuse_variables = True 127 | 128 | grads = opt.compute_gradients(total_loss) 129 | tower_grads.append(grads) 130 | 131 | grads = average_gradients(tower_grads) 132 | apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) 133 | 134 | summary_op = tf.summary.merge_all() 135 | # save moving average 136 | variable_averages = tf.train.ExponentialMovingAverage( 137 | FLAGS.moving_average_decay, global_step) 138 | variables_averages_op = variable_averages.apply(tf.trainable_variables()) 139 | # batch norm updates 140 | with tf.control_dependencies([variables_averages_op, apply_gradient_op, batch_norm_updates_op]): 141 | train_op = tf.no_op(name='train_op') 142 | 143 | saver = tf.train.Saver(tf.global_variables()) 144 | summary_writer = tf.summary.FileWriter(FLAGS.checkpoint_path, tf.get_default_graph()) 145 | 146 | init = tf.global_variables_initializer() 147 | 148 | if FLAGS.pretrained_model_path is not None: 149 | variable_restore_op = slim.assign_from_checkpoint_fn(FLAGS.pretrained_model_path, slim.get_trainable_variables(), 150 | ignore_missing_vars=True) 151 | 152 | # run_metadata = tf.RunMetadata() 153 | # run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 154 | config = tf.ConfigProto(allow_soft_placement=True) 155 | with tf.Session(config=config) as sess: 156 | if FLAGS.restore: 157 | print('continue training from previous checkpoint') 158 | ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_path) 159 | saver.restore(sess, ckpt) 160 | else: 161 | sess.run(init) 162 | if FLAGS.pretrained_model_path is not None: 163 | variable_restore_op(sess) 164 | 165 | data_generator = icdar.get_batch(num_workers=FLAGS.num_readers, 166 | input_size=FLAGS.input_size, 167 | batch_size=FLAGS.batch_size_per_gpu * len(gpus)) 168 | 169 | start = time.time() 170 | for step in range(FLAGS.max_steps): 171 | data = next(data_generator) 172 | ml, tl, _ = sess.run([model_loss, total_loss, train_op], feed_dict={input_images: data[0], 173 | input_score_maps: data[2], 174 | input_geo_maps: data[3], 175 | input_training_masks: data[4]}) 176 | 177 | if np.isnan(tl): 178 | print('Loss diverged, stop training') 179 | break 180 | 181 | if step % 10 == 0: 182 | avg_time_per_step = (time.time() - start)/10 183 | avg_examples_per_second = (10 * FLAGS.batch_size_per_gpu * len(gpus))/(time.time() - start) 184 | start = time.time() 185 | print('Step {:06d}, model loss {:.4f}, total loss {:.4f}, {:.2f} seconds/step, {:.2f} examples/second'.format( 186 | step, ml, tl, avg_time_per_step, avg_examples_per_second)) 187 | 188 | if step % FLAGS.save_checkpoint_steps == 0: 189 | saver.save(sess, FLAGS.checkpoint_path + 'model.ckpt', global_step=global_step) 190 | mox.file.copy_parallel(FLAGS.checkpoint_path, FLAGS.checkpoint_path_obs) 191 | if step % FLAGS.save_summary_steps == 0: 192 | _, tl, summary_str = sess.run([train_op, total_loss, summary_op], feed_dict={input_images: data[0], 193 | input_score_maps: data[2], 194 | input_geo_maps: data[3], 195 | input_training_masks: data[4]}) 196 | summary_writer.add_summary(summary_str, global_step=step) 197 | 198 | 199 | if __name__ == '__main__': 200 | tf.app.run() 201 | -------------------------------------------------------------------------------- /EAST/nets/resnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains building blocks for various versions of Residual Networks. 16 | 17 | Residual networks (ResNets) were proposed in: 18 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 20 | 21 | More variants were introduced in: 22 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 24 | 25 | We can obtain different ResNet variants by changing the network depth, width, 26 | and form of residual unit. This module implements the infrastructure for 27 | building them. Concrete ResNet units and full ResNet networks are implemented in 28 | the accompanying resnet_v1.py and resnet_v2.py modules. 29 | 30 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current 31 | implementation we subsample the output activations in the last residual unit of 32 | each block, instead of subsampling the input activations in the first residual 33 | unit of each block. The two implementations give identical results but our 34 | implementation is more memory efficient. 35 | """ 36 | 37 | 38 | 39 | 40 | import collections 41 | import tensorflow as tf 42 | 43 | slim = tf.contrib.slim 44 | 45 | 46 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): 47 | """A named tuple describing a ResNet block. 48 | 49 | Its parts are: 50 | scope: The scope of the `Block`. 51 | unit_fn: The ResNet unit function which takes as input a `Tensor` and 52 | returns another `Tensor` with the output of the ResNet unit. 53 | args: A list of length equal to the number of units in the `Block`. The list 54 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the 55 | block to serve as argument to unit_fn. 56 | """ 57 | 58 | 59 | def subsample(inputs, factor, scope=None): 60 | """Subsamples the input along the spatial dimensions. 61 | 62 | Args: 63 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 64 | factor: The subsampling factor. 65 | scope: Optional variable_scope. 66 | 67 | Returns: 68 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 69 | input, either intact (if factor == 1) or subsampled (if factor > 1). 70 | """ 71 | if factor == 1: 72 | return inputs 73 | else: 74 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 75 | 76 | 77 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 78 | """Strided 2-D convolution with 'SAME' padding. 79 | 80 | When stride > 1, then we do explicit zero-padding, followed by conv2d with 81 | 'VALID' padding. 82 | 83 | Note that 84 | 85 | net = conv2d_same(inputs, num_outputs, 3, stride=stride) 86 | 87 | is equivalent to 88 | 89 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') 90 | net = subsample(net, factor=stride) 91 | 92 | whereas 93 | 94 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') 95 | 96 | is different when the input's height or width is even, which is why we add the 97 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). 98 | 99 | Args: 100 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. 101 | num_outputs: An integer, the number of output filters. 102 | kernel_size: An int with the kernel_size of the filters. 103 | stride: An integer, the output stride. 104 | rate: An integer, rate for atrous convolution. 105 | scope: Scope. 106 | 107 | Returns: 108 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with 109 | the convolution output. 110 | """ 111 | if stride == 1: 112 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 113 | padding='SAME', scope=scope) 114 | else: 115 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 116 | pad_total = kernel_size_effective - 1 117 | pad_beg = pad_total // 2 118 | pad_end = pad_total - pad_beg 119 | inputs = tf.pad(inputs, 120 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 121 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 122 | rate=rate, padding='VALID', scope=scope) 123 | 124 | 125 | @slim.add_arg_scope 126 | def stack_blocks_dense(net, blocks, output_stride=None, 127 | outputs_collections=None): 128 | """Stacks ResNet `Blocks` and controls output feature density. 129 | 130 | First, this function creates scopes for the ResNet in the form of 131 | 'block_name/unit_1', 'block_name/unit_2', etc. 132 | 133 | Second, this function allows the user to explicitly control the ResNet 134 | output_stride, which is the ratio of the input to output spatial resolution. 135 | This is useful for dense prediction tasks such as semantic segmentation or 136 | object detection. 137 | 138 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a 139 | factor of 2 when transitioning between consecutive ResNet blocks. This results 140 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to 141 | half the nominal network stride (e.g., output_stride=4), then we compute 142 | responses twice. 143 | 144 | Control of the output feature density is implemented by atrous convolution. 145 | 146 | Args: 147 | net: A `Tensor` of size [batch, height, width, channels]. 148 | blocks: A list of length equal to the number of ResNet `Blocks`. Each 149 | element is a ResNet `Block` object describing the units in the `Block`. 150 | output_stride: If `None`, then the output will be computed at the nominal 151 | network stride. If output_stride is not `None`, it specifies the requested 152 | ratio of input to output spatial resolution, which needs to be equal to 153 | the product of unit strides from the start up to some level of the ResNet. 154 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, 155 | then valid values for the output_stride are 1, 2, 6, 24 or None (which 156 | is equivalent to output_stride=24). 157 | outputs_collections: Collection to add the ResNet block outputs. 158 | 159 | Returns: 160 | net: Output tensor with stride equal to the specified output_stride. 161 | 162 | Raises: 163 | ValueError: If the target output_stride is not valid. 164 | """ 165 | # The current_stride variable keeps track of the effective stride of the 166 | # activations. This allows us to invoke atrous convolution whenever applying 167 | # the next residual unit would result in the activations having stride larger 168 | # than the target output_stride. 169 | current_stride = 1 170 | 171 | # The atrous convolution rate parameter. 172 | rate = 1 173 | 174 | for block in blocks: 175 | with tf.variable_scope(block.scope, 'block', [net]) as sc: 176 | for i, unit in enumerate(block.args): 177 | if output_stride is not None and current_stride > output_stride: 178 | raise ValueError('The target output_stride cannot be reached.') 179 | 180 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]): 181 | unit_depth, unit_depth_bottleneck, unit_stride = unit 182 | # If we have reached the target output_stride, then we need to employ 183 | # atrous convolution with stride=1 and multiply the atrous rate by the 184 | # current unit's stride for use in subsequent layers. 185 | if output_stride is not None and current_stride == output_stride: 186 | net = block.unit_fn(net, 187 | depth=unit_depth, 188 | depth_bottleneck=unit_depth_bottleneck, 189 | stride=1, 190 | rate=rate) 191 | rate *= unit_stride 192 | 193 | else: 194 | net = block.unit_fn(net, 195 | depth=unit_depth, 196 | depth_bottleneck=unit_depth_bottleneck, 197 | stride=unit_stride, 198 | rate=1) 199 | current_stride *= unit_stride 200 | print(sc.name, net.shape) 201 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 202 | 203 | if output_stride is not None and current_stride != output_stride: 204 | raise ValueError('The target output_stride cannot be reached.') 205 | 206 | return net 207 | 208 | 209 | def resnet_arg_scope(weight_decay=0.0001, 210 | batch_norm_decay=0.997, 211 | batch_norm_epsilon=1e-5, 212 | batch_norm_scale=True): 213 | """Defines the default ResNet arg scope. 214 | 215 | TODO(gpapan): The batch-normalization related default values above are 216 | appropriate for use in conjunction with the reference ResNet models 217 | released at https://github.com/KaimingHe/deep-residual-networks. When 218 | training ResNets from scratch, they might need to be tuned. 219 | 220 | Args: 221 | weight_decay: The weight decay to use for regularizing the model. 222 | batch_norm_decay: The moving average decay when estimating layer activation 223 | statistics in batch normalization. 224 | batch_norm_epsilon: Small constant to prevent division by zero when 225 | normalizing activations by their variance in batch normalization. 226 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 227 | activations in the batch normalization layer. 228 | 229 | Returns: 230 | An `arg_scope` to use for the resnet models. 231 | """ 232 | batch_norm_params = { 233 | 'decay': batch_norm_decay, 234 | 'epsilon': batch_norm_epsilon, 235 | 'scale': batch_norm_scale, 236 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 237 | } 238 | 239 | with slim.arg_scope( 240 | [slim.conv2d], 241 | weights_regularizer=slim.l2_regularizer(weight_decay), 242 | weights_initializer=slim.variance_scaling_initializer(), 243 | activation_fn=tf.nn.relu, 244 | normalizer_fn=slim.batch_norm, 245 | normalizer_params=batch_norm_params): 246 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 247 | # The following implies padding='SAME' for pool1, which makes feature 248 | # alignment easier for dense prediction tasks. This is also used in 249 | # https://github.com/facebook/fb.resnet.torch. However the accompanying 250 | # code of 'Deep Residual Learning for Image Recognition' uses 251 | # padding='VALID' for pool1. You can switch to that choice by setting 252 | # slim.arg_scope([slim.max_pool2d], padding='VALID'). 253 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 254 | return arg_sc 255 | -------------------------------------------------------------------------------- /EAST/nets/resnet_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for the original form of Residual Networks. 16 | 17 | The 'v1' residual networks (ResNets) implemented in this module were proposed 18 | by: 19 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 20 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 21 | 22 | Other variants were introduced in: 23 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 25 | 26 | The networks defined in this module utilize the bottleneck building block of 27 | [1] with projection shortcuts only for increasing depths. They employ batch 28 | normalization *after* every weight layer. This is the architecture used by 29 | MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and 30 | ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1' 31 | architecture and the alternative 'v2' architecture of [2] which uses batch 32 | normalization *before* every weight layer in the so-called full pre-activation 33 | units. 34 | 35 | Typical use: 36 | 37 | from tensorflow.contrib.slim.nets import resnet_v1 38 | 39 | ResNet-101 for image classification into 1000 classes: 40 | 41 | # inputs has shape [batch, 224, 224, 3] 42 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 43 | net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False) 44 | 45 | ResNet-101 for semantic segmentation into 21 classes: 46 | 47 | # inputs has shape [batch, 513, 513, 3] 48 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 49 | net, end_points = resnet_v1.resnet_v1_101(inputs, 50 | 21, 51 | is_training=False, 52 | global_pool=False, 53 | output_stride=16) 54 | """ 55 | # from __future__ import absolute_import 56 | # from __future__ import division 57 | # from __future__ import print_function 58 | 59 | import tensorflow as tf 60 | from tensorflow.contrib import slim 61 | 62 | from . import resnet_utils 63 | 64 | resnet_arg_scope = resnet_utils.resnet_arg_scope 65 | 66 | 67 | @slim.add_arg_scope 68 | def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1, 69 | outputs_collections=None, scope=None): 70 | """Bottleneck residual unit variant with BN after convolutions. 71 | 72 | This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for 73 | its definition. Note that we use here the bottleneck variant which has an 74 | extra bottleneck layer. 75 | 76 | When putting together two consecutive ResNet blocks that use this unit, one 77 | should use stride = 2 in the last unit of the first block. 78 | 79 | Args: 80 | inputs: A tensor of size [batch, height, width, channels]. 81 | depth: The depth of the ResNet unit output. 82 | depth_bottleneck: The depth of the bottleneck layers. 83 | stride: The ResNet unit's stride. Determines the amount of downsampling of 84 | the units output compared to its input. 85 | rate: An integer, rate for atrous convolution. 86 | outputs_collections: Collection to add the ResNet unit output. 87 | scope: Optional variable_scope. 88 | 89 | Returns: 90 | The ResNet unit's output. 91 | """ 92 | with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: 93 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 94 | if depth == depth_in: 95 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 96 | else: 97 | shortcut = slim.conv2d(inputs, depth, [1, 1], stride=stride, 98 | activation_fn=None, scope='shortcut') 99 | 100 | residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1, 101 | scope='conv1') 102 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride, 103 | rate=rate, scope='conv2') 104 | residual = slim.conv2d(residual, depth, [1, 1], stride=1, 105 | activation_fn=None, scope='conv3') 106 | 107 | output = tf.nn.relu(shortcut + residual) 108 | 109 | return slim.utils.collect_named_outputs(outputs_collections, 110 | sc.original_name_scope, 111 | output) 112 | 113 | 114 | def resnet_v1(inputs, 115 | blocks, 116 | num_classes=None, 117 | is_training=True, 118 | global_pool=True, 119 | output_stride=None, 120 | include_root_block=True, 121 | spatial_squeeze=True, 122 | reuse=None, 123 | scope=None): 124 | """Generator for v1 ResNet models. 125 | 126 | This function generates a family of ResNet v1 models. See the resnet_v1_*() 127 | methods for specific model instantiations, obtained by selecting different 128 | block instantiations that produce ResNets of various depths. 129 | 130 | Training for image classification on Imagenet is usually done with [224, 224] 131 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 132 | block for the ResNets defined in [1] that have nominal stride equal to 32. 133 | However, for dense prediction tasks we advise that one uses inputs with 134 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 135 | this case the feature maps at the ResNet output will have spatial shape 136 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 137 | and corners exactly aligned with the input image corners, which greatly 138 | facilitates alignment of the features to the image. Using as input [225, 225] 139 | images results in [8, 8] feature maps at the output of the last ResNet block. 140 | 141 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 142 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 143 | have nominal stride equal to 32 and a good choice in FCN mode is to use 144 | output_stride=16 in order to increase the density of the computed features at 145 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 146 | 147 | Args: 148 | inputs: A tensor of size [batch, height_in, width_in, channels]. 149 | blocks: A list of length equal to the number of ResNet blocks. Each element 150 | is a resnet_utils.Block object describing the units in the block. 151 | num_classes: Number of predicted classes for classification tasks. If None 152 | we return the features before the logit layer. 153 | is_training: whether is training or not. 154 | global_pool: If True, we perform global average pooling before computing the 155 | logits. Set to True for image classification, False for dense prediction. 156 | output_stride: If None, then the output will be computed at the nominal 157 | network stride. If output_stride is not None, it specifies the requested 158 | ratio of input to output spatial resolution. 159 | include_root_block: If True, include the initial convolution followed by 160 | max-pooling, if False excludes it. 161 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is 162 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 163 | reuse: whether or not the network and its variables should be reused. To be 164 | able to reuse 'scope' must be given. 165 | scope: Optional variable_scope. 166 | 167 | Returns: 168 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 169 | If global_pool is False, then height_out and width_out are reduced by a 170 | factor of output_stride compared to the respective height_in and width_in, 171 | else both height_out and width_out equal one. If num_classes is None, then 172 | net is the output of the last ResNet block, potentially after global 173 | average pooling. If num_classes is not None, net contains the pre-softmax 174 | activations. 175 | end_points: A dictionary from components of the network to the corresponding 176 | activation. 177 | 178 | Raises: 179 | ValueError: If the target output_stride is not valid. 180 | """ 181 | with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc: 182 | end_points_collection = sc.name + '_end_points' 183 | with slim.arg_scope([slim.conv2d, bottleneck, 184 | resnet_utils.stack_blocks_dense], 185 | outputs_collections=end_points_collection): 186 | with slim.arg_scope([slim.batch_norm], is_training=is_training): 187 | net = inputs 188 | if include_root_block: 189 | if output_stride is not None: 190 | if output_stride % 4 != 0: 191 | raise ValueError('The output_stride needs to be a multiple of 4.') 192 | output_stride /= 4 193 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') 194 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') 195 | 196 | net = slim.utils.collect_named_outputs(end_points_collection, 'pool2', net) 197 | 198 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) 199 | 200 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 201 | 202 | # end_points['pool2'] = end_points['resnet_v1_50/pool1/MaxPool:0'] 203 | try: 204 | end_points['pool3'] = end_points['resnet_v1_50/block1'] 205 | end_points['pool4'] = end_points['resnet_v1_50/block2'] 206 | except: 207 | end_points['pool3'] = end_points['Detection/resnet_v1_50/block1'] 208 | end_points['pool4'] = end_points['Detection/resnet_v1_50/block2'] 209 | end_points['pool5'] = net 210 | # if global_pool: 211 | # # Global average pooling. 212 | # net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 213 | # if num_classes is not None: 214 | # net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 215 | # normalizer_fn=None, scope='logits') 216 | # if spatial_squeeze: 217 | # logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze') 218 | # else: 219 | # logits = net 220 | # # Convert end_points_collection into a dictionary of end_points. 221 | # end_points = slim.utils.convert_collection_to_dict(end_points_collection) 222 | # if num_classes is not None: 223 | # end_points['predictions'] = slim.softmax(logits, scope='predictions') 224 | return net, end_points 225 | 226 | 227 | resnet_v1.default_image_size = 224 228 | 229 | 230 | def resnet_v1_50(inputs, 231 | num_classes=None, 232 | is_training=True, 233 | global_pool=True, 234 | output_stride=None, 235 | spatial_squeeze=True, 236 | reuse=None, 237 | scope='resnet_v1_50'): 238 | """ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" 239 | blocks = [ 240 | resnet_utils.Block( 241 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 242 | resnet_utils.Block( 243 | 'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]), 244 | resnet_utils.Block( 245 | 'block3', bottleneck, [(1024, 256, 1)] * 5 + [(1024, 256, 2)]), 246 | resnet_utils.Block( 247 | 'block4', bottleneck, [(2048, 512, 1)] * 3) 248 | ] 249 | return resnet_v1(inputs, blocks, num_classes, is_training, 250 | global_pool=global_pool, output_stride=output_stride, 251 | include_root_block=True, spatial_squeeze=spatial_squeeze, 252 | reuse=reuse, scope=scope) 253 | 254 | 255 | resnet_v1_50.default_image_size = resnet_v1.default_image_size 256 | 257 | 258 | def resnet_v1_101(inputs, 259 | num_classes=None, 260 | is_training=True, 261 | global_pool=True, 262 | output_stride=None, 263 | spatial_squeeze=True, 264 | reuse=None, 265 | scope='resnet_v1_101'): 266 | """ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" 267 | blocks = [ 268 | resnet_utils.Block( 269 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 270 | resnet_utils.Block( 271 | 'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]), 272 | resnet_utils.Block( 273 | 'block3', bottleneck, [(1024, 256, 1)] * 22 + [(1024, 256, 2)]), 274 | resnet_utils.Block( 275 | 'block4', bottleneck, [(2048, 512, 1)] * 3) 276 | ] 277 | return resnet_v1(inputs, blocks, num_classes, is_training, 278 | global_pool=global_pool, output_stride=output_stride, 279 | include_root_block=True, spatial_squeeze=spatial_squeeze, 280 | reuse=reuse, scope=scope) 281 | 282 | 283 | resnet_v1_101.default_image_size = resnet_v1.default_image_size 284 | 285 | 286 | def resnet_v1_152(inputs, 287 | num_classes=None, 288 | is_training=True, 289 | global_pool=True, 290 | output_stride=None, 291 | spatial_squeeze=True, 292 | reuse=None, 293 | scope='resnet_v1_152'): 294 | """ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" 295 | blocks = [ 296 | resnet_utils.Block( 297 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 298 | resnet_utils.Block( 299 | 'block2', bottleneck, [(512, 128, 1)] * 7 + [(512, 128, 2)]), 300 | resnet_utils.Block( 301 | 'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]), 302 | resnet_utils.Block( 303 | 'block4', bottleneck, [(2048, 512, 1)] * 3)] 304 | return resnet_v1(inputs, blocks, num_classes, is_training, 305 | global_pool=global_pool, output_stride=output_stride, 306 | include_root_block=True, spatial_squeeze=spatial_squeeze, 307 | reuse=reuse, scope=scope) 308 | 309 | 310 | resnet_v1_152.default_image_size = resnet_v1.default_image_size 311 | 312 | 313 | def resnet_v1_200(inputs, 314 | num_classes=None, 315 | is_training=True, 316 | global_pool=True, 317 | output_stride=None, 318 | spatial_squeeze=True, 319 | reuse=None, 320 | scope='resnet_v1_200'): 321 | """ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" 322 | blocks = [ 323 | resnet_utils.Block( 324 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 325 | resnet_utils.Block( 326 | 'block2', bottleneck, [(512, 128, 1)] * 23 + [(512, 128, 2)]), 327 | resnet_utils.Block( 328 | 'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]), 329 | resnet_utils.Block( 330 | 'block4', bottleneck, [(2048, 512, 1)] * 3)] 331 | return resnet_v1(inputs, blocks, num_classes, is_training, 332 | global_pool=global_pool, output_stride=output_stride, 333 | include_root_block=True, spatial_squeeze=spatial_squeeze, 334 | reuse=reuse, scope=scope) 335 | 336 | 337 | resnet_v1_200.default_image_size = resnet_v1.default_image_size 338 | 339 | 340 | if __name__ == '__main__': 341 | input = tf.placeholder(tf.float32, shape=(None, 224, 224, 3), name='input') 342 | with slim.arg_scope(resnet_arg_scope()) as sc: 343 | logits = resnet_v1_50(input) -------------------------------------------------------------------------------- /EAST/output.txt: -------------------------------------------------------------------------------- 1 | img_calligraphy_81777_bg,437,111,20,606,424,1086,842,591 2 | img_calligraphy_82643_bg,525,112,122,632,513,1137,916,617 3 | -------------------------------------------------------------------------------- /EAST/run_demo_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | 5 | import time 6 | import datetime 7 | import cv2 8 | import numpy as np 9 | import uuid 10 | import json 11 | 12 | import functools 13 | import logging 14 | import collections 15 | 16 | logger = logging.getLogger(__name__) 17 | logger.setLevel(logging.INFO) 18 | 19 | 20 | @functools.lru_cache(maxsize=1) 21 | def get_host_info(): 22 | ret = {} 23 | with open('/proc/cpuinfo') as f: 24 | ret['cpuinfo'] = f.read() 25 | 26 | with open('/proc/meminfo') as f: 27 | ret['meminfo'] = f.read() 28 | 29 | with open('/proc/loadavg') as f: 30 | ret['loadavg'] = f.read() 31 | 32 | return ret 33 | 34 | 35 | @functools.lru_cache(maxsize=100) 36 | def get_predictor(checkpoint_path): 37 | logger.info('loading model') 38 | import tensorflow as tf 39 | import model 40 | from icdar import restore_rectangle 41 | import lanms 42 | from eval import resize_image, sort_poly, detect 43 | 44 | input_images = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_images') 45 | global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) 46 | 47 | f_score, f_geometry = model.model(input_images, is_training=False) 48 | 49 | variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step) 50 | saver = tf.train.Saver(variable_averages.variables_to_restore()) 51 | 52 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 53 | 54 | ckpt_state = tf.train.get_checkpoint_state(checkpoint_path) 55 | model_path = os.path.join(checkpoint_path, os.path.basename(ckpt_state.model_checkpoint_path)) 56 | logger.info('Restore from {}'.format(model_path)) 57 | saver.restore(sess, model_path) 58 | 59 | def predictor(img): 60 | """ 61 | :return: { 62 | 'text_lines': [ 63 | { 64 | 'score': , 65 | 'x0': , 66 | 'y0': , 67 | 'x1': , 68 | ... 69 | 'y3': , 70 | } 71 | ], 72 | 'rtparams': { # runtime parameters 73 | 'image_size': , 74 | 'working_size': , 75 | }, 76 | 'timing': { 77 | 'net': , 78 | 'restore': , 79 | 'nms': , 80 | 'cpuinfo': , 81 | 'meminfo': , 82 | 'uptime': , 83 | } 84 | } 85 | """ 86 | start_time = time.time() 87 | rtparams = collections.OrderedDict() 88 | rtparams['start_time'] = datetime.datetime.now().isoformat() 89 | rtparams['image_size'] = '{}x{}'.format(img.shape[1], img.shape[0]) 90 | timer = collections.OrderedDict([ 91 | ('net', 0), 92 | ('restore', 0), 93 | ('nms', 0) 94 | ]) 95 | 96 | im_resized, (ratio_h, ratio_w) = resize_image(img) 97 | rtparams['working_size'] = '{}x{}'.format( 98 | im_resized.shape[1], im_resized.shape[0]) 99 | start = time.time() 100 | score, geometry = sess.run( 101 | [f_score, f_geometry], 102 | feed_dict={input_images: [im_resized[:,:,::-1]]}) 103 | timer['net'] = time.time() - start 104 | 105 | boxes, timer = detect(score_map=score, geo_map=geometry, timer=timer) 106 | logger.info('net {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms'.format( 107 | timer['net']*1000, timer['restore']*1000, timer['nms']*1000)) 108 | 109 | if boxes is not None: 110 | scores = boxes[:,8].reshape(-1) 111 | boxes = boxes[:, :8].reshape((-1, 4, 2)) 112 | boxes[:, :, 0] /= ratio_w 113 | boxes[:, :, 1] /= ratio_h 114 | 115 | duration = time.time() - start_time 116 | timer['overall'] = duration 117 | logger.info('[timing] {}'.format(duration)) 118 | 119 | text_lines = [] 120 | if boxes is not None: 121 | text_lines = [] 122 | for box, score in zip(boxes, scores): 123 | box = sort_poly(box.astype(np.int32)) 124 | if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3]-box[0]) < 5: 125 | continue 126 | tl = collections.OrderedDict(zip( 127 | ['x0', 'y0', 'x1', 'y1', 'x2', 'y2', 'x3', 'y3'], 128 | map(float, box.flatten()))) 129 | tl['score'] = float(score) 130 | text_lines.append(tl) 131 | ret = { 132 | 'text_lines': text_lines, 133 | 'rtparams': rtparams, 134 | 'timing': timer, 135 | } 136 | ret.update(get_host_info()) 137 | return ret 138 | 139 | 140 | return predictor 141 | 142 | 143 | ### the webserver 144 | from flask import Flask, request, render_template 145 | import argparse 146 | 147 | 148 | class Config: 149 | SAVE_DIR = 'static/results' 150 | 151 | 152 | config = Config() 153 | 154 | 155 | app = Flask(__name__) 156 | 157 | @app.route('/') 158 | def index(): 159 | return render_template('index.html', session_id='dummy_session_id') 160 | 161 | 162 | def draw_illu(illu, rst): 163 | for t in rst['text_lines']: 164 | d = np.array([t['x0'], t['y0'], t['x1'], t['y1'], t['x2'], 165 | t['y2'], t['x3'], t['y3']], dtype='int32') 166 | d = d.reshape(-1, 2) 167 | cv2.polylines(illu, [d], isClosed=True, color=(255, 255, 0)) 168 | return illu 169 | 170 | 171 | def save_result(img, rst): 172 | session_id = str(uuid.uuid1()) 173 | dirpath = os.path.join(config.SAVE_DIR, session_id) 174 | os.makedirs(dirpath) 175 | 176 | # save input image 177 | output_path = os.path.join(dirpath, 'input.png') 178 | cv2.imwrite(output_path, img) 179 | 180 | # save illustration 181 | output_path = os.path.join(dirpath, 'output.png') 182 | cv2.imwrite(output_path, draw_illu(img.copy(), rst)) 183 | 184 | # save json data 185 | output_path = os.path.join(dirpath, 'result.json') 186 | with open(output_path, 'w') as f: 187 | json.dump(rst, f) 188 | 189 | rst['session_id'] = session_id 190 | return rst 191 | 192 | checkpoint_path = '/home/jqf/tmp/ckpt/east' 193 | 194 | 195 | @app.route('/', methods=['POST']) 196 | def index_post(): 197 | global predictor 198 | import io 199 | bio = io.BytesIO() 200 | request.files['image'].save(bio) 201 | img = cv2.imdecode(np.frombuffer(bio.getvalue(), dtype='uint8'), 1) 202 | rst = get_predictor(checkpoint_path)(img) 203 | 204 | save_result(img, rst) 205 | return render_template('index.html', session_id=rst['session_id']) 206 | 207 | 208 | def main(): 209 | global checkpoint_path 210 | parser = argparse.ArgumentParser() 211 | parser.add_argument('--port', default=8769, type=int) 212 | parser.add_argument('--checkpoint-path', default=checkpoint_path) 213 | parser.add_argument('--debug', action='store_true') 214 | args = parser.parse_args() 215 | checkpoint_path = args.checkpoint_path 216 | 217 | if not os.path.exists(args.checkpoint_path): 218 | raise RuntimeError( 219 | 'Checkpoint `{}` not found'.format(args.checkpoint_path)) 220 | 221 | app.debug = args.debug 222 | app.run('0.0.0.0', args.port) 223 | 224 | if __name__ == '__main__': 225 | main() 226 | 227 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 DataFountainCode 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # huawei_code_share 2 | https://www.datafountain.cn/competitions/334/details 赛题的baseline开源 3 | 4 | 5 | # 华为赛题开源Baseline - Modlearts玩比赛 6 | 7 | 8 | 标签: 2019数字中国创新大赛 9 | 10 | --- 11 | 12 | #### 写在前面 13 | 14 | 本文使用EAST模型作为文字框检测模型https://github.com/argman/EAST , 15 | 在文字识别OCR模型上使用西安交通大学人工智能实践大赛第一名@ yinchangchang 的方案https://github.com/yinchangchang/ocr_densenet 16 | 17 | 本文代码均已开源在且已经修改成了可以在ModelArts训练的格式,可以对比开源的EAST和OCR代码,查看修改了哪些地方。 18 | 本文在OCR模型上花6个小时,仅训练了10个epoch,在排行榜A榜得到0.42的F1,笔者目测再训练久一点F1>0.80是肯定有的。 19 | 20 | 在ModelArts上训练的注意事项: 21 | 1.需要修改文件保存、修改、读取的方法,具体请看1.2节; 22 | 2.训练是将OBS上的启动文件所在目录下载到GPU机器上运行,GPU机器用户路径为/home/work/,如需要下载数据到机器上,推荐下载到/cache/目录下(机器上的所以数据在一次训练作业完成后,都会清空); 23 | 3.请及时查看作业运行状态,以免造成代金券浪费; 24 | 25 | # EAST 26 | 27 | #### EAST数据处理 28 | 29 | 为了节省优惠券以及线上操作时间,在上ModelArts之前,先将数据处理完成后再上传。解压所有下载的数据包。 30 | ![image.png-3.8kB][1] 31 | EAST需要一张图对应一个.txt格式的数据,代码里的convert_to_txt.py可以将训练集方便地转换成需要的格式,其中convert_to_txt.py里的数据路径需要改成自己的数据路径, 32 | ![image.png-13.2kB][2] 33 | 将得到的数据上传到在OBS上创建的路径如 34 | ![image.png-22.7kB][3] 35 | 36 | #### ModelArts使用注意点 37 | ModelArts使用注意点: 38 | 39 | 1.如果发现没有某个python库,需要在训练脚本里加上“os.system(‘pip install xxx’)”,系统会自动安装这个库; 40 | 41 | 2.无法直接使用open方法读写OBS上的文件,需要使用moxing.Framework.file.File代替open,如open(‘input.txt’,’r’)-> moxing. Framework.file.File(‘input.txt’,’r’); 42 | 43 | 3.Glob也需要moxing.Framework.file.glob; 44 | 45 | 4.一般情况下,ModelArts的每个引擎都对保存checkpoint方法做了对OBS路径的适配,如果发现不能保存也可以将checkpoint路径设置为”./xxx”或者“/cache/xxx”,然后再使用mox.file.copy('./model.ckpt', 's3://ckpt/model.ckpt') 46 | 将EAST代码上传到OBS: 47 | ![image.png-8.9kB][4] 48 | 49 | 使用ModelArts创建训练作业,注意不能使用notebook创建,notebook里没有GPU资源,而且使用notebook也只能暂时保存数据,一旦关闭后,数据都会清空,而且不关闭notebook,会消耗大量代金券。但是使用OBS存储的数据不会清空,使用创建作业方式训练可以节省代金券。 50 | 51 | #### 使用ModelArts训练EAST 52 | ![image.png-25.2kB][5] 53 | 之后选择数据存储路径和使用的引擎,启动文件等, 54 | ![image.png-26.6kB][6] 55 | 再输入使用脚本需要的相应参数 56 | ![image.png-39.1kB][7] 57 | 58 | 参数列表: 59 | gpu_list=0 60 | input_size=512 61 | batch_size_per_gpu=14 62 | checkpoint_path_obs=s3://tcd_public/ckpt 63 | text_scale=512 64 | training_data_path_obs=s3://tcd_public/data/east/ 65 | geometry=RBOX 66 | learning_rate=0.0001 67 | num_readers=24 68 | 选择计算资源,并保存作业参数,以便下次使用,就可以开始运行了(18块钱真的贵)。 69 | ![image.png-26.6kB][8] 70 | 点击运行, 71 | ![image.png-61.5kB][9] 72 | 最终在ckpt文件夹下面会生成几个模型,如图所示(只训练了一个step) 73 | ![image.png-29.9kB][10] 74 | 75 | #### 推理测试 76 | 77 | 在训练到一定精度后,就可以测试了。同样创建作业,选择test数据集,使用EAST里的eval.py脚本,输入必要参数,就可以开始运行。 78 | ![image.png-39.2kB][11] 79 | ![image.png-46.6kB][12] 80 | 之后在OBS的data目录下会出现output.txt文件,里面的每行包含测试图片的名字和4个x和y的点。 81 | ![image.png-13.6kB][13] 82 | 83 | # OCR 84 | #### 生成ocr数据 85 | 数据主要包含以下要求: 86 | 87 | - 1.开源的第一名代码需要使用一个train.csv,包含name和content两个字段的文件 88 | - 2.训练OCR需要截取原图的数据中的每一列文字,这里只简单使用最大的xy坐标截取; 89 | - 3.生成测试数据集; 90 | - 4.所有数据集均保存到data/dataset/train/和test/下,可以少改些代码; 91 | 92 | 本文只使用了训练数据集,没有将验证集加入训练,如要取的更高的精度,应该将验证集也加入训练。 93 | 在线下使用ocr中的makedata.py生成训练所需要的数据格式,替换makedata.py里数据的相关路径。其中目标路径最好填写为ocr/data/dataset/train/和ocr/data/dataset/test/,input_file是指文字检测模型的推理输出output.txt,output_file是作品样例提交文件。 94 | ![image.png-33.6kB][14] 95 | ![image.png-33kB][15] 96 | 97 | 再在code中的preprocessing下运行map_word_to_index.py和analysis_dataset.py对数据做分析和文字提取,这个操作会在ocr/file/下生成训练的文字和图片的相关文件。 98 | 处理完数据就可以将ocr下所有代码和数据都上传OBS了。 99 | 当然如果觉得线下数据上传到OBS速度较慢,可以选择使用ModelArts的notebook,此时需要先下载原数据到notebook的机器上,如data_path=’/cache/data’, 100 | from moxing.framework import file 101 | file.copy_parallel(data_path_obs, data_path) 102 | 处理完数据后,再上传到OBS上, 103 | file.copy_parallel(你在/cache/下处理完后的数据路径, ocr需要的数据路径如/ocr/data/dataset/train/) 104 | 105 | #### 创建OCR训练作业 106 | 训练策略几乎与开源的方案一样,但是比赛使用的数据是竖排的,这里简单在dataloader.py里使用transpose转置成横向的。 107 | 创建作业,输入参数: 108 | ![image.png-37.6kB][16] 109 | 之后点击确定开始运行,几个step之后可以看到,loss在下降了,等到差不多十个epoch,花了大概6小时。(本人只训练了10个epoch,训练更多个epoch以及如果再加上源码中的hard mining可以得到更高的分数)。 110 | ![image.png-46.3kB][17] 111 | 112 | 这里设置每一个epoch保存一次ckpt,在save-dir-obs路径下可以看到ckpt文件。 113 | ![image.png-16kB][18] 114 | 115 | #### 推理预测 116 | 用main.py做预测,设置phase为test,设置resume参数使用的ckpt路径,设置为GPU机器上的/cache/路径,参数如图, 117 | ![image.png-29.7kB][19] 118 | 119 | 最终可以在OBS路径上看到predict.csv的文件,下载就可以上传到比赛官网了。 120 | ![image.png-4.8kB][20] 121 | ![image.png-48.4kB][21] 122 | 123 | 124 | [1]: http://static.zybuluo.com/nxzyq123/4ilh4mvf22wqwf4mclt9xpqt/image.png 125 | [2]: http://static.zybuluo.com/nxzyq123/4xtw2z3yv4y94yi98r40b5jo/image.png 126 | [3]: http://static.zybuluo.com/nxzyq123/m374c73514uekqq96tlxofmr/image.png 127 | [4]: http://static.zybuluo.com/nxzyq123/4pade8guxiwg9wnz1c38au6o/image.png 128 | [5]: http://static.zybuluo.com/nxzyq123/ojsmrfyzhj7w3421r47te6j2/image.png 129 | [6]: http://static.zybuluo.com/nxzyq123/tvhfqulfcyopzz2opgf2bcwi/image.png 130 | [7]: http://static.zybuluo.com/nxzyq123/ri27x8qejm4hggj2899chuw0/image.png 131 | [8]: http://static.zybuluo.com/nxzyq123/k51ynmd1msvvtl0n0gmr0ka5/image.png 132 | [9]: http://static.zybuluo.com/nxzyq123/j157rof3bqqyjz1sbrqv1rkt/image.png 133 | [10]: http://static.zybuluo.com/nxzyq123/6krdgviddk0h8vfthdlepljn/image.png 134 | [11]: http://static.zybuluo.com/nxzyq123/n3v6x72644mx3t2cr9eaphj6/image.png 135 | [12]: http://static.zybuluo.com/nxzyq123/dlnfvvm6ci0tlsl3yrqyxllv/image.png 136 | [13]: http://static.zybuluo.com/nxzyq123/vej3kuaexgis20icnz9yryfn/image.png 137 | [14]: http://static.zybuluo.com/nxzyq123/wtnmmx164oifvvnfl6g1cipy/image.png 138 | [15]: http://static.zybuluo.com/nxzyq123/tdlnl7boln1iaubp90mibvel/image.png 139 | [16]: http://static.zybuluo.com/nxzyq123/ptghksirdgq41es7zjxpm8wx/image.png 140 | [17]: http://static.zybuluo.com/nxzyq123/02geisyuxii1vtyfgjia3ubg/image.png 141 | [18]: http://static.zybuluo.com/nxzyq123/kezxhui3t3cmzo5yfcp3fuif/image.png 142 | [19]: http://static.zybuluo.com/nxzyq123/o61krd2ranebey2p70g54l3c/image.png 143 | [20]: http://static.zybuluo.com/nxzyq123/timegwww4hcagxxdpfd7zemv/image.png 144 | [21]: http://static.zybuluo.com/nxzyq123/i2y5xkioxlvfd40anfth4z3z/image.png 145 | -------------------------------------------------------------------------------- /ocr_densenet/code/ocr/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | from collections import OrderedDict 6 | 7 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 8 | 9 | 10 | model_urls = { 11 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 12 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 13 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 14 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 15 | } 16 | 17 | 18 | def densenet121(pretrained=False, small=0,**kwargs): 19 | r"""Densenet-121 model from 20 | `"Densely Connected Convolutional Networks" `_ 21 | 22 | Args: 23 | pretrained (bool): If True, returns a model pre-trained on ImageNet 24 | """ 25 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), small=small, 26 | **kwargs) 27 | if pretrained: 28 | model.load_state_dict(model_zoo.load_url(model_urls['densenet121'])) 29 | return model 30 | 31 | 32 | def densenet169(pretrained=False, **kwargs): 33 | r"""Densenet-169 model from 34 | `"Densely Connected Convolutional Networks" `_ 35 | 36 | Args: 37 | pretrained (bool): If True, returns a model pre-trained on ImageNet 38 | """ 39 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 40 | **kwargs) 41 | if pretrained: 42 | model.load_state_dict(model_zoo.load_url(model_urls['densenet169'])) 43 | return model 44 | 45 | 46 | def densenet201(pretrained=False, **kwargs): 47 | r"""Densenet-201 model from 48 | `"Densely Connected Convolutional Networks" `_ 49 | 50 | Args: 51 | pretrained (bool): If True, returns a model pre-trained on ImageNet 52 | """ 53 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 54 | **kwargs) 55 | if pretrained: 56 | model.load_state_dict(model_zoo.load_url(model_urls['densenet201'])) 57 | return model 58 | 59 | 60 | def densenet161(pretrained=False, **kwargs): 61 | r"""Densenet-161 model from 62 | `"Densely Connected Convolutional Networks" `_ 63 | 64 | Args: 65 | pretrained (bool): If True, returns a model pre-trained on ImageNet 66 | """ 67 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 68 | **kwargs) 69 | if pretrained: 70 | model.load_state_dict(model_zoo.load_url(model_urls['densenet161'])) 71 | return model 72 | 73 | 74 | class _DenseLayer(nn.Sequential): 75 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 76 | super(_DenseLayer, self).__init__() 77 | self.add_module('norm_1', nn.BatchNorm2d(num_input_features)), 78 | self.add_module('relu_1', nn.ReLU(inplace=True)), 79 | self.add_module('conv_1', nn.Conv2d(num_input_features, bn_size * 80 | growth_rate, kernel_size=1, stride=1, bias=False)), 81 | self.add_module('norm_2', nn.BatchNorm2d(bn_size * growth_rate)), 82 | self.add_module('relu_2', nn.ReLU(inplace=True)), 83 | self.add_module('conv_2', nn.Conv2d(bn_size * growth_rate, growth_rate, 84 | kernel_size=3, stride=1, padding=1, bias=False)), 85 | self.drop_rate = drop_rate 86 | 87 | def forward(self, x): 88 | new_features = super(_DenseLayer, self).forward(x) 89 | if self.drop_rate > 0: 90 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 91 | return torch.cat([x, new_features], 1) 92 | 93 | 94 | class _DenseBlock(nn.Sequential): 95 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 96 | super(_DenseBlock, self).__init__() 97 | for i in range(num_layers): 98 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 99 | self.add_module('denselayer%d' % (i + 1), layer) 100 | 101 | 102 | class _Transition(nn.Sequential): 103 | def __init__(self, num_input_features, num_output_features, use_pool): 104 | super(_Transition, self).__init__() 105 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 106 | self.add_module('relu', nn.ReLU(inplace=True)) 107 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 108 | kernel_size=1, stride=1, bias=False)) 109 | if use_pool: 110 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 111 | 112 | 113 | class DenseNet(nn.Module): 114 | r"""Densenet-BC model class, based on 115 | `"Densely Connected Convolutional Networks" `_ 116 | 117 | Args: 118 | growth_rate (int) - how many filters to add each layer (`k` in paper) 119 | block_config (list of 4 ints) - how many layers in each pooling block 120 | num_init_features (int) - the number of filters to learn in the first convolution layer 121 | bn_size (int) - multiplicative factor for number of bottle neck layers 122 | (i.e. bn_size * k features in the bottleneck layer) 123 | drop_rate (float) - dropout rate after each dense layer 124 | num_classes (int) - number of classification classes 125 | """ 126 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), small=0, 127 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 128 | 129 | super(DenseNet, self).__init__() 130 | 131 | # First convolution 132 | self.features = nn.Sequential(OrderedDict([ 133 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 134 | ('norm0', nn.BatchNorm2d(num_init_features)), 135 | ('relu0', nn.ReLU(inplace=True)), 136 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 137 | ])) 138 | 139 | # Each denseblock 140 | num_features = num_init_features 141 | for i, num_layers in enumerate(block_config): 142 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 143 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 144 | self.features.add_module('denseblock%d' % (i + 1), block) 145 | num_features = num_features + num_layers * growth_rate 146 | if i != len(block_config) - 1: 147 | if small and i > 0: 148 | use_pool = 0 149 | else: 150 | use_pool = 1 151 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, use_pool=use_pool) 152 | self.features.add_module('transition%d' % (i + 1), trans) 153 | num_features = num_features // 2 154 | 155 | # Final batch norm 156 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 157 | 158 | # Linear layer 159 | self.classifier = nn.Linear(num_features, num_classes) 160 | 161 | def forward(self, x): 162 | features = self.features(x) 163 | return features 164 | att_feats = features 165 | out = F.relu(features, inplace=True) 166 | out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) 167 | # out = F.avg_pool2d(out, kernel_size=3, stride=1).view(features.size(0), -1) 168 | fc_feats = out 169 | out = self.classifier(out) 170 | return att_feats, fc_feats, out 171 | -------------------------------------------------------------------------------- /ocr_densenet/code/ocr/resnet.py: -------------------------------------------------------------------------------- 1 | # Implementation of https://arxiv.org/pdf/1512.03385.pdf. 2 | # See section 4.2 for model architecture on CIFAR-10. 3 | # Some part of the code was referenced below. 4 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.datasets as dsets 8 | import torchvision.transforms as transforms 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | 12 | # 3x3 Convolution 13 | def conv3x3(in_channels, out_channels, stride=1): 14 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, 15 | stride=stride, padding=1, bias=False) 16 | 17 | # Residual Block 18 | class ResidualBlock(nn.Module): 19 | def __init__(self, in_channels, out_channels, stride=1, downsample=None): 20 | super(ResidualBlock, self).__init__() 21 | self.conv1 = conv3x3(in_channels, out_channels, stride) 22 | self.bn1 = nn.BatchNorm2d(out_channels) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.conv2 = conv3x3(out_channels, out_channels) 25 | self.bn2 = nn.BatchNorm2d(out_channels) 26 | self.downsample = downsample 27 | 28 | def forward(self, x): 29 | residual = x 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.relu(out) 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | if self.downsample: 36 | residual = self.downsample(x) 37 | out += residual 38 | out = self.relu(out) 39 | return out 40 | 41 | # ResNet Module 42 | class ResNet(nn.Module): 43 | def __init__(self, block=ResidualBlock, layers=[2,3], num_classes=10, args=None): 44 | super(ResNet, self).__init__() 45 | self.in_channels = 16 46 | self.conv = conv3x3(3, 16) 47 | self.bn = nn.BatchNorm2d(16) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.layer1 = self.make_layer(block, 32, layers[0], 2) 50 | self.layer2 = self.make_layer(block, 64, layers[0], 2) 51 | self.layer3 = self.make_layer(block, 128, layers[0], 2) 52 | self.layer4 = self.make_layer(block, 128, layers[0], 2) 53 | self.layer5 = self.make_layer(block, 128, layers[0], 2) 54 | self.fc = nn.Linear(128, num_classes) 55 | 56 | # detect 57 | self.convt1 = nn.Sequential( 58 | nn.ConvTranspose2d(128,128,kernel_size=2, stride=2), 59 | nn.BatchNorm2d(128), 60 | nn.ReLU(inplace=True)) 61 | self.convt2 = nn.Sequential( 62 | nn.ConvTranspose2d(128,128,kernel_size=2, stride=2), 63 | nn.BatchNorm2d(128), 64 | nn.ReLU(inplace=True)) 65 | self.convt3 = nn.Sequential( 66 | nn.ConvTranspose2d(128,128,kernel_size=2, stride=2), 67 | nn.BatchNorm2d(128), 68 | nn.ReLU(inplace=True)) 69 | self.convt4 = nn.Sequential( 70 | nn.ConvTranspose2d(128,128,kernel_size=2, stride=2), 71 | nn.BatchNorm2d(128), 72 | nn.ReLU(inplace=True)) 73 | self.in_channels = 256 74 | self.dec1 = self.make_layer(block, 128, layers[0]) 75 | self.in_channels = 256 76 | self.dec2 = self.make_layer(block, 128, layers[0]) 77 | self.in_channels = 192 78 | self.dec3 = self.make_layer(block, 128, layers[0]) 79 | self.in_channels = 160 80 | # self.dec4 = self.make_layer(block, 1, layers[0]) 81 | self.dec4 = nn.Sequential( 82 | nn.Conv2d(160, 256, kernel_size=3, padding=1), 83 | nn.BatchNorm2d(256), 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(256, 1, kernel_size=1, bias=True) 86 | ) 87 | self.in_channels = 256 88 | # self.dec2 = self.make_layer(block, 256, layers[0]) 89 | # self.output = conv3x3(256, 4 * len(args.anchors)) 90 | self.bbox = nn.Sequential( 91 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 92 | nn.BatchNorm2d(256), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(256, 4 * len(args.anchors), kernel_size=1, bias=True) 95 | ) 96 | self.sigmoid = nn.Sigmoid() 97 | 98 | 99 | def make_layer(self, block, out_channels, blocks, stride=1): 100 | downsample = None 101 | if (stride != 1) or (self.in_channels != out_channels): 102 | downsample = nn.Sequential( 103 | conv3x3(self.in_channels, out_channels, stride=stride), 104 | nn.BatchNorm2d(out_channels)) 105 | layers = [] 106 | layers.append(block(self.in_channels, out_channels, stride, downsample)) 107 | self.in_channels = out_channels 108 | for i in range(1, blocks): 109 | layers.append(block(out_channels, out_channels)) 110 | return nn.Sequential(*layers) 111 | 112 | def forward(self, x, phase='train'): 113 | out = self.conv(x) 114 | # print out.size() 115 | out = self.bn(out) 116 | # print out.size() 117 | out = self.relu(out) 118 | # print out.size() 119 | out1 = self.layer1(out) # 64 120 | # print out1.size() 121 | out2 = self.layer2(out1) # 32 122 | # print out2.size() 123 | out3 = self.layer3(out2) # 16 124 | # print out3.size() 125 | out4 = self.layer4(out3) # 8 126 | # print out4.size() 127 | out5 = self.layer5(out4) # 4 128 | # print out5.size() 129 | 130 | # out = F.adaptive_max_pool2d(out5, output_size=(1,1)).view(out.size(0), -1) # 128 131 | # out = out.view(out.size(0), -1) 132 | 133 | if phase == 'seg': 134 | out = F.adaptive_max_pool2d(out5, output_size=(1,1)).view(out.size(0), -1) # 128 135 | out = self.fc(out) 136 | out = out.view(out.size(0), -1) 137 | else: 138 | out = F.max_pool2d(out5, 2) 139 | out_size = out.size() 140 | # out = out.view(out_size[0],out_size[1],out_size[3]).transpose(1,2).contiguous().view(-1, out_size[1]) 141 | out = out.view(out_size[0],out_size[1],out_size[2] * out_size[3]).transpose(1,2).contiguous().view(-1, out_size[1]) 142 | out = self.fc(out) 143 | out = out.view(out_size[0], out_size[2] * out_size[3], -1).transpose(1,2).contiguous() 144 | out = F.adaptive_max_pool1d(out, output_size=(1)).view(out_size[0], -1) 145 | 146 | # print out.size() 147 | if phase not in ['seg', 'pretrain', 'pretrain2']: 148 | return out 149 | 150 | # detect 151 | cat1 = torch.cat([self.convt1(out5), out4], 1) 152 | # print cat1.size() 153 | dec1 = self.dec1(cat1) 154 | # print dec1.size() 155 | # print out3.size() 156 | cat2 = torch.cat([self.convt2(dec1), out3], 1) 157 | # print cat2.size() 158 | dec2 = self.dec2(cat2) 159 | cat3 = torch.cat([self.convt3(dec2), out2], 1) 160 | dec3 = self.dec3(cat3) 161 | cat4 = torch.cat([self.convt4(dec3), out1], 1) 162 | seg = self.dec4(cat4) 163 | seg = seg.view((seg.size(0), seg.size(2), seg.size(3))) 164 | seg = self.sigmoid(seg) 165 | 166 | bbox = self.bbox(cat2) 167 | # dec2 = self.output(dec2) 168 | # print dec2.size() 169 | size = bbox.size() 170 | bbox = bbox.view((size[0], size[1], -1)).transpose(1,2).contiguous() 171 | bbox = bbox.view((size[0], size[2],size[3],-1, 4)) 172 | 173 | return out, bbox, seg 174 | 175 | # resnet = ResNet(ResidualBlock, [2, 2, 2, 2]) 176 | -------------------------------------------------------------------------------- /ocr_densenet/code/ocr/tools/measures.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import os 3 | import numpy as np 4 | from sklearn import metrics 5 | from PIL import Image 6 | import traceback 7 | 8 | def stati_class_number_true_flase(label, pred): 9 | label = np.array(label) 10 | pred = np.array(pred) 11 | 12 | cls_list = set(label) | set(pred) 13 | d = dict() 14 | for cls in cls_list: 15 | d[cls] = dict() 16 | d[cls]['number'] = np.sum(label==cls) 17 | d[cls]['true'] = np.sum(label[label==cls]==pred[label==cls]) 18 | d[cls]['pred'] = np.sum(pred==cls) 19 | return d 20 | 21 | def stati_class_number_true_flase_multi_label_margin(labels, preds): 22 | 23 | d = dict() 24 | for label, pred in zip(labels, preds): 25 | label = set(label[label>=0]) 26 | for cls in range(len(pred)): 27 | if cls not in d: 28 | d[cls] = dict() 29 | d[cls]['number'] = 0 30 | d[cls]['true'] = 0 31 | d[cls]['pred'] = 0 32 | if cls in label: 33 | d[cls]['number'] += 1 34 | if pred[cls] > 0.5: 35 | d[cls]['true'] += 1 36 | if pred[cls] > 0.5: 37 | d[cls]['pred'] += 1 38 | return d 39 | 40 | def stati_class_number_true_flase_bce(labels, preds): 41 | d = dict() 42 | labels = labels.astype(np.int64).reshape(-1) 43 | preds = preds.reshape(-1) > 0 44 | index = labels >= 0 45 | labels = labels[index] 46 | preds = preds[index] 47 | 48 | preds_num = preds.sum(0) 49 | true_num = (labels+preds==2).sum(0) 50 | for cls in range(2): 51 | d[cls] = dict() 52 | d[cls]['number'] = (labels==cls).sum() 53 | d[cls]['true'] = (labels+preds==2*cls).sum() 54 | d[cls]['pred'] = (labels==cls).sum() 55 | return d 56 | 57 | def measures(d_list): 58 | # 合并每一个预测的结果 59 | d_all = dict() 60 | for d in d_list: 61 | for cls in d.keys(): 62 | if cls not in d_all: 63 | d_all[cls] = dict() 64 | for k in d[cls].keys(): 65 | if k not in d_all[cls]: 66 | d_all[cls][k] = 0 67 | d_all[cls][k] += d[cls][k] 68 | m = dict() 69 | number = sum([d_all[cls]['number'] for cls in d_all.keys()]) 70 | for cls in d_all: 71 | m[cls] = dict() 72 | m[cls]['number'] = d_all[cls]['number'] 73 | m[cls]['true'] = d_all[cls]['true'] 74 | m[cls]['pred'] = d_all[cls]['pred'] 75 | m[cls]['ratio'] = d_all[cls]['number'] / (float(number) + 10e-10) 76 | m[cls]['accuracy'] = d_all[cls]['true'] / (float(d_all[cls]['number']) + 10e-10) 77 | m[cls]['precision'] = d_all[cls]['true'] /(float(d_all[cls]['pred']) + 10e-10) 78 | return m 79 | 80 | def print_measures(m, s = 'measures'): 81 | print s 82 | accuracy = 0 83 | for cls in sorted(m.keys()): 84 | print '\tclass: {:d}\taccuracy:{:.6f}\tprecision:{:.6f}\tratio:{:.6f}\t\tN/T/P:{:d}/{:d}/{:d}\ 85 | '.format(cls, m[cls]['accuracy'],m[cls]['precision'],m[cls]['ratio'],m[cls]['number'],m[cls]['true'],m[cls]['pred']) 86 | accuracy += m[cls]['accuracy'] * m[cls]['ratio'] 87 | print '\tacc:{:.6f}'.format(accuracy) 88 | return accuracy 89 | 90 | def mse(pred_image, image): 91 | pred_image = pred_image.reshape(-1).astype(np.float32) 92 | image = image.reshape(-1).astype(np.float32) 93 | mse_err = metrics.mean_squared_error(pred_image,image) 94 | return mse_err 95 | 96 | def psnr(pred_image, image): 97 | return 10 * np.log10(255*255/mse(pred_image,image)) 98 | 99 | 100 | def psnr_pred(stain_vis=20, end= 10000): 101 | clean_dir = '../../data/AI/testB/' 102 | psnr_list = [] 103 | f = open('../../data/result.csv','w') 104 | for i,clean in enumerate(os.listdir(clean_dir)): 105 | clean = os.path.join(clean_dir, clean) 106 | clean_file = clean 107 | pred = clean.replace('.jpg','.png').replace('data','data/test_clean') 108 | stain = clean.replace('trainB','trainA').replace('testB','testA').replace('.jpg','_.jpg') 109 | 110 | try: 111 | pred = np.array(Image.open(pred).resize((250,250))).astype(np.float32) 112 | clean = np.array(Image.open(clean).resize((250,250))).astype(np.float32) 113 | stain = np.array(Image.open(stain).resize((250,250))).astype(np.float32) 114 | 115 | # diff = np.abs(stain - pred) 116 | # vis = 20 117 | # pred[diffgray_vis] = stain[stain>gray_vis] 121 | 122 | if end < 1000: 123 | diff = np.abs(clean - stain) 124 | # stain[diff>stain_vis] = pred[diff>stain_vis] 125 | stain[diff>stain_vis] = clean[diff>stain_vis] 126 | 127 | psnr_pred = psnr(clean, pred) 128 | psnr_stain = psnr(clean, stain) 129 | psnr_list.append([psnr_stain, psnr_pred]) 130 | except: 131 | continue 132 | if i>end: 133 | break 134 | print i, min(end, 1000) 135 | 136 | f.write(clean_file.split('/')[-1].split('.')[0]) 137 | f.write(',') 138 | f.write(str(psnr_stain)) 139 | f.write(',') 140 | f.write(str(psnr_pred)) 141 | f.write(',') 142 | f.write(str(psnr_pred/psnr_stain - 1)) 143 | f.write('\n') 144 | # print '预测',np.mean(psnr_list) 145 | psnr_list = np.array(psnr_list) 146 | psnr_mean = ((psnr_list[:,1] - psnr_list[:,0]) / psnr_list[:,0]).mean() 147 | if end > 1000: 148 | print '网纹图PSNR', psnr_list[:,0].mean() 149 | print '预测图PSNR', psnr_list[:,1].mean() 150 | print '增益率', psnr_mean 151 | f.write(str(psnr_mean)) 152 | f.close() 153 | return psnr_list[:,0].mean() 154 | 155 | def main(): 156 | pmax = [0.,0.] 157 | for vis in range(1, 30): 158 | p = psnr_pred(vis, 10) 159 | print vis, p 160 | if p > pmax[1]: 161 | pmax = [vis, p] 162 | print '...' 163 | # print 256,psnr_pred(256) 164 | print pmax 165 | # print 10 * np.log10(255*255/metrics.mean_squared_error([3],[9])) 166 | 167 | 168 | if __name__ == '__main__': 169 | psnr_pred(4000) 170 | # main() 171 | # for v in range(1,10): 172 | # print v, 10 * np.log10(255*255/v/v) 173 | -------------------------------------------------------------------------------- /ocr_densenet/code/ocr/tools/parse.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | 3 | import argparse 4 | 5 | datadir = '/cache/ocr_densenet' 6 | 7 | parser = argparse.ArgumentParser(description='medical caption GAN') 8 | 9 | parser.add_argument( 10 | '--model', 11 | '-m', 12 | type=str, 13 | default='densenet', 14 | help='model' 15 | ) 16 | parser.add_argument( 17 | '--data-dir', 18 | '-d', 19 | type=str, 20 | default=datadir+'/data/dataset/', 21 | help='data directory' 22 | ) 23 | parser.add_argument( 24 | '--data-dir-obs', 25 | '-dd', 26 | type=str, 27 | default=datadir+'/data/dataset/', 28 | help='data directory' 29 | ) 30 | parser.add_argument( 31 | '--save-dir-obs', 32 | type=str, 33 | default=datadir+'/data', 34 | help='data directory' 35 | ) 36 | parser.add_argument( 37 | '--bg-dir', 38 | type=str, 39 | default=datadir+'/data/images', 40 | help='back groud images directory' 41 | ) 42 | parser.add_argument( 43 | '--hard-mining', 44 | type=int, 45 | default=0, 46 | help='use hard mining' 47 | ) 48 | parser.add_argument('--phase', 49 | default='train', 50 | type=str, 51 | metavar='S', 52 | help='pretrain/train/test phase') 53 | parser.add_argument( 54 | '--batch-size', 55 | '-b', 56 | metavar='BATCH SIZE', 57 | type=int, 58 | default=16, 59 | help='batch size' 60 | ) 61 | parser.add_argument('--save-dir', 62 | default=datadir+'/data', 63 | type=str, 64 | metavar='S', 65 | help='save dir') 66 | parser.add_argument('--word-index-json', 67 | default=datadir+'/files/alphabet_index_dict.json', 68 | type=str, 69 | metavar='S', 70 | help='save dir') 71 | parser.add_argument('--black-json', 72 | default=datadir+'/files/black.json', 73 | type=str, 74 | metavar='S', 75 | help='black_list json') 76 | parser.add_argument('--image-hw-ratio-json', 77 | default=datadir+'/files/image_hw_ratio_dict.json', 78 | type=str, 79 | metavar='S', 80 | help='image h:w ratio dict') 81 | parser.add_argument('--word-count-json', 82 | default=datadir+'/files/alphabet_count_dict.json', 83 | type=str, 84 | metavar='S', 85 | help='word count file') 86 | parser.add_argument('--image-label-json', 87 | default=datadir+'/files/train_alphabet.json', 88 | type=str, 89 | metavar='S', 90 | help='image label json') 91 | parser.add_argument('--resume', 92 | default='', 93 | type=str, 94 | metavar='S', 95 | help='start from checkpoints') 96 | parser.add_argument('--no-aug', 97 | default=0, 98 | type=int, 99 | metavar='S', 100 | help='no augmentation') 101 | parser.add_argument('--small', 102 | default=1, 103 | type=int, 104 | metavar='S', 105 | help='small fonts') 106 | parser.add_argument('--difficult', 107 | default=0, 108 | type=int, 109 | metavar='S', 110 | help='只计算比较难的图片') 111 | parser.add_argument('--hist', 112 | default=0, 113 | type=int, 114 | metavar='S', 115 | help='采用直方图均衡化') 116 | parser.add_argument('--feat', 117 | default=0, 118 | type=int, 119 | metavar='S', 120 | help='生成LSTM的feature') 121 | #parser.add_argument('--result-dir', 122 | # default='/home/tcd/train_dir/ocr_densenet/', 123 | # type=int, 124 | # metavar='S', 125 | # help='生成LSTM的feature') 126 | 127 | ##### 128 | parser.add_argument('-j', 129 | '--workers', 130 | default=8, 131 | type=int, 132 | metavar='N', 133 | help='number of data loading workers (default: 32)') 134 | parser.add_argument('--lr', 135 | '--learning-rate', 136 | default=0.001, 137 | type=float, 138 | metavar='LR', 139 | help='initial learning rate') 140 | parser.add_argument('--epochs', 141 | default=10000, 142 | type=int, 143 | metavar='N', 144 | help='number of total epochs to run') 145 | parser.add_argument('--save-freq', 146 | default='5', 147 | type=int, 148 | metavar='S', 149 | help='save frequency') 150 | parser.add_argument('--save-pred-freq', 151 | default='10', 152 | type=int, 153 | metavar='S', 154 | help='save pred clean frequency') 155 | parser.add_argument('--val-freq', 156 | default='5', 157 | type=int, 158 | metavar='S', 159 | help='val frequency') 160 | parser.add_argument('--debug', 161 | default=0, 162 | type=int, 163 | metavar='S', 164 | help='debug') 165 | parser.add_argument('--input-filter', 166 | default=7, 167 | type=int, 168 | metavar='S', 169 | help='val frequency') 170 | parser.add_argument('--use-gan', 171 | default=0, 172 | type=int, 173 | metavar='S', 174 | help='use GAN') 175 | parser.add_argument('--write-pred', 176 | default=0, 177 | type=int, 178 | metavar='S', 179 | help='writ predictions') 180 | parser.add_argument( 181 | '--result-file', 182 | '-r', 183 | type=str, 184 | default=datadir+'/data/result/test_result.csv', 185 | help='result file' 186 | ) 187 | parser.add_argument( 188 | '--output-file', 189 | '-o', 190 | type=str, 191 | default=datadir+'/data/result/test.csv', 192 | help='output file' 193 | ) 194 | args, _ = parser.parse_known_args() 195 | -------------------------------------------------------------------------------- /ocr_densenet/code/ocr/tools/plot.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | def plot_multi_graph(image_list, name_list, save_path=None, show=False): 6 | graph_place = int(np.sqrt(len(name_list) - 1)) + 1 7 | for i, (image, name) in enumerate(zip(image_list, name_list)): 8 | ax1 = plt.subplot(graph_place,graph_place,i+1) 9 | ax1.set_title(name) 10 | # plt.imshow(image,cmap='gray') 11 | plt.imshow(image) 12 | plt.axis('off') 13 | if save_path: 14 | plt.savefig(save_path) 15 | pass 16 | if show: 17 | plt.show() 18 | 19 | def plot_multi_line(x_list, y_list, name_list, save_path=None, show=False): 20 | graph_place = int(np.sqrt(len(name_list) - 1)) + 1 21 | for i, (x, y, name) in enumerate(zip(x_list, y_list, name_list)): 22 | ax1 = plt.subplot(graph_place,graph_place,i+1) 23 | ax1.set_title(name) 24 | plt.plot(x,y) 25 | # plt.imshow(image,cmap='gray') 26 | if save_path: 27 | plt.savefig(save_path) 28 | if show: 29 | plt.show() 30 | 31 | 32 | -------------------------------------------------------------------------------- /ocr_densenet/code/ocr/tools/py_op.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 此文件用于常用python函数的使用 4 | """ 5 | import os 6 | import json 7 | import traceback 8 | from collections import OrderedDict 9 | import random 10 | from fuzzywuzzy import fuzz 11 | 12 | import sys 13 | reload(sys) 14 | sys.setdefaultencoding('utf-8') 15 | 16 | ################################################################################ 17 | ### pre define variables 18 | #:: enumerate 19 | #:: raw_input 20 | #:: listdir 21 | #:: sorted 22 | ### pre define function 23 | def mywritejson(save_path,content): 24 | content = json.dumps(content,indent=4,ensure_ascii=False) 25 | with open(save_path,'w') as f: 26 | f.write(content) 27 | 28 | def myreadjson(load_path): 29 | with open(load_path,'r') as f: 30 | return json.loads(f.read()) 31 | 32 | def mywritefile(save_path,content): 33 | with open(save_path,'w') as f: 34 | f.write(content) 35 | 36 | def myreadfile(load_path): 37 | with open(load_path,'r') as f: 38 | return f.read() 39 | 40 | def myprint(content): 41 | print json.dumps(content,indent=4,ensure_ascii=False) 42 | 43 | def rm(fi): 44 | os.system('rm ' + fi) 45 | 46 | def mystrip(s): 47 | return ''.join(s.split()) 48 | 49 | def mysorteddict(d,key = lambda s:s, reverse=False): 50 | dordered = OrderedDict() 51 | for k in sorted(d.keys(),key = key,reverse=reverse): 52 | dordered[k] = d[k] 53 | return dordered 54 | 55 | def mysorteddictfile(src,obj): 56 | mywritejson(obj,mysorteddict(myreadjson(src))) 57 | 58 | def myfuzzymatch(srcs,objs,grade=80): 59 | matchDict = OrderedDict() 60 | for src in srcs: 61 | for obj in objs: 62 | value = fuzz.partial_ratio(src,obj) 63 | if value > grade: 64 | try: 65 | matchDict[src].append(obj) 66 | except: 67 | matchDict[src] = [obj] 68 | return matchDict 69 | 70 | def mydumps(x): 71 | return json.dumps(content,indent=4,ensure_ascii=False) 72 | 73 | def get_random_list(l,num=-1,isunique=0): 74 | if isunique: 75 | l = set(l) 76 | if num < 0: 77 | num = len(l) 78 | if isunique and num > len(l): 79 | return 80 | lnew = [] 81 | l = list(l) 82 | while(num>len(lnew)): 83 | x = l[int(random.random()*len(l))] 84 | if isunique and x in lnew: 85 | continue 86 | lnew.append(x) 87 | return lnew 88 | 89 | def fuzz_list(node1_list,node2_list,score_baseline=66,proposal_num=10,string_map=None): 90 | node_dict = { } 91 | for i,node1 in enumerate(node1_list): 92 | match_score_dict = { } 93 | for node2 in node2_list: 94 | if node1 != node2: 95 | if string_map is not None: 96 | n1 = string_map(node1) 97 | n2 = string_map(node2) 98 | score = fuzz.partial_ratio(n1,n2) 99 | if n1 == n2: 100 | node2_list.remove(node2) 101 | else: 102 | score = fuzz.partial_ratio(node1,node2) 103 | if score > score_baseline: 104 | match_score_dict[node2] = score 105 | else: 106 | node2_list.remove(node2) 107 | node2_sort = sorted(match_score_dict.keys(), key=lambda k:match_score_dict[k],reverse=True) 108 | node_dict[node1] = [[n,match_score_dict[n]] for n in node2_sort[:proposal_num]] 109 | print i,len(node1_list) 110 | return node_dict, node2_list 111 | 112 | def swap(a,b): 113 | return b, a 114 | 115 | def mkdir(d): 116 | path = d.split('/') 117 | for i in range(len(path)): 118 | d = '/'.join(path[:i+1]) 119 | if not os.path.exists(d): 120 | os.mkdir(d) 121 | 122 | -------------------------------------------------------------------------------- /ocr_densenet/code/ocr/tools/segmentation.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import matplotlib.pyplot as plt 3 | from scipy import ndimage as ndi 4 | from skimage import morphology,color,data 5 | from skimage import filters 6 | import numpy as np 7 | import skimage 8 | import os 9 | from skimage import measure 10 | 11 | 12 | 13 | def watershed(image, label=None): 14 | denoised = filters.rank.median(image, morphology.disk(2)) #过滤噪声 15 | #将梯度值低于10的作为开始标记点 16 | markers = filters.rank.gradient(denoised, morphology.disk(5)) < 10 17 | markers = ndi.label(markers)[0] 18 | 19 | gradient = filters.rank.gradient(denoised, morphology.disk(2)) #计算梯度 20 | labels =morphology.watershed(gradient, markers, mask=image) #基于梯度的分水岭算法 21 | 22 | fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(6, 6)) 23 | axes = axes.ravel() 24 | ax0, ax1, ax2, ax3 = axes 25 | 26 | ax0.imshow(image, cmap=plt.cm.gray, interpolation='nearest') 27 | ax0.set_title("Original") 28 | # ax1.imshow(gradient, cmap=plt.cm.spectral, interpolation='nearest') 29 | ax1.imshow(gradient, cmap=plt.cm.gray, interpolation='nearest') 30 | ax1.set_title("Gradient") 31 | if label is not None: 32 | # ax2.imshow(markers, cmap=plt.cm.spectral, interpolation='nearest') 33 | ax2.imshow(label, cmap=plt.cm.gray, interpolation='nearest') 34 | else: 35 | ax2.imshow(markers, cmap=plt.cm.spectral, interpolation='nearest') 36 | ax2.set_title("Markers") 37 | ax3.imshow(labels, cmap=plt.cm.spectral, interpolation='nearest') 38 | ax3.set_title("Segmented") 39 | 40 | for ax in axes: 41 | ax.axis('off') 42 | 43 | fig.tight_layout() 44 | plt.show() 45 | 46 | def plot_4(image, gradient,label,segmentation, save_path=None): 47 | fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(6, 6)) 48 | axes = axes.ravel() 49 | ax0, ax1, ax2, ax3 = axes 50 | ax0.imshow(image, cmap=plt.cm.gray, interpolation='nearest') 51 | ax0.set_title("Original") 52 | ax1.imshow(gradient, cmap=plt.cm.gray, interpolation='nearest') 53 | ax1.set_title("Gradient") 54 | ax2.imshow(label, cmap=plt.cm.gray, interpolation='nearest') 55 | ax2.set_title("label") 56 | ax3.imshow(segmentation, cmap=plt.cm.spectral, interpolation='nearest') 57 | ax3.set_title("Segmented") 58 | 59 | for ax in axes: 60 | ax.axis('off') 61 | 62 | fig.tight_layout() 63 | if save_path: 64 | print save_path 65 | plt.savefig(save_path) 66 | else: 67 | plt.show() 68 | 69 | def fill(image): 70 | ''' 71 | 填充图片内部空白 72 | 临时写的函数 73 | 建议后期替换 74 | ''' 75 | label_img = measure.label(image, background=1) 76 | props = measure.regionprops(label_img) 77 | max_area = np.array([p.area for p in props]).max() 78 | for i,prop in enumerate(props): 79 | if prop.area < max_area: 80 | image[prop.coords[:,0],prop.coords[:,1]] = 1 81 | return image 82 | 83 | 84 | 85 | def my_watershed(image, label=None, min_gray=480, max_gray=708, min_gradient=5, show=False, save_path='/tmp/x.jpg'): 86 | image = image - min_gray 87 | image[image>max_gray] = 0 88 | image[image< 10] = 0 89 | image = image * 5 90 | 91 | denoised = filters.rank.median(image, morphology.disk(2)) #过滤噪声 92 | #将梯度值低于10的作为开始标记点 93 | markers = filters.rank.gradient(denoised, morphology.disk(5)) < 10 94 | markers = ndi.label(markers)[0] 95 | 96 | gradient = filters.rank.gradient(denoised, morphology.disk(2)) #计算梯度 97 | labels = gradient > min_gradient 98 | 99 | mask = gradient > min_gradient 100 | label_img = measure.label(mask, background=0) 101 | props = measure.regionprops(label_img) 102 | pred = np.zeros_like(gradient) 103 | for i,prop in enumerate(props): 104 | if prop.area > 50: 105 | region = np.array(prop.coords) 106 | vx,vy = region.var(0) 107 | v = vx + vy 108 | if v < 200: 109 | pred[prop.coords[:,0],prop.coords[:,1]] = 1 110 | 111 | # 填充边缘内部空白 112 | pred = fill(pred) 113 | 114 | if show: 115 | plot_4(image, gradient, label, pred) 116 | else: 117 | plot_4(image, gradient, label, pred, save_path) 118 | 119 | return pred 120 | 121 | def segmentation(image_npy, label_npy,save_path): 122 | print image_npy 123 | image = np.load(image_npy) 124 | label = np.load(label_npy) 125 | if np.sum(label) == 0: 126 | return 127 | min_gray,max_gray = 480, 708 128 | my_watershed(image,label,min_gray, max_gray,show=False, save_path=save_path) 129 | 130 | def main(): 131 | data_dir = '/home/yin/all/PVL_DATA/preprocessed/2D/' 132 | save_dir = '/home/yin/all/PVL_DATA/tool_result/' 133 | os.system('rm -r ' + save_dir) 134 | os.system('mkdir ' + save_dir) 135 | for patient in os.listdir(data_dir): 136 | patient_dir = os.path.join(data_dir, patient) 137 | for f in os.listdir(patient_dir): 138 | if 'roi.npy' in f: 139 | label_npy = os.path.join(patient_dir,f) 140 | image_npy = label_npy.replace('.roi.npy','.npy') 141 | segmentation(image_npy,label_npy, os.path.join(save_dir,label_npy.strip('/').replace('/','.').replace('npy','jpg'))) 142 | 143 | if __name__ == '__main__': 144 | # image =color.rgb2gray(data.camera()) 145 | # watershed(image) 146 | main() 147 | image_npy = '/home/yin/all/PVL_DATA/preprocessed/2D/JD_chen_xi/23.npy' 148 | image_npy = '/home/yin/all/PVL_DATA/preprocessed/2D/JD_chen_xi/14.npy' 149 | image_npy = '/home/yin/all/PVL_DATA/preprocessed/2D/JD_zhang_yu_chen/23.npy' 150 | label_npy = image_npy.replace('.npy','.roi.npy') 151 | segmentation(image_npy,label_npy) 152 | 153 | 154 | -------------------------------------------------------------------------------- /ocr_densenet/code/ocr/tools/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (c) 2017 www.drcubic.com, Inc. All Rights Reserved 5 | # 6 | """ 7 | File: utils.py 8 | Author: shileicao(shileicao@stu.xjtu.edu.cn) 9 | Date: 2017-06-20 14:56:54 10 | 11 | **Note.** This code absorb some code from following source. 12 | 1. [DSB2017](https://github.com/lfz/DSB2017) 13 | """ 14 | 15 | import os 16 | import sys 17 | 18 | import numpy as np 19 | import torch 20 | 21 | 22 | def getFreeId(): 23 | import pynvml 24 | 25 | pynvml.nvmlInit() 26 | 27 | def getFreeRatio(id): 28 | handle = pynvml.nvmlDeviceGetHandleByIndex(id) 29 | use = pynvml.nvmlDeviceGetUtilizationRates(handle) 30 | ratio = 0.5 * (float(use.gpu + float(use.memory))) 31 | return ratio 32 | 33 | deviceCount = pynvml.nvmlDeviceGetCount() 34 | available = [] 35 | for i in range(deviceCount): 36 | if getFreeRatio(i) < 70: 37 | available.append(i) 38 | gpus = '' 39 | for g in available: 40 | gpus = gpus + str(g) + ',' 41 | gpus = gpus[:-1] 42 | return gpus 43 | 44 | 45 | def setgpu(gpuinput): 46 | freeids = getFreeId() 47 | if gpuinput == 'all': 48 | gpus = freeids 49 | else: 50 | gpus = gpuinput 51 | busy_gpu = [g not in freeids for g in gpus.split(',')] 52 | if any(busy_gpu): 53 | raise ValueError('gpu' + ' '.join(busy_gpu) + 'is being used') 54 | print('using gpu ' + gpus) 55 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus 56 | return len(gpus.split(',')) 57 | 58 | 59 | def error_mask_stats(labels, filenames): 60 | error_f = [] 61 | for i, f in enumerate(filenames): 62 | # if not np.all(labels[i] > 0): 63 | # error_f.append(f) 64 | for bbox_i in range(labels[i].shape[0]): 65 | imgs = np.load(f) 66 | if not np.all( 67 | np.array(imgs.shape[1:]) - labels[i][bbox_i][:-1] > 0): 68 | error_f.append(f) 69 | error_f = list(set(error_f)) 70 | fileid_list = [os.path.split(filename)[1].split('_')[0] 71 | for filename in error_f] 72 | print("','".join(fileid_list)) 73 | return error_f 74 | 75 | 76 | class Logger(object): 77 | def __init__(self, logfile): 78 | self.terminal = sys.stdout 79 | self.log = open(logfile, "a") 80 | 81 | def write(self, message): 82 | self.terminal.write(message) 83 | self.log.write(message) 84 | 85 | def flush(self): 86 | #this flush method is needed for python 3 compatibility. 87 | #this handles the flush command by doing nothing. 88 | #you might want to specify some extra behavior here. 89 | pass 90 | 91 | 92 | def split4(data, max_stride, margin): 93 | splits = [] 94 | data = torch.Tensor.numpy(data) 95 | _, c, z, h, w = data.shape 96 | 97 | w_width = np.ceil(float(w / 2 + margin) / 98 | max_stride).astype('int') * max_stride 99 | h_width = np.ceil(float(h / 2 + margin) / 100 | max_stride).astype('int') * max_stride 101 | pad = int(np.ceil(float(z) / max_stride) * max_stride) - z 102 | leftpad = pad / 2 103 | pad = [[0, 0], [0, 0], [leftpad, pad - leftpad], [0, 0], [0, 0]] 104 | data = np.pad(data, pad, 'constant', constant_values=-1) 105 | data = torch.from_numpy(data) 106 | splits.append(data[:, :, :, :h_width, :w_width]) 107 | splits.append(data[:, :, :, :h_width, -w_width:]) 108 | splits.append(data[:, :, :, -h_width:, :w_width]) 109 | splits.append(data[:, :, :, -h_width:, -w_width:]) 110 | 111 | return torch.cat(splits, 0) 112 | 113 | 114 | def combine4(output, h, w): 115 | splits = [] 116 | for i in range(len(output)): 117 | splits.append(output[i]) 118 | 119 | output = np.zeros( 120 | (splits[0].shape[0], h, w, splits[0].shape[3], 121 | splits[0].shape[4]), np.float32) 122 | 123 | h0 = output.shape[1] / 2 124 | h1 = output.shape[1] - h0 125 | w0 = output.shape[2] / 2 126 | w1 = output.shape[2] - w0 127 | 128 | splits[0] = splits[0][:, :h0, :w0, :, :] 129 | output[:, :h0, :w0, :, :] = splits[0] 130 | 131 | splits[1] = splits[1][:, :h0, -w1:, :, :] 132 | output[:, :h0, -w1:, :, :] = splits[1] 133 | 134 | splits[2] = splits[2][:, -h1:, :w0, :, :] 135 | output[:, -h1:, :w0, :, :] = splits[2] 136 | 137 | splits[3] = splits[3][:, -h1:, -w1:, :, :] 138 | output[:, -h1:, -w1:, :, :] = splits[3] 139 | 140 | return output 141 | 142 | 143 | def split8(data, max_stride, margin): 144 | splits = [] 145 | if isinstance(data, np.ndarray): 146 | c, z, h, w = data.shape 147 | else: 148 | _, c, z, h, w = data.size() 149 | 150 | z_width = np.ceil(float(z / 2 + margin) / 151 | max_stride).astype('int') * max_stride 152 | w_width = np.ceil(float(w / 2 + margin) / 153 | max_stride).astype('int') * max_stride 154 | h_width = np.ceil(float(h / 2 + margin) / 155 | max_stride).astype('int') * max_stride 156 | for zz in [[0, z_width], [-z_width, None]]: 157 | for hh in [[0, h_width], [-h_width, None]]: 158 | for ww in [[0, w_width], [-w_width, None]]: 159 | if isinstance(data, np.ndarray): 160 | splits.append(data[np.newaxis, :, zz[0]:zz[1], hh[0]:hh[1], 161 | ww[0]:ww[1]]) 162 | else: 163 | splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]: 164 | ww[1]]) 165 | 166 | if isinstance(data, np.ndarray): 167 | return np.concatenate(splits, 0) 168 | else: 169 | return torch.cat(splits, 0) 170 | 171 | 172 | def combine8(output, z, h, w): 173 | splits = [] 174 | for i in range(len(output)): 175 | splits.append(output[i]) 176 | 177 | output = np.zeros( 178 | (z, h, w, splits[0].shape[3], splits[0].shape[4]), np.float32) 179 | 180 | z_width = z / 2 181 | h_width = h / 2 182 | w_width = w / 2 183 | i = 0 184 | for zz in [[0, z_width], [z_width - z, None]]: 185 | for hh in [[0, h_width], [h_width - h, None]]: 186 | for ww in [[0, w_width], [w_width - w, None]]: 187 | output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[ 188 | i][zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] 189 | i = i + 1 190 | 191 | return output 192 | 193 | 194 | def split16(data, max_stride, margin): 195 | splits = [] 196 | _, c, z, h, w = data.size() 197 | 198 | z_width = np.ceil(float(z / 4 + margin) / 199 | max_stride).astype('int') * max_stride 200 | z_pos = [z * 3 / 8 - z_width / 2, z * 5 / 8 - z_width / 2] 201 | h_width = np.ceil(float(h / 2 + margin) / 202 | max_stride).astype('int') * max_stride 203 | w_width = np.ceil(float(w / 2 + margin) / 204 | max_stride).astype('int') * max_stride 205 | for zz in [[0, z_width], [z_pos[0], z_pos[0] + z_width], 206 | [z_pos[1], z_pos[1] + z_width], [-z_width, None]]: 207 | for hh in [[0, h_width], [-h_width, None]]: 208 | for ww in [[0, w_width], [-w_width, None]]: 209 | splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[ 210 | 1]]) 211 | 212 | return torch.cat(splits, 0) 213 | 214 | 215 | def combine16(output, z, h, w): 216 | splits = [] 217 | for i in range(len(output)): 218 | splits.append(output[i]) 219 | 220 | output = np.zeros( 221 | (z, h, w, splits[0].shape[3], splits[0].shape[4]), np.float32) 222 | 223 | z_width = z / 4 224 | h_width = h / 2 225 | w_width = w / 2 226 | splitzstart = splits[0].shape[0] / 2 - z_width / 2 227 | z_pos = [z * 3 / 8 - z_width / 2, z * 5 / 8 - z_width / 2] 228 | i = 0 229 | for zz, zz2 in zip( 230 | [[0, z_width], [z_width, z_width * 2], [z_width * 2, z_width * 3], 231 | [z_width * 3 - z, None]], 232 | [[0, z_width], [splitzstart, z_width + splitzstart], 233 | [splitzstart, z_width + splitzstart], [z_width * 3 - z, None]]): 234 | for hh in [[0, h_width], [h_width - h, None]]: 235 | for ww in [[0, w_width], [w_width - w, None]]: 236 | output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[ 237 | i][zz2[0]:zz2[1], hh[0]:hh[1], ww[0]:ww[1], :, :] 238 | i = i + 1 239 | 240 | return output 241 | 242 | 243 | def split32(data, max_stride, margin): 244 | splits = [] 245 | _, c, z, h, w = data.size() 246 | 247 | z_width = np.ceil(float(z / 2 + margin) / 248 | max_stride).astype('int') * max_stride 249 | w_width = np.ceil(float(w / 4 + margin) / 250 | max_stride).astype('int') * max_stride 251 | h_width = np.ceil(float(h / 4 + margin) / 252 | max_stride).astype('int') * max_stride 253 | 254 | w_pos = [w * 3 / 8 - w_width / 2, w * 5 / 8 - w_width / 2] 255 | h_pos = [h * 3 / 8 - h_width / 2, h * 5 / 8 - h_width / 2] 256 | 257 | for zz in [[0, z_width], [-z_width, None]]: 258 | for hh in [[0, h_width], [h_pos[0], h_pos[0] + h_width], 259 | [h_pos[1], h_pos[1] + h_width], [-h_width, None]]: 260 | for ww in [[0, w_width], [w_pos[0], w_pos[0] + w_width], 261 | [w_pos[1], w_pos[1] + w_width], [-w_width, None]]: 262 | splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[ 263 | 1]]) 264 | 265 | return torch.cat(splits, 0) 266 | 267 | 268 | def combine32(splits, z, h, w): 269 | 270 | output = np.zeros( 271 | (z, h, w, splits[0].shape[3], splits[0].shape[4]), np.float32) 272 | 273 | z_width = int(np.ceil(float(z) / 2)) 274 | h_width = int(np.ceil(float(h) / 4)) 275 | w_width = int(np.ceil(float(w) / 4)) 276 | splithstart = splits[0].shape[1] / 2 - h_width / 2 277 | splitwstart = splits[0].shape[2] / 2 - w_width / 2 278 | 279 | i = 0 280 | for zz in [[0, z_width], [z_width - z, None]]: 281 | 282 | for hh, hh2 in zip( 283 | [[0, h_width], [h_width, h_width * 2], [h_width * 2, h_width * 3], 284 | [h_width * 3 - h, None]], 285 | [[0, h_width], [splithstart, h_width + splithstart], 286 | [splithstart, h_width + splithstart], [h_width * 3 - h, None]]): 287 | 288 | for ww, ww2 in zip( 289 | [[0, w_width], [w_width, w_width * 2], 290 | [w_width * 2, w_width * 3], [w_width * 3 - w, None]], 291 | [[0, w_width], [splitwstart, w_width + splitwstart], 292 | [splitwstart, w_width + splitwstart], 293 | [w_width * 3 - w, None]]): 294 | 295 | output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[ 296 | i][zz[0]:zz[1], hh2[0]:hh2[1], ww2[0]:ww2[1], :, :] 297 | i = i + 1 298 | 299 | return output 300 | 301 | 302 | def split64(data, max_stride, margin): 303 | splits = [] 304 | _, c, z, h, w = data.size() 305 | 306 | z_width = np.ceil(float(z / 4 + margin) / 307 | max_stride).astype('int') * max_stride 308 | w_width = np.ceil(float(w / 4 + margin) / 309 | max_stride).astype('int') * max_stride 310 | h_width = np.ceil(float(h / 4 + margin) / 311 | max_stride).astype('int') * max_stride 312 | 313 | z_pos = [z * 3 / 8 - z_width / 2, z * 5 / 8 - z_width / 2] 314 | w_pos = [w * 3 / 8 - w_width / 2, w * 5 / 8 - w_width / 2] 315 | h_pos = [h * 3 / 8 - h_width / 2, h * 5 / 8 - h_width / 2] 316 | 317 | for zz in [[0, z_width], [z_pos[0], z_pos[0] + z_width], 318 | [z_pos[1], z_pos[1] + z_width], [-z_width, None]]: 319 | for hh in [[0, h_width], [h_pos[0], h_pos[0] + h_width], 320 | [h_pos[1], h_pos[1] + h_width], [-h_width, None]]: 321 | for ww in [[0, w_width], [w_pos[0], w_pos[0] + w_width], 322 | [w_pos[1], w_pos[1] + w_width], [-w_width, None]]: 323 | splits.append(data[:, :, zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[ 324 | 1]]) 325 | 326 | return torch.cat(splits, 0) 327 | 328 | 329 | def combine64(output, z, h, w): 330 | splits = [] 331 | for i in range(len(output)): 332 | splits.append(output[i]) 333 | 334 | output = np.zeros( 335 | (z, h, w, splits[0].shape[3], splits[0].shape[4]), np.float32) 336 | 337 | z_width = int(np.ceil(float(z) / 4)) 338 | h_width = int(np.ceil(float(h) / 4)) 339 | w_width = int(np.ceil(float(w) / 4)) 340 | splitzstart = splits[0].shape[0] / 2 - z_width / 2 341 | splithstart = splits[0].shape[1] / 2 - h_width / 2 342 | splitwstart = splits[0].shape[2] / 2 - w_width / 2 343 | 344 | i = 0 345 | for zz, zz2 in zip( 346 | [[0, z_width], [z_width, z_width * 2], [z_width * 2, z_width * 3], 347 | [z_width * 3 - z, None]], 348 | [[0, z_width], [splitzstart, z_width + splitzstart], 349 | [splitzstart, z_width + splitzstart], [z_width * 3 - z, None]]): 350 | 351 | for hh, hh2 in zip( 352 | [[0, h_width], [h_width, h_width * 2], [h_width * 2, h_width * 3], 353 | [h_width * 3 - h, None]], 354 | [[0, h_width], [splithstart, h_width + splithstart], 355 | [splithstart, h_width + splithstart], [h_width * 3 - h, None]]): 356 | 357 | for ww, ww2 in zip( 358 | [[0, w_width], [w_width, w_width * 2], 359 | [w_width * 2, w_width * 3], [w_width * 3 - w, None]], 360 | [[0, w_width], [splitwstart, w_width + splitwstart], 361 | [splitwstart, w_width + splitwstart], 362 | [w_width * 3 - w, None]]): 363 | 364 | output[zz[0]:zz[1], hh[0]:hh[1], ww[0]:ww[1], :, :] = splits[ 365 | i][zz2[0]:zz2[1], hh2[0]:hh2[1], ww2[0]:ww2[1], :, :] 366 | i = i + 1 367 | 368 | return output 369 | -------------------------------------------------------------------------------- /ocr_densenet/code/preprocessing/analysis_dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | ######################################################################### 3 | # File Name: analysis_dataset.py 4 | # Author: ccyin 5 | # mail: ccyin04@gmail.com 6 | # Created Time: Fri 18 May 2018 04:19:58 PM CST 7 | ######################################################################### 8 | ''' 9 | 此文件用于分析原有数据集信息 10 | stati_image_size: 统计图片大小信息 11 | stati_label_length: 统计文字长度信息 12 | ''' 13 | 14 | import os 15 | import json 16 | from PIL import Image 17 | import numpy as np 18 | from tqdm import tqdm 19 | import sys 20 | sys.path.append('../ocr') 21 | from tools import plot 22 | 23 | def stati_image_size(image_dir, save_dir, big_w_dir): 24 | if not os.path.exists(big_w_dir): 25 | os.mkdir(big_w_dir) 26 | if not os.path.exists(save_dir): 27 | os.mkdir(save_dir) 28 | h_count_dict, w_count_dict, r_count_dict = { }, { }, { } 29 | image_hw_ratio_dict = { } 30 | for image in os.listdir(image_dir): 31 | h,w = Image.open(os.path.join(image_dir, image)).size 32 | if w > 80: 33 | cmd = 'cp ../../data/train/{:s} {:s}'.format(image, big_w_dir) 34 | # os.system(cmd) 35 | 36 | r = int(h / 8. / w) 37 | h = h / 10 38 | w = w / 10 39 | r_count_dict[r] = r_count_dict.get(r, 0) + 1 40 | h_count_dict[h] = h_count_dict.get(h, 0) + 1 41 | w_count_dict[w] = w_count_dict.get(w, 0) + 1 42 | image_hw_ratio_dict[image] = r 43 | 44 | with open(os.path.join(save_dir, 'image_hw_ratio_dict.json'), 'w') as f: 45 | f.write(json.dumps(image_hw_ratio_dict, indent=4)) 46 | 47 | x = range(max(h_count_dict.keys())+1) 48 | y = [0 for _ in x] 49 | for h in sorted(h_count_dict.keys()): 50 | print '图片长度:{:d}~{:d},有{:d}张图'.format(10*h, 10*h+10, h_count_dict[h]) 51 | y[h] = h_count_dict[h] 52 | plot.plot_multi_line([x], [y], ['Length'], save_path='../../data/length.png', show=True) 53 | 54 | x = range(max(w_count_dict.keys())+1) 55 | y = [0 for _ in x] 56 | for w in sorted(w_count_dict.keys()): 57 | print '图片宽度:{:d}~{:d},有{:d}张图'.format(10*w, 10*w+10, w_count_dict[w]) 58 | y[w] = w_count_dict[w] 59 | plot.plot_multi_line([x], [y], ['Width'], save_path='../../data/width.png', show=True) 60 | 61 | x = range(max(r_count_dict.keys())+1) 62 | y = [0 for _ in x] 63 | for r in sorted(r_count_dict.keys()): 64 | print '图片比例:{:d}~{:d},有{:d}张图'.format(8*r, 8*r+8, r_count_dict[r]) 65 | y[r] = r_count_dict[r] 66 | x = [8*(_+1) for _ in x] 67 | plot.plot_multi_line([x], [y], ['L/W'], save_path='../../data/ratio.png', show=True) 68 | 69 | print '\n最多的长\n', sorted(h_count_dict.keys(), key=lambda h:h_count_dict[h])[-1] * 10 70 | print '\n最多的宽\n', sorted(w_count_dict.keys(), key=lambda w:w_count_dict[w])[-1] * 10 71 | 72 | print '建议使用 64 * 512 的输入' 73 | print ' 部分使用 64 * 1024 的输入' 74 | print ' 剩下的忽略' 75 | print '建议使用FCN来做,全局取最大值得到最终结果' 76 | 77 | def stati_label_length(label_json, long_text_dir): 78 | if not os.path.exists(long_text_dir): 79 | os.mkdir(long_text_dir) 80 | image_label_json = json.load(open(label_json)) 81 | l_count_dict = { } 82 | for image, label in image_label_json.items(): 83 | l = len(label.split()) 84 | l_count_dict[l] = l_count_dict.get(l, 0) + 1 85 | if l > 25: 86 | cmd = 'cp ../../data/train/{:s} {:s}'.format(image, long_text_dir) 87 | # os.system(cmd) 88 | 89 | word_num = 0. 90 | x = range(max(l_count_dict.keys())+1) 91 | y = [0 for _ in x] 92 | for l in sorted(l_count_dict.keys()): 93 | word_num += l * l_count_dict[l] 94 | print '文字长度:{:d},有{:d}张图'.format(l, l_count_dict[l]) 95 | y[l] = l_count_dict[l] 96 | plot.plot_multi_line([x], [y], ['Word Number'], save_path='../../data/word_num.png', show=True) 97 | print '平均每张图片{:3.4f}个字'.format(word_num / sum(l_count_dict.values())) 98 | 99 | def stati_image_gray(image_dir): 100 | print 'eval train image gray' 101 | for image in tqdm(os.listdir(image_dir)): 102 | image = Image.open(os.path.join(image_dir, image)).convert('RGB') 103 | image = np.array(image) 104 | mi,ma = image.min(), image.max() 105 | assert mi >= 0 106 | assert ma < 256 107 | 108 | print 'eval test image gray' 109 | image_dir = image_dir.replace('train', 'test') 110 | for image in tqdm(os.listdir(image_dir)): 111 | image = Image.open(os.path.join(image_dir, image)).convert('RGB') 112 | image = np.array(image) 113 | mi,ma = image.min(), image.max() 114 | assert mi >= 0 115 | assert ma < 256 116 | 117 | 118 | 119 | def main(): 120 | image_dir = '../../data/train' 121 | save_dir = '../../files/' 122 | big_w_dir = '../../data/big_w_dir' 123 | stati_image_size(image_dir, save_dir, big_w_dir) 124 | 125 | train_label_json = '../../files/train_alphabet.json' 126 | long_text_dir = '../../data/long_text_dir' 127 | stati_label_length(train_label_json, long_text_dir) 128 | # stati_image_gray(image_dir) 129 | 130 | if __name__ == '__main__': 131 | main() 132 | -------------------------------------------------------------------------------- /ocr_densenet/code/preprocessing/map_word_to_index.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | ######################################################################### 3 | # File Name: map_word_to_index.py 4 | # Author: ccyin 5 | # mail: ccyin04@gmail.com 6 | # Created Time: Fri 18 May 2018 03:30:26 PM CST 7 | ######################################################################### 8 | ''' 9 | 此代码用于将所有文字映射到index上,有两种方式 10 | 1. 映射每一个英文单词为一个index 11 | 2. 映射每一个英文字母为一个index 12 | ''' 13 | 14 | import os 15 | import sys 16 | reload(sys) 17 | sys.setdefaultencoding('utf8') 18 | import json 19 | from collections import OrderedDict 20 | 21 | def map_word_to_index(train_word_file, word_index_json, word_count_json, index_label_json, alphabet_to_index=True): 22 | with open(train_word_file, 'r') as f: 23 | labels = f.read().strip().decode('utf8') 24 | word_count_dict = { } 25 | for line in labels.split('\n')[1:]: 26 | line = line.strip() 27 | image, sentence = line.strip().split('.jpg,') 28 | sentence = sentence.strip('"') 29 | for w in sentence: 30 | word_count_dict[w] = word_count_dict.get(w,0) + 1 31 | print '一共有{:d}种字符,共{:d}个'.format(len(word_count_dict), sum(word_count_dict.values())) 32 | word_sorted = sorted(word_count_dict.keys(), key=lambda k:word_count_dict[k], reverse=True) 33 | # word_index_dict = { w:i for i,w in enumerate(word_sorted) } 34 | word_index_dict = json.load(open(word_index_json)) 35 | 36 | with open(word_count_json, 'w') as f: 37 | f.write(json.dumps(word_count_dict, indent=4, ensure_ascii=False)) 38 | with open(word_index_json, 'w') as f: 39 | f.write(json.dumps(word_index_dict, indent=4, ensure_ascii=False)) 40 | 41 | image_label_dict = OrderedDict() 42 | for line in labels.split('\n')[1:]: 43 | line = line.strip() 44 | image, sentence = line.strip().split('.jpg,') 45 | sentence = sentence.strip('"') 46 | 47 | # 换掉部分相似符号 48 | for c in u"  ": 49 | sentence = sentence.replace(c, '') 50 | replace_words = [ 51 | u'((', 52 | u'))', 53 | u',,', 54 | u"´'′", 55 | u"″"“", 56 | u"..", 57 | u"—-" 58 | ] 59 | for words in replace_words: 60 | for w in words[:-1]: 61 | sentence = sentence.replace(w, words[-1]) 62 | 63 | index_list = [] 64 | for w in sentence: 65 | index_list.append(str(word_index_dict[w])) 66 | image_label_dict[image + '.jpg'] = ' '.join(index_list) 67 | with open(index_label_json, 'w') as f: 68 | f.write(json.dumps(image_label_dict, indent=4)) 69 | 70 | 71 | def main(): 72 | 73 | # 映射字母为index 74 | train_word_file = '../../files/train.csv' 75 | word_index_json = '../../files/alphabet_index_dict.json' 76 | word_count_json = '../../files/alphabet_count_dict.json' 77 | index_label_json = '../../files/train_alphabet.json' 78 | map_word_to_index(train_word_file, word_index_json, word_count_json, index_label_json, True) 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /ocr_densenet/code/preprocessing/show_black.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | ######################################################################### 3 | # File Name: show_black.py 4 | # Author: ccyin 5 | # mail: ccyin04@gmail.com 6 | # Created Time: 2018年06月07日 星期四 01时06分22秒 7 | ######################################################################### 8 | 9 | import os 10 | import sys 11 | import json 12 | sys.path.append('../ocr') 13 | from tools import parse, py_op 14 | args = parse.args 15 | 16 | def cp_black_list(black_json, black_dir): 17 | word_index_dict = json.load(open(args.word_index_json)) 18 | index_word_dict = { v:k for k,v in word_index_dict.items() } 19 | train_word_dict = json.load(open(args.image_label_json)) 20 | train_word_dict = { k:''.join([index_word_dict[int(i)] for i in v.split()]) for k,v in train_word_dict.items() } 21 | 22 | py_op.mkdir(black_dir) 23 | black_list = json.load(open(black_json))['black_list'] 24 | for i,name in enumerate(black_list): 25 | cmd = 'cp {:s} {:s}'.format(os.path.join(args.data_dir, 'train', name), black_dir) 26 | if train_word_dict[name] in ['Err:501', '#NAME?', '###']: 27 | continue 28 | print name 29 | print train_word_dict[name] 30 | os.system(cmd) 31 | if i > 30: 32 | break 33 | 34 | if __name__ == '__main__': 35 | black_dir = os.path.join(args.save_dir, 'black') 36 | cp_black_list(args.black_json, black_dir) 37 | -------------------------------------------------------------------------------- /ocr_densenet/files/ttf/simsun.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataFountainCode/huawei_code_share/f1ef76649ea5c87a7be2d93dfaec1ff9a4d3e4b5/ocr_densenet/files/ttf/simsun.ttf -------------------------------------------------------------------------------- /ocr_densenet/make_test_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Feb 3 14:33:52 2019 5 | 6 | @author: tcd 7 | """ 8 | import os 9 | import cv2 10 | import pandas as pd 11 | 12 | # make test data 13 | datapath = '/data/chinese/test_dataset/' 14 | input_file = '/home/tcd/EAST/output.txt' 15 | output_file = '/home/tcd/ocr_densenet/submission.csv' 16 | testpath = '/home/tcd/ocr_densenet/data/dataset/test/' 17 | submit_example = '/home/tcd/submit_example.csv' 18 | ifmax = True 19 | lists = [] 20 | 21 | names=['filename', 'x1', 'y1', 'x2', 'y2', 'x3', 'y3', 'x4', 'y4', 'text'] 22 | ex = pd.read_csv(submit_example) 23 | ex = ex.drop(names[1:], axis=1) 24 | f = pd.read_csv(input_file, names=names[:-1], encoding='utf-8') 25 | f['filename'] = f['filename']+'.jpg' 26 | ex.columns = ['filename'] 27 | ex = pd.merge(ex, f, how='left', on=['filename']) 28 | ex = ex.fillna('None') 29 | ex['target_file']=[str(x)+'to' for x in ex.index] + ex['filename'] 30 | ex.to_csv(output_file, header=True, index=None, encoding='utf-8') 31 | 32 | with open(output_file, 'r') as f: 33 | for ff in os.listdir(testpath): 34 | os.remove(testpath+ff) 35 | print('removed') 36 | for line in f.readlines(): 37 | l = line.strip().split(',') 38 | if l[-1] == 'target_file' or l[-1] == 'None' or l[1] == 'None': 39 | continue 40 | roi = [] 41 | point = [] 42 | for i in range(1, 9): 43 | l[i] = int(float(l[i])) 44 | point.append(l[i]) 45 | if i % 2 == 0: 46 | roi.append(point) 47 | point = [] 48 | if ifmax: 49 | xmin = min([roi[x][0] for x in range(4)]) 50 | xmax = max([roi[x][0] for x in range(4)]) 51 | ymin = min([roi[x][1] for x in range(4)]) 52 | ymax = max([roi[x][1] for x in range(4)]) 53 | if xmin < 0: 54 | xmin = 0 55 | if ymin < 0: 56 | ymin = 0 57 | im = cv2.imread(datapath + l[0]) 58 | im = im[ymin:ymax, xmin:xmax] 59 | cv2.imwrite(testpath + l[-1], im) 60 | print('image croped in...', testpath) 61 | print('test dataset done...') 62 | -------------------------------------------------------------------------------- /ocr_densenet/makedata.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun Feb 3 14:33:52 2019 5 | 6 | @author: tcd 7 | """ 8 | import os 9 | import cv2 10 | import pandas as pd 11 | train_csv = '/data/chinese/train_lable.csv' 12 | targetpath = '/home/tcd/ocr_densenet/data/train/' 13 | datapath = '/data/chinese/train_dataset/' 14 | with open(train_csv, 'r') as f: 15 | for ff in os.listdir(targetpath): 16 | os.remove(targetpath+ff) 17 | print('removed') 18 | ii = 0 19 | name = [] 20 | filename = [] 21 | for line in f.readlines(): 22 | l = line.strip().split(',') 23 | if l[-1] == 'text': 24 | continue 25 | roi = [] 26 | point = [] 27 | for i in range(1, 9): 28 | l[i] = int(float(l[i])) 29 | point.append(l[i]) 30 | if i % 2 == 0: 31 | roi.append(point) 32 | point = [] 33 | xmin = min([roi[x][0] for x in range(4)]) 34 | xmax = max([roi[x][0] for x in range(4)]) 35 | ymin = min([roi[x][1] for x in range(4)]) 36 | ymax = max([roi[x][1] for x in range(4)]) 37 | if xmin < 0: 38 | xmin = 0 39 | if ymin < 0: 40 | ymin = 0 41 | im = cv2.imread(datapath + l[0]) 42 | im = im[ymin:ymax, xmin:xmax] 43 | target = str(ii) + 'to' + l[0] 44 | cv2.imwrite(targetpath + target, im) 45 | name.append(l[-1].decode('utf-8')) 46 | filename.append(target.decode('utf-8')) 47 | ii+=1 48 | print('image croped in...', datapath) 49 | 50 | train = pd.DataFrame(columns=['name', 'content']) 51 | train['name'] = filename 52 | train['content'] = name 53 | train.to_csv('/home/tcd/ocr_densenet/files/train.csv', header=True, index=None, encoding='utf-8') 54 | --------------------------------------------------------------------------------