├── FNPC_Interface.py ├── README.md ├── Single_Slice_to_Volume.py ├── loss_aug_sam.py ├── prompt_aug.py ├── segment_anything ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── automatic_mask_generator.cpython-36.pyc │ ├── build_sam.cpython-36.pyc │ └── predictor.cpython-36.pyc ├── automatic_mask_generator.py ├── build_sam.py ├── modeling │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── common.cpython-36.pyc │ │ ├── image_encoder.cpython-36.pyc │ │ ├── mask_decoder.cpython-36.pyc │ │ ├── prompt_encoder.cpython-36.pyc │ │ ├── sam.cpython-36.pyc │ │ └── transformer.cpython-36.pyc │ ├── common.py │ ├── image_encoder.py │ ├── mask_decoder.py │ ├── prompt_encoder.py │ ├── sam.py │ └── transformer.py ├── predictor.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── amg.cpython-36.pyc │ └── transforms.cpython-36.pyc │ ├── amg.py │ ├── onnx.py │ └── transforms.py ├── show.py └── uncertainty_refine.py /FNPC_Interface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import cv2 7 | from tqdm import tqdm 8 | import argparse 9 | import matplotlib.pyplot as plt 10 | import warnings 11 | 12 | warnings.filterwarnings('ignore') 13 | 14 | from show import * 15 | from segment_anything import sam_model_registry, SamPredictor 16 | 17 | from itertools import combinations 18 | 19 | from scipy.spatial import ConvexHull 20 | import numpy as np 21 | import itertools 22 | import scipy 23 | from scipy.spatial.distance import cdist 24 | from prompt_aug import * 25 | from loss_aug_sam import * 26 | from uncertainty_refine import * 27 | from PIL import Image 28 | 29 | def get_arguments(): 30 | parser = argparse.ArgumentParser() 31 | 32 | parser.add_argument('--data', type=str, 33 | default="E:/Kidney/FNPC_dataset") 34 | parser.add_argument('--outdir', type=str, 35 | default='/result1') 36 | 37 | parser.add_argument('--ckpt', type=str, default='sam_vit_l_0b3195.pth') 38 | parser.add_argument('--ref_idx', type=str, default='000') # TODO change to "000" when using kidney 39 | parser.add_argument('--sam_type', type=str, default='vit_l')# 'vit_h', 'vit_t', 'vit_l' 40 | 41 | args = parser.parse_args() 42 | return args 43 | 44 | def main(): 45 | args = get_arguments() 46 | print("Args:", args) 47 | 48 | images_path = args.data + '/Images/' 49 | masks_path = args.data + '/Annotations_heavy_rough/' 50 | gts_path = args.data + '/GTs/' 51 | output_path = './outputs/' + args.outdir 52 | 53 | if not os.path.exists('./outputs/'): 54 | os.mkdir('./outputs/') 55 | 56 | all_dice_original = [] 57 | all_assd_original = [] 58 | all_hf_original = [] 59 | 60 | all_dice_ave = [] 61 | all_assd_ave = [] 62 | all_hf_ave = [] 63 | 64 | all_dice_FP_final = [] 65 | all_assd_FP_final = [] 66 | all_hf_FP_final = [] 67 | 68 | all_dice_FN_final = [] 69 | all_assd_FN_final = [] 70 | all_hf_FN_final = [] 71 | 72 | for obj_name in os.listdir(images_path): 73 | if ".DS" not in obj_name: 74 | # print("object name is:", obj_name) # object name is: Subject2_20230330_135719_1 75 | subject_dice_original, subject_assd_original, subject_hf_original, subject_dice_ave, subject_assd_ave, \ 76 | subject_hf_ave, subject_dice_FP_final, subject_assd_FP_final,subject_hf_FP_final, subject_dice_FN_final, \ 77 | subject_assd_FN_final,subject_hf_FN_final = FNPC(args, obj_name, images_path, masks_path, output_path, gts_path) 78 | 79 | all_dice_original += subject_dice_original 80 | all_assd_original += subject_assd_original 81 | all_hf_original += subject_hf_original 82 | 83 | all_dice_ave += subject_dice_ave 84 | all_assd_ave += subject_assd_ave 85 | all_hf_ave += subject_hf_ave 86 | 87 | all_dice_FP_final += subject_dice_FP_final 88 | all_assd_FP_final += subject_assd_FP_final 89 | all_hf_FP_final += subject_hf_FP_final 90 | 91 | all_dice_FN_final += subject_dice_FN_final 92 | all_assd_FN_final += subject_assd_FN_final 93 | all_hf_FN_final += subject_hf_FN_final 94 | 95 | mean_dice_FN_final = np.mean(all_dice_FN_final) 96 | mean_assd_FN_final = np.mean(all_assd_FN_final ) 97 | mean_hf_FN_final = np.mean(all_hf_FN_final) 98 | std_dice_FN_final = np.std(all_dice_FN_final) 99 | std_assd_FN_final = np.std(all_assd_FN_final) 100 | std_hf_FN_final = np.std(all_hf_FN_final) 101 | 102 | mean_dice_FP_final = np.mean(all_dice_FP_final) 103 | mean_assd_FP_final = np.mean(all_assd_FP_final) 104 | mean_hf_FP_final = np.mean(all_hf_FP_final) 105 | std_dice_FP_final = np.std(all_dice_FP_final) 106 | std_assd_FP_final = np.std(all_assd_FP_final) 107 | std_hf_FP_final = np.std(all_hf_FP_final) 108 | 109 | mean_dice_original = np.mean(all_dice_original) 110 | mean_assd_original = np.mean(all_assd_original) 111 | mean_hf_original = np.mean(all_hf_original) 112 | std_dice_original = np.std(all_dice_original) 113 | std_assd_original = np.std(all_assd_original) 114 | std_hf_original = np.std(all_hf_original) 115 | 116 | mean_dice_ave = np.mean(all_dice_ave) 117 | mean_assd_ave = np.mean(all_assd_ave) 118 | mean_hf_ave = np.mean(all_hf_ave) 119 | std_dice_ave = np.std(all_dice_ave) 120 | std_assd_ave = np.std(all_assd_ave) 121 | std_hf_ave = np.std(all_hf_ave) 122 | 123 | # uncommon the following part if you want to observe results for each data 124 | # all_dice_original_str = ', '.join(map(str, all_dice_original)) 125 | # all_assd_original_str = ', '.join(map(str, all_assd_original)) 126 | # all_hf_original_str = ', '.join(map(str, all_hf_original)) 127 | # 128 | # all_dice_ave_str = ', '.join(map(str, all_dice_ave)) 129 | # all_assd_ave_str = ', '.join(map(str, all_assd_ave)) 130 | # all_hf_ave_str = ', '.join(map(str, all_hf_ave)) 131 | # 132 | # all_dice_FP_final_str = ', '.join(map(str,all_dice_FP_final)) 133 | # all_assd_FP_final_str = ', '.join(map(str, all_assd_FP_final)) 134 | # all_hf_FP_final_str = ', '.join(map(str, all_hf_FP_final)) 135 | # 136 | # all_dice_FN_final_str = ', '.join(map(str, all_dice_FN_final)) 137 | # all_assd_FN_final_str = ', '.join(map(str, all_assd_FN_final)) 138 | # all_hf_FN_final_str = ', '.join(map(str, all_hf_FN_final)) 139 | 140 | 141 | with open(output_path+"/output.txt", "w") as file: 142 | # file.write(f"subject dice original: {all_dice_original_str}\n") 143 | # file.write(f"subject dice ave: {all_dice_ave_str}\n") 144 | # file.write(f"subject dice final: {all_dice_final_str}\n") 145 | # 146 | # file.write("\n") 147 | # file.write(f"subject assd original: {all_assd_original_str}\n") 148 | # file.write(f"subject assd final: {all_assd_final_str}\n") 149 | # file.write(f"subject assd ave: {all_assd_ave_str}\n") 150 | # file.write("\n") 151 | # 152 | # file.write(f"subject hf original: {all_hf_original_str}\n") 153 | # file.write(f"subject hf ave: {all_hf_ave_str}\n") 154 | # file.write(f"subject hf final: {all_hf_final_str}\n") 155 | # file.write("\n") 156 | 157 | file.write(f"mean_dice_original: {mean_dice_original}, std: {std_dice_original}\n") 158 | file.write(f"mean_dice_ave: {mean_dice_ave}, std: {std_dice_ave}\n") 159 | file.write(f"mean_dice_FN_final: {mean_dice_FN_final}, std: {std_dice_FN_final}\n") 160 | file.write(f"mean_dice_FNFP_final: {mean_dice_FP_final}, std: {std_dice_FP_final}\n") 161 | file.write("\n") 162 | file.write(f"mean_assd_original: {mean_assd_original}, std: {std_assd_original}\n") 163 | file.write(f"mean_assd_ave: {mean_assd_ave}, std: {std_assd_ave}\n") 164 | file.write(f"mean_assd_FN_final: {mean_assd_FN_final}, std: {std_assd_FN_final}\n") 165 | file.write(f"mean_assd_FNFP_final: {mean_assd_FP_final}, std: {std_assd_FP_final}\n") 166 | file.write("\n") 167 | file.write(f"mean_hf_original: {mean_hf_original}, std: {std_hf_original}\n") 168 | file.write(f"mean_hf_ave: {mean_hf_ave}, std: {std_hf_ave}\n") 169 | file.write(f"mean_hf_FN_final: {mean_hf_FN_final}, std: {std_hf_FN_final}\n") 170 | file.write(f"mean_hf_FNFP_final: {mean_hf_FP_final}, std: {std_hf_FP_final}\n") 171 | file.write("\n") 172 | 173 | 174 | def FNPC(args, obj_name, images_path, masks_path, output_path, gts_path): 175 | print("\n------------> Segment " + obj_name) 176 | 177 | M = # scale ratio 178 | N = # number of sampled bounding box 179 | ave_thre = 180 | uncertain_thre = # uncertain_thre ratio 181 | 182 | # for kidney the value of FN should be small, which means the range should be small 183 | # 184 | fna = # include more FN which outside the avemask but inside the UM, with value fna< and < fnb 185 | fnb = 186 | fpa = # exclude FP which inside the avemask and UM, with valuefpb 187 | fpb = 188 | 189 | # Path preparation 190 | test_images_path = os.path.join(images_path, obj_name) 191 | 192 | output_path = os.path.join(output_path, obj_name) 193 | os.makedirs(output_path, exist_ok=True) 194 | 195 | print("======> Load SAM") 196 | if args.sam_type == 'vit_h': 197 | # TODO add a weights/ 198 | # sam_type, sam_ckpt = 'vit_h', 'weights/sam_vit_h_4b8939.pth' 199 | sam_type, sam_ckpt = 'vit_h', 'weights/sam_vit_h_4b8939.pth' 200 | sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() 201 | sam.eval() 202 | elif args.sam_type == 'vit_t': 203 | sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' 204 | device = "cuda" if torch.cuda.is_available() else "cpu" 205 | sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) 206 | sam.eval() 207 | 208 | # here add more sam archs 209 | elif args.sam_type == 'vit_l': 210 | sam_type, sam_ckpt = 'vit_l', 'weights/sam_vit_l_0b3195.pth' 211 | device = "cuda" if torch.cuda.is_available() else "cpu" 212 | sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) 213 | sam.eval() 214 | 215 | predictor = SamPredictor(sam) 216 | # print(os.listdir(test_images_path)) 217 | sorted_files = sorted(os.listdir(test_images_path)) 218 | 219 | print('======> Start Testing') 220 | 221 | subject_dice_original = [] 222 | subject_assd_original = [] 223 | subject_hf_original = [] 224 | 225 | subject_dice_ave = [] 226 | subject_assd_ave = [] 227 | subject_hf_ave = [] 228 | 229 | subject_dice_FN_final = [] 230 | subject_assd_FN_final = [] 231 | subject_hf_FN_final = [] 232 | 233 | subject_dice_FP_final = [] 234 | subject_assd_FP_final = [] 235 | subject_hf_FP_final = [] 236 | 237 | for idx in tqdm(range(len(os.listdir(test_images_path)))): 238 | # Load test image 239 | # test_idx = '%03d' % test_idx # TODO change to %03d when using kidney data 240 | test_idx = sorted_files[idx][:-4]#get the part except the '.png' 241 | # TODO change the datatype 242 | test_image_path = test_images_path + '/' + test_idx + '.png' 243 | test_image = cv2.imread(test_image_path) 244 | test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB) 245 | 246 | ref_mask_path = os.path.join(masks_path, obj_name, test_idx + '.png') 247 | ref_mask = cv2.imread(ref_mask_path) 248 | ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB) 249 | 250 | gt_mask_path = os.path.join(gts_path, obj_name, test_idx + '.png') 251 | gt_mask = cv2.imread(gt_mask_path) 252 | # print("GT mask shape:", gt_mask.shape) # 255 253 | # print("GT mask max value:", np.max(gt_mask))(128, 128, 3) 254 | gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_BGR2GRAY) 255 | # print("GT mask shape:", gt_mask.shape) # 255 256 | # print("GT mask max value:", np.max(gt_mask)) 257 | _, gt_mask = cv2.threshold(gt_mask, np.max(gt_mask) // 2, 1, cv2.THRESH_BINARY) 258 | # print("GT mask max value:", np.max(gt_mask)) 259 | 260 | disply_original_image = test_image.copy() # 261 | disply_original_image2 = test_image.copy() 262 | 263 | # Image feature encoding 264 | y, x = np.nonzero(ref_mask[:,:,0]) 265 | if y.size == 0: # in case some input doesn't have bb 266 | x_min = 64 - 10 267 | x_max = 64 + 10 268 | y_min = 64 - 10 269 | y_max = 64 + 10 270 | else: 271 | x_min = x.min() 272 | x_max = x.max() 273 | y_min = y.min() 274 | y_max = y.max() 275 | input_box = np.array([x_min, y_min, x_max, y_max]) 276 | 277 | # ###### bbox augmentation Start ####### 278 | new_bbox_list, radius, Center_set = sample_points2(input_box, M, N) 279 | 280 | # Uncommon this part to show visualize the bbox augmentation 281 | # save_path_demo = os.path.join(output_path, f'Demo_{test_idx}.png') 282 | # save_path_boxs = os.path.join(output_path, f'BoxDemo_{test_idx}.png') 283 | # visualize_and_save(disply_original_image, new_bbox_list, VT1_set, VT2_set, VT3_set, VT4_set, radius, save_path_demo, save_path_boxs) 284 | 285 | # generate mask for bbox list 286 | mask_list = [] 287 | for i in range(len(new_bbox_list)): 288 | predictor.set_image(test_image) 289 | # print("the box i is:", new_bbox_list[i][None, :]) 290 | masks, scores, logits, _ = predictor.predict( 291 | point_coords=None, 292 | point_labels=None, 293 | box=new_bbox_list[i][None, :], 294 | multimask_output=False, 295 | return_logits = False 296 | ) 297 | mask_list.append(masks) 298 | 299 | # calculate the aleatoric_uncertainty 300 | input_mask_list = mask_list.copy() 301 | aleatoric_uncertainty_map = calculate_aleatoric_uncertainty(input_mask_list) # 128x128 302 | original_uncertainty_map = aleatoric_uncertainty_map.copy() 303 | 304 | ave_mask = np.mean(mask_list, axis=0) 305 | ave_mask = np.where(ave_mask >= ave_thre * (np.max(ave_mask)), 1, 0) 306 | 307 | # FNPC process: 308 | # For the outputs: 309 | # FN_UH: FN mask selected by the high uncertainty above the thre 310 | # FN_xUH: FN area on the image selected by the high uncertainty above the thre 311 | # FN_final_mask: FN correct mask. Note this mask is not refined by the bbox, so may contain some out-side-box region, but doesn't influence the result 312 | # FP_UH, FN_xUH: FP mask and image area selected by the high uncertainty above the thre 313 | # FP_final_mask: final FN_FP corrected mask, this mask will be refined by BBox before visualization 314 | 315 | # # if you want to use a hardcode version, please use uc_refine_hardcode 316 | # FN_final_mask, FN_UH, FN_xUH, FN_condition_mask, FP_final_mask, FP_UH, FP_xUH, FP_condition_mask = uc_refine_hardcode(ave_mask.copy(), 317 | # aleatoric_uncertainty_map.copy(), 318 | # test_image.copy(), threshold_uc=uncertain_thre, fn_alpha = fna, 319 | # fn_beta = fnb,fp_alpha = fpa, fp_beta = fpb) 320 | # else, use uc_refine_correct 321 | FN_final_mask, FN_UH, FN_xUH, FN_condition_mask, FP_final_mask, FP_UH, FP_xUH, FP_condition_mask = uc_refine_correct(ave_mask.copy(), 322 | aleatoric_uncertainty_map.copy(), 323 | test_image.copy(), threshold_uc=uncertain_thre, fn_alpha=fna, 324 | fn_beta=fnb, fp_alpha=fpa, fp_beta=fpb) 325 | 326 | # ****** Using the initial bounding box to refine the ave_mask and fine_mask ****** # 327 | # The prior knowledge is that the region outside the initial bounding box are all true negative 328 | refine_mask = np.where(ref_mask[:, :, 0] > 0, 1, 0) 329 | ave_mask[0] = ave_mask[0]*refine_mask 330 | FP_final_mask[0] = FP_final_mask[0]*refine_mask 331 | 332 | # print(FP_final_mask.shape) 333 | 334 | print("the max number of FP_xUH is:", np.max(FP_xUH)) 335 | print("the min number of FP_xUH is:", np.min(FP_xUH)) 336 | 337 | # the mask_list is in true or false farmat, turn it into 0, 1 338 | original_mask = mask_list[0][0].copy() 339 | original_mask = original_mask.astype(int) 340 | 341 | # calculate Dice, ASSD, hausdorff 342 | case_dice_original = dice_score(gt_mask, original_mask) 343 | case_assd_original = assd(gt_mask, original_mask) 344 | case_hf_original = hausdorff_distance(gt_mask, original_mask) 345 | 346 | case_dice_ave = dice_score(gt_mask, ave_mask[0]) 347 | case_assd_ave = assd(gt_mask, ave_mask[0]) 348 | case_hf_ave = hausdorff_distance(gt_mask, ave_mask[0]) 349 | 350 | # final_mask = FP_final_mask.copy() 351 | print("FP_final_mask.copy().shape",FP_final_mask.copy().shape) 352 | print("FP_final_mask.copy() max", np.max(FP_final_mask.copy())) 353 | case_dice_FP_final = dice_score(gt_mask, FP_final_mask[0]) 354 | case_assd_FP_final = assd(gt_mask, FP_final_mask[0]) 355 | case_hf_FP_final = hausdorff_distance(gt_mask, FP_final_mask[0]) 356 | 357 | case_dice_FN_final = dice_score(gt_mask, FN_final_mask[0]) 358 | case_assd_FN_final = assd(gt_mask, FN_final_mask[0]) 359 | case_hf_FN_final = hausdorff_distance(gt_mask, FN_final_mask[0]) 360 | 361 | subject_dice_original.append(case_dice_original) 362 | subject_assd_original.append(case_assd_original) 363 | subject_hf_original.append(case_hf_original) 364 | 365 | subject_dice_ave.append(case_dice_ave) 366 | subject_assd_ave.append(case_assd_ave) 367 | subject_hf_ave.append(case_hf_ave) 368 | 369 | subject_dice_FN_final.append(case_dice_FN_final) 370 | subject_assd_FN_final.append(case_assd_FN_final) 371 | subject_hf_FN_final.append(case_hf_FN_final) 372 | 373 | subject_dice_FP_final.append(case_dice_FP_final) 374 | subject_assd_FP_final.append(case_assd_FP_final) 375 | subject_hf_FP_final.append(case_hf_FP_final) 376 | 377 | # visualization the FNPC process 378 | 379 | fig, axs = plt.subplots(3, 4, figsize=(10 * 3, 10 * 4)) 380 | 381 | axs[0 , 0].imshow(test_image) 382 | show_mask(original_mask, axs[0 , 0]) 383 | show_box(new_bbox_list[0], axs[0 , 0], color="yellow") 384 | axs[0 , 0].set_title(f'Ori_D_{case_dice_original:.2f}_A_{case_assd_original:.2f}_H_{case_hf_original:.2f}', 385 | fontsize=18) 386 | 387 | axs[0 ,1].imshow(test_image) 388 | show_mask(ave_mask, axs[0 ,1]) 389 | show_box(new_bbox_list[0], axs[0 ,1], color="yellow") 390 | axs[0 ,1].set_title(f'Ave_D_{case_dice_ave:.2f}_A_{case_assd_ave:.2f}_H_{case_hf_ave:.2f}', fontsize=18) 391 | 392 | # 393 | vis_img = test_image.copy() 394 | # axs[4].imshow(test_image) 395 | # show_uncertainty(aleatoric_uncertainty_map, plt.gca()) 396 | overlay_uncertainty_on_image(vis_img, original_uncertainty_map, axs[0, 2]) 397 | axs[0, 2].set_title(f"Uncertainty Map", fontsize=18) 398 | 399 | axs[0 ,3].imshow(test_image) 400 | show_mask(np.expand_dims(gt_mask, axis=0), axs[0 ,3]) 401 | show_box(new_bbox_list[0], axs[0 ,3], color="yellow") 402 | axs[0 ,3].set_title(f'GT', fontsize=18) 403 | 404 | # FN UH 405 | axs[1, 0].imshow(FN_UH) 406 | axs[1, 0].set_title(f"FN UH", fontsize=18) 407 | 408 | # # FN xUH 409 | axs[1, 1].imshow(FN_xUH, cmap='gray', vmin=0, vmax=255) 410 | axs[1, 1].set_title(f"FN xUH", fontsize=18) 411 | 412 | # # FN condition 413 | axs[1, 2].imshow(FN_condition_mask) 414 | axs[1, 2].set_title(f"FN_condition_mask", fontsize=18) 415 | 416 | axs[1, 3].imshow(test_image) 417 | show_mask(FN_final_mask, axs[1, 3]) 418 | show_box(new_bbox_list[0], axs[1, 3], color="yellow") 419 | axs[1, 3].set_title(f'FN_D_{case_dice_FN_final:.2f}_A_{case_assd_FN_final:.2f}_H_{case_hf_FN_final:.2f}', fontsize=18) 420 | 421 | # FP UH 422 | axs[2, 0].imshow(FP_UH) 423 | axs[2, 0].set_title(f"FP UH", fontsize=18) 424 | 425 | # # FP xUH 426 | axs[2, 1].imshow(FP_xUH, cmap='gray', vmin=0, vmax=255) 427 | axs[2, 1].set_title(f"FP xUH", fontsize=18) 428 | 429 | # # FN condition 430 | axs[2, 2].imshow(FP_condition_mask) 431 | axs[2, 2].set_title(f"FP_condition_mask", fontsize=18) 432 | 433 | axs[2, 3].imshow(test_image) 434 | show_mask(FP_final_mask, axs[2, 3]) 435 | show_box(new_bbox_list[0], axs[2, 3], color="yellow") 436 | axs[2, 3].set_title(f'FP_FN_D_{case_dice_FP_final:.2f}_A_{case_assd_FP_final:.2f}_H_{case_hf_FP_final:.2f}', fontsize=18) 437 | 438 | 439 | # Save the subplot panel as a single image 440 | vis_mask_output_path = os.path.join(output_path, f'Results_Comparison_{test_idx}.png') 441 | plt.savefig(vis_mask_output_path, format='png') 442 | 443 | 444 | ######### here save all the prediction masks for the reconstruction 445 | # save all the predicted mask for the next loop: 446 | ori_save_name = os.path.join(output_path, "ori_fake_mask_" + str(test_idx) + ".png") # since it's from 0 447 | ave_save_name = os.path.join(output_path, "ave_fake_mask_" + str(test_idx) + ".png") 448 | final_save_name = os.path.join(output_path, "final_fake_mask_" + str(test_idx) + ".png") 449 | 450 | slice_y = (mask_list[0][0].copy().astype(np.float32) - np.min(mask_list[0][0].copy().astype(np.float32))) / ( 451 | np.max(mask_list[0][0].copy().astype(np.float32)) - 452 | np.min(mask_list[0][ 453 | 0].copy().astype(np.float32)) + 1e-10) * 255 # normalize to 0-255 454 | slice_y = slice_y.astype(np.uint8) 455 | Image.fromarray(slice_y).save(ori_save_name) # save with PIL instead of plt 456 | 457 | slice_y = (ave_mask[0].copy().astype(np.float32) - np.min(ave_mask[0].copy().astype(np.float32))) / ( 458 | np.max(ave_mask[0].copy().astype(np.float32)) - 459 | np.min(ave_mask[ 460 | 0].copy().astype(np.float32)) + 1e-10) * 255 # normalize to 0-255 461 | slice_y = slice_y.astype(np.uint8) 462 | Image.fromarray(slice_y).save(ave_save_name) # save with PIL instead of plt 463 | 464 | slice_y = (FP_final_mask[0].copy().astype(np.float32) - np.min(FP_final_mask[0].copy().astype(np.float32))) / ( 465 | np.max(FP_final_mask[0].copy().astype(np.float32)) - 466 | np.min(FP_final_mask[ 467 | 0].copy().astype(np.float32)) + 1e-10) * 255 # normalize to 0-255 468 | slice_y = slice_y.astype(np.uint8) 469 | Image.fromarray(slice_y).save(final_save_name) # save with PIL instead of plt 470 | 471 | 472 | return subject_dice_original, subject_assd_original, subject_hf_original, subject_dice_ave, subject_assd_ave, \ 473 | subject_hf_ave, subject_dice_FP_final, subject_assd_FP_final,subject_hf_FP_final, subject_dice_FN_final, subject_assd_FN_final,subject_hf_FN_final 474 | 475 | 476 | if __name__ == "__main__": 477 | main() 478 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FNPC 2 | The official implementation for paper: False Negative/Positive Control for SAM on Noisy Medical Images 3 | [Arxiv] https://arxiv.org/abs/2308.10382 4 | -------------------------------------------------------------------------------- /Single_Slice_to_Volume.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | import os 6 | import cv2 7 | from tqdm import tqdm 8 | import argparse 9 | import matplotlib.pyplot as plt 10 | import warnings 11 | 12 | warnings.filterwarnings('ignore') 13 | 14 | from show import * 15 | from segment_anything import sam_model_registry, SamPredictor 16 | 17 | from itertools import combinations 18 | 19 | from scipy.spatial import ConvexHull 20 | import numpy as np 21 | import itertools 22 | import scipy 23 | from scipy.spatial.distance import cdist 24 | from prompt_aug import * 25 | from loss_aug_sam import * 26 | from natsort import natsorted 27 | from PIL import Image 28 | from uncertainty_refine import * 29 | # follows 2D_to_3D_aug5_2 but utilize the bounding box refine 30 | 31 | def get_arguments(): 32 | parser = argparse.ArgumentParser() 33 | 34 | parser.add_argument('--data', type=str, 35 | default="/dataset/Placenta/manualSeg/for_SAM/Reconstruction") 36 | parser.add_argument('--outdir', type=str, 37 | default='Plancenta/Reconstruction/subjectx') 38 | 39 | parser.add_argument('--ckpt', type=str, default='sam_vit_l_0b3195.pth') 40 | parser.add_argument('--ref_idx', type=str, default='000') # TODO change to "000" when using kidney 41 | parser.add_argument('--sam_type', type=str, default='vit_l')# 'vit_h', 'vit_t', 'vit_l' 42 | 43 | args = parser.parse_args() 44 | return args 45 | 46 | 47 | def main(): 48 | args = get_arguments() 49 | print("Args:", args) 50 | 51 | images_path = args.data + '/Images_Subjectx/' 52 | masks_path = args.data + '/Annotations_0002/' 53 | gts_path = args.data + '/GTs/' 54 | output_path = './outputs/' + args.outdir 55 | 56 | if not os.path.exists('./outputs/'): 57 | os.mkdir('./outputs/') 58 | 59 | all_dice_original = [] 60 | all_assd_original = [] 61 | all_hf_original = [] 62 | 63 | all_dice_ave = [] 64 | all_assd_ave = [] 65 | all_hf_ave = [] 66 | 67 | all_dice_final = [] 68 | all_assd_final = [] 69 | all_hf_final = [] 70 | 71 | for obj_name in os.listdir(images_path): 72 | if ".DS" not in obj_name: 73 | # print("object name is:", obj_name) # object name is: Subject2_20230330_135719_1 74 | subject_dice_original, subject_assd_original, subject_hf_original, subject_dice_ave, subject_assd_ave, \ 75 | subject_hf_ave, subject_dice_final, subject_assd_final,subject_hf_final = SS2V(args, obj_name, images_path, masks_path, output_path, gts_path) 76 | 77 | def SS2V(args, obj_name, images_path, masks_path, output_path, gts_path): 78 | print("\n------------> Segment " + obj_name) 79 | M = # scale ratio 80 | N = # number of sampled bounding box 81 | ave_thre = 82 | uncertain_thre = # uncertain_thre ratio 83 | 84 | # for kidney the value of FN should be small, which means the range should be small 85 | # 86 | fna = # include more FN which outside the avemask but inside the UM, with value fna< and < fnb 87 | fnb = 88 | fpa = # exclude FP which inside the avemask and UM, with valuefpb 89 | fpb = 90 | 91 | # Path preparation 92 | test_images_path = os.path.join(images_path, obj_name) 93 | 94 | output_path = os.path.join(output_path, obj_name) 95 | os.makedirs(output_path, exist_ok=True) 96 | 97 | # Load images and masks 98 | # ref_image = cv2.imread(ref_image_path) 99 | # ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB) 100 | # 101 | # ref_mask = cv2.imread(ref_mask_path) 102 | # ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB) 103 | 104 | print("======> Load SAM") 105 | if args.sam_type == 'vit_h': 106 | # TODO add a weights/ 107 | # sam_type, sam_ckpt = 'vit_h', 'weights/sam_vit_h_4b8939.pth' 108 | sam_type, sam_ckpt = 'vit_h', 'weights/sam_vit_h_4b8939.pth' 109 | sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda() 110 | sam.eval() 111 | elif args.sam_type == 'vit_t': 112 | sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt' 113 | device = "cuda" if torch.cuda.is_available() else "cpu" 114 | sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) 115 | sam.eval() 116 | 117 | # here add more sam archs 118 | elif args.sam_type == 'vit_l': 119 | sam_type, sam_ckpt = 'vit_l', 'weights/sam_vit_l_0b3195.pth' 120 | device = "cuda" if torch.cuda.is_available() else "cpu" 121 | sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device) 122 | sam.eval() 123 | 124 | predictor = SamPredictor(sam) 125 | # print(os.listdir(test_images_path)) 126 | # print(test_images_path) 127 | sorted_files = natsorted(os.listdir(test_images_path)) 128 | # print(sorted_files) 129 | 130 | print('======> Start Testing') 131 | 132 | subject_dice_original = [] 133 | subject_assd_original = [] 134 | subject_hf_original = [] 135 | 136 | subject_dice_ave = [] 137 | subject_assd_ave = [] 138 | subject_hf_ave = [] 139 | 140 | subject_dice_final = [] 141 | subject_assd_final = [] 142 | subject_hf_final = [] 143 | 144 | # the start slice number and the end slice number 145 | start_num = 59 146 | end_num = 94 147 | minx = 2 148 | miny = 2 149 | maxx = 2 150 | maxy = 2 151 | for idx in tqdm(range(start_num, end_num+1)): 152 | print(idx) 153 | # Load test image 154 | # TODO change to %03d when using kidney data 155 | test_idx = sorted_files[idx][:-4]#get the part except the '.png' 156 | print("*********************",test_idx) 157 | # TODO change the datatype 158 | # test_image_path = test_images_path + '/' + test_idx + '.jpg' 159 | test_image_path = test_images_path + '/' + test_idx + '.png' 160 | # print("test image path:",test_image_path) 161 | test_image = cv2.imread(test_image_path) 162 | print("test_image shape:", test_image.shape) 163 | test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB) 164 | 165 | # if the input image is the middle slice, extract the reference mask 166 | if idx == start_num: 167 | ref_mask_path = os.path.join(masks_path, obj_name, "slice_" +str(start_num)+".png") 168 | print(ref_mask_path) 169 | ref_mask = cv2.imread(ref_mask_path) 170 | ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB) 171 | else: # if other slices, get the bounding box from the previous slice 172 | ref_mask_path = os.path.join(output_path, "final_fake_box_" + str(idx - 1) + ".png")# get the fake_box generate from the previous one 173 | ref_mask = cv2.imread(ref_mask_path) 174 | ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB) 175 | 176 | 177 | gt_mask_path = os.path.join(gts_path, obj_name, test_idx + '.png') 178 | gt_mask = cv2.imread(gt_mask_path) 179 | gt_mask = cv2.cvtColor(gt_mask, cv2.COLOR_BGR2GRAY) 180 | _, gt_mask = cv2.threshold(gt_mask, np.max(gt_mask) // 2, 1, cv2.THRESH_BINARY) 181 | 182 | # Image feature encoding 183 | y, x = np.nonzero(ref_mask[:,:,0]) 184 | if y.size == 0: 185 | x_min = 0 186 | x_max = 128 187 | y_min = 0 188 | y_max = 128 189 | else: 190 | x_min = x.min() 191 | x_max = x.max() 192 | y_min = y.min() 193 | y_max = y.max() 194 | 195 | input_box = np.array([x_min, y_min, x_max, y_max]) 196 | 197 | # ###### bbox augmentation Start ####### 198 | new_bbox_list, radius, Center_set = sample_points2(input_box, M, N) 199 | 200 | save_path_demo = os.path.join(output_path, f'Demo_{test_idx}.png') 201 | save_path_boxs = os.path.join(output_path, f'BoxDemo_{test_idx}.png') 202 | # visualize_and_save(disply_original_image, new_bbox_list, VT1_set, VT2_set, VT3_set, VT4_set, radius, save_path_demo, save_path_boxs) 203 | 204 | # generate mask for bbox list 205 | mask_list = [] 206 | for i in range(len(new_bbox_list)): 207 | predictor.set_image(test_image) 208 | # print("the box i is:", new_bbox_list[i][None, :]) 209 | masks, scores, logits, _ = predictor.predict( 210 | point_coords=None, 211 | point_labels=None, 212 | box=new_bbox_list[i][None, :], 213 | multimask_output=False 214 | ) 215 | # print('masks shape:', masks.shape) # (1, 128, 128) 216 | mask_list.append(masks) 217 | best_idx = 0 218 | 219 | # calculate the aleatoric_uncertainty 220 | # 计算aleatoric uncertainty 221 | input_mask_list = mask_list.copy() 222 | aleatoric_uncertainty_map = calculate_aleatoric_uncertainty(input_mask_list) # 128x128 223 | 224 | ave_mask = np.mean(mask_list, axis=0) 225 | ave_mask = np.where(ave_mask >= ave_thre * (np.max(ave_mask)), 1, 0) 226 | 227 | # calculate the ave_mask 228 | # final_mask = simple_threshold(aleatoric_uncertainty_map, uncertainty_threshold = 0.5) 229 | FN_final_mask, FN_UH, FN_xUH, FN_condition_mask, FP_final_mask, FP_UH, FP_xUH, FP_condition_mask = uc_refine_hardcode( 230 | ave_mask.copy(), 231 | aleatoric_uncertainty_map.copy(), 232 | test_image.copy(), threshold_uc=uncertain_thre, fn_alpha=fna, 233 | fn_beta=fnb, fp_alpha=fpa, fp_beta=fpb) 234 | 235 | # using the initial bounding box to refine the ave_mask and fine_mask 236 | refine_mask = np.where(ref_mask[:, :, 0] > 0, 1, 0) 237 | ave_mask[0] = ave_mask[0] * refine_mask 238 | FP_final_mask[0] = FP_final_mask[0] * refine_mask 239 | 240 | ### **** generate new mask for the next image **** #### 241 | y_new, x_new = np.nonzero(FP_final_mask[0].copy()) 242 | if y_new.size == 0: 243 | x_new_min = 0 244 | x_new_max = 128 245 | y_new_min = 0 246 | y_new_max = 128 247 | else: 248 | x_new_min = x_new.min() 249 | x_new_max = x_new.max() 250 | y_new_min = y_new.min() 251 | y_new_max = y_new.max() 252 | 253 | 254 | # compare the new mask with old mask to correct the position 255 | if np.abs(x_new_min - x_min) > minx: 256 | x_new_min = x_min.copy() 257 | if np.abs(y_new_min - y_min) > miny: 258 | y_new_min = y_min.copy() 259 | if np.abs(x_new_max - x_max) > maxx: 260 | x_new_max = x_max.copy() 261 | if np.abs(y_new_max - y_max) > maxy: 262 | y_new_max = y_max.copy() 263 | 264 | # generate the fake mask for the next input 265 | new_box = np.array([x_new_min, y_new_min, x_new_max, y_new_max]) 266 | new_binary_mask = create_mask_from_bbox(new_box, size=128) 267 | 268 | new_binary_mask_name = os.path.join(output_path,"final_fake_box_" + str(idx) + ".png") 269 | Image.fromarray(new_binary_mask).save(new_binary_mask_name) # save with PIL instead of plt 270 | 271 | # the mask_list is in true or false farmat, turn it into 0, 1 272 | original_mask = mask_list[0][0].copy() 273 | original_mask = original_mask.astype(int) 274 | 275 | 276 | 277 | # compare results we get 278 | fig, axs = plt.subplots(1, 5, figsize=(10 * 5, 10)) 279 | 280 | axs[0].imshow(test_image) 281 | show_mask(original_mask, axs[0]) 282 | show_box(new_bbox_list[0], axs[0], color="yellow") 283 | # axs[0].set_title(f'Ori_D_{case_dice_original:.2f}_A_{case_assd_original:.2f}_H_{case_hf_original:.2f}', fontsize=18) 284 | axs[0].set_title(f'Ori', 285 | fontsize=18) 286 | 287 | axs[1].imshow(test_image) 288 | show_mask(ave_mask, axs[1]) 289 | show_box(new_bbox_list[0], axs[1], color="yellow") 290 | axs[1].set_title(f'Ave', fontsize=18) 291 | 292 | axs[2].imshow(test_image) 293 | show_mask(FP_final_mask, axs[2]) 294 | show_box(new_bbox_list[0], axs[2], color="yellow") 295 | axs[2].set_title(f'Fin', fontsize=18) 296 | 297 | axs[3].imshow(test_image) 298 | show_mask(np.expand_dims(gt_mask, axis=0), axs[3]) 299 | show_box(new_bbox_list[0], axs[3], color="yellow") 300 | axs[3].set_title(f'GT', fontsize=18) 301 | 302 | # visualize uncertainty map 303 | vis_img = test_image.copy() 304 | overlay_uncertainty_on_image(vis_img, aleatoric_uncertainty_map.copy(), axs[4]) 305 | axs[4].set_title(f"Uncertainty Map", fontsize=18) 306 | 307 | # Save the subplot panel as a single image 308 | vis_mask_output_path = os.path.join(output_path, f'Results_Comparison_{test_idx}.png') 309 | plt.savefig(vis_mask_output_path, format='png') 310 | 311 | # save all the predicted mask for the next loop: 312 | ori_save_name = os.path.join(output_path, "ori_fake_mask_" + str(idx) + ".png") # since it's from 0 313 | ave_save_name = os.path.join(output_path, "ave_fake_mask_" + str(idx) + ".png") 314 | final_save_name = os.path.join(output_path,"final_fake_mask_" + str(idx) + ".png") 315 | 316 | slice_y = (mask_list[0][0].copy().astype(np.float32) - np.min(mask_list[0][0].copy().astype(np.float32))) / (np.max(mask_list[0][0].copy().astype(np.float32)) - 317 | np.min(mask_list[0][ 318 | 0].copy().astype(np.float32)) + 1e-10) * 255 # normalize to 0-255 319 | slice_y = slice_y.astype(np.uint8) 320 | Image.fromarray(slice_y).save(ori_save_name) # save with PIL instead of plt 321 | 322 | slice_y = (ave_mask[0].copy().astype(np.float32) - np.min(ave_mask[0].copy().astype(np.float32))) / (np.max(ave_mask[0].copy().astype(np.float32)) - 323 | np.min(ave_mask[ 324 | 0].copy().astype(np.float32)) + 1e-10) * 255 # normalize to 0-255 325 | slice_y = slice_y.astype(np.uint8) 326 | Image.fromarray(slice_y).save(ave_save_name) # save with PIL instead of plt 327 | 328 | slice_y = (FP_final_mask[0].copy().astype(np.float32) - np.min(FP_final_mask[0].copy().astype(np.float32))) / (np.max(FP_final_mask[0].copy().astype(np.float32)) - 329 | np.min(FP_final_mask[ 330 | 0].copy().astype(np.float32)) + 1e-10) * 255 # normalize to 0-255 331 | slice_y = slice_y.astype(np.uint8) 332 | Image.fromarray(slice_y).save(final_save_name) # save with PIL instead of plt 333 | 334 | return subject_dice_original, subject_assd_original, subject_hf_original, subject_dice_ave, subject_assd_ave, \ 335 | subject_hf_ave, subject_dice_final, subject_assd_final,subject_hf_final 336 | 337 | if __name__ == "__main__": 338 | main() 339 | -------------------------------------------------------------------------------- /loss_aug_sam.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import directed_hausdorff 3 | import numpy as np 4 | from scipy.spatial import distance 5 | 6 | # def assd(mask1, mask2): 7 | # # Get the set of non-zero points, representing the surface 8 | # surface_mask1 = np.argwhere(mask1) 9 | # surface_mask2 = np.argwhere(mask2) 10 | # 11 | # # Compute distances from surface_mask1 to surface_mask2 and vice versa 12 | # forward_distance = np.mean([np.min(distance.cdist(surface_mask1[i:i+1], surface_mask2, 'euclidean')) for i in range(surface_mask1.shape[0])]) 13 | # backward_distance = np.mean([np.min(distance.cdist(surface_mask2[i:i+1], surface_mask1, 'euclidean')) for i in range(surface_mask2.shape[0])]) 14 | # 15 | # # Compute the symmetric distance 16 | # symmetric_distance = np.mean([forward_distance, backward_distance]) 17 | # 18 | # return symmetric_distance 19 | 20 | from scipy.spatial import distance 21 | import numpy as np 22 | 23 | def assd(mask1, mask2): 24 | # Check if mask1 or mask2 is all zero 25 | if np.all(mask1 == 0) or np.all(mask2 == 0): 26 | return 1 27 | 28 | # Get the set of non-zero points, representing the surface 29 | surface_mask1 = np.argwhere(mask1) 30 | surface_mask2 = np.argwhere(mask2) 31 | 32 | # Compute distances from surface_mask1 to surface_mask2 and vice versa 33 | forward_distance = np.mean([np.min(distance.cdist(surface_mask1[i:i+1], surface_mask2, 'euclidean')) for i in range(surface_mask1.shape[0])]) 34 | backward_distance = np.mean([np.min(distance.cdist(surface_mask2[i:i+1], surface_mask1, 'euclidean')) for i in range(surface_mask2.shape[0])]) 35 | 36 | # Compute the symmetric distance 37 | symmetric_distance = np.mean([forward_distance, backward_distance]) 38 | 39 | return symmetric_distance 40 | 41 | 42 | 43 | 44 | def dice_score(mask1, mask2): 45 | # Flatten the masks 46 | mask1 = mask1.flatten() 47 | mask2 = mask2.flatten() 48 | 49 | # Compute Dice coefficient 50 | intersect = np.sum(mask1 * mask2) 51 | return (2. * intersect) / (np.sum(mask1) + np.sum(mask2)) 52 | 53 | def hausdorff_distance(mask1, mask2): 54 | # Compute Hausdorff distance 55 | return max(directed_hausdorff(mask1, mask2)[0], directed_hausdorff(mask2, mask1)[0]) 56 | 57 | 58 | def calculate_fp_fn(mask_gt, mask_pred): 59 | """ 60 | Calculate FP and FN rates for binary masks. 61 | 62 | Parameters: 63 | mask_gt (numpy.ndarray): Ground truth binary mask, shape (D, D). 64 | mask_pred (numpy.ndarray): Predicted binary mask, shape (D, D). 65 | 66 | Returns: 67 | float, float: FP rate, FN rate 68 | """ 69 | assert mask_gt.shape == mask_pred.shape, "Shape mismatch between ground truth and predicted masks." 70 | 71 | # Flatten the masks and combine 72 | labels = 2 * mask_gt.reshape(-1) + mask_pred.reshape(-1) 73 | 74 | # Count occurrences of 00, 01, 10, 11 75 | counts = np.bincount(labels, minlength=4) 76 | 77 | TN, FP, FN, TP = counts 78 | 79 | # Calculate FP and FN rates 80 | fp_rate = FP / (FP + TN) if FP + TN != 0 else 0 81 | fn_rate = FN / (FN + TP) if FN + TP != 0 else 0 82 | 83 | return fp_rate, fn_rate 84 | -------------------------------------------------------------------------------- /prompt_aug.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import functional as F 4 | import os 5 | import cv2 6 | from tqdm import tqdm 7 | import argparse 8 | import matplotlib.pyplot as plt 9 | import warnings 10 | warnings.filterwarnings('ignore') 11 | from show import * 12 | from per_segment_anything import sam_model_registry, SamPredictor 13 | from itertools import combinations 14 | from scipy.spatial import ConvexHull 15 | import numpy as np 16 | import itertools 17 | import scipy 18 | from scipy.spatial.distance import cdist 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | # this folder have all the augmentation functions need for the dataaug 22 | 23 | import numpy as np 24 | import random 25 | 26 | import matplotlib.pyplot as plt 27 | import cv2 28 | import math 29 | import cv2 30 | import math 31 | import matplotlib as mpl 32 | # import numpy as np 33 | from bridson import poisson_disc_samples 34 | 35 | from bridson import poisson_disc_samples 36 | 37 | def create_mask_from_bbox(input_box, size=128): 38 | # Initialize an empty mask of size 128x128 39 | mask = np.zeros((size, size), dtype=np.uint8) 40 | 41 | # Extract the bounding box coordinates 42 | x_min, y_min, x_max, y_max = input_box 43 | 44 | # Ensure the bounding box coordinates are within the range 45 | x_min = max(0, min(size-1, x_min)) 46 | y_min = max(0, min(size-1, y_min)) 47 | x_max = max(0, min(size, x_max)) # x_max is exclusive 48 | y_max = max(0, min(size, y_max)) # y_max is exclusive 49 | 50 | # Set the region inside the bounding box to 1 51 | mask[y_min:y_max+1, x_min:x_max+1] = 255 52 | 53 | return mask 54 | 55 | def visualize_and_save2(image, new_bbox_list, Center_Sets, radius, save_path): 56 | # Convert image to RGB if it's grayscale 57 | if len(image.shape) == 2: 58 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 59 | radius = math.ceil(radius+1) 60 | 61 | # Mark original bounding box and its vertices in yellow and draw circles 62 | x_min, y_min, x_max, y_max = map(int, new_bbox_list[0]) 63 | cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 255, 255), 1) # Yellow bounding box, the thickness == 1 64 | 65 | # Mark first point in Center_Sets with a yellow circle 66 | # x, y = map(int, Center_Sets[0]) #since center_sets doesn't contains the original points so... 67 | x, y = map(int, np.array([(x_max + x_min) / 2, (y_max + y_min) / 2])) 68 | cv2.circle(image, (x, y), radius, (0, 255, 255), 1) # Yellow circle, the -1 means the circle will be fullfilled 69 | 70 | # Mark rest of the points from Center_Sets in red 71 | for pt in Center_Sets[1:]: 72 | pt = tuple(map(int, pt)) 73 | # cv2.circle(image, pt, radius, (0, 0, 255), -1) # Red circle 74 | cv2.drawMarker(image, pt, (0, 0, 255), markerSize=1, thickness=1) # red dot 75 | 76 | cv2.drawMarker(image, tuple(map(int, np.array([(x_max + x_min) / 2, (y_max + y_min) / 2]))), (0, 255, 255), markerSize=2, 77 | thickness=1) # yellow dot for the center 78 | 79 | # Save image 80 | cv2.imwrite(save_path, image) 81 | 82 | def visualize_and_save(image, new_bbox_list, VT1_set, VT2_set, VT3_set, VT4_set, radius, save_path1, save_path2): 83 | # Convert image to RGB if it's grayscale 84 | image2 = image.copy() 85 | if len(image.shape) == 2: 86 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 87 | radius = math.ceil(radius+1) 88 | ## The circle drawing in OpenCV might be slightly different from the method used to generate the points. 89 | # It's possible that the points are indeed within the circle according to the sampling method, 90 | # but the drawn circle appears slightly smaller due to how OpenCV interprets the radius and center point. 91 | 92 | # Mark original bounding box and its vertices in yellow and draw circles 93 | x_min, y_min, x_max, y_max = map(int, new_bbox_list[0]) 94 | cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 255, 255), 1) # Yellow bounding box, the thickness == 1 95 | cv2.circle(image, (x_min, y_min), radius, (0, 255, 255), -1) # Top-left, the -1 means the circle will be fullfilled 96 | cv2.circle(image, (x_max, y_min), radius, (0, 255, 255), -1) # Top-right 97 | cv2.circle(image, (x_max, y_max), radius, (0, 255, 255), -1) # Bottom-right 98 | cv2.circle(image, (x_min, y_max), radius, (0, 255, 255), -1) # Bottom-left 99 | 100 | # Mark all sample points from VT sets in blue 101 | for VT_set in [VT1_set, VT2_set, VT3_set, VT4_set]: 102 | for pt in VT_set: 103 | pt = tuple(map(int, pt)) 104 | # cv2.circle(image, pt, 1, (255, 0, 0), -1) # Blue dot 105 | # cv2.drawMarker(image, pt, (255, 0, 0), markerType=cv2.MARKER_CROSS, markerSize=1, thickness=1)# Blue dot 106 | cv2.drawMarker(image, pt, (255, 0, 0), markerSize=1, thickness=1) # Blue dot 107 | 108 | # Mark vertices of new_bbox_list[1:] in red 109 | for bbox in new_bbox_list[1:]: 110 | x_min, y_min, x_max, y_max = map(int, bbox) 111 | # cv2.circle(image, (x_min, y_min), 1, (0, 0, 255), -1) # Top-left 112 | # cv2.circle(image, (x_max, y_min), 1, (0, 0, 255), -1) # Top-right 113 | # cv2.circle(image, (x_max, y_max), 1, (0, 0, 255), -1) # Bottom-right 114 | # cv2.circle(image, (x_min, y_max), 1, (0, 0, 255), -1) # Bottom-left 115 | cv2.drawMarker(image, (x_min, y_min), (0, 0, 255), markerType=cv2.MARKER_CROSS, markerSize=1, thickness=1)# Top-left 116 | cv2.drawMarker(image, (x_max, y_min), (0, 0, 255), markerType=cv2.MARKER_CROSS, markerSize=1, 117 | thickness=1) # Top-right 118 | cv2.drawMarker(image, (x_max, y_max), (0, 0, 255), markerType=cv2.MARKER_CROSS, markerSize=1, 119 | thickness=1) # Bottom-right 120 | cv2.drawMarker(image, (x_min, y_max), (0, 0, 255), markerType=cv2.MARKER_CROSS, markerSize=1, 121 | thickness=1) # Bottom-left 122 | # Save image 123 | cv2.imwrite(save_path1, image) 124 | 125 | def sample_points3(input_box, M, N): 126 | # follows sample_points2, but using blue noise 127 | # If you still feel that the distribution of points is uneven, it may be because this sampling method is highly random, which may result in more points in some areas and fewer points in other areas. If you need even distribution, 128 | # You may want to use other sampling strategies, such as selecting points evenly on a grid within a circle, or using a technique called "blue noise". Blue noise is a special kind of noise that is uniform globally but maintains a certain degree of randomness locally. 129 | # Blue noise sampling is a sampling method that produces uniformly distributed points throughout the entire area while trying to maintain the minimum distance between adjacent points. 130 | # Unpack the bounding box 131 | x_min, y_min, x_max, y_max = input_box 132 | 133 | # Calculate the center and side lengths 134 | C = np.array([(x_max + x_min) / 2, (y_max + y_min) / 2]) 135 | ES = min(x_max - x_min, y_max - y_min) 136 | 137 | # Calculate the radius 138 | radius = ES / M 139 | 140 | # Function to sample points from a circle 141 | def sample_points_from_circle(center, radius, num_points): 142 | points = poisson_disc_samples(width=2*radius, height=2*radius, r=radius/num_points) 143 | # Shift points to the right location 144 | points = [[point[0] + center[0] - radius, point[1] + center[1] - radius] for point in points] 145 | return np.array(points) 146 | 147 | # Sample N points from the circle 148 | Center_Sets = sample_points_from_circle(C, radius, N) 149 | 150 | # Generate new bounding boxes 151 | new_bbox_list = [input_box] 152 | for i in range(N): 153 | # Get new center point 154 | new_C = Center_Sets[i] 155 | 156 | # Calculate the new bounding box 157 | x_min_new = new_C[0] - (x_max - x_min) / 2 158 | y_min_new = new_C[1] - (y_max - y_min) / 2 159 | x_max_new = new_C[0] + (x_max - x_min) / 2 160 | y_max_new = new_C[1] + (y_max - y_min) / 2 161 | 162 | new_bbox = np.array([x_min_new, y_min_new, x_max_new, y_max_new]) 163 | new_bbox_list.append(new_bbox) 164 | 165 | return new_bbox_list, radius, Center_Sets 166 | 167 | 168 | def sample_points2(input_box, M, N): 169 | # Unpack the bounding box 170 | x_min, y_min, x_max, y_max = input_box 171 | 172 | # Calculate the center and side lengths 173 | C = np.array([(x_max + x_min) / 2, (y_max + y_min) / 2]) 174 | ES = min(x_max - x_min, y_max - y_min) 175 | 176 | # Calculate the radius 177 | radius = ES / M 178 | 179 | # Function to sample points from a circle 180 | def sample_points_from_circle(center, radius, num_points): 181 | points = [] 182 | for _ in range(num_points): 183 | # Sample a point in polar coordinates and convert it to cartesian 184 | theta = 2 * np.pi * random.random() 185 | r = radius * np.sqrt(random.random()) 186 | x = center[0] + r * np.cos(theta) 187 | y = center[1] + r * np.sin(theta) 188 | points.append([x, y]) 189 | return np.array(points) 190 | 191 | # Sample N points from the circle 192 | Center_Sets = sample_points_from_circle(C, radius, N) 193 | 194 | # Generate new bounding boxes 195 | new_bbox_list = [input_box] 196 | for i in range(N): 197 | # Get new center point 198 | new_C = Center_Sets[i] 199 | 200 | # Calculate the new bounding box 201 | x_min_new = new_C[0] - (x_max - x_min) / 2 202 | y_min_new = new_C[1] - (y_max - y_min) / 2 203 | x_max_new = new_C[0] + (x_max - x_min) / 2 204 | y_max_new = new_C[1] + (y_max - y_min) / 2 205 | 206 | new_bbox = np.array([x_min_new, y_min_new, x_max_new, y_max_new]) 207 | new_bbox_list.append(new_bbox) 208 | 209 | return new_bbox_list, radius, Center_Sets 210 | 211 | 212 | def sample_points(input_box, M, N, S, sample_type = "MC"): 213 | # 2) scale ratio M, 3) number of sampled bounding box: N, 4) number of sampled_points S 214 | # Unpack the bounding box 215 | x_min, y_min, x_max, y_max = input_box 216 | 217 | # Calculate the vertices and side lengths 218 | VT1 = np.array([x_min, y_min]) 219 | VT2 = np.array([x_max, y_min]) 220 | VT3 = np.array([x_max, y_max]) 221 | VT4 = np.array([x_min, y_max]) 222 | ES = min(x_max - x_min, y_max - y_min) 223 | 224 | # Sample S points for each vertex 225 | def sample_points_from_circle(center, radius, num_points): 226 | points = [] 227 | for _ in range(num_points): 228 | # Sample a point in polar coordinates and convert it to cartesian 229 | theta = 2 * np.pi * random.random() 230 | r = radius * np.sqrt(random.random()) 231 | x = center[0] + r * np.cos(theta) 232 | y = center[1] + r * np.sin(theta) 233 | points.append([x, y]) 234 | return np.array(points) 235 | 236 | def sample_points_from_gaussian(center, radius, num_points): 237 | # The mean would be the center of the circle, and the covariance matrix would be a 2D identity matrix scaled 238 | # by the desired variance (which could be set to (radius**2)/2 to ensure that approximately 95% of points fall 239 | # within the circle). 240 | 241 | # a Gaussian distribution is not bounded, so it's still possible to get points outside the circle. 242 | 243 | 244 | # Set the mean and covariance 245 | mean = center 246 | covariance = np.eye(2) * (radius ** 2) / 2 247 | 248 | # Sample points from the Gaussian distribution 249 | points = np.random.multivariate_normal(mean, covariance, num_points) 250 | 251 | return points 252 | 253 | radius = ES / M 254 | 255 | if sample_type == 'MC': 256 | VT1_set = sample_points_from_circle(VT1, radius, S) 257 | VT2_set = sample_points_from_circle(VT2, radius, S) 258 | VT3_set = sample_points_from_circle(VT3, radius, S) 259 | VT4_set = sample_points_from_circle(VT4, radius, S) 260 | else: 261 | VT1_set = sample_points_from_gaussian(VT1, radius, S) 262 | VT2_set = sample_points_from_gaussian(VT2, radius, S) 263 | VT3_set = sample_points_from_gaussian(VT3, radius, S) 264 | VT4_set = sample_points_from_gaussian(VT4, radius, S) 265 | 266 | # Generate new bounding boxes 267 | new_bbox_list = [input_box] 268 | for _ in range(N): 269 | # Sample a point from each set 270 | new_VT1 = VT1_set[random.randint(0, S-1)] 271 | new_VT2 = VT2_set[random.randint(0, S-1)] 272 | new_VT3 = VT3_set[random.randint(0, S-1)] 273 | new_VT4 = VT4_set[random.randint(0, S-1)] 274 | 275 | # Calculate the new bounding box 276 | x_min_new = min(new_VT1[0], new_VT4[0]) 277 | y_min_new = min(new_VT1[1], new_VT2[1]) 278 | x_max_new = max(new_VT2[0], new_VT3[0]) 279 | y_max_new = max(new_VT3[1], new_VT4[1]) 280 | 281 | new_bbox = np.array([x_min_new, y_min_new, x_max_new, y_max_new]) 282 | new_bbox_list.append(new_bbox) 283 | 284 | return new_bbox_list, radius, VT1_set, VT2_set, VT3_set, VT4_set 285 | 286 | 287 | def calculate_aleatoric_uncertainty(mask_list): 288 | # stack the mask samples in the mask_list to calculate the frequency of each pixel 289 | all_masks = np.vstack(mask_list) 290 | 291 | # calculate the frequency of 1 for each pixel location 292 | frequency = np.mean(all_masks, axis=0) 293 | 294 | # calculate the aleatoric uncertainty for each frequency location 295 | aleatoric_uncertainty = frequency * (1 - frequency) 296 | 297 | # change to an entrophy type: 298 | # aleatoric_uncertainty = -0.5*(frequency * np.log(frequency + 10e-7)+(1 - frequency) * np.log((1 - frequency) + 10e-7)) 299 | 300 | return aleatoric_uncertainty 301 | 302 | 303 | def visualize_uncertainty2(uncertainty_map): 304 | plt.imshow(uncertainty_map, cmap='hot', interpolation='nearest') 305 | plt.colorbar() 306 | plt.title('Aleatoric Uncertainty Map') 307 | plt.show() 308 | 309 | def show_uncertainty(uncertainty_map, ax): 310 | ax.imshow(uncertainty_map,cmap='hot', interpolation='nearest') 311 | 312 | def overlay_uncertainty_on_image(image, uncertainty_map, ax): 313 | """ 314 | Superimpose the uncertainty map on the image and display it. 315 | 316 | parameter: 317 | image (numpy.ndarray): Input image, a single-channel image of shape (1, height, width). 318 | uncertainty_map (numpy.ndarray): uncertainty map, shape (height, width). 319 | 320 | return: 321 | There is no return value, and the superimposed image is displayed directly. 322 | 323 | """ 324 | image_rgb = image 325 | 326 | # Create a colormap from white to red 327 | cmap = plt.get_cmap('hot') 328 | # cmap.set_under('red') # sets color for smallest values 329 | # cmap.set_over('white') # sets color for largest values 330 | 331 | # Map the colors in the colormap to the range 0-0.25 332 | norm = mpl.colors.Normalize(vmin=0, vmax=0.25) 333 | mapper = mpl.cm.ScalarMappable(norm=norm, cmap=cmap) 334 | uncertainty_map_rgb = cmap(norm(uncertainty_map)) 335 | 336 | uncertainty_map_rgb = (uncertainty_map_rgb[..., :3] * 255).astype(np.uint8) # ignore alpha channel 337 | 338 | # overlay the uncertainty map on the img 339 | overlay_image = np.copy(image_rgb) 340 | overlay_image[..., 0] = (overlay_image[..., 0] * 0.7 + uncertainty_map_rgb[..., 0] * 0.3) 341 | overlay_image[..., 1] = (overlay_image[..., 1] * 0.7 + uncertainty_map_rgb[..., 1] * 0.3) 342 | overlay_image[..., 2] = (overlay_image[..., 2] * 0.7 + uncertainty_map_rgb[..., 2] * 0.3) 343 | 344 | # Show superimposed images 345 | im = ax.imshow(overlay_image) 346 | 347 | # colorbar 348 | plt.colorbar(mapper, ax=ax, extend='neither') # set extend from 'both' to 'neither' to remove the arrow in the colorbar 349 | 350 | 351 | 352 | def simple_threshold(aleatoric_uncertainty_map , uncertainty_threshold = 0.5): 353 | # thresholding to get final mask 354 | final_mask = np.where(aleatoric_uncertainty_map <= uncertainty_threshold, 1, 0) 355 | 356 | return final_mask 357 | 358 | 359 | def mask_adjustment(binary_mask, uncertainty_map, threshold_ratio = 0.2): 360 | # Get the threshold 361 | threshold = threshold_ratio * np.max(uncertainty_map) 362 | 363 | # Identify the pixels in the uncertainty map that are above the threshold 364 | above_threshold = uncertainty_map > threshold 365 | 366 | # Apply the condition to the binary mask 367 | binary_mask[0, above_threshold] = 0 368 | 369 | return binary_mask 370 | 371 | def mask_adjustment2(binary_mask, uncertainty_map, img, threshold_ratio = 0.2, thre_fp = 0.5, thre_fn=0.5 ): 372 | # Get the threshold 373 | threshold = threshold_ratio * np.max(uncertainty_map) 374 | 375 | # Identify the pixels in the uncertainty map that are above the threshold 376 | above_threshold = uncertainty_map > threshold 377 | 378 | # Average over the channels of img 379 | img_avg = np.mean(img, axis=2) 380 | 381 | # Calculate the average value of img under the binary_mask 382 | mean_value = np.mean(img_avg[binary_mask[0] > 0]) 383 | 384 | # For the binary mask, 385 | # if the pixels are in the above_threshold and the corresponding positional pixel in img larger than mean_value, set it to be 0, 386 | # binary_mask[0, (above_threshold) & (img_avg > mean_value*thre_fp)] = 0 387 | binary_mask[0, above_threshold] = 0 388 | 389 | # if the pixels are in the above_threshold and the corresponding positional pixel in img smaller or equal to mean_value, set to be 1 390 | binary_mask[0, (above_threshold) & (img_avg <= mean_value* thre_fn)] = 1 391 | 392 | return binary_mask 393 | 394 | 395 | def mask_adjustment3(binary_mask, uncertainty_map, img, threshold_ratio = 0.2, thre_fp = 0.5, thre_fn=0.5 ): 396 | # Get the threshold 397 | threshold = threshold_ratio * np.max(uncertainty_map) 398 | 399 | # Identify the pixels in the uncertainty map that are above the threshold 400 | above_threshold = uncertainty_map > threshold 401 | 402 | # Average over the channels of img 403 | img_avg = np.mean(img, axis=2) 404 | 405 | 406 | # For the binary mask, 407 | # if the pixels are in the above_threshold and the corresponding positional pixel in img larger than mean_value, set it to be 0, 408 | # binary_mask[0, (above_threshold) & (img_avg > mean_value*thre_fp)] = 0 409 | binary_mask[0, above_threshold] = 0 410 | 411 | # Calculate the average value of img under the binary_mask, after threshold by uncertainty map 412 | mean_value = np.mean(img_avg[binary_mask[0] > 0]) 413 | 414 | # if the pixels are in the above_threshold and the corresponding positional pixel in img smaller or equal to mean_value, set to be 1 415 | binary_mask[0, (above_threshold) & (img_avg <= mean_value* thre_fn)] = 1 416 | 417 | return binary_mask 418 | 419 | def mask_adjustment4(binary_mask, uncertainty_map, img, threshold_ratio = 0.2, thre_fp = 0.5, thre_fn=0.5 ): 420 | # Get the threshold 421 | threshold = np.min(uncertainty_map)+threshold_ratio * (np.max(uncertainty_map) - np.min(uncertainty_map)) 422 | 423 | # Identify the pixels in the uncertainty map that are above the threshold 424 | above_threshold = uncertainty_map > threshold 425 | 426 | original_binary_range = binary_mask>0 427 | 428 | # Average over the channels of img 429 | img_avg = np.mean(img, axis=2) 430 | 431 | 432 | # For the binary mask, 433 | # if the pixels are in the above_threshold and the corresponding positional pixel in img larger than mean_value, set it to be 0, 434 | # binary_mask[0, (above_threshold) & (img_avg > mean_value*thre_fp)] = 0 435 | binary_mask[0, above_threshold] = 0 436 | 437 | # Calculate the average value of img under the binary_mask, after threshold by uncertainty map 438 | mean_value = np.mean(img_avg[binary_mask[0] > 0]) 439 | 440 | # print("original_binary_rang shape:", original_binary_range.shape) # (1, 256, 256) 441 | # print("above_threshold shpae:", above_threshold.shape) # (256, 256) 442 | # print("img_avg <= mean_value* thre_fn:", (img_avg <= mean_value* thre_fn).shape) # (256, 256) 443 | 444 | # if the pixels are in the above_threshold and the corresponding positional pixel in img smaller or equal to mean_value, set to be 1 445 | binary_mask[0,(original_binary_range[0])&(above_threshold) & (img_avg <= mean_value* thre_fn)] = 1 446 | 447 | return binary_mask -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /segment_anything/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyimaging/FNPC/237048543f82dd417af7df5d8312cebc762e9e3e/segment_anything/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/automatic_mask_generator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyimaging/FNPC/237048543f82dd417af7df5d8312cebc762e9e3e/segment_anything/__pycache__/automatic_mask_generator.cpython-36.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/build_sam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyimaging/FNPC/237048543f82dd417af7df5d8312cebc762e9e3e/segment_anything/__pycache__/build_sam.cpython-36.pyc -------------------------------------------------------------------------------- /segment_anything/__pycache__/predictor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyimaging/FNPC/237048543f82dd417af7df5d8312cebc762e9e3e/segment_anything/__pycache__/predictor.cpython-36.pyc -------------------------------------------------------------------------------- /segment_anything/automatic_mask_generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision.ops.boxes import batched_nms, box_area # type: ignore 10 | 11 | from typing import Any, Dict, List, Optional, Tuple 12 | 13 | from .modeling import Sam 14 | from .predictor import SamPredictor 15 | from .utils.amg import ( 16 | MaskData, 17 | area_from_rle, 18 | batch_iterator, 19 | batched_mask_to_box, 20 | box_xyxy_to_xywh, 21 | build_all_layer_point_grids, 22 | calculate_stability_score, 23 | coco_encode_rle, 24 | generate_crop_boxes, 25 | is_box_near_crop_edge, 26 | mask_to_rle_pytorch, 27 | remove_small_regions, 28 | rle_to_mask, 29 | uncrop_boxes_xyxy, 30 | uncrop_masks, 31 | uncrop_points, 32 | ) 33 | 34 | 35 | class SamAutomaticMaskGenerator: 36 | def __init__( 37 | self, 38 | model: Sam, 39 | points_per_side: Optional[int] = 32, 40 | points_per_batch: int = 64, 41 | pred_iou_thresh: float = 0.88, 42 | stability_score_thresh: float = 0.95, 43 | stability_score_offset: float = 1.0, 44 | box_nms_thresh: float = 0.7, 45 | crop_n_layers: int = 0, 46 | crop_nms_thresh: float = 0.7, 47 | crop_overlap_ratio: float = 512 / 1500, 48 | crop_n_points_downscale_factor: int = 1, 49 | point_grids: Optional[List[np.ndarray]] = None, 50 | min_mask_region_area: int = 0, 51 | output_mode: str = "binary_mask", 52 | ) -> None: 53 | """ 54 | Using a SAM model, generates masks for the entire image. 55 | Generates a grid of point prompts over the image, then filters 56 | low quality and duplicate masks. The default settings are chosen 57 | for SAM with a ViT-H backbone. 58 | 59 | Arguments: 60 | model (Sam): The SAM model to use for mask prediction. 61 | points_per_side (int or None): The number of points to be sampled 62 | along one side of the image. The total number of points is 63 | points_per_side**2. If None, 'point_grids' must provide explicit 64 | point sampling. 65 | points_per_batch (int): Sets the number of points run simultaneously 66 | by the model. Higher numbers may be faster but use more GPU memory. 67 | pred_iou_thresh (float): A filtering threshold in [0,1], using the 68 | model's predicted mask quality. 69 | stability_score_thresh (float): A filtering threshold in [0,1], using 70 | the stability of the mask under changes to the cutoff used to binarize 71 | the model's mask predictions. 72 | stability_score_offset (float): The amount to shift the cutoff when 73 | calculated the stability score. 74 | box_nms_thresh (float): The box IoU cutoff used by non-maximal 75 | suppression to filter duplicate masks. 76 | crop_n_layers (int): If >0, mask prediction will be run again on 77 | crops of the image. Sets the number of layers to run, where each 78 | layer has 2**i_layer number of image crops. 79 | crop_nms_thresh (float): The box IoU cutoff used by non-maximal 80 | suppression to filter duplicate masks between different crops. 81 | crop_overlap_ratio (float): Sets the degree to which crops overlap. 82 | In the first crop layer, crops will overlap by this fraction of 83 | the image length. Later layers with more crops scale down this overlap. 84 | crop_n_points_downscale_factor (int): The number of points-per-side 85 | sampled in layer n is scaled down by crop_n_points_downscale_factor**n. 86 | point_grids (list(np.ndarray) or None): A list over explicit grids 87 | of points used for sampling, normalized to [0,1]. The nth grid in the 88 | list is used in the nth crop layer. Exclusive with points_per_side. 89 | min_mask_region_area (int): If >0, postprocessing will be applied 90 | to remove disconnected regions and holes in masks with area smaller 91 | than min_mask_region_area. Requires opencv. 92 | output_mode (str): The form masks are returned in. Can be 'binary_mask', 93 | 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. 94 | For large resolutions, 'binary_mask' may consume large amounts of 95 | memory. 96 | """ 97 | 98 | assert (points_per_side is None) != ( 99 | point_grids is None 100 | ), "Exactly one of points_per_side or point_grid must be provided." 101 | if points_per_side is not None: 102 | self.point_grids = build_all_layer_point_grids( 103 | points_per_side, 104 | crop_n_layers, 105 | crop_n_points_downscale_factor, 106 | ) 107 | elif point_grids is not None: 108 | self.point_grids = point_grids 109 | else: 110 | raise ValueError("Can't have both points_per_side and point_grid be None.") 111 | 112 | assert output_mode in [ 113 | "binary_mask", 114 | "uncompressed_rle", 115 | "coco_rle", 116 | ], f"Unknown output_mode {output_mode}." 117 | if output_mode == "coco_rle": 118 | from pycocotools import mask as mask_utils # type: ignore # noqa: F401 119 | 120 | if min_mask_region_area > 0: 121 | import cv2 # type: ignore # noqa: F401 122 | 123 | self.predictor = SamPredictor(model) 124 | self.points_per_batch = points_per_batch 125 | self.pred_iou_thresh = pred_iou_thresh 126 | self.stability_score_thresh = stability_score_thresh 127 | self.stability_score_offset = stability_score_offset 128 | self.box_nms_thresh = box_nms_thresh 129 | self.crop_n_layers = crop_n_layers 130 | self.crop_nms_thresh = crop_nms_thresh 131 | self.crop_overlap_ratio = crop_overlap_ratio 132 | self.crop_n_points_downscale_factor = crop_n_points_downscale_factor 133 | self.min_mask_region_area = min_mask_region_area 134 | self.output_mode = output_mode 135 | 136 | @torch.no_grad() 137 | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: 138 | """ 139 | Generates masks for the given image. 140 | 141 | Arguments: 142 | image (np.ndarray): The image to generate masks for, in HWC uint8 format. 143 | 144 | Returns: 145 | list(dict(str, any)): A list over records for masks. Each record is 146 | a dict containing the following keys: 147 | segmentation (dict(str, any) or np.ndarray): The mask. If 148 | output_mode='binary_mask', is an array of shape HW. Otherwise, 149 | is a dictionary containing the RLE. 150 | bbox (list(float)): The box around the mask, in XYWH format. 151 | area (int): The area in pixels of the mask. 152 | predicted_iou (float): The model's own prediction of the mask's 153 | quality. This is filtered by the pred_iou_thresh parameter. 154 | point_coords (list(list(float))): The point coordinates input 155 | to the model to generate this mask. 156 | stability_score (float): A measure of the mask's quality. This 157 | is filtered on using the stability_score_thresh parameter. 158 | crop_box (list(float)): The crop of the image used to generate 159 | the mask, given in XYWH format. 160 | """ 161 | 162 | # Generate masks 163 | mask_data = self._generate_masks(image) 164 | 165 | # Filter small disconnected regions and holes in masks 166 | if self.min_mask_region_area > 0: 167 | mask_data = self.postprocess_small_regions( 168 | mask_data, 169 | self.min_mask_region_area, 170 | max(self.box_nms_thresh, self.crop_nms_thresh), 171 | ) 172 | 173 | # Encode masks 174 | if self.output_mode == "coco_rle": 175 | mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] 176 | elif self.output_mode == "binary_mask": 177 | mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] 178 | else: 179 | mask_data["segmentations"] = mask_data["rles"] 180 | 181 | # Write mask records 182 | curr_anns = [] 183 | for idx in range(len(mask_data["segmentations"])): 184 | ann = { 185 | "segmentation": mask_data["segmentations"][idx], 186 | "area": area_from_rle(mask_data["rles"][idx]), 187 | "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), 188 | "predicted_iou": mask_data["iou_preds"][idx].item(), 189 | "point_coords": [mask_data["points"][idx].tolist()], 190 | "stability_score": mask_data["stability_score"][idx].item(), 191 | "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), 192 | } 193 | curr_anns.append(ann) 194 | 195 | return curr_anns 196 | 197 | def _generate_masks(self, image: np.ndarray) -> MaskData: 198 | orig_size = image.shape[:2] 199 | crop_boxes, layer_idxs = generate_crop_boxes( 200 | orig_size, self.crop_n_layers, self.crop_overlap_ratio 201 | ) 202 | 203 | # Iterate over image crops 204 | data = MaskData() 205 | for crop_box, layer_idx in zip(crop_boxes, layer_idxs): 206 | crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) 207 | data.cat(crop_data) 208 | 209 | # Remove duplicate masks between crops 210 | if len(crop_boxes) > 1: 211 | # Prefer masks from smaller crops 212 | scores = 1 / box_area(data["crop_boxes"]) 213 | scores = scores.to(data["boxes"].device) 214 | keep_by_nms = batched_nms( 215 | data["boxes"].float(), 216 | scores, 217 | torch.zeros_like(data["boxes"][:, 0]), # categories 218 | iou_threshold=self.crop_nms_thresh, 219 | ) 220 | data.filter(keep_by_nms) 221 | 222 | data.to_numpy() 223 | return data 224 | 225 | def _process_crop( 226 | self, 227 | image: np.ndarray, 228 | crop_box: List[int], 229 | crop_layer_idx: int, 230 | orig_size: Tuple[int, ...], 231 | ) -> MaskData: 232 | # Crop the image and calculate embeddings 233 | x0, y0, x1, y1 = crop_box 234 | cropped_im = image[y0:y1, x0:x1, :] 235 | cropped_im_size = cropped_im.shape[:2] 236 | self.predictor.set_image(cropped_im) 237 | 238 | # Get points for this crop 239 | points_scale = np.array(cropped_im_size)[None, ::-1] 240 | points_for_image = self.point_grids[crop_layer_idx] * points_scale 241 | 242 | # Generate masks for this crop in batches 243 | data = MaskData() 244 | for (points,) in batch_iterator(self.points_per_batch, points_for_image): 245 | batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) 246 | data.cat(batch_data) 247 | del batch_data 248 | self.predictor.reset_image() 249 | 250 | # Remove duplicates within this crop. 251 | keep_by_nms = batched_nms( 252 | data["boxes"].float(), 253 | data["iou_preds"], 254 | torch.zeros_like(data["boxes"][:, 0]), # categories 255 | iou_threshold=self.box_nms_thresh, 256 | ) 257 | data.filter(keep_by_nms) 258 | 259 | # Return to the original image frame 260 | data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) 261 | data["points"] = uncrop_points(data["points"], crop_box) 262 | data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) 263 | 264 | return data 265 | 266 | def _process_batch( 267 | self, 268 | points: np.ndarray, 269 | im_size: Tuple[int, ...], 270 | crop_box: List[int], 271 | orig_size: Tuple[int, ...], 272 | ) -> MaskData: 273 | orig_h, orig_w = orig_size 274 | 275 | # Run model on this batch 276 | transformed_points = self.predictor.transform.apply_coords(points, im_size) 277 | in_points = torch.as_tensor(transformed_points, device=self.predictor.device) 278 | in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) 279 | masks, iou_preds, _ = self.predictor.predict_torch( 280 | in_points[:, None, :], 281 | in_labels[:, None], 282 | multimask_output=True, 283 | return_logits=True, 284 | ) 285 | 286 | # Serialize predictions and store in MaskData 287 | data = MaskData( 288 | masks=masks.flatten(0, 1), 289 | iou_preds=iou_preds.flatten(0, 1), 290 | points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), 291 | ) 292 | del masks 293 | 294 | # Filter by predicted IoU 295 | if self.pred_iou_thresh > 0.0: 296 | keep_mask = data["iou_preds"] > self.pred_iou_thresh 297 | data.filter(keep_mask) 298 | 299 | # Calculate stability score 300 | data["stability_score"] = calculate_stability_score( 301 | data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset 302 | ) 303 | if self.stability_score_thresh > 0.0: 304 | keep_mask = data["stability_score"] >= self.stability_score_thresh 305 | data.filter(keep_mask) 306 | 307 | # Threshold masks and calculate boxes 308 | data["masks"] = data["masks"] > self.predictor.model.mask_threshold 309 | data["boxes"] = batched_mask_to_box(data["masks"]) 310 | 311 | # Filter boxes that touch crop boundaries 312 | keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) 313 | if not torch.all(keep_mask): 314 | data.filter(keep_mask) 315 | 316 | # Compress to RLE 317 | data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) 318 | data["rles"] = mask_to_rle_pytorch(data["masks"]) 319 | del data["masks"] 320 | 321 | return data 322 | 323 | @staticmethod 324 | def postprocess_small_regions( 325 | mask_data: MaskData, min_area: int, nms_thresh: float 326 | ) -> MaskData: 327 | """ 328 | Removes small disconnected regions and holes in masks, then reruns 329 | box NMS to remove any new duplicates. 330 | 331 | Edits mask_data in place. 332 | 333 | Requires open-cv as a dependency. 334 | """ 335 | if len(mask_data["rles"]) == 0: 336 | return mask_data 337 | 338 | # Filter small disconnected regions and holes 339 | new_masks = [] 340 | scores = [] 341 | for rle in mask_data["rles"]: 342 | mask = rle_to_mask(rle) 343 | 344 | mask, changed = remove_small_regions(mask, min_area, mode="holes") 345 | unchanged = not changed 346 | mask, changed = remove_small_regions(mask, min_area, mode="islands") 347 | unchanged = unchanged and not changed 348 | 349 | new_masks.append(torch.as_tensor(mask).unsqueeze(0)) 350 | # Give score=0 to changed masks and score=1 to unchanged masks 351 | # so NMS will prefer ones that didn't need postprocessing 352 | scores.append(float(unchanged)) 353 | 354 | # Recalculate boxes and remove any new duplicates 355 | masks = torch.cat(new_masks, dim=0) 356 | boxes = batched_mask_to_box(masks) 357 | keep_by_nms = batched_nms( 358 | boxes.float(), 359 | torch.as_tensor(scores), 360 | torch.zeros_like(boxes[:, 0]), # categories 361 | iou_threshold=nms_thresh, 362 | ) 363 | 364 | # Only recalculate RLEs for masks that have changed 365 | for i_mask in keep_by_nms: 366 | if scores[i_mask] == 0.0: 367 | mask_torch = masks[i_mask].unsqueeze(0) 368 | mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] 369 | mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly 370 | mask_data.filter(keep_by_nms) 371 | 372 | return mask_data 373 | -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyimaging/FNPC/237048543f82dd417af7df5d8312cebc762e9e3e/segment_anything/modeling/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyimaging/FNPC/237048543f82dd417af7df5d8312cebc762e9e3e/segment_anything/modeling/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/image_encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyimaging/FNPC/237048543f82dd417af7df5d8312cebc762e9e3e/segment_anything/modeling/__pycache__/image_encoder.cpython-36.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/mask_decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyimaging/FNPC/237048543f82dd417af7df5d8312cebc762e9e3e/segment_anything/modeling/__pycache__/mask_decoder.cpython-36.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/prompt_encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyimaging/FNPC/237048543f82dd417af7df5d8312cebc762e9e3e/segment_anything/modeling/__pycache__/prompt_encoder.cpython-36.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/sam.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyimaging/FNPC/237048543f82dd417af7df5d8312cebc762e9e3e/segment_anything/modeling/__pycache__/sam.cpython-36.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/__pycache__/transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyimaging/FNPC/237048543f82dd417af7df5d8312cebc762e9e3e/segment_anything/modeling/__pycache__/transformer.cpython-36.pyc -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /segment_anything/modeling/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d, MLPBlock 14 | 15 | 16 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 17 | class ImageEncoderViT(nn.Module): 18 | def __init__( 19 | self, 20 | img_size: int = 1024, 21 | patch_size: int = 16, 22 | in_chans: int = 3, 23 | embed_dim: int = 768, 24 | depth: int = 12, 25 | num_heads: int = 12, 26 | mlp_ratio: float = 4.0, 27 | out_chans: int = 256, 28 | qkv_bias: bool = True, 29 | norm_layer: Type[nn.Module] = nn.LayerNorm, 30 | act_layer: Type[nn.Module] = nn.GELU, 31 | use_abs_pos: bool = True, 32 | use_rel_pos: bool = False, 33 | rel_pos_zero_init: bool = True, 34 | window_size: int = 0, 35 | global_attn_indexes: Tuple[int, ...] = (), 36 | ) -> None: 37 | """ 38 | Args: 39 | img_size (int): Input image size. 40 | patch_size (int): Patch size. 41 | in_chans (int): Number of input image channels. 42 | embed_dim (int): Patch embedding dimension. 43 | depth (int): Depth of ViT. 44 | num_heads (int): Number of attention heads in each ViT block. 45 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 46 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 47 | norm_layer (nn.Module): Normalization layer. 48 | act_layer (nn.Module): Activation layer. 49 | use_abs_pos (bool): If True, use absolute positional embeddings. 50 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 51 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 52 | window_size (int): Window size for window attention blocks. 53 | global_attn_indexes (list): Indexes for blocks using global attention. 54 | """ 55 | super().__init__() 56 | self.img_size = img_size 57 | 58 | self.patch_embed = PatchEmbed( 59 | kernel_size=(patch_size, patch_size), 60 | stride=(patch_size, patch_size), 61 | in_chans=in_chans, 62 | embed_dim=embed_dim, 63 | ) 64 | 65 | self.pos_embed: Optional[nn.Parameter] = None 66 | if use_abs_pos: 67 | # Initialize absolute positional embedding with pretrain image size. 68 | self.pos_embed = nn.Parameter( 69 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 70 | ) 71 | 72 | self.blocks = nn.ModuleList() 73 | for i in range(depth): 74 | block = Block( 75 | dim=embed_dim, 76 | num_heads=num_heads, 77 | mlp_ratio=mlp_ratio, 78 | qkv_bias=qkv_bias, 79 | norm_layer=norm_layer, 80 | act_layer=act_layer, 81 | use_rel_pos=use_rel_pos, 82 | rel_pos_zero_init=rel_pos_zero_init, 83 | window_size=window_size if i not in global_attn_indexes else 0, 84 | input_size=(img_size // patch_size, img_size // patch_size), 85 | ) 86 | self.blocks.append(block) 87 | 88 | self.neck = nn.Sequential( 89 | nn.Conv2d( 90 | embed_dim, 91 | out_chans, 92 | kernel_size=1, 93 | bias=False, 94 | ), 95 | LayerNorm2d(out_chans), 96 | nn.Conv2d( 97 | out_chans, 98 | out_chans, 99 | kernel_size=3, 100 | padding=1, 101 | bias=False, 102 | ), 103 | LayerNorm2d(out_chans), 104 | ) 105 | 106 | def forward(self, x: torch.Tensor) -> torch.Tensor: 107 | x = self.patch_embed(x) 108 | if self.pos_embed is not None: 109 | x = x + self.pos_embed 110 | 111 | for blk in self.blocks: 112 | x = blk(x) 113 | 114 | x = self.neck(x.permute(0, 3, 1, 2)) 115 | 116 | return x 117 | 118 | 119 | class Block(nn.Module): 120 | """Transformer blocks with support of window attention and residual propagation blocks""" 121 | 122 | def __init__( 123 | self, 124 | dim: int, 125 | num_heads: int, 126 | mlp_ratio: float = 4.0, 127 | qkv_bias: bool = True, 128 | norm_layer: Type[nn.Module] = nn.LayerNorm, 129 | act_layer: Type[nn.Module] = nn.GELU, 130 | use_rel_pos: bool = False, 131 | rel_pos_zero_init: bool = True, 132 | window_size: int = 0, 133 | input_size: Optional[Tuple[int, int]] = None, 134 | ) -> None: 135 | """ 136 | Args: 137 | dim (int): Number of input channels. 138 | num_heads (int): Number of attention heads in each ViT block. 139 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 140 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 141 | norm_layer (nn.Module): Normalization layer. 142 | act_layer (nn.Module): Activation layer. 143 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 144 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 145 | window_size (int): Window size for window attention blocks. If it equals 0, then 146 | use global attention. 147 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 148 | positional parameter size. 149 | """ 150 | super().__init__() 151 | self.norm1 = norm_layer(dim) 152 | self.attn = Attention( 153 | dim, 154 | num_heads=num_heads, 155 | qkv_bias=qkv_bias, 156 | use_rel_pos=use_rel_pos, 157 | rel_pos_zero_init=rel_pos_zero_init, 158 | input_size=input_size if window_size == 0 else (window_size, window_size), 159 | ) 160 | 161 | self.norm2 = norm_layer(dim) 162 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 163 | 164 | self.window_size = window_size 165 | 166 | def forward(self, x: torch.Tensor) -> torch.Tensor: 167 | shortcut = x 168 | x = self.norm1(x) 169 | # Window partition 170 | if self.window_size > 0: 171 | H, W = x.shape[1], x.shape[2] 172 | x, pad_hw = window_partition(x, self.window_size) 173 | 174 | x = self.attn(x) 175 | # Reverse window partition 176 | if self.window_size > 0: 177 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 178 | 179 | x = shortcut + x 180 | x = x + self.mlp(self.norm2(x)) 181 | 182 | return x 183 | 184 | 185 | class Attention(nn.Module): 186 | """Multi-head Attention block with relative position embeddings.""" 187 | 188 | def __init__( 189 | self, 190 | dim: int, 191 | num_heads: int = 8, 192 | qkv_bias: bool = True, 193 | use_rel_pos: bool = False, 194 | rel_pos_zero_init: bool = True, 195 | input_size: Optional[Tuple[int, int]] = None, 196 | ) -> None: 197 | """ 198 | Args: 199 | dim (int): Number of input channels. 200 | num_heads (int): Number of attention heads. 201 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 202 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 203 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 204 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 205 | positional parameter size. 206 | """ 207 | super().__init__() 208 | self.num_heads = num_heads 209 | head_dim = dim // num_heads 210 | self.scale = head_dim**-0.5 211 | 212 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 213 | self.proj = nn.Linear(dim, dim) 214 | 215 | self.use_rel_pos = use_rel_pos 216 | if self.use_rel_pos: 217 | assert ( 218 | input_size is not None 219 | ), "Input size must be provided if using relative positional encoding." 220 | # initialize relative positional embeddings 221 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 222 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 223 | 224 | def forward(self, x: torch.Tensor) -> torch.Tensor: 225 | B, H, W, _ = x.shape 226 | # qkv with shape (3, B, nHead, H * W, C) 227 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 228 | # q, k, v with shape (B * nHead, H * W, C) 229 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 230 | 231 | attn = (q * self.scale) @ k.transpose(-2, -1) 232 | 233 | if self.use_rel_pos: 234 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 235 | 236 | attn = attn.softmax(dim=-1) 237 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 238 | x = self.proj(x) 239 | 240 | return x 241 | 242 | 243 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 244 | """ 245 | Partition into non-overlapping windows with padding if needed. 246 | Args: 247 | x (tensor): input tokens with [B, H, W, C]. 248 | window_size (int): window size. 249 | 250 | Returns: 251 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 252 | (Hp, Wp): padded height and width before partition 253 | """ 254 | B, H, W, C = x.shape 255 | 256 | pad_h = (window_size - H % window_size) % window_size 257 | pad_w = (window_size - W % window_size) % window_size 258 | if pad_h > 0 or pad_w > 0: 259 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 260 | Hp, Wp = H + pad_h, W + pad_w 261 | 262 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 263 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 264 | return windows, (Hp, Wp) 265 | 266 | 267 | def window_unpartition( 268 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 269 | ) -> torch.Tensor: 270 | """ 271 | Window unpartition into original sequences and removing padding. 272 | Args: 273 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 274 | window_size (int): window size. 275 | pad_hw (Tuple): padded height and width (Hp, Wp). 276 | hw (Tuple): original height and width (H, W) before padding. 277 | 278 | Returns: 279 | x: unpartitioned sequences with [B, H, W, C]. 280 | """ 281 | Hp, Wp = pad_hw 282 | H, W = hw 283 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 284 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 285 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 286 | 287 | if Hp > H or Wp > W: 288 | x = x[:, :H, :W, :].contiguous() 289 | return x 290 | 291 | 292 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 293 | """ 294 | Get relative positional embeddings according to the relative positions of 295 | query and key sizes. 296 | Args: 297 | q_size (int): size of query q. 298 | k_size (int): size of key k. 299 | rel_pos (Tensor): relative position embeddings (L, C). 300 | 301 | Returns: 302 | Extracted positional embeddings according to relative positions. 303 | """ 304 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 305 | # Interpolate rel pos if needed. 306 | if rel_pos.shape[0] != max_rel_dist: 307 | # Interpolate rel pos. 308 | rel_pos_resized = F.interpolate( 309 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 310 | size=max_rel_dist, 311 | mode="linear", 312 | ) 313 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 314 | else: 315 | rel_pos_resized = rel_pos 316 | 317 | # Scale the coords with short length if shapes for q and k are different. 318 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 319 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 320 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 321 | 322 | return rel_pos_resized[relative_coords.long()] 323 | 324 | 325 | def add_decomposed_rel_pos( 326 | attn: torch.Tensor, 327 | q: torch.Tensor, 328 | rel_pos_h: torch.Tensor, 329 | rel_pos_w: torch.Tensor, 330 | q_size: Tuple[int, int], 331 | k_size: Tuple[int, int], 332 | ) -> torch.Tensor: 333 | """ 334 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 335 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 336 | Args: 337 | attn (Tensor): attention map. 338 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 339 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 340 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 341 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 342 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 343 | 344 | Returns: 345 | attn (Tensor): attention map with added relative positional embeddings. 346 | """ 347 | q_h, q_w = q_size 348 | k_h, k_w = k_size 349 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 350 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 351 | 352 | B, _, dim = q.shape 353 | r_q = q.reshape(B, q_h, q_w, dim) 354 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 355 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 356 | 357 | attn = ( 358 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 359 | ).view(B, q_h * q_w, k_h * k_w) 360 | 361 | return attn 362 | 363 | 364 | class PatchEmbed(nn.Module): 365 | """ 366 | Image to Patch Embedding. 367 | """ 368 | 369 | def __init__( 370 | self, 371 | kernel_size: Tuple[int, int] = (16, 16), 372 | stride: Tuple[int, int] = (16, 16), 373 | padding: Tuple[int, int] = (0, 0), 374 | in_chans: int = 3, 375 | embed_dim: int = 768, 376 | ) -> None: 377 | """ 378 | Args: 379 | kernel_size (Tuple): kernel size of the projection layer. 380 | stride (Tuple): stride of the projection layer. 381 | padding (Tuple): padding size of the projection layer. 382 | in_chans (int): Number of input image channels. 383 | embed_dim (int): Patch embedding dimension. 384 | """ 385 | super().__init__() 386 | 387 | self.proj = nn.Conv2d( 388 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 389 | ) 390 | 391 | def forward(self, x: torch.Tensor) -> torch.Tensor: 392 | x = self.proj(x) 393 | # B C H W -> B H W C 394 | x = x.permute(0, 2, 3, 1) 395 | return x 396 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | image_embeddings = self.image_encoder(input_images) 99 | 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | masks = self.postprocess_masks( 119 | low_res_masks, 120 | input_size=image_record["image"].shape[-2:], 121 | original_size=image_record["original_size"], 122 | ) 123 | masks = masks > self.mask_threshold 124 | outputs.append( 125 | { 126 | "masks": masks, 127 | "iou_predictions": iou_predictions, 128 | "low_res_logits": low_res_masks, 129 | } 130 | ) 131 | return outputs 132 | 133 | def postprocess_masks( 134 | self, 135 | masks: torch.Tensor, 136 | input_size: Tuple[int, ...], 137 | original_size: Tuple[int, ...], 138 | ) -> torch.Tensor: 139 | """ 140 | Remove padding and upscale masks to the original image size. 141 | 142 | Arguments: 143 | masks (torch.Tensor): Batched masks from the mask_decoder, 144 | in BxCxHxW format. 145 | input_size (tuple(int, int)): The size of the image input to the 146 | model, in (H, W) format. Used to remove padding. 147 | original_size (tuple(int, int)): The original size of the image 148 | before resizing for input to the model, in (H, W) format. 149 | 150 | Returns: 151 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 152 | is given by original_size. 153 | """ 154 | masks = F.interpolate( 155 | masks, 156 | (self.image_encoder.img_size, self.image_encoder.img_size), 157 | mode="bilinear", 158 | align_corners=False, 159 | ) 160 | masks = masks[..., : input_size[0], : input_size[1]] 161 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 162 | return masks 163 | 164 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 165 | """Normalize pixel values and pad to a square input.""" 166 | # Normalize colors 167 | x = (x - self.pixel_mean) / self.pixel_std 168 | 169 | # Pad 170 | h, w = x.shape[-2:] 171 | padh = self.image_encoder.img_size - h 172 | padw = self.image_encoder.img_size - w 173 | x = F.pad(x, (0, padw, 0, padh)) 174 | return x 175 | -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from segment_anything.modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | 163 | masks_np = masks[0].detach().cpu().numpy() 164 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() 165 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() 166 | return masks_np, iou_predictions_np, low_res_masks_np 167 | 168 | @torch.no_grad() 169 | def predict_torch( 170 | self, 171 | point_coords: Optional[torch.Tensor], 172 | point_labels: Optional[torch.Tensor], 173 | boxes: Optional[torch.Tensor] = None, 174 | mask_input: Optional[torch.Tensor] = None, 175 | multimask_output: bool = True, 176 | return_logits: bool = False, 177 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | """ 179 | Predict masks for the given input prompts, using the currently set image. 180 | Input prompts are batched torch tensors and are expected to already be 181 | transformed to the input frame using ResizeLongestSide. 182 | 183 | Arguments: 184 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 185 | model. Each point is in (X,Y) in pixels. 186 | point_labels (torch.Tensor or None): A BxN array of labels for the 187 | point prompts. 1 indicates a foreground point and 0 indicates a 188 | background point. 189 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 190 | model, in XYXY format. 191 | mask_input (np.ndarray): A low resolution mask input to the model, typically 192 | coming from a previous prediction iteration. Has form Bx1xHxW, where 193 | for SAM, H=W=256. Masks returned by a previous iteration of the 194 | predict method do not need further transformation. 195 | multimask_output (bool): If true, the model will return three masks. 196 | For ambiguous input prompts (such as a single click), this will often 197 | produce better masks than a single prediction. If only a single 198 | mask is needed, the model's predicted quality score can be used 199 | to select the best mask. For non-ambiguous prompts, such as multiple 200 | input prompts, multimask_output=False can give better results. 201 | return_logits (bool): If true, returns un-thresholded masks logits 202 | instead of a binary mask. 203 | 204 | Returns: 205 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 206 | number of masks, and (H, W) is the original image size. 207 | (torch.Tensor): An array of shape BxC containing the model's 208 | predictions for the quality of each mask. 209 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 210 | of masks and H=W=256. These low res logits can be passed to 211 | a subsequent iteration as mask input. 212 | """ 213 | if not self.is_image_set: 214 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 215 | 216 | if point_coords is not None: 217 | points = (point_coords, point_labels) 218 | else: 219 | points = None 220 | 221 | # Embed prompts 222 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 223 | points=points, 224 | boxes=boxes, 225 | masks=mask_input, 226 | ) 227 | 228 | # Predict masks 229 | low_res_masks, iou_predictions = self.model.mask_decoder( 230 | image_embeddings=self.features, 231 | image_pe=self.model.prompt_encoder.get_dense_pe(), 232 | sparse_prompt_embeddings=sparse_embeddings, 233 | dense_prompt_embeddings=dense_embeddings, 234 | multimask_output=multimask_output, 235 | ) 236 | 237 | # Upscale the masks to the original image resolution 238 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyimaging/FNPC/237048543f82dd417af7df5d8312cebc762e9e3e/segment_anything/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/amg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyimaging/FNPC/237048543f82dd417af7df5d8312cebc762e9e3e/segment_anything/utils/__pycache__/amg.cpython-36.pyc -------------------------------------------------------------------------------- /segment_anything/utils/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xyimaging/FNPC/237048543f82dd417af7df5d8312cebc762e9e3e/segment_anything/utils/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /segment_anything/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import math 11 | from copy import deepcopy 12 | from itertools import product 13 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 14 | 15 | 16 | class MaskData: 17 | """ 18 | A structure for storing masks and their related data in batched format. 19 | Implements basic filtering and concatenation. 20 | """ 21 | 22 | def __init__(self, **kwargs) -> None: 23 | for v in kwargs.values(): 24 | assert isinstance( 25 | v, (list, np.ndarray, torch.Tensor) 26 | ), "MaskData only supports list, numpy arrays, and torch tensors." 27 | self._stats = dict(**kwargs) 28 | 29 | def __setitem__(self, key: str, item: Any) -> None: 30 | assert isinstance( 31 | item, (list, np.ndarray, torch.Tensor) 32 | ), "MaskData only supports list, numpy arrays, and torch tensors." 33 | self._stats[key] = item 34 | 35 | def __delitem__(self, key: str) -> None: 36 | del self._stats[key] 37 | 38 | def __getitem__(self, key: str) -> Any: 39 | return self._stats[key] 40 | 41 | def items(self) -> ItemsView[str, Any]: 42 | return self._stats.items() 43 | 44 | def filter(self, keep: torch.Tensor) -> None: 45 | for k, v in self._stats.items(): 46 | if v is None: 47 | self._stats[k] = None 48 | elif isinstance(v, torch.Tensor): 49 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 50 | elif isinstance(v, np.ndarray): 51 | self._stats[k] = v[keep.detach().cpu().numpy()] 52 | elif isinstance(v, list) and keep.dtype == torch.bool: 53 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 54 | elif isinstance(v, list): 55 | self._stats[k] = [v[i] for i in keep] 56 | else: 57 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 58 | 59 | def cat(self, new_stats: "MaskData") -> None: 60 | for k, v in new_stats.items(): 61 | if k not in self._stats or self._stats[k] is None: 62 | self._stats[k] = deepcopy(v) 63 | elif isinstance(v, torch.Tensor): 64 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 65 | elif isinstance(v, np.ndarray): 66 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 67 | elif isinstance(v, list): 68 | self._stats[k] = self._stats[k] + deepcopy(v) 69 | else: 70 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 71 | 72 | def to_numpy(self) -> None: 73 | for k, v in self._stats.items(): 74 | if isinstance(v, torch.Tensor): 75 | self._stats[k] = v.detach().cpu().numpy() 76 | 77 | 78 | def is_box_near_crop_edge( 79 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 80 | ) -> torch.Tensor: 81 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 82 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 83 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 84 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 85 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 86 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 87 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 88 | return torch.any(near_crop_edge, dim=1) 89 | 90 | 91 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 92 | box_xywh = deepcopy(box_xyxy) 93 | box_xywh[2] = box_xywh[2] - box_xywh[0] 94 | box_xywh[3] = box_xywh[3] - box_xywh[1] 95 | return box_xywh 96 | 97 | 98 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 99 | assert len(args) > 0 and all( 100 | len(a) == len(args[0]) for a in args 101 | ), "Batched iteration must have inputs of all the same size." 102 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 103 | for b in range(n_batches): 104 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 105 | 106 | 107 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 108 | """ 109 | Encodes masks to an uncompressed RLE, in the format expected by 110 | pycoco tools. 111 | """ 112 | # Put in fortran order and flatten h,w 113 | b, h, w = tensor.shape 114 | tensor = tensor.permute(0, 2, 1).flatten(1) 115 | 116 | # Compute change indices 117 | diff = tensor[:, 1:] ^ tensor[:, :-1] 118 | change_indices = diff.nonzero() 119 | 120 | # Encode run length 121 | out = [] 122 | for i in range(b): 123 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 124 | cur_idxs = torch.cat( 125 | [ 126 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 127 | cur_idxs + 1, 128 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | ] 130 | ) 131 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 132 | counts = [] if tensor[i, 0] == 0 else [0] 133 | counts.extend(btw_idxs.detach().cpu().tolist()) 134 | out.append({"size": [h, w], "counts": counts}) 135 | return out 136 | 137 | 138 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 139 | """Compute a binary mask from an uncompressed RLE.""" 140 | h, w = rle["size"] 141 | mask = np.empty(h * w, dtype=bool) 142 | idx = 0 143 | parity = False 144 | for count in rle["counts"]: 145 | mask[idx : idx + count] = parity 146 | idx += count 147 | parity ^= True 148 | mask = mask.reshape(w, h) 149 | return mask.transpose() # Put in C order 150 | 151 | 152 | def area_from_rle(rle: Dict[str, Any]) -> int: 153 | return sum(rle["counts"][1::2]) 154 | 155 | 156 | def calculate_stability_score( 157 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 158 | ) -> torch.Tensor: 159 | """ 160 | Computes the stability score for a batch of masks. The stability 161 | score is the IoU between the binary masks obtained by thresholding 162 | the predicted mask logits at high and low values. 163 | """ 164 | # One mask is always contained inside the other. 165 | # Save memory by preventing unnecessary cast to torch.int64 166 | intersections = ( 167 | (masks > (mask_threshold + threshold_offset)) 168 | .sum(-1, dtype=torch.int16) 169 | .sum(-1, dtype=torch.int32) 170 | ) 171 | unions = ( 172 | (masks > (mask_threshold - threshold_offset)) 173 | .sum(-1, dtype=torch.int16) 174 | .sum(-1, dtype=torch.int32) 175 | ) 176 | return intersections / unions 177 | 178 | 179 | def build_point_grid(n_per_side: int) -> np.ndarray: 180 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 181 | offset = 1 / (2 * n_per_side) 182 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 183 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 184 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 185 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 186 | return points 187 | 188 | 189 | def build_all_layer_point_grids( 190 | n_per_side: int, n_layers: int, scale_per_layer: int 191 | ) -> List[np.ndarray]: 192 | """Generates point grids for all crop layers.""" 193 | points_by_layer = [] 194 | for i in range(n_layers + 1): 195 | n_points = int(n_per_side / (scale_per_layer**i)) 196 | points_by_layer.append(build_point_grid(n_points)) 197 | return points_by_layer 198 | 199 | 200 | def generate_crop_boxes( 201 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 202 | ) -> Tuple[List[List[int]], List[int]]: 203 | """ 204 | Generates a list of crop boxes of different sizes. Each layer 205 | has (2**i)**2 boxes for the ith layer. 206 | """ 207 | crop_boxes, layer_idxs = [], [] 208 | im_h, im_w = im_size 209 | short_side = min(im_h, im_w) 210 | 211 | # Original image 212 | crop_boxes.append([0, 0, im_w, im_h]) 213 | layer_idxs.append(0) 214 | 215 | def crop_len(orig_len, n_crops, overlap): 216 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 217 | 218 | for i_layer in range(n_layers): 219 | n_crops_per_side = 2 ** (i_layer + 1) 220 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 221 | 222 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 223 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 224 | 225 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 226 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 227 | 228 | # Crops in XYWH format 229 | for x0, y0 in product(crop_box_x0, crop_box_y0): 230 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 231 | crop_boxes.append(box) 232 | layer_idxs.append(i_layer + 1) 233 | 234 | return crop_boxes, layer_idxs 235 | 236 | 237 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 238 | x0, y0, _, _ = crop_box 239 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 240 | # Check if boxes has a channel dimension 241 | if len(boxes.shape) == 3: 242 | offset = offset.unsqueeze(1) 243 | return boxes + offset 244 | 245 | 246 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 247 | x0, y0, _, _ = crop_box 248 | offset = torch.tensor([[x0, y0]], device=points.device) 249 | # Check if points has a channel dimension 250 | if len(points.shape) == 3: 251 | offset = offset.unsqueeze(1) 252 | return points + offset 253 | 254 | 255 | def uncrop_masks( 256 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 257 | ) -> torch.Tensor: 258 | x0, y0, x1, y1 = crop_box 259 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 260 | return masks 261 | # Coordinate transform masks 262 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 263 | pad = (x0, pad_x - x0, y0, pad_y - y0) 264 | return torch.nn.functional.pad(masks, pad, value=0) 265 | 266 | 267 | def remove_small_regions( 268 | mask: np.ndarray, area_thresh: float, mode: str 269 | ) -> Tuple[np.ndarray, bool]: 270 | """ 271 | Removes small disconnected regions and holes in a mask. Returns the 272 | mask and an indicator of if the mask has been modified. 273 | """ 274 | import cv2 # type: ignore 275 | 276 | assert mode in ["holes", "islands"] 277 | correct_holes = mode == "holes" 278 | working_mask = (correct_holes ^ mask).astype(np.uint8) 279 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 280 | sizes = stats[:, -1][1:] # Row 0 is background label 281 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 282 | if len(small_regions) == 0: 283 | return mask, False 284 | fill_labels = [0] + small_regions 285 | if not correct_holes: 286 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 287 | # If every region is below threshold, keep largest 288 | if len(fill_labels) == 0: 289 | fill_labels = [int(np.argmax(sizes)) + 1] 290 | mask = np.isin(regions, fill_labels) 291 | return mask, True 292 | 293 | 294 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 295 | from pycocotools import mask as mask_utils # type: ignore 296 | 297 | h, w = uncompressed_rle["size"] 298 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 299 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 300 | return rle 301 | 302 | 303 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 304 | """ 305 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 306 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 307 | """ 308 | # torch.max below raises an error on empty inputs, just skip in this case 309 | if torch.numel(masks) == 0: 310 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 311 | 312 | # Normalize shape to CxHxW 313 | shape = masks.shape 314 | h, w = shape[-2:] 315 | if len(shape) > 2: 316 | masks = masks.flatten(0, -3) 317 | else: 318 | masks = masks.unsqueeze(0) 319 | 320 | # Get top and bottom edges 321 | in_height, _ = torch.max(masks, dim=-1) 322 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 323 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 324 | in_height_coords = in_height_coords + h * (~in_height) 325 | top_edges, _ = torch.min(in_height_coords, dim=-1) 326 | 327 | # Get left and right edges 328 | in_width, _ = torch.max(masks, dim=-2) 329 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 330 | right_edges, _ = torch.max(in_width_coords, dim=-1) 331 | in_width_coords = in_width_coords + w * (~in_width) 332 | left_edges, _ = torch.min(in_width_coords, dim=-1) 333 | 334 | # If the mask is empty the right edge will be to the left of the left edge. 335 | # Replace these boxes with [0, 0, 0, 0] 336 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 337 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 338 | out = out * (~empty_filter).unsqueeze(-1) 339 | 340 | # Return to original shape 341 | if len(shape) > 2: 342 | out = out.reshape(*shape[:-2], 4) 343 | else: 344 | out = out[0] 345 | 346 | return out 347 | -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /show.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | 6 | 7 | 8 | def show_mask(mask, ax, random_color=False): 9 | if random_color: 10 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 11 | else: 12 | color = np.array([30/255, 144/255, 255/255, 0.4]) 13 | h, w = mask.shape[-2:] 14 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 15 | ax.imshow(mask_image) 16 | 17 | 18 | def show_points(coords, labels, ax, marker_size=375): 19 | pos_points = coords[labels==1] 20 | neg_points = coords[labels==0] 21 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 22 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 23 | 24 | 25 | def show_box(box, ax, color="green"): 26 | x0, y0 = box[0], box[1] 27 | w, h = box[2] - box[0], box[3] - box[1] 28 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0,0,0,0), lw=4)) 29 | # ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=color, lw=4)) # fullfill the bbox 30 | 31 | 32 | # TODO I add here 33 | def get_bbox(mask): 34 | rows = np.any(mask, axis=1) 35 | cols = np.any(mask, axis=0) 36 | y0, y1 = np.where(rows)[0][[0, -1]] 37 | x0, x1 = np.where(cols)[0][[0, -1]] 38 | 39 | return [x0, y0, x1, y1] -------------------------------------------------------------------------------- /uncertainty_refine.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import functional as F 4 | import os 5 | import cv2 6 | from tqdm import tqdm 7 | import argparse 8 | import matplotlib.pyplot as plt 9 | import warnings 10 | warnings.filterwarnings('ignore') 11 | from show import * 12 | from per_segment_anything import sam_model_registry, SamPredictor 13 | from itertools import combinations 14 | from scipy.spatial import ConvexHull 15 | import numpy as np 16 | import itertools 17 | import scipy 18 | from scipy.spatial.distance import cdist 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | # this folder have all the augmentation functions need for the dataaug 22 | 23 | import numpy as np 24 | import random 25 | 26 | import matplotlib.pyplot as plt 27 | import cv2 28 | import math 29 | import cv2 30 | import math 31 | # import numpy as np 32 | from bridson import poisson_disc_samples 33 | 34 | from bridson import poisson_disc_samples 35 | 36 | 37 | # the uncertainty-based FNPC block in hardcode version 38 | def uc_refine_hardcode(binary_mask, uncertainty_map, img, threshold_uc = 0.2, fn_alpha = 0.5, fn_beta = 0.7,fp_alpha = 0.5, fp_beta = 0.7 ): 39 | # based on uc_refine but don't use mean value, use the range directly 40 | # binary_mask: 1xDxD (0, 1), the pseudo label 41 | # uncertainty_map: DxD (0 - 0.25) 42 | # img: DxDx3 (0, 255), the original image 43 | 44 | # Average over the channels of img 45 | img_avg = np.mean(img, axis=2) 46 | 47 | # calculate the mean xt(yt) 48 | mean_value = np.mean(img_avg * binary_mask[0]) 49 | print("the mean value is:", mean_value) 50 | 51 | # Get the uncertainty threshold value 52 | uc_threshold = np.min(uncertainty_map)+threshold_uc * (np.max(uncertainty_map) - np.min(uncertainty_map)) 53 | 54 | # Identify the U_thre (turn uc to a binary mask) 55 | U_thre = uncertainty_map > uc_threshold 56 | 57 | ################## do the FN correction ################### 58 | # get the 1 - yt 59 | # We know that the region outside the bbox should be true negative, so here should be written 60 | # as bbox - binary_mask. But don't worry, the current unperfect version doesn't affect the final results, since we 61 | # refine the final mask by BBox in the FNPC_Interface.py 62 | inverse_binary_mask = 1 - binary_mask # get the region where is not covered by pseudo mask 63 | 64 | # get the UH = (1-yt)*U_thre 65 | FN_UH = U_thre * inverse_binary_mask[0] # DxD 66 | 67 | # get the xUH 68 | FN_xUH = img_avg * FN_UH # DxD 69 | 70 | # pseudo label FN refine 71 | # Create a boolean mask where True indicates the condition is met 72 | FN_condition_mask = (fn_alpha < FN_xUH) & (FN_xUH < fn_beta) 73 | 74 | # Use numpy's logical_or function to keep the old mask values where the condition is not met 75 | FN_new_mask = np.logical_or(binary_mask, FN_condition_mask) # the return is DxD since 1xDxD will be consider as DxD in np logical 76 | FN_output_mask = FN_new_mask# 1xDxD 77 | 78 | ################## do the FP correction ################### 79 | FP_UH = U_thre * binary_mask[0] # DxD 80 | 81 | # get the xUH 82 | FP_xUH = img_avg * FP_UH # DxD 83 | # print(FP_xUH) 84 | 85 | # pseudo label FP refine 86 | # Create a boolean mask where True indicates the condition is met 87 | FP_condition_mask = ((fp_alpha > FP_xUH) | (FP_xUH >fp_beta))&(FP_UH>0) 88 | 89 | # Use numpy's logical_or function to keep the old mask values where the condition is not met 90 | FP_new_mask = np.logical_and(FN_output_mask, np.logical_not( 91 | FP_condition_mask)) # the return is DxD since 1xDxD will be consider as DxD in np logical 92 | 93 | # output_mask = np.expand_dims(new_mask, axis=0) 94 | FP_output_mask = FP_new_mask.astype(int) # 1xDxD 95 | return FN_output_mask, FN_UH, FN_xUH, FN_condition_mask, FP_output_mask, FP_UH, FP_xUH, FP_condition_mask 96 | 97 | # the uncertainty-based FNPC block in ratio version 98 | def uc_refine_correct(binary_mask, uncertainty_map, img, threshold_uc = 0.2, fn_alpha = 0.5, fn_beta = 0.7,fp_alpha = 0.5, fp_beta = 0.7 ): 99 | # binary_mask: 1xDxD (0, 1), the pseudo label 100 | # uncertainty_map: DxD (0 - 0.25) 101 | # img: DxDx3 (0, 255), the original image 102 | 103 | # Average over the channels of img 104 | img_avg = np.mean(img, axis=2) 105 | 106 | # calculate the mean xt(yt) 107 | mean_value = np.mean(img_avg[(binary_mask[0] > 0)]) # for the placenta task 108 | 109 | print("the mean value is:", mean_value) 110 | 111 | # Get the uncertainty threshold value 112 | uc_threshold = np.min(uncertainty_map)+threshold_uc * (np.max(uncertainty_map) - np.min(uncertainty_map)) 113 | 114 | # Identify the U_thre (turn uc to a binary mask) 115 | U_thre = uncertainty_map > uc_threshold 116 | 117 | ################## do the FN correction ################### 118 | # get the 1 - yt 119 | # We know that the region outside the bbox should be true negative, so here should be written 120 | # as bbox - binary_mask. But don't worry, the current unperfect version doesn't affect the final results, since we 121 | # refine the final mask by BBox in the FNPC_Interface.py 122 | inverse_binary_mask = 1 - binary_mask # get the region where is not covered by pseudo mask 123 | 124 | # get the UH = (1-yt)*U_thre 125 | FN_UH = U_thre * inverse_binary_mask[0] # DxD 126 | 127 | # get the xUH 128 | FN_xUH = img_avg * FN_UH # DxD 129 | 130 | # pseudo label FN refine 131 | # Create a boolean mask where True indicates the condition is met 132 | FN_condition_mask = (mean_value * fn_alpha < FN_xUH) & (FN_xUH < mean_value * fn_beta) 133 | 134 | # Use numpy's logical_or function to keep the old mask values where the condition is not met 135 | FN_new_mask = np.logical_or(binary_mask, FN_condition_mask) # the return is DxD since 1xDxD will be consider as DxD in np logical 136 | 137 | # output_mask = np.expand_dims(new_mask, axis=0) 138 | FN_output_mask = FN_new_mask# 1xDxD 139 | 140 | ################## do the FP correction ################### 141 | FP_UH = U_thre * binary_mask[0] # DxD 142 | 143 | # get the xUH 144 | FP_xUH = img_avg * FP_UH # DxD 145 | 146 | # pseudo label FP refine 147 | # Create a boolean mask where True indicates the condition is met 148 | FP_condition_mask = ((mean_value * fp_alpha > FP_xUH) | (FP_xUH > mean_value * fp_beta))&(FP_UH>0) 149 | # FP_condition_mask = (FP_xUH > mean_value * fp_beta) 150 | 151 | # Use numpy's logical_or function to keep the old mask values where the condition is not met 152 | FP_new_mask = np.logical_and(FN_output_mask, np.logical_not( 153 | FP_condition_mask)) # the return is DxD since 1xDxD will be consider as DxD in np logical 154 | 155 | # output_mask = np.expand_dims(new_mask, axis=0) 156 | FP_output_mask = FP_new_mask.astype(int) # 1xDxD 157 | return FN_output_mask, FN_UH, FN_xUH, FN_condition_mask, FP_output_mask, FP_UH, FP_xUH, FP_condition_mask --------------------------------------------------------------------------------