├── .gitignore ├── CONTRIBUTING.md ├── INTRODUCTION.md ├── README.md ├── assets ├── chinese_charset.dic ├── ic15_eval │ ├── gt.zip │ ├── rrc_evaluation_funcs.py │ └── script.py ├── ops │ └── dcn │ │ ├── __init__.py │ │ ├── functions │ │ ├── __init__.py │ │ ├── deform_conv.py │ │ └── deform_pool.py │ │ ├── modules │ │ ├── __init__.py │ │ ├── deform_conv.py │ │ └── deform_pool.py │ │ ├── setup.py │ │ └── src │ │ ├── deform_conv_cuda.cpp │ │ ├── deform_conv_cuda_kernel.cu │ │ ├── deform_pool_cuda.cpp │ │ └── deform_pool_cuda_kernel.cu └── train_lexicon.txt ├── backbones ├── __init__.py ├── base.py ├── crnn.py ├── feature_pyramid.py ├── fpn_top_down.py ├── ppm.py ├── resnet.py ├── resnet_dilated.py ├── resnet_fpn.py ├── resnet_ppm.py └── upsample_head.py ├── concern ├── __init__.py ├── average_meter.py ├── box2seg.py ├── charset_tool.py ├── charsets.py ├── config.py ├── convert.py ├── cv.py ├── distributed.py ├── icdar2015_eval │ ├── __init__.py │ └── detection │ │ ├── __init__.py │ │ ├── deteval.py │ │ ├── icdar2013.py │ │ ├── iou.py │ │ └── mtwi2018.py ├── log.py ├── nori_reader.py ├── redis_meta.py ├── signal_monitor.py ├── tensorboard.py ├── test_log.py ├── textsnake.py ├── visualizer.py └── webcv2 │ ├── __init__.py │ ├── manager.py │ ├── server.py │ └── templates │ └── index.html ├── config.py ├── data ├── __init__.py ├── augmenter.py ├── crop_file_dataset.py ├── data_loader.py ├── dataset.py ├── east.py ├── file_dataset.py ├── list_dataset.py ├── lmdb_dataset.py ├── local_csv.py ├── meta.py ├── meta_loader.py ├── meta_loaders │ ├── __init__.py │ ├── charbox_meta_loader.py │ ├── data_id_meta_loader.py │ ├── detection_meta_loader.py │ ├── json_meta_loader.py │ ├── lmdb_meta_loader.py │ ├── meta_cache.py │ ├── meta_loader.py │ ├── nori_meta_loader.py │ ├── recognition_meta_loader.py │ ├── redis_meta.py │ └── text_lines_meta_loader.py ├── mingle_dataset.py ├── mnist.py ├── nori_dataset.py ├── processes │ ├── __init__.py │ ├── augment_data.py │ ├── charboxes_from_textlines.py │ ├── data_process.py │ ├── extract_detetion_data.py │ ├── filter_keys.py │ ├── make_border_map.py │ ├── make_center_distance_map.py │ ├── make_center_map.py │ ├── make_center_points.py │ ├── make_decouple_map.py │ ├── make_density_map.py │ ├── make_icdar_data.py │ ├── make_keypoint_map.py │ ├── make_recognition_label.py │ ├── make_seg_detection_data.py │ ├── make_seg_recognition_label.py │ ├── normalize_image.py │ ├── random_crop_data.py │ ├── resize_image.py │ └── serialize_box.py ├── quad.py ├── simple_detection.py ├── text_lines.py ├── textsnake.py └── unpack_msgpack_data.py ├── decoders ├── __init__.py ├── attention_decoder.py ├── balance_cross_entropy_loss.py ├── classification.py ├── crnn.py ├── ctc_decoder.py ├── ctc_decoder2d.py ├── ctc_loss.py ├── ctc_loss2d.py ├── dice_loss.py ├── east.py ├── l1_loss.py ├── pss_loss.py ├── seg_detector.py ├── seg_detector_loss.py ├── seg_recognizer.py ├── simple_detection.py └── textsnake.py ├── eval.py ├── experiment.py ├── experiments ├── base.yaml ├── recognition │ ├── community-base.yaml │ ├── crnn-lmdb.yaml │ ├── crnn.yaml │ ├── fpn50-attention-decoder.yaml │ └── res50-ppm-2d-ctc.yaml └── seg_detector │ ├── community-base.yaml │ └── seg_detector_db.yaml ├── ops ├── __init__.py └── ctc_2d │ ├── csrc │ ├── ctc2d.cpp │ ├── ctc2d.h │ └── cuda │ │ ├── ctc2d.h │ │ ├── ctc2d_cuda.cu │ │ └── ctc2d_cuda_kernel.cu │ ├── ctc_loss_2d.py │ └── setup.py ├── requirement.txt ├── scripts ├── json_to_lmdb.py └── nori_to_lmdb.py ├── structure ├── __init__.py ├── builder.py ├── ensemble_model.py ├── measurers │ ├── __init__.py │ ├── classification_measurer.py │ ├── grid_sampling_measurer.py │ ├── icdar_detection_measurer.py │ ├── quad_measurer.py │ ├── sequence_recognition_measurer.py │ ├── simple_detection.py │ └── textsnake.py ├── model.py ├── models │ └── maskrcnn_benchmark │ │ └── __init__.py ├── representers │ ├── __init__.py │ ├── classification_representer.py │ ├── ctc_representer.py │ ├── ctc_representer2d.py │ ├── east.py │ ├── ensemble_ctc_representer.py │ ├── integral_regression_representer.py │ ├── mask_rcnn.py │ ├── seg_detector_representer.py │ ├── seg_recognition_representer.py │ ├── sequence_recognition_representer.py │ ├── simple_detection.py │ └── textsnake.py ├── visualizer.py └── visualizers │ ├── __init__.py │ ├── ctc_visualizer2d.py │ ├── east.py │ ├── seg_detector_visualizer.py │ ├── seg_recognition_visualizer.py │ ├── sequence_recognition_visualizer.py │ ├── simple_detection.py │ └── textsnake.py ├── train.py ├── trainer.py └── training ├── checkpoint.py ├── learning_rate.py ├── model_saver.py └── optimizer_scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | */__pycache__/* 3 | train_log/ 4 | train_log 5 | workspace 6 | tensorboards 7 | *.py[cod] 8 | *$py.class 9 | *.swp 10 | *.swo 11 | *.lock 12 | 13 | # C extensions 14 | *.so 15 | *.nfs* 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | .idea 110 | log.txt # From the naive evaluating of ICDAR15 111 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MegReader 2 | 3 | You are very welcome to contribute to MegReader. We will be grateful if you submit pull requists, help revise docmentations, report bugs or request algorithms which you regard valuable. 4 | 5 | ## Issues 6 | 7 | We use issues for bug reporting and feature request. 8 | 9 | ## Code Structure 10 | -------------------------------------------------------------------------------- /INTRODUCTION.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Megvii-CSG/MegReader/4cb5635d109f81f020780f39de9fca29a44c9f82/INTRODUCTION.md -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MegReader 2 | A project for research in text detection and recognition using PyTorch 1.2. 3 | 4 | This project is originated from the research repo, which heavily relies on closed-source libraries, of CSG-Algorithm team of Megvii(https://megvii.com). 5 | We are in ongoing progress to transfer models into this repo gradually, released implementations are listed in [Progress](#progress). 6 | 7 | ## Highlights 8 | 9 | - Implementations of representative text detection and recognition methods. 10 | - An effective framework for conducting experiments: We use yaml files to configure experiments, making it convenient to take experiments. 11 | - Thorough logging features which make it easy to follow and analyze experimental results. 12 | - CPU/GPU compatible for training and inference. 13 | - Distributed training support. 14 | 15 | ## Install 16 | 17 | ### Requirements 18 | 19 | `pip install -r requirements.txt` 20 | 21 | - Python3.7 22 | - PyTorch 1.2 and CUDA 10.0. 23 | - gcc 5.5(Important for compiling) 24 | 25 | ### Compile cuda ops (If needed) 26 | ``` 27 | cd PATH_TO_OPS 28 | 29 | python setup.py build_ext --inplace 30 | ``` 31 | ops may be used: 32 | - DeformableConvV2 `assets/ops/dcn` 33 | - CTC2DLoss `ops/ctc_2d` 34 | 35 | ### Configuration(optional) 36 | 37 | Edit configurations in `config.py`. 38 | 39 | ## Training 40 | 41 | See detailed options: `python3 train.py --help` 42 | 43 | ## Datasets 44 | We provide data loading implementation with annotation packed with json for quick start. 45 | Also, lmdb format data are now available too. 46 | You can refer the usage in [demo](experiments/recognition/crnn-lmdb.yaml). 47 | Datasets used in our recognition experiments can be downloaded from [onedrive](https://megvii-my.sharepoint.cn/:f:/g/personal/wanzhaoyi_megvii_com/EjkcrpmiW6hJrUKY-0fEBRABvNMtYniUPfWLVptMmy9-6w?e=bJaYFo). The transform [script](scripts/json_to_lmdb.py) are provide to convert json format data to lmdb. 48 | 49 | ### Non-distributed 50 | 51 | `python3 train.py PATH_TO_EXPERIMENT.yaml --validate --visualize --name NAME_OF_EXPERIMENT` 52 | 53 | Following we provide some of configurations of the released recognition models: 54 | 55 | - CRNN: `experiments/recognition/crnn.yaml` 56 | - 2D CTC: `experiments/recognition/res50-ppm-2d-ctc.yaml` 57 | - Attention Decoder: `experiments/recognition/fpn50-attention-decoder.yaml` 58 | 59 | ### Distributed(recommended for multi-gpu training) 60 | 61 | `python3 -m torch.distributed.launch --nproc_per_node=NUM_GPUS train.py PATH_TO_EXPERIMENT.yaml -d --validate` 62 | 63 | 66 | 67 | 68 | ## Evaluating 69 | 70 | See detailed options: `python3 eval.py --help`. 71 | 72 | Keeping ratio tesing is recommended: `python3 eval.py PATH_TO_EXPERIMENT.yaml --resize_mode keep_ratio` 73 | 74 | 75 | ### Model zoo 76 | Trained models are comming soon. 77 | 80 | 81 | ## Progress 82 | ### Recognition Methods 83 | - [x] 2D CTC 84 | - [x] CRNN 85 | - [x] Attention Decoder 86 | - [ ] Rectification 87 | 88 | ### Detection Methods 89 | - [x] Text Snake 90 | - [x] EAST 91 | 92 | ### End-to-end 93 | - [ ] Mask Text Spotter 94 | 95 | ## Contributing 96 | 97 | [Contributing.md](CONTRIBUTING.md) 98 | -------------------------------------------------------------------------------- /assets/ic15_eval/gt.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Megvii-CSG/MegReader/4cb5635d109f81f020780f39de9fca29a44c9f82/assets/ic15_eval/gt.zip -------------------------------------------------------------------------------- /assets/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions.deform_conv import deform_conv, modulated_deform_conv 2 | from .functions.deform_pool import deform_roi_pooling 3 | from .modules.deform_conv import (DeformConv, ModulatedDeformConv, 4 | DeformConvPack, ModulatedDeformConvPack) 5 | from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack, 6 | ModulatedDeformRoIPoolingPack) 7 | 8 | __all__ = [ 9 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 10 | 'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack', 11 | 'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv', 12 | 'deform_roi_pooling' 13 | ] 14 | -------------------------------------------------------------------------------- /assets/ops/dcn/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Megvii-CSG/MegReader/4cb5635d109f81f020780f39de9fca29a44c9f82/assets/ops/dcn/functions/__init__.py -------------------------------------------------------------------------------- /assets/ops/dcn/functions/deform_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from .. import deform_pool_cuda 5 | 6 | 7 | class DeformRoIPoolingFunction(Function): 8 | 9 | @staticmethod 10 | def forward(ctx, 11 | data, 12 | rois, 13 | offset, 14 | spatial_scale, 15 | out_size, 16 | out_channels, 17 | no_trans, 18 | group_size=1, 19 | part_size=None, 20 | sample_per_part=4, 21 | trans_std=.0): 22 | ctx.spatial_scale = spatial_scale 23 | ctx.out_size = out_size 24 | ctx.out_channels = out_channels 25 | ctx.no_trans = no_trans 26 | ctx.group_size = group_size 27 | ctx.part_size = out_size if part_size is None else part_size 28 | ctx.sample_per_part = sample_per_part 29 | ctx.trans_std = trans_std 30 | 31 | assert 0.0 <= ctx.trans_std <= 1.0 32 | if not data.is_cuda: 33 | raise NotImplementedError 34 | 35 | n = rois.shape[0] 36 | output = data.new_empty(n, out_channels, out_size, out_size) 37 | output_count = data.new_empty(n, out_channels, out_size, out_size) 38 | deform_pool_cuda.deform_psroi_pooling_cuda_forward( 39 | data, rois, offset, output, output_count, ctx.no_trans, 40 | ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size, 41 | ctx.part_size, ctx.sample_per_part, ctx.trans_std) 42 | 43 | if data.requires_grad or rois.requires_grad or offset.requires_grad: 44 | ctx.save_for_backward(data, rois, offset) 45 | ctx.output_count = output_count 46 | 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | if not grad_output.is_cuda: 52 | raise NotImplementedError 53 | 54 | data, rois, offset = ctx.saved_tensors 55 | output_count = ctx.output_count 56 | grad_input = torch.zeros_like(data) 57 | grad_rois = None 58 | grad_offset = torch.zeros_like(offset) 59 | 60 | deform_pool_cuda.deform_psroi_pooling_cuda_backward( 61 | grad_output, data, rois, offset, output_count, grad_input, 62 | grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels, 63 | ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part, 64 | ctx.trans_std) 65 | return (grad_input, grad_rois, grad_offset, None, None, None, None, 66 | None, None, None, None) 67 | 68 | 69 | deform_roi_pooling = DeformRoIPoolingFunction.apply 70 | -------------------------------------------------------------------------------- /assets/ops/dcn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Megvii-CSG/MegReader/4cb5635d109f81f020780f39de9fca29a44c9f82/assets/ops/dcn/modules/__init__.py -------------------------------------------------------------------------------- /assets/ops/dcn/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='deform_conv', 6 | ext_modules=[ 7 | CUDAExtension('deform_conv_cuda', [ 8 | 'src/deform_conv_cuda.cpp', 9 | 'src/deform_conv_cuda_kernel.cu', 10 | ]), 11 | CUDAExtension('deform_pool_cuda', [ 12 | 'src/deform_pool_cuda.cpp', 'src/deform_pool_cuda_kernel.cu' 13 | ]), 14 | ], 15 | cmdclass={'build_ext': BuildExtension}) 16 | -------------------------------------------------------------------------------- /backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet_fpn import Resnet18FPN, Resnet34FPN, Resnet50FPN, Resnet101FPN, Resnet152FPN 2 | from .resnet import resnet18, resnet34, resnet50, resnet101, deformable_resnet50 3 | from .crnn import crnn_backbone 4 | from .resnet_ppm import resnet50dilated_ppm 5 | -------------------------------------------------------------------------------- /backbones/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1, has_bias=False): 6 | "3x3 convolution with padding" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=has_bias) 9 | 10 | 11 | def conv3x3_bn_relu(in_planes, out_planes, stride=1): 12 | return nn.Sequential( 13 | conv3x3(in_planes, out_planes, stride), 14 | nn.BatchNorm2d(out_planes), 15 | nn.ReLU(inplace=True), 16 | ) 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /backbones/crnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class CRNN(nn.Module): 5 | 6 | def __init__(self, imgH, nc, nclass, nh): 7 | super(CRNN, self).__init__() 8 | assert imgH % 16 == 0, 'imgH has to be a multiple of 16' 9 | 10 | self.kernels = [3, 3, 3, 3, 3, 3, 2] 11 | self.paddings = [1, 1, 1, 1, 1, 1, 0] 12 | self.strides = [1, 1, 1, 1, 1, 1, 1] 13 | self.channels = [64, 128, 256, 256, 512, 512, 512, nc] 14 | 15 | conv0 = nn.Sequential( 16 | self._make_layer(0), 17 | nn.MaxPool2d((2, 2)) 18 | ) 19 | conv1 = nn.Sequential( 20 | self._make_layer(1), 21 | nn.MaxPool2d((2, 2)) 22 | ) 23 | conv2 = self._make_layer(2, True) 24 | conv3 = nn.Sequential( 25 | self._make_layer(3), 26 | nn.MaxPool2d((2, 2), (2, 1), (0, 1)) 27 | ) 28 | conv4 = self._make_layer(4, True) 29 | conv5 = nn.Sequential( 30 | self._make_layer(5), 31 | nn.MaxPool2d((2, 2), (2, 1), (0, 1)) 32 | ) 33 | conv6 = self._make_layer(6, True) 34 | 35 | self.cnn = nn.Sequential( 36 | conv0, 37 | conv1, 38 | conv2, 39 | conv3, 40 | conv4, 41 | conv5, 42 | conv6 43 | ) 44 | 45 | 46 | def _make_layer(self, i, batch_normalization=False): 47 | in_channel = self.channels[i - 1] 48 | out_channel = self.channels[i] 49 | layer = list() 50 | layer.append(nn.Conv2d(in_channel, out_channel, self.kernels[i], self.strides[i], self.paddings[i])) 51 | if batch_normalization: 52 | layer.append(nn.BatchNorm2d(out_channel)) 53 | else: 54 | layer.append(nn.ReLU()) 55 | return nn.Sequential(*layer) 56 | 57 | def forward(self, input): 58 | # conv features 59 | return self.cnn(input) 60 | 61 | 62 | def crnn_backbone(imgH=32, nc=3, nclass=37, nh=256): 63 | return CRNN(imgH, nc, nclass, nh) 64 | -------------------------------------------------------------------------------- /backbones/feature_pyramid.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class FeaturePyramid(nn.Module): 5 | def __init__(self, bottom_up, top_down): 6 | nn.Module.__init__(self) 7 | 8 | self.bottom_up = bottom_up 9 | self.top_down = top_down 10 | 11 | def forward(self, feature): 12 | pyramid_features = self.bottom_up(feature) 13 | feature = self.top_down(pyramid_features[::-1]) 14 | return feature 15 | -------------------------------------------------------------------------------- /backbones/fpn_top_down.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FPNTopDown(nn.Module): 7 | def __init__(self, pyramid_channels, feature_channel): 8 | nn.Module.__init__(self) 9 | 10 | self.reduction_layers = nn.ModuleList() 11 | for pyramid_channel in pyramid_channels: 12 | reduction_layer = nn.Conv2d(pyramid_channel, feature_channel, kernel_size=1, stride=1, padding=0, bias=False) 13 | self.reduction_layers.append(reduction_layer) 14 | 15 | self.merge_layer = nn.Conv2d(feature_channel, feature_channel, kernel_size=3, stride=1, padding=1, bias=False) 16 | 17 | def upsample_add(self, x, y): 18 | _, _, H, W = y.size() 19 | return F.interpolate(x, size=(H, W), mode='bilinear') + y 20 | 21 | def forward(self, pyramid_features): 22 | feature = None 23 | for pyramid_feature, reduction_layer in zip(pyramid_features, self.reduction_layers): 24 | pyramid_feature = reduction_layer(pyramid_feature) 25 | if feature is None: 26 | feature = pyramid_feature 27 | else: 28 | feature = self.upsample_add(feature, pyramid_feature) 29 | feature = self.merge_layer(feature) 30 | return feature 31 | -------------------------------------------------------------------------------- /backbones/ppm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .base import conv3x3_bn_relu, conv3x3 4 | 5 | 6 | class PPMDeepsup(nn.Module): 7 | def __init__(self, inner_channels=256, fc_dim=2048, 8 | pool_scales=(1, 2, 3, 6)): 9 | super(PPMDeepsup, self).__init__() 10 | 11 | self.ppm = [] 12 | for scale in pool_scales: 13 | self.ppm.append(nn.Sequential( 14 | nn.AdaptiveAvgPool2d(scale), 15 | nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), 16 | nn.BatchNorm2d(512), 17 | nn.ReLU(inplace=True) 18 | )) 19 | self.ppm = nn.ModuleList(self.ppm) 20 | self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) 21 | 22 | self.conv_last = nn.Sequential( 23 | nn.Conv2d(fc_dim+len(pool_scales)*512, 512, 24 | kernel_size=3, padding=1, bias=False), 25 | nn.BatchNorm2d(512), 26 | nn.ReLU(inplace=True), 27 | nn.Dropout2d(0.1), 28 | nn.Conv2d(512, inner_channels, kernel_size=1) 29 | ) 30 | 31 | def forward(self, conv_out, segSize=None): 32 | conv5 = conv_out[-1] 33 | 34 | input_size = conv5.size() 35 | ppm_out = [conv5] 36 | for pool_scale in self.ppm: 37 | ppm_out.append(nn.functional.interpolate( 38 | pool_scale(conv5), 39 | (input_size[2], input_size[3]), 40 | mode='bilinear', align_corners=False)) 41 | ppm_out = torch.cat(ppm_out, 1) 42 | 43 | x = self.conv_last(ppm_out) 44 | return x 45 | 46 | -------------------------------------------------------------------------------- /backbones/resnet_dilated.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ResnetDilated(nn.Module): 6 | def __init__(self, orig_resnet, dilate_scale=8): 7 | super(ResnetDilated, self).__init__() 8 | from functools import partial 9 | 10 | if dilate_scale == 8: 11 | orig_resnet.layer3.apply( 12 | partial(self._nostride_dilate, dilate=2)) 13 | orig_resnet.layer4.apply( 14 | partial(self._nostride_dilate, dilate=4)) 15 | elif dilate_scale == 16: 16 | orig_resnet.layer4.apply( 17 | partial(self._nostride_dilate, dilate=2)) 18 | 19 | # take pretrained resnet, except AvgPool and FC 20 | self.conv1 = orig_resnet.conv1 21 | self.bn1 = orig_resnet.bn1 22 | self.relu1 = orig_resnet.relu1 23 | self.conv2 = orig_resnet.conv2 24 | self.bn2 = orig_resnet.bn2 25 | self.relu2 = orig_resnet.relu2 26 | self.conv3 = orig_resnet.conv3 27 | self.bn3 = orig_resnet.bn3 28 | self.relu3 = orig_resnet.relu3 29 | self.maxpool = orig_resnet.maxpool 30 | self.layer1 = orig_resnet.layer1 31 | self.layer2 = orig_resnet.layer2 32 | self.layer3 = orig_resnet.layer3 33 | self.layer4 = orig_resnet.layer4 34 | 35 | def _nostride_dilate(self, m, dilate): 36 | classname = m.__class__.__name__ 37 | if classname.find('Conv') != -1: 38 | # the convolution with stride 39 | if m.stride == (2, 2): 40 | m.stride = (1, 1) 41 | if m.kernel_size == (3, 3): 42 | m.dilation = (dilate//2, dilate//2) 43 | m.padding = (dilate//2, dilate//2) 44 | # other convoluions 45 | else: 46 | if m.kernel_size == (3, 3): 47 | m.dilation = (dilate, dilate) 48 | m.padding = (dilate, dilate) 49 | 50 | def forward(self, x, return_feature_maps=True): 51 | conv_out = [] 52 | 53 | x = self.relu1(self.bn1(self.conv1(x))) 54 | x = self.relu2(self.bn2(self.conv2(x))) 55 | x = self.relu3(self.bn3(self.conv3(x))) 56 | x = self.maxpool(x) 57 | 58 | x = self.layer1(x) 59 | conv_out.append(x) 60 | x = self.layer2(x) 61 | conv_out.append(x) 62 | x = self.layer3(x) 63 | conv_out.append(x) 64 | x = self.layer4(x) 65 | conv_out.append(x) 66 | 67 | if return_feature_maps: 68 | return conv_out 69 | return x 70 | -------------------------------------------------------------------------------- /backbones/resnet_fpn.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152 2 | from .fpn_top_down import FPNTopDown 3 | from .feature_pyramid import FeaturePyramid 4 | 5 | 6 | def Resnet18FPN(resnet_pretrained=True): 7 | bottom_up = resnet18(pretrained=resnet_pretrained) 8 | top_down = FPNTopDown([512, 256, 128, 64], 256) 9 | feature_pyramid = FeaturePyramid(bottom_up, top_down) 10 | return feature_pyramid 11 | 12 | 13 | def Resnet34FPN(resnet_pretrained=True): 14 | bottom_up = resnet50(pretrained=resnet_pretrained) 15 | top_down = FPNTopDown([2048, 1024, 512, 256], 256) 16 | feature_pyramid = FeaturePyramid(bottom_up, top_down) 17 | return feature_pyramid 18 | 19 | 20 | def Resnet50FPN(resnet_pretrained=True): 21 | bottom_up = resnet50(pretrained=resnet_pretrained) 22 | top_down = FPNTopDown([2048, 1024, 512, 256], 256) 23 | feature_pyramid = FeaturePyramid(bottom_up, top_down) 24 | return feature_pyramid 25 | 26 | 27 | def Resnet101FPN(resnet_pretrained=True): 28 | bottom_up = resnet101(pretrained=resnet_pretrained) 29 | top_down = FPNTopDown([2048, 1024, 512, 256], 256) 30 | feature_pyramid = FeaturePyramid(bottom_up, top_down) 31 | return feature_pyramid 32 | 33 | 34 | def Resnet152FPN(resnet_pretrained=True): 35 | bottom_up = resnet152(pretrained=resnet_pretrained) 36 | top_down = FPNTopDown([2048, 1024, 512, 256], 256) 37 | feature_pyramid = FeaturePyramid(bottom_up, top_down) 38 | return feature_pyramid 39 | -------------------------------------------------------------------------------- /backbones/resnet_ppm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152 5 | from .resnet_dilated import ResnetDilated 6 | from .ppm import PPMDeepsup 7 | 8 | 9 | def resnet50dilated_ppm(resnet_pretrained=False, **kwargs): 10 | resnet = resnet50(pretrained=resnet_pretrained) 11 | resnet_dilated = ResnetDilated(resnet, dilate_scale=8) 12 | ppm = PPMDeepsup(**kwargs) 13 | return nn.Sequential(resnet_dilated, ppm) 14 | -------------------------------------------------------------------------------- /backbones/upsample_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def SimpleUpsampleHead(feature_channel, layer_channels): 5 | modules = [] 6 | modules.append(nn.Conv2d(feature_channel, layer_channels[0], kernel_size=3, stride=1, padding=1, bias=False)) 7 | for layer_index in range(len(layer_channels) - 1): 8 | modules.extend([ 9 | nn.BatchNorm2d(layer_channels[layer_index]), 10 | nn.ReLU(inplace=True), 11 | nn.ConvTranspose2d(layer_channels[layer_index], layer_channels[layer_index + 1], kernel_size=2, stride=2, padding=0, bias=False), 12 | ]) 13 | return nn.Sequential(*modules) 14 | -------------------------------------------------------------------------------- /concern/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : __init__.py 4 | # Author : Zhaoyi Wan 5 | # Date : 21.11.2018 6 | # Last Modified Date: 08.01.2019 7 | # Last Modified By : Zhaoyi Wan 8 | 9 | from .log import Logger 10 | from .average_meter import AverageMeter 11 | from .visualizer import Visualize 12 | from .box2seg import resize_with_coordinates, box2seg 13 | from .convert import convert 14 | -------------------------------------------------------------------------------- /concern/average_meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.reset() 6 | 7 | def reset(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def update(self, val, n=1): 14 | self.val = val 15 | self.sum += val * n 16 | self.count += n 17 | self.avg = self.sum / self.count 18 | return self 19 | -------------------------------------------------------------------------------- /concern/box2seg.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from scipy import interpolate 4 | 5 | def intersection(x, p1, p2): 6 | x1, y1 = p1 7 | x2, y2 = p2 8 | if x2 == x1: 9 | return 0 10 | k = (x - x1) / (x2 - x1) 11 | return k * (y2 - y1) + y1 12 | 13 | 14 | def midpoint(p1, p2, typed=float): 15 | return [typed((p1[0] + p2[0]) / 2), typed((p1[1] + p2[1]) / 2)] 16 | 17 | 18 | def resize_with_coordinates(image, width, height, coordinates): 19 | original_height, original_width = image.shape[:2] 20 | resized_image = cv2.resize(image, (width, height)) 21 | if coordinates is not None: 22 | assert coordinates.ndim == 2 23 | assert coordinates.shape[-1] == 2 24 | 25 | rate_x = width / original_width 26 | rate_y = height / original_height 27 | 28 | coordinates = coordinates * (rate_x, rate_y) 29 | return resized_image, coordinates 30 | 31 | 32 | def box2seg(image, boxes, label): 33 | height, width = image.shape[:2] 34 | mask = np.zeros((height, width), dtype=np.float32) 35 | seg = np.zeros((height, width), dtype=np.float32) 36 | points = [] 37 | for box_index in range(boxes.shape[0]): 38 | box = boxes[box_index, :, :] # 4x2 39 | left_top = box[0] 40 | right_top = box[1] 41 | right_bottom = box[2] 42 | left_bottom = box[3] 43 | 44 | left = [(left_top[0] + left_bottom[0]) / 2, (left_top[1] + left_bottom[1]) / 2] 45 | right = [(right_top[0] + right_bottom[0]) / 2, (right_top[1] + right_bottom[1]) / 2] 46 | 47 | center = midpoint(left, right) 48 | points.append(midpoint(left, center)) 49 | points.append(midpoint(right, center)) 50 | 51 | poly = np.array([midpoint(left_top, center), 52 | midpoint(right_top, center), 53 | midpoint(right_bottom, center), 54 | midpoint(left_bottom, center) 55 | ]) 56 | seg = cv2.fillPoly(seg, [poly.reshape(4, 1, 2).astype(np.int32)], int(label[box_index])) 57 | 58 | left_y = intersection(0, points[0], points[1]) 59 | right_y = intersection(width, points[-1], points[-2]) 60 | points.insert(0, [0, left_y]) 61 | points.append([width, right_y]) 62 | points = np.array(points) 63 | 64 | f = interpolate.interp1d(points[:, 0], points[:, 1], fill_value='extrapolate') 65 | xnew = np.arange(0, width, 1) 66 | ynew = f(xnew).clip(0, height-1) 67 | for x in range(width - 1): 68 | mask[int(ynew[x]), x] = 1 69 | return ynew.reshape(1, -1).round(), seg 70 | -------------------------------------------------------------------------------- /concern/charset_tool.py: -------------------------------------------------------------------------------- 1 | 2 | """判断一个unicode是否是汉字""" 3 | 4 | 5 | def is_chinese(uchar): 6 | if uchar >= u'\u4e00' and uchar <= u'\u9fa5': 7 | return True 8 | else: 9 | return False 10 | 11 | 12 | """判断一个unicode是否是数字""" 13 | 14 | 15 | def is_number(uchar): 16 | if uchar >= u'\u0030' and uchar <= u'\u0039': 17 | return True 18 | else: 19 | return False 20 | 21 | 22 | """判断一个unicode是否是英文字母""" 23 | 24 | 25 | def is_alphabet(uchar): 26 | if (uchar >= u'\u0041' and uchar <= u'\u005a') or (uchar >= u'\u0061' and uchar <= u'\u007a'): 27 | return True 28 | else: 29 | return False 30 | 31 | 32 | """判断是否是(汉字,数字和英文字符之外的)其他字符""" 33 | 34 | 35 | def is_other(uchar): 36 | if not (is_chinese(uchar) or is_number(uchar) or is_alphabet(uchar)): 37 | return True 38 | else: 39 | return False 40 | 41 | 42 | """半角转全角""" 43 | 44 | 45 | def B2Q(uchar): 46 | inside_code = ord(uchar) 47 | if inside_code < 0x0020 or inside_code > 0x7e: # 不是半角字符就返回原来的字符 48 | return uchar 49 | if inside_code == 0x0020: # 除了空格其他的全角半角的公式为:半角=全角-0xfee0 50 | inside_code = 0x3000 51 | else: 52 | inside_code += 0xfee0 53 | return chr(inside_code) 54 | 55 | 56 | """全角转半角""" 57 | 58 | 59 | def Q2B(uchar): 60 | inside_code = ord(uchar) 61 | if inside_code == 0x3000: 62 | inside_code = 0x0020 63 | else: 64 | inside_code -= 0xfee0 65 | if inside_code < 0x0020 or inside_code > 0x7e: # 转完之后不是半角字符返回原来的字符 66 | return uchar 67 | return chr(inside_code) 68 | 69 | 70 | """把字符串全角转半角""" 71 | 72 | 73 | def stringQ2B(ustring): 74 | return "".join([Q2B(uchar) for uchar in ustring]) 75 | 76 | 77 | """将UTF-8编码转换为Unicode编码""" 78 | 79 | 80 | def convert_toUnicode(string): 81 | # if not isinstance(string, unicode): 82 | ustring = string.decode('UTF-8') 83 | -------------------------------------------------------------------------------- /concern/convert.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import cv2 3 | import base64 4 | import io 5 | import numpy as np 6 | 7 | 8 | def convert(data): 9 | if isinstance(data, dict): 10 | ndata = {} 11 | for key, value in data.items(): 12 | nkey = key.decode() 13 | if nkey == 'img': 14 | img = Image.open(io.BytesIO(value)) 15 | img = img.convert('RGB') 16 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 17 | nvalue = img 18 | else: 19 | nvalue = convert(value) 20 | ndata[nkey] = nvalue 21 | return ndata 22 | elif isinstance(data, list): 23 | return [convert(item) for item in data] 24 | elif isinstance(data, bytes): 25 | return data.decode() 26 | else: 27 | return data 28 | 29 | 30 | def to_np(x): 31 | return x.cpu().data.numpy() 32 | -------------------------------------------------------------------------------- /concern/cv.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def min_area_rect(poly): 6 | poly = cv2.minAreaRect(np.array(poly, 'float32')) 7 | if poly[2] < -45: 8 | poly = (poly[0], poly[1], poly[2] + 180) 9 | else: 10 | poly = (poly[0], poly[1][::-1], poly[2] + 90) 11 | poly = cv2.boxPoints(poly) 12 | return poly 13 | -------------------------------------------------------------------------------- /concern/icdar2015_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Megvii-CSG/MegReader/4cb5635d109f81f020780f39de9fca29a44c9f82/concern/icdar2015_eval/__init__.py -------------------------------------------------------------------------------- /concern/icdar2015_eval/detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Megvii-CSG/MegReader/4cb5635d109f81f020780f39de9fca29a44c9f82/concern/icdar2015_eval/detection/__init__.py -------------------------------------------------------------------------------- /concern/nori_reader.py: -------------------------------------------------------------------------------- 1 | import config 2 | 3 | 4 | def NoriReader(paths=[]): 5 | import nori2 as nori 6 | from nori2.multi import MultiSourceReader 7 | if config.community_version: 8 | return MultiSourceReader(paths) 9 | else: 10 | return nori.Fetcher(paths) 11 | -------------------------------------------------------------------------------- /concern/redis_meta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import bisect 3 | 4 | import redis 5 | 6 | 7 | class RedisMeta: 8 | def __init__(self, connection, prefix, key=None): 9 | self.connection = connection 10 | self._key = key 11 | self.prefix = prefix 12 | 13 | def __getitem__(self, key): 14 | if self.key() is None: 15 | if self.connection.exists(self.key(key)): 16 | return type(self)(self.connection, prefix=self.prefix, key=key) 17 | raise KeyError 18 | return self.connection.lindex(self.key(), key).decode() 19 | 20 | def __contains__(self, key): 21 | if self.key() is None: 22 | assert isinstance(key, str) 23 | return self.connection.exists(self.key(key)) 24 | else: 25 | assert isinstance(key, int) 26 | return key < len(self) and key > -len(self) + 1 27 | 28 | def __iter__(self): 29 | if self.key() is None: 30 | iter(self.keys) 31 | else: 32 | for i in range(len(self)): 33 | yield self.__getitem__(i) 34 | 35 | def __add__(self, another): 36 | return ConcateMetaRedisMeta(self, another) 37 | 38 | def get(self, key, default=None): 39 | assert self.key() is None 40 | 41 | def key(self, key=None): 42 | if key is None: 43 | key = self._key 44 | if key is None: 45 | return None 46 | return os.path.join(self.prefix, key) 47 | 48 | def keys(self): 49 | assert self.key() is None 50 | keys = self.connection.smembers(self.key('__keys__')) 51 | return {key.decode() for key in keys} 52 | 53 | def items(self): 54 | assert self.key() is None 55 | tuples = [] 56 | for key in self.keys(): 57 | tuples.append((key, self.__getitem__(key))) 58 | return tuples 59 | 60 | 61 | def __len__(self): 62 | if self.key() is None: 63 | return len(self.keys()) 64 | return self.connection.llen(self.key()) 65 | 66 | 67 | class ConcateMetaRedisMeta: 68 | def __init__(self, *meta_list): 69 | milestones = [] 70 | start = 0 71 | for meta in meta_list: 72 | assert meta.key() is not None 73 | start += len(meta) 74 | milestones.append(start) 75 | 76 | self.milestones = milestones 77 | self.num_samples = start 78 | self.meta_list = list(meta_list) 79 | 80 | def __getitem__(self, index): 81 | meta_index = bisect.bisect(self.milestones, index) 82 | if meta_index == 0: 83 | real_index = index 84 | else: 85 | real_index = index - self.milestones[meta_index - 1] 86 | return self.meta_list[meta_index][real_index] 87 | 88 | def __len__(self): 89 | return self.num_samples 90 | 91 | def __add__(self, another): 92 | self.num_samples += len(another) 93 | self.milestones.append(self.num_samples) 94 | self.meta_list.append(another) 95 | return self 96 | -------------------------------------------------------------------------------- /concern/signal_monitor.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class SignalMonitor(object): 5 | def __init__(self, file_path): 6 | self.file_path = file_path 7 | 8 | def get_signal(self): 9 | if self.file_path is None: 10 | return None 11 | if os.path.exists(self.file_path): 12 | with open(self.file_path) as f: 13 | data = self.file.read() 14 | os.remove(f) 15 | return data 16 | else: 17 | return None 18 | -------------------------------------------------------------------------------- /concern/tensorboard.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : concern/tensorboard.py 4 | # Author : Zhaoyi Wan 5 | # Date : 03.01.2019 6 | # Last Modified Date: 03.01.2019 7 | # Last Modified By : Zhaoyi Wan 8 | -------------------------------------------------------------------------------- /concern/test_log.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : test_log.py 4 | # Author : Zhaoyi Wan 5 | # Date : 06.01.2019 6 | # Last Modified Date: 06.01.2019 7 | # Last Modified By : Zhaoyi Wan 8 | 9 | import argparse 10 | import os 11 | import torch 12 | import torch.utils.data as data 13 | 14 | from models import Builder 15 | from concern import Log, AverageMeter 16 | from data import NoriDataset 17 | from tqdm import tqdm 18 | import time 19 | import config 20 | import glob 21 | 22 | parser = argparse.ArgumentParser(description='Text Recognition Training') 23 | parser.add_argument('--name', default='', type=str) 24 | parser.add_argument('--batch_size', default=256, type=int, help='Batch size for training') 25 | parser.add_argument('--resume', default='', type=str, help='Resume from checkpoint') 26 | parser.add_argument('--num_workers', default=8, type=int, help='Number of workers used in dataloading') 27 | parser.add_argument('--iterations', default=1000000, type=int, help='Number of training iterations') 28 | parser.add_argument('--start_iter', default=0, type=int, help='Begin counting iterations starting from this value (should be used with resume)') 29 | parser.add_argument('--start_epoch', default=0, type=int, help='Begin counting epoch starting from this value (should be used with resume)') 30 | parser.add_argument('--use_isilon', default=True, type=bool, help='Use isilon to save logs and checkpoints') 31 | parser.add_argument('--num_gpus', default=4, type=int, help='number of gpus to use') 32 | parser.add_argument('--backbone', default='resnet50_fpn', type=str, help='The backbone want to use') 33 | parser.add_argument('--decoder', default='1d_ctc', type=str, help='The ctc formulation want to use') 34 | parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate') 35 | parser.add_argument('--aug', default=True, type=bool, help='add data augmentation') 36 | parser.add_argument('--verbose', default=True, type=bool, help='show verbose info') 37 | parser.add_argument('--validate', action='store_true', dest='validate', help='Validate during training') 38 | parser.add_argument('--no-validate', action='store_false', dest='validate', help='Validate during training') 39 | parser.set_defaults(validate=True) 40 | parser.add_argument('--nori', default='/unsullied/sharefs/_csg_algorithm/OCR/zhangjian02/data/ocr-data/synth-data/SynthText/croped_sorted.nori', type=str, help='nori_file_path') 41 | parser.add_argument('--test_nori', default='/unsullied/sharefs/wanzhaoyi/data/text-recognition-test/*.nori', type=str, help='nori_file_path for validation') 42 | 43 | args = parser.parse_args() 44 | name = args.name 45 | if name == '': 46 | name = args.backbone + '-' + args.decoder 47 | logger = Log('test_logger', args.use_isilon, args.verbose) 48 | logger.args(args) 49 | 50 | 51 | epoch = 0 52 | for e in range(3): 53 | for i in range(28385): 54 | loss = i * 1e-7 55 | logger.add_scalar('loss', loss, i + epoch * 28385); 56 | 57 | if i % 100 == 0: 58 | logger.info('iter: %6d, epoch: %3d, loss: %.6f, lr: %f' % (i, epoch, loss, 1e-4)) 59 | 60 | epoch += 1 61 | 62 | -------------------------------------------------------------------------------- /concern/visualizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | import config 5 | 6 | 7 | class Visualize: 8 | @classmethod 9 | def visualize(cls, x): 10 | dimension = len(x.shape) 11 | if dimension == 2: 12 | pass 13 | elif dimension == 3: 14 | pass 15 | 16 | @classmethod 17 | def to_np(cls, x): 18 | return x.cpu().data.numpy() 19 | 20 | @classmethod 21 | def visualize_weights(cls, tensor, format='HW', normalize=True): 22 | if isinstance(tensor, torch.Tensor): 23 | x = cls.to_np(tensor.permute(format.index('H'), format.index('W'))) 24 | else: 25 | x = tensor.transpose(format.index('H'), format.index('W')) 26 | if normalize: 27 | x = (x - x.min()) / (x.max() - x.min()) 28 | return cv2.applyColorMap((x * 255).astype(np.uint8), cv2.COLORMAP_JET) 29 | 30 | @classmethod 31 | def visualize_points(cls, image, tensor, radius=5, normalized=True): 32 | if isinstance(tensor, torch.Tensor): 33 | points = cls.to_np(tensor) 34 | else: 35 | points = tensor 36 | if normalized: 37 | points = points * image.shape[:2][::-1] 38 | for i in range(points.shape[0]): 39 | color = np.random.randint( 40 | 0, 255, (3, ), dtype=np.uint8).astype(np.float) 41 | image = cv2.circle(image, 42 | tuple(points[i].astype(np.int32).tolist()), 43 | radius, color, thickness=radius//2) 44 | return image 45 | 46 | @classmethod 47 | def visualize_heatmap(cls, tensor, format='CHW'): 48 | if isinstance(tensor, torch.Tensor): 49 | x = cls.to_np(tensor.permute(format.index('H'), 50 | format.index('W'), format.index('C'))) 51 | else: 52 | x = tensor.transpose( 53 | format.index('H'), format.index('W'), format.index('C')) 54 | canvas = np.zeros((x.shape[0], x.shape[1], 3), dtype=np.float) 55 | 56 | for c in range(0, x.shape[-1]): 57 | color = np.random.randint( 58 | 0, 255, (3, ), dtype=np.uint8).astype(np.float) 59 | canvas += np.tile(x[:, :, c], (3, 1, 1) 60 | ).swapaxes(0, 2).swapaxes(1, 0) * color 61 | 62 | canvas = canvas.astype(np.uint8) 63 | return canvas 64 | 65 | @classmethod 66 | def visualize_classes(cls, x): 67 | canvas = np.zeros((x.shape[0], x.shape[1], 3), dtype=np.uint8) 68 | for c in range(int(x.max())): 69 | color = np.random.randint( 70 | 0, 255, (3, ), dtype=np.uint8).astype(np.float) 71 | canvas[np.where(x == c)] = color 72 | return canvas 73 | 74 | @classmethod 75 | def visualize_grid(cls, x, y, stride=16, color=(0, 0, 255), canvas=None): 76 | h, w = x.shape 77 | if canvas is None: 78 | canvas = np.zeros((h, w, 3), dtype=np.uint8) 79 | # canvas = np.concatenate([canvas, canvas], axis=1) 80 | i, j = 0, 0 81 | while i < w: 82 | j = 0 83 | while j < h: 84 | canvas = cv2.circle(canvas, (int(x[i, j] * w + 0.5), int(y[i, j] * h + 0.5)), radius=max(stride//4, 1), color=color, thickness=stride//8) 85 | j += stride 86 | i += stride 87 | return canvas 88 | 89 | @classmethod 90 | def visualize_rect(cls, canvas, _rect, color=(0, 0, 255)): 91 | rect = (_rect + 0.5).astype(np.int32) 92 | return cv2.rectangle(canvas, (rect[0], rect[1]), (rect[2], rect[3]), color) 93 | -------------------------------------------------------------------------------- /concern/webcv2/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env mdl 2 | class WebCV2: 3 | def __init__(self): 4 | import cv2 5 | self._cv2 = cv2 6 | from .manager import global_manager as gm 7 | self._gm = gm 8 | 9 | def __getattr__(self, name): 10 | if hasattr(self._gm, name): 11 | return getattr(self._gm, name) 12 | elif hasattr(self._cv2, name): 13 | return getattr(self._cv2, name) 14 | else: 15 | raise AttributeError 16 | 17 | import sys 18 | sys.modules[__name__] = WebCV2() 19 | 20 | -------------------------------------------------------------------------------- /concern/webcv2/manager.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env mdl 2 | import socket 3 | import base64 4 | import cv2 5 | import numpy as np 6 | from collections import OrderedDict 7 | 8 | from .server import get_server 9 | 10 | 11 | def jpeg_encode(img): 12 | return cv2.imencode('.png', img)[1] 13 | 14 | 15 | def get_free_port(rng, low=2000, high=10000): 16 | in_use = True 17 | while in_use: 18 | port = rng.randint(high - low) + low 19 | in_use = False 20 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 21 | try: 22 | s.bind(("0.0.0.0", port)) 23 | except socket.error as e: 24 | if e.errno == 98: # port already in use 25 | in_use = True 26 | s.close() 27 | return port 28 | 29 | 30 | class Manager: 31 | def __init__(self, img_encode_method=jpeg_encode, rng=None): 32 | self._queue = OrderedDict() 33 | self._server = None 34 | self.img_encode_method = img_encode_method 35 | if rng is None: 36 | rng = np.random.RandomState(self.get_default_seed()) 37 | self.rng = rng 38 | 39 | def get_default_seed(self): 40 | return 0 41 | 42 | def imshow(self, title, img): 43 | data = self.img_encode_method(img) 44 | data = base64.b64encode(data) 45 | data = data.decode('utf8') 46 | self._queue[title] = data 47 | 48 | def waitKey(self, delay=0): 49 | if self._server is None: 50 | self.port = get_free_port(self.rng) 51 | self._server, self._conn = get_server(port=self.port) 52 | self._conn.send([delay, list(self._queue.items())]) 53 | # self._queue = OrderedDict() 54 | return self._conn.recv() 55 | 56 | global_manager = Manager() 57 | 58 | -------------------------------------------------------------------------------- /concern/webcv2/server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env mdl 2 | import os 3 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 4 | import time 5 | import json 6 | import select 7 | import traceback 8 | import socket 9 | from multiprocessing import Process, Pipe 10 | 11 | import gevent 12 | from gevent.pywsgi import WSGIServer 13 | from geventwebsocket.handler import WebSocketHandler 14 | from flask import Flask, request, render_template, abort 15 | 16 | 17 | def log_important_msg(msg, *, padding=3): 18 | msg_len = len(msg) 19 | width = msg_len + padding * 2 + 2 20 | print('#' * width) 21 | print('#' + ' ' * (width - 2) + '#') 22 | print('#' + ' ' * padding + msg + ' ' * padding + '#') 23 | print('#' + ' ' * (width - 2) + '#') 24 | print('#' * width) 25 | 26 | 27 | def hint_url(url, port): 28 | log_important_msg( 29 | 'The server is running at: {}'.format(url)) 30 | 31 | 32 | def _set_server(conn, name='webcv2', port=7788): 33 | package = None 34 | package_alive = False 35 | 36 | app = Flask(name) 37 | app.root_path = BASE_DIR 38 | 39 | @app.route('/') 40 | def index(): 41 | return render_template('index.html', title=name) 42 | 43 | @app.route('/stream') 44 | def stream(): 45 | def poll_ws(ws, delay): 46 | return len(select.select([ws.stream.handler.rfile], [], [], delay / 1000.)[0]) > 0 47 | 48 | if request.environ.get('wsgi.websocket'): 49 | ws = request.environ['wsgi.websocket'] 50 | if ws is None: 51 | abort(404) 52 | else: 53 | should_send = True 54 | while not ws.closed: 55 | global package 56 | global package_alive 57 | if conn.poll(): 58 | package = conn.recv() 59 | package_alive = True 60 | should_send = True 61 | if not should_send: 62 | continue 63 | should_send = False 64 | if package is None: 65 | ws.send(None) 66 | else: 67 | delay, info_lst = package 68 | ws.send(json.dumps((time.time(), package_alive, delay, info_lst))) 69 | if package_alive: 70 | if delay <= 0 or poll_ws(ws, delay): 71 | message = ws.receive() 72 | if ws.closed or message is None: 73 | break 74 | try: 75 | if isinstance(message, bytes): 76 | message = message.decode('utf8') 77 | message = int(message) 78 | except: 79 | traceback.print_exc() 80 | message = -1 81 | else: 82 | message = -1 83 | conn.send(message) 84 | package_alive = False 85 | return "" 86 | 87 | http_server = WSGIServer(('', port), app, handler_class=WebSocketHandler) 88 | hint_url('http://{}:{}'.format(socket.getfqdn(), port), port) 89 | http_server.serve_forever() 90 | 91 | 92 | def get_server(name='webcv2', port=7788): 93 | conn_server, conn_factory = Pipe() 94 | p_server = Process( 95 | target=_set_server, 96 | args=(conn_server,), 97 | kwargs=dict( 98 | name=name, port=port, 99 | ), 100 | ) 101 | p_server.daemon = True 102 | p_server.start() 103 | return p_server, conn_factory 104 | 105 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # DO NOT CHANGE THIS. Using libraries from community or not 2 | community_version = True 3 | 4 | db_path = '/data/workspace' # The path for log storage. 5 | oss_host = "http://oss.wh-a.brainpp.cn" # The oss base url(if used). 6 | # Whether to use nori(packed in vendor/nori2-1.9.18-py3-none-any.whl) for data packing. 7 | will_use_nori = False 8 | will_use_lmdb = True 9 | 10 | # Only needed when RedisMetaCache is applied for caching meta. 11 | redis_host = '10.251.160.97' 12 | redis_port = 6379 13 | 14 | sync_bn = False 15 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_dataset import FileDataset 2 | from .mingle_dataset import MingleDataset 3 | from .lmdb_dataset import LMDBDataset -------------------------------------------------------------------------------- /data/augmenter.py: -------------------------------------------------------------------------------- 1 | import imgaug 2 | import imgaug.augmenters as iaa 3 | 4 | from concern.config import Configurable, State 5 | 6 | 7 | class AugmenterBuilder(object): 8 | def __init__(self): 9 | pass 10 | 11 | def build(self, args, root=True): 12 | if args is None: 13 | return None 14 | elif isinstance(args, (int, float, str)): 15 | return args 16 | elif isinstance(args, list): 17 | if root: 18 | sequence = [self.build(value, root=False) for value in args] 19 | return iaa.Sequential(sequence) 20 | else: 21 | return getattr(iaa, args[0])( 22 | *[self.to_tuple_if_list(a) for a in args[1:]]) 23 | elif isinstance(args, dict): 24 | if 'cls' in args: 25 | cls = getattr(iaa, args['cls']) 26 | return cls(**{k: self.to_tuple_if_list(v) for k, v in args.items() if not k == 'cls'}) 27 | else: 28 | return {key: self.build(value, root=False) for key, value in args.items()} 29 | else: 30 | raise RuntimeError('unknown augmenter arg: ' + str(args)) 31 | 32 | def to_tuple_if_list(self, obj): 33 | if isinstance(obj, list): 34 | return tuple(obj) 35 | return obj 36 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from torch.utils.data import Dataset as TorchDataset 3 | 4 | from concern.config import Configurable, State 5 | 6 | 7 | class Dataset(TorchDataset, Configurable): 8 | r'''Dataset reading data. 9 | Args: 10 | meta_loader: Instance of Metaloader, which is defined in data/meta_loaders/meta_loader.py, 11 | for decoding annotation from packed files. 12 | unpack: Callable instance which unpack data from packed bytes such as pickle or msgpack. 13 | If not provided, the `default_unpack` will be invoked. 14 | Processes: A series of Callable object, which accept as parameter and return the data dict, 15 | typically inherrited the `DataProcess`(data/processes/data_process.py) class. 16 | ''' 17 | meta_loader = State() 18 | unpack = State(default=None) 19 | split = State(default=1) 20 | 21 | processes = State(default=[]) 22 | 23 | def prepare_meta_single(self, path_name): 24 | return self.meta_loader.load_meta(path_name) 25 | 26 | def list_or_pattern(self, path): 27 | if isinstance(path, str): 28 | if '*' in path: 29 | return glob.glob(path) 30 | else: 31 | return [path] 32 | else: 33 | return path 34 | 35 | def __getitem__(self, index, retry=0): 36 | if index >= self.num_samples: 37 | index = index % self.num_samples 38 | 39 | data_id = self.data_ids[index] 40 | meta = dict() 41 | 42 | for key in self.meta: 43 | meta[key] = self.meta[key][index] 44 | 45 | try: 46 | data = self.unpack(data_id, meta) 47 | if self.processes is not None: 48 | for data_process in self.processes: 49 | data = data_process(data) 50 | except Exception as e: 51 | if self.debug or retry > 10: 52 | raise e 53 | return self.__getitem__(index + 1, retry=retry+1) 54 | return data 55 | 56 | def truncate(self, rank, total): 57 | clip_size = self.num_samples // total 58 | start = rank * clip_size 59 | if rank == total - 1: 60 | ending = self.num_samples 61 | else: 62 | ending = start + clip_size 63 | new_data_ids = self.data_ids[start:ending] 64 | del self.data_ids 65 | self.data_ids = new_data_ids 66 | 67 | new_meta = dict() 68 | for key in self.meta.keys(): 69 | new_meta[key] = self.meta[key][start:ending] 70 | for key in new_meta.keys(): 71 | del self.meta[key] 72 | self.meta = new_meta 73 | self.num_samples = len(self.data_ids) 74 | self.truncated = True 75 | 76 | def select(self, indices): 77 | if self.truncated: 78 | return 79 | new_data_ids = [self.data_ids[i] for i in indices] 80 | del self.data_ids 81 | self.data_ids = new_data_ids 82 | 83 | new_meta = { 84 | key: [self.meta[key][i] for i in indices] for key in self.meta.keys() 85 | } 86 | del self.meta 87 | self.meta = new_meta 88 | # self.num_samples = len(self.data_ids) 89 | self.truncated = True 90 | 91 | def __len__(self): 92 | if self.debug: 93 | return 512000 94 | return self.num_samples // self.split 95 | -------------------------------------------------------------------------------- /data/east.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import cv2 4 | import numpy as np 5 | from shapely.geometry import Polygon 6 | 7 | from concern.config import Configurable, State 8 | 9 | 10 | class MakeEASTData(Configurable): 11 | shrink = State(default=0.5) 12 | background_weight = State(default=3.0) 13 | 14 | def __init__(self, **kwargs): 15 | self.load_all(**kwargs) 16 | 17 | def find_polygon_radius(self, poly): 18 | poly = Polygon(poly) 19 | low = 0 20 | high = 65536 21 | while high - low > 0.1: 22 | mid = (high + low) / 2 23 | area = poly.buffer(-mid).area 24 | if area > 0.1: 25 | low = mid 26 | else: 27 | high = mid 28 | return high 29 | 30 | def __call__(self, data, *args, **kwargs): 31 | image, label, meta = data 32 | lines = label['polys'] 33 | image_id = meta['data_id'] 34 | 35 | h, w = image.shape[:2] 36 | 37 | heatmap = np.zeros((h, w), np.float32) 38 | heatmap_weight = np.zeros((h, w), np.float32) 39 | densebox = np.zeros((8, h, w), np.float32) 40 | densebox_weight = np.zeros((h, w), np.float32) 41 | train_mask = np.ones((h, w), np.float32) 42 | 43 | densebox_anchor = np.indices((h, w))[::-1].astype(np.float32) 44 | 45 | for line in lines: 46 | poly = line['points'] 47 | 48 | assert(len(poly) == 4) 49 | quad = poly 50 | 51 | radius = self.find_polygon_radius(poly) 52 | shrinked_poly = list(Polygon(poly).buffer(-radius * self.shrink).exterior.coords)[:-1] 53 | shrinked_poly_points = np.array([shrinked_poly], np.int32) 54 | 55 | cv2.fillConvexPoly(heatmap, shrinked_poly_points, 1.0) 56 | cv2.fillConvexPoly(densebox_weight, shrinked_poly_points, 1.0) 57 | 58 | for i in range(0, 4): 59 | for j in range(0, 2): 60 | cv2.fillConvexPoly(densebox[i * 2 + j], shrinked_poly_points, float(quad[i][j])) 61 | 62 | if line['ignore']: 63 | cv2.fillConvexPoly(train_mask, np.array([poly], np.int32), 0.0) 64 | 65 | heatmap_neg = np.logical_and(heatmap == 0, train_mask) 66 | heatmap_pos = np.logical_and(heatmap > 0, train_mask) 67 | 68 | heatmap_weight[heatmap_neg] = self.background_weight 69 | heatmap_weight[heatmap_pos] = train_mask.sum() / max(heatmap_pos.sum(), train_mask.sum() * 0.05) 70 | 71 | densebox_weight = densebox_weight * train_mask 72 | 73 | densebox = densebox - np.tile(densebox_anchor, (4, 1, 1)) 74 | 75 | # to pytorch channel sequence 76 | image = image.transpose(2, 0, 1) 77 | 78 | label = { 79 | 'heatmap': heatmap[np.newaxis], 80 | 'heatmap_weight': heatmap_weight[np.newaxis], 81 | 'densebox': densebox, 82 | 'densebox_weight': densebox_weight[np.newaxis], 83 | } 84 | meta = { 85 | 'image_id': image_id, 86 | 'lines': lines, 87 | } 88 | return image, label, pickle.dumps(meta) 89 | -------------------------------------------------------------------------------- /data/file_dataset.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import cv2 4 | from concern.config import Configurable, State 5 | from concern.log import Log 6 | from .dataset import Dataset 7 | from concern.distributed import is_main 8 | 9 | 10 | class FileDataset(Dataset, Configurable): 11 | r'''Dataset reading from files. 12 | Args: 13 | file_paths: Pattern or list, required, the file_paths containing data and annotations. 14 | ''' 15 | file_paths = State() 16 | 17 | def __init__(self, path=None, file_paths=None, cmd={}, **kwargs): 18 | self.load_all(**kwargs) 19 | 20 | if file_paths is None: 21 | file_paths = path 22 | self.file_paths = self.list_or_pattern(file_paths) or self.file_paths 23 | 24 | self.debug = cmd.get('debug', False) 25 | self.prepare() 26 | 27 | def prepare(self): 28 | self.meta = self.prepare_meta(self.file_paths) 29 | 30 | if self.unpack is None: 31 | self.unpack = self.default_unpack 32 | 33 | self.data_ids = self.meta.get('data_ids', self.meta.get('data_id', [])) 34 | if self.debug: 35 | self.data_ids = self.data_ids[:32] 36 | self.num_samples = len(self.data_ids) 37 | if is_main(): 38 | print(self.num_samples, 'images found') 39 | return self 40 | 41 | def prepare_meta(self, path_or_list): 42 | def add(a_dict: dict, another_dict: dict): 43 | for key, value in another_dict.items(): 44 | if key in a_dict: 45 | a_dict[key] = a_dict[key] + value 46 | else: 47 | a_dict[key] = value 48 | return a_dict 49 | 50 | if isinstance(path_or_list, list): 51 | return functools.reduce(add, [self.prepare_meta(path) for path in path_or_list]) 52 | 53 | return self.prepare_meta_single(path_or_list) 54 | 55 | def default_unpack(self, data_id, meta): 56 | image = cv2.imread(data_id, cv2.IMREAD_COLOR).astype('float32') 57 | meta['image'] = image 58 | return meta 59 | -------------------------------------------------------------------------------- /data/list_dataset.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | 4 | from torch.utils.data import Dataset as TorchDataset 5 | 6 | import cv2 7 | import numpy as np 8 | from matplotlib.pyplot import imread 9 | 10 | from concern.config import Configurable, State 11 | 12 | 13 | class ListDataset(TorchDataset, Configurable): 14 | list_file = State() 15 | processes = State() 16 | image_path = State() 17 | gt_path = State() 18 | 19 | def __init__(self, 20 | debug=False, 21 | list_file=None, image_path=None, gt_path=None, 22 | **kwargs): 23 | self.load_all(**kwargs) 24 | 25 | self.image_path = image_path or self.image_path 26 | self.gt_path = gt_path or self.gt_path 27 | self.list_file = list_file or self.list_file 28 | self.gt_paths, self.image_paths =\ 29 | self.load_meta() # FIXME: this should be removed 30 | 31 | def load_meta(self): 32 | base_path = os.path.dirname(self.list_file) 33 | gt_base_path = os.path.join(base_path, 'gts') 34 | image_base_path = self.image_path 35 | gt_paths = [] 36 | image_paths = [] 37 | with open(self.list_file, 'rt') as list_reader: 38 | for _line in list_reader.readlines(): 39 | line = _line.strip() 40 | gt_paths.append(os.path.join(gt_base_path, line + '.txt')) 41 | image_paths.append(os.path.join(image_base_path, line)) 42 | print(len(gt_paths), 'images found') 43 | self.loaded = True 44 | return gt_paths, image_paths 45 | 46 | def __getitem__(self, index): 47 | if not self.loaded: 48 | self.load() 49 | gt_path = self.gt_paths[index] 50 | image_path = self.image_paths[index] 51 | data = (gt_path, image_path) 52 | for process in self.processes: 53 | data = process(data) 54 | return data 55 | 56 | def __len__(self): 57 | return len(self.gt_paths) 58 | 59 | 60 | class UnpackTxtData(Configurable): 61 | def __init__(self, **kwargs): 62 | pass 63 | 64 | def __call__(self, data): 65 | gt_path, image_path = data 66 | image_name = os.path.basename(image_path) 67 | 68 | lines = [] 69 | with open(gt_path) as reader: 70 | for line in reader.readlines(): 71 | line = line.replace('\ufeff', '').strip() 72 | if '\xef\xbb\xbf' in line: 73 | import pdb; pdb.set_trace() 74 | line_list = line.split(',') 75 | assert len(line_list) == 9 76 | points = np.array([float(scalar) 77 | for scalar in line_list[:-1]]).reshape(-1, 2) 78 | lines.append(dict( 79 | poly=points.tolist(), 80 | text=line_list[-1], 81 | filename=image_name)) 82 | 83 | return dict( 84 | img=imread(image_path, mode='RGB'), 85 | lines=lines, 86 | data_id=image_name, 87 | filename=image_name) 88 | -------------------------------------------------------------------------------- /data/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import cv2 3 | import numpy as np 4 | import pickle 5 | import lmdb 6 | import os 7 | from .dataset import Dataset 8 | from concern.config import Configurable, State 9 | from concern.distributed import is_main 10 | 11 | 12 | class LMDBDataset(Dataset, Configurable): 13 | r'''Dataset reading from lmdb. 14 | Args: 15 | lmdb_paths: Pattern or list, required, the path of `data.mdb`, 16 | e.g. `the/path/`, `['the/path/a/', 'the/path/b/']` 17 | ''' 18 | lmdb_paths = State() 19 | 20 | def __init__(self, lmdb_paths=None, cmd={}, **kwargs): 21 | self.load_all(**kwargs) 22 | self.lmdb_paths = self.list_or_pattern(lmdb_paths) or self.lmdb_paths 23 | self.debug = cmd.get('debug', False) 24 | self.envs = [] 25 | self.txns = {} 26 | self.prepare() 27 | self.truncated = False 28 | self.data = None 29 | 30 | def __del__(self): 31 | for env in self.envs: 32 | env.close() 33 | 34 | def prepare_meta(self, path_or_list): 35 | def add(a_dict: dict, another_dict: dict): 36 | new_dict = dict() 37 | for key, value in another_dict.items(): 38 | if key in a_dict: 39 | new_dict[key] = a_dict[key] + value 40 | else: 41 | new_dict[key] = value 42 | return new_dict 43 | 44 | if isinstance(path_or_list, list): 45 | return functools.reduce(add, [self.prepare_meta(path) for path in path_or_list]) 46 | 47 | assert type(path_or_list) == str, path_or_list 48 | 49 | return self.prepare_meta_single(path_or_list) 50 | 51 | def prepare_meta_single(self, path_name): 52 | return self.meta_loader.load_meta(path_name) 53 | 54 | def prepare(self): 55 | self.meta = self.prepare_meta(self.lmdb_paths) 56 | if self.unpack is None: 57 | self.unpack = self.default_unpack 58 | # prepare lmdb environments 59 | for path in self.lmdb_paths: 60 | path = os.path.join(path, '') 61 | env = lmdb.open(path, max_dbs=1, lock=False) 62 | db_image = env.open_db('image'.encode()) 63 | self.envs.append(env) 64 | self.txns[path] = env.begin(db=db_image) 65 | # The fetcher is supposed to be initialized in the 66 | # sub-processes, or it will cause CRC Error. 67 | self.fetcher = None 68 | 69 | self.data_ids = self.meta.get('data_ids', self.meta.get('data_id', [])) 70 | if self.debug: 71 | self.data_ids = self.data_ids[:32] 72 | self.num_samples = len(self.data_ids) 73 | 74 | if is_main(): 75 | print(self.num_samples, 'images found') 76 | return self 77 | 78 | def search_image(self, data_id, path): 79 | maybe_image = self.txns[path].get(data_id) 80 | assert maybe_image is not None, 'image %s not found at %s' % ( 81 | data_id, path) 82 | return maybe_image 83 | 84 | def default_unpack(self, data_id, meta): 85 | data = self.search_image(data_id, meta['db_path']) 86 | image = np.fromstring(data, dtype=np.uint8) 87 | image = cv2.imdecode(image, cv2.IMREAD_COLOR).astype('float32') 88 | meta['image'] = image 89 | return meta 90 | -------------------------------------------------------------------------------- /data/local_csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.utils.data import Dataset as TorchDataset 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | from concern.config import Configurable, State 9 | 10 | 11 | class LocalCSVDataset(TorchDataset, Configurable): 12 | csv_path = State() 13 | processes = State() 14 | debug = State(default=False) 15 | 16 | def __init__(self, cmd={}, **kwargs): 17 | self.load_all(cmd=cmd, **kwargs) 18 | 19 | self.load_meta() 20 | self.debug = cmd.get('debug', False) 21 | 22 | def load_meta(self): 23 | self.meta = [] 24 | for textline in open(self.csv_path): 25 | tokens = textline.strip().split('\t') 26 | filename = tokens[0] 27 | filepath = os.path.join(os.path.dirname(self.csv_path), filename) 28 | lines_coords = tokens[1::2] 29 | lines_text = tokens[2::2] 30 | lines = [] 31 | for coords, text in zip(lines_coords, lines_text): 32 | poly = np.array(list(map(int, coords[1:-1].split(',')))).reshape(4, 2).tolist() 33 | lines.append({ 34 | 'poly': poly, 35 | 'text': text, 36 | }) 37 | self.meta.append({ 38 | 'img': cv2.imread(filepath, cv2.IMREAD_COLOR), 39 | 'lines': lines, 40 | 'filename': filename, 41 | 'data_id': filename, 42 | }) 43 | 44 | print(len(self.meta), 'images found') 45 | 46 | def __getitem__(self, index): 47 | data = self.meta[index] 48 | for process in self.processes: 49 | data = process(data) 50 | return data 51 | 52 | def __len__(self): 53 | return len(self.meta) 54 | -------------------------------------------------------------------------------- /data/meta.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Megvii-CSG/MegReader/4cb5635d109f81f020780f39de9fca29a44c9f82/data/meta.py -------------------------------------------------------------------------------- /data/meta_loader.py: -------------------------------------------------------------------------------- 1 | from data.meta_loaders import * 2 | -------------------------------------------------------------------------------- /data/meta_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_id_meta_loader import DataIdMetaLoader 2 | from .recognition_meta_loader import RecognitionMetaLoader 3 | from .text_lines_meta_loader import TextLinesMetaLoader 4 | from .charbox_meta_loader import CharboxMetaLoader 5 | from .meta_cache import OSSMetaCache, FileMetaCache, RedisMetaCache 6 | from .lmdb_meta_loader import LMDBMetaLoader 7 | from .detection_meta_loader import DetectionMetaLoader -------------------------------------------------------------------------------- /data/meta_loaders/charbox_meta_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from concern.config import State 4 | from .recognition_meta_loader import RecognitionMetaLoader 5 | 6 | 7 | class CharboxMetaLoader(RecognitionMetaLoader): 8 | charbox_key = State(default='charboxes') 9 | transpose = State(default=False) 10 | 11 | def __init__(self, charbox_key=None, transpose=None, cmd={}, **kwargs): 12 | super().__init__(cmd=cmd, **kwargs) 13 | print('load CharBox') 14 | if charbox_key is not None: 15 | self.charbox_key = charbox_key 16 | if transpose is not None: 17 | self.transpose = transpose 18 | 19 | def parse_meta(self, data_id, meta): 20 | parsed = super().parse_meta(data_id, meta) 21 | if parsed is None: 22 | return 23 | 24 | charbox = np.array(self.get_annotation(meta)[self.charbox_key]) 25 | if self.transpose: 26 | charbox = charbox.transpose(2, 1, 0) 27 | parsed['charboxes'] = charbox 28 | return parsed 29 | -------------------------------------------------------------------------------- /data/meta_loaders/data_id_meta_loader.py: -------------------------------------------------------------------------------- 1 | from concern.config import State 2 | from .meta_loader import MetaLoader 3 | 4 | 5 | class DataIdMetaLoader(MetaLoader): 6 | return_dict = State(default=False) 7 | scan_meta = False 8 | 9 | def __init__(self, return_dict=None, cmd={}, **kwargs): 10 | super().__init__(cmd=cmd, **kwargs) 11 | if return_dict is not None: 12 | self.return_dict = return_dict 13 | 14 | def parse_meta(self, data_id): 15 | return dict(data_id=data_id) 16 | 17 | def post_prosess(self, meta): 18 | if self.return_dict: 19 | return meta 20 | return meta['data_id'] 21 | -------------------------------------------------------------------------------- /data/meta_loaders/detection_meta_loader.py: -------------------------------------------------------------------------------- 1 | from concern.config import State 2 | from .meta_loader import MetaLoader 3 | 4 | 5 | class DetectionMetaLoader(MetaLoader): 6 | key = State(default='gt') 7 | scan_meta = True 8 | 9 | def __init__(self, key=None, cmd={}, **kwargs): 10 | super().__init__(cmd=cmd, **kwargs) 11 | if key is not None: 12 | self.key = key 13 | 14 | def parse_meta(self, data_id, meta): 15 | return dict(data_ids=data_id, gt=self.get_annotation(meta)[self.key]) 16 | -------------------------------------------------------------------------------- /data/meta_loaders/json_meta_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from concern.config import Configurable, State 4 | 5 | 6 | class JsonMetaLoader(Configurable): 7 | cache = State() 8 | force_reload = State(default=False) 9 | image_folder = State(default='images') 10 | json_file = State(default='meta.json') 11 | 12 | scan_meta = True 13 | scan_data = False 14 | post_prosess = None 15 | 16 | def __init__(self, force_reload=None, cmd={}, **kwargs): 17 | self.load_all(cmd=cmd, **kwargs) 18 | self.force_reload = cmd.get('force_reload', self.force_reload) 19 | if force_reload is not None: 20 | self.force_reload = force_reload 21 | 22 | def load_meta(self, json_path): 23 | if not self.force_reload and self.cache is not None: 24 | meta = self.cache.read(json_path) 25 | if meta is not None: 26 | return meta 27 | 28 | meta_info = dict() 29 | valid_count = 0 30 | with open(os.path.join(json_path, self.json_file)) as reader: 31 | for line in reader.readlines(): 32 | line = line.strip() 33 | single_meta = json.loads(line) 34 | data_id = os.path.join(json_path, single_meta['filename']) 35 | args_dict = dict(data_id=data_id) 36 | 37 | if self.scan_data: 38 | with open(self.same_dir_with(json_path), self.image_folder) as reader: 39 | data = reader.read() 40 | args_dict.update(data=data) 41 | elif self.scan_meta: 42 | args_dict.update(meta=single_meta) 43 | 44 | meta_instance = self.parse_meta(**args_dict) 45 | if meta_instance is None: 46 | continue 47 | 48 | valid_count += 1 49 | if valid_count % 100000 == 0: 50 | print("%d instances processd" % valid_count) 51 | 52 | for key in meta_instance: 53 | the_list = meta_info.get(key, []) 54 | the_list.append(meta_instance[key]) 55 | meta_info[key] = the_list 56 | 57 | print(valid_count, 'instances found') 58 | if self.post_prosess is not None: 59 | meta_info = self.post_prosess(meta_info) 60 | 61 | if self.cache is not None: 62 | self.cache.save(json_path, meta_info) 63 | 64 | return meta_info 65 | 66 | def parse_meta(self, data_id, meta): 67 | raise NotImplementedError 68 | 69 | def get_annotation(self, meta): 70 | return meta['extra'] 71 | 72 | def same_dir_with(self, full_path, dest): 73 | return os.path.join(os.path.dirname(full_path), dest) 74 | -------------------------------------------------------------------------------- /data/meta_loaders/lmdb_meta_loader.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import pickle 3 | from collections import defaultdict 4 | from concern.config import Configurable, State 5 | import time 6 | import os 7 | 8 | 9 | class LMDBMetaLoader(Configurable): 10 | cache = State() 11 | force_reload = State(default=False) 12 | post_prosess = None 13 | 14 | def __init__(self, force_reload=None, cmd={}, **kwargs): 15 | self.load_all(cmd=cmd, **kwargs) 16 | self.force_reload = cmd.get('force_reload', self.force_reload) 17 | 18 | # lmdb environments 19 | self.envs = {} 20 | self.txns = {} 21 | if force_reload is not None: 22 | self.force_reload = force_reload 23 | 24 | def __del__(self): 25 | for path in self.envs: 26 | self.envs[path].close() 27 | 28 | def load_meta(self, lmdb_path): 29 | lmdb_path = os.path.join(lmdb_path, '') 30 | if not self.force_reload and self.cache is not None: 31 | meta = self.cache.read(lmdb_path) 32 | if meta is not None: 33 | return meta 34 | meta_info = defaultdict(list) 35 | valid_count = 0 36 | if lmdb_path not in self.envs: 37 | env = lmdb.open(lmdb_path, max_dbs=1, lock=False) 38 | self.envs[lmdb_path] = env 39 | db_extra = env.open_db('extra'.encode()) 40 | self.txns[lmdb_path] = env.begin(db=db_extra) 41 | 42 | txn = self.txns[lmdb_path] 43 | cursor = txn.cursor() 44 | for data_id, value in cursor: 45 | args_tuple = (data_id, ) 46 | if self.scan_meta: 47 | args_tuple = tuple((*args_tuple, pickle.loads(value))) 48 | meta_instance = self.parse_meta(*args_tuple) 49 | if meta_instance is None: 50 | continue 51 | meta_instance['db_path'] = lmdb_path 52 | valid_count += 1 53 | if valid_count % 100000 == 0: 54 | print("%d instances processd" % valid_count) 55 | for key in meta_instance: 56 | meta_info[key].append(meta_instance[key]) 57 | 58 | print(valid_count, 'instances found') 59 | if self.post_prosess is not None: 60 | meta_info = self.post_prosess(meta_info) 61 | 62 | if self.cache is not None: 63 | self.cache.save(lmdb_path, meta_info) 64 | 65 | return meta_info 66 | 67 | def parse_meta(self, data_id, meta): 68 | raise NotImplementedError 69 | 70 | def get_annotation(self, meta): 71 | return meta['extra'] 72 | -------------------------------------------------------------------------------- /data/meta_loaders/meta_loader.py: -------------------------------------------------------------------------------- 1 | import config 2 | 3 | assert not (config.will_use_nori and config.will_use_lmdb), 'only one metaloader can be used' 4 | if config.will_use_nori: 5 | from .nori_meta_loader import NoriMetaLoader 6 | MetaLoader = NoriMetaLoader 7 | elif config.will_use_lmdb: 8 | from .lmdb_meta_loader import LMDBMetaLoader 9 | MetaLoader = LMDBMetaLoader 10 | else: 11 | from .json_meta_loader import JsonMetaLoader 12 | MetaLoader = JsonMetaLoader 13 | -------------------------------------------------------------------------------- /data/meta_loaders/nori_meta_loader.py: -------------------------------------------------------------------------------- 1 | from concern.config import Configurable, State 2 | import nori2 as nori 3 | 4 | 5 | class NoriMetaLoader(Configurable): 6 | cache = State() 7 | force_reload = State(default=False) 8 | 9 | scan_meta = True 10 | scan_data = False 11 | post_prosess = None 12 | 13 | def __init__(self, force_reload=None, cmd={}, **kwargs): 14 | self.load_all(cmd=cmd, **kwargs) 15 | self.force_reload = cmd.get('force_reload', self.force_reload) 16 | if force_reload is not None: 17 | self.force_reload = force_reload 18 | 19 | def load_meta(self, nori_path): 20 | if not self.force_reload and self.cache is not None: 21 | meta = self.cache.read(nori_path) 22 | if meta is not None: 23 | return meta 24 | 25 | meta_info = dict() 26 | valid_count = 0 27 | with nori.open(nori_path) as reader: 28 | for data_id, data, meta in reader.scan( 29 | scan_data=self.scan_data, scan_meta=self.scan_meta): 30 | args_tuple = (data_id, ) 31 | if self.scan_data: 32 | args_tuple = tuple((*args_tuple, data)) 33 | if self.scan_meta: 34 | args_tuple = tuple((*args_tuple, meta)) 35 | meta_instance = self.parse_meta(*args_tuple) 36 | if meta_instance is None: 37 | continue 38 | valid_count += 1 39 | if valid_count % 100000 == 0: 40 | print("%d instances processd" % valid_count) 41 | for key in meta_instance: 42 | the_list = meta_info.get(key, []) 43 | the_list.append(meta_instance[key]) 44 | meta_info[key] = the_list 45 | 46 | print(valid_count, 'instances found') 47 | if self.post_prosess is not None: 48 | meta_info = self.post_prosess(meta_info) 49 | 50 | if self.cache is not None: 51 | self.cache.save(nori_path, meta_info) 52 | 53 | return meta_info 54 | 55 | def parse_meta(self, data_id, meta): 56 | raise NotImplementedError 57 | 58 | def get_annotation(self, meta): 59 | return meta['extra'] 60 | -------------------------------------------------------------------------------- /data/meta_loaders/recognition_meta_loader.py: -------------------------------------------------------------------------------- 1 | from hanziconv import HanziConv 2 | 3 | from concern.charset_tool import stringQ2B 4 | from concern.config import State 5 | from .meta_loader import MetaLoader 6 | 7 | 8 | class RecognitionMetaLoader(MetaLoader): 9 | skip_vertical = State(default=False) 10 | case_sensitive = State(default=False) 11 | simplify = State(default=False) 12 | key = State(default='words') 13 | scan_meta = True 14 | 15 | def __init__(self, key=None, cmd={}, **kwargs): 16 | super().__init__(cmd=cmd, **kwargs) 17 | if key is not None: 18 | self.key = key 19 | 20 | def may_simplify(self, words): 21 | garbled = stringQ2B(words) 22 | if self.simplify: 23 | return HanziConv.toSimplified(garbled) 24 | return garbled 25 | 26 | def parse_meta(self, data_id, meta): 27 | word = self.may_simplify(self.get_annotation(meta)[self.key]) 28 | vertical = self.get_annotation(meta).get('vertical', False) 29 | if self.skip_vertical and vertical: 30 | return None 31 | 32 | if word == '###': 33 | return None 34 | return dict(data_ids=data_id, gt=word) 35 | -------------------------------------------------------------------------------- /data/meta_loaders/redis_meta.py: -------------------------------------------------------------------------------- 1 | import redis 2 | 3 | 4 | class RedisMeta: 5 | def __init__(self, socket_path): 6 | redis.StrictRedis() 7 | -------------------------------------------------------------------------------- /data/meta_loaders/text_lines_meta_loader.py: -------------------------------------------------------------------------------- 1 | from data.text_lines import TextLines 2 | from .meta_loader import MetaLoader 3 | 4 | 5 | class TextLinesMetaLoader(MetaLoader): 6 | def parse_meta(self, data_id, meta): 7 | return dict( 8 | data_id=data_id, 9 | lines=TextLines(self.get_annotation(meta)['lines'])) 10 | -------------------------------------------------------------------------------- /data/mingle_dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | 3 | import torch.utils.data as data 4 | 5 | from concern.config import Configurable, State 6 | 7 | 8 | class MingleDataset(data.Dataset, Configurable): 9 | datasets = State(default=[]) 10 | 11 | def __init__(self, cmd={}, **kwargs): 12 | self.load_all(cmd=cmd, **kwargs) 13 | 14 | ratios = [] 15 | sizes = [] 16 | indices = [] 17 | self.data_sources = [] 18 | for i in range(len(self.datasets)): 19 | dataset_dict = self.datasets[i] 20 | ratio = dataset_dict['ratio'] 21 | size = len(dataset_dict['dataset']) 22 | if size == 0: 23 | continue 24 | indices.append(i) 25 | ratios.append(ratio) 26 | sizes.append(size) 27 | self.data_sources.append(dataset_dict['dataset']) 28 | 29 | ratio_sum = sum(ratios) 30 | ratios = [r / ratio_sum for r in ratios] 31 | total = sum(sizes) 32 | for index, ratio in enumerate(ratios): 33 | quota = ratio * total 34 | if sizes[index] < quota: 35 | total = int(sizes[index] / ratio + 0.5) 36 | 37 | milestones = [] 38 | for ratio in ratios[:-1]: 39 | milestones.append(int(ratio * total + 0.5)) 40 | self.milestones = milestones 41 | self.total = total 42 | print('total', self.total) 43 | 44 | def __len__(self): 45 | return self.total 46 | 47 | def __getitem__(self, index): 48 | dataset_index = bisect.bisect(self.milestones, index) 49 | dataset = self.data_sources[dataset_index] 50 | if dataset_index == 0: 51 | real_index = index 52 | else: 53 | real_index = index - self.milestones[dataset_index - 1] 54 | return dataset.__getitem__(real_index) 55 | -------------------------------------------------------------------------------- /data/mnist.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | 3 | from concern.config import Configurable, State 4 | 5 | 6 | class MNistDataset(Configurable, torchvision.datasets.MNIST): 7 | root = State() 8 | is_train = State(autoload=False) 9 | 10 | def __init__(self, **kwargs): 11 | self.load_all(**kwargs) 12 | 13 | cmd = kwargs['cmd'] 14 | self.is_train = cmd['is_train'] 15 | 16 | transform = torchvision.transforms.Compose([ 17 | torchvision.transforms.Resize((64, 64)), 18 | torchvision.transforms.Grayscale(num_output_channels=3), 19 | torchvision.transforms.ToTensor(), 20 | ]) 21 | torchvision.datasets.MNIST.__init__( 22 | self, self.root, 23 | train=self.is_train, download=True, transform=transform 24 | ) 25 | -------------------------------------------------------------------------------- /data/nori_dataset.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import cv2 3 | import numpy as np 4 | 5 | from .dataset import Dataset 6 | from concern.config import Configurable, State 7 | from concern.nori_reader import NoriReader 8 | from concern.distributed import is_main 9 | 10 | 11 | class NoriDataset(Dataset, Configurable): 12 | r'''Dataset reading from nori. 13 | Args: 14 | nori_paths: Pattern or list, required, the noris containing data, 15 | e.g. `the/path/*.nori`, `['a.nori', 'b.nori']` 16 | ''' 17 | nori_paths = State() 18 | 19 | def __init__(self, nori_paths=None, cmd={}, **kwargs): 20 | self.load_all(**kwargs) 21 | 22 | self.nori_paths = self.list_or_pattern(nori_paths) or self.nori_paths 23 | self.debug = cmd.get('debug', False) 24 | 25 | self.prepare() 26 | self.truncated = False 27 | self.data = None 28 | 29 | def prepare_meta(self, path_or_list): 30 | def add(a_dict: dict, another_dict: dict): 31 | new_dict = dict() 32 | for key, value in another_dict.items(): 33 | if key in a_dict: 34 | new_dict[key] = a_dict[key] + value 35 | else: 36 | new_dict[key] = value 37 | return new_dict 38 | 39 | if isinstance(path_or_list, list): 40 | return functools.reduce(add, [self.prepare_meta(path) for path in path_or_list]) 41 | 42 | assert type(path_or_list) == str, path_or_list 43 | assert path_or_list.endswith('.nori') or path_or_list.endswith('.nori/') 44 | return self.prepare_meta_single(path_or_list) 45 | 46 | def prepare_meta_single(self, path_name): 47 | return self.meta_loader.load_meta(path_name) 48 | 49 | def prepare(self): 50 | self.meta = self.prepare_meta(self.nori_paths) 51 | if self.unpack is None: 52 | self.unpack = self.default_unpack 53 | 54 | # The fetcher is supposed to be initialized in the 55 | # sub-processes, or it will cause CRC Error. 56 | self.fetcher = None 57 | 58 | self.data_ids = self.meta.get('data_ids', self.meta.get('data_id', [])) 59 | if self.debug: 60 | self.data_ids = self.data_ids[:32] 61 | self.num_samples = len(self.data_ids) 62 | if is_main(): 63 | print(self.num_samples, 'images found') 64 | return self 65 | 66 | def default_unpack(self, data_id, meta): 67 | if self.fetcher is None: 68 | self.fetcher = NoriReader(self.nori_paths) 69 | data = self.fetcher.get(data_id) 70 | image = np.fromstring(data, dtype=np.uint8) 71 | image = cv2.imdecode(image, cv2.IMREAD_COLOR).astype('float32') 72 | meta['image'] = image 73 | return meta 74 | -------------------------------------------------------------------------------- /data/processes/__init__.py: -------------------------------------------------------------------------------- 1 | from .normalize_image import NormalizeImage 2 | from .make_recognition_label import MakeRecognitionLabel 3 | from .make_keypoint_map import MakeKeyPointMap 4 | from .make_center_points import MakeCenterPoints 5 | from .resize_image import ResizeImage, ResizeData 6 | from .make_seg_recognition_label import MakeSegRecognitionLabel 7 | from .filter_keys import FilterKeys 8 | from .make_center_map import MakeCenterMap 9 | from .augment_data import AugmentData, AugmentDetectionData 10 | from .random_crop_data import RandomCropData 11 | from .make_icdar_data import MakeICDARData, ICDARCollectFN 12 | from .make_seg_detection_data import MakeSegDetectionData 13 | from .make_border_map import MakeBorderMap 14 | from .extract_detetion_data import ExtractDetectionData 15 | from .make_decouple_map import MakeDecoupleMap 16 | -------------------------------------------------------------------------------- /data/processes/augment_data.py: -------------------------------------------------------------------------------- 1 | import imgaug 2 | import numpy as np 3 | 4 | from concern.config import State 5 | from .data_process import DataProcess 6 | from data.augmenter import AugmenterBuilder 7 | 8 | 9 | class AugmentData(DataProcess): 10 | augmenter_args = State(autoload=False) 11 | 12 | def __init__(self, **kwargs): 13 | self.augmenter_args = kwargs.get('augmenter_args') 14 | self.augmenter = AugmenterBuilder().build(self.augmenter_args) 15 | 16 | def may_augment_annotation(self, aug, data): 17 | pass 18 | 19 | def process(self, data): 20 | image = data['image'] 21 | aug = None 22 | shape = image.shape 23 | 24 | if self.augmenter: 25 | aug = self.augmenter.to_deterministic() 26 | data['image'] = aug.augment_image(image) 27 | self.may_augment_annotation(aug, data, shape) 28 | 29 | filename = data.get('filename', data.get('data_id', '')) 30 | data.update(filename=filename, shape=shape[:2]) 31 | return data 32 | 33 | 34 | class AugmentDetectionData(AugmentData): 35 | def may_augment_annotation(self, aug, data, shape): 36 | if aug is None: 37 | return data 38 | 39 | line_polys = [] 40 | for line in data['lines']: 41 | line_polys.append({ 42 | 'points': self.may_augment_poly(aug, shape, line['poly']), 43 | 'ignore': line['text'] == '###', 44 | 'text': line['text'], 45 | }) 46 | data['polys'] = line_polys 47 | return data 48 | 49 | def may_augment_poly(self, aug, img_shape, poly): 50 | keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly] 51 | keypoints = aug.augment_keypoints( 52 | [imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints 53 | poly = [(p.x, p.y) for p in keypoints] 54 | return poly 55 | 56 | 57 | class AugmentTextLine(AugmentData): 58 | def may_augment_annotation(self, aug, data, shape): 59 | if aug is None: 60 | return data 61 | 62 | lines = data['lines'] 63 | points_shape = lines.quads.points.shape 64 | lines.quads.points = self.may_augment_poly( 65 | aug, shape, lines.quads.points.reshape(-1, 2)).reshape(*points_shape) 66 | for quads in lines.charboxes: 67 | points_shape = quads.points.shape 68 | quads.points = self.may_augment_poly( 69 | aug, shape, quads.points.reshape(-1, 2)).reshape(*points_shape) 70 | return data 71 | 72 | def may_augment_poly(self, aug, img_shape, poly): 73 | keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly] 74 | keypoints = aug.augment_keypoints( 75 | [imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints 76 | poly = [(p.x, p.y) for p in keypoints] 77 | return np.array(poly, dtype=np.float32) 78 | -------------------------------------------------------------------------------- /data/processes/charboxes_from_textlines.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .data_process import DataProcess 4 | 5 | 6 | class CharboxesFromTextlines(DataProcess): 7 | def process(self, data): 8 | text_lines = data['lines'] 9 | 10 | charboxes = None 11 | shape = None 12 | for boxes in text_lines.charboxes: 13 | shape = boxes.shape[1:] 14 | if charboxes is None: 15 | charboxes = boxes 16 | else: 17 | charboxes = np.concatenate([charboxes, boxes], axis=0) 18 | charboxes = np.concatenate(charboxes, axis=0) 19 | if shape is not None: 20 | charboxes = charboxes.reshape(-1, *shape) 21 | data['charboxes'] = charboxes 22 | 23 | return data 24 | -------------------------------------------------------------------------------- /data/processes/data_process.py: -------------------------------------------------------------------------------- 1 | from concern.config import Configurable 2 | 3 | 4 | class DataProcess(Configurable): 5 | r'''Processes of data dict. 6 | ''' 7 | 8 | def __call__(self, data): 9 | return self.process(data) 10 | 11 | def process(self, data): 12 | raise NotImplementedError 13 | 14 | def render_constant(self, canvas, xmin, xmax, ymin, ymax, value=1, shrink=0): 15 | def shrink_rect(xmin, xmax, ratio): 16 | center = (xmin + xmax) / 2 17 | width = center - xmin 18 | return int(center - width * ratio + 0.5), int(center + width * ratio + 0.5) 19 | 20 | if shrink > 0: 21 | xmin, xmax = shrink_rect(xmin, xmax, shrink) 22 | ymin, ymax = shrink_rect(ymin, ymax, shrink) 23 | 24 | canvas[int(ymin+0.5):int(ymax+0.5)+1, int(xmin+0.5):int(xmax+0.5)+1] = value 25 | return canvas 26 | -------------------------------------------------------------------------------- /data/processes/extract_detetion_data.py: -------------------------------------------------------------------------------- 1 | import imgaug 2 | from concern.config import Configurable, State 3 | from data.augmenter import AugmenterBuilder 4 | 5 | 6 | class ExtractDetectionData(Configurable): 7 | augmenter_args = State(autoload=False) 8 | augmenter_class = State(default=None) 9 | 10 | def __init__(self, augmenter_class=None, **kwargs): 11 | self.augmenter_args = kwargs.get('augmenter_args') 12 | self.augmenter_class = augmenter_class or self.augmenter_class 13 | if self.augmenter_class is not None: 14 | self.augmenter = eval(self.augmenter_class)().augmenter 15 | else: 16 | self.augmenter = AugmenterBuilder().build(self.augmenter_args) 17 | 18 | def may_augment_poly(self, aug, img_shape, poly): 19 | if aug is None: 20 | return poly 21 | keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly] 22 | keypoints = aug.augment_keypoints( 23 | [imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints 24 | poly = [(p.x, p.y) for p in keypoints] 25 | return poly 26 | 27 | def __call__(self, data): 28 | img = data['img'] 29 | aug = None 30 | shape = img.shape 31 | 32 | if self.augmenter: 33 | aug = self.augmenter.to_deterministic() 34 | img = aug.augment_image(data['img']) 35 | 36 | line_polys = [] 37 | for line in data['lines']: 38 | line_polys.append({ 39 | 'points': self.may_augment_poly(aug, shape, line['poly']), 40 | 'ignore': line['text'] == '###', 41 | 'text': line['text'], 42 | }) 43 | filename = data.get('filename', data.get('data_id', '')) 44 | label = { 45 | 'polys': line_polys 46 | } 47 | meta = { 48 | 'data_id': data['data_id'], 49 | 'filename': filename, 50 | 'shape': shape[:2] 51 | } 52 | return img, label, meta 53 | -------------------------------------------------------------------------------- /data/processes/filter_keys.py: -------------------------------------------------------------------------------- 1 | from concern.config import State 2 | 3 | from .data_process import DataProcess 4 | 5 | 6 | class FilterKeys(DataProcess): 7 | required = State(default=[]) 8 | superfluous = State(default=[]) 9 | 10 | def __init__(self, **kwargs): 11 | super().__init__(self, **kwargs) 12 | 13 | self.required_keys = set(self.required) 14 | self.superfluous_keys = set(self.superfluous) 15 | if len(self.required_keys) > 0 and len(self.superfluous_keys) > 0: 16 | raise ValueError( 17 | 'required_keys and superfluous_keys can not be specified at the same time.') 18 | 19 | def process(self, data): 20 | for key in self.required: 21 | assert key in data, '%s is required in data' % key 22 | 23 | superfluous = self.superfluous_keys 24 | if len(superfluous) == 0: 25 | for key in data.keys(): 26 | if key not in self.required_keys: 27 | superfluous.add(key) 28 | 29 | for key in superfluous: 30 | del data[key] 31 | return data 32 | -------------------------------------------------------------------------------- /data/processes/make_center_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.ndimage.filters as fi 3 | 4 | from concern.config import State 5 | 6 | from .data_process import DataProcess 7 | 8 | 9 | class MakeCenterMap(DataProcess): 10 | max_size = State(default=32) 11 | shape = State(default=(64, 256)) 12 | sigma_ratio = State(default=16) 13 | function_name = State(default='sample_gaussian') 14 | points_key = 'points' 15 | correlation = 0 # The formulation of guassian is simplified when correlation is 0 16 | 17 | def process(self, data): 18 | assert self.points_key in data, '%s in data is required' % self.points_key 19 | points = data['points'] * self.shape[::-1] # N, 2 20 | assert points.shape[0] >= self.max_size 21 | func = getattr(self, self.function_name) 22 | data['charmaps'] = func(points, *self.shape) 23 | return data 24 | 25 | def gaussian(self, points, height, width): 26 | index_x, index_y = np.meshgrid(np.linspace(0, width, width), 27 | np.linspace(0, height, height)) 28 | index_x = np.repeat(index_x[np.newaxis], points.shape[0], axis=0) 29 | index_y = np.repeat(index_y[np.newaxis], points.shape[0], axis=0) 30 | mu_x = points[:, 0][:, np.newaxis, np.newaxis] 31 | mu_y = points[:, 1][:, np.newaxis, np.newaxis] 32 | mask_is_zero = ((mu_x == 0) + (mu_y == 0)) == 0 33 | result = np.reciprocal(2 * np.pi * width / self.sigma_ratio * height / self.sigma_ratio)\ 34 | * np.exp(- 0.5 * (np.square((index_x - mu_x) / width * self.sigma_ratio) + 35 | np.square((index_y - mu_y) / height * self.sigma_ratio))) 36 | 37 | result = result / \ 38 | np.maximum(result.max(axis=1, keepdims=True).max( 39 | axis=2, keepdims=True), np.finfo(np.float32).eps) 40 | result = result * mask_is_zero 41 | return result.astype(np.float32) 42 | 43 | def sample_gaussian(self, points, height, width): 44 | points = (points + 0.5).astype(np.int32) 45 | canvas = np.zeros((self.max_size, height, width), dtype=np.float32) 46 | for index in range(canvas.shape[0]): 47 | point = points[index] 48 | canvas[index, point[1], point[0]] = 1. 49 | if point.sum() > 0: 50 | fi.gaussian_filter(canvas[index], (height // self.sigma_ratio, 51 | width // self.sigma_ratio), 52 | output=canvas[index], mode='mirror') 53 | canvas[index] = canvas[index] / canvas[index].max() 54 | x_range = min(point[0], width - point[0]) 55 | canvas[index, :, :point[0] - x_range] = 0 56 | canvas[index, :, point[0] + x_range:] = 0 57 | y_range = min(point[1], width - point[1]) 58 | canvas[index, :point[1] - y_range, :] = 0 59 | canvas[index, point[1] + y_range:, :] = 0 60 | return canvas 61 | -------------------------------------------------------------------------------- /data/processes/make_center_points.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from concern.config import State 4 | from .data_process import DataProcess 5 | 6 | 7 | class MakeCenterPoints(DataProcess): 8 | box_key = State(default='charboxes') 9 | size = State(default=32) 10 | 11 | def process(self, data): 12 | shape = data['image'].shape[:2] 13 | points = np.zeros((self.size, 2), dtype=np.float32) 14 | boxes = np.array(data[self.box_key])[:self.size] 15 | 16 | size = boxes.shape[0] 17 | points[:size] = boxes.mean(axis=1) 18 | data['points'] = (points / shape[::-1]).astype(np.float32) 19 | return data 20 | -------------------------------------------------------------------------------- /data/processes/make_decouple_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.ndimage.filters as fi 3 | 4 | from concern.config import State 5 | 6 | from .data_process import DataProcess 7 | 8 | 9 | class MakeDecoupleMap(DataProcess): 10 | max_size = State(default=32) 11 | shape = State(default=(64, 256)) 12 | sigma = State(default=2) 13 | summation = State(default=False) 14 | box_key = State(default='charboxes') 15 | function = State(default='gaussian') 16 | 17 | def process(self, data): 18 | assert self.box_key in data, '%s in data is required' % self.box_key 19 | shape = data['image'].shape[:2] 20 | boxes = np.array(data[self.box_key]) 21 | 22 | ratio_x = shape[1] / self.shape[1] 23 | boxes[:, :, 0] = (boxes[:, :, 0] / ratio_x).clip(0, self.shape[1]) 24 | ratio_y = shape[0] / self.shape[0] 25 | boxes[:, :, 1] = (boxes[:, :, 1] / ratio_y).clip(0, self.shape[0]) 26 | boxes = (boxes + .5).astype(np.int32) 27 | xmins = boxes[:, :, 0].min(axis=1) 28 | xmaxs = np.maximum(boxes[:, :, 0].max(axis=1), xmins + 1) 29 | 30 | ymins = boxes[:, :, 1].min(axis=1) 31 | ymaxs = np.maximum(boxes[:, :, 1].max(axis=1), ymins + 1) 32 | 33 | if self.summation: 34 | canvas = np.zeros(self.shape, dtype=np.int32) 35 | else: 36 | canvas = np.zeros((self.max_size, *self.shape), dtype=np.float32) 37 | 38 | mask = np.zeros(self.shape, dtype=np.float32) 39 | orders = self.orders(data) 40 | for i in range(xmins.shape[0]): 41 | function = getattr(self, 'render_' + self.function) 42 | order = min(orders[i], self.max_size) 43 | if self.summation: 44 | function(canvas, xmins[i], xmaxs[i], ymins[i], ymaxs[i], 45 | value=order+1, shrink=0.6) 46 | else: 47 | function(canvas[order], xmins[i], xmaxs[i], ymins[i], ymaxs[i]) 48 | self.render_gaussian(mask, xmins[i], xmaxs[i], ymins[i], ymaxs[i]) 49 | 50 | data['ordermaps'] = canvas 51 | data['mask'] = mask 52 | return data 53 | 54 | def orders(self, data): 55 | orders = [] 56 | if 'lines' in data: 57 | for text in data['lines'].texts: 58 | orders += list(range(len(text))) 59 | else: 60 | orders = list(range(data[self.box_key].shape[0])) 61 | return orders 62 | 63 | def render_gaussian_thresh(self, canvas, xmin, xmax, ymin, ymax, 64 | value=1, thresh=0.2, shrink=None): 65 | out = np.zeros_like(canvas).astype(np.float32) 66 | out[(ymax+ymin+1)//2, (xmax+xmin+1)//2] = 1. 67 | h, w = canvas.shape[:2] 68 | out = fi.gaussian_filter(out, (self.sigma, self.sigma), 69 | output=out, mode='mirror') 70 | out = out / out.max() 71 | canvas[out > thresh] = value 72 | 73 | def render_gaussian(self, canvas, xmin, xmax, ymin, ymax): 74 | out = np.zeros_like(canvas) 75 | out[(ymax+ymin+1)//2, (xmax+xmin+1)//2] = 1. 76 | h, w = canvas.shape[:2] 77 | fi.gaussian_filter(out, (self.sigma, self.sigma), 78 | output=out, mode='mirror') 79 | out = out / out.max() 80 | canvas[out > canvas] = out[out > canvas] 81 | 82 | def render_gaussian_fast(self, canvas, xmin, xmax, ymin, ymax): 83 | out = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.float32) 84 | out[(ymax-ymin+1)//2, (xmax-xmin+1)//2] = 1. 85 | h, w = canvas.shape[:2] 86 | fi.gaussian_filter(out, (self.sigma, self.sigma), 87 | output=out, mode='mirror') 88 | out = out / out.max() 89 | canvas[ymin:ymax+1, xmin:xmax+1] = np.maximum(out, canvas[ymin:ymax+1, xmin:xmax+1]) 90 | -------------------------------------------------------------------------------- /data/processes/make_icdar_data.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from concern.config import Configurable, State 8 | from .data_process import DataProcess 9 | 10 | 11 | class MakeICDARData(DataProcess): 12 | shrink_ratio = State(default=0.4) 13 | 14 | def __init__(self, debug=False, cmd={}, **kwargs): 15 | self.load_all(**kwargs) 16 | 17 | self.debug = debug 18 | if 'debug' in cmd: 19 | self.debug = cmd['debug'] 20 | 21 | def process(self, data): 22 | polygons = [] 23 | ignore_tags = [] 24 | annotations = data['polys'] 25 | for annotation in annotations: 26 | polygons.append(annotation['points']) 27 | ignore_tags.append(annotation['ignore']) 28 | ignore_tags = np.array(ignore_tags, dtype=np.uint8) 29 | polygons = np.array(polygons) 30 | filename = data.get('filename', data['data_id']) 31 | if self.debug: 32 | self.draw_polygons(data['image'], polygons, ignore_tags) 33 | shape = np.array(data['shape']) 34 | return OrderedDict(image=data['image'], 35 | polygons=polygons, 36 | ignore_tags=ignore_tags, 37 | shape=shape, 38 | filename=filename) 39 | 40 | def draw_polygons(self, image, polygons, ignore_tags): 41 | import cv2 42 | for i in range(polygons.shape[0]): 43 | polygon = polygons[i].reshape(4, 2).astype(np.int32) 44 | ignore = ignore_tags[i] 45 | if ignore: 46 | color = (255, 0, 0) # depict ignorable polygons in blue 47 | else: 48 | color = (0, 0, 255) # depict polygons in red 49 | 50 | cv2.polylines(image, [polygon], True, color, 1) 51 | 52 | polylines = staticmethod(draw_polygons) 53 | 54 | 55 | class ICDARCollectFN(Configurable): 56 | def __init__(self, *args, **kwargs): 57 | pass 58 | 59 | def __call__(self, batch): 60 | data_dict = OrderedDict() 61 | for sample in batch: 62 | for k, v in sample.items(): 63 | if k not in data_dict: 64 | data_dict[k] = [] 65 | if isinstance(v, np.ndarray): 66 | v = torch.from_numpy(v) 67 | data_dict[k].append(v) 68 | data_dict['image'] = torch.stack(data_dict['image'], 0) 69 | return data_dict 70 | -------------------------------------------------------------------------------- /data/processes/make_keypoint_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | from concern.config import State 5 | 6 | from .data_process import DataProcess 7 | 8 | 9 | class MakeKeyPointMap(DataProcess): 10 | max_size = State(default=32) 11 | box_key = State(default='charboxes') 12 | shape = State(default=[16, 64]) 13 | 14 | def process(self, data): 15 | assert self.box_key in data, '%s in data is required' % self.box_key 16 | 17 | ori_h, ori_w = data['image'].shape[:2] 18 | charboxes = data[self.box_key] 19 | height, width = self.shape 20 | ratio_h, ratio_w = float(height) / ori_h, float(width) / ori_w 21 | boxes = np.zeros_like(charboxes) 22 | boxes[..., 0] = (charboxes[..., 0] * ratio_w + 0.5).astype(np.int32) 23 | boxes[..., 1] = (charboxes[..., 1] * ratio_h + 0.5).astype(np.int32) 24 | charmaps = self.gen_keypoint_map((boxes).astype(np.float32), self.shape[0], self.shape[1]) 25 | data['charmaps'] = charmaps 26 | return data 27 | 28 | def get_gaussian(self, h, w, m): 29 | xs, ys = np.meshgrid(np.arange(w), np.arange(h)) 30 | g_map = np.exp(-m * (np.power((xs * 2 / w - 1), 2) + np.power((ys * 2 / h - 1), 2))) 31 | return g_map 32 | 33 | def gen_keypoint_map(self, boxes, h, w): 34 | maps = np.zeros((self.max_size, h, w), dtype=np.float32) 35 | for ind, box in enumerate(boxes): 36 | _, _, box_w, box_h = cv2.boundingRect(box) 37 | src = np.array([[0, box_h], [box_w, box_h], [box_w, 0], [0, 0]]).astype(np.float32) 38 | g = self.get_gaussian(box_h, box_w, m=2) 39 | M = cv2.getPerspectiveTransform(src, box.astype(np.float32)) 40 | maps[ind] = cv2.warpPerspective(g, M, (w, h)) 41 | return maps 42 | -------------------------------------------------------------------------------- /data/processes/make_recognition_label.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from concern.config import State 4 | from concern.charsets import DefaultCharset 5 | 6 | from .data_process import DataProcess 7 | 8 | 9 | class MakeRecognitionLabel(DataProcess): 10 | charset = State(default=DefaultCharset()) 11 | max_size = State(default=32) 12 | 13 | def process(self, data): 14 | assert 'gt' in data, '`gt` in data is required by this process' 15 | gt = data['gt'] 16 | label = self.gt_to_label(gt) 17 | data['label'] = label 18 | if label.sum() == 0: 19 | raise 'Empty Label' # FIXME: package into a class. 20 | 21 | length = len(gt) 22 | if self.max_size is not None: 23 | length = min(length, self.max_size) 24 | length = np.array(length, dtype=np.int32) 25 | data['length'] = length 26 | return data 27 | 28 | def gt_to_label(self, gt, image=None): 29 | if self.max_size is not None: 30 | return self.charset.string_to_label(gt)[:self.max_size] 31 | else: 32 | return self.charset.string_to_label(gt) 33 | -------------------------------------------------------------------------------- /data/processes/make_seg_recognition_label.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from shapely.geometry import Polygon 4 | import pyclipper 5 | import torch 6 | import torch.nn.functional as F 7 | from concern.config import State 8 | from concern.charsets import DefaultCharset 9 | import ipdb 10 | from .data_process import DataProcess 11 | 12 | 13 | class MakeSegRecognitionLabel(DataProcess): 14 | charset = State(default=DefaultCharset()) 15 | shrink = State(default=True) 16 | shrink_ratio = State(default=0.25) 17 | exempt_chars = State(default=list('')) 18 | shape = State(default=(16, 64)) 19 | 20 | max_size = 256 21 | 22 | def process(self, data: dict): 23 | assert 'charboxes' in data, 'charboxes in data is required' 24 | ori_height, ori_width = data['image'].shape[:2] 25 | charboxes = data['charboxes'] 26 | gt = data['gt'] 27 | height, width = self.shape 28 | ratio_h, ratio_w = float(height) / ori_height, float(width) / ori_width 29 | assert len(charboxes) == len(gt) 30 | mask = np.zeros((height, width), dtype=np.float32) 31 | classify = np.zeros((height, width), dtype=np.int32) 32 | order_map = np.zeros((height, width), dtype=np.int32) 33 | shrink_ratio = self.shrink_ratio 34 | boxes = np.zeros_like(charboxes) 35 | boxes[..., 0] = (charboxes[..., 0] * ratio_w + 0.5).astype(np.int32) 36 | boxes[..., 1] = (charboxes[..., 1] * ratio_h + 0.5).astype(np.int32) 37 | for box_index, box in enumerate(boxes): 38 | class_code = self.charset.index(gt[box_index]) 39 | if self.shrink: 40 | if self.charset.is_empty(class_code) or gt[box_index] in self.exempt_chars: 41 | shrink_ratio = 0 42 | try: 43 | rect = self.poly_to_rect(box, shrink_ratio) 44 | except AssertionError: 45 | # invalid poly 46 | continue 47 | else: 48 | rect = box 49 | if rect is not None: 50 | self.render_rect(mask, rect) 51 | self.render_rect(classify, rect, class_code) 52 | self.render_rect(order_map, rect, box_index + 1) 53 | data['mask'] = mask 54 | data['classify'] = classify 55 | if classify.sum() == 0: 56 | raise 'gt is empty!' 57 | data['ordermaps'] = order_map 58 | return data 59 | 60 | def poly_to_rect(self, poly, shrink_ratio=None): 61 | if shrink_ratio is None: 62 | shrink_ratio = self.shrink_ratio 63 | polygon_shape = Polygon(poly) 64 | distance = polygon_shape.area * \ 65 | (1 - np.power(shrink_ratio, 2)) / polygon_shape.length 66 | subject = [tuple(l) for l in poly] 67 | padding = pyclipper.PyclipperOffset() 68 | padding.AddPath(subject, pyclipper.JT_ROUND, 69 | pyclipper.ET_CLOSEDPOLYGON) 70 | shrinked = padding.Execute(-distance) 71 | if shrinked == []: 72 | return None 73 | return shrinked 74 | 75 | def render_rect(self, canvas, poly, value=1): 76 | poly = np.array(poly, dtype=np.int32).reshape(-1, 2) 77 | return cv2.fillPoly(canvas, [poly], value) 78 | ''' 79 | xmin, xmax = poly[:, 0].min(), poly[:, 0].max() 80 | ymin, ymax = poly[:, 1].min(), poly[:, 1].max() 81 | canvas[ymin:ymax+1, xmin:xmax+1] = value 82 | return canvas 83 | ''' 84 | -------------------------------------------------------------------------------- /data/processes/normalize_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .data_process import DataProcess 5 | 6 | 7 | class NormalizeImage(DataProcess): 8 | RGB_MEAN = np.array([122.67891434, 116.66876762, 104.00698793]) 9 | 10 | def process(self, data): 11 | assert 'image' in data, '`image` in data is required by this process' 12 | image = data['image'] 13 | image -= self.RGB_MEAN 14 | image /= 255. 15 | image = torch.from_numpy(image).permute(2, 0, 1).float() 16 | data['image'] = image 17 | return data 18 | 19 | @classmethod 20 | def restore(self, image): 21 | image = image.permute(1, 2, 0).to('cpu').numpy() 22 | image = image * 255. 23 | image += self.RGB_MEAN 24 | image = image.astype(np.uint8) 25 | return image 26 | -------------------------------------------------------------------------------- /data/processes/serialize_box.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from data.quad import Quad 4 | from concern.config import State 5 | from .data_process import DataProcess 6 | 7 | 8 | class SerializeBox(DataProcess): 9 | box_key = State(default='charboxes') 10 | format = State(default='NP2') 11 | 12 | def process(self, data): 13 | data[self.box_key] = data['lines'].quads 14 | return data 15 | 16 | 17 | class UnifyRect(SerializeBox): 18 | max_size = State(default=64) 19 | 20 | def process(self, data): 21 | h, w = data['image'].shape[:2] 22 | boxes = np.zeros((self.max_size, 4), dtype=np.float32) 23 | mask_has_box = np.zeros(self.max_size, dtype=np.float32) 24 | data = super().process(data) 25 | quad = data[self.box_key] 26 | assert quad.shape[0] <= self.max_size 27 | boxes[:quad.shape[0]] = quad.rectify() / np.array([w, h, w, h]).reshape(1, 4) 28 | mask_has_box[:quad.shape[0]] = 1. 29 | data['boxes'] = boxes 30 | data['mask_has_box'] = mask_has_box 31 | return data 32 | -------------------------------------------------------------------------------- /data/quad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class Quad: 6 | def __init__(self, points, format='NP2'): 7 | self._rect = None 8 | self.tensorized = False 9 | self._points = None 10 | self.set_points(points, format) 11 | 12 | @property 13 | def points(self): 14 | return self._points 15 | 16 | def set_points(self, new_points, format='NP2'): 17 | order = (format.index('N'), format.index('P'), format.index('2')) 18 | 19 | if isinstance(new_points, torch.Tensor): 20 | self._points = new_points.permute(*order) 21 | self.tensorized = True 22 | else: 23 | points = np.array(new_points, dtype=np.float32) 24 | self._points = points.transpose(*order) 25 | 26 | if self.tensorized: 27 | self.tensorized = False 28 | self.tensor 29 | 30 | @points.setter 31 | def points(self, new_points): 32 | self.set_points(new_points) 33 | 34 | @property 35 | def tensor(self): 36 | if not self.tensorized: 37 | self._points = torch.from_numpy(self._points) 38 | return self._points 39 | 40 | def to(self, device): 41 | self._points.to(device) 42 | return self._points 43 | 44 | def __iter__(self): 45 | for i in range(self._points.shape[0]): 46 | if self.tensorized: 47 | yield self.tensor[i] 48 | else: 49 | yield self.points[i] 50 | 51 | 52 | def rect(self): 53 | if self._rect is None: 54 | self._rect = self.rectify() 55 | return self._rect 56 | 57 | def __getitem__(self, *args, **kwargs): 58 | return self._points.__getitem__(*args, **kwargs) 59 | 60 | def numpy(self): 61 | if not self.tensorized: 62 | return self._points 63 | return self._points.cpu().data.numpy() 64 | 65 | def rectify(self): 66 | if self.tensorized: 67 | return self.rectify_tensor() 68 | 69 | xmin = self._points[:, :, 0].min(axis=1) 70 | ymin = self._points[:, :, 1].min(axis=1) 71 | xmax = self._points[:, :, 0].max(axis=1) 72 | ymax = self._points[:, :, 1].max(axis=1) 73 | return np.stack([xmin, ymin, xmax, ymax], axis=1) 74 | 75 | def rectify_tensor(self): 76 | xmin, _ = self.tensor[:, :, 0].min(dim=1, keepdim=True) 77 | ymin, _ = self.tensor[:, :, 1].min(dim=1, keepdim=True) 78 | xmax, _ = self.tensor[:, :, 0].max(dim=1, keepdim=True) 79 | ymax, _ = self.tensor[:, :, 1].max(dim=1, keepdim=True) 80 | return torch.cat([xmin, ymin, xmax, ymax], dim=1) 81 | 82 | def __getattribute__(self, name): 83 | try: 84 | return super().__getattribute__(name) 85 | except AttributeError: 86 | return self._points.__getattribute__(name) 87 | -------------------------------------------------------------------------------- /data/text_lines.py: -------------------------------------------------------------------------------- 1 | from .quad import Quad 2 | 3 | 4 | class TextLines: 5 | ''' 6 | The abstrct class of text lines in an input image for scene text detection and recognition. 7 | Input: 8 | lines: 9 | - text: the text content of a text instance. 10 | poly: the quadrangle-box of the text instance. 11 | charboxes: the quadrangle-box of the characters inside the corresponding text instance. 12 | ''' 13 | def __init__(self, lines, with_charboxes=True): 14 | self.texts = [] 15 | quads = [] 16 | self.charboxes = [] 17 | for line in lines: 18 | self.texts.append(line['text']) 19 | quads.append(line['poly']) 20 | if with_charboxes and 'charboxes' in line: 21 | self.charboxes.append(Quad(line['charboxes'])) 22 | self.with_charboxes = len(self.charboxes) > 0 23 | self.quads = Quad(quads) 24 | self._rects = None 25 | 26 | def __iter__(self): 27 | for text, quad in zip(self.texts, self.quads): 28 | yield(text, quad) 29 | 30 | @property 31 | def rects(self): 32 | if self._rects is None: 33 | self._rects = self.quads.rectify() 34 | return self._rects 35 | 36 | def __len__(self): 37 | return len(self.texts) 38 | 39 | def char_count(self): 40 | return sum([len(t) for t in self.texts]) 41 | -------------------------------------------------------------------------------- /data/unpack_msgpack_data.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import cv2 4 | import numpy as np 5 | import msgpack 6 | import config 7 | import os 8 | import lmdb 9 | from PIL import Image 10 | 11 | from concern.config import Configurable, State 12 | 13 | 14 | class UnpackMsgpackData(Configurable): 15 | mode = State(default='BGR') 16 | 17 | def __init__(self, cmd={}, **kwargs): 18 | self.load_all(**kwargs) 19 | if config.will_use_nori: 20 | self.fetcher = nori.Fetcher() 21 | elif config.will_use_lmdb: 22 | self.envs = [] 23 | self.txns = {} 24 | if 'mode' in cmd: 25 | self.mode = cmd['mode'] 26 | 27 | def convert_obj(self, obj): 28 | if isinstance(obj, dict): 29 | ndata = {} 30 | for key, value in obj.items(): 31 | nkey = key.decode() 32 | if nkey == 'img': 33 | img = Image.open(io.BytesIO(value)) 34 | img = np.array(img.convert('RGB')) 35 | if self.mode == 'BGR': 36 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 37 | nvalue = img 38 | else: 39 | nvalue = self.convert_obj(value) 40 | ndata[nkey] = nvalue 41 | return ndata 42 | elif isinstance(obj, list): 43 | return [self.convert_obj(item) for item in obj] 44 | elif isinstance(obj, bytes): 45 | return obj.decode() 46 | else: 47 | return obj 48 | 49 | def convert(self, data): 50 | obj = msgpack.loads(data, max_str_len=2 ** 31) 51 | return self.convert_obj(obj) 52 | 53 | def __call__(self, data_id, meta=None): 54 | if meta is None: 55 | meta = {} 56 | item = self.convert(self.fetcher.get(data_id)) 57 | item['data_id'] = data_id 58 | meta.update(item) 59 | return meta 60 | 61 | 62 | class TransformMsgpackData(UnpackMsgpackData): 63 | meta_loader = State(default=None) 64 | 65 | def __init__(self, meta_loader=None, cmd={}, **kwargs): 66 | super().__init__(cmd=cmd, meta_loader=meta_loader, **kwargs) 67 | print('transform') 68 | self.meta_loader = cmd.get('meta_loader', self.meta_loader) 69 | 70 | def get_item(self, data_id, meta): 71 | if config.will_use_nori: 72 | item = self.fetcher.get(data_id) 73 | elif config.will_use_lmdb: 74 | db_path = meta['db_path'] 75 | if db_path not in self.envs: 76 | path = os.path.join(db_path, '') 77 | env = lmdb.open(db_path, max_dbs=1, lock=False) 78 | db_image = env.open_db('image'.encode()) 79 | self.envs.append(env) 80 | self.txns[db_path] = env.begin(db=db_image) 81 | item = self.txns[db_path].get(data_id) 82 | else: 83 | raise NotImplementedError 84 | return item 85 | 86 | def __call__(self, data_id, meta): 87 | item = self.get_item(data_id, meta) 88 | item = self.convert(item) 89 | image = item.pop('img').astype(np.float32) 90 | if self.meta_loader is not None: 91 | meta['extra'] = item 92 | data = self.meta_loader.parse_meta(data_id, meta) 93 | else: 94 | data = meta 95 | data.update(**item) 96 | data.update(image=image, data_id=data_id) 97 | return data 98 | -------------------------------------------------------------------------------- /decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification import ClassificationDecoder 2 | from .attention_decoder import AttentionDecoder 3 | from .textsnake import TextsnakeDecoder 4 | from .east import EASTDecoder 5 | from .dice_loss import DiceLoss 6 | from .pss_loss import PSS_Loss 7 | from .ctc_decoder2d import CTCDecoder2D 8 | from .simple_detection import SimpleSegDecoder, SimpleEASTDecoder, SimpleTextsnakeDecoder, SimpleMSRDecoder 9 | from .ctc_decoder import CTCDecoder 10 | from .l1_loss import MaskL1Loss 11 | from .balance_cross_entropy_loss import BalanceCrossEntropyLoss 12 | from .crnn import CRNNDecoder 13 | from .seg_recognizer import SegRecognizer 14 | from .seg_detector import SegDetector -------------------------------------------------------------------------------- /decoders/balance_cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BalanceCrossEntropyLoss(nn.Module): 6 | ''' 7 | Balanced cross entropy loss. 8 | Shape: 9 | - Input: :math:`(N, 1, H, W)` 10 | - GT: :math:`(N, 1, H, W)`, same shape as the input 11 | - Mask: :math:`(N, H, W)`, same spatial shape as the input 12 | - Output: scalar. 13 | 14 | Examples:: 15 | 16 | >>> m = nn.Sigmoid() 17 | >>> loss = nn.BCELoss() 18 | >>> input = torch.randn(3, requires_grad=True) 19 | >>> target = torch.empty(3).random_(2) 20 | >>> output = loss(m(input), target) 21 | >>> output.backward() 22 | ''' 23 | 24 | def __init__(self, negative_ratio=3.0, eps=1e-6): 25 | super(BalanceCrossEntropyLoss, self).__init__() 26 | self.negative_ratio = negative_ratio 27 | self.eps = eps 28 | 29 | def forward(self, 30 | pred: torch.Tensor, 31 | gt: torch.Tensor, 32 | mask: torch.Tensor, 33 | return_origin=False): 34 | ''' 35 | Args: 36 | pred: shape :math:`(N, 1, H, W)`, the prediction of network 37 | gt: shape :math:`(N, 1, H, W)`, the target 38 | mask: shape :math:`(N, H, W)`, the mask indicates positive regions 39 | ''' 40 | positive = (gt * mask).byte() 41 | negative = ((1 - gt) * mask).byte() 42 | positive_count = int(positive.float().sum()) 43 | negative_count = min(int(negative.float().sum()), 44 | int(positive_count * self.negative_ratio)) 45 | loss = nn.functional.binary_cross_entropy( 46 | pred, gt, reduction='none')[:, 0, :, :] 47 | positive_loss = loss * positive.float() 48 | negative_loss = loss * negative.float() 49 | negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) 50 | 51 | balance_loss = (positive_loss.sum() + negative_loss.sum()) /\ 52 | (positive_count + negative_count + self.eps) 53 | 54 | if return_origin: 55 | return balance_loss, loss 56 | return balance_loss 57 | -------------------------------------------------------------------------------- /decoders/classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ClassificationDecoder(nn.Module): 6 | 7 | def __init__(self): 8 | super(ClassificationDecoder, self).__init__() 9 | 10 | self.fc = torch.nn.Linear(256, 10) 11 | self.criterion = torch.nn.CrossEntropyLoss() 12 | 13 | def forward(self, feature_map, targets=None, train=False): 14 | x = torch.max(torch.max(feature_map, dim=3)[0], dim=2)[0] 15 | x = self.fc(x) 16 | pred = x 17 | if train: 18 | loss = self.criterion(pred, targets) 19 | return loss, pred 20 | else: 21 | return pred 22 | -------------------------------------------------------------------------------- /decoders/ctc_decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : ctc_decoder.py 4 | # Author : Zhaoyi Wan 5 | # Date : 18.12.2018 6 | # Last Modified Date: 20.01.2019 7 | # Last Modified By : Zhaoyi Wan 8 | 9 | import torch 10 | import torch.nn as nn 11 | from concern.charsets import DefaultCharset 12 | 13 | 14 | class CTCDecoder(nn.Module): 15 | def __init__(self, in_channels, charset=DefaultCharset(), inner_channels=256, **kwargs): 16 | super(CTCDecoder, self).__init__() 17 | self.ctc_loss = nn.CTCLoss(reduction='mean') 18 | self.inner_channels = inner_channels 19 | self.encode = self._init_encoder(in_channels) 20 | 21 | self.pred_conv = nn.Conv2d( 22 | inner_channels, len(charset), kernel_size=1, bias=True, padding=0) 23 | self.softmax = nn.LogSoftmax(dim=1) 24 | 25 | self.blank = 0 26 | if 'blank' in kwargs: 27 | self.blank = kwargs['blank'] 28 | 29 | def _init_encoder(self, in_channels, stride=(2, 1), padding=(0, 1)): 30 | encode = nn.Sequential( 31 | self.conv_bn_relu(in_channels, self.inner_channels), 32 | self.conv_bn_relu(self.inner_channels, self.inner_channels), 33 | nn.MaxPool2d((2, 2), (2, 2), (0, 0)), 34 | self.conv_bn_relu(self.inner_channels, self.inner_channels), 35 | self.conv_bn_relu(self.inner_channels, self.inner_channels), 36 | nn.MaxPool2d(stride, stride, (0, 0)), 37 | self.conv_bn_relu(self.inner_channels, self.inner_channels), 38 | self.conv_bn_relu(self.inner_channels, self.inner_channels), 39 | nn.MaxPool2d(stride, stride, (0, 0)), 40 | self.conv_bn_relu(self.inner_channels, self.inner_channels, 41 | kernel_size=(2, 3), 42 | stride=stride, padding=padding), 43 | ) 44 | return encode 45 | 46 | def conv_bn_relu(self, input_channels, output_channels, 47 | kernel_size=3, stride=1, padding=1): 48 | return nn.Sequential(nn.Conv2d( 49 | input_channels, output_channels, 50 | kernel_size=kernel_size, stride=stride, padding=padding), 51 | nn.BatchNorm2d(output_channels), 52 | nn.ReLU(inplace=True),) 53 | 54 | def forward(self, feature, targets=None, lengths=None, train=False): 55 | pred = self.encode(feature) 56 | pred = self.pred_conv(pred) 57 | if train: 58 | pred = self.softmax(pred) 59 | pred = pred.select(2, 0) # N, C, W 60 | pred = pred.permute(2, 0, 1) # W, N, C 61 | input_lengths = torch.zeros((feature.size()[0], ), dtype=torch.int) + 32 62 | loss = self.ctc_loss(pred, targets, input_lengths, lengths) 63 | return loss, pred.permute(1, 2, 0) 64 | else: 65 | pred = nn.functional.softmax(pred, dim=1) 66 | return pred 67 | -------------------------------------------------------------------------------- /decoders/ctc_decoder2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from concern.charsets import DefaultCharset 5 | 6 | 7 | class CTCDecoder2D(nn.Module): 8 | def __init__(self, in_channels, charset=DefaultCharset(), 9 | inner_channels=256, stride=1, blank=0, **kwargs): 10 | super(CTCDecoder2D, self).__init__() 11 | self.charset = charset 12 | from ops import ctc_loss_2d 13 | self.ctc_loss = ctc_loss_2d 14 | 15 | self.inner_channels = inner_channels 16 | self.pred_mask = nn.Sequential( 17 | nn.AvgPool2d(kernel_size=(stride, stride), 18 | stride=(stride, stride)), 19 | nn.Conv2d(in_channels, inner_channels, kernel_size=3, padding=1), 20 | nn.Conv2d(inner_channels, 1, kernel_size=1), 21 | nn.Softmax(dim=2)) 22 | 23 | self.pred_classify = nn.Sequential( 24 | nn.AvgPool2d(kernel_size=(stride, stride), 25 | stride=(stride, stride)), 26 | nn.Conv2d(in_channels, inner_channels, kernel_size=3, padding=1), 27 | nn.Conv2d(inner_channels, len(charset), kernel_size=1)) 28 | self.blank = blank 29 | self.tiny = torch.tensor(torch.finfo().tiny, requires_grad=False) 30 | self.register_buffer('saved_tiny', self.tiny) 31 | 32 | def forward(self, feature, targets=None, lengths=None, train=False, 33 | masks=None, segs=None): 34 | tiny = self.saved_tiny 35 | if isinstance(feature, tuple): 36 | feature = feature[-1] 37 | masking = self.pred_mask(feature) 38 | # mask = masking / torch.max(masking.sum(dim=2, keepdim=True), tiny) 39 | mask = masking 40 | classify = self.pred_classify(feature) 41 | classify = nn.functional.softmax(classify, dim=1) 42 | if self.training: 43 | pred = mask * classify # N, C, H ,W 44 | pred = torch.log(torch.max(pred, tiny)) 45 | pred = pred.permute(3, 2, 0, 1).contiguous() # W, H, N, C 46 | input_lengths = torch.zeros( 47 | (feature.size()[0], )) + pred.shape[0] 48 | loss = self.ctc_loss(pred, targets.long(), input_lengths.long().to( 49 | pred.device), lengths.long()) / lengths.float() 50 | # return loss, pred.permute(2, 3, 1, 0) 51 | return loss, pred 52 | else: 53 | return classify, mask 54 | 55 | def mask_loss(self, mask, weight, gt): 56 | batch_size, _, height, _ = mask.shape 57 | loss = nn.functional.nll_loss( 58 | (mask.permute(0, 1, 3, 2).reshape(-1, height) + self.saved_tiny).log(), 59 | gt.reshape(-1), reduction='none').view(batch_size, -1).mean(1) 60 | return loss * weight 61 | 62 | def classify_loss(self, classify, weight, gt): 63 | batch_size, classes = classify.shape[:2] 64 | loss = nn.functional.nll_loss( 65 | (classify.permute(0, 2, 3, 1).reshape(-1, classes) + self.saved_tiny).log(), 66 | gt.reshape(-1), reduction='none').view(batch_size, -1) 67 | position_weights = (gt.view(batch_size, -1) == self.blank).float() 68 | loss = loss * position_weights 69 | 70 | return loss.mean(1) * weight 71 | -------------------------------------------------------------------------------- /decoders/east.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class EASTDecoder(nn.Module): 7 | def __init__(self, channels=256, heatmap_ratio=1.0, densebox_ratio=0.01, densebox_rescale_factor=512): 8 | nn.Module.__init__(self) 9 | 10 | self.heatmap_ratio = heatmap_ratio 11 | self.densebox_ratio = densebox_ratio 12 | self.densebox_rescale_factor = densebox_rescale_factor 13 | 14 | self.head_layer = nn.Sequential( 15 | nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1), 16 | nn.BatchNorm2d(channels), 17 | nn.ReLU(inplace=True), 18 | nn.ConvTranspose2d(channels, channels // 2, kernel_size=2, stride=2, padding=0), 19 | nn.BatchNorm2d(channels // 2), 20 | nn.ReLU(inplace=True), 21 | nn.ConvTranspose2d(channels // 2, channels // 4, kernel_size=2, stride=2, padding=0), 22 | ) 23 | 24 | self.heatmap_pred_layer = nn.Sequential( 25 | nn.Conv2d(channels // 4, 1, kernel_size=1, stride=1, padding=0), 26 | ) 27 | 28 | self.densebox_pred_layer = nn.Sequential( 29 | nn.Conv2d(channels // 4, 8, kernel_size=1, stride=1, padding=0), 30 | ) 31 | 32 | def forward(self, input, label, meta, train): 33 | heatmap = label['heatmap'] 34 | heatmap_weight = label['heatmap_weight'] 35 | densebox = label['densebox'] 36 | densebox_weight = label['densebox_weight'] 37 | 38 | feature = self.head_layer(input) 39 | heatmap_pred = self.heatmap_pred_layer(feature) 40 | densebox_pred = self.densebox_pred_layer(feature) * self.densebox_rescale_factor 41 | 42 | heatmap_loss = F.binary_cross_entropy_with_logits(heatmap_pred, heatmap, reduction='none') 43 | heatmap_loss = (heatmap_loss * heatmap_weight).mean(dim=(1, 2, 3)) 44 | densebox_loss = F.mse_loss(densebox_pred, densebox, reduction='none') 45 | densebox_loss = (densebox_loss * densebox_weight).mean(dim=(1, 2, 3)) 46 | 47 | loss = heatmap_loss * self.heatmap_ratio + densebox_loss * self.densebox_ratio 48 | 49 | pred = { 50 | 'heatmap': F.sigmoid(heatmap_pred), 51 | 'densebox': densebox_pred, 52 | } 53 | metrics = { 54 | 'heatmap_loss': heatmap_loss, 55 | 'densebox_loss': densebox_loss, 56 | } 57 | if train: 58 | return loss, pred, metrics 59 | else: 60 | return pred 61 | -------------------------------------------------------------------------------- /decoders/l1_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MaskL1Loss(nn.Module): 6 | def __init__(self): 7 | super(MaskL1Loss, self).__init__() 8 | 9 | def forward(self, pred: torch.Tensor, gt, mask): 10 | loss = (torch.abs(pred[:, 0] - gt) * mask).sum() / mask.sum() 11 | return loss, dict(l1_loss=loss) 12 | 13 | 14 | class BalanceL1Loss(nn.Module): 15 | def __init__(self, negative_ratio=3.): 16 | super(BalanceL1Loss, self).__init__() 17 | self.negative_ratio = negative_ratio 18 | 19 | def forward(self, pred: torch.Tensor, gt, mask): 20 | ''' 21 | Args: 22 | pred: (N, 1, H, W). 23 | gt: (N, H, W). 24 | mask: (N, H, W). 25 | ''' 26 | loss = torch.abs(pred[:, 0] - gt) 27 | positive = loss * mask 28 | negative = loss * (1 - mask) 29 | positive_count = int(mask.sum()) 30 | negative_count = min( 31 | int((1 - mask).sum()), 32 | int(positive_count * self.negative_ratio)) 33 | negative_loss, _ = torch.topk(negative.view(-1), negative_count) 34 | negative_loss = negative_loss.sum() / negative_count 35 | positive_loss = positive.sum() / positive_count 36 | return positive_loss + negative_loss,\ 37 | dict(l1_loss=positive_loss, nge_l1_loss=negative_loss) 38 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | from concern.config import Configurable, State 2 | from concern.log import Logger 3 | from structure.builder import Builder 4 | from structure.representers import * 5 | from structure.measurers import * 6 | from structure.visualizer import TrivalVisualizer 7 | from structure.visualizers import * 8 | from data.data_loader import * 9 | from data import * 10 | from training.model_saver import ModelSaver 11 | from training.checkpoint import Checkpoint 12 | from training.optimizer_scheduler import OptimizerScheduler 13 | 14 | 15 | class Structure(Configurable): 16 | builder = State() 17 | representer = State() 18 | measurer = State() 19 | visualizer = State() 20 | 21 | def __init__(self, **kwargs): 22 | self.load_all(**kwargs) 23 | 24 | @property 25 | def model_name(self): 26 | return self.builder.model_name 27 | 28 | 29 | class TrainSettings(Configurable): 30 | data_loader = State() 31 | model_saver = State() 32 | checkpoint = State() 33 | scheduler = State() 34 | epochs = State(default=10) 35 | 36 | def __init__(self, **kwargs): 37 | kwargs['cmd'].update(is_train=True) 38 | self.load_all(**kwargs) 39 | if 'epochs' in kwargs['cmd']: 40 | self.epochs = kwargs['cmd']['epochs'] 41 | 42 | 43 | class ValidationSettings(Configurable): 44 | data_loaders = State() 45 | visualize = State() 46 | interval = State(default=100) 47 | exempt = State(default=-1) 48 | 49 | def __init__(self, **kwargs): 50 | kwargs['cmd'].update(is_train=False) 51 | self.load_all(**kwargs) 52 | 53 | cmd = kwargs['cmd'] 54 | self.visualize = cmd['visualize'] 55 | 56 | 57 | class EvaluationSettings(Configurable): 58 | data_loaders = State() 59 | visualize = State(default=True) 60 | resume = State() 61 | 62 | def __init__(self, **kwargs): 63 | self.load_all(**kwargs) 64 | 65 | 66 | class EvaluationSettings2(Configurable): 67 | structure = State() 68 | data_loaders = State() 69 | 70 | def __init__(self, **kwargs): 71 | self.load_all(**kwargs) 72 | 73 | 74 | class ShowSettings(Configurable): 75 | data_loader = State() 76 | representer = State() 77 | visualizer = State() 78 | 79 | def __init__(self, **kwargs): 80 | self.load_all(**kwargs) 81 | 82 | 83 | class Experiment(Configurable): 84 | structure = State(autoload=False) 85 | train = State() 86 | validation = State(autoload=False) 87 | evaluation = State(autoload=False) 88 | logger = State(autoload=True) 89 | 90 | def __init__(self, cmd={}, **kwargs): 91 | self.load('structure', cmd=cmd, **kwargs) 92 | 93 | if 'name' not in cmd: 94 | cmd['name'] = self.structure.model_name 95 | 96 | self.load_all(cmd=cmd, **kwargs) 97 | self.distributed = cmd.get('distributed', False) 98 | self.local_rank = cmd.get('local_rank', 0) 99 | 100 | if cmd.get('validate', False): 101 | self.load('validation', cmd=cmd, **kwargs) 102 | else: 103 | self.validation = None 104 | -------------------------------------------------------------------------------- /experiments/base.yaml: -------------------------------------------------------------------------------- 1 | import: [] 2 | package: 3 | - 'experiment' 4 | - 'structure.model' 5 | - 'training.learning_rate' 6 | - 'data' 7 | - 'data.nori_dataset' 8 | - 'data.processes' 9 | define: [] 10 | -------------------------------------------------------------------------------- /experiments/recognition/community-base.yaml: -------------------------------------------------------------------------------- 1 | package: # Packages will be automatically imported 2 | - 'data.data_loader' 3 | - 'data.file_dataset' 4 | - 'data.meta_loader' 5 | - 'data.processes' 6 | - 'concern.charsets' 7 | - 'data.unpack_msgpack_data' 8 | 9 | define: 10 | # Charsts are defined in concern/charsets.py 11 | - name: charset 12 | class: EnglishCharset 13 | case_sensitive: false 14 | 15 | # Metaloaders are defined in data/meta_loaders/meta_loader.py 16 | - name: meta_loader 17 | class: RecognitionMetaLoader 18 | cache: 19 | class: FileMetaCache 20 | force_reload: false 21 | 22 | # Training data for recognition experiments. 23 | - name: train_data 24 | class: DataLoader 25 | dataset: 26 | class: FileDataset 27 | file_paths: 28 | - '/data/text-spotter-data/syth90k/' 29 | - '/data/text-spotter-data/synth-text/cropped/' 30 | meta_loader: ^meta_loader 31 | processes: 32 | - class: MakeRecognitionLabel 33 | - class: ResizeImage 34 | image_size: [64, 256] 35 | mode: resize 36 | - class: NormalizeImage 37 | - class: FilterKeys 38 | required: ['image', 'label', 'length'] 39 | batch_size: 256 40 | 41 | - name: iiit 42 | class: DataLoader 43 | dataset: 44 | class: FileDataset 45 | file_paths: ["/data/text-spotter-data/test/iiit/"] 46 | max_size: 32 47 | num_workers: 16 48 | meta_loader: ^meta_loader 49 | charset: ^charset 50 | processes: 51 | - class: MakeRecognitionLabel 52 | - class: ResizeImage 53 | image_size: [64, 256] 54 | mode: resize 55 | - class: NormalizeImage 56 | - class: FilterKeys 57 | required: ['image', 'label'] 58 | batch_size: 16 59 | -------------------------------------------------------------------------------- /experiments/recognition/crnn-lmdb.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/base.yaml' 3 | - 'experiments/recognition/community-base.yaml' 4 | package: 5 | - 'structure.visualizers.sequence_recognition_visualizer' 6 | - 'concern.charsets' 7 | define: 8 | - name: train_data 9 | class: DataLoader 10 | dataset: 11 | class: LMDBDataset 12 | lmdb_paths: 13 | - "/data/text-spotter-data/synthtext/" 14 | #- "/data/text-spotter-data/iiit/" 15 | meta_loader: ^meta_loader 16 | processes: 17 | - class: MakeRecognitionLabel 18 | - class: ResizeImage 19 | image_size: [32, 128] 20 | mode: resize 21 | - class: NormalizeImage 22 | - class: FilterKeys 23 | required: ['image', 'label', 'length'] 24 | batch_size: 256 25 | 26 | - name: iiit 27 | class: DataLoader 28 | dataset: 29 | class: LMDBDataset 30 | lmdb_paths: 31 | - "/data/text-spotter-data/test/iiit/" 32 | max_size: 32 33 | num_workers: 16 34 | meta_loader: ^meta_loader 35 | charset: ^charset 36 | processes: 37 | - class: MakeRecognitionLabel 38 | - class: ResizeImage 39 | image_size: [32, 128] 40 | mode: resize 41 | - class: NormalizeImage 42 | - class: FilterKeys 43 | required: ['image', 'label'] 44 | batch_size: 16 45 | 46 | - name: BasicStructure 47 | class: Structure 48 | builder: 49 | class: Builder 50 | model: SequenceRecognitionModel 51 | model_args: 52 | backbone: crnn_backbone 53 | decoder: CRNNDecoder 54 | decoder_args: 55 | in_channels: 512 56 | inner_channels: 256 57 | need_reduce: False 58 | charset: ^charset 59 | representer: 60 | class: CTCRepresenter 61 | charset: ^charset 62 | measurer: 63 | class: SequenceRecognitionMeasurer 64 | 65 | - name: 'Experiment' 66 | class: Experiment 67 | structure: ^BasicStructure 68 | 69 | train: 70 | class: TrainSettings 71 | trainer_name: 'SequenceRecognizer' 72 | data_loader: ^train_data 73 | checkpoint: 74 | class: Checkpoint 75 | start_epoch: 0 76 | start_iter: 0 77 | resume: null 78 | model_saver: 79 | class: ModelSaver 80 | dir_path: "./model" 81 | save_interval: 8000 82 | signal_path: null 83 | scheduler: 84 | class: OptimizerScheduler 85 | optimizer: "Adam" 86 | learning_rate: 87 | class: MultiStepLR 88 | milestones: [3, 4, 5] 89 | gamma: 0.1 90 | lr: 0.001 91 | epochs: 5 92 | 93 | validation: &validate 94 | class: ValidationSettings 95 | data_loaders: 96 | iiit: ^iiit 97 | interval: 8000 98 | exempt: -1 99 | evaluation: *validate 100 | 101 | logger: 102 | class: Logger 103 | verbose: true 104 | level: "info" 105 | log_interval: 200 106 | -------------------------------------------------------------------------------- /experiments/recognition/crnn.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/base.yaml' 3 | - 'experiments/recognition/community-base.yaml' 4 | package: 5 | - 'structure.visualizers.sequence_recognition_visualizer' 6 | - 'concern.charsets' 7 | define: 8 | - name: train_data 9 | class: DataLoader 10 | dataset: 11 | class: FileDataset 12 | file_paths: 13 | - '/data/text-spotter-data/syth90k/' 14 | - '/data/text-spotter-data/synth-text/cropped/' 15 | meta_loader: ^meta_loader 16 | processes: 17 | - class: MakeRecognitionLabel 18 | - class: ResizeImage 19 | image_size: [32, 128] 20 | mode: resize 21 | - class: NormalizeImage 22 | - class: FilterKeys 23 | required: ['image', 'label', 'length'] 24 | batch_size: 256 25 | 26 | - name: iiit 27 | class: DataLoader 28 | dataset: 29 | class: FileDataset 30 | file_paths: ["/data/text-spotter-data/test/iiit/"] 31 | max_size: 32 32 | num_workers: 16 33 | meta_loader: ^meta_loader 34 | charset: ^charset 35 | processes: 36 | - class: MakeRecognitionLabel 37 | - class: ResizeImage 38 | image_size: [32, 128] 39 | mode: resize 40 | - class: NormalizeImage 41 | - class: FilterKeys 42 | required: ['image', 'label'] 43 | batch_size: 16 44 | 45 | - name: BasicStructure 46 | class: Structure 47 | builder: 48 | class: Builder 49 | model: SequenceRecognitionModel 50 | model_args: 51 | backbone: crnn_backbone 52 | decoder: CRNNDecoder 53 | decoder_args: 54 | in_channels: 512 55 | inner_channels: 256 56 | need_reduce: False 57 | charset: ^charset 58 | representer: 59 | class: CTCRepresenter 60 | charset: ^charset 61 | measurer: 62 | class: SequenceRecognitionMeasurer 63 | 64 | - name: 'Experiment' 65 | class: Experiment 66 | structure: ^BasicStructure 67 | 68 | train: 69 | class: TrainSettings 70 | trainer_name: 'SequenceRecognizer' 71 | data_loader: ^train_data 72 | checkpoint: 73 | class: Checkpoint 74 | start_epoch: 0 75 | start_iter: 0 76 | resume: null 77 | model_saver: 78 | class: ModelSaver 79 | dir_path: "./model" 80 | save_interval: 8000 81 | signal_path: null 82 | scheduler: 83 | class: OptimizerScheduler 84 | optimizer: "Adam" 85 | learning_rate: 86 | class: MultiStepLR 87 | milestones: [3, 4, 5] 88 | gamma: 0.1 89 | lr: 0.001 90 | epochs: 5 91 | 92 | validation: &validate 93 | class: ValidationSettings 94 | data_loaders: 95 | iiit: ^iiit 96 | interval: 8000 97 | exempt: -1 98 | evaluation: *validate 99 | 100 | logger: 101 | class: Logger 102 | verbose: true 103 | level: "info" 104 | log_interval: 200 105 | -------------------------------------------------------------------------------- /experiments/recognition/fpn50-attention-decoder.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/base.yaml' 3 | - 'experiments/recognition/community-base.yaml' 4 | package: 5 | - 'structure.visualizers.sequence_recognition_visualizer' 6 | - 'concern.charsets' 7 | define: 8 | - name: BasicStructure 9 | class: Structure 10 | builder: 11 | class: Builder 12 | model: SequenceRecognitionModel 13 | model_args: 14 | backbone: Resnet50FPN 15 | decoder: AttentionDecoder 16 | decoder_args: 17 | in_channels: 256 18 | charset: ^charset 19 | representer: 20 | class: SequenceRecognitionRepresenter 21 | charset: ^charset 22 | measurer: 23 | class: SequenceRecognitionMeasurer 24 | 25 | - name: 'Experiment' 26 | class: Experiment 27 | structure: ^BasicStructure 28 | 29 | train: 30 | class: TrainSettings 31 | trainer_name: 'SequenceRecognizer' 32 | data_loader: ^train_data 33 | checkpoint: 34 | class: Checkpoint 35 | start_epoch: 0 36 | start_iter: 0 37 | resume: null 38 | model_saver: 39 | class: ModelSaver 40 | dir_path: "./model" 41 | save_interval: 8000 42 | signal_path: null 43 | scheduler: 44 | class: OptimizerScheduler 45 | optimizer: "Adam" 46 | learning_rate: 47 | class: MultiStepLR 48 | milestones: [3, 4, 5] 49 | gamma: 0.1 50 | lr: 0.001 51 | epochs: 5 52 | 53 | validation: &validate 54 | class: ValidationSettings 55 | data_loaders: 56 | iiit: ^iiit 57 | interval: 8000 58 | exempt: 0 59 | evaluation: *validate 60 | 61 | logger: 62 | class: Logger 63 | verbose: true 64 | level: "info" 65 | log_interval: 200 66 | -------------------------------------------------------------------------------- /experiments/recognition/res50-ppm-2d-ctc.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/base.yaml' 3 | - 'experiments/recognition/community-base.yaml' 4 | package: 5 | - 'structure.visualizers.sequence_recognition_visualizer' 6 | - 'concern.charsets' 7 | define: 8 | - name: BasicStructure 9 | class: Structure 10 | builder: 11 | class: Builder 12 | model: SequenceRecognitionModel 13 | model_args: 14 | backbone: resnet50dilated_ppm 15 | decoder: CTCDecoder2D 16 | decoder_args: 17 | in_channels: 256 18 | charset: ^charset 19 | representer: 20 | class: CTCRepresenter2D 21 | charset: ^charset 22 | measurer: 23 | class: SequenceRecognitionMeasurer 24 | visualizer: 25 | class: CTCVisualizer2D 26 | 27 | - name: 'Experiment' 28 | class: Experiment 29 | structure: ^BasicStructure 30 | 31 | train: 32 | class: TrainSettings 33 | trainer_name: 'SequenceRecognizer' 34 | data_loader: ^train_data 35 | checkpoint: 36 | class: Checkpoint 37 | start_epoch: 0 38 | start_iter: 0 39 | resume: null 40 | model_saver: 41 | class: ModelSaver 42 | dir_path: "./model" 43 | save_interval: 8000 44 | signal_path: null 45 | scheduler: 46 | class: OptimizerScheduler 47 | optimizer: "Adam" 48 | learning_rate: 49 | class: MultiStepLR 50 | milestones: [3, 4, 5] 51 | gamma: 0.1 52 | lr: 0.001 53 | epochs: 5 54 | 55 | validation: &validate 56 | class: ValidationSettings 57 | data_loaders: 58 | iiit: ^iiit 59 | interval: 8000 60 | exempt: -1 61 | evaluation: *validate 62 | 63 | logger: 64 | class: Logger 65 | verbose: true 66 | level: "info" 67 | log_interval: 200 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /experiments/seg_detector/community-base.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/base.yaml' 3 | package: 4 | - 'decoders.seg_detector_loss' 5 | - 'data.meta_loader' 6 | - 'data.unpack_msgpack_data' 7 | define: 8 | - name: meta_loader 9 | class: DetectionMetaLoader 10 | cache: 11 | class: FileMetaCache 12 | force_reload: false 13 | 14 | # id meta loader is recommended to save memory 15 | - name: id_meta_loader 16 | class: DataIdMetaLoader 17 | force_reload: false 18 | cache: 19 | class: FileMetaCache 20 | return_dict: true 21 | 22 | - name: train_data 23 | class: LMDBDataset 24 | meta_loader: ^id_meta_loader 25 | lmdb_paths: 26 | - '/data/text-spotter-data/MLT-2017/train' 27 | unpack: 28 | class: TransformMsgpackData 29 | processes: 30 | - class: AugmentDetectionData 31 | augmenter_args: 32 | - ['Fliplr', 0.5] 33 | - {'cls': 'Affine', 'rotate': [-10, 10]} 34 | - ['Resize', [0.5, 3.0]] 35 | - class: RandomCropData 36 | size: [640, 640] 37 | max_tries: 10 38 | - class: MakeICDARData 39 | - class: MakeSegDetectionData 40 | - class: MakeBorderMap 41 | - class: NormalizeImage 42 | - class: FilterKeys 43 | superfluous: ['polygons', 'filename', 'shape', 'ignore_tags'] 44 | 45 | - name: validate_data 46 | class: LMDBDataset 47 | lmdb_paths: 48 | - '/data/text-spotter-data/MLT-2017/validate' 49 | meta_loader: ^id_meta_loader 50 | unpack: 51 | class: TransformMsgpackData 52 | processes: 53 | - class: AugmentDetectionData 54 | augmenter_args: 55 | - ['Resize', {'width': 1024, 'height': 576}] 56 | - class: MakeICDARData 57 | - class: MakeSegDetectionData 58 | - class: MakeBorderMap 59 | - class: NormalizeImage -------------------------------------------------------------------------------- /experiments/seg_detector/seg_detector_db.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/seg_detector/community-base.yaml' 3 | package: [] 4 | define: 5 | - name: train_data 6 | class: LMDBDataset 7 | meta_loader: ^id_meta_loader 8 | lmdb_paths: 9 | - '/data/text-spotter-data/MLT-2017/train' 10 | unpack: 11 | class: TransformMsgpackData 12 | processes: 13 | - class: AugmentDetectionData 14 | augmenter_args: 15 | - ['Fliplr', 0.5] 16 | - {'cls': 'Affine', 'rotate': [-10, 10]} 17 | - ['Resize', [0.5, 3.0]] 18 | - class: RandomCropData 19 | size: [640, 640] 20 | max_tries: 10 21 | - class: MakeICDARData 22 | - class: MakeSegDetectionData 23 | - class: MakeBorderMap 24 | - class: NormalizeImage 25 | - class: FilterKeys 26 | superfluous: ['polygons', 'filename', 'shape', 'ignore_tags'] 27 | 28 | - name: validate_data 29 | class: LMDBDataset 30 | lmdb_paths: 31 | - '/data/text-spotter-data/MLT-2017/validate' 32 | meta_loader: ^id_meta_loader 33 | unpack: 34 | class: TransformMsgpackData 35 | processes: 36 | - class: AugmentDetectionData 37 | augmenter_args: 38 | - ['Resize', {'width': 1024, 'height': 576}] 39 | - class: MakeICDARData 40 | - class: MakeSegDetectionData 41 | - class: MakeBorderMap 42 | - class: NormalizeImage 43 | 44 | 45 | - name: 'Experiment' 46 | class: Experiment 47 | structure: 48 | class: Structure 49 | builder: 50 | class: Builder 51 | model: SegDetectorModel 52 | model_args: 53 | backbone: deformable_resnet50 54 | decoder: SegDetector 55 | decoder_args: 56 | adaptive: True 57 | in_channels: [256, 512, 1024, 2048] 58 | k: 50 59 | loss_class: L1BalanceCELoss 60 | 61 | representer: 62 | class: SegDetectorRepresenter 63 | max_candidates: 1000 64 | measurer: 65 | class: QuadMeasurer 66 | visualizer: 67 | class: SegDetectorVisualizer 68 | train: 69 | class: TrainSettings 70 | data_loader: 71 | class: DataLoader 72 | dataset: ^train_data 73 | batch_size: 16 74 | num_workers: 8 75 | checkpoint: 76 | class: Checkpoint 77 | start_epoch: 0 78 | start_iter: 0 79 | resume: null 80 | model_saver: 81 | class: ModelSaver 82 | dir_path: model 83 | save_interval: 1800 84 | signal_path: save 85 | scheduler: 86 | class: OptimizerScheduler 87 | optimizer: "SGD" 88 | optimizer_args: 89 | lr: 0.007 90 | momentum: 0.9 91 | weight_decay: 0.0001 92 | learning_rate: 93 | class: DecayLearningRate 94 | epochs: 400 95 | epochs: 400 96 | 97 | validation: &validate 98 | class: ValidationSettings 99 | data_loaders: 100 | MLT-17: 101 | class: DataLoader 102 | dataset: ^validate_data 103 | batch_size: 4 104 | num_workers: 8 105 | collect_fn: 106 | class: ICDARCollectFN 107 | visualize: false 108 | interval: 1800 109 | exempt: -1 110 | 111 | logger: 112 | class: Logger 113 | verbose: true 114 | level: info 115 | log_interval: 450 116 | 117 | evaluation: *validate 118 | -------------------------------------------------------------------------------- /ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .ctc_2d.ctc_loss_2d import CTCLoss2DFunction, ctc_loss_2d 2 | -------------------------------------------------------------------------------- /ops/ctc_2d/csrc/ctc2d.cpp: -------------------------------------------------------------------------------- 1 | #include "ctc2d.h" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("ctc2d_forward", &ctc2d_forward, "ctc2d_forward (cuda)"); 5 | m.def("ctc2d_backward", &ctc2d_backward, "ctc2d_backward (cuda)"); 6 | } -------------------------------------------------------------------------------- /ops/ctc_2d/csrc/ctc2d.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef WITH_CUDA 4 | #include "cuda/ctc2d.h" 5 | #endif 6 | 7 | std::tuple ctc2d_forward( 8 | at::Tensor log_probs, at::Tensor targets, at::Tensor input_lengths, at::Tensor target_lengths, 9 | int64_t BLANK, float TINY 10 | ) { 11 | if (log_probs.type().is_cuda()) { 12 | #ifdef WITH_CUDA 13 | return ctc2d_cuda_forward( 14 | log_probs, targets, input_lengths, target_lengths, BLANK, TINY 15 | ); 16 | #else 17 | AT_ERROR("Not compiled with GPU support"); 18 | #endif 19 | } 20 | AT_ERROR("Not implemented on the CPU"); 21 | } 22 | 23 | 24 | at::Tensor ctc2d_backward( 25 | at::Tensor grad_out, 26 | at::Tensor log_probs, at::Tensor targets, at::Tensor input_lengths, at::Tensor target_lengths, 27 | at::Tensor neg_log_likelihood, at::Tensor log_alpha, 28 | int64_t BLANK 29 | ) { 30 | if (log_probs.type().is_cuda()) { 31 | #ifdef WITH_CUDA 32 | return ctc2d_cuda_backward( 33 | grad_out, 34 | log_probs, targets, input_lengths, target_lengths, 35 | neg_log_likelihood, log_alpha, 36 | BLANK 37 | ); 38 | #else 39 | AT_ERROR("Not compiled with GPU support"); 40 | #endif 41 | } 42 | AT_ERROR("Not implemented on the CPU"); 43 | } -------------------------------------------------------------------------------- /ops/ctc_2d/csrc/cuda/ctc2d.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | 5 | std::tuple ctc2d_cuda_forward( 6 | const at::Tensor log_probs, const at::Tensor targets, 7 | const at::Tensor input_lengths, const at::Tensor target_lengths, 8 | int64_t BLANK, float TINY 9 | ); 10 | 11 | 12 | at::Tensor ctc2d_cuda_backward( 13 | const at::Tensor grad_out, 14 | const at::Tensor log_probs, const at::Tensor targets, 15 | const at::Tensor input_lengths, const at::Tensor target_lengths, 16 | const at::Tensor neg_log_likelihood, const at::Tensor log_alpha, 17 | int64_t BLANK 18 | ); -------------------------------------------------------------------------------- /ops/ctc_2d/csrc/cuda/ctc2d_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | 13 | void print_tsize(at::Tensor t, const char *msg); 14 | 15 | 16 | std::tuple ctc2d_gpu_template( 17 | at::Tensor log_probs, at::Tensor targets, 18 | at::Tensor input_lengths, at::Tensor target_lengths, int64_t BLANK, float TINY 19 | ); 20 | 21 | 22 | at::Tensor ctc2d_gpu_backward_template( 23 | const at::Tensor grad_out, const at::Tensor log_probs, const at::Tensor targets, 24 | const at::Tensor input_lengths, const at::Tensor target_lengths, 25 | const at::Tensor neg_log_likelihood, const at::Tensor log_alpha, 26 | int64_t BLANK 27 | ); 28 | 29 | 30 | std::tuple ctc2d_cuda_forward( 31 | const at::Tensor log_probs, const at::Tensor targets, 32 | const at::Tensor input_lengths, const at::Tensor target_lengths, 33 | int64_t BLANK, float TINY 34 | ) { 35 | AT_CHECK(log_probs.is_contiguous(), "log_probs tensor has to be contiguous"); 36 | 37 | // shape check 38 | int64_t batch_size = log_probs.size(2); 39 | int64_t num_labels = log_probs.size(3); 40 | AT_CHECK((0 <= BLANK) && (BLANK < num_labels), "blank must be in label range"); 41 | AT_CHECK(input_lengths.size(0) == batch_size, "input_lengths must be of size batch_size"); 42 | AT_CHECK(target_lengths.size(0) == batch_size, "target_lengths must be of size batch_size"); 43 | 44 | return ctc2d_gpu_template(log_probs, targets, input_lengths, target_lengths, BLANK, TINY); 45 | } 46 | 47 | 48 | at::Tensor ctc2d_cuda_backward( 49 | const at::Tensor grad_out, const at::Tensor log_probs, const at::Tensor targets, 50 | const at::Tensor input_lengths, const at::Tensor target_lengths, 51 | const at::Tensor neg_log_likelihood, const at::Tensor log_alpha, 52 | int64_t BLANK 53 | ) { 54 | return ctc2d_gpu_backward_template( 55 | grad_out, 56 | log_probs, targets, input_lengths, target_lengths, 57 | neg_log_likelihood, log_alpha, 58 | BLANK 59 | ); 60 | } -------------------------------------------------------------------------------- /ops/ctc_2d/ctc_loss_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from . import ctc_2d_csrc 5 | 6 | 7 | class CTCLoss2DFunction(Function): 8 | 9 | @staticmethod 10 | def forward(ctx, log_probs, targets, input_lengths, target_lengths, blank=0): 11 | ctx.blank = blank 12 | if not log_probs.is_cuda: 13 | raise NotImplementedError 14 | 15 | neg_log_likelihood, log_alpha = ctc_2d_csrc.ctc2d_forward( 16 | log_probs, targets, input_lengths, target_lengths, blank, torch.finfo().tiny) 17 | 18 | if log_probs.requires_grad: 19 | ctx.save_for_backward(log_probs, targets, input_lengths, 20 | target_lengths, neg_log_likelihood, log_alpha) 21 | return neg_log_likelihood 22 | 23 | @staticmethod 24 | def backward(ctx, grad_output): 25 | log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha = ctx.saved_tensors 26 | 27 | grad_log_probs = torch.ones(2, 3) 28 | if ctx.needs_input_grad[0]: 29 | grad_log_probs = ctc_2d_csrc.ctc2d_backward( 30 | grad_output, log_probs, targets, input_lengths, target_lengths, 31 | neg_log_likelihood, log_alpha, 32 | ctx.blank 33 | ) 34 | return grad_log_probs, None, None, None, None 35 | 36 | 37 | ctc_loss_2d = CTCLoss2DFunction.apply 38 | -------------------------------------------------------------------------------- /ops/ctc_2d/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import os 5 | 6 | import torch 7 | from setuptools import find_packages 8 | from setuptools import setup 9 | from torch.utils.cpp_extension import CUDA_HOME 10 | from torch.utils.cpp_extension import CppExtension 11 | from torch.utils.cpp_extension import CUDAExtension 12 | 13 | requirements = ["torch", "torchvision"] 14 | 15 | 16 | op_name = 'ctc_2d' 17 | def get_extensions(): 18 | this_dir = os.path.dirname(os.path.abspath(__file__)) 19 | ext_modules = [] 20 | extensions_dir = os.path.join(this_dir, "csrc") 21 | 22 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 23 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 24 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 25 | 26 | sources = main_file + source_cpu 27 | extension = CppExtension 28 | 29 | extra_compile_args = {"cxx": []} 30 | define_macros = [] 31 | 32 | if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1": 33 | extension = CUDAExtension 34 | sources += source_cuda 35 | define_macros += [("WITH_CUDA", None)] 36 | extra_compile_args["nvcc"] = [ 37 | "-DCUDA_HAS_FP16=1", 38 | "-D__CUDA_NO_HALF_OPERATORS__", 39 | "-D__CUDA_NO_HALF_CONVERSIONS__", 40 | "-D__CUDA_NO_HALF2_OPERATORS__", 41 | ] 42 | 43 | sources = [os.path.join(extensions_dir, s) for s in sources] 44 | include_dirs = [extensions_dir] 45 | ext_modules.append( 46 | extension( 47 | op_name + "_csrc", 48 | sources, 49 | include_dirs=include_dirs, 50 | define_macros=define_macros, 51 | extra_compile_args=extra_compile_args, 52 | ) 53 | ) 54 | return ext_modules 55 | 56 | 57 | setup( 58 | name=op_name, 59 | version="0.0", 60 | packages=find_packages(), 61 | ext_modules=get_extensions(), 62 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 63 | ) 64 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | Pillow 2 | msgpack 3 | numpy 4 | opencv-python==3.4.5.20 5 | shapely 6 | imgaug 7 | PyYAML 8 | torch 9 | torchvision 10 | tensorboardX 11 | hanziconv 12 | pyclipper 13 | gevent 14 | gevent-websocket 15 | anyconfig 16 | munch 17 | sortedcontainers 18 | ipdb 19 | editdistance 20 | tqdm 21 | lmdb 22 | redis 23 | -e git+https://github.com/NVIDIA/apex.git#egg=apex-master 24 | -------------------------------------------------------------------------------- /scripts/json_to_lmdb.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import json 3 | from fire import Fire 4 | from collections import defaultdict 5 | import os 6 | import pickle 7 | from tqdm import tqdm 8 | 9 | 10 | def main(json_path=None, lmdb_path=None): 11 | assert json_path is not None, 'json_path is needed' 12 | if lmdb_path is None: 13 | lmdb_path = json_path 14 | 15 | meta = os.path.join(json_path, 'meta.json') 16 | data_ids = [] 17 | value = {} 18 | env = lmdb.Environment(lmdb_path, subdir=True, 19 | map_size=int(1e9), max_dbs=2, lock=False) 20 | db_extra = env.open_db('extra'.encode(), create=True) 21 | db_image = env.open_db('image'.encode(), create=True) 22 | with open(meta, 'r') as meta_reader: 23 | for line in tqdm(meta_reader): 24 | single_meta = json.loads(line) 25 | data_id = os.path.join(json_path, single_meta['filename']) 26 | data_id = str(data_id.encode('utf-8').decode('utf-8')) 27 | with open(data_id.encode(), 'rb') as file_reader: 28 | image = file_reader.read() 29 | value['extra'] = {} 30 | for key in single_meta['extra']: 31 | value['extra'][key] = single_meta['extra'][key] 32 | with env.begin(write=True) as lmdb_writer: 33 | lmdb_writer.put(data_id.encode(), 34 | pickle.dumps(value), db=db_extra) 35 | with env.begin(write=True) as image_writer: 36 | image_writer.put(data_id.encode(), image, db=db_image) 37 | env.close() 38 | 39 | 40 | if __name__ == "__main__": 41 | Fire(main) 42 | -------------------------------------------------------------------------------- /scripts/nori_to_lmdb.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import lmdb 3 | import nori2 as nori 4 | from tqdm import tqdm 5 | from fire import Fire 6 | 7 | 8 | def main(nori_path, lmdb_path=None): 9 | if lmdb_path is None: 10 | lmdb_path = nori_path 11 | env = lmdb.Environment(lmdb_path, map_size=int( 12 | 5e10), writemap=True, max_dbs=2, lock=False) 13 | fetcher = nori.Fetcher(nori_path) 14 | db_extra = env.open_db('extra'.encode(), create=True) 15 | db_image = env.open_db('image'.encode(), create=True) 16 | with nori.open(nori_path, 'r') as nr: 17 | with env.begin(write=True) as writer: 18 | for data_id, data, meta in tqdm(nr.scan()): 19 | value = {} 20 | image = fetcher.get(data_id) 21 | value['extra'] = {} 22 | for key in meta['extra']: 23 | value['extra'][key] = meta['extra'][key] 24 | writer.put(data_id.encode(), pickle.dumps(value), db=db_extra) 25 | writer.put(data_id.encode(), image, db=db_image) 26 | env.close() 27 | print('Finished') 28 | 29 | 30 | if __name__ == '__main__': 31 | Fire(main) 32 | -------------------------------------------------------------------------------- /structure/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Megvii-CSG/MegReader/4cb5635d109f81f020780f39de9fca29a44c9f82/structure/__init__.py -------------------------------------------------------------------------------- /structure/builder.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | import structure.model 6 | from structure.ensemble_model import EnsembleModel 7 | from concern.config import Configurable, State 8 | 9 | 10 | class Builder(Configurable): 11 | model = State() 12 | model_args = State() 13 | 14 | def __init__(self, cmd={}, **kwargs): 15 | self.load_all(**kwargs) 16 | if 'backbone' in cmd: 17 | self.model_args['backbone'] = cmd['backbone'] 18 | 19 | @property 20 | def model_name(self): 21 | return self.model + '-' + getattr(structure.model, self.model).model_name(self.model_args) 22 | 23 | def build(self, device, distributed=False, local_rank: int = 0): 24 | Model = getattr(structure.model,self.model) 25 | model = Model(self.model_args, device, 26 | distributed=distributed, local_rank=local_rank) 27 | return model 28 | 29 | 30 | class EnsembleBuilder(Configurable): 31 | '''Ensemble multiple models into one model 32 | Input: 33 | builders: A dict which consists of several builders. 34 | Example: 35 | >>> builder: 36 | class: EnsembleBuilder 37 | builders: 38 | ctc: 39 | model: CTCModel 40 | atten: 41 | model: AttentionDecoderModel 42 | ''' 43 | builders = State(default={}) 44 | 45 | def __init__(self, cmd={}, **kwargs): 46 | resume_paths = dict() 47 | for key, value_dict in kwargs['builders'].items(): 48 | resume_paths[key] = value_dict.pop('resume') 49 | self.resume_paths = resume_paths 50 | self.load_all(**kwargs) 51 | 52 | @property 53 | def model_name(self): 54 | return 'ensembled-model' 55 | 56 | def build(self, device, *args, **kwargs): 57 | models = OrderedDict() 58 | for key, builder in self.builders.items(): 59 | models[key] = builder.build(device=device, *args, **kwargs) 60 | models[key].load_state_dict(torch.load( 61 | self.resume_paths[key], map_location=device), 62 | strict=True) 63 | return EnsembleModel(models) 64 | -------------------------------------------------------------------------------- /structure/ensemble_model.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import pickle 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from backbones import * 11 | from decoders import * 12 | 13 | 14 | class EnsembleModel(nn.Module): 15 | def __init__(self, models, *args, **kwargs): 16 | super(EnsembleModel, self).__init__() 17 | self.models = nn.ModuleDict(models) 18 | 19 | def forward(self, batch, select_key=None, training=False): 20 | pred = dict() 21 | 22 | for key, module in self.models.items(): 23 | if select_key is not None and key != select_key: 24 | continue 25 | pred[key] = module(batch, training) 26 | return pred 27 | -------------------------------------------------------------------------------- /structure/measurers/__init__.py: -------------------------------------------------------------------------------- 1 | from .textsnake import TextsnakeMeasurer 2 | from .classification_measurer import ClassificationMeasurer 3 | from .sequence_recognition_measurer import SequenceRecognitionMeasurer 4 | from .icdar_detection_measurer import ICDARDetectionMeasurer 5 | from .quad_measurer import QuadMeasurer 6 | from .grid_sampling_measurer import GridSamplingMeasurer 7 | -------------------------------------------------------------------------------- /structure/measurers/classification_measurer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from concern import AverageMeter 4 | from concern.config import Configurable 5 | 6 | 7 | class ClassificationMeasurer(Configurable): 8 | def __init__(self, **kwargs): 9 | pass 10 | 11 | def measure(self, batch, output): 12 | correct = torch.eq(output.cpu(), batch[1]).numpy() 13 | return correct 14 | 15 | def validate_measure(self, batch, output): 16 | return self.measure(batch, output), [0] 17 | 18 | def gather_measure(self, raw_metrics, logger=None): 19 | accuracy_meter = AverageMeter() 20 | for raw_metric in raw_metrics: 21 | total = raw_metric.shape[0] 22 | accuracy = raw_metric.sum() / total 23 | accuracy_meter.update(accuracy, total) 24 | 25 | return { 26 | 'accuracy': accuracy_meter, 27 | } 28 | -------------------------------------------------------------------------------- /structure/measurers/grid_sampling_measurer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from concern import AverageMeter 4 | from concern.config import Configurable 5 | 6 | 7 | class GridSamplingMeasurer(Configurable): 8 | def __init__(self, **kwargs): 9 | pass 10 | 11 | def measure(self, batch, output): 12 | return 0 13 | 14 | def validate_measure(self, batch, output): 15 | return 1, [0] 16 | 17 | def gather_measure(self, raw_metrics, logger=None): 18 | return { 19 | 'accuracy': 0, 20 | } 21 | -------------------------------------------------------------------------------- /structure/measurers/icdar_detection_measurer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import shutil 4 | 5 | import numpy as np 6 | import json 7 | 8 | from concern import Logger, AverageMeter 9 | from concern.config import Configurable 10 | 11 | 12 | class ICDARDetectionMeasurer(Configurable): 13 | def __init__(self, **kwargs): 14 | self.visualized = False 15 | 16 | def measure(self, batch, output): 17 | pairs = [] 18 | for i in range(len(batch[-1])): 19 | pairs.append((batch[-1][i], output[i][0])) 20 | return pairs 21 | 22 | def validate_measure(self, batch, output): 23 | return self.measure(batch, output), [int(self.visualized)] 24 | 25 | def evaluate_measure(self, batch, output): 26 | return self.measure(batch, output), np.linspace(0, batch[0].shape[0]).tolist() 27 | 28 | def gather_measure(self, name, raw_metrics, logger: Logger): 29 | save_dir = os.path.join(logger.log_dir, name) 30 | shutil.rmtree(save_dir, ignore_errors=True) 31 | if not os.path.exists(save_dir): 32 | os.makedirs(save_dir) 33 | log_file_path = os.path.join(save_dir, name + '.log') 34 | count = 0 35 | for batch_pairs in raw_metrics: 36 | for _filename, boxes in batch_pairs: 37 | boxes = np.array(boxes).reshape(-1, 8).astype(np.int32) 38 | filename = 'res_' + _filename.replace('.jpg', '.txt') 39 | with open(os.path.join(save_dir, filename), 'wt') as f: 40 | if len(boxes) == 0: 41 | f.write('') 42 | for box in boxes: 43 | f.write(','.join(map(str, box)) + '\n') 44 | count += 1 45 | 46 | self.packing(save_dir) 47 | try: 48 | raw_out = subprocess.check_output(['python assets/ic15_eval/script.py -m=' + name 49 | + ' -g=assets/ic15_eval/gt.zip -s=' + 50 | os.path.join(save_dir, 'submit.zip') + 51 | '|tee -a ' + log_file_path], 52 | timeout=30, shell=True) 53 | except subprocess.TimeoutExpired: 54 | return {} 55 | raw_out = raw_out.decode().replace('Calculated!', '') 56 | dict_out = json.loads(raw_out) 57 | return {k: AverageMeter().update(v, n=count) for k, v in dict_out.items()} 58 | 59 | def packing(self, save_dir): 60 | pack_name = 'submit.zip' 61 | os.system( 62 | 'zip -r -j -q ' + 63 | os.path.join(save_dir, pack_name) + ' ' + save_dir + '/*.txt') 64 | -------------------------------------------------------------------------------- /structure/measurers/quad_measurer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from concern import Logger, AverageMeter 4 | from concern.config import Configurable 5 | from concern.icdar2015_eval.detection.iou import DetectionIoUEvaluator 6 | 7 | 8 | class QuadMeasurer(Configurable): 9 | def __init__(self, **kwargs): 10 | self.evaluator = DetectionIoUEvaluator() 11 | 12 | def measure(self, batch, output): 13 | ''' 14 | batch: (image, polygons, ignore_tags 15 | batch: a dict produced by dataloaders. 16 | image: tensor of shape (N, C, H, W). 17 | polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions. 18 | ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not. 19 | shape: the original shape of images. 20 | filename: the original filenames of images. 21 | output: (polygons, ...) 22 | ''' 23 | results = [] 24 | gt_polyons_batch = batch['polygons'] 25 | ignore_tags_batch = batch['ignore_tags'] 26 | pred_polygons_batch = np.array(output[0]) 27 | for polygons, pred_polygons, ignore_tags in\ 28 | zip(gt_polyons_batch, pred_polygons_batch, ignore_tags_batch): 29 | gt = [dict(points=polygons[i], ignore=ignore_tags[i]) 30 | for i in range(len(polygons))] 31 | pred = [dict(points=pred_polygons[i]) 32 | for i in range(len(pred_polygons))] 33 | results.append(self.evaluator.evaluate_image(gt, pred)) 34 | return results 35 | 36 | def validate_measure(self, batch, output): 37 | return self.measure(batch, output), [0] 38 | 39 | def evaluate_measure(self, batch, output): 40 | return self.measure(batch, output),\ 41 | np.linspace(0, batch['image'].shape[0]).tolist() 42 | 43 | def gather_measure(self, raw_metrics, logger: Logger): 44 | raw_metrics = [image_metrics 45 | for batch_metrics in raw_metrics 46 | for image_metrics in batch_metrics] 47 | 48 | result = self.evaluator.combine_results(raw_metrics) 49 | 50 | precision = AverageMeter() 51 | recall = AverageMeter() 52 | fmeasure = AverageMeter() 53 | 54 | precision.update(result['precision'], n=len(raw_metrics)) 55 | recall.update(result['recall'], n=len(raw_metrics)) 56 | fmeasure_score = 2 * precision.val * recall.val /\ 57 | (precision.val + recall.val + 1e-8) 58 | fmeasure.update(fmeasure_score) 59 | 60 | return { 61 | 'precision': precision, 62 | 'recall': recall, 63 | 'fmeasure': fmeasure 64 | } 65 | -------------------------------------------------------------------------------- /structure/measurers/textsnake.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from concern import AverageMeter 4 | from concern.config import Configurable 5 | from concern.icdar2015_eval.detection.iou import DetectionIoUEvaluator 6 | 7 | 8 | class TextsnakeMeasurer(Configurable): 9 | def __init__(self, **kwargs): 10 | self.evaluator = DetectionIoUEvaluator() 11 | 12 | def validate_measure(self, batch, output): 13 | return self.measure(batch, output), [0] 14 | 15 | def evaluate_measure(self, batch, output): 16 | return self.measure(batch, output), np.linspace(0, batch[0].shape[0]).tolist() 17 | 18 | def measure(self, batch, output): 19 | batch_meta = output['meta'] 20 | batch_gt_polys = output['polygons_gt'] 21 | batch_pred_polys = output['contours_pred'] 22 | 23 | results = [] 24 | for gt_polys, pred_polys, meta in zip(batch_gt_polys, batch_pred_polys, batch_meta): 25 | gt = [{'points': points, 'ignore': not cares} for points, cares in zip(gt_polys, meta['cares'])] 26 | pred = [{'points': points} for points in pred_polys] 27 | result = self.evaluator.evaluate_image(gt, pred) 28 | results.append(result) 29 | 30 | return results 31 | 32 | def gather_measure(self, raw_metrics, logger=None): 33 | raw_metrics = [image_metrics for batch_metrics in raw_metrics for image_metrics in batch_metrics] 34 | 35 | result = self.evaluator.combine_results(raw_metrics) 36 | 37 | precision = AverageMeter() 38 | recall = AverageMeter() 39 | 40 | precision.update(result['precision'], n=len(raw_metrics)) 41 | recall.update(result['recall'], n=len(raw_metrics)) 42 | 43 | return { 44 | 'precision': precision, 45 | 'recall': recall, 46 | } 47 | -------------------------------------------------------------------------------- /structure/models/maskrcnn_benchmark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Megvii-CSG/MegReader/4cb5635d109f81f020780f39de9fca29a44c9f82/structure/models/maskrcnn_benchmark/__init__.py -------------------------------------------------------------------------------- /structure/representers/__init__.py: -------------------------------------------------------------------------------- 1 | from .textsnake import TextsnakeRepresenter 2 | from .sequence_recognition_representer import SequenceRecognitionRepresenter 3 | from .classification_representer import ClassificationRepresenter 4 | from .ctc_representer import CTCRepresenter 5 | from .ctc_representer2d import CTCRepresenter2D 6 | from .seg_recognition_representer import SegRecognitionRepresenter 7 | from .integral_regression_representer import IntegralRegressionRepresenter 8 | from .seg_detector_representer import SegDetectorRepresenter -------------------------------------------------------------------------------- /structure/representers/classification_representer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from concern.config import Configurable 3 | 4 | 5 | class ClassificationRepresenter(Configurable): 6 | def __init__(self, **kwargs): 7 | pass 8 | 9 | def represent(self, batch, pred): 10 | return torch.argmax(pred, dim=1) 11 | -------------------------------------------------------------------------------- /structure/representers/ctc_representer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import config 4 | from .sequence_recognition_representer import SequenceRecognitionRepresenter 5 | from concern.config import State 6 | 7 | 8 | class CTCRepresenter(SequenceRecognitionRepresenter): 9 | def represent(self, batch, pred): 10 | ''' 11 | decode ctc using greedy search 12 | pred: (N, C, W) 13 | return: 14 | output: { 15 | 'label_string': string of gt label, 16 | 'pred_string': string of prediction 17 | } 18 | ''' 19 | labels = batch['label'] 20 | pred = torch.argmax(pred, dim=1) 21 | pred = pred.select(1, 0) # N, W 22 | output = torch.zeros( 23 | pred.shape[0], pred.shape[-1], dtype=torch.int) + self.charset.blank 24 | for i in range(pred.shape[0]): 25 | valid = 0 26 | previous = self.charset.blank 27 | for j in range(pred.shape[1]): 28 | c = pred[i][j] 29 | if c == previous or c == self.charset.unknown: 30 | continue 31 | if not c == self.charset.blank: 32 | output[i][valid] = c 33 | valid += 1 34 | previous = c 35 | 36 | result = [] 37 | for i in range(labels.shape[0]): 38 | label_str = self.label_to_string(labels[i]) 39 | pred_str = self.label_to_string(output[i]) 40 | result.append({ 41 | 'label_string': label_str, 42 | 'pred_string': pred_str 43 | }) 44 | 45 | return result 46 | -------------------------------------------------------------------------------- /structure/representers/ctc_representer2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from concern.config import State 4 | from .sequence_recognition_representer import SequenceRecognitionRepresenter 5 | 6 | 7 | class CTCRepresenter2D(SequenceRecognitionRepresenter): 8 | max_size = State(default=32) 9 | 10 | def represent(self, batch, pred): 11 | ''' 12 | This class is supposed to be used with 13 | the measurer `SequenceRecognitionMeasurer` 14 | ''' 15 | classify, mask = pred 16 | labels = batch['label'] 17 | 18 | ''' 19 | classify: (N, C, H, W) 20 | mask: (N, 1, H, W) 21 | return: 22 | output: { 23 | 'label_string': string of gt label, 24 | 'pred_string': string of prediction 25 | } 26 | ''' 27 | heatmap = classify * mask 28 | classify = classify.to('cpu') 29 | mask = mask.to('cpu') 30 | paths = heatmap.max(1, keepdim=True)[0].argmax( 31 | 2, keepdim=True) # (N, 1, 1, W) 32 | C = classify.size(1) 33 | paths = paths.repeat(1, C, 1, 1) # (N, C, 1, W) 34 | selected_probabilities = heatmap.gather(2, paths) # (N, C, W) 35 | pred = selected_probabilities.argmax(1).squeeze(1) # (N, W) 36 | output = torch.zeros( 37 | pred.shape[0], pred.shape[-1], dtype=torch.int) + self.charset.blank 38 | pred = pred.to('cpu') 39 | output = output.to('cpu') 40 | 41 | for i in range(pred.shape[0]): 42 | valid = 0 43 | previous = self.charset.blank 44 | for j in range(pred.shape[1]): 45 | c = pred[i][j] 46 | if c == previous or c == self.charset.unknown: 47 | continue 48 | if not c == self.charset.blank: 49 | output[i][valid] = c 50 | valid += 1 51 | previous = c 52 | 53 | result = [] 54 | for i in range(labels.shape[0]): 55 | result.append({ 56 | 'label_string': self.label_to_string(labels[i]), 57 | 'pred_string': self.label_to_string(output[i]), 58 | 'mask': mask[i][0], 59 | 'classify': classify[i] 60 | }) 61 | 62 | return result 63 | -------------------------------------------------------------------------------- /structure/representers/east.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from concern.convert import to_np 7 | from concern.config import Configurable, State 8 | 9 | 10 | class EASTRepresenter(Configurable): 11 | heatmap_thr = State(default=0.5) 12 | 13 | def __init__(self, **kwargs): 14 | self.load_all(**kwargs) 15 | 16 | def get_polygons(self, heatmask, densebox): 17 | _, contours, _ = cv2.findContours(heatmask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 18 | 19 | polygons = [] 20 | for contour in contours: 21 | points = [] 22 | for x, y in contour[:, 0]: 23 | quad = densebox[:, y, x].reshape(4, 2) + (x, y) 24 | points.extend(quad) 25 | quad = cv2.boxPoints(cv2.minAreaRect(np.array(points, np.float32))) 26 | polygons.append(quad) 27 | 28 | return polygons 29 | 30 | def represent_batch(self, batch): 31 | image, label, meta = batch 32 | batch_size = image.shape[0] 33 | 34 | output = { 35 | 'image': to_np(image), 36 | 'heatmap': to_np(label['heatmap'][:, 0]), 37 | 'heatmap_weight': to_np(label['heatmap_weight']), 38 | 'densebox': to_np(label['densebox']), 39 | 'densebox_weight': to_np(label['densebox_weight']), 40 | 'meta': [pickle.loads(value) for value in meta], 41 | } 42 | output['heatmask'] = (output['heatmap'] > self.heatmap_thr).astype('uint8') 43 | 44 | output['polygons'] = [self.get_polygons( 45 | output['heatmask'][i], 46 | output['densebox'][i], 47 | ) for i in range(batch_size)] 48 | 49 | return output 50 | 51 | def represent(self, batch, pred): 52 | image, label, meta = batch 53 | batch_size = image.shape[0] 54 | 55 | output = self.represent_batch(batch) 56 | output = { 57 | **output, 58 | 'heatmap_pred': to_np(pred['heatmap'][:, 0]), 59 | 'densebox_pred': to_np(pred['densebox']), 60 | } 61 | output['heatmask_pred'] = (output['heatmap_pred'] > self.heatmap_thr).astype('uint8') 62 | output['polygons_pred'] = [self.get_polygons( 63 | output['heatmask_pred'][i], 64 | output['densebox_pred'][i], 65 | ) for i in range(batch_size)] 66 | 67 | return output 68 | -------------------------------------------------------------------------------- /structure/representers/ensemble_ctc_representer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | import config 6 | from concern.config import State 7 | from .ctc_representer import CTCRepresenter 8 | from concern.charsets import Charset 9 | import concern.webcv2 as webcv2 10 | from data.nori_dataset import NoriDataset 11 | 12 | 13 | class EnsembleCTCRepresenter(CTCRepresenter): 14 | '''decode multiple ensembled ctc models. 15 | ''' 16 | charsets = State(default={}) 17 | offset = State(default=0.5) 18 | 19 | def get_charset(self, key): 20 | return self.charsets.get(key, self.charset) 21 | 22 | def one_hot_to_chars(self, score: torch.Tensor, charset: Charset): 23 | ''' 24 | Args: 25 | score: (C, 1, W) 26 | charset: The corresponding charset. 27 | Return: 28 | chars: the chars with maximum scores. 29 | scores: corresponding scores. 30 | ''' 31 | pred = torch.argmax(score, dim=0) 32 | pred = pred[0] 33 | chars = [] 34 | scores = [] 35 | 36 | for w in range(pred.shape[0]): 37 | chars.append(charset[pred[w]]) 38 | scores.append(score[pred[w], 0, w]) 39 | return chars, scores 40 | 41 | def represent(self, batch, preds: dict, max_size=config.max_size): 42 | _, labels, _ = batch 43 | batch_size = labels.shape[0] 44 | 45 | output = [] 46 | string_scores = [] 47 | for batch_index in range(batch_size): 48 | pred_sequences = OrderedDict() 49 | pred_scores = OrderedDict() 50 | for key, pred in preds.items(): 51 | chars, scores = self.one_hot_to_chars( 52 | pred[batch_index], self.get_charset(key)) 53 | pred_sequences[key] = chars 54 | pred_scores[key] = scores 55 | 56 | result_string, pred_score = self.merge_encode(pred_sequences, pred_scores) 57 | output.append(result_string) 58 | string_scores.append(pred_score) 59 | 60 | result = [] 61 | for i in range(labels.shape[0]): 62 | result.append({ 63 | 'label_string': self.label_to_string(labels[i]), 64 | 'pred_string': ''.join(output[i]), 65 | 'score': string_scores[i], 66 | }) 67 | 68 | return result 69 | 70 | def merge_encode(self, sequences, scores): 71 | result = [] 72 | previous = self.charset.blank_char 73 | main_sequence = sequences['main'] 74 | 75 | score_sum = 0 76 | for index, _char in enumerate(main_sequence): 77 | char, score = self.choose_char(_char, 78 | [s[index] for s in sequences.values()], 79 | [s[index] for s in scores.values()]) 80 | 81 | if char == previous or self.charset.is_empty_char(char): 82 | previous = char 83 | continue 84 | else: 85 | previous = char 86 | result.append(previous) 87 | score_sum += score 88 | if score_sum > 0: 89 | score_sum /= len(result) 90 | return result, score_sum 91 | 92 | def choose_char(self, char, substitudes, scores): 93 | # if not self.charset.is_empty_char(char): 94 | # return char 95 | 96 | max_score = -1 97 | index = None 98 | for i, (char, score) in enumerate(zip(substitudes, scores)): 99 | if self.charset.is_empty_char(char): 100 | score -= self.offset 101 | if score > max_score: 102 | max_score = score 103 | index = i 104 | return substitudes[index], max_score 105 | -------------------------------------------------------------------------------- /structure/representers/mask_rcnn.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from concern.config import Configurable 4 | 5 | 6 | class MaskRCNNRepresenter(Configurable): 7 | def __init__(self, **kwargs): 8 | self.load_all(**kwargs) 9 | 10 | def represent(self, batch, pred): 11 | image, label, meta = batch 12 | 13 | output = { 14 | 'meta': [pickle.loads(value) for value in meta], 15 | 'polygons_pred': pred, 16 | } 17 | return output 18 | -------------------------------------------------------------------------------- /structure/representers/seg_recognition_representer.py: -------------------------------------------------------------------------------- 1 | from concern.config import State 2 | from .sequence_recognition_representer import SequenceRecognitionRepresenter 3 | import config 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | 9 | class SegRecognitionRepresenter(SequenceRecognitionRepresenter): 10 | max_candidates = State(default=64) 11 | min_size = State(default=2) 12 | thresh = State(default=0.3) 13 | box_thresh = State(default=0.7) 14 | 15 | def represent(self, batch, pred): 16 | labels = batch['label'] 17 | mask = pred['mask'] 18 | classify = pred['classify'] 19 | result = [] 20 | 21 | for batch_index in range(mask.shape[0]): 22 | label_string = self.label_to_string(labels[batch_index]) 23 | result_dict = self.result_from_heatmap(mask[batch_index], classify[batch_index]) 24 | result_dict.update(label_string=label_string) 25 | result.append(result_dict) 26 | return result 27 | 28 | def result_from_heatmap(self, bitmap, heatmap): 29 | bitmap = (bitmap > self.thresh).data.cpu().numpy() 30 | score_map = heatmap.data.cpu().detach().numpy() 31 | result = [] 32 | _, contours, _ = cv2.findContours( 33 | (bitmap*255).astype(np.uint8), 34 | cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) 35 | 36 | boxes = [] 37 | for contour in contours[:self.max_candidates]: 38 | points, sside = self.get_boxes(contour) 39 | 40 | #if sside < self.min_size: 41 | # print('by min side') 42 | # continue 43 | boxes.append(np.array(points).reshape(-1, 2)) 44 | 45 | for points in sorted(boxes, key=lambda x: x[0][0]): 46 | score, char_index = self.box_score_fast(bitmap, points.reshape(-1, 2), score_map) 47 | #if self.box_thresh > score: 48 | # continue 49 | result.append(char_index) 50 | 51 | pred_string = self.charset.label_to_string(result) 52 | return dict(mask=bitmap, classify=score_map, pred_string=pred_string) 53 | 54 | def get_boxes(self, contour): 55 | bounding_box = cv2.minAreaRect(contour) 56 | points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) 57 | 58 | index_1, index_2, index_3, index_4 = 0, 1, 2, 3 59 | if points[1][1] > points[0][1]: 60 | index_1 = 0 61 | index_4 = 1 62 | else: 63 | index_1 = 1 64 | index_4 = 0 65 | if points[3][1] > points[2][1]: 66 | index_2 = 2 67 | index_3 = 3 68 | else: 69 | index_2 = 3 70 | index_3 = 2 71 | 72 | box = [points[index_1], points[index_2], 73 | points[index_3], points[index_4]] 74 | return box, min(bounding_box[1]) 75 | 76 | def box_score_fast(self, bitmap, _box, score_map): 77 | h, w = bitmap.shape[:2] 78 | box = _box.copy() 79 | xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) 80 | xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) 81 | ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) 82 | ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) 83 | 84 | mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) 85 | box[:, 0] = box[:, 0] - xmin 86 | box[:, 1] = box[:, 1] - ymin 87 | cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) 88 | score = cv2.mean(bitmap[ymin:ymax+1, xmin:xmax+1], mask)[0] 89 | char_index = int(score_map[1:, ymin:ymax+1, xmin:xmax+1].mean(axis=1).mean(axis=1).argmax()) + 1 90 | return score, char_index 91 | -------------------------------------------------------------------------------- /structure/representers/sequence_recognition_representer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import config 4 | from concern.config import Configurable, State 5 | from concern.charsets import DefaultCharset 6 | 7 | import cv2 8 | import numpy as np 9 | import concern.webcv2 as webcv 10 | 11 | 12 | class SequenceRecognitionRepresenter(Configurable): 13 | charset = State(default=DefaultCharset()) 14 | 15 | def __init__(self, cmd={}, **kwargs): 16 | self.load_all(**kwargs) 17 | 18 | def label_to_string(self, label): 19 | return self.charset.label_to_string(label) 20 | 21 | def represent(self, batch, pred): 22 | images, labels = batch['image'], batch['label'] 23 | mask = torch.ones(pred.shape[0], dtype=torch.int).to(pred.device) 24 | 25 | for i in range(pred.shape[1]): 26 | mask = ( 27 | 1 - (pred[:, i] == self.charset.blank).type(torch.int)) * mask 28 | pred[:, i] = pred[:, i] * mask + self.charset.blank * (1 - mask) 29 | 30 | output = [] 31 | for i in range(labels.shape[0]): 32 | label_str = self.label_to_string(labels[i]) 33 | pred_str = self.label_to_string(pred[i]) 34 | if False and label_str != pred_str: 35 | print('label: %s , pred: %s' % (label_str, pred_str)) 36 | img = (np.clip(images[i].cpu().data.numpy().transpose( 37 | 1, 2, 0) + 0.5, 0, 1) * 255).astype('uint8') 38 | webcv.imshow('【 pred: <%s> , label: <%s> 】' % ( 39 | pred_str, label_str), np.array(img, dtype=np.uint8)) 40 | if webcv.waitKey() == ord('q'): 41 | continue 42 | output.append({ 43 | 'label_string': label_str, 44 | 'pred_string': pred_str 45 | }) 46 | 47 | return output 48 | 49 | 50 | class SequenceRecognitionEvaluationRepresenter(Configurable): 51 | charset = State(default=DefaultCharset()) 52 | 53 | def __init__(self, cmd={}, **kwargs): 54 | self.load_all(**kwargs) 55 | 56 | def label_to_string(self, label): 57 | return self.charset.label_to_string(label) 58 | 59 | def represent(self, batch, pred): 60 | images, labels, lengths = batch 61 | mask = torch.ones(pred.shape[0], dtype=torch.int) 62 | 63 | for i in range(pred.shape[1]): 64 | mask = ( 65 | 1 - (pred[:, i] == self.charset.blank).type(torch.int)) * mask 66 | pred[:, i] = pred[:, i] * mask + self.charset.blank * (1 - mask) 67 | 68 | output = [] 69 | for i in range(images.shape[0]): 70 | pred_str = self.label_to_string(pred[i]) 71 | output.append(pred_str) 72 | return output 73 | -------------------------------------------------------------------------------- /structure/visualizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | 5 | from concern.config import Configurable, State 6 | 7 | 8 | class TrivalVisualizer(Configurable): 9 | def __init__(self, **kwargs): 10 | pass 11 | 12 | def visualize(self, batch, output): 13 | return {} 14 | -------------------------------------------------------------------------------- /structure/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .textsnake import TextsnakeVisualizer 2 | from .seg_recognition_visualizer import SegRecognitionVisualizer 3 | from .ctc_visualizer2d import CTCVisualizer2D 4 | from .seg_detector_visualizer import SegDetectorVisualizer 5 | # from .sequence_recognition_visualizer import SequenceRecognitionVisualizer 6 | -------------------------------------------------------------------------------- /structure/visualizers/ctc_visualizer2d.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from concern.config import Configurable, State 5 | from concern.visualizer import Visualize 6 | from data.processes import NormalizeImage 7 | import concern.webcv2 as webcv2 8 | 9 | 10 | class CTCVisualizer2D(Configurable): 11 | eager_show = State(default=False, cmd_key='eager_show') 12 | 13 | def visualize(self, batch, output, interested): 14 | return self.visualize_batch(batch, output) 15 | 16 | def visualize_batch(self, batch, output): 17 | visualization = dict() 18 | for index, output_dict in enumerate(output): 19 | image = batch['image'][index] 20 | image = NormalizeImage.restore(image) 21 | 22 | mask = output_dict['mask'] 23 | mask = cv2.resize(Visualize.visualize_weights(mask), image.shape[:2][::-1]) 24 | 25 | classify = output_dict['classify'] 26 | classify = cv2.resize(Visualize.visualize_heatmap(classify, format='CHW'), image.shape[:2][::-1]) 27 | 28 | canvas = np.concatenate([image, mask, classify], axis=0) 29 | key = "【%s-%s】" % (output_dict['label_string'], output_dict['pred_string']) 30 | vis_dict = { 31 | key: canvas 32 | } 33 | 34 | if self.eager_show: 35 | for k, v in vis_dict.items(): 36 | # if output_dict['label_string'] != output_dict['pred_string']: 37 | webcv2.imshow(k, v) 38 | visualization.update(mask=mask, classify=classify, image=image) 39 | if self.eager_show: 40 | webcv2.waitKey() 41 | return visualization 42 | -------------------------------------------------------------------------------- /structure/visualizers/east.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from concern.config import Configurable, State 5 | 6 | 7 | class EASTVisualizer(Configurable): 8 | vis_num = State(default=4) 9 | 10 | def __init__(self, **kwargs): 11 | pass 12 | 13 | def visualize_detection(self, image, heatmap, heatmask, densebox, polygons): 14 | image_show = image.transpose(1, 2, 0).copy() 15 | 16 | h, w = image_show.shape[:2] 17 | densebox_image = np.zeros((h, w, 3), 'uint8') 18 | densebox_anchor = np.indices((h, w))[::-1] 19 | 20 | colors = [(64, 64, 255), (64, 255, 255), (64, 255, 64), (255, 255, 64)] 21 | for i in range(0, 4): 22 | points = densebox[i * 2: i * 2 + 2] + densebox_anchor 23 | points = points[np.tile(heatmask[np.newaxis], (2, 1, 1)) > 0] 24 | points = points.reshape(2, -1).astype('int32') 25 | mask = np.logical_and.reduce([points[0] >= 0, points[0] < w, points[1] >= 0, points[1] < h]) 26 | densebox_image[points[1, mask], points[0, mask]] = colors[i] 27 | 28 | image_show = cv2.polylines(image_show, np.array(polygons, 'int32'), True, (0, 0, 255), 1) 29 | 30 | result_image = np.concatenate([ 31 | image_show, 32 | cv2.cvtColor((heatmap * 255).astype('uint8'), cv2.COLOR_GRAY2BGR), 33 | cv2.cvtColor((heatmask * 255).astype('uint8'), cv2.COLOR_GRAY2BGR), 34 | densebox_image, 35 | ], axis=1) 36 | 37 | return result_image 38 | 39 | def visualize_weight(self, idx, output): 40 | heatmap_weight = output['heatmap_weight'][idx] 41 | densebox_weight = output['densebox_weight'][idx] 42 | 43 | result_image = np.concatenate([ 44 | cv2.cvtColor(((heatmap_weight / heatmap_weight.max())[0] * 255).astype('uint8'), cv2.COLOR_GRAY2BGR), 45 | cv2.cvtColor((densebox_weight[0] * 255).astype('uint8'), cv2.COLOR_GRAY2BGR), 46 | ], axis=1) 47 | 48 | return result_image 49 | 50 | def visualize_pred_detection(self, idx, output): 51 | image = output['image'][idx] 52 | heatmap = output['heatmap_pred'][idx] 53 | heatmask = output['heatmask_pred'][idx] 54 | densebox = output['densebox_pred'][idx] 55 | polygons = output['polygons_pred'][idx] 56 | 57 | return self.visualize_detection(image, heatmap, heatmask, densebox, polygons) 58 | 59 | def visualize_gt_detection(self, idx, output): 60 | image = output['image'][idx] 61 | heatmap = output['heatmap'][idx] 62 | heatmask = output['heatmask'][idx] 63 | densebox = output['densebox'][idx] 64 | polygons = output['polygons'][idx] 65 | 66 | return self.visualize_detection(image, heatmap, heatmask, densebox, polygons) 67 | 68 | def get_image(self, idx, output): 69 | return np.concatenate([ 70 | self.visualize_gt_detection(idx, output), 71 | self.visualize_weight(idx, output), 72 | self.visualize_pred_detection(idx, output), 73 | ], axis=1) 74 | 75 | def gt_get_image(self, idx, output): 76 | return np.concatenate([ 77 | self.visualize_gt_detection(idx, output), 78 | self.visualize_weight(idx, output), 79 | ], axis=1) 80 | 81 | def visualize(self, batch, output, interested): 82 | images = {} 83 | for i in range(min(self.vis_num, len(output['image']))): 84 | show = self.get_image(i, output) 85 | images['image_%d' % i] = show.astype(np.uint8) 86 | return images 87 | 88 | def visualize_batch(self, batch, output): 89 | images = {} 90 | for i in range(len(output['image'])): 91 | show = self.gt_get_image(i, output) 92 | images['image_%d' % i] = show.astype(np.uint8) 93 | return images 94 | -------------------------------------------------------------------------------- /structure/visualizers/seg_detector_visualizer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import concern.webcv2 as webcv2 3 | import numpy as np 4 | import torch 5 | 6 | from concern.config import Configurable, State 7 | from data.processes.make_icdar_data import MakeICDARData 8 | 9 | 10 | class SegDetectorVisualizer(Configurable): 11 | vis_num = State(default=4) 12 | eager_show = State(default=False) 13 | 14 | def __init__(self, **kwargs): 15 | cmd = kwargs['cmd'] 16 | if 'eager_show' in cmd: 17 | self.eager_show = cmd['eager_show'] 18 | 19 | def visualize(self, batch, output_pair, interested): 20 | output, pred = output_pair 21 | result_dict = {} 22 | for i in range(batch['image'].size(0)): 23 | result_dict.update( 24 | self.single_visualize(batch, i, output[i], pred)) 25 | if self.eager_show: 26 | webcv2.waitKey() 27 | return {} 28 | return result_dict 29 | 30 | def _visualize_heatmap(self, heatmap, canvas=None): 31 | if isinstance(heatmap, torch.Tensor): 32 | heatmap = heatmap.detach().cpu().numpy() 33 | heatmap = (heatmap[0] * 255).astype(np.uint8) 34 | if canvas is None: 35 | pred_image = heatmap 36 | else: 37 | pred_image = (heatmap.reshape( 38 | *heatmap.shape[:2], 1).astype(np.float32) / 255 + 1) / 2 * canvas 39 | pred_image = pred_image.astype(np.uint8) 40 | return pred_image 41 | 42 | def single_visualize(self, batch, index, output, pred): 43 | image = batch['image'][index] 44 | polygons = batch['polygons'][index] 45 | if isinstance(polygons, torch.Tensor): 46 | polygons = polygons.cpu().data.numpy() 47 | ignore_tags = batch['ignore_tags'][index] 48 | original_shape = batch['shape'][index] 49 | filename = batch['filename'][index] 50 | std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) 51 | mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1) 52 | image = (image.cpu().numpy() * std + mean).transpose(1, 2, 0) * 255 53 | pred_canvas = image.copy().astype(np.uint8) 54 | original_shape = tuple(original_shape.tolist()) 55 | pred_canvas = self._visualize_heatmap(pred['binary'][index], pred_canvas) 56 | 57 | if 'thresh' in pred: 58 | thresh = self._visualize_heatmap(pred['thresh'][index]) 59 | 60 | if 'thresh_binary' in pred: 61 | thresh_binary = self._visualize_heatmap(pred['thresh_binary'][index]) 62 | MakeICDARData.polylines(self, thresh_binary, polygons, ignore_tags) 63 | MakeICDARData.polylines(self, pred_canvas, polygons, ignore_tags) 64 | 65 | for box in output: 66 | box = np.array(box).astype(np.int32).reshape(-1, 2) 67 | cv2.polylines(pred_canvas, [box], True, (0, 255, 0), 1) 68 | if 'thresh_binary' in pred: 69 | cv2.polylines(thresh_binary, [box], True, (0, 255, 0), 1) 70 | 71 | if self.eager_show: 72 | webcv2.imshow(filename + ' output', cv2.resize(pred_canvas, (1024, 1024))) 73 | if 'thresh' in pred: 74 | webcv2.imshow(filename + ' thresh', cv2.resize(thresh, (1024, 1024))) 75 | webcv2.imshow(filename + ' pred', cv2.resize(pred_canvas, (1024, 1024))) 76 | if 'thresh_binary' in pred: 77 | webcv2.imshow(filename + ' thresh_binary', cv2.resize(thresh_binary, (1024, 1024))) 78 | return {} 79 | else: 80 | return { 81 | filename + '_output': pred_canvas, 82 | filename + '_pred': np.expand_dims(thresh_binary, 2) if thresh_binary is not None else None 83 | } 84 | -------------------------------------------------------------------------------- /structure/visualizers/seg_recognition_visualizer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import concern.webcv2 as webcv2 5 | from concern.config import Configurable, State 6 | from concern.visualizer import Visualize 7 | from data.processes import NormalizeImage 8 | 9 | 10 | class SegRecognitionVisualizer(Configurable): 11 | eager_show = State(default=False) 12 | 13 | def __init__(self, cmd, **kwargs): 14 | self.load_all(cmd=cmd, **kwargs) 15 | self.eager_show = cmd.get('eager_show', self.eager_show) 16 | 17 | def visualize(self, batch, output, interested): 18 | return self.visualize_batch(batch, output) 19 | 20 | def visualize_batch(self, batch, output): 21 | visualization = dict() 22 | for index, output_dict in enumerate(output): 23 | image = batch['image'][index] 24 | image = NormalizeImage.restore(image) 25 | 26 | mask = output_dict['mask'] 27 | mask = cv2.resize(Visualize.visualize_weights(mask), image.shape[:2][::-1]) 28 | 29 | classify = output_dict['classify'] 30 | classify = cv2.resize(Visualize.visualize_heatmap(classify, format='CHW'), image.shape[:2][::-1]) 31 | 32 | canvas = np.concatenate([image, mask, classify], axis=0) 33 | key = "【%s-%s】" % (output_dict['label_string'], output_dict['pred_string']) 34 | vis_dict = { 35 | key: canvas 36 | } 37 | 38 | if self.eager_show: 39 | for k, v in vis_dict.items(): 40 | # if output_dict['label_string'] != output_dict['pred_string']: 41 | webcv2.imshow(k, v) 42 | visualization.update(vis_dict) 43 | if self.eager_show: 44 | webcv2.waitKey() 45 | return visualization 46 | -------------------------------------------------------------------------------- /structure/visualizers/sequence_recognition_visualizer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from concern.config import Configurable, State 5 | import concern.webcv2 as webcv2 6 | from concern.charsets import DefaultCharset 7 | from data.processes.normalize_image import NormalizeImage 8 | 9 | 10 | class SequenceRecognitionVisualizer(Configurable): 11 | charset = State(default=DefaultCharset()) 12 | 13 | def __init__(self, cmd={}, **kwargs): 14 | self.eager = cmd.get('eager_show', False) 15 | self.load_all(**kwargs) 16 | 17 | def visualize(self, batch, output, interested): 18 | return self.visualize_batch(batch, output) 19 | 20 | def visualize_batch(self, batch, output): 21 | images, labels, lengths = batch['image'], batch['label'], batch['length'] 22 | for i in range(images.shape[0]): 23 | image = NormalizeImage.restore(images[i]) 24 | gt = self.charset.label_to_string(labels[i]) 25 | webcv2.imshow(output[i]['pred_string'] + '_' + str(i) + '_' + gt, image) 26 | # folder = 'images/dropout/lexicon/' 27 | # np.save(folder + output[i]['pred_string'] + '_' + gt + '_' + batch['data_ids'][i], image) 28 | webcv2.waitKey() 29 | return { 30 | 'image': (np.clip(batch['image'][0].cpu().data.numpy().transpose(1, 2, 0) + 0.5, 0, 1) * 255).astype( 31 | 'uint8') 32 | } 33 | -------------------------------------------------------------------------------- /structure/visualizers/textsnake.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | 5 | from concern.config import Configurable, State 6 | 7 | 8 | class TextsnakeVisualizer(Configurable): 9 | vis_num = State(default=4) 10 | 11 | def __init__(self, **kwargs): 12 | pass 13 | 14 | def gt_get_image(self, idx, output): 15 | img = output['image'][idx] 16 | img_show = img.copy() 17 | 18 | tr_mask = output['tr_mask'][idx] 19 | tcl_mask = output['tcl_mask'][idx] 20 | 21 | polygons_gt = output['polygons_gt'][idx] 22 | contours_gt = output['contours_gt'][idx] 23 | 24 | contours_gt_im = np.zeros_like(img_show) 25 | for contour in contours_gt: 26 | cv2.fillPoly( 27 | contours_gt_im, np.array([contour], 'int32'), 28 | (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255)) 29 | ) 30 | 31 | gt_vis = self.visualize_detection(img_show, tr_mask, tcl_mask, polygons_gt) 32 | gt_vis = np.concatenate([gt_vis, contours_gt_im], axis=1) 33 | 34 | return gt_vis 35 | 36 | def pred_get_image(self, idx, output): 37 | img = output['image'][idx] 38 | 39 | tr_pred = output['tr_pred'][idx] 40 | tcl_pred = output['tcl_pred'][idx] 41 | 42 | # visualization 43 | img_show = img.copy() 44 | contours = output['contours_pred'][idx] 45 | contours_im = np.zeros_like(img_show) 46 | for contour in contours: 47 | cv2.fillPoly( 48 | contours_im, np.array([contour], 'int32'), 49 | (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255)) 50 | ) 51 | 52 | predict_vis = self.visualize_detection(img_show, tr_pred, tcl_pred, contours) 53 | predict_vis = np.concatenate([predict_vis, contours_im], axis=1) 54 | 55 | return predict_vis 56 | 57 | def get_image(self, idx, output): 58 | gt_vis = self.gt_get_image(idx, output) 59 | predict_vis = self.pred_get_image(idx, output) 60 | return np.concatenate([predict_vis, gt_vis], axis=0) 61 | 62 | def visualize_detection(self, image, tr, tcl, contours): 63 | image_show = image.copy() 64 | for contour in contours: 65 | image_show = cv2.polylines(image_show, np.array([contour], 'int32'), True, (0, 0, 255), 2) 66 | tr = cv2.cvtColor((tr * 255).astype('uint8'), cv2.COLOR_GRAY2BGR) 67 | tcl = cv2.cvtColor((tcl * 255).astype('uint8'), cv2.COLOR_GRAY2BGR) 68 | image_show = np.concatenate([image_show, tr, tcl], axis=1) 69 | return image_show 70 | 71 | def visualize(self, batch, output, interested): 72 | images = {} 73 | for i in range(min(self.vis_num, len(output['image']))): 74 | show = self.get_image(i, output) 75 | images['image_%d' % i] = show.astype(np.uint8) 76 | return images 77 | 78 | def visualize_batch(self, batch, output): 79 | images = {} 80 | for i in range(len(output['image'])): 81 | show = self.gt_get_image(i, output) 82 | train_mask = cv2.cvtColor((output['train_mask'][i] * 255).astype('uint8'), cv2.COLOR_GRAY2BGR) 83 | show = np.concatenate([show, train_mask], axis=1) 84 | images['image_%d' % i] = show.astype(np.uint8) 85 | return images 86 | -------------------------------------------------------------------------------- /training/checkpoint.py: -------------------------------------------------------------------------------- 1 | from concern.config import Configurable, State 2 | import os 3 | import torch 4 | 5 | 6 | class Checkpoint(Configurable): 7 | start_epoch = State(default=0) 8 | start_iter = State(default=0) 9 | resume = State() 10 | 11 | def __init__(self, **kwargs): 12 | self.load_all(**kwargs) 13 | 14 | cmd = kwargs['cmd'] 15 | if 'start_epoch' in cmd: 16 | self.start_epoch = cmd['start_epoch'] 17 | if 'start_iter' in cmd: 18 | self.start_iter = cmd['start_iter'] 19 | if 'resume' in cmd: 20 | self.resume = cmd['resume'] 21 | 22 | def restore_model(self, model, device, logger): 23 | if self.resume is None: 24 | return 25 | 26 | if not os.path.exists(self.resume): 27 | self.logger.warning("Checkpoint not found: " + 28 | self.resume) 29 | return 30 | 31 | logger.info("Resuming from " + self.resume) 32 | state_dict = torch.load(self.resume, map_location=device) 33 | model.load_state_dict(state_dict, strict=False) 34 | logger.info("Resumed from " + self.resume) 35 | 36 | def restore_counter(self): 37 | return self.start_epoch, self.start_iter 38 | -------------------------------------------------------------------------------- /training/model_saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | from concern.config import Configurable, State 7 | from concern.signal_monitor import SignalMonitor 8 | 9 | 10 | class ModelSaver(Configurable): 11 | dir_path = State() 12 | save_interval = State(default=1000) 13 | signal_path = State() 14 | 15 | def __init__(self, **kwargs): 16 | self.load_all(**kwargs) 17 | 18 | # BUG: signal path should not be global 19 | self.monitor = SignalMonitor(self.signal_path) 20 | 21 | def maybe_save_model(self, model, epoch, step, logger): 22 | if step % self.save_interval == 0 or self.monitor.get_signal() is not None: 23 | self.save_model(model, epoch, step) 24 | logger.report_time('Saving ') 25 | logger.iter(step) 26 | 27 | def save_model(self, model, epoch=None, step=None): 28 | if isinstance(model, dict): 29 | for name, net in model.items(): 30 | checkpoint_name = self.make_checkpoint_name(name, epoch, step) 31 | self.save_checkpoint(net, checkpoint_name) 32 | else: 33 | checkpoint_name = self.make_checkpoint_name('model', epoch, step) 34 | self.save_checkpoint(model, checkpoint_name) 35 | 36 | def save_checkpoint(self, net, name): 37 | if dist.is_available() and dist.is_initialized() and not dist.get_rank() == 0: 38 | return 39 | os.makedirs(self.dir_path, exist_ok=True) 40 | torch.save(net.state_dict(), os.path.join(self.dir_path, name)) 41 | 42 | def make_checkpoint_name(self, name, epoch=None, step=None): 43 | if epoch is None or step is None: 44 | c_name = name + '_latest' 45 | else: 46 | c_name = '{}_epoch_{}_minibatch_{}'.format(name, epoch, step) 47 | return c_name 48 | -------------------------------------------------------------------------------- /training/optimizer_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from concern.config import Configurable, State 4 | 5 | 6 | class OptimizerScheduler(Configurable): 7 | optimizer = State() 8 | optimizer_args = State(default={}) 9 | learning_rate = State(autoload=False) 10 | 11 | def __init__(self, cmd={}, **kwargs): 12 | self.load_all(**kwargs) 13 | self.load('learning_rate', cmd=cmd, **kwargs) 14 | if 'lr' in cmd: 15 | self.optimizer_args['lr'] = cmd['lr'] 16 | 17 | def create_optimizer(self, parameters): 18 | optimizer = getattr(torch.optim, self.optimizer)( 19 | parameters, **self.optimizer_args) 20 | if hasattr(self.learning_rate, 'prepare'): 21 | self.learning_rate.prepare(optimizer) 22 | return optimizer 23 | --------------------------------------------------------------------------------