├── about.png ├── installation.png ├── README_zh.md ├── README.md └── scripts └── segment_anything.py /about.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erickrus/stable-diffusion-webui-segment-anything/HEAD/about.png -------------------------------------------------------------------------------- /installation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Erickrus/stable-diffusion-webui-segment-anything/HEAD/installation.png -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # stable-diffusion-webui 万能图像分割扩展 2 | 3 | ## 安装 4 | 5 | ### 安装依赖包 6 | ```shell 7 | pip3 install opencv-python matplotlib onnx onnxruntime 8 | pip3 install 'git+https://github.com/facebookresearch/segment-anything.git' 9 | ``` 10 | 11 | 确保 pytorch>=2.0.0 12 | 13 | ### SAM权重 14 | 15 | 目前自动权重下载仅支持Linux/Mac, 使用的是wget。在Windows上可以将权重自行下载至models/sam/sam_vit_h_4b8939.pth 16 | 17 | https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 18 | 19 | ### 安装扩展 20 | 21 | 转到 `Extensions` Tab, 点击 `Install from URL`, 输入 `https://github.com/Erickrus/stable-diffusion-webui-segment-anything`, 点击 `Install` 22 | 23 | 24 | 25 | 脚本安装完成后,请重启webui 26 | 27 | 28 | 29 | 30 | ## 使用 31 | 32 | 将你的图片上传到左侧图像栏,然后在左侧图像上绘制/单击小圆点,然后单击Segment按钮。 33 | 34 | 请注意画笔半径请控制在 **5** 左右. 不要使用太大的半径,会影响最终分割点搜索和最终效果。 35 | 36 | ![](about.png) 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An extension for stable-diffusion-webui that segment the image elements 2 | 3 | [中文版说明](https://github.com/Erickrus/stable-diffusion-webui-segment-anything/blob/main/README_zh.md) 4 | 5 | ## Installation 6 | 7 | ### Install dependencies 8 | 9 | ```shell 10 | pip3 install opencv-python matplotlib onnx onnxruntime 11 | pip3 install 'git+https://github.com/facebookresearch/segment-anything.git' 12 | ``` 13 | 14 | ensure pytorch>=2.0.0 15 | 16 | ### Install extension from webui 17 | 18 | Go to `Extensions` Tab, Click `Install from URL`, Input `https://github.com/Erickrus/stable-diffusion-webui-segment-anything`, Click `Install` 19 | 20 | 21 | 22 | Please restart webui, after the script is installed 23 | 24 | 25 | 26 | 27 | ## How to use 28 | 29 | Upload the image to segment, draw/click tiny point(s) on the left side image and click segment button. 30 | 31 | Please make sure the brush radius is around **5**. Dont use large radius as they wont work. 32 | 33 | ![](about.png) 34 | -------------------------------------------------------------------------------- /scripts/segment_anything.py: -------------------------------------------------------------------------------- 1 | import html 2 | 3 | 4 | # installation 5 | # pip3 install opencv-python matplotlib onnx onnxruntime 6 | # pip3 install 'git+https://github.com/facebookresearch/segment-anything.git' 7 | # wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 8 | 9 | 10 | 11 | import cv2 12 | import numpy as np 13 | import torch 14 | 15 | from PIL import Image 16 | 17 | #import matplotlib.pyplot as plt 18 | 19 | import os 20 | 21 | from segment_anything import sam_model_registry, SamPredictor 22 | from segment_anything.utils.onnx import SamOnnxModel 23 | 24 | import onnxruntime 25 | from onnxruntime.quantization import QuantType 26 | from onnxruntime.quantization.quantize import quantize_dynamic 27 | 28 | from modules import script_callbacks, shared 29 | 30 | import gradio as gr 31 | import torch 32 | 33 | model_dir = "models/sam/" 34 | checkpoint = model_dir + "sam_vit_h_4b8939.pth" 35 | model_type = "vit_h" 36 | if not os.path.exists(checkpoint): 37 | os.mkdir(model_dir) 38 | os.system("wget -O %s https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" % checkpoint) 39 | 40 | 41 | 42 | sam = sam_model_registry[model_type](checkpoint=checkpoint) 43 | 44 | 45 | 46 | def extract_onnx(): 47 | import warnings 48 | onnx_model_path = model_dir + "sam_onnx_example.onnx" 49 | onnx_model = SamOnnxModel(sam, return_single_mask=True) 50 | 51 | dynamic_axes = { 52 | "point_coords": {1: "num_points"}, 53 | "point_labels": {1: "num_points"}, 54 | } 55 | 56 | embed_dim = sam.prompt_encoder.embed_dim 57 | embed_size = sam.prompt_encoder.image_embedding_size 58 | mask_input_size = [4 * x for x in embed_size] 59 | dummy_inputs = { 60 | "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), 61 | "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), 62 | "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), 63 | "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), 64 | "has_mask_input": torch.tensor([1], dtype=torch.float), 65 | "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), 66 | } 67 | output_names = ["masks", "iou_predictions", "low_res_masks"] 68 | 69 | with warnings.catch_warnings(): 70 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 71 | warnings.filterwarnings("ignore", category=UserWarning) 72 | with open(onnx_model_path, "wb") as f: 73 | torch.onnx.export( 74 | onnx_model, 75 | tuple(dummy_inputs.values()), 76 | f, 77 | export_params=True, 78 | verbose=False, 79 | opset_version=17, 80 | do_constant_folding=True, 81 | input_names=list(dummy_inputs.keys()), 82 | output_names=output_names, 83 | dynamic_axes=dynamic_axes, 84 | ) 85 | onnx_model_quantized_path = model_dir + "sam_onnx_quantized_example.onnx" 86 | quantize_dynamic( 87 | model_input=onnx_model_path, 88 | model_output=onnx_model_quantized_path, 89 | optimize_model=True, 90 | per_channel=False, 91 | reduce_range=False, 92 | weight_type=QuantType.QUInt8, 93 | ) 94 | 95 | onnx_model_quantized_path = model_dir + "sam_onnx_quantized_example.onnx" 96 | if not os.path.exists(onnx_model_quantized_path): 97 | extract_onnx() 98 | onnx_model_path = onnx_model_quantized_path 99 | ort_session = onnxruntime.InferenceSession(onnx_model_path) 100 | 101 | sam.to(device='cuda') 102 | predictor = SamPredictor(sam) 103 | 104 | def find_input_points(mask): 105 | input_points = [] 106 | gray = 255 - np.array(mask) 107 | thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] 108 | 109 | cnts = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 110 | cnts = cnts[0] if len(cnts) == 2 else cnts[1] 111 | for c in cnts: 112 | area = cv2.contourArea(c) 113 | if area < 100: 114 | x, y = c[:,:,0].mean(), c[:,:,1].mean() 115 | input_points.append([int(x), int(y)]) 116 | return input_points 117 | 118 | def segment_anything(im): 119 | img = im['image'].convert("RGB") 120 | mask = im['mask'].convert("L") 121 | input_points = find_input_points(mask) 122 | 123 | image = np.array(img) 124 | 125 | predictor.set_image(image) 126 | image_embedding = predictor.get_image_embedding().cpu().numpy() 127 | 128 | input_point = np.array(input_points) 129 | input_label = np.ones(len(input_point)).astype(np.int) 130 | print(input_point) 131 | 132 | onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] 133 | onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) 134 | 135 | onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32) 136 | 137 | onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) 138 | onnx_has_mask_input = np.zeros(1, dtype=np.float32) 139 | 140 | ort_inputs = { 141 | "image_embeddings": image_embedding, 142 | "point_coords": onnx_coord, 143 | "point_labels": onnx_label, 144 | "mask_input": onnx_mask_input, 145 | "has_mask_input": onnx_has_mask_input, 146 | "orig_im_size": np.array(image.shape[:2], dtype=np.float32) 147 | } 148 | 149 | masks, _, low_res_logits = ort_session.run(None, ort_inputs) 150 | masks = masks > predictor.model.mask_threshold 151 | 152 | masks = np.squeeze(np.squeeze(masks,0),0) 153 | masks = masks.astype(np.uint8) 154 | masks = np.repeat(masks[...,None], 3, axis=2) 155 | image = im['image'].convert("RGB") 156 | white = (255 * np.ones([image.size[0],image.size[1],3])).astype(np.uint8) 157 | image = np.array(image) * masks + white * (1 - masks) 158 | 159 | return Image.fromarray(image) 160 | 161 | def add_tab(): 162 | with gr.Blocks(analytics_enabled=False) as ui: 163 | with gr.Tabs() as tabs: 164 | with gr.Blocks(): 165 | with gr.Row(): 166 | im = gr.Image(type="pil", tool="sketch", source='upload', brush_radius=5, label="input") 167 | output_im = gr.Image(type="pil", label="output") 168 | with gr.Tab("Segment Anything", id="input_image"): 169 | segment_anything_btn = gr.Button(value="Segment", variant="primary") 170 | 171 | segment_anything_btn.click( 172 | fn=segment_anything, 173 | inputs=[im], 174 | outputs=[output_im], 175 | ) 176 | 177 | return [(ui, "Segment Anything", "segment_anything")] 178 | 179 | script_callbacks.on_ui_tabs(add_tab) 180 | --------------------------------------------------------------------------------