├── .envrc ├── tests ├── assets │ ├── .gitattributes │ ├── realesr.mnn │ └── resizing.mnn ├── multi_threaded.rs ├── resizing.rs ├── backend.rs ├── basic.rs ├── segfault.rs └── common.rs ├── .gitmodules ├── .gitignore ├── sgconfig.yml ├── mnn-bridge ├── src │ ├── opencv.rs │ ├── lib.rs │ └── ndarray.rs └── Cargo.toml ├── tools ├── cachix │ └── push.sh ├── sg-lints │ ├── lints │ │ ├── no-unwrap.yml │ │ └── no-println.yaml │ └── utils │ │ └── is-test.yml └── bencher │ ├── Cargo.toml │ └── src │ └── cli.rs ├── mnn-sys ├── mnn_c │ ├── session_c.h │ ├── session_c.cpp │ ├── error_code_c.h │ ├── utils.h │ ├── schedule_c.h │ ├── backend_c.h │ ├── backend_c.cpp │ ├── schedule_c.cpp │ ├── utils.cpp │ ├── tensor_c.h │ ├── tensor_c.cpp │ ├── interpreter_c.h │ └── interpreter_c.cpp ├── patches │ ├── halide_type_t_64.patch │ └── mnn-tracing.patch ├── Cargo.toml ├── src │ ├── lib.rs │ └── tracing.rs └── build.rs ├── mnn-sync └── Cargo.toml ├── README.md ├── src ├── profile.rs ├── session.rs ├── lib.rs ├── error.rs ├── tensor │ ├── list.rs │ └── raw.rs ├── backend.rs └── schedule.rs ├── .github └── workflows │ ├── docs.yaml │ └── build.yaml ├── benches └── mnn-bench.rs ├── Cargo.toml ├── examples ├── simple.rs └── inspect.rs ├── flake.lock ├── flake.nix ├── deny.toml └── LICENSE /.envrc: -------------------------------------------------------------------------------- 1 | use flake 2 | -------------------------------------------------------------------------------- /tests/assets/.gitattributes: -------------------------------------------------------------------------------- 1 | *.mnn filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "mnn-sys/vendor"] 2 | path = mnn-sys/vendor 3 | url = https://github.com/alibaba/mnn 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .direnv 2 | target 3 | .DS_Store 4 | *.mnn 5 | *.ppm 6 | lama 7 | *.json 8 | *.cache 9 | result 10 | -------------------------------------------------------------------------------- /sgconfig.yml: -------------------------------------------------------------------------------- 1 | ruleDirs: 2 | - ./tools/sg-lints/lints 3 | utilDirs: 4 | - ./tools/sg-lints/utils 5 | ignores: 6 | - mnn-sys/vendor 7 | -------------------------------------------------------------------------------- /tests/assets/realesr.mnn: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1c46f03663e221596286dda67413e4840edff466ce240f520aff5c97cc32da28 3 | size 4865572 4 | -------------------------------------------------------------------------------- /tests/assets/resizing.mnn: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1fd086430d897b39d97c2de595a4004cab6db7f3e06678c595074c321846ff14 3 | size 8624 4 | -------------------------------------------------------------------------------- /mnn-bridge/src/opencv.rs: -------------------------------------------------------------------------------- 1 | // pub trait OpencvToNdarray { 2 | // type H: mnn::HalideType; 3 | // fn as_ndarray(&self) -> ndarray::ArrayViewD; 4 | // } 5 | -------------------------------------------------------------------------------- /tools/cachix/push.sh: -------------------------------------------------------------------------------- 1 | cachix watch-exec mnn-rs -- nix flake check --system x86_64-linux --max-jobs 0 2 | cachix watch-exec mnn-rs -- nix flake check --system aarch64-darwin 3 | -------------------------------------------------------------------------------- /mnn-bridge/src/lib.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "ndarray")] 2 | pub mod ndarray; 3 | #[cfg(feature = "ndarray_0_15")] 4 | mod ndarray_0_15 { 5 | use ndarray_0_15 as ndarray; 6 | include!("ndarray.rs"); 7 | } 8 | -------------------------------------------------------------------------------- /mnn-sys/mnn_c/session_c.h: -------------------------------------------------------------------------------- 1 | #ifndef SESSION_C_H 2 | #define SESSION_C_H 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | typedef struct Session Session; 9 | void Session_destroy(Session *session); 10 | int Session_hasAsyncWork(Session *session); 11 | 12 | #ifdef __cplusplus 13 | } 14 | #endif 15 | 16 | #endif // SESSION_C_H 17 | -------------------------------------------------------------------------------- /tools/sg-lints/lints/no-unwrap.yml: -------------------------------------------------------------------------------- 1 | id: no-unwrap 2 | message: Do not use unwrap 3 | severity: error 4 | language: Rust 5 | rule: 6 | pattern: $ITEM.unwrap() 7 | not: 8 | inside: 9 | stopBy: end 10 | matches: is-test 11 | files: 12 | - src/**/*.rs 13 | - mnn-sync/src/*.rs 14 | - mnn-sys/src/*.rs 15 | - mnn-bridge/src/**/*.rs 16 | ignores: 17 | - build.rs 18 | - mnn-sys/vendor/**/*.rs 19 | 20 | -------------------------------------------------------------------------------- /tools/sg-lints/utils/is-test.yml: -------------------------------------------------------------------------------- 1 | id: is-test 2 | language: Rust 3 | 4 | rule: 5 | all: 6 | - kind: function_item 7 | - follows: 8 | stopBy: 9 | kind: function_item 10 | matches: test-token 11 | 12 | utils: 13 | test-token: 14 | kind: attribute_item 15 | has: 16 | kind: attribute 17 | has: 18 | any: 19 | - pattern: test 20 | - pattern: tokio::test 21 | 22 | ignores: 23 | - mnn-sys/vendor/**/*.rs 24 | -------------------------------------------------------------------------------- /mnn-sync/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mnn-sync" 3 | version = "0.1.0" 4 | edition = "2021" 5 | license.workspace = true 6 | 7 | [dependencies] 8 | error-stack.workspace = true 9 | flume = { version = "0.11.0", default-features = false, features = [ 10 | "eventual-fairness", 11 | "nanorand", 12 | ] } 13 | mnn.workspace = true 14 | oneshot = "0.1.8" 15 | tracing = { version = "0.1", optional = true } 16 | 17 | [features] 18 | tracing = ["dep:tracing", "mnn/tracing"] 19 | 20 | [dev-dependencies] 21 | tracing-test = "0.2.5" 22 | -------------------------------------------------------------------------------- /mnn-sys/mnn_c/session_c.cpp: -------------------------------------------------------------------------------- 1 | #include "session_c.h" 2 | #include 3 | 4 | namespace MNN { 5 | class Session { 6 | public: 7 | bool hasAsyncWork(); 8 | }; 9 | } // namespace MNN 10 | void Session_destroy(Session *session) { 11 | auto mnn_session = reinterpret_cast(session); 12 | delete mnn_session; 13 | } 14 | 15 | int Session_hasAsyncWork(Session *session) { 16 | auto mnn_session = reinterpret_cast(session); 17 | return mnn_session->hasAsyncWork(); 18 | // return true; 19 | } 20 | -------------------------------------------------------------------------------- /mnn-bridge/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mnn-bridge" 3 | version = "0.1.0" 4 | edition = "2021" 5 | license = { workspace = true } 6 | 7 | [dependencies] 8 | error-stack = "0.5.0" 9 | mnn = { workspace = true } 10 | ndarray = { version = "0.16", optional = true } 11 | ndarray_0_15 = { package = "ndarray", version = "0.15", optional = true } 12 | # opencv = { version = "0.92.3", default-features = false, optional = true } 13 | 14 | [features] 15 | ndarray = ["dep:ndarray"] 16 | ndarray_0_15 = ["dep:ndarray_0_15"] 17 | # opencv = ["dep:opencv"] 18 | 19 | default = [] 20 | -------------------------------------------------------------------------------- /tools/sg-lints/lints/no-println.yaml: -------------------------------------------------------------------------------- 1 | id: no-println 2 | message: Do not use println! use `tracing::info`/`tracing::trace`/`tracing::debug` instead 3 | severity: warning 4 | language: Rust 5 | rule: 6 | kind: macro_invocation 7 | pattern: println!($$$ITEMS) 8 | not: 9 | inside: 10 | stopBy: end 11 | matches: is-test 12 | 13 | fix: tracing::info!($$$ITEMS) 14 | files: 15 | - src/**/*.rs 16 | - mnn-sync/src/*.rs 17 | - mnn-sys/src/*.rs 18 | - mnn-bridge/src/**/*.rs 19 | ignores: 20 | - build.rs 21 | - mnn-sys/build.rs 22 | - mnn-sys/vendor/**/*.rs 23 | -------------------------------------------------------------------------------- /mnn-sys/mnn_c/error_code_c.h: -------------------------------------------------------------------------------- 1 | #ifndef ERROR_CODE_C_H 2 | #define ERROR_CODE_C_H 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | typedef enum { 7 | ERROR_CODE_NO_ERROR = 0, 8 | ERROR_CODE_OUT_OF_MEMORY = 1, 9 | ERROR_CODE_NOT_SUPPORT = 2, 10 | ERROR_CODE_COMPUTE_SIZE_ERROR = 3, 11 | ERROR_CODE_NO_EXECUTION = 4, 12 | ERROR_CODE_INVALID_VALUE = 5, 13 | // User error 14 | ERROR_CODE_INPUT_DATA_ERROR = 10, 15 | ERROR_CODE_CALL_BACK_STOP = 11, 16 | // Op Resize Error 17 | ERROR_CODE_TENSOR_NOT_SUPPORT = 20, 18 | ERROR_CODE_TENSOR_NEED_DIVIDE = 21, 19 | } ErrorCode; 20 | #ifdef __cplusplus 21 | } 22 | #endif 23 | #endif // ERROR_CODE_C_H 24 | -------------------------------------------------------------------------------- /mnn-sys/patches/halide_type_t_64.patch: -------------------------------------------------------------------------------- 1 | @@ -82,11 +82,11 @@ typedef enum halide_type_code_t 2 | * exactly 32-bits in size. */ 3 | struct halide_type_t { 4 | /** The basic type code: signed integer, unsigned integer, or floating point. */ 5 | -#if __cplusplus >= 201103L 6 | +// #if __cplusplus >= 201103L 7 | HALIDE_ATTRIBUTE_ALIGN(1) halide_type_code_t code; // halide_type_code_t 8 | -#else 9 | - HALIDE_ATTRIBUTE_ALIGN(1) uint8_t code; // halide_type_code_t 10 | -#endif 11 | +// #else 12 | +// HALIDE_ATTRIBUTE_ALIGN(1) uint8_t code; // halide_type_code_t 13 | +// #endif 14 | 15 | /** The number of bits of precision of a single scalar value of this type. */ 16 | HALIDE_ATTRIBUTE_ALIGN(1) uint8_t bits; 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mnn-rs 2 | ![Codecov](https://img.shields.io/codecov/c/github/aftershootco/mnn-rs?link=https%3A%2F%2Fapp.codecov.io%2Fgithub%2Faftershootco%2Fmnn-rs) 3 | ![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/aftershootco/mnn-rs/build.yaml?link=https%3A%2F%2Fgithub.com%2Faftershootco%2Fmnn-rs%2Factions%2Fworkflows%2Fbuild.yaml) 4 | 5 | Rust wrapper over [alibaba/MNN](https://github.com/alibaba/MNN) c++ library with handwritten C wrapper over mnn 6 | 7 | If you have nix you can just build the inspect binary with 8 | 9 | ``` 10 | nix build .#inspect 11 | ``` 12 | 13 | NOTES: 14 | On windows it will only compile with --release mode 15 | There's a few issues with rustc linking to msvcrt by default and anything compiled with /MTd will not link properly 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /mnn-sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mnn-sys" 3 | version = "0.1.0" 4 | edition = "2021" 5 | links = "mnn" 6 | license = { workspace = true } 7 | 8 | [build-dependencies] 9 | anyhow = "1.0.86" 10 | bindgen = { version = "0.70", features = ["experimental"] } 11 | cc = { version = "1.1.5", features = [] } 12 | cmake = { git = "https://github.com/blonteractor/cmake-rs", features = [ 13 | "parallel", 14 | ] } 15 | diffy = "0.4.0" 16 | dunce = "1.0.4" 17 | fs_extra = "1.3.0" 18 | itertools = "0.13.0" 19 | tap = "1.0.1" 20 | 21 | [features] 22 | vulkan = [] 23 | metal = [] 24 | coreml = ["metal"] 25 | opencl = [] 26 | openmp = [] 27 | opengl = [] 28 | mnn-threadpool = [] 29 | default = ["mnn-threadpool"] 30 | crt_static = [] 31 | 32 | [dependencies] 33 | libc = "0.2.155" 34 | once_cell = "1.20.2" 35 | tracing-core = "0.1.33" 36 | -------------------------------------------------------------------------------- /mnn-sys/mnn_c/utils.h: -------------------------------------------------------------------------------- 1 | #ifndef UTILS_H 2 | #define UTILS_H 3 | #include 4 | #include 5 | #ifdef __cplusplus 6 | extern "C" { 7 | #endif 8 | typedef struct { 9 | char *data; 10 | size_t size; 11 | } CString; 12 | CString createCString(const char *data, size_t size); 13 | void destroyCString(CString *string); 14 | // This must always be 15 | typedef struct { 16 | // Name of the tensor 17 | CString name; 18 | // Points to a raw tensor object 19 | void *tensor; 20 | } TensorInfo; 21 | 22 | typedef struct { 23 | TensorInfo *tensors; 24 | size_t size; 25 | } TensorInfoArray; 26 | 27 | TensorInfoArray *createTensorInfoArray(size_t count); 28 | void destroyTensorInfoArray(TensorInfoArray *array); 29 | TensorInfo *getTensorInfoArray(TensorInfoArray const *array, size_t index); 30 | #ifdef __cplusplus 31 | } 32 | #endif 33 | 34 | #endif // UTILS_H 35 | -------------------------------------------------------------------------------- /src/profile.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "profile")] 2 | macro_rules! profile { 3 | ($message: expr; $($t:tt)*) => {{ 4 | let now = std::time::Instant::now(); 5 | #[cfg(feature = "tracing")] 6 | tracing::trace!("{}: Starting", $message); 7 | let result = { 8 | $($t)* 9 | }; 10 | let elapsed = now.elapsed(); 11 | #[cfg(feature = "tracing")] 12 | tracing::info!("{}: elapsed time: {:?}", $message, elapsed); 13 | result 14 | }} 15 | } 16 | #[cfg(not(feature = "profile"))] 17 | macro_rules! profile { 18 | ($_: expr; $($t:tt)*) => { 19 | $($t)* 20 | } 21 | } 22 | pub(crate) use profile; 23 | 24 | #[test] 25 | pub fn test_profiling() { 26 | let time = std::time::Instant::now(); 27 | profile!("Testing profiling"; { 28 | std::thread::sleep(std::time::Duration::from_secs(1)); 29 | }); 30 | let time = time.elapsed(); 31 | assert!(time.as_secs() == 1); 32 | } 33 | -------------------------------------------------------------------------------- /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | 7 | env: 8 | CARGO_TERM_COLOR: always 9 | 10 | jobs: 11 | docs: 12 | runs-on: ubuntu-latest 13 | permissions: 14 | id-token: "write" 15 | contents: "read" 16 | pages: "write" 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | - uses: DeterminateSystems/nix-installer-action@main 21 | - uses: cachix/cachix-action@v14 22 | with: 23 | name: mnn-rs 24 | authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' 25 | - uses: DeterminateSystems/flake-checker-action@main 26 | 27 | - name: Generate docs 28 | run: nix build .#checks.x86_64-linux.mnn-docs 29 | 30 | - name: Setup Pages 31 | uses: actions/configure-pages@v5 32 | 33 | - name: Upload artifact 34 | uses: actions/upload-pages-artifact@v3 35 | with: 36 | path: result/share/doc 37 | 38 | - name: Deploy to gh-pages 39 | id: deployment 40 | uses: actions/deploy-pages@v4 41 | 42 | -------------------------------------------------------------------------------- /tests/multi_threaded.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::*; 3 | use mnn::ForwardType; 4 | 5 | #[cfg(test)] 6 | pub fn test_multi_threading(backend: ForwardType) -> Result<()> { 7 | let handles: Vec<_> = (1..=10) 8 | .map(move |_| std::thread::spawn(move || test_basic(backend))) 9 | .collect(); 10 | handles 11 | .into_iter() 12 | .map(|h| h.join().unwrap()) 13 | .collect::>>()?; 14 | Ok(()) 15 | } 16 | 17 | #[test] 18 | #[ignore = "takes too long"] 19 | fn test_multi_threading_cpu() { 20 | test_multi_threading(ForwardType::CPU).unwrap(); 21 | } 22 | 23 | #[cfg(feature = "metal")] 24 | #[test] 25 | #[ignore = "takes too long"] 26 | fn test_multi_threading_metal() { 27 | test_multi_threading(ForwardType::Metal).unwrap(); 28 | } 29 | 30 | #[cfg(feature = "opencl")] 31 | #[test] 32 | #[ignore = "takes too long"] 33 | fn test_multi_threading_opencl() { 34 | test_multi_threading(ForwardType::OpenCL).unwrap(); 35 | } 36 | 37 | #[test] 38 | #[ignore = "takes too long and unreliable on CI"] 39 | fn test_multi_path_cpu_cpu() { 40 | test_multipath_session(ForwardType::CPU, ForwardType::CPU).unwrap(); 41 | } 42 | -------------------------------------------------------------------------------- /tools/bencher/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "bencher" 3 | version = "0.1.0" 4 | edition = "2021" 5 | license.workspace = true 6 | 7 | [target."aarch64-apple-darwin".dependencies] 8 | mnn = { workspace = true, features = ["opencl", "serde", "metal"] } 9 | 10 | [target."x86_64-apple-darwin".dependencies] 11 | mnn = { workspace = true, features = ["opencl", "serde"] } 12 | 13 | [target."cfg(windows)".dependencies] 14 | mnn = { workspace = true, features = ["opencl", "serde"] } 15 | 16 | [dependencies] 17 | bytemuck = { version = "1.20.0", features = ["extern_crate_alloc"] } 18 | clap = { version = "4.5.22", features = ["derive", "unstable-v5"] } 19 | clap-verbosity-flag = { version = "3.0.1", features = [ 20 | "tracing", 21 | ], default-features = false } 22 | clap_complete = "4.5.38" 23 | console = "0.15.8" 24 | dunce = "1.0.5" 25 | error-stack = { workspace = true, features = ["serde"] } 26 | indicatif = "0.17.9" 27 | ndarray = "0.16.1" 28 | num = "0.4.3" 29 | same-file = "1.0.6" 30 | serde = { version = "1.0.215", features = ["derive"] } 31 | serde_json = "1.0.133" 32 | tempfile = "3.14.0" 33 | thiserror = "2.0.4" 34 | tracing = "0.1.41" 35 | tracing-subscriber = "0.3.19" 36 | -------------------------------------------------------------------------------- /mnn-sys/mnn_c/schedule_c.h: -------------------------------------------------------------------------------- 1 | #ifndef SCHEDULE_C_H 2 | #define SCHEDULE_C_H 3 | #include "backend_c.h" 4 | #include 5 | #include 6 | 7 | #ifdef __cplusplus 8 | extern "C" { 9 | #endif 10 | 11 | typedef struct MNNScheduleConfig MNNScheduleConfig; 12 | 13 | MNNScheduleConfig *mnnsc_create(); 14 | MNNScheduleConfig *mnnsc_clone(const MNNScheduleConfig *from); 15 | void mnnsc_destroy(MNNScheduleConfig *config); 16 | void mnnsc_set_save_tensors(MNNScheduleConfig *config, 17 | const char *const *saveTensors, 18 | size_t saveTensorsSize); 19 | void mnnsc_set_type(MNNScheduleConfig *config, MNNForwardType type); 20 | void mnnsc_set_num_threads(MNNScheduleConfig *config, int numThread); 21 | void mnnsc_set_mode(MNNScheduleConfig *config, int mode); 22 | void mnnsc_set_backup_type(MNNScheduleConfig *config, 23 | MNNForwardType backupType); 24 | void mnnsc_set_backend_config(MNNScheduleConfig *config, 25 | MNNBackendConfig *backendConfig); 26 | MNNForwardType mnnsc_get_type(MNNScheduleConfig *config); 27 | MNNForwardType mnnsc_get_backup_type(MNNScheduleConfig *config); 28 | 29 | #ifdef __cplusplus 30 | } 31 | #endif 32 | #endif // SCHEDULE_C_H 33 | -------------------------------------------------------------------------------- /benches/mnn-bench.rs: -------------------------------------------------------------------------------- 1 | use divan::*; 2 | #[divan::bench_group(sample_size = 5, sample_count = 5)] 3 | mod mnn_realesr_bench_with_ones { 4 | use divan::*; 5 | use mnn::*; 6 | #[divan::bench] 7 | pub fn mnn_realesr_benchmark_cpu(bencher: Bencher) { 8 | let mut net = Interpreter::from_file("tests/assets/realesr.mnn").unwrap(); 9 | let mut config = ScheduleConfig::new(); 10 | config.set_type(ForwardType::CPU); 11 | let session = net.create_session(config).unwrap(); 12 | bencher.bench_local(|| { 13 | let mut input = net.input(&session, "data").unwrap(); 14 | input.fill(1f32); 15 | net.run_session(&session).unwrap(); 16 | }); 17 | } 18 | 19 | #[cfg(feature = "opencl")] 20 | #[divan::bench] 21 | pub fn mnn_realesr_benchmark_opencl(bencher: Bencher) { 22 | let mut net = Interpreter::from_file("tests/assets/realesr.mnn").unwrap(); 23 | let mut config = ScheduleConfig::new(); 24 | config.set_type(ForwardType::OpenCL); 25 | let session = net.create_session(config).unwrap(); 26 | bencher.bench_local(|| { 27 | let mut input = net.input(&session, "data").unwrap(); 28 | input.fill(1f32); 29 | net.run_session(&session).unwrap(); 30 | net.wait(&session); 31 | }); 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [".", "mnn-bridge", "mnn-sync", "mnn-sys", "tools/bencher"] 3 | [workspace.package] 4 | license = "Apache-2.0" 5 | 6 | [package] 7 | name = "mnn" 8 | version = "0.2.0" 9 | edition = "2024" 10 | license = { workspace = true } 11 | 12 | [workspace.dependencies] 13 | mnn = { version = "0.2.0", path = "." } 14 | error-stack = { version = "0.5" } 15 | 16 | [dependencies] 17 | libc = "0.2" 18 | mnn-sys = { version = "0.1", path = "mnn-sys", features = [] } 19 | thiserror = "2.0" 20 | error-stack.workspace = true 21 | oneshot = "0.1" 22 | tracing = { version = "0.1.40", optional = true } 23 | dunce = "1.0.5" 24 | serde = { version = "1.0", features = ["derive"], optional = true } 25 | 26 | [features] 27 | metal = ["mnn-sys/metal"] 28 | coreml = ["mnn-sys/coreml"] 29 | vulkan = ["mnn-sys/vulkan"] 30 | opencl = ["mnn-sys/opencl"] 31 | opengl = ["mnn-sys/opengl"] 32 | crt_static = ["mnn-sys/crt_static"] 33 | # Disable mnn-threadpool to enable this 34 | openmp = ["mnn-sys/openmp"] 35 | mnn-threadpool = ["mnn-sys/mnn-threadpool"] 36 | tracing = ["dep:tracing"] 37 | profile = ["tracing"] 38 | serde = ["dep:serde"] 39 | 40 | default = ["mnn-threadpool"] 41 | 42 | 43 | [dev-dependencies] 44 | anyhow = "1.0" 45 | bytemuck = "1.17" 46 | clap = { version = "4.5", features = ["derive"] } 47 | divan = "0.1.14" 48 | tracing = "0.1.40" 49 | tracing-subscriber = "0.3.19" 50 | tracing-test = { version = "0.2.5", features = ["no-env-filter"] } 51 | 52 | [[bench]] 53 | name = "mnn-bench" 54 | harness = false 55 | 56 | [profile.rwd] 57 | debug = true 58 | inherits = "release" 59 | -------------------------------------------------------------------------------- /mnn-sys/patches/mnn-tracing.patch: -------------------------------------------------------------------------------- 1 | index 8f30cd68..77407812 100644 2 | --- a/include/MNN/MNNDefine.h 3 | +++ b/include/MNN/MNNDefine.h 4 | @@ -35,8 +35,27 @@ 5 | #define MNN_PRINT(format, ...) syslog(LOG_WARNING, format, ##__VA_ARGS__); fprintf(stderr, format, ##__VA_ARGS__) 6 | #define MNN_ERROR(format, ...) syslog(LOG_WARNING, format, ##__VA_ARGS__); fprintf(stderr, format, ##__VA_ARGS__) 7 | #else 8 | -#define MNN_PRINT(format, ...) printf(format, ##__VA_ARGS__) 9 | -#define MNN_ERROR(format, ...) printf(format, ##__VA_ARGS__) 10 | +enum class Level { 11 | + Info = 0, 12 | + Error = 1, 13 | +}; 14 | +extern "C" { 15 | +void mnn_ffi_emit(const char *file, size_t line, Level level, 16 | + const char *message); 17 | +} 18 | +#define MNN_PRINT(format, ...) \ 19 | + { \ 20 | + char logtmp[4096]; \ 21 | + snprintf(logtmp, 4096, format, ##__VA_ARGS__); \ 22 | + mnn_ffi_emit(__FILE__, __LINE__, Level::Info, logtmp); \ 23 | + } 24 | + 25 | +#define MNN_ERROR(format, ...) \ 26 | + { \ 27 | + char logtmp[4096]; \ 28 | + snprintf(logtmp, 4096, format, ##__VA_ARGS__); \ 29 | + mnn_ffi_emit(__FILE__, __LINE__, Level::Error, logtmp); \ 30 | + } 31 | #endif 32 | 33 | #ifdef DEBUG 34 | -------------------------------------------------------------------------------- /tests/resizing.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | use common::*; 3 | 4 | #[test] 5 | pub fn test_resizing() -> Result<()> { 6 | let model = std::fs::read("tests/assets/resizing.mnn").expect("No resizing model"); 7 | let mut net = Interpreter::from_bytes(&model).unwrap(); 8 | net.set_cache_file("resizing.cache", 128)?; 9 | let config = ScheduleConfig::default(); 10 | #[cfg(feature = "opencl")] 11 | config.set_type(ForwardType::OpenCL); 12 | let mut session = net.create_session(config).unwrap(); 13 | net.update_cache_file(&mut session)?; 14 | 15 | let now = std::time::Instant::now(); 16 | let mut mask = unsafe { net.input_unresized::(&session, "mask") }?; 17 | net.resize_tensor(&mut mask, [2048, 2048]); 18 | drop(mask); 19 | 20 | let mut og = unsafe { net.input_unresized::(&session, "original") }?; 21 | net.resize_tensor(&mut og, [2048, 2048, 3]); 22 | drop(og); 23 | 24 | let mut pain = unsafe { net.input_unresized::(&session, "inpainted") }?; 25 | net.resize_tensor(&mut pain, [2048, 2048, 3]); 26 | drop(pain); 27 | 28 | net.resize_session(&mut session); 29 | let inputs = net.inputs(&session); 30 | for tensor_info in inputs.iter() { 31 | let tensor = tensor_info.tensor::().unwrap(); 32 | println!( 33 | "{:13}: {:>13}", 34 | tensor_info.name(), 35 | format!("{:?}", tensor.shape()) 36 | ); 37 | let mut host = tensor.create_host_tensor_from_device(false); 38 | host.host_mut().fill(1.0); 39 | } 40 | drop(inputs); 41 | net.run_session(&session).unwrap(); 42 | println!("{:?}", now.elapsed()); 43 | Ok(()) 44 | } 45 | -------------------------------------------------------------------------------- /tests/backend.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused_imports)] 2 | pub mod common; 3 | use common::*; 4 | use mnn::ForwardType; 5 | use tracing_test::traced_test; 6 | 7 | #[cfg(feature = "coreml")] 8 | #[test] 9 | #[traced_test] 10 | fn compare_cpu_and_coreml_outputs() { 11 | let mut net = mnn::Interpreter::from_file("tests/assets/realesr.mnn").unwrap(); 12 | let cpu_config = ScheduleConfig::new(); 13 | let mut coreml_config = ScheduleConfig::new(); 14 | let mut bc = BackendConfig::new(); 15 | coreml_config.set_type(ForwardType::CoreML); 16 | let cpu_session = net.create_session(cpu_config).unwrap(); 17 | let coreml_session = net.create_session(coreml_config).unwrap(); 18 | net.inputs(&cpu_session).iter().for_each(|x| { 19 | let mut tensor = x.tensor::().expect("No tensor"); 20 | tensor.fill(1.0f32); 21 | }); 22 | net.inputs(&coreml_session).iter().for_each(|x| { 23 | let mut tensor = x.tensor::().expect("No tensor"); 24 | tensor.fill(1.0f32); 25 | }); 26 | 27 | net.run_session(&cpu_session).unwrap(); 28 | net.run_session(&coreml_session).unwrap(); 29 | 30 | let cpu_outputs = net.outputs(&cpu_session); 31 | let coreml_outputs = net.outputs(&coreml_session); 32 | 33 | cpu_outputs 34 | .iter() 35 | .zip(coreml_outputs.iter()) 36 | .for_each(|(cpu, coreml)| { 37 | let cpu_tensor = cpu.tensor::().expect("No tensor"); 38 | let coreml_tensor = coreml.tensor::().expect("No tensor"); 39 | let cpu = cpu_tensor.create_host_tensor_from_device(true); 40 | let coreml = coreml_tensor.create_host_tensor_from_device(true); 41 | assert_eq!(cpu.host(), coreml.host()); 42 | }); 43 | } 44 | -------------------------------------------------------------------------------- /mnn-sys/mnn_c/backend_c.h: -------------------------------------------------------------------------------- 1 | #ifndef BACKEND_C_H 2 | #define BACKEND_C_H 3 | #include 4 | 5 | #ifdef __cplusplus 6 | extern "C" { 7 | #endif 8 | 9 | typedef enum { Memory_Normal = 0, Memory_High, Memory_Low } MemoryMode; 10 | typedef enum { Power_Normal = 0, Power_High, Power_Low } PowerMode; 11 | typedef enum { 12 | Precision_Normal = 0, 13 | Precision_High, 14 | Precision_Low, 15 | Precision_Low_BF16 16 | } PrecisionMode; 17 | typedef struct MNNBackendConfig MNNBackendConfig; 18 | // struct BackendConfig { 19 | // MemoryMode memory; // = Memory_Normal; 20 | // PowerMode power; // = Power_Normal; 21 | // PrecisionMode precision; // = Precision_Normal; 22 | // /** user defined context */ 23 | // union { 24 | // void *sharedContext; // = nullptr; 25 | // size_t flags; // Valid for CPU Backend 26 | // }; 27 | // }; 28 | 29 | MNNBackendConfig *mnnbc_create(); 30 | MNNBackendConfig *mnnbc_clone(const MNNBackendConfig *config); 31 | void mnnbc_destroy(MNNBackendConfig *config); 32 | void mnnbc_set_memory_mode(MNNBackendConfig *config, MemoryMode memory_mode); 33 | void mnnbc_set_power_mode(MNNBackendConfig *config, PowerMode power_mode); 34 | void mnnbc_set_precision_mode(MNNBackendConfig *config, 35 | PrecisionMode precision_mode); 36 | void mnnbc_set_shared_context(MNNBackendConfig *config, void *shared_context); 37 | void mnnbc_set_flags(MNNBackendConfig *config, size_t flags); 38 | void mnnbc_reset(MNNBackendConfig *config); 39 | 40 | MemoryMode mnnbc_get_memory_mode(MNNBackendConfig *config); 41 | PowerMode mnnbc_get_power_mode(MNNBackendConfig *config); 42 | PrecisionMode mnnbc_get_precision_mode(MNNBackendConfig *config); 43 | 44 | #ifdef __cplusplus 45 | } 46 | #endif 47 | #endif // BACKEND_C_H 48 | -------------------------------------------------------------------------------- /src/session.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; 2 | 3 | /// A session is a context in which a computation graph is executed. 4 | /// 5 | /// Inference unit. multiple sessions could share one net/interpreter. 6 | #[derive(Debug)] 7 | pub struct Session { 8 | /// Pointer to the underlying MNN session. 9 | pub(crate) inner: *mut mnn_sys::Session, 10 | /// Pointer to the underlying MNN interpreter 11 | /// # Safety Note 12 | /// Since the interpreter is actually not owned by session but it is a shared resource we can 13 | /// reasonably assume that the interpreter will outlive the session. (This is not a compile 14 | /// time gurantee yet) 15 | /// TODO: Add a proper lifetime bound to ensure the interpreter outlives the session. 16 | pub(crate) net: *mut mnn_sys::Interpreter, 17 | /// Internal session configurations. 18 | pub(crate) __session_internals: crate::SessionInternals, 19 | /// Marker to ensure the struct is not Send or Sync. 20 | pub(crate) __marker: PhantomData<()>, 21 | } 22 | 23 | /// Enum representing the internal configurations of a session. 24 | #[derive(Debug)] 25 | pub enum SessionInternals { 26 | /// Single session configuration. 27 | Single(crate::ScheduleConfig), 28 | /// Multiple session configurations. 29 | MultiSession(crate::ScheduleConfigs), 30 | } 31 | 32 | impl Session { 33 | /// Calls the destroy function on the underlying MNN session. 34 | pub fn destroy(&mut self) { 35 | unsafe { 36 | mnn_sys::Interpreter_releaseSession(self.net, self.inner); 37 | } 38 | // unsafe { mnn_sys::Session_destroy(self.inner) } 39 | } 40 | } 41 | 42 | impl Drop for Session { 43 | /// Custom drop implementation to ensure the underlying MNN session is properly destroyed. 44 | fn drop(&mut self) { 45 | self.destroy(); 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /.github/workflows/build.yaml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | checks-matrix: 14 | runs-on: ubuntu-latest 15 | outputs: 16 | matrix: ${{ steps.set-matrix.outputs.matrix }} 17 | steps: 18 | - uses: actions/checkout@v4 19 | - uses: DeterminateSystems/nix-installer-action@main 20 | - id: set-matrix 21 | name: Generate Nix Matrix 22 | run: | 23 | set -Eeu 24 | matrix="$(nix eval --json '.#githubActions.matrix')" 25 | echo "matrix=$matrix" >> "$GITHUB_OUTPUT" 26 | 27 | checks-build: 28 | needs: checks-matrix 29 | runs-on: ${{ matrix.os }} 30 | strategy: 31 | matrix: ${{fromJSON(needs.checks-matrix.outputs.matrix)}} 32 | steps: 33 | - uses: actions/checkout@v4 34 | with: 35 | lfs: true 36 | submodules: 'recursive' 37 | - uses: DeterminateSystems/nix-installer-action@main 38 | - uses: cachix/cachix-action@v14 39 | with: 40 | name: mnn-rs 41 | authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' 42 | - run: nix build -L '.#${{ matrix.attr }}' 43 | 44 | codecov: 45 | runs-on: ubuntu-latest 46 | permissions: 47 | id-token: "write" 48 | contents: "read" 49 | 50 | steps: 51 | - uses: actions/checkout@v4 52 | with: 53 | lfs: true 54 | submodules: 'recursive' 55 | - uses: DeterminateSystems/nix-installer-action@main 56 | - uses: cachix/cachix-action@v14 57 | with: 58 | name: mnn-rs 59 | authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' 60 | 61 | - name: Run codecov 62 | run: nix build .#checks.x86_64-linux.mnn-llvm-cov 63 | 64 | - name: Upload coverage reports to Codecov 65 | uses: codecov/codecov-action@v4.0.1 66 | with: 67 | flags: unittests 68 | name: codecov-mnn 69 | fail_ci_if_error: true 70 | token: ${{ secrets.CODECOV_TOKEN }} 71 | files: ./result 72 | verbose: true 73 | 74 | -------------------------------------------------------------------------------- /tests/basic.rs: -------------------------------------------------------------------------------- 1 | pub mod common; 2 | use common::*; 3 | use mnn::ForwardType; 4 | 5 | #[test] 6 | fn test_basic_cpu() { 7 | test_basic(ForwardType::CPU).unwrap(); 8 | } 9 | #[cfg(feature = "metal")] 10 | #[test] 11 | #[ignore = "Doesn't work on ci"] 12 | fn test_basic_metal() { 13 | test_basic(ForwardType::Metal).unwrap(); 14 | } 15 | #[cfg(feature = "opencl")] 16 | #[test] 17 | #[ignore = "Doesn't work on ci"] 18 | fn test_basic_opencl() -> Result<(), Box> { 19 | let backend = ForwardType::OpenCL; 20 | let realesr = std::path::Path::new("tests/assets/realesr.mnn"); 21 | 22 | let mut net = mnn::Interpreter::from_file(realesr)?; 23 | net.set_cache_file(realesr.with_extension("cache"), 128)?; 24 | let mut config = ScheduleConfig::new(); 25 | config.set_type(backend); 26 | let mut session = net.create_session(config)?; 27 | net.update_cache_file(&mut session)?; 28 | 29 | net.inputs(&session).iter().for_each(|x| { 30 | let mut tensor = x.tensor::().expect("No tensor"); 31 | println!("{}: {:?}", x.name(), tensor.shape()); 32 | tensor.fill(1.0f32); 33 | }); 34 | net.run_session(&session)?; 35 | let outputs = net.outputs(&session); 36 | outputs.iter().for_each(|x| { 37 | let tensor = x.tensor::().expect("No tensor"); 38 | tensor.wait(ffi::MapType::MAP_TENSOR_READ, true); 39 | println!("Waiting for tensor: {}", x.name()); 40 | println!("{}: {:?}", x.name(), tensor.shape()); 41 | // let _ = tensor.create_host_tensor_from_device(true); 42 | }); 43 | 44 | // drop(outputs); 45 | // drop(session); 46 | // drop(net); 47 | Ok(()) 48 | } 49 | #[cfg(feature = "coreml")] 50 | #[test] 51 | fn test_basic_coreml() { 52 | test_basic(ForwardType::CoreML).unwrap(); 53 | } 54 | #[cfg(feature = "opengl")] 55 | #[test] 56 | fn test_basic_opengl() { 57 | test_basic(ForwardType::OpenGL).unwrap(); 58 | } 59 | 60 | #[test] 61 | #[ignore = "takes too long and unreliable on CI"] 62 | fn test_multi_path_cpu_cpu() { 63 | test_multipath_session(ForwardType::CPU, ForwardType::CPU).unwrap(); 64 | } 65 | 66 | // #[cfg(feature = "opencl")] 67 | // #[test] 68 | // fn test_multi_path_opencl_cpu() { 69 | // test_multipath_session(ForwardType::OpenCL, ForwardType::CPU).unwrap(); 70 | // } 71 | -------------------------------------------------------------------------------- /examples/simple.rs: -------------------------------------------------------------------------------- 1 | // use mnn::utils::*; 2 | use mnn::*; 3 | use std::path::PathBuf; 4 | 5 | #[derive(Debug, clap::Parser, Clone)] 6 | pub struct Cli { 7 | // image: PathBuf, 8 | model: PathBuf, 9 | // #[clap(short, long, default_value = "metal")] 10 | // forward: ForwardType, 11 | // #[clap(short, long, default_value = "high")] 12 | // precision: Modes, 13 | // #[clap(short = 'P', long, default_value = "high")] 14 | // power: Modes, 15 | } 16 | 17 | pub fn main() -> anyhow::Result<()> { 18 | use clap::Parser; 19 | let cli = Cli::parse(); 20 | let mut interpreter = Interpreter::from_file(cli.model)?; 21 | 22 | let mut config = ScheduleConfig::new(); 23 | config.set_type(ForwardType::CPU); 24 | let mut backend_config = BackendConfig::new(); 25 | backend_config.set_precision_mode(PrecisionMode::High); 26 | backend_config.set_power_mode(PowerMode::High); 27 | config.set_backend_config(backend_config); 28 | 29 | let now = std::time::Instant::now(); 30 | let session = interpreter.create_session(config)?; 31 | println!("create session time: {:?}", now.elapsed()); 32 | let mut image = interpreter.input(&session, "image")?; 33 | let mut mask = interpreter.input(&session, "mask")?; 34 | let mut image_tensor = image.create_host_tensor_from_device(false); 35 | image_tensor.host_mut().fill(1.0f32); 36 | image.copy_from_host_tensor(&image_tensor)?; 37 | let mut mask_tensor = mask.create_host_tensor_from_device(false); 38 | mask_tensor.host_mut().fill(0.7f32); 39 | let now = std::time::Instant::now(); 40 | mask.copy_from_host_tensor(&mask_tensor)?; 41 | println!("copy time: {:?}", now.elapsed()); 42 | 43 | let output = interpreter.output(&session, "output")?; 44 | // image.copy_from_host_tensor(&unit_tensor)?; 45 | 46 | let now = std::time::Instant::now(); 47 | interpreter.run_session(&session)?; 48 | output.wait(ffi::MapType::MAP_TENSOR_READ, true); 49 | println!("run time: {:?}", now.elapsed()); 50 | 51 | let now = std::time::Instant::now(); 52 | let output_tensor = output.create_host_tensor_from_device(true); 53 | println!("copy time: {:?}", now.elapsed()); 54 | 55 | let out_vec = output_tensor.host().to_vec(); 56 | let mut out_ppm = b"P6\n512 512\n255\n".to_vec(); 57 | out_ppm.extend(out_vec.iter().map(|x: &f32| *x as u8)); 58 | std::fs::write("output.ppm", out_ppm)?; 59 | 60 | Ok(()) 61 | } 62 | -------------------------------------------------------------------------------- /mnn-sys/mnn_c/backend_c.cpp: -------------------------------------------------------------------------------- 1 | #include "backend_c.h" 2 | #include 3 | 4 | MNNBackendConfig *mnnbc_create() { 5 | return reinterpret_cast(new MNN::BackendConfig()); 6 | } 7 | 8 | MNNBackendConfig *mnnbc_clone(const MNNBackendConfig *config) { 9 | return reinterpret_cast(new MNN::BackendConfig( 10 | *reinterpret_cast(config))); 11 | } 12 | 13 | void mnnbc_destroy(MNNBackendConfig *config) { 14 | delete reinterpret_cast(config); 15 | } 16 | 17 | void mnnbc_set_memory_mode(MNNBackendConfig *config, MemoryMode memory_mode) { 18 | reinterpret_cast(config)->memory = 19 | static_cast(memory_mode); 20 | } 21 | void mnnbc_set_power_mode(MNNBackendConfig *config, PowerMode power_mode) { 22 | reinterpret_cast(config)->power = 23 | static_cast(power_mode); 24 | } 25 | void mnnbc_set_precision_mode(MNNBackendConfig *config, 26 | PrecisionMode precision_mode) { 27 | reinterpret_cast(config)->precision = 28 | static_cast(precision_mode); 29 | } 30 | void mnnbc_set_shared_context(MNNBackendConfig *config, void *shared_context) { 31 | reinterpret_cast(config)->sharedContext = 32 | shared_context; 33 | } 34 | void mnnbc_set_flags(MNNBackendConfig *config, size_t flags) { 35 | reinterpret_cast(config)->flags = flags; 36 | } 37 | void mnnbc_reset(MNNBackendConfig *config) { 38 | reinterpret_cast(config)->memory = 39 | MNN::BackendConfig::Memory_Normal; 40 | reinterpret_cast(config)->power = 41 | MNN::BackendConfig::Power_Normal; 42 | reinterpret_cast(config)->precision = 43 | MNN::BackendConfig::Precision_Normal; 44 | reinterpret_cast(config)->sharedContext = nullptr; 45 | } 46 | 47 | MemoryMode mnnbc_get_memory_mode(MNNBackendConfig *config) { 48 | return static_cast( 49 | reinterpret_cast(config)->memory); 50 | } 51 | PowerMode mnnbc_get_power_mode(MNNBackendConfig *config) { 52 | return static_cast( 53 | reinterpret_cast(config)->power); 54 | } 55 | PrecisionMode mnnbc_get_precision_mode(MNNBackendConfig *config) { 56 | return static_cast( 57 | reinterpret_cast(config)->precision); 58 | } 59 | -------------------------------------------------------------------------------- /mnn-sys/mnn_c/schedule_c.cpp: -------------------------------------------------------------------------------- 1 | #include "schedule_c.h" 2 | #include 3 | #include 4 | 5 | MNNScheduleConfig *mnnsc_create() { 6 | auto mnnsc = new MNN::ScheduleConfig(); 7 | mnnsc->saveTensors = std::vector(); 8 | return reinterpret_cast(mnnsc); 9 | } 10 | 11 | MNNScheduleConfig *mnnsc_clone(const MNNScheduleConfig *from) { 12 | auto mnn_from = reinterpret_cast(from); 13 | auto mnn_to = new MNN::ScheduleConfig(*mnn_from); 14 | return reinterpret_cast(mnn_to); 15 | } 16 | 17 | void mnnsc_destroy(MNNScheduleConfig *config) { 18 | auto mnn_config = reinterpret_cast(config); 19 | delete mnn_config; 20 | } 21 | 22 | void mnnsc_set_save_tensors(MNNScheduleConfig *config, 23 | const char *const *saveTensors, 24 | size_t saveTensorsSize) { 25 | auto mnn_config = reinterpret_cast(config); 26 | auto mnn_saveTensors = 27 | std::vector(saveTensors, saveTensors + saveTensorsSize); 28 | mnn_config->saveTensors = std::move(mnn_saveTensors); 29 | } 30 | 31 | void mnnsc_set_type(MNNScheduleConfig *config, MNNForwardType type) { 32 | auto mnn_config = reinterpret_cast(config); 33 | mnn_config->type = type; 34 | } 35 | 36 | void mnnsc_set_num_threads(MNNScheduleConfig *config, int numThread) { 37 | auto mnn_config = reinterpret_cast(config); 38 | mnn_config->numThread = numThread; 39 | } 40 | 41 | void mnnsc_set_mode(MNNScheduleConfig *config, int mode) { 42 | auto mnn_config = reinterpret_cast(config); 43 | mnn_config->mode = mode; 44 | } 45 | 46 | void mnnsc_set_backup_type(MNNScheduleConfig *config, 47 | MNNForwardType backupType) { 48 | auto mnn_config = reinterpret_cast(config); 49 | mnn_config->backupType = backupType; 50 | } 51 | void mnnsc_set_backend_config(MNNScheduleConfig *config, 52 | MNNBackendConfig *backendConfig) { 53 | auto mnn_config = reinterpret_cast(config); 54 | mnn_config->backendConfig = 55 | reinterpret_cast(backendConfig); 56 | } 57 | 58 | MNNForwardType mnnsc_get_type(MNNScheduleConfig *config) { 59 | return reinterpret_cast(config)->type; 60 | } 61 | MNNForwardType mnnsc_get_backup_type(MNNScheduleConfig *config) { 62 | return reinterpret_cast(config)->backupType; 63 | } 64 | -------------------------------------------------------------------------------- /tests/segfault.rs: -------------------------------------------------------------------------------- 1 | /// This segfault on OpenCL backend if we print the tensorinfo 2 | #[cfg(feature = "opencl")] 3 | #[test] 4 | fn test_segfault_case_1_() -> Result<(), Box> { 5 | use mnn::*; 6 | let backend = ForwardType::OpenCL; 7 | let realesr = std::path::Path::new("tests/assets/realesr.mnn"); 8 | 9 | let mut net = mnn::Interpreter::from_file(realesr)?; 10 | net.set_cache_file(realesr.with_extension("cache"), 128)?; 11 | let mut config = ScheduleConfig::new(); 12 | config.set_type(backend); 13 | let mut session = net.create_session(config)?; 14 | net.update_cache_file(&mut session)?; 15 | 16 | net.inputs(&session).iter().for_each(|x| { 17 | let mut tensor = x.tensor::().expect("No tensor"); 18 | // println!("{}: {:?}", x.name(), tensor.shape()); 19 | println!("{:?}", x); 20 | tensor.fill(1.0f32); 21 | }); 22 | net.run_session(&session)?; 23 | let outputs = net.outputs(&session); 24 | drop(outputs); 25 | drop(session); 26 | drop(net); 27 | Ok(()) 28 | } 29 | 30 | #[test] 31 | #[ignore] 32 | pub fn test_resizing() { 33 | use mnn::*; 34 | let model = std::fs::read("tests/assets/resizing.mnn").expect("No resizing model"); 35 | let mut net = Interpreter::from_bytes(&model).unwrap(); 36 | let config = ScheduleConfig::default(); 37 | let mut session = net.create_session(config).unwrap(); 38 | 39 | loop { 40 | let inputs = net.inputs(&session); 41 | for tensor_info in inputs.iter() { 42 | let mut tensor = unsafe { tensor_info.tensor_unresized::() }.unwrap(); 43 | let mut shape = tensor.shape().as_ref().to_vec(); 44 | dbg!(&shape); 45 | shape.iter_mut().for_each(|v| { 46 | if *v == -1 { 47 | *v = 3; 48 | } 49 | }); 50 | dbg!(&shape); 51 | net.resize_tensor(&mut tensor, &shape); 52 | } 53 | drop(inputs); 54 | 55 | net.resize_session(&mut session); 56 | let inputs = net.inputs(&session); 57 | for tensor_info in inputs.iter() { 58 | let tensor = tensor_info.tensor::().unwrap(); 59 | println!( 60 | "{:13}: {:>13}", 61 | tensor_info.name(), 62 | format!("{:?}", tensor.shape()) 63 | ); 64 | let mut host = tensor.create_host_tensor_from_device(false); 65 | host.host_mut().fill(1.0); 66 | } 67 | drop(inputs); 68 | net.run_session(&session).unwrap(); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /mnn-sys/mnn_c/utils.cpp: -------------------------------------------------------------------------------- 1 | #include "utils.h" 2 | #include 3 | #include 4 | #ifdef __DEBUG 5 | #include 6 | #include 7 | 8 | void code_bits_lanes(const char *name, halide_type_t *type) { 9 | printf("====================================\n"); 10 | printf("sizes: \n"); 11 | std::cout << "code: " << sizeof(type->code) << std::endl; 12 | std::cout << "bits: " << sizeof(type->bits) << std::endl; 13 | std::cout << "lanes: " << sizeof(type->lanes) << std::endl; 14 | printf("%s: cbt %d %d %d\n", name, type->code, type->bits, type->lanes); 15 | printf("sizeof(%s): %lu\n", name, sizeof(*type)); 16 | printf("====================================\n"); 17 | } 18 | #endif 19 | TensorInfoArray *createTensorInfoArray(size_t count) { 20 | TensorInfoArray *array; 21 | array = (TensorInfoArray *)malloc(sizeof(TensorInfoArray)); 22 | array->size = count; 23 | array->tensors = (TensorInfo *)malloc(count * sizeof(TensorInfo)); 24 | return array; 25 | } 26 | 27 | void destroyTensorInfoArray(TensorInfoArray *array) { 28 | for (size_t i = 0; i < array->size; i++) { 29 | destroyCString(&array->tensors[i].name); 30 | } 31 | free(array->tensors); 32 | array->tensors = NULL; 33 | array->size = 0; 34 | free(array); 35 | array = NULL; 36 | } 37 | 38 | TensorInfo *getTensorInfoArray(TensorInfoArray const *array, size_t index) { 39 | if (index >= array->size) { 40 | return NULL; 41 | } 42 | return array->tensors + index; 43 | } 44 | 45 | CString createCString(const char *str, size_t max_size) { 46 | CString cstr; 47 | // Find out the size of the input 48 | size_t size = 0; 49 | while (str[size] != '\0' || size <= max_size) { 50 | size++; 51 | } 52 | cstr.size = size; 53 | cstr.data = (char *)malloc(size + 1); 54 | if (cstr.data) { 55 | memcpy((void *)cstr.data, str, size); 56 | cstr.data[size] = '\0'; 57 | } 58 | return cstr; 59 | } 60 | 61 | void destroyCString(CString *cstr) { 62 | free(cstr->data); 63 | cstr->data = NULL; 64 | cstr->size = 0; 65 | } 66 | 67 | #ifdef __DISABLED 68 | struct halide_type_t halide_type_to_halide_type_t(halide_type_c type) { 69 | // std::cout << sizeof(halide_type_of()) << std::endl; 70 | // std::cout 71 | // << "================halide_type_to_halide_type_t=====================" 72 | // << std::endl; 73 | auto htt = halide_type_t(type.code, type.bits, type.lanes); 74 | // code_bits_lanes("htt", &htt); 75 | return htt; 76 | } 77 | 78 | union TypeUnion { 79 | halide_type_t htt; 80 | uint64_t as_uint64; 81 | TypeUnion() {} 82 | ~TypeUnion() {} 83 | }; 84 | 85 | uint64_t halide_type_t_from(halide_type_c type) { 86 | TypeUnion tu; 87 | tu.htt = halide_type_t(type.code, type.bits, type.lanes); 88 | 89 | return tu.as_uint64; 90 | // // return reinterpret_cast(htt); 91 | // return reinterpret_cast(reinterpret_cast(&htt)); 92 | } 93 | #endif 94 | -------------------------------------------------------------------------------- /tests/common.rs: -------------------------------------------------------------------------------- 1 | pub use mnn::*; 2 | pub type Error = Box; 3 | pub type Result = std::result::Result; 4 | pub struct Model { 5 | bytes: &'static [u8], 6 | } 7 | 8 | impl Model { 9 | pub const fn new() -> Self { 10 | Model { 11 | bytes: include_bytes!("assets/realesr.mnn"), 12 | } 13 | } 14 | } 15 | 16 | impl Default for Model { 17 | fn default() -> Self { 18 | Self::new() 19 | } 20 | } 21 | 22 | impl AsRef<[u8]> for Model { 23 | fn as_ref(&self) -> &[u8] { 24 | self.bytes 25 | } 26 | } 27 | 28 | #[allow(dead_code)] 29 | pub fn test_basic(backend: ForwardType) -> Result<()> { 30 | let mut net = mnn::Interpreter::from_file("tests/assets/realesr.mnn")?; 31 | let mut config = ScheduleConfig::new(); 32 | config.set_type(backend); 33 | let session = net.create_session(config)?; 34 | net.inputs(&session).iter().for_each(|x| { 35 | let mut tensor = x.tensor::().expect("No tensor"); 36 | println!("{}: {:?}", x.name(), tensor.shape()); 37 | tensor.fill(1.0f32); 38 | }); 39 | net.run_session(&session)?; 40 | let outputs = net.outputs(&session); 41 | for output in outputs.iter() { 42 | println!("output: {:?}", output); 43 | let tensor = output.tensor::()?; 44 | let shape = tensor.shape(); 45 | assert_eq!(shape.as_ref(), [1, 3, 2048, 2048]); 46 | } 47 | Ok(()) 48 | } 49 | 50 | #[allow(dead_code)] 51 | pub fn test_multipath_session(backend: ForwardType, backend2: ForwardType) -> Result<()> { 52 | use mnn::BackendConfig; 53 | 54 | let mut net = mnn::Interpreter::from_bytes(Model::new())?; 55 | let mut config = ScheduleConfig::new(); 56 | config.set_type(backend); 57 | config.set_backup_type(backend); 58 | let mut bc = BackendConfig::new(); 59 | bc.set_memory_mode(mnn::MemoryMode::High); 60 | bc.set_precision_mode(mnn::PrecisionMode::High); 61 | bc.set_power_mode(mnn::PowerMode::High); 62 | let mut config2 = ScheduleConfig::new(); 63 | config2.set_type(backend2); 64 | config2.set_backup_type(backend2); 65 | let mut bc = BackendConfig::new(); 66 | bc.set_memory_mode(mnn::MemoryMode::High); 67 | bc.set_precision_mode(mnn::PrecisionMode::High); 68 | bc.set_power_mode(mnn::PowerMode::High); 69 | config2.set_backend_config(bc); 70 | 71 | let session = net.create_multipath_session([config, config2])?; 72 | { 73 | let inputs = net.inputs(&session); 74 | for input in inputs.iter() { 75 | println!("input: {:?}", input); 76 | input.tensor::()?.fill(0.0); 77 | } 78 | } 79 | net.run_session(&session)?; 80 | let outputs = net.outputs(&session); 81 | for output in outputs.iter() { 82 | println!("output: {:?}", output); 83 | let tensor = output.tensor::()?; 84 | let shape = tensor.shape(); 85 | assert_eq!(shape.as_ref(), [1, 3, 2048, 2048]); 86 | } 87 | Ok(()) 88 | } 89 | -------------------------------------------------------------------------------- /mnn-sys/mnn_c/tensor_c.h: -------------------------------------------------------------------------------- 1 | #ifndef TENSOR_C_H 2 | #define TENSOR_C_H 3 | #include "utils.h" 4 | #include 5 | #include 6 | #include 7 | 8 | #ifdef __cplusplus 9 | extern "C" { 10 | #endif 11 | typedef struct Tensor Tensor; 12 | typedef struct { 13 | int shape[4]; 14 | size_t size; 15 | } TensorShape; 16 | typedef enum { TENSORFLOW, CAFFE, CAFFE_C4 } DimensionType; 17 | typedef enum { HANDLE_NONE = 0, HANDLE_STRING = 1 } HandleDataType; 18 | typedef enum { MAP_TENSOR_WRITE = 0, MAP_TENSOR_READ = 1 } MapType; 19 | Tensor *Tensor_create(int dimSize, DimensionType type); 20 | Tensor *Tensor_createFromTensor(const Tensor *tensor, DimensionType type, 21 | int allocMemory); 22 | void Tensor_destroy(Tensor *tensor); 23 | Tensor *Tensor_createDevice(const int *shape, size_t shapeSize, 24 | struct halide_type_t typeCode, 25 | DimensionType dimType); 26 | Tensor *Tensor_createWith(const int *shape, size_t shapeSize, 27 | struct halide_type_t typeCode, void *data, 28 | DimensionType dimType); 29 | int Tensor_copyFromHostTensor(Tensor *deviceTensor, const Tensor *hostTensor); 30 | int Tensor_copyToHostTensor(const Tensor *deviceTensor, Tensor *hostTensor); 31 | Tensor *Tensor_createHostTensorFromDevice(const Tensor *deviceTensor, 32 | int copyData); 33 | DimensionType Tensor_getDimensionType(const Tensor *tensor); 34 | const halide_buffer_t *Tensor_buffer(const Tensor *tensor); 35 | halide_buffer_t *Tensor_buffer_mut(Tensor *tensor); 36 | const void *Tensor_host(const Tensor *tensor); 37 | void *Tensor_host_mut(Tensor *tensor); 38 | uint64_t Tensor_deviceId(const Tensor *tensor); 39 | int Tensor_dimensions(const Tensor *tensor); 40 | TensorShape Tensor_shape(const Tensor *tensor); 41 | int Tensor_size(const Tensor *tensor); 42 | size_t Tensor_usize(const Tensor *tensor); 43 | int Tensor_elementSize(const Tensor *tensor); 44 | int Tensor_width(const Tensor *tensor); 45 | int Tensor_height(const Tensor *tensor); 46 | int Tensor_channel(const Tensor *tensor); 47 | int Tensor_batch(const Tensor *tensor); 48 | int Tensor_stride(const Tensor *tensor, int index); 49 | int Tensor_length(const Tensor *tensor, int index); 50 | void Tensor_setStride(Tensor *tensor, int index, int stride); 51 | void Tensor_setLength(Tensor *tensor, int index, int length); 52 | int Tensor_getDeviceInfo(const Tensor *tensor, void *dst, int forwardType); 53 | void Tensor_print(const Tensor *tensor); 54 | void Tensor_printShape(const Tensor *tensor); 55 | void *Tensor_map(Tensor *tensor, MapType mtype, DimensionType dtype); 56 | void Tensor_unmap(Tensor *tensor, MapType mtype, DimensionType dtype, 57 | void *mapPtr); 58 | Tensor* Tensor_clone(const Tensor *tensor); 59 | int Tensor_wait(Tensor *tensor, MapType mtype, int finish); 60 | int Tensor_setDevicePtr(Tensor *tensor, const void *devicePtr, int memoryType); 61 | struct halide_type_t Tensor_getType(const Tensor *tensor); 62 | bool Tensor_isTypeOf(const Tensor *tensor, struct halide_type_t type); 63 | #ifdef __cplusplus 64 | } 65 | #endif 66 | #endif // TENSOR_C_H 67 | -------------------------------------------------------------------------------- /mnn-sys/src/lib.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::CStr; 2 | mod tracing; 3 | 4 | pub mod cpp { 5 | #![allow(non_upper_case_globals)] 6 | #![allow(non_camel_case_types)] 7 | #![allow(non_snake_case)] 8 | include!(concat!(env!("OUT_DIR"), "/mnn_cpp.rs")); 9 | } 10 | mod sys { 11 | #![allow(non_upper_case_globals)] 12 | #![allow(non_camel_case_types)] 13 | #![allow(non_snake_case)] 14 | #![allow(clippy::manual_c_str_literals)] 15 | #![allow(clippy::suspicious_doc_comments)] 16 | include!(concat!(env!("OUT_DIR"), "/mnn_c.rs")); 17 | } 18 | pub use sys::*; 19 | impl DimensionType { 20 | pub const NHWC: Self = Self::TENSORFLOW; 21 | pub const NCHW: Self = Self::CAFFE; 22 | pub const NC4HW4: Self = Self::CAFFE_C4; 23 | } 24 | impl halide_type_t { 25 | unsafe fn new(code: halide_type_code_t, bits: u8, lanes: u16) -> Self { 26 | Self { code, bits, lanes } 27 | } 28 | } 29 | 30 | pub fn halide_type_of() -> halide_type_t { 31 | T::halide_type_of() 32 | } 33 | 34 | pub trait HalideType: seal::Sealed { 35 | fn halide_type_of() -> halide_type_t; 36 | } 37 | mod seal { 38 | pub trait Sealed {} 39 | } 40 | 41 | macro_rules! halide_types { 42 | ($($t:ty => $ht:expr),*) => { 43 | $( 44 | impl seal::Sealed for $t {} 45 | impl HalideType for $t { 46 | fn halide_type_of() -> halide_type_t { 47 | unsafe { 48 | $ht 49 | } 50 | } 51 | } 52 | )* 53 | }; 54 | } 55 | 56 | halide_types! { 57 | f32 => halide_type_t::new(halide_type_code_t::halide_type_float, 32, 1), 58 | f64 => halide_type_t::new(halide_type_code_t::halide_type_float, 64, 1), 59 | bool => halide_type_t::new(halide_type_code_t::halide_type_uint, 1, 1), 60 | u8 => halide_type_t::new(halide_type_code_t::halide_type_uint, 8,1), 61 | u16 => halide_type_t::new(halide_type_code_t::halide_type_uint, 16,1), 62 | u32 => halide_type_t::new(halide_type_code_t::halide_type_uint, 32,1), 63 | u64 => halide_type_t::new(halide_type_code_t::halide_type_uint, 64,1), 64 | i8 => halide_type_t::new(halide_type_code_t::halide_type_int, 8,1), 65 | i16 => halide_type_t::new(halide_type_code_t::halide_type_int, 16,1), 66 | i32 => halide_type_t::new(halide_type_code_t::halide_type_int, 32,1), 67 | i64 => halide_type_t::new(halide_type_code_t::halide_type_int, 64,1) 68 | } 69 | 70 | impl Drop for CString { 71 | fn drop(&mut self) { 72 | unsafe { destroyCString(self.as_ptr_mut()) } 73 | } 74 | } 75 | 76 | impl CString { 77 | pub fn as_ptr(&self) -> *const CString { 78 | core::ptr::addr_of!(*self) 79 | } 80 | 81 | pub fn as_ptr_mut(&mut self) -> *mut CString { 82 | core::ptr::addr_of_mut!(*self) 83 | } 84 | /// # Safety 85 | /// This function is unsafe because it dereferences a raw pointer. 86 | pub unsafe fn to_cstr(&self) -> &CStr { 87 | unsafe { std::ffi::CStr::from_ptr(self.data) } 88 | } 89 | } 90 | 91 | impl AsRef<[i32]> for TensorShape { 92 | fn as_ref(&self) -> &[i32] { 93 | &self.shape[..self.size] 94 | } 95 | } 96 | 97 | impl halide_type_code_t { 98 | /// # Safety 99 | /// This function is unsafe because this basically truansmutes an integer to an enum. 100 | /// And if the enum is not valid, it will cause undefined behavior in rust. 101 | pub unsafe fn from_u32(code: u32) -> Self { 102 | unsafe { std::mem::transmute(code) } 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![deny(missing_docs)] 2 | //! 3 | //! Ergonomic rust bindings for [MNN](https://github.com/alibaba/MNN) 4 | //! 5 | //! The main data structures used are [`Tensor`] and [`Interpreter`]. 6 | //! [Interpreter] should be thread safe and can be used to run multiple sessions concurrently. 7 | //! [Send] / [Sync] is not implemented for Interpreter yet since we don't know how it will be used. 8 | //! 9 | //! ![Codecov](https://img.shields.io/codecov/c/github/aftershootco/mnn-rs?link=https%3A%2F%2Fapp.codecov.io%2Fgithub%2Faftershootco%2Fmnn-rs) 10 | //! ![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/aftershootco/mnn-rs/build.yaml?link=https%3A%2F%2Fgithub.com%2Faftershootco%2Fmnn-rs%2Factions%2Fworkflows%2Fbuild.yaml) 11 | //! # Example 12 | //! ```rust,no_run 13 | //! use mnn::*; 14 | //! let mut interpreter = Interpreter::from_bytes([0;100]).unwrap(); 15 | //! let mut sc = ScheduleConfig::new(); 16 | //! let session = interpreter.create_session(sc).unwrap(); 17 | //! let mut input = interpreter.input::(&session, "input").unwrap(); 18 | //! let mut tensor = input.create_host_tensor_from_device(false); 19 | //! tensor.host_mut().fill(1.0f32); 20 | //! input.copy_from_host_tensor(&tensor).unwrap(); 21 | //! interpreter.run_session(&session).unwrap(); 22 | //! let output = interpreter.output::(&session, "output").unwrap(); 23 | //! let mut output_tensor = output.create_host_tensor_from_device(true); 24 | //! std::fs::write("output.bin", output_tensor.host().to_vec()).unwrap(); 25 | //! ``` 26 | //! **NOTE:** The library is still in development and the API is subject to change. 27 | //! 28 | //! ## Features 29 | //! - `metal`: Enable mnn Metal backend 30 | //! - `coreml`: Enable mnn CoreML backend 31 | //! - `vulkan`: Enable mnn Vulkan backend (unimplemented from rust wrapper) 32 | //! - `opencl`: Enable mnn OpenCL backend 33 | //! - `opengl`: Enable mnn OpenGL backend (unimplemented from rust wrapper) 34 | //! - `openmp`: Enable mnn Openmp ( disable the mnn-threadpool feature to enable this) 35 | //! - `mnn-threadpool`: Enable mnn threadpool ( enabled by default can't be used with openmp) 36 | //! - `sync`: Enable sync api 37 | //! - `profile`: Enable profiling ( emits some profiling tracing events ) 38 | //! - `tracing`: Enable tracing ( emits some tracing events ) 39 | //! - `crt_static`: Link statically to the C runtime on windows (noop on other platforms) 40 | //! ## License 41 | //! This links to the MNN library which is licensed under the Apache License 2.0. 42 | //! The rust bindings are licensed under the same Apache License 2.0. 43 | //! 44 | //! ## Building 45 | //! The flake.nix provides a nix-shell with all the dependencies required to build the library. 46 | //! If not using nix you'll need to clone the git submodule to get the MNN source code in mnn-sys/vendor first 47 | //! Or you can export the MNN_SRC environment variable to point to the MNN source code. 48 | //! 49 | //! ## Compatibility Chart for current crate 50 | //! | MNN Backend | Compiles | Works | 51 | //! | ----------- | -------- | ----- | 52 | //! | CPU | ✅ | ✅ | 53 | //! | OpenCL | ✅ | ✅ | 54 | //! | Metal | ✅ | ✅ | 55 | //! | CoreML | ✅ | 🚸 | 56 | //! | OpenGL | ❌ | ❌ | 57 | //! | Vulkan | ❌ | ❌ | 58 | //! 59 | //! - ✅ - Works 60 | //! - 🚸 - Some models work 61 | //! - ❌ - Doesn't work 62 | 63 | /// Re-export of whole mnn-sys 64 | pub mod ffi { 65 | pub use mnn_sys::*; 66 | } 67 | 68 | mod profile; 69 | 70 | pub mod backend; 71 | /// Error handling 72 | pub mod error; 73 | /// MNN::Interpreter related items 74 | pub mod interpreter; 75 | /// Schedule configuration 76 | pub mod schedule; 77 | /// MNN::Session related items 78 | pub mod session; 79 | /// MNN::Tensor related items 80 | pub mod tensor; 81 | 82 | pub use backend::*; 83 | pub use error::*; 84 | pub use interpreter::*; 85 | pub use schedule::*; 86 | pub use session::*; 87 | pub use tensor::*; 88 | 89 | pub use ffi::HalideType; 90 | pub use ffi::MapType; 91 | 92 | /// Re-export of commonly used items 93 | pub mod prelude { 94 | pub use crate::error::*; 95 | pub(crate) use crate::profile::profile; 96 | pub use core::marker::PhantomData; 97 | pub use error_stack::{Report, ResultExt}; 98 | pub use libc::*; 99 | pub use mnn_sys::{HalideType, MapType}; 100 | } 101 | -------------------------------------------------------------------------------- /tools/bencher/src/cli.rs: -------------------------------------------------------------------------------- 1 | use std::path::PathBuf; 2 | 3 | use chumsky::prelude::*; 4 | // fn parse() -> impl Parser> 5 | // fn models() -> impl Parser> { 6 | // let model = super::ModelIO::parser(); 7 | // let comma = char(',').skip_many1(); 8 | // let models = model.sep_by(comma); 9 | // models 10 | // } 11 | pub enum ModelIOArgs { 12 | Path(PathBuf), 13 | Assert(PathBuf), 14 | InputType(super::DataTypes), 15 | OutputType(super::DataTypes), 16 | } 17 | 18 | // pub fn arg<'a, T: Clone + 'a, E: chumsky::Error<&'a T>>( 19 | // s: T, 20 | // ) -> chumsky::primitive::Just<&'a [T], &'a [T], E> { 21 | // just(&[s]) 22 | // } 23 | macro_rules! arg { 24 | ($s:expr) => { 25 | just::<&str, _, Simple<&str>>($s) 26 | }; 27 | } 28 | 29 | fn models<'a>() -> impl Parser<&'a str, Vec, Error = Simple<&'a str>> { 30 | let assert = choice((arg!("--assert"), arg!("-a"))) 31 | .then(path()) 32 | .map(|(_, p)| p); 33 | let data_type = choice(( 34 | arg!("f32").to(super::DataTypes::F32), 35 | arg!("u8").to(super::DataTypes::U8), 36 | )); 37 | let input_type = choice((arg!("--input-type"), arg!("-i"))) 38 | .then(data_type) 39 | .map(|(_, t)| t); 40 | let output_type = choice((arg!("--output-type"), arg!("-o"))) 41 | .then(choice(( 42 | arg!("f32").to(super::DataTypes::F32), 43 | arg!("u8").to(super::DataTypes::U8), 44 | ))) 45 | .map(|(_, t)| t); 46 | let args = choice(( 47 | // path.map(|p| ModelIOArgs::Path(p)), 48 | assert.map(|p| ModelIOArgs::Assert(p)), 49 | input_type.map(|t| ModelIOArgs::InputType(t)), 50 | output_type.map(|t| ModelIOArgs::OutputType(t)), 51 | )) 52 | .repeated(); 53 | let mios = path().then(args).map(|(p, margs)| { 54 | let mut mio = super::ModelIO::default(); 55 | mio.path = p; 56 | margs.into_iter().for_each(|arg| match arg { 57 | ModelIOArgs::Path(p) => mio.path = p, 58 | ModelIOArgs::Assert(p) => mio.assert = Some(p), 59 | ModelIOArgs::InputType(t) => mio.input_type = t, 60 | ModelIOArgs::OutputType(t) => mio.output_type = t, 61 | }); 62 | mio 63 | }); 64 | mios.repeated() 65 | } 66 | 67 | #[derive(Debug, Clone)] 68 | pub enum Flags { 69 | Verbose, 70 | Warmup(u8), 71 | Output(PathBuf), 72 | Exec, 73 | } 74 | fn flags<'a>() -> impl Parser<&'a str, Vec, Error = Simple<&'a str>> { 75 | choice(( 76 | choice((arg!("--verbose"), arg!("-v"))).to(Flags::Verbose), 77 | choice((arg!("--warmup"), arg!("-w"))) 78 | .ignore_then(any().from_str().unwrapped()) 79 | .map(Flags::Warmup), 80 | )) 81 | .repeated() 82 | } 83 | 84 | fn path<'i>() -> impl Parser<&'i str, PathBuf, Error = Simple<&'i str>> { 85 | any().map(|c| PathBuf::from(c)) 86 | } 87 | 88 | impl super::Cli { 89 | pub fn try_from_env() -> super::Result { 90 | // let args: Vec<_> = std::env::args() 91 | // // .enumerate() 92 | // // .map(|(i, a)| (a, i..i + 1)) 93 | // .collect(); 94 | // let args_str: Vec<_> = args 95 | // .iter() 96 | // // .enumerate() 97 | // // .map(|(i, item)| (item.as_str(), i..i + 1)) 98 | // .map(|i| i.as_str()) 99 | // .collect(); 100 | let args = std::env::args().collect::>(); 101 | let args_str = args.iter().map(|i| i.as_str()).collect::>(); 102 | 103 | let mio = path() 104 | .then(choice((models().to(()), flags().to(())))) 105 | .parse(args_str); 106 | 107 | // let mio = super::ModelIO::parse().parse(args_str.as_slice()); 108 | dbg!(mio.unwrap()); 109 | todo!() 110 | } 111 | } 112 | #[derive(Debug, Clone, ValueEnum, Default)] 113 | pub enum DataTypes { 114 | #[default] 115 | F32, 116 | U8, 117 | } 118 | 119 | #[derive(Debug, Clone, Args, Default)] 120 | pub struct ModelIO { 121 | path: PathBuf, 122 | #[clap(short, long)] 123 | assert: Option, 124 | #[clap(short, long, default_value = "f32")] 125 | input_type: DataTypes, 126 | #[clap(short, long, default_value = "f32")] 127 | output_type: DataTypes, 128 | } 129 | impl AsRef for ModelIO { 130 | fn as_ref(&self) -> &Path { 131 | &self.path 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use mnn_sys::ErrorCode; 2 | 3 | #[doc(hidden)] 4 | pub type Result = core::result::Result; 5 | 6 | /// Error type container for MNN 7 | pub struct MNNError { 8 | kind: error_stack::Report, 9 | } 10 | 11 | impl core::fmt::Display for MNNError { 12 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 13 | write!(f, "{:?}", self.kind) 14 | } 15 | } 16 | 17 | impl core::fmt::Debug for MNNError { 18 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 19 | write!(f, "{:?}", self.kind) 20 | } 21 | } 22 | 23 | impl std::error::Error for MNNError {} 24 | // pub type MNNError = error_stack::Report; 25 | 26 | /// Error types for MNN 27 | #[derive(thiserror::Error, Debug)] 28 | pub enum ErrorKind { 29 | /// Internal error (from MNN library) 30 | #[error("Internal error: {0:?}")] 31 | InternalError(ErrorCode), 32 | /// Mismatching Size for input 33 | #[error("Invalid input: expected {expected}, got {got}")] 34 | SizeMismatch { 35 | /// Expected size 36 | expected: usize, 37 | /// Provided size 38 | got: usize, 39 | }, 40 | /// Failed to copy tensor 41 | #[error("Failed to copy tensor")] 42 | TensorCopyFailed(i32), 43 | /// I/O Error 44 | #[error("IO Error")] 45 | IOError, 46 | /// Interpreter Error 47 | #[error("Interpreter Error")] 48 | InterpreterError, 49 | /// ASCII Error (path, name, etc had invalid characters) 50 | #[error("Ascii Error")] 51 | AsciiError, 52 | /// HalideType mismatch (e.g. trying to convert from a float tensor to an int tensor) 53 | #[error("HalideType mismatch: got {got}")] 54 | HalideTypeMismatch { 55 | /// HalideType that was 56 | got: &'static str, 57 | }, 58 | /// Failed to parse the Argument 59 | #[error("Parse Error")] 60 | ParseError, 61 | /// Error with mnn-sync crate 62 | #[error("Sync Error")] 63 | SyncError, 64 | /// Error with some tensor 65 | #[error("Tensor Error")] 66 | TensorError, 67 | /// Tried to run a dynamic tensor without resizing it first 68 | #[error("Dynamic Tensor Error: Tensor needs to be resized before using")] 69 | DynamicTensorError, 70 | } 71 | 72 | impl MNNError { 73 | #[track_caller] 74 | #[doc(hidden)] 75 | pub fn new(kind: ErrorKind) -> Self { 76 | let kind = error_stack::Report::new(kind); 77 | Self { kind } 78 | } 79 | 80 | #[track_caller] 81 | pub(crate) fn from_error_code(code: ErrorCode) -> Self { 82 | Self::new(ErrorKind::InternalError(code)) 83 | } 84 | 85 | /// Return the inner [error_stack::Report] containing the error 86 | #[inline(always)] 87 | pub fn into_inner(self) -> error_stack::Report { 88 | self.kind 89 | } 90 | } 91 | 92 | impl From for MNNError { 93 | #[track_caller] 94 | fn from(kind: ErrorKind) -> Self { 95 | Self::new(kind) 96 | } 97 | } 98 | 99 | macro_rules! ensure { 100 | ($cond:expr, $kind:expr) => { 101 | if !($cond) { 102 | return Err(crate::error::MNNError::new($kind)); 103 | } 104 | }; 105 | 106 | ($cond:expr, $kind:expr; $($printable:expr),*) => { 107 | if !($cond) { 108 | return Err(crate::error::MNNError::new($kind) 109 | $(.attach_printable($printable))* 110 | ) 111 | } 112 | }; 113 | 114 | 115 | ($cond:expr, $from:expr, $to:expr) => { 116 | if (!$cond) { 117 | return Err(error_stack::Report::new($from).change_context($to)); 118 | } 119 | }; 120 | ($cond:expr, $from:expr, $to:expr; $($printable:expr),*) => { 121 | if (!$cond) { 122 | return Err(error_stack::Report::new($from) 123 | .change_context($to) 124 | $(.attach_printable($printable))* 125 | ) 126 | } 127 | }; 128 | } 129 | 130 | macro_rules! error { 131 | ($kind:expr) => { 132 | crate::error::MNNError::new($kind) 133 | }; 134 | ($kind:expr, $from:expr) => { 135 | crate::error::MNNError::from(error_stack::Report::new($from).change_context($kind)) 136 | }; 137 | } 138 | 139 | pub(crate) use ensure; 140 | pub(crate) use error; 141 | 142 | impl From> for MNNError { 143 | #[track_caller] 144 | fn from(report: error_stack::Report) -> Self { 145 | Self { kind: report } 146 | } 147 | } 148 | 149 | impl MNNError { 150 | pub(crate) fn attach_printable( 151 | self, 152 | printable: impl core::fmt::Display + core::fmt::Debug + Send + Sync + 'static, 153 | ) -> Self { 154 | let kind = self.kind.attach_printable(printable); 155 | Self { kind } 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /examples/inspect.rs: -------------------------------------------------------------------------------- 1 | use mnn::*; 2 | use std::path::PathBuf; 3 | 4 | #[derive(Debug, clap::Parser, Clone)] 5 | pub struct Cli { 6 | model: PathBuf, 7 | #[clap(short, long)] 8 | forward: ForwardType, 9 | #[clap(short, long, default_value = "high")] 10 | power: PowerMode, 11 | #[clap(short = 'P', long, default_value = "high")] 12 | precision: PrecisionMode, 13 | #[clap(short, long, default_value = "high")] 14 | memory: MemoryMode, 15 | #[clap(short, long, default_value = "f32")] 16 | output_data_type: DataType, 17 | #[clap(short, long, default_value = "f32")] 18 | input_data_type: DataType, 19 | #[clap(short, long, default_value = "1")] 20 | loops: usize, 21 | #[clap(short, long)] 22 | no_cache: bool, 23 | } 24 | 25 | #[derive(Debug, Clone, clap::ValueEnum)] 26 | pub enum DataType { 27 | F32, 28 | U8, 29 | I8, 30 | } 31 | 32 | macro_rules! time { 33 | ($($x:expr),+ ; $text: expr) => { 34 | { 35 | let start = std::time::Instant::now(); 36 | let result = { $($x);+ }; 37 | let elapsed = start.elapsed(); 38 | println!("{}: took: {:?}", $text,elapsed ); 39 | result 40 | } 41 | }; 42 | ($($x:expr),+) => { 43 | time!($($x),+; "") 44 | }; 45 | } 46 | 47 | pub fn main() -> anyhow::Result<()> { 48 | use clap::Parser; 49 | let cli = Cli::parse(); 50 | let mut interpreter = Interpreter::from_file(&cli.model)?; 51 | if !cli.no_cache { 52 | interpreter.set_cache_file(cli.model.with_extension("cache"), 128)?; 53 | } 54 | 55 | tracing_subscriber::fmt() 56 | .event_format( 57 | tracing_subscriber::fmt::format() 58 | .with_file(true) 59 | .with_line_number(true), 60 | ) 61 | .init(); 62 | 63 | let mut config = ScheduleConfig::new(); 64 | config.set_type(cli.forward); 65 | let mut session = time!(interpreter.create_session(config)?; "create session"); 66 | if !cli.no_cache { 67 | interpreter.update_cache_file(&mut session)?; 68 | } 69 | 70 | let mut current = 0; 71 | println!("--------------------------------Info--------------------------------"); 72 | let mem = interpreter.memory(&session)?; 73 | let flops = interpreter.flops(&session)?; 74 | println!("Memory: {:?}MiB", mem); 75 | println!("Flops : {:?}M", flops); 76 | println!("ResizeStatus : {:?}", interpreter.resize_status(&session)?); 77 | 78 | time!(loop { 79 | println!("--------------------------------Inputs--------------------------------"); 80 | interpreter.inputs(&session).iter().for_each(|x| { 81 | match cli.input_data_type { 82 | DataType::F32 => { 83 | let mut tensor = x.tensor::().expect("No tensor"); 84 | println!("{}: {:?}", x.name(), tensor.shape()); 85 | tensor.fill(1.0f32); 86 | }, 87 | DataType::U8 => { 88 | let mut tensor = x.tensor::().expect("No tensor"); 89 | println!("{}: {:?}", x.name(), tensor.shape()); 90 | tensor.fill(1u8); 91 | }, 92 | DataType::I8 => { 93 | let mut tensor = x.tensor::().expect("No tensor"); 94 | println!("{}: {:?}", x.name(), tensor.shape()); 95 | tensor.fill(1i8); 96 | }, 97 | }; 98 | }); 99 | 100 | println!("Running session"); 101 | interpreter.run_session(&session)?; 102 | println!("--------------------------------Outputs--------------------------------"); 103 | let outputs = interpreter.outputs(&session); 104 | outputs.iter().for_each(|x| { 105 | match cli.output_data_type { 106 | DataType::F32 => { 107 | let tensor = x.tensor::().expect("No tensor"); 108 | println!("{}: {:?}", x.name(), tensor.shape()); 109 | time!(tensor.wait(MapType::MAP_TENSOR_READ, true); format!("Waiting for tensor: {}", x.name())); 110 | }, 111 | DataType::U8 => { 112 | let tensor = x.tensor::().expect("No tensor"); 113 | println!("{}: {:?}", x.name(), tensor.shape()); 114 | time!(tensor.wait(MapType::MAP_TENSOR_READ, true); format!("Waiting for tensor: {}", x.name())); 115 | }, 116 | DataType::I8 => { 117 | let tensor = x.tensor::().expect("No tensor"); 118 | println!("{}: {:?}", x.name(), tensor.shape()); 119 | time!(tensor.wait(MapType::MAP_TENSOR_READ, true); format!("Waiting for tensor: {}", x.name())); 120 | }, 121 | }; 122 | 123 | }); 124 | current += 1; 125 | if current >= cli.loops { 126 | break; 127 | } 128 | }; "run loop"); 129 | Ok(()) 130 | } 131 | -------------------------------------------------------------------------------- /src/tensor/list.rs: -------------------------------------------------------------------------------- 1 | #![deny(missing_docs)] 2 | use crate::{Device, RawTensor, RefMut, Tensor, prelude::*}; 3 | use mnn_sys::HalideType; 4 | 5 | #[repr(transparent)] 6 | pub struct TensorList<'t> { 7 | pub(crate) inner: *const mnn_sys::TensorInfoArray, 8 | pub(crate) __marker: PhantomData<&'t ()>, 9 | } 10 | 11 | impl core::fmt::Debug for TensorList<'_> { 12 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> core::fmt::Result { 13 | f.debug_list().entries(self.iter()).finish() 14 | } 15 | } 16 | 17 | impl Drop for TensorList<'_> { 18 | fn drop(&mut self) { 19 | unsafe { mnn_sys::destroyTensorInfoArray(self.inner.cast_mut()) } 20 | } 21 | } 22 | 23 | impl<'t> TensorList<'t> { 24 | pub(crate) fn from_ptr(inner: *const mnn_sys::TensorInfoArray) -> Self { 25 | Self { 26 | inner, 27 | __marker: PhantomData, 28 | } 29 | } 30 | 31 | /// Returns the size of the tensor list 32 | pub fn size(&self) -> usize { 33 | unsafe { (*self.inner).size } 34 | } 35 | 36 | /// Get the tensor at the given index 37 | pub fn get(&self, index: usize) -> Option> { 38 | if index >= self.size() { 39 | None 40 | } else { 41 | let gtinfo = unsafe { mnn_sys::getTensorInfoArray(self.inner, index) }; 42 | if !gtinfo.is_null() { 43 | Some(TensorInfo { 44 | tensor_info: gtinfo, 45 | __marker: PhantomData, 46 | }) 47 | } else { 48 | None 49 | } 50 | } 51 | } 52 | 53 | /// Get an iterator over the tensor list 54 | pub fn iter(&self) -> TensorListIter { 55 | TensorListIter { 56 | tensor_list: self, 57 | idx: 0, 58 | } 59 | } 60 | } 61 | 62 | impl<'t, 'tl: 't> IntoIterator for &'tl TensorList<'t> { 63 | type Item = TensorInfo<'t, 'tl>; 64 | type IntoIter = TensorListIter<'t, 'tl>; 65 | 66 | fn into_iter(self) -> Self::IntoIter { 67 | TensorListIter { 68 | tensor_list: self, 69 | idx: 0, 70 | } 71 | } 72 | } 73 | 74 | pub struct TensorListIter<'t, 'tl> { 75 | tensor_list: &'tl TensorList<'t>, 76 | idx: usize, 77 | } 78 | impl<'t, 'tl> Iterator for TensorListIter<'t, 'tl> { 79 | type Item = TensorInfo<'t, 'tl>; 80 | fn next(&mut self) -> Option { 81 | let idx = self.idx; 82 | self.idx += 1; 83 | self.tensor_list.get(idx) 84 | } 85 | } 86 | 87 | #[repr(transparent)] 88 | pub struct TensorInfo<'t, 'tl> { 89 | pub(crate) tensor_info: *mut mnn_sys::TensorInfo, 90 | pub(crate) __marker: PhantomData<&'tl TensorList<'t>>, 91 | } 92 | 93 | impl core::fmt::Debug for TensorInfo<'_, '_> { 94 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 95 | let tensor = self.raw_tensor(); 96 | let shape = tensor.shape(); 97 | f.debug_struct("TensorInfo") 98 | .field("name", &self.name()) 99 | .field("tensor", &shape) 100 | .finish() 101 | } 102 | } 103 | 104 | impl<'t, 'tl> TensorInfo<'t, 'tl> { 105 | pub fn name(&self) -> &'tl str { 106 | debug_assert!(!self.tensor_info.is_null()); 107 | unsafe { (*self.tensor_info).name.to_cstr() } 108 | .to_str() 109 | .expect("Tensor name is not utf-8") 110 | } 111 | 112 | pub fn tensor(&self) -> Result>>> { 113 | debug_assert!(!self.tensor_info.is_null()); 114 | unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) }; 115 | let tensor = unsafe { Tensor::from_ptr((*self.tensor_info).tensor.cast()) }; 116 | let shape = tensor.shape(); 117 | ensure!(!shape.as_ref().contains(&-1), ErrorKind::DynamicTensorError); 118 | ensure!( 119 | tensor.is_type_of::(), 120 | ErrorKind::HalideTypeMismatch { 121 | got: std::any::type_name::(), 122 | } 123 | ); 124 | Ok(tensor) 125 | } 126 | 127 | /// # Safety 128 | /// The shape is not checked so it's marked unsafe since futher calls to interpreter might be **unsafe** with this 129 | pub unsafe fn tensor_unresized(&self) -> Result>>> { 130 | debug_assert!(!self.tensor_info.is_null()); 131 | unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) }; 132 | let tensor = unsafe { Tensor::from_ptr((*self.tensor_info).tensor.cast()) }; 133 | ensure!( 134 | tensor.is_type_of::(), 135 | ErrorKind::HalideTypeMismatch { 136 | got: std::any::type_name::(), 137 | } 138 | ); 139 | Ok(tensor) 140 | } 141 | 142 | /// This function return's the raw tensor without any sort of type-checking or shape-checking 143 | pub fn raw_tensor(&self) -> RawTensor<'t> { 144 | debug_assert!(!self.tensor_info.is_null()); 145 | unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) }; 146 | RawTensor::from_ptr(unsafe { (*self.tensor_info).tensor.cast() }) 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /src/tensor/raw.rs: -------------------------------------------------------------------------------- 1 | use crate::prelude::*; 2 | use core::marker::PhantomData; 3 | use mnn_sys::HalideType; 4 | /// A raw tensor type that doesn't have any guarantees 5 | /// and will be unconditionally dropped 6 | #[repr(transparent)] 7 | pub struct RawTensor<'r> { 8 | pub(crate) inner: *mut mnn_sys::Tensor, 9 | pub(crate) __marker: PhantomData<&'r ()>, 10 | } 11 | 12 | // impl<'r> core::ops::Drop for RawTensor<'r> { 13 | // fn drop(&mut self) { 14 | // unsafe { 15 | // mnn_sys::Tensor_destroy(self.inner); 16 | // } 17 | // } 18 | // } 19 | 20 | impl RawTensor<'_> { 21 | /// Creates a new host tensor from the device tensor 22 | pub fn create_host_tensor_from_device(&self, copy_data: bool) -> RawTensor<'static> { 23 | let tensor = 24 | unsafe { mnn_sys::Tensor_createHostTensorFromDevice(self.inner, copy_data as i32) }; 25 | // crate::ensure!(!tensor.is_null(), ErrorKind::TensorError); 26 | assert!(!tensor.is_null()); 27 | RawTensor { 28 | inner: tensor, 29 | __marker: PhantomData, 30 | } 31 | } 32 | 33 | /// Copies the data from a host tensor to the self tensor 34 | pub fn copy_from_host_tensor(&mut self, tensor: &RawTensor) -> Result<()> { 35 | let ret = unsafe { mnn_sys::Tensor_copyFromHostTensor(self.inner, tensor.inner) }; 36 | crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret)); 37 | Ok(()) 38 | } 39 | 40 | /// Copies the data from the self tensor to a host tensor 41 | pub fn copy_to_host_tensor(&self, tensor: &mut RawTensor) -> Result<()> { 42 | let ret = unsafe { mnn_sys::Tensor_copyToHostTensor(self.inner, tensor.inner) }; 43 | crate::ensure!(ret != 0, ErrorKind::TensorCopyFailed(ret)); 44 | Ok(()) 45 | } 46 | 47 | /// Returns the shape of the tensor 48 | pub fn shape(&self) -> crate::TensorShape { 49 | unsafe { mnn_sys::Tensor_shape(self.inner) }.into() 50 | } 51 | 52 | /// Returns the dimension type of the tensor 53 | pub fn get_dimension_type(&self) -> super::DimensionType { 54 | debug_assert!(!self.inner.is_null()); 55 | From::from(unsafe { mnn_sys::Tensor_getDimensionType(self.inner) }) 56 | } 57 | 58 | /// Cleans up the tensor by calling the destructor of the tensor 59 | pub fn destroy(self) { 60 | unsafe { 61 | mnn_sys::Tensor_destroy(self.inner); 62 | } 63 | } 64 | 65 | /// Returns the size of the tensor when counted by bytes 66 | pub fn size(&self) -> usize { 67 | unsafe { mnn_sys::Tensor_usize(self.inner) } 68 | } 69 | 70 | /// Returns the size of the tensor when counted by elements 71 | pub fn element_size(&self) -> usize { 72 | unsafe { mnn_sys::Tensor_elementSize(self.inner) as usize } 73 | } 74 | 75 | /// Returns the number of dimensions of the tensor 76 | pub fn dimensions(&self) -> usize { 77 | unsafe { mnn_sys::Tensor_dimensions(self.inner) as usize } 78 | } 79 | 80 | /// Returns the width of the tensor 81 | pub fn width(&self) -> u32 { 82 | unsafe { mnn_sys::Tensor_width(self.inner) as u32 } 83 | } 84 | 85 | /// Returns the height of the tensor 86 | pub fn height(&self) -> u32 { 87 | unsafe { mnn_sys::Tensor_height(self.inner) as u32 } 88 | } 89 | 90 | /// Returns the channel of the tensor 91 | pub fn channel(&self) -> u32 { 92 | unsafe { mnn_sys::Tensor_channel(self.inner) as u32 } 93 | } 94 | 95 | /// Returns true if the tensor is unsized and dynamic (needs to be resized to work) 96 | pub fn is_dynamic_unsized(&self) -> bool { 97 | self.shape().as_ref().contains(&-1) 98 | } 99 | 100 | /// Waits for the tensor to be ready 101 | pub fn wait(&self, map_type: MapType, finish: bool) { 102 | unsafe { 103 | mnn_sys::Tensor_wait(self.inner, map_type, finish as i32); 104 | } 105 | } 106 | 107 | /// # Safety 108 | /// This is very unsafe do not use this unless you know what you are doing 109 | /// Gives a raw pointer to the tensor's data 110 | /// P.S. I don't know what I'm doing 111 | pub unsafe fn unchecked_host_ptr(&self) -> *mut c_void { 112 | debug_assert!(!self.inner.is_null()); 113 | let data = unsafe { mnn_sys::Tensor_host_mut(self.inner) }; 114 | debug_assert!(!data.is_null()); 115 | data 116 | } 117 | 118 | /// # Safety 119 | /// This is very unsafe do not use this unless you know what you are doing 120 | /// Gives a mutable byte slice to the tensor's data 121 | pub unsafe fn unchecked_host_bytes(&mut self) -> &mut [u8] { 122 | unsafe { core::slice::from_raw_parts_mut(self.unchecked_host_ptr().cast(), self.size()) } 123 | } 124 | 125 | /// # Safety 126 | /// This is very unsafe do not use this unless you know what you are doing 127 | pub unsafe fn to_concrete(self) -> super::Tensor 128 | where 129 | T::H: HalideType, 130 | { 131 | unsafe { super::Tensor::from_ptr(self.inner) } 132 | } 133 | 134 | pub(crate) fn from_ptr(inner: *mut mnn_sys::Tensor) -> Self { 135 | Self { 136 | inner, 137 | __marker: PhantomData, 138 | } 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /mnn-bridge/src/ndarray.rs: -------------------------------------------------------------------------------- 1 | use error_stack::*; 2 | use ndarray::*; 3 | 4 | #[derive(Debug)] 5 | pub struct MnnBridge; 6 | impl Context for MnnBridge {} 7 | impl core::fmt::Display for MnnBridge { 8 | fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { 9 | write!(f, "MnnBridgeError") 10 | } 11 | } 12 | 13 | pub trait MnnToNdarray { 14 | type H: mnn::HalideType; 15 | fn as_ndarray(&self) -> ndarray::ArrayView { 16 | self.try_as_ndarray::() 17 | .expect("Failed to create ndarray::ArrayViewD from mnn::Tensor") 18 | } 19 | fn try_as_ndarray(&self) -> Result, MnnBridge>; 20 | } 21 | 22 | pub trait MnnToNdarrayMut { 23 | type H: mnn::HalideType; 24 | fn as_ndarray_mut(&mut self) -> ndarray::ArrayViewMut { 25 | self.try_as_ndarray_mut::() 26 | .expect("Failed to create ndarray::ArrayViewMutD from mnn::Tensor") 27 | } 28 | fn try_as_ndarray_mut( 29 | &mut self, 30 | ) -> Result, MnnBridge>; 31 | } 32 | 33 | pub trait NdarrayToMnn { 34 | type H: mnn::HalideType; 35 | fn as_mnn_tensor(&self) -> Option>>>; 36 | } 37 | 38 | pub trait NdarrayToMnnMut { 39 | type H: mnn::HalideType; 40 | fn as_mnn_tensor_mut(&mut self) -> Option>>>; 41 | } 42 | 43 | const _: () = { 44 | impl MnnToNdarray for mnn::Tensor 45 | where 46 | T: mnn::TensorType + mnn::HostTensorType, 47 | T::H: mnn::HalideType, 48 | { 49 | type H = T::H; 50 | fn try_as_ndarray( 51 | &self, 52 | ) -> Result, MnnBridge> { 53 | let shape = self 54 | .shape() 55 | .as_ref() 56 | .into_iter() 57 | .copied() 58 | .map(|i| i as usize) 59 | .collect::>(); 60 | let data = self.host(); 61 | Ok(ndarray::ArrayViewD::from_shape(shape, data) 62 | .change_context(MnnBridge)? 63 | .into_dimensionality() 64 | .change_context(MnnBridge)?) 65 | } 66 | } 67 | 68 | impl MnnToNdarrayMut for mnn::Tensor 69 | where 70 | T: mnn::TensorType + mnn::MutableTensorType + mnn::HostTensorType, 71 | T::H: mnn::HalideType, 72 | { 73 | type H = T::H; 74 | fn try_as_ndarray_mut( 75 | &mut self, 76 | ) -> Result, MnnBridge> { 77 | let shape = self 78 | .shape() 79 | .as_ref() 80 | .into_iter() 81 | .copied() 82 | .map(|i| i as usize) 83 | .collect::>(); 84 | let data = self.host_mut(); 85 | Ok(ndarray::ArrayViewMutD::from_shape(shape, data) 86 | .change_context(MnnBridge)? 87 | .into_dimensionality() 88 | .change_context(MnnBridge)?) 89 | } 90 | } 91 | 92 | impl NdarrayToMnn for ndarray::ArrayBase 93 | where 94 | A: ndarray::Data, 95 | D: ndarray::Dimension, 96 | T: mnn::HalideType, 97 | { 98 | type H = T; 99 | fn as_mnn_tensor(&self) -> Option>>> { 100 | let shape = self.shape().iter().map(|i| *i as i32).collect::>(); 101 | let data = self.as_slice()?; 102 | Some(mnn::Tensor::borrowed(shape, data)) 103 | } 104 | } 105 | 106 | impl NdarrayToMnnMut for ndarray::ArrayBase 107 | where 108 | A: ndarray::DataMut, 109 | D: ndarray::Dimension, 110 | T: mnn::HalideType, 111 | { 112 | type H = T; 113 | fn as_mnn_tensor_mut(&mut self) -> Option>>> { 114 | let shape = self.shape().iter().map(|i| *i as i32).collect::>(); 115 | let data = self.as_slice_mut()?; 116 | Some(mnn::Tensor::borrowed_mut(shape, data)) 117 | } 118 | } 119 | }; 120 | #[test] 121 | pub fn test_tensor_to_ndarray_ref() { 122 | let mut tensor: mnn::Tensor> = 123 | mnn::Tensor::new([1, 2, 3], mnn::DimensionType::Caffe); 124 | tensor.fill(64); 125 | let ndarr = tensor.as_ndarray(); 126 | let ndarr_other = ndarray::Array3::from_shape_vec([1, 2, 3], [64; 6].to_vec()).unwrap(); 127 | assert_eq!(ndarr, ndarr_other); 128 | } 129 | #[test] 130 | pub fn test_tensor_to_ndarray_ref_mut() { 131 | let mut data = vec![100; 8 * 8 * 3]; 132 | let mut tensor: mnn::Tensor>> = 133 | mnn::Tensor::borrowed_mut([8, 8, 3], &mut data); 134 | let mut ndarray = tensor.as_ndarray_mut::(); 135 | ndarray.fill(600); 136 | assert_eq!(data, [600; 8 * 8 * 3]); 137 | } 138 | #[test] 139 | pub fn test_ndarray_to_tensor_ref_mut() { 140 | let mut arr = ndarray::Array3::from_shape_vec([1, 2, 3], [64; 6].to_vec()).unwrap(); 141 | arr.as_mnn_tensor_mut().unwrap().fill(600); 142 | assert_eq!(arr.as_slice().unwrap(), &[600; 6]); 143 | } 144 | #[test] 145 | pub fn test_ndarray_to_tensor_ref() { 146 | let arr = ndarray::Array3::from_shape_vec([1, 2, 3], [64; 6].to_vec()).unwrap(); 147 | let t = arr.as_mnn_tensor().unwrap(); 148 | assert_eq!(t.host(), &[64; 6]); 149 | } 150 | -------------------------------------------------------------------------------- /mnn-sys/src/tracing.rs: -------------------------------------------------------------------------------- 1 | // This is mostly adapted from tracing-gstreamer crate's implementation 2 | use once_cell::sync::OnceCell; 3 | use std::sync::atomic::AtomicUsize; 4 | use std::sync::{PoisonError, RwLock}; 5 | use std::{collections::BTreeMap, ffi::c_char}; 6 | use tracing_core::{field::FieldSet, identify_callsite, Callsite, Interest, Kind, Metadata}; 7 | 8 | pub const CALLSITE_INTEREST_NEVER: usize = 1; 9 | pub const CALLSITE_INTEREST_SOMETIMES: usize = 2; 10 | pub const CALLSITE_INTEREST_ALWAYS: usize = 3; 11 | 12 | #[repr(C)] 13 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)] 14 | pub enum Level { 15 | Info = 0, 16 | Error = 1, 17 | } 18 | 19 | impl From for tracing_core::Level { 20 | fn from(value: Level) -> Self { 21 | match value { 22 | Level::Info => tracing_core::Level::INFO, 23 | Level::Error => tracing_core::Level::ERROR, 24 | } 25 | } 26 | } 27 | 28 | pub struct DynamicCallsites { 29 | callsites: RwLock, 30 | } 31 | 32 | type Map = BTreeMap, &'static MnnCallsite>; 33 | 34 | impl DynamicCallsites { 35 | pub(crate) fn get() -> &'static Self { 36 | static MAP: OnceCell = OnceCell::new(); 37 | MAP.get_or_init(|| DynamicCallsites { 38 | callsites: RwLock::new(Map::new()), 39 | }) 40 | } 41 | 42 | fn callsite_for( 43 | &'static self, 44 | level: Level, 45 | line: Option, 46 | file: Option<&'static str>, 47 | ) -> &'static MnnCallsite { 48 | let mut guard = self 49 | .callsites 50 | .write() 51 | .unwrap_or_else(PoisonError::into_inner); 52 | let lookup_key = Key { level, line, file }; 53 | if let Some(callsite) = guard.get(&lookup_key) { 54 | return callsite; 55 | } 56 | let callsite = MnnCallsite::make_static(&lookup_key); 57 | let key = Key::<'static> { 58 | level, 59 | line, 60 | file: callsite.metadata.file(), 61 | }; 62 | guard.insert(key, callsite); 63 | tracing_core::callsite::register(callsite); 64 | callsite 65 | } 66 | } 67 | 68 | impl Callsite for MnnCallsite { 69 | fn set_interest(&self, interest: Interest) { 70 | self.interest.store( 71 | match () { 72 | _ if interest.is_never() => CALLSITE_INTEREST_NEVER, 73 | _ if interest.is_always() => CALLSITE_INTEREST_ALWAYS, 74 | _ => CALLSITE_INTEREST_SOMETIMES, 75 | }, 76 | std::sync::atomic::Ordering::Release, 77 | ); 78 | } 79 | 80 | fn metadata(&self) -> &Metadata<'_> { 81 | &self.metadata 82 | } 83 | } 84 | #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] 85 | pub struct Key<'k> { 86 | level: Level, 87 | line: Option, 88 | file: Option<&'k str>, 89 | } 90 | 91 | impl DynamicCallsites {} 92 | 93 | pub struct MnnCallsite { 94 | interest: AtomicUsize, 95 | metadata: Metadata<'static>, 96 | } 97 | 98 | impl MnnCallsite { 99 | pub fn make_static(key: &Key<'static>) -> &'static Self { 100 | unsafe { 101 | use std::alloc::GlobalAlloc as _; 102 | let callsite_layout = std::alloc::Layout::new::(); 103 | let alloc = std::alloc::System.alloc(callsite_layout); 104 | let callsite = alloc as *mut MnnCallsite; 105 | // No allocation for string required as they are static by default 106 | callsite.write(MnnCallsite { 107 | interest: AtomicUsize::new(0), 108 | metadata: Metadata::new( 109 | "", 110 | "mnn_ffi_emit", 111 | key.level.into(), 112 | key.file, 113 | key.line, 114 | None, 115 | FieldSet::new(&["message"], identify_callsite!(&*callsite)), 116 | Kind::EVENT, 117 | ), 118 | }); 119 | &*callsite 120 | } 121 | } 122 | 123 | pub(crate) fn interest(&self) -> Interest { 124 | match self.interest.load(std::sync::atomic::Ordering::Acquire) { 125 | CALLSITE_INTEREST_NEVER => Interest::never(), 126 | CALLSITE_INTEREST_SOMETIMES => Interest::sometimes(), 127 | CALLSITE_INTEREST_ALWAYS => Interest::always(), 128 | _ => panic!("attempting to obtain callsite's interest before its been set"), 129 | } 130 | } 131 | } 132 | 133 | #[no_mangle] 134 | extern "C" fn mnn_ffi_emit( 135 | file: *const c_char, 136 | line: libc::size_t, 137 | level: Level, 138 | message: *const c_char, 139 | ) { 140 | std::panic::catch_unwind(|| { 141 | let file: &'static str = unsafe { 142 | core::ffi::CStr::from_ptr(file) 143 | .to_str() 144 | .expect("Invalid filename for C file") 145 | }; 146 | 147 | let callsite = DynamicCallsites::get().callsite_for(level, Some(line as u32), Some(file)); 148 | // let interest = callsite.interest 149 | let interest = callsite.interest(); 150 | if interest.is_never() { 151 | return; 152 | } 153 | let meta = callsite.metadata(); 154 | tracing_core::dispatcher::get_default(move |dispatcher| { 155 | if !dispatcher.enabled(meta) { 156 | return; 157 | } 158 | let fields = meta.fields(); 159 | let message = unsafe { 160 | std::ffi::CStr::from_ptr(message) 161 | .to_str() 162 | .expect("Invalid message for C message") 163 | }; 164 | 165 | let message_value = &message as &dyn tracing_core::field::Value; 166 | let message_field = fields 167 | .into_iter() 168 | .next() 169 | .expect("Failed to get message field"); 170 | let values = &[(&message_field, Some(message_value))]; 171 | let valueset = fields.value_set(values); 172 | 173 | let event = tracing_core::Event::new(meta, &valueset); 174 | 175 | dispatcher.event(&event); 176 | }); 177 | }) 178 | .unwrap_or_else(|_e| { 179 | eprintln!("Panic in mnn_ffi_emit aborting"); 180 | // Cannot let the panic escape the ffi boundary 181 | std::process::abort(); 182 | }) 183 | } 184 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "advisory-db": { 4 | "flake": false, 5 | "locked": { 6 | "lastModified": 1748950236, 7 | "narHash": "sha256-kNiGMrXi5Bq/aWoQmnpK0v+ufQA4FOInhbkY56iUndc=", 8 | "owner": "rustsec", 9 | "repo": "advisory-db", 10 | "rev": "a1f651cba8bf224f52c5d55d8182b3bb0ebce49e", 11 | "type": "github" 12 | }, 13 | "original": { 14 | "owner": "rustsec", 15 | "repo": "advisory-db", 16 | "type": "github" 17 | } 18 | }, 19 | "crane": { 20 | "locked": { 21 | "lastModified": 1748970125, 22 | "narHash": "sha256-UDyigbDGv8fvs9aS95yzFfOKkEjx1LO3PL3DsKopohA=", 23 | "owner": "ipetkov", 24 | "repo": "crane", 25 | "rev": "323b5746d89e04b22554b061522dfce9e4c49b18", 26 | "type": "github" 27 | }, 28 | "original": { 29 | "owner": "ipetkov", 30 | "repo": "crane", 31 | "type": "github" 32 | } 33 | }, 34 | "flake-utils": { 35 | "inputs": { 36 | "systems": "systems" 37 | }, 38 | "locked": { 39 | "lastModified": 1731533236, 40 | "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", 41 | "owner": "numtide", 42 | "repo": "flake-utils", 43 | "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", 44 | "type": "github" 45 | }, 46 | "original": { 47 | "owner": "numtide", 48 | "repo": "flake-utils", 49 | "type": "github" 50 | } 51 | }, 52 | "flake-utils_2": { 53 | "inputs": { 54 | "systems": "systems_2" 55 | }, 56 | "locked": { 57 | "lastModified": 1731533236, 58 | "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", 59 | "owner": "numtide", 60 | "repo": "flake-utils", 61 | "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", 62 | "type": "github" 63 | }, 64 | "original": { 65 | "owner": "numtide", 66 | "repo": "flake-utils", 67 | "type": "github" 68 | } 69 | }, 70 | "mnn": { 71 | "flake": false, 72 | "locked": { 73 | "lastModified": 1749173738, 74 | "narHash": "sha256-pNljvQ4xMZ4VmuxQyXt+boNBZD0+UZNpNLrWrj8Rtfw=", 75 | "owner": "alibaba", 76 | "repo": "MNN", 77 | "rev": "ebdada82634300956e08bd4056ecfeb1e4f23b32", 78 | "type": "github" 79 | }, 80 | "original": { 81 | "owner": "alibaba", 82 | "ref": "3.2.0", 83 | "repo": "MNN", 84 | "type": "github" 85 | } 86 | }, 87 | "mnn-overlay": { 88 | "inputs": { 89 | "flake-utils": "flake-utils_2", 90 | "mnn": "mnn", 91 | "nixpkgs": [ 92 | "nixpkgs" 93 | ] 94 | }, 95 | "locked": { 96 | "lastModified": 1749204972, 97 | "narHash": "sha256-ICLU408iwxZA7uETBmEBuuForBIPLvonuy1hW/fuiME=", 98 | "owner": "uttarayan21", 99 | "repo": "mnn-nix-overlay", 100 | "rev": "7b97393977689e851a6840a8e1cbea058e67363a", 101 | "type": "github" 102 | }, 103 | "original": { 104 | "owner": "uttarayan21", 105 | "repo": "mnn-nix-overlay", 106 | "type": "github" 107 | } 108 | }, 109 | "mnn-src": { 110 | "flake": false, 111 | "locked": { 112 | "lastModified": 1749173738, 113 | "narHash": "sha256-pNljvQ4xMZ4VmuxQyXt+boNBZD0+UZNpNLrWrj8Rtfw=", 114 | "owner": "alibaba", 115 | "repo": "MNN", 116 | "rev": "ebdada82634300956e08bd4056ecfeb1e4f23b32", 117 | "type": "github" 118 | }, 119 | "original": { 120 | "owner": "alibaba", 121 | "ref": "3.2.0", 122 | "repo": "MNN", 123 | "type": "github" 124 | } 125 | }, 126 | "nix-github-actions": { 127 | "inputs": { 128 | "nixpkgs": [ 129 | "nixpkgs" 130 | ] 131 | }, 132 | "locked": { 133 | "lastModified": 1737420293, 134 | "narHash": "sha256-F1G5ifvqTpJq7fdkT34e/Jy9VCyzd5XfJ9TO8fHhJWE=", 135 | "owner": "nix-community", 136 | "repo": "nix-github-actions", 137 | "rev": "f4158fa080ef4503c8f4c820967d946c2af31ec9", 138 | "type": "github" 139 | }, 140 | "original": { 141 | "owner": "nix-community", 142 | "repo": "nix-github-actions", 143 | "type": "github" 144 | } 145 | }, 146 | "nixpkgs": { 147 | "locked": { 148 | "lastModified": 1748929857, 149 | "narHash": "sha256-lcZQ8RhsmhsK8u7LIFsJhsLh/pzR9yZ8yqpTzyGdj+Q=", 150 | "owner": "nixos", 151 | "repo": "nixpkgs", 152 | "rev": "c2a03962b8e24e669fb37b7df10e7c79531ff1a4", 153 | "type": "github" 154 | }, 155 | "original": { 156 | "owner": "nixos", 157 | "ref": "nixos-unstable", 158 | "repo": "nixpkgs", 159 | "type": "github" 160 | } 161 | }, 162 | "root": { 163 | "inputs": { 164 | "advisory-db": "advisory-db", 165 | "crane": "crane", 166 | "flake-utils": "flake-utils", 167 | "mnn-overlay": "mnn-overlay", 168 | "mnn-src": "mnn-src", 169 | "nix-github-actions": "nix-github-actions", 170 | "nixpkgs": "nixpkgs", 171 | "rust-overlay": "rust-overlay" 172 | } 173 | }, 174 | "rust-overlay": { 175 | "inputs": { 176 | "nixpkgs": [ 177 | "nixpkgs" 178 | ] 179 | }, 180 | "locked": { 181 | "lastModified": 1749177458, 182 | "narHash": "sha256-9HNq3EHZIvvxXQyEn0sYOywcESF1Xqw2Q8J1ZewcXuk=", 183 | "owner": "oxalica", 184 | "repo": "rust-overlay", 185 | "rev": "d58933b88cef7a05e9677e94352fd6fedba402cd", 186 | "type": "github" 187 | }, 188 | "original": { 189 | "owner": "oxalica", 190 | "repo": "rust-overlay", 191 | "type": "github" 192 | } 193 | }, 194 | "systems": { 195 | "locked": { 196 | "lastModified": 1681028828, 197 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 198 | "owner": "nix-systems", 199 | "repo": "default", 200 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 201 | "type": "github" 202 | }, 203 | "original": { 204 | "owner": "nix-systems", 205 | "repo": "default", 206 | "type": "github" 207 | } 208 | }, 209 | "systems_2": { 210 | "locked": { 211 | "lastModified": 1681028828, 212 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 213 | "owner": "nix-systems", 214 | "repo": "default", 215 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 216 | "type": "github" 217 | }, 218 | "original": { 219 | "owner": "nix-systems", 220 | "repo": "default", 221 | "type": "github" 222 | } 223 | } 224 | }, 225 | "root": "root", 226 | "version": 7 227 | } 228 | -------------------------------------------------------------------------------- /mnn-sys/mnn_c/tensor_c.cpp: -------------------------------------------------------------------------------- 1 | #include "tensor_c.h" 2 | #include "MNN/Tensor.hpp" 3 | #include "utils.h" 4 | #include 5 | #ifdef __DEBUG 6 | #include 7 | void code_bits_lanes(const char *name, halide_type_t *type) { 8 | printf("====================================\n"); 9 | printf("sizes: \n"); 10 | std::cout << "code: " << sizeof(type->code) << std::endl; 11 | std::cout << "bits: " << sizeof(type->bits) << std::endl; 12 | std::cout << "lanes: " << sizeof(type->lanes) << std::endl; 13 | printf("%s: cbt %d %d %d\n", name, type->code, type->bits, type->lanes); 14 | printf("sizeof(%s): %lu\n", name, sizeof(*type)); 15 | printf("====================================\n"); 16 | } 17 | #endif 18 | extern "C" { 19 | Tensor *Tensor_create(int dimSize, DimensionType type) { 20 | return reinterpret_cast( 21 | new MNN::Tensor(dimSize, static_cast(type))); 22 | } 23 | Tensor *Tensor_createFromTensor(const Tensor *tensor, DimensionType type, 24 | int allocMemory) { 25 | return reinterpret_cast(new MNN::Tensor( 26 | reinterpret_cast(tensor), 27 | static_cast(type), allocMemory)); 28 | } 29 | void Tensor_destroy(Tensor *tensor) { 30 | delete reinterpret_cast(tensor); 31 | } 32 | Tensor *Tensor_createDevice(const int *shape, size_t shapeSize, 33 | halide_type_t typeCode, DimensionType dimType) { 34 | std::vector shapeVec(shape, shape + shapeSize); 35 | return reinterpret_cast(MNN::Tensor::createDevice( 36 | shapeVec, typeCode, static_cast(dimType))); 37 | } 38 | Tensor *Tensor_createWith(const int *shape, size_t shapeSize, 39 | halide_type_t typeCode, void *data, 40 | DimensionType dimType) { 41 | std::vector shapeVec(shape, shape + shapeSize); 42 | auto mnn_tensor = 43 | MNN::Tensor::create(shapeVec, typeCode, data, 44 | static_cast(dimType)); 45 | return reinterpret_cast(mnn_tensor); 46 | } 47 | 48 | int Tensor_copyFromHostTensor(Tensor *deviceTensor, const Tensor *hostTensor) { 49 | return reinterpret_cast(deviceTensor) 50 | ->copyFromHostTensor(reinterpret_cast(hostTensor)); 51 | } 52 | int Tensor_copyToHostTensor(const Tensor *deviceTensor, Tensor *hostTensor) { 53 | return reinterpret_cast(deviceTensor) 54 | ->copyToHostTensor(reinterpret_cast(hostTensor)); 55 | } 56 | Tensor *Tensor_createHostTensorFromDevice(const Tensor *deviceTensor, 57 | int copyData) { 58 | return reinterpret_cast(MNN::Tensor::createHostTensorFromDevice( 59 | reinterpret_cast(deviceTensor), copyData)); 60 | } 61 | const void *Tensor_host(const Tensor *tensor) { 62 | return reinterpret_cast(tensor)->host(); 63 | } 64 | 65 | void *Tensor_host_mut(Tensor *tensor) { 66 | return reinterpret_cast(tensor)->host(); 67 | } 68 | 69 | uint64_t Tensor_deviceId(const Tensor *tensor) { 70 | return reinterpret_cast(tensor)->deviceId(); 71 | } 72 | 73 | int Tensor_dimensions(const Tensor *tensor) { 74 | return reinterpret_cast(tensor)->dimensions(); 75 | } 76 | /** 77 | * @brief get all dimensions' extent. 78 | * @return dimensions' extent. 79 | */ 80 | TensorShape Tensor_shape(const Tensor *tensor) { 81 | auto shapeVec = reinterpret_cast(tensor)->shape(); 82 | TensorShape shape; 83 | shape.size = shapeVec.size(); 84 | for (size_t i = 0; i < shapeVec.size(); i++) { 85 | shape.shape[i] = shapeVec[i]; 86 | } 87 | return shape; 88 | } 89 | 90 | int Tensor_size(const Tensor *tensor) { 91 | return reinterpret_cast(tensor)->size(); 92 | } 93 | size_t Tensor_usize(const Tensor *tensor) { 94 | return reinterpret_cast(tensor)->usize(); 95 | } 96 | int Tensor_elementSize(const Tensor *tensor) { 97 | return reinterpret_cast(tensor)->elementSize(); 98 | } 99 | int Tensor_width(const Tensor *tensor) { 100 | return reinterpret_cast(tensor)->width(); 101 | } 102 | int Tensor_height(const Tensor *tensor) { 103 | return reinterpret_cast(tensor)->height(); 104 | } 105 | int Tensor_channel(const Tensor *tensor) { 106 | return reinterpret_cast(tensor)->channel(); 107 | } 108 | int Tensor_batch(const Tensor *tensor) { 109 | return reinterpret_cast(tensor)->batch(); 110 | } 111 | int Tensor_stride(const Tensor *tensor, int index) { 112 | return reinterpret_cast(tensor)->stride(index); 113 | } 114 | int Tensor_length(const Tensor *tensor, int index) { 115 | return reinterpret_cast(tensor)->length(index); 116 | } 117 | void Tensor_setStride(Tensor *tensor, int index, int stride) { 118 | reinterpret_cast(tensor)->setStride(index, stride); 119 | } 120 | void Tensor_setLength(Tensor *tensor, int index, int length) { 121 | reinterpret_cast(tensor)->setLength(index, length); 122 | } 123 | 124 | int Tensor_getDeviceInfo(const Tensor *tensor, void *dst, int forwardType) { 125 | return reinterpret_cast(tensor)->getDeviceInfo( 126 | dst, forwardType); 127 | } 128 | void Tensor_print(const Tensor *tensor) { 129 | reinterpret_cast(tensor)->print(); 130 | } 131 | void Tensor_printShape(const Tensor *tensor) { 132 | reinterpret_cast(tensor)->printShape(); 133 | } 134 | void *Tensor_map(Tensor *tensor, MapType mtype, DimensionType dtype) { 135 | return reinterpret_cast(tensor)->map( 136 | static_cast(mtype), 137 | static_cast(dtype)); 138 | } 139 | void Tensor_unmap(Tensor *tensor, MapType mtype, DimensionType dtype, 140 | void *mapPtr) { 141 | reinterpret_cast(tensor)->unmap( 142 | static_cast(mtype), 143 | static_cast(dtype), mapPtr); 144 | } 145 | int Tensor_wait(Tensor *tensor, MapType mtype, int finish) { 146 | return reinterpret_cast(tensor)->wait( 147 | static_cast(mtype), finish); 148 | } 149 | int Tensor_setDevicePtr(Tensor *tensor, const void *devicePtr, int memoryType) { 150 | return reinterpret_cast(tensor)->setDevicePtr(devicePtr, 151 | memoryType); 152 | } 153 | 154 | const halide_buffer_t *Tensor_buffer(const Tensor *tensor) { 155 | return &reinterpret_cast(tensor)->buffer(); 156 | } 157 | 158 | halide_buffer_t *Tensor_buffer_mut(Tensor *tensor) { 159 | return &reinterpret_cast(tensor)->buffer(); 160 | } 161 | DimensionType Tensor_getDimensionType(const Tensor *tensor) { 162 | return static_cast( 163 | reinterpret_cast(tensor)->getDimensionType()); 164 | } 165 | halide_type_t Tensor_getType(const Tensor *tensor) { 166 | auto mnn_tensor = reinterpret_cast(tensor); 167 | return mnn_tensor->getType(); 168 | } 169 | 170 | bool Tensor_isTypeOf(const Tensor *tensor, struct halide_type_t other) { 171 | auto my = Tensor_getType(tensor); 172 | auto ret = (my.code == other.code && my.bits == other.bits && 173 | my.lanes == other.lanes); 174 | return ret; 175 | } 176 | 177 | Tensor *Tensor_clone(const Tensor *tensor) { 178 | auto mnn_tensor = reinterpret_cast(tensor); 179 | auto ret = MNN::Tensor::clone(mnn_tensor, true); 180 | return reinterpret_cast(ret); 181 | } 182 | 183 | } // extern "C" 184 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "A simple rust flake using rust-overlay and craneLib"; 3 | 4 | inputs = { 5 | nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; 6 | flake-utils.url = "github:numtide/flake-utils"; 7 | crane.url = "github:ipetkov/crane"; 8 | nix-github-actions = { 9 | url = "github:nix-community/nix-github-actions"; 10 | inputs.nixpkgs.follows = "nixpkgs"; 11 | }; 12 | rust-overlay = { 13 | url = "github:oxalica/rust-overlay"; 14 | inputs.nixpkgs.follows = "nixpkgs"; 15 | }; 16 | mnn-overlay = { 17 | url = "github:uttarayan21/mnn-nix-overlay"; 18 | inputs.nixpkgs.follows = "nixpkgs"; 19 | }; 20 | advisory-db = { 21 | url = "github:rustsec/advisory-db"; 22 | flake = false; 23 | }; 24 | mnn-src = { 25 | url = "github:alibaba/MNN/3.2.0"; 26 | flake = false; 27 | }; 28 | }; 29 | 30 | outputs = { 31 | self, 32 | crane, 33 | flake-utils, 34 | nixpkgs, 35 | rust-overlay, 36 | mnn-overlay, 37 | advisory-db, 38 | nix-github-actions, 39 | mnn-src, 40 | ... 41 | }: 42 | flake-utils.lib.eachDefaultSystem ( 43 | system: let 44 | pkgs = import nixpkgs { 45 | inherit system; 46 | overlays = [ 47 | rust-overlay.overlays.default 48 | (final: prev: { 49 | mnn = mnn-overlay.packages.${system}.mnn.override { 50 | src = mnn-src; 51 | buildConverter = true; 52 | enableMetal = true; 53 | enableOpencl = true; 54 | }; 55 | }) 56 | ]; 57 | }; 58 | inherit (pkgs) lib; 59 | 60 | version = "latest"; 61 | 62 | rustToolchain = pkgs.rust-bin.stable.${version}.default; 63 | rustToolchainWithLLvmTools = pkgs.rust-bin.stable.${version}.default.override { 64 | extensions = ["rust-src" "llvm-tools"]; 65 | }; 66 | rustToolchainWithRustAnalyzer = pkgs.rust-bin.stable.${version}.default.override ({ 67 | extensions = ["rust-docs" "rust-src" "rust-analyzer"]; 68 | } 69 | // (lib.optionalAttrs pkgs.stdenv.isDarwin { 70 | targets = ["aarch64-apple-darwin" "x86_64-apple-darwin"]; 71 | })); 72 | craneLib = (crane.mkLib pkgs).overrideToolchain rustToolchain; 73 | craneLibLLvmTools = (crane.mkLib pkgs).overrideToolchain rustToolchainWithLLvmTools; 74 | 75 | src = lib.sources.sourceFilesBySuffices ./. [".rs" ".toml" ".patch" ".mnn" ".h" ".cpp" ".svg" "lock"]; 76 | MNN_SRC = pkgs.applyPatches { 77 | name = "mnn-src"; 78 | src = mnn-src; 79 | patches = [./mnn-sys/patches/mnn-tracing.patch]; 80 | }; 81 | commonArgs = { 82 | inherit src MNN_SRC; 83 | pname = "mnn"; 84 | doCheck = false; 85 | LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib"; 86 | nativeBuildInputs = with pkgs; [ 87 | cmake 88 | llvmPackages.libclang.lib 89 | clang 90 | pkg-config 91 | ]; 92 | buildInputs = with pkgs; 93 | [] 94 | ++ (lib.optionals pkgs.stdenv.isLinux [ 95 | ocl-icd 96 | opencl-headers 97 | ]) 98 | ++ (lib.optionals pkgs.stdenv.isDarwin [ 99 | apple-sdk_13 100 | ]); 101 | }; 102 | cargoArtifacts = craneLib.buildPackage commonArgs; 103 | in rec { 104 | checks = 105 | { 106 | mnn-clippy = craneLib.cargoClippy (commonArgs 107 | // { 108 | inherit cargoArtifacts; 109 | cargoClippyExtraArgs = "--all-targets -- --deny warnings"; 110 | }); 111 | mnn-docs = craneLib.cargoDoc (commonArgs 112 | // { 113 | inherit cargoArtifacts; 114 | cargoDocExtraArgs = "-p mnn -p mnn-sys"; 115 | }); 116 | mnn-fmt = craneLib.cargoFmt {inherit src;}; 117 | # Audit dependencies 118 | mnn-audit = 119 | craneLib.cargoAudit.override { 120 | cargo-audit = pkgs.cargo-audit; 121 | } { 122 | inherit src advisory-db; 123 | }; 124 | 125 | # Audit licenses 126 | mnn-deny = craneLib.cargoDeny { 127 | inherit src; 128 | }; 129 | mnn-nextest = craneLib.cargoNextest (commonArgs 130 | // { 131 | inherit cargoArtifacts; 132 | partitions = 1; 133 | partitionType = "count"; 134 | }); 135 | mnn-sys-clippy = craneLib.cargoClippy (commonArgs 136 | // { 137 | inherit cargoArtifacts; 138 | cargoClippyExtraArgs = "-p mnn-sys --all-targets -- --deny warnings"; 139 | }); 140 | mnn-sys-nextest = craneLib.cargoNextest (commonArgs 141 | // { 142 | inherit cargoArtifacts; 143 | partitions = 1; 144 | partitionType = "count"; 145 | cargoExtraArgs = "-p mnn-sys"; 146 | }); 147 | # mnn-asan = let 148 | # rustPlatform = pkgs.makeRustPlatform { 149 | # cargo = nightlyToolchain; 150 | # rustc = nightlyToolchain; 151 | # }; 152 | # in 153 | # rustPlatform.buildRustPackage ( 154 | # commonArgs 155 | # // { 156 | # inherit src; 157 | # name = "mnn-leaks"; 158 | # cargoLock = { 159 | # lockFile = ./Cargo.lock; 160 | # outputHashes = { 161 | # "cmake-0.1.50" = "sha256-GM2D7dpb2i2S6qYVM4HYk5B40TwKCmGQnUPfXksyf0M="; 162 | # }; 163 | # }; 164 | # 165 | # buildPhase = '' 166 | # cargo test --target aarch64-apple-darwin 167 | # ''; 168 | # RUSTFLAGS = "-Zsanitizer=address"; 169 | # ASAN_OPTIONS = "detect_leaks=1"; 170 | # # MNN_COMPILE = "NO"; 171 | # # MNN_LIB_DIR = "${pkgs.mnn}/lib"; 172 | # } 173 | # ); 174 | } 175 | // lib.optionalAttrs (!pkgs.stdenv.isDarwin) { 176 | mnn-llvm-cov = craneLibLLvmTools.cargoLlvmCov (commonArgs // {inherit cargoArtifacts;}); 177 | }; 178 | packages = rec { 179 | mnn = craneLib.buildPackage (commonArgs 180 | // { 181 | inherit cargoArtifacts; 182 | }); 183 | inspect = craneLib.buildPackage (commonArgs 184 | // { 185 | inherit cargoArtifacts; 186 | pname = "inspect"; 187 | cargoExtraArgs = 188 | "--example inspect" 189 | + ( 190 | lib.optionalString pkgs.stdenv.isDarwin " --features opencl,metal,coreml" # + lib.optionalString pkgs.stdenv.isAarch64 ",metal,coreml" 191 | ); 192 | }); 193 | bencher = craneLib.buildPackage (commonArgs 194 | // { 195 | inherit cargoArtifacts; 196 | pname = "bencher"; 197 | cargoExtraArgs = "--package bencher"; 198 | }); 199 | default = mnn; 200 | }; 201 | 202 | devShells = { 203 | default = pkgs.mkShell (commonArgs 204 | // { 205 | MNN_SRC = null; 206 | LLDB_DEBUGSERVER_PATH = "/Applications/Xcode.app/Contents/SharedFrameworks/LLDB.framework/Versions/A/Resources/debugserver"; 207 | packages = with pkgs; 208 | [ 209 | cargo-audit 210 | cargo-deny 211 | cargo-hakari 212 | cargo-nextest 213 | cargo-semver-checks 214 | clang 215 | git 216 | git-lfs 217 | llvm 218 | llvmPackages.lldb 219 | nushell 220 | rust-bindgen 221 | google-cloud-sdk 222 | rustToolchainWithRustAnalyzer 223 | mnn 224 | ] 225 | ++ ( 226 | lib.optionals pkgs.stdenv.isLinux [ 227 | cargo-llvm-cov 228 | ] 229 | ); 230 | }); 231 | }; 232 | } 233 | ) 234 | // { 235 | githubActions = nix-github-actions.lib.mkGithubMatrix { 236 | checks = nixpkgs.lib.getAttrs ["x86_64-linux" "aarch64-darwin"] self.checks; 237 | }; 238 | }; 239 | } 240 | -------------------------------------------------------------------------------- /mnn-sys/mnn_c/interpreter_c.h: -------------------------------------------------------------------------------- 1 | #ifndef INTERPRETER_C_H 2 | #define INTERPRETER_C_H 3 | #include "backend_c.h" 4 | #include "error_code_c.h" 5 | #include "schedule_c.h" 6 | #include "session_c.h" 7 | #include "tensor_c.h" 8 | #include "utils.h" 9 | #include 10 | #include 11 | #ifdef __cplusplus 12 | extern "C" { 13 | #endif 14 | typedef struct Interpreter Interpreter; 15 | typedef struct Backend Backend; 16 | 17 | /** acquire runtime status by Runtime::getCurrentStatus with following keys, 18 | */ 19 | enum RuntimeStatus { 20 | /** 21 | * get status whether this runtime support 16-bits float point arithmetic 22 | */ 23 | STATUS_SUPPORT_FP16, 24 | /** 25 | * get status whether this runtime support dot-product arithmetic 26 | */ 27 | STATUS_SUPPORT_DOT_PRODUCT, 28 | /** 29 | * get status whether this runtime support power-low (means low priority for 30 | * opencl) 31 | */ 32 | STATUS_SUPPORT_POWER_LOW, 33 | /** 34 | * emum total number 35 | */ 36 | STATUS_COUNT 37 | }; 38 | 39 | // typedef struct { 40 | // char **saveTensors; 41 | // size_t saveTensorsSize; 42 | // MNNForwardType type; 43 | // union { 44 | // int numThread; 45 | // int mode; 46 | // }; 47 | // struct { 48 | // char **inputs; 49 | // size_t inputsSize; 50 | // char **outputs; 51 | // size_t outputsSize; 52 | // int mode; 53 | // } path; 54 | // MNNForwardType backupType; 55 | // MNNBackendConfig *backendConfig; 56 | // } ScheduleConfig; 57 | 58 | #if 0 59 | typedef struct { 60 | std::map> *runtimeMap; 61 | std::shared_ptr *defaultRuntime; 62 | } RuntimeInfo; 63 | #endif 64 | 65 | void modelPrintIO(const char *model); 66 | 67 | /** 68 | * @brief get mnn version info. 69 | * @return mnn version string. 70 | */ 71 | const char *getVersion(); 72 | /** 73 | * @brief create net from file. 74 | * @param file given file. 75 | * @return created net if success, NULL otherwise. 76 | */ 77 | Interpreter *Interpreter_createFromFile(const char *file); 78 | /** 79 | * @brief create net from buffer. 80 | * @param buffer given data buffer. 81 | * @param size size of data buffer. 82 | * @return created net if success, NULL otherwise. 83 | */ 84 | Interpreter *Interpreter_createFromBuffer(const void *buffer, size_t size); 85 | void Interpreter_destroy(Interpreter *interpreter); 86 | typedef enum { 87 | /** About CallBack, Default Session_Debug*/ 88 | /** runSessionWithCallBack is allowed and can get internal op info*/ 89 | Session_Debug = 0, 90 | /** runSessionWithCallBack is not valid and can't get any info of op in 91 | session*/ 92 | Session_Release = 1, 93 | 94 | /** About input tenosr, Default Session_Input_Inside*/ 95 | /** The input tensor is alloced by session, input data after session resized*/ 96 | Session_Input_Inside = 2, 97 | /** The input tensor is alloced by user, set input data before session 98 | resize*/ 99 | Session_Input_User = 3, 100 | 101 | /** The output tensor depends on session, and can't be separate used*/ 102 | Session_Output_Inside = 4, 103 | /** The output tensor can be separated from session*/ 104 | Session_Output_User = 5, 105 | 106 | /** Try Resize Session when create Session or not, default direct: */ 107 | Session_Resize_Direct = 6, 108 | Session_Resize_Defer = 7, 109 | 110 | /** Determine the Execution's forward type is determine by user or auto 111 | determine */ 112 | Session_Backend_Fix = 113 | 8, // Use the backend user set, when not support use default backend 114 | Session_Backend_Auto = 9, // Auto Determine the Op type by MNN 115 | 116 | /** Determine static memory whether recyle in resizeSession or just cache the 117 | memory */ 118 | Session_Memory_Collect = 119 | 10, // Recycle static memory when session resize in case memory explosion 120 | Session_Memory_Cache = 11, // Cache the static memory for next forward usage 121 | 122 | /** Determine whether use codegen function */ 123 | Session_Codegen_Disable = 124 | 12, // Disable codegen in case extra build codegen cost 125 | Session_Codegen_Enable = 13, // Enable codegen 126 | 127 | /** Dynamic Reisze Optimization */ 128 | Session_Resize_Check = 14, // Open Trace for resize 129 | Session_Resize_Fix = 15, // Apply Resize Optimization 130 | } SessionMode; 131 | void Interpreter_setSessionMode(Interpreter *interpreter, SessionMode mode); 132 | void Interpreter_setCacheFile(Interpreter *interpreter, const char *cacheFile, 133 | size_t keySize); 134 | void Interpreter_setExternalFile(Interpreter *interpreter, const char *file, 135 | size_t flag); 136 | ErrorCode Interpreter_updateCacheFile(Interpreter *interpreter, 137 | Session *session); 138 | void Interpreter_setSessionHint(Interpreter *interpreter, int mode, int value); 139 | // RuntimeInfo *Interpreter_createRuntime(const ScheduleConfig *configs, 140 | // size_t configSize); 141 | Session *Interpreter_createSession(Interpreter *interpreter, 142 | const MNNScheduleConfig *config); 143 | // Session *Interpreter_createSessionWithRuntime(Interpreter *interpreter, 144 | // const ScheduleConfig *config, 145 | // const RuntimeInfo *runtime); 146 | Session * 147 | Interpreter_createMultiPathSession(Interpreter *interpreter, 148 | const MNNScheduleConfig *const *configs, 149 | size_t configSize); 150 | // Session *Interpreter_createMultiPathSessionWithRuntime( 151 | // Interpreter *interpreter, const ScheduleConfig *configs, size_t 152 | // configSize, const RuntimeInfo *runtime); 153 | int Interpreter_releaseSession(Interpreter *interpreter, Session *session); 154 | void Interpreter_resizeSession(Interpreter *interpreter, Session *session); 155 | void Interpreter_resizeSessionWithFlag(Interpreter *interpreter, 156 | Session *session, int needRelloc); 157 | void Interpreter_releaseModel(Interpreter *interpreter); 158 | // std::pair 159 | // Interpreter_getModelBuffer(const Interpreter *interpreter); 160 | const char *Interpreter_getModelVersion(const Interpreter *interpreter); 161 | ErrorCode Interpreter_updateSessionToModel(Interpreter *interpreter, 162 | Session *session); 163 | ErrorCode Interpreter_runSession(const Interpreter *interpreter, 164 | Session *session); 165 | // ErrorCode Interpreter_runSessionWithCallBack(const Interpreter *interpreter, 166 | // const Session *session, 167 | // void *before, void *end, int 168 | // sync); 169 | ErrorCode Interpreter_runSessionWithCallBackInfo(const Interpreter *interpreter, 170 | const Session *session, 171 | void *before, void *end, 172 | int sync); 173 | Tensor *Interpreter_getSessionInput(Interpreter *interpreter, 174 | const Session *session, const char *name); 175 | Tensor *Interpreter_getSessionOutput(Interpreter *interpreter, 176 | const Session *session, const char *name); 177 | int Interpreter_getSessionInfo(Interpreter *interpreter, const Session *session, 178 | int code, void *ptr); 179 | TensorInfoArray const * 180 | Interpreter_getSessionOutputAll(const Interpreter *interpreter, 181 | const Session *session); 182 | 183 | TensorInfoArray const * 184 | Interpreter_getSessionInputAll(const Interpreter *interpreter, 185 | const Session *session); 186 | void Interpreter_resizeTensor(Interpreter *interpreter, Tensor *tensor, 187 | const int *dims, size_t dimsSize); 188 | void Interpreter_resizeTensorByNCHW(Interpreter *interpreter, Tensor *tensor, 189 | int batch, int channel, int height, 190 | int width); 191 | const Backend *Interpreter_getBackend(const Interpreter *interpreter, 192 | const Session *session, 193 | const Tensor *tensor); 194 | const char *Interpreter_bizCode(const Interpreter *interpreter); 195 | const char *Interpreter_uuid(const Interpreter *interpreter); 196 | 197 | const char *OperatorInfo_name(const void *op); 198 | const char *OperatorInfo_type(const void *op); 199 | float OperatorInfo_flops(const void *op); 200 | 201 | #ifdef __cplusplus 202 | } 203 | #endif 204 | #endif // INTERPRETER_C_H 205 | -------------------------------------------------------------------------------- /deny.toml: -------------------------------------------------------------------------------- 1 | # This template contains all of the possible sections and their default values 2 | 3 | # Note that all fields that take a lint level have these possible values: 4 | # * deny - An error will be produced and the check will fail 5 | # * warn - A warning will be produced, but the check will not fail 6 | # * allow - No warning or error will be produced, though in some cases a note 7 | # will be 8 | 9 | # The values provided in this template are the default values that will be used 10 | # when any section or field is not specified in your own configuration 11 | 12 | # Root options 13 | 14 | # The graph table configures how the dependency graph is constructed and thus 15 | # which crates the checks are performed against 16 | [graph] 17 | # If 1 or more target triples (and optionally, target_features) are specified, 18 | # only the specified targets will be checked when running `cargo deny check`. 19 | # This means, if a particular package is only ever used as a target specific 20 | # dependency, such as, for example, the `nix` crate only being used via the 21 | # `target_family = "unix"` configuration, that only having windows targets in 22 | # this list would mean the nix crate, as well as any of its exclusive 23 | # dependencies not shared by any other crates, would be ignored, as the target 24 | # list here is effectively saying which targets you are building for. 25 | targets = [ 26 | # The triple can be any string, but only the target triples built in to 27 | # rustc (as of 1.40) can be checked against actual config expressions 28 | #"x86_64-unknown-linux-musl", 29 | # You can also specify which target_features you promise are enabled for a 30 | # particular target. target_features are currently not validated against 31 | # the actual valid features supported by the target architecture. 32 | #{ triple = "wasm32-unknown-unknown", features = ["atomics"] }, 33 | ] 34 | # When creating the dependency graph used as the source of truth when checks are 35 | # executed, this field can be used to prune crates from the graph, removing them 36 | # from the view of cargo-deny. This is an extremely heavy hammer, as if a crate 37 | # is pruned from the graph, all of its dependencies will also be pruned unless 38 | # they are connected to another crate in the graph that hasn't been pruned, 39 | # so it should be used with care. The identifiers are [Package ID Specifications] 40 | # (https://doc.rust-lang.org/cargo/reference/pkgid-spec.html) 41 | #exclude = [] 42 | # If true, metadata will be collected with `--all-features`. Note that this can't 43 | # be toggled off if true, if you want to conditionally enable `--all-features` it 44 | # is recommended to pass `--all-features` on the cmd line instead 45 | all-features = false 46 | # If true, metadata will be collected with `--no-default-features`. The same 47 | # caveat with `all-features` applies 48 | no-default-features = false 49 | # If set, these feature will be enabled when collecting metadata. If `--features` 50 | # is specified on the cmd line they will take precedence over this option. 51 | #features = [] 52 | 53 | # The output table provides options for how/if diagnostics are outputted 54 | [output] 55 | # When outputting inclusion graphs in diagnostics that include features, this 56 | # option can be used to specify the depth at which feature edges will be added. 57 | # This option is included since the graphs can be quite large and the addition 58 | # of features from the crate(s) to all of the graph roots can be far too verbose. 59 | # This option can be overridden via `--feature-depth` on the cmd line 60 | feature-depth = 1 61 | 62 | # This section is considered when running `cargo deny check advisories` 63 | # More documentation for the advisories section can be found here: 64 | # https://embarkstudios.github.io/cargo-deny/checks/advisories/cfg.html 65 | [advisories] 66 | # The path where the advisory databases are cloned/fetched into 67 | #db-path = "$CARGO_HOME/advisory-dbs" 68 | # The url(s) of the advisory databases to use 69 | #db-urls = ["https://github.com/rustsec/advisory-db"] 70 | # A list of advisory IDs to ignore. Note that ignored advisories will still 71 | # output a note when they are encountered. 72 | ignore = [ 73 | #"RUSTSEC-0000-0000", 74 | #{ id = "RUSTSEC-0000-0000", reason = "you can specify a reason the advisory is ignored" }, 75 | #"a-crate-that-is-yanked@0.1.1", # you can also ignore yanked crate versions if you wish 76 | #{ crate = "a-crate-that-is-yanked@0.1.1", reason = "you can specify why you are ignoring the yanked crate" }, 77 | ] 78 | # If this is true, then cargo deny will use the git executable to fetch advisory database. 79 | # If this is false, then it uses a built-in git library. 80 | # Setting this to true can be helpful if you have special authentication requirements that cargo-deny does not support. 81 | # See Git Authentication for more information about setting up git authentication. 82 | #git-fetch-with-cli = true 83 | 84 | # This section is considered when running `cargo deny check licenses` 85 | # More documentation for the licenses section can be found here: 86 | # https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html 87 | [licenses] 88 | # List of explicitly allowed licenses 89 | # See https://spdx.org/licenses/ for list of possible licenses 90 | # [possible values: any SPDX 3.11 short identifier (+ optional exception)]. 91 | allow = ["MIT", "Apache-2.0", "BSD-3-Clause", "ISC", "Unicode-DFS-2016", "Zlib"] 92 | # The confidence threshold for detecting a license from license text. 93 | # The higher the value, the more closely the license text must be to the 94 | # canonical license text of a valid SPDX license file. 95 | # [possible values: any between 0.0 and 1.0]. 96 | confidence-threshold = 0.8 97 | # Allow 1 or more licenses on a per-crate basis, so that particular licenses 98 | # aren't accepted for every possible crate as with the normal allow list 99 | exceptions = [ 100 | # Each entry is the crate and version constraint, and its specific allow 101 | # list 102 | #{ allow = ["Zlib"], crate = "adler32" }, 103 | ] 104 | 105 | # Some crates don't have (easily) machine readable licensing information, 106 | # adding a clarification entry for it allows you to manually specify the 107 | # licensing information 108 | #[[licenses.clarify]] 109 | # The package spec the clarification applies to 110 | #crate = "ring" 111 | # The SPDX expression for the license requirements of the crate 112 | #expression = "MIT AND ISC AND OpenSSL" 113 | # One or more files in the crate's source used as the "source of truth" for 114 | # the license expression. If the contents match, the clarification will be used 115 | # when running the license check, otherwise the clarification will be ignored 116 | # and the crate will be checked normally, which may produce warnings or errors 117 | # depending on the rest of your configuration 118 | #license-files = [ 119 | # Each entry is a crate relative path, and the (opaque) hash of its contents 120 | #{ path = "LICENSE", hash = 0xbd0eed23 } 121 | #] 122 | 123 | [licenses.private] 124 | # If true, ignores workspace crates that aren't published, or are only 125 | # published to private registries. 126 | # To see how to mark a crate as unpublished (to the official registry), 127 | # visit https://doc.rust-lang.org/cargo/reference/manifest.html#the-publish-field. 128 | ignore = false 129 | # One or more private registries that you might publish crates to, if a crate 130 | # is only published to private registries, and ignore is true, the crate will 131 | # not have its license(s) checked 132 | registries = [ 133 | #"https://sekretz.com/registry 134 | ] 135 | 136 | # This section is considered when running `cargo deny check bans`. 137 | # More documentation about the 'bans' section can be found here: 138 | # https://embarkstudios.github.io/cargo-deny/checks/bans/cfg.html 139 | [bans] 140 | # Lint level for when multiple versions of the same crate are detected 141 | multiple-versions = "warn" 142 | # Lint level for when a crate version requirement is `*` 143 | wildcards = "allow" 144 | # The graph highlighting used when creating dotgraphs for crates 145 | # with multiple versions 146 | # * lowest-version - The path to the lowest versioned duplicate is highlighted 147 | # * simplest-path - The path to the version with the fewest edges is highlighted 148 | # * all - Both lowest-version and simplest-path are used 149 | highlight = "all" 150 | # The default lint level for `default` features for crates that are members of 151 | # the workspace that is being checked. This can be overridden by allowing/denying 152 | # `default` on a crate-by-crate basis if desired. 153 | workspace-default-features = "allow" 154 | # The default lint level for `default` features for external crates that are not 155 | # members of the workspace. This can be overridden by allowing/denying `default` 156 | # on a crate-by-crate basis if desired. 157 | external-default-features = "allow" 158 | # List of crates that are allowed. Use with care! 159 | allow = [ 160 | #"ansi_term@0.11.0", 161 | #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason it is allowed" }, 162 | ] 163 | # List of crates to deny 164 | deny = [ 165 | #"ansi_term@0.11.0", 166 | #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason it is banned" }, 167 | # Wrapper crates can optionally be specified to allow the crate when it 168 | # is a direct dependency of the otherwise banned crate 169 | #{ crate = "ansi_term@0.11.0", wrappers = ["this-crate-directly-depends-on-ansi_term"] }, 170 | ] 171 | 172 | # List of features to allow/deny 173 | # Each entry the name of a crate and a version range. If version is 174 | # not specified, all versions will be matched. 175 | #[[bans.features]] 176 | #crate = "reqwest" 177 | # Features to not allow 178 | #deny = ["json"] 179 | # Features to allow 180 | #allow = [ 181 | # "rustls", 182 | # "__rustls", 183 | # "__tls", 184 | # "hyper-rustls", 185 | # "rustls", 186 | # "rustls-pemfile", 187 | # "rustls-tls-webpki-roots", 188 | # "tokio-rustls", 189 | # "webpki-roots", 190 | #] 191 | # If true, the allowed features must exactly match the enabled feature set. If 192 | # this is set there is no point setting `deny` 193 | #exact = true 194 | 195 | # Certain crates/versions that will be skipped when doing duplicate detection. 196 | skip = [ 197 | #"ansi_term@0.11.0", 198 | #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason why it can't be updated/removed" }, 199 | ] 200 | # Similarly to `skip` allows you to skip certain crates during duplicate 201 | # detection. Unlike skip, it also includes the entire tree of transitive 202 | # dependencies starting at the specified crate, up to a certain depth, which is 203 | # by default infinite. 204 | skip-tree = [ 205 | #"ansi_term@0.11.0", # will be skipped along with _all_ of its direct and transitive dependencies 206 | #{ crate = "ansi_term@0.11.0", depth = 20 }, 207 | ] 208 | 209 | # This section is considered when running `cargo deny check sources`. 210 | # More documentation about the 'sources' section can be found here: 211 | # https://embarkstudios.github.io/cargo-deny/checks/sources/cfg.html 212 | [sources] 213 | # Lint level for what to happen when a crate from a crate registry that is not 214 | # in the allow list is encountered 215 | unknown-registry = "warn" 216 | # Lint level for what to happen when a crate from a git repository that is not 217 | # in the allow list is encountered 218 | unknown-git = "warn" 219 | # List of URLs for allowed crate registries. Defaults to the crates.io index 220 | # if not specified. If it is specified but empty, no registries are allowed. 221 | allow-registry = ["https://github.com/rust-lang/crates.io-index"] 222 | # List of URLs for allowed Git repositories 223 | allow-git = [] 224 | 225 | [sources.allow-org] 226 | # github.com organizations to allow git sources for 227 | github = [] 228 | # gitlab.com organizations to allow git sources for 229 | gitlab = [] 230 | # bitbucket.org organizations to allow git sources for 231 | bitbucket = [] 232 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 AfterShoot 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/backend.rs: -------------------------------------------------------------------------------- 1 | //! The backend module contains the data types for the backend configuration 2 | 3 | use crate::prelude::*; 4 | use std::str::FromStr; 5 | 6 | use mnn_sys::*; 7 | 8 | /// BackendConfig is a struct that holds the configuration for the backend 9 | /// memory: [MemoryMode] 10 | /// power: [PowerMode] 11 | /// precision: [PrecisionMode] 12 | #[repr(transparent)] 13 | pub struct BackendConfig { 14 | pub(crate) inner: *mut MNNBackendConfig, 15 | __marker: core::marker::PhantomData<()>, 16 | } 17 | 18 | impl core::fmt::Debug for BackendConfig { 19 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 20 | f.debug_struct("BackendConfig") 21 | .field("memory", &self.get_memory_mode()) 22 | .field("power", &self.get_power_mode()) 23 | .field("precision", &self.get_precision_mode()) 24 | .finish() 25 | } 26 | } 27 | 28 | #[cfg(feature = "serde")] 29 | impl serde::Serialize for BackendConfig { 30 | fn serialize(&self, serializer: S) -> Result 31 | where 32 | S: serde::ser::Serializer, 33 | { 34 | use serde::ser::SerializeStruct; 35 | let mut state = serializer.serialize_struct("BackendConfig", 3)?; 36 | state.serialize_field("memory", &self.get_memory_mode())?; 37 | state.serialize_field("power", &self.get_power_mode())?; 38 | state.serialize_field("precision", &self.get_precision_mode())?; 39 | state.end() 40 | } 41 | } 42 | 43 | impl Clone for BackendConfig { 44 | fn clone(&self) -> Self { 45 | unsafe { 46 | let inner = mnn_sys::mnnbc_clone(self.inner); 47 | Self { 48 | inner, 49 | __marker: core::marker::PhantomData, 50 | } 51 | } 52 | } 53 | } 54 | 55 | impl Drop for BackendConfig { 56 | fn drop(&mut self) { 57 | unsafe { 58 | mnn_sys::mnnbc_destroy(self.inner); 59 | } 60 | } 61 | } 62 | 63 | impl Default for BackendConfig { 64 | fn default() -> Self { 65 | Self::new() 66 | } 67 | } 68 | 69 | /// PowerModes depend on if the specific backend has support for it 70 | #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] 71 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 72 | pub enum PowerMode { 73 | /// Low power mode 74 | Low, 75 | /// Normal power mode 76 | Normal, 77 | /// High power mode 78 | High, 79 | } 80 | 81 | impl PowerMode { 82 | fn to_mnn_sys(self) -> mnn_sys::PowerMode { 83 | match self { 84 | Self::Low => mnn_sys::PowerMode::Power_Low, 85 | Self::Normal => mnn_sys::PowerMode::Power_Normal, 86 | Self::High => mnn_sys::PowerMode::Power_High, 87 | } 88 | } 89 | 90 | /// Returns a string representation of the power mode 91 | pub fn to_str(self) -> &'static str { 92 | match self { 93 | Self::Low => "low", 94 | Self::Normal => "normal", 95 | Self::High => "high", 96 | } 97 | } 98 | 99 | fn from_mnn_sys(mode: mnn_sys::PowerMode) -> Self { 100 | match mode { 101 | mnn_sys::PowerMode::Power_Low => Self::Low, 102 | mnn_sys::PowerMode::Power_Normal => Self::Normal, 103 | mnn_sys::PowerMode::Power_High => Self::High, 104 | _ => Self::Normal, 105 | } 106 | } 107 | } 108 | 109 | impl FromStr for PowerMode { 110 | type Err = MNNError; 111 | fn from_str(s: &str) -> Result { 112 | match s { 113 | "low" => Ok(Self::Low), 114 | "normal" => Ok(Self::Normal), 115 | "high" => Ok(Self::High), 116 | _ => { 117 | Err(error!(ErrorKind::ParseError) 118 | .attach_printable(format!("invalid power mode: {s}"))) 119 | } 120 | } 121 | } 122 | } 123 | 124 | impl FromStr for MemoryMode { 125 | type Err = MNNError; 126 | fn from_str(s: &str) -> Result { 127 | match s { 128 | "low" => Ok(Self::Low), 129 | "normal" => Ok(Self::Normal), 130 | "high" => Ok(Self::High), 131 | _ => { 132 | Err(error!(ErrorKind::ParseError) 133 | .attach_printable(format!("invalid memory mode: {s}"))) 134 | } 135 | } 136 | } 137 | } 138 | 139 | impl FromStr for PrecisionMode { 140 | type Err = MNNError; 141 | fn from_str(s: &str) -> Result { 142 | match s { 143 | "low" => Ok(Self::Low), 144 | "normal" => Ok(Self::Normal), 145 | "high" => Ok(Self::High), 146 | "low_bf16" => Ok(Self::LowBf16), 147 | _ => Err(error!(ErrorKind::ParseError) 148 | .attach_printable(format!("invalid precision mode: {s}"))), 149 | } 150 | } 151 | } 152 | 153 | /// MemoryModes depend on if the specific backend has support for it 154 | #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] 155 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 156 | pub enum MemoryMode { 157 | /// Low memory mode 158 | Low, 159 | /// Normal memory mode 160 | Normal, 161 | /// High memory mode 162 | High, 163 | } 164 | 165 | impl MemoryMode { 166 | fn to_mnn_sys(self) -> mnn_sys::MemoryMode { 167 | match self { 168 | Self::Low => mnn_sys::MemoryMode::Memory_Low, 169 | Self::Normal => mnn_sys::MemoryMode::Memory_Normal, 170 | Self::High => mnn_sys::MemoryMode::Memory_High, 171 | } 172 | } 173 | 174 | /// Returns a string representation of the memory mode 175 | pub fn to_str(self) -> &'static str { 176 | match self { 177 | Self::Low => "low", 178 | Self::Normal => "normal", 179 | Self::High => "high", 180 | } 181 | } 182 | 183 | fn from_mnn_sys(mode: mnn_sys::MemoryMode) -> Self { 184 | match mode { 185 | mnn_sys::MemoryMode::Memory_Low => Self::Low, 186 | mnn_sys::MemoryMode::Memory_Normal => Self::Normal, 187 | mnn_sys::MemoryMode::Memory_High => Self::High, 188 | _ => Self::Normal, 189 | } 190 | } 191 | } 192 | 193 | /// PrecisionModes depend on if the specific backend has support for it 194 | #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] 195 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 196 | pub enum PrecisionMode { 197 | /// Normal precision mode 198 | Normal = 0, 199 | /// High precision mode 200 | High, 201 | /// Low precision mode 202 | Low, 203 | /// Low precision mode with BF16 204 | LowBf16, 205 | } 206 | impl PrecisionMode { 207 | pub(crate) fn to_mnn_sys(self) -> mnn_sys::PrecisionMode { 208 | match self { 209 | Self::LowBf16 => mnn_sys::PrecisionMode::Precision_Low_BF16, 210 | Self::Low => mnn_sys::PrecisionMode::Precision_Low, 211 | Self::Normal => mnn_sys::PrecisionMode::Precision_Normal, 212 | Self::High => mnn_sys::PrecisionMode::Precision_High, 213 | } 214 | } 215 | 216 | /// Returns a string representation of the precision mode 217 | pub fn to_str(self) -> &'static str { 218 | match self { 219 | Self::LowBf16 => "low_bf16", 220 | Self::Low => "low", 221 | Self::Normal => "normal", 222 | Self::High => "high", 223 | } 224 | } 225 | 226 | fn from_mnn_sys(mode: mnn_sys::PrecisionMode) -> Self { 227 | match mode { 228 | mnn_sys::PrecisionMode::Precision_Low_BF16 => Self::LowBf16, 229 | mnn_sys::PrecisionMode::Precision_Low => Self::Low, 230 | mnn_sys::PrecisionMode::Precision_Normal => Self::Normal, 231 | mnn_sys::PrecisionMode::Precision_High => Self::High, 232 | _ => Self::Normal, 233 | } 234 | } 235 | } 236 | 237 | impl BackendConfig { 238 | /// Create a new backend config 239 | pub fn new() -> Self { 240 | unsafe { 241 | let inner = mnnbc_create(); 242 | Self { 243 | inner, 244 | __marker: core::marker::PhantomData, 245 | } 246 | } 247 | } 248 | 249 | /// Sets the [MemoryMode] for the backend 250 | pub fn set_memory_mode(&mut self, mode: MemoryMode) { 251 | unsafe { 252 | mnn_sys::mnnbc_set_memory_mode(self.inner, mode.to_mnn_sys()); 253 | } 254 | } 255 | 256 | /// Sets the [MemoryMode] for the backend 257 | pub fn with_memory_mode(mut self, mode: MemoryMode) -> Self { 258 | self.set_memory_mode(mode); 259 | self 260 | } 261 | 262 | /// Gets the [MemoryMode] for the backend 263 | pub fn get_memory_mode(&self) -> MemoryMode { 264 | unsafe { MemoryMode::from_mnn_sys(mnn_sys::mnnbc_get_memory_mode(self.inner)) } 265 | } 266 | 267 | /// Sets the [PowerMode] for the backend 268 | pub fn set_power_mode(&mut self, mode: PowerMode) { 269 | unsafe { 270 | mnn_sys::mnnbc_set_power_mode(self.inner, mode.to_mnn_sys()); 271 | } 272 | } 273 | 274 | /// Sets the [PowerMode] for the backend 275 | pub fn with_power_mode(mut self, mode: PowerMode) -> Self { 276 | self.set_power_mode(mode); 277 | self 278 | } 279 | 280 | /// Gets the [PowerMode] for the backend 281 | pub fn get_power_mode(&self) -> PowerMode { 282 | unsafe { PowerMode::from_mnn_sys(mnn_sys::mnnbc_get_power_mode(self.inner)) } 283 | } 284 | 285 | /// Sets the [PrecisionMode] for the backend 286 | pub fn set_precision_mode(&mut self, mode: PrecisionMode) { 287 | unsafe { 288 | mnn_sys::mnnbc_set_precision_mode(self.inner, mode.to_mnn_sys()); 289 | } 290 | } 291 | 292 | /// Sets the [PrecisionMode] for the backend 293 | pub fn with_precision_mode(mut self, mode: PrecisionMode) -> Self { 294 | self.set_precision_mode(mode); 295 | self 296 | } 297 | 298 | /// Gets the [PrecisionMode] for the backend 299 | pub fn get_precision_mode(&self) -> PrecisionMode { 300 | unsafe { PrecisionMode::from_mnn_sys(mnn_sys::mnnbc_get_precision_mode(self.inner)) } 301 | } 302 | 303 | /// Sets the flags for the backend 304 | /// What the flag represents is depends on each backend or isn't documented 305 | pub fn set_flags(&mut self, flags: usize) { 306 | unsafe { 307 | mnn_sys::mnnbc_set_flags(self.inner, flags); 308 | } 309 | } 310 | 311 | /// Sets the flags for the backend 312 | pub fn with_flags(mut self, flags: usize) -> Self { 313 | self.set_flags(flags); 314 | self 315 | } 316 | 317 | /// # Safety 318 | /// This just binds to the underlying unsafe api and should be used only if you know what you 319 | /// are doing 320 | pub unsafe fn set_shared_context(&mut self, shared_context: *mut libc::c_void) { 321 | unsafe { 322 | mnn_sys::mnnbc_set_shared_context(self.inner, shared_context); 323 | } 324 | } 325 | 326 | /// # Safety 327 | /// This just binds to the underlying unsafe api and should be used only if you know what you 328 | /// are doing 329 | pub unsafe fn with_shared_context(mut self, shared_context: *mut libc::c_void) -> Self { 330 | unsafe { 331 | self.set_shared_context(shared_context); 332 | } 333 | self 334 | } 335 | } 336 | 337 | #[test] 338 | fn test_backend_config() { 339 | let mut config = BackendConfig::new(); 340 | config.set_memory_mode(MemoryMode::Low); 341 | config.set_power_mode(PowerMode::Low); 342 | config.set_precision_mode(PrecisionMode::Low); 343 | let config = std::hint::black_box(config.clone()); 344 | assert_eq!(config.get_memory_mode(), MemoryMode::Low); 345 | assert_eq!(config.get_power_mode(), PowerMode::Low); 346 | assert_eq!(config.get_precision_mode(), PrecisionMode::Low); 347 | let config = config 348 | .with_memory_mode(MemoryMode::Normal) 349 | .with_power_mode(PowerMode::Normal) 350 | .with_precision_mode(PrecisionMode::Normal); 351 | assert_eq!(config.get_memory_mode(), MemoryMode::Normal); 352 | assert_eq!(config.get_power_mode(), PowerMode::Normal); 353 | assert_eq!(config.get_precision_mode(), PrecisionMode::Normal); 354 | } 355 | -------------------------------------------------------------------------------- /mnn-sys/mnn_c/interpreter_c.cpp: -------------------------------------------------------------------------------- 1 | #include "interpreter_c.h" 2 | #include "MNN/Interpreter.hpp" 3 | #include 4 | #include 5 | #include 6 | #include 7 | extern "C" { 8 | // int rust_closure_callback_runner(void *closure, Tensor *const *tensors, 9 | // size_t tensorCount, const char *opName); 10 | int rust_closure_callback_runner_op(void *closure, Tensor *const *tensors, 11 | size_t tensorCount, const void *op); 12 | 13 | void modelPrintIO(const char *model) { 14 | auto net = MNN::Interpreter::createFromFile(model); 15 | MNN::ScheduleConfig config; 16 | config.numThread = 4; 17 | config.type = MNN_FORWARD_METAL; 18 | MNN::Session *session = net->createSession(config); 19 | auto inputs = net->getSessionInputAll(session); 20 | for (auto input : inputs) { 21 | std::cout << "Input: " << input.first << " "; 22 | input.second->printShape(); 23 | } 24 | auto outputs = net->getSessionOutputAll(session); 25 | for (auto output : outputs) { 26 | std::cout << "Output: " << output.first << " "; 27 | output.second->printShape(); 28 | } 29 | } 30 | 31 | // const char *getVersion() { return MNN::getVersion(); } 32 | Interpreter *Interpreter_createFromFile(const char *file) { 33 | return reinterpret_cast( 34 | MNN::Interpreter::createFromFile(file)); 35 | } 36 | Interpreter *Interpreter_createFromBuffer(const void *buffer, size_t size) { 37 | return reinterpret_cast( 38 | MNN::Interpreter::createFromBuffer(buffer, size)); 39 | } 40 | void Interpreter_destroy(Interpreter *interpreter) { 41 | auto mnn_interpreter = reinterpret_cast(interpreter); 42 | MNN::Interpreter::destroy(mnn_interpreter); 43 | } 44 | void Interpreter_setSessionMode(Interpreter *interpreter, SessionMode mode) { 45 | auto mnn_interpreter = reinterpret_cast(interpreter); 46 | mnn_interpreter->setSessionMode( 47 | static_cast(mode)); 48 | } 49 | void Interpreter_setCacheFile(Interpreter *interpreter, const char *cacheFile, 50 | size_t keySize) { 51 | auto mnn_interpreter = reinterpret_cast(interpreter); 52 | mnn_interpreter->setCacheFile(cacheFile, keySize); 53 | } 54 | void Interpreter_setExternalFile(Interpreter *interpreter, const char *file, 55 | size_t flag) { 56 | auto mnn_interpreter = reinterpret_cast(interpreter); 57 | mnn_interpreter->setExternalFile(file, flag); 58 | } 59 | ErrorCode Interpreter_updateCacheFile(Interpreter *interpreter, 60 | Session *session) { 61 | auto mnn_interpreter = reinterpret_cast(interpreter); 62 | auto mnn_session = reinterpret_cast(session); 63 | return static_cast(mnn_interpreter->updateCacheFile(mnn_session)); 64 | } 65 | void Interpreter_setSessionHint(Interpreter *interpreter, int mode, int value) { 66 | auto mnn_interpreter = reinterpret_cast(interpreter); 67 | mnn_interpreter->setSessionHint(static_cast(mode), 68 | value); 69 | } 70 | // RuntimeInfo* Interpreter_createRuntime(const ScheduleConfig* configs, size_t 71 | // configSize) { 72 | // std::vector cppConfigs(configSize); 73 | // for (size_t i = 0; i < configSize; ++i) { 74 | // cppConfigs[i].saveTensors.assign(configs[i].saveTensors, 75 | // configs[i].saveTensors + configs[i].saveTensorsSize); 76 | // cppConfigs[i].type = configs[i].type; 77 | // cppConfigs[i].numThread = configs[i].numThread; 78 | // cppConfigs[i].path.inputs.assign(configs[i].path.inputs, 79 | // configs[i].path.inputs + configs[i].path.inputsSize); 80 | // cppConfigs[i].path.outputs.assign(configs[i].path.outputs, 81 | // configs[i].path.outputs + configs[i].path.outputsSize); 82 | // cppConfigs[i].path.mode = 83 | // static_cast(configs[i].path.mode); 84 | // cppConfigs[i].backupType = configs[i].backupType; 85 | // cppConfigs[i].backendConfig = configs[i].backendConfig; 86 | // } 87 | // auto runtimeInfo = MNN::Interpreter::createRuntime(cppConfigs); 88 | // return new RuntimeInfo{new std::map>(runtimeInfo.first), new 90 | // std::shared_ptr(runtimeInfo.second)}; 91 | // } 92 | Session *Interpreter_createSession(Interpreter *interpreter, 93 | const MNNScheduleConfig *config) { 94 | auto mnn_interpreter = reinterpret_cast(interpreter); 95 | auto mnn_schedule_config = 96 | reinterpret_cast(config); 97 | 98 | return reinterpret_cast( 99 | mnn_interpreter->createSession(*mnn_schedule_config)); 100 | } 101 | // Session* Interpreter_createSessionWithRuntime(Interpreter* interpreter, const 102 | // ScheduleConfig* config, const RuntimeInfo* runtime) { 103 | // MNN::ScheduleConfig cppConfig; 104 | // cppConfig.saveTensors.assign(config->saveTensors, config->saveTensors + 105 | // config->saveTensorsSize); cppConfig.type = config->type; 106 | // cppConfig.numThread = config->numThread; 107 | // cppConfig.path.inputs.assign(config->path.inputs, config->path.inputs + 108 | // config->path.inputsSize); 109 | // cppConfig.path.outputs.assign(config->path.outputs, config->path.outputs 110 | // + config->path.outputsSize); cppConfig.path.mode = 111 | // static_cast(config->path.mode); 112 | // cppConfig.backupType = config->backupType; 113 | // cppConfig.backendConfig = config->backendConfig; 114 | // return interpreter->createSession(cppConfig, *runtime); 115 | // } 116 | // Session *Interpreter_createMultiPathSession(Interpreter *interpreter, 117 | // const MNNScheduleConfig *configs, 118 | // size_t configSize) { 119 | // 120 | // auto mnn_configs = reinterpret_cast(configs); 121 | // std::vector cppConfigs(mnn_configs, 122 | // mnn_configs + configSize); 123 | // auto mnn_interpreter = reinterpret_cast(interpreter); 124 | // return reinterpret_cast( 125 | // mnn_interpreter->createMultiPathSession(cppConfigs)); 126 | // } 127 | Session * 128 | Interpreter_createMultiPathSession(Interpreter *interpreter, 129 | const MNNScheduleConfig *const *configs, 130 | size_t configSize) { 131 | auto mnn_configs = 132 | reinterpret_cast(configs); 133 | std::vector s_configs; 134 | for (size_t i = 0; i < configSize; ++i) { 135 | s_configs.push_back(*mnn_configs[i]); 136 | } 137 | // std::vector cppConfigs(mnn_configs, 138 | // mnn_configs + configSize); 139 | // Create a std::vector from 140 | // std::vector 141 | // auto s_configs = 142 | // std::vector(cppConfigs.begin(), cppConfigs.end()); 143 | auto mnn_interpreter = reinterpret_cast(interpreter); 144 | MNN::Session *session = mnn_interpreter->createMultiPathSession(s_configs); 145 | return reinterpret_cast(session); 146 | } 147 | 148 | // Session* Interpreter_createMultiPathSessionWithRuntime(Interpreter* 149 | // interpreter, const ScheduleConfig* configs, size_t configSize, const 150 | // RuntimeInfo* runtime) { 151 | // } 152 | int Interpreter_releaseSession(Interpreter *interpreter, Session *session) { 153 | auto mnn_interpreter = reinterpret_cast(interpreter); 154 | auto mnn_session = reinterpret_cast(session); 155 | return mnn_interpreter->releaseSession(mnn_session); 156 | } 157 | void Interpreter_resizeSession(Interpreter *interpreter, Session *session) { 158 | auto mnn_interpreter = reinterpret_cast(interpreter); 159 | auto mnn_session = reinterpret_cast(session); 160 | mnn_interpreter->resizeSession(mnn_session); 161 | } 162 | void Interpreter_resizeSessionWithFlag(Interpreter *interpreter, 163 | Session *session, int needRelloc) { 164 | auto mnn_interpreter = reinterpret_cast(interpreter); 165 | auto mnn_session = reinterpret_cast(session); 166 | mnn_interpreter->resizeSession(mnn_session, needRelloc); 167 | } 168 | void Interpreter_releaseModel(Interpreter *interpreter) { 169 | auto mnn_interpreter = reinterpret_cast(interpreter); 170 | mnn_interpreter->releaseModel(); 171 | } 172 | // std::pair Interpreter_getModelBuffer(const Interpreter* 173 | // interpreter) { 174 | // auto mnn_interpreter = reinterpret_cast(interpreter); return mnn_interpreter->getModelBuffer(); 176 | // } 177 | const char *Interpreter_getModelVersion(const Interpreter *interpreter) { 178 | auto mnn_interpreter = 179 | reinterpret_cast(interpreter); 180 | return mnn_interpreter->getModelVersion(); 181 | } 182 | ErrorCode Interpreter_updateSessionToModel(Interpreter *interpreter, 183 | Session *session) { 184 | auto mnn_interpreter = reinterpret_cast(interpreter); 185 | auto mnn_session = reinterpret_cast(session); 186 | return static_cast( 187 | mnn_interpreter->updateSessionToModel(mnn_session)); 188 | } 189 | ErrorCode Interpreter_runSession(const Interpreter *interpreter, 190 | Session *session) { 191 | auto mnn_interpreter = 192 | reinterpret_cast(interpreter); 193 | auto mnn_session = reinterpret_cast(session); 194 | return static_cast(mnn_interpreter->runSession(mnn_session)); 195 | } 196 | // ErrorCode Interpreter_runSessionWithCallBack(const Interpreter *interpreter, 197 | // const Session *session, 198 | // void *before, void *end, 199 | // int sync) { 200 | // MNN::TensorCallBack beforeCpp = 201 | // [before](const std::vector &tensors, 202 | // const std::string &opName) { 203 | // if (before == nullptr) { 204 | // return true; 205 | // } 206 | // return static_cast(rust_closure_callback_runner( 207 | // before, reinterpret_cast(tensors.data()), 208 | // tensors.size(), opName.c_str())); 209 | // }; 210 | // 211 | // MNN::TensorCallBack endCpp = [end](const std::vector 212 | // &tensors, 213 | // const std::string &opName) { 214 | // if (end == nullptr) { 215 | // return true; 216 | // } 217 | // return static_cast(rust_closure_callback_runner( 218 | // end, reinterpret_cast(tensors.data()), 219 | // tensors.size(), opName.c_str())); 220 | // }; 221 | // auto net = reinterpret_cast(interpreter); 222 | // auto sess = reinterpret_cast(session); 223 | // auto ret = net->runSessionWithCallBack(sess, beforeCpp, endCpp, 224 | // static_cast(sync)); 225 | // return static_cast(ret); 226 | // } 227 | 228 | ErrorCode Interpreter_runSessionWithCallBackInfo(const Interpreter *interpreter, 229 | const Session *session, 230 | void *before, void *end, 231 | int sync) { 232 | MNN::TensorCallBackWithInfo beforeCpp = 233 | [before](const std::vector &tensors, 234 | const MNN::OperatorInfo *op) { 235 | if (before == nullptr) { 236 | return true; 237 | } 238 | return static_cast(rust_closure_callback_runner_op( 239 | before, reinterpret_cast(tensors.data()), 240 | tensors.size(), reinterpret_cast(op))); 241 | }; 242 | MNN::TensorCallBackWithInfo endCpp = 243 | [end](const std::vector &tensors, 244 | const MNN::OperatorInfo *op) { 245 | if (end == nullptr) { 246 | return true; 247 | } 248 | return static_cast(rust_closure_callback_runner_op( 249 | end, reinterpret_cast(tensors.data()), 250 | tensors.size(), reinterpret_cast(op))); 251 | }; 252 | auto net = reinterpret_cast(interpreter); 253 | auto sess = reinterpret_cast(session); 254 | auto ret = net->runSessionWithCallBackInfo(sess, beforeCpp, endCpp, 255 | static_cast(sync)); 256 | return static_cast(ret); 257 | } 258 | 259 | Tensor *Interpreter_getSessionInput(Interpreter *interpreter, 260 | const Session *session, const char *name) { 261 | auto mnn_interpreter = reinterpret_cast(interpreter); 262 | auto mnn_session = reinterpret_cast(session); 263 | return reinterpret_cast( 264 | mnn_interpreter->getSessionInput(mnn_session, name)); 265 | } 266 | 267 | Tensor *Interpreter_getSessionOutput(Interpreter *interpreter, 268 | const Session *session, const char *name) { 269 | auto mnn_interpreter = reinterpret_cast(interpreter); 270 | auto mnn_session = reinterpret_cast(session); 271 | return reinterpret_cast( 272 | mnn_interpreter->getSessionOutput(mnn_session, name)); 273 | } 274 | int Interpreter_getSessionInfo(Interpreter *interpreter, const Session *session, 275 | int code, void *ptr) { 276 | auto mnn_interpreter = reinterpret_cast(interpreter); 277 | auto mnn_session = reinterpret_cast(session); 278 | auto ret = mnn_interpreter->getSessionInfo( 279 | mnn_session, static_cast(code), ptr); 280 | return static_cast(ret); 281 | } 282 | TensorInfoArray const * 283 | Interpreter_getSessionOutputAll(const Interpreter *interpreter, 284 | const Session *session) { 285 | auto mnn_interpreter = 286 | reinterpret_cast(interpreter); 287 | auto mnn_session = reinterpret_cast(session); 288 | auto outputMap = mnn_interpreter->getSessionOutputAll(mnn_session); 289 | auto out = createTensorInfoArray(outputMap.size()); 290 | size_t index = 0; 291 | for (const auto &entry : outputMap) { 292 | out->tensors[index].name = 293 | createCString(entry.first.c_str(), entry.first.size()); 294 | out->tensors[index].tensor = static_cast(entry.second); 295 | ++index; 296 | } 297 | return out; 298 | } 299 | TensorInfoArray const * 300 | Interpreter_getSessionInputAll(const Interpreter *interpreter, 301 | const Session *session) { 302 | auto mnn_interpreter = 303 | reinterpret_cast(interpreter); 304 | auto mnn_session = reinterpret_cast(session); 305 | auto inputMap = mnn_interpreter->getSessionInputAll(mnn_session); 306 | auto in = createTensorInfoArray(inputMap.size()); 307 | size_t index = 0; 308 | for (const auto &entry : inputMap) { 309 | in->tensors[index].name = 310 | createCString(entry.first.c_str(), entry.first.size()); 311 | in->tensors[index].tensor = static_cast(entry.second); 312 | ++index; 313 | } 314 | return in; 315 | } 316 | void Interpreter_resizeTensor(Interpreter *interpreter, Tensor *tensor, 317 | const int *dims, size_t dimsSize) { 318 | std::vector cppDims(dims, dims + dimsSize); 319 | auto mnn_interpreter = reinterpret_cast(interpreter); 320 | auto mnn_tensor = reinterpret_cast(tensor); 321 | mnn_interpreter->resizeTensor(mnn_tensor, cppDims); 322 | } 323 | void Interpreter_resizeTensorByNCHW(Interpreter *interpreter, Tensor *tensor, 324 | int batch, int channel, int height, 325 | int width) { 326 | auto mnn_interpreter = reinterpret_cast(interpreter); 327 | auto mnn_tensor = reinterpret_cast(tensor); 328 | mnn_interpreter->resizeTensor(mnn_tensor, batch, channel, height, width); 329 | } 330 | const Backend *Interpreter_getBackend(const Interpreter *interpreter, 331 | const Session *session, 332 | const Tensor *tensor) { 333 | auto mnn_interpreter = 334 | reinterpret_cast(interpreter); 335 | auto mnn_session = reinterpret_cast(session); 336 | auto mnn_tensor = reinterpret_cast(tensor); 337 | return reinterpret_cast( 338 | mnn_interpreter->getBackend(mnn_session, mnn_tensor)); 339 | } 340 | const char *Interpreter_bizCode(const Interpreter *interpreter) { 341 | auto mnn_interpreter = 342 | reinterpret_cast(interpreter); 343 | return mnn_interpreter->bizCode(); 344 | } 345 | const char *Interpreter_uuid(const Interpreter *interpreter) { 346 | auto mnn_interpreter = 347 | reinterpret_cast(interpreter); 348 | return mnn_interpreter->uuid(); 349 | } 350 | const char *OperatorInfo_name(const void *op) { 351 | return reinterpret_cast(op)->name().c_str(); 352 | } 353 | const char *OperatorInfo_type(const void *op) { 354 | return reinterpret_cast(op)->type().c_str(); 355 | } 356 | float OperatorInfo_flops(const void *op) { 357 | return reinterpret_cast(op)->flops(); 358 | } 359 | } // extern "C" 360 | -------------------------------------------------------------------------------- /src/schedule.rs: -------------------------------------------------------------------------------- 1 | use mnn_sys::*; 2 | use std::{ffi::CString, mem::ManuallyDrop}; 3 | 4 | use crate::{BackendConfig, prelude::*}; 5 | 6 | /// Backend used for running the model 7 | /// 8 | /// The `ForwardType` enum is used to specify the backend that will be used for forward computation 9 | /// in the MNN framework. Each variant corresponds to a different backend, which may be enabled 10 | /// or disabled based on the features enabled in the build configuration. 11 | /// 12 | /// # Variants 13 | /// 14 | /// - `All`: Use all available backends. 15 | /// - `Auto`: Automatically select the best backend based on the current environment and hardware. 16 | /// - `CPU`: Use the CPU for computation. 17 | /// - `Metal`: Use the Metal backend for computation (requires the `metal` feature). 18 | /// - `OpenCL`: Use the OpenCL backend for computation (requires the `opencl` feature). 19 | /// - `OpenGL`: Use the OpenGL backend for computation (requires the `opengl` feature). 20 | /// - `Vulkan`: Use the Vulkan backend for computation (requires the `vulkan` feature). 21 | /// - `CoreML`: Use the CoreML backend for computation (requires the `coreml` feature). 22 | /// 23 | /// # Example 24 | /// 25 | /// ```rust 26 | /// use mnn::schedule::ForwardType; 27 | /// 28 | /// let forward_type = ForwardType::Auto; 29 | /// println!("Selected forward type: {:?}", forward_type); 30 | /// ``` 31 | /// 32 | /// # Note 33 | /// 34 | /// The availability of certain variants depends on the features enabled during the build. 35 | /// For example, the `Metal` variant is only available if the `metal` feature is enabled. 36 | #[derive(Debug, Copy, Clone, Default, PartialEq, Eq)] 37 | #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 38 | pub enum ForwardType { 39 | /// Use all available backends. 40 | All, 41 | #[default] 42 | /// Try to automatically select the best backend based on the current environment and hardware. 43 | Auto, 44 | /// Use the CPU for computation. 45 | CPU, 46 | #[cfg(feature = "metal")] 47 | /// Use the Metal backend for computation. 48 | Metal, 49 | #[cfg(feature = "opencl")] 50 | /// Use the OpenCL backend for computation. 51 | OpenCL, 52 | /// Use the Vulkan backend for computation. 53 | #[cfg(feature = "vulkan")] 54 | Vulkan, 55 | /// Use the CoreML backend for computation. 56 | #[cfg(feature = "coreml")] 57 | CoreML, 58 | } 59 | 60 | impl ForwardType { 61 | /// Convert the `ForwardType` enum to the corresponding C++ `MNNForwardType` enum. 62 | fn to_mnn_sys(self) -> MNNForwardType { 63 | match self { 64 | ForwardType::Auto => MNNForwardType::MNN_FORWARD_AUTO, 65 | ForwardType::All => MNNForwardType::MNN_FORWARD_ALL, 66 | ForwardType::CPU => MNNForwardType::MNN_FORWARD_CPU, 67 | #[cfg(feature = "metal")] 68 | ForwardType::Metal => MNNForwardType::MNN_FORWARD_METAL, 69 | #[cfg(feature = "opencl")] 70 | ForwardType::OpenCL => MNNForwardType::MNN_FORWARD_OPENCL, 71 | #[cfg(feature = "opengl")] 72 | ForwardType::OpenGL => MNNForwardType::MNN_FORWARD_OPENGL, 73 | #[cfg(feature = "vulkan")] 74 | ForwardType::Vulkan => MNNForwardType::MNN_FORWARD_VULKAN, 75 | #[cfg(feature = "coreml")] 76 | ForwardType::CoreML => MNNForwardType::MNN_FORWARD_NN, 77 | } 78 | } 79 | 80 | fn from_mnn_sys(mode: MNNForwardType) -> Self { 81 | match mode { 82 | MNNForwardType::MNN_FORWARD_AUTO => ForwardType::Auto, 83 | MNNForwardType::MNN_FORWARD_ALL => ForwardType::All, 84 | MNNForwardType::MNN_FORWARD_CPU => ForwardType::CPU, 85 | #[cfg(feature = "metal")] 86 | MNNForwardType::MNN_FORWARD_METAL => ForwardType::Metal, 87 | #[cfg(feature = "opencl")] 88 | MNNForwardType::MNN_FORWARD_OPENCL => ForwardType::OpenCL, 89 | #[cfg(feature = "opengl")] 90 | MNNForwardType::MNN_FORWARD_OPENGL => ForwardType::OpenGL, 91 | #[cfg(feature = "vulkan")] 92 | MNNForwardType::MNN_FORWARD_VULKAN => ForwardType::Vulkan, 93 | #[cfg(feature = "coreml")] 94 | MNNForwardType::MNN_FORWARD_NN => ForwardType::CoreML, 95 | _ => ForwardType::Auto, 96 | } 97 | } 98 | 99 | /// List all available `ForwardType` variants as string slices. 100 | fn list() -> Vec<&'static str> { 101 | vec![ 102 | "auto", 103 | "all", 104 | "cpu", 105 | #[cfg(feature = "metal")] 106 | "metal", 107 | #[cfg(feature = "opencl")] 108 | "opencl", 109 | #[cfg(feature = "opengl")] 110 | "opengl", 111 | #[cfg(feature = "vulkan")] 112 | "vulkan", 113 | #[cfg(feature = "coreml")] 114 | "coreml", 115 | ] 116 | } 117 | 118 | /// Convert the `ForwardType` enum to a string slice. 119 | pub fn to_str(self) -> &'static str { 120 | match self { 121 | ForwardType::Auto => "auto", 122 | ForwardType::All => "all", 123 | ForwardType::CPU => "cpu", 124 | #[cfg(feature = "metal")] 125 | ForwardType::Metal => "metal", 126 | #[cfg(feature = "opencl")] 127 | ForwardType::OpenCL => "opencl", 128 | #[cfg(feature = "opengl")] 129 | ForwardType::OpenGL => "opengl", 130 | #[cfg(feature = "vulkan")] 131 | ForwardType::Vulkan => "vulkan", 132 | #[cfg(feature = "coreml")] 133 | ForwardType::CoreML => "coreml", 134 | } 135 | } 136 | } 137 | 138 | impl core::str::FromStr for ForwardType { 139 | type Err = MNNError; 140 | 141 | fn from_str(s: &str) -> Result { 142 | match s { 143 | "auto" => Ok(ForwardType::Auto), 144 | "all" => Ok(ForwardType::All), 145 | "cpu" => Ok(ForwardType::CPU), 146 | #[cfg(feature = "metal")] 147 | "metal" => Ok(ForwardType::Metal), 148 | #[cfg(feature = "opencl")] 149 | "opencl" => Ok(ForwardType::OpenCL), 150 | #[cfg(feature = "opengl")] 151 | "opengl" => Ok(ForwardType::OpenGL), 152 | #[cfg(feature = "vulkan")] 153 | "vulkan" => Ok(ForwardType::Vulkan), 154 | #[cfg(feature = "coreml")] 155 | "coreml" => Ok(ForwardType::CoreML), 156 | _ => Err(MNNError::new(crate::ErrorKind::ParseError) 157 | .attach_printable(format!( 158 | "Invalid ForwardType: {s}, maybe you might need to enable feature {s}" 159 | )) 160 | .attach_printable(format!( 161 | "Valid ForwardType: {}", 162 | ForwardType::list().join(", ") 163 | ))), 164 | } 165 | } 166 | } 167 | 168 | /// Configuration for scheduling the forward computation in MNN. 169 | /// 170 | /// The `ScheduleConfig` struct is used to configure various parameters for scheduling the forward 171 | /// computation in the MNN framework. It allows setting the type of backend, the number of threads, 172 | /// the mode of computation, and other options. 173 | /// 174 | /// # Example 175 | /// 176 | /// ```rust 177 | /// use mnn::schedule::{ScheduleConfig, ForwardType}; 178 | /// 179 | /// let mut config = ScheduleConfig::new(); 180 | /// config.set_type(ForwardType::Auto); 181 | /// config.set_num_threads(4); 182 | /// config.set_mode(0); 183 | /// ``` 184 | /// 185 | /// # Fields 186 | /// 187 | /// - `inner`: A raw pointer to the underlying `MNNScheduleConfig` structure. 188 | /// - `backend_config`: Specifies backend-specific configurations. 189 | /// - `__marker`: A marker to ensure the struct is `!Send` by default. 190 | /// 191 | /// # Methods 192 | /// 193 | /// - `new() -> Self`: Creates a new `ScheduleConfig` with default settings. 194 | /// - `as_ptr_mut(&mut self) -> *mut MNNScheduleConfig`: Returns a mutable raw pointer to the underlying `MNNScheduleConfig`. 195 | /// - `set_save_tensors(&mut self, save_tensors: &[&str]) -> Result<()>`: Sets the tensors to be saved during computation. 196 | /// - `set_type(&mut self, forward_type: ForwardType)`: Sets the type of backend to be used for computation. 197 | /// - `set_num_threads(&mut self, num_threads: i32)`: Sets the number of threads to be used for computation. 198 | /// - `set_mode(&mut self, mode: i32)`: Sets the mode of computation. 199 | /// - `set_backup_type(&mut self, backup_type: ForwardType)`: Sets the backup type of backend to be used if the primary backend fails. 200 | /// - `set_backend_config(&mut self, backend_config: impl Into>)`: Sets the backend-specific configuration. 201 | /// 202 | /// # Safety 203 | /// 204 | /// The `ScheduleConfig` struct contains raw pointers and interacts with the underlying C API of MNN. 205 | /// Users should be cautious when using this struct to avoid undefined behavior. 206 | /// 207 | /// # Warning 208 | /// 209 | /// **Warning:** The `Drop` implementation for `ScheduleConfig` ensures that the underlying `MNNScheduleConfig` 210 | /// is properly destroyed when the struct goes out of scope. Users should not manually free the `inner` pointer. 211 | pub struct ScheduleConfig { 212 | pub(crate) inner: *mut MNNScheduleConfig, 213 | pub(crate) backend_config: Option, 214 | pub(crate) __marker: core::marker::PhantomData<()>, 215 | } 216 | 217 | impl core::fmt::Debug for ScheduleConfig { 218 | fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 219 | f.debug_struct("ScheduleConfig") 220 | .field("type", &self.get_type()) 221 | .field("backup_type", &self.get_backup_type()) 222 | .field("backend_config", &self.backend_config) 223 | .finish() 224 | } 225 | } 226 | 227 | #[cfg(feature = "serde")] 228 | impl serde::Serialize for ScheduleConfig { 229 | fn serialize(&self, serializer: S) -> Result { 230 | use serde::ser::SerializeStruct; 231 | let mut state = serializer.serialize_struct("ScheduleConfig", 3)?; 232 | state.serialize_field("type", &self.get_type())?; 233 | state.serialize_field("backup_type", &self.get_backup_type())?; 234 | state.serialize_field("backend_config", &self.backend_config)?; 235 | state.end() 236 | } 237 | } 238 | 239 | impl Clone for ScheduleConfig { 240 | fn clone(&self) -> Self { 241 | unsafe { 242 | let inner = mnnsc_clone(self.inner); 243 | Self { 244 | inner, 245 | backend_config: self.backend_config.clone(), 246 | __marker: core::marker::PhantomData, 247 | } 248 | } 249 | } 250 | } 251 | 252 | impl Drop for ScheduleConfig { 253 | fn drop(&mut self) { 254 | unsafe { 255 | mnn_sys::mnnsc_destroy(self.inner); 256 | } 257 | } 258 | } 259 | 260 | unsafe impl Send for ScheduleConfig {} 261 | 262 | impl Default for ScheduleConfig { 263 | fn default() -> Self { 264 | Self::new() 265 | } 266 | } 267 | 268 | impl ScheduleConfig { 269 | /// Returns a mutable raw pointer to the underlying `MNNScheduleConfig`. 270 | pub fn as_ptr_mut(&mut self) -> *mut MNNScheduleConfig { 271 | self.inner 272 | } 273 | 274 | /// Creates a new `ScheduleConfig` with default settings. 275 | pub fn new() -> Self { 276 | unsafe { 277 | let inner = mnnsc_create(); 278 | Self { 279 | inner, 280 | backend_config: None, 281 | __marker: core::marker::PhantomData, 282 | } 283 | } 284 | } 285 | 286 | /// Sets the tensors to be saved during computation. 287 | /// 288 | /// # Arguments 289 | /// 290 | /// - `save_tensors`: A slice of tensor names to be saved. 291 | /// 292 | /// # Errors 293 | /// 294 | /// Returns an error if any of the tensor names contain null bytes. 295 | pub fn set_save_tensors(&mut self, save_tensors: &[&str]) -> Result<&mut Self> { 296 | let vec_cstring = save_tensors 297 | .iter() 298 | .map(|s| std::ffi::CString::new(*s).map_err(|e| error!(ErrorKind::AsciiError, e))) 299 | .collect::>>()?; 300 | let vec_cstr = vec_cstring 301 | .iter() 302 | .map(|s: &CString| s.as_c_str().as_ptr()) 303 | .collect::>(); 304 | unsafe { mnnsc_set_save_tensors(self.inner, vec_cstr.as_ptr(), vec_cstr.len()) } 305 | Ok(self) 306 | } 307 | 308 | /// Sets the type of backend to be used for computation. 309 | /// 310 | /// # Arguments 311 | /// 312 | /// - `forward_type`: The type of backend to be used. 313 | pub fn set_type(&mut self, forward_type: ForwardType) -> &mut Self { 314 | unsafe { 315 | mnnsc_set_type(self.inner, forward_type.to_mnn_sys()); 316 | } 317 | self 318 | } 319 | 320 | /// Sets the type of backend to be used for computation. 321 | pub fn with_type(mut self, forward_type: ForwardType) -> Self { 322 | self.set_type(forward_type); 323 | self 324 | } 325 | 326 | /// Gets the type of backend to be used for computation. 327 | pub fn get_type(&self) -> ForwardType { 328 | unsafe { ForwardType::from_mnn_sys(mnnsc_get_type(self.inner)) } 329 | } 330 | 331 | /// Sets the number of threads to be used for computation. 332 | /// 333 | /// # Arguments 334 | /// 335 | /// - `num_threads`: The number of threads to be used. 336 | pub fn set_num_threads(&mut self, num_threads: i32) -> &mut Self { 337 | unsafe { 338 | mnnsc_set_num_threads(self.inner, num_threads); 339 | } 340 | self 341 | } 342 | 343 | /// Sets the number of threads to be used for computation. 344 | pub fn with_num_threads(mut self, num_threads: i32) -> Self { 345 | self.set_num_threads(num_threads); 346 | self 347 | } 348 | 349 | /// Sets the mode of computation. 350 | /// 351 | /// # Arguments 352 | /// 353 | /// - `mode`: The mode of computation. 354 | pub fn set_mode(&mut self, mode: i32) -> &mut Self { 355 | unsafe { 356 | mnnsc_set_mode(self.inner, mode); 357 | } 358 | self 359 | } 360 | 361 | /// Sets the mode of computation. 362 | pub fn with_mode(mut self, mode: i32) -> Self { 363 | self.set_mode(mode); 364 | self 365 | } 366 | 367 | /// Sets the backup type of backend to be used if the primary backend fails. 368 | /// 369 | /// # Arguments 370 | /// 371 | /// - `backup_type`: The backup type of backend to be used. 372 | pub fn set_backup_type(&mut self, backup_type: ForwardType) -> &mut Self { 373 | unsafe { 374 | mnnsc_set_backup_type(self.inner, backup_type.to_mnn_sys()); 375 | } 376 | self 377 | } 378 | 379 | /// Sets the backup type of backend to be used if the primary backend fails. 380 | pub fn with_backup_type(mut self, backup_type: ForwardType) -> Self { 381 | self.set_backup_type(backup_type); 382 | self 383 | } 384 | 385 | /// Gets the backup type of backend to be used if the primary backend fails. 386 | pub fn get_backup_type(&self) -> ForwardType { 387 | unsafe { ForwardType::from_mnn_sys(mnnsc_get_backup_type(self.inner)) } 388 | } 389 | 390 | /// Sets the backend-specific configuration. 391 | /// 392 | /// # Arguments 393 | /// 394 | /// - `backend_config`: specifies additional backend-specific configurations. 395 | pub fn set_backend_config( 396 | &mut self, 397 | backend_config: impl Into>, 398 | ) -> &mut Self { 399 | self.backend_config = backend_config.into(); 400 | let ptr = if let Some(ref b) = self.backend_config { 401 | b.inner 402 | } else { 403 | core::ptr::null_mut() 404 | }; 405 | unsafe { 406 | mnnsc_set_backend_config(self.inner, ptr); 407 | } 408 | self 409 | } 410 | 411 | /// Sets the backend-specific configuration. 412 | pub fn with_backend_config(mut self, backend_config: impl Into>) -> Self { 413 | self.set_backend_config(backend_config); 414 | self 415 | } 416 | } 417 | 418 | /// A list of `ScheduleConfig` objects to be used for scheduling the forward computation in MNN. 419 | #[derive(Debug)] 420 | pub struct ScheduleConfigs { 421 | pub(crate) inner: Vec<*const MNNScheduleConfig>, 422 | pub(crate) backend_configs: Vec>, 423 | } 424 | 425 | impl Drop for ScheduleConfigs { 426 | fn drop(&mut self) { 427 | unsafe { 428 | for i in self.inner.iter() { 429 | mnnsc_destroy(*i.cast()); 430 | } 431 | } 432 | } 433 | } 434 | 435 | impl ScheduleConfigs { 436 | /// Pushed a new `ScheduleConfig` to the list of configurations. 437 | pub fn push(&mut self, config: ScheduleConfig) { 438 | let mut config = ManuallyDrop::new(config); 439 | self.inner.push(config.inner); 440 | self.backend_configs.push(config.backend_config.take()); 441 | } 442 | 443 | /// Creates a new (empty) `ScheduleConfigs` with the specified capacity. 444 | pub fn with_capacity(capacity: usize) -> Self { 445 | Self { 446 | inner: Vec::with_capacity(capacity), 447 | backend_configs: Vec::with_capacity(capacity), 448 | } 449 | } 450 | 451 | /// Creates a new (empty) `ScheduleConfigs` with default settings. 452 | pub const fn new() -> Self { 453 | Self { 454 | inner: Vec::new(), 455 | backend_configs: Vec::new(), 456 | } 457 | } 458 | } 459 | 460 | impl Default for ScheduleConfigs { 461 | fn default() -> Self { 462 | Self::new() 463 | } 464 | } 465 | 466 | impl FromIterator for ScheduleConfigs { 467 | fn from_iter>(iter: T) -> Self { 468 | let iter = iter.into_iter(); 469 | let mut ret = Self::with_capacity(iter.size_hint().1.unwrap_or_default()); 470 | iter.for_each(|item| { 471 | ret.push(item); 472 | }); 473 | ret 474 | } 475 | } 476 | 477 | unsafe impl Send for ScheduleConfigs {} 478 | -------------------------------------------------------------------------------- /mnn-sys/build.rs: -------------------------------------------------------------------------------- 1 | use ::tap::*; 2 | use anyhow::*; 3 | #[cfg(unix)] 4 | use std::os::unix::fs::PermissionsExt; 5 | use std::{ 6 | path::{Path, PathBuf}, 7 | sync::LazyLock, 8 | }; 9 | const VENDOR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/vendor"); 10 | const MANIFEST_DIR: &str = env!("CARGO_MANIFEST_DIR"); 11 | static TARGET_OS: LazyLock = 12 | LazyLock::new(|| std::env::var("CARGO_CFG_TARGET_OS").expect("CARGO_CFG_TARGET_OS not set")); 13 | static TARGET_ARCH: LazyLock = LazyLock::new(|| { 14 | std::env::var("CARGO_CFG_TARGET_ARCH").expect("CARGO_CFG_TARGET_ARCH not found") 15 | }); 16 | static EMSCRIPTEN_CACHE: LazyLock = LazyLock::new(|| { 17 | let emscripten_cache = std::process::Command::new("em-config") 18 | .arg("CACHE") 19 | .output() 20 | .expect("Failed to get emscripten cache") 21 | .stdout; 22 | let emscripten_cache = std::str::from_utf8(&emscripten_cache) 23 | .expect("Failed to parse emscripten cache") 24 | .trim() 25 | .to_string(); 26 | emscripten_cache 27 | }); 28 | 29 | static MNN_COMPILE: LazyLock = LazyLock::new(|| { 30 | std::env::var("MNN_COMPILE") 31 | .ok() 32 | .and_then(|v| match v.as_str() { 33 | "1" | "true" | "yes" => Some(true), 34 | "0" | "false" | "no" => Some(false), 35 | _ => None, 36 | }) 37 | .unwrap_or(true) 38 | }); 39 | 40 | const HALIDE_SEARCH: &str = 41 | r#"HALIDE_ATTRIBUTE_ALIGN(1) halide_type_code_t code; // halide_type_code_t"#; 42 | const TRACING_SEARCH: &str = "#define MNN_PRINT(format, ...) printf(format, ##__VA_ARGS__)\n#define MNN_ERROR(format, ...) printf(format, ##__VA_ARGS__)"; 43 | const TRACING_REPLACE: &str = r#" 44 | enum class Level { 45 | Info = 0, 46 | Error = 1, 47 | }; 48 | extern "C" { 49 | void mnn_ffi_emit(const char *file, size_t line, Level level, 50 | const char *message); 51 | } 52 | #define MNN_PRINT(format, ...) \ 53 | { \ 54 | char logtmp[4096]; \ 55 | snprintf(logtmp, 4096, format, ##__VA_ARGS__); \ 56 | mnn_ffi_emit(__FILE__, __LINE__, Level::Info, logtmp); \ 57 | } 58 | 59 | #define MNN_ERROR(format, ...) \ 60 | { \ 61 | char logtmp[4096]; \ 62 | snprintf(logtmp, 4096, format, ##__VA_ARGS__); \ 63 | mnn_ffi_emit(__FILE__, __LINE__, Level::Error, logtmp); \ 64 | } 65 | "#; 66 | 67 | fn ensure_vendor_exists(vendor: impl AsRef) -> Result<()> { 68 | if vendor 69 | .as_ref() 70 | .read_dir() 71 | .with_context(|| format!("Vendor directory missing: {}", vendor.as_ref().display()))? 72 | .flatten() 73 | .count() 74 | == 0 75 | { 76 | anyhow::bail!("Vendor not found maybe you need to run \"git submodule update --init\"") 77 | } 78 | Ok(()) 79 | } 80 | 81 | fn main() -> Result<()> { 82 | println!("cargo:rerun-if-changed=build.rs"); 83 | println!("cargo:rerun-if-env-changed=MNN_SRC"); 84 | let out_dir = PathBuf::from(std::env::var("OUT_DIR")?); 85 | let source = PathBuf::from( 86 | std::env::var("MNN_SRC") 87 | .ok() 88 | .unwrap_or_else(|| VENDOR.into()), 89 | ); 90 | 91 | ensure_vendor_exists(&source)?; 92 | 93 | let vendor = out_dir.join("vendor"); 94 | // std::fs::remove_dir_all(&vendor).ok(); 95 | if !vendor.exists() { 96 | fs_extra::dir::copy( 97 | &source, 98 | &vendor, 99 | &fs_extra::dir::CopyOptions::new() 100 | .overwrite(true) 101 | .copy_inside(true), 102 | ) 103 | .context("Failed to copy vendor")?; 104 | let intptr = vendor.join("include").join("MNN").join("HalideRuntime.h"); 105 | #[cfg(unix)] 106 | std::fs::set_permissions(&intptr, std::fs::Permissions::from_mode(0o644))?; 107 | 108 | use itertools::Itertools; 109 | let intptr_contents = std::fs::read_to_string(&intptr)?; 110 | let patched = intptr_contents.lines().collect::>(); 111 | if let Some((idx, _)) = patched 112 | .iter() 113 | .find_position(|line| line.contains(HALIDE_SEARCH)) 114 | { 115 | // remove the last line and the next 3 lines 116 | let patched = patched 117 | .into_iter() 118 | .enumerate() 119 | .filter(|(c_idx, _)| !(*c_idx == idx - 1 || (idx + 1..=idx + 3).contains(c_idx))) 120 | .map(|(_, c)| c) 121 | .collect::>(); 122 | 123 | std::fs::write(intptr, patched.join("\n"))?; 124 | } 125 | 126 | let mnn_define = vendor.join("include").join("MNN").join("MNNDefine.h"); 127 | let patched = 128 | std::fs::read_to_string(&mnn_define)?.replace(TRACING_SEARCH, TRACING_REPLACE); 129 | #[cfg(unix)] 130 | std::fs::set_permissions(&mnn_define, std::fs::Permissions::from_mode(0o644))?; 131 | std::fs::write(mnn_define, patched)?; 132 | } 133 | 134 | if *MNN_COMPILE { 135 | let install_dir = out_dir.join("mnn-install"); 136 | build_cmake(&vendor, &install_dir)?; 137 | println!( 138 | "cargo:rustc-link-search=native={}", 139 | install_dir.join("lib").display() 140 | ); 141 | } else if let core::result::Result::Ok(lib_dir) = std::env::var("MNN_LIB_DIR") { 142 | println!("cargo:rustc-link-search=native={}", lib_dir); 143 | } else { 144 | panic!("MNN_LIB_DIR not set while MNN_COMPILE is false"); 145 | } 146 | 147 | mnn_c_build(PathBuf::from(MANIFEST_DIR).join("mnn_c"), &vendor) 148 | .with_context(|| "Failed to build mnn_c")?; 149 | mnn_c_bindgen(&vendor, &out_dir).with_context(|| "Failed to generate mnn_c bindings")?; 150 | mnn_cpp_bindgen(&vendor, &out_dir).with_context(|| "Failed to generate mnn_cpp bindings")?; 151 | println!("cargo:include={vendor}/include", vendor = vendor.display()); 152 | if *TARGET_OS == "macos" { 153 | #[cfg(feature = "metal")] 154 | println!("cargo:rustc-link-lib=framework=Foundation"); 155 | #[cfg(feature = "metal")] 156 | println!("cargo:rustc-link-lib=framework=CoreGraphics"); 157 | #[cfg(feature = "metal")] 158 | println!("cargo:rustc-link-lib=framework=Metal"); 159 | #[cfg(feature = "coreml")] 160 | println!("cargo:rustc-link-lib=framework=CoreML"); 161 | #[cfg(feature = "coreml")] 162 | println!("cargo:rustc-link-lib=framework=CoreVideo"); 163 | #[cfg(feature = "opencl")] 164 | println!("cargo:rustc-link-lib=framework=OpenCL"); 165 | #[cfg(feature = "opengl")] 166 | println!("cargo:rustc-link-lib=framework=OpenGL"); 167 | } else { 168 | // #[cfg(feature = "opencl")] 169 | // println!("cargo:rustc-link-lib=static=opencl"); 170 | } 171 | if is_emscripten() { 172 | // println!("cargo:rustc-link-lib=static=stdc++"); 173 | let emscripten_cache = std::process::Command::new("em-config") 174 | .arg("CACHE") 175 | .output()? 176 | .stdout; 177 | let emscripten_cache = std::str::from_utf8(&emscripten_cache)?.trim(); 178 | let wasm32_emscripten_libs = 179 | PathBuf::from(emscripten_cache).join("sysroot/lib/wasm32-emscripten"); 180 | println!( 181 | "cargo:rustc-link-search=native={}", 182 | wasm32_emscripten_libs.display() 183 | ); 184 | } 185 | println!("cargo:rustc-link-lib=static=MNN"); 186 | Ok(()) 187 | } 188 | 189 | pub fn mnn_c_bindgen(vendor: impl AsRef, out: impl AsRef) -> Result<()> { 190 | let vendor = vendor.as_ref(); 191 | let mnn_c = PathBuf::from(MANIFEST_DIR).join("mnn_c"); 192 | mnn_c.read_dir()?.flatten().for_each(|e| { 193 | rerun_if_changed(e.path()); 194 | }); 195 | const HEADERS: &[&str] = &[ 196 | "error_code_c.h", 197 | "interpreter_c.h", 198 | "tensor_c.h", 199 | "backend_c.h", 200 | "schedule_c.h", 201 | ]; 202 | 203 | let bindings = bindgen::Builder::default() 204 | // .clang_args(["-x", "c++"]) 205 | .clang_arg(CxxOption::VULKAN.cxx()) 206 | .clang_arg(CxxOption::METAL.cxx()) 207 | .clang_arg(CxxOption::COREML.cxx()) 208 | .clang_arg(CxxOption::OPENCL.cxx()) 209 | .pipe(|builder| { 210 | if is_emscripten() { 211 | println!("cargo:rustc-cdylib-link-arg=-fvisibility=default"); 212 | builder 213 | .clang_arg("-fvisibility=default") 214 | .clang_arg("--target=wasm32-emscripten") 215 | .clang_arg(format!("-I{}/sysroot/include", emscripten_cache())) 216 | } else { 217 | builder 218 | } 219 | }) 220 | .clang_arg(format!("-I{}", vendor.join("include").to_string_lossy())) 221 | .pipe(|generator| { 222 | HEADERS.iter().fold(generator, |gen, header| { 223 | gen.header(mnn_c.join(header).to_string_lossy()) 224 | }) 225 | }) 226 | .newtype_enum("MemoryMode") 227 | .newtype_enum("PowerMode") 228 | .newtype_enum("PrecisionMode") 229 | .constified_enum_module("SessionMode") 230 | .rustified_enum("DimensionType") 231 | .rustified_enum("HandleDataType") 232 | .rustified_enum("MapType") 233 | .rustified_enum("halide_type_code_t") 234 | .rustified_enum("ErrorCode") 235 | .rustified_enum("MNNGpuMode") 236 | .rustified_enum("MNNForwardType") 237 | .rustified_enum("RuntimeStatus") 238 | .no_copy("CString") 239 | .generate_cstr(true) 240 | .generate_inline_functions(true) 241 | .size_t_is_usize(true) 242 | .emit_diagnostics() 243 | .detect_include_paths(std::env::var("TARGET") == std::env::var("HOST")) 244 | .ctypes_prefix("core::ffi") 245 | // .tap(|d| { 246 | // // eprintln!("Full bindgen: {}", d.command_line_flags().join(" ")); 247 | // std::fs::write("bindgen.txt", d.command_line_flags().join(" ")).ok(); 248 | // }) 249 | .generate()?; 250 | bindings.write_to_file(out.as_ref().join("mnn_c.rs"))?; 251 | Ok(()) 252 | } 253 | 254 | pub fn mnn_cpp_bindgen(vendor: impl AsRef, out: impl AsRef) -> Result<()> { 255 | let vendor = vendor.as_ref(); 256 | let bindings = bindgen::Builder::default() 257 | .clang_args(["-x", "c++"]) 258 | .clang_args(["-std=c++14"]) 259 | .clang_arg(CxxOption::VULKAN.cxx()) 260 | .clang_arg(CxxOption::METAL.cxx()) 261 | .clang_arg(CxxOption::COREML.cxx()) 262 | .clang_arg(CxxOption::OPENCL.cxx()) 263 | .clang_arg(format!("-I{}", vendor.join("include").to_string_lossy())) 264 | .generate_cstr(true) 265 | .generate_inline_functions(true) 266 | .size_t_is_usize(true) 267 | .emit_diagnostics() 268 | .ctypes_prefix("core::ffi") 269 | .header( 270 | vendor 271 | .join("include") 272 | .join("MNN") 273 | .join("Interpreter.hpp") 274 | .to_string_lossy(), 275 | ) 276 | .allowlist_item(".*SessionInfoCode.*"); 277 | // let cmd = bindings.command_line_flags().join(" "); 278 | // println!("cargo:warn=bindgen: {}", cmd); 279 | let bindings = bindings.generate()?; 280 | bindings.write_to_file(out.as_ref().join("mnn_cpp.rs"))?; 281 | Ok(()) 282 | } 283 | 284 | pub fn mnn_c_build(path: impl AsRef, vendor: impl AsRef) -> Result<()> { 285 | let mnn_c = path.as_ref(); 286 | let files = mnn_c.read_dir()?.flatten().map(|e| e.path()).filter(|e| { 287 | e.extension() == Some(std::ffi::OsStr::new("cpp")) 288 | || e.extension() == Some(std::ffi::OsStr::new("c")) 289 | }); 290 | let vendor = vendor.as_ref(); 291 | cc::Build::new() 292 | .include(vendor.join("include")) 293 | // .includes(vulkan_includes(vendor)) 294 | .pipe(|config| { 295 | #[cfg(feature = "vulkan")] 296 | config.define("MNN_VULKAN", "1"); 297 | #[cfg(feature = "opengl")] 298 | config.define("MNN_OPENGL", "1"); 299 | #[cfg(feature = "metal")] 300 | config.define("MNN_METAL", "1"); 301 | #[cfg(feature = "coreml")] 302 | config.define("MNN_COREML", "1"); 303 | #[cfg(feature = "opencl")] 304 | config.define("MNN_OPENCL", "ON"); 305 | if is_emscripten() { 306 | config.compiler("emcc"); 307 | // We can't compile wasm32-unknown-unknown with emscripten 308 | config.target("wasm32-unknown-emscripten"); 309 | config.cpp_link_stdlib("c++-noexcept"); 310 | } 311 | #[cfg(feature = "crt_static")] 312 | config.static_crt(true); 313 | config 314 | }) 315 | .cpp(true) 316 | .static_flag(true) 317 | .files(files) 318 | .std("c++14") 319 | // .pipe(|build| { 320 | // let c = build.get_compiler(); 321 | // use std::io::Write; 322 | // writeln!( 323 | // std::fs::File::create("./command.txt").unwrap(), 324 | // "{:?}", 325 | // c.to_command() 326 | // ) 327 | // .unwrap(); 328 | // build 329 | // }) 330 | .try_compile("mnn_c") 331 | .context("Failed to compile mnn_c library")?; 332 | Ok(()) 333 | } 334 | 335 | pub fn build_cmake(path: impl AsRef, install: impl AsRef) -> Result<()> { 336 | let threads = std::thread::available_parallelism()?; 337 | cmake::Config::new(path) 338 | .define("CMAKE_CXX_STANDARD", "14") 339 | .parallel(threads.get() as u8) 340 | .define("MNN_BUILD_SHARED_LIBS", "OFF") 341 | .define("MNN_SEP_BUILD", "OFF") 342 | .define("MNN_PORTABLE_BUILD", "ON") 343 | .define("MNN_USE_SYSTEM_LIB", "OFF") 344 | .define("MNN_BUILD_CONVERTER", "OFF") 345 | .define("MNN_BUILD_TOOLS", "OFF") 346 | .define("CMAKE_INSTALL_PREFIX", install.as_ref()) 347 | // https://github.com/rust-lang/rust/issues/39016 348 | // https://github.com/rust-lang/cc-rs/pull/717 349 | // .define("CMAKE_BUILD_TYPE", "Release") 350 | .pipe(|config| { 351 | config.define("MNN_WIN_RUNTIME_MT", CxxOption::CRT_STATIC.cmake_value()); 352 | config.define("MNN_USE_THREAD_POOL", CxxOption::THREADPOOL.cmake_value()); 353 | config.define("MNN_OPENMP", CxxOption::OPENMP.cmake_value()); 354 | config.define("MNN_VULKAN", CxxOption::VULKAN.cmake_value()); 355 | config.define("MNN_METAL", CxxOption::METAL.cmake_value()); 356 | config.define("MNN_COREML", CxxOption::COREML.cmake_value()); 357 | config.define("MNN_OPENCL", CxxOption::OPENCL.cmake_value()); 358 | config.define("MNN_OPENGL", CxxOption::OPENGL.cmake_value()); 359 | // config.define("CMAKE_CXX_FLAGS", "-O0"); 360 | // #[cfg(windows)] 361 | if *TARGET_OS == "windows" { 362 | config.define("CMAKE_CXX_FLAGS", "-DWIN32=1"); 363 | } 364 | 365 | if is_emscripten() { 366 | config 367 | .define("CMAKE_C_COMPILER", "emcc") 368 | .define("CMAKE_CXX_COMPILER", "em++") 369 | .target("wasm32-unknown-emscripten"); 370 | } 371 | config 372 | }) 373 | .build(); 374 | Ok(()) 375 | } 376 | 377 | // pub fn try_patch_file(patch: impl AsRef, file: impl AsRef) -> Result<()> { 378 | // let patch = dunce::canonicalize(patch)?; 379 | // rerun_if_changed(&patch); 380 | // let patch = std::fs::read_to_string(&patch)?; 381 | // let patch = diffy::Patch::from_str(&patch)?; 382 | // let file_path = file.as_ref(); 383 | // let file = std::fs::read_to_string(file_path).context("Failed to read input file")?; 384 | // let patched_file = 385 | // diffy::apply(&file, &patch).context("Failed to apply patches using diffy")?; 386 | // std::fs::write(file_path, patched_file)?; 387 | // Ok(()) 388 | // } 389 | 390 | pub fn rerun_if_changed(path: impl AsRef) { 391 | println!("cargo:rerun-if-changed={}", path.as_ref().display()); 392 | } 393 | 394 | // pub fn vulkan_includes(vendor: impl AsRef) -> Vec { 395 | // let vendor = vendor.as_ref(); 396 | // let vulkan_dir = vendor.join("source/backend/vulkan"); 397 | // if cfg!(feature = "vulkan") { 398 | // vec![ 399 | // vulkan_dir.clone(), 400 | // vulkan_dir.join("runtime"), 401 | // vulkan_dir.join("component"), 402 | // // IDK If the order is important but the cmake file does it like this 403 | // vulkan_dir.join("buffer/execution"), 404 | // vulkan_dir.join("buffer/backend"), 405 | // vulkan_dir.join("buffer"), 406 | // vulkan_dir.join("buffer/shaders"), 407 | // // vulkan_dir.join("image/execution"), 408 | // // vulkan_dir.join("image/backend"), 409 | // // vulkan_dir.join("image"), 410 | // // vulkan_dir.join("image/shaders"), 411 | // vendor.join("schema/current"), 412 | // vendor.join("3rd_party/flatbuffers/include"), 413 | // vendor.join("source"), 414 | // ] 415 | // } else { 416 | // vec![] 417 | // } 418 | // } 419 | 420 | pub fn is_emscripten() -> bool { 421 | *TARGET_OS == "emscripten" && *TARGET_ARCH == "wasm32" 422 | } 423 | 424 | pub fn emscripten_cache() -> &'static str { 425 | &EMSCRIPTEN_CACHE 426 | } 427 | 428 | #[derive(Debug, Clone, Copy)] 429 | pub enum CxxOptionValue { 430 | On, 431 | Off, 432 | Value(&'static str), 433 | } 434 | 435 | impl From for CxxOptionValue { 436 | fn from(b: bool) -> Self { 437 | if b { 438 | Self::On 439 | } else { 440 | Self::Off 441 | } 442 | } 443 | } 444 | 445 | impl CxxOptionValue { 446 | pub const fn from_bool(value: bool) -> Self { 447 | match value { 448 | true => Self::On, 449 | false => Self::Off, 450 | } 451 | } 452 | } 453 | 454 | impl From<&'static str> for CxxOptionValue { 455 | fn from(s: &'static str) -> Self { 456 | match s { 457 | "ON" => Self::On, 458 | "OFF" => Self::Off, 459 | _ => Self::Value(s), 460 | } 461 | } 462 | } 463 | 464 | #[derive(Debug, Clone, Copy)] 465 | pub struct CxxOption { 466 | pub name: &'static str, 467 | pub value: CxxOptionValue, 468 | } 469 | 470 | macro_rules! cxx_option_from_feature { 471 | ($feature:literal, $cxx:literal) => {{ 472 | CxxOption::from_bool($cxx, cfg!(feature = $feature)) 473 | }}; 474 | } 475 | impl CxxOption { 476 | const fn from_bool(name: &'static str, value: bool) -> Self { 477 | Self { 478 | name, 479 | value: CxxOptionValue::from_bool(value), 480 | } 481 | } 482 | pub const VULKAN: CxxOption = cxx_option_from_feature!("vulkan", "MNN_VULKAN"); 483 | pub const METAL: CxxOption = cxx_option_from_feature!("metal", "MNN_METAL"); 484 | pub const COREML: CxxOption = cxx_option_from_feature!("coreml", "MNN_COREML"); 485 | pub const OPENCL: CxxOption = cxx_option_from_feature!("opencl", "MNN_OPENCL"); 486 | pub const OPENMP: CxxOption = cxx_option_from_feature!("openmp", "MNN_OPENMP"); 487 | pub const OPENGL: CxxOption = cxx_option_from_feature!("opengl", "MNN_OPENGL"); 488 | pub const CRT_STATIC: CxxOption = cxx_option_from_feature!("opengl", "MNN_WIN_RUNTIME_MT"); 489 | pub const THREADPOOL: CxxOption = 490 | cxx_option_from_feature!("mnn-threadpool", "MNN_USE_THREAD_POOL"); 491 | 492 | pub fn new(name: &'static str, value: impl Into) -> Self { 493 | Self { 494 | name, 495 | value: value.into(), 496 | } 497 | } 498 | 499 | pub fn on(mut self) -> Self { 500 | self.value = CxxOptionValue::On; 501 | self 502 | } 503 | 504 | pub fn off(mut self) -> Self { 505 | self.value = CxxOptionValue::Off; 506 | self 507 | } 508 | 509 | pub fn with_value(mut self, value: &'static str) -> Self { 510 | self.value = CxxOptionValue::Value(value); 511 | self 512 | } 513 | 514 | pub fn cmake(&self) -> String { 515 | match &self.value { 516 | CxxOptionValue::On => format!("-D{}=ON", self.name), 517 | CxxOptionValue::Off => format!("-D{}=OFF", self.name), 518 | CxxOptionValue::Value(v) => format!("-D{}={}", self.name, v), 519 | } 520 | } 521 | 522 | pub fn cmake_value(&self) -> &'static str { 523 | match &self.value { 524 | CxxOptionValue::On => "ON", 525 | CxxOptionValue::Off => "OFF", 526 | CxxOptionValue::Value(v) => v, 527 | } 528 | } 529 | 530 | pub fn cxx(&self) -> String { 531 | match &self.value { 532 | CxxOptionValue::On => format!("-D{}=1", self.name), 533 | CxxOptionValue::Off => format!("-D{}=0", self.name), 534 | CxxOptionValue::Value(v) => format!("-D{}={}", self.name, v), 535 | } 536 | } 537 | 538 | pub fn enabled(&self) -> bool { 539 | match self.value { 540 | CxxOptionValue::On => true, 541 | CxxOptionValue::Off => false, 542 | CxxOptionValue::Value(_) => true, 543 | } 544 | } 545 | } 546 | --------------------------------------------------------------------------------