├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── build_model.py ├── simple_tokenizer.py ├── clip_model.py ├── clip.py └── clip_surgery_model.py ├── .gitignore ├── demo.jpg ├── figs ├── fig1.jpg ├── fig2.jpg ├── fig3.jpg ├── fig4.jpg ├── fig5.jpg ├── fig6.jpg └── fig7.jpg ├── README.md └── demo.py /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pth 3 | .ipynb_checkpoints 4 | -------------------------------------------------------------------------------- /demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/CLIP_Surgery/HEAD/demo.jpg -------------------------------------------------------------------------------- /figs/fig1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/CLIP_Surgery/HEAD/figs/fig1.jpg -------------------------------------------------------------------------------- /figs/fig2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/CLIP_Surgery/HEAD/figs/fig2.jpg -------------------------------------------------------------------------------- /figs/fig3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/CLIP_Surgery/HEAD/figs/fig3.jpg -------------------------------------------------------------------------------- /figs/fig4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/CLIP_Surgery/HEAD/figs/fig4.jpg -------------------------------------------------------------------------------- /figs/fig5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/CLIP_Surgery/HEAD/figs/fig5.jpg -------------------------------------------------------------------------------- /figs/fig6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/CLIP_Surgery/HEAD/figs/fig6.jpg -------------------------------------------------------------------------------- /figs/fig7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/CLIP_Surgery/HEAD/figs/fig7.jpg -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xmed-lab/CLIP_Surgery/HEAD/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A closer look at the explainability of Contrastive language-image pre-training ([Pattern Recognition](https://www.sciencedirect.com/science/article/abs/pii/S003132032500069X?via%3Dihub)) 2 | Early version: CLIP Surgery for Better Explainability with Enhancement in Open-Vocabulary Tasks ([arxiv](https://arxiv.org/abs/2304.05653)) 3 | 4 | ## Introduction 5 | 6 | This work focuses on the explainability of CLIP via its raw predictions. We identify two problems about CLIP's explainability: opposite visualization and noisy activations. Then we propose the CLIP Surgery, which does not require any fine-tuning or additional supervision. It greatly improves the explainability of CLIP, and enhances downstream open-vocabulary tasks such as multi-label recognition, semantic segmentation, interactive segmentation (specifically the Segment Anything Model, SAM), and multimodal visualization. Currently, we offer a simple demo for interpretability analysis, and how to convert text to point prompts for SAM. Rest codes including evaluation and other tasks will be released later. 7 | 8 | Opposite visualization is due to wrong relation in self-attention: 9 | ![image](figs/fig1.jpg) 10 | 11 | Noisy activations is owing to redundant features across lables: 12 | ![image](figs/fig2.jpg) 13 | 14 | Our visualization results: 15 | ![image](figs/fig3.jpg) 16 | 17 | Text2Points to guide SAM: 18 | ![image](figs/fig4.jpg) 19 | 20 | Multimodal visualization: 21 | ![image](figs/fig5.jpg) 22 | 23 | Segmentation results: 24 | ![image](figs/fig6.jpg) 25 | 26 | Multilabel results: 27 | ![image](figs/fig7.jpg) 28 | 29 | ## Demo 30 | 31 | Firstly to install the SAM, and download the model 32 | ``` 33 | pip install git+https://github.com/facebookresearch/segment-anything.git 34 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 35 | ``` 36 | 37 | Then explain CLIP via jupyter demo ["demo.ipynb"](https://github.com/xmed-lab/CLIP_Surgery/blob/master/demo.ipynb). 38 | Or use the python file: 39 | ``` 40 | python demo.py 41 | ``` 42 | (Note: demo's results are slightly different from the experimental code, specifically no apex amp fp16 for easier to use.) 43 | 44 | ## Cite 45 | ``` 46 | @article{LI2025111409, 47 | title = {A closer look at the explainability of Contrastive language-image pre-training}, 48 | journal = {Pattern Recognition}, 49 | volume = {162}, 50 | pages = {111409}, 51 | year = {2025}, 52 | issn = {0031-3203}, 53 | doi = {https://doi.org/10.1016/j.patcog.2025.111409}, 54 | url = {https://www.sciencedirect.com/science/article/pii/S003132032500069X}, 55 | author = {Yi Li and Hualiang Wang and Yiqun Duan and Jiheng Zhang and Xiaomeng Li} 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /clip/build_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .clip_model import CLIP 3 | from .clip_surgery_model import CLIPSurgery 4 | 5 | 6 | def convert_weights(model: nn.Module): 7 | """Convert applicable model parameters to fp16""" 8 | 9 | def _convert_weights_to_fp16(l): 10 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 11 | l.weight.data = l.weight.data.half() 12 | if l.bias is not None: 13 | l.bias.data = l.bias.data.half() 14 | 15 | if isinstance(l, nn.MultiheadAttention): 16 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 17 | tensor = getattr(l, attr) 18 | if tensor is not None: 19 | tensor.data = tensor.data.half() 20 | 21 | for name in ["text_projection", "proj"]: 22 | if hasattr(l, name): 23 | attr = getattr(l, name) 24 | if attr is not None: 25 | attr.data = attr.data.half() 26 | 27 | model.apply(_convert_weights_to_fp16) 28 | 29 | 30 | def build_model(name: str, state_dict: dict): 31 | vit = "visual.proj" in state_dict 32 | 33 | if vit: 34 | vision_width = state_dict["visual.conv1.weight"].shape[0] 35 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 36 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 37 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 38 | image_resolution = vision_patch_size * grid_size 39 | else: 40 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 41 | vision_layers = tuple(counts) 42 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 43 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 44 | vision_patch_size = None 45 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 46 | image_resolution = output_width * 32 47 | 48 | embed_dim = state_dict["text_projection"].shape[1] 49 | context_length = state_dict["positional_embedding"].shape[0] 50 | vocab_size = state_dict["token_embedding.weight"].shape[0] 51 | transformer_width = state_dict["ln_final.weight"].shape[0] 52 | transformer_heads = transformer_width // 64 53 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 54 | 55 | if 'CS-' in name: 56 | model = CLIPSurgery( 57 | embed_dim, 58 | image_resolution, vision_layers, vision_width, vision_patch_size, 59 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 60 | ) 61 | else: 62 | model = CLIP( 63 | embed_dim, 64 | image_resolution, vision_layers, vision_width, vision_patch_size, 65 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 66 | ) 67 | 68 | for key in ["input_resolution", "context_length", "vocab_size"]: 69 | if key in state_dict: 70 | del state_dict[key] 71 | 72 | #convert_weights(model) 73 | model.load_state_dict(state_dict) 74 | return model.eval() 75 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import torch 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | from matplotlib import pyplot as plt 7 | from torchvision.transforms import Compose, Resize, ToTensor, Normalize 8 | from torchvision.transforms import InterpolationMode 9 | BICUBIC = InterpolationMode.BICUBIC 10 | from segment_anything import sam_model_registry, SamPredictor 11 | 12 | 13 | ### Init CLIP and data 14 | device = "cuda" if torch.cuda.is_available() else "cpu" 15 | model, _ = clip.load("ViT-B/16", device=device) 16 | model.eval() 17 | preprocess = Compose([Resize((224, 224), interpolation=BICUBIC), ToTensor(), 18 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))]) 19 | 20 | 21 | pil_img = Image.open("demo.jpg") 22 | cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) 23 | image = preprocess(pil_img).unsqueeze(0).to(device) 24 | all_texts = ['airplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform', 'potted plant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train', 'tree', 'truck', 'tv monitor', 'wall', 'water', 'window', 'wood'] 25 | target_texts = ['bench', 'person', 'ground', 'building'] 26 | 27 | ### Explain raw predictions of CLIP, which are opposite and noisy. 28 | with torch.no_grad(): 29 | # Extract image features 30 | image_features = model.encode_image(image) 31 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 32 | 33 | # Prompt ensemble for text features with normalization 34 | text_features = clip.encode_text_with_prompt_ensemble(model, all_texts, device) 35 | 36 | # Similarity map from image tokens with min-max norm and resize, B,H,W,N 37 | features = image_features @ text_features.t() 38 | similarity_map = clip.get_similarity_map(features[:, 1:, :], cv2_img.shape[:2]) 39 | 40 | # Draw similarity map 41 | for b in range(similarity_map.shape[0]): 42 | for n in range(similarity_map.shape[-1]): 43 | if all_texts[n] not in target_texts: 44 | continue 45 | vis = (similarity_map[b, :, :, n].cpu().numpy() * 255).astype('uint8') 46 | vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET) 47 | vis = cv2_img * 0.4 + vis * 0.6 48 | vis = cv2.cvtColor(vis.astype('uint8'), cv2.COLOR_BGR2RGB) 49 | print('CLIP:', all_texts[n]) 50 | plt.imshow(vis) 51 | plt.show() 52 | 53 | 54 | ### Explain CLIP via our CLIP Surgery 55 | model, preprocess = clip.load("CS-ViT-B/16", device=device) 56 | model.eval() 57 | 58 | with torch.no_grad(): 59 | # CLIP architecture surgery acts on the image encoder 60 | image_features = model.encode_image(image) 61 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 62 | 63 | # Prompt ensemble for text features with normalization 64 | text_features = clip.encode_text_with_prompt_ensemble(model, all_texts, device) 65 | 66 | # Apply feature surgery 67 | similarity = clip.clip_feature_surgery(image_features, text_features) 68 | similarity_map = clip.get_similarity_map(similarity[:, 1:, :], cv2_img.shape[:2]) 69 | 70 | # Draw similarity map 71 | for b in range(similarity_map.shape[0]): 72 | for n in range(similarity_map.shape[-1]): 73 | if all_texts[n] not in target_texts: 74 | continue 75 | vis = (similarity_map[b, :, :, n].cpu().numpy() * 255).astype('uint8') 76 | vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET) 77 | vis = cv2_img * 0.4 + vis * 0.6 78 | vis = cv2.cvtColor(vis.astype('uint8'), cv2.COLOR_BGR2RGB) 79 | print('CLIP Surgery:', all_texts[n]) 80 | plt.imshow(vis) 81 | plt.show() 82 | 83 | 84 | ### CLIP Surgery using higher resolution 85 | 86 | # This preprocess for all next cases 87 | preprocess = Compose([Resize((512, 512), interpolation=BICUBIC), ToTensor(), 88 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))]) 89 | image = preprocess(pil_img).unsqueeze(0).to(device) 90 | 91 | with torch.no_grad(): 92 | # CLIP architecture surgery acts on the image encoder 93 | image_features = model.encode_image(image) 94 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 95 | 96 | # Prompt ensemble for text features with normalization 97 | text_features = clip.encode_text_with_prompt_ensemble(model, all_texts, device) 98 | 99 | # Apply feature surgery 100 | similarity = clip.clip_feature_surgery(image_features, text_features) 101 | similarity_map = clip.get_similarity_map(similarity[:, 1:, :], cv2_img.shape[:2]) 102 | 103 | # Draw similarity map 104 | for b in range(similarity_map.shape[0]): 105 | for n in range(similarity_map.shape[-1]): 106 | if all_texts[n] not in target_texts: 107 | continue 108 | vis = (similarity_map[b, :, :, n].cpu().numpy() * 255).astype('uint8') 109 | vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET) 110 | vis = cv2_img * 0.4 + vis * 0.6 111 | vis = cv2.cvtColor(vis.astype('uint8'), cv2.COLOR_BGR2RGB) 112 | print('CLIP Surgery 512:', all_texts[n]) 113 | plt.imshow(vis) 114 | plt.show() 115 | 116 | 117 | ### CLIP Surgery for a single text, without fixed label sets 118 | texts = ['shoes'] 119 | 120 | with torch.no_grad(): 121 | # CLIP architecture surgery acts on the image encoder 122 | image_features = model.encode_image(image) 123 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 124 | 125 | # Prompt ensemble for text features with normalization 126 | text_features = clip.encode_text_with_prompt_ensemble(model, texts, device) 127 | 128 | # Extract redundant features from an empty string 129 | redundant_features = clip.encode_text_with_prompt_ensemble(model, [""], device) 130 | 131 | # Apply feature surgery for single text 132 | similarity = clip.clip_feature_surgery(image_features, text_features, redundant_features) 133 | similarity_map = clip.get_similarity_map(similarity[:, 1:, :], cv2_img.shape[:2]) 134 | 135 | # Draw similarity map 136 | for b in range(similarity_map.shape[0]): 137 | for n in range(similarity_map.shape[-1]): 138 | vis = (similarity_map[b, :, :, n].cpu().numpy() * 255).astype('uint8') 139 | vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET) 140 | vis = cv2_img * 0.4 + vis * 0.6 141 | vis = cv2.cvtColor(vis.astype('uint8'), cv2.COLOR_BGR2RGB) 142 | print('CLIP Surgery for a single text:', texts[n]) 143 | plt.imshow(vis) 144 | plt.show() 145 | 146 | 147 | ### Text to points from CLIP Surgery to guide SAM 148 | 149 | # Init SAM 150 | sam_checkpoint = "sam_vit_h_4b8939.pth" 151 | model_type = "vit_h" 152 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 153 | sam.to(device=device) 154 | predictor = SamPredictor(sam) 155 | predictor.set_image(np.array(pil_img)) 156 | 157 | # Inference CLIP Surgery and SAM 158 | with torch.no_grad(): 159 | # CLIP architecture surgery acts on the image encoder 160 | image_features = model.encode_image(image) # Image resized to 512 161 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 162 | 163 | # Prompt ensemble for text features with normalization 164 | text_features = clip.encode_text_with_prompt_ensemble(model, all_texts, device) 165 | 166 | # Apply feature surgery, no batch 167 | similarity = clip.clip_feature_surgery(image_features, text_features)[0] 168 | 169 | # Inference SAM with points from CLIP Surgery 170 | for n in range(similarity.shape[-1]): 171 | if all_texts[n] not in target_texts: 172 | continue 173 | points, labels = clip.similarity_map_to_points(similarity[1:, n], cv2_img.shape[:2], t=0.8) 174 | masks, scores, logits = predictor.predict(point_labels=labels, point_coords=np.array(points), multimask_output=True) 175 | mask = masks[np.argmax(scores)] 176 | mask = mask.astype('uint8') 177 | 178 | # Visualize the results 179 | vis = cv2_img.copy() 180 | vis[mask > 0] = vis[mask > 0] // 2 + np.array([153, 255, 255], dtype=np.uint8) // 2 181 | for i, [x, y] in enumerate(points): 182 | cv2.circle(vis, (x, y), 3, (0, 102, 255) if labels[i] == 1 else (255, 102, 51), 3) 183 | vis = cv2.cvtColor(vis.astype('uint8'), cv2.COLOR_BGR2RGB) 184 | print('SAM guided by points from CLIP Surgery:', all_texts[n]) 185 | plt.imshow(vis) 186 | plt.show() 187 | 188 | print('Sometimes, the points are accurate, while the masks from SAM still need improvements.') 189 | print('I mean, some failure cases are not caused by wrong points.') 190 | 191 | 192 | ### Inference CLIP Surgery and SAM for a single text 193 | texts = ['bench'] 194 | 195 | with torch.no_grad(): 196 | # CLIP architecture surgery acts on the image encoder 197 | image_features = model.encode_image(image) 198 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 199 | 200 | # Prompt ensemble for text features with normalization 201 | text_features = clip.encode_text_with_prompt_ensemble(model, texts, device) 202 | 203 | # Extract redundant features from an empty string 204 | redundant_features = clip.encode_text_with_prompt_ensemble(model, [""], device) 205 | 206 | # CLIP feature surgery with costum redundant features 207 | similarity = clip.clip_feature_surgery(image_features, text_features, redundant_features)[0] 208 | 209 | # Inference SAM with points from CLIP Surgery 210 | points, labels = clip.similarity_map_to_points(similarity[1:, 0], cv2_img.shape[:2], t=0.8) 211 | masks, scores, logits = predictor.predict(point_labels=labels, point_coords=np.array(points), multimask_output=True) 212 | mask = masks[np.argmax(scores)] 213 | mask = mask.astype('uint8') 214 | 215 | # Visualize the results 216 | vis = cv2_img.copy() 217 | vis[mask > 0] = vis[mask > 0] // 2 + np.array([153, 255, 255], dtype=np.uint8) // 2 218 | for i, [x, y] in enumerate(points): 219 | cv2.circle(vis, (x, y), 3, (0, 102, 255) if labels[i] == 1 else (255, 102, 51), 3) 220 | vis = cv2.cvtColor(vis.astype('uint8'), cv2.COLOR_BGR2RGB) 221 | print('SAM & CLIP Surgery for single text:', texts[0]) 222 | plt.imshow(vis) 223 | plt.show() 224 | 225 | 226 | ### CLIP Surgery + SAM for combined targets 227 | 228 | # We use "+" to combine texts, instead of a whole sentence (obvious text may take the lead thus overlook rest) 229 | text = 'person+bench' 230 | 231 | with torch.no_grad(): 232 | # CLIP architecture surgery acts on the image encoder 233 | image_features = model.encode_image(image) 234 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 235 | 236 | # Extract redundant features from an empty string 237 | redundant_features = clip.encode_text_with_prompt_ensemble(model, [""], device) 238 | 239 | # Prompt ensemble for text features with normalization 240 | text_features = clip.encode_text_with_prompt_ensemble(model, text.split('+'), device) 241 | 242 | # Combine features after removing redundant features and min-max norm 243 | sm = clip.clip_feature_surgery(image_features, text_features, redundant_features)[0, 1:, :] 244 | sm_norm = (sm - sm.min(0, keepdim=True)[0]) / (sm.max(0, keepdim=True)[0] - sm.min(0, keepdim=True)[0]) 245 | sm_mean = sm_norm.mean(-1, keepdim=True) 246 | 247 | # get positive points from individual maps, and negative points from the mean map 248 | p, l = clip.similarity_map_to_points(sm_mean, cv2_img.shape[:2], t=0.8) 249 | num = len(p) // 2 250 | points = p[num:] # negatives in the second half 251 | labels = [l[num:]] 252 | for i in range(sm.shape[-1]): 253 | p, l = clip.similarity_map_to_points(sm[:, i], cv2_img.shape[:2], t=0.8) 254 | num = len(p) // 2 255 | points = points + p[:num] # positive in first half 256 | labels.append(l[:num]) 257 | labels = np.concatenate(labels, 0) 258 | 259 | # Inference SAM with points from CLIP Surgery 260 | masks, scores, logits = predictor.predict(point_labels=labels, point_coords=np.array(points), multimask_output=True) 261 | mask = masks[np.argmax(scores)] 262 | mask = mask.astype('uint8') 263 | 264 | # Visualize the results 265 | vis = cv2_img.copy() 266 | vis[mask > 0] = vis[mask > 0] // 2 + np.array([153, 255, 255], dtype=np.uint8) // 2 267 | for i, [x, y] in enumerate(points): 268 | cv2.circle(vis, (x, y), 3, (0, 102, 255) if labels[i] == 1 else (255, 102, 51), 3) 269 | vis = cv2.cvtColor(vis.astype('uint8'), cv2.COLOR_BGR2RGB) 270 | print('SAM & CLIP Surgery for texts combination:', text) 271 | plt.imshow(vis) 272 | plt.show() 273 | -------------------------------------------------------------------------------- /clip/clip_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.relu1(self.bn1(self.conv1(x))) 46 | out = self.relu2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | 72 | side = int((self.positional_embedding.shape[0] - 1) ** 0.5) 73 | new_side = int((x.shape[0] - 1) ** 0.5) 74 | 75 | # update the position embedding during inference for varied input size 76 | if side != new_side: 77 | new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2) 78 | new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear') 79 | new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2) 80 | self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0) 81 | 82 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 83 | x, _ = F.multi_head_attention_forward( 84 | query=x, key=x, value=x, 85 | embed_dim_to_check=x.shape[-1], 86 | num_heads=self.num_heads, 87 | q_proj_weight=self.q_proj.weight, 88 | k_proj_weight=self.k_proj.weight, 89 | v_proj_weight=self.v_proj.weight, 90 | in_proj_weight=None, 91 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 92 | bias_k=None, 93 | bias_v=None, 94 | add_zero_attn=False, 95 | dropout_p=0, 96 | out_proj_weight=self.c_proj.weight, 97 | out_proj_bias=self.c_proj.bias, 98 | use_separate_proj_weight=True, 99 | training=self.training, 100 | need_weights=False 101 | ) 102 | 103 | #return x[0] 104 | return x.transpose(0, 1) # return both cls token and image tokens, B,N,C 105 | 106 | 107 | class ModifiedResNet(nn.Module): 108 | """ 109 | A ResNet class that is similar to torchvision's but contains the following changes: 110 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 111 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 112 | - The final pooling layer is a QKV attention instead of an average pool 113 | """ 114 | 115 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 116 | super().__init__() 117 | self.output_dim = output_dim 118 | self.input_resolution = input_resolution 119 | 120 | # the 3-layer stem 121 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 122 | self.bn1 = nn.BatchNorm2d(width // 2) 123 | self.relu1 = nn.ReLU(inplace=True) 124 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 125 | self.bn2 = nn.BatchNorm2d(width // 2) 126 | self.relu2 = nn.ReLU(inplace=True) 127 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 128 | self.bn3 = nn.BatchNorm2d(width) 129 | self.relu3 = nn.ReLU(inplace=True) 130 | self.avgpool = nn.AvgPool2d(2) 131 | 132 | # residual layers 133 | self._inplanes = width # this is a *mutable* variable used during construction 134 | self.layer1 = self._make_layer(width, layers[0]) 135 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 136 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 137 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 138 | 139 | embed_dim = width * 32 # the ResNet feature dimension 140 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 141 | 142 | def _make_layer(self, planes, blocks, stride=1): 143 | layers = [Bottleneck(self._inplanes, planes, stride)] 144 | 145 | self._inplanes = planes * Bottleneck.expansion 146 | for _ in range(1, blocks): 147 | layers.append(Bottleneck(self._inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | def stem(x): 153 | x = self.relu1(self.bn1(self.conv1(x))) 154 | x = self.relu2(self.bn2(self.conv2(x))) 155 | x = self.relu3(self.bn3(self.conv3(x))) 156 | x = self.avgpool(x) 157 | return x 158 | 159 | x = x.type(self.conv1.weight.dtype) 160 | x = stem(x) 161 | x = self.layer1(x) 162 | x = self.layer2(x) 163 | x = self.layer3(x) 164 | x = self.layer4(x) 165 | x = self.attnpool(x) 166 | 167 | return x 168 | 169 | 170 | class LayerNorm(nn.LayerNorm): 171 | """Subclass torch's LayerNorm to handle fp16.""" 172 | 173 | def forward(self, x: torch.Tensor): 174 | orig_type = x.dtype 175 | ret = super().forward(x.type(torch.float32)) 176 | return ret.type(orig_type) 177 | 178 | 179 | class QuickGELU(nn.Module): 180 | def forward(self, x: torch.Tensor): 181 | return x * torch.sigmoid(1.702 * x) 182 | 183 | 184 | class ResidualAttentionBlock(nn.Module): 185 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, need_weights: bool = False): 186 | super().__init__() 187 | 188 | self.attn = nn.MultiheadAttention(d_model, n_head) 189 | self.ln_1 = LayerNorm(d_model) 190 | self.mlp = nn.Sequential(OrderedDict([ 191 | ("c_fc", nn.Linear(d_model, d_model * 4)), 192 | ("gelu", QuickGELU()), 193 | ("c_proj", nn.Linear(d_model * 4, d_model)) 194 | ])) 195 | self.ln_2 = LayerNorm(d_model) 196 | self.attn_mask = attn_mask 197 | self.need_weights = need_weights 198 | 199 | def attention(self, x: torch.Tensor): 200 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 201 | if self.need_weights == False: 202 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 203 | else: 204 | return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask) 205 | 206 | def forward(self, x: torch.Tensor): 207 | if self.need_weights == False: 208 | x = x + self.attention(self.ln_1(x)) 209 | x = x + self.mlp(self.ln_2(x)) 210 | return x 211 | else: 212 | y, attn = self.attention(self.ln_1(x)) 213 | x = x + y 214 | x = x + self.mlp(self.ln_2(x)) 215 | return x 216 | 217 | 218 | class Transformer(nn.Module): 219 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False): 220 | super().__init__() 221 | self.width = width 222 | self.layers = layers 223 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, need_weights if i == layers - 1 else False) for i in range(layers)]) 224 | 225 | def forward(self, x: torch.Tensor): 226 | return self.resblocks(x) 227 | 228 | 229 | class VisionTransformer(nn.Module): 230 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 231 | super().__init__() 232 | self.input_resolution = input_resolution 233 | self.output_dim = output_dim 234 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 235 | 236 | scale = width ** -0.5 237 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 238 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 239 | self.ln_pre = LayerNorm(width) 240 | 241 | self.transformer = Transformer(width, layers, heads, need_weights=True) 242 | 243 | self.ln_post = LayerNorm(width) 244 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 245 | 246 | def forward(self, x: torch.Tensor): 247 | x = self.conv1(x) # shape = [*, width, grid, grid] 248 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 249 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 250 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 251 | x = x + self.positional_embedding.to(x.dtype) 252 | x = self.ln_pre(x) 253 | 254 | x = x.permute(1, 0, 2) # NLD -> LND 255 | x = self.transformer(x) 256 | x = x.permute(1, 0, 2) # LND -> NLD 257 | 258 | #x = self.ln_post(x[:, 0, :]) 259 | x = self.ln_post(x) # return both cls token and image tokens 260 | 261 | if self.proj is not None: 262 | x = x @ self.proj 263 | 264 | return x 265 | 266 | 267 | class CLIP(nn.Module): 268 | def __init__(self, 269 | embed_dim: int, 270 | # vision 271 | image_resolution: int, 272 | vision_layers: Union[Tuple[int, int, int, int], int], 273 | vision_width: int, 274 | vision_patch_size: int, 275 | # text 276 | context_length: int, 277 | vocab_size: int, 278 | transformer_width: int, 279 | transformer_heads: int, 280 | transformer_layers: int 281 | ): 282 | super().__init__() 283 | 284 | self.context_length = context_length 285 | 286 | if isinstance(vision_layers, (tuple, list)): 287 | vision_heads = vision_width * 32 // 64 288 | self.visual = ModifiedResNet( 289 | layers=vision_layers, 290 | output_dim=embed_dim, 291 | heads=vision_heads, 292 | input_resolution=image_resolution, 293 | width=vision_width 294 | ) 295 | else: 296 | vision_heads = vision_width // 64 297 | self.visual = VisionTransformer( 298 | input_resolution=image_resolution, 299 | patch_size=vision_patch_size, 300 | width=vision_width, 301 | layers=vision_layers, 302 | heads=vision_heads, 303 | output_dim=embed_dim 304 | ) 305 | 306 | self.transformer = Transformer( 307 | width=transformer_width, 308 | layers=transformer_layers, 309 | heads=transformer_heads, 310 | attn_mask=self.build_attention_mask() 311 | ) 312 | 313 | self.vocab_size = vocab_size 314 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 315 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 316 | self.ln_final = LayerNorm(transformer_width) 317 | 318 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 319 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 320 | 321 | self.initialize_parameters() 322 | 323 | def initialize_parameters(self): 324 | nn.init.normal_(self.token_embedding.weight, std=0.02) 325 | nn.init.normal_(self.positional_embedding, std=0.01) 326 | 327 | if isinstance(self.visual, ModifiedResNet): 328 | if self.visual.attnpool is not None: 329 | std = self.visual.attnpool.c_proj.in_features ** -0.5 330 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 331 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 332 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 333 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 334 | 335 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 336 | for name, param in resnet_block.named_parameters(): 337 | if name.endswith("bn3.weight"): 338 | nn.init.zeros_(param) 339 | 340 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 341 | attn_std = self.transformer.width ** -0.5 342 | fc_std = (2 * self.transformer.width) ** -0.5 343 | for block in self.transformer.resblocks: 344 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 345 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 346 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 347 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 348 | 349 | if self.text_projection is not None: 350 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 351 | 352 | def build_attention_mask(self): 353 | # lazily create causal attention mask, with full attention between the vision tokens 354 | # pytorch uses additive attention mask; fill with -inf 355 | mask = torch.empty(self.context_length, self.context_length) 356 | mask.fill_(float("-inf")) 357 | mask.triu_(1) # zero out the lower diagonal 358 | return mask 359 | 360 | @property 361 | def dtype(self): 362 | return self.visual.conv1.weight.dtype 363 | 364 | def encode_image(self, image): 365 | return self.visual(image.type(self.dtype)) 366 | 367 | def encode_text(self, text): 368 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 369 | 370 | x = x + self.positional_embedding.type(self.dtype) 371 | x = x.permute(1, 0, 2) # NLD -> LND 372 | x = self.transformer(x) 373 | x = x.permute(1, 0, 2) # LND -> NLD 374 | x = self.ln_final(x).type(self.dtype) 375 | 376 | # x.shape = [batch_size, n_ctx, transformer.width] 377 | # take features from the eot embedding (eot_token is the highest number in each sequence) 378 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 379 | 380 | return x 381 | 382 | def forward(self, image, text): 383 | image_features = self.encode_image(image) 384 | text_features = self.encode_text(text) 385 | 386 | # normalized features 387 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 388 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 389 | 390 | # cosine similarity as logits 391 | logit_scale = self.logit_scale.exp() 392 | logits_per_image = logit_scale * image_features @ text_features.t() 393 | logits_per_text = logits_per_image.t() 394 | 395 | # shape = [global_batch_size, global_batch_size] 396 | return logits_per_image, logits_per_text 397 | -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, ToTensor, Normalize 11 | from tqdm import tqdm 12 | import numpy as np 13 | 14 | from .build_model import build_model 15 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 16 | 17 | try: 18 | from torchvision.transforms import InterpolationMode 19 | BICUBIC = InterpolationMode.BICUBIC 20 | except ImportError: 21 | BICUBIC = Image.BICUBIC 22 | 23 | 24 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 25 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 26 | 27 | 28 | __all__ = ["available_models", "load", "tokenize", "encode_text_with_prompt_ensemble", 29 | "get_similarity_map", "clip_feature_surgery", "similarity_map_to_points"] 30 | _tokenizer = _Tokenizer() 31 | 32 | _MODELS = { 33 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 34 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 35 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 36 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 37 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 38 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 39 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 40 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 41 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 42 | "CS-RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 43 | "CS-RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 44 | "CS-RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 45 | "CS-RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 46 | "CS-RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 47 | "CS-ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 48 | "CS-ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 49 | "CS-ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 50 | "CS-ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 51 | } 52 | 53 | 54 | def _download(url: str, root: str): 55 | os.makedirs(root, exist_ok=True) 56 | filename = os.path.basename(url) 57 | 58 | expected_sha256 = url.split("/")[-2] 59 | download_target = os.path.join(root, filename) 60 | 61 | if os.path.exists(download_target) and not os.path.isfile(download_target): 62 | raise RuntimeError(f"{download_target} exists and is not a regular file") 63 | 64 | if os.path.isfile(download_target): 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 66 | return download_target 67 | else: 68 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 69 | 70 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 71 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 72 | while True: 73 | buffer = source.read(8192) 74 | if not buffer: 75 | break 76 | 77 | output.write(buffer) 78 | loop.update(len(buffer)) 79 | 80 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 81 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 82 | 83 | return download_target 84 | 85 | 86 | def _convert_image_to_rgb(image): 87 | return image.convert("RGB") 88 | 89 | 90 | def _transform(n_px): 91 | return Compose([ 92 | Resize((n_px, n_px), interpolation=BICUBIC), 93 | #CenterCrop(n_px), # rm center crop to explain whole image 94 | _convert_image_to_rgb, 95 | ToTensor(), 96 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 97 | ]) 98 | 99 | 100 | def available_models() -> List[str]: 101 | """Returns the names of available CLIP models""" 102 | return list(_MODELS.keys()) 103 | 104 | 105 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 106 | """Load a CLIP model 107 | 108 | Parameters 109 | ---------- 110 | name : str 111 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 112 | 113 | device : Union[str, torch.device] 114 | The device to put the loaded model 115 | 116 | jit : bool 117 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 118 | 119 | download_root: str 120 | path to download the model files; by default, it uses "~/.cache/clip" 121 | 122 | Returns 123 | ------- 124 | model : torch.nn.Module 125 | The CLIP model 126 | 127 | preprocess : Callable[[PIL.Image], torch.Tensor] 128 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 129 | """ 130 | if name in _MODELS: 131 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 132 | elif os.path.isfile(name): 133 | model_path = name 134 | else: 135 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 136 | 137 | with open(model_path, 'rb') as opened_file: 138 | try: 139 | # loading JIT archive 140 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 141 | state_dict = None 142 | except RuntimeError: 143 | # loading saved state dict 144 | if jit: 145 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 146 | jit = False 147 | state_dict = torch.load(opened_file, map_location="cpu") 148 | 149 | if not jit: 150 | model = build_model(name, state_dict or model.state_dict()).to(device) 151 | if str(device) == "cpu": 152 | model.float() 153 | return model, _transform(model.visual.input_resolution) 154 | 155 | # patch the device names 156 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 157 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 158 | 159 | def patch_device(module): 160 | try: 161 | graphs = [module.graph] if hasattr(module, "graph") else [] 162 | except RuntimeError: 163 | graphs = [] 164 | 165 | if hasattr(module, "forward1"): 166 | graphs.append(module.forward1.graph) 167 | 168 | for graph in graphs: 169 | for node in graph.findAllNodes("prim::Constant"): 170 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 171 | node.copyAttributes(device_node) 172 | 173 | model.apply(patch_device) 174 | patch_device(model.encode_image) 175 | patch_device(model.encode_text) 176 | 177 | # patch dtype to float32 on CPU 178 | if str(device) == "cpu": 179 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 180 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 181 | float_node = float_input.node() 182 | 183 | def patch_float(module): 184 | try: 185 | graphs = [module.graph] if hasattr(module, "graph") else [] 186 | except RuntimeError: 187 | graphs = [] 188 | 189 | if hasattr(module, "forward1"): 190 | graphs.append(module.forward1.graph) 191 | 192 | for graph in graphs: 193 | for node in graph.findAllNodes("aten::to"): 194 | inputs = list(node.inputs()) 195 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 196 | if inputs[i].node()["value"] == 5: 197 | inputs[i].node().copyAttributes(float_node) 198 | 199 | model.apply(patch_float) 200 | patch_float(model.encode_image) 201 | patch_float(model.encode_text) 202 | 203 | model.float() 204 | 205 | return model, _transform(model.input_resolution.item()) 206 | 207 | 208 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 209 | """ 210 | Returns the tokenized representation of given input string(s) 211 | 212 | Parameters 213 | ---------- 214 | texts : Union[str, List[str]] 215 | An input string or a list of input strings to tokenize 216 | 217 | context_length : int 218 | The context length to use; all CLIP models use 77 as the context length 219 | 220 | truncate: bool 221 | Whether to truncate the text in case its encoding is longer than the context length 222 | 223 | Returns 224 | ------- 225 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 226 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 227 | """ 228 | if isinstance(texts, str): 229 | texts = [texts] 230 | 231 | sot_token = _tokenizer.encoder["<|startoftext|>"] 232 | eot_token = _tokenizer.encoder["<|endoftext|>"] 233 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 234 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 235 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 236 | else: 237 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 238 | 239 | for i, tokens in enumerate(all_tokens): 240 | if len(tokens) > context_length: 241 | if truncate: 242 | tokens = tokens[:context_length] 243 | tokens[-1] = eot_token 244 | else: 245 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 246 | result[i, :len(tokens)] = torch.tensor(tokens) 247 | 248 | return result 249 | 250 | 251 | def encode_text_with_prompt_ensemble(model, texts, device, prompt_templates=None): 252 | 253 | # using default prompt templates for ImageNet 254 | if prompt_templates == None: 255 | prompt_templates = ['a bad photo of a {}.', 'a photo of many {}.', 'a sculpture of a {}.', 'a photo of the hard to see {}.', 'a low resolution photo of the {}.', 'a rendering of a {}.', 'graffiti of a {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a tattoo of a {}.', 'the embroidered {}.', 'a photo of a hard to see {}.', 'a bright photo of a {}.', 'a photo of a clean {}.', 'a photo of a dirty {}.', 'a dark photo of the {}.', 'a drawing of a {}.', 'a photo of my {}.', 'the plastic {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a painting of the {}.', 'a painting of a {}.', 'a pixelated photo of the {}.', 'a sculpture of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a plastic {}.', 'a photo of the dirty {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a rendering of the {}.', 'a {} in a video game.', 'a photo of one {}.', 'a doodle of a {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'the origami {}.', 'the {} in a video game.', 'a sketch of a {}.', 'a doodle of the {}.', 'a origami {}.', 'a low resolution photo of a {}.', 'the toy {}.', 'a rendition of the {}.', 'a photo of the clean {}.', 'a photo of a large {}.', 'a rendition of a {}.', 'a photo of a nice {}.', 'a photo of a weird {}.', 'a blurry photo of a {}.', 'a cartoon {}.', 'art of a {}.', 'a sketch of the {}.', 'a embroidered {}.', 'a pixelated photo of a {}.', 'itap of the {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a plushie {}.', 'a photo of the nice {}.', 'a photo of the small {}.', 'a photo of the weird {}.', 'the cartoon {}.', 'art of the {}.', 'a drawing of the {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'the plushie {}.', 'a dark photo of a {}.', 'itap of a {}.', 'graffiti of the {}.', 'a toy {}.', 'itap of my {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'a tattoo of the {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.'] 256 | 257 | text_features = [] 258 | for t in texts: 259 | prompted_t = [template.format(t) for template in prompt_templates] 260 | prompted_t = tokenize(prompted_t).to(device) 261 | class_embeddings = model.encode_text(prompted_t) 262 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 263 | class_embedding = class_embeddings.mean(dim=0) 264 | class_embedding /= class_embedding.norm() 265 | text_features.append(class_embedding) 266 | text_features = torch.stack(text_features, dim=1).to(device).t() 267 | 268 | return text_features 269 | 270 | 271 | def get_similarity_map(sm, shape): 272 | 273 | # min-max norm 274 | sm = (sm - sm.min(1, keepdim=True)[0]) / (sm.max(1, keepdim=True)[0] - sm.min(1, keepdim=True)[0]) 275 | 276 | # reshape 277 | side = int(sm.shape[1] ** 0.5) # square output 278 | sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2) 279 | 280 | # interpolate 281 | sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear') 282 | sm = sm.permute(0, 2, 3, 1) 283 | 284 | return sm 285 | 286 | 287 | def clip_feature_surgery(image_features, text_features, redundant_feats=None, t=2): 288 | 289 | if redundant_feats != None: 290 | similarity = image_features @ (text_features - redundant_feats).t() 291 | 292 | else: 293 | # weights to restrain influence of obvious classes on others 294 | prob = image_features[:, :1, :] @ text_features.t() 295 | prob = (prob * 2).softmax(-1) 296 | w = prob / prob.mean(-1, keepdim=True) 297 | 298 | # element-wise multiplied features 299 | b, n_t, n_i, c = image_features.shape[0], text_features.shape[0], image_features.shape[1], image_features.shape[2] 300 | feats = image_features.reshape(b, n_i, 1, c) * text_features.reshape(1, 1, n_t, c) 301 | feats *= w.reshape(1, 1, n_t, 1) 302 | redundant_feats = feats.mean(2, keepdim=True) # along cls dim 303 | feats = feats - redundant_feats 304 | 305 | # sum the element-wise multiplied features as cosine similarity 306 | similarity = feats.sum(-1) 307 | 308 | return similarity 309 | 310 | 311 | # sm shape N_t 312 | def similarity_map_to_points(sm, shape, t=0.8, down_sample=2): 313 | side = int(sm.shape[0] ** 0.5) 314 | sm = sm.reshape(1, 1, side, side) 315 | 316 | # down sample to smooth results 317 | down_side = side // down_sample 318 | sm = torch.nn.functional.interpolate(sm, (down_side, down_side), mode='bilinear')[0, 0, :, :] 319 | h, w = sm.shape 320 | sm = sm.reshape(-1) 321 | 322 | sm = (sm - sm.min()) / (sm.max() - sm.min()) 323 | rank = sm.sort(0)[1] 324 | scale_h = float(shape[0]) / h 325 | scale_w = float(shape[1]) / w 326 | 327 | num = min((sm >= t).sum(), sm.shape[0] // 2) 328 | labels = np.ones(num * 2).astype('uint8') 329 | labels[num:] = 0 330 | points = [] 331 | 332 | # positives 333 | for idx in rank[-num:]: 334 | x = min((idx % w + 0.5) * scale_w, shape[1] - 1) # +0.5 to center 335 | y = min((idx // w + 0.5) * scale_h, shape[0] - 1) 336 | points.append([int(x.item()), int(y.item())]) 337 | 338 | # negatives 339 | for idx in rank[:num]: 340 | x = min((idx % w + 0.5) * scale_w, shape[1] - 1) 341 | y = min((idx // w + 0.5) * scale_h, shape[0] - 1) 342 | points.append([int(x.item()), int(y.item())]) 343 | 344 | return points, labels 345 | -------------------------------------------------------------------------------- /clip/clip_surgery_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | expansion = 4 11 | 12 | def __init__(self, inplanes, planes, stride=1): 13 | super().__init__() 14 | 15 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 16 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.relu1 = nn.ReLU(inplace=True) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.relu2 = nn.ReLU(inplace=True) 23 | 24 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 25 | 26 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 27 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 28 | self.relu3 = nn.ReLU(inplace=True) 29 | 30 | self.downsample = None 31 | self.stride = stride 32 | 33 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 34 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 35 | self.downsample = nn.Sequential(OrderedDict([ 36 | ("-1", nn.AvgPool2d(stride)), 37 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 38 | ("1", nn.BatchNorm2d(planes * self.expansion)) 39 | ])) 40 | 41 | def forward(self, x: torch.Tensor): 42 | identity = x 43 | 44 | out = self.relu1(self.bn1(self.conv1(x))) 45 | out = self.relu2(self.bn2(self.conv2(out))) 46 | out = self.avgpool(out) 47 | out = self.bn3(self.conv3(out)) 48 | 49 | if self.downsample is not None: 50 | identity = self.downsample(x) 51 | 52 | out += identity 53 | out = self.relu3(out) 54 | return out 55 | 56 | 57 | # implement attention module for v-v self-attention 58 | class Attention(nn.Module): 59 | def __init__(self, out_dim, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., settings=''): 60 | super().__init__() 61 | self.num_heads = num_heads 62 | head_dim = dim // num_heads 63 | self.scale = qk_scale or head_dim ** -0.5 64 | 65 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 66 | self.attn_drop = nn.Dropout(attn_drop) 67 | self.proj = nn.Linear(out_dim, dim) 68 | self.proj_drop = nn.Dropout(proj_drop) 69 | self.settings = settings 70 | 71 | def forward(self, x): 72 | B, N, C = x.shape 73 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 74 | q, k, v = qkv[0], qkv[1], qkv[2] 75 | 76 | # original self-attention for the original path 77 | attn_ori = (q @ k.transpose(-2, -1)) * self.scale 78 | attn_ori = attn_ori.softmax(dim=-1) 79 | attn_ori = self.attn_drop(attn_ori) 80 | 81 | # replace k & q by v 82 | k = v 83 | q = k 84 | 85 | # resnets have only one self-attention, norm and larger scale perform better 86 | if self.settings == 'resnet': 87 | k = k / (k.norm(p=2, dim=-1, keepdim=True) + 1e-6) 88 | q = k 89 | scale = self.scale * 8 90 | else: 91 | scale = self.scale 92 | 93 | # self-attention, higher temperate for resnets performs better 94 | attn = (q @ k.transpose(-2, -1)) * scale 95 | attn = (attn).softmax(dim=-1) 96 | attn = self.attn_drop(attn) 97 | 98 | x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C) 99 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) # clip_surgery 100 | #x = v.transpose(1, 2).reshape(B, N, C) # mask_clip 101 | x = self.proj_drop(self.proj(x)) 102 | x_ori = self.proj_drop(self.proj(x_ori)) 103 | return [x, x_ori] 104 | 105 | 106 | class AttentionPool2d(nn.Module): 107 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 108 | super().__init__() 109 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 110 | self.k_proj = nn.Linear(embed_dim, embed_dim) 111 | self.q_proj = nn.Linear(embed_dim, embed_dim) 112 | self.v_proj = nn.Linear(embed_dim, embed_dim) 113 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 114 | self.num_heads = num_heads 115 | 116 | self.attn = None 117 | self.embed_dim = embed_dim 118 | self.num_heads = num_heads 119 | self.output_dim = output_dim 120 | 121 | 122 | def forward(self, x): 123 | # reform transformer layer after init and load weights, using v only 124 | if self.attn == None: 125 | self.attn = Attention(self.output_dim, self.embed_dim, self.num_heads, True) 126 | self.attn.qkv.weight = torch.nn.Parameter(torch.cat([self.v_proj.weight, self.v_proj.weight, self.v_proj.weight], 0)) 127 | self.attn.qkv.bias = torch.nn.Parameter(torch.cat([self.v_proj.bias, self.v_proj.bias, self.v_proj.bias])) 128 | self.attn.proj.weight = self.c_proj.weight 129 | self.attn.proj.bias = self.c_proj.bias 130 | 131 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 132 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 133 | 134 | side = int((self.positional_embedding.shape[0] - 1) ** 0.5) 135 | new_side = int((x.shape[0] - 1) ** 0.5) 136 | 137 | # update the position embedding during inference for varied input size 138 | if side != new_side: 139 | new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2) 140 | new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear') 141 | new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2) 142 | self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0) 143 | 144 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 145 | x, x_ori = self.attn(x.transpose(0, 1)) 146 | 147 | # cls token from the original path, and img tokens from the new path 148 | x[:, 0, :] = x_ori[:, 0, :] 149 | return x 150 | 151 | 152 | class ModifiedResNet(nn.Module): 153 | """ 154 | A ResNet class that is similar to torchvision's but contains the following changes: 155 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 156 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 157 | - The final pooling layer is a QKV attention instead of an average pool 158 | """ 159 | 160 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 161 | super().__init__() 162 | self.output_dim = output_dim 163 | self.input_resolution = input_resolution 164 | 165 | # the 3-layer stem 166 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 167 | self.bn1 = nn.BatchNorm2d(width // 2) 168 | self.relu1 = nn.ReLU(inplace=True) 169 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 170 | self.bn2 = nn.BatchNorm2d(width // 2) 171 | self.relu2 = nn.ReLU(inplace=True) 172 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 173 | self.bn3 = nn.BatchNorm2d(width) 174 | self.relu3 = nn.ReLU(inplace=True) 175 | self.avgpool = nn.AvgPool2d(2) 176 | 177 | # residual layers 178 | self._inplanes = width # this is a *mutable* variable used during construction 179 | self.layer1 = self._make_layer(width, layers[0]) 180 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 181 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 182 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 183 | 184 | embed_dim = width * 32 # the ResNet feature dimension 185 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 186 | 187 | def _make_layer(self, planes, blocks, stride=1): 188 | layers = [Bottleneck(self._inplanes, planes, stride)] 189 | 190 | self._inplanes = planes * Bottleneck.expansion 191 | for _ in range(1, blocks): 192 | layers.append(Bottleneck(self._inplanes, planes)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def forward(self, x): 197 | def stem(x): 198 | x = self.relu1(self.bn1(self.conv1(x))) 199 | x = self.relu2(self.bn2(self.conv2(x))) 200 | x = self.relu3(self.bn3(self.conv3(x))) 201 | x = self.avgpool(x) 202 | return x 203 | 204 | x = x.type(self.conv1.weight.dtype) 205 | x = stem(x) 206 | x = self.layer1(x) 207 | x = self.layer2(x) 208 | x = self.layer3(x) 209 | x = self.layer4(x) 210 | x = self.attnpool(x) 211 | 212 | # shape BNC 213 | return x 214 | 215 | 216 | class LayerNorm(nn.LayerNorm): 217 | """Subclass torch's LayerNorm to handle fp16.""" 218 | 219 | def forward(self, x: torch.Tensor): 220 | orig_type = x.dtype 221 | ret = super().forward(x.type(torch.float32)) 222 | return ret.type(orig_type) 223 | 224 | 225 | class QuickGELU(nn.Module): 226 | def forward(self, x: torch.Tensor): 227 | return x * torch.sigmoid(1.702 * x) 228 | 229 | 230 | class ResidualAttentionBlock(nn.Module): 231 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 232 | super().__init__() 233 | 234 | self.attn = nn.MultiheadAttention(d_model, n_head) 235 | self.ln_1 = LayerNorm(d_model) 236 | self.mlp = nn.Sequential(OrderedDict([ 237 | ("c_fc", nn.Linear(d_model, d_model * 4)), 238 | ("gelu", QuickGELU()), 239 | ("c_proj", nn.Linear(d_model * 4, d_model)) 240 | ])) 241 | self.ln_2 = LayerNorm(d_model) 242 | self.attn_mask = attn_mask 243 | 244 | def attention(self, x: torch.Tensor): 245 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 246 | if isinstance(self.attn, Attention): 247 | x = x.transpose(0, 1) 248 | x, x_ori = self.attn(x) 249 | return [x.transpose(0, 1), x_ori.transpose(0, 1)] 250 | else: 251 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 252 | 253 | def forward(self, x): 254 | 255 | # dual paths for blocks deeper than "d" 256 | if isinstance(self.attn, Attention): 257 | if isinstance(x, list): 258 | x, x_ori = x 259 | x_res = self.attention(self.ln_1(x_ori)) 260 | x_res, x_ori_res = x_res 261 | x_ori += x_ori_res 262 | x_ori = x_ori + self.mlp(self.ln_2(x_ori)) 263 | x += x_res # skip ffn for the new path 264 | return [x, x_ori] 265 | 266 | # start of dual path 267 | else: 268 | x_res = self.attention(self.ln_1(x)) 269 | if isinstance(x_res, list): 270 | x_res, x_ori_res = x_res 271 | x_ori = x + x_ori_res 272 | x_ori = x_ori + self.mlp(self.ln_2(x_ori)) 273 | x += x_res 274 | return [x, x_ori] 275 | 276 | # singl path before "d" 277 | else: 278 | x = x + self.attention(self.ln_1(x)) 279 | x = x + self.mlp(self.ln_2(x)) 280 | return x 281 | 282 | 283 | class Transformer(nn.Module): 284 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False): 285 | super().__init__() 286 | self.width = width 287 | self.layers = layers 288 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for i in range(layers)]) 289 | 290 | def forward(self, x: torch.Tensor): 291 | return self.resblocks(x) 292 | 293 | 294 | class VisionTransformer(nn.Module): 295 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 296 | super().__init__() 297 | self.input_resolution = input_resolution 298 | self.output_dim = output_dim 299 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 300 | 301 | scale = width ** -0.5 302 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 303 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 304 | self.ln_pre = LayerNorm(width) 305 | 306 | self.transformer = Transformer(width, layers, heads, need_weights=True) 307 | self.attn = None 308 | self.embed_dim = width 309 | self.num_heads = heads 310 | 311 | self.ln_post = LayerNorm(width) 312 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 313 | 314 | @torch.no_grad() 315 | def forward(self, x: torch.Tensor): 316 | 317 | # reform the architecture during first inference 318 | if self.attn == None: 319 | 320 | # apply architecture surgery on the last 6 blocks 321 | for i in range(1, 7): # surgery 7, maskclip 2 322 | self.attn = Attention(self.embed_dim, self.embed_dim, self.num_heads, True) 323 | self.attn.qkv.weight.data = self.transformer.resblocks[-i].attn.in_proj_weight.clone() 324 | self.attn.qkv.bias.data = self.transformer.resblocks[-i].attn.in_proj_bias.clone() 325 | self.attn.proj.weight.data = self.transformer.resblocks[-i].attn.out_proj.weight.clone() 326 | self.attn.proj.bias.data = self.transformer.resblocks[-i].attn.out_proj.bias.clone() 327 | self.transformer.resblocks[-i].attn = self.attn 328 | 329 | x = self.conv1(x) # shape = [*, width, grid, grid] 330 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 331 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 332 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 333 | side = int((self.positional_embedding.shape[0] - 1) ** 0.5) 334 | new_side = int((x.shape[1] - 1) ** 0.5) 335 | 336 | # update the position embedding during inference for varied input size 337 | if side != new_side: 338 | new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2) 339 | new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear') 340 | new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2) 341 | self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0) 342 | 343 | pos = self.positional_embedding.to(x.dtype) 344 | x = x + pos 345 | x = self.ln_pre(x) 346 | 347 | x = x.permute(1, 0, 2) # NLD -> LND 348 | x, x_ori = self.transformer(x) 349 | x[0, :, :] = x_ori[0, :, :] # clip_surgery 350 | x = x.permute(1, 0, 2) # LND -> NLD 351 | 352 | x = self.ln_post(x) 353 | x = x @ self.proj 354 | 355 | return x 356 | 357 | 358 | class CLIPSurgery(nn.Module): 359 | def __init__(self, 360 | embed_dim: int, 361 | # vision 362 | image_resolution: int, 363 | vision_layers: Union[Tuple[int, int, int, int], int], 364 | vision_width: int, 365 | vision_patch_size: int, 366 | # text 367 | context_length: int, 368 | vocab_size: int, 369 | transformer_width: int, 370 | transformer_heads: int, 371 | transformer_layers: int 372 | ): 373 | super().__init__() 374 | 375 | self.context_length = context_length 376 | 377 | if isinstance(vision_layers, (tuple, list)): 378 | vision_heads = vision_width * 32 // 64 379 | self.visual = ModifiedResNet( 380 | layers=vision_layers, 381 | output_dim=embed_dim, 382 | heads=vision_heads, 383 | input_resolution=image_resolution, 384 | width=vision_width 385 | ) 386 | else: 387 | vision_heads = vision_width // 64 388 | self.visual = VisionTransformer( 389 | input_resolution=image_resolution, 390 | patch_size=vision_patch_size, 391 | width=vision_width, 392 | layers=vision_layers, 393 | heads=vision_heads, 394 | output_dim=embed_dim 395 | ) 396 | 397 | self.transformer = Transformer( 398 | width=transformer_width, 399 | layers=transformer_layers, 400 | heads=transformer_heads, 401 | attn_mask=self.build_attention_mask() 402 | ) 403 | 404 | self.vocab_size = vocab_size 405 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 406 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 407 | self.ln_final = LayerNorm(transformer_width) 408 | 409 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 410 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 411 | 412 | self.initialize_parameters() 413 | 414 | def initialize_parameters(self): 415 | nn.init.normal_(self.token_embedding.weight, std=0.02) 416 | nn.init.normal_(self.positional_embedding, std=0.01) 417 | 418 | if isinstance(self.visual, ModifiedResNet): 419 | if self.visual.attnpool is not None: 420 | std = self.visual.attnpool.c_proj.in_features ** -0.5 421 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 422 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 423 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 424 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 425 | 426 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 427 | for name, param in resnet_block.named_parameters(): 428 | if name.endswith("bn3.weight"): 429 | nn.init.zeros_(param) 430 | 431 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 432 | attn_std = self.transformer.width ** -0.5 433 | fc_std = (2 * self.transformer.width) ** -0.5 434 | for block in self.transformer.resblocks: 435 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 436 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 437 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 438 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 439 | 440 | if self.text_projection is not None: 441 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 442 | 443 | def build_attention_mask(self): 444 | # lazily create causal attention mask, with full attention between the vision tokens 445 | # pytorch uses additive attention mask; fill with -inf 446 | mask = torch.empty(self.context_length, self.context_length) 447 | mask.fill_(float("-inf")) 448 | mask.triu_(1) # zero out the lower diagonal 449 | return mask 450 | 451 | @property 452 | def dtype(self): 453 | return self.visual.conv1.weight.dtype 454 | 455 | def encode_image(self, image): 456 | return self.visual(image.type(self.dtype)) 457 | 458 | def encode_text(self, text): 459 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 460 | 461 | x = x + self.positional_embedding.type(self.dtype) 462 | x = x.permute(1, 0, 2) # NLD -> LND 463 | x = self.transformer(x) 464 | x = x.permute(1, 0, 2) # LND -> NLD 465 | x = self.ln_final(x).type(self.dtype) 466 | 467 | # x.shape = [batch_size, n_ctx, transformer.width] 468 | # take features from the eot embedding (eot_token is the highest number in each sequence) 469 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 470 | 471 | return x 472 | 473 | def forward(self, image, text): 474 | image_features = self.encode_image(image) 475 | text_features = self.encode_text(text) 476 | 477 | # normalized features 478 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 479 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 480 | 481 | # cosine similarity as logits 482 | logit_scale = self.logit_scale.exp() 483 | logits_per_image = logit_scale * image_features @ text_features.t() 484 | logits_per_text = logits_per_image.t() 485 | 486 | # shape = [global_batch_size, global_batch_size] 487 | return logits_per_image, logits_per_text 488 | --------------------------------------------------------------------------------