├── .gitignore ├── LICENSE ├── README.md ├── compute_hmean.py ├── config.py ├── data_loader.py ├── dataset ├── test │ └── README.txt ├── test_compute_hmean │ ├── __pycache__ │ │ ├── rrc_evaluation_funcs.cpython-35.pyc │ │ └── script.cpython-35.pyc │ ├── readme.txt │ ├── rrc_evaluation_funcs.py │ └── script.py ├── test_result │ └── README └── train │ └── README.txt ├── demo ├── result_img │ ├── img_1.jpg │ ├── img_16.jpg │ ├── img_16.txt │ ├── img_2.jpg │ ├── res_img_1.txt │ └── res_img_2.txt └── test_img │ ├── img_1.jpg │ ├── img_16.jpg │ └── img_2.jpg ├── eval.py ├── locality_aware_nms.py ├── loss.py ├── model.py ├── preprossing.py ├── requirements.txt ├── tensorboards └── README.txt ├── tmp ├── README.txt └── backbone_net │ └── README.txt ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Mingliang Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EAST: An Efficient and Accurate Scene Text Detector 2 | 3 | ### Introduction 4 | This is a Pytorch re-implementation of [EAST: An Efficient and Accurate Scene Text Detector](https://arxiv.org/abs/1704.03155v2). 5 | The features are summarized blow: 6 | + Only **RBOX** part is implemented. 7 | + Incidental Scene Text Detection Challenge using only training images from ICDAR 2015 and 2013. 8 | + Differences from original paper 9 | + Use Mobilenet-v2 / ResNet-50 10 | + Use dice loss (optimize IoU of segmentation) rather than balanced cross entropy 11 | + Use linear learning rate decay rather than staged learning rate decay 12 | + Every parameter is written in `config.py`, you should change it before you run this project 13 | + The pre-trained model byprovided achieves ( Mobilenet-v2-**75.05**, ResNet-50-**81.32**) F1-score on ICDAR 2015 14 | + Speed on 720p (resolution of 1280x720) images: 15 | + Graphic card: GTX 1080 Ti 16 | + Network fprop: **~15 ms**/**~50 ms** 17 | + NMS (C++): **~6ms**/**~6ms** 18 | + Overall: **~43 fps**/**~16 fps** 19 | 20 | Thanks for the code of authors ([@argman](https://github.com/argman)) and ([@songdejia](https://github.com/songdejia)), thanks for the help of my partners Wei Baole, Yang Yirong, Liu Hao, Wang Wei and Ma Yuting. 21 | 22 | ### Contents 23 | 1. [Installation](#installation) 24 | 2. [Download](#download) 25 | 3. [Train](#train) 26 | 4. [Test](#test) 27 | 5. [Demo](#demo) 28 | 6. [Compute-hmean](#compute-hmean) 29 | 7. [Examples](#examples) 30 | 31 | ### Installation 32 | 1. Any version of pytorch version > 0.4.1 should be ok. 33 | 2. Other librarys are instructed in `requirements.txt`. 34 | 35 | ### Download 36 | This project provides pre-trained models and datasets in [BaiduYun link](https://pan.baidu.com/s/19yXbaWp0TvEdtsrAERo1Qg), keyword: yu00 or [Google drive link](https://drive.google.com/file/d/1oApGwPfuzuAfwNu2jSnWPIw88XBigHgc/view?usp=sharing) : 37 | 38 | 1. **backbone_net**: 39 | The folder contains pretrained backbone net of Mobilenet-v2 / ResNet-50 which should put into `.\tmp\backbone_net` 40 | 2. **pretrain-model-EAST**: 41 | The folder contains pretrained Model of EAST which should put into `.\tmp`, you should also change files `model.py` and `train.py` 42 | 3. **train-dataset-ICDAR15,13**: 43 | The folder contains train dataset of ICDAR15,13 which should put into `.\dataset\train` 44 | 4. **test-dataset-ICDAR15**: 45 | The folder contains test dataset of ICDAR15 which should put into `.\dataset\test` 46 | 5. **test-groudtruth-ICDAR15**: 47 | The folder contains groundtruth labels of test dataset ICDAR15 which should put into `.\dataset\test_compute_hmean` 48 | 49 | 50 | ### Train 51 | If you want to train the model, you should change `config.py` parameter. 52 | 1. `train_data_path` is the path of train dataset, put train dataset in this folder. 53 | 2. Depending on your hardware configuration, set parameters of `train_batch_size_per_gpu` `num_workers` `gpu_ids` and `gpu`. 54 | 3. Of course you should specify the pre-training model of backbone_net in `pretrained_basemodel_path` and `pretrained`. 55 | 4. If you want to resume the model of EAST, you should specify the path of `checkpoint` and `resume`. 56 | 5. On the other hand, you could also adjust the setting of other overparameters, such learning rate, weight decay, decay_steps and so on. 57 | 6. Then run 58 | ``` 59 | python train.py 60 | ``` 61 | *Note: you should change the train and test datasets format same as provided in this project, which the gt text files have same names as image files. In this project, only `.jpg` format image files is accepted. Of course, you can change the code of project.* 62 | 63 | ### Test 64 | If you want to test the model, you should also change `config.py` parameter. 65 | 1. `test_img_path` is the path of test dataset, put test dataset in this folder. `res_img_path` is the path of result which will saved result of images files and txt files. 66 | 2. You should also specify the pretrained model in `checkpoint`. 67 | 3. Then run 68 | ``` 69 | python eval.py 70 | ``` 71 | 72 | 73 | ### Demo 74 | If you only want to test some demos, you downloaded the pre-trained model provided in this project and change `config.py` 75 | 1. Put demo images in `.\demo\test_img`, and specify the path of `test_img_path` and `res_img_path`, you will find result in `.\demo\result_img` 76 | 2. You should also specify the pretrained model in `checkpoint`. 77 | 3. Then run 78 | ``` 79 | python eval.py 80 | ``` 81 | 82 | ### Compute-hmean 83 | 1. Put groudtruth of `gt.zip` in `.\dataset\test_compute_hmean` 84 | 2. Change parameter of `config.py`, specify the path of `compute_hmean_path` 85 | 3. Then run 86 | ``` 87 | python compute_hmean.py 88 | ``` 89 | *Note: The result will show in the screen, also record in `.\dataset\test_compute_hmean\log_epoch_hmean.txt`* 90 | 91 | 92 | ### Examples 93 | Here are some test examples on icdar2015, enjoy the beautiful text boxes by mobilenet-v2 EAST! 94 | 95 | ![image_1](demo/result_img/img_1.jpg) 96 | ![image_2](demo/result_img/img_2.jpg) 97 | ![image_16](demo/result_img/img_16.jpg) 98 | 99 | Please let me know if you encounter any issues(my email zhangmingliang2018@ia.ac.cn). 100 | 101 | -------------------------------------------------------------------------------- /compute_hmean.py: -------------------------------------------------------------------------------- 1 | from dataset.test_compute_hmean import rrc_evaluation_funcs, script 2 | import config as cfg 3 | import os 4 | import glob 5 | import zipfile 6 | import time 7 | 8 | gt_file_path = os.path.join(cfg.compute_hmean_path, 'gt.zip') 9 | submit_file_path = os.path.join(cfg.compute_hmean_path, 'submit.zip') 10 | log_file_path = os.path.join(cfg.compute_hmean_path, 'log_epoch_hmean.txt') 11 | result_dir_path = cfg.compute_hmean_path 12 | 13 | print('EAST <==> TEST <==> Compute Humean <==> Begin') 14 | 15 | with zipfile.ZipFile(submit_file_path, 'w') as azip: 16 | # 必须保证路径存在,将bb件夹(及其下aa.txt)添加到压缩包,压缩算法LZMA 17 | # 新建压缩包,放文件进去,若压缩包已经存在,将覆盖。可选择用a模式,追加 18 | txt_files = [] 19 | txt_files.extend(glob.glob( 20 | os.path.join(cfg.res_img_path, '*.{}'.format('txt')))) 21 | for txt_name in txt_files: 22 | azip.write(txt_name, os.path.basename(txt_name), compress_type=zipfile.ZIP_LZMA) 23 | 24 | resDict = rrc_evaluation_funcs.main_evaluation({'g': gt_file_path, 's': submit_file_path, 'o': result_dir_path}, 25 | script.default_evaluation_params, script.validate_data, 26 | script.evaluate_method) 27 | 28 | # print(resDict) 29 | recall = resDict['method']['recall'] 30 | precision = resDict['method']['precision'] 31 | hmean = resDict['method']['hmean'] 32 | 33 | # print('EAST <==> Evaluation <==> Precision:%.4f Recall:%.4f Hmean %.4f <==> Done' % (precision, recall, 34 | # hmean)) 35 | with open(log_file_path, 'a') as f: 36 | 37 | f.write(time.strftime("%a %b %d %H:%M:%S %Y", time.localtime())) 38 | f.write('\nEAST <==> Evaluation <==> Precision:{:.4f} Recall:{:.4f} Hmean{:.4f} <==> Done\n' 39 | .format(precision, recall, hmean)) 40 | 41 | print('\nEAST <==> TEST <==> Compute Humean <==> End') 42 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # data-config 2 | import numpy as np 3 | 4 | train_data_path = './dataset/train/' 5 | train_batch_size_per_gpu = 14 # 14 6 | num_workers = 24 # 24 7 | gpu_ids = [0] # [0,1,2,3] 8 | gpu = 1 # 4 9 | input_size = 512 # 预处理后归一化后图像尺寸 10 | background_ratio = 3. / 8 # 纯背景样本比例 11 | random_scale = np.array([0.5, 1, 2.0, 3.0]) # 提取多尺度图片信息 12 | geometry = 'RBOX' # 选择使用几何特征图类型 13 | max_image_large_side = 1280 14 | max_text_size = 800 15 | min_text_size = 10 16 | min_crop_side_ratio = 0.1 17 | means=[100, 100, 100] 18 | pretrained = True # 是否加载基础网络的预训练模型 19 | pretrained_basemodel_path = './tmp/backbone_net/mobilenet_v2.pth.tar' 20 | pre_lr = 1e-4 # 基础网络的初始学习率 21 | lr = 1e-3 # 后面网络的初始学习率 22 | decay_steps = 50 # decayed_learning_rate = learning_rate * decay_rate ^ (global_epoch / decay_steps) 23 | decay_rate = 0.97 24 | init_type = 'xavier' # 网络参数初始化方式 25 | resume = True # 整体网络是否恢复原来保存的模型 26 | checkpoint = './tmp/epoch_1100_checkpoint.pth.tar' # 指定具体路径及文件名 27 | max_epochs = 1000 # 最大迭代epochs数 28 | l2_weight_decay = 1e-6 # l2正则化惩罚项权重 29 | print_freq = 10 # 每10个batch输出损失结果 30 | save_eval_iteration = 50 # 每10个epoch保存一次模型,并做一次评价 31 | save_model_path = './tmp/' # 模型保存路径 32 | test_img_path = './demo/test_img/' # demo测试样本路径'./demo/test_img/',数据集测试为'./dataset/test/' 33 | res_img_path = './demo/result_img/' # demo结果存放路径'./demo/result_img/',数据集测试为 './dataset/test_result/' 34 | write_images = True # 是否输出图像结果 35 | score_map_thresh = 0.8 # 置信度阈值 36 | box_thresh = 0.1 # 文本框中置信度平均值的阈值 37 | nms_thres = 0.2 # 局部非极大抑制IOU阈值 38 | compute_hmean_path = './dataset/test_compute_hmean/' 39 | 40 | 41 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | import numpy as np 3 | import preprossing 4 | import config as cfg 5 | 6 | 7 | class custom_dset(data.Dataset): 8 | def __init__(self, transform=None): 9 | 10 | # 获取文件路径 11 | self.img_path_list = preprossing.get_images(cfg.train_data_path) 12 | self.transform = transform 13 | # print(self.img_path_list) 14 | 15 | def __getitem__(self, index): 16 | 17 | status = True 18 | while status: 19 | # img 预处理后的图像 20 | # img_path 预处理后的图像文件路径 21 | # score_map 置信度特征图 22 | # geo_map 几何特征图 23 | # training_mask 训练掩膜 24 | # print(self.img_path_list) 25 | img, img_path, score_map, geo_map, training_mask = preprossing.generator( 26 | index=index, 27 | input_size=cfg.input_size, 28 | background_ratio=cfg.background_ratio, 29 | random_scale=cfg.random_scale, 30 | image_list=self.img_path_list) 31 | 32 | if not (img is None): 33 | 34 | status = False 35 | if self.transform is not None: 36 | # 是否进行transform, 512,512,3 ndarray should transform to 3,512,512 37 | img = self.transform(img) 38 | score_map = self.transform(score_map) 39 | geo_map = self.transform(geo_map) 40 | training_mask = self.transform(training_mask) 41 | # return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容 42 | return img, img_path, score_map, geo_map, training_mask 43 | 44 | else: 45 | index = np.random.randint(0, self.__len__()) 46 | # print('Exception in getitem, and choose another index:{}'.format(index)) 47 | 48 | def __len__(self): 49 | return len(self.img_path_list) 50 | 51 | # img = bs * 512 * 512 *3 52 | # score_map = bs* 128 * 128 * 1 53 | # geo_map = bs * 128 * 128 * 5 54 | # training_mask = bs * 128 * 128 * 1 55 | -------------------------------------------------------------------------------- /dataset/test/README.txt: -------------------------------------------------------------------------------- 1 | Test dataset is saved in this folder -------------------------------------------------------------------------------- /dataset/test_compute_hmean/__pycache__/rrc_evaluation_funcs.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingliangzhang2018/EAST-Pytorch/082113e373f9815b62f449c77756f2a73a57c36e/dataset/test_compute_hmean/__pycache__/rrc_evaluation_funcs.cpython-35.pyc -------------------------------------------------------------------------------- /dataset/test_compute_hmean/__pycache__/script.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingliangzhang2018/EAST-Pytorch/082113e373f9815b62f449c77756f2a73a57c36e/dataset/test_compute_hmean/__pycache__/script.cpython-35.pyc -------------------------------------------------------------------------------- /dataset/test_compute_hmean/readme.txt: -------------------------------------------------------------------------------- 1 | INSTRUCTIONS FOR THE STANDALONE SCRIPTS 2 | Requirements: 3 | - Python version 2.7. 4 | - Each Task requires different Python modules. When running the script, if some module is not installed you will see a notification and installation instructions. 5 | 6 | Procedure: 7 | Download the ZIP file for the requested script and unzip it to a directory. 8 | 9 | Open a terminal in the directory and run the command: 10 | python script.py –g=gt.zip –s=submit.zip 11 | 12 | If you have already installed all the required modules, then you will see the method’s results or an error message if the submitted file is not correct. 13 | 14 | parameters: 15 | -g: Path of the Ground Truth file. In most cases, the Ground Truth will be included in the same Zip file named 'gt.zip', gt.txt' or 'gt.json'. If not, you will be able to get it on the Downloads page of the Task. 16 | -s: Path of your method's results file. 17 | 18 | Optional parameters: 19 | -o: Path to a directory where to copy the file ‘results.zip’ that contains per-sample results. 20 | -p: JSON string parameters to override the script default parameters. The parameters that can be overrided are inside the function 'default_evaluation_params' located at the begining of the evaluation Script. 21 | 22 | Example: python script.py –g=gt.zip –s=submit.zip –o=./ -p='{" IOU_CONSTRAINT" = 0.8}' -------------------------------------------------------------------------------- /dataset/test_compute_hmean/rrc_evaluation_funcs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # encoding: UTF-8 3 | import json 4 | import sys; 5 | 6 | sys.path.append('./') 7 | import zipfile 8 | import re 9 | import sys 10 | import os 11 | import codecs 12 | import importlib 13 | 14 | """ 15 | #from StringIO import StringIO 16 | try: 17 | from StringIO import StringIO 18 | except ImportError: 19 | from io import StringIO 20 | """ 21 | 22 | 23 | def print_help(): 24 | sys.stdout.write( 25 | 'Usage: python %s.py -g= -s= -o= [-i= -p=]' % 26 | sys.argv[0]) 27 | sys.exit(2) 28 | 29 | 30 | def load_zip_file_keys(file, fileNameRegExp=''): 31 | """ 32 | Returns an array with the entries of the ZIP file that match with the regular expression. 33 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 34 | """ 35 | try: 36 | archive = zipfile.ZipFile(file, mode='r', allowZip64=True) 37 | except: 38 | raise Exception('Error loading the ZIP archive.') 39 | 40 | pairs = [] 41 | 42 | for name in archive.namelist(): 43 | addFile = True 44 | keyName = name 45 | if fileNameRegExp != "": 46 | m = re.match(fileNameRegExp, name) 47 | if m == None: 48 | addFile = False 49 | else: 50 | if len(m.groups()) > 0: 51 | keyName = m.group(1) 52 | 53 | if addFile: 54 | pairs.append(keyName) 55 | 56 | return pairs 57 | 58 | 59 | def load_zip_file(file, fileNameRegExp='', allEntries=False): 60 | """ 61 | Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. 62 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 63 | allEntries validates that all entries in the ZIP file pass the fileNameRegExp 64 | """ 65 | try: 66 | archive = zipfile.ZipFile(file, mode='r', allowZip64=True) 67 | except: 68 | raise Exception('Error loading the ZIP archive') 69 | 70 | pairs = [] 71 | 72 | for name in archive.namelist(): 73 | addFile = True 74 | keyName = name 75 | if fileNameRegExp != "": 76 | m = re.match(fileNameRegExp, name) 77 | if m == None: 78 | addFile = False 79 | else: 80 | if len(m.groups()) > 0: 81 | keyName = m.group(1) 82 | 83 | if addFile: 84 | pairs.append([keyName, archive.read(name)]) 85 | else: 86 | if allEntries: 87 | raise Exception('ZIP entry not valid: %s' % name) 88 | 89 | return dict(pairs) 90 | 91 | 92 | def decode_utf8(raw): 93 | """ 94 | Returns a Unicode object on success, or None on failure 95 | """ 96 | try: 97 | raw = codecs.decode(raw, 'utf-8', 'replace') 98 | # extracts BOM if exists 99 | raw = raw.encode('utf8') 100 | if raw.startswith(codecs.BOM_UTF8): 101 | raw = raw.replace(codecs.BOM_UTF8, '', 1) 102 | return raw.decode('utf-8') 103 | except: 104 | return None 105 | 106 | 107 | def validate_lines_in_file(fileName, file_contents, CRLF=True, LTRB=True, withTranscription=False, withConfidence=False, 108 | imWidth=0, imHeight=0): 109 | """ 110 | This function validates that all lines of the file calling the Line validation function for each line 111 | """ 112 | utf8File = decode_utf8(file_contents) 113 | if (utf8File is None): 114 | raise Exception("The file %s is not UTF-8" % fileName) 115 | 116 | lines = utf8File.split("\r\n" if CRLF else "\n") 117 | for line in lines: 118 | line = line.replace("\r", "").replace("\n", "") 119 | if (line != ""): 120 | try: 121 | validate_tl_line(line, LTRB, withTranscription, withConfidence, imWidth, imHeight) 122 | except Exception as e: 123 | raise Exception( 124 | ("Line in sample not valid. Sample: %s Line: %s Error: %s" % (fileName, line, str(e))).encode( 125 | 'utf-8', 'replace')) 126 | 127 | 128 | def validate_tl_line(line, LTRB=True, withTranscription=True, withConfidence=True, imWidth=0, imHeight=0): 129 | """ 130 | Validate the format of the line. If the line is not valid an exception will be raised. 131 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 132 | Posible values are: 133 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 134 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 135 | """ 136 | get_tl_line_values(line, LTRB, withTranscription, withConfidence, imWidth, imHeight) 137 | 138 | 139 | def get_tl_line_values(line, LTRB=True, withTranscription=False, withConfidence=False, imWidth=0, imHeight=0): 140 | """ 141 | Validate the format of the line. If the line is not valid an exception will be raised. 142 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 143 | Posible values are: 144 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 145 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 146 | Returns values from a textline. Points , [Confidences], [Transcriptions] 147 | """ 148 | confidence = 0.0 149 | transcription = ""; 150 | points = [] 151 | 152 | numPoints = 4; 153 | 154 | if LTRB: 155 | 156 | numPoints = 4; 157 | 158 | if withTranscription and withConfidence: 159 | m = re.match( 160 | r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', line) 161 | if m == None: 162 | m = re.match( 163 | r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', 164 | line) 165 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") 166 | elif withConfidence: 167 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$', 168 | line) 169 | if m == None: 170 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") 171 | elif withTranscription: 172 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$', line) 173 | if m == None: 174 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") 175 | else: 176 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$', line) 177 | if m == None: 178 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") 179 | 180 | xmin = int(m.group(1)) 181 | ymin = int(m.group(2)) 182 | xmax = int(m.group(3)) 183 | ymax = int(m.group(4)) 184 | if (xmax < xmin): 185 | raise Exception("Xmax value (%s) not valid (Xmax < Xmin)." % (xmax)) 186 | if (ymax < ymin): 187 | raise Exception("Ymax value (%s) not valid (Ymax < Ymin)." % (ymax)) 188 | 189 | points = [float(m.group(i)) for i in range(1, (numPoints + 1))] 190 | 191 | if (imWidth > 0 and imHeight > 0): 192 | validate_point_inside_bounds(xmin, ymin, imWidth, imHeight); 193 | validate_point_inside_bounds(xmax, ymax, imWidth, imHeight); 194 | 195 | else: 196 | 197 | numPoints = 8; 198 | 199 | if withTranscription and withConfidence: 200 | m = re.match( 201 | r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$', 202 | line) 203 | if m == None: 204 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") 205 | elif withConfidence: 206 | m = re.match( 207 | r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$', 208 | line) 209 | if m == None: 210 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") 211 | elif withTranscription: 212 | m = re.match( 213 | r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$', 214 | line) 215 | if m == None: 216 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") 217 | else: 218 | m = re.match( 219 | r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$', 220 | line) 221 | if m == None: 222 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") 223 | 224 | points = [float(m.group(i)) for i in range(1, (numPoints + 1))] 225 | 226 | validate_clockwise_points(points) 227 | 228 | if (imWidth > 0 and imHeight > 0): 229 | validate_point_inside_bounds(points[0], points[1], imWidth, imHeight); 230 | validate_point_inside_bounds(points[2], points[3], imWidth, imHeight); 231 | validate_point_inside_bounds(points[4], points[5], imWidth, imHeight); 232 | validate_point_inside_bounds(points[6], points[7], imWidth, imHeight); 233 | 234 | if withConfidence: 235 | try: 236 | confidence = float(m.group(numPoints + 1)) 237 | except ValueError: 238 | raise Exception("Confidence value must be a float") 239 | 240 | if withTranscription: 241 | posTranscription = numPoints + (2 if withConfidence else 1) 242 | transcription = m.group(posTranscription) 243 | m2 = re.match(r'^\s*\"(.*)\"\s*$', transcription) 244 | if m2 != None: # Transcription with double quotes, we extract the value and replace escaped characters 245 | transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") 246 | 247 | return points, confidence, transcription 248 | 249 | 250 | def validate_point_inside_bounds(x, y, imWidth, imHeight): 251 | if (x < 0 or x > imWidth): 252 | raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" % (x, imWidth, imHeight)) 253 | if (y < 0 or y > imHeight): 254 | raise Exception( 255 | "Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" % (y, imWidth, imHeight)) 256 | 257 | 258 | def validate_clockwise_points(points): 259 | """ 260 | Validates that the points that the 4 points that dlimite a polygon are in clockwise order. 261 | """ 262 | 263 | if len(points) != 8: 264 | raise Exception("Points list not valid." + str(len(points))) 265 | 266 | point = [ 267 | [int(points[0]), int(points[1])], 268 | [int(points[2]), int(points[3])], 269 | [int(points[4]), int(points[5])], 270 | [int(points[6]), int(points[7])] 271 | ] 272 | edge = [ 273 | (point[1][0] - point[0][0]) * (point[1][1] + point[0][1]), 274 | (point[2][0] - point[1][0]) * (point[2][1] + point[1][1]), 275 | (point[3][0] - point[2][0]) * (point[3][1] + point[2][1]), 276 | (point[0][0] - point[3][0]) * (point[0][1] + point[3][1]) 277 | ] 278 | 279 | summatory = edge[0] + edge[1] + edge[2] + edge[3]; 280 | if summatory > 0: 281 | raise Exception( 282 | "Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.") 283 | 284 | 285 | def get_tl_line_values_from_file_contents(content, CRLF=True, LTRB=True, withTranscription=False, withConfidence=False, 286 | imWidth=0, imHeight=0, sort_by_confidences=True): 287 | """ 288 | Returns all points, confindences and transcriptions of a file in lists. Valid line formats: 289 | xmin,ymin,xmax,ymax,[confidence],[transcription] 290 | x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] 291 | """ 292 | pointsList = [] 293 | transcriptionsList = [] 294 | confidencesList = [] 295 | 296 | lines = content.split("\r\n" if CRLF else "\n") 297 | for line in lines: 298 | line = line.replace("\r", "").replace("\n", "") 299 | if (line != ""): 300 | points, confidence, transcription = get_tl_line_values(line, LTRB, withTranscription, withConfidence, 301 | imWidth, imHeight); 302 | pointsList.append(points) 303 | transcriptionsList.append(transcription) 304 | confidencesList.append(confidence) 305 | 306 | if withConfidence and len(confidencesList) > 0 and sort_by_confidences: 307 | confidencesList, pointsList, transcriptionsList = (list(t) for t in zip( 308 | *sorted(zip(confidencesList, pointsList, transcriptionsList), reverse=True))) 309 | 310 | return pointsList, confidencesList, transcriptionsList 311 | 312 | 313 | def main_evaluation(p, default_evaluation_params_fn, validate_data_fn, evaluate_method_fn, show_result=True, 314 | per_sample=True): 315 | """ 316 | This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. 317 | Params: 318 | p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. 319 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 320 | validate_data_fn: points to a method that validates the corrct format of the submission 321 | evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results 322 | """ 323 | # check path 324 | gt_path = p['g'] 325 | submit_path = p['s'] 326 | output = p['o'] 327 | # print('gt', gt_path) 328 | # print('submit', submit_path) 329 | # print('output', output) 330 | if (p == None): 331 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 332 | if (len(sys.argv) < 2): 333 | print_help() 334 | 335 | evalParams = default_evaluation_params_fn() 336 | if 'p' in p.keys(): 337 | evalParams.update(p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1])) 338 | 339 | resDict = {'calculated': True, 'Message': '', 'method': '{}', 'per_sample': '{}'} 340 | try: 341 | validate_data_fn(p['g'], p['s'], evalParams) 342 | evalData = evaluate_method_fn(p['g'], p['s'], evalParams) 343 | resDict.update(evalData) 344 | except Exception as e: 345 | resDict['Message'] = str(e) 346 | resDict['calculated'] = False 347 | # print('gt',p['g']) # addres of gt 348 | # print('sub',p['s']) 349 | 350 | if not os.path.exists(p['o']): 351 | os.makedirs(p['o']) 352 | 353 | resultsOutputname = p['o'] + '/results.zip' 354 | outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) 355 | 356 | del resDict['per_sample'] 357 | if 'output_items' in resDict.keys(): 358 | del resDict['output_items'] 359 | 360 | outZip.writestr('method.json', json.dumps(resDict)) 361 | 362 | if not resDict['calculated']: 363 | if show_result: 364 | sys.stderr.write('Error!\n' + resDict['Message'] + '\n\n') 365 | outZip.close() 366 | return resDict 367 | 368 | if per_sample == True: 369 | for k, v in evalData['per_sample'].items(): 370 | outZip.writestr(k + '.json', json.dumps(v)) 371 | 372 | if 'output_items' in evalData.keys(): 373 | for k, v in evalData['output_items'].iteritems(): 374 | outZip.writestr(k, v) 375 | 376 | outZip.close() 377 | 378 | if show_result: 379 | sys.stdout.write("Calculated!") 380 | sys.stdout.write(json.dumps(resDict['method'])) 381 | 382 | return resDict 383 | 384 | 385 | def main_validation(default_evaluation_params_fn, validate_data_fn): 386 | """ 387 | This process validates a method 388 | Params: 389 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 390 | validate_data_fn: points to a method that validates the corrct format of the submission 391 | """ 392 | try: 393 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 394 | evalParams = default_evaluation_params_fn() 395 | if 'p' in p.keys(): 396 | evalParams.update(p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1])) 397 | 398 | validate_data_fn(p['g'], p['s'], evalParams) 399 | print('SUCCESS') 400 | sys.exit(0) 401 | except Exception as e: 402 | print(str(e)) 403 | sys.exit(101) 404 | -------------------------------------------------------------------------------- /dataset/test_compute_hmean/script.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import sys 4 | 5 | sys.path.append('.') 6 | from collections import namedtuple 7 | from dataset.test_compute_hmean import rrc_evaluation_funcs 8 | import importlib 9 | from shapely.geometry import Polygon as plg 10 | import numpy as np 11 | 12 | 13 | def evaluation_imports(): 14 | """ 15 | evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation. 16 | """ 17 | return { 18 | 'Polygon': 'plg', 19 | 'numpy': 'np' 20 | } 21 | 22 | 23 | def default_evaluation_params(): 24 | """ 25 | default_evaluation_params: Default parameters to use for the validation and evaluation. 26 | """ 27 | return { 28 | 'IOU_CONSTRAINT': 0.5, 29 | 'AREA_PRECISION_CONSTRAINT': 0.5, 30 | 'GT_SAMPLE_NAME_2_ID': 'gt_img_([0-9]+).txt', 31 | 'DET_SAMPLE_NAME_2_ID': 'res_img_([0-9]+).txt', 32 | 'LTRB': False, # LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) 33 | 'CRLF': False, # Lines are delimited by Windows CRLF format 34 | 'CONFIDENCES': False, # Detections must include confidence value. AP will be calculated 35 | 'PER_SAMPLE_RESULTS': True # Generate per sample results and produce data for visualization 36 | } 37 | 38 | 39 | def validate_data(gtFilePath, submFilePath, evaluationParams): 40 | """ 41 | Method validate_data: validates that all files in the results folder are correct (have the correct name contents). 42 | Validates also that there are no missing files in the folder. 43 | If some error detected, the method raises the error 44 | """ 45 | gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) 46 | 47 | subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) 48 | 49 | # Validate format of GroundTruth 50 | for k in gt: 51 | rrc_evaluation_funcs.validate_lines_in_file(k, gt[k], evaluationParams['CRLF'], evaluationParams['LTRB'], True) 52 | 53 | # Validate format of results 54 | for k in subm: 55 | if (k in gt) == False: 56 | raise Exception("The sample %s not present in GT" % k) 57 | 58 | rrc_evaluation_funcs.validate_lines_in_file(k, subm[k], evaluationParams['CRLF'], evaluationParams['LTRB'], 59 | False, evaluationParams['CONFIDENCES']) 60 | 61 | 62 | def evaluate_method(gtFilePath, submFilePath, evaluationParams): 63 | """ 64 | Method evaluate_method: evaluate method and returns the results 65 | Results. Dictionary with the following values: 66 | - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } 67 | - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } 68 | """ 69 | """ 70 | for module,alias in evaluation_imports().items(): 71 | globals()[alias] = importlib.import_module(module) 72 | """ 73 | 74 | def polygon_from_points(points): 75 | """ 76 | Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 77 | """ 78 | resBoxes = np.empty([1, 8], dtype='int32') 79 | resBoxes[0, 0] = int(points[0]) 80 | resBoxes[0, 4] = int(points[1]) 81 | resBoxes[0, 1] = int(points[2]) 82 | resBoxes[0, 5] = int(points[3]) 83 | resBoxes[0, 2] = int(points[4]) 84 | resBoxes[0, 6] = int(points[5]) 85 | resBoxes[0, 3] = int(points[6]) 86 | resBoxes[0, 7] = int(points[7]) 87 | pointMat = resBoxes[0].reshape([2, 4]).T 88 | return plg(pointMat) 89 | 90 | def rectangle_to_polygon(rect): 91 | resBoxes = np.empty([1, 8], dtype='int32') 92 | resBoxes[0, 0] = int(rect.xmin) 93 | resBoxes[0, 4] = int(rect.ymax) 94 | resBoxes[0, 1] = int(rect.xmin) 95 | resBoxes[0, 5] = int(rect.ymin) 96 | resBoxes[0, 2] = int(rect.xmax) 97 | resBoxes[0, 6] = int(rect.ymin) 98 | resBoxes[0, 3] = int(rect.xmax) 99 | resBoxes[0, 7] = int(rect.ymax) 100 | 101 | pointMat = resBoxes[0].reshape([2, 4]).T 102 | 103 | return plg(pointMat) 104 | 105 | def rectangle_to_points(rect): 106 | points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), 107 | int(rect.xmin), int(rect.ymin)] 108 | return points 109 | 110 | def get_union(pD, pG): 111 | areaA = pD.area; 112 | areaB = pG.area; 113 | return areaA + areaB - get_intersection(pD, pG); 114 | 115 | def get_intersection_over_union(pD, pG): 116 | try: 117 | return get_intersection(pD, pG) / get_union(pD, pG); 118 | except: 119 | return 0 120 | 121 | def get_intersection(pD, pG): 122 | pInt = pD & pG 123 | try: 124 | if len(pInt) == 0: 125 | return 0 126 | except: 127 | return pInt.area 128 | 129 | def compute_ap(confList, matchList, numGtCare): 130 | correct = 0 131 | AP = 0 132 | if len(confList) > 0: 133 | confList = np.array(confList) 134 | matchList = np.array(matchList) 135 | sorted_ind = np.argsort(-confList) 136 | confList = confList[sorted_ind] 137 | matchList = matchList[sorted_ind] 138 | for n in range(len(confList)): 139 | match = matchList[n] 140 | if match: 141 | correct += 1 142 | AP += float(correct) / (n + 1) 143 | 144 | if numGtCare > 0: 145 | AP /= numGtCare 146 | 147 | return AP 148 | 149 | perSampleMetrics = {} 150 | 151 | matchedSum = 0 152 | 153 | Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') 154 | 155 | gt = rrc_evaluation_funcs.load_zip_file(gtFilePath, evaluationParams['GT_SAMPLE_NAME_2_ID']) 156 | subm = rrc_evaluation_funcs.load_zip_file(submFilePath, evaluationParams['DET_SAMPLE_NAME_2_ID'], True) 157 | 158 | numGlobalCareGt = 0; 159 | numGlobalCareDet = 0; 160 | 161 | arrGlobalConfidences = []; 162 | arrGlobalMatches = []; 163 | 164 | for ids, resFile in enumerate(gt): 165 | 166 | gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile]) 167 | recall = 0 168 | precision = 0 169 | hmean = 0 170 | 171 | detMatched = 0 172 | 173 | iouMat = np.empty([1, 1]) 174 | 175 | gtPols = [] 176 | detPols = [] 177 | 178 | gtPolPoints = [] 179 | detPolPoints = [] 180 | 181 | # Array of Ground Truth Polygons' keys marked as don't Care 182 | gtDontCarePolsNum = [] 183 | # Array of Detected Polygons' matched with a don't Care GT 184 | detDontCarePolsNum = [] 185 | 186 | pairs = [] 187 | detMatchedNums = [] 188 | 189 | arrSampleConfidences = []; 190 | arrSampleMatch = []; 191 | sampleAP = 0; 192 | 193 | evaluationLog = "" 194 | 195 | pointsList, _, transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile, 196 | evaluationParams[ 197 | 'CRLF'], 198 | evaluationParams[ 199 | 'LTRB'], 200 | True, False) 201 | for n in range(len(pointsList)): 202 | points = pointsList[n] 203 | transcription = transcriptionsList[n] 204 | dontCare = transcription == "###" 205 | if evaluationParams['LTRB']: 206 | gtRect = Rectangle(*points) 207 | gtPol = rectangle_to_polygon(gtRect) 208 | else: 209 | gtPol = polygon_from_points(points) 210 | gtPols.append(gtPol) 211 | gtPolPoints.append(points) 212 | if dontCare: 213 | gtDontCarePolsNum.append(len(gtPols) - 1) 214 | 215 | evaluationLog += "GT polygons: " + str(len(gtPols)) + ( 216 | " (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum) > 0 else "\n") 217 | 218 | if resFile in subm: 219 | 220 | detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile]) 221 | 222 | pointsList, confidencesList, _ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile, 223 | evaluationParams[ 224 | 'CRLF'], 225 | evaluationParams[ 226 | 'LTRB'], 227 | False, 228 | evaluationParams[ 229 | 'CONFIDENCES']) 230 | for n in range(len(pointsList)): 231 | points = pointsList[n] 232 | 233 | if evaluationParams['LTRB']: 234 | detRect = Rectangle(*points) 235 | detPol = rectangle_to_polygon(detRect) 236 | else: 237 | detPol = polygon_from_points(points) 238 | detPols.append(detPol) 239 | detPolPoints.append(points) 240 | if len(gtDontCarePolsNum) > 0: 241 | for dontCarePol in gtDontCarePolsNum: 242 | dontCarePol = gtPols[dontCarePol] 243 | intersected_area = get_intersection(dontCarePol, detPol) 244 | pdDimensions = detPol.area 245 | precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions 246 | if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT']): 247 | detDontCarePolsNum.append(len(detPols) - 1) 248 | break 249 | 250 | evaluationLog += "DET polygons: " + str(len(detPols)) + ( 251 | " (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum) > 0 else "\n") 252 | 253 | if len(gtPols) > 0 and len(detPols) > 0: 254 | # Calculate IoU and precision matrixs 255 | outputShape = [len(gtPols), len(detPols)] 256 | iouMat = np.empty(outputShape) 257 | gtRectMat = np.zeros(len(gtPols), np.int8) 258 | detRectMat = np.zeros(len(detPols), np.int8) 259 | for gtNum in range(len(gtPols)): 260 | for detNum in range(len(detPols)): 261 | pG = gtPols[gtNum] 262 | pD = detPols[detNum] 263 | iouMat[gtNum, detNum] = get_intersection_over_union(pD, pG) 264 | 265 | for gtNum in range(len(gtPols)): 266 | match = False; 267 | for detNum in range(len(detPols)): 268 | if gtRectMat[gtNum] == 0 and detRectMat[ 269 | detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum: 270 | if iouMat[gtNum, detNum] > evaluationParams['IOU_CONSTRAINT']: 271 | gtRectMat[gtNum] = 1 272 | detRectMat[detNum] = 1 273 | detMatched += 1 274 | pairs.append({'gt': gtNum, 'det': detNum}) 275 | detMatchedNums.append(detNum) 276 | evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n" 277 | match = True 278 | 279 | if evaluationParams['CONFIDENCES']: 280 | for detNum in range(len(detPols)): 281 | if detNum not in detDontCarePolsNum: 282 | # we exclude the don't care detections 283 | match = detNum in detMatchedNums 284 | 285 | arrSampleConfidences.append(confidencesList[detNum]) 286 | arrSampleMatch.append(match) 287 | 288 | arrGlobalConfidences.append(confidencesList[detNum]); 289 | arrGlobalMatches.append(match); 290 | 291 | numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) 292 | numDetCare = (len(detPols) - len(detDontCarePolsNum)) 293 | if numGtCare == 0: 294 | recall = float(1) 295 | precision = float(0) if numDetCare > 0 else float(1) 296 | sampleAP = precision 297 | else: 298 | recall = float(detMatched) / numGtCare 299 | precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare 300 | if evaluationParams['CONFIDENCES'] and evaluationParams['PER_SAMPLE_RESULTS']: 301 | sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare) 302 | 303 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 304 | print('==' * 28) 305 | print('ID:{:3d} P {:3d}% R {:3d}% Hmean {:3d}% Matched:{:2d} GT:{:2d} Det:{:2d}'.format(ids + 1, 306 | int(precision * 100), 307 | int(recall * 100), 308 | int(hmean * 100), 309 | detMatched, numGtCare, 310 | numDetCare)) 311 | matchedSum += detMatched 312 | numGlobalCareGt += numGtCare 313 | numGlobalCareDet += numDetCare 314 | 315 | if evaluationParams['PER_SAMPLE_RESULTS']: 316 | perSampleMetrics[resFile] = { 317 | 'precision': precision, 318 | 'recall': recall, 319 | 'hmean': hmean, 320 | 'pairs': pairs, 321 | 'AP': sampleAP, 322 | 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(), 323 | 'gtPolPoints': gtPolPoints, 324 | 'detPolPoints': detPolPoints, 325 | 'gtDontCare': gtDontCarePolsNum, 326 | 'detDontCare': detDontCarePolsNum, 327 | 'evaluationParams': evaluationParams, 328 | 'evaluationLog': evaluationLog 329 | } 330 | 331 | # Compute MAP and MAR 332 | AP = 0 333 | if evaluationParams['CONFIDENCES']: 334 | AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt) 335 | 336 | methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum) / numGlobalCareGt 337 | methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum) / numGlobalCareDet 338 | methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * methodRecall * methodPrecision / ( 339 | methodRecall + methodPrecision) 340 | 341 | methodMetrics = {'precision': methodPrecision, 'recall': methodRecall, 'hmean': methodHmean, 'AP': AP} 342 | 343 | resDict = {'calculated': True, 'Message': '', 'method': methodMetrics, 'per_sample': perSampleMetrics} 344 | 345 | return resDict; 346 | 347 | 348 | if __name__ == '__main__': 349 | rrc_evaluation_funcs.main_evaluation(None, default_evaluation_params, validate_data, evaluate_method) 350 | -------------------------------------------------------------------------------- /dataset/test_result/README: -------------------------------------------------------------------------------- 1 | Test result (txt files) are saved in this folder. 2 | -------------------------------------------------------------------------------- /dataset/train/README.txt: -------------------------------------------------------------------------------- 1 | Train dataset is saved in this folder -------------------------------------------------------------------------------- /demo/result_img/img_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingliangzhang2018/EAST-Pytorch/082113e373f9815b62f449c77756f2a73a57c36e/demo/result_img/img_1.jpg -------------------------------------------------------------------------------- /demo/result_img/img_16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingliangzhang2018/EAST-Pytorch/082113e373f9815b62f449c77756f2a73a57c36e/demo/result_img/img_16.jpg -------------------------------------------------------------------------------- /demo/result_img/img_16.txt: -------------------------------------------------------------------------------- 1 | 515,457,662,455,663,484,515,486 2 | 673,451,795,452,795,487,673,486 3 | 1124,190,1203,191,1202,226,1124,225 4 | 1039,188,1116,188,1116,223,1039,223 5 | 1035,225,1088,224,1088,250,1036,251 6 | 1112,227,1187,227,1187,251,1112,251 7 | 741,262,774,261,774,273,742,274 8 | 632,268,664,270,663,284,632,282 9 | 665,270,689,271,688,284,665,283 10 | 748,271,775,271,775,284,748,284 11 | -------------------------------------------------------------------------------- /demo/result_img/img_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingliangzhang2018/EAST-Pytorch/082113e373f9815b62f449c77756f2a73a57c36e/demo/result_img/img_2.jpg -------------------------------------------------------------------------------- /demo/result_img/res_img_1.txt: -------------------------------------------------------------------------------- 1 | 375,-8,500,-2,497,81,372,76 2 | 490,153,550,152,550,169,491,170 3 | 375,196,421,197,421,210,374,210 4 | 393,115,450,116,450,130,393,129 5 | 495,191,526,190,527,203,495,204 6 | 373,155,406,155,406,170,373,170 7 | 487,116,520,116,520,129,487,128 8 | -------------------------------------------------------------------------------- /demo/result_img/res_img_2.txt: -------------------------------------------------------------------------------- 1 | 744,303,801,318,789,372,732,357 2 | 601,174,633,176,632,196,601,194 3 | -------------------------------------------------------------------------------- /demo/test_img/img_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingliangzhang2018/EAST-Pytorch/082113e373f9815b62f449c77756f2a73a57c36e/demo/test_img/img_1.jpg -------------------------------------------------------------------------------- /demo/test_img/img_16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingliangzhang2018/EAST-Pytorch/082113e373f9815b62f449c77756f2a73a57c36e/demo/test_img/img_16.jpg -------------------------------------------------------------------------------- /demo/test_img/img_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingliangzhang2018/EAST-Pytorch/082113e373f9815b62f449c77756f2a73a57c36e/demo/test_img/img_2.jpg -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import config as cfg 3 | from model import East 4 | import torch 5 | import utils 6 | import preprossing 7 | import cv2 8 | import numpy as np 9 | import time 10 | 11 | 12 | def predict(model, epoch): 13 | 14 | model.eval() 15 | img_path_list = preprossing.get_images(cfg.test_img_path) 16 | 17 | for index in range(len(img_path_list)): 18 | 19 | im_fn = img_path_list[index] 20 | im = cv2.imread(im_fn)[:,:,::-1] 21 | if im is None: 22 | print("can not find image of %s" % (im_fn)) 23 | continue 24 | 25 | print('EAST <==> TEST <==> epoch:{}, idx:{} <==> Begin'.format(epoch, index)) 26 | # 图像进行放缩 27 | im_resized, (ratio_h, ratio_w) = utils.resize_image(im) 28 | im_resized = im_resized.astype(np.float32) 29 | # 图像转换成tensor格式 30 | im_resized = im_resized.transpose(2, 0, 1) 31 | im_tensor = torch.from_numpy(im_resized) 32 | im_tensor = im_tensor.cuda() 33 | # 图像数据增加一维 34 | im_tensor = im_tensor.unsqueeze(0) 35 | 36 | timer = {'net': 0, 'restore': 0, 'nms': 0} 37 | start = time.time() 38 | 39 | # 输入网络进行推断 40 | score, geometry = model(im_tensor) 41 | 42 | timer['net'] = time.time() - start 43 | # score与geometry转换成numpy格式 44 | score = score.permute(0, 2, 3, 1) 45 | geometry = geometry.permute(0, 2, 3, 1) 46 | score = score.data.cpu().numpy() 47 | geometry = geometry.data.cpu().numpy() 48 | # 文本框检测 49 | boxes, timer = utils.detect(score_map=score, geo_map=geometry, timer=timer, 50 | score_map_thresh=cfg.score_map_thresh, box_thresh=cfg.box_thresh, 51 | nms_thres=cfg.box_thresh) 52 | print('EAST <==> TEST <==> idx:{} <==> model:{:.2f}ms, restore:{:.2f}ms, nms:{:.2f}ms' 53 | .format(index, timer['net'] * 1000, timer['restore'] * 1000, timer['nms'] * 1000)) 54 | if boxes is not None: 55 | boxes = boxes[:, :8].reshape((-1, 4, 2)) 56 | boxes[:, :, 0] /= ratio_w 57 | boxes[:, :, 1] /= ratio_h 58 | 59 | # save to txt file 60 | if boxes is not None: 61 | res_file = os.path.join( 62 | cfg.res_img_path, 63 | 'res_{}.txt'.format( 64 | os.path.basename(im_fn).split('.')[0])) 65 | 66 | with open(res_file, 'w') as f: 67 | for box in boxes: 68 | # to avoid submitting errors 69 | box = utils.sort_poly(box.astype(np.int32)) 70 | if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5: 71 | continue 72 | f.write('{},{},{},{},{},{},{},{}\r\n'.format( 73 | box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0], box[2, 1], box[3, 0], box[3, 1], 74 | )) 75 | cv2.polylines(im[:,:,::-1], [box.astype(np.int32).reshape((-1, 1, 2))], True, 76 | color=(255, 255, 0), thickness=1) 77 | print('EAST <==> TEST <==> Save txt at:{} <==> Done'.format(res_file)) 78 | 79 | # 图片输出 80 | if cfg.write_images: 81 | img_path = os.path.join(cfg.res_img_path, os.path.basename(im_fn)) 82 | cv2.imwrite(img_path, im[:,:,::-1]) 83 | print('EAST <==> TEST <==> Save image at:{} <==> Done'.format(img_path)) 84 | 85 | print('EAST <==> TEST <==> Record and Save <==> epoch:{}, ids:{} <==> Done'.format(epoch, index)) 86 | 87 | 88 | def main(): 89 | # prepare output directory 90 | # global epoch 91 | print('EAST <==> TEST <==> Create Res_file and Img_with_box <==> Begin') 92 | result_root = os.path.abspath(cfg.res_img_path) 93 | if not os.path.exists(result_root): 94 | os.mkdir(result_root) 95 | 96 | print('EAST <==> Prepare <==> Network <==> Begin') 97 | model = East() 98 | model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids) 99 | model.cuda() 100 | # 载入模型 101 | if os.path.isfile(cfg.checkpoint): 102 | print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(cfg.checkpoint)) 103 | checkpoint = torch.load(cfg.checkpoint) 104 | epoch = checkpoint['epoch'] 105 | model.load_state_dict(checkpoint['state_dict']) 106 | print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(cfg.checkpoint)) 107 | else: 108 | print('Can not find checkpoint !!!') 109 | exit(1) 110 | 111 | predict(model, epoch) 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | 117 | img_path_list = preprossing.get_images(cfg.test_img_path) -------------------------------------------------------------------------------- /locality_aware_nms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from shapely.geometry import Polygon 3 | 4 | 5 | def intersection(g, p): 6 | g = Polygon(g[:8].reshape((4, 2))) 7 | p = Polygon(p[:8].reshape((4, 2))) 8 | if not g.is_valid or not p.is_valid: 9 | return 0 10 | inter = Polygon(g).intersection(Polygon(p)).area 11 | union = g.area + p.area - inter 12 | if union == 0: 13 | return 0 14 | else: 15 | return inter/union 16 | 17 | 18 | def weighted_merge(g, p): 19 | # g[0]=min(g[0],p[0]) 20 | # g[1] = min(g[1], p[1]) 21 | # g[4] = max(g[4], p[4]) 22 | # g[5]= max(g[5],p[5]) 23 | # 24 | # g[2] = max(g[2], p[2]) 25 | # g[3] = min(g[3], p[3]) 26 | # g[6] = min(g[6], p[6]) 27 | # g[7] = max(g[7], p[7]) 28 | 29 | g[:8] = (g[8] * g[:8] + p[8] * p[:8])/(g[8] + p[8]) 30 | g[8] = (g[8] + p[8]) 31 | return g 32 | 33 | 34 | def standard_nms(S, thres): 35 | order = np.argsort(S[:, 8])[::-1] 36 | keep = [] 37 | while order.size > 0: 38 | i = order[0] 39 | keep.append(i) 40 | ovr = np.array([intersection(S[i], S[t]) for t in order[1:]]) 41 | 42 | inds = np.where(ovr <= thres)[0] 43 | order = order[inds+1] 44 | 45 | return S[keep] 46 | 47 | 48 | def nms_locality(polys, thres=0.3): 49 | ''' 50 | locality aware nms of EAST 51 | :param polys: a N*9 numpy array. first 8 coordinates, then prob 52 | :return: boxes after nms 53 | ''' 54 | S = [] 55 | p = None 56 | for g in polys: 57 | if p is not None and intersection(g, p) > thres: 58 | p = weighted_merge(g, p) 59 | else: 60 | if p is not None: 61 | S.append(p) 62 | p = g 63 | if p is not None: 64 | S.append(p) 65 | 66 | if len(S) == 0: 67 | return np.array([]) 68 | return standard_nms(np.array(S), thres) 69 | 70 | 71 | if __name__ == '__main__': 72 | # 343,350,448,135,474,143,369,359 73 | print(Polygon(np.array([[343, 350], [448, 135], 74 | [474, 143], [369, 359]])).area) 75 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # 此处默认真实值和预测值的格式均为 bs *channels * W * H,即为torch数据格式 6 | 7 | def dice_coefficient(y_true_cls, y_pred_cls, 8 | training_mask): 9 | ''' 10 | dice loss 11 | :param y_true_cls: 12 | :param y_pred_cls: 13 | :param training_mask: 14 | :return: 15 | ''' 16 | eps = 1e-5 17 | intersection = torch.sum(y_true_cls * y_pred_cls * training_mask) 18 | union = torch.sum(y_true_cls * training_mask) + torch.sum(y_pred_cls * training_mask) + eps 19 | loss = 1. - (2 * intersection / union) 20 | return loss 21 | 22 | 23 | class LossFunc(nn.Module): 24 | def __init__(self): 25 | super(LossFunc, self).__init__() 26 | return 27 | 28 | def forward(self, y_true_cls, y_pred_cls, y_true_geo, y_pred_geo, training_mask): 29 | ''' 30 | define the loss used for training, contraning two part, 31 | the first part we use dice loss instead of weighted logloss, 32 | the second part is the iou loss defined in the paper 33 | :param y_true_cls: ground truth of text 34 | :param y_pred_cls: prediction os text 35 | :param y_true_geo: ground truth of geometry 36 | :param y_pred_geo: prediction of geometry 37 | :param training_mask: mask used in training, to ignore some text annotated by ### 38 | :return: 39 | ''' 40 | 41 | # score交叉熵 42 | classification_loss = dice_coefficient(y_true_cls, y_pred_cls, training_mask) 43 | # scale classification loss to match the iou loss part 44 | classification_loss *= 0.01 45 | 46 | # IOU loss计算 47 | # d1 -> top, d2->right, d3->bottom, d4->left 48 | d1_gt, d2_gt, d3_gt, d4_gt, theta_gt = torch.split(y_true_geo, 1, 1) 49 | d1_pred, d2_pred, d3_pred, d4_pred, theta_pred = torch.split(y_pred_geo, 1, 1) 50 | area_gt = (d1_gt + d3_gt) * (d2_gt + d4_gt) 51 | area_pred = (d1_pred + d3_pred) * (d2_pred + d4_pred) 52 | w_union = torch.min(d2_gt, d2_pred) + torch.min(d4_gt, d4_pred) 53 | h_union = torch.min(d1_gt, d1_pred) + torch.min(d3_gt, d3_pred) 54 | area_intersect = w_union * h_union 55 | area_union = area_gt + area_pred - area_intersect 56 | L_AABB = -torch.log((area_intersect + 1.0) / (area_union + 1.0)) 57 | L_theta = 1 - torch.cos(theta_pred - theta_gt) 58 | L_g = L_AABB + 20 * L_theta 59 | 60 | return torch.mean(L_g * y_true_cls * training_mask) + classification_loss 61 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | import config as cfg 5 | import utils 6 | 7 | 8 | def conv_bn(inp, oup, stride): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | nn.BatchNorm2d(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | class InvertedResidual(nn.Module): 17 | def __init__(self, inp, oup, stride, expand_ratio): 18 | super(InvertedResidual, self).__init__() 19 | self.stride = stride 20 | assert stride in [1, 2] 21 | 22 | hidden_dim = round(inp * expand_ratio) 23 | self.use_res_connect = self.stride == 1 and inp == oup 24 | 25 | if expand_ratio == 1: 26 | self.conv = nn.Sequential( 27 | # dw 28 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 29 | nn.BatchNorm2d(hidden_dim), 30 | nn.ReLU6(inplace=True), 31 | # pw-linear 32 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 33 | nn.BatchNorm2d(oup), 34 | ) 35 | else: 36 | self.conv = nn.Sequential( 37 | # pw 38 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 39 | nn.BatchNorm2d(hidden_dim), 40 | nn.ReLU6(inplace=True), 41 | # dw 42 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 43 | nn.BatchNorm2d(hidden_dim), 44 | nn.ReLU6(inplace=True), 45 | # pw-linear 46 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 47 | nn.BatchNorm2d(oup), 48 | ) 49 | 50 | def forward(self, x): 51 | if self.use_res_connect: 52 | return x + self.conv(x) 53 | else: 54 | return self.conv(x) 55 | 56 | 57 | class MobileNetV2(nn.Module): 58 | def __init__(self, width_mult=1.): 59 | super(MobileNetV2, self).__init__() 60 | block = InvertedResidual 61 | input_channel = 32 62 | last_channel = 1280 63 | interverted_residual_setting = [ 64 | # t, c, n, s 65 | [1, 16, 1, 1], 66 | [6, 24, 2, 2], 67 | [6, 32, 3, 2], 68 | [6, 64, 4, 2], 69 | [6, 96, 3, 1], 70 | [6, 160, 3, 2], 71 | # [6, 320, 1, 1], 72 | ] 73 | 74 | # building first layer 75 | # assert input_size % 32 == 0 76 | input_channel = int(input_channel * width_mult) 77 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 78 | self.features = [conv_bn(3, input_channel, 2)] 79 | # building inverted residual blocks 80 | for t, c, n, s in interverted_residual_setting: 81 | output_channel = int(c * width_mult) 82 | for i in range(n): 83 | if i == 0: 84 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 85 | else: 86 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 87 | input_channel = output_channel 88 | 89 | # make it nn.Sequential 90 | self.features = nn.Sequential(*self.features) 91 | 92 | self._initialize_weights() 93 | 94 | def forward(self, x): 95 | x = self.features(x) 96 | # x = x.mean(3).mean(2) 97 | # x = self.classifier(x) 98 | return x 99 | 100 | def _initialize_weights(self): 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 104 | m.weight.data.normal_(0, math.sqrt(2. / n)) 105 | if m.bias is not None: 106 | m.bias.data.zero_() 107 | elif isinstance(m, nn.BatchNorm2d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | elif isinstance(m, nn.Linear): 111 | n = m.weight.size(1) 112 | m.weight.data.normal_(0, 0.01) 113 | m.bias.data.zero_() 114 | 115 | 116 | def mobilenet(pretrained=True, **kwargs): 117 | """ 118 | Constructs a ResNet-50 model. 119 | Args: 120 | pretrained (bool): If True, returns a model pre-trained on ImageNet 121 | """ 122 | model = MobileNetV2() 123 | if pretrained: 124 | model_dict = model.state_dict() 125 | pretrained_dict = torch.load(cfg.pretrained_basemodel_path) 126 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 127 | model_dict.update(pretrained_dict) 128 | model.load_state_dict(model_dict) 129 | # state_dict = torch.load(cfg.pretrained_basemodel_path) # add map_location='cpu' if no gpu 130 | # model.load_state_dict(state_dict) 131 | 132 | return model 133 | 134 | 135 | class East(nn.Module): 136 | def __init__(self): 137 | super(East, self).__init__() 138 | self.mobilenet = mobilenet(True) 139 | # self.si for stage i 140 | self.s1 = nn.Sequential(*list(self.mobilenet.children())[0][0:4]) 141 | self.s2 = nn.Sequential(*list(self.mobilenet.children())[0][4:7]) 142 | self.s3 = nn.Sequential(*list(self.mobilenet.children())[0][7:14]) 143 | self.s4 = nn.Sequential(*list(self.mobilenet.children())[0][14:17]) 144 | 145 | self.conv1 = nn.Conv2d(160+96, 128, 1) 146 | self.bn1 = nn.BatchNorm2d(128) 147 | self.relu1 = nn.ReLU() 148 | 149 | self.conv2 = nn.Conv2d(128, 128, 3, padding=1) 150 | self.bn2 = nn.BatchNorm2d(128) 151 | self.relu2 = nn.ReLU() 152 | 153 | self.conv3 = nn.Conv2d(128+32, 64, 1) 154 | self.bn3 = nn.BatchNorm2d(64) 155 | self.relu3 = nn.ReLU() 156 | 157 | self.conv4 = nn.Conv2d(64, 64, 3, padding=1) 158 | self.bn4 = nn.BatchNorm2d(64) 159 | self.relu4 = nn.ReLU() 160 | 161 | self.conv5 = nn.Conv2d(64+24, 64, 1) 162 | self.bn5 = nn.BatchNorm2d(64) 163 | self.relu5 = nn.ReLU() 164 | 165 | self.conv6 = nn.Conv2d(64, 32, 3, padding=1) 166 | self.bn6 = nn.BatchNorm2d(32) 167 | self.relu6 = nn.ReLU() 168 | 169 | self.conv7 = nn.Conv2d(32, 32, 3, padding=1) 170 | self.bn7 = nn.BatchNorm2d(32) 171 | self.relu7 = nn.ReLU() 172 | 173 | self.conv8 = nn.Conv2d(32, 1, 1) 174 | self.sigmoid1 = nn.Sigmoid() 175 | self.conv9 = nn.Conv2d(32, 4, 1) 176 | self.sigmoid2 = nn.Sigmoid() 177 | self.conv10 = nn.Conv2d(32, 1, 1) 178 | self.sigmoid3 = nn.Sigmoid() 179 | self.unpool1 = nn.Upsample(scale_factor=2, mode='bilinear') 180 | self.unpool2 = nn.Upsample(scale_factor=2, mode='bilinear') 181 | self.unpool3 = nn.Upsample(scale_factor=2, mode='bilinear') 182 | 183 | # utils.init_weights([self.conv1,self.conv2,self.conv3,self.conv4, 184 | # self.conv5,self.conv6,self.conv7,self.conv8, 185 | # self.conv9,self.conv10,self.bn1,self.bn2, 186 | # self.bn3,self.bn4,self.bn5,self.bn6,self.bn7]) 187 | 188 | def forward(self, images): 189 | images = utils.mean_image_subtraction(images) 190 | 191 | f0 = self.s1(images) 192 | f1 = self.s2(f0) 193 | f2 = self.s3(f1) 194 | f3 = self.s4(f2) 195 | 196 | # _, f = self.mobilenet(images) 197 | h = f3 # bs 2048 w/32 h/32 198 | g = (self.unpool1(h)) # bs 2048 w/16 h/16 199 | c = self.conv1(torch.cat((g, f2), 1)) 200 | c = self.bn1(c) 201 | c = self.relu1(c) 202 | 203 | h = self.conv2(c) # bs 128 w/16 h/16 204 | h = self.bn2(h) 205 | h = self.relu2(h) 206 | g = self.unpool2(h) # bs 128 w/8 h/8 207 | c = self.conv3(torch.cat((g, f1), 1)) 208 | c = self.bn3(c) 209 | c = self.relu3(c) 210 | 211 | h = self.conv4(c) # bs 64 w/8 h/8 212 | h = self.bn4(h) 213 | h = self.relu4(h) 214 | g = self.unpool3(h) # bs 64 w/4 h/4 215 | c = self.conv5(torch.cat((g, f0), 1)) 216 | c = self.bn5(c) 217 | c = self.relu5(c) 218 | 219 | h = self.conv6(c) # bs 32 w/4 h/4 220 | h = self.bn6(h) 221 | h = self.relu6(h) 222 | g = self.conv7(h) # bs 32 w/4 h/4 223 | g = self.bn7(g) 224 | g = self.relu7(g) 225 | 226 | F_score = self.conv8(g) # bs 1 w/4 h/4 227 | F_score = self.sigmoid1(F_score) 228 | geo_map = self.conv9(g) 229 | geo_map = self.sigmoid2(geo_map) * 512 230 | angle_map = self.conv10(g) 231 | angle_map = self.sigmoid3(angle_map) 232 | angle_map = (angle_map - 0.5) * math.pi / 2 233 | 234 | F_geometry = torch.cat((geo_map, angle_map), 1) # bs 5 w/4 h/4 235 | 236 | return F_score, F_geometry 237 | 238 | 239 | model=East() 240 | print(model) -------------------------------------------------------------------------------- /preprossing.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import glob 3 | import csv 4 | import cv2 5 | import os 6 | import numpy as np 7 | from shapely.geometry import Polygon 8 | import config as cfg 9 | import utils 10 | 11 | 12 | def get_images(img_root): 13 | files = [] 14 | for ext in ['jpg']: 15 | files.extend(glob.glob( 16 | os.path.join(img_root, '*.{}'.format(ext)))) 17 | # print(glob.glob( 18 | # os.path.join(FLAGS.training_data_path, '*.{}'.format(ext)))) 19 | return files 20 | 21 | 22 | def load_annoataion(p): 23 | ''' 24 | load annotation from the text file 25 | :param p: 26 | :return: 27 | ''' 28 | text_polys = [] 29 | text_tags = [] 30 | if not os.path.exists(p): 31 | return np.array(text_polys, dtype=np.float32) 32 | with open(p, 'r', encoding='UTF-8') as f: 33 | reader = csv.reader(f) 34 | for line in reader: 35 | label = line[-1] 36 | # strip BOM. \ufeff for python3, \xef\xbb\bf for python2 37 | line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in line] 38 | 39 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8])) 40 | text_polys.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) 41 | # print(text_polys) 42 | if label == '*' or label == '###': 43 | text_tags.append(True) 44 | else: 45 | text_tags.append(False) 46 | return np.array(text_polys, dtype=np.float32), np.array(text_tags, dtype=np.bool) 47 | 48 | 49 | def polygon_area(poly): 50 | ''' 51 | compute area of a polygon 52 | :param poly: 53 | :return: 54 | ''' 55 | edge = [ 56 | (poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), 57 | (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), 58 | (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), 59 | (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1]) 60 | ] 61 | return np.sum(edge) / 2. 62 | 63 | 64 | def check_and_validate_polys(polys, tags, xxx_todo_changeme): 65 | ''' 66 | check so that the text poly is in the same direction, 67 | and also filter some invalid polygons 68 | :param polys: 69 | :param tags: 70 | :return: 71 | ''' 72 | (h, w) = xxx_todo_changeme 73 | if polys.shape[0] == 0: 74 | return polys 75 | polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) 76 | polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1) 77 | 78 | validated_polys = [] 79 | validated_tags = [] 80 | 81 | # 判断四边形的点时针方向,以及是否是有效四边形 82 | for poly, tag in zip(polys, tags): 83 | p_area = polygon_area(poly) 84 | if abs(p_area) < 1: 85 | # print poly 86 | print('invalid poly') 87 | continue 88 | if p_area > 0: 89 | print('poly in wrong direction') 90 | poly = poly[(0, 3, 2, 1), :] 91 | validated_polys.append(poly) 92 | validated_tags.append(tag) 93 | return np.array(validated_polys), np.array(validated_tags) 94 | 95 | 96 | def crop_area(im, polys, tags, crop_background=False, max_tries=100): 97 | ''' 98 | make random crop from the input image 99 | :param im: 100 | :param polys: 101 | :param tags: 102 | :param crop_background: 103 | :param max_tries: 104 | :return: 105 | ''' 106 | h, w, _ = im.shape 107 | pad_h = h // 10 108 | pad_w = w // 10 109 | h_array = np.zeros((h + pad_h * 2), dtype=np.int32) 110 | w_array = np.zeros((w + pad_w * 2), dtype=np.int32) 111 | for poly in polys: 112 | poly = np.round(poly, decimals=0).astype(np.int32) 113 | minx = np.min(poly[:, 0]) 114 | maxx = np.max(poly[:, 0]) 115 | w_array[minx + pad_w:maxx + pad_w] = 1 116 | miny = np.min(poly[:, 1]) 117 | maxy = np.max(poly[:, 1]) 118 | h_array[miny + pad_h:maxy + pad_h] = 1 119 | # ensure the cropped area not across a text,保证裁剪区域不能与文本交叉 120 | h_axis = np.where(h_array == 0)[0] 121 | w_axis = np.where(w_array == 0)[0] 122 | if len(h_axis) == 0 or len(w_axis) == 0: 123 | return im, polys, tags 124 | for i in range(max_tries): # 试验50次 125 | xx = np.random.choice(w_axis, size=2) 126 | xmin = np.min(xx) - pad_w 127 | xmax = np.max(xx) - pad_w 128 | xmin = np.clip(xmin, 0, w - 1) 129 | xmax = np.clip(xmax, 0, w - 1) 130 | yy = np.random.choice(h_axis, size=2) 131 | ymin = np.min(yy) - pad_h 132 | ymax = np.max(yy) - pad_h 133 | ymin = np.clip(ymin, 0, h - 1) 134 | ymax = np.clip(ymax, 0, h - 1) 135 | if xmax - xmin < cfg.min_crop_side_ratio * w or ymax - ymin < cfg.min_crop_side_ratio * h: 136 | # area too small 137 | continue 138 | if polys.shape[0] != 0: 139 | poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \ 140 | & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax) 141 | selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0] 142 | else: 143 | selected_polys = [] 144 | if len(selected_polys) == 0: 145 | # no text in this area 146 | if crop_background: 147 | return im[ymin:ymax + 1, xmin:xmax + 1, :], polys[selected_polys], tags[selected_polys] 148 | else: 149 | continue 150 | im = im[ymin:ymax + 1, xmin:xmax + 1, :] 151 | polys = polys[selected_polys] 152 | tags = tags[selected_polys] 153 | polys[:, :, 0] -= xmin 154 | polys[:, :, 1] -= ymin 155 | return im, polys, tags 156 | 157 | return im, polys, tags 158 | 159 | 160 | def shrink_poly(poly, r): 161 | ''' 162 | fit a poly inside the origin poly, maybe bugs here... 163 | used for generate the score map 164 | :param poly: the text poly 165 | :param r: r in the paper 166 | :return: the shrinked poly 167 | ''' 168 | # shrink ratio 169 | R = 0.3 170 | # find the longer pair 171 | if np.linalg.norm(poly[0] - poly[1]) + np.linalg.norm(poly[2] - poly[3]) > \ 172 | np.linalg.norm(poly[0] - poly[3]) + np.linalg.norm(poly[1] - poly[2]): 173 | # first move (p0, p1), (p2, p3), then (p0, p3), (p1, p2) 174 | ## p0, p1 175 | theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0])) 176 | poly[0][0] += R * r[0] * np.cos(theta) 177 | poly[0][1] += R * r[0] * np.sin(theta) 178 | poly[1][0] -= R * r[1] * np.cos(theta) 179 | poly[1][1] -= R * r[1] * np.sin(theta) 180 | ## p2, p3 181 | theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0])) 182 | poly[3][0] += R * r[3] * np.cos(theta) 183 | poly[3][1] += R * r[3] * np.sin(theta) 184 | poly[2][0] -= R * r[2] * np.cos(theta) 185 | poly[2][1] -= R * r[2] * np.sin(theta) 186 | ## p0, p3 187 | theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1])) 188 | poly[0][0] += R * r[0] * np.sin(theta) 189 | poly[0][1] += R * r[0] * np.cos(theta) 190 | poly[3][0] -= R * r[3] * np.sin(theta) 191 | poly[3][1] -= R * r[3] * np.cos(theta) 192 | ## p1, p2 193 | theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1])) 194 | poly[1][0] += R * r[1] * np.sin(theta) 195 | poly[1][1] += R * r[1] * np.cos(theta) 196 | poly[2][0] -= R * r[2] * np.sin(theta) 197 | poly[2][1] -= R * r[2] * np.cos(theta) 198 | else: 199 | ## p0, p3 200 | # print poly 201 | theta = np.arctan2((poly[3][0] - poly[0][0]), (poly[3][1] - poly[0][1])) 202 | poly[0][0] += R * r[0] * np.sin(theta) 203 | poly[0][1] += R * r[0] * np.cos(theta) 204 | poly[3][0] -= R * r[3] * np.sin(theta) 205 | poly[3][1] -= R * r[3] * np.cos(theta) 206 | ## p1, p2 207 | theta = np.arctan2((poly[2][0] - poly[1][0]), (poly[2][1] - poly[1][1])) 208 | poly[1][0] += R * r[1] * np.sin(theta) 209 | poly[1][1] += R * r[1] * np.cos(theta) 210 | poly[2][0] -= R * r[2] * np.sin(theta) 211 | poly[2][1] -= R * r[2] * np.cos(theta) 212 | ## p0, p1 213 | theta = np.arctan2((poly[1][1] - poly[0][1]), (poly[1][0] - poly[0][0])) 214 | poly[0][0] += R * r[0] * np.cos(theta) 215 | poly[0][1] += R * r[0] * np.sin(theta) 216 | poly[1][0] -= R * r[1] * np.cos(theta) 217 | poly[1][1] -= R * r[1] * np.sin(theta) 218 | ## p2, p3 219 | theta = np.arctan2((poly[2][1] - poly[3][1]), (poly[2][0] - poly[3][0])) 220 | poly[3][0] += R * r[3] * np.cos(theta) 221 | poly[3][1] += R * r[3] * np.sin(theta) 222 | poly[2][0] -= R * r[2] * np.cos(theta) 223 | poly[2][1] -= R * r[2] * np.sin(theta) 224 | return poly 225 | 226 | 227 | # def point_dist_to_line(p1, p2, p3): 228 | # # compute the distance from p3 to p1-p2 229 | # return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1) 230 | 231 | 232 | # 点p3到直线p12的距离 233 | def point_dist_to_line(p1, p2, p3): 234 | # compute the distance from p3 to p1-p2 235 | # return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1) 236 | a = np.linalg.norm(p1 - p2) 237 | b = np.linalg.norm(p2 - p3) 238 | c = np.linalg.norm(p3 - p1) 239 | s = (a + b + c) / 2.0 240 | area = np.abs((s * (s - a) * (s - b) * (s - c))) ** 0.5 241 | if a < 1.0: 242 | return (b + c) / 2.0 243 | return 2 * area / a 244 | 245 | 246 | def fit_line(p1, p2): 247 | # fit a line ax+by+c = 0 248 | if p1[0] == p1[1]: 249 | return [1., 0., -p1[0]] 250 | else: 251 | [k, b] = np.polyfit(p1, p2, deg=1) 252 | return [k, -1., b] 253 | 254 | 255 | def line_cross_point(line1, line2): 256 | # line1 0= ax+by+c, compute the cross point of line1 and line2 257 | if line1[0] != 0 and line1[0] == line2[0]: 258 | print('Cross point does not exist') 259 | return None 260 | if line1[0] == 0 and line2[0] == 0: 261 | print('Cross point does not exist') 262 | return None 263 | if line1[1] == 0: 264 | x = -line1[2] 265 | y = line2[0] * x + line2[2] 266 | elif line2[1] == 0: 267 | x = -line2[2] 268 | y = line1[0] * x + line1[2] 269 | else: 270 | k1, _, b1 = line1 271 | k2, _, b2 = line2 272 | x = -(b1 - b2) / (k1 - k2) 273 | y = k1 * x + b1 274 | return np.array([x, y], dtype=np.float32) 275 | 276 | 277 | def line_verticle(line, point): 278 | # get the verticle line from line across point 279 | if line[1] == 0: 280 | verticle = [0, -1, point[1]] 281 | else: 282 | if line[0] == 0: 283 | verticle = [1, 0, -point[0]] 284 | else: 285 | verticle = [-1. / line[0], -1, point[1] - (-1 / line[0] * point[0])] 286 | return verticle 287 | 288 | 289 | def rectangle_from_parallelogram(poly): 290 | ''' 291 | fit a rectangle from a parallelogram 292 | :param poly: 293 | :return: 294 | ''' 295 | p0, p1, p2, p3 = poly 296 | angle_p0 = np.arccos(np.dot(p1 - p0, p3 - p0) / (np.linalg.norm(p0 - p1) * np.linalg.norm(p3 - p0))) 297 | if angle_p0 < 0.5 * np.pi: 298 | if np.linalg.norm(p0 - p1) > np.linalg.norm(p0 - p3): 299 | # p0 and p2 300 | ## p0 301 | p2p3 = fit_line([p2[0], p3[0]], [p2[1], p3[1]]) 302 | p2p3_verticle = line_verticle(p2p3, p0) 303 | 304 | new_p3 = line_cross_point(p2p3, p2p3_verticle) 305 | ## p2 306 | p0p1 = fit_line([p0[0], p1[0]], [p0[1], p1[1]]) 307 | p0p1_verticle = line_verticle(p0p1, p2) 308 | 309 | new_p1 = line_cross_point(p0p1, p0p1_verticle) 310 | return np.array([p0, new_p1, p2, new_p3], dtype=np.float32) 311 | else: 312 | p1p2 = fit_line([p1[0], p2[0]], [p1[1], p2[1]]) 313 | p1p2_verticle = line_verticle(p1p2, p0) 314 | 315 | new_p1 = line_cross_point(p1p2, p1p2_verticle) 316 | p0p3 = fit_line([p0[0], p3[0]], [p0[1], p3[1]]) 317 | p0p3_verticle = line_verticle(p0p3, p2) 318 | 319 | new_p3 = line_cross_point(p0p3, p0p3_verticle) 320 | return np.array([p0, new_p1, p2, new_p3], dtype=np.float32) 321 | else: 322 | if np.linalg.norm(p0 - p1) > np.linalg.norm(p0 - p3): 323 | # p1 and p3 324 | ## p1 325 | p2p3 = fit_line([p2[0], p3[0]], [p2[1], p3[1]]) 326 | p2p3_verticle = line_verticle(p2p3, p1) 327 | 328 | new_p2 = line_cross_point(p2p3, p2p3_verticle) 329 | ## p3 330 | p0p1 = fit_line([p0[0], p1[0]], [p0[1], p1[1]]) 331 | p0p1_verticle = line_verticle(p0p1, p3) 332 | 333 | new_p0 = line_cross_point(p0p1, p0p1_verticle) 334 | return np.array([new_p0, p1, new_p2, p3], dtype=np.float32) 335 | else: 336 | p0p3 = fit_line([p0[0], p3[0]], [p0[1], p3[1]]) 337 | p0p3_verticle = line_verticle(p0p3, p1) 338 | 339 | new_p0 = line_cross_point(p0p3, p0p3_verticle) 340 | p1p2 = fit_line([p1[0], p2[0]], [p1[1], p2[1]]) 341 | p1p2_verticle = line_verticle(p1p2, p3) 342 | 343 | new_p2 = line_cross_point(p1p2, p1p2_verticle) 344 | return np.array([new_p0, p1, new_p2, p3], dtype=np.float32) 345 | 346 | 347 | def sort_rectangle(poly): 348 | # sort the four coordinates of the polygon, points in poly should be sorted clockwise 349 | # First find the lowest point 350 | p_lowest = np.argmax(poly[:, 1]) 351 | if np.count_nonzero(poly[:, 1] == poly[p_lowest, 1]) == 2: 352 | # 底边平行于X轴, 那么p0为左上角 - if the bottom line is parallel to x-axis, then p0 must be the upper-left corner 353 | p0_index = np.argmin(np.sum(poly, axis=1)) 354 | p1_index = (p0_index + 1) % 4 355 | p2_index = (p0_index + 2) % 4 356 | p3_index = (p0_index + 3) % 4 357 | return poly[[p0_index, p1_index, p2_index, p3_index]], 0. 358 | else: 359 | # 找到最低点右边的点 - find the point that sits right to the lowest point 360 | p_lowest_right = (p_lowest - 1) % 4 361 | p_lowest_left = (p_lowest + 1) % 4 362 | angle = np.arctan( 363 | -(poly[p_lowest][1] - poly[p_lowest_right][1]) / (poly[p_lowest][0] - poly[p_lowest_right][0])) 364 | # assert angle > 0 365 | if angle <= 0: 366 | print(angle, poly[p_lowest], poly[p_lowest_right]) 367 | if angle / np.pi * 180 > 45: 368 | # 这个点为p2 - this point is p2 369 | p2_index = p_lowest 370 | p1_index = (p2_index - 1) % 4 371 | p0_index = (p2_index - 2) % 4 372 | p3_index = (p2_index + 1) % 4 373 | return poly[[p0_index, p1_index, p2_index, p3_index]], -(np.pi / 2 - angle) 374 | else: 375 | # 这个点为p3 - this point is p3 376 | p3_index = p_lowest 377 | p0_index = (p3_index + 1) % 4 378 | p1_index = (p3_index + 2) % 4 379 | p2_index = (p3_index + 3) % 4 380 | return poly[[p0_index, p1_index, p2_index, p3_index]], angle 381 | 382 | 383 | def restore_rectangle_rbox(origin, geometry): 384 | d = geometry[:, :4] 385 | angle = geometry[:, 4] 386 | # for angle > 0 387 | origin_0 = origin[angle >= 0] 388 | d_0 = d[angle >= 0] 389 | angle_0 = angle[angle >= 0] 390 | if origin_0.shape[0] > 0: 391 | p = np.array([np.zeros(d_0.shape[0]), -d_0[:, 0] - d_0[:, 2], 392 | d_0[:, 1] + d_0[:, 3], -d_0[:, 0] - d_0[:, 2], 393 | d_0[:, 1] + d_0[:, 3], np.zeros(d_0.shape[0]), 394 | np.zeros(d_0.shape[0]), np.zeros(d_0.shape[0]), 395 | d_0[:, 3], -d_0[:, 2]]) 396 | p = p.transpose((1, 0)).reshape((-1, 5, 2)) # N*5*2 397 | 398 | rotate_matrix_x = np.array([np.cos(angle_0), np.sin(angle_0)]).transpose((1, 0)) 399 | rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) # N*5*2 400 | 401 | rotate_matrix_y = np.array([-np.sin(angle_0), np.cos(angle_0)]).transpose((1, 0)) 402 | rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) 403 | 404 | p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis] # N*5*1 405 | p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis] # N*5*1 406 | 407 | p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2) # N*5*2 408 | 409 | p3_in_origin = origin_0 - p_rotate[:, 4, :] 410 | new_p0 = p_rotate[:, 0, :] + p3_in_origin # N*2 411 | new_p1 = p_rotate[:, 1, :] + p3_in_origin 412 | new_p2 = p_rotate[:, 2, :] + p3_in_origin 413 | new_p3 = p_rotate[:, 3, :] + p3_in_origin 414 | 415 | new_p_0 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :], 416 | new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1) # N*4*2 417 | else: 418 | new_p_0 = np.zeros((0, 4, 2)) 419 | # for angle < 0 420 | origin_1 = origin[angle < 0] 421 | d_1 = d[angle < 0] 422 | angle_1 = angle[angle < 0] 423 | if origin_1.shape[0] > 0: 424 | p = np.array([-d_1[:, 1] - d_1[:, 3], -d_1[:, 0] - d_1[:, 2], 425 | np.zeros(d_1.shape[0]), -d_1[:, 0] - d_1[:, 2], 426 | np.zeros(d_1.shape[0]), np.zeros(d_1.shape[0]), 427 | -d_1[:, 1] - d_1[:, 3], np.zeros(d_1.shape[0]), 428 | -d_1[:, 1], -d_1[:, 2]]) 429 | p = p.transpose((1, 0)).reshape((-1, 5, 2)) # N*5*2 430 | 431 | rotate_matrix_x = np.array([np.cos(-angle_1), -np.sin(-angle_1)]).transpose((1, 0)) 432 | rotate_matrix_x = np.repeat(rotate_matrix_x, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) # N*5*2 433 | 434 | rotate_matrix_y = np.array([np.sin(-angle_1), np.cos(-angle_1)]).transpose((1, 0)) 435 | rotate_matrix_y = np.repeat(rotate_matrix_y, 5, axis=1).reshape(-1, 2, 5).transpose((0, 2, 1)) 436 | 437 | p_rotate_x = np.sum(rotate_matrix_x * p, axis=2)[:, :, np.newaxis] # N*5*1 438 | p_rotate_y = np.sum(rotate_matrix_y * p, axis=2)[:, :, np.newaxis] # N*5*1 439 | 440 | p_rotate = np.concatenate([p_rotate_x, p_rotate_y], axis=2) # N*5*2 441 | 442 | p3_in_origin = origin_1 - p_rotate[:, 4, :] 443 | new_p0 = p_rotate[:, 0, :] + p3_in_origin # N*2 444 | new_p1 = p_rotate[:, 1, :] + p3_in_origin 445 | new_p2 = p_rotate[:, 2, :] + p3_in_origin 446 | new_p3 = p_rotate[:, 3, :] + p3_in_origin 447 | 448 | new_p_1 = np.concatenate([new_p0[:, np.newaxis, :], new_p1[:, np.newaxis, :], 449 | new_p2[:, np.newaxis, :], new_p3[:, np.newaxis, :]], axis=1) # N*4*2 450 | else: 451 | new_p_1 = np.zeros((0, 4, 2)) 452 | return np.concatenate([new_p_0, new_p_1]) 453 | 454 | 455 | def restore_rectangle(origin, geometry): 456 | return restore_rectangle_rbox(origin, geometry) 457 | 458 | 459 | def generate_rbox(im_size, polys, tags): 460 | h, w = im_size 461 | poly_mask = np.zeros((h, w), dtype=np.uint8) 462 | score_map = np.zeros((h, w), dtype=np.uint8) 463 | geo_map = np.zeros((h, w, 5), dtype=np.float32) 464 | # mask used during traning, to ignore some hard areas,用于忽略那些过小的文本 465 | training_mask = np.ones((h, w), dtype=np.uint8) 466 | for poly_idx, poly_tag in enumerate(zip(polys, tags)): 467 | poly = poly_tag[0] 468 | tag = poly_tag[1] 469 | 470 | # 对每个顶点,找到经过他的两条边中较短的那条 471 | r = [None, None, None, None] 472 | for i in range(4): 473 | r[i] = min(np.linalg.norm(poly[i] - poly[(i + 1) % 4]), 474 | np.linalg.norm(poly[i] - poly[(i - 1) % 4])) 475 | # score map 476 | # 放缩边框为之前的0.3倍,并对边框对应score图中的位置进行填充 477 | shrinked_poly = shrink_poly(poly.copy(), r).astype(np.int32)[np.newaxis, :, :] 478 | cv2.fillPoly(score_map, shrinked_poly, 1) 479 | cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1) 480 | # if the poly is too small, then ignore it during training 481 | # 如果文本框标签太小或者txt中没具体标记是什么内容,即*或者###,则加掩模,训练时忽略该部分 482 | poly_h = min(np.linalg.norm(poly[0] - poly[3]), np.linalg.norm(poly[1] - poly[2])) 483 | poly_w = min(np.linalg.norm(poly[0] - poly[1]), np.linalg.norm(poly[2] - poly[3])) 484 | if min(poly_h, poly_w) < cfg.min_text_size: 485 | cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) 486 | if tag: 487 | cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) 488 | 489 | # 当前新加入的文本框区域像素点 490 | xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1)) 491 | # if geometry == 'RBOX': 492 | # 对任意两个顶点的组合生成一个平行四边形 - generate a parallelogram for any combination of two vertices 493 | fitted_parallelograms = [] 494 | for i in range(4): 495 | # 选中p0和p1的连线边,生成两个平行四边形 496 | p0 = poly[i] 497 | p1 = poly[(i + 1) % 4] 498 | p2 = poly[(i + 2) % 4] 499 | p3 = poly[(i + 3) % 4] 500 | # 拟合ax+by+c=0 501 | edge = fit_line([p0[0], p1[0]], [p0[1], p1[1]]) 502 | backward_edge = fit_line([p0[0], p3[0]], [p0[1], p3[1]]) 503 | forward_edge = fit_line([p1[0], p2[0]], [p1[1], p2[1]]) 504 | # 通过另外两个点距离edge的距离,来决定edge对应的平行线应该过p2还是p3 505 | if point_dist_to_line(p0, p1, p2) > point_dist_to_line(p0, p1, p3): 506 | # 平行线经过p2 - parallel lines through p2 507 | if edge[1] == 0: 508 | edge_opposite = [1, 0, -p2[0]] 509 | else: 510 | edge_opposite = [edge[0], -1, p2[1] - edge[0] * p2[0]] 511 | else: 512 | # 经过p3 - after p3 513 | if edge[1] == 0: 514 | edge_opposite = [1, 0, -p3[0]] 515 | else: 516 | edge_opposite = [edge[0], -1, p3[1] - edge[0] * p3[0]] 517 | # move forward edge 518 | new_p0 = p0 519 | new_p1 = p1 520 | new_p2 = p2 521 | new_p3 = p3 522 | new_p2 = line_cross_point(forward_edge, edge_opposite) 523 | if point_dist_to_line(p1, new_p2, p0) > point_dist_to_line(p1, new_p2, p3): 524 | # across p0 525 | if forward_edge[1] == 0: 526 | forward_opposite = [1, 0, -p0[0]] 527 | else: 528 | forward_opposite = [forward_edge[0], -1, p0[1] - forward_edge[0] * p0[0]] 529 | else: 530 | # across p3 531 | if forward_edge[1] == 0: 532 | forward_opposite = [1, 0, -p3[0]] 533 | else: 534 | forward_opposite = [forward_edge[0], -1, p3[1] - forward_edge[0] * p3[0]] 535 | new_p0 = line_cross_point(forward_opposite, edge) 536 | new_p3 = line_cross_point(forward_opposite, edge_opposite) 537 | fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0]) 538 | # or move backward edge 539 | new_p0 = p0 540 | new_p1 = p1 541 | new_p2 = p2 542 | new_p3 = p3 543 | new_p3 = line_cross_point(backward_edge, edge_opposite) 544 | if point_dist_to_line(p0, p3, p1) > point_dist_to_line(p0, p3, p2): 545 | # across p1 546 | if backward_edge[1] == 0: 547 | backward_opposite = [1, 0, -p1[0]] 548 | else: 549 | backward_opposite = [backward_edge[0], -1, p1[1] - backward_edge[0] * p1[0]] 550 | else: 551 | # across p2 552 | if backward_edge[1] == 0: 553 | backward_opposite = [1, 0, -p2[0]] 554 | else: 555 | backward_opposite = [backward_edge[0], -1, p2[1] - backward_edge[0] * p2[0]] 556 | new_p1 = line_cross_point(backward_opposite, edge) 557 | new_p2 = line_cross_point(backward_opposite, edge_opposite) 558 | fitted_parallelograms.append([new_p0, new_p1, new_p2, new_p3, new_p0]) 559 | 560 | # 选定面积最小的平行四边形 561 | areas = [Polygon(t).area for t in fitted_parallelograms] 562 | parallelogram = np.array(fitted_parallelograms[np.argmin(areas)][:-1], dtype=np.float32) 563 | # sort thie polygon 564 | parallelogram_coord_sum = np.sum(parallelogram, axis=1) 565 | min_coord_idx = np.argmin(parallelogram_coord_sum) 566 | parallelogram = parallelogram[ 567 | [min_coord_idx, (min_coord_idx + 1) % 4, (min_coord_idx + 2) % 4, (min_coord_idx + 3) % 4]] 568 | 569 | # 得到外包矩形即旋转角 570 | rectange = rectangle_from_parallelogram(parallelogram) 571 | rectange, rotate_angle = sort_rectangle(rectange) 572 | 573 | p0_rect, p1_rect, p2_rect, p3_rect = rectange 574 | # 对当前新加入的文本框区域像素点,根据其到矩形四边的距离修改geo_map 575 | for y, x in xy_in_poly: 576 | point = np.array([x, y], dtype=np.float32) 577 | # top 578 | geo_map[y, x, 0] = point_dist_to_line(p0_rect, p1_rect, point) 579 | # right 580 | geo_map[y, x, 1] = point_dist_to_line(p1_rect, p2_rect, point) 581 | # down 582 | geo_map[y, x, 2] = point_dist_to_line(p2_rect, p3_rect, point) 583 | # left 584 | geo_map[y, x, 3] = point_dist_to_line(p3_rect, p0_rect, point) 585 | # angle 586 | geo_map[y, x, 4] = rotate_angle 587 | return score_map, geo_map, training_mask 588 | 589 | 590 | def generator(index, 591 | input_size=512, 592 | background_ratio=3. / 8, # 纯背景样本比例 593 | random_scale=np.array([0.5, 1, 2.0, 3.0]), # 提取多尺度图片信息 594 | image_list=None): 595 | try: 596 | im_fn = image_list[index] 597 | im = cv2.imread(im_fn) 598 | if im is None: 599 | print("can't find image") 600 | return None, None, None, None, None 601 | h, w, _ = im.shape 602 | # 所以要把gt去掉 603 | txt_fn = im_fn.replace(os.path.basename(im_fn).split('.')[1], 'txt') 604 | if not os.path.exists(txt_fn): 605 | print('text file {} does not exists'.format(txt_fn)) 606 | return None, None, None, None, None 607 | # 加载标注框信息 608 | text_polys, text_tags = load_annoataion(txt_fn) 609 | 610 | text_polys, text_tags = check_and_validate_polys(text_polys, text_tags, (h, w)) 611 | 612 | # random scale this image,随机选择一种尺度 613 | rd_scale = np.random.choice(random_scale) 614 | im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) 615 | text_polys *= rd_scale 616 | 617 | # random crop a area from image,3/8的选中的概率,裁剪纯背景的图片 618 | if np.random.rand() < background_ratio: 619 | # crop background 620 | im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=True) 621 | if text_polys.shape[0] > 0: 622 | # print("cannot find background") 623 | return None, None, None, None, None 624 | # pad and resize image 625 | new_h, new_w, _ = im.shape 626 | max_h_w_i = np.max([new_h, new_w, input_size]) 627 | im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8) 628 | im_padded[:new_h, :new_w, :] = im.copy() 629 | # 将裁剪后图片扩充成512*512的图片 630 | im = cv2.resize(im_padded, dsize=(input_size, input_size)) 631 | score_map = np.zeros((input_size, input_size), dtype=np.uint8) 632 | geo_map_channels = 5 if cfg.geometry == 'RBOX' else 8 633 | geo_map = np.zeros((input_size, input_size, geo_map_channels), dtype=np.float32) 634 | training_mask = np.ones((input_size, input_size), dtype=np.uint8) 635 | else: 636 | # 5 / 8的选中的概率,裁剪含文本信息的图片 637 | im, text_polys, text_tags = crop_area(im, text_polys, text_tags, crop_background=False) 638 | if text_polys.shape[0] == 0: 639 | # print("cannot find txt ground") 640 | return None, None, None, None, None 641 | h, w, _ = im.shape 642 | # pad the image to the training input size or the longer side of image 643 | new_h, new_w, _ = im.shape 644 | max_h_w_i = np.max([new_h, new_w, input_size]) 645 | im_padded = np.zeros((max_h_w_i, max_h_w_i, 3), dtype=np.uint8) 646 | im_padded[:new_h, :new_w, :] = im.copy() 647 | im = im_padded 648 | # resize the image to input size 649 | # 填充,resize图像至设定尺寸 650 | new_h, new_w, _ = im.shape 651 | resize_h = input_size 652 | resize_w = input_size 653 | im = cv2.resize(im, dsize=(resize_w, resize_h)) 654 | # 将文本框坐标标签等比例修改 655 | resize_ratio_3_x = resize_w / float(new_w) 656 | resize_ratio_3_y = resize_h / float(new_h) 657 | text_polys[:, :, 0] *= resize_ratio_3_x 658 | text_polys[:, :, 1] *= resize_ratio_3_y 659 | new_h, new_w, _ = im.shape 660 | score_map, geo_map, training_mask = generate_rbox((new_h, new_w), text_polys, text_tags) 661 | 662 | # 将一个样本的样本内容和标签信息append 663 | images = im[:,:,::-1].astype(np.float32) 664 | # 文件名加入列表 665 | image_fns = im_fn 666 | # 512*512取提取四分之一行列 667 | score_maps = score_map[::4, ::4, np.newaxis].astype(np.float32) 668 | geo_maps = geo_map[::4, ::4, :].astype(np.float32) 669 | training_masks = training_mask[::4, ::4, np.newaxis].astype(np.float32) 670 | # 符合一个样本之后输出 671 | return images, image_fns, score_maps, geo_maps, training_masks 672 | 673 | except Exception as e: 674 | import traceback 675 | traceback.print_exc() 676 | 677 | # print("Exception is exist!") 678 | return None, None, None, None, None 679 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch=0.4.1 2 | torchvision=0.2.1 3 | shapely=1.6.4.post1 4 | opencv=3.4.2 5 | tensorboardX 6 | -------------------------------------------------------------------------------- /tensorboards/README.txt: -------------------------------------------------------------------------------- 1 | Information of Tensorboard is saved in this folder. -------------------------------------------------------------------------------- /tmp/README.txt: -------------------------------------------------------------------------------- 1 | Save EAST model in this folder. -------------------------------------------------------------------------------- /tmp/backbone_net/README.txt: -------------------------------------------------------------------------------- 1 | Put backbone net in this folder such as resnet or mobilenet. -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from torch.utils.data import DataLoader 3 | from data_loader import custom_dset 4 | import config as cfg 5 | from torchvision import transforms 6 | from model import East 7 | import torch 8 | from torch.optim import lr_scheduler 9 | import loss 10 | import os 11 | import utils 12 | import tensorboardX 13 | 14 | 15 | def init_tensorboard_writer(store_dir): 16 | assert os.path.exists(os.path.dirname(store_dir)) 17 | return tensorboardX.SummaryWriter(store_dir) 18 | 19 | 20 | def fit(train_loader, model, criterion, optimizer, epoch, weight_loss, writer): 21 | 22 | model.train() 23 | start = time.time() 24 | 25 | for i, (img, img_path, score_map, geo_map, training_mask) in enumerate(train_loader): 26 | 27 | img, score_map, geo_map, training_mask = img.cuda(), score_map.cuda(), geo_map.cuda(), training_mask.cuda() 28 | f_score, f_geometry = model(img) 29 | model_loss = criterion(score_map, f_score, geo_map, f_geometry, training_mask) 30 | total_loss = model_loss + weight_loss(model) 31 | 32 | # backward 33 | optimizer.zero_grad() 34 | total_loss.backward() 35 | optimizer.step() 36 | 37 | # measure elapsed time 38 | end = time.time() 39 | batch_sum_time = end - start 40 | per_img_time = 1.0 * batch_sum_time / img.size(0) 41 | start = end 42 | 43 | steps=epoch*len(train_loader)+i 44 | writer.add_scalar('model_loss', model_loss.item(), steps) 45 | writer.add_scalar('total_loss', total_loss.item(), steps) 46 | writer.add_scalar('per_img_time', per_img_time, steps) 47 | 48 | if i % cfg.print_freq == 0: 49 | print('EAST <==> TRAIN <==> Epoch: [%d][%d/%d] ,Model Loss %.5f, Total Loss %.5f, Per Img Time %.2f second' 50 | % (epoch, i, len(train_loader), model_loss.item(), total_loss.item(), per_img_time)) 51 | 52 | 53 | def main(): 54 | 55 | # Prepare for dataset 56 | print('EAST <==> Prepare <==> DataLoader <==> Begin') 57 | trainset = custom_dset(transform=transforms.ToTensor()) 58 | train_loader = DataLoader(trainset, batch_size=cfg.train_batch_size_per_gpu * cfg.gpu, 59 | shuffle=True, num_workers=cfg.num_workers) 60 | print('EAST <==> Prepare <==> Batch_size:{} <==> Begin'.format(cfg.train_batch_size_per_gpu * cfg.gpu)) 61 | print('EAST <==> Prepare <==> DataLoader <==> Done') 62 | 63 | # test datalodaer 64 | # import numpy as np 65 | # import matplotlib.pyplot as plt 66 | # for batch_idx, (img, img_path, score_map, geo_map, training_mask) in enumerate(train_loader): 67 | # print("batch index:", batch_idx, ",img batch shape", np.shape(geo_map.numpy())) 68 | # h1 = img.numpy()[0].transpose(1, 2, 0).astype(np.int64) 69 | # h2 = score_map.numpy()[0].transpose(1, 2, 0).astype(np.float32)[:, :, 0] 70 | # plt.figure() 71 | # plt.subplot(1, 2, 1) 72 | # plt.imshow(h1) 73 | # plt.subplot(1, 2, 2) 74 | # plt.imshow(h2, cmap='gray') 75 | # plt.show() 76 | 77 | # Model 78 | print('EAST <==> Prepare <==> Network <==> Begin') 79 | model = East() 80 | model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids) 81 | criterion = loss.LossFunc().cuda() 82 | weight_loss = utils.Regularization(model, cfg.l2_weight_decay, p=2).cuda() 83 | 84 | pre_params = list(map(id, model.module.mobilenet.parameters())) 85 | post_params = filter(lambda p: id(p) not in pre_params, model.module.parameters()) 86 | optimizer = torch.optim.Adam([{'params': model.module.mobilenet.parameters(), 'lr': cfg.pre_lr}, 87 | {'params': post_params, 'lr': cfg.lr}]) 88 | # 计算方式 decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps) 89 | scheduler = lr_scheduler.StepLR(optimizer, step_size=cfg.decay_steps, gamma=cfg.decay_rate) 90 | model.cuda() 91 | 92 | # init or resume,恢复模型 93 | if cfg.resume and os.path.isfile(cfg.checkpoint): 94 | start_epoch = utils.Loading_checkpoint(model, optimizer, scheduler) 95 | else: 96 | start_epoch = 0 97 | 98 | print('EAST <==> Prepare <==> Network <==> Done') 99 | 100 | tensorboard_writer = init_tensorboard_writer('tensorboards/{}'.format(str(int(time.time())))) 101 | 102 | # train Model 103 | for epoch in range(start_epoch, cfg.max_epochs): 104 | 105 | scheduler.step() 106 | fit(train_loader, model, criterion, optimizer, epoch, weight_loss,tensorboard_writer) 107 | 108 | # 保存模型 109 | if epoch % cfg.save_eval_iteration == 0: 110 | utils.save_checkpoint(epoch, model, optimizer, scheduler) 111 | 112 | 113 | if __name__ == "__main__": 114 | main() 115 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import config as cfg 4 | from torch.nn import init 5 | import cv2 6 | import numpy as np 7 | import time 8 | import preprossing 9 | import locality_aware_nms 10 | 11 | 12 | def init_weights(m_list, init_type=cfg.init_type, gain=0.02): 13 | print("EAST <==> Prepare <==> Init Network'{}' <==> Begin".format(cfg.init_type)) 14 | # this will apply to each layer 15 | for m in m_list: 16 | classname = m.__class__.__name__ 17 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 18 | if init_type == 'normal': 19 | init.normal_(m.weight.data, 0.0, gain) 20 | elif init_type == 'xavier': 21 | init.xavier_normal_(m.weight.data, gain=gain) 22 | elif init_type == 'kaiming': 23 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # good for relu 24 | elif init_type == 'orthogonal': 25 | init.orthogonal_(m.weight.data, gain=gain) 26 | else: 27 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 28 | 29 | if hasattr(m, 'bias') and m.bias is not None: 30 | init.constant_(m.bias.data, 0.0) 31 | elif classname.find('BatchNorm2d') != -1: 32 | init.normal_(m.weight.data, 1.0, gain) 33 | init.constant_(m.bias.data, 0.0) 34 | 35 | print("EAST <==> Prepare <==> Init Network'{}' <==> Done".format(cfg.init_type)) 36 | 37 | 38 | def Loading_checkpoint(model, optimizer, scheduler, filename='checkpoint.pth.tar'): 39 | """[summary] 40 | [description] 41 | Arguments: 42 | state {[type]} -- [description] a dict describe some params 43 | Keyword Arguments: 44 | filename {str} -- [description] (default: {'checkpoint.pth.tar'}) 45 | """ 46 | weightpath = os.path.abspath(cfg.checkpoint) 47 | print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Begin".format(weightpath)) 48 | checkpoint = torch.load(weightpath) 49 | start_epoch = checkpoint['epoch'] + 1 50 | model.load_state_dict(checkpoint['state_dict']) 51 | optimizer.load_state_dict(checkpoint['optimizer']) 52 | scheduler.load_state_dict(checkpoint['scheduler']) 53 | print("EAST <==> Prepare <==> Loading checkpoint '{}' <==> Done".format(weightpath)) 54 | 55 | return start_epoch 56 | 57 | 58 | def save_checkpoint(epoch, model, optimizer, scheduler, filename='checkpoint.pth.tar'): 59 | """[summary] 60 | [description] 61 | Arguments: 62 | state {[type]} -- [description] a dict describe some params 63 | Keyword Arguments: 64 | filename {str} -- [description] (default: {'checkpoint.pth.tar'}) 65 | """ 66 | print('EAST <==> Save weight - epoch {} <==> Begin'.format(epoch)) 67 | state = { 68 | 'epoch': epoch, 69 | 'state_dict': model.state_dict(), 70 | 'optimizer': optimizer.state_dict(), 71 | 'scheduler': scheduler.state_dict() 72 | } 73 | weight_dir = cfg.save_model_path 74 | if not os.path.exists(weight_dir): 75 | os.mkdir(weight_dir) 76 | filename = 'epoch_' + str(epoch) + '_checkpoint.pth.tar' 77 | file_path = os.path.join(weight_dir, filename) 78 | torch.save(state, file_path) 79 | print('EAST <==> Save weight - epoch {} <==> Done'.format(epoch)) 80 | 81 | 82 | class Regularization(torch.nn.Module): 83 | def __init__(self, model, weight_decay, p=2): 84 | super(Regularization, self).__init__() 85 | if weight_decay < 0: 86 | print("param weight_decay can not <0") 87 | exit(0) 88 | self.model = model 89 | self.weight_decay = weight_decay 90 | self.p = p 91 | self.weight_list = self.get_weight(model) 92 | # self.weight_info(self.weight_list) 93 | 94 | def to(self, device): 95 | self.device = device 96 | super().to(device) 97 | return self 98 | 99 | def forward(self, model): 100 | self.weight_list = self.get_weight(model) # 获得最新的权重 101 | reg_loss = self.regularization_loss(self.weight_list, self.weight_decay, p=self.p) 102 | return reg_loss 103 | 104 | def get_weight(self, model): 105 | weight_list = [] 106 | for name, param in model.named_parameters(): 107 | if 'weight' in name: 108 | weight = (name, param) 109 | weight_list.append(weight) 110 | return weight_list 111 | 112 | def regularization_loss(self, weight_list, weight_decay, p=2): 113 | reg_loss = 0 114 | for name, w in weight_list: 115 | l2_reg = torch.norm(w, p=p) 116 | reg_loss = reg_loss + l2_reg 117 | 118 | reg_loss = weight_decay * reg_loss 119 | return reg_loss 120 | 121 | def weight_info(self, weight_list): 122 | print("---------------regularization weight---------------") 123 | for name, w in weight_list: 124 | print(name) 125 | print("---------------------------------------------------") 126 | 127 | 128 | def resize_image(im, max_side_len=2400): 129 | ''' 130 | resize image to a size multiple of 32 which is required by the network 131 | :param im: the resized image 132 | :param max_side_len: limit of max image size to avoid out of memory in gpu 133 | :return: the resized image and the resize ratio 134 | ''' 135 | h, w, _ = im.shape 136 | 137 | resize_w = w 138 | resize_h = h 139 | 140 | # limit the max side 141 | """ 142 | if max(resize_h, resize_w) > max_side_len: 143 | ratio = float(max_side_len) / resize_h if resize_h > resize_w else float(max_side_len) / resize_w 144 | else: 145 | ratio = 1. 146 | 147 | resize_h = int(resize_h * ratio) 148 | resize_w = int(resize_w * ratio) 149 | """ 150 | 151 | resize_h = resize_h if resize_h % 32 == 0 else (resize_h // 32 - 1) * 32 152 | resize_w = resize_w if resize_w % 32 == 0 else (resize_w // 32 - 1) * 32 153 | #resize_h, resize_w = 512, 512 154 | im = cv2.resize(im, (int(resize_w), int(resize_h))) 155 | 156 | ratio_h = resize_h / float(h) 157 | ratio_w = resize_w / float(w) 158 | 159 | return im, (ratio_h, ratio_w) 160 | 161 | 162 | def detect(score_map, geo_map, timer, score_map_thresh=0.8, box_thresh=0.1, nms_thres=0.2): 163 | ''' 164 | restore text boxes from score map and geo map 165 | :param score_map: 166 | :param geo_map: 167 | :param timer: 168 | :param score_map_thresh: threshhold for score map 169 | :param box_thresh: threshhold for boxes 170 | :param nms_thres: threshold for nms 171 | :return: 172 | ''' 173 | 174 | # score_map 和 geo_map 的维数进行调整 175 | if len(score_map.shape) == 4: 176 | score_map = score_map[0, :, :, 0] 177 | geo_map = geo_map[0, :, :, :] 178 | # filter the score map 179 | xy_text = np.argwhere(score_map > score_map_thresh) 180 | # sort the text boxes via the y axis 181 | xy_text = xy_text[np.argsort(xy_text[:, 0])] 182 | # restore 183 | start = time.time() 184 | text_box_restored = preprossing.restore_rectangle(xy_text[:, ::-1] * 4, 185 | geo_map[xy_text[:, 0], xy_text[:, 1], :]) # N*4*2 186 | print('{} text boxes before nms'.format(text_box_restored.shape[0])) 187 | boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32) 188 | boxes[:, :8] = text_box_restored.reshape((-1, 8)) 189 | boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]] 190 | timer['restore'] = time.time() - start 191 | # nms part 192 | start = time.time() 193 | boxes = locality_aware_nms.nms_locality(boxes.astype(np.float64), nms_thres) 194 | timer['nms'] = time.time() - start 195 | print(timer['nms']) 196 | if boxes.shape[0] == 0: 197 | return None, timer 198 | 199 | # here we filter some low score boxes by the average score map, this is different from the orginal paper 200 | for i, box in enumerate(boxes): 201 | mask = np.zeros_like(score_map, dtype=np.uint8) 202 | cv2.fillPoly(mask, box[:8].reshape((-1, 4, 2)).astype(np.int32) // 4, 1) 203 | boxes[i, 8] = cv2.mean(score_map, mask)[0] 204 | boxes = boxes[boxes[:, 8] > box_thresh] 205 | return boxes, timer 206 | 207 | 208 | def sort_poly(p): 209 | min_axis = np.argmin(np.sum(p, axis=1)) 210 | p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]] 211 | if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]): 212 | return p 213 | else: 214 | return p[[0, 3, 2, 1]] 215 | 216 | 217 | def mean_image_subtraction(images, means=cfg.means): 218 | ''' 219 | image normalization 220 | :param images: bs * w * h * channel 221 | :param means: 222 | :return: 223 | ''' 224 | num_channels = images.data.shape[1] 225 | if len(means) != num_channels: 226 | raise ValueError('len(means) must match the number of channels') 227 | for i in range(num_channels): 228 | images.data[:, i, :, :] -= means[i] 229 | 230 | return images --------------------------------------------------------------------------------