├── CMakeLists.txt ├── README.md ├── android_build.sh ├── export_script ├── demo.py └── to_onnx.py ├── include ├── MNN │ ├── AutoTime.hpp │ ├── ErrorCode.hpp │ ├── HalideRuntime.h │ ├── ImageProcess.hpp │ ├── Interpreter.hpp │ ├── MNNDefine.h │ ├── MNNForwardType.h │ ├── MNNSharedContext.h │ ├── Matrix.h │ ├── Rect.h │ ├── Tensor.hpp │ ├── expr │ │ ├── Executor.hpp │ │ ├── ExecutorScope.hpp │ │ ├── Expr.hpp │ │ ├── ExprCreator.hpp │ │ ├── MathOp.hpp │ │ ├── Module.hpp │ │ ├── NeuralNetWorkOp.hpp │ │ ├── Optimizer.hpp │ │ └── Scope.hpp │ └── plugin │ │ ├── PluginContext.hpp │ │ ├── PluginKernel.hpp │ │ └── PluginShapeInference.hpp ├── cv │ ├── calib3d.hpp │ ├── cv.hpp │ ├── imgcodecs.hpp │ ├── imgproc │ │ ├── color.hpp │ │ ├── draw.hpp │ │ ├── filter.hpp │ │ ├── geometric.hpp │ │ ├── histograms.hpp │ │ ├── imgproc.hpp │ │ ├── miscellaneous.hpp │ │ └── structural.hpp │ └── types.hpp ├── pipeline.hpp └── tokenizer.hpp ├── libs └── README.md ├── resource ├── alphas.txt ├── demo.jpg ├── logo.png └── vocab.txt └── src ├── main.cpp ├── pipeline.cpp └── tokenizer.cpp /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0) 2 | project(stable-diffusion-mnn) 3 | 4 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") 5 | 6 | option(BUILD_FOR_ANDROID "Build for android." OFF) 7 | 8 | # include dir 9 | include_directories(${CMAKE_CURRENT_LIST_DIR}/include/) 10 | 11 | # libs dir 12 | link_directories(${CMAKE_CURRENT_LIST_DIR}/libs) 13 | 14 | # source files 15 | FILE(GLOB SRCS ${CMAKE_CURRENT_LIST_DIR}/src/*.cpp) 16 | 17 | if (BUILD_FOR_ANDROID) 18 | add_library(MNN SHARED IMPORTED) 19 | add_library(MNN_Express SHARED IMPORTED) 20 | add_library(MNNOpenCV SHARED IMPORTED) 21 | set_target_properties( 22 | MNN 23 | PROPERTIES IMPORTED_LOCATION 24 | ${CMAKE_CURRENT_LIST_DIR}/libs/libMNN.so 25 | ) 26 | set_target_properties( 27 | MNN_Express 28 | PROPERTIES IMPORTED_LOCATION 29 | ${CMAKE_CURRENT_LIST_DIR}/libs/libMNN_Express.so 30 | ) 31 | set_target_properties( 32 | MNNOpenCV 33 | PROPERTIES IMPORTED_LOCATION 34 | ${CMAKE_CURRENT_LIST_DIR}/libs/libMNNOpenCV.so 35 | ) 36 | add_executable(main ${SRCS}) 37 | target_link_libraries(main MNN MNN_Express MNNOpenCV log) 38 | else() 39 | # target 40 | add_executable(main ${SRCS}) 41 | if (MSVC) 42 | target_link_libraries(main MNN) 43 | else() 44 | target_link_libraries(main MNN MNN_Express MNNOpenCV) 45 | endif() 46 | endif() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![mnn-stable-diffusion](resource/logo.png) 2 | 3 | # mnn-stable-diffusion 4 | 5 | ## Usage 6 | 7 | ### 1. Compile MNN library 8 | #### Linx/Mac 9 | ```bash 10 | git clone https://github.com/alibaba/MNN.git 11 | cd MNN 12 | mkdir build 13 | cmake -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON .. 14 | make -j8 15 | cp libMNN.so express/libMNN_Express.so tools/cv/libMNNOpenCV.so /path/to/stable-diffusion-mnn/libs 16 | ``` 17 | 18 | #### Windows 19 | ```bash 20 | # Visual Studio xxxx Developer Command Prompt 21 | powershell 22 | git clone https://github.com/alibaba/MNN.git 23 | cd MNN 24 | mkdir build 25 | cmake -G "Ninja" -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON .. 26 | ninja 27 | cp MNN.dll MNN.lib /path/to/stable-diffusion-mnn/build 28 | ``` 29 | 30 | ### 2. Download Models 31 | Download models from `github release` to `/path/to/mnn-stable-diffusion/resource` 32 | ```bash 33 | cd resource 34 | wget https://github.com/wangzhaode/mnn-stable-diffusion/releases/download/v0.1/text_encoder.mnn 35 | wget https://github.com/wangzhaode/mnn-stable-diffusion/releases/download/v0.1/vae_decoder.mnn 36 | wget https://github.com/wangzhaode/mnn-stable-diffusion/releases/download/v0.1/unet.mnn 37 | ``` 38 | 39 | ### 2. Build and Run 40 | 41 | #### Linux/Mac 42 | ```bash 43 | mkdir build 44 | cd build 45 | cmake .. 46 | make -j4 47 | ./main "飞流直下三千尺,疑是银河落九天,唐诗,水墨,国画。" demo.jpg 48 | [##################################################] [100%] [iter time: 411.441000 ms] 49 | SUCCESS! write to demo.jpg 50 | ``` 51 | #### Windows 52 | ```bash 53 | # Visual Studio xxxx Developer Command Prompt 54 | powershell 55 | mkdir build 56 | cd build 57 | cmake -G "Ninja" .. 58 | ninja 59 | ./main "飞流直下三千尺,疑是银河落九天,唐诗,水墨,国画。" demo.jpg 60 | [##################################################] [100%] [iter time: 411.441000 ms] 61 | SUCCESS! write to demo.jpg 62 | ``` 63 | #### Android 64 | ```bash 65 | mkdir build 66 | cd build 67 | ../android_build.sh 68 | adb push main ../libs/*.so /data/local/tmp/ 69 | adb push ../resource /data/local/tmp/ 70 | adb shell 71 | cd /data/local/tmp/ 72 | ./main "飞流直下三千尺,疑是银河落九天,唐诗,水墨,国画。" demo.jpg 73 | [##################################################] [100%] [iter time: 411.441000 ms] 74 | SUCCESS! write to demo.jpg 75 | ``` 76 | 77 | ![demo.jpg](./resource/demo.jpg) 78 | 79 | ## Speed 80 | 81 | | device | speed (ms/iter) | 82 | |----------------|-----------------| 83 | | RTX 3060Ti | 421 | 84 | | Core i7-13700K | 15985 | 85 | 86 | 87 | ## Promot 88 | - 不要使用中文标点。(Don't use Chinese punctuation marks, such as commas and periods.) 89 | - 给出画的属性。(Give the painting some attributes.) 90 | 91 | ## Ref 92 | https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1 93 | -------------------------------------------------------------------------------- /android_build.sh: -------------------------------------------------------------------------------- 1 | cmake .. \ 2 | -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ 3 | -DANDROID_STL=c++_static \ 4 | -DANDROID_ABI="arm64-v8a" \ 5 | -DANDROID_NATIVE_API_LEVEL=android-21 \ 6 | -DCMAKE_BUILD_TYPE=Release \ 7 | -DBUILD_FOR_ANDROID=ON 8 | -------------------------------------------------------------------------------- /export_script/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import StableDiffusionPipeline 3 | torch.backends.cudnn.benchmark = True 4 | # pipe = StableDiffusionPipeline.from_pretrained("IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1", torch_dtype=torch.float16) 5 | pipe = StableDiffusionPipeline.from_pretrained("IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1", torch_dtype=torch.float32) 6 | pipe.to('cuda') 7 | 8 | prompt = '飞流直下三千尺,油画' 9 | image = pipe(prompt, guidance_scale=7.5).images[0] 10 | image.save("飞流.png") 11 | -------------------------------------------------------------------------------- /export_script/to_onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | import shutil 18 | from pathlib import Path 19 | 20 | import torch 21 | from torch.onnx import export 22 | 23 | import onnx 24 | from diffusers import OnnxStableDiffusionPipeline, StableDiffusionPipeline 25 | # from diffusers.onnx_utils import OnnxRuntimeModel 26 | from packaging import version 27 | 28 | 29 | is_torch_less_than_1_11 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.11") 30 | 31 | expor_torchscript = False 32 | 33 | def model_export( 34 | model, 35 | model_args: tuple, 36 | output_path: Path, 37 | ordered_input_names, 38 | output_names, 39 | dynamic_axes, 40 | opset, 41 | use_external_data_format=False, 42 | ): 43 | output_path.parent.mkdir(parents=True, exist_ok=True) 44 | if expor_torchscript: 45 | traced_model = torch.jit.trace(model, model_args, strict=False) 46 | fp = output_path.as_posix() 47 | fp.replace('onnx', 'pt') 48 | traced_model.save(fp) 49 | return 50 | # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, 51 | # so we check the torch version for backwards compatibility 52 | if is_torch_less_than_1_11: 53 | export( 54 | model, 55 | model_args, 56 | f=output_path.as_posix(), 57 | input_names=ordered_input_names, 58 | output_names=output_names, 59 | dynamic_axes=dynamic_axes, 60 | do_constant_folding=True, 61 | use_external_data_format=use_external_data_format, 62 | enable_onnx_checker=True, 63 | opset_version=opset, 64 | ) 65 | else: 66 | export( 67 | model, 68 | model_args, 69 | f=output_path.as_posix(), 70 | input_names=ordered_input_names, 71 | output_names=output_names, 72 | dynamic_axes=dynamic_axes, 73 | do_constant_folding=True, 74 | opset_version=opset, 75 | ) 76 | 77 | 78 | @torch.no_grad() 79 | def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = False): 80 | dtype = torch.float16 if fp16 else torch.float32 81 | if fp16 and torch.cuda.is_available(): 82 | device = "cuda" 83 | elif fp16 and not torch.cuda.is_available(): 84 | raise ValueError("`float16` model export is only supported on GPUs with CUDA") 85 | else: 86 | device = "cpu" 87 | device = "cpu" 88 | pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device) 89 | output_path = Path(output_path) 90 | 91 | # TEXT ENCODER 92 | num_tokens = pipeline.text_encoder.config.max_position_embeddings 93 | text_hidden_size = pipeline.text_encoder.config.hidden_size 94 | text_input = pipeline.tokenizer( 95 | "A sample prompt", 96 | padding="max_length", 97 | max_length=pipeline.tokenizer.model_max_length, 98 | truncation=True, 99 | return_tensors="pt", 100 | ) 101 | model_export( 102 | pipeline.text_encoder, 103 | # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files 104 | model_args=(text_input.input_ids.to(device=device, dtype=torch.int32)), 105 | output_path=output_path / "text_encoder" / "model.onnx", 106 | ordered_input_names=["input_ids"], 107 | output_names=["last_hidden_state", "pooler_output"], 108 | dynamic_axes={ 109 | "input_ids": {0: "batch", 1: "sequence"}, 110 | }, 111 | opset=opset, 112 | ) 113 | del pipeline.text_encoder 114 | 115 | # UNET 116 | unet_in_channels = pipeline.unet.config.in_channels 117 | unet_sample_size = pipeline.unet.config.sample_size 118 | unet_path = output_path / "unet" / "model.onnx" 119 | model_export( 120 | pipeline.unet, 121 | model_args=( 122 | torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), 123 | torch.randn(2).to(device=device, dtype=torch.int32), 124 | torch.randn(2, num_tokens, text_hidden_size).to(device=device, dtype=dtype), 125 | # False, 126 | ), 127 | output_path=unet_path, 128 | ordered_input_names=["sample", "timestep", "encoder_hidden_states", "return_dict"], 129 | output_names=["out_sample"], # has to be different from "sample" for correct tracing 130 | dynamic_axes={ 131 | "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, 132 | "timestep": {0: "batch"}, 133 | "encoder_hidden_states": {0: "batch", 1: "sequence"}, 134 | }, 135 | opset=opset, 136 | use_external_data_format=True, # UNet is > 2GB, so the weights need to be split 137 | ) 138 | unet_model_path = str(unet_path.absolute().as_posix()) 139 | unet_dir = os.path.dirname(unet_model_path) 140 | unet = onnx.load(unet_model_path) 141 | # clean up existing tensor files 142 | shutil.rmtree(unet_dir) 143 | os.mkdir(unet_dir) 144 | # collate external tensor files into one 145 | onnx.save_model( 146 | unet, 147 | unet_model_path, 148 | save_as_external_data=True, 149 | all_tensors_to_one_file=True, 150 | location="weights.pb", 151 | convert_attribute=False, 152 | ) 153 | del pipeline.unet 154 | 155 | # VAE ENCODER 156 | vae_encoder = pipeline.vae 157 | vae_in_channels = vae_encoder.config.in_channels 158 | vae_sample_size = vae_encoder.config.sample_size 159 | # need to get the raw tensor output (sample) from the encoder 160 | vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(sample, return_dict)[0].sample() 161 | model_export( 162 | vae_encoder, 163 | model_args=( 164 | torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(device=device, dtype=dtype), 165 | False, 166 | ), 167 | output_path=output_path / "vae_encoder" / "model.onnx", 168 | ordered_input_names=["sample", "return_dict"], 169 | output_names=["latent_sample"], 170 | dynamic_axes={ 171 | "sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, 172 | }, 173 | opset=opset, 174 | ) 175 | 176 | # VAE DECODER 177 | vae_decoder = pipeline.vae 178 | vae_latent_channels = vae_decoder.config.latent_channels 179 | vae_out_channels = vae_decoder.config.out_channels 180 | # forward only through the decoder part 181 | vae_decoder.forward = vae_encoder.decode 182 | model_export( 183 | vae_decoder, 184 | model_args=( 185 | torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(device=device, dtype=dtype), 186 | False, 187 | ), 188 | output_path=output_path / "vae_decoder" / "model.onnx", 189 | ordered_input_names=["latent_sample", "return_dict"], 190 | output_names=["sample"], 191 | dynamic_axes={ 192 | "latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"}, 193 | }, 194 | opset=opset, 195 | ) 196 | del pipeline.vae 197 | ''' 198 | # SAFETY CHECKER 199 | if pipeline.safety_checker is not None: 200 | safety_checker = pipeline.safety_checker 201 | clip_num_channels = safety_checker.config.vision_config.num_channels 202 | clip_image_size = safety_checker.config.vision_config.image_size 203 | safety_checker.forward = safety_checker.forward_onnx 204 | model_export( 205 | pipeline.safety_checker, 206 | model_args=( 207 | torch.randn( 208 | 1, 209 | clip_num_channels, 210 | clip_image_size, 211 | clip_image_size, 212 | ).to(device=device, dtype=dtype), 213 | torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(device=device, dtype=dtype), 214 | ), 215 | output_path=output_path / "safety_checker" / "model.onnx", 216 | ordered_input_names=["clip_input", "images"], 217 | output_names=["out_images", "has_nsfw_concepts"], 218 | dynamic_axes={ 219 | "clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"}, 220 | "images": {0: "batch", 1: "height", 2: "width", 3: "channels"}, 221 | }, 222 | opset=opset, 223 | ) 224 | del pipeline.safety_checker 225 | safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker") 226 | else: 227 | safety_checker = None 228 | 229 | onnx_pipeline = OnnxStableDiffusionPipeline( 230 | vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"), 231 | vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"), 232 | text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"), 233 | tokenizer=pipeline.tokenizer, 234 | unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), 235 | scheduler=pipeline.scheduler, 236 | safety_checker=safety_checker, 237 | feature_extractor=pipeline.feature_extractor, 238 | ) 239 | 240 | onnx_pipeline.save_pretrained(output_path) 241 | print("ONNX pipeline saved to", output_path) 242 | 243 | del pipeline 244 | del onnx_pipeline 245 | _ = OnnxStableDiffusionPipeline.from_pretrained(output_path, provider="CPUExecutionProvider") 246 | print("ONNX pipeline is loadable") 247 | ''' 248 | 249 | 250 | if __name__ == "__main__": 251 | parser = argparse.ArgumentParser() 252 | 253 | parser.add_argument( 254 | "--model_path", 255 | type=str, 256 | required=True, 257 | help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).", 258 | ) 259 | 260 | parser.add_argument("--output_path", type=str, required=True, help="Path to the output model.") 261 | 262 | parser.add_argument( 263 | "--opset", 264 | default=14, 265 | type=int, 266 | help="The version of the ONNX operator set to use.", 267 | ) 268 | parser.add_argument("--fp16", action="store_true", default=False, help="Export the models in `float16` mode") 269 | 270 | args = parser.parse_args() 271 | 272 | convert_models(args.model_path, args.output_path, args.opset, args.fp16) 273 | -------------------------------------------------------------------------------- /include/MNN/AutoTime.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // AutoTime.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2018/07/27. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_AutoTime_hpp 10 | #define MNN_AutoTime_hpp 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | namespace MNN { 17 | 18 | class MNN_PUBLIC Timer { 19 | public: 20 | Timer(); 21 | ~Timer(); 22 | Timer(const Timer&) = delete; 23 | Timer(const Timer&&) = delete; 24 | Timer& operator=(const Timer&) = delete; 25 | Timer& operator=(const Timer&&) = delete; 26 | 27 | // reset timer 28 | void reset(); 29 | // get duration (us) from init or latest reset. 30 | uint64_t durationInUs(); 31 | 32 | // Get Current Time 33 | uint64_t current() const { 34 | return mLastResetTime; 35 | } 36 | protected: 37 | uint64_t mLastResetTime; 38 | }; 39 | 40 | /** time tracing util. prints duration between init and deinit. */ 41 | class MNN_PUBLIC AutoTime : Timer { 42 | public: 43 | AutoTime(int line, const char* func); 44 | ~AutoTime(); 45 | AutoTime(const AutoTime&) = delete; 46 | AutoTime(const AutoTime&&) = delete; 47 | AutoTime& operator=(const AutoTime&) = delete; 48 | AutoTime& operator=(const AutoTime&&) = delete; 49 | 50 | private: 51 | int mLine; 52 | char* mName; 53 | }; 54 | } // namespace MNN 55 | 56 | #ifdef MNN_OPEN_TIME_TRACE 57 | #define AUTOTIME MNN::AutoTime ___t(__LINE__, __func__) 58 | #else 59 | #define AUTOTIME 60 | #endif 61 | 62 | #endif /* AutoTime_hpp */ 63 | -------------------------------------------------------------------------------- /include/MNN/ErrorCode.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ErrorCode.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2018/09/18. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_ErrorCode_h 10 | #define MNN_ErrorCode_h 11 | 12 | namespace MNN { 13 | enum ErrorCode { 14 | #ifdef NO_ERROR 15 | #undef NO_ERROR 16 | #endif // NO_ERROR 17 | NO_ERROR = 0, 18 | OUT_OF_MEMORY = 1, 19 | NOT_SUPPORT = 2, 20 | COMPUTE_SIZE_ERROR = 3, 21 | NO_EXECUTION = 4, 22 | INVALID_VALUE = 5, 23 | 24 | // User error 25 | INPUT_DATA_ERROR = 10, 26 | CALL_BACK_STOP = 11, 27 | 28 | // Op Resize Error 29 | TENSOR_NOT_SUPPORT = 20, 30 | TENSOR_NEED_DIVIDE = 21, 31 | }; 32 | } // namespace MNN 33 | 34 | #endif /* ErrorCode_h */ 35 | -------------------------------------------------------------------------------- /include/MNN/HalideRuntime.h: -------------------------------------------------------------------------------- 1 | #ifndef MNN_HALIDE_HALIDERUNTIME_H 2 | #define MNN_HALIDE_HALIDERUNTIME_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #ifdef __cplusplus 9 | extern "C" { 10 | #endif 11 | 12 | // Note that you should not use "inline" along with HALIDE_ALWAYS_INLINE; 13 | // it is not necessary, and may produce warnings for some build configurations. 14 | #ifdef _MSC_VER 15 | #define HALIDE_ALWAYS_INLINE __forceinline 16 | #define HALIDE_NEVER_INLINE __declspec(noinline) 17 | #else 18 | #define HALIDE_ALWAYS_INLINE __attribute__((always_inline)) inline 19 | #define HALIDE_NEVER_INLINE __attribute__((noinline)) 20 | #endif 21 | 22 | /** \file 23 | * 24 | * This file declares the routines used by Halide internally in its 25 | * runtime. On platforms that support weak linking, these can be 26 | * replaced with user-defined versions by defining an extern "C" 27 | * function with the same name and signature. 28 | * 29 | * When doing Just In Time (JIT) compilation methods on the Func being 30 | * compiled must be called instead. The corresponding methods are 31 | * documented below. 32 | * 33 | * All of these functions take a "void *user_context" parameter as their 34 | * first argument; if the Halide kernel that calls back to any of these 35 | * functions has been compiled with the UserContext feature set on its Target, 36 | * then the value of that pointer passed from the code that calls the 37 | * Halide kernel is piped through to the function. 38 | * 39 | * Some of these are also useful to call when using the default 40 | * implementation. E.g. halide_shutdown_thread_pool. 41 | * 42 | * Note that even on platforms with weak linking, some linker setups 43 | * may not respect the override you provide. E.g. if the override is 44 | * in a shared library and the halide object files are linked directly 45 | * into the output, the builtin versions of the runtime functions will 46 | * be called. See your linker documentation for more details. On 47 | * Linux, LD_DYNAMIC_WEAK=1 may help. 48 | * 49 | */ 50 | 51 | // Forward-declare to suppress warnings if compiling as C. 52 | struct halide_buffer_t; 53 | 54 | /** Types in the halide type system. They can be ints, unsigned ints, 55 | * or floats (of various bit-widths), or a handle (which is always 64-bits). 56 | * Note that the int/uint/float values do not imply a specific bit width 57 | * (the bit width is expected to be encoded in a separate value). 58 | */ 59 | typedef enum halide_type_code_t 60 | { 61 | halide_type_int = 0, //!< signed integers 62 | halide_type_uint = 1, //!< unsigned integers 63 | halide_type_float = 2, //!< floating point numbers 64 | halide_type_handle = 3 //!< opaque pointer type (void *) 65 | } halide_type_code_t; 66 | 67 | // Note that while __attribute__ can go before or after the declaration, 68 | // __declspec apparently is only allowed before. 69 | #ifndef HALIDE_ATTRIBUTE_ALIGN 70 | #ifdef _MSC_VER 71 | #define HALIDE_ATTRIBUTE_ALIGN(x) __declspec(align(x)) 72 | #else 73 | #define HALIDE_ATTRIBUTE_ALIGN(x) __attribute__((aligned(x))) 74 | #endif 75 | #endif 76 | 77 | /** A runtime tag for a type in the halide type system. Can be ints, 78 | * unsigned ints, or floats of various bit-widths (the 'bits' 79 | * field). Can also be vectors of the same (by setting the 'lanes' 80 | * field to something larger than one). This struct should be 81 | * exactly 32-bits in size. */ 82 | struct halide_type_t { 83 | /** The basic type code: signed integer, unsigned integer, or floating point. */ 84 | #if __cplusplus >= 201103L 85 | HALIDE_ATTRIBUTE_ALIGN(1) halide_type_code_t code; // halide_type_code_t 86 | #else 87 | HALIDE_ATTRIBUTE_ALIGN(1) uint8_t code; // halide_type_code_t 88 | #endif 89 | 90 | /** The number of bits of precision of a single scalar value of this type. */ 91 | HALIDE_ATTRIBUTE_ALIGN(1) uint8_t bits; 92 | 93 | /** How many elements in a vector. This is 1 for scalar types. */ 94 | HALIDE_ATTRIBUTE_ALIGN(2) uint16_t lanes; 95 | 96 | #ifdef __cplusplus 97 | /** Construct a runtime representation of a Halide type from: 98 | * code: The fundamental type from an enum. 99 | * bits: The bit size of one element. 100 | * lanes: The number of vector elements in the type. */ 101 | HALIDE_ALWAYS_INLINE halide_type_t(halide_type_code_t code, uint8_t bits, uint16_t lanes = 1) 102 | : code(code), bits(bits), lanes(lanes) { 103 | } 104 | 105 | /** Default constructor is required e.g. to declare halide_trace_event 106 | * instances. */ 107 | HALIDE_ALWAYS_INLINE halide_type_t() : code((halide_type_code_t)0), bits(0), lanes(0) {} 108 | 109 | /** Compare two types for equality. */ 110 | HALIDE_ALWAYS_INLINE bool operator==(const halide_type_t &other) const { 111 | return (code == other.code && 112 | bits == other.bits && 113 | lanes == other.lanes); 114 | } 115 | 116 | HALIDE_ALWAYS_INLINE bool operator!=(const halide_type_t &other) const { 117 | return !(*this == other); 118 | } 119 | 120 | /** Size in bytes for a single element, even if width is not 1, of this type. */ 121 | HALIDE_ALWAYS_INLINE int bytes() const { return (bits + 7) / 8; } 122 | #endif 123 | }; 124 | 125 | /** An opaque struct containing per-GPU API implementations of the 126 | * device functions. */ 127 | struct halide_device_interface_impl_t; 128 | 129 | /** Each GPU API provides a halide_device_interface_t struct pointing 130 | * to the code that manages device allocations. You can access these 131 | * functions directly from the struct member function pointers, or by 132 | * calling the functions declared below. Note that the global 133 | * functions are not available when using Halide as a JIT compiler. 134 | * If you are using raw halide_buffer_t in that context you must use 135 | * the function pointers in the device_interface struct. 136 | * 137 | * The function pointers below are currently the same for every GPU 138 | * API; only the impl field varies. These top-level functions do the 139 | * bookkeeping that is common across all GPU APIs, and then dispatch 140 | * to more API-specific functions via another set of function pointers 141 | * hidden inside the impl field. 142 | */ 143 | struct halide_device_interface_t { 144 | int (*device_malloc)(void *user_context, struct halide_buffer_t *buf, 145 | const struct halide_device_interface_t *device_interface); 146 | int (*device_free)(void *user_context, struct halide_buffer_t *buf); 147 | int (*device_sync)(void *user_context, struct halide_buffer_t *buf); 148 | void (*device_release)(void *user_context, 149 | const struct halide_device_interface_t *device_interface); 150 | int (*copy_to_host)(void *user_context, struct halide_buffer_t *buf); 151 | int (*copy_to_device)(void *user_context, struct halide_buffer_t *buf, 152 | const struct halide_device_interface_t *device_interface); 153 | int (*device_and_host_malloc)(void *user_context, struct halide_buffer_t *buf, 154 | const struct halide_device_interface_t *device_interface); 155 | int (*device_and_host_free)(void *user_context, struct halide_buffer_t *buf); 156 | int (*buffer_copy)(void *user_context, struct halide_buffer_t *src, 157 | const struct halide_device_interface_t *dst_device_interface, struct halide_buffer_t *dst); 158 | int (*device_crop)(void *user_context, const struct halide_buffer_t *src, 159 | struct halide_buffer_t *dst); 160 | int (*device_release_crop)(void *user_context, struct halide_buffer_t *buf); 161 | int (*wrap_native)(void *user_context, struct halide_buffer_t *buf, uint64_t handle, 162 | const struct halide_device_interface_t *device_interface); 163 | int (*detach_native)(void *user_context, struct halide_buffer_t *buf); 164 | const struct halide_device_interface_impl_t *impl; 165 | }; 166 | 167 | typedef struct halide_dimension_t { 168 | int32_t min, extent, stride; 169 | 170 | // Per-dimension flags. None are defined yet (This is reserved for future use). 171 | uint32_t flags; 172 | 173 | #ifdef __cplusplus 174 | HALIDE_ALWAYS_INLINE halide_dimension_t() : min(0), extent(0), stride(0), flags(0) {} 175 | HALIDE_ALWAYS_INLINE halide_dimension_t(int32_t m, int32_t e, int32_t s, uint32_t f = 0) : 176 | min(m), extent(e), stride(s), flags(f) {} 177 | 178 | HALIDE_ALWAYS_INLINE bool operator==(const halide_dimension_t &other) const { 179 | return (min == other.min) && 180 | (extent == other.extent) && 181 | (stride == other.stride) && 182 | (flags == other.flags); 183 | } 184 | 185 | HALIDE_ALWAYS_INLINE bool operator!=(const halide_dimension_t &other) const { 186 | return !(*this == other); 187 | } 188 | #endif 189 | } halide_dimension_t; 190 | 191 | #ifdef __cplusplus 192 | } // extern "C" 193 | #endif 194 | 195 | typedef enum {halide_buffer_flag_host_dirty = 1, 196 | halide_buffer_flag_device_dirty = 2} halide_buffer_flags; 197 | 198 | /** 199 | * The raw representation of an image passed around by generated 200 | * Halide code. It includes some stuff to track whether the image is 201 | * not actually in main memory, but instead on a device (like a 202 | * GPU). For a more convenient C++ wrapper, use Halide::Buffer. */ 203 | typedef struct halide_buffer_t { 204 | /** A device-handle for e.g. GPU memory used to back this buffer. */ 205 | uint64_t device; 206 | 207 | /** The interface used to interpret the above handle. */ 208 | const struct halide_device_interface_t *device_interface; 209 | 210 | /** A pointer to the start of the data in main memory. In terms of 211 | * the Halide coordinate system, this is the address of the min 212 | * coordinates (defined below). */ 213 | uint8_t* host; 214 | 215 | /** flags with various meanings. */ 216 | uint64_t flags; 217 | 218 | /** The type of each buffer element. */ 219 | struct halide_type_t type; 220 | 221 | /** The dimensionality of the buffer. */ 222 | int32_t dimensions; 223 | 224 | /** The shape of the buffer. Halide does not own this array - you 225 | * must manage the memory for it yourself. */ 226 | halide_dimension_t *dim; 227 | 228 | /** Pads the buffer up to a multiple of 8 bytes */ 229 | void *padding; 230 | } halide_buffer_t; 231 | 232 | 233 | #ifdef __cplusplus 234 | 235 | namespace { 236 | template struct check_is_pointer; 237 | template struct check_is_pointer {}; 238 | } 239 | 240 | /** Construct the halide equivalent of a C type */ 241 | template 242 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 243 | // Create a compile-time error if T is not a pointer (without 244 | // using any includes - this code goes into the runtime). 245 | check_is_pointer check; 246 | (void)check; 247 | return halide_type_t(halide_type_handle, 64); 248 | } 249 | 250 | template<> 251 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 252 | return halide_type_t(halide_type_float, 32); 253 | } 254 | 255 | template<> 256 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 257 | return halide_type_t(halide_type_float, 64); 258 | } 259 | 260 | template<> 261 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 262 | return halide_type_t(halide_type_uint, 1); 263 | } 264 | 265 | template<> 266 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 267 | return halide_type_t(halide_type_uint, 8); 268 | } 269 | 270 | template<> 271 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 272 | return halide_type_t(halide_type_uint, 16); 273 | } 274 | 275 | template<> 276 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 277 | return halide_type_t(halide_type_uint, 32); 278 | } 279 | 280 | template<> 281 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 282 | return halide_type_t(halide_type_uint, 64); 283 | } 284 | 285 | template<> 286 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 287 | return halide_type_t(halide_type_int, 8); 288 | } 289 | 290 | template<> 291 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 292 | return halide_type_t(halide_type_int, 16); 293 | } 294 | 295 | template<> 296 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 297 | return halide_type_t(halide_type_int, 32); 298 | } 299 | 300 | template<> 301 | HALIDE_ALWAYS_INLINE halide_type_t halide_type_of() { 302 | return halide_type_t(halide_type_int, 64); 303 | } 304 | 305 | #endif 306 | 307 | #endif // HALIDE_HALIDERUNTIME_H 308 | -------------------------------------------------------------------------------- /include/MNN/ImageProcess.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ImageProcess.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2018/09/19. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_ImageProcess_hpp 10 | #define MNN_ImageProcess_hpp 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | namespace MNN { 17 | namespace CV { 18 | enum ImageFormat { 19 | RGBA = 0, 20 | RGB = 1, 21 | BGR = 2, 22 | GRAY = 3, 23 | BGRA = 4, 24 | YCrCb = 5, 25 | YUV = 6, 26 | HSV = 7, 27 | XYZ = 8, 28 | BGR555 = 9, 29 | BGR565 = 10, 30 | YUV_NV21 = 11, 31 | YUV_NV12 = 12, 32 | YUV_I420 = 13, 33 | HSV_FULL = 14, 34 | }; 35 | 36 | enum Filter { NEAREST = 0, BILINEAR = 1, BICUBIC = 2 }; 37 | 38 | enum Wrap { CLAMP_TO_EDGE = 0, ZERO = 1, REPEAT = 2 }; 39 | 40 | /** 41 | * handle image process for tensor. 42 | * step: 43 | * 1: Do transform compute and get points 44 | * 2: Sample line and do format convert 45 | * 3: Turn RGBA to float tensor, and do sub and normalize 46 | */ 47 | class MNN_PUBLIC ImageProcess { 48 | public: 49 | struct Inside; 50 | struct Config { 51 | /** data filter */ 52 | Filter filterType = NEAREST; 53 | /** format of source data */ 54 | ImageFormat sourceFormat = RGBA; 55 | /** format of destination data */ 56 | ImageFormat destFormat = RGBA; 57 | 58 | // Only valid if the dest type is float 59 | float mean[4] = {0.0f, 0.0f, 0.0f, 0.0f}; 60 | float normal[4] = {1.0f, 1.0f, 1.0f, 1.0f}; 61 | 62 | /** edge wrapper */ 63 | Wrap wrap = CLAMP_TO_EDGE; 64 | }; 65 | 66 | public: 67 | /** 68 | * @brief create image process with given config for given tensor. 69 | * @param config given config. 70 | * @param dstTensor given tensor. 71 | * @return image processor. 72 | */ 73 | static ImageProcess* create(const Config& config, const Tensor* dstTensor = nullptr); 74 | 75 | /** 76 | * @brief create image process with given config for given tensor. 77 | * @param means given means 78 | * @param meanCount given means count 79 | * @param normals given normals 80 | * @param normalCount given normal count 81 | * @param sourceFormat format of source data 82 | * @param destFormat format of destination data 83 | * @param dstTensor given tensor. 84 | * @return image processor. 85 | */ 86 | static ImageProcess* create(const ImageFormat sourceFormat = RGBA, const ImageFormat destFormat = RGBA, 87 | const float* means = nullptr, const int meanCount = 0, const float* normals = nullptr, 88 | const int normalCount = 0, const Tensor* dstTensor = nullptr); 89 | 90 | ~ImageProcess(); 91 | static void destroy(ImageProcess* imageProcess); 92 | 93 | /** 94 | * @brief get affine transform matrix. 95 | * @return affine transform matrix. 96 | */ 97 | inline const Matrix& matrix() const { 98 | return mTransform; 99 | } 100 | void setMatrix(const Matrix& matrix); 101 | 102 | /** 103 | * @brief convert source data to given tensor. 104 | * @param source source data. 105 | * @param iw source width. 106 | * @param ih source height. 107 | * @param stride number of elements per row. eg: 100 width RGB contains at least 300 elements. 108 | * @param dest given tensor. 109 | * @return result code. 110 | */ 111 | ErrorCode convert(const uint8_t* source, int iw, int ih, int stride, Tensor* dest); 112 | 113 | /** 114 | * @brief convert source data to given tensor. 115 | * @param source source data. 116 | * @param iw source width. 117 | * @param ih source height. 118 | * @param stride number of elements per row. eg: 100 width RGB contains at least 300 elements. 119 | * @param dest dest data. 120 | * @param ow output width. 121 | * @param oh output height. 122 | * @param outputBpp output bpp, if 0, set as the save and config.destFormat. 123 | * @param outputStride output stride, if 0, set as ow * outputBpp. 124 | * @param type Only support halide_type_of and halide_type_of. 125 | * @return result code. 126 | */ 127 | ErrorCode convert(const uint8_t* source, int iw, int ih, int stride, void* dest, int ow, int oh, int outputBpp = 0, 128 | int outputStride = 0, halide_type_t type = halide_type_of()); 129 | 130 | /** 131 | * @brief create tensor with given data. 132 | * @param w image width. 133 | * @param h image height. 134 | * @param bpp bytes per pixel. 135 | * @param p pixel data pointer. 136 | * @return created tensor. 137 | */ 138 | template 139 | static Tensor* createImageTensor(int w, int h, int bpp, void* p = nullptr) { 140 | return createImageTensor(halide_type_of(), w, h, bpp, p); 141 | } 142 | static Tensor* createImageTensor(halide_type_t type, int w, int h, int bpp, void* p = nullptr); 143 | 144 | /** 145 | * @brief set padding value when wrap=ZERO. 146 | * @param value padding value. 147 | * @return void. 148 | */ 149 | void setPadding(uint8_t value) { 150 | mPaddingValue = value; 151 | } 152 | 153 | /** 154 | * @brief set to draw mode. 155 | * @param void 156 | * @return void. 157 | */ 158 | void setDraw(); 159 | 160 | /** 161 | * @brief draw color to regions of img. 162 | * @param img the image to draw. 163 | * @param w the image's width. 164 | * @param h the image's height. 165 | * @param c the image's channel. 166 | * @param regions the regions to draw, size is [num * 3] contain num x { y, xl, xr } 167 | * @param num regions num 168 | * @param color the color to draw. 169 | * @return void. 170 | */ 171 | void draw(uint8_t* img, int w, int h, int c, const int* regions, int num, const uint8_t* color); 172 | private: 173 | ImageProcess(const Config& config); 174 | Matrix mTransform; 175 | Matrix mTransformInvert; 176 | Inside* mInside; 177 | uint8_t mPaddingValue = 0; 178 | }; 179 | } // namespace CV 180 | } // namespace MNN 181 | 182 | #endif /* ImageProcess_hpp */ 183 | -------------------------------------------------------------------------------- /include/MNN/Interpreter.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Interpreter.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2018/07/23. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_Interpreter_hpp 10 | #define MNN_Interpreter_hpp 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | namespace MNN { 21 | 22 | /** session schedule config */ 23 | struct ScheduleConfig { 24 | /** which tensor should be kept */ 25 | std::vector saveTensors; 26 | /** forward type */ 27 | MNNForwardType type = MNN_FORWARD_CPU; 28 | /** CPU:number of threads in parallel , Or GPU: mode setting*/ 29 | union { 30 | int numThread = 4; 31 | int mode; 32 | }; 33 | 34 | /** subpath to run */ 35 | struct Path { 36 | std::vector inputs; 37 | std::vector outputs; 38 | 39 | enum Mode { 40 | /** 41 | * Op Mode 42 | * - inputs means the source op, can NOT be empty. 43 | * - outputs means the sink op, can be empty. 44 | * The path will start from source op, then flow when encounter the sink op. 45 | * The sink op will not be compute in this path. 46 | */ 47 | Op = 0, 48 | 49 | /** 50 | * Tensor Mode 51 | * - inputs means the inputs tensors, can NOT be empty. 52 | * - outputs means the outputs tensors, can NOT be empty. 53 | * It will find the pipeline that compute outputs from inputs. 54 | */ 55 | Tensor = 1 56 | }; 57 | 58 | /** running mode */ 59 | Mode mode = Op; 60 | }; 61 | Path path; 62 | 63 | /** backup backend used to create execution when desinated backend do NOT support any op */ 64 | MNNForwardType backupType = MNN_FORWARD_CPU; 65 | 66 | /** extra backend config */ 67 | BackendConfig* backendConfig = nullptr; 68 | }; 69 | 70 | class Session; 71 | struct Content; 72 | class Tensor; 73 | class Backend; 74 | class Runtime; 75 | 76 | class MNN_PUBLIC OperatorInfo { 77 | struct Info; 78 | 79 | public: 80 | /** Operator's name*/ 81 | const std::string& name() const; 82 | 83 | /** Operator's type*/ 84 | const std::string& type() const; 85 | 86 | /** Operator's flops, in M*/ 87 | float flops() const; 88 | 89 | protected: 90 | OperatorInfo(); 91 | ~OperatorInfo(); 92 | Info* mContent; 93 | }; 94 | 95 | typedef std::function&, const std::string& /*opName*/)> TensorCallBack; 96 | typedef std::function&, const OperatorInfo*)> TensorCallBackWithInfo; 97 | typedef std::pair< std::map>, std::shared_ptr> RuntimeInfo; 98 | 99 | /** 100 | * @brief get mnn version info. 101 | * @return mnn version string. 102 | */ 103 | MNN_PUBLIC const char* getVersion(); 104 | 105 | /** net data holder. multiple sessions could share same net. */ 106 | class MNN_PUBLIC Interpreter { 107 | public: 108 | /** 109 | * @brief create net from file. 110 | * @param file given file. 111 | * @return created net if success, NULL otherwise. 112 | */ 113 | static Interpreter* createFromFile(const char* file); 114 | /** 115 | * @brief create net from buffer. 116 | * @param buffer given data buffer. 117 | * @param size size of data buffer. 118 | * @return created net if success, NULL otherwise. 119 | */ 120 | static Interpreter* createFromBuffer(const void* buffer, size_t size); 121 | ~Interpreter(); 122 | 123 | /** 124 | * @brief destroy Interpreter 125 | * @param model given Interpreter to release. 126 | */ 127 | static void destroy(Interpreter* net); 128 | 129 | enum SessionMode { 130 | /** About CallBack, Default Session_Debug*/ 131 | /** runSessionWithCallBack is allowed and can get internal op info*/ 132 | Session_Debug = 0, 133 | /** runSessionWithCallBack is not valid and can't get any info of op in session*/ 134 | Session_Release = 1, 135 | 136 | /** About input tenosr, Default Session_Input_Inside*/ 137 | /** The input tensor is alloced by session, input data after session resized*/ 138 | Session_Input_Inside = 2, 139 | /** The input tensor is alloced by user, set input data before session resize*/ 140 | Session_Input_User = 3, 141 | 142 | /** The output tensor depends on session, and can't be separate used*/ 143 | Session_Output_Inside = 4, 144 | /** The output tensor can be separated from session*/ 145 | Session_Output_User = 5, 146 | 147 | /** Try Resize Session when create Session or not, default direct: */ 148 | Session_Resize_Direct = 6, 149 | Session_Resize_Defer = 7, 150 | 151 | /** Determine the Execution's forward type is determine by user or auto determine */ 152 | Session_Backend_Fix = 8, // Use the backend user set, when not support use default backend 153 | Session_Backend_Auto = 9, // Auto Determine the Op type by MNN 154 | }; 155 | /** 156 | * @brief The API shoud be called before create session. 157 | * @param mode session mode 158 | */ 159 | void setSessionMode(SessionMode mode); 160 | 161 | /** 162 | * @brief The API shoud be called before create session. 163 | * If the cache exist, try to load cache from file. 164 | * After createSession, try to save cache to file. 165 | * @param cacheFile cache file name 166 | * @param keySize depercerate, for future use. 167 | */ 168 | void setCacheFile(const char* cacheFile, size_t keySize = 128); 169 | 170 | /** 171 | * @brief The API shoud be called after last resize session. 172 | * If resize session generate new cache info, try to rewrite cache file. 173 | * If resize session do not generate any new cache info, just do nothing. 174 | * @param session giveb session 175 | * @param flag Protected param, not used now 176 | */ 177 | ErrorCode updateCacheFile(Session *session, int flag = 0); 178 | 179 | enum HintMode { 180 | // Max Op number for async tuning 181 | MAX_TUNING_NUMBER = 0, 182 | }; 183 | /** 184 | * @brief The API shoud be called before create session. 185 | * @param mode Hint type 186 | * @param value Hint value 187 | */ 188 | void setSessionHint(HintMode mode, int value); 189 | public: 190 | /** 191 | * @brief create runtimeInfo separately with schedule config. 192 | * @param configs session schedule configs. 193 | */ 194 | static RuntimeInfo createRuntime(const std::vector& configs); 195 | 196 | /** 197 | * @brief create session with schedule config. created session will be managed in net. 198 | * @param config session schedule config. 199 | * @return created session if success, NULL otherwise. 200 | */ 201 | Session* createSession(const ScheduleConfig& config); 202 | 203 | /** 204 | * @brief create session with schedule config and user-specified runtime. 205 | * @param config session schedule config, runtime runtimeInfo used by the created session. 206 | * @return created session if success, NULL otherwise. 207 | */ 208 | Session* createSession(const ScheduleConfig& config, const RuntimeInfo& runtime); 209 | 210 | /** 211 | * @brief create multi-path session with schedule configs. created session will be managed in net. 212 | * @param configs session schedule configs. 213 | * @return created session if success, NULL otherwise. 214 | */ 215 | Session* createMultiPathSession(const std::vector& configs); 216 | 217 | /** 218 | * @brief create multi-path session with schedule configs and user-specified runtime. 219 | created session will be managed in net. 220 | * @param configs session schedule configs. 221 | * @return created session if success, NULL otherwise. 222 | */ 223 | Session* createMultiPathSession(const std::vector& configs, const RuntimeInfo& runtime); 224 | 225 | /** 226 | * @brief release session. 227 | * @param session given session. 228 | * @return true if given session is held by net and is freed. 229 | */ 230 | bool releaseSession(Session* session); 231 | 232 | /** 233 | * @brief call this function to get tensors ready. output tensor buffer (host or deviceId) should be retrieved 234 | * after resize of any input tensor. 235 | * @param session given session. 236 | */ 237 | void resizeSession(Session* session); 238 | 239 | /** 240 | * @brief call this function to get tensors ready. output tensor buffer (host or deviceId) should be retrieved 241 | * after resize of any input tensor. 242 | * @param session given session. 243 | * @param needRelloc, 1 means need realloc. 244 | */ 245 | void resizeSession(Session* session, int needRelloc); 246 | 247 | 248 | /** 249 | * @brief call this function if don't need resize or create session any more, it will save a few memory that equal 250 | * to the size of model buffer 251 | */ 252 | void releaseModel(); 253 | 254 | /** 255 | * @brief Get the model buffer for user to save 256 | * @return std::make_pair(modelBuffer, modelSize). 257 | * @example: 258 | * std::ofstream output("trainResult.alinn") 259 | * auto buffer = net->getModelBuffer(); 260 | * output.write((const char*)buffer.first, buffer.second); 261 | */ 262 | std::pair getModelBuffer() const; 263 | 264 | /** 265 | * @brief Get the model's version info. 266 | * @return const char* of model's version info like "2.0.0"; 267 | * If model is not loaded or model no version info, return "version info not found". 268 | */ 269 | const char* getModelVersion() const; 270 | 271 | /** 272 | * @brief update Session's Tensor to model's Const Op 273 | * @param session given session. 274 | * @return result of running. 275 | */ 276 | ErrorCode updateSessionToModel(Session* session); 277 | 278 | /** 279 | * @brief run session. 280 | * @param session given session. 281 | * @return result of running. 282 | */ 283 | ErrorCode runSession(Session* session) const; 284 | 285 | /* 286 | * @brief run session. 287 | * @param session given session. 288 | * @param before callback before each op. return true to run the op; return false to skip the op. 289 | * @param after callback after each op. return true to continue running; return false to interrupt the session. 290 | * @param sync synchronously wait for finish of execution or not. 291 | * @return result of running. 292 | */ 293 | ErrorCode runSessionWithCallBack(const Session* session, const TensorCallBack& before, const TensorCallBack& end, 294 | bool sync = false) const; 295 | 296 | /* 297 | * @brief run session. 298 | * @param session given session. 299 | * @param before callback before each op. return true to run the op; return false to skip the op. 300 | * @param after callback after each op. return true to continue running; return false to interrupt the session. 301 | * @param sync synchronously wait for finish of execution or not. 302 | * @return result of running. 303 | */ 304 | ErrorCode runSessionWithCallBackInfo(const Session* session, const TensorCallBackWithInfo& before, 305 | const TensorCallBackWithInfo& end, bool sync = false) const; 306 | 307 | /** 308 | * @brief get input tensor for given name. 309 | * @param session given session. 310 | * @param name given name. if NULL, return first input. 311 | * @return tensor if found, NULL otherwise. 312 | */ 313 | Tensor* getSessionInput(const Session* session, const char* name); 314 | /** 315 | * @brief get output tensor for given name. 316 | * @param session given session. 317 | * @param name given name. if NULL, return first output. 318 | * @return tensor if found, NULL otherwise. 319 | */ 320 | Tensor* getSessionOutput(const Session* session, const char* name); 321 | 322 | enum SessionInfoCode { 323 | /** memory session used in MB, float* */ 324 | MEMORY = 0, 325 | 326 | /** float operation needed in session in M, float* */ 327 | FLOPS = 1, 328 | 329 | /** Backends in session in M, int*, length >= 1 + number of configs when create session */ 330 | BACKENDS = 2, 331 | 332 | /** Resize Info, int*, 0: ready to execute, 1: need malloc, 2: need resize */ 333 | RESIZE_STATUS = 3, 334 | 335 | ALL 336 | }; 337 | 338 | /** 339 | * @brief get session info 340 | * @param session given session. 341 | * @param code given info code. 342 | * @param ptr given info ptr, see SessionInfoCode for detail 343 | * @return true if support the code, false otherwise. 344 | */ 345 | bool getSessionInfo(const Session* session, SessionInfoCode code, void* ptr); 346 | 347 | /** 348 | * @brief get all output tensors. 349 | * @param session given session. 350 | * @return all output tensors mapped with name. 351 | */ 352 | const std::map& getSessionOutputAll(const Session* session) const; 353 | /** 354 | * @brief get all input tensors. 355 | * @param session given session. 356 | * @return all input tensors mapped with name. 357 | */ 358 | const std::map& getSessionInputAll(const Session* session) const; 359 | 360 | public: 361 | /** 362 | * @brief resize given tensor. 363 | * @param tensor given tensor. 364 | * @param dims new dims. at most 6 dims. 365 | */ 366 | void resizeTensor(Tensor* tensor, const std::vector& dims); 367 | 368 | /** 369 | * @brief resize given tensor by nchw. 370 | * @param batch / N. 371 | * @param channel / C. 372 | * @param height / H. 373 | * @param width / W 374 | */ 375 | void resizeTensor(Tensor* tensor, int batch, int channel, int height, int width); 376 | 377 | /** 378 | * @brief get backend used to create given tensor. 379 | * @param session given session. 380 | * @param tensor given tensor. 381 | * @return backend used to create given tensor, may be NULL. 382 | */ 383 | const Backend* getBackend(const Session* session, const Tensor* tensor) const; 384 | 385 | /** 386 | * @brief get business code (model identifier). 387 | * @return business code. 388 | */ 389 | const char* bizCode() const; 390 | 391 | /** 392 | * @brief get model UUID 393 | * @return Model UUID. 394 | */ 395 | const char* uuid() const; 396 | 397 | private: 398 | static Interpreter* createFromBufferInternal(Content* net, bool enforceAuth); 399 | 400 | Content* mNet = nullptr; 401 | Interpreter(Content* net); 402 | 403 | Interpreter(const Interpreter&) = delete; 404 | Interpreter(const Interpreter&&) = delete; 405 | Interpreter& operator=(const Interpreter&) = delete; 406 | Interpreter& operator=(const Interpreter&&) = delete; 407 | void waitSessionFinish(const Session* session) const; 408 | #ifdef MNN_INTERNAL_ENABLED 409 | void logForRunSession(const Session* session, float time, const char* api) const; 410 | #endif 411 | }; 412 | } // namespace MNN 413 | 414 | #endif /* Interpreter_hpp */ 415 | -------------------------------------------------------------------------------- /include/MNN/MNNDefine.h: -------------------------------------------------------------------------------- 1 | // 2 | // MNNDefine.h 3 | // MNN 4 | // 5 | // Created by MNN on 2018/08/09. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNNDefine_h 10 | #define MNNDefine_h 11 | 12 | #include 13 | #include 14 | 15 | #if defined(__APPLE__) 16 | #include 17 | #if TARGET_OS_IPHONE 18 | #define MNN_BUILD_FOR_IOS 19 | #endif 20 | #endif 21 | 22 | #ifdef MNN_USE_LOGCAT 23 | #include 24 | #define MNN_ERROR(format, ...) __android_log_print(ANDROID_LOG_ERROR, "MNNJNI", format, ##__VA_ARGS__) 25 | #define MNN_PRINT(format, ...) __android_log_print(ANDROID_LOG_INFO, "MNNJNI", format, ##__VA_ARGS__) 26 | #elif defined MNN_BUILD_FOR_IOS 27 | // on iOS, stderr prints to XCode debug area and syslog prints Console. You need both. 28 | #include 29 | #define MNN_PRINT(format, ...) syslog(LOG_WARNING, format, ##__VA_ARGS__); fprintf(stderr, format, ##__VA_ARGS__) 30 | #define MNN_ERROR(format, ...) syslog(LOG_WARNING, format, ##__VA_ARGS__); fprintf(stderr, format, ##__VA_ARGS__) 31 | #else 32 | #define MNN_PRINT(format, ...) printf(format, ##__VA_ARGS__) 33 | #define MNN_ERROR(format, ...) printf(format, ##__VA_ARGS__) 34 | #endif 35 | 36 | #ifdef DEBUG 37 | #define MNN_ASSERT(x) \ 38 | { \ 39 | int res = (x); \ 40 | if (!res) { \ 41 | MNN_ERROR("Error for %s, %d\n", __FILE__, __LINE__); \ 42 | assert(res); \ 43 | } \ 44 | } 45 | #else 46 | #define MNN_ASSERT(x) 47 | #endif 48 | 49 | #define FUNC_PRINT(x) MNN_PRINT(#x "=%d in %s, %d \n", x, __func__, __LINE__); 50 | #define FUNC_PRINT_ALL(x, type) MNN_PRINT(#x "=" #type " %" #type " in %s, %d \n", x, __func__, __LINE__); 51 | 52 | #define MNN_CHECK(success, log) \ 53 | if(!(success)){ \ 54 | MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \ 55 | } 56 | 57 | #if defined(_MSC_VER) 58 | #if defined(BUILDING_MNN_DLL) 59 | #define MNN_PUBLIC __declspec(dllexport) 60 | #elif defined(USING_MNN_DLL) 61 | #define MNN_PUBLIC __declspec(dllimport) 62 | #else 63 | #define MNN_PUBLIC 64 | #endif 65 | #else 66 | #define MNN_PUBLIC __attribute__((visibility("default"))) 67 | #endif 68 | #define STR_IMP(x) #x 69 | #define STR(x) STR_IMP(x) 70 | #define MNN_VERSION_MAJOR 2 71 | #define MNN_VERSION_MINOR 2 72 | #define MNN_VERSION_PATCH 1 73 | #define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH) 74 | #endif /* MNNDefine_h */ 75 | -------------------------------------------------------------------------------- /include/MNN/MNNForwardType.h: -------------------------------------------------------------------------------- 1 | // 2 | // MNNForwardType.h 3 | // MNN 4 | // 5 | // Created by MNN on 2019/01/19. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNNForwardType_h 10 | #define MNNForwardType_h 11 | #include 12 | #include 13 | 14 | typedef enum { 15 | MNN_FORWARD_CPU = 0, 16 | 17 | /* 18 | Firtly find the first available backends not equal to CPU 19 | If no other backends, use cpu 20 | */ 21 | MNN_FORWARD_AUTO = 4, 22 | 23 | /*Hand write metal*/ 24 | MNN_FORWARD_METAL = 1, 25 | 26 | /*NVIDIA GPU API*/ 27 | MNN_FORWARD_CUDA = 2, 28 | 29 | /*Android / Common Device GPU API*/ 30 | MNN_FORWARD_OPENCL = 3, 31 | MNN_FORWARD_OPENGL = 6, 32 | MNN_FORWARD_VULKAN = 7, 33 | 34 | /*Android 8.1's NNAPI, Not Support yet. CoreML Now*/ 35 | MNN_FORWARD_NN = 5, 36 | 37 | /*User can use API from Backend.hpp to add or search Backend*/ 38 | MNN_FORWARD_USER_0 = 8, 39 | MNN_FORWARD_USER_1 = 9, 40 | MNN_FORWARD_USER_2 = 10, 41 | MNN_FORWARD_USER_3 = 11, 42 | 43 | MNN_FORWARD_ALL, 44 | 45 | /* Apply arm extension instruction set to accelerate some Ops, this forward type 46 | is only used in MNN internal, and will be active automatically when user set forward type 47 | to be MNN_FORWARD_CPU and extension instruction set is valid on hardware. 48 | */ 49 | MNN_FORWARD_CPU_EXTENSION 50 | 51 | } MNNForwardType; 52 | 53 | typedef enum { 54 | // choose one tuning mode Only 55 | MNN_GPU_TUNING_NONE = 1 << 0,/* Forbidden tuning, performance not good */ 56 | MNN_GPU_TUNING_HEAVY = 1 << 1,/* heavily tuning, usually not suggested */ 57 | MNN_GPU_TUNING_WIDE = 1 << 2,/* widely tuning, performance good. Default */ 58 | MNN_GPU_TUNING_NORMAL = 1 << 3,/* normal tuning, performance may be ok */ 59 | MNN_GPU_TUNING_FAST = 1 << 4,/* fast tuning, performance may not good */ 60 | 61 | // choose one opencl memory mode Only 62 | /* User can try OpenCL_MEMORY_BUFFER and OpenCL_MEMORY_IMAGE both, 63 | then choose the better one according to performance*/ 64 | MNN_GPU_MEMORY_BUFFER = 1 << 6,/* User assign mode */ 65 | MNN_GPU_MEMORY_IMAGE = 1 << 7,/* User assign mode */ 66 | } MNNGpuMode; 67 | 68 | #ifdef __cplusplus 69 | namespace MNN { 70 | struct BackendConfig { 71 | enum MemoryMode { Memory_Normal = 0, Memory_High, Memory_Low }; 72 | 73 | MemoryMode memory = Memory_Normal; 74 | 75 | enum PowerMode { Power_Normal = 0, Power_High, Power_Low }; 76 | 77 | PowerMode power = Power_Normal; 78 | 79 | enum PrecisionMode { Precision_Normal = 0, Precision_High, Precision_Low, Precision_Low_BF16 }; 80 | 81 | PrecisionMode precision = Precision_Normal; 82 | 83 | /** user defined context */ 84 | union { 85 | void* sharedContext = nullptr; 86 | size_t flags; // Valid for CPU Backend 87 | }; 88 | }; 89 | 90 | /** acquire runtime status by Runtime::getCurrentStatus with following keys, 91 | */ 92 | enum RuntimeStatus { 93 | /** 94 | * get status whether this runtime support 16-bits float point arithmetic 95 | */ 96 | STATUS_SUPPORT_FP16, 97 | /** 98 | * get status whether this runtime support dot-product arithmetic 99 | */ 100 | STATUS_SUPPORT_DOT_PRODUCT, 101 | /** 102 | * emum total number 103 | */ 104 | STATUS_COUNT 105 | }; 106 | 107 | }; // namespace MNN 108 | #endif 109 | #endif /* MNNForwardType_h */ 110 | -------------------------------------------------------------------------------- /include/MNN/MNNSharedContext.h: -------------------------------------------------------------------------------- 1 | // 2 | // MNNSharedContext.h 3 | // MNN 4 | // 5 | // Created by MNN on 2018/10/11. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNNSharedContext_h 10 | #define MNNSharedContext_h 11 | #include "MNNDefine.h" 12 | #ifdef __cplusplus 13 | extern "C" { 14 | #endif 15 | 16 | #include /*uint32_t*/ 17 | #ifdef MNN_VULKAN 18 | 19 | struct MNNVulkanContext { 20 | VkInstance pInstance; 21 | VkPhysicalDevice pPhysicalDevice; 22 | VkDevice pDevice; 23 | VkQueue pQueue; 24 | uint32_t iQueueFamilyIndex; 25 | }; 26 | 27 | #endif 28 | 29 | #ifdef MNN_METAL 30 | struct MNNMetalSharedContext { 31 | id device; 32 | id queue; 33 | }; 34 | 35 | struct MNNMetalTensorContent { 36 | id buffer; 37 | int32_t offset; 38 | id texture; 39 | int32_t forFuture[8]; 40 | }; 41 | 42 | MNN_PUBLIC int MNNMetalGetTensorContent(MNNMetalTensorContent* content, void* tensor); 43 | #endif 44 | 45 | 46 | #ifdef __cplusplus 47 | } 48 | #endif 49 | 50 | #endif /* MNNSharedContext_h */ 51 | -------------------------------------------------------------------------------- /include/MNN/Rect.h: -------------------------------------------------------------------------------- 1 | // 2 | // Rect.h 3 | // MNN 4 | // 5 | // Modified by jiangxiaotang on 2018/09/19. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | /* 10 | * Copyright 2006 The Android Open Source Project 11 | * 12 | * Use of this source code is governed by a BSD-style license that can be 13 | * found in the LICENSE file. 14 | */ 15 | 16 | /* Generated by tools/bookmaker from include/core/Rect.h and docs/SkRect_Reference.bmh 17 | on 2018-07-13 08:15:11. Additional documentation and examples can be found at: 18 | https://skia.org/user/api/SkRect_Reference 19 | 20 | You may edit either file directly. Structural changes to public interfaces require 21 | editing both files. After editing docs/SkRect_Reference.bmh, run: 22 | bookmaker -b docs -i include/core/Rect.h -p 23 | to create an updated version of this file. 24 | */ 25 | 26 | #ifndef MNN_Rect_DEFINED 27 | #define MNN_Rect_DEFINED 28 | 29 | #include 30 | #include 31 | #include 32 | #include 33 | 34 | namespace MNN { 35 | namespace CV { 36 | 37 | struct Point { 38 | float fX; 39 | float fY; 40 | 41 | void set(float x, float y) { 42 | fX = x; 43 | fY = y; 44 | } 45 | }; 46 | 47 | /** \struct Rect 48 | Rect holds four float coordinates describing the upper and 49 | lower bounds of a rectangle. Rect may be created from outer bounds or 50 | from position, width, and height. Rect describes an area; if its right 51 | is less than or equal to its left, or if its bottom is less than or equal to 52 | its top, it is considered empty. 53 | */ 54 | struct MNN_PUBLIC Rect { 55 | float fLeft; //!< smaller x-axis bounds 56 | float fTop; //!< smaller y-axis bounds 57 | float fRight; //!< larger x-axis bounds 58 | float fBottom; //!< larger y-axis bounds 59 | 60 | /** Returns constructed Rect set to (0, 0, 0, 0). 61 | Many other rectangles are empty; if left is equal to or greater than right, 62 | or if top is equal to or greater than bottom. Setting all members to zero 63 | is a convenience, but does not designate a special empty rectangle. 64 | 65 | @return bounds (0, 0, 0, 0) 66 | */ 67 | static constexpr Rect MakeEmpty() { 68 | return Rect{0, 0, 0, 0}; 69 | } 70 | 71 | #ifdef SK_SUPPORT_LEGACY_RECTMAKELARGEST 72 | /** Deprecated. 73 | */ 74 | static Rect MakeLargest() { 75 | return {SK_ScalarMin, SK_ScalarMin, SK_ScalarMax, SK_ScalarMax}; 76 | } 77 | #endif 78 | 79 | /** Returns constructed Rect set to float values (0, 0, w, h). Does not 80 | validate input; w or h may be negative. 81 | 82 | Passing integer values may generate a compiler warning since Rect cannot 83 | represent 32-bit integers exactly. Use SkIRect for an exact integer rectangle. 84 | 85 | @param w float width of constructed Rect 86 | @param h float height of constructed Rect 87 | @return bounds (0, 0, w, h) 88 | */ 89 | static constexpr Rect MakeWH(float w, float h) { 90 | return Rect{0, 0, w, h}; 91 | } 92 | 93 | /** Returns constructed Rect set to integer values (0, 0, w, h). Does not validate 94 | input; w or h may be negative. 95 | 96 | Use to avoid a compiler warning that input may lose precision when stored. 97 | Use SkIRect for an exact integer rectangle. 98 | 99 | @param w integer width of constructed Rect 100 | @param h integer height of constructed Rect 101 | @return bounds (0, 0, w, h) 102 | */ 103 | static Rect MakeIWH(int w, int h) { 104 | Rect r; 105 | r.set(0, 0, (float)(w), (float)(h)); 106 | return r; 107 | } 108 | 109 | /** Returns constructed Rect set to (l, t, r, b). Does not sort input; Rect may 110 | result in fLeft greater than fRight, or fTop greater than fBottom. 111 | 112 | @param l float stored in fLeft 113 | @param t float stored in fTop 114 | @param r float stored in fRight 115 | @param b float stored in fBottom 116 | @return bounds (l, t, r, b) 117 | */ 118 | static constexpr Rect MakeLTRB(float l, float t, float r, float b) { 119 | return Rect{l, t, r, b}; 120 | } 121 | 122 | /** Returns constructed Rect set to (x, y, x + w, y + h). Does not validate input; 123 | w or h may be negative. 124 | 125 | @param x stored in fLeft 126 | @param y stored in fTop 127 | @param w added to x and stored in fRight 128 | @param h added to y and stored in fBottom 129 | @return bounds at (x, y) with width w and height h 130 | */ 131 | static constexpr Rect MakeXYWH(float x, float y, float w, float h) { 132 | return Rect{x, y, x + w, y + h}; 133 | } 134 | 135 | /** Returns true if fLeft is equal to or greater than fRight, or if fTop is equal 136 | to or greater than fBottom. Call sort() to reverse rectangles with negative 137 | width() or height(). 138 | 139 | @return true if width() or height() are zero or negative 140 | */ 141 | bool isEmpty() const { 142 | // We write it as the NOT of a non-empty rect, so we will return true if any values 143 | // are NaN. 144 | return !(fLeft < fRight && fTop < fBottom); 145 | } 146 | 147 | /** Returns true if fLeft is equal to or less than fRight, or if fTop is equal 148 | to or less than fBottom. Call sort() to reverse rectangles with negative 149 | width() or height(). 150 | 151 | @return true if width() or height() are zero or positive 152 | */ 153 | bool isSorted() const { 154 | return fLeft <= fRight && fTop <= fBottom; 155 | } 156 | 157 | /** Returns left edge of Rect, if sorted. Call isSorted() to see if Rect is valid. 158 | Call sort() to reverse fLeft and fRight if needed. 159 | 160 | @return fLeft 161 | */ 162 | float x() const { 163 | return fLeft; 164 | } 165 | 166 | /** Returns top edge of Rect, if sorted. Call isEmpty() to see if Rect may be invalid, 167 | and sort() to reverse fTop and fBottom if needed. 168 | 169 | @return fTop 170 | */ 171 | float y() const { 172 | return fTop; 173 | } 174 | 175 | /** Returns left edge of Rect, if sorted. Call isSorted() to see if Rect is valid. 176 | Call sort() to reverse fLeft and fRight if needed. 177 | 178 | @return fLeft 179 | */ 180 | float left() const { 181 | return fLeft; 182 | } 183 | 184 | /** Returns top edge of Rect, if sorted. Call isEmpty() to see if Rect may be invalid, 185 | and sort() to reverse fTop and fBottom if needed. 186 | 187 | @return fTop 188 | */ 189 | float top() const { 190 | return fTop; 191 | } 192 | 193 | /** Returns right edge of Rect, if sorted. Call isSorted() to see if Rect is valid. 194 | Call sort() to reverse fLeft and fRight if needed. 195 | 196 | @return fRight 197 | */ 198 | float right() const { 199 | return fRight; 200 | } 201 | 202 | /** Returns bottom edge of Rect, if sorted. Call isEmpty() to see if Rect may be invalid, 203 | and sort() to reverse fTop and fBottom if needed. 204 | 205 | @return fBottom 206 | */ 207 | float bottom() const { 208 | return fBottom; 209 | } 210 | 211 | /** Returns span on the x-axis. This does not check if Rect is sorted, or if 212 | result fits in 32-bit float; result may be negative or infinity. 213 | 214 | @return fRight minus fLeft 215 | */ 216 | float width() const { 217 | return fRight - fLeft; 218 | } 219 | 220 | /** Returns span on the y-axis. This does not check if Rect is sorted, or if 221 | result fits in 32-bit float; result may be negative or infinity. 222 | 223 | @return fBottom minus fTop 224 | */ 225 | float height() const { 226 | return fBottom - fTop; 227 | } 228 | 229 | /** Returns average of left edge and right edge. Result does not change if Rect 230 | is sorted. Result may overflow to infinity if Rect is far from the origin. 231 | 232 | @return midpoint in x 233 | */ 234 | float centerX() const { 235 | // don't use floatHalf(fLeft + fBottom) as that might overflow before the 0.5 236 | return 0.5f * (fLeft) + 0.5f * (fRight); 237 | } 238 | 239 | /** Returns average of top edge and bottom edge. Result does not change if Rect 240 | is sorted. 241 | 242 | @return midpoint in y 243 | */ 244 | float centerY() const { 245 | // don't use floatHalf(fTop + fBottom) as that might overflow before the 0.5 246 | return 0.5f * (fTop) + 0.5f * (fBottom); 247 | } 248 | 249 | /** Sets Rect to (0, 0, 0, 0). 250 | 251 | Many other rectangles are empty; if left is equal to or greater than right, 252 | or if top is equal to or greater than bottom. Setting all members to zero 253 | is a convenience, but does not designate a special empty rectangle. 254 | */ 255 | void setEmpty() { 256 | *this = MakeEmpty(); 257 | } 258 | 259 | /** Sets Rect to (left, top, right, bottom). 260 | left and right are not sorted; left is not necessarily less than right. 261 | top and bottom are not sorted; top is not necessarily less than bottom. 262 | 263 | @param left stored in fLeft 264 | @param top stored in fTop 265 | @param right stored in fRight 266 | @param bottom stored in fBottom 267 | */ 268 | void set(float left, float top, float right, float bottom) { 269 | fLeft = left; 270 | fTop = top; 271 | fRight = right; 272 | fBottom = bottom; 273 | } 274 | 275 | /** Sets Rect to (left, top, right, bottom). 276 | left and right are not sorted; left is not necessarily less than right. 277 | top and bottom are not sorted; top is not necessarily less than bottom. 278 | 279 | @param left stored in fLeft 280 | @param top stored in fTop 281 | @param right stored in fRight 282 | @param bottom stored in fBottom 283 | */ 284 | void setLTRB(float left, float top, float right, float bottom) { 285 | this->set(left, top, right, bottom); 286 | } 287 | 288 | /** Sets Rect to (left, top, right, bottom). 289 | All parameters are promoted from integer to scalar. 290 | left and right are not sorted; left is not necessarily less than right. 291 | top and bottom are not sorted; top is not necessarily less than bottom. 292 | 293 | @param left promoted to float and stored in fLeft 294 | @param top promoted to float and stored in fTop 295 | @param right promoted to float and stored in fRight 296 | @param bottom promoted to float and stored in fBottom 297 | */ 298 | void iset(int left, int top, int right, int bottom) { 299 | fLeft = (float)(left); 300 | fTop = (float)(top); 301 | fRight = (float)(right); 302 | fBottom = (float)(bottom); 303 | } 304 | 305 | /** Sets Rect to (0, 0, width, height). 306 | width and height may be zero or negative. width and height are promoted from 307 | integer to float, large values may lose precision. 308 | 309 | @param width promoted to float and stored in fRight 310 | @param height promoted to float and stored in fBottom 311 | */ 312 | void isetWH(int width, int height) { 313 | fLeft = fTop = 0; 314 | fRight = (float)(width); 315 | fBottom = (float)(height); 316 | } 317 | 318 | /** Sets Rect to (x, y, x + width, y + height). Does not validate input; 319 | width or height may be negative. 320 | 321 | @param x stored in fLeft 322 | @param y stored in fTop 323 | @param width added to x and stored in fRight 324 | @param height added to y and stored in fBottom 325 | */ 326 | void setXYWH(float x, float y, float width, float height) { 327 | fLeft = x; 328 | fTop = y; 329 | fRight = x + width; 330 | fBottom = y + height; 331 | } 332 | 333 | /** Sets Rect to (0, 0, width, height). Does not validate input; 334 | width or height may be negative. 335 | 336 | @param width stored in fRight 337 | @param height stored in fBottom 338 | */ 339 | void setWH(float width, float height) { 340 | fLeft = 0; 341 | fTop = 0; 342 | fRight = width; 343 | fBottom = height; 344 | } 345 | 346 | /** Returns Rect offset by (dx, dy). 347 | 348 | If dx is negative, Rect returned is moved to the left. 349 | If dx is positive, Rect returned is moved to the right. 350 | If dy is negative, Rect returned is moved upward. 351 | If dy is positive, Rect returned is moved downward. 352 | 353 | @param dx added to fLeft and fRight 354 | @param dy added to fTop and fBottom 355 | @return Rect offset on axes, with original width and height 356 | */ 357 | Rect makeOffset(float dx, float dy) const { 358 | return MakeLTRB(fLeft + dx, fTop + dy, fRight + dx, fBottom + dy); 359 | } 360 | 361 | /** Returns Rect, inset by (dx, dy). 362 | 363 | If dx is negative, Rect returned is wider. 364 | If dx is positive, Rect returned is narrower. 365 | If dy is negative, Rect returned is taller. 366 | If dy is positive, Rect returned is shorter. 367 | 368 | @param dx added to fLeft and subtracted from fRight 369 | @param dy added to fTop and subtracted from fBottom 370 | @return Rect inset symmetrically left and right, top and bottom 371 | */ 372 | Rect makeInset(float dx, float dy) const { 373 | return MakeLTRB(fLeft + dx, fTop + dy, fRight - dx, fBottom - dy); 374 | } 375 | 376 | /** Returns Rect, outset by (dx, dy). 377 | 378 | If dx is negative, Rect returned is narrower. 379 | If dx is positive, Rect returned is wider. 380 | If dy is negative, Rect returned is shorter. 381 | If dy is positive, Rect returned is taller. 382 | 383 | @param dx subtracted to fLeft and added from fRight 384 | @param dy subtracted to fTop and added from fBottom 385 | @return Rect outset symmetrically left and right, top and bottom 386 | */ 387 | Rect makeOutset(float dx, float dy) const { 388 | return MakeLTRB(fLeft - dx, fTop - dy, fRight + dx, fBottom + dy); 389 | } 390 | 391 | /** Offsets Rect by adding dx to fLeft, fRight; and by adding dy to fTop, fBottom. 392 | 393 | If dx is negative, moves Rect to the left. 394 | If dx is positive, moves Rect to the right. 395 | If dy is negative, moves Rect upward. 396 | If dy is positive, moves Rect downward. 397 | 398 | @param dx offset added to fLeft and fRight 399 | @param dy offset added to fTop and fBottom 400 | */ 401 | void offset(float dx, float dy) { 402 | fLeft += dx; 403 | fTop += dy; 404 | fRight += dx; 405 | fBottom += dy; 406 | } 407 | 408 | /** Offsets Rect so that fLeft equals newX, and fTop equals newY. width and height 409 | are unchanged. 410 | 411 | @param newX stored in fLeft, preserving width() 412 | @param newY stored in fTop, preserving height() 413 | */ 414 | void offsetTo(float newX, float newY) { 415 | fRight += newX - fLeft; 416 | fBottom += newY - fTop; 417 | fLeft = newX; 418 | fTop = newY; 419 | } 420 | 421 | /** Insets Rect by (dx, dy). 422 | 423 | If dx is positive, makes Rect narrower. 424 | If dx is negative, makes Rect wider. 425 | If dy is positive, makes Rect shorter. 426 | If dy is negative, makes Rect taller. 427 | 428 | @param dx added to fLeft and subtracted from fRight 429 | @param dy added to fTop and subtracted from fBottom 430 | */ 431 | void inset(float dx, float dy) { 432 | fLeft += dx; 433 | fTop += dy; 434 | fRight -= dx; 435 | fBottom -= dy; 436 | } 437 | 438 | /** Outsets Rect by (dx, dy). 439 | 440 | If dx is positive, makes Rect wider. 441 | If dx is negative, makes Rect narrower. 442 | If dy is positive, makes Rect taller. 443 | If dy is negative, makes Rect shorter. 444 | 445 | @param dx subtracted to fLeft and added from fRight 446 | @param dy subtracted to fTop and added from fBottom 447 | */ 448 | void outset(float dx, float dy) { 449 | this->inset(-dx, -dy); 450 | } 451 | 452 | private: 453 | static bool Intersects(float al, float at, float ar, float ab, float bl, float bt, float br, float bb) { 454 | float L = std::max(al, bl); 455 | float R = std::min(ar, br); 456 | float T = std::max(at, bt); 457 | float B = std::min(ab, bb); 458 | return L < R && T < B; 459 | } 460 | 461 | public: 462 | /** Constructs Rect to intersect from (left, top, right, bottom). Does not sort 463 | construction. 464 | 465 | Returns true if Rect intersects construction. 466 | Returns false if either construction or Rect is empty, or do not intersect. 467 | 468 | @param left x-axis minimum of constructed Rect 469 | @param top y-axis minimum of constructed Rect 470 | @param right x-axis maximum of constructed Rect 471 | @param bottom y-axis maximum of constructed Rect 472 | @return true if construction and Rect have area in common 473 | */ 474 | bool intersects(float left, float top, float right, float bottom) const { 475 | return Intersects(fLeft, fTop, fRight, fBottom, left, top, right, bottom); 476 | } 477 | 478 | /** Returns true if Rect intersects r. 479 | Returns false if either r or Rect is empty, or do not intersect. 480 | 481 | @param r Rect to intersect 482 | @return true if r and Rect have area in common 483 | */ 484 | bool intersects(const Rect& r) const { 485 | return Intersects(fLeft, fTop, fRight, fBottom, r.fLeft, r.fTop, r.fRight, r.fBottom); 486 | } 487 | 488 | /** Returns true if a intersects b. 489 | Returns false if either a or b is empty, or do not intersect. 490 | 491 | @param a Rect to intersect 492 | @param b Rect to intersect 493 | @return true if a and b have area in common 494 | */ 495 | static bool Intersects(const Rect& a, const Rect& b) { 496 | return Intersects(a.fLeft, a.fTop, a.fRight, a.fBottom, b.fLeft, b.fTop, b.fRight, b.fBottom); 497 | } 498 | 499 | /** Sets Rect to the union of itself and r. 500 | 501 | Asserts if r is empty and SK_DEBUG is defined. 502 | If Rect is empty, sets Rect to r. 503 | 504 | May produce incorrect results if r is empty. 505 | 506 | @param r expansion Rect 507 | */ 508 | void joinNonEmptyArg(const Rect& r) { 509 | MNN_ASSERT(!r.isEmpty()); 510 | // if we are empty, just assign 511 | if (fLeft >= fRight || fTop >= fBottom) { 512 | *this = r; 513 | } else { 514 | this->joinPossiblyEmptyRect(r); 515 | } 516 | } 517 | 518 | /** Sets Rect to the union of itself and the construction. 519 | 520 | May produce incorrect results if Rect or r is empty. 521 | 522 | @param r expansion Rect 523 | */ 524 | void joinPossiblyEmptyRect(const Rect& r) { 525 | fLeft = std::min(fLeft, r.left()); 526 | fTop = std::min(fTop, r.top()); 527 | fRight = std::max(fRight, r.right()); 528 | fBottom = std::max(fBottom, r.bottom()); 529 | } 530 | 531 | /** Returns true if: fLeft <= x < fRight && fTop <= y < fBottom. 532 | Returns false if Rect is empty. 533 | 534 | @param x test Point x-coordinate 535 | @param y test Point y-coordinate 536 | @return true if (x, y) is inside Rect 537 | */ 538 | bool contains(float x, float y) const { 539 | return x >= fLeft && x < fRight && y >= fTop && y < fBottom; 540 | } 541 | 542 | /** Swaps fLeft and fRight if fLeft is greater than fRight; and swaps 543 | fTop and fBottom if fTop is greater than fBottom. Result may be empty; 544 | and width() and height() will be zero or positive. 545 | */ 546 | void sort() { 547 | using std::swap; 548 | if (fLeft > fRight) { 549 | swap(fLeft, fRight); 550 | } 551 | 552 | if (fTop > fBottom) { 553 | swap(fTop, fBottom); 554 | } 555 | } 556 | 557 | /** Returns Rect with fLeft and fRight swapped if fLeft is greater than fRight; and 558 | with fTop and fBottom swapped if fTop is greater than fBottom. Result may be empty; 559 | and width() and height() will be zero or positive. 560 | 561 | @return sorted Rect 562 | */ 563 | Rect makeSorted() const { 564 | return MakeLTRB(std::min(fLeft, fRight), std::min(fTop, fBottom), std::max(fLeft, fRight), 565 | std::max(fTop, fBottom)); 566 | } 567 | 568 | /** Returns pointer to first scalar in Rect, to treat it as an array with four 569 | entries. 570 | 571 | @return pointer to fLeft 572 | */ 573 | const float* asScalars() const { 574 | return &fLeft; 575 | } 576 | }; 577 | 578 | } // namespace CV 579 | } // namespace MNN 580 | #endif 581 | -------------------------------------------------------------------------------- /include/MNN/Tensor.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Tensor.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2018/08/14. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_Tensor_hpp 10 | #define MNN_Tensor_hpp 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | namespace MNN { 17 | 18 | /** 19 | * data container. 20 | * data for host tensor is saved in `host` field. its memory is allocated malloc directly. 21 | * data for device tensor is saved in `deviceId` field. its memory is allocated by session's backend. 22 | * usually, device tensors are created by engine (like net, session). 23 | * meanwhile, host tensors could be created by engine or user. 24 | */ 25 | class MNN_PUBLIC Tensor { 26 | public: 27 | struct InsideDescribe; 28 | 29 | /** dimension type used to create tensor */ 30 | enum DimensionType { 31 | /** for tensorflow net type. uses NHWC as data format. */ 32 | TENSORFLOW, 33 | /** for caffe net type. uses NCHW as data format. */ 34 | CAFFE, 35 | /** for caffe net type. uses NC4HW4 as data format. */ 36 | CAFFE_C4 37 | }; 38 | 39 | /** handle type */ 40 | enum HandleDataType { 41 | /** default handle type */ 42 | HANDLE_NONE = 0, 43 | /** string handle type */ 44 | HANDLE_STRING = 1 45 | }; 46 | 47 | /** Tensor map type : Read or Write*/ 48 | enum MapType { 49 | /** map Tensor for writing data*/ 50 | MAP_TENSOR_WRITE = 0, 51 | MAP_TENSOR_READ = 1 52 | }; 53 | 54 | public: 55 | /** 56 | * @brief create a tensor with dimension size and type without acquire memory for data. 57 | * @param dimSize dimension size. 58 | * @param type dimension type. 59 | */ 60 | Tensor(int dimSize = 4, DimensionType type = CAFFE); 61 | 62 | /** 63 | * @brief create a tensor with same shape as given tensor. 64 | * @param tensor shape provider. 65 | * @param type dimension type. 66 | * @param allocMemory acquire memory for data or not. 67 | * @warning tensor data won't be copied. 68 | */ 69 | Tensor(const Tensor* tensor, DimensionType type = CAFFE, bool allocMemory = true); 70 | 71 | /** deinitializer */ 72 | ~Tensor(); 73 | 74 | private: 75 | Tensor(bool deepCopy, const Tensor* tensor); 76 | // remove all assignment operator 77 | Tensor(const Tensor& tensor) = delete; 78 | Tensor(const Tensor&& tensor) = delete; 79 | Tensor& operator=(const Tensor&) = delete; 80 | Tensor& operator=(const Tensor&&) = delete; 81 | 82 | public: 83 | /** 84 | * @brief create tensor with shape, data type and dimension type. 85 | * @param shape tensor shape. 86 | * @param type data type. 87 | * @param dimType dimension type. 88 | * @return created tensor. 89 | * @warning memory for data won't be acquired. call backend's onAcquireBuffer to get memory ready. 90 | */ 91 | static Tensor* createDevice(const std::vector& shape, halide_type_t type, DimensionType dimType = TENSORFLOW); 92 | 93 | /** 94 | * @brief create tensor with shape and dimension type. data type is represented by `T`. 95 | * @param shape tensor shape. 96 | * @param dimType dimension type. 97 | * @return created tensor. 98 | * @warning memory for data won't be acquired. call backend's onAcquireBuffer to get memory ready. 99 | */ 100 | template 101 | static Tensor* createDevice(const std::vector& shape, DimensionType dimType = TENSORFLOW) { 102 | return createDevice(shape, halide_type_of(), dimType); 103 | } 104 | 105 | /** 106 | * @brief create tensor with shape, data type, data and dimension type. 107 | * @param shape tensor shape. 108 | * @param type data type. 109 | * @param data data to save. 110 | * @param dimType dimension type. 111 | * @return created tensor. 112 | */ 113 | static Tensor* create(const std::vector& shape, halide_type_t type, void* data = NULL, 114 | DimensionType dimType = TENSORFLOW); 115 | 116 | /** 117 | * @brief create tensor with shape, data and dimension type. data type is represented by `T`. 118 | * @param shape tensor shape. 119 | * @param data data to save. 120 | * @param dimType dimension type. 121 | * @return created tensor. 122 | */ 123 | template 124 | static Tensor* create(const std::vector& shape, void* data = NULL, DimensionType dimType = TENSORFLOW) { 125 | return create(shape, halide_type_of(), data, dimType); 126 | } 127 | 128 | /** 129 | * @brief copy tensor. 130 | * @param src tensor 131 | * @param deepCopy whether create new content and copy, currently only support deepCopy = false 132 | */ 133 | static Tensor* clone(const Tensor* src, bool deepCopy = false); 134 | 135 | /** 136 | * @brief delete tensor. 137 | * @param src tensor 138 | */ 139 | static void destroy(Tensor* tensor); 140 | public: 141 | /** 142 | * @brief for DEVICE tensor, copy data from given host tensor. 143 | * @param hostTensor host tensor, the data provider. 144 | * @return true for DEVICE tensor, and false for HOST tensor. 145 | */ 146 | bool copyFromHostTensor(const Tensor* hostTensor); 147 | 148 | /** 149 | * @brief for DEVICE tensor, copy data to given host tensor. 150 | * @param hostTensor host tensor, the data consumer. 151 | * @return true for DEVICE tensor, and false for HOST tensor. 152 | */ 153 | bool copyToHostTensor(Tensor* hostTensor) const; 154 | 155 | /** 156 | * @brief create HOST tensor from DEVICE tensor, with or without data copying. 157 | * @param deviceTensor given device tensor. 158 | * @param copyData copy data or not. 159 | * @return created host tensor. 160 | */ 161 | static Tensor* createHostTensorFromDevice(const Tensor* deviceTensor, bool copyData = true); 162 | 163 | public: 164 | const halide_buffer_t& buffer() const { 165 | return mBuffer; 166 | } 167 | halide_buffer_t& buffer() { 168 | return mBuffer; 169 | } 170 | 171 | /** 172 | * @brief get dimension type. 173 | * @return dimension type. 174 | */ 175 | DimensionType getDimensionType() const; 176 | 177 | /** 178 | * @brief handle data type. used when data type code is halide_type_handle. 179 | * @return handle data type. 180 | */ 181 | HandleDataType getHandleDataType() const; 182 | 183 | /** 184 | * @brief set data type. 185 | * @param type data type defined in 'Type_generated.h'. 186 | */ 187 | void setType(int type); 188 | 189 | /** 190 | * @brief get data type. 191 | * @return data type. 192 | */ 193 | inline halide_type_t getType() const { 194 | return mBuffer.type; 195 | } 196 | 197 | /** 198 | * @brief visit host memory, data type is represented by `T`. 199 | * @return data point in `T` type. 200 | */ 201 | template 202 | T* host() const { 203 | return (T*)mBuffer.host; 204 | } 205 | 206 | /** 207 | * @brief visit device memory. 208 | * @return device data ID. what the ID means varies between backends. 209 | */ 210 | uint64_t deviceId() const { 211 | return mBuffer.device; 212 | } 213 | 214 | public: 215 | int dimensions() const { 216 | return mBuffer.dimensions; 217 | } 218 | 219 | /** 220 | * @brief get all dimensions' extent. 221 | * @return dimensions' extent. 222 | */ 223 | std::vector shape() const; 224 | 225 | /** 226 | * @brief calculate number of bytes needed to store data taking reordering flag into account. 227 | * @return bytes needed to store data 228 | */ 229 | int size() const; 230 | 231 | /** 232 | * @brief calculate number of elements needed to store data taking reordering flag into account. 233 | * @return elements needed to store data 234 | */ 235 | inline int elementSize() const { 236 | return size() / mBuffer.type.bytes(); 237 | } 238 | 239 | public: 240 | inline int width() const { 241 | if (getDimensionType() == TENSORFLOW) { 242 | return mBuffer.dim[2].extent; 243 | } 244 | 245 | return mBuffer.dim[3].extent; 246 | } 247 | inline int height() const { 248 | if (getDimensionType() == TENSORFLOW) { 249 | return mBuffer.dim[1].extent; 250 | } 251 | return mBuffer.dim[2].extent; 252 | } 253 | inline int channel() const { 254 | if (getDimensionType() == TENSORFLOW) { 255 | return mBuffer.dim[3].extent; 256 | } 257 | return mBuffer.dim[1].extent; 258 | } 259 | inline int batch() const { 260 | return mBuffer.dim[0].extent; 261 | } 262 | 263 | // visit dimension's extent & stride 264 | inline int stride(int index) const { 265 | return mBuffer.dim[index].stride; 266 | } 267 | inline int length(int index) const { 268 | return mBuffer.dim[index].extent; 269 | } 270 | inline void setStride(int index, int stride) { 271 | mBuffer.dim[index].stride = stride; 272 | } 273 | inline void setLength(int index, int length) { 274 | mBuffer.dim[index].extent = length; 275 | } 276 | 277 | public: 278 | /** 279 | * @brief print tensor data. for DEBUG use only. 280 | */ 281 | void print() const; 282 | 283 | /** 284 | *@brief print tensor shape 285 | */ 286 | void printShape() const; 287 | 288 | public: 289 | /** 290 | * @brief map/umap GPU Tensor, to get host ptr 291 | */ 292 | void* map(MapType mtype, DimensionType dtype); 293 | void unmap(MapType mtype, DimensionType dtype, void* mapPtr); 294 | /** 295 | * @brief wait until the tensor is ready to read / write 296 | * @param mtype wait for read or write 297 | * @param finish wait for command flush or finish 298 | */ 299 | int wait(MapType mtype, bool finish); 300 | private: 301 | halide_buffer_t mBuffer; 302 | struct InsideDescribe* mDescribe; 303 | 304 | private: 305 | friend class TensorUtils; 306 | }; 307 | } // namespace MNN 308 | 309 | #endif /* Tensor_hpp */ 310 | -------------------------------------------------------------------------------- /include/MNN/expr/Executor.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Executor.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/07/25. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | #ifndef MNN_Executor_hpp 9 | #define MNN_Executor_hpp 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | namespace MNN { 19 | class Backend; 20 | class Execution; 21 | class Runtime; 22 | struct Op; 23 | namespace Express { 24 | struct RuntimeAttr; 25 | class MNN_PUBLIC Executor { 26 | public: 27 | class ComputeCache; 28 | struct Unit; 29 | struct DebugTools; 30 | /**Internal Usage Begin*/ 31 | static void setShapeDirty(ComputeCache* cache); 32 | static void setContentDirty(ComputeCache* cache); 33 | static Tensor* getOutput(ComputeCache* cache, int offset); 34 | static std::pair, std::shared_ptr> getBackends(ComputeCache* cache); 35 | static void* mapOutput(ComputeCache* cache, int offset, Tensor* dest); 36 | struct Requirement { 37 | std::vector contentNeedContent; 38 | std::vector shapeNeedContent; 39 | }; 40 | ~Executor(); 41 | Requirement getRequirement(Expr* expr) const; 42 | ErrorCode computeInfo(Expr* expr); 43 | void makeCache(const std::vector& expr, bool forceCPU = false); 44 | ErrorCode runCache(std::shared_ptr cache); 45 | bool lazyEval = true; 46 | /**Internal Usage End*/ 47 | 48 | void setGlobalExecutorConfig(MNNForwardType type, const BackendConfig& config, int numberThread); 49 | int getCurrentRuntimeStatus(RuntimeStatus statusEnum); 50 | enum GCFlag { 51 | FULL, 52 | PART 53 | }; 54 | void gc(GCFlag flag = FULL); 55 | static std::shared_ptr getGlobalExecutor(); 56 | 57 | static std::shared_ptr newExecutor(MNNForwardType type, 58 | const BackendConfig& config, 59 | int numberThread); 60 | void resetProfile(); 61 | void dumpProfile(); 62 | /**Internal Usage Begin*/ 63 | void addOpCostTime(int op, float costTime); 64 | void addOpCostTime(const std::string& type, float costTime); 65 | void addOpFlops(const std::string& type, float flops); 66 | class Profiler; 67 | /**Internal Usage End*/ 68 | static RuntimeInfo getRuntime(); 69 | void setCallBack(TensorCallBackWithInfo&& before, TensorCallBackWithInfo&& after); 70 | const DebugTools* getDebugTools() const { 71 | return mDebug.get(); 72 | } 73 | class MNN_PUBLIC RuntimeManager { 74 | public: 75 | ~RuntimeManager(); 76 | /** 77 | * @param configs : schedule configs. 78 | * @param cacheName : full path for cache file. Note: should choose location for reading and writing. 79 | */ 80 | static RuntimeManager* createRuntimeManager(const ScheduleConfig& config); 81 | 82 | /** 83 | * @param rtmgr : the rtmgr to destroy 84 | */ 85 | static void destroy(RuntimeManager* rtmgr); 86 | 87 | /** 88 | * Deceperate, the same as createRuntimeManager(configs[0]) 89 | * @param configs : schedule configs. 90 | * @param cacheName : full path for cache file. Note: should choose location for reading and writing. 91 | */ 92 | static RuntimeManager* createRuntimeManager(std::vector& configs); 93 | 94 | /** 95 | * @brief set cache file. when file not exist -- create it, when file exist -- load it. 96 | * When should use : When choose GPU backend or use AUTO backend. 97 | * Calling Position: calling after createRuntimeManager. 98 | */ 99 | void setCache(std::string cacheName); 100 | 101 | /** 102 | * @brief update cache file 103 | * When should use : Together with setCache API. calling for first inference and when input shape is changed. 104 | * Calling Position : calling after inference done. 105 | */ 106 | void updateCache(); 107 | std::vector isBackendSupport(const std::vector type); 108 | friend class Executor; 109 | void setMode(Interpreter::SessionMode mode); 110 | void setHint(Interpreter::HintMode mode, int value); 111 | bool getInfo(Interpreter::SessionInfoCode code, void* ptr); 112 | BackendConfig* getBnConfig(); 113 | const RuntimeAttr* getInside() const { 114 | return mInside; 115 | } 116 | private: 117 | RuntimeAttr* mInside; 118 | friend class StaticModule; 119 | RuntimeManager(); 120 | }; 121 | private: 122 | void _makeCache(const std::vector& outputs, bool forceCPU); 123 | void _create(const std::vector& outputs, std::set>&& inputCaches, std::set>&& inputNode, bool forceCPU); 124 | 125 | void _visit(EXPRP expr, std::set>& inputCaches, std::set>& inputNode); 126 | std::map, std::shared_ptr> mRuntimes; 127 | 128 | Executor(std::shared_ptr backend, MNNForwardType type, int numberThread); 129 | std::mutex mMutex; 130 | std::shared_ptr mProfiler; 131 | std::shared_ptr mDebug; 132 | 133 | std::pair mFirstType; 134 | }; 135 | } // namespace Express 136 | } // namespace MNN 137 | #endif 138 | -------------------------------------------------------------------------------- /include/MNN/expr/ExecutorScope.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ExecutorScope.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2020/10/26. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_EXPR_EXECUTOR_SCOPE_HPP_ 10 | #define MNN_EXPR_EXECUTOR_SCOPE_HPP_ 11 | 12 | #include 13 | 14 | namespace MNN { 15 | namespace Express { 16 | 17 | struct MNN_PUBLIC ExecutorScope final { 18 | public: 19 | ExecutorScope() = delete; 20 | explicit ExecutorScope(const ExecutorScope&) = delete; 21 | explicit ExecutorScope(const std::shared_ptr& current); 22 | 23 | explicit ExecutorScope(const std::string& scope_name, 24 | const std::shared_ptr& current); 25 | 26 | virtual ~ExecutorScope(); 27 | 28 | static const std::shared_ptr Current(); 29 | }; 30 | 31 | } // namespace MNN 32 | } // namespace Express 33 | #endif // MNN_EXPR_EXECUTOR_SCOPE_HPP_ 34 | -------------------------------------------------------------------------------- /include/MNN/expr/Expr.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Expr.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/06/10. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_Expr_hpp 10 | #define MNN_Expr_hpp 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | namespace MNN { 21 | struct BufferStorage; 22 | struct OpT; 23 | struct Op; 24 | struct NetT; 25 | class Tensor; 26 | namespace Express { 27 | class Variable; 28 | class Expr; 29 | class Executor; 30 | typedef std::shared_ptr EXPRP; 31 | typedef std::weak_ptr WeakEXPRP; 32 | typedef std::vector INTS; 33 | enum Dimensionformat { NHWC, NC4HW4, NCHW }; 34 | class MNN_PUBLIC VARP { 35 | public: 36 | VARP() { 37 | // Do nothing 38 | } 39 | VARP(std::shared_ptr c) { 40 | mContent = std::move(c); 41 | } 42 | VARP(Variable* c) { 43 | mContent.reset(c); 44 | } 45 | Variable* get() const { 46 | return mContent.get(); 47 | } 48 | ~ VARP() { 49 | // Do nothing 50 | } 51 | VARP(const VARP& var) { 52 | mContent = var.mContent; 53 | } 54 | VARP(VARP&& var) { 55 | mContent = std::move(var.mContent); 56 | } 57 | VARP operator+(VARP var) const; 58 | VARP operator-(VARP var) const; 59 | VARP operator*(VARP var) const; 60 | VARP operator/(VARP var) const; 61 | VARP mean(INTS dims) const; 62 | VARP sum(INTS dims) const; 63 | 64 | bool operator==(const VARP& var) const { 65 | return var.mContent == mContent; 66 | } 67 | bool operator<(const VARP& var) const { 68 | return mContent < var.mContent; 69 | } 70 | bool operator<=(const VARP& var) const { 71 | return mContent <= var.mContent; 72 | } 73 | VARP& operator=(const VARP& var) { 74 | mContent = var.mContent; 75 | return *this; 76 | } 77 | VARP& operator=(Variable* var) { 78 | mContent.reset(var); 79 | return *this; 80 | } 81 | Variable* operator->() const { 82 | return mContent.get(); 83 | } 84 | enum InputType { 85 | INPUT = 0, 86 | CONSTANT = 1, 87 | TRAINABLE = 2, 88 | }; 89 | bool fix(InputType type) const; 90 | private: 91 | friend class Variable; 92 | std::shared_ptr mContent; 93 | }; 94 | inline bool operator==(Variable* src, VARP dst) { 95 | return src == dst.get(); 96 | } 97 | inline bool operator!=(Variable* src, VARP dst) { 98 | return src != dst.get(); 99 | } 100 | // inline bool operator<(VARP src, VARP dst) { 101 | // return src.get() < dst.get(); 102 | // } 103 | typedef std::vector VARPS; 104 | 105 | class MNN_PUBLIC Variable { 106 | public: 107 | struct Info { 108 | Dimensionformat order = NHWC; 109 | INTS dim; 110 | halide_type_t type; 111 | int size; 112 | void syncSize(); 113 | }; 114 | const std::string& name() const; 115 | void setName(const std::string& name); 116 | std::pair expr() const { 117 | return std::make_pair(mFrom, mFromIndex); 118 | } 119 | // If compute info error, return nullptr 120 | const Info* getInfo(); 121 | bool resize(INTS dims); 122 | template 123 | const T* readMap() { 124 | return (const T*)readInternal(); 125 | } 126 | 127 | template 128 | T* writeMap() { 129 | return (T*)writeInternal(); 130 | } 131 | 132 | //Depecerate 133 | void unMap(); 134 | 135 | bool input(VARP src); 136 | static void replace(VARP dst, VARP src); 137 | 138 | static VARP create(EXPRP expr, int index = 0); 139 | 140 | static std::vector load(const char* fileName); 141 | static std::map loadMap(const char* fileName); 142 | static std::vector load(const uint8_t* buffer, size_t length); 143 | static std::map loadMap(const uint8_t* buffer, size_t length); 144 | static std::pair, std::map> getInputAndOutput(const std::map& allVariable); 145 | static std::vector mapToSequence(const std::map& source); 146 | static std::vector getExecuteOrder(const std::vector& output); 147 | static void save(const std::vector& vars, const char* fileName); 148 | static std::vector save(const std::vector& vars); 149 | static void save(const std::vector& vars, NetT* dest); 150 | 151 | // Pack a few Variable to compute in one pipeline 152 | static void prepareCompute(const std::vector& vars, bool forceCPU = false); 153 | static void compute(const std::vector& vars, bool forceCPU = false); 154 | 155 | size_t linkNumber() const; 156 | const std::vector& toExprs() const; 157 | void setExpr(EXPRP expr, int index) { 158 | mFrom = expr; 159 | mFromIndex = index; 160 | } 161 | 162 | private: 163 | Variable(EXPRP expr, int index) { 164 | mFrom = expr; 165 | mFromIndex = index; 166 | } 167 | 168 | void* readInternal(bool forShape = false); 169 | void* writeInternal(bool inform=true); 170 | void informDirty(); 171 | 172 | friend class Expr; 173 | EXPRP mFrom; 174 | int mFromIndex; 175 | }; 176 | class MNN_PUBLIC Expr { 177 | public: 178 | struct Inside; 179 | enum MemoryType { 180 | COPY, 181 | MOVE, 182 | REF 183 | }; 184 | static EXPRP create(Tensor* tensor, bool own = false); 185 | 186 | static EXPRP create(Variable::Info&& info, const void* ptr, VARP::InputType type, MemoryType copy = COPY); 187 | static EXPRP create(const OpT* op, std::vector inputs, int outputSize = 1); 188 | static EXPRP create(std::shared_ptr extra, std::vector&& inputs, int outputSize = 1); 189 | static EXPRP create(std::unique_ptr&& op, std::vector inputs, int outputSize = 1) { 190 | return create(op.get(), inputs, outputSize); 191 | } 192 | void setName(const std::string& name); 193 | 194 | const Op* get() const { 195 | return mOp; 196 | } 197 | const std::vector& inputs() const { 198 | return mInputs; 199 | } 200 | int outputSize() const { 201 | return (int)mOutputNames.size(); 202 | } 203 | static void replace(EXPRP oldExpr, EXPRP newExpr); 204 | bool requireInfo(); 205 | void visitOutputs(const std::function& visit); 206 | static void visit(EXPRP expr, const std::function& before, const std::function& after); 207 | 208 | const std::vector& outputs() const { 209 | return mTo; 210 | } 211 | ~Expr(); 212 | 213 | bool visited() const { 214 | return mVisited; 215 | } 216 | void setVisited(bool visited) { 217 | mVisited = visited; 218 | } 219 | const std::string& name() const { 220 | return mName; 221 | } 222 | const std::string& outputName(int index) { 223 | return mOutputNames[index]; 224 | } 225 | 226 | VARP::InputType inputType() const {return mType;} 227 | Variable::Info* outputInfo(int index) const; 228 | std::shared_ptr extra() const { 229 | return mStorage; 230 | } 231 | bool setInfoDirty(); 232 | std::shared_ptr inside() const { 233 | return mInside; 234 | } 235 | bool valid() const { 236 | return mValid; 237 | } 238 | 239 | private: 240 | static void _addLinkForInputs(EXPRP expr); 241 | 242 | Expr(int outputSize); 243 | Expr(Tensor* tensor, bool own = false); 244 | 245 | friend class Variable; 246 | friend class VARP; 247 | VARP::InputType mType; 248 | const Op* mOp; 249 | std::vector mInputs; 250 | std::vector mOutputNames; 251 | 252 | bool mValid = true; 253 | std::shared_ptr mStorage; 254 | std::string mName; 255 | std::shared_ptr mInside = nullptr; 256 | bool mVisited = false; 257 | std::vector mTo; 258 | 259 | }; 260 | } // namespace Express 261 | } // namespace MNN 262 | 263 | #endif /* Expr_hpp */ 264 | -------------------------------------------------------------------------------- /include/MNN/expr/ExprCreator.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ExprCreator.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/06/27. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_ExprCreator_hpp 10 | #define MNN_ExprCreator_hpp 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #endif 17 | -------------------------------------------------------------------------------- /include/MNN/expr/MathOp.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // MathOp.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/06/27. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_MathOp_HPP 10 | #define MNN_MathOp_HPP 11 | 12 | namespace MNN { 13 | namespace Express { 14 | //BinaryOPs 15 | MNN_PUBLIC VARP _Add(VARP x, VARP y); 16 | MNN_PUBLIC VARP _Subtract(VARP x, VARP y); 17 | MNN_PUBLIC VARP _Multiply(VARP x, VARP y); 18 | MNN_PUBLIC VARP _Divide(VARP x, VARP y); 19 | MNN_PUBLIC VARP _Pow(VARP x, VARP y); 20 | MNN_PUBLIC VARP _Minimum(VARP x, VARP y); 21 | MNN_PUBLIC VARP _Maximum(VARP x, VARP y); 22 | MNN_PUBLIC VARP _BiasAdd(VARP value, VARP bias); 23 | MNN_PUBLIC VARP _Greater(VARP x, VARP y); 24 | MNN_PUBLIC VARP _GreaterEqual(VARP x, VARP y); 25 | MNN_PUBLIC VARP _Less(VARP x, VARP y); 26 | MNN_PUBLIC VARP _FloorDiv(VARP x, VARP y); 27 | MNN_PUBLIC VARP _SquaredDifference(VARP x, VARP y); 28 | MNN_PUBLIC VARP _Equal(VARP x, VARP y); 29 | MNN_PUBLIC VARP _LessEqual(VARP x, VARP y); 30 | MNN_PUBLIC VARP _FloorMod(VARP x, VARP y); 31 | MNN_PUBLIC VARP _Atan2(VARP x, VARP y); 32 | MNN_PUBLIC VARP _LogicalOr(VARP x, VARP y); 33 | MNN_PUBLIC VARP _NotEqual(VARP x, VARP y); 34 | MNN_PUBLIC VARP _BitwiseAnd(VARP x, VARP y); 35 | MNN_PUBLIC VARP _BitwiseOr(VARP x, VARP y); 36 | MNN_PUBLIC VARP _BitwiseXor(VARP x, VARP y); 37 | 38 | //UnaryOPs 39 | MNN_PUBLIC VARP _Sign(VARP a); 40 | MNN_PUBLIC VARP _Abs(VARP x); 41 | MNN_PUBLIC VARP _Negative(VARP x); 42 | MNN_PUBLIC VARP _Floor(VARP x); 43 | MNN_PUBLIC VARP _Round(VARP x); 44 | MNN_PUBLIC VARP _Ceil(VARP x); 45 | MNN_PUBLIC VARP _Square(VARP x); 46 | MNN_PUBLIC VARP _Sqrt(VARP x); 47 | MNN_PUBLIC VARP _Rsqrt(VARP x); 48 | MNN_PUBLIC VARP _Exp(VARP x); 49 | MNN_PUBLIC VARP _Log(VARP x); 50 | MNN_PUBLIC VARP _Sin(VARP x); 51 | MNN_PUBLIC VARP _Sinh(VARP x); 52 | MNN_PUBLIC VARP _Cos(VARP x); 53 | MNN_PUBLIC VARP _Cosh(VARP x); 54 | MNN_PUBLIC VARP _Tan(VARP x); 55 | MNN_PUBLIC VARP _Asin(VARP x); 56 | MNN_PUBLIC VARP _Asinh(VARP x); 57 | MNN_PUBLIC VARP _Acos(VARP x); 58 | MNN_PUBLIC VARP _Acosh(VARP x); 59 | MNN_PUBLIC VARP _Atan(VARP x); 60 | MNN_PUBLIC VARP _Atanh(VARP x); 61 | MNN_PUBLIC VARP _Reciprocal(VARP x); 62 | MNN_PUBLIC VARP _Log1p(VARP x); 63 | MNN_PUBLIC VARP _Gelu(VARP x); 64 | //Only one but not in UnaryOPs 65 | MNN_PUBLIC VARP _Tanh(VARP x); 66 | MNN_PUBLIC VARP _Sigmoid(VARP x); 67 | MNN_PUBLIC VARP _Erf(VARP x); 68 | MNN_PUBLIC VARP _Erfc(VARP x); 69 | MNN_PUBLIC VARP _Erfinv(VARP x); 70 | MNN_PUBLIC VARP _Expm1(VARP x); 71 | MNN_PUBLIC VARP _Hardswish(VARP x); 72 | 73 | //ReduceOPs 74 | MNN_PUBLIC VARP _ReduceSum(VARP input_variable, INTS axis = {}, bool keepDims = false); 75 | MNN_PUBLIC VARP _ReduceMean(VARP input_variable, INTS axis = {}, bool keepDims = false); 76 | MNN_PUBLIC VARP _ReduceMax(VARP input_variable, INTS axis = {}, bool keepDims = false); 77 | MNN_PUBLIC VARP _ReduceMin(VARP input_variable, INTS axis = {}, bool keepDims = false); 78 | MNN_PUBLIC VARP _ReduceProd(VARP input_variable, INTS axis = {}, bool keepDims = false); 79 | MNN_PUBLIC VARP _ReduceAny(VARP input_variable, INTS axis = {}, bool keepDims = false); 80 | MNN_PUBLIC VARP _ReduceAll(VARP input_variable, INTS axis = {}, bool keepDims = false); 81 | 82 | MNN_PUBLIC VARP _ReduceSumMutable(VARP input_variable, VARP axis, bool keepDims = false); 83 | MNN_PUBLIC VARP _ReduceMeanMutable(VARP input_variable, VARP axis, bool keepDims = false); 84 | MNN_PUBLIC VARP _ReduceMaxMutable(VARP input_variable, VARP axis, bool keepDims = false); 85 | MNN_PUBLIC VARP _ReduceMinMutable(VARP input_variable, VARP axis, bool keepDims = false); 86 | MNN_PUBLIC VARP _ReduceProdMutable(VARP input_variable, VARP axis, bool keepDims = false); 87 | MNN_PUBLIC VARP _ReduceAnyMutable(VARP input_variable, VARP axis, bool keepDims = false); 88 | MNN_PUBLIC VARP _ReduceAllMutable(VARP input_variable, VARP axis, bool keepDims = false); 89 | 90 | //EltwiseOPs 91 | MNN_PUBLIC VARP _Prod(VARP a, VARP b, std::vector coeff); 92 | MNN_PUBLIC VARP _Sum(VARP a, VARP b, std::vector coeff); 93 | MNN_PUBLIC VARP _Max(VARP a, VARP b, std::vector coeff); 94 | MNN_PUBLIC VARP _Sub(VARP a, VARP b, std::vector coeff); 95 | MNN_PUBLIC VARP _EltwiseProdInt8(VARP x, VARP y, 96 | std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, 97 | std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, 98 | std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale); 99 | MNN_PUBLIC VARP _EltwiseSumInt8(VARP x, VARP y, 100 | std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, 101 | std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, 102 | std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale); 103 | MNN_PUBLIC VARP _EltwiseSubInt8(VARP x, VARP y, 104 | std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, 105 | std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, 106 | std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale); 107 | MNN_PUBLIC VARP _EltwiseMaxInt8(VARP x, VARP y, 108 | std::vector x_weight, std::vector x_bias, std::vector x_scale, std::vector x_tensorScale, 109 | std::vector y_weight, std::vector y_bias, std::vector y_scale, std::vector y_tensorScale, 110 | std::vector output_weight, std::vector output_bias, std::vector output_scale, std::vector output_tensorScale); 111 | MNN_PUBLIC VARP _Mod(VARP x, VARP y); 112 | 113 | 114 | //OtherOPs 115 | template 116 | VARP _Cast(VARP x) { 117 | return _Cast(x, halide_type_of()); 118 | } 119 | MNN_PUBLIC VARP _Cast(VARP x, halide_type_t dtype); 120 | MNN_PUBLIC VARP _MatMul(VARP a, VARP b, bool tranposeA = false, bool tranposeB = false); 121 | MNN_PUBLIC VARP _Normalize(VARP x, int32_t acrossSpatial, int32_t channelShared, float eps, std::vector scale); 122 | MNN_PUBLIC VARP _ArgMax(VARP input, int axis = 0); 123 | MNN_PUBLIC VARP _ArgMin(VARP input, int axis = 0); 124 | MNN_PUBLIC VARP _BatchMatMul(VARP x, VARP y, bool adj_x = false, bool adj_y = false); 125 | MNN_PUBLIC VARP _UnravelIndex(VARP indices, VARP dims); 126 | MNN_PUBLIC VARP _ScatterNd(VARP indices, VARP updates, VARP shape); 127 | MNN_PUBLIC VARP _ScatterNd(VARP indices, VARP updates, VARP shape, VARP input); 128 | MNN_PUBLIC VARP _ScatterNd(VARP indices, VARP updates, VARP shape, int reduction); 129 | MNN_PUBLIC VARP _ScatterNd(VARP indices, VARP updates, VARP shape, VARP input, int reduction); 130 | MNN_PUBLIC VARP _ScatterElements(VARP data, VARP indices, VARP updates, int reduction = -1); 131 | MNN_PUBLIC VARP _ScatterElements(VARP data, VARP indices, VARP updates, VARP axis, int reduction = -1); 132 | MNN_PUBLIC VARP _OneHot(VARP indices, VARP depth, VARP onValue, VARP offValue, int axis = -1); 133 | MNN_PUBLIC VARP _BroadcastTo(VARP a, VARP shape); 134 | MNN_PUBLIC VARP _LinSpace(VARP start, VARP stop, VARP num); 135 | 136 | MNN_PUBLIC VARP _RandomUnifom(VARP shape, halide_type_t dtype, float low = 0.0f, float high = 1.0f, int seed0 = 0, int seed1 = 0); 137 | MNN_PUBLIC VARP _CumSum(VARP x, int axis, bool exclusive = false, bool reverse = false); 138 | MNN_PUBLIC VARP _CumProd(VARP x, int axis); 139 | MNN_PUBLIC VARPS _Svd(VARP x); 140 | MNN_PUBLIC VARP _Histogram(VARP x, int bin, int min, int max, int channel = -1); 141 | }; // namespace Express 142 | }; // namespace MNN 143 | 144 | #endif /* MathOp_HPP */ 145 | -------------------------------------------------------------------------------- /include/MNN/expr/Module.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Module.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/11/25. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_Train_Module_hpp 10 | #define MNN_Train_Module_hpp 11 | 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | 19 | namespace MNN { 20 | namespace Express { 21 | struct SubGraph; 22 | class MNN_PUBLIC Module { 23 | public: 24 | Module() = default; 25 | virtual ~Module() = default; 26 | virtual std::vector onForward(const std::vector& inputs) = 0; 27 | Express::VARP forward(Express::VARP input); 28 | std::vector parameters() const; 29 | bool loadParameters(const std::vector& parameters); 30 | void setIsTraining(const bool isTraining); 31 | bool getIsTraining(); 32 | void clearCache(); 33 | 34 | const std::string& name() const { 35 | return mName; 36 | }; 37 | void setName(std::string name) { 38 | mName = std::move(name); 39 | } 40 | const std::string type() const { 41 | return mType; 42 | } 43 | void setType(std::string type) { 44 | mType = std::move(type); 45 | } 46 | // Return the parameter index 47 | int addParameter(Express::VARP parameter); 48 | 49 | void setParameter(Express::VARP parameter, int index); 50 | static Module* createEmpty(const std::vector& parameters); 51 | 52 | struct BackendInfo { 53 | MNNForwardType type = MNN_FORWARD_CPU; 54 | BackendConfig* config = nullptr; 55 | }; 56 | 57 | struct Config { 58 | // Load module as dynamic, default static 59 | bool dynamic = false; 60 | 61 | // for static mode, if the shape is mutable, set true, otherwise set false to avoid resizeSession freqencily 62 | bool shapeMutable = true; 63 | // Pre-rearrange weights or not. Disabled by default. 64 | // The weights will be rearranged in a general way, so the best implementation 65 | // may not be adopted if `rearrange` is enabled. 66 | bool rearrange = false; 67 | 68 | BackendInfo* backend = nullptr; 69 | }; 70 | static Module* load(const std::vector& inputs, const std::vector& outputs, const uint8_t* buffer, size_t length, const Config* config = nullptr); 71 | static Module* load(const std::vector& inputs, const std::vector& outputs, const char* fileName, const Config* config = nullptr); 72 | // Shared RuntimeManager 73 | static Module* load(const std::vector& inputs, const std::vector& outputs, const char* fileName, const std::shared_ptr rtMgr, const Config* config = nullptr); 74 | static Module* load(const std::vector& inputs, const std::vector& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr rtMgr, const Config* config = nullptr); 75 | 76 | static Module* extract(std::vector inputs, std::vector outputs, bool fortrain, const std::map& subGraph = {}); 77 | 78 | static Module* clone(const Module* module, const bool shareParams = false); 79 | 80 | struct Info { 81 | // Input info load from model 82 | std::vector inputs; 83 | // The Module's defaultFormat, NCHW or NHWC 84 | Dimensionformat defaultFormat; 85 | // Runtime Info 86 | std::shared_ptr runTimeManager; 87 | // Input Names By Order 88 | std::vector inputNames; 89 | // Output Names By Order 90 | std::vector outputNames; 91 | // The MNNConvert's Version build the module 92 | std::string version; 93 | }; 94 | const Info* getInfo() const; 95 | class CloneContext { 96 | public: 97 | CloneContext() = default; 98 | explicit CloneContext(const bool shareParams) 99 | : mShareParams(shareParams) {} 100 | virtual ~CloneContext() = default; 101 | 102 | const bool shareParams() const { return mShareParams; } 103 | 104 | EXPRP getOrClone(const EXPRP expr); 105 | VARP getOrClone(const VARP var); 106 | 107 | private: 108 | bool mShareParams = false; 109 | std::unordered_map mExprMap; 110 | std::unordered_map mVarMap; 111 | }; 112 | 113 | virtual Module* clone(CloneContext* ctx) const { 114 | return nullptr; 115 | } 116 | void registerModel(const std::vector>& children); 117 | 118 | static void destroy(Module* m); 119 | protected: 120 | virtual void onClearCache() { 121 | } 122 | 123 | Module* cloneBaseTo(CloneContext* ctx, Module* module) const; 124 | 125 | private: 126 | void _collectParameters(std::vector& result) const; 127 | std::vector> mChildren; 128 | std::vector mParameters; 129 | bool mIsTraining = true; 130 | std::string mName; 131 | std::string mType; 132 | }; 133 | 134 | struct SubGraph { 135 | std::vector inputs; 136 | std::vector outputs; 137 | std::shared_ptr m; 138 | }; 139 | 140 | } // namespace Train 141 | } // namespace MNN 142 | 143 | #endif 144 | -------------------------------------------------------------------------------- /include/MNN/expr/NeuralNetWorkOp.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // NeuralNetWorkOp.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/06/27. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_NeuralNetWorkOp_HPP 10 | #define MNN_NeuralNetWorkOp_HPP 11 | #include 12 | 13 | namespace MNN { 14 | namespace Express { 15 | enum PaddingMode {CAFFE, VALID, SAME}; 16 | enum PoolingMode {MAXPOOL, AVEPOOL}; 17 | enum PadValueMode {CONSTANT, REFLECT, SYMMETRIC, EDGE}; 18 | MNN_PUBLIC VARP _Input(INTS shape = {}, Dimensionformat data_format = NC4HW4, halide_type_t dtype = halide_type_of()) ; 19 | MNN_PUBLIC VARP _Clone(VARP source, bool deepCopy = false); 20 | 21 | MNN_PUBLIC VARP _Scalar(const void* ptr, halide_type_t type); 22 | 23 | template 24 | VARP _Scalar(T value) { 25 | return _Scalar(&value, halide_type_of()); 26 | } 27 | 28 | 29 | MNN_PUBLIC VARP _Const(float value, INTS shape = {}, Dimensionformat format = NHWC); 30 | MNN_PUBLIC VARP _Const(const void* ptr, INTS shape = {}, Dimensionformat format = NHWC, 31 | halide_type_t type = halide_type_of()); 32 | MNN_PUBLIC VARP _TrainableParam(float value, INTS dims, Dimensionformat format); 33 | MNN_PUBLIC VARP _TrainableParam(const void* ptr, INTS dims, Dimensionformat format, 34 | halide_type_t type = halide_type_of()); 35 | MNN_PUBLIC VARP _InnerProduct(std::vector&& weight, std::vector&& bias, VARP x, INTS outputShape); 36 | MNN_PUBLIC VARP _Conv(VARP weight, VARP bias, VARP x, PaddingMode pad = VALID, INTS stride = {1, 1}, 37 | INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}); 38 | 39 | MNN_PUBLIC VARP _Conv(float weight, float bias, VARP x, INTS channel, INTS kernelSize, PaddingMode pad = VALID, 40 | INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1); 41 | MNN_PUBLIC VARP _Conv(std::vector&& weight, std::vector&& bias, VARP x, INTS channel, INTS kernelSize, 42 | PaddingMode pad = VALID, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}, bool relu = false, bool relu6 = false, int nbits = 8); 43 | MNN_PUBLIC VARP _Conv(std::vector&& weight, std::vector&& bias, VARP x, INTS channel, INTS kernelSize, 44 | PaddingMode pad = VALID, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}, bool relu = false, bool relu6 = false); 45 | MNN_PUBLIC VARP _Deconv(VARP weight, VARP bias, VARP x, PaddingMode pad = VALID, INTS stride = {1, 1}, 46 | INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}); 47 | 48 | MNN_PUBLIC VARP _Deconv(std::vector&& weight, std::vector&& bias, VARP x, INTS channel, INTS kernelSize, 49 | PaddingMode pad, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}, bool relu = false, bool relu6 = false); 50 | 51 | MNN_PUBLIC VARP _MaxPool(VARP x, INTS kernel, INTS stride = {1, 1}, PaddingMode pad = VALID, INTS pads= {0, 0}); 52 | MNN_PUBLIC VARP _AvePool(VARP x, INTS kernel, INTS stride = {1, 1}, PaddingMode pad = VALID, INTS pads= {0, 0}); 53 | MNN_PUBLIC VARP _Reshape(VARP x, INTS shape, Dimensionformat original_format = NCHW); 54 | MNN_PUBLIC VARP _Reshape(VARP x, VARP shape); 55 | MNN_PUBLIC VARP _Scale(VARP x, int channels, std::vector&& scales, std::vector&& bias); 56 | 57 | MNN_PUBLIC VARP _Relu(VARP x, float slope = 0.0f); 58 | MNN_PUBLIC VARP _Relu6(VARP x, float minValue = 0.0f, float maxValue = 6.0f); 59 | MNN_PUBLIC VARP _PRelu(VARP x, std::vector &&slopes); 60 | MNN_PUBLIC VARP _Softmax(VARP logits, int axis = -1); 61 | MNN_PUBLIC VARP _Softplus(VARP features); 62 | MNN_PUBLIC VARP _Softsign(VARP features); 63 | MNN_PUBLIC std::vector _Split(VARP value, INTS size_splits, int axis = 0); 64 | MNN_PUBLIC VARP _Slice(VARP x, VARP starts, VARP sizes); 65 | MNN_PUBLIC VARP _StridedSlice(VARP input, VARP begin, VARP end, VARP strided, 66 | int32_t beginMask, int32_t endMask, int32_t ellipsisMask, 67 | int32_t newAxisMask, int32_t shrinkAxisMask); 68 | MNN_PUBLIC VARP _StridedSliceWrite(VARP input, VARP begin, VARP end, VARP strided, VARP write, 69 | int32_t beginMask, int32_t endMask, int32_t ellipsisMask, 70 | int32_t newAxisMask, int32_t shrinkAxisMask); 71 | MNN_PUBLIC VARP _Concat(VARPS values, int axis); 72 | MNN_PUBLIC VARP _Convert(VARP input, Dimensionformat format); 73 | MNN_PUBLIC VARP _Transpose(VARP x, INTS perm); 74 | MNN_PUBLIC VARP _Transpose(VARP x, VARP perm); 75 | MNN_PUBLIC VARP _ChannelShuffle(VARP x, int group); 76 | MNN_PUBLIC VARP _ChangeInputFormat(VARP input, Dimensionformat format); 77 | MNN_PUBLIC VARP _Conv2DBackPropFilter(VARP input, VARP inputGrad, INTS kernelSize, PaddingMode pad = VALID, INTS stride = {1, 1}, INTS dilate = {1, 1}, int group = 1, INTS pads = {0, 0}); 78 | MNN_PUBLIC VARP _PoolGrad(VARP originInput, VARP originOutput, VARP inputGrad, INTS kernel, INTS stride, PoolingMode type, PaddingMode pad = VALID, INTS pads= {0, 0}); 79 | // FIXME: move the api to Array Ops 80 | MNN_PUBLIC VARP _ReverseSequence(VARP x, VARP y, int batchDim, int seqDim); 81 | // FIXME: move the api to Image Ops 82 | MNN_PUBLIC VARP _Crop(VARP images, VARP size, int axis, INTS offset); 83 | MNN_PUBLIC VARP _Resize(VARP images, float xScale, float yScale); 84 | MNN_PUBLIC VARP _Pad(VARP x, VARP paddings, PadValueMode mode = CONSTANT); 85 | MNN_PUBLIC VARP _ExpandDims(VARP input, int axis); 86 | MNN_PUBLIC VARP _ExpandDims(VARP input, VARP axis); 87 | 88 | MNN_PUBLIC VARP _Shape(VARP input, bool nchw = false); 89 | MNN_PUBLIC VARP _Stack(VARPS values, int axis=0); 90 | enum InterpolationMethod {BILINEAR, NEAREST}; 91 | MNN_PUBLIC VARP _CropAndResize(VARP image, VARP boxes, VARP box_ind, VARP crop_size, 92 | InterpolationMethod method, float extrapolation_value = 0.0); 93 | MNN_PUBLIC VARP _Fill(VARP dims, VARP value); 94 | MNN_PUBLIC VARP _Tile(VARP input, VARP multiples); 95 | MNN_PUBLIC VARP _Gather(VARP params, VARP indices); 96 | MNN_PUBLIC VARP _GatherV2(VARP params, VARP indices, VARP axis = nullptr); 97 | MNN_PUBLIC VARP _Squeeze(VARP input, INTS axis = {}); 98 | MNN_PUBLIC VARP _Unsqueeze(VARP input, INTS axis = {}); 99 | MNN_PUBLIC VARP _BatchToSpaceND(VARP input, VARP block_shape, VARP crops); 100 | MNN_PUBLIC VARP _GatherND(VARP params, VARP indices); 101 | MNN_PUBLIC VARP _GatherElements(VARP params, VARP indices); 102 | MNN_PUBLIC VARP _GatherElements(VARP params, VARP indices, VARP axis); 103 | MNN_PUBLIC VARP _Selu(VARP features, float scale, float alpha); 104 | MNN_PUBLIC VARP _Size(VARP input); 105 | MNN_PUBLIC VARP _Elu(VARP features, float alpha=1.0); 106 | MNN_PUBLIC VARP _Threshold(VARP features, float alpha=1.0); 107 | MNN_PUBLIC VARP _MatrixBandPart(VARP input, VARP num_lower, VARP num_upper); 108 | MNN_PUBLIC std::vector _Moments(VARP x, INTS axis, VARP shift, bool keepDims); 109 | MNN_PUBLIC VARP _SetDiff1D(VARP x, VARP y); 110 | MNN_PUBLIC VARP _SpaceToDepth(VARP input, int block_size); 111 | MNN_PUBLIC VARP _SpaceToBatchND(VARP input, VARP block_shape, VARP paddings); 112 | MNN_PUBLIC VARP _ZerosLike(VARP input); 113 | MNN_PUBLIC std::vector _Unstack(VARP value, int axis=0); 114 | MNN_PUBLIC VARP _Rank(VARP input); 115 | MNN_PUBLIC VARP _Range(VARP start, VARP limit, VARP delta); 116 | MNN_PUBLIC VARP _DepthToSpace(VARP input, int block_size); 117 | MNN_PUBLIC VARP _PriorBox(VARP feature, VARP image, 118 | std::vector min_size, std::vector max_size, std::vectoraspect_ratio, 119 | bool flip, bool clip, std::vectorvariance, 120 | unsigned int img_h, unsigned int img_w, float step_h, float step_w, float offset = 0.5); 121 | MNN_PUBLIC VARP _Permute(VARP input, INTS dims); 122 | MNN_PUBLIC VARP _DetectionOutput(VARP location, VARP confidence, VARP priorbox, 123 | unsigned int num_classes, bool share_location, int background_label_id, 124 | float nms_threshhold, int nms_topk, int code_type, 125 | bool variance_encoded_in_target, 126 | int keep_top_k, float confidence_threshold, float visualize_threshold); 127 | MNN_PUBLIC std::vector _DetectionPostProcess(VARP encode_boxes, VARP class_predictions, VARP anchors, 128 | int num_classes, int max_detections, 129 | int max_class_per_detection, int detections_per_class, 130 | float nms_threshold, float iou_threshold, 131 | bool use_regular_nms, std::vector centersize_encoding); 132 | MNN_PUBLIC VARP _Interp(VARPS xs, float widthScale, float heightScale, int outputWidth, int outputHeight, int resizeType, bool alignCorners); 133 | 134 | MNN_PUBLIC VARP _ZeroGrad(VARP x); 135 | 136 | // Int8 Inference 137 | MNN_PUBLIC VARP _Conv(std::vector&& weight, std::vector&& bias, std::vector&& scale, VARP x, INTS channel, INTS kernelSize, 138 | PaddingMode pad, INTS stride, INTS dilate, int group, INTS pads, bool relu, int nbits = 8); 139 | MNN_PUBLIC VARP _Conv(std::vector&& weight, std::vector&& bias, std::vector&& scale, 140 | VARP x, INTS channel, INTS kernelSize, 141 | PaddingMode pad, INTS stride, INTS dilate, int group, INTS pads, bool relu, 142 | int8_t inputZeroPoint, int8_t outputZeroPoint, 143 | int8_t minValue, int8_t maxValue, bool accumulateToInt16); 144 | MNN_PUBLIC VARP _Conv(std::vector&& weight, std::vector&& bias, std::vector&& weightScale, 145 | VARP x, INTS channel, INTS kernelSize, 146 | PaddingMode pad, INTS stride, INTS dilate, int group, INTS pads, bool relu, 147 | float scaleIn, float scaleOut, 148 | int8_t inputZeroPoint, int8_t outputZeroPoint, 149 | int8_t minValue, int8_t maxValue, float weightClampValue, bool accumulateToInt16); 150 | MNN_PUBLIC VARP _CosineSimilarity(VARP input0, VARP input1, VARP inputDim); 151 | 152 | enum GridSamplePaddingMode {GRID_SAMPLE_PADDING_ZEROS, GRID_SAMPLE_PADDING_BORDER, GRID_SAMPLE_PADDING_REFLECTION}; 153 | MNN_PUBLIC VARP _GridSample(VARP input, VARP grid, InterpolationMethod mode=BILINEAR, GridSamplePaddingMode paddingMode=GRID_SAMPLE_PADDING_ZEROS, bool alignCorners=false); 154 | MNN_PUBLIC VARP _FloatToInt8(VARP x, VARP scale, char minValue, char maxValue); 155 | MNN_PUBLIC VARP _FloatToInt8(VARP x, VARP scale, int8_t minValue, int8_t maxValue, int8_t zeroPoint); 156 | MNN_PUBLIC VARP _Int8ToFloat(VARP x, VARP scale); 157 | MNN_PUBLIC VARP _Int8ToFloat(VARP x, VARP scale, int8_t zeroPoint); 158 | 159 | MNN_PUBLIC VARP _Select(VARP select, VARP input0, VARP input1); 160 | MNN_PUBLIC std::vector _TopKV2(VARP input0, VARP input1); 161 | MNN_PUBLIC VARP _ImageProcess(VARP input, CV::ImageProcess::Config config, CV::Matrix matrix, int oh, int ow, int oc, int dtype, uint8_t padVal = 0); 162 | MNN_PUBLIC VARP _Where(VARP x); 163 | MNN_PUBLIC VARP _Sort(VARP x, int axis = -1, bool arg = false, bool descend = false); 164 | MNN_PUBLIC VARP _Raster(const std::vector& vars, const std::vector& regions, const std::vector& shape); 165 | MNN_PUBLIC VARP _Nms(VARP boxes, VARP scores, int maxDetections, float iouThreshold = -1, float scoreThreshold = -1); 166 | MNN_PUBLIC VARP _Im2Col(VARP x, INTS kernelSize, INTS dilate, INTS pads, INTS stride); 167 | MNN_PUBLIC VARP _Col2Im(VARP x, VARP outputShape, INTS kernelSize, INTS dilate, INTS pads, INTS stride); 168 | 169 | } // namespace Express 170 | } // namespace MNN 171 | 172 | #endif /* NeuralNetWorkOp_HPP */ 173 | -------------------------------------------------------------------------------- /include/MNN/expr/Optimizer.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Optimizer.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2019/08/20. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | #ifndef Optimizer_hpp 9 | #define Optimizer_hpp 10 | #include 11 | #include 12 | 13 | namespace MNN { 14 | namespace Express { 15 | class MNN_PUBLIC Optimizer { 16 | public: 17 | enum Device { 18 | CPU = 0, 19 | GPU = 1, 20 | OTHER = 2, 21 | AUTO = 3 22 | }; 23 | struct Config { 24 | Device device = CPU; 25 | MNNForwardType forwardType = MNN_FORWARD_ALL; 26 | int numThread = 4; 27 | }; 28 | static std::shared_ptr create(Config config); 29 | struct Cost { 30 | float compute; // MFlops 31 | float memory; // MB 32 | }; 33 | class Parameters { 34 | public: 35 | Parameters(int n); 36 | virtual ~Parameters(); 37 | 38 | float* get() const { 39 | return mValue; 40 | } 41 | int size() const { 42 | return mSize; 43 | } 44 | 45 | private: 46 | float* mValue; 47 | int mSize; 48 | }; 49 | virtual std::shared_ptr onGetParameters(const std::vector& outputs) { 50 | return nullptr; 51 | } 52 | 53 | //Given paramters and measure cost, the parameters must be the same as onGetParameters 54 | virtual Cost onMeasure(const std::vector& outputs, std::shared_ptr parameters = nullptr) = 0; 55 | 56 | //Modify the output directly, the parameters must be the same as onGetParameters 57 | virtual bool onExecute(const std::vector& outputs, std::shared_ptr parameters = nullptr) = 0; 58 | 59 | Optimizer() = default; 60 | virtual ~Optimizer() = default; 61 | }; 62 | } // namespace Express 63 | } // namespace MNN 64 | #endif 65 | -------------------------------------------------------------------------------- /include/MNN/expr/Scope.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // RuntimeScope.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2020/10/26. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_EXPR_SCOPE_HPP_ 10 | #define MNN_EXPR_SCOPE_HPP_ 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | 19 | namespace MNN { 20 | namespace Express { 21 | 22 | template 23 | class Scope { 24 | public: 25 | Scope(); 26 | virtual ~Scope() = default; 27 | 28 | struct ScopedContent { 29 | std::string scope_name; 30 | T content; 31 | }; 32 | void EnterScope(const ScopedContent& current); 33 | void EnterScope(const T& current); 34 | void EnterScope(const std::string& scope_name, const T& current); 35 | 36 | void ExitScope(); 37 | 38 | const ScopedContent& Current() const; 39 | const T Content() const; 40 | 41 | int ScopedLevel() const { return scoped_level_; } 42 | 43 | private: 44 | std::string MakeScopeName(const std::string& prefix, int level) const; 45 | 46 | mutable std::mutex mutex_; 47 | int scoped_level_ = 0; 48 | std::vector scoped_contents_; 49 | }; 50 | 51 | template 52 | Scope::Scope() : scoped_level_(0) { 53 | } 54 | 55 | template 56 | void Scope::EnterScope(const ScopedContent& current) { 57 | std::lock_guard lock(mutex_); 58 | ++scoped_level_; 59 | scoped_contents_.push_back(current); 60 | } 61 | 62 | template 63 | void Scope::EnterScope(const T& current) { 64 | EnterScope("scope", current); 65 | } 66 | 67 | template 68 | void Scope::EnterScope(const std::string& scope_name, 69 | const T& current) { 70 | std::lock_guard lock(mutex_); 71 | int scoped_level = ScopedLevel(); 72 | std::string name = MakeScopeName(scope_name, scoped_level++); 73 | ScopedContent content{name, current}; 74 | ++scoped_level_; 75 | scoped_contents_.push_back(content); 76 | } 77 | 78 | template 79 | void Scope::ExitScope() { 80 | std::lock_guard lock(mutex_); 81 | --scoped_level_; 82 | scoped_contents_.resize(scoped_level_); 83 | } 84 | 85 | template 86 | const typename Scope::ScopedContent& Scope::Current() const { 87 | std::lock_guard lock(mutex_); 88 | MNN_CHECK(scoped_contents_.size() > 0, "Scope level should not be 0."); 89 | return scoped_contents_.back(); 90 | } 91 | 92 | template 93 | const T Scope::Content() const { 94 | std::lock_guard lock(mutex_); 95 | if (scoped_contents_.empty()) { 96 | return nullptr; 97 | } 98 | return scoped_contents_.back().content; 99 | } 100 | 101 | template 102 | std::string Scope::MakeScopeName(const std::string& prefix, 103 | int level) const { 104 | char s[16]; 105 | snprintf(s, 16, "%d", level); 106 | return prefix + "/" + std::string(s); 107 | } 108 | 109 | } // namespace Express 110 | } // namespace MNN 111 | 112 | #endif // MNN_EXPR_SCOPE_HPP_ 113 | -------------------------------------------------------------------------------- /include/MNN/plugin/PluginContext.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ShapeInference.h 3 | // MNN 4 | // 5 | // Created by MNN on 2020/04/05. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_PLUGIN_PLUGIN_CONTEXT_HPP_ 10 | #define MNN_PLUGIN_PLUGIN_CONTEXT_HPP_ 11 | 12 | #include 13 | #include 14 | 15 | #include // Backend 16 | #include 17 | #include "Tensor_generated.h" 18 | 19 | namespace MNN { 20 | namespace plugin { 21 | 22 | class MNN_PUBLIC PluginContext { 23 | public: 24 | PluginContext() = delete; 25 | PluginContext(const std::vector& inputs, // NOLINT 26 | const std::vector& outputs); 27 | 28 | virtual ~PluginContext() = default; 29 | 30 | const std::vector& inputs() const { 31 | return inputs_; 32 | } 33 | const std::vector& outputs() const { 34 | return outputs_; 35 | } 36 | 37 | const Tensor* input(const int index) const; 38 | const Tensor* output(const int index) const; 39 | 40 | Tensor* output(const int index); 41 | 42 | bool hasAttr(const std::string& name) const; 43 | 44 | bool setAttr(const std::string& name, const Attribute* attr); 45 | 46 | void setAttrs(const std::unordered_map& attrs); 48 | 49 | const Attribute* getAttr(const std::string& name) const; 50 | 51 | const std::unordered_map& getAttrs() const; 52 | 53 | protected: 54 | const std::vector& inputs_; 55 | const std::vector& outputs_; 56 | std::unordered_map attrs_; 57 | }; 58 | 59 | class MNN_PUBLIC InferShapeContext : public PluginContext { 60 | public: 61 | InferShapeContext() = delete; 62 | InferShapeContext(const std::vector& inputs, // NOLINT 63 | const std::vector& outputs); 64 | 65 | virtual ~InferShapeContext() = default; 66 | }; 67 | 68 | class MNN_PUBLIC CPUKernelContext : public PluginContext { 69 | public: 70 | CPUKernelContext() = delete; 71 | CPUKernelContext(const std::string& op_type, // NOLINT 72 | Backend* backend, // NOLINT 73 | const std::vector& inputs, // NOLINT 74 | const std::vector& outputs); 75 | 76 | virtual ~CPUKernelContext() = default; 77 | 78 | Backend* backend() const { 79 | return backend_; 80 | } 81 | 82 | const std::string& op_type() const { 83 | return op_type_; 84 | } 85 | 86 | private: 87 | const std::string op_type_ = ""; 88 | Backend* backend_ = nullptr; 89 | }; 90 | 91 | inline PluginContext::PluginContext(const std::vector& inputs, // NOLINT 92 | const std::vector& outputs) // NOLINT 93 | : inputs_(inputs), outputs_(outputs) { 94 | } 95 | 96 | inline const Tensor* PluginContext::input(const int index) const { 97 | MNN_ASSERT(index < inputs_.size()); 98 | return inputs_.at(index); 99 | } 100 | 101 | inline const Tensor* PluginContext::output(const int index) const { 102 | MNN_ASSERT(index < outputs_.size()); 103 | return outputs_.at(index); 104 | } 105 | 106 | inline Tensor* PluginContext::output(const int index) { 107 | MNN_ASSERT(index < outputs_.size()); 108 | return outputs_.at(index); 109 | } 110 | 111 | inline bool PluginContext::hasAttr(const std::string& name) const { 112 | return attrs_.count(name) > 0; 113 | } 114 | 115 | inline bool PluginContext::setAttr(const std::string& name, // NOLINT 116 | const Attribute* attr) { 117 | return attrs_.emplace(name, attr).second; 118 | } 119 | 120 | inline void PluginContext::setAttrs( // NOLINT 121 | const std::unordered_map& attrs) { 122 | attrs_ = attrs; 123 | } 124 | 125 | inline const Attribute* PluginContext::getAttr(const std::string& name) const { 126 | const auto& it = attrs_.find(name); 127 | MNN_ASSERT(it != attrs_.end()); 128 | return it->second; 129 | } 130 | 131 | inline const std::unordered_map& // NOLINT 132 | PluginContext::getAttrs() const { 133 | return attrs_; 134 | } 135 | 136 | } // namespace plugin 137 | } // namespace MNN 138 | 139 | #endif // MNN_PLUGIN_PLUGIN_CONTEXT_HPP_ 140 | -------------------------------------------------------------------------------- /include/MNN/plugin/PluginKernel.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ShapeInference.h 3 | // MNN 4 | // 5 | // Created by MNN on 2020/04/05. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_PLUGIN_PLUGIN_KERNEL_HPP_ 10 | #define MNN_PLUGIN_PLUGIN_KERNEL_HPP_ 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | namespace MNN { 19 | namespace plugin { 20 | 21 | template 22 | class MNN_PUBLIC ComputeKernel { 23 | public: 24 | ComputeKernel() = default; 25 | virtual ~ComputeKernel() = default; 26 | virtual bool compute(KernelContextT* ctx) = 0; 27 | }; 28 | 29 | class MNN_PUBLIC CPUComputeKernel : public ComputeKernel { 30 | public: 31 | using ContextT = CPUKernelContext; 32 | using KernelT = CPUComputeKernel; 33 | 34 | CPUComputeKernel() = default; 35 | virtual ~CPUComputeKernel() = default; 36 | virtual bool init(CPUKernelContext* ctx) = 0; 37 | virtual bool compute(CPUKernelContext* ctx) = 0; 38 | }; 39 | 40 | template 41 | class MNN_PUBLIC ComputeKernelRegistry { 42 | public: 43 | typedef std::function Factory; 44 | static std::unordered_map* getFactoryMap(); 45 | 46 | static bool add(const std::string& name, Factory factory); 47 | 48 | static PluginKernelT* get(const std::string& name); 49 | }; 50 | 51 | template 52 | struct ComputeKernelRegistrar { 53 | ComputeKernelRegistrar(const std::string& name) { 54 | ComputeKernelRegistry::add(name, []() { // NOLINT 55 | return new PluginKernelT; // NOLINT 56 | }); 57 | } 58 | }; 59 | 60 | #define REGISTER_PLUGIN_COMPUTE_KERNEL(name, computeKernel) \ 61 | namespace { \ 62 | static auto _plugin_compute_kernel_##name##_ __attribute__((unused)) = \ 63 | ComputeKernelRegistrar(#name); \ 64 | } // namespace 65 | 66 | } // namespace plugin 67 | } // namespace MNN 68 | 69 | #endif // MNN_PLUGIN_PLUGIN_KERNEL_HPP_ 70 | -------------------------------------------------------------------------------- /include/MNN/plugin/PluginShapeInference.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // ShapeInference.h 3 | // MNN 4 | // 5 | // Created by MNN on 2020/04/05. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MNN_PLUGIN_PLUGIN_SHAPE_INFERENCE_HPP_ 10 | #define MNN_PLUGIN_PLUGIN_SHAPE_INFERENCE_HPP_ 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | namespace MNN { 19 | namespace plugin { 20 | 21 | class MNN_PUBLIC InferShapeKernel { 22 | public: 23 | virtual ~InferShapeKernel() = default; 24 | virtual bool compute(InferShapeContext* ctx) = 0; 25 | }; 26 | 27 | class MNN_PUBLIC InferShapeKernelRegister { 28 | public: 29 | // typedef InferShapeKernel* (*Factory)(); 30 | typedef std::function Factory; 31 | static std::unordered_map* getFactoryMap(); 32 | 33 | static bool add(const std::string& name, Factory factory); 34 | 35 | static InferShapeKernel* get(const std::string& name); 36 | }; 37 | 38 | template 39 | struct InferShapeKernelRegistrar { 40 | InferShapeKernelRegistrar(const std::string& name) { 41 | InferShapeKernelRegister::add(name, []() { // NOLINT 42 | return new PluginKernel; // NOLINT 43 | }); 44 | } 45 | }; 46 | 47 | #define REGISTER_PLUGIN_OP(name, inferShapeKernel) \ 48 | namespace { \ 49 | static auto _plugin_infer_shape_##name##_ __attribute__((unused)) = \ 50 | InferShapeKernelRegistrar(#name); \ 51 | } // namespace 52 | 53 | } // namespace plugin 54 | } // namespace MNN 55 | 56 | #endif // MNN_PLUGIN_PLUGIN_SHAPE_INFERENCE_HPP_ 57 | -------------------------------------------------------------------------------- /include/cv/calib3d.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // calib3d.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2022/07/14. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef CALIB3D_HPP 10 | #define CALIB3D_HPP 11 | 12 | #include 13 | #include 14 | #include 15 | #include "types.hpp" 16 | 17 | namespace MNN { 18 | namespace CV { 19 | 20 | MNN_PUBLIC VARP Rodrigues(VARP src); 21 | 22 | MNN_PUBLIC std::pair solvePnP(VARP objectPoints, VARP imagePoints, VARP cameraMatrix, VARP distCoeffs, 23 | bool useExtrinsicGuess = false); 24 | 25 | } // CV 26 | } // MNN 27 | #endif // CALIB3D_HPP 28 | -------------------------------------------------------------------------------- /include/cv/cv.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // cv.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2021/09/02. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef CV_HPP 10 | #define CV_HPP 11 | 12 | #include "types.hpp" 13 | #include "calib3d.hpp" 14 | #include "imgcodecs.hpp" 15 | #include "imgproc/imgproc.hpp" 16 | 17 | #endif // CV_HPP 18 | -------------------------------------------------------------------------------- /include/cv/imgcodecs.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // imgcodecs.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2021/08/26. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef IMGCODECS_HPP 10 | #define IMGCODECS_HPP 11 | 12 | #include 13 | #include 14 | 15 | namespace MNN { 16 | namespace CV { 17 | using namespace Express; 18 | 19 | enum ImreadModes { 20 | IMREAD_GRAYSCALE = 0, // uint8_t gray 21 | IMREAD_COLOR = 1, // uint8_t bgr 22 | IMREAD_ANYDEPTH = 4, // float bgr 23 | }; 24 | 25 | enum ImwriteFlags { 26 | IMWRITE_JPEG_QUALITY = 1, // jpg, default is 95 27 | }; 28 | 29 | MNN_PUBLIC bool haveImageReader(const std::string& filename); 30 | 31 | MNN_PUBLIC bool haveImageWriter(const std::string& filename); 32 | 33 | MNN_PUBLIC VARP imdecode(const std::vector& buf, int flags); 34 | 35 | MNN_PUBLIC std::pair> imencode(std::string ext, VARP img, 36 | const std::vector& params = std::vector()); 37 | 38 | MNN_PUBLIC VARP imread(const std::string& filename, int flags = IMREAD_COLOR); 39 | 40 | MNN_PUBLIC bool imwrite(const std::string& filename, VARP img, 41 | const std::vector& params = std::vector()); 42 | 43 | } // CV 44 | } // MNN 45 | #endif // IMGCODECS_HPP 46 | -------------------------------------------------------------------------------- /include/cv/imgproc/color.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // color.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2021/08/18. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef COLOR_HPP 10 | #define COLOR_HPP 11 | 12 | #include 13 | #include 14 | #include "../types.hpp" 15 | 16 | namespace MNN { 17 | namespace CV { 18 | using namespace Express; 19 | 20 | enum ColorConversionCodes { 21 | COLOR_BGR2BGRA = 0, 22 | COLOR_RGB2RGBA = COLOR_BGR2BGRA, 23 | COLOR_BGRA2BGR = 1, 24 | COLOR_RGBA2RGB = COLOR_BGRA2BGR, 25 | COLOR_BGR2RGBA = 2, 26 | COLOR_RGB2BGRA = COLOR_BGR2RGBA, 27 | COLOR_RGBA2BGR = 3, 28 | COLOR_BGRA2RGB = COLOR_RGBA2BGR, 29 | COLOR_BGR2RGB = 4, 30 | COLOR_RGB2BGR = COLOR_BGR2RGB, 31 | COLOR_BGRA2RGBA = 5, 32 | COLOR_RGBA2BGRA = COLOR_BGRA2RGBA, 33 | COLOR_BGR2GRAY = 6, 34 | COLOR_RGB2GRAY = 7, 35 | COLOR_GRAY2BGR = 8, 36 | COLOR_GRAY2RGB = COLOR_GRAY2BGR, 37 | COLOR_GRAY2BGRA = 9, 38 | COLOR_GRAY2RGBA = COLOR_GRAY2BGRA, 39 | COLOR_BGRA2GRAY = 10, 40 | COLOR_RGBA2GRAY = 11, 41 | COLOR_BGR2BGR565 = 12, 42 | COLOR_RGB2BGR565 = 13, 43 | COLOR_BGR5652BGR = 14, 44 | COLOR_BGR5652RGB = 15, 45 | COLOR_BGRA2BGR565 = 16, 46 | COLOR_RGBA2BGR565 = 17, 47 | COLOR_BGR5652BGRA = 18, 48 | COLOR_BGR5652RGBA = 19, 49 | COLOR_GRAY2BGR565 = 20, 50 | COLOR_BGR5652GRAY = 21, 51 | COLOR_BGR2BGR555 = 22, 52 | COLOR_RGB2BGR555 = 23, 53 | COLOR_BGR5552BGR = 24, 54 | COLOR_BGR5552RGB = 25, 55 | COLOR_BGRA2BGR555 = 26, 56 | COLOR_RGBA2BGR555 = 27, 57 | COLOR_BGR5552BGRA = 28, 58 | COLOR_BGR5552RGBA = 29, 59 | COLOR_GRAY2BGR555 = 30, 60 | COLOR_BGR5552GRAY = 31, 61 | COLOR_BGR2XYZ = 32, 62 | COLOR_RGB2XYZ = 33, 63 | COLOR_XYZ2BGR = 34, 64 | COLOR_XYZ2RGB = 35, 65 | COLOR_BGR2YCrCb = 36, 66 | COLOR_RGB2YCrCb = 37, 67 | COLOR_YCrCb2BGR = 38, 68 | COLOR_YCrCb2RGB = 39, 69 | COLOR_BGR2HSV = 40, 70 | COLOR_RGB2HSV = 41, 71 | COLOR_BGR2Lab = 44, 72 | COLOR_RGB2Lab = 45, 73 | COLOR_BGR2Luv = 50, 74 | COLOR_RGB2Luv = 51, 75 | COLOR_BGR2HLS = 52, 76 | COLOR_RGB2HLS = 53, 77 | COLOR_HSV2BGR = 54, 78 | COLOR_HSV2RGB = 55, 79 | COLOR_Lab2BGR = 56, 80 | COLOR_Lab2RGB = 57, 81 | COLOR_Luv2BGR = 58, 82 | COLOR_Luv2RGB = 59, 83 | COLOR_HLS2BGR = 60, 84 | COLOR_HLS2RGB = 61, 85 | COLOR_BGR2HSV_FULL = 66, 86 | COLOR_RGB2HSV_FULL = 67, 87 | COLOR_BGR2HLS_FULL = 68, 88 | COLOR_RGB2HLS_FULL = 69, 89 | COLOR_HSV2BGR_FULL = 70, 90 | COLOR_HSV2RGB_FULL = 71, 91 | COLOR_HLS2BGR_FULL = 72, 92 | COLOR_HLS2RGB_FULL = 73, 93 | COLOR_LBGR2Lab = 74, 94 | COLOR_LRGB2Lab = 75, 95 | COLOR_LBGR2Luv = 76, 96 | COLOR_LRGB2Luv = 77, 97 | COLOR_Lab2LBGR = 78, 98 | COLOR_Lab2LRGB = 79, 99 | COLOR_Luv2LBGR = 80, 100 | COLOR_Luv2LRGB = 81, 101 | COLOR_BGR2YUV = 82, 102 | COLOR_RGB2YUV = 83, 103 | COLOR_YUV2BGR = 84, 104 | COLOR_YUV2RGB = 85, 105 | COLOR_YUV2RGB_NV12 = 90, 106 | COLOR_YUV2BGR_NV12 = 91, 107 | COLOR_YUV2RGB_NV21 = 92, 108 | COLOR_YUV2BGR_NV21 = 93, 109 | COLOR_YUV420sp2RGB = COLOR_YUV2RGB_NV21, 110 | COLOR_YUV420sp2BGR = COLOR_YUV2BGR_NV21, 111 | COLOR_YUV2RGBA_NV12 = 94, 112 | COLOR_YUV2BGRA_NV12 = 95, 113 | COLOR_YUV2RGBA_NV21 = 96, 114 | COLOR_YUV2BGRA_NV21 = 97, 115 | COLOR_YUV420sp2RGBA = COLOR_YUV2RGBA_NV21, 116 | COLOR_YUV420sp2BGRA = COLOR_YUV2BGRA_NV21, 117 | COLOR_YUV2RGB_YV12 = 98, 118 | COLOR_YUV2BGR_YV12 = 99, 119 | COLOR_YUV2RGB_IYUV = 100, 120 | COLOR_YUV2BGR_IYUV = 101, 121 | COLOR_YUV2RGB_I420 = COLOR_YUV2RGB_IYUV, 122 | COLOR_YUV2BGR_I420 = COLOR_YUV2BGR_IYUV, 123 | COLOR_YUV420p2RGB = COLOR_YUV2RGB_YV12, 124 | COLOR_YUV420p2BGR = COLOR_YUV2BGR_YV12, 125 | COLOR_YUV2RGBA_YV12 = 102, 126 | COLOR_YUV2BGRA_YV12 = 103, 127 | COLOR_YUV2RGBA_IYUV = 104, 128 | COLOR_YUV2BGRA_IYUV = 105, 129 | COLOR_YUV2RGBA_I420 = COLOR_YUV2RGBA_IYUV, 130 | COLOR_YUV2BGRA_I420 = COLOR_YUV2BGRA_IYUV, 131 | COLOR_YUV420p2RGBA = COLOR_YUV2RGBA_YV12, 132 | COLOR_YUV420p2BGRA = COLOR_YUV2BGRA_YV12, 133 | COLOR_YUV2GRAY_420 = 106, 134 | COLOR_YUV2GRAY_NV21 = COLOR_YUV2GRAY_420, 135 | COLOR_YUV2GRAY_NV12 = COLOR_YUV2GRAY_420, 136 | COLOR_YUV2GRAY_YV12 = COLOR_YUV2GRAY_420, 137 | COLOR_YUV2GRAY_IYUV = COLOR_YUV2GRAY_420, 138 | COLOR_YUV2GRAY_I420 = COLOR_YUV2GRAY_420, 139 | COLOR_YUV420sp2GRAY = COLOR_YUV2GRAY_420, 140 | COLOR_YUV420p2GRAY = COLOR_YUV2GRAY_420, 141 | COLOR_YUV2RGB_UYVY = 107, 142 | COLOR_YUV2BGR_UYVY = 108, 143 | COLOR_YUV2RGB_Y422 = COLOR_YUV2RGB_UYVY, 144 | COLOR_YUV2BGR_Y422 = COLOR_YUV2BGR_UYVY, 145 | COLOR_YUV2RGB_UYNV = COLOR_YUV2RGB_UYVY, 146 | COLOR_YUV2BGR_UYNV = COLOR_YUV2BGR_UYVY, 147 | COLOR_YUV2RGBA_UYVY = 111, 148 | COLOR_YUV2BGRA_UYVY = 112, 149 | COLOR_YUV2RGBA_Y422 = COLOR_YUV2RGBA_UYVY, 150 | COLOR_YUV2BGRA_Y422 = COLOR_YUV2BGRA_UYVY, 151 | COLOR_YUV2RGBA_UYNV = COLOR_YUV2RGBA_UYVY, 152 | COLOR_YUV2BGRA_UYNV = COLOR_YUV2BGRA_UYVY, 153 | COLOR_YUV2RGB_YUY2 = 115, 154 | COLOR_YUV2BGR_YUY2 = 116, 155 | COLOR_YUV2RGB_YVYU = 117, 156 | COLOR_YUV2BGR_YVYU = 118, 157 | COLOR_YUV2RGB_YUYV = COLOR_YUV2RGB_YUY2, 158 | COLOR_YUV2BGR_YUYV = COLOR_YUV2BGR_YUY2, 159 | COLOR_YUV2RGB_YUNV = COLOR_YUV2RGB_YUY2, 160 | COLOR_YUV2BGR_YUNV = COLOR_YUV2BGR_YUY2, 161 | COLOR_YUV2RGBA_YUY2 = 119, 162 | COLOR_YUV2BGRA_YUY2 = 120, 163 | COLOR_YUV2RGBA_YVYU = 121, 164 | COLOR_YUV2BGRA_YVYU = 122, 165 | COLOR_YUV2RGBA_YUYV = COLOR_YUV2RGBA_YUY2, 166 | COLOR_YUV2BGRA_YUYV = COLOR_YUV2BGRA_YUY2, 167 | COLOR_YUV2RGBA_YUNV = COLOR_YUV2RGBA_YUY2, 168 | COLOR_YUV2BGRA_YUNV = COLOR_YUV2BGRA_YUY2, 169 | COLOR_YUV2GRAY_UYVY = 123, 170 | COLOR_YUV2GRAY_YUY2 = 124, 171 | COLOR_YUV2GRAY_Y422 = COLOR_YUV2GRAY_UYVY, 172 | COLOR_YUV2GRAY_UYNV = COLOR_YUV2GRAY_UYVY, 173 | COLOR_YUV2GRAY_YVYU = COLOR_YUV2GRAY_YUY2, 174 | COLOR_YUV2GRAY_YUYV = COLOR_YUV2GRAY_YUY2, 175 | COLOR_YUV2GRAY_YUNV = COLOR_YUV2GRAY_YUY2, 176 | COLOR_RGBA2mRGBA = 125, 177 | COLOR_mRGBA2RGBA = 126, 178 | COLOR_RGB2YUV_I420 = 127, 179 | COLOR_BGR2YUV_I420 = 128, 180 | COLOR_RGB2YUV_IYUV = COLOR_RGB2YUV_I420, 181 | COLOR_BGR2YUV_IYUV = COLOR_BGR2YUV_I420, 182 | COLOR_RGBA2YUV_I420 = 129, 183 | COLOR_BGRA2YUV_I420 = 130, 184 | COLOR_RGBA2YUV_IYUV = COLOR_RGBA2YUV_I420, 185 | COLOR_BGRA2YUV_IYUV = COLOR_BGRA2YUV_I420, 186 | COLOR_RGB2YUV_YV12 = 131, 187 | COLOR_BGR2YUV_YV12 = 132, 188 | COLOR_RGBA2YUV_YV12 = 133, 189 | COLOR_BGRA2YUV_YV12 = 134, 190 | COLOR_COLORCVT_MAX = 143 191 | }; 192 | 193 | 194 | MNN_PUBLIC VARP cvtColor(VARP src, int code, int dstCn = 0); 195 | MNN_PUBLIC VARP cvtColorTwoPlane(VARP src1, VARP src2, int code); 196 | MNN_PUBLIC VARP demosaicing(VARP src, int code, int dstCn = 0); 197 | 198 | } // CV 199 | } // MNN 200 | #endif // COLOR_HPP 201 | -------------------------------------------------------------------------------- /include/cv/imgproc/draw.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // draw.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2021/08/26. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef DRAW_HPP 10 | #define DRAW_HPP 11 | 12 | #include 13 | #include 14 | #include 15 | #include "../types.hpp" 16 | 17 | namespace MNN { 18 | namespace CV { 19 | 20 | enum LineTypes { 21 | FILLED = -1, 22 | LINE_4 = 4, 23 | LINE_8 = 8, 24 | LINE_AA = 16 25 | }; 26 | 27 | MNN_PUBLIC void arrowedLine(VARP& img, Point pt1, Point pt2, const Scalar& color, 28 | int thickness=1, int line_type=8, int shift=0, double tipLength=0.1); 29 | MNN_PUBLIC void circle(VARP& img, Point center, int radius, const Scalar& color, 30 | int thickness=1, int line_type=8, int shift=0); 31 | 32 | MNN_PUBLIC void line(VARP& img, Point pt1, Point pt2, const Scalar& color, 33 | int thickness = 1, int lineType = LINE_8, int shift = 0); 34 | 35 | MNN_PUBLIC void rectangle(VARP& img, Point pt1, Point pt2, const Scalar& color, 36 | int thickness = 1, int lineType = LINE_8, int shift = 0); 37 | 38 | MNN_PUBLIC void drawContours(VARP& img, std::vector> _contours, int contourIdx, const Scalar& color, 39 | int thickness = 1, int lineType = LINE_8); 40 | 41 | MNN_PUBLIC void fillPoly(VARP& img, std::vector> pts, const Scalar& color, 42 | int line_type = LINE_8, int shift = 0, Point offset = {0, 0}); 43 | } // CV 44 | } // MNN 45 | #endif // DRAW_HPP 46 | -------------------------------------------------------------------------------- /include/cv/imgproc/filter.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // filter.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2021/08/18. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef FILTER_HPP 10 | #define FILTER_HPP 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include "../types.hpp" 17 | 18 | namespace MNN { 19 | namespace CV { 20 | 21 | MNN_PUBLIC VARP bilateralFilter(VARP src, int d, double sigmaColor, double sigmaSpace, 22 | int borderType = REFLECT); 23 | 24 | MNN_PUBLIC VARP blur(VARP src, Size ksize, int borderType = REFLECT); 25 | 26 | MNN_PUBLIC VARP boxFilter(VARP src, int ddepth, Size ksize, 27 | bool normalize = true, int borderType = REFLECT); 28 | 29 | MNN_PUBLIC VARP dilate(VARP src, VARP kernel, 30 | int iterations = 1, int borderType = CONSTANT); 31 | 32 | MNN_PUBLIC VARP filter2D(VARP src, int ddepth, VARP kernel, 33 | double delta = 0, int borderType = REFLECT); 34 | 35 | MNN_PUBLIC VARP GaussianBlur(VARP src, Size ksize, double sigmaX, 36 | double sigmaY = 0, int borderType = REFLECT); 37 | 38 | MNN_PUBLIC std::pair getDerivKernels(int dx, int dy, int ksize, 39 | bool normalize = false); 40 | 41 | MNN_PUBLIC VARP getGaborKernel(Size ksize, double sigma, double theta, double lambd, 42 | double gamma, double psi = MNN_PI * 0.5); 43 | 44 | MNN_PUBLIC VARP getGaussianKernel(int n, double sigma); 45 | 46 | MNN_PUBLIC VARP getStructuringElement(int shape, Size ksize); 47 | 48 | MNN_PUBLIC VARP Laplacian(VARP src, int ddepth, int ksize = 1, 49 | double scale = 1, double delta = 0, int borderType = REFLECT); 50 | 51 | MNN_PUBLIC VARP pyrDown(VARP src, Size dstsize = {}, int borderType = REFLECT); 52 | 53 | MNN_PUBLIC VARP pyrUp(VARP src, Size dstsize = {}, int borderType = REFLECT); 54 | 55 | MNN_PUBLIC VARP Scharr(VARP src, int ddepth, int dx, int dy, 56 | double scale = 1, double delta = 0, int borderType = REFLECT); 57 | 58 | MNN_PUBLIC VARP sepFilter2D(VARP src, int ddepth, VARP& kernelX, VARP& kernelY, 59 | double delta = 0, int borderType = REFLECT); 60 | 61 | MNN_PUBLIC VARP Sobel(VARP src, int ddepth, int dx, int dy, int ksize = 3, 62 | double scale = 1, double delta = 0, int borderType = REFLECT); 63 | 64 | MNN_PUBLIC std::pair spatialGradient(VARP src, int ksize = 3, 65 | int borderType = REFLECT); 66 | 67 | MNN_PUBLIC VARP sqrBoxFilter(VARP src, int ddepth, Size ksize, 68 | bool normalize = true, int borderType = REFLECT); 69 | } // CV 70 | } // MNN 71 | #endif // FILTER_HPP 72 | -------------------------------------------------------------------------------- /include/cv/imgproc/geometric.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // geometric.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2021/08/19. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef GEOMETRIC_HPP 10 | #define GEOMETRIC_HPP 11 | 12 | #include 13 | #include 14 | #include 15 | #include "../types.hpp" 16 | 17 | namespace MNN { 18 | namespace CV { 19 | 20 | enum InterpolationFlags { 21 | INTER_NEAREST = 0, 22 | INTER_LINEAR = 1, 23 | INTER_CUBIC = 2, 24 | INTER_AREA = 3, 25 | INTER_LANCZOS4 = 4, 26 | INTER_LINEAR_EXACT = 5, 27 | INTER_NEAREST_EXACT = 6, 28 | INTER_MAX = 7, 29 | WARP_FILL_OUTLIERS = 8, 30 | WARP_INVERSE_MAP = 16 31 | }; 32 | 33 | enum BorderTypes { 34 | BORDER_CONSTANT = 0, 35 | BORDER_REPLICATE = 1, 36 | BORDER_REFLECT = 2, 37 | BORDER_WRAP = 3, 38 | BORDER_REFLECT_101 = 4, 39 | BORDER_TRANSPARENT = 5, 40 | BORDER_REFLECT101 = BORDER_REFLECT_101, 41 | BORDER_DEFAULT = BORDER_REFLECT_101, 42 | BORDER_ISOLATED = 16 43 | }; 44 | 45 | MNN_PUBLIC std::pair convertMaps(VARP map1, VARP map2, int dstmap1type, 46 | bool nninterpolation = false); 47 | 48 | MNN_PUBLIC Matrix getAffineTransform(const Point src[], const Point dst[]); 49 | 50 | MNN_PUBLIC Matrix getPerspectiveTransform(const Point src[], const Point dst[]); 51 | 52 | MNN_PUBLIC VARP getRectSubPix(VARP image, Size patchSize, Point center); 53 | 54 | MNN_PUBLIC Matrix getRotationMatrix2D(Point center, double angle, double scale); 55 | 56 | MNN_PUBLIC Matrix invertAffineTransform(Matrix M); 57 | 58 | MNN_PUBLIC VARP resize(VARP src, Size dsize, double fx = 0, double fy = 0, 59 | int interpolation = INTER_LINEAR, int code = -1, 60 | std::vector mean = {}, std::vector norm = {}); 61 | 62 | MNN_PUBLIC VARP warpAffine(VARP src, Matrix M, Size dsize, 63 | int flags = INTER_LINEAR, int borderMode = BORDER_CONSTANT, int borderValue = 0, 64 | int code = -1, std::vector mean = {}, std::vector norm = {}); 65 | 66 | MNN_PUBLIC VARP warpPerspective(VARP src, Matrix M, Size dsize, 67 | int flags = INTER_LINEAR, int borderMode = BORDER_CONSTANT, 68 | int borderValue = 0); 69 | 70 | MNN_PUBLIC VARP undistortPoints(VARP src, VARP cameraMatrix, VARP distCoeffs); 71 | } // CV 72 | } // MNN 73 | #endif // GEOMETRIC_HPP 74 | -------------------------------------------------------------------------------- /include/cv/imgproc/histograms.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // histograms.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2022/07/26. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef HISTOGRAMS_HPP 10 | #define HISTOGRAMS_HPP 11 | 12 | #include 13 | #include 14 | 15 | namespace MNN { 16 | namespace CV { 17 | using namespace Express; 18 | 19 | MNN_PUBLIC VARP calcHist(VARPS images, const std::vector& channels, VARP mask, 20 | const std::vector& histSize, const std::vector& ranges, bool accumulate = false); 21 | 22 | } // CV 23 | } // MNN 24 | #endif // HISTOGRAMS_HPP 25 | -------------------------------------------------------------------------------- /include/cv/imgproc/imgproc.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // imgproc.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2021/08/13. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef IMGPROC_HPP 10 | #define IMGPROC_HPP 11 | 12 | #include "filter.hpp" 13 | #include "geometric.hpp" 14 | #include "draw.hpp" 15 | #include "miscellaneous.hpp" 16 | #include "color.hpp" 17 | #include "structural.hpp" 18 | #include "histograms.hpp" 19 | 20 | #endif // IMGPROC_HPP 21 | -------------------------------------------------------------------------------- /include/cv/imgproc/miscellaneous.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // miscellaneous.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2021/08/20. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef MISCELLANEOUS_HPP 10 | #define MISCELLANEOUS_HPP 11 | 12 | #include 13 | #include 14 | 15 | namespace MNN { 16 | namespace CV { 17 | using namespace Express; 18 | 19 | enum ThresholdTypes { 20 | THRESH_BINARY = 0, 21 | THRESH_BINARY_INV = 1, 22 | THRESH_TRUNC = 2, 23 | THRESH_TOZERO = 3, 24 | THRESH_TOZERO_INV = 4, 25 | THRESH_MASK = 7, 26 | THRESH_OTSU = 8, 27 | THRESH_TRIANGLE = 16 28 | }; 29 | 30 | MNN_PUBLIC VARP adaptiveThreshold(VARP src, double maxValue, int adaptiveMethod, int thresholdType, int blockSize, double C); 31 | 32 | MNN_PUBLIC VARP blendLinear(VARP src1, VARP src2, VARP weight1, VARP weight2); 33 | 34 | MNN_PUBLIC void distanceTransform(VARP src, VARP& dst, VARP& labels, int distanceType, int maskSize, int labelType = 0); 35 | 36 | MNN_PUBLIC int floodFill(VARP image, std::pair seedPoint, float newVal); 37 | 38 | MNN_PUBLIC VARP integral(VARP src, int sdepth = -1); 39 | 40 | MNN_PUBLIC VARP threshold(VARP src, double thresh, double maxval, int type); 41 | 42 | } // CV 43 | } // MNN 44 | #endif // MISCELLANEOUS_HPP 45 | -------------------------------------------------------------------------------- /include/cv/imgproc/structural.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // structural.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2021/12/01. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef STRUCTURAL_HPP 10 | #define STRUCTURAL_HPP 11 | 12 | #include 13 | #include "cv/types.hpp" 14 | 15 | namespace MNN { 16 | namespace CV { 17 | 18 | enum RetrievalModes { 19 | RETR_EXTERNAL = 0, 20 | RETR_LIST = 1, 21 | RETR_CCOMP = 2, 22 | RETR_TREE = 3, 23 | RETR_FLOODFILL = 4 24 | }; 25 | 26 | enum ContourApproximationModes { 27 | CHAIN_APPROX_NONE = 1, 28 | CHAIN_APPROX_SIMPLE = 2, 29 | CHAIN_APPROX_TC89_L1 = 3, 30 | CHAIN_APPROX_TC89_KCOS = 4 31 | }; 32 | 33 | class RotatedRect 34 | { 35 | public: 36 | //! default constructor 37 | RotatedRect() {} 38 | //! returns the rectangle mass center 39 | Point2f center; 40 | //! returns width and height of the rectangle 41 | Size2f size; 42 | //! returns the rotation angle. When the angle is 0, 90, 180, 270 etc., the rectangle becomes an up-right rectangle. 43 | float angle; 44 | }; 45 | typedef std::vector POINTS; 46 | 47 | MNN_PUBLIC std::vector findContours(VARP image, int mode, int method, Point offset = {0, 0}); 48 | MNN_PUBLIC double contourArea(VARP _contour, bool oriented = false); 49 | MNN_PUBLIC std::vector convexHull(VARP _points, bool clockwise = false, bool returnPoints = true); 50 | MNN_PUBLIC RotatedRect minAreaRect(VARP _points); 51 | MNN_PUBLIC Rect2i boundingRect(VARP points); 52 | MNN_PUBLIC int connectedComponentsWithStats(VARP image, VARP& labels, VARP& statsv, VARP& centroids, int connectivity = 8); 53 | MNN_PUBLIC VARP boxPoints(RotatedRect box); 54 | } // CV 55 | } // MNN 56 | #endif // STRUCTURAL_HPP 57 | -------------------------------------------------------------------------------- /include/cv/types.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // types.hpp 3 | // MNN 4 | // 5 | // Created by MNN on 2021/08/18. 6 | // Copyright © 2018, Alibaba Group Holding Limited 7 | // 8 | 9 | #ifndef TYPES_HPP 10 | #define TYPES_HPP 11 | 12 | #include 13 | 14 | namespace MNN { 15 | namespace CV { 16 | 17 | using namespace Express; 18 | 19 | #define MNN_PI 3.1415926535897932384626433832795 20 | 21 | typedef signed char schar; 22 | typedef unsigned char uchar; 23 | 24 | // Size Start 25 | template class Size_ 26 | { 27 | public: 28 | typedef _Tp value_type; 29 | 30 | //! default constructor 31 | Size_(); 32 | Size_(_Tp _width, _Tp _height); 33 | Size_(const Size_& sz); 34 | Size_(Size_&& sz); 35 | 36 | Size_& operator = (const Size_& sz); 37 | Size_& operator = (Size_&& sz); 38 | //! the area (width*height) 39 | _Tp area() const; 40 | //! aspect ratio (width/height) 41 | double aspectRatio() const; 42 | //! true if empty 43 | bool empty() const; 44 | 45 | //! conversion of another data type. 46 | template operator Size_<_Tp2>() const; 47 | 48 | _Tp width; //!< the width 49 | _Tp height; //!< the height 50 | }; 51 | 52 | typedef Size_ Size2i; 53 | typedef Size_ Size2l; 54 | typedef Size_ Size2f; 55 | typedef Size_ Size2d; 56 | typedef Size2i Size; 57 | 58 | template inline 59 | Size_<_Tp>::Size_() 60 | : width(0), height(0) {} 61 | 62 | template inline 63 | Size_<_Tp>::Size_(_Tp _width, _Tp _height) 64 | : width(_width), height(_height) {} 65 | 66 | template inline 67 | Size_<_Tp>::Size_(const Size_& sz) 68 | : width(sz.width), height(sz.height) {} 69 | 70 | template inline 71 | Size_<_Tp>::Size_(Size_&& sz) 72 | : width(std::move(sz.width)), height(std::move(sz.height)) {} 73 | 74 | template template inline 75 | Size_<_Tp>::operator Size_<_Tp2>() const 76 | { 77 | return Size_<_Tp2>(static_cast<_Tp2>(width), static_cast<_Tp2>(height)); 78 | } 79 | 80 | template inline 81 | Size_<_Tp>& Size_<_Tp>::operator = (const Size_<_Tp>& sz) 82 | { 83 | width = sz.width; height = sz.height; 84 | return *this; 85 | } 86 | 87 | template inline 88 | Size_<_Tp>& Size_<_Tp>::operator = (Size_<_Tp>&& sz) 89 | { 90 | width = std::move(sz.width); height = std::move(sz.height); 91 | return *this; 92 | } 93 | 94 | template inline 95 | _Tp Size_<_Tp>::area() const 96 | { 97 | return width * height; 98 | } 99 | 100 | template inline 101 | bool Size_<_Tp>::empty() const 102 | { 103 | return width <= 0 || height <= 0; 104 | } 105 | 106 | template static inline 107 | Size_<_Tp>& operator *= (Size_<_Tp>& a, _Tp b) 108 | { 109 | a.width *= b; 110 | a.height *= b; 111 | return a; 112 | } 113 | 114 | template static inline 115 | Size_<_Tp> operator * (const Size_<_Tp>& a, _Tp b) 116 | { 117 | Size_<_Tp> tmp(a); 118 | tmp *= b; 119 | return tmp; 120 | } 121 | 122 | template static inline 123 | Size_<_Tp>& operator /= (Size_<_Tp>& a, _Tp b) 124 | { 125 | a.width /= b; 126 | a.height /= b; 127 | return a; 128 | } 129 | 130 | template static inline 131 | Size_<_Tp> operator / (const Size_<_Tp>& a, _Tp b) 132 | { 133 | Size_<_Tp> tmp(a); 134 | tmp /= b; 135 | return tmp; 136 | } 137 | 138 | template static inline 139 | Size_<_Tp>& operator += (Size_<_Tp>& a, const Size_<_Tp>& b) 140 | { 141 | a.width += b.width; 142 | a.height += b.height; 143 | return a; 144 | } 145 | 146 | template static inline 147 | Size_<_Tp> operator + (const Size_<_Tp>& a, const Size_<_Tp>& b) 148 | { 149 | Size_<_Tp> tmp(a); 150 | tmp += b; 151 | return tmp; 152 | } 153 | 154 | template static inline 155 | Size_<_Tp>& operator -= (Size_<_Tp>& a, const Size_<_Tp>& b) 156 | { 157 | a.width -= b.width; 158 | a.height -= b.height; 159 | return a; 160 | } 161 | 162 | template static inline 163 | Size_<_Tp> operator - (const Size_<_Tp>& a, const Size_<_Tp>& b) 164 | { 165 | Size_<_Tp> tmp(a); 166 | tmp -= b; 167 | return tmp; 168 | } 169 | 170 | template static inline 171 | bool operator == (const Size_<_Tp>& a, const Size_<_Tp>& b) 172 | { 173 | return a.width == b.width && a.height == b.height; 174 | } 175 | 176 | template static inline 177 | bool operator != (const Size_<_Tp>& a, const Size_<_Tp>& b) 178 | { 179 | return !(a == b); 180 | } 181 | // Size End 182 | // Point Start 183 | template class Point_ 184 | { 185 | public: 186 | typedef _Tp value_type; 187 | 188 | //! default constructor 189 | Point_(); 190 | Point_(_Tp _x, _Tp _y); 191 | Point_(const Point_& pt); 192 | Point_(Point_&& pt); 193 | Point_(const Size_<_Tp>& sz); 194 | 195 | Point_& operator = (const Point_& pt); 196 | Point_& operator = (Point_&& pt); 197 | template operator Point_<_Tp2>() const; 198 | 199 | _Tp x; //!< x coordinate of the point 200 | _Tp y; //!< y coordinate of the point 201 | }; 202 | 203 | typedef Point_ Point2i; 204 | typedef Point_ Point2l; 205 | typedef Point_ Point2f; 206 | typedef Point_ Point2d; 207 | 208 | template inline 209 | Point_<_Tp>::Point_() 210 | : x(0), y(0) {} 211 | 212 | template inline 213 | Point_<_Tp>::Point_(_Tp _x, _Tp _y) 214 | : x(_x), y(_y) {} 215 | 216 | template inline 217 | Point_<_Tp>::Point_(const Point_& pt) 218 | : x(pt.x), y(pt.y) {} 219 | 220 | template inline 221 | Point_<_Tp>::Point_(Point_&& pt) 222 | : x(std::move(pt.x)), y(std::move(pt.y)) {} 223 | 224 | template inline 225 | Point_<_Tp>::Point_(const Size_<_Tp>& sz) 226 | : x(sz.width), y(sz.height) {} 227 | 228 | template inline 229 | Point_<_Tp>& Point_<_Tp>::operator = (const Point_& pt) 230 | { 231 | x = pt.x; y = pt.y; 232 | return *this; 233 | } 234 | 235 | template inline 236 | Point_<_Tp>& Point_<_Tp>::operator = (Point_&& pt) 237 | { 238 | x = std::move(pt.x); y = std::move(pt.y); 239 | return *this; 240 | } 241 | 242 | template template inline 243 | Point_<_Tp>::operator Point_<_Tp2>() const 244 | { 245 | return Point_<_Tp2>(static_cast<_Tp2>(x), static_cast<_Tp2>(y)); 246 | } 247 | 248 | template static inline 249 | Point_<_Tp>& operator += (Point_<_Tp>& a, const Point_<_Tp>& b) 250 | { 251 | a.x += b.x; 252 | a.y += b.y; 253 | return a; 254 | } 255 | 256 | template static inline 257 | Point_<_Tp> operator - (const Point_<_Tp>& a, const Point_<_Tp>& b) 258 | { 259 | return Point_<_Tp>( static_cast<_Tp>(a.x - b.x), static_cast<_Tp>(a.y - b.y) ); 260 | } 261 | 262 | template static inline 263 | bool operator != (const Point_<_Tp>& a, const Point_<_Tp>& b) 264 | { 265 | return a.x != b.x || a.y != b.y; 266 | } 267 | // Point End 268 | // Rect Start 269 | template class Rect_ 270 | { 271 | public: 272 | typedef _Tp value_type; 273 | 274 | //! default constructor 275 | Rect_(); 276 | Rect_(_Tp _x, _Tp _y, _Tp _width, _Tp _height); 277 | Rect_(const Rect_& r); 278 | Rect_(Rect_&& r); 279 | Rect_(const Point_<_Tp>& org, const Size_<_Tp>& sz); 280 | Rect_(const Point_<_Tp>& pt1, const Point_<_Tp>& pt2); 281 | 282 | Rect_& operator = ( const Rect_& r ); 283 | Rect_& operator = ( Rect_&& r ); 284 | //! the top-left corner 285 | Point_<_Tp> tl() const; 286 | //! the bottom-right corner 287 | Point_<_Tp> br() const; 288 | 289 | //! size (width, height) of the rectangle 290 | Size_<_Tp> size() const; 291 | //! area (width*height) of the rectangle 292 | _Tp area() const; 293 | //! true if empty 294 | bool empty() const; 295 | 296 | _Tp x; //!< x coordinate of the top-left corner 297 | _Tp y; //!< y coordinate of the top-left corner 298 | _Tp width; //!< width of the rectangle 299 | _Tp height; //!< height of the rectangle 300 | }; 301 | 302 | typedef Rect_ Rect2i; 303 | typedef Rect_ Rect2f; 304 | typedef Rect_ Rect2d; 305 | 306 | template inline 307 | Rect_<_Tp>::Rect_() 308 | : x(0), y(0), width(0), height(0) {} 309 | 310 | template inline 311 | Rect_<_Tp>::Rect_(_Tp _x, _Tp _y, _Tp _width, _Tp _height) 312 | : x(_x), y(_y), width(_width), height(_height) {} 313 | 314 | template inline 315 | Rect_<_Tp>::Rect_(const Rect_<_Tp>& r) 316 | : x(r.x), y(r.y), width(r.width), height(r.height) {} 317 | 318 | template inline 319 | Rect_<_Tp>::Rect_(Rect_<_Tp>&& r) 320 | : x(std::move(r.x)), y(std::move(r.y)), width(std::move(r.width)), height(std::move(r.height)) {} 321 | 322 | template inline 323 | Rect_<_Tp>::Rect_(const Point_<_Tp>& org, const Size_<_Tp>& sz) 324 | : x(org.x), y(org.y), width(sz.width), height(sz.height) {} 325 | 326 | template inline 327 | Rect_<_Tp>::Rect_(const Point_<_Tp>& pt1, const Point_<_Tp>& pt2) 328 | { 329 | x = std::min(pt1.x, pt2.x); 330 | y = std::min(pt1.y, pt2.y); 331 | width = std::max(pt1.x, pt2.x) - x; 332 | height = std::max(pt1.y, pt2.y) - y; 333 | } 334 | 335 | template inline 336 | Rect_<_Tp>& Rect_<_Tp>::operator = ( const Rect_<_Tp>& r ) 337 | { 338 | x = r.x; 339 | y = r.y; 340 | width = r.width; 341 | height = r.height; 342 | return *this; 343 | } 344 | 345 | template inline 346 | Rect_<_Tp>& Rect_<_Tp>::operator = ( Rect_<_Tp>&& r ) 347 | { 348 | x = std::move(r.x); 349 | y = std::move(r.y); 350 | width = std::move(r.width); 351 | height = std::move(r.height); 352 | return *this; 353 | } 354 | 355 | template inline 356 | Point_<_Tp> Rect_<_Tp>::tl() const 357 | { 358 | return Point_<_Tp>(x,y); 359 | } 360 | 361 | template inline 362 | Point_<_Tp> Rect_<_Tp>::br() const 363 | { 364 | return Point_<_Tp>(x + width, y + height); 365 | } 366 | 367 | template inline 368 | Size_<_Tp> Rect_<_Tp>::size() const 369 | { 370 | return Size_<_Tp>(width, height); 371 | } 372 | 373 | template inline 374 | _Tp Rect_<_Tp>::area() const 375 | { 376 | const _Tp result = width * height; 377 | return result; 378 | } 379 | 380 | template inline 381 | bool Rect_<_Tp>::empty() const 382 | { 383 | return width <= 0 || height <= 0; 384 | } 385 | // Rect 386 | // Scalar Start 387 | template class Scalar_ { 388 | public: 389 | //! default constructor 390 | Scalar_(); 391 | Scalar_(_Tp _r, _Tp _g, _Tp _b) { 392 | val[0] = _r; 393 | val[1] = _g; 394 | val[2] = _b; 395 | val[3] = 255; 396 | }; 397 | Scalar_(_Tp _r, _Tp _g, _Tp _b, _Tp _a) { 398 | val[0] = _r; 399 | val[1] = _g; 400 | val[2] = _b; 401 | val[3] = _a; 402 | }; 403 | _Tp val[4]; 404 | }; 405 | typedef Scalar_ Scalar; 406 | // Scalar End 407 | 408 | static void getVARPSize(VARP var, int* height, int* width, int* channel) { 409 | auto info = var->getInfo(); 410 | auto dims = info->dim; 411 | int num = dims.size(); 412 | if (num < 2) return; 413 | if (num == 2) { 414 | *height = dims[0]; 415 | *width = dims[1]; 416 | *channel = 1; 417 | } else if (num == 3) { 418 | *height = dims[0]; 419 | *width = dims[1]; 420 | *channel = dims[2]; 421 | } else if (info->order == NHWC) { 422 | *channel = dims[num - 1]; 423 | *width = dims[num - 2]; 424 | *height = dims[num - 3]; 425 | } else { // NCHW 426 | *width = dims[num - 1]; 427 | *height = dims[num - 2]; 428 | *channel = dims[num - 3]; 429 | } 430 | } 431 | static int getVARPHeight(VARP var) { 432 | int h, w, c; 433 | getVARPSize(var, &h, &w, &c); 434 | return h; 435 | } 436 | static int getVARPWidth(VARP var) { 437 | int h, w, c; 438 | getVARPSize(var, &h, &w, &c); 439 | return w; 440 | } 441 | static int getVARPChannel(VARP var) { 442 | int h, w, c; 443 | getVARPSize(var, &h, &w, &c); 444 | return c; 445 | } 446 | static int getVARPByte(VARP var) { 447 | return var->getInfo()->type.bytes(); 448 | } 449 | } // CV 450 | } // MNN 451 | #endif // TYPES_HPP 452 | -------------------------------------------------------------------------------- /include/pipeline.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | using namespace MNN; 9 | using namespace MNN::Express; 10 | 11 | namespace diffusion { 12 | 13 | class Pipeline { 14 | public: 15 | Pipeline(std::string modelPath); 16 | ~Pipeline() = default; 17 | bool run(const std::string& sentence, const std::string& img_name); 18 | private: 19 | void loadNet(std::string modelPath); 20 | void runNet(); 21 | VARP step_plms(VARP sample, VARP model_output, int index); 22 | std::unique_ptr text_encoder(const std::vector& ids); 23 | VARP unet(std::unique_ptr text_embeddings); 24 | VARP vae_decoder(VARP latent); 25 | private: 26 | std::unique_ptr mNet; 27 | MNN::Session* mSession; 28 | std::map mInputs, mOutputs; 29 | std::string mModelPath; 30 | // step_plms 31 | std::vector mTimeSteps; 32 | std::vector mAlphas; 33 | std::vector mEts; 34 | VARP mSample; 35 | }; 36 | 37 | } -------------------------------------------------------------------------------- /include/tokenizer.hpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | namespace diffusion { 6 | 7 | class tokenizer { 8 | public: 9 | tokenizer(std::string dictPath); 10 | ~tokenizer() {} 11 | int word(std::string word); 12 | std::vector sentence(std::string sentence, int maxlen = 0); 13 | private: 14 | int mStartIdx, mEndIdx; 15 | std::unordered_map mWordDict; 16 | }; 17 | 18 | } // diffusion -------------------------------------------------------------------------------- /libs/README.md: -------------------------------------------------------------------------------- 1 | # Libs 2 | 3 | copy libs from `MNN` 4 | -------------------------------------------------------------------------------- /resource/alphas.txt: -------------------------------------------------------------------------------- 1 | 0.9991 2 | 0.9983 3 | 0.9974 4 | 0.9966 5 | 0.9957 6 | 0.9948 7 | 0.9940 8 | 0.9931 9 | 0.9922 10 | 0.9913 11 | 0.9904 12 | 0.9895 13 | 0.9886 14 | 0.9877 15 | 0.9868 16 | 0.9859 17 | 0.9850 18 | 0.9841 19 | 0.9832 20 | 0.9822 21 | 0.9813 22 | 0.9804 23 | 0.9794 24 | 0.9785 25 | 0.9776 26 | 0.9766 27 | 0.9757 28 | 0.9747 29 | 0.9737 30 | 0.9728 31 | 0.9718 32 | 0.9708 33 | 0.9698 34 | 0.9689 35 | 0.9679 36 | 0.9669 37 | 0.9659 38 | 0.9649 39 | 0.9639 40 | 0.9629 41 | 0.9619 42 | 0.9609 43 | 0.9599 44 | 0.9588 45 | 0.9578 46 | 0.9568 47 | 0.9557 48 | 0.9547 49 | 0.9537 50 | 0.9526 51 | 0.9516 52 | 0.9505 53 | 0.9495 54 | 0.9484 55 | 0.9473 56 | 0.9463 57 | 0.9452 58 | 0.9441 59 | 0.9430 60 | 0.9420 61 | 0.9409 62 | 0.9398 63 | 0.9387 64 | 0.9376 65 | 0.9365 66 | 0.9354 67 | 0.9343 68 | 0.9332 69 | 0.9320 70 | 0.9309 71 | 0.9298 72 | 0.9287 73 | 0.9275 74 | 0.9264 75 | 0.9252 76 | 0.9241 77 | 0.9229 78 | 0.9218 79 | 0.9206 80 | 0.9195 81 | 0.9183 82 | 0.9171 83 | 0.9160 84 | 0.9148 85 | 0.9136 86 | 0.9124 87 | 0.9112 88 | 0.9100 89 | 0.9089 90 | 0.9077 91 | 0.9065 92 | 0.9052 93 | 0.9040 94 | 0.9028 95 | 0.9016 96 | 0.9004 97 | 0.8992 98 | 0.8979 99 | 0.8967 100 | 0.8955 101 | 0.8942 102 | 0.8930 103 | 0.8917 104 | 0.8905 105 | 0.8892 106 | 0.8880 107 | 0.8867 108 | 0.8854 109 | 0.8842 110 | 0.8829 111 | 0.8816 112 | 0.8804 113 | 0.8791 114 | 0.8778 115 | 0.8765 116 | 0.8752 117 | 0.8739 118 | 0.8726 119 | 0.8713 120 | 0.8700 121 | 0.8687 122 | 0.8674 123 | 0.8661 124 | 0.8647 125 | 0.8634 126 | 0.8621 127 | 0.8607 128 | 0.8594 129 | 0.8581 130 | 0.8567 131 | 0.8554 132 | 0.8540 133 | 0.8527 134 | 0.8513 135 | 0.8500 136 | 0.8486 137 | 0.8473 138 | 0.8459 139 | 0.8445 140 | 0.8431 141 | 0.8418 142 | 0.8404 143 | 0.8390 144 | 0.8376 145 | 0.8362 146 | 0.8348 147 | 0.8334 148 | 0.8320 149 | 0.8306 150 | 0.8292 151 | 0.8278 152 | 0.8264 153 | 0.8250 154 | 0.8236 155 | 0.8221 156 | 0.8207 157 | 0.8193 158 | 0.8179 159 | 0.8164 160 | 0.8150 161 | 0.8136 162 | 0.8121 163 | 0.8107 164 | 0.8092 165 | 0.8078 166 | 0.8063 167 | 0.8049 168 | 0.8034 169 | 0.8019 170 | 0.8005 171 | 0.7990 172 | 0.7975 173 | 0.7960 174 | 0.7946 175 | 0.7931 176 | 0.7916 177 | 0.7901 178 | 0.7886 179 | 0.7871 180 | 0.7856 181 | 0.7842 182 | 0.7827 183 | 0.7812 184 | 0.7796 185 | 0.7781 186 | 0.7766 187 | 0.7751 188 | 0.7736 189 | 0.7721 190 | 0.7706 191 | 0.7690 192 | 0.7675 193 | 0.7660 194 | 0.7645 195 | 0.7629 196 | 0.7614 197 | 0.7599 198 | 0.7583 199 | 0.7568 200 | 0.7552 201 | 0.7537 202 | 0.7521 203 | 0.7506 204 | 0.7490 205 | 0.7475 206 | 0.7459 207 | 0.7444 208 | 0.7428 209 | 0.7412 210 | 0.7397 211 | 0.7381 212 | 0.7365 213 | 0.7350 214 | 0.7334 215 | 0.7318 216 | 0.7302 217 | 0.7286 218 | 0.7271 219 | 0.7255 220 | 0.7239 221 | 0.7223 222 | 0.7207 223 | 0.7191 224 | 0.7175 225 | 0.7159 226 | 0.7143 227 | 0.7127 228 | 0.7111 229 | 0.7095 230 | 0.7079 231 | 0.7063 232 | 0.7047 233 | 0.7031 234 | 0.7015 235 | 0.6999 236 | 0.6982 237 | 0.6966 238 | 0.6950 239 | 0.6934 240 | 0.6918 241 | 0.6901 242 | 0.6885 243 | 0.6869 244 | 0.6852 245 | 0.6836 246 | 0.6820 247 | 0.6803 248 | 0.6787 249 | 0.6771 250 | 0.6754 251 | 0.6738 252 | 0.6722 253 | 0.6705 254 | 0.6689 255 | 0.6672 256 | 0.6656 257 | 0.6639 258 | 0.6623 259 | 0.6606 260 | 0.6590 261 | 0.6573 262 | 0.6557 263 | 0.6540 264 | 0.6524 265 | 0.6507 266 | 0.6490 267 | 0.6474 268 | 0.6457 269 | 0.6441 270 | 0.6424 271 | 0.6407 272 | 0.6391 273 | 0.6374 274 | 0.6357 275 | 0.6341 276 | 0.6324 277 | 0.6307 278 | 0.6291 279 | 0.6274 280 | 0.6257 281 | 0.6241 282 | 0.6224 283 | 0.6207 284 | 0.6190 285 | 0.6174 286 | 0.6157 287 | 0.6140 288 | 0.6123 289 | 0.6107 290 | 0.6090 291 | 0.6073 292 | 0.6056 293 | 0.6039 294 | 0.6023 295 | 0.6006 296 | 0.5989 297 | 0.5972 298 | 0.5955 299 | 0.5939 300 | 0.5922 301 | 0.5905 302 | 0.5888 303 | 0.5871 304 | 0.5855 305 | 0.5838 306 | 0.5821 307 | 0.5804 308 | 0.5787 309 | 0.5770 310 | 0.5754 311 | 0.5737 312 | 0.5720 313 | 0.5703 314 | 0.5686 315 | 0.5669 316 | 0.5652 317 | 0.5636 318 | 0.5619 319 | 0.5602 320 | 0.5585 321 | 0.5568 322 | 0.5551 323 | 0.5535 324 | 0.5518 325 | 0.5501 326 | 0.5484 327 | 0.5467 328 | 0.5450 329 | 0.5434 330 | 0.5417 331 | 0.5400 332 | 0.5383 333 | 0.5366 334 | 0.5350 335 | 0.5333 336 | 0.5316 337 | 0.5299 338 | 0.5282 339 | 0.5266 340 | 0.5249 341 | 0.5232 342 | 0.5215 343 | 0.5199 344 | 0.5182 345 | 0.5165 346 | 0.5148 347 | 0.5132 348 | 0.5115 349 | 0.5098 350 | 0.5082 351 | 0.5065 352 | 0.5048 353 | 0.5032 354 | 0.5015 355 | 0.4998 356 | 0.4982 357 | 0.4965 358 | 0.4948 359 | 0.4932 360 | 0.4915 361 | 0.4898 362 | 0.4882 363 | 0.4865 364 | 0.4849 365 | 0.4832 366 | 0.4816 367 | 0.4799 368 | 0.4782 369 | 0.4766 370 | 0.4749 371 | 0.4733 372 | 0.4716 373 | 0.4700 374 | 0.4684 375 | 0.4667 376 | 0.4651 377 | 0.4634 378 | 0.4618 379 | 0.4601 380 | 0.4585 381 | 0.4569 382 | 0.4552 383 | 0.4536 384 | 0.4520 385 | 0.4503 386 | 0.4487 387 | 0.4471 388 | 0.4455 389 | 0.4438 390 | 0.4422 391 | 0.4406 392 | 0.4390 393 | 0.4374 394 | 0.4357 395 | 0.4341 396 | 0.4325 397 | 0.4309 398 | 0.4293 399 | 0.4277 400 | 0.4261 401 | 0.4245 402 | 0.4229 403 | 0.4213 404 | 0.4197 405 | 0.4181 406 | 0.4165 407 | 0.4149 408 | 0.4133 409 | 0.4117 410 | 0.4101 411 | 0.4086 412 | 0.4070 413 | 0.4054 414 | 0.4038 415 | 0.4022 416 | 0.4007 417 | 0.3991 418 | 0.3975 419 | 0.3960 420 | 0.3944 421 | 0.3928 422 | 0.3913 423 | 0.3897 424 | 0.3882 425 | 0.3866 426 | 0.3850 427 | 0.3835 428 | 0.3819 429 | 0.3804 430 | 0.3789 431 | 0.3773 432 | 0.3758 433 | 0.3742 434 | 0.3727 435 | 0.3712 436 | 0.3697 437 | 0.3681 438 | 0.3666 439 | 0.3651 440 | 0.3636 441 | 0.3621 442 | 0.3605 443 | 0.3590 444 | 0.3575 445 | 0.3560 446 | 0.3545 447 | 0.3530 448 | 0.3515 449 | 0.3500 450 | 0.3485 451 | 0.3470 452 | 0.3456 453 | 0.3441 454 | 0.3426 455 | 0.3411 456 | 0.3396 457 | 0.3382 458 | 0.3367 459 | 0.3352 460 | 0.3338 461 | 0.3323 462 | 0.3308 463 | 0.3294 464 | 0.3279 465 | 0.3265 466 | 0.3250 467 | 0.3236 468 | 0.3222 469 | 0.3207 470 | 0.3193 471 | 0.3178 472 | 0.3164 473 | 0.3150 474 | 0.3136 475 | 0.3122 476 | 0.3107 477 | 0.3093 478 | 0.3079 479 | 0.3065 480 | 0.3051 481 | 0.3037 482 | 0.3023 483 | 0.3009 484 | 0.2995 485 | 0.2981 486 | 0.2967 487 | 0.2953 488 | 0.2940 489 | 0.2926 490 | 0.2912 491 | 0.2899 492 | 0.2885 493 | 0.2871 494 | 0.2858 495 | 0.2844 496 | 0.2831 497 | 0.2817 498 | 0.2804 499 | 0.2790 500 | 0.2777 501 | 0.2763 502 | 0.2750 503 | 0.2737 504 | 0.2723 505 | 0.2710 506 | 0.2697 507 | 0.2684 508 | 0.2671 509 | 0.2658 510 | 0.2645 511 | 0.2631 512 | 0.2618 513 | 0.2606 514 | 0.2593 515 | 0.2580 516 | 0.2567 517 | 0.2554 518 | 0.2541 519 | 0.2528 520 | 0.2516 521 | 0.2503 522 | 0.2490 523 | 0.2478 524 | 0.2465 525 | 0.2453 526 | 0.2440 527 | 0.2428 528 | 0.2415 529 | 0.2403 530 | 0.2391 531 | 0.2378 532 | 0.2366 533 | 0.2354 534 | 0.2341 535 | 0.2329 536 | 0.2317 537 | 0.2305 538 | 0.2293 539 | 0.2281 540 | 0.2269 541 | 0.2257 542 | 0.2245 543 | 0.2233 544 | 0.2221 545 | 0.2209 546 | 0.2198 547 | 0.2186 548 | 0.2174 549 | 0.2163 550 | 0.2151 551 | 0.2139 552 | 0.2128 553 | 0.2116 554 | 0.2105 555 | 0.2093 556 | 0.2082 557 | 0.2071 558 | 0.2059 559 | 0.2048 560 | 0.2037 561 | 0.2026 562 | 0.2014 563 | 0.2003 564 | 0.1992 565 | 0.1981 566 | 0.1970 567 | 0.1959 568 | 0.1948 569 | 0.1937 570 | 0.1926 571 | 0.1915 572 | 0.1905 573 | 0.1894 574 | 0.1883 575 | 0.1872 576 | 0.1862 577 | 0.1851 578 | 0.1841 579 | 0.1830 580 | 0.1820 581 | 0.1809 582 | 0.1799 583 | 0.1788 584 | 0.1778 585 | 0.1768 586 | 0.1757 587 | 0.1747 588 | 0.1737 589 | 0.1727 590 | 0.1717 591 | 0.1707 592 | 0.1696 593 | 0.1686 594 | 0.1677 595 | 0.1667 596 | 0.1657 597 | 0.1647 598 | 0.1637 599 | 0.1627 600 | 0.1618 601 | 0.1608 602 | 0.1598 603 | 0.1589 604 | 0.1579 605 | 0.1569 606 | 0.1560 607 | 0.1550 608 | 0.1541 609 | 0.1532 610 | 0.1522 611 | 0.1513 612 | 0.1504 613 | 0.1494 614 | 0.1485 615 | 0.1476 616 | 0.1467 617 | 0.1458 618 | 0.1449 619 | 0.1440 620 | 0.1431 621 | 0.1422 622 | 0.1413 623 | 0.1404 624 | 0.1395 625 | 0.1386 626 | 0.1378 627 | 0.1369 628 | 0.1360 629 | 0.1352 630 | 0.1343 631 | 0.1334 632 | 0.1326 633 | 0.1317 634 | 0.1309 635 | 0.1301 636 | 0.1292 637 | 0.1284 638 | 0.1276 639 | 0.1267 640 | 0.1259 641 | 0.1251 642 | 0.1243 643 | 0.1235 644 | 0.1227 645 | 0.1219 646 | 0.1211 647 | 0.1203 648 | 0.1195 649 | 0.1187 650 | 0.1179 651 | 0.1171 652 | 0.1163 653 | 0.1155 654 | 0.1148 655 | 0.1140 656 | 0.1132 657 | 0.1125 658 | 0.1117 659 | 0.1110 660 | 0.1102 661 | 0.1095 662 | 0.1087 663 | 0.1080 664 | 0.1073 665 | 0.1065 666 | 0.1058 667 | 0.1051 668 | 0.1044 669 | 0.1036 670 | 0.1029 671 | 0.1022 672 | 0.1015 673 | 0.1008 674 | 0.1001 675 | 0.0994 676 | 0.0987 677 | 0.0980 678 | 0.0973 679 | 0.0967 680 | 0.0960 681 | 0.0953 682 | 0.0946 683 | 0.0940 684 | 0.0933 685 | 0.0926 686 | 0.0920 687 | 0.0913 688 | 0.0907 689 | 0.0900 690 | 0.0894 691 | 0.0887 692 | 0.0881 693 | 0.0875 694 | 0.0868 695 | 0.0862 696 | 0.0856 697 | 0.0850 698 | 0.0844 699 | 0.0837 700 | 0.0831 701 | 0.0825 702 | 0.0819 703 | 0.0813 704 | 0.0807 705 | 0.0801 706 | 0.0795 707 | 0.0789 708 | 0.0784 709 | 0.0778 710 | 0.0772 711 | 0.0766 712 | 0.0761 713 | 0.0755 714 | 0.0749 715 | 0.0744 716 | 0.0738 717 | 0.0732 718 | 0.0727 719 | 0.0721 720 | 0.0716 721 | 0.0711 722 | 0.0705 723 | 0.0700 724 | 0.0694 725 | 0.0689 726 | 0.0684 727 | 0.0679 728 | 0.0673 729 | 0.0668 730 | 0.0663 731 | 0.0658 732 | 0.0653 733 | 0.0648 734 | 0.0643 735 | 0.0638 736 | 0.0633 737 | 0.0628 738 | 0.0623 739 | 0.0618 740 | 0.0613 741 | 0.0608 742 | 0.0604 743 | 0.0599 744 | 0.0594 745 | 0.0589 746 | 0.0585 747 | 0.0580 748 | 0.0575 749 | 0.0571 750 | 0.0566 751 | 0.0562 752 | 0.0557 753 | 0.0553 754 | 0.0548 755 | 0.0544 756 | 0.0539 757 | 0.0535 758 | 0.0531 759 | 0.0526 760 | 0.0522 761 | 0.0518 762 | 0.0514 763 | 0.0509 764 | 0.0505 765 | 0.0501 766 | 0.0497 767 | 0.0493 768 | 0.0489 769 | 0.0485 770 | 0.0481 771 | 0.0477 772 | 0.0473 773 | 0.0469 774 | 0.0465 775 | 0.0461 776 | 0.0457 777 | 0.0453 778 | 0.0450 779 | 0.0446 780 | 0.0442 781 | 0.0438 782 | 0.0435 783 | 0.0431 784 | 0.0427 785 | 0.0424 786 | 0.0420 787 | 0.0416 788 | 0.0413 789 | 0.0409 790 | 0.0406 791 | 0.0402 792 | 0.0399 793 | 0.0395 794 | 0.0392 795 | 0.0389 796 | 0.0385 797 | 0.0382 798 | 0.0379 799 | 0.0375 800 | 0.0372 801 | 0.0369 802 | 0.0365 803 | 0.0362 804 | 0.0359 805 | 0.0356 806 | 0.0353 807 | 0.0350 808 | 0.0347 809 | 0.0343 810 | 0.0340 811 | 0.0337 812 | 0.0334 813 | 0.0331 814 | 0.0328 815 | 0.0325 816 | 0.0323 817 | 0.0320 818 | 0.0317 819 | 0.0314 820 | 0.0311 821 | 0.0308 822 | 0.0305 823 | 0.0303 824 | 0.0300 825 | 0.0297 826 | 0.0295 827 | 0.0292 828 | 0.0289 829 | 0.0286 830 | 0.0284 831 | 0.0281 832 | 0.0279 833 | 0.0276 834 | 0.0274 835 | 0.0271 836 | 0.0268 837 | 0.0266 838 | 0.0264 839 | 0.0261 840 | 0.0259 841 | 0.0256 842 | 0.0254 843 | 0.0251 844 | 0.0249 845 | 0.0247 846 | 0.0244 847 | 0.0242 848 | 0.0240 849 | 0.0237 850 | 0.0235 851 | 0.0233 852 | 0.0231 853 | 0.0229 854 | 0.0226 855 | 0.0224 856 | 0.0222 857 | 0.0220 858 | 0.0218 859 | 0.0216 860 | 0.0214 861 | 0.0212 862 | 0.0210 863 | 0.0207 864 | 0.0205 865 | 0.0203 866 | 0.0201 867 | 0.0200 868 | 0.0198 869 | 0.0196 870 | 0.0194 871 | 0.0192 872 | 0.0190 873 | 0.0188 874 | 0.0186 875 | 0.0184 876 | 0.0182 877 | 0.0181 878 | 0.0179 879 | 0.0177 880 | 0.0175 881 | 0.0174 882 | 0.0172 883 | 0.0170 884 | 0.0168 885 | 0.0167 886 | 0.0165 887 | 0.0163 888 | 0.0162 889 | 0.0160 890 | 0.0158 891 | 0.0157 892 | 0.0155 893 | 0.0154 894 | 0.0152 895 | 0.0151 896 | 0.0149 897 | 0.0147 898 | 0.0146 899 | 0.0144 900 | 0.0143 901 | 0.0142 902 | 0.0140 903 | 0.0139 904 | 0.0137 905 | 0.0136 906 | 0.0134 907 | 0.0133 908 | 0.0132 909 | 0.0130 910 | 0.0129 911 | 0.0127 912 | 0.0126 913 | 0.0125 914 | 0.0123 915 | 0.0122 916 | 0.0121 917 | 0.0120 918 | 0.0118 919 | 0.0117 920 | 0.0116 921 | 0.0115 922 | 0.0113 923 | 0.0112 924 | 0.0111 925 | 0.0110 926 | 0.0109 927 | 0.0107 928 | 0.0106 929 | 0.0105 930 | 0.0104 931 | 0.0103 932 | 0.0102 933 | 0.0101 934 | 0.0100 935 | 0.0098 936 | 0.0097 937 | 0.0096 938 | 0.0095 939 | 0.0094 940 | 0.0093 941 | 0.0092 942 | 0.0091 943 | 0.0090 944 | 0.0089 945 | 0.0088 946 | 0.0087 947 | 0.0086 948 | 0.0085 949 | 0.0084 950 | 0.0083 951 | 0.0082 952 | 0.0082 953 | 0.0081 954 | 0.0080 955 | 0.0079 956 | 0.0078 957 | 0.0077 958 | 0.0076 959 | 0.0075 960 | 0.0074 961 | 0.0074 962 | 0.0073 963 | 0.0072 964 | 0.0071 965 | 0.0070 966 | 0.0070 967 | 0.0069 968 | 0.0068 969 | 0.0067 970 | 0.0066 971 | 0.0066 972 | 0.0065 973 | 0.0064 974 | 0.0063 975 | 0.0063 976 | 0.0062 977 | 0.0061 978 | 0.0061 979 | 0.0060 980 | 0.0059 981 | 0.0058 982 | 0.0058 983 | 0.0057 984 | 0.0056 985 | 0.0056 986 | 0.0055 987 | 0.0054 988 | 0.0054 989 | 0.0053 990 | 0.0053 991 | 0.0052 992 | 0.0051 993 | 0.0051 994 | 0.0050 995 | 0.0049 996 | 0.0049 997 | 0.0048 998 | 0.0048 999 | 0.0047 1000 | 0.0047 -------------------------------------------------------------------------------- /resource/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhaode/mnn-stable-diffusion/f19782e2e44055eaffe968ece7f3e297a29346b8/resource/demo.jpg -------------------------------------------------------------------------------- /resource/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhaode/mnn-stable-diffusion/f19782e2e44055eaffe968ece7f3e297a29346b8/resource/logo.png -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "pipeline.hpp" 3 | 4 | int main(int argc, const char* argv[]) { 5 | if (argc < 3) { 6 | printf("Usage: ./main "); 7 | return 1; 8 | } 9 | auto sentence = argv[1]; 10 | auto img_name = argv[2]; 11 | printf("input setnetce: %s\n", sentence); 12 | printf("output img_name: %s\n", img_name); 13 | diffusion::Pipeline pipeline("../resource"); 14 | pipeline.run(sentence, img_name); 15 | return 0; 16 | } -------------------------------------------------------------------------------- /src/pipeline.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "pipeline.hpp" 5 | #include "tokenizer.hpp" 6 | // #define MNN_OPEN_TIME_TRACE 7 | #include 8 | #include 9 | 10 | using namespace CV; 11 | 12 | namespace diffusion { 13 | 14 | void display_progress(int cur, int total){ 15 | putchar('\r'); 16 | printf("["); 17 | for (int i = 0; i < cur; i++) putchar('#'); 18 | for (int i = 0; i < total - cur; i++) putchar('-'); 19 | printf("]"); 20 | fprintf(stdout, " [%3d%%]", cur * 100 / total); 21 | if (cur == total) putchar('\n'); 22 | fflush(stdout); 23 | } 24 | 25 | Pipeline::Pipeline(std::string modelPath) : mModelPath(modelPath) { 26 | std::ifstream alphaFile(modelPath + "/alphas.txt"); 27 | int index = 0; 28 | float alpha; 29 | while (alphaFile >> alpha) { 30 | mAlphas.push_back(alpha); 31 | } 32 | mTimeSteps = { 33 | 981, 961, 961, 941, 921, 901, 881, 861, 841, 821, 801, 781, 761, 741, 34 | 721, 701, 681, 661, 641, 621, 601, 581, 561, 541, 521, 501, 481, 461, 35 | 441, 421, 401, 381, 361, 341, 321, 301, 281, 261, 241, 221, 201, 181, 36 | 161, 141, 121, 101, 81, 61, 41, 21, 1 37 | }; 38 | } 39 | 40 | void Pipeline::loadNet(std::string modelPath) { 41 | mNet.reset(Interpreter::createFromFile(modelPath.c_str())); 42 | ScheduleConfig config; 43 | #if 1 44 | config.type = MNN_FORWARD_CUDA; 45 | BackendConfig backendConfig; 46 | backendConfig.precision = BackendConfig::Precision_Normal; 47 | #else 48 | config.type = MNN_FORWARD_CPU; 49 | config.numThread = 12; 50 | BackendConfig backendConfig; 51 | backendConfig.precision = BackendConfig::Precision_Normal; 52 | #endif 53 | config.backendConfig = &backendConfig; 54 | mSession = mNet->createSession(config); 55 | mNet->releaseModel(); 56 | } 57 | 58 | void Pipeline::runNet() { 59 | // AUTOTIME; 60 | auto t1 = std::chrono::high_resolution_clock::now(); 61 | mNet->runSession(mSession); 62 | auto t2 = std::chrono::high_resolution_clock::now(); 63 | auto time = std::chrono::duration_cast(t2 - t1).count(); 64 | printf(" [iter time: %f ms]", time / 1000.0); 65 | } 66 | 67 | std::unique_ptr Pipeline::text_encoder(const std::vector& ids) { 68 | loadNet(mModelPath + "/text_encoder.mnn"); 69 | auto input = mNet->getSessionInput(mSession, NULL); 70 | auto output = mNet->getSessionOutput(mSession, NULL); 71 | std::unique_ptr idsTensor(Tensor::create(input->shape(), input->getType(), const_cast(ids.data()), input->getDimensionType())); 72 | input->copyFromHostTensor(idsTensor.get()); 73 | runNet(); 74 | std::unique_ptr text_embeddings(new Tensor(output, output->getDimensionType())); 75 | output->copyToHostTensor(text_embeddings.get()); 76 | mNet.reset(); 77 | return text_embeddings; 78 | } 79 | 80 | VARP Pipeline::step_plms(VARP sample, VARP model_output, int index) { 81 | int timestep = mTimeSteps[index]; 82 | int prev_timestep = 0; 83 | if (index + 1 < mTimeSteps.size()) { 84 | prev_timestep = mTimeSteps[index + 1]; 85 | } 86 | if (index != 1) { 87 | if (mEts.size() >= 4) { 88 | mEts[mEts.size() - 4] = nullptr; 89 | } 90 | mEts.push_back(model_output); 91 | } else { 92 | timestep = mTimeSteps[0]; 93 | prev_timestep = mTimeSteps[1]; 94 | } 95 | int ets = mEts.size() - 1; 96 | if (index == 0) { 97 | mSample = sample; 98 | } else if (index == 1) { 99 | model_output = (model_output + mEts[ets]) * _Const(0.5); 100 | sample = mSample; 101 | } else if (ets == 1) { 102 | model_output = (_Const(3.0) * mEts[ets] - mEts[ets-1]) * _Const(0.5); 103 | } else if (ets == 2) { 104 | model_output = (_Const(23.0) * mEts[ets] - _Const(16.0) * mEts[ets-1] + _Const(5.0) * mEts[ets-2]) * _Const(1.0 / 12.0); 105 | } else if (ets >= 3) { 106 | model_output = _Const(1. / 24.) * (_Const(55.0) * mEts[ets] - _Const(59.0) * mEts[ets-1] + _Const(37.0) * mEts[ets-2] - _Const(9.0) * mEts[ets-3]); 107 | } 108 | auto alpha_prod_t = mAlphas[timestep]; 109 | auto alpha_prod_t_prev = mAlphas[prev_timestep]; 110 | auto beta_prod_t = 1 - alpha_prod_t; 111 | auto beta_prod_t_prev = 1 - alpha_prod_t_prev; 112 | auto sample_coeff = std::sqrt(alpha_prod_t_prev / alpha_prod_t); 113 | auto model_output_denom_coeff = alpha_prod_t * std::sqrt(beta_prod_t_prev) + std::sqrt(alpha_prod_t * beta_prod_t * alpha_prod_t_prev); 114 | auto prev_sample = _Scalar(sample_coeff) * sample - _Scalar((alpha_prod_t_prev - alpha_prod_t)/model_output_denom_coeff) * model_output; 115 | return prev_sample; 116 | } 117 | 118 | VARP Pipeline::unet(std::unique_ptr text_embeddings) { 119 | loadNet(mModelPath + "/unet.mnn"); 120 | auto sample = mNet->getSessionInput(mSession, "sample"); 121 | auto timestep = mNet->getSessionInput(mSession, "timestep"); 122 | auto encoder_hidden_states = mNet->getSessionInput(mSession, "encoder_hidden_states"); 123 | auto output = mNet->getSessionOutput(mSession, NULL); 124 | 125 | std::unique_ptr latents(new Tensor(sample, Tensor::CAFFE)); 126 | std::unique_ptr timestepVal(new Tensor(timestep, timestep->getDimensionType())); 127 | std::unique_ptr pred(new Tensor(output, output->getDimensionType())); 128 | 129 | std::mt19937 rng; 130 | rng.seed(std::random_device()()); 131 | std::normal_distribution normal(0, 1); 132 | std::vector initVal(16384); 133 | for (int i = 0; i < 16384; i++) { 134 | initVal[i] = normal(rng); 135 | } 136 | VARP latentvar = _Const(initVal.data(), {1, 4, 64, 64}, NCHW); 137 | int zero = 0, one = 1; 138 | for (int i = 0; i < mTimeSteps.size(); i++) { 139 | display_progress(i, 50); 140 | memcpy(latents->host(), latentvar->readMap(), 65536); 141 | memcpy(latents->host() + 65536, latentvar->readMap(), 65536); 142 | timestepVal->host()[0] = mTimeSteps[i]; 143 | sample->copyFromHostTensor(latents.get()); 144 | timestep->copyFromHostTensor(timestepVal.get()); 145 | encoder_hidden_states->copyFromHostTensor(text_embeddings.get()); 146 | runNet(); 147 | output->copyToHostTensor(pred.get()); 148 | auto noise_pred = Variable::create(Expr::create(pred.get(), false)); 149 | auto noise_pred_uncond = _Gather(noise_pred, _Const(&zero, {1}, NHWC, halide_type_of())); 150 | auto noise_pred_text = _Gather(noise_pred, _Const(&one, {1}, NHWC, halide_type_of())); 151 | noise_pred = _Const(7.5) * (noise_pred_text - noise_pred_uncond) + noise_pred_uncond; 152 | latentvar = step_plms(latentvar, noise_pred, i); 153 | } 154 | latentvar.fix(VARP::CONSTANT); 155 | return latentvar; 156 | } 157 | 158 | VARP Pipeline::vae_decoder(VARP latent) { 159 | latent = latent * _Const(1 / 0.18215); 160 | loadNet(mModelPath + "/vae_decoder.mnn"); 161 | auto input = mNet->getSessionInput(mSession, NULL); 162 | auto output = mNet->getSessionOutput(mSession, NULL); 163 | std::unique_ptr latentTensor(Tensor::create(input->shape(), input->getType(), const_cast(latent->readMap()), input->getDimensionType())); 164 | input->copyFromHostTensor(latentTensor.get()); 165 | runNet(); 166 | std::unique_ptr sampleTensor(new Tensor(output, output->getDimensionType())); 167 | output->copyToHostTensor(sampleTensor.get()); 168 | auto image = Variable::create(Expr::create(sampleTensor.get(), false)); 169 | image = _Relu6(image * _Const(0.5) + _Const(0.5), 0, 1); 170 | image = _Squeeze(_Transpose(image, {0, 2, 3, 1})); 171 | image = _Cast(_Round(image * _Const(255.0)), halide_type_of()); 172 | image = cvtColor(image, COLOR_BGR2RGB); 173 | image.fix(VARP::CONSTANT); 174 | return image; 175 | } 176 | 177 | bool Pipeline::run(const std::string& sentence, const std::string& img_name) { 178 | diffusion::tokenizer tok(mModelPath + "/vocab.txt"); 179 | auto ids = tok.sentence(sentence, 512); 180 | auto text_embeddings = text_encoder(ids); 181 | auto latent = unet(std::move(text_embeddings)); 182 | auto image = vae_decoder(latent); 183 | bool res = imwrite(img_name, image); 184 | if (res) printf("SUCCESS! write to %s\n", img_name.c_str()); 185 | return res; 186 | } 187 | 188 | } 189 | -------------------------------------------------------------------------------- /src/tokenizer.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include "tokenizer.hpp" 6 | 7 | namespace diffusion { 8 | 9 | tokenizer::tokenizer(std::string dictPath) { 10 | std::ifstream dictFile(dictPath); 11 | int index = 0; 12 | std::string word; 13 | while (dictFile >> word) { 14 | mWordDict.insert(std::make_pair(std::move(word), index++)); 15 | } 16 | mStartIdx = this->word("[CLS]"); 17 | mEndIdx = this->word("[SEP]"); 18 | } 19 | 20 | int tokenizer::word(std::string word) { 21 | const auto& iter = mWordDict.find(word); 22 | if (iter != mWordDict.end()) { 23 | return iter->second; 24 | } 25 | return -1; 26 | } 27 | 28 | std::vector tokenizer::sentence(std::string sentence, int maxlen) { 29 | std::vector ids(maxlen * 2, 0); 30 | // uncond 31 | ids[0] = mStartIdx; 32 | ids[1] = mEndIdx; 33 | // ids 34 | int idx = maxlen; 35 | ids[idx++] = mStartIdx; 36 | for (size_t i = 0; i < sentence.length();) { 37 | int wordlen = 1; 38 | if ((sentence[i] & 0xf8) == 0xf0) { 39 | wordlen = 4; 40 | } else if ((sentence[i] & 0xf0) == 0xe0) { 41 | wordlen = 3; 42 | } else if ((sentence[i] & 0xe0) == 0xc0) { 43 | wordlen = 2; 44 | } 45 | if ((i + wordlen) > sentence.length()) { 46 | wordlen = 1; 47 | } 48 | std::string word = sentence.substr(i, wordlen); 49 | ids[idx++] = this->word(word); 50 | i += wordlen; 51 | } 52 | ids[idx++] = mEndIdx; 53 | return ids; 54 | } 55 | 56 | } // diffusion --------------------------------------------------------------------------------