├── .idea ├── .gitignore ├── githubai.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── README_cn.md ├── app.py ├── demo.py ├── example ├── 1.gif ├── 1683122305662.png ├── 1683122435166.png ├── 1683134557206.png ├── 2.gif ├── 3.gif ├── 4.gif ├── 5.gif ├── 6.gif ├── image.jpg ├── imagemask.jpg ├── images.png └── mask1.jpg ├── lamamodel.py ├── metaseg ├── __init__.py ├── falai_demo.py ├── generator │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ └── predictor.py ├── mask_predictor.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py ├── sahi_predict.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── data_utils.py │ ├── file_utils.py │ ├── onnx.py │ └── transforms.py └── requirements.txt /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/githubai.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | English | [简体中文](README_cn.md) 2 | # Modify-Anything: Segment Anything Meets Video and Image Modify and Picture Video Background Replacement 3 | 4 | Modify-Anything is based on YOLO5,YOLO8, for video and image detection. Segment-anything,lama_cleaner is applied to 5 | segment, modify, erase, and obtain the target image. The background of the target image video 6 | can be changed, and the background of the image video can be changed. 7 | 8 | 9 | ## Installation 10 | The code requires python>=3.8, as well as pytorch>=1.7 and torchvision>=0.8. Please follow the instructions here 11 | to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support 12 | is strongly recommended. 13 | To install the Modify-Anything, please follow these steps: 14 | - The first time it runs, it will download the model itself. If the download is too slow, the phone will download and place it as follows 15 | - Train your own YOLO5 or YOLOv8 models to detect segmentation, modification, and erasure. 16 | The default models used in this project are "YOLOv5l. pt", "YOLOv5l6. pt", "YOLOv8l. pt", and "YOLOv8x. pt". 17 | Please download and place them in the project root directory 18 | - Download the Segment anything model and place it in the project root directory sam_vit_h_4b8939.pth (change to) vit_h.pth,sam_vit_l_0b3195.pth (change to) vit_l.pth,sam_vit_b_01ec64.pth (change to) vit_b.pth 19 | - Install pip install ultralytics sahi fal_serverless lama_cleaner tqdm or pip install - r requirements. Txt 20 | - Run python app.py 21 | - The generated results are all in the output directory 22 | 23 |

24 | image 25 |

26 | 27 | ## Modify Anything Image and Picture Video Background Replacement 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 |
36 | 37 | 38 | 39 | 40 | 41 | 42 |
43 | 44 | ## Modify Anything Video and Picture Video Background Replacement 45 | 46 | 47 | 48 | 49 | 50 | 51 |
52 | 53 | 54 | 55 | 56 | 57 |
58 | 59 | ## Acknowledgments 60 | - [LaMa](https://github.com/advimman/lama) 61 | - [Segment Anything](https://github.com/facebookresearch/segment-anything) 62 | - [YOLOv8](https://github.com/ultralytics/ultralytics) 63 | 64 | ## Citation 65 | If you find this work useful for your research, please cite us: 66 | ``` 67 | @article{ 68 | title={Modify-Anything: Segment Anything Meets Video and Image Modify and Picture Video Background Replacement}, 69 | author={Zhang Jing}, 70 | year={2023} 71 | } 72 | ``` 73 | -------------------------------------------------------------------------------- /README_cn.md: -------------------------------------------------------------------------------- 1 | 简体中文 | [English](README.md) 2 | # Modify-Anything: Segment Anything Meets Video and Image Modify and Picture Video Background Replacement 3 | 4 | Modify-Anything 是基于YOLO5,YOLO8对视频,图片检测指定目标进行应用segment-anything,lama_cleaner对目标 5 | 分割,修改,擦除,得到目标对目标图片视频进行换背景等,可换图背景视频背景。 6 | 7 | ## 安装 8 | 该代码需要python>=3.8,以及pytorch>=1.7和torchvision>=0.8。请按照此处的说明安装PyTorch和TorchVision依赖项。 9 | 强烈建议在支持CUDA的情况下同时安装PyTorch和TorchVision。 10 | 11 | 要安装Modify Anything,请执行以下步骤: 12 | - 第一次运行都会自己下载模型,如下载过慢手机下载放置如下 13 | - 训练自己的要检测分割,修改,擦除的yolo5或者yolov8模型,本项目用的默认"yolov5l.pt", "yolov5l6.pt", "yolov8l.pt", "yolov8x.pt"模型请自己下载放入项目根目录中 14 | - 下载Segment-anything模型放入项目根目录中,sam_vit_h_4b8939.pth 改为vit_h.pth,sam_vit_l_0b3195.pth改为vit_l.pth,sam_vit_b_01ec64.pth改为vit_b.pth 15 | - 安装pip install ultralytics sahi fal_serverless lama_cleaner tqdm 或者pip install -r requirements.txt 16 | - 运行python app.py 17 | - 生成结果全在output目录 18 | 19 |

20 | image 21 |

