├── images ├── room.jpg └── demo_room.png ├── config └── default.yaml ├── README.md ├── feature_autogenerator.py ├── segment.py └── seganyclip.py /images/room.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minfenli/Segment-Anything-CLIP/HEAD/images/room.jpg -------------------------------------------------------------------------------- /images/demo_room.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minfenli/Segment-Anything-CLIP/HEAD/images/demo_room.png -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | use_openclip: False 2 | dowmsample: 8 3 | batch_size: 1024 4 | sam_checkpoint_loc: './checkpoints/' 5 | image_loc: './images/room.jpg' 6 | feature_loc: './room_8.pt' 7 | replace_feature: True 8 | output_dir: './output/' 9 | query: ['lamp', 'dog', 'table', 'plant', 'pillow', 'blanket', 'furniture', 'wooden floor', 'couch', 'room', 'drawing'] 10 | use_prompt_ensemble: True 11 | top_k: 2 12 | resolution: 3 13 | similarity_threshold: 0. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Per-pixel Features: Mating Segment-Anything with CLIP 2 | This repository aims to generate per-pixel features using pretrained models, [Segment-Anything](https://github.com/facebookresearch/segment-anything) and [CLIP](https://github.com/openai/CLIP). The pixel-aligned features are useful for downstream tasks such as visual grounding and VQA. First, we use the SAM to generate segmetation masks. Then, cropped images are sent into CLIP to extract semantic features. Finally, each pixel will be assigned semantic features according to its associated masks. 3 | 4 | Here, we show open-vocabulary segmentation without any training and finetuning. 5 | 6 | | Input Image | Segment Segmentation| 7 | | :---: | :---:| 8 | | image | image| 9 | 10 | 11 | ## Prepare 12 | 1. You may need to install [Segment-Anything](https://github.com/facebookresearch/segment-anything) and [CLIP](https://github.com/openai/CLIP) (or, [OpenCLIP](https://github.com/mlfoundations/open_clip)). 13 | 2. Download one of [SAM](https://github.com/facebookresearch/segment-anything#model-checkpoints) checkpoints from the SAM repository. 14 | 15 | ## Demo 16 | You can generate per-pixel features of an image. 17 | ``` 18 | python feature_autogenerator.py --image_path {image_path} --output_path {output_path} --output_name {feature_file_name} --checkpoint_dir {checkpoint_dir} 19 | ``` 20 | Or directly generate segmentation results by the given config file. 21 | ``` 22 | python segment.py --config_path {config_path} 23 | ``` 24 | 25 | ## Acknowledgement 26 | 1. [Segment-Anything](https://github.com/facebookresearch/segment-anything) 27 | 2. [CLIP](https://github.com/openai/CLIP) 28 | 3. [OpenCLIP](https://github.com/mlfoundations/open_clip) 29 | 30 | ## Citation 31 | If you find this work useful for your research, please consider citing this repo: 32 | 33 | ``` 34 | @misc{mingfengli_seganyclip, 35 | title={Per-pixel Features: Mating Segment-Anything with CLIP}, 36 | author={Li, Ming-Feng}, 37 | url={https://github.com/justin871030/Segment-Anything-CLIP}, 38 | year={2023} 39 | } 40 | ``` -------------------------------------------------------------------------------- /feature_autogenerator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import cv2 4 | import os 5 | 6 | from seganyclip import AutoSegmentAnything, OpenAICLIP, OpenCLIP, AutoSegAnyCLIP 7 | 8 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 9 | 10 | def main(args): 11 | clip = OpenCLIP() if args.openclip else OpenAICLIP() 12 | segany = AutoSegmentAnything(args.checkpoint_dir, "default", device) 13 | segclip = AutoSegAnyCLIP(segany, clip, device) 14 | 15 | image = cv2.imread(args.image_path) 16 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 17 | image = cv2.resize(image, (image.shape[1]//args.downsample, image.shape[0]//args.downsample), interpolation=cv2.INTER_AREA) 18 | 19 | segany.set_image(image) 20 | 21 | output_dir = os.path.join(args.output_dir, 'clip' if not args.openclip else 'openclip') 22 | 23 | if not os.path.isdir(output_dir): 24 | os.makedirs(output_dir) 25 | 26 | image_features_multires_crop = segclip.encode_image(bbox_crop=args.bbox_crop, extent_segmentation_mask=1) 27 | torch.save(image_features_multires_crop, 28 | os.path.join(output_dir, 29 | f'{args.output_name}_{args.downsample}.pt')) 30 | 31 | del image_features_multires_crop 32 | torch.cuda.empty_cache() 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser(description='Per-Pixel Semantic Feature Generator') 36 | parser.add_argument('--image_path', type=str, required=True, help='path of the image to generate features.') 37 | parser.add_argument('--output_dir', type=str, required=True, help='path to the directory to save features.') 38 | parser.add_argument('--output_name', type=str, required=True, help="file name of image features") 39 | parser.add_argument('--downsample', default=8, type=int, help='the scale of downsampling on the input image.') 40 | parser.add_argument('--openclip', default=False, type=bool, help='using OpenCLIP as feature extractor instead of openai-CLIP.') 41 | parser.add_argument('--bbox_crop', default=True, type=bool, help='using bbox instead of segmentation mask to crop the mask before encoding.') 42 | parser.add_argument('--checkpoint_dir', default='', type=str, help='path to the directory of SAM checkpoints.') 43 | parser.add_argument('--model_type', default='default', type=str, help='the type of the SAM.') 44 | 45 | args = parser.parse_args() 46 | 47 | main(args) -------------------------------------------------------------------------------- /segment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import cv2 4 | import os 5 | import yaml 6 | import numpy as np 7 | 8 | import matplotlib.pyplot as plt 9 | import matplotlib.colors as mcolors 10 | 11 | from seganyclip import OpenAICLIP, OpenCLIP, AutoSegmentAnything, AutoSegAnyCLIP 12 | 13 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 14 | 15 | np.random.seed(2) 16 | color_samples = [np.array([np.random.randint(0, 255) for _ in range(3)]) for _ in range(100)] 17 | 18 | def save_segment_map(segment_map, classes, save_loc): 19 | H, W = segment_map.shape 20 | color_map = np.ones((H, W, 3), dtype='uint8') * 155 21 | n_classes = len(classes) 22 | classes = classes.copy() 23 | color_masks = [] 24 | for i in range(n_classes): 25 | if classes[i] in mcolors.cnames.keys(): 26 | h = mcolors.cnames[classes[i]].strip('#') 27 | color_mask = np.array([int(h[i:i+2], 16) for i in (0, 2, 4)]) 28 | else: 29 | color_mask = color_samples[i%len(color_samples)] 30 | color_map[segment_map==i] = color_mask 31 | color_masks.append((*(color_mask/255.),1)) 32 | classes.insert(0, 'none') 33 | color_masks.insert(0, (0.6, 0.6, 0.6, 1)) 34 | handles = [plt.Rectangle((0, 0), 0, 0, color=color_masks[i], label=classes[i]) for i in range(n_classes+1)] 35 | plt.legend(handles=handles, framealpha=0.4, title='class') 36 | plt.imshow(color_map) 37 | plt.axis('off') 38 | plt.savefig(save_loc, bbox_inches='tight', pad_inches=0) 39 | 40 | 41 | def main(args): 42 | with open(args.config_path) as f: 43 | config = yaml.load(f, Loader=yaml.FullLoader) 44 | 45 | output_dir = config['output_dir'] 46 | 47 | if not os.path.isdir(output_dir): 48 | os.makedirs(output_dir) 49 | 50 | feature_loc = config['feature_loc'] 51 | 52 | use_openclip = config['use_openclip'] 53 | 54 | querys = config['query'] 55 | use_prompt_ensemble = config['use_prompt_ensemble'] 56 | batch_size = config['batch_size'] 57 | top_k = config['top_k'] 58 | threshold = config['similarity_threshold'] 59 | 60 | assert(top_k >= 1) 61 | 62 | if os.path.exists(feature_loc) and not config['replace_feature']: 63 | print(f'Load model.') 64 | clip = OpenCLIP() if use_openclip else OpenAICLIP() 65 | print(f'Load features from {feature_loc}.') 66 | image_features = torch.load(feature_loc) 67 | else: 68 | print(f'Load model.') 69 | clip = OpenCLIP() if use_openclip else OpenAICLIP() 70 | segany = AutoSegmentAnything(config['sam_checkpoint_loc'], "default", device) 71 | segclip = AutoSegAnyCLIP(segany, clip, device) 72 | 73 | print(f'Not find features at {feature_loc}.') 74 | print(f'Generate and store features at \'{feature_loc}\'.') 75 | 76 | image_loc = config['image_loc'] 77 | dowmsample = config['dowmsample'] 78 | resolution = config['resolution'] 79 | 80 | assert(resolution > 0 and resolution <= 3) 81 | 82 | image = cv2.imread(image_loc) 83 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 84 | image = cv2.resize(image, (image.shape[1]//dowmsample, image.shape[0]//dowmsample), interpolation=cv2.INTER_AREA) 85 | 86 | segany.set_image(image) 87 | segclip.n_objs = resolution 88 | image_features = segclip.encode_image(extent_segmentation_mask=1, bbox_crop=False) 89 | 90 | torch.save(image_features, feature_loc) 91 | 92 | del segany, segclip 93 | torch.cuda.empty_cache() 94 | 95 | similarity_argmax_top_k = clip.predict_similarity_objects_with_feature_attention_batch(image_features, querys, use_prompt_ensemble, batch_size, top_k, threshold) 96 | 97 | similarity_argmax_top_k = [similarity_argmax.cpu().numpy() for similarity_argmax in similarity_argmax_top_k] 98 | 99 | for k, similarity_argmax in enumerate(similarity_argmax_top_k): 100 | save_segment_map(similarity_argmax, querys, os.path.join(output_dir, f'similarity_max{k}.png')) 101 | 102 | with open(os.path.join(output_dir, 'config.yaml'), 'w') as f: 103 | yaml.dump(config, f) 104 | 105 | 106 | if __name__ == "__main__": 107 | parser = argparse.ArgumentParser(description='Per-Pixel Semantic Feature Generator') 108 | parser.add_argument('--config_path', type=str, required=True, help='path of the config.') 109 | 110 | args = parser.parse_args() 111 | 112 | main(args) -------------------------------------------------------------------------------- /seganyclip.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator 3 | import numpy as np 4 | import open_clip 5 | import clip 6 | from PIL import Image 7 | import torch 8 | from tqdm import tqdm 9 | import os 10 | import math 11 | import cv2 12 | 13 | class SegmentAnything: 14 | def __init__(self, 15 | data_dir="/media/public_dataset/segany/", 16 | model_type="vit_h", 17 | device='cuda'): 18 | data_dir = data_dir 19 | model_type = model_type 20 | checkpoint_name = { 21 | "default": 'sam_vit_h_4b8939.pth', 22 | "vit_h": 'sam_vit_h_4b8939.pth', 23 | "vit_l": 'sam_vit_l_0b3195.pth', 24 | "vit_b": 'sam_vit_b_01ec64.pth' 25 | } 26 | sam_checkpoint = checkpoint_name[model_type] 27 | 28 | sam_checkpoint = data_dir + sam_checkpoint 29 | 30 | self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 31 | self.sam.to(device=device) 32 | 33 | self.predictor = SamPredictor(self.sam) 34 | self.device = device 35 | 36 | def set_image(self, image): 37 | self.image = image 38 | self.predictor.set_image(image) 39 | 40 | def predict_object_coord(self, coord, without_mask=False): 41 | # predict objects by the input coordinate (x, y) 42 | # return 3 or less object predicted by Seg-Anything with their scores 43 | # mask out pixels that are not related to objects if 'without_mask'==False 44 | # crop rectangles related to objects if 'without_mask'==True 45 | input_point = np.array([[coord[0], coord[1]]]) 46 | input_label = np.array([1]) 47 | masks, scores, logits = self.predictor.predict( 48 | point_coords=input_point, 49 | point_labels=input_label, 50 | multimask_output=True, 51 | ) 52 | 53 | objects = [] 54 | object_scores = [] 55 | for i, (mask, score) in enumerate(zip(masks, scores)): 56 | if not mask.any(): 57 | continue 58 | image_object = self.image.copy() 59 | if not without_mask: 60 | image_object[np.logical_not(mask)] = (255,255,255) 61 | xmin, ymin, xmax, ymax = self.from_mask_to_bbox(mask) 62 | image_object = image_object[ymin:ymax+1, xmin:xmax+1] 63 | objects.append(image_object) 64 | object_scores.append(score) 65 | 66 | return objects, object_scores 67 | 68 | def crop_multires_around_coord(self, coord, down=2, res=3): 69 | # crop rectangles in multi-resolutions around the input coordinate (x, y) 70 | 71 | H, W, _ = self.image.shape 72 | 73 | assert(down > 0 and res > 0) 74 | 75 | objects = [] 76 | H_, W_ = H//down, W//down 77 | for _ in range(res): 78 | H_, W_ = H_//down, W_//down 79 | image_object = self.image.copy() 80 | xmin, ymin, xmax, ymax = max(coord[0]-W_, 0), max(coord[1]-H_, 0), min(coord[0]+W_, W), min(coord[1]+H_, H) 81 | image_object = image_object[ymin:ymax, xmin:xmax] 82 | objects.append(image_object) 83 | 84 | return objects 85 | 86 | @staticmethod 87 | def from_mask_to_bbox(mask): 88 | mask_indices = np.where(mask) 89 | xmin, ymin, xmax, ymax = min(mask_indices[1]), min(mask_indices[0]), max(mask_indices[1]), max(mask_indices[0]) 90 | return xmin, ymin, xmax, ymax 91 | 92 | @staticmethod 93 | def make_per_pixel_point_prompt(image_size): 94 | # image_size: H x W 95 | x = np.arange(image_size[1]) 96 | y = np.arange(image_size[0]) 97 | xv, yv = np.meshgrid(x, y) 98 | 99 | points = np.stack([yv, xv], axis=-1).reshape(-1, 1, 2) 100 | labels = np.ones(image_size[0]*image_size[1]).reshape(-1, 1) 101 | return points, labels 102 | 103 | 104 | class AutoSegmentAnything: 105 | def __init__(self, 106 | data_dir="./checkpoints/", 107 | model_type="vit_h", 108 | device='cuda'): 109 | data_dir = data_dir 110 | model_type = model_type 111 | checkpoint_name = { 112 | "default": 'sam_vit_h_4b8939.pth', 113 | "vit_h": 'sam_vit_h_4b8939.pth', 114 | "vit_l": 'sam_vit_l_0b3195.pth', 115 | "vit_b": 'sam_vit_b_01ec64.pth' 116 | } 117 | sam_checkpoint = checkpoint_name[model_type] 118 | 119 | sam_checkpoint = os.path.join(data_dir, sam_checkpoint) 120 | 121 | self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 122 | self.sam.to(device=device) 123 | 124 | self.generator = SamAutomaticMaskGenerator( 125 | model=self.sam, 126 | points_per_side = 32, 127 | points_per_batch = 64, 128 | pred_iou_thresh = 0.86, 129 | stability_score_thresh = 0.92, 130 | stability_score_offset = 1.0, 131 | box_nms_thresh = 0.75, 132 | crop_n_layers = 2, 133 | crop_nms_thresh = 0.75, 134 | crop_overlap_ratio = 0.66, 135 | crop_n_points_downscale_factor = 2, 136 | min_mask_region_area = 100 137 | ) 138 | 139 | # self.generator = SamAutomaticMaskGenerator( 140 | # model=self.sam, 141 | # points_per_side=64, 142 | # pred_iou_thresh=0.8, 143 | # stability_score_thresh=0.8, 144 | # crop_n_layers=0, 145 | # crop_n_points_downscale_factor=0, 146 | # min_mask_region_area=100, # Requires open-cv to run post-processing 147 | # ) 148 | 149 | self.device = device 150 | 151 | def set_image(self, image): 152 | self.image = image 153 | 154 | def generate_masks(self): 155 | return self.generator.generate(self.image) 156 | 157 | class CLIP: 158 | def __init__(self, similarity_scale=10): 159 | self.similarity_scale = similarity_scale 160 | 161 | def set_similarity_scale(self, similarity_scale): 162 | self.similarity_scale = similarity_scale 163 | 164 | def encode_image(self, image): 165 | image = self.preprocess(Image.fromarray(image)).unsqueeze(0).to(device=self.device).half() 166 | with torch.no_grad(): 167 | image_features = self.model.encode_image(image) 168 | image_features /= image_features.norm(dim=-1, keepdim=True) 169 | return image_features 170 | 171 | def encode_text(self, text_list): 172 | text = self.tokenizer(text_list).to(device=self.device) 173 | with torch.no_grad(): 174 | text_features = self.model.encode_text(text) 175 | text_features /= text_features.norm(dim=-1, keepdim=True) 176 | return text_features 177 | 178 | def encode_text_with_prompt_ensemble(self, text_list, prompt_templates=None): 179 | # using default prompt templates for ImageNet 180 | if prompt_templates == None: 181 | # prompt_templates = ['a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.'] 182 | # easier ones 183 | prompt_templates = ['a photo of a {}.', 'This is a photo of a {}', 'This is a photo of a small {}', 'This is a photo of a medium {}', 'This is a photo of a large {}', 'This is a photo of a {}', 'This is a photo of a small {}', 'This is a photo of a medium {}', 'This is a photo of a large {}', 'a photo of a {} in the scene', 'a photo of a {} in the scene', 'There is a {} in the scene', 'There is the {} in the scene', 'This is a {} in the scene', 'This is the {} in the scene', 'This is one {} in the scene'] 184 | 185 | with torch.no_grad(): 186 | text_features = [] 187 | for t in text_list: 188 | prompted_t = [template.format(t) for template in prompt_templates] 189 | class_embeddings = self.encode_text(prompted_t) 190 | class_embedding = class_embeddings.mean(dim=0) 191 | class_embedding /= class_embedding.norm() 192 | text_features.append(class_embedding) 193 | text_features = torch.stack(text_features, dim=0) 194 | 195 | return text_features 196 | 197 | def predict_similarity_objects_with_feature_attention_batch(self, image_features, text, prompt_ensemble=False, batch_size=1024, top_k=1, threshold=0., projection=None): 198 | 199 | # use prompt templetes to prompt input texts 200 | text_features = self.encode_text(text) if not prompt_ensemble else self.encode_text_with_prompt_ensemble(text) 201 | if projection is not None: 202 | text_features = projection(text_features.float()).half() 203 | 204 | image_shape = image_features.shape[:2] 205 | 206 | batches = self.separate_image_features_batches(image_features, batch_size) 207 | batches_similarity = [[] for _ in range(top_k)] 208 | 209 | for image_features in batches: 210 | # don't need fuse features if only one dim. 211 | if image_features.shape[-2] != 1: 212 | feature_similarity = self.similarity_scale * (image_features @ text_features.T) 213 | feature_similarity = torch.moveaxis(feature_similarity, -1, 0).softmax(axis=-1) 214 | 215 | image_features = (image_features[None,...] * feature_similarity[...,None]).sum(axis=-2) 216 | # overall similarity of objects (that detected from a pixel) with the text prompt 217 | similarity = self.similarity_scale * (image_features @ text_features.T) 218 | similarity = torch.stack([similarity[i, ..., i] for i in range(len(similarity))]) 219 | else: 220 | similarity = torch.moveaxis((image_features @ text_features.T), -1, 0).squeeze(-1) 221 | 222 | for i in range(top_k): 223 | similarity_max = similarity.max(0).values 224 | similarity_argmax = similarity.argmax(0) 225 | similarity[similarity_argmax] = -1 226 | similarity_argmax[similarity_max <= threshold] = -1 227 | batches_similarity[i].append(similarity_argmax.cpu()) 228 | 229 | similarity_argmax = [self.merge_image_features_batches(batches_similarity[i], image_shape) for i in range(top_k)] 230 | 231 | # similarity (len(text), len(text)): similarity scores when each text as input (input_text_for_attention, H, W, relation_with_each_text) 232 | return similarity_argmax 233 | 234 | @staticmethod 235 | def separate_image_features_batches(image_features, batch_size=1024): 236 | H, W = image_features.shape[:2] 237 | image_features = image_features.reshape(H*W, *image_features.shape[2:]) 238 | batch_indices = [] 239 | idx = 0 240 | while idx < H*W: 241 | batch_indices.append((idx, min(idx+batch_size, H*W))) 242 | idx += batch_size 243 | 244 | batches = [] 245 | for start, end in batch_indices: 246 | batches.append(image_features[start:end]) 247 | 248 | return batches 249 | 250 | @staticmethod 251 | def merge_image_features_batches(batches, image_shape): 252 | H, W = image_shape 253 | image_features = torch.cat(batches, axis=0) 254 | 255 | return image_features.reshape(H, W, *image_features.shape[1:]) 256 | 257 | 258 | class OpenAICLIP(CLIP): 259 | def __init__(self, 260 | model_type='ViT-B/16', 261 | device='cuda'): 262 | super().__init__() 263 | 264 | self.model, self.preprocess = clip.load(model_type) 265 | self.model.to(device=device) 266 | self.tokenizer = clip.tokenize 267 | self.device = device 268 | 269 | class OpenCLIP(CLIP): 270 | def __init__(self, 271 | model_type='ViT-B-16', 272 | pretrained='laion2b_s34b_b88k', 273 | device='cuda'): 274 | super().__init__() 275 | 276 | self.model, _, self.preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained=pretrained, precision="fp16") 277 | self.model.to(device=device) 278 | self.tokenizer = open_clip.get_tokenizer(model_type) 279 | self.device = device 280 | 281 | 282 | class SegAnyCLIP: 283 | def __init__(self, 284 | segany, 285 | clip, 286 | device='cuda'): 287 | self.segany = segany 288 | self.clip = clip 289 | self.clip_n_dims = 512 290 | self.n_objs = 3 291 | self.zeros = torch.zeros((1, self.clip_n_dims), device=device).half() 292 | self.device = device 293 | 294 | def encode_image(self, without_mask=False): 295 | # predect per-pixels features for 'image' in 'SegmentAnything' based on predicted objects 296 | # mask out pixels that are not related to objects if 'without_mask'==False 297 | # crop rectangles related to objects if 'without_mask'==True 298 | # output shape: (H, W, clip_n_dims, n_objs) 299 | 300 | image = self.segany.image 301 | coords = self.make_per_pixel_point_prompt(image.shape) 302 | 303 | H, W, _ = image.shape 304 | 305 | image_pixel_embeddings = [] 306 | 307 | with torch.no_grad(): 308 | for coord in tqdm(coords): 309 | objects, _ = self.segany.predict_object_coord(coord, without_mask=without_mask) 310 | pixel_embeddings = [] 311 | for single_object in objects: 312 | pixel_embeddings.append(self.clip.encode_image(single_object)) 313 | for _ in range(self.n_objs-len(objects)): 314 | pixel_embeddings.append(self.zeros) 315 | pixel_embeddings = torch.cat(pixel_embeddings, dim=0) 316 | image_pixel_embeddings.append(pixel_embeddings) 317 | 318 | image_pixel_embeddings = torch.cat(image_pixel_embeddings, axis=0).reshape(H, W, self.n_objs, self.clip_n_dims) 319 | return image_pixel_embeddings 320 | 321 | def encode_image_multires_crop(self): 322 | # predect per-pixels features for 'image' in 'SegmentAnything' by cropping 323 | # crop rectangles related to objects if 'without_mask'==True 324 | # output shape: (H, W, clip_n_dims, n_objs) 325 | 326 | image = self.segany.image 327 | coords = self.make_per_pixel_point_prompt(image.shape) 328 | 329 | H, W, _ = image.shape 330 | 331 | image_pixel_embeddings = [] 332 | 333 | for coord in tqdm(coords): 334 | objects = self.segany.crop_multires_around_coord(coord) 335 | pixel_embeddings = [] 336 | for single_object in objects: 337 | pixel_embeddings.append(self.clip.encode_image(single_object)) 338 | for _ in range(self.n_objs-len(objects)): 339 | pixel_embeddings.append(self.zeros) 340 | pixel_embeddings = torch.cat(pixel_embeddings, dim=0) 341 | image_pixel_embeddings.append(pixel_embeddings) 342 | 343 | image_pixel_embeddings = torch.cat(image_pixel_embeddings, axis=0).reshape(H, W, self.n_objs, self.clip_n_dims) 344 | return image_pixel_embeddings 345 | 346 | @staticmethod 347 | def make_per_pixel_point_prompt(image_size): 348 | # image_size: H x W 349 | x = np.arange(image_size[1]) 350 | y = np.arange(image_size[0]) 351 | xv, yv = np.meshgrid(x, y) 352 | points = np.stack([xv, yv], axis=-1).reshape(-1, 2) 353 | return points 354 | 355 | 356 | class AutoSegAnyCLIP: 357 | def __init__(self, 358 | segany, 359 | clip, 360 | device='cuda'): 361 | self.segany = segany 362 | self.clip = clip 363 | self.clip_n_dims = 512 364 | self.n_objs = 3 365 | self.zeros = torch.zeros((1, self.clip_n_dims), device=device).half() 366 | self.device = device 367 | 368 | def encode_image(self, bbox_crop=False, extent_segmentation_mask=0, blur=False): 369 | # predect per-pixels features for 'image' in 'SegmentAnything' based on predicted objects 370 | # mask out pixels that are not related to objects if 'bbox_crop'==False 371 | # crop rectangles related to objects if 'bbox_crop'==True 372 | # extent_segmentation_mask: extent pixels of an area from each segmentation mask for bigger coverage 373 | # output shape: (H, W, clip_n_dims, n_objs) 374 | 375 | image = self.segany.image 376 | 377 | H, W, _ = image.shape 378 | 379 | image_pixel_embeddings = [] 380 | 381 | masks = self.segany.generate_masks() 382 | 383 | check_mask_covered = torch.zeros(image.shape[:2]) 384 | 385 | for i, mask in enumerate(masks): 386 | masks[i]['segmentation'] = self.segmentmap_extent_multi(mask['segmentation'], extent_segmentation_mask) 387 | check_mask_covered[masks[i]['segmentation']] = 1 388 | 389 | point_to_mask = {} 390 | for y in range(H): 391 | for x in range(W): 392 | point_to_mask[(x, y)] = [] 393 | for i, mask in enumerate(masks): 394 | ys, xs = np.where(mask['segmentation']) 395 | for x, y in zip(xs, ys): 396 | point_to_mask[(x, y)] += [i] 397 | 398 | objects = [] 399 | object_scores = [] 400 | object_areas = [] 401 | 402 | background_color = np.array([255.,255.,255.])*0. 403 | 404 | for i, mask in enumerate(masks): 405 | image_object = image.copy().astype('float') 406 | if bbox_crop: 407 | if blur: 408 | image_blur = cv2.GaussianBlur(image_object, (5, 5), 0) 409 | image_object[np.logical_not(mask['segmentation'])] = image_blur[np.logical_not(mask['segmentation'])] 410 | image_object[np.logical_not(mask['segmentation'])] *= 0.75 411 | image_object[np.logical_not(mask['segmentation'])] += background_color * 0.25 412 | else: 413 | image_object[np.logical_not(mask['segmentation'])] = background_color 414 | image_object = image_object.astype('uint8') 415 | xmin, ymin, xmax, ymax = self.from_mask_to_bbox(mask['segmentation'], extent=0.01) 416 | image_object = image_object[ymin:ymax+1, xmin:xmax+1] 417 | objects.append(image_object) 418 | object_scores.append(mask['predicted_iou']) 419 | object_areas.append(mask['area']) 420 | 421 | for point in point_to_mask.keys(): 422 | point_to_mask[point] = sorted(point_to_mask[point], key=lambda x: (object_areas[x], object_scores[x]), reverse=True) 423 | 424 | objects_embeddings = [] 425 | for single_object in objects: 426 | objects_embeddings.append(self.clip.encode_image(single_object)) 427 | 428 | # self.zeros = self.clip.encode_image(image) 429 | # with torch.no_grad(): 430 | # image_crop = image.copy() 431 | # mask_covered = (check_mask_covered==1) 432 | # mask_not_covered = (check_mask_covered==0) 433 | # image_crop[mask_covered] = (0, 0, 0) 434 | # if mask_not_covered.any(): 435 | # xmin, ymin, xmax, ymax = self.from_mask_to_bbox(mask_not_covered) 436 | # image_crop = image_crop[ymin:ymax+1, xmin:xmax+1] 437 | # self.zeros = self.clip.encode_image(image_crop) 438 | 439 | image_pixel_embeddings = [] 440 | for y in range(H): 441 | for x in range(W): 442 | pixel_embeddings = [objects_embeddings[object_id] for object_id in point_to_mask[(x, y)][:self.n_objs]] 443 | for i in range(self.n_objs-len(pixel_embeddings)): 444 | pixel_embeddings.append(self.zeros) 445 | image_pixel_embeddings.append(torch.cat(pixel_embeddings, axis=0)) 446 | image_pixel_embeddings = torch.cat(image_pixel_embeddings, axis=0).reshape(H, W, self.n_objs, self.clip_n_dims) 447 | 448 | return image_pixel_embeddings 449 | 450 | def encode_image_concept_fusion(self, bbox_crop=False, extent_segmentation_mask=0): 451 | # predect per-pixels features for 'image' in 'SegmentAnything' based on predicted objects 452 | # mask out pixels that are not related to objects if 'bbox_crop'==False 453 | # crop rectangles related to objects if 'bbox_crop'==True 454 | # extent_segmentation_mask: extent pixels of area from each segmentation mask for bigger coverage 455 | # output shape: (H, W, clip_n_dims, n_objs) 456 | 457 | image = self.segany.image 458 | 459 | H, W, _ = image.shape 460 | 461 | image_pixel_embeddings = [] 462 | 463 | masks = self.segany.generate_masks() 464 | 465 | for i, mask in enumerate(masks): 466 | masks[i]['segmentation'] = self.segmentmap_extent_multi(mask['segmentation'], extent_segmentation_mask) 467 | 468 | point_to_mask = {} 469 | for y in range(H): 470 | for x in range(W): 471 | point_to_mask[(x, y)] = [] 472 | for i, mask in enumerate(masks): 473 | ys, xs = np.where(mask['segmentation']) 474 | for x, y in zip(xs, ys): 475 | point_to_mask[(x, y)] += [i] 476 | 477 | objects = [] 478 | object_scores = [] 479 | object_areas = [] 480 | for i, mask in enumerate(masks): 481 | image_object = image.copy() 482 | if not bbox_crop: 483 | image_object[np.logical_not(mask['segmentation'])] = (255,255,255) 484 | xmin, ymin, xmax, ymax = self.from_mask_to_bbox(mask['segmentation']) 485 | image_object = image_object[ymin:ymax+1, xmin:xmax+1] 486 | objects.append(image_object) 487 | object_scores.append(mask['predicted_iou']) 488 | object_areas.append(mask['area']) 489 | 490 | for point in point_to_mask.keys(): 491 | point_to_mask[point] = sorted(point_to_mask[point], key=lambda x: (object_areas[x], object_scores[x]), reverse=True) 492 | 493 | objects_embeddings = [] 494 | for single_object in objects: 495 | objects_embeddings.append(self.clip.encode_image(single_object)) 496 | 497 | image_embeddings = self.clip.encode_image(image) 498 | objects_embeddings = torch.cat(objects_embeddings, axis=0) 499 | objects_local_global_similarity = (image_embeddings @ objects_embeddings.T).squeeze(0) 500 | objects_cross_similarity = (objects_embeddings @ objects_embeddings.T) 501 | objects_self_similarity = torch.stack([objects_cross_similarity[i, i] for i in range(len(objects_cross_similarity))]) 502 | objects_avg_cross_similarity = ((objects_cross_similarity.sum(axis=-1) - objects_self_similarity)) / (len(objects_cross_similarity)-1) 503 | 504 | t = 1 505 | w_global = ((objects_local_global_similarity + objects_avg_cross_similarity)/t).softmax(-1) 506 | 507 | objects_embeddings_fusion = ((w_global[:, None] * image_embeddings) + ((1 - w_global[:, None])*objects_embeddings)) 508 | objects_embeddings_fusion /= objects_embeddings_fusion.norm(dim=-1, keepdim=True) 509 | 510 | image_pixel_embeddings = [] 511 | for y in range(H): 512 | for x in range(W): 513 | pixel_embeddings = [objects_embeddings_fusion[object_id][None, :] for object_id in point_to_mask[(x, y)][:self.n_objs]] 514 | for i in range(self.n_objs-len(pixel_embeddings)): 515 | pixel_embeddings.append(self.zeros) 516 | image_pixel_embeddings.append(torch.cat(pixel_embeddings, axis=0)) 517 | image_pixel_embeddings = torch.cat(image_pixel_embeddings, axis=0).reshape(H, W, self.n_objs, self.clip_n_dims) 518 | 519 | return image_pixel_embeddings 520 | 521 | @staticmethod 522 | def from_mask_to_bbox(mask, extent=0, sqrt=False): 523 | H, W = mask.shape[:2] 524 | mask_indices = np.where(mask) 525 | xmin, ymin, xmax, ymax = min(mask_indices[1]), min(mask_indices[0]), max(mask_indices[1]), max(mask_indices[0]) 526 | if extent > 0: 527 | if sqrt: 528 | x_extent, y_extent = math.ceil(math.sqrt((xmax-xmin)*extent)), math.ceil(math.sqrt((ymax-ymin)*extent)) 529 | # x_extent, y_extent = max(x_extent, y_extent), max(x_extent, y_extent) 530 | else: 531 | x_extent, y_extent = math.ceil((xmax-xmin)*extent), math.ceil((ymax-ymin)*extent) 532 | # x_extent, y_extent = max(x_extent, y_extent), max(x_extent, y_extent) 533 | xmin, ymin, xmax, ymax = max(xmin-x_extent, 0), max(ymin-y_extent, 0), min(xmax+x_extent, W-1), min(ymax+y_extent, H-1) 534 | return xmin, ymin, xmax, ymax 535 | 536 | @staticmethod 537 | def segmentmap_extent(segmentmap): 538 | segmentmap_extent = segmentmap.copy() 539 | H, W = segmentmap.shape 540 | ys, xs = np.where(segmentmap) 541 | for x, y in zip(xs, ys): 542 | extents = [(x, max(0, y-1)), (max(0, x-1), y), (min(W-1, x+1), y), (x, min(H-1, y+1))] 543 | for x, y in extents: 544 | segmentmap_extent[y, x] = True 545 | return segmentmap_extent 546 | 547 | def segmentmap_extent_multi(self, segmentmap, time=2): 548 | segmentmap_temp = segmentmap.copy() 549 | for _ in range(time): 550 | segmentmap_temp = self.segmentmap_extent(segmentmap_temp) 551 | return segmentmap_temp --------------------------------------------------------------------------------