├── .dockerignore ├── .github ├── dependabot.yml └── workflows │ └── builld.yaml ├── .gitignore ├── Cargo.toml ├── Dockerfile ├── LICENSE ├── README.md └── src ├── bbox.rs ├── bin └── main.rs ├── detection ├── mod.rs └── segformer.rs ├── error.rs ├── hf.rs ├── lib.rs ├── postprocess.rs ├── preprocess.rs └── recognition ├── mbart.rs ├── mod.rs └── swin_transformer.rs /.dockerignore: -------------------------------------------------------------------------------- 1 | .github/ 2 | target/ 3 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "cargo" 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "monthly" 12 | -------------------------------------------------------------------------------- /.github/workflows/builld.yaml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | env: 10 | OPENCV_VERSION: "4.9.0" 11 | 12 | jobs: 13 | build-windows: 14 | runs-on: windows-latest 15 | env: 16 | OPENCV_LINK_LIBS: "opencv_world490" 17 | OPENCV_LINK_PATHS: "C:/tools/opencv/build/x64/vc16/lib,C:/tools/opencv/build/x64/vc15/lib" 18 | OPENCV_INCLUDE_PATHS: "C:/tools/opencv/build/include" 19 | steps: 20 | - uses: actions/checkout@v4 21 | 22 | - name: Setup OpenCV and LLVM 23 | run: choco install opencv llvm 24 | 25 | - name: Setup Rust 26 | uses: dtolnay/rust-toolchain@stable 27 | 28 | - name: Run check on Windows 29 | run: cargo check --features=cli --bins 30 | 31 | build-macos: 32 | runs-on: macos-latest 33 | steps: 34 | - uses: actions/checkout@v4 35 | 36 | - name: Setup OpenCV and LLVM 37 | run: brew install opencv llvm 38 | 39 | - name: Setup Rust 40 | uses: dtolnay/rust-toolchain@stable 41 | 42 | - name: Run check on macOS 43 | run: | 44 | echo "setting DYLD_FALLBACK_LIBRARY_PATH to $(xcode-select --print-path)/Toolchains/XcodeDefault.xctoolchain/usr/lib/" 45 | export DYLD_FALLBACK_LIBRARY_PATH="$(xcode-select --print-path)/Toolchains/XcodeDefault.xctoolchain/usr/lib/" 46 | cargo check --features=cli,metal --bins 47 | 48 | build: 49 | runs-on: ubuntu-latest 50 | steps: 51 | - uses: actions/checkout@v4 52 | 53 | - name: Setup OpenCV for Ubuntu 54 | run: | 55 | sudo apt update -y 56 | sudo apt install -y libopencv-dev clang libclang-dev 57 | 58 | - name: Setup Rust 59 | uses: dtolnay/rust-toolchain@stable 60 | 61 | - name: Run check on Ubuntu 62 | run: cargo check --features=cli --bin surya 63 | 64 | - name: Run fmt check 65 | run: cargo fmt --all --check 66 | 67 | - name: Run clippy 68 | run: cargo clippy --features=cli --bin surya --lib 69 | 70 | - name: Run unit tests 71 | run: cargo test --features=cli --bin surya --lib 72 | 73 | - name: Test run surya 74 | run: cargo run --features=cli --bin surya -- --help 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/rust 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=rust 3 | 4 | ### Rust ### 5 | # Generated by Cargo 6 | # will have compiled files and executables 7 | debug/ 8 | target/ 9 | 10 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 11 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 12 | Cargo.lock 13 | 14 | # These are backup files generated by rustfmt 15 | **/*.rs.bk 16 | 17 | # MSVC Windows builds of rustc generate these, which store debugging information 18 | *.pdb 19 | 20 | # End of https://www.toptal.com/developers/gitignore/api/rust 21 | surya_output/ 22 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "surya" 3 | version = "0.4.0" 4 | edition = "2021" 5 | description = "Surya is a multilingual document OCR toolkit, original implementation in Python and PyTorch" 6 | license = "Apache-2.0" 7 | authors = ["Jiayu Liu "] 8 | repository = "https://github.com/jimexist/surya-rs" 9 | default-run = "surya" 10 | 11 | [dependencies] 12 | anyhow = { version = "1.0.79", optional = true } 13 | candle-core = { version = "0.6.0" } 14 | candle-nn = { version = "0.6.0" } 15 | clap = { version = "4.5.11", features = ["derive"], optional = true } 16 | env_logger = { version = "0.11.0" } 17 | hf-hub = { version = "0.3.2" } 18 | log = { version = "0.4.20" } 19 | opencv = { version = "0.93.1", default-features = false, features = [ 20 | 'imgproc', 21 | 'imgcodecs', 22 | ] } 23 | serde = { version = "1.0.196" } 24 | serde_json = { version = "1.0.112" } 25 | accelerate-src = { version = "0.3.2", optional = true } 26 | intel-mkl-src = { version = "0.8.1", features = [ 27 | "mkl-static-lp64-iomp", 28 | ], optional = true } 29 | thiserror = { version = "1.0.56" } 30 | 31 | [features] 32 | default = ["cli"] 33 | metal = ["candle-core/metal", "candle-nn/metal"] 34 | accelerate = [ 35 | "accelerate-src", 36 | "candle-core/accelerate", 37 | "candle-nn/accelerate", 38 | ] 39 | mkl = ["intel-mkl-src", "candle-core/mkl", "candle-nn/mkl"] 40 | cli = ["clap", "anyhow"] 41 | 42 | [[bin]] 43 | name = "surya" 44 | path = "src/bin/main.rs" 45 | required-features = ["cli"] 46 | 47 | [dev-dependencies] 48 | float-cmp = "0.10.0" 49 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rust:1.80-slim as builder 2 | 3 | ENV OPEN_CV_VERSION="4.10.0" 4 | 5 | RUN apt-get update && apt-get install -y \ 6 | build-essential \ 7 | clang \ 8 | libclang-dev \ 9 | libssl-dev \ 10 | wget \ 11 | zip \ 12 | cmake 13 | 14 | WORKDIR /usr/src/opencv 15 | 16 | RUN wget -O opencv.zip https://github.com/opencv/opencv/archive/refs/tags/${OPEN_CV_VERSION}.zip && \ 17 | unzip opencv.zip && \ 18 | rm opencv.zip 19 | 20 | RUN wget -O opencv_contrib.zip https://github.com/opencv/opencv_contrib/archive/refs/tags/${OPEN_CV_VERSION}.zip && \ 21 | unzip opencv_contrib.zip && \ 22 | rm opencv_contrib.zip 23 | 24 | WORKDIR /usr/src/opencv/build 25 | 26 | RUN cmake -DCMAKE_BUILD_TYPE=Release \ 27 | -DBUILD_SHARED_LIBS=NO \ 28 | -DCMAKE_INSTALL_PREFIX=/opt/opencv \ 29 | -DBUILD_DOCS=OFF \ 30 | -DBUILD_EXAMPLES=OFF \ 31 | -DBUILD_TESTS=OFF \ 32 | -DBUILD_PERF_TESTS=OFF \ 33 | -DBUILD_ITT=OFF \ 34 | -DBUILD_IPP_IW=OFF \ 35 | -DWITH_PNG=OFF \ 36 | -DWITH_JPEG=OFF \ 37 | -DWITH_TIFF=OFF \ 38 | -DWITH_WEBP=OFF \ 39 | -DWITH_OPENJPEG=OFF \ 40 | -DWITH_JASPER=OFF \ 41 | -DWITH_OPENEXR=OFF \ 42 | -DWITH_V4L=OFF \ 43 | -DWITH_CAROTENE=OFF \ 44 | -DBUILD_opencv_java=OFF \ 45 | -DBUILD_opencv_python=OFF \ 46 | -DOPENCV_EXTRA_MODULES_PATH=../opencv_contrib-${OPEN_CV_VERSION}/modules \ 47 | ../opencv-${OPEN_CV_VERSION} 48 | 49 | RUN cmake --build . --target install --config Release --parallel 8 50 | 51 | RUN cmake --install . --prefix /opt/opencv 52 | 53 | WORKDIR /usr/src/surya 54 | 55 | COPY . . 56 | 57 | RUN OPENCV_LINK_LIBS="opencv_imgcodecs,opencv_imgproc,opencv_core" \ 58 | OPENCV_LINK_PATHS="/opt/opencv/lib,/opt/opencv/lib/opencv4/3rdparty,/usr/lib/$(uname -m)-linux-gnu" \ 59 | OPENCV_INCLUDE_PATHS="/opt/opencv/include,/opt/opencv/include/opencv4" \ 60 | OPENSSL_LIB_DIR="/usr/lib/$(uname -m)-linux-gnu" \ 61 | OPENSSL_INCLUDE_DIR="/usr/include/openssl" \ 62 | cargo install --path . --features "cli" 63 | 64 | FROM debian:bookworm-slim 65 | 66 | RUN apt-get update && \ 67 | apt-get install -y libssl-dev && \ 68 | rm -rf /var/lib/apt/lists/* 69 | 70 | WORKDIR /usr/local/bin 71 | 72 | COPY --from=builder /usr/local/cargo/bin/surya /usr/local/bin/surya 73 | 74 | ENTRYPOINT ["surya"] 75 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # surya-rs 2 | 3 | [![Build](https://github.com/Jimexist/surya-rs/actions/workflows/builld.yaml/badge.svg)](https://github.com/Jimexist/surya-rs/actions/workflows/builld.yaml) 4 | [![Crates.io Version](https://img.shields.io/crates/v/surya)](https://crates.io/crates/surya) 5 | 6 | Rust implementation of [surya][surya], a multilingual document OCR toolkit. 7 | The implementation is based on a modified version of Segformer, [OpenCV][opencv], and 8 | donut transformer. 9 | 10 | Please refer to the original project for more details on licensing of the weights. 11 | 12 | ## Roadmap 13 | 14 | This project is still in development, feel free to star and check back. 15 | 16 | - [x] image input pre-processing 17 | - [x] detection - segformer 18 | - [x] detection - weights loading 19 | - [x] detection - heatmap and affinity map 20 | - [x] detection - bboxes 21 | - [x] detection - image splitting and stitching 22 | - [ ] recognition - swin encoder 23 | - [ ] recognition - MoE MBart 24 | - [ ] recognition - donut transformer loading 25 | - [ ] benchmark 26 | - [ ] quantifications 27 | 28 | ## How to build and install 29 | 30 | Setup rust toolchain if you haven't yet: 31 | 32 | ```bash 33 | # visit https://rustup.rs/ for more detailed information 34 | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh 35 | ``` 36 | 37 | Install `llvm` and `opencv` (example on Mac): 38 | 39 | ```bash 40 | brew install llvm opencv 41 | ``` 42 | 43 | Build and install the binary: 44 | 45 | ```bash 46 | # run this first on Mac if you have a M1 chip 47 | export DYLD_FALLBACK_LIBRARY_PATH="$(xcode-select --print-path)/usr/lib/" 48 | # run this first on other Mac 49 | export DYLD_FALLBACK_LIBRARY_PATH="$(xcode-select --print-path)/Toolchains/XcodeDefault.xctoolchain/" 50 | # optionally you can include features like accelerate, metal, mkl, etc. 51 | cargo install --path . --features=cli 52 | ``` 53 | 54 | The binary when built does _not_ include the weights file itself, and will instead download via the HuggingFace Hub API. Once downloaded, the weights file will be cached in the HuggingFace cache directory. 55 | 56 | Check `-h` for help: 57 | 58 | ```text 59 | Surya is a multilingual document OCR toolkit, original implementation in Python and PyTorch 60 | 61 | Usage: surya [OPTIONS] 62 | 63 | Arguments: 64 | path to image 65 | 66 | Options: 67 | --detection-batch-size 68 | detection batch size, if not supplied defaults to 2 on CPU and 16 on GPU 69 | --detection-model-repo 70 | detection model's hugging face repo [default: vikp/surya_det] 71 | --weights-file-name 72 | detection model's weights file name [default: model.safetensors] 73 | --config-file-name 74 | detection model's config file name [default: config.json] 75 | --non-max-suppression-threshold 76 | a value between 0.0 and 1.0 to filter low density part of heatmap [default: 0.35] 77 | --extract-text-threshold 78 | a value between 0.0 and 1.0 to filter out bbox with low heatmap density [default: 0.6] 79 | --bbox-area-threshold 80 | a pixel threshold to filter out small area bbox [default: 10] 81 | --recognition-batch-size 82 | recognition batch size, if not supplied defaults to 8 on CPU and 256 on GPU 83 | --recognition-model-repo 84 | recognition model's hugging face repo [default: vikp/surya_rec] 85 | --output-dir 86 | output directory, under which the input image will be generating a subdirectory [default: ./surya_output] 87 | --polygons 88 | whether to output polygons json file 89 | --image 90 | whether to generate bbox image 91 | --heatmap 92 | whether to generate heatmap 93 | --affinity-map 94 | whether to generate affinity map 95 | --device 96 | device type, if not specified will try to use GPU or Metal [possible values: cpu, gpu, metal] 97 | --verbose 98 | whether to enable verbose mode 99 | -h, --help 100 | Print help 101 | -V, --version 102 | Print version 103 | ``` 104 | 105 | You can also use this to control logging level: 106 | 107 | ```bash 108 | export SURYA_LOG=warn # or debug, warn, etc. 109 | ``` 110 | 111 | ## Library 112 | 113 | This lib is also published as a trait for other rust projects to use. 114 | 115 | [surya]: https://github.com/VikParuchuri/surya 116 | [opencv]: https://crates.io/crates/opencv 117 | -------------------------------------------------------------------------------- /src/bbox.rs: -------------------------------------------------------------------------------- 1 | use log::debug; 2 | use opencv::core::{ 3 | self, max_mat_f64, min_mat_f64, Mat, Point, Point2f, Rect, Scalar, Size, Vector, CV_32S, 4 | }; 5 | use opencv::prelude::*; 6 | use opencv::{imgcodecs, imgproc}; 7 | use std::path::Path; 8 | 9 | #[derive(Debug, Clone)] 10 | pub struct BBox { 11 | pub polygon: [Point2f; 4], 12 | } 13 | 14 | impl BBox { 15 | fn rescale(&self, heatmap_size: Size, image_with_padding_size: Size) -> crate::Result { 16 | let (h_scaler, w_scaler) = ( 17 | image_with_padding_size.height as f32 / heatmap_size.height as f32, 18 | image_with_padding_size.width as f32 / heatmap_size.width as f32, 19 | ); 20 | let mut polygon = [Point2f::default(); 4]; 21 | for (i, point) in self.polygon.iter().enumerate() { 22 | let x = point.x * w_scaler; 23 | let y = point.y * h_scaler; 24 | polygon[i] = Point2f::new(x, y); 25 | } 26 | Ok(Self { polygon }) 27 | } 28 | 29 | fn draw_on_image(&self, image: &mut Mat) -> crate::Result<()> { 30 | let points: Vector = self 31 | .polygon 32 | .iter() 33 | .map(|point| { 34 | let x = point.x as i32; 35 | let y = point.y as i32; 36 | Point::new(x, y) 37 | }) 38 | .collect(); 39 | imgproc::polylines( 40 | image, 41 | &points, 42 | true, 43 | Scalar::new(0., 0., 255., 0.), 44 | 1, 45 | opencv::imgproc::LINE_8, 46 | 0, 47 | )?; 48 | Ok(()) 49 | } 50 | } 51 | 52 | /// https://docs.rs/opencv/0.88.8/opencv/imgproc/fn.threshold.html 53 | fn image_threshold(mat: &Mat, non_max_suppression_threshold: f64) -> crate::Result { 54 | let mut r = Mat::default(); 55 | let max_val = 1.0; 56 | imgproc::threshold( 57 | &mat, 58 | &mut r, 59 | non_max_suppression_threshold, 60 | max_val, 61 | imgproc::THRESH_BINARY, 62 | )?; 63 | let r = min_mat_f64(&r, 1.0)?.to_mat()?; 64 | let r = max_mat_f64(&r, 0.0)?.to_mat()?; 65 | Ok(r) 66 | } 67 | 68 | /// https://docs.rs/opencv/0.88.8/opencv/prelude/trait.MatTraitConst.html#method.convert_to 69 | fn image_f32_to_u8(mat: Mat) -> crate::Result { 70 | let mut r = Mat::default(); 71 | let alpha = 255.0; 72 | let beta = 0.0; 73 | mat.convert_to(&mut r, core::CV_8UC1, alpha, beta)?; 74 | Ok(r) 75 | } 76 | 77 | /// https://docs.rs/opencv/0.88.8/opencv/imgproc/fn.connected_components.html 78 | fn image_to_connected_components(mat: Mat) -> crate::Result<(Mat, Mat, Mat)> { 79 | let mut labels: Mat = Default::default(); 80 | let mut stats: Mat = Default::default(); 81 | let mut centroids: Mat = Default::default(); 82 | imgproc::connected_components_with_stats( 83 | &mat, 84 | &mut labels, 85 | &mut stats, 86 | &mut centroids, 87 | 4, 88 | CV_32S, 89 | )?; 90 | Ok((labels, stats, centroids)) 91 | } 92 | 93 | fn heatmap_label_max(heatmap: &Mat, labels: &Mat, label: i32) -> crate::Result { 94 | let mut mask = Mat::default(); 95 | core::compare(labels, &(label as f64), &mut mask, opencv::core::CMP_EQ)?; 96 | let mut max_value = 0.0; 97 | core::min_max_loc(heatmap, None, Some(&mut max_value), None, None, &mask)?; 98 | Ok(max_value) 99 | } 100 | 101 | fn get_dilation_matrix(segmap: &mut Mat, stats_row: &[i32]) -> crate::Result { 102 | let (x, y, w, h, area) = ( 103 | stats_row[imgproc::CC_STAT_LEFT as usize], 104 | stats_row[imgproc::CC_STAT_TOP as usize], 105 | stats_row[imgproc::CC_STAT_WIDTH as usize], 106 | stats_row[imgproc::CC_STAT_HEIGHT as usize], 107 | stats_row[imgproc::CC_STAT_AREA as usize], 108 | ); 109 | let niter = { 110 | let niter = (area * w.min(h)) as f64 / (w * h) as f64; 111 | (niter.sqrt() * 2.0) as i32 112 | }; 113 | let roi = { 114 | let sx = (x - niter).max(0); 115 | let sy = (y - niter).max(0); 116 | let ex = (x + w + niter + 1).min(segmap.cols()); 117 | let ey = (y + h + niter + 1).min(segmap.rows()); 118 | Rect::new(sx, sy, ex - sx, ey - sy) 119 | }; 120 | let mut roi = Mat::roi(segmap, roi)?.clone_pointee(); 121 | let kernel = imgproc::get_structuring_element( 122 | imgproc::MORPH_RECT, 123 | Size::new(1 + niter, 1 + niter), 124 | Point::new(-1, -1), 125 | )?; 126 | imgproc::dilate( 127 | segmap, 128 | &mut roi, 129 | &kernel, 130 | Point::new(-1, -1), // default anchor 131 | 1, 132 | core::BORDER_CONSTANT, // border type 133 | Scalar::default(), // border value 134 | )?; 135 | Ok(roi) 136 | } 137 | 138 | fn connected_area_to_bbox( 139 | labels: &Mat, 140 | stats_row: &[i32], 141 | label: i32, 142 | ) -> crate::Result<[Point2f; 4]> { 143 | let mut segmap = Mat::default(); 144 | core::compare(&labels, &(label as f64), &mut segmap, opencv::core::CMP_EQ)?; 145 | 146 | let dilated_roi = get_dilation_matrix(&mut segmap, stats_row)?; 147 | dilated_roi.copy_to(&mut segmap)?; 148 | 149 | let mut non_zero = Mat::default(); 150 | core::find_non_zero(&segmap, &mut non_zero)?; 151 | let rotated_rect = imgproc::min_area_rect(&non_zero)?; 152 | let mut points = [Point2f::default(); 4]; 153 | rotated_rect.points(&mut points)?; 154 | Ok(points) 155 | } 156 | 157 | pub fn draw_bboxes>( 158 | image: &mut Mat, 159 | heatmap_size: Size, 160 | image_with_padding_size: Size, 161 | bboxes: &[BBox], 162 | output_file: P, 163 | ) -> crate::Result<()> { 164 | debug!( 165 | "image size={:?}, heatmap_size={:?}, image_with_padding_size={:?}", 166 | image.size()?, 167 | heatmap_size, 168 | image_with_padding_size 169 | ); 170 | for bbox in bboxes { 171 | bbox.rescale(heatmap_size, image_with_padding_size)? 172 | .draw_on_image(image)?; 173 | } 174 | let params = Vector::::new(); 175 | imgcodecs::imwrite( 176 | output_file.as_ref().as_os_str().to_str().unwrap(), 177 | image, 178 | ¶ms, 179 | )?; 180 | Ok(()) 181 | } 182 | 183 | /// generate bbox from heatmap which are rescaled to original size 184 | pub fn generate_bbox( 185 | heatmap: &Mat, 186 | non_max_suppression_threshold: f64, 187 | extract_text_threshold: f64, 188 | bbox_area_threshold: i32, 189 | ) -> crate::Result> { 190 | let labels = image_threshold(heatmap, non_max_suppression_threshold)?; 191 | let labels = image_f32_to_u8(labels)?; 192 | let (labels, stats, centroids) = image_to_connected_components(labels)?; 193 | debug!("labels {:?}", labels); 194 | debug!("stats {:?}", stats); 195 | debug!("centroids {:?}", centroids); 196 | 197 | debug_assert_eq!( 198 | centroids.rows(), 199 | stats.rows(), 200 | "centroids and stats rows must be equal" 201 | ); 202 | debug_assert_eq!(5, stats.cols(), "stats must have 5 columns"); 203 | debug_assert_eq!(2, centroids.cols(), "centroids must have 2 columns"); 204 | 205 | let mut bboxes = Vec::new(); 206 | // 0 is background so skip it 207 | for label in 1..stats.rows() { 208 | let stats_row = stats.at_row::(label)?; 209 | let area = stats_row[imgproc::CC_STAT_AREA as usize]; 210 | if area < bbox_area_threshold { 211 | continue; 212 | } 213 | let max_value = heatmap_label_max(heatmap, &labels, label)?; 214 | if max_value < extract_text_threshold { 215 | continue; 216 | } 217 | let polygon = connected_area_to_bbox(&labels, stats_row, label)?; 218 | bboxes.push(BBox { polygon }); 219 | } 220 | debug!( 221 | "bbox filtering, before={}, after={} bboxes", 222 | stats.rows(), 223 | bboxes.len() 224 | ); 225 | Ok(bboxes) 226 | } 227 | -------------------------------------------------------------------------------- /src/bin/main.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "accelerate")] 2 | extern crate accelerate_src; 3 | #[cfg(feature = "mkl")] 4 | extern crate intel_mkl_src; 5 | 6 | use candle_core::{Device, IndexOp, Module, Tensor}; 7 | use clap::{Parser, ValueEnum}; 8 | use env_logger::Env; 9 | use log::{debug, info}; 10 | use opencv::hub_prelude::MatTraitConst; 11 | use std::fs::File; 12 | use std::io::BufWriter; 13 | use std::io::Write; 14 | use std::path::PathBuf; 15 | use std::time::Instant; 16 | use surya::bbox::{draw_bboxes, generate_bbox}; 17 | use surya::detection::SemanticSegmentationModel; 18 | use surya::hf::HfModel; 19 | use surya::hf::HfModelInfo; 20 | use surya::postprocess::save_image; 21 | use surya::preprocess::{image_to_tensor, read_chunked_resized_image, read_image}; 22 | use surya::recognition::RecognitionModel; 23 | 24 | #[derive(Debug, ValueEnum, Clone, Copy)] 25 | enum DeviceType { 26 | Cpu, 27 | Gpu, 28 | #[cfg(feature = "metal")] 29 | Metal, 30 | } 31 | 32 | impl TryInto for DeviceType { 33 | type Error = candle_core::Error; 34 | 35 | fn try_into(self) -> Result { 36 | match self { 37 | Self::Cpu => Ok(Device::Cpu), 38 | Self::Gpu => Device::new_cuda(0), 39 | #[cfg(feature = "metal")] 40 | Self::Metal => Device::new_metal(0), 41 | } 42 | } 43 | } 44 | 45 | #[derive(Parser, Debug)] 46 | #[command(author, version, about, long_about = None)] 47 | struct Cli { 48 | #[arg(help = "path to image")] 49 | image: PathBuf, 50 | 51 | #[arg( 52 | long, 53 | help = "detection batch size, if not supplied defaults to 2 on CPU and 16 on GPU" 54 | )] 55 | detection_batch_size: Option, 56 | 57 | #[arg( 58 | long, 59 | default_value = "vikp/surya_det", 60 | help = "detection model's hugging face repo" 61 | )] 62 | detection_model_repo: String, 63 | 64 | #[arg( 65 | long, 66 | default_value = "model.safetensors", 67 | help = "detection model's weights file name" 68 | )] 69 | detection_weights_file_name: String, 70 | 71 | #[arg( 72 | long, 73 | default_value = "config.json", 74 | help = "detection model's config file name" 75 | )] 76 | detection_config_file_name: String, 77 | 78 | #[arg( 79 | long, 80 | default_value_t = 0.35, 81 | help = "a value between 0.0 and 1.0 to filter low density part of heatmap" 82 | )] 83 | non_max_suppression_threshold: f64, 84 | 85 | #[arg( 86 | long, 87 | default_value_t = 0.6, 88 | help = "a value between 0.0 and 1.0 to filter out bbox with low heatmap density" 89 | )] 90 | extract_text_threshold: f64, 91 | 92 | #[arg( 93 | long, 94 | default_value_t = 10, 95 | help = "a pixel threshold to filter out small area bbox" 96 | )] 97 | bbox_area_threshold: usize, 98 | 99 | #[arg( 100 | long, 101 | help = "recognition batch size, if not supplied defaults to 8 on CPU and 256 on GPU" 102 | )] 103 | recognition_batch_size: Option, 104 | 105 | #[arg( 106 | long, 107 | default_value = "vikp/surya_rec", 108 | help = "recognition model's hugging face repo" 109 | )] 110 | recognition_model_repo: String, 111 | 112 | #[arg( 113 | long, 114 | default_value = "model.safetensors", 115 | help = "recognition model's weights file name" 116 | )] 117 | recognition_weights_file_name: String, 118 | 119 | #[arg( 120 | long, 121 | default_value = "config.json", 122 | help = "recognition model's config file name" 123 | )] 124 | recognition_config_file_name: String, 125 | 126 | #[arg( 127 | long, 128 | default_value = "./surya_output", 129 | help = "output directory, under which the input image will be generating a subdirectory" 130 | )] 131 | output_dir: PathBuf, 132 | 133 | #[arg( 134 | long = "polygons", 135 | default_value_t = true, 136 | help = "whether to output polygons json file" 137 | )] 138 | output_polygons: bool, 139 | 140 | #[arg( 141 | long = "image", 142 | default_value_t = true, 143 | help = "whether to generate bbox image" 144 | )] 145 | generate_bbox_image: bool, 146 | 147 | #[arg( 148 | long = "heatmap", 149 | default_value_t = true, 150 | help = "whether to generate heatmap" 151 | )] 152 | generate_heatmap: bool, 153 | 154 | #[arg( 155 | long = "affinity-map", 156 | default_value_t = true, 157 | help = "whether to generate affinity map" 158 | )] 159 | generate_affinity_map: bool, 160 | 161 | #[arg( 162 | long = "device", 163 | value_enum, 164 | help = "device type, if not specified will try to use GPU or Metal" 165 | )] 166 | device_type: Option, 167 | 168 | #[arg(long, help = "whether to enable verbose mode")] 169 | verbose: bool, 170 | } 171 | 172 | impl Cli { 173 | fn get_detection_model(&self, device: &Device) -> surya::Result { 174 | SemanticSegmentationModel::from_hf( 175 | HfModelInfo { 176 | model_type: "detection", 177 | repo: self.detection_model_repo.clone(), 178 | weights_file: self.detection_weights_file_name.clone(), 179 | config_file: self.detection_config_file_name.clone(), 180 | }, 181 | device, 182 | ) 183 | } 184 | 185 | fn get_recognition_model(&self, device: &Device) -> surya::Result { 186 | RecognitionModel::from_hf( 187 | HfModelInfo { 188 | model_type: "recognition", 189 | repo: self.recognition_model_repo.clone(), 190 | weights_file: self.recognition_weights_file_name.clone(), 191 | config_file: self.recognition_config_file_name.clone(), 192 | }, 193 | device, 194 | ) 195 | } 196 | } 197 | 198 | fn main() -> surya::Result<()> { 199 | let args = Cli::parse(); 200 | let env = Env::new().filter_or("SURYA_LOG", if args.verbose { "debug" } else { "info" }); 201 | env_logger::init_from_env(env); 202 | 203 | assert!( 204 | 0.0 <= args.non_max_suppression_threshold && args.non_max_suppression_threshold <= 1.0, 205 | "non-max-suppression-threshold must be between 0.0 and 1.0" 206 | ); 207 | assert!( 208 | 0.0 <= args.extract_text_threshold && args.extract_text_threshold <= 1.0, 209 | "extract-text-threshold must be between 0.0 and 1.0" 210 | ); 211 | assert!( 212 | args.bbox_area_threshold > 0, 213 | "bbox-area-threshold must be > 0" 214 | ); 215 | 216 | let device = match args.device_type { 217 | Some(device_type) => device_type.try_into()?, 218 | None => Device::new_cuda(0) 219 | .or_else(|_| Device::new_metal(0)) 220 | .unwrap_or(Device::Cpu), 221 | }; 222 | 223 | debug!("using device {:?}", device); 224 | 225 | let image_chunks = read_chunked_resized_image(&args.image)?; 226 | 227 | // join the output dir with the input image's base name 228 | let output_dir = args.image.file_stem().expect("failed to get file stem"); 229 | let output_dir = args.output_dir.join(output_dir); 230 | std::fs::DirBuilder::new() 231 | .recursive(true) 232 | .create(output_dir.clone())?; 233 | info!("generating output to {:?}", output_dir); 234 | 235 | let detection_model = args.get_detection_model(&device)?; 236 | // let recognition_model = args.get_recognition_model(&device)?; 237 | 238 | let batch_size = args.detection_batch_size.unwrap_or(match device { 239 | Device::Cpu => 2, 240 | Device::Cuda(_) | Device::Metal(_) => 16, 241 | }); 242 | let image_tensors: Vec = image_chunks 243 | .resized_chunks 244 | .iter() 245 | .map(|img| image_to_tensor(img, &device)) 246 | .collect::>()?; 247 | 248 | let mut heatmaps = Vec::new(); 249 | let mut affinity_maps = Vec::new(); 250 | for batch in image_tensors.chunks(batch_size) { 251 | let batch_size = batch.len(); 252 | let batch = Tensor::stack(batch, 0)?; 253 | info!( 254 | "starting segformer inference with batch size {}...", 255 | batch_size, 256 | ); 257 | let now = Instant::now(); 258 | let segmentation = detection_model.forward(&batch)?; 259 | info!("inference took {:.3}s", now.elapsed().as_secs_f32()); 260 | for i in 0..batch_size { 261 | let heatmap: Tensor = segmentation.i(i)?.squeeze(0)?.i(0)?; 262 | let affinity_map: Tensor = segmentation.i(i)?.squeeze(0)?.i(1)?; 263 | heatmaps.push(heatmap); 264 | affinity_maps.push(affinity_map); 265 | } 266 | } 267 | 268 | let heatmap = image_chunks.stitch_image_tensors(heatmaps)?; 269 | let affinity_map = image_chunks.stitch_image_tensors(affinity_maps)?; 270 | 271 | debug!("heatmap {:?}", heatmap); 272 | debug!("affinity_map {:?}", affinity_map); 273 | 274 | let bboxes = generate_bbox( 275 | &heatmap, 276 | args.non_max_suppression_threshold, 277 | args.extract_text_threshold, 278 | args.bbox_area_threshold as i32, 279 | )?; 280 | 281 | if args.output_polygons { 282 | let output_file = output_dir.join("polygons.jsonl"); 283 | let mut buf_writer = BufWriter::new(File::create(&output_file)?); 284 | for bbox in &bboxes { 285 | let polygons: Vec<(f32, f32)> = bbox 286 | .polygon 287 | .iter() 288 | .map(|p| { 289 | let precision = 1.0e3; 290 | let x = (p.x * precision).round() / precision; 291 | let y = (p.y * precision).round() / precision; 292 | (x, y) 293 | }) 294 | .collect(); 295 | serde_json::to_writer(&mut buf_writer, &polygons)?; 296 | writeln!(&mut buf_writer)?; 297 | } 298 | buf_writer.flush()?; 299 | info!("polygons json file {:?} generated", output_file); 300 | } 301 | 302 | if args.generate_bbox_image { 303 | let mut image = read_image(args.image)?; 304 | let output_file = output_dir.join("bbox.png"); 305 | draw_bboxes( 306 | &mut image, 307 | heatmap.size()?, 308 | image_chunks.original_size_with_padding, 309 | &bboxes, 310 | &output_file, 311 | )?; 312 | info!("bbox image {:?} generated", output_file); 313 | } 314 | 315 | if args.generate_heatmap { 316 | let output_file = output_dir.join("heatmap.png"); 317 | let image = image_chunks.resize_heatmap_to_image(heatmap)?; 318 | save_image(&image, &output_file)?; 319 | info!("heatmap image {:?} generated", output_file); 320 | } 321 | 322 | if args.generate_affinity_map { 323 | let output_file = output_dir.join("affinity_map.png"); 324 | let image = image_chunks.resize_heatmap_to_image(affinity_map)?; 325 | save_image(&image, &output_file)?; 326 | info!("affinity map image {:?} generated", output_file); 327 | } 328 | 329 | Ok(()) 330 | } 331 | -------------------------------------------------------------------------------- /src/detection/mod.rs: -------------------------------------------------------------------------------- 1 | //! Detection module, consisting of segformer implementation 2 | 3 | mod segformer; 4 | 5 | use crate::error::Result; 6 | use crate::hf::HfModel; 7 | use candle_core::Device; 8 | use candle_nn::VarBuilder; 9 | pub use segformer::Config; 10 | pub use segformer::SemanticSegmentationModel; 11 | use std::path::PathBuf; 12 | 13 | impl HfModel for SemanticSegmentationModel { 14 | fn from_hf_files(config: PathBuf, weights: PathBuf, device: &Device) -> Result { 15 | let config = serde_json::from_str(&std::fs::read_to_string(config)?)?; 16 | let vb = unsafe { 17 | VarBuilder::from_mmaped_safetensors(&[weights], candle_core::DType::F32, device)? 18 | }; 19 | Self::new(&config, 2, vb).map_err(Into::into) 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/detection/segformer.rs: -------------------------------------------------------------------------------- 1 | //! Segformer implementation 2 | use candle_core::{Module, ModuleT, Result, Tensor, D}; 3 | use candle_nn::{ 4 | conv2d, conv2d_no_bias, layer_norm, linear, Activation, Conv2d, Conv2dConfig, Linear, 5 | VarBuilder, 6 | }; 7 | 8 | // https://github.com/huggingface/transformers/blob/main/src/transformers/models/segformer/configuration_segformer.py 9 | #[derive(Debug, Clone, PartialEq, serde::Deserialize)] 10 | pub struct Config { 11 | pub num_channels: usize, 12 | pub num_encoder_blocks: usize, 13 | pub depths: Vec, 14 | pub sr_ratios: Vec, 15 | pub hidden_sizes: Vec, 16 | pub patch_sizes: Vec, 17 | pub strides: Vec, 18 | pub num_attention_heads: Vec, 19 | pub mlp_ratios: Vec, 20 | pub hidden_act: candle_nn::Activation, 21 | pub layer_norm_eps: f64, 22 | pub decoder_layer_hidden_size: usize, 23 | pub decoder_hidden_size: usize, 24 | } 25 | 26 | impl Config { 27 | pub fn new() -> Self { 28 | Self { 29 | num_channels: 3, 30 | num_encoder_blocks: 4, 31 | depths: vec![3, 4, 9, 3], 32 | sr_ratios: vec![8, 4, 2, 1], 33 | hidden_sizes: vec![64, 128, 320, 512], 34 | patch_sizes: vec![7, 3, 3, 3], 35 | strides: vec![4, 2, 2, 2], 36 | num_attention_heads: vec![1, 2, 5, 8], 37 | mlp_ratios: vec![4, 4, 4, 4], 38 | hidden_act: candle_nn::Activation::Gelu, 39 | layer_norm_eps: 1e-6, 40 | decoder_layer_hidden_size: 192, 41 | decoder_hidden_size: 768, 42 | } 43 | } 44 | } 45 | 46 | impl Default for Config { 47 | fn default() -> Self { 48 | Self::new() 49 | } 50 | } 51 | 52 | #[derive(Debug, Clone)] 53 | struct SegformerOverlapPatchEmbeddings { 54 | projection: Conv2d, 55 | layer_norm: candle_nn::LayerNorm, 56 | } 57 | 58 | impl SegformerOverlapPatchEmbeddings { 59 | fn new( 60 | config: &Config, 61 | patch_size: usize, 62 | stride: usize, 63 | num_channels: usize, 64 | hidden_size: usize, 65 | vb: VarBuilder, 66 | ) -> Result { 67 | let projection = conv2d( 68 | num_channels, 69 | hidden_size, 70 | patch_size, 71 | Conv2dConfig { 72 | stride, 73 | padding: patch_size / 2, 74 | ..Default::default() 75 | }, 76 | vb.pp("proj"), 77 | )?; 78 | let layer_norm = 79 | candle_nn::layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm"))?; 80 | Ok(Self { 81 | projection, 82 | layer_norm, 83 | }) 84 | } 85 | } 86 | 87 | impl Module for SegformerOverlapPatchEmbeddings { 88 | fn forward(&self, x: &Tensor) -> Result { 89 | let embeddings = self.projection.forward(x)?; 90 | let shape = embeddings.shape(); 91 | // [B, C, H, W] -> [B, H * W, C] 92 | let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?; 93 | let embeddings = self.layer_norm.forward(&embeddings)?; 94 | // [B, H * W, C] -> [B, C, H, W] 95 | let embeddings = embeddings.transpose(1, 2)?.reshape(shape)?; 96 | Ok(embeddings) 97 | } 98 | } 99 | 100 | #[derive(Debug, Clone)] 101 | struct SegformerEfficientSelfAttention { 102 | num_attention_heads: usize, 103 | attention_head_size: usize, 104 | query: Linear, 105 | key: Linear, 106 | value: Linear, 107 | sr: Option, 108 | layer_norm: Option, 109 | } 110 | 111 | impl SegformerEfficientSelfAttention { 112 | fn new( 113 | config: &Config, 114 | hidden_size: usize, 115 | num_attention_heads: usize, 116 | sequence_reduction_ratio: usize, 117 | vb: VarBuilder, 118 | ) -> Result { 119 | if hidden_size % num_attention_heads != 0 { 120 | candle_core::bail!( 121 | "The hidden size {} is not a multiple of the number of attention heads {}", 122 | hidden_size, 123 | num_attention_heads, 124 | ); 125 | } 126 | let attention_head_size = hidden_size / num_attention_heads; 127 | let all_head_size = num_attention_heads * attention_head_size; 128 | let query = linear(hidden_size, all_head_size, vb.pp("query"))?; 129 | let key = linear(hidden_size, all_head_size, vb.pp("key"))?; 130 | let value = linear(hidden_size, all_head_size, vb.pp("value"))?; 131 | let (sr, layer_norm) = if sequence_reduction_ratio > 1 { 132 | ( 133 | Some(conv2d( 134 | hidden_size, 135 | hidden_size, 136 | sequence_reduction_ratio, 137 | Conv2dConfig { 138 | stride: sequence_reduction_ratio, 139 | ..Default::default() 140 | }, 141 | vb.pp("sr"), 142 | )?), 143 | Some(candle_nn::layer_norm( 144 | hidden_size, 145 | config.layer_norm_eps, 146 | vb.pp("layer_norm"), 147 | )?), 148 | ) 149 | } else { 150 | (None, None) 151 | }; 152 | Ok(Self { 153 | num_attention_heads, 154 | attention_head_size, 155 | query, 156 | key, 157 | value, 158 | sr, 159 | layer_norm, 160 | }) 161 | } 162 | 163 | fn transpose_for_scores(&self, hidden_states: Tensor) -> Result { 164 | let (batch, seq_length, _) = hidden_states.shape().dims3()?; 165 | let new_shape = &[ 166 | batch, 167 | seq_length, 168 | self.num_attention_heads, 169 | self.attention_head_size, 170 | ]; 171 | let hidden_states = hidden_states.reshape(new_shape)?; 172 | let hidden_states = hidden_states.permute((0, 2, 1, 3))?; 173 | Ok(hidden_states) 174 | } 175 | } 176 | 177 | impl Module for SegformerEfficientSelfAttention { 178 | fn forward(&self, x: &Tensor) -> Result { 179 | // [B, C, H, W] -> [B, H * W, C] 180 | let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?; 181 | let query = self 182 | .transpose_for_scores(self.query.forward(&hidden_states)?)? 183 | .contiguous()?; 184 | let hidden_states = if let (Some(sr), Some(layer_norm)) = (&self.sr, &self.layer_norm) { 185 | let hidden_states = sr.forward(x)?; 186 | // [B, C, H, W] -> [B, H * W, C] 187 | let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?; 188 | layer_norm.forward(&hidden_states)? 189 | } else { 190 | // already [B, H * W, C] 191 | hidden_states 192 | }; 193 | // standard self-attention 194 | let key = self 195 | .transpose_for_scores(self.key.forward(&hidden_states)?)? 196 | .contiguous()?; 197 | let value = self 198 | .transpose_for_scores(self.value.forward(&hidden_states)?)? 199 | .contiguous()?; 200 | let attention_scores = 201 | (query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?; 202 | let attention_scores = candle_nn::ops::softmax_last_dim(&attention_scores)?; 203 | let result = attention_scores.matmul(&value)?; 204 | let result = result.permute((0, 2, 1, 3))?.contiguous()?; 205 | result.flatten_from(D::Minus2) 206 | } 207 | } 208 | 209 | #[derive(Debug, Clone)] 210 | struct SegformerSelfOutput { 211 | dense: Linear, 212 | } 213 | 214 | impl SegformerSelfOutput { 215 | fn new(hidden_size: usize, vb: VarBuilder) -> Result { 216 | let dense = linear(hidden_size, hidden_size, vb.pp("dense"))?; 217 | Ok(Self { dense }) 218 | } 219 | } 220 | 221 | impl Module for SegformerSelfOutput { 222 | fn forward(&self, x: &Tensor) -> Result { 223 | self.dense.forward(x) 224 | } 225 | } 226 | 227 | #[derive(Debug, Clone)] 228 | struct SegformerAttention { 229 | attention: SegformerEfficientSelfAttention, 230 | output: SegformerSelfOutput, 231 | } 232 | 233 | impl SegformerAttention { 234 | fn new( 235 | config: &Config, 236 | hidden_size: usize, 237 | num_attention_heads: usize, 238 | sequence_reduction_ratio: usize, 239 | vb: VarBuilder, 240 | ) -> Result { 241 | let attention = SegformerEfficientSelfAttention::new( 242 | config, 243 | hidden_size, 244 | num_attention_heads, 245 | sequence_reduction_ratio, 246 | vb.pp("self"), 247 | )?; 248 | let output = SegformerSelfOutput::new(hidden_size, vb.pp("output"))?; 249 | Ok(Self { attention, output }) 250 | } 251 | } 252 | 253 | impl Module for SegformerAttention { 254 | fn forward(&self, x: &Tensor) -> Result { 255 | let attention_output = self.attention.forward(x)?; 256 | self.output.forward(&attention_output) 257 | } 258 | } 259 | 260 | #[derive(Debug, Clone)] 261 | struct SegformerDWConv { 262 | dw_conv: Conv2d, 263 | } 264 | 265 | impl SegformerDWConv { 266 | fn new(dim: usize, vb: VarBuilder) -> Result { 267 | let dw_conv = conv2d( 268 | dim, 269 | dim, 270 | 3, 271 | Conv2dConfig { 272 | stride: 1, 273 | padding: 1, 274 | groups: dim, 275 | ..Default::default() 276 | }, 277 | vb.pp("dwconv"), 278 | )?; 279 | Ok(Self { dw_conv }) 280 | } 281 | } 282 | 283 | impl Module for SegformerDWConv { 284 | fn forward(&self, x: &Tensor) -> Result { 285 | self.dw_conv.forward(x) 286 | } 287 | } 288 | 289 | #[derive(Debug, Clone)] 290 | struct SegformerMixFFN { 291 | dense1: Linear, 292 | dw_conv: SegformerDWConv, 293 | act: Activation, 294 | dense2: Linear, 295 | } 296 | 297 | impl SegformerMixFFN { 298 | fn new( 299 | config: &Config, 300 | in_features: usize, 301 | hidden_features: usize, 302 | out_features: usize, 303 | vb: VarBuilder, 304 | ) -> Result { 305 | let dense1 = linear(in_features, hidden_features, vb.pp("dense1"))?; 306 | let dw_conv = SegformerDWConv::new(hidden_features, vb.pp("dwconv"))?; 307 | let act = config.hidden_act; 308 | let dense2 = linear(hidden_features, out_features, vb.pp("dense2"))?; 309 | Ok(Self { 310 | dense1, 311 | dw_conv, 312 | act, 313 | dense2, 314 | }) 315 | } 316 | } 317 | 318 | impl Module for SegformerMixFFN { 319 | fn forward(&self, x: &Tensor) -> Result { 320 | let (batch, _, height, width) = x.shape().dims4()?; 321 | let hidden_states = self 322 | .dense1 323 | .forward(&x.flatten_from(2)?.permute((0, 2, 1))?)?; 324 | let channels = hidden_states.dim(2)?; 325 | let hidden_states = self.dw_conv.forward( 326 | &hidden_states 327 | .permute((0, 2, 1))? 328 | .reshape((batch, channels, height, width))?, 329 | )?; 330 | let hidden_states = self.act.forward(&hidden_states)?; 331 | let hidden_states = self 332 | .dense2 333 | .forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?; 334 | let channels = hidden_states.dim(2)?; 335 | hidden_states 336 | .permute((0, 2, 1))? 337 | .reshape((batch, channels, height, width)) 338 | } 339 | } 340 | 341 | #[derive(Debug, Clone)] 342 | struct SegformerLayer { 343 | layer_norm_1: candle_nn::LayerNorm, 344 | attention: SegformerAttention, 345 | layer_norm_2: candle_nn::LayerNorm, 346 | mlp: SegformerMixFFN, 347 | } 348 | 349 | impl SegformerLayer { 350 | fn new( 351 | config: &Config, 352 | hidden_size: usize, 353 | num_attention_heads: usize, 354 | sequence_reduction_ratio: usize, 355 | mlp_ratio: usize, 356 | vb: VarBuilder, 357 | ) -> Result { 358 | let layer_norm_1 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_1"))?; 359 | let attention = SegformerAttention::new( 360 | config, 361 | hidden_size, 362 | num_attention_heads, 363 | sequence_reduction_ratio, 364 | vb.pp("attention"), 365 | )?; 366 | let layer_norm_2 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_2"))?; 367 | let mlp = SegformerMixFFN::new( 368 | config, 369 | hidden_size, 370 | hidden_size * mlp_ratio, 371 | hidden_size, 372 | vb.pp("mlp"), 373 | )?; 374 | Ok(Self { 375 | layer_norm_1, 376 | attention, 377 | layer_norm_2, 378 | mlp, 379 | }) 380 | } 381 | } 382 | 383 | impl Module for SegformerLayer { 384 | fn forward(&self, x: &Tensor) -> Result { 385 | let shape = x.shape().dims4()?; 386 | // [B, C, H, W] -> [B, H * W, C] 387 | let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?; 388 | let layer_norm_output = self.layer_norm_1.forward(&hidden_states)?; 389 | let layer_norm_output = layer_norm_output.permute((0, 2, 1))?.reshape(shape)?; 390 | // attention takes in [B, C, H, W] in order to properly do conv2d (and output [B, H * W, C]) 391 | let attention_output = self.attention.forward(&layer_norm_output)?; 392 | let hidden_states = (attention_output + hidden_states)?; 393 | let layer_norm_output = self.layer_norm_2.forward(&hidden_states)?; 394 | let mlp_output = self 395 | .mlp 396 | .forward(&layer_norm_output.permute((0, 2, 1))?.reshape(shape)?)?; 397 | hidden_states.permute((0, 2, 1))?.reshape(shape)? + mlp_output 398 | } 399 | } 400 | 401 | #[derive(Debug, Clone)] 402 | struct SegformerEncoder { 403 | /// config file 404 | config: Config, 405 | /// a list of embeddings 406 | patch_embeddings: Vec, 407 | /// a list of attention blocks, each consisting of layers 408 | blocks: Vec>, 409 | /// a final list of layer norms 410 | layer_norms: Vec, 411 | } 412 | 413 | impl SegformerEncoder { 414 | fn new(config: Config, vb: VarBuilder) -> Result { 415 | let mut patch_embeddings = Vec::with_capacity(config.num_encoder_blocks); 416 | let mut blocks = Vec::with_capacity(config.num_encoder_blocks); 417 | let mut layer_norms = Vec::with_capacity(config.num_encoder_blocks); 418 | for i in 0..config.num_encoder_blocks { 419 | let patch_size = config.patch_sizes[i]; 420 | let stride = config.strides[i]; 421 | let hidden_size = config.hidden_sizes[i]; 422 | let num_channels = if i == 0 { 423 | config.num_channels 424 | } else { 425 | config.hidden_sizes[i - 1] 426 | }; 427 | patch_embeddings.push(SegformerOverlapPatchEmbeddings::new( 428 | &config, 429 | patch_size, 430 | stride, 431 | num_channels, 432 | hidden_size, 433 | vb.pp(&format!("patch_embeddings.{}", i)), 434 | )?); 435 | let mut layers = Vec::with_capacity(config.depths[i]); 436 | for j in 0..config.depths[i] { 437 | let sequence_reduction_ratio = config.sr_ratios[i]; 438 | let num_attention_heads = config.num_attention_heads[i]; 439 | let mlp_ratio = config.mlp_ratios[i]; 440 | layers.push(SegformerLayer::new( 441 | &config, 442 | hidden_size, 443 | num_attention_heads, 444 | sequence_reduction_ratio, 445 | mlp_ratio, 446 | vb.pp(&format!("block.{}.{}", i, j)), 447 | )?); 448 | } 449 | blocks.push(layers); 450 | layer_norms.push(layer_norm( 451 | hidden_size, 452 | config.layer_norm_eps, 453 | vb.pp(&format!("layer_norm.{}", i)), 454 | )?); 455 | } 456 | Ok(Self { 457 | config, 458 | patch_embeddings, 459 | blocks, 460 | layer_norms, 461 | }) 462 | } 463 | } 464 | 465 | impl ModuleWithHiddenStates for SegformerEncoder { 466 | fn forward(&self, x: &Tensor) -> Result> { 467 | let mut all_hidden_states = Vec::with_capacity(self.config.num_encoder_blocks); 468 | let mut hidden_states = x.clone(); 469 | for i in 0..self.config.num_encoder_blocks { 470 | hidden_states = self.patch_embeddings[i].forward(&hidden_states)?; 471 | for layer in &self.blocks[i] { 472 | hidden_states = layer.forward(&hidden_states)?; 473 | } 474 | let shape = hidden_states.shape().dims4()?; 475 | hidden_states = 476 | self.layer_norms[i].forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?; 477 | hidden_states = hidden_states.permute((0, 2, 1))?.reshape(shape)?; 478 | all_hidden_states.push(hidden_states.clone()); 479 | } 480 | Ok(all_hidden_states) 481 | } 482 | } 483 | 484 | #[derive(Debug, Clone)] 485 | struct SegformerModel { 486 | encoder: SegformerEncoder, 487 | } 488 | 489 | impl SegformerModel { 490 | fn new(config: &Config, vb: VarBuilder) -> Result { 491 | let encoder = SegformerEncoder::new(config.clone(), vb.pp("encoder"))?; 492 | Ok(Self { encoder }) 493 | } 494 | } 495 | 496 | impl ModuleWithHiddenStates for SegformerModel { 497 | fn forward(&self, x: &Tensor) -> Result> { 498 | self.encoder.forward(x) 499 | } 500 | } 501 | 502 | #[derive(Debug, Clone)] 503 | struct SegformerMLP { 504 | proj: Linear, 505 | } 506 | 507 | impl SegformerMLP { 508 | fn new(config: &Config, input_dim: usize, vb: VarBuilder) -> Result { 509 | let proj = linear(input_dim, config.decoder_layer_hidden_size, vb.pp("proj"))?; 510 | Ok(Self { proj }) 511 | } 512 | } 513 | 514 | impl Module for SegformerMLP { 515 | fn forward(&self, x: &Tensor) -> Result { 516 | self.proj.forward(x) 517 | } 518 | } 519 | 520 | trait ModuleWithHiddenStates { 521 | fn forward(&self, xs: &Tensor) -> Result>; 522 | } 523 | 524 | #[cfg(test)] 525 | mod tests { 526 | 527 | use super::*; 528 | 529 | #[test] 530 | fn test_config_json_load() { 531 | let raw_json = r#"{ 532 | "_name_or_path": "line_detector_192_aug/checkpoint-72000", 533 | "architectures": [ 534 | "SegformerForRegressionMask" 535 | ], 536 | "attention_probs_dropout_prob": 0.0, 537 | "classifier_dropout_prob": 0.1, 538 | "decoder_hidden_size": 768, 539 | "decoder_layer_hidden_size": 192, 540 | "decoder_upsample_rate": 2, 541 | "depths": [ 542 | 3, 543 | 4, 544 | 9, 545 | 3 546 | ], 547 | "downsampling_rates": [ 548 | 1, 549 | 4, 550 | 8, 551 | 16 552 | ], 553 | "drop_path_rate": 0.1, 554 | "hidden_act": "gelu", 555 | "hidden_dropout_prob": 0.0, 556 | "hidden_sizes": [ 557 | 64, 558 | 128, 559 | 320, 560 | 512 561 | ], 562 | "id2label": { 563 | "0": "blank", 564 | "1": "text" 565 | }, 566 | "image_size": 224, 567 | "initializer_range": 0.02, 568 | "label2id": { 569 | "blank": 0, 570 | "text": 1 571 | }, 572 | "layer_norm_eps": 1e-06, 573 | "mlp_ratios": [ 574 | 4, 575 | 4, 576 | 4, 577 | 4 578 | ], 579 | "model_type": "segformer", 580 | "num_attention_heads": [ 581 | 1, 582 | 2, 583 | 5, 584 | 8 585 | ], 586 | "num_channels": 3, 587 | "num_encoder_blocks": 4, 588 | "patch_sizes": [ 589 | 7, 590 | 3, 591 | 3, 592 | 3 593 | ], 594 | "reshape_last_stage": true, 595 | "semantic_loss_ignore_index": -1, 596 | "sr_ratios": [ 597 | 8, 598 | 4, 599 | 2, 600 | 1 601 | ], 602 | "strides": [ 603 | 4, 604 | 2, 605 | 2, 606 | 2 607 | ], 608 | "torch_dtype": "float32", 609 | "transformers_version": "4.36.0" 610 | }"#; 611 | let config: Config = serde_json::from_str(raw_json).unwrap(); 612 | assert_eq!(vec![4, 2, 2, 2], config.strides); 613 | assert_eq!(1e-6, config.layer_norm_eps); 614 | assert_eq!(Config::default(), config); 615 | } 616 | } 617 | 618 | #[derive(Debug, Clone)] 619 | struct SegformerDecodeHead { 620 | linear_c: Vec, 621 | linear_fuse: candle_nn::Conv2d, 622 | batch_norm: candle_nn::BatchNorm, 623 | classifier: candle_nn::Conv2d, 624 | } 625 | 626 | impl SegformerDecodeHead { 627 | fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result { 628 | let mut linear_c = Vec::with_capacity(config.num_encoder_blocks); 629 | for i in 0..config.num_encoder_blocks { 630 | let hidden_size = config.hidden_sizes[i]; 631 | linear_c.push(SegformerMLP::new( 632 | config, 633 | hidden_size, 634 | vb.pp(&format!("linear_c.{}", i)), 635 | )?); 636 | } 637 | let linear_fuse = conv2d_no_bias( 638 | config.decoder_layer_hidden_size * config.num_encoder_blocks, 639 | config.decoder_hidden_size, 640 | 1, 641 | Conv2dConfig::default(), 642 | vb.pp("linear_fuse"), 643 | )?; 644 | let batch_norm = candle_nn::batch_norm( 645 | config.decoder_hidden_size, 646 | config.layer_norm_eps, 647 | vb.pp("batch_norm"), 648 | )?; 649 | let classifier = conv2d( 650 | config.decoder_hidden_size, 651 | num_labels, 652 | 1, 653 | Conv2dConfig::default(), 654 | vb.pp("classifier"), 655 | )?; 656 | Ok(Self { 657 | linear_c, 658 | linear_fuse, 659 | batch_norm, 660 | classifier, 661 | }) 662 | } 663 | 664 | fn forward(&self, encoder_hidden_states: &[Tensor]) -> Result { 665 | if encoder_hidden_states.len() != self.linear_c.len() { 666 | candle_core::bail!( 667 | "The number of encoder hidden states {} is not equal to the number of linear layers {}", 668 | encoder_hidden_states.len(), 669 | self.linear_c.len() 670 | ) 671 | } 672 | // most fine layer 673 | let (_, _, upsample_height, upsample_width) = encoder_hidden_states[0].shape().dims4()?; 674 | let mut hidden_states = Vec::with_capacity(self.linear_c.len()); 675 | for (hidden_state, mlp) in encoder_hidden_states.iter().zip(&self.linear_c) { 676 | let (batch, _, height, width) = hidden_state.shape().dims4()?; 677 | let hidden_state = mlp.forward(&hidden_state.flatten_from(2)?.permute((0, 2, 1))?)?; 678 | let hidden_state = hidden_state.permute((0, 2, 1))?.reshape(( 679 | batch, 680 | hidden_state.dim(2)?, 681 | height, 682 | width, 683 | ))?; 684 | let hidden_state = hidden_state.upsample_nearest2d(upsample_height, upsample_width)?; 685 | hidden_states.push(hidden_state); 686 | } 687 | hidden_states.reverse(); 688 | let hidden_states = Tensor::cat(&hidden_states, 1)?; 689 | let hidden_states = self.linear_fuse.forward(&hidden_states)?; 690 | let hidden_states = self.batch_norm.forward_t(&hidden_states, false)?; 691 | let hidden_states = hidden_states.relu()?; 692 | self.classifier.forward(&hidden_states) 693 | } 694 | } 695 | 696 | #[derive(Debug, Clone)] 697 | pub struct SemanticSegmentationModel { 698 | segformer: SegformerModel, 699 | decode_head: SegformerDecodeHead, 700 | } 701 | 702 | impl SemanticSegmentationModel { 703 | pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result { 704 | let segformer = SegformerModel::new(config, vb.pp("segformer"))?; 705 | let decode_head = SegformerDecodeHead::new(config, num_labels, vb.pp("decode_head"))?; 706 | Ok(Self { 707 | segformer, 708 | decode_head, 709 | }) 710 | } 711 | } 712 | 713 | impl Module for SemanticSegmentationModel { 714 | fn forward(&self, x: &Tensor) -> Result { 715 | let hidden_states = self.segformer.forward(x)?; 716 | let hidden_states = self.decode_head.forward(&hidden_states)?; 717 | let result = candle_nn::ops::sigmoid(&hidden_states)?; 718 | Ok(result) 719 | } 720 | } 721 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use thiserror::Error; 3 | 4 | #[derive(Error, Debug)] 5 | pub enum SuryaError { 6 | #[error("Candle error: {0}")] 7 | CandleError(#[from] candle_core::Error), 8 | #[error("OpenCV error: {0}")] 9 | OpenCVError(#[from] opencv::Error), 10 | #[error("IO error: {0}")] 11 | IoError(#[from] io::Error), 12 | #[error("Json deser error: {0}")] 13 | JsonDeserError(#[from] serde_json::Error), 14 | #[error("Hugging Face Hub error: {0}")] 15 | ApiError(#[from] hf_hub::api::sync::ApiError), 16 | } 17 | 18 | pub type Result = std::result::Result; 19 | -------------------------------------------------------------------------------- /src/hf.rs: -------------------------------------------------------------------------------- 1 | //! HuggingFace API 2 | 3 | use crate::error::Result; 4 | use candle_core::Device; 5 | use hf_hub::api::sync::ApiBuilder; 6 | use log::debug; 7 | use std::path::PathBuf; 8 | 9 | pub struct HfModelInfo { 10 | pub model_type: &'static str, 11 | pub repo: String, 12 | pub weights_file: String, 13 | pub config_file: String, 14 | } 15 | 16 | impl HfModelInfo { 17 | pub fn download_model_files(&self) -> Result<(PathBuf, PathBuf)> { 18 | let api = ApiBuilder::new().with_progress(true).build()?; 19 | let repo = api.model(self.repo.clone()); 20 | debug!( 21 | "using {} model from HuggingFace repo '{}'", 22 | self.model_type, self.repo, 23 | ); 24 | let model_file = repo.get(&self.weights_file)?; 25 | debug!( 26 | "using {} weights file '{}'", 27 | self.model_type, self.weights_file 28 | ); 29 | let config_file = repo.get(&self.config_file)?; 30 | debug!( 31 | "using {} config file '{}'", 32 | self.model_type, self.config_file 33 | ); 34 | Ok((config_file, model_file)) 35 | } 36 | } 37 | 38 | pub trait HfModel { 39 | fn from_hf(info: HfModelInfo, device: &Device) -> Result 40 | where 41 | Self: Sized, 42 | { 43 | let (config_file, model_file) = info.download_model_files()?; 44 | Self::from_hf_files(config_file, model_file, device) 45 | } 46 | 47 | fn from_hf_files(config: PathBuf, weights: PathBuf, device: &Device) -> Result 48 | where 49 | Self: Sized; 50 | } 51 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod bbox; 2 | pub mod detection; 3 | pub mod error; 4 | pub mod hf; 5 | pub mod postprocess; 6 | pub mod preprocess; 7 | pub mod recognition; 8 | 9 | pub use error::Result; 10 | -------------------------------------------------------------------------------- /src/postprocess.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | 3 | use candle_core::Tensor; 4 | use opencv::core::{self, Mat, Vector}; 5 | use opencv::imgcodecs; 6 | use opencv::imgproc; 7 | use opencv::prelude::*; 8 | 9 | pub struct ImageChunks { 10 | pub resized_chunks: Vec, 11 | pub padding: i32, 12 | pub original_size: core::Size, 13 | pub original_size_with_padding: core::Size, 14 | } 15 | 16 | impl ImageChunks { 17 | pub fn stitch_image_tensors(&self, images: Vec) -> crate::Result { 18 | let image_chunks = images 19 | .into_iter() 20 | .map(heatmap_tensor_to_mat) 21 | .collect::>>()?; 22 | let mut image = Mat::default(); 23 | let image_chunks = Vector::::from_iter(image_chunks); 24 | core::vconcat(&image_chunks, &mut image)?; 25 | Ok(image) 26 | } 27 | 28 | pub fn resize_heatmap_to_image(&self, heatmap: Mat) -> crate::Result { 29 | // convert image [0,1) to 255 grayscale image 30 | let mut gray_scale_image = Mat::default(); 31 | heatmap.convert_to(&mut gray_scale_image, core::CV_8UC1, 255.0, 0.0)?; 32 | // resize image 33 | let mut resized_image = Mat::default(); 34 | imgproc::resize( 35 | &gray_scale_image, 36 | &mut resized_image, 37 | self.original_size_with_padding, 38 | 0.0, 39 | 0.0, 40 | opencv::imgproc::INTER_LINEAR, 41 | )?; 42 | let result = Mat::roi( 43 | &resized_image, 44 | core::Rect::new(0, 0, self.original_size.width, self.original_size.height), 45 | )? 46 | .clone_pointee(); 47 | Ok(result) 48 | } 49 | } 50 | 51 | fn heatmap_tensor_to_mat(heatmap: Tensor) -> crate::Result { 52 | let (height, width) = heatmap.dims2()?; 53 | debug_assert_eq!(height, width, "original heatmap must be square"); 54 | let heatmap: Vec> = heatmap.to_vec2()?; 55 | let mut img = 56 | unsafe { Mat::new_size(core::Size::new(width as i32, height as i32), core::CV_32F)? }; 57 | for (x, row) in heatmap.iter().enumerate() { 58 | for (y, &value) in row.iter().enumerate() { 59 | *(img.at_2d_mut::(x as i32, y as i32)?) = value; 60 | } 61 | } 62 | Ok(img) 63 | } 64 | 65 | /// convert an image from map to gray scale image and save it to output_path 66 | pub fn save_image>(image: &Mat, output_path: P) -> crate::Result<()> { 67 | imgcodecs::imwrite( 68 | output_path.as_ref().as_os_str().to_str().unwrap(), 69 | image, 70 | &core::Vector::new(), 71 | )?; 72 | Ok(()) 73 | } 74 | -------------------------------------------------------------------------------- /src/preprocess.rs: -------------------------------------------------------------------------------- 1 | use crate::postprocess::ImageChunks; 2 | use candle_core::{Device, Tensor}; 3 | use log::debug; 4 | use opencv::{ 5 | core::{self}, 6 | imgcodecs::{self, IMREAD_COLOR}, 7 | imgproc, 8 | prelude::*, 9 | }; 10 | use std::path::Path; 11 | 12 | const INPUT_IMAGE_SIZE: i32 = 896; 13 | const IMAGE_CHUNK_HEIGHT: i32 = 1200; 14 | 15 | /// load image from path and resize it to [INPUT_IMAGE_SIZE] and return the resized image and 16 | /// its original size 17 | pub fn read_chunked_resized_image>(image_path: P) -> crate::Result { 18 | let image = read_image(image_path)?; 19 | let original_size = core::Size::new(image.cols(), image.rows()); 20 | 21 | let num_chunks = (original_size.height as f32 / IMAGE_CHUNK_HEIGHT as f32).ceil() as usize; 22 | debug_assert!(num_chunks > 0, "image must have at least one chunk"); 23 | 24 | // pad the image with black pixels to make it divisible by chunk_height 25 | let mut padding: i32 = original_size.height % IMAGE_CHUNK_HEIGHT; 26 | if padding > 0 { 27 | padding = IMAGE_CHUNK_HEIGHT - padding; 28 | debug_assert!(padding > 0, "padding must be (still) greater than 0"); 29 | } 30 | debug!( 31 | "image size is (w, h)=({}, {}), padding with {}", 32 | original_size.width, original_size.height, padding 33 | ); 34 | 35 | let image = if padding > 0 { 36 | let mut padded_image = Mat::default(); 37 | core::copy_make_border( 38 | &image, 39 | &mut padded_image, 40 | 0, 41 | padding, 42 | 0, 43 | 0, 44 | core::BORDER_CONSTANT, 45 | core::Scalar::all(0.), 46 | )?; 47 | padded_image 48 | } else { 49 | image 50 | }; 51 | debug_assert_eq!( 52 | image.rows() % IMAGE_CHUNK_HEIGHT, 53 | 0, 54 | "image height must be divisible by {}", 55 | IMAGE_CHUNK_HEIGHT 56 | ); 57 | 58 | let resized_chunks = (0..num_chunks) 59 | .map(|i| { 60 | let start = (i as i32) * IMAGE_CHUNK_HEIGHT; 61 | let roi: core::Rect_ = core::Rect::new(0, start, image.cols(), IMAGE_CHUNK_HEIGHT); 62 | let chunk = Mat::roi(&image, roi)?.clone_pointee(); 63 | let size = core::Size::new(INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE); 64 | resize(chunk, size) 65 | }) 66 | .collect::>>()?; 67 | 68 | Ok(ImageChunks { 69 | resized_chunks, 70 | padding, 71 | original_size, 72 | original_size_with_padding: core::Size::new( 73 | original_size.width, 74 | original_size.height + padding, 75 | ), 76 | }) 77 | } 78 | 79 | /// read image into a matrix 80 | pub fn read_image>(image_path: P) -> crate::Result { 81 | let image = imgcodecs::imread(image_path.as_ref().to_str().unwrap(), IMREAD_COLOR)?; 82 | Ok(image) 83 | } 84 | 85 | /// load dynamic image into a device tensor 86 | pub fn image_to_tensor(input: &Mat, device: &Device) -> crate::Result { 87 | let mut image = Mat::default(); 88 | // Convert the image to RGB (OpenCV reads images in BGR format by default) 89 | imgproc::cvt_color(input, &mut image, imgproc::COLOR_BGR2RGB, 0)?; 90 | // Get the dimensions of the image 91 | let size = image.size()?; 92 | let width = size.width; 93 | let height = size.height; 94 | // Convert the Mat to a slice of u8 and then to a Tensor and reshape it 95 | let data = Tensor::from_slice( 96 | image.data_bytes()?, 97 | (height as usize, width as usize, 3), 98 | device, 99 | )? 100 | .permute((2, 0, 1))?; 101 | let mean = Tensor::new(&[0.485f32, 0.456, 0.406], device)?.reshape((3, 1, 1))?; 102 | let std = Tensor::new(&[0.229f32, 0.224, 0.225], device)?.reshape((3, 1, 1))?; 103 | Ok((data.to_dtype(candle_core::DType::F32)? / 255.)? 104 | .broadcast_sub(&mean)? 105 | .broadcast_div(&std)?) 106 | } 107 | 108 | fn resize(image: Mat, new_size: core::Size) -> crate::Result { 109 | let mut resized_image = Mat::default(); 110 | imgproc::resize( 111 | &image, 112 | &mut resized_image, 113 | new_size, 114 | 0.0, 115 | 0.0, 116 | imgproc::INTER_LINEAR, 117 | )?; 118 | Ok(resized_image) 119 | } 120 | -------------------------------------------------------------------------------- /src/recognition/mbart.rs: -------------------------------------------------------------------------------- 1 | //! MBart with MOE 2 | use std::collections::HashMap; 3 | 4 | use candle_core::{Module, Result, Tensor}; 5 | use candle_nn::{Activation, VarBuilder}; 6 | 7 | // TODO this is a placeholder 8 | 9 | #[derive(Debug, Clone, serde::Deserialize)] 10 | pub(crate) struct MBartConfig { 11 | activation_function: Activation, 12 | id2label: HashMap, 13 | langs: HashMap, 14 | vocab_size: usize, 15 | moe_layers: Vec, 16 | d_model: usize, 17 | d_expert: usize, 18 | decoder_attention_heads: usize, 19 | decoder_ffn_dim: usize, 20 | decoder_layers: usize, 21 | kv_heads: usize, 22 | max_position_embeddings: usize, 23 | } 24 | 25 | #[derive(Debug, Clone)] 26 | pub(crate) struct MBart {} 27 | 28 | impl MBart { 29 | pub(crate) fn new(_config: &MBartConfig, _vb: VarBuilder) -> Result { 30 | Ok(Self {}) 31 | } 32 | } 33 | 34 | impl Module for MBart { 35 | fn forward(&self, input: &Tensor) -> Result { 36 | Ok(input.clone()) 37 | } 38 | } 39 | 40 | #[cfg(test)] 41 | mod tests { 42 | use super::*; 43 | 44 | #[test] 45 | fn test_mbart_config() { 46 | let raw_json = r#"{ 47 | "_name_or_path": "", 48 | "activation_dropout": 0.0, 49 | "activation_function": "gelu", 50 | "add_cross_attention": true, 51 | "add_final_layer_norm": true, 52 | "architectures": [ 53 | "MBartForCausalLM" 54 | ], 55 | "attention_dropout": 0.0, 56 | "bad_words_ids": null, 57 | "begin_suppress_tokens": null, 58 | "bos_token_id": 0, 59 | "chunk_size_feed_forward": 0, 60 | "classifier_dropout": 0.0, 61 | "cross_attention_hidden_size": null, 62 | "d_expert": 1024, 63 | "d_model": 1024, 64 | "decoder_attention_heads": 16, 65 | "decoder_ffn_dim": 4096, 66 | "decoder_layerdrop": 0.0, 67 | "decoder_layers": 7, 68 | "decoder_start_token_id": null, 69 | "diversity_penalty": 0.0, 70 | "do_sample": false, 71 | "dropout": 0.1, 72 | "early_stopping": false, 73 | "encoder_attention_heads": 16, 74 | "encoder_ffn_dim": 4096, 75 | "encoder_layerdrop": 0.0, 76 | "encoder_layers": 12, 77 | "encoder_no_repeat_ngram_size": 0, 78 | "eos_token_id": 2, 79 | "exponential_decay_length_penalty": null, 80 | "finetuning_task": null, 81 | "forced_bos_token_id": null, 82 | "forced_eos_token_id": 2, 83 | "id2label": { 84 | "0": "LABEL_0", 85 | "1": "LABEL_1" 86 | }, 87 | "init_std": 0.02, 88 | "is_decoder": true, 89 | "is_encoder_decoder": false, 90 | "kv_heads": 4, 91 | "label2id": { 92 | "LABEL_0": 0, 93 | "LABEL_1": 1 94 | }, 95 | "langs": { 96 | "af": 65539, 97 | "am": 65540, 98 | "ar": 65541, 99 | "as": 65542, 100 | "az": 65543, 101 | "be": 65544, 102 | "bg": 65545, 103 | "bn": 65546, 104 | "br": 65547, 105 | "bs": 65548, 106 | "ca": 65549, 107 | "cs": 65550, 108 | "cy": 65551, 109 | "da": 65552, 110 | "de": 65553, 111 | "el": 65554, 112 | "en": 65555, 113 | "eo": 65556, 114 | "es": 65557, 115 | "et": 65558, 116 | "eu": 65559, 117 | "fa": 65560, 118 | "fi": 65561, 119 | "fr": 65562, 120 | "fy": 65563, 121 | "ga": 65564, 122 | "gd": 65565, 123 | "gl": 65566, 124 | "gu": 65567, 125 | "ha": 65568, 126 | "he": 65569, 127 | "hi": 65570, 128 | "hr": 65571, 129 | "hu": 65572, 130 | "hy": 65573, 131 | "id": 65574, 132 | "is": 65575, 133 | "it": 65576, 134 | "ja": 65577, 135 | "jv": 65578, 136 | "ka": 65579, 137 | "kk": 65580, 138 | "km": 65581, 139 | "kn": 65582, 140 | "ko": 65583, 141 | "ku": 65584, 142 | "ky": 65585, 143 | "la": 65586, 144 | "lo": 65587, 145 | "lt": 65588, 146 | "lv": 65589, 147 | "mg": 65590, 148 | "mk": 65591, 149 | "ml": 65592, 150 | "mn": 65593, 151 | "mr": 65594, 152 | "ms": 65595, 153 | "my": 65596, 154 | "ne": 65597, 155 | "nl": 65598, 156 | "no": 65599, 157 | "om": 65600, 158 | "or": 65601, 159 | "pa": 65602, 160 | "pl": 65603, 161 | "ps": 65604, 162 | "pt": 65605, 163 | "ro": 65606, 164 | "ru": 65607, 165 | "sa": 65608, 166 | "sd": 65609, 167 | "si": 65610, 168 | "sk": 65611, 169 | "sl": 65612, 170 | "so": 65613, 171 | "sq": 65614, 172 | "sr": 65615, 173 | "su": 65616, 174 | "sv": 65617, 175 | "sw": 65618, 176 | "ta": 65619, 177 | "te": 65620, 178 | "th": 65621, 179 | "tl": 65622, 180 | "tr": 65623, 181 | "ug": 65624, 182 | "uk": 65625, 183 | "ur": 65626, 184 | "uz": 65627, 185 | "vi": 65628, 186 | "xh": 65629, 187 | "yi": 65630, 188 | "zh": 65631 189 | }, 190 | "length_penalty": 1.0, 191 | "max_length": 256, 192 | "max_position_embeddings": 1536, 193 | "min_length": 0, 194 | "model_type": "mbart", 195 | "moe_layers": [ 196 | 3 197 | ], 198 | "no_repeat_ngram_size": 0, 199 | "num_beam_groups": 1, 200 | "num_beams": 1, 201 | "num_decoder_layers": 6, 202 | "num_hidden_layers": 12, 203 | "num_return_sequences": 1, 204 | "output_attentions": false, 205 | "output_hidden_states": false, 206 | "output_scores": false, 207 | "pad_token_id": 1, 208 | "prefix": null, 209 | "problem_type": null, 210 | "pruned_heads": {}, 211 | "remove_invalid_values": false, 212 | "repetition_penalty": 1.0, 213 | "return_dict": true, 214 | "return_dict_in_generate": false, 215 | "scale_embedding": true, 216 | "sep_token_id": null, 217 | "suppress_tokens": null, 218 | "task_specific_params": null, 219 | "temperature": 1.0, 220 | "tf_legacy_loss": false, 221 | "tie_encoder_decoder": false, 222 | "tie_word_embeddings": true, 223 | "tokenizer_class": null, 224 | "top_k": 50, 225 | "top_p": 1.0, 226 | "torch_dtype": "float32", 227 | "torchscript": false, 228 | "typical_p": 1.0, 229 | "use_bfloat16": false, 230 | "use_cache": true, 231 | "use_moe": true, 232 | "vocab_size": 65792 233 | }"#; 234 | let deserialized: MBartConfig = serde_json::from_str(raw_json).unwrap(); 235 | assert_eq!(deserialized.langs.len(), 93); 236 | assert_eq!(deserialized.vocab_size, 65792); 237 | assert_eq!(deserialized.moe_layers, vec![3]); 238 | assert_eq!(deserialized.d_model, 1024); 239 | assert_eq!(deserialized.d_expert, 1024); 240 | assert_eq!(deserialized.decoder_attention_heads, 16); 241 | assert_eq!(deserialized.decoder_ffn_dim, 4096); 242 | assert_eq!(deserialized.decoder_layers, 7); 243 | assert_eq!(deserialized.kv_heads, 4); 244 | assert_eq!(deserialized.max_position_embeddings, 1536); 245 | } 246 | } 247 | -------------------------------------------------------------------------------- /src/recognition/mod.rs: -------------------------------------------------------------------------------- 1 | //! The recognition module consists of donut encoder and an MBart decoder 2 | 3 | mod mbart; 4 | mod swin_transformer; 5 | 6 | use crate::hf::HfModel; 7 | use candle_core::{Device, Module, Result, Tensor}; 8 | use candle_nn::VarBuilder; 9 | use mbart::MBart; 10 | use mbart::MBartConfig; 11 | use std::path::PathBuf; 12 | use swin_transformer::SwinConfig; 13 | use swin_transformer::SwinModel; 14 | 15 | #[derive(Debug, Clone, serde::Deserialize)] 16 | pub struct Config { 17 | encoder: SwinConfig, 18 | decoder: MBartConfig, 19 | } 20 | 21 | #[derive(Debug, Clone)] 22 | pub struct RecognitionModel { 23 | encoder: SwinModel, 24 | decoder: MBart, 25 | } 26 | 27 | impl RecognitionModel { 28 | pub fn new(config: &Config, vb: VarBuilder) -> Result { 29 | let encoder = SwinModel::new(&config.encoder, vb.pp("encoder"))?; 30 | let decoder = MBart::new(&config.decoder, vb.pp("decoder"))?; 31 | Ok(Self { encoder, decoder }) 32 | } 33 | } 34 | 35 | impl Module for RecognitionModel { 36 | fn forward(&self, input: &Tensor) -> Result { 37 | let encoded = self.encoder.forward(input)?; 38 | self.decoder.forward(&encoded) 39 | } 40 | } 41 | 42 | impl HfModel for RecognitionModel { 43 | fn from_hf_files( 44 | config: PathBuf, 45 | weights: PathBuf, 46 | device: &Device, 47 | ) -> crate::error::Result { 48 | let config = serde_json::from_str(&std::fs::read_to_string(config)?)?; 49 | let vb = unsafe { 50 | VarBuilder::from_mmaped_safetensors(&[weights], candle_core::DType::F16, device)? 51 | }; 52 | Self::new(&config, vb).map_err(Into::into) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/recognition/swin_transformer.rs: -------------------------------------------------------------------------------- 1 | //! Swin Transformer 2 | //! 3 | //! The Swin Transformer was proposed in Swin Transformer: Hierarchical Vision Transformer using 4 | //! Shifted Windows by Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, 5 | //! Baining Guo. 6 | //! 7 | //! https://huggingface.co/docs/transformers/model_doc/swin 8 | 9 | use candle_core::{DType, Device, IndexOp, Module, Result, Shape, Tensor, D}; 10 | use candle_nn::{ 11 | conv2d, layer_norm, linear, linear_no_bias, Activation, Conv2d, Conv2dConfig, LayerNorm, 12 | Linear, VarBuilder, 13 | }; 14 | 15 | #[derive(Debug, Clone, serde::Deserialize)] 16 | pub(crate) struct SwinConfig { 17 | pub image_size: (usize, usize), 18 | pub patch_size: usize, 19 | pub num_channels: usize, 20 | pub embed_dim: usize, 21 | pub depths: Vec, 22 | pub num_heads: Vec, 23 | pub window_size: usize, 24 | pub mlp_ratio: f64, 25 | pub qkv_bias: bool, 26 | pub hidden_act: Activation, 27 | pub use_absolute_embeddings: bool, 28 | pub initializer_range: f64, 29 | pub layer_norm_eps: f64, 30 | } 31 | 32 | impl SwinConfig { 33 | fn tiny_patch4_window7_224() -> Self { 34 | Self { 35 | image_size: (224, 224), 36 | patch_size: 4, 37 | num_channels: 3, 38 | embed_dim: 96, 39 | depths: vec![2, 2, 6, 2], 40 | num_heads: vec![3, 6, 12, 24], 41 | window_size: 7, 42 | mlp_ratio: 4.0, 43 | qkv_bias: true, 44 | hidden_act: Activation::Gelu, 45 | use_absolute_embeddings: false, 46 | initializer_range: 0.02, 47 | layer_norm_eps: 1e-05, 48 | } 49 | } 50 | } 51 | 52 | impl Default for SwinConfig { 53 | /// this defaults to the surya swin encoder config 54 | fn default() -> Self { 55 | Self { 56 | image_size: (196, 896), 57 | patch_size: 4, 58 | num_channels: 3, 59 | embed_dim: 128, 60 | depths: vec![2, 2, 14, 2], 61 | num_heads: vec![4, 8, 16, 32], 62 | window_size: 7, 63 | mlp_ratio: 4.0, 64 | qkv_bias: true, 65 | hidden_act: Activation::Gelu, 66 | use_absolute_embeddings: true, 67 | initializer_range: 0.02, 68 | layer_norm_eps: 1e-05, 69 | } 70 | } 71 | } 72 | 73 | #[derive(Debug, Clone)] 74 | struct SwinPatchEmbeddings { 75 | projection: Conv2d, 76 | patch_size: usize, 77 | num_channels: usize, 78 | num_patches: usize, 79 | } 80 | 81 | impl SwinPatchEmbeddings { 82 | fn new(config: &SwinConfig, vb: VarBuilder) -> Result { 83 | let num_channels = config.num_channels; 84 | let patch_size = config.patch_size; 85 | let hidden_size = config.embed_dim; 86 | let projection = conv2d( 87 | num_channels, 88 | hidden_size, 89 | patch_size, 90 | Conv2dConfig { 91 | stride: patch_size, 92 | ..Default::default() 93 | }, 94 | vb.pp("projection"), 95 | )?; 96 | let num_patches = config.image_size.0 * config.image_size.1 / patch_size / patch_size; 97 | Ok(Self { 98 | projection, 99 | patch_size, 100 | num_patches, 101 | num_channels, 102 | }) 103 | } 104 | 105 | fn maybe_pad(&self, tensor: &Tensor, height: usize, width: usize) -> Result { 106 | debug_assert_eq!( 107 | 4, 108 | tensor.dims().len(), 109 | "Input tensor must have 4 dimensions" 110 | ); 111 | let tensor = if width % self.patch_size != 0 { 112 | let pad = self.patch_size - (width % self.patch_size); 113 | tensor.pad_with_zeros(3, 0, pad)? 114 | } else { 115 | tensor.clone() 116 | }; 117 | let tensor = if height % self.patch_size != 0 { 118 | let pad = self.patch_size - (height % self.patch_size); 119 | tensor.pad_with_zeros(2, 0, pad)? 120 | } else { 121 | tensor.clone() 122 | }; 123 | Ok(tensor) 124 | } 125 | } 126 | 127 | impl Module for SwinPatchEmbeddings { 128 | fn forward(&self, x: &Tensor) -> Result { 129 | let (_, c, h, w) = x.dims4()?; 130 | if c != self.num_channels { 131 | candle_core::bail!("Input channels must be equal to num_channels"); 132 | } 133 | let x = self.maybe_pad(x, h, w)?; 134 | let embedding = self.projection.forward(&x)?; 135 | Ok(embedding) 136 | } 137 | } 138 | 139 | #[derive(Debug, Clone)] 140 | struct SwinEmbeddings { 141 | patch_embeddings: SwinPatchEmbeddings, 142 | position_embeddings: Option, 143 | norm: LayerNorm, 144 | } 145 | 146 | impl SwinEmbeddings { 147 | fn new(config: &SwinConfig, vb: VarBuilder) -> Result { 148 | let patch_embeddings = SwinPatchEmbeddings::new(config, vb.pp("patch_embeddings"))?; 149 | let norm = layer_norm(config.embed_dim, config.layer_norm_eps, vb.pp("norm"))?; 150 | let position_embeddings = if config.use_absolute_embeddings { 151 | let position_embedding = vb.get( 152 | (1, patch_embeddings.num_patches + 1, config.embed_dim), 153 | "position_embeddings", 154 | )?; 155 | Some(position_embedding) 156 | } else { 157 | None 158 | }; 159 | Ok(Self { 160 | patch_embeddings, 161 | position_embeddings, 162 | norm, 163 | }) 164 | } 165 | } 166 | 167 | impl Module for SwinEmbeddings { 168 | fn forward(&self, x: &Tensor) -> Result { 169 | let x = self.patch_embeddings.forward(x)?; 170 | let (b, c, h, w) = x.dims4()?; 171 | let x = { 172 | let x = x.flatten_from(2)?.permute((0, 2, 1))?; 173 | let x = self.norm.forward(&x)?; 174 | x.permute((0, 2, 1))?.reshape(&[b, c, h, w])? 175 | }; 176 | let x = if let Some(position_embedding) = &self.position_embeddings { 177 | let seq_len = h * w; 178 | let position_embedding = position_embedding.i((.., ..seq_len))?; 179 | let x = x.flatten_from(2)?.permute((0, 2, 1))?; 180 | let x = x.broadcast_add(&position_embedding)?; 181 | x.permute((0, 2, 1))?.reshape(&[b, c, h, w])? 182 | } else { 183 | x.clone() 184 | }; 185 | Ok(x) 186 | } 187 | } 188 | 189 | #[derive(Debug, Clone)] 190 | struct SwinIntermediate { 191 | dense: Linear, 192 | intermediate_act_fn: Activation, 193 | } 194 | 195 | impl SwinIntermediate { 196 | fn new(config: &SwinConfig, dim: usize, vb: VarBuilder) -> Result { 197 | let dense = linear( 198 | dim, 199 | (dim as f64 * config.mlp_ratio) as usize, 200 | vb.pp("dense"), 201 | )?; 202 | let intermediate_act_fn = config.hidden_act; 203 | Ok(Self { 204 | dense, 205 | intermediate_act_fn, 206 | }) 207 | } 208 | } 209 | 210 | impl Module for SwinIntermediate { 211 | fn forward(&self, x: &Tensor) -> Result { 212 | let x = self.dense.forward(x)?; 213 | let x = self.intermediate_act_fn.forward(&x)?; 214 | Ok(x) 215 | } 216 | } 217 | 218 | #[derive(Debug, Clone)] 219 | struct SwinSelfOutput { 220 | dense: Linear, 221 | } 222 | 223 | impl SwinSelfOutput { 224 | fn new(dim: usize, vb: VarBuilder) -> Result { 225 | let dense = linear(dim, dim, vb.pp("dense"))?; 226 | Ok(Self { dense }) 227 | } 228 | } 229 | 230 | impl Module for SwinSelfOutput { 231 | fn forward(&self, x: &Tensor) -> Result { 232 | let x = self.dense.forward(x)?; 233 | Ok(x) 234 | } 235 | } 236 | 237 | #[derive(Debug, Clone)] 238 | struct SwinOutput { 239 | dense: Linear, 240 | } 241 | 242 | impl SwinOutput { 243 | fn new(config: &SwinConfig, dim: usize, vb: VarBuilder) -> Result { 244 | let dense = linear(dim * config.mlp_ratio as usize, dim, vb.pp("dense"))?; 245 | Ok(Self { dense }) 246 | } 247 | } 248 | 249 | impl Module for SwinOutput { 250 | fn forward(&self, x: &Tensor) -> Result { 251 | let x = self.dense.forward(x)?; 252 | Ok(x) 253 | } 254 | } 255 | 256 | #[derive(Debug, Clone)] 257 | struct SwinPatchMerging { 258 | reduction: Linear, 259 | norm: LayerNorm, 260 | } 261 | 262 | impl SwinPatchMerging { 263 | fn new(config: &SwinConfig, dim: usize, vb: VarBuilder) -> Result { 264 | let reduction = linear_no_bias(dim * 4, dim * 2, vb.pp("reduction"))?; 265 | let norm = layer_norm(4 * dim, config.layer_norm_eps, vb.pp("norm"))?; 266 | Ok(Self { reduction, norm }) 267 | } 268 | 269 | fn maybe_pad(x: &Tensor) -> Result { 270 | let (_, h, w, _) = x.dims4()?; 271 | let x = if h % 2 == 1 { 272 | x.pad_with_zeros(1, 0, 1)? 273 | } else { 274 | x.clone() 275 | }; 276 | let x = if w % 2 == 1 { 277 | x.pad_with_zeros(2, 0, 1)? 278 | } else { 279 | x.clone() 280 | }; 281 | Ok(x) 282 | } 283 | } 284 | 285 | impl Module for SwinPatchMerging { 286 | fn forward(&self, x: &Tensor) -> Result { 287 | let x = Self::maybe_pad(x)?; 288 | let (b, h, w, c) = x.dims4()?; 289 | let input_feature = { 290 | let x = x.reshape((b, 2, h / 2, 2, w / 2, c))?; 291 | let input_feature_0 = x.i((.., 0, .., 0, .., ..))?.squeeze(1)?.squeeze(2)?; 292 | let input_feature_1 = x.i((.., 0, .., 1, .., ..))?.squeeze(1)?.squeeze(2)?; 293 | let input_feature_2 = x.i((.., 1, .., 0, .., ..))?.squeeze(1)?.squeeze(2)?; 294 | let input_feature_3 = x.i((.., 1, .., 1, .., ..))?.squeeze(1)?.squeeze(2)?; 295 | let x = Tensor::cat( 296 | &[ 297 | input_feature_0, 298 | input_feature_1, 299 | input_feature_2, 300 | input_feature_3, 301 | ], 302 | D::Minus1, 303 | )?; 304 | x.reshape((b, (), 4 * c))? 305 | }; 306 | let x = self.norm.forward(&input_feature)?; 307 | self.reduction.forward(&x) 308 | } 309 | } 310 | 311 | #[derive(Debug, Clone)] 312 | struct SwinSelfAttention { 313 | num_attention_heads: usize, 314 | attention_head_size: usize, 315 | relative_position_bias: Tensor, 316 | query: Linear, 317 | key: Linear, 318 | value: Linear, 319 | } 320 | 321 | impl SwinSelfAttention { 322 | fn new(dim: usize, num_heads: usize, window_size: usize, vb: VarBuilder) -> Result { 323 | let num_attention_heads = num_heads; 324 | let attention_head_size = dim / num_attention_heads; 325 | // let all_head_size = num_attention_heads * attention_head_size; 326 | let query = linear(dim, dim, vb.pp("query"))?; 327 | let key = linear(dim, dim, vb.pp("key"))?; 328 | let value = linear(dim, dim, vb.pp("value"))?; 329 | let relative_position_bias_table = vb.get( 330 | ( 331 | (2 * window_size - 1) * (2 * window_size - 1), 332 | num_attention_heads, 333 | ), 334 | "relative_position_bias_table", 335 | )?; 336 | let relative_position_index = Self::generate_relative_position_index( 337 | window_size, 338 | relative_position_bias_table.device(), 339 | )? 340 | .flatten_all()?; 341 | let relative_position_bias = relative_position_bias_table.i(&relative_position_index)?; 342 | let relative_position_bias = relative_position_bias 343 | .reshape(( 344 | window_size * window_size, 345 | window_size * window_size, 346 | num_attention_heads, 347 | ))? 348 | .permute((2, 0, 1))? 349 | .contiguous()? 350 | .unsqueeze(0)?; 351 | Ok(Self { 352 | num_attention_heads, 353 | attention_head_size, 354 | relative_position_bias, 355 | query, 356 | key, 357 | value, 358 | }) 359 | } 360 | 361 | fn generate_relative_position_index(window_size: usize, device: &Device) -> Result { 362 | debug_assert!(window_size > 1, "window_size must be greater than 1"); 363 | let window_size = window_size as i64; 364 | let h = Tensor::arange(0, window_size, device)?; 365 | let w = Tensor::arange(0, window_size, device)?; 366 | let xy_indexing = false; // use ij indexing 367 | let grids = Tensor::meshgrid(&[h, w], xy_indexing)?; 368 | let grid = Tensor::stack(&grids, 0)?.flatten_from(1)?; 369 | let grid = { 370 | let (_, w) = grid.shape().dims2()?; 371 | let left = grid.unsqueeze(2)?.repeat(Shape::from_dims(&[1, 1, w]))?; 372 | let right = grid.unsqueeze(1)?.repeat(Shape::from_dims(&[1, w, 1]))?; 373 | (left - right)? 374 | }; 375 | let relative_grid = { 376 | let bias = Tensor::full(window_size - 1, grid.shape().clone(), device)?; 377 | let relative_grid = (grid + bias)?; 378 | let m1 = relative_grid.i(0)?; 379 | let m2 = relative_grid.i(1)?; 380 | let scalar = Tensor::full(2 * window_size - 1, m1.shape().clone(), device)?; 381 | let m1 = (m1 * scalar)?; 382 | Tensor::stack(&[m1, m2], 2)? 383 | }; 384 | relative_grid.sum(2)?.to_dtype(DType::U32) 385 | } 386 | 387 | fn transpose_for_scores(&self, x: &Tensor) -> Result { 388 | let (b, n, _) = x.shape().dims3()?; 389 | x.reshape((b, n, self.num_attention_heads, self.attention_head_size))? 390 | .permute((0, 2, 1, 3))? 391 | .contiguous() 392 | } 393 | } 394 | 395 | impl Module for SwinSelfAttention { 396 | fn forward(&self, x: &Tensor) -> Result { 397 | debug_assert_eq!(3, x.dims().len(), "Input tensor must have 3 dimensions"); 398 | let key_layer = self.transpose_for_scores(&self.key.forward(x)?)?; 399 | let query_layer = self.transpose_for_scores(&self.query.forward(x)?)?; 400 | let value_layer = self.transpose_for_scores(&self.value.forward(x)?)?; 401 | let attention_scores = (query_layer.matmul(&key_layer.t()?))?; 402 | let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; 403 | let attention_scores = attention_scores.broadcast_add(&self.relative_position_bias)?; 404 | let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; 405 | let context_layer = attention_probs.matmul(&value_layer)?; 406 | let context_layer = context_layer.permute((0, 2, 1, 3))?; 407 | context_layer.flatten_from(2) 408 | } 409 | } 410 | 411 | #[derive(Debug, Clone)] 412 | struct SwinAttention { 413 | self_attention: SwinSelfAttention, 414 | output: SwinSelfOutput, 415 | } 416 | 417 | impl SwinAttention { 418 | fn new(dim: usize, num_heads: usize, window_size: usize, vb: VarBuilder) -> Result { 419 | let self_attention = SwinSelfAttention::new(dim, num_heads, window_size, vb.pp("self"))?; 420 | let output = SwinSelfOutput::new(dim, vb.pp("output"))?; 421 | Ok(Self { 422 | self_attention, 423 | output, 424 | }) 425 | } 426 | } 427 | 428 | impl Module for SwinAttention { 429 | fn forward(&self, x: &Tensor) -> Result { 430 | let x = self.self_attention.forward(x)?; 431 | self.output.forward(&x) 432 | } 433 | } 434 | 435 | #[derive(Debug, Clone)] 436 | struct SwinLayer { 437 | shift_size: usize, 438 | window_size: usize, 439 | layernorm_before: LayerNorm, 440 | attention: SwinAttention, 441 | layernorm_after: LayerNorm, 442 | intermediate: SwinIntermediate, 443 | output: SwinOutput, 444 | } 445 | 446 | impl SwinLayer { 447 | fn new( 448 | config: &SwinConfig, 449 | dim: usize, 450 | num_heads: usize, 451 | shift_size: Option, 452 | vb: VarBuilder, 453 | ) -> Result { 454 | let layer_norm_eps = config.layer_norm_eps; 455 | let layernorm_before = layer_norm(dim, layer_norm_eps, vb.pp("layernorm_before"))?; 456 | let attention = SwinAttention::new(dim, num_heads, config.window_size, vb.pp("attention"))?; 457 | let layernorm_after = layer_norm(dim, layer_norm_eps, vb.pp("layernorm_after"))?; 458 | let intermediate = SwinIntermediate::new(config, dim, vb.pp("intermediate"))?; 459 | let output = SwinOutput::new(config, dim, vb.pp("output"))?; 460 | Ok(Self { 461 | shift_size: shift_size.unwrap_or_default(), 462 | window_size: config.window_size, 463 | layernorm_before, 464 | attention, 465 | layernorm_after, 466 | intermediate, 467 | output, 468 | }) 469 | } 470 | 471 | /// the x tensor should be in [b, h, w, c] format 472 | fn maybe_pad( 473 | x: &Tensor, 474 | window_size: usize, 475 | height: usize, 476 | width: usize, 477 | ) -> Result<(Tensor, bool)> { 478 | let mut padded = false; 479 | let x = if width % window_size != 0 { 480 | padded = true; 481 | let pad_right = window_size - (width % window_size); 482 | x.pad_with_zeros(2, 0, pad_right)? 483 | } else { 484 | x.clone() 485 | }; 486 | let x = if height % window_size != 0 { 487 | padded = true; 488 | let pad_bottom = window_size - (height % window_size); 489 | x.pad_with_zeros(1, 0, pad_bottom)? 490 | } else { 491 | x.clone() 492 | }; 493 | Ok((x, padded)) 494 | } 495 | 496 | fn get_attn_mask( 497 | shift_size: usize, 498 | window_size: usize, 499 | height: usize, 500 | width: usize, 501 | dtype: DType, 502 | device: &Device, 503 | ) -> Result> { 504 | if shift_size > 0 { 505 | let xs = [0, height - window_size, height - shift_size, height]; 506 | let ys = [0, width - window_size, width - shift_size, width]; 507 | let mut count = 0i64; 508 | let mut rows = vec![]; 509 | for (xs, xe) in xs.iter().zip(&xs[1..]) { 510 | let mut cols = vec![]; 511 | for (ys, ye) in ys.iter().zip(&ys[1..]) { 512 | let shape = (xe - xs, ye - ys); 513 | let tensor = Tensor::full(count, shape, device)?; 514 | cols.push(tensor); 515 | count += 1; 516 | } 517 | let row = Tensor::cat(&cols, 1)?; 518 | rows.push(row); 519 | } 520 | let mask = Tensor::cat(&rows, 0)?; 521 | debug_assert_eq!( 522 | mask.dims2()?, 523 | (height, width), 524 | "mask shape must match input shape" 525 | ); 526 | let mask = mask.unsqueeze(0)?.unsqueeze(3)?.to_dtype(dtype)?; 527 | let mask = Self::window_partition(&mask, window_size)?; 528 | let mask = mask.reshape(((), window_size * window_size))?; 529 | println!("mask dim {:?}", mask); 530 | let mask = (mask.unsqueeze(1)?.broadcast_sub(&mask.unsqueeze(2)?))?; 531 | let mask = (mask.ne(0i64)?.to_dtype(dtype)? * -100.0f64)?; 532 | Ok(Some(mask)) 533 | } else { 534 | Ok(None) 535 | } 536 | } 537 | 538 | fn window_partition(x: &Tensor, window_size: usize) -> Result { 539 | let (b, h, w, c) = x.dims4()?; 540 | debug_assert!( 541 | h % window_size == 0 && w % window_size == 0, 542 | "input resolution must be divisible by window size" 543 | ); 544 | let x = x.reshape(( 545 | b, 546 | h / window_size, 547 | window_size, 548 | w / window_size, 549 | window_size, 550 | c, 551 | ))?; 552 | let x = x.permute((0, 1, 3, 2, 4, 5))?.contiguous()?; 553 | x.reshape(((), window_size, window_size, c)) 554 | } 555 | 556 | fn window_reverse(x: &Tensor, window_size: usize, h: usize, w: usize) -> Result { 557 | let (b, _, _, c) = x.dims4()?; 558 | debug_assert!( 559 | h % window_size == 0 && w % window_size == 0, 560 | "input resolution must be divisible by window size" 561 | ); 562 | let b = b * window_size * window_size / h / w; 563 | let x = x.reshape(( 564 | b, 565 | h / window_size, 566 | w / window_size, 567 | window_size, 568 | window_size, 569 | c, 570 | ))?; 571 | let x = x.permute((0, 1, 3, 2, 4, 5))?.contiguous()?; 572 | x.reshape((b, h, w, c)) 573 | } 574 | 575 | /// if window size is larger than input resolution, we don't partition windows 576 | fn get_shift_and_window_size(&self, h: usize, w: usize) -> (usize, usize) { 577 | let min = h.min(w); 578 | if min <= self.window_size { 579 | (0, min) 580 | } else { 581 | (self.shift_size, self.window_size) 582 | } 583 | } 584 | } 585 | 586 | impl Module for SwinLayer { 587 | fn forward(&self, x: &Tensor) -> Result { 588 | let (b, h, w, c) = x.dims4()?; 589 | let (shift_size, window_size) = self.get_shift_and_window_size(h, w); 590 | let shortcut = x; 591 | let x = { 592 | let x = x.flatten(1, 2)?; 593 | let x = self.layernorm_before.forward(&x)?; 594 | // note no permutation 595 | x.reshape((b, h, w, c))? 596 | }; 597 | 598 | let (x, was_padded) = Self::maybe_pad(&x, window_size, h, w)?; 599 | let (_, padded_h, padded_w, _) = x.shape().dims4()?; 600 | 601 | // shift 602 | let x = if shift_size > 0 { 603 | let x = x.roll(-(shift_size as i32), 1)?; 604 | x.roll(-(shift_size as i32), 2)? 605 | } else { 606 | x 607 | }; 608 | 609 | // partition 610 | let x = Self::window_partition(&x, window_size)?; 611 | 612 | // attention 613 | let x = { 614 | let (b, w1, w2, c) = x.dims4()?; 615 | debug_assert_eq!(w1, w2, "window size must be square"); 616 | debug_assert_eq!(w1, window_size, "window size must be equal to window_size"); 617 | let x = x.reshape((b, (), c))?; 618 | 619 | if let Some(_) = 620 | Self::get_attn_mask(shift_size, window_size, w1, w2, x.dtype(), x.device())? 621 | { 622 | // TODO attention mask 623 | println!("TODO must apply attention mask!"); 624 | } 625 | 626 | let x = self.attention.forward(&x)?; 627 | x.reshape((b, w1, w2, c))? 628 | }; 629 | 630 | // un-partition 631 | let x = Self::window_reverse(&x, window_size, padded_h, padded_w)?; 632 | 633 | // un-shift 634 | let x = if shift_size > 0 { 635 | let x = x.roll(shift_size as i32, 1)?; 636 | x.roll(shift_size as i32, 2)? 637 | } else { 638 | x 639 | }; 640 | 641 | let x = if was_padded { 642 | x.narrow(1, 0, h)?.narrow(2, 0, w)? 643 | } else { 644 | x 645 | }; 646 | 647 | let hidden_states = (x + shortcut)?; 648 | 649 | let x = self.layernorm_after.forward(&hidden_states)?; 650 | let x = self.intermediate.forward(&x)?; 651 | let x = self.output.forward(&x)?; 652 | 653 | x + hidden_states 654 | } 655 | } 656 | 657 | #[derive(Debug, Clone)] 658 | struct SwinStage { 659 | blocks: Vec, 660 | downsample: Option, 661 | } 662 | 663 | impl SwinStage { 664 | fn new( 665 | config: &SwinConfig, 666 | dim: usize, 667 | depth: usize, 668 | num_heads: usize, 669 | downsample: bool, 670 | vb: VarBuilder, 671 | ) -> Result { 672 | let blocks = (0..depth) 673 | .map(|i| { 674 | let shift_size = if i % 2 == 0 { 675 | None 676 | } else { 677 | Some(config.window_size / 2) 678 | }; 679 | SwinLayer::new( 680 | config, 681 | dim, 682 | num_heads, 683 | shift_size, 684 | vb.pp(&format!("blocks.{}", i)), 685 | ) 686 | }) 687 | .collect::>>()?; 688 | let downsample = if downsample { 689 | Some(SwinPatchMerging::new(config, dim, vb.pp("downsample"))?) 690 | } else { 691 | None 692 | }; 693 | Ok(Self { blocks, downsample }) 694 | } 695 | } 696 | 697 | impl Module for SwinStage { 698 | fn forward(&self, x: &Tensor) -> Result { 699 | let mut x = x.clone(); 700 | for block in &self.blocks { 701 | x = block.forward(&x)?; 702 | } 703 | if let Some(downsample) = &self.downsample { 704 | x = downsample.forward(&x)?; 705 | } 706 | Ok(x) 707 | } 708 | } 709 | 710 | #[derive(Debug, Clone)] 711 | struct SwinEncoder { 712 | layers: Vec, 713 | } 714 | 715 | impl SwinEncoder { 716 | fn new(config: &SwinConfig, vb: VarBuilder) -> Result { 717 | let layers = (0..config.depths.len()) 718 | .map(|i| { 719 | let dim = config.embed_dim * 2_usize.pow(i as u32); 720 | let depth = config.depths[i]; 721 | let num_heads = config.num_heads[i]; 722 | let downsample = i < config.depths.len() - 1; 723 | SwinStage::new( 724 | config, 725 | dim, 726 | depth, 727 | num_heads, 728 | downsample, 729 | vb.pp(&format!("layers.{}", i)), 730 | ) 731 | }) 732 | .collect::>>()?; 733 | Ok(Self { layers }) 734 | } 735 | } 736 | 737 | impl Module for SwinEncoder { 738 | fn forward(&self, x: &Tensor) -> Result { 739 | let mut x = x.clone(); 740 | for layer in &self.layers { 741 | x = layer.forward(&x)?; 742 | } 743 | Ok(x) 744 | } 745 | } 746 | 747 | #[derive(Debug, Clone)] 748 | pub(crate) struct SwinModel { 749 | embeddings: SwinEmbeddings, 750 | encoder: SwinEncoder, 751 | } 752 | 753 | impl SwinModel { 754 | pub(crate) fn new(config: &SwinConfig, vb: VarBuilder) -> Result { 755 | let embeddings = SwinEmbeddings::new(config, vb.pp("embeddings"))?; 756 | let encoder = SwinEncoder::new(config, vb.pp("encoder"))?; 757 | Ok(Self { 758 | embeddings, 759 | encoder, 760 | }) 761 | } 762 | } 763 | 764 | impl Module for SwinModel { 765 | fn forward(&self, x: &Tensor) -> Result { 766 | let x = self.embeddings.forward(x)?; 767 | let x = self.encoder.forward(&x)?; 768 | // this is the same as adaptive avg pool with output size 1 769 | x.mean(1) 770 | } 771 | } 772 | 773 | #[cfg(test)] 774 | mod test { 775 | 776 | use super::*; 777 | use candle_nn::var_builder::VarBuilderArgs; 778 | 779 | #[test] 780 | fn test_swin_config_from_json() { 781 | let config_raw = r#"{ 782 | "_name_or_path": "", 783 | "add_cross_attention": false, 784 | "architectures": [ 785 | "DonutSwinModel" 786 | ], 787 | "attention_probs_dropout_prob": 0.0, 788 | "bad_words_ids": null, 789 | "begin_suppress_tokens": null, 790 | "bos_token_id": null, 791 | "chunk_size_feed_forward": 0, 792 | "cross_attention_hidden_size": null, 793 | "decoder_start_token_id": null, 794 | "depths": [ 795 | 2, 796 | 2, 797 | 14, 798 | 2 799 | ], 800 | "diversity_penalty": 0.0, 801 | "do_sample": false, 802 | "drop_path_rate": 0.1, 803 | "early_stopping": false, 804 | "embed_dim": 128, 805 | "encoder_no_repeat_ngram_size": 0, 806 | "eos_token_id": null, 807 | "exponential_decay_length_penalty": null, 808 | "finetuning_task": null, 809 | "forced_bos_token_id": null, 810 | "forced_eos_token_id": null, 811 | "hidden_act": "gelu", 812 | "hidden_dropout_prob": 0.0, 813 | "hidden_size": 1024, 814 | "id2label": { 815 | "0": "LABEL_0", 816 | "1": "LABEL_1" 817 | }, 818 | "image_size": [ 819 | 196, 820 | 896 821 | ], 822 | "initializer_range": 0.02, 823 | "is_decoder": false, 824 | "is_encoder_decoder": false, 825 | "label2id": { 826 | "LABEL_0": 0, 827 | "LABEL_1": 1 828 | }, 829 | "layer_norm_eps": 0.00001, 830 | "length_penalty": 1.0, 831 | "max_length": 20, 832 | "min_length": 0, 833 | "mlp_ratio": 4.0, 834 | "model_type": "donut-swin", 835 | "no_repeat_ngram_size": 0, 836 | "num_beam_groups": 1, 837 | "num_beams": 1, 838 | "num_channels": 3, 839 | "num_heads": [ 840 | 4, 841 | 8, 842 | 16, 843 | 32 844 | ], 845 | "num_layers": 4, 846 | "num_return_sequences": 1, 847 | "output_attentions": false, 848 | "output_hidden_states": false, 849 | "output_scores": false, 850 | "pad_token_id": null, 851 | "patch_size": 4, 852 | "path_norm": true, 853 | "prefix": null, 854 | "problem_type": null, 855 | "pruned_heads": {}, 856 | "qkv_bias": true, 857 | "remove_invalid_values": false, 858 | "repetition_penalty": 1.0, 859 | "return_dict": true, 860 | "return_dict_in_generate": false, 861 | "sep_token_id": null, 862 | "suppress_tokens": null, 863 | "task_specific_params": null, 864 | "temperature": 1.0, 865 | "tf_legacy_loss": false, 866 | "tie_encoder_decoder": false, 867 | "tie_word_embeddings": true, 868 | "tokenizer_class": null, 869 | "top_k": 50, 870 | "top_p": 1.0, 871 | "torch_dtype": "float32", 872 | "torchscript": false, 873 | "typical_p": 1.0, 874 | "use_2d_embeddings": false, 875 | "use_absolute_embeddings": true, 876 | "use_bfloat16": false, 877 | "window_size": 7 878 | }"#; 879 | let config: SwinConfig = serde_json::from_str(config_raw).unwrap(); 880 | let default_config = SwinConfig::default(); 881 | assert_eq!(config.image_size, default_config.image_size); 882 | assert_eq!(config.patch_size, default_config.patch_size); 883 | assert_eq!(config.num_channels, default_config.num_channels); 884 | assert_eq!(config.embed_dim, default_config.embed_dim); 885 | assert_eq!(config.depths, default_config.depths); 886 | assert_eq!(config.num_heads, default_config.num_heads); 887 | assert_eq!(config.window_size, default_config.window_size); 888 | assert_eq!(config.mlp_ratio, default_config.mlp_ratio); 889 | assert_eq!(config.qkv_bias, default_config.qkv_bias); 890 | assert_eq!(config.hidden_act, default_config.hidden_act); 891 | assert_eq!( 892 | config.use_absolute_embeddings, 893 | default_config.use_absolute_embeddings 894 | ); 895 | assert_eq!(config.initializer_range, default_config.initializer_range); 896 | assert_eq!(config.layer_norm_eps, default_config.layer_norm_eps); 897 | } 898 | 899 | #[test] 900 | fn test_swin_patch_embeddings() -> Result<()> { 901 | let device = Device::Cpu; 902 | let vb = VarBuilderArgs::zeros(DType::F32, &device); 903 | let config = SwinConfig::tiny_patch4_window7_224(); 904 | let module = SwinPatchEmbeddings::new(&config, vb)?; 905 | let x = Tensor::zeros(&[1, 3, 224, 224], DType::F32, &device)?; 906 | let result = module.forward(&x)?; 907 | assert_eq!(result.dims(), &[1, config.embed_dim, 56, 56]); 908 | Ok(()) 909 | } 910 | 911 | // this is expensive, run using `cargo t -- --ignored` 912 | #[test] 913 | #[ignore] 914 | fn test_embeddings_value_compare() -> anyhow::Result<()> { 915 | use crate::hf::HfModelInfo; 916 | use crate::recognition::Config; 917 | use float_cmp::approx_eq; 918 | 919 | let device = Device::Cpu; 920 | let dtype = DType::F16; 921 | let model_info = HfModelInfo { 922 | model_type: "test", 923 | repo: "vikp/surya_rec".into(), 924 | weights_file: "model.safetensors".into(), 925 | config_file: "config.json".into(), 926 | }; 927 | let (config_file, model_file) = model_info.download_model_files()?; 928 | let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? }; 929 | let config = 930 | serde_json::from_str::(&std::fs::read_to_string(config_file)?)?.encoder; 931 | let embedding = SwinEmbeddings::new(&config, vb.pp("encoder").pp("embeddings"))?; 932 | let x = Tensor::ones( 933 | &[ 934 | 1, 935 | config.num_channels, 936 | config.image_size.0, 937 | config.image_size.1, 938 | ], 939 | dtype, 940 | &device, 941 | )?; 942 | { 943 | let y = embedding.patch_embeddings.forward(&x)?; 944 | let y_sum: f32 = y.to_dtype(DType::F32)?.sum_all()?.to_scalar()?; 945 | // assert_eq!(y_sum, -112499.0938); 946 | assert!(approx_eq!(f32, y_sum, -112_499.09, epsilon = 20.)); 947 | } 948 | { 949 | let y = embedding.forward(&x)?; 950 | let y_sum: f32 = y.to_dtype(DType::F32)?.sum_all()?.to_scalar()?; 951 | // assert_eq!(y_sum, 140062.19); 952 | assert!(approx_eq!(f32, y_sum, 140062.19, epsilon = 45.)); 953 | } 954 | Ok(()) 955 | } 956 | 957 | #[test] 958 | fn test_swin_embeddings() -> Result<()> { 959 | let device = Device::Cpu; 960 | let vb = VarBuilderArgs::zeros(DType::F32, &device); 961 | let config = SwinConfig::tiny_patch4_window7_224(); 962 | let module = SwinEmbeddings::new(&config, vb)?; 963 | let x = Tensor::zeros(&[1, 3, 224, 224], DType::F32, &device)?; 964 | let result = module.forward(&x)?; 965 | assert_eq!(result.dims(), &[1, config.embed_dim, 56, 56]); 966 | Ok(()) 967 | } 968 | 969 | #[test] 970 | fn test_generate_relative_position_index() -> Result<()> { 971 | let device = Device::Cpu; 972 | let result = SwinSelfAttention::generate_relative_position_index(3, &device)?; 973 | assert_eq!( 974 | result.to_vec2::()?, 975 | [ 976 | [12, 11, 10, 7, 6, 5, 2, 1, 0], 977 | [13, 12, 11, 8, 7, 6, 3, 2, 1], 978 | [14, 13, 12, 9, 8, 7, 4, 3, 2], 979 | [17, 16, 15, 12, 11, 10, 7, 6, 5], 980 | [18, 17, 16, 13, 12, 11, 8, 7, 6], 981 | [19, 18, 17, 14, 13, 12, 9, 8, 7], 982 | [22, 21, 20, 17, 16, 15, 12, 11, 10], 983 | [23, 22, 21, 18, 17, 16, 13, 12, 11], 984 | [24, 23, 22, 19, 18, 17, 14, 13, 12] 985 | ] 986 | ); 987 | Ok(()) 988 | } 989 | 990 | #[test] 991 | fn test_swin_self_attention() -> Result<()> { 992 | let device = Device::Cpu; 993 | let dim = 96; 994 | let window_size = 7; 995 | let vb = VarBuilderArgs::zeros(DType::F32, &device); 996 | let module: SwinSelfAttention = SwinSelfAttention::new(dim, 3, window_size, vb)?; 997 | let x = Tensor::zeros(&[1, window_size * window_size, dim], DType::F32, &device)?; 998 | let result = module.forward(&x)?; 999 | assert_eq!(result.dims(), &[1, window_size * window_size, dim]); 1000 | Ok(()) 1001 | } 1002 | 1003 | #[test] 1004 | fn test_window_partition() -> Result<()> { 1005 | let device = Device::Cpu; 1006 | let dim = 96; 1007 | let x = Tensor::zeros(&[1, 56, 56, dim], DType::F32, &device)?; 1008 | let result = SwinLayer::window_partition(&x, 7)?; 1009 | assert_eq!(result.dims(), &[56 * 56 / 7 / 7, 7, 7, dim]); 1010 | Ok(()) 1011 | } 1012 | 1013 | #[test] 1014 | fn test_window_reverse() -> Result<()> { 1015 | let device = Device::Cpu; 1016 | let dim = 96; 1017 | let x = Tensor::zeros(&[56 * 56 / 7 / 7, 7, 7, dim], DType::F32, &device)?; 1018 | let result = SwinLayer::window_reverse(&x, 7, 56, 56)?; 1019 | assert_eq!(result.dims(), &[1, 56, 56, dim]); 1020 | Ok(()) 1021 | } 1022 | 1023 | #[test] 1024 | fn test_window_full_cycle() -> Result<()> { 1025 | let device = Device::Cpu; 1026 | let dim = 96; 1027 | let x = Tensor::zeros(&[1, 56, 56, dim], DType::F32, &device)?; 1028 | let window_size = 7; 1029 | let x = SwinLayer::window_partition(&x, window_size)?; 1030 | let x = SwinLayer::window_reverse(&x, window_size, 56, 56)?; 1031 | assert_eq!(x.dims(), &[1, 56, 56, dim]); 1032 | Ok(()) 1033 | } 1034 | 1035 | #[test] 1036 | fn test_swin_layer() -> Result<()> { 1037 | let device = Device::Cpu; 1038 | let config = SwinConfig::tiny_patch4_window7_224(); 1039 | let dim = 96; 1040 | let vb = VarBuilderArgs::zeros(DType::F32, &device); 1041 | let x = Tensor::zeros(&[1, 56, 56, 96], DType::F32, &device)?; 1042 | let module = SwinLayer::new(&config, dim, 3, None, vb)?; 1043 | let result = module.forward(&x)?; 1044 | assert_eq!(result.dims(), &[1, 56, 56, 96]); 1045 | Ok(()) 1046 | } 1047 | 1048 | #[test] 1049 | fn test_get_attn_mask() -> Result<()> { 1050 | let device = Device::Cpu; 1051 | let result = SwinLayer::get_attn_mask(0, 7, 56, 56, DType::F32, &device)?; 1052 | assert!(result.is_none()); 1053 | 1054 | let result = SwinLayer::get_attn_mask(3, 7, 56, 56, DType::F32, &device)?; 1055 | assert!(result.is_some()); 1056 | let result = result.unwrap(); 1057 | assert_eq!(result.dims(), &[64, 49, 49]); 1058 | 1059 | let result = SwinLayer::get_attn_mask(1, 2, 4, 4, DType::I64, &device)?; 1060 | assert!(result.is_some()); 1061 | let result = result.unwrap(); 1062 | assert_eq!(result.dims(), &[4, 4, 4]); 1063 | assert_eq!( 1064 | result.to_vec3::()?, 1065 | [ 1066 | [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 1067 | [ 1068 | [0, -100, 0, -100], 1069 | [-100, 0, -100, 0], 1070 | [0, -100, 0, -100], 1071 | [-100, 0, -100, 0] 1072 | ], 1073 | [ 1074 | [0, 0, -100, -100], 1075 | [0, 0, -100, -100], 1076 | [-100, -100, 0, 0], 1077 | [-100, -100, 0, 0] 1078 | ], 1079 | [ 1080 | [0, -100, -100, -100], 1081 | [-100, 0, -100, -100], 1082 | [-100, -100, 0, -100], 1083 | [-100, -100, -100, 0] 1084 | ] 1085 | ] 1086 | ); 1087 | Ok(()) 1088 | } 1089 | 1090 | #[test] 1091 | fn test_swin_patch_merging() -> Result<()> { 1092 | let device = Device::Cpu; 1093 | let dim = 96; 1094 | let vb = VarBuilderArgs::zeros(DType::F32, &device); 1095 | let config = SwinConfig::tiny_patch4_window7_224(); 1096 | let module = SwinPatchMerging::new(&config, dim, vb)?; 1097 | let x = Tensor::zeros(&[1, 7, 7, 96], DType::F32, &device)?; 1098 | let result = module.forward(&x)?; 1099 | assert_eq!(result.dims(), &[1, 4 * 4, 96 * 2]); 1100 | Ok(()) 1101 | } 1102 | 1103 | // skip this test for now 1104 | // #[test] 1105 | fn test_swin_encoder() -> Result<()> { 1106 | let device = Device::Cpu; 1107 | let vb = VarBuilderArgs::zeros(DType::F32, &device); 1108 | let config = SwinConfig::tiny_patch4_window7_224(); 1109 | let module = SwinEncoder::new(&config, vb)?; 1110 | let x = Tensor::zeros(&[1, 56, 56, 96], DType::F32, &device)?; 1111 | let result = module.forward(&x)?; 1112 | assert_eq!(result.dims(), &[1, 49, 768]); 1113 | Ok(()) 1114 | } 1115 | } 1116 | --------------------------------------------------------------------------------