├── docs ├── logo.png ├── cases.png ├── cnocr-wx.png └── cnstd-qq.jpg ├── examples ├── jd.jpg ├── 1_res.png ├── 2_res.png ├── life.jpg ├── beauty.png ├── beauty2.jpg ├── beauty3.jpg ├── mfd │ ├── en1.jpg │ ├── en2.jpg │ ├── zh.jpg │ ├── zh1.jpg │ ├── zh2.jpg │ ├── zh3.jpg │ ├── zh4.jpg │ ├── zh5.jpg │ ├── out-en2.jpg │ ├── out-zh4.jpg │ └── out-zh5.jpg ├── taobao.jpg ├── taobao2.jpg ├── taobao3.jpg ├── taobao4.jpg ├── taobao5.jpg ├── layout │ └── out-zh.jpg ├── train_config_gpu.json └── train_config.json ├── requirements.in ├── gpu.Makefile ├── tests ├── test_utils.py ├── test_transforms.py ├── test_models.py ├── test_cnstd.py ├── test_lr_schedulers.py └── test_rapidocr.py ├── Makefile ├── apps └── mfd │ └── anno.Makefile ├── cnstd ├── yolov7 │ ├── __init__.py │ ├── consts.py │ ├── yolov7-mfd.yaml │ ├── yolov7-tiny-layout.yaml │ ├── yolov7-tiny-mfd.yaml │ └── autoanchor.py ├── __version__.py ├── datasets │ ├── __init__.py │ └── util.py ├── transforms │ ├── __init__.py │ ├── random_crop.py │ ├── resize.py │ ├── base.py │ └── utils.py ├── __init__.py ├── utils │ ├── __init__.py │ ├── common_types.py │ ├── repr.py │ ├── geometry.py │ └── _utils.py ├── ppocr │ ├── __init__.py │ ├── postprocess │ │ ├── cls_postprocess.py │ │ ├── __init__.py │ │ └── db_postprocess.py │ ├── consts.py │ ├── opt_utils.py │ ├── angle_classifier.py │ └── rapid_detector.py ├── model │ ├── __init__.py │ └── fpn.py ├── app.py ├── yolo_detector.py ├── lr_scheduler.py ├── cn_std.py ├── trainer.py └── consts.py ├── .gitignore ├── scripts ├── detect_images.py ├── convert_label_studio_to_yolov7.py ├── generate_idx_file.py └── gen_label_studio_json.py ├── setup.py ├── requirements.txt └── RELEASE.md /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/docs/logo.png -------------------------------------------------------------------------------- /docs/cases.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/docs/cases.png -------------------------------------------------------------------------------- /examples/jd.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/jd.jpg -------------------------------------------------------------------------------- /docs/cnocr-wx.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/docs/cnocr-wx.png -------------------------------------------------------------------------------- /docs/cnstd-qq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/docs/cnstd-qq.jpg -------------------------------------------------------------------------------- /examples/1_res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/1_res.png -------------------------------------------------------------------------------- /examples/2_res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/2_res.png -------------------------------------------------------------------------------- /examples/life.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/life.jpg -------------------------------------------------------------------------------- /examples/beauty.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/beauty.png -------------------------------------------------------------------------------- /examples/beauty2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/beauty2.jpg -------------------------------------------------------------------------------- /examples/beauty3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/beauty3.jpg -------------------------------------------------------------------------------- /examples/mfd/en1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/mfd/en1.jpg -------------------------------------------------------------------------------- /examples/mfd/en2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/mfd/en2.jpg -------------------------------------------------------------------------------- /examples/mfd/zh.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/mfd/zh.jpg -------------------------------------------------------------------------------- /examples/mfd/zh1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/mfd/zh1.jpg -------------------------------------------------------------------------------- /examples/mfd/zh2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/mfd/zh2.jpg -------------------------------------------------------------------------------- /examples/mfd/zh3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/mfd/zh3.jpg -------------------------------------------------------------------------------- /examples/mfd/zh4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/mfd/zh4.jpg -------------------------------------------------------------------------------- /examples/mfd/zh5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/mfd/zh5.jpg -------------------------------------------------------------------------------- /examples/taobao.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/taobao.jpg -------------------------------------------------------------------------------- /examples/taobao2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/taobao2.jpg -------------------------------------------------------------------------------- /examples/taobao3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/taobao3.jpg -------------------------------------------------------------------------------- /examples/taobao4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/taobao4.jpg -------------------------------------------------------------------------------- /examples/taobao5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/taobao5.jpg -------------------------------------------------------------------------------- /examples/mfd/out-en2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/mfd/out-en2.jpg -------------------------------------------------------------------------------- /examples/mfd/out-zh4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/mfd/out-zh4.jpg -------------------------------------------------------------------------------- /examples/mfd/out-zh5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/mfd/out-zh5.jpg -------------------------------------------------------------------------------- /examples/layout/out-zh.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/breezedeus/CnSTD/HEAD/examples/layout/out-zh.jpg -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | --index-url https://pypi.tuna.tsinghua.edu.cn/simple 2 | --extra-index-url https://pypi.org/simple 3 | 4 | click 5 | tqdm 6 | pyyaml 7 | unidecode 8 | torch>=1.8.0 9 | torchvision>=0.9.0 10 | numpy 11 | scipy 12 | pandas 13 | pytorch-lightning>=1.6.0 14 | pillow>=5.3.0 15 | opencv-python>=4.0.0 16 | shapely 17 | # Polygon3 18 | pyclipper 19 | matplotlib 20 | seaborn 21 | onnx 22 | onnxruntime 23 | huggingface_hub 24 | ultralytics 25 | rapidocr>=3.0 -------------------------------------------------------------------------------- /gpu.Makefile: -------------------------------------------------------------------------------- 1 | MODEL_NAME = db_resnet18 2 | 3 | train: 4 | cnstd train -m $(MODEL_NAME) --train-config-fp examples/train_config_gpu.json -i data 5 | 6 | predict: 7 | cnstd predict -m $(MODEL_NAME) --model_epoch 29 --rotated-bbox --box-score-thresh 0.3 --resized-shape 768,768 \ 8 | --context cuda:0 -i examples -o prediction 9 | 10 | package: 11 | python setup.py sdist bdist_wheel 12 | 13 | VERSION = 1.0.0 14 | upload: 15 | python -m twine upload dist/cnstd-$(VERSION)* --verbose 16 | 17 | 18 | .PHONY: train predict package upload 19 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from cnstd.utils.utils import sort_boxes 4 | 5 | 6 | def four_to_eight(box): 7 | x1, y1, x2, y2 = box 8 | return [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] 9 | 10 | 11 | def test_sort_boxes(): 12 | # 一些用于测试的box坐标 13 | boxes = [ 14 | [0, 2, 20, 18], 15 | [21, 1, 40, 19], 16 | [0, 20, 20, 40], 17 | [21, 20, 40, 40], 18 | ] 19 | boxes = [{'box': four_to_eight(box)} for box in boxes] 20 | out = sort_boxes(boxes, key='box') 21 | print(out) 22 | -------------------------------------------------------------------------------- /examples/train_config_gpu.json: -------------------------------------------------------------------------------- 1 | { 2 | "debug": false, 3 | "preserve_aspect_ratio": true, 4 | 5 | "vocab_fp": "label_cn.txt", 6 | "data_root_dir": "/home/ein/jinlong/std_data", 7 | 8 | "fpn_type": "fpn", 9 | "rotated_bbox": true, 10 | "resized_shape": [3, 768, 768], 11 | 12 | "gpus": [1], 13 | "epochs": 50, 14 | "batch_size": 16, 15 | "num_workers": 10, 16 | "pin_memory": true, 17 | "optimizer": "adam", 18 | "learning_rate": 1e-3, 19 | "weight_decay": 0, 20 | "lr_scheduler": { 21 | "name": "cos_warmup", 22 | }, 23 | "precision": 16, 24 | "limit_train_batches": 1.0, 25 | "limit_val_batches": 1.0, 26 | "pl_checkpoint_monitor": "iou_epoch", 27 | "pl_checkpoint_mode": "max" 28 | } 29 | -------------------------------------------------------------------------------- /examples/train_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "debug": false, 3 | "preserve_aspect_ratio": true, 4 | 5 | "vocab_fp": "label_cn.txt", 6 | "data_root_dir": "data", 7 | 8 | "fpn_type": "pan", 9 | "auto_rotate_whole_image": false, 10 | "rotated_bbox": true, 11 | "resized_shape": [3, 512, 512], 12 | 13 | "gpus": 0, 14 | "epochs": 2, 15 | "batch_size": 4, 16 | "num_workers": 0, 17 | "pin_memory": false, 18 | "optimizer": "adam", 19 | "learning_rate": 3e-3, 20 | "weight_decay": 0, 21 | "lr_scheduler": { 22 | "name": "cos_warmup", 23 | "min_lr_mult_factor": 0.1, 24 | "warmup_epochs": 0.1 25 | }, 26 | "precision": 32, 27 | "limit_train_batches": 1.0, 28 | "limit_val_batches": 1.0, 29 | "pl_checkpoint_monitor": "iou_epoch", 30 | "pl_checkpoint_mode": "max" 31 | } 32 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | MODEL_NAME = db_resnet18 2 | 3 | train: 4 | cnstd train -m $(MODEL_NAME) --train-config-fp examples/train_config.json -i data/icdar2015 5 | 6 | predict: 7 | cnstd predict -m $(MODEL_NAME) --model_epoch 29 --rotated-bbox --box-score-thresh 0.3 --resized-shape 768,768 \ 8 | --context cpu -i examples -o prediction 9 | 10 | layout: 11 | cnstd analyze -m layout --conf-thresh 0.25 --resized-shape 800 --img-fp examples/mfd/zh.jpg 12 | 13 | mfd: 14 | cnstd analyze -m mfd --conf-thresh 0.25 --resized-shape 700 --img-fp examples/mfd/zh4.jpg 15 | 16 | demo: 17 | pip install streamlit 18 | streamlit run cnstd/app.py 19 | 20 | package: 21 | rm -rf build 22 | python setup.py sdist bdist_wheel 23 | 24 | VERSION := $(shell sed -n "s/^__version__ = '\(.*\)'/\1/p" cnstd/__version__.py) 25 | upload: 26 | python -m twine upload dist/cnstd-$(VERSION)* --verbose 27 | 28 | 29 | .PHONY: train predict layout mfd demo package upload 30 | -------------------------------------------------------------------------------- /apps/mfd/anno.Makefile: -------------------------------------------------------------------------------- 1 | MODEL_TYPE = 'yolov7' 2 | MODEL_FP = '/home/ein/.cnstd/1.2/analysis/mfd-yolov7-epoch224-20230613.pt' 3 | LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT = '/data/jinlong/std_data' 4 | INPUT_IMAGE_DIR = '/data/jinlong/std_data/call_images/images/2023-02-27_2023-03-05' 5 | 6 | # 生成检测结果(json格式)文件,这个文件可以导入到label studio中,生成待标注的任务 7 | predict: 8 | python scripts/gen_label_studio_json.py --model-type $(MODEL_TYPE) --model-fp $(MODEL_FP) \ 9 | --resized-shape 608 -l $(LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT) -i $(INPUT_IMAGE_DIR) -o 'prediction_results.json' 10 | 11 | # 把标注结果转化成 breezedeus/yolov7 中模型训练所需的文件格式 12 | convert_to_yolov7: 13 | python scripts/convert_label_studio_to_yolov7.py --anno-json-fp-list 'annotation.json' \ 14 | --index-prefix 'data/call_images/images/2023-02-27_2023-03-05' \ 15 | --out-labels-dir '/data/jinlong/std_data/call_images/labels/2023-02-27_2023-03-05' --out-index-fp 'train.txt' 16 | 17 | .PHONY: predict convert_to_yolov7 18 | -------------------------------------------------------------------------------- /cnstd/yolov7/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2022, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | -------------------------------------------------------------------------------- /cnstd/__version__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021-2023, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | __version__ = '1.2.6.1' 21 | -------------------------------------------------------------------------------- /cnstd/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | from .dataset import StdDataset, StdDataModule 21 | -------------------------------------------------------------------------------- /tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision.transforms as T 6 | 7 | from cnstd.utils import ( 8 | set_logger, 9 | data_dir, 10 | load_model_params, 11 | imsave, 12 | imread, 13 | ) 14 | 15 | EXAMPLE_DIR = Path(__file__).parent.parent / 'examples' 16 | 17 | 18 | def test_transforms(): 19 | train_transform = T.Compose( # MUST NOT include `Resize` 20 | [ 21 | # T.RandomInvert(p=1.0), 22 | T.RandomPosterize(bits=4, p=1.0), 23 | T.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), 24 | # T.ColorJitter(brightness=0.3, contrast=0.2, saturation=0.2, hue=0.2), 25 | # T.RandomEqualize(p=0.3), 26 | # T.RandomApply([T.GaussianBlur(kernel_size=3)], p=0.5), 27 | ] 28 | ) 29 | img_fp = EXAMPLE_DIR / '1_res.png' 30 | img = imread(img_fp) 31 | img = train_transform(torch.from_numpy(img)) 32 | imsave(img.numpy().transpose((1, 2, 0)), 'test-transformed.png', normalized=False) 33 | -------------------------------------------------------------------------------- /cnstd/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/mindee/doctr 20 | 21 | from .base import * 22 | from .resize import Resize 23 | from .random_crop import random_crop 24 | -------------------------------------------------------------------------------- /cnstd/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | from .detector import Detector 21 | from .ppocr import PPDetector 22 | from .yolov7.layout_analyzer import LayoutAnalyzer, save_layout_img 23 | 24 | from .cn_std import CnStd 25 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | import sys 4 | import pytest 5 | 6 | import torch 7 | 8 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) 10 | 11 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | example_dir = os.path.join(root_dir, 'examples') 13 | 14 | from cnstd.consts import MODEL_CONFIGS 15 | from cnstd.model.dbnet import gen_dbnet 16 | 17 | 18 | def test_db_mobilenet(): 19 | model = gen_dbnet(MODEL_CONFIGS['db_mobilenet_v3'], pretrained=False, pretrained_backbone=False) 20 | input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) 21 | out = model(input_tensor) 22 | print(out.keys()) 23 | print(out['preds'][0][0].shape) 24 | 25 | 26 | def test_db_shufflenet(): 27 | model = gen_dbnet(MODEL_CONFIGS['db_shufflenet_v2'], pretrained=False, pretrained_backbone=False) 28 | input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) 29 | out = model(input_tensor) 30 | print(out.keys()) 31 | print(out['preds'][0][0].shape) 32 | -------------------------------------------------------------------------------- /cnstd/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/mindee/doctr 20 | 21 | from .geometry import * 22 | from .common_types import * 23 | from .metrics import * 24 | from .utils import * 25 | from ._utils import * 26 | -------------------------------------------------------------------------------- /tests/test_cnstd.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | import sys 4 | import pytest 5 | from PIL import Image 6 | 7 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | sys.path.insert(1, os.path.dirname(os.path.abspath(__file__))) 9 | 10 | from cnstd import CnStd 11 | from cnstd.consts import AVAILABLE_MODELS 12 | 13 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 14 | example_dir = os.path.join(root_dir, 'examples') 15 | 16 | 17 | def test_ppocr_models(): 18 | model_name, model_backend = 'ch_PP-OCRv5_det_server', 'onnx' 19 | img_fp = os.path.join(example_dir, 'beauty2.jpg') 20 | std = CnStd(model_name, model_backend=model_backend, use_angle_clf=True) 21 | img = Image.open(img_fp) 22 | box_info_list = std.detect(img) 23 | print(len(box_info_list)) 24 | 25 | 26 | @pytest.mark.parametrize('model_name, model_backend', AVAILABLE_MODELS.all_models()) 27 | def test_cnstd(model_name, model_backend): 28 | img_fp = os.path.join(example_dir, 'beauty2.jpg') 29 | std = CnStd(model_name, model_backend=model_backend, rotated_bbox=False) 30 | box_info_list = std.detect(img_fp) 31 | print(len(box_info_list)) 32 | -------------------------------------------------------------------------------- /cnstd/ppocr/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2022, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | from ..consts import AVAILABLE_MODELS 21 | from .consts import MODEL_LABELS_FILE_DICT, PP_SPACE 22 | from .pp_detector import PPDetector 23 | from .rapid_detector import RapidDetector 24 | 25 | AVAILABLE_MODELS.register_models(MODEL_LABELS_FILE_DICT, space=PP_SPACE) 26 | -------------------------------------------------------------------------------- /cnstd/utils/common_types.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/mindee/doctr 20 | 21 | from typing import Tuple, List 22 | 23 | __all__ = ['Point2D', 'BoundingBox', 'RotatedBbox', 'Polygon4P', 'Polygon'] 24 | 25 | 26 | Point2D = Tuple[float, float] 27 | BoundingBox = Tuple[Point2D, Point2D] 28 | RotatedBbox = Tuple[float, float, float, float, float] 29 | Polygon4P = Tuple[Point2D, Point2D, Point2D, Point2D] 30 | Polygon = List[Point2D] 31 | -------------------------------------------------------------------------------- /cnstd/yolov7/consts.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2022, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits to: CDLA Dataset --- https://github.com/buptlihang/CDLA 20 | 21 | # Analysis 22 | CATEGORY_DICT = { 23 | 'layout': [ # Layout Analysis 24 | '_background_', 25 | 'Text', 26 | 'Title', 27 | 'Figure', 28 | 'Figure caption', 29 | 'Table', 30 | 'Table caption', 31 | 'Header', 32 | 'Footer', 33 | 'Reference', 34 | 'Equation', 35 | ], 36 | 'mfd': ['embedding', 'isolated'], # Mathematical Formula Detection 37 | } 38 | -------------------------------------------------------------------------------- /cnstd/ppocr/postprocess/cls_postprocess.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2022, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/PaddlePaddle/PaddleOCR 20 | 21 | class ClsPostProcess(object): 22 | """ Convert between text-label and text-index """ 23 | 24 | def __init__(self, label_list, **kwargs): 25 | super(ClsPostProcess, self).__init__() 26 | self.label_list = label_list 27 | 28 | def __call__(self, preds, label=None, *args, **kwargs): 29 | pred_idxs = preds.argmax(axis=1) 30 | decode_out = [(self.label_list[idx], preds[i, idx]) 31 | for i, idx in enumerate(pred_idxs)] 32 | if label is None: 33 | return decode_out 34 | label = [(self.label_list[idx], 1.0) for idx in label] 35 | return decode_out, label 36 | -------------------------------------------------------------------------------- /tests/test_lr_schedulers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import lr_scheduler 4 | import matplotlib.pyplot as plt 5 | 6 | from cnstd.lr_scheduler import WarmupCosineAnnealingRestarts 7 | 8 | 9 | class NullModule(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | self.fc = nn.Linear(1, 1) 13 | 14 | 15 | ori_lr = 5e-4 16 | model = NullModule() 17 | optimizer = torch.optim.Adam(model.parameters()) 18 | 19 | 20 | def plot_lr(scheduler, step=900): 21 | lrs = [] 22 | for i in range(step): 23 | lr = optimizer.param_groups[0]['lr'] 24 | scheduler.step() 25 | lrs.append(lr) 26 | 27 | plt.plot(lrs) 28 | plt.show() 29 | 30 | 31 | def test_CosineAnnealingWarmRestarts(): 32 | CAW = lr_scheduler.CosineAnnealingWarmRestarts( 33 | optimizer, T_0=200, T_mult=1, eta_min=ori_lr / 10.0 34 | ) 35 | plot_lr(CAW, step=1000) 36 | 37 | 38 | def test_WarmupCosineAnnealingRestarts(): 39 | CAW = WarmupCosineAnnealingRestarts( 40 | optimizer, 41 | first_cycle_steps=95600, 42 | cycle_mult=1.0, 43 | max_lr=0.001, 44 | min_lr=0.0001, 45 | warmup_steps=100, 46 | gamma=1.0, 47 | ) 48 | plot_lr(CAW, step=95600) 49 | 50 | 51 | def test_CyclicLR(): 52 | Cyc = lr_scheduler.CyclicLR( 53 | optimizer, 54 | base_lr=ori_lr / 10.0, 55 | max_lr=ori_lr, 56 | step_size_up=200, 57 | cycle_momentum=False, 58 | ) 59 | 60 | plot_lr(Cyc, 1000) 61 | 62 | 63 | def test_OneCycleLR(): 64 | Cyc = lr_scheduler.OneCycleLR( 65 | optimizer, max_lr=0.1, epochs=20, steps_per_epoch=50, 66 | ) 67 | 68 | plot_lr(Cyc, 1000) 69 | -------------------------------------------------------------------------------- /cnstd/ppocr/postprocess/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2022, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/PaddlePaddle/PaddleOCR 20 | 21 | import copy 22 | 23 | __all__ = ['build_post_process'] 24 | 25 | from .db_postprocess import DBPostProcess, DistillationDBPostProcess 26 | from .cls_postprocess import ClsPostProcess 27 | 28 | 29 | def build_post_process(config, global_config=None): 30 | support_dict = [ 31 | 'DBPostProcess', 32 | 'ClsPostProcess', 33 | 'DistillationDBPostProcess', 34 | ] 35 | 36 | config = copy.deepcopy(config) 37 | module_name = config.pop('name') 38 | if module_name == "None": 39 | return 40 | if global_config is not None: 41 | config.update(global_config) 42 | assert module_name in support_dict, Exception( 43 | 'post process only support {}'.format(support_dict)) 44 | module_class = eval(module_name)(**config) 45 | return module_class 46 | -------------------------------------------------------------------------------- /cnstd/model/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | from copy import deepcopy 21 | 22 | from .dbnet import gen_dbnet, DBNet 23 | from ..consts import MODEL_CONFIGS 24 | 25 | 26 | def gen_model(model_name: str, pretrained_backbone: bool = True, **kwargs) -> DBNet: 27 | """ 28 | 29 | Args: 30 | model_name: 31 | pretrained_backbone: whether use pretrained for the backbone model 32 | **kwargs: 33 | 'rotated_bbox': bool, 是否考虑非水平的boxes 34 | 'pretrained': bool, 是否使用预训练好的模型 35 | 'input_shape': Tuple[int, int, int], resize后输入模型的图片大小:[C, H, W] 36 | 37 | Returns: a DBNet model 38 | 39 | """ 40 | if model_name not in MODEL_CONFIGS: 41 | raise KeyError('got unsupported model name: %s' % model_name) 42 | 43 | config = deepcopy(MODEL_CONFIGS[model_name]) 44 | config.update(**kwargs) 45 | return gen_dbnet( 46 | config, pretrained_backbone=pretrained_backbone, **kwargs 47 | ) 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | data/ 93 | debugs/ 94 | predictions/ 95 | output-* 96 | out-* 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | .vscode/ 112 | .idea/ 113 | result/ 114 | *.pyc 115 | -------------------------------------------------------------------------------- /cnstd/ppocr/consts.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2022, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | 21 | MODEL_LABELS_FILE_DICT = { 22 | ('ch_PP-OCRv3_det', 'onnx'): { 23 | 'url': 'ch_PP-OCRv3_det_infer-onnx.zip', 24 | }, 25 | ('ch_PP-OCRv2_det', 'onnx'): { 26 | 'url': 'ch_PP-OCRv2_det_infer-onnx.zip', 27 | }, 28 | ('en_PP-OCRv3_det', 'onnx'): { 29 | 'url': 'en_PP-OCRv3_det_infer-onnx.zip', 30 | # 'detector': 'RapidDetector', 31 | # 'repo': 'breezedeus/cnstd-ppocr-en_PP-OCRv3_det', 32 | }, 33 | ('ch_PP-OCRv4_det', 'onnx'): { 34 | 'detector': 'RapidDetector', 35 | 'repo': 'breezedeus/cnstd-ppocr-ch_PP-OCRv4_det', 36 | }, 37 | ('ch_PP-OCRv4_det_server', 'onnx'): { 38 | 'detector': 'RapidDetector', 39 | 'repo': 'breezedeus/cnstd-ppocr-ch_PP-OCRv4_det_server', 40 | }, 41 | ('ch_PP-OCRv5_det', 'onnx'): { 42 | 'detector': 'RapidDetector', 43 | 'repo': 'breezedeus/cnstd-ppocr-ch_PP-OCRv5_det', 44 | }, 45 | ('ch_PP-OCRv5_det_server', 'onnx'): { 46 | 'detector': 'RapidDetector', 47 | 'repo': 'breezedeus/cnstd-ppocr-ch_PP-OCRv5_det_server', 48 | }, 49 | } 50 | 51 | PP_SPACE = 'ppocr' 52 | -------------------------------------------------------------------------------- /cnstd/ppocr/opt_utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2022, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/PaddlePaddle/PaddleOCR 20 | 21 | from .img_operators import * 22 | 23 | 24 | def transform(data, ops=None): 25 | """ transform """ 26 | if ops is None: 27 | ops = [] 28 | for op in ops: 29 | data = op(data) 30 | if data is None: 31 | return None 32 | return data 33 | 34 | 35 | def create_operators(op_param_list, global_config=None): 36 | """ 37 | create operators based on the config 38 | 39 | Args: 40 | params(list): a dict list, used to create some operators 41 | """ 42 | assert isinstance(op_param_list, list), ('operator config should be a list') 43 | ops = [] 44 | for operator in op_param_list: 45 | assert isinstance(operator, 46 | dict) and len(operator) == 1, "yaml format error" 47 | op_name = list(operator)[0] 48 | param = {} if operator[op_name] is None else operator[op_name] 49 | if global_config is not None: 50 | param.update(global_config) 51 | op = eval(op_name)(**param) 52 | ops.append(op) 53 | return ops 54 | -------------------------------------------------------------------------------- /scripts/detect_images.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | 5 | from cnstd import CnStd 6 | from cnstd.utils import imsave 7 | 8 | 9 | def read_idx_file(idx_fp): 10 | img_label_pairs = [] 11 | with open(idx_fp) as f: 12 | for line in f: 13 | img_fp, gt_fp = line.strip().split('\t') 14 | img_fp = img_fp.split('\\')[-1] 15 | img_label_pairs.append((img_fp, gt_fp)) 16 | return img_label_pairs 17 | 18 | 19 | def main(): 20 | root_data_dir = '/Users/king/Documents/beiye-Ein/语料/ocr/From-CnOCR-Users/ocr' 21 | index_fp = os.path.join(root_data_dir, 'train.tsv') 22 | image_dir = os.path.join(root_data_dir, 'train_pic') 23 | out_index_fp = open(os.path.join(root_data_dir, 'train_cleaned.tsv'), 'w') 24 | out_image_dir = os.path.join(root_data_dir, 'train_pic_cleaned') 25 | if not os.path.exists(out_image_dir): 26 | os.makedirs(out_image_dir) 27 | img_label_pairs = read_idx_file(index_fp) 28 | std_model_name = 'db_shufflenet_v2_small' 29 | std = CnStd( 30 | std_model_name, 31 | rotated_bbox=False, 32 | context='cpu', 33 | ) 34 | resized_shape = (384, 384) 35 | 36 | num_success = 0 37 | num_total = len(img_label_pairs) 38 | for idx, (img_fp, label) in enumerate(img_label_pairs): 39 | if idx % 100 == 0: 40 | print(f'{idx=}, {num_success=}') 41 | std_out = std.detect( 42 | os.path.join(image_dir, img_fp), 43 | resized_shape=resized_shape, 44 | preserve_aspect_ratio=True, 45 | box_score_thresh=0.3, 46 | ) 47 | # if img_fp == 'A_ISO_18.JPG': 48 | # breakpoint() 49 | if len(std_out['detected_texts']) != 1: 50 | continue 51 | cropped_img = std_out['detected_texts'][0]['cropped_img'] 52 | h, w = cropped_img.shape[:2] 53 | if w < 2.5 * h: 54 | continue 55 | 56 | imsave(cropped_img, os.path.join(out_image_dir, img_fp), normalized=False) 57 | out_index_fp.write(f'{img_fp}\t{label}\n') 58 | num_success += 1 59 | 60 | print(f'Totally, {num_total=}, {num_success=}.') 61 | out_index_fp.close() 62 | 63 | 64 | if __name__ == '__main__': 65 | main() 66 | -------------------------------------------------------------------------------- /cnstd/utils/repr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/mindee/doctr 20 | 21 | __all__ = ['NestedObject'] 22 | 23 | 24 | def _addindent(s_, num_spaces): 25 | s = s_.split('\n') 26 | # don't do anything for single-line stuff 27 | if len(s) == 1: 28 | return s_ 29 | first = s.pop(0) 30 | s = [(num_spaces * ' ') + line for line in s] 31 | s = '\n'.join(s) 32 | s = first + '\n' + s 33 | return s 34 | 35 | 36 | class NestedObject: 37 | def extra_repr(self) -> str: 38 | return '' 39 | 40 | def __repr__(self): 41 | # We treat the extra repr like the sub-object, one item per line 42 | extra_lines = [] 43 | extra_repr = self.extra_repr() 44 | # empty string will be split into list [''] 45 | if extra_repr: 46 | extra_lines = extra_repr.split('\n') 47 | child_lines = [] 48 | if hasattr(self, '_children_names'): 49 | for key in self._children_names: 50 | child = getattr(self, key) 51 | if isinstance(child, list) and len(child) > 0: 52 | child_str = ",\n".join([repr(subchild) for subchild in child]) 53 | if len(child) > 1: 54 | child_str = _addindent(f"\n{child_str},", 2) + '\n' 55 | child_str = f"[{child_str}]" 56 | else: 57 | child_str = repr(child) 58 | child_str = _addindent(child_str, 2) 59 | child_lines.append('(' + key + '): ' + child_str) 60 | lines = extra_lines + child_lines 61 | 62 | main_str = self.__class__.__name__ + '(' 63 | if lines: 64 | # simple one-liner info, which most builtin Modules will use 65 | if len(extra_lines) == 1 and not child_lines: 66 | main_str += extra_lines[0] 67 | else: 68 | main_str += '\n ' + '\n '.join(lines) + '\n' 69 | 70 | main_str += ')' 71 | return main_str 72 | -------------------------------------------------------------------------------- /scripts/convert_label_studio_to_yolov7.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 把 Label Studio 导出的JSON文件,转成训练 YoloV7 MFD模型所需的格式 3 | import os 4 | import json 5 | from argparse import ArgumentParser 6 | 7 | 8 | LABEL_MAPPINGS = {'embedding': 0, 'isolated': 1} 9 | 10 | 11 | def bboxs_to_file(bboxs, out_fp): 12 | with open(out_fp, 'w') as f: 13 | for bbox in bboxs: 14 | bbox_str = ' '.join(map(str, bbox)) 15 | f.write(f'{bbox_str}\n') 16 | 17 | 18 | def main(): 19 | parser = ArgumentParser() 20 | parser.add_argument( 21 | '--anno-json-fp-list', 22 | nargs='+', 23 | type=str, 24 | help='从Label Studio导出的标注结果JSON文件,可以有多个文件', 25 | ) 26 | parser.add_argument( 27 | '--index-prefix', type=str, default='images', help='输出的索引文件中,每行图片的路径会加上这个前缀', 28 | ) 29 | parser.add_argument( 30 | '--out-labels-dir', 31 | type=str, 32 | default='labels', 33 | help='输出的labels所在的文件夹,应该是 "**/**/labels/**" 这种格式', 34 | ) 35 | parser.add_argument( 36 | '--out-index-fp', type=str, default='train.txt', help='输出的YoloV7 MFD格式的标注文件' 37 | ) 38 | args = parser.parse_args() 39 | # img_root_dir = os.path.join(args.out_root_dir, 'images') 40 | # label_root_dir = os.path.join(args.out_root_dir, 'labels') 41 | label_root_dir = args.out_labels_dir 42 | 43 | fp_list = [] 44 | for json_fp in args.anno_json_fp_list: 45 | ori_contents = json.load(open(json_fp)) 46 | for info in ori_contents: 47 | fp_url = info['data']['image'] 48 | remove_len = len(r'/data/local-files/?d=') 49 | fp = fp_url[remove_len:] 50 | annotations = info['annotations'][0].get('result', []) 51 | if not annotations: 52 | continue 53 | 54 | fn = os.path.basename(fp) 55 | # fp_list.append(os.path.join(img_root_dir, fp)) 56 | fp_list.append(os.path.join(args.index_prefix, fn)) 57 | 58 | # label_dir = os.path.join(label_root_dir, os.path.dirname(fp)) 59 | label_dir = label_root_dir 60 | if not os.path.exists(label_dir): 61 | os.makedirs(label_dir) 62 | label_fp = os.path.join(label_dir, fn.rsplit('.', maxsplit=1)[0] + '.txt') 63 | 64 | label_infos = [] 65 | for annotation in annotations: 66 | value = annotation['value'] 67 | x, y = value['x'], value['y'] 68 | width, height = value['width'], value['height'] 69 | # to [x0, y0, x1, y1, x2, y2, x3, y3] 70 | bbox = [x, y, x + width, y, x + width, y + height, x, y + height] 71 | bbox = [v * 0.01 for v in bbox] 72 | 73 | label = value['rectanglelabels'][0] 74 | label_id = LABEL_MAPPINGS[label] 75 | 76 | label_infos.append([label_id] + bbox) 77 | bboxs_to_file(label_infos, label_fp) 78 | 79 | with open(args.out_index_fp, 'w') as f: 80 | for fp in fp_list: 81 | f.write(fp + '\n') 82 | 83 | 84 | if __name__ == '__main__': 85 | main() 86 | -------------------------------------------------------------------------------- /cnstd/transforms/random_crop.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/MhLiao/DB 20 | 21 | from copy import deepcopy 22 | from typing import List, Dict, Any 23 | 24 | from PIL import Image 25 | import numpy as np 26 | 27 | 28 | def random_crop( 29 | image: Image.Image, 30 | boxes: List[Dict[str, Any]], 31 | max_tries, 32 | w_axis, 33 | h_axis, 34 | min_crop_side_ratio, 35 | ): 36 | """随机选取一个框,然后只保留这个框中图片""" 37 | w, h = image.size 38 | selected_boxes = [] 39 | for i in range(max_tries): 40 | xx = np.random.choice(w_axis, size=2) 41 | xmin = np.min(xx) 42 | xmax = np.max(xx) 43 | xmin = np.clip(xmin, 0, w - 1) 44 | xmax = np.clip(xmax, 0, w - 1) 45 | yy = np.random.choice(h_axis, size=2) 46 | ymin = np.min(yy) 47 | ymax = np.max(yy) 48 | ymin = np.clip(ymin, 0, h - 1) 49 | ymax = np.clip(ymax, 0, h - 1) 50 | if ( 51 | xmax - xmin < min_crop_side_ratio * w 52 | or ymax - ymin < min_crop_side_ratio * h 53 | ): 54 | # area too small 55 | continue 56 | if len(boxes) != 0: 57 | selected_boxes = np.array( 58 | [ 59 | idx 60 | for idx, box in enumerate(boxes) 61 | if _in_area(box['poly'], xmin, xmax, ymin, ymax) 62 | ] 63 | ) 64 | if len([i for i in selected_boxes if boxes[i]['text'] != '###']) > 0: 65 | break 66 | else: 67 | selected_boxes = [] 68 | break 69 | if i == max_tries - 1: 70 | return image, boxes 71 | 72 | new_image = image.crop((xmin, ymin, xmax, ymax)) 73 | new_boxes = [] 74 | for i in selected_boxes: 75 | box = deepcopy(boxes[i]) 76 | box['poly'][:, 0] -= xmin 77 | box['poly'][:, 1] -= ymin 78 | new_boxes.append(box) 79 | return new_image, new_boxes 80 | 81 | 82 | def _in_area(box, xmin, xmax, ymin, ymax) -> bool: 83 | box_axis_in_area = ( 84 | (box[:, 0] >= xmin) 85 | & (box[:, 0] <= xmax) 86 | & (box[:, 1] >= ymin) 87 | & (box[:, 1] <= ymax) 88 | ) 89 | return np.sum(box_axis_in_area) == 4 90 | -------------------------------------------------------------------------------- /cnstd/transforms/resize.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/mindee/doctr 20 | 21 | import math 22 | import torch 23 | from torchvision.transforms import transforms as T 24 | from torchvision.transforms import functional as F 25 | from torch.nn.functional import pad 26 | from typing import Tuple 27 | 28 | 29 | class Resize(T.Resize): 30 | def __init__( 31 | self, 32 | size: Tuple[int, int], # [H, W] 33 | interpolation=F.InterpolationMode.BILINEAR, 34 | preserve_aspect_ratio: bool = False, 35 | symmetric_pad: bool = False, 36 | ) -> None: 37 | super().__init__(size, interpolation) 38 | self.preserve_aspect_ratio = preserve_aspect_ratio 39 | self.symmetric_pad = symmetric_pad 40 | 41 | def forward(self, img: torch.Tensor) -> torch.Tensor: 42 | """ 43 | 44 | Args: 45 | img: [C, H, W] 46 | 47 | Returns: 48 | 49 | """ 50 | target_ratio = self.size[0] / self.size[1] 51 | actual_ratio = img.shape[-2] / img.shape[-1] 52 | if not self.preserve_aspect_ratio or (target_ratio == actual_ratio): 53 | return super().forward(img) 54 | else: 55 | # Resize 56 | if actual_ratio > target_ratio: 57 | tmp_size = (self.size[0], int(self.size[0] / actual_ratio)) 58 | else: 59 | tmp_size = (int(self.size[1] * actual_ratio), self.size[1]) 60 | 61 | # Scale image 62 | if tuple(img.shape[1:]) != tmp_size: 63 | img = F.resize(img, tmp_size, self.interpolation) 64 | # Pad (inverted in pytorch) 65 | _pad = (0, self.size[1] - img.shape[-1], 0, self.size[0] - img.shape[-2]) 66 | if self.symmetric_pad: 67 | half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2)) 68 | _pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1]) 69 | return pad(img, _pad) 70 | 71 | def __repr__(self) -> str: 72 | interpolate_str = self.interpolation.value 73 | _repr = f"output_size={self.size}, interpolation='{interpolate_str}'" 74 | if self.preserve_aspect_ratio: 75 | _repr += f", preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}" 76 | return f"{self.__class__.__name__}({_repr})" 77 | -------------------------------------------------------------------------------- /tests/test_rapidocr.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | import os 21 | import pytest 22 | import torch 23 | from pathlib import Path 24 | 25 | from rapidocr import RapidOCR, EngineType, LangDet, ModelType, OCRVersion, LangRec 26 | from rapidocr.utils import LoadImage 27 | from rapidocr.ch_ppocr_det import TextDetector 28 | 29 | from cnstd.utils import set_logger 30 | from cnstd.ppocr.rapid_detector import RapidDetector, Config 31 | 32 | logger = set_logger() 33 | 34 | 35 | def test_whole_pipeline(): 36 | engine = RapidOCR( 37 | params={ 38 | "Det.engine_type": EngineType.ONNXRUNTIME, 39 | "Det.lang_type": LangDet.CH, 40 | "Det.model_type": ModelType.SERVER, 41 | "Det.ocr_version": OCRVersion.PPOCRV5, 42 | "Rec.engine_type": EngineType.ONNXRUNTIME, 43 | "Rec.lang_type": LangRec.CH, 44 | "Rec.model_type": ModelType.SERVER, 45 | "Rec.ocr_version": OCRVersion.PPOCRV5, 46 | } 47 | ) 48 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 49 | example_dir = Path(root_dir) / "examples" 50 | img_path = example_dir / 'multi-line_cn1.png' 51 | result = engine(img_path) 52 | print(result) 53 | 54 | 55 | def test_det(): 56 | config = Config(Config.DEFAULT_CFG) 57 | engine = TextDetector(config) 58 | 59 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 60 | example_dir = Path(root_dir) / "docs" 61 | img_path = example_dir / "cnocr-wx.png" 62 | 63 | load_img = LoadImage() 64 | 65 | result = engine(load_img(img_path)) 66 | print(result) 67 | 68 | 69 | def test_rapid_detector(): 70 | # 测试直接指定模型文件路径 71 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 72 | model_fp = os.path.join(root_dir, "models", "ch_PP-OCRv4_det_infer.onnx") 73 | detector = RapidDetector( 74 | model_name="ch_PP-OCRv5_det", 75 | # model_fp=model_fp, 76 | ) 77 | 78 | example_dir = Path(root_dir) / "docs" 79 | img_path = example_dir / "cnocr-wx.png" 80 | 81 | result = detector.detect(img_path) 82 | print(result) 83 | assert isinstance(result, dict) 84 | assert "rotated_angle" in result 85 | assert "detected_texts" in result 86 | assert isinstance(result["detected_texts"], list) 87 | if len(result["detected_texts"]) > 0: 88 | box = result["detected_texts"][0] 89 | assert "box" in box 90 | assert "score" in box 91 | assert box["box"].shape == (4, 2) 92 | assert isinstance(box["score"], float) 93 | 94 | # 测试使用默认参数 95 | detector = RapidDetector() 96 | result = detector.detect(img_path) 97 | print(result) 98 | assert isinstance(result, dict) 99 | assert "rotated_angle" in result 100 | assert "detected_texts" in result 101 | 102 | # 测试错误的模型名称 103 | with pytest.raises(NotImplementedError): 104 | RapidDetector(model_name="invalid") 105 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # coding: utf-8 3 | # Copyright (C) 2021-2023, [Breezedeus](https://github.com/breezedeus). 4 | # Licensed to the Apache Software Foundation (ASF) under one 5 | # or more contributor license agreements. See the NOTICE file 6 | # distributed with this work for additional information 7 | # regarding copyright ownership. The ASF licenses this file 8 | # to you under the Apache License, Version 2.0 (the 9 | # "License"); you may not use this file except in compliance 10 | # with the License. You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, 15 | # software distributed under the License is distributed on an 16 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 17 | # KIND, either express or implied. See the License for the 18 | # specific language governing permissions and limitations 19 | # under the License. 20 | 21 | import os 22 | from setuptools import find_packages, setup 23 | from pathlib import Path 24 | 25 | PACKAGE_NAME = "cnstd" 26 | 27 | here = Path(__file__).parent 28 | 29 | long_description = (here / "README.md").read_text(encoding="utf-8") 30 | 31 | about = {} 32 | exec( 33 | (here / PACKAGE_NAME.replace('.', os.path.sep) / "__version__.py").read_text( 34 | encoding="utf-8" 35 | ), 36 | about, 37 | ) 38 | 39 | required = [ 40 | 'click', 41 | 'tqdm', 42 | 'pyyaml', 43 | 'unidecode', 44 | "torch>=1.8.0", 45 | "torchvision>=0.9.0", 46 | 'numpy', 47 | 'scipy', 48 | 'pandas', 49 | "pytorch-lightning", 50 | 'pillow>=5.3.0', 51 | 'opencv-python>=4.0.0', 52 | 'shapely', 53 | # 'Polygon3', 54 | 'pyclipper', 55 | 'matplotlib', 56 | 'seaborn', 57 | "onnx", 58 | "huggingface_hub", 59 | "ultralytics", 60 | "rapidocr>=3.0", 61 | ] 62 | 63 | extras_require = { 64 | "ort-cpu": ["onnxruntime"], 65 | "ort-gpu": ["onnxruntime-gpu"], 66 | "dev": ["pip-tools", "pytest"], 67 | } 68 | 69 | entry_points = """ 70 | [console_scripts] 71 | cnstd = cnstd.cli:cli 72 | """ 73 | 74 | setup( 75 | name=PACKAGE_NAME, 76 | version=about['__version__'], 77 | description="Python3 package for Chinese/English Scene Text Detection (STD), Mathematical Formula Detection (MFD), " 78 | "and Layout Analysis, with free pretrained models", 79 | long_description=long_description, 80 | long_description_content_type="text/markdown", 81 | author='breezedeus', 82 | author_email='breezedeus@163.com', 83 | license='Apache 2.0', 84 | url='https://github.com/breezedeus/cnstd', 85 | platforms=["Mac", "Linux", "Windows"], 86 | packages=find_packages(), 87 | entry_points=entry_points, 88 | include_package_data=True, 89 | data_files=[ 90 | ( 91 | '', 92 | [ 93 | 'cnstd/yolov7/yolov7-tiny-layout.yaml', 94 | 'cnstd/yolov7/yolov7-tiny-mfd.yaml', 95 | 'cnstd/yolov7/yolov7-mfd.yaml', 96 | ], 97 | ) 98 | ], 99 | install_requires=required, 100 | extras_require=extras_require, 101 | zip_safe=False, 102 | classifiers=[ 103 | 'Development Status :: 4 - Beta', 104 | 'Operating System :: OS Independent', 105 | 'Intended Audience :: Developers', 106 | 'License :: OSI Approved :: Apache Software License', 107 | 'Programming Language :: Python', 108 | 'Programming Language :: Python :: Implementation', 109 | 'Programming Language :: Python :: 3', 110 | 'Programming Language :: Python :: 3.7', 111 | 'Programming Language :: Python :: 3.8', 112 | 'Programming Language :: Python :: 3.9', 113 | 'Programming Language :: Python :: 3.10', 114 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 115 | ], 116 | ) 117 | -------------------------------------------------------------------------------- /cnstd/model/fpn.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/mindee/doctr 20 | 21 | import logging 22 | from typing import List 23 | 24 | import torch 25 | from torch import nn 26 | from torchvision.ops.deform_conv import DeformConv2d 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class FeaturePyramidNetwork(nn.Module): 32 | def __init__( 33 | self, in_channels: List[int], out_channels: int, deform_conv: bool = False, 34 | ) -> None: 35 | 36 | super().__init__() 37 | 38 | out_chans = out_channels // len(in_channels) 39 | 40 | conv_layer = DeformConv2d if deform_conv else nn.Conv2d 41 | 42 | self.in_branches = nn.ModuleList( 43 | [ 44 | nn.Sequential( 45 | conv_layer(chans, out_channels, 1, bias=False), 46 | nn.BatchNorm2d(out_channels), 47 | nn.ReLU(inplace=True), 48 | ) 49 | for idx, chans in enumerate(in_channels) 50 | ] 51 | ) 52 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 53 | self.out_branches = nn.ModuleList( 54 | [ 55 | nn.Sequential( 56 | conv_layer(out_channels, out_chans, 3, padding=1, bias=False), 57 | nn.BatchNorm2d(out_chans), 58 | nn.ReLU(inplace=True), 59 | nn.Upsample( 60 | scale_factor=2 ** idx, mode='bilinear', align_corners=True 61 | ), 62 | ) 63 | for idx, chans in enumerate(in_channels) 64 | ] 65 | ) 66 | 67 | def forward(self, x: List[torch.Tensor]) -> torch.Tensor: 68 | if len(x) != len(self.out_branches): 69 | raise AssertionError 70 | # Conv1x1 to get the same number of channels 71 | _x: List[torch.Tensor] = [branch(t) for branch, t in zip(self.in_branches, x)] 72 | out: List[torch.Tensor] = self._merge(_x) 73 | 74 | # Conv and final upsampling 75 | out = [branch(t) for branch, t in zip(self.out_branches, out[::-1])] 76 | 77 | return torch.cat(out, dim=1) 78 | 79 | def _merge(self, _x: List[torch.Tensor]) -> List[torch.Tensor]: 80 | return self._merge_small_to_large(_x) 81 | 82 | def _merge_small_to_large(self, _x: List[torch.Tensor]) -> List[torch.Tensor]: 83 | out: List[torch.Tensor] = [_x[-1]] 84 | for t in _x[:-1][::-1]: 85 | out.append(self.upsample(out[-1]) + t) 86 | return out 87 | 88 | 89 | class PathAggregationNetwork(FeaturePyramidNetwork): 90 | """ 91 | 参考:https://github.dev/RangiLyu/nanodet 。 92 | This is an implementation of the `PAN in Path Aggregation Network 93 | ` . 94 | """ 95 | def __init__( 96 | self, in_channels: List[int], out_channels: int, deform_conv: bool = False, 97 | ) -> None: 98 | super().__init__(in_channels, out_channels, deform_conv) 99 | self.downsample = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True) 100 | 101 | def _merge(self, _x: List[torch.Tensor]) -> List[torch.Tensor]: 102 | _x = self._merge_small_to_large(_x) 103 | return self._merge_large_to_small(_x) 104 | 105 | def _merge_large_to_small(self, _x: List[torch.Tensor]) -> List[torch.Tensor]: 106 | out = [v for v in _x] 107 | for i in range(len(out)-1, 1, -1): 108 | out[i-1] += self.downsample(out[i]) 109 | return out 110 | -------------------------------------------------------------------------------- /cnstd/transforms/base.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/mindee/doctr 20 | 21 | import random 22 | from typing import List, Any, Callable, Dict, Tuple 23 | import numpy as np 24 | 25 | from ..utils import normalize_img_array 26 | from ..utils.repr import NestedObject 27 | from .utils import invert_colors, rotate 28 | 29 | 30 | __all__ = ['NormalizeAug', 'ColorInversion', 'OneOf', 'RandomApply', 'RandomRotate'] 31 | 32 | 33 | class NormalizeAug(object): 34 | def __call__(self, img): 35 | return normalize_img_array(img) 36 | 37 | 38 | class ColorInversion(NestedObject): 39 | """Applies the following tranformation to a tensor (image or batch of images): 40 | convert to grayscale, colorize (shift 0-values randomly), and then invert colors 41 | 42 | Example:: 43 | >>> transfo = ColorInversion(min_val=0.6) 44 | >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) 45 | 46 | Args: 47 | min_val: range [min_val, 1] to colorize RGB pixels 48 | """ 49 | def __init__(self, min_val: float = 0.5) -> None: 50 | self.min_val = min_val 51 | 52 | def extra_repr(self) -> str: 53 | return f"min_val={self.min_val}" 54 | 55 | def __call__(self, img: Any) -> Any: 56 | return invert_colors(img, self.min_val) 57 | 58 | 59 | class OneOf(NestedObject): 60 | """Randomly apply one of the input transformations 61 | 62 | Example:: 63 | >>> transfo = OneOf([JpegQuality(), Gamma()]) 64 | >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) 65 | 66 | Args: 67 | transforms: list of transformations, one only will be picked 68 | """ 69 | 70 | _children_names: List[str] = ['transforms'] 71 | 72 | def __init__(self, transforms: List[Callable[[Any], Any]]) -> None: 73 | self.transforms = transforms 74 | 75 | def __call__(self, img: Any) -> Any: 76 | # Pick transformation 77 | transfo = self.transforms[int(random.random() * len(self.transforms))] 78 | # Apply 79 | return transfo(img) 80 | 81 | 82 | class RandomApply(NestedObject): 83 | """Apply with a probability p the input transformation 84 | 85 | Example:: 86 | >>> transfo = RandomApply(Gamma(), p=.5) 87 | >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) 88 | 89 | Args: 90 | transform: transformation to apply 91 | p: probability to apply 92 | """ 93 | def __init__(self, transform: Callable[[Any], Any], p: float = .5) -> None: 94 | self.transform = transform 95 | self.p = p 96 | 97 | def extra_repr(self) -> str: 98 | return f"transform={self.transform}, p={self.p}" 99 | 100 | def __call__(self, img: Any) -> Any: 101 | if random.random() < self.p: 102 | return self.transform(img) 103 | return img 104 | 105 | 106 | class RandomRotate(NestedObject): 107 | """Randomly rotate a tensor image 108 | 109 | Args: 110 | max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in 111 | [-max_angle, max_angle] 112 | """ 113 | def __init__(self, max_angle: float = 25.) -> None: 114 | self.max_angle = max_angle 115 | 116 | def extra_repr(self) -> str: 117 | return f"max_angle={self.max_angle}" 118 | 119 | def __call__(self, img: Any, target: Dict[str, np.ndarray]) -> Tuple[Any, Dict[str, np.ndarray]]: 120 | angle = random.uniform(-self.max_angle, self.max_angle) 121 | img, target['boxes'] = rotate(img, target['boxes'], angle) 122 | return img, target 123 | -------------------------------------------------------------------------------- /cnstd/transforms/utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/mindee/doctr 20 | 21 | import torch 22 | from torchvision.transforms import functional as F 23 | from copy import deepcopy 24 | import numpy as np 25 | from typing import Tuple 26 | from ..utils.geometry import rotate_boxes 27 | 28 | __all__ = ["invert_colors", "rotate", "crop_detection"] 29 | 30 | 31 | def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor: 32 | out = F.rgb_to_grayscale(img, num_output_channels=3) 33 | # Random RGB shift 34 | shift_shape = [img.shape[0], 3, 1, 1] if img.ndim == 4 else [3, 1, 1] 35 | rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape) 36 | # Inverse the color 37 | if out.dtype == torch.uint8: 38 | out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) 39 | else: 40 | out = out * rgb_shift.to(dtype=out.dtype) 41 | # Inverse the color 42 | out = 255 - out if out.dtype == torch.uint8 else 1 - out 43 | return out 44 | 45 | 46 | def rotate( 47 | img: torch.Tensor, 48 | boxes: np.ndarray, 49 | angle: float, 50 | ) -> Tuple[torch.Tensor, np.ndarray]: 51 | """Rotate image around the center, interpolation=NEAREST, pad with 0 (black) 52 | 53 | Args: 54 | img: image to rotate 55 | boxes: array of boxes to rotate as well 56 | angle: angle in degrees. +: counter-clockwise, -: clockwise 57 | 58 | Returns: 59 | A tuple of rotated img (tensor), rotated boxes (np array) 60 | """ 61 | rotated_img = F.rotate(img, angle=angle, fill=0) # Interpolation NEAREST by default 62 | _boxes = deepcopy(boxes) 63 | if boxes.dtype == int: 64 | # Compute relative boxes 65 | _boxes = _boxes.astype(float) 66 | _boxes[:, [0, 2]] = _boxes[:, [0, 2]] / img.shape[2] 67 | _boxes[:, [1, 3]] = _boxes[:, [1, 3]] / img.shape[1] 68 | # Compute rotated bboxes: xmin, ymin, xmax, ymax --> x, y, w, h, alpha 69 | r_boxes = rotate_boxes(_boxes, angle=angle, min_angle=0) 70 | if boxes.dtype == int: 71 | # Back to absolute boxes 72 | r_boxes[:, [0, 2]] *= img.shape[2] 73 | r_boxes[:, [1, 3]] *= img.shape[1] 74 | return rotated_img, r_boxes 75 | 76 | 77 | def crop_detection( 78 | img: torch.Tensor, 79 | boxes: np.ndarray, 80 | crop_box: Tuple[int, int, int, int] 81 | ) -> Tuple[torch.Tensor, np.ndarray]: 82 | """Crop and image and associated bboxes 83 | 84 | Args: 85 | img: image to crop 86 | boxes: array of boxes to clip, absolute (int) or relative (float) 87 | crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Absolute coords. 88 | 89 | Returns: 90 | A tuple of cropped image, cropped boxes, where the image is not resized. 91 | """ 92 | xmin, ymin, xmax, ymax = crop_box 93 | croped_img = F.crop( 94 | img, ymin, xmin, ymax - ymin, xmax - xmin 95 | ) 96 | if boxes.dtype == int: # absolute boxes 97 | # Clip boxes 98 | boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], xmin, xmax) 99 | boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], ymin, ymax) 100 | else: # relative boxes 101 | h, w = img.shape[-2:] 102 | # Clip boxes 103 | boxes[:, [0, 2]] = np.clip(boxes[:, [0, 2]], xmin / w, xmax / w) 104 | boxes[:, [1, 3]] = np.clip(boxes[:, [1, 3]], ymin / h, ymax / h) 105 | # Remove 0-sized boxes 106 | zero_height = boxes[:, 1] == boxes[:, 3] 107 | zero_width = boxes[:, 0] == boxes[:, 2] 108 | empty_boxes = np.logical_or(zero_height, zero_width) 109 | boxes = boxes[~empty_boxes] 110 | 111 | return croped_img, boxes 112 | -------------------------------------------------------------------------------- /scripts/generate_idx_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | from itertools import chain 5 | 6 | 7 | def generate_icdar2015_idx_pairs( 8 | img_dir, img_prefix_dir, label_prefix_dir, label_prefix_fn='' 9 | ): 10 | imglst = glob.glob1(img_dir, '*g') 11 | imglst.sort() 12 | idx_pairs = [] 13 | for img in imglst: 14 | img_fp = os.path.join(img_prefix_dir, img) 15 | name = img.rsplit('.', maxsplit=1)[0] 16 | label_fp = os.path.join(label_prefix_dir, label_prefix_fn + name + '.txt') 17 | idx_pairs.append((img_fp, label_fp)) 18 | 19 | print(f'{len(idx_pairs)} pairs are generated') 20 | return idx_pairs 21 | 22 | 23 | def generate_icpr_mtwi_2018_idx_pairs(img_dir, img_prefix_dir, label_prefix_dir): 24 | imglst = glob.glob1(img_dir, '*g') 25 | imglst.sort() 26 | idx_pairs = [] 27 | for img in imglst: 28 | img_fp = os.path.join(img_prefix_dir, img) 29 | name = img.rsplit('.', maxsplit=1)[0] 30 | label_fp = os.path.join(label_prefix_dir, name + '.txt') 31 | idx_pairs.append((img_fp, label_fp)) 32 | 33 | print(f'{len(idx_pairs)} pairs are generated') 34 | return idx_pairs 35 | 36 | 37 | def save_idx_file(idx_pairs, output_fp): 38 | with open(output_fp, 'w') as f: 39 | for one_pair in idx_pairs: 40 | f.write('\t'.join(one_pair) + '\n') 41 | 42 | 43 | def icdar2015(): 44 | img_dir = 'data/icdar2015/train/images' 45 | img_prefix_dir = 'icdar2015/train/images' 46 | label_prefix_dir = 'icdar2015/train/gts' 47 | label_prefix_fn = 'gt_' 48 | idx_pairs = generate_icdar2015_idx_pairs( 49 | img_dir, img_prefix_dir, label_prefix_dir, label_prefix_fn 50 | ) 51 | save_idx_file(idx_pairs, 'data/icdar2015/train.tsv') 52 | 53 | 54 | def icpr_mtwi_2018(): 55 | for i in ('1000', '9000'): 56 | img_dir = '/home/ein/jinlong/std_data/ICPR-MTWI-2018/train/image_%s' % i 57 | img_prefix_dir = 'ICPR-MTWI-2018/train/image_%s' % i 58 | label_prefix_dir = 'ICPR-MTWI-2018/train/txt_%s' % i 59 | idx_pairs = generate_icpr_mtwi_2018_idx_pairs( 60 | img_dir, img_prefix_dir, label_prefix_dir 61 | ) 62 | save_idx_file( 63 | idx_pairs, '/home/ein/jinlong/std_data/ICPR-MTWI-2018/train_%s.tsv' % i 64 | ) 65 | 66 | 67 | def icdar_rctw_2017(): 68 | img_dir = '/home/ein/jinlong/std_data/ICDAR-RCTW-2017/train_images' 69 | img_prefix_dir = 'ICDAR-RCTW-2017/train_images' 70 | label_prefix_dir = 'ICDAR-RCTW-2017/train_gts' 71 | idx_pairs = generate_icpr_mtwi_2018_idx_pairs( 72 | img_dir, img_prefix_dir, label_prefix_dir 73 | ) 74 | save_idx_file(idx_pairs, '/home/ein/jinlong/std_data/ICDAR-RCTW-2017/train.tsv') 75 | 76 | 77 | def generate_icpr_2019_lstv_idx_pairs( 78 | label_json_fp, out_label_dir, img_prefix_dir, label_prefix_dir 79 | ): 80 | idx_pairs = [] 81 | labels = json.load(open(label_json_fp)) 82 | for fname, info_list in labels.items(): 83 | new_info = [] 84 | for box_info in info_list: 85 | if len(box_info['points']) != 4: 86 | continue 87 | box = list(chain(*box_info['points'])) 88 | 89 | text = box_info['transcription'] 90 | if box_info['illegibility']: 91 | text = '###' 92 | 93 | box.append(text) 94 | new_info.append(list(map(str, box))) 95 | 96 | if new_info: 97 | label_fn = f'{fname}.txt' 98 | label_fp = os.path.join(out_label_dir, label_fn) 99 | with open(label_fp, 'w') as f: 100 | for line in new_info: 101 | f.write(','.join(line) + '\n') 102 | 103 | img_fn = f'{fname}.jpg' 104 | img_fp = os.path.join(img_prefix_dir, img_fn) 105 | label_fp = os.path.join(label_prefix_dir, label_fn) 106 | idx_pairs.append((img_fp, label_fp)) 107 | return idx_pairs 108 | 109 | 110 | def icdar_2019_lstv(): 111 | label_json_fp = '/home/ein/jinlong/std_data/ICDAR2019-LSTV/train_full_labels.json' 112 | label_dir = '/home/ein/jinlong/std_data/ICDAR2019-LSTV/train_gts' 113 | img_prefix_dir = 'ICDAR2019-LSTV/train_images' 114 | label_prefix_dir = 'ICDAR2019-LSTV/train_gts' 115 | if not os.path.exists(label_dir): 116 | os.makedirs(label_dir) 117 | 118 | idx_pairs = generate_icpr_2019_lstv_idx_pairs( 119 | label_json_fp, label_dir, img_prefix_dir, label_prefix_dir 120 | ) 121 | print(f'{len(idx_pairs)} pairs are generated') 122 | save_idx_file(idx_pairs, '/home/ein/jinlong/std_data/ICDAR2019-LSTV/train.tsv') 123 | 124 | 125 | if __name__ == '__main__': 126 | # icpr_mtwi_2018() 127 | # icdar_rctw_2017() 128 | icdar_2019_lstv() 129 | -------------------------------------------------------------------------------- /cnstd/utils/geometry.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/mindee/doctr 20 | 21 | from typing import List, Union 22 | import numpy as np 23 | import cv2 24 | from .common_types import BoundingBox, Polygon4P, RotatedBbox 25 | 26 | __all__ = ['rbbox_to_polygon', 'bbox_to_polygon', 'polygon_to_bbox', 'polygon_to_rbbox', 27 | 'resolve_enclosing_bbox', 'resolve_enclosing_bbox', 'fit_rbbox', 'rotate_boxes'] 28 | 29 | 30 | def bbox_to_polygon(bbox: BoundingBox) -> Polygon4P: 31 | return bbox[0], (bbox[1][0], bbox[0][1]), (bbox[0][0], bbox[1][1]), bbox[1] 32 | 33 | 34 | def rbbox_to_polygon(rbbox: RotatedBbox) -> Polygon4P: 35 | (x, y, w, h, alpha) = rbbox 36 | return cv2.boxPoints(((float(x), float(y)), (float(w), float(h)), float(alpha))) 37 | 38 | 39 | def fit_rbbox(pts: np.ndarray) -> RotatedBbox: 40 | ((x, y), (w, h), alpha) = cv2.minAreaRect(pts) 41 | return x, y, w, h, alpha 42 | 43 | 44 | def polygon_to_bbox(polygon: Polygon4P) -> BoundingBox: 45 | x, y = zip(*polygon) 46 | return (min(x), min(y)), (max(x), max(y)) 47 | 48 | 49 | def polygon_to_rbbox(polygon: Polygon4P) -> RotatedBbox: 50 | cnt = np.array(polygon).reshape((-1, 1, 2)).astype(np.float32) 51 | return fit_rbbox(cnt) 52 | 53 | 54 | def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Union[BoundingBox, np.ndarray]: 55 | """Compute enclosing bbox either from: 56 | 57 | - an array of boxes: (*, 5), where boxes have this shape: 58 | (xmin, ymin, xmax, ymax, score) 59 | 60 | - a list of BoundingBox 61 | 62 | Return a (1, 5) array (enclosing boxarray), or a BoundingBox 63 | """ 64 | if isinstance(bboxes, np.ndarray): 65 | xmin, ymin, xmax, ymax, score = np.split(bboxes, 5, axis=1) 66 | return np.array([xmin.min(), ymin.min(), xmax.max(), ymax.max(), score.mean()]) 67 | else: 68 | x, y = zip(*[point for box in bboxes for point in box]) 69 | return (min(x), min(y)), (max(x), max(y)) 70 | 71 | 72 | def resolve_enclosing_rbbox(rbboxes: List[RotatedBbox]) -> RotatedBbox: 73 | pts = np.asarray([pt for rbbox in rbboxes for pt in rbbox_to_polygon(rbbox)], np.float32) 74 | return fit_rbbox(pts) 75 | 76 | 77 | def rotate_boxes( 78 | boxes: np.ndarray, 79 | angle: float = 0., 80 | min_angle: float = 1. 81 | ) -> np.ndarray: 82 | """Rotate a batch of straight bounding boxes (xmin, ymin, xmax, ymax) of an angle, 83 | if angle > min_angle, around the center of the page. 84 | 85 | Args: 86 | boxes: (N, 4) array of RELATIVE boxes 87 | angle: angle between -90 and +90 degrees 88 | min_angle: minimum angle to rotate boxes 89 | 90 | Returns: 91 | A batch of rotated boxes (N, 5): (x, y, w, h, alpha) or a batch of straight bounding boxes 92 | """ 93 | # If small angle, return boxes (no rotation) 94 | if abs(angle) < min_angle or abs(angle) > 90 - min_angle: 95 | return boxes 96 | # Compute rotation matrix 97 | angle_rad = angle * np.pi / 180. # compute radian angle for np functions 98 | rotation_mat = np.array([ 99 | [np.cos(angle_rad), -np.sin(angle_rad)], 100 | [np.sin(angle_rad), np.cos(angle_rad)] 101 | ], dtype=boxes.dtype) 102 | # Compute unrotated boxes 103 | x_unrotated, y_unrotated = (boxes[:, 0] + boxes[:, 2]) / 2, (boxes[:, 1] + boxes[:, 3]) / 2 104 | width, height = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] 105 | # Rotate centers 106 | centers = np.stack((x_unrotated, y_unrotated), axis=-1) 107 | rotated_centers = .5 + np.matmul(centers - .5, np.transpose(rotation_mat)) 108 | x_center, y_center = rotated_centers[:, 0], rotated_centers[:, 1] 109 | # Compute rotated boxes 110 | rotated_boxes = np.stack((x_center, y_center, width, height, angle * np.ones_like(boxes[:, 0])), axis=1) 111 | return rotated_boxes 112 | -------------------------------------------------------------------------------- /cnstd/app.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | from collections import OrderedDict 21 | 22 | import numpy as np 23 | from PIL import Image 24 | import streamlit as st 25 | 26 | from cnstd import CnStd 27 | from cnstd.utils import plot_for_debugging, pil_to_numpy 28 | from cnstd.consts import AVAILABLE_MODELS as STD_MODELS 29 | 30 | try: 31 | from cnocr import CnOcr 32 | from cnocr.consts import AVAILABLE_MODELS 33 | 34 | cnocr_available = True 35 | except Exception: 36 | cnocr_available = False 37 | 38 | 39 | @st.cache(allow_output_mutation=True) 40 | def get_ocr_model(ocr_model_name): 41 | if not cnocr_available: 42 | return None 43 | model_name, model_backend = ocr_model_name 44 | return CnOcr(model_name, model_backend=model_backend) 45 | 46 | 47 | @st.cache(allow_output_mutation=True) 48 | def get_std_model(std_model_name, rotated_bbox, use_angle_clf): 49 | model_name, model_backend = std_model_name 50 | return CnStd( 51 | model_name, 52 | model_backend=model_backend, 53 | rotated_bbox=rotated_bbox, 54 | use_angle_clf=use_angle_clf, 55 | ) 56 | 57 | 58 | def visualize_std(img, std_out, box_score_thresh): 59 | img = pil_to_numpy(img).transpose((1, 2, 0)).astype(np.uint8) 60 | 61 | plot_for_debugging( 62 | img, std_out['detected_texts'], box_score_thresh, './streamlit-app' 63 | ) 64 | st.subheader('STD Result') 65 | st.image('./streamlit-app-result.png') 66 | # st.image('./streamlit-app-crops.png') 67 | 68 | 69 | def visualize_ocr(ocr, std_out): 70 | st.empty() 71 | st.subheader('OCR Result') 72 | ocr_res = OrderedDict({'文本': []}) 73 | ocr_res['概率值'] = [] 74 | for box_info in std_out['detected_texts']: 75 | cropped_img = box_info['cropped_img'] # 检测出的文本框 76 | try: 77 | ocr_out = ocr.ocr_for_single_line(cropped_img) 78 | prob, text = ocr_out[1], ocr_out[0] 79 | except: 80 | prob, text = 0.0, '' 81 | ocr_res['概率值'].append(prob) 82 | ocr_res['文本'].append(text) 83 | st.table(ocr_res) 84 | 85 | 86 | def main(): 87 | st.sidebar.header('CnStd 设置') 88 | models = list(STD_MODELS.all_models()) 89 | models.sort() 90 | std_model_name = st.sidebar.selectbox( 91 | '模型名称', models, index=models.index(('ch_PP-OCRv4_det', 'onnx')) 92 | ) 93 | rotated_bbox = st.sidebar.checkbox('是否检测带角度文本框', value=True) 94 | use_angle_clf = st.sidebar.checkbox('是否使用角度预测模型校正文本框', value=False) 95 | st.sidebar.subheader('resize 后图片(长边)大小') 96 | new_size = st.sidebar.slider('高宽尺寸', min_value=124, max_value=4096, value=768) 97 | st.sidebar.subheader('检测参数') 98 | box_score_thresh = st.sidebar.slider( 99 | '得分阈值(低于阈值的结果会被过滤掉)', min_value=0.05, max_value=0.95, value=0.3 100 | ) 101 | min_box_size = st.sidebar.slider( 102 | '框大小阈值(更小的文本框会被过滤掉)', min_value=4, max_value=50, value=10 103 | ) 104 | std = get_std_model(std_model_name, rotated_bbox, use_angle_clf) 105 | 106 | if cnocr_available: 107 | st.sidebar.markdown("""---""") 108 | st.sidebar.header('CnOcr 设置') 109 | all_models = list(AVAILABLE_MODELS.all_models()) 110 | all_models.sort() 111 | idx = all_models.index(('densenet_lite_136-fc', 'onnx')) 112 | ocr_model_name = st.sidebar.selectbox('选择模型', all_models, index=idx) 113 | ocr = get_ocr_model(ocr_model_name) 114 | 115 | st.markdown( 116 | '# 开源文本检测和识别工具 [CnStd](https://github.com/breezedeus/cnstd) 和 ' 117 | '[CnOcr](https://github.com/breezedeus/cnocr) 演示 Demo' 118 | ) 119 | st.subheader('选择待检测图片') 120 | content_file = st.file_uploader('', type=["png", "jpg", "jpeg", "webp"]) 121 | if content_file is None: 122 | st.stop() 123 | 124 | try: 125 | img = Image.open(content_file) 126 | 127 | std_out = std.detect( 128 | img, 129 | resized_shape=new_size, 130 | preserve_aspect_ratio=True, 131 | box_score_thresh=box_score_thresh, 132 | min_box_size=min_box_size, 133 | ) 134 | visualize_std(img, std_out, box_score_thresh) 135 | 136 | if cnocr_available: 137 | visualize_ocr(ocr, std_out) 138 | except Exception as e: 139 | st.error(e) 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /scripts/gen_label_studio_json.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # 生成检测结果(json格式)文件,这个文件可以导入到label studio中,生成待标注的任务 3 | from collections import OrderedDict 4 | from glob import glob 5 | import json 6 | from argparse import ArgumentParser 7 | import tqdm 8 | from pathlib import Path 9 | 10 | import cv2 11 | 12 | from cnstd import LayoutAnalyzer 13 | from cnstd.utils import read_img 14 | 15 | 16 | def to_json(total_width, total_height, box_type, x0, y0, w, h, _id): 17 | return { 18 | "original_width": total_width, 19 | "original_height": total_height, 20 | "image_rotation": 0, 21 | "value": { 22 | "x": x0, 23 | "y": y0, 24 | "width": w, 25 | "height": h, 26 | "rotation": 0, 27 | "rectanglelabels": [box_type], 28 | }, 29 | "id": str(_id), 30 | "from_name": "label", 31 | "to_name": "image", 32 | "type": "rectanglelabels", 33 | "origin": "manual", 34 | } 35 | 36 | 37 | def deduplicate_images(img_dir): 38 | # 对文件夹下的图片做去重 39 | def calculate_image_hash(image_path): 40 | # with open(img_fp, 'rb') as f: 41 | # image_data = f.read() 42 | # return hashlib.md5(image_data).hexdigest() 43 | image = cv2.imread(image_path) 44 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 45 | resized = cv2.resize(gray, (8, 8), interpolation=cv2.INTER_AREA) 46 | _, threshold = cv2.threshold( 47 | resized, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU 48 | ) 49 | 50 | return sum([2 ** i for (i, v) in enumerate(threshold.flatten()) if v]) 51 | 52 | img_fp_list = glob('{}/*g'.format(img_dir), recursive=True) 53 | print(f'{len(img_fp_list)} images found in {img_dir}') 54 | outs = OrderedDict() 55 | for img_fp in tqdm.tqdm(img_fp_list): 56 | img_hash = calculate_image_hash(img_fp) 57 | 58 | # 将特征值与文件名存储在字典中 59 | if img_hash not in outs: 60 | outs[img_hash] = img_fp 61 | print(f'{len(outs)} different images kept after deduplication') 62 | return list(outs.values()) 63 | 64 | 65 | def main(): 66 | parser = ArgumentParser() 67 | parser.add_argument( 68 | '-t', 69 | '--model-type', 70 | type=str, 71 | default='yolov7', 72 | help='模型类型。当前支持 [`yolov7_tiny`, `yolov7`]', 73 | ) 74 | parser.add_argument( 75 | '-p', 76 | '--model-fp', 77 | type=str, 78 | default='epoch_124-mfd.pt', 79 | help='使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型', 80 | ) 81 | parser.add_argument( 82 | "--resized-shape", type=int, default=608, help='分析时把图片resize到此大小再进行。默认为 `608`', 83 | ) 84 | parser.add_argument( 85 | '-l', 86 | '--local-file-doc-root-dir', 87 | type=str, 88 | required=True, 89 | help='这个路径对应 Label Studio 启动时使用的 LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT 值', 90 | ) 91 | parser.add_argument( 92 | '-i', '--img-dir', type=str, required=True, help='image directory' 93 | ) 94 | parser.add_argument( 95 | '-o', 96 | '--out-json-fp', 97 | type=str, 98 | default='prediction_results.json', 99 | help='output json file', 100 | ) 101 | args = parser.parse_args() 102 | img_dir = args.img_dir 103 | 104 | analyzer = LayoutAnalyzer( 105 | model_name='mfd', model_type=args.model_type, model_fp=args.model_fp 106 | ) 107 | 108 | img_fp_list = deduplicate_images(img_dir) 109 | 110 | total_json = [] 111 | num_boxes = 0 112 | for img_fp in tqdm.tqdm(img_fp_list): 113 | img0 = read_img(img_fp) 114 | width, height = img0.size 115 | out = analyzer.analyze(img0, resized_shape=args.resized_shape) 116 | 117 | results = [] 118 | for box_info in out: 119 | num_boxes += 1 120 | # box with 4 points to (x0, y0, w, h) 121 | box = box_info['box'] 122 | w = box[2][0] - box[0][0] 123 | h = box[2][1] - box[0][1] 124 | info = to_json( 125 | width, 126 | height, 127 | box_info['type'], 128 | 100 * box[0][0] / width, 129 | 100 * box[0][1] / height, 130 | 100 * w / width, 131 | 100 * h / height, 132 | num_boxes, 133 | ) 134 | results.append(info) 135 | 136 | predictions = [{"model_version": "one", "score": 0.5, "result": results}] 137 | # 这个路径要相对于Label Studio初始化时设置的 LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT 值, 138 | # 例如:如果图片绝对路径为 `/home/user1/images/1.jpg`,而 `LABEL_STUDIO_LOCAL_FILES_DOCUMENT_ROOT` 为 `/home/user1`, 139 | # 则下面字典中的 `image` 中对应的路径应该为 `image/1.jpg`, 140 | # 此时 `image` 应该取值为 `/data/local-files/?d=image/1.jpg` 。 141 | # 注:如果下面代码输出的文件路径有问题,改一下以下几行的逻辑就行 142 | local_fp = Path(img_fp).relative_to(args.local_file_doc_root_dir) 143 | data = { 144 | # "image": img_fp, 145 | "image": "/data/local-files/?d=" 146 | + str(local_fp), 147 | } 148 | total_json.append({"data": data, "predictions": predictions}) 149 | 150 | json.dump(total_json, open(args.out_json_fp, 'w'), indent=2, ensure_ascii=False) 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /cnstd/yolo_detector.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2022-2024, [Breezedeus](https://www.breezedeus.com). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # YOLO Detector based on Ultralytics. 20 | 21 | from pathlib import Path 22 | from typing import Union, Optional, Any, List, Dict, Tuple 23 | import logging 24 | 25 | from PIL import Image 26 | import numpy as np 27 | from ultralytics import YOLO 28 | 29 | from .utils import sort_boxes, dedup_boxes, xyxy24p, select_device, expand_box_by_margin 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | class YoloDetector(object): 35 | def __init__( 36 | self, 37 | *, 38 | model_path: Optional[str] = None, 39 | device: Optional[str] = None, 40 | static_resized_shape: Optional[Union[int, Tuple[int, int]]] = None, 41 | **kwargs, 42 | ): 43 | """ 44 | YOLO Detector based on Ultralytics. 45 | Args: 46 | model_path (optional str): model path, default is None. 47 | device (optional str): device to use, default is None. 48 | static_resized_shape (optional int or tuple): static resized shape, default is None. 49 | When it is not None, the input image will be resized to this shape before detection, 50 | ignoring the input parameter `resized_shape` if .detect() is called. 51 | Some format of models may require a fixed input size, such as CoreML. 52 | **kwargs (): other parameters. 53 | """ 54 | self.device = select_device(device) 55 | self.static_resized_shape = static_resized_shape 56 | self.model = YOLO(model_path, task='detect') 57 | 58 | def __call__(self, *args, **kwargs): 59 | """参考函数 `self.detect()` 。""" 60 | return self.detect(*args, **kwargs) 61 | 62 | def detect( 63 | self, 64 | img_list: Union[ 65 | str, 66 | Path, 67 | Image.Image, 68 | np.ndarray, 69 | List[Union[str, Path, Image.Image, np.ndarray]], 70 | ], 71 | resized_shape: int = 768, 72 | box_margin: int = 0, 73 | conf: float = 0.25, 74 | **kwargs, 75 | ) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]: 76 | """ 77 | 对指定图片(列表)进行目标检测。 78 | 79 | Args: 80 | img_list (str or list): 待识别图片或图片列表;如果是 `np.ndarray`,则应该是shape为 `[H, W, 3]` 的 RGB 格式数组 81 | resized_shape (int or tuple): (H, W); 把图片resize到此大小再做分析;默认值为 `700` 82 | box_margin (int): 对识别出的内容框往外扩展的像素大小;默认值为 `2` 83 | conf (float): 分数阈值;默认值为 `0.25` 84 | **kwargs (): 其他预测使用的参数,以及以下值 85 | - dedup_thrsh: 去重时使用的阈值;默认值为 `0.1` 86 | 87 | Returns: 一张图片的结果为一个list,其中每个元素表示识别出的版面中的一个元素,包含以下信息: 88 | * type: 版面元素对应的类型;可选值来自:`self.categories` ; 89 | * box: 版面元素对应的矩形框;np.ndarray, shape: (4, 2),对应 box 4个点的坐标值 (x, y) ; 90 | * score: 得分,越高表示越可信 。 91 | 92 | """ 93 | dedup_thrsh = kwargs.pop('dedup_thrsh') if 'dedup_thrsh' in kwargs else 0.1 94 | single = not isinstance(img_list, (list, tuple)) 95 | # Ultralytics 需要的 ndarray 是 HWC,BGR 格式 96 | if isinstance(img_list, np.ndarray): 97 | img_list = img_list[:, :, ::-1] 98 | elif isinstance(img_list, list): 99 | img_list = [ 100 | img[:, :, ::-1] if isinstance(img, np.ndarray) else img 101 | for img in img_list 102 | ] 103 | 104 | if self.static_resized_shape is not None: 105 | resized_shape = self.static_resized_shape 106 | batch_results = self.model.predict( 107 | img_list, imgsz=resized_shape, conf=conf, device=self.device, **kwargs 108 | ) 109 | outs = [] 110 | for res in batch_results: 111 | boxes = res.boxes.xyxy.cpu().numpy().tolist() 112 | scores = res.boxes.conf.cpu().numpy().tolist() 113 | labels = res.boxes.cls.cpu().int().numpy().tolist() 114 | categories = res.names 115 | height, width = res.orig_shape 116 | one_out = [] 117 | for box, score, label in zip(boxes, scores, labels): 118 | box = expand_box_by_margin(box, box_margin, (height, width)) 119 | box = xyxy24p(box, ret_type=np.array) 120 | one_out.append({'box': box, 'score': score, 'type': categories[label]}) 121 | 122 | one_out = sort_boxes(one_out, key='box') 123 | one_out = dedup_boxes(one_out, threshold=dedup_thrsh) 124 | outs.append(one_out) 125 | 126 | if single and len(outs) == 1: 127 | return outs[0] 128 | return outs 129 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.10 3 | # by the following command: 4 | # 5 | # pip-compile --output-file=requirements.txt requirements.in 6 | # 7 | --index-url https://pypi.tuna.tsinghua.edu.cn/simple 8 | --extra-index-url https://pypi.org/simple 9 | 10 | aiohappyeyeballs==2.4.3 11 | # via aiohttp 12 | aiohttp==3.11.7 13 | # via fsspec 14 | aiosignal==1.3.1 15 | # via aiohttp 16 | antlr4-python3-runtime==4.9.3 17 | # via omegaconf 18 | async-timeout==5.0.1 19 | # via aiohttp 20 | attrs==24.2.0 21 | # via aiohttp 22 | certifi==2024.8.30 23 | # via requests 24 | charset-normalizer==3.4.0 25 | # via requests 26 | click==8.1.7 27 | # via -r requirements.in 28 | coloredlogs==15.0.1 29 | # via onnxruntime 30 | colorlog==6.9.0 31 | # via rapidocr 32 | contourpy==1.3.0 33 | # via matplotlib 34 | cycler==0.12.1 35 | # via matplotlib 36 | filelock==3.16.1 37 | # via 38 | # huggingface-hub 39 | # torch 40 | flatbuffers==24.3.25 41 | # via onnxruntime 42 | fonttools==4.55.0 43 | # via matplotlib 44 | frozenlist==1.5.0 45 | # via 46 | # aiohttp 47 | # aiosignal 48 | fsspec[http]==2024.10.0 49 | # via 50 | # huggingface-hub 51 | # pytorch-lightning 52 | # torch 53 | huggingface-hub==0.26.2 54 | # via -r requirements.in 55 | humanfriendly==10.0 56 | # via coloredlogs 57 | idna==3.10 58 | # via 59 | # requests 60 | # yarl 61 | jinja2==3.1.4 62 | # via torch 63 | kiwisolver==1.4.7 64 | # via matplotlib 65 | lightning-utilities==0.11.9 66 | # via 67 | # pytorch-lightning 68 | # torchmetrics 69 | markupsafe==3.0.2 70 | # via jinja2 71 | matplotlib==3.9.2 72 | # via 73 | # -r requirements.in 74 | # seaborn 75 | # ultralytics 76 | mpmath==1.3.0 77 | # via sympy 78 | multidict==6.1.0 79 | # via 80 | # aiohttp 81 | # yarl 82 | networkx==3.2.1 83 | # via torch 84 | numpy==1.26.4 85 | # via 86 | # -r requirements.in 87 | # contourpy 88 | # matplotlib 89 | # onnx 90 | # onnxruntime 91 | # opencv-python 92 | # pandas 93 | # rapidocr 94 | # scipy 95 | # seaborn 96 | # shapely 97 | # torchmetrics 98 | # torchvision 99 | # ultralytics 100 | # ultralytics-thop 101 | omegaconf==2.3.0 102 | # via rapidocr 103 | onnx==1.17.0 104 | # via -r requirements.in 105 | onnxruntime==1.19.2 106 | # via -r requirements.in 107 | opencv-python==4.10.0.84 108 | # via 109 | # -r requirements.in 110 | # rapidocr 111 | # ultralytics 112 | packaging==24.2 113 | # via 114 | # huggingface-hub 115 | # lightning-utilities 116 | # matplotlib 117 | # onnxruntime 118 | # pytorch-lightning 119 | # torchmetrics 120 | pandas==2.2.3 121 | # via 122 | # -r requirements.in 123 | # seaborn 124 | # ultralytics 125 | pillow==11.0.0 126 | # via 127 | # -r requirements.in 128 | # matplotlib 129 | # rapidocr 130 | # torchvision 131 | # ultralytics 132 | propcache==0.2.0 133 | # via 134 | # aiohttp 135 | # yarl 136 | protobuf==5.28.3 137 | # via 138 | # onnx 139 | # onnxruntime 140 | psutil==6.1.0 141 | # via ultralytics 142 | py-cpuinfo==9.0.0 143 | # via ultralytics 144 | pyclipper==1.3.0.post6 145 | # via 146 | # -r requirements.in 147 | # rapidocr 148 | pyparsing==3.2.0 149 | # via matplotlib 150 | python-dateutil==2.9.0.post0 151 | # via 152 | # matplotlib 153 | # pandas 154 | pytorch-lightning==2.4.0 155 | # via -r requirements.in 156 | pytz==2024.2 157 | # via pandas 158 | pyyaml==6.0.2 159 | # via 160 | # -r requirements.in 161 | # huggingface-hub 162 | # omegaconf 163 | # pytorch-lightning 164 | # rapidocr 165 | # ultralytics 166 | rapidocr==3.2.0 167 | # via -r requirements.in 168 | requests==2.32.3 169 | # via 170 | # huggingface-hub 171 | # rapidocr 172 | # ultralytics 173 | scipy==1.13.1 174 | # via 175 | # -r requirements.in 176 | # ultralytics 177 | seaborn==0.13.2 178 | # via 179 | # -r requirements.in 180 | # ultralytics 181 | shapely==2.0.6 182 | # via 183 | # -r requirements.in 184 | # rapidocr 185 | six==1.16.0 186 | # via 187 | # python-dateutil 188 | # rapidocr 189 | sympy==1.13.1 190 | # via 191 | # onnxruntime 192 | # torch 193 | torch==2.5.1 194 | # via 195 | # -r requirements.in 196 | # pytorch-lightning 197 | # torchmetrics 198 | # torchvision 199 | # ultralytics 200 | # ultralytics-thop 201 | torchmetrics==1.6.0 202 | # via pytorch-lightning 203 | torchvision==0.20.1 204 | # via 205 | # -r requirements.in 206 | # ultralytics 207 | tqdm==4.67.0 208 | # via 209 | # -r requirements.in 210 | # huggingface-hub 211 | # pytorch-lightning 212 | # rapidocr 213 | # ultralytics 214 | typing-extensions==4.12.2 215 | # via 216 | # huggingface-hub 217 | # lightning-utilities 218 | # multidict 219 | # pytorch-lightning 220 | # torch 221 | tzdata==2024.2 222 | # via pandas 223 | ultralytics==8.3.36 224 | # via -r requirements.in 225 | ultralytics-thop==2.0.12 226 | # via ultralytics 227 | unidecode==1.3.8 228 | # via -r requirements.in 229 | urllib3==2.2.3 230 | # via requests 231 | yarl==1.18.0 232 | # via aiohttp 233 | 234 | # The following packages are considered to be unsafe in a requirements file: 235 | # setuptools 236 | -------------------------------------------------------------------------------- /cnstd/yolov7/yolov7-mfd.yaml: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2022, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits to: https://github.com/WongKinYiu/yolov7, forked to https://github.com/breezedeus/yolov7 20 | 21 | # parameters 22 | nc: 80 # number of classes 23 | depth_multiple: 1.0 # model depth multiple 24 | width_multiple: 1.0 # layer channel multiple 25 | 26 | # anchors 27 | anchors: 28 | - [15,10, 30,16, 43,47] # P3/8, 3 anchors, wh 29 | - [61,20, 65,17, 119,49] # P4/16, 3 anchors, wh 30 | - [116,30, 216,80, 453,226] # P5/32, 3 anchors, wh 31 | 32 | # yolov7 backbone 33 | backbone: 34 | # [from, number, module, args] 35 | [[-1, 1, Conv, [32, 3, 1]], # 0 36 | 37 | [-1, 1, Conv, [64, 3, 2]], # 1-P1/2 38 | [-1, 1, Conv, [64, 3, 1]], 39 | 40 | [-1, 1, Conv, [128, 3, 2]], # 3-P2/4 41 | [-1, 1, Conv, [64, 1, 1]], 42 | [-2, 1, Conv, [64, 1, 1]], 43 | [-1, 1, Conv, [64, 3, 1]], 44 | [-1, 1, Conv, [64, 3, 1]], 45 | [-1, 1, Conv, [64, 3, 1]], 46 | [-1, 1, Conv, [64, 3, 1]], 47 | [[-1, -3, -5, -6], 1, Concat, [1]], 48 | [-1, 1, Conv, [256, 1, 1]], # 11 49 | 50 | [-1, 1, MP, []], 51 | [-1, 1, Conv, [128, 1, 1]], 52 | [-3, 1, Conv, [128, 1, 1]], 53 | [-1, 1, Conv, [128, 3, 2]], 54 | [[-1, -3], 1, Concat, [1]], # 16-P3/8 55 | [-1, 1, Conv, [128, 1, 1]], 56 | [-2, 1, Conv, [128, 1, 1]], 57 | [-1, 1, Conv, [128, 3, 1]], 58 | [-1, 1, Conv, [128, 3, 1]], 59 | [-1, 1, Conv, [128, 3, 1]], 60 | [-1, 1, Conv, [128, 3, 1]], 61 | [[-1, -3, -5, -6], 1, Concat, [1]], 62 | [-1, 1, Conv, [512, 1, 1]], # 24 63 | 64 | [-1, 1, MP, []], 65 | [-1, 1, Conv, [256, 1, 1]], 66 | [-3, 1, Conv, [256, 1, 1]], 67 | [-1, 1, Conv, [256, 3, 2]], 68 | [[-1, -3], 1, Concat, [1]], # 29-P4/16 69 | [-1, 1, Conv, [256, 1, 1]], 70 | [-2, 1, Conv, [256, 1, 1]], 71 | [-1, 1, Conv, [256, 3, 1]], 72 | [-1, 1, Conv, [256, 3, 1]], 73 | [-1, 1, Conv, [256, 3, 1]], 74 | [-1, 1, Conv, [256, 3, 1]], 75 | [[-1, -3, -5, -6], 1, Concat, [1]], 76 | [-1, 1, Conv, [1024, 1, 1]], # 37 77 | 78 | [-1, 1, MP, []], 79 | [-1, 1, Conv, [512, 1, 1]], 80 | [-3, 1, Conv, [512, 1, 1]], 81 | [-1, 1, Conv, [512, 3, 2]], 82 | [[-1, -3], 1, Concat, [1]], # 42-P5/32 83 | [-1, 1, Conv, [256, 1, 1]], 84 | [-2, 1, Conv, [256, 1, 1]], 85 | [-1, 1, Conv, [256, 3, 1]], 86 | [-1, 1, Conv, [256, 3, 1]], 87 | [-1, 1, Conv, [256, 3, 1]], 88 | [-1, 1, Conv, [256, 3, 1]], 89 | [[-1, -3, -5, -6], 1, Concat, [1]], 90 | [-1, 1, Conv, [1024, 1, 1]], # 50 91 | ] 92 | 93 | # yolov7 head 94 | head: 95 | [[-1, 1, SPPCSPC, [512]], # 51 96 | 97 | [-1, 1, Conv, [256, 1, 1]], 98 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 99 | [37, 1, Conv, [256, 1, 1]], # route backbone P4 100 | [[-1, -2], 1, Concat, [1]], 101 | 102 | [-1, 1, Conv, [256, 1, 1]], 103 | [-2, 1, Conv, [256, 1, 1]], 104 | [-1, 1, Conv, [128, 3, 1]], 105 | [-1, 1, Conv, [128, 3, 1]], 106 | [-1, 1, Conv, [128, 3, 1]], 107 | [-1, 1, Conv, [128, 3, 1]], 108 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 109 | [-1, 1, Conv, [256, 1, 1]], # 63 110 | 111 | [-1, 1, Conv, [128, 1, 1]], 112 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 113 | [24, 1, Conv, [128, 1, 1]], # route backbone P3 114 | [[-1, -2], 1, Concat, [1]], 115 | 116 | [-1, 1, Conv, [128, 1, 1]], 117 | [-2, 1, Conv, [128, 1, 1]], 118 | [-1, 1, Conv, [64, 3, 1]], 119 | [-1, 1, Conv, [64, 3, 1]], 120 | [-1, 1, Conv, [64, 3, 1]], 121 | [-1, 1, Conv, [64, 3, 1]], 122 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 123 | [-1, 1, Conv, [128, 1, 1]], # 75 124 | 125 | [-1, 1, MP, []], 126 | [-1, 1, Conv, [128, 1, 1]], 127 | [-3, 1, Conv, [128, 1, 1]], 128 | [-1, 1, Conv, [128, 3, 2]], 129 | [[-1, -3, 63], 1, Concat, [1]], 130 | 131 | [-1, 1, Conv, [256, 1, 1]], 132 | [-2, 1, Conv, [256, 1, 1]], 133 | [-1, 1, Conv, [128, 3, 1]], 134 | [-1, 1, Conv, [128, 3, 1]], 135 | [-1, 1, Conv, [128, 3, 1]], 136 | [-1, 1, Conv, [128, 3, 1]], 137 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 138 | [-1, 1, Conv, [256, 1, 1]], # 88 139 | 140 | [-1, 1, MP, []], 141 | [-1, 1, Conv, [256, 1, 1]], 142 | [-3, 1, Conv, [256, 1, 1]], 143 | [-1, 1, Conv, [256, 3, 2]], 144 | [[-1, -3, 51], 1, Concat, [1]], 145 | 146 | [-1, 1, Conv, [512, 1, 1]], 147 | [-2, 1, Conv, [512, 1, 1]], 148 | [-1, 1, Conv, [256, 3, 1]], 149 | [-1, 1, Conv, [256, 3, 1]], 150 | [-1, 1, Conv, [256, 3, 1]], 151 | [-1, 1, Conv, [256, 3, 1]], 152 | [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], 153 | [-1, 1, Conv, [512, 1, 1]], # 101 154 | 155 | [75, 1, RepConv, [256, 3, 1]], 156 | [88, 1, RepConv, [512, 3, 1]], 157 | [101, 1, RepConv, [1024, 3, 1]], 158 | 159 | [[102,103,104], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5) 160 | ] 161 | -------------------------------------------------------------------------------- /cnstd/yolov7/yolov7-tiny-layout.yaml: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2022, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits to: https://github.com/WongKinYiu/yolov7, forked to https://github.com/breezedeus/yolov7 20 | 21 | # parameters 22 | nc: 80 # number of classes 23 | depth_multiple: 1.0 # model depth multiple 24 | width_multiple: 1.0 # layer channel multiple 25 | 26 | # anchors 27 | anchors: 28 | - [10,13, 16,30, 33,23] # P3/8 29 | - [30,61, 62,45, 59,119] # P4/16 30 | - [116,90, 156,198, 373,326] # P5/32 31 | 32 | # yolov7-tiny backbone 33 | backbone: 34 | # [from, number, module, args] c2, k=1, s=1, p=None, g=1, act=True 35 | [[-1, 1, Conv, [32, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 0-P1/2 36 | 37 | [-1, 1, Conv, [64, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 1-P2/4 38 | 39 | [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 40 | [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 41 | [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 42 | [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 43 | [[-1, -2, -3, -4], 1, Concat, [1]], 44 | [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 7 45 | 46 | [-1, 1, MP, []], # 8-P3/8 47 | [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 48 | [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 49 | [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 50 | [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 51 | [[-1, -2, -3, -4], 1, Concat, [1]], 52 | [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 14 53 | 54 | [-1, 1, MP, []], # 15-P4/16 55 | [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 56 | [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 57 | [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 58 | [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 59 | [[-1, -2, -3, -4], 1, Concat, [1]], 60 | [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 21 61 | 62 | [-1, 1, MP, []], # 22-P5/32 63 | [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 64 | [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 65 | [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 66 | [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 67 | [[-1, -2, -3, -4], 1, Concat, [1]], 68 | [-1, 1, Conv, [512, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 28 69 | ] 70 | 71 | # yolov7-tiny head 72 | head: 73 | [[-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 74 | [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 75 | [-1, 1, SP, [5]], 76 | [-2, 1, SP, [9]], 77 | [-3, 1, SP, [13]], 78 | [[-1, -2, -3, -4], 1, Concat, [1]], 79 | [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 80 | [[-1, -7], 1, Concat, [1]], 81 | [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 37 82 | 83 | [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 84 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 85 | [21, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P4 86 | [[-1, -2], 1, Concat, [1]], 87 | 88 | [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 89 | [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 90 | [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 91 | [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 92 | [[-1, -2, -3, -4], 1, Concat, [1]], 93 | [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 47 94 | 95 | [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 96 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 97 | [14, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P3 98 | [[-1, -2], 1, Concat, [1]], 99 | 100 | [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 101 | [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 102 | [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 103 | [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 104 | [[-1, -2, -3, -4], 1, Concat, [1]], 105 | [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 57 106 | 107 | [-1, 1, Conv, [128, 3, 2, None, 1, nn.LeakyReLU(0.1)]], 108 | [[-1, 47], 1, Concat, [1]], 109 | 110 | [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 111 | [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 112 | [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 113 | [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 114 | [[-1, -2, -3, -4], 1, Concat, [1]], 115 | [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 65 116 | 117 | [-1, 1, Conv, [256, 3, 2, None, 1, nn.LeakyReLU(0.1)]], 118 | [[-1, 37], 1, Concat, [1]], 119 | 120 | [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 121 | [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 122 | [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 123 | [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 124 | [[-1, -2, -3, -4], 1, Concat, [1]], 125 | [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 73 126 | 127 | [57, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 128 | [65, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 129 | [73, 1, Conv, [512, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 130 | 131 | [[74,75,76], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5) 132 | ] 133 | -------------------------------------------------------------------------------- /cnstd/yolov7/yolov7-tiny-mfd.yaml: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2022, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits to: https://github.com/WongKinYiu/yolov7, forked to https://github.com/breezedeus/yolov7 20 | 21 | # parameters 22 | nc: 80 # number of classes 23 | depth_multiple: 1.0 # model depth multiple 24 | width_multiple: 1.0 # layer channel multiple 25 | 26 | # anchors 27 | anchors: 28 | - [15,10, 30,16, 43,47] # P3/8, 3 anchors, wh 29 | - [61,20, 65,17, 119,49] # P4/16, 3 anchors, wh 30 | - [116,30, 216,80, 453,226] # P5/32, 3 anchors, wh 31 | 32 | # yolov7-tiny backbone 33 | backbone: 34 | # [from, number, module, args] c2, k=1, s=1, p=None, g=1, act=True 35 | [[-1, 1, Conv, [32, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 0-P1/2 36 | 37 | [-1, 1, Conv, [64, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 1-P2/4 38 | 39 | [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 40 | [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 41 | [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 42 | [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 43 | [[-1, -2, -3, -4], 1, Concat, [1]], 44 | [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 7 45 | 46 | [-1, 1, MP, []], # 8-P3/8 47 | [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 48 | [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 49 | [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 50 | [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 51 | [[-1, -2, -3, -4], 1, Concat, [1]], 52 | [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 14 53 | 54 | [-1, 1, MP, []], # 15-P4/16 55 | [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 56 | [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 57 | [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 58 | [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 59 | [[-1, -2, -3, -4], 1, Concat, [1]], 60 | [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 21 61 | 62 | [-1, 1, MP, []], # 22-P5/32 63 | [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 64 | [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 65 | [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 66 | [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 67 | [[-1, -2, -3, -4], 1, Concat, [1]], 68 | [-1, 1, Conv, [512, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 28 69 | ] 70 | 71 | # yolov7-tiny head 72 | head: 73 | [[-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 74 | [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 75 | [-1, 1, SP, [5]], 76 | [-2, 1, SP, [9]], 77 | [-3, 1, SP, [13]], 78 | [[-1, -2, -3, -4], 1, Concat, [1]], 79 | [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 80 | [[-1, -7], 1, Concat, [1]], 81 | [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 37 82 | 83 | [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 84 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 85 | [21, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P4 86 | [[-1, -2], 1, Concat, [1]], 87 | 88 | [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 89 | [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 90 | [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 91 | [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 92 | [[-1, -2, -3, -4], 1, Concat, [1]], 93 | [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 47 94 | 95 | [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 96 | [-1, 1, nn.Upsample, [None, 2, 'nearest']], 97 | [14, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P3 98 | [[-1, -2], 1, Concat, [1]], 99 | 100 | [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 101 | [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 102 | [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 103 | [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 104 | [[-1, -2, -3, -4], 1, Concat, [1]], 105 | [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 57 106 | 107 | [-1, 1, Conv, [128, 3, 2, None, 1, nn.LeakyReLU(0.1)]], 108 | [[-1, 47], 1, Concat, [1]], 109 | 110 | [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 111 | [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 112 | [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 113 | [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 114 | [[-1, -2, -3, -4], 1, Concat, [1]], 115 | [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 65 116 | 117 | [-1, 1, Conv, [256, 3, 2, None, 1, nn.LeakyReLU(0.1)]], 118 | [[-1, 37], 1, Concat, [1]], 119 | 120 | [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 121 | [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], 122 | [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 123 | [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 124 | [[-1, -2, -3, -4], 1, Concat, [1]], 125 | [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 73 126 | 127 | [57, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 128 | [65, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 129 | [73, 1, Conv, [512, 3, 1, None, 1, nn.LeakyReLU(0.1)]], 130 | 131 | [[74,75,76], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5) 132 | ] 133 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # Release Notes 2 | 3 | ## Update 2025.06.27:发布 V1.2.6.1 4 | 5 | Major Changes: 6 | 7 | * Bug Fixed 8 | 9 | 主要变更: 10 | 11 | * 修复已知 bug 12 | 13 | ## Update 2025.06.25:发布 V1.2.6 14 | 15 | Major Changes: 16 | 17 | * Integrated the latest PPOCRv5 text detection functionality based on RapidOCR for even faster inference speed 18 | * Added support for PP-OCRv5 detection models: `ch_PP-OCRv5_det` and `ch_PP-OCRv5_det_server` 19 | * Fixed some known bugs 20 | 21 | 主要变更: 22 | 23 | * 基于 RapidOCR 集成 PPOCRv5 最新版文本检测功能,提供更快的推理速度 24 | * 新增支持 PP-OCRv5 检测模型:`ch_PP-OCRv5_det` 和 `ch_PP-OCRv5_det_server` 25 | * 修复部分已知 bug 26 | 27 | ## Update 2024.12.08:发布 V1.2.5.2 28 | 29 | Bug Fixes: 30 | 31 | * Fix compatibility issue of setting environment variables on Windows systems 32 | * Use subprocess.run instead of os.system for better cross-platform support 33 | 34 | Bug Fixes: 35 | 36 | * 修复在 Windows 系统下设置环境变量的兼容性问题 37 | * 使用 subprocess.run 替代 os.system 以提供更好的跨平台支持 38 | 39 | ## Update 2024.11.30:发布 V1.2.5.1 40 | 41 | Major Changes: 42 | 43 | * en_PP-OCRv3_det still uses the previous version and does not use RapidDetector 44 | 45 | Bug Fixes: 46 | 47 | * en_PP-OCRv3_det 依旧使用之前的版本,不使用 RapidDetector 48 | 49 | ## Update 2024.11.24:发布 V1.2.5 50 | 51 | Major Changes: 52 | 53 | * Integrated latest PPOCRv4 text detection functionality based on RapidOCR for faster inference 54 | * Added support for PP-OCRv4 detection models, including standard and server versions 55 | * Added support for PP-OCRv3 English detection model 56 | * Optimized model download functionality with support for domestic mirrors 57 | 58 | 主要变更: 59 | 60 | * 基于 RapidOCR 集成 PPOCRv4 最新版文本检测功能,提供更快的推理速度 61 | * 新增支持 PP-OCRv4 检测模型,包括标准版和服务器版 62 | * 新增支持 PP-OCRv3 英文检测模型 63 | * 优化模型下载功能,支持从国内镜像下载模型文件 64 | 65 | # Update 2024.06.22:发布 V1.2.4.2 66 | 67 | Major Changes: 68 | 69 | * Added a new parameter `static_resized_shape` when initializing `YoloDetector`, which is used to resize the input image to a fixed size. Some formats of models require fixed-size input images during inference, such as `CoreML`. 70 | 71 | 主要变更: 72 | 73 | * `YoloDetector` 初始化时加入了参数 `static_resized_shape`, 用于把输入图片 resize 为固定大小。某些格式的模型在推理时需要固定大小的输入图片,如 `CoreML`。 74 | 75 | # Update 2024.06.17:发布 V1.2.4.1 76 | 77 | Major Changes: 78 | 79 | * Fixed a bug in the `detect` method of `YoloDetector`: when the input is a single file, the output is not a double-layer nested list. 80 | 81 | 主要变更: 82 | 83 | * 修复了 `YoloDetector` 中 `detect` 方法的一个bug:输入为单个文件时,输出不是双层嵌套的 list。 84 | 85 | # Update 2024.06.16:发布 V1.2.4 86 | 87 | Major Changes: 88 | 89 | * Support for YOLO Detector based on Ultralytics. 90 | 91 | 92 | 主要变更: 93 | 94 | * 支持基于 Ultralytics 的 YOLO Detector。 95 | 96 | # Update 2024.04.10:发布 V1.2.3.6 97 | 98 | 主要变更: 99 | 100 | * CN OSS 不可用了,默认下载模型地址由 `CN` 改为 `HF`。 101 | 102 | # Update 2023.10.09:发布 V1.2.3.5 103 | 104 | 主要变更: 105 | 106 | * 支持基于环境变量 `CNSTD_DOWNLOAD_SOURCE` 的取值,来决定不同的模型下载路径。 107 | * `LayoutAnalyzer` 中增加了参数 `model_categories` 和 `model_arch_yaml`,用于指定模型的类别名称列表和模型架构。 108 | 109 | # Update 2023.09.23:发布 V1.2.3.4 110 | 111 | 主要变更: 112 | * 增加了对 `onnxruntine` (ORT) 新版的兼容:`InferenceSession` 中显式提供了 `providers` 参数。 113 | * `setup.py` 中去除对 `onnxruntime` 的依赖,改为在 `extras_require` 中按需指定: 114 | * `cnstd[ort-cpu]`:`onnxruntime`; 115 | * `cnstd[ort-gpu]`: `onnxruntime-gpu`。 116 | 117 | # Update 2023.09.21:发布 V1.2.3.3 118 | 119 | 主要变更: 120 | * 画图颜色优先使用固定的颜色组。 121 | * 下载模型时支持设定环境变量 `HF_TOKEN`,以便从private repos中下载模型。 122 | 123 | # Update 2023.07.02:发布 V1.2.3.2 124 | 125 | 主要变更: 126 | * 修复参数 `device` 的取值bug,感谢 @Shadow-Alex 。 127 | 128 | # Update 2023.06.30:发布 V1.2.3.1 129 | 130 | 主要变更: 131 | * 修复比例转换后检测框可能出界的问题。 132 | 133 | # Update 2023.06.30:发布 V1.2.3 134 | 135 | 主要变更: 136 | * 修复了模型文件自动下载的功能。HuggingFace似乎对下载文件的逻辑做了调整,导致之前版本的自动下载失败,当前版本已修复。但由于HuggingFace国内被墙,国内下载仍需 **梯子(VPN)**。 137 | * 更新了各个依赖包的版本号。 138 | 139 | # Update 2023.06.20: 140 | 141 | 主要变更: 142 | * 基于新标注的数据,重新训练了 **MFD YoloV7** 模型,目前新模型已部署到 [P2T网页版](https://p2t.behye.com) 。具体说明见:[Pix2Text (P2T) 新版公式检测模型 | Breezedeus.com](https://www.breezedeus.com/article/p2t-mfd-20230613) 。 143 | * 之前的 MFD YoloV7 模型已开放给星球会员下载,具体说明见:[P2T YoloV7 数学公式检测模型开放给星球会员下载 | Breezedeus.com](https://www.breezedeus.com/article/p2t-yolov7-for-zsxq-20230619) 。 144 | * 增加了一些Label Studio相关的脚本,见 [scripts](scripts) 。如:利用 CnSTD 自带的 MFD 模型对目录中的图片进行公式检测后生成可导入到Label Studio中的JSON文件;以及,Label Studio标注后把导出的JSON文件转换成训练 MFD 模型所需的数据格式。注意,MFD 模型的训练代码在 [yolov7](https://github.com/breezedeus/yolov7) (`dev` branch)中。 145 | 146 | # Update 2023.02.19:发布 V1.2.2 147 | 148 | 主要变更: 149 | * MFD训练了参数更多精度更高的模型,供 [P2T网页版](https://p2t.behye.com) 使用。 150 | * 优化了检测出的boxes的排序算法,使得boxes的顺序更加符合人类的阅读习惯。 151 | 152 | # Update 2023.02.01:发布 V1.2.1 153 | 154 | 主要变更: 155 | * 支持基于 **YOLOv7** 的 **数学公式检测**(**Mathematical Formula Detection**,简称**MFD**)和 **版面分析**(**Layout Analysis**)模型,并提供预训练好的模型可直接使用。 156 | * 修复了不兼容 Numpy>=1.24 的bug。 157 | 158 | # Update 2022.07.07:发布 cnstd V1.2 159 | 160 | 主要变更: 161 | * 加入了对 [**PaddleOCR**](https://github.com/PaddlePaddle/PaddleOCR) 检测模型的支持; 162 | * 部分调整了检测结果中 `box` 的表达方式,统一为 `4` 个点的坐标值; 163 | * 修复了已知bugs。 164 | 165 | 166 | # Update 2022.05.27:发布 cnstd V1.1.2 167 | 168 | 主要变更: 169 | * 兼容 `opencv-python >=4.5.2`,修复图片反转问题和画图报错问题。 170 | 171 | 172 | 173 | # Update 2021.09.20:发布 cnstd V1.1.0 174 | 175 | 相较于 V1.0.0, **V1.1.0** 的变化主要包括: 176 | 177 | * bugfixes:修复了训练过程中发现的诸多问题; 178 | * 检测主类 **`CnStd`** 初始化接口略有调整,去掉了参数 `model_epoch`; 179 | * backbone 结构中加入了对 **ShuffleNet** 的支持; 180 | * 优化了训练中的超参数取值,提升了模型检测精度; 181 | * 提供了更多的预训练模型可供选择,最小模型降至 **7.5M** 文件大小。 182 | 183 | 184 | 185 | # Update 2021.08.26:发布 cnstd V1.0.0 186 | 187 | * MXNet 越来越小众化,故从基于 MXNet 的实现转为基于 **PyTorch** 的实现; 188 | * 检测速度得到极大提升,耗时几乎下降了一个量级; 189 | * 检测精度也得到较大的提升; 190 | * 实用性增强;检测接口中提供了更灵活的参数,不同应用场景可以尝试使用不同的参数以获得更好的检测效果; 191 | * 提供了更丰富的预训练模型,开箱即用。 192 | 193 | 194 | 195 | 196 | # Update 2020.07.01:发布 cnstd V0.1.1 197 | 198 | `CnStd.detect()`加入输入参数 `kwargs`: 目前会使用到的keys有: 199 | * "height_border",裁切图片时在高度上留出的边界比例,最终上下总共留出的边界大小为height * height_border; 默认为0.05; 200 | * "width_border",裁切图片时在宽度上留出的边界比例,最终左右总共留出的边界大小为height * width_border; 默认为0.0; 201 | 202 | bugfix: 203 | * 修复GPU下推断bug:https://github.com/breezedeus/cnstd/issues/3 204 | 205 | 206 | 207 | 208 | # Update 2020.06.02:发布 cnstd V0.1.0 209 | 210 | 初次发布,主要功能: 211 | 212 | * 利用PSENet进行场景文字检测(STD),支持两种backbone模型:`mobilenetv3` 和 `resnet50_v1b`。 213 | -------------------------------------------------------------------------------- /cnstd/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | from copy import deepcopy 21 | import math 22 | 23 | import torch 24 | from torch.optim.lr_scheduler import ( 25 | _LRScheduler, 26 | StepLR, 27 | LambdaLR, 28 | CyclicLR, 29 | CosineAnnealingWarmRestarts, 30 | MultiStepLR, 31 | OneCycleLR, 32 | ) 33 | 34 | 35 | def get_lr_scheduler(config, optimizer): 36 | orig_lr = config['learning_rate'] 37 | lr_sch_config = deepcopy(config['lr_scheduler']) 38 | lr_sch_name = lr_sch_config.pop('name') 39 | epochs = config['epochs'] 40 | steps_per_epoch = config['steps_per_epoch'] 41 | 42 | if lr_sch_name == 'multi_step': 43 | milestones = [v * steps_per_epoch for v in lr_sch_config['milestones']] 44 | return MultiStepLR( 45 | optimizer, milestones=milestones, gamma=lr_sch_config['gamma'], 46 | ) 47 | elif lr_sch_name == 'cos_warmup': 48 | min_lr_mult_factor = lr_sch_config.get('min_lr_mult_factor', 0.1) 49 | warmup_epochs = lr_sch_config.get('warmup_epochs', 0.1) 50 | return WarmupCosineAnnealingRestarts( 51 | optimizer, 52 | first_cycle_steps=steps_per_epoch * epochs, 53 | max_lr=orig_lr, 54 | min_lr=orig_lr * min_lr_mult_factor, 55 | warmup_steps=int(steps_per_epoch * warmup_epochs), 56 | ) 57 | elif lr_sch_name == 'cos_anneal': 58 | # 5 个 epochs, 一个循环 59 | return CosineAnnealingWarmRestarts( 60 | optimizer, T_0=5 * steps_per_epoch, T_mult=1, eta_min=orig_lr * 0.1 61 | ) 62 | elif lr_sch_name == 'cyclic': 63 | return CyclicLR( 64 | optimizer, 65 | base_lr=orig_lr / 10.0, 66 | max_lr=orig_lr, 67 | step_size_up=5 * steps_per_epoch, # 5 个 epochs, 从最小base_lr上升到最大max_lr 68 | cycle_momentum=False, 69 | ) 70 | elif lr_sch_name == 'one_cycle': 71 | return OneCycleLR( 72 | optimizer, max_lr=orig_lr, epochs=epochs, steps_per_epoch=steps_per_epoch, 73 | ) 74 | 75 | step_size = lr_sch_config['step_size'] 76 | gamma = lr_sch_config['gamma'] 77 | if step_size is None or gamma is None: 78 | return LambdaLR(optimizer, lr_lambda=lambda _: 1) 79 | return StepLR(optimizer, step_size, gamma=gamma) 80 | 81 | 82 | class WarmupCosineAnnealingRestarts(_LRScheduler): 83 | """ 84 | from https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup/blob/master/cosine_annealing_warmup/scheduler.py 85 | 86 | optimizer (Optimizer): Wrapped optimizer. 87 | first_cycle_steps (int): First cycle step size. 88 | cycle_mult(float): Cycle steps magnification. Default: -1. 89 | max_lr(float): First cycle's max learning rate. Default: 0.1. 90 | min_lr(float): Min learning rate. Default: 0.001. 91 | warmup_steps(int): Linear warmup step size. Default: 0. 92 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 93 | last_epoch (int): The index of last epoch. Default: -1. 94 | """ 95 | 96 | def __init__( 97 | self, 98 | optimizer: torch.optim.Optimizer, 99 | first_cycle_steps: int, 100 | cycle_mult: float = 1.0, 101 | max_lr: float = 0.1, 102 | min_lr: float = 0.001, 103 | warmup_steps: int = 0, 104 | gamma: float = 1.0, 105 | last_epoch: int = -1, 106 | ): 107 | assert warmup_steps < first_cycle_steps 108 | 109 | self.first_cycle_steps = first_cycle_steps # first cycle step size 110 | self.cycle_mult = cycle_mult # cycle steps magnification 111 | self.base_max_lr = max_lr # first max learning rate 112 | self.max_lr = max_lr # max learning rate in the current cycle 113 | self.min_lr = min_lr # min learning rate 114 | self.warmup_steps = warmup_steps # warmup step size 115 | self.gamma = gamma # decrease rate of max learning rate by cycle 116 | 117 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 118 | self.cycle = 0 # cycle count 119 | self.step_in_cycle = last_epoch # step size of the current cycle 120 | 121 | super().__init__(optimizer, last_epoch) 122 | 123 | # set learning rate min_lr 124 | self.init_lr() 125 | 126 | def init_lr(self): 127 | self.base_lrs = [] 128 | for param_group in self.optimizer.param_groups: 129 | param_group['lr'] = self.min_lr 130 | self.base_lrs.append(self.min_lr) 131 | 132 | def get_lr(self): 133 | if self.step_in_cycle == -1: 134 | return self.base_lrs 135 | elif self.step_in_cycle < self.warmup_steps: 136 | return [ 137 | (self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps 138 | + base_lr 139 | for base_lr in self.base_lrs 140 | ] 141 | else: 142 | return [ 143 | base_lr 144 | + (self.max_lr - base_lr) 145 | * ( 146 | 1 147 | + math.cos( 148 | math.pi 149 | * (self.step_in_cycle - self.warmup_steps) 150 | / (self.cur_cycle_steps - self.warmup_steps) 151 | ) 152 | ) 153 | / 2 154 | for base_lr in self.base_lrs 155 | ] 156 | 157 | def step(self, epoch=None): 158 | if epoch is None: 159 | epoch = self.last_epoch + 1 160 | self.step_in_cycle = self.step_in_cycle + 1 161 | if self.step_in_cycle >= self.cur_cycle_steps: 162 | self.cycle += 1 163 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 164 | self.cur_cycle_steps = ( 165 | int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) 166 | + self.warmup_steps 167 | ) 168 | else: 169 | if epoch >= self.first_cycle_steps: 170 | if self.cycle_mult == 1.0: 171 | self.step_in_cycle = epoch % self.first_cycle_steps 172 | self.cycle = epoch // self.first_cycle_steps 173 | else: 174 | n = int( 175 | math.log( 176 | ( 177 | epoch / self.first_cycle_steps * (self.cycle_mult - 1) 178 | + 1 179 | ), 180 | self.cycle_mult, 181 | ) 182 | ) 183 | self.cycle = n 184 | self.step_in_cycle = epoch - int( 185 | self.first_cycle_steps 186 | * (self.cycle_mult ** n - 1) 187 | / (self.cycle_mult - 1) 188 | ) 189 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** ( 190 | n 191 | ) 192 | else: 193 | self.cur_cycle_steps = self.first_cycle_steps 194 | self.step_in_cycle = epoch 195 | 196 | self.max_lr = self.base_max_lr * (self.gamma ** self.cycle) 197 | self.last_epoch = math.floor(epoch) 198 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 199 | param_group['lr'] = lr 200 | -------------------------------------------------------------------------------- /cnstd/ppocr/angle_classifier.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2022, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/PaddlePaddle/PaddleOCR 20 | 21 | import os 22 | import math 23 | import logging 24 | import traceback 25 | from pathlib import Path 26 | from typing import Union, Optional, Any, List, Dict 27 | 28 | import cv2 29 | import numpy as np 30 | 31 | from ..consts import MODEL_VERSION, ANGLE_CLF_MODELS, ANGLE_CLF_SPACE, DOWNLOAD_SOURCE 32 | from ..utils import data_dir, get_model_file 33 | from .postprocess import build_post_process 34 | from .utility import ( 35 | get_image_file_list, 36 | check_and_read_gif, 37 | create_predictor, 38 | parse_args, 39 | ) 40 | 41 | logger = logging.getLogger(__name__) 42 | 43 | 44 | class AngleClassifier(object): 45 | def __init__( 46 | self, 47 | model_name: str = 'ch_ppocr_mobile_v2.0_cls', 48 | *, 49 | model_fp: Optional[str] = None, 50 | clf_image_shape='3, 48, 192', 51 | clf_batch_num=6, 52 | clf_thresh=0.9, 53 | label_list=['0', '180'], # 只支持0和180两个角度,参考:https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.6/doc/doc_ch/angle_class.md # noqa 54 | root: Union[str, Path] = data_dir(), 55 | ): 56 | self._model_name = model_name 57 | self._model_backend = 'onnx' 58 | self.clf_image_shape = [int(v) for v in clf_image_shape.split(",")] 59 | self.clf_batch_num = clf_batch_num 60 | self.clf_thresh = clf_thresh 61 | 62 | self._assert_and_prepare_model_files(model_fp, root) 63 | 64 | postprocess_params = { 65 | 'name': 'ClsPostProcess', 66 | "label_list": label_list, 67 | } 68 | self.postprocess_op = build_post_process(postprocess_params) 69 | self.predictor, self.input_tensor, self.output_tensors, _ = create_predictor( 70 | self._model_fp, 'cls', logger 71 | ) 72 | 73 | def _assert_and_prepare_model_files(self, model_fp, root): 74 | if model_fp is not None and not os.path.isfile(model_fp): 75 | raise FileNotFoundError('can not find model file %s' % model_fp) 76 | 77 | if model_fp is not None: 78 | self._model_fp = model_fp 79 | return 80 | 81 | self._model_dir = os.path.join(root, MODEL_VERSION, ANGLE_CLF_SPACE) 82 | model_fp = os.path.join(self._model_dir, '%s_infer.onnx' % self._model_name) 83 | if not os.path.isfile(model_fp): 84 | logger.warning('can not find model file %s' % model_fp) 85 | if (self._model_name, self._model_backend) not in ANGLE_CLF_MODELS: 86 | raise NotImplementedError( 87 | '%s is not a downloadable model' 88 | % ((self._model_name, self._model_backend),) 89 | ) 90 | url = ANGLE_CLF_MODELS[(self._model_name, self._model_backend)]['url'] 91 | 92 | get_model_file(url, self._model_dir, download_source=DOWNLOAD_SOURCE) # download the .zip file and unzip 93 | 94 | self._model_fp = model_fp 95 | logger.info('use model: %s' % self._model_fp) 96 | 97 | def resize_norm_img(self, img): 98 | imgC, imgH, imgW = self.clf_image_shape 99 | h = img.shape[0] 100 | w = img.shape[1] 101 | ratio = w / float(h) 102 | if math.ceil(imgH * ratio) > imgW: 103 | resized_w = imgW 104 | else: 105 | resized_w = int(math.ceil(imgH * ratio)) 106 | resized_image = cv2.resize(img, (resized_w, imgH)) 107 | resized_image = resized_image.astype('float32') 108 | if self.clf_image_shape[0] == 1: 109 | resized_image = resized_image / 255 110 | resized_image = resized_image[np.newaxis, :] 111 | else: 112 | resized_image = resized_image.transpose((2, 0, 1)) / 255 113 | resized_image -= 0.5 114 | resized_image /= 0.5 115 | padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) 116 | padding_im[:, :, 0:resized_w] = resized_image 117 | return padding_im 118 | 119 | def __call__(self, img_list): 120 | """ 121 | 122 | Args: 123 | img_list (list): each element with shape [H, W, 3], RGB-formated image 124 | 125 | Returns: 126 | img_list (list): rotated images, each element with shape [H, W, 3], RGB-formated image 127 | cls_res (list): 128 | 129 | """ 130 | img_list = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in img_list] 131 | 132 | img_num = len(img_list) 133 | # Calculate the aspect ratio of all text bars 134 | width_list = [] 135 | for img in img_list: 136 | width_list.append(img.shape[1] / float(img.shape[0])) 137 | # Sorting can speed up the cls process 138 | indices = np.argsort(np.array(width_list)) 139 | 140 | cls_res = [['', 0.0]] * img_num 141 | batch_num = self.clf_batch_num 142 | for beg_img_no in range(0, img_num, batch_num): 143 | 144 | end_img_no = min(img_num, beg_img_no + batch_num) 145 | norm_img_batch = [] 146 | max_wh_ratio = 0 147 | for ino in range(beg_img_no, end_img_no): 148 | h, w = img_list[indices[ino]].shape[0:2] 149 | wh_ratio = w * 1.0 / h 150 | max_wh_ratio = max(max_wh_ratio, wh_ratio) 151 | for ino in range(beg_img_no, end_img_no): 152 | norm_img = self.resize_norm_img(img_list[indices[ino]]) 153 | norm_img = norm_img[np.newaxis, :] 154 | norm_img_batch.append(norm_img) 155 | norm_img_batch = np.concatenate(norm_img_batch) 156 | norm_img_batch = norm_img_batch.copy() 157 | 158 | input_dict = {} 159 | input_dict[self.input_tensor.name] = norm_img_batch 160 | outputs = self.predictor.run(self.output_tensors, input_dict) 161 | prob_out = outputs[0] 162 | cls_result = self.postprocess_op(prob_out) 163 | for rno in range(len(cls_result)): 164 | label, score = cls_result[rno] 165 | cls_res[indices[beg_img_no + rno]] = [label, score] 166 | if '180' in label and score > self.clf_thresh: 167 | img_list[indices[beg_img_no + rno]] = cv2.rotate( 168 | img_list[indices[beg_img_no + rno]], 1 169 | ) 170 | 171 | img_list = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_list] 172 | return img_list, cls_res 173 | 174 | 175 | def main(args): 176 | image_file_list = get_image_file_list(args.image_dir) 177 | text_classifier = AngleClassifier(args) 178 | valid_image_file_list = [] 179 | img_list = [] 180 | for image_file in image_file_list: 181 | img, flag = check_and_read_gif(image_file) 182 | if not flag: 183 | img = cv2.imread(image_file) 184 | if img is None: 185 | logger.info("error in loading image:{}".format(image_file)) 186 | continue 187 | valid_image_file_list.append(image_file) 188 | img_list.append(img) 189 | try: 190 | img_list, cls_res, predict_time = text_classifier(img_list) 191 | except Exception as E: 192 | logger.info(traceback.format_exc()) 193 | logger.info(E) 194 | exit() 195 | for ino in range(len(img_list)): 196 | logger.info( 197 | "Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ino]) 198 | ) 199 | 200 | 201 | if __name__ == "__main__": 202 | main(parse_args()) 203 | -------------------------------------------------------------------------------- /cnstd/yolov7/autoanchor.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2022, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits to: https://github.com/WongKinYiu/yolov7, forked to https://github.com/breezedeus/yolov7 20 | # Auto-anchor utils 21 | 22 | import numpy as np 23 | import torch 24 | import yaml 25 | from scipy.cluster.vq import kmeans 26 | from tqdm import tqdm 27 | 28 | from .general import colorstr 29 | 30 | 31 | def check_anchor_order(m): 32 | # Check anchor order against stride order for YOLO Detect() module m, and correct if necessary 33 | a = m.anchor_grid.prod(-1).view(-1) # anchor area 34 | da = a[-1] - a[0] # delta a 35 | ds = m.stride[-1] - m.stride[0] # delta s 36 | if da.sign() != ds.sign(): # same order 37 | print('Reversing anchor order') 38 | m.anchors[:] = m.anchors.flip(0) 39 | m.anchor_grid[:] = m.anchor_grid.flip(0) 40 | 41 | 42 | def check_anchors(dataset, model, thr=4.0, imgsz=640): 43 | # Check anchor fit to data, recompute if necessary 44 | prefix = colorstr('autoanchor: ') 45 | print(f'\n{prefix}Analyzing anchors... ', end='') 46 | m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() 47 | shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) 48 | scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale 49 | wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh 50 | 51 | def metric(k): # compute metric 52 | r = wh[:, None] / k[None] 53 | x = torch.min(r, 1. / r).min(2)[0] # ratio metric 54 | best = x.max(1)[0] # best_x 55 | aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold 56 | bpr = (best > 1. / thr).float().mean() # best possible recall 57 | return bpr, aat 58 | 59 | anchors = m.anchor_grid.clone().cpu().view(-1, 2) # current anchors 60 | bpr, aat = metric(anchors) 61 | print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='') 62 | if bpr < 0.98: # threshold to recompute 63 | print('. Attempting to improve anchors, please wait...') 64 | na = m.anchor_grid.numel() // 2 # number of anchors 65 | try: 66 | anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False) 67 | except Exception as e: 68 | print(f'{prefix}ERROR: {e}') 69 | new_bpr = metric(anchors)[0] 70 | if new_bpr > bpr: # replace anchors 71 | anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors) 72 | m.anchor_grid[:] = anchors.clone().view_as(m.anchor_grid) # for inference 73 | check_anchor_order(m) 74 | m.anchors[:] = anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss 75 | print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.') 76 | else: 77 | print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.') 78 | print('') # newline 79 | 80 | 81 | def kmean_anchors(path='./data/coco.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True): 82 | """ Creates kmeans-evolved anchors from training dataset 83 | 84 | Arguments: 85 | path: path to dataset *.yaml, or a loaded dataset 86 | n: number of anchors 87 | img_size: image size used for training 88 | thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0 89 | gen: generations to evolve anchors using genetic algorithm 90 | verbose: print all results 91 | 92 | Return: 93 | k: kmeans evolved anchors 94 | 95 | Usage: 96 | from utils.autoanchor import *; _ = kmean_anchors() 97 | """ 98 | thr = 1. / thr 99 | prefix = colorstr('autoanchor: ') 100 | 101 | def metric(k, wh): # compute metrics 102 | r = wh[:, None] / k[None] 103 | x = torch.min(r, 1. / r).min(2)[0] # ratio metric 104 | # x = wh_iou(wh, torch.tensor(k)) # iou metric 105 | return x, x.max(1)[0] # x, best_x 106 | 107 | def anchor_fitness(k): # mutation fitness 108 | _, best = metric(torch.tensor(k, dtype=torch.float32), wh) 109 | return (best * (best > thr).float()).mean() # fitness 110 | 111 | def print_results(k): 112 | k = k[np.argsort(k.prod(1))] # sort small to large 113 | x, best = metric(k, wh0) 114 | bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr 115 | print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr') 116 | print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' 117 | f'past_thr={x[x > thr].mean():.3f}-mean: ', end='') 118 | for i, x in enumerate(k): 119 | print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg 120 | return k 121 | 122 | if isinstance(path, str): # *.yaml file 123 | with open(path) as f: 124 | data_dict = yaml.load(f, Loader=yaml.SafeLoader) # model dict 125 | from utils.datasets import LoadImagesAndLabels 126 | dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True) 127 | else: 128 | dataset = path # dataset 129 | 130 | # Get label wh 131 | shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True) 132 | wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh 133 | 134 | # Filter 135 | i = (wh0 < 3.0).any(1).sum() 136 | if i: 137 | print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.') 138 | wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels 139 | # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1 140 | 141 | # Kmeans calculation 142 | print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...') 143 | s = wh.std(0) # sigmas for whitening 144 | k, dist = kmeans(wh / s, n, iter=30) # points, mean distance 145 | assert len(k) == n, print(f'{prefix}ERROR: scipy.cluster.vq.kmeans requested {n} points but returned only {len(k)}') 146 | k *= s 147 | wh = torch.tensor(wh, dtype=torch.float32) # filtered 148 | wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered 149 | k = print_results(k) 150 | 151 | # Plot 152 | # k, d = [None] * 20, [None] * 20 153 | # for i in tqdm(range(1, 21)): 154 | # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance 155 | # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True) 156 | # ax = ax.ravel() 157 | # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.') 158 | # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh 159 | # ax[0].hist(wh[wh[:, 0]<100, 0],400) 160 | # ax[1].hist(wh[wh[:, 1]<100, 1],400) 161 | # fig.savefig('wh.png', dpi=200) 162 | 163 | # Evolve 164 | npr = np.random 165 | f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma 166 | pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar 167 | for _ in pbar: 168 | v = np.ones(sh) 169 | while (v == 1).all(): # mutate until a change occurs (prevent duplicates) 170 | v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) 171 | kg = (k.copy() * v).clip(min=2.0) 172 | fg = anchor_fitness(kg) 173 | if fg > f: 174 | f, k = fg, kg.copy() 175 | pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}' 176 | if verbose: 177 | print_results(k) 178 | 179 | return print_results(k) 180 | -------------------------------------------------------------------------------- /cnstd/ppocr/postprocess/db_postprocess.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2022, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/PaddlePaddle/PaddleOCR 20 | # This code is refered from: 21 | # https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py 22 | 23 | import numpy as np 24 | import cv2 25 | from shapely.geometry import Polygon 26 | import pyclipper 27 | 28 | 29 | class DBPostProcess(object): 30 | """ 31 | The post process for Differentiable Binarization (DB). 32 | """ 33 | 34 | def __init__(self, 35 | bin_thresh=0.3, 36 | box_thresh=0.6, 37 | max_candidates=1000, 38 | unclip_ratio=2.0, 39 | use_dilation=False, 40 | score_mode="fast", 41 | **kwargs): 42 | self.bin_thresh = bin_thresh 43 | self.box_thresh = box_thresh 44 | self.max_candidates = max_candidates 45 | self.unclip_ratio = unclip_ratio 46 | self.min_size = 3 47 | self.score_mode = score_mode 48 | assert score_mode in [ 49 | "slow", "fast" 50 | ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) 51 | 52 | self.dilation_kernel = None if not use_dilation else np.array( 53 | [[1, 1], [1, 1]]) 54 | 55 | def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height, box_thresh): 56 | ''' 57 | _bitmap: single map with shape (1, H, W), 58 | whose values are binarized as {0, 1} 59 | ''' 60 | 61 | bitmap = _bitmap 62 | height, width = bitmap.shape 63 | 64 | outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, 65 | cv2.CHAIN_APPROX_SIMPLE) 66 | if len(outs) == 3: 67 | img, contours, _ = outs[0], outs[1], outs[2] 68 | elif len(outs) == 2: 69 | contours, _ = outs[0], outs[1] 70 | 71 | num_contours = min(len(contours), self.max_candidates) 72 | 73 | boxes = [] 74 | scores = [] 75 | for index in range(num_contours): 76 | contour = contours[index] 77 | points, sside = self.get_mini_boxes(contour) 78 | if sside < self.min_size: 79 | continue 80 | points = np.array(points) 81 | if self.score_mode == "fast": 82 | score = self.box_score_fast(pred, points.reshape(-1, 2)) 83 | else: 84 | score = self.box_score_slow(pred, contour) 85 | if box_thresh > score: 86 | continue 87 | 88 | box = self.unclip(points).reshape(-1, 1, 2) 89 | box, sside = self.get_mini_boxes(box) 90 | if sside < self.min_size + 2: 91 | continue 92 | box = np.array(box) 93 | 94 | box[:, 0] = np.clip( 95 | np.round(box[:, 0] / width * dest_width), 0, dest_width) 96 | box[:, 1] = np.clip( 97 | np.round(box[:, 1] / height * dest_height), 0, dest_height) 98 | boxes.append(box.astype(np.int16)) 99 | scores.append(score) 100 | return np.array(boxes, dtype=np.int16), scores 101 | 102 | def unclip(self, box): 103 | unclip_ratio = self.unclip_ratio 104 | poly = Polygon(box) 105 | distance = poly.area * unclip_ratio / poly.length 106 | offset = pyclipper.PyclipperOffset() 107 | offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 108 | expanded = np.array(offset.Execute(distance)) 109 | return expanded 110 | 111 | def get_mini_boxes(self, contour): 112 | bounding_box = cv2.minAreaRect(contour) 113 | points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) 114 | 115 | index_1, index_2, index_3, index_4 = 0, 1, 2, 3 116 | if points[1][1] > points[0][1]: 117 | index_1 = 0 118 | index_4 = 1 119 | else: 120 | index_1 = 1 121 | index_4 = 0 122 | if points[3][1] > points[2][1]: 123 | index_2 = 2 124 | index_3 = 3 125 | else: 126 | index_2 = 3 127 | index_3 = 2 128 | 129 | box = [ 130 | points[index_1], points[index_2], points[index_3], points[index_4] 131 | ] 132 | return box, min(bounding_box[1]) 133 | 134 | def box_score_fast(self, bitmap, _box): 135 | ''' 136 | box_score_fast: use bbox mean score as the mean score 137 | ''' 138 | h, w = bitmap.shape[:2] 139 | box = _box.copy() 140 | xmin = np.clip(np.floor(box[:, 0].min()).astype(int), 0, w - 1) 141 | xmax = np.clip(np.ceil(box[:, 0].max()).astype(int), 0, w - 1) 142 | ymin = np.clip(np.floor(box[:, 1].min()).astype(int), 0, h - 1) 143 | ymax = np.clip(np.ceil(box[:, 1].max()).astype(int), 0, h - 1) 144 | 145 | mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) 146 | box[:, 0] = box[:, 0] - xmin 147 | box[:, 1] = box[:, 1] - ymin 148 | cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) 149 | return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] 150 | 151 | def box_score_slow(self, bitmap, contour): 152 | ''' 153 | box_score_slow: use polyon mean score as the mean score 154 | ''' 155 | h, w = bitmap.shape[:2] 156 | contour = contour.copy() 157 | contour = np.reshape(contour, (-1, 2)) 158 | 159 | xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) 160 | xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) 161 | ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) 162 | ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) 163 | 164 | mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) 165 | 166 | contour[:, 0] = contour[:, 0] - xmin 167 | contour[:, 1] = contour[:, 1] - ymin 168 | 169 | cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) 170 | return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] 171 | 172 | def __call__(self, outs_dict, shape_list, *, box_thresh=None): 173 | if box_thresh is None: 174 | box_thresh = self.box_thresh 175 | pred = outs_dict['maps'] 176 | pred = pred[:, 0, :, :] 177 | segmentation = pred > self.bin_thresh 178 | 179 | boxes_batch = [] 180 | for batch_index in range(pred.shape[0]): 181 | src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] 182 | if self.dilation_kernel is not None: 183 | mask = cv2.dilate( 184 | np.array(segmentation[batch_index]).astype(np.uint8), 185 | self.dilation_kernel) 186 | else: 187 | mask = segmentation[batch_index] 188 | boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, 189 | src_w, src_h, box_thresh) 190 | 191 | boxes_batch.append({'points': boxes, 'scores': scores}) 192 | return boxes_batch 193 | 194 | 195 | class DistillationDBPostProcess(object): 196 | def __init__(self, 197 | model_name=["student"], 198 | key=None, 199 | thresh=0.3, 200 | box_thresh=0.6, 201 | max_candidates=1000, 202 | unclip_ratio=1.5, 203 | use_dilation=False, 204 | score_mode="fast", 205 | **kwargs): 206 | self.model_name = model_name 207 | self.key = key 208 | self.post_process = DBPostProcess( 209 | bin_thresh=thresh, 210 | box_thresh=box_thresh, 211 | max_candidates=max_candidates, 212 | unclip_ratio=unclip_ratio, 213 | use_dilation=use_dilation, 214 | score_mode=score_mode) 215 | 216 | def __call__(self, predicts, shape_list): 217 | results = {} 218 | for k in self.model_name: 219 | results[k] = self.post_process(predicts[k], shape_list=shape_list) 220 | return results 221 | -------------------------------------------------------------------------------- /cnstd/cn_std.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | from __future__ import absolute_import 21 | 22 | import logging 23 | import traceback 24 | from pathlib import Path 25 | from typing import Tuple, List, Dict, Union, Any, Optional 26 | 27 | from PIL import Image 28 | import numpy as np 29 | 30 | from .consts import AVAILABLE_MODELS 31 | from .detector import Detector 32 | from .ppocr import PP_SPACE, PPDetector, RapidDetector 33 | from .ppocr.angle_classifier import AngleClassifier 34 | from .utils import data_dir 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | class CnStd(object): 40 | """ 41 | 场景文字检测器(Scene Text Detection)。虽然名字中有个"Cn"(Chinese),但其实也可以轻松识别英文的。 42 | """ 43 | 44 | def __init__( 45 | self, 46 | model_name: str = 'ch_PP-OCRv5_det', 47 | *, 48 | auto_rotate_whole_image: bool = False, 49 | rotated_bbox: bool = True, 50 | context: str = 'cpu', 51 | model_fp: Optional[str] = None, 52 | model_backend: str = 'onnx', # ['pytorch', 'onnx'] 53 | root: Union[str, Path] = data_dir(), 54 | use_angle_clf: bool = False, 55 | angle_clf_configs: Optional[dict] = None, 56 | **kwargs, 57 | ): 58 | """ 59 | Args: 60 | model_name: 模型名称。默认为 'ch_PP-OCRv5_det' 61 | auto_rotate_whole_image: 是否自动对整张图片进行旋转调整。默认为False 62 | rotated_bbox: 是否支持检测带角度的文本框;默认为 True,表示支持;取值为 False 时,只检测水平或垂直的文本 63 | context: 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为CPU 64 | model_fp: 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件) 65 | model_backend (str): 'pytorch', or 'onnx'。表明预测时是使用 PyTorch 版本模型,还是使用 ONNX 版本模型。 66 | 同样的模型,ONNX 版本的预测速度一般是 PyTorch 版本的2倍左右。默认为 'onnx'。 67 | root: 模型文件所在的根目录。 68 | Linux/Mac下默认值为 `~/.cnstd`,表示模型文件所处文件夹类似 `~/.cnstd/1.2/db_resnet18` 69 | Windows下默认值为 `C:/Users//AppData/Roaming/cnstd`。 70 | use_angle_clf (bool): 对于检测出的文本框,是否使用角度分类模型进行调整(检测出的文本框可能会存在倒转180度的情况)。 71 | 默认为 `False` 72 | angle_clf_configs (dict): 角度分类模型对应的参数取值,主要包含以下值: 73 | - model_name: 模型名称。默认为 'ch_ppocr_mobile_v2.0_cls' 74 | - model_fp: 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.onnx' 文件)。默认为 `None` 75 | 具体可参考类 `AngleClassifier` 的说明 76 | """ 77 | self.space = AVAILABLE_MODELS.get_space(model_name, model_backend) 78 | if self.space is None: 79 | logger.warning( 80 | 'no available model is found for name %s and backend %s' 81 | % (model_name, model_backend) 82 | ) 83 | model_backend = 'onnx' if model_backend == 'pytorch' else 'pytorch' 84 | logger.warning( 85 | 'trying to use name %s and backend %s' % (model_name, model_backend) 86 | ) 87 | self.space = AVAILABLE_MODELS.get_space(model_name, model_backend) 88 | 89 | if self.space == AVAILABLE_MODELS.CNSTD_SPACE: 90 | det_cls = Detector 91 | elif self.space == PP_SPACE: 92 | det_name = AVAILABLE_MODELS.get_value(model_name, model_backend, 'detector') 93 | det_cls = RapidDetector if det_name == 'RapidDetector' else PPDetector 94 | else: 95 | raise NotImplementedError( 96 | '%s is not supported currently' % ((model_name, model_backend),) 97 | ) 98 | 99 | self.det_model = det_cls( 100 | model_name=model_name, 101 | auto_rotate_whole_image=auto_rotate_whole_image, 102 | rotated_bbox=rotated_bbox, 103 | context=context, 104 | model_fp=model_fp, 105 | model_backend=model_backend, 106 | root=root, 107 | **kwargs, 108 | ) 109 | 110 | self.use_angle_clf = use_angle_clf 111 | if self.use_angle_clf: 112 | angle_clf_configs = angle_clf_configs or dict() 113 | angle_clf_configs['root'] = root 114 | self.angle_clf = AngleClassifier(**angle_clf_configs) 115 | 116 | def detect( 117 | self, 118 | img_list: Union[ 119 | str, 120 | Path, 121 | Image.Image, 122 | np.ndarray, 123 | List[Union[str, Path, Image.Image, np.ndarray]], 124 | ], 125 | resized_shape: Union[int, Tuple[int, int]] = (768, 768), 126 | preserve_aspect_ratio: bool = True, 127 | min_box_size: int = 8, 128 | box_score_thresh: float = 0.3, 129 | batch_size: int = 20, 130 | **kwargs, 131 | ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: 132 | """ 133 | 检测图片中的文本。 134 | Args: 135 | img_list: 支持对单个图片或者多个图片(列表)的检测。每个值可以是图片路径,或者已经读取进来 PIL.Image.Image 或 np.ndarray, 136 | 格式应该是 RGB 3通道,shape: (height, width, 3), 取值:[0, 255] 137 | resized_shape: `int` or `tuple`, `tuple` 含义为 (height, width), `int` 则表示高宽都为此值; 138 | 检测前,先把原始图片resize到接近此大小(只是接近,未必相等)。默认为 `(768, 768)`。 139 | 注:这个取值对检测结果的影响较大,可以针对自己的应用多尝试几组值,再选出最优值。 140 | 例如 (512, 768), (768, 768), (768, 1024)等。 141 | preserve_aspect_ratio: 对原始图片resize时是否保持高宽比不变。默认为 `True`。 142 | min_box_size: 如果检测出的文本框高度或者宽度低于此值,此文本框会被过滤掉。默认为 `8`,也即高或者宽低于 `8` 的文本框会被过滤去掉。 143 | box_score_thresh: 过滤掉得分低于此值的文本框。默认为 `0.3`。 144 | batch_size: 待处理图片很多时,需要分批处理,每批图片的数量由此参数指定。默认为 `20`。 145 | kwargs: 保留参数,目前未被使用。 146 | 147 | Returns: 148 | List[Dict], 每个Dict对应一张图片的检测结果。Dict 中包含以下 keys: 149 | * 'rotated_angle': float, 整张图片旋转的角度。只有 auto_rotate_whole_image==True 才可能非0。 150 | * 'detected_texts': list, 每个元素存储了检测出的一个框的信息,使用词典记录,包括以下几个值: 151 | 'box':检测出的文字对应的矩形框;np.ndarray, shape: (4, 2),对应 box 4个点的坐标值 (x, y) ; 152 | 'score':得分;float 类型;分数越高表示越可靠; 153 | 'cropped_img':对应'box'中的图片patch(RGB格式),会把倾斜的图片旋转为水平。 154 | np.ndarray 类型,shape: (height, width, 3), 取值范围:[0, 255]; 155 | 156 | 示例: 157 | [{'box': array([[416, 77], 158 | [486, 13], 159 | [800, 325], 160 | [730, 390]], dtype=int32), 161 | 'score': 0.8, 'cropped_img': array([[[25, 20, 24], 162 | [26, 21, 25], 163 | [25, 20, 24], 164 | ..., 165 | [11, 11, 13], 166 | [11, 11, 13], 167 | [11, 11, 13]]], dtype=uint8)}, 168 | ... 169 | ] 170 | 171 | """ 172 | single = False 173 | if isinstance(img_list, (list, tuple)): 174 | pass 175 | elif isinstance(img_list, (str, Path, Image.Image, np.ndarray)): 176 | img_list = [img_list] 177 | single = True 178 | else: 179 | raise TypeError('type %s is not supported now' % str(type(img_list))) 180 | 181 | outs = self.det_model.detect( 182 | img_list, 183 | resized_shape=calibrate_resized_shape(resized_shape), 184 | preserve_aspect_ratio=preserve_aspect_ratio, 185 | min_box_size=min_box_size, 186 | box_score_thresh=box_score_thresh, 187 | batch_size=batch_size, 188 | ) 189 | 190 | if self.use_angle_clf: 191 | for out in outs: 192 | crop_img_list = [info['cropped_img'] for info in out['detected_texts']] 193 | try: 194 | crop_img_list, angle_list = self.angle_clf(crop_img_list) 195 | for info, crop_img in zip(out['detected_texts'], crop_img_list): 196 | info['cropped_img'] = crop_img 197 | except Exception as e: 198 | logger.info(traceback.format_exc()) 199 | logger.info(e) 200 | 201 | return outs[0] if single else outs 202 | 203 | 204 | def calibrate_resized_shape(resized_shape): 205 | if isinstance(resized_shape, int): 206 | resized_shape = (resized_shape, resized_shape) 207 | 208 | def calibrate(ori): 209 | return max(int(round(ori / 32) * 32), 32) 210 | 211 | return calibrate(resized_shape[0]), calibrate(resized_shape[1]) 212 | -------------------------------------------------------------------------------- /cnstd/trainer.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/mindee/doctr 20 | 21 | import logging 22 | from pathlib import Path 23 | from typing import Any, Optional, Union, List 24 | 25 | import torch 26 | import torch.optim as optim 27 | from torch import nn 28 | import pytorch_lightning as pl 29 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 30 | from torch.utils.data import DataLoader 31 | 32 | from .lr_scheduler import get_lr_scheduler 33 | from .utils import LocalizationConfusion 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | def get_optimizer(name: str, model, learning_rate, weight_decay): 39 | r"""Init the Optimizer 40 | 41 | Returns: 42 | torch.optim: the optimizer 43 | """ 44 | OPTIMIZERS = { 45 | 'adam': optim.Adam, 46 | 'adamw': optim.AdamW, 47 | 'sgd': optim.SGD, 48 | 'adagrad': optim.Adagrad, 49 | 'rmsprop': optim.RMSprop, 50 | } 51 | 52 | try: 53 | opt_cls = OPTIMIZERS[name.lower()] 54 | optimizer = opt_cls( 55 | model.parameters(), lr=learning_rate, weight_decay=weight_decay 56 | ) 57 | except: 58 | logger.warning('Received unrecognized optimizer, set default Adam optimizer') 59 | optimizer = optim.Adam( 60 | model.parameters(), lr=learning_rate, weight_decay=weight_decay 61 | ) 62 | return optimizer 63 | 64 | 65 | class WrapperLightningModule(pl.LightningModule): 66 | def __init__(self, config, model): 67 | super().__init__() 68 | self.config = config 69 | self.model = model 70 | self._optimizer = get_optimizer( 71 | config['optimizer'], 72 | self.model, 73 | config['learning_rate'], 74 | config.get('weight_decay', 0), 75 | ) 76 | 77 | expected_img_shape = model.cfg['input_shape'] 78 | self.val_metric = LocalizationConfusion( 79 | rotated_bbox=self.model.rotated_bbox, mask_shape=expected_img_shape[1:] 80 | ) 81 | 82 | def forward(self, x): 83 | return self.model(x) 84 | 85 | def training_step(self, batch, batch_idx): 86 | if hasattr(self.model, 'set_current_epoch'): 87 | self.model.set_current_epoch(self.current_epoch) 88 | else: 89 | setattr(self.model, 'current_epoch', self.current_epoch) 90 | res = self.model.calculate_loss(batch) 91 | 92 | # update lr scheduler 93 | sch = self.lr_schedulers() 94 | sch.step() 95 | 96 | losses = res['loss'] 97 | self.log( 98 | 'train_loss', 99 | losses.item(), 100 | on_step=True, 101 | on_epoch=True, 102 | prog_bar=True, 103 | logger=True, 104 | ) 105 | return losses 106 | 107 | def validation_step(self, batch, batch_idx): 108 | if hasattr(self.model, 'validation_step'): 109 | return self.model.validation_step(batch, batch_idx, self) 110 | 111 | res = self.model.calculate_loss( 112 | batch, return_model_output=True, return_preds=True 113 | ) 114 | losses = res['loss'] 115 | val_metrics = {'val_loss': losses.item()} 116 | self.log_dict( 117 | val_metrics, on_step=True, on_epoch=True, prog_bar=True, logger=True, 118 | ) 119 | 120 | pred_boxes = [boxes[:, :-1] for boxes in res['preds'][0]] # 最后一列是分数,去掉不用 121 | gt_boxes = [] 122 | for boxes, ignores in zip(batch['polygons'], batch['ignore_tags']): 123 | boxes = [box for idx, box in enumerate(boxes) if not ignores[idx]] 124 | gt_boxes.append(boxes) 125 | metric_res = self.val_metric.update(gt_boxes, pred_boxes) 126 | val_metrics = {name + '_step': val for name, val in metric_res.items()} 127 | self.log_dict( 128 | val_metrics, on_step=True, on_epoch=False, prog_bar=True, logger=True, 129 | ) 130 | 131 | return losses 132 | 133 | def validation_epoch_end(self, losses_list) -> None: 134 | metric_res = self.val_metric.summary() 135 | val_metrics = {name + '_epoch': val for name, val in metric_res.items()} 136 | self.log_dict( 137 | val_metrics, on_step=False, on_epoch=True, prog_bar=True, logger=True, 138 | ) 139 | self.val_metric.reset() 140 | 141 | def configure_optimizers(self): 142 | return [self._optimizer], [get_lr_scheduler(self.config, self._optimizer)] 143 | 144 | 145 | class PlTrainer(object): 146 | """ 147 | 封装 PyTorch Lightning 的训练器。 148 | """ 149 | 150 | def __init__(self, config, ckpt_fn=None): 151 | self.config = config 152 | 153 | lr_monitor = LearningRateMonitor(logging_interval='step') 154 | callbacks = [lr_monitor] 155 | 156 | mode = self.config.get('pl_checkpoint_mode', 'min') 157 | monitor = self.config.get('pl_checkpoint_monitor') 158 | fn_fields = ckpt_fn or [] 159 | fn_fields.append('{epoch:03d}') 160 | if monitor: 161 | fn_fields.append('{' + monitor + ':.4f}') 162 | checkpoint_callback = ModelCheckpoint( 163 | monitor=monitor, 164 | mode=mode, 165 | filename='-'.join(fn_fields), 166 | save_last=True, 167 | save_top_k=5, 168 | ) 169 | callbacks.append(checkpoint_callback) 170 | 171 | self.pl_trainer = pl.Trainer( 172 | limit_train_batches=self.config.get('limit_train_batches', 1.0), 173 | limit_val_batches=self.config.get('limit_val_batches', 1.0), 174 | num_sanity_val_steps=2, 175 | log_every_n_steps=10, 176 | gpus=self.config.get('gpus'), 177 | max_epochs=self.config.get('epochs', 20), 178 | precision=self.config.get('precision', 32), 179 | callbacks=callbacks, 180 | stochastic_weight_avg=True, 181 | ) 182 | 183 | def fit( 184 | self, 185 | model: nn.Module, 186 | train_dataloader: Any = None, 187 | val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, 188 | datamodule: Optional[pl.LightningDataModule] = None, 189 | resume_from_checkpoint: Optional[Union[Path, str]] = None, 190 | ): 191 | r""" 192 | Runs the full optimization routine. 193 | 194 | Args: 195 | model: Model to fit. 196 | 197 | train_dataloader: Either a single PyTorch DataLoader or a collection of these 198 | (list, dict, nested lists and dicts). In the case of multiple dataloaders, please 199 | see this :ref:`page ` 200 | val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. 201 | If the model has a predefined val_dataloaders method this will be skipped 202 | datamodule: A instance of :class:`LightningDataModule`. 203 | resume_from_checkpoint: Path/URL of the checkpoint from which training is resumed. If there is 204 | no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, 205 | training will start from the beginning of the next epoch. 206 | """ 207 | steps_per_epoch = ( 208 | len(train_dataloader) 209 | if train_dataloader is not None 210 | else len(datamodule.train_dataloader()) 211 | ) 212 | self.config['steps_per_epoch'] = steps_per_epoch 213 | if resume_from_checkpoint is not None: 214 | pl_module = WrapperLightningModule.load_from_checkpoint( 215 | resume_from_checkpoint, config=self.config, model=model 216 | ) 217 | self.pl_trainer = pl.Trainer(resume_from_checkpoint=resume_from_checkpoint) 218 | else: 219 | pl_module = WrapperLightningModule(self.config, model) 220 | 221 | self.pl_trainer.fit(pl_module, train_dataloader, val_dataloaders, datamodule) 222 | 223 | fields = self.pl_trainer.checkpoint_callback.best_model_path.rsplit( 224 | '.', maxsplit=1 225 | ) 226 | fields[0] += '-model' 227 | output_model_fp = '.'.join(fields) 228 | resave_model( 229 | self.pl_trainer.checkpoint_callback.best_model_path, output_model_fp 230 | ) 231 | self.saved_model_file = output_model_fp 232 | 233 | 234 | def resave_model(module_fp, output_model_fp, map_location=None): 235 | """PlTrainer存储的文件对应其 `pl_module` 模块,需利用此函数转存为 `model` 对应的模型文件。""" 236 | checkpoint = torch.load(module_fp, map_location=map_location) 237 | state_dict = {} 238 | for k, v in checkpoint['state_dict'].items(): 239 | state_dict[k.split('.', maxsplit=1)[1]] = v 240 | torch.save({'state_dict': state_dict}, output_model_fp) 241 | -------------------------------------------------------------------------------- /cnstd/datasets/util.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | import numpy as np 21 | import random 22 | import cv2 23 | from shapely.geometry import Polygon 24 | import pyclipper 25 | import os 26 | 27 | 28 | def random_horizontal_flip(imgs): 29 | if random.random() < 0.5: 30 | for i in range(len(imgs)): 31 | imgs[i] = np.flip(imgs[i], axis=1).copy() 32 | return imgs 33 | 34 | 35 | def random_crop(imgs, img_size): 36 | """ 37 | 38 | :param imgs: 包含img和kernel 39 | :param img_size: 40 | :return: 41 | """ 42 | h, w = imgs[0].shape[0:2] 43 | th, tw = img_size 44 | if w == tw and h == th: 45 | return imgs 46 | 47 | if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0: 48 | tl = np.min(np.where(imgs[1] > 0), axis=1) - img_size 49 | tl[tl < 0] = 0 50 | br = np.max(np.where(imgs[1] > 0), axis=1) - img_size 51 | br[br < 0] = 0 52 | br[0] = min(br[0], h - th) 53 | br[1] = min(br[1], w - tw) 54 | 55 | i = random.randint(tl[0], br[0]) 56 | j = random.randint(tl[1], br[1]) 57 | else: 58 | i = random.randint(0, h - th) 59 | j = random.randint(0, w - tw) 60 | 61 | # return i, j, th, tw 62 | for idx in range(len(imgs)): 63 | if len(imgs[idx].shape) == 3: 64 | imgs[idx] = imgs[idx][i : i + th, j : j + tw, :] 65 | else: 66 | imgs[idx] = imgs[idx][i : i + th, j : j + tw] 67 | return imgs 68 | 69 | 70 | def random_rotate(imgs): 71 | angle = np.random.uniform(-10, 10) 72 | cols = imgs[0].shape[1] 73 | rows = imgs[0].shape[0] 74 | 75 | M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1) 76 | for idx in range(len(imgs)): 77 | imgs[idx] = cv2.warpAffine(imgs[idx], M, (cols, rows)) 78 | 79 | return imgs 80 | 81 | 82 | def poly_offset(img, poly, dis): 83 | subj_poly = np.array(poly) 84 | # Polygon(subj_poly).area, Polygon(subj_poly).length 85 | pco = pyclipper.PyclipperOffset() 86 | pco.AddPath(subj_poly, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 87 | solution = pco.Execute(-1.0 * dis) 88 | ss = np.array(solution) 89 | cv2.fillPoly(img, ss.astype(np.int32), 1) 90 | return img 91 | 92 | 93 | def cal_offset(poly, r, max_shr=20): 94 | area, length = Polygon(poly).area, Polygon(poly).length 95 | r = r * r 96 | d = area * (1 - r) / (length + 0.005) + 0.5 97 | d = min(int(d), max_shr) 98 | return d 99 | 100 | 101 | def shrink_polys(img, polys, tags, mini_scale_ratio, num_kernels=6): 102 | h, w = img.shape[:2] 103 | f = lambda x: 1.0 - (1.0 - mini_scale_ratio) / (num_kernels - 1.0) * x 104 | r = [f(i + 1) for i in range(num_kernels)] 105 | training_mask = np.ones((h, w), dtype=np.float32) 106 | kernel_maps = np.zeros((h, w, num_kernels), dtype=np.float32) 107 | score_map = np.zeros((h, w), dtype=np.float32) 108 | for poly, tag in zip(polys, tags): 109 | poly = np.array(poly, dtype=np.float32).reshape((-1, 2)) 110 | cv2.fillPoly(score_map, poly.astype(np.int32)[np.newaxis, :, :], 1) 111 | 112 | for i, val in enumerate(r): 113 | tmp_score_map = np.zeros((h, w), dtype=np.float32) 114 | for poly, tag in zip(polys, tags): 115 | poly = np.array(poly, dtype=np.float32).reshape((-1, 2)) 116 | if tag: 117 | cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0) 118 | d = cal_offset(poly, val) 119 | tmp_score_map = poly_offset(tmp_score_map, poly, d) 120 | kernel_maps[:, :, i] = tmp_score_map 121 | # return [kernel_maps[:, :, i] for i in xrange(num_kernels)], training_mask 122 | return score_map, kernel_maps, training_mask 123 | 124 | 125 | def parse_lines(filename): 126 | with open(filename, 'r') as f: 127 | lines = f.readlines() 128 | 129 | # print(filename) 130 | text_polys = [] 131 | text_tags = [] 132 | if not os.path.exists(filename): 133 | return np.array(text_polys, dtype=np.float32) 134 | for line in lines: 135 | line = line.strip('\n').split(',') 136 | # print(line) 137 | label = line[-1] 138 | # strip BOM. \ufeff for python3, \xef\xbb\bf for python2 139 | line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in line] 140 | if 10 > len(line) > 7: 141 | x1, y1, x2, y2, x3, y3, x4, y4 = list(map(float, line[:8])) 142 | text_polys.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) 143 | elif 7 > len(line) > 3: 144 | x0, y0, x1, y1 = list(map(float, line[:4])) 145 | text_polys.append([[x0, y0], [x1, y0], [x1, y1], [x0, y1]]) 146 | 147 | else: 148 | continue 149 | if label == '*' or label == '###': 150 | # 表示"Do Not Care",可能是太小的文字或者其他东西 151 | text_tags.append(True) 152 | else: 153 | text_tags.append(False) 154 | return np.array(text_polys, dtype=np.float32), np.array(text_tags, dtype=np.bool) 155 | 156 | 157 | def scale(img, long_size=2240): 158 | h, w = img.shape[0:2] 159 | scale = long_size * 1.0 / max(h, w) 160 | img = cv2.resize(img, dsize=None, fx=scale, fy=scale) 161 | return img 162 | 163 | 164 | def random_scale(img, text_polys, min_side=640): 165 | h, w = img.shape[:2] 166 | scale = 1.0 167 | if max(h, w) > 1280.0: 168 | scale = 1280.0 / max(h, w) 169 | img = cv2.resize(img, dsize=None, fx=scale, fy=scale) 170 | if text_polys is not None: 171 | text_polys *= scale 172 | text_polys = np.array(text_polys) 173 | 174 | h, w = img.shape[:2] 175 | random_scale = np.array([0.5, 1.0, 2.0, 3.0]) 176 | scale = np.random.choice(random_scale) 177 | 178 | if min(h, w) * scale < min_side: 179 | scale = (min_side + 10) * 1.0 / min(h, w) 180 | img = cv2.resize(img, dsize=None, fx=scale, fy=scale) 181 | if text_polys is not None: 182 | text_polys *= scale 183 | text_polys = np.array(text_polys) 184 | return img, text_polys 185 | 186 | 187 | def save_images(imgs): 188 | for i, item in enumerate(imgs): 189 | cv2.imwrite('img_{}.png'.format(i), item * 255) 190 | 191 | 192 | def dist(a, b): 193 | return np.sqrt(np.sum((a - b) ** 2)) 194 | 195 | 196 | def perimeter(bbox): 197 | peri = 0.0 198 | for i in range(bbox.shape[0]): 199 | peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]]) 200 | return peri 201 | 202 | 203 | def shrink(bboxes, rate, max_shr=20): 204 | rate = rate * rate 205 | shrinked_bboxes = [] 206 | for bbox in bboxes: 207 | area = Polygon(bbox).area 208 | peri = perimeter(bbox) 209 | 210 | pco = pyclipper.PyclipperOffset() 211 | pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) 212 | offset = min((int)(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr) 213 | 214 | shrinked_bbox = pco.Execute(-offset) 215 | if len(shrinked_bbox) == 0: 216 | shrinked_bboxes.append(bbox) 217 | continue 218 | 219 | shrinked_bbox = np.array(shrinked_bbox[0]) 220 | if shrinked_bbox.shape[0] <= 2: 221 | shrinked_bboxes.append(bbox) 222 | continue 223 | 224 | shrinked_bboxes.append(shrinked_bbox) 225 | 226 | return np.array(shrinked_bboxes) 227 | 228 | 229 | def process_data(image, bboxes, label, num_kernels=6): 230 | img, bboxes = random_scale(image, bboxes) 231 | 232 | gt_text = np.zeros(img.shape[0:2], dtype='uint8') 233 | training_mask = np.ones(img.shape[0:2], dtype='uint8') 234 | if bboxes.shape[0] > 0: 235 | for i in range(bboxes.shape[0]): 236 | cv2.drawContours( 237 | gt_text, bboxes[i][np.newaxis, :, :].astype(np.int32), -1, i + 1, -1 238 | ) 239 | if label[i]: # True,表示"Do Not Care",可能是太小的文字或者其他东西 240 | cv2.drawContours( 241 | training_mask, 242 | bboxes[i][np.newaxis, :, :].astype(np.int32), 243 | -1, 244 | 0, 245 | -1, 246 | ) 247 | 248 | gt_kernals = [] 249 | f = lambda x: 1.0 - (1.0 - 0.5) / (num_kernels) * x 250 | rates = [f(i + 1) for i in range(num_kernels)] 251 | 252 | # from large kernel to small kernel 253 | for rate in rates: 254 | gt_kernal = np.zeros(img.shape[0:2], dtype='uint8') 255 | kernal_bboxes = shrink(bboxes, rate) 256 | for i in range(bboxes.shape[0]): 257 | cv2.drawContours( 258 | gt_kernal, 259 | kernal_bboxes[i][np.newaxis, :, :].astype(np.int32), 260 | -1, 261 | 1, 262 | -1, 263 | ) 264 | gt_kernals.append(gt_kernal) 265 | 266 | imgs = [img, gt_text, training_mask] 267 | imgs.extend(gt_kernals) 268 | 269 | imgs = random_horizontal_flip(imgs) 270 | imgs = random_rotate(imgs) 271 | imgs = random_crop(imgs, (640, 640)) 272 | 273 | img, gt_text, training_mask, gt_kernals = imgs[0], imgs[1], imgs[2], imgs[3:] 274 | 275 | gt_text[gt_text > 0] = 1 276 | gt_kernals = np.array(gt_kernals) 277 | 278 | return img, gt_text, gt_kernals, training_mask 279 | -------------------------------------------------------------------------------- /cnstd/consts.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021-2023, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | import os 21 | import logging 22 | from pathlib import Path 23 | from typing import Tuple, Set, Dict, Any, Optional, Union 24 | from copy import deepcopy 25 | from collections import OrderedDict 26 | 27 | from torchvision.models import ( 28 | resnet50, 29 | resnet34, 30 | resnet18, 31 | mobilenet_v3_large, 32 | mobilenet_v3_small, 33 | shufflenet_v2_x1_0, 34 | shufflenet_v2_x1_5, 35 | shufflenet_v2_x2_0, 36 | ) 37 | 38 | from .__version__ import __version__ 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | 43 | # 模型版本只对应到第二层,第三层的改动表示模型兼容。 44 | # 如: __version__ = '1.0.*',对应的 MODEL_VERSION 都是 '1.0' 45 | MODEL_VERSION = '.'.join(__version__.split('.', maxsplit=2)[:2]) 46 | VOCAB_FP = Path(__file__).parent.parent / 'label_cn.txt' 47 | # Which OSS source will be used for downloading model files, 'CN' or 'HF' 48 | DOWNLOAD_SOURCE = os.environ.get('CNSTD_DOWNLOAD_SOURCE', 'HF') 49 | HF_ENDPOINT_LIST = ['https://huggingface.co', 'https://hf-mirror.com'] 50 | 51 | MODEL_CONFIGS: Dict[str, Dict[str, Any]] = { 52 | 'db_resnet50': { 53 | 'backbone': resnet50, 54 | 'backbone_submodule': None, 55 | 'fpn_layers': ['layer1', 'layer2', 'layer3', 'layer4'], 56 | 'fpn_channels': [256, 512, 1024, 2048], 57 | 'input_shape': (3, 768, 768), # resize后输入模型的图片大小, 即 `resized_shape` 58 | 'url': None, 59 | }, 60 | 'db_resnet34': { 61 | 'backbone': resnet34, 62 | 'backbone_submodule': None, 63 | 'fpn_layers': ['layer1', 'layer2', 'layer3', 'layer4'], 64 | 'fpn_channels': [64, 128, 256, 512], 65 | 'input_shape': (3, 768, 768), 66 | 'url': None, 67 | }, 68 | 'db_resnet18': { 69 | 'backbone': resnet18, 70 | 'backbone_submodule': None, 71 | 'fpn_layers': ['layer1', 'layer2', 'layer3', 'layer4'], 72 | 'fpn_channels': [64, 128, 256, 512], 73 | 'input_shape': (3, 768, 768), 74 | 'url': None, 75 | }, 76 | 'db_mobilenet_v3': { 77 | 'backbone': mobilenet_v3_large, 78 | 'backbone_submodule': 'features', 79 | 'fpn_layers': ['3', '6', '12', '16'], 80 | 'fpn_channels': [24, 40, 112, 960], 81 | 'input_shape': (3, 768, 768), 82 | 'url': None, 83 | }, 84 | 'db_mobilenet_v3_small': { 85 | 'backbone': mobilenet_v3_small, 86 | 'backbone_submodule': 'features', 87 | 'fpn_layers': ['1', '3', '8', '12'], 88 | 'fpn_channels': [16, 24, 48, 576], 89 | 'input_shape': (3, 768, 768), 90 | 'url': None, 91 | }, 92 | 'db_shufflenet_v2': { 93 | 'backbone': shufflenet_v2_x2_0, 94 | 'backbone_submodule': None, 95 | 'fpn_layers': ['maxpool', 'stage2', 'stage3', 'stage4'], 96 | 'fpn_channels': [24, 244, 488, 976], 97 | 'input_shape': (3, 768, 768), 98 | 'url': None, 99 | }, 100 | 'db_shufflenet_v2_small': { 101 | 'backbone': shufflenet_v2_x1_5, 102 | 'backbone_submodule': None, 103 | 'fpn_layers': ['maxpool', 'stage2', 'stage3', 'stage4'], 104 | 'fpn_channels': [24, 176, 352, 704], 105 | 'input_shape': (3, 768, 768), 106 | 'url': None, 107 | }, 108 | 'db_shufflenet_v2_tiny': { 109 | 'backbone': shufflenet_v2_x1_0, 110 | 'backbone_submodule': None, 111 | 'fpn_layers': ['maxpool', 'stage2', 'stage3', 'stage4'], 112 | 'fpn_channels': [24, 116, 232, 464], 113 | 'input_shape': (3, 768, 768), 114 | 'url': None, 115 | }, 116 | } 117 | 118 | HF_HUB_REPO_ID = "breezedeus/cnstd-cnocr-models" 119 | HF_HUB_SUBFOLDER = "models/cnstd/%s" % MODEL_VERSION 120 | CN_OSS_ENDPOINT = ( 121 | "https://sg-models.oss-cn-beijing.aliyuncs.com/cnstd/%s/" % MODEL_VERSION 122 | ) 123 | 124 | 125 | def format_hf_hub_url(url: str) -> dict: 126 | return { 127 | 'repo_id': HF_HUB_REPO_ID, 128 | 'subfolder': HF_HUB_SUBFOLDER, 129 | 'filename': url, 130 | 'cn_oss': CN_OSS_ENDPOINT, 131 | } 132 | 133 | 134 | class AvailableModels(object): 135 | CNSTD_SPACE = '__cnstd__' 136 | 137 | # name: (epochs, url) 138 | # 免费模型 139 | FREE_MODELS = OrderedDict( 140 | { 141 | ('db_resnet34', 'pytorch'): { 142 | 'model_epoch': 41, 143 | 'fpn_type': 'pan', 144 | 'url': 'db_resnet34-pan.zip', 145 | }, 146 | ('db_resnet18', 'pytorch'): { 147 | 'model_epoch': 34, 148 | 'fpn_type': 'pan', 149 | 'url': 'db_resnet18-pan.zip', 150 | }, 151 | ('db_mobilenet_v3', 'pytorch'): { 152 | 'model_epoch': 47, 153 | 'fpn_type': 'pan', 154 | 'url': 'db_mobilenet_v3-pan.zip', 155 | }, 156 | ('db_mobilenet_v3_small', 'pytorch'): { 157 | 'model_epoch': 37, 158 | 'fpn_type': 'pan', 159 | 'url': 'db_mobilenet_v3_small-pan.zip', 160 | }, 161 | ('db_shufflenet_v2', 'pytorch'): { 162 | 'model_epoch': 41, 163 | 'fpn_type': 'pan', 164 | 'url': 'db_shufflenet_v2-pan.zip', 165 | }, 166 | ('db_shufflenet_v2_small', 'pytorch'): { 167 | 'model_epoch': 34, 168 | 'fpn_type': 'pan', 169 | 'url': 'db_shufflenet_v2_small-pan.zip', 170 | }, 171 | } 172 | ) 173 | 174 | # 付费模型 175 | PAID_MODELS = OrderedDict( 176 | { 177 | # ('db_shufflenet_v2_tiny', 'pytorch'): { 178 | # 'model_epoch': 48, 179 | # 'fpn_type': 'pan', 180 | # 'url': 'db_shufflenet_v2_tiny-pan.zip', 181 | # }, 182 | } 183 | ) 184 | 185 | CNSTD_MODELS = deepcopy(FREE_MODELS) 186 | CNSTD_MODELS.update(PAID_MODELS) 187 | 188 | OUTER_MODELS = {} 189 | 190 | def all_models(self) -> Set[Tuple[str, str]]: 191 | return set(self.CNSTD_MODELS.keys()) | set(self.OUTER_MODELS.keys()) 192 | 193 | def __contains__(self, model_name_backend: Tuple[str, str]) -> bool: 194 | return model_name_backend in self.all_models() 195 | 196 | def register_models(self, model_dict: Dict[Tuple[str, str], Any], space: str): 197 | assert not space.startswith('__') 198 | for key, val in model_dict.items(): 199 | if key in self.CNSTD_MODELS or key in self.OUTER_MODELS: 200 | logger.warning( 201 | 'model %s has already existed, and will be ignored' % key 202 | ) 203 | continue 204 | val = deepcopy(val) 205 | val['space'] = space 206 | self.OUTER_MODELS[key] = val 207 | 208 | def get_space(self, model_name, model_backend) -> Optional[str]: 209 | if (model_name, model_backend) in self.CNSTD_MODELS: 210 | return self.CNSTD_SPACE 211 | elif (model_name, model_backend) in self.OUTER_MODELS: 212 | return self.OUTER_MODELS[(model_name, model_backend)]['space'] 213 | return None 214 | 215 | def get_value(self, model_name, model_backend, key) -> Optional[Any]: 216 | if (model_name, model_backend) in self.CNSTD_MODELS: 217 | info = self.CNSTD_MODELS[(model_name, model_backend)] 218 | elif (model_name, model_backend) in self.OUTER_MODELS: 219 | info = self.OUTER_MODELS[(model_name, model_backend)] 220 | else: 221 | logger.warning( 222 | 'no url is found for model %s' % ((model_name, model_backend),) 223 | ) 224 | return None 225 | return info.get(key) 226 | 227 | def get_epoch(self, model_name, model_backend) -> Optional[int]: 228 | return self.get_value(model_name, model_backend, 'model_epoch') 229 | 230 | def get_fpn_type(self, model_name, model_backend) -> Optional[int]: 231 | return self.get_value(model_name, model_backend, 'fpn_type') 232 | 233 | def get_url(self, model_name, model_backend) -> Optional[dict]: 234 | url = self.get_value(model_name, model_backend, 'url') 235 | if url: 236 | url = format_hf_hub_url(url) 237 | 238 | return url 239 | 240 | 241 | AVAILABLE_MODELS = AvailableModels() 242 | 243 | ANGLE_CLF_SPACE = 'angle_clf' 244 | ANGLE_CLF_MODELS = { 245 | ('ch_ppocr_mobile_v2.0_cls', 'onnx'): { 246 | 'url': format_hf_hub_url('ch_ppocr_mobile_v2.0_cls_infer-onnx.zip') 247 | } 248 | } 249 | 250 | ANALYSIS_SPACE = 'analysis' 251 | ANALYSIS_MODELS = { 252 | 'layout': { 253 | ('yolov7_tiny', 'pytorch'): { 254 | 'url': format_hf_hub_url('yolov7_tiny_layout-pytorch.zip'), 255 | 'arch_yaml': Path(__file__).parent / 'yolov7' / 'yolov7-tiny-layout.yaml', 256 | } 257 | }, 258 | 'mfd': { 259 | ('yolov7_tiny', 'pytorch'): { 260 | 'url': format_hf_hub_url('yolov7_tiny_mfd-pytorch.zip'), 261 | 'arch_yaml': Path(__file__).parent / 'yolov7' / 'yolov7-tiny-mfd.yaml', 262 | }, 263 | ('yolov7', 'pytorch'): { 264 | 'url': format_hf_hub_url('yolov7_mfd-pytorch.zip'), 265 | 'arch_yaml': Path(__file__).parent / 'yolov7' / 'yolov7-mfd.yaml', 266 | }, 267 | }, 268 | } 269 | -------------------------------------------------------------------------------- /cnstd/utils/_utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | # Credits: adapted from https://github.com/mindee/doctr 20 | 21 | import numpy as np 22 | import cv2 23 | from math import floor 24 | from typing import List 25 | from statistics import median_low 26 | 27 | __all__ = ['estimate_orientation', 'extract_crops', 'extract_rcrops', 'rotate_page', 'get_bitmap_angle'] 28 | 29 | 30 | def extract_crops(img: np.ndarray, boxes: np.ndarray) -> List[np.ndarray]: 31 | """Created cropped images from list of bounding boxes 32 | 33 | Args: 34 | img: input image 35 | boxes: bounding boxes of shape (N, 4) where N is the number of boxes, and the relative 36 | coordinates (xmin, ymin, xmax, ymax) 37 | 38 | Returns: 39 | list of cropped images 40 | """ 41 | if boxes.shape[0] == 0: 42 | return [] 43 | if boxes.shape[1] != 4: 44 | raise AssertionError("boxes are expected to be relative and in order (xmin, ymin, xmax, ymax)") 45 | 46 | # Project relative coordinates 47 | _boxes = boxes.copy() 48 | if 'float' in str(_boxes.dtype) and _boxes.max() <= 1.0: 49 | _boxes[:, [0, 2]] *= img.shape[1] 50 | _boxes[:, [1, 3]] *= img.shape[0] 51 | _boxes = _boxes.round().astype(int) 52 | # Add last index 53 | _boxes[2:] += 1 54 | _boxes[:2] -= 1 55 | _boxes[_boxes < 0] = 0 56 | return [img[box[1]: box[3], box[0]: box[2]] for box in _boxes] 57 | 58 | 59 | def extract_rcrops(img: np.ndarray, boxes: np.ndarray, dtype=np.float32) -> List[np.ndarray]: 60 | """Created cropped images from list of rotated bounding boxes 61 | 62 | Args: 63 | img: input image 64 | boxes: bounding boxes of shape (N, 5) where N is the number of boxes, and the relative 65 | coordinates (x, y, w, h, alpha) 66 | 67 | Returns: 68 | list of cropped images 69 | """ 70 | if boxes.shape[0] == 0: 71 | return [] 72 | if boxes.shape[1] != 5: 73 | raise AssertionError("boxes are expected to be relative and in order (x, y, w, h, alpha)") 74 | 75 | # Project relative coordinates 76 | _boxes = boxes.copy() 77 | if 'float' in str(_boxes.dtype) and _boxes[:, 0:4].max() <= 1.0: 78 | _boxes[:, [0, 2]] *= img.shape[1] 79 | _boxes[:, [1, 3]] *= img.shape[0] 80 | 81 | crops = [] 82 | for box in _boxes: 83 | x, y, w, h, alpha = box.astype(dtype) 84 | vertical_box = False 85 | if (abs(alpha) < 3 and w * 1.3 < h) or (90 - abs(alpha) < 3 and w > h * 1.3): 86 | vertical_box = True 87 | 88 | process_func = _process_vertical_box if vertical_box else _process_horizontal_box 89 | crop = process_func(img, box, dtype) 90 | 91 | crops.append(crop) 92 | 93 | return crops 94 | 95 | 96 | def _process_horizontal_box(img, box, dtype): 97 | x, y, w, h, alpha = box.astype(dtype) 98 | if alpha > 80 and w < h: # for opencv-python >= 4.5.2 99 | alpha -= 90 100 | w, h = h, w 101 | clockwise = False 102 | if w > h: 103 | clockwise = True 104 | if clockwise: 105 | # 1 -------- 2 106 | # | | 107 | # * -------- 3 108 | dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1]], dtype=dtype) 109 | else: 110 | # * -------- 1 111 | # | | 112 | # 3 -------- 2 113 | # dst_pts = np.array([[h - 1, 0], [h - 1, w - 1], [0, w - 1]], dtype=dtype) 114 | # 2 -------- 3 115 | # | | 116 | # 1 -------- * 117 | dst_pts = np.array([[0, w - 1], [0, 0], [h - 1, 0]], dtype=dtype) 118 | # The transformation matrix 119 | src_pts = cv2.boxPoints(((x, y), (w, h), alpha)) 120 | M = cv2.getAffineTransform(src_pts[1:, :], dst_pts) 121 | # Warp the rotated rectangle 122 | if clockwise: 123 | crop = cv2.warpAffine(img, M, (int(w), int(h))) 124 | else: 125 | crop = cv2.warpAffine(img, M, (int(h), int(w))) 126 | return crop 127 | 128 | 129 | def _process_vertical_box(img, box, dtype): 130 | x, y, w, h, alpha = box.astype(dtype) 131 | clockwise = False 132 | if w > h: 133 | clockwise = True 134 | if clockwise: 135 | # 2 ------- 3 136 | # | | 137 | # | | 138 | # | | 139 | # | | 140 | # 1 ------- * 141 | dst_pts = np.array([[0, w - 1], [0, 0], [h - 1, 0]], dtype=dtype) 142 | # dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1]], dtype=dtype) 143 | else: 144 | # 1 ------- 2 145 | # | | 146 | # | | 147 | # | | 148 | # | | 149 | # * ------- 3 150 | dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1]], dtype=dtype) 151 | # The transformation matrix 152 | src_pts = cv2.boxPoints(((x, y), (w, h), alpha)) 153 | M = cv2.getAffineTransform(src_pts[1:, :], dst_pts) 154 | # Warp the rotated rectangle 155 | if clockwise: 156 | crop = cv2.warpAffine(img, M, (int(h), int(w))) 157 | else: 158 | crop = cv2.warpAffine(img, M, (int(w), int(h))) 159 | return crop 160 | 161 | 162 | def rotate_page( 163 | image: np.ndarray, 164 | angle: float = 0., 165 | min_angle: float = 1. 166 | ) -> np.ndarray: 167 | """Rotate an image counterclockwise by an ange alpha (negative angle to go clockwise). 168 | 169 | Args: 170 | image: numpy tensor to rotate 171 | angle: rotation angle in degrees, between -90 and +90 172 | min_angle: min. angle in degrees to rotate a page 173 | 174 | Returns: 175 | Rotated array or tf.Tensor, padded by 0 by default. 176 | """ 177 | if abs(angle) < min_angle or abs(angle) > 90 - min_angle: 178 | return image 179 | 180 | height, width = image.shape[:2] 181 | center = (height / 2, width / 2) 182 | rot_mat = cv2.getRotationMatrix2D(center, angle, 1.0) 183 | return cv2.warpAffine(image, rot_mat, (width, height)) 184 | 185 | 186 | def get_max_width_length_ratio(contour: np.ndarray) -> float: 187 | """ 188 | Get the maximum shape ratio of a contour. 189 | Args: 190 | contour: the contour from cv2.findContour 191 | 192 | Returns: the maximum shape ratio 193 | 194 | """ 195 | _, (w, h), _ = cv2.minAreaRect(contour) 196 | return max(w / h, h / w) 197 | 198 | 199 | def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> float: 200 | """Estimate the angle of the general document orientation based on the 201 | lines of the document and the assumption that they should be horizontal. 202 | 203 | Args: 204 | img: the img to analyze 205 | n_ct: the number of contours used for the orientation estimation 206 | ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines 207 | Returns: 208 | the angle of the general document orientation 209 | """ 210 | gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 211 | gray_img = cv2.medianBlur(gray_img, 5) 212 | thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] 213 | 214 | # try to merge words in lines 215 | (h, w) = img.shape[:2] 216 | k_x = max(1, (floor(w / 100))) 217 | k_y = max(1, (floor(h / 100))) 218 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y)) 219 | thresh = cv2.dilate(thresh, kernel, iterations=1) 220 | 221 | # extract contours 222 | contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 223 | 224 | # Sort contours 225 | contours = sorted(contours, key=get_max_width_length_ratio, reverse=True) 226 | 227 | angles = [] 228 | for contour in contours[:n_ct]: 229 | _, (w, h), angle = cv2.minAreaRect(contour) 230 | if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines 231 | angles.append(angle) 232 | elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree 233 | angles.append(angle - 90) 234 | return -median_low(angles) 235 | 236 | 237 | def get_bitmap_angle(bitmap: np.ndarray, n_ct: int = 20, std_max: float = 3.) -> float: 238 | """From a binarized segmentation map, find contours and fit min area rectangles to determine page angle 239 | 240 | Args: 241 | bitmap: binarized segmentation map 242 | n_ct: number of contours to use to fit page angle 243 | std_max: maximum deviation of the angle distribution to consider the mean angle reliable 244 | 245 | Returns: 246 | The angle of the page 247 | """ 248 | # Find all contours on binarized seg map 249 | contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 250 | # Sort contours 251 | contours = sorted(contours, key=cv2.contourArea, reverse=True) 252 | 253 | # Find largest contours and fit angles 254 | # Track heights and widths to find aspect ratio (determine is rotation is clockwise) 255 | angles, heights, widths = [], [], [] 256 | for ct in contours[:n_ct]: 257 | _, (w, h), alpha = cv2.minAreaRect(ct) 258 | widths.append(w) 259 | heights.append(h) 260 | angles.append(alpha) 261 | 262 | if np.std(angles) > std_max: 263 | # Edge case with angles of both 0 and 90°, or multi_oriented docs 264 | angle = 0. 265 | else: 266 | angle = -np.mean(angles) 267 | # Determine rotation direction (clockwise/counterclockwise) 268 | # Angle coverage: [-90°, +90°], half of the quadrant 269 | if np.sum(widths) < np.sum(heights): # CounterClockwise 270 | angle = 90 + angle 271 | 272 | return angle 273 | -------------------------------------------------------------------------------- /cnstd/ppocr/rapid_detector.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Copyright (C) 2024, [Breezedeus](https://github.com/breezedeus). 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | import os 21 | import logging 22 | from pathlib import Path 23 | from copy import deepcopy 24 | from typing import Tuple, List, Dict, Union, Any, Optional 25 | 26 | import numpy as np 27 | from PIL import Image 28 | import cv2 29 | # from rapidocr_onnxruntime.ch_ppocr_det import TextDetector 30 | # from rapidocr_onnxruntime import RapidOCR 31 | from rapidocr import EngineType, LangDet, ModelType, OCRVersion 32 | from rapidocr.utils.typings import TaskType 33 | from rapidocr.ch_ppocr_det import TextDetector 34 | 35 | from ..consts import AVAILABLE_MODELS, MODEL_VERSION 36 | from ..utils import read_img, data_dir, prepare_model_files 37 | from .utility import get_rotate_crop_image 38 | from .consts import PP_SPACE 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | class Config(dict): 43 | DEFAULT_CFG = { 44 | "engine_type": EngineType.ONNXRUNTIME, 45 | "lang_type": LangDet.CH, 46 | "model_type": ModelType.SERVER, 47 | "ocr_version": OCRVersion.PPOCRV5, 48 | "task_type": TaskType.DET, 49 | "model_path": None, 50 | "model_dir": None, 51 | "limit_side_len": 736, 52 | "limit_type": "min", 53 | "std": [0.5, 0.5, 0.5], 54 | "mean": [0.5, 0.5, 0.5], 55 | "thresh": 0.3, 56 | "box_thresh": 0.5, 57 | "max_candidates": 1000, 58 | "unclip_ratio": 1.6, 59 | "use_dilation": True, 60 | "score_mode": "fast", 61 | "engine_cfg": { 62 | "intra_op_num_threads": -1, 63 | "inter_op_num_threads": -1, 64 | "enable_cpu_mem_arena": False, 65 | "cpu_ep_cfg": {"arena_extend_strategy": "kSameAsRequested"}, 66 | "use_cuda": False, 67 | "cuda_ep_cfg": { 68 | "device_id": 0, 69 | "arena_extend_strategy": "kNextPowerOfTwo", 70 | "cudnn_conv_algo_search": "EXHAUSTIVE", 71 | "do_copy_in_default_stream": True, 72 | }, 73 | "use_dml": False, 74 | "dm_ep_cfg": None, 75 | "use_cann": False, 76 | "cann_ep_cfg": { 77 | "device_id": 0, 78 | "arena_extend_strategy": "kNextPowerOfTwo", 79 | "npu_mem_limit": 21474836480, 80 | "op_select_impl_mode": "high_performance", 81 | "optypelist_for_implmode": "Gelu", 82 | "enable_cann_graph": True, 83 | }, 84 | }, 85 | } 86 | 87 | def __init__(self, *args, **kwargs): 88 | super().__init__() 89 | data = dict(*args, **kwargs) 90 | for k, v in data.items(): 91 | if isinstance(v, dict): 92 | v = Config(v) 93 | self[k] = v 94 | 95 | def __getattr__(self, name): 96 | try: 97 | return self[name] 98 | except KeyError: 99 | raise AttributeError(name) 100 | 101 | def __setattr__(self, name, value): 102 | self[name] = value 103 | 104 | 105 | class RapidDetector(object): 106 | """ 107 | 场景文字检测器(Scene Text Detection),使用 rapidocr_onnxruntime 中的 TextDetector。 108 | """ 109 | 110 | def __init__( 111 | self, 112 | model_name: str = 'ch_PP-OCRv5_det', 113 | *, 114 | model_fp: Optional[str] = None, 115 | root: Union[str, Path] = data_dir(), 116 | context: str = 'cpu', # ['cpu', 'gpu'] 117 | limit_side_len: int = 736, 118 | limit_type: str = "min", 119 | thresh: float = 0.3, 120 | box_thresh: float = 0.5, 121 | max_candidates: int = 1000, 122 | unclip_ratio: float = 1.6, 123 | use_dilation: bool = True, 124 | score_mode: str = "fast", 125 | **kwargs, 126 | ): 127 | """ 128 | Args: 129 | model_name: 模型名称,目前只支持 'rapid',默认为 'rapid' 130 | model_fp: 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.onnx' 文件) 131 | root: 模型文件所在的根目录。默认为 `~/.cnstd` 132 | context: 使用的设备,可选值为 'cpu' 或 'gpu',默认为 'cpu' 133 | limit_side_len: 限制图片最长边的长度,默认为 736 134 | limit_type: 限制类型,可选值为 'min' 或 'max',默认为 'min' 135 | thresh: 二值化阈值,默认为 0.3 136 | box_thresh: 文本框阈值,默认为 0.5 137 | max_candidates: 最大候选框数量,默认为 1000 138 | unclip_ratio: 文本框扩张比例,默认为 1.6 139 | use_dilation: 是否使用膨胀,默认为 True 140 | score_mode: 得分模式,可选值为 'fast' 或 'slow',默认为 'fast' 141 | kwargs: 其他参数 142 | """ 143 | self._model_name = model_name 144 | self._model_backend = 'onnx' 145 | self._assert_and_prepare_model_files(model_fp, root) 146 | use_gpu = context.lower() not in ('cpu', 'mps') 147 | 148 | config = Config.DEFAULT_CFG 149 | config["engine_cfg"]["use_cuda"] = use_gpu 150 | if "engine_cfg" in kwargs: 151 | config["engine_cfg"].update(kwargs["engine_cfg"]) 152 | 153 | config.update({ 154 | "limit_side_len": limit_side_len, 155 | "limit_type": limit_type, 156 | "thresh": thresh, 157 | "box_thresh": box_thresh, 158 | "max_candidates": max_candidates, 159 | "unclip_ratio": unclip_ratio, 160 | "use_dilation": use_dilation, 161 | "score_mode": score_mode, 162 | "model_path": self._model_fp, 163 | }) 164 | # 从 model_name 中获取 model_type 和 ocr_version 165 | config["model_type"] = ModelType.SERVER if "server" in model_name else ModelType.MOBILE 166 | config["ocr_version"] = OCRVersion.PPOCRV5 if "v5" in model_name else OCRVersion.PPOCRV4 167 | 168 | config = Config(config) 169 | self._detector = TextDetector(config) 170 | 171 | def _assert_and_prepare_model_files(self, model_fp, root): 172 | if model_fp is not None and not os.path.isfile(model_fp): 173 | raise FileNotFoundError('can not find model file %s' % model_fp) 174 | 175 | if model_fp is not None: 176 | self._model_fp = model_fp 177 | return 178 | 179 | root = os.path.join(root, MODEL_VERSION) 180 | self._model_dir = os.path.join(root, PP_SPACE, self._model_name) 181 | model_fp = os.path.join(self._model_dir, '%s_infer.onnx' % self._model_name) 182 | if not os.path.isfile(model_fp): 183 | logger.warning('can not find model file %s' % model_fp) 184 | if (self._model_name, self._model_backend) not in AVAILABLE_MODELS: 185 | raise NotImplementedError( 186 | '%s is not a downloadable model' 187 | % ((self._model_name, self._model_backend),) 188 | ) 189 | remote_repo = AVAILABLE_MODELS.get_value(self._model_name, self._model_backend, 'repo') 190 | model_fp = prepare_model_files(model_fp, remote_repo) 191 | 192 | self._model_fp = model_fp 193 | logger.info('use model: %s' % self._model_fp) 194 | 195 | def detect( 196 | self, 197 | img_list: Union[ 198 | str, 199 | Path, 200 | Image.Image, 201 | np.ndarray, 202 | List[Union[str, Path, Image.Image, np.ndarray]], 203 | ], 204 | **kwargs, 205 | ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: 206 | """ 207 | 检测图片中的文本。 208 | Args: 209 | img_list: 支持对单个图片或者多个图片(列表)的检测。每个值可以是图片路径,或者已经读取进来 PIL.Image.Image 或 np.ndarray, 210 | 格式应该是 RGB 3通道,shape: (height, width, 3), 取值:[0, 255] 211 | kwargs: 其他参数,目前未被使用 212 | 213 | Returns: 214 | List[Dict], 每个Dict对应一张图片的检测结果。Dict 中包含以下 keys: 215 | * 'rotated_angle': float, 整张图片旋转的角度。只有 auto_rotate_whole_image==True 才可能非0。 216 | * 'detected_texts': list, 每个元素存储了检测出的一个框的信息,使用词典记录,包括以下几个值: 217 | 'box':检测出的文字对应的矩形框;np.ndarray, shape: (4, 2),对应 box 4个点的坐标值 (x, y) ; 218 | 'score':得分;float 类型;分数越高表示越可靠; 219 | 'cropped_img':对应'box'中的图片patch(RGB格式),会把倾斜的图片旋转为水平。 220 | np.ndarray 类型,shape: (height, width, 3), 取值范围:[0, 255]; 221 | """ 222 | single = False 223 | if isinstance(img_list, (list, tuple)): 224 | pass 225 | elif isinstance(img_list, (str, Path, Image.Image, np.ndarray)): 226 | img_list = [img_list] 227 | single = True 228 | else: 229 | raise TypeError('type %s is not supported now' % str(type(img_list))) 230 | 231 | out = [] 232 | for img in img_list: 233 | if isinstance(img, (str, Path)): 234 | if not os.path.isfile(img): 235 | raise FileNotFoundError(img) 236 | img = read_img(img) 237 | if isinstance(img, Image.Image): 238 | img = np.array(img) 239 | 240 | if not isinstance(img, np.ndarray): 241 | raise TypeError('type %s is not supported now' % str(type(img))) 242 | 243 | # rapidocr 需要 BGR 格式的图片 244 | if len(img.shape) == 3 and img.shape[2] == 3: 245 | img = img[..., ::-1] # RGB to BGR 246 | 247 | det_out = self._detector(img) 248 | if det_out is None or det_out.boxes is None or len(det_out.boxes) < 1: 249 | out.append({ 250 | 'rotated_angle': 0.0, # rapidocr 不支持自动旋转 251 | 'detected_texts': [], 252 | }) 253 | continue 254 | 255 | # boxes = self._detector.sorted_boxes(boxes) 256 | 257 | # 构造返回结果 258 | detected_texts = [] 259 | for box, score in zip(det_out.boxes, det_out.scores): 260 | box = np.array(box).astype(np.int32) 261 | img_crop = get_rotate_crop_image(img, deepcopy(box)) 262 | img_crop = cv2.cvtColor(img_crop, cv2.COLOR_BGR2RGB) 263 | detected_texts.append({ 264 | 'box': box, 265 | 'score': score, 266 | 'cropped_img': img_crop.astype('uint8'), 267 | }) 268 | 269 | out.append({ 270 | 'rotated_angle': 0.0, # rapidocr 不支持自动旋转 271 | 'detected_texts': detected_texts, 272 | }) 273 | 274 | return out[0] if single else out 275 | --------------------------------------------------------------------------------