├── about.png
├── installation.png
├── README_zh.md
├── README.md
└── scripts
└── segment_anything.py
/about.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Erickrus/stable-diffusion-webui-segment-anything/HEAD/about.png
--------------------------------------------------------------------------------
/installation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Erickrus/stable-diffusion-webui-segment-anything/HEAD/installation.png
--------------------------------------------------------------------------------
/README_zh.md:
--------------------------------------------------------------------------------
1 | # stable-diffusion-webui 万能图像分割扩展
2 |
3 | ## 安装
4 |
5 | ### 安装依赖包
6 | ```shell
7 | pip3 install opencv-python matplotlib onnx onnxruntime
8 | pip3 install 'git+https://github.com/facebookresearch/segment-anything.git'
9 | ```
10 |
11 | 确保 pytorch>=2.0.0
12 |
13 | ### SAM权重
14 |
15 | 目前自动权重下载仅支持Linux/Mac, 使用的是wget。在Windows上可以将权重自行下载至models/sam/sam_vit_h_4b8939.pth
16 |
17 | https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
18 |
19 | ### 安装扩展
20 |
21 | 转到 `Extensions` Tab, 点击 `Install from URL`, 输入 `https://github.com/Erickrus/stable-diffusion-webui-segment-anything`, 点击 `Install`
22 |
23 |
24 |
25 | 脚本安装完成后,请重启webui
26 |
27 |
28 |
29 |
30 | ## 使用
31 |
32 | 将你的图片上传到左侧图像栏,然后在左侧图像上绘制/单击小圆点,然后单击Segment按钮。
33 |
34 | 请注意画笔半径请控制在 **5** 左右. 不要使用太大的半径,会影响最终分割点搜索和最终效果。
35 |
36 | 
37 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # An extension for stable-diffusion-webui that segment the image elements
2 |
3 | [中文版说明](https://github.com/Erickrus/stable-diffusion-webui-segment-anything/blob/main/README_zh.md)
4 |
5 | ## Installation
6 |
7 | ### Install dependencies
8 |
9 | ```shell
10 | pip3 install opencv-python matplotlib onnx onnxruntime
11 | pip3 install 'git+https://github.com/facebookresearch/segment-anything.git'
12 | ```
13 |
14 | ensure pytorch>=2.0.0
15 |
16 | ### Install extension from webui
17 |
18 | Go to `Extensions` Tab, Click `Install from URL`, Input `https://github.com/Erickrus/stable-diffusion-webui-segment-anything`, Click `Install`
19 |
20 |
21 |
22 | Please restart webui, after the script is installed
23 |
24 |
25 |
26 |
27 | ## How to use
28 |
29 | Upload the image to segment, draw/click tiny point(s) on the left side image and click segment button.
30 |
31 | Please make sure the brush radius is around **5**. Dont use large radius as they wont work.
32 |
33 | 
34 |
--------------------------------------------------------------------------------
/scripts/segment_anything.py:
--------------------------------------------------------------------------------
1 | import html
2 |
3 |
4 | # installation
5 | # pip3 install opencv-python matplotlib onnx onnxruntime
6 | # pip3 install 'git+https://github.com/facebookresearch/segment-anything.git'
7 | # wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
8 |
9 |
10 |
11 | import cv2
12 | import numpy as np
13 | import torch
14 |
15 | from PIL import Image
16 |
17 | #import matplotlib.pyplot as plt
18 |
19 | import os
20 |
21 | from segment_anything import sam_model_registry, SamPredictor
22 | from segment_anything.utils.onnx import SamOnnxModel
23 |
24 | import onnxruntime
25 | from onnxruntime.quantization import QuantType
26 | from onnxruntime.quantization.quantize import quantize_dynamic
27 |
28 | from modules import script_callbacks, shared
29 |
30 | import gradio as gr
31 | import torch
32 |
33 | model_dir = "models/sam/"
34 | checkpoint = model_dir + "sam_vit_h_4b8939.pth"
35 | model_type = "vit_h"
36 | if not os.path.exists(checkpoint):
37 | os.mkdir(model_dir)
38 | os.system("wget -O %s https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" % checkpoint)
39 |
40 |
41 |
42 | sam = sam_model_registry[model_type](checkpoint=checkpoint)
43 |
44 |
45 |
46 | def extract_onnx():
47 | import warnings
48 | onnx_model_path = model_dir + "sam_onnx_example.onnx"
49 | onnx_model = SamOnnxModel(sam, return_single_mask=True)
50 |
51 | dynamic_axes = {
52 | "point_coords": {1: "num_points"},
53 | "point_labels": {1: "num_points"},
54 | }
55 |
56 | embed_dim = sam.prompt_encoder.embed_dim
57 | embed_size = sam.prompt_encoder.image_embedding_size
58 | mask_input_size = [4 * x for x in embed_size]
59 | dummy_inputs = {
60 | "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
61 | "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
62 | "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
63 | "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
64 | "has_mask_input": torch.tensor([1], dtype=torch.float),
65 | "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
66 | }
67 | output_names = ["masks", "iou_predictions", "low_res_masks"]
68 |
69 | with warnings.catch_warnings():
70 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
71 | warnings.filterwarnings("ignore", category=UserWarning)
72 | with open(onnx_model_path, "wb") as f:
73 | torch.onnx.export(
74 | onnx_model,
75 | tuple(dummy_inputs.values()),
76 | f,
77 | export_params=True,
78 | verbose=False,
79 | opset_version=17,
80 | do_constant_folding=True,
81 | input_names=list(dummy_inputs.keys()),
82 | output_names=output_names,
83 | dynamic_axes=dynamic_axes,
84 | )
85 | onnx_model_quantized_path = model_dir + "sam_onnx_quantized_example.onnx"
86 | quantize_dynamic(
87 | model_input=onnx_model_path,
88 | model_output=onnx_model_quantized_path,
89 | optimize_model=True,
90 | per_channel=False,
91 | reduce_range=False,
92 | weight_type=QuantType.QUInt8,
93 | )
94 |
95 | onnx_model_quantized_path = model_dir + "sam_onnx_quantized_example.onnx"
96 | if not os.path.exists(onnx_model_quantized_path):
97 | extract_onnx()
98 | onnx_model_path = onnx_model_quantized_path
99 | ort_session = onnxruntime.InferenceSession(onnx_model_path)
100 |
101 | sam.to(device='cuda')
102 | predictor = SamPredictor(sam)
103 |
104 | def find_input_points(mask):
105 | input_points = []
106 | gray = 255 - np.array(mask)
107 | thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
108 |
109 | cnts = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
110 | cnts = cnts[0] if len(cnts) == 2 else cnts[1]
111 | for c in cnts:
112 | area = cv2.contourArea(c)
113 | if area < 100:
114 | x, y = c[:,:,0].mean(), c[:,:,1].mean()
115 | input_points.append([int(x), int(y)])
116 | return input_points
117 |
118 | def segment_anything(im):
119 | img = im['image'].convert("RGB")
120 | mask = im['mask'].convert("L")
121 | input_points = find_input_points(mask)
122 |
123 | image = np.array(img)
124 |
125 | predictor.set_image(image)
126 | image_embedding = predictor.get_image_embedding().cpu().numpy()
127 |
128 | input_point = np.array(input_points)
129 | input_label = np.ones(len(input_point)).astype(np.int)
130 | print(input_point)
131 |
132 | onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
133 | onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
134 |
135 | onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
136 |
137 | onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
138 | onnx_has_mask_input = np.zeros(1, dtype=np.float32)
139 |
140 | ort_inputs = {
141 | "image_embeddings": image_embedding,
142 | "point_coords": onnx_coord,
143 | "point_labels": onnx_label,
144 | "mask_input": onnx_mask_input,
145 | "has_mask_input": onnx_has_mask_input,
146 | "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
147 | }
148 |
149 | masks, _, low_res_logits = ort_session.run(None, ort_inputs)
150 | masks = masks > predictor.model.mask_threshold
151 |
152 | masks = np.squeeze(np.squeeze(masks,0),0)
153 | masks = masks.astype(np.uint8)
154 | masks = np.repeat(masks[...,None], 3, axis=2)
155 | image = im['image'].convert("RGB")
156 | white = (255 * np.ones([image.size[0],image.size[1],3])).astype(np.uint8)
157 | image = np.array(image) * masks + white * (1 - masks)
158 |
159 | return Image.fromarray(image)
160 |
161 | def add_tab():
162 | with gr.Blocks(analytics_enabled=False) as ui:
163 | with gr.Tabs() as tabs:
164 | with gr.Blocks():
165 | with gr.Row():
166 | im = gr.Image(type="pil", tool="sketch", source='upload', brush_radius=5, label="input")
167 | output_im = gr.Image(type="pil", label="output")
168 | with gr.Tab("Segment Anything", id="input_image"):
169 | segment_anything_btn = gr.Button(value="Segment", variant="primary")
170 |
171 | segment_anything_btn.click(
172 | fn=segment_anything,
173 | inputs=[im],
174 | outputs=[output_im],
175 | )
176 |
177 | return [(ui, "Segment Anything", "segment_anything")]
178 |
179 | script_callbacks.on_ui_tabs(add_tab)
180 |
--------------------------------------------------------------------------------