├── cpp ├── libs │ └── .gitkeep ├── include │ └── .gitkeep ├── CMakeLists.txt ├── README.md └── sam_demo.cpp ├── resource ├── logo.png ├── res.jpg └── truck.jpg ├── python ├── README.md └── segment_anything_example.py └── README.md /cpp/libs/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cpp/include/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /resource/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhaode/mnn-segment-anything/HEAD/resource/logo.png -------------------------------------------------------------------------------- /resource/res.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhaode/mnn-segment-anything/HEAD/resource/res.jpg -------------------------------------------------------------------------------- /resource/truck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzhaode/mnn-segment-anything/HEAD/resource/truck.jpg -------------------------------------------------------------------------------- /python/README.md: -------------------------------------------------------------------------------- 1 | # Usage 2 | 3 | ## Install MNN 4 | ``` 5 | pip install MNN 6 | ``` 7 | 8 | ## Run Demo 9 | ``` 10 | python segment_anything_example.py --embed embed.mnn --sam segment.mnn --img ../resource/truck.jpg 11 | # edge model need add `--edge` 12 | python segment_anything_example.py --embed edge_embed.mnn --sam edge_segment.mnn --img ../resource/truck.jpg --edge 13 | ``` -------------------------------------------------------------------------------- /cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0) 2 | project(sam) 3 | 4 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") 5 | 6 | # include dir 7 | include_directories(${CMAKE_CURRENT_LIST_DIR}/include/) 8 | 9 | # libs dir 10 | link_directories(${CMAKE_CURRENT_LIST_DIR}/libs) 11 | 12 | # source files 13 | FILE(GLOB SRCS ${CMAKE_CURRENT_LIST_DIR}/*.cpp) 14 | 15 | # target 16 | add_executable(sam_demo ${SRCS}) 17 | 18 | # link 19 | if (MSVC) 20 | target_link_libraries(sam_demo MNN) 21 | else() 22 | target_link_libraries(sam_demo MNN MNN_Express MNNOpenCV log) 23 | endif() 24 | -------------------------------------------------------------------------------- /cpp/README.md: -------------------------------------------------------------------------------- 1 | # Usage 2 | 3 | ## Compile MNN library 4 | ### Linx/Mac 5 | ```bash 6 | git clone https://github.com/alibaba/MNN.git 7 | # copy header file 8 | cp -r MNN/include . 9 | cp -r MNN/tools/cv/include . 10 | cd MNN 11 | mkdir build 12 | cmake -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON .. 13 | make -j8 14 | cd .. 15 | cp MNN/build/libMNN.so MNN/build/express/libMNN_Express.so MNN/build/tools/cv/libMNNOpenCV.so ./libs 16 | ``` 17 | 18 | ### Windows 19 | ```bash 20 | # Visual Studio xxxx Developer Command Prompt 21 | powershell 22 | git clone https://github.com/alibaba/MNN.git 23 | # copy header file 24 | cp -r MNN/include . 25 | cp -r MNN/tools/cv/include . 26 | cd MNN 27 | mkdir build 28 | cmake -G "Ninja" -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON .. 29 | ninja 30 | cd .. 31 | cp MNN.dll MNN.lib ./libs 32 | ``` 33 | 34 | ## Build and Run 35 | 36 | #### Linux/Mac 37 | ```bash 38 | mkdir build && cd build 39 | cmake .. 40 | make -j4 41 | ./sam_demo embed.mnn segment.mnn ../../resource/truck.jpg 42 | # edge model need add `1` 43 | ./sam_demo edge_embed.mnn edge_segment.mnn ../../resource/truck.jpg 1 44 | ``` 45 | #### Windows 46 | ```bash 47 | # Visual Studio xxxx Developer Command Prompt 48 | powershell 49 | mkdir build && cd build 50 | cmake -G "Ninja" .. 51 | ninja 52 | ./sam_demo embed.mnn segment.mnn ../../resource/truck.jpg 53 | # edge model need add `1` 54 | ./sam_demo edge_embed.mnn edge_segment.mnn ../../resource/truck.jpg 1 55 | ``` 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![mnn-segment-anything](resource/logo.png) 2 | 3 | # mnn-segment-anything 4 | 5 | ## 说明 6 | `mnn-segment-anything`支持Python与C++在Linux, MacOS, Windows, Android与iOS上运行。目前支持以下模型: 7 | - [segment-anything](https://github.com/facebookresearch/segment-anything)的`vit_b`, `vit_l`, `vit_h` 8 | - [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) 9 | - [EdgeSAM](https://github.com/chongzhou96/EdgeSAM) 10 | 11 | ## 模型文件 12 | 13 | - 端侧部署建议使用模型: 14 | - edge_sam 15 | - [edge_embed.mnn](https://github.com/wangzhaode/mnn-segment-anything/releases/download/edge_mnn/edge_embed.mnn): `21 MB` 16 | - [edge_segment.mnn](https://github.com/wangzhaode/mnn-segment-anything/releases/download/edge_mnn/edge_segment.mnn): `19.3M` 17 | - mobiel_sam 18 | - [mobile_embed.mnn](https://github.com/wangzhaode/mnn-segment-anything/releases/download/mobile_mnn/mobile_embed.mnn): `26.7 MB` 19 | - [mobile_segment.mnn](https://github.com/wangzhaode/mnn-segment-anything/releases/download/mobile_mnn/mobile_segment.mnn): `19.7M` 20 | 21 | | model | onnx | mnn | 22 | |:---------:|:------:|:------:| 23 | | edge_sam | [![Download][download-e-onnx]][release-e-onnx] | [![Download][download-e-mnn]][release-e-mnn] | 24 | | mobile_sam | [![Download][download-m-onnx]][release-m-onnx] | [![Download][download-m-mnn]][release-m-mnn] | 25 | | sam_vit_b | [![Download][download-b-onnx]][release-b-onnx] | [![Download][download-b-mnn]][release-b-mnn] | 26 | | sam_vit_l | [![Download][download-l-onnx]][release-l-onnx] | [![Download][download-l-mnn]][release-l-mnn] | 27 | | sam_vit_h | [![Download][download-h-onnx]][release-h-onnx] | [![Download][download-h-mnn]][release-h-mnn] | 28 | 29 | [download-e-onnx]: https://img.shields.io/github/downloads/wangzhaode/mnn-segment-anything/edge_onnx/total 30 | [download-m-onnx]: https://img.shields.io/github/downloads/wangzhaode/mnn-segment-anything/mobile_onnx/total 31 | [download-b-onnx]: https://img.shields.io/github/downloads/wangzhaode/mnn-segment-anything/vit_b_onnx/total 32 | [download-l-onnx]: https://img.shields.io/github/downloads/wangzhaode/mnn-segment-anything/vit_l_onnx/total 33 | [download-h-onnx]: https://img.shields.io/github/downloads/wangzhaode/mnn-segment-anything/vit_h_onnx/total 34 | 35 | [download-e-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-segment-anything/edge_mnn/total 36 | [download-m-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-segment-anything/mobile_mnn/total 37 | [download-b-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-segment-anything/vit_b_mnn/total 38 | [download-l-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-segment-anything/vit_l_mnn/total 39 | [download-h-mnn]: https://img.shields.io/github/downloads/wangzhaode/mnn-segment-anything/vit_h_mnn/total 40 | 41 | [release-e-onnx]: https://github.com/wangzhaode/mnn-segment-anything/releases/tag/edge_onnx 42 | [release-m-onnx]: https://github.com/wangzhaode/mnn-segment-anything/releases/tag/mobile_onnx 43 | [release-b-onnx]: https://github.com/wangzhaode/mnn-segment-anything/releases/tag/vit_b_onnx 44 | [release-l-onnx]: https://github.com/wangzhaode/mnn-segment-anything/releases/tag/vit_l_onnx 45 | [release-h-onnx]: https://github.com/wangzhaode/mnn-segment-anything/releases/tag/vit_h_onnx 46 | 47 | [release-e-mnn]: https://github.com/wangzhaode/mnn-segment-anything/releases/tag/edge_mnn 48 | [release-m-mnn]: https://github.com/wangzhaode/mnn-segment-anything/releases/tag/mobile_mnn 49 | [release-b-mnn]: https://github.com/wangzhaode/mnn-segment-anything/releases/tag/vit_b_mnn 50 | [release-l-mnn]: https://github.com/wangzhaode/mnn-segment-anything/releases/tag/vit_l_mnn 51 | [release-h-mnn]: https://github.com/wangzhaode/mnn-segment-anything/releases/tag/vit_h_mnn 52 | 53 | ## 示例代码 54 | - [Python](./python/) 55 | - [C++](./cpp) 56 | 57 | ## 示例输出 58 | ![res](resource/res.jpg) 59 | -------------------------------------------------------------------------------- /python/segment_anything_example.py: -------------------------------------------------------------------------------- 1 | #-- coding:utf8 -- 2 | import argparse 3 | import time 4 | import MNN 5 | import MNN.numpy as np 6 | import MNN.cv as cv2 7 | 8 | def inference(emed, sam, img, precision, backend, thread, is_edge): 9 | mask_threshold = 0.0 10 | # 0. load model 11 | config = {} 12 | config['precision'] = precision 13 | config['backend'] = backend 14 | config['numThread'] = thread 15 | rt = MNN.nn.create_runtime_manager((config,)) 16 | embed = MNN.nn.load_module_from_file(emed, [], [], runtime_manager=rt) 17 | if is_edge: 18 | sam_inputs = ['point_coords', 'point_labels', 'image_embeddings'] 19 | sam_outputs = ['masks', 'scores'] 20 | else: 21 | sam_inputs = ['point_coords', 'point_labels', 'image_embeddings', 'has_mask_input', 'mask_input', 'orig_im_size'] 22 | sam_outputs = ['iou_predictions', 'low_res_masks', 'masks'] 23 | sam = MNN.nn.load_module_from_file(sam, sam_inputs, sam_outputs, runtime_manager=rt) 24 | # 1. preprocess 25 | image = cv2.imread(img) 26 | origin_h, origin_w, _ = image.shape 27 | length = 1024 28 | if origin_h > origin_w: 29 | new_w = round(origin_w * float(length) / origin_h) 30 | new_h = length 31 | else: 32 | new_h = round(origin_h * float(length) / origin_w) 33 | new_w = length 34 | scale_w = new_w / origin_w 35 | sclae_h = new_h / origin_h 36 | input_var = cv2.resize(image, (new_w, new_h), 0., 0., cv2.INTER_LINEAR, -1, [123.675, 116.28, 103.53], [1/58.395, 1/57.12, 1/57.375]) 37 | input_var = np.pad(input_var, [[0, length - new_h], [0, length - new_w], [0, 0]], 'constant') 38 | input_var = np.expand_dims(input_var, 0) 39 | # 2. embedding forward 40 | input_var = MNN.expr.convert(input_var, MNN.expr.NC4HW4) 41 | t1 = time.time() 42 | output_var = embed.forward(input_var) 43 | t2 = time.time() 44 | print('# 1. embedding times: {} ms'.format((t2 - t1) * 1000)) 45 | image_embedding = MNN.expr.convert(output_var, MNN.expr.NCHW) 46 | # 3. segment forward 47 | points = [[500, 375]] 48 | sclaes = [scale_w, sclae_h] 49 | input_point = np.array(points) * sclaes 50 | input_label = np.array([1]) 51 | point_coords = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] 52 | point_labels = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) 53 | if is_edge: 54 | input_vars = [point_coords, point_labels, image_embedding] 55 | else: 56 | orig_im_size = np.array([float(origin_h), float(origin_w)], dtype=np.float32) 57 | mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) 58 | has_mask_input = np.zeros(1, dtype=np.float32) 59 | input_vars = [point_coords, point_labels, image_embedding, has_mask_input, mask_input, orig_im_size] 60 | t1 = time.time() 61 | output_vars = sam.onForward(input_vars) 62 | t2 = time.time() 63 | print('# 1. segment times: {} ms'.format((t2 - t1) * 1000)) 64 | # 4. postprocess: draw masks and point 65 | if is_edge: 66 | low_res_masks = output_vars[0] 67 | low_res_masks = MNN.expr.convert(low_res_masks, MNN.expr.NC4HW4) 68 | print(low_res_masks.data_format) 69 | _, _, h, w, = low_res_masks.shape 70 | masks = MNN.expr.resize(low_res_masks, length / w, length / h) 71 | masks = masks[:, :, :new_h, :new_w] 72 | masks = MNN.expr.resize(masks, origin_w / new_w, origin_h / new_h) 73 | else: 74 | masks = output_vars[2] 75 | masks = MNN.expr.convert(masks, MNN.expr.NCHW).squeeze([0])[0] 76 | masks = (masks > mask_threshold).reshape([origin_h, origin_w, 1]) 77 | color = np.array([30, 144, 255]).reshape([1, 1, -1]) 78 | image = (image + masks * color).astype(np.uint8) 79 | for point in points: 80 | cv2.circle(image, point, 10, (0, 0, 255), 5) 81 | cv2.imwrite('res.jpg', image) 82 | 83 | if __name__ == "__main__": 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--embed', type=str, required=True, help='the embedding model path') 86 | parser.add_argument('--sam', type=str, required=True, help='the sam model path') 87 | parser.add_argument('--img', type=str, required=True, help='the input image path') 88 | parser.add_argument('--precision', type=str, default='normal', help='inference precision: normal, low, high, lowBF') 89 | parser.add_argument('--backend', type=str, default='CPU', help='inference backend: CPU, OPENCL, OPENGL, NN, VULKAN, METAL, TRT, CUDA, HIAI') 90 | parser.add_argument('--thread', type=int, default=4, help='inference using thread: int') 91 | parser.add_argument('--edge', action='store_true', help='using edge sam model.') 92 | args = parser.parse_args() 93 | inference(args.embed, args.sam, args.img, args.precision, args.backend, args.thread, args.edge) -------------------------------------------------------------------------------- /cpp/sam_demo.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | using namespace MNN; 12 | using namespace MNN::Express; 13 | using namespace MNN::CV; 14 | 15 | int main(int argc, const char* argv[]) { 16 | if (argc < 4) { 17 | MNN_PRINT("Usage: ./sam_demo.out embed.mnn sam.mnn input.jpg [is_edge] [forwardType] [precision] [thread]\n"); 18 | return 0; 19 | } 20 | bool is_edge = false; 21 | int thread = 4; 22 | int precision = 0; 23 | int forwardType = MNN_FORWARD_CPU; 24 | if (argc >= 5) { 25 | is_edge = atoi(argv[4]); 26 | } 27 | if (argc >= 6) { 28 | forwardType = atoi(argv[5]); 29 | } 30 | if (argc >= 7) { 31 | precision = atoi(argv[6]); 32 | } 33 | if (argc >= 8) { 34 | thread = atoi(argv[7]); 35 | } 36 | float mask_threshold = 0; 37 | MNN::ScheduleConfig sConfig; 38 | sConfig.type = static_cast(forwardType); 39 | sConfig.numThread = thread; 40 | BackendConfig bConfig; 41 | bConfig.precision = static_cast(precision); 42 | sConfig.backendConfig = &bConfig; 43 | std::shared_ptr rtmgr = std::shared_ptr(Executor::RuntimeManager::createRuntimeManager(sConfig)); 44 | if(rtmgr == nullptr) { 45 | MNN_ERROR("Empty RuntimeManger\n"); 46 | return 0; 47 | } 48 | // rtmgr->setCache(".cachefile"); 49 | std::shared_ptr embed(Module::load(std::vector{}, std::vector{}, argv[1], rtmgr)); 50 | std::vector sam_inputs = {"point_coords", "point_labels", "image_embeddings", "has_mask_input", "mask_input", "orig_im_size"}; 51 | std::vector sam_outputs = {"iou_predictions", "low_res_masks", "masks"}; 52 | if (is_edge) { 53 | sam_inputs = {"point_coords", "point_labels", "image_embeddings"}; 54 | sam_outputs = {"masks", "scores"}; 55 | } 56 | std::shared_ptr sam(Module::load(sam_inputs, sam_outputs, argv[2], rtmgr)); 57 | auto image = imread(argv[3]); 58 | // 1. preprocess 59 | auto dims = image->getInfo()->dim; 60 | int origin_h = dims[0]; 61 | int origin_w = dims[1]; 62 | int length = 1024; 63 | int new_h, new_w; 64 | if (origin_h > origin_w) { 65 | new_w = round(origin_w * (float)length / origin_h); 66 | new_h = length; 67 | } else { 68 | new_h = round(origin_h * (float)length / origin_w); 69 | new_w = length; 70 | } 71 | float scale_w = (float)new_w / origin_w; 72 | float scale_h = (float)new_h / origin_h; 73 | auto input_var = resize(image, Size(new_w, new_h), 0, 0, INTER_LINEAR, -1, {123.675, 116.28, 103.53}, {1/58.395, 1/57.12, 1/57.375}); 74 | std::vector padvals { 0, length - new_h, 0, length - new_w, 0, 0 }; 75 | auto pads = _Const(static_cast(padvals.data()), {3, 2}, NCHW, halide_type_of()); 76 | input_var = _Pad(input_var, pads, CONSTANT); 77 | input_var = _Unsqueeze(input_var, {0}); 78 | // 2. image embedding 79 | input_var = _Convert(input_var, NC4HW4); 80 | auto st = std::chrono::system_clock::now(); 81 | auto outputs = embed->onForward({input_var}); 82 | auto et = std::chrono::system_clock::now(); 83 | auto duration = std::chrono::duration_cast(et - st); 84 | printf("# 1. embedding times: %f ms\n", duration.count() * 1e-3); 85 | 86 | auto image_embedding = _Convert(outputs[0], NCHW); 87 | 88 | // 3. segment 89 | auto build_input = [](std::vector data, std::vector shape) { 90 | return _Const(static_cast(data.data()), shape, NCHW, halide_type_of()); 91 | }; 92 | // build inputs 93 | std::vector points = {500, 375}; 94 | auto scale_points = points; 95 | for (int i = 0; i < scale_points.size() / 2; i++) { 96 | scale_points[2 * i] = scale_points[2 * i] * scale_w; 97 | scale_points[2 * i + 1] = scale_points[2 * i + 1] * scale_h; 98 | } 99 | scale_points.push_back(0); 100 | scale_points.push_back(0); 101 | auto point_coords = build_input(scale_points, {1, 2, 2}); 102 | auto point_labels = build_input({1, -1}, {1, 2}); 103 | std::vector input_vars; 104 | if (is_edge) { 105 | input_vars = {point_coords, point_labels, image_embedding}; 106 | } else { 107 | auto orig_im_size = build_input({static_cast(origin_h), static_cast(origin_w)}, {2}); 108 | auto has_mask_input = build_input({0}, {1}); 109 | std::vector zeros(256*256, 0.f); 110 | auto mask_input = build_input(zeros, {1, 1, 256, 256}); 111 | input_vars = {point_coords, point_labels, image_embedding, has_mask_input, mask_input, orig_im_size}; 112 | } 113 | st = std::chrono::system_clock::now(); 114 | auto output_vars = sam->onForward(input_vars); 115 | et = std::chrono::system_clock::now(); 116 | duration = std::chrono::duration_cast(et - st); 117 | printf("# 2. segment times: %f ms\n", duration.count() * 1e-3); 118 | // 4. postprocess: draw mask and point 119 | // MobileSam has multi channel masks, get first 120 | VARP masks; 121 | if (is_edge) { 122 | masks = output_vars[0]; 123 | auto dims = masks->getInfo()->dim; 124 | int h = dims[2], w = dims[3]; 125 | masks = _Convert(masks, NC4HW4); 126 | masks = _Resize(masks, length/w, length/h); 127 | int sliceStartData[] = {0, 0, 0, 0}, sliceEndData[] = {-1, -1, new_h, new_w}; 128 | masks = _Slice(masks, _Const(sliceStartData, {4}, NCHW), _Const(sliceEndData, {4}, NCHW)); 129 | masks = _Resize(masks, (float)origin_w/new_w, (float)origin_h/new_h); 130 | } else { 131 | masks = output_vars[2]; 132 | } 133 | masks = _Convert(masks, NCHW); 134 | masks = _Gather(_Squeeze(masks, {0}), _Scalar(0)); 135 | masks = _Greater(masks, _Scalar(mask_threshold)); 136 | masks = _Reshape(masks, {origin_h, origin_w, 1}); 137 | std::vector color_vec {30, 144, 255}; 138 | auto color = _Const(static_cast(color_vec.data()), {1, 1, 3}, NCHW, halide_type_of()); 139 | image = _Cast(_Cast(image) + masks * color); 140 | auto ptr = image->readMap(); 141 | for (int i = 0; i < points.size() / 2; i++) { 142 | float x = points[2 * i]; 143 | float y = points[2 * i + 1]; 144 | circle(image, {x, y}, 10, {0, 0, 255}, 5); 145 | } 146 | if (imwrite("res.jpg", image)) { 147 | MNN_PRINT("result image write to `res.jpg`.\n"); 148 | } 149 | // rtmgr->updateCache(); 150 | return 0; 151 | } 152 | --------------------------------------------------------------------------------