22 | 23 | ## Modify Anything Image and Picture Video Background Replacement 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 |
32 | 33 | 34 | 35 | 36 | 37 | 38 |
39 | 40 | ## Modify Anything Video and Picture Video Background Replacement 41 | 42 | 43 | 44 | 45 | 46 | 47 |
48 | 49 | 50 | 51 | 52 | 53 |
54 | 55 | ## Acknowledgments 56 | - [LaMa](https://github.com/advimman/lama) 57 | - [Segment Anything](https://github.com/facebookresearch/segment-anything) 58 | - [YOLOv8](https://github.com/ultralytics/ultralytics) 59 | 60 | ## Citation 61 | If you find this work useful for your research, please cite us: 62 | ``` 63 | @article{ 64 | title={Modify-Anything: Segment Anything Meets Video and Image Modify and Picture Video Background Replacement}, 65 | author={Zhang Jing}, 66 | year={2023} 67 | } 68 | ``` -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from demo import lama_image_app, lama_video_app, change_video, change_image, change_aimage, change_avideo 3 | 4 | available_lamamodels = ["lama", "ldm", "zits", "mat", "fcf", "manga", "sd2"] 5 | default_lamammodel = "lama" 6 | 7 | 8 | def image_app(): 9 | with gr.Blocks(): 10 | with gr.Row(): 11 | with gr.Column(): 12 | image_file = gr.Image(type="filepath").style(height=260) 13 | with gr.Row(): 14 | with gr.Column(): 15 | image_model_type = gr.Dropdown( 16 | choices=[ 17 | "vit_h", 18 | "vit_l", 19 | "vit_b", 20 | ], 21 | value="vit_l", 22 | label="Model Type", 23 | ) 24 | 25 | with gr.Row(): 26 | with gr.Column(): 27 | sahi_model_type = gr.Dropdown( 28 | choices=[ 29 | "yolov5", 30 | "yolov8", 31 | ], 32 | value="yolov8", 33 | label="Detector Model Type", 34 | ) 35 | sahi_image_size = gr.Slider( 36 | minimum=0, 37 | maximum=1600, 38 | step=32, 39 | value=640, 40 | label="Image Size", 41 | ) 42 | 43 | sahi_overlap_width = gr.Slider( 44 | minimum=0, 45 | maximum=1, 46 | step=0.1, 47 | value=0.2, 48 | label="Overlap Width", 49 | ) 50 | 51 | sahi_slice_width = gr.Slider( 52 | minimum=0, 53 | maximum=640, 54 | step=32, 55 | value=256, 56 | label="Slice Width", 57 | ) 58 | 59 | with gr.Row(): 60 | with gr.Column(): 61 | sahi_model_path = gr.Dropdown( 62 | choices=["yolov5l.pt", "yolov5l6.pt", "yolov8l.pt", "yolov8x.pt"], 63 | value="yolov8l.pt", 64 | label="Detector Model Path", 65 | ) 66 | selected_lamamodel = gr.Dropdown(choices=available_lamamodels, label="lama Model(lama, ldm, zits, mat, fcf, mang, sd2)", 67 | value=default_lamammodel, 68 | interactive=True) 69 | 70 | sahi_conf_th = gr.Slider( 71 | minimum=0, 72 | maximum=1, 73 | step=0.1, 74 | value=0.2, 75 | label="Confidence Threshold", 76 | ) 77 | sahi_overlap_height = gr.Slider( 78 | minimum=0, 79 | maximum=1, 80 | step=0.1, 81 | value=0.2, 82 | label="Overlap Height", 83 | ) 84 | sahi_slice_height = gr.Slider( 85 | minimum=0, 86 | maximum=640, 87 | step=32, 88 | value=256, 89 | label="Slice Height", 90 | ) 91 | image_predict = gr.Button(value="Generate vector images from targets(去目标生成矢量图片)") 92 | 93 | with gr.Column(): 94 | output_image = gr.Gallery() 95 | 96 | image_predict.click( 97 | fn=lama_image_app, 98 | inputs=[ 99 | image_file, 100 | image_model_type, 101 | selected_lamamodel, 102 | sahi_model_type, 103 | sahi_model_path, 104 | sahi_conf_th, 105 | sahi_image_size, 106 | sahi_slice_height, 107 | sahi_slice_width, 108 | sahi_overlap_height, 109 | sahi_overlap_width, 110 | ], 111 | outputs=[output_image], 112 | ) 113 | with gr.Row(): 114 | with gr.Column(): 115 | b_image = gr.Image(type="filepath") 116 | with gr.Column(): 117 | output_change = gr.Image() 118 | with gr.Row(): 119 | change_images = gr.Button(value="Target and background image(目标与背景图)") 120 | change_images.click( 121 | fn=change_aimage, 122 | inputs=[image_file, b_image], 123 | outputs=[output_change] 124 | 125 | ) 126 | with gr.Row(): 127 | with gr.Column(): 128 | b_video = gr.Video(type="filepath").style(height=260) 129 | 130 | with gr.Column(): 131 | output_change = gr.Video() 132 | with gr.Row(): 133 | change_videos = gr.Button(value="Target and Background Video(目标与背景视频)") 134 | change_videos.click( 135 | fn=change_avideo, 136 | inputs=[image_file, b_video], 137 | outputs=[output_change] 138 | 139 | ) 140 | 141 | 142 | def video_app(): 143 | with gr.Blocks(): 144 | with gr.Row(): 145 | with gr.Column(): 146 | sahi_image_file = gr.Video().style(height=260) 147 | sahi_autoseg_model_type = gr.Dropdown( 148 | choices=[ 149 | "vit_h", 150 | "vit_l", 151 | "vit_b", 152 | ], 153 | value="vit_l", 154 | label="Sam Model Type", 155 | ) 156 | 157 | with gr.Row(): 158 | with gr.Column(): 159 | sahi_model_type = gr.Dropdown( 160 | choices=[ 161 | "yolov5", 162 | "yolov8", 163 | ], 164 | value="yolov8", 165 | label="Detector Model Type", 166 | ) 167 | sahi_image_size = gr.Slider( 168 | minimum=0, 169 | maximum=1600, 170 | step=32, 171 | value=640, 172 | label="Image Size", 173 | ) 174 | 175 | sahi_overlap_width = gr.Slider( 176 | minimum=0, 177 | maximum=1, 178 | step=0.1, 179 | value=0.2, 180 | label="Overlap Width", 181 | ) 182 | 183 | sahi_slice_width = gr.Slider( 184 | minimum=0, 185 | maximum=640, 186 | step=32, 187 | value=256, 188 | label="Slice Width", 189 | ) 190 | 191 | with gr.Row(): 192 | with gr.Column(): 193 | sahi_model_path = gr.Dropdown( 194 | choices=["yolov5l.pt", "yolov5l6.pt", "yolov8l.pt", "yolov8x.pt"], 195 | value="yolov8l.pt", 196 | label="Detector Model Path", 197 | ) 198 | selected_lamamodel = gr.Dropdown(choices=available_lamamodels, label="lama Model(lama, ldm, zits, mat, fcf, mang, sd2)", 199 | value=default_lamammodel, 200 | interactive=True) 201 | 202 | sahi_conf_th = gr.Slider( 203 | minimum=0, 204 | maximum=1, 205 | step=0.1, 206 | value=0.2, 207 | label="Confidence Threshold", 208 | ) 209 | sahi_overlap_height = gr.Slider( 210 | minimum=0, 211 | maximum=1, 212 | step=0.1, 213 | value=0.2, 214 | label="Overlap Height", 215 | ) 216 | sahi_slice_height = gr.Slider( 217 | minimum=0, 218 | maximum=640, 219 | step=32, 220 | value=256, 221 | label="Slice Height", 222 | ) 223 | sahi_image_predict = gr.Button(value="Generate vector video by removing targets(去目标生成矢量视频)") 224 | 225 | with gr.Column(): 226 | output_video = gr.Video() 227 | output_video1 = gr.Video() 228 | 229 | sahi_image_predict.click( 230 | fn=lama_video_app, 231 | inputs=[ 232 | sahi_image_file, 233 | sahi_autoseg_model_type, 234 | selected_lamamodel, 235 | sahi_model_type, 236 | sahi_model_path, 237 | sahi_conf_th, 238 | sahi_image_size, 239 | sahi_slice_height, 240 | sahi_slice_width, 241 | sahi_overlap_height, 242 | sahi_overlap_width, 243 | ], 244 | outputs=[output_video, output_video1], 245 | ) 246 | with gr.Row(): 247 | with gr.Column(): 248 | b_image = gr.Image(type="filepath").style(height=260) 249 | 250 | with gr.Column(): 251 | output_change = gr.Video() 252 | 253 | with gr.Row(): 254 | change_images = gr.Button(value="Target and background image(目标与背景图)") 255 | change_images.click( 256 | fn=change_image, 257 | inputs=[sahi_image_file, b_image], 258 | outputs=[output_change]) 259 | 260 | with gr.Row(): 261 | with gr.Column(): 262 | b_video = gr.Video() 263 | with gr.Column(): 264 | output_change = gr.Video() 265 | with gr.Row(): 266 | change_videos = gr.Button(value="Target and Background Video(目标与背景视频)") 267 | change_videos.click( 268 | fn=change_video, 269 | inputs=[sahi_image_file, b_video], 270 | outputs=[output_change] 271 | 272 | ) 273 | 274 | 275 | def metaseg_app(): 276 | app = gr.Blocks() 277 | with app: 278 | with gr.Row(): 279 | with gr.Column(): 280 | with gr.Tab("Image"): 281 | image_app() 282 | 283 | with gr.Tab("Video"): 284 | video_app() 285 | 286 | app.queue(concurrency_count=1) 287 | app.launch(debug=True, enable_queue=True) 288 | 289 | 290 | if __name__ == "__main__": 291 | metaseg_app() 292 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from metaseg import SahiAutoSegmentation,sahi_sliced_predict 2 | import cv2 3 | from tqdm import tqdm 4 | import os 5 | from lamamodel import lamamain 6 | import numpy as np 7 | 8 | def load_video(video_path, output_path="output.mp4"): 9 | cap = cv2.VideoCapture(video_path) 10 | frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 11 | frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 12 | fourcc = cv2.VideoWriter_fourcc(*"XVID") 13 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 14 | out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height)) 15 | 16 | return cap, out 17 | def lama_image_app( 18 | image_path, 19 | sam_model_type, 20 | selected_lamamodel, 21 | detection_model_type, 22 | detection_model_path, 23 | conf_th, 24 | image_size, 25 | slice_height, 26 | slice_width, 27 | overlap_height_ratio, 28 | overlap_width_ratio, 29 | ): 30 | img = cv2.imread(image_path, cv2.IMREAD_COLOR) 31 | width = img.shape[1] 32 | height = img.shape[0] 33 | boxes = sahi_sliced_predict( 34 | image_path=img, 35 | detection_model_type=detection_model_type, # yolov8, detectron2, mmdetection, torchvision 36 | detection_model_path=detection_model_path, 37 | conf_th=conf_th, 38 | image_size=image_size, 39 | slice_height=slice_height, 40 | slice_width=slice_width, 41 | overlap_height_ratio=overlap_height_ratio, 42 | overlap_width_ratio=overlap_width_ratio, 43 | ) 44 | if len(boxes) == 0: 45 | boxes = [0, 0, 0, 0] 46 | 47 | masks =SahiAutoSegmentation().predict( 48 | source=img, 49 | model_type=sam_model_type, 50 | input_box=boxes, 51 | multimask_output=False, 52 | random_color=False, 53 | show=False, 54 | ) 55 | a = np.full((1, height, width), False) 56 | for idx, mask in enumerate(masks): 57 | if mask[0] != None: 58 | bmasks = mask.detach().cpu().numpy() 59 | mask2 = np.logical_or(a, bmasks) 60 | a[:, :, :] = mask2 61 | h, w = a.shape[-2:] 62 | mask_image = a.reshape(h, w, 1) 63 | cv2.imwrite('./output/aimage/image.jpg', img) 64 | cv2.imwrite("output/amask/mask.jpg" , mask_image * 255) 65 | mask_image = mask_image * img 66 | tmp = cv2.cvtColor(mask_image, cv2.COLOR_BGR2GRAY) 67 | _, alpha = cv2.threshold(tmp, 0, 255, cv2.THRESH_BINARY) 68 | b, g, r = cv2.split(mask_image) 69 | rgba = [b, g, r, alpha] 70 | dst = cv2.merge(rgba, 4) 71 | 72 | 73 | cv2.imwrite("output/aimages/images.png" , dst) 74 | 75 | lamaindex = lamamain("output/aimage/image.jpg", "output/amask/mask.jpg" , selected_lamamodel) 76 | lamaimage = cv2.cvtColor(lamaindex, cv2.COLOR_BGR2RGB) 77 | cv2.imwrite('output/alama/imagemask.jpg' , lamaimage) 78 | 79 | lists=["output/aimage/image.jpg","output/amask/mask.jpg","output/aimages/images.png","output/alama/imagemask.jpg"] 80 | 81 | 82 | return lists 83 | 84 | 85 | 86 | def lama_video_app( 87 | image_path, 88 | sam_model_type, 89 | selected_lamamodel, 90 | detection_model_type, 91 | detection_model_path, 92 | conf_th, 93 | image_size, 94 | slice_height, 95 | slice_width, 96 | overlap_height_ratio, 97 | overlap_width_ratio, 98 | ): 99 | cap, out = load_video(image_path) 100 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 101 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 102 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 103 | index = 1 104 | for _ in tqdm(range(length)): 105 | ret, frame = cap.read() 106 | if not ret: 107 | break 108 | 109 | boxes = sahi_sliced_predict( 110 | image_path=frame, 111 | detection_model_type=detection_model_type, # yolov8, detectron2, mmdetection, torchvision 112 | detection_model_path=detection_model_path, 113 | conf_th=conf_th, 114 | image_size=image_size, 115 | slice_height=slice_height, 116 | slice_width=slice_width, 117 | overlap_height_ratio=overlap_height_ratio, 118 | overlap_width_ratio=overlap_width_ratio, 119 | ) 120 | if len(boxes) == 0: 121 | boxes = [0, 0, 0, 0] 122 | 123 | 124 | masks = SahiAutoSegmentation().predict( 125 | source=frame, 126 | model_type=sam_model_type, 127 | input_box=boxes, 128 | multimask_output=False, 129 | random_color=False, 130 | show=False, 131 | ) 132 | a = np.full((1, height, width), False) 133 | for idx, mask in enumerate(masks): 134 | if mask[0] != None: 135 | bmasks = mask.detach().cpu().numpy() 136 | mask2 = np.logical_or(a, bmasks) 137 | a[:, :, :] = mask2 138 | h, w = a.shape[-2:] 139 | mask_image = a.reshape(h, w, 1) 140 | cv2.imwrite('./output/image/image%s.jpg' % index, frame) 141 | cv2.imwrite("output/mask/mask%s.jpg" % index, mask_image * 255) 142 | mask_image = mask_image * frame 143 | tmp = cv2.cvtColor(mask_image, cv2.COLOR_BGR2GRAY) 144 | _, alpha = cv2.threshold(tmp, 0, 255, cv2.THRESH_BINARY) 145 | b, g, r = cv2.split(mask_image) 146 | rgba = [b, g, r, alpha] 147 | dst = cv2.merge(rgba, 4) 148 | cv2.imwrite("output/images/images%s.png" % index, dst) 149 | 150 | lamaindex = lamamain("output/image/image%s.jpg" % index, "output/mask/mask%s.jpg" % index,selected_lamamodel) 151 | lamaimage = cv2.cvtColor(lamaindex, cv2.COLOR_BGR2RGB) 152 | cv2.imwrite('output/lama/imagemask%s.jpg' % index, lamaimage) 153 | index += 1 154 | list = os.listdir("./output/lama/") 155 | list.sort(key=lambda x: int(x.replace("imagemask", "").split('.')[0])) 156 | list1 = os.listdir("./output/images/") 157 | list1.sort(key=lambda x: int(x.replace("images", "").split('.')[0])) 158 | frame_width1 = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 159 | frame_height1 = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 160 | fourcc = cv2.VideoWriter_fourcc(*"XVID") 161 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 162 | out1 = cv2.VideoWriter("output1.mp4", fourcc, fps, (frame_width1, frame_height1)) 163 | ## 使用切片将图片名称单独切开 164 | 165 | for path in list: 166 | paths = ("./output/lama/" + path) 167 | frame = cv2.imread(paths) 168 | out.write(frame) 169 | 170 | for path in list1: 171 | paths = ("./output/images/" + path) 172 | frame = cv2.imread(paths) 173 | out1.write(frame) 174 | cap.release() 175 | return "output.mp4", "output1.mp4" 176 | 177 | 178 | def change_video(video,bgvideo): 179 | cap_img = cv2.VideoCapture(video) 180 | fps = cap_img.get(cv2.CAP_PROP_FPS) 181 | total_frames = int(cap_img.get(cv2.CAP_PROP_FRAME_COUNT)) 182 | width = int(cap_img.get(cv2.CAP_PROP_FRAME_WIDTH)) 183 | height = int(cap_img.get(cv2.CAP_PROP_FRAME_HEIGHT)) 184 | cap_out = cv2.VideoWriter("output2.mp4", 185 | cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, 186 | (width, height)) 187 | cap_bg = cv2.VideoCapture(bgvideo) 188 | bg_frame_nums = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT) 189 | bg_frame_idx = 1 190 | with tqdm(total=total_frames) as pbar: 191 | img_frame_idx = 1 192 | while cap_img.isOpened(): 193 | ret_img, origin_img = cap_img.read() 194 | if not ret_img: 195 | break 196 | ret_bg, bg = cap_bg.read() 197 | img = cv2.imread("output/mask/mask%s.jpg" % img_frame_idx) 198 | _, mask_thr = cv2.threshold(img, 240, 1, cv2.THRESH_BINARY) 199 | if not ret_bg: 200 | break 201 | bg_frame_idx += 1 202 | if bg_frame_idx == bg_frame_nums: 203 | bg_frame_idx = 1 204 | cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0) 205 | bg = cv2.resize(bg, (width, height)) 206 | if bg.ndim == 2: 207 | bg = bg[..., np.newaxis] 208 | out = (mask_thr * origin_img + (1 - mask_thr) * bg).astype(np.uint8) 209 | cap_out.write(out) 210 | img_frame_idx += 1 211 | pbar.update(1) 212 | cap_img.release() 213 | cap_out.release() 214 | return "output2.mp4" 215 | 216 | def change_image(video,bgimage): 217 | cap_img = cv2.VideoCapture(video) 218 | fps = cap_img.get(cv2.CAP_PROP_FPS) 219 | total_frames = int(cap_img.get(cv2.CAP_PROP_FRAME_COUNT)) 220 | width = int(cap_img.get(cv2.CAP_PROP_FRAME_WIDTH)) 221 | height = int(cap_img.get(cv2.CAP_PROP_FRAME_HEIGHT)) 222 | cap_out = cv2.VideoWriter("output3.mp4", 223 | cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, 224 | (width, height)) 225 | bg=cv2.imread(bgimage) 226 | with tqdm(total=total_frames) as pbar: 227 | img_frame_idx = 1 228 | while cap_img.isOpened(): 229 | ret_img, origin_img = cap_img.read() 230 | if not ret_img: 231 | break 232 | img = cv2.imread("output/mask/mask%s.jpg" % img_frame_idx) 233 | _, mask_thr = cv2.threshold(img, 240, 1, cv2.THRESH_BINARY) 234 | bg = cv2.resize(bg, (width, height)) 235 | if bg.ndim == 2: 236 | bg = bg[..., np.newaxis] 237 | out = (mask_thr * origin_img + (1 - mask_thr) * bg).astype(np.uint8) 238 | cap_out.write(out) 239 | img_frame_idx += 1 240 | pbar.update(1) 241 | cap_img.release() 242 | cap_out.release() 243 | return "output3.mp4" 244 | 245 | def change_aimage(image,bgimage): 246 | origin_img=cv2.imread(image) 247 | bg=cv2.imread(bgimage) 248 | img = cv2.imread("output/amask/mask.jpg") 249 | _, mask_thr = cv2.threshold(img, 240, 1, cv2.THRESH_BINARY) 250 | width = origin_img.shape[1] 251 | height = origin_img.shape[0] 252 | bg = cv2.resize(bg, (width, height)) 253 | if bg.ndim == 2: 254 | bg = bg[..., np.newaxis] 255 | out = (mask_thr * origin_img + (1 - mask_thr) * bg).astype(np.uint8) 256 | out = cv2.cvtColor(out, cv2.COLOR_BGR2RGB) 257 | return out 258 | def change_avideo(image,bgvideo): 259 | origin_img = cv2.imread(image) 260 | 261 | cap_bg = cv2.VideoCapture(bgvideo) 262 | fps = cap_bg.get(cv2.CAP_PROP_FPS) 263 | bg_frame_nums = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT) 264 | total_frames = int(cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)) 265 | width = int(cap_bg.get(cv2.CAP_PROP_FRAME_WIDTH)) 266 | height = int(cap_bg.get(cv2.CAP_PROP_FRAME_HEIGHT)) 267 | cap_out = cv2.VideoWriter("output4.mp4", 268 | cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, 269 | (width, height)) 270 | mask = cv2.imread("output/amask/mask.jpg") 271 | origin_img = cv2.resize(origin_img, (width, height)) 272 | mask = cv2.resize(mask, (width, height)) 273 | for _ in tqdm(range(total_frames)): 274 | ret_bg, bg = cap_bg.read() 275 | if not ret_bg: 276 | break 277 | _, mask_thr = cv2.threshold(mask, 240, 1, cv2.THRESH_BINARY) 278 | out = (mask_thr * origin_img + (1 - mask_thr) * bg).astype(np.uint8) 279 | cap_out.write(out) 280 | cap_bg.release() 281 | cap_out.release() 282 | return "output4.mp4" -------------------------------------------------------------------------------- /example/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/1.gif -------------------------------------------------------------------------------- /example/1683122305662.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/1683122305662.png -------------------------------------------------------------------------------- /example/1683122435166.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/1683122435166.png -------------------------------------------------------------------------------- /example/1683134557206.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/1683134557206.png -------------------------------------------------------------------------------- /example/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/2.gif -------------------------------------------------------------------------------- /example/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/3.gif -------------------------------------------------------------------------------- /example/4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/4.gif -------------------------------------------------------------------------------- /example/5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/5.gif -------------------------------------------------------------------------------- /example/6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/6.gif -------------------------------------------------------------------------------- /example/image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/image.jpg -------------------------------------------------------------------------------- /example/imagemask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/imagemask.jpg -------------------------------------------------------------------------------- /example/images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/images.png -------------------------------------------------------------------------------- /example/mask1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/example/mask1.jpg -------------------------------------------------------------------------------- /lamamodel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lama_cleaner.parse_args import parse_args 3 | import random 4 | import time 5 | import imghdr 6 | from typing import Union 7 | 8 | import cv2 9 | import torch 10 | from loguru import logger 11 | from lama_cleaner.model_manager import ModelManager 12 | from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler 13 | 14 | 15 | from enum import Enum 16 | 17 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 18 | 19 | model: ModelManager(name="lama", device=device) 20 | device = torch.device(device) 21 | input_image_path: str = None 22 | is_disable_model_switch: bool = False 23 | is_desktop: bool = False 24 | from lama_cleaner.helper import ( 25 | resize_max_size, 26 | ) 27 | 28 | 29 | def diffuser_callback(i, t, latents): 30 | pass 31 | def get_image_ext(img_bytes): 32 | w = imghdr.what("", img_bytes) 33 | if w is None: 34 | w = "jpeg" 35 | return w 36 | class LDMSampler(str, Enum): 37 | ddim = "ddim" 38 | plms = "plms" 39 | def get_data(img_p ,mask_p): 40 | img = cv2.imread(str(img_p)) 41 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) 42 | mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE) 43 | mask = cv2.dilate( 44 | mask, 45 | np.ones((10, 10), np.uint8), 46 | iterations=1 47 | ) 48 | img = cv2.resize(img, None, fx=1, fy= 1.0, interpolation=cv2.INTER_AREA) 49 | mask = cv2.resize(mask, None, fx=1, fy= 1.0, interpolation=cv2.INTER_NEAREST) 50 | return img, mask 51 | 52 | def process(img_p, mask_p): 53 | image, mask = get_data(img_p=img_p, mask_p=mask_p) 54 | alpha_channel = image[:, :, -1] 55 | if image.shape[:2] != mask.shape[:2]: 56 | return f"Mask shape{mask.shape[:2]} not queal to Image shape{image.shape[:2]}", 400 57 | 58 | original_shape = image.shape 59 | interpolation = cv2.INTER_CUBIC 60 | 61 | size_limit: Union[int, str] = 2500 62 | 63 | if size_limit == "Original": 64 | size_limit = max(image.shape) 65 | else: 66 | size_limit = int(size_limit) 67 | 68 | config = Config( 69 | ldm_steps=1, 70 | ldm_sampler=LDMSampler.plms, 71 | hd_strategy=HDStrategy.ORIGINAL, 72 | hd_strategy_crop_margin=32, 73 | hd_strategy_crop_trigger_size=200, 74 | hd_strategy_resize_limit=200, 75 | ) 76 | if config.sd_seed == -1: 77 | config.sd_seed = random.randint(1, 999999999) 78 | 79 | logger.info(f"Origin image shape: {original_shape}") 80 | image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation) 81 | logger.info(f"Resized image shape: {image.shape}") 82 | 83 | mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) 84 | 85 | start = time.time() 86 | 87 | res_np_img = model(image, mask, config) # -----------------------导入模型 88 | 89 | return res_np_img 90 | 91 | 92 | def lamamain(image,mask,name): 93 | args = parse_args() 94 | global model 95 | global device 96 | global input_image_path 97 | global is_disable_model_switch 98 | global is_desktop 99 | 100 | device = torch.device(args.device) 101 | input_image_path = args.input 102 | is_disable_model_switch = args.disable_model_switch 103 | is_desktop = args.gui 104 | if is_disable_model_switch: 105 | logger.info(f"Start with --disable-model-switch, model switch on frontend is disable") 106 | 107 | model = ModelManager( 108 | name=name, 109 | device=device, 110 | hf_access_token=args.hf_access_token, 111 | sd_disable_nsfw=args.sd_disable_nsfw, 112 | sd_cpu_textencoder=args.sd_cpu_textencoder, 113 | sd_run_local=args.sd_run_local, 114 | callback=diffuser_callback, 115 | ) 116 | image=process(image,mask) 117 | image=np.uint8(image) 118 | images = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 119 | return images 120 | # cv2.imwrite('imagemask.jpg', images) 121 | # print(images) 122 | if __name__ == '__main__': 123 | lamamain() -------------------------------------------------------------------------------- /metaseg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from metaseg.falai_demo import falai_automask_image, falai_manuelmask_image 8 | from metaseg.generator.automatic_mask_generator import SamAutomaticMaskGenerator 9 | from metaseg.generator.build_sam import build_sam, build_sam_vit_b, build_sam_vit_h, build_sam_vit_l, sam_model_registry 10 | from metaseg.generator.predictor import SamPredictor 11 | from metaseg.mask_predictor import SegAutoMaskPredictor, SegManualMaskPredictor 12 | from metaseg.sahi_predict import SahiAutoSegmentation, sahi_sliced_predict 13 | 14 | __version__ = "0.7.3" 15 | -------------------------------------------------------------------------------- /metaseg/falai_demo.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | from PIL import Image 4 | 5 | from metaseg.mask_predictor import SegAutoMaskPredictor, SegManualMaskPredictor 6 | from metaseg.utils.data_utils import load_server_image 7 | 8 | try: 9 | from fal_serverless import isolated 10 | except ImportError: 11 | raise ImportError("Please install FalAI library using 'pip install fal_serverless'.") 12 | 13 | 14 | @isolated(requirements=["metaseg"], keep_alive=1800, machine_type="GPU-T4") 15 | def automask_image(data, model_type="vit_b", points_per_side=16, points_per_batch=32, min_area=0): 16 | image_path, output_path = load_server_image(data) 17 | SegAutoMaskPredictor().image_predict( 18 | source=image_path, 19 | model_type=model_type, 20 | points_per_side=points_per_side, 21 | points_per_batch=points_per_batch, 22 | min_area=min_area, 23 | output_path=output_path, 24 | show=False, 25 | save=True, 26 | ) 27 | with open(output_path, "rb") as f: 28 | result = f.read() 29 | 30 | return result 31 | 32 | @isolated(requirements=["metaseg"], keep_alive=1800, machine_type="GPU-T4") 33 | def manuelmask_image( 34 | data, 35 | model_type="vit_b", 36 | input_point=[[100, 100], [200, 200]], 37 | input_label=[0, 1], 38 | input_box=[100, 100, 200, 200], 39 | multimask_output=False, 40 | random_color=False, 41 | min_area=0, 42 | ): 43 | image_path, output_path = load_server_image(data) 44 | SegManualMaskPredictor().image_predict( 45 | source=image_path, 46 | model_type=model_type, 47 | input_point=input_point, 48 | input_label=input_label, 49 | input_box=input_box, # 50 | multimask_output=multimask_output, 51 | random_color=random_color, 52 | min_area=min_area, # 53 | output_path=output_path, 54 | show=False, 55 | save=True, 56 | ) 57 | with open(output_path, "rb") as f: 58 | result = f.read() 59 | 60 | return result 61 | 62 | 63 | def falai_automask_image(image_path, model_type="vit_b", points_per_side=16, points_per_batch=32, min_area=0): 64 | with open(image_path, "rb") as f: 65 | data = f.read() 66 | 67 | image = automask_image( 68 | data=data, 69 | model_type=model_type, 70 | points_per_side=points_per_side, 71 | points_per_batch=points_per_batch, 72 | min_area=min_area, 73 | ) 74 | image = Image.open(BytesIO(image)) 75 | return image 76 | 77 | 78 | def falai_manuelmask_image( 79 | image_path, 80 | model_type="vit_b", 81 | input_point=[[100, 100], [200, 200]], 82 | input_label=[0, 1], 83 | input_box=[100, 100, 200, 200], 84 | multimask_output=False, 85 | random_color=False, 86 | min_area=0, 87 | ): 88 | with open(image_path, "rb") as f: 89 | data = f.read() 90 | 91 | image = manuelmask_image( 92 | data=data, 93 | model_type=model_type, 94 | input_point=input_point, 95 | input_label=input_label, 96 | input_box=input_box, 97 | multimask_output=multimask_output, 98 | random_color=random_color, 99 | min_area=min_area, 100 | ) 101 | image = Image.open(BytesIO(image)) 102 | return image 103 | -------------------------------------------------------------------------------- /metaseg/generator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxaizj/Modify-Anything/e6bcd7476d32f71ad3ed53e134a045224873cb4f/metaseg/generator/__init__.py -------------------------------------------------------------------------------- /metaseg/generator/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, List, Optional, Tuple 8 | 9 | import numpy as np 10 | import torch 11 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 12 | 13 | from metaseg.generator.predictor import SamPredictor 14 | from metaseg.modeling import Sam 15 | from metaseg.utils.amg import ( 16 | MaskData, 17 | area_from_rle, 18 | batch_iterator, 19 | batched_mask_to_box, 20 | box_xyxy_to_xywh, 21 | build_all_layer_point_grids, 22 | calculate_stability_score, 23 | coco_encode_rle, 24 | generate_crop_boxes, 25 | is_box_near_crop_edge, 26 | mask_to_rle_pytorch, 27 | remove_small_regions, 28 | rle_to_mask, 29 | uncrop_boxes_xyxy, 30 | uncrop_masks, 31 | uncrop_points, 32 | ) 33 | 34 | 35 | class SamAutomaticMaskGenerator: 36 | def __init__( 37 | self, 38 | model: Sam, 39 | points_per_side: Optional[int] = 32, 40 | points_per_batch: int = 64, 41 | pred_iou_thresh: float = 0.88, 42 | stability_score_thresh: float = 0.95, 43 | stability_score_offset: float = 1.0, 44 | box_nms_thresh: float = 0.7, 45 | crop_n_layers: int = 0, 46 | crop_nms_thresh: float = 0.7, 47 | crop_overlap_ratio: float = 512 / 1500, 48 | crop_n_points_downscale_factor: int = 1, 49 | point_grids: Optional[List[np.ndarray]] = None, 50 | min_mask_region_area: int = 0, 51 | output_mode: str = "binary_mask", 52 | ) -> None: 53 | """ 54 | Using a SAM model, generates masks for the entire image. 55 | Generates a grid of point prompts over the image, then filters 56 | low quality and duplicate masks. The default settings are chosen 57 | for SAM with a ViT-H backbone. 58 | 59 | Arguments: 60 | model (Sam): The SAM model to use for mask prediction. 61 | points_per_side (int or None): The number of points to be sampled 62 | along one side of the image. The total number of points is 63 | points_per_side**2. If None, 'point_grids' must provide explicit 64 | point sampling. 65 | points_per_batch (int): Sets the number of points run simultaneously 66 | by the model. Higher numbers may be faster but use more GPU memory. 67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 68 | model's predicted mask quality. 69 | stability_score_thresh (float): A filtering threshold in [0,1], using 70 | the stability of the mask under changes to the cutoff used to binarize 71 | the model's mask predictions. 72 | stability_score_offset (float): The amount to shift the cutoff when 73 | calculated the stability score. 74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 75 | suppression to filter duplicate masks. 76 | crops_n_layers (int): If >0, mask prediction will be run again on 77 | crops of the image. Sets the number of layers to run, where each 78 | layer has 2**i_layer number of image crops. 79 | crops_nms_thresh (float): The box IoU cutoff used by non-maximal 80 | suppression to filter duplicate masks between different crops. 81 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 82 | In the first crop layer, crops will overlap by this fraction of 83 | the image length. Later layers with more crops scale down this overlap. 84 | crop_n_points_downscale_factor (int): The number of points-per-side 85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 86 | point_grids (list(np.ndarray) or None): A list over explicit grids 87 | of points used for sampling, normalized to [0,1]. The nth grid in the 88 | list is used in the nth crop layer. Exclusive with points_per_side. 89 | min_mask_region_area (int): If >0, postprocessing will be applied 90 | to remove disconnected regions and holes in masks with area smaller 91 | than min_mask_region_area. Requires opencv. 92 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 94 | For large resolutions, 'binary_mask' may consume large amounts of 95 | memory. 96 | """ 97 | 98 | assert (points_per_side is None) != ( 99 | point_grids is None 100 | ), "Exactly one of points_per_side or point_grid must be provided." 101 | if points_per_side is not None: 102 | self.point_grids = build_all_layer_point_grids( 103 | points_per_side, 104 | crop_n_layers, 105 | crop_n_points_downscale_factor, 106 | ) 107 | elif point_grids is not None: 108 | self.point_grids = point_grids 109 | else: 110 | raise ValueError("Can't have both points_per_side and point_grid be None.") 111 | 112 | assert output_mode in [ 113 | "binary_mask", 114 | "uncompressed_rle", 115 | "coco_rle", 116 | ], f"Unknown output_mode {output_mode}." 117 | if output_mode == "coco_rle": 118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 119 | 120 | if min_mask_region_area > 0: 121 | import cv2 # type: ignore # noqa: F401 122 | 123 | self.predictor = SamPredictor(model) 124 | self.points_per_batch = points_per_batch 125 | self.pred_iou_thresh = pred_iou_thresh 126 | self.stability_score_thresh = stability_score_thresh 127 | self.stability_score_offset = stability_score_offset 128 | self.box_nms_thresh = box_nms_thresh 129 | self.crop_n_layers = crop_n_layers 130 | self.crop_nms_thresh = crop_nms_thresh 131 | self.crop_overlap_ratio = crop_overlap_ratio 132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 133 | self.min_mask_region_area = min_mask_region_area 134 | self.output_mode = output_mode 135 | 136 | @torch.no_grad() 137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 138 | """ 139 | Generates masks for the given image. 140 | 141 | Arguments: 142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 143 | 144 | Returns: 145 | list(dict(str, any)): A list over records for masks. Each record is 146 | a dict containing the following keys: 147 | segmentation (dict(str, any) or np.ndarray): The mask. If 148 | output_mode='binary_mask', is an array of shape HW. Otherwise, 149 | is a dictionary containing the RLE. 150 | bbox (list(float)): The box around the mask, in XYWH format. 151 | area (int): The area in pixels of the mask. 152 | predicted_iou (float): The model's own prediction of the mask's 153 | quality. This is filtered by the pred_iou_thresh parameter. 154 | point_coords (list(list(float))): The point coordinates input 155 | to the model to generate this mask. 156 | stability_score (float): A measure of the mask's quality. This 157 | is filtered on using the stability_score_thresh parameter. 158 | crop_box (list(float)): The crop of the image used to generate 159 | the mask, given in XYWH format. 160 | """ 161 | 162 | # Generate masks 163 | mask_data = self._generate_masks(image) 164 | 165 | # Filter small disconnected regions and holes in masks 166 | if self.min_mask_region_area > 0: 167 | mask_data = self.postprocess_small_regions( 168 | mask_data, 169 | self.min_mask_region_area, 170 | max(self.box_nms_thresh, self.crop_nms_thresh), 171 | ) 172 | 173 | # Encode masks 174 | if self.output_mode == "coco_rle": 175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] 176 | elif self.output_mode == "binary_mask": 177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 178 | else: 179 | mask_data["segmentations"] = mask_data["rles"] 180 | 181 | # Write mask records 182 | curr_anns = [] 183 | for idx in range(len(mask_data["segmentations"])): 184 | ann = { 185 | "segmentation": mask_data["segmentations"][idx], 186 | "area": area_from_rle(mask_data["rles"][idx]), 187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 188 | "predicted_iou": mask_data["iou_preds"][idx].item(), 189 | "point_coords": [mask_data["points"][idx].tolist()], 190 | "stability_score": mask_data["stability_score"][idx].item(), 191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 192 | } 193 | curr_anns.append(ann) 194 | 195 | return curr_anns 196 | 197 | def _generate_masks(self, image: np.ndarray) -> MaskData: 198 | orig_size = image.shape[:2] 199 | crop_boxes, layer_idxs = generate_crop_boxes(orig_size, self.crop_n_layers, self.crop_overlap_ratio) 200 | 201 | # Iterate over image crops 202 | data = MaskData() 203 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 204 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 205 | data.cat(crop_data) 206 | 207 | # Remove duplicate masks between crops 208 | if len(crop_boxes) > 1: 209 | # Prefer masks from smaller crops 210 | scores = 1 / box_area(data["crop_boxes"]) 211 | scores = scores.to(data["boxes"].device) 212 | keep_by_nms = batched_nms( 213 | data["boxes"].float(), 214 | scores, 215 | torch.zeros(len(data["boxes"])), # categories 216 | iou_threshold=self.crop_nms_thresh, 217 | ) 218 | data.filter(keep_by_nms) 219 | 220 | data.to_numpy() 221 | return data 222 | 223 | def _process_crop( 224 | self, 225 | image: np.ndarray, 226 | crop_box: List[int], 227 | crop_layer_idx: int, 228 | orig_size: Tuple[int, ...], 229 | ) -> MaskData: 230 | # Crop the image and calculate embeddings 231 | x0, y0, x1, y1 = crop_box 232 | cropped_im = image[y0:y1, x0:x1, :] 233 | cropped_im_size = cropped_im.shape[:2] 234 | self.predictor.set_image(cropped_im) 235 | 236 | # Get points for this crop 237 | points_scale = np.array(cropped_im_size)[None, ::-1] 238 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 239 | 240 | # Generate masks for this crop in batches 241 | data = MaskData() 242 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 243 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) 244 | data.cat(batch_data) 245 | del batch_data 246 | self.predictor.reset_image() 247 | 248 | # Remove duplicates within this crop. 249 | keep_by_nms = batched_nms( 250 | data["boxes"].float(), 251 | data["iou_preds"], 252 | torch.zeros(len(data["boxes"])), # categories 253 | iou_threshold=self.box_nms_thresh, 254 | ) 255 | data.filter(keep_by_nms) 256 | 257 | # Return to the original image frame 258 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 259 | data["points"] = uncrop_points(data["points"], crop_box) 260 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 261 | 262 | return data 263 | 264 | def _process_batch( 265 | self, 266 | points: np.ndarray, 267 | im_size: Tuple[int, ...], 268 | crop_box: List[int], 269 | orig_size: Tuple[int, ...], 270 | ) -> MaskData: 271 | orig_h, orig_w = orig_size 272 | 273 | # Run model on this batch 274 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 275 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 276 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 277 | masks, iou_preds, _ = self.predictor.predict_torch( 278 | in_points[:, None, :], 279 | in_labels[:, None], 280 | multimask_output=True, 281 | return_logits=True, 282 | ) 283 | 284 | # Serialize predictions and store in MaskData 285 | data = MaskData( 286 | masks=masks.flatten(0, 1), 287 | iou_preds=iou_preds.flatten(0, 1), 288 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 289 | ) 290 | del masks 291 | 292 | # Filter by predicted IoU 293 | if self.pred_iou_thresh > 0.0: 294 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 295 | data.filter(keep_mask) 296 | 297 | # Calculate stability score 298 | data["stability_score"] = calculate_stability_score( 299 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 300 | ) 301 | if self.stability_score_thresh > 0.0: 302 | keep_mask = data["stability_score"] >= self.stability_score_thresh 303 | data.filter(keep_mask) 304 | 305 | # Threshold masks and calculate boxes 306 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 307 | data["boxes"] = batched_mask_to_box(data["masks"]) 308 | 309 | # Filter boxes that touch crop boundaries 310 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 311 | if not torch.all(keep_mask): 312 | data.filter(keep_mask) 313 | 314 | # Compress to RLE 315 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 316 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 317 | del data["masks"] 318 | 319 | return data 320 | 321 | @staticmethod 322 | def postprocess_small_regions(mask_data: MaskData, min_area: int, nms_thresh: float) -> MaskData: 323 | """ 324 | Removes small disconnected regions and holes in masks, then reruns 325 | box NMS to remove any new duplicates. 326 | 327 | Edits mask_data in place. 328 | 329 | Requires open-cv as a dependency. 330 | """ 331 | if len(mask_data["rles"]) == 0: 332 | return mask_data 333 | 334 | # Filter small disconnected regions and holes 335 | new_masks = [] 336 | scores = [] 337 | for rle in mask_data["rles"]: 338 | mask = rle_to_mask(rle) 339 | 340 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 341 | unchanged = not changed 342 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 343 | unchanged = unchanged and not changed 344 | 345 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 346 | # Give score=0 to changed masks and score=1 to unchanged masks 347 | # so NMS will prefer ones that didn't need postprocessing 348 | scores.append(float(unchanged)) 349 | 350 | # Recalculate boxes and remove any new duplicates 351 | masks = torch.cat(new_masks, dim=0) 352 | boxes = batched_mask_to_box(masks) 353 | keep_by_nms = batched_nms( 354 | boxes.float(), 355 | torch.as_tensor(scores), 356 | torch.zeros(len(boxes)), # categories 357 | iou_threshold=nms_thresh, 358 | ) 359 | 360 | # Only recalculate RLEs for masks that have changed 361 | for i_mask in keep_by_nms: 362 | if scores[i_mask] == 0.0: 363 | mask_torch = masks[i_mask].unsqueeze(0) 364 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 365 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 366 | mask_data.filter(keep_by_nms) 367 | 368 | return mask_data 369 | -------------------------------------------------------------------------------- /metaseg/generator/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from functools import partial 8 | 9 | import torch 10 | 11 | from metaseg.modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam, 49 | "vit_h": build_sam, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /metaseg/generator/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Tuple 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from metaseg.modeling import Sam 13 | from metaseg.utils.transforms import ResizeLongestSide 14 | 15 | 16 | class SamPredictor: 17 | def __init__( 18 | self, 19 | sam_model: Sam, 20 | ) -> None: 21 | """ 22 | Uses SAM to calculate the image embedding for an image, and then 23 | allow repeated, efficient mask prediction given prompts. 24 | 25 | Arguments: 26 | sam_model (Sam): The model to use for mask prediction. 27 | """ 28 | super().__init__() 29 | self.model = sam_model 30 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 31 | self.reset_image() 32 | 33 | def set_image( 34 | self, 35 | image: np.ndarray, 36 | image_format: str = "RGB", 37 | ) -> None: 38 | """ 39 | Calculates the image embeddings for the provided image, allowing 40 | masks to be predicted with the 'predict' method. 41 | 42 | Arguments: 43 | image (np.ndarray): The image for calculating masks. Expects an 44 | image in HWC uint8 format, with pixel values in [0, 255]. 45 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 46 | """ 47 | assert image_format in [ 48 | "RGB", 49 | "BGR", 50 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 51 | if image_format != self.model.image_format: 52 | image = image[..., ::-1] 53 | 54 | # Transform the image to the form expected by the model 55 | input_image = self.transform.apply_image(image) 56 | input_image_torch = torch.as_tensor(input_image, device=self.device) 57 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 58 | 59 | self.set_torch_image(input_image_torch, image.shape[:2]) 60 | 61 | @torch.no_grad() 62 | def set_torch_image( 63 | self, 64 | transformed_image: torch.Tensor, 65 | original_image_size: Tuple[int, ...], 66 | ) -> None: 67 | """ 68 | Calculates the image embeddings for the provided image, allowing 69 | masks to be predicted with the 'predict' method. Expects the input 70 | image to be already transformed to the format expected by the model. 71 | 72 | Arguments: 73 | transformed_image (torch.Tensor): The input image, with shape 74 | 1x3xHxW, which has been transformed with ResizeLongestSide. 75 | original_image_size (tuple(int, int)): The size of the image 76 | before transformation, in (H, W) format. 77 | """ 78 | assert ( 79 | len(transformed_image.shape) == 4 80 | and transformed_image.shape[1] == 3 81 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 82 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 83 | self.reset_image() 84 | 85 | self.original_size = original_image_size 86 | self.input_size = tuple(transformed_image.shape[-2:]) 87 | input_image = self.model.preprocess(transformed_image) 88 | self.features = self.model.image_encoder(input_image) 89 | self.is_image_set = True 90 | 91 | def predict( 92 | self, 93 | point_coords: Optional[np.ndarray] = None, 94 | point_labels: Optional[np.ndarray] = None, 95 | box: Optional[np.ndarray] = None, 96 | mask_input: Optional[np.ndarray] = None, 97 | multimask_output: bool = True, 98 | return_logits: bool = False, 99 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 100 | """ 101 | Predict masks for the given input prompts, using the currently set image. 102 | 103 | Arguments: 104 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 105 | model. Each point is in (X,Y) in pixels. 106 | point_labels (np.ndarray or None): A length N array of labels for the 107 | point prompts. 1 indicates a foreground point and 0 indicates a 108 | background point. 109 | box (np.ndarray or None): A length 4 array given a box prompt to the 110 | model, in XYXY format. 111 | mask_input (np.ndarray): A low resolution mask input to the model, typically 112 | coming from a previous prediction iteration. Has form 1xHxW, where 113 | for SAM, H=W=256. 114 | multimask_output (bool): If true, the model will return three masks. 115 | For ambiguous input prompts (such as a single click), this will often 116 | produce better masks than a single prediction. If only a single 117 | mask is needed, the model's predicted quality score can be used 118 | to select the best mask. For non-ambiguous prompts, such as multiple 119 | input prompts, multimask_output=False can give better results. 120 | return_logits (bool): If true, returns un-thresholded masks logits 121 | instead of a binary mask. 122 | 123 | Returns: 124 | (np.ndarray): The output masks in CxHxW format, where C is the 125 | number of masks, and (H, W) is the original image size. 126 | (np.ndarray): An array of length C containing the model's 127 | predictions for the quality of each mask. 128 | (np.ndarray): An array of shape CxHxW, where C is the number 129 | of masks and H=W=256. These low resolution logits can be passed to 130 | a subsequent iteration as mask input. 131 | """ 132 | if not self.is_image_set: 133 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 134 | 135 | # Transform input prompts 136 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 137 | if point_coords is not None: 138 | assert point_labels is not None, "point_labels must be supplied if point_coords is supplied." 139 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 140 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 141 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 142 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 143 | if box is not None: 144 | box = self.transform.apply_boxes(box, self.original_size) 145 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 146 | box_torch = box_torch[None, :] 147 | if mask_input is not None: 148 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 149 | mask_input_torch = mask_input_torch[None, :, :, :] 150 | 151 | masks, iou_predictions, low_res_masks = self.predict_torch( 152 | coords_torch, 153 | labels_torch, 154 | box_torch, 155 | mask_input_torch, 156 | multimask_output, 157 | return_logits=return_logits, 158 | ) 159 | 160 | masks = masks[0].detach().cpu().numpy() 161 | iou_predictions = iou_predictions[0].detach().cpu().numpy() 162 | low_res_masks = low_res_masks[0].detach().cpu().numpy() 163 | return masks, iou_predictions, low_res_masks 164 | 165 | @torch.no_grad() 166 | def predict_torch( 167 | self, 168 | point_coords: Optional[torch.Tensor], 169 | point_labels: Optional[torch.Tensor], 170 | boxes: Optional[torch.Tensor] = None, 171 | mask_input: Optional[torch.Tensor] = None, 172 | multimask_output: bool = True, 173 | return_logits: bool = False, 174 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 175 | """ 176 | Predict masks for the given input prompts, using the currently set image. 177 | Input prompts are batched torch tensors and are expected to already be 178 | transformed to the input frame using ResizeLongestSide. 179 | 180 | Arguments: 181 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 182 | model. Each point is in (X,Y) in pixels. 183 | point_labels (torch.Tensor or None): A BxN array of labels for the 184 | point prompts. 1 indicates a foreground point and 0 indicates a 185 | background point. 186 | box (np.ndarray or None): A Bx4 array given a box prompt to the 187 | model, in XYXY format. 188 | mask_input (np.ndarray): A low resolution mask input to the model, typically 189 | coming from a previous prediction iteration. Has form Bx1xHxW, where 190 | for SAM, H=W=256. Masks returned by a previous iteration of the 191 | predict method do not need further transformation. 192 | multimask_output (bool): If true, the model will return three masks. 193 | For ambiguous input prompts (such as a single click), this will often 194 | produce better masks than a single prediction. If only a single 195 | mask is needed, the model's predicted quality score can be used 196 | to select the best mask. For non-ambiguous prompts, such as multiple 197 | input prompts, multimask_output=False can give better results. 198 | return_logits (bool): If true, returns un-thresholded masks logits 199 | instead of a binary mask. 200 | 201 | Returns: 202 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 203 | number of masks, and (H, W) is the original image size. 204 | (torch.Tensor): An array of shape BxC containing the model's 205 | predictions for the quality of each mask. 206 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 207 | of masks and H=W=256. These low res logits can be passed to 208 | a subsequent iteration as mask input. 209 | """ 210 | if not self.is_image_set: 211 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 212 | 213 | if point_coords is not None: 214 | points = (point_coords, point_labels) 215 | else: 216 | points = None 217 | 218 | # Embed prompts 219 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 220 | points=points, 221 | boxes=boxes, 222 | masks=mask_input, 223 | ) 224 | 225 | # Predict masks 226 | low_res_masks, iou_predictions = self.model.mask_decoder( 227 | image_embeddings=self.features, 228 | image_pe=self.model.prompt_encoder.get_dense_pe(), 229 | sparse_prompt_embeddings=sparse_embeddings, 230 | dense_prompt_embeddings=dense_embeddings, 231 | multimask_output=multimask_output, 232 | ) 233 | 234 | # Upscale the masks to the original image resolution 235 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 236 | 237 | if not return_logits: 238 | masks = masks > self.model.mask_threshold 239 | 240 | return masks, iou_predictions, low_res_masks 241 | 242 | def get_image_embedding(self) -> torch.Tensor: 243 | """ 244 | Returns the image embeddings for the currently set image, with 245 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 246 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 247 | """ 248 | if not self.is_image_set: 249 | raise RuntimeError("An image must be set with .set_image(...) to generate an embedding.") 250 | assert self.features is not None, "Features must exist if an image has been set." 251 | return self.features 252 | 253 | @property 254 | def device(self) -> torch.device: 255 | return self.model.device 256 | 257 | def reset_image(self) -> None: 258 | """Resets the currently set image.""" 259 | self.is_image_set = False 260 | self.features = None 261 | self.orig_h = None 262 | self.orig_w = None 263 | self.input_h = None 264 | self.input_w = None 265 | -------------------------------------------------------------------------------- /metaseg/mask_predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from tqdm import tqdm 7 | 8 | from metaseg.generator.automatic_mask_generator import SamAutomaticMaskGenerator 9 | from metaseg.generator.predictor import SamPredictor 10 | from metaseg.generator.build_sam import sam_model_registry 11 | 12 | from metaseg.utils import ( 13 | download_model, 14 | load_box, 15 | load_image, 16 | load_mask, 17 | load_video, 18 | multi_boxes, 19 | save_image, 20 | show_image, 21 | ) 22 | 23 | 24 | class SegAutoMaskPredictor: 25 | def __init__(self): 26 | self.model = None 27 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 28 | 29 | def load_model(self, model_type): 30 | if self.model is None: 31 | self.model_path = download_model(model_type) 32 | self.model = sam_model_registry[model_type](checkpoint=self.model_path) 33 | self.model.to(device=self.device) 34 | 35 | return self.model 36 | 37 | def image_predict( 38 | self, 39 | source, 40 | model_type, 41 | points_per_side, 42 | points_per_batch, 43 | min_area, 44 | output_path="output.png", 45 | show=False, 46 | save=False, 47 | ): 48 | read_image = load_image(source) 49 | model = self.load_model(model_type) 50 | mask_generator = SamAutomaticMaskGenerator( 51 | model, points_per_side=points_per_side, points_per_batch=points_per_batch, min_mask_region_area=min_area 52 | ) 53 | 54 | masks = mask_generator.generate(read_image) 55 | 56 | sorted_anns = sorted(masks, key=(lambda x: x["area"]), reverse=True) 57 | mask_image = np.zeros((masks[0]["segmentation"].shape[0], masks[0]["segmentation"].shape[1], 3), dtype=np.uint8) 58 | colors = np.random.randint(0, 255, size=(256, 3), dtype=np.uint8) 59 | for i, ann in enumerate(sorted_anns): 60 | m = ann["segmentation"] 61 | img = np.ones((m.shape[0], m.shape[1], 3), dtype=np.uint8) 62 | color = colors[i % 256] 63 | for i in range(3): 64 | img[:, :, 0] = color[0] 65 | img[:, :, 1] = color[1] 66 | img[:, :, 2] = color[2] 67 | img = cv2.bitwise_and(img, img, mask=m.astype(np.uint8)) 68 | img = cv2.addWeighted(img, 0.35, np.zeros_like(img), 0.65, 0) 69 | mask_image = cv2.add(mask_image, img) 70 | 71 | combined_mask = cv2.add(read_image, mask_image) 72 | self.combined_mask = combined_mask 73 | if show: 74 | show_image(combined_mask) 75 | 76 | if save: 77 | save_image(output_path=output_path, output_image=combined_mask) 78 | 79 | return masks 80 | 81 | def video_predict(self, source, model_type, points_per_side, points_per_batch, min_area, output_path="output.mp4"): 82 | cap, out = load_video(source, output_path) 83 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 84 | colors = np.random.randint(0, 255, size=(256, 3), dtype=np.uint8) 85 | 86 | for _ in tqdm(range(length)): 87 | ret, frame = cap.read() 88 | if not ret: 89 | break 90 | 91 | model = self.load_model(model_type) 92 | mask_generator = SamAutomaticMaskGenerator( 93 | model, points_per_side=points_per_side, points_per_batch=points_per_batch, min_mask_region_area=min_area 94 | ) 95 | masks = mask_generator.generate(frame) 96 | 97 | if len(masks) == 0: 98 | continue 99 | 100 | sorted_anns = sorted(masks, key=(lambda x: x["area"]), reverse=True) 101 | mask_image = np.zeros( 102 | (masks[0]["segmentation"].shape[0], masks[0]["segmentation"].shape[1], 3), dtype=np.uint8 103 | ) 104 | 105 | for i, ann in enumerate(sorted_anns): 106 | m = ann["segmentation"] 107 | color = colors[i % 256] 108 | img = np.zeros((m.shape[0], m.shape[1], 3), dtype=np.uint8) 109 | img[:, :, 0] = color[0] 110 | img[:, :, 1] = color[1] 111 | img[:, :, 2] = color[2] 112 | img = cv2.bitwise_and(img, img, mask=m.astype(np.uint8)) 113 | img = cv2.addWeighted(img, 0.35, np.zeros_like(img), 0.65, 0) 114 | mask_image = cv2.add(mask_image, img) 115 | 116 | combined_mask = cv2.add(frame, mask_image) 117 | out.write(combined_mask) 118 | 119 | out.release() 120 | cap.release() 121 | cv2.destroyAllWindows() 122 | 123 | return output_path 124 | 125 | 126 | class SegManualMaskPredictor: 127 | def __init__(self): 128 | self.model = None 129 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 130 | 131 | def load_model(self, model_type): 132 | if self.model is None: 133 | self.model_path = download_model(model_type) 134 | self.model = sam_model_registry[model_type](checkpoint=self.model_path) 135 | self.model.to(device=self.device) 136 | 137 | return self.model 138 | 139 | def image_predict( 140 | self, 141 | source, 142 | model_type, 143 | input_box=None, 144 | input_point=None, 145 | input_label=None, 146 | multimask_output=False, 147 | output_path="output.png", 148 | random_color=False, 149 | show=False, 150 | save=False, 151 | ): 152 | image = load_image(source) 153 | model = self.load_model(model_type) 154 | predictor = SamPredictor(model) 155 | predictor.set_image(image) 156 | 157 | if type(input_box[0]) == list: 158 | input_boxes, new_boxes = multi_boxes(input_box, predictor, image) 159 | 160 | masks, _, _ = predictor.predict_torch( 161 | point_coords=None, 162 | point_labels=None, 163 | boxes=new_boxes, 164 | multimask_output=False, 165 | ) 166 | for mask in masks: 167 | mask_image = load_mask(mask.cpu().numpy(), random_color) 168 | 169 | for box in input_boxes: 170 | image = load_box(box.cpu().numpy(), image) 171 | 172 | elif type(input_box[0]) == int: 173 | input_boxes = np.array(input_box)[None, :] 174 | 175 | masks, _, _ = predictor.predict( 176 | point_coords=input_point, 177 | point_labels=input_label, 178 | box=input_boxes, 179 | multimask_output=multimask_output, 180 | ) 181 | mask_image = load_mask(masks, random_color) 182 | image = load_box(input_box, image) 183 | 184 | combined_mask = cv2.add(image, mask_image) 185 | if save: 186 | save_image(output_path=output_path, output_image=combined_mask) 187 | 188 | if show: 189 | show_image(combined_mask) 190 | 191 | return masks 192 | 193 | def video_predict( 194 | self, 195 | source, 196 | model_type, 197 | input_box=None, 198 | input_point=None, 199 | input_label=None, 200 | multimask_output=False, 201 | output_path="output.mp4", 202 | random_color=False, 203 | ): 204 | cap, out = load_video(source, output_path) 205 | length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 206 | 207 | for _ in tqdm(range(length)): 208 | ret, frame = cap.read() 209 | if not ret: 210 | break 211 | 212 | model = self.load_model(model_type) 213 | predictor = SamPredictor(model) 214 | predictor.set_image(frame) 215 | 216 | if type(input_box[0]) == list: 217 | input_boxes, new_boxes = multi_boxes(input_box, predictor, frame) 218 | 219 | masks, _, _ = predictor.predict_torch( 220 | point_coords=None, 221 | point_labels=None, 222 | boxes=new_boxes, 223 | multimask_output=False, 224 | ) 225 | for mask in masks: 226 | mask_image = load_mask(mask.cpu().numpy(), random_color) 227 | 228 | for box in input_boxes: 229 | frame = load_box(box.cpu().numpy(), frame) 230 | 231 | elif type(input_box[0]) == int: 232 | input_boxes = np.array(input_box)[None, :] 233 | 234 | masks, _, _ = predictor.predict( 235 | point_coords=input_point, 236 | point_labels=input_label, 237 | box=input_boxes, 238 | multimask_output=multimask_output, 239 | ) 240 | mask_image = load_mask(masks, random_color) 241 | frame = load_box(input_box, frame) 242 | 243 | combined_mask = cv2.add(frame, mask_image) 244 | out.write(combined_mask) 245 | 246 | out.release() 247 | cap.release() 248 | cv2.destroyAllWindows() 249 | return output_path 250 | -------------------------------------------------------------------------------- /metaseg/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from metaseg.modeling.image_encoder import ImageEncoderViT 8 | from metaseg.modeling.mask_decoder import MaskDecoder 9 | from metaseg.modeling.prompt_encoder import PromptEncoder 10 | from metaseg.modeling.sam import Sam 11 | from metaseg.modeling.transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /metaseg/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Type 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /metaseg/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Tuple, Type 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from metaseg.modeling.common import LayerNorm2d, MLPBlock 14 | 15 | 16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 17 | class ImageEncoderViT(nn.Module): 18 | def __init__( 19 | self, 20 | img_size: int = 1024, 21 | patch_size: int = 16, 22 | in_chans: int = 3, 23 | embed_dim: int = 768, 24 | depth: int = 12, 25 | num_heads: int = 12, 26 | mlp_ratio: float = 4.0, 27 | out_chans: int = 256, 28 | qkv_bias: bool = True, 29 | norm_layer: Type[nn.Module] = nn.LayerNorm, 30 | act_layer: Type[nn.Module] = nn.GELU, 31 | use_abs_pos: bool = True, 32 | use_rel_pos: bool = False, 33 | rel_pos_zero_init: bool = True, 34 | window_size: int = 0, 35 | global_attn_indexes: Tuple[int, ...] = (), 36 | ) -> None: 37 | """ 38 | Args: 39 | img_size (int): Input image size. 40 | patch_size (int): Patch size. 41 | in_chans (int): Number of input image channels. 42 | embed_dim (int): Patch embedding dimension. 43 | depth (int): Depth of ViT. 44 | num_heads (int): Number of attention heads in each ViT block. 45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 46 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 47 | norm_layer (nn.Module): Normalization layer. 48 | act_layer (nn.Module): Activation layer. 49 | use_abs_pos (bool): If True, use absolute positional embeddings. 50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 52 | window_size (int): Window size for window attention blocks. 53 | global_attn_indexes (list): Indexes for blocks using global attention. 54 | """ 55 | super().__init__() 56 | self.img_size = img_size 57 | 58 | self.patch_embed = PatchEmbed( 59 | kernel_size=(patch_size, patch_size), 60 | stride=(patch_size, patch_size), 61 | in_chans=in_chans, 62 | embed_dim=embed_dim, 63 | ) 64 | 65 | self.pos_embed: Optional[nn.Parameter] = None 66 | if use_abs_pos: 67 | # Initialize absolute positional embedding with pretrain image size. 68 | self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)) 69 | 70 | self.blocks = nn.ModuleList() 71 | for i in range(depth): 72 | block = Block( 73 | dim=embed_dim, 74 | num_heads=num_heads, 75 | mlp_ratio=mlp_ratio, 76 | qkv_bias=qkv_bias, 77 | norm_layer=norm_layer, 78 | act_layer=act_layer, 79 | use_rel_pos=use_rel_pos, 80 | rel_pos_zero_init=rel_pos_zero_init, 81 | window_size=window_size if i not in global_attn_indexes else 0, 82 | input_size=(img_size // patch_size, img_size // patch_size), 83 | ) 84 | self.blocks.append(block) 85 | 86 | self.neck = nn.Sequential( 87 | nn.Conv2d( 88 | embed_dim, 89 | out_chans, 90 | kernel_size=1, 91 | bias=False, 92 | ), 93 | LayerNorm2d(out_chans), 94 | nn.Conv2d( 95 | out_chans, 96 | out_chans, 97 | kernel_size=3, 98 | padding=1, 99 | bias=False, 100 | ), 101 | LayerNorm2d(out_chans), 102 | ) 103 | 104 | def forward(self, x: torch.Tensor) -> torch.Tensor: 105 | x = self.patch_embed(x) 106 | if self.pos_embed is not None: 107 | x = x + self.pos_embed 108 | 109 | for blk in self.blocks: 110 | x = blk(x) 111 | 112 | x = self.neck(x.permute(0, 3, 1, 2)) 113 | 114 | return x 115 | 116 | 117 | class Block(nn.Module): 118 | """Transformer blocks with support of window attention and residual propagation blocks""" 119 | 120 | def __init__( 121 | self, 122 | dim: int, 123 | num_heads: int, 124 | mlp_ratio: float = 4.0, 125 | qkv_bias: bool = True, 126 | norm_layer: Type[nn.Module] = nn.LayerNorm, 127 | act_layer: Type[nn.Module] = nn.GELU, 128 | use_rel_pos: bool = False, 129 | rel_pos_zero_init: bool = True, 130 | window_size: int = 0, 131 | input_size: Optional[Tuple[int, int]] = None, 132 | ) -> None: 133 | """ 134 | Args: 135 | dim (int): Number of input channels. 136 | num_heads (int): Number of attention heads in each ViT block. 137 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 138 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 139 | norm_layer (nn.Module): Normalization layer. 140 | act_layer (nn.Module): Activation layer. 141 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 142 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 143 | window_size (int): Window size for window attention blocks. If it equals 0, then 144 | use global attention. 145 | input_size (int or None): Input resolution for calculating the relative positional 146 | parameter size. 147 | """ 148 | super().__init__() 149 | self.norm1 = norm_layer(dim) 150 | self.attn = Attention( 151 | dim, 152 | num_heads=num_heads, 153 | qkv_bias=qkv_bias, 154 | use_rel_pos=use_rel_pos, 155 | rel_pos_zero_init=rel_pos_zero_init, 156 | input_size=input_size if window_size == 0 else (window_size, window_size), 157 | ) 158 | 159 | self.norm2 = norm_layer(dim) 160 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 161 | 162 | self.window_size = window_size 163 | 164 | def forward(self, x: torch.Tensor) -> torch.Tensor: 165 | shortcut = x 166 | x = self.norm1(x) 167 | # Window partition 168 | if self.window_size > 0: 169 | H, W = x.shape[1], x.shape[2] 170 | x, pad_hw = window_partition(x, self.window_size) 171 | 172 | x = self.attn(x) 173 | # Reverse window partition 174 | if self.window_size > 0: 175 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 176 | 177 | x = shortcut + x 178 | x = x + self.mlp(self.norm2(x)) 179 | 180 | return x 181 | 182 | 183 | class Attention(nn.Module): 184 | """Multi-head Attention block with relative position embeddings.""" 185 | 186 | def __init__( 187 | self, 188 | dim: int, 189 | num_heads: int = 8, 190 | qkv_bias: bool = True, 191 | use_rel_pos: bool = False, 192 | rel_pos_zero_init: bool = True, 193 | input_size: Optional[Tuple[int, int]] = None, 194 | ) -> None: 195 | """ 196 | Args: 197 | dim (int): Number of input channels. 198 | num_heads (int): Number of attention heads. 199 | qkv_bias (bool: If True, add a learnable bias to query, key, value. 200 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 201 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 202 | input_size (int or None): Input resolution for calculating the relative positional 203 | parameter size. 204 | """ 205 | super().__init__() 206 | self.num_heads = num_heads 207 | head_dim = dim // num_heads 208 | self.scale = head_dim ** -0.5 209 | 210 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 211 | self.proj = nn.Linear(dim, dim) 212 | 213 | self.use_rel_pos = use_rel_pos 214 | if self.use_rel_pos: 215 | assert input_size is not None, "Input size must be provided if using relative positional encoding." 216 | # initialize relative positional embeddings 217 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 218 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 219 | 220 | def forward(self, x: torch.Tensor) -> torch.Tensor: 221 | B, H, W, _ = x.shape 222 | # qkv with shape (3, B, nHead, H * W, C) 223 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 224 | # q, k, v with shape (B * nHead, H * W, C) 225 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 226 | 227 | attn = (q * self.scale) @ k.transpose(-2, -1) 228 | 229 | if self.use_rel_pos: 230 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 231 | 232 | attn = attn.softmax(dim=-1) 233 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 234 | x = self.proj(x) 235 | 236 | return x 237 | 238 | 239 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 240 | """ 241 | Partition into non-overlapping windows with padding if needed. 242 | Args: 243 | x (tensor): input tokens with [B, H, W, C]. 244 | window_size (int): window size. 245 | 246 | Returns: 247 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 248 | (Hp, Wp): padded height and width before partition 249 | """ 250 | B, H, W, C = x.shape 251 | 252 | pad_h = (window_size - H % window_size) % window_size 253 | pad_w = (window_size - W % window_size) % window_size 254 | if pad_h > 0 or pad_w > 0: 255 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 256 | Hp, Wp = H + pad_h, W + pad_w 257 | 258 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 259 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 260 | return windows, (Hp, Wp) 261 | 262 | 263 | def window_unpartition( 264 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 265 | ) -> torch.Tensor: 266 | """ 267 | Window unpartition into original sequences and removing padding. 268 | Args: 269 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 270 | window_size (int): window size. 271 | pad_hw (Tuple): padded height and width (Hp, Wp). 272 | hw (Tuple): original height and width (H, W) before padding. 273 | 274 | Returns: 275 | x: unpartitioned sequences with [B, H, W, C]. 276 | """ 277 | Hp, Wp = pad_hw 278 | H, W = hw 279 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 280 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 281 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 282 | 283 | if Hp > H or Wp > W: 284 | x = x[:, :H, :W, :].contiguous() 285 | return x 286 | 287 | 288 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 289 | """ 290 | Get relative positional embeddings according to the relative positions of 291 | query and key sizes. 292 | Args: 293 | q_size (int): size of query q. 294 | k_size (int): size of key k. 295 | rel_pos (Tensor): relative position embeddings (L, C). 296 | 297 | Returns: 298 | Extracted positional embeddings according to relative positions. 299 | """ 300 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 301 | # Interpolate rel pos if needed. 302 | if rel_pos.shape[0] != max_rel_dist: 303 | # Interpolate rel pos. 304 | rel_pos_resized = F.interpolate( 305 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 306 | size=max_rel_dist, 307 | mode="linear", 308 | ) 309 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 310 | else: 311 | rel_pos_resized = rel_pos 312 | 313 | # Scale the coords with short length if shapes for q and k are different. 314 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 315 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 316 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 317 | 318 | return rel_pos_resized[relative_coords.long()] 319 | 320 | 321 | def add_decomposed_rel_pos( 322 | attn: torch.Tensor, 323 | q: torch.Tensor, 324 | rel_pos_h: torch.Tensor, 325 | rel_pos_w: torch.Tensor, 326 | q_size: Tuple[int, int], 327 | k_size: Tuple[int, int], 328 | ) -> torch.Tensor: 329 | """ 330 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 331 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 332 | Args: 333 | attn (Tensor): attention map. 334 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 335 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 336 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 337 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 338 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 339 | 340 | Returns: 341 | attn (Tensor): attention map with added relative positional embeddings. 342 | """ 343 | q_h, q_w = q_size 344 | k_h, k_w = k_size 345 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 346 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 347 | 348 | B, _, dim = q.shape 349 | r_q = q.reshape(B, q_h, q_w, dim) 350 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 351 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 352 | 353 | attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( 354 | B, q_h * q_w, k_h * k_w 355 | ) 356 | 357 | return attn 358 | 359 | 360 | class PatchEmbed(nn.Module): 361 | """ 362 | Image to Patch Embedding. 363 | """ 364 | 365 | def __init__( 366 | self, 367 | kernel_size: Tuple[int, int] = (16, 16), 368 | stride: Tuple[int, int] = (16, 16), 369 | padding: Tuple[int, int] = (0, 0), 370 | in_chans: int = 3, 371 | embed_dim: int = 768, 372 | ) -> None: 373 | """ 374 | Args: 375 | kernel_size (Tuple): kernel size of the projection layer. 376 | stride (Tuple): stride of the projection layer. 377 | padding (Tuple): padding size of the projection layer. 378 | in_chans (int): Number of input image channels. 379 | embed_dim (int): embed_dim (int): Patch embedding dimension. 380 | """ 381 | super().__init__() 382 | 383 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) 384 | 385 | def forward(self, x: torch.Tensor) -> torch.Tensor: 386 | x = self.proj(x) 387 | # B C H W -> B H W C 388 | x = x.permute(0, 2, 3, 1) 389 | return x 390 | -------------------------------------------------------------------------------- /metaseg/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from metaseg.modeling.common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | tranformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for i in range(self.num_mask_tokens)] 62 | ) 63 | 64 | self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth) 65 | 66 | def forward( 67 | self, 68 | image_embeddings: torch.Tensor, 69 | image_pe: torch.Tensor, 70 | sparse_prompt_embeddings: torch.Tensor, 71 | dense_prompt_embeddings: torch.Tensor, 72 | multimask_output: bool, 73 | ) -> Tuple[torch.Tensor, torch.Tensor]: 74 | """ 75 | Predict masks given image and prompt embeddings. 76 | 77 | Arguments: 78 | image_embeddings (torch.Tensor): the embeddings from the image encoder 79 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 80 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 81 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 82 | multimask_output (bool): Whether to return multiple masks or a single 83 | mask. 84 | 85 | Returns: 86 | torch.Tensor: batched predicted masks 87 | torch.Tensor: batched predictions of mask quality 88 | """ 89 | masks, iou_pred = self.predict_masks( 90 | image_embeddings=image_embeddings, 91 | image_pe=image_pe, 92 | sparse_prompt_embeddings=sparse_prompt_embeddings, 93 | dense_prompt_embeddings=dense_prompt_embeddings, 94 | ) 95 | 96 | # Select the correct mask or masks for outptu 97 | if multimask_output: 98 | mask_slice = slice(1, None) 99 | else: 100 | mask_slice = slice(0, 1) 101 | masks = masks[:, mask_slice, :, :] 102 | iou_pred = iou_pred[:, mask_slice] 103 | 104 | # Prepare output 105 | return masks, iou_pred 106 | 107 | def predict_masks( 108 | self, 109 | image_embeddings: torch.Tensor, 110 | image_pe: torch.Tensor, 111 | sparse_prompt_embeddings: torch.Tensor, 112 | dense_prompt_embeddings: torch.Tensor, 113 | ) -> Tuple[torch.Tensor, torch.Tensor]: 114 | """Predicts masks. See 'forward' for more details.""" 115 | # Concatenate output tokens 116 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 117 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 118 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 119 | 120 | # Expand per-image data in batch direction to be per-mask 121 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 122 | src = src + dense_prompt_embeddings 123 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 124 | b, c, h, w = src.shape 125 | 126 | # Run the transformer 127 | hs, src = self.transformer(src, pos_src, tokens) 128 | iou_token_out = hs[:, 0, :] 129 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 130 | 131 | # Upscale mask embeddings and predict masks using the mask tokens 132 | src = src.transpose(1, 2).view(b, c, h, w) 133 | upscaled_embedding = self.output_upscaling(src) 134 | hyper_in_list: List[torch.Tensor] = [] 135 | for i in range(self.num_mask_tokens): 136 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 137 | hyper_in = torch.stack(hyper_in_list, dim=1) 138 | b, c, h, w = upscaled_embedding.shape 139 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 140 | 141 | # Generate mask quality predictions 142 | iou_pred = self.iou_prediction_head(iou_token_out) 143 | 144 | return masks, iou_pred 145 | 146 | 147 | # Lightly adapted from 148 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 149 | class MLP(nn.Module): 150 | def __init__( 151 | self, 152 | input_dim: int, 153 | hidden_dim: int, 154 | output_dim: int, 155 | num_layers: int, 156 | sigmoid_output: bool = False, 157 | ) -> None: 158 | super().__init__() 159 | self.num_layers = num_layers 160 | h = [hidden_dim] * (num_layers - 1) 161 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 162 | self.sigmoid_output = sigmoid_output 163 | 164 | def forward(self, x): 165 | for i, layer in enumerate(self.layers): 166 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 167 | if self.sigmoid_output: 168 | x = F.sigmoid(x) 169 | return x 170 | -------------------------------------------------------------------------------- /metaseg/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Optional, Tuple, Type 8 | 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | 13 | from metaseg.modeling.common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor: 208 | """Positionally encode points that are not normalized to [0,1].""" 209 | coords = coords_input.clone() 210 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 211 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 212 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 213 | -------------------------------------------------------------------------------- /metaseg/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, List, Tuple 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from metaseg.modeling.image_encoder import ImageEncoderViT 14 | from metaseg.modeling.mask_decoder import MaskDecoder 15 | from metaseg.modeling.prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input promts, 89 | C is determiend by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /metaseg/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple, Type 9 | 10 | import torch 11 | from torch import Tensor, nn 12 | 13 | from metaseg.modeling.common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) 58 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 59 | 60 | def forward( 61 | self, 62 | image_embedding: Tensor, 63 | image_pe: Tensor, 64 | point_embedding: Tensor, 65 | ) -> Tuple[Tensor, Tensor]: 66 | """ 67 | Args: 68 | image_embedding (torch.Tensor): image to attend to. Should be shape 69 | B x embedding_dim x h x w for any h and w. 70 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 71 | have the same shape as image_embedding. 72 | point_embedding (torch.Tensor): the embedding to add to the query points. 73 | Must have shape B x N_points x embedding_dim for any N_points. 74 | 75 | Returns: 76 | torch.Tensor: the processed point_embedding 77 | torch.Tensor: the processed image_embedding 78 | """ 79 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 80 | bs, c, h, w = image_embedding.shape 81 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 82 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 83 | 84 | # Prepare queries 85 | queries = point_embedding 86 | keys = image_embedding 87 | 88 | # Apply transformer blocks and final layernorm 89 | for layer in self.layers: 90 | queries, keys = layer( 91 | queries=queries, 92 | keys=keys, 93 | query_pe=point_embedding, 94 | key_pe=image_pe, 95 | ) 96 | 97 | # Apply the final attenion layer from the points to the image 98 | q = queries + point_embedding 99 | k = keys + image_pe 100 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 101 | queries = queries + attn_out 102 | queries = self.norm_final_attn(queries) 103 | 104 | return queries, keys 105 | 106 | 107 | class TwoWayAttentionBlock(nn.Module): 108 | def __init__( 109 | self, 110 | embedding_dim: int, 111 | num_heads: int, 112 | mlp_dim: int = 2048, 113 | activation: Type[nn.Module] = nn.ReLU, 114 | attention_downsample_rate: int = 2, 115 | skip_first_layer_pe: bool = False, 116 | ) -> None: 117 | """ 118 | A transformer block with four layers: (1) self-attention of sparse 119 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 120 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 121 | inputs. 122 | 123 | Arguments: 124 | embedding_dim (int): the channel dimension of the embeddings 125 | num_heads (int): the number of heads in the attention layers 126 | mlp_dim (int): the hidden dimension of the mlp block 127 | activation (nn.Module): the activation of the mlp block 128 | skip_first_layer_pe (bool): skip the PE on the first layer 129 | """ 130 | super().__init__() 131 | self.self_attn = Attention(embedding_dim, num_heads) 132 | self.norm1 = nn.LayerNorm(embedding_dim) 133 | 134 | self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) 135 | self.norm2 = nn.LayerNorm(embedding_dim) 136 | 137 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 138 | self.norm3 = nn.LayerNorm(embedding_dim) 139 | 140 | self.norm4 = nn.LayerNorm(embedding_dim) 141 | self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) 142 | 143 | self.skip_first_layer_pe = skip_first_layer_pe 144 | 145 | def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]: 146 | # Self attention block 147 | if self.skip_first_layer_pe: 148 | queries = self.self_attn(q=queries, k=queries, v=queries) 149 | else: 150 | q = queries + query_pe 151 | attn_out = self.self_attn(q=q, k=q, v=queries) 152 | queries = queries + attn_out 153 | queries = self.norm1(queries) 154 | 155 | # Cross attention block, tokens attending to image embedding 156 | q = queries + query_pe 157 | k = keys + key_pe 158 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 159 | queries = queries + attn_out 160 | queries = self.norm2(queries) 161 | 162 | # MLP block 163 | mlp_out = self.mlp(queries) 164 | queries = queries + mlp_out 165 | queries = self.norm3(queries) 166 | 167 | # Cross attention block, image embedding attending to tokens 168 | q = queries + query_pe 169 | k = keys + key_pe 170 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 171 | keys = keys + attn_out 172 | keys = self.norm4(keys) 173 | 174 | return queries, keys 175 | 176 | 177 | class Attention(nn.Module): 178 | """ 179 | An attention layer that allows for downscaling the size of the embedding 180 | after projection to queries, keys, and values. 181 | """ 182 | 183 | def __init__( 184 | self, 185 | embedding_dim: int, 186 | num_heads: int, 187 | downsample_rate: int = 1, 188 | ) -> None: 189 | super().__init__() 190 | self.embedding_dim = embedding_dim 191 | self.internal_dim = embedding_dim // downsample_rate 192 | self.num_heads = num_heads 193 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 194 | 195 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 196 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 197 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 198 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 199 | 200 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 201 | b, n, c = x.shape 202 | x = x.reshape(b, n, num_heads, c // num_heads) 203 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 204 | 205 | def _recombine_heads(self, x: Tensor) -> Tensor: 206 | b, n_heads, n_tokens, c_per_head = x.shape 207 | x = x.transpose(1, 2) 208 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 209 | 210 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 211 | # Input projections 212 | q = self.q_proj(q) 213 | k = self.k_proj(k) 214 | v = self.v_proj(v) 215 | 216 | # Separate into heads 217 | q = self._separate_heads(q, self.num_heads) 218 | k = self._separate_heads(k, self.num_heads) 219 | v = self._separate_heads(v, self.num_heads) 220 | 221 | # Attention 222 | _, _, _, c_per_head = q.shape 223 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 224 | attn = attn / math.sqrt(c_per_head) 225 | attn = torch.softmax(attn, dim=-1) 226 | 227 | # Get output 228 | out = attn @ v 229 | out = self._recombine_heads(out) 230 | out = self.out_proj(out) 231 | 232 | return out 233 | -------------------------------------------------------------------------------- /metaseg/sahi_predict.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | 7 | from metaseg import SamPredictor, sam_model_registry 8 | from metaseg.utils import download_model, load_image, multi_boxes, plt_load_box, plt_load_mask,load_image_vide 9 | 10 | 11 | def sahi_sliced_predict( 12 | image_path, 13 | detection_model_type, 14 | detection_model_path, 15 | conf_th, 16 | image_size, 17 | slice_height, 18 | slice_width, 19 | overlap_height_ratio, 20 | overlap_width_ratio, 21 | ): 22 | 23 | try: 24 | from sahi import AutoDetectionModel 25 | from sahi.predict import get_prediction, get_sliced_prediction 26 | except ImportError: 27 | raise ImportError("Please install SAHI library using 'pip install sahi'.") 28 | 29 | device = "cuda" if torch.cuda.is_available() else "cpu" 30 | 31 | detection_model = AutoDetectionModel.from_pretrained( 32 | image_size=image_size, 33 | model_type=detection_model_type, 34 | model_path=detection_model_path, 35 | confidence_threshold=conf_th, 36 | device=device, 37 | ) 38 | result = get_sliced_prediction( 39 | image_path, 40 | detection_model, 41 | slice_height=slice_height, 42 | slice_width=slice_width, 43 | overlap_height_ratio=overlap_height_ratio, 44 | overlap_width_ratio=overlap_width_ratio, 45 | ) 46 | 47 | result = get_prediction(image_path, detection_model) 48 | output = result.object_prediction_list 49 | boxes = [] 50 | for i in output: 51 | boxes.append(i.bbox.to_xyxy()) 52 | 53 | return boxes 54 | 55 | 56 | class SahiAutoSegmentation: 57 | def __init__(self): 58 | self.model = None 59 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 60 | 61 | def load_model(self, model_type): 62 | if self.model is None: 63 | self.model_path = download_model(model_type) 64 | self.model = sam_model_registry[model_type](checkpoint=self.model_path) 65 | self.model.to(device=self.device) 66 | 67 | return self.model 68 | 69 | def predict( 70 | self, 71 | source, 72 | model_type, 73 | input_box=None, 74 | input_point=None, 75 | input_label=None, 76 | multimask_output=False, 77 | random_color=False, 78 | show=False, 79 | save=False, 80 | ): 81 | 82 | # read_image = load_image(source) 83 | read_vide = load_image_vide(source) 84 | model = self.load_model(model_type) 85 | predictor = SamPredictor(model) 86 | predictor.set_image(read_vide) 87 | 88 | if type(input_box[0]) == list: 89 | input_boxes, new_boxes = multi_boxes(input_box, predictor, read_vide) 90 | 91 | masks, _, _ = predictor.predict_torch( 92 | point_coords=None, 93 | point_labels=None, 94 | boxes=new_boxes, 95 | multimask_output=False, 96 | ) 97 | 98 | elif type(input_box[0]) == int: 99 | input_boxes = np.array(input_box)[None, :] 100 | 101 | masks, _, _ = predictor.predict( 102 | point_coords=input_point, 103 | point_labels=input_label, 104 | box=input_boxes, 105 | multimask_output=multimask_output, 106 | ) 107 | 108 | 109 | return masks 110 | -------------------------------------------------------------------------------- /metaseg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from metaseg.utils.data_utils import * 8 | from metaseg.utils.file_utils import * 9 | -------------------------------------------------------------------------------- /metaseg/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from copy import deepcopy 9 | from itertools import product 10 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 11 | 12 | import numpy as np 13 | import torch 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor: 157 | """ 158 | Computes the stability score for a batch of masks. The stability 159 | score is the IoU between the binary masks obtained by thresholding 160 | the predicted mask logits at high and low values. 161 | """ 162 | # One mask is always contained inside the other. 163 | # Save memory by preventing unnecesary cast to torch.int64 164 | intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) 165 | unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) 166 | return intersections / unions 167 | 168 | 169 | def build_point_grid(n_per_side: int) -> np.ndarray: 170 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 171 | offset = 1 / (2 * n_per_side) 172 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 173 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 174 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 175 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 176 | return points 177 | 178 | 179 | def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]: 180 | """Generates point grids for all crop layers.""" 181 | points_by_layer = [] 182 | for i in range(n_layers + 1): 183 | n_points = int(n_per_side / (scale_per_layer ** i)) 184 | points_by_layer.append(build_point_grid(n_points)) 185 | return points_by_layer 186 | 187 | 188 | def generate_crop_boxes( 189 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 190 | ) -> Tuple[List[List[int]], List[int]]: 191 | """ 192 | Generates a list of crop boxes of different sizes. Each layer 193 | has (2**i)**2 boxes for the ith layer. 194 | """ 195 | crop_boxes, layer_idxs = [], [] 196 | im_h, im_w = im_size 197 | short_side = min(im_h, im_w) 198 | 199 | # Original image 200 | crop_boxes.append([0, 0, im_w, im_h]) 201 | layer_idxs.append(0) 202 | 203 | def crop_len(orig_len, n_crops, overlap): 204 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 205 | 206 | for i_layer in range(n_layers): 207 | n_crops_per_side = 2 ** (i_layer + 1) 208 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 209 | 210 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 211 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 212 | 213 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 214 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 215 | 216 | # Crops in XYWH format 217 | for x0, y0 in product(crop_box_x0, crop_box_y0): 218 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 219 | crop_boxes.append(box) 220 | layer_idxs.append(i_layer + 1) 221 | 222 | return crop_boxes, layer_idxs 223 | 224 | 225 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 226 | x0, y0, _, _ = crop_box 227 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 228 | # Check if boxes has a channel dimension 229 | if len(boxes.shape) == 3: 230 | offset = offset.unsqueeze(1) 231 | return boxes + offset 232 | 233 | 234 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 235 | x0, y0, _, _ = crop_box 236 | offset = torch.tensor([[x0, y0]], device=points.device) 237 | # Check if points has a channel dimension 238 | if len(points.shape) == 3: 239 | offset = offset.unsqueeze(1) 240 | return points + offset 241 | 242 | 243 | def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor: 244 | x0, y0, x1, y1 = crop_box 245 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 246 | return masks 247 | # Coordinate transform masks 248 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 249 | pad = (x0, pad_x - x0, y0, pad_y - y0) 250 | return torch.nn.functional.pad(masks, pad, value=0) 251 | 252 | 253 | def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]: 254 | """ 255 | Removes small disconnected regions and holes in a mask. Returns the 256 | mask and an indicator of if the mask has been modified. 257 | """ 258 | import cv2 # type: ignore 259 | 260 | assert mode in ["holes", "islands"] 261 | correct_holes = mode == "holes" 262 | working_mask = (correct_holes ^ mask).astype(np.uint8) 263 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 264 | sizes = stats[:, -1][1:] # Row 0 is background label 265 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 266 | if len(small_regions) == 0: 267 | return mask, False 268 | fill_labels = [0] + small_regions 269 | if not correct_holes: 270 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 271 | # If every region is below threshold, keep largest 272 | if len(fill_labels) == 0: 273 | fill_labels = [int(np.argmax(sizes)) + 1] 274 | mask = np.isin(regions, fill_labels) 275 | return mask, True 276 | 277 | 278 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 279 | from pycocotools import mask as mask_utils # type: ignore 280 | 281 | h, w = uncompressed_rle["size"] 282 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 283 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 284 | return rle 285 | 286 | 287 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 288 | """ 289 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 290 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 291 | """ 292 | # torch.max below raises an error on empty inputs, just skip in this case 293 | if torch.numel(masks) == 0: 294 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 295 | 296 | # Normalize shape to CxHxW 297 | shape = masks.shape 298 | h, w = shape[-2:] 299 | if len(shape) > 2: 300 | masks = masks.flatten(0, -3) 301 | else: 302 | masks = masks.unsqueeze(0) 303 | 304 | # Get top and bottom edges 305 | in_height, _ = torch.max(masks, dim=-1) 306 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 307 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 308 | in_height_coords = in_height_coords + h * (~in_height) 309 | top_edges, _ = torch.min(in_height_coords, dim=-1) 310 | 311 | # Get left and right edges 312 | in_width, _ = torch.max(masks, dim=-2) 313 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 314 | right_edges, _ = torch.max(in_width_coords, dim=-1) 315 | in_width_coords = in_width_coords + w * (~in_width) 316 | left_edges, _ = torch.min(in_width_coords, dim=-1) 317 | 318 | # If the mask is empty the right edge will be to the left of the left edge. 319 | # Replace these boxes with [0, 0, 0, 0] 320 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 321 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 322 | out = out * (~empty_filter).unsqueeze(-1) 323 | 324 | # Return to original shape 325 | if len(shape) > 2: 326 | out = out.reshape(*shape[:-2], 4) 327 | else: 328 | out = out[0] 329 | 330 | return out 331 | -------------------------------------------------------------------------------- /metaseg/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | def load_image(image_path): 2 | import cv2 3 | 4 | image = cv2.imread(image_path) 5 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 6 | return image 7 | def load_image_vide(image_path): 8 | import cv2 9 | 10 | 11 | image = cv2.cvtColor(image_path, cv2.COLOR_BGR2RGB) 12 | return image 13 | 14 | def load_server_image(image_path): 15 | import os 16 | from io import BytesIO 17 | from uuid import uuid4 18 | 19 | from PIL import Image 20 | 21 | imagedir = str(uuid4()) 22 | os.system(f"mkdir -p {imagedir}") 23 | image = Image.open(BytesIO(image_path)) 24 | if image.mode != "RGB": 25 | image = image.convert("RGB") 26 | 27 | image_path = f"{imagedir}/base_image_v0.png" 28 | output_path = f"{imagedir}/output_v0.png" 29 | image.save(image_path, format="PNG") 30 | return image_path, output_path 31 | 32 | 33 | def load_video(video_path, output_path="output.mp4"): 34 | import cv2 35 | 36 | cap = cv2.VideoCapture(video_path) 37 | frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 38 | frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 39 | fourcc = cv2.VideoWriter_fourcc(*"XVID") 40 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 41 | out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height)) 42 | 43 | return cap, out 44 | 45 | 46 | def read_image(image_path): 47 | import cv2 48 | 49 | image = cv2.imread(image_path) 50 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 51 | return image 52 | 53 | 54 | def load_mask(mask, random_color): 55 | import numpy as np 56 | 57 | if random_color: 58 | color = np.random.rand(3) * 255 59 | else: 60 | color = np.array([100, 50, 0]) 61 | 62 | h, w = mask.shape[-2:] 63 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 64 | mask_image = mask_image.astype(np.uint8) 65 | return mask_image 66 | 67 | 68 | def load_box(box, image): 69 | import cv2 70 | 71 | x, y, w, h = int(box[0]), int(box[1]), int(box[2]), int(box[3]) 72 | cv2.rectangle(image, (x, y), (w, h), (0, 255, 0), 2) 73 | return image 74 | 75 | 76 | def plt_load_mask(mask, ax, random_color=False): 77 | import numpy as np 78 | 79 | if random_color: 80 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 81 | else: 82 | color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) 83 | h, w = mask.shape[-2:] 84 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 85 | ax.imshow(mask_image) 86 | 87 | 88 | def plt_load_box(box, ax): 89 | import matplotlib.pyplot as plt 90 | 91 | x0, y0 = box[0], box[1] 92 | w, h = box[2] - box[0], box[3] - box[1] 93 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)) 94 | 95 | 96 | def multi_boxes(boxes, predictor, image): 97 | import torch 98 | 99 | input_boxes = torch.tensor(boxes, device=predictor.device) 100 | transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2]) 101 | return input_boxes, transformed_boxes 102 | 103 | 104 | def show_image(output_image): 105 | import cv2 106 | 107 | cv2.imshow("output", output_image) 108 | cv2.waitKey(0) 109 | cv2.destroyAllWindows() 110 | 111 | 112 | def save_image(output_image, output_path): 113 | import cv2 114 | 115 | cv2.imwrite(output_path, output_image) 116 | -------------------------------------------------------------------------------- /metaseg/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | 4 | 5 | def download_model(model_type): 6 | """ 7 | model_type: str, A string representing the model type. It can be 'vit_h', 'vit_l', or 'vit_b'. 8 | """ 9 | 10 | # A dictionary containing model types as keys and their respective URLs as values 11 | model_urls = { 12 | "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 13 | "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 14 | "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 15 | } 16 | 17 | # Check if the model file already exists and model_type is in model_urls 18 | filename = f"{model_type}.pth" 19 | if not os.path.exists(filename) and model_type in model_urls: 20 | url = model_urls[model_type] 21 | print(f"Downloading {model_type} model from {url}...") 22 | urllib.request.urlretrieve(url, filename) 23 | print(f"{model_type} model has been successfully downloaded and saved as '{filename}'.") 24 | elif os.path.exists(filename): 25 | print(f"{model_type} model already exists as '{filename}'. Skipping download.") 26 | else: 27 | raise ValueError("Invalid model type. It should be 'vit_h', 'vit_l', or 'vit_b'.") 28 | 29 | return filename 30 | -------------------------------------------------------------------------------- /metaseg/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import functional as F 12 | 13 | from metaseg.modeling import Sam 14 | from metaseg.utils.amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size(input_image_size: torch.Tensor, longest_side: int) -> torch.Tensor: 43 | input_image_size = input_image_size.to(torch.float32) 44 | scale = longest_side / torch.max(input_image_size) 45 | transformed_size = scale * input_image_size 46 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 47 | return transformed_size 48 | 49 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 50 | point_coords = point_coords + 0.5 51 | point_coords = point_coords / self.img_size 52 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 53 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 54 | 55 | point_embedding = point_embedding * (point_labels != -1) 56 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1) 57 | 58 | for i in range(self.model.prompt_encoder.num_point_embeddings): 59 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[i].weight * ( 60 | point_labels == i 61 | ) 62 | 63 | return point_embedding 64 | 65 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 66 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 67 | mask_embedding = mask_embedding + (1 - has_mask_input) * self.model.prompt_encoder.no_mask_embed.weight.reshape( 68 | 1, -1, 1, 1 69 | ) 70 | return mask_embedding 71 | 72 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 73 | masks = F.interpolate( 74 | masks, 75 | size=(self.img_size, self.img_size), 76 | mode="bilinear", 77 | align_corners=False, 78 | ) 79 | 80 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size) 81 | masks = masks[..., : int(prepadded_size[0]), : int(prepadded_size[1])] 82 | 83 | orig_im_size = orig_im_size.to(torch.int64) 84 | h, w = orig_im_size[0], orig_im_size[1] 85 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 86 | return masks 87 | 88 | def select_masks( 89 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 90 | ) -> Tuple[torch.Tensor, torch.Tensor]: 91 | # Determine if we should return the multiclick mask or not from the number of points. 92 | # The reweighting is used to avoid control flow. 93 | score_reweight = torch.tensor([[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]).to( 94 | iou_preds.device 95 | ) 96 | score = iou_preds + (num_points - 2.5) * score_reweight 97 | best_idx = torch.argmax(score, dim=1) 98 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 99 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 100 | 101 | return masks, iou_preds 102 | 103 | @torch.no_grad() 104 | def forward( 105 | self, 106 | image_embeddings: torch.Tensor, 107 | point_coords: torch.Tensor, 108 | point_labels: torch.Tensor, 109 | mask_input: torch.Tensor, 110 | has_mask_input: torch.Tensor, 111 | orig_im_size: torch.Tensor, 112 | ): 113 | sparse_embedding = self._embed_points(point_coords, point_labels) 114 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 115 | 116 | masks, scores = self.model.mask_decoder.predict_masks( 117 | image_embeddings=image_embeddings, 118 | image_pe=self.model.prompt_encoder.get_dense_pe(), 119 | sparse_prompt_embeddings=sparse_embedding, 120 | dense_prompt_embeddings=dense_embedding, 121 | ) 122 | 123 | if self.use_stability_score: 124 | scores = calculate_stability_score(masks, self.model.mask_threshold, self.stability_score_offset) 125 | 126 | if self.return_single_mask: 127 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 128 | 129 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 130 | 131 | if self.return_extra_metrics: 132 | stability_scores = calculate_stability_score( 133 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 134 | ) 135 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 136 | return upscaled_masks, scores, stability_scores, areas, masks 137 | 138 | return upscaled_masks, scores, masks 139 | -------------------------------------------------------------------------------- /metaseg/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from copy import deepcopy 8 | from typing import Tuple 9 | 10 | import numpy as np 11 | import torch 12 | from torch.nn import functional as F 13 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) 40 | coords = deepcopy(coords).astype(float) 41 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 42 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 43 | return coords 44 | 45 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 46 | """ 47 | Expects a numpy array shape Bx4. Requires the original image size 48 | in (H, W) format. 49 | """ 50 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 51 | return boxes.reshape(-1, 4) 52 | 53 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 54 | """ 55 | Expects batched images with shape BxCxHxW and float format. This 56 | transformation may not exactly match apply_image. apply_image is 57 | the transformation expected by the model. 58 | """ 59 | # Expects an image in BCHW format. May not exactly match apply_image. 60 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 61 | return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True) 62 | 63 | def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: 64 | """ 65 | Expects a torch tensor with length 2 in the last dimension. Requires the 66 | original image size in (H, W) format. 67 | """ 68 | old_h, old_w = original_size 69 | new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length) 70 | coords = deepcopy(coords).to(torch.float) 71 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 72 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 73 | return coords 74 | 75 | def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor: 76 | """ 77 | Expects a torch tensor with shape Bx4. Requires the original image 78 | size in (H, W) format. 79 | """ 80 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 81 | return boxes.reshape(-1, 4) 82 | 83 | @staticmethod 84 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 85 | """ 86 | Compute the output size given input size and target long side length. 87 | """ 88 | scale = long_side_length * 1.0 / max(oldh, oldw) 89 | newh, neww = oldh * scale, oldw * scale 90 | neww = int(neww + 0.5) 91 | newh = int(newh + 0.5) 92 | return (newh, neww) 93 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.18.0 2 | aiofiles==23.1.0 3 | aiohttp==3.8.4 4 | aiosignal==1.3.1 5 | altair==4.2.2 6 | antlr4-python3-runtime==4.9.3 7 | anyio==3.6.2 8 | async-timeout==4.0.2 9 | attrs==23.1.0 10 | auth0-python==4.2.0 11 | certifi==2022.12.7 12 | cffi==1.15.1 13 | charset-normalizer==2.1.1 14 | click==8.0.4 15 | colorama==0.4.6 16 | contourpy==1.0.7 17 | cryptography==40.0.2 18 | cycler==0.11.0 19 | datadog-api-client==2.12.0 20 | Deprecated==1.2.13 21 | diffusers==0.14.0 22 | dill==0.3.5.1 23 | distlib==0.3.6 24 | entrypoints==0.4 25 | exceptiongroup==1.1.1 26 | fal-serverless==0.6.29 27 | fastapi==0.95.1 28 | ffmpy==0.3.0 29 | filelock==3.12.0 30 | fire==0.5.0 31 | Flask==1.1.4 32 | Flask-Cors==3.0.10 33 | flaskwebgui==0.3.5 34 | fonttools==4.39.3 35 | frozenlist==1.3.3 36 | fsspec==2023.4.0 37 | gradio==3.28.1 38 | gradio_client==0.1.4 39 | grpc-interceptor==0.15.1 40 | grpcio==1.54.0 41 | h11==0.14.0 42 | httpcore==0.17.0 43 | httpx==0.24.0 44 | huggingface-hub==0.14.1 45 | idna==3.4 46 | imageio==2.28.1 47 | importlib-metadata==6.0.1 48 | iniconfig==2.0.0 49 | isolate==0.12.0 50 | isolate-proto==0.0.27 51 | itsdangerous==1.1.0 52 | Jinja2==2.11.3 53 | jsonschema==4.17.3 54 | kiwisolver==1.4.4 55 | lama-cleaner==1.1.2 56 | linkify-it-py==2.0.2 57 | loguru==0.7.0 58 | markdown-it-py==2.2.0 59 | MarkupSafe==2.0.1 60 | matplotlib==3.7.1 61 | mdit-py-plugins==0.3.3 62 | mdurl==0.1.2 63 | mediapipe==0.9.3.0 64 | mpmath==1.2.1 65 | multidict==6.0.4 66 | networkx==3.0 67 | numpy==1.24.1 68 | omegaconf==2.3.0 69 | opencv-contrib-python==4.7.0.72 70 | opencv-python==4.7.0.72 71 | opentelemetry-api==1.17.0 72 | opentelemetry-sdk==1.17.0 73 | opentelemetry-semantic-conventions==0.38b0 74 | orjson==3.8.11 75 | packaging==23.1 76 | pandas==2.0.1 77 | piexif==1.1.3 78 | Pillow==9.3.0 79 | platformdirs==3.5.0 80 | pluggy==1.0.0 81 | portalocker==2.7.0 82 | protobuf==4.22.3 83 | psutil==5.9.5 84 | pybboxes==0.1.6 85 | pycparser==2.21 86 | pydantic==1.10.7 87 | pydub==0.25.1 88 | Pygments==2.15.1 89 | PyJWT==2.6.0 90 | pyparsing==3.0.9 91 | pyrsistent==0.19.3 92 | pytest==7.3.1 93 | python-dateutil==2.8.2 94 | python-multipart==0.0.6 95 | pytz==2023.3 96 | PyWavelets==1.4.1 97 | pywin32==306 98 | PyYAML==6.0 99 | regex==2023.5.4 100 | requests==2.28.1 101 | rich==13.3.5 102 | safetensors==0.3.1 103 | sahi==0.11.13 104 | scikit-image==0.19.3 105 | scipy==1.10.1 106 | seaborn==0.12.2 107 | semantic-version==2.10.0 108 | sentry-sdk==1.21.1 109 | shapely==2.0.1 110 | six==1.16.0 111 | sniffio==1.3.0 112 | sounddevice==0.4.6 113 | starlette==0.26.1 114 | structlog==22.3.0 115 | sympy==1.11.1 116 | tblib==1.7.0 117 | termcolor==2.3.0 118 | terminaltables==3.1.10 119 | thop==0.1.1.post2209072238 120 | tifffile==2023.4.12 121 | tokenizers==0.13.3 122 | tomli==2.0.1 123 | toolz==0.12.0 124 | torch==2.0.0+cu118 125 | torchaudio==2.0.1+cu118 126 | torchvision==0.15.1+cu118 127 | tqdm==4.65.0 128 | transformers==4.27.4 129 | typing_extensions==4.4.0 130 | tzdata==2023.3 131 | uc-micro-py==1.0.2 132 | ultralytics==8.0.91 133 | urllib3==1.26.13 134 | uvicorn==0.22.0 135 | virtualenv==20.23.0 136 | websockets==11.0.2 137 | Werkzeug==1.0.1 138 | whichcraft==0.6.1 139 | win32-setctime==1.1.0 140 | wrapt==1.15.0 141 | yacs==0.1.8 142 | yarl==1.9.2 143 | zipp==3.15.0 144 | --------------------------------------------------------------------------------