├── README.md ├── assets ├── 1.jpg └── ed_1.jpg └── edge_detect.py /README.md: -------------------------------------------------------------------------------- 1 | # Segment Anything Edge Detection 2 | 3 | This project is a replication of edge detection implemented using Segment Anything. It is reproduced as described in the [Segment Anything paper](https://ai.facebook.com/research/publications/segment-anything/) and [Segment Anything issues#226](https://github.com/facebookresearch/segment-anything/issues/226). Please advise us if there are any problems. 4 | 5 | 6 | ## Installation 7 | 8 | The code requires `python>=3.8`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended. 9 | 10 | Install Segment Anything: 11 | 12 | ```shell 13 | pip install git+https://github.com/facebookresearch/segment-anything.git 14 | ``` 15 | 16 | Install opencv: 17 | 18 | ```shell 19 | pip install opencv-python 20 | ``` 21 | 22 | ## Run 23 | 24 | - download weight 25 | 26 | please refer to [segment anything](https://github.com/facebookresearch/segment-anything#model-checkpoints) that is a awesome work. 27 | 28 | - edge detect 29 | 30 | ```shell 31 | python edge_detect.py --edge_dir EDGE_DIR --save_dir SAVE_DIR [--sam_checkpoint SAM_CHECKPOINT] [--model_type MODEL_TYPE] [--device DEVICE] 32 | ``` 33 | 34 | 35 | 36 | ## Demo 37 | 38 |

39 | 40 | 41 |

42 | 43 | ## References 44 | 45 | - [Segment Anything project](https://github.com/facebookresearch/segment-anything) 46 | - [Segment Anything paper](https://ai.facebook.com/research/publications/segment-anything/) 47 | -------------------------------------------------------------------------------- /assets/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meng-Sang/segment-anything-edge-detect/86d9e78117990553a55848bcf322ac05a1c1aebf/assets/1.jpg -------------------------------------------------------------------------------- /assets/ed_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Meng-Sang/segment-anything-edge-detect/86d9e78117990553a55848bcf322ac05a1c1aebf/assets/ed_1.jpg -------------------------------------------------------------------------------- /edge_detect.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import cv2 6 | import numpy as np 7 | import tqdm 8 | 9 | sys.path.append("..") 10 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor 11 | 12 | 13 | def load_model(sam_checkpoint, model_type, device): 14 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 15 | sam.to(device=device) 16 | mask_generator = SamAutomaticMaskGenerator( 17 | model=sam, 18 | points_per_side=16, 19 | pred_iou_thresh=0.86, 20 | stability_score_thresh=0.7, 21 | crop_n_layers=0, 22 | crop_n_points_downscale_factor=2, 23 | min_mask_region_area=100, # Requires open-cv to run post-processing 24 | ) 25 | return mask_generator 26 | 27 | 28 | def predict(net, image): 29 | masks = net.generate(image) 30 | return masks 31 | 32 | 33 | def get_edge_from_masks(masks): 34 | edge_image = np.zeros(masks[0]["segmentation"].shape) 35 | for mask in masks: 36 | gray_image = (np.array(mask["segmentation"]) * 255).astype(np.uint8) 37 | sobel_x = cv2.Sobel(gray_image, cv2.CV_64F, 1, 0, ksize=3) 38 | sobel_y = cv2.Sobel(gray_image, cv2.CV_64F, 0, 1, ksize=3) 39 | sobel = np.sqrt(sobel_x ** 2 + sobel_y ** 2) 40 | sobel = cv2.normalize(sobel, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) 41 | _, binary_image = cv2.threshold(sobel, 100, 255, cv2.THRESH_BINARY) 42 | edge_image += binary_image 43 | return edge_image 44 | 45 | 46 | def main(edge_dir, save_dir, model_args): 47 | net = load_model(model_args.sam_checkpoint, model_args.model_type, model_args.device) 48 | image_name_list = os.listdir(edge_dir) 49 | if not os.path.exists(edge_dir): 50 | raise Exception(f"{edge_dir} not exist") 51 | if not os.path.exists(save_dir): 52 | os.makedirs(save_dir, exist_ok=True) 53 | for image_name in tqdm.tqdm(image_name_list): 54 | image = cv2.imread(os.path.join(edge_dir, image_name)) 55 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 56 | masks = predict(net, image) 57 | edge_image = get_edge_from_masks(masks) 58 | cv2.imwrite(os.path.join(save_dir, image_name), edge_image) 59 | print("Finished !!!") 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser(prog="Generation Edge", 64 | description="Generate Edge Image with Segment Anything and Sobel", 65 | allow_abbrev=True) 66 | # 67 | parser.add_argument("--edge_dir", required=True, type=str) 68 | parser.add_argument("--save_dir", required=True, type=str) 69 | parser.add_argument("--sam_checkpoint", default="./weight/sam_vit_h_4b8939.pth", type=str) 70 | parser.add_argument("--model_type", default="vit_h", type=str) 71 | parser.add_argument("--device", default="cpu", type=str) 72 | args = parser.parse_args() 73 | main(args.edge_dir, args.save_dir, args) 74 | --------------------------------------------------------------------------------