├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── Infer_Utils.py ├── README.md ├── TestModel.py ├── codelist.txt ├── compile.sh ├── config └── pan_pp │ ├── R18-AUG.py │ └── R50-AUG.py ├── convertor.py ├── dataset ├── README.md ├── __init__.py ├── builder.py └── pan_pp │ ├── __init__.py │ ├── coco_text.py │ ├── pan_pp_coco.py │ ├── pan_pp_ic15.py │ └── pan_pp_joint_train.py ├── eval ├── README.md ├── ctw │ ├── eval.py │ └── file_util.py ├── eval_ctw.sh ├── eval_ic15.sh ├── eval_ic15_end2end_rec.sh ├── eval_ic15_word_spotting.sh ├── eval_msra.sh ├── eval_tt.sh ├── eval_tt_rec.sh ├── ic15 │ ├── gt.zip │ ├── rrc_evaluation_funcs.py │ ├── rrc_evaluation_funcs_v1.py │ ├── rrc_evaluation_funcs_v2.py │ ├── script.py │ └── script_self_adapt.py ├── ic15_end2end_rec │ ├── gt.zip │ ├── readme.txt │ ├── rrc_evaluation_funcs_1_1.py │ ├── script.py │ └── script_self_adapt.py ├── ic15_word_spotting │ ├── gt.zip │ ├── readme-鈹傗敩袞鈹も暋鈹€MacBook Pro.txt │ ├── readme.txt │ ├── rrc_evaluation_funcs_1_1.py │ └── script.py ├── msra │ ├── eval.py │ └── file_util.py ├── tt │ ├── Deteval.py │ ├── Deteval_rec.py │ └── polygon_wrapper.py └── tt_rec │ ├── gt.zip │ ├── readme.txt │ ├── rrc_evaluation_funcs_1_1.py │ └── script.py ├── font ├── NotoSansCJK-Regular.ttc ├── README.md └── SIMSUN.TTC ├── infer.sh ├── main.py ├── models ├── __init__.py ├── backbone │ ├── __init__.py │ ├── builder.py │ └── resnet.py ├── builder.py ├── head │ ├── __init__.py │ ├── builder.py │ ├── pa_head.py │ ├── pan_pp_det_head.py │ ├── pan_pp_rec_head.py │ └── psenet_head.py ├── loss │ ├── __init__.py │ ├── acc.py │ ├── builder.py │ ├── dice_loss.py │ ├── emb_loss_v1.py │ ├── emb_loss_v2.py │ ├── iou.py │ └── ohem.py ├── neck │ ├── __init__.py │ ├── builder.py │ ├── fpem_v1.py │ ├── fpem_v2.py │ └── fpn.py ├── pan.py ├── pan_pp.py ├── post_processing │ ├── __init__.py │ ├── beam_search │ │ ├── __init__.py │ │ ├── beam_search.py │ │ └── topk.py │ ├── pa │ │ ├── __init__.py │ │ ├── pa.cpp │ │ ├── pa.pyx │ │ ├── readme.txt │ │ └── setup.py │ └── pse │ │ ├── __init__.py │ │ ├── pse.cpp │ │ ├── pse.pyx │ │ ├── readme.txt │ │ └── setup.py ├── psenet.py └── utils │ ├── __init__.py │ ├── conv_bn_relu.py │ ├── coordconv.py │ └── fuse_conv_bn.py ├── requirements.txt ├── train_pan_pp ├── train.py └── train.sh ├── utils ├── UNet_Order_Dataset.py ├── __init__.py ├── average_meter.py ├── build_dataset.py ├── corrector.py ├── logger.py ├── result_format.py └── visualizer.py └── vis ├── 34-V101P0264.jpg ├── aug-vis.py ├── image_373.jpg └── image_553.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.pyc 4 | 5 | # C extensions 6 | *.so 7 | *.o 8 | *.nfs* 9 | 10 | # Distribution / packaging 11 | .Python 12 | *build/ 13 | *out/ 14 | *outputs/ 15 | *data/ 16 | *weights/ 17 | *ckpt/ 18 | *pretrain/ 19 | *.pth 20 | *job.* 21 | *env.sh 22 | *.tar 23 | *checkpoints/ 24 | *dataloader_vis/ 25 | *pretrained/ 26 | pretrained 27 | data 28 | *~ 29 | tmp/ 30 | .idea/ 31 | outputs*/ 32 | ckpt_key_set.py 33 | .DS_Store 34 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | known_third_party = Cython,PIL,Polygon,cv2,editdistance,matplotlib,mmcv,numpy,pyclipper,scipy,torch,torchvision 3 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: ^eval/ 2 | repos: 3 | - repo: https://gitlab.com/pycqa/flake8.git 4 | rev: 3.8.3 5 | hooks: 6 | - id: flake8 7 | args: ['--exclude=train.py'] 8 | - repo: https://github.com/asottile/seed-isort-config 9 | rev: v2.2.0 10 | hooks: 11 | - id: seed-isort-config 12 | args: ['--exclude=(eval)/.*\.py'] 13 | - repo: https://github.com/timothycrosley/isort 14 | rev: 4.3.21 15 | hooks: 16 | - id: isort 17 | exclude: ^eval/ 18 | - repo: https://github.com/pre-commit/mirrors-yapf 19 | rev: v0.30.0 20 | hooks: 21 | - id: yapf 22 | - repo: https://github.com/pre-commit/pre-commit-hooks 23 | rev: v3.1.0 24 | hooks: 25 | - id: trailing-whitespace 26 | - id: check-yaml 27 | - id: end-of-file-fixer 28 | - id: requirements-txt-fixer 29 | - id: double-quote-string-fixer 30 | - id: check-merge-conflict 31 | - id: fix-encoding-pragma 32 | args: ["--remove"] 33 | - id: mixed-line-ending 34 | args: ["--fix=lf"] 35 | - repo: https://github.com/jumanjihouse/pre-commit-hooks 36 | rev: 2.1.4 37 | hooks: 38 | - id: markdownlint 39 | args: ["-r", "~MD002,~MD013,~MD024,~MD029,~MD033,~MD034,~MD036", "-t", "allow_different_nesting"] 40 | - repo: https://github.com/myint/docformatter 41 | rev: v1.3.1 42 | hooks: 43 | - id: docformatter 44 | args: ["--in-place", "--wrap-descriptions", "79"] 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AlphX-Code-For-DAR 2 | 3 | ## 粤港澳大湾区(黄埔)国际算法算例大赛-古籍文档图像识别与分析算法比赛 AlphX队源码 4 | 5 | [[说明文档](https://docs.qq.com/doc/DWk9IZ2JYVnNyc0hM)] [[PPT展示](https://docs.qq.com/doc/DWk9IZ2JYVnNyc0hM)] 6 | 7 | 8 | 我国的古籍文献资料记录承载着丰富的历史信息和文化传承,为响应古籍文化遗产保护的相关国家战略需求,古籍数字化工作势在必行。由于古籍文档图像存在版式复杂多变、不同朝代的刻字书写风格差异大等问题,古籍文档图像的分析于识别仍极具挑战。本方案整合现有优秀模型,实现汉文古籍文档图像的分析与识别。利用PAN++网络检测任意形状的文本列对象,并结合编码解码网络与启发式算法实现复杂页面的阅读顺序预测。根据前景像素比例,结合霍夫变换与上边缘对齐算法实现高效的任意形状文本串图像的扭曲倾斜矫正。针对过长的文字序列图像,使用叠瓦识别策略避免过度压缩导致的信息损失。最后,使用改进的卷积循环神经网络实现文本字符串图像的端到端识别。本地实验结果表明,所提出的方案稳定可靠,鲁棒性较高,在保持较高准确率的同时维持了合理的推理速度。 9 | 10 | 11 | #### 识别结果 12 | ![example](vis/image_553.jpg) 13 | 14 | ![example](vis/34-V101P0264.jpg) 15 | 16 | ![example](vis/image_373.jpg) 17 | 18 | # 环境配置 19 | ## Modified from [Official Code of PAN++](https://github.com/whai362/pan_pp.pytorch) 20 | 21 | First, clone the repository locally: 22 | 23 | ```shell 24 | git clone https://github.com/ssocean/AlphX-Code-For-DAR.git 25 | ``` 26 | 27 | Then, install PyTorch 1.1.0+, torchvision 0.3.0+, and other requirements: 28 | 29 | ```shell 30 | conda install pytorch torchvision -c pytorch 31 | pip install -r requirement.txt 32 | ``` 33 | 34 | Finally, compile codes of post-processing: 35 | 36 | ```shell 37 | # build pse and pa algorithms 38 | sh ./compile.sh 39 | ``` 40 | 41 | # 训练 42 | 43 | ### PAN++训练 44 | ` 45 | sh train_pan_pp/train.sh 46 | ` 47 | ### CRNN训练 48 | ` 49 | 你可以参考任意的端到端识别代码库进行训练,推荐参考https://github.com/clovaai/deep-text-recognition-benchmark。 50 | ` 51 | 52 | ### UNet训练 53 | 54 | 请参考[UNet](https://github.com/ssocean/UNet-Binarization)完成训练。 55 | 56 | # 推理 57 | 58 | ` 59 | sh infer.sh 60 | ` 61 | 62 | # 文件说明 63 | 64 | `main.py`是大赛要求的程序入口,我们做了本地化适配使其可以运行在本地环境上。 65 | 66 | `TestModel.py` 是大赛要求的推理单张图像的脚本,我们在其中实现了两套方案(DeBUG|加速版本),您可根据需求更改`DEBUG`参数灵活切换推理模式。 67 | 68 | `Infer_Utils.py`是我们为TestModel创建的辅助工具包,其内几乎实现了与单张推理有关的全部功能,包括行列投影、判断图像性质、图像矫正、叠瓦识别、启发式排序、全局顺序排序、模型定义等。 69 | 70 | 71 | 72 | # 引用 73 | 如果您觉得我们的方案有一定帮助,请考虑引用如下工作~ 74 | 75 | ### PAN++ 76 | ``` 77 | @article{wang2021pan++, 78 | title={PAN++: Towards Efficient and Accurate End-to-End Spotting of Arbitrarily-Shaped Text}, 79 | author={Wang, Wenhai and Xie, Enze and Li, Xiang and Liu, Xuebo and Liang, Ding and Zhibo, Yang and Lu, Tong and Shen, Chunhua}, 80 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 81 | year={2021}, 82 | publisher={IEEE} 83 | } 84 | 85 | ``` 86 | 87 | ``` 88 | 赵鹏海. 乌金体藏文文档版面分析与识别系统[D].西北民族大学,2022.DOI:10.27408/d.cnki.gxmzc.2022.000367. 89 | ``` 90 | 91 | ### 独立代码 92 | 93 | [[叠瓦识别](https://github.com/ssocean/Overlapping-Recognition)] 94 | 95 | [[二值化&全局顺序预测](https://github.com/ssocean/UNet-Binarization)] 96 | 97 | 98 | ### 如有疑问,请随时通过ISSUE与我们取得联系 -------------------------------------------------------------------------------- /compile.sh: -------------------------------------------------------------------------------- 1 | cd ./models/post_processing/pa/ 2 | python3 setup.py build_ext --inplace 3 | cd ../pse/ 4 | python3 setup.py build_ext --inplace 5 | cd ../../../ 6 | -------------------------------------------------------------------------------- /config/pan_pp/R18-AUG.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='PAN_PP', 3 | backbone=dict( 4 | type='resnet18', 5 | pretrained=True 6 | ), 7 | neck=dict( 8 | type='FPEM_v2', 9 | in_channels=(64, 128, 256, 512), 10 | out_channels=128 11 | ), 12 | detection_head=dict( 13 | type='PAN_PP_DetHead', 14 | in_channels=512, 15 | hidden_dim=128, 16 | num_classes=6, 17 | loss_text=dict( 18 | type='DiceLoss', 19 | loss_weight=1.0 20 | ), 21 | loss_kernel=dict( 22 | type='DiceLoss', 23 | loss_weight=0.5 24 | ), 25 | loss_emb=dict( 26 | type='EmbLoss_v2', 27 | feature_dim=4, 28 | loss_weight=0.25 29 | ), 30 | use_coordconv=False, 31 | ) 32 | ) 33 | 34 | data = dict( 35 | batch_size=4, 36 | train=dict( 37 | type='PAN_PP_Joint_Train', 38 | split='train', 39 | is_transform=True, 40 | img_size=(896,896), 41 | short_size=896, 42 | kernel_scale=0.5, 43 | read_type='pil', 44 | with_rec=False 45 | ), 46 | test=dict( 47 | type='PAN_PP_COCO', 48 | split='train', 49 | is_transform=False, 50 | img_size=(896,896), 51 | short_size=896, 52 | read_type='pil', 53 | with_rec=True 54 | ), 55 | ) 56 | train_cfg = dict( 57 | lr=1e-2, 58 | schedule='polylr', 59 | epoch=300, 60 | optimizer='Adam' 61 | ) 62 | test_cfg = dict( 63 | min_score=0.80, 64 | min_area=50, 65 | min_kernel_area=0.5, 66 | scale=2, 67 | bbox_type='poly',#rect poly 68 | result_path='outputs/0915-txt',#'outputs/submit_ic15_rec.zip', 69 | ) 70 | # report_speed=True 71 | -------------------------------------------------------------------------------- /config/pan_pp/R50-AUG.py: -------------------------------------------------------------------------------- 1 | 2 | model = dict( 3 | type='PAN_PP', 4 | backbone=dict( 5 | type='resnet50', 6 | pretrained=True 7 | ), 8 | neck=dict( 9 | type='FPEM_v2', 10 | in_channels=(256, 512, 1024, 2048), 11 | out_channels=128 12 | ), 13 | detection_head=dict( 14 | type='PAN_PP_DetHead', 15 | in_channels=512, 16 | hidden_dim=128, 17 | num_classes=6, 18 | loss_text=dict( 19 | type='DiceLoss', 20 | loss_weight=1.0 21 | ), 22 | loss_kernel=dict( 23 | type='DiceLoss', 24 | loss_weight=0.5 25 | ), 26 | loss_emb=dict( 27 | type='EmbLoss_v2', 28 | feature_dim=4, 29 | loss_weight=0.25 30 | ), 31 | use_coordconv=False, 32 | ), 33 | ) 34 | data = dict( 35 | batch_size=4, 36 | train=dict( 37 | type='PAN_PP_Joint_Train', 38 | split='train', 39 | is_transform=True, 40 | img_size=(896,896), 41 | short_size=896, 42 | kernel_scale=0.5, 43 | read_type='pil', 44 | with_rec=False 45 | ), 46 | test=dict( 47 | type='PAN_PP_COCO', 48 | split='train', 49 | is_transform=True, 50 | img_size=(896,896), 51 | short_size=896, 52 | read_type='pil', 53 | with_rec=True 54 | ), 55 | ) 56 | train_cfg = dict( 57 | lr=1e-2, 58 | schedule='polylr', 59 | epoch=300, 60 | optimizer='Adam' 61 | ) 62 | test_cfg = dict( 63 | min_score=0.80, 64 | min_area=50, 65 | min_kernel_area=0.5, 66 | scale=2, 67 | bbox_type='poly',#rect poly 68 | result_path='outputs/txt',#'outputs/submit_ic15_rec.zip', 69 | ) 70 | -------------------------------------------------------------------------------- /convertor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import mmcv 3 | import argparse 4 | import os.path as osp 5 | 6 | parser = argparse.ArgumentParser(description='Hyperparams') 7 | parser.add_argument('checkpoint', nargs='?', type=str, default=None) 8 | args = parser.parse_args() 9 | 10 | dir_name = args.checkpoint.split("/")[-2] 11 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 12 | state_dict = checkpoint['state_dict'] 13 | for k, v in state_dict.items(): 14 | print(k) 15 | checkpoint = {'state_dict': state_dict} 16 | mmcv.mkdir_or_exist("converted/") 17 | try: 18 | torch.save(checkpoint, osp.join("converted", dir_name+".pth.tar"), _use_new_zipfile_serialization=False) 19 | except: 20 | torch.save(checkpoint, osp.join("converted", dir_name+".pth.tar")) 21 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | # 本方案将原始数据集格式转换为coco数据格式,详情见PAN_PP_Joint_Train.py -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(BASE_DIR) 5 | from .builder import * 6 | 7 | from .pan_pp import PAN_PP_IC15, PAN_PP_Joint_Train, PAN_PP_COCO 8 | 9 | 10 | __all__ = [ 11 | 'PAN_PP_IC15','PAN_PP_Joint_Train', 'build_data_loader','PAN_PP_COCO' 12 | ] 13 | -------------------------------------------------------------------------------- /dataset/builder.py: -------------------------------------------------------------------------------- 1 | import dataset 2 | 3 | 4 | def build_data_loader(cfg): 5 | param = dict() 6 | for key in cfg: 7 | if key == 'type': 8 | continue 9 | param[key] = cfg[key] 10 | 11 | data_loader = dataset.__dict__[cfg.type](**param) 12 | 13 | return data_loader 14 | -------------------------------------------------------------------------------- /dataset/pan_pp/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(BASE_DIR) 5 | from .pan_pp_ic15 import PAN_PP_IC15 6 | from .pan_pp_joint_train import PAN_PP_Joint_Train 7 | from dataset.pan_pp.pan_pp_coco import PAN_PP_COCO 8 | __all__ = [ 9 | 'PAN_PP_IC15', 10 | 'PAN_PP_Joint_Train', 11 | 'PAN_PP_COCO' 12 | ] 13 | -------------------------------------------------------------------------------- /dataset/pan_pp/coco_text.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import datetime 5 | from importlib import reload 6 | import json 7 | import os 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from matplotlib.collections import PatchCollection 12 | from matplotlib.patches import PathPatch, Rectangle 13 | from matplotlib.path import Path 14 | 15 | __author__ = 'andreasveit' 16 | __version__ = '1.1' 17 | 18 | # import sys 19 | # reload(sys) 20 | # sys.setdefaultencoding('utf-8') 21 | # Interface for accessing the COCO-Text dataset. 22 | 23 | # COCO-Text is a large dataset designed for text detection and recognition. 24 | # This is a Python API that assists in loading, parsing and visualizing the 25 | # annotations. The format of the COCO-Text annotations is also described on 26 | # the project website http://vision.cornell.edu/se3/coco-text/. 27 | # In addition to this API, please download both 28 | # the COCO images and annotations. 29 | # This dataset is based on Microsoft COCO. Please visit http://mscoco.org/ 30 | # for more information on COCO, including for the image data, object annotatins 31 | # and caption annotations. 32 | 33 | # An alternative to using the API is to load the annotations directly 34 | # into Python dictionary: 35 | # with open(annotation_filename) as json_file: 36 | # coco_text = json.load(json_file) 37 | # Using the API provides additional utility functions. 38 | 39 | # The following API functions are defined: 40 | # COCO_Text - COCO-Text api class that loads COCO annotations 41 | # and prepare data structures. 42 | # getAnnIds - Get ann ids that satisfy given filter conditions. 43 | # getImgIds - Get img ids that satisfy given filter conditions. 44 | # loadAnns - Load anns with the specified ids. 45 | # loadImgs - Load imgs with the specified ids. 46 | # showAnns - Display the specified annotations. 47 | # loadRes - Load algorithm results and create API for accessing them. 48 | # Throughout the API "ann"=annotation, "cat"=category, and "img"=image. 49 | 50 | # COCO-Text Toolbox. Version 1.1 51 | # Data and paper available at: http://vision.cornell.edu/se3/coco-text/ 52 | # Code based on Microsoft COCO Toolbox Version 1.0 by Piotr Dollar and 53 | # Tsung-Yi Lin extended and adapted by Andreas Veit, 2016. 54 | # Licensed under the Simplified BSD License [see bsd.txt] 55 | 56 | 57 | class COCO_Text: 58 | def __init__(self, annotation_file=None): 59 | """Constructor of COCO-Text helper class for reading and visualizing 60 | annotations. 61 | 62 | :param annotation_file (str): location of annotation file 63 | :return: 64 | """ 65 | # load dataset 66 | self.dataset = {} 67 | self.anns = {} 68 | self.imgToAnns = {} 69 | self.catToImgs = {} 70 | self.imgs = {} 71 | self.cats = {} 72 | self.val = [] 73 | self.test = [] 74 | self.train = [] 75 | if annotation_file is not None: 76 | assert os.path.isfile(annotation_file), 'file does not exist' 77 | print('loading annotations into memory...') 78 | time_t = datetime.datetime.utcnow() 79 | dataset = json.load(open(annotation_file, 'r', encoding='utf-8')) 80 | # print(dataset['anns']) 81 | print(datetime.datetime.utcnow() - time_t) 82 | self.dataset = dataset 83 | self.createIndex() 84 | 85 | def createIndex(self): 86 | # create index 87 | print('creating index...') 88 | self.imgToAnns = { 89 | int(cocoid): self.dataset['imgToAnns'][cocoid] 90 | for cocoid in self.dataset['imgToAnns'] 91 | } 92 | self.imgs = { 93 | int(cocoid): self.dataset['imgs'][cocoid] 94 | for cocoid in self.dataset['imgs'] 95 | } 96 | self.anns = { 97 | int(annid): self.dataset['anns'][annid] 98 | for annid in self.dataset['anns'] 99 | } 100 | self.cats = self.dataset['cats'] 101 | self.val = [ 102 | int(cocoid) for cocoid in self.dataset['imgs'] 103 | if self.dataset['imgs'][cocoid]['set'] == 'val' 104 | ] 105 | self.test = [ 106 | int(cocoid) for cocoid in self.dataset['imgs'] 107 | if self.dataset['imgs'][cocoid]['set'] == 'test' 108 | ] 109 | self.train = [ 110 | int(cocoid) for cocoid in self.dataset['imgs'] 111 | if self.dataset['imgs'][cocoid]['set'] == 'train' 112 | ] 113 | print('index created!') 114 | 115 | def info(self): 116 | """Print information about the annotation file. 117 | 118 | :return: 119 | """ 120 | for key, value in self.dataset['info'].items(): 121 | print('%s: %s' % (key, value)) 122 | 123 | def filtering(self, filterDict, criteria): 124 | return [ 125 | key for key in filterDict if all( 126 | criterion(filterDict[key]) for criterion in criteria) 127 | ] 128 | 129 | def getAnnByCat(self, properties): 130 | """Get ann ids that satisfy given properties. 131 | 132 | :param properties (list of tuples of the form 133 | [(category type, category)] e.g., [('readability','readable')] 134 | : get anns for given categories - anns have to satisfy 135 | all given property tuples 136 | :return: ids (int array): integer array of ann ids 137 | """ 138 | return self.filtering( 139 | self.anns, 140 | [lambda d, x=a, y=b: d[x] == y for (a, b) in properties]) 141 | 142 | def getAnnIds(self, imgIds=[], catIds=[], areaRng=[]): 143 | """Get ann ids that satisfy given filter conditions. default skips that 144 | filter. 145 | 146 | :param imgIds (int array) : get anns for given imgs 147 | catIds (list of tuples of the form 148 | [(category type, category)] 149 | e.g., [('readability','readable')] 150 | : get anns for given cats 151 | areaRng (float array): get anns for given area range 152 | (e.g. [0 inf]) 153 | :return: ids (int array) : integer array of ann ids 154 | """ 155 | imgIds = imgIds if type(imgIds) == list else [imgIds] 156 | catIds = catIds if type(catIds) == list else [catIds] 157 | 158 | if len(imgIds) == len(catIds) == len(areaRng) == 0: 159 | anns = list(self.anns.keys()) 160 | else: 161 | if not len(imgIds) == 0: 162 | anns = sum([ 163 | self.imgToAnns[imgId] 164 | for imgId in imgIds if imgId in self.imgToAnns 165 | ], []) 166 | else: 167 | anns = list(self.anns.keys()) 168 | anns = anns if len(catIds) == 0 else list( 169 | set(anns).intersection(set(self.getAnnByCat(catIds)))) 170 | anns = anns if len(areaRng) == 0 else [ 171 | ann for ann in anns if self.anns[ann]['area'] > areaRng[0] 172 | and self.anns[ann]['area'] < areaRng[1] 173 | ] 174 | return anns 175 | 176 | def getImgIds(self, imgIds=[], catIds=[]): 177 | """Get img ids that satisfy given filter conditions. 178 | 179 | :param imgIds (int array) : get imgs for given ids 180 | :param catIds (int array) : get imgs with all given cats 181 | :return: ids (int array) : integer array of img ids 182 | """ 183 | imgIds = imgIds if type(imgIds) == list else [imgIds] 184 | catIds = catIds if type(catIds) == list else [catIds] 185 | 186 | if len(imgIds) == len(catIds) == 0: 187 | ids = list(self.imgs.keys()) 188 | else: 189 | ids = set(imgIds) 190 | if not len(catIds) == 0: 191 | ids = ids.intersection( 192 | set([ 193 | self.anns[annid]['image_id'] 194 | for annid in self.getAnnByCat(catIds) 195 | ])) 196 | 197 | return list(ids) 198 | 199 | def loadAnns(self, ids=[]): 200 | """Load anns with the specified ids. 201 | 202 | :param ids (int array) : integer ids specifying anns 203 | :return: anns (object array) : loaded ann objects 204 | """ 205 | if type(ids) == list: 206 | return [self.anns[id] for id in ids] 207 | elif type(ids) == int: 208 | return [self.anns[ids]] 209 | 210 | def loadImgs(self, ids=[]): 211 | """Load anns with the specified ids. 212 | 213 | :param ids (int array) : integer ids specifying img 214 | :return: imgs (object array) : loaded img objects 215 | """ 216 | if type(ids) == list: 217 | return [self.imgs[id] for id in ids] 218 | elif type(ids) == int: 219 | return [self.imgs[ids]] 220 | 221 | def showAnns(self, anns, show_polygon=False): 222 | """Display the specified annotations. 223 | 224 | :param anns (array of object): annotations to display 225 | :return: None 226 | """ 227 | if len(anns) == 0: 228 | return 0 229 | ax = plt.gca() 230 | boxes = [] 231 | color = [] 232 | for ann in anns: 233 | c = np.random.random((1, 3)).tolist()[0] 234 | if show_polygon: 235 | tl_x, tl_y, tr_x, tr_y, br_x, br_y, bl_x, bl_y = ann['polygon'] 236 | verts = [(tl_x, tl_y), (tr_x, tr_y), (br_x, br_y), 237 | (bl_x, bl_y), (0, 0)] 238 | codes = [ 239 | Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, 240 | Path.CLOSEPOLY 241 | ] 242 | path = Path(verts, codes) 243 | patch = PathPatch(path, facecolor='none') 244 | boxes.append(patch) 245 | left, top = tl_x, tl_y 246 | else: 247 | left, top, width, height = ann['bbox'] 248 | boxes.append(Rectangle([left, top], width, height, alpha=0.4)) 249 | color.append(c) 250 | if 'utf8_string' in list(ann.keys()): 251 | ax.annotate(ann['utf8_string'], (left, top - 4), color=c) 252 | p = PatchCollection(boxes, 253 | facecolors=color, 254 | edgecolors=(0, 0, 0, 1), 255 | linewidths=3, 256 | alpha=0.4) 257 | ax.add_collection(p) 258 | 259 | def loadRes(self, resFile): 260 | """Load result file and return a result api object. 261 | 262 | :param resFile (str) : file name of result file 263 | :return: res (obj) : result api object 264 | """ 265 | res = COCO_Text() 266 | res.dataset['imgs'] = [img for img in self.dataset['imgs']] 267 | 268 | print('Loading and preparing results... ') 269 | time_t = datetime.datetime.utcnow() 270 | if type(resFile) == str: 271 | anns = json.load(open(resFile,encoding='utf-8')) 272 | else: 273 | anns = resFile 274 | assert type(anns) == list, 'results in not an array of objects' 275 | annsImgIds = [int(ann['image_id']) for ann in anns] 276 | 277 | if set(annsImgIds) != (set(annsImgIds) & set(self.getImgIds())): 278 | print('Results do not correspond to current coco set') 279 | print( 280 | 'skipping ', 281 | str( 282 | len(set(annsImgIds)) - 283 | len(set(annsImgIds) & set(self.getImgIds()))), ' images') 284 | annsImgIds = list(set(annsImgIds) & set(self.getImgIds())) 285 | 286 | res.imgToAnns = {cocoid: [] for cocoid in annsImgIds} 287 | res.imgs = {cocoid: self.imgs[cocoid] for cocoid in annsImgIds} 288 | 289 | assert anns[0]['bbox'] != [], 'results have incorrect format' 290 | for id, ann in enumerate(anns): 291 | if ann['image_id'] not in annsImgIds: 292 | continue 293 | bb = ann['bbox'] 294 | ann['area'] = bb[2] * bb[3] 295 | ann['id'] = id 296 | res.anns[id] = ann 297 | res.imgToAnns[ann['image_id']].append(id) 298 | print('DONE (t=%0.2fs)' % 299 | ((datetime.datetime.utcnow() - time_t).total_seconds())) 300 | 301 | return res 302 | 303 | # CT = COCO_Text(r'F:\Data\GJJS-dataset\dataset\train\label-CT.json') 304 | # img_paths = {} 305 | # img_paths['ct'] = CT.getImgIds(imgIds=CT.train,catIds=[('legibility', 306 | # 'legible')]) -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | The evaluation scripts of ICDAR 2015 (IC15), Total-Text (TT), CTW1500 (CTW) and MSRA-TD500 (MSRA) datasets. 3 | 4 | ## [ICDAR 2015](https://rrc.cvc.uab.es/?ch=4) 5 | Text detection 6 | ```shell script 7 | ./eval_ic15.sh 8 | ``` 9 | End-to-End Recognition 10 | ```shell script 11 | ./eval_ic15_end2end_rec.sh 12 | ``` 13 | 14 | Word Spotting 15 | ```shell script 16 | ./eval_ic15_word_spotting.sh 17 | ``` 18 | 19 | ## [Total-Text](https://github.com/cs-chan/Total-Text-Dataset) 20 | Text detection 21 | ```shell script 22 | ./eval_tt.sh 23 | ``` 24 | End-to-End Text Spotting 25 | ```shell script 26 | ./eval_tt_rec.sh 27 | ``` 28 | 29 | 30 | ## [CTW1500](https://github.com/Yuliang-Liu/Curve-Text-Detector) 31 | Text detection 32 | ```shell script 33 | ./eval_ctw.sh 34 | ``` 35 | 36 | ## [MSRA-TD500](http://www.iapr-tc11.org/dataset/MSRA-TD500/MSRA-TD500.zip) 37 | Text detection 38 | ```shell script 39 | ./eval_msra.sh 40 | ``` -------------------------------------------------------------------------------- /eval/ctw/eval.py: -------------------------------------------------------------------------------- 1 | import file_util 2 | import Polygon as plg 3 | import numpy as np 4 | 5 | project_root = '../../' 6 | 7 | pred_root = project_root + 'outputs/submit_ctw' 8 | gt_root = project_root + 'data/ctw1500/test/text_label_circum/' 9 | 10 | 11 | def get_pred(path): 12 | lines = file_util.read_file(path).split('\n') 13 | bboxes = [] 14 | for line in lines: 15 | if line == '': 16 | continue 17 | bbox = line.split(',') 18 | if len(bbox) % 2 == 1: 19 | print(path) 20 | bbox = [int(x) for x in bbox] 21 | bboxes.append(bbox) 22 | return bboxes 23 | 24 | 25 | def get_gt(path): 26 | lines = file_util.read_file(path).split('\n') 27 | bboxes = [] 28 | for line in lines: 29 | if line == '': 30 | continue 31 | # line = util.str.remove_all(line, '\xef\xbb\xbf') 32 | # gt = util.str.split(line, ',') 33 | gt = line.split(',') 34 | 35 | x1 = np.int(gt[0]) 36 | y1 = np.int(gt[1]) 37 | 38 | bbox = [np.int(gt[i]) for i in range(4, 32)] 39 | bbox = np.asarray(bbox) + ([x1, y1] * 14) 40 | 41 | bboxes.append(bbox) 42 | return bboxes 43 | 44 | 45 | def get_union(pD, pG): 46 | areaA = pD.area() 47 | areaB = pG.area() 48 | return areaA + areaB - get_intersection(pD, pG); 49 | 50 | 51 | def get_intersection(pD, pG): 52 | pInt = pD & pG 53 | if len(pInt) == 0: 54 | return 0 55 | return pInt.area() 56 | 57 | 58 | if __name__ == '__main__': 59 | th = 0.5 60 | pred_list = file_util.read_dir(pred_root) 61 | 62 | tp, fp, npos = 0, 0, 0 63 | 64 | for pred_path in pred_list: 65 | preds = get_pred(pred_path) 66 | gt_path = gt_root + pred_path.split('/')[-1] 67 | gts = get_gt(gt_path) 68 | npos += len(gts) 69 | 70 | cover = set() 71 | for pred_id, pred in enumerate(preds): 72 | pred = np.array(pred) 73 | pred = pred.reshape(pred.shape[0] / 2, 2)[:, ::-1] 74 | 75 | pred_p = plg.Polygon(pred) 76 | 77 | flag = False 78 | for gt_id, gt in enumerate(gts): 79 | gt = np.array(gt) 80 | gt = gt.reshape(gt.shape[0] / 2, 2) 81 | gt_p = plg.Polygon(gt) 82 | 83 | union = get_union(pred_p, gt_p) 84 | inter = get_intersection(pred_p, gt_p) 85 | 86 | if inter * 1.0 / union >= th: 87 | if gt_id not in cover: 88 | flag = True 89 | cover.add(gt_id) 90 | if flag: 91 | tp += 1.0 92 | else: 93 | fp += 1.0 94 | 95 | # print tp, fp, npos 96 | precision = tp / (tp + fp) 97 | recall = tp / npos 98 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 99 | 100 | print('p: %.4f, r: %.4f, f: %.4f' % (precision, recall, hmean)) 101 | -------------------------------------------------------------------------------- /eval/ctw/file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def read_dir(root): 4 | file_path_list = [] 5 | for file_path, dirs, files in os.walk(root): 6 | for file in files: 7 | file_path_list.append(os.path.join(file_path, file).replace('\\', '/')) 8 | file_path_list.sort() 9 | return file_path_list 10 | 11 | def read_file(file_path): 12 | file_object = open(file_path, 'r') 13 | file_content = file_object.read() 14 | file_object.close() 15 | return file_content 16 | 17 | def write_file(file_path, file_content): 18 | if file_path.find('/') != -1: 19 | father_dir = '/'.join(file_path.split('/')[0:-1]) 20 | if not os.path.exists(father_dir): 21 | os.makedirs(father_dir) 22 | file_object = open(file_path, 'w') 23 | file_object.write(file_content) 24 | file_object.close() 25 | 26 | 27 | def write_file_not_cover(file_path, file_content): 28 | father_dir = '/'.join(file_path.split('/')[0:-1]) 29 | if not os.path.exists(father_dir): 30 | os.makedirs(father_dir) 31 | file_object = open(file_path, 'a') 32 | file_object.write(file_content) 33 | file_object.close() -------------------------------------------------------------------------------- /eval/eval_ctw.sh: -------------------------------------------------------------------------------- 1 | cd ctw && python2 eval.py && cd .. 2 | -------------------------------------------------------------------------------- /eval/eval_ic15.sh: -------------------------------------------------------------------------------- 1 | cd ic15 && python2 script.py -g=gt.zip -s=../../outputs/submit_ic15.zip && cd .. -------------------------------------------------------------------------------- /eval/eval_ic15_end2end_rec.sh: -------------------------------------------------------------------------------- 1 | cd ic15_end2end_rec && python2 script.py -g=gt.zip -s=../../outputs/submit_ic15_rec.zip && cd .. -------------------------------------------------------------------------------- /eval/eval_ic15_word_spotting.sh: -------------------------------------------------------------------------------- 1 | cd ic15_word_spotting && python3 script.py -g=gt.zip -s=../../outputs/submit_ic15_rec.zip && cd .. -------------------------------------------------------------------------------- /eval/eval_msra.sh: -------------------------------------------------------------------------------- 1 | cd msra && python2 eval.py && cd .. 2 | -------------------------------------------------------------------------------- /eval/eval_tt.sh: -------------------------------------------------------------------------------- 1 | cd tt && python2 Deteval.py && cd .. -------------------------------------------------------------------------------- /eval/eval_tt_rec.sh: -------------------------------------------------------------------------------- 1 | cd tt_rec && python2 script.py -g=gt.zip -s=../../outputs/submit_ic15_rec.zip && cd .. -------------------------------------------------------------------------------- /eval/ic15/gt.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssocean/AlphX-Code-For-DAR/6be8c05c12d373245c97f8434d3e5abe9f8a4550/eval/ic15/gt.zip -------------------------------------------------------------------------------- /eval/ic15/rrc_evaluation_funcs_v1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | #encoding: UTF-8 3 | import json 4 | import sys;sys.path.append('./') 5 | import zipfile 6 | import re 7 | import sys 8 | import os 9 | import codecs 10 | import importlib 11 | from StringIO import StringIO 12 | 13 | def print_help(): 14 | sys.stdout.write('Usage: python %s.py -g= -s= [-o= -p=]' %sys.argv[0]) 15 | sys.exit(2) 16 | 17 | 18 | def load_zip_file_keys(file,fileNameRegExp=''): 19 | """ 20 | Returns an array with the entries of the ZIP file that match with the regular expression. 21 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 22 | """ 23 | try: 24 | archive=zipfile.ZipFile(file, mode='r', allowZip64=True) 25 | except : 26 | raise Exception('Error loading the ZIP archive.') 27 | 28 | pairs = [] 29 | 30 | for name in archive.namelist(): 31 | addFile = True 32 | keyName = name 33 | if fileNameRegExp!="": 34 | m = re.match(fileNameRegExp,name) 35 | if m == None: 36 | addFile = False 37 | else: 38 | if len(m.groups())>0: 39 | keyName = m.group(1) 40 | 41 | if addFile: 42 | pairs.append( keyName ) 43 | 44 | return pairs 45 | 46 | 47 | def load_zip_file(file,fileNameRegExp='',allEntries=False): 48 | """ 49 | Returns an array with the contents (filtered by fileNameRegExp) of a ZIP file. 50 | The key's are the names or the file or the capturing group definied in the fileNameRegExp 51 | allEntries validates that all entries in the ZIP file pass the fileNameRegExp 52 | """ 53 | try: 54 | archive=zipfile.ZipFile(file, mode='r', allowZip64=True) 55 | except : 56 | raise Exception('Error loading the ZIP archive') 57 | 58 | pairs = [] 59 | for name in archive.namelist(): 60 | addFile = True 61 | keyName = name 62 | if fileNameRegExp!="": 63 | m = re.match(fileNameRegExp,name) 64 | if m == None: 65 | addFile = False 66 | else: 67 | if len(m.groups())>0: 68 | keyName = m.group(1) 69 | 70 | if addFile: 71 | pairs.append( [ keyName , archive.read(name)] ) 72 | else: 73 | if allEntries: 74 | raise Exception('ZIP entry not valid: %s' %name) 75 | 76 | return dict(pairs) 77 | 78 | def decode_utf8(raw): 79 | """ 80 | Returns a Unicode object on success, or None on failure 81 | """ 82 | try: 83 | raw = codecs.decode(raw,'utf-8', 'replace') 84 | #extracts BOM if exists 85 | raw = raw.encode('utf8') 86 | if raw.startswith(codecs.BOM_UTF8): 87 | raw = raw.replace(codecs.BOM_UTF8, '', 1) 88 | return raw.decode('utf-8') 89 | except: 90 | return None 91 | 92 | def validate_lines_in_file(fileName,file_contents,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): 93 | """ 94 | This function validates that all lines of the file calling the Line validation function for each line 95 | """ 96 | utf8File = decode_utf8(file_contents) 97 | if (utf8File is None) : 98 | raise Exception("The file %s is not UTF-8" %fileName) 99 | 100 | lines = utf8File.split( "\r\n" if CRLF else "\n" ) 101 | for line in lines: 102 | line = line.replace("\r","").replace("\n","") 103 | if(line != ""): 104 | try: 105 | validate_tl_line(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) 106 | except Exception as e: 107 | raise Exception(("Line in sample not valid. Sample: %s Line: %s Error: %s" %(fileName,line,str(e))).encode('utf-8', 'replace')) 108 | 109 | 110 | 111 | def validate_tl_line(line,LTRB=True,withTranscription=True,withConfidence=True,imWidth=0,imHeight=0): 112 | """ 113 | Validate the format of the line. If the line is not valid an exception will be raised. 114 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 115 | Posible values are: 116 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 117 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 118 | """ 119 | get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight) 120 | 121 | 122 | def get_tl_line_values(line,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0): 123 | """ 124 | Validate the format of the line. If the line is not valid an exception will be raised. 125 | If maxWidth and maxHeight are specified, all points must be inside the imgage bounds. 126 | Posible values are: 127 | LTRB=True: xmin,ymin,xmax,ymax[,confidence][,transcription] 128 | LTRB=False: x1,y1,x2,y2,x3,y3,x4,y4[,confidence][,transcription] 129 | Returns values from a textline. Points , [Confidences], [Transcriptions] 130 | """ 131 | confidence = 0.0 132 | transcription = ""; 133 | points = [] 134 | 135 | numPoints = 4; 136 | 137 | if LTRB: 138 | 139 | numPoints = 4; 140 | 141 | if withTranscription and withConfidence: 142 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 143 | if m == None : 144 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 145 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence,transcription") 146 | elif withConfidence: 147 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) 148 | if m == None : 149 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,confidence") 150 | elif withTranscription: 151 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,(.*)$',line) 152 | if m == None : 153 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax,transcription") 154 | else: 155 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*,?\s*$',line) 156 | if m == None : 157 | raise Exception("Format incorrect. Should be: xmin,ymin,xmax,ymax") 158 | 159 | xmin = int(m.group(1)) 160 | ymin = int(m.group(2)) 161 | xmax = int(m.group(3)) 162 | ymax = int(m.group(4)) 163 | if(xmax0 and imHeight>0): 171 | validate_point_inside_bounds(xmin,ymin,imWidth,imHeight); 172 | validate_point_inside_bounds(xmax,ymax,imWidth,imHeight); 173 | 174 | else: 175 | 176 | numPoints = 8; 177 | 178 | if withTranscription and withConfidence: 179 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*,(.*)$',line) 180 | if m == None : 181 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence,transcription") 182 | elif withConfidence: 183 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*([0-1].?[0-9]*)\s*$',line) 184 | if m == None : 185 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,confidence") 186 | elif withTranscription: 187 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,(.*)$',line) 188 | if m == None : 189 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4,transcription") 190 | else: 191 | m = re.match(r'^\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*,\s*(-?[0-9]+)\s*$',line) 192 | if m == None : 193 | raise Exception("Format incorrect. Should be: x1,y1,x2,y2,x3,y3,x4,y4") 194 | 195 | points = [ float(m.group(i)) for i in range(1, (numPoints+1) ) ] 196 | 197 | validate_clockwise_points(points) 198 | 199 | if (imWidth>0 and imHeight>0): 200 | validate_point_inside_bounds(points[0],points[1],imWidth,imHeight); 201 | validate_point_inside_bounds(points[2],points[3],imWidth,imHeight); 202 | validate_point_inside_bounds(points[4],points[5],imWidth,imHeight); 203 | validate_point_inside_bounds(points[6],points[7],imWidth,imHeight); 204 | 205 | 206 | if withConfidence: 207 | try: 208 | confidence = float(m.group(numPoints+1)) 209 | except ValueError: 210 | raise Exception("Confidence value must be a float") 211 | 212 | if withTranscription: 213 | posTranscription = numPoints + (2 if withConfidence else 1) 214 | transcription = m.group(posTranscription) 215 | m2 = re.match(r'^\s*\"(.*)\"\s*$',transcription) 216 | if m2 != None : #Transcription with double quotes, we extract the value and replace escaped characters 217 | transcription = m2.group(1).replace("\\\\", "\\").replace("\\\"", "\"") 218 | 219 | return points,confidence,transcription 220 | 221 | 222 | def validate_point_inside_bounds(x,y,imWidth,imHeight): 223 | if(x<0 or x>imWidth): 224 | raise Exception("X value (%s) not valid. Image dimensions: (%s,%s)" %(xmin,imWidth,imHeight)) 225 | if(y<0 or y>imHeight): 226 | raise Exception("Y value (%s) not valid. Image dimensions: (%s,%s) Sample: %s Line:%s" %(ymin,imWidth,imHeight)) 227 | 228 | def validate_clockwise_points(points): 229 | """ 230 | Validates that the points that the 4 points that dlimite a polygon are in clockwise order. 231 | """ 232 | 233 | if len(points) != 8: 234 | raise Exception("Points list not valid." + str(len(points))) 235 | 236 | point = [ 237 | [int(points[0]) , int(points[1])], 238 | [int(points[2]) , int(points[3])], 239 | [int(points[4]) , int(points[5])], 240 | [int(points[6]) , int(points[7])] 241 | ] 242 | edge = [ 243 | ( point[1][0] - point[0][0])*( point[1][1] + point[0][1]), 244 | ( point[2][0] - point[1][0])*( point[2][1] + point[1][1]), 245 | ( point[3][0] - point[2][0])*( point[3][1] + point[2][1]), 246 | ( point[0][0] - point[3][0])*( point[0][1] + point[3][1]) 247 | ] 248 | 249 | summatory = edge[0] + edge[1] + edge[2] + edge[3]; 250 | if summatory>0: 251 | raise Exception("Points are not clockwise. The coordinates of bounding quadrilaterals have to be given in clockwise order. Regarding the correct interpretation of 'clockwise' remember that the image coordinate system used is the standard one, with the image origin at the upper left, the X axis extending to the right and Y axis extending downwards.") 252 | 253 | def get_tl_line_values_from_file_contents(content,CRLF=True,LTRB=True,withTranscription=False,withConfidence=False,imWidth=0,imHeight=0,sort_by_confidences=True): 254 | """ 255 | Returns all points, confindences and transcriptions of a file in lists. Valid line formats: 256 | xmin,ymin,xmax,ymax,[confidence],[transcription] 257 | x1,y1,x2,y2,x3,y3,x4,y4,[confidence],[transcription] 258 | """ 259 | pointsList = [] 260 | transcriptionsList = [] 261 | confidencesList = [] 262 | 263 | lines = content.split( "\r\n" if CRLF else "\n" ) 264 | for line in lines: 265 | line = line.replace("\r","").replace("\n","") 266 | if(line != "") : 267 | points, confidence, transcription = get_tl_line_values(line,LTRB,withTranscription,withConfidence,imWidth,imHeight); 268 | pointsList.append(points) 269 | transcriptionsList.append(transcription) 270 | confidencesList.append(confidence) 271 | 272 | if withConfidence and len(confidencesList)>0 and sort_by_confidences: 273 | import numpy as np 274 | sorted_ind = np.argsort(-np.array(confidencesList)) 275 | confidencesList = [confidencesList[i] for i in sorted_ind] 276 | pointsList = [pointsList[i] for i in sorted_ind] 277 | transcriptionsList = [transcriptionsList[i] for i in sorted_ind] 278 | 279 | return pointsList,confidencesList,transcriptionsList 280 | 281 | def main_evaluation(p,default_evaluation_params_fn,validate_data_fn,evaluate_method_fn,show_result=True,per_sample=True): 282 | """ 283 | This process validates a method, evaluates it and if it succed generates a ZIP file with a JSON entry for each sample. 284 | Params: 285 | p: Dictionary of parmeters with the GT/submission locations. If None is passed, the parameters send by the system are used. 286 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 287 | validate_data_fn: points to a method that validates the corrct format of the submission 288 | evaluate_method_fn: points to a function that evaluated the submission and return a Dictionary with the results 289 | """ 290 | 291 | if (p == None): 292 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 293 | if(len(sys.argv)<3): 294 | print_help() 295 | 296 | evalParams = default_evaluation_params_fn() 297 | if 'p' in p.keys(): 298 | evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) 299 | 300 | resDict={'calculated':True,'Message':'','method':'{}','per_sample':'{}'} 301 | try: 302 | validate_data_fn(p['g'], p['s'], evalParams) 303 | evalData = evaluate_method_fn(p['g'], p['s'], evalParams) 304 | resDict.update(evalData) 305 | 306 | except Exception, e: 307 | resDict['Message']= str(e) 308 | resDict['calculated']=False 309 | 310 | if 'o' in p: 311 | if not os.path.exists(p['o']): 312 | os.makedirs(p['o']) 313 | 314 | resultsOutputname = p['o'] + '/results.zip' 315 | outZip = zipfile.ZipFile(resultsOutputname, mode='w', allowZip64=True) 316 | 317 | del resDict['per_sample'] 318 | if 'output_items' in resDict.keys(): 319 | del resDict['output_items'] 320 | 321 | outZip.writestr('method.json',json.dumps(resDict)) 322 | 323 | if not resDict['calculated']: 324 | if show_result: 325 | sys.stderr.write('Error!\n'+ resDict['Message']+'\n\n') 326 | if 'o' in p: 327 | outZip.close() 328 | return resDict 329 | 330 | if 'o' in p: 331 | if per_sample == True: 332 | for k,v in evalData['per_sample'].iteritems(): 333 | outZip.writestr( k + '.json',json.dumps(v)) 334 | 335 | if 'output_items' in evalData.keys(): 336 | for k, v in evalData['output_items'].iteritems(): 337 | outZip.writestr( k,v) 338 | 339 | outZip.close() 340 | 341 | if show_result: 342 | sys.stdout.write("Calculated!") 343 | sys.stdout.write(json.dumps(resDict['method'])) 344 | 345 | return resDict 346 | 347 | 348 | def main_validation(default_evaluation_params_fn,validate_data_fn): 349 | """ 350 | This process validates a method 351 | Params: 352 | default_evaluation_params_fn: points to a function that returns a dictionary with the default parameters used for the evaluation 353 | validate_data_fn: points to a method that validates the corrct format of the submission 354 | """ 355 | try: 356 | p = dict([s[1:].split('=') for s in sys.argv[1:]]) 357 | evalParams = default_evaluation_params_fn() 358 | if 'p' in p.keys(): 359 | evalParams.update( p['p'] if isinstance(p['p'], dict) else json.loads(p['p'][1:-1]) ) 360 | 361 | validate_data_fn(p['g'], p['s'], evalParams) 362 | print 'SUCCESS' 363 | sys.exit(0) 364 | except Exception as e: 365 | print str(e) 366 | sys.exit(101) -------------------------------------------------------------------------------- /eval/ic15/script.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from collections import namedtuple 4 | import rrc_evaluation_funcs 5 | import importlib 6 | 7 | def evaluation_imports(): 8 | """ 9 | evaluation_imports: Dictionary ( key = module name , value = alias ) with python modules used in the evaluation. 10 | """ 11 | return { 12 | 'Polygon':'plg', 13 | 'numpy':'np' 14 | } 15 | 16 | def default_evaluation_params(): 17 | """ 18 | default_evaluation_params: Default parameters to use for the validation and evaluation. 19 | """ 20 | return { 21 | 'IOU_CONSTRAINT' :0.5, 22 | 'AREA_PRECISION_CONSTRAINT' :0.5, 23 | 'GT_SAMPLE_NAME_2_ID':'gt_img_([0-9]+).txt', 24 | 'DET_SAMPLE_NAME_2_ID':'res_img_([0-9]+).txt', 25 | 'LTRB':False, #LTRB:2points(left,top,right,bottom) or 4 points(x1,y1,x2,y2,x3,y3,x4,y4) 26 | 'CRLF':False, # Lines are delimited by Windows CRLF format 27 | 'CONFIDENCES':False, #Detections must include confidence value. AP will be calculated 28 | 'PER_SAMPLE_RESULTS':True #Generate per sample results and produce data for visualization 29 | } 30 | 31 | def validate_data(gtFilePath, submFilePath,evaluationParams): 32 | """ 33 | Method validate_data: validates that all files in the results folder are correct (have the correct name contents). 34 | Validates also that there are no missing files in the folder. 35 | If some error detected, the method raises the error 36 | """ 37 | gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID']) 38 | 39 | subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True) 40 | 41 | #Validate format of GroundTruth 42 | for k in gt: 43 | rrc_evaluation_funcs.validate_lines_in_file(k,gt[k],evaluationParams['CRLF'],evaluationParams['LTRB'],True) 44 | 45 | #Validate format of results 46 | for k in subm: 47 | if (k in gt) == False : 48 | raise Exception("The sample %s not present in GT" %k) 49 | 50 | rrc_evaluation_funcs.validate_lines_in_file(k,subm[k],evaluationParams['CRLF'],evaluationParams['LTRB'],False,evaluationParams['CONFIDENCES']) 51 | 52 | 53 | def evaluate_method(gtFilePath, submFilePath, evaluationParams): 54 | """ 55 | Method evaluate_method: evaluate method and returns the results 56 | Results. Dictionary with the following values: 57 | - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } 58 | - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } 59 | """ 60 | 61 | for module,alias in evaluation_imports().iteritems(): 62 | globals()[alias] = importlib.import_module(module) 63 | 64 | def polygon_from_points(points): 65 | """ 66 | Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4 67 | """ 68 | # resBoxes=np.empty([1,8],dtype='int32') 69 | # resBoxes[0,0]=int(points[0]) 70 | # resBoxes[0,4]=int(points[1]) 71 | # resBoxes[0,1]=int(points[2]) 72 | # resBoxes[0,5]=int(points[3]) 73 | # resBoxes[0,2]=int(points[4]) 74 | # resBoxes[0,6]=int(points[5]) 75 | # resBoxes[0,3]=int(points[6]) 76 | # resBoxes[0,7]=int(points[7]) 77 | # pointMat = resBoxes[0].reshape([2,4]).T 78 | # return plg.Polygon( pointMat) 79 | 80 | p = np.array(points) 81 | p = p.reshape(p.shape[0]//2, 2) 82 | p = plg.Polygon(p) 83 | return p 84 | 85 | def rectangle_to_polygon(rect): 86 | resBoxes=np.empty([1,8],dtype='int32') 87 | resBoxes[0,0]=int(rect.xmin) 88 | resBoxes[0,4]=int(rect.ymax) 89 | resBoxes[0,1]=int(rect.xmin) 90 | resBoxes[0,5]=int(rect.ymin) 91 | resBoxes[0,2]=int(rect.xmax) 92 | resBoxes[0,6]=int(rect.ymin) 93 | resBoxes[0,3]=int(rect.xmax) 94 | resBoxes[0,7]=int(rect.ymax) 95 | 96 | pointMat = resBoxes[0].reshape([2,4]).T 97 | 98 | return plg.Polygon( pointMat) 99 | 100 | def rectangle_to_points(rect): 101 | points = [int(rect.xmin), int(rect.ymax), int(rect.xmax), int(rect.ymax), int(rect.xmax), int(rect.ymin), int(rect.xmin), int(rect.ymin)] 102 | return points 103 | 104 | def get_union(pD,pG): 105 | areaA = pD.area(); 106 | areaB = pG.area(); 107 | return areaA + areaB - get_intersection(pD, pG); 108 | 109 | def get_intersection_over_union(pD,pG): 110 | try: 111 | return get_intersection(pD, pG) / get_union(pD, pG); 112 | except: 113 | return 0 114 | 115 | def get_intersection(pD,pG): 116 | pInt = pD & pG 117 | if len(pInt) == 0: 118 | return 0 119 | return pInt.area() 120 | 121 | def compute_ap(confList, matchList,numGtCare): 122 | correct = 0 123 | AP = 0 124 | if len(confList)>0: 125 | confList = np.array(confList) 126 | matchList = np.array(matchList) 127 | sorted_ind = np.argsort(-confList) 128 | confList = confList[sorted_ind] 129 | matchList = matchList[sorted_ind] 130 | for n in range(len(confList)): 131 | match = matchList[n] 132 | if match: 133 | correct += 1 134 | AP += float(correct)/(n + 1) 135 | 136 | if numGtCare>0: 137 | AP /= numGtCare 138 | 139 | return AP 140 | 141 | perSampleMetrics = {} 142 | 143 | matchedSum = 0 144 | 145 | Rectangle = namedtuple('Rectangle', 'xmin ymin xmax ymax') 146 | 147 | gt = rrc_evaluation_funcs.load_zip_file(gtFilePath,evaluationParams['GT_SAMPLE_NAME_2_ID']) 148 | subm = rrc_evaluation_funcs.load_zip_file(submFilePath,evaluationParams['DET_SAMPLE_NAME_2_ID'],True) 149 | 150 | numGlobalCareGt = 0; 151 | numGlobalCareDet = 0; 152 | 153 | arrGlobalConfidences = []; 154 | arrGlobalMatches = []; 155 | 156 | for resFile in gt: 157 | 158 | gtFile = rrc_evaluation_funcs.decode_utf8(gt[resFile]) 159 | recall = 0 160 | precision = 0 161 | hmean = 0 162 | 163 | detMatched = 0 164 | 165 | iouMat = np.empty([1,1]) 166 | 167 | gtPols = [] 168 | detPols = [] 169 | 170 | gtPolPoints = [] 171 | detPolPoints = [] 172 | 173 | #Array of Ground Truth Polygons' keys marked as don't Care 174 | gtDontCarePolsNum = [] 175 | #Array of Detected Polygons' matched with a don't Care GT 176 | detDontCarePolsNum = [] 177 | 178 | pairs = [] 179 | detMatchedNums = [] 180 | 181 | arrSampleConfidences = []; 182 | arrSampleMatch = []; 183 | sampleAP = 0; 184 | 185 | evaluationLog = "" 186 | 187 | pointsList,_,transcriptionsList = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(gtFile,evaluationParams['CRLF'],evaluationParams['LTRB'],True,False) 188 | for n in range(len(pointsList)): 189 | points = pointsList[n] 190 | transcription = transcriptionsList[n] 191 | dontCare = transcription == "###" 192 | if evaluationParams['LTRB']: 193 | gtRect = Rectangle(*points) 194 | gtPol = rectangle_to_polygon(gtRect) 195 | else: 196 | gtPol = polygon_from_points(points) 197 | gtPols.append(gtPol) 198 | gtPolPoints.append(points) 199 | if dontCare: 200 | gtDontCarePolsNum.append( len(gtPols)-1 ) 201 | 202 | evaluationLog += "GT polygons: " + str(len(gtPols)) + (" (" + str(len(gtDontCarePolsNum)) + " don't care)\n" if len(gtDontCarePolsNum)>0 else "\n") 203 | 204 | if resFile in subm: 205 | 206 | detFile = rrc_evaluation_funcs.decode_utf8(subm[resFile]) 207 | def get_pred(file): 208 | lines = file.split('\n') 209 | pointsList = [] 210 | for line in lines: 211 | if line == '': 212 | continue 213 | bbox = line.split(',') 214 | if len(bbox) % 2 == 1: 215 | print(path) 216 | bbox = [int(x) for x in bbox] 217 | pointsList.append(bbox) 218 | return pointsList 219 | 220 | # pointsList,confidencesList,_ = rrc_evaluation_funcs.get_tl_line_values_from_file_contents(detFile,evaluationParams['CRLF'],evaluationParams['LTRB'],False,evaluationParams['CONFIDENCES']) 221 | # print(pointsList) 222 | # print(confidencesList) 223 | 224 | pointsList = get_pred(detFile) 225 | confidencesList = [0.0] * len(pointsList) 226 | 227 | for n in range(len(pointsList)): 228 | points = pointsList[n] 229 | 230 | if evaluationParams['LTRB']: 231 | detRect = Rectangle(*points) 232 | detPol = rectangle_to_polygon(detRect) 233 | else: 234 | detPol = polygon_from_points(points) 235 | detPols.append(detPol) 236 | detPolPoints.append(points) 237 | if len(gtDontCarePolsNum)>0 : 238 | for dontCarePol in gtDontCarePolsNum: 239 | dontCarePol = gtPols[dontCarePol] 240 | intersected_area = get_intersection(dontCarePol,detPol) 241 | pdDimensions = detPol.area() 242 | precision = 0 if pdDimensions == 0 else intersected_area / pdDimensions 243 | if (precision > evaluationParams['AREA_PRECISION_CONSTRAINT'] ): 244 | detDontCarePolsNum.append( len(detPols)-1 ) 245 | break 246 | 247 | evaluationLog += "DET polygons: " + str(len(detPols)) + (" (" + str(len(detDontCarePolsNum)) + " don't care)\n" if len(detDontCarePolsNum)>0 else "\n") 248 | 249 | if len(gtPols)>0 and len(detPols)>0: 250 | #Calculate IoU and precision matrixs 251 | outputShape=[len(gtPols),len(detPols)] 252 | iouMat = np.empty(outputShape) 253 | gtRectMat = np.zeros(len(gtPols),np.int8) 254 | detRectMat = np.zeros(len(detPols),np.int8) 255 | for gtNum in range(len(gtPols)): 256 | for detNum in range(len(detPols)): 257 | pG = gtPols[gtNum] 258 | pD = detPols[detNum] 259 | iouMat[gtNum,detNum] = get_intersection_over_union(pD,pG) 260 | 261 | for gtNum in range(len(gtPols)): 262 | for detNum in range(len(detPols)): 263 | if gtRectMat[gtNum] == 0 and detRectMat[detNum] == 0 and gtNum not in gtDontCarePolsNum and detNum not in detDontCarePolsNum : 264 | if iouMat[gtNum,detNum]>evaluationParams['IOU_CONSTRAINT']: 265 | gtRectMat[gtNum] = 1 266 | detRectMat[detNum] = 1 267 | detMatched += 1 268 | pairs.append({'gt':gtNum,'det':detNum}) 269 | detMatchedNums.append(detNum) 270 | evaluationLog += "Match GT #" + str(gtNum) + " with Det #" + str(detNum) + "\n" 271 | 272 | if evaluationParams['CONFIDENCES']: 273 | for detNum in range(len(detPols)): 274 | if detNum not in detDontCarePolsNum : 275 | #we exclude the don't care detections 276 | match = detNum in detMatchedNums 277 | 278 | arrSampleConfidences.append(confidencesList[detNum]) 279 | arrSampleMatch.append(match) 280 | 281 | arrGlobalConfidences.append(confidencesList[detNum]); 282 | arrGlobalMatches.append(match); 283 | 284 | numGtCare = (len(gtPols) - len(gtDontCarePolsNum)) 285 | numDetCare = (len(detPols) - len(detDontCarePolsNum)) 286 | if numGtCare == 0: 287 | recall = float(1) 288 | precision = float(0) if numDetCare >0 else float(1) 289 | sampleAP = precision 290 | else: 291 | recall = float(detMatched) / numGtCare 292 | precision = 0 if numDetCare==0 else float(detMatched) / numDetCare 293 | if evaluationParams['CONFIDENCES'] and evaluationParams['PER_SAMPLE_RESULTS']: 294 | sampleAP = compute_ap(arrSampleConfidences, arrSampleMatch, numGtCare ) 295 | 296 | hmean = 0 if (precision + recall)==0 else 2.0 * precision * recall / (precision + recall) 297 | 298 | matchedSum += detMatched 299 | numGlobalCareGt += numGtCare 300 | numGlobalCareDet += numDetCare 301 | 302 | if evaluationParams['PER_SAMPLE_RESULTS']: 303 | perSampleMetrics[resFile] = { 304 | 'precision':precision, 305 | 'recall':recall, 306 | 'hmean':hmean, 307 | 'pairs':pairs, 308 | 'AP':sampleAP, 309 | 'iouMat':[] if len(detPols)>100 else iouMat.tolist(), 310 | 'gtPolPoints':gtPolPoints, 311 | 'detPolPoints':detPolPoints, 312 | 'gtDontCare':gtDontCarePolsNum, 313 | 'detDontCare':detDontCarePolsNum, 314 | 'evaluationParams': evaluationParams, 315 | 'evaluationLog': evaluationLog 316 | } 317 | 318 | # Compute MAP and MAR 319 | AP = 0 320 | if evaluationParams['CONFIDENCES']: 321 | AP = compute_ap(arrGlobalConfidences, arrGlobalMatches, numGlobalCareGt) 322 | 323 | methodRecall = 0 if numGlobalCareGt == 0 else float(matchedSum)/numGlobalCareGt 324 | methodPrecision = 0 if numGlobalCareDet == 0 else float(matchedSum)/numGlobalCareDet 325 | methodHmean = 0 if methodRecall + methodPrecision==0 else 2* methodRecall * methodPrecision / (methodRecall + methodPrecision) 326 | 327 | methodMetrics = {'precision':methodPrecision, 'recall':methodRecall,'hmean': methodHmean, 'AP': AP } 328 | 329 | resDict = {'calculated':True,'Message':'','method': methodMetrics,'per_sample': perSampleMetrics} 330 | 331 | return resDict 332 | 333 | 334 | 335 | if __name__=='__main__': 336 | 337 | rrc_evaluation_funcs.main_evaluation(None,default_evaluation_params,validate_data,evaluate_method) 338 | print('') 339 | -------------------------------------------------------------------------------- /eval/ic15/script_self_adapt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mmcv 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser(description='Hyperparams') 6 | # parser.add_argument('--gt', nargs='?', type=str, default=None) 7 | parser.add_argument('--pred', nargs='?', type=str, default=None) 8 | args = parser.parse_args() 9 | 10 | output_root = '../outputs/tmp_results/' 11 | pred = mmcv.load(args.pred) 12 | 13 | def write_result_as_txt(image_name, bboxes, path, words=None): 14 | if not os.path.exists(path): 15 | os.makedirs(path) 16 | 17 | file_path = path + 'res_%s.txt'%(image_name) 18 | lines = [] 19 | for i, bbox in enumerate(bboxes): 20 | values = [int(v) for v in bbox] 21 | if words is None: 22 | line = "%d,%d,%d,%d,%d,%d,%d,%d\n"%tuple(values) 23 | lines.append(line) 24 | elif words[i] is not None: 25 | line = "%d,%d,%d,%d,%d,%d,%d,%d"%tuple(values) + ",%s\n"%words[i] 26 | lines.append(line) 27 | with open(file_path, 'w') as f: 28 | for line in lines: 29 | f.write(line) 30 | 31 | def eval(thr): 32 | for key in pred: 33 | pred_ = pred[key] 34 | line_num = len(pred_['scores']) 35 | bboxes = [] 36 | # words = [] 37 | for i in range(line_num): 38 | if pred_['scores'][i] < thr: 39 | continue 40 | bboxes.append(pred_['bboxes'][i]) 41 | # words.append(pred_['words'][i]) 42 | 43 | write_result_as_txt(key, bboxes, output_root) 44 | 45 | cmd = 'cd %s;zip -j %s %s/*' % ('../outputs/', 'tmp_results.zip', 'tmp_results') 46 | res_cmd = os.popen(cmd) 47 | res_cmd.read() 48 | 49 | cmd = 'cd ic15 && python2 script.py -g=gt.zip -s=../../outputs/tmp_results.zip && cd ..' 50 | res_cmd = os.popen(cmd) 51 | res_cmd = res_cmd.read() 52 | h_mean = float(res_cmd.split(',')[-2].split(':')[-1]) 53 | return res_cmd, h_mean 54 | 55 | max_h_mean = 0 56 | best_thr = 0 57 | best_res = '' 58 | for i in range(85, 100): 59 | thr = float(i) / 100 60 | # print('Testing thr: %f'%thr) 61 | res, h_mean = eval(thr) 62 | # print(thr, h_mean) 63 | if h_mean > max_h_mean: 64 | max_h_mean = h_mean 65 | best_thr = thr 66 | best_res = res 67 | 68 | print('thr: %f | %s'%(best_thr, best_res)) 69 | -------------------------------------------------------------------------------- /eval/ic15_end2end_rec/gt.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssocean/AlphX-Code-For-DAR/6be8c05c12d373245c97f8434d3e5abe9f8a4550/eval/ic15_end2end_rec/gt.zip -------------------------------------------------------------------------------- /eval/ic15_end2end_rec/readme.txt: -------------------------------------------------------------------------------- 1 | INSTRUCTIONS FOR THE STANDALONE SCRIPTS 2 | Requirements: 3 | - Python version 3. 4 | - Each Task requires different Python modules. When running the script, if some module is not installed you will see a notification and installation instructions. 5 | 6 | Procedure: 7 | Download the ZIP file for the requested script and unzip it to a directory. 8 | 9 | Open a terminal in the directory and run the command: 10 | python script.py –g=gt.zip –s=submit.zip 11 | 12 | If you have already installed all the required modules, then you will see the method’s results or an error message if the submitted file is not correct. 13 | 14 | If a module is not present, you should install them with PIP: pip install 'module' 15 | 16 | In case of Polygon module, use: 'pip install Polygon3' 17 | 18 | parameters: 19 | -g: Path of the Ground Truth file. In most cases, the Ground Truth will be included in the same Zip file named 'gt.zip', gt.txt' or 'gt.json'. If not, you will be able to get it on the Downloads page of the Task. 20 | -s: Path of your method's results file. 21 | 22 | Optional parameters: 23 | -o: Path to a directory where to copy the file ‘results.zip’ that contains per-sample results. 24 | -p: JSON string parameters to override the script default parameters. The parameters that can be overrided are inside the function 'default_evaluation_params' located at the begining of the evaluation Script. 25 | 26 | Example: python script.py –g=gt.zip –s=submit.zip –o=./ -p={\"IOU_CONSTRAINT\":0.8} -------------------------------------------------------------------------------- /eval/ic15_end2end_rec/script_self_adapt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mmcv 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser(description='Hyperparams') 6 | # parser.add_argument('--gt', nargs='?', type=str, default=None) 7 | parser.add_argument('--pred', nargs='?', type=str, default=None) 8 | args = parser.parse_args() 9 | 10 | output_root = '../outputs/tmp_results/' 11 | pred = mmcv.load(args.pred) 12 | 13 | def write_result_as_txt(image_name, bboxes, path, words=None): 14 | if not os.path.exists(path): 15 | os.makedirs(path) 16 | 17 | file_path = path + 'res_%s.txt'%(image_name) 18 | lines = [] 19 | for i, bbox in enumerate(bboxes): 20 | values = [int(v) for v in bbox] 21 | if words is None: 22 | line = "%d,%d,%d,%d,%d,%d,%d,%d\n"%tuple(values) 23 | lines.append(line) 24 | elif words[i] is not None: 25 | line = "%d,%d,%d,%d,%d,%d,%d,%d"%tuple(values) + ",%s\n"%words[i] 26 | lines.append(line) 27 | with open(file_path, 'w') as f: 28 | for line in lines: 29 | f.write(line.encode('utf-8')) 30 | 31 | def eval(thr): 32 | for key in pred: 33 | pred_ = pred[key] 34 | line_num = len(pred_['scores']) 35 | bboxes = [] 36 | words = [] 37 | for i in range(line_num): 38 | if pred_['word_scores'][i] < thr: 39 | continue 40 | bboxes.append(pred_['bboxes'][i]) 41 | words.append(pred_['words'][i]) 42 | 43 | write_result_as_txt(key, bboxes, output_root, words) 44 | 45 | cmd = 'cd %s;zip -j %s %s/*' % ('../outputs/', 'tmp_results.zip', 'tmp_results') 46 | res_cmd = os.popen(cmd) 47 | res_cmd.read() 48 | 49 | cmd = 'cd ic15_rec && python2 script.py -g=gt.zip -s=../../outputs/tmp_results.zip && cd ..' 50 | res_cmd = os.popen(cmd) 51 | res_cmd = res_cmd.read() 52 | h_mean = float(res_cmd.split(',')[-2].split(':')[-1]) 53 | return res_cmd, h_mean 54 | 55 | max_h_mean = 0 56 | best_thr = 0 57 | best_res = '' 58 | for i in range(80, 100): 59 | thr = float(i) / 100 60 | # print('Testing thr: %f'%thr) 61 | res, h_mean = eval(thr) 62 | if h_mean >= max_h_mean: 63 | max_h_mean = h_mean 64 | best_thr = thr 65 | best_res = res 66 | 67 | print('thr: %f | %s'%(best_thr, best_res)) 68 | -------------------------------------------------------------------------------- /eval/ic15_word_spotting/gt.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssocean/AlphX-Code-For-DAR/6be8c05c12d373245c97f8434d3e5abe9f8a4550/eval/ic15_word_spotting/gt.zip -------------------------------------------------------------------------------- /eval/ic15_word_spotting/readme-鈹傗敩袞鈹も暋鈹€MacBook Pro.txt: -------------------------------------------------------------------------------- 1 | INSTRUCTIONS FOR THE STANDALONE SCRIPTS 2 | Requirements: 3 | - Python version 3. 4 | - Each Task requires different Python modules. When running the script, if some module is not installed you will see a notification and installation instructions. 5 | 6 | Procedure: 7 | Download the ZIP file for the requested script and unzip it to a directory. 8 | 9 | Open a terminal in the directory and run the command: 10 | python script.py –g=gt.zip –s=submit.zip 11 | 12 | If you have already installed all the required modules, then you will see the method’s results or an error message if the submitted file is not correct. 13 | 14 | If a module is not present, you should install them with PIP: pip install 'module' 15 | 16 | In case of Polygon module, use: 'pip install Polygon3' 17 | 18 | parameters: 19 | -g: Path of the Ground Truth file. In most cases, the Ground Truth will be included in the same Zip file named 'gt.zip', gt.txt' or 'gt.json'. If not, you will be able to get it on the Downloads page of the Task. 20 | -s: Path of your method's results file. 21 | 22 | Optional parameters: 23 | -o: Path to a directory where to copy the file ‘results.zip’ that contains per-sample results. 24 | -p: JSON string parameters to override the script default parameters. The parameters that can be overrided are inside the function 'default_evaluation_params' located at the begining of the evaluation Script. 25 | 26 | Example: python script.py –g=gt.zip –s=submit.zip –o=./ -p={\"IOU_CONSTRAINT\":0.8} -------------------------------------------------------------------------------- /eval/ic15_word_spotting/readme.txt: -------------------------------------------------------------------------------- 1 | INSTRUCTIONS FOR THE STANDALONE SCRIPTS 2 | Requirements: 3 | - Python version 3. 4 | - Each Task requires different Python modules. When running the script, if some module is not installed you will see a notification and installation instructions. 5 | 6 | Procedure: 7 | Download the ZIP file for the requested script and unzip it to a directory. 8 | 9 | Open a terminal in the directory and run the command: 10 | python script.py –g=gt.zip –s=submit.zip 11 | 12 | If you have already installed all the required modules, then you will see the method’s results or an error message if the submitted file is not correct. 13 | 14 | If a module is not present, you should install them with PIP: pip install 'module' 15 | 16 | In case of Polygon module, use: 'pip install Polygon3' 17 | 18 | parameters: 19 | -g: Path of the Ground Truth file. In most cases, the Ground Truth will be included in the same Zip file named 'gt.zip', gt.txt' or 'gt.json'. If not, you will be able to get it on the Downloads page of the Task. 20 | -s: Path of your method's results file. 21 | 22 | Optional parameters: 23 | -o: Path to a directory where to copy the file ‘results.zip’ that contains per-sample results. 24 | -p: JSON string parameters to override the script default parameters. The parameters that can be overrided are inside the function 'default_evaluation_params' located at the begining of the evaluation Script. 25 | 26 | Example: python script.py –g=gt.zip –s=submit.zip –o=./ -p={\"IOU_CONSTRAINT\":0.8} -------------------------------------------------------------------------------- /eval/msra/eval.py: -------------------------------------------------------------------------------- 1 | import file_util 2 | import Polygon as plg 3 | import numpy as np 4 | import math 5 | import cv2 6 | 7 | project_root = '../../' 8 | 9 | pred_root = project_root + 'outputs/submit_msra/' 10 | gt_root = project_root + 'data/MSRA-TD500/test/' 11 | 12 | 13 | def get_pred(path): 14 | lines = file_util.read_file(path).split('\n') 15 | bboxes = [] 16 | for line in lines: 17 | if line == '': 18 | continue 19 | bbox = line.split(',') 20 | if len(bbox) % 2 == 1: 21 | print(path) 22 | bbox = [int(x) for x in bbox] 23 | bboxes.append(bbox) 24 | return bboxes 25 | 26 | 27 | def get_gt(path): 28 | lines = file_util.read_file(path).split('\n') 29 | bboxes = [] 30 | tags = [] 31 | for line in lines: 32 | if line == '': 33 | continue 34 | # line = util.str.remove_all(line, '\xef\xbb\xbf') 35 | # gt = util.str.split(line, ' ') 36 | gt = line.split(' ') 37 | 38 | w_ = np.float(gt[4]) 39 | h_ = np.float(gt[5]) 40 | x1 = np.float(gt[2]) + w_ / 2.0 41 | y1 = np.float(gt[3]) + h_ / 2.0 42 | theta = np.float(gt[6]) / math.pi * 180 43 | 44 | bbox = cv2.boxPoints(((x1, y1), (w_, h_), theta)) 45 | bbox = bbox.reshape(-1) 46 | 47 | bboxes.append(bbox) 48 | tags.append(np.int(gt[1])) 49 | return np.array(bboxes), tags 50 | 51 | 52 | def get_union(pD, pG): 53 | areaA = pD.area() 54 | areaB = pG.area() 55 | return areaA + areaB - get_intersection(pD, pG) 56 | 57 | 58 | def get_intersection(pD, pG): 59 | pInt = pD & pG 60 | if len(pInt) == 0: 61 | return 0 62 | return pInt.area() 63 | 64 | 65 | if __name__ == '__main__': 66 | th = 0.5 67 | pred_list = file_util.read_dir(pred_root) 68 | 69 | count, tp, fp, tn, ta = 0, 0, 0, 0, 0 70 | for pred_path in pred_list: 71 | count = count + 1 72 | preds = get_pred(pred_path) 73 | gt_path = gt_root + pred_path.split('/')[-1].split('.')[0] + '.gt' 74 | gts, tags = get_gt(gt_path) 75 | 76 | ta = ta + len(preds) 77 | for gt, tag in zip(gts, tags): 78 | gt = np.array(gt) 79 | gt = gt.reshape(gt.shape[0] / 2, 2) 80 | gt_p = plg.Polygon(gt) 81 | difficult = tag 82 | flag = 0 83 | for pred in preds: 84 | pred = np.array(pred) 85 | pred = pred.reshape(pred.shape[0] / 2, 2) 86 | pred_p = plg.Polygon(pred) 87 | 88 | union = get_union(pred_p, gt_p) 89 | inter = get_intersection(pred_p, gt_p) 90 | iou = float(inter) / union 91 | if iou >= th: 92 | flag = 1 93 | tp = tp + 1 94 | break 95 | 96 | if flag == 0 and difficult == 0: 97 | fp = fp + 1 98 | 99 | recall = float(tp) / (tp + fp) 100 | precision = float(tp) / ta 101 | hmean = 0 if (precision + recall) == 0 else 2.0 * precision * recall / (precision + recall) 102 | 103 | print('p: %.4f, r: %.4f, f: %.4f' % (precision, recall, hmean)) 104 | -------------------------------------------------------------------------------- /eval/msra/file_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def read_dir(root): 4 | file_path_list = [] 5 | for file_path, dirs, files in os.walk(root): 6 | for file in files: 7 | file_path_list.append(os.path.join(file_path, file).replace('\\', '/')) 8 | file_path_list.sort() 9 | return file_path_list 10 | 11 | def read_file(file_path): 12 | file_object = open(file_path, 'r') 13 | file_content = file_object.read() 14 | file_object.close() 15 | return file_content 16 | 17 | def write_file(file_path, file_content): 18 | if file_path.find('/') != -1: 19 | father_dir = '/'.join(file_path.split('/')[0:-1]) 20 | if not os.path.exists(father_dir): 21 | os.makedirs(father_dir) 22 | file_object = open(file_path, 'w') 23 | file_object.write(file_content) 24 | file_object.close() 25 | 26 | 27 | def write_file_not_cover(file_path, file_content): 28 | father_dir = '/'.join(file_path.split('/')[0:-1]) 29 | if not os.path.exists(father_dir): 30 | os.makedirs(father_dir) 31 | file_object = open(file_path, 'a') 32 | file_object.write(file_content) 33 | file_object.close() -------------------------------------------------------------------------------- /eval/tt/polygon_wrapper.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from skimage.draw import polygon 4 | 5 | """ 6 | :param det_x: [1, N] Xs of detection's vertices 7 | :param det_y: [1, N] Ys of detection's vertices 8 | :param gt_x: [1, N] Xs of groundtruth's vertices 9 | :param gt_y: [1, N] Ys of groundtruth's vertices 10 | ############## 11 | All the calculation of 'AREA' in this script is handled by: 12 | 1) First generating a binary mask with the polygon area filled up with 1's 13 | 2) Summing up all the 1's 14 | """ 15 | 16 | 17 | def area(x, y): 18 | """ 19 | This helper calculates the area given x and y vertices. 20 | """ 21 | ymax = np.max(y) 22 | xmax = np.max(x) 23 | bin_mask = np.zeros((ymax, xmax)) 24 | rr, cc = polygon(y, x) 25 | bin_mask[rr, cc] = 1 26 | area = np.sum(bin_mask) 27 | return area 28 | #return np.round(area, 2) 29 | 30 | 31 | def approx_area_of_intersection(det_x, det_y, gt_x, gt_y): 32 | """ 33 | This helper determine if both polygons are intersecting with each others with an approximation method. 34 | Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax] 35 | """ 36 | det_ymax = np.max(det_y) 37 | det_xmax = np.max(det_x) 38 | det_ymin = np.min(det_y) 39 | det_xmin = np.min(det_x) 40 | 41 | gt_ymax = np.max(gt_y) 42 | gt_xmax = np.max(gt_x) 43 | gt_ymin = np.min(gt_y) 44 | gt_xmin = np.min(gt_x) 45 | 46 | all_min_ymax = np.minimum(det_ymax, gt_ymax) 47 | all_max_ymin = np.maximum(det_ymin, gt_ymin) 48 | 49 | intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin)) 50 | 51 | all_min_xmax = np.minimum(det_xmax, gt_xmax) 52 | all_max_xmin = np.maximum(det_xmin, gt_xmin) 53 | intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin)) 54 | 55 | return intersect_heights * intersect_widths 56 | 57 | def area_of_intersection(det_x, det_y, gt_x, gt_y): 58 | """ 59 | This helper calculates the area of intersection. 60 | """ 61 | if approx_area_of_intersection(det_x, det_y, gt_x, gt_y) > 1: #only proceed if it passes the approximation test 62 | ymax = np.maximum(np.max(det_y), np.max(gt_y)) + 1 63 | xmax = np.maximum(np.max(det_x), np.max(gt_x)) + 1 64 | bin_mask = np.zeros((ymax, xmax)) 65 | det_bin_mask = np.zeros_like(bin_mask) 66 | gt_bin_mask = np.zeros_like(bin_mask) 67 | 68 | rr, cc = polygon(det_y, det_x) 69 | det_bin_mask[rr, cc] = 1 70 | 71 | rr, cc = polygon(gt_y, gt_x) 72 | gt_bin_mask[rr, cc] = 1 73 | 74 | final_bin_mask = det_bin_mask + gt_bin_mask 75 | 76 | inter_map = np.where(final_bin_mask == 2, 1, 0) 77 | inter = np.sum(inter_map) 78 | return inter 79 | # return np.round(inter, 2) 80 | else: 81 | return 0 82 | 83 | 84 | def iou(det_x, det_y, gt_x, gt_y): 85 | """ 86 | This helper determine the intersection over union of two polygons. 87 | """ 88 | 89 | if approx_area_of_intersection(det_x, det_y, gt_x, gt_y) > 1: #only proceed if it passes the approximation test 90 | ymax = np.maximum(np.max(det_y), np.max(gt_y)) + 1 91 | xmax = np.maximum(np.max(det_x), np.max(gt_x)) + 1 92 | bin_mask = np.zeros((ymax, xmax)) 93 | det_bin_mask = np.zeros_like(bin_mask) 94 | gt_bin_mask = np.zeros_like(bin_mask) 95 | 96 | rr, cc = polygon(det_y, det_x) 97 | det_bin_mask[rr, cc] = 1 98 | 99 | rr, cc = polygon(gt_y, gt_x) 100 | gt_bin_mask[rr, cc] = 1 101 | 102 | final_bin_mask = det_bin_mask + gt_bin_mask 103 | 104 | #inter_map = np.zeros_like(final_bin_mask) 105 | inter_map = np.where(final_bin_mask == 2, 1, 0) 106 | inter = np.sum(inter_map) 107 | 108 | #union_map = np.zeros_like(final_bin_mask) 109 | union_map = np.where(final_bin_mask > 0, 1, 0) 110 | union = np.sum(union_map) 111 | return inter / float(union + 1.0) 112 | #return np.round(inter / float(union + 1.0), 2) 113 | else: 114 | return 0 115 | 116 | def iod(det_x, det_y, gt_x, gt_y): 117 | """ 118 | This helper determine the fraction of intersection area over detection area 119 | """ 120 | 121 | if approx_area_of_intersection(det_x, det_y, gt_x, gt_y) > 1: #only proceed if it passes the approximation test 122 | ymax = np.maximum(np.max(det_y), np.max(gt_y)) + 1 123 | xmax = np.maximum(np.max(det_x), np.max(gt_x)) + 1 124 | bin_mask = np.zeros((ymax, xmax)) 125 | det_bin_mask = np.zeros_like(bin_mask) 126 | gt_bin_mask = np.zeros_like(bin_mask) 127 | 128 | rr, cc = polygon(det_y, det_x) 129 | det_bin_mask[rr, cc] = 1 130 | 131 | rr, cc = polygon(gt_y, gt_x) 132 | gt_bin_mask[rr, cc] = 1 133 | 134 | final_bin_mask = det_bin_mask + gt_bin_mask 135 | 136 | inter_map = np.where(final_bin_mask == 2, 1, 0) 137 | inter = np.round(np.sum(inter_map), 2) 138 | 139 | det = np.round(np.sum(det_bin_mask), 2) 140 | return inter / float(det + 1.0) 141 | #return np.round(inter / float(det + 1.0), 2) 142 | else: 143 | return 0 144 | -------------------------------------------------------------------------------- /eval/tt_rec/gt.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssocean/AlphX-Code-For-DAR/6be8c05c12d373245c97f8434d3e5abe9f8a4550/eval/tt_rec/gt.zip -------------------------------------------------------------------------------- /eval/tt_rec/readme.txt: -------------------------------------------------------------------------------- 1 | INSTRUCTIONS FOR THE STANDALONE SCRIPTS 2 | Requirements: 3 | - Python version 3. 4 | - Each Task requires different Python modules. When running the script, if some module is not installed you will see a notification and installation instructions. 5 | 6 | Procedure: 7 | Download the ZIP file for the requested script and unzip it to a directory. 8 | 9 | Open a terminal in the directory and run the command: 10 | python script.py –g=gt.zip –s=submit.zip 11 | 12 | If you have already installed all the required modules, then you will see the method’s results or an error message if the submitted file is not correct. 13 | 14 | If a module is not present, you should install them with PIP: pip install 'module' 15 | 16 | In case of Polygon module, use: 'pip install Polygon3' 17 | 18 | parameters: 19 | -g: Path of the Ground Truth file. In most cases, the Ground Truth will be included in the same Zip file named 'gt.zip', gt.txt' or 'gt.json'. If not, you will be able to get it on the Downloads page of the Task. 20 | -s: Path of your method's results file. 21 | 22 | Optional parameters: 23 | -o: Path to a directory where to copy the file ‘results.zip’ that contains per-sample results. 24 | -p: JSON string parameters to override the script default parameters. The parameters that can be overrided are inside the function 'default_evaluation_params' located at the begining of the evaluation Script. 25 | 26 | Example: python script.py –g=gt.zip –s=submit.zip –o=./ -p={\"IOU_CONSTRAINT\":0.8} -------------------------------------------------------------------------------- /font/NotoSansCJK-Regular.ttc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssocean/AlphX-Code-For-DAR/6be8c05c12d373245c97f8434d3e5abe9f8a4550/font/NotoSansCJK-Regular.ttc -------------------------------------------------------------------------------- /font/README.md: -------------------------------------------------------------------------------- 1 | # 字体文件,用于渲染 -------------------------------------------------------------------------------- /font/SIMSUN.TTC: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssocean/AlphX-Code-For-DAR/6be8c05c12d373245c97f8434d3e5abe9f8a4550/font/SIMSUN.TTC -------------------------------------------------------------------------------- /infer.sh: -------------------------------------------------------------------------------- 1 | INPUT_DIR=IMAGE/FOLDER/PATH 2 | python main.py --input_dir ${INPUT_DIR} -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # from get_image import ImageFetcher #测试方编写的脚本,开发者无需编写 4 | # from evaluate import Evaluator #测试方编写的脚本,用于计算指标,这里没有给出具体调用的代码,开发者无需编写 5 | from TestModel import Infer # 开发者需要将推理模型封装为一个TestModel.py中的Infer类,并且具有eval()方法 6 | import glob 7 | from tqdm import tqdm 8 | import argparse 9 | 10 | # from torch.utils.tensorboard import SummaryWriter 11 | def get_files_pth(dir_pth: str, suffix: str = '*'): 12 | ''' 13 | 返回dir_pth下以后缀名suffix结尾的文件绝对路径list 14 | :param dir_pth:文件夹路径 15 | :param suffix:限定的文件后缀 16 | :return: 文件绝对路径list 17 | ''' 18 | rst = [] 19 | glob_pth = os.path.join(dir_pth, f'*.{suffix}') 20 | for filename in glob.glob(glob_pth): 21 | rst.append(filename) 22 | return rst 23 | 24 | # 如需启用tensorboard,请取消相关注释行 25 | ''' 26 | def init_tensorboard(out_dir: str = 'logs'): 27 | if not os.path.exists(out_dir): ##目录存在,返回为真 28 | os.makedirs(out_dir) 29 | 30 | writer = SummaryWriter(log_dir=out_dir) 31 | """ 32 | https://pytorch.org/docs/stable/tensorboard.html 33 | writer. 34 | add_scalar(tag, scalar_value, global_step=None, walltime=None, new_style=False, double_precision=False) 35 | add_scalars(main_tag, tag_scalar_dict, global_step=None, walltime=None) 36 | add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW') 37 | add_images(tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW') 38 | """ 39 | return writer 40 | ''' 41 | 42 | parser = argparse.ArgumentParser(description="Please type the path of the image folder") 43 | parser.add_argument('-dir', '--input_dir', type=str) 44 | args = parser.parse_args() 45 | 46 | if __name__ == '__main__': 47 | # writer = init_tensorboard('outputs/tblogs') 48 | output_dir = 'outputs/' 49 | os.makedirs(output_dir, exist_ok=True) 50 | 51 | inferencer = Infer() # 初始化模型 52 | 53 | # fetcher = ImageFetcher() 54 | # ImageFetcher 是一个迭代器,产生图片路径。开发者自己编写代码进行测试的时候可以直接用测试图片的路径list代替。 55 | # e.g. fetcher = ['images/image_1000.jpg', 'images/image_1001.jpg', 'images/image_1002.jpg', ...] 56 | fetcher = get_files_pth(args.input_dir) 57 | 58 | 59 | for img_path in tqdm(fetcher): 60 | 61 | inferencer.eval(img_path) #,writer 62 | # writer.close() 63 | # 模型前向过程。eval()的流程为:调用图片路径,读取并转换数据并送入模型进行预测,得到所有结果(按阅读顺序排列好的文本框以及对应的文本内容), 64 | # 并且在outputs文件夹中必须要生成该图片对应的csv(格式与A榜相同),否则迭代器的输出路径不会更新 -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_model 2 | from .pan_pp import PAN_PP 3 | 4 | 5 | __all__ = ['PAN_PP', 'build_model'] 6 | -------------------------------------------------------------------------------- /models/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_backbone 2 | from .resnet import resnet18, resnet50, resnet101 3 | 4 | __all__ = ['resnet18', 'resnet50', 'resnet101', 'build_backbone'] 5 | -------------------------------------------------------------------------------- /models/backbone/builder.py: -------------------------------------------------------------------------------- 1 | import models 2 | 3 | 4 | def build_backbone(cfg): 5 | param = dict() 6 | for key in cfg: 7 | if key == 'type': 8 | continue 9 | param[key] = cfg[key] 10 | 11 | backbone = models.backbone.__dict__[cfg.type](**param) 12 | 13 | return backbone 14 | -------------------------------------------------------------------------------- /models/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | try: 9 | from urllib import urlretrieve 10 | except ImportError: 11 | from urllib.request import urlretrieve 12 | 13 | __all__ = ['resnet18', 'resnet50', 'resnet101'] 14 | 15 | base_url = 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/' 16 | model_urls = { 17 | 'resnet18': base_url + 'resnet18-imagenet.pth', 18 | 'resnet50': base_url + 'resnet50-imagenet.pth', 19 | 'resnet101': base_url + 'resnet101-imagenet.pth' 20 | } 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1): 24 | """3x3 convolution with padding.""" 25 | return nn.Conv2d(in_planes, 26 | out_planes, 27 | kernel_size=3, 28 | stride=stride, 29 | padding=1, 30 | bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, inplanes, planes, stride=1, downsample=None): 37 | super(BasicBlock, self).__init__() 38 | self.conv1 = conv3x3(inplanes, planes, stride) 39 | self.bn1 = nn.BatchNorm2d(planes) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | 65 | class Bottleneck(nn.Module): 66 | expansion = 4 67 | 68 | def __init__(self, inplanes, planes, stride=1, downsample=None): 69 | super(Bottleneck, self).__init__() 70 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 71 | self.bn1 = nn.BatchNorm2d(planes) 72 | self.conv2 = nn.Conv2d(planes, 73 | planes, 74 | kernel_size=3, 75 | stride=stride, 76 | padding=1, 77 | bias=False) 78 | self.bn2 = nn.BatchNorm2d(planes) 79 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 80 | self.bn3 = nn.BatchNorm2d(planes * 4) 81 | self.relu = nn.ReLU(inplace=True) 82 | self.downsample = downsample 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | residual = x 87 | 88 | out = self.conv1(x) 89 | out = self.bn1(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv2(out) 93 | out = self.bn2(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv3(out) 97 | out = self.bn3(out) 98 | 99 | if self.downsample is not None: 100 | residual = self.downsample(x) 101 | 102 | out += residual 103 | out = self.relu(out) 104 | 105 | return out 106 | 107 | 108 | class Convkxk(nn.Module): 109 | def __init__(self, 110 | in_planes, 111 | out_planes, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0): 115 | super(Convkxk, self).__init__() 116 | self.conv = nn.Conv2d(in_planes, 117 | out_planes, 118 | kernel_size=kernel_size, 119 | stride=stride, 120 | padding=padding, 121 | bias=False) 122 | self.bn = nn.BatchNorm2d(out_planes) 123 | self.relu = nn.ReLU(inplace=True) 124 | 125 | for m in self.modules(): 126 | if isinstance(m, nn.Conv2d): 127 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 128 | m.weight.data.normal_(0, math.sqrt(2. / n)) 129 | elif isinstance(m, nn.BatchNorm2d): 130 | m.weight.data.fill_(1) 131 | m.bias.data.zero_() 132 | 133 | def forward(self, x): 134 | return self.relu(self.bn(self.conv(x))) 135 | 136 | 137 | class ResNet(nn.Module): 138 | def __init__(self, block, layers, num_classes=1000): 139 | super(ResNet, self).__init__() 140 | self.inplanes = 128 141 | self.conv1 = conv3x3(3, 64, stride=2) 142 | self.bn1 = nn.BatchNorm2d(64) 143 | self.relu1 = nn.ReLU(inplace=True) 144 | self.conv2 = conv3x3(64, 64) 145 | self.bn2 = nn.BatchNorm2d(64) 146 | self.relu2 = nn.ReLU(inplace=True) 147 | self.conv3 = conv3x3(64, 128) 148 | self.bn3 = nn.BatchNorm2d(128) 149 | self.relu3 = nn.ReLU(inplace=True) 150 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 151 | 152 | self.layer1 = self._make_layer(block, 64, layers[0]) 153 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 154 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 156 | # self.avgpool = nn.AvgPool2d(7, stride=1) 157 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 158 | 159 | for m in self.modules(): 160 | if isinstance(m, nn.Conv2d): 161 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 162 | m.weight.data.normal_(0, math.sqrt(2. / n)) 163 | elif isinstance(m, nn.BatchNorm2d): 164 | m.weight.data.fill_(1) 165 | m.bias.data.zero_() 166 | 167 | def _make_layer(self, block, planes, blocks, stride=1): 168 | downsample = None 169 | if stride != 1 or self.inplanes != planes * block.expansion: 170 | downsample = nn.Sequential( 171 | nn.Conv2d(self.inplanes, 172 | planes * block.expansion, 173 | kernel_size=1, 174 | stride=stride, 175 | bias=False), 176 | nn.BatchNorm2d(planes * block.expansion), 177 | ) 178 | 179 | layers = [] 180 | layers.append(block(self.inplanes, planes, stride, downsample)) 181 | self.inplanes = planes * block.expansion 182 | for i in range(1, blocks): 183 | layers.append(block(self.inplanes, planes)) 184 | 185 | return nn.Sequential(*layers) 186 | 187 | def forward(self, x): 188 | x = self.relu1(self.bn1(self.conv1(x))) 189 | x = self.relu2(self.bn2(self.conv2(x))) 190 | x = self.relu3(self.bn3(self.conv3(x))) 191 | x = self.maxpool(x) 192 | 193 | f = [] 194 | x = self.layer1(x) 195 | f.append(x) 196 | x = self.layer2(x) 197 | f.append(x) 198 | x = self.layer3(x) 199 | f.append(x) 200 | x = self.layer4(x) 201 | f.append(x) 202 | 203 | return tuple(f) 204 | 205 | # x = self.avgpool(x) 206 | # x = x.view(x.size(0), -1) 207 | # x = self.fc(x) 208 | 209 | # return x 210 | 211 | 212 | def resnet18(pretrained=False, **kwargs): 213 | """Constructs a ResNet-18 model. 214 | 215 | Args: 216 | pretrained (bool): If True, returns a model pre-trained on Places 217 | """ 218 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 219 | if pretrained: 220 | model.load_state_dict(load_url(model_urls['resnet18']), strict=False) 221 | return model 222 | 223 | 224 | def resnet50(pretrained=False, **kwargs): 225 | """Constructs a ResNet-50 model. 226 | 227 | Args: 228 | pretrained (bool): If True, returns a model pre-trained on Places 229 | """ 230 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 231 | if pretrained: 232 | model.load_state_dict(load_url(model_urls['resnet50']), strict=False) 233 | return model 234 | 235 | 236 | def resnet101(pretrained=False, **kwargs): 237 | """Constructs a ResNet-101 model. 238 | 239 | Args: 240 | pretrained (bool): If True, returns a model pre-trained on Places 241 | """ 242 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 243 | if pretrained: 244 | model.load_state_dict(load_url(model_urls['resnet101']), strict=False) 245 | return model 246 | 247 | 248 | def load_url(url, model_dir='./pretrained', map_location=None): 249 | if not os.path.exists(model_dir): 250 | os.makedirs(model_dir) 251 | filename = url.split('/')[-1] 252 | cached_file = os.path.join(model_dir, filename) 253 | if not os.path.exists(cached_file): 254 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 255 | urlretrieve(url, cached_file) 256 | return torch.load(cached_file, map_location=map_location) 257 | -------------------------------------------------------------------------------- /models/builder.py: -------------------------------------------------------------------------------- 1 | import models 2 | 3 | 4 | def build_model(cfg): 5 | param = dict() 6 | for key in cfg: 7 | if key == 'type': 8 | continue 9 | param[key] = cfg[key] 10 | 11 | model = models.__dict__[cfg.type](**param) 12 | 13 | return model 14 | -------------------------------------------------------------------------------- /models/head/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_head 2 | from .pa_head import PA_Head 3 | from .pan_pp_det_head import PAN_PP_DetHead 4 | from .pan_pp_rec_head import PAN_PP_RecHead 5 | from .psenet_head import PSENet_Head 6 | 7 | __all__ = [ 8 | 'PA_Head', 'PSENet_Head', 'PAN_PP_DetHead', 'PAN_PP_RecHead', 'build_head' 9 | ] 10 | -------------------------------------------------------------------------------- /models/head/builder.py: -------------------------------------------------------------------------------- 1 | import models 2 | 3 | 4 | def build_head(cfg): 5 | param = dict() 6 | for key in cfg: 7 | if key == 'type': 8 | continue 9 | param[key] = cfg[key] 10 | 11 | head = models.head.__dict__[cfg.type](**param) 12 | 13 | return head 14 | -------------------------------------------------------------------------------- /models/head/pa_head.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | from ..loss import build_loss, iou, ohem_batch 10 | from ..post_processing import pa 11 | 12 | 13 | class PA_Head(nn.Module): 14 | def __init__(self, in_channels, hidden_dim, num_classes, loss_text, 15 | loss_kernel, loss_emb): 16 | super(PA_Head, self).__init__() 17 | self.conv1 = nn.Conv2d(in_channels, 18 | hidden_dim, 19 | kernel_size=3, 20 | stride=1, 21 | padding=1) 22 | self.bn1 = nn.BatchNorm2d(hidden_dim) 23 | self.relu1 = nn.ReLU(inplace=True) 24 | 25 | self.conv2 = nn.Conv2d(hidden_dim, 26 | num_classes, 27 | kernel_size=1, 28 | stride=1, 29 | padding=0) 30 | 31 | self.text_loss = build_loss(loss_text) 32 | self.kernel_loss = build_loss(loss_kernel) 33 | self.emb_loss = build_loss(loss_emb) 34 | 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv2d): 37 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 38 | m.weight.data.normal_(0, math.sqrt(2. / n)) 39 | elif isinstance(m, nn.BatchNorm2d): 40 | m.weight.data.fill_(1) 41 | m.bias.data.zero_() 42 | 43 | def forward(self, f): 44 | out = self.conv1(f) 45 | out = self.relu1(self.bn1(out)) 46 | out = self.conv2(out) 47 | 48 | return out 49 | 50 | def get_results(self, out, img_meta, cfg): 51 | outputs = dict() 52 | 53 | if not self.training and cfg.report_speed: 54 | torch.cuda.synchronize() 55 | start = time.time() 56 | 57 | score = torch.sigmoid(out[:, 0, :, :]) 58 | kernels = out[:, :2, :, :] > 0 59 | text_mask = kernels[:, :1, :, :] 60 | kernels[:, 1:, :, :] = kernels[:, 1:, :, :] * text_mask 61 | emb = out[:, 2:, :, :] 62 | emb = emb * text_mask.float() 63 | 64 | score = score.data.cpu().numpy()[0].astype(np.float32) 65 | kernels = kernels.data.cpu().numpy()[0].astype(np.uint8) 66 | emb = emb.cpu().numpy()[0].astype(np.float32) 67 | 68 | # pa 69 | label = pa(kernels, emb) 70 | 71 | # image size 72 | org_img_size = img_meta['org_img_size'][0] 73 | img_size = img_meta['img_size'][0] 74 | 75 | label_num = np.max(label) + 1 76 | label = cv2.resize(label, (img_size[1], img_size[0]), 77 | interpolation=cv2.INTER_NEAREST) 78 | score = cv2.resize(score, (img_size[1], img_size[0]), 79 | interpolation=cv2.INTER_NEAREST) 80 | 81 | if not self.training and cfg.report_speed: 82 | torch.cuda.synchronize() 83 | outputs.update(dict(det_post_time=time.time() - start)) 84 | 85 | scale = (float(org_img_size[1]) / float(img_size[1]), 86 | float(org_img_size[0]) / float(img_size[0])) 87 | 88 | with_rec = hasattr(cfg.model, 'recognition_head') 89 | 90 | if with_rec: 91 | bboxes_h = np.zeros((1, label_num, 4), dtype=np.int32) 92 | instances = [[]] 93 | 94 | bboxes = [] 95 | scores = [] 96 | for i in range(1, label_num): 97 | ind = label == i 98 | points = np.array(np.where(ind)).transpose((1, 0)) 99 | 100 | if points.shape[0] < cfg.test_cfg.min_area: 101 | label[ind] = 0 102 | continue 103 | 104 | score_i = np.mean(score[ind]) 105 | if score_i < cfg.test_cfg.min_score: 106 | label[ind] = 0 107 | continue 108 | 109 | if with_rec: 110 | tl = np.min(points, axis=0) 111 | br = np.max(points, axis=0) + 1 112 | bboxes_h[0, i] = (tl[0], tl[1], br[0], br[1]) 113 | instances[0].append(i) 114 | # text box type 115 | if cfg.test_cfg.bbox_type == 'rect': 116 | rect = cv2.minAreaRect(points[:, ::-1]) 117 | bbox = cv2.boxPoints(rect) * scale 118 | elif cfg.test_cfg.bbox_type == 'poly': 119 | binary = np.zeros(label.shape, dtype='uint8') 120 | binary[ind] = 1 121 | contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, 122 | cv2.CHAIN_APPROX_SIMPLE) 123 | bbox = contours[0] * scale 124 | 125 | bbox = bbox.astype('int32') 126 | bboxes.append(bbox.reshape(-1)) 127 | scores.append(score_i) 128 | 129 | outputs.update(dict(bboxes=bboxes, scores=scores)) 130 | if with_rec: 131 | outputs.update( 132 | dict(label=label, bboxes_h=bboxes_h, instances=instances)) 133 | 134 | return outputs 135 | 136 | def loss(self, out, gt_texts, gt_kernels, training_masks, gt_instances, 137 | gt_bboxes): 138 | # output 139 | texts = out[:, 0, :, :] 140 | kernels = out[:, 1:2, :, :] 141 | embs = out[:, 2:, :, :] 142 | 143 | # text loss 144 | selected_masks = ohem_batch(texts, gt_texts, training_masks) 145 | loss_text = self.text_loss( 146 | texts, gt_texts, selected_masks, reduce=False) 147 | iou_text = iou( 148 | (texts > 0).long(), gt_texts, training_masks, reduce=False) 149 | losses = dict(loss_text=loss_text, iou_text=iou_text) 150 | 151 | # kernel loss 152 | loss_kernels = [] 153 | selected_masks = gt_texts * training_masks 154 | for i in range(kernels.size(1)): 155 | kernel_i = kernels[:, i, :, :] 156 | gt_kernel_i = gt_kernels[:, i, :, :] 157 | loss_kernel_i = self.kernel_loss( 158 | kernel_i, gt_kernel_i, selected_masks, reduce=False) 159 | loss_kernels.append(loss_kernel_i) 160 | loss_kernels = torch.mean(torch.stack(loss_kernels, dim=1), dim=1) 161 | iou_kernel = iou( 162 | (kernels[:, -1, :, :] > 0).long(), gt_kernels[:, -1, :, :], 163 | training_masks * gt_texts, reduce=False) 164 | losses.update(dict(loss_kernels=loss_kernels, iou_kernel=iou_kernel)) 165 | 166 | # embedding loss 167 | loss_emb = self.emb_loss( 168 | embs, gt_instances, gt_kernels[:, -1, :, :], training_masks, 169 | gt_bboxes, reduce=False) 170 | losses.update(dict(loss_emb=loss_emb)) 171 | 172 | return losses 173 | -------------------------------------------------------------------------------- /models/head/pan_pp_det_head.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | from ..loss import build_loss, iou, ohem_batch 10 | from ..post_processing import pa 11 | from ..utils import CoordConv2d 12 | 13 | 14 | class PAN_PP_DetHead(nn.Module): 15 | def __init__(self, 16 | in_channels, 17 | hidden_dim, 18 | num_classes, 19 | loss_text, 20 | loss_kernel, 21 | loss_emb, 22 | use_coordconv=False): 23 | super(PAN_PP_DetHead, self).__init__() 24 | if not use_coordconv: 25 | self.conv1 = nn.Conv2d(in_channels, 26 | hidden_dim, 27 | kernel_size=3, 28 | stride=1, 29 | padding=1) 30 | else: 31 | self.conv1 = CoordConv2d(in_channels, 32 | hidden_dim, 33 | kernel_size=3, 34 | stride=1, 35 | padding=1) 36 | self.bn1 = nn.BatchNorm2d(hidden_dim) 37 | self.relu1 = nn.ReLU(inplace=True) 38 | 39 | self.conv2 = nn.Conv2d(hidden_dim, 40 | num_classes, 41 | kernel_size=1, 42 | stride=1, 43 | padding=0) 44 | 45 | self.text_loss = build_loss(loss_text) 46 | self.kernel_loss = build_loss(loss_kernel) 47 | self.emb_loss = build_loss(loss_emb) 48 | 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 52 | m.weight.data.normal_(0, math.sqrt(2. / n)) 53 | elif isinstance(m, nn.BatchNorm2d): 54 | m.weight.data.fill_(1) 55 | m.bias.data.zero_() 56 | 57 | def forward(self, f): 58 | out = self.conv1(f) 59 | out = self.relu1(self.bn1(out)) 60 | out = self.conv2(out) 61 | return out 62 | 63 | def get_results(self, out, img_meta, cfg): 64 | results = {} 65 | if cfg.report_speed: 66 | torch.cuda.synchronize() 67 | start = time.time() 68 | 69 | score = torch.sigmoid(out[:, 0, :, :]) 70 | 71 | kernels = out[:, :2, :, :] > 0 72 | text_mask = kernels[:, :1, :, :] 73 | kernels[:, 1:, :, :] = kernels[:, 1:, :, :] * text_mask 74 | 75 | emb = out[:, 2:, :, :] 76 | emb = emb * text_mask.float() 77 | 78 | score = score.data.detach().cpu().numpy()[0].astype(np.float32) 79 | kernels = kernels.data.detach().cpu().numpy()[0].astype(np.uint8) 80 | emb = emb.detach().cpu().numpy()[0].astype(np.float32) 81 | # print(kernels.shape) 82 | label = pa(kernels, emb, 83 | cfg.test_cfg.min_kernel_area / (cfg.test_cfg.scale**2)) 84 | # print(label.shape) 85 | if cfg.report_speed: 86 | torch.cuda.synchronize() 87 | results['det_post_time'] = time.time() - start 88 | 89 | # image size 90 | org_img_size = img_meta['org_img_size'][0] 91 | img_size = img_meta['img_size'][0] 92 | # print(org_img_size.numpy().tolist()) 93 | # print(img_size.numpy().tolist()) 94 | org_img_size = org_img_size.numpy().tolist() 95 | img_size = img_size.numpy().tolist() 96 | label_num = np.max(label) + 1 97 | scale = (float(org_img_size[1]) / float(img_size[1]), 98 | float(org_img_size[0]) / float(img_size[0])) 99 | label = cv2.resize(label, (img_size[1], img_size[0]), 100 | interpolation=cv2.INTER_NEAREST) 101 | score = cv2.resize(score, (img_size[1], img_size[0]), 102 | interpolation=cv2.INTER_NEAREST) 103 | 104 | with_rec = hasattr(cfg.model, 'recognition_head') 105 | if with_rec: 106 | bboxes_h = np.zeros((1, label_num, 4), dtype=np.int32) 107 | instances = [[]] 108 | 109 | bboxes = [] 110 | scores = [] 111 | for i in range(1, label_num): 112 | ind = label == i 113 | points = np.array(np.where(ind)).transpose((1, 0)) 114 | 115 | min_area = cfg.test_cfg.min_area / (cfg.test_cfg.scale**2) 116 | if points.shape[0] < min_area: 117 | label[ind] = 0 118 | continue 119 | 120 | score_i = np.mean(score[ind]) 121 | if score_i < cfg.test_cfg.min_score: 122 | label[ind] = 0 123 | continue 124 | 125 | if with_rec: 126 | tl = np.min(points, axis=0) 127 | br = np.max(points, axis=0) + 1 128 | bboxes_h[0, i] = (tl[0], tl[1], br[0], br[1]) 129 | instances[0].append(i) 130 | 131 | if cfg.test_cfg.bbox_type == 'rect': 132 | rect = cv2.minAreaRect(points[:, ::-1]) 133 | bbox = cv2.boxPoints(rect) * scale 134 | elif cfg.test_cfg.bbox_type == 'poly': 135 | binary = np.zeros(label.shape, dtype='uint8') 136 | binary[ind] = 1 137 | contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, 138 | cv2.CHAIN_APPROX_TC89_L1) 139 | bbox = contours[0] * scale 140 | 141 | bbox = bbox.astype('int32') 142 | bboxes.append(bbox.reshape(-1)) 143 | scores.append(score_i) 144 | 145 | results['bboxes'] = bboxes 146 | results['scores'] = scores 147 | if with_rec: 148 | results['label'] = label 149 | results['bboxes_h'] = bboxes_h 150 | results['instances'] = instances 151 | return results 152 | 153 | def loss(self, out, gt_texts, gt_kernels, training_masks, gt_instances, 154 | gt_bboxes): 155 | texts = out[:, 0, :, :] 156 | kernels = out[:, 1:2, :, :] 157 | embs = out[:, 2:, :, :] 158 | 159 | 160 | # print(texts[0]) 161 | 162 | # print(gt_texts[0]) 163 | selected_masks = ohem_batch(texts, gt_texts, training_masks) 164 | # loss_text = dice_loss(texts, gt_texts, selected_masks, reduce=False) 165 | loss_text = self.text_loss(texts, 166 | gt_texts, 167 | selected_masks, 168 | reduce=False) 169 | iou_text = iou((texts > 0).long(), 170 | gt_texts, 171 | training_masks, 172 | reduce=False) 173 | losses = {'loss_text': loss_text, 'iou_text': iou_text} 174 | 175 | loss_kernels = [] 176 | selected_masks = gt_texts * training_masks 177 | for i in range(kernels.size(1)): 178 | kernel_i = kernels[:, i, :, :] 179 | gt_kernel_i = gt_kernels[:, i, :, :] 180 | loss_kernel_i = self.kernel_loss(kernel_i, 181 | gt_kernel_i, 182 | selected_masks, 183 | reduce=False) 184 | loss_kernels.append(loss_kernel_i) 185 | loss_kernels = torch.mean(torch.stack(loss_kernels, dim=1), dim=1) 186 | iou_kernel = iou((kernels[:, -1, :, :] > 0).long(), 187 | gt_kernels[:, -1, :, :], 188 | training_masks * gt_texts, 189 | reduce=False) 190 | losses.update(dict(loss_kernels=loss_kernels, iou_kernel=iou_kernel)) 191 | 192 | loss_emb = self.emb_loss(embs, 193 | gt_instances, 194 | gt_kernels[:, -1, :, :], 195 | training_masks, 196 | gt_bboxes, 197 | reduce=False) 198 | losses.update(dict(loss_emb=loss_emb)) 199 | 200 | return losses 201 | -------------------------------------------------------------------------------- /models/head/psenet_head.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | from ..loss import build_loss, iou, ohem_batch 10 | from ..post_processing import pse 11 | 12 | 13 | class PSENet_Head(nn.Module): 14 | def __init__(self, in_channels, hidden_dim, num_classes, loss_text, 15 | loss_kernel): 16 | super(PSENet_Head, self).__init__() 17 | self.conv1 = nn.Conv2d(in_channels, 18 | hidden_dim, 19 | kernel_size=3, 20 | stride=1, 21 | padding=1) 22 | self.bn1 = nn.BatchNorm2d(hidden_dim) 23 | self.relu1 = nn.ReLU(inplace=True) 24 | 25 | self.conv2 = nn.Conv2d(hidden_dim, 26 | num_classes, 27 | kernel_size=1, 28 | stride=1, 29 | padding=0) 30 | 31 | self.text_loss = build_loss(loss_text) 32 | self.kernel_loss = build_loss(loss_kernel) 33 | 34 | for m in self.modules(): 35 | if isinstance(m, nn.Conv2d): 36 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 37 | m.weight.data.normal_(0, math.sqrt(2. / n)) 38 | elif isinstance(m, nn.BatchNorm2d): 39 | m.weight.data.fill_(1) 40 | m.bias.data.zero_() 41 | 42 | def forward(self, f): 43 | out = self.conv1(f) 44 | out = self.relu1(self.bn1(out)) 45 | out = self.conv2(out) 46 | 47 | return out 48 | 49 | def get_results(self, out, img_meta, cfg): 50 | outputs = dict() 51 | 52 | if not self.training and cfg.report_speed: 53 | torch.cuda.synchronize() 54 | start = time.time() 55 | 56 | score = torch.sigmoid(out[:, 0, :, :]) 57 | 58 | kernels = out[:, :cfg.test_cfg.kernel_num, :, :] > 0 59 | text_mask = kernels[:, :1, :, :] 60 | kernels[:, 1:, :, :] = kernels[:, 1:, :, :] * text_mask 61 | 62 | score = score.data.cpu().numpy()[0].astype(np.float32) 63 | kernels = kernels.data.cpu().numpy()[0].astype(np.uint8) 64 | 65 | label = pse(kernels, cfg.test_cfg.min_area) 66 | 67 | # image size 68 | org_img_size = img_meta['org_img_size'][0] 69 | img_size = img_meta['img_size'][0] 70 | 71 | label_num = np.max(label) + 1 72 | label = cv2.resize(label, (img_size[1], img_size[0]), 73 | interpolation=cv2.INTER_NEAREST) 74 | score = cv2.resize(score, (img_size[1], img_size[0]), 75 | interpolation=cv2.INTER_NEAREST) 76 | 77 | if not self.training and cfg.report_speed: 78 | torch.cuda.synchronize() 79 | outputs.update(dict(det_post_time=time.time() - start)) 80 | 81 | scale = (float(org_img_size[1]) / float(img_size[1]), 82 | float(org_img_size[0]) / float(img_size[0])) 83 | 84 | bboxes = [] 85 | scores = [] 86 | for i in range(1, label_num): 87 | ind = label == i 88 | points = np.array(np.where(ind)).transpose((1, 0)) 89 | 90 | if points.shape[0] < cfg.test_cfg.min_area: 91 | label[ind] = 0 92 | continue 93 | 94 | score_i = np.mean(score[ind]) 95 | if score_i < cfg.test_cfg.min_score: 96 | label[ind] = 0 97 | continue 98 | 99 | if cfg.test_cfg.bbox_type == 'rect': 100 | rect = cv2.minAreaRect(points[:, ::-1]) 101 | bbox = cv2.boxPoints(rect) * scale 102 | elif cfg.test_cfg.bbox_type == 'poly': 103 | binary = np.zeros(label.shape, dtype='uint8') 104 | binary[ind] = 1 105 | contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, 106 | cv2.CHAIN_APPROX_SIMPLE) 107 | bbox = contours[0] * scale 108 | 109 | bbox = bbox.astype('int32') 110 | bboxes.append(bbox.reshape(-1)) 111 | scores.append(score_i) 112 | 113 | outputs.update(dict(bboxes=bboxes, scores=scores)) 114 | 115 | return outputs 116 | 117 | def loss(self, out, gt_texts, gt_kernels, training_masks): 118 | # output 119 | texts = out[:, 0, :, :] 120 | kernels = out[:, 1:, :, :] 121 | # text loss 122 | selected_masks = ohem_batch(texts, gt_texts, training_masks) 123 | 124 | loss_text = self.text_loss(texts, 125 | gt_texts, 126 | selected_masks, 127 | reduce=False) 128 | iou_text = iou((texts > 0).long(), 129 | gt_texts, 130 | training_masks, 131 | reduce=False) 132 | losses = dict(loss_text=loss_text, iou_text=iou_text) 133 | 134 | # kernel loss 135 | loss_kernels = [] 136 | selected_masks = gt_texts * training_masks 137 | for i in range(kernels.size(1)): 138 | kernel_i = kernels[:, i, :, :] 139 | gt_kernel_i = gt_kernels[:, i, :, :] 140 | loss_kernel_i = self.kernel_loss(kernel_i, 141 | gt_kernel_i, 142 | selected_masks, 143 | reduce=False) 144 | loss_kernels.append(loss_kernel_i) 145 | loss_kernels = torch.mean(torch.stack(loss_kernels, dim=1), dim=1) 146 | iou_kernel = iou((kernels[:, -1, :, :] > 0).long(), 147 | gt_kernels[:, -1, :, :], 148 | training_masks * gt_texts, 149 | reduce=False) 150 | losses.update(dict(loss_kernels=loss_kernels, iou_kernel=iou_kernel)) 151 | 152 | return losses 153 | -------------------------------------------------------------------------------- /models/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .acc import acc 2 | from .builder import build_loss 3 | from .dice_loss import DiceLoss 4 | from .emb_loss_v1 import EmbLoss_v1 5 | from .emb_loss_v2 import EmbLoss_v2 6 | from .iou import iou 7 | from .ohem import ohem_batch 8 | 9 | __all__ = [ 10 | 'DiceLoss', 'EmbLoss_v1', 'EmbLoss_v2', 'acc', 'iou', 'ohem_batch', 11 | 'build_loss' 12 | ] 13 | -------------------------------------------------------------------------------- /models/loss/acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | EPS = 1e-6 4 | 5 | 6 | def acc_single(a, b, mask): 7 | ind = mask == 1 8 | if torch.sum(ind) == 0: 9 | return 0 10 | correct = (a[ind] == b[ind]).float() 11 | acc = torch.sum(correct) / correct.size(0) 12 | return acc 13 | 14 | 15 | def acc(a, b, mask, reduce=True): 16 | batch_size = a.size(0) 17 | 18 | a = a.view(batch_size, -1) 19 | b = b.view(batch_size, -1) 20 | mask = mask.view(batch_size, -1) 21 | 22 | acc = a.new_zeros((batch_size, ), dtype=torch.float32) 23 | for i in range(batch_size): 24 | acc[i] = acc_single(a[i], b[i], mask[i]) 25 | 26 | if reduce: 27 | acc = torch.mean(acc) 28 | return acc 29 | -------------------------------------------------------------------------------- /models/loss/builder.py: -------------------------------------------------------------------------------- 1 | import models 2 | 3 | 4 | def build_loss(cfg): 5 | param = dict() 6 | for key in cfg: 7 | if key == 'type': 8 | continue 9 | param[key] = cfg[key] 10 | 11 | loss = models.loss.__dict__[cfg.type](**param) 12 | 13 | return loss 14 | -------------------------------------------------------------------------------- /models/loss/dice_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DiceLoss(nn.Module): 6 | def __init__(self, loss_weight=1.0): 7 | super(DiceLoss, self).__init__() 8 | self.loss_weight = loss_weight 9 | 10 | def forward(self, input, target, mask, reduce=True): 11 | batch_size = input.size(0) 12 | input = torch.sigmoid(input) 13 | 14 | input = input.contiguous().view(batch_size, -1) 15 | target = target.contiguous().view(batch_size, -1).float() 16 | mask = mask.contiguous().view(batch_size, -1).float() 17 | 18 | input = input * mask 19 | target = target * mask 20 | 21 | a = torch.sum(input * target, dim=1) 22 | b = torch.sum(input * input, dim=1) + 0.001 23 | c = torch.sum(target * target, dim=1) + 0.001 24 | d = (2 * a) / (b + c) 25 | loss = 1 - d 26 | 27 | loss = self.loss_weight * loss 28 | 29 | if reduce: 30 | loss = torch.mean(loss) 31 | 32 | return loss 33 | -------------------------------------------------------------------------------- /models/loss/emb_loss_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class EmbLoss_v1(nn.Module): 7 | def __init__(self, feature_dim=4, loss_weight=1.0): 8 | super(EmbLoss_v1, self).__init__() 9 | self.feature_dim = feature_dim 10 | self.loss_weight = loss_weight 11 | self.delta_v = 0.5 12 | self.delta_d = 1.5 13 | self.weights = (1.0, 1.0) 14 | 15 | def forward_single(self, emb, instance, kernel, training_mask, bboxes): 16 | training_mask = (training_mask > 0.5).long() 17 | kernel = (kernel > 0.5).long() 18 | instance = instance * training_mask 19 | instance_kernel = (instance * kernel).view(-1) 20 | instance = instance.view(-1) 21 | emb = emb.view(self.feature_dim, -1) 22 | 23 | unique_labels, unique_ids = torch.unique(instance_kernel, 24 | sorted=True, 25 | return_inverse=True) 26 | num_instance = unique_labels.size(0) 27 | if num_instance <= 1: 28 | return 0 29 | 30 | emb_mean = emb.new_zeros((self.feature_dim, num_instance), 31 | dtype=torch.float32) 32 | for i, lb in enumerate(unique_labels): 33 | if lb == 0: 34 | continue 35 | ind_k = instance_kernel == lb 36 | emb_mean[:, i] = torch.mean(emb[:, ind_k], dim=1) 37 | 38 | l_agg = emb.new_zeros(num_instance, dtype=torch.float32) # bug 39 | for i, lb in enumerate(unique_labels): 40 | if lb == 0: 41 | continue 42 | ind = instance == lb 43 | emb_ = emb[:, ind] 44 | dist = (emb_ - emb_mean[:, i:i + 1]).norm(p=2, dim=0) 45 | dist = F.relu(dist - self.delta_v)**2 46 | l_agg[i] = torch.mean(torch.log(dist + 1.0)) 47 | l_agg = torch.mean(l_agg[1:]) 48 | 49 | if num_instance > 2: 50 | emb_interleave = emb_mean.permute(1, 0).repeat(num_instance, 1) 51 | emb_band = emb_mean.permute(1, 0).repeat(1, num_instance).view( 52 | -1, self.feature_dim) 53 | # print(seg_band) 54 | 55 | mask = (1 - torch.eye(num_instance, dtype=torch.int8)).view( 56 | -1, 1).repeat(1, self.feature_dim) 57 | mask = mask.view(num_instance, num_instance, -1) 58 | mask[0, :, :] = 0 59 | mask[:, 0, :] = 0 60 | mask = mask.view(num_instance * num_instance, -1) 61 | # print(mask) 62 | 63 | dist = emb_interleave - emb_band 64 | dist = dist[mask > 0].view(-1, self.feature_dim).norm(p=2, dim=1) 65 | dist = F.relu(2 * self.delta_d - dist)**2 66 | l_dis = torch.mean(torch.log(dist + 1.0)) 67 | else: 68 | l_dis = 0 69 | 70 | l_agg = self.weights[0] * l_agg 71 | l_dis = self.weights[1] * l_dis 72 | l_reg = torch.mean(torch.log(torch.norm(emb_mean, 2, 0) + 1.0)) * 0.001 73 | loss = l_agg + l_dis + l_reg 74 | return loss 75 | 76 | def forward(self, 77 | emb, 78 | instance, 79 | kernel, 80 | training_mask, 81 | bboxes, 82 | reduce=True): 83 | loss_batch = emb.new_zeros((emb.size(0)), dtype=torch.float32) 84 | 85 | for i in range(loss_batch.size(0)): 86 | loss_batch[i] = self.forward_single(emb[i], instance[i], kernel[i], 87 | training_mask[i], bboxes[i]) 88 | 89 | loss_batch = self.loss_weight * loss_batch 90 | 91 | if reduce: 92 | loss_batch = torch.mean(loss_batch) 93 | 94 | return loss_batch 95 | -------------------------------------------------------------------------------- /models/loss/emb_loss_v2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class EmbLoss_v2(nn.Module): 8 | def __init__(self, feature_dim=4, loss_weight=1.0): 9 | super(EmbLoss_v2, self).__init__() 10 | self.feature_dim = feature_dim 11 | self.loss_weight = loss_weight 12 | self.delta_v = 0.5 13 | self.delta_d = 1.5 14 | self.weights = (1.0, 1.0) 15 | 16 | def forward_single(self, emb, instance, kernel, training_mask, bboxes): 17 | training_mask = (training_mask > 0.5).long() 18 | kernel = (kernel > 0.5).long() 19 | instance = instance * training_mask 20 | instance_kernel = (instance * kernel).view(-1) 21 | instance = instance.view(-1) 22 | emb = emb.view(self.feature_dim, -1) 23 | 24 | unique_labels, unique_ids = torch.unique(instance_kernel, 25 | sorted=True, 26 | return_inverse=True) 27 | num_instance = unique_labels.size(0) 28 | if num_instance <= 1: 29 | return 0 30 | 31 | emb_mean = emb.new_zeros((self.feature_dim, num_instance), 32 | dtype=torch.float32) 33 | for i, lb in enumerate(unique_labels): 34 | if lb == 0: 35 | continue 36 | ind_k = instance_kernel == lb 37 | emb_mean[:, i] = torch.mean(emb[:, ind_k], dim=1) 38 | 39 | l_agg = emb.new_zeros(num_instance, dtype=torch.float32) # bug 40 | for i, lb in enumerate(unique_labels): 41 | if lb == 0: 42 | continue 43 | ind = instance == lb 44 | emb_ = emb[:, ind] 45 | dist = (emb_ - emb_mean[:, i:i + 1]).norm(p=2, dim=0) 46 | dist = F.relu(dist - self.delta_v)**2 47 | l_agg[i] = torch.mean(torch.log(dist + 1.0)) 48 | l_agg = torch.mean(l_agg[1:]) 49 | 50 | if num_instance > 2: 51 | emb_interleave = emb_mean.permute(1, 0).repeat(num_instance, 1) 52 | emb_band = emb_mean.permute(1, 0).repeat(1, num_instance).view( 53 | -1, self.feature_dim) 54 | # print(seg_band) 55 | 56 | mask = (1 - torch.eye(num_instance, dtype=torch.int8)).view( 57 | -1, 1).repeat(1, self.feature_dim) 58 | mask = mask.view(num_instance, num_instance, -1) 59 | mask[0, :, :] = 0 60 | mask[:, 0, :] = 0 61 | mask = mask.view(num_instance * num_instance, -1) 62 | # print(mask) 63 | 64 | dist = emb_interleave - emb_band 65 | dist = dist[mask > 0].view(-1, self.feature_dim).norm(p=2, dim=1) 66 | dist = F.relu(2 * self.delta_d - dist)**2 67 | # l_dis = torch.mean(torch.log(dist + 1.0)) 68 | 69 | l_dis = [torch.log(dist + 1.0)] 70 | emb_bg = emb[:, instance == 0].view(self.feature_dim, -1) 71 | if emb_bg.size(1) > 100: 72 | rand_ind = np.random.permutation(emb_bg.size(1))[:100] 73 | emb_bg = emb_bg[:, rand_ind] 74 | if emb_bg.size(1) > 0: 75 | for i, lb in enumerate(unique_labels): 76 | if lb == 0: 77 | continue 78 | dist = (emb_bg - emb_mean[:, i:i + 1]).norm(p=2, dim=0) 79 | dist = F.relu(2 * self.delta_d - dist)**2 80 | l_dis_bg = torch.mean(torch.log(dist + 1.0), 81 | 0, 82 | keepdim=True) 83 | l_dis.append(l_dis_bg) 84 | l_dis = torch.mean(torch.cat(l_dis)) 85 | else: 86 | l_dis = 0 87 | 88 | l_agg = self.weights[0] * l_agg 89 | l_dis = self.weights[1] * l_dis 90 | l_reg = torch.mean(torch.log(torch.norm(emb_mean, 2, 0) + 1.0)) * 0.001 91 | loss = l_agg + l_dis + l_reg 92 | return loss 93 | 94 | def forward(self, 95 | emb, 96 | instance, 97 | kernel, 98 | training_mask, 99 | bboxes, 100 | reduce=True): 101 | loss_batch = emb.new_zeros((emb.size(0)), dtype=torch.float32) 102 | 103 | for i in range(loss_batch.size(0)): 104 | loss_batch[i] = self.forward_single(emb[i], instance[i], kernel[i], 105 | training_mask[i], bboxes[i]) 106 | 107 | loss_batch = self.loss_weight * loss_batch 108 | 109 | if reduce: 110 | loss_batch = torch.mean(loss_batch) 111 | 112 | return loss_batch 113 | -------------------------------------------------------------------------------- /models/loss/iou.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | EPS = 1e-6 4 | 5 | 6 | def iou_single(a, b, mask, n_class): 7 | valid = mask == 1 8 | # print(valid) 9 | # print(type(valid)) 10 | a = a[valid] 11 | b = b[valid] 12 | # print(type(a)) 13 | # print(a.shape) 14 | miou = [] 15 | for i in range(n_class): 16 | inter = ((a == i) & (b == i)).float() 17 | union = ((a == i) | (b == i)).float() 18 | 19 | miou.append(torch.sum(inter) / (torch.sum(union) + EPS)) 20 | miou = sum(miou) / len(miou) 21 | return miou 22 | 23 | 24 | def iou(a, b, mask, n_class=2, reduce=True): 25 | batch_size = a.size(0) 26 | 27 | a = a.view(batch_size, -1) 28 | b = b.view(batch_size, -1) 29 | mask = mask.view(batch_size, -1) 30 | 31 | iou = a.new_zeros((batch_size, ), dtype=torch.float32) 32 | for i in range(batch_size): 33 | iou[i] = iou_single(a[i], b[i], mask[i], n_class) 34 | 35 | if reduce: 36 | iou = torch.mean(iou) 37 | return iou 38 | -------------------------------------------------------------------------------- /models/loss/ohem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def ohem_single(score, gt_text, training_mask): 5 | pos_num = int(torch.sum(gt_text > 0.5)) - int( 6 | torch.sum((gt_text > 0.5) & (training_mask <= 0.5))) 7 | 8 | if pos_num == 0: 9 | # selected_mask = gt_text.copy() * 0 # may be not good 10 | selected_mask = training_mask 11 | selected_mask = selected_mask.view(1, selected_mask.shape[0], 12 | selected_mask.shape[1]).float() 13 | return selected_mask 14 | 15 | neg_num = int(torch.sum(gt_text <= 0.5)) 16 | neg_num = int(min(pos_num * 3, neg_num)) 17 | 18 | if neg_num == 0: 19 | selected_mask = training_mask 20 | selected_mask = selected_mask.view(1, selected_mask.shape[0], 21 | selected_mask.shape[1]).float() 22 | return selected_mask 23 | 24 | neg_score = score[gt_text <= 0.5] 25 | neg_score_sorted, _ = torch.sort(-neg_score) 26 | threshold = -neg_score_sorted[neg_num - 1] 27 | 28 | selected_mask = ((score >= threshold) | 29 | (gt_text > 0.5)) & (training_mask > 0.5) 30 | selected_mask = selected_mask.reshape(1, selected_mask.shape[0], 31 | selected_mask.shape[1]).float() 32 | return selected_mask 33 | 34 | 35 | def ohem_batch(scores, gt_texts, training_masks): 36 | selected_masks = [] 37 | for i in range(scores.shape[0]): 38 | selected_masks.append( 39 | ohem_single(scores[i, :, :], gt_texts[i, :, :], 40 | training_masks[i, :, :])) 41 | 42 | selected_masks = torch.cat(selected_masks, 0).float() 43 | return selected_masks 44 | -------------------------------------------------------------------------------- /models/neck/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_neck 2 | from .fpem_v1 import FPEM_v1 3 | from .fpem_v2 import FPEM_v2 # for PAN++ 4 | from .fpn import FPN 5 | 6 | __all__ = ['FPN', 'FPEM_v1', 'FPEM_v2', 'build_neck'] 7 | -------------------------------------------------------------------------------- /models/neck/builder.py: -------------------------------------------------------------------------------- 1 | import models 2 | 3 | 4 | def build_neck(cfg): 5 | param = dict() 6 | for key in cfg: 7 | if key == 'type': 8 | continue 9 | param[key] = cfg[key] 10 | 11 | neck = models.neck.__dict__[cfg.type](**param) 12 | 13 | return neck 14 | -------------------------------------------------------------------------------- /models/neck/fpem_v1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from ..utils import Conv_BN_ReLU 5 | 6 | 7 | class FPEM_v1(nn.Module): 8 | def __init__(self, in_channels, out_channels): 9 | super(FPEM_v1, self).__init__() 10 | planes = out_channels 11 | self.dwconv3_1 = nn.Conv2d(planes, 12 | planes, 13 | kernel_size=3, 14 | stride=1, 15 | padding=1, 16 | groups=planes, 17 | bias=False) 18 | self.smooth_layer3_1 = Conv_BN_ReLU(planes, planes) 19 | 20 | self.dwconv2_1 = nn.Conv2d(planes, 21 | planes, 22 | kernel_size=3, 23 | stride=1, 24 | padding=1, 25 | groups=planes, 26 | bias=False) 27 | self.smooth_layer2_1 = Conv_BN_ReLU(planes, planes) 28 | 29 | self.dwconv1_1 = nn.Conv2d(planes, 30 | planes, 31 | kernel_size=3, 32 | stride=1, 33 | padding=1, 34 | groups=planes, 35 | bias=False) 36 | self.smooth_layer1_1 = Conv_BN_ReLU(planes, planes) 37 | 38 | self.dwconv2_2 = nn.Conv2d(planes, 39 | planes, 40 | kernel_size=3, 41 | stride=2, 42 | padding=1, 43 | groups=planes, 44 | bias=False) 45 | self.smooth_layer2_2 = Conv_BN_ReLU(planes, planes) 46 | 47 | self.dwconv3_2 = nn.Conv2d(planes, 48 | planes, 49 | kernel_size=3, 50 | stride=2, 51 | padding=1, 52 | groups=planes, 53 | bias=False) 54 | self.smooth_layer3_2 = Conv_BN_ReLU(planes, planes) 55 | 56 | self.dwconv4_2 = nn.Conv2d(planes, 57 | planes, 58 | kernel_size=3, 59 | stride=2, 60 | padding=1, 61 | groups=planes, 62 | bias=False) 63 | self.smooth_layer4_2 = Conv_BN_ReLU(planes, planes) 64 | 65 | def _upsample_add(self, x, y): 66 | _, _, H, W = y.size() 67 | return F.interpolate(x, size=(H, W), mode='bilinear') + y 68 | 69 | def forward(self, f1, f2, f3, f4): 70 | f3 = self.smooth_layer3_1(self.dwconv3_1(self._upsample_add(f4, f3))) 71 | f2 = self.smooth_layer2_1(self.dwconv2_1(self._upsample_add(f3, f2))) 72 | f1 = self.smooth_layer1_1(self.dwconv1_1(self._upsample_add(f2, f1))) 73 | 74 | f2 = self.smooth_layer2_2(self.dwconv2_2(self._upsample_add(f2, f1))) 75 | f3 = self.smooth_layer3_2(self.dwconv3_2(self._upsample_add(f3, f2))) 76 | f4 = self.smooth_layer4_2(self.dwconv4_2(self._upsample_add(f4, f3))) 77 | 78 | return f1, f2, f3, f4 79 | -------------------------------------------------------------------------------- /models/neck/fpem_v2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from ..utils import Conv_BN_ReLU 5 | 6 | 7 | class FPEM_v2(nn.Module): 8 | def __init__(self, in_channels, out_channels): 9 | super(FPEM_v2, self).__init__() 10 | planes = out_channels 11 | self.dwconv3_1 = nn.Conv2d(planes, 12 | planes, 13 | kernel_size=3, 14 | stride=1, 15 | padding=1, 16 | groups=planes, 17 | bias=False) 18 | self.smooth_layer3_1 = Conv_BN_ReLU(planes, planes) 19 | 20 | self.dwconv2_1 = nn.Conv2d(planes, 21 | planes, 22 | kernel_size=3, 23 | stride=1, 24 | padding=1, 25 | groups=planes, 26 | bias=False) 27 | self.smooth_layer2_1 = Conv_BN_ReLU(planes, planes) 28 | 29 | self.dwconv1_1 = nn.Conv2d(planes, 30 | planes, 31 | kernel_size=3, 32 | stride=1, 33 | padding=1, 34 | groups=planes, 35 | bias=False) 36 | self.smooth_layer1_1 = Conv_BN_ReLU(planes, planes) 37 | 38 | self.dwconv2_2 = nn.Conv2d(planes, 39 | planes, 40 | kernel_size=3, 41 | stride=2, 42 | padding=1, 43 | groups=planes, 44 | bias=False) 45 | self.smooth_layer2_2 = Conv_BN_ReLU(planes, planes) 46 | 47 | self.dwconv3_2 = nn.Conv2d(planes, 48 | planes, 49 | kernel_size=3, 50 | stride=2, 51 | padding=1, 52 | groups=planes, 53 | bias=False) 54 | self.smooth_layer3_2 = Conv_BN_ReLU(planes, planes) 55 | 56 | self.dwconv4_2 = nn.Conv2d(planes, 57 | planes, 58 | kernel_size=3, 59 | stride=2, 60 | padding=1, 61 | groups=planes, 62 | bias=False) 63 | self.smooth_layer4_2 = Conv_BN_ReLU(planes, planes) 64 | 65 | def _upsample_add(self, x, y): 66 | _, _, H, W = y.size() 67 | return F.interpolate(x, size=(H, W), mode='bilinear') + y 68 | 69 | def forward(self, f1, f2, f3, f4): 70 | f3_ = self.smooth_layer3_1(self.dwconv3_1(self._upsample_add(f4, f3))) 71 | f2_ = self.smooth_layer2_1(self.dwconv2_1(self._upsample_add(f3_, f2))) 72 | f1_ = self.smooth_layer1_1(self.dwconv1_1(self._upsample_add(f2_, f1))) 73 | 74 | f2_ = self.smooth_layer2_2(self.dwconv2_2(self._upsample_add(f2_, 75 | f1_))) 76 | f3_ = self.smooth_layer3_2(self.dwconv3_2(self._upsample_add(f3_, 77 | f2_))) 78 | f4_ = self.smooth_layer4_2(self.dwconv4_2(self._upsample_add(f4, f3_))) 79 | 80 | f1 = f1 + f1_ 81 | f2 = f2 + f2_ 82 | f3 = f3 + f3_ 83 | f4 = f4 + f4_ 84 | 85 | return f1, f2, f3, f4 86 | -------------------------------------------------------------------------------- /models/neck/fpn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ..utils import Conv_BN_ReLU 7 | 8 | 9 | class FPN(nn.Module): 10 | def __init__(self, in_channels, out_channels): 11 | super(FPN, self).__init__() 12 | 13 | # Top layer 14 | self.toplayer_ = Conv_BN_ReLU(2048, 15 | 256, 16 | kernel_size=1, 17 | stride=1, 18 | padding=0) 19 | 20 | # Smooth layers 21 | self.smooth1_ = Conv_BN_ReLU(256, 22 | 256, 23 | kernel_size=3, 24 | stride=1, 25 | padding=1) 26 | 27 | self.smooth2_ = Conv_BN_ReLU(256, 28 | 256, 29 | kernel_size=3, 30 | stride=1, 31 | padding=1) 32 | 33 | self.smooth3_ = Conv_BN_ReLU(256, 34 | 256, 35 | kernel_size=3, 36 | stride=1, 37 | padding=1) 38 | 39 | # Lateral layers 40 | self.latlayer1_ = Conv_BN_ReLU(1024, 41 | 256, 42 | kernel_size=1, 43 | stride=1, 44 | padding=0) 45 | 46 | self.latlayer2_ = Conv_BN_ReLU(512, 47 | 256, 48 | kernel_size=1, 49 | stride=1, 50 | padding=0) 51 | 52 | self.latlayer3_ = Conv_BN_ReLU(256, 53 | 256, 54 | kernel_size=1, 55 | stride=1, 56 | padding=0) 57 | 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 61 | m.weight.data.normal_(0, math.sqrt(2. / n)) 62 | elif isinstance(m, nn.BatchNorm2d): 63 | m.weight.data.fill_(1) 64 | m.bias.data.zero_() 65 | 66 | def _upsample(self, x, y, scale=1): 67 | _, _, H, W = y.size() 68 | return F.interpolate(x, size=(H // scale, W // scale), mode='bilinear') 69 | 70 | def _upsample_add(self, x, y): 71 | _, _, H, W = y.size() 72 | return F.interpolate(x, size=(H, W), mode='bilinear') + y 73 | 74 | def forward(self, f2, f3, f4, f5): 75 | p5 = self.toplayer_(f5) 76 | 77 | f4 = self.latlayer1_(f4) 78 | p4 = self._upsample_add(p5, f4) 79 | p4 = self.smooth1_(p4) 80 | 81 | f3 = self.latlayer2_(f3) 82 | p3 = self._upsample_add(p4, f3) 83 | p3 = self.smooth2_(p3) 84 | 85 | f2 = self.latlayer3_(f2) 86 | p2 = self._upsample_add(p3, f2) 87 | p2 = self.smooth3_(p2) 88 | 89 | p3 = self._upsample(p3, p2) 90 | p4 = self._upsample(p4, p2) 91 | p5 = self._upsample(p5, p2) 92 | 93 | return p2, p3, p4, p5 94 | -------------------------------------------------------------------------------- /models/pan.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .backbone import build_backbone 8 | from .head import build_head 9 | from .neck import build_neck 10 | from .utils import Conv_BN_ReLU 11 | 12 | 13 | class PAN(nn.Module): 14 | def __init__(self, backbone, neck, detection_head): 15 | super(PAN, self).__init__() 16 | self.backbone = build_backbone(backbone) 17 | 18 | in_channels = neck.in_channels 19 | self.reduce_layer1 = Conv_BN_ReLU(in_channels[0], 128) 20 | self.reduce_layer2 = Conv_BN_ReLU(in_channels[1], 128) 21 | self.reduce_layer3 = Conv_BN_ReLU(in_channels[2], 128) 22 | self.reduce_layer4 = Conv_BN_ReLU(in_channels[3], 128) 23 | 24 | self.fpem1 = build_neck(neck) 25 | self.fpem2 = build_neck(neck) 26 | 27 | self.det_head = build_head(detection_head) 28 | 29 | def _upsample(self, x, size, scale=1): 30 | _, _, H, W = size 31 | return F.interpolate(x, size=(H // scale, W // scale), mode='bilinear') 32 | 33 | def forward(self, 34 | imgs, 35 | gt_texts=None, 36 | gt_kernels=None, 37 | training_masks=None, 38 | gt_instances=None, 39 | gt_bboxes=None, 40 | img_metas=None, 41 | cfg=None): 42 | outputs = dict() 43 | 44 | if not self.training and cfg.report_speed: 45 | torch.cuda.synchronize() 46 | start = time.time() 47 | 48 | # backbone 49 | f = self.backbone(imgs) 50 | 51 | if not self.training and cfg.report_speed: 52 | torch.cuda.synchronize() 53 | outputs.update(dict(backbone_time=time.time() - start)) 54 | start = time.time() 55 | 56 | # reduce channel 57 | f1 = self.reduce_layer1(f[0]) 58 | f2 = self.reduce_layer2(f[1]) 59 | f3 = self.reduce_layer3(f[2]) 60 | f4 = self.reduce_layer4(f[3]) 61 | 62 | # FPEM 63 | f1_1, f2_1, f3_1, f4_1 = self.fpem1(f1, f2, f3, f4) 64 | f1_2, f2_2, f3_2, f4_2 = self.fpem2(f1_1, f2_1, f3_1, f4_1) 65 | 66 | # FFM 67 | f1 = f1_1 + f1_2 68 | f2 = f2_1 + f2_2 69 | f3 = f3_1 + f3_2 70 | f4 = f4_1 + f4_2 71 | f2 = self._upsample(f2, f1.size()) 72 | f3 = self._upsample(f3, f1.size()) 73 | f4 = self._upsample(f4, f1.size()) 74 | f = torch.cat((f1, f2, f3, f4), 1) 75 | 76 | if not self.training and cfg.report_speed: 77 | torch.cuda.synchronize() 78 | outputs.update(dict(neck_time=time.time() - start)) 79 | start = time.time() 80 | 81 | # detection 82 | det_out = self.det_head(f) 83 | 84 | if not self.training and cfg.report_speed: 85 | torch.cuda.synchronize() 86 | outputs.update(dict(det_head_time=time.time() - start)) 87 | 88 | if self.training: 89 | det_out = self._upsample(det_out, imgs.size()) 90 | det_loss = self.det_head.loss(det_out, gt_texts, gt_kernels, 91 | training_masks, gt_instances, 92 | gt_bboxes) 93 | outputs.update(det_loss) 94 | else: 95 | det_out = self._upsample(det_out, imgs.size(), 4) 96 | det_res = self.det_head.get_results(det_out, img_metas, cfg) 97 | outputs.update(det_res) 98 | 99 | return outputs 100 | -------------------------------------------------------------------------------- /models/pan_pp.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .backbone import build_backbone 8 | from .head import build_head 9 | from .neck import build_neck 10 | from .utils import Conv_BN_ReLU 11 | 12 | 13 | class PAN_PP(nn.Module): 14 | def __init__(self, backbone, neck, detection_head, recognition_head=None): 15 | super(PAN_PP, self).__init__() 16 | self.backbone = build_backbone(backbone) 17 | 18 | in_channels = neck.in_channels 19 | self.reduce_layer4 = Conv_BN_ReLU(in_channels[3], 128) 20 | self.reduce_layer3 = Conv_BN_ReLU(in_channels[2], 128) 21 | self.reduce_layer2 = Conv_BN_ReLU(in_channels[1], 128) 22 | self.reduce_layer1 = Conv_BN_ReLU(in_channels[0], 128) 23 | 24 | self.fpem1 = build_neck(neck) 25 | self.fpem2 = build_neck(neck) 26 | 27 | self.det_head = build_head(detection_head) 28 | self.rec_head = None 29 | if recognition_head: 30 | self.rec_head = build_head(recognition_head) 31 | 32 | def _upsample(self, x, size, scale=1): 33 | _, _, H, W = size 34 | return F.interpolate(x, size=(int(H // scale), int(W // scale)), mode='bilinear') 35 | 36 | def forward(self, 37 | imgs, 38 | gt_texts=None, 39 | gt_kernels=None, 40 | training_masks=None, 41 | gt_instances=None, 42 | gt_bboxes=None, 43 | gt_words=None, 44 | word_masks=None, 45 | img_metas=None, 46 | cfg=None): 47 | # print(img_metas) 48 | # if cfg.debug: 49 | # from IPython import embed 50 | # embed() 51 | 52 | outputs = dict() 53 | # print(self.training) 54 | # print(not self.training and cfg.report_speed) 55 | if not self.training and cfg.report_speed: 56 | torch.cuda.synchronize() 57 | start = time.time() 58 | 59 | # backbone 60 | f = self.backbone(imgs) 61 | 62 | if not self.training and cfg.report_speed: 63 | torch.cuda.synchronize() 64 | outputs.update(dict(backbone_time=time.time() - start)) 65 | start = time.time() 66 | 67 | # reduce channel 68 | f1 = self.reduce_layer1(f[0]) 69 | f2 = self.reduce_layer2(f[1]) 70 | f3 = self.reduce_layer3(f[2]) 71 | f4 = self.reduce_layer4(f[3]) 72 | 73 | # FPEM 74 | f1, f2, f3, f4 = self.fpem1(f1, f2, f3, f4) 75 | f1, f2, f3, f4 = self.fpem2(f1, f2, f3, f4) 76 | 77 | # FFM 78 | f2 = self._upsample(f2, f1.size()) 79 | f3 = self._upsample(f3, f1.size()) 80 | f4 = self._upsample(f4, f1.size()) 81 | f = torch.cat((f1, f2, f3, f4), 1) 82 | 83 | if not self.training and cfg.report_speed: 84 | torch.cuda.synchronize() 85 | outputs.update(dict(neck_time=time.time() - start)) 86 | start = time.time() 87 | 88 | # detection 89 | out_det = self.det_head(f) 90 | 91 | if not self.training and cfg.report_speed: 92 | torch.cuda.synchronize() 93 | outputs.update(dict(det_head_time=time.time() - start)) 94 | start = time.time() 95 | 96 | if self.training: 97 | out_det = self._upsample(out_det, imgs.size()) 98 | loss_det = self.det_head.loss( 99 | out_det, gt_texts, gt_kernels, training_masks, 100 | gt_instances, gt_bboxes) 101 | outputs.update(loss_det) 102 | else: 103 | out_det = self._upsample(out_det, imgs.size(), cfg.test_cfg.scale) 104 | res_det = self.det_head.get_results(out_det, img_metas, cfg) 105 | outputs.update(res_det) 106 | 107 | if self.rec_head is not None: 108 | # print(gt_words) 109 | if self.training: 110 | x_crops, gt_words = self.rec_head.extract_feature( 111 | f, (imgs.size(2), imgs.size(3)), 112 | gt_instances * training_masks, gt_bboxes, gt_words, 113 | word_masks) 114 | 115 | if x_crops is not None: 116 | out_rec = self.rec_head(x_crops, gt_words) 117 | # ******************************************* 118 | # print(f'gt_words:{gt_words}')#out_rec:{out_rec} 119 | # print(f'out_rec:{out_rec.shape} gt_words:{gt_words.shape}') 120 | loss_rec = self.rec_head.loss(out_rec, gt_words, 121 | reduce=False) 122 | else: 123 | loss_rec = { 124 | 'loss_rec': f.new_full((1,), -1, dtype=torch.float32), 125 | 'acc_rec': f.new_full((1,), -1, dtype=torch.float32) 126 | } 127 | outputs.update(loss_rec) 128 | else: 129 | if len(res_det['bboxes']) > 0: 130 | x_crops, _ = self.rec_head.extract_feature( 131 | f, (imgs.size(2), imgs.size(3)), 132 | f.new_tensor(res_det['label'], 133 | dtype=torch.long).unsqueeze(0), 134 | bboxes=f.new_tensor(res_det['bboxes_h'], 135 | dtype=torch.long), 136 | unique_labels=res_det['instances']) 137 | words, word_scores = self.rec_head.forward(x_crops) 138 | else: 139 | words = [] 140 | word_scores = [] 141 | 142 | if cfg.report_speed: 143 | torch.cuda.synchronize() 144 | outputs.update(dict(rec_time=time.time() - start)) 145 | outputs.update( 146 | dict(words=words, word_scores=word_scores, label='')) 147 | 148 | return outputs 149 | -------------------------------------------------------------------------------- /models/post_processing/__init__.py: -------------------------------------------------------------------------------- 1 | # for PAN++ 2 | from .beam_search import BeamSearch 3 | from .pa import pa 4 | from .pse import pse 5 | 6 | __all__ = ['BeamSearch', 'pa', 'pse'] 7 | -------------------------------------------------------------------------------- /models/post_processing/beam_search/__init__.py: -------------------------------------------------------------------------------- 1 | from .beam_search import BeamSearch 2 | 3 | __all__ = ['BeamSearch'] 4 | -------------------------------------------------------------------------------- /models/post_processing/beam_search/beam_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .topk import TopK 4 | 5 | 6 | class BeamNode(object): 7 | def __init__(self, seq, state, score): 8 | self.seq = seq 9 | self.state = state 10 | self.score = score 11 | self.avg_score = score / len(seq) 12 | 13 | def __cmp__(self, other): 14 | if self.avg_score == other.avg_score: 15 | return 0 16 | elif self.avg_score < other.avg_score: 17 | return -1 18 | else: 19 | return 1 20 | 21 | def __lt__(self, other): 22 | return self.avg_score < other.avg_score 23 | 24 | def __eq__(self, other): 25 | return self.avg_score == other.avg_score 26 | 27 | 28 | class BeamSearch(object): 29 | """Class to generate sequences from an image-to-text model.""" 30 | def __init__(self, decode_step, eos, beam_size=2, max_seq_len=32): 31 | self.decode_step = decode_step 32 | self.eos = eos 33 | self.beam_size = beam_size 34 | self.max_seq_len = max_seq_len 35 | 36 | def beam_search(self, init_inputs, init_states): 37 | # self.beam_size = 1 38 | batch_size = len(init_inputs) 39 | part_seqs = [TopK(self.beam_size) for _ in range(batch_size)] 40 | comp_seqs = [TopK(self.beam_size) for _ in range(batch_size)] 41 | 42 | # print(init_inputs.shape, init_states.shape) 43 | words, scores, states = self.decode_step(init_inputs, 44 | init_states, 45 | k=self.beam_size) 46 | for batch_id in range(batch_size): 47 | for i in range(self.beam_size): 48 | node = BeamNode([words[batch_id][i]], 49 | states[:, :, batch_id, :], scores[batch_id][i]) 50 | part_seqs[batch_id].push(node) 51 | 52 | for t in range(self.max_seq_len - 1): 53 | part_seq_list = [] 54 | for p in part_seqs: 55 | part_seq_list.append(p.extract()) 56 | p.reset() 57 | 58 | inputs, states = [], [] 59 | for seq_list in part_seq_list: 60 | for node in seq_list: 61 | inputs.append(node.seq[-1]) 62 | states.append(node.state) 63 | if len(inputs) == 0: 64 | break 65 | 66 | inputs = torch.stack(inputs) 67 | states = torch.stack(states, dim=2) 68 | words, scores, states = self.decode_step(inputs, 69 | states, 70 | k=self.beam_size + 1) 71 | 72 | idx = 0 73 | for batch_id in range(batch_size): 74 | for node in part_seq_list[batch_id]: 75 | tmp_state = states[:, :, idx, :] 76 | k = 0 77 | num_hyp = 0 78 | while num_hyp < self.beam_size: 79 | word = words[idx][k] 80 | tmp_seq = node.seq + [word] 81 | tmp_score = node.score + scores[idx][k] 82 | tmp_node = BeamNode(tmp_seq, tmp_state, tmp_score) 83 | k += 1 84 | num_hyp += 1 85 | 86 | if word == self.eos: 87 | comp_seqs[batch_id].push(tmp_node) 88 | num_hyp -= 1 89 | else: 90 | part_seqs[batch_id].push(tmp_node) 91 | idx += 1 92 | 93 | for batch_id in range(batch_size): 94 | if not comp_seqs[batch_id].size(): 95 | comp_seqs[batch_id] = part_seqs[batch_id] 96 | seqs = [seq_list.extract(sort=True)[0].seq for seq_list in comp_seqs] 97 | seq_scores = [ 98 | seq_list.extract(sort=True)[0].avg_score for seq_list in comp_seqs 99 | ] 100 | return seqs, seq_scores 101 | -------------------------------------------------------------------------------- /models/post_processing/beam_search/topk.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | 3 | 4 | class TopK(object): 5 | def __init__(self, k): 6 | self.k = k 7 | self.data = [] 8 | 9 | def reset(self): 10 | self.data = [] 11 | 12 | def size(self): 13 | return len(self.data) 14 | 15 | def push(self, x): 16 | if len(self.data) < self.k: 17 | heapq.heappush(self.data, x) 18 | else: 19 | heapq.heappushpop(self.data, x) 20 | 21 | def extract(self, sort=False): 22 | if sort: 23 | self.data.sort(reverse=True) 24 | return self.data 25 | -------------------------------------------------------------------------------- /models/post_processing/pa/__init__.py: -------------------------------------------------------------------------------- 1 | from .pa import pa 2 | 3 | __all__ = ['pa'] 4 | -------------------------------------------------------------------------------- /models/post_processing/pa/pa.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | cimport numpy as np 5 | cimport cython 6 | cimport libcpp 7 | cimport libcpp.pair 8 | cimport libcpp.queue 9 | from libcpp.pair cimport * 10 | from libcpp.queue cimport * 11 | 12 | @cython.boundscheck(False) 13 | @cython.wraparound(False) 14 | cdef np.ndarray[np.int32_t, ndim=2] _pa(np.ndarray[np.uint8_t, ndim=3] kernels, 15 | np.ndarray[np.float32_t, ndim=3] emb, 16 | np.ndarray[np.int32_t, ndim=2] label, 17 | np.ndarray[np.int32_t, ndim=2] cc, 18 | int kernel_num, 19 | int label_num, 20 | float min_area=0): 21 | cdef np.ndarray[np.int32_t, ndim=2] pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) 22 | cdef np.ndarray[np.float32_t, ndim=2] mean_emb = np.zeros((label_num, 4), dtype=np.float32) 23 | cdef np.ndarray[np.float32_t, ndim=1] area = np.full((label_num,), -1, dtype=np.float32) 24 | cdef np.ndarray[np.int32_t, ndim=1] flag = np.zeros((label_num,), dtype=np.int32) 25 | cdef np.ndarray[np.uint8_t, ndim=3] inds = np.zeros((label_num, label.shape[0], label.shape[1]), dtype=np.uint8) 26 | cdef np.ndarray[np.int32_t, ndim=2] p = np.zeros((label_num, 2), dtype=np.int32) 27 | 28 | cdef np.float32_t max_rate = 1024 29 | for i in range(1, label_num): 30 | ind = label == i 31 | inds[i] = ind 32 | 33 | area[i] = np.sum(ind) 34 | 35 | if area[i] < min_area: 36 | label[ind] = 0 37 | continue 38 | 39 | px, py = np.where(ind) 40 | p[i] = (px[0], py[0]) 41 | 42 | for j in range(1, i): 43 | if area[j] < min_area: 44 | continue 45 | if cc[p[i, 0], p[i, 1]] != cc[p[j, 0], p[j, 1]]: 46 | continue 47 | rate = area[i] / area[j] 48 | if rate < 1 / max_rate or rate > max_rate: 49 | flag[i] = 1 50 | mean_emb[i] = np.mean(emb[:, ind], axis=1) 51 | 52 | if flag[j] == 0: 53 | flag[j] = 1 54 | mean_emb[j] = np.mean(emb[:, inds[j].astype(np.bool)], axis=1) 55 | 56 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t, np.int16_t]] que = \ 57 | queue[libcpp.pair.pair[np.int16_t, np.int16_t]]() 58 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t, np.int16_t]] nxt_que = \ 59 | queue[libcpp.pair.pair[np.int16_t, np.int16_t]]() 60 | cdef np.int16_t*dx = [-1, 1, 0, 0] 61 | cdef np.int16_t*dy = [0, 0, -1, 1] 62 | cdef np.int16_t tmpx, tmpy 63 | 64 | points = np.array(np.where(label > 0)).transpose((1, 0)) 65 | for point_idx in range(points.shape[0]): 66 | tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] 67 | que.push(pair[np.int16_t, np.int16_t](tmpx, tmpy)) 68 | pred[tmpx, tmpy] = label[tmpx, tmpy] 69 | 70 | cdef libcpp.pair.pair[np.int16_t, np.int16_t] cur 71 | cdef int cur_label 72 | for kernel_idx in range(kernel_num - 2, -1, -1): 73 | while not que.empty(): 74 | cur = que.front() 75 | que.pop() 76 | cur_label = pred[cur.first, cur.second] 77 | 78 | is_edge = True 79 | for j in range(4): 80 | tmpx = cur.first + dx[j] 81 | tmpy = cur.second + dy[j] 82 | if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: 83 | continue 84 | if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: 85 | continue 86 | if flag[cur_label] == 1 and np.linalg.norm(emb[:, tmpx, tmpy] - mean_emb[cur_label]) > 3: 87 | continue 88 | 89 | que.push(pair[np.int16_t, np.int16_t](tmpx, tmpy)) 90 | pred[tmpx, tmpy] = cur_label 91 | is_edge = False 92 | if is_edge: 93 | nxt_que.push(cur) 94 | 95 | que, nxt_que = nxt_que, que 96 | 97 | return pred 98 | 99 | def pa(kernels, emb, min_area=0): 100 | kernel_num = kernels.shape[0] 101 | _, cc = cv2.connectedComponents(kernels[0], connectivity=4) 102 | label_num, label = cv2.connectedComponents(kernels[1], connectivity=4) 103 | 104 | return _pa(kernels[:-1], emb, label, cc, kernel_num, label_num, min_area) 105 | -------------------------------------------------------------------------------- /models/post_processing/pa/readme.txt: -------------------------------------------------------------------------------- 1 | python setup.py build_ext --inplace 2 | -------------------------------------------------------------------------------- /models/post_processing/pa/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import Extension, setup 2 | 3 | import numpy 4 | from Cython.Build import cythonize 5 | 6 | setup(ext_modules=cythonize( 7 | Extension('pa', 8 | sources=['pa.pyx'], 9 | language='c++', 10 | include_dirs=[numpy.get_include()], 11 | library_dirs=[], 12 | libraries=[], 13 | extra_compile_args=['-O3'], 14 | extra_link_args=[]))) 15 | -------------------------------------------------------------------------------- /models/post_processing/pse/__init__.py: -------------------------------------------------------------------------------- 1 | from .pse import pse 2 | 3 | __all__ = ['pse'] 4 | -------------------------------------------------------------------------------- /models/post_processing/pse/pse.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | cimport numpy as np 4 | cimport cython 5 | cimport libcpp 6 | cimport libcpp.pair 7 | cimport libcpp.queue 8 | from libcpp.pair cimport * 9 | from libcpp.queue cimport * 10 | 11 | @cython.boundscheck(False) 12 | @cython.wraparound(False) 13 | cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels, 14 | np.ndarray[np.int32_t, ndim=2] label, 15 | int kernel_num, 16 | int label_num, 17 | float min_area=0): 18 | cdef np.ndarray[np.int32_t, ndim=2] pred 19 | pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32) 20 | 21 | for label_idx in range(1, label_num): 22 | if np.sum(label == label_idx) < min_area: 23 | label[label == label_idx] = 0 24 | 25 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \ 26 | queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() 27 | cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \ 28 | queue[libcpp.pair.pair[np.int16_t,np.int16_t]]() 29 | cdef np.int16_t* dx = [-1, 1, 0, 0] 30 | cdef np.int16_t* dy = [0, 0, -1, 1] 31 | cdef np.int16_t tmpx, tmpy 32 | 33 | points = np.array(np.where(label > 0)).transpose((1, 0)) 34 | for point_idx in range(points.shape[0]): 35 | tmpx, tmpy = points[point_idx, 0], points[point_idx, 1] 36 | que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) 37 | pred[tmpx, tmpy] = label[tmpx, tmpy] 38 | 39 | cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur 40 | cdef int cur_label 41 | for kernel_idx in range(kernel_num - 1, -1, -1): 42 | while not que.empty(): 43 | cur = que.front() 44 | que.pop() 45 | cur_label = pred[cur.first, cur.second] 46 | 47 | is_edge = True 48 | for j in range(4): 49 | tmpx = cur.first + dx[j] 50 | tmpy = cur.second + dy[j] 51 | if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]: 52 | continue 53 | if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0: 54 | continue 55 | 56 | que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy)) 57 | pred[tmpx, tmpy] = cur_label 58 | is_edge = False 59 | if is_edge: 60 | nxt_que.push(cur) 61 | 62 | que, nxt_que = nxt_que, que 63 | 64 | return pred 65 | 66 | def pse(kernels, min_area): 67 | kernel_num = kernels.shape[0] 68 | label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4) 69 | return _pse(kernels[:-1], label, kernel_num, label_num, min_area) 70 | -------------------------------------------------------------------------------- /models/post_processing/pse/readme.txt: -------------------------------------------------------------------------------- 1 | python setup.py build_ext --inplace 2 | -------------------------------------------------------------------------------- /models/post_processing/pse/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import Extension, setup 2 | 3 | import numpy 4 | from Cython.Build import cythonize 5 | 6 | setup(ext_modules=cythonize( 7 | Extension('pse', 8 | sources=['pse.pyx'], 9 | language='c++', 10 | include_dirs=[numpy.get_include()], 11 | library_dirs=[], 12 | libraries=[], 13 | extra_compile_args=['-O3'], 14 | extra_link_args=[]))) 15 | -------------------------------------------------------------------------------- /models/psenet.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .backbone import build_backbone 8 | from .head import build_head 9 | from .neck import build_neck 10 | 11 | 12 | class PSENet(nn.Module): 13 | def __init__(self, backbone, neck, detection_head): 14 | super(PSENet, self).__init__() 15 | self.backbone = build_backbone(backbone) 16 | self.fpn = build_neck(neck) 17 | 18 | self.det_head = build_head(detection_head) 19 | 20 | def _upsample(self, x, size, scale=1): 21 | _, _, H, W = size 22 | return F.interpolate(x, size=(H // scale, W // scale), mode='bilinear') 23 | 24 | def forward(self, 25 | imgs, 26 | gt_texts=None, 27 | gt_kernels=None, 28 | training_masks=None, 29 | img_metas=None, 30 | cfg=None): 31 | outputs = dict() 32 | 33 | if not self.training and cfg.report_speed: 34 | torch.cuda.synchronize() 35 | start = time.time() 36 | 37 | # backbone 38 | f = self.backbone(imgs) 39 | if not self.training and cfg.report_speed: 40 | torch.cuda.synchronize() 41 | outputs.update(dict(backbone_time=time.time() - start)) 42 | start = time.time() 43 | 44 | # FPN 45 | f1, f2, f3, f4, = self.fpn(f[0], f[1], f[2], f[3]) 46 | 47 | f = torch.cat((f1, f2, f3, f4), 1) 48 | 49 | if not self.training and cfg.report_speed: 50 | torch.cuda.synchronize() 51 | outputs.update(dict(neck_time=time.time() - start)) 52 | start = time.time() 53 | 54 | # detection 55 | det_out = self.det_head(f) 56 | 57 | if not self.training and cfg.report_speed: 58 | torch.cuda.synchronize() 59 | outputs.update(dict(det_head_time=time.time() - start)) 60 | 61 | if self.training: 62 | det_out = self._upsample(det_out, imgs.size()) 63 | det_loss = self.det_head.loss(det_out, gt_texts, gt_kernels, 64 | training_masks) 65 | outputs.update(det_loss) 66 | else: 67 | det_out = self._upsample(det_out, imgs.size(), 1) 68 | det_res = self.det_head.get_results(det_out, img_metas, cfg) 69 | outputs.update(det_res) 70 | 71 | return outputs 72 | -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv_bn_relu import Conv_BN_ReLU 2 | # for PAN++ 3 | from .coordconv import CoordConv2d 4 | from .fuse_conv_bn import fuse_module 5 | 6 | __all__ = ['Conv_BN_ReLU', 'CoordConv2d', 'fuse_module'] 7 | -------------------------------------------------------------------------------- /models/utils/conv_bn_relu.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class Conv_BN_ReLU(nn.Module): 7 | def __init__(self, 8 | in_planes, 9 | out_planes, 10 | kernel_size=1, 11 | stride=1, 12 | padding=0): 13 | super(Conv_BN_ReLU, self).__init__() 14 | self.conv = nn.Conv2d(in_planes, 15 | out_planes, 16 | kernel_size=kernel_size, 17 | stride=stride, 18 | padding=padding, 19 | bias=False) 20 | self.bn = nn.BatchNorm2d(out_planes) 21 | self.relu = nn.ReLU(inplace=True) 22 | 23 | for m in self.modules(): 24 | if isinstance(m, nn.Conv2d): 25 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 26 | m.weight.data.normal_(0, math.sqrt(2. / n)) 27 | elif isinstance(m, nn.BatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | 31 | def forward(self, x): 32 | return self.relu(self.bn(self.conv(x))) 33 | -------------------------------------------------------------------------------- /models/utils/coordconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.modules.conv as conv 4 | 5 | 6 | class AddCoords(nn.Module): 7 | def __init__(self, rank, with_r=False, use_cuda=True): 8 | super(AddCoords, self).__init__() 9 | self.rank = rank 10 | self.with_r = with_r 11 | self.use_cuda = use_cuda 12 | 13 | def forward(self, input_tensor): 14 | if self.rank == 1: 15 | batch_size_shape, channel_in_shape, dim_x = input_tensor.shape 16 | xx_range = torch.arange(dim_x, dtype=torch.int32) 17 | xx_channel = xx_range[None, None, :] 18 | 19 | xx_channel = xx_channel.float() / (dim_x - 1) 20 | xx_channel = xx_channel * 2 - 1 21 | xx_channel = xx_channel.repeat(batch_size_shape, 1, 1) 22 | 23 | if torch.cuda.is_available and self.use_cuda: 24 | input_tensor = input_tensor.cuda() 25 | xx_channel = xx_channel.cuda() 26 | out = torch.cat([input_tensor, xx_channel], dim=1) 27 | 28 | if self.with_r: 29 | rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2)) 30 | out = torch.cat([out, rr], dim=1) 31 | 32 | elif self.rank == 2: 33 | batch_size_shape, channel_in_shape,\ 34 | dim_y, dim_x = input_tensor.shape 35 | xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32) 36 | yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32) 37 | 38 | xx_range = torch.arange(dim_y, dtype=torch.int32) 39 | yy_range = torch.arange(dim_x, dtype=torch.int32) 40 | xx_range = xx_range[None, None, :, None] 41 | yy_range = yy_range[None, None, :, None] 42 | 43 | xx_channel = torch.matmul(xx_range, xx_ones) 44 | yy_channel = torch.matmul(yy_range, yy_ones) 45 | 46 | # transpose y 47 | yy_channel = yy_channel.permute(0, 1, 3, 2) 48 | 49 | xx_channel = xx_channel.float() / (dim_y - 1) 50 | yy_channel = yy_channel.float() / (dim_x - 1) 51 | 52 | xx_channel = xx_channel * 2 - 1 53 | yy_channel = yy_channel * 2 - 1 54 | 55 | xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1) 56 | yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1) 57 | 58 | if torch.cuda.is_available and self.use_cuda: 59 | input_tensor = input_tensor.cuda() 60 | xx_channel = xx_channel.cuda() 61 | yy_channel = yy_channel.cuda() 62 | out = torch.cat([input_tensor, xx_channel, yy_channel], dim=1) 63 | 64 | if self.with_r: 65 | rr = torch.sqrt( 66 | torch.pow(xx_channel - 0.5, 2) + 67 | torch.pow(yy_channel - 0.5, 2)) 68 | out = torch.cat([out, rr], dim=1) 69 | 70 | elif self.rank == 3: 71 | batch_size_shape, channel_in_shape, \ 72 | dim_z, dim_y, dim_x = input_tensor.shape 73 | xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32) 74 | yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32) 75 | zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32) 76 | 77 | xy_range = torch.arange(dim_y, dtype=torch.int32) 78 | xy_range = xy_range[None, None, None, :, None] 79 | 80 | yz_range = torch.arange(dim_z, dtype=torch.int32) 81 | yz_range = yz_range[None, None, None, :, None] 82 | 83 | zx_range = torch.arange(dim_x, dtype=torch.int32) 84 | zx_range = zx_range[None, None, None, :, None] 85 | 86 | xy_channel = torch.matmul(xy_range, xx_ones) 87 | xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], 88 | dim=2) 89 | 90 | yz_channel = torch.matmul(yz_range, yy_ones) 91 | yz_channel = yz_channel.permute(0, 1, 3, 4, 2) 92 | yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], 93 | dim=4) 94 | 95 | zx_channel = torch.matmul(zx_range, zz_ones) 96 | zx_channel = zx_channel.permute(0, 1, 4, 2, 3) 97 | zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], 98 | dim=3) 99 | 100 | if torch.cuda.is_available and self.use_cuda: 101 | input_tensor = input_tensor.cuda() 102 | xx_channel = xx_channel.cuda() 103 | yy_channel = yy_channel.cuda() 104 | zz_channel = zz_channel.cuda() 105 | out = torch.cat([input_tensor, xx_channel, yy_channel, zz_channel], 106 | dim=1) 107 | 108 | if self.with_r: 109 | rr = torch.sqrt( 110 | torch.pow(xx_channel - 0.5, 2) + 111 | torch.pow(yy_channel - 0.5, 2) + 112 | torch.pow(zz_channel - 0.5, 2)) 113 | out = torch.cat([out, rr], dim=1) 114 | else: 115 | raise NotImplementedError 116 | 117 | return out 118 | 119 | 120 | class CoordConv1d(conv.Conv1d): 121 | def __init__(self, 122 | in_channels, 123 | out_channels, 124 | kernel_size, 125 | stride=1, 126 | padding=0, 127 | dilation=1, 128 | groups=1, 129 | bias=True, 130 | with_r=False, 131 | use_cuda=True): 132 | super(CoordConv1d, 133 | self).__init__(in_channels, out_channels, kernel_size, stride, 134 | padding, dilation, groups, bias) 135 | self.rank = 1 136 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 137 | self.conv = nn.Conv1d(in_channels + self.rank + int(with_r), 138 | out_channels, kernel_size, stride, padding, 139 | dilation, groups, bias) 140 | 141 | def forward(self, input_tensor): 142 | out = self.addcoords(input_tensor) 143 | out = self.conv(out) 144 | 145 | return out 146 | 147 | 148 | class CoordConv2d(conv.Conv2d): 149 | def __init__(self, 150 | in_channels, 151 | out_channels, 152 | kernel_size, 153 | stride=1, 154 | padding=0, 155 | dilation=1, 156 | groups=1, 157 | bias=True, 158 | with_r=False, 159 | use_cuda=True): 160 | super(CoordConv2d, 161 | self).__init__(in_channels, out_channels, kernel_size, stride, 162 | padding, dilation, groups, bias) 163 | self.rank = 2 164 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 165 | self.conv = nn.Conv2d(in_channels + self.rank + int(with_r), 166 | out_channels, kernel_size, stride, padding, 167 | dilation, groups, bias) 168 | 169 | def forward(self, input_tensor): 170 | out = self.addcoords(input_tensor) 171 | out = self.conv(out) 172 | 173 | return out 174 | 175 | 176 | class CoordConv3d(conv.Conv3d): 177 | def __init__(self, 178 | in_channels, 179 | out_channels, 180 | kernel_size, 181 | stride=1, 182 | padding=0, 183 | dilation=1, 184 | groups=1, 185 | bias=True, 186 | with_r=False, 187 | use_cuda=True): 188 | super(CoordConv3d, 189 | self).__init__(in_channels, out_channels, kernel_size, stride, 190 | padding, dilation, groups, bias) 191 | self.rank = 3 192 | self.addcoords = AddCoords(self.rank, with_r, use_cuda=use_cuda) 193 | self.conv = nn.Conv3d(in_channels + self.rank + int(with_r), 194 | out_channels, kernel_size, stride, padding, 195 | dilation, groups, bias) 196 | 197 | def forward(self, input_tensor): 198 | out = self.addcoords(input_tensor) 199 | out = self.conv(out) 200 | 201 | return out 202 | -------------------------------------------------------------------------------- /models/utils/fuse_conv_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def fuse_conv_bn(conv, bn): 6 | """During inference, the functionary of batch norm layers is turned off but 7 | only the mean and var alone channels are used, which exposes the chance to 8 | fuse it with the preceding conv layers to save computations and simplify 9 | network structures.""" 10 | conv_w = conv.weight 11 | conv_b = conv.bias if conv.bias is not None else torch.zeros_like( 12 | bn.running_mean) 13 | 14 | factor = bn.weight / torch.sqrt(bn.running_var + bn.eps) 15 | conv.weight = nn.Parameter(conv_w * 16 | factor.reshape([conv.out_channels, 1, 1, 1])) 17 | conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias) 18 | return conv 19 | 20 | 21 | def fuse_module(m): 22 | last_conv = None 23 | last_conv_name = None 24 | 25 | for name, child in m.named_children(): 26 | if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)): 27 | if last_conv is None: # only fuse BN that is after Conv 28 | continue 29 | fused_conv = fuse_conv_bn(last_conv, child) 30 | m._modules[last_conv_name] = fused_conv 31 | # To reduce changes, set BN as Identity instead of deleting it. 32 | m._modules[name] = nn.Identity() 33 | last_conv = None 34 | elif isinstance(child, nn.Conv2d): 35 | last_conv = child 36 | last_conv_name = name 37 | else: 38 | fuse_module(child) 39 | return m 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mmcv==1.3.1 2 | imgaug==0.4.0 3 | Polygon3 4 | pyclipper 5 | imutils==0.5.4 6 | Cython 7 | editdistance 8 | torch 9 | torchvision 10 | tqdm 11 | tensorboard 12 | opencv-python==4.5.1.48 13 | scipy 14 | matplotlib 15 | pillow==8.3.0 16 | -------------------------------------------------------------------------------- /train_pan_pp/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import os.path as osp 5 | import random 6 | import time 7 | import numpy as np 8 | import torch 9 | import os 10 | import sys 11 | BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | sys.path.append(BASE_DIR) 13 | from mmcv import Config 14 | from models import build_model 15 | from dataset import build_data_loader 16 | 17 | from utils import AverageMeter 18 | import warnings 19 | 20 | warnings.filterwarnings("ignore") 21 | torch.manual_seed(9797) 22 | torch.cuda.manual_seed(9797) 23 | np.random.seed(9797) 24 | random.seed(9797) 25 | EPS = 1e-6 26 | 27 | from torch.utils.tensorboard import SummaryWriter 28 | 29 | 30 | 31 | def init_tensorboard(out_dir: str = 'logs'): 32 | if not os.path.exists(out_dir): ##目录存在,返回为真 33 | os.makedirs(out_dir) 34 | 35 | writer = SummaryWriter(log_dir=out_dir) 36 | ''' 37 | https://pytorch.org/docs/stable/tensorboard.html 38 | writer. 39 | add_scalar(tag, scalar_value, global_step=None, walltime=None, new_style=False, double_precision=False) 40 | add_scalars(main_tag, tag_scalar_dict, global_step=None, walltime=None) 41 | add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW') 42 | add_images(tag, img_tensor, global_step=None, walltime=None, dataformats='NCHW') 43 | ''' 44 | 45 | # writer.close() 需在最后关闭 46 | return writer 47 | 48 | def eval(train_loader, model, epoch, cfg): 49 | with torch.no_grad(): 50 | model.train() 51 | 52 | # meters 53 | batch_time = AverageMeter(max_len=500) 54 | data_time = AverageMeter(max_len=500) 55 | losses = AverageMeter(max_len=500) 56 | losses_text = AverageMeter(max_len=500) 57 | losses_kernels = AverageMeter(max_len=500) 58 | losses_emb = AverageMeter(max_len=500) 59 | ious_text = AverageMeter(max_len=500) 60 | ious_kernel = AverageMeter(max_len=500) 61 | 62 | # start time 63 | start = time.time() 64 | for iter, data in enumerate(train_loader): 65 | # skip previous iterations 66 | 67 | # time cost of data loader 68 | data_time.update(time.time() - start) 69 | 70 | # prepare input 71 | data.update(dict(cfg=cfg)) 72 | 73 | # forward 74 | outputs = model(**data) 75 | # ************************************************************************* 76 | # print(outputs) 77 | 78 | # detection loss 79 | loss_text = torch.mean(outputs['loss_text']) 80 | losses_text.update(loss_text.item(), data['imgs'].size(0)) 81 | 82 | loss_kernels = torch.mean(outputs['loss_kernels']) 83 | losses_kernels.update(loss_kernels.item(), data['imgs'].size(0)) 84 | if 'loss_emb' in outputs.keys(): 85 | loss_emb = torch.mean(outputs['loss_emb']) 86 | losses_emb.update(loss_emb.item(), data['imgs'].size(0)) 87 | loss = loss_text + loss_kernels + loss_emb 88 | else: 89 | loss = loss_text + loss_kernels 90 | 91 | iou_text = torch.mean(outputs['iou_text']) 92 | ious_text.update(iou_text.item(), data['imgs'].size(0)) 93 | iou_kernel = torch.mean(outputs['iou_kernel']) 94 | ious_kernel.update(iou_kernel.item(), data['imgs'].size(0)) 95 | 96 | losses.update(loss.item(), data['imgs'].size(0)) 97 | 98 | batch_time.update(time.time() - start) 99 | 100 | # update start time 101 | start = time.time() 102 | 103 | 104 | 105 | # print log 106 | if True: 107 | writer.add_scalar('Val-Loss', losses.avg, global_step=epoch) 108 | writer.add_scalar('Val-Loss-text', losses_text.avg, global_step=epoch) 109 | writer.add_scalar('Val-Loss-kernel', losses_kernels.avg, global_step=epoch) 110 | writer.add_scalar('Val-Loss-emb', losses_emb.avg, global_step=epoch) 111 | # writer.add_scalar('Loss-rec', losses_rec.avg, global_step=epoch) 112 | 113 | writer.add_scalar('Val-IoU(text)', ious_text.avg, global_step=epoch) 114 | writer.add_scalar('Val-IoU(kernel)', ious_kernel.avg, global_step=epoch) 115 | # writer.add_scalar('ACC rec', accs_rec.avg, global_step=epoch) 116 | 117 | log = f'EVALUATION-LOG' \ 118 | f'Total-Time: {batch_time.avg * iter / 60.0:.0f}min | ' \ 119 | f'Loss: {losses.avg:.3f} | ' \ 120 | f'IoU(text/kernel): {ious_text.avg:.3f}/{ious_kernel.avg:.3f}' #\ 121 | 122 | # f'{" | ACC rec: " + format(accs_rec.avg, ".3f") if with_rec else ""}' 123 | print(log, flush=True) 124 | 125 | 126 | 127 | gs = 0 # global step 128 | def train(train_loader, model, optimizer, epoch, start_iter, cfg): 129 | model.train() 130 | 131 | # meters 132 | batch_time = AverageMeter(max_len=500) 133 | data_time = AverageMeter(max_len=500) 134 | 135 | losses = AverageMeter(max_len=500) 136 | losses_text = AverageMeter(max_len=500) 137 | losses_kernels = AverageMeter(max_len=500) 138 | losses_emb = AverageMeter(max_len=500) 139 | losses_rec = AverageMeter(max_len=500) 140 | 141 | ious_text = AverageMeter(max_len=500) 142 | ious_kernel = AverageMeter(max_len=500) 143 | accs_rec = AverageMeter(max_len=500) 144 | 145 | with_rec = hasattr(cfg.model, 'recognition_head') 146 | 147 | # start time 148 | start = time.time() 149 | for iter, data in enumerate(train_loader): 150 | # skip previous iterations 151 | if iter < start_iter: 152 | print('Skipping iter: %d' % iter) 153 | continue 154 | 155 | # time cost of data loader 156 | data_time.update(time.time() - start) 157 | 158 | # adjust learning rate 159 | adjust_learning_rate(optimizer, train_loader, epoch, iter, cfg) 160 | 161 | # prepare input 162 | data.update(dict(cfg=cfg)) 163 | 164 | outputs = model(**data) 165 | 166 | loss_text = torch.mean(outputs['loss_text']) 167 | losses_text.update(loss_text.item(), data['imgs'].size(0)) 168 | 169 | loss_kernels = torch.mean(outputs['loss_kernels']) 170 | losses_kernels.update(loss_kernels.item(), data['imgs'].size(0)) 171 | if 'loss_emb' in outputs.keys(): 172 | loss_emb = torch.mean(outputs['loss_emb']) 173 | losses_emb.update(loss_emb.item(), data['imgs'].size(0)) 174 | loss = loss_text + loss_kernels + loss_emb 175 | else: 176 | loss = loss_text + loss_kernels 177 | 178 | iou_text = torch.mean(outputs['iou_text']) 179 | ious_text.update(iou_text.item(), data['imgs'].size(0)) 180 | iou_kernel = torch.mean(outputs['iou_kernel']) 181 | ious_kernel.update(iou_kernel.item(), data['imgs'].size(0)) 182 | 183 | # recognition loss 184 | if with_rec: 185 | loss_rec = outputs['loss_rec'] 186 | valid = loss_rec > -EPS 187 | if torch.sum(valid) > 0: 188 | loss_rec = torch.mean(loss_rec[valid]) 189 | losses_rec.update(loss_rec.item(), data['imgs'].size(0)) 190 | loss = loss + loss_rec 191 | 192 | acc_rec = outputs['acc_rec'] 193 | acc_rec = torch.mean(acc_rec[valid]) 194 | accs_rec.update(acc_rec.item(), torch.sum(valid).item()) 195 | 196 | # if cfg.debug: 197 | # from IPython import embed 198 | # embed() 199 | 200 | losses.update(loss.item(), data['imgs'].size(0)) 201 | 202 | # backward 203 | optimizer.zero_grad() 204 | loss.backward() 205 | optimizer.step() 206 | 207 | batch_time.update(time.time() - start) 208 | 209 | # update start time 210 | start = time.time() 211 | 212 | 213 | if iter % 20 == 0: 214 | length = len(train_loader) 215 | log = f'({iter + 1}/{length}) ' \ 216 | f'LR: {optimizer.param_groups[0]["lr"]:.6f} | ' \ 217 | f'Batch: {batch_time.avg:.3f}s | ' \ 218 | f'Total: {batch_time.avg * iter / 60.0:.0f}min | ' \ 219 | f'ETA: {batch_time.avg * (length - iter) / 60.0:.0f}min | ' \ 220 | f'Loss: {losses.avg:.3f} | ' \ 221 | f'Loss(text/kernel/emb{"/rec" if with_rec else ""}): ' \ 222 | f'{losses_text.avg:.3f}/{losses_kernels.avg:.3f}/' \ 223 | f'{losses_emb.avg:.3f}' \ 224 | f'{"/" + format(losses_rec.avg, ".3f") if with_rec else ""} | ' \ 225 | f'IoU(text/kernel): {ious_text.avg:.3f}/{ious_kernel.avg:.3f}' \ 226 | f'{" | ACC rec: " + format(accs_rec.avg, ".3f") if with_rec else ""}' 227 | print(log, flush=True) 228 | # print log 229 | 230 | # if iter == len(train_loader)-1: 231 | # gs = epoch*1000 + iter 232 | gs = epoch 233 | writer.add_scalar('Train-LR', optimizer.param_groups[0]["lr"], global_step=gs) 234 | writer.add_scalar('Train-Loss', losses.avg, global_step=gs) 235 | writer.add_scalar('Train-Loss-text', losses_text.avg, global_step=gs) 236 | writer.add_scalar('Train-Loss-kernel', losses_kernels.avg, global_step=gs) 237 | writer.add_scalar('Train-Loss-emb', losses_emb.avg, global_step=gs) 238 | writer.add_scalar('Train-Loss-rec', losses_rec.avg, global_step=gs) 239 | writer.add_scalar('Train-IoU(text)', ious_text.avg, global_step=gs) 240 | writer.add_scalar('Train-IoU(kernel)', ious_kernel.avg, global_step=gs) 241 | writer.add_scalar('Train-ACC rec', accs_rec.avg, global_step=gs) 242 | 243 | 244 | def adjust_learning_rate(optimizer, dataloader, epoch, iter, cfg): 245 | schedule = cfg.train_cfg.schedule 246 | 247 | if isinstance(schedule, str): 248 | assert schedule == 'polylr', 'Error: schedule should be polylr!' 249 | cur_iter = epoch * len(dataloader) + iter 250 | max_iter_num = cfg.train_cfg.epoch * len(dataloader) 251 | lr = cfg.train_cfg.lr * (1.0 - float(cur_iter) / max_iter_num) ** 0.9 252 | elif isinstance(schedule, tuple): 253 | lr = cfg.train_cfg.lr 254 | for i in range(len(schedule)): 255 | if epoch < schedule[i]: 256 | break 257 | lr = lr * 0.1 258 | 259 | for param_group in optimizer.param_groups: 260 | param_group['lr'] = lr 261 | 262 | 263 | def save_checkpoint(state, checkpoint_path, cfg): 264 | file_path = osp.join(checkpoint_path, 'checkpoint.pth.tar') 265 | torch.save(state, file_path) 266 | 267 | if cfg.data.train.type in ['synth'] or \ 268 | (state['iter'] == 0 and 269 | state['epoch'] % 5 == 0): 270 | file_name = 'checkpoint_%dep.pth.tar' % state['epoch'] 271 | file_path = osp.join(checkpoint_path, file_name) 272 | torch.save(state, file_path) 273 | 274 | 275 | def main(args): 276 | cfg = Config.fromfile(args.config) 277 | cfg.update(dict(debug=args.debug)) 278 | cfg.data.train.update(dict(debug=args.debug)) 279 | cfg.report_speed = False 280 | #print log 281 | print(json.dumps(cfg._cfg_dict, indent=4)) 282 | 283 | if args.checkpoint is not None: 284 | checkpoint_path = args.checkpoint 285 | else: 286 | cfg_name, _ = osp.splitext(osp.basename(args.config)) 287 | checkpoint_path = osp.join('checkpoints', cfg_name) 288 | if not osp.isdir(checkpoint_path): 289 | os.makedirs(checkpoint_path) 290 | 291 | # data loader 292 | data_loader = build_data_loader(cfg.data.train) 293 | train_size = int(0.8 * len(data_loader)) 294 | test_size = len(data_loader) - train_size 295 | train_dataset, test_dataset = torch.utils.data.random_split(data_loader, [train_size, test_size]) 296 | print(f'data_loader_type:{type(data_loader)}') 297 | train_loader = torch.utils.data.DataLoader( 298 | train_dataset, 299 | batch_size=cfg.data.batch_size, 300 | shuffle=not cfg.debug, 301 | num_workers=4, 302 | drop_last=True, 303 | pin_memory=True) 304 | eval_loader = torch.utils.data.DataLoader( 305 | test_dataset, 306 | batch_size=cfg.data.batch_size, 307 | shuffle=False, 308 | num_workers=4, 309 | drop_last=True, 310 | pin_memory=True) 311 | 312 | # model 313 | if hasattr(cfg.model, 'recognition_head'): 314 | cfg.model.recognition_head.update( 315 | dict( 316 | voc=data_loader.voc, 317 | char2id=data_loader.char2id, 318 | id2char=data_loader.id2char, 319 | )) 320 | model = build_model(cfg.model) 321 | 322 | if cfg.debug: 323 | # from IPython import embed; embed() 324 | checkpoint = torch.load('checkpoints/tmp.pth.tar') 325 | model.load_state_dict(checkpoint['state_dict']) 326 | 327 | model = torch.nn.DataParallel(model).cuda() 328 | 329 | # Check if model has custom optimizer / loss 330 | if hasattr(model.module, 'optimizer'): 331 | optimizer = model.module.optimizer 332 | else: 333 | if cfg.train_cfg.optimizer == 'SGD': 334 | optimizer = torch.optim.SGD(model.parameters(), 335 | lr=cfg.train_cfg.lr, 336 | momentum=0.99, 337 | weight_decay=5e-4) 338 | elif cfg.train_cfg.optimizer == 'Adam': 339 | optimizer = torch.optim.Adam(model.parameters(), 340 | lr=0.001)#cfg.train_cfg.lr 341 | 342 | start_epoch = 0 343 | start_iter = 0 344 | if hasattr(cfg.train_cfg, 'pretrain'): 345 | assert osp.isfile( 346 | cfg.train_cfg.pretrain), 'Error: no pretrained weights found!' 347 | print('Finetuning from pretrained model %s.' % cfg.train_cfg.pretrain) 348 | checkpoint = torch.load(cfg.train_cfg.pretrain) 349 | model.load_state_dict(checkpoint['state_dict']) 350 | if args.resume: 351 | assert osp.isfile(args.resume), 'Error: no checkpoint directory found!' 352 | print('Resuming from checkpoint %s.' % args.resume) 353 | checkpoint = torch.load(args.resume) 354 | # print(checkpoint) 355 | start_epoch = checkpoint['epoch'] 356 | start_iter = checkpoint['iter'] 357 | model.load_state_dict(checkpoint['state_dict']) 358 | optimizer.load_state_dict(checkpoint['optimizer']) 359 | 360 | for epoch in range(start_epoch, cfg.train_cfg.epoch): 361 | print('\nEpoch: [%d | %d]' % (epoch + 1, cfg.train_cfg.epoch)) 362 | # eval(eval_loader, model, epoch, cfg) 363 | train(train_loader, model, optimizer, epoch, start_iter, cfg) 364 | if epoch%5==0: 365 | eval(eval_loader, model, epoch, cfg) 366 | state = dict(epoch=epoch + 1, 367 | iter=0, 368 | state_dict=model.state_dict(), 369 | optimizer=optimizer.state_dict()) 370 | 371 | save_checkpoint(state, checkpoint_path, cfg=cfg) 372 | 373 | 374 | if __name__ == '__main__': 375 | try: 376 | writer = init_tensorboard('./tblogs') 377 | parser = argparse.ArgumentParser(description='Hyperparams') 378 | parser.add_argument('--config', help='config file path',default='config/pan_pp/R18-AUG.py') 379 | parser.add_argument('--checkpoint', nargs='?', type=str, default=None) 380 | parser.add_argument('--resume', nargs='?', type=str, default=None) 381 | parser.add_argument('--debug', action='store_true') 382 | args = parser.parse_args() 383 | main(args) 384 | finally: 385 | writer.close() 386 | -------------------------------------------------------------------------------- /train_pan_pp/train.sh: -------------------------------------------------------------------------------- 1 | # You may follow the instruction from the vanilla repo of PAN++ [https://github.com/whai362/pan_pp.pytorch] to get started! 2 | # Edit the path below if needed 3 | CONFIG_PTH='/opt/data/private/AlphX-Code-For-DAR/config/pan_pp/R18-AUG.py' 4 | python train_pan_pp/train.py --config ${CONFIG_PTH} -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .average_meter import AverageMeter 2 | from .corrector import Corrector 3 | from .logger import Logger 4 | from .result_format import ResultFormat 5 | from .visualizer import Visualizer 6 | 7 | __all__ = ['AverageMeter', 'Corrector', 'Logger', 'ResultFormat', 'Visualizer'] 8 | -------------------------------------------------------------------------------- /utils/average_meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value.""" 3 | def __init__(self, max_len=-1): 4 | self.val = [] 5 | self.count = [] 6 | self.max_len = max_len 7 | self.avg = 0 8 | 9 | def update(self, val, n=1): 10 | self.val.append(val * n) 11 | self.count.append(n) 12 | if self.max_len > 0 and len(self.val) > self.max_len: 13 | self.val = self.val[-self.max_len:] 14 | self.count = self.count[-self.max_len:] 15 | self.avg = sum(self.val) / sum(self.count) 16 | -------------------------------------------------------------------------------- /utils/corrector.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import editdistance 3 | import mmcv 4 | 5 | 6 | class Corrector: 7 | def __init__(self, 8 | data_type, 9 | len_thres, 10 | score_thres, 11 | unalpha_score_thres, 12 | ignore_score_thres, 13 | edit_dist_thres=0, 14 | edit_dist_score_thres=0, 15 | voc_type=None, 16 | voc_path=None): 17 | 18 | self.len_thres = len_thres 19 | self.score_thres = score_thres 20 | self.unalpha_score_thres = unalpha_score_thres 21 | self.ignore_score_thres = ignore_score_thres 22 | self.edit_dist_thres = edit_dist_thres 23 | self.edit_dist_score_thres = edit_dist_score_thres 24 | self.voc_type = voc_type 25 | self.voc = self.load_voc(data_type, voc_type, voc_path) 26 | 27 | def process(self, img_metas, outputs): 28 | img_name = img_metas['img_name'][0] 29 | words = outputs['words'] 30 | word_scores = outputs['word_scores'] 31 | words = [ 32 | self.correct( 33 | word, score, 34 | self.voc if self.voc_type != 's' \ 35 | else self.voc['voc_%s.txt' % img_name] 36 | ) 37 | for word, score in zip(words, word_scores) 38 | ] 39 | outputs.update(dict(words=words)) 40 | return outputs 41 | 42 | @staticmethod 43 | def _prefix_score(a, b): 44 | prefix_s = 0 45 | for i in range(min(len(a), len(b))): 46 | if a[i] == b[i]: 47 | prefix_s += 1.0 / (i + 1) 48 | return prefix_s 49 | 50 | def correct(self, word, score, voc=None): 51 | # print(voc is None) 52 | if len(word) < self.len_thres: 53 | return None 54 | if score > self.score_thres: 55 | return word 56 | # if not word.isalpha(): 57 | # if score > self.unalpha_score_thres: 58 | # return word 59 | # return None 60 | 61 | if score < self.ignore_score_thres: 62 | return None 63 | 64 | if voc is not None: 65 | min_d = 1e10 66 | matched = '' 67 | for voc_word in voc: 68 | d = editdistance.eval(word, voc_word) 69 | prefix_s = self._prefix_score(word, voc_word) 70 | if d < min_d: 71 | matched = voc_word 72 | min_d = d 73 | max_prefix_s = prefix_s 74 | elif d == min_d and prefix_s > max_prefix_s: 75 | matched = voc_word 76 | max_prefix_s = prefix_s 77 | 78 | if min_d == 0: 79 | break 80 | if min_d < self.edit_dist_thres or \ 81 | float(min_d) / len(word) < self.edit_dist_score_thres: 82 | return matched 83 | 84 | return None 85 | 86 | return word 87 | 88 | def load_voc(self, data_type, voc_type, voc_path): 89 | if voc_path is None: 90 | return None 91 | if 'IC15' in data_type: 92 | return self._load_voc_ic15(voc_type, voc_path) 93 | elif 'TT' in data_type: 94 | return self._load_voc_tt(voc_path) 95 | 96 | def _load_voc(self, voc_path): 97 | lines = mmcv.list_from_file(voc_path) 98 | voc = [] 99 | for line in lines: 100 | line = line.encode('utf-8').decode('utf-8-sig') 101 | line = line.replace('\xef\xbb\xbf', '') 102 | line = line.replace('\r', '').replace('\n', '') 103 | if len(line) == 0 or line[0] == '#': 104 | continue 105 | voc.append(line.lower()) 106 | 107 | return voc 108 | 109 | def _load_voc_ic15(self, voc_type, voc_path): 110 | if voc_type == 's' and osp.isdir(voc_path): 111 | voc_names = [voc_name for voc_name in 112 | mmcv.utils.scandir(voc_path, '.txt')] 113 | voc = {} 114 | for voc_name in voc_names: 115 | voc[voc_name] = self._load_voc(osp.join(voc_path, voc_name)) 116 | elif voc_type in ['g', 'w'] and osp.isfile(voc_path): 117 | voc = self._load_voc(voc_path) 118 | 119 | return voc 120 | 121 | def _load_voc_tt(self, voc_path): 122 | pass 123 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | class Logger(object): 2 | def __init__(self, fpath, title=None, resume=False): 3 | self.file = None 4 | self.resume = resume 5 | self.title = '' if title is None else title 6 | if fpath is not None: 7 | if resume: 8 | self.file = open(fpath, 'r') 9 | name = self.file.readline() 10 | self.names = name.rstrip().split('\t') 11 | self.numbers = {} 12 | for _, name in enumerate(self.names): 13 | self.numbers[name] = [] 14 | 15 | for numbers in self.file: 16 | numbers = numbers.rstrip().split('\t') 17 | for i in range(0, len(numbers)): 18 | self.numbers[self.names[i]].append(numbers[i]) 19 | self.file.close() 20 | self.file = open(fpath, 'a') 21 | else: 22 | self.file = open(fpath, 'w') 23 | 24 | def set_names(self, names): 25 | if self.resume: 26 | pass 27 | # initialize numbers as empty list 28 | self.numbers = {} 29 | self.names = names 30 | for _, name in enumerate(self.names): 31 | self.file.write(name) 32 | self.file.write('\t') 33 | self.numbers[name] = [] 34 | self.file.write('\n') 35 | self.file.flush() 36 | 37 | def append(self, numbers): 38 | assert len(self.names) == len(numbers) 39 | for index, num in enumerate(numbers): 40 | if type(num) == str: 41 | self.file.write(num) 42 | else: 43 | self.file.write('{0:.6f}'.format(num)) 44 | self.file.write('\t') 45 | self.numbers[self.names[index]].append(num) 46 | self.file.write('\n') 47 | self.file.flush() 48 | 49 | def close(self): 50 | if self.file is not None: 51 | self.file.close() 52 | -------------------------------------------------------------------------------- /utils/result_format.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import zipfile 4 | 5 | 6 | class ResultFormat(object): 7 | def __init__(self, data_type, result_path): 8 | self.data_type = data_type 9 | self.result_path = result_path 10 | 11 | if osp.isfile(result_path): 12 | os.remove(result_path) 13 | 14 | if result_path.endswith('.zip'): 15 | result_path = result_path.replace('.zip', '') 16 | 17 | if not osp.exists(result_path): 18 | os.makedirs(result_path) 19 | 20 | def write_result(self, img_metas, outputs): 21 | img_name = img_metas['img_name'][0] 22 | self._write_result_ctw(img_name, outputs) 23 | if 'IC15' in self.data_type: 24 | self._write_result_ic15(img_name, outputs) 25 | elif 'TT' in self.data_type: 26 | self._write_result_tt(img_name, outputs) 27 | elif 'CTW' in self.data_type: 28 | self._write_result_ctw(img_name, outputs) 29 | elif 'MSRA' in self.data_type: 30 | self._write_result_msra(img_name, outputs) 31 | 32 | def _write_result_ic15(self, img_name, outputs): 33 | assert self.result_path.endswith( 34 | '.zip'), 'Error: ic15 result should be a zip file!' 35 | 36 | tmp_folder = self.result_path.replace('.zip', '') 37 | 38 | bboxes = outputs['bboxes'] 39 | words = None 40 | if 'words' in outputs: 41 | words = outputs['words'] 42 | 43 | lines = [] 44 | for i, bbox in enumerate(bboxes): 45 | values = [int(v) for v in bbox] 46 | if words is None: 47 | line = '%d,%d,%d,%d,%d,%d,%d,%d\n' % tuple(values) 48 | lines.append(line) 49 | elif words[i] is not None: 50 | line = '%d,%d,%d,%d,%d,%d,%d,%d' % tuple( 51 | values) + ',%s\n' % words[i] 52 | lines.append(line) 53 | 54 | file_name = 'res_%s.txt' % img_name 55 | file_path = osp.join(tmp_folder, file_name) 56 | with open(file_path, 'w') as f: 57 | for line in lines: 58 | f.write(line) 59 | 60 | z = zipfile.ZipFile(self.result_path, 'a', zipfile.ZIP_DEFLATED) 61 | z.write(file_path, file_name) 62 | z.close() 63 | 64 | def _write_result_tt(self, image_name, outputs): 65 | bboxes = outputs['bboxes'] 66 | 67 | lines = [] 68 | for i, bbox in enumerate(bboxes): 69 | bbox = bbox.reshape(-1, 2)[:, ::-1].reshape(-1) 70 | values = [int(v) for v in bbox] 71 | line = '%d' % values[0] 72 | for v_id in range(1, len(values)): 73 | line += ',%d' % values[v_id] 74 | line += '\n' 75 | lines.append(line) 76 | 77 | file_name = '%s.txt' % image_name 78 | file_path = osp.join(self.result_path, file_name) 79 | with open(file_path, 'w') as f: 80 | for line in lines: 81 | f.write(line) 82 | 83 | def _write_result_ctw(self, image_name, outputs): 84 | bboxes = outputs['bboxes'] 85 | 86 | lines = [] 87 | for i, bbox in enumerate(bboxes): 88 | bbox = bbox.reshape(-1, 2)[:, ::-1].reshape(-1) 89 | values = [int(v) for v in bbox] 90 | line = '%d' % values[0] 91 | for v_id in range(1, len(values)): 92 | line += ',%d' % values[v_id] 93 | line += '\n' 94 | lines.append(line) 95 | 96 | file_name = '%s.txt' % image_name 97 | file_path = osp.join(self.result_path, file_name) 98 | with open(file_path, 'w') as f: 99 | for line in lines: 100 | f.write(line) 101 | 102 | def _write_result_msra(self, image_name, outputs): 103 | bboxes = outputs['bboxes'] 104 | 105 | lines = [] 106 | for b_idx, bbox in enumerate(bboxes): 107 | values = [int(v) for v in bbox] 108 | line = '%d' % values[0] 109 | for v_id in range(1, len(values)): 110 | line += ', %d' % values[v_id] 111 | line += '\n' 112 | lines.append(line) 113 | 114 | file_name = '%s.txt' % image_name 115 | file_path = osp.join(self.result_path, file_name) 116 | with open(file_path, 'w') as f: 117 | for line in lines: 118 | f.write(line) 119 | -------------------------------------------------------------------------------- /utils/visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | class Visualizer: 8 | def __init__(self, vis_path): 9 | self.vis_path = vis_path 10 | if not osp.exists(vis_path): 11 | os.makedirs(vis_path) 12 | 13 | def process(self, img_metas, outputs): 14 | img_path = img_metas['img_path'][0] 15 | img_name = img_metas['img_name'][0] 16 | bboxes = outputs['bboxes'] 17 | if 'words' in outputs: 18 | words = outputs['words'] 19 | else: 20 | words = [None] * len(bboxes) 21 | 22 | img = cv2.imread(img_path) 23 | for bbox, word in zip(bboxes, words): 24 | cv2.drawContours(img, [bbox.reshape(-1, 2)], -1, (0, 255, 0), 2) 25 | if word is not None: 26 | pos = np.min(bbox.reshape(-1, 2), axis=0) 27 | cv2.putText(img, word, (pos[0], pos[1]), 28 | cv2.FONT_HERSHEY_COMPLEX, 0.5, (255, 0, 0), 1) 29 | 30 | cv2.imwrite(osp.join(self.vis_path, '%s.jpg' % img_name), img) 31 | -------------------------------------------------------------------------------- /vis/34-V101P0264.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssocean/AlphX-Code-For-DAR/6be8c05c12d373245c97f8434d3e5abe9f8a4550/vis/34-V101P0264.jpg -------------------------------------------------------------------------------- /vis/aug-vis.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from imgaug import augmenters as iaa 4 | 5 | often = lambda aug: iaa.Sometimes(0.8, aug) 6 | sometimes = lambda aug: iaa.Sometimes(0.5, aug) 7 | seldom = lambda aug: iaa.Sometimes(0.2, aug) 8 | seq_1104 = iaa.Sequential([ 9 | seldom(iaa.OneOf([ 10 | iaa.Invert(0.1), 11 | iaa.CoarsePepper(0.005, size_percent=(0, 0.005)), 12 | iaa.CoarseSaltAndPepper(0.005, size_percent=(0, 0.005)), 13 | ]),), 14 | often(iaa.OneOf([ 15 | iaa.MultiplyBrightness((0.8, 1.1)), 16 | iaa.LinearContrast((0.9, 1.1)), 17 | iaa.Multiply((0.8, 1.1), per_channel=0.2), 18 | ])), 19 | sometimes(iaa.OneOf([ 20 | iaa.JpegCompression(compression=(0, 50)), 21 | iaa.imgcorruptlike.GaussianNoise(severity=1), 22 | iaa.imgcorruptlike.ShotNoise(severity=1), 23 | iaa.imgcorruptlike.ImpulseNoise(severity=1), 24 | iaa.imgcorruptlike.SpeckleNoise(severity=1), 25 | ])), 26 | ], random_order=False) 27 | # BCHW->BHWC 28 | img = np.array(img) 29 | img = seq_1104(image=img) -------------------------------------------------------------------------------- /vis/image_373.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssocean/AlphX-Code-For-DAR/6be8c05c12d373245c97f8434d3e5abe9f8a4550/vis/image_373.jpg -------------------------------------------------------------------------------- /vis/image_553.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssocean/AlphX-Code-For-DAR/6be8c05c12d373245c97f8434d3e5abe9f8a4550/vis/image_553.jpg --------------------------------------------------------------------------------