├── pse ├── log.txt ├── include │ ├── clipper │ │ └── clipper.cpp │ └── pybind11 │ │ ├── common.h │ │ ├── detail │ │ ├── typeid.h │ │ └── descr.h │ │ ├── complex.h │ │ ├── options.h │ │ ├── functional.h │ │ ├── eval.h │ │ ├── buffer_info.h │ │ ├── iostream.h │ │ └── chrono.h ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-37.pyc ├── __main__.py ├── Makefile ├── __init__.py ├── adaptor.cpp ├── .ycm_extra_conf.py ├── adaptor_2.cpp ├── adaptor_1.cpp ├── adaptor_3.cpp └── lanms.h ├── util ├── demo_nn │ ├── tf2.py │ └── pytorch_demo.py ├── test_open_read.py ├── adaptor.so ├── __pycache__ │ ├── cmd.cpython-36.pyc │ ├── cmd.cpython-37.pyc │ ├── dec.cpython-36.pyc │ ├── dec.cpython-37.pyc │ ├── img.cpython-36.pyc │ ├── img.cpython-37.pyc │ ├── io_.cpython-36.pyc │ ├── io_.cpython-37.pyc │ ├── log.cpython-36.pyc │ ├── log.cpython-37.pyc │ ├── ml.cpython-36.pyc │ ├── ml.cpython-37.pyc │ ├── mod.cpython-36.pyc │ ├── mod.cpython-37.pyc │ ├── np.cpython-36.pyc │ ├── np.cpython-37.pyc │ ├── url.cpython-36.pyc │ ├── url.cpython-37.pyc │ ├── caffe_.cpython-36.pyc │ ├── caffe_.cpython-37.pyc │ ├── dtype.cpython-36.pyc │ ├── dtype.cpython-37.pyc │ ├── event.cpython-36.pyc │ ├── event.cpython-37.pyc │ ├── logger.cpython-36.pyc │ ├── logger.cpython-37.pyc │ ├── misc.cpython-36.pyc │ ├── misc.cpython-37.pyc │ ├── proc.cpython-36.pyc │ ├── proc.cpython-37.pyc │ ├── rand.cpython-36.pyc │ ├── rand.cpython-37.pyc │ ├── str_.cpython-36.pyc │ ├── str_.cpython-37.pyc │ ├── test.cpython-36.pyc │ ├── test.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── feature.cpython-36.pyc │ ├── feature.cpython-37.pyc │ ├── thread_.cpython-36.pyc │ ├── thread_.cpython-37.pyc │ ├── neighbour.cpython-36.pyc │ └── neighbour.cpython-37.pyc ├── cmd.py ├── test.py ├── statistic.py ├── event.py ├── url.py ├── feature.py ├── t.py ├── dtype.py ├── ml.py ├── rand.py ├── opencv_demo.py ├── proc.py ├── log.py ├── mod.py ├── __init__.py ├── thread_.py ├── mask.py ├── misc.py ├── str_.py ├── neighbour.py ├── caffe_.py ├── dec.py ├── np.py ├── logger.py ├── io_.py └── plt.py ├── MobileNetV3 ├── .gitignore ├── README.md ├── LICENSE ├── mobilenet_v3_small.py └── mobilenet_v3_large.py ├── eval ├── eval_ctw1500.sh └── ctw1500 │ ├── file_util.pyc │ ├── __pycache__ │ ├── file_util.cpython-36.pyc │ └── file_util.cpython-37.pyc │ ├── file_util.py │ └── eval_ctw1500.py ├── figure ├── pse.png ├── res0.png └── pipeline.png ├── MobileNetV2 ├── images │ ├── net.jpg │ ├── stru.jpg │ └── MobileNetv2.png ├── data │ └── convert.py ├── LICENSE ├── .gitignore ├── README.md ├── train.py └── mobilenet_v2.py ├── __pycache__ ├── pypse.cpython-36.pyc ├── pypse.cpython-37.pyc ├── metrics.cpython-36.pyc └── metrics.cpython-37.pyc ├── dataset ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── ctw1500_loader.cpython-36.pyc │ ├── ctw1500_loader.cpython-37.pyc │ ├── ctw1500_test_loader.cpython-36.pyc │ └── ctw1500_test_loader.cpython-37.pyc ├── __init__.py └── ctw1500_test_loader.py ├── models ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── fpn_resnet.cpython-36.pyc │ ├── fpn_resnet.cpython-37.pyc │ └── mobilenet_v3_block.cpython-37.pyc ├── __init__.py └── mobilenet_v3_block.py ├── .github └── FUNDING.yml ├── metrics.py ├── pypse.py ├── README.md ├── test_ctw1500.py └── test_id41k.py /pse/log.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/demo_nn/tf2.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MobileNetV3/.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /util/test_open_read.py: -------------------------------------------------------------------------------- 1 | test_path = '' 2 | for pa in -------------------------------------------------------------------------------- /eval/eval_ctw1500.sh: -------------------------------------------------------------------------------- 1 | cd ctw1500 2 | python eval_ctw1500.py 3 | cd .. -------------------------------------------------------------------------------- /figure/pse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/figure/pse.png -------------------------------------------------------------------------------- /figure/res0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/figure/res0.png -------------------------------------------------------------------------------- /util/adaptor.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/adaptor.so -------------------------------------------------------------------------------- /figure/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/figure/pipeline.png -------------------------------------------------------------------------------- /MobileNetV2/images/net.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/MobileNetV2/images/net.jpg -------------------------------------------------------------------------------- /MobileNetV2/images/stru.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/MobileNetV2/images/stru.jpg -------------------------------------------------------------------------------- /eval/ctw1500/file_util.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/eval/ctw1500/file_util.pyc -------------------------------------------------------------------------------- /__pycache__/pypse.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/__pycache__/pypse.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/pypse.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/__pycache__/pypse.cpython-37.pyc -------------------------------------------------------------------------------- /pse/include/clipper/clipper.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/pse/include/clipper/clipper.cpp -------------------------------------------------------------------------------- /MobileNetV2/images/MobileNetv2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/MobileNetV2/images/MobileNetv2.png -------------------------------------------------------------------------------- /__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/cmd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/cmd.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/cmd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/cmd.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/dec.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/dec.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/dec.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/dec.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/img.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/img.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/img.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/img.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/io_.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/io_.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/io_.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/io_.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/log.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/log.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/log.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/log.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/ml.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/ml.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/ml.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/ml.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/mod.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/mod.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/mod.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/mod.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/np.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/np.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/np.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/np.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/url.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/url.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/url.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/url.cpython-37.pyc -------------------------------------------------------------------------------- /util/cmd.py: -------------------------------------------------------------------------------- 1 | #encoding = utf-8 2 | 3 | def cmd(cmd): 4 | import commands 5 | return commands.getoutput(cmd) 6 | 7 | -------------------------------------------------------------------------------- /util/__pycache__/caffe_.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/caffe_.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/caffe_.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/caffe_.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/dtype.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/dtype.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/dtype.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/dtype.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/event.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/event.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/event.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/event.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/proc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/proc.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/proc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/proc.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/rand.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/rand.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/rand.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/rand.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/str_.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/str_.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/str_.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/str_.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/test.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/test.cpython-37.pyc -------------------------------------------------------------------------------- /pse/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/pse/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pse/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/pse/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/feature.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/feature.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/feature.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/feature.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/thread_.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/thread_.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/thread_.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/thread_.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/dataset/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/dataset/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /util/__pycache__/neighbour.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/neighbour.cpython-36.pyc -------------------------------------------------------------------------------- /util/__pycache__/neighbour.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/util/__pycache__/neighbour.cpython-37.pyc -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .fpn_resnet import resnet50, resnet101, resnet152, resnet34, resnet18, mobilenetv2, mobilenetv3_large, mobilenetv3_small -------------------------------------------------------------------------------- /models/__pycache__/fpn_resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/models/__pycache__/fpn_resnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/fpn_resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/models/__pycache__/fpn_resnet.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/ctw1500_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/dataset/__pycache__/ctw1500_loader.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/ctw1500_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/dataset/__pycache__/ctw1500_loader.cpython-37.pyc -------------------------------------------------------------------------------- /eval/ctw1500/__pycache__/file_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/eval/ctw1500/__pycache__/file_util.cpython-36.pyc -------------------------------------------------------------------------------- /eval/ctw1500/__pycache__/file_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/eval/ctw1500/__pycache__/file_util.cpython-37.pyc -------------------------------------------------------------------------------- /pse/include/pybind11/common.h: -------------------------------------------------------------------------------- 1 | #include "detail/common.h" 2 | #warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'." 3 | -------------------------------------------------------------------------------- /models/__pycache__/mobilenet_v3_block.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/models/__pycache__/mobilenet_v3_block.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/ctw1500_test_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/dataset/__pycache__/ctw1500_test_loader.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/ctw1500_test_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/li10141110/PSENet-tf2/HEAD/dataset/__pycache__/ctw1500_test_loader.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from dataset.ctw1500_loader import CTW1500Loader, ctw_train_loader 2 | from dataset.ctw1500_test_loader import CTW1500TestLoader, ctw_test_loader 3 | -------------------------------------------------------------------------------- /util/test.py: -------------------------------------------------------------------------------- 1 | #encoding = utf-8 2 | import numpy as np 3 | 4 | assert_true = np.testing.assert_ 5 | assert_equal = np.testing.assert_equal 6 | assert_array_equal = np.testing.assert_array_equal 7 | assert_almost_equal = np.testing.assert_almost_equal 8 | -------------------------------------------------------------------------------- /MobileNetV3/README.md: -------------------------------------------------------------------------------- 1 | # MobileNetV3_TensorFlow2 2 | A tensorflow2 implementation of MobileNet_V3. 3 | 4 | See https://github.com/calmisential/Basic_CNNs_TensorFlow2 for training details. 5 | 6 | ## References: 7 | 1. The original paper: [Searching for MobileNetV3](https://arxiv.org/abs/1905.02244) 8 | -------------------------------------------------------------------------------- /pse/__main__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | from . import merge_quadrangle_n9 5 | 6 | if __name__ == '__main__': 7 | # unit square with confidence 1 8 | q = np.array([0, 0, 0, 1, 1, 1, 1, 0, 1], dtype='float32') 9 | 10 | print(merge_quadrangle_n9(np.array([q, q + 0.1, q + 2]))) 11 | -------------------------------------------------------------------------------- /util/statistic.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | ''' 3 | Created on 2016年10月8日 4 | 5 | @author: dengdan 6 | ''' 7 | import numpy as np 8 | import util.np 9 | 10 | def D(x): 11 | x = util.np.flatten(x) 12 | return np.var(x) 13 | 14 | def E(x): 15 | x = util.np.flatten(x) 16 | return np.average(x) 17 | -------------------------------------------------------------------------------- /util/event.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import logging 3 | def wait_key(target = None): 4 | key = cv2.waitKey()& 0xFF 5 | if target == None: 6 | return key 7 | if type(target) == str: 8 | target = ord(target) 9 | while key != target: 10 | key = cv2.waitKey()& 0xFF 11 | 12 | logging.debug('Key Pression caught:%s'%(target)) 13 | -------------------------------------------------------------------------------- /pse/Makefile: -------------------------------------------------------------------------------- 1 | CXXFLAGS = -I include -std=c++11 -O3 2 | 3 | DEPS = lanms.h $(shell find include -xtype f) 4 | CXX_SOURCES = adaptor.cpp include/clipper/clipper.cpp 5 | OPENCV = `pkg-config --cflags --libs opencv` 6 | #OPENCV=/usr/local/ 7 | 8 | LIB_SO = adaptor.so 9 | 10 | $(LIB_SO): $(CXX_SOURCES) $(DEPS) 11 | $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC $(OPENCV) 12 | 13 | clean: 14 | rm -rf $(LIB_SO) 15 | -------------------------------------------------------------------------------- /pse/__init__.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import numpy as np 4 | import time 5 | 6 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 7 | 8 | if subprocess.call(['make', '-C', BASE_DIR]) != 0: 9 | raise RuntimeError('Cannot compile pse: {}'.format(BASE_DIR)) 10 | 11 | from .adaptor import pse as cpse 12 | def pse(polys, min_area): 13 | # start = time.time() 14 | ret = np.array(cpse(polys, min_area), dtype='int32') 15 | # end = time.time() 16 | # print (end - start), 's' 17 | return ret 18 | 19 | -------------------------------------------------------------------------------- /util/url.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from six.moves import urllib 4 | 5 | import util 6 | def download(url, path): 7 | filename = path.split('/')[-1] 8 | if not util.io.exists(path): 9 | def _progress(count, block_size, total_size): 10 | sys.stdout.write('\r-----Downloading %s %.1f%%' % (filename, 11 | float(count * block_size) / float(total_size) * 100.0)) 12 | sys.stdout.flush() 13 | path, _ = urllib.request.urlretrieve(url, path, _progress) 14 | print() 15 | statinfo = os.stat(path) 16 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 17 | -------------------------------------------------------------------------------- /util/feature.py: -------------------------------------------------------------------------------- 1 | # encoding utf-8 2 | def hog(img, bins =9, pixels_per_cell=(8, 8), cells_per_block=(2, 2), transform_sqrt=False, feature_vector=True): 3 | """ 4 | Extract hog feature from image. 5 | See detail at https://github.com/scikit-image/scikit-image/blob/master/skimage/feature/_hog.py 6 | """ 7 | from skimage.feature import hog 8 | return hog(img, 9 | orientations = bins, 10 | pixels_per_cell = pixels_per_cell, 11 | cells_per_block = cells_per_block, 12 | visualise = False, 13 | transform_sqrt=False, 14 | feature_vector=True) 15 | -------------------------------------------------------------------------------- /util/t.py: -------------------------------------------------------------------------------- 1 | #encoding=utf-8 2 | """ 3 | for theano shortcuts 4 | """ 5 | import theano 6 | import theano.tensor as T 7 | import util.rand 8 | 9 | trng = T.shared_randomstreams.RandomStreams(util.rand.randint()) 10 | scan_until = theano.scan_module.until 11 | 12 | def add_noise(input, noise_level): 13 | noise = trng.binomial(size = input.shape, n = 1, p = 1 - noise_level) 14 | return noise * input 15 | 16 | def crop_into(large, small): 17 | """ 18 | center crop large image into small. 19 | both 'large' and 'small' are 4D: (batch_size, channels, h, w) 20 | """ 21 | 22 | h1, w1 = large.shape[2:] 23 | h2, w2 = small.shape[2:] 24 | y, x = (h1 - h2) / 2, (w1 - h2)/2 25 | return large[:, :, y: y + h2, x: x + w2 ] -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /util/dtype.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | ''' 3 | Created on 2016年9月27日 4 | @author: dengdan 5 | ''' 6 | import numpy as np 7 | 8 | float32 = 'float32' 9 | floatX = float32 10 | int32 = 'int32' 11 | uint8 = 'uint8' 12 | string = 'str' 13 | 14 | def cast(obj, dtype): 15 | if isinstance(obj, list): 16 | return np.asarray(obj, dtype = floatX) 17 | return np.cast[dtype](obj) 18 | 19 | def int(obj): 20 | return cast(obj, 'int') 21 | 22 | def double(obj): 23 | return cast(obj, 'double') 24 | 25 | def is_number(obj): 26 | try: 27 | obj + 1 28 | except: 29 | return False 30 | return True 31 | 32 | def is_str(s): 33 | return type(s) == str 34 | 35 | def is_list(s): 36 | return type(s) == list 37 | 38 | def is_tuple(s): 39 | return type(s) == tuple 40 | -------------------------------------------------------------------------------- /util/ml.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import cv2 3 | import numpy as np 4 | import util.dec 5 | import util.np 6 | 7 | @util.dec.print_calling 8 | def kmeans(samples, k, criteria = None, attempts = 3, flags = cv2.KMEANS_RANDOM_CENTERS): 9 | if criteria == None: 10 | criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0) 11 | samples = np.asarray(samples, dtype = np.float32) 12 | _,labels,centers = cv2.kmeans(samples, k, criteria, attempts, flags) 13 | labels = util.np.flatten(labels) 14 | clusters = [None]*k 15 | for idx, label in enumerate(labels): 16 | if clusters[label] is None: 17 | clusters[label] = [] 18 | clusters[label].append(idx) 19 | 20 | for idx, cluster in enumerate(clusters): 21 | if cluster == None: 22 | logging.warn('Empty cluster appeared.') 23 | clusters[idx] = [] 24 | 25 | return labels, clusters, centers 26 | -------------------------------------------------------------------------------- /util/rand.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | ''' 3 | Created on 2016年9月27日 4 | 5 | @author: dengdan 6 | ''' 7 | import numpy as np 8 | import time 9 | import random 10 | 11 | rng = np.random.RandomState(int(time.time())) 12 | 13 | rand = np.random.rand 14 | """ 15 | Create an array of the given shape and populate it with random samples from a uniform distribution over [0, 1) 16 | """ 17 | 18 | 19 | def normal(shape, mu = 0, sigma_square = 1): 20 | return rng.normal(mu, np.sqrt(sigma_square), shape) 21 | 22 | def randint(low = 2 ** 30, high = None, shape = None): 23 | """ 24 | low: the higher bound except when high is not None. 25 | high: when it is not none, low must be smaller than it 26 | shape: if not provided, a scalar will be returned 27 | """ 28 | return rng.randint(low = low, high = high, size = shape) 29 | 30 | def shuffle(lst): 31 | random.shuffle(lst) 32 | 33 | def sample(lst, n): 34 | return random.sample(lst, n) 35 | 36 | 37 | -------------------------------------------------------------------------------- /eval/ctw1500/file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def read_dir(root): 4 | file_path_list = [] 5 | for file_path, dirs, files in os.walk(root): 6 | for file in files: 7 | file_path_list.append(os.path.join(file_path, file).replace('\\', '/')) 8 | file_path_list.sort() 9 | return file_path_list 10 | 11 | def read_file(file_path): 12 | file_object = open(file_path, 'r') 13 | file_content = file_object.read() 14 | file_object.close() 15 | return file_content 16 | 17 | def write_file(file_path, file_content): 18 | if file_path.find('/') != -1: 19 | father_dir = '/'.join(file_path.split('/')[0:-1]) 20 | if not os.path.exists(father_dir): 21 | os.makedirs(father_dir) 22 | file_object = open(file_path, 'w') 23 | file_object.write(file_content) 24 | file_object.close() 25 | 26 | 27 | def write_file_not_cover(file_path, file_content): 28 | father_dir = '/'.join(file_path.split('/')[0:-1]) 29 | if not os.path.exists(father_dir): 30 | os.makedirs(father_dir) 31 | file_object = open(file_path, 'a') 32 | file_object.write(file_content) 33 | file_object.close() -------------------------------------------------------------------------------- /MobileNetV2/data/convert.py: -------------------------------------------------------------------------------- 1 | """ 2 | A sample of convert the cifar100 dataset to 224 * 224 size train\val data. 3 | """ 4 | import cv2 5 | import os 6 | from keras.datasets import cifar100 7 | 8 | 9 | def convert(): 10 | train = 'train//' 11 | val = 'validation//' 12 | 13 | (X_train, y_train), (X_test, y_test) = cifar100.load_data(label_mode='fine') 14 | 15 | for i in range(len(X_train)): 16 | x = X_train[i] 17 | y = y_train[i] 18 | path = train + str(y[0]) 19 | x = cv2.resize(x, (224, 224), interpolation=cv2.INTER_CUBIC) 20 | if not os.path.exists(path): 21 | os.makedirs(path) 22 | cv2.imwrite(path + '//' + str(i) + '.jpg', x) 23 | 24 | for i in range(len(X_test)): 25 | x = X_test[i] 26 | y = y_test[i] 27 | path = val + str(y[0]) 28 | x = cv2.resize(x, (224, 224), interpolation=cv2.INTER_CUBIC) 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | cv2.imwrite(path + '//' + str(i) + '.jpg', x) 32 | 33 | 34 | if __name__ == '__main__': 35 | convert() 36 | -------------------------------------------------------------------------------- /MobileNetV2/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Larry 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MobileNetV3/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 calmisential 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /util/opencv_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | #读取图片 5 | img = cv2.imread('city.jpg') 6 | #二值化,canny检测 7 | binaryImg = cv2.Canny(img,50,200) 8 | 9 | #寻找轮廓 10 | #也可以这么写: 11 | #binary,contours, hierarchy = cv2.findContours(binaryImg,cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE) 12 | #这样,可以直接用contours表示 13 | h = cv2.findContours(binaryImg,cv2.RETR_TREE,cv2.CHAIN_APPROX_NONE) 14 | #提取轮廓 15 | contours = h[0] 16 | #打印返回值,这是一个元组 17 | print(type(h)) 18 | #打印轮廓类型,这是个列表 19 | print(type(h[1])) 20 | #查看轮廓数量 21 | print (len(contours)) 22 | 23 | #创建白色幕布 24 | temp = np.ones(binaryImg.shape,np.uint8)*255 25 | #画出轮廓:temp是白色幕布,contours是轮廓,-1表示全画,然后是颜色,厚度 26 | cv2.drawContours(temp,contours,-1,(0,255,0),3) 27 | 28 | cv2.imshow("contours",temp) 29 | cv2.waitKey(0) 30 | cv2.destroyAllWindows() 31 | 32 | import cv2 33 | import numpy as np 34 | 35 | img = cv2.imread('test.jpg') 36 | imgray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 37 | ret, thresh = cv2.threshold(imgray, 127, 255, 0) 38 | image, contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 39 | cv2.imshow('imageshow', image) # 显示返回值image,其实与输入参数的thresh原图没啥区别 40 | cv2.waitKey(0) 41 | 42 | img = cv2.drawContours(img, contours, -1, (0, 255, 0), 5) # img为三通道才能显示轮廓 43 | cv2.imshow('drawimg', img) 44 | cv2.waitKey(0) 45 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /util/proc.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | 3 | def cpu_count(): 4 | return multiprocessing.cpu_count() 5 | 6 | def get_pool(processes): 7 | pool = multiprocessing.Pool(processes = processes) 8 | return pool 9 | 10 | def wait_for_pool(pool): 11 | pool.close() 12 | pool.join() 13 | 14 | def set_proc_name(name): 15 | import setproctitle 16 | setproctitle.setproctitle(name) 17 | 18 | def kill(pid): 19 | import util 20 | if type(pid) == list: 21 | for p in pid: 22 | kill(p) 23 | elif type(pid) == int: 24 | cmd = 'kill -9 %d'%(pid) 25 | print (cmd) 26 | print (util.cmd.cmd(cmd)) 27 | elif type(pid) == str: 28 | pids = get_pid(pid) 29 | kill(pids) 30 | else: 31 | print('Not supported parameter type:' , type(pid)) 32 | raise ValueError 33 | 34 | def ps_aux_grep(pattern): 35 | import util 36 | cmd = 'ps aux|grep %s'%(pattern) 37 | return util.cmd.cmd(cmd) 38 | 39 | 40 | def get_pid(pattern): 41 | import util 42 | cmd = 'ps aux|grep %s'%(pattern) 43 | results = util.cmd.cmd(cmd) 44 | results = util.str.split(results, '\n') 45 | pids = [] 46 | for result in results: 47 | info = result.split() 48 | if len(info) > 0: 49 | pid = int(info[1]) 50 | pids.append(pid) 51 | return pids 52 | 53 | -------------------------------------------------------------------------------- /util/log.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | ''' 3 | Created on 2016年10月12日 4 | 5 | @author: dengdan 6 | ''' 7 | import datetime 8 | import logging 9 | import util 10 | import sys 11 | 12 | def get_date_str(): 13 | now = datetime.datetime.now() 14 | return now.strftime('%Y-%m-%d %H:%M:%S') 15 | 16 | def init_logger(log_file = None, log_path = None, log_level = logging.DEBUG, mode = 'w', stdout = True): 17 | """ 18 | log_path: 日志文件的文件夹路径 19 | mode: 'a', append; 'w', 覆盖原文件写入. 20 | """ 21 | fmt = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s: %(message)s' 22 | if log_path is None: 23 | log_path = '~/temp/log/' 24 | if log_file is None: 25 | log_file = 'log_' + get_date_str() + '.log' 26 | log_file = util.io.join_path(log_path, log_file) 27 | # 此处不能使用logging输出 28 | print('log file path:' + log_file); 29 | util.io.make_parent_dir(log_file) 30 | logging.basicConfig(level = log_level, 31 | format= fmt, 32 | filename= util.io.get_absolute_path(log_file), 33 | filemode=mode) 34 | 35 | if stdout: 36 | console = logging.StreamHandler(stream = sys.stdout) 37 | console.setLevel(log_level) 38 | formatter = logging.Formatter(fmt) 39 | console.setFormatter(formatter) 40 | logging.getLogger('').addHandler(console) 41 | 42 | # console = logging.StreamHandler(stream = sys.stderr) 43 | # console.setLevel(log_level) 44 | # formatter = logging.Formatter(fmt) 45 | # console.setFormatter(formatter) 46 | # logging.getLogger('').addHandler(console) 47 | 48 | -------------------------------------------------------------------------------- /util/mod.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | import logging 4 | 5 | def add_to_path(path): 6 | ''' 7 | add path to sys.path. 8 | ''' 9 | import sys; 10 | sys.path.insert(0, path); 11 | 12 | def add_ancester_dir_to_path(fp, p): 13 | ''' 14 | add ancester directory to sys.path. 15 | fp: usually __file__ 16 | p : the relative path to be added. 17 | ''' 18 | import util 19 | parent_path = util.io.get_dir(fp) 20 | path = util.io.join_path(parent_path, p) 21 | add_to_path(path) 22 | 23 | def is_main(mod_name): 24 | return mod_name == '__main__' 25 | 26 | def import_by_name(mod_name): 27 | __import__(mod_name) 28 | return get_mod_by_name(mod_name) 29 | 30 | def try_import_by_name(mod_name, error_path): 31 | try: 32 | import_by_name(mod_name) 33 | except ImportError: 34 | logging.info('adding %s to sys.path'%(error_path)) 35 | add_to_path(error_path) 36 | import_by_name(mod_name) 37 | 38 | return get_mod_by_name(mod_name) 39 | 40 | def get_mod_by_name(mod_name): 41 | import sys 42 | return sys.modules[mod_name] 43 | 44 | def load_mod_from_path(path, keep_name = True): 45 | """" 46 | Params: 47 | path 48 | keep_name: if True, the filename will be used as module name. 49 | """ 50 | import util 51 | import imp 52 | path = util.io.get_absolute_path(path) 53 | file_name = util.io.get_filename(path) 54 | module_name = file_name.split('.')[0] 55 | if not keep_name: 56 | module_name = '%s_%d'%(module_name, util.get_count()) 57 | return imp.load_source(module_name, path) 58 | -------------------------------------------------------------------------------- /pse/include/pybind11/detail/typeid.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/typeid.h: Compiler-independent access to type identifiers 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__GNUG__) 16 | #include 17 | #endif 18 | 19 | #include "common.h" 20 | 21 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 22 | NAMESPACE_BEGIN(detail) 23 | /// Erase all occurrences of a substring 24 | inline void erase_all(std::string &string, const std::string &search) { 25 | for (size_t pos = 0;;) { 26 | pos = string.find(search, pos); 27 | if (pos == std::string::npos) break; 28 | string.erase(pos, search.length()); 29 | } 30 | } 31 | 32 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { 33 | #if defined(__GNUG__) 34 | int status = 0; 35 | std::unique_ptr res { 36 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; 37 | if (status == 0) 38 | name = res.get(); 39 | #else 40 | detail::erase_all(name, "class "); 41 | detail::erase_all(name, "struct "); 42 | detail::erase_all(name, "enum "); 43 | #endif 44 | detail::erase_all(name, "pybind11::"); 45 | } 46 | NAMESPACE_END(detail) 47 | 48 | /// Return a string representation of a C++ type 49 | template static std::string type_id() { 50 | std::string name(typeid(T).name()); 51 | detail::clean_type_id(name); 52 | return name; 53 | } 54 | 55 | NAMESPACE_END(PYBIND11_NAMESPACE) 56 | -------------------------------------------------------------------------------- /MobileNetV2/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | from util import log 2 | from util import dtype 3 | # import plt 4 | from util import np 5 | from util import img 6 | _img = img 7 | from util import dec 8 | from util import rand 9 | from util import mod 10 | from util import proc 11 | from util import test 12 | from util import neighbour as nb 13 | #import mask 14 | from util import str_ as str 15 | #from util import io as sys_io 16 | from util import io_ as io 17 | from util import feature 18 | from util import thread_ as thread 19 | from util import caffe_ as caffe 20 | # import tf 21 | from util import cmd 22 | from util import ml 23 | #from util import sys 24 | from util import url 25 | #from util.misc import * 26 | from .logger import * 27 | # log.init_logger('~/temp/log/log_' + get_date_str() + '.log') 28 | 29 | def exit(code = 0): 30 | sys.exit(0) 31 | 32 | is_main = mod.is_main 33 | init_logger = log.init_logger 34 | 35 | def sit(img, path = None, name = ""): 36 | if path is None: 37 | _count = get_count(); 38 | path = '~/temp/no-use/images/%s_%d_%s.jpg'%(log.get_date_str(), _count, name) 39 | 40 | if type(img) == list: 41 | plt.show_images(images = img, path = path, show = False, axis_off = True, save = True) 42 | else: 43 | plt.imwrite(path, img) 44 | 45 | return path 46 | _count = 0; 47 | 48 | def get_count(): 49 | global _count; 50 | _count += 1; 51 | return _count 52 | 53 | def cit(img, path = None, rgb = True, name = ""): 54 | _count = get_count(); 55 | if path is None: 56 | img = np.np.asarray(img, dtype = np.np.uint8) 57 | path = '~/temp/no-use/%s_%d_%s.jpg'%(log.get_date_str(), _count, name) 58 | _img.imwrite(path, img, rgb = rgb) 59 | return path 60 | 61 | def argv(index): 62 | return sys.argv[index] 63 | -------------------------------------------------------------------------------- /util/thread_.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from threading import Thread 3 | 4 | def get_current_thread(): 5 | return threading.current_thread() 6 | 7 | def get_current_thread_name(): 8 | return get_current_thread().getName() 9 | 10 | def is_alive(t): 11 | return t.is_alive() 12 | 13 | def create_and_start(name, target, daemon = True): 14 | t = Thread(target= target) 15 | t.daemon = True 16 | t.setName(name) 17 | t.start() 18 | return t 19 | 20 | 21 | 22 | class ThreadPool(object): 23 | def __init__(self, capacity = 10): 24 | import threadpool 25 | self.num_threads = capacity 26 | self.pool = threadpool.ThreadPool(10) 27 | 28 | def add(self, fn, args): 29 | import threadpool 30 | if type(args) == list: 31 | args = [(args, None)] 32 | elif type(args) == dict: 33 | args = [(None, args)] 34 | else: 35 | print("Unsuported args,", type(args)) 36 | raise ValueError 37 | #raise ValueError, "Unsuported args,", type(args) 38 | request = threadpool.makeRequests(fn, args)[0] 39 | self.pool.putRequest(request, block = False) 40 | self.pool.poll() 41 | 42 | def join(self): 43 | self.pool.wait() 44 | 45 | class ProcessPool(object): 46 | """ 47 | Remember that function in function is not supported by multiprocessing. 48 | """ 49 | def __init__(self, capacity = 8): 50 | from multiprocessing import Pool 51 | 52 | self.capacity = capacity 53 | self.pool = Pool(capacity) 54 | 55 | def add(self, fn, args): 56 | self.pool.apply_async(fn, args) 57 | # self.pool.poll() 58 | # self.pool.poll 59 | 60 | def join(self): 61 | self.pool.close() 62 | self.pool.join() 63 | 64 | 65 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # Adapted from score written by wkentaro 2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 3 | 4 | import numpy as np 5 | 6 | class runningScore(object): 7 | 8 | def __init__(self, n_classes): 9 | self.n_classes = n_classes 10 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 11 | 12 | def _fast_hist(self, label_true, label_pred, n_class): 13 | mask = (label_true >= 0) & (label_true < n_class) 14 | 15 | if np.sum((label_pred[mask] < 0)) > 0: 16 | print (label_pred[label_pred < 0]) 17 | hist = np.bincount( 18 | n_class * label_true[mask].astype(int) + 19 | label_pred[mask], minlength=n_class**2).reshape(n_class, n_class) 20 | return hist 21 | 22 | def update(self, label_trues, label_preds): 23 | # print label_trues.dtype, label_preds.dtype 24 | for lt, lp in zip(label_trues, label_preds): 25 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 26 | 27 | def get_scores(self): 28 | """Returns accuracy score evaluation result. 29 | - overall accuracy 30 | - mean accuracy 31 | - mean IU 32 | - fwavacc 33 | """ 34 | hist = self.confusion_matrix 35 | acc = np.diag(hist).sum() / (hist.sum() + 0.0001) 36 | acc_cls = np.diag(hist) / (hist.sum(axis=1) + 0.0001) 37 | acc_cls = np.nanmean(acc_cls) 38 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + 0.0001) 39 | mean_iu = np.nanmean(iu) 40 | freq = hist.sum(axis=1) / (hist.sum() + 0.0001) 41 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 42 | cls_iu = dict(zip(range(self.n_classes), iu)) 43 | 44 | return {'Overall Acc': acc, 45 | 'Mean Acc': acc_cls, 46 | 'FreqW Acc': fwavacc, 47 | 'Mean IoU': mean_iu,}, cls_iu 48 | 49 | def reset(self): 50 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 51 | -------------------------------------------------------------------------------- /pypse.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | #import Queue 4 | import queue as q 5 | 6 | def pse(kernals, min_area): 7 | kernal_num = len(kernals) 8 | pred = np.zeros(kernals[0].shape, dtype='int32') 9 | 10 | label_num, label = cv2.connectedComponents(kernals[kernal_num - 1], connectivity=4) 11 | 12 | for label_idx in range(1, label_num): 13 | if np.sum(label == label_idx) < min_area: 14 | label[label == label_idx] = 0 15 | 16 | queue = q.Queue(maxsize = 0) 17 | next_queue = q.Queue(maxsize = 0) 18 | points = np.array(np.where(label > 0)).transpose((1, 0)) 19 | 20 | for point_idx in range(points.shape[0]): 21 | x, y = points[point_idx, 0], points[point_idx, 1] 22 | l = label[x, y] 23 | queue.put((x, y, l)) 24 | pred[x, y] = l 25 | 26 | dx = [-1, 1, 0, 0] 27 | dy = [0, 0, -1, 1] 28 | for kernal_idx in range(kernal_num - 2, -1, -1): 29 | kernal = kernals[kernal_idx].copy() 30 | while not queue.empty(): 31 | (x, y, l) = queue.get() 32 | 33 | is_edge = True 34 | for j in range(4): 35 | tmpx = x + dx[j] 36 | tmpy = y + dy[j] 37 | if tmpx < 0 or tmpx >= kernal.shape[0] or tmpy < 0 or tmpy >= kernal.shape[1]: 38 | continue 39 | if kernal[tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: 40 | continue 41 | 42 | queue.put((tmpx, tmpy, l)) 43 | pred[tmpx, tmpy] = l 44 | is_edge = False 45 | if is_edge: 46 | next_queue.put((x, y, l)) 47 | 48 | # kernal[pred > 0] = 0 49 | queue, next_queue = next_queue, queue 50 | 51 | # points = np.array(np.where(pred > 0)).transpose((1, 0)) 52 | # for point_idx in range(points.shape[0]): 53 | # x, y = points[point_idx, 0], points[point_idx, 1] 54 | # l = pred[x, y] 55 | # queue.put((x, y, l)) 56 | 57 | return pred -------------------------------------------------------------------------------- /pse/include/pybind11/complex.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/complex.h: Complex number support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | /// glibc defines I as a macro which breaks things, e.g., boost template names 16 | #ifdef I 17 | # undef I 18 | #endif 19 | 20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 21 | 22 | template struct format_descriptor, detail::enable_if_t::value>> { 23 | static constexpr const char c = format_descriptor::c; 24 | static constexpr const char value[3] = { 'Z', c, '\0' }; 25 | static std::string format() { return std::string(value); } 26 | }; 27 | 28 | #ifndef PYBIND11_CPP17 29 | 30 | template constexpr const char format_descriptor< 31 | std::complex, detail::enable_if_t::value>>::value[3]; 32 | 33 | #endif 34 | 35 | NAMESPACE_BEGIN(detail) 36 | 37 | template struct is_fmt_numeric, detail::enable_if_t::value>> { 38 | static constexpr bool value = true; 39 | static constexpr int index = is_fmt_numeric::index + 3; 40 | }; 41 | 42 | template class type_caster> { 43 | public: 44 | bool load(handle src, bool convert) { 45 | if (!src) 46 | return false; 47 | if (!convert && !PyComplex_Check(src.ptr())) 48 | return false; 49 | Py_complex result = PyComplex_AsCComplex(src.ptr()); 50 | if (result.real == -1.0 && PyErr_Occurred()) { 51 | PyErr_Clear(); 52 | return false; 53 | } 54 | value = std::complex((T) result.real, (T) result.imag); 55 | return true; 56 | } 57 | 58 | static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) { 59 | return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); 60 | } 61 | 62 | PYBIND11_TYPE_CASTER(std::complex, _("complex")); 63 | }; 64 | NAMESPACE_END(detail) 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /pse/include/pybind11/options.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/options.h: global settings that are configurable at runtime. 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | 16 | class options { 17 | public: 18 | 19 | // Default RAII constructor, which leaves settings as they currently are. 20 | options() : previous_state(global_state()) {} 21 | 22 | // Class is non-copyable. 23 | options(const options&) = delete; 24 | options& operator=(const options&) = delete; 25 | 26 | // Destructor, which restores settings that were in effect before. 27 | ~options() { 28 | global_state() = previous_state; 29 | } 30 | 31 | // Setter methods (affect the global state): 32 | 33 | options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; } 34 | 35 | options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; } 36 | 37 | options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; } 38 | 39 | options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; } 40 | 41 | // Getter methods (return the global state): 42 | 43 | static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; } 44 | 45 | static bool show_function_signatures() { return global_state().show_function_signatures; } 46 | 47 | // This type is not meant to be allocated on the heap. 48 | void* operator new(size_t) = delete; 49 | 50 | private: 51 | 52 | struct state { 53 | bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings. 54 | bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings. 55 | }; 56 | 57 | static state &global_state() { 58 | static state instance; 59 | return instance; 60 | } 61 | 62 | state previous_state; 63 | }; 64 | 65 | NAMESPACE_END(PYBIND11_NAMESPACE) 66 | -------------------------------------------------------------------------------- /util/mask.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import util 5 | from util import nb as neighbour 6 | 7 | 8 | def find_white_components(mask, min_area = 0): 9 | mask = (mask == 0) * 1 10 | return find_black_components(mask, min_area); 11 | 12 | def find_black_components(mask, min_area = 0): 13 | """ 14 | find components of zeros. 15 | mask is a 0-1 matrix, ndarray. 16 | """ 17 | neighbour_type = neighbour.N4 18 | visited = mask.copy() 19 | c_mask = util.img.black(mask) 20 | 21 | root_idx = [1] 22 | def get_new_root(): 23 | root_idx[0] += 1 24 | return root_idx[0] 25 | 26 | def is_visited(xy): 27 | x, y = xy 28 | return visited[y][x] 29 | 30 | def set_visited(xy): 31 | x, y = xy 32 | visited[y][x] = 255 33 | 34 | def set_root(xy, root): 35 | x, y = xy 36 | c_mask[y][x] = root 37 | 38 | def get_root(xy): 39 | x, y = xy 40 | return c_mask[y][x] 41 | 42 | rows, cols = np.shape(mask) 43 | q = [] 44 | for y in xrange(rows): 45 | for x in xrange(cols): 46 | xy = (x, y) 47 | if is_visited(xy): 48 | continue 49 | 50 | q.append(xy) 51 | new_root = get_new_root() 52 | while len(q) > 0: 53 | cp = q.pop() 54 | set_root(cp, new_root) 55 | set_visited(cp) 56 | nbs = neighbour.get_neighbours(cp[0], cp[1], cols, rows, neighbour_type) 57 | for nb in nbs: 58 | if not is_visited(nb) and nb not in q: 59 | # q.append(nb) 60 | q.insert(0, nb) 61 | 62 | components = {} 63 | for y in xrange(rows): 64 | for x in xrange(cols): 65 | root = get_root((x, y)) 66 | if root == 0: 67 | continue 68 | 69 | if root not in components: 70 | components[root] = [] 71 | 72 | components[root].append((x,y)) 73 | 74 | ret = [] 75 | 76 | for root in components: 77 | if len(components[root]) >= min_area: 78 | ret.append(components[root]) 79 | 80 | return ret 81 | 82 | 83 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value""" 61 | def __init__(self): 62 | self.reset() 63 | 64 | def reset(self): 65 | self.val = 0 66 | self.avg = 0 67 | self.sum = 0 68 | self.count = 0 69 | 70 | def update(self, val, n=1): 71 | self.val = val 72 | self.sum += val * n 73 | self.count += n 74 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /util/str_.py: -------------------------------------------------------------------------------- 1 | # encoding = utf-8 2 | def int_array_to_str(arr): 3 | """turn an int array to a str""" 4 | return "".join(map(chr, arr)) 5 | 6 | def join(arr, splitter=','): 7 | temp = [] 8 | for e in arr: 9 | temp.append(e) 10 | temp.append(splitter) 11 | temp.pop() 12 | return "".join(temp) 13 | 14 | def is_str(s): 15 | return type(s) == str 16 | 17 | def to_lowercase(s): 18 | return str.lower(s) 19 | 20 | def to_uppercase(s): 21 | return str.upper(s) 22 | 23 | def ends_with(s, suffix, ignore_case = False): 24 | """ 25 | suffix: str, list, or tuple 26 | """ 27 | if is_str(suffix): 28 | suffix = [suffix] 29 | suffix = list(suffix) 30 | if ignore_case: 31 | for idx, suf in enumerate(suffix): 32 | suffix[idx] = to_lowercase(suf) 33 | s = to_lowercase(s) 34 | suffix = tuple(suffix) 35 | return s.endswith(suffix) 36 | 37 | def starts_with(s, prefix, ignore_case = False): 38 | """ 39 | prefix: str, list, or tuple 40 | """ 41 | if is_str(prefix): 42 | prefix = [prefix] 43 | prefix = list(prefix) 44 | if ignore_case: 45 | for idx, pre in enumerate(prefix): 46 | prefix[idx] = to_lowercase(pre) 47 | s = to_lowercase(s) 48 | prefix = tuple(prefix) 49 | return s.startswith(prefix) 50 | 51 | 52 | def contains(s, target, ignore_case = False): 53 | if ignore_case: 54 | s = to_lowercase(s) 55 | target = to_lowercase(target) 56 | return s.find(target) >= 0 57 | 58 | def index_of(s, target): 59 | return s.find(target) 60 | 61 | def replace_all(s, old, new, reg = False): 62 | if reg: 63 | import re 64 | targets = re.findall(old, s) 65 | for t in targets: 66 | s = s.replace(t, new) 67 | else: 68 | s = s.replace(old, new) 69 | return s 70 | 71 | def remove_all(s, sub): 72 | return replace_all(s, sub, '') 73 | 74 | def split(s, splitter, reg = False): 75 | if not reg: 76 | return s.split(splitter) 77 | import re 78 | return re.split(splitter, s) 79 | 80 | def remove_invisible(s): 81 | s = replace_all(s, ' ', '') 82 | s = replace_all(s, '\n', '') 83 | s = replace_all(s, '\t', '') 84 | s = replace_all(s, '\r', '') 85 | return s 86 | 87 | def find_all(s, pattern): 88 | import re 89 | return re.findall(pattern, s) 90 | 91 | def is_none_or_empty(s): 92 | if s is None: 93 | return True 94 | return len(s)==0; 95 | -------------------------------------------------------------------------------- /util/neighbour.py: -------------------------------------------------------------------------------- 1 | #encoding=utf-8 2 | 3 | import numpy as np 4 | 5 | N1 = 'n1' 6 | N2 = 'n2' 7 | N4 = 'n4' 8 | N8 = 'n8' 9 | 10 | def _in_image(c, w, h): 11 | cx, cy = c 12 | return cx >=0 and cx < w and cy >= 0 and cy < h 13 | 14 | def n1(x, y, w, h): 15 | """down and right""" 16 | neighbours = [] 17 | candidates = [(x, y + 1), (x + 1, y)]; 18 | 19 | for c in candidates: 20 | if _in_image(c, w, h): 21 | neighbours.append(c) 22 | 23 | return neighbours 24 | 25 | 26 | def n2(x, y, w, h): 27 | neighbours = [] 28 | candidates = [(x, y + 1), (x + 1, y), (x + 1, y + 1), (x - 1, y + 1)]; 29 | for c in candidates: 30 | if _in_image(c, w, h): 31 | neighbours.append(c) 32 | 33 | return neighbours; 34 | 35 | def n4(x, y, w, h): 36 | neighbours = [] 37 | candidates = [(x, y - 1),(x, y + 1), (x + 1, y), (x - 1, y)]; 38 | for c in candidates: 39 | if _in_image(c, w, h): 40 | neighbours.append(c) 41 | return neighbours 42 | 43 | 44 | def n8(x, y, w, h): 45 | neighbours = [] 46 | candidates = [(x + 1, y - 1),(x, y - 1),(x - 1, y - 1), (x - 1, y),(x, y + 1), (x + 1, y), (x + 1, y + 1), (x - 1, y + 1)]; 47 | for c in candidates: 48 | if _in_image(c, w, h): 49 | neighbours.append(c) 50 | 51 | return neighbours; 52 | 53 | 54 | def n1_count(w, h): 55 | return 2 * w * h - w - h 56 | 57 | def n2_count(w, h): 58 | return 4 * w * h - 3 * w - 3 * h + 2 59 | 60 | 61 | _dict1 = {N1:n1, N2:n2, N4:n4, N8:n8}; 62 | _dict2 = {N1:n1_count, N2:n2_count}; 63 | 64 | def get_neighbours(x, y, w, h, neighbour_type): 65 | if neighbour_type in _dict1: 66 | fn = _dict1[neighbour_type] 67 | return fn(x, y, w, h) 68 | raise NotImplementedError("unknown neighbour type '%s'" % (neighbour_type)) 69 | 70 | def count_neighbours(w, h, neighbour_type): 71 | if neighbour_type in _dict2: 72 | fn = _dict2[neighbour_type] 73 | return fn(w, h) 74 | raise NotImplementedError("unknown neighbour type '%s'" % (neighbour_type)) 75 | 76 | 77 | if __name__ == "__main__": 78 | w, h = 10, 10 79 | np.testing.assert_equal(len(n4(0, 0, w, h)), 2) 80 | np.testing.assert_equal(len(n8(0, 0, w, h)), 3) 81 | 82 | np.testing.assert_equal(len(n4(0, 2, w, h)), 3) 83 | np.testing.assert_equal(len(n8(0, 2, w, h)), 5) 84 | 85 | np.testing.assert_equal(len(n4(3, 3, w, h)), 4) 86 | np.testing.assert_equal(len(n8(3, 3, w, h)), 8) 87 | -------------------------------------------------------------------------------- /util/caffe_.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | import util 4 | def get_data(net, name): 5 | import caffe 6 | if isinstance(net, caffe._caffe.Solver): 7 | net = net.net 8 | return net.blobs[name].data[...] 9 | 10 | def get_params(net, name = None): 11 | import caffe 12 | if isinstance(net, caffe._caffe.Solver): 13 | net = net.net 14 | params = net.params[name] 15 | p = [] 16 | for param in params: 17 | p.append(param.data[...]) 18 | return p 19 | 20 | def draw_log(log_path, output_names, show = False, save_path = None, from_to = None, smooth = False): 21 | pattern = "Train net output: word_bbox_loc_loss = " 22 | log_path = util.io.get_absolute_path(log_path) 23 | f = open(log_path,'r') 24 | iterations = [] 25 | outputs = {} 26 | plt = util.plt.plt 27 | for line in f.readlines(): 28 | if util.str.contains(line, 'Iteration') and util.str.contains(line, 'loss = '): 29 | print (line) 30 | s = line.split('Iteration')[-1] 31 | iter_num = util.str.find_all(s, '\d+')[0] 32 | iter_num = int(iter_num) 33 | iterations.append(iter_num) 34 | 35 | if util.str.contains(line, "Train net output #"): 36 | s = util.str.split(line, 'Train net output #\d+\:')[-1] 37 | s = s.split('(')[0] 38 | output = util.str.find_all(s, '\d*\.*\d+e*\-*\d*\.*\d*')[-1] 39 | output = eval(output) 40 | output = float(output) 41 | for name in output_names: 42 | ptr = ' '+ name + ' =' 43 | if util.str.contains(line, ptr): 44 | if name not in outputs: 45 | outputs[name] = [] 46 | print (line) 47 | print ('\t', iter_num, name, output) 48 | outputs[name].append(output) 49 | if len(outputs)==0: 50 | print ('No output named:', output_names) 51 | return 52 | for name in outputs: 53 | output = outputs[name] 54 | if smooth: 55 | output = util.np.smooth(output) 56 | start = 0 57 | end = len(output) 58 | 59 | if from_to is not None: 60 | start = from_to[0] 61 | end = from_to[1] 62 | line_style = util.plt.get_random_line_style() 63 | plt.plot(iterations[start: end], output[start: end], line_style, label = name) 64 | 65 | plt.legend() 66 | 67 | if save_path is not None: 68 | util.plt.save_image(save_path) 69 | if show: 70 | util.plt.show() 71 | -------------------------------------------------------------------------------- /util/demo_nn/pytorch_demo.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | #from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html 5 | 6 | class DynamicNet(torch.nn.Module): 7 | def __init__(self, D_in, H, D_out): 8 | """ 9 | In the constructor we construct three nn.Linear instances that we will use 10 | in the forward pass. 11 | """ 12 | super(DynamicNet, self).__init__() 13 | self.input_linear = torch.nn.Linear(D_in, H) 14 | self.middle_linear = torch.nn.Linear(H, H) 15 | self.output_linear = torch.nn.Linear(H, D_out) 16 | 17 | def forward(self, x): 18 | """ 19 | For the forward pass of the model, we randomly choose either 0, 1, 2, or 3 20 | and reuse the middle_linear Module that many times to compute hidden layer 21 | representations. 22 | 23 | Since each forward pass builds a dynamic computation graph, we can use normal 24 | Python control-flow operators like loops or conditional statements when 25 | defining the forward pass of the model. 26 | 27 | Here we also see that it is perfectly safe to reuse the same Module many 28 | times when defining a computational graph. This is a big improvement from Lua 29 | Torch, where each Module could be used only once. 30 | """ 31 | h_relu = self.input_linear(x).clamp(min=0) 32 | for _ in range(random.randint(0, 3)): 33 | h_relu = self.middle_linear(h_relu).clamp(min=0) 34 | y_pred = self.output_linear(h_relu) 35 | return y_pred 36 | 37 | 38 | # N is batch size; D_in is input dimension; 39 | # H is hidden dimension; D_out is output dimension. 40 | N, D_in, H, D_out = 64, 1000, 100, 10 41 | 42 | # Create random Tensors to hold inputs and outputs 43 | x = torch.randn(N, D_in) 44 | y = torch.randn(N, D_out) 45 | 46 | # Construct our model by instantiating the class defined above 47 | model = DynamicNet(D_in, H, D_out) 48 | 49 | # Construct our loss function and an Optimizer. Training this strange model with 50 | # vanilla stochastic gradient descent is tough, so we use momentum 51 | criterion = torch.nn.MSELoss(reduction='sum') 52 | optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) 53 | for t in range(500): 54 | # Forward pass: Compute predicted y by passing x to the model 55 | y_pred = model(x) 56 | 57 | # Compute and print loss 58 | loss = criterion(y_pred, y) 59 | if t % 100 == 99: 60 | print(t, loss.item()) 61 | 62 | # Zero gradients, perform a backward pass, and update the weights. 63 | optimizer.zero_grad() 64 | loss.backward() 65 | optimizer.step() -------------------------------------------------------------------------------- /MobileNetV2/README.md: -------------------------------------------------------------------------------- 1 | # MobileNet v2 2 | A Python 3 and Keras 2 implementation of MobileNet V2 and provide train method. 3 | 4 | According to the paper: [Inverted Residuals and Linear Bottlenecks Mobile Networks for Classification, Detection and Segmentation](https://arxiv.org/abs/1801.04381). 5 | 6 | 7 | ## Requirement 8 | - OpenCV 3.4 9 | - Python 3.5 10 | - Tensorflow-gpu 1.5.0 11 | - Keras 2.2 12 | 13 | 14 | ## MobileNet v2 and inverted residual block architectures 15 | 16 | **MobileNet v2:** 17 | 18 | Each line describes a sequence of 1 or more identical (modulo stride) layers, repeated n times. All layers in the same sequence have the same number c of output channels. The first layer of each sequence has a stride s and all others use stride 1. All spatial convolutions use 3 X 3 kernels. The expansion factor t is always applied to the input size. 19 | 20 | ![MobileNetV2](/images/net.jpg) 21 | 22 | **Bottleneck Architectures:** 23 | 24 | ![residual block architectures](/images/stru.jpg) 25 | 26 | 27 | ## Train the model 28 | 29 | The recommended size of the image in the paper is 224 * 224. The ```data\convert.py``` file provide a demo of resize cifar-100 dataset to this size. 30 | 31 | **The dataset folder structure is as follows:** 32 | 33 | | - data/ 34 | | - train/ 35 | | - class 0/ 36 | | - image.jpg 37 | .... 38 | | - class 1/ 39 | .... 40 | | - class n/ 41 | | - validation/ 42 | | - class 0/ 43 | | - class 1/ 44 | .... 45 | | - class n/ 46 | 47 | **Run command below to train the model:** 48 | 49 | ``` 50 | python train.py --classes num_classes --batch batch_size --epochs epochs --size image_size 51 | ``` 52 | 53 | The ```.h5``` weight file was saved at model folder. If you want to do fine tune the trained model, you can run the following command. However, it should be noted that the size of the input image should be consistent with the original model. 54 | 55 | ``` 56 | python train.py --classes num_classes --batch batch_size --epochs epochs --size image_size --weights weights_path --tclasses pre_classes 57 | ``` 58 | 59 | **Parameter explanation** 60 | 61 | - --classes, The number of classes of dataset. 62 | - --size, The image size of train sample. 63 | - --batch, The number of train samples per batch. 64 | - --epochs, The number of train iterations. 65 | - --weights, Fine tune with other weights. 66 | - --tclasses, The number of classes of pre-trained model. 67 | 68 | 69 | ## Reference 70 | 71 | @article{MobileNetv2, 72 | title={Inverted Residuals and Linear Bottlenecks Mobile Networks for Classification, Detection and Segmentatio}, 73 | author={Mark Sandler, Andrew Howard, Menglong Zhu, Andrey Zhmoginov, Liang-Chieh Chen}, 74 | journal={arXiv preprint arXiv:1801.04381}, 75 | year={2018} 76 | } 77 | 78 | 79 | ## Copyright 80 | See [LICENSE](LICENSE) for details. 81 | 82 | 83 | -------------------------------------------------------------------------------- /dataset/ctw1500_test_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import util 4 | import tensorflow as tf 5 | import cv2 6 | import random 7 | import pyclipper 8 | import Polygon as plg 9 | 10 | ctw_root_dir = './data/CTW1500/' 11 | ctw_test_data_dir = ctw_root_dir + 'test/text_image/' 12 | 13 | random.seed(123456) 14 | 15 | def get_img(img_path): 16 | try: 17 | img = cv2.imread(img_path) 18 | img = img[:, :, [2, 1, 0]] 19 | except Exception as e: 20 | print (img_path) 21 | raise 22 | return img 23 | 24 | def scale(img, long_size=1280): 25 | h, w = img.shape[0:2] 26 | scale = long_size * 1.0 / max(h, w) 27 | img = cv2.resize(img, dsize=None, fx=scale, fy=scale) 28 | return img 29 | 30 | 31 | class CTW1500TestLoader(): 32 | def __init__(self, long_size=1280): 33 | 34 | data_dirs = [ctw_test_data_dir] 35 | 36 | self.img_paths = [] 37 | 38 | for data_dir in data_dirs: 39 | img_names = util.io.ls(data_dir, '.jpg') 40 | img_names.extend(util.io.ls(data_dir, '.png')) 41 | # img_names.extend(util.io.ls(data_dir, '.gif')) 42 | 43 | img_paths = [] 44 | for idx, img_name in enumerate(img_names): 45 | img_path = data_dir + img_name 46 | img_paths.append(img_path) 47 | 48 | self.img_paths.extend(img_paths) 49 | 50 | # self.img_paths = self.img_paths[440:] 51 | # self.gt_paths = self.gt_paths[440:] 52 | self.long_size = long_size 53 | 54 | def __len__(self): 55 | return len(self.img_paths) 56 | 57 | def __getitem__(self, index): 58 | img_path = self.img_paths[index] 59 | 60 | img = get_img(img_path) 61 | 62 | scaled_img = scale(img, self.long_size) 63 | #scaled_img = Image.fromarray(scaled_img) 64 | #scaled_img = scaled_img.convert('RGB') 65 | #scaled_img = transforms.ToTensor()(scaled_img) 66 | #scaled_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(scaled_img) 67 | scaled_img = (scaled_img / 255.).astype('float32') 68 | mean = tf.constant([0.485, 0.456, 0.406]) 69 | std = tf.constant([0.229, 0.224, 0.225]) 70 | scaled_img = (scaled_img - mean) / std 71 | 72 | return img[:, :, [2, 1, 0]], scaled_img 73 | 74 | 75 | def ctw_test_loader(dataset, batch_size): 76 | data_length = len(dataset) 77 | num_iter = data_length // batch_size 78 | for i in range(num_iter): 79 | imgs, scaled_img = [], [] 80 | for j in range(batch_size): 81 | sample = dataset[i*batch_size+j] 82 | imgs.append(sample[0]) 83 | scaled_img.append(sample[1]) 84 | 85 | imgs = tf.stack(imgs, 0) 86 | scaled_imgs = tf.stack(scaled_img, 0) 87 | yield imgs, scaled_imgs, data_length -------------------------------------------------------------------------------- /eval/ctw1500/eval_ctw1500.py: -------------------------------------------------------------------------------- 1 | import file_util 2 | import Polygon as plg 3 | import numpy as np 4 | 5 | pred_root = '../../outputs/submit_ctw1500/' 6 | gt_root = '../../data/CTW1500/test/text_label_curve/' 7 | 8 | def get_pred(path): 9 | lines = file_util.read_file(path).split('\n') 10 | bboxes = [] 11 | for line in lines: 12 | if line == '': 13 | continue 14 | bbox = line.split(',') 15 | if len(bbox) % 2 == 1: 16 | print(path) 17 | bbox = [(int)(x) for x in bbox] 18 | bboxes.append(bbox) 19 | return bboxes 20 | 21 | def get_gt(path): 22 | lines = file_util.read_file(path).split('\n') 23 | bboxes = [] 24 | for line in lines: 25 | if line == '': 26 | continue 27 | # line = util.str.remove_all(line, '\xef\xbb\xbf') 28 | # gt = util.str.split(line, ',') 29 | gt = line.split(',') 30 | 31 | x1 = np.int(gt[0]) 32 | y1 = np.int(gt[1]) 33 | 34 | bbox = [np.int(gt[i]) for i in range(4, 32)] 35 | bbox = np.asarray(bbox) + ([x1, y1] * 14) 36 | 37 | bboxes.append(bbox) 38 | return bboxes 39 | 40 | def get_union(pD,pG): 41 | areaA = pD.area(); 42 | areaB = pG.area(); 43 | return areaA + areaB - get_intersection(pD, pG); 44 | 45 | def get_intersection(pD,pG): 46 | pInt = pD & pG 47 | if len(pInt) == 0: 48 | return 0 49 | return pInt.area() 50 | 51 | if __name__ == '__main__': 52 | th = 0.5 53 | pred_list = file_util.read_dir(pred_root) 54 | 55 | tp, fp, npos = 0, 0, 0 56 | 57 | for pred_path in pred_list: 58 | preds = get_pred(pred_path) 59 | gt_path = gt_root + pred_path.split('/')[-1] 60 | gts = get_gt(gt_path) 61 | npos += len(gts) 62 | 63 | cover = set() 64 | for pred_id, pred in enumerate(preds): 65 | pred = np.array(pred) 66 | pred = pred.reshape(pred.shape[0] // 2, 2) 67 | # if pred.shape[0] <= 2: 68 | # continue 69 | pred_p = plg.Polygon(pred) 70 | 71 | flag = False 72 | for gt_id, gt in enumerate(gts): 73 | gt = np.array(gt) 74 | gt = gt.reshape(gt.shape[0] // 2, 2) 75 | gt_p = plg.Polygon(gt) 76 | 77 | union = get_union(pred_p, gt_p) 78 | inter = get_intersection(pred_p, gt_p) 79 | 80 | if inter * 1.0 / union >= th: 81 | if gt_id not in cover: 82 | flag = True 83 | cover.add(gt_id) 84 | if flag: 85 | tp += 1.0 86 | else: 87 | fp += 1.0 88 | 89 | print(tp, fp, npos) 90 | precision = tp / (tp + fp) 91 | recall = tp / npos 92 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 93 | 94 | print('p: %.4f, r: %.4f, f: %.4f'%(precision, recall, hmean)) -------------------------------------------------------------------------------- /models/mobilenet_v3_block.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def h_sigmoid(x): 5 | return tf.nn.relu6(x + 3) / 6 6 | 7 | 8 | def h_swish(x): 9 | return x * h_sigmoid(x) 10 | 11 | 12 | class SEBlock(tf.keras.layers.Layer): 13 | def __init__(self, input_channels, r=16): 14 | super(SEBlock, self).__init__() 15 | self.pool = tf.keras.layers.GlobalAveragePooling2D() 16 | self.fc1 = tf.keras.layers.Dense(units=input_channels // r) 17 | self.fc2 = tf.keras.layers.Dense(units=input_channels) 18 | 19 | def call(self, inputs, **kwargs): 20 | branch = self.pool(inputs) 21 | branch = self.fc1(branch) 22 | branch = tf.nn.relu(branch) 23 | branch = self.fc2(branch) 24 | branch = h_sigmoid(branch) 25 | branch = tf.expand_dims(input=branch, axis=1) 26 | branch = tf.expand_dims(input=branch, axis=1) 27 | output = inputs * branch 28 | return output 29 | 30 | 31 | class BottleNeck(tf.keras.layers.Layer): 32 | def __init__(self, in_size, exp_size, out_size, s, is_se_existing, NL, k): 33 | super(BottleNeck, self).__init__() 34 | self.stride = s 35 | self.in_size = in_size 36 | self.out_size = out_size 37 | self.is_se_existing = is_se_existing 38 | self.NL = NL 39 | self.conv1 = tf.keras.layers.Conv2D(filters=exp_size, 40 | kernel_size=(1, 1), 41 | strides=1, 42 | padding="same") 43 | self.bn1 = tf.keras.layers.BatchNormalization() 44 | self.dwconv = tf.keras.layers.DepthwiseConv2D(kernel_size=(k, k), 45 | strides=s, 46 | padding="same") 47 | self.bn2 = tf.keras.layers.BatchNormalization() 48 | self.se = SEBlock(input_channels=exp_size) 49 | self.conv2 = tf.keras.layers.Conv2D(filters=out_size, 50 | kernel_size=(1, 1), 51 | strides=1, 52 | padding="same") 53 | self.bn3 = tf.keras.layers.BatchNormalization() 54 | self.linear = tf.keras.layers.Activation(tf.keras.activations.linear) 55 | 56 | def call(self, inputs, training=None, **kwargs): 57 | x = self.conv1(inputs) 58 | x = self.bn1(x, training=training) 59 | if self.NL == "HS": 60 | x = h_swish(x) 61 | elif self.NL == "RE": 62 | x = tf.nn.relu6(x) 63 | x = self.dwconv(x) 64 | x = self.bn2(x, training=training) 65 | if self.NL == "HS": 66 | x = h_swish(x) 67 | elif self.NL == "RE": 68 | x = tf.nn.relu6(x) 69 | if self.is_se_existing: 70 | x = self.se(x) 71 | x = self.conv2(x) 72 | x = self.bn3(x, training=training) 73 | x = self.linear(x) 74 | 75 | if self.stride == 1 and self.in_size == self.out_size: 76 | x = tf.keras.layers.add([x, inputs]) 77 | 78 | return x 79 | -------------------------------------------------------------------------------- /util/dec.py: -------------------------------------------------------------------------------- 1 | #encoding=utf-8 2 | import logging 3 | import time 4 | def print_calling(fn): 5 | def wrapper(*args1, ** args2): 6 | s = "calling function %s"%(fn.__name__) 7 | logging.info(s) 8 | start = time.time() 9 | ret = fn(*args1, **args2) 10 | end = time.time() 11 | # s = "%s. time used = %f seconds"%(s, (end - start)) 12 | s = "function [%s] has been called, taking %f seconds"%(fn.__name__, (end - start)) 13 | logging.debug(s) 14 | return ret 15 | return wrapper 16 | 17 | 18 | def print_test(fn): 19 | def wrapper(*args1, ** args2): 20 | s = "running test: %s..."%(fn.__name__) 21 | logging.info(s) 22 | ret = fn(*args1, **args2) 23 | s = "running test: %s...succeed"%(fn.__name__) 24 | logging.debug(s) 25 | return ret 26 | return wrapper 27 | 28 | def print_calling_in_short(fn): 29 | def wrapper(*args1, ** args2): 30 | start = time.time() 31 | ret = fn(*args1, **args2) 32 | end = time.time() 33 | s = "function [%s] has been called, taking %f seconds"%(fn.__name__, (end - start)) 34 | logging.debug(s) 35 | return ret 36 | return wrapper 37 | 38 | import collections 39 | counter = collections.defaultdict(int) 40 | count_times =collections.defaultdict(int) 41 | def print_calling_in_short_for_tf(fn): 42 | import tensorflow as tf 43 | import util 44 | def wrapper(*args1, ** args2): 45 | start = time.time() 46 | thread_name = util.thread.get_current_thread_name() 47 | ret = fn(*args1, **args2) 48 | end = time.time() 49 | counter[fn.__name__] = counter[fn.__name__] + (end - start) 50 | count_times[fn.__name__] += 1 51 | all_time = sum([counter[name] for name in counter]) * 1.0 52 | for name in counter: 53 | # tf.logging.info('\t %s: %f, %f seconds'%(name, counter[name] / all_time, counter[name])) 54 | tf.logging.info('\t %s: %d callings, %fsper calling'%(name, count_times[name], counter[name] * 1.0 / count_times[name])) 55 | s = "Thread [%s]:function [%s] has been called, taking %f seconds"%(thread_name, fn.__name__, (end - start)) 56 | tf.logging.info(s) 57 | return ret 58 | return wrapper 59 | 60 | def timeit(fn): 61 | import util 62 | def wrapper(*args1, ** args2): 63 | start = time.time() 64 | thread_name = util.thread.get_current_thread_name() 65 | ret = fn(*args1, **args2) 66 | end = time.time() 67 | counter[fn.__name__] = counter[fn.__name__] + (end - start) 68 | count_times[fn.__name__] += 1 69 | all_time = sum([counter[name] for name in counter]) * 1.0 70 | for name in counter: 71 | logging.info('\t %s: %f, %f seconds'%(name, counter[name] / all_time, counter[name])) 72 | logging.info('\t %s: %d callings, %f seconds per calling'%(name, count_times[name], counter[name] * 1.0 / count_times[name])) 73 | s = "Thread [%s]:function [%s] has been called, taking %f seconds"%(thread_name, fn.__name__, (end - start)) 74 | # logging.info(s) 75 | return ret 76 | return wrapper 77 | 78 | 79 | -------------------------------------------------------------------------------- /MobileNetV3/mobilenet_v3_small.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from mobilenet_v3_block import BottleNeck, h_swish 3 | 4 | NUM_CLASSES = 10 5 | 6 | 7 | class MobileNetV3Small(tf.keras.Model): 8 | def __init__(self): 9 | super(MobileNetV3Small, self).__init__() 10 | self.conv1 = tf.keras.layers.Conv2D(filters=16, 11 | kernel_size=(3, 3), 12 | strides=2, 13 | padding="same") 14 | self.bn1 = tf.keras.layers.BatchNormalization() 15 | self.bneck1 = BottleNeck(in_size=16, exp_size=16, out_size=16, s=2, is_se_existing=True, NL="RE", k=3) 16 | self.bneck2 = BottleNeck(in_size=16, exp_size=72, out_size=24, s=2, is_se_existing=False, NL="RE", k=3) 17 | self.bneck3 = BottleNeck(in_size=24, exp_size=88, out_size=24, s=1, is_se_existing=False, NL="RE", k=3) 18 | self.bneck4 = BottleNeck(in_size=24, exp_size=96, out_size=40, s=2, is_se_existing=True, NL="HS", k=5) 19 | self.bneck5 = BottleNeck(in_size=40, exp_size=240, out_size=40, s=1, is_se_existing=True, NL="HS", k=5) 20 | self.bneck6 = BottleNeck(in_size=40, exp_size=240, out_size=40, s=1, is_se_existing=True, NL="HS", k=5) 21 | self.bneck7 = BottleNeck(in_size=40, exp_size=120, out_size=48, s=1, is_se_existing=True, NL="HS", k=5) 22 | self.bneck8 = BottleNeck(in_size=48, exp_size=144, out_size=48, s=1, is_se_existing=True, NL="HS", k=5) 23 | self.bneck9 = BottleNeck(in_size=48, exp_size=288, out_size=96, s=2, is_se_existing=True, NL="HS", k=5) 24 | self.bneck10 = BottleNeck(in_size=96, exp_size=576, out_size=96, s=1, is_se_existing=True, NL="HS", k=5) 25 | self.bneck11 = BottleNeck(in_size=96, exp_size=576, out_size=96, s=1, is_se_existing=True, NL="HS", k=5) 26 | 27 | self.conv2 = tf.keras.layers.Conv2D(filters=576, 28 | kernel_size=(1, 1), 29 | strides=1, 30 | padding="same") 31 | self.bn2 = tf.keras.layers.BatchNormalization() 32 | self.avgpool = tf.keras.layers.AveragePooling2D(pool_size=(7, 7), 33 | strides=1) 34 | self.conv3 = tf.keras.layers.Conv2D(filters=1280, 35 | kernel_size=(1, 1), 36 | strides=1, 37 | padding="same") 38 | self.conv4 = tf.keras.layers.Conv2D(filters=NUM_CLASSES, 39 | kernel_size=(1, 1), 40 | strides=1, 41 | padding="same", 42 | activation=tf.keras.activations.softmax) 43 | 44 | def call(self, inputs, training=None, mask=None): 45 | x = self.conv1(inputs) 46 | x = self.bn1(x, training=training) 47 | x = h_swish(x) 48 | 49 | x = self.bneck1(x, training=training) 50 | x = self.bneck2(x, training=training) 51 | x = self.bneck3(x, training=training) 52 | x = self.bneck4(x, training=training) 53 | x = self.bneck5(x, training=training) 54 | x = self.bneck6(x, training=training) 55 | x = self.bneck7(x, training=training) 56 | x = self.bneck8(x, training=training) 57 | x = self.bneck9(x, training=training) 58 | x = self.bneck10(x, training=training) 59 | x = self.bneck11(x, training=training) 60 | 61 | x = self.conv2(x) 62 | x = self.bn2(x, training=training) 63 | x = h_swish(x) 64 | x = self.avgpool(x) 65 | x = self.conv3(x) 66 | x = h_swish(x) 67 | x = self.conv4(x) 68 | 69 | return x 70 | 71 | 72 | if __name__ == '__main__': 73 | model = MobileNetV3Small() 74 | model.build(input_shape=(None, 224, 224, 3)) 75 | model.summary() 76 | 77 | -------------------------------------------------------------------------------- /pse/include/pybind11/detail/descr.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | NAMESPACE_BEGIN(detail) 16 | 17 | #if !defined(_MSC_VER) 18 | # define PYBIND11_DESCR_CONSTEXPR static constexpr 19 | #else 20 | # define PYBIND11_DESCR_CONSTEXPR const 21 | #endif 22 | 23 | /* Concatenate type signatures at compile time */ 24 | template 25 | struct descr { 26 | char text[N + 1]; 27 | 28 | constexpr descr() : text{'\0'} { } 29 | constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence()) { } 30 | 31 | template 32 | constexpr descr(char const (&s)[N+1], index_sequence) : text{s[Is]..., '\0'} { } 33 | 34 | template 35 | constexpr descr(char c, Chars... cs) : text{c, static_cast(cs)..., '\0'} { } 36 | 37 | static constexpr std::array types() { 38 | return {{&typeid(Ts)..., nullptr}}; 39 | } 40 | }; 41 | 42 | template 43 | constexpr descr plus_impl(const descr &a, const descr &b, 44 | index_sequence, index_sequence) { 45 | return {a.text[Is1]..., b.text[Is2]...}; 46 | } 47 | 48 | template 49 | constexpr descr operator+(const descr &a, const descr &b) { 50 | return plus_impl(a, b, make_index_sequence(), make_index_sequence()); 51 | } 52 | 53 | template 54 | constexpr descr _(char const(&text)[N]) { return descr(text); } 55 | constexpr descr<0> _(char const(&)[1]) { return {}; } 56 | 57 | template struct int_to_str : int_to_str { }; 58 | template struct int_to_str<0, Digits...> { 59 | static constexpr auto digits = descr(('0' + Digits)...); 60 | }; 61 | 62 | // Ternary description (like std::conditional) 63 | template 64 | constexpr enable_if_t> _(char const(&text1)[N1], char const(&)[N2]) { 65 | return _(text1); 66 | } 67 | template 68 | constexpr enable_if_t> _(char const(&)[N1], char const(&text2)[N2]) { 69 | return _(text2); 70 | } 71 | 72 | template 73 | constexpr enable_if_t _(const T1 &d, const T2 &) { return d; } 74 | template 75 | constexpr enable_if_t _(const T1 &, const T2 &d) { return d; } 76 | 77 | template auto constexpr _() -> decltype(int_to_str::digits) { 78 | return int_to_str::digits; 79 | } 80 | 81 | template constexpr descr<1, Type> _() { return {'%'}; } 82 | 83 | constexpr descr<0> concat() { return {}; } 84 | 85 | template 86 | constexpr descr concat(const descr &descr) { return descr; } 87 | 88 | template 89 | constexpr auto concat(const descr &d, const Args &...args) 90 | -> decltype(std::declval>() + concat(args...)) { 91 | return d + _(", ") + concat(args...); 92 | } 93 | 94 | template 95 | constexpr descr type_descr(const descr &descr) { 96 | return _("{") + descr + _("}"); 97 | } 98 | 99 | NAMESPACE_END(detail) 100 | NAMESPACE_END(PYBIND11_NAMESPACE) 101 | -------------------------------------------------------------------------------- /pse/include/pybind11/functional.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/functional.h: std::function<> support 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | #include 14 | 15 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 16 | NAMESPACE_BEGIN(detail) 17 | 18 | template 19 | struct type_caster> { 20 | using type = std::function; 21 | using retval_type = conditional_t::value, void_type, Return>; 22 | using function_type = Return (*) (Args...); 23 | 24 | public: 25 | bool load(handle src, bool convert) { 26 | if (src.is_none()) { 27 | // Defer accepting None to other overloads (if we aren't in convert mode): 28 | if (!convert) return false; 29 | return true; 30 | } 31 | 32 | if (!isinstance(src)) 33 | return false; 34 | 35 | auto func = reinterpret_borrow(src); 36 | 37 | /* 38 | When passing a C++ function as an argument to another C++ 39 | function via Python, every function call would normally involve 40 | a full C++ -> Python -> C++ roundtrip, which can be prohibitive. 41 | Here, we try to at least detect the case where the function is 42 | stateless (i.e. function pointer or lambda function without 43 | captured variables), in which case the roundtrip can be avoided. 44 | */ 45 | if (auto cfunc = func.cpp_function()) { 46 | auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); 47 | auto rec = (function_record *) c; 48 | 49 | if (rec && rec->is_stateless && 50 | same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { 51 | struct capture { function_type f; }; 52 | value = ((capture *) &rec->data)->f; 53 | return true; 54 | } 55 | } 56 | 57 | // ensure GIL is held during functor destruction 58 | struct func_handle { 59 | function f; 60 | func_handle(function&& f_) : f(std::move(f_)) {} 61 | func_handle(const func_handle&) = default; 62 | ~func_handle() { 63 | gil_scoped_acquire acq; 64 | function kill_f(std::move(f)); 65 | } 66 | }; 67 | 68 | // to emulate 'move initialization capture' in C++11 69 | struct func_wrapper { 70 | func_handle hfunc; 71 | func_wrapper(func_handle&& hf): hfunc(std::move(hf)) {} 72 | Return operator()(Args... args) const { 73 | gil_scoped_acquire acq; 74 | object retval(hfunc.f(std::forward(args)...)); 75 | /* Visual studio 2015 parser issue: need parentheses around this expression */ 76 | return (retval.template cast()); 77 | } 78 | }; 79 | 80 | value = func_wrapper(func_handle(std::move(func))); 81 | return true; 82 | } 83 | 84 | template 85 | static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) { 86 | if (!f_) 87 | return none().inc_ref(); 88 | 89 | auto result = f_.template target(); 90 | if (result) 91 | return cpp_function(*result, policy).release(); 92 | else 93 | return cpp_function(std::forward(f_), policy).release(); 94 | } 95 | 96 | PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster::name...) + _("], ") 97 | + make_caster::name + _("]")); 98 | }; 99 | 100 | NAMESPACE_END(detail) 101 | NAMESPACE_END(PYBIND11_NAMESPACE) 102 | -------------------------------------------------------------------------------- /pse/include/pybind11/eval.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/exec.h: Support for evaluating Python expressions and statements 3 | from strings and files 4 | 5 | Copyright (c) 2016 Klemens Morgenstern and 6 | Wenzel Jakob 7 | 8 | All rights reserved. Use of this source code is governed by a 9 | BSD-style license that can be found in the LICENSE file. 10 | */ 11 | 12 | #pragma once 13 | 14 | #include "pybind11.h" 15 | 16 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 17 | 18 | enum eval_mode { 19 | /// Evaluate a string containing an isolated expression 20 | eval_expr, 21 | 22 | /// Evaluate a string containing a single statement. Returns \c none 23 | eval_single_statement, 24 | 25 | /// Evaluate a string containing a sequence of statement. Returns \c none 26 | eval_statements 27 | }; 28 | 29 | template 30 | object eval(str expr, object global = globals(), object local = object()) { 31 | if (!local) 32 | local = global; 33 | 34 | /* PyRun_String does not accept a PyObject / encoding specifier, 35 | this seems to be the only alternative */ 36 | std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr; 37 | 38 | int start; 39 | switch (mode) { 40 | case eval_expr: start = Py_eval_input; break; 41 | case eval_single_statement: start = Py_single_input; break; 42 | case eval_statements: start = Py_file_input; break; 43 | default: pybind11_fail("invalid evaluation mode"); 44 | } 45 | 46 | PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr()); 47 | if (!result) 48 | throw error_already_set(); 49 | return reinterpret_steal(result); 50 | } 51 | 52 | template 53 | object eval(const char (&s)[N], object global = globals(), object local = object()) { 54 | /* Support raw string literals by removing common leading whitespace */ 55 | auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s)) 56 | : str(s); 57 | return eval(expr, global, local); 58 | } 59 | 60 | inline void exec(str expr, object global = globals(), object local = object()) { 61 | eval(expr, global, local); 62 | } 63 | 64 | template 65 | void exec(const char (&s)[N], object global = globals(), object local = object()) { 66 | eval(s, global, local); 67 | } 68 | 69 | template 70 | object eval_file(str fname, object global = globals(), object local = object()) { 71 | if (!local) 72 | local = global; 73 | 74 | int start; 75 | switch (mode) { 76 | case eval_expr: start = Py_eval_input; break; 77 | case eval_single_statement: start = Py_single_input; break; 78 | case eval_statements: start = Py_file_input; break; 79 | default: pybind11_fail("invalid evaluation mode"); 80 | } 81 | 82 | int closeFile = 1; 83 | std::string fname_str = (std::string) fname; 84 | #if PY_VERSION_HEX >= 0x03040000 85 | FILE *f = _Py_fopen_obj(fname.ptr(), "r"); 86 | #elif PY_VERSION_HEX >= 0x03000000 87 | FILE *f = _Py_fopen(fname.ptr(), "r"); 88 | #else 89 | /* No unicode support in open() :( */ 90 | auto fobj = reinterpret_steal(PyFile_FromString( 91 | const_cast(fname_str.c_str()), 92 | const_cast("r"))); 93 | FILE *f = nullptr; 94 | if (fobj) 95 | f = PyFile_AsFile(fobj.ptr()); 96 | closeFile = 0; 97 | #endif 98 | if (!f) { 99 | PyErr_Clear(); 100 | pybind11_fail("File \"" + fname_str + "\" could not be opened!"); 101 | } 102 | 103 | #if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION) 104 | PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(), 105 | local.ptr()); 106 | (void) closeFile; 107 | #else 108 | PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(), 109 | local.ptr(), closeFile); 110 | #endif 111 | 112 | if (!result) 113 | throw error_already_set(); 114 | return reinterpret_steal(result); 115 | } 116 | 117 | NAMESPACE_END(PYBIND11_NAMESPACE) 118 | -------------------------------------------------------------------------------- /MobileNetV3/mobilenet_v3_large.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from mobilenet_v3_block import BottleNeck, h_swish 3 | 4 | NUM_CLASSES = 10 5 | 6 | 7 | class MobileNetV3Large(tf.keras.Model): 8 | def __init__(self): 9 | super(MobileNetV3Large, self).__init__() 10 | self.conv1 = tf.keras.layers.Conv2D(filters=16, 11 | kernel_size=(3, 3), 12 | strides=2, 13 | padding="same") 14 | self.bn1 = tf.keras.layers.BatchNormalization() 15 | self.bneck1 = BottleNeck(in_size=16, exp_size=16, out_size=16, s=1, is_se_existing=False, NL="RE", k=3) 16 | self.bneck2 = BottleNeck(in_size=16, exp_size=64, out_size=24, s=2, is_se_existing=False, NL="RE", k=3) 17 | self.bneck3 = BottleNeck(in_size=24, exp_size=72, out_size=24, s=1, is_se_existing=False, NL="RE", k=3) 18 | self.bneck4 = BottleNeck(in_size=24, exp_size=72, out_size=40, s=2, is_se_existing=True, NL="RE", k=5) 19 | self.bneck5 = BottleNeck(in_size=40, exp_size=120, out_size=40, s=1, is_se_existing=True, NL="RE", k=5) 20 | self.bneck6 = BottleNeck(in_size=40, exp_size=120, out_size=40, s=1, is_se_existing=True, NL="RE", k=5) 21 | self.bneck7 = BottleNeck(in_size=40, exp_size=240, out_size=80, s=2, is_se_existing=False, NL="HS", k=3) 22 | self.bneck8 = BottleNeck(in_size=80, exp_size=200, out_size=80, s=1, is_se_existing=False, NL="HS", k=3) 23 | self.bneck9 = BottleNeck(in_size=80, exp_size=184, out_size=80, s=1, is_se_existing=False, NL="HS", k=3) 24 | self.bneck10 = BottleNeck(in_size=80, exp_size=184, out_size=80, s=1, is_se_existing=False, NL="HS", k=3) 25 | self.bneck11 = BottleNeck(in_size=80, exp_size=480, out_size=112, s=1, is_se_existing=True, NL="HS", k=3) 26 | self.bneck12 = BottleNeck(in_size=112, exp_size=672, out_size=112, s=1, is_se_existing=True, NL="HS", k=3) 27 | self.bneck13 = BottleNeck(in_size=112, exp_size=672, out_size=160, s=2, is_se_existing=True, NL="HS", k=5) 28 | self.bneck14 = BottleNeck(in_size=160, exp_size=960, out_size=160, s=1, is_se_existing=True, NL="HS", k=5) 29 | self.bneck15 = BottleNeck(in_size=160, exp_size=960, out_size=160, s=1, is_se_existing=True, NL="HS", k=5) 30 | 31 | self.conv2 = tf.keras.layers.Conv2D(filters=960, 32 | kernel_size=(1, 1), 33 | strides=1, 34 | padding="same") 35 | self.bn2 = tf.keras.layers.BatchNormalization() 36 | self.avgpool = tf.keras.layers.AveragePooling2D(pool_size=(7, 7), 37 | strides=1) 38 | self.conv3 = tf.keras.layers.Conv2D(filters=1280, 39 | kernel_size=(1, 1), 40 | strides=1, 41 | padding="same") 42 | self.conv4 = tf.keras.layers.Conv2D(filters=NUM_CLASSES, 43 | kernel_size=(1, 1), 44 | strides=1, 45 | padding="same", 46 | activation=tf.keras.activations.softmax) 47 | 48 | def call(self, inputs, training=None, mask=None): 49 | x = self.conv1(inputs) 50 | x = self.bn1(x, training=training) 51 | x = h_swish(x) 52 | 53 | x = self.bneck1(x, training=training) 54 | x = self.bneck2(x, training=training) 55 | x = self.bneck3(x, training=training) 56 | x = self.bneck4(x, training=training) 57 | x = self.bneck5(x, training=training) 58 | x = self.bneck6(x, training=training) 59 | x = self.bneck7(x, training=training) 60 | x = self.bneck8(x, training=training) 61 | x = self.bneck9(x, training=training) 62 | x = self.bneck10(x, training=training) 63 | x = self.bneck11(x, training=training) 64 | x = self.bneck12(x, training=training) 65 | x = self.bneck13(x, training=training) 66 | x = self.bneck14(x, training=training) 67 | x = self.bneck15(x, training=training) 68 | 69 | x = self.conv2(x) 70 | x = self.bn2(x, training=training) 71 | x = h_swish(x) 72 | x = self.avgpool(x) 73 | x = self.conv3(x) 74 | x = h_swish(x) 75 | x = self.conv4(x) 76 | 77 | return x 78 | 79 | 80 | if __name__ == '__main__': 81 | model = MobileNetV3Large() 82 | model.build(input_shape=(None, 224, 224, 3)) 83 | model.summary() 84 | -------------------------------------------------------------------------------- /pse/adaptor.cpp: -------------------------------------------------------------------------------- 1 | #include "pybind11/pybind11.h" 2 | #include "pybind11/numpy.h" 3 | #include "pybind11/stl.h" 4 | #include "pybind11/stl_bind.h" 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | using namespace std; 15 | using namespace cv; 16 | 17 | namespace py = pybind11; 18 | 19 | namespace pse_adaptor { 20 | void get_kernals(const int *data, vector data_shape, vector &kernals) { 21 | for (int i = 0; i < data_shape[0]; ++i) { 22 | Mat kernal = Mat::zeros(data_shape[1], data_shape[2], CV_8UC1); 23 | for (int x = 0; x < kernal.rows; ++x) { 24 | for (int y = 0; y < kernal.cols; ++y) { 25 | kernal.at(x, y) = data[i * data_shape[1] * data_shape[2] + x * data_shape[2] + y]; 26 | } 27 | } 28 | kernals.emplace_back(kernal); 29 | } 30 | } 31 | 32 | void growing_text_line(vector &kernals, vector> &text_line, float min_area) { 33 | 34 | Mat label_mat; 35 | int label_num = connectedComponents(kernals[kernals.size() - 1], label_mat, 4); 36 | 37 | // cout << "label num: " << label_num << endl; 38 | 39 | int area[label_num + 1]; 40 | memset(area, 0, sizeof(area)); 41 | for (int x = 0; x < label_mat.rows; ++x) { 42 | for (int y = 0; y < label_mat.cols; ++y) { 43 | int label = label_mat.at(x, y); 44 | if (label == 0) continue; 45 | area[label] += 1; 46 | } 47 | } 48 | 49 | queue queue, next_queue; 50 | for (int x = 0; x < label_mat.rows; ++x) { 51 | vector row(label_mat.cols); 52 | for (int y = 0; y < label_mat.cols; ++y) { 53 | int label = label_mat.at(x, y); 54 | 55 | if (label == 0) continue; 56 | if (area[label] < min_area) continue; 57 | 58 | Point point(x, y); 59 | queue.push(point); 60 | row[y] = label; 61 | } 62 | text_line.emplace_back(row); 63 | } 64 | 65 | // cout << "ok" << endl; 66 | 67 | int dx[] = {-1, 1, 0, 0}; 68 | int dy[] = {0, 0, -1, 1}; 69 | 70 | for (int kernal_id = kernals.size() - 2; kernal_id >= 0; --kernal_id) { 71 | while (!queue.empty()) { 72 | Point point = queue.front(); queue.pop(); 73 | int x = point.x; 74 | int y = point.y; 75 | int label = text_line[x][y]; 76 | // cout << text_line.size() << ' ' << text_line[0].size() << ' ' << x << ' ' << y << endl; 77 | 78 | bool is_edge = true; 79 | for (int d = 0; d < 4; ++d) { 80 | int tmp_x = x + dx[d]; 81 | int tmp_y = y + dy[d]; 82 | 83 | if (tmp_x < 0 || tmp_x >= (int)text_line.size()) continue; 84 | if (tmp_y < 0 || tmp_y >= (int)text_line[1].size()) continue; 85 | if (kernals[kernal_id].at(tmp_x, tmp_y) == 0) continue; 86 | if (text_line[tmp_x][tmp_y] > 0) continue; 87 | 88 | Point point(tmp_x, tmp_y); 89 | queue.push(point); 90 | text_line[tmp_x][tmp_y] = label; 91 | is_edge = false; 92 | } 93 | 94 | if (is_edge) { 95 | next_queue.push(point); 96 | } 97 | } 98 | swap(queue, next_queue); 99 | } 100 | } 101 | 102 | vector> pse(py::array_t quad_n9, float min_area) { 103 | auto buf = quad_n9.request(); 104 | auto data = static_cast(buf.ptr); 105 | vector kernals; 106 | get_kernals(data, buf.shape, kernals); 107 | 108 | // cout << "min_area: " << min_area << endl; 109 | // for (int i = 0; i < kernals.size(); ++i) { 110 | // cout << "kernal" << i <<" shape: " << kernals[i].rows << ' ' << kernals[i].cols << endl; 111 | // } 112 | 113 | vector> text_line; 114 | growing_text_line(kernals, text_line, min_area); 115 | 116 | return text_line; 117 | } 118 | } 119 | 120 | PYBIND11_MODULE(adaptor,m) { 121 | 122 | m.def("pse", &pse_adaptor::pse, "pse"); 123 | 124 | } 125 | -------------------------------------------------------------------------------- /util/np.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | 4 | TINY = np.exp(-100) 5 | concat = np.concatenate 6 | def is_2D(m): 7 | ''' 8 | judge if a matrix is 2-D or not 9 | ''' 10 | return len(np.shape(m)) == 2 11 | 12 | def norm1(v): 13 | return np.sum(np.abs(v)) 14 | 15 | def norm2(v): 16 | return np.sqrt(np.sum(v ** 2)) 17 | 18 | def norm2_squared(v): 19 | return np.sum(v ** 2) 20 | 21 | 22 | def cos_dist(v1, v2): 23 | length1 = norm2(v1) 24 | length2 = norm2(v2) 25 | return np.dot(v1, v2) / (length1 * length2) 26 | 27 | def eu_dist(v1, v2): 28 | v = v1 - v2 29 | return norm2(v) 30 | 31 | def chi_squared_dist(f1, f2): 32 | dist = 0 33 | for ff1, ff2 in zip(f1, f2): 34 | if ff1 + ff2 == 0:# color feature values are supposed to be non-negative. If this case happened, it means both ne and de are 0s 35 | continue; 36 | dist += (ff1 - ff2) ** 2 * 1.0/ (ff1 + ff2) 37 | return np.sqrt(dist) 38 | 39 | def flatten(arr, ndim = 1): 40 | """ 41 | flatten an multi-dimensional array to a certain degree. 42 | ndim: the number of dimensions after flatten 43 | """ 44 | arr = np.asarray(arr) 45 | dims = len(arr.shape) 46 | shape = [np.prod(arr.shape[0: dims + 1 - ndim])] 47 | shape.extend(arr.shape[dims + 1 - ndim: dims]) 48 | return np.reshape(arr, shape) 49 | 50 | def arcsin(sins, xs = None): 51 | """ 52 | cal arcsin. 53 | xs: if this parameter is provided, the returned arcsins will be within [0, 2*pi) 54 | otherwise the default [-pi/2, pi/2] 55 | """ 56 | arcs = np.arcsin(sins); 57 | if xs != None: 58 | xs = np.asarray(xs) 59 | sins = np.asarray(sins) 60 | # if x > 0, then the corresponding mask value is -1. The resulting angle unchanged: v = 0 - (-v) = v. else, v = pi - v 61 | add_pi = xs < 0 62 | pi_mask = add_pi * np.pi 63 | # 0 --> 1, 1 --> -1 64 | arc_mask = 2 * add_pi - 1 65 | arcs = pi_mask - arcs * arc_mask 66 | 67 | # if x >= 0 and sin < 0, v = 2*pi + v 68 | add_2_pi = (xs >= 0) * (sins < 0) 69 | pi_mask = add_2_pi * 2 * np.pi 70 | arcs = pi_mask + arcs 71 | return arcs 72 | 73 | def sin(ys = None, lengths = None, xs = None, angles = None): 74 | """ 75 | calculate sin with multiple kinds of parameters 76 | """ 77 | if not angles is None: 78 | return np.sin(angles) 79 | 80 | if ys is None: 81 | raise ValueError('ys must be provided when "angles" is None ') 82 | 83 | if lengths is None: 84 | if xs is None: 85 | raise ValueError('xs must be provided when "lengths" is None ') 86 | lengths = np.sqrt(xs ** 2 + ys ** 2) 87 | 88 | if not np.iterable(lengths): 89 | sins = ys / lengths if lengths > 0 else 0 90 | else: 91 | lengths = np.asarray(lengths) 92 | shape = lengths.shape 93 | ys = flatten(ys) 94 | lengths = flatten(lengths) 95 | sins = [y / length if length > 0 else 0 for (y, length) in zip(ys, lengths)] 96 | sins = np.reshape(sins, shape) 97 | return sins 98 | 99 | def sum_all(m): 100 | """ 101 | sum up all the elements in a multi-dimension array 102 | """ 103 | return np.sum(m) 104 | 105 | 106 | def clone(obj, deep = False): 107 | if not deep: 108 | return copy.copy(obj) 109 | return copy.deepcopy(obj) 110 | 111 | def empty_list(length, etype): 112 | empty_list = [None] * length 113 | for i in xrange(length): 114 | if etype == list: 115 | empty_list[i] = [] 116 | else: 117 | raise NotImplementedError 118 | 119 | return empty_list 120 | 121 | def shuffle(arr): 122 | import random 123 | random.shuffle(arr) 124 | 125 | def is_empty(a): 126 | ''' 127 | tell whether an array is empty. 128 | If a is multidimensional, it is empty when it contains no entry in the last dimension. 129 | ''' 130 | if a is None: 131 | return True 132 | 133 | shape = np.shape(a) 134 | if np.prod(shape) == 0: 135 | return True 136 | 137 | return False 138 | 139 | def angle_with_x(x, y): 140 | """ 141 | return the arctan x/y, in range [-pi, pi] 142 | """ 143 | return np.arctan2(y, x) 144 | 145 | def has_infty(x): 146 | test = x == np.infty 147 | return np.sum(test) > 0 148 | 149 | def has_nan(x): 150 | x = np.asarray(x) 151 | test = x != x 152 | return np.sum(test) > 0 153 | 154 | def has_nan_or_infty(x): 155 | if has_nan(x): 156 | return True 157 | 158 | if has_infty(x): 159 | return True 160 | 161 | 162 | def iterable(x): 163 | return np.iterable(x) 164 | 165 | def smooth(arr): 166 | result = [0] * len(arr) 167 | s = 0 168 | for idx, n in enumerate(arr): 169 | s += n 170 | result[idx] = s * 1.0 / (idx + 1) 171 | return result 172 | -------------------------------------------------------------------------------- /pse/.ycm_extra_conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Copyright (C) 2014 Google Inc. 4 | # 5 | # This file is part of YouCompleteMe. 6 | # 7 | # YouCompleteMe is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # YouCompleteMe is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with YouCompleteMe. If not, see . 19 | 20 | import os 21 | import sys 22 | import glob 23 | import ycm_core 24 | 25 | # These are the compilation flags that will be used in case there's no 26 | # compilation database set (by default, one is not set). 27 | # CHANGE THIS LIST OF FLAGS. YES, THIS IS THE DROID YOU HAVE BEEN LOOKING FOR. 28 | sys.path.append(os.path.dirname(__file__)) 29 | 30 | 31 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 32 | 33 | from plumbum.cmd import python_config 34 | 35 | 36 | flags = [ 37 | '-Wall', 38 | '-Wextra', 39 | '-Wnon-virtual-dtor', 40 | '-Winvalid-pch', 41 | '-Wno-unused-local-typedefs', 42 | '-std=c++11', 43 | '-x', 'c++', 44 | '-Iinclude', 45 | ] + python_config('--cflags').split() 46 | 47 | 48 | # Set this to the absolute path to the folder (NOT the file!) containing the 49 | # compile_commands.json file to use that instead of 'flags'. See here for 50 | # more details: http://clang.llvm.org/docs/JSONCompilationDatabase.html 51 | # 52 | # Most projects will NOT need to set this to anything; you can just change the 53 | # 'flags' list of compilation flags. 54 | compilation_database_folder = '' 55 | 56 | if os.path.exists( compilation_database_folder ): 57 | database = ycm_core.CompilationDatabase( compilation_database_folder ) 58 | else: 59 | database = None 60 | 61 | SOURCE_EXTENSIONS = [ '.cpp', '.cxx', '.cc', '.c', '.m', '.mm' ] 62 | 63 | def DirectoryOfThisScript(): 64 | return os.path.dirname( os.path.abspath( __file__ ) ) 65 | 66 | 67 | def MakeRelativePathsInFlagsAbsolute( flags, working_directory ): 68 | if not working_directory: 69 | return list( flags ) 70 | new_flags = [] 71 | make_next_absolute = False 72 | path_flags = [ '-isystem', '-I', '-iquote', '--sysroot=' ] 73 | for flag in flags: 74 | new_flag = flag 75 | 76 | if make_next_absolute: 77 | make_next_absolute = False 78 | if not flag.startswith( '/' ): 79 | new_flag = os.path.join( working_directory, flag ) 80 | 81 | for path_flag in path_flags: 82 | if flag == path_flag: 83 | make_next_absolute = True 84 | break 85 | 86 | if flag.startswith( path_flag ): 87 | path = flag[ len( path_flag ): ] 88 | new_flag = path_flag + os.path.join( working_directory, path ) 89 | break 90 | 91 | if new_flag: 92 | new_flags.append( new_flag ) 93 | return new_flags 94 | 95 | 96 | def IsHeaderFile( filename ): 97 | extension = os.path.splitext( filename )[ 1 ] 98 | return extension in [ '.h', '.hxx', '.hpp', '.hh' ] 99 | 100 | 101 | def GetCompilationInfoForFile( filename ): 102 | # The compilation_commands.json file generated by CMake does not have entries 103 | # for header files. So we do our best by asking the db for flags for a 104 | # corresponding source file, if any. If one exists, the flags for that file 105 | # should be good enough. 106 | if IsHeaderFile( filename ): 107 | basename = os.path.splitext( filename )[ 0 ] 108 | for extension in SOURCE_EXTENSIONS: 109 | replacement_file = basename + extension 110 | if os.path.exists( replacement_file ): 111 | compilation_info = database.GetCompilationInfoForFile( 112 | replacement_file ) 113 | if compilation_info.compiler_flags_: 114 | return compilation_info 115 | return None 116 | return database.GetCompilationInfoForFile( filename ) 117 | 118 | 119 | # This is the entry point; this function is called by ycmd to produce flags for 120 | # a file. 121 | def FlagsForFile( filename, **kwargs ): 122 | if database: 123 | # Bear in mind that compilation_info.compiler_flags_ does NOT return a 124 | # python list, but a "list-like" StringVec object 125 | compilation_info = GetCompilationInfoForFile( filename ) 126 | if not compilation_info: 127 | return None 128 | 129 | final_flags = MakeRelativePathsInFlagsAbsolute( 130 | compilation_info.compiler_flags_, 131 | compilation_info.compiler_working_dir_ ) 132 | else: 133 | relative_to = DirectoryOfThisScript() 134 | final_flags = MakeRelativePathsInFlagsAbsolute( flags, relative_to ) 135 | 136 | return { 137 | 'flags': final_flags, 138 | 'do_cache': True 139 | } 140 | 141 | -------------------------------------------------------------------------------- /util/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | # import matplotlib.pyplot as plt 5 | import matplotlib 6 | matplotlib.use('pdf') 7 | import matplotlib.pyplot as plt 8 | import os 9 | import sys 10 | import numpy as np 11 | 12 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 13 | 14 | def savefig(fname, dpi=None): 15 | dpi = 150 if dpi == None else dpi 16 | plt.savefig(fname, dpi=dpi) 17 | 18 | def plot_overlap(logger, names=None): 19 | names = logger.names if names == None else names 20 | numbers = logger.numbers 21 | for _, name in enumerate(names): 22 | x = np.arange(len(numbers[name])) 23 | plt.plot(x, np.asarray(numbers[name])) 24 | return [logger.title + '(' + name + ')' for name in names] 25 | 26 | class Logger(object): 27 | '''Save training process to log file with simple plot function.''' 28 | def __init__(self, fpath, title=None, resume=False): 29 | self.file = None 30 | self.resume = resume 31 | self.title = '' if title == None else title 32 | if fpath is not None: 33 | if resume: 34 | self.file = open(fpath, 'r') 35 | name = self.file.readline() 36 | self.names = name.rstrip().split('\t') 37 | self.numbers = {} 38 | for _, name in enumerate(self.names): 39 | self.numbers[name] = [] 40 | 41 | for numbers in self.file: 42 | numbers = numbers.rstrip().split('\t') 43 | for i in range(0, len(numbers)): 44 | self.numbers[self.names[i]].append(numbers[i]) 45 | self.file.close() 46 | self.file = open(fpath, 'a') 47 | else: 48 | self.file = open(fpath, 'w') 49 | 50 | def set_names(self, names): 51 | if self.resume: 52 | pass 53 | # initialize numbers as empty list 54 | self.numbers = {} 55 | self.names = names 56 | for _, name in enumerate(self.names): 57 | self.file.write(name) 58 | self.file.write('\t') 59 | self.numbers[name] = [] 60 | self.file.write('\n') 61 | self.file.flush() 62 | 63 | 64 | def append(self, numbers): 65 | assert len(self.names) == len(numbers), 'Numbers do not match names' 66 | for index, num in enumerate(numbers): 67 | self.file.write("{0:.6f}".format(num)) 68 | self.file.write('\t') 69 | self.numbers[self.names[index]].append(num) 70 | self.file.write('\n') 71 | self.file.flush() 72 | 73 | def plot(self, names=None): 74 | print ('plot') 75 | ''' 76 | names = self.names if names == None else names 77 | numbers = self.numbers 78 | for _, name in enumerate(names): 79 | x = np.arange(len(numbers[name])) 80 | plt.plot(x, np.asarray(numbers[name])) 81 | plt.legend([self.title + '(' + name + ')' for name in names]) 82 | plt.grid(True) 83 | ''' 84 | 85 | def close(self): 86 | if self.file is not None: 87 | self.file.close() 88 | 89 | class LoggerMonitor(object): 90 | '''Load and visualize multiple logs.''' 91 | def __init__ (self, paths): 92 | '''paths is a distionary with {name:filepath} pair''' 93 | self.loggers = [] 94 | for title, path in paths.items(): 95 | logger = Logger(path, title=title, resume=True) 96 | self.loggers.append(logger) 97 | 98 | def plot(self, names=None): 99 | plt.figure() 100 | plt.subplot(121) 101 | legend_text = [] 102 | for logger in self.loggers: 103 | legend_text += plot_overlap(logger, names) 104 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 105 | plt.grid(True) 106 | 107 | if __name__ == '__main__': 108 | # # Example 109 | # logger = Logger('test.txt') 110 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 111 | 112 | # length = 100 113 | # t = np.arange(length) 114 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 115 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 116 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 117 | 118 | # for i in range(0, length): 119 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 120 | # logger.plot() 121 | 122 | # Example: logger monitor 123 | paths = { 124 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 125 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 126 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 127 | } 128 | 129 | field = ['Valid Acc.'] 130 | 131 | monitor = LoggerMonitor(paths) 132 | monitor.plot(names=field) 133 | savefig('test.eps') -------------------------------------------------------------------------------- /pse/include/pybind11/buffer_info.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/buffer_info.h: Python buffer object interface 3 | 4 | Copyright (c) 2016 Wenzel Jakob 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "detail/common.h" 13 | 14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 15 | 16 | /// Information record describing a Python buffer object 17 | struct buffer_info { 18 | void *ptr = nullptr; // Pointer to the underlying storage 19 | ssize_t itemsize = 0; // Size of individual items in bytes 20 | ssize_t size = 0; // Total number of entries 21 | std::string format; // For homogeneous buffers, this should be set to format_descriptor::format() 22 | ssize_t ndim = 0; // Number of dimensions 23 | std::vector shape; // Shape of the tensor (1 entry per dimension) 24 | std::vector strides; // Number of bytes between adjacent entries (for each per dimension) 25 | bool readonly = false; // flag to indicate if the underlying storage may be written to 26 | 27 | buffer_info() { } 28 | 29 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 30 | detail::any_container shape_in, detail::any_container strides_in, bool readonly=false) 31 | : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), 32 | shape(std::move(shape_in)), strides(std::move(strides_in)), readonly(readonly) { 33 | if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) 34 | pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); 35 | for (size_t i = 0; i < (size_t) ndim; ++i) 36 | size *= shape[i]; 37 | } 38 | 39 | template 40 | buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in, bool readonly=false) 41 | : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in), readonly) { } 42 | 43 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size, bool readonly=false) 44 | : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}, readonly) { } 45 | 46 | template 47 | buffer_info(T *ptr, ssize_t size, bool readonly=false) 48 | : buffer_info(ptr, sizeof(T), format_descriptor::format(), size, readonly) { } 49 | 50 | template 51 | buffer_info(const T *ptr, ssize_t size, bool readonly=true) 52 | : buffer_info(const_cast(ptr), sizeof(T), format_descriptor::format(), size, readonly) { } 53 | 54 | explicit buffer_info(Py_buffer *view, bool ownview = true) 55 | : buffer_info(view->buf, view->itemsize, view->format, view->ndim, 56 | {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}, view->readonly) { 57 | this->view = view; 58 | this->ownview = ownview; 59 | } 60 | 61 | buffer_info(const buffer_info &) = delete; 62 | buffer_info& operator=(const buffer_info &) = delete; 63 | 64 | buffer_info(buffer_info &&other) { 65 | (*this) = std::move(other); 66 | } 67 | 68 | buffer_info& operator=(buffer_info &&rhs) { 69 | ptr = rhs.ptr; 70 | itemsize = rhs.itemsize; 71 | size = rhs.size; 72 | format = std::move(rhs.format); 73 | ndim = rhs.ndim; 74 | shape = std::move(rhs.shape); 75 | strides = std::move(rhs.strides); 76 | std::swap(view, rhs.view); 77 | std::swap(ownview, rhs.ownview); 78 | readonly = rhs.readonly; 79 | return *this; 80 | } 81 | 82 | ~buffer_info() { 83 | if (view && ownview) { PyBuffer_Release(view); delete view; } 84 | } 85 | 86 | private: 87 | struct private_ctr_tag { }; 88 | 89 | buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, 90 | detail::any_container &&shape_in, detail::any_container &&strides_in, bool readonly) 91 | : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in), readonly) { } 92 | 93 | Py_buffer *view = nullptr; 94 | bool ownview = false; 95 | }; 96 | 97 | NAMESPACE_BEGIN(detail) 98 | 99 | template struct compare_buffer_info { 100 | static bool compare(const buffer_info& b) { 101 | return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); 102 | } 103 | }; 104 | 105 | template struct compare_buffer_info::value>> { 106 | static bool compare(const buffer_info& b) { 107 | return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value || 108 | ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || 109 | ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); 110 | } 111 | }; 112 | 113 | NAMESPACE_END(detail) 114 | NAMESPACE_END(PYBIND11_NAMESPACE) 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Shape Robust Text Detection with Progressive Scale Expansion Network 2 | 3 | ## Requirements 4 | * Python3 5 | * pyclipper 6 | * Polygon2 7 | * OpenCV 8 | * TensorFlow 2.0+ 9 | 10 | 11 | ## Introduction 12 | (PSENet-tf2.0)Progressive Scale Expansion Network (PSENet) is a text detector which is able to well detect the arbitrary-shape text in natural scene. 13 | Besides, based on this text segmentation model, we got top 6 in MTWI 2018 Text Detection Challenge 14 | 15 | ## Training (polygon) 16 | ``` 17 | CUDA_VISIBLE_DEVICES=0 python train_ic15.py 18 | ``` 19 | 20 | ## Testing (polygon) 21 | ``` 22 | CUDA_VISIBLE_DEVICES=0 python test_ic15.py --scale 1 --resume [path of model] 23 | ``` 24 | 25 | ## Training (quadrilateral) 26 | ``` 27 | CUDA_VISIBLE_DEVICES=0 python train_id41k.py 28 | ``` 29 | 30 | ## Testing (quadrilateral) 31 | ``` 32 | CUDA_VISIBLE_DEVICES=0 python test_id41k.py --scale 1 --resume [path of model] 33 | ``` 34 | 35 | ## Eval script for ICDAR 2015 and SCUT-CTW1500 36 | ``` 37 | cd eval 38 | sh eval_ic15.sh 39 | sh eval_ctw1500.sh 40 | ``` 41 | 42 | 43 | ## Performance (new version paper) 44 | ### [ICDAR 2015](http://rrc.cvc.uab.es/?ch=4&com=evaluation&task=1) 45 | | Method | Extra Data | Precision (%) | Recall (%) | F-measure (%) | FPS (1080Ti) | Model | 46 | | - | - | - | - | - | - | - | 47 | | PSENet-1s (ResNet50) | - | 81.49 | 79.68 | 80.57 | 1.6 | [baiduyun](https://pan.baidu.com/s/17FssfXd-hjsU5i2GGrKD-g)(extract code: rxti); [OneDrive](https://1drv.ms/u/s!Ai5Ldd26Lrzkkx3E1OTZzcNlMz5T) | 48 | | PSENet-1s (ResNet50) | pretrain on IC17 MLT | 86.92 | 84.5 | 85.69 | 1.6 | [baiduyun](https://pan.baidu.com/s/1oKVxHKuT3hdzDUmksbcgAQ)(extract code: aieo); [OneDrive](https://1drv.ms/u/s!Ai5Ldd26Lrzkkx44xvpay4rbV4nW) | 49 | | PSENet-4s (ResNet50) | pretrain on IC17 MLT | 86.1 | 83.77 | 84.92 | 3.8 | [baiduyun](https://pan.baidu.com/s/1oKVxHKuT3hdzDUmksbcgAQ)(extract code: aieo); [OneDrive](https://1drv.ms/u/s!Ai5Ldd26Lrzkkx44xvpay4rbV4nW) | 50 | 51 | ### [SCUT-CTW1500](https://github.com/Yuliang-Liu/Curve-Text-Detector) 52 | | Method | Extra Data | Precision (%) | Recall (%) | F-measure (%) | FPS (1080Ti) | Model | 53 | | - | - | - | - | - | - | - | 54 | | PSENet-1s (ResNet50) | - | 80.57 | 75.55 | 78.0 | 3.9 | [baiduyun](https://pan.baidu.com/s/1BqJspFwBmHjoqlE0jOrJQg)(extract code: ksv7); [OneDrive](https://1drv.ms/u/s!Ai5Ldd26LrzkkxtlTb-yqBPd1PCn) | 55 | | PSENet-1s (ResNet50) | pretrain on IC17 MLT | 84.84| 79.73 | 82.2 | 3.9 | [baiduyun](https://pan.baidu.com/s/1zonNEABLk4ifseeJtQeS4w)(extract code: z7ac); [OneDrive](https://1drv.ms/u/s!Ai5Ldd26LrzkkxxJcfU1a__6nJTT) | 56 | | PSENet-4s (ResNet50) | pretrain on IC17 MLT | 82.09 | 77.84 | 79.9 | 8.4 | [baiduyun](https://pan.baidu.com/s/1zonNEABLk4ifseeJtQeS4w)(extract code: z7ac); [OneDrive](https://1drv.ms/u/s!Ai5Ldd26LrzkkxxJcfU1a__6nJTT) | 57 | 58 | ## Performance (old version paper) 59 | ### [ICDAR 2015](http://rrc.cvc.uab.es/?ch=4&com=evaluation&task=1) (training with ICDAR 2017 MLT) 60 | | Method | Precision (%) | Recall (%) | F-measure (%) | 61 | | - | - | - | - | 62 | | PSENet-4s (ResNet152) | 87.98 | 83.87 | 85.88 | 63 | | PSENet-2s (ResNet152) | 89.30 | 85.22 | 87.21 | 64 | | PSENet-1s (ResNet152) | 88.71 | 85.51 | 87.08 | 65 | 66 | ### [ICDAR 2017 MLT](http://rrc.cvc.uab.es/?ch=8&com=evaluation&task=1) 67 | | Method | Precision (%) | Recall (%) | F-measure (%) | 68 | | - | - | - | - | 69 | | PSENet-4s (ResNet152) | 75.98 | 67.56 | 71.52 | 70 | | PSENet-2s (ResNet152) | 76.97 | 68.35 | 72.40 | 71 | | PSENet-1s (ResNet152) | 77.01 | 68.40 | 72.45 | 72 | 73 | ### [SCUT-CTW1500](https://github.com/Yuliang-Liu/Curve-Text-Detector) 74 | | Method | Precision (%) | Recall (%) | F-measure (%) | 75 | | - | - | - | - | 76 | | PSENet-4s (ResNet152) | 80.49 | 78.13 | 79.29 | 77 | | PSENet-2s (ResNet152) | 81.95 | 79.30 | 80.60 | 78 | | PSENet-1s (ResNet152) | 82.50 | 79.89 | 81.17 | 79 | 80 | ### [ICPR MTWI 2018 Challenge 2](https://tianchi.aliyun.com/competition/rankingList.htm?spm=5176.100067.5678.4.65166a80jnPm5W&raceId=231651) 81 | | Method | Precision (%) | Recall (%) | F-measure (%) | 82 | | - | - | - | - | 83 | | PSENet-1s (ResNet152) | 8.28 | 70.0 | 76 | 84 | 85 | ## Results 86 |
87 | 88 |
89 |

