├── README.md
├── assets
├── 1.jpg
└── ed_1.jpg
└── edge_detect.py
/README.md:
--------------------------------------------------------------------------------
1 | # Segment Anything Edge Detection
2 |
3 | This project is a replication of edge detection implemented using Segment Anything. It is reproduced as described in the [Segment Anything paper](https://ai.facebook.com/research/publications/segment-anything/) and [Segment Anything issues#226](https://github.com/facebookresearch/segment-anything/issues/226). Please advise us if there are any problems.
4 |
5 |
6 | ## Installation
7 |
8 | The code requires `python>=3.8`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended.
9 |
10 | Install Segment Anything:
11 |
12 | ```shell
13 | pip install git+https://github.com/facebookresearch/segment-anything.git
14 | ```
15 |
16 | Install opencv:
17 |
18 | ```shell
19 | pip install opencv-python
20 | ```
21 |
22 | ## Run
23 |
24 | - download weight
25 |
26 | please refer to [segment anything](https://github.com/facebookresearch/segment-anything#model-checkpoints) that is a awesome work.
27 |
28 | - edge detect
29 |
30 | ```shell
31 | python edge_detect.py --edge_dir EDGE_DIR --save_dir SAVE_DIR [--sam_checkpoint SAM_CHECKPOINT] [--model_type MODEL_TYPE] [--device DEVICE]
32 | ```
33 |
34 |
35 |
36 | ## Demo
37 |
38 |
39 |
40 |
41 |
42 |
43 | ## References
44 |
45 | - [Segment Anything project](https://github.com/facebookresearch/segment-anything)
46 | - [Segment Anything paper](https://ai.facebook.com/research/publications/segment-anything/)
47 |
--------------------------------------------------------------------------------
/assets/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meng-Sang/segment-anything-edge-detect/86d9e78117990553a55848bcf322ac05a1c1aebf/assets/1.jpg
--------------------------------------------------------------------------------
/assets/ed_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Meng-Sang/segment-anything-edge-detect/86d9e78117990553a55848bcf322ac05a1c1aebf/assets/ed_1.jpg
--------------------------------------------------------------------------------
/edge_detect.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | import cv2
6 | import numpy as np
7 | import tqdm
8 |
9 | sys.path.append("..")
10 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
11 |
12 |
13 | def load_model(sam_checkpoint, model_type, device):
14 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
15 | sam.to(device=device)
16 | mask_generator = SamAutomaticMaskGenerator(
17 | model=sam,
18 | points_per_side=16,
19 | pred_iou_thresh=0.86,
20 | stability_score_thresh=0.7,
21 | crop_n_layers=0,
22 | crop_n_points_downscale_factor=2,
23 | min_mask_region_area=100, # Requires open-cv to run post-processing
24 | )
25 | return mask_generator
26 |
27 |
28 | def predict(net, image):
29 | masks = net.generate(image)
30 | return masks
31 |
32 |
33 | def get_edge_from_masks(masks):
34 | edge_image = np.zeros(masks[0]["segmentation"].shape)
35 | for mask in masks:
36 | gray_image = (np.array(mask["segmentation"]) * 255).astype(np.uint8)
37 | sobel_x = cv2.Sobel(gray_image, cv2.CV_64F, 1, 0, ksize=3)
38 | sobel_y = cv2.Sobel(gray_image, cv2.CV_64F, 0, 1, ksize=3)
39 | sobel = np.sqrt(sobel_x ** 2 + sobel_y ** 2)
40 | sobel = cv2.normalize(sobel, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
41 | _, binary_image = cv2.threshold(sobel, 100, 255, cv2.THRESH_BINARY)
42 | edge_image += binary_image
43 | return edge_image
44 |
45 |
46 | def main(edge_dir, save_dir, model_args):
47 | net = load_model(model_args.sam_checkpoint, model_args.model_type, model_args.device)
48 | image_name_list = os.listdir(edge_dir)
49 | if not os.path.exists(edge_dir):
50 | raise Exception(f"{edge_dir} not exist")
51 | if not os.path.exists(save_dir):
52 | os.makedirs(save_dir, exist_ok=True)
53 | for image_name in tqdm.tqdm(image_name_list):
54 | image = cv2.imread(os.path.join(edge_dir, image_name))
55 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
56 | masks = predict(net, image)
57 | edge_image = get_edge_from_masks(masks)
58 | cv2.imwrite(os.path.join(save_dir, image_name), edge_image)
59 | print("Finished !!!")
60 |
61 |
62 | if __name__ == "__main__":
63 | parser = argparse.ArgumentParser(prog="Generation Edge",
64 | description="Generate Edge Image with Segment Anything and Sobel",
65 | allow_abbrev=True)
66 | #
67 | parser.add_argument("--edge_dir", required=True, type=str)
68 | parser.add_argument("--save_dir", required=True, type=str)
69 | parser.add_argument("--sam_checkpoint", default="./weight/sam_vit_h_4b8939.pth", type=str)
70 | parser.add_argument("--model_type", default="vit_h", type=str)
71 | parser.add_argument("--device", default="cpu", type=str)
72 | args = parser.parse_args()
73 | main(args.edge_dir, args.save_dir, args)
74 |
--------------------------------------------------------------------------------