├── .gitignore ├── DirectSAM ├── trainer.py └── utils.py ├── License ├── README.md ├── assets ├── DirectSAM_qingming.jpg ├── DirectSAM_visualizations.jpg ├── ade20k_finetuning_visualization.jpg ├── hkust_logo.png ├── teaser.png └── xiaobing_logo.jpg └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | core* 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | dist/ 16 | eggs/ 17 | *.egg-info/ 18 | bin/ 19 | include/ 20 | lib/ 21 | local/ 22 | man/ 23 | share/ 24 | pip-wheel-metadata/ 25 | htmlcov/ 26 | .coverage 27 | .tox/ 28 | nosetests.xml 29 | coverage.xml 30 | 31 | *.pt 32 | 33 | 34 | runs/ 35 | __pycache__/ -------------------------------------------------------------------------------- /DirectSAM/trainer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image as PILImage 4 | 5 | import torch 6 | import torchvision.transforms as transforms 7 | import torch.distributed as dist 8 | 9 | from datasets import Dataset, load_dataset 10 | from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation, TrainingArguments, Trainer 11 | 12 | 13 | def annotation_to_label(label_map, line_thickness=3): 14 | """ 15 | Parameters: 16 | label_map (PIL.Image): The input label map. 17 | line_thickness (int): The thickness of the lines that will be drawn for the contours. 18 | 19 | Returns: 20 | PIL.Image: The output binary boundary label image. 21 | """ 22 | label_map = np.array(label_map) 23 | all_contours = [] 24 | for label_idx in np.unique(label_map): 25 | mask = (label_map == label_idx).astype(np.uint8) 26 | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 27 | all_contours.append(contours) 28 | h, w = label_map.shape 29 | canvas = np.zeros((h, w, 3), dtype=np.uint8) 30 | for contours in all_contours: 31 | cv2.drawContours(canvas, contours, -1, (1, 1, 1), line_thickness) 32 | label = PILImage.fromarray(canvas[:, :, 0], mode='L') 33 | return label 34 | 35 | 36 | def transforms(example_batch): 37 | images = [x.convert("RGB") for x in example_batch["image"]] 38 | labels = [annotation_to_label(x) for x in example_batch["annotation"]] 39 | inputs = image_processor(images, labels, do_reduce_labels=False) 40 | return inputs 41 | 42 | 43 | if __name__=='__main__': 44 | 45 | dist.init_process_group(backend='nccl') 46 | 47 | dataset = load_dataset("scene_parse_150", split="train") 48 | dataset.set_transform(transforms) 49 | 50 | checkpoint = "chendelong/DirectSAM-1800px-0424" 51 | model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, num_labels=1, ignore_mismatched_sizes=True) 52 | image_processor = AutoImageProcessor.from_pretrained(checkpoint, reduce_labels=True) 53 | 54 | input_resolution = 512 55 | image_processor.size['height'] = input_resolution 56 | image_processor.size['width'] = input_resolution 57 | 58 | if torch.distributed.get_rank() == 0: 59 | print(model) 60 | print(f"Number of parameters: {model.num_parameters()/1e6}M, trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6}M") 61 | print(dataset) 62 | 63 | training_args = TrainingArguments( 64 | output_dir=f'runs/finetune-directsam-ade20k-5ep-512px', 65 | learning_rate=5e-5, 66 | num_train_epochs=3, 67 | per_device_train_batch_size=4, 68 | gradient_accumulation_steps=1, 69 | save_total_limit=3, 70 | dataloader_num_workers=4, 71 | dataloader_prefetch_factor=4, 72 | save_strategy="epoch", 73 | do_eval=False, 74 | logging_steps=1, 75 | remove_unused_columns=False, 76 | push_to_hub=False, 77 | fp16=True 78 | ) 79 | 80 | trainer = Trainer( 81 | model=model, 82 | args=training_args, 83 | train_dataset=dataset, 84 | ) 85 | 86 | trainer.train() 87 | 88 | -------------------------------------------------------------------------------- /DirectSAM/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | import math 6 | from itertools import product 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | 10 | 11 | def generate_crop_boxes(im_size, n_layers=1, overlap=0): 12 | """ 13 | Generates a list of crop boxes of different sizes. Each layer 14 | has (2**i)**2 boxes for the ith layer. 15 | """ 16 | 17 | crop_boxes, layer_idxs = [], [] 18 | im_w , im_h = im_size 19 | 20 | # Original image 21 | crop_boxes.append([0, 0, im_w, im_h]) 22 | layer_idxs.append(0) 23 | 24 | def crop_len(orig_len, n_crops, overlap): 25 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 26 | 27 | for i_layer in range(n_layers): 28 | n_crops_per_side = 2 ** (i_layer + 1) 29 | # overlap = int(overlap_ratio * min(im_h, im_w) * (2 / n_crops_per_side)) 30 | 31 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 32 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 33 | 34 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 35 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 36 | 37 | # Crops in XYWH format 38 | for x0, y0 in product(crop_box_x0, crop_box_y0): 39 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 40 | crop_boxes.append(box) 41 | layer_idxs.append(i_layer + 1) 42 | 43 | # only keep layer_id=n_layers 44 | crop_boxes = [box for box, layer in zip(crop_boxes, layer_idxs) if layer == n_layers] 45 | layer_idxs = [layer for layer in layer_idxs if layer == n_layers] 46 | 47 | return crop_boxes 48 | 49 | 50 | def boundary_zero_padding(probs, p=15): 51 | # from https://arxiv.org/abs/2308.13779 52 | 53 | zero_p = p//3 54 | alpha_p = zero_p*2 55 | 56 | probs[:, :alpha_p] *= 0.5 57 | probs[:, -alpha_p:] *= 0.5 58 | probs[:alpha_p, :] *= 0.5 59 | probs[-alpha_p:, :] *= 0.5 60 | 61 | probs[:, :zero_p] = 0 62 | probs[:, -zero_p:] = 0 63 | probs[:zero_p, :] = 0 64 | probs[-zero_p:, :] = 0 65 | 66 | return probs 67 | 68 | 69 | def inference_single_image(image, image_processor, model, pyramid_layers=0, overlap=90, resolution=None): 70 | 71 | if resolution: 72 | image_processor.size['height'] = resolution 73 | image_processor.size['width'] = resolution 74 | 75 | def run(image, bzp=0): 76 | encoding = image_processor(image, return_tensors="pt") 77 | pixel_values = encoding.pixel_values.to(model.device).to(model.dtype) 78 | 79 | with torch.no_grad(): 80 | outputs = model(pixel_values=pixel_values) 81 | logits = outputs.logits.float().cpu() 82 | 83 | upsampled_logits = nn.functional.interpolate( 84 | logits, 85 | size=image.size[::-1], 86 | mode="bilinear", 87 | align_corners=False, 88 | ) 89 | probs = torch.sigmoid(upsampled_logits)[0, 0].detach().numpy() 90 | 91 | if bzp>0: 92 | probs = boundary_zero_padding(probs, p=bzp) 93 | return probs 94 | 95 | global_probs = run(image) 96 | 97 | if pyramid_layers > 0: 98 | for layer in range(1, pyramid_layers+1): 99 | boxes = generate_crop_boxes(image.size, n_layers=layer, overlap=overlap) 100 | for box in boxes: 101 | x1, y1, x2, y2 = box 102 | crop = image.crop(box) 103 | probs = run(crop, bzp=overlap) 104 | global_probs[y1:y2, x1:x2] += probs 105 | global_probs /= (pyramid_layers + 1) 106 | 107 | return global_probs 108 | 109 | 110 | def probs_to_masks(probs, threshold=0.1): 111 | 112 | binarilized = (probs < threshold).astype(np.uint8) 113 | num_objects, labels = cv2.connectedComponents(binarilized) 114 | masks = [labels == i for i in range(1, labels.max() + 1)] 115 | masks.sort(key=lambda x: x.sum(), reverse=True) 116 | return masks 117 | 118 | 119 | def visualize_masks(image, masks): 120 | canvas = np.ones_like(image) * 255 121 | 122 | for i in range(len(masks)): 123 | mask = masks[i] 124 | color = np.mean(image[mask], axis=0) 125 | canvas[mask] = color 126 | return canvas 127 | 128 | 129 | def resize_to_max_length(image, max_length): 130 | width, height = image.size 131 | if width > height: 132 | new_width = max_length 133 | new_height = int(height * (max_length / width)) 134 | else: 135 | new_height = max_length 136 | new_width = int(width * (max_length / height)) 137 | return image.resize((new_width, new_height)) 138 | 139 | 140 | def visualize_direct_sam_result(probs, image, show_reconstruction=True, threshold=0.01, mask_cutoff = 256): 141 | 142 | plt.figure(figsize=(20, 10)) 143 | plt.subplot(1, 2, 1) 144 | plt.imshow(probs, cmap='PuBuGn') 145 | plt.title('Boundary Probabilities') 146 | plt.axis('off') 147 | 148 | plt.subplot(1, 2, 2) 149 | plt.imshow(probs > threshold, cmap='PuBuGn') 150 | plt.title(f'Binarilized Boundary Prediction (threshold={threshold})') 151 | plt.axis('off') 152 | 153 | plt.tight_layout() 154 | plt.show() 155 | 156 | if show_reconstruction: 157 | plt.figure(figsize=(20, 10)) 158 | plt.subplot(1, 2, 1) 159 | plt.imshow(image) 160 | plt.title(f'Input Image') 161 | plt.axis('off') 162 | 163 | plt.subplot(1, 2, 2) 164 | masks = probs_to_masks(probs, threshold=threshold) 165 | masks = sorted(masks, key=lambda x: np.sum(x), reverse=True)[:mask_cutoff] 166 | 167 | segment_visualization = visualize_masks(np.array(image), masks) 168 | plt.imshow(segment_visualization) 169 | 170 | plt.title(f'{len(masks)} Subobject Segments') 171 | plt.axis('off') 172 | plt.tight_layout() 173 | plt.show() 174 | -------------------------------------------------------------------------------- /License: -------------------------------------------------------------------------------- 1 | # Non-Commercial Research License 2 | 3 | The pretrained model weights provided herein are made available strictly for non-commercial research purposes only. 4 | 5 | This model was initialized from NVIDIA's SegFormer, which is subject to NVIDIA's [SegFormer License](https://github.com/NVlabs/SegFormer/blob/master/LICENSE), and trained on Meta's SA-1B dataset, which is subject to Meta's [SA-1B Dataset Research License](https://ai.meta.com/datasets/segment-anything-downloads/). Therefore, your use of this model must strictly adhere to the terms and conditions of both the SegFormer and SA-1B licenses. 6 | 7 | Any commercial use is explicitly prohibited. Users must ensure compliance with the original licenses from NVIDIA and Meta when using or distributing derivative works based on this model. 8 | 9 | No warranties or guarantees are provided with this model. The model is provided "as is." 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |