├── .gitignore ├── docs ├── error │ ├── missDll.png │ └── index.md ├── test_images │ ├── test_1.png │ ├── test_2.png │ ├── test_3.png │ └── test_4.png ├── staticLinking │ └── index.md └── README_zh-Hans.md ├── src ├── ocr_error.rs ├── base_net.rs ├── scale_param.rs ├── ocr_result.rs ├── angle_net.rs ├── ocr_utils.rs ├── crnn_net.rs ├── lib.rs ├── ocr_lite.rs └── db_net.rs ├── Cargo.toml ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | Cargo.lock 3 | test 4 | models 5 | test_output -------------------------------------------------------------------------------- /docs/error/missDll.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mg-chao/paddle-ocr-rs/HEAD/docs/error/missDll.png -------------------------------------------------------------------------------- /docs/test_images/test_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mg-chao/paddle-ocr-rs/HEAD/docs/test_images/test_1.png -------------------------------------------------------------------------------- /docs/test_images/test_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mg-chao/paddle-ocr-rs/HEAD/docs/test_images/test_2.png -------------------------------------------------------------------------------- /docs/test_images/test_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mg-chao/paddle-ocr-rs/HEAD/docs/test_images/test_3.png -------------------------------------------------------------------------------- /docs/test_images/test_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mg-chao/paddle-ocr-rs/HEAD/docs/test_images/test_4.png -------------------------------------------------------------------------------- /docs/error/index.md: -------------------------------------------------------------------------------- 1 | ## 问题处理 2 | 3 | ### 缺少 *.dll 文件 4 | 5 | ![缺少 DLL 文件](./missDll.png) 6 | 7 | 参考[静态链接](/docs/staticLinking/index.md)文档 -------------------------------------------------------------------------------- /src/ocr_error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Error, Debug)] 4 | pub enum OcrError { 5 | #[error("Ort error")] 6 | Ort(#[from] ort::Error), 7 | #[error("Io error")] 8 | Io(#[from] std::io::Error), 9 | #[error("Session not initialized")] 10 | ImageError(#[from] image::ImageError), 11 | #[error("Image error")] 12 | SessionNotInitialized, 13 | } 14 | -------------------------------------------------------------------------------- /docs/staticLinking/index.md: -------------------------------------------------------------------------------- 1 | ## 静态链接 2 | 3 | 主要是处理 [ort](https://ort.pyke.io/) 的静态链接,[参考文档](https://ort.pyke.io/setup/linking#static-linking)。 4 | 5 | ### Windows 示例: 6 | 7 | 在 [onnxruntime-build](https://github.com/supertone-inc/onnxruntime-build/releases) 下载 Windows 平台的 lib。 8 | 9 | 解压后将 lib 文件夹放到项目根目录下(注意是解压包里的 lib 文件夹),然后配置 .cargo/config.toml: 10 | 11 | ```toml 12 | [target.x86_64-pc-windows-msvc] 13 | rustflags = ["-Ctarget-feature=+crt-static"] 14 | [target.i686-pc-windows-msvc] 15 | rustflags = ["-Ctarget-feature=+crt-static"] 16 | 17 | [env] 18 | ORT_LIB_LOCATION = "./lib" 19 | ``` 20 | 21 | 环境变量 `ORT_LIB_LOCATION` 可以自由配置,路径应指向文件夹。 22 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "paddle-ocr-rs" 3 | version = "0.6.1" 4 | edition = "2024" 5 | readme = "README.md" 6 | keywords = ["paddle", "ocr", "onnx"] 7 | license = "Apache-2.0" 8 | description = "Use Rust to call Paddle OCR models via ONNX Runtime for image text recognition." 9 | repository = "https://github.com/mg-chao/paddle-ocr-rs" 10 | authors = ["mg-chao "] 11 | 12 | [dependencies] 13 | serde_json = "1" 14 | serde = { version = "1", features = ["derive"] } 15 | imageproc = "0.25" 16 | ndarray = "0.16" 17 | ort = { version = "2.0.0-rc.10", default-features = false, features = [ 18 | "ndarray", 19 | ] } 20 | geo-types = "0.7" 21 | geo-clipper = "0.9" 22 | clipper-sys = "0.8" 23 | thiserror = "2" 24 | image = "0.25" 25 | 26 | [features] 27 | default = ["download-binaries", "copy-dylibs", "ort/default"] 28 | download-binaries = ["ort/download-binaries"] 29 | copy-dylibs = ["ort/copy-dylibs"] 30 | -------------------------------------------------------------------------------- /src/base_net.rs: -------------------------------------------------------------------------------- 1 | use ort::session::{ 2 | builder::{GraphOptimizationLevel, SessionBuilder}, 3 | Session, 4 | }; 5 | 6 | use crate::ocr_error::OcrError; 7 | 8 | pub trait BaseNet { 9 | fn new() -> Self; 10 | 11 | fn get_session_builder( 12 | &self, 13 | num_thread: usize, 14 | builder_fn: Option Result>, 15 | ) -> Result { 16 | let builder = Session::builder()?; 17 | let builder = match builder_fn { 18 | Some(custom) => custom(builder)?, 19 | None => builder 20 | .with_optimization_level(GraphOptimizationLevel::Level2)? 21 | .with_intra_threads(num_thread)? 22 | .with_inter_threads(num_thread)?, 23 | }; 24 | 25 | Ok(builder) 26 | } 27 | 28 | fn set_input_names(&mut self, input_names: Vec); 29 | fn set_session(&mut self, session: Option); 30 | 31 | fn init(&mut self, session: Session) { 32 | let input_names: Vec = session 33 | .inputs 34 | .iter() 35 | .map(|input| input.name.clone()) 36 | .collect(); 37 | 38 | self.set_input_names(input_names); 39 | self.set_session(Some(session)); 40 | } 41 | 42 | fn init_model( 43 | &mut self, 44 | path: &str, 45 | num_thread: usize, 46 | builder_fn: Option Result>, 47 | ) -> Result<(), OcrError> { 48 | let session = self 49 | .get_session_builder(num_thread, builder_fn)? 50 | .commit_from_file(path)?; 51 | self.init(session); 52 | 53 | Ok(()) 54 | } 55 | 56 | fn init_model_from_memory( 57 | &mut self, 58 | model_bytes: &[u8], 59 | num_thread: usize, 60 | builder_fn: Option Result>, 61 | ) -> Result<(), OcrError> { 62 | let session = self 63 | .get_session_builder(num_thread, builder_fn)? 64 | .commit_from_memory(model_bytes)?; 65 | 66 | self.init(session); 67 | 68 | Ok(()) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/scale_param.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug)] 2 | pub struct ScaleParam { 3 | pub src_width: u32, 4 | pub src_height: u32, 5 | pub dst_width: u32, 6 | pub dst_height: u32, 7 | pub scale_width: f32, 8 | pub scale_height: f32, 9 | } 10 | 11 | impl ScaleParam { 12 | pub fn new( 13 | src_width: u32, 14 | src_height: u32, 15 | dst_width: u32, 16 | dst_height: u32, 17 | scale_width: f32, 18 | scale_height: f32, 19 | ) -> Self { 20 | Self { 21 | src_width, 22 | src_height, 23 | dst_width, 24 | dst_height, 25 | scale_width, 26 | scale_height, 27 | } 28 | } 29 | 30 | pub fn get_scale_param(src: &image::RgbImage, target_size: u32) -> Self { 31 | let src_width = src.width(); 32 | let src_height = src.height(); 33 | let mut dst_width; 34 | let mut dst_height; 35 | 36 | let ratio: f32 = if src_width > src_height { 37 | target_size as f32 / src_width as f32 38 | } else { 39 | target_size as f32 / src_height as f32 40 | }; 41 | 42 | dst_width = (src_width as f32 * ratio) as u32; 43 | dst_height = (src_height as f32 * ratio) as u32; 44 | 45 | if dst_width % 32 != 0 { 46 | dst_width = (dst_width / 32) * 32; 47 | dst_width = dst_width.max(32); 48 | } 49 | if dst_height % 32 != 0 { 50 | dst_height = (dst_height / 32) * 32; 51 | dst_height = dst_height.max(32); 52 | } 53 | 54 | let scale_width = dst_width as f32 / src_width as f32; 55 | let scale_height = dst_height as f32 / src_height as f32; 56 | 57 | Self::new( 58 | src_width, 59 | src_height, 60 | dst_width, 61 | dst_height, 62 | scale_width, 63 | scale_height, 64 | ) 65 | } 66 | } 67 | 68 | impl std::fmt::Display for ScaleParam { 69 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 70 | write!( 71 | f, 72 | "src_width:{},src_height:{},dst_width:{},dst_height:{},scale_width:{},scale_height:{}", 73 | self.src_width, 74 | self.src_height, 75 | self.dst_width, 76 | self.dst_height, 77 | self.scale_width, 78 | self.scale_height 79 | ) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/ocr_result.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{self, Write}; 2 | 3 | use serde::{Deserialize, Serialize}; 4 | 5 | #[derive(Debug, Clone, Copy, Serialize, Deserialize)] 6 | pub struct Point { 7 | pub x: u32, 8 | pub y: u32, 9 | } 10 | 11 | #[derive(Debug, Serialize, Deserialize)] 12 | pub struct TextBox { 13 | pub points: Vec, 14 | pub score: f32, 15 | } 16 | 17 | impl fmt::Display for TextBox { 18 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 19 | write!( 20 | f, 21 | "TextBox [score({}), [x: {}, y: {}], [x: {}, y: {}], [x: {}, y: {}], [x: {}, y: {}]]", 22 | self.score, 23 | self.points[0].x, 24 | self.points[0].y, 25 | self.points[1].x, 26 | self.points[1].y, 27 | self.points[2].x, 28 | self.points[2].y, 29 | self.points[3].x, 30 | self.points[3].y, 31 | ) 32 | } 33 | } 34 | 35 | #[derive(Debug, Default)] 36 | pub struct Angle { 37 | pub index: i32, 38 | pub score: f32, 39 | } 40 | 41 | impl fmt::Display for Angle { 42 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 43 | let header = if self.index >= 0 { 44 | "Angle" 45 | } else { 46 | "AngleDisabled" 47 | }; 48 | write!( 49 | f, 50 | "{}[Index({}), Score({})]", 51 | header, self.index, self.score 52 | ) 53 | } 54 | } 55 | 56 | #[derive(Debug, Default)] 57 | pub struct TextLine { 58 | pub text: String, 59 | pub text_score: f32, 60 | } 61 | 62 | impl fmt::Display for TextLine { 63 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 64 | write!( 65 | f, 66 | "TextLine[Text({}),TextScore({})]", 67 | self.text, self.text_score 68 | ) 69 | } 70 | } 71 | 72 | #[derive(Debug, Serialize, Deserialize)] 73 | pub struct TextBlock { 74 | pub box_points: Vec, 75 | pub box_score: f32, 76 | 77 | pub angle_index: i32, 78 | pub angle_score: f32, 79 | 80 | pub text: String, 81 | pub text_score: f32, 82 | } 83 | 84 | #[derive(Serialize, Deserialize)] 85 | pub struct OcrResult { 86 | pub text_blocks: Vec, 87 | } 88 | 89 | impl fmt::Display for OcrResult { 90 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 91 | let mut str_builder = String::with_capacity(0); 92 | for text_block in &self.text_blocks { 93 | write!( 94 | str_builder, 95 | "TextBlock[BoxPointsLen({}), BoxScore({}), AngleIndex({}), AngleScore({}), Text({}), TextScore({})]", 96 | text_block.box_points.len(), 97 | text_block.box_score, 98 | text_block.angle_index, 99 | text_block.angle_score, 100 | text_block.text, 101 | text_block.text_score 102 | )?; 103 | } 104 | f.write_str(&str_builder) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /src/angle_net.rs: -------------------------------------------------------------------------------- 1 | use crate::{base_net::BaseNet, ocr_error::OcrError, ocr_result::Angle, ocr_utils::OcrUtils}; 2 | 3 | use ort::{ 4 | inputs, 5 | session::{Session, SessionOutputs}, 6 | value::Tensor, 7 | }; 8 | 9 | const MEAN_VALUES: [f32; 3] = [127.5, 127.5, 127.5]; 10 | const NORM_VALUES: [f32; 3] = [1.0 / 127.5, 1.0 / 127.5, 1.0 / 127.5]; 11 | const ANGLE_DST_WIDTH: u32 = 192; 12 | const ANGLE_DST_HEIGHT: u32 = 48; 13 | const ANGLE_COLS: usize = 2; 14 | 15 | #[derive(Debug)] 16 | pub struct AngleNet { 17 | session: Option, 18 | input_names: Vec, 19 | } 20 | 21 | impl BaseNet for AngleNet { 22 | fn new() -> Self { 23 | Self { 24 | session: None, 25 | input_names: Vec::new(), 26 | } 27 | } 28 | 29 | fn set_input_names(&mut self, input_names: Vec) { 30 | self.input_names = input_names; 31 | } 32 | 33 | fn set_session(&mut self, session: Option) { 34 | self.session = session; 35 | } 36 | } 37 | 38 | impl AngleNet { 39 | pub fn get_angles( 40 | &mut self, 41 | part_imgs: &[image::RgbImage], 42 | do_angle: bool, 43 | most_angle: bool, 44 | ) -> Result, OcrError> { 45 | let mut angles = Vec::new(); 46 | 47 | if do_angle { 48 | for img in part_imgs { 49 | let angle = self.get_angle(img)?; 50 | angles.push(angle); 51 | } 52 | } else { 53 | angles.extend(part_imgs.iter().map(|_| Angle::default())); 54 | } 55 | 56 | if do_angle && most_angle { 57 | let sum: i32 = angles.iter().map(|x| x.index).sum(); 58 | let half_percent = angles.len() as f32 / 2.0; 59 | let most_angle_index = if (sum as f32) < half_percent { 0 } else { 1 }; 60 | 61 | for angle in angles.iter_mut() { 62 | angle.index = most_angle_index; 63 | } 64 | } 65 | 66 | Ok(angles) 67 | } 68 | 69 | fn get_angle(&mut self, img_src: &image::RgbImage) -> Result { 70 | let Some(session) = &mut self.session else { 71 | return Err(OcrError::SessionNotInitialized); 72 | }; 73 | 74 | let angle_img = image::imageops::resize( 75 | img_src, 76 | ANGLE_DST_WIDTH, 77 | ANGLE_DST_HEIGHT, 78 | image::imageops::FilterType::Triangle, 79 | ); 80 | 81 | let input_tensors = 82 | OcrUtils::substract_mean_normalize(&angle_img, &MEAN_VALUES, &NORM_VALUES); 83 | 84 | let input_tensors = Tensor::from_array(input_tensors)?; 85 | 86 | let outputs = session.run(inputs![self.input_names[0].clone() => input_tensors])?; 87 | 88 | let angle = Self::score_to_angle(&outputs, ANGLE_COLS)?; 89 | 90 | Ok(angle) 91 | } 92 | 93 | fn score_to_angle( 94 | output_tensor: &SessionOutputs, 95 | angle_cols: usize, 96 | ) -> Result { 97 | let (_, red_data) = output_tensor.iter().next().unwrap(); 98 | 99 | let src_data: Vec = red_data.try_extract_tensor::()?.1.to_vec(); 100 | 101 | let mut angle = Angle::default(); 102 | let mut max_value = f32::MIN; 103 | let mut angle_index = 0; 104 | 105 | for (i, value) in src_data.iter().take(angle_cols).enumerate() { 106 | if i == 0 || value > &max_value { 107 | max_value = *value; 108 | angle_index = i as i32; 109 | } 110 | } 111 | 112 | angle.index = angle_index; 113 | angle.score = max_value; 114 | Ok(angle) 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /docs/README_zh-Hans.md: -------------------------------------------------------------------------------- 1 | ## paddle-ocr-rs 2 | 3 | 使用 Rust 通过 ONNX Runtime 调用 Paddle OCR 模型进行图片文字识别。 4 | 5 | ### 示例 6 | 7 | 8 | ```rust 9 | use crate::{ocr_error::OcrError, ocr_lite::OcrLite}; 10 | 11 | fn run_test() -> Result<(), OcrError> { 12 | let mut ocr = OcrLite::new(); 13 | ocr.init_models( 14 | "./models/ch_PP-OCRv5_mobile_det.onnx", 15 | "./models/ch_ppocr_mobile_v2.0_cls_infer.onnx", 16 | "./models/ch_PP-OCRv5_rec_mobile_infer.onnx", 17 | 2, 18 | )?; 19 | 20 | println!("===test_1==="); 21 | let res = ocr.detect_from_path( 22 | "./docs/test_images/test_1.png", 23 | 50, 24 | 1024, 25 | 0.5, 26 | 0.3, 27 | 1.6, 28 | false, 29 | false, 30 | )?; 31 | res.text_blocks.iter().for_each(|item| { 32 | println!("text: {} score: {}", item.text, item.text_score); 33 | }); 34 | 35 | println!("===test_2==="); 36 | let res = ocr.detect_from_path( 37 | "./docs/test_images/test_2.png", 38 | 50, 39 | 1024, 40 | 0.5, 41 | 0.3, 42 | 1.6, 43 | false, 44 | false, 45 | )?; 46 | res.text_blocks.iter().for_each(|item| { 47 | println!("text: {} score: {}", item.text, item.text_score); 48 | }); 49 | 50 | // 通过 image 读取图片 51 | println!("===test_3==="); 52 | let test_three_img = image::open("./docs/test_images/test_3.png") 53 | .unwrap() 54 | .to_rgb8(); 55 | let res = ocr.detect(&test_three_img, 50, 1024, 0.5, 0.3, 1.6, true, false)?; 56 | res.text_blocks.iter().for_each(|item| { 57 | println!("text: {} score: {}", item.text, item.text_score); 58 | }); 59 | 60 | Ok(()) 61 | } 62 | 63 | // 某些情况下角度纠正会得出错误结果,支持角度纠正回退,当角度纠正后的文本识别得分低于指定值(或为 NaN)时,将使用进行角度纠正前的图片进行识别 64 | fn run_test_angle_rollback() -> Result<(), OcrError> { 65 | let mut ocr = OcrLite::new(); 66 | ocr.init_models( 67 | "./models/ch_PP-OCRv4_det_infer.onnx", 68 | "./models/ch_ppocr_mobile_v2.0_cls_infer.onnx", 69 | "./models/ch_PP-OCRv4_rec_infer.onnx", 70 | 2, 71 | )?; 72 | 73 | 74 | println!("===test_angle_ori==="); 75 | let test_img = image::open("./docs/test_images/test_4.png") 76 | .unwrap() 77 | .to_rgb8(); 78 | let res = ocr.detect(&test_img, 50, 1024, 0.5, 0.3, 1.6, true, false)?; 79 | res.text_blocks.iter().for_each(|item| { 80 | println!("text: {} score: {}", item.text, item.text_score); 81 | }); 82 | 83 | 84 | println!("===test_angle_rollback==="); 85 | let test_img = image::open("./docs/test_images/test_4.png") 86 | .unwrap() 87 | .to_rgb8(); 88 | let res = 89 | ocr.detect_angle_rollback(&test_img, 50, 1024, 0.5, 0.3, 1.6, true, false, 0.8)?; 90 | res.text_blocks.iter().for_each(|item| { 91 | println!("text: {} score: {}", item.text, item.text_score); 92 | }); 93 | Ok(()) 94 | } 95 | ``` 96 | 97 | ### 参考开发环境 98 | 99 | | 依赖 | 版本号 | 100 | | ---------- | ----------------------------- | 101 | | rustc | 1.84.1 (e71f9a9a9 2025-01-27) | 102 | | cargo | 1.84.1 (66221abde 2024-11-19) | 103 | | OS | Windows 11 24H2 | 104 | | Paddle OCR | 4 | 105 | 106 | ### 文档 107 | 108 | [运行报错](/docs/error/index.md) 109 | 110 | [静态链接](/docs/staticLinking/index.md) 111 | 112 | ### 模型来源 113 | 114 | [RapidOCR Docs](https://rapidai.github.io/RapidOCRDocs/main/model_list/) 115 | 116 | ### 相关事项 117 | 118 | 代码参考自 [RapidOcrOnnx](https://github.com/RapidAI/RapidOcrOnnx),已使用 image 和 imageproc 代替 OpenCV 进行图片相关的实现。 119 | 120 | ### 效果展示 121 | 122 | #### test_1.png 123 | 124 | ![test_1](/docs/test_images/test_1.png) 125 | 126 | ```bash 127 | text: 使用Rust 通过ONNX Runtime 调用 Paddle OCR 模型进行图片文字识别。 score: 0.95269924 128 | text: paddle-ocr-rs score: 0.9979071 129 | ``` 130 | 131 | #### test_2.png 132 | 133 | ![test_2](/docs/test_images/test_2.png) 134 | 135 | ```bash 136 | text: 母婴用品连锁 score: 0.99713486 137 | ``` 138 | 139 | #### test_3.png 140 | 141 | ![test_3](/docs/test_images/test_3.png) 142 | 143 | #### 输出预览 144 | 145 | ```bash 146 | text: salta sobre o cao preguicoso. score: 0.9794339 147 | text: perezoso. A raposa marrom rapida score: 0.9970329 148 | text: marron rapido salta sobre el perro score: 0.9995695 149 | text: salta sopra il cane pigro. El zorro score: 0.99923337 150 | text: paresseux. La volpe marrone rapida score: 0.9991456 151 | text: 《rapide> saute par-dessus le chien score: 0.9685502 152 | text: uber den faulen Hund. Le renard brun score: 0.988613 153 | text: Der ,schnelle" braune Fuchs springt score: 0.97560924 154 | text: from aspammer@website.com is spam. score: 0.98167914 155 | text: & duck/goose, as 12.5% of E-mail score: 0.98472834 156 | text: Over the $43,456.78 #90 dog score: 0.9847551 157 | text: The (quick) [brown] {fox} jumps! score: 0.98300403 158 | ``` 159 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [简体中文](./docs/README_zh-Hans.md) 2 | 3 | ## paddle-ocr-rs 4 | 5 | Use Rust to call Paddle OCR models via ONNX Runtime for image text recognition. 6 | 7 | ### Example 8 | 9 | ```rust 10 | use crate::{ocr_error::OcrError, ocr_lite::OcrLite}; 11 | 12 | fn run_test() -> Result<(), OcrError> { 13 | let mut ocr = OcrLite::new(); 14 | ocr.init_models( 15 | "./models/ch_PP-OCRv5_mobile_det.onnx", 16 | "./models/ch_ppocr_mobile_v2.0_cls_infer.onnx", 17 | "./models/ch_PP-OCRv5_rec_mobile_infer.onnx", 18 | 2, 19 | )?; 20 | 21 | println!("===test_1==="); 22 | let res = ocr.detect_from_path( 23 | "./docs/test_images/test_1.png", 24 | 50, 25 | 1024, 26 | 0.5, 27 | 0.3, 28 | 1.6, 29 | false, 30 | false, 31 | )?; 32 | res.text_blocks.iter().for_each(|item| { 33 | println!("text: {} score: {}", item.text, item.text_score); 34 | }); 35 | 36 | println!("===test_2==="); 37 | let res = ocr.detect_from_path( 38 | "./docs/test_images/test_2.png", 39 | 50, 40 | 1024, 41 | 0.5, 42 | 0.3, 43 | 1.6, 44 | false, 45 | false, 46 | )?; 47 | res.text_blocks.iter().for_each(|item| { 48 | println!("text: {} score: {}", item.text, item.text_score); 49 | }); 50 | 51 | // 通过 image 读取图片 52 | println!("===test_3==="); 53 | let test_three_img = image::open("./docs/test_images/test_3.png") 54 | .unwrap() 55 | .to_rgb8(); 56 | let res = ocr.detect(&test_three_img, 50, 1024, 0.5, 0.3, 1.6, true, false)?; 57 | res.text_blocks.iter().for_each(|item| { 58 | println!("text: {} score: {}", item.text, item.text_score); 59 | }); 60 | 61 | Ok(()) 62 | } 63 | 64 | // 某些情况下角度纠正会得出错误结果,支持角度纠正回退,当角度纠正后的文本识别得分低于指定值(或为 NaN)时,将使用进行角度纠正前的图片进行识别 65 | fn run_test_angle_rollback() -> Result<(), OcrError> { 66 | let mut ocr = OcrLite::new(); 67 | ocr.init_models( 68 | "./models/ch_PP-OCRv4_det_infer.onnx", 69 | "./models/ch_ppocr_mobile_v2.0_cls_infer.onnx", 70 | "./models/ch_PP-OCRv4_rec_infer.onnx", 71 | 2, 72 | )?; 73 | 74 | 75 | println!("===test_angle_ori==="); 76 | let test_img = image::open("./docs/test_images/test_4.png") 77 | .unwrap() 78 | .to_rgb8(); 79 | let res = ocr.detect(&test_img, 50, 1024, 0.5, 0.3, 1.6, true, false)?; 80 | res.text_blocks.iter().for_each(|item| { 81 | println!("text: {} score: {}", item.text, item.text_score); 82 | }); 83 | 84 | 85 | println!("===test_angle_rollback==="); 86 | let test_img = image::open("./docs/test_images/test_4.png") 87 | .unwrap() 88 | .to_rgb8(); 89 | let res = 90 | ocr.detect_angle_rollback(&test_img, 50, 1024, 0.5, 0.3, 1.6, true, false, 0.8)?; 91 | res.text_blocks.iter().for_each(|item| { 92 | println!("text: {} score: {}", item.text, item.text_score); 93 | }); 94 | Ok(()) 95 | } 96 | ``` 97 | 98 | ### Reference Development Environment 99 | 100 | | Dependency | Version | 101 | | ---------- | ----------------------------- | 102 | | rustc | 1.84.1 (e71f9a9a9 2025-01-27) | 103 | | cargo | 1.84.1 (66221abde 2024-11-19) | 104 | | OS | Windows 11 24H2 | 105 | | Paddle OCR | 4 | 106 | 107 | ### Documentation 108 | 109 | [Error Handling](/docs/error/index.md) 110 | 111 | [Static Linking Reference](/docs/staticLinking/index.md) 112 | 113 | ### Model Source 114 | 115 | [RapidOCR Docs](https://rapidai.github.io/RapidOCRDocs/main/model_list/) 116 | 117 | ### Related Notes 118 | 119 | The code is referenced from [RapidOcrOnnx](https://github.com/RapidAI/RapidOcrOnnx), and has replaced OpenCV with image and imageproc libraries for image-related implementations. 120 | 121 | ### Results Demonstration 122 | 123 | #### test_1.png 124 | 125 | ![test_1](/docs/test_images/test_1.png) 126 | 127 | ```bash 128 | text: 使用Rust 通过ONNX Runtime 调用 Paddle OCR 模型进行图片文字识别。 score: 0.95269924 129 | text: paddle-ocr-rs score: 0.9979071 130 | ``` 131 | 132 | #### test_2.png 133 | 134 | ![test_2](/docs/test_images/test_2.png) 135 | 136 | ```bash 137 | text: 母婴用品连锁 score: 0.99713486 138 | ``` 139 | 140 | #### test_3.png 141 | 142 | ![test_3](/docs/test_images/test_3.png) 143 | 144 | #### Output Preview 145 | 146 | ```bash 147 | text: salta sobre o cao preguicoso. score: 0.9794339 148 | text: perezoso. A raposa marrom rapida score: 0.9970329 149 | text: marron rapido salta sobre el perro score: 0.9995695 150 | text: salta sopra il cane pigro. El zorro score: 0.99923337 151 | text: paresseux. La volpe marrone rapida score: 0.9991456 152 | text: 《rapide> saute par-dessus le chien score: 0.9685502 153 | text: uber den faulen Hund. Le renard brun score: 0.988613 154 | text: Der ,schnelle" braune Fuchs springt score: 0.97560924 155 | text: from aspammer@website.com is spam. score: 0.98167914 156 | text: & duck/goose, as 12.5% of E-mail score: 0.98472834 157 | text: Over the $43,456.78 #90 dog score: 0.9847551 158 | text: The (quick) [brown] {fox} jumps! score: 0.98300403 159 | ``` 160 | -------------------------------------------------------------------------------- /src/ocr_utils.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | ocr_error::OcrError, 3 | ocr_result::{Point, TextBox}, 4 | }; 5 | use image::imageops; 6 | use imageproc::geometric_transformations::{Interpolation, Projection}; 7 | use ndarray::{Array, Array4}; 8 | 9 | pub struct OcrUtils; 10 | 11 | impl OcrUtils { 12 | pub fn substract_mean_normalize( 13 | img_src: &image::RgbImage, 14 | mean_vals: &[f32], 15 | norm_vals: &[f32], 16 | ) -> Array4 { 17 | let cols = img_src.width(); 18 | let rows = img_src.height(); 19 | let channels = 3; 20 | 21 | let mut input_tensor = Array::zeros((1, channels as usize, rows as usize, cols as usize)); 22 | 23 | // 获取图像数据 24 | unsafe { 25 | for r in 0..rows { 26 | for c in 0..cols { 27 | for ch in 0..channels { 28 | let idx = (r * cols * channels + c * channels + ch) as usize; 29 | let value = img_src.get_unchecked(idx).to_owned(); 30 | let data = value as f32 * norm_vals[ch as usize] 31 | - mean_vals[ch as usize] * norm_vals[ch as usize]; 32 | input_tensor[[0, ch as usize, r as usize, c as usize]] = data; 33 | } 34 | } 35 | } 36 | } 37 | 38 | input_tensor 39 | } 40 | 41 | pub fn make_padding( 42 | img_src: &image::RgbImage, 43 | padding: u32, 44 | ) -> Result { 45 | if padding == 0 { 46 | return Ok(img_src.clone()); 47 | } 48 | 49 | let width = img_src.width(); 50 | let height = img_src.height(); 51 | 52 | let mut padding_src = image::RgbImage::new(width + 2 * padding, height + 2 * padding); 53 | imageproc::drawing::draw_filled_rect_mut( 54 | &mut padding_src, 55 | imageproc::rect::Rect::at(0, 0).of_size(width + 2 * padding, height + 2 * padding), 56 | image::Rgb([255, 255, 255]), 57 | ); 58 | 59 | image::imageops::replace(&mut padding_src, img_src, padding as i64, padding as i64); 60 | 61 | Ok(padding_src) 62 | } 63 | 64 | pub fn get_part_images( 65 | img_src: &image::RgbImage, 66 | text_boxes: &[TextBox], 67 | ) -> Vec { 68 | text_boxes 69 | .iter() 70 | .map(|text_box| Self::get_rotate_crop_image(img_src, &text_box.points)) 71 | .collect() 72 | } 73 | 74 | pub fn get_rotate_crop_image( 75 | img_src: &image::RgbImage, 76 | box_points: &[Point], 77 | ) -> image::RgbImage { 78 | let mut points = box_points.to_vec(); 79 | 80 | // 计算边界框 81 | let (min_x, min_y, max_x, max_y) = points.iter().fold( 82 | (u32::MAX, u32::MAX, 0u32, 0u32), 83 | |(min_x, min_y, max_x, max_y), point| { 84 | ( 85 | min_x.min(point.x), 86 | min_y.min(point.y), 87 | max_x.max(point.x), 88 | max_y.max(point.y), 89 | ) 90 | }, 91 | ); 92 | 93 | // 裁剪图像 94 | let img_crop = 95 | imageops::crop_imm(img_src, min_x, min_y, max_x - min_x, max_y - min_y).to_image(); 96 | 97 | for point in &mut points { 98 | point.x -= min_x; 99 | point.y -= min_y; 100 | } 101 | 102 | let img_crop_width = ((points[0].x as i32 - points[1].x as i32).pow(2) as f32 103 | + (points[0].y as i32 - points[1].y as i32).pow(2) as f32) 104 | .sqrt() as u32; 105 | let img_crop_height = ((points[0].x as i32 - points[3].x as i32).pow(2) as f32 106 | + (points[0].y as i32 - points[3].y as i32).pow(2) as f32) 107 | .sqrt() as u32; 108 | 109 | let src_points = [ 110 | (points[0].x as f32, points[0].y as f32), 111 | (points[1].x as f32, points[1].y as f32), 112 | (points[2].x as f32, points[2].y as f32), 113 | (points[3].x as f32, points[3].y as f32), 114 | ]; 115 | 116 | let dst_points = [ 117 | (0.0, 0.0), 118 | (img_crop_width as f32, 0.0), 119 | (img_crop_width as f32, img_crop_height as f32), 120 | (0.0, img_crop_height as f32), 121 | ]; 122 | 123 | let projection = Projection::from_control_points(src_points, dst_points) 124 | .expect("Failed to create projection transformation"); 125 | 126 | let mut part_img = image::RgbImage::new(img_crop_width, img_crop_height); 127 | imageproc::geometric_transformations::warp_into( 128 | &img_crop, 129 | &projection, 130 | Interpolation::Nearest, 131 | image::Rgb([255, 255, 255]), 132 | &mut part_img, 133 | ); 134 | 135 | // 根据需要旋转图像 136 | if part_img.height() >= part_img.width() * 3 / 2 { 137 | let mut rotated = image::RgbImage::new(part_img.height(), part_img.width()); 138 | 139 | for (x, y, pixel) in part_img.enumerate_pixels() { 140 | rotated.put_pixel(y, part_img.width() - 1 - x, *pixel); 141 | } 142 | 143 | rotated 144 | } else { 145 | part_img 146 | } 147 | } 148 | 149 | pub fn mat_rotate_clock_wise_180(src: &mut image::RgbImage) { 150 | imageops::rotate180_in_place(src); 151 | } 152 | 153 | pub fn calculate_mean_with_mask( 154 | img: &image::ImageBuffer, Vec>, 155 | mask: &image::ImageBuffer, Vec>, 156 | ) -> f32 { 157 | let mut sum: f32 = 0.0; 158 | let mut mask_count = 0; 159 | 160 | assert_eq!(img.width(), mask.width()); 161 | assert_eq!(img.height(), mask.height()); 162 | 163 | for y in 0..img.height() { 164 | for x in 0..img.width() { 165 | let mask_value = mask.get_pixel(x, y)[0]; 166 | if mask_value > 0 { 167 | let pixel = img.get_pixel(x, y); 168 | sum += pixel[0]; 169 | mask_count += 1; 170 | } 171 | } 172 | } 173 | 174 | if mask_count == 0 { 175 | return 0.0; 176 | } 177 | 178 | sum / mask_count as f32 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /src/crnn_net.rs: -------------------------------------------------------------------------------- 1 | use ort::session::Session; 2 | use ort::value::Tensor; 3 | use ort::{inputs, session::builder::SessionBuilder}; 4 | use std::collections::HashMap; 5 | 6 | use crate::{base_net::BaseNet, ocr_error::OcrError, ocr_result::TextLine, ocr_utils::OcrUtils}; 7 | 8 | const CRNN_DST_HEIGHT: u32 = 48; 9 | const MEAN_VALUES: [f32; 3] = [127.5, 127.5, 127.5]; 10 | const NORM_VALUES: [f32; 3] = [1.0 / 127.5, 1.0 / 127.5, 1.0 / 127.5]; 11 | 12 | #[derive(Debug)] 13 | pub struct CrnnNet { 14 | session: Option, 15 | keys: Vec, 16 | input_names: Vec, 17 | } 18 | 19 | impl BaseNet for CrnnNet { 20 | fn new() -> Self { 21 | Self { 22 | session: None, 23 | keys: Vec::new(), 24 | input_names: Vec::new(), 25 | } 26 | } 27 | 28 | fn set_input_names(&mut self, input_names: Vec) { 29 | self.input_names = input_names; 30 | } 31 | 32 | fn set_session(&mut self, session: Option) { 33 | self.session = session; 34 | } 35 | } 36 | 37 | impl CrnnNet { 38 | pub fn init_model( 39 | &mut self, 40 | path: &str, 41 | num_thread: usize, 42 | builder_fn: Option Result>, 43 | ) -> Result<(), OcrError> { 44 | BaseNet::init_model(self, path, num_thread, builder_fn)?; 45 | 46 | self.keys = self.get_keys()?; 47 | 48 | Ok(()) 49 | } 50 | 51 | pub fn init_model_dict_file( 52 | &mut self, 53 | path: &str, 54 | num_thread: usize, 55 | builder_fn: Option Result>, 56 | dict_file_path: &str, 57 | ) -> Result<(), OcrError> { 58 | BaseNet::init_model(self, path, num_thread, builder_fn)?; 59 | 60 | self.read_keys_from_file(dict_file_path)?; 61 | 62 | Ok(()) 63 | } 64 | 65 | pub fn init_model_from_memory( 66 | &mut self, 67 | model_bytes: &[u8], 68 | num_thread: usize, 69 | builder_fn: Option Result>, 70 | ) -> Result<(), OcrError> { 71 | BaseNet::init_model_from_memory(self, model_bytes, num_thread, builder_fn)?; 72 | 73 | self.keys = self.get_keys()?; 74 | 75 | Ok(()) 76 | } 77 | 78 | fn get_keys(&mut self) -> Result, OcrError> { 79 | // 简单处理下报错,模型正确的话并无概率出错 80 | let model_charater_list = self 81 | .session 82 | .as_ref() 83 | .expect("crnn_net session not initialized") 84 | .metadata() 85 | .expect("crnn_net metadata not initialized") 86 | .custom("character") 87 | .expect("crnn_net character not initialized[0]") 88 | .expect("crnn_net character not initialized[1]"); 89 | 90 | // 大概估一个数即可 91 | let mut keys = Vec::with_capacity((model_charater_list.len() as f32 / 3.9) as usize); 92 | 93 | keys.push("#".to_string()); 94 | 95 | keys.extend(model_charater_list.split('\n').map(|s| s.to_string())); 96 | 97 | keys.push(" ".to_string()); 98 | 99 | Ok(keys) 100 | } 101 | 102 | fn read_keys_from_file(&mut self, path: &str) -> Result<(), OcrError> { 103 | let content = std::fs::read_to_string(path)?; 104 | let mut keys = Vec::new(); 105 | 106 | keys.extend(content.split('\n').map(|s| s.to_string())); 107 | self.keys = keys; 108 | Ok(()) 109 | } 110 | 111 | pub fn get_text_lines( 112 | &mut self, 113 | part_imgs: &[image::RgbImage], 114 | angle_rollback_records: &HashMap, 115 | angle_rollback_threshold: f32, 116 | ) -> Result, OcrError> { 117 | let mut text_lines = Vec::new(); 118 | 119 | for (index, img) in part_imgs.iter().enumerate() { 120 | let mut text_line = self.get_text_line(img)?; 121 | 122 | if (text_line.text_score.is_nan() || text_line.text_score < angle_rollback_threshold) 123 | && let Some(angle_rollback_record) = angle_rollback_records.get(&index) 124 | { 125 | text_line = self.get_text_line(angle_rollback_record)?; 126 | } 127 | 128 | text_lines.push(text_line); 129 | } 130 | 131 | Ok(text_lines) 132 | } 133 | 134 | fn get_text_line(&mut self, img_src: &image::RgbImage) -> Result { 135 | let Some(session) = &mut self.session else { 136 | return Err(OcrError::SessionNotInitialized); 137 | }; 138 | 139 | let scale = CRNN_DST_HEIGHT as f32 / img_src.height() as f32; 140 | let dst_width = (img_src.width() as f32 * scale) as u32; 141 | 142 | let src_resize = image::imageops::resize( 143 | img_src, 144 | dst_width, 145 | CRNN_DST_HEIGHT, 146 | image::imageops::FilterType::Triangle, 147 | ); 148 | 149 | let input_tensors = 150 | OcrUtils::substract_mean_normalize(&src_resize, &MEAN_VALUES, &NORM_VALUES); 151 | 152 | let input_tensors = Tensor::from_array(input_tensors)?; 153 | 154 | let outputs = session.run(inputs![self.input_names[0].clone() => input_tensors])?; 155 | 156 | let (_, red_data) = outputs.iter().next().unwrap(); 157 | 158 | let (shape, src_data) = red_data.try_extract_tensor::()?; 159 | let dimensions = shape; 160 | let height = dimensions[1] as usize; 161 | let width = dimensions[2] as usize; 162 | let src_data: Vec = src_data.to_vec(); 163 | 164 | Self::score_to_text_line(&src_data, height, width, &self.keys) 165 | } 166 | 167 | fn score_to_text_line( 168 | output_data: &[f32], 169 | height: usize, 170 | width: usize, 171 | keys: &[String], 172 | ) -> Result { 173 | let mut text_line = TextLine::default(); 174 | let mut last_index = 0; 175 | 176 | let mut text_score_sum = 0.0; 177 | let mut text_socre_count = 0; 178 | for i in 0..height { 179 | let start = i * width; 180 | let stop = (i + 1) * width; 181 | let slice = &output_data[start..stop.min(output_data.len())]; 182 | 183 | let (max_index, max_value) = 184 | slice 185 | .iter() 186 | .enumerate() 187 | .fold((0, f32::MIN), |(max_idx, max_val), (idx, &val)| { 188 | if val > max_val { 189 | (idx, val) 190 | } else { 191 | (max_idx, max_val) 192 | } 193 | }); 194 | 195 | if max_index > 0 && max_index < keys.len() && !(i > 0 && max_index == last_index) { 196 | text_line.text.push_str(&keys[max_index]); 197 | text_score_sum += max_value; 198 | text_socre_count += 1; 199 | } 200 | last_index = max_index; 201 | } 202 | 203 | text_line.text_score = text_score_sum / text_socre_count as f32; 204 | Ok(text_line) 205 | } 206 | } 207 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::too_many_arguments)] 2 | 3 | pub mod angle_net; 4 | pub mod base_net; 5 | pub mod crnn_net; 6 | pub mod db_net; 7 | pub mod ocr_error; 8 | pub mod ocr_lite; 9 | pub mod ocr_result; 10 | pub mod ocr_utils; 11 | pub mod scale_param; 12 | 13 | #[cfg(test)] 14 | mod tests { 15 | use crate::{ocr_error::OcrError, ocr_lite::OcrLite}; 16 | use std::fs; 17 | use std::io::{Cursor, Read}; 18 | 19 | #[test] 20 | fn run_test() -> Result<(), OcrError> { 21 | let mut ocr = OcrLite::new(); 22 | ocr.init_models( 23 | "./models/ch_PP-OCRv5_mobile_det.onnx", 24 | "./models/ch_ppocr_mobile_v2.0_cls_infer.onnx", 25 | "./models/ch_PP-OCRv5_rec_mobile_infer.onnx", 26 | 2, 27 | )?; 28 | 29 | println!("===test_1==="); 30 | let res = ocr.detect_from_path( 31 | "./docs/test_images/test_1.png", 32 | 50, 33 | 1024, 34 | 0.5, 35 | 0.3, 36 | 1.6, 37 | false, 38 | false, 39 | )?; 40 | res.text_blocks.iter().for_each(|item| { 41 | println!("text: {} score: {}", item.text, item.text_score); 42 | }); 43 | println!("===test_2==="); 44 | let res = ocr.detect_from_path( 45 | "./docs/test_images/test_2.png", 46 | 50, 47 | 1024, 48 | 0.5, 49 | 0.3, 50 | 1.6, 51 | false, 52 | false, 53 | )?; 54 | res.text_blocks.iter().for_each(|item| { 55 | println!("text: {} score: {}", item.text, item.text_score); 56 | }); 57 | 58 | // 通过 image 读取图片 59 | println!("===test_3==="); 60 | let test_three_img = image::open("./docs/test_images/test_3.png") 61 | .unwrap() 62 | .to_rgb8(); 63 | let res = ocr.detect(&test_three_img, 50, 1024, 0.5, 0.3, 1.6, true, false)?; 64 | res.text_blocks.iter().for_each(|item| { 65 | println!("text: {} score: {}", item.text, item.text_score); 66 | }); 67 | 68 | Ok(()) 69 | } 70 | 71 | #[test] 72 | fn run_test_from_memory() -> Result<(), OcrError> { 73 | let det_bytes = fs::read("./models/ch_PP-OCRv4_det_infer.onnx")?; 74 | let cls_bytes = fs::read("./models/ch_ppocr_mobile_v2.0_cls_infer.onnx")?; 75 | let rec_bytes = fs::read("./models/ch_PP-OCRv4_rec_infer.onnx")?; 76 | 77 | let mut ocr = OcrLite::new(); 78 | ocr.init_models_from_memory(&det_bytes, &cls_bytes, &rec_bytes, 2)?; 79 | 80 | println!("===test_from_memory==="); 81 | let test_img = image::open("./docs/test_images/test_1.png") 82 | .unwrap() 83 | .to_rgb8(); 84 | let res = ocr.detect(&test_img, 50, 1024, 0.5, 0.3, 1.6, false, false)?; 85 | res.text_blocks.iter().for_each(|item| { 86 | println!("text: {} score: {}", item.text, item.text_score); 87 | }); 88 | 89 | Ok(()) 90 | } 91 | 92 | #[test] 93 | fn run_test_from_cursor() -> Result<(), OcrError> { 94 | let mut det_file = fs::File::open("./models/ch_PP-OCRv4_det_infer.onnx")?; 95 | let mut cls_file = fs::File::open("./models/ch_ppocr_mobile_v2.0_cls_infer.onnx")?; 96 | let mut rec_file = fs::File::open("./models/ch_PP-OCRv4_rec_infer.onnx")?; 97 | 98 | let mut det_buffer = Vec::new(); 99 | let mut cls_buffer = Vec::new(); 100 | let mut rec_buffer = Vec::new(); 101 | 102 | det_file.read_to_end(&mut det_buffer)?; 103 | cls_file.read_to_end(&mut cls_buffer)?; 104 | rec_file.read_to_end(&mut rec_buffer)?; 105 | 106 | let det_cursor = Cursor::new(det_buffer); 107 | let cls_cursor = Cursor::new(cls_buffer); 108 | let rec_cursor = Cursor::new(rec_buffer); 109 | 110 | let det_bytes = det_cursor.into_inner(); 111 | let cls_bytes = cls_cursor.into_inner(); 112 | let rec_bytes = rec_cursor.into_inner(); 113 | 114 | let mut ocr = OcrLite::new(); 115 | ocr.init_models_from_memory(&det_bytes, &cls_bytes, &rec_bytes, 2)?; 116 | 117 | println!("===test_from_cursor==="); 118 | let test_img = image::open("./docs/test_images/test_2.png") 119 | .unwrap() 120 | .to_rgb8(); 121 | let res = ocr.detect(&test_img, 50, 1024, 0.5, 0.3, 1.6, false, false)?; 122 | res.text_blocks.iter().for_each(|item| { 123 | println!("text: {} score: {}", item.text, item.text_score); 124 | }); 125 | 126 | Ok(()) 127 | } 128 | 129 | #[test] 130 | fn run_test_angle_rollback() -> Result<(), OcrError> { 131 | let mut ocr = OcrLite::new(); 132 | ocr.init_models( 133 | "./models/ch_PP-OCRv4_det_infer.onnx", 134 | "./models/ch_ppocr_mobile_v2.0_cls_infer.onnx", 135 | "./models/ch_PP-OCRv4_rec_infer.onnx", 136 | 2, 137 | )?; 138 | 139 | println!("===test_angle_ori==="); 140 | let test_img = image::open("./docs/test_images/test_4.png") 141 | .unwrap() 142 | .to_rgb8(); 143 | let res = ocr.detect(&test_img, 50, 1024, 0.5, 0.3, 1.6, true, false)?; 144 | res.text_blocks.iter().for_each(|item| { 145 | println!("text: {} score: {}", item.text, item.text_score); 146 | }); 147 | 148 | println!("===test_angle_rollback==="); 149 | let test_img = image::open("./docs/test_images/test_4.png") 150 | .unwrap() 151 | .to_rgb8(); 152 | let res = 153 | ocr.detect_angle_rollback(&test_img, 50, 1024, 0.5, 0.3, 1.6, true, false, 0.8)?; 154 | res.text_blocks.iter().for_each(|item| { 155 | println!("text: {} score: {}", item.text, item.text_score); 156 | }); 157 | 158 | Ok(()) 159 | } 160 | 161 | #[test] 162 | fn run_test_from_custom() -> Result<(), OcrError> { 163 | let mut ocr = OcrLite::new(); 164 | ocr.init_models_custom( 165 | "./models/ch_PP-OCRv5_mobile_det.onnx", 166 | "./models/ch_ppocr_mobile_v2.0_cls_infer.onnx", 167 | "./models/ch_PP-OCRv5_rec_mobile_infer.onnx", 168 | |builder| builder.with_inter_threads(2)?.with_intra_threads(2), 169 | )?; 170 | 171 | println!("===test_from_custom==="); 172 | let res = ocr.detect_from_path( 173 | "./docs/test_images/test_4.png", 174 | 50, 175 | 1024, 176 | 0.5, 177 | 0.3, 178 | 1.6, 179 | false, 180 | false, 181 | )?; 182 | res.text_blocks.iter().for_each(|item| { 183 | println!("text: {} score: {}", item.text, item.text_score); 184 | }); 185 | 186 | Ok(()) 187 | } 188 | 189 | #[test] 190 | fn run_test_from_custom_with_dict() -> Result<(), OcrError> { 191 | let mut ocr = OcrLite::new(); 192 | ocr.init_models_with_dict( 193 | "./models/ch_PP-OCRv5_mobile_det.onnx", 194 | "./models/ch_ppocr_mobile_v2.0_cls_infer.onnx", 195 | "./models/ch_PP-OCRv5_rec_mobile_infer_no_dict.onnx", 196 | "./models/dict.txt", 197 | 2, 198 | )?; 199 | 200 | println!("===test_from_custom_with_dict==="); 201 | let res = ocr.detect_from_path( 202 | "./docs/test_images/test_4.png", 203 | 50, 204 | 1024, 205 | 0.5, 206 | 0.3, 207 | 1.6, 208 | false, 209 | false, 210 | )?; 211 | res.text_blocks.iter().for_each(|item| { 212 | println!("text: {} score: {}", item.text, item.text_score); 213 | }); 214 | 215 | Ok(()) 216 | } 217 | } 218 | -------------------------------------------------------------------------------- /src/ocr_lite.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use image::ImageBuffer; 4 | use ort::session::builder::SessionBuilder; 5 | 6 | use crate::{ 7 | angle_net::AngleNet, 8 | base_net::BaseNet, 9 | crnn_net::CrnnNet, 10 | db_net::DbNet, 11 | ocr_error::OcrError, 12 | ocr_result::{OcrResult, Point, TextBlock}, 13 | ocr_utils::OcrUtils, 14 | scale_param::ScaleParam, 15 | }; 16 | 17 | #[derive(Debug)] 18 | pub struct OcrLite { 19 | db_net: DbNet, 20 | angle_net: AngleNet, 21 | crnn_net: CrnnNet, 22 | } 23 | 24 | impl Default for OcrLite { 25 | fn default() -> Self { 26 | Self::new() 27 | } 28 | } 29 | 30 | impl OcrLite { 31 | pub fn new() -> Self { 32 | Self { 33 | db_net: DbNet::new(), 34 | angle_net: AngleNet::new(), 35 | crnn_net: CrnnNet::new(), 36 | } 37 | } 38 | 39 | pub fn init_models( 40 | &mut self, 41 | det_path: &str, 42 | cls_path: &str, 43 | rec_path: &str, 44 | num_thread: usize, 45 | ) -> Result<(), OcrError> { 46 | self.db_net.init_model(det_path, num_thread, None)?; 47 | self.angle_net.init_model(cls_path, num_thread, None)?; 48 | self.crnn_net.init_model(rec_path, num_thread, None)?; 49 | Ok(()) 50 | } 51 | 52 | pub fn init_models_with_dict( 53 | &mut self, 54 | det_path: &str, 55 | cls_path: &str, 56 | rec_path: &str, 57 | dict_path: &str, 58 | num_thread: usize, 59 | ) -> Result<(), OcrError> { 60 | self.db_net.init_model(det_path, num_thread, None)?; 61 | self.angle_net.init_model(cls_path, num_thread, None)?; 62 | self.crnn_net 63 | .init_model_dict_file(rec_path, num_thread, None, dict_path)?; 64 | Ok(()) 65 | } 66 | 67 | pub fn init_models_custom( 68 | &mut self, 69 | det_path: &str, 70 | cls_path: &str, 71 | rec_path: &str, 72 | builder_fn: fn(SessionBuilder) -> Result, 73 | ) -> Result<(), OcrError> { 74 | self.db_net.init_model(det_path, 0, Some(builder_fn))?; 75 | self.angle_net.init_model(cls_path, 0, Some(builder_fn))?; 76 | self.crnn_net.init_model(rec_path, 0, Some(builder_fn))?; 77 | Ok(()) 78 | } 79 | 80 | pub fn init_models_from_memory( 81 | &mut self, 82 | det_bytes: &[u8], 83 | cls_bytes: &[u8], 84 | rec_bytes: &[u8], 85 | num_thread: usize, 86 | ) -> Result<(), OcrError> { 87 | self.db_net 88 | .init_model_from_memory(det_bytes, num_thread, None)?; 89 | self.angle_net 90 | .init_model_from_memory(cls_bytes, num_thread, None)?; 91 | self.crnn_net 92 | .init_model_from_memory(rec_bytes, num_thread, None)?; 93 | Ok(()) 94 | } 95 | 96 | pub fn init_models_from_memory_custom( 97 | &mut self, 98 | det_bytes: &[u8], 99 | cls_bytes: &[u8], 100 | rec_bytes: &[u8], 101 | builder_fn: fn(SessionBuilder) -> Result, 102 | ) -> Result<(), OcrError> { 103 | self.db_net 104 | .init_model_from_memory(det_bytes, 0, Some(builder_fn))?; 105 | self.angle_net 106 | .init_model_from_memory(cls_bytes, 0, Some(builder_fn))?; 107 | self.crnn_net 108 | .init_model_from_memory(rec_bytes, 0, Some(builder_fn))?; 109 | Ok(()) 110 | } 111 | 112 | fn detect_base( 113 | &mut self, 114 | img_src: &image::RgbImage, 115 | padding: u32, 116 | max_side_len: u32, 117 | box_score_thresh: f32, 118 | box_thresh: f32, 119 | un_clip_ratio: f32, 120 | do_angle: bool, 121 | most_angle: bool, 122 | angle_rollback: bool, 123 | angle_rollback_threshold: f32, 124 | ) -> Result { 125 | let origin_max_side = img_src.width().max(img_src.height()); 126 | let mut resize; 127 | if max_side_len == 0 || max_side_len > origin_max_side { 128 | resize = origin_max_side; 129 | } else { 130 | resize = max_side_len; 131 | } 132 | resize += 2 * padding; 133 | 134 | let padding_src = OcrUtils::make_padding(img_src, padding)?; 135 | 136 | let scale = ScaleParam::get_scale_param(&padding_src, resize); 137 | 138 | self.detect_once( 139 | &padding_src, 140 | &scale, 141 | padding, 142 | box_score_thresh, 143 | box_thresh, 144 | un_clip_ratio, 145 | do_angle, 146 | most_angle, 147 | angle_rollback, 148 | angle_rollback_threshold, 149 | ) 150 | } 151 | 152 | /// 检测图片 153 | /// 154 | /// # Arguments 155 | /// 156 | /// - `&self` (`undefined`) - Describe this parameter. 157 | /// - `img_src` (`&image`) - 图片 158 | /// - `padding` (`u32`) - 变换图片时添加边框的宽度(提高检测效果) 159 | /// - `max_side_len` (`u32`) - 变换图片后图片宽和高保留的最大边长(超出该尺寸的图片将缩小) 160 | /// - `box_score_thresh` (`f32`) - 检测存在文本的区域的分值阈值 161 | /// - `do_angle` (`bool`) - 是否进行角度检测 162 | /// ``` 163 | pub fn detect( 164 | &mut self, 165 | img_src: &image::RgbImage, 166 | padding: u32, 167 | max_side_len: u32, 168 | box_score_thresh: f32, 169 | box_thresh: f32, 170 | un_clip_ratio: f32, 171 | do_angle: bool, 172 | most_angle: bool, 173 | ) -> Result { 174 | self.detect_base( 175 | img_src, 176 | padding, 177 | max_side_len, 178 | box_score_thresh, 179 | box_thresh, 180 | un_clip_ratio, 181 | do_angle, 182 | most_angle, 183 | false, 184 | 0.0, 185 | ) 186 | } 187 | 188 | /// 支持角度回滚的检测图片 189 | /// 在 do_angle 为 true 时生效,如果图片经过了角度纠正,但识别效果过差,则取消角度纠正 190 | /// 191 | /// # Arguments 192 | /// 193 | /// - `&self` (`undefined`) - Describe this parameter. 194 | /// - `img_src` (`&image`) - 图片 195 | /// - `padding` (`u32`) - 变换图片时添加的边框的宽度(提高检测效果) 196 | /// - `max_side_len` (`u32`) - 变换图片后图片宽和高保留的最大边长(超出该尺寸的图片将缩小) 197 | /// - `box_score_thresh` (`f32`) - 检测存在文本的区域的分值阈值 198 | /// - `do_angle` (`bool`) - 是否进行角度检测 199 | /// - `angle_rollback_threshold` (`f32`) - 角度回滚的阈值,如果识别到的文字得分低于该值(或等于 NaN),则取消角度回滚 200 | /// ``` 201 | pub fn detect_angle_rollback( 202 | &mut self, 203 | img_src: &image::RgbImage, 204 | padding: u32, 205 | max_side_len: u32, 206 | box_score_thresh: f32, 207 | box_thresh: f32, 208 | un_clip_ratio: f32, 209 | do_angle: bool, 210 | most_angle: bool, 211 | angle_rollback_threshold: f32, 212 | ) -> Result { 213 | self.detect_base( 214 | img_src, 215 | padding, 216 | max_side_len, 217 | box_score_thresh, 218 | box_thresh, 219 | un_clip_ratio, 220 | do_angle, 221 | most_angle, 222 | true, 223 | angle_rollback_threshold, 224 | ) 225 | } 226 | 227 | pub fn detect_from_path( 228 | &mut self, 229 | img_path: &str, 230 | padding: u32, 231 | max_side_len: u32, 232 | box_score_thresh: f32, 233 | box_thresh: f32, 234 | un_clip_ratio: f32, 235 | do_angle: bool, 236 | most_angle: bool, 237 | ) -> Result { 238 | let img_src = image::open(img_path)?.to_rgb8(); 239 | 240 | self.detect( 241 | &img_src, 242 | padding, 243 | max_side_len, 244 | box_score_thresh, 245 | box_thresh, 246 | un_clip_ratio, 247 | do_angle, 248 | most_angle, 249 | ) 250 | } 251 | 252 | fn detect_once( 253 | &mut self, 254 | img_src: &image::RgbImage, 255 | scale: &ScaleParam, 256 | padding: u32, 257 | box_score_thresh: f32, 258 | box_thresh: f32, 259 | un_clip_ratio: f32, 260 | do_angle: bool, 261 | most_angle: bool, 262 | angle_rollback: bool, 263 | angle_rollback_threshold: f32, 264 | ) -> Result { 265 | let text_boxes = self.db_net.get_text_boxes( 266 | img_src, 267 | scale, 268 | box_score_thresh, 269 | box_thresh, 270 | un_clip_ratio, 271 | )?; 272 | 273 | let part_images = OcrUtils::get_part_images(img_src, &text_boxes); 274 | 275 | let angles = self 276 | .angle_net 277 | .get_angles(&part_images, do_angle, most_angle)?; 278 | 279 | let mut rotated_images: Vec = Vec::with_capacity(part_images.len()); 280 | 281 | // 角度纠正回滚 282 | let mut angle_rollback_records = 283 | HashMap::, Vec>>::new(); 284 | 285 | for (index, (angle, mut part_image)) in 286 | angles.iter().zip(part_images.into_iter()).enumerate() 287 | { 288 | if angle.index == 1 { 289 | if angle_rollback { 290 | // 保留原始副本 291 | angle_rollback_records.insert(index, part_image.clone()); 292 | } 293 | 294 | OcrUtils::mat_rotate_clock_wise_180(&mut part_image); 295 | } 296 | rotated_images.push(part_image); 297 | } 298 | 299 | let text_lines = self.crnn_net.get_text_lines( 300 | &rotated_images, 301 | &angle_rollback_records, 302 | angle_rollback_threshold, 303 | )?; 304 | 305 | let mut text_blocks = Vec::with_capacity(text_lines.len()); 306 | for i in 0..text_lines.len() { 307 | text_blocks.push(TextBlock { 308 | box_points: text_boxes[i] 309 | .points 310 | .iter() 311 | .map(|p| Point { 312 | x: ((p.x as f32) - padding as f32) as u32, 313 | y: ((p.y as f32) - padding as f32) as u32, 314 | }) 315 | .collect(), 316 | box_score: text_boxes[i].score, 317 | angle_index: angles[i].index, 318 | angle_score: angles[i].score, 319 | text: text_lines[i].text.clone(), 320 | text_score: text_lines[i].text_score, 321 | }); 322 | } 323 | 324 | Ok(OcrResult { text_blocks }) 325 | } 326 | } 327 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 paddle-ocr-rust Contributors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/db_net.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | base_net::BaseNet, 3 | ocr_error::OcrError, 4 | ocr_result::{self, TextBox}, 5 | ocr_utils::OcrUtils, 6 | scale_param::ScaleParam, 7 | }; 8 | use geo_clipper::{Clipper, EndType, JoinType}; 9 | use geo_types::{Coord, LineString, Polygon}; 10 | use ort::{inputs, session::SessionOutputs}; 11 | use ort::{session::Session, value::Tensor}; 12 | use std::cmp::Ordering; 13 | 14 | const MEAN_VALUES: [f32; 3] = [ 15 | 0.485_f32 * 255_f32, 16 | 0.456_f32 * 255_f32, 17 | 0.406_f32 * 255_f32, 18 | ]; 19 | const NORM_VALUES: [f32; 3] = [ 20 | 1.0_f32 / 0.229_f32 / 255.0_f32, 21 | 1.0_f32 / 0.224_f32 / 255.0_f32, 22 | 1.0_f32 / 0.225_f32 / 255.0_f32, 23 | ]; 24 | 25 | #[derive(Debug)] 26 | pub struct DbNet { 27 | session: Option, 28 | input_names: Vec, 29 | } 30 | 31 | impl BaseNet for DbNet { 32 | fn new() -> Self { 33 | Self { 34 | session: None, 35 | input_names: Vec::new(), 36 | } 37 | } 38 | 39 | fn set_input_names(&mut self, input_names: Vec) { 40 | self.input_names = input_names; 41 | } 42 | 43 | fn set_session(&mut self, session: Option) { 44 | self.session = session; 45 | } 46 | } 47 | 48 | impl DbNet { 49 | pub fn get_text_boxes( 50 | &mut self, 51 | img_src: &image::RgbImage, 52 | scale: &ScaleParam, 53 | box_score_thresh: f32, 54 | box_thresh: f32, 55 | un_clip_ratio: f32, 56 | ) -> Result, OcrError> { 57 | let Some(session) = &mut self.session else { 58 | return Err(OcrError::SessionNotInitialized); 59 | }; 60 | 61 | let src_resize = image::imageops::resize( 62 | img_src, 63 | scale.dst_width, 64 | scale.dst_height, 65 | image::imageops::FilterType::Triangle, 66 | ); 67 | 68 | let input_tensors = 69 | OcrUtils::substract_mean_normalize(&src_resize, &MEAN_VALUES, &NORM_VALUES); 70 | 71 | let tensor = Tensor::from_array(input_tensors)?; 72 | 73 | let outputs = session.run(inputs![self.input_names[0].clone() => tensor])?; 74 | 75 | let text_boxes = Self::get_text_boxes_core( 76 | &outputs, 77 | src_resize.height(), 78 | src_resize.width(), 79 | &ScaleParam::new( 80 | scale.src_width, 81 | scale.src_height, 82 | scale.dst_width, 83 | scale.dst_height, 84 | scale.scale_width, 85 | scale.scale_height, 86 | ), 87 | box_score_thresh, 88 | box_thresh, 89 | un_clip_ratio, 90 | )?; 91 | 92 | Ok(text_boxes) 93 | } 94 | 95 | fn get_text_boxes_core( 96 | output_tensor: &SessionOutputs, 97 | rows: u32, 98 | cols: u32, 99 | s: &ScaleParam, 100 | box_score_thresh: f32, 101 | box_thresh: f32, 102 | un_clip_ratio: f32, 103 | ) -> Result, OcrError> { 104 | let max_side_thresh = 3.0; 105 | let mut rs_boxes = Vec::new(); 106 | 107 | let (_, red_data) = output_tensor.iter().next().unwrap(); 108 | 109 | let pred_data: Vec = red_data.try_extract_tensor::()?.1.to_vec(); 110 | 111 | let cbuf_data: Vec = pred_data 112 | .iter() 113 | .map(|pixel| (pixel * 255.0) as u8) 114 | .collect(); 115 | 116 | let pred_img: image::ImageBuffer, Vec> = 117 | image::ImageBuffer::from_vec(cols, rows, pred_data).unwrap(); 118 | 119 | let cbuf_img = image::GrayImage::from_vec(cols, rows, cbuf_data).unwrap(); 120 | 121 | let threshold_img = imageproc::contrast::threshold( 122 | &cbuf_img, 123 | (box_thresh * 255.0) as u8, 124 | imageproc::contrast::ThresholdType::Binary, 125 | ); 126 | 127 | let dilate_img = imageproc::morphology::dilate( 128 | &threshold_img, 129 | imageproc::distance_transform::Norm::LInf, 130 | 1, 131 | ); 132 | 133 | let img_contours: Vec> = 134 | imageproc::contours::find_contours(&dilate_img); 135 | 136 | for contour in img_contours { 137 | if contour.points.len() <= 2 { 138 | continue; 139 | } 140 | 141 | let mut max_side = 0.0; 142 | let min_box = Self::get_mini_box(&contour.points, &mut max_side)?; 143 | if max_side < max_side_thresh { 144 | continue; 145 | } 146 | 147 | let score = Self::get_score(&contour, &pred_img)?; 148 | if score < box_score_thresh { 149 | continue; 150 | } 151 | 152 | let clip_box = Self::unclip(&min_box, un_clip_ratio)?; 153 | if clip_box.is_empty() { 154 | continue; 155 | } 156 | 157 | let mut clip_contour = Vec::new(); 158 | for point in &clip_box { 159 | clip_contour.push(*point); 160 | } 161 | 162 | let mut max_side_clip = 0.0; 163 | let clip_min_box = Self::get_mini_box(&clip_contour, &mut max_side_clip)?; 164 | if max_side_clip < max_side_thresh + 2.0 { 165 | continue; 166 | } 167 | 168 | let mut final_points = Vec::new(); 169 | for item in clip_min_box { 170 | let x = (item.x / s.scale_width) as u32; 171 | let ptx = x.min(s.src_width); 172 | 173 | let y = (item.y / s.scale_height) as u32; 174 | let pty = y.min(s.src_height); 175 | 176 | final_points.push(ocr_result::Point { x: ptx, y: pty }); 177 | } 178 | 179 | let text_box = TextBox { 180 | score, 181 | points: final_points, 182 | }; 183 | 184 | rs_boxes.push(text_box); 185 | } 186 | 187 | Ok(rs_boxes) 188 | } 189 | 190 | fn get_mini_box( 191 | contour_points: &[imageproc::point::Point], 192 | min_edge_size: &mut f32, 193 | ) -> Result>, OcrError> { 194 | let rect = imageproc::geometry::min_area_rect(contour_points); 195 | 196 | let mut rect_points: Vec> = rect 197 | .iter() 198 | .map(|p| imageproc::point::Point::new(p.x as f32, p.y as f32)) 199 | .collect(); 200 | 201 | let width = ((rect_points[0].x - rect_points[1].x).powi(2) 202 | + (rect_points[0].y - rect_points[1].y).powi(2)) 203 | .sqrt(); 204 | let height = ((rect_points[1].x - rect_points[2].x).powi(2) 205 | + (rect_points[1].y - rect_points[2].y).powi(2)) 206 | .sqrt(); 207 | 208 | *min_edge_size = width.min(height); 209 | 210 | rect_points.sort_by(|a, b| { 211 | if a.x > b.x { 212 | return Ordering::Greater; 213 | } 214 | if a.x == b.x { 215 | return Ordering::Equal; 216 | } 217 | Ordering::Less 218 | }); 219 | 220 | let mut box_points = Vec::new(); 221 | let index_1; 222 | let index_4; 223 | if rect_points[1].y > rect_points[0].y { 224 | index_1 = 0; 225 | index_4 = 1; 226 | } else { 227 | index_1 = 1; 228 | index_4 = 0; 229 | } 230 | 231 | let index_2; 232 | let index_3; 233 | if rect_points[3].y > rect_points[2].y { 234 | index_2 = 2; 235 | index_3 = 3; 236 | } else { 237 | index_2 = 3; 238 | index_3 = 2; 239 | } 240 | 241 | box_points.push(rect_points[index_1]); 242 | box_points.push(rect_points[index_2]); 243 | box_points.push(rect_points[index_3]); 244 | box_points.push(rect_points[index_4]); 245 | 246 | Ok(box_points) 247 | } 248 | 249 | fn get_score( 250 | contour: &imageproc::contours::Contour, 251 | f_map_mat: &image::ImageBuffer, Vec>, 252 | ) -> Result { 253 | // 初始化边界值 254 | let mut xmin = i32::MAX; 255 | let mut xmax = i32::MIN; 256 | let mut ymin = i32::MAX; 257 | let mut ymax = i32::MIN; 258 | 259 | // 找到轮廓的边界框 260 | for point in contour.points.iter() { 261 | let x = point.x; 262 | let y = point.y; 263 | 264 | if x < xmin { 265 | xmin = x; 266 | } 267 | if x > xmax { 268 | xmax = x; 269 | } 270 | if y < ymin { 271 | ymin = y; 272 | } 273 | if y > ymax { 274 | ymax = y; 275 | } 276 | } 277 | 278 | let width = f_map_mat.width() as i32; 279 | let height = f_map_mat.height() as i32; 280 | 281 | xmin = xmin.max(0).min(width - 1); 282 | xmax = xmax.max(0).min(width - 1); 283 | ymin = ymin.max(0).min(height - 1); 284 | ymax = ymax.max(0).min(height - 1); 285 | 286 | let roi_width = xmax - xmin + 1; 287 | let roi_height = ymax - ymin + 1; 288 | 289 | if roi_width <= 0 || roi_height <= 0 { 290 | return Ok(0.0); 291 | } 292 | 293 | let mut mask = image::GrayImage::new(roi_width as u32, roi_height as u32); 294 | 295 | let mut pts = Vec::>::new(); 296 | for point in contour.points.iter() { 297 | pts.push(imageproc::point::Point::new(point.x - xmin, point.y - ymin)); 298 | } 299 | 300 | imageproc::drawing::draw_polygon_mut(&mut mask, pts.as_slice(), image::Luma([255])); 301 | 302 | let cropped_img = image::imageops::crop_imm( 303 | f_map_mat, 304 | xmin as u32, 305 | ymin as u32, 306 | roi_width as u32, 307 | roi_height as u32, 308 | ) 309 | .to_image(); 310 | 311 | let mean = OcrUtils::calculate_mean_with_mask(&cropped_img, &mask); 312 | 313 | Ok(mean) 314 | } 315 | 316 | fn unclip( 317 | box_points: &[imageproc::point::Point], 318 | unclip_ratio: f32, 319 | ) -> Result>, OcrError> { 320 | let points_arr = box_points.to_vec(); 321 | 322 | let clip_rect_width = ((points_arr[0].x - points_arr[1].x).powi(2) 323 | + (points_arr[0].y - points_arr[1].y).powi(2)) 324 | .sqrt(); 325 | let clip_rect_height = ((points_arr[1].x - points_arr[2].x).powi(2) 326 | + (points_arr[1].y - points_arr[2].y).powi(2)) 327 | .sqrt(); 328 | 329 | if clip_rect_height < 1.001 && clip_rect_width < 1.001 { 330 | return Ok(Vec::new()); 331 | } 332 | 333 | let mut the_cliper_pts = Vec::new(); 334 | for pt in box_points { 335 | let a1 = Coord { 336 | x: pt.x as f64, 337 | y: pt.y as f64, 338 | }; 339 | the_cliper_pts.push(a1); 340 | } 341 | 342 | let area = Self::signed_polygon_area(box_points).abs(); 343 | let length = Self::length_of_points(box_points); 344 | let distance = area * unclip_ratio / length as f32; 345 | 346 | let co = Polygon::new(LineString::new(the_cliper_pts), vec![]); 347 | let solution = co 348 | .offset( 349 | distance as f64, 350 | JoinType::Round(2.0), 351 | EndType::ClosedPolygon, 352 | 1.0, 353 | ) 354 | .0; 355 | 356 | if solution.is_empty() { 357 | return Ok(Vec::new()); 358 | } 359 | 360 | let mut ret_pts = Vec::new(); 361 | for ip in solution.first().unwrap().exterior().points() { 362 | ret_pts.push(imageproc::point::Point::new(ip.x() as i32, ip.y() as i32)); 363 | } 364 | 365 | Ok(ret_pts) 366 | } 367 | 368 | fn signed_polygon_area(points: &[imageproc::point::Point]) -> f32 { 369 | let num_points = points.len(); 370 | let mut pts = Vec::with_capacity(num_points + 1); 371 | pts.extend_from_slice(points); 372 | pts.push(points[0]); 373 | 374 | let mut area = 0.0; 375 | for i in 0..num_points { 376 | area += (pts[i + 1].x - pts[i].x) * (pts[i + 1].y + pts[i].y) / 2.0; 377 | } 378 | 379 | area 380 | } 381 | 382 | fn length_of_points(box_points: &[imageproc::point::Point]) -> f64 { 383 | if box_points.is_empty() { 384 | return 0.0; 385 | } 386 | 387 | let mut length = 0.0; 388 | let pt = box_points[0]; 389 | let mut x0 = pt.x as f64; 390 | let mut y0 = pt.y as f64; 391 | 392 | let mut box_with_first = Vec::from(box_points); 393 | box_with_first.push(pt); 394 | 395 | (1..box_with_first.len()).for_each(|idx| { 396 | let pts = box_with_first[idx]; 397 | let x1 = pts.x as f64; 398 | let y1 = pts.y as f64; 399 | let dx = x1 - x0; 400 | let dy = y1 - y0; 401 | 402 | length += (dx * dx + dy * dy).sqrt(); 403 | 404 | x0 = x1; 405 | y0 = y1; 406 | }); 407 | 408 | length 409 | } 410 | } 411 | --------------------------------------------------------------------------------