├── .gitignore ├── requirements.txt ├── datasets ├── labels │ ├── Apex_Legends_2022_05_02_20_43_42.txt │ ├── Apex_Legends_2022_05_01_20_43_28.txt │ └── Apex_Legends_2022_05_01_22_32_05.txt ├── labels_refine │ ├── Apex_Legends_2022_05_02_20_43_42.txt │ ├── Apex_Legends_2022_05_01_20_43_28.txt │ └── Apex_Legends_2022_05_01_22_32_05.txt ├── demo_img.png └── images │ ├── Apex_Legends_2022_05_01_20_43_28.png │ ├── Apex_Legends_2022_05_01_22_32_05.png │ └── Apex_Legends_2022_05_02_20_43_42.png ├── README.md ├── visualization.py ├── utils.py └── run.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | sam_vit_h_4b8939.pth -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | tqdm 3 | click 4 | opencv-python -------------------------------------------------------------------------------- /datasets/labels/Apex_Legends_2022_05_02_20_43_42.txt: -------------------------------------------------------------------------------- 1 | 0 0.519025 0.488668 0.499225 0.73206 2 | -------------------------------------------------------------------------------- /datasets/labels/Apex_Legends_2022_05_01_20_43_28.txt: -------------------------------------------------------------------------------- 1 | 0 0.537095 0.513901 0.466028 0.807452 2 | -------------------------------------------------------------------------------- /datasets/labels/Apex_Legends_2022_05_01_22_32_05.txt: -------------------------------------------------------------------------------- 1 | 0 0.560685 0.487348 0.182621 0.164612 2 | -------------------------------------------------------------------------------- /datasets/labels_refine/Apex_Legends_2022_05_02_20_43_42.txt: -------------------------------------------------------------------------------- 1 | 0 0.53046875 0.498 0.6265625 0.758 2 | -------------------------------------------------------------------------------- /datasets/labels_refine/Apex_Legends_2022_05_01_20_43_28.txt: -------------------------------------------------------------------------------- 1 | 0 0.54140625 0.5145 0.4109375 0.745 2 | -------------------------------------------------------------------------------- /datasets/labels_refine/Apex_Legends_2022_05_01_22_32_05.txt: -------------------------------------------------------------------------------- 1 | 0 0.525 0.49453125 0.165625 0.1796875 2 | -------------------------------------------------------------------------------- /datasets/demo_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTUYWANG103/SAM-BoudingBox-Refine/HEAD/datasets/demo_img.png -------------------------------------------------------------------------------- /datasets/images/Apex_Legends_2022_05_01_20_43_28.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTUYWANG103/SAM-BoudingBox-Refine/HEAD/datasets/images/Apex_Legends_2022_05_01_20_43_28.png -------------------------------------------------------------------------------- /datasets/images/Apex_Legends_2022_05_01_22_32_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTUYWANG103/SAM-BoudingBox-Refine/HEAD/datasets/images/Apex_Legends_2022_05_01_22_32_05.png -------------------------------------------------------------------------------- /datasets/images/Apex_Legends_2022_05_02_20_43_42.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NTUYWANG103/SAM-BoudingBox-Refine/HEAD/datasets/images/Apex_Legends_2022_05_02_20_43_42.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAM-BoudingBox-Refine 2 | This repository refines bounding boxes formatted for YOLO and utilizes the advanced capabilities of the Segment Anything (SAM) model to enhance the accuracy of these bounding boxes. 3 | 4 | ![demo_img](datasets/demo_img.png) 5 | 6 | # Environment 7 | My environment uses python 3.9 with cuda 11.3 8 | ``` 9 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 10 | pip install git+https://github.com/facebookresearch/segment-anything.git 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | # SAM Model Download 15 | Click the links below to download the checkpoint for the corresponding model type. 16 | 17 | - **`default` or `vit_h`: [ViT-H SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)** 18 | - `vit_l`: [ViT-L SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth) 19 | - `vit_b`: [ViT-B SAM model.](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth) 20 | 21 | # Inference 22 | ``` 23 | python run.py --image_dir datasets/images --label_dir datasets/labels --refined_label_dir datasets/labels_refine --checkpoint sam_vit_h_4b8939.pth --model_type vit_h 24 | ``` 25 | 26 | # Visualization 27 | ``` 28 | streamlit run visualization.py 29 | ``` 30 | 31 | # Star History 32 | 33 | [![Star History Chart](https://api.star-history.com/svg?repos=NTUYWANG103/SAM-BoudingBox-Refine&type=Date)](https://star-history.com/#NTUYWANG103/SAM-BoudingBox-Refine&Date) 34 | 35 | -------------------------------------------------------------------------------- /visualization.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import cv2 3 | import os 4 | from utils import read_yolo_label, draw_bounding_boxes 5 | 6 | # Streamlit interface 7 | st.title('YOLO Bounding Box Visualizer') 8 | 9 | # Input fields for directory paths 10 | image_dir = st.text_input('Image Directory', 'datasets/images') 11 | label_dir = st.text_input('Label Directory', 'datasets/labels') 12 | refined_label_dir = st.text_input('Refined Label Directory', 'datasets/labels_refine') 13 | 14 | if not os.path.isdir(image_dir) or not os.path.isdir(label_dir): 15 | st.error('Invalid directory path(s)') 16 | else: 17 | image_files = sorted(os.listdir(image_dir)) 18 | if image_files: 19 | # Choose an index to visualize 20 | max_index = len(image_files) - 1 21 | index = st.number_input('Image Index', min_value=0, max_value=max_index, value=0, step=1) 22 | 23 | image_name = image_files[index] 24 | image_path = os.path.join(image_dir, image_name) 25 | original_label_path = os.path.join(label_dir, os.path.splitext(image_name)[0] + '.txt') 26 | refined_label_path = os.path.join(refined_label_dir, os.path.splitext(image_name)[0] + '.txt') 27 | 28 | col1, col2 = st.columns(2) 29 | 30 | # Display original bounding boxes 31 | if os.path.exists(original_label_path): 32 | original_labels = read_yolo_label(original_label_path) 33 | original_image_with_boxes = draw_bounding_boxes(image_path, original_labels) 34 | col1.image(original_image_with_boxes, caption="Original Bounding Boxes", use_column_width=True) 35 | else: 36 | st.error('Label file not found') 37 | 38 | # Display refined bounding boxes 39 | if os.path.exists(refined_label_path): 40 | refined_labels = read_yolo_label(refined_label_path) 41 | refined_image_with_boxes = draw_bounding_boxes(image_path, refined_labels) 42 | col2.image(refined_image_with_boxes, caption="Refined Bounding Boxes", use_column_width=True) 43 | else: 44 | st.error('Refined label file not found') 45 | else: 46 | st.error('No images found in the specified directory') 47 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | # Function to read YOLO labels 5 | def read_yolo_label(label_path): 6 | labels = [] 7 | with open(label_path, 'r') as file: 8 | for line in file: 9 | if line == '\n': 10 | continue 11 | class_id, x_center, y_center, width, height = map(float, line.split()) 12 | labels.append((class_id, x_center, y_center, width, height)) 13 | return labels 14 | 15 | # Function to convert YOLO format to rectangle coordinates 16 | def yolo_to_coords(yolo_coords, image_shape): 17 | x_center, y_center, width, height = yolo_coords 18 | x_center *= image_shape[1] 19 | y_center *= image_shape[0] 20 | width *= image_shape[1] 21 | height *= image_shape[0] 22 | 23 | x_min = int(x_center - width / 2) 24 | y_min = int(y_center - height / 2) 25 | x_max = int(x_center + width / 2) 26 | y_max = int(y_center + height / 2) 27 | 28 | return (x_min, y_min, x_max, y_max) 29 | 30 | # Function to convert rectangle coordinates to YOLO format 31 | def coords_to_yolo(bboxes, image_shape): 32 | yolo_bboxes = [] 33 | for bbox in bboxes: 34 | if bbox is None: 35 | continue 36 | 37 | x_center = (bbox[0] + bbox[2]) / 2 38 | y_center = (bbox[1] + bbox[3]) / 2 39 | width = bbox[2] - bbox[0] 40 | height = bbox[3] - bbox[1] 41 | 42 | x_center /= image_shape[1] # Normalize by image width 43 | y_center /= image_shape[0] # Normalize by image height 44 | width /= image_shape[1] 45 | height /= image_shape[0] 46 | 47 | yolo_bboxes.append((x_center, y_center, width, height)) 48 | 49 | return yolo_bboxes 50 | 51 | # Function to find minimal bounding rectangle from a segmentation mask 52 | def find_minimal_rectangles(masks): 53 | minimal_rectangles = [] 54 | for mask in masks: 55 | # Convert to uint8 56 | mask = (mask * 255).astype(np.uint8) 57 | 58 | # Apply threshold to convert mask to binary format 59 | _, binary_mask = cv2.threshold(mask, 0.5, 255, cv2.THRESH_BINARY) 60 | 61 | contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 62 | 63 | if contours: 64 | bounding_rects = [cv2.boundingRect(c) for c in contours] 65 | x_min = min([x for x, y, w, h in bounding_rects]) 66 | y_min = min([y for x, y, w, h in bounding_rects]) 67 | x_max = max([x + w for x, y, w, h in bounding_rects]) 68 | y_max = max([y + h for x, y, w, h in bounding_rects]) 69 | 70 | minimal_rectangles.append((x_min, y_min, x_max, y_max)) 71 | else: 72 | minimal_rectangles.append(None) 73 | 74 | return minimal_rectangles 75 | 76 | # Function to draw bounding boxes on an image 77 | def draw_bounding_boxes(image_path, labels): 78 | image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) 79 | for label in labels: 80 | _, x_center, y_center, width, height = label 81 | x_min, y_min, x_max, y_max = yolo_to_coords((x_center, y_center, width, height), image.shape) 82 | cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2) 83 | return image -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import torch 4 | import click 5 | from segment_anything import sam_model_registry, SamPredictor 6 | from tqdm import tqdm 7 | from utils import read_yolo_label, yolo_to_coords, coords_to_yolo, find_minimal_rectangles 8 | 9 | @click.command() 10 | @click.option('--image_dir', default='datasets/images', help='Directory containing images.') 11 | @click.option('--label_dir', default='datasets/labels', help='Directory containing YOLO format labels.') 12 | @click.option('--refined_label_dir', default='datasets/labels_refine', help='Directory to save refined labels.') 13 | @click.option('--checkpoint', default='sam_vit_h_4b8939.pth', help='Path to the SAM model checkpoint.') 14 | @click.option('--model_type', default='vit_h', help='Type of the SAM model.') 15 | def refine_bounding_boxes(image_dir, label_dir, refined_label_dir, checkpoint, model_type): 16 | print(f"Refining bounding boxes for images in {image_dir}..., saving to {refined_label_dir}...") 17 | device = "cuda" if torch.cuda.is_available() else "cpu" 18 | 19 | # Initialize the Segment Anything model 20 | sam = sam_model_registry[model_type](checkpoint=checkpoint) 21 | sam.to(device=device) 22 | predictor = SamPredictor(sam) 23 | 24 | if not os.path.exists(refined_label_dir): 25 | os.makedirs(refined_label_dir) 26 | 27 | for image_name in tqdm(os.listdir(image_dir)): 28 | image_path = os.path.join(image_dir, image_name) 29 | label_path = os.path.join(label_dir, os.path.splitext(image_name)[0] + '.txt') 30 | 31 | if not os.path.exists(label_path): 32 | print(f"Label file not found for {image_name}") 33 | continue 34 | 35 | image = cv2.imread(image_path) 36 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 37 | image_shape = image.shape[:2] 38 | predictor.set_image(image) 39 | 40 | labels = read_yolo_label(label_path) 41 | input_boxes = [] 42 | 43 | if len(labels) == 0: 44 | yolo_bboxes = [] 45 | else: 46 | for label in labels: 47 | class_id, x_center, y_center, width, height = label 48 | rect_coords = yolo_to_coords((x_center, y_center, width, height), image_shape) 49 | input_boxes.append(rect_coords) 50 | 51 | input_boxes = torch.tensor(input_boxes, device=predictor.device) 52 | transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2]) 53 | masks, _, _ = predictor.predict_torch( 54 | point_coords=None, 55 | point_labels=None, 56 | boxes=transformed_boxes, 57 | multimask_output=False, 58 | ) 59 | 60 | minimal_rectangles = find_minimal_rectangles(masks.squeeze(1).cpu().numpy()) 61 | yolo_bboxes = coords_to_yolo(minimal_rectangles, image_shape) 62 | 63 | # Save refined labels 64 | refined_label_path = os.path.join(refined_label_dir, os.path.splitext(image_name)[0] + '.txt') 65 | with open(refined_label_path, 'w') as file: 66 | for idx, yolo_bbox in enumerate(yolo_bboxes): 67 | if yolo_bbox: 68 | class_id = labels[idx][0] 69 | file.write(f"{int(class_id)} {yolo_bbox[0]} {yolo_bbox[1]} {yolo_bbox[2]} {yolo_bbox[3]}\n") 70 | 71 | if __name__ == '__main__': 72 | refine_bounding_boxes() 73 | --------------------------------------------------------------------------------