90 | Figure 3: The results on ICDAR 2015, ICDAR 2017 MLT and SCUT-CTW1500 91 |

92 | 93 | ## Paper Link 94 | [new version paper] [https://arxiv.org/abs/1903.12473](https://arxiv.org/abs/1903.12473) 95 | 96 | [old version paper] [https://arxiv.org/abs/1806.02559](https://arxiv.org/abs/1806.02559) 97 | 98 | ## Other Implements 99 | [pytorch version (thanks @[WenmuZhou](https://github.com/WenmuZhou))] (https://github.com/WenmuZhou/PSENet.pytorch) 100 | 101 | [tensorflow1.x version (thanks @[liuheng92](https://github.com/liuheng92))] [https://github.com/liuheng92/tensorflow_PSENet](https://github.com/liuheng92/tensorflow_PSENet) 102 | 103 | ## Thanks and collaborator 104 | laizhihui @ [lzh](https://github.com/lzh37) 105 | 106 | 107 | ## Citation 108 | ``` 109 | @inproceedings{wang2019shape, 110 | title={Shape Robust Text Detection With Progressive Scale Expansion Network}, 111 | author={Wang, Wenhai and Xie, Enze and Li, Xiang and Hou, Wenbo and Lu, Tong and Yu, Gang and Shao, Shuai}, 112 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 113 | pages={9336--9345}, 114 | year={2019} 115 | } 116 | ``` 117 | -------------------------------------------------------------------------------- /MobileNetV2/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train the MobileNet V2 model 3 | """ 4 | import os 5 | import sys 6 | import argparse 7 | import pandas as pd 8 | 9 | from mobilenet_v2 import MobileNetv2 10 | 11 | from keras.optimizers import Adam 12 | from keras.preprocessing.image import ImageDataGenerator 13 | from keras.callbacks import EarlyStopping 14 | from keras.layers import Conv2D, Reshape, Activation 15 | from keras.models import Model 16 | 17 | 18 | def main(argv): 19 | parser = argparse.ArgumentParser() 20 | # Required arguments. 21 | parser.add_argument( 22 | "--classes", 23 | help="The number of classes of dataset.") 24 | # Optional arguments. 25 | parser.add_argument( 26 | "--size", 27 | default=224, 28 | help="The image size of train sample.") 29 | parser.add_argument( 30 | "--batch", 31 | default=32, 32 | help="The number of train samples per batch.") 33 | parser.add_argument( 34 | "--epochs", 35 | default=300, 36 | help="The number of train iterations.") 37 | parser.add_argument( 38 | "--weights", 39 | default=False, 40 | help="Fine tune with other weights.") 41 | parser.add_argument( 42 | "--tclasses", 43 | default=0, 44 | help="The number of classes of pre-trained model.") 45 | 46 | args = parser.parse_args() 47 | 48 | train(int(args.batch), int(args.epochs), int(args.classes), int(args.size), args.weights, int(args.tclasses)) 49 | 50 | 51 | def generate(batch, size): 52 | """Data generation and augmentation 53 | 54 | # Arguments 55 | batch: Integer, batch size. 56 | size: Integer, image size. 57 | 58 | # Returns 59 | train_generator: train set generator 60 | validation_generator: validation set generator 61 | count1: Integer, number of train set. 62 | count2: Integer, number of test set. 63 | """ 64 | 65 | # Using the data Augmentation in traning data 66 | ptrain = 'data/train' 67 | pval = 'data/validation' 68 | 69 | datagen1 = ImageDataGenerator( 70 | rescale=1. / 255, 71 | shear_range=0.2, 72 | zoom_range=0.2, 73 | rotation_range=90, 74 | width_shift_range=0.2, 75 | height_shift_range=0.2, 76 | horizontal_flip=True) 77 | 78 | datagen2 = ImageDataGenerator(rescale=1. / 255) 79 | 80 | train_generator = datagen1.flow_from_directory( 81 | ptrain, 82 | target_size=(size, size), 83 | batch_size=batch, 84 | class_mode='categorical') 85 | 86 | validation_generator = datagen2.flow_from_directory( 87 | pval, 88 | target_size=(size, size), 89 | batch_size=batch, 90 | class_mode='categorical') 91 | 92 | count1 = 0 93 | for root, dirs, files in os.walk(ptrain): 94 | for each in files: 95 | count1 += 1 96 | 97 | count2 = 0 98 | for root, dirs, files in os.walk(pval): 99 | for each in files: 100 | count2 += 1 101 | 102 | return train_generator, validation_generator, count1, count2 103 | 104 | 105 | def fine_tune(num_classes, weights, model): 106 | """Re-build model with current num_classes. 107 | 108 | # Arguments 109 | num_classes, Integer, The number of classes of dataset. 110 | tune, String, The pre_trained model weights. 111 | model, Model, The model structure. 112 | """ 113 | model.load_weights(weights) 114 | 115 | x = model.get_layer('Dropout').output 116 | x = Conv2D(num_classes, (1, 1), padding='same')(x) 117 | x = Activation('softmax', name='softmax')(x) 118 | output = Reshape((num_classes,))(x) 119 | 120 | model = Model(inputs=model.input, outputs=output) 121 | 122 | return model 123 | 124 | 125 | def train(batch, epochs, num_classes, size, weights, tclasses): 126 | """Train the model. 127 | 128 | # Arguments 129 | batch: Integer, The number of train samples per batch. 130 | epochs: Integer, The number of train iterations. 131 | num_classes, Integer, The number of classes of dataset. 132 | size: Integer, image size. 133 | weights, String, The pre_trained model weights. 134 | tclasses, Integer, The number of classes of pre-trained model. 135 | """ 136 | 137 | train_generator, validation_generator, count1, count2 = generate(batch, size) 138 | 139 | if weights: 140 | model = MobileNetv2((size, size, 3), tclasses) 141 | model = fine_tune(num_classes, weights, model) 142 | else: 143 | model = MobileNetv2((size, size, 3), num_classes) 144 | 145 | opt = Adam() 146 | earlystop = EarlyStopping(monitor='val_acc', patience=30, verbose=0, mode='auto') 147 | model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy']) 148 | 149 | hist = model.fit_generator( 150 | train_generator, 151 | validation_data=validation_generator, 152 | steps_per_epoch=count1 // batch, 153 | validation_steps=count2 // batch, 154 | epochs=epochs, 155 | callbacks=[earlystop]) 156 | 157 | if not os.path.exists('model'): 158 | os.makedirs('model') 159 | 160 | df = pd.DataFrame.from_dict(hist.history) 161 | df.to_csv('model/hist.csv', encoding='utf-8', index=False) 162 | model.save_weights('model/weights.h5') 163 | 164 | 165 | if __name__ == '__main__': 166 | main(sys.argv) 167 | -------------------------------------------------------------------------------- /pse/adaptor_2.cpp: -------------------------------------------------------------------------------- 1 | #include "pybind11/pybind11.h" 2 | #include "pybind11/numpy.h" 3 | #include "pybind11/stl.h" 4 | #include "pybind11/stl_bind.h" 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | using namespace std; 15 | using namespace cv; 16 | 17 | namespace py = pybind11; 18 | 19 | namespace lanms_adaptor { 20 | vector get_kernals(const int *data, vector data_shape) { 21 | vector kernals; 22 | for (int i = 0; i < data_shape[0]; ++i) { 23 | Mat kernal = Mat::zeros(data_shape[1], data_shape[2], CV_8UC1); 24 | for (int x = 0; x < kernal.rows; ++x) { 25 | for (int y = 0; y < kernal.cols; ++y) { 26 | kernal.at(x, y) = data[i * data_shape[1] * data_shape[2] + x * data_shape[2] + y]; 27 | } 28 | } 29 | kernals.emplace_back(kernal); 30 | } 31 | return kernals; 32 | } 33 | 34 | Mat growing_text_line(vector kernals) { 35 | int th1 = 10; 36 | // int th1 = 0; 37 | Mat text_line = Mat::zeros(kernals[0].size(), CV_32SC1); 38 | 39 | Mat label_mat; 40 | int label_num = connectedComponents(kernals[kernals.size() - 1], label_mat, 4); 41 | 42 | int area[label_num + 1]; 43 | memset(area, 0, sizeof(area)); 44 | for (int x = 0; x < label_mat.rows; ++x) { 45 | for (int y = 0; y < label_mat.cols; ++y) { 46 | int label = label_mat.at(x, y); 47 | if (label == 0) continue; 48 | area[label] += 1; 49 | } 50 | } 51 | queue queue, next_queue; 52 | for (int x = 0; x < label_mat.rows; ++x) { 53 | for (int y = 0; y < label_mat.cols; ++y) { 54 | int label = label_mat.at(x, y); 55 | if (label == 0) continue; 56 | if (area[label] < th1) continue; 57 | Point point(x, y); 58 | queue.push(point); 59 | text_line.at(x, y) = label; 60 | } 61 | } 62 | 63 | // cout << text_line << endl; 64 | 65 | int dx[] = {-1, 1, 0, 0}; 66 | int dy[] = {0, 0, -1, 1}; 67 | 68 | for (int kernal_id = kernals.size() - 2; kernal_id >= 0; --kernal_id) { 69 | while (!queue.empty()) { 70 | Point point = queue.front(); queue.pop(); 71 | int x = point.x; 72 | int y = point.y; 73 | int label = text_line.at(x, y); 74 | 75 | bool is_edge = true; 76 | for (int d = 0; d < 4; ++d) { 77 | int tmp_x = x + dx[d]; 78 | int tmp_y = y + dy[d]; 79 | 80 | if (tmp_x < 0 || tmp_x >= text_line.rows) continue; 81 | if (tmp_y < 0 || tmp_y >= text_line.cols) continue; 82 | if (kernals[kernal_id].at(tmp_x, tmp_y) == 0) continue; 83 | if (text_line.at(tmp_x, tmp_y) > 0) continue; 84 | 85 | Point point(tmp_x, tmp_y); 86 | queue.push(point); 87 | text_line.at(tmp_x, tmp_y) = label; 88 | is_edge = false; 89 | } 90 | 91 | if (is_edge) { 92 | next_queue.push(point); 93 | } 94 | } 95 | 96 | /* 97 | label_num = connectedComponents(kernals[kernal_id], label_mat, 4); 98 | 99 | int area[label_num + 1]; 100 | memset(area, 0, sizeof(area)); 101 | for (int x = 0; x < label_mat.rows; ++x) { 102 | for (int y = 0; y < label_mat.cols; ++y) { 103 | int label = label_mat.at(x, y); 104 | if (label == 0) continue; 105 | area[label] += 1; 106 | } 107 | } 108 | 109 | for (int x = 0; x < label_mat.rows; ++x) { 110 | for (int y = 0; y < label_mat.cols; ++y) { 111 | int label = label_mat.at(x, y); 112 | if (label == 0) continue; 113 | if (area[label] < th1) continue; 114 | if (text_line.at(x, y) > 0) continue; 115 | text_line.at(x, y) = label + bias; 116 | } 117 | } 118 | bias += label_num; 119 | */ 120 | 121 | /* 122 | for (int x = 0; x < text_line.rows; ++x) { 123 | for (int y = 0; y < text_line.cols; ++y) { 124 | if (text_line.at(x, y) == 0) continue; 125 | Point point(x, y); 126 | queue.push(point); 127 | } 128 | } 129 | */ 130 | 131 | swap(queue, next_queue); 132 | } 133 | 134 | // cout << text_line << endl; 135 | 136 | return text_line; 137 | } 138 | 139 | vector> merge_quadrangle_n9(py::array_t quad_n9) { 140 | auto buf = quad_n9.request(); 141 | auto data = static_cast(buf.ptr); 142 | vector kernals = get_kernals(data, buf.shape); 143 | 144 | Mat _text_line = growing_text_line(kernals); 145 | 146 | // cout << _text_line << endl; 147 | vector> text_line; 148 | for (int x = 0; x < _text_line.rows; ++x) { 149 | vector row; 150 | for (int y = 0; y < _text_line.cols; ++y) { 151 | row.emplace_back(_text_line.at(x, y)); 152 | } 153 | text_line.emplace_back(row); 154 | } 155 | 156 | return text_line; 157 | } 158 | } 159 | 160 | PYBIND11_PLUGIN(adaptor) { 161 | py::module m("adaptor", "NMS"); 162 | 163 | m.def("merge_quadrangle_n9", &lanms_adaptor::merge_quadrangle_n9, "merge quadrangels"); 164 | 165 | return m.ptr(); 166 | } 167 | 168 | -------------------------------------------------------------------------------- /util/io_.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | ''' 3 | Created on 2016年9月27日 4 | 5 | @author: dengdan 6 | 7 | Tool functions for file system operation and I/O. 8 | In the style of linux shell commands 9 | ''' 10 | import os 11 | #import cPickle as pkl 12 | import _pickle as pkl 13 | #import commands 14 | import subprocess 15 | import logging 16 | 17 | import util 18 | 19 | def mkdir(path): 20 | """ 21 | If the target directory does not exists, it and its parent directories will created. 22 | """ 23 | path = get_absolute_path(path) 24 | if not exists(path): 25 | os.makedirs(path) 26 | return path 27 | 28 | def make_parent_dir(path): 29 | """make the parent directories for a file.""" 30 | parent_dir = get_dir(path) 31 | mkdir(parent_dir) 32 | 33 | 34 | def pwd(): 35 | return os.getcwd() 36 | 37 | def dump(path, obj): 38 | path = get_absolute_path(path) 39 | parent_path = get_dir(path) 40 | mkdir(parent_path) 41 | with open(path, 'w') as f: 42 | logging.info('dumping file:' + path); 43 | pkl.dump(obj, f) 44 | 45 | def load(path): 46 | path = get_absolute_path(path) 47 | with open(path, 'r') as f: 48 | data = pkl.load(f) 49 | return data 50 | 51 | def join_path(a, *p): 52 | return os.path.join(a, *p) 53 | 54 | def is_dir(path): 55 | path = get_absolute_path(path) 56 | return os.path.isdir(path) 57 | 58 | 59 | def is_path(path): 60 | path = get_absolute_path(path) 61 | return os.path.ispath(path) 62 | 63 | def get_dir(path): 64 | ''' 65 | return the directory it belongs to. 66 | if path is a directory itself, itself will be return 67 | ''' 68 | path = get_absolute_path(path) 69 | if is_dir(path): 70 | return path; 71 | return os.path.split(path)[0] 72 | 73 | def get_filename(path): 74 | return os.path.split(path)[1] 75 | 76 | def get_absolute_path(p): 77 | if p.startswith('~'): 78 | p = os.path.expanduser(p) 79 | return os.path.abspath(p) 80 | 81 | def cd(p): 82 | p = get_absolute_path(p) 83 | os.chdir(p) 84 | 85 | def ls(path = '.', suffix = None): 86 | """ 87 | list files in a directory. 88 | return file names in a list 89 | """ 90 | path = get_absolute_path(path) 91 | files = os.listdir(path) 92 | 93 | if suffix is None: 94 | return files 95 | 96 | filtered = [] 97 | for f in files: 98 | if util.str.ends_with(f, suffix, ignore_case = True): 99 | filtered.append(f) 100 | 101 | return filtered 102 | 103 | def find_files(pattern): 104 | import glob 105 | return glob.glob(pattern) 106 | 107 | def read_lines(p): 108 | """return the text in a file in lines as a list """ 109 | p = get_absolute_path(p) 110 | f = open(p,'r') 111 | return f.readlines() 112 | 113 | def write_lines(p, lines): 114 | p = get_absolute_path(p) 115 | make_parent_dir(p) 116 | with open(p, 'w') as f: 117 | for line in lines: 118 | f.write(line) 119 | 120 | 121 | def cat(p): 122 | """return the text in a file as a whole""" 123 | cmd = 'cat ' + p 124 | #return commands.getoutput(cmd) 125 | return subprocess.getoutput(cmd) 126 | 127 | def exists(path): 128 | path = get_absolute_path(path) 129 | return os.path.exists(path) 130 | 131 | def load_mat(path): 132 | import scipy.io as sio 133 | path = get_absolute_path(path) 134 | return sio.loadmat(path) 135 | 136 | def dump_mat(path, dict_obj, append = True): 137 | import scipy.io as sio 138 | path = get_absolute_path(path) 139 | make_parent_dir(path) 140 | sio.savemat(file_name = path, mdict = dict_obj, appendmat = append) 141 | 142 | def dir_mat(path): 143 | ''' 144 | list the variables in mat file. 145 | return a list: [(name, shape, dtype), ...] 146 | ''' 147 | import scipy.io as sio 148 | path = get_absolute_path(path) 149 | return sio.whosmat(path) 150 | 151 | SIZE_UNIT_K = 1024 152 | SIZE_UNIT_M = SIZE_UNIT_K ** 2 153 | SIZE_UNIT_G = SIZE_UNIT_K ** 3 154 | def get_file_size(path, unit = SIZE_UNIT_K): 155 | size = os.path.getsize(get_absolute_path(path)) 156 | return size * 1.0 / unit 157 | 158 | 159 | def create_h5(path): 160 | import h5py 161 | path = get_absolute_path(path) 162 | make_parent_dir(path) 163 | return h5py.File(path, 'w'); 164 | 165 | def open_h5(path, mode = 'r'): 166 | import h5py 167 | path = get_absolute_path(path) 168 | return h5py.File(path, mode); 169 | 170 | def read_h5(h5, key): 171 | return h5[key][:] 172 | def read_h5_attrs(h5, key, attrs): 173 | return h5[key].attrs[attrs] 174 | 175 | def copy(src, dest): 176 | import shutil 177 | shutil.copy(get_absolute_path(src), get_absolute_path(dest)) 178 | 179 | cp = copy 180 | 181 | def remove(p): 182 | import os 183 | os.remove(get_absolute_path(p)) 184 | rm = remove 185 | 186 | def search(pattern, path, file_only = True): 187 | """ 188 | Search files whose name matches the give pattern. The search scope 189 | is the directory and sub-directories of 'path'. 190 | """ 191 | path = get_absolute_path(path) 192 | pattern_here = util.io.join_path(path, pattern) 193 | targets = [] 194 | 195 | # find matchings in current directory 196 | candidates = find_files(pattern_here) 197 | for can in candidates: 198 | if util.io.is_dir(can) and file_only: 199 | continue 200 | else: 201 | targets.append(can) 202 | 203 | # find matching in sub-dirs 204 | files = ls(path) 205 | for f in files: 206 | fpath = util.io.join_path(path, f) 207 | if is_dir(fpath): 208 | targets_in_sub_dir = search(pattern, fpath, file_only) 209 | targets.extend(targets_in_sub_dir) 210 | return targets 211 | 212 | def dump_json(path, data): 213 | import json 214 | path = get_absolute_path(path) 215 | make_parent_dir(path) 216 | 217 | with open(path, 'w') as f: 218 | json.dump(data, f) 219 | return path -------------------------------------------------------------------------------- /util/plt.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | ''' 3 | Created on 2016-9-27 4 | 5 | @author: dengdan 6 | ''' 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import util 10 | 11 | def hist(x, title = None, normed = False, show = True, save = False, save_path = None, bin_count = 100, bins = None): 12 | x = np.asarray(x) 13 | if len(np.shape(x)) > 1: 14 | # x = np.reshape(x, np.prod(x.shape)) 15 | x = util.np.flatten(x) 16 | if bins == None: 17 | bins = np.linspace(start = min(x), stop = max(x), num = bin_count, endpoint = True, retstep = False) 18 | plt.figure(num = title) 19 | plt.hist(x, bins, normed = normed) 20 | if save: 21 | if save_path is None: 22 | raise ValueError 23 | path = util.io.join_path(save_path, title + '.png') 24 | save_image(path) 25 | if show: 26 | plt.show() 27 | #util.img.imshow(title, path, block = block) 28 | 29 | def plot_solver_data(solver_path): 30 | data = util.io.load(solver_path) 31 | training_losses = data.training_losses 32 | training_accuracies = data.training_accuracies 33 | val_losses = data.val_losses 34 | val_accuracies = data.val_accuracies 35 | plt.figure(solver_path) 36 | 37 | n = len(training_losses) 38 | x = range(n) 39 | 40 | plt.plot(x, training_losses, 'r-', label = 'Training Loss') 41 | 42 | if len(training_accuracies) > 0: 43 | plt.plot(x, training_accuracies, 'r--', label = 'Training Accuracy') 44 | 45 | if len(val_losses) > 0: 46 | n = len(val_losses) 47 | x = range(n) 48 | plt.plot(x, val_losses, 'g-', label = 'Validation Loss') 49 | 50 | if len(val_accuracies) > 0: 51 | plt.plot(x, val_accuracies, 'g--', label = 'Validation Accuracy') 52 | plt.legend() 53 | plt.show() 54 | 55 | 56 | def rectangle(xy, width, height, color = 'red', linewidth = 1, fill = False, alpha = None, axis = None): 57 | """ 58 | draw a rectangle on plt axis 59 | """ 60 | import matplotlib.patches as patches 61 | rect = patches.Rectangle( 62 | xy = xy, 63 | width = width, 64 | height = height, 65 | alpha = alpha, 66 | color = color, 67 | fill = fill, 68 | linewidth = linewidth 69 | ) 70 | if axis is not None: 71 | axis.add_patch(rect) 72 | return rect 73 | 74 | rect = rectangle 75 | 76 | def maximize_figure(): 77 | mng = plt.get_current_fig_manager() 78 | mng.full_screen_toggle() 79 | 80 | def line(xy_start, xy_end, color = 'red', linewidth = 1, alpha = None, axis = None): 81 | """ 82 | draw a line on plt axis 83 | """ 84 | from matplotlib.lines import Line2D 85 | num = 100 86 | xdata = np.linspace(xy_start[0], xy_end[0], num = num) 87 | ydata = np.linspace(xy_start[1], xy_end[1], num = num) 88 | line = Line2D( 89 | alpha = alpha, 90 | color = color, 91 | linewidth = linewidth, 92 | xdata = xdata, 93 | ydata = ydata 94 | ) 95 | if axis is not None: 96 | axis.add_line(line) 97 | return line 98 | 99 | def imshow(title = None, img = None, gray = False): 100 | show_images([img], [title], gray = gray) 101 | 102 | def show_images(images, titles = None, shape = None, share_axis = False, 103 | bgr2rgb = False, maximized = False, 104 | show = True, gray = False, save = False, colorbar = False, 105 | path = None, axis_off = False, vertical = False, subtitle = None): 106 | 107 | if shape == None: 108 | if vertical: 109 | shape = (len(images), 1) 110 | else: 111 | shape = (1, len(images)) 112 | 113 | ret_axes = [] 114 | ax0 = None 115 | for idx, img in enumerate(images): 116 | if bgr2rgb: 117 | img = util.img.bgr2rgb(img) 118 | loc = (idx / shape[1], idx % shape[1]) 119 | if idx == 0: 120 | ax = plt.subplot2grid(shape, loc) 121 | ax0 = ax 122 | else: 123 | if share_axis: 124 | ax = plt.subplot2grid(shape, loc, sharex = ax0, sharey = ax0) 125 | else: 126 | ax = plt.subplot2grid(shape, loc) 127 | if len(np.shape(img)) == 2 and gray: 128 | img_ax = ax.imshow(img, cmap = 'gray') 129 | else: 130 | img_ax = ax.imshow(img) 131 | 132 | if len(np.shape(img)) == 2 and colorbar: 133 | plt.colorbar(img_ax, ax = ax) 134 | if titles != None: 135 | ax.set_title(titles[idx]) 136 | 137 | if axis_off: 138 | plt.axis('off') 139 | # plt.xticks([]), plt.yticks([]) 140 | ret_axes.append(ax) 141 | 142 | if subtitle is not None: 143 | set_subtitle(subtitle) 144 | if maximized: 145 | maximize_figure() 146 | 147 | if save: 148 | if path is None: 149 | raise ValueError('path can not be None when save is True') 150 | save_image(path) 151 | if show: 152 | plt.show() 153 | return ret_axes 154 | 155 | def save_image(path, img = None, dpi = 150): 156 | path = util.io.get_absolute_path(path) 157 | util.io.make_parent_dir(path) 158 | if img is None: 159 | plt.gcf().savefig(path, dpi = dpi) 160 | else: 161 | plt.imsave(path, img) 162 | 163 | imwrite = save_image 164 | 165 | def to_ROI(ax, ROI): 166 | xy1, xy2 = ROI 167 | xmin, ymin = xy1 168 | xmax, ymax = xy2 169 | ax.set_xlim(xmin, xmax) 170 | #ax.extent 171 | ax.set_ylim(ymax, ymin) 172 | 173 | def set_subtitle(title, fontsize = 12): 174 | plt.gcf().suptitle(title, fontsize=fontsize) 175 | 176 | def show(maximized = False): 177 | if maximized: 178 | maximize_figure() 179 | plt.show() 180 | 181 | def draw(): 182 | plt.gcf().canvas.draw() 183 | 184 | def get_random_line_style(): 185 | colors = ['r', 'g', 'b'] 186 | line_types = ['-']#, '--', '-.', ':'] 187 | idx = util.rand.randint(len(colors)) 188 | color = colors[idx] 189 | idx = util.rand.randint(len(line_types)) 190 | line_type = line_types[idx] 191 | return color + line_type 192 | -------------------------------------------------------------------------------- /pse/adaptor_1.cpp: -------------------------------------------------------------------------------- 1 | #include "pybind11/pybind11.h" 2 | #include "pybind11/numpy.h" 3 | #include "pybind11/stl.h" 4 | #include "pybind11/stl_bind.h" 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | using namespace std; 15 | using namespace cv; 16 | 17 | namespace py = pybind11; 18 | 19 | namespace lanms_adaptor { 20 | vector get_kernals(const int *data, vector data_shape) { 21 | vector kernals; 22 | for (int i = 0; i < data_shape[0]; ++i) { 23 | Mat kernal = Mat::zeros(data_shape[1], data_shape[2], CV_8UC1); 24 | for (int x = 0; x < kernal.rows; ++x) { 25 | for (int y = 0; y < kernal.cols; ++y) { 26 | kernal.at(x, y) = data[i * data_shape[1] * data_shape[2] + x * data_shape[2] + y]; 27 | } 28 | } 29 | kernals.emplace_back(kernal); 30 | } 31 | return kernals; 32 | } 33 | 34 | vector> growing_text_line(vector kernals) { 35 | int th1 = 10; 36 | // int th1 = 0; 37 | // Mat text_line = Mat::zeros(kernals[0].size(), CV_32SC1); 38 | 39 | Mat label_mat; 40 | int label_num = connectedComponents(kernals[kernals.size() - 1], label_mat, 4); 41 | 42 | int area[label_num + 1]; 43 | memset(area, 0, sizeof(area)); 44 | for (int x = 0; x < label_mat.rows; ++x) { 45 | for (int y = 0; y < label_mat.cols; ++y) { 46 | int label = label_mat.at(x, y); 47 | if (label == 0) continue; 48 | area[label] += 1; 49 | } 50 | } 51 | 52 | vector> text_line; 53 | queue queue, next_queue; 54 | for (int x = 0; x < label_mat.rows; ++x) { 55 | vector row; 56 | for (int y = 0; y < label_mat.cols; ++y) { 57 | int label = label_mat.at(x, y); 58 | if (label == 0) { 59 | row.emplace_back(0); 60 | continue; 61 | } 62 | if (area[label] < th1) { 63 | row.emplace_back(0); 64 | continue; 65 | } 66 | Point point(x, y); 67 | queue.push(point); 68 | // text_line.at(x, y) = label; 69 | row.emplace_back(label); 70 | } 71 | text_line.emplace_back(row); 72 | } 73 | 74 | // cout << text_line << endl; 75 | 76 | int dx[] = {-1, 1, 0, 0}; 77 | int dy[] = {0, 0, -1, 1}; 78 | 79 | for (int kernal_id = kernals.size() - 2; kernal_id >= 0; --kernal_id) { 80 | while (!queue.empty()) { 81 | Point point = queue.front(); queue.pop(); 82 | int x = point.x; 83 | int y = point.y; 84 | // int label = text_line.at(x, y); 85 | int label = text_line[x][y]; 86 | 87 | bool is_edge = true; 88 | for (int d = 0; d < 4; ++d) { 89 | int tmp_x = x + dx[d]; 90 | int tmp_y = y + dy[d]; 91 | 92 | if (tmp_x < 0 || tmp_x >= (int)text_line.size()) continue; 93 | if (tmp_y < 0 || tmp_y >= (int)text_line[1].size()) continue; 94 | if (kernals[kernal_id].at(tmp_x, tmp_y) == 0) continue; 95 | if (text_line[tmp_x][tmp_y] > 0) continue; 96 | 97 | Point point(tmp_x, tmp_y); 98 | queue.push(point); 99 | text_line[tmp_x][tmp_y] = label; 100 | is_edge = false; 101 | } 102 | 103 | if (is_edge) { 104 | next_queue.push(point); 105 | } 106 | } 107 | 108 | /* 109 | label_num = connectedComponents(kernals[kernal_id], label_mat, 4); 110 | 111 | int area[label_num + 1]; 112 | memset(area, 0, sizeof(area)); 113 | for (int x = 0; x < label_mat.rows; ++x) { 114 | for (int y = 0; y < label_mat.cols; ++y) { 115 | int label = label_mat.at(x, y); 116 | if (label == 0) continue; 117 | area[label] += 1; 118 | } 119 | } 120 | 121 | for (int x = 0; x < label_mat.rows; ++x) { 122 | for (int y = 0; y < label_mat.cols; ++y) { 123 | int label = label_mat.at(x, y); 124 | if (label == 0) continue; 125 | if (area[label] < th1) continue; 126 | if (text_line.at(x, y) > 0) continue; 127 | text_line.at(x, y) = label + bias; 128 | } 129 | } 130 | bias += label_num; 131 | */ 132 | 133 | /* 134 | for (int x = 0; x < text_line.rows; ++x) { 135 | for (int y = 0; y < text_line.cols; ++y) { 136 | if (text_line.at(x, y) == 0) continue; 137 | Point point(x, y); 138 | queue.push(point); 139 | } 140 | } 141 | */ 142 | 143 | swap(queue, next_queue); 144 | } 145 | 146 | // cout << text_line << endl; 147 | 148 | return text_line; 149 | } 150 | 151 | vector> merge_quadrangle_n9(py::array_t quad_n9) { 152 | auto buf = quad_n9.request(); 153 | auto data = static_cast(buf.ptr); 154 | vector kernals = get_kernals(data, buf.shape); 155 | 156 | vector> text_line = growing_text_line(kernals); 157 | 158 | // cout << _text_line << endl; 159 | // vector> text_line; 160 | // for (int x = 0; x < _text_line.rows; ++x) { 161 | // vector row; 162 | // for (int y = 0; y < _text_line.cols; ++y) { 163 | // row.emplace_back(_text_line.at(x, y)); 164 | // } 165 | // text_line.emplace_back(row); 166 | // } 167 | 168 | return text_line; 169 | } 170 | } 171 | 172 | PYBIND11_PLUGIN(adaptor) { 173 | py::module m("adaptor", "NMS"); 174 | 175 | m.def("merge_quadrangle_n9", &lanms_adaptor::merge_quadrangle_n9, "merge quadrangels"); 176 | 177 | return m.ptr(); 178 | } 179 | 180 | -------------------------------------------------------------------------------- /pse/include/pybind11/iostream.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python 3 | 4 | Copyright (c) 2017 Henry F. Schreiner 5 | 6 | All rights reserved. Use of this source code is governed by a 7 | BSD-style license that can be found in the LICENSE file. 8 | */ 9 | 10 | #pragma once 11 | 12 | #include "pybind11.h" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 21 | NAMESPACE_BEGIN(detail) 22 | 23 | // Buffer that writes to Python instead of C++ 24 | class pythonbuf : public std::streambuf { 25 | private: 26 | using traits_type = std::streambuf::traits_type; 27 | 28 | const size_t buf_size; 29 | std::unique_ptr d_buffer; 30 | object pywrite; 31 | object pyflush; 32 | 33 | int overflow(int c) { 34 | if (!traits_type::eq_int_type(c, traits_type::eof())) { 35 | *pptr() = traits_type::to_char_type(c); 36 | pbump(1); 37 | } 38 | return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof(); 39 | } 40 | 41 | int sync() { 42 | if (pbase() != pptr()) { 43 | // This subtraction cannot be negative, so dropping the sign 44 | str line(pbase(), static_cast(pptr() - pbase())); 45 | 46 | { 47 | gil_scoped_acquire tmp; 48 | pywrite(line); 49 | pyflush(); 50 | } 51 | 52 | setp(pbase(), epptr()); 53 | } 54 | return 0; 55 | } 56 | 57 | public: 58 | 59 | pythonbuf(object pyostream, size_t buffer_size = 1024) 60 | : buf_size(buffer_size), 61 | d_buffer(new char[buf_size]), 62 | pywrite(pyostream.attr("write")), 63 | pyflush(pyostream.attr("flush")) { 64 | setp(d_buffer.get(), d_buffer.get() + buf_size - 1); 65 | } 66 | 67 | pythonbuf(pythonbuf&&) = default; 68 | 69 | /// Sync before destroy 70 | ~pythonbuf() { 71 | sync(); 72 | } 73 | }; 74 | 75 | NAMESPACE_END(detail) 76 | 77 | 78 | /** \rst 79 | This a move-only guard that redirects output. 80 | 81 | .. code-block:: cpp 82 | 83 | #include 84 | 85 | ... 86 | 87 | { 88 | py::scoped_ostream_redirect output; 89 | std::cout << "Hello, World!"; // Python stdout 90 | } // <-- return std::cout to normal 91 | 92 | You can explicitly pass the c++ stream and the python object, 93 | for example to guard stderr instead. 94 | 95 | .. code-block:: cpp 96 | 97 | { 98 | py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")}; 99 | std::cerr << "Hello, World!"; 100 | } 101 | \endrst */ 102 | class scoped_ostream_redirect { 103 | protected: 104 | std::streambuf *old; 105 | std::ostream &costream; 106 | detail::pythonbuf buffer; 107 | 108 | public: 109 | scoped_ostream_redirect( 110 | std::ostream &costream = std::cout, 111 | object pyostream = module::import("sys").attr("stdout")) 112 | : costream(costream), buffer(pyostream) { 113 | old = costream.rdbuf(&buffer); 114 | } 115 | 116 | ~scoped_ostream_redirect() { 117 | costream.rdbuf(old); 118 | } 119 | 120 | scoped_ostream_redirect(const scoped_ostream_redirect &) = delete; 121 | scoped_ostream_redirect(scoped_ostream_redirect &&other) = default; 122 | scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete; 123 | scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete; 124 | }; 125 | 126 | 127 | /** \rst 128 | Like `scoped_ostream_redirect`, but redirects cerr by default. This class 129 | is provided primary to make ``py::call_guard`` easier to make. 130 | 131 | .. code-block:: cpp 132 | 133 | m.def("noisy_func", &noisy_func, 134 | py::call_guard()); 136 | 137 | \endrst */ 138 | class scoped_estream_redirect : public scoped_ostream_redirect { 139 | public: 140 | scoped_estream_redirect( 141 | std::ostream &costream = std::cerr, 142 | object pyostream = module::import("sys").attr("stderr")) 143 | : scoped_ostream_redirect(costream,pyostream) {} 144 | }; 145 | 146 | 147 | NAMESPACE_BEGIN(detail) 148 | 149 | // Class to redirect output as a context manager. C++ backend. 150 | class OstreamRedirect { 151 | bool do_stdout_; 152 | bool do_stderr_; 153 | std::unique_ptr redirect_stdout; 154 | std::unique_ptr redirect_stderr; 155 | 156 | public: 157 | OstreamRedirect(bool do_stdout = true, bool do_stderr = true) 158 | : do_stdout_(do_stdout), do_stderr_(do_stderr) {} 159 | 160 | void enter() { 161 | if (do_stdout_) 162 | redirect_stdout.reset(new scoped_ostream_redirect()); 163 | if (do_stderr_) 164 | redirect_stderr.reset(new scoped_estream_redirect()); 165 | } 166 | 167 | void exit() { 168 | redirect_stdout.reset(); 169 | redirect_stderr.reset(); 170 | } 171 | }; 172 | 173 | NAMESPACE_END(detail) 174 | 175 | /** \rst 176 | This is a helper function to add a C++ redirect context manager to Python 177 | instead of using a C++ guard. To use it, add the following to your binding code: 178 | 179 | .. code-block:: cpp 180 | 181 | #include 182 | 183 | ... 184 | 185 | py::add_ostream_redirect(m, "ostream_redirect"); 186 | 187 | You now have a Python context manager that redirects your output: 188 | 189 | .. code-block:: python 190 | 191 | with m.ostream_redirect(): 192 | m.print_to_cout_function() 193 | 194 | This manager can optionally be told which streams to operate on: 195 | 196 | .. code-block:: python 197 | 198 | with m.ostream_redirect(stdout=true, stderr=true): 199 | m.noisy_function_with_error_printing() 200 | 201 | \endrst */ 202 | inline class_ add_ostream_redirect(module m, std::string name = "ostream_redirect") { 203 | return class_(m, name.c_str(), module_local()) 204 | .def(init(), arg("stdout")=true, arg("stderr")=true) 205 | .def("__enter__", &detail::OstreamRedirect::enter) 206 | .def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); }); 207 | } 208 | 209 | NAMESPACE_END(PYBIND11_NAMESPACE) 210 | -------------------------------------------------------------------------------- /pse/adaptor_3.cpp: -------------------------------------------------------------------------------- 1 | #include "pybind11/pybind11.h" 2 | #include "pybind11/numpy.h" 3 | #include "pybind11/stl.h" 4 | #include "pybind11/stl_bind.h" 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | using namespace std; 15 | using namespace cv; 16 | 17 | namespace py = pybind11; 18 | 19 | namespace lanms_adaptor { 20 | void get_kernals(const int *data, vector data_shape, vector &kernals) { 21 | // vector kernals; 22 | for (int i = 0; i < data_shape[0]; ++i) { 23 | Mat kernal = Mat::zeros(data_shape[1], data_shape[2], CV_8UC1); 24 | for (int x = 0; x < kernal.rows; ++x) { 25 | for (int y = 0; y < kernal.cols; ++y) { 26 | kernal.at(x, y) = data[i * data_shape[1] * data_shape[2] + x * data_shape[2] + y]; 27 | } 28 | } 29 | kernals.emplace_back(kernal); 30 | } 31 | // return kernals; 32 | } 33 | 34 | void growing_text_line(vector &kernals, vector> &text_line) { 35 | int th1 = 10; 36 | // int th1 = 0; 37 | // Mat text_line = Mat::zeros(kernals[0].size(), CV_32SC1); 38 | 39 | Mat label_mat; 40 | int label_num = connectedComponents(kernals[kernals.size() - 1], label_mat, 4); 41 | 42 | int area[label_num + 1]; 43 | memset(area, 0, sizeof(area)); 44 | for (int x = 0; x < label_mat.rows; ++x) { 45 | for (int y = 0; y < label_mat.cols; ++y) { 46 | int label = label_mat.at(x, y); 47 | if (label == 0) continue; 48 | area[label] += 1; 49 | } 50 | } 51 | 52 | // vector> text_line; 53 | queue queue, next_queue; 54 | for (int x = 0; x < label_mat.rows; ++x) { 55 | vector row(label_mat.cols); 56 | for (int y = 0; y < label_mat.cols; ++y) { 57 | int label = label_mat.at(x, y); 58 | if (label == 0) { 59 | // row.emplace_back(0); 60 | continue; 61 | } 62 | if (area[label] < th1) { 63 | // row.emplace_back(0); 64 | continue; 65 | } 66 | Point point(x, y); 67 | queue.push(point); 68 | // text_line.at(x, y) = label; 69 | // row.emplace_back(label); 70 | row[y] = label; 71 | } 72 | text_line.emplace_back(row); 73 | } 74 | 75 | // cout << text_line << endl; 76 | 77 | int dx[] = {-1, 1, 0, 0}; 78 | int dy[] = {0, 0, -1, 1}; 79 | 80 | for (int kernal_id = kernals.size() - 2; kernal_id >= 0; --kernal_id) { 81 | while (!queue.empty()) { 82 | Point point = queue.front(); queue.pop(); 83 | int x = point.x; 84 | int y = point.y; 85 | // int label = text_line.at(x, y); 86 | int label = text_line[x][y]; 87 | 88 | bool is_edge = true; 89 | for (int d = 0; d < 4; ++d) { 90 | int tmp_x = x + dx[d]; 91 | int tmp_y = y + dy[d]; 92 | 93 | if (tmp_x < 0 || tmp_x >= (int)text_line.size()) continue; 94 | if (tmp_y < 0 || tmp_y >= (int)text_line[1].size()) continue; 95 | if (kernals[kernal_id].at(tmp_x, tmp_y) == 0) continue; 96 | if (text_line[tmp_x][tmp_y] > 0) continue; 97 | 98 | Point point(tmp_x, tmp_y); 99 | queue.push(point); 100 | text_line[tmp_x][tmp_y] = label; 101 | is_edge = false; 102 | } 103 | 104 | if (is_edge) { 105 | next_queue.push(point); 106 | } 107 | } 108 | 109 | /* 110 | label_num = connectedComponents(kernals[kernal_id], label_mat, 4); 111 | 112 | int area[label_num + 1]; 113 | memset(area, 0, sizeof(area)); 114 | for (int x = 0; x < label_mat.rows; ++x) { 115 | for (int y = 0; y < label_mat.cols; ++y) { 116 | int label = label_mat.at(x, y); 117 | if (label == 0) continue; 118 | area[label] += 1; 119 | } 120 | } 121 | 122 | for (int x = 0; x < label_mat.rows; ++x) { 123 | for (int y = 0; y < label_mat.cols; ++y) { 124 | int label = label_mat.at(x, y); 125 | if (label == 0) continue; 126 | if (area[label] < th1) continue; 127 | if (text_line.at(x, y) > 0) continue; 128 | text_line.at(x, y) = label + bias; 129 | } 130 | } 131 | bias += label_num; 132 | */ 133 | 134 | /* 135 | for (int x = 0; x < text_line.rows; ++x) { 136 | for (int y = 0; y < text_line.cols; ++y) { 137 | if (text_line.at(x, y) == 0) continue; 138 | Point point(x, y); 139 | queue.push(point); 140 | } 141 | } 142 | */ 143 | 144 | swap(queue, next_queue); 145 | } 146 | 147 | // cout << text_line << endl; 148 | 149 | // return text_line; 150 | } 151 | 152 | vector> merge_quadrangle_n9(py::array_t quad_n9) { 153 | auto buf = quad_n9.request(); 154 | auto data = static_cast(buf.ptr); 155 | vector kernals; 156 | get_kernals(data, buf.shape, kernals); 157 | 158 | vector> text_line; 159 | growing_text_line(kernals, text_line); 160 | 161 | // cout << _text_line << endl; 162 | // vector> text_line; 163 | // for (int x = 0; x < _text_line.rows; ++x) { 164 | // vector row; 165 | // for (int y = 0; y < _text_line.cols; ++y) { 166 | // row.emplace_back(_text_line.at(x, y)); 167 | // } 168 | // text_line.emplace_back(row); 169 | // } 170 | 171 | return text_line; 172 | } 173 | } 174 | 175 | PYBIND11_PLUGIN(adaptor) { 176 | py::module m("adaptor", "NMS"); 177 | 178 | m.def("merge_quadrangle_n9", &lanms_adaptor::merge_quadrangle_n9, "merge quadrangels"); 179 | 180 | return m.ptr(); 181 | } 182 | 183 | -------------------------------------------------------------------------------- /MobileNetV2/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | """MobileNet v2 models for Keras. 2 | 3 | # Reference 4 | - [Inverted Residuals and Linear Bottlenecks Mobile Networks for 5 | Classification, Detection and Segmentation] 6 | (https://arxiv.org/abs/1801.04381) 7 | """ 8 | 9 | 10 | from keras.models import Model 11 | from keras.layers import Input, Conv2D, GlobalAveragePooling2D, Dropout 12 | from keras.layers import Activation, BatchNormalization, Add, Reshape, DepthwiseConv2D 13 | from keras.utils.vis_utils import plot_model 14 | 15 | from keras import backend as K 16 | 17 | 18 | def _make_divisible(v, divisor, min_value=None): 19 | if min_value is None: 20 | min_value = divisor 21 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 22 | # Make sure that round down does not go down by more than 10%. 23 | if new_v < 0.9 * v: 24 | new_v += divisor 25 | return new_v 26 | 27 | 28 | def relu6(x): 29 | """Relu 6 30 | """ 31 | return K.relu(x, max_value=6.0) 32 | 33 | 34 | def _conv_block(inputs, filters, kernel, strides): 35 | """Convolution Block 36 | This function defines a 2D convolution operation with BN and relu6. 37 | 38 | # Arguments 39 | inputs: Tensor, input tensor of conv layer. 40 | filters: Integer, the dimensionality of the output space. 41 | kernel: An integer or tuple/list of 2 integers, specifying the 42 | width and height of the 2D convolution window. 43 | strides: An integer or tuple/list of 2 integers, 44 | specifying the strides of the convolution along the width and height. 45 | Can be a single integer to specify the same value for 46 | all spatial dimensions. 47 | 48 | # Returns 49 | Output tensor. 50 | """ 51 | 52 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 53 | 54 | x = Conv2D(filters, kernel, padding='same', strides=strides)(inputs) 55 | x = BatchNormalization(axis=channel_axis)(x) 56 | return Activation(relu6)(x) 57 | 58 | 59 | def _bottleneck(inputs, filters, kernel, t, alpha, s, r=False): 60 | """Bottleneck 61 | This function defines a basic bottleneck structure. 62 | 63 | # Arguments 64 | inputs: Tensor, input tensor of conv layer. 65 | filters: Integer, the dimensionality of the output space. 66 | kernel: An integer or tuple/list of 2 integers, specifying the 67 | width and height of the 2D convolution window. 68 | t: Integer, expansion factor. 69 | t is always applied to the input size. 70 | s: An integer or tuple/list of 2 integers,specifying the strides 71 | of the convolution along the width and height.Can be a single 72 | integer to specify the same value for all spatial dimensions. 73 | alpha: Integer, width multiplier. 74 | r: Boolean, Whether to use the residuals. 75 | 76 | # Returns 77 | Output tensor. 78 | """ 79 | 80 | channel_axis = 1 if K.image_data_format() == 'channels_first' else -1 81 | # Depth 82 | tchannel = K.int_shape(inputs)[channel_axis] * t 83 | # Width 84 | cchannel = int(filters * alpha) 85 | 86 | x = _conv_block(inputs, tchannel, (1, 1), (1, 1)) 87 | 88 | x = DepthwiseConv2D(kernel, strides=(s, s), depth_multiplier=1, padding='same')(x) 89 | x = BatchNormalization(axis=channel_axis)(x) 90 | x = Activation(relu6)(x) 91 | 92 | x = Conv2D(cchannel, (1, 1), strides=(1, 1), padding='same')(x) 93 | x = BatchNormalization(axis=channel_axis)(x) 94 | 95 | if r: 96 | x = Add()([x, inputs]) 97 | 98 | return x 99 | 100 | 101 | def _inverted_residual_block(inputs, filters, kernel, t, alpha, strides, n): 102 | """Inverted Residual Block 103 | This function defines a sequence of 1 or more identical layers. 104 | 105 | # Arguments 106 | inputs: Tensor, input tensor of conv layer. 107 | filters: Integer, the dimensionality of the output space. 108 | kernel: An integer or tuple/list of 2 integers, specifying the 109 | width and height of the 2D convolution window. 110 | t: Integer, expansion factor. 111 | t is always applied to the input size. 112 | alpha: Integer, width multiplier. 113 | s: An integer or tuple/list of 2 integers,specifying the strides 114 | of the convolution along the width and height.Can be a single 115 | integer to specify the same value for all spatial dimensions. 116 | n: Integer, layer repeat times. 117 | 118 | # Returns 119 | Output tensor. 120 | """ 121 | 122 | x = _bottleneck(inputs, filters, kernel, t, alpha, strides) 123 | 124 | for i in range(1, n): 125 | x = _bottleneck(x, filters, kernel, t, alpha, 1, True) 126 | 127 | return x 128 | 129 | 130 | def MobileNetv2(input_shape, k, alpha=1.0): 131 | """MobileNetv2 132 | This function defines a MobileNetv2 architectures. 133 | 134 | # Arguments 135 | input_shape: An integer or tuple/list of 3 integers, shape 136 | of input tensor. 137 | k: Integer, number of classes. 138 | alpha: Integer, width multiplier, better in [0.35, 0.50, 0.75, 1.0, 1.3, 1.4]. 139 | 140 | # Returns 141 | MobileNetv2 model. 142 | """ 143 | inputs = Input(shape=input_shape) 144 | 145 | first_filters = _make_divisible(32 * alpha, 8) 146 | x = _conv_block(inputs, first_filters, (3, 3), strides=(2, 2)) 147 | 148 | x = _inverted_residual_block(x, 16, (3, 3), t=1, alpha=alpha, strides=1, n=1) 149 | x = _inverted_residual_block(x, 24, (3, 3), t=6, alpha=alpha, strides=2, n=2) 150 | x = _inverted_residual_block(x, 32, (3, 3), t=6, alpha=alpha, strides=2, n=3) 151 | x = _inverted_residual_block(x, 64, (3, 3), t=6, alpha=alpha, strides=2, n=4) 152 | x = _inverted_residual_block(x, 96, (3, 3), t=6, alpha=alpha, strides=1, n=3) 153 | x = _inverted_residual_block(x, 160, (3, 3), t=6, alpha=alpha, strides=2, n=3) 154 | x = _inverted_residual_block(x, 320, (3, 3), t=6, alpha=alpha, strides=1, n=1) 155 | 156 | if alpha > 1.0: 157 | last_filters = _make_divisible(1280 * alpha, 8) 158 | else: 159 | last_filters = 1280 160 | 161 | x = _conv_block(x, last_filters, (1, 1), strides=(1, 1)) 162 | x = GlobalAveragePooling2D()(x) 163 | x = Reshape((1, 1, last_filters))(x) 164 | x = Dropout(0.3, name='Dropout')(x) 165 | x = Conv2D(k, (1, 1), padding='same')(x) 166 | 167 | x = Activation('softmax', name='softmax')(x) 168 | output = Reshape((k,))(x) 169 | 170 | model = Model(inputs, output) 171 | # plot_model(model, to_file='images/MobileNetv2.png', show_shapes=True) 172 | 173 | return model 174 | 175 | 176 | if __name__ == '__main__': 177 | model = MobileNetv2((224, 224, 3), 100, 1.0) 178 | print(model.summary()) 179 | -------------------------------------------------------------------------------- /pse/lanms.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "clipper/clipper.hpp" 4 | 5 | // locality-aware NMS 6 | namespace lanms { 7 | 8 | namespace cl = ClipperLib; 9 | 10 | struct Polygon { 11 | cl::Path poly; 12 | float score; 13 | }; 14 | 15 | float paths_area(const ClipperLib::Paths &ps) { 16 | float area = 0; 17 | for (auto &&p: ps) 18 | area += cl::Area(p); 19 | return area; 20 | } 21 | 22 | float poly_iou(const Polygon &a, const Polygon &b) { 23 | cl::Clipper clpr; 24 | clpr.AddPath(a.poly, cl::ptSubject, true); 25 | clpr.AddPath(b.poly, cl::ptClip, true); 26 | 27 | cl::Paths inter, uni; 28 | clpr.Execute(cl::ctIntersection, inter, cl::pftEvenOdd); 29 | clpr.Execute(cl::ctUnion, uni, cl::pftEvenOdd); 30 | 31 | auto inter_area = paths_area(inter), 32 | uni_area = paths_area(uni); 33 | return std::abs(inter_area) / std::max(std::abs(uni_area), 1.0f); 34 | } 35 | 36 | bool should_merge(const Polygon &a, const Polygon &b, float iou_threshold) { 37 | return poly_iou(a, b) > iou_threshold; 38 | } 39 | 40 | /** 41 | * Incrementally merge polygons 42 | */ 43 | class PolyMerger { 44 | public: 45 | PolyMerger(): score(0), nr_polys(0) { 46 | memset(data, 0, sizeof(data)); 47 | } 48 | 49 | /** 50 | * Add a new polygon to be merged. 51 | */ 52 | void add(const Polygon &p_given) { 53 | Polygon p; 54 | if (nr_polys > 0) { 55 | // vertices of two polygons to merge may not in the same order; 56 | // we match their vertices by choosing the ordering that 57 | // minimizes the total squared distance. 58 | // see function normalize_poly for details. 59 | p = normalize_poly(get(), p_given); 60 | } else { 61 | p = p_given; 62 | } 63 | assert(p.poly.size() == 4); 64 | auto &poly = p.poly; 65 | auto s = p.score; 66 | data[0] += poly[0].X * s; 67 | data[1] += poly[0].Y * s; 68 | 69 | data[2] += poly[1].X * s; 70 | data[3] += poly[1].Y * s; 71 | 72 | data[4] += poly[2].X * s; 73 | data[5] += poly[2].Y * s; 74 | 75 | data[6] += poly[3].X * s; 76 | data[7] += poly[3].Y * s; 77 | 78 | score += p.score; 79 | 80 | nr_polys += 1; 81 | } 82 | 83 | inline std::int64_t sqr(std::int64_t x) { return x * x; } 84 | 85 | Polygon normalize_poly( 86 | const Polygon &ref, 87 | const Polygon &p) { 88 | 89 | std::int64_t min_d = std::numeric_limits::max(); 90 | size_t best_start = 0, best_order = 0; 91 | 92 | for (size_t start = 0; start < 4; start ++) { 93 | size_t j = start; 94 | std::int64_t d = ( 95 | sqr(ref.poly[(j + 0) % 4].X - p.poly[(j + 0) % 4].X) 96 | + sqr(ref.poly[(j + 0) % 4].Y - p.poly[(j + 0) % 4].Y) 97 | + sqr(ref.poly[(j + 1) % 4].X - p.poly[(j + 1) % 4].X) 98 | + sqr(ref.poly[(j + 1) % 4].Y - p.poly[(j + 1) % 4].Y) 99 | + sqr(ref.poly[(j + 2) % 4].X - p.poly[(j + 2) % 4].X) 100 | + sqr(ref.poly[(j + 2) % 4].Y - p.poly[(j + 2) % 4].Y) 101 | + sqr(ref.poly[(j + 3) % 4].X - p.poly[(j + 3) % 4].X) 102 | + sqr(ref.poly[(j + 3) % 4].Y - p.poly[(j + 3) % 4].Y) 103 | ); 104 | if (d < min_d) { 105 | min_d = d; 106 | best_start = start; 107 | best_order = 0; 108 | } 109 | 110 | d = ( 111 | sqr(ref.poly[(j + 0) % 4].X - p.poly[(j + 3) % 4].X) 112 | + sqr(ref.poly[(j + 0) % 4].Y - p.poly[(j + 3) % 4].Y) 113 | + sqr(ref.poly[(j + 1) % 4].X - p.poly[(j + 2) % 4].X) 114 | + sqr(ref.poly[(j + 1) % 4].Y - p.poly[(j + 2) % 4].Y) 115 | + sqr(ref.poly[(j + 2) % 4].X - p.poly[(j + 1) % 4].X) 116 | + sqr(ref.poly[(j + 2) % 4].Y - p.poly[(j + 1) % 4].Y) 117 | + sqr(ref.poly[(j + 3) % 4].X - p.poly[(j + 0) % 4].X) 118 | + sqr(ref.poly[(j + 3) % 4].Y - p.poly[(j + 0) % 4].Y) 119 | ); 120 | if (d < min_d) { 121 | min_d = d; 122 | best_start = start; 123 | best_order = 1; 124 | } 125 | } 126 | 127 | Polygon r; 128 | r.poly.resize(4); 129 | auto j = best_start; 130 | if (best_order == 0) { 131 | for (size_t i = 0; i < 4; i ++) 132 | r.poly[i] = p.poly[(j + i) % 4]; 133 | } else { 134 | for (size_t i = 0; i < 4; i ++) 135 | r.poly[i] = p.poly[(j + 4 - i - 1) % 4]; 136 | } 137 | r.score = p.score; 138 | return r; 139 | } 140 | 141 | Polygon get() const { 142 | Polygon p; 143 | 144 | auto &poly = p.poly; 145 | poly.resize(4); 146 | auto score_inv = 1.0f / std::max(1e-8f, score); 147 | poly[0].X = data[0] * score_inv; 148 | poly[0].Y = data[1] * score_inv; 149 | poly[1].X = data[2] * score_inv; 150 | poly[1].Y = data[3] * score_inv; 151 | poly[2].X = data[4] * score_inv; 152 | poly[2].Y = data[5] * score_inv; 153 | poly[3].X = data[6] * score_inv; 154 | poly[3].Y = data[7] * score_inv; 155 | 156 | assert(score > 0); 157 | p.score = score; 158 | 159 | return p; 160 | } 161 | 162 | private: 163 | std::int64_t data[8]; 164 | float score; 165 | std::int32_t nr_polys; 166 | }; 167 | 168 | 169 | /** 170 | * The standard NMS algorithm. 171 | */ 172 | std::vector standard_nms(std::vector &polys, float iou_threshold) { 173 | size_t n = polys.size(); 174 | if (n == 0) 175 | return {}; 176 | std::vector indices(n); 177 | std::iota(std::begin(indices), std::end(indices), 0); 178 | std::sort(std::begin(indices), std::end(indices), [&](size_t i, size_t j) { return polys[i].score > polys[j].score; }); 179 | 180 | std::vector keep; 181 | while (indices.size()) { 182 | size_t p = 0, cur = indices[0]; 183 | keep.emplace_back(cur); 184 | for (size_t i = 1; i < indices.size(); i ++) { 185 | if (!should_merge(polys[cur], polys[indices[i]], iou_threshold)) { 186 | indices[p ++] = indices[i]; 187 | } 188 | } 189 | indices.resize(p); 190 | } 191 | 192 | std::vector ret; 193 | for (auto &&i: keep) { 194 | ret.emplace_back(polys[i]); 195 | } 196 | return ret; 197 | } 198 | 199 | std::vector 200 | merge_quadrangle_n9(const float *data, size_t n, float iou_threshold) { 201 | using cInt = cl::cInt; 202 | 203 | // first pass 204 | std::vector polys; 205 | for (size_t i = 0; i < n; i ++) { 206 | auto p = data + i * 9; 207 | Polygon poly{ 208 | { 209 | {cInt(p[0]), cInt(p[1])}, 210 | {cInt(p[2]), cInt(p[3])}, 211 | {cInt(p[4]), cInt(p[5])}, 212 | {cInt(p[6]), cInt(p[7])}, 213 | }, 214 | p[8], 215 | }; 216 | 217 | if (polys.size()) { 218 | // merge with the last one 219 | auto &bpoly = polys.back(); 220 | if (should_merge(poly, bpoly, iou_threshold)) { 221 | PolyMerger merger; 222 | merger.add(bpoly); 223 | merger.add(poly); 224 | bpoly = merger.get(); 225 | } else { 226 | polys.emplace_back(poly); 227 | } 228 | } else { 229 | polys.emplace_back(poly); 230 | } 231 | } 232 | return standard_nms(polys, iou_threshold); 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /test_ctw1500.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import sys 4 | import time 5 | import collections 6 | import argparse 7 | import numpy as np 8 | import tensorflow as tf 9 | from dataset import CTW1500TestLoader, ctw_test_loader 10 | import models 11 | import util 12 | # c++ version pse based on opencv 3+ 13 | #from pse import pse 14 | # python pse 15 | from pypse import pse as pypse 16 | 17 | def extend_3c(img): 18 | img = img.reshape(img.shape[0], img.shape[1], 1) 19 | img = np.concatenate((img, img, img), axis=2) 20 | return img 21 | 22 | def debug(idx, img_paths, imgs, output_root): 23 | if not os.path.exists(output_root): 24 | os.makedirs(output_root) 25 | 26 | col = [] 27 | for i in range(len(imgs)): 28 | row = [] 29 | for j in range(len(imgs[i])): 30 | # img = cv2.copyMakeBorder(imgs[i][j], 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0]) 31 | row.append(imgs[i][j]) 32 | res = np.concatenate(row, axis=1) 33 | col.append(res) 34 | res = np.concatenate(col, axis=0) 35 | img_name = img_paths[idx].split('/')[-1] 36 | print(idx, '/', len(img_paths), img_name) 37 | cv2.imwrite(output_root + img_name, res) 38 | 39 | def write_result_as_txt(image_name, bboxes, path): 40 | if not os.path.exists(path): 41 | os.makedirs(path) 42 | 43 | filename = util.io.join_path(path, '%s.txt'%(image_name)) 44 | lines = [] 45 | for b_idx, bbox in enumerate(bboxes): 46 | values = [int(v) for v in bbox] 47 | # line = "%d, %d, %d, %d, %d, %d, %d, %d\n"%tuple(values) 48 | line = "%d"%values[0] 49 | for v_id in range(1, len(values)): 50 | line += ", %d"%values[v_id] 51 | line += '\n' 52 | lines.append(line) 53 | util.io.write_lines(filename, lines) 54 | 55 | def polygon_from_points(points): 56 | """ 57 | Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 58 | """ 59 | resBoxes=np.empty([1, 8],dtype='int32') 60 | resBoxes[0, 0] = int(points[0]) 61 | resBoxes[0, 4] = int(points[1]) 62 | resBoxes[0, 1] = int(points[2]) 63 | resBoxes[0, 5] = int(points[3]) 64 | resBoxes[0, 2] = int(points[4]) 65 | resBoxes[0, 6] = int(points[5]) 66 | resBoxes[0, 3] = int(points[6]) 67 | resBoxes[0, 7] = int(points[7]) 68 | pointMat = resBoxes[0].reshape([2, 4]).T 69 | return plg.Polygon(pointMat) 70 | 71 | 72 | def test(args): 73 | data_loader = CTW1500TestLoader(long_size=args.long_size) 74 | test_loader = ctw_test_loader(data_loader, 1) 75 | 76 | # Setup Model 77 | if args.arch == "resnet50": 78 | model = models.resnet50(pretrained=True, num_classes=7, scale=args.scale) 79 | elif args.arch == "resnet101": 80 | model = models.resnet101(pretrained=True, num_classes=7, scale=args.scale) 81 | elif args.arch == "resnet152": 82 | model = models.resnet152(pretrained=True, num_classes=7, scale=args.scale) 83 | elif args.arch == "resnet18": 84 | model = models.resnet18(pretrained=True, num_classes=7, scale=args.scale) 85 | if args.resume is not None: 86 | print("Loading model and optimizer from checkpoint '{}'".format(args.resume)) 87 | model.load_weights(args.resume) 88 | print("Loaded checkpoint '{}' ".format(args.resume,)) 89 | sys.stdout.flush() 90 | else: 91 | print("No checkpoint found at '{}'".format(args.resume)) 92 | sys.stdout.flush() 93 | 94 | total_frame = 0.0 95 | total_time = 0.0 96 | for idx, (org_img, img, data_length) in enumerate(test_loader): 97 | print('progress: %d / %d'%(idx, data_length)) 98 | sys.stdout.flush() 99 | 100 | org_img = org_img.numpy().astype('uint8')[0] 101 | text_box = org_img.copy() 102 | 103 | start = time.time() 104 | outputs = model(img) 105 | outputs = tf.transpose(outputs,(0,3,1,2)) 106 | 107 | score = tf.sigmoid(outputs[:, 0, :, :]) 108 | outputs = (tf.sign(outputs - args.binary_th) + 1) / 2 109 | 110 | text = outputs[:, 0, :, :] 111 | kernels = outputs[:, 0:args.kernel_num, :, :] * text 112 | 113 | score = score.numpy()[0].astype(np.float32) 114 | text = text.numpy()[0].astype(np.uint8) 115 | kernels = kernels.numpy()[0].astype(np.uint8) 116 | 117 | # c++ version pse #编译问题 暂时不用 118 | #pred = pse(kernels, args.min_kernel_area / (args.scale * args.scale)) 119 | # python version pse 120 | pred = pypse(kernels, args.min_kernel_area / (args.scale * args.scale)) 121 | 122 | # scale = (org_img.shape[0] * 1.0 / pred.shape[0], org_img.shape[1] * 1.0 / pred.shape[1]) 123 | scale = (org_img.shape[1] * 1.0 / pred.shape[1], org_img.shape[0] * 1.0 / pred.shape[0]) 124 | label = pred 125 | label_num = np.max(label) + 1 126 | bboxes = [] 127 | for i in range(1, label_num): 128 | points = np.array(np.where(label == i)).transpose((1, 0))[:, ::-1] 129 | 130 | if points.shape[0] < args.min_area / (args.scale * args.scale): 131 | continue 132 | 133 | score_i = np.mean(score[label == i]) 134 | if score_i < args.min_score: 135 | continue 136 | 137 | # rect = cv2.minAreaRect(points) 138 | binary = np.zeros(label.shape, dtype='uint8') 139 | binary[label == i] = 1 140 | 141 | contours, _ = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 142 | contour = contours[0] 143 | # epsilon = 0.01 * cv2.arcLength(contour, True) 144 | # bbox = cv2.approxPolyDP(contour, epsilon, True) 145 | bbox = contour 146 | 147 | if bbox.shape[0] <= 2: 148 | continue 149 | 150 | bbox = bbox * scale 151 | bbox = bbox.astype('int32') 152 | bboxes.append(bbox.reshape(-1)) 153 | 154 | end = time.time() 155 | total_frame += 1 156 | total_time += (end - start) 157 | print('fps: %.2f'%(total_frame / total_time)) 158 | sys.stdout.flush() 159 | 160 | for bbox in bboxes: 161 | cv2.drawContours(text_box, [bbox.reshape(bbox.shape[0] // 2, 2)], -1, (0, 255, 0), 2) 162 | 163 | image_name = data_loader.img_paths[idx].split('/')[-1].split('.')[0] 164 | write_result_as_txt(image_name, bboxes, 'outputs/submit_ctw1500/') 165 | 166 | text_box = cv2.resize(text_box, (text.shape[1], text.shape[0])) 167 | debug(idx, data_loader.img_paths, [[text_box]], 'outputs/vis_ctw1500/') 168 | 169 | 170 | if __name__ == '__main__': 171 | parser = argparse.ArgumentParser(description='Hyperparams') 172 | parser.add_argument('--arch', nargs='?', type=str, default='resnet18') 173 | parser.add_argument('--resume', nargs='?', type=str, default='checkpoints/', 174 | help='Path to previous saved model to restart from') 175 | parser.add_argument('--binary_th', nargs='?', type=float, default=1.0, 176 | help='Path to previous saved model to restart from') 177 | parser.add_argument('--kernel_num', nargs='?', type=int, default=3, 178 | help='Path to previous saved model to restart from') 179 | parser.add_argument('--scale', nargs='?', type=int, default=1, 180 | help='Path to previous saved model to restart from') 181 | parser.add_argument('--long_size', nargs='?', type=int, default=1280, 182 | help='') 183 | parser.add_argument('--min_kernel_area', nargs='?', type=float, default=10.0, 184 | help='min kernel area') 185 | parser.add_argument('--min_area', nargs='?', type=float, default=300.0, 186 | help='min area') 187 | parser.add_argument('--min_score',nargs='?', type=float, default=0.93, 188 | help='min score') 189 | 190 | args = parser.parse_args() 191 | test(args) -------------------------------------------------------------------------------- /pse/include/pybind11/chrono.h: -------------------------------------------------------------------------------- 1 | /* 2 | pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime 3 | 4 | Copyright (c) 2016 Trent Houliston and 5 | Wenzel Jakob 6 | 7 | All rights reserved. Use of this source code is governed by a 8 | BSD-style license that can be found in the LICENSE file. 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "pybind11.h" 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | // Backport the PyDateTime_DELTA functions from Python3.3 if required 20 | #ifndef PyDateTime_DELTA_GET_DAYS 21 | #define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days) 22 | #endif 23 | #ifndef PyDateTime_DELTA_GET_SECONDS 24 | #define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds) 25 | #endif 26 | #ifndef PyDateTime_DELTA_GET_MICROSECONDS 27 | #define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds) 28 | #endif 29 | 30 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE) 31 | NAMESPACE_BEGIN(detail) 32 | 33 | template class duration_caster { 34 | public: 35 | typedef typename type::rep rep; 36 | typedef typename type::period period; 37 | 38 | typedef std::chrono::duration> days; 39 | 40 | bool load(handle src, bool) { 41 | using namespace std::chrono; 42 | 43 | // Lazy initialise the PyDateTime import 44 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 45 | 46 | if (!src) return false; 47 | // If invoked with datetime.delta object 48 | if (PyDelta_Check(src.ptr())) { 49 | value = type(duration_cast>( 50 | days(PyDateTime_DELTA_GET_DAYS(src.ptr())) 51 | + seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr())) 52 | + microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr())))); 53 | return true; 54 | } 55 | // If invoked with a float we assume it is seconds and convert 56 | else if (PyFloat_Check(src.ptr())) { 57 | value = type(duration_cast>(duration(PyFloat_AsDouble(src.ptr())))); 58 | return true; 59 | } 60 | else return false; 61 | } 62 | 63 | // If this is a duration just return it back 64 | static const std::chrono::duration& get_duration(const std::chrono::duration &src) { 65 | return src; 66 | } 67 | 68 | // If this is a time_point get the time_since_epoch 69 | template static std::chrono::duration get_duration(const std::chrono::time_point> &src) { 70 | return src.time_since_epoch(); 71 | } 72 | 73 | static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) { 74 | using namespace std::chrono; 75 | 76 | // Use overloaded function to get our duration from our source 77 | // Works out if it is a duration or time_point and get the duration 78 | auto d = get_duration(src); 79 | 80 | // Lazy initialise the PyDateTime import 81 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 82 | 83 | // Declare these special duration types so the conversions happen with the correct primitive types (int) 84 | using dd_t = duration>; 85 | using ss_t = duration>; 86 | using us_t = duration; 87 | 88 | auto dd = duration_cast(d); 89 | auto subd = d - dd; 90 | auto ss = duration_cast(subd); 91 | auto us = duration_cast(subd - ss); 92 | return PyDelta_FromDSU(dd.count(), ss.count(), us.count()); 93 | } 94 | 95 | PYBIND11_TYPE_CASTER(type, _("datetime.timedelta")); 96 | }; 97 | 98 | // This is for casting times on the system clock into datetime.datetime instances 99 | template class type_caster> { 100 | public: 101 | typedef std::chrono::time_point type; 102 | bool load(handle src, bool) { 103 | using namespace std::chrono; 104 | 105 | // Lazy initialise the PyDateTime import 106 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 107 | 108 | if (!src) return false; 109 | 110 | std::tm cal; 111 | microseconds msecs; 112 | 113 | if (PyDateTime_Check(src.ptr())) { 114 | cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr()); 115 | cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr()); 116 | cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr()); 117 | cal.tm_mday = PyDateTime_GET_DAY(src.ptr()); 118 | cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1; 119 | cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900; 120 | cal.tm_isdst = -1; 121 | msecs = microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr())); 122 | } else if (PyDate_Check(src.ptr())) { 123 | cal.tm_sec = 0; 124 | cal.tm_min = 0; 125 | cal.tm_hour = 0; 126 | cal.tm_mday = PyDateTime_GET_DAY(src.ptr()); 127 | cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1; 128 | cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900; 129 | cal.tm_isdst = -1; 130 | msecs = microseconds(0); 131 | } else if (PyTime_Check(src.ptr())) { 132 | cal.tm_sec = PyDateTime_TIME_GET_SECOND(src.ptr()); 133 | cal.tm_min = PyDateTime_TIME_GET_MINUTE(src.ptr()); 134 | cal.tm_hour = PyDateTime_TIME_GET_HOUR(src.ptr()); 135 | cal.tm_mday = 1; // This date (day, month, year) = (1, 0, 70) 136 | cal.tm_mon = 0; // represents 1-Jan-1970, which is the first 137 | cal.tm_year = 70; // earliest available date for Python's datetime 138 | cal.tm_isdst = -1; 139 | msecs = microseconds(PyDateTime_TIME_GET_MICROSECOND(src.ptr())); 140 | } 141 | else return false; 142 | 143 | value = system_clock::from_time_t(std::mktime(&cal)) + msecs; 144 | return true; 145 | } 146 | 147 | static handle cast(const std::chrono::time_point &src, return_value_policy /* policy */, handle /* parent */) { 148 | using namespace std::chrono; 149 | 150 | // Lazy initialise the PyDateTime import 151 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; } 152 | 153 | std::time_t tt = system_clock::to_time_t(time_point_cast(src)); 154 | // this function uses static memory so it's best to copy it out asap just in case 155 | // otherwise other code that is using localtime may break this (not just python code) 156 | std::tm localtime = *std::localtime(&tt); 157 | 158 | // Declare these special duration types so the conversions happen with the correct primitive types (int) 159 | using us_t = duration; 160 | 161 | return PyDateTime_FromDateAndTime(localtime.tm_year + 1900, 162 | localtime.tm_mon + 1, 163 | localtime.tm_mday, 164 | localtime.tm_hour, 165 | localtime.tm_min, 166 | localtime.tm_sec, 167 | (duration_cast(src.time_since_epoch() % seconds(1))).count()); 168 | } 169 | PYBIND11_TYPE_CASTER(type, _("datetime.datetime")); 170 | }; 171 | 172 | // Other clocks that are not the system clock are not measured as datetime.datetime objects 173 | // since they are not measured on calendar time. So instead we just make them timedeltas 174 | // Or if they have passed us a time as a float we convert that 175 | template class type_caster> 176 | : public duration_caster> { 177 | }; 178 | 179 | template class type_caster> 180 | : public duration_caster> { 181 | }; 182 | 183 | NAMESPACE_END(detail) 184 | NAMESPACE_END(PYBIND11_NAMESPACE) 185 | -------------------------------------------------------------------------------- /test_id41k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import sys 4 | import time 5 | import collections 6 | import argparse 7 | import numpy as np 8 | import tensorflow as tf 9 | from dataset import CTW1500TestLoader, ctw_test_loader 10 | import models 11 | import util 12 | # c++ version pse based on opencv 3+ 13 | #from pse import pse 14 | # python pse 15 | from pypse import pse as pypse 16 | 17 | def extend_3c(img): 18 | img = img.reshape(img.shape[0], img.shape[1], 1) 19 | img = np.concatenate((img, img, img), axis=2) 20 | return img 21 | 22 | def debug(idx, img_paths, imgs, output_root): 23 | if not os.path.exists(output_root): 24 | os.makedirs(output_root) 25 | 26 | col = [] 27 | for i in range(len(imgs)): 28 | row = [] 29 | for j in range(len(imgs[i])): 30 | # img = cv2.copyMakeBorder(imgs[i][j], 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0]) 31 | row.append(imgs[i][j]) 32 | res = np.concatenate(row, axis=1) 33 | col.append(res) 34 | res = np.concatenate(col, axis=0) 35 | img_name = img_paths[idx].split('/')[-1] 36 | print(idx, '/', len(img_paths), img_name) 37 | cv2.imwrite(output_root + img_name, res) 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | def write_result_as_txt(image_name, bboxes, path): 46 | if not os.path.exists(path): 47 | os.makedirs(path) 48 | 49 | filename = util.io.join_path(path, '%s.txt'%(image_name)) 50 | lines = [] 51 | for b_idx, bbox in enumerate(bboxes): 52 | values = [int(v) for v in bbox] 53 | # line = "%d, %d, %d, %d, %d, %d, %d, %d\n"%tuple(values) 54 | line = "%d"%values[0] 55 | for v_id in range(1, len(values)): 56 | line += ", %d"%values[v_id] 57 | line += '\n' 58 | lines.append(line) 59 | util.io.write_lines(filename, lines) 60 | 61 | def polygon_from_points(points): 62 | """ 63 | Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 64 | """ 65 | resBoxes=np.empty([1, 8],dtype='int32') 66 | resBoxes[0, 0] = int(points[0]) 67 | resBoxes[0, 4] = int(points[1]) 68 | resBoxes[0, 1] = int(points[2]) 69 | resBoxes[0, 5] = int(points[3]) 70 | resBoxes[0, 2] = int(points[4]) 71 | resBoxes[0, 6] = int(points[5]) 72 | resBoxes[0, 3] = int(points[6]) 73 | resBoxes[0, 7] = int(points[7]) 74 | pointMat = resBoxes[0].reshape([2, 4]).T 75 | return plg.Polygon(pointMat) 76 | 77 | 78 | 79 | 80 | def test(args): 81 | data_loader = CTW1500TestLoader(long_size=args.long_size) 82 | test_loader = ctw_test_loader(data_loader, 1) 83 | 84 | # Setup Model 85 | if args.arch == "resnet50": 86 | model = models.resnet50(pretrained=True, num_classes=7, scale=args.scale) 87 | elif args.arch == "resnet101": 88 | model = models.resnet101(pretrained=True, num_classes=7, scale=args.scale) 89 | elif args.arch == "resnet152": 90 | model = models.resnet152(pretrained=True, num_classes=7, scale=args.scale) 91 | elif args.arch == "resnet18": 92 | model = models.resnet18(pretrained=True, num_classes=7, scale=args.scale) 93 | if args.resume is not None: 94 | print("Loading model and optimizer from checkpoint '{}'".format(args.resume)) 95 | model.load_weights(args.resume) 96 | print("Loaded checkpoint '{}' ".format(args.resume,)) 97 | sys.stdout.flush() 98 | else: 99 | print("No checkpoint found at '{}'".format(args.resume)) 100 | sys.stdout.flush() 101 | 102 | total_frame = 0.0 103 | total_time = 0.0 104 | for idx, (org_img, img, data_length) in enumerate(test_loader): 105 | print('progress: %d / %d'%(idx, data_length)) 106 | sys.stdout.flush() 107 | 108 | org_img = org_img.numpy().astype('uint8')[0] 109 | text_box = org_img.copy() 110 | 111 | start = time.time() 112 | outputs = model(img) 113 | outputs = tf.transpose(outputs,(0,3,1,2)) 114 | 115 | score = tf.sigmoid(outputs[:, 0, :, :]) 116 | outputs = (tf.sign(outputs - args.binary_th) + 1) / 2 117 | 118 | text = outputs[:, 0, :, :] 119 | kernels = outputs[:, 0:args.kernel_num, :, :] * text 120 | 121 | score = score.numpy()[0].astype(np.float32) 122 | text = text.numpy()[0].astype(np.uint8) 123 | kernels = kernels.numpy()[0].astype(np.uint8) 124 | 125 | # c++ version pse #编译问题 暂时不用 126 | #pred = pse(kernels, args.min_kernel_area / (args.scale * args.scale)) 127 | # python version pse 128 | pred = pypse(kernels, args.min_kernel_area / (args.scale * args.scale)) 129 | 130 | # scale = (org_img.shape[0] * 1.0 / pred.shape[0], org_img.shape[1] * 1.0 / pred.shape[1]) 131 | scale = (org_img.shape[1] * 1.0 / pred.shape[1], org_img.shape[0] * 1.0 / pred.shape[0]) 132 | label = pred 133 | label_num = np.max(label) + 1 134 | bboxes = [] 135 | for i in range(1, label_num): 136 | points = np.array(np.where(label == i)).transpose((1, 0))[:, ::-1] 137 | 138 | if points.shape[0] < args.min_area / (args.scale * args.scale): 139 | continue 140 | 141 | score_i = np.mean(score[label == i]) 142 | if score_i < args.min_score: 143 | continue 144 | 145 | # rect = cv2.minAreaRect(points) 146 | binary = np.zeros(label.shape, dtype='uint8') 147 | binary[label == i] = 1 148 | 149 | contours, _ = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 150 | contour = contours[0] 151 | # epsilon = 0.01 * cv2.arcLength(contour, True) 152 | # bbox = cv2.approxPolyDP(contour, epsilon, True) 153 | bbox = contour 154 | 155 | if bbox.shape[0] <= 2: 156 | continue 157 | 158 | # bbox = bbox * scale 159 | # bbox = bbox.astype('int32') 160 | # bboxes.append(bbox.reshape(-1)) 161 | bbox = bbox * scale 162 | bbox = cv2.minAreaRect(bbox.reshape((-1, 2)).astype(np.float32)) # 163 | bbox = cv2.boxPoints(bbox) # 164 | bbox = bbox.astype('int32') 165 | bboxes.append(bbox.reshape(-1)) 166 | 167 | end = time.time() 168 | total_frame += 1 169 | total_time += (end - start) 170 | print('fps: %.2f'%(total_frame / total_time)) 171 | sys.stdout.flush() 172 | 173 | for bbox in bboxes: 174 | cv2.drawContours(text_box, [bbox.reshape(bbox.shape[0] // 2, 2)], -1, (0, 255, 0), 2) 175 | 176 | image_name = data_loader.img_paths[idx].split('/')[-1].split('.jpg')[0] 177 | write_result_as_txt(image_name, bboxes, 'outputs/submit_ctw1500/') 178 | 179 | text_box = cv2.resize(text_box, (text.shape[1], text.shape[0])) 180 | debug(idx, data_loader.img_paths, [[text_box]], 'outputs/vis_ctw1500/') 181 | 182 | 183 | if __name__ == '__main__': 184 | parser = argparse.ArgumentParser(description='Hyperparams') 185 | parser.add_argument('--arch', nargs='?', type=str, default='resnet18') 186 | parser.add_argument('--resume', nargs='?', type=str, default='checkpoints/', 187 | help='Path to previous saved model to restart from') 188 | parser.add_argument('--binary_th', nargs='?', type=float, default=1.0, 189 | help='Path to previous saved model to restart from') 190 | parser.add_argument('--kernel_num', nargs='?', type=int, default=3, 191 | help='Path to previous saved model to restart from') 192 | parser.add_argument('--scale', nargs='?', type=int, default=1, 193 | help='Path to previous saved model to restart from') 194 | parser.add_argument('--long_size', nargs='?', type=int, default=1280, 195 | help='') 196 | parser.add_argument('--min_kernel_area', nargs='?', type=float, default=10.0, 197 | help='min kernel area') 198 | parser.add_argument('--min_area', nargs='?', type=float, default=300.0, 199 | help='min area') 200 | parser.add_argument('--min_score'test_ctw1500.py, nargs='?', type=float, default=0.93, 201 | help='min score') 202 | 203 | args = parser.parse_args() 204 | test(args) --------------------------------------------------------------------------------