├── .gitignore ├── README.md ├── assets ├── ic15_eval │ ├── rrc_evaluation_funcs.py │ └── script.py └── ops │ └── dcn │ ├── __init__.py │ ├── functions │ ├── __init__.py │ ├── deform_conv.py │ └── deform_pool.py │ ├── modules │ ├── __init__.py │ ├── deform_conv.py │ └── deform_pool.py │ ├── setup.py │ └── src │ ├── deform_conv_cuda.cpp │ ├── deform_conv_cuda_kernel.cu │ ├── deform_pool_cuda.cpp │ └── deform_pool_cuda_kernel.cu ├── backbones ├── __init__.py ├── mobilenetv3.py └── resnet.py ├── concern ├── __init__.py ├── average_meter.py ├── box2seg.py ├── config.py ├── convert.py ├── icdar2015_eval │ ├── __init__.py │ └── detection │ │ ├── __init__.py │ │ ├── deteval.py │ │ ├── icdar2013.py │ │ ├── iou.py │ │ └── mtwi2018.py ├── log.py ├── signal_monitor.py ├── visualizer.py └── webcv2 │ ├── __init__.py │ ├── manager.py │ ├── server.py │ └── templates │ └── index.html ├── convert_to_onnx.py ├── data ├── __init__.py ├── augmenter.py ├── data_loader.py ├── dataset.py ├── image_dataset.py ├── make_border_map.py ├── make_seg_detector_data.py ├── meta_loader.py ├── processes │ ├── __init__.py │ ├── augment_data.py │ ├── data_process.py │ ├── filter_keys.py │ ├── make_border_map.py │ ├── make_center_distance_map.py │ ├── make_center_map.py │ ├── make_center_points.py │ ├── make_icdar_data.py │ ├── make_seg_detection_data.py │ ├── normalize_image.py │ ├── random_crop_data.py │ ├── resize_image.py │ └── serialize_box.py ├── quad.py ├── random_crop_aug.py ├── simple_detection.py ├── text_lines.py ├── transform_data.py └── unpack_msgpack_data.py ├── decoders ├── __init__.py ├── balance_cross_entropy_loss.py ├── dice_loss.py ├── feature_attention.py ├── l1_loss.py ├── pss_loss.py ├── seg_detector.py ├── seg_detector_asf.py ├── seg_detector_loss.py └── simple_detection.py ├── demo.py ├── eval.py ├── experiment.py ├── experiments ├── ASF │ └── td500_resnet50_deform_thre_asf.yaml ├── base.yaml └── seg_detector │ ├── base.yaml │ ├── base_ic15.yaml │ ├── base_td500.yaml │ ├── base_totaltext.yaml │ ├── ic15_resnet18_deform_thre.yaml │ ├── ic15_resnet50_deform_thre.yaml │ ├── td500_resnet18_deform_thre.yaml │ ├── td500_resnet50_deform_thre.yaml │ ├── totaltext_mobilenet_v3_large_thre.yaml │ ├── totaltext_resnet18_deform_thre.yaml │ └── totaltext_resnet50_deform_thre.yaml ├── requirement.txt ├── structure ├── __init__.py ├── builder.py ├── measurers │ ├── __init__.py │ ├── icdar_detection_measurer.py │ └── quad_measurer.py ├── model.py ├── representers │ ├── __init__.py │ ├── seg_detector_representer.py │ └── setup.py └── visualizers │ ├── __init__.py │ └── seg_detector_visualizer.py ├── train.py ├── trainer.py └── training ├── checkpoint.py ├── learning_rate.py ├── model_saver.py └── optimizer_scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | */__pycache__/* 3 | workspace 4 | *.py[cod] 5 | *$py.class 6 | *.swp 7 | *.swo 8 | *.lock 9 | 10 | # C extensions 11 | *.so 12 | *.nfs* 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | .idea 109 | log.txt # From the naive evaluating of ICDAR15 110 | 111 | # specific directory 112 | datasets 113 | evaluation 114 | experiments/backup 115 | lib 116 | outputs 117 | results 118 | *.zip 119 | *.pyx 120 | struture/representers/setup.py 121 | demo_results -------------------------------------------------------------------------------- /assets/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions.deform_conv import deform_conv, modulated_deform_conv 2 | from .functions.deform_pool import deform_roi_pooling 3 | from .modules.deform_conv import (DeformConv, ModulatedDeformConv, 4 | DeformConvPack, ModulatedDeformConvPack) 5 | from .modules.deform_pool import (DeformRoIPooling, DeformRoIPoolingPack, 6 | ModulatedDeformRoIPoolingPack) 7 | 8 | __all__ = [ 9 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 10 | 'ModulatedDeformConvPack', 'DeformRoIPooling', 'DeformRoIPoolingPack', 11 | 'ModulatedDeformRoIPoolingPack', 'deform_conv', 'modulated_deform_conv', 12 | 'deform_roi_pooling' 13 | ] 14 | -------------------------------------------------------------------------------- /assets/ops/dcn/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MhLiao/DB/65ca77a0bcfbd7114b916cf8a1e9ca85114286ce/assets/ops/dcn/functions/__init__.py -------------------------------------------------------------------------------- /assets/ops/dcn/functions/deform_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | from .. import deform_pool_cuda 5 | 6 | 7 | class DeformRoIPoolingFunction(Function): 8 | 9 | @staticmethod 10 | def forward(ctx, 11 | data, 12 | rois, 13 | offset, 14 | spatial_scale, 15 | out_size, 16 | out_channels, 17 | no_trans, 18 | group_size=1, 19 | part_size=None, 20 | sample_per_part=4, 21 | trans_std=.0): 22 | ctx.spatial_scale = spatial_scale 23 | ctx.out_size = out_size 24 | ctx.out_channels = out_channels 25 | ctx.no_trans = no_trans 26 | ctx.group_size = group_size 27 | ctx.part_size = out_size if part_size is None else part_size 28 | ctx.sample_per_part = sample_per_part 29 | ctx.trans_std = trans_std 30 | 31 | assert 0.0 <= ctx.trans_std <= 1.0 32 | if not data.is_cuda: 33 | raise NotImplementedError 34 | 35 | n = rois.shape[0] 36 | output = data.new_empty(n, out_channels, out_size, out_size) 37 | output_count = data.new_empty(n, out_channels, out_size, out_size) 38 | deform_pool_cuda.deform_psroi_pooling_cuda_forward( 39 | data, rois, offset, output, output_count, ctx.no_trans, 40 | ctx.spatial_scale, ctx.out_channels, ctx.group_size, ctx.out_size, 41 | ctx.part_size, ctx.sample_per_part, ctx.trans_std) 42 | 43 | if data.requires_grad or rois.requires_grad or offset.requires_grad: 44 | ctx.save_for_backward(data, rois, offset) 45 | ctx.output_count = output_count 46 | 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | if not grad_output.is_cuda: 52 | raise NotImplementedError 53 | 54 | data, rois, offset = ctx.saved_tensors 55 | output_count = ctx.output_count 56 | grad_input = torch.zeros_like(data) 57 | grad_rois = None 58 | grad_offset = torch.zeros_like(offset) 59 | 60 | deform_pool_cuda.deform_psroi_pooling_cuda_backward( 61 | grad_output, data, rois, offset, output_count, grad_input, 62 | grad_offset, ctx.no_trans, ctx.spatial_scale, ctx.out_channels, 63 | ctx.group_size, ctx.out_size, ctx.part_size, ctx.sample_per_part, 64 | ctx.trans_std) 65 | return (grad_input, grad_rois, grad_offset, None, None, None, None, 66 | None, None, None, None) 67 | 68 | 69 | deform_roi_pooling = DeformRoIPoolingFunction.apply 70 | -------------------------------------------------------------------------------- /assets/ops/dcn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MhLiao/DB/65ca77a0bcfbd7114b916cf8a1e9ca85114286ce/assets/ops/dcn/modules/__init__.py -------------------------------------------------------------------------------- /assets/ops/dcn/modules/deform_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules.utils import _pair 6 | 7 | from ..functions.deform_conv import deform_conv, modulated_deform_conv 8 | 9 | 10 | class DeformConv(nn.Module): 11 | 12 | def __init__(self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | stride=1, 17 | padding=0, 18 | dilation=1, 19 | groups=1, 20 | deformable_groups=1, 21 | bias=False): 22 | super(DeformConv, self).__init__() 23 | 24 | assert not bias 25 | assert in_channels % groups == 0, \ 26 | 'in_channels {} cannot be divisible by groups {}'.format( 27 | in_channels, groups) 28 | assert out_channels % groups == 0, \ 29 | 'out_channels {} cannot be divisible by groups {}'.format( 30 | out_channels, groups) 31 | 32 | self.in_channels = in_channels 33 | self.out_channels = out_channels 34 | self.kernel_size = _pair(kernel_size) 35 | self.stride = _pair(stride) 36 | self.padding = _pair(padding) 37 | self.dilation = _pair(dilation) 38 | self.groups = groups 39 | self.deformable_groups = deformable_groups 40 | 41 | self.weight = nn.Parameter( 42 | torch.Tensor(out_channels, in_channels // self.groups, 43 | *self.kernel_size)) 44 | 45 | self.reset_parameters() 46 | 47 | def reset_parameters(self): 48 | n = self.in_channels 49 | for k in self.kernel_size: 50 | n *= k 51 | stdv = 1. / math.sqrt(n) 52 | self.weight.data.uniform_(-stdv, stdv) 53 | 54 | def forward(self, x, offset): 55 | return deform_conv(x, offset, self.weight, self.stride, self.padding, 56 | self.dilation, self.groups, self.deformable_groups) 57 | 58 | 59 | class DeformConvPack(DeformConv): 60 | 61 | def __init__(self, *args, **kwargs): 62 | super(DeformConvPack, self).__init__(*args, **kwargs) 63 | 64 | self.conv_offset = nn.Conv2d( 65 | self.in_channels, 66 | self.deformable_groups * 2 * self.kernel_size[0] * 67 | self.kernel_size[1], 68 | kernel_size=self.kernel_size, 69 | stride=_pair(self.stride), 70 | padding=_pair(self.padding), 71 | bias=True) 72 | self.init_offset() 73 | 74 | def init_offset(self): 75 | self.conv_offset.weight.data.zero_() 76 | self.conv_offset.bias.data.zero_() 77 | 78 | def forward(self, x): 79 | offset = self.conv_offset(x) 80 | return deform_conv(x, offset, self.weight, self.stride, self.padding, 81 | self.dilation, self.groups, self.deformable_groups) 82 | 83 | 84 | class ModulatedDeformConv(nn.Module): 85 | 86 | def __init__(self, 87 | in_channels, 88 | out_channels, 89 | kernel_size, 90 | stride=1, 91 | padding=0, 92 | dilation=1, 93 | groups=1, 94 | deformable_groups=1, 95 | bias=True): 96 | super(ModulatedDeformConv, self).__init__() 97 | self.in_channels = in_channels 98 | self.out_channels = out_channels 99 | self.kernel_size = _pair(kernel_size) 100 | self.stride = stride 101 | self.padding = padding 102 | self.dilation = dilation 103 | self.groups = groups 104 | self.deformable_groups = deformable_groups 105 | self.with_bias = bias 106 | 107 | self.weight = nn.Parameter( 108 | torch.Tensor(out_channels, in_channels // groups, 109 | *self.kernel_size)) 110 | if bias: 111 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 112 | else: 113 | self.register_parameter('bias', None) 114 | self.reset_parameters() 115 | 116 | def reset_parameters(self): 117 | n = self.in_channels 118 | for k in self.kernel_size: 119 | n *= k 120 | stdv = 1. / math.sqrt(n) 121 | self.weight.data.uniform_(-stdv, stdv) 122 | if self.bias is not None: 123 | self.bias.data.zero_() 124 | 125 | def forward(self, x, offset, mask): 126 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, 127 | self.stride, self.padding, self.dilation, 128 | self.groups, self.deformable_groups) 129 | 130 | 131 | class ModulatedDeformConvPack(ModulatedDeformConv): 132 | 133 | def __init__(self, *args, **kwargs): 134 | super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) 135 | 136 | self.conv_offset_mask = nn.Conv2d( 137 | self.in_channels, 138 | self.deformable_groups * 3 * self.kernel_size[0] * 139 | self.kernel_size[1], 140 | kernel_size=self.kernel_size, 141 | stride=_pair(self.stride), 142 | padding=_pair(self.padding), 143 | bias=True) 144 | self.init_offset() 145 | 146 | def init_offset(self): 147 | self.conv_offset_mask.weight.data.zero_() 148 | self.conv_offset_mask.bias.data.zero_() 149 | 150 | def forward(self, x): 151 | out = self.conv_offset_mask(x) 152 | o1, o2, mask = torch.chunk(out, 3, dim=1) 153 | offset = torch.cat((o1, o2), dim=1) 154 | mask = torch.sigmoid(mask) 155 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, 156 | self.stride, self.padding, self.dilation, 157 | self.groups, self.deformable_groups) 158 | -------------------------------------------------------------------------------- /assets/ops/dcn/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='deform_conv', 6 | ext_modules=[ 7 | CUDAExtension('deform_conv_cuda', [ 8 | 'src/deform_conv_cuda.cpp', 9 | 'src/deform_conv_cuda_kernel.cu', 10 | ]), 11 | CUDAExtension('deform_pool_cuda', [ 12 | 'src/deform_pool_cuda.cpp', 'src/deform_pool_cuda_kernel.cu' 13 | ]), 14 | ], 15 | cmdclass={'build_ext': BuildExtension}) 16 | -------------------------------------------------------------------------------- /assets/ops/dcn/src/deform_pool_cuda.cpp: -------------------------------------------------------------------------------- 1 | // modify from 2 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c 3 | 4 | // based on 5 | // author: Charles Shang 6 | // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | void DeformablePSROIPoolForward( 14 | const at::Tensor data, const at::Tensor bbox, const at::Tensor trans, 15 | at::Tensor out, at::Tensor top_count, const int batch, const int channels, 16 | const int height, const int width, const int num_bbox, 17 | const int channels_trans, const int no_trans, const float spatial_scale, 18 | const int output_dim, const int group_size, const int pooled_size, 19 | const int part_size, const int sample_per_part, const float trans_std); 20 | 21 | void DeformablePSROIPoolBackwardAcc( 22 | const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox, 23 | const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad, 24 | at::Tensor trans_grad, const int batch, const int channels, 25 | const int height, const int width, const int num_bbox, 26 | const int channels_trans, const int no_trans, const float spatial_scale, 27 | const int output_dim, const int group_size, const int pooled_size, 28 | const int part_size, const int sample_per_part, const float trans_std); 29 | 30 | void deform_psroi_pooling_cuda_forward( 31 | at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out, 32 | at::Tensor top_count, const int no_trans, const float spatial_scale, 33 | const int output_dim, const int group_size, const int pooled_size, 34 | const int part_size, const int sample_per_part, const float trans_std) { 35 | AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); 36 | 37 | const int batch = input.size(0); 38 | const int channels = input.size(1); 39 | const int height = input.size(2); 40 | const int width = input.size(3); 41 | const int channels_trans = no_trans ? 2 : trans.size(1); 42 | 43 | const int num_bbox = bbox.size(0); 44 | if (num_bbox != out.size(0)) 45 | AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", 46 | out.size(0), num_bbox); 47 | 48 | DeformablePSROIPoolForward( 49 | input, bbox, trans, out, top_count, batch, channels, height, width, 50 | num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size, 51 | pooled_size, part_size, sample_per_part, trans_std); 52 | } 53 | 54 | void deform_psroi_pooling_cuda_backward( 55 | at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans, 56 | at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad, 57 | const int no_trans, const float spatial_scale, const int output_dim, 58 | const int group_size, const int pooled_size, const int part_size, 59 | const int sample_per_part, const float trans_std) { 60 | AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); 61 | AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); 62 | 63 | const int batch = input.size(0); 64 | const int channels = input.size(1); 65 | const int height = input.size(2); 66 | const int width = input.size(3); 67 | const int channels_trans = no_trans ? 2 : trans.size(1); 68 | 69 | const int num_bbox = bbox.size(0); 70 | if (num_bbox != out_grad.size(0)) 71 | AT_ERROR("Output shape and bbox number wont match: (%d vs %d).", 72 | out_grad.size(0), num_bbox); 73 | 74 | DeformablePSROIPoolBackwardAcc( 75 | out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch, 76 | channels, height, width, num_bbox, channels_trans, no_trans, 77 | spatial_scale, output_dim, group_size, pooled_size, part_size, 78 | sample_per_part, trans_std); 79 | } 80 | 81 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 82 | m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward, 83 | "deform psroi pooling forward(CUDA)"); 84 | m.def("deform_psroi_pooling_cuda_backward", 85 | &deform_psroi_pooling_cuda_backward, 86 | "deform psroi pooling backward(CUDA)"); 87 | } -------------------------------------------------------------------------------- /backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet18, resnet34, resnet50, resnet101, deformable_resnet50, deformable_resnet18 2 | from .mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small -------------------------------------------------------------------------------- /concern/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : __init__.py 4 | # Author : Zhaoyi Wan 5 | # Date : 21.11.2018 6 | # Last Modified Date: 08.01.2019 7 | # Last Modified By : Zhaoyi Wan 8 | 9 | from .log import Logger 10 | from .average_meter import AverageMeter 11 | from .visualizer import Visualize 12 | from .box2seg import resize_with_coordinates, box2seg 13 | from .convert import convert 14 | -------------------------------------------------------------------------------- /concern/average_meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | def __init__(self): 4 | self.reset() 5 | 6 | def reset(self): 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | 12 | def update(self, val, n=1): 13 | self.val = val 14 | self.sum += val * n 15 | self.count += n 16 | self.avg = self.sum / self.count 17 | return self 18 | -------------------------------------------------------------------------------- /concern/box2seg.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from scipy import interpolate 4 | 5 | def intersection(x, p1, p2): 6 | x1, y1 = p1 7 | x2, y2 = p2 8 | if x2 == x1: 9 | return 0 10 | k = (x - x1) / (x2 - x1) 11 | return k * (y2 - y1) + y1 12 | 13 | 14 | def midpoint(p1, p2, typed=float): 15 | return [typed((p1[0] + p2[0]) / 2), typed((p1[1] + p2[1]) / 2)] 16 | 17 | 18 | def resize_with_coordinates(image, width, height, coordinates): 19 | original_height, original_width = image.shape[:2] 20 | resized_image = cv2.resize(image, (width, height)) 21 | if coordinates is not None: 22 | assert coordinates.ndim == 2 23 | assert coordinates.shape[-1] == 2 24 | 25 | rate_x = width / original_width 26 | rate_y = height / original_height 27 | 28 | coordinates = coordinates * (rate_x, rate_y) 29 | return resized_image, coordinates 30 | 31 | 32 | def box2seg(image, boxes, label): 33 | height, width = image.shape[:2] 34 | mask = np.zeros((height, width), dtype=np.float32) 35 | seg = np.zeros((height, width), dtype=np.float32) 36 | points = [] 37 | for box_index in range(boxes.shape[0]): 38 | box = boxes[box_index, :, :] # 4x2 39 | left_top = box[0] 40 | right_top = box[1] 41 | right_bottom = box[2] 42 | left_bottom = box[3] 43 | 44 | left = [(left_top[0] + left_bottom[0]) / 2, (left_top[1] + left_bottom[1]) / 2] 45 | right = [(right_top[0] + right_bottom[0]) / 2, (right_top[1] + right_bottom[1]) / 2] 46 | 47 | center = midpoint(left, right) 48 | points.append(midpoint(left, center)) 49 | points.append(midpoint(right, center)) 50 | 51 | poly = np.array([midpoint(left_top, center), 52 | midpoint(right_top, center), 53 | midpoint(right_bottom, center), 54 | midpoint(left_bottom, center) 55 | ]) 56 | seg = cv2.fillPoly(seg, [poly.reshape(4, 1, 2).astype(np.int32)], int(label[box_index])) 57 | 58 | left_y = intersection(0, points[0], points[1]) 59 | right_y = intersection(width, points[-1], points[-2]) 60 | points.insert(0, [0, left_y]) 61 | points.append([width, right_y]) 62 | points = np.array(points) 63 | 64 | f = interpolate.interp1d(points[:, 0], points[:, 1], fill_value='extrapolate') 65 | xnew = np.arange(0, width, 1) 66 | ynew = f(xnew).clip(0, height-1) 67 | for x in range(width - 1): 68 | mask[int(ynew[x]), x] = 1 69 | return ynew.reshape(1, -1).round(), seg 70 | -------------------------------------------------------------------------------- /concern/config.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from collections import OrderedDict 3 | 4 | import anyconfig 5 | import munch 6 | 7 | 8 | class Config(object): 9 | def __init__(self): 10 | pass 11 | 12 | def load(self, conf): 13 | conf = anyconfig.load(conf) 14 | return munch.munchify(conf) 15 | 16 | def compile(self, conf, return_packages=False): 17 | packages = conf.get('package', []) 18 | defines = {} 19 | 20 | for path in conf.get('import', []): 21 | parent_conf = self.load(path) 22 | parent_packages, parent_defines = self.compile( 23 | parent_conf, return_packages=True) 24 | packages.extend(parent_packages) 25 | defines.update(parent_defines) 26 | 27 | modules = [] 28 | for package in packages: 29 | module = importlib.import_module(package) 30 | modules.append(module) 31 | 32 | if isinstance(conf['define'], dict): 33 | conf['define'] = [conf['define']] 34 | 35 | for define in conf['define']: 36 | name = define.copy().pop('name') 37 | 38 | if not isinstance(name, str): 39 | raise RuntimeError('name must be str') 40 | 41 | defines[name] = self.compile_conf(define, defines, modules) 42 | 43 | if return_packages: 44 | return packages, defines 45 | else: 46 | return defines 47 | 48 | def compile_conf(self, conf, defines, modules): 49 | if isinstance(conf, (int, float)): 50 | return conf 51 | elif isinstance(conf, str): 52 | if conf.startswith('^'): 53 | return defines[conf[1:]] 54 | if conf.startswith('$'): 55 | return {'class': self.find_class_in_modules(conf[1:], modules)} 56 | return conf 57 | elif isinstance(conf, dict): 58 | if 'class' in conf: 59 | conf['class'] = self.find_class_in_modules( 60 | conf['class'], modules) 61 | if 'base' in conf: 62 | base = conf.copy().pop('base') 63 | 64 | if not isinstance(base, str): 65 | raise RuntimeError('base must be str') 66 | 67 | conf = { 68 | **defines[base], 69 | **conf, 70 | } 71 | return {key: self.compile_conf(value, defines, modules) for key, value in conf.items()} 72 | elif isinstance(conf, (list, tuple)): 73 | return [self.compile_conf(value, defines, modules) for value in conf] 74 | else: 75 | return conf 76 | 77 | def find_class_in_modules(self, cls, modules): 78 | if not isinstance(cls, str): 79 | raise RuntimeError('class name must be str') 80 | 81 | if cls.find('.') != -1: 82 | package, cls = cls.rsplit('.', 1) 83 | module = importlib.import_module(package) 84 | if hasattr(module, cls): 85 | return module.__name__ + '.' + cls 86 | 87 | for module in modules: 88 | if hasattr(module, cls): 89 | return module.__name__ + '.' + cls 90 | raise RuntimeError('class not found ' + cls) 91 | 92 | 93 | class State: 94 | def __init__(self, autoload=True, default=None): 95 | self.autoload = autoload 96 | self.default = default 97 | 98 | 99 | class StateMeta(type): 100 | def __new__(mcs, name, bases, attrs): 101 | current_states = [] 102 | for key, value in attrs.items(): 103 | if isinstance(value, State): 104 | current_states.append((key, value)) 105 | 106 | current_states.sort(key=lambda x: x[0]) 107 | attrs['states'] = OrderedDict(current_states) 108 | new_class = super(StateMeta, mcs).__new__(mcs, name, bases, attrs) 109 | 110 | # Walk through the MRO 111 | states = OrderedDict() 112 | for base in reversed(new_class.__mro__): 113 | if hasattr(base, 'states'): 114 | states.update(base.states) 115 | new_class.states = states 116 | 117 | for key, value in states.items(): 118 | setattr(new_class, key, value.default) 119 | 120 | return new_class 121 | 122 | 123 | class Configurable(metaclass=StateMeta): 124 | def __init__(self, *args, cmd={}, **kwargs): 125 | self.load_all(cmd=cmd, **kwargs) 126 | 127 | @staticmethod 128 | def construct_class_from_config(args): 129 | cls = Configurable.extract_class_from_args(args) 130 | return cls(**args) 131 | 132 | @staticmethod 133 | def extract_class_from_args(args): 134 | cls = args.copy().pop('class') 135 | package, cls = cls.rsplit('.', 1) 136 | module = importlib.import_module(package) 137 | cls = getattr(module, cls) 138 | return cls 139 | 140 | def load_all(self, **kwargs): 141 | for name, state in self.states.items(): 142 | if state.autoload: 143 | self.load(name, **kwargs) 144 | 145 | def load(self, state_name, **kwargs): 146 | # FIXME: kwargs should be filtered 147 | # Args passed from command line 148 | cmd = kwargs.pop('cmd', dict()) 149 | if state_name in kwargs: 150 | setattr(self, state_name, self.create_member_from_config( 151 | (kwargs[state_name], cmd))) 152 | else: 153 | setattr(self, state_name, self.states[state_name].default) 154 | 155 | def create_member_from_config(self, conf): 156 | args, cmd = conf 157 | if args is None or isinstance(args, (int, float, str)): 158 | return args 159 | elif isinstance(args, (list, tuple)): 160 | return [self.create_member_from_config((subargs, cmd)) for subargs in args] 161 | elif isinstance(args, dict): 162 | if 'class' in args: 163 | cls = self.extract_class_from_args(args) 164 | return cls(**args, cmd=cmd) 165 | return {key: self.create_member_from_config((subargs, cmd)) for key, subargs in args.items()} 166 | else: 167 | return args 168 | 169 | def dump(self): 170 | state = {} 171 | state['class'] = self.__class__.__module__ + \ 172 | '.' + self.__class__.__name__ 173 | for name, value in self.states.items(): 174 | obj = getattr(self, name) 175 | state[name] = self.dump_obj(obj) 176 | return state 177 | 178 | def dump_obj(self, obj): 179 | if obj is None: 180 | return None 181 | elif hasattr(obj, 'dump'): 182 | return obj.dump() 183 | elif isinstance(obj, (int, float, str)): 184 | return obj 185 | elif isinstance(obj, (list, tuple)): 186 | return [self.dump_obj(value) for value in obj] 187 | elif isinstance(obj, dict): 188 | return {key: self.dump_obj(value) for key, value in obj.items()} 189 | else: 190 | return str(obj) 191 | 192 | -------------------------------------------------------------------------------- /concern/convert.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import cv2 3 | import base64 4 | import io 5 | import numpy as np 6 | 7 | 8 | def convert(data): 9 | if isinstance(data, dict): 10 | ndata = {} 11 | for key, value in data.items(): 12 | nkey = key.decode() 13 | if nkey == 'img': 14 | img = Image.open(io.BytesIO(value)) 15 | img = img.convert('RGB') 16 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 17 | nvalue = img 18 | else: 19 | nvalue = convert(value) 20 | ndata[nkey] = nvalue 21 | return ndata 22 | elif isinstance(data, list): 23 | return [convert(item) for item in data] 24 | elif isinstance(data, bytes): 25 | return data.decode() 26 | else: 27 | return data 28 | 29 | 30 | def to_np(x): 31 | return x.cpu().data.numpy() 32 | -------------------------------------------------------------------------------- /concern/icdar2015_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MhLiao/DB/65ca77a0bcfbd7114b916cf8a1e9ca85114286ce/concern/icdar2015_eval/__init__.py -------------------------------------------------------------------------------- /concern/icdar2015_eval/detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MhLiao/DB/65ca77a0bcfbd7114b916cf8a1e9ca85114286ce/concern/icdar2015_eval/detection/__init__.py -------------------------------------------------------------------------------- /concern/log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import functools 4 | import json 5 | import time 6 | from datetime import datetime 7 | 8 | from tensorboardX import SummaryWriter 9 | import yaml 10 | import cv2 11 | import numpy as np 12 | 13 | from concern.config import Configurable, State 14 | 15 | 16 | class Logger(Configurable): 17 | SUMMARY_DIR_NAME = 'summaries' 18 | VISUALIZE_NAME = 'visualize' 19 | LOG_FILE_NAME = 'output.log' 20 | ARGS_FILE_NAME = 'args.log' 21 | METRICS_FILE_NAME = 'metrics.log' 22 | 23 | database_dir = State(default='./outputs/') 24 | log_dir = State(default='workspace') 25 | verbose = State(default=False) 26 | level = State(default='info') 27 | log_interval = State(default=100) 28 | 29 | def __init__(self, **kwargs): 30 | self.load_all(**kwargs) 31 | 32 | self._make_storage() 33 | 34 | cmd = kwargs['cmd'] 35 | self.name = cmd['name'] 36 | self.log_dir = os.path.join(self.log_dir, self.name) 37 | try: 38 | self.verbose = cmd['verbose'] 39 | except: 40 | print('verbose:', self.verbose) 41 | if self.verbose: 42 | print('Initializing log dir for', self.log_dir) 43 | 44 | if not os.path.exists(self.log_dir): 45 | os.makedirs(self.log_dir) 46 | 47 | self.message_logger = self._init_message_logger() 48 | 49 | summary_path = os.path.join(self.log_dir, self.SUMMARY_DIR_NAME) 50 | self.tf_board_logger = SummaryWriter(summary_path) 51 | 52 | self.metrics_writer = open(os.path.join( 53 | self.log_dir, self.METRICS_FILE_NAME), 'at') 54 | 55 | self.timestamp = time.time() 56 | self.logged = -1 57 | self.speed = None 58 | self.eta_time = None 59 | 60 | def _make_storage(self): 61 | application = os.path.basename(os.getcwd()) 62 | storage_dir = os.path.join( 63 | self.database_dir, self.log_dir, application) 64 | if not os.path.exists(storage_dir): 65 | os.makedirs(storage_dir) 66 | if not os.path.exists(self.log_dir): 67 | os.symlink(storage_dir, self.log_dir) 68 | 69 | def save_dir(self, dir_name): 70 | return os.path.join(self.log_dir, dir_name) 71 | 72 | def _init_message_logger(self): 73 | message_logger = logging.getLogger('messages') 74 | message_logger.setLevel( 75 | logging.DEBUG if self.verbose else logging.INFO) 76 | formatter = logging.Formatter( 77 | '[%(levelname)s] [%(asctime)s] %(message)s') 78 | std_handler = logging.StreamHandler() 79 | std_handler.setLevel(message_logger.level) 80 | std_handler.setFormatter(formatter) 81 | 82 | file_handler = logging.FileHandler( 83 | os.path.join(self.log_dir, self.LOG_FILE_NAME)) 84 | file_handler.setLevel(message_logger.level) 85 | file_handler.setFormatter(formatter) 86 | 87 | message_logger.addHandler(std_handler) 88 | message_logger.addHandler(file_handler) 89 | return message_logger 90 | 91 | def report_time(self, name: str): 92 | if self.verbose: 93 | self.info(name + " time :" + str(time.time() - self.timestamp)) 94 | self.timestamp = time.time() 95 | 96 | def report_eta(self, steps, total, epoch): 97 | self.logged = self.logged % total + 1 98 | steps = steps % total 99 | if self.eta_time is None: 100 | self.eta_time = time.time() 101 | speed = -1 102 | else: 103 | eta_time = time.time() 104 | speed = eta_time - self.eta_time 105 | if self.speed is not None: 106 | speed = ((self.logged - 1) * self.speed + speed) / self.logged 107 | self.speed = speed 108 | self.eta_time = eta_time 109 | 110 | seconds = (total - steps) * speed 111 | hours = seconds // 3600 112 | minutes = (seconds - (hours * 3600)) // 60 113 | seconds = seconds % 60 114 | 115 | print('%d/%d batches processed in epoch %d, ETA: %2d:%2d:%2d' % 116 | (steps, total, epoch, 117 | hours, minutes, seconds), end='\r') 118 | 119 | def args(self, parameters=None): 120 | if parameters is None: 121 | with open(os.path.join(self.log_dir, self.ARGS_FILE_NAME), 'rt') as reader: 122 | return yaml.load(reader.read()) 123 | with open(os.path.join(self.log_dir, self.ARGS_FILE_NAME), 'wt') as writer: 124 | yaml.dump(parameters.dump(), writer) 125 | 126 | def metrics(self, epoch, steps, metrics_dict): 127 | results = {} 128 | for name, a in metrics_dict.items(): 129 | results[name] = {'count': a.count, 'value': float(a.avg)} 130 | self.add_scalar('metrics/' + name, a.avg, steps) 131 | result_dict = { 132 | str(datetime.now()): { 133 | 'epoch': epoch, 134 | 'steps': steps, 135 | **results 136 | } 137 | } 138 | string_result = yaml.dump(result_dict) 139 | self.info(string_result) 140 | self.metrics_writer.write(string_result) 141 | self.metrics_writer.flush() 142 | 143 | def named_number(self, name, num=None, default=0): 144 | if num is None: 145 | return int(self.has_signal(name)) or default 146 | else: 147 | with open(os.path.join(self.log_dir, name), 'w') as writer: 148 | writer.write(str(num)) 149 | return num 150 | 151 | epoch = functools.partialmethod(named_number, 'epoch') 152 | iter = functools.partialmethod(named_number, 'iter') 153 | 154 | def message(self, level, content): 155 | self.message_logger.__getattribute__(level)(content) 156 | 157 | def images(self, prefix, image_dict, step): 158 | for name, image in image_dict.items(): 159 | self.add_image(prefix + '/' + name, image, step, dataformats='HWC') 160 | 161 | def merge_save_images(self, name, images): 162 | for i, image in enumerate(images): 163 | if i == 0: 164 | result = image 165 | else: 166 | result = np.concatenate([result, image], 0) 167 | cv2.imwrite(os.path.join(self.vis_dir(), name+'.jpg'), result) 168 | 169 | def vis_dir(self): 170 | vis_dir = os.path.join(self.log_dir, self.VISUALIZE_NAME) 171 | if not os.path.exists(vis_dir): 172 | os.mkdir(vis_dir) 173 | return vis_dir 174 | 175 | def save_image_dict(self, images, max_size=1024): 176 | for file_name, image in images.items(): 177 | height, width = image.shape[:2] 178 | if height > width: 179 | actual_height = min(height, max_size) 180 | actual_width = int(round(actual_height * width / height)) 181 | else: 182 | actual_width = min(width, max_size) 183 | actual_height = int(round(actual_width * height / width)) 184 | image = cv2.resize(image, (actual_width, actual_height)) 185 | cv2.imwrite(os.path.join(self.vis_dir(), file_name+'.jpg'), image) 186 | 187 | def __getattr__(self, name): 188 | message_levels = set(['debug', 'info', 'warning', 'error', 'critical']) 189 | if name == '__setstate__': 190 | raise AttributeError('haha') 191 | if name in message_levels: 192 | return functools.partial(self.message, name) 193 | elif hasattr(self.__dict__.get('tf_board_logger'), name): 194 | return self.tf_board_logger.__getattribute__(name) 195 | else: 196 | super() 197 | -------------------------------------------------------------------------------- /concern/signal_monitor.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class SignalMonitor(object): 5 | def __init__(self, file_path): 6 | self.file_path = file_path 7 | 8 | def get_signal(self): 9 | if self.file_path is None: 10 | return None 11 | if os.path.exists(self.file_path): 12 | with open(self.file_path) as f: 13 | data = self.file.read() 14 | os.remove(f) 15 | return data 16 | else: 17 | return None 18 | -------------------------------------------------------------------------------- /concern/visualizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : visualizer.py 4 | # Author : Zhaoyi Wan 5 | # Date : 08.01.2019 6 | # Last Modified Date: 02.12.2019 7 | # Last Modified By : Minghui Liao 8 | import torch 9 | import numpy as np 10 | import cv2 11 | 12 | class Visualize: 13 | @classmethod 14 | def visualize(cls, x): 15 | dimension = len(x.shape) 16 | if dimension == 2: 17 | pass 18 | elif dimension == 3: 19 | pass 20 | 21 | @classmethod 22 | def to_np(cls, x): 23 | return x.cpu().data.numpy() 24 | 25 | @classmethod 26 | def visualize_weights(cls, tensor, format='HW', normalize=True): 27 | if isinstance(tensor, torch.Tensor): 28 | x = cls.to_np(tensor.permute(format.index('H'), format.index('W'))) 29 | else: 30 | x = tensor.transpose(format.index('H'), format.index('W')) 31 | if normalize: 32 | x = (x - x.min()) / (x.max() - x.min()) 33 | # return np.tile(x * 255., (3, 1, 1)).swapaxes(0, 2).swapaxes(1, 0).astype(np.uint8) 34 | return cv2.applyColorMap((x * 255).astype(np.uint8), cv2.COLORMAP_JET) 35 | 36 | @classmethod 37 | def visualize_points(cls, image, tensor, radius=5, normalized=True): 38 | if isinstance(tensor, torch.Tensor): 39 | points = cls.to_np(tensor) 40 | else: 41 | points = tensor 42 | if normalized: 43 | points = points * image.shape[:2][::-1] 44 | for i in range(points.shape[0]): 45 | color = np.random.randint( 46 | 0, 255, (3, ), dtype=np.uint8).astype(np.float) 47 | image = cv2.circle(image, 48 | tuple(points[i].astype(np.int32).tolist()), 49 | radius, color, thickness=radius//2) 50 | return image 51 | 52 | @classmethod 53 | def visualize_heatmap(cls, tensor, format='CHW'): 54 | if isinstance(tensor, torch.Tensor): 55 | x = cls.to_np(tensor.permute(format.index('H'), 56 | format.index('W'), format.index('C'))) 57 | else: 58 | x = tensor.transpose( 59 | format.index('H'), format.index('W'), format.index('C')) 60 | canvas = np.zeros((x.shape[0], x.shape[1], 3), dtype=np.float) 61 | 62 | for c in range(0, x.shape[-1]): 63 | color = np.random.randint( 64 | 0, 255, (3, ), dtype=np.uint8).astype(np.float) 65 | canvas += np.tile(x[:, :, c], (3, 1, 1) 66 | ).swapaxes(0, 2).swapaxes(1, 0) * color 67 | 68 | canvas = canvas.astype(np.uint8) 69 | return canvas 70 | 71 | @classmethod 72 | def visualize_classes(cls, x): 73 | canvas = np.zeros((x.shape[0], x.shape[1], 3), dtype=np.uint8) 74 | for c in range(int(x.max())): 75 | color = np.random.randint( 76 | 0, 255, (3, ), dtype=np.uint8).astype(np.float) 77 | canvas[np.where(x == c)] = color 78 | return canvas 79 | 80 | @classmethod 81 | def visualize_grid(cls, x, y, stride=16, color=(0, 0, 255), canvas=None): 82 | h, w = x.shape 83 | if canvas is None: 84 | canvas = np.zeros((h, w, 3), dtype=np.uint8) 85 | # canvas = np.concatenate([canvas, canvas], axis=1) 86 | i, j = 0, 0 87 | while i < w: 88 | j = 0 89 | while j < h: 90 | canvas = cv2.circle(canvas, (int(x[i, j] * w + 0.5), int(y[i, j] * h + 0.5)), radius=max(stride//4, 1), color=color, thickness=stride//8) 91 | j += stride 92 | i += stride 93 | return canvas 94 | 95 | @classmethod 96 | def visualize_rect(cls, canvas, _rect, color=(0, 0, 255)): 97 | rect = (_rect + 0.5).astype(np.int32) 98 | return cv2.rectangle(canvas, (rect[0], rect[1]), (rect[2], rect[3]), color) 99 | -------------------------------------------------------------------------------- /concern/webcv2/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env mdl 2 | class WebCV2: 3 | def __init__(self): 4 | import cv2 5 | self._cv2 = cv2 6 | from .manager import global_manager as gm 7 | self._gm = gm 8 | 9 | def __getattr__(self, name): 10 | if hasattr(self._gm, name): 11 | return getattr(self._gm, name) 12 | elif hasattr(self._cv2, name): 13 | return getattr(self._cv2, name) 14 | else: 15 | raise AttributeError 16 | 17 | import sys 18 | sys.modules[__name__] = WebCV2() 19 | 20 | -------------------------------------------------------------------------------- /concern/webcv2/manager.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env mdl 2 | import socket 3 | import base64 4 | import cv2 5 | import numpy as np 6 | from collections import OrderedDict 7 | 8 | from .server import get_server 9 | 10 | 11 | def jpeg_encode(img): 12 | return cv2.imencode('.png', img)[1] 13 | 14 | 15 | def get_free_port(rng, low=2000, high=10000): 16 | in_use = True 17 | while in_use: 18 | port = rng.randint(high - low) + low 19 | in_use = False 20 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 21 | try: 22 | s.bind(("0.0.0.0", port)) 23 | except socket.error as e: 24 | if e.errno == 98: # port already in use 25 | in_use = True 26 | s.close() 27 | return port 28 | 29 | 30 | class Manager: 31 | def __init__(self, img_encode_method=jpeg_encode, rng=None): 32 | self._queue = OrderedDict() 33 | self._server = None 34 | self.img_encode_method = img_encode_method 35 | if rng is None: 36 | rng = np.random.RandomState(self.get_default_seed()) 37 | self.rng = rng 38 | 39 | def get_default_seed(self): 40 | return 0 41 | 42 | def imshow(self, title, img): 43 | data = self.img_encode_method(img) 44 | data = base64.b64encode(data) 45 | data = data.decode('utf8') 46 | self._queue[title] = data 47 | 48 | def waitKey(self, delay=0): 49 | if self._server is None: 50 | self.port = get_free_port(self.rng) 51 | self._server, self._conn = get_server(port=self.port) 52 | self._conn.send([delay, list(self._queue.items())]) 53 | # self._queue = OrderedDict() 54 | return self._conn.recv() 55 | 56 | global_manager = Manager() 57 | 58 | -------------------------------------------------------------------------------- /concern/webcv2/server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env mdl 2 | import os 3 | BASE_DIR = os.path.dirname(os.path.realpath(__file__)) 4 | import time 5 | import json 6 | import select 7 | import traceback 8 | import socket 9 | from multiprocessing import Process, Pipe 10 | 11 | import gevent 12 | from gevent.pywsgi import WSGIServer 13 | from geventwebsocket.handler import WebSocketHandler 14 | from flask import Flask, request, render_template, abort 15 | 16 | 17 | def log_important_msg(msg, *, padding=3): 18 | msg_len = len(msg) 19 | width = msg_len + padding * 2 + 2 20 | print('#' * width) 21 | print('#' + ' ' * (width - 2) + '#') 22 | print('#' + ' ' * padding + msg + ' ' * padding + '#') 23 | print('#' + ' ' * (width - 2) + '#') 24 | print('#' * width) 25 | 26 | 27 | def hint_url(url, port): 28 | log_important_msg( 29 | 'The server is running at: {}'.format(url)) 30 | 31 | 32 | def _set_server(conn, name='webcv2', port=7788): 33 | package = None 34 | package_alive = False 35 | 36 | app = Flask(name) 37 | app.root_path = BASE_DIR 38 | 39 | @app.route('/') 40 | def index(): 41 | return render_template('index.html', title=name) 42 | 43 | @app.route('/stream') 44 | def stream(): 45 | def poll_ws(ws, delay): 46 | return len(select.select([ws.stream.handler.rfile], [], [], delay / 1000.)[0]) > 0 47 | 48 | if request.environ.get('wsgi.websocket'): 49 | ws = request.environ['wsgi.websocket'] 50 | if ws is None: 51 | abort(404) 52 | else: 53 | should_send = True 54 | while not ws.closed: 55 | global package 56 | global package_alive 57 | if conn.poll(): 58 | package = conn.recv() 59 | package_alive = True 60 | should_send = True 61 | if not should_send: 62 | continue 63 | should_send = False 64 | if package is None: 65 | ws.send(None) 66 | else: 67 | delay, info_lst = package 68 | ws.send(json.dumps((time.time(), package_alive, delay, info_lst))) 69 | if package_alive: 70 | if delay <= 0 or poll_ws(ws, delay): 71 | message = ws.receive() 72 | if ws.closed or message is None: 73 | break 74 | try: 75 | if isinstance(message, bytes): 76 | message = message.decode('utf8') 77 | message = int(message) 78 | except: 79 | traceback.print_exc() 80 | message = -1 81 | else: 82 | message = -1 83 | conn.send(message) 84 | package_alive = False 85 | return "" 86 | 87 | http_server = WSGIServer(('', port), app, handler_class=WebSocketHandler) 88 | hint_url('http://{}:{}'.format(socket.getfqdn(), port), port) 89 | http_server.serve_forever() 90 | 91 | 92 | def get_server(name='webcv2', port=7788): 93 | conn_server, conn_factory = Pipe() 94 | p_server = Process( 95 | target=_set_server, 96 | args=(conn_server,), 97 | kwargs=dict( 98 | name=name, port=port, 99 | ), 100 | ) 101 | p_server.daemon = True 102 | p_server.start() 103 | return p_server, conn_factory 104 | 105 | -------------------------------------------------------------------------------- /concern/webcv2/templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | {{title}} 7 | 8 | 9 | 10 | 11 | 12 | 13 | 118 | 119 | {% raw %} 120 | 121 |
122 | 131 | 134 |

Network: {{net_speed / 1000000 | decimal(2)}} MB/s * {{download_time | decimal(2)}} s = {{package_size / 1000000 | decimal(2)}} MB

135 |
136 |
137 |

{{obj[0]}}

138 | 139 |
140 |
141 |
142 | {% endraw %} 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /convert_to_onnx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import numpy as np 5 | from concern.config import Configurable, Config 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser(description='Convert model to ONNX') 10 | parser.add_argument('exp', type=str) 11 | parser.add_argument('resume', type=str, help='Resume from checkpoint') 12 | parser.add_argument('output', type=str, help='Output ONNX path') 13 | 14 | args = parser.parse_args() 15 | args = vars(args) 16 | args = {k: v for k, v in args.items() if v is not None} 17 | 18 | conf = Config() 19 | experiment_args = conf.compile(conf.load(args['exp']))['Experiment'] 20 | experiment_args.update(cmd=args) 21 | experiment = Configurable.construct_class_from_config(experiment_args) 22 | 23 | Demo(experiment, experiment_args, cmd=args).inference() 24 | 25 | 26 | class Demo: 27 | def __init__(self, experiment, args, cmd=dict()): 28 | self.RGB_MEAN = np.array([122.67891434, 116.66876762, 104.00698793]) 29 | self.experiment = experiment 30 | experiment.load('evaluation', **args) 31 | self.args = cmd 32 | self.structure = experiment.structure 33 | self.model_path = self.args['resume'] 34 | self.output_path = self.args['output'] 35 | 36 | def init_torch_tensor(self): 37 | # Use gpu or not 38 | if torch.cuda.is_available(): 39 | self.device = torch.device('cuda') 40 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 41 | else: 42 | self.device = torch.device('cpu') 43 | torch.set_default_tensor_type('torch.FloatTensor') 44 | 45 | def init_model(self): 46 | model = self.structure.builder.build(self.device) 47 | return model 48 | 49 | def resume(self, model, path): 50 | if not os.path.exists(path): 51 | print("Checkpoint not found: " + path) 52 | return 53 | states = torch.load(path, map_location=self.device) 54 | model.load_state_dict(states, strict=False) 55 | print("Resumed from " + path) 56 | 57 | def inference(self): 58 | self.init_torch_tensor() 59 | model = self.init_model() 60 | self.resume(model, self.model_path) 61 | model.eval() 62 | 63 | img = np.random.randint(0, 255, size=(960, 960, 3), dtype=np.uint8) 64 | img = img.astype(np.float32) 65 | img = (img / 255. - 0.5) / 0.5 # torch style norm 66 | img = img.transpose((2, 0, 1)) 67 | img = torch.from_numpy(img).unsqueeze(0).float() 68 | dynamic_axes = {'input': {0: 'batch_size', 2: 'height', 3: 'width'}, 69 | 'output': {0: 'batch_size', 2: 'height', 3: 'width'}} 70 | with torch.no_grad(): 71 | img = img.to(self.device) 72 | torch.onnx.export(model.model.module, img, self.output_path, input_names=['input'], 73 | output_names=['output'], dynamic_axes=dynamic_axes, keep_initializers_as_inputs=False, 74 | verbose=False, opset_version=12) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_seg_detector_data import MakeSegDetectorData 2 | from .transform_data import TransformData 3 | from .random_crop_aug import RandomCropAug 4 | from .make_border_map import MakeBorderMap 5 | from .image_dataset import ImageDataset 6 | -------------------------------------------------------------------------------- /data/augmenter.py: -------------------------------------------------------------------------------- 1 | import imgaug 2 | import imgaug.augmenters as iaa 3 | 4 | from concern.config import Configurable, State 5 | 6 | 7 | class AugmenterBuilder(object): 8 | def __init__(self): 9 | pass 10 | 11 | def build(self, args, root=True): 12 | if args is None: 13 | return None 14 | elif isinstance(args, (int, float, str)): 15 | return args 16 | elif isinstance(args, list): 17 | if root: 18 | sequence = [self.build(value, root=False) for value in args] 19 | return iaa.Sequential(sequence) 20 | else: 21 | return getattr(iaa, args[0])( 22 | *[self.to_tuple_if_list(a) for a in args[1:]]) 23 | elif isinstance(args, dict): 24 | if 'cls' in args: 25 | cls = getattr(iaa, args['cls']) 26 | return cls(**{k: self.to_tuple_if_list(v) for k, v in args.items() if not k == 'cls'}) 27 | else: 28 | return {key: self.build(value, root=False) for key, value in args.items()} 29 | else: 30 | raise RuntimeError('unknown augmenter arg: ' + str(args)) 31 | 32 | def to_tuple_if_list(self, obj): 33 | if isinstance(obj, list): 34 | return tuple(obj) 35 | return obj 36 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset as TorchDataset 2 | 3 | from concern.config import Configurable, State 4 | 5 | 6 | class SliceDataset(TorchDataset, Configurable): 7 | dataset = State() 8 | start = State() 9 | end = State() 10 | 11 | def __init__(self, **kwargs): 12 | self.load_all(**kwargs) 13 | 14 | if self.start is None: 15 | self.start = 0 16 | if self.end is None: 17 | self.end = len(self.dataset) 18 | 19 | def __getitem__(self, idx): 20 | return self.dataset[self.start + idx] 21 | 22 | def __len__(self): 23 | return self.end - self.start 24 | -------------------------------------------------------------------------------- /data/image_dataset.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import bisect 4 | 5 | import torch.utils.data as data 6 | import cv2 7 | import numpy as np 8 | import glob 9 | from concern.config import Configurable, State 10 | import math 11 | 12 | class ImageDataset(data.Dataset, Configurable): 13 | r'''Dataset reading from images. 14 | Args: 15 | Processes: A series of Callable object, which accept as parameter and return the data dict, 16 | typically inherrited the `DataProcess`(data/processes/data_process.py) class. 17 | ''' 18 | data_dir = State() 19 | data_list = State() 20 | processes = State(default=[]) 21 | 22 | def __init__(self, data_dir=None, data_list=None, cmd={}, **kwargs): 23 | self.load_all(**kwargs) 24 | self.data_dir = data_dir or self.data_dir 25 | self.data_list = data_list or self.data_list 26 | if 'train' in self.data_list[0]: 27 | self.is_training = True 28 | else: 29 | self.is_training = False 30 | self.debug = cmd.get('debug', False) 31 | self.image_paths = [] 32 | self.gt_paths = [] 33 | self.get_all_samples() 34 | 35 | def get_all_samples(self): 36 | for i in range(len(self.data_dir)): 37 | with open(self.data_list[i], 'r') as fid: 38 | image_list = fid.readlines() 39 | if self.is_training: 40 | image_path=[self.data_dir[i]+'/train_images/'+timg.strip() for timg in image_list] 41 | gt_path=[self.data_dir[i]+'/train_gts/'+timg.strip()+'.txt' for timg in image_list] 42 | else: 43 | image_path=[self.data_dir[i]+'/test_images/'+timg.strip() for timg in image_list] 44 | print(self.data_dir[i]) 45 | if 'TD500' in self.data_list[i] or 'total_text' in self.data_list[i]: 46 | gt_path=[self.data_dir[i]+'/test_gts/'+timg.strip()+'.txt' for timg in image_list] 47 | else: 48 | gt_path=[self.data_dir[i]+'/test_gts/'+'gt_'+timg.strip().split('.')[0]+'.txt' for timg in image_list] 49 | self.image_paths += image_path 50 | self.gt_paths += gt_path 51 | self.num_samples = len(self.image_paths) 52 | self.targets = self.load_ann() 53 | if self.is_training: 54 | assert len(self.image_paths) == len(self.targets) 55 | 56 | def load_ann(self): 57 | res = [] 58 | for gt in self.gt_paths: 59 | lines = [] 60 | reader = open(gt, 'r').readlines() 61 | for line in reader: 62 | item = {} 63 | parts = line.strip().split(',') 64 | label = parts[-1] 65 | if 'TD' in self.data_dir[0] and label == '1': 66 | label = '###' 67 | line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in parts] 68 | if 'icdar' in self.data_dir[0]: 69 | poly = np.array(list(map(float, line[:8]))).reshape((-1, 2)).tolist() 70 | else: 71 | num_points = math.floor((len(line) - 1) / 2) * 2 72 | poly = np.array(list(map(float, line[:num_points]))).reshape((-1, 2)).tolist() 73 | item['poly'] = poly 74 | item['text'] = label 75 | lines.append(item) 76 | res.append(lines) 77 | return res 78 | 79 | def __getitem__(self, index, retry=0): 80 | if index >= self.num_samples: 81 | index = index % self.num_samples 82 | data = {} 83 | image_path = self.image_paths[index] 84 | img = cv2.imread(image_path, cv2.IMREAD_COLOR).astype('float32') 85 | if self.is_training: 86 | data['filename'] = image_path 87 | data['data_id'] = image_path 88 | else: 89 | data['filename'] = image_path.split('/')[-1] 90 | data['data_id'] = image_path.split('/')[-1] 91 | data['image'] = img 92 | target = self.targets[index] 93 | data['lines'] = target 94 | if self.processes is not None: 95 | for data_process in self.processes: 96 | data = data_process(data) 97 | return data 98 | 99 | def __len__(self): 100 | return len(self.image_paths) 101 | -------------------------------------------------------------------------------- /data/make_border_map.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | import cv2 4 | from shapely.geometry import Polygon 5 | import pyclipper 6 | 7 | from concern.config import Configurable, State 8 | 9 | 10 | class MakeBorderMap(Configurable): 11 | shrink_ratio = State(default=0.4) 12 | thresh_min = State(default=0.3) 13 | thresh_max = State(default=0.7) 14 | 15 | def __init__(self, cmd={}, *args, **kwargs): 16 | self.load_all(cmd=cmd, **kwargs) 17 | warnings.simplefilter("ignore") 18 | 19 | 20 | def __call__(self, data, *args, **kwargs): 21 | image = data['image'] 22 | polygons = data['polygons'] 23 | ignore_tags = data['ignore_tags'] 24 | 25 | canvas = np.zeros(image.shape[:2], dtype=np.float32) 26 | mask = np.zeros(image.shape[:2], dtype=np.float32) 27 | 28 | for i in range(len(polygons)): 29 | if ignore_tags[i]: 30 | continue 31 | self.draw_border_map(polygons[i], canvas, mask=mask) 32 | canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min 33 | data['thresh_map'] = canvas 34 | data['thresh_mask'] = mask 35 | return data 36 | 37 | def draw_border_map(self, polygon, canvas, mask): 38 | polygon = np.array(polygon) 39 | assert polygon.ndim == 2 40 | assert polygon.shape[1] == 2 41 | 42 | polygon_shape = Polygon(polygon) 43 | distance = polygon_shape.area * \ 44 | (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length 45 | subject = [tuple(l) for l in polygon] 46 | padding = pyclipper.PyclipperOffset() 47 | padding.AddPath(subject, pyclipper.JT_ROUND, 48 | pyclipper.ET_CLOSEDPOLYGON) 49 | padded_polygon = np.array(padding.Execute(distance)[0]) 50 | cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) 51 | 52 | xmin = padded_polygon[:, 0].min() 53 | xmax = padded_polygon[:, 0].max() 54 | ymin = padded_polygon[:, 1].min() 55 | ymax = padded_polygon[:, 1].max() 56 | width = xmax - xmin + 1 57 | height = ymax - ymin + 1 58 | 59 | polygon[:, 0] = polygon[:, 0] - xmin 60 | polygon[:, 1] = polygon[:, 1] - ymin 61 | 62 | xs = np.broadcast_to( 63 | np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)) 64 | ys = np.broadcast_to( 65 | np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)) 66 | 67 | distance_map = np.zeros( 68 | (polygon.shape[0], height, width), dtype=np.float32) 69 | for i in range(polygon.shape[0]): 70 | j = (i + 1) % polygon.shape[0] 71 | absolute_distance = self.distance(xs, ys, polygon[i], polygon[j]) 72 | distance_map[i] = np.clip(absolute_distance / distance, 0, 1) 73 | distance_map = distance_map.min(axis=0) 74 | 75 | xmin_valid = min(max(0, xmin), canvas.shape[1] - 1) 76 | xmax_valid = min(max(0, xmax), canvas.shape[1] - 1) 77 | ymin_valid = min(max(0, ymin), canvas.shape[0] - 1) 78 | ymax_valid = min(max(0, ymax), canvas.shape[0] - 1) 79 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax( 80 | 1 - distance_map[ 81 | ymin_valid-ymin:ymax_valid-ymax+height, 82 | xmin_valid-xmin:xmax_valid-xmax+width], 83 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1]) 84 | 85 | def distance(self, xs, ys, point_1, point_2): 86 | ''' 87 | compute the distance from point to a line 88 | ys: coordinates in the first axis 89 | xs: coordinates in the second axis 90 | point_1, point_2: (x, y), the end of the line 91 | ''' 92 | height, width = xs.shape[:2] 93 | square_distance_1 = np.square( 94 | xs - point_1[0]) + np.square(ys - point_1[1]) 95 | square_distance_2 = np.square( 96 | xs - point_2[0]) + np.square(ys - point_2[1]) 97 | square_distance = np.square( 98 | point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1]) 99 | 100 | cosin = (square_distance - square_distance_1 - square_distance_2) / \ 101 | (2 * np.sqrt(square_distance_1 * square_distance_2)) 102 | square_sin = 1 - np.square(cosin) 103 | square_sin = np.nan_to_num(square_sin) 104 | result = np.sqrt(square_distance_1 * square_distance_2 * 105 | square_sin / square_distance) 106 | 107 | result[cosin < 0] = np.sqrt(np.fmin( 108 | square_distance_1, square_distance_2))[cosin < 0] 109 | # self.extend_line(point_1, point_2, result) 110 | return result 111 | 112 | def extend_line(self, point_1, point_2, result): 113 | ex_point_1 = (int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))), 114 | int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio)))) 115 | cv2.line(result, tuple(ex_point_1), tuple(point_1), 116 | 4096.0, 1, lineType=cv2.LINE_AA, shift=0) 117 | ex_point_2 = (int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))), 118 | int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio)))) 119 | cv2.line(result, tuple(ex_point_2), tuple(point_2), 120 | 4096.0, 1, lineType=cv2.LINE_AA, shift=0) 121 | return ex_point_1, ex_point_2 122 | -------------------------------------------------------------------------------- /data/make_seg_detector_data.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import cv2 5 | from shapely.geometry import Polygon 6 | import pyclipper 7 | 8 | from concern.config import Configurable, State 9 | 10 | 11 | class MakeSegDetectorData(Configurable): 12 | min_text_size = State(default=8) 13 | shrink_ratio = State(default=0.4) 14 | 15 | def __init__(self, **kwargs): 16 | self.load_all(**kwargs) 17 | 18 | def __call__(self, data, *args, **kwargs): 19 | ''' 20 | data: a dict typically returned from `MakeICDARData`, 21 | where the following keys are contrains: 22 | image*, polygons*, ignore_tags*, shape, filename 23 | * means required. 24 | ''' 25 | image = data['image'] 26 | polygons = data['polygons'] 27 | ignore_tags = data['ignore_tags'] 28 | image = data['image'] 29 | filename = data['filename'] 30 | 31 | h, w = image.shape[:2] 32 | polygons, ignore_tags = self.validate_polygons( 33 | polygons, ignore_tags, h, w) 34 | gt = np.zeros((1, h, w), dtype=np.float32) 35 | mask = np.ones((h, w), dtype=np.float32) 36 | for i in range(polygons.shape[0]): 37 | polygon = polygons[i] 38 | height = min(np.linalg.norm(polygon[0] - polygon[3]), 39 | np.linalg.norm(polygon[1] - polygon[2])) 40 | width = min(np.linalg.norm(polygon[0] - polygon[1]), 41 | np.linalg.norm(polygon[2] - polygon[3])) 42 | if ignore_tags[i] or min(height, width) < self.min_text_size: 43 | cv2.fillPoly(mask, polygon.astype( 44 | np.int32)[np.newaxis, :, :], 0) 45 | ignore_tags[i] = True 46 | else: 47 | polygon_shape = Polygon(polygon) 48 | distance = polygon_shape.area * \ 49 | (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length 50 | subject = [tuple(l) for l in polygons[i]] 51 | padding = pyclipper.PyclipperOffset() 52 | padding.AddPath(subject, pyclipper.JT_ROUND, 53 | pyclipper.ET_CLOSEDPOLYGON) 54 | shrinked = padding.Execute(-distance) 55 | if shrinked == []: 56 | cv2.fillPoly(mask, polygon.astype( 57 | np.int32)[np.newaxis, :, :], 0) 58 | ignore_tags[i] = True 59 | continue 60 | shrinked = np.array(shrinked[0]).reshape(-1, 2) 61 | cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1) 62 | 63 | if filename is None: 64 | filename = '' 65 | data.update(image=image, 66 | polygons=polygons, 67 | gt=gt, mask=mask, filename=filename) 68 | return data 69 | 70 | def validate_polygons(self, polygons, ignore_tags, h, w): 71 | ''' 72 | polygons (numpy.array, required): of shape (num_instances, num_points, 2) 73 | ''' 74 | if polygons.shape[0] == 0: 75 | return polygons, ignore_tags 76 | assert polygons.shape[0] == len(ignore_tags) 77 | 78 | polygons[:, :, 0] = np.clip(polygons[:, :, 0], 0, w - 1) 79 | polygons[:, :, 1] = np.clip(polygons[:, :, 1], 0, h - 1) 80 | 81 | for i in range(polygons.shape[0]): 82 | area = self.polygon_area(polygons[i]) 83 | if abs(area) < 1: 84 | ignore_tags[i] = True 85 | if area > 0: 86 | polygons[i] = polygons[i][(0, 3, 2, 1), :] 87 | return polygons, ignore_tags 88 | 89 | def polygon_area(self, polygon): 90 | edge = [ 91 | (polygon[1][0] - polygon[0][0]) * (polygon[1][1] + polygon[0][1]), 92 | (polygon[2][0] - polygon[1][0]) * (polygon[2][1] + polygon[1][1]), 93 | (polygon[3][0] - polygon[2][0]) * (polygon[3][1] + polygon[2][1]), 94 | (polygon[0][0] - polygon[3][0]) * (polygon[0][1] + polygon[3][1]) 95 | ] 96 | return np.sum(edge) / 2. 97 | -------------------------------------------------------------------------------- /data/meta_loader.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import hashlib 3 | import os 4 | import io 5 | import urllib.parse as urlparse 6 | import warnings 7 | import numpy as np 8 | from concern.charset_tool import stringQ2B 9 | from hanziconv import HanziConv 10 | from concern.config import Configurable, State 11 | from data.text_lines import TextLines 12 | 13 | class DataIdMetaLoader(MetaLoader): 14 | return_dict = State(default=False) 15 | scan_meta = False 16 | 17 | def __init__(self, return_dict=None, cmd={}, **kwargs): 18 | super().__init__(cmd=cmd, **kwargs) 19 | if return_dict is not None: 20 | self.return_dict = return_dict 21 | 22 | def parse_meta(self, data_id): 23 | return dict(data_id=data_id) 24 | 25 | def post_prosess(self, meta): 26 | if self.return_dict: 27 | return meta 28 | return meta['data_id'] 29 | 30 | class MetaCache(Configurable): 31 | META_FILE = 'meta_cache.pickle' 32 | client = State(default='all') 33 | 34 | def __init__(self, **kwargs): 35 | self.load_all(**kwargs) 36 | 37 | def cache(self, nori_path, meta=None): 38 | if meta is None: 39 | return self.read(nori_path) 40 | else: 41 | return self.save(nori_path, meta) 42 | 43 | def read(self, nori_path): 44 | raise NotImplementedError 45 | 46 | def save(self, nori_path, meta): 47 | raise NotImplementedError 48 | 49 | 50 | class FileMetaCache(MetaCache): 51 | storage_dir = State(default='/data/.meta_cache') 52 | 53 | def __init__(self, storage_dir=None, cmd={}, **kwargs): 54 | super(FileMetaCache, self).__init__(cmd=cmd, **kwargs) 55 | 56 | self.storage_dir = cmd.get('storage_dir', self.storage_dir) 57 | if storage_dir is not None: 58 | self.storage_dir = storage_dir 59 | self.debug = cmd.get('debug', False) 60 | 61 | def ensure_dir(self): 62 | if not os.path.exists(self.storage_dir): 63 | os.makedirs(self.storage_dir) 64 | 65 | def storate_path(self, nori_path): 66 | return os.path.join(self.storage_dir, self.hash(nori_path) + '.pickle') 67 | 68 | def hash(self, nori_path: str): 69 | return hashlib.md5(nori_path.encode('utf-8')).hexdigest() + '-' + self.client 70 | 71 | def read(self, nori_path): 72 | file_path = self.storate_path(nori_path) 73 | if not os.path.exists(file_path): 74 | warnings.warn( 75 | 'Meta cache not found: ' + file_path) 76 | warnings.warn('Now trying to read meta from nori') 77 | return None 78 | with open(file_path, 'rb') as reader: 79 | try: 80 | return pickle.load(reader) 81 | except EOFError as e: # recover from broken file 82 | if self.debug: 83 | raise e 84 | return None 85 | 86 | def save(self, nori_path, meta): 87 | self.ensure_dir() 88 | 89 | with open(self.storate_path(nori_path), 'wb') as writer: 90 | pickle.dump(meta, writer) 91 | return True 92 | -------------------------------------------------------------------------------- /data/processes/__init__.py: -------------------------------------------------------------------------------- 1 | from .normalize_image import NormalizeImage 2 | from .make_center_points import MakeCenterPoints 3 | from .resize_image import ResizeImage, ResizeData 4 | from .filter_keys import FilterKeys 5 | from .make_center_map import MakeCenterMap 6 | from .augment_data import AugmentData, AugmentDetectionData 7 | from .random_crop_data import RandomCropData 8 | from .make_icdar_data import MakeICDARData, ICDARCollectFN 9 | from .make_seg_detection_data import MakeSegDetectionData 10 | from .make_border_map import MakeBorderMap 11 | -------------------------------------------------------------------------------- /data/processes/augment_data.py: -------------------------------------------------------------------------------- 1 | import imgaug 2 | import numpy as np 3 | 4 | from concern.config import State 5 | from .data_process import DataProcess 6 | from data.augmenter import AugmenterBuilder 7 | import cv2 8 | import math 9 | 10 | 11 | class AugmentData(DataProcess): 12 | augmenter_args = State(autoload=False) 13 | 14 | def __init__(self, **kwargs): 15 | self.augmenter_args = kwargs.get('augmenter_args') 16 | self.keep_ratio = kwargs.get('keep_ratio') 17 | self.only_resize = kwargs.get('only_resize') 18 | self.augmenter = AugmenterBuilder().build(self.augmenter_args) 19 | 20 | def may_augment_annotation(self, aug, data): 21 | pass 22 | 23 | def resize_image(self, image): 24 | origin_height, origin_width, _ = image.shape 25 | resize_shape = self.augmenter_args[0][1] 26 | height = resize_shape['height'] 27 | width = resize_shape['width'] 28 | if self.keep_ratio: 29 | width = origin_width * height / origin_height 30 | N = math.ceil(width / 32) 31 | width = N * 32 32 | image = cv2.resize(image, (width, height)) 33 | return image 34 | 35 | def process(self, data): 36 | image = data['image'] 37 | aug = None 38 | shape = image.shape 39 | 40 | if self.augmenter: 41 | aug = self.augmenter.to_deterministic() 42 | if self.only_resize: 43 | data['image'] = self.resize_image(image) 44 | else: 45 | data['image'] = aug.augment_image(image) 46 | self.may_augment_annotation(aug, data, shape) 47 | 48 | filename = data.get('filename', data.get('data_id', '')) 49 | data.update(filename=filename, shape=shape[:2]) 50 | if not self.only_resize: 51 | data['is_training'] = True 52 | else: 53 | data['is_training'] = False 54 | return data 55 | 56 | 57 | class AugmentDetectionData(AugmentData): 58 | def may_augment_annotation(self, aug, data, shape): 59 | if aug is None: 60 | return data 61 | 62 | line_polys = [] 63 | for line in data['lines']: 64 | if self.only_resize: 65 | new_poly = [(p[0], p[1]) for p in line['poly']] 66 | else: 67 | new_poly = self.may_augment_poly(aug, shape, line['poly']) 68 | line_polys.append({ 69 | 'points': new_poly, 70 | 'ignore': line['text'] == '###', 71 | 'text': line['text'], 72 | }) 73 | data['polys'] = line_polys 74 | return data 75 | 76 | def may_augment_poly(self, aug, img_shape, poly): 77 | keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly] 78 | keypoints = aug.augment_keypoints( 79 | [imgaug.KeypointsOnImage(keypoints, shape=img_shape)])[0].keypoints 80 | poly = [(p.x, p.y) for p in keypoints] 81 | return poly 82 | 83 | -------------------------------------------------------------------------------- /data/processes/data_process.py: -------------------------------------------------------------------------------- 1 | from concern.config import Configurable 2 | 3 | 4 | class DataProcess(Configurable): 5 | r'''Processes of data dict. 6 | ''' 7 | 8 | def __call__(self, data): 9 | return self.process(data) 10 | 11 | def process(self, data): 12 | raise NotImplementedError 13 | 14 | def render_constant(self, canvas, xmin, xmax, ymin, ymax, value=1, shrink=0): 15 | def shrink_rect(xmin, xmax, ratio): 16 | center = (xmin + xmax) / 2 17 | width = center - xmin 18 | return int(center - width * ratio + 0.5), int(center + width * ratio + 0.5) 19 | 20 | if shrink > 0: 21 | xmin, xmax = shrink_rect(xmin, xmax, shrink) 22 | ymin, ymax = shrink_rect(ymin, ymax, shrink) 23 | 24 | canvas[int(ymin+0.5):int(ymax+0.5)+1, int(xmin+0.5):int(xmax+0.5)+1] = value 25 | return canvas 26 | -------------------------------------------------------------------------------- /data/processes/filter_keys.py: -------------------------------------------------------------------------------- 1 | from concern.config import State 2 | 3 | from .data_process import DataProcess 4 | 5 | 6 | class FilterKeys(DataProcess): 7 | required = State(default=[]) 8 | superfluous = State(default=[]) 9 | 10 | def __init__(self, **kwargs): 11 | super().__init__(self, **kwargs) 12 | 13 | self.required_keys = set(self.required) 14 | self.superfluous_keys = set(self.superfluous) 15 | if len(self.required_keys) > 0 and len(self.superfluous_keys) > 0: 16 | raise ValueError( 17 | 'required_keys and superfluous_keys can not be specified at the same time.') 18 | 19 | def process(self, data): 20 | for key in self.required: 21 | assert key in data, '%s is required in data' % key 22 | 23 | superfluous = self.superfluous_keys 24 | if len(superfluous) == 0: 25 | for key in data.keys(): 26 | if key not in self.required_keys: 27 | superfluous.add(key) 28 | 29 | for key in superfluous: 30 | del data[key] 31 | return data 32 | -------------------------------------------------------------------------------- /data/processes/make_border_map.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | import cv2 4 | from shapely.geometry import Polygon 5 | import pyclipper 6 | 7 | from concern.config import State 8 | from .data_process import DataProcess 9 | 10 | 11 | class MakeBorderMap(DataProcess): 12 | r''' 13 | Making the border map from detection data with ICDAR format. 14 | Typically following the process of class `MakeICDARData`. 15 | ''' 16 | shrink_ratio = State(default=0.4) 17 | thresh_min = State(default=0.3) 18 | thresh_max = State(default=0.7) 19 | 20 | def __init__(self, cmd={}, *args, **kwargs): 21 | self.load_all(cmd=cmd, **kwargs) 22 | warnings.simplefilter("ignore") 23 | 24 | def process(self, data, *args, **kwargs): 25 | r''' 26 | required keys: 27 | image, polygons, ignore_tags 28 | adding keys: 29 | thresh_map, thresh_mask 30 | ''' 31 | image = data['image'] 32 | polygons = data['polygons'] 33 | ignore_tags = data['ignore_tags'] 34 | canvas = np.zeros(image.shape[:2], dtype=np.float32) 35 | mask = np.zeros(image.shape[:2], dtype=np.float32) 36 | 37 | for i in range(len(polygons)): 38 | if ignore_tags[i]: 39 | continue 40 | self.draw_border_map(polygons[i], canvas, mask=mask) 41 | canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min 42 | data['thresh_map'] = canvas 43 | data['thresh_mask'] = mask 44 | return data 45 | 46 | def draw_border_map(self, polygon, canvas, mask): 47 | polygon = np.array(polygon) 48 | assert polygon.ndim == 2 49 | assert polygon.shape[1] == 2 50 | 51 | polygon_shape = Polygon(polygon) 52 | distance = polygon_shape.area * \ 53 | (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length 54 | subject = [tuple(l) for l in polygon] 55 | padding = pyclipper.PyclipperOffset() 56 | padding.AddPath(subject, pyclipper.JT_ROUND, 57 | pyclipper.ET_CLOSEDPOLYGON) 58 | padded_polygon = np.array(padding.Execute(distance)[0]) 59 | cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) 60 | 61 | xmin = padded_polygon[:, 0].min() 62 | xmax = padded_polygon[:, 0].max() 63 | ymin = padded_polygon[:, 1].min() 64 | ymax = padded_polygon[:, 1].max() 65 | width = xmax - xmin + 1 66 | height = ymax - ymin + 1 67 | 68 | polygon[:, 0] = polygon[:, 0] - xmin 69 | polygon[:, 1] = polygon[:, 1] - ymin 70 | 71 | xs = np.broadcast_to( 72 | np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)) 73 | ys = np.broadcast_to( 74 | np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)) 75 | 76 | distance_map = np.zeros( 77 | (polygon.shape[0], height, width), dtype=np.float32) 78 | for i in range(polygon.shape[0]): 79 | j = (i + 1) % polygon.shape[0] 80 | absolute_distance = self.distance(xs, ys, polygon[i], polygon[j]) 81 | distance_map[i] = np.clip(absolute_distance / distance, 0, 1) 82 | distance_map = distance_map.min(axis=0) 83 | 84 | xmin_valid = min(max(0, xmin), canvas.shape[1] - 1) 85 | xmax_valid = min(max(0, xmax), canvas.shape[1] - 1) 86 | ymin_valid = min(max(0, ymin), canvas.shape[0] - 1) 87 | ymax_valid = min(max(0, ymax), canvas.shape[0] - 1) 88 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax( 89 | 1 - distance_map[ 90 | ymin_valid-ymin:ymax_valid-ymax+height, 91 | xmin_valid-xmin:xmax_valid-xmax+width], 92 | canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1]) 93 | 94 | def distance(self, xs, ys, point_1, point_2): 95 | ''' 96 | compute the distance from point to a line 97 | ys: coordinates in the first axis 98 | xs: coordinates in the second axis 99 | point_1, point_2: (x, y), the end of the line 100 | ''' 101 | height, width = xs.shape[:2] 102 | square_distance_1 = np.square( 103 | xs - point_1[0]) + np.square(ys - point_1[1]) 104 | square_distance_2 = np.square( 105 | xs - point_2[0]) + np.square(ys - point_2[1]) 106 | square_distance = np.square( 107 | point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1]) 108 | 109 | cosin = (square_distance - square_distance_1 - square_distance_2) / \ 110 | (2 * np.sqrt(square_distance_1 * square_distance_2)) 111 | square_sin = 1 - np.square(cosin) 112 | square_sin = np.nan_to_num(square_sin) 113 | result = np.sqrt(square_distance_1 * square_distance_2 * 114 | square_sin / square_distance) 115 | 116 | result[cosin < 0] = np.sqrt(np.fmin( 117 | square_distance_1, square_distance_2))[cosin < 0] 118 | # self.extend_line(point_1, point_2, result) 119 | return result 120 | 121 | def extend_line(self, point_1, point_2, result): 122 | ex_point_1 = (int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))), 123 | int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio)))) 124 | cv2.line(result, tuple(ex_point_1), tuple(point_1), 125 | 4096.0, 1, lineType=cv2.LINE_AA, shift=0) 126 | ex_point_2 = (int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))), 127 | int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio)))) 128 | cv2.line(result, tuple(ex_point_2), tuple(point_2), 129 | 4096.0, 1, lineType=cv2.LINE_AA, shift=0) 130 | return ex_point_1, ex_point_2 131 | 132 | -------------------------------------------------------------------------------- /data/processes/make_center_distance_map.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | import cv2 4 | from shapely.geometry import Polygon 5 | import pyclipper 6 | 7 | from concern.config import State 8 | from .data_process import DataProcess 9 | 10 | 11 | class MakeCenterDistanceMap(DataProcess): 12 | r''' 13 | Making the border map from detection data with ICDAR format. 14 | Typically following the process of class `MakeICDARData`. 15 | ''' 16 | expansion_ratio = State(default=0.1) 17 | 18 | def __init__(self, cmd={}, *args, **kwargs): 19 | self.load_all(cmd=cmd, **kwargs) 20 | warnings.simplefilter("ignore") 21 | 22 | def process(self, data, *args, **kwargs): 23 | r''' 24 | required keys: 25 | image. 26 | lines: Instace of `TextLines`, which is defined in data/text_lines.py 27 | adding keys: 28 | distance_map 29 | ''' 30 | image = data['image'] 31 | lines = data['lines'] 32 | 33 | h, w = image.shape[:2] 34 | canvas = np.zeros(image.shape[:2], dtype=np.float32) 35 | mask = np.zeros(image.shape[:2], dtype=np.float32) 36 | for _, quad in lines: 37 | padded = self.expand_quad(quad) 38 | center_x = padded[:, 0].mean() 39 | center_y = padded[:, 1].mean() 40 | index_x, index_y = np.meshgrid(np.arange(w), np.arange(h)) 41 | self.render_distance_map(canvas, center_x, center_y, index_x, index_y) 42 | self.render_constant(mask, quad, 1) 43 | 44 | canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min 45 | data['thresh_map'] = canvas 46 | return data 47 | 48 | def expand_quad(self, polygon): 49 | polygon = np.array(polygon) 50 | assert polygon.ndim == 2 51 | assert polygon.shape[1] == 2 52 | 53 | polygon_shape = Polygon(polygon) 54 | distance = polygon_shape.area * \ 55 | (1 - np.power(self.expansion_ratio, 2)) / polygon_shape.length 56 | subject = [tuple(l) for l in polygon] 57 | padding = pyclipper.PyclipperOffset() 58 | padding.AddPath(subject, pyclipper.JT_ROUND, 59 | pyclipper.ET_CLOSEDPOLYGON) 60 | padded_polygon = np.array(padding.Execute(distance)[0]) 61 | return padded_polygon 62 | cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) 63 | 64 | 65 | 66 | 67 | def distance(self, xs, ys, point): 68 | ''' 69 | compute the distance from point to a line 70 | ys: coordinates in the first axis 71 | xs: coordinates in the second axis 72 | point_1, point_2: (x, y), the end of the line 73 | ''' 74 | height, width = xs.shape[:2] 75 | square_distance_1 = np.square( 76 | xs - point_1[0]) + np.square(ys - point_1[1]) 77 | square_distance_2 = np.square( 78 | xs - point_2[0]) + np.square(ys - point_2[1]) 79 | square_distance = np.square( 80 | point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1]) 81 | 82 | cosin = (square_distance - square_distance_1 - square_distance_2) / \ 83 | (2 * np.sqrt(square_distance_1 * square_distance_2)) 84 | square_sin = 1 - np.square(cosin) 85 | square_sin = np.nan_to_num(square_sin) 86 | result = np.sqrt(square_distance_1 * square_distance_2 * 87 | square_sin / square_distance) 88 | 89 | result[cosin < 0] = np.sqrt(np.fmin( 90 | square_distance_1, square_distance_2))[cosin < 0] 91 | # self.extend_line(point_1, point_2, result) 92 | return result 93 | 94 | def extend_line(self, point_1, point_2, result): 95 | ex_point_1 = (int(round(point_1[0] + (point_1[0] - point_2[0]) * (1 + self.shrink_ratio))), 96 | int(round(point_1[1] + (point_1[1] - point_2[1]) * (1 + self.shrink_ratio)))) 97 | cv2.line(result, tuple(ex_point_1), tuple(point_1), 98 | 4096.0, 1, lineType=cv2.LINE_AA, shift=0) 99 | ex_point_2 = (int(round(point_2[0] + (point_2[0] - point_1[0]) * (1 + self.shrink_ratio))), 100 | int(round(point_2[1] + (point_2[1] - point_1[1]) * (1 + self.shrink_ratio)))) 101 | cv2.line(result, tuple(ex_point_2), tuple(point_2), 102 | 4096.0, 1, lineType=cv2.LINE_AA, shift=0) 103 | return ex_point_1, ex_point_2 104 | 105 | 106 | -------------------------------------------------------------------------------- /data/processes/make_center_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.ndimage.filters as fi 3 | 4 | from concern.config import State 5 | 6 | from .data_process import DataProcess 7 | 8 | 9 | class MakeCenterMap(DataProcess): 10 | max_size = State(default=32) 11 | shape = State(default=(64, 256)) 12 | sigma_ratio = State(default=16) 13 | function_name = State(default='sample_gaussian') 14 | points_key = 'points' 15 | correlation = 0 # The formulation of guassian is simplified when correlation is 0 16 | 17 | def process(self, data): 18 | assert self.points_key in data, '%s in data is required' % self.points_key 19 | points = data['points'] * self.shape[::-1] # N, 2 20 | assert points.shape[0] >= self.max_size 21 | func = getattr(self, self.function_name) 22 | data['charmaps'] = func(points, *self.shape) 23 | return data 24 | 25 | def gaussian(self, points, height, width): 26 | index_x, index_y = np.meshgrid(np.linspace(0, width, width), 27 | np.linspace(0, height, height)) 28 | index_x = np.repeat(index_x[np.newaxis], points.shape[0], axis=0) 29 | index_y = np.repeat(index_y[np.newaxis], points.shape[0], axis=0) 30 | mu_x = points[:, 0][:, np.newaxis, np.newaxis] 31 | mu_y = points[:, 1][:, np.newaxis, np.newaxis] 32 | mask_is_zero = ((mu_x == 0) + (mu_y == 0)) == 0 33 | result = np.reciprocal(2 * np.pi * width / self.sigma_ratio * height / self.sigma_ratio)\ 34 | * np.exp(- 0.5 * (np.square((index_x - mu_x) / width * self.sigma_ratio) + 35 | np.square((index_y - mu_y) / height * self.sigma_ratio))) 36 | 37 | result = result / \ 38 | np.maximum(result.max(axis=1, keepdims=True).max( 39 | axis=2, keepdims=True), np.finfo(np.float32).eps) 40 | result = result * mask_is_zero 41 | return result.astype(np.float32) 42 | 43 | def sample_gaussian(self, points, height, width): 44 | points = (points + 0.5).astype(np.int32) 45 | canvas = np.zeros((self.max_size, height, width), dtype=np.float32) 46 | for index in range(canvas.shape[0]): 47 | point = points[index] 48 | canvas[index, point[1], point[0]] = 1. 49 | if point.sum() > 0: 50 | fi.gaussian_filter(canvas[index], (height // self.sigma_ratio, 51 | width // self.sigma_ratio), 52 | output=canvas[index], mode='mirror') 53 | canvas[index] = canvas[index] / canvas[index].max() 54 | x_range = min(point[0], width - point[0]) 55 | canvas[index, :, :point[0] - x_range] = 0 56 | canvas[index, :, point[0] + x_range:] = 0 57 | y_range = min(point[1], width - point[1]) 58 | canvas[index, :point[1] - y_range, :] = 0 59 | canvas[index, point[1] + y_range:, :] = 0 60 | return canvas 61 | -------------------------------------------------------------------------------- /data/processes/make_center_points.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from concern.config import State 4 | from .data_process import DataProcess 5 | 6 | 7 | class MakeCenterPoints(DataProcess): 8 | box_key = State(default='charboxes') 9 | size = State(default=32) 10 | 11 | def process(self, data): 12 | shape = data['image'].shape[:2] 13 | points = np.zeros((self.size, 2), dtype=np.float32) 14 | boxes = np.array(data[self.box_key])[:self.size] 15 | 16 | size = boxes.shape[0] 17 | points[:size] = boxes.mean(axis=1) 18 | data['points'] = (points / shape[::-1]).astype(np.float32) 19 | return data 20 | -------------------------------------------------------------------------------- /data/processes/make_icdar_data.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from concern.config import Configurable, State 7 | from .data_process import DataProcess 8 | import cv2 9 | 10 | 11 | class MakeICDARData(DataProcess): 12 | shrink_ratio = State(default=0.4) 13 | 14 | def __init__(self, debug=False, cmd={}, **kwargs): 15 | self.load_all(**kwargs) 16 | 17 | self.debug = debug 18 | if 'debug' in cmd: 19 | self.debug = cmd['debug'] 20 | 21 | def process(self, data): 22 | polygons = [] 23 | ignore_tags = [] 24 | annotations = data['polys'] 25 | for annotation in annotations: 26 | polygons.append(np.array(annotation['points'])) 27 | # polygons.append(annotation['points']) 28 | ignore_tags.append(annotation['ignore']) 29 | ignore_tags = np.array(ignore_tags, dtype=np.uint8) 30 | filename = data.get('filename', data['data_id']) 31 | if self.debug: 32 | self.draw_polygons(data['image'], polygons, ignore_tags) 33 | shape = np.array(data['shape']) 34 | return OrderedDict(image=data['image'], 35 | polygons=polygons, 36 | ignore_tags=ignore_tags, 37 | shape=shape, 38 | filename=filename, 39 | is_training=data['is_training']) 40 | 41 | def draw_polygons(self, image, polygons, ignore_tags): 42 | for i in range(len(polygons)): 43 | polygon = polygons[i].reshape(-1, 2).astype(np.int32) 44 | ignore = ignore_tags[i] 45 | if ignore: 46 | color = (255, 0, 0) # depict ignorable polygons in blue 47 | else: 48 | color = (0, 0, 255) # depict polygons in red 49 | 50 | cv2.polylines(image, [polygon], True, color, 1) 51 | polylines = staticmethod(draw_polygons) 52 | 53 | 54 | class ICDARCollectFN(Configurable): 55 | def __init__(self, *args, **kwargs): 56 | pass 57 | 58 | def __call__(self, batch): 59 | data_dict = OrderedDict() 60 | for sample in batch: 61 | for k, v in sample.items(): 62 | if k not in data_dict: 63 | data_dict[k] = [] 64 | if isinstance(v, np.ndarray): 65 | v = torch.from_numpy(v) 66 | data_dict[k].append(v) 67 | data_dict['image'] = torch.stack(data_dict['image'], 0) 68 | return data_dict 69 | 70 | -------------------------------------------------------------------------------- /data/processes/make_seg_detection_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from shapely.geometry import Polygon 4 | import pyclipper 5 | 6 | from concern.config import State 7 | from .data_process import DataProcess 8 | 9 | 10 | class MakeSegDetectionData(DataProcess): 11 | r''' 12 | Making binary mask from detection data with ICDAR format. 13 | Typically following the process of class `MakeICDARData`. 14 | ''' 15 | min_text_size = State(default=8) 16 | shrink_ratio = State(default=0.4) 17 | 18 | def __init__(self, **kwargs): 19 | self.load_all(**kwargs) 20 | 21 | def process(self, data): 22 | ''' 23 | requied keys: 24 | image, polygons, ignore_tags, filename 25 | adding keys: 26 | mask 27 | ''' 28 | image = data['image'] 29 | polygons = data['polygons'] 30 | ignore_tags = data['ignore_tags'] 31 | image = data['image'] 32 | filename = data['filename'] 33 | 34 | h, w = image.shape[:2] 35 | if data['is_training']: 36 | polygons, ignore_tags = self.validate_polygons( 37 | polygons, ignore_tags, h, w) 38 | gt = np.zeros((1, h, w), dtype=np.float32) 39 | mask = np.ones((h, w), dtype=np.float32) 40 | for i in range(len(polygons)): 41 | polygon = polygons[i] 42 | height = max(polygon[:, 1]) - min(polygon[:, 1]) 43 | width = max(polygon[:, 0]) - min(polygon[:, 0]) 44 | # height = min(np.linalg.norm(polygon[0] - polygon[3]), 45 | # np.linalg.norm(polygon[1] - polygon[2])) 46 | # width = min(np.linalg.norm(polygon[0] - polygon[1]), 47 | # np.linalg.norm(polygon[2] - polygon[3])) 48 | if ignore_tags[i] or min(height, width) < self.min_text_size: 49 | cv2.fillPoly(mask, polygon.astype( 50 | np.int32)[np.newaxis, :, :], 0) 51 | ignore_tags[i] = True 52 | else: 53 | polygon_shape = Polygon(polygon) 54 | distance = polygon_shape.area * \ 55 | (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length 56 | subject = [tuple(l) for l in polygons[i]] 57 | padding = pyclipper.PyclipperOffset() 58 | padding.AddPath(subject, pyclipper.JT_ROUND, 59 | pyclipper.ET_CLOSEDPOLYGON) 60 | shrinked = padding.Execute(-distance) 61 | if shrinked == []: 62 | cv2.fillPoly(mask, polygon.astype( 63 | np.int32)[np.newaxis, :, :], 0) 64 | ignore_tags[i] = True 65 | continue 66 | shrinked = np.array(shrinked[0]).reshape(-1, 2) 67 | cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1) 68 | 69 | if filename is None: 70 | filename = '' 71 | data.update(image=image, 72 | polygons=polygons, 73 | gt=gt, mask=mask, filename=filename) 74 | return data 75 | 76 | def validate_polygons(self, polygons, ignore_tags, h, w): 77 | ''' 78 | polygons (numpy.array, required): of shape (num_instances, num_points, 2) 79 | ''' 80 | if len(polygons) == 0: 81 | return polygons, ignore_tags 82 | assert len(polygons) == len(ignore_tags) 83 | for polygon in polygons: 84 | polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1) 85 | polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1) 86 | 87 | for i in range(len(polygons)): 88 | area = self.polygon_area(polygons[i]) 89 | if abs(area) < 1: 90 | ignore_tags[i] = True 91 | if area > 0: 92 | polygons[i] = polygons[i][::-1, :] 93 | return polygons, ignore_tags 94 | 95 | def polygon_area(self, polygon): 96 | edge = 0 97 | for i in range(polygon.shape[0]): 98 | next_index = (i + 1) % polygon.shape[0] 99 | edge += (polygon[next_index, 0] - polygon[i, 0]) * (polygon[next_index, 1] + polygon[i, 1]) 100 | 101 | return edge / 2. 102 | 103 | -------------------------------------------------------------------------------- /data/processes/normalize_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .data_process import DataProcess 5 | 6 | 7 | class NormalizeImage(DataProcess): 8 | RGB_MEAN = np.array([122.67891434, 116.66876762, 104.00698793]) 9 | 10 | def process(self, data): 11 | assert 'image' in data, '`image` in data is required by this process' 12 | image = data['image'] 13 | image -= self.RGB_MEAN 14 | image /= 255. 15 | image = torch.from_numpy(image).permute(2, 0, 1).float() 16 | data['image'] = image 17 | return data 18 | 19 | @classmethod 20 | def restore(self, image): 21 | image = image.permute(1, 2, 0).to('cpu').numpy() 22 | image = image * 255. 23 | image += self.RGB_MEAN 24 | image = image.astype(np.uint8) 25 | return image 26 | -------------------------------------------------------------------------------- /data/processes/random_crop_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | from .data_process import DataProcess 5 | from concern.config import Configurable, State 6 | 7 | 8 | # random crop algorithm similar to https://github.com/argman/EAST 9 | class RandomCropData(DataProcess): 10 | size = State(default=(512, 512)) 11 | max_tries = State(default=50) 12 | min_crop_side_ratio = State(default=0.1) 13 | require_original_image = State(default=False) 14 | 15 | def __init__(self, **kwargs): 16 | self.load_all(**kwargs) 17 | 18 | def process(self, data): 19 | img = data['image'] 20 | ori_img = img 21 | ori_lines = data['polys'] 22 | 23 | all_care_polys = [line['points'] 24 | for line in data['polys'] if not line['ignore']] 25 | crop_x, crop_y, crop_w, crop_h = self.crop_area(img, all_care_polys) 26 | scale_w = self.size[0] / crop_w 27 | scale_h = self.size[1] / crop_h 28 | scale = min(scale_w, scale_h) 29 | h = int(crop_h * scale) 30 | w = int(crop_w * scale) 31 | padimg = np.zeros( 32 | (self.size[1], self.size[0], img.shape[2]), img.dtype) 33 | padimg[:h, :w] = cv2.resize( 34 | img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) 35 | img = padimg 36 | 37 | lines = [] 38 | for line in data['polys']: 39 | poly = ((np.array(line['points']) - 40 | (crop_x, crop_y)) * scale).tolist() 41 | if not self.is_poly_outside_rect(poly, 0, 0, w, h): 42 | lines.append({**line, 'points': poly}) 43 | data['polys'] = lines 44 | 45 | if self.require_original_image: 46 | data['image'] = ori_img 47 | else: 48 | data['image'] = img 49 | data['lines'] = ori_lines 50 | data['scale_w'] = scale 51 | data['scale_h'] = scale 52 | 53 | return data 54 | 55 | def is_poly_in_rect(self, poly, x, y, w, h): 56 | poly = np.array(poly) 57 | if poly[:, 0].min() < x or poly[:, 0].max() > x + w: 58 | return False 59 | if poly[:, 1].min() < y or poly[:, 1].max() > y + h: 60 | return False 61 | return True 62 | 63 | def is_poly_outside_rect(self, poly, x, y, w, h): 64 | poly = np.array(poly) 65 | if poly[:, 0].max() < x or poly[:, 0].min() > x + w: 66 | return True 67 | if poly[:, 1].max() < y or poly[:, 1].min() > y + h: 68 | return True 69 | return False 70 | 71 | def split_regions(self, axis): 72 | regions = [] 73 | min_axis = 0 74 | for i in range(1, axis.shape[0]): 75 | if axis[i] != axis[i-1] + 1: 76 | region = axis[min_axis:i] 77 | min_axis = i 78 | regions.append(region) 79 | return regions 80 | 81 | def random_select(self, axis, max_size): 82 | xx = np.random.choice(axis, size=2) 83 | xmin = np.min(xx) 84 | xmax = np.max(xx) 85 | xmin = np.clip(xmin, 0, max_size - 1) 86 | xmax = np.clip(xmax, 0, max_size - 1) 87 | return xmin, xmax 88 | 89 | def region_wise_random_select(self, regions, max_size): 90 | selected_index = list(np.random.choice(len(regions), 2)) 91 | selected_values = [] 92 | for index in selected_index: 93 | axis = regions[index] 94 | xx = int(np.random.choice(axis, size=1)) 95 | selected_values.append(xx) 96 | xmin = min(selected_values) 97 | xmax = max(selected_values) 98 | return xmin, xmax 99 | 100 | def crop_area(self, img, polys): 101 | h, w, _ = img.shape 102 | h_array = np.zeros(h, dtype=np.int32) 103 | w_array = np.zeros(w, dtype=np.int32) 104 | for points in polys: 105 | points = np.round(points, decimals=0).astype(np.int32) 106 | minx = np.min(points[:, 0]) 107 | maxx = np.max(points[:, 0]) 108 | w_array[minx:maxx] = 1 109 | miny = np.min(points[:, 1]) 110 | maxy = np.max(points[:, 1]) 111 | h_array[miny:maxy] = 1 112 | # ensure the cropped area not across a text 113 | h_axis = np.where(h_array == 0)[0] 114 | w_axis = np.where(w_array == 0)[0] 115 | 116 | if len(h_axis) == 0 or len(w_axis) == 0: 117 | return 0, 0, w, h 118 | 119 | h_regions = self.split_regions(h_axis) 120 | w_regions = self.split_regions(w_axis) 121 | 122 | for i in range(self.max_tries): 123 | if len(w_regions) > 1: 124 | xmin, xmax = self.region_wise_random_select(w_regions, w) 125 | else: 126 | xmin, xmax = self.random_select(w_axis, w) 127 | if len(h_regions) > 1: 128 | ymin, ymax = self.region_wise_random_select(h_regions, h) 129 | else: 130 | ymin, ymax = self.random_select(h_axis, h) 131 | 132 | if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h: 133 | # area too small 134 | continue 135 | num_poly_in_rect = 0 136 | for poly in polys: 137 | if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin): 138 | num_poly_in_rect += 1 139 | break 140 | 141 | if num_poly_in_rect > 0: 142 | return xmin, ymin, xmax - xmin, ymax - ymin 143 | 144 | return 0, 0, w, h 145 | -------------------------------------------------------------------------------- /data/processes/resize_image.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from concern.config import Configurable, State 5 | import concern.webcv2 as webcv2 6 | from .data_process import DataProcess 7 | 8 | 9 | class _ResizeImage: 10 | ''' 11 | Resize images. 12 | Inputs: 13 | image_size: two-tuple-like object (height, width). 14 | mode: the mode used to resize image. Valid options: 15 | "keep_size": keep the original size of image. 16 | "resize": arbitrarily resize the image to image_size. 17 | "keep_ratio": resize to dest height 18 | while keepping the height/width ratio of the input. 19 | "pad": pad the image to image_size after applying 20 | "keep_ratio" resize. 21 | ''' 22 | MODES = ['resize', 'keep_size', 'keep_ratio', 'pad'] 23 | 24 | def __init__(self, image_size, mode): 25 | self.image_size = image_size 26 | assert mode in self.MODES 27 | self.mode = mode 28 | 29 | def resize_or_pad(self, image): 30 | if self.mode == 'keep_size': 31 | return image 32 | if self.mode == 'pad': 33 | return self.pad_iamge(image) 34 | 35 | assert self.mode in ['resize', 'keep_ratio'] 36 | height, width = self.get_image_size(image) 37 | image = cv2.resize(image, (width, height)) 38 | return image 39 | 40 | def get_image_size(self, image): 41 | height, width = self.image_size 42 | if self.mode == 'keep_ratio': 43 | width = max(width, int( 44 | height / image.shape[0] * image.shape[1] / 32 + 0.5) * 32) 45 | if self.mode == 'pad': 46 | width = min(width, 47 | max(int(height / image.shape[0] * image.shape[1] / 32 + 0.5) * 32, 32)) 48 | return height, width 49 | 50 | def pad_iamge(self, image): 51 | canvas = np.zeros((*self.image_size, 3), np.float32) 52 | height, width = self.get_image_size(image) 53 | image = cv2.resize(image, (width, height)) 54 | canvas[:, :width, :] = image 55 | return canvas 56 | 57 | 58 | class ResizeImage(_ResizeImage, DataProcess): 59 | mode = State(default='keep_ratio') 60 | image_size = State(default=[1152, 2048]) # height, width 61 | key = State(default='image') 62 | 63 | def __init__(self, cmd={}, mode=None, **kwargs): 64 | self.load_all(**kwargs) 65 | if mode is not None: 66 | self.mode = mode 67 | if 'resize_mode' in cmd: 68 | self.mode = cmd['resize_mode'] 69 | assert self.mode in self.MODES 70 | 71 | def process(self, data): 72 | data[self.key] = self.resize_or_pad(data[self.key]) 73 | return data 74 | 75 | 76 | class ResizeData(_ResizeImage, DataProcess): 77 | key = State(default='image') 78 | box_key = State(default='polygons') 79 | image_size = State(default=[64, 256]) # height, width 80 | 81 | def __init__(self, cmd={}, mode=None, key=None, box_key=None, **kwargs): 82 | self.load_all(**kwargs) 83 | if mode is not None: 84 | self.mode = mode 85 | if key is not None: 86 | self.key = key 87 | if box_key is not None: 88 | self.box_key = box_key 89 | if 'resize_mode' in cmd: 90 | self.mode = cmd['resize_mode'] 91 | assert self.mode in self.MODES 92 | 93 | def process(self, data): 94 | height, width = data['image'].shape[:2] 95 | new_height, new_width = self.get_image_size(data['image']) 96 | data[self.key] = self.resize_or_pad(data[self.key]) 97 | 98 | charboxes = data[self.box_key] 99 | data[self.box_key] = charboxes.copy() 100 | data[self.box_key][:, :, 0] = data[self.box_key][:, :, 0] * \ 101 | new_width / width 102 | data[self.box_key][:, :, 1] = data[self.box_key][:, :, 1] * \ 103 | new_height / height 104 | return data 105 | -------------------------------------------------------------------------------- /data/processes/serialize_box.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from data.quad import Quad 4 | from concern.config import State 5 | from .data_process import DataProcess 6 | 7 | 8 | class SerializeBox(DataProcess): 9 | box_key = State(default='charboxes') 10 | format = State(default='NP2') 11 | 12 | def process(self, data): 13 | data[self.box_key] = data['lines'].quads 14 | return data 15 | 16 | 17 | class UnifyRect(SerializeBox): 18 | max_size = State(default=64) 19 | 20 | def process(self, data): 21 | h, w = data['image'].shape[:2] 22 | boxes = np.zeros((self.max_size, 4), dtype=np.float32) 23 | mask_has_box = np.zeros(self.max_size, dtype=np.float32) 24 | data = super().process(data) 25 | quad = data[self.box_key] 26 | assert quad.shape[0] <= self.max_size 27 | boxes[:quad.shape[0]] = quad.rectify() / np.array([w, h, w, h]).reshape(1, 4) 28 | mask_has_box[:quad.shape[0]] = 1. 29 | data['boxes'] = boxes 30 | data['mask_has_box'] = mask_has_box 31 | return data 32 | -------------------------------------------------------------------------------- /data/quad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class Quad: 6 | def __init__(self, points, format='NP2'): 7 | self._rect = None 8 | self.tensorized = False 9 | self._points = None 10 | self.set_points(points, format) 11 | 12 | @property 13 | def points(self): 14 | return self._points 15 | 16 | def set_points(self, new_points, format='NP2'): 17 | order = (format.index('N'), format.index('P'), format.index('2')) 18 | 19 | if isinstance(new_points, torch.Tensor): 20 | self._points = new_points.permute(*order) 21 | self.tensorized = True 22 | else: 23 | points = np.array(new_points, dtype=np.float32) 24 | self._points = points.transpose(*order) 25 | 26 | if self.tensorized: 27 | self.tensorized = False 28 | self.tensor 29 | 30 | @points.setter 31 | def points(self, new_points): 32 | self.set_points(new_points) 33 | 34 | @property 35 | def tensor(self): 36 | if not self.tensorized: 37 | self._points = torch.from_numpy(self._points) 38 | return self._points 39 | 40 | def to(self, device): 41 | self._points.to(device) 42 | return self._points 43 | 44 | def __iter__(self): 45 | for i in range(self._points.shape[0]): 46 | if self.tensorized: 47 | yield self.tensor[i] 48 | else: 49 | yield self.points[i] 50 | 51 | 52 | def rect(self): 53 | if self._rect is None: 54 | self._rect = self.rectify() 55 | return self._rect 56 | 57 | def __getitem__(self, *args, **kwargs): 58 | return self._points.__getitem__(*args, **kwargs) 59 | 60 | def numpy(self): 61 | if not self.tensorized: 62 | return self._points 63 | return self._points.cpu().data.numpy() 64 | 65 | def rectify(self): 66 | if self.tensorized: 67 | return self.rectify_tensor() 68 | 69 | xmin = self._points[:, :, 0].min(axis=1) 70 | ymin = self._points[:, :, 1].min(axis=1) 71 | xmax = self._points[:, :, 0].max(axis=1) 72 | ymax = self._points[:, :, 1].max(axis=1) 73 | return np.stack([xmin, ymin, xmax, ymax], axis=1) 74 | 75 | def rectify_tensor(self): 76 | xmin, _ = self.tensor[:, :, 0].min(dim=1, keepdim=True) 77 | ymin, _ = self.tensor[:, :, 1].min(dim=1, keepdim=True) 78 | xmax, _ = self.tensor[:, :, 0].max(dim=1, keepdim=True) 79 | ymax, _ = self.tensor[:, :, 1].max(dim=1, keepdim=True) 80 | return torch.cat([xmin, ymin, xmax, ymax], dim=1) 81 | 82 | def __getattribute__(self, name): 83 | try: 84 | return super().__getattribute__(name) 85 | except AttributeError: 86 | return self._points.__getattribute__(name) 87 | -------------------------------------------------------------------------------- /data/random_crop_aug.py: -------------------------------------------------------------------------------- 1 | import random 2 | import cv2 3 | import numpy as np 4 | from shapely.geometry import Polygon 5 | from shapely import affinity 6 | 7 | from concern.config import Configurable, State 8 | 9 | 10 | def regular_resize(image, boxes, tags, crop_size): 11 | h, w, c = image.shape 12 | if h < w: 13 | scale_ratio = crop_size * 1.0 / w 14 | new_h = int(round(crop_size * h * 1.0 / w)) 15 | if new_h > crop_size: 16 | new_h = crop_size 17 | image = cv2.resize(image, (crop_size, new_h)) 18 | new_img = np.zeros((crop_size, crop_size, 3)) 19 | new_img[:new_h, :, :] = image 20 | boxes *= scale_ratio 21 | else: 22 | scale_ratio = crop_size * 1.0 / h 23 | new_w = int(round(crop_size * w * 1.0 / h)) 24 | if new_w > crop_size: 25 | new_w = crop_size 26 | image = cv2.resize(image, (new_w, crop_size)) 27 | new_img = np.zeros((crop_size, crop_size, 3)) 28 | new_img[:, :new_w, :] = image 29 | boxes *= scale_ratio 30 | return new_img, boxes, tags 31 | 32 | 33 | def random_crop(image, boxes, tags, crop_size, max_tries, w_axis, h_axis, min_crop_side_ratio): 34 | h, w, c = image.shape 35 | selected_boxes = [] 36 | for i in range(max_tries): 37 | xx = np.random.choice(w_axis, size=2) 38 | xmin = np.min(xx) 39 | xmax = np.max(xx) 40 | xmin = np.clip(xmin, 0, w-1) 41 | xmax = np.clip(xmax, 0, w-1) 42 | yy = np.random.choice(h_axis, size=2) 43 | ymin = np.min(yy) 44 | ymax = np.max(yy) 45 | ymin = np.clip(ymin, 0, h-1) 46 | ymax = np.clip(ymax, 0, h-1) 47 | if xmax - xmin < min_crop_side_ratio*w or ymax - ymin < min_crop_side_ratio*h: 48 | # area too small 49 | continue 50 | if boxes.shape[0] != 0: 51 | box_axis_in_area = (boxes[:, :, 0] >= xmin) & (boxes[:, :, 0] <= xmax) \ 52 | & (boxes[:, :, 1] >= ymin) & (boxes[:, :, 1] <= ymax) 53 | selected_boxes = np.where(np.sum(box_axis_in_area, axis=1) == 4)[0] 54 | if len(selected_boxes) > 0: 55 | if (tags[selected_boxes] == False).astype(np.float).sum() > 0: 56 | break 57 | else: 58 | selected_boxes = [] 59 | break 60 | if i == max_tries - 1: 61 | return regular_resize(image, boxes, tags, crop_size) 62 | 63 | image = image[ymin:ymax+1, xmin:xmax+1, :] 64 | boxes = boxes[selected_boxes] 65 | tags = tags[selected_boxes] 66 | boxes[:, :, 0] -= xmin 67 | boxes[:, :, 1] -= ymin 68 | return regular_resize(image, boxes, tags, crop_size) 69 | 70 | 71 | def regular_crop(image, boxes, tags, crop_size, max_tries, w_array, h_array, w_axis, h_axis, min_crop_side_ratio): 72 | h, w, c = image.shape 73 | mask_w = np.arange(w - crop_size) 74 | mask_h = np.arange(h - crop_size) 75 | keep_w = np.where(np.logical_and( 76 | w_array[mask_w] == 0, w_array[mask_w + crop_size - 1] == 0))[0] 77 | keep_h = np.where(np.logical_and( 78 | h_array[mask_h] == 0, h_array[mask_h + crop_size - 1] == 0))[0] 79 | 80 | if keep_w.size > 0 and keep_h.size > 0: 81 | for i in range(max_tries): 82 | xmin = np.random.choice(keep_w, size=1)[0] 83 | xmax = xmin + crop_size 84 | ymin = np.random.choice(keep_h, size=1)[0] 85 | ymax = ymin + crop_size 86 | if boxes.shape[0] != 0: 87 | box_axis_in_area = (boxes[:, :, 0] >= xmin) & (boxes[:, :, 0] <= xmax) \ 88 | & (boxes[:, :, 1] >= ymin) & (boxes[:, :, 1] <= ymax) 89 | selected_boxes = np.where( 90 | np.sum(box_axis_in_area, axis=1) == 4)[0] 91 | if len(selected_boxes) > 0: 92 | if (tags[selected_boxes] == False).astype(np.float).sum() > 0: 93 | break 94 | else: 95 | selected_boxes = [] 96 | break 97 | if i == max_tries-1: 98 | return random_crop(image, boxes, tags, crop_size, max_tries, w_axis, h_axis, min_crop_side_ratio) 99 | image = image[ymin:ymax, xmin:xmax, :] 100 | boxes = boxes[selected_boxes] 101 | tags = tags[selected_boxes] 102 | boxes[:, :, 0] -= xmin 103 | boxes[:, :, 1] -= ymin 104 | return image, boxes, tags 105 | else: 106 | return random_crop(image, boxes, tags, crop_size, max_tries, w_axis, h_axis, min_crop_side_ratio) 107 | 108 | 109 | class RandomCrop(object): 110 | def __init__(self, crop_size=640, max_tries=50, min_crop_side_ratio=0.1): 111 | self.crop_size = crop_size 112 | self.max_tries = max_tries 113 | self.min_crop_side_ratio = min_crop_side_ratio 114 | 115 | def __call__(self, image, boxes, tags): 116 | h, w, _ = image.shape 117 | h_array = np.zeros((h), dtype=np.int32) 118 | w_array = np.zeros((w), dtype=np.int32) 119 | 120 | for box in boxes: 121 | box = np.round(box, decimals=0).astype(np.int32) 122 | minx = np.min(box[:, 0]) 123 | maxx = np.max(box[:, 0]) 124 | w_array[minx:maxx] = 1 125 | miny = np.min(box[:, 1]) 126 | maxy = np.max(box[:, 1]) 127 | h_array[miny:maxy] = 1 128 | 129 | h_axis = np.where(h_array == 0)[0] 130 | w_axis = np.where(w_array == 0)[0] 131 | if len(h_axis) == 0 or len(w_axis) == 0: 132 | # resize image 133 | return regular_resize(image, boxes, tags, self.crop_size) 134 | 135 | if h <= self.crop_size + 1 or w <= self.crop_size + 1: 136 | return random_crop(image, boxes, tags, self.crop_size, self.max_tries, w_axis, h_axis, self.min_crop_side_ratio) 137 | else: 138 | return regular_crop(image, boxes, tags, self.crop_size, self.max_tries, w_array, h_array, w_axis, h_axis, self.min_crop_side_ratio) 139 | 140 | 141 | class RandomCropAug(Configurable): 142 | size = State(default=640) 143 | 144 | def __init__(self, size=640, *args, **kwargs): 145 | self.size = size or self.size 146 | self.augment = RandomCrop(size) 147 | 148 | def __call__(self, data): 149 | ''' 150 | This augmenter is supposed to following the process of `MakeICDARData`, 151 | in which labels are mapped to this specific format: 152 | (image, polygons: (n, 4, 2), tags: [Boolean], ...) 153 | ''' 154 | image, boxes, ignore_tags = data[:3] 155 | image, boxes, ignore_tags = self.augment(image, boxes, ignore_tags) 156 | return (image, boxes, ignore_tags, *data[3:]) 157 | -------------------------------------------------------------------------------- /data/text_lines.py: -------------------------------------------------------------------------------- 1 | from .quad import Quad 2 | 3 | 4 | class TextLines: 5 | ''' 6 | The abstrct class of text lines in an input image for scene text detection and recognition. 7 | Input: 8 | lines: 9 | - text: the text content of a text instance. 10 | poly: the quadrangle-box of the text instance. 11 | charboxes: the quadrangle-box of the characters inside the corresponding text instance. 12 | ''' 13 | def __init__(self, lines, with_charboxes=True): 14 | self.texts = [] 15 | quads = [] 16 | self.charboxes = [] 17 | for line in lines: 18 | self.texts.append(line['text']) 19 | quads.append(line['poly']) 20 | if with_charboxes and 'charboxes' in line: 21 | self.charboxes.append(Quad(line['charboxes'])) 22 | self.with_charboxes = len(self.charboxes) > 0 23 | self.quads = Quad(quads) 24 | self._rects = None 25 | 26 | def __iter__(self): 27 | for text, quad in zip(self.texts, self.quads): 28 | yield(text, quad) 29 | 30 | @property 31 | def rects(self): 32 | if self._rects is None: 33 | self._rects = self.quads.rectify() 34 | return self._rects 35 | 36 | def __len__(self): 37 | return len(self.texts) 38 | 39 | def char_count(self): 40 | return sum([len(t) for t in self.texts]) 41 | -------------------------------------------------------------------------------- /data/transform_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from concern.config import Configurable 5 | 6 | 7 | class TransformData(Configurable): 8 | ''' 9 | this transformation is inplcae, which means that the input 10 | will be modified. 11 | ''' 12 | mean = np.array([0.485, 0.456, 0.406]) 13 | std = np.array([0.229, 0.224, 0.225]) 14 | 15 | def __init__(self, **kwargs): 16 | self.load_all(**kwargs) 17 | 18 | def __call__(self, data_dict, *args, **kwargs): 19 | image = data_dict['image'].transpose(2, 0, 1) 20 | image = image / 255.0 21 | image = (image - self.mean[:, None, None]) / self.std[:, None, None] 22 | data_dict['image'] = image 23 | return data_dict 24 | -------------------------------------------------------------------------------- /data/unpack_msgpack_data.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import cv2 4 | import numpy as np 5 | # import nori2 as nori 6 | import msgpack 7 | from PIL import Image 8 | 9 | from concern.config import Configurable, State 10 | 11 | 12 | # class UnpackMsgpackData(Configurable): 13 | # mode = State(default='BGR') 14 | 15 | # def __init__(self, cmd={}, **kwargs): 16 | # self.load_all(**kwargs) 17 | # self.fetcher = nori.Fetcher() 18 | # if 'mode' in cmd: 19 | # self.mode = cmd['mode'] 20 | 21 | # def convert_obj(self, obj): 22 | # if isinstance(obj, dict): 23 | # ndata = {} 24 | # for key, value in obj.items(): 25 | # nkey = key.decode() 26 | # if nkey == 'img': 27 | # img = Image.open(io.BytesIO(value)) 28 | # img = np.array(img.convert('RGB')) 29 | # if self.mode == 'BGR': 30 | # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 31 | # nvalue = img 32 | # else: 33 | # nvalue = self.convert_obj(value) 34 | # ndata[nkey] = nvalue 35 | # return ndata 36 | # elif isinstance(obj, list): 37 | # return [self.convert_obj(item) for item in obj] 38 | # elif isinstance(obj, bytes): 39 | # return obj.decode() 40 | # else: 41 | # return obj 42 | 43 | # def convert(self, data): 44 | # obj = msgpack.loads(data, max_str_len=2 ** 31) 45 | # return self.convert_obj(obj) 46 | 47 | # def __call__(self, data_id, meta=None): 48 | # if meta is None: 49 | # meta = {} 50 | # item = self.convert(self.fetcher.get(data_id)) 51 | # item['data_id'] = data_id 52 | # meta.update(item) 53 | # return meta 54 | 55 | 56 | class TransformMsgpackData(UnpackMsgpackData): 57 | meta_loader = State(default=None) 58 | 59 | def __init__(self, meta_loader=None, cmd={}, **kwargs): 60 | super().__init__(cmd=cmd, meta_loader=meta_loader, **kwargs) 61 | print('transform') 62 | self.meta_loader = cmd.get('meta_loader', self.meta_loader) 63 | 64 | def __call__(self, data_id, meta): 65 | item = self.convert(self.fetcher.get(data_id)) 66 | image = item.pop('img').astype(np.float32) 67 | if self.meta_loader is not None: 68 | meta['extra'] = item 69 | data = self.meta_loader.parse_meta(data_id, meta) 70 | else: 71 | data = meta 72 | data.update(**item) 73 | data.update(image=image, data_id=data_id) 74 | return data 75 | -------------------------------------------------------------------------------- /decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .seg_detector import SegDetector 2 | from .seg_detector_asf import SegSpatialScaleDetector 3 | from .dice_loss import DiceLoss 4 | from .pss_loss import PSS_Loss 5 | from .l1_loss import MaskL1Loss 6 | from .balance_cross_entropy_loss import BalanceCrossEntropyLoss 7 | -------------------------------------------------------------------------------- /decoders/balance_cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BalanceCrossEntropyLoss(nn.Module): 6 | ''' 7 | Balanced cross entropy loss. 8 | Shape: 9 | - Input: :math:`(N, 1, H, W)` 10 | - GT: :math:`(N, 1, H, W)`, same shape as the input 11 | - Mask: :math:`(N, H, W)`, same spatial shape as the input 12 | - Output: scalar. 13 | 14 | Examples:: 15 | 16 | >>> m = nn.Sigmoid() 17 | >>> loss = nn.BCELoss() 18 | >>> input = torch.randn(3, requires_grad=True) 19 | >>> target = torch.empty(3).random_(2) 20 | >>> output = loss(m(input), target) 21 | >>> output.backward() 22 | ''' 23 | 24 | def __init__(self, negative_ratio=3.0, eps=1e-6): 25 | super(BalanceCrossEntropyLoss, self).__init__() 26 | self.negative_ratio = negative_ratio 27 | self.eps = eps 28 | 29 | def forward(self, 30 | pred: torch.Tensor, 31 | gt: torch.Tensor, 32 | mask: torch.Tensor, 33 | return_origin=False): 34 | ''' 35 | Args: 36 | pred: shape :math:`(N, 1, H, W)`, the prediction of network 37 | gt: shape :math:`(N, 1, H, W)`, the target 38 | mask: shape :math:`(N, H, W)`, the mask indicates positive regions 39 | ''' 40 | positive = (gt[:,0,:,:] * mask).byte() 41 | negative = ((1 - gt[:,0,:,:]) * mask).byte() 42 | positive_count = int(positive.float().sum()) 43 | negative_count = min(int(negative.float().sum()), 44 | int(positive_count * self.negative_ratio)) 45 | loss = nn.functional.binary_cross_entropy( 46 | pred, gt, reduction='none')[:, 0, :, :] 47 | positive_loss = loss * positive.float() 48 | negative_loss = loss * negative.float() 49 | negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) 50 | 51 | balance_loss = (positive_loss.sum() + negative_loss.sum()) /\ 52 | (positive_count + negative_count + self.eps) 53 | 54 | if return_origin: 55 | return balance_loss, loss 56 | return balance_loss 57 | -------------------------------------------------------------------------------- /decoders/feature_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ScaleChannelAttention(nn.Module): 6 | def __init__(self, in_planes, out_planes, num_features, init_weight=True): 7 | super(ScaleChannelAttention, self).__init__() 8 | self.avgpool = nn.AdaptiveAvgPool2d(1) 9 | print(self.avgpool) 10 | self.fc1 = nn.Conv2d(in_planes, out_planes, 1, bias=False) 11 | self.bn = nn.BatchNorm2d(out_planes) 12 | self.fc2 = nn.Conv2d(out_planes, num_features, 1, bias=False) 13 | if init_weight: 14 | self._initialize_weights() 15 | 16 | def _initialize_weights(self): 17 | for m in self.modules(): 18 | if isinstance(m, nn.Conv2d): 19 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 20 | if m.bias is not None: 21 | nn.init.constant_(m.bias, 0) 22 | if isinstance(m ,nn.BatchNorm2d): 23 | nn.init.constant_(m.weight, 1) 24 | nn.init.constant_(m.bias, 0) 25 | 26 | def forward(self, x): 27 | global_x = self.avgpool(x) 28 | global_x = self.fc1(global_x) 29 | global_x = F.relu(self.bn(global_x)) 30 | global_x = self.fc2(global_x) 31 | global_x = F.softmax(global_x, 1) 32 | return global_x 33 | 34 | class ScaleChannelSpatialAttention(nn.Module): 35 | def __init__(self, in_planes, out_planes, num_features, init_weight=True): 36 | super(ScaleChannelSpatialAttention, self).__init__() 37 | self.channel_wise = nn.Sequential( 38 | nn.AdaptiveAvgPool2d(1), 39 | nn.Conv2d(in_planes, out_planes , 1, bias=False), 40 | # nn.BatchNorm2d(out_planes), 41 | nn.ReLU(), 42 | nn.Conv2d(out_planes, in_planes, 1, bias=False) 43 | ) 44 | self.spatial_wise = nn.Sequential( 45 | #Nx1xHxW 46 | nn.Conv2d(1, 1, 3, bias=False, padding=1), 47 | nn.ReLU(), 48 | nn.Conv2d(1, 1, 1, bias=False), 49 | nn.Sigmoid() 50 | ) 51 | self.attention_wise = nn.Sequential( 52 | nn.Conv2d(in_planes, num_features, 1, bias=False), 53 | nn.Sigmoid() 54 | ) 55 | if init_weight: 56 | self._initialize_weights() 57 | 58 | def _initialize_weights(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.Conv2d): 61 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 62 | if m.bias is not None: 63 | nn.init.constant_(m.bias, 0) 64 | if isinstance(m ,nn.BatchNorm2d): 65 | nn.init.constant_(m.weight, 1) 66 | nn.init.constant_(m.bias, 0) 67 | 68 | def forward(self, x): 69 | # global_x = self.avgpool(x) 70 | #shape Nx4x1x1 71 | global_x = self.channel_wise(x).sigmoid() 72 | #shape: NxCxHxW 73 | global_x = global_x + x 74 | #shape:Nx1xHxW 75 | x = torch.mean(global_x, dim=1, keepdim=True) 76 | global_x = self.spatial_wise(x) + global_x 77 | global_x = self.attention_wise(global_x) 78 | return global_x 79 | 80 | class ScaleSpatialAttention(nn.Module): 81 | def __init__(self, in_planes, out_planes, num_features, init_weight=True): 82 | super(ScaleSpatialAttention, self).__init__() 83 | self.spatial_wise = nn.Sequential( 84 | #Nx1xHxW 85 | nn.Conv2d(1, 1, 3, bias=False, padding=1), 86 | nn.ReLU(), 87 | nn.Conv2d(1, 1, 1, bias=False), 88 | nn.Sigmoid() 89 | ) 90 | self.attention_wise = nn.Sequential( 91 | nn.Conv2d(in_planes, num_features, 1, bias=False), 92 | nn.Sigmoid() 93 | ) 94 | if init_weight: 95 | self._initialize_weights() 96 | 97 | def _initialize_weights(self): 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 101 | if m.bias is not None: 102 | nn.init.constant_(m.bias, 0) 103 | if isinstance(m ,nn.BatchNorm2d): 104 | nn.init.constant_(m.weight, 1) 105 | nn.init.constant_(m.bias, 0) 106 | 107 | def forward(self, x): 108 | global_x = torch.mean(x, dim=1, keepdim=True) 109 | global_x = self.spatial_wise(global_x) + x 110 | global_x = self.attention_wise(global_x) 111 | return global_x 112 | 113 | class ScaleFeatureSelection(nn.Module): 114 | def __init__(self, in_channels, inter_channels , out_features_num=4, attention_type='scale_spatial'): 115 | super(ScaleFeatureSelection, self).__init__() 116 | self.in_channels=in_channels 117 | self.inter_channels = inter_channels 118 | self.out_features_num = out_features_num 119 | self.conv = nn.Conv2d(in_channels, inter_channels, 3, padding=1) 120 | self.type = attention_type 121 | if self.type == 'scale_spatial': 122 | self.enhanced_attention = ScaleSpatialAttention(inter_channels, inter_channels//4, out_features_num) 123 | elif self.type == 'scale_channel_spatial': 124 | self.enhanced_attention = ScaleChannelSpatialAttention(inter_channels, inter_channels // 4, out_features_num) 125 | elif self.type == 'scale_channel': 126 | self.enhanced_attention = ScaleChannelAttention(inter_channels, inter_channels//2, out_features_num) 127 | 128 | def _initialize_weights(self, m): 129 | classname = m.__class__.__name__ 130 | if classname.find('Conv') != -1: 131 | nn.init.kaiming_normal_(m.weight.data) 132 | elif classname.find('BatchNorm') != -1: 133 | m.weight.data.fill_(1.) 134 | m.bias.data.fill_(1e-4) 135 | def forward(self, concat_x, features_list): 136 | concat_x = self.conv(concat_x) 137 | score = self.enhanced_attention(concat_x) 138 | assert len(features_list) == self.out_features_num 139 | if self.type not in ['scale_channel_spatial', 'scale_spatial']: 140 | shape = features_list[0].shape[2:] 141 | score = F.interpolate(score, size=shape, mode='bilinear') 142 | x = [] 143 | for i in range(self.out_features_num): 144 | x.append(score[:, i:i+1] * features_list[i]) 145 | return torch.cat(x, dim=1) -------------------------------------------------------------------------------- /decoders/l1_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MaskL1Loss(nn.Module): 6 | def __init__(self): 7 | super(MaskL1Loss, self).__init__() 8 | 9 | def forward(self, pred: torch.Tensor, gt, mask): 10 | mask_sum = mask.sum() 11 | if mask_sum.item() == 0: 12 | return mask_sum, dict(l1_loss=mask_sum) 13 | else: 14 | loss = (torch.abs(pred[:, 0] - gt) * mask).sum() / mask_sum 15 | return loss, dict(l1_loss=loss) 16 | 17 | 18 | class BalanceL1Loss(nn.Module): 19 | def __init__(self, negative_ratio=3.): 20 | super(BalanceL1Loss, self).__init__() 21 | self.negative_ratio = negative_ratio 22 | 23 | def forward(self, pred: torch.Tensor, gt, mask): 24 | ''' 25 | Args: 26 | pred: (N, 1, H, W). 27 | gt: (N, H, W). 28 | mask: (N, H, W). 29 | ''' 30 | loss = torch.abs(pred[:, 0] - gt) 31 | positive = loss * mask 32 | negative = loss * (1 - mask) 33 | positive_count = int(mask.sum()) 34 | negative_count = min( 35 | int((1 - mask).sum()), 36 | int(positive_count * self.negative_ratio)) 37 | negative_loss, _ = torch.topk(negative.view(-1), negative_count) 38 | negative_loss = negative_loss.sum() / negative_count 39 | positive_loss = positive.sum() / positive_count 40 | return positive_loss + negative_loss,\ 41 | dict(l1_loss=positive_loss, nge_l1_loss=negative_loss) 42 | -------------------------------------------------------------------------------- /decoders/pss_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class PSS_Loss(nn.Module): 6 | 7 | def __init__(self, cls_loss): 8 | super(PSS_Loss, self).__init__() 9 | self.eps = 1e-6 10 | self.criterion = eval('self.' + cls_loss + '_loss') 11 | 12 | def dice_loss(self, pred, gt, m): 13 | intersection = torch.sum(pred*gt*m) 14 | union = torch.sum(pred*m) + torch.sum(gt*m) + self.eps 15 | loss = 1 - 2.0*intersection/union 16 | if loss > 1: 17 | print(intersection, union) 18 | return loss 19 | 20 | def dice_ohnm_loss(self, pred, gt, m): 21 | pos_index = (gt == 1) * (m == 1) 22 | neg_index = (gt == 0) * (m == 1) 23 | pos_num = pos_index.float().sum().item() 24 | neg_num = neg_index.float().sum().item() 25 | if pos_num == 0 or neg_num < pos_num*3.0: 26 | return self.dice_loss(pred, gt, m) 27 | else: 28 | neg_num = int(pos_num*3) 29 | pos_pred = pred[pos_index] 30 | neg_pred = pred[neg_index] 31 | neg_sort, _ = torch.sort(neg_pred, descending=True) 32 | sampled_neg_pred = neg_sort[:neg_num] 33 | pos_gt = pos_pred.clone() 34 | pos_gt.data.fill_(1.0) 35 | pos_gt = pos_gt.detach() 36 | neg_gt = sampled_neg_pred.clone() 37 | neg_gt.data.fill_(0) 38 | neg_gt = neg_gt.detach() 39 | tpred = torch.cat((pos_pred, sampled_neg_pred)) 40 | tgt = torch.cat((pos_gt, neg_gt)) 41 | intersection = torch.sum(tpred * tgt) 42 | union = torch.sum(tpred) + torch.sum(gt) + self.eps 43 | loss = 1 - 2.0 * intersection / union 44 | return loss 45 | 46 | def focal_loss(self, pred, gt, m, alpha=0.25, gamma=0.6): 47 | pos_mask = (gt == 1).float() 48 | neg_mask = (gt == 0).float() 49 | mask = alpha*pos_mask * \ 50 | torch.pow(1-pred.data, gamma)+(1-alpha) * \ 51 | neg_mask*torch.pow(pred.data, gamma) 52 | l = F.binary_cross_entropy(pred, gt, weight=mask, reduction='none') 53 | loss = torch.sum(l*m)/(self.eps+m.sum()) 54 | loss *= 10 55 | return loss 56 | 57 | def wbce_orig_loss(self, pred, gt, m): 58 | n, h, w = pred.size() 59 | assert (torch.max(gt) == 1) 60 | pos_neg_p = pred[m.byte()] 61 | pos_neg_t = gt[m.byte()] 62 | pos_mask = (pos_neg_t == 1).squeeze() 63 | w = pos_mask.float() * (1 - pos_mask).sum().item() + \ 64 | (1 - pos_mask).float() * pos_mask.sum().item() 65 | w = w / (pos_mask.size(0)) 66 | loss = F.binary_cross_entropy(pos_neg_p, pos_neg_t, w, reduction='sum') 67 | return loss 68 | 69 | def wbce_loss(self, pred, gt, m): 70 | pos_mask = (gt == 1).float()*m 71 | neg_mask = (gt == 0).float()*m 72 | # mask=(pos_mask*neg_mask.sum()+neg_mask*pos_mask.sum())/m.sum() 73 | # loss=torch.sum(l) 74 | mask = pos_mask * neg_mask.sum() / pos_mask.sum() + neg_mask 75 | l = F.binary_cross_entropy(pred, gt, weight=mask, reduction='none') 76 | loss = torch.sum(l)/(m.sum()+self.eps) 77 | return loss 78 | 79 | def bce_loss(self, pred, gt, m): 80 | l = F.binary_cross_entropy(pred, gt, weight=m, reduction='sum') 81 | loss = l/(m.sum()+self.eps) 82 | return loss 83 | 84 | def dice_bce_loss(self, pred, gt, m): 85 | return (self.dice_loss(pred, gt, m) + self.bce_loss(pred, gt, m)) / 2.0 86 | 87 | def dice_ohnm_bce_loss(self, pred, gt, m): 88 | return (self.dice_ohnm_loss(pred, gt, m) + self.bce_loss(pred, gt, m)) / 2.0 89 | 90 | def forward(self, pred, gt, mask, gt_type='shrink'): 91 | if gt_type == 'shrink': 92 | loss = self.get_loss(pred, gt, mask) 93 | return loss 94 | elif gt_type == 'pss': 95 | loss = self.get_loss(pred, gt[:, :4, :, :], mask) 96 | g_g = gt[:, 4, :, :] 97 | g_p, _ = torch.max(pred, 1) 98 | loss += self.criterion(g_p, g_g, mask) 99 | return loss 100 | elif gt_type == 'both': 101 | pss_loss = self.get_loss(pred[:, :4, :, :], gt[:, :4, :, :], mask) 102 | g_g = gt[:, 4, :, :] 103 | g_p, _ = torch.max(pred, 1) 104 | pss_loss += self.criterion(g_p, g_g, mask) 105 | shrink_loss = self.criterion( 106 | pred[:, 4, :, :], gt[:, 5, :, :], mask) 107 | return pss_loss, shrink_loss 108 | else: 109 | return NotImplementedError('gt_type [%s] is not implemented', gt_type) 110 | 111 | def get_loss(self, pred, gt, mask): 112 | loss = torch.tensor(0.) 113 | for ind in range(pred.size(1)): 114 | loss += self.criterion(pred[:, ind, :, :], gt[:, ind, :, :], mask) 115 | return loss 116 | -------------------------------------------------------------------------------- /decoders/seg_detector.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | BatchNorm2d = nn.BatchNorm2d 6 | 7 | class SegDetector(nn.Module): 8 | def __init__(self, 9 | in_channels=[64, 128, 256, 512], 10 | inner_channels=256, k=10, 11 | bias=False, adaptive=False, smooth=False, serial=False, 12 | *args, **kwargs): 13 | ''' 14 | bias: Whether conv layers have bias or not. 15 | adaptive: Whether to use adaptive threshold training or not. 16 | smooth: If true, use bilinear instead of deconv. 17 | serial: If true, thresh prediction will combine segmentation result as input. 18 | ''' 19 | super(SegDetector, self).__init__() 20 | self.k = k 21 | self.serial = serial 22 | self.up5 = nn.Upsample(scale_factor=2, mode='nearest') 23 | self.up4 = nn.Upsample(scale_factor=2, mode='nearest') 24 | self.up3 = nn.Upsample(scale_factor=2, mode='nearest') 25 | 26 | self.in5 = nn.Conv2d(in_channels[-1], inner_channels, 1, bias=bias) 27 | self.in4 = nn.Conv2d(in_channels[-2], inner_channels, 1, bias=bias) 28 | self.in3 = nn.Conv2d(in_channels[-3], inner_channels, 1, bias=bias) 29 | self.in2 = nn.Conv2d(in_channels[-4], inner_channels, 1, bias=bias) 30 | 31 | self.out5 = nn.Sequential( 32 | nn.Conv2d(inner_channels, inner_channels // 33 | 4, 3, padding=1, bias=bias), 34 | nn.Upsample(scale_factor=8, mode='nearest')) 35 | self.out4 = nn.Sequential( 36 | nn.Conv2d(inner_channels, inner_channels // 37 | 4, 3, padding=1, bias=bias), 38 | nn.Upsample(scale_factor=4, mode='nearest')) 39 | self.out3 = nn.Sequential( 40 | nn.Conv2d(inner_channels, inner_channels // 41 | 4, 3, padding=1, bias=bias), 42 | nn.Upsample(scale_factor=2, mode='nearest')) 43 | self.out2 = nn.Conv2d( 44 | inner_channels, inner_channels//4, 3, padding=1, bias=bias) 45 | 46 | self.binarize = nn.Sequential( 47 | nn.Conv2d(inner_channels, inner_channels // 48 | 4, 3, padding=1, bias=bias), 49 | BatchNorm2d(inner_channels//4), 50 | nn.ReLU(inplace=True), 51 | nn.ConvTranspose2d(inner_channels//4, inner_channels//4, 2, 2), 52 | BatchNorm2d(inner_channels//4), 53 | nn.ReLU(inplace=True), 54 | nn.ConvTranspose2d(inner_channels//4, 1, 2, 2), 55 | nn.Sigmoid()) 56 | self.binarize.apply(self.weights_init) 57 | 58 | self.adaptive = adaptive 59 | if adaptive: 60 | self.thresh = self._init_thresh( 61 | inner_channels, serial=serial, smooth=smooth, bias=bias) 62 | self.thresh.apply(self.weights_init) 63 | 64 | self.in5.apply(self.weights_init) 65 | self.in4.apply(self.weights_init) 66 | self.in3.apply(self.weights_init) 67 | self.in2.apply(self.weights_init) 68 | self.out5.apply(self.weights_init) 69 | self.out4.apply(self.weights_init) 70 | self.out3.apply(self.weights_init) 71 | self.out2.apply(self.weights_init) 72 | 73 | def weights_init(self, m): 74 | classname = m.__class__.__name__ 75 | if classname.find('Conv') != -1: 76 | nn.init.kaiming_normal_(m.weight.data) 77 | elif classname.find('BatchNorm') != -1: 78 | m.weight.data.fill_(1.) 79 | m.bias.data.fill_(1e-4) 80 | 81 | def _init_thresh(self, inner_channels, 82 | serial=False, smooth=False, bias=False): 83 | in_channels = inner_channels 84 | if serial: 85 | in_channels += 1 86 | self.thresh = nn.Sequential( 87 | nn.Conv2d(in_channels, inner_channels // 88 | 4, 3, padding=1, bias=bias), 89 | BatchNorm2d(inner_channels//4), 90 | nn.ReLU(inplace=True), 91 | self._init_upsample(inner_channels // 4, inner_channels//4, smooth=smooth, bias=bias), 92 | BatchNorm2d(inner_channels//4), 93 | nn.ReLU(inplace=True), 94 | self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias), 95 | nn.Sigmoid()) 96 | return self.thresh 97 | 98 | def _init_upsample(self, 99 | in_channels, out_channels, 100 | smooth=False, bias=False): 101 | if smooth: 102 | inter_out_channels = out_channels 103 | if out_channels == 1: 104 | inter_out_channels = in_channels 105 | module_list = [ 106 | nn.Upsample(scale_factor=2, mode='nearest'), 107 | nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)] 108 | if out_channels == 1: 109 | module_list.append( 110 | nn.Conv2d(in_channels, out_channels, 111 | kernel_size=1, stride=1, padding=1, bias=True)) 112 | 113 | return nn.Sequential(module_list) 114 | else: 115 | return nn.ConvTranspose2d(in_channels, out_channels, 2, 2) 116 | 117 | def forward(self, features, gt=None, masks=None, training=False): 118 | c2, c3, c4, c5 = features 119 | in5 = self.in5(c5) 120 | in4 = self.in4(c4) 121 | in3 = self.in3(c3) 122 | in2 = self.in2(c2) 123 | 124 | out4 = self.up5(in5) + in4 # 1/16 125 | out3 = self.up4(out4) + in3 # 1/8 126 | out2 = self.up3(out3) + in2 # 1/4 127 | 128 | p5 = self.out5(in5) 129 | p4 = self.out4(out4) 130 | p3 = self.out3(out3) 131 | p2 = self.out2(out2) 132 | 133 | fuse = torch.cat((p5, p4, p3, p2), 1) 134 | # this is the pred module, not binarization module; 135 | # We do not correct the name due to the trained model. 136 | binary = self.binarize(fuse) 137 | if self.training: 138 | result = OrderedDict(binary=binary) 139 | else: 140 | return binary 141 | if self.adaptive and self.training: 142 | if self.serial: 143 | fuse = torch.cat( 144 | (fuse, nn.functional.interpolate( 145 | binary, fuse.shape[2:])), 1) 146 | thresh = self.thresh(fuse) 147 | thresh_binary = self.step_function(binary, thresh) 148 | result.update(thresh=thresh, thresh_binary=thresh_binary) 149 | return result 150 | 151 | def step_function(self, x, y): 152 | return torch.reciprocal(1 + torch.exp(-self.k * (x - y))) 153 | -------------------------------------------------------------------------------- /decoders/simple_detection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | from backbones.upsample_head import SimpleUpsampleHead 7 | 8 | 9 | class SimpleDetectionDecoder(nn.Module): 10 | def __init__(self, feature_channel=256): 11 | nn.Module.__init__(self) 12 | 13 | self.feature_channel = feature_channel 14 | self.head_layer = self.create_head_layer() 15 | 16 | self.pred_layers = nn.ModuleDict(self.create_pred_layers()) 17 | 18 | def create_head_layer(self): 19 | return SimpleUpsampleHead( 20 | self.feature_channel, 21 | [self.feature_channel, self.feature_channel // 2, self.feature_channel // 4] 22 | ) 23 | 24 | def create_pred_layer(self, channels): 25 | return nn.Sequential( 26 | nn.Conv2d(self.feature_channel // 4, channels, kernel_size=1, stride=1, padding=0, bias=False), 27 | ) 28 | 29 | def create_pred_layers(self): 30 | return {} 31 | 32 | def postprocess_pred(self, pred): 33 | return pred 34 | 35 | def calculate_losses(self, preds, label): 36 | raise NotImplementedError() 37 | 38 | def forward(self, input, label, meta, train): 39 | feature = self.head_layer(input) 40 | 41 | pred = {} 42 | for name, pred_layer in self.pred_layers.items(): 43 | pred[name] = pred_layer(feature) 44 | 45 | if train: 46 | losses = self.calculate_losses(pred, label) 47 | pred = self.postprocess_pred(pred) 48 | loss = sum(losses.values()) 49 | return loss, pred, losses 50 | else: 51 | pred = self.postprocess_pred(pred) 52 | return pred 53 | 54 | 55 | class SimpleSegDecoder(SimpleDetectionDecoder): 56 | def create_pred_layers(self): 57 | return { 58 | 'heatmap': self.create_pred_layer(1) 59 | } 60 | 61 | def postprocess_pred(self, pred): 62 | pred['heatmap'] = F.sigmoid(pred['heatmap']) 63 | return pred 64 | 65 | def calculate_losses(self, pred, label): 66 | heatmap = label['heatmap'] 67 | heatmap_weight = label['heatmap_weight'] 68 | 69 | heatmap_pred = pred['heatmap'] 70 | 71 | heatmap_loss = F.binary_cross_entropy_with_logits(heatmap_pred, heatmap, reduction='none') 72 | heatmap_loss = (heatmap_loss * heatmap_weight).mean(dim=(1, 2, 3)) 73 | 74 | return { 75 | 'heatmap_loss': heatmap_loss, 76 | } 77 | 78 | 79 | class SimpleEASTDecoder(SimpleDetectionDecoder): 80 | def __init__(self, feature_channels=256, densebox_ratio=1000.0, densebox_rescale_factor=512): 81 | SimpleDetectionDecoder.__init__(self, feature_channels) 82 | 83 | self.densebox_ratio = densebox_ratio 84 | self.densebox_rescale_factor = densebox_rescale_factor 85 | 86 | def create_pred_layers(self): 87 | return { 88 | 'heatmap': self.create_pred_layer(1), 89 | 'densebox': self.create_pred_layer(8), 90 | } 91 | 92 | def postprocess_pred(self, pred): 93 | pred['heatmap'] = F.sigmoid(pred['heatmap']) 94 | pred['densebox'] = pred['densebox'] * self.densebox_rescale_factor 95 | return pred 96 | 97 | def calculate_losses(self, pred, label): 98 | heatmap = label['heatmap'] 99 | heatmap_weight = label['heatmap_weight'] 100 | densebox = label['densebox'] / self.densebox_rescale_factor 101 | densebox_weight = label['densebox_weight'] 102 | 103 | heatmap_pred = pred['heatmap'] 104 | densebox_pred = pred['densebox'] 105 | 106 | heatmap_loss = F.binary_cross_entropy_with_logits(heatmap_pred, heatmap, reduction='none') 107 | heatmap_loss = (heatmap_loss * heatmap_weight).mean(dim=(1, 2, 3)) 108 | 109 | densebox_loss = F.mse_loss(densebox_pred, densebox, reduction='none') 110 | densebox_loss = (densebox_loss * densebox_weight).mean(dim=(1, 2, 3)) * self.densebox_ratio 111 | 112 | return { 113 | 'heatmap_loss': heatmap_loss, 114 | 'densebox_loss': densebox_loss, 115 | } 116 | 117 | 118 | class SimpleTextsnakeDecoder(SimpleDetectionDecoder): 119 | def __init__(self, feature_channels=256, radius_ratio=10.0): 120 | SimpleDetectionDecoder.__init__(self, feature_channels) 121 | 122 | self.radius_ratio = radius_ratio 123 | 124 | def create_pred_layers(self): 125 | return { 126 | 'heatmap': self.create_pred_layer(1), 127 | 'radius': self.create_pred_layer(1), 128 | } 129 | 130 | def postprocess_pred(self, pred): 131 | pred['heatmap'] = F.sigmoid(pred['heatmap']) 132 | pred['radius'] = torch.exp(pred['radius']) 133 | return pred 134 | 135 | def calculate_losses(self, pred, label): 136 | heatmap = label['heatmap'] 137 | heatmap_weight = label['heatmap_weight'] 138 | radius = torch.log(label['radius'] + 1) 139 | radius_weight = label['radius_weight'] 140 | 141 | heatmap_pred = pred['heatmap'] 142 | radius_pred = pred['radius'] 143 | 144 | heatmap_loss = F.binary_cross_entropy_with_logits(heatmap_pred, heatmap, reduction='none') 145 | heatmap_loss = (heatmap_loss * heatmap_weight).mean(dim=(1, 2, 3)) 146 | 147 | radius_loss = F.smooth_l1_loss(radius_pred, radius, reduction='none') 148 | radius_loss = (radius_loss * radius_weight).mean(dim=(1, 2, 3)) * self.radius_ratio 149 | 150 | return { 151 | 'heatmap_loss': heatmap_loss, 152 | 'radius_loss': radius_loss, 153 | } 154 | 155 | 156 | class SimpleMSRDecoder(SimpleDetectionDecoder): 157 | def __init__(self, feature_channels=256, offset_ratio=1000.0, offset_rescale_factor=512): 158 | SimpleDetectionDecoder.__init__(self, feature_channels) 159 | 160 | self.offset_ratio = offset_ratio 161 | self.offset_rescale_factor = offset_rescale_factor 162 | 163 | def create_pred_layers(self): 164 | return { 165 | 'heatmap': self.create_pred_layer(1), 166 | 'offset': self.create_pred_layer(2), 167 | } 168 | 169 | def postprocess_pred(self, pred): 170 | pred['heatmap'] = F.sigmoid(pred['heatmap']) 171 | pred['offset'] = pred['offset'] * self.offset_rescale_factor 172 | return pred 173 | 174 | def calculate_losses(self, pred, label): 175 | heatmap = label['heatmap'] 176 | heatmap_weight = label['heatmap_weight'] 177 | offset = label['offset'] / self.offset_rescale_factor 178 | offset_weight = label['offset_weight'] 179 | 180 | heatmap_pred = pred['heatmap'] 181 | offset_pred = pred['offset'] 182 | 183 | heatmap_loss = F.binary_cross_entropy_with_logits(heatmap_pred, heatmap, reduction='none') 184 | heatmap_loss = (heatmap_loss * heatmap_weight).mean(dim=(1, 2, 3)) 185 | offset_loss = F.mse_loss(offset_pred, offset, reduction='none') 186 | offset_loss = (offset_loss * offset_weight).mean(dim=(1, 2, 3)) * self.offset_ratio 187 | 188 | return { 189 | 'heatmap_loss': heatmap_loss, 190 | 'offset_loss': offset_loss, 191 | } 192 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #!python3 2 | import argparse 3 | import os 4 | import torch 5 | import cv2 6 | import numpy as np 7 | from experiment import Structure, Experiment 8 | from concern.config import Configurable, Config 9 | import math 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser(description='Text Recognition Training') 13 | parser.add_argument('exp', type=str) 14 | parser.add_argument('--resume', type=str, help='Resume from checkpoint') 15 | parser.add_argument('--image_path', type=str, help='image path') 16 | parser.add_argument('--result_dir', type=str, default='./demo_results/', help='path to save results') 17 | parser.add_argument('--data', type=str, 18 | help='The name of dataloader which will be evaluated on.') 19 | parser.add_argument('--image_short_side', type=int, default=736, 20 | help='The threshold to replace it in the representers') 21 | parser.add_argument('--thresh', type=float, 22 | help='The threshold to replace it in the representers') 23 | parser.add_argument('--box_thresh', type=float, default=0.6, 24 | help='The threshold to replace it in the representers') 25 | parser.add_argument('--visualize', action='store_true', 26 | help='visualize maps in tensorboard') 27 | parser.add_argument('--resize', action='store_true', 28 | help='resize') 29 | parser.add_argument('--polygon', action='store_true', 30 | help='output polygons if true') 31 | parser.add_argument('--eager', '--eager_show', action='store_true', dest='eager_show', 32 | help='Show iamges eagerly') 33 | 34 | args = parser.parse_args() 35 | args = vars(args) 36 | args = {k: v for k, v in args.items() if v is not None} 37 | 38 | conf = Config() 39 | experiment_args = conf.compile(conf.load(args['exp']))['Experiment'] 40 | experiment_args.update(cmd=args) 41 | experiment = Configurable.construct_class_from_config(experiment_args) 42 | 43 | Demo(experiment, experiment_args, cmd=args).inference(args['image_path'], args['visualize']) 44 | 45 | 46 | class Demo: 47 | def __init__(self, experiment, args, cmd=dict()): 48 | self.RGB_MEAN = np.array([122.67891434, 116.66876762, 104.00698793]) 49 | self.experiment = experiment 50 | experiment.load('evaluation', **args) 51 | self.args = cmd 52 | model_saver = experiment.train.model_saver 53 | self.structure = experiment.structure 54 | self.model_path = self.args['resume'] 55 | 56 | def init_torch_tensor(self): 57 | # Use gpu or not 58 | torch.set_default_tensor_type('torch.FloatTensor') 59 | if torch.cuda.is_available(): 60 | self.device = torch.device('cuda') 61 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 62 | else: 63 | self.device = torch.device('cpu') 64 | 65 | def init_model(self): 66 | model = self.structure.builder.build(self.device) 67 | return model 68 | 69 | def resume(self, model, path): 70 | if not os.path.exists(path): 71 | print("Checkpoint not found: " + path) 72 | return 73 | print("Resuming from " + path) 74 | states = torch.load( 75 | path, map_location=self.device) 76 | model.load_state_dict(states, strict=False) 77 | print("Resumed from " + path) 78 | 79 | def resize_image(self, img): 80 | height, width, _ = img.shape 81 | if height < width: 82 | new_height = self.args['image_short_side'] 83 | new_width = int(math.ceil(new_height / height * width / 32) * 32) 84 | else: 85 | new_width = self.args['image_short_side'] 86 | new_height = int(math.ceil(new_width / width * height / 32) * 32) 87 | resized_img = cv2.resize(img, (new_width, new_height)) 88 | return resized_img 89 | 90 | def load_image(self, image_path): 91 | img = cv2.imread(image_path, cv2.IMREAD_COLOR).astype('float32') 92 | original_shape = img.shape[:2] 93 | img = self.resize_image(img) 94 | img -= self.RGB_MEAN 95 | img /= 255. 96 | img = torch.from_numpy(img).permute(2, 0, 1).float().unsqueeze(0) 97 | return img, original_shape 98 | 99 | def format_output(self, batch, output): 100 | batch_boxes, batch_scores = output 101 | for index in range(batch['image'].size(0)): 102 | original_shape = batch['shape'][index] 103 | filename = batch['filename'][index] 104 | result_file_name = 'res_' + filename.split('/')[-1].split('.')[0] + '.txt' 105 | result_file_path = os.path.join(self.args['result_dir'], result_file_name) 106 | boxes = batch_boxes[index] 107 | scores = batch_scores[index] 108 | if self.args['polygon']: 109 | with open(result_file_path, 'wt') as res: 110 | for i, box in enumerate(boxes): 111 | box = np.array(box).reshape(-1).tolist() 112 | result = ",".join([str(int(x)) for x in box]) 113 | score = scores[i] 114 | res.write(result + ',' + str(score) + "\n") 115 | else: 116 | with open(result_file_path, 'wt') as res: 117 | for i in range(boxes.shape[0]): 118 | score = scores[i] 119 | if score < self.args['box_thresh']: 120 | continue 121 | box = boxes[i,:,:].reshape(-1).tolist() 122 | result = ",".join([str(int(x)) for x in box]) 123 | res.write(result + ',' + str(score) + "\n") 124 | 125 | def inference(self, image_path, visualize=False): 126 | self.init_torch_tensor() 127 | model = self.init_model() 128 | self.resume(model, self.model_path) 129 | all_matircs = {} 130 | model.eval() 131 | batch = dict() 132 | batch['filename'] = [image_path] 133 | img, original_shape = self.load_image(image_path) 134 | batch['shape'] = [original_shape] 135 | with torch.no_grad(): 136 | batch['image'] = img 137 | pred = model.forward(batch, training=False) 138 | output = self.structure.representer.represent(batch, pred, is_output_polygon=self.args['polygon']) 139 | if not os.path.isdir(self.args['result_dir']): 140 | os.mkdir(self.args['result_dir']) 141 | self.format_output(batch, output) 142 | 143 | if visualize and self.structure.visualizer: 144 | vis_image = self.structure.visualizer.demo_visualize(image_path, output) 145 | cv2.imwrite(os.path.join(self.args['result_dir'], image_path.split('/')[-1].split('.')[0]+'.jpg'), vis_image) 146 | 147 | if __name__ == '__main__': 148 | main() 149 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | from concern.config import Configurable, State 2 | from concern.log import Logger 3 | from structure.builder import Builder 4 | from structure.representers import * 5 | from structure.measurers import * 6 | from structure.visualizers import * 7 | from data.data_loader import * 8 | from data import * 9 | from training.model_saver import ModelSaver 10 | from training.checkpoint import Checkpoint 11 | from training.optimizer_scheduler import OptimizerScheduler 12 | 13 | 14 | class Structure(Configurable): 15 | builder = State() 16 | representer = State() 17 | measurer = State() 18 | visualizer = State() 19 | 20 | def __init__(self, **kwargs): 21 | self.load_all(**kwargs) 22 | 23 | @property 24 | def model_name(self): 25 | return self.builder.model_name 26 | 27 | 28 | class TrainSettings(Configurable): 29 | data_loader = State() 30 | model_saver = State() 31 | checkpoint = State() 32 | scheduler = State() 33 | epochs = State(default=10) 34 | 35 | def __init__(self, **kwargs): 36 | kwargs['cmd'].update(is_train=True) 37 | self.load_all(**kwargs) 38 | if 'epochs' in kwargs['cmd']: 39 | self.epochs = kwargs['cmd']['epochs'] 40 | 41 | 42 | class ValidationSettings(Configurable): 43 | data_loaders = State() 44 | visualize = State() 45 | interval = State(default=100) 46 | exempt = State(default=-1) 47 | 48 | def __init__(self, **kwargs): 49 | kwargs['cmd'].update(is_train=False) 50 | self.load_all(**kwargs) 51 | 52 | cmd = kwargs['cmd'] 53 | self.visualize = cmd['visualize'] 54 | 55 | 56 | class EvaluationSettings(Configurable): 57 | data_loaders = State() 58 | visualize = State(default=True) 59 | resume = State() 60 | 61 | def __init__(self, **kwargs): 62 | self.load_all(**kwargs) 63 | 64 | 65 | class EvaluationSettings2(Configurable): 66 | structure = State() 67 | data_loaders = State() 68 | 69 | def __init__(self, **kwargs): 70 | self.load_all(**kwargs) 71 | 72 | 73 | class ShowSettings(Configurable): 74 | data_loader = State() 75 | representer = State() 76 | visualizer = State() 77 | 78 | def __init__(self, **kwargs): 79 | self.load_all(**kwargs) 80 | 81 | 82 | class Experiment(Configurable): 83 | structure = State(autoload=False) 84 | train = State() 85 | validation = State(autoload=False) 86 | evaluation = State(autoload=False) 87 | logger = State(autoload=True) 88 | 89 | def __init__(self, **kwargs): 90 | self.load('structure', **kwargs) 91 | 92 | cmd = kwargs.get('cmd', {}) 93 | if 'name' not in cmd: 94 | cmd['name'] = self.structure.model_name 95 | 96 | self.load_all(**kwargs) 97 | self.distributed = cmd.get('distributed', False) 98 | self.local_rank = cmd.get('local_rank', 0) 99 | 100 | if cmd.get('validate', False): 101 | self.load('validation', **kwargs) 102 | else: 103 | self.validation = None 104 | -------------------------------------------------------------------------------- /experiments/ASF/td500_resnet50_deform_thre_asf.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/seg_detector/base_td500.yaml' 3 | package: [] 4 | define: 5 | - name: 'Experiment' 6 | class: Experiment 7 | structure: 8 | class: Structure 9 | builder: 10 | class: Builder 11 | model: SegDetectorModel 12 | model_args: 13 | backbone: deformable_resnet50 14 | decoder: SegSpatialScaleDetector 15 | decoder_args: 16 | adaptive: True 17 | in_channels: [256, 512, 1024, 2048] 18 | k: 50 19 | attention_type: 'scale_channel_spatial' 20 | loss_class: L1BalanceCELoss 21 | representer: 22 | class: SegDetectorRepresenter 23 | max_candidates: 1000 24 | measurer: 25 | class: QuadMeasurer 26 | visualizer: 27 | class: SegDetectorVisualizer 28 | train: 29 | class: TrainSettings 30 | data_loader: 31 | class: DataLoader 32 | dataset: ^train_data 33 | batch_size: 16 34 | num_workers: 8 35 | checkpoint: 36 | class: Checkpoint 37 | start_epoch: 0 38 | start_iter: 0 39 | resume: null 40 | model_saver: 41 | class: ModelSaver 42 | dir_path: model 43 | save_interval: 2000 44 | signal_path: save 45 | scheduler: 46 | class: OptimizerScheduler 47 | optimizer: "SGD" 48 | optimizer_args: 49 | lr: 0.007 50 | momentum: 0.9 51 | weight_decay: 0.0001 52 | learning_rate: 53 | class: DecayLearningRate 54 | epochs: 1000 55 | epochs: 1000 56 | 57 | validation: &validate 58 | class: ValidationSettings 59 | data_loaders: 60 | icdar2015: 61 | class: DataLoader 62 | dataset: ^validate_data 63 | batch_size: 1 64 | num_workers: 0 65 | collect_fn: 66 | class: ICDARCollectFN 67 | visualize: false 68 | interval: 4500 69 | exempt: 1 70 | 71 | logger: 72 | class: Logger 73 | verbose: true 74 | level: info 75 | log_interval: 450 76 | 77 | evaluation: *validate -------------------------------------------------------------------------------- /experiments/base.yaml: -------------------------------------------------------------------------------- 1 | import: [] 2 | package: 3 | - 'experiment' 4 | - 'structure.model' 5 | - 'training.learning_rate' 6 | - 'data' 7 | - 'data.processes' 8 | define: [] 9 | -------------------------------------------------------------------------------- /experiments/seg_detector/base.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/base.yaml' 3 | package: 4 | - 'decoders.seg_detector_loss' 5 | define: 6 | - name: train_data 7 | class: ImageDataset 8 | data_dir: 9 | - './datasets/icdar2015/' 10 | data_list: 11 | - './datasets/icdar2015/train_list.txt' 12 | processes: 13 | - class: AugmentDetectionData 14 | augmenter_args: 15 | - ['Fliplr', 0.5] 16 | - {'cls': 'Affine', 'rotate': [-10, 10]} 17 | - ['Resize', [0.5, 3.0]] 18 | only_resize: False 19 | keep_ratio: False 20 | - class: RandomCropData 21 | size: [640, 640] 22 | max_tries: 10 23 | - class: MakeICDARData 24 | - class: MakeSegDetectionData 25 | - class: MakeBorderMap 26 | - class: NormalizeImage 27 | - class: FilterKeys 28 | superfluous: ['polygons', 'filename', 'shape', 'ignore_tags', 'is_training'] 29 | 30 | - name: validate_data 31 | class: ImageDataset 32 | data_dir: 33 | - './datasets/icdar2015/' 34 | data_list: 35 | - './datasets/icdar2015/test_list.txt' 36 | processes: 37 | - class: AugmentDetectionData 38 | augmenter_args: 39 | # - ['Resize', {'width': 1280, 'height': 736}] 40 | - ['Resize', {'width': 2048, 'height': 1152}] 41 | only_resize: True 42 | keep_ratio: True 43 | - class: MakeICDARData 44 | - class: MakeSegDetectionData 45 | - class: NormalizeImage 46 | -------------------------------------------------------------------------------- /experiments/seg_detector/base_ic15.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/base.yaml' 3 | package: 4 | - 'decoders.seg_detector_loss' 5 | define: 6 | - name: train_data 7 | class: ImageDataset 8 | data_dir: 9 | - './datasets/icdar2015/' 10 | data_list: 11 | - './datasets/icdar2015/train_list.txt' 12 | processes: 13 | - class: AugmentDetectionData 14 | augmenter_args: 15 | - ['Fliplr', 0.5] 16 | - {'cls': 'Affine', 'rotate': [-10, 10]} 17 | - ['Resize', [0.5, 3.0]] 18 | only_resize: False 19 | keep_ratio: False 20 | - class: RandomCropData 21 | size: [640, 640] 22 | max_tries: 10 23 | - class: MakeICDARData 24 | - class: MakeSegDetectionData 25 | - class: MakeBorderMap 26 | - class: NormalizeImage 27 | - class: FilterKeys 28 | superfluous: ['polygons', 'filename', 'shape', 'ignore_tags', 'is_training'] 29 | 30 | - name: validate_data 31 | class: ImageDataset 32 | data_dir: 33 | - './datasets/icdar2015/' 34 | data_list: 35 | - './datasets/icdar2015/test_list.txt' 36 | processes: 37 | - class: AugmentDetectionData 38 | augmenter_args: 39 | - ['Resize', {'width': 1280, 'height': 736}] 40 | # - ['Resize', {'width': 2048, 'height': 1152}] 41 | only_resize: True 42 | keep_ratio: False 43 | - class: MakeICDARData 44 | - class: MakeSegDetectionData 45 | - class: NormalizeImage 46 | -------------------------------------------------------------------------------- /experiments/seg_detector/base_td500.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/base.yaml' 3 | package: 4 | - 'decoders.seg_detector_loss' 5 | define: 6 | - name: train_data 7 | class: ImageDataset 8 | data_dir: 9 | - './datasets/TD_TR/TD500/' 10 | - './datasets/TD_TR/TR400/' 11 | data_list: 12 | - './datasets/TD_TR/TD500/train_list.txt' 13 | - './datasets/TD_TR/TR400/train_list.txt' 14 | processes: 15 | - class: AugmentDetectionData 16 | augmenter_args: 17 | - ['Fliplr', 0.5] 18 | - {'cls': 'Affine', 'rotate': [-10, 10]} 19 | - ['Resize', [0.5, 3.0]] 20 | only_resize: False 21 | keep_ratio: False 22 | - class: RandomCropData 23 | size: [640, 640] 24 | max_tries: 10 25 | - class: MakeICDARData 26 | - class: MakeSegDetectionData 27 | - class: MakeBorderMap 28 | - class: NormalizeImage 29 | - class: FilterKeys 30 | superfluous: ['polygons', 'filename', 'shape', 'ignore_tags', 'is_training'] 31 | 32 | - name: validate_data 33 | class: ImageDataset 34 | data_dir: 35 | - './datasets/TD_TR/TD500/' 36 | data_list: 37 | - './datasets/TD_TR/TD500/test_list.txt' 38 | processes: 39 | - class: AugmentDetectionData 40 | augmenter_args: 41 | - ['Resize', {'width': 736, 'height': 736}] 42 | only_resize: True 43 | keep_ratio: True 44 | - class: MakeICDARData 45 | - class: MakeSegDetectionData 46 | - class: NormalizeImage 47 | -------------------------------------------------------------------------------- /experiments/seg_detector/base_totaltext.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/base.yaml' 3 | package: 4 | - 'decoders.seg_detector_loss' 5 | define: 6 | - name: train_data 7 | class: ImageDataset 8 | data_dir: 9 | - './datasets/total_text/' 10 | data_list: 11 | - './datasets/total_text/train_list.txt' 12 | processes: 13 | - class: AugmentDetectionData 14 | augmenter_args: 15 | - ['Fliplr', 0.5] 16 | - {'cls': 'Affine', 'rotate': [-10, 10]} 17 | - ['Resize', [0.5, 3.0]] 18 | only_resize: False 19 | keep_ratio: False 20 | - class: RandomCropData 21 | size: [640, 640] 22 | max_tries: 10 23 | - class: MakeICDARData 24 | - class: MakeSegDetectionData 25 | - class: MakeBorderMap 26 | - class: NormalizeImage 27 | - class: FilterKeys 28 | superfluous: ['polygons', 'filename', 'shape', 'ignore_tags', 'is_training'] 29 | 30 | - name: validate_data 31 | class: ImageDataset 32 | data_dir: 33 | - './datasets/total_text/' 34 | data_list: 35 | - './datasets/total_text/test_list.txt' 36 | processes: 37 | - class: AugmentDetectionData 38 | augmenter_args: 39 | - ['Resize', {'width': 800, 'height': 800}] 40 | only_resize: True 41 | keep_ratio: True 42 | - class: MakeICDARData 43 | - class: MakeSegDetectionData 44 | - class: NormalizeImage 45 | -------------------------------------------------------------------------------- /experiments/seg_detector/ic15_resnet18_deform_thre.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/seg_detector/base_ic15.yaml' 3 | package: [] 4 | define: 5 | - name: 'Experiment' 6 | class: Experiment 7 | structure: 8 | class: Structure 9 | builder: 10 | class: Builder 11 | model: SegDetectorModel 12 | model_args: 13 | backbone: deformable_resnet18 14 | decoder: SegDetector 15 | decoder_args: 16 | adaptive: True 17 | in_channels: [64, 128, 256, 512] 18 | k: 50 19 | loss_class: L1BalanceCELoss 20 | 21 | representer: 22 | class: SegDetectorRepresenter 23 | max_candidates: 1000 24 | measurer: 25 | class: QuadMeasurer 26 | visualizer: 27 | class: SegDetectorVisualizer 28 | train: 29 | class: TrainSettings 30 | data_loader: 31 | class: DataLoader 32 | dataset: ^train_data 33 | batch_size: 16 34 | num_workers: 16 35 | checkpoint: 36 | class: Checkpoint 37 | start_epoch: 0 38 | start_iter: 0 39 | resume: null 40 | model_saver: 41 | class: ModelSaver 42 | dir_path: model 43 | save_interval: 3000 44 | signal_path: save 45 | scheduler: 46 | class: OptimizerScheduler 47 | optimizer: "SGD" 48 | optimizer_args: 49 | lr: 0.007 50 | momentum: 0.9 51 | weight_decay: 0.0001 52 | learning_rate: 53 | class: DecayLearningRate 54 | epochs: 1200 55 | epochs: 1200 56 | 57 | validation: &validate 58 | class: ValidationSettings 59 | data_loaders: 60 | icdar2015: 61 | class: DataLoader 62 | dataset: ^validate_data 63 | batch_size: 1 64 | num_workers: 16 65 | collect_fn: 66 | class: ICDARCollectFN 67 | visualize: false 68 | interval: 4500 69 | exempt: 1 70 | 71 | logger: 72 | class: Logger 73 | verbose: true 74 | level: info 75 | log_interval: 450 76 | 77 | evaluation: *validate 78 | -------------------------------------------------------------------------------- /experiments/seg_detector/ic15_resnet50_deform_thre.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/seg_detector/base_ic15.yaml' 3 | package: [] 4 | define: 5 | - name: 'Experiment' 6 | class: Experiment 7 | structure: 8 | class: Structure 9 | builder: 10 | class: Builder 11 | model: SegDetectorModel 12 | model_args: 13 | backbone: deformable_resnet50 14 | decoder: SegDetector 15 | decoder_args: 16 | adaptive: True 17 | in_channels: [256, 512, 1024, 2048] 18 | k: 50 19 | loss_class: L1BalanceCELoss 20 | 21 | representer: 22 | class: SegDetectorRepresenter 23 | max_candidates: 1000 24 | measurer: 25 | class: QuadMeasurer 26 | visualizer: 27 | class: SegDetectorVisualizer 28 | train: 29 | class: TrainSettings 30 | data_loader: 31 | class: DataLoader 32 | dataset: ^train_data 33 | batch_size: 16 34 | num_workers: 16 35 | checkpoint: 36 | class: Checkpoint 37 | start_epoch: 0 38 | start_iter: 0 39 | resume: null 40 | model_saver: 41 | class: ModelSaver 42 | dir_path: model 43 | save_interval: 3000 44 | signal_path: save 45 | scheduler: 46 | class: OptimizerScheduler 47 | optimizer: "SGD" 48 | optimizer_args: 49 | lr: 0.007 50 | momentum: 0.9 51 | weight_decay: 0.0001 52 | learning_rate: 53 | class: DecayLearningRate 54 | epochs: 1200 55 | epochs: 1200 56 | 57 | validation: &validate 58 | class: ValidationSettings 59 | data_loaders: 60 | icdar2015: 61 | class: DataLoader 62 | dataset: ^validate_data 63 | batch_size: 1 64 | num_workers: 16 65 | collect_fn: 66 | class: ICDARCollectFN 67 | visualize: false 68 | interval: 4500 69 | exempt: 1 70 | 71 | logger: 72 | class: Logger 73 | verbose: true 74 | level: info 75 | log_interval: 450 76 | 77 | evaluation: *validate 78 | -------------------------------------------------------------------------------- /experiments/seg_detector/td500_resnet18_deform_thre.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/seg_detector/base_td500.yaml' 3 | package: [] 4 | define: 5 | - name: 'Experiment' 6 | class: Experiment 7 | structure: 8 | class: Structure 9 | builder: 10 | class: Builder 11 | model: SegDetectorModel 12 | model_args: 13 | backbone: deformable_resnet18 14 | decoder: SegDetector 15 | decoder_args: 16 | adaptive: True 17 | in_channels: [64, 128, 256, 512] 18 | k: 50 19 | loss_class: L1BalanceCELoss 20 | 21 | representer: 22 | class: SegDetectorRepresenter 23 | max_candidates: 1000 24 | measurer: 25 | class: QuadMeasurer 26 | visualizer: 27 | class: SegDetectorVisualizer 28 | train: 29 | class: TrainSettings 30 | data_loader: 31 | class: DataLoader 32 | dataset: ^train_data 33 | batch_size: 16 34 | num_workers: 16 35 | checkpoint: 36 | class: Checkpoint 37 | start_epoch: 0 38 | start_iter: 0 39 | resume: null 40 | model_saver: 41 | class: ModelSaver 42 | dir_path: model 43 | save_interval: 18000 44 | signal_path: save 45 | scheduler: 46 | class: OptimizerScheduler 47 | optimizer: "SGD" 48 | optimizer_args: 49 | lr: 0.007 50 | momentum: 0.9 51 | weight_decay: 0.0001 52 | learning_rate: 53 | class: DecayLearningRate 54 | epochs: 1200 55 | epochs: 1200 56 | 57 | validation: &validate 58 | class: ValidationSettings 59 | data_loaders: 60 | icdar2015: 61 | class: DataLoader 62 | dataset: ^validate_data 63 | batch_size: 1 64 | num_workers: 16 65 | collect_fn: 66 | class: ICDARCollectFN 67 | visualize: false 68 | interval: 4500 69 | exempt: 1 70 | 71 | logger: 72 | class: Logger 73 | verbose: true 74 | level: info 75 | log_interval: 450 76 | 77 | evaluation: *validate 78 | -------------------------------------------------------------------------------- /experiments/seg_detector/td500_resnet50_deform_thre.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/seg_detector/base_td500.yaml' 3 | package: [] 4 | define: 5 | - name: 'Experiment' 6 | class: Experiment 7 | structure: 8 | class: Structure 9 | builder: 10 | class: Builder 11 | model: SegDetectorModel 12 | model_args: 13 | backbone: deformable_resnet50 14 | decoder: SegDetector 15 | decoder_args: 16 | adaptive: True 17 | in_channels: [256, 512, 1024, 2048] 18 | k: 50 19 | loss_class: L1BalanceCELoss 20 | 21 | representer: 22 | class: SegDetectorRepresenter 23 | max_candidates: 1000 24 | measurer: 25 | class: QuadMeasurer 26 | visualizer: 27 | class: SegDetectorVisualizer 28 | train: 29 | class: TrainSettings 30 | data_loader: 31 | class: DataLoader 32 | dataset: ^train_data 33 | batch_size: 16 34 | num_workers: 16 35 | checkpoint: 36 | class: Checkpoint 37 | start_epoch: 0 38 | start_iter: 0 39 | resume: null 40 | model_saver: 41 | class: ModelSaver 42 | dir_path: model 43 | save_interval: 18000 44 | signal_path: save 45 | scheduler: 46 | class: OptimizerScheduler 47 | optimizer: "SGD" 48 | optimizer_args: 49 | lr: 0.007 50 | momentum: 0.9 51 | weight_decay: 0.0001 52 | learning_rate: 53 | class: DecayLearningRate 54 | epochs: 1200 55 | epochs: 1200 56 | 57 | validation: &validate 58 | class: ValidationSettings 59 | data_loaders: 60 | icdar2015: 61 | class: DataLoader 62 | dataset: ^validate_data 63 | batch_size: 1 64 | num_workers: 16 65 | collect_fn: 66 | class: ICDARCollectFN 67 | visualize: false 68 | interval: 4500 69 | exempt: 1 70 | 71 | logger: 72 | class: Logger 73 | verbose: true 74 | level: info 75 | log_interval: 450 76 | 77 | evaluation: *validate 78 | -------------------------------------------------------------------------------- /experiments/seg_detector/totaltext_mobilenet_v3_large_thre.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/seg_detector/base_totaltext.yaml' 3 | package: [] 4 | define: 5 | - name: 'Experiment' 6 | class: Experiment 7 | structure: 8 | class: Structure 9 | builder: 10 | class: Builder 11 | model: SegDetectorModel 12 | model_args: 13 | backbone: mobilenet_v3_large 14 | decoder: SegDetector 15 | decoder_args: 16 | adaptive: True 17 | in_channels: [24, 40, 112, 960] 18 | k: 50 19 | loss_class: L1BalanceCELoss 20 | representer: 21 | class: SegDetectorRepresenter 22 | max_candidates: 1000 23 | measurer: 24 | class: QuadMeasurer 25 | visualizer: 26 | class: SegDetectorVisualizer 27 | train: 28 | class: TrainSettings 29 | data_loader: 30 | class: DataLoader 31 | dataset: ^train_data 32 | batch_size: 8 33 | num_workers: 4 34 | checkpoint: 35 | class: Checkpoint 36 | start_epoch: 0 37 | start_iter: 0 38 | resume: null 39 | model_saver: 40 | class: ModelSaver 41 | dir_path: model 42 | save_interval: 156 43 | signal_path: save 44 | scheduler: 45 | class: OptimizerScheduler 46 | optimizer: "SGD" 47 | optimizer_args: 48 | lr: 0.001 49 | momentum: 0.9 50 | weight_decay: 0.0005 51 | learning_rate: 52 | class: DecayLearningRate 53 | epochs: 100 # 1200 54 | epochs: 100 # 1200 55 | 56 | validation: &validate 57 | class: ValidationSettings 58 | data_loaders: 59 | icdar2015: 60 | class: DataLoader 61 | dataset: ^validate_data 62 | batch_size: 8 63 | num_workers: 4 64 | collect_fn: 65 | class: ICDARCollectFN 66 | visualize: false 67 | interval: 156 68 | exempt: 1 69 | 70 | logger: 71 | class: Logger 72 | verbose: true 73 | level: info 74 | log_interval: 156 75 | 76 | evaluation: *validate 77 | -------------------------------------------------------------------------------- /experiments/seg_detector/totaltext_resnet18_deform_thre.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/seg_detector/base_totaltext.yaml' 3 | package: [] 4 | define: 5 | - name: 'Experiment' 6 | class: Experiment 7 | structure: 8 | class: Structure 9 | builder: 10 | class: Builder 11 | model: SegDetectorModel 12 | model_args: 13 | backbone: deformable_resnet18 14 | decoder: SegDetector 15 | decoder_args: 16 | adaptive: True 17 | in_channels: [64, 128, 256, 512] 18 | k: 50 19 | loss_class: L1BalanceCELoss 20 | representer: 21 | class: SegDetectorRepresenter 22 | max_candidates: 1000 23 | measurer: 24 | class: QuadMeasurer 25 | visualizer: 26 | class: SegDetectorVisualizer 27 | train: 28 | class: TrainSettings 29 | data_loader: 30 | class: DataLoader 31 | dataset: ^train_data 32 | batch_size: 16 33 | num_workers: 16 34 | checkpoint: 35 | class: Checkpoint 36 | start_epoch: 0 37 | start_iter: 0 38 | resume: null 39 | model_saver: 40 | class: ModelSaver 41 | dir_path: model 42 | save_interval: 18000 43 | signal_path: save 44 | scheduler: 45 | class: OptimizerScheduler 46 | optimizer: "SGD" 47 | optimizer_args: 48 | lr: 0.007 49 | momentum: 0.9 50 | weight_decay: 0.0001 51 | learning_rate: 52 | class: DecayLearningRate 53 | epochs: 1200 # 1200 54 | epochs: 1200 # 1200 55 | 56 | validation: &validate 57 | class: ValidationSettings 58 | data_loaders: 59 | icdar2015: 60 | class: DataLoader 61 | dataset: ^validate_data 62 | batch_size: 1 63 | num_workers: 16 64 | collect_fn: 65 | class: ICDARCollectFN 66 | visualize: false 67 | interval: 4500 68 | exempt: 1 69 | 70 | logger: 71 | class: Logger 72 | verbose: true 73 | level: info 74 | log_interval: 450 75 | 76 | evaluation: *validate 77 | -------------------------------------------------------------------------------- /experiments/seg_detector/totaltext_resnet50_deform_thre.yaml: -------------------------------------------------------------------------------- 1 | import: 2 | - 'experiments/seg_detector/base_totaltext.yaml' 3 | package: [] 4 | define: 5 | - name: 'Experiment' 6 | class: Experiment 7 | structure: 8 | class: Structure 9 | builder: 10 | class: Builder 11 | model: SegDetectorModel 12 | model_args: 13 | backbone: deformable_resnet50 14 | decoder: SegDetector 15 | decoder_args: 16 | adaptive: True 17 | in_channels: [256, 512, 1024, 2048] 18 | k: 50 19 | loss_class: L1BalanceCELoss 20 | representer: 21 | class: SegDetectorRepresenter 22 | max_candidates: 1000 23 | measurer: 24 | class: QuadMeasurer 25 | visualizer: 26 | class: SegDetectorVisualizer 27 | train: 28 | class: TrainSettings 29 | data_loader: 30 | class: DataLoader 31 | dataset: ^train_data 32 | batch_size: 16 33 | num_workers: 16 34 | checkpoint: 35 | class: Checkpoint 36 | start_epoch: 0 37 | start_iter: 0 38 | resume: null 39 | model_saver: 40 | class: ModelSaver 41 | dir_path: model 42 | save_interval: 18000 43 | signal_path: save 44 | scheduler: 45 | class: OptimizerScheduler 46 | optimizer: "SGD" 47 | optimizer_args: 48 | lr: 0.007 49 | momentum: 0.9 50 | weight_decay: 0.0001 51 | learning_rate: 52 | class: DecayLearningRate 53 | epochs: 1200 # 1200 54 | epochs: 1200 # 1200 55 | 56 | validation: &validate 57 | class: ValidationSettings 58 | data_loaders: 59 | icdar2015: 60 | class: DataLoader 61 | dataset: ^validate_data 62 | batch_size: 1 63 | num_workers: 16 64 | collect_fn: 65 | class: ICDARCollectFN 66 | visualize: false 67 | interval: 4500 68 | exempt: 1 69 | 70 | logger: 71 | class: Logger 72 | verbose: true 73 | level: info 74 | log_interval: 450 75 | 76 | evaluation: *validate 77 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | pyyaml 2 | tqdm 3 | tensorboardX 4 | opencv-python==4.1.2.30 5 | anyconfig 6 | munch 7 | scipy 8 | sortedcontainers 9 | shapely 10 | pyclipper 11 | gevent 12 | gevent-websocket 13 | flask 14 | editdistance 15 | scikit-image 16 | imgaug==0.2.8 17 | 18 | -------------------------------------------------------------------------------- /structure/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MhLiao/DB/65ca77a0bcfbd7114b916cf8a1e9ca85114286ce/structure/__init__.py -------------------------------------------------------------------------------- /structure/builder.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | import structure.model 6 | from concern.config import Configurable, State 7 | 8 | 9 | class Builder(Configurable): 10 | model = State() 11 | model_args = State() 12 | 13 | def __init__(self, cmd={}, **kwargs): 14 | self.load_all(**kwargs) 15 | if 'backbone' in cmd: 16 | self.model_args['backbone'] = cmd['backbone'] 17 | 18 | @property 19 | def model_name(self): 20 | return self.model + '-' + getattr(structure.model, self.model).model_name(self.model_args) 21 | 22 | def build(self, device, distributed=False, local_rank: int = 0): 23 | Model = getattr(structure.model,self.model) 24 | model = Model(self.model_args, device, 25 | distributed=distributed, local_rank=local_rank) 26 | return model 27 | 28 | -------------------------------------------------------------------------------- /structure/measurers/__init__.py: -------------------------------------------------------------------------------- 1 | from .icdar_detection_measurer import ICDARDetectionMeasurer 2 | from .quad_measurer import QuadMeasurer 3 | 4 | -------------------------------------------------------------------------------- /structure/measurers/icdar_detection_measurer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import shutil 4 | 5 | import numpy as np 6 | import json 7 | 8 | from concern import Logger, AverageMeter 9 | from concern.config import Configurable 10 | 11 | 12 | class ICDARDetectionMeasurer(Configurable): 13 | def __init__(self, **kwargs): 14 | self.visualized = False 15 | 16 | def measure(self, batch, output): 17 | pairs = [] 18 | for i in range(len(batch[-1])): 19 | pairs.append((batch[-1][i], output[i][0])) 20 | return pairs 21 | 22 | def validate_measure(self, batch, output): 23 | return self.measure(batch, output), [int(self.visualized)] 24 | 25 | def evaluate_measure(self, batch, output): 26 | return self.measure(batch, output), np.linspace(0, batch[0].shape[0]).tolist() 27 | 28 | def gather_measure(self, name, raw_metrics, logger: Logger): 29 | save_dir = os.path.join(logger.log_dir, name) 30 | shutil.rmtree(save_dir, ignore_errors=True) 31 | if not os.path.exists(save_dir): 32 | os.makedirs(save_dir) 33 | log_file_path = os.path.join(save_dir, name + '.log') 34 | count = 0 35 | for batch_pairs in raw_metrics: 36 | for _filename, boxes in batch_pairs: 37 | boxes = np.array(boxes).reshape(-1, 8).astype(np.int32) 38 | filename = 'res_' + _filename.replace('.jpg', '.txt') 39 | with open(os.path.join(save_dir, filename), 'wt') as f: 40 | if len(boxes) == 0: 41 | f.write('') 42 | for box in boxes: 43 | f.write(','.join(map(str, box)) + '\n') 44 | count += 1 45 | 46 | self.packing(save_dir) 47 | try: 48 | raw_out = subprocess.check_output(['python assets/ic15_eval/script.py -m=' + name 49 | + ' -g=assets/ic15_eval/gt.zip -s=' + 50 | os.path.join(save_dir, 'submit.zip') + 51 | '|tee -a ' + log_file_path], 52 | timeout=30, shell=True) 53 | except subprocess.TimeoutExpired: 54 | return {} 55 | raw_out = raw_out.decode().replace('Calculated!', '') 56 | dict_out = json.loads(raw_out) 57 | return {k: AverageMeter().update(v, n=count) for k, v in dict_out.items()} 58 | 59 | def packing(self, save_dir): 60 | pack_name = 'submit.zip' 61 | os.system( 62 | 'zip -r -j -q ' + 63 | os.path.join(save_dir, pack_name) + ' ' + save_dir + '/*.txt') 64 | -------------------------------------------------------------------------------- /structure/measurers/quad_measurer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from concern import Logger, AverageMeter 4 | from concern.config import Configurable 5 | from concern.icdar2015_eval.detection.iou import DetectionIoUEvaluator 6 | 7 | 8 | class QuadMeasurer(Configurable): 9 | def __init__(self, **kwargs): 10 | self.evaluator = DetectionIoUEvaluator() 11 | 12 | def measure(self, batch, output, is_output_polygon=False, box_thresh=0.6): 13 | ''' 14 | batch: (image, polygons, ignore_tags 15 | batch: a dict produced by dataloaders. 16 | image: tensor of shape (N, C, H, W). 17 | polygons: tensor of shape (N, K, 4, 2), the polygons of objective regions. 18 | ignore_tags: tensor of shape (N, K), indicates whether a region is ignorable or not. 19 | shape: the original shape of images. 20 | filename: the original filenames of images. 21 | output: (polygons, ...) 22 | ''' 23 | results = [] 24 | gt_polyons_batch = batch['polygons'] 25 | ignore_tags_batch = batch['ignore_tags'] 26 | pred_polygons_batch = np.array(output[0]) 27 | pred_scores_batch = np.array(output[1]) 28 | for polygons, pred_polygons, pred_scores, ignore_tags in\ 29 | zip(gt_polyons_batch, pred_polygons_batch, pred_scores_batch, ignore_tags_batch): 30 | gt = [dict(points=polygons[i], ignore=ignore_tags[i]) 31 | for i in range(len(polygons))] 32 | if is_output_polygon: 33 | pred = [dict(points=pred_polygons[i]) 34 | for i in range(len(pred_polygons))] 35 | else: 36 | pred = [] 37 | # print(pred_polygons.shape) 38 | for i in range(pred_polygons.shape[0]): 39 | if pred_scores[i] >= box_thresh: 40 | # print(pred_polygons[i,:,:].tolist()) 41 | pred.append(dict(points=pred_polygons[i,:,:].tolist())) 42 | # pred = [dict(points=pred_polygons[i,:,:].tolist()) if pred_scores[i] >= box_thresh for i in range(pred_polygons.shape[0])] 43 | results.append(self.evaluator.evaluate_image(gt, pred)) 44 | return results 45 | 46 | def validate_measure(self, batch, output, is_output_polygon=False, box_thresh=0.6): 47 | return self.measure(batch, output, is_output_polygon, box_thresh) 48 | 49 | def evaluate_measure(self, batch, output): 50 | return self.measure(batch, output),\ 51 | np.linspace(0, batch['image'].shape[0]).tolist() 52 | 53 | def gather_measure(self, raw_metrics, logger: Logger): 54 | raw_metrics = [image_metrics 55 | for batch_metrics in raw_metrics 56 | for image_metrics in batch_metrics] 57 | 58 | result = self.evaluator.combine_results(raw_metrics) 59 | 60 | precision = AverageMeter() 61 | recall = AverageMeter() 62 | fmeasure = AverageMeter() 63 | 64 | precision.update(result['precision'], n=len(raw_metrics)) 65 | recall.update(result['recall'], n=len(raw_metrics)) 66 | fmeasure_score = 2 * precision.val * recall.val /\ 67 | (precision.val + recall.val + 1e-8) 68 | fmeasure.update(fmeasure_score) 69 | 70 | return { 71 | 'precision': precision, 72 | 'recall': recall, 73 | 'fmeasure': fmeasure 74 | } 75 | -------------------------------------------------------------------------------- /structure/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import backbones 8 | import decoders 9 | 10 | 11 | class BasicModel(nn.Module): 12 | def __init__(self, args): 13 | nn.Module.__init__(self) 14 | 15 | self.backbone = getattr(backbones, args['backbone'])(**args.get('backbone_args', {})) 16 | self.decoder = getattr(decoders, args['decoder'])(**args.get('decoder_args', {})) 17 | 18 | def forward(self, data, *args, **kwargs): 19 | return self.decoder(self.backbone(data), *args, **kwargs) 20 | 21 | 22 | def parallelize(model, distributed, local_rank): 23 | if distributed: 24 | return nn.parallel.DistributedDataParallel( 25 | model, 26 | device_ids=[local_rank], 27 | output_device=[local_rank], 28 | find_unused_parameters=True) 29 | else: 30 | return nn.DataParallel(model) 31 | 32 | class SegDetectorModel(nn.Module): 33 | def __init__(self, args, device, distributed: bool = False, local_rank: int = 0): 34 | super(SegDetectorModel, self).__init__() 35 | from decoders.seg_detector_loss import SegDetectorLossBuilder 36 | 37 | self.model = BasicModel(args) 38 | # for loading models 39 | self.model = parallelize(self.model, distributed, local_rank) 40 | self.criterion = SegDetectorLossBuilder( 41 | args['loss_class'], *args.get('loss_args', []), **args.get('loss_kwargs', {})).build() 42 | self.criterion = parallelize(self.criterion, distributed, local_rank) 43 | self.device = device 44 | self.to(self.device) 45 | 46 | @staticmethod 47 | def model_name(args): 48 | return os.path.join('seg_detector', args['backbone'], args['loss_class']) 49 | 50 | def forward(self, batch, training=True): 51 | if isinstance(batch, dict): 52 | data = batch['image'].to(self.device) 53 | else: 54 | data = batch.to(self.device) 55 | data = data.float() 56 | pred = self.model(data, training=self.training) 57 | 58 | if self.training: 59 | for key, value in batch.items(): 60 | if value is not None: 61 | if hasattr(value, 'to'): 62 | batch[key] = value.to(self.device) 63 | loss_with_metrics = self.criterion(pred, batch) 64 | loss, metrics = loss_with_metrics 65 | return loss, pred, metrics 66 | return pred -------------------------------------------------------------------------------- /structure/representers/__init__.py: -------------------------------------------------------------------------------- 1 | from .seg_detector_representer import SegDetectorRepresenter 2 | -------------------------------------------------------------------------------- /structure/representers/setup.py: -------------------------------------------------------------------------------- 1 | # setup.py 2 | from distutils.core import setup, Extension 3 | from Cython.Build import cythonize 4 | import numpy 5 | setup(ext_modules = cythonize(Extension( 6 | 'boxes_from_map', 7 | sources=['boxes_from_map.pyx'], 8 | language='c', 9 | include_dirs=[numpy.get_include()], 10 | library_dirs=[], 11 | libraries=[], 12 | extra_compile_args=[], 13 | extra_link_args=[] 14 | ))) -------------------------------------------------------------------------------- /structure/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .seg_detector_visualizer import SegDetectorVisualizer 2 | -------------------------------------------------------------------------------- /structure/visualizers/seg_detector_visualizer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import concern.webcv2 as webcv2 3 | import numpy as np 4 | import torch 5 | 6 | from concern.config import Configurable, State 7 | from data.processes.make_icdar_data import MakeICDARData 8 | 9 | 10 | class SegDetectorVisualizer(Configurable): 11 | vis_num = State(default=4) 12 | eager_show = State(default=False) 13 | 14 | def __init__(self, **kwargs): 15 | cmd = kwargs['cmd'] 16 | if 'eager_show' in cmd: 17 | self.eager_show = cmd['eager_show'] 18 | 19 | def visualize(self, batch, output_pair, pred): 20 | boxes, _ = output_pair 21 | result_dict = {} 22 | for i in range(batch['image'].size(0)): 23 | result_dict.update( 24 | self.single_visualize(batch, i, boxes[i], pred)) 25 | if self.eager_show: 26 | webcv2.waitKey() 27 | return {} 28 | return result_dict 29 | 30 | def _visualize_heatmap(self, heatmap, canvas=None): 31 | if isinstance(heatmap, torch.Tensor): 32 | heatmap = heatmap.cpu().numpy() 33 | heatmap = (heatmap[0] * 255).astype(np.uint8) 34 | if canvas is None: 35 | pred_image = heatmap 36 | else: 37 | pred_image = (heatmap.reshape( 38 | *heatmap.shape[:2], 1).astype(np.float32) / 255 + 1) / 2 * canvas 39 | pred_image = pred_image.astype(np.uint8) 40 | return pred_image 41 | 42 | 43 | def single_visualize(self, batch, index, boxes, pred): 44 | image = batch['image'][index] 45 | polygons = batch['polygons'][index] 46 | if isinstance(polygons, torch.Tensor): 47 | polygons = polygons.cpu().data.numpy() 48 | ignore_tags = batch['ignore_tags'][index] 49 | original_shape = batch['shape'][index] 50 | filename = batch['filename'][index] 51 | std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1) 52 | mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1) 53 | image = (image.cpu().numpy() * std + mean).transpose(1, 2, 0) * 255 54 | pred_canvas = image.copy().astype(np.uint8) 55 | pred_canvas = cv2.resize(pred_canvas, (original_shape[1], original_shape[0])) 56 | 57 | if isinstance(pred, dict) and 'thresh' in pred: 58 | thresh = self._visualize_heatmap(pred['thresh'][index]) 59 | 60 | if isinstance(pred, dict) and 'thresh_binary' in pred: 61 | thresh_binary = self._visualize_heatmap(pred['thresh_binary'][index]) 62 | MakeICDARData.polylines(self, thresh_binary, polygons, ignore_tags) 63 | 64 | for box in boxes: 65 | box = np.array(box).astype(np.int32).reshape(-1, 2) 66 | cv2.polylines(pred_canvas, [box], True, (0, 255, 0), 2) 67 | if isinstance(pred, dict) and 'thresh_binary' in pred: 68 | cv2.polylines(thresh_binary, [box], True, (0, 255, 0), 1) 69 | 70 | if self.eager_show: 71 | webcv2.imshow(filename + ' output', cv2.resize(pred_canvas, (1024, 1024))) 72 | if isinstance(pred, dict) and 'thresh' in pred: 73 | webcv2.imshow(filename + ' thresh', cv2.resize(thresh, (1024, 1024))) 74 | webcv2.imshow(filename + ' pred', cv2.resize(pred_canvas, (1024, 1024))) 75 | if isinstance(pred, dict) and 'thresh_binary' in pred: 76 | webcv2.imshow(filename + ' thresh_binary', cv2.resize(thresh_binary, (1024, 1024))) 77 | return {} 78 | else: 79 | if isinstance(pred, dict) and 'thresh' in pred: 80 | return { 81 | filename + '_output': pred_canvas, 82 | filename + '_thresh': thresh, 83 | # filename + '_pred': thresh_binary 84 | } 85 | else: 86 | return { 87 | filename + '_output': pred_canvas, 88 | # filename + '_pred': thresh_binary 89 | } 90 | 91 | def demo_visualize(self, image_path, output): 92 | boxes, _ = output 93 | boxes = boxes[0] 94 | original_image = cv2.imread(image_path, cv2.IMREAD_COLOR) 95 | original_shape = original_image.shape 96 | pred_canvas = original_image.copy().astype(np.uint8) 97 | pred_canvas = cv2.resize(pred_canvas, (original_shape[1], original_shape[0])) 98 | 99 | for box in boxes: 100 | box = np.array(box).astype(np.int32).reshape(-1, 2) 101 | cv2.polylines(pred_canvas, [box], True, (0, 255, 0), 2) 102 | 103 | return pred_canvas 104 | 105 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!python3 2 | import argparse 3 | import time 4 | 5 | import torch 6 | import yaml 7 | 8 | from trainer import Trainer 9 | # tagged yaml objects 10 | from experiment import Structure, TrainSettings, ValidationSettings, Experiment 11 | from concern.log import Logger 12 | from data.data_loader import DataLoader 13 | from data.image_dataset import ImageDataset 14 | from training.checkpoint import Checkpoint 15 | from training.model_saver import ModelSaver 16 | from training.optimizer_scheduler import OptimizerScheduler 17 | from concern.config import Configurable, Config 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser(description='Text Recognition Training') 22 | parser.add_argument('exp', type=str) 23 | parser.add_argument('--name', type=str) 24 | parser.add_argument('--batch_size', type=int, help='Batch size for training') 25 | parser.add_argument('--resume', type=str, help='Resume from checkpoint') 26 | parser.add_argument('--epochs', type=int, help='Number of training epochs') 27 | parser.add_argument('--num_workers', type=int, help='Number of dataloader workers') 28 | parser.add_argument('--start_iter', type=int, help='Begin counting iterations starting from this value (should be used with resume)') 29 | parser.add_argument('--start_epoch', type=int, help='Begin counting epoch starting from this value (should be used with resume)') 30 | parser.add_argument('--max_size', type=int, help='max length of label') 31 | parser.add_argument('--lr', type=float, help='initial learning rate') 32 | parser.add_argument('--optimizer', type=str, help='The optimizer want to use') 33 | parser.add_argument('--thresh', type=float, help='The threshold to replace it in the representers') 34 | parser.add_argument('--verbose', action='store_true', help='show verbose info') 35 | parser.add_argument('--visualize', action='store_true', help='visualize maps in tensorboard') 36 | parser.add_argument('--force_reload', action='store_true', dest='force_reload', help='Force reload data meta') 37 | parser.add_argument('--no-force_reload', action='store_false', dest='force_reload', help='Force reload data meta') 38 | parser.add_argument('--validate', action='store_true', dest='validate', help='Validate during training') 39 | parser.add_argument('--no-validate', action='store_false', dest='validate', help='Validate during training') 40 | parser.add_argument('--print-config-only', action='store_true', help='print config without actual training') 41 | parser.add_argument('--debug', action='store_true', dest='debug', help='Run with debug mode, which hacks dataset num_samples to toy number') 42 | parser.add_argument('--no-debug', action='store_false', dest='debug', help='Run without debug mode') 43 | parser.add_argument('--benchmark', action='store_true', dest='benchmark', help='Open cudnn benchmark mode') 44 | parser.add_argument('--no-benchmark', action='store_false', dest='benchmark', help='Turn cudnn benchmark mode off') 45 | parser.add_argument('-d', '--distributed', action='store_true', dest='distributed', help='Use distributed training') 46 | parser.add_argument('--local_rank', dest='local_rank', default=0, type=int, help='Use distributed training') 47 | parser.add_argument('-g', '--num_gpus', dest='num_gpus', default=4, type=int, help='The number of accessible gpus') 48 | parser.set_defaults(debug=False) 49 | parser.set_defaults(benchmark=True) 50 | 51 | args = parser.parse_args() 52 | args = vars(args) 53 | args = {k: v for k, v in args.items() if v is not None} 54 | 55 | if args['distributed']: 56 | torch.cuda.set_device(args['local_rank']) 57 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 58 | 59 | conf = Config() 60 | experiment_args = conf.compile(conf.load(args['exp']))['Experiment'] 61 | experiment_args.update(cmd=args) 62 | experiment = Configurable.construct_class_from_config(experiment_args) 63 | 64 | if not args['print_config_only']: 65 | torch.backends.cudnn.benchmark = args['benchmark'] 66 | trainer = Trainer(experiment) 67 | trainer.train() 68 | 69 | if __name__ == '__main__': 70 | main() 71 | 72 | -------------------------------------------------------------------------------- /training/checkpoint.py: -------------------------------------------------------------------------------- 1 | from concern.config import Configurable, State 2 | import os 3 | import torch 4 | 5 | 6 | class Checkpoint(Configurable): 7 | start_epoch = State(default=0) 8 | start_iter = State(default=0) 9 | resume = State() 10 | 11 | def __init__(self, **kwargs): 12 | self.load_all(**kwargs) 13 | 14 | cmd = kwargs['cmd'] 15 | if 'start_epoch' in cmd: 16 | self.start_epoch = cmd['start_epoch'] 17 | if 'start_iter' in cmd: 18 | self.start_iter = cmd['start_iter'] 19 | if 'resume' in cmd: 20 | self.resume = cmd['resume'] 21 | 22 | def restore_model(self, model, device, logger): 23 | if self.resume is None: 24 | return 25 | 26 | if not os.path.exists(self.resume): 27 | self.logger.warning("Checkpoint not found: " + 28 | self.resume) 29 | return 30 | 31 | logger.info("Resuming from " + self.resume) 32 | state_dict = torch.load(self.resume, map_location=device) 33 | model.load_state_dict(state_dict, strict=False) 34 | logger.info("Resumed from " + self.resume) 35 | 36 | def restore_counter(self): 37 | return self.start_epoch, self.start_iter 38 | -------------------------------------------------------------------------------- /training/learning_rate.py: -------------------------------------------------------------------------------- 1 | from bisect import bisect_right 2 | import numpy as np 3 | import torch.optim.lr_scheduler as lr_scheduler 4 | 5 | from concern.config import Configurable, State 6 | from concern.signal_monitor import SignalMonitor 7 | 8 | 9 | class ConstantLearningRate(Configurable): 10 | lr = State(default=0.0001) 11 | 12 | def __init__(self, **kwargs): 13 | self.load_all(**kwargs) 14 | 15 | def get_learning_rate(self, epoch, step): 16 | return self.lr 17 | 18 | 19 | class FileMonitorLearningRate(Configurable): 20 | file_path = State() 21 | 22 | def __init__(self, **kwargs): 23 | self.load_all(**kwargs) 24 | 25 | self.monitor = SignalMonitor(self.file_path) 26 | 27 | def get_learning_rate(self, epoch, step): 28 | signal = self.monitor.get_signal() 29 | if signal is not None: 30 | return float(signal) 31 | return None 32 | 33 | 34 | class PriorityLearningRate(Configurable): 35 | learning_rates = State() 36 | 37 | def __init__(self, **kwargs): 38 | self.load_all(**kwargs) 39 | 40 | def get_learning_rate(self, epoch, step): 41 | for learning_rate in self.learning_rates: 42 | lr = learning_rate.get_learning_rate(epoch, step) 43 | if lr is not None: 44 | return lr 45 | return None 46 | 47 | 48 | class MultiStepLR(Configurable): 49 | lr = State() 50 | milestones = State(default=[]) # milestones must be sorted 51 | gamma = State(default=0.1) 52 | 53 | def __init__(self, cmd={}, **kwargs): 54 | self.load_all(**kwargs) 55 | self.lr = cmd.get('lr', self.lr) 56 | 57 | def get_learning_rate(self, epoch, step): 58 | return self.lr * self.gamma ** bisect_right(self.milestones, epoch) 59 | 60 | 61 | class WarmupLR(Configurable): 62 | steps = State(default=4000) 63 | warmup_lr = State(default=1e-5) 64 | origin_lr = State() 65 | 66 | def __init__(self, cmd={}, **kwargs): 67 | self.load_all(**kwargs) 68 | 69 | def get_learning_rate(self, epoch, step): 70 | if epoch == 0 and step < self.steps: 71 | return self.warmup_lr 72 | return self.origin_lr.get_learning_rate(epoch, step) 73 | 74 | 75 | class PiecewiseConstantLearningRate(Configurable): 76 | boundaries = State(default=[10000, 20000]) 77 | values = State(default=[0.001, 0.0001, 0.00001]) 78 | 79 | def __init__(self, **kwargs): 80 | self.load_all(**kwargs) 81 | 82 | def get_learning_rate(self, epoch, step): 83 | for boundary, value in zip(self.boundaries, self.values[:-1]): 84 | if step < boundary: 85 | return value 86 | return self.values[-1] 87 | 88 | 89 | class DecayLearningRate(Configurable): 90 | lr = State(default=0.007) 91 | epochs = State(default=1200) 92 | factor = State(default=0.9) 93 | 94 | def __init__(self, **kwargs): 95 | self.load_all(**kwargs) 96 | 97 | def get_learning_rate(self, epoch, step=None): 98 | rate = np.power(1.0 - epoch / float(self.epochs + 1), self.factor) 99 | return rate * self.lr 100 | 101 | 102 | class BuitlinLearningRate(Configurable): 103 | lr = State(default=0.001) 104 | klass = State(default='StepLR') 105 | args = State(default=[]) 106 | kwargs = State(default={}) 107 | 108 | def __init__(self, cmd={}, **kwargs): 109 | self.load_all(**kwargs) 110 | self.lr = cmd.get('lr', None) or self.lr 111 | self.scheduler = None 112 | 113 | def prepare(self, optimizer): 114 | self.scheduler = getattr(lr_scheduler, self.klass)( 115 | optimizer, *self.args, **self.kwargs) 116 | 117 | def get_learning_rate(self, epoch, step=None): 118 | if self.scheduler is None: 119 | raise 'learning rate not ready(prepared with optimizer) ' 120 | self.scheduler.last_epoch = epoch 121 | # return value of gt_lr is a list, 122 | # where each element is the corresponding learning rate for a 123 | # paramater group. 124 | return self.scheduler.get_lr()[0] 125 | -------------------------------------------------------------------------------- /training/model_saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from concern.config import Configurable, State 6 | from concern.signal_monitor import SignalMonitor 7 | 8 | 9 | class ModelSaver(Configurable): 10 | dir_path = State() 11 | save_interval = State(default=1000) 12 | signal_path = State() 13 | 14 | def __init__(self, **kwargs): 15 | self.load_all(**kwargs) 16 | 17 | # BUG: signal path should not be global 18 | self.monitor = SignalMonitor(self.signal_path) 19 | 20 | def maybe_save_model(self, model, epoch, step, logger): 21 | if step % self.save_interval == 0 or self.monitor.get_signal() is not None: 22 | self.save_model(model, epoch, step) 23 | logger.report_time('Saving ') 24 | logger.iter(step) 25 | 26 | def save_model(self, model, epoch=None, step=None): 27 | if isinstance(model, dict): 28 | for name, net in model.items(): 29 | checkpoint_name = self.make_checkpoint_name(name, epoch, step) 30 | self.save_checkpoint(net, checkpoint_name) 31 | else: 32 | checkpoint_name = self.make_checkpoint_name('model', epoch, step) 33 | self.save_checkpoint(model, checkpoint_name) 34 | 35 | def save_checkpoint(self, net, name): 36 | os.makedirs(self.dir_path, exist_ok=True) 37 | torch.save(net.state_dict(), os.path.join(self.dir_path, name)) 38 | 39 | def make_checkpoint_name(self, name, epoch=None, step=None): 40 | if epoch is None or step is None: 41 | c_name = name + '_latest' 42 | else: 43 | c_name = '{}_epoch_{}_minibatch_{}'.format(name, epoch, step) 44 | return c_name 45 | -------------------------------------------------------------------------------- /training/optimizer_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from concern.config import Configurable, State 4 | 5 | 6 | class OptimizerScheduler(Configurable): 7 | optimizer = State() 8 | optimizer_args = State(default={}) 9 | learning_rate = State(autoload=False) 10 | 11 | def __init__(self, cmd={}, **kwargs): 12 | self.load_all(**kwargs) 13 | self.load('learning_rate', cmd=cmd, **kwargs) 14 | if 'lr' in cmd: 15 | self.optimizer_args['lr'] = cmd['lr'] 16 | 17 | def create_optimizer(self, parameters): 18 | optimizer = getattr(torch.optim, self.optimizer)( 19 | parameters, **self.optimizer_args) 20 | if hasattr(self.learning_rate, 'prepare'): 21 | self.learning_rate.prepare(optimizer) 22 | return optimizer 23 | --------------------------------------------------------------------------------