├── .idea
├── .gitignore
├── githubai.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── README_cn.md
├── app.py
├── demo.py
├── example
├── 1.gif
├── 1683122305662.png
├── 1683122435166.png
├── 1683134557206.png
├── 2.gif
├── 3.gif
├── 4.gif
├── 5.gif
├── 6.gif
├── image.jpg
├── imagemask.jpg
├── images.png
└── mask1.jpg
├── lamamodel.py
├── metaseg
├── __init__.py
├── falai_demo.py
├── generator
│ ├── __init__.py
│ ├── automatic_mask_generator.py
│ ├── build_sam.py
│ └── predictor.py
├── mask_predictor.py
├── modeling
│ ├── __init__.py
│ ├── common.py
│ ├── image_encoder.py
│ ├── mask_decoder.py
│ ├── prompt_encoder.py
│ ├── sam.py
│ └── transformer.py
├── sahi_predict.py
└── utils
│ ├── __init__.py
│ ├── amg.py
│ ├── data_utils.py
│ ├── file_utils.py
│ ├── onnx.py
│ └── transforms.py
└── requirements.txt
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/githubai.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | English | [简体中文](README_cn.md)
2 | # Modify-Anything: Segment Anything Meets Video and Image Modify and Picture Video Background Replacement
3 |
4 | Modify-Anything is based on YOLO5,YOLO8, for video and image detection. Segment-anything,lama_cleaner is applied to
5 | segment, modify, erase, and obtain the target image. The background of the target image video
6 | can be changed, and the background of the image video can be changed.
7 |
8 |
9 | ## Installation
10 | The code requires python>=3.8, as well as pytorch>=1.7 and torchvision>=0.8. Please follow the instructions here
11 | to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support
12 | is strongly recommended.
13 | To install the Modify-Anything, please follow these steps:
14 | - The first time it runs, it will download the model itself. If the download is too slow, the phone will download and place it as follows
15 | - Train your own YOLO5 or YOLOv8 models to detect segmentation, modification, and erasure.
16 | The default models used in this project are "YOLOv5l. pt", "YOLOv5l6. pt", "YOLOv8l. pt", and "YOLOv8x. pt".
17 | Please download and place them in the project root directory
18 | - Download the Segment anything model and place it in the project root directory sam_vit_h_4b8939.pth (change to) vit_h.pth,sam_vit_l_0b3195.pth (change to) vit_l.pth,sam_vit_b_01ec64.pth (change to) vit_b.pth
19 | - Install pip install ultralytics sahi fal_serverless lama_cleaner tqdm or pip install - r requirements. Txt
20 | - Run python app.py
21 | - The generated results are all in the output directory
22 |
23 |
24 |
25 |
26 |
27 | ## Modify Anything Image and Picture Video Background Replacement
28 |
29 |
30 |
31 |  |
32 |  |
33 |  |
34 |
35 |
36 |
37 |
38 |  |
39 |  |
40 |  |
41 |
42 |
43 |
44 | ## Modify Anything Video and Picture Video Background Replacement
45 |
46 |
47 |  |
48 |  |
49 |  |
50 |
51 |
52 |
53 |
54 |  |
55 |  |
56 |
57 |
58 |
59 | ## Acknowledgments
60 | - [LaMa](https://github.com/advimman/lama)
61 | - [Segment Anything](https://github.com/facebookresearch/segment-anything)
62 | - [YOLOv8](https://github.com/ultralytics/ultralytics)
63 |
64 | ## Citation
65 | If you find this work useful for your research, please cite us:
66 | ```
67 | @article{
68 | title={Modify-Anything: Segment Anything Meets Video and Image Modify and Picture Video Background Replacement},
69 | author={Zhang Jing},
70 | year={2023}
71 | }
72 | ```
73 |
--------------------------------------------------------------------------------
/README_cn.md:
--------------------------------------------------------------------------------
1 | 简体中文 | [English](README.md)
2 | # Modify-Anything: Segment Anything Meets Video and Image Modify and Picture Video Background Replacement
3 |
4 | Modify-Anything 是基于YOLO5,YOLO8对视频,图片检测指定目标进行应用segment-anything,lama_cleaner对目标
5 | 分割,修改,擦除,得到目标对目标图片视频进行换背景等,可换图背景视频背景。
6 |
7 | ## 安装
8 | 该代码需要python>=3.8,以及pytorch>=1.7和torchvision>=0.8。请按照此处的说明安装PyTorch和TorchVision依赖项。
9 | 强烈建议在支持CUDA的情况下同时安装PyTorch和TorchVision。
10 |
11 | 要安装Modify Anything,请执行以下步骤:
12 | - 第一次运行都会自己下载模型,如下载过慢手机下载放置如下
13 | - 训练自己的要检测分割,修改,擦除的yolo5或者yolov8模型,本项目用的默认"yolov5l.pt", "yolov5l6.pt", "yolov8l.pt", "yolov8x.pt"模型请自己下载放入项目根目录中
14 | - 下载Segment-anything模型放入项目根目录中,sam_vit_h_4b8939.pth 改为vit_h.pth,sam_vit_l_0b3195.pth改为vit_l.pth,sam_vit_b_01ec64.pth改为vit_b.pth
15 | - 安装pip install ultralytics sahi fal_serverless lama_cleaner tqdm 或者pip install -r requirements.txt
16 | - 运行python app.py
17 | - 生成结果全在output目录
18 |
19 |
20 |
21 |
22 |
23 | ## Modify Anything Image and Picture Video Background Replacement
24 |
25 |
26 |
27 |  |
28 |  |
29 |  |
30 |
31 |
32 |
33 |
34 |  |
35 |  |
36 |  |
37 |
38 |
39 |
40 | ## Modify Anything Video and Picture Video Background Replacement
41 |
42 |
43 |  |
44 |  |
45 |  |
46 |
47 |
48 |
49 |
50 |  |
51 |  |
52 |
53 |
54 |
55 | ## Acknowledgments
56 | - [LaMa](https://github.com/advimman/lama)
57 | - [Segment Anything](https://github.com/facebookresearch/segment-anything)
58 | - [YOLOv8](https://github.com/ultralytics/ultralytics)
59 |
60 | ## Citation
61 | If you find this work useful for your research, please cite us:
62 | ```
63 | @article{
64 | title={Modify-Anything: Segment Anything Meets Video and Image Modify and Picture Video Background Replacement},
65 | author={Zhang Jing},
66 | year={2023}
67 | }
68 | ```
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from demo import lama_image_app, lama_video_app, change_video, change_image, change_aimage, change_avideo
3 |
4 | available_lamamodels = ["lama", "ldm", "zits", "mat", "fcf", "manga", "sd2"]
5 | default_lamammodel = "lama"
6 |
7 |
8 | def image_app():
9 | with gr.Blocks():
10 | with gr.Row():
11 | with gr.Column():
12 | image_file = gr.Image(type="filepath").style(height=260)
13 | with gr.Row():
14 | with gr.Column():
15 | image_model_type = gr.Dropdown(
16 | choices=[
17 | "vit_h",
18 | "vit_l",
19 | "vit_b",
20 | ],
21 | value="vit_l",
22 | label="Model Type",
23 | )
24 |
25 | with gr.Row():
26 | with gr.Column():
27 | sahi_model_type = gr.Dropdown(
28 | choices=[
29 | "yolov5",
30 | "yolov8",
31 | ],
32 | value="yolov8",
33 | label="Detector Model Type",
34 | )
35 | sahi_image_size = gr.Slider(
36 | minimum=0,
37 | maximum=1600,
38 | step=32,
39 | value=640,
40 | label="Image Size",
41 | )
42 |
43 | sahi_overlap_width = gr.Slider(
44 | minimum=0,
45 | maximum=1,
46 | step=0.1,
47 | value=0.2,
48 | label="Overlap Width",
49 | )
50 |
51 | sahi_slice_width = gr.Slider(
52 | minimum=0,
53 | maximum=640,
54 | step=32,
55 | value=256,
56 | label="Slice Width",
57 | )
58 |
59 | with gr.Row():
60 | with gr.Column():
61 | sahi_model_path = gr.Dropdown(
62 | choices=["yolov5l.pt", "yolov5l6.pt", "yolov8l.pt", "yolov8x.pt"],
63 | value="yolov8l.pt",
64 | label="Detector Model Path",
65 | )
66 | selected_lamamodel = gr.Dropdown(choices=available_lamamodels, label="lama Model(lama, ldm, zits, mat, fcf, mang, sd2)",
67 | value=default_lamammodel,
68 | interactive=True)
69 |
70 | sahi_conf_th = gr.Slider(
71 | minimum=0,
72 | maximum=1,
73 | step=0.1,
74 | value=0.2,
75 | label="Confidence Threshold",
76 | )
77 | sahi_overlap_height = gr.Slider(
78 | minimum=0,
79 | maximum=1,
80 | step=0.1,
81 | value=0.2,
82 | label="Overlap Height",
83 | )
84 | sahi_slice_height = gr.Slider(
85 | minimum=0,
86 | maximum=640,
87 | step=32,
88 | value=256,
89 | label="Slice Height",
90 | )
91 | image_predict = gr.Button(value="Generate vector images from targets(去目标生成矢量图片)")
92 |
93 | with gr.Column():
94 | output_image = gr.Gallery()
95 |
96 | image_predict.click(
97 | fn=lama_image_app,
98 | inputs=[
99 | image_file,
100 | image_model_type,
101 | selected_lamamodel,
102 | sahi_model_type,
103 | sahi_model_path,
104 | sahi_conf_th,
105 | sahi_image_size,
106 | sahi_slice_height,
107 | sahi_slice_width,
108 | sahi_overlap_height,
109 | sahi_overlap_width,
110 | ],
111 | outputs=[output_image],
112 | )
113 | with gr.Row():
114 | with gr.Column():
115 | b_image = gr.Image(type="filepath")
116 | with gr.Column():
117 | output_change = gr.Image()
118 | with gr.Row():
119 | change_images = gr.Button(value="Target and background image(目标与背景图)")
120 | change_images.click(
121 | fn=change_aimage,
122 | inputs=[image_file, b_image],
123 | outputs=[output_change]
124 |
125 | )
126 | with gr.Row():
127 | with gr.Column():
128 | b_video = gr.Video(type="filepath").style(height=260)
129 |
130 | with gr.Column():
131 | output_change = gr.Video()
132 | with gr.Row():
133 | change_videos = gr.Button(value="Target and Background Video(目标与背景视频)")
134 | change_videos.click(
135 | fn=change_avideo,
136 | inputs=[image_file, b_video],
137 | outputs=[output_change]
138 |
139 | )
140 |
141 |
142 | def video_app():
143 | with gr.Blocks():
144 | with gr.Row():
145 | with gr.Column():
146 | sahi_image_file = gr.Video().style(height=260)
147 | sahi_autoseg_model_type = gr.Dropdown(
148 | choices=[
149 | "vit_h",
150 | "vit_l",
151 | "vit_b",
152 | ],
153 | value="vit_l",
154 | label="Sam Model Type",
155 | )
156 |
157 | with gr.Row():
158 | with gr.Column():
159 | sahi_model_type = gr.Dropdown(
160 | choices=[
161 | "yolov5",
162 | "yolov8",
163 | ],
164 | value="yolov8",
165 | label="Detector Model Type",
166 | )
167 | sahi_image_size = gr.Slider(
168 | minimum=0,
169 | maximum=1600,
170 | step=32,
171 | value=640,
172 | label="Image Size",
173 | )
174 |
175 | sahi_overlap_width = gr.Slider(
176 | minimum=0,
177 | maximum=1,
178 | step=0.1,
179 | value=0.2,
180 | label="Overlap Width",
181 | )
182 |
183 | sahi_slice_width = gr.Slider(
184 | minimum=0,
185 | maximum=640,
186 | step=32,
187 | value=256,
188 | label="Slice Width",
189 | )
190 |
191 | with gr.Row():
192 | with gr.Column():
193 | sahi_model_path = gr.Dropdown(
194 | choices=["yolov5l.pt", "yolov5l6.pt", "yolov8l.pt", "yolov8x.pt"],
195 | value="yolov8l.pt",
196 | label="Detector Model Path",
197 | )
198 | selected_lamamodel = gr.Dropdown(choices=available_lamamodels, label="lama Model(lama, ldm, zits, mat, fcf, mang, sd2)",
199 | value=default_lamammodel,
200 | interactive=True)
201 |
202 | sahi_conf_th = gr.Slider(
203 | minimum=0,
204 | maximum=1,
205 | step=0.1,
206 | value=0.2,
207 | label="Confidence Threshold",
208 | )
209 | sahi_overlap_height = gr.Slider(
210 | minimum=0,
211 | maximum=1,
212 | step=0.1,
213 | value=0.2,
214 | label="Overlap Height",
215 | )
216 | sahi_slice_height = gr.Slider(
217 | minimum=0,
218 | maximum=640,
219 | step=32,
220 | value=256,
221 | label="Slice Height",
222 | )
223 | sahi_image_predict = gr.Button(value="Generate vector video by removing targets(去目标生成矢量视频)")
224 |
225 | with gr.Column():
226 | output_video = gr.Video()
227 | output_video1 = gr.Video()
228 |
229 | sahi_image_predict.click(
230 | fn=lama_video_app,
231 | inputs=[
232 | sahi_image_file,
233 | sahi_autoseg_model_type,
234 | selected_lamamodel,
235 | sahi_model_type,
236 | sahi_model_path,
237 | sahi_conf_th,
238 | sahi_image_size,
239 | sahi_slice_height,
240 | sahi_slice_width,
241 | sahi_overlap_height,
242 | sahi_overlap_width,
243 | ],
244 | outputs=[output_video, output_video1],
245 | )
246 | with gr.Row():
247 | with gr.Column():
248 | b_image = gr.Image(type="filepath").style(height=260)
249 |
250 | with gr.Column():
251 | output_change = gr.Video()
252 |
253 | with gr.Row():
254 | change_images = gr.Button(value="Target and background image(目标与背景图)")
255 | change_images.click(
256 | fn=change_image,
257 | inputs=[sahi_image_file, b_image],
258 | outputs=[output_change])
259 |
260 | with gr.Row():
261 | with gr.Column():
262 | b_video = gr.Video()
263 | with gr.Column():
264 | output_change = gr.Video()
265 | with gr.Row():
266 | change_videos = gr.Button(value="Target and Background Video(目标与背景视频)")
267 | change_videos.click(
268 | fn=change_video,
269 | inputs=[sahi_image_file, b_video],
270 | outputs=[output_change]
271 |
272 | )
273 |
274 |
275 | def metaseg_app():
276 | app = gr.Blocks()
277 | with app:
278 | with gr.Row():
279 | with gr.Column():
280 | with gr.Tab("Image"):
281 | image_app()
282 |
283 | with gr.Tab("Video"):
284 | video_app()
285 |
286 | app.queue(concurrency_count=1)
287 | app.launch(debug=True, enable_queue=True)
288 |
289 |
290 | if __name__ == "__main__":
291 | metaseg_app()
292 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | from metaseg import SahiAutoSegmentation,sahi_sliced_predict
2 | import cv2
3 | from tqdm import tqdm
4 | import os
5 | from lamamodel import lamamain
6 | import numpy as np
7 |
8 | def load_video(video_path, output_path="output.mp4"):
9 | cap = cv2.VideoCapture(video_path)
10 | frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
11 | frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
12 | fourcc = cv2.VideoWriter_fourcc(*"XVID")
13 | fps = int(cap.get(cv2.CAP_PROP_FPS))
14 | out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
15 |
16 | return cap, out
17 | def lama_image_app(
18 | image_path,
19 | sam_model_type,
20 | selected_lamamodel,
21 | detection_model_type,
22 | detection_model_path,
23 | conf_th,
24 | image_size,
25 | slice_height,
26 | slice_width,
27 | overlap_height_ratio,
28 | overlap_width_ratio,
29 | ):
30 | img = cv2.imread(image_path, cv2.IMREAD_COLOR)
31 | width = img.shape[1]
32 | height = img.shape[0]
33 | boxes = sahi_sliced_predict(
34 | image_path=img,
35 | detection_model_type=detection_model_type, # yolov8, detectron2, mmdetection, torchvision
36 | detection_model_path=detection_model_path,
37 | conf_th=conf_th,
38 | image_size=image_size,
39 | slice_height=slice_height,
40 | slice_width=slice_width,
41 | overlap_height_ratio=overlap_height_ratio,
42 | overlap_width_ratio=overlap_width_ratio,
43 | )
44 | if len(boxes) == 0:
45 | boxes = [0, 0, 0, 0]
46 |
47 | masks =SahiAutoSegmentation().predict(
48 | source=img,
49 | model_type=sam_model_type,
50 | input_box=boxes,
51 | multimask_output=False,
52 | random_color=False,
53 | show=False,
54 | )
55 | a = np.full((1, height, width), False)
56 | for idx, mask in enumerate(masks):
57 | if mask[0] != None:
58 | bmasks = mask.detach().cpu().numpy()
59 | mask2 = np.logical_or(a, bmasks)
60 | a[:, :, :] = mask2
61 | h, w = a.shape[-2:]
62 | mask_image = a.reshape(h, w, 1)
63 | cv2.imwrite('./output/aimage/image.jpg', img)
64 | cv2.imwrite("output/amask/mask.jpg" , mask_image * 255)
65 | mask_image = mask_image * img
66 | tmp = cv2.cvtColor(mask_image, cv2.COLOR_BGR2GRAY)
67 | _, alpha = cv2.threshold(tmp, 0, 255, cv2.THRESH_BINARY)
68 | b, g, r = cv2.split(mask_image)
69 | rgba = [b, g, r, alpha]
70 | dst = cv2.merge(rgba, 4)
71 |
72 |
73 | cv2.imwrite("output/aimages/images.png" , dst)
74 |
75 | lamaindex = lamamain("output/aimage/image.jpg", "output/amask/mask.jpg" , selected_lamamodel)
76 | lamaimage = cv2.cvtColor(lamaindex, cv2.COLOR_BGR2RGB)
77 | cv2.imwrite('output/alama/imagemask.jpg' , lamaimage)
78 |
79 | lists=["output/aimage/image.jpg","output/amask/mask.jpg","output/aimages/images.png","output/alama/imagemask.jpg"]
80 |
81 |
82 | return lists
83 |
84 |
85 |
86 | def lama_video_app(
87 | image_path,
88 | sam_model_type,
89 | selected_lamamodel,
90 | detection_model_type,
91 | detection_model_path,
92 | conf_th,
93 | image_size,
94 | slice_height,
95 | slice_width,
96 | overlap_height_ratio,
97 | overlap_width_ratio,
98 | ):
99 | cap, out = load_video(image_path)
100 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
101 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
102 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
103 | index = 1
104 | for _ in tqdm(range(length)):
105 | ret, frame = cap.read()
106 | if not ret:
107 | break
108 |
109 | boxes = sahi_sliced_predict(
110 | image_path=frame,
111 | detection_model_type=detection_model_type, # yolov8, detectron2, mmdetection, torchvision
112 | detection_model_path=detection_model_path,
113 | conf_th=conf_th,
114 | image_size=image_size,
115 | slice_height=slice_height,
116 | slice_width=slice_width,
117 | overlap_height_ratio=overlap_height_ratio,
118 | overlap_width_ratio=overlap_width_ratio,
119 | )
120 | if len(boxes) == 0:
121 | boxes = [0, 0, 0, 0]
122 |
123 |
124 | masks = SahiAutoSegmentation().predict(
125 | source=frame,
126 | model_type=sam_model_type,
127 | input_box=boxes,
128 | multimask_output=False,
129 | random_color=False,
130 | show=False,
131 | )
132 | a = np.full((1, height, width), False)
133 | for idx, mask in enumerate(masks):
134 | if mask[0] != None:
135 | bmasks = mask.detach().cpu().numpy()
136 | mask2 = np.logical_or(a, bmasks)
137 | a[:, :, :] = mask2
138 | h, w = a.shape[-2:]
139 | mask_image = a.reshape(h, w, 1)
140 | cv2.imwrite('./output/image/image%s.jpg' % index, frame)
141 | cv2.imwrite("output/mask/mask%s.jpg" % index, mask_image * 255)
142 | mask_image = mask_image * frame
143 | tmp = cv2.cvtColor(mask_image, cv2.COLOR_BGR2GRAY)
144 | _, alpha = cv2.threshold(tmp, 0, 255, cv2.THRESH_BINARY)
145 | b, g, r = cv2.split(mask_image)
146 | rgba = [b, g, r, alpha]
147 | dst = cv2.merge(rgba, 4)
148 | cv2.imwrite("output/images/images%s.png" % index, dst)
149 |
150 | lamaindex = lamamain("output/image/image%s.jpg" % index, "output/mask/mask%s.jpg" % index,selected_lamamodel)
151 | lamaimage = cv2.cvtColor(lamaindex, cv2.COLOR_BGR2RGB)
152 | cv2.imwrite('output/lama/imagemask%s.jpg' % index, lamaimage)
153 | index += 1
154 | list = os.listdir("./output/lama/")
155 | list.sort(key=lambda x: int(x.replace("imagemask", "").split('.')[0]))
156 | list1 = os.listdir("./output/images/")
157 | list1.sort(key=lambda x: int(x.replace("images", "").split('.')[0]))
158 | frame_width1 = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
159 | frame_height1 = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
160 | fourcc = cv2.VideoWriter_fourcc(*"XVID")
161 | fps = int(cap.get(cv2.CAP_PROP_FPS))
162 | out1 = cv2.VideoWriter("output1.mp4", fourcc, fps, (frame_width1, frame_height1))
163 | ## 使用切片将图片名称单独切开
164 |
165 | for path in list:
166 | paths = ("./output/lama/" + path)
167 | frame = cv2.imread(paths)
168 | out.write(frame)
169 |
170 | for path in list1:
171 | paths = ("./output/images/" + path)
172 | frame = cv2.imread(paths)
173 | out1.write(frame)
174 | cap.release()
175 | return "output.mp4", "output1.mp4"
176 |
177 |
178 | def change_video(video,bgvideo):
179 | cap_img = cv2.VideoCapture(video)
180 | fps = cap_img.get(cv2.CAP_PROP_FPS)
181 | total_frames = int(cap_img.get(cv2.CAP_PROP_FRAME_COUNT))
182 | width = int(cap_img.get(cv2.CAP_PROP_FRAME_WIDTH))
183 | height = int(cap_img.get(cv2.CAP_PROP_FRAME_HEIGHT))
184 | cap_out = cv2.VideoWriter("output2.mp4",
185 | cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps,
186 | (width, height))
187 | cap_bg = cv2.VideoCapture(bgvideo)
188 | bg_frame_nums = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
189 | bg_frame_idx = 1
190 | with tqdm(total=total_frames) as pbar:
191 | img_frame_idx = 1
192 | while cap_img.isOpened():
193 | ret_img, origin_img = cap_img.read()
194 | if not ret_img:
195 | break
196 | ret_bg, bg = cap_bg.read()
197 | img = cv2.imread("output/mask/mask%s.jpg" % img_frame_idx)
198 | _, mask_thr = cv2.threshold(img, 240, 1, cv2.THRESH_BINARY)
199 | if not ret_bg:
200 | break
201 | bg_frame_idx += 1
202 | if bg_frame_idx == bg_frame_nums:
203 | bg_frame_idx = 1
204 | cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
205 | bg = cv2.resize(bg, (width, height))
206 | if bg.ndim == 2:
207 | bg = bg[..., np.newaxis]
208 | out = (mask_thr * origin_img + (1 - mask_thr) * bg).astype(np.uint8)
209 | cap_out.write(out)
210 | img_frame_idx += 1
211 | pbar.update(1)
212 | cap_img.release()
213 | cap_out.release()
214 | return "output2.mp4"
215 |
216 | def change_image(video,bgimage):
217 | cap_img = cv2.VideoCapture(video)
218 | fps = cap_img.get(cv2.CAP_PROP_FPS)
219 | total_frames = int(cap_img.get(cv2.CAP_PROP_FRAME_COUNT))
220 | width = int(cap_img.get(cv2.CAP_PROP_FRAME_WIDTH))
221 | height = int(cap_img.get(cv2.CAP_PROP_FRAME_HEIGHT))
222 | cap_out = cv2.VideoWriter("output3.mp4",
223 | cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps,
224 | (width, height))
225 | bg=cv2.imread(bgimage)
226 | with tqdm(total=total_frames) as pbar:
227 | img_frame_idx = 1
228 | while cap_img.isOpened():
229 | ret_img, origin_img = cap_img.read()
230 | if not ret_img:
231 | break
232 | img = cv2.imread("output/mask/mask%s.jpg" % img_frame_idx)
233 | _, mask_thr = cv2.threshold(img, 240, 1, cv2.THRESH_BINARY)
234 | bg = cv2.resize(bg, (width, height))
235 | if bg.ndim == 2:
236 | bg = bg[..., np.newaxis]
237 | out = (mask_thr * origin_img + (1 - mask_thr) * bg).astype(np.uint8)
238 | cap_out.write(out)
239 | img_frame_idx += 1
240 | pbar.update(1)
241 | cap_img.release()
242 | cap_out.release()
243 | return "output3.mp4"
244 |
245 | def change_aimage(image,bgimage):
246 | origin_img=cv2.imread(image)
247 | bg=cv2.imread(bgimage)
248 | img = cv2.imread("output/amask/mask.jpg")
249 | _, mask_thr = cv2.threshold(img, 240, 1, cv2.THRESH_BINARY)
250 | width = origin_img.shape[1]
251 | height = origin_img.shape[0]
252 | bg = cv2.resize(bg, (width, height))
253 | if bg.ndim == 2:
254 | bg = bg[..., np.newaxis]
255 | out = (mask_thr * origin_img + (1 - mask_thr) * bg).astype(np.uint8)
256 | out = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
257 | return out
258 | def change_avideo(image,bgvideo):
259 | origin_img = cv2.imread(image)
260 |
261 | cap_bg = cv2.VideoCapture(bgvideo)
262 | fps = cap_bg.get(cv2.CAP_PROP_FPS)
263 | bg_frame_nums = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
264 | total_frames = int(cap_bg.get(cv2.CAP_PROP_FRAME_COUNT))
265 | width = int(cap_bg.get(cv2.CAP_PROP_FRAME_WIDTH))
266 | height = int(cap_bg.get(cv2.CAP_PROP_FRAME_HEIGHT))
267 | cap_out = cv2.VideoWriter("output4.mp4",
268 | cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps,
269 | (width, height))
270 | mask = cv2.imread("output/amask/mask.jpg")
271 | origin_img = cv2.resize(origin_img, (width, height))
272 | mask = cv2.resize(mask, (width, height))
273 | for _ in tqdm(range(total_frames)):
274 | ret_bg, bg = cap_bg.read()
275 | if not ret_bg:
276 | break
277 | _, mask_thr = cv2.threshold(mask, 240, 1, cv2.THRESH_BINARY)
278 | out = (mask_thr * origin_img + (1 - mask_thr) * bg).astype(np.uint8)
279 | cap_out.write(out)
280 | cap_bg.release()
281 | cap_out.release()
282 | return "output4.mp4"
--------------------------------------------------------------------------------
/example/1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/1.gif
--------------------------------------------------------------------------------
/example/1683122305662.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/1683122305662.png
--------------------------------------------------------------------------------
/example/1683122435166.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/1683122435166.png
--------------------------------------------------------------------------------
/example/1683134557206.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/1683134557206.png
--------------------------------------------------------------------------------
/example/2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/2.gif
--------------------------------------------------------------------------------
/example/3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/3.gif
--------------------------------------------------------------------------------
/example/4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/4.gif
--------------------------------------------------------------------------------
/example/5.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/5.gif
--------------------------------------------------------------------------------
/example/6.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/6.gif
--------------------------------------------------------------------------------
/example/image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/image.jpg
--------------------------------------------------------------------------------
/example/imagemask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/imagemask.jpg
--------------------------------------------------------------------------------
/example/images.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/images.png
--------------------------------------------------------------------------------
/example/mask1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/mask1.jpg
--------------------------------------------------------------------------------
/lamamodel.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from lama_cleaner.parse_args import parse_args
3 | import random
4 | import time
5 | import imghdr
6 | from typing import Union
7 |
8 | import cv2
9 | import torch
10 | from loguru import logger
11 | from lama_cleaner.model_manager import ModelManager
12 | from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler
13 |
14 |
15 | from enum import Enum
16 |
17 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
18 |
19 | model: ModelManager(name="lama", device=device)
20 | device = torch.device(device)
21 | input_image_path: str = None
22 | is_disable_model_switch: bool = False
23 | is_desktop: bool = False
24 | from lama_cleaner.helper import (
25 | resize_max_size,
26 | )
27 |
28 |
29 | def diffuser_callback(i, t, latents):
30 | pass
31 | def get_image_ext(img_bytes):
32 | w = imghdr.what("", img_bytes)
33 | if w is None:
34 | w = "jpeg"
35 | return w
36 | class LDMSampler(str, Enum):
37 | ddim = "ddim"
38 | plms = "plms"
39 | def get_data(img_p ,mask_p):
40 | img = cv2.imread(str(img_p))
41 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
42 | mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
43 | mask = cv2.dilate(
44 | mask,
45 | np.ones((10, 10), np.uint8),
46 | iterations=1
47 | )
48 | img = cv2.resize(img, None, fx=1, fy= 1.0, interpolation=cv2.INTER_AREA)
49 | mask = cv2.resize(mask, None, fx=1, fy= 1.0, interpolation=cv2.INTER_NEAREST)
50 | return img, mask
51 |
52 | def process(img_p, mask_p):
53 | image, mask = get_data(img_p=img_p, mask_p=mask_p)
54 | alpha_channel = image[:, :, -1]
55 | if image.shape[:2] != mask.shape[:2]:
56 | return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400
57 |
58 | original_shape = image.shape
59 | interpolation = cv2.INTER_CUBIC
60 |
61 | size_limit: Union[int, str] = 2500
62 |
63 | if size_limit == "Original":
64 | size_limit = max(image.shape)
65 | else:
66 | size_limit = int(size_limit)
67 |
68 | config = Config(
69 | ldm_steps=1,
70 | ldm_sampler=LDMSampler.plms,
71 | hd_strategy=HDStrategy.ORIGINAL,
72 | hd_strategy_crop_margin=32,
73 | hd_strategy_crop_trigger_size=200,
74 | hd_strategy_resize_limit=200,
75 | )
76 | if config.sd_seed == -1:
77 | config.sd_seed = random.randint(1, 999999999)
78 |
79 | logger.info(f"Origin image shape: {original_shape}")
80 | image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
81 | logger.info(f"Resized image shape: {image.shape}")
82 |
83 | mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
84 |
85 | start = time.time()
86 |
87 | res_np_img = model(image, mask, config) # -----------------------导入模型
88 |
89 | return res_np_img
90 |
91 |
92 | def lamamain(image,mask,name):
93 | args = parse_args()
94 | global model
95 | global device
96 | global input_image_path
97 | global is_disable_model_switch
98 | global is_desktop
99 |
100 | device = torch.device(args.device)
101 | input_image_path = args.input
102 | is_disable_model_switch = args.disable_model_switch
103 | is_desktop = args.gui
104 | if is_disable_model_switch:
105 | logger.info(f"Start with --disable-model-switch, model switch on frontend is disable")
106 |
107 | model = ModelManager(
108 | name=name,
109 | device=device,
110 | hf_access_token=args.hf_access_token,
111 | sd_disable_nsfw=args.sd_disable_nsfw,
112 | sd_cpu_textencoder=args.sd_cpu_textencoder,
113 | sd_run_local=args.sd_run_local,
114 | callback=diffuser_callback,
115 | )
116 | image=process(image,mask)
117 | image=np.uint8(image)
118 | images = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
119 | return images
120 | # cv2.imwrite('imagemask.jpg', images)
121 | # print(images)
122 | if __name__ == '__main__':
123 | lamamain()
--------------------------------------------------------------------------------
/metaseg/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from metaseg.falai_demo import falai_automask_image, falai_manuelmask_image
8 | from metaseg.generator.automatic_mask_generator import SamAutomaticMaskGenerator
9 | from metaseg.generator.build_sam import build_sam, build_sam_vit_b, build_sam_vit_h, build_sam_vit_l, sam_model_registry
10 | from metaseg.generator.predictor import SamPredictor
11 | from metaseg.mask_predictor import SegAutoMaskPredictor, SegManualMaskPredictor
12 | from metaseg.sahi_predict import SahiAutoSegmentation, sahi_sliced_predict
13 |
14 | __version__ = "0.7.3"
15 |
--------------------------------------------------------------------------------
/metaseg/falai_demo.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | from PIL import Image
4 |
5 | from metaseg.mask_predictor import SegAutoMaskPredictor, SegManualMaskPredictor
6 | from metaseg.utils.data_utils import load_server_image
7 |
8 | try:
9 | from fal_serverless import isolated
10 | except ImportError:
11 | raise ImportError("Please install FalAI library using 'pip install fal_serverless'.")
12 |
13 |
14 | @isolated(requirements=["metaseg"], keep_alive=1800, machine_type="GPU-T4")
15 | def automask_image(data, model_type="vit_b", points_per_side=16, points_per_batch=32, min_area=0):
16 | image_path, output_path = load_server_image(data)
17 | SegAutoMaskPredictor().image_predict(
18 | source=image_path,
19 | model_type=model_type,
20 | points_per_side=points_per_side,
21 | points_per_batch=points_per_batch,
22 | min_area=min_area,
23 | output_path=output_path,
24 | show=False,
25 | save=True,
26 | )
27 | with open(output_path, "rb") as f:
28 | result = f.read()
29 |
30 | return result
31 |
32 | @isolated(requirements=["metaseg"], keep_alive=1800, machine_type="GPU-T4")
33 | def manuelmask_image(
34 | data,
35 | model_type="vit_b",
36 | input_point=[[100, 100], [200, 200]],
37 | input_label=[0, 1],
38 | input_box=[100, 100, 200, 200],
39 | multimask_output=False,
40 | random_color=False,
41 | min_area=0,
42 | ):
43 | image_path, output_path = load_server_image(data)
44 | SegManualMaskPredictor().image_predict(
45 | source=image_path,
46 | model_type=model_type,
47 | input_point=input_point,
48 | input_label=input_label,
49 | input_box=input_box, #
50 | multimask_output=multimask_output,
51 | random_color=random_color,
52 | min_area=min_area, #
53 | output_path=output_path,
54 | show=False,
55 | save=True,
56 | )
57 | with open(output_path, "rb") as f:
58 | result = f.read()
59 |
60 | return result
61 |
62 |
63 | def falai_automask_image(image_path, model_type="vit_b", points_per_side=16, points_per_batch=32, min_area=0):
64 | with open(image_path, "rb") as f:
65 | data = f.read()
66 |
67 | image = automask_image(
68 | data=data,
69 | model_type=model_type,
70 | points_per_side=points_per_side,
71 | points_per_batch=points_per_batch,
72 | min_area=min_area,
73 | )
74 | image = Image.open(BytesIO(image))
75 | return image
76 |
77 |
78 | def falai_manuelmask_image(
79 | image_path,
80 | model_type="vit_b",
81 | input_point=[[100, 100], [200, 200]],
82 | input_label=[0, 1],
83 | input_box=[100, 100, 200, 200],
84 | multimask_output=False,
85 | random_color=False,
86 | min_area=0,
87 | ):
88 | with open(image_path, "rb") as f:
89 | data = f.read()
90 |
91 | image = manuelmask_image(
92 | data=data,
93 | model_type=model_type,
94 | input_point=input_point,
95 | input_label=input_label,
96 | input_box=input_box,
97 | multimask_output=multimask_output,
98 | random_color=random_color,
99 | min_area=min_area,
100 | )
101 | image = Image.open(BytesIO(image))
102 | return image
103 |
--------------------------------------------------------------------------------
/metaseg/generator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/metaseg/generator/__init__.py
--------------------------------------------------------------------------------
/metaseg/generator/automatic_mask_generator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Any, Dict, List, Optional, Tuple
8 |
9 | import numpy as np
10 | import torch
11 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore
12 |
13 | from metaseg.generator.predictor import SamPredictor
14 | from metaseg.modeling import Sam
15 | from metaseg.utils.amg import (
16 | MaskData,
17 | area_from_rle,
18 | batch_iterator,
19 | batched_mask_to_box,
20 | box_xyxy_to_xywh,
21 | build_all_layer_point_grids,
22 | calculate_stability_score,
23 | coco_encode_rle,
24 | generate_crop_boxes,
25 | is_box_near_crop_edge,
26 | mask_to_rle_pytorch,
27 | remove_small_regions,
28 | rle_to_mask,
29 | uncrop_boxes_xyxy,
30 | uncrop_masks,
31 | uncrop_points,
32 | )
33 |
34 |
35 | class SamAutomaticMaskGenerator:
36 | def __init__(
37 | self,
38 | model: Sam,
39 | points_per_side: Optional[int] = 32,
40 | points_per_batch: int = 64,
41 | pred_iou_thresh: float = 0.88,
42 | stability_score_thresh: float = 0.95,
43 | stability_score_offset: float = 1.0,
44 | box_nms_thresh: float = 0.7,
45 | crop_n_layers: int = 0,
46 | crop_nms_thresh: float = 0.7,
47 | crop_overlap_ratio: float = 512 / 1500,
48 | crop_n_points_downscale_factor: int = 1,
49 | point_grids: Optional[List[np.ndarray]] = None,
50 | min_mask_region_area: int = 0,
51 | output_mode: str = "binary_mask",
52 | ) -> None:
53 | """
54 | Using a SAM model, generates masks for the entire image.
55 | Generates a grid of point prompts over the image, then filters
56 | low quality and duplicate masks. The default settings are chosen
57 | for SAM with a ViT-H backbone.
58 |
59 | Arguments:
60 | model (Sam): The SAM model to use for mask prediction.
61 | points_per_side (int or None): The number of points to be sampled
62 | along one side of the image. The total number of points is
63 | points_per_side**2. If None, 'point_grids' must provide explicit
64 | point sampling.
65 | points_per_batch (int): Sets the number of points run simultaneously
66 | by the model. Higher numbers may be faster but use more GPU memory.
67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the
68 | model's predicted mask quality.
69 | stability_score_thresh (float): A filtering threshold in [0,1], using
70 | the stability of the mask under changes to the cutoff used to binarize
71 | the model's mask predictions.
72 | stability_score_offset (float): The amount to shift the cutoff when
73 | calculated the stability score.
74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal
75 | suppression to filter duplicate masks.
76 | crops_n_layers (int): If >0, mask prediction will be run again on
77 | crops of the image. Sets the number of layers to run, where each
78 | layer has 2**i_layer number of image crops.
79 | crops_nms_thresh (float): The box IoU cutoff used by non-maximal
80 | suppression to filter duplicate masks between different crops.
81 | crop_overlap_ratio (float): Sets the degree to which crops overlap.
82 | In the first crop layer, crops will overlap by this fraction of
83 | the image length. Later layers with more crops scale down this overlap.
84 | crop_n_points_downscale_factor (int): The number of points-per-side
85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
86 | point_grids (list(np.ndarray) or None): A list over explicit grids
87 | of points used for sampling, normalized to [0,1]. The nth grid in the
88 | list is used in the nth crop layer. Exclusive with points_per_side.
89 | min_mask_region_area (int): If >0, postprocessing will be applied
90 | to remove disconnected regions and holes in masks with area smaller
91 | than min_mask_region_area. Requires opencv.
92 | output_mode (str): The form masks are returned in. Can be 'binary_mask',
93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
94 | For large resolutions, 'binary_mask' may consume large amounts of
95 | memory.
96 | """
97 |
98 | assert (points_per_side is None) != (
99 | point_grids is None
100 | ), "Exactly one of points_per_side or point_grid must be provided."
101 | if points_per_side is not None:
102 | self.point_grids = build_all_layer_point_grids(
103 | points_per_side,
104 | crop_n_layers,
105 | crop_n_points_downscale_factor,
106 | )
107 | elif point_grids is not None:
108 | self.point_grids = point_grids
109 | else:
110 | raise ValueError("Can't have both points_per_side and point_grid be None.")
111 |
112 | assert output_mode in [
113 | "binary_mask",
114 | "uncompressed_rle",
115 | "coco_rle",
116 | ], f"Unknown output_mode {output_mode}."
117 | if output_mode == "coco_rle":
118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401
119 |
120 | if min_mask_region_area > 0:
121 | import cv2 # type: ignore # noqa: F401
122 |
123 | self.predictor = SamPredictor(model)
124 | self.points_per_batch = points_per_batch
125 | self.pred_iou_thresh = pred_iou_thresh
126 | self.stability_score_thresh = stability_score_thresh
127 | self.stability_score_offset = stability_score_offset
128 | self.box_nms_thresh = box_nms_thresh
129 | self.crop_n_layers = crop_n_layers
130 | self.crop_nms_thresh = crop_nms_thresh
131 | self.crop_overlap_ratio = crop_overlap_ratio
132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
133 | self.min_mask_region_area = min_mask_region_area
134 | self.output_mode = output_mode
135 |
136 | @torch.no_grad()
137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
138 | """
139 | Generates masks for the given image.
140 |
141 | Arguments:
142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format.
143 |
144 | Returns:
145 | list(dict(str, any)): A list over records for masks. Each record is
146 | a dict containing the following keys:
147 | segmentation (dict(str, any) or np.ndarray): The mask. If
148 | output_mode='binary_mask', is an array of shape HW. Otherwise,
149 | is a dictionary containing the RLE.
150 | bbox (list(float)): The box around the mask, in XYWH format.
151 | area (int): The area in pixels of the mask.
152 | predicted_iou (float): The model's own prediction of the mask's
153 | quality. This is filtered by the pred_iou_thresh parameter.
154 | point_coords (list(list(float))): The point coordinates input
155 | to the model to generate this mask.
156 | stability_score (float): A measure of the mask's quality. This
157 | is filtered on using the stability_score_thresh parameter.
158 | crop_box (list(float)): The crop of the image used to generate
159 | the mask, given in XYWH format.
160 | """
161 |
162 | # Generate masks
163 | mask_data = self._generate_masks(image)
164 |
165 | # Filter small disconnected regions and holes in masks
166 | if self.min_mask_region_area > 0:
167 | mask_data = self.postprocess_small_regions(
168 | mask_data,
169 | self.min_mask_region_area,
170 | max(self.box_nms_thresh, self.crop_nms_thresh),
171 | )
172 |
173 | # Encode masks
174 | if self.output_mode == "coco_rle":
175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
176 | elif self.output_mode == "binary_mask":
177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
178 | else:
179 | mask_data["segmentations"] = mask_data["rles"]
180 |
181 | # Write mask records
182 | curr_anns = []
183 | for idx in range(len(mask_data["segmentations"])):
184 | ann = {
185 | "segmentation": mask_data["segmentations"][idx],
186 | "area": area_from_rle(mask_data["rles"][idx]),
187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
188 | "predicted_iou": mask_data["iou_preds"][idx].item(),
189 | "point_coords": [mask_data["points"][idx].tolist()],
190 | "stability_score": mask_data["stability_score"][idx].item(),
191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
192 | }
193 | curr_anns.append(ann)
194 |
195 | return curr_anns
196 |
197 | def _generate_masks(self, image: np.ndarray) -> MaskData:
198 | orig_size = image.shape[:2]
199 | crop_boxes, layer_idxs = generate_crop_boxes(orig_size, self.crop_n_layers, self.crop_overlap_ratio)
200 |
201 | # Iterate over image crops
202 | data = MaskData()
203 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
204 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
205 | data.cat(crop_data)
206 |
207 | # Remove duplicate masks between crops
208 | if len(crop_boxes) > 1:
209 | # Prefer masks from smaller crops
210 | scores = 1 / box_area(data["crop_boxes"])
211 | scores = scores.to(data["boxes"].device)
212 | keep_by_nms = batched_nms(
213 | data["boxes"].float(),
214 | scores,
215 | torch.zeros(len(data["boxes"])), # categories
216 | iou_threshold=self.crop_nms_thresh,
217 | )
218 | data.filter(keep_by_nms)
219 |
220 | data.to_numpy()
221 | return data
222 |
223 | def _process_crop(
224 | self,
225 | image: np.ndarray,
226 | crop_box: List[int],
227 | crop_layer_idx: int,
228 | orig_size: Tuple[int, ...],
229 | ) -> MaskData:
230 | # Crop the image and calculate embeddings
231 | x0, y0, x1, y1 = crop_box
232 | cropped_im = image[y0:y1, x0:x1, :]
233 | cropped_im_size = cropped_im.shape[:2]
234 | self.predictor.set_image(cropped_im)
235 |
236 | # Get points for this crop
237 | points_scale = np.array(cropped_im_size)[None, ::-1]
238 | points_for_image = self.point_grids[crop_layer_idx] * points_scale
239 |
240 | # Generate masks for this crop in batches
241 | data = MaskData()
242 | for (points,) in batch_iterator(self.points_per_batch, points_for_image):
243 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
244 | data.cat(batch_data)
245 | del batch_data
246 | self.predictor.reset_image()
247 |
248 | # Remove duplicates within this crop.
249 | keep_by_nms = batched_nms(
250 | data["boxes"].float(),
251 | data["iou_preds"],
252 | torch.zeros(len(data["boxes"])), # categories
253 | iou_threshold=self.box_nms_thresh,
254 | )
255 | data.filter(keep_by_nms)
256 |
257 | # Return to the original image frame
258 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
259 | data["points"] = uncrop_points(data["points"], crop_box)
260 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
261 |
262 | return data
263 |
264 | def _process_batch(
265 | self,
266 | points: np.ndarray,
267 | im_size: Tuple[int, ...],
268 | crop_box: List[int],
269 | orig_size: Tuple[int, ...],
270 | ) -> MaskData:
271 | orig_h, orig_w = orig_size
272 |
273 | # Run model on this batch
274 | transformed_points = self.predictor.transform.apply_coords(points, im_size)
275 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
276 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
277 | masks, iou_preds, _ = self.predictor.predict_torch(
278 | in_points[:, None, :],
279 | in_labels[:, None],
280 | multimask_output=True,
281 | return_logits=True,
282 | )
283 |
284 | # Serialize predictions and store in MaskData
285 | data = MaskData(
286 | masks=masks.flatten(0, 1),
287 | iou_preds=iou_preds.flatten(0, 1),
288 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
289 | )
290 | del masks
291 |
292 | # Filter by predicted IoU
293 | if self.pred_iou_thresh > 0.0:
294 | keep_mask = data["iou_preds"] > self.pred_iou_thresh
295 | data.filter(keep_mask)
296 |
297 | # Calculate stability score
298 | data["stability_score"] = calculate_stability_score(
299 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
300 | )
301 | if self.stability_score_thresh > 0.0:
302 | keep_mask = data["stability_score"] >= self.stability_score_thresh
303 | data.filter(keep_mask)
304 |
305 | # Threshold masks and calculate boxes
306 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold
307 | data["boxes"] = batched_mask_to_box(data["masks"])
308 |
309 | # Filter boxes that touch crop boundaries
310 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
311 | if not torch.all(keep_mask):
312 | data.filter(keep_mask)
313 |
314 | # Compress to RLE
315 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
316 | data["rles"] = mask_to_rle_pytorch(data["masks"])
317 | del data["masks"]
318 |
319 | return data
320 |
321 | @staticmethod
322 | def postprocess_small_regions(mask_data: MaskData, min_area: int, nms_thresh: float) -> MaskData:
323 | """
324 | Removes small disconnected regions and holes in masks, then reruns
325 | box NMS to remove any new duplicates.
326 |
327 | Edits mask_data in place.
328 |
329 | Requires open-cv as a dependency.
330 | """
331 | if len(mask_data["rles"]) == 0:
332 | return mask_data
333 |
334 | # Filter small disconnected regions and holes
335 | new_masks = []
336 | scores = []
337 | for rle in mask_data["rles"]:
338 | mask = rle_to_mask(rle)
339 |
340 | mask, changed = remove_small_regions(mask, min_area, mode="holes")
341 | unchanged = not changed
342 | mask, changed = remove_small_regions(mask, min_area, mode="islands")
343 | unchanged = unchanged and not changed
344 |
345 | new_masks.append(torch.as_tensor(mask).unsqueeze(0))
346 | # Give score=0 to changed masks and score=1 to unchanged masks
347 | # so NMS will prefer ones that didn't need postprocessing
348 | scores.append(float(unchanged))
349 |
350 | # Recalculate boxes and remove any new duplicates
351 | masks = torch.cat(new_masks, dim=0)
352 | boxes = batched_mask_to_box(masks)
353 | keep_by_nms = batched_nms(
354 | boxes.float(),
355 | torch.as_tensor(scores),
356 | torch.zeros(len(boxes)), # categories
357 | iou_threshold=nms_thresh,
358 | )
359 |
360 | # Only recalculate RLEs for masks that have changed
361 | for i_mask in keep_by_nms:
362 | if scores[i_mask] == 0.0:
363 | mask_torch = masks[i_mask].unsqueeze(0)
364 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
365 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
366 | mask_data.filter(keep_by_nms)
367 |
368 | return mask_data
369 |
--------------------------------------------------------------------------------
/metaseg/generator/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from functools import partial
8 |
9 | import torch
10 |
11 | from metaseg.modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
12 |
13 |
14 | def build_sam_vit_h(checkpoint=None):
15 | return _build_sam(
16 | encoder_embed_dim=1280,
17 | encoder_depth=32,
18 | encoder_num_heads=16,
19 | encoder_global_attn_indexes=[7, 15, 23, 31],
20 | checkpoint=checkpoint,
21 | )
22 |
23 |
24 | build_sam = build_sam_vit_h
25 |
26 |
27 | def build_sam_vit_l(checkpoint=None):
28 | return _build_sam(
29 | encoder_embed_dim=1024,
30 | encoder_depth=24,
31 | encoder_num_heads=16,
32 | encoder_global_attn_indexes=[5, 11, 17, 23],
33 | checkpoint=checkpoint,
34 | )
35 |
36 |
37 | def build_sam_vit_b(checkpoint=None):
38 | return _build_sam(
39 | encoder_embed_dim=768,
40 | encoder_depth=12,
41 | encoder_num_heads=12,
42 | encoder_global_attn_indexes=[2, 5, 8, 11],
43 | checkpoint=checkpoint,
44 | )
45 |
46 |
47 | sam_model_registry = {
48 | "default": build_sam,
49 | "vit_h": build_sam,
50 | "vit_l": build_sam_vit_l,
51 | "vit_b": build_sam_vit_b,
52 | }
53 |
54 |
55 | def _build_sam(
56 | encoder_embed_dim,
57 | encoder_depth,
58 | encoder_num_heads,
59 | encoder_global_attn_indexes,
60 | checkpoint=None,
61 | ):
62 | prompt_embed_dim = 256
63 | image_size = 1024
64 | vit_patch_size = 16
65 | image_embedding_size = image_size // vit_patch_size
66 | sam = Sam(
67 | image_encoder=ImageEncoderViT(
68 | depth=encoder_depth,
69 | embed_dim=encoder_embed_dim,
70 | img_size=image_size,
71 | mlp_ratio=4,
72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
73 | num_heads=encoder_num_heads,
74 | patch_size=vit_patch_size,
75 | qkv_bias=True,
76 | use_rel_pos=True,
77 | global_attn_indexes=encoder_global_attn_indexes,
78 | window_size=14,
79 | out_chans=prompt_embed_dim,
80 | ),
81 | prompt_encoder=PromptEncoder(
82 | embed_dim=prompt_embed_dim,
83 | image_embedding_size=(image_embedding_size, image_embedding_size),
84 | input_image_size=(image_size, image_size),
85 | mask_in_chans=16,
86 | ),
87 | mask_decoder=MaskDecoder(
88 | num_multimask_outputs=3,
89 | transformer=TwoWayTransformer(
90 | depth=2,
91 | embedding_dim=prompt_embed_dim,
92 | mlp_dim=2048,
93 | num_heads=8,
94 | ),
95 | transformer_dim=prompt_embed_dim,
96 | iou_head_depth=3,
97 | iou_head_hidden_dim=256,
98 | ),
99 | pixel_mean=[123.675, 116.28, 103.53],
100 | pixel_std=[58.395, 57.12, 57.375],
101 | )
102 | sam.eval()
103 | if checkpoint is not None:
104 | with open(checkpoint, "rb") as f:
105 | state_dict = torch.load(f)
106 | sam.load_state_dict(state_dict)
107 | return sam
108 |
--------------------------------------------------------------------------------
/metaseg/generator/predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional, Tuple
8 |
9 | import numpy as np
10 | import torch
11 |
12 | from metaseg.modeling import Sam
13 | from metaseg.utils.transforms import ResizeLongestSide
14 |
15 |
16 | class SamPredictor:
17 | def __init__(
18 | self,
19 | sam_model: Sam,
20 | ) -> None:
21 | """
22 | Uses SAM to calculate the image embedding for an image, and then
23 | allow repeated, efficient mask prediction given prompts.
24 |
25 | Arguments:
26 | sam_model (Sam): The model to use for mask prediction.
27 | """
28 | super().__init__()
29 | self.model = sam_model
30 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
31 | self.reset_image()
32 |
33 | def set_image(
34 | self,
35 | image: np.ndarray,
36 | image_format: str = "RGB",
37 | ) -> None:
38 | """
39 | Calculates the image embeddings for the provided image, allowing
40 | masks to be predicted with the 'predict' method.
41 |
42 | Arguments:
43 | image (np.ndarray): The image for calculating masks. Expects an
44 | image in HWC uint8 format, with pixel values in [0, 255].
45 | image_format (str): The color format of the image, in ['RGB', 'BGR'].
46 | """
47 | assert image_format in [
48 | "RGB",
49 | "BGR",
50 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
51 | if image_format != self.model.image_format:
52 | image = image[..., ::-1]
53 |
54 | # Transform the image to the form expected by the model
55 | input_image = self.transform.apply_image(image)
56 | input_image_torch = torch.as_tensor(input_image, device=self.device)
57 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
58 |
59 | self.set_torch_image(input_image_torch, image.shape[:2])
60 |
61 | @torch.no_grad()
62 | def set_torch_image(
63 | self,
64 | transformed_image: torch.Tensor,
65 | original_image_size: Tuple[int, ...],
66 | ) -> None:
67 | """
68 | Calculates the image embeddings for the provided image, allowing
69 | masks to be predicted with the 'predict' method. Expects the input
70 | image to be already transformed to the format expected by the model.
71 |
72 | Arguments:
73 | transformed_image (torch.Tensor): The input image, with shape
74 | 1x3xHxW, which has been transformed with ResizeLongestSide.
75 | original_image_size (tuple(int, int)): The size of the image
76 | before transformation, in (H, W) format.
77 | """
78 | assert (
79 | len(transformed_image.shape) == 4
80 | and transformed_image.shape[1] == 3
81 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
82 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
83 | self.reset_image()
84 |
85 | self.original_size = original_image_size
86 | self.input_size = tuple(transformed_image.shape[-2:])
87 | input_image = self.model.preprocess(transformed_image)
88 | self.features = self.model.image_encoder(input_image)
89 | self.is_image_set = True
90 |
91 | def predict(
92 | self,
93 | point_coords: Optional[np.ndarray] = None,
94 | point_labels: Optional[np.ndarray] = None,
95 | box: Optional[np.ndarray] = None,
96 | mask_input: Optional[np.ndarray] = None,
97 | multimask_output: bool = True,
98 | return_logits: bool = False,
99 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
100 | """
101 | Predict masks for the given input prompts, using the currently set image.
102 |
103 | Arguments:
104 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the
105 | model. Each point is in (X,Y) in pixels.
106 | point_labels (np.ndarray or None): A length N array of labels for the
107 | point prompts. 1 indicates a foreground point and 0 indicates a
108 | background point.
109 | box (np.ndarray or None): A length 4 array given a box prompt to the
110 | model, in XYXY format.
111 | mask_input (np.ndarray): A low resolution mask input to the model, typically
112 | coming from a previous prediction iteration. Has form 1xHxW, where
113 | for SAM, H=W=256.
114 | multimask_output (bool): If true, the model will return three masks.
115 | For ambiguous input prompts (such as a single click), this will often
116 | produce better masks than a single prediction. If only a single
117 | mask is needed, the model's predicted quality score can be used
118 | to select the best mask. For non-ambiguous prompts, such as multiple
119 | input prompts, multimask_output=False can give better results.
120 | return_logits (bool): If true, returns un-thresholded masks logits
121 | instead of a binary mask.
122 |
123 | Returns:
124 | (np.ndarray): The output masks in CxHxW format, where C is the
125 | number of masks, and (H, W) is the original image size.
126 | (np.ndarray): An array of length C containing the model's
127 | predictions for the quality of each mask.
128 | (np.ndarray): An array of shape CxHxW, where C is the number
129 | of masks and H=W=256. These low resolution logits can be passed to
130 | a subsequent iteration as mask input.
131 | """
132 | if not self.is_image_set:
133 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
134 |
135 | # Transform input prompts
136 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
137 | if point_coords is not None:
138 | assert point_labels is not None, "point_labels must be supplied if point_coords is supplied."
139 | point_coords = self.transform.apply_coords(point_coords, self.original_size)
140 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
141 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
142 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
143 | if box is not None:
144 | box = self.transform.apply_boxes(box, self.original_size)
145 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
146 | box_torch = box_torch[None, :]
147 | if mask_input is not None:
148 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
149 | mask_input_torch = mask_input_torch[None, :, :, :]
150 |
151 | masks, iou_predictions, low_res_masks = self.predict_torch(
152 | coords_torch,
153 | labels_torch,
154 | box_torch,
155 | mask_input_torch,
156 | multimask_output,
157 | return_logits=return_logits,
158 | )
159 |
160 | masks = masks[0].detach().cpu().numpy()
161 | iou_predictions = iou_predictions[0].detach().cpu().numpy()
162 | low_res_masks = low_res_masks[0].detach().cpu().numpy()
163 | return masks, iou_predictions, low_res_masks
164 |
165 | @torch.no_grad()
166 | def predict_torch(
167 | self,
168 | point_coords: Optional[torch.Tensor],
169 | point_labels: Optional[torch.Tensor],
170 | boxes: Optional[torch.Tensor] = None,
171 | mask_input: Optional[torch.Tensor] = None,
172 | multimask_output: bool = True,
173 | return_logits: bool = False,
174 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
175 | """
176 | Predict masks for the given input prompts, using the currently set image.
177 | Input prompts are batched torch tensors and are expected to already be
178 | transformed to the input frame using ResizeLongestSide.
179 |
180 | Arguments:
181 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
182 | model. Each point is in (X,Y) in pixels.
183 | point_labels (torch.Tensor or None): A BxN array of labels for the
184 | point prompts. 1 indicates a foreground point and 0 indicates a
185 | background point.
186 | box (np.ndarray or None): A Bx4 array given a box prompt to the
187 | model, in XYXY format.
188 | mask_input (np.ndarray): A low resolution mask input to the model, typically
189 | coming from a previous prediction iteration. Has form Bx1xHxW, where
190 | for SAM, H=W=256. Masks returned by a previous iteration of the
191 | predict method do not need further transformation.
192 | multimask_output (bool): If true, the model will return three masks.
193 | For ambiguous input prompts (such as a single click), this will often
194 | produce better masks than a single prediction. If only a single
195 | mask is needed, the model's predicted quality score can be used
196 | to select the best mask. For non-ambiguous prompts, such as multiple
197 | input prompts, multimask_output=False can give better results.
198 | return_logits (bool): If true, returns un-thresholded masks logits
199 | instead of a binary mask.
200 |
201 | Returns:
202 | (torch.Tensor): The output masks in BxCxHxW format, where C is the
203 | number of masks, and (H, W) is the original image size.
204 | (torch.Tensor): An array of shape BxC containing the model's
205 | predictions for the quality of each mask.
206 | (torch.Tensor): An array of shape BxCxHxW, where C is the number
207 | of masks and H=W=256. These low res logits can be passed to
208 | a subsequent iteration as mask input.
209 | """
210 | if not self.is_image_set:
211 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
212 |
213 | if point_coords is not None:
214 | points = (point_coords, point_labels)
215 | else:
216 | points = None
217 |
218 | # Embed prompts
219 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
220 | points=points,
221 | boxes=boxes,
222 | masks=mask_input,
223 | )
224 |
225 | # Predict masks
226 | low_res_masks, iou_predictions = self.model.mask_decoder(
227 | image_embeddings=self.features,
228 | image_pe=self.model.prompt_encoder.get_dense_pe(),
229 | sparse_prompt_embeddings=sparse_embeddings,
230 | dense_prompt_embeddings=dense_embeddings,
231 | multimask_output=multimask_output,
232 | )
233 |
234 | # Upscale the masks to the original image resolution
235 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
236 |
237 | if not return_logits:
238 | masks = masks > self.model.mask_threshold
239 |
240 | return masks, iou_predictions, low_res_masks
241 |
242 | def get_image_embedding(self) -> torch.Tensor:
243 | """
244 | Returns the image embeddings for the currently set image, with
245 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are
246 | the embedding spatial dimension of SAM (typically C=256, H=W=64).
247 | """
248 | if not self.is_image_set:
249 | raise RuntimeError("An image must be set with .set_image(...) to generate an embedding.")
250 | assert self.features is not None, "Features must exist if an image has been set."
251 | return self.features
252 |
253 | @property
254 | def device(self) -> torch.device:
255 | return self.model.device
256 |
257 | def reset_image(self) -> None:
258 | """Resets the currently set image."""
259 | self.is_image_set = False
260 | self.features = None
261 | self.orig_h = None
262 | self.orig_w = None
263 | self.input_h = None
264 | self.input_w = None
265 |
--------------------------------------------------------------------------------
/metaseg/mask_predictor.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from tqdm import tqdm
7 |
8 | from metaseg.generator.automatic_mask_generator import SamAutomaticMaskGenerator
9 | from metaseg.generator.predictor import SamPredictor
10 | from metaseg.generator.build_sam import sam_model_registry
11 |
12 | from metaseg.utils import (
13 | download_model,
14 | load_box,
15 | load_image,
16 | load_mask,
17 | load_video,
18 | multi_boxes,
19 | save_image,
20 | show_image,
21 | )
22 |
23 |
24 | class SegAutoMaskPredictor:
25 | def __init__(self):
26 | self.model = None
27 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
28 |
29 | def load_model(self, model_type):
30 | if self.model is None:
31 | self.model_path = download_model(model_type)
32 | self.model = sam_model_registry[model_type](checkpoint=self.model_path)
33 | self.model.to(device=self.device)
34 |
35 | return self.model
36 |
37 | def image_predict(
38 | self,
39 | source,
40 | model_type,
41 | points_per_side,
42 | points_per_batch,
43 | min_area,
44 | output_path="output.png",
45 | show=False,
46 | save=False,
47 | ):
48 | read_image = load_image(source)
49 | model = self.load_model(model_type)
50 | mask_generator = SamAutomaticMaskGenerator(
51 | model, points_per_side=points_per_side, points_per_batch=points_per_batch, min_mask_region_area=min_area
52 | )
53 |
54 | masks = mask_generator.generate(read_image)
55 |
56 | sorted_anns = sorted(masks, key=(lambda x: x["area"]), reverse=True)
57 | mask_image = np.zeros((masks[0]["segmentation"].shape[0], masks[0]["segmentation"].shape[1], 3), dtype=np.uint8)
58 | colors = np.random.randint(0, 255, size=(256, 3), dtype=np.uint8)
59 | for i, ann in enumerate(sorted_anns):
60 | m = ann["segmentation"]
61 | img = np.ones((m.shape[0], m.shape[1], 3), dtype=np.uint8)
62 | color = colors[i % 256]
63 | for i in range(3):
64 | img[:, :, 0] = color[0]
65 | img[:, :, 1] = color[1]
66 | img[:, :, 2] = color[2]
67 | img = cv2.bitwise_and(img, img, mask=m.astype(np.uint8))
68 | img = cv2.addWeighted(img, 0.35, np.zeros_like(img), 0.65, 0)
69 | mask_image = cv2.add(mask_image, img)
70 |
71 | combined_mask = cv2.add(read_image, mask_image)
72 | self.combined_mask = combined_mask
73 | if show:
74 | show_image(combined_mask)
75 |
76 | if save:
77 | save_image(output_path=output_path, output_image=combined_mask)
78 |
79 | return masks
80 |
81 | def video_predict(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.mp4"):
82 | cap, out = load_video(source, output_path)
83 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
84 | colors = np.random.randint(0, 255, size=(256, 3), dtype=np.uint8)
85 |
86 | for _ in tqdm(range(length)):
87 | ret, frame = cap.read()
88 | if not ret:
89 | break
90 |
91 | model = self.load_model(model_type)
92 | mask_generator = SamAutomaticMaskGenerator(
93 | model, points_per_side=points_per_side, points_per_batch=points_per_batch, min_mask_region_area=min_area
94 | )
95 | masks = mask_generator.generate(frame)
96 |
97 | if len(masks) == 0:
98 | continue
99 |
100 | sorted_anns = sorted(masks, key=(lambda x: x["area"]), reverse=True)
101 | mask_image = np.zeros(
102 | (masks[0]["segmentation"].shape[0], masks[0]["segmentation"].shape[1], 3), dtype=np.uint8
103 | )
104 |
105 | for i, ann in enumerate(sorted_anns):
106 | m = ann["segmentation"]
107 | color = colors[i % 256]
108 | img = np.zeros((m.shape[0], m.shape[1], 3), dtype=np.uint8)
109 | img[:, :, 0] = color[0]
110 | img[:, :, 1] = color[1]
111 | img[:, :, 2] = color[2]
112 | img = cv2.bitwise_and(img, img, mask=m.astype(np.uint8))
113 | img = cv2.addWeighted(img, 0.35, np.zeros_like(img), 0.65, 0)
114 | mask_image = cv2.add(mask_image, img)
115 |
116 | combined_mask = cv2.add(frame, mask_image)
117 | out.write(combined_mask)
118 |
119 | out.release()
120 | cap.release()
121 | cv2.destroyAllWindows()
122 |
123 | return output_path
124 |
125 |
126 | class SegManualMaskPredictor:
127 | def __init__(self):
128 | self.model = None
129 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
130 |
131 | def load_model(self, model_type):
132 | if self.model is None:
133 | self.model_path = download_model(model_type)
134 | self.model = sam_model_registry[model_type](checkpoint=self.model_path)
135 | self.model.to(device=self.device)
136 |
137 | return self.model
138 |
139 | def image_predict(
140 | self,
141 | source,
142 | model_type,
143 | input_box=None,
144 | input_point=None,
145 | input_label=None,
146 | multimask_output=False,
147 | output_path="output.png",
148 | random_color=False,
149 | show=False,
150 | save=False,
151 | ):
152 | image = load_image(source)
153 | model = self.load_model(model_type)
154 | predictor = SamPredictor(model)
155 | predictor.set_image(image)
156 |
157 | if type(input_box[0]) == list:
158 | input_boxes, new_boxes = multi_boxes(input_box, predictor, image)
159 |
160 | masks, _, _ = predictor.predict_torch(
161 | point_coords=None,
162 | point_labels=None,
163 | boxes=new_boxes,
164 | multimask_output=False,
165 | )
166 | for mask in masks:
167 | mask_image = load_mask(mask.cpu().numpy(), random_color)
168 |
169 | for box in input_boxes:
170 | image = load_box(box.cpu().numpy(), image)
171 |
172 | elif type(input_box[0]) == int:
173 | input_boxes = np.array(input_box)[None, :]
174 |
175 | masks, _, _ = predictor.predict(
176 | point_coords=input_point,
177 | point_labels=input_label,
178 | box=input_boxes,
179 | multimask_output=multimask_output,
180 | )
181 | mask_image = load_mask(masks, random_color)
182 | image = load_box(input_box, image)
183 |
184 | combined_mask = cv2.add(image, mask_image)
185 | if save:
186 | save_image(output_path=output_path, output_image=combined_mask)
187 |
188 | if show:
189 | show_image(combined_mask)
190 |
191 | return masks
192 |
193 | def video_predict(
194 | self,
195 | source,
196 | model_type,
197 | input_box=None,
198 | input_point=None,
199 | input_label=None,
200 | multimask_output=False,
201 | output_path="output.mp4",
202 | random_color=False,
203 | ):
204 | cap, out = load_video(source, output_path)
205 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
206 |
207 | for _ in tqdm(range(length)):
208 | ret, frame = cap.read()
209 | if not ret:
210 | break
211 |
212 | model = self.load_model(model_type)
213 | predictor = SamPredictor(model)
214 | predictor.set_image(frame)
215 |
216 | if type(input_box[0]) == list:
217 | input_boxes, new_boxes = multi_boxes(input_box, predictor, frame)
218 |
219 | masks, _, _ = predictor.predict_torch(
220 | point_coords=None,
221 | point_labels=None,
222 | boxes=new_boxes,
223 | multimask_output=False,
224 | )
225 | for mask in masks:
226 | mask_image = load_mask(mask.cpu().numpy(), random_color)
227 |
228 | for box in input_boxes:
229 | frame = load_box(box.cpu().numpy(), frame)
230 |
231 | elif type(input_box[0]) == int:
232 | input_boxes = np.array(input_box)[None, :]
233 |
234 | masks, _, _ = predictor.predict(
235 | point_coords=input_point,
236 | point_labels=input_label,
237 | box=input_boxes,
238 | multimask_output=multimask_output,
239 | )
240 | mask_image = load_mask(masks, random_color)
241 | frame = load_box(input_box, frame)
242 |
243 | combined_mask = cv2.add(frame, mask_image)
244 | out.write(combined_mask)
245 |
246 | out.release()
247 | cap.release()
248 | cv2.destroyAllWindows()
249 | return output_path
250 |
--------------------------------------------------------------------------------
/metaseg/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from metaseg.modeling.image_encoder import ImageEncoderViT
8 | from metaseg.modeling.mask_decoder import MaskDecoder
9 | from metaseg.modeling.prompt_encoder import PromptEncoder
10 | from metaseg.modeling.sam import Sam
11 | from metaseg.modeling.transformer import TwoWayTransformer
12 |
--------------------------------------------------------------------------------
/metaseg/modeling/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Type
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 |
13 | class MLPBlock(nn.Module):
14 | def __init__(
15 | self,
16 | embedding_dim: int,
17 | mlp_dim: int,
18 | act: Type[nn.Module] = nn.GELU,
19 | ) -> None:
20 | super().__init__()
21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
23 | self.act = act()
24 |
25 | def forward(self, x: torch.Tensor) -> torch.Tensor:
26 | return self.lin2(self.act(self.lin1(x)))
27 |
28 |
29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
31 | class LayerNorm2d(nn.Module):
32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
33 | super().__init__()
34 | self.weight = nn.Parameter(torch.ones(num_channels))
35 | self.bias = nn.Parameter(torch.zeros(num_channels))
36 | self.eps = eps
37 |
38 | def forward(self, x: torch.Tensor) -> torch.Tensor:
39 | u = x.mean(1, keepdim=True)
40 | s = (x - u).pow(2).mean(1, keepdim=True)
41 | x = (x - u) / torch.sqrt(s + self.eps)
42 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
43 | return x
44 |
--------------------------------------------------------------------------------
/metaseg/modeling/image_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional, Tuple, Type
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from metaseg.modeling.common import LayerNorm2d, MLPBlock
14 |
15 |
16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
17 | class ImageEncoderViT(nn.Module):
18 | def __init__(
19 | self,
20 | img_size: int = 1024,
21 | patch_size: int = 16,
22 | in_chans: int = 3,
23 | embed_dim: int = 768,
24 | depth: int = 12,
25 | num_heads: int = 12,
26 | mlp_ratio: float = 4.0,
27 | out_chans: int = 256,
28 | qkv_bias: bool = True,
29 | norm_layer: Type[nn.Module] = nn.LayerNorm,
30 | act_layer: Type[nn.Module] = nn.GELU,
31 | use_abs_pos: bool = True,
32 | use_rel_pos: bool = False,
33 | rel_pos_zero_init: bool = True,
34 | window_size: int = 0,
35 | global_attn_indexes: Tuple[int, ...] = (),
36 | ) -> None:
37 | """
38 | Args:
39 | img_size (int): Input image size.
40 | patch_size (int): Patch size.
41 | in_chans (int): Number of input image channels.
42 | embed_dim (int): Patch embedding dimension.
43 | depth (int): Depth of ViT.
44 | num_heads (int): Number of attention heads in each ViT block.
45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
46 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
47 | norm_layer (nn.Module): Normalization layer.
48 | act_layer (nn.Module): Activation layer.
49 | use_abs_pos (bool): If True, use absolute positional embeddings.
50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
52 | window_size (int): Window size for window attention blocks.
53 | global_attn_indexes (list): Indexes for blocks using global attention.
54 | """
55 | super().__init__()
56 | self.img_size = img_size
57 |
58 | self.patch_embed = PatchEmbed(
59 | kernel_size=(patch_size, patch_size),
60 | stride=(patch_size, patch_size),
61 | in_chans=in_chans,
62 | embed_dim=embed_dim,
63 | )
64 |
65 | self.pos_embed: Optional[nn.Parameter] = None
66 | if use_abs_pos:
67 | # Initialize absolute positional embedding with pretrain image size.
68 | self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
69 |
70 | self.blocks = nn.ModuleList()
71 | for i in range(depth):
72 | block = Block(
73 | dim=embed_dim,
74 | num_heads=num_heads,
75 | mlp_ratio=mlp_ratio,
76 | qkv_bias=qkv_bias,
77 | norm_layer=norm_layer,
78 | act_layer=act_layer,
79 | use_rel_pos=use_rel_pos,
80 | rel_pos_zero_init=rel_pos_zero_init,
81 | window_size=window_size if i not in global_attn_indexes else 0,
82 | input_size=(img_size // patch_size, img_size // patch_size),
83 | )
84 | self.blocks.append(block)
85 |
86 | self.neck = nn.Sequential(
87 | nn.Conv2d(
88 | embed_dim,
89 | out_chans,
90 | kernel_size=1,
91 | bias=False,
92 | ),
93 | LayerNorm2d(out_chans),
94 | nn.Conv2d(
95 | out_chans,
96 | out_chans,
97 | kernel_size=3,
98 | padding=1,
99 | bias=False,
100 | ),
101 | LayerNorm2d(out_chans),
102 | )
103 |
104 | def forward(self, x: torch.Tensor) -> torch.Tensor:
105 | x = self.patch_embed(x)
106 | if self.pos_embed is not None:
107 | x = x + self.pos_embed
108 |
109 | for blk in self.blocks:
110 | x = blk(x)
111 |
112 | x = self.neck(x.permute(0, 3, 1, 2))
113 |
114 | return x
115 |
116 |
117 | class Block(nn.Module):
118 | """Transformer blocks with support of window attention and residual propagation blocks"""
119 |
120 | def __init__(
121 | self,
122 | dim: int,
123 | num_heads: int,
124 | mlp_ratio: float = 4.0,
125 | qkv_bias: bool = True,
126 | norm_layer: Type[nn.Module] = nn.LayerNorm,
127 | act_layer: Type[nn.Module] = nn.GELU,
128 | use_rel_pos: bool = False,
129 | rel_pos_zero_init: bool = True,
130 | window_size: int = 0,
131 | input_size: Optional[Tuple[int, int]] = None,
132 | ) -> None:
133 | """
134 | Args:
135 | dim (int): Number of input channels.
136 | num_heads (int): Number of attention heads in each ViT block.
137 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
138 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
139 | norm_layer (nn.Module): Normalization layer.
140 | act_layer (nn.Module): Activation layer.
141 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
142 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
143 | window_size (int): Window size for window attention blocks. If it equals 0, then
144 | use global attention.
145 | input_size (int or None): Input resolution for calculating the relative positional
146 | parameter size.
147 | """
148 | super().__init__()
149 | self.norm1 = norm_layer(dim)
150 | self.attn = Attention(
151 | dim,
152 | num_heads=num_heads,
153 | qkv_bias=qkv_bias,
154 | use_rel_pos=use_rel_pos,
155 | rel_pos_zero_init=rel_pos_zero_init,
156 | input_size=input_size if window_size == 0 else (window_size, window_size),
157 | )
158 |
159 | self.norm2 = norm_layer(dim)
160 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
161 |
162 | self.window_size = window_size
163 |
164 | def forward(self, x: torch.Tensor) -> torch.Tensor:
165 | shortcut = x
166 | x = self.norm1(x)
167 | # Window partition
168 | if self.window_size > 0:
169 | H, W = x.shape[1], x.shape[2]
170 | x, pad_hw = window_partition(x, self.window_size)
171 |
172 | x = self.attn(x)
173 | # Reverse window partition
174 | if self.window_size > 0:
175 | x = window_unpartition(x, self.window_size, pad_hw, (H, W))
176 |
177 | x = shortcut + x
178 | x = x + self.mlp(self.norm2(x))
179 |
180 | return x
181 |
182 |
183 | class Attention(nn.Module):
184 | """Multi-head Attention block with relative position embeddings."""
185 |
186 | def __init__(
187 | self,
188 | dim: int,
189 | num_heads: int = 8,
190 | qkv_bias: bool = True,
191 | use_rel_pos: bool = False,
192 | rel_pos_zero_init: bool = True,
193 | input_size: Optional[Tuple[int, int]] = None,
194 | ) -> None:
195 | """
196 | Args:
197 | dim (int): Number of input channels.
198 | num_heads (int): Number of attention heads.
199 | qkv_bias (bool: If True, add a learnable bias to query, key, value.
200 | rel_pos (bool): If True, add relative positional embeddings to the attention map.
201 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
202 | input_size (int or None): Input resolution for calculating the relative positional
203 | parameter size.
204 | """
205 | super().__init__()
206 | self.num_heads = num_heads
207 | head_dim = dim // num_heads
208 | self.scale = head_dim ** -0.5
209 |
210 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
211 | self.proj = nn.Linear(dim, dim)
212 |
213 | self.use_rel_pos = use_rel_pos
214 | if self.use_rel_pos:
215 | assert input_size is not None, "Input size must be provided if using relative positional encoding."
216 | # initialize relative positional embeddings
217 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
218 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
219 |
220 | def forward(self, x: torch.Tensor) -> torch.Tensor:
221 | B, H, W, _ = x.shape
222 | # qkv with shape (3, B, nHead, H * W, C)
223 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
224 | # q, k, v with shape (B * nHead, H * W, C)
225 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
226 |
227 | attn = (q * self.scale) @ k.transpose(-2, -1)
228 |
229 | if self.use_rel_pos:
230 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
231 |
232 | attn = attn.softmax(dim=-1)
233 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
234 | x = self.proj(x)
235 |
236 | return x
237 |
238 |
239 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
240 | """
241 | Partition into non-overlapping windows with padding if needed.
242 | Args:
243 | x (tensor): input tokens with [B, H, W, C].
244 | window_size (int): window size.
245 |
246 | Returns:
247 | windows: windows after partition with [B * num_windows, window_size, window_size, C].
248 | (Hp, Wp): padded height and width before partition
249 | """
250 | B, H, W, C = x.shape
251 |
252 | pad_h = (window_size - H % window_size) % window_size
253 | pad_w = (window_size - W % window_size) % window_size
254 | if pad_h > 0 or pad_w > 0:
255 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
256 | Hp, Wp = H + pad_h, W + pad_w
257 |
258 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
259 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
260 | return windows, (Hp, Wp)
261 |
262 |
263 | def window_unpartition(
264 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
265 | ) -> torch.Tensor:
266 | """
267 | Window unpartition into original sequences and removing padding.
268 | Args:
269 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
270 | window_size (int): window size.
271 | pad_hw (Tuple): padded height and width (Hp, Wp).
272 | hw (Tuple): original height and width (H, W) before padding.
273 |
274 | Returns:
275 | x: unpartitioned sequences with [B, H, W, C].
276 | """
277 | Hp, Wp = pad_hw
278 | H, W = hw
279 | B = windows.shape[0] // (Hp * Wp // window_size // window_size)
280 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
281 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
282 |
283 | if Hp > H or Wp > W:
284 | x = x[:, :H, :W, :].contiguous()
285 | return x
286 |
287 |
288 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
289 | """
290 | Get relative positional embeddings according to the relative positions of
291 | query and key sizes.
292 | Args:
293 | q_size (int): size of query q.
294 | k_size (int): size of key k.
295 | rel_pos (Tensor): relative position embeddings (L, C).
296 |
297 | Returns:
298 | Extracted positional embeddings according to relative positions.
299 | """
300 | max_rel_dist = int(2 * max(q_size, k_size) - 1)
301 | # Interpolate rel pos if needed.
302 | if rel_pos.shape[0] != max_rel_dist:
303 | # Interpolate rel pos.
304 | rel_pos_resized = F.interpolate(
305 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
306 | size=max_rel_dist,
307 | mode="linear",
308 | )
309 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
310 | else:
311 | rel_pos_resized = rel_pos
312 |
313 | # Scale the coords with short length if shapes for q and k are different.
314 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
315 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
316 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
317 |
318 | return rel_pos_resized[relative_coords.long()]
319 |
320 |
321 | def add_decomposed_rel_pos(
322 | attn: torch.Tensor,
323 | q: torch.Tensor,
324 | rel_pos_h: torch.Tensor,
325 | rel_pos_w: torch.Tensor,
326 | q_size: Tuple[int, int],
327 | k_size: Tuple[int, int],
328 | ) -> torch.Tensor:
329 | """
330 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
331 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
332 | Args:
333 | attn (Tensor): attention map.
334 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
335 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
336 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
337 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
338 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
339 |
340 | Returns:
341 | attn (Tensor): attention map with added relative positional embeddings.
342 | """
343 | q_h, q_w = q_size
344 | k_h, k_w = k_size
345 | Rh = get_rel_pos(q_h, k_h, rel_pos_h)
346 | Rw = get_rel_pos(q_w, k_w, rel_pos_w)
347 |
348 | B, _, dim = q.shape
349 | r_q = q.reshape(B, q_h, q_w, dim)
350 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
351 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
352 |
353 | attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
354 | B, q_h * q_w, k_h * k_w
355 | )
356 |
357 | return attn
358 |
359 |
360 | class PatchEmbed(nn.Module):
361 | """
362 | Image to Patch Embedding.
363 | """
364 |
365 | def __init__(
366 | self,
367 | kernel_size: Tuple[int, int] = (16, 16),
368 | stride: Tuple[int, int] = (16, 16),
369 | padding: Tuple[int, int] = (0, 0),
370 | in_chans: int = 3,
371 | embed_dim: int = 768,
372 | ) -> None:
373 | """
374 | Args:
375 | kernel_size (Tuple): kernel size of the projection layer.
376 | stride (Tuple): stride of the projection layer.
377 | padding (Tuple): padding size of the projection layer.
378 | in_chans (int): Number of input image channels.
379 | embed_dim (int): embed_dim (int): Patch embedding dimension.
380 | """
381 | super().__init__()
382 |
383 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
384 |
385 | def forward(self, x: torch.Tensor) -> torch.Tensor:
386 | x = self.proj(x)
387 | # B C H W -> B H W C
388 | x = x.permute(0, 2, 3, 1)
389 | return x
390 |
--------------------------------------------------------------------------------
/metaseg/modeling/mask_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List, Tuple, Type
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 |
13 | from metaseg.modeling.common import LayerNorm2d
14 |
15 |
16 | class MaskDecoder(nn.Module):
17 | def __init__(
18 | self,
19 | *,
20 | transformer_dim: int,
21 | transformer: nn.Module,
22 | num_multimask_outputs: int = 3,
23 | activation: Type[nn.Module] = nn.GELU,
24 | iou_head_depth: int = 3,
25 | iou_head_hidden_dim: int = 256,
26 | ) -> None:
27 | """
28 | Predicts masks given an image and prompt embeddings, using a
29 | tranformer architecture.
30 |
31 | Arguments:
32 | transformer_dim (int): the channel dimension of the transformer
33 | transformer (nn.Module): the transformer used to predict masks
34 | num_multimask_outputs (int): the number of masks to predict
35 | when disambiguating masks
36 | activation (nn.Module): the type of activation to use when
37 | upscaling masks
38 | iou_head_depth (int): the depth of the MLP used to predict
39 | mask quality
40 | iou_head_hidden_dim (int): the hidden dimension of the MLP
41 | used to predict mask quality
42 | """
43 | super().__init__()
44 | self.transformer_dim = transformer_dim
45 | self.transformer = transformer
46 |
47 | self.num_multimask_outputs = num_multimask_outputs
48 |
49 | self.iou_token = nn.Embedding(1, transformer_dim)
50 | self.num_mask_tokens = num_multimask_outputs + 1
51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
52 |
53 | self.output_upscaling = nn.Sequential(
54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
55 | LayerNorm2d(transformer_dim // 4),
56 | activation(),
57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
58 | activation(),
59 | )
60 | self.output_hypernetworks_mlps = nn.ModuleList(
61 | [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for i in range(self.num_mask_tokens)]
62 | )
63 |
64 | self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
65 |
66 | def forward(
67 | self,
68 | image_embeddings: torch.Tensor,
69 | image_pe: torch.Tensor,
70 | sparse_prompt_embeddings: torch.Tensor,
71 | dense_prompt_embeddings: torch.Tensor,
72 | multimask_output: bool,
73 | ) -> Tuple[torch.Tensor, torch.Tensor]:
74 | """
75 | Predict masks given image and prompt embeddings.
76 |
77 | Arguments:
78 | image_embeddings (torch.Tensor): the embeddings from the image encoder
79 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
80 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
81 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
82 | multimask_output (bool): Whether to return multiple masks or a single
83 | mask.
84 |
85 | Returns:
86 | torch.Tensor: batched predicted masks
87 | torch.Tensor: batched predictions of mask quality
88 | """
89 | masks, iou_pred = self.predict_masks(
90 | image_embeddings=image_embeddings,
91 | image_pe=image_pe,
92 | sparse_prompt_embeddings=sparse_prompt_embeddings,
93 | dense_prompt_embeddings=dense_prompt_embeddings,
94 | )
95 |
96 | # Select the correct mask or masks for outptu
97 | if multimask_output:
98 | mask_slice = slice(1, None)
99 | else:
100 | mask_slice = slice(0, 1)
101 | masks = masks[:, mask_slice, :, :]
102 | iou_pred = iou_pred[:, mask_slice]
103 |
104 | # Prepare output
105 | return masks, iou_pred
106 |
107 | def predict_masks(
108 | self,
109 | image_embeddings: torch.Tensor,
110 | image_pe: torch.Tensor,
111 | sparse_prompt_embeddings: torch.Tensor,
112 | dense_prompt_embeddings: torch.Tensor,
113 | ) -> Tuple[torch.Tensor, torch.Tensor]:
114 | """Predicts masks. See 'forward' for more details."""
115 | # Concatenate output tokens
116 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
117 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
118 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
119 |
120 | # Expand per-image data in batch direction to be per-mask
121 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
122 | src = src + dense_prompt_embeddings
123 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
124 | b, c, h, w = src.shape
125 |
126 | # Run the transformer
127 | hs, src = self.transformer(src, pos_src, tokens)
128 | iou_token_out = hs[:, 0, :]
129 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
130 |
131 | # Upscale mask embeddings and predict masks using the mask tokens
132 | src = src.transpose(1, 2).view(b, c, h, w)
133 | upscaled_embedding = self.output_upscaling(src)
134 | hyper_in_list: List[torch.Tensor] = []
135 | for i in range(self.num_mask_tokens):
136 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
137 | hyper_in = torch.stack(hyper_in_list, dim=1)
138 | b, c, h, w = upscaled_embedding.shape
139 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
140 |
141 | # Generate mask quality predictions
142 | iou_pred = self.iou_prediction_head(iou_token_out)
143 |
144 | return masks, iou_pred
145 |
146 |
147 | # Lightly adapted from
148 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
149 | class MLP(nn.Module):
150 | def __init__(
151 | self,
152 | input_dim: int,
153 | hidden_dim: int,
154 | output_dim: int,
155 | num_layers: int,
156 | sigmoid_output: bool = False,
157 | ) -> None:
158 | super().__init__()
159 | self.num_layers = num_layers
160 | h = [hidden_dim] * (num_layers - 1)
161 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
162 | self.sigmoid_output = sigmoid_output
163 |
164 | def forward(self, x):
165 | for i, layer in enumerate(self.layers):
166 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
167 | if self.sigmoid_output:
168 | x = F.sigmoid(x)
169 | return x
170 |
--------------------------------------------------------------------------------
/metaseg/modeling/prompt_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Any, Optional, Tuple, Type
8 |
9 | import numpy as np
10 | import torch
11 | from torch import nn
12 |
13 | from metaseg.modeling.common import LayerNorm2d
14 |
15 |
16 | class PromptEncoder(nn.Module):
17 | def __init__(
18 | self,
19 | embed_dim: int,
20 | image_embedding_size: Tuple[int, int],
21 | input_image_size: Tuple[int, int],
22 | mask_in_chans: int,
23 | activation: Type[nn.Module] = nn.GELU,
24 | ) -> None:
25 | """
26 | Encodes prompts for input to SAM's mask decoder.
27 |
28 | Arguments:
29 | embed_dim (int): The prompts' embedding dimension
30 | image_embedding_size (tuple(int, int)): The spatial size of the
31 | image embedding, as (H, W).
32 | input_image_size (int): The padded size of the image as input
33 | to the image encoder, as (H, W).
34 | mask_in_chans (int): The number of hidden channels used for
35 | encoding input masks.
36 | activation (nn.Module): The activation to use when encoding
37 | input masks.
38 | """
39 | super().__init__()
40 | self.embed_dim = embed_dim
41 | self.input_image_size = input_image_size
42 | self.image_embedding_size = image_embedding_size
43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44 |
45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
47 | self.point_embeddings = nn.ModuleList(point_embeddings)
48 | self.not_a_point_embed = nn.Embedding(1, embed_dim)
49 |
50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
51 | self.mask_downscaling = nn.Sequential(
52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
53 | LayerNorm2d(mask_in_chans // 4),
54 | activation(),
55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
56 | LayerNorm2d(mask_in_chans),
57 | activation(),
58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
59 | )
60 | self.no_mask_embed = nn.Embedding(1, embed_dim)
61 |
62 | def get_dense_pe(self) -> torch.Tensor:
63 | """
64 | Returns the positional encoding used to encode point prompts,
65 | applied to a dense set of points the shape of the image encoding.
66 |
67 | Returns:
68 | torch.Tensor: Positional encoding with shape
69 | 1x(embed_dim)x(embedding_h)x(embedding_w)
70 | """
71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0)
72 |
73 | def _embed_points(
74 | self,
75 | points: torch.Tensor,
76 | labels: torch.Tensor,
77 | pad: bool,
78 | ) -> torch.Tensor:
79 | """Embeds point prompts."""
80 | points = points + 0.5 # Shift to center of pixel
81 | if pad:
82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
84 | points = torch.cat([points, padding_point], dim=1)
85 | labels = torch.cat([labels, padding_label], dim=1)
86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
87 | point_embedding[labels == -1] = 0.0
88 | point_embedding[labels == -1] += self.not_a_point_embed.weight
89 | point_embedding[labels == 0] += self.point_embeddings[0].weight
90 | point_embedding[labels == 1] += self.point_embeddings[1].weight
91 | return point_embedding
92 |
93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
94 | """Embeds box prompts."""
95 | boxes = boxes + 0.5 # Shift to center of pixel
96 | coords = boxes.reshape(-1, 2, 2)
97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight
99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight
100 | return corner_embedding
101 |
102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
103 | """Embeds mask inputs."""
104 | mask_embedding = self.mask_downscaling(masks)
105 | return mask_embedding
106 |
107 | def _get_batch_size(
108 | self,
109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
110 | boxes: Optional[torch.Tensor],
111 | masks: Optional[torch.Tensor],
112 | ) -> int:
113 | """
114 | Gets the batch size of the output given the batch size of the input prompts.
115 | """
116 | if points is not None:
117 | return points[0].shape[0]
118 | elif boxes is not None:
119 | return boxes.shape[0]
120 | elif masks is not None:
121 | return masks.shape[0]
122 | else:
123 | return 1
124 |
125 | def _get_device(self) -> torch.device:
126 | return self.point_embeddings[0].weight.device
127 |
128 | def forward(
129 | self,
130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
131 | boxes: Optional[torch.Tensor],
132 | masks: Optional[torch.Tensor],
133 | ) -> Tuple[torch.Tensor, torch.Tensor]:
134 | """
135 | Embeds different types of prompts, returning both sparse and dense
136 | embeddings.
137 |
138 | Arguments:
139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
140 | and labels to embed.
141 | boxes (torch.Tensor or none): boxes to embed
142 | masks (torch.Tensor or none): masks to embed
143 |
144 | Returns:
145 | torch.Tensor: sparse embeddings for the points and boxes, with shape
146 | BxNx(embed_dim), where N is determined by the number of input points
147 | and boxes.
148 | torch.Tensor: dense embeddings for the masks, in the shape
149 | Bx(embed_dim)x(embed_H)x(embed_W)
150 | """
151 | bs = self._get_batch_size(points, boxes, masks)
152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
153 | if points is not None:
154 | coords, labels = points
155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
157 | if boxes is not None:
158 | box_embeddings = self._embed_boxes(boxes)
159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
160 |
161 | if masks is not None:
162 | dense_embeddings = self._embed_masks(masks)
163 | else:
164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
166 | )
167 |
168 | return sparse_embeddings, dense_embeddings
169 |
170 |
171 | class PositionEmbeddingRandom(nn.Module):
172 | """
173 | Positional encoding using random spatial frequencies.
174 | """
175 |
176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
177 | super().__init__()
178 | if scale is None or scale <= 0.0:
179 | scale = 1.0
180 | self.register_buffer(
181 | "positional_encoding_gaussian_matrix",
182 | scale * torch.randn((2, num_pos_feats)),
183 | )
184 |
185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
186 | """Positionally encode points that are normalized to [0,1]."""
187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
188 | coords = 2 * coords - 1
189 | coords = coords @ self.positional_encoding_gaussian_matrix
190 | coords = 2 * np.pi * coords
191 | # outputs d_1 x ... x d_n x C shape
192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
193 |
194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor:
195 | """Generate positional encoding for a grid of the specified size."""
196 | h, w = size
197 | device: Any = self.positional_encoding_gaussian_matrix.device
198 | grid = torch.ones((h, w), device=device, dtype=torch.float32)
199 | y_embed = grid.cumsum(dim=0) - 0.5
200 | x_embed = grid.cumsum(dim=1) - 0.5
201 | y_embed = y_embed / h
202 | x_embed = x_embed / w
203 |
204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
205 | return pe.permute(2, 0, 1) # C x H x W
206 |
207 | def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
208 | """Positionally encode points that are not normalized to [0,1]."""
209 | coords = coords_input.clone()
210 | coords[:, :, 0] = coords[:, :, 0] / image_size[1]
211 | coords[:, :, 1] = coords[:, :, 1] / image_size[0]
212 | return self._pe_encoding(coords.to(torch.float)) # B x N x C
213 |
--------------------------------------------------------------------------------
/metaseg/modeling/sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Any, Dict, List, Tuple
8 |
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 |
13 | from metaseg.modeling.image_encoder import ImageEncoderViT
14 | from metaseg.modeling.mask_decoder import MaskDecoder
15 | from metaseg.modeling.prompt_encoder import PromptEncoder
16 |
17 |
18 | class Sam(nn.Module):
19 | mask_threshold: float = 0.0
20 | image_format: str = "RGB"
21 |
22 | def __init__(
23 | self,
24 | image_encoder: ImageEncoderViT,
25 | prompt_encoder: PromptEncoder,
26 | mask_decoder: MaskDecoder,
27 | pixel_mean: List[float] = [123.675, 116.28, 103.53],
28 | pixel_std: List[float] = [58.395, 57.12, 57.375],
29 | ) -> None:
30 | """
31 | SAM predicts object masks from an image and input prompts.
32 |
33 | Arguments:
34 | image_encoder (ImageEncoderViT): The backbone used to encode the
35 | image into image embeddings that allow for efficient mask prediction.
36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts.
37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings
38 | and encoded prompts.
39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
40 | pixel_std (list(float)): Std values for normalizing pixels in the input image.
41 | """
42 | super().__init__()
43 | self.image_encoder = image_encoder
44 | self.prompt_encoder = prompt_encoder
45 | self.mask_decoder = mask_decoder
46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
48 |
49 | @property
50 | def device(self) -> Any:
51 | return self.pixel_mean.device
52 |
53 | @torch.no_grad()
54 | def forward(
55 | self,
56 | batched_input: List[Dict[str, Any]],
57 | multimask_output: bool,
58 | ) -> List[Dict[str, torch.Tensor]]:
59 | """
60 | Predicts masks end-to-end from provided images and prompts.
61 | If prompts are not known in advance, using SamPredictor is
62 | recommended over calling the model directly.
63 |
64 | Arguments:
65 | batched_input (list(dict)): A list over input images, each a
66 | dictionary with the following keys. A prompt key can be
67 | excluded if it is not present.
68 | 'image': The image as a torch tensor in 3xHxW format,
69 | already transformed for input to the model.
70 | 'original_size': (tuple(int, int)) The original size of
71 | the image before transformation, as (H, W).
72 | 'point_coords': (torch.Tensor) Batched point prompts for
73 | this image, with shape BxNx2. Already transformed to the
74 | input frame of the model.
75 | 'point_labels': (torch.Tensor) Batched labels for point prompts,
76 | with shape BxN.
77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
78 | Already transformed to the input frame of the model.
79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
80 | in the form Bx1xHxW.
81 | multimask_output (bool): Whether the model should predict multiple
82 | disambiguating masks, or return a single mask.
83 |
84 | Returns:
85 | (list(dict)): A list over input images, where each element is
86 | as dictionary with the following keys.
87 | 'masks': (torch.Tensor) Batched binary mask predictions,
88 | with shape BxCxHxW, where B is the number of input promts,
89 | C is determiend by multimask_output, and (H, W) is the
90 | original size of the image.
91 | 'iou_predictions': (torch.Tensor) The model's predictions
92 | of mask quality, in shape BxC.
93 | 'low_res_logits': (torch.Tensor) Low resolution logits with
94 | shape BxCxHxW, where H=W=256. Can be passed as mask input
95 | to subsequent iterations of prediction.
96 | """
97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
98 | image_embeddings = self.image_encoder(input_images)
99 |
100 | outputs = []
101 | for image_record, curr_embedding in zip(batched_input, image_embeddings):
102 | if "point_coords" in image_record:
103 | points = (image_record["point_coords"], image_record["point_labels"])
104 | else:
105 | points = None
106 | sparse_embeddings, dense_embeddings = self.prompt_encoder(
107 | points=points,
108 | boxes=image_record.get("boxes", None),
109 | masks=image_record.get("mask_inputs", None),
110 | )
111 | low_res_masks, iou_predictions = self.mask_decoder(
112 | image_embeddings=curr_embedding.unsqueeze(0),
113 | image_pe=self.prompt_encoder.get_dense_pe(),
114 | sparse_prompt_embeddings=sparse_embeddings,
115 | dense_prompt_embeddings=dense_embeddings,
116 | multimask_output=multimask_output,
117 | )
118 | masks = self.postprocess_masks(
119 | low_res_masks,
120 | input_size=image_record["image"].shape[-2:],
121 | original_size=image_record["original_size"],
122 | )
123 | masks = masks > self.mask_threshold
124 | outputs.append(
125 | {
126 | "masks": masks,
127 | "iou_predictions": iou_predictions,
128 | "low_res_logits": low_res_masks,
129 | }
130 | )
131 | return outputs
132 |
133 | def postprocess_masks(
134 | self,
135 | masks: torch.Tensor,
136 | input_size: Tuple[int, ...],
137 | original_size: Tuple[int, ...],
138 | ) -> torch.Tensor:
139 | """
140 | Remove padding and upscale masks to the original image size.
141 |
142 | Arguments:
143 | masks (torch.Tensor): Batched masks from the mask_decoder,
144 | in BxCxHxW format.
145 | input_size (tuple(int, int)): The size of the image input to the
146 | model, in (H, W) format. Used to remove padding.
147 | original_size (tuple(int, int)): The original size of the image
148 | before resizing for input to the model, in (H, W) format.
149 |
150 | Returns:
151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
152 | is given by original_size.
153 | """
154 | masks = F.interpolate(
155 | masks,
156 | (self.image_encoder.img_size, self.image_encoder.img_size),
157 | mode="bilinear",
158 | align_corners=False,
159 | )
160 | masks = masks[..., : input_size[0], : input_size[1]]
161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
162 | return masks
163 |
164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor:
165 | """Normalize pixel values and pad to a square input."""
166 | # Normalize colors
167 | x = (x - self.pixel_mean) / self.pixel_std
168 |
169 | # Pad
170 | h, w = x.shape[-2:]
171 | padh = self.image_encoder.img_size - h
172 | padw = self.image_encoder.img_size - w
173 | x = F.pad(x, (0, padw, 0, padh))
174 | return x
175 |
--------------------------------------------------------------------------------
/metaseg/modeling/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from typing import Tuple, Type
9 |
10 | import torch
11 | from torch import Tensor, nn
12 |
13 | from metaseg.modeling.common import MLPBlock
14 |
15 |
16 | class TwoWayTransformer(nn.Module):
17 | def __init__(
18 | self,
19 | depth: int,
20 | embedding_dim: int,
21 | num_heads: int,
22 | mlp_dim: int,
23 | activation: Type[nn.Module] = nn.ReLU,
24 | attention_downsample_rate: int = 2,
25 | ) -> None:
26 | """
27 | A transformer decoder that attends to an input image using
28 | queries whose positional embedding is supplied.
29 |
30 | Args:
31 | depth (int): number of layers in the transformer
32 | embedding_dim (int): the channel dimension for the input embeddings
33 | num_heads (int): the number of heads for multihead attention. Must
34 | divide embedding_dim
35 | mlp_dim (int): the channel dimension internal to the MLP block
36 | activation (nn.Module): the activation to use in the MLP block
37 | """
38 | super().__init__()
39 | self.depth = depth
40 | self.embedding_dim = embedding_dim
41 | self.num_heads = num_heads
42 | self.mlp_dim = mlp_dim
43 | self.layers = nn.ModuleList()
44 |
45 | for i in range(depth):
46 | self.layers.append(
47 | TwoWayAttentionBlock(
48 | embedding_dim=embedding_dim,
49 | num_heads=num_heads,
50 | mlp_dim=mlp_dim,
51 | activation=activation,
52 | attention_downsample_rate=attention_downsample_rate,
53 | skip_first_layer_pe=(i == 0),
54 | )
55 | )
56 |
57 | self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
58 | self.norm_final_attn = nn.LayerNorm(embedding_dim)
59 |
60 | def forward(
61 | self,
62 | image_embedding: Tensor,
63 | image_pe: Tensor,
64 | point_embedding: Tensor,
65 | ) -> Tuple[Tensor, Tensor]:
66 | """
67 | Args:
68 | image_embedding (torch.Tensor): image to attend to. Should be shape
69 | B x embedding_dim x h x w for any h and w.
70 | image_pe (torch.Tensor): the positional encoding to add to the image. Must
71 | have the same shape as image_embedding.
72 | point_embedding (torch.Tensor): the embedding to add to the query points.
73 | Must have shape B x N_points x embedding_dim for any N_points.
74 |
75 | Returns:
76 | torch.Tensor: the processed point_embedding
77 | torch.Tensor: the processed image_embedding
78 | """
79 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C
80 | bs, c, h, w = image_embedding.shape
81 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
82 | image_pe = image_pe.flatten(2).permute(0, 2, 1)
83 |
84 | # Prepare queries
85 | queries = point_embedding
86 | keys = image_embedding
87 |
88 | # Apply transformer blocks and final layernorm
89 | for layer in self.layers:
90 | queries, keys = layer(
91 | queries=queries,
92 | keys=keys,
93 | query_pe=point_embedding,
94 | key_pe=image_pe,
95 | )
96 |
97 | # Apply the final attenion layer from the points to the image
98 | q = queries + point_embedding
99 | k = keys + image_pe
100 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
101 | queries = queries + attn_out
102 | queries = self.norm_final_attn(queries)
103 |
104 | return queries, keys
105 |
106 |
107 | class TwoWayAttentionBlock(nn.Module):
108 | def __init__(
109 | self,
110 | embedding_dim: int,
111 | num_heads: int,
112 | mlp_dim: int = 2048,
113 | activation: Type[nn.Module] = nn.ReLU,
114 | attention_downsample_rate: int = 2,
115 | skip_first_layer_pe: bool = False,
116 | ) -> None:
117 | """
118 | A transformer block with four layers: (1) self-attention of sparse
119 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
120 | block on sparse inputs, and (4) cross attention of dense inputs to sparse
121 | inputs.
122 |
123 | Arguments:
124 | embedding_dim (int): the channel dimension of the embeddings
125 | num_heads (int): the number of heads in the attention layers
126 | mlp_dim (int): the hidden dimension of the mlp block
127 | activation (nn.Module): the activation of the mlp block
128 | skip_first_layer_pe (bool): skip the PE on the first layer
129 | """
130 | super().__init__()
131 | self.self_attn = Attention(embedding_dim, num_heads)
132 | self.norm1 = nn.LayerNorm(embedding_dim)
133 |
134 | self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
135 | self.norm2 = nn.LayerNorm(embedding_dim)
136 |
137 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
138 | self.norm3 = nn.LayerNorm(embedding_dim)
139 |
140 | self.norm4 = nn.LayerNorm(embedding_dim)
141 | self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
142 |
143 | self.skip_first_layer_pe = skip_first_layer_pe
144 |
145 | def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
146 | # Self attention block
147 | if self.skip_first_layer_pe:
148 | queries = self.self_attn(q=queries, k=queries, v=queries)
149 | else:
150 | q = queries + query_pe
151 | attn_out = self.self_attn(q=q, k=q, v=queries)
152 | queries = queries + attn_out
153 | queries = self.norm1(queries)
154 |
155 | # Cross attention block, tokens attending to image embedding
156 | q = queries + query_pe
157 | k = keys + key_pe
158 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
159 | queries = queries + attn_out
160 | queries = self.norm2(queries)
161 |
162 | # MLP block
163 | mlp_out = self.mlp(queries)
164 | queries = queries + mlp_out
165 | queries = self.norm3(queries)
166 |
167 | # Cross attention block, image embedding attending to tokens
168 | q = queries + query_pe
169 | k = keys + key_pe
170 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
171 | keys = keys + attn_out
172 | keys = self.norm4(keys)
173 |
174 | return queries, keys
175 |
176 |
177 | class Attention(nn.Module):
178 | """
179 | An attention layer that allows for downscaling the size of the embedding
180 | after projection to queries, keys, and values.
181 | """
182 |
183 | def __init__(
184 | self,
185 | embedding_dim: int,
186 | num_heads: int,
187 | downsample_rate: int = 1,
188 | ) -> None:
189 | super().__init__()
190 | self.embedding_dim = embedding_dim
191 | self.internal_dim = embedding_dim // downsample_rate
192 | self.num_heads = num_heads
193 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
194 |
195 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
196 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
197 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
198 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
199 |
200 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
201 | b, n, c = x.shape
202 | x = x.reshape(b, n, num_heads, c // num_heads)
203 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
204 |
205 | def _recombine_heads(self, x: Tensor) -> Tensor:
206 | b, n_heads, n_tokens, c_per_head = x.shape
207 | x = x.transpose(1, 2)
208 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
209 |
210 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
211 | # Input projections
212 | q = self.q_proj(q)
213 | k = self.k_proj(k)
214 | v = self.v_proj(v)
215 |
216 | # Separate into heads
217 | q = self._separate_heads(q, self.num_heads)
218 | k = self._separate_heads(k, self.num_heads)
219 | v = self._separate_heads(v, self.num_heads)
220 |
221 | # Attention
222 | _, _, _, c_per_head = q.shape
223 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
224 | attn = attn / math.sqrt(c_per_head)
225 | attn = torch.softmax(attn, dim=-1)
226 |
227 | # Get output
228 | out = attn @ v
229 | out = self._recombine_heads(out)
230 | out = self.out_proj(out)
231 |
232 | return out
233 |
--------------------------------------------------------------------------------
/metaseg/sahi_predict.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import torch
5 | from PIL import Image
6 |
7 | from metaseg import SamPredictor, sam_model_registry
8 | from metaseg.utils import download_model, load_image, multi_boxes, plt_load_box, plt_load_mask,load_image_vide
9 |
10 |
11 | def sahi_sliced_predict(
12 | image_path,
13 | detection_model_type,
14 | detection_model_path,
15 | conf_th,
16 | image_size,
17 | slice_height,
18 | slice_width,
19 | overlap_height_ratio,
20 | overlap_width_ratio,
21 | ):
22 |
23 | try:
24 | from sahi import AutoDetectionModel
25 | from sahi.predict import get_prediction, get_sliced_prediction
26 | except ImportError:
27 | raise ImportError("Please install SAHI library using 'pip install sahi'.")
28 |
29 | device = "cuda" if torch.cuda.is_available() else "cpu"
30 |
31 | detection_model = AutoDetectionModel.from_pretrained(
32 | image_size=image_size,
33 | model_type=detection_model_type,
34 | model_path=detection_model_path,
35 | confidence_threshold=conf_th,
36 | device=device,
37 | )
38 | result = get_sliced_prediction(
39 | image_path,
40 | detection_model,
41 | slice_height=slice_height,
42 | slice_width=slice_width,
43 | overlap_height_ratio=overlap_height_ratio,
44 | overlap_width_ratio=overlap_width_ratio,
45 | )
46 |
47 | result = get_prediction(image_path, detection_model)
48 | output = result.object_prediction_list
49 | boxes = []
50 | for i in output:
51 | boxes.append(i.bbox.to_xyxy())
52 |
53 | return boxes
54 |
55 |
56 | class SahiAutoSegmentation:
57 | def __init__(self):
58 | self.model = None
59 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
60 |
61 | def load_model(self, model_type):
62 | if self.model is None:
63 | self.model_path = download_model(model_type)
64 | self.model = sam_model_registry[model_type](checkpoint=self.model_path)
65 | self.model.to(device=self.device)
66 |
67 | return self.model
68 |
69 | def predict(
70 | self,
71 | source,
72 | model_type,
73 | input_box=None,
74 | input_point=None,
75 | input_label=None,
76 | multimask_output=False,
77 | random_color=False,
78 | show=False,
79 | save=False,
80 | ):
81 |
82 | # read_image = load_image(source)
83 | read_vide = load_image_vide(source)
84 | model = self.load_model(model_type)
85 | predictor = SamPredictor(model)
86 | predictor.set_image(read_vide)
87 |
88 | if type(input_box[0]) == list:
89 | input_boxes, new_boxes = multi_boxes(input_box, predictor, read_vide)
90 |
91 | masks, _, _ = predictor.predict_torch(
92 | point_coords=None,
93 | point_labels=None,
94 | boxes=new_boxes,
95 | multimask_output=False,
96 | )
97 |
98 | elif type(input_box[0]) == int:
99 | input_boxes = np.array(input_box)[None, :]
100 |
101 | masks, _, _ = predictor.predict(
102 | point_coords=input_point,
103 | point_labels=input_label,
104 | box=input_boxes,
105 | multimask_output=multimask_output,
106 | )
107 |
108 |
109 | return masks
110 |
--------------------------------------------------------------------------------
/metaseg/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from metaseg.utils.data_utils import *
8 | from metaseg.utils.file_utils import *
9 |
--------------------------------------------------------------------------------
/metaseg/utils/amg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from copy import deepcopy
9 | from itertools import product
10 | from typing import Any, Dict, Generator, ItemsView, List, Tuple
11 |
12 | import numpy as np
13 | import torch
14 |
15 |
16 | class MaskData:
17 | """
18 | A structure for storing masks and their related data in batched format.
19 | Implements basic filtering and concatenation.
20 | """
21 |
22 | def __init__(self, **kwargs) -> None:
23 | for v in kwargs.values():
24 | assert isinstance(
25 | v, (list, np.ndarray, torch.Tensor)
26 | ), "MaskData only supports list, numpy arrays, and torch tensors."
27 | self._stats = dict(**kwargs)
28 |
29 | def __setitem__(self, key: str, item: Any) -> None:
30 | assert isinstance(
31 | item, (list, np.ndarray, torch.Tensor)
32 | ), "MaskData only supports list, numpy arrays, and torch tensors."
33 | self._stats[key] = item
34 |
35 | def __delitem__(self, key: str) -> None:
36 | del self._stats[key]
37 |
38 | def __getitem__(self, key: str) -> Any:
39 | return self._stats[key]
40 |
41 | def items(self) -> ItemsView[str, Any]:
42 | return self._stats.items()
43 |
44 | def filter(self, keep: torch.Tensor) -> None:
45 | for k, v in self._stats.items():
46 | if v is None:
47 | self._stats[k] = None
48 | elif isinstance(v, torch.Tensor):
49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
50 | elif isinstance(v, np.ndarray):
51 | self._stats[k] = v[keep.detach().cpu().numpy()]
52 | elif isinstance(v, list) and keep.dtype == torch.bool:
53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
54 | elif isinstance(v, list):
55 | self._stats[k] = [v[i] for i in keep]
56 | else:
57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
58 |
59 | def cat(self, new_stats: "MaskData") -> None:
60 | for k, v in new_stats.items():
61 | if k not in self._stats or self._stats[k] is None:
62 | self._stats[k] = deepcopy(v)
63 | elif isinstance(v, torch.Tensor):
64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0)
65 | elif isinstance(v, np.ndarray):
66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
67 | elif isinstance(v, list):
68 | self._stats[k] = self._stats[k] + deepcopy(v)
69 | else:
70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
71 |
72 | def to_numpy(self) -> None:
73 | for k, v in self._stats.items():
74 | if isinstance(v, torch.Tensor):
75 | self._stats[k] = v.detach().cpu().numpy()
76 |
77 |
78 | def is_box_near_crop_edge(
79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
80 | ) -> torch.Tensor:
81 | """Filter masks at the edge of a crop, but not at the edge of the original image."""
82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
88 | return torch.any(near_crop_edge, dim=1)
89 |
90 |
91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
92 | box_xywh = deepcopy(box_xyxy)
93 | box_xywh[2] = box_xywh[2] - box_xywh[0]
94 | box_xywh[3] = box_xywh[3] - box_xywh[1]
95 | return box_xywh
96 |
97 |
98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
99 | assert len(args) > 0 and all(
100 | len(a) == len(args[0]) for a in args
101 | ), "Batched iteration must have inputs of all the same size."
102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
103 | for b in range(n_batches):
104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
105 |
106 |
107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
108 | """
109 | Encodes masks to an uncompressed RLE, in the format expected by
110 | pycoco tools.
111 | """
112 | # Put in fortran order and flatten h,w
113 | b, h, w = tensor.shape
114 | tensor = tensor.permute(0, 2, 1).flatten(1)
115 |
116 | # Compute change indices
117 | diff = tensor[:, 1:] ^ tensor[:, :-1]
118 | change_indices = diff.nonzero()
119 |
120 | # Encode run length
121 | out = []
122 | for i in range(b):
123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1]
124 | cur_idxs = torch.cat(
125 | [
126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
127 | cur_idxs + 1,
128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
129 | ]
130 | )
131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
132 | counts = [] if tensor[i, 0] == 0 else [0]
133 | counts.extend(btw_idxs.detach().cpu().tolist())
134 | out.append({"size": [h, w], "counts": counts})
135 | return out
136 |
137 |
138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
139 | """Compute a binary mask from an uncompressed RLE."""
140 | h, w = rle["size"]
141 | mask = np.empty(h * w, dtype=bool)
142 | idx = 0
143 | parity = False
144 | for count in rle["counts"]:
145 | mask[idx : idx + count] = parity
146 | idx += count
147 | parity ^= True
148 | mask = mask.reshape(w, h)
149 | return mask.transpose() # Put in C order
150 |
151 |
152 | def area_from_rle(rle: Dict[str, Any]) -> int:
153 | return sum(rle["counts"][1::2])
154 |
155 |
156 | def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
157 | """
158 | Computes the stability score for a batch of masks. The stability
159 | score is the IoU between the binary masks obtained by thresholding
160 | the predicted mask logits at high and low values.
161 | """
162 | # One mask is always contained inside the other.
163 | # Save memory by preventing unnecesary cast to torch.int64
164 | intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
165 | unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
166 | return intersections / unions
167 |
168 |
169 | def build_point_grid(n_per_side: int) -> np.ndarray:
170 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
171 | offset = 1 / (2 * n_per_side)
172 | points_one_side = np.linspace(offset, 1 - offset, n_per_side)
173 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
174 | points_y = np.tile(points_one_side[:, None], (1, n_per_side))
175 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
176 | return points
177 |
178 |
179 | def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
180 | """Generates point grids for all crop layers."""
181 | points_by_layer = []
182 | for i in range(n_layers + 1):
183 | n_points = int(n_per_side / (scale_per_layer ** i))
184 | points_by_layer.append(build_point_grid(n_points))
185 | return points_by_layer
186 |
187 |
188 | def generate_crop_boxes(
189 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
190 | ) -> Tuple[List[List[int]], List[int]]:
191 | """
192 | Generates a list of crop boxes of different sizes. Each layer
193 | has (2**i)**2 boxes for the ith layer.
194 | """
195 | crop_boxes, layer_idxs = [], []
196 | im_h, im_w = im_size
197 | short_side = min(im_h, im_w)
198 |
199 | # Original image
200 | crop_boxes.append([0, 0, im_w, im_h])
201 | layer_idxs.append(0)
202 |
203 | def crop_len(orig_len, n_crops, overlap):
204 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
205 |
206 | for i_layer in range(n_layers):
207 | n_crops_per_side = 2 ** (i_layer + 1)
208 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
209 |
210 | crop_w = crop_len(im_w, n_crops_per_side, overlap)
211 | crop_h = crop_len(im_h, n_crops_per_side, overlap)
212 |
213 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
214 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
215 |
216 | # Crops in XYWH format
217 | for x0, y0 in product(crop_box_x0, crop_box_y0):
218 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
219 | crop_boxes.append(box)
220 | layer_idxs.append(i_layer + 1)
221 |
222 | return crop_boxes, layer_idxs
223 |
224 |
225 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
226 | x0, y0, _, _ = crop_box
227 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
228 | # Check if boxes has a channel dimension
229 | if len(boxes.shape) == 3:
230 | offset = offset.unsqueeze(1)
231 | return boxes + offset
232 |
233 |
234 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
235 | x0, y0, _, _ = crop_box
236 | offset = torch.tensor([[x0, y0]], device=points.device)
237 | # Check if points has a channel dimension
238 | if len(points.shape) == 3:
239 | offset = offset.unsqueeze(1)
240 | return points + offset
241 |
242 |
243 | def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
244 | x0, y0, x1, y1 = crop_box
245 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
246 | return masks
247 | # Coordinate transform masks
248 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
249 | pad = (x0, pad_x - x0, y0, pad_y - y0)
250 | return torch.nn.functional.pad(masks, pad, value=0)
251 |
252 |
253 | def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
254 | """
255 | Removes small disconnected regions and holes in a mask. Returns the
256 | mask and an indicator of if the mask has been modified.
257 | """
258 | import cv2 # type: ignore
259 |
260 | assert mode in ["holes", "islands"]
261 | correct_holes = mode == "holes"
262 | working_mask = (correct_holes ^ mask).astype(np.uint8)
263 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
264 | sizes = stats[:, -1][1:] # Row 0 is background label
265 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
266 | if len(small_regions) == 0:
267 | return mask, False
268 | fill_labels = [0] + small_regions
269 | if not correct_holes:
270 | fill_labels = [i for i in range(n_labels) if i not in fill_labels]
271 | # If every region is below threshold, keep largest
272 | if len(fill_labels) == 0:
273 | fill_labels = [int(np.argmax(sizes)) + 1]
274 | mask = np.isin(regions, fill_labels)
275 | return mask, True
276 |
277 |
278 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
279 | from pycocotools import mask as mask_utils # type: ignore
280 |
281 | h, w = uncompressed_rle["size"]
282 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
283 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
284 | return rle
285 |
286 |
287 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
288 | """
289 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
290 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
291 | """
292 | # torch.max below raises an error on empty inputs, just skip in this case
293 | if torch.numel(masks) == 0:
294 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
295 |
296 | # Normalize shape to CxHxW
297 | shape = masks.shape
298 | h, w = shape[-2:]
299 | if len(shape) > 2:
300 | masks = masks.flatten(0, -3)
301 | else:
302 | masks = masks.unsqueeze(0)
303 |
304 | # Get top and bottom edges
305 | in_height, _ = torch.max(masks, dim=-1)
306 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
307 | bottom_edges, _ = torch.max(in_height_coords, dim=-1)
308 | in_height_coords = in_height_coords + h * (~in_height)
309 | top_edges, _ = torch.min(in_height_coords, dim=-1)
310 |
311 | # Get left and right edges
312 | in_width, _ = torch.max(masks, dim=-2)
313 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
314 | right_edges, _ = torch.max(in_width_coords, dim=-1)
315 | in_width_coords = in_width_coords + w * (~in_width)
316 | left_edges, _ = torch.min(in_width_coords, dim=-1)
317 |
318 | # If the mask is empty the right edge will be to the left of the left edge.
319 | # Replace these boxes with [0, 0, 0, 0]
320 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
321 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
322 | out = out * (~empty_filter).unsqueeze(-1)
323 |
324 | # Return to original shape
325 | if len(shape) > 2:
326 | out = out.reshape(*shape[:-2], 4)
327 | else:
328 | out = out[0]
329 |
330 | return out
331 |
--------------------------------------------------------------------------------
/metaseg/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | def load_image(image_path):
2 | import cv2
3 |
4 | image = cv2.imread(image_path)
5 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
6 | return image
7 | def load_image_vide(image_path):
8 | import cv2
9 |
10 |
11 | image = cv2.cvtColor(image_path, cv2.COLOR_BGR2RGB)
12 | return image
13 |
14 | def load_server_image(image_path):
15 | import os
16 | from io import BytesIO
17 | from uuid import uuid4
18 |
19 | from PIL import Image
20 |
21 | imagedir = str(uuid4())
22 | os.system(f"mkdir -p {imagedir}")
23 | image = Image.open(BytesIO(image_path))
24 | if image.mode != "RGB":
25 | image = image.convert("RGB")
26 |
27 | image_path = f"{imagedir}/base_image_v0.png"
28 | output_path = f"{imagedir}/output_v0.png"
29 | image.save(image_path, format="PNG")
30 | return image_path, output_path
31 |
32 |
33 | def load_video(video_path, output_path="output.mp4"):
34 | import cv2
35 |
36 | cap = cv2.VideoCapture(video_path)
37 | frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
38 | frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
39 | fourcc = cv2.VideoWriter_fourcc(*"XVID")
40 | fps = int(cap.get(cv2.CAP_PROP_FPS))
41 | out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
42 |
43 | return cap, out
44 |
45 |
46 | def read_image(image_path):
47 | import cv2
48 |
49 | image = cv2.imread(image_path)
50 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
51 | return image
52 |
53 |
54 | def load_mask(mask, random_color):
55 | import numpy as np
56 |
57 | if random_color:
58 | color = np.random.rand(3) * 255
59 | else:
60 | color = np.array([100, 50, 0])
61 |
62 | h, w = mask.shape[-2:]
63 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
64 | mask_image = mask_image.astype(np.uint8)
65 | return mask_image
66 |
67 |
68 | def load_box(box, image):
69 | import cv2
70 |
71 | x, y, w, h = int(box[0]), int(box[1]), int(box[2]), int(box[3])
72 | cv2.rectangle(image, (x, y), (w, h), (0, 255, 0), 2)
73 | return image
74 |
75 |
76 | def plt_load_mask(mask, ax, random_color=False):
77 | import numpy as np
78 |
79 | if random_color:
80 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
81 | else:
82 | color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
83 | h, w = mask.shape[-2:]
84 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
85 | ax.imshow(mask_image)
86 |
87 |
88 | def plt_load_box(box, ax):
89 | import matplotlib.pyplot as plt
90 |
91 | x0, y0 = box[0], box[1]
92 | w, h = box[2] - box[0], box[3] - box[1]
93 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
94 |
95 |
96 | def multi_boxes(boxes, predictor, image):
97 | import torch
98 |
99 | input_boxes = torch.tensor(boxes, device=predictor.device)
100 | transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
101 | return input_boxes, transformed_boxes
102 |
103 |
104 | def show_image(output_image):
105 | import cv2
106 |
107 | cv2.imshow("output", output_image)
108 | cv2.waitKey(0)
109 | cv2.destroyAllWindows()
110 |
111 |
112 | def save_image(output_image, output_path):
113 | import cv2
114 |
115 | cv2.imwrite(output_path, output_image)
116 |
--------------------------------------------------------------------------------
/metaseg/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import urllib.request
3 |
4 |
5 | def download_model(model_type):
6 | """
7 | model_type: str, A string representing the model type. It can be 'vit_h', 'vit_l', or 'vit_b'.
8 | """
9 |
10 | # A dictionary containing model types as keys and their respective URLs as values
11 | model_urls = {
12 | "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
13 | "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
14 | "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
15 | }
16 |
17 | # Check if the model file already exists and model_type is in model_urls
18 | filename = f"{model_type}.pth"
19 | if not os.path.exists(filename) and model_type in model_urls:
20 | url = model_urls[model_type]
21 | print(f"Downloading {model_type} model from {url}...")
22 | urllib.request.urlretrieve(url, filename)
23 | print(f"{model_type} model has been successfully downloaded and saved as '{filename}'.")
24 | elif os.path.exists(filename):
25 | print(f"{model_type} model already exists as '{filename}'. Skipping download.")
26 | else:
27 | raise ValueError("Invalid model type. It should be 'vit_h', 'vit_l', or 'vit_b'.")
28 |
29 | return filename
30 |
--------------------------------------------------------------------------------
/metaseg/utils/onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Tuple
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.nn import functional as F
12 |
13 | from metaseg.modeling import Sam
14 | from metaseg.utils.amg import calculate_stability_score
15 |
16 |
17 | class SamOnnxModel(nn.Module):
18 | """
19 | This model should not be called directly, but is used in ONNX export.
20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
21 | with some functions modified to enable model tracing. Also supports extra
22 | options controlling what information. See the ONNX export script for details.
23 | """
24 |
25 | def __init__(
26 | self,
27 | model: Sam,
28 | return_single_mask: bool,
29 | use_stability_score: bool = False,
30 | return_extra_metrics: bool = False,
31 | ) -> None:
32 | super().__init__()
33 | self.mask_decoder = model.mask_decoder
34 | self.model = model
35 | self.img_size = model.image_encoder.img_size
36 | self.return_single_mask = return_single_mask
37 | self.use_stability_score = use_stability_score
38 | self.stability_score_offset = 1.0
39 | self.return_extra_metrics = return_extra_metrics
40 |
41 | @staticmethod
42 | def resize_longest_image_size(input_image_size: torch.Tensor, longest_side: int) -> torch.Tensor:
43 | input_image_size = input_image_size.to(torch.float32)
44 | scale = longest_side / torch.max(input_image_size)
45 | transformed_size = scale * input_image_size
46 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
47 | return transformed_size
48 |
49 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
50 | point_coords = point_coords + 0.5
51 | point_coords = point_coords / self.img_size
52 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
53 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
54 |
55 | point_embedding = point_embedding * (point_labels != -1)
56 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
57 |
58 | for i in range(self.model.prompt_encoder.num_point_embeddings):
59 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[i].weight * (
60 | point_labels == i
61 | )
62 |
63 | return point_embedding
64 |
65 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
66 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
67 | mask_embedding = mask_embedding + (1 - has_mask_input) * self.model.prompt_encoder.no_mask_embed.weight.reshape(
68 | 1, -1, 1, 1
69 | )
70 | return mask_embedding
71 |
72 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
73 | masks = F.interpolate(
74 | masks,
75 | size=(self.img_size, self.img_size),
76 | mode="bilinear",
77 | align_corners=False,
78 | )
79 |
80 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size)
81 | masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])]
82 |
83 | orig_im_size = orig_im_size.to(torch.int64)
84 | h, w = orig_im_size[0], orig_im_size[1]
85 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
86 | return masks
87 |
88 | def select_masks(
89 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
90 | ) -> Tuple[torch.Tensor, torch.Tensor]:
91 | # Determine if we should return the multiclick mask or not from the number of points.
92 | # The reweighting is used to avoid control flow.
93 | score_reweight = torch.tensor([[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]).to(
94 | iou_preds.device
95 | )
96 | score = iou_preds + (num_points - 2.5) * score_reweight
97 | best_idx = torch.argmax(score, dim=1)
98 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
99 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
100 |
101 | return masks, iou_preds
102 |
103 | @torch.no_grad()
104 | def forward(
105 | self,
106 | image_embeddings: torch.Tensor,
107 | point_coords: torch.Tensor,
108 | point_labels: torch.Tensor,
109 | mask_input: torch.Tensor,
110 | has_mask_input: torch.Tensor,
111 | orig_im_size: torch.Tensor,
112 | ):
113 | sparse_embedding = self._embed_points(point_coords, point_labels)
114 | dense_embedding = self._embed_masks(mask_input, has_mask_input)
115 |
116 | masks, scores = self.model.mask_decoder.predict_masks(
117 | image_embeddings=image_embeddings,
118 | image_pe=self.model.prompt_encoder.get_dense_pe(),
119 | sparse_prompt_embeddings=sparse_embedding,
120 | dense_prompt_embeddings=dense_embedding,
121 | )
122 |
123 | if self.use_stability_score:
124 | scores = calculate_stability_score(masks, self.model.mask_threshold, self.stability_score_offset)
125 |
126 | if self.return_single_mask:
127 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
128 |
129 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
130 |
131 | if self.return_extra_metrics:
132 | stability_scores = calculate_stability_score(
133 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset
134 | )
135 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
136 | return upscaled_masks, scores, stability_scores, areas, masks
137 |
138 | return upscaled_masks, scores, masks
139 |
--------------------------------------------------------------------------------
/metaseg/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from copy import deepcopy
8 | from typing import Tuple
9 |
10 | import numpy as np
11 | import torch
12 | from torch.nn import functional as F
13 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
14 |
15 |
16 | class ResizeLongestSide:
17 | """
18 | Resizes images to longest side 'target_length', as well as provides
19 | methods for resizing coordinates and boxes. Provides methods for
20 | transforming both numpy array and batched torch tensors.
21 | """
22 |
23 | def __init__(self, target_length: int) -> None:
24 | self.target_length = target_length
25 |
26 | def apply_image(self, image: np.ndarray) -> np.ndarray:
27 | """
28 | Expects a numpy array with shape HxWxC in uint8 format.
29 | """
30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
31 | return np.array(resize(to_pil_image(image), target_size))
32 |
33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
34 | """
35 | Expects a numpy array of length 2 in the final dimension. Requires the
36 | original image size in (H, W) format.
37 | """
38 | old_h, old_w = original_size
39 | new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
40 | coords = deepcopy(coords).astype(float)
41 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
42 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
43 | return coords
44 |
45 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
46 | """
47 | Expects a numpy array shape Bx4. Requires the original image size
48 | in (H, W) format.
49 | """
50 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
51 | return boxes.reshape(-1, 4)
52 |
53 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
54 | """
55 | Expects batched images with shape BxCxHxW and float format. This
56 | transformation may not exactly match apply_image. apply_image is
57 | the transformation expected by the model.
58 | """
59 | # Expects an image in BCHW format. May not exactly match apply_image.
60 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
61 | return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)
62 |
63 | def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
64 | """
65 | Expects a torch tensor with length 2 in the last dimension. Requires the
66 | original image size in (H, W) format.
67 | """
68 | old_h, old_w = original_size
69 | new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
70 | coords = deepcopy(coords).to(torch.float)
71 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
72 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
73 | return coords
74 |
75 | def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
76 | """
77 | Expects a torch tensor with shape Bx4. Requires the original image
78 | size in (H, W) format.
79 | """
80 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
81 | return boxes.reshape(-1, 4)
82 |
83 | @staticmethod
84 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
85 | """
86 | Compute the output size given input size and target long side length.
87 | """
88 | scale = long_side_length * 1.0 / max(oldh, oldw)
89 | newh, neww = oldh * scale, oldw * scale
90 | neww = int(neww + 0.5)
91 | newh = int(newh + 0.5)
92 | return (newh, neww)
93 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.18.0
2 | aiofiles==23.1.0
3 | aiohttp==3.8.4
4 | aiosignal==1.3.1
5 | altair==4.2.2
6 | antlr4-python3-runtime==4.9.3
7 | anyio==3.6.2
8 | async-timeout==4.0.2
9 | attrs==23.1.0
10 | auth0-python==4.2.0
11 | certifi==2022.12.7
12 | cffi==1.15.1
13 | charset-normalizer==2.1.1
14 | click==8.0.4
15 | colorama==0.4.6
16 | contourpy==1.0.7
17 | cryptography==40.0.2
18 | cycler==0.11.0
19 | datadog-api-client==2.12.0
20 | Deprecated==1.2.13
21 | diffusers==0.14.0
22 | dill==0.3.5.1
23 | distlib==0.3.6
24 | entrypoints==0.4
25 | exceptiongroup==1.1.1
26 | fal-serverless==0.6.29
27 | fastapi==0.95.1
28 | ffmpy==0.3.0
29 | filelock==3.12.0
30 | fire==0.5.0
31 | Flask==1.1.4
32 | Flask-Cors==3.0.10
33 | flaskwebgui==0.3.5
34 | fonttools==4.39.3
35 | frozenlist==1.3.3
36 | fsspec==2023.4.0
37 | gradio==3.28.1
38 | gradio_client==0.1.4
39 | grpc-interceptor==0.15.1
40 | grpcio==1.54.0
41 | h11==0.14.0
42 | httpcore==0.17.0
43 | httpx==0.24.0
44 | huggingface-hub==0.14.1
45 | idna==3.4
46 | imageio==2.28.1
47 | importlib-metadata==6.0.1
48 | iniconfig==2.0.0
49 | isolate==0.12.0
50 | isolate-proto==0.0.27
51 | itsdangerous==1.1.0
52 | Jinja2==2.11.3
53 | jsonschema==4.17.3
54 | kiwisolver==1.4.4
55 | lama-cleaner==1.1.2
56 | linkify-it-py==2.0.2
57 | loguru==0.7.0
58 | markdown-it-py==2.2.0
59 | MarkupSafe==2.0.1
60 | matplotlib==3.7.1
61 | mdit-py-plugins==0.3.3
62 | mdurl==0.1.2
63 | mediapipe==0.9.3.0
64 | mpmath==1.2.1
65 | multidict==6.0.4
66 | networkx==3.0
67 | numpy==1.24.1
68 | omegaconf==2.3.0
69 | opencv-contrib-python==4.7.0.72
70 | opencv-python==4.7.0.72
71 | opentelemetry-api==1.17.0
72 | opentelemetry-sdk==1.17.0
73 | opentelemetry-semantic-conventions==0.38b0
74 | orjson==3.8.11
75 | packaging==23.1
76 | pandas==2.0.1
77 | piexif==1.1.3
78 | Pillow==9.3.0
79 | platformdirs==3.5.0
80 | pluggy==1.0.0
81 | portalocker==2.7.0
82 | protobuf==4.22.3
83 | psutil==5.9.5
84 | pybboxes==0.1.6
85 | pycparser==2.21
86 | pydantic==1.10.7
87 | pydub==0.25.1
88 | Pygments==2.15.1
89 | PyJWT==2.6.0
90 | pyparsing==3.0.9
91 | pyrsistent==0.19.3
92 | pytest==7.3.1
93 | python-dateutil==2.8.2
94 | python-multipart==0.0.6
95 | pytz==2023.3
96 | PyWavelets==1.4.1
97 | pywin32==306
98 | PyYAML==6.0
99 | regex==2023.5.4
100 | requests==2.28.1
101 | rich==13.3.5
102 | safetensors==0.3.1
103 | sahi==0.11.13
104 | scikit-image==0.19.3
105 | scipy==1.10.1
106 | seaborn==0.12.2
107 | semantic-version==2.10.0
108 | sentry-sdk==1.21.1
109 | shapely==2.0.1
110 | six==1.16.0
111 | sniffio==1.3.0
112 | sounddevice==0.4.6
113 | starlette==0.26.1
114 | structlog==22.3.0
115 | sympy==1.11.1
116 | tblib==1.7.0
117 | termcolor==2.3.0
118 | terminaltables==3.1.10
119 | thop==0.1.1.post2209072238
120 | tifffile==2023.4.12
121 | tokenizers==0.13.3
122 | tomli==2.0.1
123 | toolz==0.12.0
124 | torch==2.0.0+cu118
125 | torchaudio==2.0.1+cu118
126 | torchvision==0.15.1+cu118
127 | tqdm==4.65.0
128 | transformers==4.27.4
129 | typing_extensions==4.4.0
130 | tzdata==2023.3
131 | uc-micro-py==1.0.2
132 | ultralytics==8.0.91
133 | urllib3==1.26.13
134 | uvicorn==0.22.0
135 | virtualenv==20.23.0
136 | websockets==11.0.2
137 | Werkzeug==1.0.1
138 | whichcraft==0.6.1
139 | win32-setctime==1.1.0
140 | wrapt==1.15.0
141 | yacs==0.1.8
142 | yarl==1.9.2
143 | zipp==3.15.0
144 |
--------------------------------------------------------------------------------