├── .gitignore ├── example ├── data │ └── paper-example.png └── predict.rs ├── examples ├── data │ └── paper-example.png ├── predict_yolox.rs ├── predict_detectron2.rs ├── boxes.rs └── ocr.rs ├── src ├── models │ ├── mod.rs │ ├── detectron2.rs │ └── yolox.rs ├── lib.rs ├── error.rs ├── utils.rs ├── layout_element.rs └── ocr │ ├── hocr_ext.rs │ └── mod.rs ├── README.md ├── Cargo.toml ├── LICENSE └── Cargo.lock /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /example/data/paper-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/styrowolf/layoutparser-ort/HEAD/example/data/paper-example.png -------------------------------------------------------------------------------- /examples/data/paper-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/styrowolf/layoutparser-ort/HEAD/examples/data/paper-example.png -------------------------------------------------------------------------------- /src/models/mod.rs: -------------------------------------------------------------------------------- 1 | //! Implemented layout models. 2 | 3 | mod detectron2; 4 | mod yolox; 5 | 6 | pub use detectron2::{Detectron2Model, Detectron2PretrainedModels}; 7 | pub use yolox::{YOLOXModel, YOLOXPretrainedModels}; 8 | -------------------------------------------------------------------------------- /example/predict.rs: -------------------------------------------------------------------------------- 1 | use layoutparser_ort::{Detectron2ONNXModel, Result}; 2 | 3 | fn main() -> Result<()> { 4 | let img = image::open("data/paper-example.png").unwrap(); 5 | 6 | let model = Detectron2ONNXModel::new(layoutparser_ort::ModelType::FasterRCNN)?; 7 | 8 | let predictions = model.predict(&img)?; 9 | 10 | println!("{:?}", predictions); 11 | 12 | Ok(()) 13 | } 14 | -------------------------------------------------------------------------------- /examples/predict_yolox.rs: -------------------------------------------------------------------------------- 1 | use layoutparser_ort::{ 2 | models::{YOLOXModel, YOLOXPretrainedModels}, 3 | Result, 4 | }; 5 | 6 | fn main() -> Result<()> { 7 | let img = image::open("examples/data/paper-example.png").unwrap(); 8 | 9 | let model = YOLOXModel::pretrained(YOLOXPretrainedModels::Tiny)?; 10 | 11 | let predictions = model.predict(&img)?; 12 | 13 | println!("{:?}", predictions); 14 | 15 | Ok(()) 16 | } 17 | -------------------------------------------------------------------------------- /examples/predict_detectron2.rs: -------------------------------------------------------------------------------- 1 | use layoutparser_ort::{ 2 | models::{Detectron2Model, Detectron2PretrainedModels}, 3 | Result, 4 | }; 5 | 6 | fn main() -> Result<()> { 7 | let img = image::open("examples/data/paper-example.png").unwrap(); 8 | 9 | let model = Detectron2Model::pretrained(Detectron2PretrainedModels::FASTER_RCNN_R_50_FPN_3X)?; 10 | 11 | let predictions = model.predict(&img)?; 12 | 13 | println!("{:?}", predictions); 14 | 15 | Ok(()) 16 | } 17 | -------------------------------------------------------------------------------- /examples/boxes.rs: -------------------------------------------------------------------------------- 1 | use layoutparser_ort::{ 2 | models::{Detectron2Model, Detectron2PretrainedModels}, 3 | Result, 4 | }; 5 | 6 | fn main() -> Result<()> { 7 | let img = image::open("examples/data/paper-example.png").unwrap(); 8 | 9 | let model = Detectron2Model::pretrained(Detectron2PretrainedModels::FASTER_RCNN_R_50_FPN_3X)?; 10 | 11 | let predictions = model.predict(&img)?; 12 | 13 | for pred in &predictions { 14 | println!( 15 | "Label: {}, Confidence: {}, Box: {:?}", 16 | pred.element_type, pred.confidence, pred.bbox 17 | ); 18 | } 19 | 20 | println!("{:?}", predictions); 21 | 22 | Ok(()) 23 | } 24 | -------------------------------------------------------------------------------- /examples/ocr.rs: -------------------------------------------------------------------------------- 1 | use layoutparser_ort::{ 2 | models::{Detectron2Model, Detectron2PretrainedModels}, 3 | ocr::TesseractAgent, 4 | Result, 5 | }; 6 | 7 | fn main() -> Result<()> { 8 | let img = image::open("examples/data/paper-example.png").unwrap(); 9 | 10 | let model = Detectron2Model::pretrained(Detectron2PretrainedModels::FASTER_RCNN_R_50_FPN_3X)?; 11 | 12 | let mut predictions = model.predict(&img)?; 13 | 14 | let mut agent = TesseractAgent::new()?; 15 | 16 | for pred in predictions.iter_mut().filter(|e| e.element_type == "Text") { 17 | pred.pad(5.0); 18 | agent.extract_text_to_lm(pred, &img)?; 19 | println!("{:?}", pred.text.as_ref().unwrap()); 20 | } 21 | 22 | Ok(()) 23 | } 24 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # Overview 2 | //! 3 | //! A simplified port of [LayoutParser](https://github.com/Layout-Parser/layout-parser) for detecting layout elements on documents. 4 | //! Runs Detectron2 and YOLOX layout models from [unstructured-inference](https://github.com/Unstructured-IO/unstructured-inference/) 5 | //! in ONNX format through onnxruntime (bindings via [ort](https://github.com/pykeio/ort)). 6 | 7 | mod error; 8 | mod layout_element; 9 | #[cfg(feature = "ocr")] 10 | pub mod ocr; 11 | mod utils; 12 | 13 | pub use error::{Error, Result}; 14 | 15 | // re-exports 16 | pub use geo_types; 17 | pub use image; 18 | pub use ort; 19 | 20 | pub mod models; 21 | #[cfg(feature = "save")] 22 | pub use utils::save; 23 | 24 | pub use layout_element::LayoutElement; 25 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | #[derive(Error, Debug)] 4 | /// layoutparser-ort error variants 5 | pub enum Error { 6 | #[error("ort (onnxruntime) error: {0}")] 7 | /// [`ort`] (onnxruntime) error 8 | Ort(#[from] ort::Error), 9 | #[error("hf-hub error: {0}")] 10 | /// Hugging Face API error 11 | HuggingFace(#[from] hf_hub::api::sync::ApiError), 12 | #[error("tesseract error: {0}")] 13 | #[cfg(feature = "ocr")] 14 | /// Tesseract error 15 | TesseractError(#[from] tesseract::TesseractError), 16 | #[error("hocr-parser error: {0}")] 17 | #[cfg(feature = "ocr")] 18 | /// hOCR parsing error 19 | HOCRParserError(#[from] hocr_parser::HOCRParserError), 20 | #[error("hOCR element conversion error: {0}")] 21 | #[cfg(feature = "ocr")] 22 | /// hOCR element conversion error 23 | HOCRElementConversionError(#[from] crate::ocr::HOCRElementConversionError), 24 | } 25 | 26 | /// A `Result` type alias using [`enum@Error`] instances as the error variant. 27 | pub type Result = std::result::Result; 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # layoutparser-ort 2 | 3 | A simplified port of [LayoutParser](https://github.com/Layout-Parser/layout-parser) for detecting layout elements on documents. Runs Detectron2 and YOLOX layout models from [unstructured-inference](https://github.com/Unstructured-IO/unstructured-inference/) in ONNX format through onnxruntime (bindings via [ort](https://github.com/pykeio/ort)). [Check out the examples for a quick start!](examples/) 4 | 5 | ## License 6 | 7 | `layoutparser-ort` mirrors its API from [LayoutParser](https://github.com/Layout-Parser/layout-parser) and includes preprocessing code derived from [unstructured-inference](https://github.com/Unstructured-IO/unstructured-inference/), both licensed under the Apache License 2.0. Likewise, `layoutparser-ort` is licensed under the Apache License 2.0. 8 | 9 | ## Appendix: Similar libraries 10 | - [surya](https://github.com/VikParuchuri/surya): OCR, layout analysis, reading order, line detection in 90+ languages 11 | - SegFormer (transformers: SegFormer), Donut (transformers: Donut), CRAFT (pytorch) 12 | - License: GPLv3.0 (code), cc-by-nc-sa-4.0 (models) 13 | - cc-by-nc-sa-4.0: noncommerical but author "waive[s] that for any organization under $5M USD in gross revenue in the most recent 12-month period." 14 | - [unstructured-inference](https://github.com/Unstructured-IO/unstructured-inference/): hosted model inference code for layout parsing models 15 | - Models: Detectron2 (LayoutParser-PubLayNet-PyTorch, LayoutParser-PubLayNet-ONNX), YOLOX (probably trained on DocLayNet, Quantized, ONNX), Table-Transformer (transformers: Table Transformer), Donut (transformers: Donut) 16 | - License: Apache 2.0 17 | - [LayoutParser](https://github.com/Layout-Parser/layout-parser): A Unified Toolkit for Deep Learning Based Document Image Analysis 18 | - Models: Detectron2 19 | - License: Apache 2.0 20 | - Documentation: https://layout-parser.readthedocs.io/en/latest/api_doc/elements.html -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | use ndarray::prelude::*; 2 | use ndarray::Data; 3 | 4 | use std::cmp::Ordering; 5 | 6 | // argsort_by function from: https://github.com/rust-ndarray/ndarray/issues/1145 7 | pub fn argsort_by(arr: &ArrayBase, mut compare: F) -> Vec 8 | where 9 | S: Data, 10 | F: FnMut(&S::Elem, &S::Elem) -> Ordering, 11 | { 12 | let mut indices: Vec = (0..arr.len()).collect(); 13 | indices.sort_by(move |&i, &j| compare(&arr[i], &arr[j])); 14 | indices 15 | } 16 | 17 | pub(crate) fn vec_to_bbox(v: Vec) -> [T; 4] { 18 | return [v[0], v[1], v[2], v[3]]; 19 | } 20 | 21 | #[cfg(feature = "save")] 22 | pub(crate) mod save { 23 | use ndarray::{Array1, Array2}; 24 | 25 | pub fn savetxt(a: &Array2, filename: &str) { 26 | let file = std::fs::File::create(filename).unwrap(); 27 | let mut writer = csv::Writer::from_writer(file); 28 | for row in a.outer_iter() { 29 | writer.serialize(row.iter().collect::>()).unwrap(); 30 | } 31 | writer.flush().unwrap(); 32 | } 33 | 34 | pub fn savetxt_u32(a: &Array2, filename: &str) { 35 | let file = std::fs::File::create(filename).unwrap(); 36 | let mut writer = csv::Writer::from_writer(file); 37 | for row in a.outer_iter() { 38 | writer.serialize(row.iter().collect::>()).unwrap(); 39 | } 40 | writer.flush().unwrap(); 41 | } 42 | 43 | pub fn savetxt_usize(a: &Array2, filename: &str) { 44 | let file = std::fs::File::create(filename).unwrap(); 45 | let mut writer = csv::Writer::from_writer(file); 46 | for row in a.outer_iter() { 47 | writer.serialize(row.iter().collect::>()).unwrap(); 48 | } 49 | writer.flush().unwrap(); 50 | } 51 | 52 | pub fn savetxt_usize_a1(a: &Array1, filename: &str) { 53 | let file = std::fs::File::create(filename).unwrap(); 54 | let mut writer = csv::Writer::from_writer(file); 55 | for row in a.outer_iter() { 56 | writer.serialize(row.iter().collect::>()).unwrap(); 57 | } 58 | writer.flush().unwrap(); 59 | } 60 | 61 | pub fn savetxt_f32_a1(a: &Array1, filename: &str) { 62 | let file = std::fs::File::create(filename).unwrap(); 63 | let mut writer = csv::Writer::from_writer(file); 64 | for row in a.outer_iter() { 65 | writer.serialize(row.iter().collect::>()).unwrap(); 66 | } 67 | writer.flush().unwrap(); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "layoutparser-ort" 3 | version = "0.1.0" 4 | edition = "2021" 5 | readme = "README.md" 6 | license = "Apache-2.0" 7 | description = "A simplified port of LayoutParser for detecting layout elements on documents." 8 | homepage = "https://github.com/styrowolf/layoutparser-ort" 9 | repository = "https://github.com/styrowolf/layoutparser-ort" 10 | keywords = ["document", "analysis", "layout", "deep-learning"] 11 | categories = ["computer-vision", "multimedia::images", "science", "visualization"] 12 | exclude = [ 13 | "examples/data/*" 14 | ] 15 | 16 | [[example]] 17 | name = "ocr" 18 | required-features = ["ocr", "png"] 19 | 20 | [[example]] 21 | name = "boxes" 22 | required-features = ["png"] 23 | 24 | [[example]] 25 | name = "predict_detectron2" 26 | required-features = ["png"] 27 | 28 | [[example]] 29 | name = "predict_yolox" 30 | required-features = ["png"] 31 | 32 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 33 | 34 | [dependencies] 35 | csv = { version = "1.3.0", optional = true } 36 | geo-types = "0.7.13" 37 | hf-hub = "0.3.2" 38 | image = { version = "0.25.1", default-features = false } 39 | itertools = "0.12.1" 40 | ndarray = "0.15.6" 41 | ort = "2.0.0-rc.2" 42 | serde = { version = "1.0.199", features = ["derive"], optional = true } 43 | tesseract = { version ="0.15.1", optional = true } 44 | thiserror = "1.0.59" 45 | tracing = "0.1.40" 46 | hocr-parser = { version = "0.1.0", optional = true } 47 | 48 | [features] 49 | default = [] 50 | serde = ["dep:serde"] 51 | save = ["dep:csv"] 52 | ocr = ["dep:tesseract", "dep:hocr-parser"] 53 | 54 | # ort/onnxruntime execution providers: https://ort.pyke.io/setup/cargo-features 55 | cuda = ["ort/cuda"] 56 | tensorrt = ["ort/tensorrt"] 57 | directml = ["ort/directml"] 58 | coreml = ["ort/coreml"] 59 | rocm = ["ort/rocm"] 60 | openvino = ["ort/openvino"] 61 | onednn = ["ort/onednn"] 62 | xnnpack = ["ort/xnnpack"] 63 | qnn = ["ort/qnn"] 64 | cann = ["ort/cann"] 65 | nnapi = ["ort/nnapi"] 66 | tvm = ["ort/tvm"] 67 | acl = ["ort/acl"] 68 | armnn = ["ort/armnn"] 69 | migraphx = ["ort/migraphx"] 70 | vitis = ["ort/vitis"] 71 | rknpu = ["ort/rknpu"] 72 | 73 | # image features 74 | default-formats = ["image/default-formats"] 75 | rayon = ["image/rayon"] 76 | 77 | avif = ["image/avif"] 78 | bmp = ["image/bmp"] 79 | dds = ["image/dds"] 80 | exr = ["image/exr"] 81 | ff = ["image/ff"] # Farbfeld image format 82 | gif = ["image/gif"] 83 | hdr = ["image/hdr"] 84 | ico = ["image/ico"] 85 | jpeg = ["image/jpeg"] 86 | png = ["image/png"] 87 | pnm = ["image/pnm"] 88 | qoi = ["image/qoi"] 89 | tga = ["image/tga"] 90 | tiff = ["image/tiff"] 91 | webp = ["image/webp"] 92 | 93 | nasm = ["image/nasm"] 94 | avif-native = ["image/avif-native"] -------------------------------------------------------------------------------- /src/layout_element.rs: -------------------------------------------------------------------------------- 1 | use geo_types::{coord, Coord, Rect}; 2 | 3 | #[derive(Debug, Clone)] 4 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 5 | /// A detected element in a document's layout. 6 | pub struct LayoutElement { 7 | /// Bounding box of the element. 8 | pub bbox: Rect, 9 | /// Type of element. This value depends on the labels of the model used to detect this element. 10 | pub element_type: String, 11 | // Confidence for the detection. 12 | pub confidence: f32, 13 | /// Source of the detection (the name of module which detected this element). 14 | pub source: String, 15 | /// Text within this element. This field is filled after OCR. 16 | pub text: Option, 17 | } 18 | 19 | impl LayoutElement { 20 | /// Constructs a [`LayoutElement`] instance. 21 | pub fn new( 22 | x1: f32, 23 | y1: f32, 24 | x2: f32, 25 | y2: f32, 26 | element_type: &str, 27 | confidence: f32, 28 | source: &str, 29 | ) -> Self { 30 | let bbox = Rect::new(coord! { x: x1, y: y1 }, coord! { x: x2, y: y2 }); 31 | 32 | Self { 33 | bbox, 34 | element_type: element_type.to_string(), 35 | confidence, 36 | text: None, 37 | source: source.to_string(), 38 | } 39 | } 40 | 41 | /// Constructs a [`LayoutElement`] instance with text. 42 | pub fn new_with_text( 43 | x1: f32, 44 | y1: f32, 45 | x2: f32, 46 | y2: f32, 47 | element_type: &str, 48 | text: String, 49 | confidence: f32, 50 | source: &str, 51 | ) -> Self { 52 | let bbox = Rect::new(coord! { x: x1, y: y1 }, coord! { x: x2, y: y2 }); 53 | 54 | Self { 55 | bbox, 56 | element_type: element_type.to_string(), 57 | confidence, 58 | text: Some(text), 59 | source: source.to_string(), 60 | } 61 | } 62 | 63 | /// Pads the bounding box of a [`LayoutElement`]. Useful for OCRing the element. 64 | pub fn pad(&mut self, padding: f32) { 65 | self.bbox 66 | .set_min(self.bbox.min() - coord! { x: padding, y: padding }); 67 | self.bbox 68 | .set_max(self.bbox.max() + coord! { x: padding, y: padding }); 69 | } 70 | 71 | /// Crop the section of the image according to the [`LayoutElement`]'s bounding box. 72 | pub fn crop_from_image(&self, img: &image::DynamicImage) -> image::DynamicImage { 73 | let (x1, y1) = (self.bbox.min().x as u32, self.bbox.min().y as u32); 74 | let (width, height) = (self.bbox.width() as u32, self.bbox.height() as u32); 75 | 76 | img.clone().crop(x1, y1, width, height) 77 | } 78 | 79 | /// Apply a transformation to the bounding box points. 80 | pub fn transform(&mut self, transform: impl Fn(Coord) -> Coord) { 81 | self.bbox.set_min(transform(self.bbox.min())); 82 | self.bbox.set_max(transform(self.bbox.max())); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/ocr/hocr_ext.rs: -------------------------------------------------------------------------------- 1 | use hocr_parser::spec_definitions::elements; 2 | use hocr_parser::spec_definitions::properties; 3 | use hocr_parser::Element; 4 | 5 | use thiserror::Error; 6 | 7 | use crate::utils::vec_to_bbox; 8 | use crate::LayoutElement; 9 | 10 | #[derive(Debug, Error)] 11 | /// hOCR element conversion error variants 12 | pub enum HOCRElementConversionError { 13 | #[error("No bounding box found in element properties")] 14 | NoBoundingBoxFound, 15 | #[error("No confidence found in element properties")] 16 | NoConfidenceFound, 17 | } 18 | 19 | pub(crate) trait HOCRElementConversion { 20 | fn get_layout_element(&self) -> Result; 21 | fn bbox(&self) -> Option<[u32; 4]>; 22 | fn confidence(&self) -> Option; 23 | fn extract_text(&self) -> String; 24 | } 25 | 26 | impl HOCRElementConversion for Element { 27 | fn get_layout_element(&self) -> Result { 28 | let [x1, y1, x2, y2] = self 29 | .bbox() 30 | .ok_or(HOCRElementConversionError::NoBoundingBoxFound)?; 31 | 32 | Ok(LayoutElement::new_with_text( 33 | x1 as f32, 34 | y1 as f32, 35 | x2 as f32, 36 | y2 as f32, 37 | &self.element_type, 38 | self.extract_text(), 39 | self.confidence() 40 | .ok_or(HOCRElementConversionError::NoConfidenceFound)?, 41 | "hocr-parser", 42 | )) 43 | } 44 | 45 | fn confidence(&self) -> Option { 46 | match self.element_type.as_str() { 47 | elements::OCRX_WORD => self 48 | .properties 49 | .iter() 50 | .find(|(n, _)| n == properties::X_WCONF)? 51 | .1[0] 52 | .parse::() 53 | .map(|e| e / 100.0) 54 | .ok(), 55 | _ => { 56 | let children: Vec = self 57 | .children 58 | .iter() 59 | .filter_map(|e| e.confidence()) 60 | .collect(); 61 | match children.len() { 62 | 0 => None, 63 | len => { 64 | let sum: f32 = children.iter().sum(); 65 | Some(sum / len as f32) 66 | } 67 | } 68 | } 69 | } 70 | } 71 | 72 | fn bbox(&self) -> Option<[u32; 4]> { 73 | let bbox_strs = &self 74 | .properties 75 | .iter() 76 | .find(|(n, _)| n == properties::BBOX)? 77 | .1; 78 | 79 | if bbox_strs.len() != 4 { 80 | return None; 81 | } 82 | 83 | let bbox = bbox_strs 84 | .iter() 85 | .map(|s| s.parse::().unwrap()) 86 | .collect::>(); 87 | 88 | Some(vec_to_bbox(bbox)) 89 | } 90 | 91 | fn extract_text(&self) -> String { 92 | match self.element_type.as_str() { 93 | elements::OCRX_WORD => self.text.clone().unwrap_or_default(), 94 | /* to filter by confidence if you want 95 | elements::OCR_LINE | elements::OCRX_LINE => { 96 | let confidence = self.confidence().unwrap_or(0.0); 97 | if confidence < 0.5 { 98 | return "".to_string(); 99 | } else { 100 | self.children 101 | .iter() 102 | .map(|e| e.extract_text()) 103 | .collect::>() 104 | .join(" ") 105 | + "\n" 106 | } 107 | }, 108 | */ 109 | _ => { 110 | self.children 111 | .iter() 112 | .map(|e| e.extract_text()) 113 | .collect::>() 114 | .join(" ") 115 | + "\n" 116 | } 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /src/ocr/mod.rs: -------------------------------------------------------------------------------- 1 | //! OCR agents and utilities for extracting text. 2 | 3 | mod hocr_ext; 4 | 5 | use hocr_ext::HOCRElementConversion; 6 | pub use hocr_ext::HOCRElementConversionError; 7 | 8 | pub use hocr_parser::roxmltree; 9 | use hocr_parser::spec_definitions::elements; 10 | use hocr_parser::Element; 11 | use hocr_parser::HOCRParserError; 12 | 13 | pub use hocr_parser; 14 | 15 | use crate::{LayoutElement, Result}; 16 | use tesseract::{Tesseract, TesseractError}; 17 | 18 | /// The Tesseract OCR Agent. 19 | /// 20 | /// Constructing the agent may require a valid `tessdata`` directory 21 | /// depending on the constructor. 22 | pub struct TesseractAgent { 23 | arguments: TesseractInitArguments, 24 | inner: Option, 25 | } 26 | 27 | enum TesseractInitArguments { 28 | Data { 29 | data: Vec, 30 | lang: String, 31 | }, 32 | DataPath { 33 | data_path: String, 34 | lang: String, 35 | }, 36 | Default, 37 | Generic { 38 | data_path: Option, 39 | lang: Option, 40 | }, 41 | } 42 | 43 | impl TesseractAgent { 44 | /// Construct a new [`TesseractAgent`]. 45 | pub fn new() -> Result { 46 | let arguments = TesseractInitArguments::Default; 47 | 48 | let inner = Tesseract::new(None, Some("eng")).map_err(|err| TesseractError::from(err))?; 49 | 50 | Ok(Self { 51 | inner: Some(inner), 52 | arguments, 53 | }) 54 | } 55 | 56 | /// Construct a new [`TesseractAgent`] specifying the OCR languages. 57 | pub fn new_with_lang(lang: &[&str]) -> Result { 58 | let lang = lang.join("+"); 59 | let inner = Tesseract::new(None, Some(&lang)).map_err(|err| TesseractError::from(err))?; 60 | 61 | let arguments = TesseractInitArguments::Generic { 62 | data_path: None, 63 | lang: Some(lang.to_string()), 64 | }; 65 | 66 | Ok(Self { 67 | inner: Some(inner), 68 | arguments, 69 | }) 70 | } 71 | 72 | /// Construct a new [`TesseractAgent`] with the tessdata file and specifying the OCR languages. 73 | pub fn new_with_data(data: &[u8], lang: &[&str]) -> Result { 74 | let lang = lang.join("+"); 75 | let inner = Tesseract::new_with_data(data, Some(&lang), tesseract::OcrEngineMode::Default) 76 | .map_err(|err| TesseractError::from(err))?; 77 | 78 | let arguments = TesseractInitArguments::Data { 79 | data: data.to_vec(), 80 | lang: lang, 81 | }; 82 | 83 | Ok(Self { 84 | inner: Some(inner), 85 | arguments, 86 | }) 87 | } 88 | 89 | /// Construct a new [`TesseractAgent`] with the tessdata path and specfiying the OCR languages. 90 | pub fn new_data_path(data_path: &str, lang: &[&str]) -> Result { 91 | let lang = lang.join("+"); 92 | // data_path is tessdata, which includes the traineddata files 93 | // https://github.com/tesseract-ocr/tessdata_fast 94 | // https://github.com/tesseract-ocr/tessdata 95 | let inner = Tesseract::new_with_oem( 96 | Some(data_path), 97 | Some(&lang), 98 | tesseract::OcrEngineMode::Default, 99 | ) 100 | .map_err(|err| TesseractError::from(err))?; 101 | 102 | let arguments = TesseractInitArguments::DataPath { 103 | data_path: data_path.to_string(), 104 | lang: lang, 105 | }; 106 | 107 | Ok(Self { 108 | inner: Some(inner), 109 | arguments, 110 | }) 111 | } 112 | 113 | fn reinit(&mut self) { 114 | // UNWRAP SAFETY: we constructed Tesseract before with these arguments, so it should be safe to unwrap 115 | let tesseract = match &self.arguments { 116 | TesseractInitArguments::Data { data, lang } => { 117 | Tesseract::new_with_data(&data, Some(&lang), tesseract::OcrEngineMode::Default) 118 | .unwrap() 119 | } 120 | TesseractInitArguments::DataPath { data_path, lang } => Tesseract::new_with_oem( 121 | Some(&data_path), 122 | Some(&lang), 123 | tesseract::OcrEngineMode::Default, 124 | ) 125 | .unwrap(), 126 | TesseractInitArguments::Default => Tesseract::new(None, None).unwrap(), 127 | TesseractInitArguments::Generic { data_path, lang } => { 128 | Tesseract::new(data_path.as_deref(), lang.as_deref()).unwrap() 129 | } 130 | }; 131 | self.inner = Some(tesseract); 132 | } 133 | 134 | /// Extracts the text within a [`LayoutElement`] and adds it to the element. 135 | pub fn extract_text_to_lm( 136 | &mut self, 137 | lm: &mut LayoutElement, 138 | img: &image::DynamicImage, 139 | ) -> Result<()> { 140 | let img = lm.crop_from_image(img); 141 | let text = self.extract_text(&img)?; 142 | lm.text = Some(text); 143 | Ok(()) 144 | } 145 | 146 | /// Extracts the text from an image. 147 | pub fn extract_text(&mut self, img: &image::DynamicImage) -> Result { 148 | let img = img.to_rgba8(); 149 | let (width, height) = img.dimensions(); 150 | let bytes_per_line = 4 * width; 151 | let frame_data = img.clone().into_vec(); 152 | 153 | let inner = self.inner.take().unwrap(); 154 | 155 | let mut inner = match inner 156 | .set_frame( 157 | &frame_data, 158 | width as i32, 159 | height as i32, 160 | 4, 161 | bytes_per_line as i32, 162 | ) 163 | .map_err(|err| TesseractError::from(err)) 164 | { 165 | Ok(tess) => tess, 166 | Err(err) => { 167 | self.reinit(); 168 | return Err(err.into()); 169 | } 170 | }; 171 | 172 | let text = inner.get_text().map_err(|err| TesseractError::from(err))?; 173 | self.inner = Some(inner); 174 | 175 | Ok(text) 176 | } 177 | 178 | /// Extracts the text regions as [`LayoutElement`] according to the OCR [`FeatureType`] from an image. 179 | pub fn extract( 180 | &mut self, 181 | img: &image::DynamicImage, 182 | feature: FeatureType, 183 | ) -> Result> { 184 | let img = img.to_rgba8(); 185 | let (width, height) = img.dimensions(); 186 | let bytes_per_line = 4 * width; 187 | let frame_data = img.clone().into_vec(); 188 | 189 | let inner = self.inner.take().unwrap(); 190 | 191 | let mut inner = match inner 192 | .set_frame( 193 | &frame_data, 194 | width as i32, 195 | height as i32, 196 | 4, 197 | bytes_per_line as i32, 198 | ) 199 | .map_err(|err| TesseractError::from(err)) 200 | { 201 | Ok(tess) => tess, 202 | Err(err) => { 203 | self.reinit(); 204 | return Err(err.into()); 205 | } 206 | }; 207 | 208 | let hocr = inner 209 | .get_hocr_text(0) 210 | .map_err(|err| TesseractError::from(err))?; 211 | 212 | let element = Element::from_node( 213 | roxmltree::Document::parse(&hocr) 214 | .map_err(HOCRParserError::from)? 215 | .root_element(), 216 | )?; 217 | 218 | let mut elements = vec![&element]; 219 | elements.extend(element.descendants()); 220 | 221 | let extracted_features: Vec<_> = elements 222 | .into_iter() 223 | .filter_map(|e| { 224 | if e.element_type == feature.as_hocr_element() { 225 | // SAFETY: tesseract hOCR conversion always works 226 | Some(e.get_layout_element().unwrap()) 227 | } else { 228 | None 229 | } 230 | }) 231 | .collect(); 232 | 233 | self.inner = Some(inner); 234 | 235 | Ok(extracted_features) 236 | } 237 | } 238 | 239 | /// OCR Feature types. Useful for extracting text regions at a specific level. 240 | pub enum FeatureType { 241 | /// Page 242 | Page, 243 | /// Block 244 | Block, 245 | /// Paragraph 246 | Para, 247 | /// Line 248 | Line, 249 | /// Word 250 | Word, 251 | } 252 | 253 | impl FeatureType { 254 | /// Convert to hOCR element type. 255 | pub fn as_hocr_element(&self) -> &str { 256 | match self { 257 | FeatureType::Page => elements::OCR_PAGE, 258 | FeatureType::Block => elements::OCR_CAREA, 259 | FeatureType::Para => elements::OCR_PAR, 260 | FeatureType::Line => elements::OCRX_LINE, 261 | FeatureType::Word => elements::OCRX_WORD, 262 | } 263 | } 264 | } 265 | -------------------------------------------------------------------------------- /src/models/detectron2.rs: -------------------------------------------------------------------------------- 1 | use image::imageops; 2 | use ndarray::{Array, ArrayBase, Dim, OwnedRepr}; 3 | use ort::{Session, SessionBuilder, SessionOutputs}; 4 | 5 | pub use crate::error::Result; 6 | use crate::{utils::vec_to_bbox, LayoutElement}; 7 | 8 | /// A [`Detectron2`](https://github.com/facebookresearch/detectron2)-based model. 9 | pub struct Detectron2Model { 10 | model_name: String, 11 | model: ort::Session, 12 | confidence_threshold: f32, 13 | label_map: Vec<(i64, String)>, 14 | confidence_score_index: usize, 15 | } 16 | 17 | #[allow(non_camel_case_types)] 18 | /// Pretrained Detectron2-based models from Hugging Face. 19 | pub enum Detectron2PretrainedModels { 20 | FASTER_RCNN_R_50_FPN_3X, 21 | MASK_RCNN_X_101_32X8D_FPN_3x, 22 | } 23 | 24 | impl Detectron2PretrainedModels { 25 | /// Model name. 26 | pub fn name(&self) -> &str { 27 | match self { 28 | _ => self.hf_repo(), 29 | } 30 | } 31 | 32 | /// Hugging Face repository for this model. 33 | pub fn hf_repo(&self) -> &str { 34 | match self { 35 | Self::FASTER_RCNN_R_50_FPN_3X => "unstructuredio/detectron2_faster_rcnn_R_50_FPN_3x", 36 | Self::MASK_RCNN_X_101_32X8D_FPN_3x => { 37 | "unstructuredio/detectron2_mask_rcnn_X_101_32x8d_FPN_3x" 38 | } 39 | } 40 | } 41 | 42 | /// Path for this model file in Hugging Face repository. 43 | pub fn hf_filename(&self) -> &str { 44 | match self { 45 | Self::FASTER_RCNN_R_50_FPN_3X => "model.onnx", 46 | Self::MASK_RCNN_X_101_32X8D_FPN_3x => "model.onnx", 47 | } 48 | } 49 | 50 | /// The label map for this model. 51 | pub fn label_map(&self) -> Vec<(i64, String)> { 52 | match self { 53 | Detectron2PretrainedModels::FASTER_RCNN_R_50_FPN_3X => { 54 | ["Text", "Title", "List", "Table", "Figure"] 55 | .iter() 56 | .enumerate() 57 | .map(|(i, l)| (i as i64, l.to_string())) 58 | .collect() 59 | } 60 | Detectron2PretrainedModels::MASK_RCNN_X_101_32X8D_FPN_3x => { 61 | ["Text", "Title", "List", "Table", "Figure"] 62 | .iter() 63 | .enumerate() 64 | .map(|(i, l)| (i as i64, l.to_string())) 65 | .collect() 66 | } 67 | } 68 | } 69 | 70 | /// Index for the confidence score in this model's outputs. 71 | pub fn confidence_score_index(&self) -> usize { 72 | match self { 73 | Detectron2PretrainedModels::FASTER_RCNN_R_50_FPN_3X => 2, 74 | Detectron2PretrainedModels::MASK_RCNN_X_101_32X8D_FPN_3x => 3, 75 | } 76 | } 77 | } 78 | 79 | impl Detectron2Model { 80 | /// Required input image width. 81 | pub const REQUIRED_WIDTH: u32 = 800; 82 | /// Required input image height. 83 | pub const REQUIRED_HEIGHT: u32 = 1035; 84 | /// Default confidence threshold for detections. 85 | pub const DEFAULT_CONFIDENCE_THRESHOLD: f32 = 0.8; 86 | 87 | /// Construct a [`Detectron2Model`] with a pretrained model downloaded from Hugging Face. 88 | pub fn pretrained(p_model: Detectron2PretrainedModels) -> Result { 89 | let session_builder = Session::builder()?; 90 | let api = hf_hub::api::sync::Api::new()?; 91 | let filename = api 92 | .model(p_model.hf_repo().to_string()) 93 | .get(p_model.hf_filename())?; 94 | 95 | let model = session_builder.commit_from_file(filename)?; 96 | 97 | Ok(Self { 98 | model_name: p_model.name().to_string(), 99 | model, 100 | label_map: p_model.label_map(), 101 | confidence_threshold: Self::DEFAULT_CONFIDENCE_THRESHOLD, 102 | confidence_score_index: p_model.confidence_score_index(), 103 | }) 104 | } 105 | 106 | /// Construct a configured [`Detectron2Model`] with a pretrained model downloaded from Hugging Face. 107 | pub fn configure_pretrained( 108 | p_model: Detectron2PretrainedModels, 109 | confidence_threshold: f32, 110 | session_builder: SessionBuilder, 111 | ) -> Result { 112 | let api = hf_hub::api::sync::Api::new()?; 113 | let filename = api 114 | .model(p_model.hf_repo().to_string()) 115 | .get(p_model.hf_filename())?; 116 | 117 | let model = session_builder.commit_from_file(filename)?; 118 | 119 | Ok(Self { 120 | model_name: p_model.name().to_string(), 121 | model, 122 | label_map: p_model.label_map(), 123 | confidence_threshold, 124 | confidence_score_index: p_model.confidence_score_index(), 125 | }) 126 | } 127 | 128 | /// Construct a [`Detectron2Model`] from a model file. 129 | pub fn new_from_file( 130 | file_path: &str, 131 | model_name: &str, 132 | label_map: &[(i64, &str)], 133 | confidence_threshold: f32, 134 | confidence_score_index: usize, 135 | session_builder: SessionBuilder, 136 | ) -> Result { 137 | let model = session_builder.commit_from_file(file_path)?; 138 | 139 | Ok(Self { 140 | model_name: model_name.to_string(), 141 | model, 142 | label_map: label_map.iter().map(|(i, l)| (*i, l.to_string())).collect(), 143 | confidence_threshold, 144 | confidence_score_index, 145 | }) 146 | } 147 | 148 | /// Predict [`LayoutElement`]s from the image provided. 149 | pub fn predict(&self, img: &image::DynamicImage) -> Result> { 150 | let (img_width, img_height, input) = self.preprocess(img); 151 | 152 | let run_result = self.model.run(ort::inputs!["x.1" => input]?); 153 | match run_result { 154 | Ok(outputs) => { 155 | let elements = self.postprocess(&outputs, img_width, img_height)?; 156 | return Ok(elements); 157 | } 158 | Err(_err) => { 159 | tracing::warn!( 160 | "Ignoring runtime error from onnx (likely due to encountering blank page)." 161 | ); 162 | return Ok(vec![]); 163 | } 164 | } 165 | } 166 | 167 | fn preprocess( 168 | &self, 169 | img: &image::DynamicImage, 170 | ) -> (u32, u32, ArrayBase, Dim<[usize; 3]>>) { 171 | let (img_width, img_height) = (img.width(), img.height()); 172 | let img = img.resize_exact( 173 | Self::REQUIRED_WIDTH, 174 | Self::REQUIRED_HEIGHT, 175 | imageops::FilterType::Triangle, 176 | ); 177 | let img_rgb8 = img.into_rgba8(); 178 | 179 | let mut input = Array::zeros((3, 1035, 800)); 180 | 181 | for pixel in img_rgb8.enumerate_pixels() { 182 | let x = pixel.0 as _; 183 | let y = pixel.1 as _; 184 | let [r, g, b, _] = pixel.2 .0; 185 | input[[0, y, x]] = r as f32; 186 | input[[1, y, x]] = g as f32; 187 | input[[2, y, x]] = b as f32; 188 | } 189 | 190 | return (img_width, img_height, input); 191 | } 192 | 193 | fn postprocess<'s>( 194 | &self, 195 | outputs: &SessionOutputs<'s>, 196 | img_width: u32, 197 | img_height: u32, 198 | ) -> Result> { 199 | let bboxes = &outputs[0].try_extract_tensor::()?; 200 | let labels = &outputs[1].try_extract_tensor::()?; 201 | let confidence_scores = 202 | &outputs[self.confidence_score_index].try_extract_tensor::()?; 203 | 204 | let width_conversion = img_width as f32 / Self::REQUIRED_WIDTH as f32; 205 | let height_conversion = img_height as f32 / Self::REQUIRED_HEIGHT as f32; 206 | 207 | let mut elements = vec![]; 208 | 209 | for (bbox, (label, confidence_score)) in bboxes 210 | .rows() 211 | .into_iter() 212 | .zip(labels.iter().zip(confidence_scores)) 213 | { 214 | let [x1, y1, x2, y2] = vec_to_bbox(bbox.iter().copied().collect()); 215 | 216 | let detected_label = &self 217 | .label_map 218 | .iter() 219 | .find(|(l_i, _)| l_i == label) 220 | .unwrap() // SAFETY: the model always yields one of these labels 221 | .1; 222 | 223 | if *confidence_score > self.confidence_threshold as f32 { 224 | elements.push(LayoutElement::new( 225 | x1 * width_conversion, 226 | y1 * height_conversion, 227 | x2 * width_conversion, 228 | y2 * height_conversion, 229 | &detected_label, 230 | *confidence_score, 231 | &self.model_name, 232 | )) 233 | } 234 | } 235 | 236 | elements.sort_by(|a, b| a.bbox.max().y.total_cmp(&b.bbox.max().y)); 237 | 238 | return Ok(elements); 239 | } 240 | } 241 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 LayoutParser Authors 190 | Copyright 2024 Unstructured Technologies Inc. 191 | Copyright 2024 Oğuz Kurt 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /src/models/yolox.rs: -------------------------------------------------------------------------------- 1 | use image::imageops; 2 | use itertools::Itertools; 3 | use ndarray::{ 4 | concatenate, s, stack, Array, Array1, Array2, ArrayBase, ArrayD, Axis, Dim, IxDyn, OwnedRepr, 5 | }; 6 | use ort::{Session, SessionBuilder, SessionOutputs}; 7 | 8 | pub use crate::error::Result; 9 | use crate::{utils, LayoutElement}; 10 | 11 | /// A [`YOLOX`](https://github.com/Megvii-BaseDetection/YOLOX)-based model. 12 | pub struct YOLOXModel { 13 | model_name: String, 14 | model: ort::Session, 15 | is_quantized: bool, 16 | label_map: Vec<(i64, String)>, 17 | } 18 | 19 | #[derive(PartialEq)] 20 | /// Pretrained YOLOX-based models from Hugging Face. 21 | pub enum YOLOXPretrainedModels { 22 | Large, 23 | LargeQuantized, 24 | Tiny, 25 | } 26 | 27 | impl YOLOXPretrainedModels { 28 | /// Model name. 29 | pub fn name(&self) -> &str { 30 | match self { 31 | _ => self.hf_repo(), 32 | } 33 | } 34 | 35 | /// Hugging Face repository for this model. 36 | pub fn hf_repo(&self) -> &str { 37 | match self { 38 | _ => "unstructuredio/yolo_x_layout", 39 | } 40 | } 41 | 42 | /// Path for this model file in Hugging Face repository. 43 | pub fn hf_filename(&self) -> &str { 44 | match self { 45 | YOLOXPretrainedModels::Large => "yolox_l0.05.onnx", 46 | YOLOXPretrainedModels::LargeQuantized => "yolox_l0.05_quantized.onnx", 47 | YOLOXPretrainedModels::Tiny => "yolox_tiny.onnx", 48 | } 49 | } 50 | 51 | /// The label map for this model. 52 | pub fn label_map(&self) -> Vec<(i64, String)> { 53 | match self { 54 | _ => Vec::from_iter( 55 | [ 56 | (0, "Caption"), 57 | (1, "Footnote"), 58 | (2, "Formula"), 59 | (3, "List-item"), 60 | (4, "Page-footer"), 61 | (5, "Page-header"), 62 | (6, "Picture"), 63 | (7, "Section-header"), 64 | (8, "Table"), 65 | (9, "Text"), 66 | (10, "Title"), 67 | ] 68 | .iter() 69 | .map(|(i, l)| (*i as i64, l.to_string())), 70 | ), 71 | } 72 | } 73 | } 74 | 75 | impl YOLOXModel { 76 | /// Required input image width. 77 | pub const REQUIRED_WIDTH: u32 = 768; 78 | /// Required input image height. 79 | pub const REQUIRED_HEIGHT: u32 = 1024; 80 | 81 | /// Construct a [`YOLOXModel`] with a pretrained model downloaded from Hugging Face. 82 | pub fn pretrained(p_model: YOLOXPretrainedModels) -> Result { 83 | let session_builder = Session::builder()?; 84 | let api = hf_hub::api::sync::Api::new()?; 85 | let filename = api 86 | .model(p_model.hf_repo().to_string()) 87 | .get(p_model.hf_filename())?; 88 | 89 | let model = session_builder.commit_from_file(filename)?; 90 | 91 | Ok(Self { 92 | model_name: p_model.name().to_string(), 93 | model, 94 | label_map: p_model.label_map(), 95 | is_quantized: p_model == YOLOXPretrainedModels::LargeQuantized, 96 | }) 97 | } 98 | 99 | /// Construct a configured [`YOLOXModel`] with a pretrained model downloaded from Hugging Face. 100 | pub fn configure_pretrained( 101 | p_model: YOLOXPretrainedModels, 102 | session_builder: SessionBuilder, 103 | ) -> Result { 104 | let api = hf_hub::api::sync::Api::new()?; 105 | let filename = api 106 | .model(p_model.hf_repo().to_string()) 107 | .get(p_model.hf_filename())?; 108 | 109 | let model = session_builder.commit_from_file(filename)?; 110 | 111 | Ok(Self { 112 | model_name: p_model.name().to_string(), 113 | model, 114 | label_map: p_model.label_map(), 115 | is_quantized: p_model == YOLOXPretrainedModels::LargeQuantized, 116 | }) 117 | } 118 | 119 | /// Construct a [`YOLOXModel`] from a model file. 120 | pub fn new_from_file( 121 | file_path: &str, 122 | model_name: &str, 123 | label_map: &[(i64, &str)], 124 | is_quantized: bool, 125 | session_builder: SessionBuilder, 126 | ) -> Result { 127 | let model = session_builder.commit_from_file(file_path)?; 128 | 129 | Ok(Self { 130 | model_name: model_name.to_string(), 131 | model, 132 | label_map: label_map.iter().map(|(i, l)| (*i, l.to_string())).collect(), 133 | is_quantized, 134 | }) 135 | } 136 | 137 | /// Predict [`LayoutElement`]s from the image provided. 138 | pub fn predict(&self, img: &image::DynamicImage) -> Result> { 139 | // UNWRAP SAFETY: shape unwraps are never a problem because we know the size of the output tensor 140 | let (input, r) = self.preprocess(img); 141 | 142 | let input_name = &self.model.inputs[0].name; 143 | 144 | let run_result = self.model.run(ort::inputs![input_name => input]?); 145 | match run_result { 146 | Ok(outputs) => { 147 | let predictions = self 148 | .postprocess(&outputs, false)? 149 | .slice(s![0, .., ..]) 150 | .to_owned(); 151 | 152 | let boxes = predictions 153 | .slice(s![.., 0..4]) 154 | .to_shape([16128, 4]) 155 | .unwrap() 156 | .to_owned(); 157 | let scores = predictions 158 | .slice(s![.., 4..5]) 159 | .to_shape([16128, 1]) 160 | .unwrap() 161 | .to_owned() 162 | * predictions.slice(s![.., 5..]); 163 | 164 | let mut boxes_xyxy: Array = ndarray::Array::ones([16128, 4]); 165 | 166 | let s0 = 167 | boxes.slice(s![.., 0]).to_owned() - (boxes.slice(s![.., 2]).to_owned() / 2.0); 168 | let s1 = 169 | boxes.slice(s![.., 1]).to_owned() - (boxes.slice(s![.., 3]).to_owned() / 2.0); 170 | let s2 = 171 | boxes.slice(s![.., 0]).to_owned() + (boxes.slice(s![.., 2]).to_owned() / 2.0); 172 | let s3 = 173 | boxes.slice(s![.., 1]).to_owned() + (boxes.slice(s![.., 3]).to_owned() / 2.0); 174 | 175 | boxes_xyxy 176 | .slice_mut(s![.., 0]) 177 | .iter_mut() 178 | .zip_eq(s0.iter()) 179 | .for_each(|(old, new)| *old = *new); 180 | boxes_xyxy 181 | .slice_mut(s![.., 1]) 182 | .iter_mut() 183 | .zip_eq(s1.iter()) 184 | .for_each(|(old, new)| *old = *new); 185 | boxes_xyxy 186 | .slice_mut(s![.., 2]) 187 | .iter_mut() 188 | .zip_eq(s2.iter()) 189 | .for_each(|(old, new)| *old = *new); 190 | boxes_xyxy 191 | .slice_mut(s![.., 3]) 192 | .iter_mut() 193 | .zip_eq(s3.iter()) 194 | .for_each(|(old, new)| *old = *new); 195 | 196 | boxes_xyxy /= r; 197 | 198 | let mut regions = vec![]; 199 | 200 | let (nms_thr, score_thr) = if self.is_quantized { 201 | (0.0, 0.07) 202 | } else { 203 | (0.1, 0.25) 204 | }; 205 | 206 | let dets = multiclass_nms_class_agnostic(&boxes_xyxy, &scores, nms_thr, score_thr); 207 | 208 | for det in dets.outer_iter() { 209 | let [x1, y1, x2, y2, prob, class_id] = 210 | extract_bbox_etc(&det.into_iter().copied().collect()); 211 | let detected_class = self.get_label(class_id as i64); 212 | regions.push(LayoutElement::new( 213 | x1, 214 | y1, 215 | x2, 216 | y2, 217 | &detected_class, 218 | prob, 219 | &self.model_name, 220 | )); 221 | } 222 | 223 | regions.sort_by(|a, b| a.bbox.max().y.total_cmp(&b.bbox.max().y)); 224 | 225 | return Ok(regions); 226 | } 227 | Err(_err) => { 228 | eprintln!("{_err:?}"); 229 | tracing::warn!( 230 | "Ignoring runtime error from onnx (likely due to encountering blank page)." 231 | ); 232 | return Ok(vec![]); 233 | } 234 | } 235 | } 236 | 237 | fn postprocess<'s>( 238 | &self, 239 | outputs: &SessionOutputs<'s>, 240 | p6: bool, 241 | ) -> Result>> { 242 | let output_m = &outputs[0].try_extract_tensor::()?; 243 | let mut shaped_output = output_m.to_shape([1, 16128, 16]).unwrap().to_owned(); 244 | 245 | let strides = if !p6 { 246 | vec![8, 16, 32] 247 | } else { 248 | vec![8, 16, 32, 64] 249 | }; 250 | 251 | let hsizes: Vec = strides.iter().map(|s| Self::REQUIRED_HEIGHT / s).collect(); 252 | let wsizes: Vec = strides.iter().map(|s| Self::REQUIRED_WIDTH / s).collect(); 253 | 254 | let mut grids = vec![]; 255 | let mut expanded_strides = vec![]; 256 | 257 | for (stride, (hsize, wsize)) in strides.iter().zip(hsizes.iter().zip(wsizes.iter())) { 258 | let meshgrid_res = meshgrid( 259 | &[Array1::from_iter(0..*wsize), Array1::from_iter(0..*hsize)], 260 | Indexing::Xy, 261 | ); 262 | let xv = meshgrid_res[0].to_owned(); 263 | let yv = meshgrid_res[1].to_owned(); 264 | 265 | let grid = stack![Axis(2), xv, yv] 266 | .to_shape((1, (hsize * wsize) as usize, 2)) 267 | .unwrap() 268 | .to_owned(); 269 | 270 | let shape_1 = &grid.shape()[0..2]; 271 | expanded_strides.push(Array::from_elem((shape_1[0], shape_1[1], 1), stride)); 272 | 273 | grids.push(grid); 274 | } 275 | 276 | let grids = 277 | ndarray::concatenate(Axis(1), &grids.iter().map(|g| g.view()).collect::>()) 278 | .unwrap(); 279 | let expanded_strides = ndarray::concatenate( 280 | Axis(1), 281 | &expanded_strides 282 | .iter() 283 | .map(|g| g.view()) 284 | .collect::>(), 285 | ) 286 | .unwrap(); 287 | 288 | let s1 = (shaped_output.slice(s![.., .., 0..2]).to_owned() + grids.mapv(|e| e as f32)) 289 | * expanded_strides.mapv(|e| *e as f32); 290 | let s2 = (shaped_output 291 | .slice(s![.., .., 2..4]) 292 | .mapv(|e| e.exp()) 293 | .to_owned()) 294 | * expanded_strides.mapv(|e| *e as f32); 295 | 296 | shaped_output 297 | .slice_mut(s![.., .., 0..2]) 298 | .into_iter() 299 | .zip_eq(s1.into_iter()) 300 | .for_each(|(old, new)| { 301 | *old = new; 302 | }); 303 | 304 | shaped_output 305 | .slice_mut(s![.., .., 2..4]) 306 | .into_iter() 307 | .zip_eq(s2.into_iter()) 308 | .for_each(|(old, new)| { 309 | *old = new; 310 | }); 311 | 312 | Ok(shaped_output) 313 | } 314 | 315 | fn preprocess( 316 | &self, 317 | img: &image::DynamicImage, 318 | ) -> (ArrayBase, Dim<[usize; 4]>>, f32) { 319 | let (img_width, img_height) = (img.width(), img.height()); 320 | 321 | let mut padded_img: ArrayBase, Dim<[usize; 4]>> = Array::ones(( 322 | 1, 323 | 3, 324 | Self::REQUIRED_HEIGHT as usize, 325 | Self::REQUIRED_WIDTH as usize, 326 | )) * 114_f32; 327 | 328 | let r: f64 = f64::min( 329 | Self::REQUIRED_HEIGHT as f64 / img_height as f64, 330 | Self::REQUIRED_WIDTH as f64 / img_width as f64, 331 | ); 332 | 333 | let resized_img = img.resize_exact( 334 | (img_width as f64 * r) as u32, 335 | (img_height as f64 * r) as u32, 336 | imageops::FilterType::Triangle, 337 | ); 338 | 339 | for pixel in resized_img.into_rgba8().enumerate_pixels() { 340 | let x = pixel.0 as _; 341 | let y = pixel.1 as _; 342 | let [r, g, b, _] = pixel.2 .0; 343 | padded_img[[0, 0, y, x]] = r as f32; 344 | padded_img[[0, 1, y, x]] = g as f32; 345 | padded_img[[0, 2, y, x]] = b as f32; 346 | } 347 | 348 | (padded_img, r as f32) 349 | } 350 | 351 | fn get_label(&self, label_id: i64) -> String { 352 | self.label_map 353 | .iter() 354 | .find(|(l_i, _)| l_i == &label_id) 355 | .unwrap() 356 | .1 357 | .clone() 358 | } 359 | } 360 | 361 | fn multiclass_nms_class_agnostic( 362 | boxes: &Array>, 363 | scores: &Array>, 364 | nms_thr: f32, 365 | score_thr: f32, 366 | ) -> Array2 { 367 | let cls_inds = Array1::from_iter(scores.axis_iter(Axis(0)).map(|e| { 368 | let (max_i, _max) = e.iter().enumerate().fold((0_usize, 0_f32), |acc, (i, e)| { 369 | let (max_i, max) = acc; 370 | if *e > max { 371 | (i, *e) 372 | } else { 373 | (max_i, max) 374 | } 375 | }); 376 | max_i 377 | })); 378 | 379 | let cls_scores = Array1::from_iter( 380 | scores 381 | .axis_iter(Axis(0)) 382 | .zip_eq(cls_inds.iter()) 383 | .map(|(e, i)| e[*i]), 384 | ); 385 | 386 | let valid_score_mask = cls_scores.mapv(|s| s > score_thr); 387 | let valid_scores = Array1::from_iter( 388 | cls_scores 389 | .iter() 390 | .zip_eq(valid_score_mask.iter()) 391 | .filter(|(_, b)| **b) 392 | .map(|(s, _)| *s), 393 | ); 394 | 395 | let valid_boxes: Array2 = to_array2( 396 | &boxes 397 | .outer_iter() 398 | .zip_eq(valid_score_mask.iter()) 399 | .filter(|(_, b)| **b) 400 | .map(|(s, _)| s.to_owned()) 401 | .collect::>(), 402 | ) 403 | .unwrap(); 404 | 405 | let valid_cls_inds = Array1::from_iter( 406 | cls_inds 407 | .iter() 408 | .zip_eq(valid_score_mask.iter()) 409 | .filter(|(_, b)| **b) 410 | .map(|(s, _)| s) 411 | .collect::>(), 412 | ); 413 | 414 | let keep = nms(&valid_boxes.to_owned(), &valid_scores, nms_thr); 415 | 416 | let valid_boxes_vec: Vec<_> = valid_boxes.outer_iter().collect(); 417 | let valid_boxes_kept = to_array2( 418 | &keep 419 | .iter() 420 | .map(|i| valid_boxes_vec[*i]) 421 | .map(|e| e.to_owned()) 422 | .collect::>(), 423 | ) 424 | .unwrap(); 425 | 426 | let valid_scores_vec: Vec<_> = valid_scores.into_iter().collect(); 427 | let valid_scores_kept = to_array2( 428 | &keep 429 | .iter() 430 | .map(|i| valid_scores_vec[*i]) 431 | .map(|e| Array1::from_elem(1, e)) 432 | .collect::>(), 433 | ) 434 | .unwrap(); 435 | 436 | let valid_cls_inds_vec: Vec<_> = valid_cls_inds.into_iter().collect(); 437 | let valid_cls_inds_kept = to_array2( 438 | &keep 439 | .iter() 440 | .map(|i| valid_cls_inds_vec[*i]) 441 | .map(|e| Array1::from_elem(1, e)) 442 | .collect::>(), 443 | ) 444 | .unwrap(); 445 | 446 | let dets = concatenate( 447 | Axis(1), 448 | &[ 449 | valid_boxes_kept.view(), 450 | valid_scores_kept.view(), 451 | valid_cls_inds_kept.mapv(|e| *e as f32).view(), 452 | ], 453 | ) 454 | .unwrap(); 455 | 456 | return dets; 457 | } 458 | 459 | fn nms( 460 | boxes: &Array>, 461 | scores: &Array>, 462 | nms_thr: f32, 463 | ) -> Vec { 464 | let x1 = boxes.slice(s![.., 0]); 465 | let y1 = boxes.slice(s![.., 1]); 466 | let x2 = boxes.slice(s![.., 2]); 467 | let y2 = boxes.slice(s![.., 3]); 468 | 469 | let areas = (&x2 - &x1 + 1_f32) * (&y2 - &y1 + 1_f32); 470 | let mut order = { 471 | let mut o = utils::argsort_by(&scores, |a, b| a.partial_cmp(b).unwrap()); 472 | o.reverse(); 473 | o 474 | }; 475 | 476 | let mut keep = vec![]; 477 | 478 | while !order.is_empty() { 479 | let i = order[0]; 480 | keep.push(i); 481 | 482 | let order_sliced = Array1::from_iter(order.iter().skip(1)); 483 | 484 | let xx1 = order_sliced.mapv(|o_i| f32::max(x1[i], x1[*o_i])); 485 | let yy1 = order_sliced.mapv(|o_i| f32::max(y1[i], y1[*o_i])); 486 | let xx2 = order_sliced.mapv(|o_i| f32::min(x2[i], x2[*o_i])); 487 | let yy2 = order_sliced.mapv(|o_i| f32::min(y2[i], y2[*o_i])); 488 | 489 | let w = ((&xx2 - &xx1) + 1_f32).mapv(|v| f32::max(0.0, v)); 490 | let h = ((&yy2 - &yy1) + 1_f32).mapv(|v| f32::max(0.0, v)); 491 | let inter = w * h; 492 | let ovr = &inter / (areas[i] + order_sliced.mapv(|e| areas[*e]) - &inter); 493 | 494 | let inds = Array1::from_iter( 495 | ovr.iter() 496 | .map(|e| *e <= nms_thr) 497 | .enumerate() 498 | .filter(|(_, p)| *p) 499 | .map(|(i, _)| i), 500 | ); 501 | 502 | drop(order_sliced); 503 | 504 | order = inds.into_iter().map(|i| order[i + 1]).collect(); 505 | } 506 | 507 | return keep; 508 | } 509 | 510 | fn to_array2(source: &[Array1]) -> Result, impl std::error::Error> { 511 | let width = source.len(); 512 | let flattened: Array1 = source.into_iter().flat_map(|row| row.to_vec()).collect(); 513 | let height = if width == 0 { 514 | flattened.len() 515 | } else { 516 | flattened.len() / width 517 | }; 518 | flattened.into_shape((width, height)) 519 | } 520 | 521 | /** [x1, y1, x2, y2, prob, class_id] */ 522 | fn extract_bbox_etc(v: &Vec) -> [f32; 6] { 523 | [v[0], v[1], v[2], v[3], v[4], v[5]] 524 | } 525 | 526 | // from: https://github.com/jreniel/meshgridrs (licensed under MIT) 527 | #[derive(PartialEq)] 528 | pub(crate) enum Indexing { 529 | Xy, 530 | Ij, 531 | } 532 | // from: https://github.com/jreniel/meshgridrs (licensed under MIT) 533 | pub(crate) fn meshgrid( 534 | xi: &[Array1], 535 | indexing: Indexing, 536 | ) -> Vec, Dim>> 537 | where 538 | T: Copy, 539 | { 540 | let ndim = xi.len(); 541 | let product = xi.iter().map(|x| x.iter()).multi_cartesian_product(); 542 | 543 | let mut grids: Vec> = Vec::with_capacity(ndim); 544 | 545 | for (dim_index, _) in xi.iter().enumerate() { 546 | // Generate a flat vector with the correct repeated pattern 547 | let values: Vec = product.clone().map(|p| *p[dim_index]).collect(); 548 | 549 | let mut grid_shape: Vec = vec![1; ndim]; 550 | grid_shape[dim_index] = xi[dim_index].len(); 551 | 552 | // Determine the correct repetition for each dimension 553 | for (j, len) in xi.iter().map(|x| x.len()).enumerate() { 554 | if j != dim_index { 555 | grid_shape[j] = len; 556 | } 557 | } 558 | 559 | let grid = Array::from_shape_vec(IxDyn(&grid_shape), values).unwrap(); 560 | grids.push(grid); 561 | } 562 | 563 | // Swap axes for "xy" indexing 564 | if matches!(indexing, Indexing::Xy) && ndim > 1 { 565 | for grid in &mut grids { 566 | grid.swap_axes(0, 1); 567 | } 568 | } 569 | 570 | grids 571 | } 572 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "adler" 7 | version = "1.0.2" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" 10 | 11 | [[package]] 12 | name = "ahash" 13 | version = "0.8.11" 14 | source = "registry+https://github.com/rust-lang/crates.io-index" 15 | checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" 16 | dependencies = [ 17 | "cfg-if", 18 | "once_cell", 19 | "version_check", 20 | "zerocopy", 21 | ] 22 | 23 | [[package]] 24 | name = "aho-corasick" 25 | version = "1.1.3" 26 | source = "registry+https://github.com/rust-lang/crates.io-index" 27 | checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" 28 | dependencies = [ 29 | "memchr", 30 | ] 31 | 32 | [[package]] 33 | name = "aligned-vec" 34 | version = "0.5.0" 35 | source = "registry+https://github.com/rust-lang/crates.io-index" 36 | checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" 37 | 38 | [[package]] 39 | name = "anyhow" 40 | version = "1.0.86" 41 | source = "registry+https://github.com/rust-lang/crates.io-index" 42 | checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" 43 | 44 | [[package]] 45 | name = "approx" 46 | version = "0.5.1" 47 | source = "registry+https://github.com/rust-lang/crates.io-index" 48 | checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" 49 | dependencies = [ 50 | "num-traits", 51 | ] 52 | 53 | [[package]] 54 | name = "arbitrary" 55 | version = "1.3.2" 56 | source = "registry+https://github.com/rust-lang/crates.io-index" 57 | checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" 58 | 59 | [[package]] 60 | name = "arg_enum_proc_macro" 61 | version = "0.3.4" 62 | source = "registry+https://github.com/rust-lang/crates.io-index" 63 | checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" 64 | dependencies = [ 65 | "proc-macro2", 66 | "quote", 67 | "syn 2.0.60", 68 | ] 69 | 70 | [[package]] 71 | name = "arrayvec" 72 | version = "0.7.4" 73 | source = "registry+https://github.com/rust-lang/crates.io-index" 74 | checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" 75 | 76 | [[package]] 77 | name = "autocfg" 78 | version = "1.2.0" 79 | source = "registry+https://github.com/rust-lang/crates.io-index" 80 | checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" 81 | 82 | [[package]] 83 | name = "av-data" 84 | version = "0.4.2" 85 | source = "registry+https://github.com/rust-lang/crates.io-index" 86 | checksum = "d75b98a3525d00f920df9a2d44cc99b9cc5b7dc70d7fbb612cd755270dbe6552" 87 | dependencies = [ 88 | "byte-slice-cast", 89 | "bytes", 90 | "num-derive", 91 | "num-rational", 92 | "num-traits", 93 | "thiserror", 94 | ] 95 | 96 | [[package]] 97 | name = "av1-grain" 98 | version = "0.2.3" 99 | source = "registry+https://github.com/rust-lang/crates.io-index" 100 | checksum = "6678909d8c5d46a42abcf571271e15fdbc0a225e3646cf23762cd415046c78bf" 101 | dependencies = [ 102 | "anyhow", 103 | "arrayvec", 104 | "log", 105 | "nom", 106 | "num-rational", 107 | "v_frame", 108 | ] 109 | 110 | [[package]] 111 | name = "avif-serialize" 112 | version = "0.8.1" 113 | source = "registry+https://github.com/rust-lang/crates.io-index" 114 | checksum = "876c75a42f6364451a033496a14c44bffe41f5f4a8236f697391f11024e596d2" 115 | dependencies = [ 116 | "arrayvec", 117 | ] 118 | 119 | [[package]] 120 | name = "base64" 121 | version = "0.22.0" 122 | source = "registry+https://github.com/rust-lang/crates.io-index" 123 | checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" 124 | 125 | [[package]] 126 | name = "bindgen" 127 | version = "0.64.0" 128 | source = "registry+https://github.com/rust-lang/crates.io-index" 129 | checksum = "c4243e6031260db77ede97ad86c27e501d646a27ab57b59a574f725d98ab1fb4" 130 | dependencies = [ 131 | "bitflags 1.3.2", 132 | "cexpr", 133 | "clang-sys", 134 | "lazy_static", 135 | "lazycell", 136 | "log", 137 | "peeking_take_while", 138 | "proc-macro2", 139 | "quote", 140 | "regex", 141 | "rustc-hash", 142 | "shlex", 143 | "syn 1.0.109", 144 | "which", 145 | ] 146 | 147 | [[package]] 148 | name = "bit_field" 149 | version = "0.10.2" 150 | source = "registry+https://github.com/rust-lang/crates.io-index" 151 | checksum = "dc827186963e592360843fb5ba4b973e145841266c1357f7180c43526f2e5b61" 152 | 153 | [[package]] 154 | name = "bitflags" 155 | version = "1.3.2" 156 | source = "registry+https://github.com/rust-lang/crates.io-index" 157 | checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" 158 | 159 | [[package]] 160 | name = "bitflags" 161 | version = "2.5.0" 162 | source = "registry+https://github.com/rust-lang/crates.io-index" 163 | checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" 164 | 165 | [[package]] 166 | name = "bitreader" 167 | version = "0.3.7" 168 | source = "registry+https://github.com/rust-lang/crates.io-index" 169 | checksum = "f10043e4864d975e7f197f993ec4018636ad93946724b2571c4474d51845869b" 170 | dependencies = [ 171 | "cfg-if", 172 | ] 173 | 174 | [[package]] 175 | name = "bitstream-io" 176 | version = "2.3.0" 177 | source = "registry+https://github.com/rust-lang/crates.io-index" 178 | checksum = "7c12d1856e42f0d817a835fe55853957c85c8c8a470114029143d3f12671446e" 179 | 180 | [[package]] 181 | name = "block-buffer" 182 | version = "0.10.4" 183 | source = "registry+https://github.com/rust-lang/crates.io-index" 184 | checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" 185 | dependencies = [ 186 | "generic-array", 187 | ] 188 | 189 | [[package]] 190 | name = "built" 191 | version = "0.7.3" 192 | source = "registry+https://github.com/rust-lang/crates.io-index" 193 | checksum = "c6a6c0b39c38fd754ac338b00a88066436389c0f029da5d37d1e01091d9b7c17" 194 | 195 | [[package]] 196 | name = "bumpalo" 197 | version = "3.16.0" 198 | source = "registry+https://github.com/rust-lang/crates.io-index" 199 | checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" 200 | 201 | [[package]] 202 | name = "byte-slice-cast" 203 | version = "1.2.2" 204 | source = "registry+https://github.com/rust-lang/crates.io-index" 205 | checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" 206 | 207 | [[package]] 208 | name = "bytemuck" 209 | version = "1.15.0" 210 | source = "registry+https://github.com/rust-lang/crates.io-index" 211 | checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" 212 | 213 | [[package]] 214 | name = "byteorder" 215 | version = "1.5.0" 216 | source = "registry+https://github.com/rust-lang/crates.io-index" 217 | checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" 218 | 219 | [[package]] 220 | name = "byteorder-lite" 221 | version = "0.1.0" 222 | source = "registry+https://github.com/rust-lang/crates.io-index" 223 | checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" 224 | 225 | [[package]] 226 | name = "bytes" 227 | version = "1.6.0" 228 | source = "registry+https://github.com/rust-lang/crates.io-index" 229 | checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" 230 | 231 | [[package]] 232 | name = "cc" 233 | version = "1.0.95" 234 | source = "registry+https://github.com/rust-lang/crates.io-index" 235 | checksum = "d32a725bc159af97c3e629873bb9f88fb8cf8a4867175f76dc987815ea07c83b" 236 | dependencies = [ 237 | "jobserver", 238 | "libc", 239 | "once_cell", 240 | ] 241 | 242 | [[package]] 243 | name = "cexpr" 244 | version = "0.6.0" 245 | source = "registry+https://github.com/rust-lang/crates.io-index" 246 | checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" 247 | dependencies = [ 248 | "nom", 249 | ] 250 | 251 | [[package]] 252 | name = "cfg-expr" 253 | version = "0.15.8" 254 | source = "registry+https://github.com/rust-lang/crates.io-index" 255 | checksum = "d067ad48b8650848b989a59a86c6c36a995d02d2bf778d45c3c5d57bc2718f02" 256 | dependencies = [ 257 | "smallvec", 258 | "target-lexicon", 259 | ] 260 | 261 | [[package]] 262 | name = "cfg-if" 263 | version = "1.0.0" 264 | source = "registry+https://github.com/rust-lang/crates.io-index" 265 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 266 | 267 | [[package]] 268 | name = "clang-sys" 269 | version = "1.7.0" 270 | source = "registry+https://github.com/rust-lang/crates.io-index" 271 | checksum = "67523a3b4be3ce1989d607a828d036249522dd9c1c8de7f4dd2dae43a37369d1" 272 | dependencies = [ 273 | "glob", 274 | "libc", 275 | "libloading", 276 | ] 277 | 278 | [[package]] 279 | name = "color_quant" 280 | version = "1.1.0" 281 | source = "registry+https://github.com/rust-lang/crates.io-index" 282 | checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" 283 | 284 | [[package]] 285 | name = "console" 286 | version = "0.15.8" 287 | source = "registry+https://github.com/rust-lang/crates.io-index" 288 | checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" 289 | dependencies = [ 290 | "encode_unicode", 291 | "lazy_static", 292 | "libc", 293 | "unicode-width", 294 | "windows-sys 0.52.0", 295 | ] 296 | 297 | [[package]] 298 | name = "core-foundation" 299 | version = "0.9.4" 300 | source = "registry+https://github.com/rust-lang/crates.io-index" 301 | checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" 302 | dependencies = [ 303 | "core-foundation-sys", 304 | "libc", 305 | ] 306 | 307 | [[package]] 308 | name = "core-foundation-sys" 309 | version = "0.8.6" 310 | source = "registry+https://github.com/rust-lang/crates.io-index" 311 | checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" 312 | 313 | [[package]] 314 | name = "cpufeatures" 315 | version = "0.2.12" 316 | source = "registry+https://github.com/rust-lang/crates.io-index" 317 | checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" 318 | dependencies = [ 319 | "libc", 320 | ] 321 | 322 | [[package]] 323 | name = "crc32fast" 324 | version = "1.4.0" 325 | source = "registry+https://github.com/rust-lang/crates.io-index" 326 | checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" 327 | dependencies = [ 328 | "cfg-if", 329 | ] 330 | 331 | [[package]] 332 | name = "crossbeam-deque" 333 | version = "0.8.5" 334 | source = "registry+https://github.com/rust-lang/crates.io-index" 335 | checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" 336 | dependencies = [ 337 | "crossbeam-epoch", 338 | "crossbeam-utils", 339 | ] 340 | 341 | [[package]] 342 | name = "crossbeam-epoch" 343 | version = "0.9.18" 344 | source = "registry+https://github.com/rust-lang/crates.io-index" 345 | checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" 346 | dependencies = [ 347 | "crossbeam-utils", 348 | ] 349 | 350 | [[package]] 351 | name = "crossbeam-utils" 352 | version = "0.8.20" 353 | source = "registry+https://github.com/rust-lang/crates.io-index" 354 | checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" 355 | 356 | [[package]] 357 | name = "crunchy" 358 | version = "0.2.2" 359 | source = "registry+https://github.com/rust-lang/crates.io-index" 360 | checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" 361 | 362 | [[package]] 363 | name = "crypto-common" 364 | version = "0.1.6" 365 | source = "registry+https://github.com/rust-lang/crates.io-index" 366 | checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" 367 | dependencies = [ 368 | "generic-array", 369 | "typenum", 370 | ] 371 | 372 | [[package]] 373 | name = "csv" 374 | version = "1.3.0" 375 | source = "registry+https://github.com/rust-lang/crates.io-index" 376 | checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" 377 | dependencies = [ 378 | "csv-core", 379 | "itoa", 380 | "ryu", 381 | "serde", 382 | ] 383 | 384 | [[package]] 385 | name = "csv-core" 386 | version = "0.1.11" 387 | source = "registry+https://github.com/rust-lang/crates.io-index" 388 | checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" 389 | dependencies = [ 390 | "memchr", 391 | ] 392 | 393 | [[package]] 394 | name = "dav1d" 395 | version = "0.10.3" 396 | source = "registry+https://github.com/rust-lang/crates.io-index" 397 | checksum = "0d4b54a40baf633a71c6f0fb49494a7e4ee7bc26f3e727212b6cb915aa1ea1e1" 398 | dependencies = [ 399 | "av-data", 400 | "bitflags 2.5.0", 401 | "dav1d-sys", 402 | "static_assertions", 403 | ] 404 | 405 | [[package]] 406 | name = "dav1d-sys" 407 | version = "0.8.2" 408 | source = "registry+https://github.com/rust-lang/crates.io-index" 409 | checksum = "6ecb1c5e8f4dc438eedc1b534a54672fb0e0a56035dae6b50162787bd2c50e95" 410 | dependencies = [ 411 | "libc", 412 | "system-deps", 413 | ] 414 | 415 | [[package]] 416 | name = "dcv-color-primitives" 417 | version = "0.6.1" 418 | source = "registry+https://github.com/rust-lang/crates.io-index" 419 | checksum = "07ad62edfed069700a5b33af6babd29c498d7e33eb01d96ffa8841ee1841634c" 420 | dependencies = [ 421 | "paste", 422 | "wasm-bindgen", 423 | ] 424 | 425 | [[package]] 426 | name = "digest" 427 | version = "0.10.7" 428 | source = "registry+https://github.com/rust-lang/crates.io-index" 429 | checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" 430 | dependencies = [ 431 | "block-buffer", 432 | "crypto-common", 433 | ] 434 | 435 | [[package]] 436 | name = "dirs" 437 | version = "5.0.1" 438 | source = "registry+https://github.com/rust-lang/crates.io-index" 439 | checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" 440 | dependencies = [ 441 | "dirs-sys", 442 | ] 443 | 444 | [[package]] 445 | name = "dirs-sys" 446 | version = "0.4.1" 447 | source = "registry+https://github.com/rust-lang/crates.io-index" 448 | checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" 449 | dependencies = [ 450 | "libc", 451 | "option-ext", 452 | "redox_users", 453 | "windows-sys 0.48.0", 454 | ] 455 | 456 | [[package]] 457 | name = "either" 458 | version = "1.11.0" 459 | source = "registry+https://github.com/rust-lang/crates.io-index" 460 | checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" 461 | 462 | [[package]] 463 | name = "encode_unicode" 464 | version = "0.3.6" 465 | source = "registry+https://github.com/rust-lang/crates.io-index" 466 | checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" 467 | 468 | [[package]] 469 | name = "equivalent" 470 | version = "1.0.1" 471 | source = "registry+https://github.com/rust-lang/crates.io-index" 472 | checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" 473 | 474 | [[package]] 475 | name = "errno" 476 | version = "0.3.8" 477 | source = "registry+https://github.com/rust-lang/crates.io-index" 478 | checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" 479 | dependencies = [ 480 | "libc", 481 | "windows-sys 0.52.0", 482 | ] 483 | 484 | [[package]] 485 | name = "exr" 486 | version = "1.72.0" 487 | source = "registry+https://github.com/rust-lang/crates.io-index" 488 | checksum = "887d93f60543e9a9362ef8a21beedd0a833c5d9610e18c67abe15a5963dcb1a4" 489 | dependencies = [ 490 | "bit_field", 491 | "flume", 492 | "half", 493 | "lebe", 494 | "miniz_oxide", 495 | "rayon-core", 496 | "smallvec", 497 | "zune-inflate", 498 | ] 499 | 500 | [[package]] 501 | name = "fallible_collections" 502 | version = "0.4.9" 503 | source = "registry+https://github.com/rust-lang/crates.io-index" 504 | checksum = "a88c69768c0a15262df21899142bc6df9b9b823546d4b4b9a7bc2d6c448ec6fd" 505 | dependencies = [ 506 | "hashbrown 0.13.2", 507 | ] 508 | 509 | [[package]] 510 | name = "fastrand" 511 | version = "2.0.2" 512 | source = "registry+https://github.com/rust-lang/crates.io-index" 513 | checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" 514 | 515 | [[package]] 516 | name = "fdeflate" 517 | version = "0.3.4" 518 | source = "registry+https://github.com/rust-lang/crates.io-index" 519 | checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645" 520 | dependencies = [ 521 | "simd-adler32", 522 | ] 523 | 524 | [[package]] 525 | name = "filetime" 526 | version = "0.2.23" 527 | source = "registry+https://github.com/rust-lang/crates.io-index" 528 | checksum = "1ee447700ac8aa0b2f2bd7bc4462ad686ba06baa6727ac149a2d6277f0d240fd" 529 | dependencies = [ 530 | "cfg-if", 531 | "libc", 532 | "redox_syscall", 533 | "windows-sys 0.52.0", 534 | ] 535 | 536 | [[package]] 537 | name = "flate2" 538 | version = "1.0.29" 539 | source = "registry+https://github.com/rust-lang/crates.io-index" 540 | checksum = "4556222738635b7a3417ae6130d8f52201e45a0c4d1a907f0826383adb5f85e7" 541 | dependencies = [ 542 | "crc32fast", 543 | "miniz_oxide", 544 | ] 545 | 546 | [[package]] 547 | name = "flume" 548 | version = "0.11.0" 549 | source = "registry+https://github.com/rust-lang/crates.io-index" 550 | checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" 551 | dependencies = [ 552 | "spin", 553 | ] 554 | 555 | [[package]] 556 | name = "foreign-types" 557 | version = "0.3.2" 558 | source = "registry+https://github.com/rust-lang/crates.io-index" 559 | checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" 560 | dependencies = [ 561 | "foreign-types-shared", 562 | ] 563 | 564 | [[package]] 565 | name = "foreign-types-shared" 566 | version = "0.1.1" 567 | source = "registry+https://github.com/rust-lang/crates.io-index" 568 | checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" 569 | 570 | [[package]] 571 | name = "form_urlencoded" 572 | version = "1.2.1" 573 | source = "registry+https://github.com/rust-lang/crates.io-index" 574 | checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" 575 | dependencies = [ 576 | "percent-encoding", 577 | ] 578 | 579 | [[package]] 580 | name = "generic-array" 581 | version = "0.14.7" 582 | source = "registry+https://github.com/rust-lang/crates.io-index" 583 | checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" 584 | dependencies = [ 585 | "typenum", 586 | "version_check", 587 | ] 588 | 589 | [[package]] 590 | name = "geo-types" 591 | version = "0.7.13" 592 | source = "registry+https://github.com/rust-lang/crates.io-index" 593 | checksum = "9ff16065e5720f376fbced200a5ae0f47ace85fd70b7e54269790281353b6d61" 594 | dependencies = [ 595 | "approx", 596 | "num-traits", 597 | "serde", 598 | ] 599 | 600 | [[package]] 601 | name = "getrandom" 602 | version = "0.2.14" 603 | source = "registry+https://github.com/rust-lang/crates.io-index" 604 | checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" 605 | dependencies = [ 606 | "cfg-if", 607 | "libc", 608 | "wasi", 609 | ] 610 | 611 | [[package]] 612 | name = "gif" 613 | version = "0.13.1" 614 | source = "registry+https://github.com/rust-lang/crates.io-index" 615 | checksum = "3fb2d69b19215e18bb912fa30f7ce15846e301408695e44e0ef719f1da9e19f2" 616 | dependencies = [ 617 | "color_quant", 618 | "weezl", 619 | ] 620 | 621 | [[package]] 622 | name = "glob" 623 | version = "0.3.1" 624 | source = "registry+https://github.com/rust-lang/crates.io-index" 625 | checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" 626 | 627 | [[package]] 628 | name = "half" 629 | version = "2.4.1" 630 | source = "registry+https://github.com/rust-lang/crates.io-index" 631 | checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" 632 | dependencies = [ 633 | "cfg-if", 634 | "crunchy", 635 | ] 636 | 637 | [[package]] 638 | name = "hashbrown" 639 | version = "0.13.2" 640 | source = "registry+https://github.com/rust-lang/crates.io-index" 641 | checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" 642 | dependencies = [ 643 | "ahash", 644 | ] 645 | 646 | [[package]] 647 | name = "hashbrown" 648 | version = "0.14.5" 649 | source = "registry+https://github.com/rust-lang/crates.io-index" 650 | checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" 651 | 652 | [[package]] 653 | name = "heck" 654 | version = "0.5.0" 655 | source = "registry+https://github.com/rust-lang/crates.io-index" 656 | checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" 657 | 658 | [[package]] 659 | name = "hf-hub" 660 | version = "0.3.2" 661 | source = "registry+https://github.com/rust-lang/crates.io-index" 662 | checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732" 663 | dependencies = [ 664 | "dirs", 665 | "indicatif", 666 | "log", 667 | "native-tls", 668 | "rand", 669 | "serde", 670 | "serde_json", 671 | "thiserror", 672 | "ureq", 673 | ] 674 | 675 | [[package]] 676 | name = "hocr-parser" 677 | version = "0.1.0" 678 | source = "registry+https://github.com/rust-lang/crates.io-index" 679 | checksum = "d2d715372e28707019d47f461e741bae6663692be0098adb56b0a777e6da0086" 680 | dependencies = [ 681 | "roxmltree", 682 | "thiserror", 683 | ] 684 | 685 | [[package]] 686 | name = "home" 687 | version = "0.5.9" 688 | source = "registry+https://github.com/rust-lang/crates.io-index" 689 | checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" 690 | dependencies = [ 691 | "windows-sys 0.52.0", 692 | ] 693 | 694 | [[package]] 695 | name = "idna" 696 | version = "0.5.0" 697 | source = "registry+https://github.com/rust-lang/crates.io-index" 698 | checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" 699 | dependencies = [ 700 | "unicode-bidi", 701 | "unicode-normalization", 702 | ] 703 | 704 | [[package]] 705 | name = "image" 706 | version = "0.25.1" 707 | source = "registry+https://github.com/rust-lang/crates.io-index" 708 | checksum = "fd54d660e773627692c524beaad361aca785a4f9f5730ce91f42aabe5bce3d11" 709 | dependencies = [ 710 | "bytemuck", 711 | "byteorder", 712 | "color_quant", 713 | "dav1d", 714 | "dcv-color-primitives", 715 | "exr", 716 | "gif", 717 | "image-webp", 718 | "mp4parse", 719 | "num-traits", 720 | "png", 721 | "qoi", 722 | "ravif", 723 | "rayon", 724 | "rgb", 725 | "tiff", 726 | "zune-core", 727 | "zune-jpeg", 728 | ] 729 | 730 | [[package]] 731 | name = "image-webp" 732 | version = "0.1.2" 733 | source = "registry+https://github.com/rust-lang/crates.io-index" 734 | checksum = "d730b085583c4d789dfd07fdcf185be59501666a90c97c40162b37e4fdad272d" 735 | dependencies = [ 736 | "byteorder-lite", 737 | "thiserror", 738 | ] 739 | 740 | [[package]] 741 | name = "imgref" 742 | version = "1.10.1" 743 | source = "registry+https://github.com/rust-lang/crates.io-index" 744 | checksum = "44feda355f4159a7c757171a77de25daf6411e217b4cabd03bd6650690468126" 745 | 746 | [[package]] 747 | name = "indexmap" 748 | version = "2.2.6" 749 | source = "registry+https://github.com/rust-lang/crates.io-index" 750 | checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" 751 | dependencies = [ 752 | "equivalent", 753 | "hashbrown 0.14.5", 754 | ] 755 | 756 | [[package]] 757 | name = "indicatif" 758 | version = "0.17.8" 759 | source = "registry+https://github.com/rust-lang/crates.io-index" 760 | checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" 761 | dependencies = [ 762 | "console", 763 | "instant", 764 | "number_prefix", 765 | "portable-atomic", 766 | "unicode-width", 767 | ] 768 | 769 | [[package]] 770 | name = "instant" 771 | version = "0.1.12" 772 | source = "registry+https://github.com/rust-lang/crates.io-index" 773 | checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" 774 | dependencies = [ 775 | "cfg-if", 776 | ] 777 | 778 | [[package]] 779 | name = "interpolate_name" 780 | version = "0.2.4" 781 | source = "registry+https://github.com/rust-lang/crates.io-index" 782 | checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" 783 | dependencies = [ 784 | "proc-macro2", 785 | "quote", 786 | "syn 2.0.60", 787 | ] 788 | 789 | [[package]] 790 | name = "itertools" 791 | version = "0.12.1" 792 | source = "registry+https://github.com/rust-lang/crates.io-index" 793 | checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" 794 | dependencies = [ 795 | "either", 796 | ] 797 | 798 | [[package]] 799 | name = "itoa" 800 | version = "1.0.11" 801 | source = "registry+https://github.com/rust-lang/crates.io-index" 802 | checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" 803 | 804 | [[package]] 805 | name = "jobserver" 806 | version = "0.1.31" 807 | source = "registry+https://github.com/rust-lang/crates.io-index" 808 | checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" 809 | dependencies = [ 810 | "libc", 811 | ] 812 | 813 | [[package]] 814 | name = "jpeg-decoder" 815 | version = "0.3.1" 816 | source = "registry+https://github.com/rust-lang/crates.io-index" 817 | checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" 818 | 819 | [[package]] 820 | name = "js-sys" 821 | version = "0.3.69" 822 | source = "registry+https://github.com/rust-lang/crates.io-index" 823 | checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" 824 | dependencies = [ 825 | "wasm-bindgen", 826 | ] 827 | 828 | [[package]] 829 | name = "layoutparser-ort" 830 | version = "0.1.0" 831 | dependencies = [ 832 | "csv", 833 | "geo-types", 834 | "hf-hub", 835 | "hocr-parser", 836 | "image", 837 | "itertools", 838 | "ndarray", 839 | "ort", 840 | "serde", 841 | "tesseract", 842 | "thiserror", 843 | "tracing", 844 | ] 845 | 846 | [[package]] 847 | name = "lazy_static" 848 | version = "1.4.0" 849 | source = "registry+https://github.com/rust-lang/crates.io-index" 850 | checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" 851 | 852 | [[package]] 853 | name = "lazycell" 854 | version = "1.3.0" 855 | source = "registry+https://github.com/rust-lang/crates.io-index" 856 | checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" 857 | 858 | [[package]] 859 | name = "lebe" 860 | version = "0.5.2" 861 | source = "registry+https://github.com/rust-lang/crates.io-index" 862 | checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" 863 | 864 | [[package]] 865 | name = "leptonica-plumbing" 866 | version = "1.4.0" 867 | source = "registry+https://github.com/rust-lang/crates.io-index" 868 | checksum = "cc7a74c43d6f090d39158d233f326f47cd8bba545217595c93662b4e31156f42" 869 | dependencies = [ 870 | "leptonica-sys", 871 | "libc", 872 | "thiserror", 873 | ] 874 | 875 | [[package]] 876 | name = "leptonica-sys" 877 | version = "0.4.7" 878 | source = "registry+https://github.com/rust-lang/crates.io-index" 879 | checksum = "335aadd5fa8d493d62d5596a980ce6ee823a72da45e89bcc45be3841e6d74bff" 880 | dependencies = [ 881 | "bindgen", 882 | "pkg-config", 883 | "vcpkg", 884 | ] 885 | 886 | [[package]] 887 | name = "libc" 888 | version = "0.2.153" 889 | source = "registry+https://github.com/rust-lang/crates.io-index" 890 | checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" 891 | 892 | [[package]] 893 | name = "libfuzzer-sys" 894 | version = "0.4.7" 895 | source = "registry+https://github.com/rust-lang/crates.io-index" 896 | checksum = "a96cfd5557eb82f2b83fed4955246c988d331975a002961b07c81584d107e7f7" 897 | dependencies = [ 898 | "arbitrary", 899 | "cc", 900 | "once_cell", 901 | ] 902 | 903 | [[package]] 904 | name = "libloading" 905 | version = "0.8.3" 906 | source = "registry+https://github.com/rust-lang/crates.io-index" 907 | checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" 908 | dependencies = [ 909 | "cfg-if", 910 | "windows-targets 0.52.5", 911 | ] 912 | 913 | [[package]] 914 | name = "libm" 915 | version = "0.2.8" 916 | source = "registry+https://github.com/rust-lang/crates.io-index" 917 | checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" 918 | 919 | [[package]] 920 | name = "libredox" 921 | version = "0.1.3" 922 | source = "registry+https://github.com/rust-lang/crates.io-index" 923 | checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" 924 | dependencies = [ 925 | "bitflags 2.5.0", 926 | "libc", 927 | ] 928 | 929 | [[package]] 930 | name = "linux-raw-sys" 931 | version = "0.4.13" 932 | source = "registry+https://github.com/rust-lang/crates.io-index" 933 | checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" 934 | 935 | [[package]] 936 | name = "lock_api" 937 | version = "0.4.12" 938 | source = "registry+https://github.com/rust-lang/crates.io-index" 939 | checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" 940 | dependencies = [ 941 | "autocfg", 942 | "scopeguard", 943 | ] 944 | 945 | [[package]] 946 | name = "log" 947 | version = "0.4.21" 948 | source = "registry+https://github.com/rust-lang/crates.io-index" 949 | checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" 950 | 951 | [[package]] 952 | name = "loop9" 953 | version = "0.1.5" 954 | source = "registry+https://github.com/rust-lang/crates.io-index" 955 | checksum = "0fae87c125b03c1d2c0150c90365d7d6bcc53fb73a9acaef207d2d065860f062" 956 | dependencies = [ 957 | "imgref", 958 | ] 959 | 960 | [[package]] 961 | name = "matrixmultiply" 962 | version = "0.3.8" 963 | source = "registry+https://github.com/rust-lang/crates.io-index" 964 | checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" 965 | dependencies = [ 966 | "autocfg", 967 | "rawpointer", 968 | ] 969 | 970 | [[package]] 971 | name = "maybe-rayon" 972 | version = "0.1.1" 973 | source = "registry+https://github.com/rust-lang/crates.io-index" 974 | checksum = "8ea1f30cedd69f0a2954655f7188c6a834246d2bcf1e315e2ac40c4b24dc9519" 975 | dependencies = [ 976 | "cfg-if", 977 | "rayon", 978 | ] 979 | 980 | [[package]] 981 | name = "memchr" 982 | version = "2.7.2" 983 | source = "registry+https://github.com/rust-lang/crates.io-index" 984 | checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" 985 | 986 | [[package]] 987 | name = "minimal-lexical" 988 | version = "0.2.1" 989 | source = "registry+https://github.com/rust-lang/crates.io-index" 990 | checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" 991 | 992 | [[package]] 993 | name = "miniz_oxide" 994 | version = "0.7.2" 995 | source = "registry+https://github.com/rust-lang/crates.io-index" 996 | checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" 997 | dependencies = [ 998 | "adler", 999 | "simd-adler32", 1000 | ] 1001 | 1002 | [[package]] 1003 | name = "mp4parse" 1004 | version = "0.17.0" 1005 | source = "registry+https://github.com/rust-lang/crates.io-index" 1006 | checksum = "63a35203d3c6ce92d5251c77520acb2e57108c88728695aa883f70023624c570" 1007 | dependencies = [ 1008 | "bitreader", 1009 | "byteorder", 1010 | "fallible_collections", 1011 | "log", 1012 | "num-traits", 1013 | "static_assertions", 1014 | ] 1015 | 1016 | [[package]] 1017 | name = "nasm-rs" 1018 | version = "0.2.5" 1019 | source = "registry+https://github.com/rust-lang/crates.io-index" 1020 | checksum = "fe4d98d0065f4b1daf164b3eafb11974c94662e5e2396cf03f32d0bb5c17da51" 1021 | dependencies = [ 1022 | "rayon", 1023 | ] 1024 | 1025 | [[package]] 1026 | name = "native-tls" 1027 | version = "0.2.11" 1028 | source = "registry+https://github.com/rust-lang/crates.io-index" 1029 | checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" 1030 | dependencies = [ 1031 | "lazy_static", 1032 | "libc", 1033 | "log", 1034 | "openssl", 1035 | "openssl-probe", 1036 | "openssl-sys", 1037 | "schannel", 1038 | "security-framework", 1039 | "security-framework-sys", 1040 | "tempfile", 1041 | ] 1042 | 1043 | [[package]] 1044 | name = "ndarray" 1045 | version = "0.15.6" 1046 | source = "registry+https://github.com/rust-lang/crates.io-index" 1047 | checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" 1048 | dependencies = [ 1049 | "matrixmultiply", 1050 | "num-complex", 1051 | "num-integer", 1052 | "num-traits", 1053 | "rawpointer", 1054 | ] 1055 | 1056 | [[package]] 1057 | name = "new_debug_unreachable" 1058 | version = "1.0.6" 1059 | source = "registry+https://github.com/rust-lang/crates.io-index" 1060 | checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" 1061 | 1062 | [[package]] 1063 | name = "nom" 1064 | version = "7.1.3" 1065 | source = "registry+https://github.com/rust-lang/crates.io-index" 1066 | checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" 1067 | dependencies = [ 1068 | "memchr", 1069 | "minimal-lexical", 1070 | ] 1071 | 1072 | [[package]] 1073 | name = "noop_proc_macro" 1074 | version = "0.3.0" 1075 | source = "registry+https://github.com/rust-lang/crates.io-index" 1076 | checksum = "0676bb32a98c1a483ce53e500a81ad9c3d5b3f7c920c28c24e9cb0980d0b5bc8" 1077 | 1078 | [[package]] 1079 | name = "num-bigint" 1080 | version = "0.4.5" 1081 | source = "registry+https://github.com/rust-lang/crates.io-index" 1082 | checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" 1083 | dependencies = [ 1084 | "num-integer", 1085 | "num-traits", 1086 | ] 1087 | 1088 | [[package]] 1089 | name = "num-complex" 1090 | version = "0.4.5" 1091 | source = "registry+https://github.com/rust-lang/crates.io-index" 1092 | checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" 1093 | dependencies = [ 1094 | "num-traits", 1095 | ] 1096 | 1097 | [[package]] 1098 | name = "num-derive" 1099 | version = "0.4.2" 1100 | source = "registry+https://github.com/rust-lang/crates.io-index" 1101 | checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" 1102 | dependencies = [ 1103 | "proc-macro2", 1104 | "quote", 1105 | "syn 2.0.60", 1106 | ] 1107 | 1108 | [[package]] 1109 | name = "num-integer" 1110 | version = "0.1.46" 1111 | source = "registry+https://github.com/rust-lang/crates.io-index" 1112 | checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" 1113 | dependencies = [ 1114 | "num-traits", 1115 | ] 1116 | 1117 | [[package]] 1118 | name = "num-rational" 1119 | version = "0.4.2" 1120 | source = "registry+https://github.com/rust-lang/crates.io-index" 1121 | checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" 1122 | dependencies = [ 1123 | "num-bigint", 1124 | "num-integer", 1125 | "num-traits", 1126 | ] 1127 | 1128 | [[package]] 1129 | name = "num-traits" 1130 | version = "0.2.18" 1131 | source = "registry+https://github.com/rust-lang/crates.io-index" 1132 | checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" 1133 | dependencies = [ 1134 | "autocfg", 1135 | "libm", 1136 | ] 1137 | 1138 | [[package]] 1139 | name = "number_prefix" 1140 | version = "0.4.0" 1141 | source = "registry+https://github.com/rust-lang/crates.io-index" 1142 | checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" 1143 | 1144 | [[package]] 1145 | name = "once_cell" 1146 | version = "1.19.0" 1147 | source = "registry+https://github.com/rust-lang/crates.io-index" 1148 | checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" 1149 | 1150 | [[package]] 1151 | name = "openssl" 1152 | version = "0.10.64" 1153 | source = "registry+https://github.com/rust-lang/crates.io-index" 1154 | checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" 1155 | dependencies = [ 1156 | "bitflags 2.5.0", 1157 | "cfg-if", 1158 | "foreign-types", 1159 | "libc", 1160 | "once_cell", 1161 | "openssl-macros", 1162 | "openssl-sys", 1163 | ] 1164 | 1165 | [[package]] 1166 | name = "openssl-macros" 1167 | version = "0.1.1" 1168 | source = "registry+https://github.com/rust-lang/crates.io-index" 1169 | checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" 1170 | dependencies = [ 1171 | "proc-macro2", 1172 | "quote", 1173 | "syn 2.0.60", 1174 | ] 1175 | 1176 | [[package]] 1177 | name = "openssl-probe" 1178 | version = "0.1.5" 1179 | source = "registry+https://github.com/rust-lang/crates.io-index" 1180 | checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" 1181 | 1182 | [[package]] 1183 | name = "openssl-sys" 1184 | version = "0.9.102" 1185 | source = "registry+https://github.com/rust-lang/crates.io-index" 1186 | checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" 1187 | dependencies = [ 1188 | "cc", 1189 | "libc", 1190 | "pkg-config", 1191 | "vcpkg", 1192 | ] 1193 | 1194 | [[package]] 1195 | name = "option-ext" 1196 | version = "0.2.0" 1197 | source = "registry+https://github.com/rust-lang/crates.io-index" 1198 | checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" 1199 | 1200 | [[package]] 1201 | name = "ort" 1202 | version = "2.0.0-rc.2" 1203 | source = "registry+https://github.com/rust-lang/crates.io-index" 1204 | checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14" 1205 | dependencies = [ 1206 | "half", 1207 | "js-sys", 1208 | "ndarray", 1209 | "ort-sys", 1210 | "thiserror", 1211 | "tracing", 1212 | "web-sys", 1213 | ] 1214 | 1215 | [[package]] 1216 | name = "ort-sys" 1217 | version = "2.0.0-rc.2" 1218 | source = "registry+https://github.com/rust-lang/crates.io-index" 1219 | checksum = "b3d9c1373fc813d3f024d394f621f4c6dde0734c79b1c17113c3bb5bf0084bbe" 1220 | dependencies = [ 1221 | "flate2", 1222 | "sha2", 1223 | "tar", 1224 | "ureq", 1225 | ] 1226 | 1227 | [[package]] 1228 | name = "paste" 1229 | version = "1.0.15" 1230 | source = "registry+https://github.com/rust-lang/crates.io-index" 1231 | checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" 1232 | 1233 | [[package]] 1234 | name = "peeking_take_while" 1235 | version = "0.1.2" 1236 | source = "registry+https://github.com/rust-lang/crates.io-index" 1237 | checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" 1238 | 1239 | [[package]] 1240 | name = "percent-encoding" 1241 | version = "2.3.1" 1242 | source = "registry+https://github.com/rust-lang/crates.io-index" 1243 | checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" 1244 | 1245 | [[package]] 1246 | name = "pin-project-lite" 1247 | version = "0.2.14" 1248 | source = "registry+https://github.com/rust-lang/crates.io-index" 1249 | checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" 1250 | 1251 | [[package]] 1252 | name = "pkg-config" 1253 | version = "0.3.30" 1254 | source = "registry+https://github.com/rust-lang/crates.io-index" 1255 | checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" 1256 | 1257 | [[package]] 1258 | name = "png" 1259 | version = "0.17.13" 1260 | source = "registry+https://github.com/rust-lang/crates.io-index" 1261 | checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1" 1262 | dependencies = [ 1263 | "bitflags 1.3.2", 1264 | "crc32fast", 1265 | "fdeflate", 1266 | "flate2", 1267 | "miniz_oxide", 1268 | ] 1269 | 1270 | [[package]] 1271 | name = "portable-atomic" 1272 | version = "1.6.0" 1273 | source = "registry+https://github.com/rust-lang/crates.io-index" 1274 | checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" 1275 | 1276 | [[package]] 1277 | name = "ppv-lite86" 1278 | version = "0.2.17" 1279 | source = "registry+https://github.com/rust-lang/crates.io-index" 1280 | checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" 1281 | 1282 | [[package]] 1283 | name = "proc-macro2" 1284 | version = "1.0.81" 1285 | source = "registry+https://github.com/rust-lang/crates.io-index" 1286 | checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" 1287 | dependencies = [ 1288 | "unicode-ident", 1289 | ] 1290 | 1291 | [[package]] 1292 | name = "profiling" 1293 | version = "1.0.15" 1294 | source = "registry+https://github.com/rust-lang/crates.io-index" 1295 | checksum = "43d84d1d7a6ac92673717f9f6d1518374ef257669c24ebc5ac25d5033828be58" 1296 | dependencies = [ 1297 | "profiling-procmacros", 1298 | ] 1299 | 1300 | [[package]] 1301 | name = "profiling-procmacros" 1302 | version = "1.0.15" 1303 | source = "registry+https://github.com/rust-lang/crates.io-index" 1304 | checksum = "8021cf59c8ec9c432cfc2526ac6b8aa508ecaf29cd415f271b8406c1b851c3fd" 1305 | dependencies = [ 1306 | "quote", 1307 | "syn 2.0.60", 1308 | ] 1309 | 1310 | [[package]] 1311 | name = "qoi" 1312 | version = "0.4.1" 1313 | source = "registry+https://github.com/rust-lang/crates.io-index" 1314 | checksum = "7f6d64c71eb498fe9eae14ce4ec935c555749aef511cca85b5568910d6e48001" 1315 | dependencies = [ 1316 | "bytemuck", 1317 | ] 1318 | 1319 | [[package]] 1320 | name = "quick-error" 1321 | version = "2.0.1" 1322 | source = "registry+https://github.com/rust-lang/crates.io-index" 1323 | checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" 1324 | 1325 | [[package]] 1326 | name = "quote" 1327 | version = "1.0.36" 1328 | source = "registry+https://github.com/rust-lang/crates.io-index" 1329 | checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" 1330 | dependencies = [ 1331 | "proc-macro2", 1332 | ] 1333 | 1334 | [[package]] 1335 | name = "rand" 1336 | version = "0.8.5" 1337 | source = "registry+https://github.com/rust-lang/crates.io-index" 1338 | checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" 1339 | dependencies = [ 1340 | "libc", 1341 | "rand_chacha", 1342 | "rand_core", 1343 | ] 1344 | 1345 | [[package]] 1346 | name = "rand_chacha" 1347 | version = "0.3.1" 1348 | source = "registry+https://github.com/rust-lang/crates.io-index" 1349 | checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" 1350 | dependencies = [ 1351 | "ppv-lite86", 1352 | "rand_core", 1353 | ] 1354 | 1355 | [[package]] 1356 | name = "rand_core" 1357 | version = "0.6.4" 1358 | source = "registry+https://github.com/rust-lang/crates.io-index" 1359 | checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" 1360 | dependencies = [ 1361 | "getrandom", 1362 | ] 1363 | 1364 | [[package]] 1365 | name = "rav1e" 1366 | version = "0.7.1" 1367 | source = "registry+https://github.com/rust-lang/crates.io-index" 1368 | checksum = "cd87ce80a7665b1cce111f8a16c1f3929f6547ce91ade6addf4ec86a8dda5ce9" 1369 | dependencies = [ 1370 | "arbitrary", 1371 | "arg_enum_proc_macro", 1372 | "arrayvec", 1373 | "av1-grain", 1374 | "bitstream-io", 1375 | "built", 1376 | "cc", 1377 | "cfg-if", 1378 | "interpolate_name", 1379 | "itertools", 1380 | "libc", 1381 | "libfuzzer-sys", 1382 | "log", 1383 | "maybe-rayon", 1384 | "nasm-rs", 1385 | "new_debug_unreachable", 1386 | "noop_proc_macro", 1387 | "num-derive", 1388 | "num-traits", 1389 | "once_cell", 1390 | "paste", 1391 | "profiling", 1392 | "rand", 1393 | "rand_chacha", 1394 | "simd_helpers", 1395 | "system-deps", 1396 | "thiserror", 1397 | "v_frame", 1398 | "wasm-bindgen", 1399 | ] 1400 | 1401 | [[package]] 1402 | name = "ravif" 1403 | version = "0.11.5" 1404 | source = "registry+https://github.com/rust-lang/crates.io-index" 1405 | checksum = "bc13288f5ab39e6d7c9d501759712e6969fcc9734220846fc9ed26cae2cc4234" 1406 | dependencies = [ 1407 | "avif-serialize", 1408 | "imgref", 1409 | "loop9", 1410 | "quick-error", 1411 | "rav1e", 1412 | "rayon", 1413 | "rgb", 1414 | ] 1415 | 1416 | [[package]] 1417 | name = "rawpointer" 1418 | version = "0.2.1" 1419 | source = "registry+https://github.com/rust-lang/crates.io-index" 1420 | checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" 1421 | 1422 | [[package]] 1423 | name = "rayon" 1424 | version = "1.10.0" 1425 | source = "registry+https://github.com/rust-lang/crates.io-index" 1426 | checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" 1427 | dependencies = [ 1428 | "either", 1429 | "rayon-core", 1430 | ] 1431 | 1432 | [[package]] 1433 | name = "rayon-core" 1434 | version = "1.12.1" 1435 | source = "registry+https://github.com/rust-lang/crates.io-index" 1436 | checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" 1437 | dependencies = [ 1438 | "crossbeam-deque", 1439 | "crossbeam-utils", 1440 | ] 1441 | 1442 | [[package]] 1443 | name = "redox_syscall" 1444 | version = "0.4.1" 1445 | source = "registry+https://github.com/rust-lang/crates.io-index" 1446 | checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" 1447 | dependencies = [ 1448 | "bitflags 1.3.2", 1449 | ] 1450 | 1451 | [[package]] 1452 | name = "redox_users" 1453 | version = "0.4.5" 1454 | source = "registry+https://github.com/rust-lang/crates.io-index" 1455 | checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" 1456 | dependencies = [ 1457 | "getrandom", 1458 | "libredox", 1459 | "thiserror", 1460 | ] 1461 | 1462 | [[package]] 1463 | name = "regex" 1464 | version = "1.10.4" 1465 | source = "registry+https://github.com/rust-lang/crates.io-index" 1466 | checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" 1467 | dependencies = [ 1468 | "aho-corasick", 1469 | "memchr", 1470 | "regex-automata", 1471 | "regex-syntax", 1472 | ] 1473 | 1474 | [[package]] 1475 | name = "regex-automata" 1476 | version = "0.4.6" 1477 | source = "registry+https://github.com/rust-lang/crates.io-index" 1478 | checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" 1479 | dependencies = [ 1480 | "aho-corasick", 1481 | "memchr", 1482 | "regex-syntax", 1483 | ] 1484 | 1485 | [[package]] 1486 | name = "regex-syntax" 1487 | version = "0.8.3" 1488 | source = "registry+https://github.com/rust-lang/crates.io-index" 1489 | checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" 1490 | 1491 | [[package]] 1492 | name = "rgb" 1493 | version = "0.8.37" 1494 | source = "registry+https://github.com/rust-lang/crates.io-index" 1495 | checksum = "05aaa8004b64fd573fc9d002f4e632d51ad4f026c2b5ba95fcb6c2f32c2c47d8" 1496 | dependencies = [ 1497 | "bytemuck", 1498 | ] 1499 | 1500 | [[package]] 1501 | name = "ring" 1502 | version = "0.17.8" 1503 | source = "registry+https://github.com/rust-lang/crates.io-index" 1504 | checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" 1505 | dependencies = [ 1506 | "cc", 1507 | "cfg-if", 1508 | "getrandom", 1509 | "libc", 1510 | "spin", 1511 | "untrusted", 1512 | "windows-sys 0.52.0", 1513 | ] 1514 | 1515 | [[package]] 1516 | name = "roxmltree" 1517 | version = "0.19.0" 1518 | source = "registry+https://github.com/rust-lang/crates.io-index" 1519 | checksum = "3cd14fd5e3b777a7422cca79358c57a8f6e3a703d9ac187448d0daf220c2407f" 1520 | 1521 | [[package]] 1522 | name = "rustc-hash" 1523 | version = "1.1.0" 1524 | source = "registry+https://github.com/rust-lang/crates.io-index" 1525 | checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" 1526 | 1527 | [[package]] 1528 | name = "rustix" 1529 | version = "0.38.34" 1530 | source = "registry+https://github.com/rust-lang/crates.io-index" 1531 | checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" 1532 | dependencies = [ 1533 | "bitflags 2.5.0", 1534 | "errno", 1535 | "libc", 1536 | "linux-raw-sys", 1537 | "windows-sys 0.52.0", 1538 | ] 1539 | 1540 | [[package]] 1541 | name = "rustls" 1542 | version = "0.22.4" 1543 | source = "registry+https://github.com/rust-lang/crates.io-index" 1544 | checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" 1545 | dependencies = [ 1546 | "log", 1547 | "ring", 1548 | "rustls-pki-types", 1549 | "rustls-webpki", 1550 | "subtle", 1551 | "zeroize", 1552 | ] 1553 | 1554 | [[package]] 1555 | name = "rustls-pki-types" 1556 | version = "1.5.0" 1557 | source = "registry+https://github.com/rust-lang/crates.io-index" 1558 | checksum = "beb461507cee2c2ff151784c52762cf4d9ff6a61f3e80968600ed24fa837fa54" 1559 | 1560 | [[package]] 1561 | name = "rustls-webpki" 1562 | version = "0.102.3" 1563 | source = "registry+https://github.com/rust-lang/crates.io-index" 1564 | checksum = "f3bce581c0dd41bce533ce695a1437fa16a7ab5ac3ccfa99fe1a620a7885eabf" 1565 | dependencies = [ 1566 | "ring", 1567 | "rustls-pki-types", 1568 | "untrusted", 1569 | ] 1570 | 1571 | [[package]] 1572 | name = "ryu" 1573 | version = "1.0.17" 1574 | source = "registry+https://github.com/rust-lang/crates.io-index" 1575 | checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" 1576 | 1577 | [[package]] 1578 | name = "schannel" 1579 | version = "0.1.23" 1580 | source = "registry+https://github.com/rust-lang/crates.io-index" 1581 | checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" 1582 | dependencies = [ 1583 | "windows-sys 0.52.0", 1584 | ] 1585 | 1586 | [[package]] 1587 | name = "scopeguard" 1588 | version = "1.2.0" 1589 | source = "registry+https://github.com/rust-lang/crates.io-index" 1590 | checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" 1591 | 1592 | [[package]] 1593 | name = "security-framework" 1594 | version = "2.10.0" 1595 | source = "registry+https://github.com/rust-lang/crates.io-index" 1596 | checksum = "770452e37cad93e0a50d5abc3990d2bc351c36d0328f86cefec2f2fb206eaef6" 1597 | dependencies = [ 1598 | "bitflags 1.3.2", 1599 | "core-foundation", 1600 | "core-foundation-sys", 1601 | "libc", 1602 | "security-framework-sys", 1603 | ] 1604 | 1605 | [[package]] 1606 | name = "security-framework-sys" 1607 | version = "2.10.0" 1608 | source = "registry+https://github.com/rust-lang/crates.io-index" 1609 | checksum = "41f3cc463c0ef97e11c3461a9d3787412d30e8e7eb907c79180c4a57bf7c04ef" 1610 | dependencies = [ 1611 | "core-foundation-sys", 1612 | "libc", 1613 | ] 1614 | 1615 | [[package]] 1616 | name = "serde" 1617 | version = "1.0.199" 1618 | source = "registry+https://github.com/rust-lang/crates.io-index" 1619 | checksum = "0c9f6e76df036c77cd94996771fb40db98187f096dd0b9af39c6c6e452ba966a" 1620 | dependencies = [ 1621 | "serde_derive", 1622 | ] 1623 | 1624 | [[package]] 1625 | name = "serde_derive" 1626 | version = "1.0.199" 1627 | source = "registry+https://github.com/rust-lang/crates.io-index" 1628 | checksum = "11bd257a6541e141e42ca6d24ae26f7714887b47e89aa739099104c7e4d3b7fc" 1629 | dependencies = [ 1630 | "proc-macro2", 1631 | "quote", 1632 | "syn 2.0.60", 1633 | ] 1634 | 1635 | [[package]] 1636 | name = "serde_json" 1637 | version = "1.0.116" 1638 | source = "registry+https://github.com/rust-lang/crates.io-index" 1639 | checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" 1640 | dependencies = [ 1641 | "itoa", 1642 | "ryu", 1643 | "serde", 1644 | ] 1645 | 1646 | [[package]] 1647 | name = "serde_spanned" 1648 | version = "0.6.6" 1649 | source = "registry+https://github.com/rust-lang/crates.io-index" 1650 | checksum = "79e674e01f999af37c49f70a6ede167a8a60b2503e56c5599532a65baa5969a0" 1651 | dependencies = [ 1652 | "serde", 1653 | ] 1654 | 1655 | [[package]] 1656 | name = "sha2" 1657 | version = "0.10.8" 1658 | source = "registry+https://github.com/rust-lang/crates.io-index" 1659 | checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" 1660 | dependencies = [ 1661 | "cfg-if", 1662 | "cpufeatures", 1663 | "digest", 1664 | ] 1665 | 1666 | [[package]] 1667 | name = "shlex" 1668 | version = "1.3.0" 1669 | source = "registry+https://github.com/rust-lang/crates.io-index" 1670 | checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" 1671 | 1672 | [[package]] 1673 | name = "simd-adler32" 1674 | version = "0.3.7" 1675 | source = "registry+https://github.com/rust-lang/crates.io-index" 1676 | checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" 1677 | 1678 | [[package]] 1679 | name = "simd_helpers" 1680 | version = "0.1.0" 1681 | source = "registry+https://github.com/rust-lang/crates.io-index" 1682 | checksum = "95890f873bec569a0362c235787f3aca6e1e887302ba4840839bcc6459c42da6" 1683 | dependencies = [ 1684 | "quote", 1685 | ] 1686 | 1687 | [[package]] 1688 | name = "smallvec" 1689 | version = "1.13.2" 1690 | source = "registry+https://github.com/rust-lang/crates.io-index" 1691 | checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" 1692 | 1693 | [[package]] 1694 | name = "spin" 1695 | version = "0.9.8" 1696 | source = "registry+https://github.com/rust-lang/crates.io-index" 1697 | checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" 1698 | dependencies = [ 1699 | "lock_api", 1700 | ] 1701 | 1702 | [[package]] 1703 | name = "static_assertions" 1704 | version = "1.1.0" 1705 | source = "registry+https://github.com/rust-lang/crates.io-index" 1706 | checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" 1707 | 1708 | [[package]] 1709 | name = "subtle" 1710 | version = "2.5.0" 1711 | source = "registry+https://github.com/rust-lang/crates.io-index" 1712 | checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" 1713 | 1714 | [[package]] 1715 | name = "syn" 1716 | version = "1.0.109" 1717 | source = "registry+https://github.com/rust-lang/crates.io-index" 1718 | checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" 1719 | dependencies = [ 1720 | "proc-macro2", 1721 | "quote", 1722 | "unicode-ident", 1723 | ] 1724 | 1725 | [[package]] 1726 | name = "syn" 1727 | version = "2.0.60" 1728 | source = "registry+https://github.com/rust-lang/crates.io-index" 1729 | checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" 1730 | dependencies = [ 1731 | "proc-macro2", 1732 | "quote", 1733 | "unicode-ident", 1734 | ] 1735 | 1736 | [[package]] 1737 | name = "system-deps" 1738 | version = "6.2.2" 1739 | source = "registry+https://github.com/rust-lang/crates.io-index" 1740 | checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349" 1741 | dependencies = [ 1742 | "cfg-expr", 1743 | "heck", 1744 | "pkg-config", 1745 | "toml", 1746 | "version-compare", 1747 | ] 1748 | 1749 | [[package]] 1750 | name = "tar" 1751 | version = "0.4.40" 1752 | source = "registry+https://github.com/rust-lang/crates.io-index" 1753 | checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb" 1754 | dependencies = [ 1755 | "filetime", 1756 | "libc", 1757 | "xattr", 1758 | ] 1759 | 1760 | [[package]] 1761 | name = "target-lexicon" 1762 | version = "0.12.14" 1763 | source = "registry+https://github.com/rust-lang/crates.io-index" 1764 | checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" 1765 | 1766 | [[package]] 1767 | name = "tempfile" 1768 | version = "3.10.1" 1769 | source = "registry+https://github.com/rust-lang/crates.io-index" 1770 | checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" 1771 | dependencies = [ 1772 | "cfg-if", 1773 | "fastrand", 1774 | "rustix", 1775 | "windows-sys 0.52.0", 1776 | ] 1777 | 1778 | [[package]] 1779 | name = "tesseract" 1780 | version = "0.15.1" 1781 | source = "registry+https://github.com/rust-lang/crates.io-index" 1782 | checksum = "220d5c325aa2fa6656edd8924ad9a91d7ac7b5e998fe0f083a84f7f06ec9fda7" 1783 | dependencies = [ 1784 | "tesseract-plumbing", 1785 | "tesseract-sys", 1786 | "thiserror", 1787 | ] 1788 | 1789 | [[package]] 1790 | name = "tesseract-plumbing" 1791 | version = "0.11.0" 1792 | source = "registry+https://github.com/rust-lang/crates.io-index" 1793 | checksum = "f7fb02c52201d03517af73dd0a146ac62cbd6f0155ad3dc6455d0140d6112191" 1794 | dependencies = [ 1795 | "leptonica-plumbing", 1796 | "tesseract-sys", 1797 | "thiserror", 1798 | ] 1799 | 1800 | [[package]] 1801 | name = "tesseract-sys" 1802 | version = "0.5.15" 1803 | source = "registry+https://github.com/rust-lang/crates.io-index" 1804 | checksum = "bd33f6f216124cfaf0fa86c2c0cdf04da39b6257bd78c5e44fa4fa98c3a5857b" 1805 | dependencies = [ 1806 | "bindgen", 1807 | "leptonica-sys", 1808 | "pkg-config", 1809 | "vcpkg", 1810 | ] 1811 | 1812 | [[package]] 1813 | name = "thiserror" 1814 | version = "1.0.59" 1815 | source = "registry+https://github.com/rust-lang/crates.io-index" 1816 | checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa" 1817 | dependencies = [ 1818 | "thiserror-impl", 1819 | ] 1820 | 1821 | [[package]] 1822 | name = "thiserror-impl" 1823 | version = "1.0.59" 1824 | source = "registry+https://github.com/rust-lang/crates.io-index" 1825 | checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" 1826 | dependencies = [ 1827 | "proc-macro2", 1828 | "quote", 1829 | "syn 2.0.60", 1830 | ] 1831 | 1832 | [[package]] 1833 | name = "tiff" 1834 | version = "0.9.1" 1835 | source = "registry+https://github.com/rust-lang/crates.io-index" 1836 | checksum = "ba1310fcea54c6a9a4fd1aad794ecc02c31682f6bfbecdf460bf19533eed1e3e" 1837 | dependencies = [ 1838 | "flate2", 1839 | "jpeg-decoder", 1840 | "weezl", 1841 | ] 1842 | 1843 | [[package]] 1844 | name = "tinyvec" 1845 | version = "1.6.0" 1846 | source = "registry+https://github.com/rust-lang/crates.io-index" 1847 | checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" 1848 | dependencies = [ 1849 | "tinyvec_macros", 1850 | ] 1851 | 1852 | [[package]] 1853 | name = "tinyvec_macros" 1854 | version = "0.1.1" 1855 | source = "registry+https://github.com/rust-lang/crates.io-index" 1856 | checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" 1857 | 1858 | [[package]] 1859 | name = "toml" 1860 | version = "0.8.13" 1861 | source = "registry+https://github.com/rust-lang/crates.io-index" 1862 | checksum = "a4e43f8cc456c9704c851ae29c67e17ef65d2c30017c17a9765b89c382dc8bba" 1863 | dependencies = [ 1864 | "serde", 1865 | "serde_spanned", 1866 | "toml_datetime", 1867 | "toml_edit", 1868 | ] 1869 | 1870 | [[package]] 1871 | name = "toml_datetime" 1872 | version = "0.6.6" 1873 | source = "registry+https://github.com/rust-lang/crates.io-index" 1874 | checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" 1875 | dependencies = [ 1876 | "serde", 1877 | ] 1878 | 1879 | [[package]] 1880 | name = "toml_edit" 1881 | version = "0.22.13" 1882 | source = "registry+https://github.com/rust-lang/crates.io-index" 1883 | checksum = "c127785850e8c20836d49732ae6abfa47616e60bf9d9f57c43c250361a9db96c" 1884 | dependencies = [ 1885 | "indexmap", 1886 | "serde", 1887 | "serde_spanned", 1888 | "toml_datetime", 1889 | "winnow", 1890 | ] 1891 | 1892 | [[package]] 1893 | name = "tracing" 1894 | version = "0.1.40" 1895 | source = "registry+https://github.com/rust-lang/crates.io-index" 1896 | checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" 1897 | dependencies = [ 1898 | "pin-project-lite", 1899 | "tracing-attributes", 1900 | "tracing-core", 1901 | ] 1902 | 1903 | [[package]] 1904 | name = "tracing-attributes" 1905 | version = "0.1.27" 1906 | source = "registry+https://github.com/rust-lang/crates.io-index" 1907 | checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" 1908 | dependencies = [ 1909 | "proc-macro2", 1910 | "quote", 1911 | "syn 2.0.60", 1912 | ] 1913 | 1914 | [[package]] 1915 | name = "tracing-core" 1916 | version = "0.1.32" 1917 | source = "registry+https://github.com/rust-lang/crates.io-index" 1918 | checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" 1919 | dependencies = [ 1920 | "once_cell", 1921 | ] 1922 | 1923 | [[package]] 1924 | name = "typenum" 1925 | version = "1.17.0" 1926 | source = "registry+https://github.com/rust-lang/crates.io-index" 1927 | checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" 1928 | 1929 | [[package]] 1930 | name = "unicode-bidi" 1931 | version = "0.3.15" 1932 | source = "registry+https://github.com/rust-lang/crates.io-index" 1933 | checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" 1934 | 1935 | [[package]] 1936 | name = "unicode-ident" 1937 | version = "1.0.12" 1938 | source = "registry+https://github.com/rust-lang/crates.io-index" 1939 | checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" 1940 | 1941 | [[package]] 1942 | name = "unicode-normalization" 1943 | version = "0.1.23" 1944 | source = "registry+https://github.com/rust-lang/crates.io-index" 1945 | checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" 1946 | dependencies = [ 1947 | "tinyvec", 1948 | ] 1949 | 1950 | [[package]] 1951 | name = "unicode-width" 1952 | version = "0.1.12" 1953 | source = "registry+https://github.com/rust-lang/crates.io-index" 1954 | checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6" 1955 | 1956 | [[package]] 1957 | name = "untrusted" 1958 | version = "0.9.0" 1959 | source = "registry+https://github.com/rust-lang/crates.io-index" 1960 | checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" 1961 | 1962 | [[package]] 1963 | name = "ureq" 1964 | version = "2.9.7" 1965 | source = "registry+https://github.com/rust-lang/crates.io-index" 1966 | checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd" 1967 | dependencies = [ 1968 | "base64", 1969 | "flate2", 1970 | "log", 1971 | "native-tls", 1972 | "once_cell", 1973 | "rustls", 1974 | "rustls-pki-types", 1975 | "rustls-webpki", 1976 | "serde", 1977 | "serde_json", 1978 | "url", 1979 | "webpki-roots", 1980 | ] 1981 | 1982 | [[package]] 1983 | name = "url" 1984 | version = "2.5.0" 1985 | source = "registry+https://github.com/rust-lang/crates.io-index" 1986 | checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" 1987 | dependencies = [ 1988 | "form_urlencoded", 1989 | "idna", 1990 | "percent-encoding", 1991 | ] 1992 | 1993 | [[package]] 1994 | name = "v_frame" 1995 | version = "0.3.8" 1996 | source = "registry+https://github.com/rust-lang/crates.io-index" 1997 | checksum = "d6f32aaa24bacd11e488aa9ba66369c7cd514885742c9fe08cfe85884db3e92b" 1998 | dependencies = [ 1999 | "aligned-vec", 2000 | "num-traits", 2001 | "wasm-bindgen", 2002 | ] 2003 | 2004 | [[package]] 2005 | name = "vcpkg" 2006 | version = "0.2.15" 2007 | source = "registry+https://github.com/rust-lang/crates.io-index" 2008 | checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" 2009 | 2010 | [[package]] 2011 | name = "version-compare" 2012 | version = "0.2.0" 2013 | source = "registry+https://github.com/rust-lang/crates.io-index" 2014 | checksum = "852e951cb7832cb45cb1169900d19760cfa39b82bc0ea9c0e5a14ae88411c98b" 2015 | 2016 | [[package]] 2017 | name = "version_check" 2018 | version = "0.9.4" 2019 | source = "registry+https://github.com/rust-lang/crates.io-index" 2020 | checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" 2021 | 2022 | [[package]] 2023 | name = "wasi" 2024 | version = "0.11.0+wasi-snapshot-preview1" 2025 | source = "registry+https://github.com/rust-lang/crates.io-index" 2026 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 2027 | 2028 | [[package]] 2029 | name = "wasm-bindgen" 2030 | version = "0.2.92" 2031 | source = "registry+https://github.com/rust-lang/crates.io-index" 2032 | checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" 2033 | dependencies = [ 2034 | "cfg-if", 2035 | "wasm-bindgen-macro", 2036 | ] 2037 | 2038 | [[package]] 2039 | name = "wasm-bindgen-backend" 2040 | version = "0.2.92" 2041 | source = "registry+https://github.com/rust-lang/crates.io-index" 2042 | checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" 2043 | dependencies = [ 2044 | "bumpalo", 2045 | "log", 2046 | "once_cell", 2047 | "proc-macro2", 2048 | "quote", 2049 | "syn 2.0.60", 2050 | "wasm-bindgen-shared", 2051 | ] 2052 | 2053 | [[package]] 2054 | name = "wasm-bindgen-macro" 2055 | version = "0.2.92" 2056 | source = "registry+https://github.com/rust-lang/crates.io-index" 2057 | checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" 2058 | dependencies = [ 2059 | "quote", 2060 | "wasm-bindgen-macro-support", 2061 | ] 2062 | 2063 | [[package]] 2064 | name = "wasm-bindgen-macro-support" 2065 | version = "0.2.92" 2066 | source = "registry+https://github.com/rust-lang/crates.io-index" 2067 | checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" 2068 | dependencies = [ 2069 | "proc-macro2", 2070 | "quote", 2071 | "syn 2.0.60", 2072 | "wasm-bindgen-backend", 2073 | "wasm-bindgen-shared", 2074 | ] 2075 | 2076 | [[package]] 2077 | name = "wasm-bindgen-shared" 2078 | version = "0.2.92" 2079 | source = "registry+https://github.com/rust-lang/crates.io-index" 2080 | checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" 2081 | 2082 | [[package]] 2083 | name = "web-sys" 2084 | version = "0.3.69" 2085 | source = "registry+https://github.com/rust-lang/crates.io-index" 2086 | checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" 2087 | dependencies = [ 2088 | "js-sys", 2089 | "wasm-bindgen", 2090 | ] 2091 | 2092 | [[package]] 2093 | name = "webpki-roots" 2094 | version = "0.26.1" 2095 | source = "registry+https://github.com/rust-lang/crates.io-index" 2096 | checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009" 2097 | dependencies = [ 2098 | "rustls-pki-types", 2099 | ] 2100 | 2101 | [[package]] 2102 | name = "weezl" 2103 | version = "0.1.8" 2104 | source = "registry+https://github.com/rust-lang/crates.io-index" 2105 | checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" 2106 | 2107 | [[package]] 2108 | name = "which" 2109 | version = "4.4.2" 2110 | source = "registry+https://github.com/rust-lang/crates.io-index" 2111 | checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" 2112 | dependencies = [ 2113 | "either", 2114 | "home", 2115 | "once_cell", 2116 | "rustix", 2117 | ] 2118 | 2119 | [[package]] 2120 | name = "windows-sys" 2121 | version = "0.48.0" 2122 | source = "registry+https://github.com/rust-lang/crates.io-index" 2123 | checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" 2124 | dependencies = [ 2125 | "windows-targets 0.48.5", 2126 | ] 2127 | 2128 | [[package]] 2129 | name = "windows-sys" 2130 | version = "0.52.0" 2131 | source = "registry+https://github.com/rust-lang/crates.io-index" 2132 | checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" 2133 | dependencies = [ 2134 | "windows-targets 0.52.5", 2135 | ] 2136 | 2137 | [[package]] 2138 | name = "windows-targets" 2139 | version = "0.48.5" 2140 | source = "registry+https://github.com/rust-lang/crates.io-index" 2141 | checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" 2142 | dependencies = [ 2143 | "windows_aarch64_gnullvm 0.48.5", 2144 | "windows_aarch64_msvc 0.48.5", 2145 | "windows_i686_gnu 0.48.5", 2146 | "windows_i686_msvc 0.48.5", 2147 | "windows_x86_64_gnu 0.48.5", 2148 | "windows_x86_64_gnullvm 0.48.5", 2149 | "windows_x86_64_msvc 0.48.5", 2150 | ] 2151 | 2152 | [[package]] 2153 | name = "windows-targets" 2154 | version = "0.52.5" 2155 | source = "registry+https://github.com/rust-lang/crates.io-index" 2156 | checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" 2157 | dependencies = [ 2158 | "windows_aarch64_gnullvm 0.52.5", 2159 | "windows_aarch64_msvc 0.52.5", 2160 | "windows_i686_gnu 0.52.5", 2161 | "windows_i686_gnullvm", 2162 | "windows_i686_msvc 0.52.5", 2163 | "windows_x86_64_gnu 0.52.5", 2164 | "windows_x86_64_gnullvm 0.52.5", 2165 | "windows_x86_64_msvc 0.52.5", 2166 | ] 2167 | 2168 | [[package]] 2169 | name = "windows_aarch64_gnullvm" 2170 | version = "0.48.5" 2171 | source = "registry+https://github.com/rust-lang/crates.io-index" 2172 | checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" 2173 | 2174 | [[package]] 2175 | name = "windows_aarch64_gnullvm" 2176 | version = "0.52.5" 2177 | source = "registry+https://github.com/rust-lang/crates.io-index" 2178 | checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" 2179 | 2180 | [[package]] 2181 | name = "windows_aarch64_msvc" 2182 | version = "0.48.5" 2183 | source = "registry+https://github.com/rust-lang/crates.io-index" 2184 | checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" 2185 | 2186 | [[package]] 2187 | name = "windows_aarch64_msvc" 2188 | version = "0.52.5" 2189 | source = "registry+https://github.com/rust-lang/crates.io-index" 2190 | checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" 2191 | 2192 | [[package]] 2193 | name = "windows_i686_gnu" 2194 | version = "0.48.5" 2195 | source = "registry+https://github.com/rust-lang/crates.io-index" 2196 | checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" 2197 | 2198 | [[package]] 2199 | name = "windows_i686_gnu" 2200 | version = "0.52.5" 2201 | source = "registry+https://github.com/rust-lang/crates.io-index" 2202 | checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" 2203 | 2204 | [[package]] 2205 | name = "windows_i686_gnullvm" 2206 | version = "0.52.5" 2207 | source = "registry+https://github.com/rust-lang/crates.io-index" 2208 | checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" 2209 | 2210 | [[package]] 2211 | name = "windows_i686_msvc" 2212 | version = "0.48.5" 2213 | source = "registry+https://github.com/rust-lang/crates.io-index" 2214 | checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" 2215 | 2216 | [[package]] 2217 | name = "windows_i686_msvc" 2218 | version = "0.52.5" 2219 | source = "registry+https://github.com/rust-lang/crates.io-index" 2220 | checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" 2221 | 2222 | [[package]] 2223 | name = "windows_x86_64_gnu" 2224 | version = "0.48.5" 2225 | source = "registry+https://github.com/rust-lang/crates.io-index" 2226 | checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" 2227 | 2228 | [[package]] 2229 | name = "windows_x86_64_gnu" 2230 | version = "0.52.5" 2231 | source = "registry+https://github.com/rust-lang/crates.io-index" 2232 | checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" 2233 | 2234 | [[package]] 2235 | name = "windows_x86_64_gnullvm" 2236 | version = "0.48.5" 2237 | source = "registry+https://github.com/rust-lang/crates.io-index" 2238 | checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" 2239 | 2240 | [[package]] 2241 | name = "windows_x86_64_gnullvm" 2242 | version = "0.52.5" 2243 | source = "registry+https://github.com/rust-lang/crates.io-index" 2244 | checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" 2245 | 2246 | [[package]] 2247 | name = "windows_x86_64_msvc" 2248 | version = "0.48.5" 2249 | source = "registry+https://github.com/rust-lang/crates.io-index" 2250 | checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" 2251 | 2252 | [[package]] 2253 | name = "windows_x86_64_msvc" 2254 | version = "0.52.5" 2255 | source = "registry+https://github.com/rust-lang/crates.io-index" 2256 | checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" 2257 | 2258 | [[package]] 2259 | name = "winnow" 2260 | version = "0.6.9" 2261 | source = "registry+https://github.com/rust-lang/crates.io-index" 2262 | checksum = "86c949fede1d13936a99f14fafd3e76fd642b556dd2ce96287fbe2e0151bfac6" 2263 | dependencies = [ 2264 | "memchr", 2265 | ] 2266 | 2267 | [[package]] 2268 | name = "xattr" 2269 | version = "1.3.1" 2270 | source = "registry+https://github.com/rust-lang/crates.io-index" 2271 | checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f" 2272 | dependencies = [ 2273 | "libc", 2274 | "linux-raw-sys", 2275 | "rustix", 2276 | ] 2277 | 2278 | [[package]] 2279 | name = "zerocopy" 2280 | version = "0.7.34" 2281 | source = "registry+https://github.com/rust-lang/crates.io-index" 2282 | checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" 2283 | dependencies = [ 2284 | "zerocopy-derive", 2285 | ] 2286 | 2287 | [[package]] 2288 | name = "zerocopy-derive" 2289 | version = "0.7.34" 2290 | source = "registry+https://github.com/rust-lang/crates.io-index" 2291 | checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" 2292 | dependencies = [ 2293 | "proc-macro2", 2294 | "quote", 2295 | "syn 2.0.60", 2296 | ] 2297 | 2298 | [[package]] 2299 | name = "zeroize" 2300 | version = "1.7.0" 2301 | source = "registry+https://github.com/rust-lang/crates.io-index" 2302 | checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" 2303 | 2304 | [[package]] 2305 | name = "zune-core" 2306 | version = "0.4.12" 2307 | source = "registry+https://github.com/rust-lang/crates.io-index" 2308 | checksum = "3f423a2c17029964870cfaabb1f13dfab7d092a62a29a89264f4d36990ca414a" 2309 | 2310 | [[package]] 2311 | name = "zune-inflate" 2312 | version = "0.2.54" 2313 | source = "registry+https://github.com/rust-lang/crates.io-index" 2314 | checksum = "73ab332fe2f6680068f3582b16a24f90ad7096d5d39b974d1c0aff0125116f02" 2315 | dependencies = [ 2316 | "simd-adler32", 2317 | ] 2318 | 2319 | [[package]] 2320 | name = "zune-jpeg" 2321 | version = "0.4.11" 2322 | source = "registry+https://github.com/rust-lang/crates.io-index" 2323 | checksum = "ec866b44a2a1fd6133d363f073ca1b179f438f99e7e5bfb1e33f7181facfe448" 2324 | dependencies = [ 2325 | "zune-core", 2326 | ] 2327 | --------------------------------------------------------------------------------