├── det_images ├── 0.png ├── 1.png ├── 5.png └── 6.jpg ├── rec_images ├── 1.png ├── 10.png ├── 11.png ├── 12.png ├── 13.png ├── 14.png ├── 15.png ├── 16.png ├── 17.png ├── 18.png ├── 19.png ├── 2.png ├── 20.png ├── 21.png ├── 22.png ├── 27.png ├── 5.png ├── 7.png ├── 8.png └── 9.png ├── rec ├── __pycache__ │ ├── RNN.cpython-37.pyc │ ├── RecSVTR.cpython-37.pyc │ ├── RecModel.cpython-37.pyc │ ├── RecCTCHead.cpython-37.pyc │ ├── RecSARHead.cpython-37.pyc │ └── RecMv1_enhance.cpython-37.pyc ├── RecModel.py ├── RecCTCHead.py ├── RNN.py ├── RecMv1_enhance.py ├── RecSARHead.py └── RecSVTR.py ├── det ├── __pycache__ │ ├── DB_fpn.cpython-37.pyc │ ├── DetDbHead.cpython-37.pyc │ ├── DetModel.cpython-37.pyc │ ├── CommonModules.cpython-37.pyc │ └── DetMobilenetV3.cpython-37.pyc ├── DetModel.py ├── DetDbHead.py ├── CommonModules.py ├── DetMobilenetV3.py └── DB_fpn.py ├── README.md ├── paddle2torch_ppocrv3_det.py ├── torch_rec_infer.py ├── paddle2torch_ppocrv3_rec.py ├── torch_det_infer.py ├── onnx_infer.py └── weights └── ppocr_keys_v1.txt /det_images/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/det_images/0.png -------------------------------------------------------------------------------- /det_images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/det_images/1.png -------------------------------------------------------------------------------- /det_images/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/det_images/5.png -------------------------------------------------------------------------------- /det_images/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/det_images/6.jpg -------------------------------------------------------------------------------- /rec_images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/1.png -------------------------------------------------------------------------------- /rec_images/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/10.png -------------------------------------------------------------------------------- /rec_images/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/11.png -------------------------------------------------------------------------------- /rec_images/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/12.png -------------------------------------------------------------------------------- /rec_images/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/13.png -------------------------------------------------------------------------------- /rec_images/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/14.png -------------------------------------------------------------------------------- /rec_images/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/15.png -------------------------------------------------------------------------------- /rec_images/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/16.png -------------------------------------------------------------------------------- /rec_images/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/17.png -------------------------------------------------------------------------------- /rec_images/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/18.png -------------------------------------------------------------------------------- /rec_images/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/19.png -------------------------------------------------------------------------------- /rec_images/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/2.png -------------------------------------------------------------------------------- /rec_images/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/20.png -------------------------------------------------------------------------------- /rec_images/21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/21.png -------------------------------------------------------------------------------- /rec_images/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/22.png -------------------------------------------------------------------------------- /rec_images/27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/27.png -------------------------------------------------------------------------------- /rec_images/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/5.png -------------------------------------------------------------------------------- /rec_images/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/7.png -------------------------------------------------------------------------------- /rec_images/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/8.png -------------------------------------------------------------------------------- /rec_images/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec_images/9.png -------------------------------------------------------------------------------- /rec/__pycache__/RNN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec/__pycache__/RNN.cpython-37.pyc -------------------------------------------------------------------------------- /det/__pycache__/DB_fpn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/det/__pycache__/DB_fpn.cpython-37.pyc -------------------------------------------------------------------------------- /rec/__pycache__/RecSVTR.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec/__pycache__/RecSVTR.cpython-37.pyc -------------------------------------------------------------------------------- /det/__pycache__/DetDbHead.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/det/__pycache__/DetDbHead.cpython-37.pyc -------------------------------------------------------------------------------- /det/__pycache__/DetModel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/det/__pycache__/DetModel.cpython-37.pyc -------------------------------------------------------------------------------- /rec/__pycache__/RecModel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec/__pycache__/RecModel.cpython-37.pyc -------------------------------------------------------------------------------- /rec/__pycache__/RecCTCHead.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec/__pycache__/RecCTCHead.cpython-37.pyc -------------------------------------------------------------------------------- /rec/__pycache__/RecSARHead.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec/__pycache__/RecSARHead.cpython-37.pyc -------------------------------------------------------------------------------- /det/__pycache__/CommonModules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/det/__pycache__/CommonModules.cpython-37.pyc -------------------------------------------------------------------------------- /det/__pycache__/DetMobilenetV3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/det/__pycache__/DetMobilenetV3.cpython-37.pyc -------------------------------------------------------------------------------- /rec/__pycache__/RecMv1_enhance.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1079863482/paddle2torch_PPOCRv3/HEAD/rec/__pycache__/RecMv1_enhance.cpython-37.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # paddleocrv3模型转pytorch 2 | 3 | 将PaddleOCRV3的训练模型下载后解压到weights文件夹中 4 | 5 | 原理请参考博客链接:https://blog.csdn.net/qq_39056987/article/details/124921515?spm=1001.2014.3001.5501 6 | -------------------------------------------------------------------------------- /det/DetModel.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from det.DetMobilenetV3 import MobileNetV3 3 | from det.DB_fpn import DB_fpn,RSEFPN,LKPAN 4 | from det.DetDbHead import DBHead 5 | 6 | backbone_dict = {'MobileNetV3': MobileNetV3} 7 | neck_dict = {'DB_fpn': DB_fpn,'RSEFPN':RSEFPN,'LKPAN':LKPAN} 8 | head_dict = {'DBHead': DBHead} 9 | 10 | class DetModel(nn.Module): 11 | def __init__(self, config): 12 | super().__init__() 13 | assert 'in_channels' in config, 'in_channels must in model config' 14 | backbone_type = config.backbone.pop('type') 15 | assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}' 16 | self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone) 17 | 18 | neck_type = config.neck.pop('type') 19 | assert neck_type in neck_dict, f'neck.type must in {neck_dict}' 20 | self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck) 21 | 22 | head_type = config.head.pop('type') 23 | assert head_type in head_dict, f'head.type must in {head_dict}' 24 | self.head = head_dict[head_type](self.neck.out_channels, **config.head) 25 | 26 | self.name = f'DetModel_{backbone_type}_{neck_type}_{head_type}' 27 | 28 | def load_3rd_state_dict(self, _3rd_name, _state): 29 | self.backbone.load_3rd_state_dict(_3rd_name, _state) 30 | self.neck.load_3rd_state_dict(_3rd_name, _state) 31 | self.head.load_3rd_state_dict(_3rd_name, _state) 32 | 33 | def forward(self, x): 34 | x = self.backbone(x) 35 | x = self.neck(x) 36 | x = self.head(x) 37 | return x -------------------------------------------------------------------------------- /rec/RecModel.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from rec.RNN import SequenceEncoder, Im2Seq,Im2Im 4 | from rec.RecSVTR import SVTRNet 5 | from rec.RecMv1_enhance import MobileNetV1Enhance 6 | 7 | from rec.RecCTCHead import CTC,MultiHead 8 | 9 | backbone_dict = {"SVTR":SVTRNet,"MobileNetV1Enhance":MobileNetV1Enhance} 10 | neck_dict = {'PPaddleRNN': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im} 11 | head_dict = {'CTC': CTC,'Multi':MultiHead} 12 | 13 | 14 | class RecModel(nn.Module): 15 | def __init__(self, config): 16 | super().__init__() 17 | assert 'in_channels' in config, 'in_channels must in model config' 18 | backbone_type = config.backbone.pop('type') 19 | assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}' 20 | self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone) 21 | 22 | neck_type = config.neck.pop('type') 23 | assert neck_type in neck_dict, f'neck.type must in {neck_dict}' 24 | self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck) 25 | 26 | head_type = config.head.pop('type') 27 | assert head_type in head_dict, f'head.type must in {head_dict}' 28 | self.head = head_dict[head_type](self.neck.out_channels, **config.head) 29 | 30 | self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}' 31 | 32 | def load_3rd_state_dict(self, _3rd_name, _state): 33 | self.backbone.load_3rd_state_dict(_3rd_name, _state) 34 | self.neck.load_3rd_state_dict(_3rd_name, _state) 35 | self.head.load_3rd_state_dict(_3rd_name, _state) 36 | 37 | def forward(self, x): 38 | x = self.backbone(x) 39 | x = self.neck(x) 40 | x = self.head(x) 41 | return x -------------------------------------------------------------------------------- /det/DetDbHead.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Head(nn.Module): 6 | def __init__(self, in_channels): 7 | super().__init__() 8 | 9 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 4, kernel_size=3, padding=1, 10 | bias=False) 11 | self.conv_bn1 = nn.BatchNorm2d(in_channels // 4) 12 | self.relu = nn.ReLU(inplace=True) 13 | self.conv2 = nn.ConvTranspose2d(in_channels=in_channels // 4, out_channels=in_channels // 4, kernel_size=2, 14 | stride=2) 15 | self.conv_bn2 = nn.BatchNorm2d(in_channels // 4) 16 | self.conv3 = nn.ConvTranspose2d(in_channels=in_channels // 4, out_channels=1, kernel_size=2, stride=2) 17 | 18 | def load_3rd_state_dict(self, _3rd_name, _state): 19 | pass 20 | 21 | def forward(self, x): 22 | x = self.conv1(x) 23 | x = self.conv_bn1(x) 24 | x = self.relu(x) 25 | x = self.conv2(x) 26 | x = self.conv_bn2(x) 27 | x = self.relu(x) 28 | x = self.conv3(x) 29 | x = torch.sigmoid(x) 30 | return x 31 | 32 | 33 | class DBHead(nn.Module): 34 | """ 35 | Differentiable Binarization (DB) for text detection: 36 | see https://arxiv.org/abs/1911.08947 37 | args: 38 | params(dict): super parameters for build DB network 39 | """ 40 | 41 | def __init__(self, in_channels, k=50): 42 | super().__init__() 43 | self.k = k 44 | self.binarize = Head(in_channels) 45 | self.thresh = Head(in_channels) 46 | self.binarize.apply(self.weights_init) 47 | self.thresh.apply(self.weights_init) 48 | 49 | def step_function(self, x, y): 50 | return torch.reciprocal(1 + torch.exp(-self.k * (x - y))) 51 | 52 | def load_3rd_state_dict(self, _3rd_name, _state): 53 | pass 54 | 55 | def forward(self, x): 56 | shrink_maps = self.binarize(x) 57 | if not self.training: 58 | return shrink_maps 59 | threshold_maps = self.thresh(x) 60 | binary_maps = self.step_function(shrink_maps, threshold_maps) 61 | y = torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1) 62 | return y 63 | 64 | def weights_init(self, m): 65 | classname = m.__class__.__name__ 66 | if classname.find('Conv') != -1: 67 | nn.init.kaiming_normal_(m.weight.data) 68 | elif classname.find('BatchNorm') != -1: 69 | m.weight.data.fill_(1.) 70 | m.bias.data.fill_(1e-4) 71 | -------------------------------------------------------------------------------- /det/CommonModules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from collections import OrderedDict 5 | 6 | 7 | class HSwish(nn.Module): 8 | def forward(self, x): 9 | out = x * F.relu6(x + 3, inplace=True) / 6 10 | return out 11 | 12 | 13 | class HardSigmoid(nn.Module): 14 | def __init__(self, type): 15 | super().__init__() 16 | self.type = type 17 | 18 | def forward(self, x): 19 | if self.type == 'paddle': 20 | x = (1.2 * x).add_(3.).clamp_(0., 6.).div_(6.) 21 | else: 22 | x = F.relu6(x + 3, inplace=True) / 6 23 | return x 24 | 25 | 26 | class HSigmoid(nn.Module): 27 | def forward(self, x): 28 | x = (1.2 * x).add_(3.).clamp_(0., 6.).div_(6.) 29 | return x 30 | 31 | 32 | class ConvBNACT(nn.Module): 33 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, act=None): 34 | super().__init__() 35 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 36 | stride=stride, padding=padding, groups=groups, 37 | bias=False) 38 | self.bn = nn.BatchNorm2d(out_channels) 39 | if act == 'relu': 40 | self.act = nn.ReLU() 41 | elif act == 'hard_swish': 42 | self.act = HSwish() 43 | elif act is None: 44 | self.act = None 45 | 46 | def load_3rd_state_dict(self, _3rd_name, _state, _name_prefix): 47 | to_load_state_dict = OrderedDict() 48 | if _3rd_name == 'paddle': 49 | to_load_state_dict['conv.weight'] = torch.Tensor(_state[f'{_name_prefix}_weights']) 50 | to_load_state_dict['bn.weight'] = torch.Tensor(_state[f'{_name_prefix}_bn_scale']) 51 | to_load_state_dict['bn.bias'] = torch.Tensor(_state[f'{_name_prefix}_bn_offset']) 52 | to_load_state_dict['bn.running_mean'] = torch.Tensor(_state[f'{_name_prefix}_bn_mean']) 53 | to_load_state_dict['bn.running_var'] = torch.Tensor(_state[f'{_name_prefix}_bn_variance']) 54 | self.load_state_dict(to_load_state_dict) 55 | else: 56 | pass 57 | 58 | def forward(self, x): 59 | x = self.conv(x) 60 | x = self.bn(x) 61 | if self.act is not None: 62 | x = self.act(x) 63 | return x 64 | 65 | 66 | class SEBlock(nn.Module): 67 | def __init__(self, in_channels, out_channels, hsigmoid_type='others', ratio=4): 68 | super().__init__() 69 | num_mid_filter = out_channels // ratio 70 | self.pool = nn.AdaptiveAvgPool2d(1) 71 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=num_mid_filter, kernel_size=1, bias=True) 72 | self.relu1 = nn.ReLU() 73 | self.conv2 = nn.Conv2d(in_channels=num_mid_filter, kernel_size=1, out_channels=out_channels, bias=True) 74 | self.relu2 = HardSigmoid(hsigmoid_type) 75 | 76 | def load_3rd_state_dict(self, _3rd_name, _state, _name_prefix): 77 | to_load_state_dict = OrderedDict() 78 | if _3rd_name == 'paddle': 79 | to_load_state_dict['conv1.weight'] = torch.Tensor(_state[f'{_name_prefix}_1_weights']) 80 | to_load_state_dict['conv2.weight'] = torch.Tensor(_state[f'{_name_prefix}_2_weights']) 81 | to_load_state_dict['conv1.bias'] = torch.Tensor(_state[f'{_name_prefix}_1_offset']) 82 | to_load_state_dict['conv2.bias'] = torch.Tensor(_state[f'{_name_prefix}_2_offset']) 83 | self.load_state_dict(to_load_state_dict) 84 | else: 85 | pass 86 | 87 | def forward(self, x): 88 | attn = self.pool(x) 89 | attn = self.conv1(attn) 90 | attn = self.relu1(attn) 91 | attn = self.conv2(attn) 92 | attn = self.relu2(attn) 93 | return x * attn 94 | -------------------------------------------------------------------------------- /paddle2torch_ppocrv3_det.py: -------------------------------------------------------------------------------- 1 | from addict import Dict as AttrDict 2 | import torch 3 | import os 4 | import torch.onnx as tr_onnx 5 | import shutil 6 | import tempfile 7 | import paddle.fluid as fluid 8 | import onnxruntime as ort 9 | import numpy as np 10 | from det.DetModel import DetModel 11 | 12 | def load_state(path,trModule_state): 13 | """ 14 | 记载paddlepaddle的参数 15 | :param path: 16 | :return: 17 | """ 18 | if os.path.exists(path + '.pdopt'): 19 | # XXX another hack to ignore the optimizer state 20 | tmp = tempfile.mkdtemp() 21 | dst = os.path.join(tmp, os.path.basename(os.path.normpath(path))) 22 | shutil.copy(path + '.pdparams', dst + '.pdparams') 23 | state = fluid.io.load_program_state(dst) 24 | shutil.rmtree(tmp) 25 | else: 26 | state = fluid.io.load_program_state(path) 27 | 28 | # for i, key in enumerate(state.keys()): 29 | # print("{} {} ".format(i, key)) 30 | 31 | state_dict = {} 32 | for i, key in enumerate(state.keys()): 33 | if key =="StructuredToParameterName@@": 34 | continue 35 | state_dict[trModule_state[i]] = torch.from_numpy(state[key]) 36 | 37 | return state_dict 38 | 39 | def torch_onnx_infer(model,onnx_path): 40 | torch_model = model 41 | onnx_model = ort.InferenceSession(onnx_path) 42 | data_arr = torch.ones(1,3,640,640) 43 | np_arr = np.array(data_arr).astype(np.float32) 44 | print("->>模型前向对比!") 45 | torch_infer = torch_model(data_arr).detach().numpy() 46 | print("torch:", torch_infer) 47 | onnx_infer = onnx_model.run(None,{'input':np_arr}) 48 | print("onnx:", onnx_infer[0]) 49 | std = np.std(torch_infer-onnx_infer[0]) 50 | print("std:",std) 51 | 52 | def torch2onnx(model,onnx_path): 53 | test_arr = torch.randn(1,3,640,640) 54 | input_names = ['input'] 55 | output_names = ['output'] 56 | tr_onnx.export( 57 | model,test_arr,onnx_path, 58 | verbose=False, 59 | opset_version=11, 60 | input_names=input_names, 61 | output_names=output_names, 62 | dynamic_axes={"input":{ 2:"H", 63 | 3:"W",}, 64 | # "output":{1:"width"} 65 | } 66 | ) 67 | print('->>模型转换成功!') 68 | torch_onnx_infer(model,onnx_path) 69 | 70 | def torch2libtorch(model,lib_path): 71 | test_arr = torch.randn(1,3,640,640) 72 | traced_script_module = torch.jit.trace(model, test_arr) 73 | x = torch.ones(1, 3, 640, 640) 74 | output1 = traced_script_module(x) 75 | output2 = model(x) 76 | print(output1) 77 | print(output2) 78 | std = np.std(output1-output2) 79 | print("std:",std) 80 | traced_script_module.save(lib_path) 81 | print("->>模型转换成功!") 82 | 83 | if __name__ == '__main__': 84 | TrModule_save = './weights/ppv3_db.pth' # pytorch save model 85 | PpModule_path = './weights/ch_PP-OCRv3_det_distill_train/student' # paddle train model 86 | 87 | db_config = AttrDict( 88 | in_channels=3, 89 | backbone=AttrDict(type='MobileNetV3', model_name='large',scale=0.5,pretrained=True), 90 | neck=AttrDict(type='RSEFPN', out_channels=96), 91 | head=AttrDict(type='DBHead') 92 | ) 93 | 94 | model = DetModel(db_config) 95 | 96 | state_dict = [] 97 | for i,key in enumerate(model.state_dict()): 98 | if 'num_batches_tracked' in key: 99 | continue 100 | state_dict.append(key) 101 | 102 | state_torch = load_state(PpModule_path,state_dict) 103 | 104 | torch.save(state_torch, TrModule_save) 105 | model.load_state_dict(state_torch) 106 | model.eval() 107 | 108 | torch2onnx(model,"./weights/ppv3_db.onnx") 109 | # torch2libtorch(model,"ppv3_db.pt") 110 | -------------------------------------------------------------------------------- /rec/RecCTCHead.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from rec.RNN import Im2Seq,SequenceEncoder,EncoderWithSVTR 6 | from rec.RecSARHead import SARHead 7 | from addict import Dict as AttrDict 8 | 9 | class CTC(nn.Module): 10 | def __init__(self, in_channels, n_class, mid_channels=None,**kwargs): 11 | super().__init__() 12 | 13 | if mid_channels == None: 14 | self.fc = nn.Linear(in_channels, n_class) 15 | else: 16 | self.fc = nn.Sequential( 17 | nn.Linear(in_channels,mid_channels), 18 | nn.Linear(mid_channels,n_class) 19 | ) 20 | 21 | 22 | self.n_class = n_class 23 | 24 | def load_3rd_state_dict(self, _3rd_name, _state): 25 | to_load_state_dict = OrderedDict() 26 | if _3rd_name == 'paddle': 27 | if _state['ctc_fc_b_attr'].size == self.n_class: 28 | to_load_state_dict['fc.weight'] = torch.Tensor(_state['ctc_fc_w_attr'].T) 29 | to_load_state_dict['fc.bias'] = torch.Tensor(_state['ctc_fc_b_attr']) 30 | self.load_state_dict(to_load_state_dict) 31 | else: 32 | pass 33 | 34 | def forward(self, x, targets=None): 35 | 36 | return self.fc(x) 37 | 38 | class MultiHead(nn.Module): 39 | def __init__(self, in_channels, **kwargs): 40 | super().__init__() 41 | self.out_c = kwargs.get('n_class') 42 | self.head_list = kwargs.get('head_list') 43 | self.gtc_head = 'sar' 44 | # assert len(self.head_list) >= 2 45 | for idx, head_name in enumerate(self.head_list): 46 | # name = list(head_name)[0] 47 | name = head_name 48 | # if name == 'SARHead': 49 | # # sar head 50 | # sar_args = self.head_list[name] 51 | # self.sar_head = eval(name)(in_channels=in_channels, out_channels=self.out_c, **sar_args) 52 | if name == 'CTC': 53 | # ctc neck 54 | self.encoder_reshape = Im2Seq(in_channels) 55 | neck_args = self.head_list[name]['Neck'] 56 | encoder_type = neck_args.pop('name') 57 | self.encoder = encoder_type 58 | self.ctc_encoder = SequenceEncoder(in_channels=in_channels,encoder_type=encoder_type, **neck_args) 59 | # ctc head 60 | head_args = self.head_list[name] 61 | self.ctc_head = eval(name)(in_channels=self.ctc_encoder.out_channels,n_class=self.out_c, **head_args) 62 | else: 63 | raise NotImplementedError( 64 | '{} is not supported in MultiHead yet'.format(name)) 65 | 66 | def forward(self, x, targets=None): 67 | ctc_encoder = self.ctc_encoder(x) 68 | ctc_out = self.ctc_head(ctc_encoder, targets) 69 | head_out = dict() 70 | head_out['ctc'] = ctc_out 71 | head_out['ctc_neck'] = ctc_encoder 72 | return ctc_out # infer 73 | 74 | # # eval mode 75 | # print(not self.training) 76 | # if not self.training: # training 77 | # return ctc_out 78 | # if self.gtc_head == 'sar': 79 | # sar_out = self.sar_head(x, targets[1:]) 80 | # head_out['sar'] = sar_out 81 | # return head_out 82 | # else: 83 | # return head_out 84 | 85 | 86 | if __name__=="__main__": 87 | 88 | config = AttrDict(head_list=AttrDict(CTC=AttrDict(Neck=AttrDict(name="svtr",dims=64,depth=2,hidden_dims=120,use_guide=True)), 89 | # SARHead=AttrDict(enc_dim=512,max_text_length=70) 90 | ) 91 | ) 92 | # config = {'head_list': {"CTC": {"Neck": {"name": "svtr", "dims": 64, "depth": 2, 93 | # "hidden_dims": 120, "use_guide": True}, 94 | # }, 95 | # "SARHead": {"enc_dim": 512, "max_text_length": 25}, 96 | # }, 97 | # 'n_class': 5963, 98 | # } 99 | multi = MultiHead(128,kwargs=config) 100 | 101 | print(multi) -------------------------------------------------------------------------------- /torch_rec_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from rec.RecModel import RecModel 3 | import torch 4 | from addict import Dict as AttrDict 5 | import cv2 6 | import numpy as np 7 | import math 8 | import time 9 | 10 | class CTCLabelConverter(object): 11 | """ Convert between text-label and text-index """ 12 | 13 | def __init__(self, character): 14 | # character (str): set of the possible characters. 15 | dict_character = [] 16 | with open(character, "rb") as fin: 17 | lines = fin.readlines() 18 | for line in lines: 19 | line = line.decode('utf-8').strip("\n").strip("\r\n") 20 | dict_character += list(line) 21 | # dict_character = list(character) 22 | 23 | self.dict = {} 24 | for i, char in enumerate(dict_character): 25 | # NOTE: 0 is reserved for 'blank' token required by CTCLoss 26 | self.dict[char] = i + 1 27 | #TODO replace ‘ ’ with special symbol 28 | self.character = ['[blank]'] + dict_character+[' '] # dummy '[blank]' token for CTCLoss (index 0) 29 | 30 | def encode(self, text, batch_max_length=None): 31 | """convert text-label into text-index. 32 | input: 33 | text: text labels of each image. [batch_size] 34 | output: 35 | text: concatenated text index for CTCLoss. 36 | [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] 37 | length: length of each text. [batch_size] 38 | """ 39 | length = [len(s) for s in text] 40 | # text = ''.join(text) 41 | # text = [self.dict[char] for char in text] 42 | d = [] 43 | batch_max_length = max(length) 44 | for s in text: 45 | t = [self.dict[char] for char in s] 46 | t.extend([0] * (batch_max_length - len(s))) 47 | d.append(t) 48 | return (torch.tensor(d, dtype=torch.long), torch.tensor(length, dtype=torch.long)) 49 | 50 | def decode(self, preds, raw=False): 51 | """ convert text-index into text-label. """ 52 | preds_idx = preds.argmax(axis=2) 53 | preds_prob = preds.max(axis=2) 54 | 55 | result_list = [] 56 | for word, prob in zip(preds_idx, preds_prob): 57 | if raw: 58 | result_list.append((''.join([self.character[int(i)] for i in word]), prob)) 59 | else: 60 | result = [] 61 | conf = [] 62 | for i, index in enumerate(word): 63 | if word[i] != 0 and (not (i > 0 and word[i - 1] == word[i])): 64 | # if prob[i] < 0.3: # -------------------------------------------------- 65 | # continue 66 | result.append(self.character[int(index)]) 67 | conf.append(prob[i]) 68 | result_list.append((''.join(result), conf)) 69 | return result_list 70 | 71 | def narrow_224_32(image, expected_size=(280,48)): 72 | ih, iw = image.shape[0:2] 73 | ew, eh = expected_size 74 | scale = eh / ih 75 | nh = int(ih * scale) 76 | nw = int(iw * scale) 77 | image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_CUBIC) 78 | top = 0 79 | bottom = eh - nh - top 80 | left = 0 81 | right = ew - nw - left 82 | 83 | new_img = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) 84 | return new_img 85 | 86 | def img_nchw(img): 87 | mean = 0.5 88 | std = 0.5 89 | resize_ratio = 48 / img.shape[0] 90 | img = cv2.resize(img,(0,0),fx=resize_ratio,fy= resize_ratio,interpolation=cv2.INTER_LINEAR) 91 | # img = cv2.resize(img,(img.shape[1],32)) 92 | 93 | W = math.ceil(img.shape[1]/32)+1 94 | img = narrow_224_32(img,expected_size=(W*32,48)) 95 | img_data = (img.astype(np.float32)/255 - mean) / std 96 | img_np = img_data.transpose(2,0,1) 97 | img_np = np.expand_dims(img_np,0) 98 | return img_np 99 | 100 | if __name__ == "__main__": 101 | 102 | rec_model_path = "./weights/ppv3_rec.pth" 103 | img_path = "rec_images" 104 | dict_path = r"./weights/ppocr_keys_v1.txt" 105 | converter = CTCLabelConverter(dict_path) 106 | 107 | rec_config = AttrDict( 108 | in_channels=3, 109 | backbone=AttrDict(type='MobileNetV1Enhance', scale=0.5, last_conv_stride=[1, 2], last_pool_type='avg'), 110 | neck=AttrDict(type='None'), 111 | head=AttrDict(type='Multi', head_list=AttrDict( 112 | CTC=AttrDict(Neck=AttrDict(name="svtr", dims=64, depth=2, hidden_dims=120, use_guide=True)), 113 | # SARHead=AttrDict(enc_dim=512,max_text_length=70) 114 | ), 115 | n_class=6625) 116 | ) 117 | 118 | rec_model = RecModel(rec_config) 119 | rec_model.load_state_dict(torch.load(rec_model_path)) 120 | rec_model.eval() 121 | 122 | path_list = os.listdir(img_path) 123 | # path_list.sort(key=lambda x: int(x[:-4])) 124 | for name in path_list: 125 | img = cv2.imread(os.path.join(img_path,name)) 126 | time1 = time.time() 127 | img_np_nchw = img_nchw(img) 128 | input_for_torch = torch.from_numpy(img_np_nchw) 129 | feat_2 = rec_model(input_for_torch).softmax(dim=2) 130 | time2 = time.time() 131 | time3 = time2 - time1 132 | feat_2 = feat_2.cpu().data 133 | txt = converter.decode(feat_2.detach().cpu().numpy()) 134 | 135 | print("name:{} txt:{} time:{}".format(name,txt,time3)) 136 | # break 137 | 138 | 139 | -------------------------------------------------------------------------------- /paddle2torch_ppocrv3_rec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from addict import Dict as AttrDict 3 | import shutil 4 | import tempfile 5 | import paddle.fluid as fluid 6 | import os 7 | import torch.onnx as tr_onnx 8 | from rec.RecModel import RecModel 9 | import onnxruntime as ort 10 | import numpy as np 11 | import cv2 12 | import math 13 | 14 | 15 | def load_state(path,trModule_state): 16 | """ 17 | 记载paddlepaddle的参数 18 | :param path: 19 | :return: 20 | """ 21 | if os.path.exists(path + '.pdopt'): 22 | # XXX another hack to ignore the optimizer state 23 | tmp = tempfile.mkdtemp() 24 | dst = os.path.join(tmp, os.path.basename(os.path.normpath(path))) 25 | shutil.copy(path + '.pdparams', dst + '.pdparams') 26 | state = fluid.io.load_program_state(dst) 27 | shutil.rmtree(tmp) 28 | else: 29 | state = fluid.io.load_program_state(path) 30 | 31 | # for i, key in enumerate(state.keys()): 32 | # print("{} {} ".format(i, key)) 33 | keys = ["head.ctc_encoder.encoder.svtr_block.0.mixer.qkv.weight", 34 | "head.ctc_encoder.encoder.svtr_block.0.mixer.proj.weight", 35 | "head.ctc_encoder.encoder.svtr_block.0.mlp.fc1.weight", 36 | "head.ctc_encoder.encoder.svtr_block.0.mlp.fc2.weight", 37 | "head.ctc_encoder.encoder.svtr_block.1.mixer.qkv.weight", 38 | "head.ctc_encoder.encoder.svtr_block.1.mixer.proj.weight", 39 | "head.ctc_encoder.encoder.svtr_block.1.mlp.fc1.weight", 40 | "head.ctc_encoder.encoder.svtr_block.1.mlp.fc2.weight", 41 | "head.ctc_head.fc.weight", 42 | ] 43 | 44 | state_dict = {} 45 | for i, key in enumerate(state.keys()): 46 | if key =="StructuredToParameterName@@": 47 | continue 48 | if i > 238: 49 | j = i-239 50 | if j <= 195: 51 | if trModule_state[j] in keys: 52 | state_dict[trModule_state[j]] = torch.from_numpy(state[key]).transpose(0,1) 53 | else: 54 | state_dict[trModule_state[j]] = torch.from_numpy(state[key]) 55 | 56 | 57 | return state_dict 58 | 59 | def torch_onnx_infer(model,onnx_path): 60 | torch_model = model 61 | onnx_model = ort.InferenceSession(onnx_path) 62 | data_arr = torch.ones(1,3,48,224) 63 | np_arr = np.array(data_arr).astype(np.float32) 64 | print("->>模型前向对比!") 65 | torch_infer = torch_model(data_arr).detach().numpy() 66 | print("torch:", torch_infer) 67 | onnx_infer = onnx_model.run(None,{'input':np_arr}) 68 | print("onnx:", onnx_infer[0]) 69 | std = np.std(torch_infer-onnx_infer[0]) 70 | print("std:",std) 71 | 72 | def torch2onnx(model,onnx_path): 73 | test_arr = torch.randn(1,3,48,224) 74 | input_names = ['input'] 75 | output_names = ['output'] 76 | tr_onnx.export( 77 | model,test_arr,onnx_path, 78 | verbose=False, 79 | opset_version=11, 80 | input_names=input_names, 81 | output_names=output_names, 82 | dynamic_axes={"input":{3:"W"}, 83 | # "output":{1:"width"} 84 | } 85 | ) 86 | print('->>模型转换成功!') 87 | torch_onnx_infer(model,onnx_path) 88 | 89 | def torch2libtorch(model,lib_path): 90 | test_arr = torch.randn(1,3,48,224) 91 | traced_script_module = torch.jit.trace(model, test_arr) 92 | x = torch.ones(1, 3, 48, 224) 93 | output1 = traced_script_module(x) 94 | output2 = model(x) 95 | print(output1) 96 | print(output2) 97 | traced_script_module.save(lib_path) 98 | print("->>模型转换成功!") 99 | 100 | def narrow_224_48(image, expected_size=(280,48)): 101 | ih, iw = image.shape[0:2] 102 | ew, eh = expected_size 103 | scale = eh / ih 104 | nh = int(ih * scale) 105 | nw = int(iw * scale) 106 | image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_CUBIC) 107 | top = 0 108 | bottom = eh - nh - top 109 | left = 0 110 | right = ew - nw - left 111 | 112 | new_img = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) 113 | return new_img 114 | 115 | def img_nchw(img): 116 | mean = 0.5 117 | std = 0.5 118 | resize_ratio = 48 / img.shape[0] 119 | img = cv2.resize(img,(0,0),fx=resize_ratio,fy= resize_ratio,interpolation=cv2.INTER_LINEAR) 120 | # img = cv2.resize(img,(img.shape[1],32)) 121 | 122 | W = math.ceil(img.shape[1]/32)+1 123 | img = narrow_224_48(img,expected_size=(W*32,48)) 124 | img_data = (img.astype(np.float32)/255 - mean) / std 125 | img_np = img_data.transpose(2,0,1) 126 | img_np = np.expand_dims(img_np,0) 127 | return img_np 128 | 129 | if __name__=="__main__": 130 | TrModule_save = './weights/ppv3_rec.pth' 131 | PpModule_path = './weights/ch_PP-OCRv3_rec_train/best_accuracy' 132 | 133 | rec_config = AttrDict( 134 | in_channels=3, 135 | backbone=AttrDict(type='MobileNetV1Enhance', scale=0.5,last_conv_stride=[1,2],last_pool_type='avg'), 136 | neck=AttrDict(type='None'), 137 | head=AttrDict(type='Multi',head_list=AttrDict(CTC=AttrDict(Neck=AttrDict(name="svtr",dims=64,depth=2,hidden_dims=120,use_guide=True)), 138 | # SARHead=AttrDict(enc_dim=512,max_text_length=70) 139 | ), 140 | n_class=6625) 141 | ) 142 | 143 | model = RecModel(rec_config) 144 | 145 | state_dict = [] 146 | 147 | for i,key in enumerate(model.state_dict()): 148 | # print("{} {} {}".format(i,key,model.state_dict()[key].size())) 149 | if 'num_batches_tracked' in key: 150 | continue 151 | state_dict.append(key) 152 | 153 | # for i,keys in enumerate(state_dict): 154 | # print("{} {}".format(i, keys)) 155 | 156 | state_torch = load_state(PpModule_path,state_dict) 157 | torch.save(state_torch, TrModule_save) 158 | model.load_state_dict(state_torch) # model load state 159 | model.eval() 160 | 161 | torch2onnx(model,"./weights/ppv3_rec.onnx") # torch2onnx 162 | # torch2libtorch(model,"ppv3_rec.pt") # torch2jit 163 | -------------------------------------------------------------------------------- /rec/RNN.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from rec.RecSVTR import Block,trunc_normal_,zeros_,ones_ 4 | 5 | class Swish(nn.Module): 6 | def __int__(self): 7 | super(Swish, self).__int__() 8 | 9 | def forward(self,x): 10 | return x*torch.sigmoid(x) 11 | 12 | class Im2Im(nn.Module): 13 | def __init__(self, in_channels, **kwargs): 14 | super().__init__() 15 | self.out_channels = in_channels 16 | 17 | def forward(self, x): 18 | return x 19 | 20 | class Im2Seq(nn.Module): 21 | def __init__(self, in_channels, **kwargs): 22 | super().__init__() 23 | self.out_channels = in_channels 24 | 25 | def forward(self, x): 26 | B, C, H, W = x.shape 27 | # assert H == 1 28 | x = x.reshape(B, C, H * W) 29 | x = x.permute((0, 2, 1)) 30 | return x 31 | 32 | class EncoderWithRNN(nn.Module): 33 | def __init__(self, in_channels,**kwargs): 34 | super(EncoderWithRNN, self).__init__() 35 | hidden_size = kwargs.get('hidden_size', 256) 36 | self.out_channels = hidden_size * 2 37 | self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2,batch_first=True) 38 | 39 | def forward(self, x): 40 | self.lstm.flatten_parameters() 41 | x, _ = self.lstm(x) 42 | return x 43 | 44 | class SequenceEncoder(nn.Module): 45 | def __init__(self, in_channels, encoder_type='rnn', **kwargs): 46 | super(SequenceEncoder, self).__init__() 47 | self.encoder_reshape = Im2Seq(in_channels) 48 | self.out_channels = self.encoder_reshape.out_channels 49 | self.encoder_type = encoder_type 50 | if encoder_type == 'reshape': 51 | self.only_reshape = True 52 | else: 53 | support_encoder_dict = { 54 | 'reshape': Im2Seq, 55 | 'rnn': EncoderWithRNN, 56 | 'svtr':EncoderWithSVTR 57 | } 58 | assert encoder_type in support_encoder_dict, '{} must in {}'.format( 59 | encoder_type, support_encoder_dict.keys()) 60 | 61 | self.encoder = support_encoder_dict[encoder_type]( 62 | self.encoder_reshape.out_channels,**kwargs) 63 | self.out_channels = self.encoder.out_channels 64 | self.only_reshape = False 65 | 66 | def forward(self, x): 67 | if self.encoder_type != 'svtr': 68 | x = self.encoder_reshape(x) 69 | if not self.only_reshape: 70 | x = self.encoder(x) 71 | return x 72 | else: 73 | x = self.encoder(x) 74 | x = self.encoder_reshape(x) 75 | return x 76 | 77 | class ConvBNLayer(nn.Module): 78 | def __init__(self, 79 | in_channels, 80 | out_channels, 81 | kernel_size=3, 82 | stride=1, 83 | padding=0, 84 | bias_attr=False, 85 | groups=1, 86 | act=nn.GELU): 87 | super().__init__() 88 | self.conv = nn.Conv2d( 89 | in_channels=in_channels, 90 | out_channels=out_channels, 91 | kernel_size=kernel_size, 92 | stride=stride, 93 | padding=padding, 94 | groups=groups, 95 | # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), 96 | bias=bias_attr) 97 | self.norm = nn.BatchNorm2d(out_channels) 98 | self.act = Swish() 99 | 100 | def forward(self, inputs): 101 | out = self.conv(inputs) 102 | out = self.norm(out) 103 | out = self.act(out) 104 | return out 105 | 106 | 107 | class EncoderWithSVTR(nn.Module): 108 | def __init__( 109 | self, 110 | in_channels, 111 | dims=64, # XS 112 | depth=2, 113 | hidden_dims=120, 114 | use_guide=False, 115 | num_heads=8, 116 | qkv_bias=True, 117 | mlp_ratio=2.0, 118 | drop_rate=0.1, 119 | attn_drop_rate=0.1, 120 | drop_path=0., 121 | qk_scale=None): 122 | super(EncoderWithSVTR, self).__init__() 123 | self.depth = depth 124 | self.use_guide = use_guide 125 | self.conv1 = ConvBNLayer( 126 | in_channels, in_channels // 8, padding=1) 127 | self.conv2 = ConvBNLayer( 128 | in_channels // 8, hidden_dims, kernel_size=1) 129 | 130 | self.svtr_block = nn.ModuleList([ 131 | Block( 132 | dim=hidden_dims, 133 | num_heads=num_heads, 134 | mixer='Global', 135 | HW=None, 136 | mlp_ratio=mlp_ratio, 137 | qkv_bias=qkv_bias, 138 | qk_scale=qk_scale, 139 | drop=drop_rate, 140 | act_layer="Swish", 141 | attn_drop=attn_drop_rate, 142 | drop_path=drop_path, 143 | norm_layer='nn.LayerNorm', 144 | epsilon=1e-05, 145 | prenorm=False) for i in range(depth) 146 | ]) 147 | self.norm = nn.LayerNorm(hidden_dims, eps=1e-6) 148 | self.conv3 = ConvBNLayer( 149 | hidden_dims, in_channels, kernel_size=1) 150 | # last conv-nxn, the input is concat of input tensor and conv3 output tensor 151 | self.conv4 = ConvBNLayer( 152 | 2 * in_channels, in_channels // 8, padding=1) 153 | 154 | self.conv1x1 = ConvBNLayer( 155 | in_channels // 8, dims, kernel_size=1) 156 | self.out_channels = dims 157 | self.apply(self._init_weights) 158 | 159 | def _init_weights(self, m): 160 | if isinstance(m, nn.Linear): 161 | trunc_normal_(m.weight) 162 | if isinstance(m, nn.Linear) and m.bias is not None: 163 | zeros_(m.bias) 164 | elif isinstance(m, nn.LayerNorm): 165 | zeros_(m.bias) 166 | ones_(m.weight) 167 | 168 | def forward(self, x): 169 | # for use guide 170 | if self.use_guide: 171 | z = x.clone() 172 | z.stop_gradient = True 173 | else: 174 | z = x 175 | # for short cut 176 | h = z 177 | # reduce dim 178 | z = self.conv1(z) 179 | z = self.conv2(z) 180 | # SVTR global block 181 | B, C, H, W = z.shape 182 | z = z.flatten(2).permute([0, 2, 1]) 183 | for blk in self.svtr_block: 184 | z = blk(z) 185 | z = self.norm(z) 186 | # last stage 187 | z = z.reshape([-1, H, W, C]).permute([0, 3, 1, 2]) 188 | z = self.conv3(z) 189 | z = torch.cat((h, z), dim=1) 190 | z = self.conv1x1(self.conv4(z)) 191 | return z 192 | 193 | if __name__=="__main__": 194 | svtrRNN = EncoderWithSVTR(56) 195 | print(svtrRNN) -------------------------------------------------------------------------------- /rec/RecMv1_enhance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Hardswish(nn.Module): 6 | def forward(self, x): 7 | out = x * F.relu6(x + 3, inplace=True) / 6 8 | return out 9 | 10 | class Hardsigmoid(nn.Module): 11 | def __init__(self, inplace=True): 12 | super().__init__() 13 | self.inplace = inplace 14 | 15 | def forward(self, x): 16 | # return (1.2 * x).add_(3.).clamp_(0., 6.).div_(6.) 17 | return F.relu6(x + 3., inplace=True) / 6. 18 | 19 | # class Hardsigmoid(nn.Module): 20 | # def __init__(self, type): 21 | # super().__init__() 22 | # self.type = type 23 | # 24 | # def forward(self, x): 25 | # if self.type == 'paddle': 26 | # x = (1.2 * x).add_(3.).clamp_(0., 6.).div_(6.) 27 | # else: 28 | # x = F.relu6(x + 3, inplace=True) / 6 29 | # return x 30 | 31 | 32 | class ConvBNLayer(nn.Module): 33 | def __init__(self, 34 | num_channels, 35 | filter_size, 36 | num_filters, 37 | stride, 38 | padding, 39 | num_groups=1): 40 | super().__init__() 41 | 42 | self.conv = nn.Conv2d( 43 | in_channels=num_channels, 44 | out_channels=num_filters, 45 | kernel_size=filter_size, 46 | stride=stride, 47 | padding=padding, 48 | groups=num_groups, 49 | bias=False) 50 | 51 | self.bn = nn.BatchNorm2d( 52 | num_filters, 53 | ) 54 | self.hardswish = Hardswish() 55 | 56 | def forward(self, x): 57 | x = self.conv(x) 58 | x = self.bn(x) 59 | x = self.hardswish(x) 60 | return x 61 | 62 | 63 | class DepthwiseSeparable(nn.Module): 64 | def __init__(self, 65 | num_channels, 66 | num_filters1, 67 | num_filters2, 68 | num_groups, 69 | stride, 70 | scale, 71 | dw_size=3, 72 | padding =1, 73 | use_se=False): 74 | super().__init__() 75 | self.use_se = use_se 76 | self.dw_conv = ConvBNLayer( 77 | num_channels=num_channels, 78 | num_filters=int(num_filters1 *scale), 79 | filter_size=dw_size, 80 | stride=stride, 81 | padding=padding, 82 | num_groups=int(num_groups *scale)) 83 | if use_se: 84 | self.se = SEModule(int(num_filters1 * scale)) 85 | self.pw_conv = ConvBNLayer( 86 | num_channels=int(num_filters1 * scale), 87 | filter_size=1, 88 | num_filters=int(num_filters2 * scale), 89 | stride=1, 90 | padding=0) 91 | 92 | def forward(self, x): 93 | x = self.dw_conv(x) 94 | if self.use_se: 95 | x = self.se(x) 96 | x = self.pw_conv(x) 97 | return x 98 | 99 | 100 | class MobileNetV1Enhance(nn.Module): 101 | def __init__(self, 102 | in_channels=3, 103 | scale=0.5, 104 | last_conv_stride=1, 105 | last_pool_type='max', 106 | **kwargs): 107 | super().__init__() 108 | self.scale = scale 109 | self.block_list = [] 110 | 111 | self.conv1 = ConvBNLayer( 112 | num_channels=3, 113 | filter_size=3, 114 | num_filters=int(32 * scale), 115 | stride=2, 116 | padding=1) 117 | 118 | conv2_1 = DepthwiseSeparable( 119 | num_channels=int(32 * scale), 120 | num_filters1=32, 121 | num_filters2=64, 122 | num_groups=32, 123 | stride=1, 124 | scale=scale) 125 | self.block_list.append(conv2_1) 126 | 127 | conv2_2 = DepthwiseSeparable( 128 | num_channels=int(64 * scale), 129 | num_filters1=64, 130 | num_filters2=128, 131 | num_groups=64, 132 | stride=1, 133 | scale=scale) 134 | self.block_list.append(conv2_2) 135 | 136 | conv3_1 = DepthwiseSeparable( 137 | num_channels=int(128 * scale), 138 | num_filters1=128, 139 | num_filters2=128, 140 | num_groups=128, 141 | stride=1, 142 | scale=scale) 143 | self.block_list.append(conv3_1) 144 | 145 | conv3_2 = DepthwiseSeparable( 146 | num_channels=int(128 * scale), 147 | num_filters1=128, 148 | num_filters2=256, 149 | num_groups=128, 150 | stride=(2, 1), 151 | scale=scale) 152 | self.block_list.append(conv3_2) 153 | 154 | conv4_1 = DepthwiseSeparable( 155 | num_channels=int(256 * scale), 156 | num_filters1=256, 157 | num_filters2=256, 158 | num_groups=256, 159 | stride=1, 160 | scale=scale) 161 | self.block_list.append(conv4_1) 162 | 163 | conv4_2 = DepthwiseSeparable( 164 | num_channels=int(256 * scale), 165 | num_filters1=256, 166 | num_filters2=512, 167 | num_groups=256, 168 | stride=(2, 1), 169 | scale=scale) 170 | self.block_list.append(conv4_2) 171 | 172 | for _ in range(5): 173 | conv5 = DepthwiseSeparable( 174 | num_channels=int(512 * scale), 175 | num_filters1=512, 176 | num_filters2=512, 177 | num_groups=512, 178 | stride=1, 179 | dw_size=5, 180 | padding=2, 181 | scale=scale, 182 | use_se=False) 183 | self.block_list.append(conv5) 184 | 185 | conv5_6 = DepthwiseSeparable( 186 | num_channels=int(512 * scale), 187 | num_filters1=512, 188 | num_filters2=1024, 189 | num_groups=512, 190 | stride=(2, 1), 191 | dw_size=5, 192 | padding=2, 193 | scale=scale, 194 | use_se=True) 195 | self.block_list.append(conv5_6) 196 | 197 | conv6 = DepthwiseSeparable( 198 | num_channels=int(1024 * scale), 199 | num_filters1=1024, 200 | num_filters2=1024, 201 | num_groups=1024, 202 | stride=last_conv_stride, 203 | dw_size=5, 204 | padding=2, 205 | use_se=True, 206 | scale=scale) 207 | self.block_list.append(conv6) 208 | 209 | self.block_list = nn.Sequential(*self.block_list) 210 | if last_pool_type == "avg": 211 | self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) 212 | else: 213 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 214 | self.out_channels = int(1024 * scale) 215 | 216 | def forward(self, inputs): 217 | y = self.conv1(inputs) 218 | y = self.block_list(y) 219 | y = self.pool(y) 220 | return y 221 | 222 | class SEModule(nn.Module): 223 | def __init__(self, channel, reduction=4): 224 | super().__init__() 225 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 226 | self.conv1 = nn.Conv2d( 227 | in_channels=channel, 228 | out_channels=channel // reduction, 229 | kernel_size=1, 230 | stride=1, 231 | padding=0) 232 | self.relu = nn.ReLU() 233 | self.conv2 = nn.Conv2d( 234 | in_channels=channel // reduction, 235 | out_channels=channel, 236 | kernel_size=1, 237 | stride=1, 238 | padding=0) 239 | self.hardsigmoid = Hardsigmoid() 240 | 241 | def forward(self, x): 242 | identity = x 243 | x = self.avg_pool(x) 244 | x = self.conv1(x) 245 | x = self.relu(x) 246 | x = self.conv2(x) 247 | x = self.hardsigmoid(x) 248 | x = torch.mul(identity, x) 249 | return x 250 | 251 | 252 | if __name__=="__main__": 253 | 254 | from torchsummary import summary 255 | 256 | arr = torch.rand((1,3,32,224)) 257 | model = MobileNetV1Enhance() 258 | summary(model, input_size=(3, 32, 224), batch_size=1) 259 | out = model(arr) 260 | print(out.size()) 261 | # print(model) 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | -------------------------------------------------------------------------------- /torch_det_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from det.DetModel import DetModel 3 | import torch 4 | from addict import Dict as AttrDict 5 | import cv2 6 | import numpy as np 7 | import math 8 | import time 9 | import pyclipper 10 | from shapely.geometry import Polygon 11 | 12 | 13 | class DBPostProcess(): 14 | def __init__(self, thresh=0.3, box_thresh=0.4, max_candidates=1000, unclip_ratio=2): 15 | self.min_size = 3 16 | self.thresh = thresh 17 | self.box_thresh = box_thresh 18 | self.max_candidates = max_candidates 19 | self.unclip_ratio = unclip_ratio 20 | 21 | def __call__(self, pred, h_w_list, is_output_polygon=False): 22 | ''' 23 | batch: (image, polygons, ignore_tags 24 | h_w_list: 包含[h,w]的数组 25 | pred: 26 | binary: text region segmentation map, with shape (N, 1,H, W) 27 | ''' 28 | pred = pred[:, 0, :, :] 29 | segmentation = self.binarize(pred) 30 | boxes_batch = [] 31 | scores_batch = [] 32 | for batch_index in range(pred.shape[0]): 33 | height, width = h_w_list[batch_index] 34 | boxes, scores = self.post_p(pred[batch_index], segmentation[batch_index], width, height,is_output_polygon=is_output_polygon) 35 | boxes_batch.append(boxes) 36 | scores_batch.append(scores) 37 | return boxes_batch, scores_batch 38 | 39 | def binarize(self, pred): 40 | return pred > self.thresh 41 | 42 | def post_p(self, pred, bitmap, dest_width, dest_height, is_output_polygon=False): 43 | ''' 44 | _bitmap: single map with shape (H, W), 45 | whose values are binarized as {0, 1} 46 | ''' 47 | height, width = pred.shape 48 | boxes = [] 49 | new_scores = [] 50 | bitmap = bitmap.cpu().numpy() 51 | if cv2.__version__.startswith('3'): 52 | _, contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 53 | if cv2.__version__.startswith('4'): 54 | contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 55 | for contour in contours[:self.max_candidates]: 56 | epsilon = 0.005 * cv2.arcLength(contour, True) 57 | approx = cv2.approxPolyDP(contour, epsilon, True) 58 | points = approx.reshape((-1, 2)) 59 | if points.shape[0] < 4: 60 | continue 61 | score = self.box_score_fast(pred, contour.squeeze(1)) 62 | if self.box_thresh > score: 63 | continue 64 | if points.shape[0] > 2: 65 | box = self.unclip(points, unclip_ratio=self.unclip_ratio) 66 | if len(box) > 1: 67 | continue 68 | else: 69 | continue 70 | four_point_box, sside = self.get_mini_boxes(box.reshape((-1, 1, 2))) 71 | if sside < self.min_size + 2: 72 | continue 73 | if not isinstance(dest_width, int): 74 | dest_width = dest_width.item() 75 | dest_height = dest_height.item() 76 | if not is_output_polygon: 77 | box = np.array(four_point_box) 78 | else: 79 | box = box.reshape(-1, 2) 80 | box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) 81 | box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) 82 | boxes.append(box) 83 | new_scores.append(score) 84 | return boxes, new_scores 85 | 86 | def unclip(self, box, unclip_ratio=1.5): 87 | poly = Polygon(box) 88 | distance = poly.area * unclip_ratio / poly.length 89 | offset = pyclipper.PyclipperOffset() 90 | offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 91 | expanded = np.array(offset.Execute(distance)) 92 | return expanded 93 | 94 | def get_mini_boxes(self, contour): 95 | bounding_box = cv2.minAreaRect(contour) 96 | points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) 97 | 98 | index_1, index_2, index_3, index_4 = 0, 1, 2, 3 99 | if points[1][1] > points[0][1]: 100 | index_1 = 0 101 | index_4 = 1 102 | else: 103 | index_1 = 1 104 | index_4 = 0 105 | if points[3][1] > points[2][1]: 106 | index_2 = 2 107 | index_3 = 3 108 | else: 109 | index_2 = 3 110 | index_3 = 2 111 | 112 | box = [points[index_1], points[index_2], points[index_3], points[index_4]] 113 | return box, min(bounding_box[1]) 114 | 115 | def box_score_fast(self, bitmap, _box): 116 | bitmap = bitmap.detach().cpu().numpy() 117 | h, w = bitmap.shape[:2] 118 | box = _box.copy() 119 | xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) 120 | xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) 121 | ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) 122 | ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) 123 | 124 | mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) 125 | box[:, 0] = box[:, 0] - xmin 126 | box[:, 1] = box[:, 1] - ymin 127 | cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) 128 | return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] 129 | 130 | def narrow(image, expected_size=(224,224)): 131 | ih, iw = image.shape[0:2] 132 | ew, eh = expected_size 133 | # scale = eh / ih 134 | scale = min((eh/ih),(ew/iw)) 135 | # scale = eh / max(iw,ih) 136 | nh = int(ih * scale) 137 | nw = int(iw * scale) 138 | image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_CUBIC) 139 | top = 0 140 | bottom = eh - nh 141 | left = 0 142 | right = ew - nw 143 | new_img = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) 144 | return new_img 145 | 146 | def draw_bbox(img_path, result, color=(0, 0, 255), thickness=2): 147 | import cv2 148 | if isinstance(img_path, str): 149 | img_path = cv2.imread(img_path) 150 | # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB) 151 | img_path = img_path.copy() 152 | for point in result: 153 | point = point.astype(int) 154 | cv2.polylines(img_path, [point], True, color, thickness) 155 | return img_path 156 | 157 | def img_nchw(img): 158 | mean = 0.5 159 | std = 0.5 160 | resize_ratio = min((640 / img.shape[0]),(640/img.shape[1])) 161 | img = cv2.resize(img,(0,0),fx=resize_ratio,fy= resize_ratio,interpolation=cv2.INTER_LINEAR) 162 | h,w= img.shape[:2] 163 | if h == 640: 164 | w = (math.ceil(w/32)+1)*32 165 | elif w == 640: 166 | h = (math.ceil(h/32)+1)*32 167 | img1 = narrow(img,(w,h)) 168 | 169 | img_data = (img1.astype(np.float32)/255 - mean) / std 170 | img_np = img_data.transpose(2,0,1) 171 | img_np = np.expand_dims(img_np,0) 172 | return img1,img_np 173 | 174 | 175 | if __name__ == '__main__': 176 | det_model_path = './weights/ppv3_db.pth' 177 | test_img = "./det_images" 178 | 179 | post_proess = DBPostProcess() 180 | 181 | db_config = AttrDict( 182 | in_channels=3, 183 | backbone=AttrDict(type='MobileNetV3', model_name='large',scale=0.5,pretrained=True), 184 | neck=AttrDict(type='RSEFPN', out_channels=96), 185 | head=AttrDict(type='DBHead') 186 | ) 187 | 188 | det_model = DetModel(db_config) 189 | det_model.load_state_dict(torch.load(det_model_path)) 190 | det_model.eval() 191 | 192 | path_list = os.listdir(test_img) 193 | for name in path_list: 194 | img = cv2.imread(os.path.join(test_img, name)) 195 | img0, img_np_nchw = img_nchw(img) 196 | 197 | input_for_torch = torch.from_numpy(img_np_nchw) 198 | out = det_model(input_for_torch) # torch model infer 199 | 200 | box_list, score_list = post_proess(out, [img0.shape[:2]], is_output_polygon=False) 201 | box_list, score_list = box_list[0], score_list[0] 202 | if len(box_list) > 0: 203 | idx = [x.sum() > 0 for x in box_list] 204 | box_list = [box_list[i] for i, v in enumerate(idx) if v] 205 | score_list = [score_list[i] for i, v in enumerate(idx) if v] 206 | else: 207 | box_list, score_list = [], [] 208 | 209 | img1 = draw_bbox(img0, box_list) 210 | cv2.imshow("draw", img1) 211 | cv2.waitKey() 212 | 213 | 214 | -------------------------------------------------------------------------------- /det/DetMobilenetV3.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torch import nn 5 | from det.CommonModules import ConvBNACT, SEBlock 6 | 7 | class ResidualUnit(nn.Module): 8 | def __init__(self, num_in_filter, num_mid_filter, num_out_filter, stride, kernel_size, act=None, use_se=False): 9 | super().__init__() 10 | self.conv0 = ConvBNACT(in_channels=num_in_filter, out_channels=num_mid_filter, kernel_size=1, stride=1, 11 | padding=0, act=act) 12 | 13 | self.conv1 = ConvBNACT(in_channels=num_mid_filter, out_channels=num_mid_filter, kernel_size=kernel_size, 14 | stride=stride, 15 | padding=int((kernel_size - 1) // 2), act=act, groups=num_mid_filter) 16 | if use_se: 17 | self.se = SEBlock(in_channels=num_mid_filter, out_channels=num_mid_filter) 18 | else: 19 | self.se = None 20 | 21 | self.conv2 = ConvBNACT(in_channels=num_mid_filter, out_channels=num_out_filter, kernel_size=1, stride=1, 22 | padding=0) 23 | self.not_add = num_in_filter != num_out_filter or stride != 1 24 | 25 | def load_3rd_state_dict(self, _3rd_name, _state, _convolution_index): 26 | if _3rd_name == 'paddle': 27 | self.conv0.load_3rd_state_dict(_3rd_name, _state, f'conv{_convolution_index}_expand') 28 | self.conv1.load_3rd_state_dict(_3rd_name, _state, f'conv{_convolution_index}_depthwise') 29 | if self.se is not None: 30 | self.se.load_3rd_state_dict(_3rd_name, _state, f'conv{_convolution_index}_se') 31 | self.conv2.load_3rd_state_dict(_3rd_name, _state, f'conv{_convolution_index}_linear') 32 | else: 33 | pass 34 | pass 35 | 36 | def forward(self, x): 37 | y = self.conv0(x) 38 | y = self.conv1(y) 39 | if self.se is not None: 40 | y = self.se(y) 41 | y = self.conv2(y) 42 | if not self.not_add: 43 | y = x + y 44 | return y 45 | 46 | 47 | class MobileNetV3(nn.Module): 48 | def __init__(self, in_channels, pretrained=True, **kwargs): 49 | """ 50 | the MobilenetV3 backbone network for detection module. 51 | Args: 52 | params(dict): the super parameters for build network 53 | """ 54 | super().__init__() 55 | self.scale = kwargs.get('scale', 0.5) 56 | model_name = kwargs.get('model_name', 'large') 57 | self.disable_se=kwargs.get('disable_se','True') 58 | self.inplanes = 16 59 | if model_name == "large": 60 | self.cfg = [ 61 | # k, exp, c, se, nl, s, 62 | [3, 16, 16, False, 'relu', 1], 63 | [3, 64, 24, False, 'relu', 2], 64 | [3, 72, 24, False, 'relu', 1], 65 | [5, 72, 40, True, 'relu', 2], 66 | [5, 120, 40, True, 'relu', 1], 67 | [5, 120, 40, True, 'relu', 1], 68 | [3, 240, 80, False, 'hard_swish', 2], 69 | [3, 200, 80, False, 'hard_swish', 1], 70 | [3, 184, 80, False, 'hard_swish', 1], 71 | [3, 184, 80, False, 'hard_swish', 1], 72 | [3, 480, 112, True, 'hard_swish', 1], 73 | [3, 672, 112, True, 'hard_swish', 1], 74 | [5, 672, 160, True, 'hard_swish', 2], 75 | [5, 960, 160, True, 'hard_swish', 1], 76 | [5, 960, 160, True, 'hard_swish', 1], 77 | ] 78 | self.cls_ch_squeeze = 960 79 | self.cls_ch_expand = 1280 80 | elif model_name == "small": 81 | self.cfg = [ 82 | # k, exp, c, se, nl, s, 83 | [3, 16, 16, True, 'relu', 2], 84 | [3, 72, 24, False, 'relu', 2], 85 | [3, 88, 24, False, 'relu', 1], 86 | [5, 96, 40, True, 'hard_swish', 2], 87 | [5, 240, 40, True, 'hard_swish', 1], 88 | [5, 240, 40, True, 'hard_swish', 1], 89 | [5, 120, 48, True, 'hard_swish', 1], 90 | [5, 144, 48, True, 'hard_swish', 1], 91 | [5, 288, 96, True, 'hard_swish', 2], 92 | [5, 576, 96, True, 'hard_swish', 1], 93 | [5, 576, 96, True, 'hard_swish', 1], 94 | ] 95 | self.cls_ch_squeeze = 576 96 | self.cls_ch_expand = 1280 97 | else: 98 | raise NotImplementedError("mode[" + model_name + 99 | "_model] is not implemented!") 100 | 101 | supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25] 102 | assert self.scale in supported_scale, \ 103 | "supported scale are {} but input scale is {}".format(supported_scale, self.scale) 104 | 105 | scale = self.scale 106 | inplanes = self.inplanes 107 | cfg = self.cfg 108 | cls_ch_squeeze = self.cls_ch_squeeze 109 | # conv1 110 | self.conv1 = ConvBNACT(in_channels=in_channels, 111 | out_channels=self.make_divisible(inplanes * scale), 112 | kernel_size=3, 113 | stride=2, 114 | padding=1, 115 | groups=1, 116 | act='hard_swish') 117 | i = 0 118 | inplanes = self.make_divisible(inplanes * scale) 119 | self.stages = nn.ModuleList() 120 | block_list = [] 121 | self.out_channels = [] 122 | for layer_cfg in cfg: 123 | se = layer_cfg[3] and not self.disable_se 124 | if layer_cfg[5] == 2 and i > 2: 125 | self.out_channels.append(inplanes) 126 | self.stages.append(nn.Sequential(*block_list)) 127 | block_list = [] 128 | block = ResidualUnit(num_in_filter=inplanes, 129 | num_mid_filter=self.make_divisible(scale * layer_cfg[1]), 130 | num_out_filter=self.make_divisible(scale * layer_cfg[2]), 131 | act=layer_cfg[4], 132 | stride=layer_cfg[5], 133 | kernel_size=layer_cfg[0], 134 | use_se=se) 135 | block_list.append(block) 136 | inplanes = self.make_divisible(scale * layer_cfg[2]) 137 | i += 1 138 | block_list.append(ConvBNACT( 139 | in_channels=inplanes, 140 | out_channels=self.make_divisible(scale * cls_ch_squeeze), 141 | kernel_size=1, 142 | stride=1, 143 | padding=0, 144 | groups=1, 145 | act='hard_swish')) 146 | self.stages.append(nn.Sequential(*block_list)) 147 | self.out_channels.append(self.make_divisible(scale * cls_ch_squeeze)) 148 | 149 | if pretrained: 150 | ckpt_path = f'./weights/MobileNetV3_{model_name}_x{str(scale).replace(".", "_")}.pth' 151 | logger = logging.getLogger('torchocr') 152 | if os.path.exists(ckpt_path): 153 | logger.info('load imagenet weights') 154 | self.load_state_dict(torch.load(ckpt_path)) 155 | else: 156 | logger.info(f'{ckpt_path} not exists') 157 | 158 | def make_divisible(self, v, divisor=8, min_value=None): 159 | if min_value is None: 160 | min_value = divisor 161 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 162 | if new_v < 0.9 * v: 163 | new_v += divisor 164 | return new_v 165 | 166 | def load_3rd_state_dict(self, _3rd_name, _state): 167 | if _3rd_name == 'paddle': 168 | self.conv1.load_3rd_state_dict(_3rd_name, _state, 'conv1') 169 | m_block_index = 2 170 | for m_stage in self.stages: 171 | for m_block in m_stage: 172 | m_block.load_3rd_state_dict(_3rd_name, _state, m_block_index) 173 | m_block_index += 1 174 | self.conv2.load_3rd_state_dict(_3rd_name, _state, 'conv_last') 175 | else: 176 | pass 177 | 178 | def forward(self, x): 179 | x = self.conv1(x) 180 | out = [] 181 | for stage in self.stages: 182 | x = stage(x) 183 | out.append(x) 184 | 185 | return out 186 | 187 | 188 | if __name__ == "__main__": 189 | from torchsummary import summary 190 | 191 | input = torch.randn(1, 3, 640, 640) 192 | net = MobileNetV3(in_channels=3,disable_se=True) 193 | out = net(input) 194 | print(len(out)) 195 | print(out[0].size()) 196 | print(out[1].size()) 197 | print(out[2].size()) 198 | print(out[3].size()) 199 | 200 | summary(net, input_size=(3, 640, 640), batch_size=1, device="cpu") 201 | -------------------------------------------------------------------------------- /det/DB_fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from det.DetMobilenetV3 import SEBlock 6 | 7 | class DSConv(nn.Module): 8 | def __init__(self, 9 | in_channels, 10 | out_channels, 11 | kernel_size, 12 | padding, 13 | stride=1, 14 | groups=None, 15 | if_act=True, 16 | act="relu", 17 | **kwargs): 18 | super(DSConv, self).__init__() 19 | if groups == None: 20 | groups = in_channels 21 | self.if_act = if_act 22 | self.act = act 23 | 24 | self.conv1 = nn.Conv2d( 25 | in_channels=in_channels, 26 | out_channels=in_channels, 27 | kernel_size=kernel_size, 28 | stride=stride, 29 | padding=padding, 30 | groups=groups, 31 | bias=False) 32 | self.conv1 = nn.Conv2d(in_channels=in_channels, 33 | out_channels=in_channels, 34 | kernel_size=kernel_size, 35 | stride=stride, 36 | padding=padding, 37 | groups=groups, 38 | bias=False) 39 | 40 | self.bn1 = nn.BatchNorm2d(in_channels) 41 | 42 | self.conv2 = nn.Conv2d( 43 | in_channels=in_channels, 44 | out_channels=int(in_channels * 4), 45 | kernel_size=1, 46 | stride=1, 47 | bias=False) 48 | 49 | self.bn2 = nn.BatchNorm2d(int(in_channels * 4)) 50 | 51 | self.conv3 = nn.Conv2d( 52 | in_channels=int(in_channels * 4), 53 | out_channels=out_channels, 54 | kernel_size=1, 55 | stride=1, 56 | bias=False) 57 | self._c = [in_channels, out_channels] 58 | if in_channels != out_channels: 59 | self.conv_end = nn.Conv2d( 60 | in_channels=in_channels, 61 | out_channels=out_channels, 62 | kernel_size=1, 63 | stride=1, 64 | bias=False) 65 | 66 | def forward(self, inputs): 67 | 68 | x = self.conv1(inputs) 69 | x = self.bn1(x) 70 | 71 | x = self.conv2(x) 72 | x = self.bn2(x) 73 | if self.if_act: 74 | if self.act == "relu": 75 | x = F.relu(x) 76 | elif self.act == "hardswish": 77 | x = F.hardswish(x) 78 | else: 79 | print("The activation function({}) is selected incorrectly.". 80 | format(self.act)) 81 | exit() 82 | 83 | x = self.conv3(x) 84 | if self._c[0] != self._c[1]: 85 | x = x + self.conv_end(inputs) 86 | return x 87 | 88 | class DB_fpn(nn.Module): 89 | def __init__(self, in_channels, out_channels=256, **kwargs): 90 | """ 91 | :param in_channels: 基础网络输出的维度 92 | :param kwargs: 93 | """ 94 | super().__init__() 95 | inplace = True 96 | self.out_channels = out_channels 97 | # reduce layers 98 | self.in2_conv = nn.Conv2d(in_channels[0], self.out_channels, kernel_size=1, bias=False) 99 | self.in3_conv = nn.Conv2d(in_channels[1], self.out_channels, kernel_size=1, bias=False) 100 | self.in4_conv = nn.Conv2d(in_channels[2], self.out_channels, kernel_size=1, bias=False) 101 | self.in5_conv = nn.Conv2d(in_channels[3], self.out_channels, kernel_size=1, bias=False) 102 | # Smooth layers 103 | self.p5_conv = nn.Conv2d(self.out_channels, self.out_channels // 4, kernel_size=3, padding=1, bias=False) 104 | self.p4_conv = nn.Conv2d(self.out_channels, self.out_channels // 4, kernel_size=3, padding=1, bias=False) 105 | self.p3_conv = nn.Conv2d(self.out_channels, self.out_channels // 4, kernel_size=3, padding=1, bias=False) 106 | self.p2_conv = nn.Conv2d(self.out_channels, self.out_channels // 4, kernel_size=3, padding=1, bias=False) 107 | 108 | def _upsample_add(self, x, y): 109 | return F.interpolate(x, scale_factor=2) + y 110 | 111 | def _upsample_cat(self, p2, p3, p4, p5): 112 | p3 = F.interpolate(p3, scale_factor=2) 113 | p4 = F.interpolate(p4, scale_factor=4) 114 | p5 = F.interpolate(p5, scale_factor=8) 115 | return torch.cat([p5, p4, p3, p2], dim=1) 116 | 117 | def forward(self, x): 118 | c2, c3, c4, c5 = x 119 | in5 = self.in5_conv(c5) 120 | in4 = self.in4_conv(c4) 121 | in3 = self.in3_conv(c3) 122 | in2 = self.in2_conv(c2) 123 | 124 | out4 = self._upsample_add(in5, in4) 125 | out3 = self._upsample_add(out4, in3) 126 | out2 = self._upsample_add(out3, in2) 127 | 128 | p5 = self.p5_conv(in5) 129 | p4 = self.p4_conv(out4) 130 | p3 = self.p3_conv(out3) 131 | p2 = self.p2_conv(out2) 132 | 133 | x = self._upsample_cat(p2, p3, p4, p5) 134 | return x 135 | 136 | class RSELayer(nn.Module): 137 | def __init__(self, in_channels, out_channels, kernel_size, shortcut=True): 138 | super(RSELayer, self).__init__() 139 | # weight_attr = torch.nn.init.kaiming_uniform() 140 | self.out_channels = out_channels 141 | self.in_conv = nn.Conv2d( 142 | in_channels=in_channels, 143 | out_channels=self.out_channels, 144 | kernel_size=kernel_size, 145 | padding=int(kernel_size // 2), 146 | # weight_attr=ParamAttr(initializer=weight_attr), 147 | bias=False) 148 | self.se_block = SEBlock(self.out_channels,self.out_channels) 149 | self.shortcut = shortcut 150 | 151 | def forward(self, ins): 152 | x = self.in_conv(ins) 153 | if self.shortcut: 154 | out = x + self.se_block(x) 155 | else: 156 | out = self.se_block(x) 157 | return out 158 | 159 | 160 | class RSEFPN(nn.Module): 161 | def __init__(self, in_channels, out_channels=256, shortcut=True, **kwargs): 162 | super(RSEFPN, self).__init__() 163 | self.out_channels = out_channels 164 | self.ins_conv = nn.ModuleList() 165 | self.inp_conv = nn.ModuleList() 166 | 167 | for i in range(len(in_channels)): 168 | self.ins_conv.append( 169 | RSELayer( 170 | in_channels[i], 171 | out_channels, 172 | kernel_size=1, 173 | shortcut=shortcut)) 174 | self.inp_conv.append( 175 | RSELayer( 176 | out_channels, 177 | out_channels // 4, 178 | kernel_size=3, 179 | shortcut=shortcut)) 180 | 181 | def _upsample_add(self, x, y): 182 | return F.interpolate(x, scale_factor=2) + y 183 | 184 | def _upsample_cat(self, p2, p3, p4, p5): 185 | p3 = F.interpolate(p3, scale_factor=2) 186 | p4 = F.interpolate(p4, scale_factor=4) 187 | p5 = F.interpolate(p5, scale_factor=8) 188 | return torch.cat([p5, p4, p3, p2], dim=1) 189 | 190 | def forward(self, x): 191 | c2, c3, c4, c5 = x 192 | 193 | in5 = self.ins_conv[3](c5) 194 | in4 = self.ins_conv[2](c4) 195 | in3 = self.ins_conv[1](c3) 196 | in2 = self.ins_conv[0](c2) 197 | 198 | out4 = self._upsample_add(in5, in4) 199 | out3 = self._upsample_add(out4, in3) 200 | out2 = self._upsample_add(out3, in2) 201 | 202 | p5 = self.inp_conv[3](in5) 203 | p4 = self.inp_conv[2](out4) 204 | p3 = self.inp_conv[1](out3) 205 | p2 = self.inp_conv[0](out2) 206 | 207 | x = self._upsample_cat(p2, p3, p4, p5) 208 | return x 209 | 210 | 211 | class LKPAN(nn.Module): 212 | def __init__(self, in_channels, out_channels, mode='large', **kwargs): 213 | super(LKPAN, self).__init__() 214 | self.out_channels = out_channels 215 | # weight_attr = torch.nn.init.kaiming_uniform() 216 | 217 | self.ins_conv = nn.ModuleList() 218 | self.inp_conv = nn.ModuleList() 219 | # pan head 220 | self.pan_head_conv = nn.ModuleList() 221 | self.pan_lat_conv = nn.ModuleList() 222 | 223 | if mode.lower() == 'lite': 224 | p_layer = DSConv 225 | elif mode.lower() == 'large': 226 | p_layer = nn.Conv2d 227 | else: 228 | raise ValueError( 229 | "mode can only be one of ['lite', 'large'], but received {}". 230 | format(mode)) 231 | 232 | for i in range(len(in_channels)): 233 | self.ins_conv.append( 234 | nn.Conv2d( 235 | in_channels=in_channels[i], 236 | out_channels=self.out_channels, 237 | kernel_size=1, 238 | # weight_attr=ParamAttr(initializer=weight_attr), 239 | bias=False)) 240 | 241 | self.inp_conv.append( 242 | p_layer( 243 | in_channels=self.out_channels, 244 | out_channels=self.out_channels // 4, 245 | kernel_size=9, 246 | padding=4, 247 | # weight_attr=ParamAttr(initializer=weight_attr), 248 | bias_attr=False)) 249 | 250 | if i > 0: 251 | self.pan_head_conv.append( 252 | nn.Conv2d( 253 | in_channels=self.out_channels // 4, 254 | out_channels=self.out_channels // 4, 255 | kernel_size=3, 256 | padding=1, 257 | stride=2, 258 | # weight_attr=ParamAttr(initializer=weight_attr), 259 | bias=False)) 260 | self.pan_lat_conv.append( 261 | p_layer( 262 | in_channels=self.out_channels // 4, 263 | out_channels=self.out_channels // 4, 264 | kernel_size=9, 265 | padding=4, 266 | # weight_attr=ParamAttr(initializer=weight_attr), 267 | bias=False)) 268 | 269 | def _upsample_add(self, x, y): 270 | return F.interpolate(x, scale_factor=2) + y 271 | 272 | def _upsample_cat(self, p2, p3, p4, p5): 273 | p3 = F.interpolate(p3, scale_factor=2) 274 | p4 = F.interpolate(p4, scale_factor=4) 275 | p5 = F.interpolate(p5, scale_factor=8) 276 | return torch.cat([p5, p4, p3, p2], dim=1) 277 | 278 | def forward(self, x): 279 | c2, c3, c4, c5 = x 280 | 281 | in5 = self.ins_conv[3](c5) 282 | in4 = self.ins_conv[2](c4) 283 | in3 = self.ins_conv[1](c3) 284 | in2 = self.ins_conv[0](c2) 285 | 286 | out4 = self._upsample_add(in5, in4) 287 | out3 = self._upsample_add(out4, in3) 288 | out2 = self._upsample_add(out3, in2) 289 | 290 | f5 = self.inp_conv[3](in5) 291 | f4 = self.inp_conv[2](out4) 292 | f3 = self.inp_conv[1](out3) 293 | f2 = self.inp_conv[0](out2) 294 | 295 | pan3 = f3 + self.pan_head_conv[0](f2) 296 | pan4 = f4 + self.pan_head_conv[1](pan3) 297 | pan5 = f5 + self.pan_head_conv[2](pan4) 298 | 299 | p2 = self.pan_lat_conv[0](f2) 300 | p3 = self.pan_lat_conv[1](pan3) 301 | p4 = self.pan_lat_conv[2](pan4) 302 | p5 = self.pan_lat_conv[3](pan5) 303 | 304 | x = self._upsample_cat(p2, p3, p4, p5) 305 | return x 306 | 307 | -------------------------------------------------------------------------------- /rec/RecSARHead.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class SAREncoder(nn.Module): 7 | """ 8 | Args: 9 | enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. 10 | enc_drop_rnn (float): Dropout probability of RNN layer in encoder. 11 | enc_gru (bool): If True, use GRU, else LSTM in encoder. 12 | d_model (int): Dim of channels from backbone. 13 | d_enc (int): Dim of encoder RNN layer. 14 | mask (bool): If True, mask padding in RNN sequence. 15 | """ 16 | 17 | def __init__(self, 18 | enc_bi_rnn=False, 19 | enc_drop_rnn=0.1, 20 | enc_gru=False, 21 | d_model=512, 22 | d_enc=512, 23 | mask=True, 24 | **kwargs): 25 | super().__init__() 26 | assert isinstance(enc_bi_rnn, bool) 27 | assert isinstance(enc_drop_rnn, (int, float)) 28 | assert 0 <= enc_drop_rnn < 1.0 29 | assert isinstance(enc_gru, bool) 30 | assert isinstance(d_model, int) 31 | assert isinstance(d_enc, int) 32 | assert isinstance(mask, bool) 33 | 34 | self.enc_bi_rnn = enc_bi_rnn 35 | self.enc_drop_rnn = enc_drop_rnn 36 | self.mask = mask 37 | 38 | # LSTM Encoder 39 | if enc_bi_rnn: 40 | # direction = 'bidirectional' 41 | bidirectional = True 42 | else: 43 | # direction = 'forward' 44 | bidirectional = False 45 | kwargs = dict( 46 | input_size=d_model, 47 | hidden_size=d_enc, 48 | num_layers=2, 49 | # time_major=False, 50 | dropout=enc_drop_rnn, 51 | bidirectional=bidirectional) 52 | if enc_gru: 53 | self.rnn_encoder = nn.GRU(**kwargs) 54 | else: 55 | self.rnn_encoder = nn.LSTM(**kwargs) 56 | 57 | # global feature transformation 58 | encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) 59 | self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size) 60 | 61 | def forward(self, feat, img_metas=None): 62 | if img_metas is not None: 63 | assert len(img_metas[0]) == feat.shape[0] 64 | 65 | valid_ratios = None 66 | if img_metas is not None and self.mask: 67 | valid_ratios = img_metas[-1] 68 | 69 | h_feat = feat.shape[2] # bsz c h w 70 | feat_v = F.max_pool2d( 71 | feat, kernel_size=(h_feat, 1), stride=1, padding=0) 72 | feat_v = feat_v.squeeze(2) # bsz * C * W 73 | feat_v = feat_v.permute([0, 2, 1]) 74 | # feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C 75 | holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C 76 | 77 | if valid_ratios is not None: 78 | valid_hf = [] 79 | T = holistic_feat.shape[1] 80 | for i in range(len(valid_ratios)): 81 | valid_step = min(T, math.ceil(T * valid_ratios[i])) - 1 82 | valid_hf.append(holistic_feat[i, valid_step, :]) 83 | valid_hf = torch.stack(valid_hf, dim=0) 84 | else: 85 | valid_hf = holistic_feat[:, -1, :] # bsz * C 86 | holistic_feat = self.linear(valid_hf) # bsz * C 87 | 88 | return holistic_feat 89 | 90 | 91 | class BaseDecoder(nn.Module): 92 | def __init__(self, **kwargs): 93 | super().__init__() 94 | 95 | def forward_train(self, feat, out_enc, targets, img_metas): 96 | raise NotImplementedError 97 | 98 | def forward_test(self, feat, out_enc, img_metas): 99 | raise NotImplementedError 100 | 101 | def forward(self, 102 | feat, 103 | out_enc, 104 | label=None, 105 | img_metas=None, 106 | train_mode=True): 107 | self.train_mode = train_mode 108 | 109 | if train_mode: 110 | return self.forward_train(feat, out_enc, label, img_metas) 111 | return self.forward_test(feat, out_enc, img_metas) 112 | 113 | 114 | class ParallelSARDecoder(BaseDecoder): 115 | """ 116 | Args: 117 | out_channels (int): Output class number. 118 | enc_bi_rnn (bool): If True, use bidirectional RNN in encoder. 119 | dec_bi_rnn (bool): If True, use bidirectional RNN in decoder. 120 | dec_drop_rnn (float): Dropout of RNN layer in decoder. 121 | dec_gru (bool): If True, use GRU, else LSTM in decoder. 122 | d_model (int): Dim of channels from backbone. 123 | d_enc (int): Dim of encoder RNN layer. 124 | d_k (int): Dim of channels of attention module. 125 | pred_dropout (float): Dropout probability of prediction layer. 126 | max_seq_len (int): Maximum sequence length for decoding. 127 | mask (bool): If True, mask padding in feature map. 128 | start_idx (int): Index of start token. 129 | padding_idx (int): Index of padding token. 130 | pred_concat (bool): If True, concat glimpse feature from 131 | attention with holistic feature and hidden state. 132 | """ 133 | 134 | def __init__( 135 | self, 136 | out_channels, # 90 + unknown + start + padding 137 | enc_bi_rnn=False, 138 | dec_bi_rnn=False, 139 | dec_drop_rnn=0.0, 140 | dec_gru=False, 141 | d_model=512, 142 | d_enc=512, 143 | d_k=64, 144 | pred_dropout=0.1, 145 | max_text_length=30, 146 | mask=True, 147 | pred_concat=True, 148 | **kwargs): 149 | super().__init__() 150 | 151 | self.num_classes = out_channels 152 | self.enc_bi_rnn = enc_bi_rnn 153 | self.d_k = d_k 154 | self.start_idx = out_channels - 2 155 | self.padding_idx = out_channels - 1 156 | self.max_seq_len = max_text_length 157 | self.mask = mask 158 | self.pred_concat = pred_concat 159 | 160 | encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1) 161 | decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1) 162 | 163 | # 2D attention layer 164 | self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k) 165 | self.conv3x3_1 = nn.Conv2d( 166 | d_model, d_k, kernel_size=3, stride=1, padding=1) 167 | self.conv1x1_2 = nn.Linear(d_k, 1) 168 | 169 | # Decoder RNN layer 170 | if dec_bi_rnn: 171 | # direction = 'bidirectional' 172 | bidirectional = True 173 | else: 174 | bidirectional = False 175 | # direction = 'forward' 176 | 177 | kwargs = dict( 178 | input_size=encoder_rnn_out_size, 179 | hidden_size=encoder_rnn_out_size, 180 | num_layers=2, 181 | # time_major=False, 182 | dropout=dec_drop_rnn, 183 | bidirectional=bidirectional) 184 | if dec_gru: 185 | self.rnn_decoder = nn.GRU(**kwargs) 186 | else: 187 | self.rnn_decoder = nn.LSTM(**kwargs) 188 | 189 | # Decoder input embedding 190 | self.embedding = nn.Embedding( 191 | self.num_classes, 192 | encoder_rnn_out_size, 193 | padding_idx=self.padding_idx) 194 | 195 | # Prediction layer 196 | self.pred_dropout = nn.Dropout(pred_dropout) 197 | pred_num_classes = self.num_classes - 1 198 | if pred_concat: 199 | fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size 200 | else: 201 | fc_in_channel = d_model 202 | self.prediction = nn.Linear(fc_in_channel, pred_num_classes) 203 | 204 | def _2d_attention(self, 205 | decoder_input, 206 | feat, 207 | holistic_feat, 208 | valid_ratios=None): 209 | 210 | y = self.rnn_decoder(decoder_input)[0] 211 | # y: bsz * (seq_len + 1) * hidden_size 212 | 213 | attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size 214 | bsz, seq_len, attn_size = attn_query.shape 215 | attn_query = torch.unsqueeze(attn_query, dim=3) 216 | attn_query = torch.unsqueeze(attn_query, dim=4) 217 | # (bsz, seq_len + 1, attn_size, 1, 1) 218 | 219 | attn_key = self.conv3x3_1(feat) 220 | # bsz * attn_size * h * w 221 | attn_key = attn_key.unsqueeze(1) 222 | # bsz * 1 * attn_size * h * w 223 | 224 | attn_weight = torch.tanh(torch.add(attn_key, attn_query)) 225 | 226 | # bsz * (seq_len + 1) * attn_size * h * w 227 | attn_weight = attn_weight.permute([0,1,3,4,2]) 228 | # attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2]) 229 | # bsz * (seq_len + 1) * h * w * attn_size 230 | attn_weight = self.conv1x1_2(attn_weight) 231 | # bsz * (seq_len + 1) * h * w * 1 232 | bsz, T, h, w, c = attn_weight.shape 233 | assert c == 1 234 | 235 | if valid_ratios is not None: 236 | # cal mask of attention weight 237 | for i in range(len(valid_ratios)): 238 | valid_width = min(w, math.ceil(w * valid_ratios[i])) 239 | if valid_width < w: 240 | attn_weight[i, :, :, valid_width:, :] = float('-inf') 241 | 242 | attn_weight = torch.reshape(attn_weight, [bsz, T, -1]) 243 | attn_weight = F.softmax(attn_weight, dim=-1) 244 | 245 | attn_weight = torch.reshape(attn_weight, [bsz, T, h, w, c]) 246 | attn_weight = attn_weight.permute([0,1,4,2,3]) 247 | # attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3]) 248 | # attn_weight: bsz * T * c * h * w 249 | # feat: bsz * c * h * w 250 | attn_feat = torch.sum(torch.multiply(feat.unsqueeze(1), attn_weight), 251 | (3, 4), 252 | keepdim=False) 253 | # bsz * (seq_len + 1) * C 254 | 255 | # Linear transformation 256 | if self.pred_concat: 257 | hf_c = holistic_feat.shape[-1] 258 | # holistic_feat = paddle.expand(holistic_feat, shape=[bsz, seq_len, hf_c]) 259 | holistic_feat = holistic_feat.expand([bsz, seq_len, hf_c]) 260 | y = self.prediction(torch.cat((y, attn_feat, holistic_feat), 2)) 261 | else: 262 | y = self.prediction(attn_feat) 263 | # bsz * (seq_len + 1) * num_classes 264 | if self.train_mode: 265 | y = self.pred_dropout(y) 266 | 267 | return y 268 | 269 | def forward_train(self, feat, out_enc, label, img_metas): 270 | ''' 271 | img_metas: [label, valid_ratio] 272 | ''' 273 | if img_metas is not None: 274 | assert len(img_metas[0]) == feat.shape[0] 275 | 276 | valid_ratios = None 277 | if img_metas is not None and self.mask: 278 | valid_ratios = img_metas[-1] 279 | 280 | lab_embedding = self.embedding(label) 281 | # bsz * seq_len * emb_dim 282 | out_enc = out_enc.unsqueeze(1) 283 | # bsz * 1 * emb_dim 284 | in_dec = torch.cat((out_enc, lab_embedding), dim=1) 285 | # bsz * (seq_len + 1) * C 286 | out_dec = self._2d_attention( 287 | in_dec, feat, out_enc, valid_ratios=valid_ratios) 288 | # bsz * (seq_len + 1) * num_classes 289 | 290 | return out_dec[:, 1:, :] # bsz * seq_len * num_classes 291 | 292 | def forward_test(self, feat, out_enc, img_metas): 293 | if img_metas is not None: 294 | assert len(img_metas[0]) == feat.shape[0] 295 | 296 | valid_ratios = None 297 | if img_metas is not None and self.mask: 298 | valid_ratios = img_metas[-1] 299 | 300 | seq_len = self.max_seq_len 301 | bsz = feat.shape[0] 302 | start_token = torch.full( 303 | (bsz, ), fill_value=self.start_idx).long() 304 | # bsz 305 | start_token = self.embedding(start_token) 306 | # bsz * emb_dim 307 | emb_dim = start_token.shape[1] 308 | start_token = start_token.unsqueeze(1) 309 | start_token = start_token.expand([bsz, seq_len, emb_dim]) 310 | # start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim]) 311 | # bsz * seq_len * emb_dim 312 | out_enc = out_enc.unsqueeze(1) 313 | # bsz * 1 * emb_dim 314 | decoder_input = torch.cat((out_enc, start_token), dim=1) 315 | # bsz * (seq_len + 1) * emb_dim 316 | 317 | outputs = [] 318 | for i in range(1, seq_len + 1): 319 | decoder_output = self._2d_attention( 320 | decoder_input, feat, out_enc, valid_ratios=valid_ratios) 321 | char_output = decoder_output[:, i, :] # bsz * num_classes 322 | char_output = F.softmax(char_output, -1) 323 | outputs.append(char_output) 324 | max_idx = torch.argmax(char_output, dim=1, keepdim=False) 325 | char_embedding = self.embedding(max_idx) # bsz * emb_dim 326 | if i < seq_len: 327 | decoder_input[:, i + 1, :] = char_embedding 328 | 329 | outputs = torch.stack(outputs, 1) # bsz * seq_len * num_classes 330 | 331 | return outputs 332 | 333 | class SARHead(nn.Module): 334 | def __init__(self, 335 | in_channels, 336 | out_channels, 337 | enc_dim=512, 338 | max_text_length=30, 339 | enc_bi_rnn=False, 340 | enc_drop_rnn=0.1, 341 | enc_gru=False, 342 | dec_bi_rnn=False, 343 | dec_drop_rnn=0.0, 344 | dec_gru=False, 345 | d_k=512, 346 | pred_dropout=0.1, 347 | pred_concat=True, 348 | **kwargs): 349 | super(SARHead, self).__init__() 350 | 351 | # encoder module 352 | self.encoder = SAREncoder( 353 | enc_bi_rnn=enc_bi_rnn, 354 | enc_drop_rnn=enc_drop_rnn, 355 | enc_gru=enc_gru, 356 | d_model=in_channels, 357 | d_enc=enc_dim) 358 | 359 | # decoder module 360 | self.decoder = ParallelSARDecoder( 361 | out_channels=out_channels, 362 | enc_bi_rnn=enc_bi_rnn, 363 | dec_bi_rnn=dec_bi_rnn, 364 | dec_drop_rnn=dec_drop_rnn, 365 | dec_gru=dec_gru, 366 | d_model=in_channels, 367 | d_enc=enc_dim, 368 | d_k=d_k, 369 | pred_dropout=pred_dropout, 370 | max_text_length=max_text_length, 371 | pred_concat=pred_concat) 372 | 373 | def forward(self, feat, targets=None): 374 | ''' 375 | img_metas: [label, valid_ratio] 376 | ''' 377 | 378 | holistic_feat = self.encoder(feat, targets) # bsz c 379 | 380 | if self.training: 381 | label = targets[0] # label 382 | label = torch.tensor(label).long() 383 | final_out = self.decoder( 384 | feat, holistic_feat, label, img_metas=targets) 385 | else: 386 | final_out = self.decoder( 387 | feat, 388 | holistic_feat, 389 | label=None, 390 | img_metas=targets, 391 | train_mode=False) 392 | # (bsz, seq_len, num_classes) 393 | 394 | return final_out 395 | 396 | if __name__=="__main__": 397 | sarh = SARHead(512,6625) 398 | print(sarh) -------------------------------------------------------------------------------- /rec/RecSVTR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.nn.init import trunc_normal_,constant,normal_,zeros_,ones_ 5 | from torch.nn import functional 6 | 7 | def drop_path(x, drop_prob=0., training=False): 8 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 9 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 10 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... 11 | """ 12 | if drop_prob == 0. or not training: 13 | return x 14 | keep_prob = torch.tensor(1 - drop_prob) 15 | shape = (x.size()[0], ) + (1, ) * (x.ndim - 1) 16 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype) 17 | random_tensor = torch.floor(random_tensor) # binarize 18 | output = x.divide(keep_prob) * random_tensor 19 | return output 20 | 21 | class Swish(nn.Module): 22 | def __int__(self): 23 | super(Swish, self).__int__() 24 | 25 | def forward(self,x): 26 | return x*torch.sigmoid(x) 27 | 28 | class ConvBNLayer(nn.Module): 29 | def __init__(self, 30 | in_channels, 31 | out_channels, 32 | kernel_size=3, 33 | stride=1, 34 | padding=0, 35 | bias_attr=False, 36 | groups=1, 37 | act=nn.GELU): 38 | super().__init__() 39 | self.conv = nn.Conv2d( 40 | in_channels=in_channels, 41 | out_channels=out_channels, 42 | kernel_size=kernel_size, 43 | stride=stride, 44 | padding=padding, 45 | groups=groups, 46 | # weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()), 47 | bias=bias_attr) 48 | self.norm = nn.BatchNorm2d(out_channels) 49 | self.act = act() 50 | 51 | def forward(self, inputs): 52 | out = self.conv(inputs) 53 | out = self.norm(out) 54 | out = self.act(out) 55 | return out 56 | 57 | 58 | class DropPath(nn.Module): 59 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 60 | """ 61 | 62 | def __init__(self, drop_prob=None): 63 | super(DropPath, self).__init__() 64 | self.drop_prob = drop_prob 65 | 66 | def forward(self, x): 67 | return drop_path(x, self.drop_prob, self.training) 68 | 69 | 70 | class Identity(nn.Module): 71 | def __init__(self): 72 | super(Identity, self).__init__() 73 | 74 | def forward(self, input): 75 | return input 76 | 77 | 78 | class Mlp(nn.Module): 79 | def __init__(self, 80 | in_features, 81 | hidden_features=None, 82 | out_features=None, 83 | act_layer=nn.GELU, 84 | drop=0.): 85 | super().__init__() 86 | out_features = out_features or in_features 87 | hidden_features = hidden_features or in_features 88 | self.fc1 = nn.Linear(in_features, hidden_features) 89 | if isinstance(act_layer, str): 90 | self.act = Swish() 91 | else: 92 | self.act = act_layer() 93 | self.fc2 = nn.Linear(hidden_features, out_features) 94 | self.drop = nn.Dropout(drop) 95 | 96 | def forward(self, x): 97 | x = self.fc1(x) 98 | x = self.act(x) 99 | x = self.drop(x) 100 | x = self.fc2(x) 101 | x = self.drop(x) 102 | return x 103 | 104 | 105 | class ConvMixer(nn.Module): 106 | def __init__( 107 | self, 108 | dim, 109 | num_heads=8, 110 | HW=(8, 25), 111 | local_k=(3, 3), ): 112 | super().__init__() 113 | self.HW = HW 114 | self.dim = dim 115 | self.local_mixer = nn.Conv2d( 116 | dim, 117 | dim, 118 | local_k, 119 | 1, (local_k[0] // 2, local_k[1] // 2), 120 | groups=num_heads, 121 | # weight_attr=ParamAttr(initializer=KaimingNormal()) 122 | ) 123 | 124 | def forward(self, x): 125 | h = self.HW[0] 126 | w = self.HW[1] 127 | x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w]) 128 | x = self.local_mixer(x) 129 | x = x.flatten(2).transpose([0, 2, 1]) 130 | return x 131 | 132 | 133 | class Attention(nn.Module): 134 | def __init__(self, 135 | dim, 136 | num_heads=8, 137 | mixer='Global', 138 | HW=(8, 25), 139 | local_k=(7, 11), 140 | qkv_bias=False, 141 | qk_scale=None, 142 | attn_drop=0., 143 | proj_drop=0.): 144 | super().__init__() 145 | self.num_heads = num_heads 146 | head_dim = dim // num_heads 147 | self.scale = qk_scale or head_dim**-0.5 148 | 149 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 150 | self.attn_drop = nn.Dropout(attn_drop) 151 | self.proj = nn.Linear(dim, dim) 152 | self.proj_drop = nn.Dropout(proj_drop) 153 | self.HW = HW 154 | if HW is not None: 155 | H = HW[0] 156 | W = HW[1] 157 | self.N = H * W 158 | self.C = dim 159 | if mixer == 'Local' and HW is not None: 160 | hk = local_k[0] 161 | wk = local_k[1] 162 | mask = torch.ones([H * W, H + hk - 1, W + wk - 1]) 163 | for h in range(0, H): 164 | for w in range(0, W): 165 | mask[h * W + w, h:h + hk, w:w + wk] = 0. 166 | mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk // 167 | 2].flatten(1) 168 | mask_inf = torch.full([H * W, H * W],fill_value=float('-inf')) 169 | mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf) 170 | self.mask = mask[None,None,:] 171 | # self.mask = mask.unsqueeze([0, 1]) 172 | self.mixer = mixer 173 | 174 | def forward(self, x): 175 | if self.HW is not None: 176 | N = self.N 177 | C = self.C 178 | else: 179 | _, N, C = x.shape 180 | qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C //self.num_heads)).permute((2, 0, 3, 1, 4)) 181 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 182 | 183 | attn = (q.matmul(k.permute((0, 1, 3, 2)))) 184 | if self.mixer == 'Local': 185 | attn += self.mask 186 | attn = functional.softmax(attn, dim=-1) 187 | attn = self.attn_drop(attn) 188 | 189 | x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C)) 190 | x = self.proj(x) 191 | x = self.proj_drop(x) 192 | return x 193 | 194 | 195 | class Block(nn.Module): 196 | def __init__(self, 197 | dim, 198 | num_heads, 199 | mixer='Global', 200 | local_mixer=(7, 11), 201 | HW=(8, 25), 202 | mlp_ratio=4., 203 | qkv_bias=False, 204 | qk_scale=None, 205 | drop=0., 206 | attn_drop=0., 207 | drop_path=0., 208 | act_layer=nn.GELU, 209 | norm_layer='nn.LayerNorm', 210 | epsilon=1e-6, 211 | prenorm=True): 212 | super().__init__() 213 | if isinstance(norm_layer, str): 214 | self.norm1 = eval(norm_layer)(dim, eps=epsilon) 215 | else: 216 | self.norm1 = norm_layer(dim) 217 | if mixer == 'Global' or mixer == 'Local': 218 | 219 | self.mixer = Attention( 220 | dim, 221 | num_heads=num_heads, 222 | mixer=mixer, 223 | HW=HW, 224 | local_k=local_mixer, 225 | qkv_bias=qkv_bias, 226 | qk_scale=qk_scale, 227 | attn_drop=attn_drop, 228 | proj_drop=drop) 229 | elif mixer == 'Conv': 230 | self.mixer = ConvMixer( 231 | dim, num_heads=num_heads, HW=HW, local_k=local_mixer) 232 | else: 233 | raise TypeError("The mixer must be one of [Global, Local, Conv]") 234 | 235 | self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() 236 | if isinstance(norm_layer, str): 237 | self.norm2 = eval(norm_layer)(dim, eps=epsilon) 238 | else: 239 | self.norm2 = norm_layer(dim) 240 | mlp_hidden_dim = int(dim * mlp_ratio) 241 | self.mlp_ratio = mlp_ratio 242 | self.mlp = Mlp(in_features=dim, 243 | hidden_features=mlp_hidden_dim, 244 | act_layer=act_layer, 245 | drop=drop) 246 | self.prenorm = prenorm 247 | 248 | def forward(self, x): 249 | if self.prenorm: 250 | x = self.norm1(x + self.drop_path(self.mixer(x))) 251 | x = self.norm2(x + self.drop_path(self.mlp(x))) 252 | else: 253 | x = x + self.drop_path(self.mixer(self.norm1(x))) 254 | x = x + self.drop_path(self.mlp(self.norm2(x))) 255 | return x 256 | 257 | 258 | class PatchEmbed(nn.Module): 259 | """ Image to Patch Embedding 260 | """ 261 | 262 | def __init__(self, 263 | img_size=(32, 100), 264 | in_channels=3, 265 | embed_dim=768, 266 | sub_num=2): 267 | super().__init__() 268 | num_patches = (img_size[1] // (2 ** sub_num)) * \ 269 | (img_size[0] // (2 ** sub_num)) 270 | self.img_size = img_size 271 | self.num_patches = num_patches 272 | self.embed_dim = embed_dim 273 | self.norm = None 274 | if sub_num == 2: 275 | self.proj = nn.Sequential( 276 | ConvBNLayer( 277 | in_channels=in_channels, 278 | out_channels=embed_dim // 2, 279 | kernel_size=3, 280 | stride=2, 281 | padding=1, 282 | act=nn.GELU, 283 | bias_attr=False), 284 | ConvBNLayer( 285 | in_channels=embed_dim // 2, 286 | out_channels=embed_dim, 287 | kernel_size=3, 288 | stride=2, 289 | padding=1, 290 | act=nn.GELU, 291 | bias_attr=False)) 292 | if sub_num == 3: 293 | self.proj = nn.Sequential( 294 | ConvBNLayer( 295 | in_channels=in_channels, 296 | out_channels=embed_dim // 4, 297 | kernel_size=3, 298 | stride=2, 299 | padding=1, 300 | act=nn.GELU, 301 | bias_attr=False), 302 | ConvBNLayer( 303 | in_channels=embed_dim // 4, 304 | out_channels=embed_dim // 2, 305 | kernel_size=3, 306 | stride=2, 307 | padding=1, 308 | act=nn.GELU, 309 | bias_attr=False), 310 | ConvBNLayer( 311 | in_channels=embed_dim // 2, 312 | out_channels=embed_dim, 313 | kernel_size=3, 314 | stride=2, 315 | padding=1, 316 | act=nn.GELU, 317 | bias_attr=False)) 318 | 319 | def forward(self, x): 320 | B, C, H, W = x.shape 321 | assert H == self.img_size[0] and W == self.img_size[1], \ 322 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 323 | x = self.proj(x).flatten(2).permute(0, 2, 1) 324 | return x 325 | 326 | 327 | class SubSample(nn.Module): 328 | def __init__(self, 329 | in_channels, 330 | out_channels, 331 | types='Pool', 332 | stride=(2, 1), 333 | sub_norm='nn.LayerNorm', 334 | act=None): 335 | super().__init__() 336 | self.types = types 337 | if types == 'Pool': 338 | self.avgpool = nn.AvgPool2d( 339 | kernel_size=(3, 5), stride=stride, padding=(1, 2)) 340 | self.maxpool = nn.MaxPool2d( 341 | kernel_size=(3, 5), stride=stride, padding=(1, 2)) 342 | self.proj = nn.Linear(in_channels, out_channels) 343 | else: 344 | self.conv = nn.Conv2d( 345 | in_channels, 346 | out_channels, 347 | kernel_size=3, 348 | stride=stride, 349 | padding=1, 350 | # weight_attr=ParamAttr(initializer=KaimingNormal()) 351 | ) 352 | self.norm = eval(sub_norm)(out_channels) 353 | if act is not None: 354 | self.act = act() 355 | else: 356 | self.act = None 357 | 358 | def forward(self, x): 359 | 360 | if self.types == 'Pool': 361 | x1 = self.avgpool(x) 362 | x2 = self.maxpool(x) 363 | x = (x1 + x2) * 0.5 364 | out = self.proj(x.flatten(2).permute((0, 2, 1))) 365 | else: 366 | x = self.conv(x) 367 | out = x.flatten(2).permute((0, 2, 1)) 368 | out = self.norm(out) 369 | if self.act is not None: 370 | out = self.act(out) 371 | 372 | return out 373 | 374 | 375 | class SVTRNet(nn.Module): 376 | def __init__( 377 | self, 378 | img_size=[48, 100], 379 | in_channels=3, 380 | embed_dim=[64, 128, 256], 381 | depth=[3, 6, 3], 382 | num_heads=[2, 4, 8], 383 | mixer=['Local'] * 6 + ['Global'] * 384 | 6, # Local atten, Global atten, Conv 385 | local_mixer=[[7, 11], [7, 11], [7, 11]], 386 | patch_merging='Conv', # Conv, Pool, None 387 | mlp_ratio=4, 388 | qkv_bias=True, 389 | qk_scale=None, 390 | drop_rate=0., 391 | last_drop=0.1, 392 | attn_drop_rate=0., 393 | drop_path_rate=0.1, 394 | norm_layer='nn.LayerNorm', 395 | sub_norm='nn.LayerNorm', 396 | epsilon=1e-6, 397 | out_channels=192, 398 | out_char_num=25, 399 | block_unit='Block', 400 | act='nn.GELU', 401 | last_stage=True, 402 | sub_num=2, 403 | prenorm=True, 404 | use_lenhead=False, 405 | **kwargs): 406 | super().__init__() 407 | self.img_size = img_size 408 | self.embed_dim = embed_dim 409 | self.out_channels = out_channels 410 | self.prenorm = prenorm 411 | patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging 412 | self.patch_embed = PatchEmbed( 413 | img_size=img_size, 414 | in_channels=in_channels, 415 | embed_dim=embed_dim[0], 416 | sub_num=sub_num) 417 | num_patches = self.patch_embed.num_patches 418 | self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)] 419 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0])) 420 | # self.pos_embed = self.create_parameter( 421 | # shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_) 422 | 423 | # self.add_parameter("pos_embed", self.pos_embed) 424 | 425 | self.pos_drop = nn.Dropout(p=drop_rate) 426 | Block_unit = eval(block_unit) 427 | 428 | dpr = np.linspace(0, drop_path_rate, sum(depth)) 429 | self.blocks1 = nn.ModuleList( 430 | [ 431 | Block_unit( 432 | dim=embed_dim[0], 433 | num_heads=num_heads[0], 434 | mixer=mixer[0:depth[0]][i], 435 | HW=self.HW, 436 | local_mixer=local_mixer[0], 437 | mlp_ratio=mlp_ratio, 438 | qkv_bias=qkv_bias, 439 | qk_scale=qk_scale, 440 | drop=drop_rate, 441 | act_layer=eval(act), 442 | attn_drop=attn_drop_rate, 443 | drop_path=dpr[0:depth[0]][i], 444 | norm_layer=norm_layer, 445 | epsilon=epsilon, 446 | prenorm=prenorm) for i in range(depth[0]) 447 | ] 448 | ) 449 | if patch_merging is not None: 450 | self.sub_sample1 = SubSample( 451 | embed_dim[0], 452 | embed_dim[1], 453 | sub_norm=sub_norm, 454 | stride=[2, 1], 455 | types=patch_merging) 456 | HW = [self.HW[0] // 2, self.HW[1]] 457 | else: 458 | HW = self.HW 459 | self.patch_merging = patch_merging 460 | self.blocks2 = nn.ModuleList([ 461 | Block_unit( 462 | dim=embed_dim[1], 463 | num_heads=num_heads[1], 464 | mixer=mixer[depth[0]:depth[0] + depth[1]][i], 465 | HW=HW, 466 | local_mixer=local_mixer[1], 467 | mlp_ratio=mlp_ratio, 468 | qkv_bias=qkv_bias, 469 | qk_scale=qk_scale, 470 | drop=drop_rate, 471 | act_layer=eval(act), 472 | attn_drop=attn_drop_rate, 473 | drop_path=dpr[depth[0]:depth[0] + depth[1]][i], 474 | norm_layer=norm_layer, 475 | epsilon=epsilon, 476 | prenorm=prenorm) for i in range(depth[1]) 477 | ]) 478 | if patch_merging is not None: 479 | self.sub_sample2 = SubSample( 480 | embed_dim[1], 481 | embed_dim[2], 482 | sub_norm=sub_norm, 483 | stride=[2, 1], 484 | types=patch_merging) 485 | HW = [self.HW[0] // 4, self.HW[1]] 486 | else: 487 | HW = self.HW 488 | self.blocks3 = nn.ModuleList([ 489 | Block_unit( 490 | dim=embed_dim[2], 491 | num_heads=num_heads[2], 492 | mixer=mixer[depth[0] + depth[1]:][i], 493 | HW=HW, 494 | local_mixer=local_mixer[2], 495 | mlp_ratio=mlp_ratio, 496 | qkv_bias=qkv_bias, 497 | qk_scale=qk_scale, 498 | drop=drop_rate, 499 | act_layer=eval(act), 500 | attn_drop=attn_drop_rate, 501 | drop_path=dpr[depth[0] + depth[1]:][i], 502 | norm_layer=norm_layer, 503 | epsilon=epsilon, 504 | prenorm=prenorm) for i in range(depth[2]) 505 | ]) 506 | self.last_stage = last_stage 507 | if last_stage: 508 | self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num)) 509 | self.last_conv = nn.Conv2d( 510 | in_channels=embed_dim[2], 511 | out_channels=self.out_channels, 512 | kernel_size=1, 513 | stride=1, 514 | padding=0, 515 | bias=False) 516 | self.hardswish = nn.Hardswish() 517 | self.dropout = nn.Dropout(p=last_drop) 518 | if not prenorm: 519 | self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon) 520 | self.use_lenhead = use_lenhead 521 | if use_lenhead: 522 | self.len_conv = nn.Linear(embed_dim[2], self.out_channels) 523 | self.hardswish_len = nn.Hardswish() 524 | self.dropout_len = nn.Dropout( 525 | p=last_drop) 526 | 527 | trunc_normal_(self.pos_embed,std=.02) 528 | self.apply(self._init_weights) 529 | 530 | def _init_weights(self, m): 531 | if isinstance(m, nn.Linear): 532 | trunc_normal_(m.weight,std=.02) 533 | if isinstance(m, nn.Linear) and m.bias is not None: 534 | zeros_(m.bias) 535 | elif isinstance(m, nn.LayerNorm): 536 | zeros_(m.bias) 537 | ones_(m.weight) 538 | 539 | def forward_features(self, x): 540 | x = self.patch_embed(x) 541 | x = x + self.pos_embed 542 | x = self.pos_drop(x) 543 | for blk in self.blocks1: 544 | x = blk(x) 545 | if self.patch_merging is not None: 546 | x = self.sub_sample1( 547 | x.permute([0, 2, 1]).reshape( 548 | [-1, self.embed_dim[0], self.HW[0], self.HW[1]])) 549 | for blk in self.blocks2: 550 | x = blk(x) 551 | if self.patch_merging is not None: 552 | x = self.sub_sample2( 553 | x.permute([0, 2, 1]).reshape( 554 | [-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]])) 555 | for blk in self.blocks3: 556 | x = blk(x) 557 | if not self.prenorm: 558 | x = self.norm(x) 559 | return x 560 | 561 | def forward(self, x): 562 | x = self.forward_features(x) 563 | if self.use_lenhead: 564 | len_x = self.len_conv(x.mean(1)) 565 | len_x = self.dropout_len(self.hardswish_len(len_x)) 566 | if self.last_stage: 567 | if self.patch_merging is not None: 568 | h = self.HW[0] // 4 569 | else: 570 | h = self.HW[0] 571 | x = self.avg_pool( 572 | x.permute([0, 2, 1]).reshape( 573 | [-1, self.embed_dim[2], h, self.HW[1]])) 574 | x = self.last_conv(x) 575 | x = self.hardswish(x) 576 | x = self.dropout(x) 577 | if self.use_lenhead: 578 | return x, len_x 579 | return x 580 | 581 | 582 | if __name__=="__main__": 583 | a = torch.rand(1,3,48,100) 584 | svtr = SVTRNet() 585 | 586 | out = svtr(a) 587 | print(svtr) 588 | print(out.size()) -------------------------------------------------------------------------------- /onnx_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import time 5 | import onnx 6 | import math 7 | import copy 8 | import onnxruntime 9 | import numpy as np 10 | import pyclipper 11 | from shapely.geometry import Polygon 12 | 13 | # PalldeOCR 检测模块 需要用到的图片预处理类 14 | class NormalizeImage(object): 15 | """ normalize image such as substract mean, divide std 16 | """ 17 | 18 | def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): 19 | if isinstance(scale, str): 20 | scale = eval(scale) 21 | self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) 22 | mean = mean if mean is not None else [0.485, 0.456, 0.406] 23 | std = std if std is not None else [0.229, 0.224, 0.225] 24 | 25 | shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) 26 | self.mean = np.array(mean).reshape(shape).astype('float32') 27 | self.std = np.array(std).reshape(shape).astype('float32') 28 | 29 | def __call__(self, data): 30 | img = data['image'] 31 | from PIL import Image 32 | if isinstance(img, Image.Image): 33 | img = np.array(img) 34 | 35 | assert isinstance(img, 36 | np.ndarray), "invalid input 'img' in NormalizeImage" 37 | data['image'] = ( 38 | img.astype('float32') * self.scale - self.mean) / self.std 39 | return data 40 | 41 | 42 | class ToCHWImage(object): 43 | """ convert hwc image to chw image 44 | """ 45 | 46 | def __init__(self, **kwargs): 47 | pass 48 | 49 | def __call__(self, data): 50 | img = data['image'] 51 | from PIL import Image 52 | if isinstance(img, Image.Image): 53 | img = np.array(img) 54 | data['image'] = img.transpose((2, 0, 1)) 55 | return data 56 | 57 | 58 | class KeepKeys(object): 59 | def __init__(self, keep_keys, **kwargs): 60 | self.keep_keys = keep_keys 61 | 62 | def __call__(self, data): 63 | data_list = [] 64 | for key in self.keep_keys: 65 | data_list.append(data[key]) 66 | return data_list 67 | 68 | class DetResizeForTest(object): 69 | def __init__(self, **kwargs): 70 | super(DetResizeForTest, self).__init__() 71 | self.resize_type = 0 72 | self.limit_side_len = kwargs['limit_side_len'] 73 | self.limit_type = kwargs.get('limit_type', 'min') 74 | 75 | def __call__(self, data): 76 | img = data['image'] 77 | 78 | src_h, src_w, _ = img.shape 79 | img, [ratio_h, ratio_w] = self.resize_image_type0(img) 80 | 81 | data['image'] = img 82 | data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) 83 | return data 84 | 85 | def resize_image_type0(self, img): 86 | """ 87 | resize image to a size multiple of 32 which is required by the network 88 | args: 89 | img(array): array with shape [h, w, c] 90 | return(tuple): 91 | img, (ratio_h, ratio_w) 92 | """ 93 | limit_side_len = self.limit_side_len 94 | h, w, _ = img.shape 95 | 96 | # limit the max side 97 | if max(h, w) > limit_side_len: 98 | if h > w: 99 | ratio = float(limit_side_len) / h 100 | else: 101 | ratio = float(limit_side_len) / w 102 | else: 103 | ratio = 1. 104 | resize_h = int(h * ratio) 105 | resize_w = int(w * ratio) 106 | 107 | 108 | resize_h = int(round(resize_h / 32) * 32) 109 | resize_w = int(round(resize_w / 32) * 32) 110 | 111 | try: 112 | if int(resize_w) <= 0 or int(resize_h) <= 0: 113 | return None, (None, None) 114 | img = cv2.resize(img, (int(resize_w), int(resize_h))) 115 | except: 116 | print(img.shape, resize_w, resize_h) 117 | sys.exit(0) 118 | ratio_h = resize_h / float(h) 119 | ratio_w = resize_w / float(w) 120 | # return img, np.array([h, w]) 121 | return img, [ratio_h, ratio_w] 122 | 123 | ### 检测结果后处理过程(得到检测框) 124 | class DBPostProcess(object): 125 | """ 126 | The post process for Differentiable Binarization (DB). 127 | """ 128 | 129 | def __init__(self, 130 | thresh=0.3, 131 | box_thresh=0.7, 132 | max_candidates=1000, 133 | unclip_ratio=2.0, 134 | use_dilation=False, 135 | **kwargs): 136 | self.thresh = thresh 137 | self.box_thresh = box_thresh 138 | self.max_candidates = max_candidates 139 | self.unclip_ratio = unclip_ratio 140 | self.min_size = 3 141 | self.dilation_kernel = None if not use_dilation else np.array( 142 | [[1, 1], [1, 1]]) 143 | 144 | def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): 145 | ''' 146 | _bitmap: single map with shape (1, H, W), 147 | whose values are binarized as {0, 1} 148 | ''' 149 | 150 | bitmap = _bitmap 151 | height, width = bitmap.shape 152 | # cv2.imshow("mask",(bitmap * 255).astype(np.uint8)) 153 | outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, 154 | cv2.CHAIN_APPROX_SIMPLE) 155 | if len(outs) == 3: 156 | img, contours, _ = outs[0], outs[1], outs[2] 157 | elif len(outs) == 2: 158 | contours, _ = outs[0], outs[1] 159 | 160 | num_contours = min(len(contours), self.max_candidates) 161 | 162 | boxes = [] 163 | scores = [] 164 | for index in range(num_contours): 165 | contour = contours[index] 166 | points, sside = self.get_mini_boxes(contour) 167 | if sside < self.min_size: 168 | continue 169 | points = np.array(points) 170 | score = self.box_score_fast(pred, points.reshape(-1, 2)) 171 | if self.box_thresh > score: 172 | continue 173 | 174 | box = self.unclip(points).reshape(-1, 1, 2) 175 | box, sside = self.get_mini_boxes(box) 176 | if sside < self.min_size + 2: 177 | continue 178 | box = np.array(box) 179 | 180 | box[:, 0] = np.clip( 181 | np.round(box[:, 0] / width * dest_width), 0, dest_width) 182 | box[:, 1] = np.clip( 183 | np.round(box[:, 1] / height * dest_height), 0, dest_height) 184 | boxes.append(box.astype(np.int16)) 185 | scores.append(score) 186 | return np.array(boxes, dtype=np.int16), scores 187 | 188 | def unclip(self, box): 189 | unclip_ratio = self.unclip_ratio 190 | poly = Polygon(box) 191 | distance = poly.area * unclip_ratio / poly.length 192 | offset = pyclipper.PyclipperOffset() 193 | offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 194 | expanded = np.array(offset.Execute(distance)) 195 | return expanded 196 | 197 | def get_mini_boxes(self, contour): 198 | bounding_box = cv2.minAreaRect(contour) 199 | points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) 200 | 201 | index_1, index_2, index_3, index_4 = 0, 1, 2, 3 202 | if points[1][1] > points[0][1]: 203 | index_1 = 0 204 | index_4 = 1 205 | else: 206 | index_1 = 1 207 | index_4 = 0 208 | if points[3][1] > points[2][1]: 209 | index_2 = 2 210 | index_3 = 3 211 | else: 212 | index_2 = 3 213 | index_3 = 2 214 | 215 | box = [ 216 | points[index_1], points[index_2], points[index_3], points[index_4] 217 | ] 218 | return box, min(bounding_box[1]) 219 | 220 | def box_score_fast(self, bitmap, _box): 221 | h, w = bitmap.shape[:2] 222 | box = _box.copy() 223 | xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) 224 | xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) 225 | ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) 226 | ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) 227 | 228 | mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) 229 | box[:, 0] = box[:, 0] - xmin 230 | box[:, 1] = box[:, 1] - ymin 231 | cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) 232 | return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] 233 | 234 | def __call__(self, outs_dict, shape_list): 235 | pred = outs_dict 236 | pred = pred[:, 0, :, :] 237 | segmentation = pred > self.thresh 238 | 239 | boxes_batch = [] 240 | for batch_index in range(pred.shape[0]): 241 | src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] # 图片缩放比例 242 | if self.dilation_kernel is not None: 243 | mask = cv2.dilate( 244 | np.array(segmentation[batch_index]).astype(np.uint8), 245 | self.dilation_kernel) 246 | else: 247 | mask = segmentation[batch_index] 248 | boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, 249 | src_w, src_h) 250 | boxes_batch.append({'points': boxes}) 251 | return boxes_batch 252 | 253 | ## 根据推理结果解码识别结果 254 | class process_pred(object): 255 | def __init__(self, character_dict_path=None, character_type='ch', use_space_char=False): 256 | self.character_str = '' 257 | with open(character_dict_path, 'rb') as fin: 258 | lines = fin.readlines() 259 | for line in lines: 260 | line = line.decode('utf-8').strip('\n').strip('\r\n') 261 | self.character_str += line 262 | if use_space_char: 263 | self.character_str += ' ' 264 | dict_character = list(self.character_str) 265 | 266 | dict_character = self.add_special_char(dict_character) 267 | self.dict = {} 268 | for i, char in enumerate(dict_character): 269 | self.dict[char] = i 270 | self.character = dict_character 271 | 272 | def add_special_char(self, dict_character): 273 | dict_character = ['blank'] + dict_character 274 | return dict_character 275 | 276 | def decode(self, text_index, text_prob=None, is_remove_duplicate=False): 277 | result_list = [] 278 | ignored_tokens = [0] 279 | batch_size = len(text_index) 280 | for batch_idx in range(batch_size): 281 | char_list = [] 282 | conf_list = [] 283 | for idx in range(len(text_index[batch_idx])): 284 | if text_index[batch_idx][idx] in ignored_tokens: 285 | continue 286 | if is_remove_duplicate: 287 | if idx > 0 and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]: 288 | continue 289 | char_list.append(self.character[int(text_index[batch_idx][idx])]) 290 | if text_prob is not None: 291 | conf_list.append(text_prob[batch_idx][idx]) 292 | else: 293 | conf_list.append(1) 294 | text = ''.join(char_list) 295 | result_list.append((text, np.mean(conf_list))) 296 | return result_list 297 | 298 | def __call__(self, preds, label=None): 299 | if not isinstance(preds, np.ndarray): 300 | preds = np.array(preds) 301 | preds_idx = preds.argmax(axis=2) 302 | preds_prob = preds.max(axis=2) 303 | text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) 304 | if label is None: 305 | return text 306 | label = self.decode(label) 307 | return text, label 308 | 309 | 310 | class det_rec_functions(object): 311 | 312 | def __init__(self, image,det_path,rec_path,keys_path): 313 | self.img = image.copy() 314 | self.det_file = det_path 315 | self.small_rec_file = rec_path 316 | self.onet_det_session = onnxruntime.InferenceSession(self.det_file) 317 | self.onet_rec_session = onnxruntime.InferenceSession(self.small_rec_file) 318 | self.infer_before_process_op, self.det_re_process_op = self.get_process() 319 | self.postprocess_op = process_pred(keys_path, 'ch', True) 320 | 321 | ## 图片预处理过程 322 | def transform(self, data, ops=None): 323 | """ transform """ 324 | if ops is None: 325 | ops = [] 326 | for op in ops: 327 | data = op(data) 328 | if data is None: 329 | return None 330 | return data 331 | 332 | def create_operators(self, op_param_list, global_config=None): 333 | """ 334 | create operators based on the config 335 | 336 | Args: 337 | params(list): a dict list, used to create some operators 338 | """ 339 | assert isinstance(op_param_list, list), ('operator config should be a list') 340 | ops = [] 341 | for operator in op_param_list: 342 | assert isinstance(operator, 343 | dict) and len(operator) == 1, "yaml format error" 344 | op_name = list(operator)[0] 345 | param = {} if operator[op_name] is None else operator[op_name] 346 | if global_config is not None: 347 | param.update(global_config) 348 | op = eval(op_name)(**param) 349 | ops.append(op) 350 | return ops 351 | 352 | ### 检测框的后处理 353 | def order_points_clockwise(self, pts): 354 | """ 355 | reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py 356 | # sort the points based on their x-coordinates 357 | """ 358 | xSorted = pts[np.argsort(pts[:, 0]), :] 359 | 360 | # grab the left-most and right-most points from the sorted 361 | # x-roodinate points 362 | leftMost = xSorted[:2, :] 363 | rightMost = xSorted[2:, :] 364 | 365 | # now, sort the left-most coordinates according to their 366 | # y-coordinates so we can grab the top-left and bottom-left 367 | # points, respectively 368 | leftMost = leftMost[np.argsort(leftMost[:, 1]), :] 369 | (tl, bl) = leftMost 370 | 371 | rightMost = rightMost[np.argsort(rightMost[:, 1]), :] 372 | (tr, br) = rightMost 373 | 374 | rect = np.array([tl, tr, br, bl], dtype="float32") 375 | return rect 376 | 377 | def clip_det_res(self, points, img_height, img_width): 378 | for pno in range(points.shape[0]): 379 | points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) 380 | points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) 381 | return points 382 | 383 | def filter_tag_det_res(self, dt_boxes, image_shape): 384 | img_height, img_width = image_shape[0:2] 385 | dt_boxes_new = [] 386 | for box in dt_boxes: 387 | box = self.order_points_clockwise(box) 388 | box = self.clip_det_res(box, img_height, img_width) 389 | rect_width = int(np.linalg.norm(box[0] - box[1])) 390 | rect_height = int(np.linalg.norm(box[0] - box[3])) 391 | if rect_width <= 3 or rect_height <= 3: 392 | continue 393 | dt_boxes_new.append(box) 394 | dt_boxes = np.array(dt_boxes_new) 395 | return dt_boxes 396 | 397 | ### 定义图片前处理过程,和检测结果后处理过程 398 | def get_process(self): 399 | det_db_thresh = 0.3 400 | det_db_box_thresh = 0.4 401 | max_candidates = 2000 402 | unclip_ratio = 1.6 403 | use_dilation = True 404 | 405 | pre_process_list = [{ 406 | 'DetResizeForTest': { 407 | 'limit_side_len': 2500, 408 | 'limit_type': 'max' 409 | } 410 | }, { 411 | 'NormalizeImage': { 412 | 'std': [0.5, 0.5, 0.5], 413 | 'mean': [0.5, 0.5, 0.5], 414 | 'scale': '1./255.', 415 | 'order': 'hwc' 416 | } 417 | }, { 418 | 'ToCHWImage': None 419 | }, { 420 | 'KeepKeys': { 421 | 'keep_keys': ['image', 'shape'] 422 | } 423 | }] 424 | 425 | infer_before_process_op = self.create_operators(pre_process_list) 426 | det_re_process_op = DBPostProcess(det_db_thresh, det_db_box_thresh, max_candidates, unclip_ratio, use_dilation) 427 | return infer_before_process_op, det_re_process_op 428 | 429 | def sorted_boxes(self, dt_boxes): 430 | """ 431 | Sort text boxes in order from top to bottom, left to right 432 | args: 433 | dt_boxes(array):detected text boxes with shape [4, 2] 434 | return: 435 | sorted boxes(array) with shape [4, 2] 436 | """ 437 | num_boxes = dt_boxes.shape[0] 438 | sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) 439 | _boxes = list(sorted_boxes) 440 | 441 | for i in range(num_boxes - 1): 442 | if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ 443 | (_boxes[i + 1][0][0] < _boxes[i][0][0]): 444 | tmp = _boxes[i] 445 | _boxes[i] = _boxes[i + 1] 446 | _boxes[i + 1] = tmp 447 | return _boxes 448 | 449 | ### 图像输入预处理 450 | def resize_norm_img(self, img, max_wh_ratio): 451 | imgC, imgH, imgW = [int(v) for v in "3, 48, 100".split(",")] 452 | assert imgC == img.shape[2] 453 | imgW = int((imgH * max_wh_ratio)) 454 | h, w = img.shape[:2] 455 | ratio = w / float(h) 456 | if math.ceil(imgH * ratio) > imgW: 457 | resized_w = imgW 458 | else: 459 | resized_w = int(math.ceil(imgH * ratio)) 460 | resized_image = cv2.resize(img, (resized_w, imgH)) 461 | resized_image = resized_image.astype('float32') 462 | resized_image = resized_image.transpose((2, 0, 1)) / 255 463 | resized_image -= 0.5 464 | resized_image /= 0.5 465 | padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) 466 | padding_im[:, :, 0:resized_w] = resized_image 467 | return padding_im 468 | 469 | ## 推理检测图片中的部分 470 | def get_boxes(self): 471 | img_ori = self.img 472 | img_part = img_ori.copy() 473 | data_part = {'image': img_part} 474 | data_part = self.transform(data_part, self.infer_before_process_op) 475 | img_part, shape_part_list = data_part 476 | img_part = np.expand_dims(img_part, axis=0) 477 | shape_part_list = np.expand_dims(shape_part_list, axis=0) 478 | inputs_part = {self.onet_det_session.get_inputs()[0].name: img_part} 479 | outs_part = self.onet_det_session.run(None, inputs_part) 480 | 481 | post_res_part = self.det_re_process_op(outs_part[0], shape_part_list) 482 | dt_boxes_part = post_res_part[0]['points'] 483 | dt_boxes_part = self.filter_tag_det_res(dt_boxes_part, img_ori.shape) 484 | dt_boxes_part = self.sorted_boxes(dt_boxes_part) 485 | return dt_boxes_part 486 | 487 | ### 根据bounding box得到单元格图片 488 | def get_rotate_crop_image(self, img, points): 489 | img_crop_width = int( 490 | max( 491 | np.linalg.norm(points[0] - points[1]), 492 | np.linalg.norm(points[2] - points[3]))) 493 | img_crop_height = int( 494 | max( 495 | np.linalg.norm(points[0] - points[3]), 496 | np.linalg.norm(points[1] - points[2]))) 497 | pts_std = np.float32([[0, 0], [img_crop_width, 0], 498 | [img_crop_width, img_crop_height], 499 | [0, img_crop_height]]) 500 | M = cv2.getPerspectiveTransform(points, pts_std) 501 | dst_img = cv2.warpPerspective( 502 | img, 503 | M, (img_crop_width, img_crop_height), 504 | borderMode=cv2.BORDER_REPLICATE, 505 | flags=cv2.INTER_CUBIC) 506 | dst_img_height, dst_img_width = dst_img.shape[0:2] 507 | if dst_img_height * 1.0 / dst_img_width >= 1.5: 508 | dst_img = np.rot90(dst_img) 509 | return dst_img 510 | 511 | ### 单张图片推理 512 | def get_img_res(self, onnx_model, img, process_op): 513 | h, w = img.shape[:2] 514 | img = self.resize_norm_img(img, w * 1.0 / h) 515 | img = img[np.newaxis, :] 516 | inputs = {onnx_model.get_inputs()[0].name: img} 517 | outs = onnx_model.run(None, inputs) 518 | result = process_op(outs[0]) 519 | return result 520 | 521 | def recognition_img(self, dt_boxes): 522 | img_ori = self.img 523 | img = img_ori.copy() 524 | ### 识别过程 525 | ## 根据bndbox得到小图片 526 | img_list = [] 527 | for box in dt_boxes: 528 | tmp_box = copy.deepcopy(box) 529 | cv2.rectangle(img_ori,(int(box[0][0]),int(box[0][1])),(int(box[2][0]),int(box[2][1])),(0,0,255),1) 530 | img_crop = self.get_rotate_crop_image(img, tmp_box) 531 | img_list.append(img_crop) 532 | 533 | ## 识别小图片 534 | results = [] 535 | results_info = [] 536 | for pic in img_list: 537 | res = self.get_img_res(self.onet_rec_session, pic, self.postprocess_op) 538 | results.append(res[0]) 539 | results_info.append(res) 540 | return results, results_info,img_ori 541 | 542 | def draw_bbox(self,img_path, result, color=(255, 0, 0), thickness=2): 543 | import cv2 544 | if isinstance(img_path, str): 545 | img_path = cv2.imread(img_path) 546 | # img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB) 547 | img_path = img_path.copy() 548 | for point in result: 549 | point = point.astype(int) 550 | cv2.polylines(img_path, [point], True, color, thickness) 551 | return img_path 552 | 553 | 554 | if __name__=='__main__': 555 | import os 556 | 557 | img_path = "det_images" # 识别测试文件夹 558 | det_file = './weights/ppv3_db.onnx' # 检测模型DBNet地址,需要支持动态输入 559 | rec_file = './weights/ppv3_rec.onnx' # 识别模型,需要支持动态输入 560 | keys_file = './weights/ppocr_keys_v1.txt' # 识别keys文件 561 | 562 | for name in os.listdir(img_path): 563 | time1 = time.time() 564 | image = cv2.imread(os.path.join(img_path,name)) 565 | resize_ratio = min((640 / image.shape[0]), (640 / image.shape[1])) 566 | image = cv2.resize(image, (0, 0), fx=resize_ratio, fy=resize_ratio, interpolation=cv2.INTER_LINEAR) 567 | # OCR-检测-识别 568 | ocr_sys = det_rec_functions(image,det_file,rec_file,keys_file) 569 | # 得到检测框 570 | dt_boxes = ocr_sys.get_boxes() 571 | # 识别 results: 单纯的识别结果,results_info: 识别结果+置信度 572 | results, results_info,img_draw = ocr_sys.recognition_img(dt_boxes) 573 | 574 | time2 = time.time() 575 | print("图片名称: {} 识别耗时: {} ".format(name,time2-time1)) 576 | for txt in results: 577 | print(txt) 578 | print() 579 | cv2.imshow("img_draw", img_draw) 580 | cv2.waitKey() 581 | -------------------------------------------------------------------------------- /weights/ppocr_keys_v1.txt: -------------------------------------------------------------------------------- 1 | ' 2 | 疗 3 | 绚 4 | 诚 5 | 娇 6 | 溜 7 | 题 8 | 贿 9 | 者 10 | 廖 11 | 更 12 | 纳 13 | 加 14 | 奉 15 | 公 16 | 一 17 | 就 18 | 汴 19 | 计 20 | 与 21 | 路 22 | 房 23 | 原 24 | 妇 25 | 2 26 | 0 27 | 8 28 | - 29 | 7 30 | 其 31 | > 32 | : 33 | ] 34 | , 35 | , 36 | 骑 37 | 刈 38 | 全 39 | 消 40 | 昏 41 | 傈 42 | 安 43 | 久 44 | 钟 45 | 嗅 46 | 不 47 | 影 48 | 处 49 | 驽 50 | 蜿 51 | 资 52 | 关 53 | 椤 54 | 地 55 | 瘸 56 | 专 57 | 问 58 | 忖 59 | 票 60 | 嫉 61 | 炎 62 | 韵 63 | 要 64 | 月 65 | 田 66 | 节 67 | 陂 68 | 鄙 69 | 捌 70 | 备 71 | 拳 72 | 伺 73 | 眼 74 | 网 75 | 盎 76 | 大 77 | 傍 78 | 心 79 | 东 80 | 愉 81 | 汇 82 | 蹿 83 | 科 84 | 每 85 | 业 86 | 里 87 | 航 88 | 晏 89 | 字 90 | 平 91 | 录 92 | 先 93 | 1 94 | 3 95 | 彤 96 | 鲶 97 | 产 98 | 稍 99 | 督 100 | 腴 101 | 有 102 | 象 103 | 岳 104 | 注 105 | 绍 106 | 在 107 | 泺 108 | 文 109 | 定 110 | 核 111 | 名 112 | 水 113 | 过 114 | 理 115 | 让 116 | 偷 117 | 率 118 | 等 119 | 这 120 | 发 121 | ” 122 | 为 123 | 含 124 | 肥 125 | 酉 126 | 相 127 | 鄱 128 | 七 129 | 编 130 | 猥 131 | 锛 132 | 日 133 | 镀 134 | 蒂 135 | 掰 136 | 倒 137 | 辆 138 | 栾 139 | 栗 140 | 综 141 | 涩 142 | 州 143 | 雌 144 | 滑 145 | 馀 146 | 了 147 | 机 148 | 块 149 | 司 150 | 宰 151 | 甙 152 | 兴 153 | 矽 154 | 抚 155 | 保 156 | 用 157 | 沧 158 | 秩 159 | 如 160 | 收 161 | 息 162 | 滥 163 | 页 164 | 疑 165 | 埠 166 | ! 167 | ! 168 | 姥 169 | 异 170 | 橹 171 | 钇 172 | 向 173 | 下 174 | 跄 175 | 的 176 | 椴 177 | 沫 178 | 国 179 | 绥 180 | 獠 181 | 报 182 | 开 183 | 民 184 | 蜇 185 | 何 186 | 分 187 | 凇 188 | 长 189 | 讥 190 | 藏 191 | 掏 192 | 施 193 | 羽 194 | 中 195 | 讲 196 | 派 197 | 嘟 198 | 人 199 | 提 200 | 浼 201 | 间 202 | 世 203 | 而 204 | 古 205 | 多 206 | 倪 207 | 唇 208 | 饯 209 | 控 210 | 庚 211 | 首 212 | 赛 213 | 蜓 214 | 味 215 | 断 216 | 制 217 | 觉 218 | 技 219 | 替 220 | 艰 221 | 溢 222 | 潮 223 | 夕 224 | 钺 225 | 外 226 | 摘 227 | 枋 228 | 动 229 | 双 230 | 单 231 | 啮 232 | 户 233 | 枇 234 | 确 235 | 锦 236 | 曜 237 | 杜 238 | 或 239 | 能 240 | 效 241 | 霜 242 | 盒 243 | 然 244 | 侗 245 | 电 246 | 晁 247 | 放 248 | 步 249 | 鹃 250 | 新 251 | 杖 252 | 蜂 253 | 吒 254 | 濂 255 | 瞬 256 | 评 257 | 总 258 | 隍 259 | 对 260 | 独 261 | 合 262 | 也 263 | 是 264 | 府 265 | 青 266 | 天 267 | 诲 268 | 墙 269 | 组 270 | 滴 271 | 级 272 | 邀 273 | 帘 274 | 示 275 | 已 276 | 时 277 | 骸 278 | 仄 279 | 泅 280 | 和 281 | 遨 282 | 店 283 | 雇 284 | 疫 285 | 持 286 | 巍 287 | 踮 288 | 境 289 | 只 290 | 亨 291 | 目 292 | 鉴 293 | 崤 294 | 闲 295 | 体 296 | 泄 297 | 杂 298 | 作 299 | 般 300 | 轰 301 | 化 302 | 解 303 | 迂 304 | 诿 305 | 蛭 306 | 璀 307 | 腾 308 | 告 309 | 版 310 | 服 311 | 省 312 | 师 313 | 小 314 | 规 315 | 程 316 | 线 317 | 海 318 | 办 319 | 引 320 | 二 321 | 桧 322 | 牌 323 | 砺 324 | 洄 325 | 裴 326 | 修 327 | 图 328 | 痫 329 | 胡 330 | 许 331 | 犊 332 | 事 333 | 郛 334 | 基 335 | 柴 336 | 呼 337 | 食 338 | 研 339 | 奶 340 | 律 341 | 蛋 342 | 因 343 | 葆 344 | 察 345 | 戏 346 | 褒 347 | 戒 348 | 再 349 | 李 350 | 骁 351 | 工 352 | 貂 353 | 油 354 | 鹅 355 | 章 356 | 啄 357 | 休 358 | 场 359 | 给 360 | 睡 361 | 纷 362 | 豆 363 | 器 364 | 捎 365 | 说 366 | 敏 367 | 学 368 | 会 369 | 浒 370 | 设 371 | 诊 372 | 格 373 | 廓 374 | 查 375 | 来 376 | 霓 377 | 室 378 | 溆 379 | ¢ 380 | 诡 381 | 寥 382 | 焕 383 | 舜 384 | 柒 385 | 狐 386 | 回 387 | 戟 388 | 砾 389 | 厄 390 | 实 391 | 翩 392 | 尿 393 | 五 394 | 入 395 | 径 396 | 惭 397 | 喹 398 | 股 399 | 宇 400 | 篝 401 | | 402 | ; 403 | 美 404 | 期 405 | 云 406 | 九 407 | 祺 408 | 扮 409 | 靠 410 | 锝 411 | 槌 412 | 系 413 | 企 414 | 酰 415 | 阊 416 | 暂 417 | 蚕 418 | 忻 419 | 豁 420 | 本 421 | 羹 422 | 执 423 | 条 424 | 钦 425 | H 426 | 獒 427 | 限 428 | 进 429 | 季 430 | 楦 431 | 于 432 | 芘 433 | 玖 434 | 铋 435 | 茯 436 | 未 437 | 答 438 | 粘 439 | 括 440 | 样 441 | 精 442 | 欠 443 | 矢 444 | 甥 445 | 帷 446 | 嵩 447 | 扣 448 | 令 449 | 仔 450 | 风 451 | 皈 452 | 行 453 | 支 454 | 部 455 | 蓉 456 | 刮 457 | 站 458 | 蜡 459 | 救 460 | 钊 461 | 汗 462 | 松 463 | 嫌 464 | 成 465 | 可 466 | . 467 | 鹤 468 | 院 469 | 从 470 | 交 471 | 政 472 | 怕 473 | 活 474 | 调 475 | 球 476 | 局 477 | 验 478 | 髌 479 | 第 480 | 韫 481 | 谗 482 | 串 483 | 到 484 | 圆 485 | 年 486 | 米 487 | / 488 | * 489 | 友 490 | 忿 491 | 检 492 | 区 493 | 看 494 | 自 495 | 敢 496 | 刃 497 | 个 498 | 兹 499 | 弄 500 | 流 501 | 留 502 | 同 503 | 没 504 | 齿 505 | 星 506 | 聆 507 | 轼 508 | 湖 509 | 什 510 | 三 511 | 建 512 | 蛔 513 | 儿 514 | 椋 515 | 汕 516 | 震 517 | 颧 518 | 鲤 519 | 跟 520 | 力 521 | 情 522 | 璺 523 | 铨 524 | 陪 525 | 务 526 | 指 527 | 族 528 | 训 529 | 滦 530 | 鄣 531 | 濮 532 | 扒 533 | 商 534 | 箱 535 | 十 536 | 召 537 | 慷 538 | 辗 539 | 所 540 | 莞 541 | 管 542 | 护 543 | 臭 544 | 横 545 | 硒 546 | 嗓 547 | 接 548 | 侦 549 | 六 550 | 露 551 | 党 552 | 馋 553 | 驾 554 | 剖 555 | 高 556 | 侬 557 | 妪 558 | 幂 559 | 猗 560 | 绺 561 | 骐 562 | 央 563 | 酐 564 | 孝 565 | 筝 566 | 课 567 | 徇 568 | 缰 569 | 门 570 | 男 571 | 西 572 | 项 573 | 句 574 | 谙 575 | 瞒 576 | 秃 577 | 篇 578 | 教 579 | 碲 580 | 罚 581 | 声 582 | 呐 583 | 景 584 | 前 585 | 富 586 | 嘴 587 | 鳌 588 | 稀 589 | 免 590 | 朋 591 | 啬 592 | 睐 593 | 去 594 | 赈 595 | 鱼 596 | 住 597 | 肩 598 | 愕 599 | 速 600 | 旁 601 | 波 602 | 厅 603 | 健 604 | 茼 605 | 厥 606 | 鲟 607 | 谅 608 | 投 609 | 攸 610 | 炔 611 | 数 612 | 方 613 | 击 614 | 呋 615 | 谈 616 | 绩 617 | 别 618 | 愫 619 | 僚 620 | 躬 621 | 鹧 622 | 胪 623 | 炳 624 | 招 625 | 喇 626 | 膨 627 | 泵 628 | 蹦 629 | 毛 630 | 结 631 | 5 632 | 4 633 | 谱 634 | 识 635 | 陕 636 | 粽 637 | 婚 638 | 拟 639 | 构 640 | 且 641 | 搜 642 | 任 643 | 潘 644 | 比 645 | 郢 646 | 妨 647 | 醪 648 | 陀 649 | 桔 650 | 碘 651 | 扎 652 | 选 653 | 哈 654 | 骷 655 | 楷 656 | 亿 657 | 明 658 | 缆 659 | 脯 660 | 监 661 | 睫 662 | 逻 663 | 婵 664 | 共 665 | 赴 666 | 淝 667 | 凡 668 | 惦 669 | 及 670 | 达 671 | 揖 672 | 谩 673 | 澹 674 | 减 675 | 焰 676 | 蛹 677 | 番 678 | 祁 679 | 柏 680 | 员 681 | 禄 682 | 怡 683 | 峤 684 | 龙 685 | 白 686 | 叽 687 | 生 688 | 闯 689 | 起 690 | 细 691 | 装 692 | 谕 693 | 竟 694 | 聚 695 | 钙 696 | 上 697 | 导 698 | 渊 699 | 按 700 | 艾 701 | 辘 702 | 挡 703 | 耒 704 | 盹 705 | 饪 706 | 臀 707 | 记 708 | 邮 709 | 蕙 710 | 受 711 | 各 712 | 医 713 | 搂 714 | 普 715 | 滇 716 | 朗 717 | 茸 718 | 带 719 | 翻 720 | 酚 721 | ( 722 | 光 723 | 堤 724 | 墟 725 | 蔷 726 | 万 727 | 幻 728 | 〓 729 | 瑙 730 | 辈 731 | 昧 732 | 盏 733 | 亘 734 | 蛀 735 | 吉 736 | 铰 737 | 请 738 | 子 739 | 假 740 | 闻 741 | 税 742 | 井 743 | 诩 744 | 哨 745 | 嫂 746 | 好 747 | 面 748 | 琐 749 | 校 750 | 馊 751 | 鬣 752 | 缂 753 | 营 754 | 访 755 | 炖 756 | 占 757 | 农 758 | 缀 759 | 否 760 | 经 761 | 钚 762 | 棵 763 | 趟 764 | 张 765 | 亟 766 | 吏 767 | 茶 768 | 谨 769 | 捻 770 | 论 771 | 迸 772 | 堂 773 | 玉 774 | 信 775 | 吧 776 | 瞠 777 | 乡 778 | 姬 779 | 寺 780 | 咬 781 | 溏 782 | 苄 783 | 皿 784 | 意 785 | 赉 786 | 宝 787 | 尔 788 | 钰 789 | 艺 790 | 特 791 | 唳 792 | 踉 793 | 都 794 | 荣 795 | 倚 796 | 登 797 | 荐 798 | 丧 799 | 奇 800 | 涵 801 | 批 802 | 炭 803 | 近 804 | 符 805 | 傩 806 | 感 807 | 道 808 | 着 809 | 菊 810 | 虹 811 | 仲 812 | 众 813 | 懈 814 | 濯 815 | 颞 816 | 眺 817 | 南 818 | 释 819 | 北 820 | 缝 821 | 标 822 | 既 823 | 茗 824 | 整 825 | 撼 826 | 迤 827 | 贲 828 | 挎 829 | 耱 830 | 拒 831 | 某 832 | 妍 833 | 卫 834 | 哇 835 | 英 836 | 矶 837 | 藩 838 | 治 839 | 他 840 | 元 841 | 领 842 | 膜 843 | 遮 844 | 穗 845 | 蛾 846 | 飞 847 | 荒 848 | 棺 849 | 劫 850 | 么 851 | 市 852 | 火 853 | 温 854 | 拈 855 | 棚 856 | 洼 857 | 转 858 | 果 859 | 奕 860 | 卸 861 | 迪 862 | 伸 863 | 泳 864 | 斗 865 | 邡 866 | 侄 867 | 涨 868 | 屯 869 | 萋 870 | 胭 871 | 氡 872 | 崮 873 | 枞 874 | 惧 875 | 冒 876 | 彩 877 | 斜 878 | 手 879 | 豚 880 | 随 881 | 旭 882 | 淑 883 | 妞 884 | 形 885 | 菌 886 | 吲 887 | 沱 888 | 争 889 | 驯 890 | 歹 891 | 挟 892 | 兆 893 | 柱 894 | 传 895 | 至 896 | 包 897 | 内 898 | 响 899 | 临 900 | 红 901 | 功 902 | 弩 903 | 衡 904 | 寂 905 | 禁 906 | 老 907 | 棍 908 | 耆 909 | 渍 910 | 织 911 | 害 912 | 氵 913 | 渑 914 | 布 915 | 载 916 | 靥 917 | 嗬 918 | 虽 919 | 苹 920 | 咨 921 | 娄 922 | 库 923 | 雉 924 | 榜 925 | 帜 926 | 嘲 927 | 套 928 | 瑚 929 | 亲 930 | 簸 931 | 欧 932 | 边 933 | 6 934 | 腿 935 | 旮 936 | 抛 937 | 吹 938 | 瞳 939 | 得 940 | 镓 941 | 梗 942 | 厨 943 | 继 944 | 漾 945 | 愣 946 | 憨 947 | 士 948 | 策 949 | 窑 950 | 抑 951 | 躯 952 | 襟 953 | 脏 954 | 参 955 | 贸 956 | 言 957 | 干 958 | 绸 959 | 鳄 960 | 穷 961 | 藜 962 | 音 963 | 折 964 | 详 965 | ) 966 | 举 967 | 悍 968 | 甸 969 | 癌 970 | 黎 971 | 谴 972 | 死 973 | 罩 974 | 迁 975 | 寒 976 | 驷 977 | 袖 978 | 媒 979 | 蒋 980 | 掘 981 | 模 982 | 纠 983 | 恣 984 | 观 985 | 祖 986 | 蛆 987 | 碍 988 | 位 989 | 稿 990 | 主 991 | 澧 992 | 跌 993 | 筏 994 | 京 995 | 锏 996 | 帝 997 | 贴 998 | 证 999 | 糠 1000 | 才 1001 | 黄 1002 | 鲸 1003 | 略 1004 | 炯 1005 | 饱 1006 | 四 1007 | 出 1008 | 园 1009 | 犀 1010 | 牧 1011 | 容 1012 | 汉 1013 | 杆 1014 | 浈 1015 | 汰 1016 | 瑷 1017 | 造 1018 | 虫 1019 | 瘩 1020 | 怪 1021 | 驴 1022 | 济 1023 | 应 1024 | 花 1025 | 沣 1026 | 谔 1027 | 夙 1028 | 旅 1029 | 价 1030 | 矿 1031 | 以 1032 | 考 1033 | s 1034 | u 1035 | 呦 1036 | 晒 1037 | 巡 1038 | 茅 1039 | 准 1040 | 肟 1041 | 瓴 1042 | 詹 1043 | 仟 1044 | 褂 1045 | 译 1046 | 桌 1047 | 混 1048 | 宁 1049 | 怦 1050 | 郑 1051 | 抿 1052 | 些 1053 | 余 1054 | 鄂 1055 | 饴 1056 | 攒 1057 | 珑 1058 | 群 1059 | 阖 1060 | 岔 1061 | 琨 1062 | 藓 1063 | 预 1064 | 环 1065 | 洮 1066 | 岌 1067 | 宀 1068 | 杲 1069 | 瀵 1070 | 最 1071 | 常 1072 | 囡 1073 | 周 1074 | 踊 1075 | 女 1076 | 鼓 1077 | 袭 1078 | 喉 1079 | 简 1080 | 范 1081 | 薯 1082 | 遐 1083 | 疏 1084 | 粱 1085 | 黜 1086 | 禧 1087 | 法 1088 | 箔 1089 | 斤 1090 | 遥 1091 | 汝 1092 | 奥 1093 | 直 1094 | 贞 1095 | 撑 1096 | 置 1097 | 绱 1098 | 集 1099 | 她 1100 | 馅 1101 | 逗 1102 | 钧 1103 | 橱 1104 | 魉 1105 | [ 1106 | 恙 1107 | 躁 1108 | 唤 1109 | 9 1110 | 旺 1111 | 膘 1112 | 待 1113 | 脾 1114 | 惫 1115 | 购 1116 | 吗 1117 | 依 1118 | 盲 1119 | 度 1120 | 瘿 1121 | 蠖 1122 | 俾 1123 | 之 1124 | 镗 1125 | 拇 1126 | 鲵 1127 | 厝 1128 | 簧 1129 | 续 1130 | 款 1131 | 展 1132 | 啃 1133 | 表 1134 | 剔 1135 | 品 1136 | 钻 1137 | 腭 1138 | 损 1139 | 清 1140 | 锶 1141 | 统 1142 | 涌 1143 | 寸 1144 | 滨 1145 | 贪 1146 | 链 1147 | 吠 1148 | 冈 1149 | 伎 1150 | 迥 1151 | 咏 1152 | 吁 1153 | 览 1154 | 防 1155 | 迅 1156 | 失 1157 | 汾 1158 | 阔 1159 | 逵 1160 | 绀 1161 | 蔑 1162 | 列 1163 | 川 1164 | 凭 1165 | 努 1166 | 熨 1167 | 揪 1168 | 利 1169 | 俱 1170 | 绉 1171 | 抢 1172 | 鸨 1173 | 我 1174 | 即 1175 | 责 1176 | 膦 1177 | 易 1178 | 毓 1179 | 鹊 1180 | 刹 1181 | 玷 1182 | 岿 1183 | 空 1184 | 嘞 1185 | 绊 1186 | 排 1187 | 术 1188 | 估 1189 | 锷 1190 | 违 1191 | 们 1192 | 苟 1193 | 铜 1194 | 播 1195 | 肘 1196 | 件 1197 | 烫 1198 | 审 1199 | 鲂 1200 | 广 1201 | 像 1202 | 铌 1203 | 惰 1204 | 铟 1205 | 巳 1206 | 胍 1207 | 鲍 1208 | 康 1209 | 憧 1210 | 色 1211 | 恢 1212 | 想 1213 | 拷 1214 | 尤 1215 | 疳 1216 | 知 1217 | S 1218 | Y 1219 | F 1220 | D 1221 | A 1222 | 峄 1223 | 裕 1224 | 帮 1225 | 握 1226 | 搔 1227 | 氐 1228 | 氘 1229 | 难 1230 | 墒 1231 | 沮 1232 | 雨 1233 | 叁 1234 | 缥 1235 | 悴 1236 | 藐 1237 | 湫 1238 | 娟 1239 | 苑 1240 | 稠 1241 | 颛 1242 | 簇 1243 | 后 1244 | 阕 1245 | 闭 1246 | 蕤 1247 | 缚 1248 | 怎 1249 | 佞 1250 | 码 1251 | 嘤 1252 | 蔡 1253 | 痊 1254 | 舱 1255 | 螯 1256 | 帕 1257 | 赫 1258 | 昵 1259 | 升 1260 | 烬 1261 | 岫 1262 | 、 1263 | 疵 1264 | 蜻 1265 | 髁 1266 | 蕨 1267 | 隶 1268 | 烛 1269 | 械 1270 | 丑 1271 | 盂 1272 | 梁 1273 | 强 1274 | 鲛 1275 | 由 1276 | 拘 1277 | 揉 1278 | 劭 1279 | 龟 1280 | 撤 1281 | 钩 1282 | 呕 1283 | 孛 1284 | 费 1285 | 妻 1286 | 漂 1287 | 求 1288 | 阑 1289 | 崖 1290 | 秤 1291 | 甘 1292 | 通 1293 | 深 1294 | 补 1295 | 赃 1296 | 坎 1297 | 床 1298 | 啪 1299 | 承 1300 | 吼 1301 | 量 1302 | 暇 1303 | 钼 1304 | 烨 1305 | 阂 1306 | 擎 1307 | 脱 1308 | 逮 1309 | 称 1310 | P 1311 | 神 1312 | 属 1313 | 矗 1314 | 华 1315 | 届 1316 | 狍 1317 | 葑 1318 | 汹 1319 | 育 1320 | 患 1321 | 窒 1322 | 蛰 1323 | 佼 1324 | 静 1325 | 槎 1326 | 运 1327 | 鳗 1328 | 庆 1329 | 逝 1330 | 曼 1331 | 疱 1332 | 克 1333 | 代 1334 | 官 1335 | 此 1336 | 麸 1337 | 耧 1338 | 蚌 1339 | 晟 1340 | 例 1341 | 础 1342 | 榛 1343 | 副 1344 | 测 1345 | 唰 1346 | 缢 1347 | 迹 1348 | 灬 1349 | 霁 1350 | 身 1351 | 岁 1352 | 赭 1353 | 扛 1354 | 又 1355 | 菡 1356 | 乜 1357 | 雾 1358 | 板 1359 | 读 1360 | 陷 1361 | 徉 1362 | 贯 1363 | 郁 1364 | 虑 1365 | 变 1366 | 钓 1367 | 菜 1368 | 圾 1369 | 现 1370 | 琢 1371 | 式 1372 | 乐 1373 | 维 1374 | 渔 1375 | 浜 1376 | 左 1377 | 吾 1378 | 脑 1379 | 钡 1380 | 警 1381 | T 1382 | 啵 1383 | 拴 1384 | 偌 1385 | 漱 1386 | 湿 1387 | 硕 1388 | 止 1389 | 骼 1390 | 魄 1391 | 积 1392 | 燥 1393 | 联 1394 | 踢 1395 | 玛 1396 | 则 1397 | 窿 1398 | 见 1399 | 振 1400 | 畿 1401 | 送 1402 | 班 1403 | 钽 1404 | 您 1405 | 赵 1406 | 刨 1407 | 印 1408 | 讨 1409 | 踝 1410 | 籍 1411 | 谡 1412 | 舌 1413 | 崧 1414 | 汽 1415 | 蔽 1416 | 沪 1417 | 酥 1418 | 绒 1419 | 怖 1420 | 财 1421 | 帖 1422 | 肱 1423 | 私 1424 | 莎 1425 | 勋 1426 | 羔 1427 | 霸 1428 | 励 1429 | 哼 1430 | 帐 1431 | 将 1432 | 帅 1433 | 渠 1434 | 纪 1435 | 婴 1436 | 娩 1437 | 岭 1438 | 厘 1439 | 滕 1440 | 吻 1441 | 伤 1442 | 坝 1443 | 冠 1444 | 戊 1445 | 隆 1446 | 瘁 1447 | 介 1448 | 涧 1449 | 物 1450 | 黍 1451 | 并 1452 | 姗 1453 | 奢 1454 | 蹑 1455 | 掣 1456 | 垸 1457 | 锴 1458 | 命 1459 | 箍 1460 | 捉 1461 | 病 1462 | 辖 1463 | 琰 1464 | 眭 1465 | 迩 1466 | 艘 1467 | 绌 1468 | 繁 1469 | 寅 1470 | 若 1471 | 毋 1472 | 思 1473 | 诉 1474 | 类 1475 | 诈 1476 | 燮 1477 | 轲 1478 | 酮 1479 | 狂 1480 | 重 1481 | 反 1482 | 职 1483 | 筱 1484 | 县 1485 | 委 1486 | 磕 1487 | 绣 1488 | 奖 1489 | 晋 1490 | 濉 1491 | 志 1492 | 徽 1493 | 肠 1494 | 呈 1495 | 獐 1496 | 坻 1497 | 口 1498 | 片 1499 | 碰 1500 | 几 1501 | 村 1502 | 柿 1503 | 劳 1504 | 料 1505 | 获 1506 | 亩 1507 | 惕 1508 | 晕 1509 | 厌 1510 | 号 1511 | 罢 1512 | 池 1513 | 正 1514 | 鏖 1515 | 煨 1516 | 家 1517 | 棕 1518 | 复 1519 | 尝 1520 | 懋 1521 | 蜥 1522 | 锅 1523 | 岛 1524 | 扰 1525 | 队 1526 | 坠 1527 | 瘾 1528 | 钬 1529 | @ 1530 | 卧 1531 | 疣 1532 | 镇 1533 | 譬 1534 | 冰 1535 | 彷 1536 | 频 1537 | 黯 1538 | 据 1539 | 垄 1540 | 采 1541 | 八 1542 | 缪 1543 | 瘫 1544 | 型 1545 | 熹 1546 | 砰 1547 | 楠 1548 | 襁 1549 | 箐 1550 | 但 1551 | 嘶 1552 | 绳 1553 | 啤 1554 | 拍 1555 | 盥 1556 | 穆 1557 | 傲 1558 | 洗 1559 | 盯 1560 | 塘 1561 | 怔 1562 | 筛 1563 | 丿 1564 | 台 1565 | 恒 1566 | 喂 1567 | 葛 1568 | 永 1569 | ¥ 1570 | 烟 1571 | 酒 1572 | 桦 1573 | 书 1574 | 砂 1575 | 蚝 1576 | 缉 1577 | 态 1578 | 瀚 1579 | 袄 1580 | 圳 1581 | 轻 1582 | 蛛 1583 | 超 1584 | 榧 1585 | 遛 1586 | 姒 1587 | 奘 1588 | 铮 1589 | 右 1590 | 荽 1591 | 望 1592 | 偻 1593 | 卡 1594 | 丶 1595 | 氰 1596 | 附 1597 | 做 1598 | 革 1599 | 索 1600 | 戚 1601 | 坨 1602 | 桷 1603 | 唁 1604 | 垅 1605 | 榻 1606 | 岐 1607 | 偎 1608 | 坛 1609 | 莨 1610 | 山 1611 | 殊 1612 | 微 1613 | 骇 1614 | 陈 1615 | 爨 1616 | 推 1617 | 嗝 1618 | 驹 1619 | 澡 1620 | 藁 1621 | 呤 1622 | 卤 1623 | 嘻 1624 | 糅 1625 | 逛 1626 | 侵 1627 | 郓 1628 | 酌 1629 | 德 1630 | 摇 1631 | ※ 1632 | 鬃 1633 | 被 1634 | 慨 1635 | 殡 1636 | 羸 1637 | 昌 1638 | 泡 1639 | 戛 1640 | 鞋 1641 | 河 1642 | 宪 1643 | 沿 1644 | 玲 1645 | 鲨 1646 | 翅 1647 | 哽 1648 | 源 1649 | 铅 1650 | 语 1651 | 照 1652 | 邯 1653 | 址 1654 | 荃 1655 | 佬 1656 | 顺 1657 | 鸳 1658 | 町 1659 | 霭 1660 | 睾 1661 | 瓢 1662 | 夸 1663 | 椁 1664 | 晓 1665 | 酿 1666 | 痈 1667 | 咔 1668 | 侏 1669 | 券 1670 | 噎 1671 | 湍 1672 | 签 1673 | 嚷 1674 | 离 1675 | 午 1676 | 尚 1677 | 社 1678 | 锤 1679 | 背 1680 | 孟 1681 | 使 1682 | 浪 1683 | 缦 1684 | 潍 1685 | 鞅 1686 | 军 1687 | 姹 1688 | 驶 1689 | 笑 1690 | 鳟 1691 | 鲁 1692 | 》 1693 | 孽 1694 | 钜 1695 | 绿 1696 | 洱 1697 | 礴 1698 | 焯 1699 | 椰 1700 | 颖 1701 | 囔 1702 | 乌 1703 | 孔 1704 | 巴 1705 | 互 1706 | 性 1707 | 椽 1708 | 哞 1709 | 聘 1710 | 昨 1711 | 早 1712 | 暮 1713 | 胶 1714 | 炀 1715 | 隧 1716 | 低 1717 | 彗 1718 | 昝 1719 | 铁 1720 | 呓 1721 | 氽 1722 | 藉 1723 | 喔 1724 | 癖 1725 | 瑗 1726 | 姨 1727 | 权 1728 | 胱 1729 | 韦 1730 | 堑 1731 | 蜜 1732 | 酋 1733 | 楝 1734 | 砝 1735 | 毁 1736 | 靓 1737 | 歙 1738 | 锲 1739 | 究 1740 | 屋 1741 | 喳 1742 | 骨 1743 | 辨 1744 | 碑 1745 | 武 1746 | 鸠 1747 | 宫 1748 | 辜 1749 | 烊 1750 | 适 1751 | 坡 1752 | 殃 1753 | 培 1754 | 佩 1755 | 供 1756 | 走 1757 | 蜈 1758 | 迟 1759 | 翼 1760 | 况 1761 | 姣 1762 | 凛 1763 | 浔 1764 | 吃 1765 | 飘 1766 | 债 1767 | 犟 1768 | 金 1769 | 促 1770 | 苛 1771 | 崇 1772 | 坂 1773 | 莳 1774 | 畔 1775 | 绂 1776 | 兵 1777 | 蠕 1778 | 斋 1779 | 根 1780 | 砍 1781 | 亢 1782 | 欢 1783 | 恬 1784 | 崔 1785 | 剁 1786 | 餐 1787 | 榫 1788 | 快 1789 | 扶 1790 | ‖ 1791 | 濒 1792 | 缠 1793 | 鳜 1794 | 当 1795 | 彭 1796 | 驭 1797 | 浦 1798 | 篮 1799 | 昀 1800 | 锆 1801 | 秸 1802 | 钳 1803 | 弋 1804 | 娣 1805 | 瞑 1806 | 夷 1807 | 龛 1808 | 苫 1809 | 拱 1810 | 致 1811 | % 1812 | 嵊 1813 | 障 1814 | 隐 1815 | 弑 1816 | 初 1817 | 娓 1818 | 抉 1819 | 汩 1820 | 累 1821 | 蓖 1822 | " 1823 | 唬 1824 | 助 1825 | 苓 1826 | 昙 1827 | 押 1828 | 毙 1829 | 破 1830 | 城 1831 | 郧 1832 | 逢 1833 | 嚏 1834 | 獭 1835 | 瞻 1836 | 溱 1837 | 婿 1838 | 赊 1839 | 跨 1840 | 恼 1841 | 璧 1842 | 萃 1843 | 姻 1844 | 貉 1845 | 灵 1846 | 炉 1847 | 密 1848 | 氛 1849 | 陶 1850 | 砸 1851 | 谬 1852 | 衔 1853 | 点 1854 | 琛 1855 | 沛 1856 | 枳 1857 | 层 1858 | 岱 1859 | 诺 1860 | 脍 1861 | 榈 1862 | 埂 1863 | 征 1864 | 冷 1865 | 裁 1866 | 打 1867 | 蹴 1868 | 素 1869 | 瘘 1870 | 逞 1871 | 蛐 1872 | 聊 1873 | 激 1874 | 腱 1875 | 萘 1876 | 踵 1877 | 飒 1878 | 蓟 1879 | 吆 1880 | 取 1881 | 咙 1882 | 簋 1883 | 涓 1884 | 矩 1885 | 曝 1886 | 挺 1887 | 揣 1888 | 座 1889 | 你 1890 | 史 1891 | 舵 1892 | 焱 1893 | 尘 1894 | 苏 1895 | 笈 1896 | 脚 1897 | 溉 1898 | 榨 1899 | 诵 1900 | 樊 1901 | 邓 1902 | 焊 1903 | 义 1904 | 庶 1905 | 儋 1906 | 蟋 1907 | 蒲 1908 | 赦 1909 | 呷 1910 | 杞 1911 | 诠 1912 | 豪 1913 | 还 1914 | 试 1915 | 颓 1916 | 茉 1917 | 太 1918 | 除 1919 | 紫 1920 | 逃 1921 | 痴 1922 | 草 1923 | 充 1924 | 鳕 1925 | 珉 1926 | 祗 1927 | 墨 1928 | 渭 1929 | 烩 1930 | 蘸 1931 | 慕 1932 | 璇 1933 | 镶 1934 | 穴 1935 | 嵘 1936 | 恶 1937 | 骂 1938 | 险 1939 | 绋 1940 | 幕 1941 | 碉 1942 | 肺 1943 | 戳 1944 | 刘 1945 | 潞 1946 | 秣 1947 | 纾 1948 | 潜 1949 | 銮 1950 | 洛 1951 | 须 1952 | 罘 1953 | 销 1954 | 瘪 1955 | 汞 1956 | 兮 1957 | 屉 1958 | r 1959 | 林 1960 | 厕 1961 | 质 1962 | 探 1963 | 划 1964 | 狸 1965 | 殚 1966 | 善 1967 | 煊 1968 | 烹 1969 | 〒 1970 | 锈 1971 | 逯 1972 | 宸 1973 | 辍 1974 | 泱 1975 | 柚 1976 | 袍 1977 | 远 1978 | 蹋 1979 | 嶙 1980 | 绝 1981 | 峥 1982 | 娥 1983 | 缍 1984 | 雀 1985 | 徵 1986 | 认 1987 | 镱 1988 | 谷 1989 | = 1990 | 贩 1991 | 勉 1992 | 撩 1993 | 鄯 1994 | 斐 1995 | 洋 1996 | 非 1997 | 祚 1998 | 泾 1999 | 诒 2000 | 饿 2001 | 撬 2002 | 威 2003 | 晷 2004 | 搭 2005 | 芍 2006 | 锥 2007 | 笺 2008 | 蓦 2009 | 候 2010 | 琊 2011 | 档 2012 | 礁 2013 | 沼 2014 | 卵 2015 | 荠 2016 | 忑 2017 | 朝 2018 | 凹 2019 | 瑞 2020 | 头 2021 | 仪 2022 | 弧 2023 | 孵 2024 | 畏 2025 | 铆 2026 | 突 2027 | 衲 2028 | 车 2029 | 浩 2030 | 气 2031 | 茂 2032 | 悖 2033 | 厢 2034 | 枕 2035 | 酝 2036 | 戴 2037 | 湾 2038 | 邹 2039 | 飚 2040 | 攘 2041 | 锂 2042 | 写 2043 | 宵 2044 | 翁 2045 | 岷 2046 | 无 2047 | 喜 2048 | 丈 2049 | 挑 2050 | 嗟 2051 | 绛 2052 | 殉 2053 | 议 2054 | 槽 2055 | 具 2056 | 醇 2057 | 淞 2058 | 笃 2059 | 郴 2060 | 阅 2061 | 饼 2062 | 底 2063 | 壕 2064 | 砚 2065 | 弈 2066 | 询 2067 | 缕 2068 | 庹 2069 | 翟 2070 | 零 2071 | 筷 2072 | 暨 2073 | 舟 2074 | 闺 2075 | 甯 2076 | 撞 2077 | 麂 2078 | 茌 2079 | 蔼 2080 | 很 2081 | 珲 2082 | 捕 2083 | 棠 2084 | 角 2085 | 阉 2086 | 媛 2087 | 娲 2088 | 诽 2089 | 剿 2090 | 尉 2091 | 爵 2092 | 睬 2093 | 韩 2094 | 诰 2095 | 匣 2096 | 危 2097 | 糍 2098 | 镯 2099 | 立 2100 | 浏 2101 | 阳 2102 | 少 2103 | 盆 2104 | 舔 2105 | 擘 2106 | 匪 2107 | 申 2108 | 尬 2109 | 铣 2110 | 旯 2111 | 抖 2112 | 赘 2113 | 瓯 2114 | 居 2115 | ˇ 2116 | 哮 2117 | 游 2118 | 锭 2119 | 茏 2120 | 歌 2121 | 坏 2122 | 甚 2123 | 秒 2124 | 舞 2125 | 沙 2126 | 仗 2127 | 劲 2128 | 潺 2129 | 阿 2130 | 燧 2131 | 郭 2132 | 嗖 2133 | 霏 2134 | 忠 2135 | 材 2136 | 奂 2137 | 耐 2138 | 跺 2139 | 砀 2140 | 输 2141 | 岖 2142 | 媳 2143 | 氟 2144 | 极 2145 | 摆 2146 | 灿 2147 | 今 2148 | 扔 2149 | 腻 2150 | 枝 2151 | 奎 2152 | 药 2153 | 熄 2154 | 吨 2155 | 话 2156 | q 2157 | 额 2158 | 慑 2159 | 嘌 2160 | 协 2161 | 喀 2162 | 壳 2163 | 埭 2164 | 视 2165 | 著 2166 | 於 2167 | 愧 2168 | 陲 2169 | 翌 2170 | 峁 2171 | 颅 2172 | 佛 2173 | 腹 2174 | 聋 2175 | 侯 2176 | 咎 2177 | 叟 2178 | 秀 2179 | 颇 2180 | 存 2181 | 较 2182 | 罪 2183 | 哄 2184 | 岗 2185 | 扫 2186 | 栏 2187 | 钾 2188 | 羌 2189 | 己 2190 | 璨 2191 | 枭 2192 | 霉 2193 | 煌 2194 | 涸 2195 | 衿 2196 | 键 2197 | 镝 2198 | 益 2199 | 岢 2200 | 奏 2201 | 连 2202 | 夯 2203 | 睿 2204 | 冥 2205 | 均 2206 | 糖 2207 | 狞 2208 | 蹊 2209 | 稻 2210 | 爸 2211 | 刿 2212 | 胥 2213 | 煜 2214 | 丽 2215 | 肿 2216 | 璃 2217 | 掸 2218 | 跚 2219 | 灾 2220 | 垂 2221 | 樾 2222 | 濑 2223 | 乎 2224 | 莲 2225 | 窄 2226 | 犹 2227 | 撮 2228 | 战 2229 | 馄 2230 | 软 2231 | 络 2232 | 显 2233 | 鸢 2234 | 胸 2235 | 宾 2236 | 妲 2237 | 恕 2238 | 埔 2239 | 蝌 2240 | 份 2241 | 遇 2242 | 巧 2243 | 瞟 2244 | 粒 2245 | 恰 2246 | 剥 2247 | 桡 2248 | 博 2249 | 讯 2250 | 凯 2251 | 堇 2252 | 阶 2253 | 滤 2254 | 卖 2255 | 斌 2256 | 骚 2257 | 彬 2258 | 兑 2259 | 磺 2260 | 樱 2261 | 舷 2262 | 两 2263 | 娱 2264 | 福 2265 | 仃 2266 | 差 2267 | 找 2268 | 桁 2269 | ÷ 2270 | 净 2271 | 把 2272 | 阴 2273 | 污 2274 | 戬 2275 | 雷 2276 | 碓 2277 | 蕲 2278 | 楚 2279 | 罡 2280 | 焖 2281 | 抽 2282 | 妫 2283 | 咒 2284 | 仑 2285 | 闱 2286 | 尽 2287 | 邑 2288 | 菁 2289 | 爱 2290 | 贷 2291 | 沥 2292 | 鞑 2293 | 牡 2294 | 嗉 2295 | 崴 2296 | 骤 2297 | 塌 2298 | 嗦 2299 | 订 2300 | 拮 2301 | 滓 2302 | 捡 2303 | 锻 2304 | 次 2305 | 坪 2306 | 杩 2307 | 臃 2308 | 箬 2309 | 融 2310 | 珂 2311 | 鹗 2312 | 宗 2313 | 枚 2314 | 降 2315 | 鸬 2316 | 妯 2317 | 阄 2318 | 堰 2319 | 盐 2320 | 毅 2321 | 必 2322 | 杨 2323 | 崃 2324 | 俺 2325 | 甬 2326 | 状 2327 | 莘 2328 | 货 2329 | 耸 2330 | 菱 2331 | 腼 2332 | 铸 2333 | 唏 2334 | 痤 2335 | 孚 2336 | 澳 2337 | 懒 2338 | 溅 2339 | 翘 2340 | 疙 2341 | 杷 2342 | 淼 2343 | 缙 2344 | 骰 2345 | 喊 2346 | 悉 2347 | 砻 2348 | 坷 2349 | 艇 2350 | 赁 2351 | 界 2352 | 谤 2353 | 纣 2354 | 宴 2355 | 晃 2356 | 茹 2357 | 归 2358 | 饭 2359 | 梢 2360 | 铡 2361 | 街 2362 | 抄 2363 | 肼 2364 | 鬟 2365 | 苯 2366 | 颂 2367 | 撷 2368 | 戈 2369 | 炒 2370 | 咆 2371 | 茭 2372 | 瘙 2373 | 负 2374 | 仰 2375 | 客 2376 | 琉 2377 | 铢 2378 | 封 2379 | 卑 2380 | 珥 2381 | 椿 2382 | 镧 2383 | 窨 2384 | 鬲 2385 | 寿 2386 | 御 2387 | 袤 2388 | 铃 2389 | 萎 2390 | 砖 2391 | 餮 2392 | 脒 2393 | 裳 2394 | 肪 2395 | 孕 2396 | 嫣 2397 | 馗 2398 | 嵇 2399 | 恳 2400 | 氯 2401 | 江 2402 | 石 2403 | 褶 2404 | 冢 2405 | 祸 2406 | 阻 2407 | 狈 2408 | 羞 2409 | 银 2410 | 靳 2411 | 透 2412 | 咳 2413 | 叼 2414 | 敷 2415 | 芷 2416 | 啥 2417 | 它 2418 | 瓤 2419 | 兰 2420 | 痘 2421 | 懊 2422 | 逑 2423 | 肌 2424 | 往 2425 | 捺 2426 | 坊 2427 | 甩 2428 | 呻 2429 | 〃 2430 | 沦 2431 | 忘 2432 | 膻 2433 | 祟 2434 | 菅 2435 | 剧 2436 | 崆 2437 | 智 2438 | 坯 2439 | 臧 2440 | 霍 2441 | 墅 2442 | 攻 2443 | 眯 2444 | 倘 2445 | 拢 2446 | 骠 2447 | 铐 2448 | 庭 2449 | 岙 2450 | 瓠 2451 | ′ 2452 | 缺 2453 | 泥 2454 | 迢 2455 | 捶 2456 | ? 2457 | ? 2458 | 郏 2459 | 喙 2460 | 掷 2461 | 沌 2462 | 纯 2463 | 秘 2464 | 种 2465 | 听 2466 | 绘 2467 | 固 2468 | 螨 2469 | 团 2470 | 香 2471 | 盗 2472 | 妒 2473 | 埚 2474 | 蓝 2475 | 拖 2476 | 旱 2477 | 荞 2478 | 铀 2479 | 血 2480 | 遏 2481 | 汲 2482 | 辰 2483 | 叩 2484 | 拽 2485 | 幅 2486 | 硬 2487 | 惶 2488 | 桀 2489 | 漠 2490 | 措 2491 | 泼 2492 | 唑 2493 | 齐 2494 | 肾 2495 | 念 2496 | 酱 2497 | 虚 2498 | 屁 2499 | 耶 2500 | 旗 2501 | 砦 2502 | 闵 2503 | 婉 2504 | 馆 2505 | 拭 2506 | 绅 2507 | 韧 2508 | 忏 2509 | 窝 2510 | 醋 2511 | 葺 2512 | 顾 2513 | 辞 2514 | 倜 2515 | 堆 2516 | 辋 2517 | 逆 2518 | 玟 2519 | 贱 2520 | 疾 2521 | 董 2522 | 惘 2523 | 倌 2524 | 锕 2525 | 淘 2526 | 嘀 2527 | 莽 2528 | 俭 2529 | 笏 2530 | 绑 2531 | 鲷 2532 | 杈 2533 | 择 2534 | 蟀 2535 | 粥 2536 | 嗯 2537 | 驰 2538 | 逾 2539 | 案 2540 | 谪 2541 | 褓 2542 | 胫 2543 | 哩 2544 | 昕 2545 | 颚 2546 | 鲢 2547 | 绠 2548 | 躺 2549 | 鹄 2550 | 崂 2551 | 儒 2552 | 俨 2553 | 丝 2554 | 尕 2555 | 泌 2556 | 啊 2557 | 萸 2558 | 彰 2559 | 幺 2560 | 吟 2561 | 骄 2562 | 苣 2563 | 弦 2564 | 脊 2565 | 瑰 2566 | 〈 2567 | 诛 2568 | 镁 2569 | 析 2570 | 闪 2571 | 剪 2572 | 侧 2573 | 哟 2574 | 框 2575 | 螃 2576 | 守 2577 | 嬗 2578 | 燕 2579 | 狭 2580 | 铈 2581 | 缮 2582 | 概 2583 | 迳 2584 | 痧 2585 | 鲲 2586 | 俯 2587 | 售 2588 | 笼 2589 | 痣 2590 | 扉 2591 | 挖 2592 | 满 2593 | 咋 2594 | 援 2595 | 邱 2596 | 扇 2597 | 歪 2598 | 便 2599 | 玑 2600 | 绦 2601 | 峡 2602 | 蛇 2603 | 叨 2604 | 〖 2605 | 泽 2606 | 胃 2607 | 斓 2608 | 喋 2609 | 怂 2610 | 坟 2611 | 猪 2612 | 该 2613 | 蚬 2614 | 炕 2615 | 弥 2616 | 赞 2617 | 棣 2618 | 晔 2619 | 娠 2620 | 挲 2621 | 狡 2622 | 创 2623 | 疖 2624 | 铕 2625 | 镭 2626 | 稷 2627 | 挫 2628 | 弭 2629 | 啾 2630 | 翔 2631 | 粉 2632 | 履 2633 | 苘 2634 | 哦 2635 | 楼 2636 | 秕 2637 | 铂 2638 | 土 2639 | 锣 2640 | 瘟 2641 | 挣 2642 | 栉 2643 | 习 2644 | 享 2645 | 桢 2646 | 袅 2647 | 磨 2648 | 桂 2649 | 谦 2650 | 延 2651 | 坚 2652 | 蔚 2653 | 噗 2654 | 署 2655 | 谟 2656 | 猬 2657 | 钎 2658 | 恐 2659 | 嬉 2660 | 雒 2661 | 倦 2662 | 衅 2663 | 亏 2664 | 璩 2665 | 睹 2666 | 刻 2667 | 殿 2668 | 王 2669 | 算 2670 | 雕 2671 | 麻 2672 | 丘 2673 | 柯 2674 | 骆 2675 | 丸 2676 | 塍 2677 | 谚 2678 | 添 2679 | 鲈 2680 | 垓 2681 | 桎 2682 | 蚯 2683 | 芥 2684 | 予 2685 | 飕 2686 | 镦 2687 | 谌 2688 | 窗 2689 | 醚 2690 | 菀 2691 | 亮 2692 | 搪 2693 | 莺 2694 | 蒿 2695 | 羁 2696 | 足 2697 | J 2698 | 真 2699 | 轶 2700 | 悬 2701 | 衷 2702 | 靛 2703 | 翊 2704 | 掩 2705 | 哒 2706 | 炅 2707 | 掐 2708 | 冼 2709 | 妮 2710 | l 2711 | 谐 2712 | 稚 2713 | 荆 2714 | 擒 2715 | 犯 2716 | 陵 2717 | 虏 2718 | 浓 2719 | 崽 2720 | 刍 2721 | 陌 2722 | 傻 2723 | 孜 2724 | 千 2725 | 靖 2726 | 演 2727 | 矜 2728 | 钕 2729 | 煽 2730 | 杰 2731 | 酗 2732 | 渗 2733 | 伞 2734 | 栋 2735 | 俗 2736 | 泫 2737 | 戍 2738 | 罕 2739 | 沾 2740 | 疽 2741 | 灏 2742 | 煦 2743 | 芬 2744 | 磴 2745 | 叱 2746 | 阱 2747 | 榉 2748 | 湃 2749 | 蜀 2750 | 叉 2751 | 醒 2752 | 彪 2753 | 租 2754 | 郡 2755 | 篷 2756 | 屎 2757 | 良 2758 | 垢 2759 | 隗 2760 | 弱 2761 | 陨 2762 | 峪 2763 | 砷 2764 | 掴 2765 | 颁 2766 | 胎 2767 | 雯 2768 | 绵 2769 | 贬 2770 | 沐 2771 | 撵 2772 | 隘 2773 | 篙 2774 | 暖 2775 | 曹 2776 | 陡 2777 | 栓 2778 | 填 2779 | 臼 2780 | 彦 2781 | 瓶 2782 | 琪 2783 | 潼 2784 | 哪 2785 | 鸡 2786 | 摩 2787 | 啦 2788 | 俟 2789 | 锋 2790 | 域 2791 | 耻 2792 | 蔫 2793 | 疯 2794 | 纹 2795 | 撇 2796 | 毒 2797 | 绶 2798 | 痛 2799 | 酯 2800 | 忍 2801 | 爪 2802 | 赳 2803 | 歆 2804 | 嘹 2805 | 辕 2806 | 烈 2807 | 册 2808 | 朴 2809 | 钱 2810 | 吮 2811 | 毯 2812 | 癜 2813 | 娃 2814 | 谀 2815 | 邵 2816 | 厮 2817 | 炽 2818 | 璞 2819 | 邃 2820 | 丐 2821 | 追 2822 | 词 2823 | 瓒 2824 | 忆 2825 | 轧 2826 | 芫 2827 | 谯 2828 | 喷 2829 | 弟 2830 | 半 2831 | 冕 2832 | 裙 2833 | 掖 2834 | 墉 2835 | 绮 2836 | 寝 2837 | 苔 2838 | 势 2839 | 顷 2840 | 褥 2841 | 切 2842 | 衮 2843 | 君 2844 | 佳 2845 | 嫒 2846 | 蚩 2847 | 霞 2848 | 佚 2849 | 洙 2850 | 逊 2851 | 镖 2852 | 暹 2853 | 唛 2854 | & 2855 | 殒 2856 | 顶 2857 | 碗 2858 | 獗 2859 | 轭 2860 | 铺 2861 | 蛊 2862 | 废 2863 | 恹 2864 | 汨 2865 | 崩 2866 | 珍 2867 | 那 2868 | 杵 2869 | 曲 2870 | 纺 2871 | 夏 2872 | 薰 2873 | 傀 2874 | 闳 2875 | 淬 2876 | 姘 2877 | 舀 2878 | 拧 2879 | 卷 2880 | 楂 2881 | 恍 2882 | 讪 2883 | 厩 2884 | 寮 2885 | 篪 2886 | 赓 2887 | 乘 2888 | 灭 2889 | 盅 2890 | 鞣 2891 | 沟 2892 | 慎 2893 | 挂 2894 | 饺 2895 | 鼾 2896 | 杳 2897 | 树 2898 | 缨 2899 | 丛 2900 | 絮 2901 | 娌 2902 | 臻 2903 | 嗳 2904 | 篡 2905 | 侩 2906 | 述 2907 | 衰 2908 | 矛 2909 | 圈 2910 | 蚜 2911 | 匕 2912 | 筹 2913 | 匿 2914 | 濞 2915 | 晨 2916 | 叶 2917 | 骋 2918 | 郝 2919 | 挚 2920 | 蚴 2921 | 滞 2922 | 增 2923 | 侍 2924 | 描 2925 | 瓣 2926 | 吖 2927 | 嫦 2928 | 蟒 2929 | 匾 2930 | 圣 2931 | 赌 2932 | 毡 2933 | 癞 2934 | 恺 2935 | 百 2936 | 曳 2937 | 需 2938 | 篓 2939 | 肮 2940 | 庖 2941 | 帏 2942 | 卿 2943 | 驿 2944 | 遗 2945 | 蹬 2946 | 鬓 2947 | 骡 2948 | 歉 2949 | 芎 2950 | 胳 2951 | 屐 2952 | 禽 2953 | 烦 2954 | 晌 2955 | 寄 2956 | 媾 2957 | 狄 2958 | 翡 2959 | 苒 2960 | 船 2961 | 廉 2962 | 终 2963 | 痞 2964 | 殇 2965 | 々 2966 | 畦 2967 | 饶 2968 | 改 2969 | 拆 2970 | 悻 2971 | 萄 2972 | £ 2973 | 瓿 2974 | 乃 2975 | 訾 2976 | 桅 2977 | 匮 2978 | 溧 2979 | 拥 2980 | 纱 2981 | 铍 2982 | 骗 2983 | 蕃 2984 | 龋 2985 | 缬 2986 | 父 2987 | 佐 2988 | 疚 2989 | 栎 2990 | 醍 2991 | 掳 2992 | 蓄 2993 | x 2994 | 惆 2995 | 颜 2996 | 鲆 2997 | 榆 2998 | 〔 2999 | 猎 3000 | 敌 3001 | 暴 3002 | 谥 3003 | 鲫 3004 | 贾 3005 | 罗 3006 | 玻 3007 | 缄 3008 | 扦 3009 | 芪 3010 | 癣 3011 | 落 3012 | 徒 3013 | 臾 3014 | 恿 3015 | 猩 3016 | 托 3017 | 邴 3018 | 肄 3019 | 牵 3020 | 春 3021 | 陛 3022 | 耀 3023 | 刊 3024 | 拓 3025 | 蓓 3026 | 邳 3027 | 堕 3028 | 寇 3029 | 枉 3030 | 淌 3031 | 啡 3032 | 湄 3033 | 兽 3034 | 酷 3035 | 萼 3036 | 碚 3037 | 濠 3038 | 萤 3039 | 夹 3040 | 旬 3041 | 戮 3042 | 梭 3043 | 琥 3044 | 椭 3045 | 昔 3046 | 勺 3047 | 蜊 3048 | 绐 3049 | 晚 3050 | 孺 3051 | 僵 3052 | 宣 3053 | 摄 3054 | 冽 3055 | 旨 3056 | 萌 3057 | 忙 3058 | 蚤 3059 | 眉 3060 | 噼 3061 | 蟑 3062 | 付 3063 | 契 3064 | 瓜 3065 | 悼 3066 | 颡 3067 | 壁 3068 | 曾 3069 | 窕 3070 | 颢 3071 | 澎 3072 | 仿 3073 | 俑 3074 | 浑 3075 | 嵌 3076 | 浣 3077 | 乍 3078 | 碌 3079 | 褪 3080 | 乱 3081 | 蔟 3082 | 隙 3083 | 玩 3084 | 剐 3085 | 葫 3086 | 箫 3087 | 纲 3088 | 围 3089 | 伐 3090 | 决 3091 | 伙 3092 | 漩 3093 | 瑟 3094 | 刑 3095 | 肓 3096 | 镳 3097 | 缓 3098 | 蹭 3099 | 氨 3100 | 皓 3101 | 典 3102 | 畲 3103 | 坍 3104 | 铑 3105 | 檐 3106 | 塑 3107 | 洞 3108 | 倬 3109 | 储 3110 | 胴 3111 | 淳 3112 | 戾 3113 | 吐 3114 | 灼 3115 | 惺 3116 | 妙 3117 | 毕 3118 | 珐 3119 | 缈 3120 | 虱 3121 | 盖 3122 | 羰 3123 | 鸿 3124 | 磅 3125 | 谓 3126 | 髅 3127 | 娴 3128 | 苴 3129 | 唷 3130 | 蚣 3131 | 霹 3132 | 抨 3133 | 贤 3134 | 唠 3135 | 犬 3136 | 誓 3137 | 逍 3138 | 庠 3139 | 逼 3140 | 麓 3141 | 籼 3142 | 釉 3143 | 呜 3144 | 碧 3145 | 秧 3146 | 氩 3147 | 摔 3148 | 霄 3149 | 穸 3150 | 纨 3151 | 辟 3152 | 妈 3153 | 映 3154 | 完 3155 | 牛 3156 | 缴 3157 | 嗷 3158 | 炊 3159 | 恩 3160 | 荔 3161 | 茆 3162 | 掉 3163 | 紊 3164 | 慌 3165 | 莓 3166 | 羟 3167 | 阙 3168 | 萁 3169 | 磐 3170 | 另 3171 | 蕹 3172 | 辱 3173 | 鳐 3174 | 湮 3175 | 吡 3176 | 吩 3177 | 唐 3178 | 睦 3179 | 垠 3180 | 舒 3181 | 圜 3182 | 冗 3183 | 瞿 3184 | 溺 3185 | 芾 3186 | 囱 3187 | 匠 3188 | 僳 3189 | 汐 3190 | 菩 3191 | 饬 3192 | 漓 3193 | 黑 3194 | 霰 3195 | 浸 3196 | 濡 3197 | 窥 3198 | 毂 3199 | 蒡 3200 | 兢 3201 | 驻 3202 | 鹉 3203 | 芮 3204 | 诙 3205 | 迫 3206 | 雳 3207 | 厂 3208 | 忐 3209 | 臆 3210 | 猴 3211 | 鸣 3212 | 蚪 3213 | 栈 3214 | 箕 3215 | 羡 3216 | 渐 3217 | 莆 3218 | 捍 3219 | 眈 3220 | 哓 3221 | 趴 3222 | 蹼 3223 | 埕 3224 | 嚣 3225 | 骛 3226 | 宏 3227 | 淄 3228 | 斑 3229 | 噜 3230 | 严 3231 | 瑛 3232 | 垃 3233 | 椎 3234 | 诱 3235 | 压 3236 | 庾 3237 | 绞 3238 | 焘 3239 | 廿 3240 | 抡 3241 | 迄 3242 | 棘 3243 | 夫 3244 | 纬 3245 | 锹 3246 | 眨 3247 | 瞌 3248 | 侠 3249 | 脐 3250 | 竞 3251 | 瀑 3252 | 孳 3253 | 骧 3254 | 遁 3255 | 姜 3256 | 颦 3257 | 荪 3258 | 滚 3259 | 萦 3260 | 伪 3261 | 逸 3262 | 粳 3263 | 爬 3264 | 锁 3265 | 矣 3266 | 役 3267 | 趣 3268 | 洒 3269 | 颔 3270 | 诏 3271 | 逐 3272 | 奸 3273 | 甭 3274 | 惠 3275 | 攀 3276 | 蹄 3277 | 泛 3278 | 尼 3279 | 拼 3280 | 阮 3281 | 鹰 3282 | 亚 3283 | 颈 3284 | 惑 3285 | 勒 3286 | 〉 3287 | 际 3288 | 肛 3289 | 爷 3290 | 刚 3291 | 钨 3292 | 丰 3293 | 养 3294 | 冶 3295 | 鲽 3296 | 辉 3297 | 蔻 3298 | 画 3299 | 覆 3300 | 皴 3301 | 妊 3302 | 麦 3303 | 返 3304 | 醉 3305 | 皂 3306 | 擀 3307 | 〗 3308 | 酶 3309 | 凑 3310 | 粹 3311 | 悟 3312 | 诀 3313 | 硖 3314 | 港 3315 | 卜 3316 | z 3317 | 杀 3318 | 涕 3319 | ± 3320 | 舍 3321 | 铠 3322 | 抵 3323 | 弛 3324 | 段 3325 | 敝 3326 | 镐 3327 | 奠 3328 | 拂 3329 | 轴 3330 | 跛 3331 | 袱 3332 | e 3333 | t 3334 | 沉 3335 | 菇 3336 | 俎 3337 | 薪 3338 | 峦 3339 | 秭 3340 | 蟹 3341 | 历 3342 | 盟 3343 | 菠 3344 | 寡 3345 | 液 3346 | 肢 3347 | 喻 3348 | 染 3349 | 裱 3350 | 悱 3351 | 抱 3352 | 氙 3353 | 赤 3354 | 捅 3355 | 猛 3356 | 跑 3357 | 氮 3358 | 谣 3359 | 仁 3360 | 尺 3361 | 辊 3362 | 窍 3363 | 烙 3364 | 衍 3365 | 架 3366 | 擦 3367 | 倏 3368 | 璐 3369 | 瑁 3370 | 币 3371 | 楞 3372 | 胖 3373 | 夔 3374 | 趸 3375 | 邛 3376 | 惴 3377 | 饕 3378 | 虔 3379 | 蝎 3380 | § 3381 | 哉 3382 | 贝 3383 | 宽 3384 | 辫 3385 | 炮 3386 | 扩 3387 | 饲 3388 | 籽 3389 | 魏 3390 | 菟 3391 | 锰 3392 | 伍 3393 | 猝 3394 | 末 3395 | 琳 3396 | 哚 3397 | 蛎 3398 | 邂 3399 | 呀 3400 | 姿 3401 | 鄞 3402 | 却 3403 | 歧 3404 | 仙 3405 | 恸 3406 | 椐 3407 | 森 3408 | 牒 3409 | 寤 3410 | 袒 3411 | 婆 3412 | 虢 3413 | 雅 3414 | 钉 3415 | 朵 3416 | 贼 3417 | 欲 3418 | 苞 3419 | 寰 3420 | 故 3421 | 龚 3422 | 坭 3423 | 嘘 3424 | 咫 3425 | 礼 3426 | 硷 3427 | 兀 3428 | 睢 3429 | 汶 3430 | ’ 3431 | 铲 3432 | 烧 3433 | 绕 3434 | 诃 3435 | 浃 3436 | 钿 3437 | 哺 3438 | 柜 3439 | 讼 3440 | 颊 3441 | 璁 3442 | 腔 3443 | 洽 3444 | 咐 3445 | 脲 3446 | 簌 3447 | 筠 3448 | 镣 3449 | 玮 3450 | 鞠 3451 | 谁 3452 | 兼 3453 | 姆 3454 | 挥 3455 | 梯 3456 | 蝴 3457 | 谘 3458 | 漕 3459 | 刷 3460 | 躏 3461 | 宦 3462 | 弼 3463 | b 3464 | 垌 3465 | 劈 3466 | 麟 3467 | 莉 3468 | 揭 3469 | 笙 3470 | 渎 3471 | 仕 3472 | 嗤 3473 | 仓 3474 | 配 3475 | 怏 3476 | 抬 3477 | 错 3478 | 泯 3479 | 镊 3480 | 孰 3481 | 猿 3482 | 邪 3483 | 仍 3484 | 秋 3485 | 鼬 3486 | 壹 3487 | 歇 3488 | 吵 3489 | 炼 3490 | < 3491 | 尧 3492 | 射 3493 | 柬 3494 | 廷 3495 | 胧 3496 | 霾 3497 | 凳 3498 | 隋 3499 | 肚 3500 | 浮 3501 | 梦 3502 | 祥 3503 | 株 3504 | 堵 3505 | 退 3506 | L 3507 | 鹫 3508 | 跎 3509 | 凶 3510 | 毽 3511 | 荟 3512 | 炫 3513 | 栩 3514 | 玳 3515 | 甜 3516 | 沂 3517 | 鹿 3518 | 顽 3519 | 伯 3520 | 爹 3521 | 赔 3522 | 蛴 3523 | 徐 3524 | 匡 3525 | 欣 3526 | 狰 3527 | 缸 3528 | 雹 3529 | 蟆 3530 | 疤 3531 | 默 3532 | 沤 3533 | 啜 3534 | 痂 3535 | 衣 3536 | 禅 3537 | w 3538 | i 3539 | h 3540 | 辽 3541 | 葳 3542 | 黝 3543 | 钗 3544 | 停 3545 | 沽 3546 | 棒 3547 | 馨 3548 | 颌 3549 | 肉 3550 | 吴 3551 | 硫 3552 | 悯 3553 | 劾 3554 | 娈 3555 | 马 3556 | 啧 3557 | 吊 3558 | 悌 3559 | 镑 3560 | 峭 3561 | 帆 3562 | 瀣 3563 | 涉 3564 | 咸 3565 | 疸 3566 | 滋 3567 | 泣 3568 | 翦 3569 | 拙 3570 | 癸 3571 | 钥 3572 | 蜒 3573 | + 3574 | 尾 3575 | 庄 3576 | 凝 3577 | 泉 3578 | 婢 3579 | 渴 3580 | 谊 3581 | 乞 3582 | 陆 3583 | 锉 3584 | 糊 3585 | 鸦 3586 | 淮 3587 | I 3588 | B 3589 | N 3590 | 晦 3591 | 弗 3592 | 乔 3593 | 庥 3594 | 葡 3595 | 尻 3596 | 席 3597 | 橡 3598 | 傣 3599 | 渣 3600 | 拿 3601 | 惩 3602 | 麋 3603 | 斛 3604 | 缃 3605 | 矮 3606 | 蛏 3607 | 岘 3608 | 鸽 3609 | 姐 3610 | 膏 3611 | 催 3612 | 奔 3613 | 镒 3614 | 喱 3615 | 蠡 3616 | 摧 3617 | 钯 3618 | 胤 3619 | 柠 3620 | 拐 3621 | 璋 3622 | 鸥 3623 | 卢 3624 | 荡 3625 | 倾 3626 | ^ 3627 | _ 3628 | 珀 3629 | 逄 3630 | 萧 3631 | 塾 3632 | 掇 3633 | 贮 3634 | 笆 3635 | 聂 3636 | 圃 3637 | 冲 3638 | 嵬 3639 | M 3640 | 滔 3641 | 笕 3642 | 值 3643 | 炙 3644 | 偶 3645 | 蜱 3646 | 搐 3647 | 梆 3648 | 汪 3649 | 蔬 3650 | 腑 3651 | 鸯 3652 | 蹇 3653 | 敞 3654 | 绯 3655 | 仨 3656 | 祯 3657 | 谆 3658 | 梧 3659 | 糗 3660 | 鑫 3661 | 啸 3662 | 豺 3663 | 囹 3664 | 猾 3665 | 巢 3666 | 柄 3667 | 瀛 3668 | 筑 3669 | 踌 3670 | 沭 3671 | 暗 3672 | 苁 3673 | 鱿 3674 | 蹉 3675 | 脂 3676 | 蘖 3677 | 牢 3678 | 热 3679 | 木 3680 | 吸 3681 | 溃 3682 | 宠 3683 | 序 3684 | 泞 3685 | 偿 3686 | 拜 3687 | 檩 3688 | 厚 3689 | 朐 3690 | 毗 3691 | 螳 3692 | 吞 3693 | 媚 3694 | 朽 3695 | 担 3696 | 蝗 3697 | 橘 3698 | 畴 3699 | 祈 3700 | 糟 3701 | 盱 3702 | 隼 3703 | 郜 3704 | 惜 3705 | 珠 3706 | 裨 3707 | 铵 3708 | 焙 3709 | 琚 3710 | 唯 3711 | 咚 3712 | 噪 3713 | 骊 3714 | 丫 3715 | 滢 3716 | 勤 3717 | 棉 3718 | 呸 3719 | 咣 3720 | 淀 3721 | 隔 3722 | 蕾 3723 | 窈 3724 | 饨 3725 | 挨 3726 | 煅 3727 | 短 3728 | 匙 3729 | 粕 3730 | 镜 3731 | 赣 3732 | 撕 3733 | 墩 3734 | 酬 3735 | 馁 3736 | 豌 3737 | 颐 3738 | 抗 3739 | 酣 3740 | 氓 3741 | 佑 3742 | 搁 3743 | 哭 3744 | 递 3745 | 耷 3746 | 涡 3747 | 桃 3748 | 贻 3749 | 碣 3750 | 截 3751 | 瘦 3752 | 昭 3753 | 镌 3754 | 蔓 3755 | 氚 3756 | 甲 3757 | 猕 3758 | 蕴 3759 | 蓬 3760 | 散 3761 | 拾 3762 | 纛 3763 | 狼 3764 | 猷 3765 | 铎 3766 | 埋 3767 | 旖 3768 | 矾 3769 | 讳 3770 | 囊 3771 | 糜 3772 | 迈 3773 | 粟 3774 | 蚂 3775 | 紧 3776 | 鲳 3777 | 瘢 3778 | 栽 3779 | 稼 3780 | 羊 3781 | 锄 3782 | 斟 3783 | 睁 3784 | 桥 3785 | 瓮 3786 | 蹙 3787 | 祉 3788 | 醺 3789 | 鼻 3790 | 昱 3791 | 剃 3792 | 跳 3793 | 篱 3794 | 跷 3795 | 蒜 3796 | 翎 3797 | 宅 3798 | 晖 3799 | 嗑 3800 | 壑 3801 | 峻 3802 | 癫 3803 | 屏 3804 | 狠 3805 | 陋 3806 | 袜 3807 | 途 3808 | 憎 3809 | 祀 3810 | 莹 3811 | 滟 3812 | 佶 3813 | 溥 3814 | 臣 3815 | 约 3816 | 盛 3817 | 峰 3818 | 磁 3819 | 慵 3820 | 婪 3821 | 拦 3822 | 莅 3823 | 朕 3824 | 鹦 3825 | 粲 3826 | 裤 3827 | 哎 3828 | 疡 3829 | 嫖 3830 | 琵 3831 | 窟 3832 | 堪 3833 | 谛 3834 | 嘉 3835 | 儡 3836 | 鳝 3837 | 斩 3838 | 郾 3839 | 驸 3840 | 酊 3841 | 妄 3842 | 胜 3843 | 贺 3844 | 徙 3845 | 傅 3846 | 噌 3847 | 钢 3848 | 栅 3849 | 庇 3850 | 恋 3851 | 匝 3852 | 巯 3853 | 邈 3854 | 尸 3855 | 锚 3856 | 粗 3857 | 佟 3858 | 蛟 3859 | 薹 3860 | 纵 3861 | 蚊 3862 | 郅 3863 | 绢 3864 | 锐 3865 | 苗 3866 | 俞 3867 | 篆 3868 | 淆 3869 | 膀 3870 | 鲜 3871 | 煎 3872 | 诶 3873 | 秽 3874 | 寻 3875 | 涮 3876 | 刺 3877 | 怀 3878 | 噶 3879 | 巨 3880 | 褰 3881 | 魅 3882 | 灶 3883 | 灌 3884 | 桉 3885 | 藕 3886 | 谜 3887 | 舸 3888 | 薄 3889 | 搀 3890 | 恽 3891 | 借 3892 | 牯 3893 | 痉 3894 | 渥 3895 | 愿 3896 | 亓 3897 | 耘 3898 | 杠 3899 | 柩 3900 | 锔 3901 | 蚶 3902 | 钣 3903 | 珈 3904 | 喘 3905 | 蹒 3906 | 幽 3907 | 赐 3908 | 稗 3909 | 晤 3910 | 莱 3911 | 泔 3912 | 扯 3913 | 肯 3914 | 菪 3915 | 裆 3916 | 腩 3917 | 豉 3918 | 疆 3919 | 骜 3920 | 腐 3921 | 倭 3922 | 珏 3923 | 唔 3924 | 粮 3925 | 亡 3926 | 润 3927 | 慰 3928 | 伽 3929 | 橄 3930 | 玄 3931 | 誉 3932 | 醐 3933 | 胆 3934 | 龊 3935 | 粼 3936 | 塬 3937 | 陇 3938 | 彼 3939 | 削 3940 | 嗣 3941 | 绾 3942 | 芽 3943 | 妗 3944 | 垭 3945 | 瘴 3946 | 爽 3947 | 薏 3948 | 寨 3949 | 龈 3950 | 泠 3951 | 弹 3952 | 赢 3953 | 漪 3954 | 猫 3955 | 嘧 3956 | 涂 3957 | 恤 3958 | 圭 3959 | 茧 3960 | 烽 3961 | 屑 3962 | 痕 3963 | 巾 3964 | 赖 3965 | 荸 3966 | 凰 3967 | 腮 3968 | 畈 3969 | 亵 3970 | 蹲 3971 | 偃 3972 | 苇 3973 | 澜 3974 | 艮 3975 | 换 3976 | 骺 3977 | 烘 3978 | 苕 3979 | 梓 3980 | 颉 3981 | 肇 3982 | 哗 3983 | 悄 3984 | 氤 3985 | 涠 3986 | 葬 3987 | 屠 3988 | 鹭 3989 | 植 3990 | 竺 3991 | 佯 3992 | 诣 3993 | 鲇 3994 | 瘀 3995 | 鲅 3996 | 邦 3997 | 移 3998 | 滁 3999 | 冯 4000 | 耕 4001 | 癔 4002 | 戌 4003 | 茬 4004 | 沁 4005 | 巩 4006 | 悠 4007 | 湘 4008 | 洪 4009 | 痹 4010 | 锟 4011 | 循 4012 | 谋 4013 | 腕 4014 | 鳃 4015 | 钠 4016 | 捞 4017 | 焉 4018 | 迎 4019 | 碱 4020 | 伫 4021 | 急 4022 | 榷 4023 | 奈 4024 | 邝 4025 | 卯 4026 | 辄 4027 | 皲 4028 | 卟 4029 | 醛 4030 | 畹 4031 | 忧 4032 | 稳 4033 | 雄 4034 | 昼 4035 | 缩 4036 | 阈 4037 | 睑 4038 | 扌 4039 | 耗 4040 | 曦 4041 | 涅 4042 | 捏 4043 | 瞧 4044 | 邕 4045 | 淖 4046 | 漉 4047 | 铝 4048 | 耦 4049 | 禹 4050 | 湛 4051 | 喽 4052 | 莼 4053 | 琅 4054 | 诸 4055 | 苎 4056 | 纂 4057 | 硅 4058 | 始 4059 | 嗨 4060 | 傥 4061 | 燃 4062 | 臂 4063 | 赅 4064 | 嘈 4065 | 呆 4066 | 贵 4067 | 屹 4068 | 壮 4069 | 肋 4070 | 亍 4071 | 蚀 4072 | 卅 4073 | 豹 4074 | 腆 4075 | 邬 4076 | 迭 4077 | 浊 4078 | } 4079 | 童 4080 | 螂 4081 | 捐 4082 | 圩 4083 | 勐 4084 | 触 4085 | 寞 4086 | 汊 4087 | 壤 4088 | 荫 4089 | 膺 4090 | 渌 4091 | 芳 4092 | 懿 4093 | 遴 4094 | 螈 4095 | 泰 4096 | 蓼 4097 | 蛤 4098 | 茜 4099 | 舅 4100 | 枫 4101 | 朔 4102 | 膝 4103 | 眙 4104 | 避 4105 | 梅 4106 | 判 4107 | 鹜 4108 | 璜 4109 | 牍 4110 | 缅 4111 | 垫 4112 | 藻 4113 | 黔 4114 | 侥 4115 | 惚 4116 | 懂 4117 | 踩 4118 | 腰 4119 | 腈 4120 | 札 4121 | 丞 4122 | 唾 4123 | 慈 4124 | 顿 4125 | 摹 4126 | 荻 4127 | 琬 4128 | ~ 4129 | 斧 4130 | 沈 4131 | 滂 4132 | 胁 4133 | 胀 4134 | 幄 4135 | 莜 4136 | Z 4137 | 匀 4138 | 鄄 4139 | 掌 4140 | 绰 4141 | 茎 4142 | 焚 4143 | 赋 4144 | 萱 4145 | 谑 4146 | 汁 4147 | 铒 4148 | 瞎 4149 | 夺 4150 | 蜗 4151 | 野 4152 | 娆 4153 | 冀 4154 | 弯 4155 | 篁 4156 | 懵 4157 | 灞 4158 | 隽 4159 | 芡 4160 | 脘 4161 | 俐 4162 | 辩 4163 | 芯 4164 | 掺 4165 | 喏 4166 | 膈 4167 | 蝈 4168 | 觐 4169 | 悚 4170 | 踹 4171 | 蔗 4172 | 熠 4173 | 鼠 4174 | 呵 4175 | 抓 4176 | 橼 4177 | 峨 4178 | 畜 4179 | 缔 4180 | 禾 4181 | 崭 4182 | 弃 4183 | 熊 4184 | 摒 4185 | 凸 4186 | 拗 4187 | 穹 4188 | 蒙 4189 | 抒 4190 | 祛 4191 | 劝 4192 | 闫 4193 | 扳 4194 | 阵 4195 | 醌 4196 | 踪 4197 | 喵 4198 | 侣 4199 | 搬 4200 | 仅 4201 | 荧 4202 | 赎 4203 | 蝾 4204 | 琦 4205 | 买 4206 | 婧 4207 | 瞄 4208 | 寓 4209 | 皎 4210 | 冻 4211 | 赝 4212 | 箩 4213 | 莫 4214 | 瞰 4215 | 郊 4216 | 笫 4217 | 姝 4218 | 筒 4219 | 枪 4220 | 遣 4221 | 煸 4222 | 袋 4223 | 舆 4224 | 痱 4225 | 涛 4226 | 母 4227 | 〇 4228 | 启 4229 | 践 4230 | 耙 4231 | 绲 4232 | 盘 4233 | 遂 4234 | 昊 4235 | 搞 4236 | 槿 4237 | 诬 4238 | 纰 4239 | 泓 4240 | 惨 4241 | 檬 4242 | 亻 4243 | 越 4244 | C 4245 | o 4246 | 憩 4247 | 熵 4248 | 祷 4249 | 钒 4250 | 暧 4251 | 塔 4252 | 阗 4253 | 胰 4254 | 咄 4255 | 娶 4256 | 魔 4257 | 琶 4258 | 钞 4259 | 邻 4260 | 扬 4261 | 杉 4262 | 殴 4263 | 咽 4264 | 弓 4265 | 〆 4266 | 髻 4267 | 】 4268 | 吭 4269 | 揽 4270 | 霆 4271 | 拄 4272 | 殖 4273 | 脆 4274 | 彻 4275 | 岩 4276 | 芝 4277 | 勃 4278 | 辣 4279 | 剌 4280 | 钝 4281 | 嘎 4282 | 甄 4283 | 佘 4284 | 皖 4285 | 伦 4286 | 授 4287 | 徕 4288 | 憔 4289 | 挪 4290 | 皇 4291 | 庞 4292 | 稔 4293 | 芜 4294 | 踏 4295 | 溴 4296 | 兖 4297 | 卒 4298 | 擢 4299 | 饥 4300 | 鳞 4301 | 煲 4302 | ‰ 4303 | 账 4304 | 颗 4305 | 叻 4306 | 斯 4307 | 捧 4308 | 鳍 4309 | 琮 4310 | 讹 4311 | 蛙 4312 | 纽 4313 | 谭 4314 | 酸 4315 | 兔 4316 | 莒 4317 | 睇 4318 | 伟 4319 | 觑 4320 | 羲 4321 | 嗜 4322 | 宜 4323 | 褐 4324 | 旎 4325 | 辛 4326 | 卦 4327 | 诘 4328 | 筋 4329 | 鎏 4330 | 溪 4331 | 挛 4332 | 熔 4333 | 阜 4334 | 晰 4335 | 鳅 4336 | 丢 4337 | 奚 4338 | 灸 4339 | 呱 4340 | 献 4341 | 陉 4342 | 黛 4343 | 鸪 4344 | 甾 4345 | 萨 4346 | 疮 4347 | 拯 4348 | 洲 4349 | 疹 4350 | 辑 4351 | 叙 4352 | 恻 4353 | 谒 4354 | 允 4355 | 柔 4356 | 烂 4357 | 氏 4358 | 逅 4359 | 漆 4360 | 拎 4361 | 惋 4362 | 扈 4363 | 湟 4364 | 纭 4365 | 啕 4366 | 掬 4367 | 擞 4368 | 哥 4369 | 忽 4370 | 涤 4371 | 鸵 4372 | 靡 4373 | 郗 4374 | 瓷 4375 | 扁 4376 | 廊 4377 | 怨 4378 | 雏 4379 | 钮 4380 | 敦 4381 | E 4382 | 懦 4383 | 憋 4384 | 汀 4385 | 拚 4386 | 啉 4387 | 腌 4388 | 岸 4389 | f 4390 | 痼 4391 | 瞅 4392 | 尊 4393 | 咀 4394 | 眩 4395 | 飙 4396 | 忌 4397 | 仝 4398 | 迦 4399 | 熬 4400 | 毫 4401 | 胯 4402 | 篑 4403 | 茄 4404 | 腺 4405 | 凄 4406 | 舛 4407 | 碴 4408 | 锵 4409 | 诧 4410 | 羯 4411 | 後 4412 | 漏 4413 | 汤 4414 | 宓 4415 | 仞 4416 | 蚁 4417 | 壶 4418 | 谰 4419 | 皑 4420 | 铄 4421 | 棰 4422 | 罔 4423 | 辅 4424 | 晶 4425 | 苦 4426 | 牟 4427 | 闽 4428 | \ 4429 | 烃 4430 | 饮 4431 | 聿 4432 | 丙 4433 | 蛳 4434 | 朱 4435 | 煤 4436 | 涔 4437 | 鳖 4438 | 犁 4439 | 罐 4440 | 荼 4441 | 砒 4442 | 淦 4443 | 妤 4444 | 黏 4445 | 戎 4446 | 孑 4447 | 婕 4448 | 瑾 4449 | 戢 4450 | 钵 4451 | 枣 4452 | 捋 4453 | 砥 4454 | 衩 4455 | 狙 4456 | 桠 4457 | 稣 4458 | 阎 4459 | 肃 4460 | 梏 4461 | 诫 4462 | 孪 4463 | 昶 4464 | 婊 4465 | 衫 4466 | 嗔 4467 | 侃 4468 | 塞 4469 | 蜃 4470 | 樵 4471 | 峒 4472 | 貌 4473 | 屿 4474 | 欺 4475 | 缫 4476 | 阐 4477 | 栖 4478 | 诟 4479 | 珞 4480 | 荭 4481 | 吝 4482 | 萍 4483 | 嗽 4484 | 恂 4485 | 啻 4486 | 蜴 4487 | 磬 4488 | 峋 4489 | 俸 4490 | 豫 4491 | 谎 4492 | 徊 4493 | 镍 4494 | 韬 4495 | 魇 4496 | 晴 4497 | U 4498 | 囟 4499 | 猜 4500 | 蛮 4501 | 坐 4502 | 囿 4503 | 伴 4504 | 亭 4505 | 肝 4506 | 佗 4507 | 蝠 4508 | 妃 4509 | 胞 4510 | 滩 4511 | 榴 4512 | 氖 4513 | 垩 4514 | 苋 4515 | 砣 4516 | 扪 4517 | 馏 4518 | 姓 4519 | 轩 4520 | 厉 4521 | 夥 4522 | 侈 4523 | 禀 4524 | 垒 4525 | 岑 4526 | 赏 4527 | 钛 4528 | 辐 4529 | 痔 4530 | 披 4531 | 纸 4532 | 碳 4533 | “ 4534 | 坞 4535 | 蠓 4536 | 挤 4537 | 荥 4538 | 沅 4539 | 悔 4540 | 铧 4541 | 帼 4542 | 蒌 4543 | 蝇 4544 | a 4545 | p 4546 | y 4547 | n 4548 | g 4549 | 哀 4550 | 浆 4551 | 瑶 4552 | 凿 4553 | 桶 4554 | 馈 4555 | 皮 4556 | 奴 4557 | 苜 4558 | 佤 4559 | 伶 4560 | 晗 4561 | 铱 4562 | 炬 4563 | 优 4564 | 弊 4565 | 氢 4566 | 恃 4567 | 甫 4568 | 攥 4569 | 端 4570 | 锌 4571 | 灰 4572 | 稹 4573 | 炝 4574 | 曙 4575 | 邋 4576 | 亥 4577 | 眶 4578 | 碾 4579 | 拉 4580 | 萝 4581 | 绔 4582 | 捷 4583 | 浍 4584 | 腋 4585 | 姑 4586 | 菖 4587 | 凌 4588 | 涞 4589 | 麽 4590 | 锢 4591 | 桨 4592 | 潢 4593 | 绎 4594 | 镰 4595 | 殆 4596 | 锑 4597 | 渝 4598 | 铬 4599 | 困 4600 | 绽 4601 | 觎 4602 | 匈 4603 | 糙 4604 | 暑 4605 | 裹 4606 | 鸟 4607 | 盔 4608 | 肽 4609 | 迷 4610 | 綦 4611 | 『 4612 | 亳 4613 | 佝 4614 | 俘 4615 | 钴 4616 | 觇 4617 | 骥 4618 | 仆 4619 | 疝 4620 | 跪 4621 | 婶 4622 | 郯 4623 | 瀹 4624 | 唉 4625 | 脖 4626 | 踞 4627 | 针 4628 | 晾 4629 | 忒 4630 | 扼 4631 | 瞩 4632 | 叛 4633 | 椒 4634 | 疟 4635 | 嗡 4636 | 邗 4637 | 肆 4638 | 跆 4639 | 玫 4640 | 忡 4641 | 捣 4642 | 咧 4643 | 唆 4644 | 艄 4645 | 蘑 4646 | 潦 4647 | 笛 4648 | 阚 4649 | 沸 4650 | 泻 4651 | 掊 4652 | 菽 4653 | 贫 4654 | 斥 4655 | 髂 4656 | 孢 4657 | 镂 4658 | 赂 4659 | 麝 4660 | 鸾 4661 | 屡 4662 | 衬 4663 | 苷 4664 | 恪 4665 | 叠 4666 | 希 4667 | 粤 4668 | 爻 4669 | 喝 4670 | 茫 4671 | 惬 4672 | 郸 4673 | 绻 4674 | 庸 4675 | 撅 4676 | 碟 4677 | 宄 4678 | 妹 4679 | 膛 4680 | 叮 4681 | 饵 4682 | 崛 4683 | 嗲 4684 | 椅 4685 | 冤 4686 | 搅 4687 | 咕 4688 | 敛 4689 | 尹 4690 | 垦 4691 | 闷 4692 | 蝉 4693 | 霎 4694 | 勰 4695 | 败 4696 | 蓑 4697 | 泸 4698 | 肤 4699 | 鹌 4700 | 幌 4701 | 焦 4702 | 浠 4703 | 鞍 4704 | 刁 4705 | 舰 4706 | 乙 4707 | 竿 4708 | 裔 4709 | 。 4710 | 茵 4711 | 函 4712 | 伊 4713 | 兄 4714 | 丨 4715 | 娜 4716 | 匍 4717 | 謇 4718 | 莪 4719 | 宥 4720 | 似 4721 | 蝽 4722 | 翳 4723 | 酪 4724 | 翠 4725 | 粑 4726 | 薇 4727 | 祢 4728 | 骏 4729 | 赠 4730 | 叫 4731 | Q 4732 | 噤 4733 | 噻 4734 | 竖 4735 | 芗 4736 | 莠 4737 | 潭 4738 | 俊 4739 | 羿 4740 | 耜 4741 | O 4742 | 郫 4743 | 趁 4744 | 嗪 4745 | 囚 4746 | 蹶 4747 | 芒 4748 | 洁 4749 | 笋 4750 | 鹑 4751 | 敲 4752 | 硝 4753 | 啶 4754 | 堡 4755 | 渲 4756 | 揩 4757 | 』 4758 | 携 4759 | 宿 4760 | 遒 4761 | 颍 4762 | 扭 4763 | 棱 4764 | 割 4765 | 萜 4766 | 蔸 4767 | 葵 4768 | 琴 4769 | 捂 4770 | 饰 4771 | 衙 4772 | 耿 4773 | 掠 4774 | 募 4775 | 岂 4776 | 窖 4777 | 涟 4778 | 蔺 4779 | 瘤 4780 | 柞 4781 | 瞪 4782 | 怜 4783 | 匹 4784 | 距 4785 | 楔 4786 | 炜 4787 | 哆 4788 | 秦 4789 | 缎 4790 | 幼 4791 | 茁 4792 | 绪 4793 | 痨 4794 | 恨 4795 | 楸 4796 | 娅 4797 | 瓦 4798 | 桩 4799 | 雪 4800 | 嬴 4801 | 伏 4802 | 榔 4803 | 妥 4804 | 铿 4805 | 拌 4806 | 眠 4807 | 雍 4808 | 缇 4809 | ‘ 4810 | 卓 4811 | 搓 4812 | 哌 4813 | 觞 4814 | 噩 4815 | 屈 4816 | 哧 4817 | 髓 4818 | 咦 4819 | 巅 4820 | 娑 4821 | 侑 4822 | 淫 4823 | 膳 4824 | 祝 4825 | 勾 4826 | 姊 4827 | 莴 4828 | 胄 4829 | 疃 4830 | 薛 4831 | 蜷 4832 | 胛 4833 | 巷 4834 | 芙 4835 | 芋 4836 | 熙 4837 | 闰 4838 | 勿 4839 | 窃 4840 | 狱 4841 | 剩 4842 | 钏 4843 | 幢 4844 | 陟 4845 | 铛 4846 | 慧 4847 | 靴 4848 | 耍 4849 | k 4850 | 浙 4851 | 浇 4852 | 飨 4853 | 惟 4854 | 绗 4855 | 祜 4856 | 澈 4857 | 啼 4858 | 咪 4859 | 磷 4860 | 摞 4861 | 诅 4862 | 郦 4863 | 抹 4864 | 跃 4865 | 壬 4866 | 吕 4867 | 肖 4868 | 琏 4869 | 颤 4870 | 尴 4871 | 剡 4872 | 抠 4873 | 凋 4874 | 赚 4875 | 泊 4876 | 津 4877 | 宕 4878 | 殷 4879 | 倔 4880 | 氲 4881 | 漫 4882 | 邺 4883 | 涎 4884 | 怠 4885 | $ 4886 | 垮 4887 | 荬 4888 | 遵 4889 | 俏 4890 | 叹 4891 | 噢 4892 | 饽 4893 | 蜘 4894 | 孙 4895 | 筵 4896 | 疼 4897 | 鞭 4898 | 羧 4899 | 牦 4900 | 箭 4901 | 潴 4902 | c 4903 | 眸 4904 | 祭 4905 | 髯 4906 | 啖 4907 | 坳 4908 | 愁 4909 | 芩 4910 | 驮 4911 | 倡 4912 | 巽 4913 | 穰 4914 | 沃 4915 | 胚 4916 | 怒 4917 | 凤 4918 | 槛 4919 | 剂 4920 | 趵 4921 | 嫁 4922 | v 4923 | 邢 4924 | 灯 4925 | 鄢 4926 | 桐 4927 | 睽 4928 | 檗 4929 | 锯 4930 | 槟 4931 | 婷 4932 | 嵋 4933 | 圻 4934 | 诗 4935 | 蕈 4936 | 颠 4937 | 遭 4938 | 痢 4939 | 芸 4940 | 怯 4941 | 馥 4942 | 竭 4943 | 锗 4944 | 徜 4945 | 恭 4946 | 遍 4947 | 籁 4948 | 剑 4949 | 嘱 4950 | 苡 4951 | 龄 4952 | 僧 4953 | 桑 4954 | 潸 4955 | 弘 4956 | 澶 4957 | 楹 4958 | 悲 4959 | 讫 4960 | 愤 4961 | 腥 4962 | 悸 4963 | 谍 4964 | 椹 4965 | 呢 4966 | 桓 4967 | 葭 4968 | 攫 4969 | 阀 4970 | 翰 4971 | 躲 4972 | 敖 4973 | 柑 4974 | 郎 4975 | 笨 4976 | 橇 4977 | 呃 4978 | 魁 4979 | 燎 4980 | 脓 4981 | 葩 4982 | 磋 4983 | 垛 4984 | 玺 4985 | 狮 4986 | 沓 4987 | 砜 4988 | 蕊 4989 | 锺 4990 | 罹 4991 | 蕉 4992 | 翱 4993 | 虐 4994 | 闾 4995 | 巫 4996 | 旦 4997 | 茱 4998 | 嬷 4999 | 枯 5000 | 鹏 5001 | 贡 5002 | 芹 5003 | 汛 5004 | 矫 5005 | 绁 5006 | 拣 5007 | 禺 5008 | 佃 5009 | 讣 5010 | 舫 5011 | 惯 5012 | 乳 5013 | 趋 5014 | 疲 5015 | 挽 5016 | 岚 5017 | 虾 5018 | 衾 5019 | 蠹 5020 | 蹂 5021 | 飓 5022 | 氦 5023 | 铖 5024 | 孩 5025 | 稞 5026 | 瑜 5027 | 壅 5028 | 掀 5029 | 勘 5030 | 妓 5031 | 畅 5032 | 髋 5033 | W 5034 | 庐 5035 | 牲 5036 | 蓿 5037 | 榕 5038 | 练 5039 | 垣 5040 | 唱 5041 | 邸 5042 | 菲 5043 | 昆 5044 | 婺 5045 | 穿 5046 | 绡 5047 | 麒 5048 | 蚱 5049 | 掂 5050 | 愚 5051 | 泷 5052 | 涪 5053 | 漳 5054 | 妩 5055 | 娉 5056 | 榄 5057 | 讷 5058 | 觅 5059 | 旧 5060 | 藤 5061 | 煮 5062 | 呛 5063 | 柳 5064 | 腓 5065 | 叭 5066 | 庵 5067 | 烷 5068 | 阡 5069 | 罂 5070 | 蜕 5071 | 擂 5072 | 猖 5073 | 咿 5074 | 媲 5075 | 脉 5076 | 【 5077 | 沏 5078 | 貅 5079 | 黠 5080 | 熏 5081 | 哲 5082 | 烁 5083 | 坦 5084 | 酵 5085 | 兜 5086 | × 5087 | 潇 5088 | 撒 5089 | 剽 5090 | 珩 5091 | 圹 5092 | 乾 5093 | 摸 5094 | 樟 5095 | 帽 5096 | 嗒 5097 | 襄 5098 | 魂 5099 | 轿 5100 | 憬 5101 | 锡 5102 | 〕 5103 | 喃 5104 | 皆 5105 | 咖 5106 | 隅 5107 | 脸 5108 | 残 5109 | 泮 5110 | 袂 5111 | 鹂 5112 | 珊 5113 | 囤 5114 | 捆 5115 | 咤 5116 | 误 5117 | 徨 5118 | 闹 5119 | 淙 5120 | 芊 5121 | 淋 5122 | 怆 5123 | 囗 5124 | 拨 5125 | 梳 5126 | 渤 5127 | R 5128 | G 5129 | 绨 5130 | 蚓 5131 | 婀 5132 | 幡 5133 | 狩 5134 | 麾 5135 | 谢 5136 | 唢 5137 | 裸 5138 | 旌 5139 | 伉 5140 | 纶 5141 | 裂 5142 | 驳 5143 | 砼 5144 | 咛 5145 | 澄 5146 | 樨 5147 | 蹈 5148 | 宙 5149 | 澍 5150 | 倍 5151 | 貔 5152 | 操 5153 | 勇 5154 | 蟠 5155 | 摈 5156 | 砧 5157 | 虬 5158 | 够 5159 | 缁 5160 | 悦 5161 | 藿 5162 | 撸 5163 | 艹 5164 | 摁 5165 | 淹 5166 | 豇 5167 | 虎 5168 | 榭 5169 | ˉ 5170 | 吱 5171 | d 5172 | ° 5173 | 喧 5174 | 荀 5175 | 踱 5176 | 侮 5177 | 奋 5178 | 偕 5179 | 饷 5180 | 犍 5181 | 惮 5182 | 坑 5183 | 璎 5184 | 徘 5185 | 宛 5186 | 妆 5187 | 袈 5188 | 倩 5189 | 窦 5190 | 昂 5191 | 荏 5192 | 乖 5193 | K 5194 | 怅 5195 | 撰 5196 | 鳙 5197 | 牙 5198 | 袁 5199 | 酞 5200 | X 5201 | 痿 5202 | 琼 5203 | 闸 5204 | 雁 5205 | 趾 5206 | 荚 5207 | 虻 5208 | 涝 5209 | 《 5210 | 杏 5211 | 韭 5212 | 偈 5213 | 烤 5214 | 绫 5215 | 鞘 5216 | 卉 5217 | 症 5218 | 遢 5219 | 蓥 5220 | 诋 5221 | 杭 5222 | 荨 5223 | 匆 5224 | 竣 5225 | 簪 5226 | 辙 5227 | 敕 5228 | 虞 5229 | 丹 5230 | 缭 5231 | 咩 5232 | 黟 5233 | m 5234 | 淤 5235 | 瑕 5236 | 咂 5237 | 铉 5238 | 硼 5239 | 茨 5240 | 嶂 5241 | 痒 5242 | 畸 5243 | 敬 5244 | 涿 5245 | 粪 5246 | 窘 5247 | 熟 5248 | 叔 5249 | 嫔 5250 | 盾 5251 | 忱 5252 | 裘 5253 | 憾 5254 | 梵 5255 | 赡 5256 | 珙 5257 | 咯 5258 | 娘 5259 | 庙 5260 | 溯 5261 | 胺 5262 | 葱 5263 | 痪 5264 | 摊 5265 | 荷 5266 | 卞 5267 | 乒 5268 | 髦 5269 | 寐 5270 | 铭 5271 | 坩 5272 | 胗 5273 | 枷 5274 | 爆 5275 | 溟 5276 | 嚼 5277 | 羚 5278 | 砬 5279 | 轨 5280 | 惊 5281 | 挠 5282 | 罄 5283 | 竽 5284 | 菏 5285 | 氧 5286 | 浅 5287 | 楣 5288 | 盼 5289 | 枢 5290 | 炸 5291 | 阆 5292 | 杯 5293 | 谏 5294 | 噬 5295 | 淇 5296 | 渺 5297 | 俪 5298 | 秆 5299 | 墓 5300 | 泪 5301 | 跻 5302 | 砌 5303 | 痰 5304 | 垡 5305 | 渡 5306 | 耽 5307 | 釜 5308 | 讶 5309 | 鳎 5310 | 煞 5311 | 呗 5312 | 韶 5313 | 舶 5314 | 绷 5315 | 鹳 5316 | 缜 5317 | 旷 5318 | 铊 5319 | 皱 5320 | 龌 5321 | 檀 5322 | 霖 5323 | 奄 5324 | 槐 5325 | 艳 5326 | 蝶 5327 | 旋 5328 | 哝 5329 | 赶 5330 | 骞 5331 | 蚧 5332 | 腊 5333 | 盈 5334 | 丁 5335 | ` 5336 | 蜚 5337 | 矸 5338 | 蝙 5339 | 睨 5340 | 嚓 5341 | 僻 5342 | 鬼 5343 | 醴 5344 | 夜 5345 | 彝 5346 | 磊 5347 | 笔 5348 | 拔 5349 | 栀 5350 | 糕 5351 | 厦 5352 | 邰 5353 | 纫 5354 | 逭 5355 | 纤 5356 | 眦 5357 | 膊 5358 | 馍 5359 | 躇 5360 | 烯 5361 | 蘼 5362 | 冬 5363 | 诤 5364 | 暄 5365 | 骶 5366 | 哑 5367 | 瘠 5368 | 」 5369 | 臊 5370 | 丕 5371 | 愈 5372 | 咱 5373 | 螺 5374 | 擅 5375 | 跋 5376 | 搏 5377 | 硪 5378 | 谄 5379 | 笠 5380 | 淡 5381 | 嘿 5382 | 骅 5383 | 谧 5384 | 鼎 5385 | 皋 5386 | 姚 5387 | 歼 5388 | 蠢 5389 | 驼 5390 | 耳 5391 | 胬 5392 | 挝 5393 | 涯 5394 | 狗 5395 | 蒽 5396 | 孓 5397 | 犷 5398 | 凉 5399 | 芦 5400 | 箴 5401 | 铤 5402 | 孤 5403 | 嘛 5404 | 坤 5405 | V 5406 | 茴 5407 | 朦 5408 | 挞 5409 | 尖 5410 | 橙 5411 | 诞 5412 | 搴 5413 | 碇 5414 | 洵 5415 | 浚 5416 | 帚 5417 | 蜍 5418 | 漯 5419 | 柘 5420 | 嚎 5421 | 讽 5422 | 芭 5423 | 荤 5424 | 咻 5425 | 祠 5426 | 秉 5427 | 跖 5428 | 埃 5429 | 吓 5430 | 糯 5431 | 眷 5432 | 馒 5433 | 惹 5434 | 娼 5435 | 鲑 5436 | 嫩 5437 | 讴 5438 | 轮 5439 | 瞥 5440 | 靶 5441 | 褚 5442 | 乏 5443 | 缤 5444 | 宋 5445 | 帧 5446 | 删 5447 | 驱 5448 | 碎 5449 | 扑 5450 | 俩 5451 | 俄 5452 | 偏 5453 | 涣 5454 | 竹 5455 | 噱 5456 | 皙 5457 | 佰 5458 | 渚 5459 | 唧 5460 | 斡 5461 | # 5462 | 镉 5463 | 刀 5464 | 崎 5465 | 筐 5466 | 佣 5467 | 夭 5468 | 贰 5469 | 肴 5470 | 峙 5471 | 哔 5472 | 艿 5473 | 匐 5474 | 牺 5475 | 镛 5476 | 缘 5477 | 仡 5478 | 嫡 5479 | 劣 5480 | 枸 5481 | 堀 5482 | 梨 5483 | 簿 5484 | 鸭 5485 | 蒸 5486 | 亦 5487 | 稽 5488 | 浴 5489 | { 5490 | 衢 5491 | 束 5492 | 槲 5493 | j 5494 | 阁 5495 | 揍 5496 | 疥 5497 | 棋 5498 | 潋 5499 | 聪 5500 | 窜 5501 | 乓 5502 | 睛 5503 | 插 5504 | 冉 5505 | 阪 5506 | 苍 5507 | 搽 5508 | 「 5509 | 蟾 5510 | 螟 5511 | 幸 5512 | 仇 5513 | 樽 5514 | 撂 5515 | 慢 5516 | 跤 5517 | 幔 5518 | 俚 5519 | 淅 5520 | 覃 5521 | 觊 5522 | 溶 5523 | 妖 5524 | 帛 5525 | 侨 5526 | 曰 5527 | 妾 5528 | 泗 5529 | · 5530 | : 5531 | 瀘 5532 | 風 5533 | Ë 5534 | ( 5535 | ) 5536 | ∶ 5537 | 紅 5538 | 紗 5539 | 瑭 5540 | 雲 5541 | 頭 5542 | 鶏 5543 | 財 5544 | 許 5545 | • 5546 | ¥ 5547 | 樂 5548 | 焗 5549 | 麗 5550 | — 5551 | ; 5552 | 滙 5553 | 東 5554 | 榮 5555 | 繪 5556 | 興 5557 | … 5558 | 門 5559 | 業 5560 | π 5561 | 楊 5562 | 國 5563 | 顧 5564 | é 5565 | 盤 5566 | 寳 5567 | Λ 5568 | 龍 5569 | 鳳 5570 | 島 5571 | 誌 5572 | 緣 5573 | 結 5574 | 銭 5575 | 萬 5576 | 勝 5577 | 祎 5578 | 璟 5579 | 優 5580 | 歡 5581 | 臨 5582 | 時 5583 | 購 5584 | = 5585 | ★ 5586 | 藍 5587 | 昇 5588 | 鐵 5589 | 觀 5590 | 勅 5591 | 農 5592 | 聲 5593 | 畫 5594 | 兿 5595 | 術 5596 | 發 5597 | 劉 5598 | 記 5599 | 專 5600 | 耑 5601 | 園 5602 | 書 5603 | 壴 5604 | 種 5605 | Ο 5606 | ● 5607 | 褀 5608 | 號 5609 | 銀 5610 | 匯 5611 | 敟 5612 | 锘 5613 | 葉 5614 | 橪 5615 | 廣 5616 | 進 5617 | 蒄 5618 | 鑽 5619 | 阝 5620 | 祙 5621 | 貢 5622 | 鍋 5623 | 豊 5624 | 夬 5625 | 喆 5626 | 團 5627 | 閣 5628 | 開 5629 | 燁 5630 | 賓 5631 | 館 5632 | 酡 5633 | 沔 5634 | 順 5635 | + 5636 | 硚 5637 | 劵 5638 | 饸 5639 | 陽 5640 | 車 5641 | 湓 5642 | 復 5643 | 萊 5644 | 氣 5645 | 軒 5646 | 華 5647 | 堃 5648 | 迮 5649 | 纟 5650 | 戶 5651 | 馬 5652 | 學 5653 | 裡 5654 | 電 5655 | 嶽 5656 | 獨 5657 | マ 5658 | シ 5659 | サ 5660 | ジ 5661 | 燘 5662 | 袪 5663 | 環 5664 | ❤ 5665 | 臺 5666 | 灣 5667 | 専 5668 | 賣 5669 | 孖 5670 | 聖 5671 | 攝 5672 | 線 5673 | ▪ 5674 | α 5675 | 傢 5676 | 俬 5677 | 夢 5678 | 達 5679 | 莊 5680 | 喬 5681 | 貝 5682 | 薩 5683 | 劍 5684 | 羅 5685 | 壓 5686 | 棛 5687 | 饦 5688 | 尃 5689 | 璈 5690 | 囍 5691 | 醫 5692 | G 5693 | I 5694 | A 5695 | # 5696 | N 5697 | 鷄 5698 | 髙 5699 | 嬰 5700 | 啓 5701 | 約 5702 | 隹 5703 | 潔 5704 | 賴 5705 | 藝 5706 | ~ 5707 | 寶 5708 | 籣 5709 | 麺 5710 |   5711 | 嶺 5712 | √ 5713 | 義 5714 | 網 5715 | 峩 5716 | 長 5717 | ∧ 5718 | 魚 5719 | 機 5720 | 構 5721 | ② 5722 | 鳯 5723 | 偉 5724 | L 5725 | B 5726 | 㙟 5727 | 畵 5728 | 鴿 5729 | ' 5730 | 詩 5731 | 溝 5732 | 嚞 5733 | 屌 5734 | 藔 5735 | 佧 5736 | 玥 5737 | 蘭 5738 | 織 5739 | 1 5740 | 3 5741 | 9 5742 | 0 5743 | 7 5744 | 點 5745 | 砭 5746 | 鴨 5747 | 鋪 5748 | 銘 5749 | 廳 5750 | 弍 5751 | ‧ 5752 | 創 5753 | 湯 5754 | 坶 5755 | ℃ 5756 | 卩 5757 | 骝 5758 | & 5759 | 烜 5760 | 荘 5761 | 當 5762 | 潤 5763 | 扞 5764 | 係 5765 | 懷 5766 | 碶 5767 | 钅 5768 | 蚨 5769 | 讠 5770 | ☆ 5771 | 叢 5772 | 爲 5773 | 埗 5774 | 涫 5775 | 塗 5776 | → 5777 | 楽 5778 | 現 5779 | 鯨 5780 | 愛 5781 | 瑪 5782 | 鈺 5783 | 忄 5784 | 悶 5785 | 藥 5786 | 飾 5787 | 樓 5788 | 視 5789 | 孬 5790 | ㆍ 5791 | 燚 5792 | 苪 5793 | 師 5794 | ① 5795 | 丼 5796 | 锽 5797 | │ 5798 | 韓 5799 | 標 5800 | è 5801 | 兒 5802 | 閏 5803 | 匋 5804 | 張 5805 | 漢 5806 | Ü 5807 | 髪 5808 | 會 5809 | 閑 5810 | 檔 5811 | 習 5812 | 裝 5813 | の 5814 | 峯 5815 | 菘 5816 | 輝 5817 | И 5818 | 雞 5819 | 釣 5820 | 億 5821 | 浐 5822 | K 5823 | O 5824 | R 5825 | 8 5826 | H 5827 | E 5828 | P 5829 | T 5830 | W 5831 | D 5832 | S 5833 | C 5834 | M 5835 | F 5836 | 姌 5837 | 饹 5838 | » 5839 | 晞 5840 | 廰 5841 | ä 5842 | 嵯 5843 | 鷹 5844 | 負 5845 | 飲 5846 | 絲 5847 | 冚 5848 | 楗 5849 | 澤 5850 | 綫 5851 | 區 5852 | ❋ 5853 | ← 5854 | 質 5855 | 靑 5856 | 揚 5857 | ③ 5858 | 滬 5859 | 統 5860 | 産 5861 | 協 5862 | ﹑ 5863 | 乸 5864 | 畐 5865 | 經 5866 | 運 5867 | 際 5868 | 洺 5869 | 岽 5870 | 為 5871 | 粵 5872 | 諾 5873 | 崋 5874 | 豐 5875 | 碁 5876 | ɔ 5877 | V 5878 | 2 5879 | 6 5880 | 齋 5881 | 誠 5882 | 訂 5883 | ´ 5884 | 勑 5885 | 雙 5886 | 陳 5887 | 無 5888 | í 5889 | 泩 5890 | 媄 5891 | 夌 5892 | 刂 5893 | i 5894 | c 5895 | t 5896 | o 5897 | r 5898 | a 5899 | 嘢 5900 | 耄 5901 | 燴 5902 | 暃 5903 | 壽 5904 | 媽 5905 | 靈 5906 | 抻 5907 | 體 5908 | 唻 5909 | É 5910 | 冮 5911 | 甹 5912 | 鎮 5913 | 錦 5914 | ʌ 5915 | 蜛 5916 | 蠄 5917 | 尓 5918 | 駕 5919 | 戀 5920 | 飬 5921 | 逹 5922 | 倫 5923 | 貴 5924 | 極 5925 | Я 5926 | Й 5927 | 寬 5928 | 磚 5929 | 嶪 5930 | 郎 5931 | 職 5932 | | 5933 | 間 5934 | n 5935 | d 5936 | 剎 5937 | 伈 5938 | 課 5939 | 飛 5940 | 橋 5941 | 瘊 5942 | № 5943 | 譜 5944 | 骓 5945 | 圗 5946 | 滘 5947 | 縣 5948 | 粿 5949 | 咅 5950 | 養 5951 | 濤 5952 | 彳 5953 | ® 5954 | % 5955 | Ⅱ 5956 | 啰 5957 | 㴪 5958 | 見 5959 | 矞 5960 | 薬 5961 | 糁 5962 | 邨 5963 | 鲮 5964 | 顔 5965 | 罱 5966 | З 5967 | 選 5968 | 話 5969 | 贏 5970 | 氪 5971 | 俵 5972 | 競 5973 | 瑩 5974 | 繡 5975 | 枱 5976 | β 5977 | 綉 5978 | á 5979 | 獅 5980 | 爾 5981 | ™ 5982 | 麵 5983 | 戋 5984 | 淩 5985 | 徳 5986 | 個 5987 | 劇 5988 | 場 5989 | 務 5990 | 簡 5991 | 寵 5992 | h 5993 | 實 5994 | 膠 5995 | 轱 5996 | 圖 5997 | 築 5998 | 嘣 5999 | 樹 6000 | 㸃 6001 | 營 6002 | 耵 6003 | 孫 6004 | 饃 6005 | 鄺 6006 | 飯 6007 | 麯 6008 | 遠 6009 | 輸 6010 | 坫 6011 | 孃 6012 | 乚 6013 | 閃 6014 | 鏢 6015 | ㎡ 6016 | 題 6017 | 廠 6018 | 關 6019 | ↑ 6020 | 爺 6021 | 將 6022 | 軍 6023 | 連 6024 | 篦 6025 | 覌 6026 | 參 6027 | 箸 6028 | - 6029 | 窠 6030 | 棽 6031 | 寕 6032 | 夀 6033 | 爰 6034 | 歐 6035 | 呙 6036 | 閥 6037 | 頡 6038 | 熱 6039 | 雎 6040 | 垟 6041 | 裟 6042 | 凬 6043 | 勁 6044 | 帑 6045 | 馕 6046 | 夆 6047 | 疌 6048 | 枼 6049 | 馮 6050 | 貨 6051 | 蒤 6052 | 樸 6053 | 彧 6054 | 旸 6055 | 靜 6056 | 龢 6057 | 暢 6058 | 㐱 6059 | 鳥 6060 | 珺 6061 | 鏡 6062 | 灡 6063 | 爭 6064 | 堷 6065 | 廚 6066 | Ó 6067 | 騰 6068 | 診 6069 | ┅ 6070 | 蘇 6071 | 褔 6072 | 凱 6073 | 頂 6074 | 豕 6075 | 亞 6076 | 帥 6077 | 嘬 6078 | ⊥ 6079 | 仺 6080 | 桖 6081 | 複 6082 | 饣 6083 | 絡 6084 | 穂 6085 | 顏 6086 | 棟 6087 | 納 6088 | ▏ 6089 | 濟 6090 | 親 6091 | 設 6092 | 計 6093 | 攵 6094 | 埌 6095 | 烺 6096 | ò 6097 | 頤 6098 | 燦 6099 | 蓮 6100 | 撻 6101 | 節 6102 | 講 6103 | 濱 6104 | 濃 6105 | 娽 6106 | 洳 6107 | 朿 6108 | 燈 6109 | 鈴 6110 | 護 6111 | 膚 6112 | 铔 6113 | 過 6114 | 補 6115 | Z 6116 | U 6117 | 5 6118 | 4 6119 | 坋 6120 | 闿 6121 | 䖝 6122 | 餘 6123 | 缐 6124 | 铞 6125 | 貿 6126 | 铪 6127 | 桼 6128 | 趙 6129 | 鍊 6130 | [ 6131 | 㐂 6132 | 垚 6133 | 菓 6134 | 揸 6135 | 捲 6136 | 鐘 6137 | 滏 6138 | 𣇉 6139 | 爍 6140 | 輪 6141 | 燜 6142 | 鴻 6143 | 鮮 6144 | 動 6145 | 鹞 6146 | 鷗 6147 | 丄 6148 | 慶 6149 | 鉌 6150 | 翥 6151 | 飮 6152 | 腸 6153 | ⇋ 6154 | 漁 6155 | 覺 6156 | 來 6157 | 熘 6158 | 昴 6159 | 翏 6160 | 鲱 6161 | 圧 6162 | 鄉 6163 | 萭 6164 | 頔 6165 | 爐 6166 | 嫚 6167 | г 6168 | 貭 6169 | 類 6170 | 聯 6171 | 幛 6172 | 輕 6173 | 訓 6174 | 鑒 6175 | 夋 6176 | 锨 6177 | 芃 6178 | 珣 6179 | 䝉 6180 | 扙 6181 | 嵐 6182 | 銷 6183 | 處 6184 | ㄱ 6185 | 語 6186 | 誘 6187 | 苝 6188 | 歸 6189 | 儀 6190 | 燒 6191 | 楿 6192 | 內 6193 | 粢 6194 | 葒 6195 | 奧 6196 | 麥 6197 | 礻 6198 | 滿 6199 | 蠔 6200 | 穵 6201 | 瞭 6202 | 態 6203 | 鱬 6204 | 榞 6205 | 硂 6206 | 鄭 6207 | 黃 6208 | 煙 6209 | 祐 6210 | 奓 6211 | 逺 6212 | * 6213 | 瑄 6214 | 獲 6215 | 聞 6216 | 薦 6217 | 讀 6218 | 這 6219 | 樣 6220 | 決 6221 | 問 6222 | 啟 6223 | 們 6224 | 執 6225 | 説 6226 | 轉 6227 | 單 6228 | 隨 6229 | 唘 6230 | 帶 6231 | 倉 6232 | 庫 6233 | 還 6234 | 贈 6235 | 尙 6236 | 皺 6237 | ■ 6238 | 餅 6239 | 產 6240 | ○ 6241 | ∈ 6242 | 報 6243 | 狀 6244 | 楓 6245 | 賠 6246 | 琯 6247 | 嗮 6248 | 禮 6249 | ` 6250 | 傳 6251 | > 6252 | ≤ 6253 | 嗞 6254 | Φ 6255 | ≥ 6256 | 換 6257 | 咭 6258 | ∣ 6259 | ↓ 6260 | 曬 6261 | ε 6262 | 応 6263 | 寫 6264 | ″ 6265 | 終 6266 | 様 6267 | 純 6268 | 費 6269 | 療 6270 | 聨 6271 | 凍 6272 | 壐 6273 | 郵 6274 | ü 6275 | 黒 6276 | ∫ 6277 | 製 6278 | 塊 6279 | 調 6280 | 軽 6281 | 確 6282 | 撃 6283 | 級 6284 | 馴 6285 | Ⅲ 6286 | 涇 6287 | 繹 6288 | 數 6289 | 碼 6290 | 證 6291 | 狒 6292 | 処 6293 | 劑 6294 | < 6295 | 晧 6296 | 賀 6297 | 衆 6298 | ] 6299 | 櫥 6300 | 兩 6301 | 陰 6302 | 絶 6303 | 對 6304 | 鯉 6305 | 憶 6306 | ◎ 6307 | p 6308 | e 6309 | Y 6310 | 蕒 6311 | 煖 6312 | 頓 6313 | 測 6314 | 試 6315 | 鼽 6316 | 僑 6317 | 碩 6318 | 妝 6319 | 帯 6320 | ≈ 6321 | 鐡 6322 | 舖 6323 | 權 6324 | 喫 6325 | 倆 6326 | ˋ 6327 | 該 6328 | 悅 6329 | ā 6330 | 俫 6331 | . 6332 | f 6333 | s 6334 | b 6335 | m 6336 | k 6337 | g 6338 | u 6339 | j 6340 | 貼 6341 | 淨 6342 | 濕 6343 | 針 6344 | 適 6345 | 備 6346 | l 6347 | / 6348 | 給 6349 | 謢 6350 | 強 6351 | 觸 6352 | 衛 6353 | 與 6354 | ⊙ 6355 | $ 6356 | 緯 6357 | 變 6358 | ⑴ 6359 | ⑵ 6360 | ⑶ 6361 | ㎏ 6362 | 殺 6363 | ∩ 6364 | 幚 6365 | ─ 6366 | 價 6367 | ▲ 6368 | 離 6369 | ú 6370 | ó 6371 | 飄 6372 | 烏 6373 | 関 6374 | 閟 6375 | ﹝ 6376 | ﹞ 6377 | 邏 6378 | 輯 6379 | 鍵 6380 | 驗 6381 | 訣 6382 | 導 6383 | 歷 6384 | 屆 6385 | 層 6386 | ▼ 6387 | 儱 6388 | 錄 6389 | 熳 6390 | ē 6391 | 艦 6392 | 吋 6393 | 錶 6394 | 辧 6395 | 飼 6396 | 顯 6397 | ④ 6398 | 禦 6399 | 販 6400 | 気 6401 | 対 6402 | 枰 6403 | 閩 6404 | 紀 6405 | 幹 6406 | 瞓 6407 | 貊 6408 | 淚 6409 | △ 6410 | 眞 6411 | 墊 6412 | Ω 6413 | 獻 6414 | 褲 6415 | 縫 6416 | 緑 6417 | 亜 6418 | 鉅 6419 | 餠 6420 | { 6421 | } 6422 | ◆ 6423 | 蘆 6424 | 薈 6425 | █ 6426 | ◇ 6427 | 溫 6428 | 彈 6429 | 晳 6430 | 粧 6431 | 犸 6432 | 穩 6433 | 訊 6434 | 崬 6435 | 凖 6436 | 熥 6437 | П 6438 | 舊 6439 | 條 6440 | 紋 6441 | 圍 6442 | Ⅳ 6443 | 筆 6444 | 尷 6445 | 難 6446 | 雜 6447 | 錯 6448 | 綁 6449 | 識 6450 | 頰 6451 | 鎖 6452 | 艶 6453 | □ 6454 | 殁 6455 | 殼 6456 | ⑧ 6457 | ├ 6458 | ▕ 6459 | 鵬 6460 | ǐ 6461 | ō 6462 | ǒ 6463 | 糝 6464 | 綱 6465 | ▎ 6466 | μ 6467 | 盜 6468 | 饅 6469 | 醬 6470 | 籤 6471 | 蓋 6472 | 釀 6473 | 鹽 6474 | 據 6475 | à 6476 | ɡ 6477 | 辦 6478 | ◥ 6479 | 彐 6480 | ┌ 6481 | 婦 6482 | 獸 6483 | 鲩 6484 | 伱 6485 | ī 6486 | 蒟 6487 | 蒻 6488 | 齊 6489 | 袆 6490 | 腦 6491 | 寧 6492 | 凈 6493 | 妳 6494 | 煥 6495 | 詢 6496 | 偽 6497 | 謹 6498 | 啫 6499 | 鯽 6500 | 騷 6501 | 鱸 6502 | 損 6503 | 傷 6504 | 鎻 6505 | 髮 6506 | 買 6507 | 冏 6508 | 儥 6509 | 両 6510 | ﹢ 6511 | ∞ 6512 | 載 6513 | 喰 6514 | z 6515 | 羙 6516 | 悵 6517 | 燙 6518 | 曉 6519 | 員 6520 | 組 6521 | 徹 6522 | 艷 6523 | 痠 6524 | 鋼 6525 | 鼙 6526 | 縮 6527 | 細 6528 | 嚒 6529 | 爯 6530 | ≠ 6531 | 維 6532 | " 6533 | 鱻 6534 | 壇 6535 | 厍 6536 | 帰 6537 | 浥 6538 | 犇 6539 | 薡 6540 | 軎 6541 | ² 6542 | 應 6543 | 醜 6544 | 刪 6545 | 緻 6546 | 鶴 6547 | 賜 6548 | 噁 6549 | 軌 6550 | 尨 6551 | 镔 6552 | 鷺 6553 | 槗 6554 | 彌 6555 | 葚 6556 | 濛 6557 | 請 6558 | 溇 6559 | 緹 6560 | 賢 6561 | 訪 6562 | 獴 6563 | 瑅 6564 | 資 6565 | 縤 6566 | 陣 6567 | 蕟 6568 | 栢 6569 | 韻 6570 | 祼 6571 | 恁 6572 | 伢 6573 | 謝 6574 | 劃 6575 | 涑 6576 | 總 6577 | 衖 6578 | 踺 6579 | 砋 6580 | 凉 6581 | 籃 6582 | 駿 6583 | 苼 6584 | 瘋 6585 | 昽 6586 | 紡 6587 | 驊 6588 | 腎 6589 | ﹗ 6590 | 響 6591 | 杋 6592 | 剛 6593 | 嚴 6594 | 禪 6595 | 歓 6596 | 槍 6597 | 傘 6598 | 檸 6599 | 檫 6600 | 炣 6601 | 勢 6602 | 鏜 6603 | 鎢 6604 | 銑 6605 | 尐 6606 | 減 6607 | 奪 6608 | 惡 6609 | θ 6610 | 僮 6611 | 婭 6612 | 臘 6613 | ū 6614 | ì 6615 | 殻 6616 | 鉄 6617 | ∑ 6618 | 蛲 6619 | 焼 6620 | 緖 6621 | 續 6622 | 紹 6623 | 懮 --------------------------------------------------------------------------------