├── images
├── room.jpg
└── demo_room.png
├── config
└── default.yaml
├── README.md
├── feature_autogenerator.py
├── segment.py
└── seganyclip.py
/images/room.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minfenli/Segment-Anything-CLIP/HEAD/images/room.jpg
--------------------------------------------------------------------------------
/images/demo_room.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minfenli/Segment-Anything-CLIP/HEAD/images/demo_room.png
--------------------------------------------------------------------------------
/config/default.yaml:
--------------------------------------------------------------------------------
1 | use_openclip: False
2 | dowmsample: 8
3 | batch_size: 1024
4 | sam_checkpoint_loc: './checkpoints/'
5 | image_loc: './images/room.jpg'
6 | feature_loc: './room_8.pt'
7 | replace_feature: True
8 | output_dir: './output/'
9 | query: ['lamp', 'dog', 'table', 'plant', 'pillow', 'blanket', 'furniture', 'wooden floor', 'couch', 'room', 'drawing']
10 | use_prompt_ensemble: True
11 | top_k: 2
12 | resolution: 3
13 | similarity_threshold: 0.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Per-pixel Features: Mating Segment-Anything with CLIP
2 | This repository aims to generate per-pixel features using pretrained models, [Segment-Anything](https://github.com/facebookresearch/segment-anything) and [CLIP](https://github.com/openai/CLIP). The pixel-aligned features are useful for downstream tasks such as visual grounding and VQA. First, we use the SAM to generate segmetation masks. Then, cropped images are sent into CLIP to extract semantic features. Finally, each pixel will be assigned semantic features according to its associated masks.
3 |
4 | Here, we show open-vocabulary segmentation without any training and finetuning.
5 |
6 | | Input Image | Segment Segmentation|
7 | | :---: | :---:|
8 | |
|
|
9 |
10 |
11 | ## Prepare
12 | 1. You may need to install [Segment-Anything](https://github.com/facebookresearch/segment-anything) and [CLIP](https://github.com/openai/CLIP) (or, [OpenCLIP](https://github.com/mlfoundations/open_clip)).
13 | 2. Download one of [SAM](https://github.com/facebookresearch/segment-anything#model-checkpoints) checkpoints from the SAM repository.
14 |
15 | ## Demo
16 | You can generate per-pixel features of an image.
17 | ```
18 | python feature_autogenerator.py --image_path {image_path} --output_path {output_path} --output_name {feature_file_name} --checkpoint_dir {checkpoint_dir}
19 | ```
20 | Or directly generate segmentation results by the given config file.
21 | ```
22 | python segment.py --config_path {config_path}
23 | ```
24 |
25 | ## Acknowledgement
26 | 1. [Segment-Anything](https://github.com/facebookresearch/segment-anything)
27 | 2. [CLIP](https://github.com/openai/CLIP)
28 | 3. [OpenCLIP](https://github.com/mlfoundations/open_clip)
29 |
30 | ## Citation
31 | If you find this work useful for your research, please consider citing this repo:
32 |
33 | ```
34 | @misc{mingfengli_seganyclip,
35 | title={Per-pixel Features: Mating Segment-Anything with CLIP},
36 | author={Li, Ming-Feng},
37 | url={https://github.com/justin871030/Segment-Anything-CLIP},
38 | year={2023}
39 | }
40 | ```
--------------------------------------------------------------------------------
/feature_autogenerator.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import cv2
4 | import os
5 |
6 | from seganyclip import AutoSegmentAnything, OpenAICLIP, OpenCLIP, AutoSegAnyCLIP
7 |
8 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
9 |
10 | def main(args):
11 | clip = OpenCLIP() if args.openclip else OpenAICLIP()
12 | segany = AutoSegmentAnything(args.checkpoint_dir, "default", device)
13 | segclip = AutoSegAnyCLIP(segany, clip, device)
14 |
15 | image = cv2.imread(args.image_path)
16 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
17 | image = cv2.resize(image, (image.shape[1]//args.downsample, image.shape[0]//args.downsample), interpolation=cv2.INTER_AREA)
18 |
19 | segany.set_image(image)
20 |
21 | output_dir = os.path.join(args.output_dir, 'clip' if not args.openclip else 'openclip')
22 |
23 | if not os.path.isdir(output_dir):
24 | os.makedirs(output_dir)
25 |
26 | image_features_multires_crop = segclip.encode_image(bbox_crop=args.bbox_crop, extent_segmentation_mask=1)
27 | torch.save(image_features_multires_crop,
28 | os.path.join(output_dir,
29 | f'{args.output_name}_{args.downsample}.pt'))
30 |
31 | del image_features_multires_crop
32 | torch.cuda.empty_cache()
33 |
34 | if __name__ == "__main__":
35 | parser = argparse.ArgumentParser(description='Per-Pixel Semantic Feature Generator')
36 | parser.add_argument('--image_path', type=str, required=True, help='path of the image to generate features.')
37 | parser.add_argument('--output_dir', type=str, required=True, help='path to the directory to save features.')
38 | parser.add_argument('--output_name', type=str, required=True, help="file name of image features")
39 | parser.add_argument('--downsample', default=8, type=int, help='the scale of downsampling on the input image.')
40 | parser.add_argument('--openclip', default=False, type=bool, help='using OpenCLIP as feature extractor instead of openai-CLIP.')
41 | parser.add_argument('--bbox_crop', default=True, type=bool, help='using bbox instead of segmentation mask to crop the mask before encoding.')
42 | parser.add_argument('--checkpoint_dir', default='', type=str, help='path to the directory of SAM checkpoints.')
43 | parser.add_argument('--model_type', default='default', type=str, help='the type of the SAM.')
44 |
45 | args = parser.parse_args()
46 |
47 | main(args)
--------------------------------------------------------------------------------
/segment.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import cv2
4 | import os
5 | import yaml
6 | import numpy as np
7 |
8 | import matplotlib.pyplot as plt
9 | import matplotlib.colors as mcolors
10 |
11 | from seganyclip import OpenAICLIP, OpenCLIP, AutoSegmentAnything, AutoSegAnyCLIP
12 |
13 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
14 |
15 | np.random.seed(2)
16 | color_samples = [np.array([np.random.randint(0, 255) for _ in range(3)]) for _ in range(100)]
17 |
18 | def save_segment_map(segment_map, classes, save_loc):
19 | H, W = segment_map.shape
20 | color_map = np.ones((H, W, 3), dtype='uint8') * 155
21 | n_classes = len(classes)
22 | classes = classes.copy()
23 | color_masks = []
24 | for i in range(n_classes):
25 | if classes[i] in mcolors.cnames.keys():
26 | h = mcolors.cnames[classes[i]].strip('#')
27 | color_mask = np.array([int(h[i:i+2], 16) for i in (0, 2, 4)])
28 | else:
29 | color_mask = color_samples[i%len(color_samples)]
30 | color_map[segment_map==i] = color_mask
31 | color_masks.append((*(color_mask/255.),1))
32 | classes.insert(0, 'none')
33 | color_masks.insert(0, (0.6, 0.6, 0.6, 1))
34 | handles = [plt.Rectangle((0, 0), 0, 0, color=color_masks[i], label=classes[i]) for i in range(n_classes+1)]
35 | plt.legend(handles=handles, framealpha=0.4, title='class')
36 | plt.imshow(color_map)
37 | plt.axis('off')
38 | plt.savefig(save_loc, bbox_inches='tight', pad_inches=0)
39 |
40 |
41 | def main(args):
42 | with open(args.config_path) as f:
43 | config = yaml.load(f, Loader=yaml.FullLoader)
44 |
45 | output_dir = config['output_dir']
46 |
47 | if not os.path.isdir(output_dir):
48 | os.makedirs(output_dir)
49 |
50 | feature_loc = config['feature_loc']
51 |
52 | use_openclip = config['use_openclip']
53 |
54 | querys = config['query']
55 | use_prompt_ensemble = config['use_prompt_ensemble']
56 | batch_size = config['batch_size']
57 | top_k = config['top_k']
58 | threshold = config['similarity_threshold']
59 |
60 | assert(top_k >= 1)
61 |
62 | if os.path.exists(feature_loc) and not config['replace_feature']:
63 | print(f'Load model.')
64 | clip = OpenCLIP() if use_openclip else OpenAICLIP()
65 | print(f'Load features from {feature_loc}.')
66 | image_features = torch.load(feature_loc)
67 | else:
68 | print(f'Load model.')
69 | clip = OpenCLIP() if use_openclip else OpenAICLIP()
70 | segany = AutoSegmentAnything(config['sam_checkpoint_loc'], "default", device)
71 | segclip = AutoSegAnyCLIP(segany, clip, device)
72 |
73 | print(f'Not find features at {feature_loc}.')
74 | print(f'Generate and store features at \'{feature_loc}\'.')
75 |
76 | image_loc = config['image_loc']
77 | dowmsample = config['dowmsample']
78 | resolution = config['resolution']
79 |
80 | assert(resolution > 0 and resolution <= 3)
81 |
82 | image = cv2.imread(image_loc)
83 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
84 | image = cv2.resize(image, (image.shape[1]//dowmsample, image.shape[0]//dowmsample), interpolation=cv2.INTER_AREA)
85 |
86 | segany.set_image(image)
87 | segclip.n_objs = resolution
88 | image_features = segclip.encode_image(extent_segmentation_mask=1, bbox_crop=False)
89 |
90 | torch.save(image_features, feature_loc)
91 |
92 | del segany, segclip
93 | torch.cuda.empty_cache()
94 |
95 | similarity_argmax_top_k = clip.predict_similarity_objects_with_feature_attention_batch(image_features, querys, use_prompt_ensemble, batch_size, top_k, threshold)
96 |
97 | similarity_argmax_top_k = [similarity_argmax.cpu().numpy() for similarity_argmax in similarity_argmax_top_k]
98 |
99 | for k, similarity_argmax in enumerate(similarity_argmax_top_k):
100 | save_segment_map(similarity_argmax, querys, os.path.join(output_dir, f'similarity_max{k}.png'))
101 |
102 | with open(os.path.join(output_dir, 'config.yaml'), 'w') as f:
103 | yaml.dump(config, f)
104 |
105 |
106 | if __name__ == "__main__":
107 | parser = argparse.ArgumentParser(description='Per-Pixel Semantic Feature Generator')
108 | parser.add_argument('--config_path', type=str, required=True, help='path of the config.')
109 |
110 | args = parser.parse_args()
111 |
112 | main(args)
--------------------------------------------------------------------------------
/seganyclip.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
3 | import numpy as np
4 | import open_clip
5 | import clip
6 | from PIL import Image
7 | import torch
8 | from tqdm import tqdm
9 | import os
10 | import math
11 | import cv2
12 |
13 | class SegmentAnything:
14 | def __init__(self,
15 | data_dir="/media/public_dataset/segany/",
16 | model_type="vit_h",
17 | device='cuda'):
18 | data_dir = data_dir
19 | model_type = model_type
20 | checkpoint_name = {
21 | "default": 'sam_vit_h_4b8939.pth',
22 | "vit_h": 'sam_vit_h_4b8939.pth',
23 | "vit_l": 'sam_vit_l_0b3195.pth',
24 | "vit_b": 'sam_vit_b_01ec64.pth'
25 | }
26 | sam_checkpoint = checkpoint_name[model_type]
27 |
28 | sam_checkpoint = data_dir + sam_checkpoint
29 |
30 | self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
31 | self.sam.to(device=device)
32 |
33 | self.predictor = SamPredictor(self.sam)
34 | self.device = device
35 |
36 | def set_image(self, image):
37 | self.image = image
38 | self.predictor.set_image(image)
39 |
40 | def predict_object_coord(self, coord, without_mask=False):
41 | # predict objects by the input coordinate (x, y)
42 | # return 3 or less object predicted by Seg-Anything with their scores
43 | # mask out pixels that are not related to objects if 'without_mask'==False
44 | # crop rectangles related to objects if 'without_mask'==True
45 | input_point = np.array([[coord[0], coord[1]]])
46 | input_label = np.array([1])
47 | masks, scores, logits = self.predictor.predict(
48 | point_coords=input_point,
49 | point_labels=input_label,
50 | multimask_output=True,
51 | )
52 |
53 | objects = []
54 | object_scores = []
55 | for i, (mask, score) in enumerate(zip(masks, scores)):
56 | if not mask.any():
57 | continue
58 | image_object = self.image.copy()
59 | if not without_mask:
60 | image_object[np.logical_not(mask)] = (255,255,255)
61 | xmin, ymin, xmax, ymax = self.from_mask_to_bbox(mask)
62 | image_object = image_object[ymin:ymax+1, xmin:xmax+1]
63 | objects.append(image_object)
64 | object_scores.append(score)
65 |
66 | return objects, object_scores
67 |
68 | def crop_multires_around_coord(self, coord, down=2, res=3):
69 | # crop rectangles in multi-resolutions around the input coordinate (x, y)
70 |
71 | H, W, _ = self.image.shape
72 |
73 | assert(down > 0 and res > 0)
74 |
75 | objects = []
76 | H_, W_ = H//down, W//down
77 | for _ in range(res):
78 | H_, W_ = H_//down, W_//down
79 | image_object = self.image.copy()
80 | xmin, ymin, xmax, ymax = max(coord[0]-W_, 0), max(coord[1]-H_, 0), min(coord[0]+W_, W), min(coord[1]+H_, H)
81 | image_object = image_object[ymin:ymax, xmin:xmax]
82 | objects.append(image_object)
83 |
84 | return objects
85 |
86 | @staticmethod
87 | def from_mask_to_bbox(mask):
88 | mask_indices = np.where(mask)
89 | xmin, ymin, xmax, ymax = min(mask_indices[1]), min(mask_indices[0]), max(mask_indices[1]), max(mask_indices[0])
90 | return xmin, ymin, xmax, ymax
91 |
92 | @staticmethod
93 | def make_per_pixel_point_prompt(image_size):
94 | # image_size: H x W
95 | x = np.arange(image_size[1])
96 | y = np.arange(image_size[0])
97 | xv, yv = np.meshgrid(x, y)
98 |
99 | points = np.stack([yv, xv], axis=-1).reshape(-1, 1, 2)
100 | labels = np.ones(image_size[0]*image_size[1]).reshape(-1, 1)
101 | return points, labels
102 |
103 |
104 | class AutoSegmentAnything:
105 | def __init__(self,
106 | data_dir="./checkpoints/",
107 | model_type="vit_h",
108 | device='cuda'):
109 | data_dir = data_dir
110 | model_type = model_type
111 | checkpoint_name = {
112 | "default": 'sam_vit_h_4b8939.pth',
113 | "vit_h": 'sam_vit_h_4b8939.pth',
114 | "vit_l": 'sam_vit_l_0b3195.pth',
115 | "vit_b": 'sam_vit_b_01ec64.pth'
116 | }
117 | sam_checkpoint = checkpoint_name[model_type]
118 |
119 | sam_checkpoint = os.path.join(data_dir, sam_checkpoint)
120 |
121 | self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
122 | self.sam.to(device=device)
123 |
124 | self.generator = SamAutomaticMaskGenerator(
125 | model=self.sam,
126 | points_per_side = 32,
127 | points_per_batch = 64,
128 | pred_iou_thresh = 0.86,
129 | stability_score_thresh = 0.92,
130 | stability_score_offset = 1.0,
131 | box_nms_thresh = 0.75,
132 | crop_n_layers = 2,
133 | crop_nms_thresh = 0.75,
134 | crop_overlap_ratio = 0.66,
135 | crop_n_points_downscale_factor = 2,
136 | min_mask_region_area = 100
137 | )
138 |
139 | # self.generator = SamAutomaticMaskGenerator(
140 | # model=self.sam,
141 | # points_per_side=64,
142 | # pred_iou_thresh=0.8,
143 | # stability_score_thresh=0.8,
144 | # crop_n_layers=0,
145 | # crop_n_points_downscale_factor=0,
146 | # min_mask_region_area=100, # Requires open-cv to run post-processing
147 | # )
148 |
149 | self.device = device
150 |
151 | def set_image(self, image):
152 | self.image = image
153 |
154 | def generate_masks(self):
155 | return self.generator.generate(self.image)
156 |
157 | class CLIP:
158 | def __init__(self, similarity_scale=10):
159 | self.similarity_scale = similarity_scale
160 |
161 | def set_similarity_scale(self, similarity_scale):
162 | self.similarity_scale = similarity_scale
163 |
164 | def encode_image(self, image):
165 | image = self.preprocess(Image.fromarray(image)).unsqueeze(0).to(device=self.device).half()
166 | with torch.no_grad():
167 | image_features = self.model.encode_image(image)
168 | image_features /= image_features.norm(dim=-1, keepdim=True)
169 | return image_features
170 |
171 | def encode_text(self, text_list):
172 | text = self.tokenizer(text_list).to(device=self.device)
173 | with torch.no_grad():
174 | text_features = self.model.encode_text(text)
175 | text_features /= text_features.norm(dim=-1, keepdim=True)
176 | return text_features
177 |
178 | def encode_text_with_prompt_ensemble(self, text_list, prompt_templates=None):
179 | # using default prompt templates for ImageNet
180 | if prompt_templates == None:
181 | # prompt_templates = ['a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
182 | # easier ones
183 | prompt_templates = ['a photo of a {}.', 'This is a photo of a {}', 'This is a photo of a small {}', 'This is a photo of a medium {}', 'This is a photo of a large {}', 'This is a photo of a {}', 'This is a photo of a small {}', 'This is a photo of a medium {}', 'This is a photo of a large {}', 'a photo of a {} in the scene', 'a photo of a {} in the scene', 'There is a {} in the scene', 'There is the {} in the scene', 'This is a {} in the scene', 'This is the {} in the scene', 'This is one {} in the scene']
184 |
185 | with torch.no_grad():
186 | text_features = []
187 | for t in text_list:
188 | prompted_t = [template.format(t) for template in prompt_templates]
189 | class_embeddings = self.encode_text(prompted_t)
190 | class_embedding = class_embeddings.mean(dim=0)
191 | class_embedding /= class_embedding.norm()
192 | text_features.append(class_embedding)
193 | text_features = torch.stack(text_features, dim=0)
194 |
195 | return text_features
196 |
197 | def predict_similarity_objects_with_feature_attention_batch(self, image_features, text, prompt_ensemble=False, batch_size=1024, top_k=1, threshold=0., projection=None):
198 |
199 | # use prompt templetes to prompt input texts
200 | text_features = self.encode_text(text) if not prompt_ensemble else self.encode_text_with_prompt_ensemble(text)
201 | if projection is not None:
202 | text_features = projection(text_features.float()).half()
203 |
204 | image_shape = image_features.shape[:2]
205 |
206 | batches = self.separate_image_features_batches(image_features, batch_size)
207 | batches_similarity = [[] for _ in range(top_k)]
208 |
209 | for image_features in batches:
210 | # don't need fuse features if only one dim.
211 | if image_features.shape[-2] != 1:
212 | feature_similarity = self.similarity_scale * (image_features @ text_features.T)
213 | feature_similarity = torch.moveaxis(feature_similarity, -1, 0).softmax(axis=-1)
214 |
215 | image_features = (image_features[None,...] * feature_similarity[...,None]).sum(axis=-2)
216 | # overall similarity of objects (that detected from a pixel) with the text prompt
217 | similarity = self.similarity_scale * (image_features @ text_features.T)
218 | similarity = torch.stack([similarity[i, ..., i] for i in range(len(similarity))])
219 | else:
220 | similarity = torch.moveaxis((image_features @ text_features.T), -1, 0).squeeze(-1)
221 |
222 | for i in range(top_k):
223 | similarity_max = similarity.max(0).values
224 | similarity_argmax = similarity.argmax(0)
225 | similarity[similarity_argmax] = -1
226 | similarity_argmax[similarity_max <= threshold] = -1
227 | batches_similarity[i].append(similarity_argmax.cpu())
228 |
229 | similarity_argmax = [self.merge_image_features_batches(batches_similarity[i], image_shape) for i in range(top_k)]
230 |
231 | # similarity (len(text), len(text)): similarity scores when each text as input (input_text_for_attention, H, W, relation_with_each_text)
232 | return similarity_argmax
233 |
234 | @staticmethod
235 | def separate_image_features_batches(image_features, batch_size=1024):
236 | H, W = image_features.shape[:2]
237 | image_features = image_features.reshape(H*W, *image_features.shape[2:])
238 | batch_indices = []
239 | idx = 0
240 | while idx < H*W:
241 | batch_indices.append((idx, min(idx+batch_size, H*W)))
242 | idx += batch_size
243 |
244 | batches = []
245 | for start, end in batch_indices:
246 | batches.append(image_features[start:end])
247 |
248 | return batches
249 |
250 | @staticmethod
251 | def merge_image_features_batches(batches, image_shape):
252 | H, W = image_shape
253 | image_features = torch.cat(batches, axis=0)
254 |
255 | return image_features.reshape(H, W, *image_features.shape[1:])
256 |
257 |
258 | class OpenAICLIP(CLIP):
259 | def __init__(self,
260 | model_type='ViT-B/16',
261 | device='cuda'):
262 | super().__init__()
263 |
264 | self.model, self.preprocess = clip.load(model_type)
265 | self.model.to(device=device)
266 | self.tokenizer = clip.tokenize
267 | self.device = device
268 |
269 | class OpenCLIP(CLIP):
270 | def __init__(self,
271 | model_type='ViT-B-16',
272 | pretrained='laion2b_s34b_b88k',
273 | device='cuda'):
274 | super().__init__()
275 |
276 | self.model, _, self.preprocess = open_clip.create_model_and_transforms('ViT-B-16', pretrained=pretrained, precision="fp16")
277 | self.model.to(device=device)
278 | self.tokenizer = open_clip.get_tokenizer(model_type)
279 | self.device = device
280 |
281 |
282 | class SegAnyCLIP:
283 | def __init__(self,
284 | segany,
285 | clip,
286 | device='cuda'):
287 | self.segany = segany
288 | self.clip = clip
289 | self.clip_n_dims = 512
290 | self.n_objs = 3
291 | self.zeros = torch.zeros((1, self.clip_n_dims), device=device).half()
292 | self.device = device
293 |
294 | def encode_image(self, without_mask=False):
295 | # predect per-pixels features for 'image' in 'SegmentAnything' based on predicted objects
296 | # mask out pixels that are not related to objects if 'without_mask'==False
297 | # crop rectangles related to objects if 'without_mask'==True
298 | # output shape: (H, W, clip_n_dims, n_objs)
299 |
300 | image = self.segany.image
301 | coords = self.make_per_pixel_point_prompt(image.shape)
302 |
303 | H, W, _ = image.shape
304 |
305 | image_pixel_embeddings = []
306 |
307 | with torch.no_grad():
308 | for coord in tqdm(coords):
309 | objects, _ = self.segany.predict_object_coord(coord, without_mask=without_mask)
310 | pixel_embeddings = []
311 | for single_object in objects:
312 | pixel_embeddings.append(self.clip.encode_image(single_object))
313 | for _ in range(self.n_objs-len(objects)):
314 | pixel_embeddings.append(self.zeros)
315 | pixel_embeddings = torch.cat(pixel_embeddings, dim=0)
316 | image_pixel_embeddings.append(pixel_embeddings)
317 |
318 | image_pixel_embeddings = torch.cat(image_pixel_embeddings, axis=0).reshape(H, W, self.n_objs, self.clip_n_dims)
319 | return image_pixel_embeddings
320 |
321 | def encode_image_multires_crop(self):
322 | # predect per-pixels features for 'image' in 'SegmentAnything' by cropping
323 | # crop rectangles related to objects if 'without_mask'==True
324 | # output shape: (H, W, clip_n_dims, n_objs)
325 |
326 | image = self.segany.image
327 | coords = self.make_per_pixel_point_prompt(image.shape)
328 |
329 | H, W, _ = image.shape
330 |
331 | image_pixel_embeddings = []
332 |
333 | for coord in tqdm(coords):
334 | objects = self.segany.crop_multires_around_coord(coord)
335 | pixel_embeddings = []
336 | for single_object in objects:
337 | pixel_embeddings.append(self.clip.encode_image(single_object))
338 | for _ in range(self.n_objs-len(objects)):
339 | pixel_embeddings.append(self.zeros)
340 | pixel_embeddings = torch.cat(pixel_embeddings, dim=0)
341 | image_pixel_embeddings.append(pixel_embeddings)
342 |
343 | image_pixel_embeddings = torch.cat(image_pixel_embeddings, axis=0).reshape(H, W, self.n_objs, self.clip_n_dims)
344 | return image_pixel_embeddings
345 |
346 | @staticmethod
347 | def make_per_pixel_point_prompt(image_size):
348 | # image_size: H x W
349 | x = np.arange(image_size[1])
350 | y = np.arange(image_size[0])
351 | xv, yv = np.meshgrid(x, y)
352 | points = np.stack([xv, yv], axis=-1).reshape(-1, 2)
353 | return points
354 |
355 |
356 | class AutoSegAnyCLIP:
357 | def __init__(self,
358 | segany,
359 | clip,
360 | device='cuda'):
361 | self.segany = segany
362 | self.clip = clip
363 | self.clip_n_dims = 512
364 | self.n_objs = 3
365 | self.zeros = torch.zeros((1, self.clip_n_dims), device=device).half()
366 | self.device = device
367 |
368 | def encode_image(self, bbox_crop=False, extent_segmentation_mask=0, blur=False):
369 | # predect per-pixels features for 'image' in 'SegmentAnything' based on predicted objects
370 | # mask out pixels that are not related to objects if 'bbox_crop'==False
371 | # crop rectangles related to objects if 'bbox_crop'==True
372 | # extent_segmentation_mask: extent pixels of an area from each segmentation mask for bigger coverage
373 | # output shape: (H, W, clip_n_dims, n_objs)
374 |
375 | image = self.segany.image
376 |
377 | H, W, _ = image.shape
378 |
379 | image_pixel_embeddings = []
380 |
381 | masks = self.segany.generate_masks()
382 |
383 | check_mask_covered = torch.zeros(image.shape[:2])
384 |
385 | for i, mask in enumerate(masks):
386 | masks[i]['segmentation'] = self.segmentmap_extent_multi(mask['segmentation'], extent_segmentation_mask)
387 | check_mask_covered[masks[i]['segmentation']] = 1
388 |
389 | point_to_mask = {}
390 | for y in range(H):
391 | for x in range(W):
392 | point_to_mask[(x, y)] = []
393 | for i, mask in enumerate(masks):
394 | ys, xs = np.where(mask['segmentation'])
395 | for x, y in zip(xs, ys):
396 | point_to_mask[(x, y)] += [i]
397 |
398 | objects = []
399 | object_scores = []
400 | object_areas = []
401 |
402 | background_color = np.array([255.,255.,255.])*0.
403 |
404 | for i, mask in enumerate(masks):
405 | image_object = image.copy().astype('float')
406 | if bbox_crop:
407 | if blur:
408 | image_blur = cv2.GaussianBlur(image_object, (5, 5), 0)
409 | image_object[np.logical_not(mask['segmentation'])] = image_blur[np.logical_not(mask['segmentation'])]
410 | image_object[np.logical_not(mask['segmentation'])] *= 0.75
411 | image_object[np.logical_not(mask['segmentation'])] += background_color * 0.25
412 | else:
413 | image_object[np.logical_not(mask['segmentation'])] = background_color
414 | image_object = image_object.astype('uint8')
415 | xmin, ymin, xmax, ymax = self.from_mask_to_bbox(mask['segmentation'], extent=0.01)
416 | image_object = image_object[ymin:ymax+1, xmin:xmax+1]
417 | objects.append(image_object)
418 | object_scores.append(mask['predicted_iou'])
419 | object_areas.append(mask['area'])
420 |
421 | for point in point_to_mask.keys():
422 | point_to_mask[point] = sorted(point_to_mask[point], key=lambda x: (object_areas[x], object_scores[x]), reverse=True)
423 |
424 | objects_embeddings = []
425 | for single_object in objects:
426 | objects_embeddings.append(self.clip.encode_image(single_object))
427 |
428 | # self.zeros = self.clip.encode_image(image)
429 | # with torch.no_grad():
430 | # image_crop = image.copy()
431 | # mask_covered = (check_mask_covered==1)
432 | # mask_not_covered = (check_mask_covered==0)
433 | # image_crop[mask_covered] = (0, 0, 0)
434 | # if mask_not_covered.any():
435 | # xmin, ymin, xmax, ymax = self.from_mask_to_bbox(mask_not_covered)
436 | # image_crop = image_crop[ymin:ymax+1, xmin:xmax+1]
437 | # self.zeros = self.clip.encode_image(image_crop)
438 |
439 | image_pixel_embeddings = []
440 | for y in range(H):
441 | for x in range(W):
442 | pixel_embeddings = [objects_embeddings[object_id] for object_id in point_to_mask[(x, y)][:self.n_objs]]
443 | for i in range(self.n_objs-len(pixel_embeddings)):
444 | pixel_embeddings.append(self.zeros)
445 | image_pixel_embeddings.append(torch.cat(pixel_embeddings, axis=0))
446 | image_pixel_embeddings = torch.cat(image_pixel_embeddings, axis=0).reshape(H, W, self.n_objs, self.clip_n_dims)
447 |
448 | return image_pixel_embeddings
449 |
450 | def encode_image_concept_fusion(self, bbox_crop=False, extent_segmentation_mask=0):
451 | # predect per-pixels features for 'image' in 'SegmentAnything' based on predicted objects
452 | # mask out pixels that are not related to objects if 'bbox_crop'==False
453 | # crop rectangles related to objects if 'bbox_crop'==True
454 | # extent_segmentation_mask: extent pixels of area from each segmentation mask for bigger coverage
455 | # output shape: (H, W, clip_n_dims, n_objs)
456 |
457 | image = self.segany.image
458 |
459 | H, W, _ = image.shape
460 |
461 | image_pixel_embeddings = []
462 |
463 | masks = self.segany.generate_masks()
464 |
465 | for i, mask in enumerate(masks):
466 | masks[i]['segmentation'] = self.segmentmap_extent_multi(mask['segmentation'], extent_segmentation_mask)
467 |
468 | point_to_mask = {}
469 | for y in range(H):
470 | for x in range(W):
471 | point_to_mask[(x, y)] = []
472 | for i, mask in enumerate(masks):
473 | ys, xs = np.where(mask['segmentation'])
474 | for x, y in zip(xs, ys):
475 | point_to_mask[(x, y)] += [i]
476 |
477 | objects = []
478 | object_scores = []
479 | object_areas = []
480 | for i, mask in enumerate(masks):
481 | image_object = image.copy()
482 | if not bbox_crop:
483 | image_object[np.logical_not(mask['segmentation'])] = (255,255,255)
484 | xmin, ymin, xmax, ymax = self.from_mask_to_bbox(mask['segmentation'])
485 | image_object = image_object[ymin:ymax+1, xmin:xmax+1]
486 | objects.append(image_object)
487 | object_scores.append(mask['predicted_iou'])
488 | object_areas.append(mask['area'])
489 |
490 | for point in point_to_mask.keys():
491 | point_to_mask[point] = sorted(point_to_mask[point], key=lambda x: (object_areas[x], object_scores[x]), reverse=True)
492 |
493 | objects_embeddings = []
494 | for single_object in objects:
495 | objects_embeddings.append(self.clip.encode_image(single_object))
496 |
497 | image_embeddings = self.clip.encode_image(image)
498 | objects_embeddings = torch.cat(objects_embeddings, axis=0)
499 | objects_local_global_similarity = (image_embeddings @ objects_embeddings.T).squeeze(0)
500 | objects_cross_similarity = (objects_embeddings @ objects_embeddings.T)
501 | objects_self_similarity = torch.stack([objects_cross_similarity[i, i] for i in range(len(objects_cross_similarity))])
502 | objects_avg_cross_similarity = ((objects_cross_similarity.sum(axis=-1) - objects_self_similarity)) / (len(objects_cross_similarity)-1)
503 |
504 | t = 1
505 | w_global = ((objects_local_global_similarity + objects_avg_cross_similarity)/t).softmax(-1)
506 |
507 | objects_embeddings_fusion = ((w_global[:, None] * image_embeddings) + ((1 - w_global[:, None])*objects_embeddings))
508 | objects_embeddings_fusion /= objects_embeddings_fusion.norm(dim=-1, keepdim=True)
509 |
510 | image_pixel_embeddings = []
511 | for y in range(H):
512 | for x in range(W):
513 | pixel_embeddings = [objects_embeddings_fusion[object_id][None, :] for object_id in point_to_mask[(x, y)][:self.n_objs]]
514 | for i in range(self.n_objs-len(pixel_embeddings)):
515 | pixel_embeddings.append(self.zeros)
516 | image_pixel_embeddings.append(torch.cat(pixel_embeddings, axis=0))
517 | image_pixel_embeddings = torch.cat(image_pixel_embeddings, axis=0).reshape(H, W, self.n_objs, self.clip_n_dims)
518 |
519 | return image_pixel_embeddings
520 |
521 | @staticmethod
522 | def from_mask_to_bbox(mask, extent=0, sqrt=False):
523 | H, W = mask.shape[:2]
524 | mask_indices = np.where(mask)
525 | xmin, ymin, xmax, ymax = min(mask_indices[1]), min(mask_indices[0]), max(mask_indices[1]), max(mask_indices[0])
526 | if extent > 0:
527 | if sqrt:
528 | x_extent, y_extent = math.ceil(math.sqrt((xmax-xmin)*extent)), math.ceil(math.sqrt((ymax-ymin)*extent))
529 | # x_extent, y_extent = max(x_extent, y_extent), max(x_extent, y_extent)
530 | else:
531 | x_extent, y_extent = math.ceil((xmax-xmin)*extent), math.ceil((ymax-ymin)*extent)
532 | # x_extent, y_extent = max(x_extent, y_extent), max(x_extent, y_extent)
533 | xmin, ymin, xmax, ymax = max(xmin-x_extent, 0), max(ymin-y_extent, 0), min(xmax+x_extent, W-1), min(ymax+y_extent, H-1)
534 | return xmin, ymin, xmax, ymax
535 |
536 | @staticmethod
537 | def segmentmap_extent(segmentmap):
538 | segmentmap_extent = segmentmap.copy()
539 | H, W = segmentmap.shape
540 | ys, xs = np.where(segmentmap)
541 | for x, y in zip(xs, ys):
542 | extents = [(x, max(0, y-1)), (max(0, x-1), y), (min(W-1, x+1), y), (x, min(H-1, y+1))]
543 | for x, y in extents:
544 | segmentmap_extent[y, x] = True
545 | return segmentmap_extent
546 |
547 | def segmentmap_extent_multi(self, segmentmap, time=2):
548 | segmentmap_temp = segmentmap.copy()
549 | for _ in range(time):
550 | segmentmap_temp = self.segmentmap_extent(segmentmap_temp)
551 | return segmentmap_temp
--------------------------------------------------------------------------------