├── example_facets.py ├── example_usage.py ├── model-card.md ├── model.py ├── requirements.txt └── tokenizer.py /example_facets.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from model import CLIPImage, CLIPText 3 | import tensorflow as tf 4 | import os 5 | import numpy as np 6 | 7 | from lucid.optvis import objectives, param 8 | import lucid.optvis.render as render 9 | from lucid.optvis.objectives import wrap_objective, diversity 10 | import lucid.optvis.transform as transform 11 | from lucid.misc.io import load, save 12 | 13 | 14 | @wrap_objective() 15 | def l2(batch=None): 16 | def inner(T): 17 | return -tf.reduce_mean((T("input") - 0.5)**2) 18 | return inner 19 | 20 | @wrap_objective() 21 | def vector(layer, d, batch=None): 22 | def inner(T): 23 | channel_obj = tf.reduce_mean( tf.einsum( "ijkl,j->ikl", tf.nn.relu(T(layer)), tf.constant(d) ), [1,2]) 24 | channel_obj_weighted = tf.reduce_mean(channel_obj)**(1/1) 25 | return channel_obj_weighted 26 | return inner 27 | 28 | @wrap_objective() 29 | def attr(obj, style_attrs, layers, strength): 30 | def inner(T): 31 | style = tf.constant(style_attrs) 32 | obj_t = obj(T) 33 | layer_t = T(layers[0]) 34 | w = tf.linspace(strength[0], strength[1], tf.shape(layer_t)[0]) 35 | batch_n, _, _, _ = layer_t.get_shape().as_list() 36 | style = tf.transpose(style, (0,2,3,1)) 37 | style = tf.image.resize(style, (tf.shape(layer_t)[2],tf.shape(layer_t)[3])) 38 | style = tf.transpose(style, (0,3,1,2)) 39 | flat_attrs = [] 40 | grads = tf.gradients(obj_t, [T(layer) for layer in layers]) 41 | for layer, grad_t in zip(layers, grads): 42 | layer_t = T(layer) 43 | attr_t = layer_t * tf.nn.relu(tf.stop_gradient(grad_t)) 44 | if len(style_attrs.shape) == 2: 45 | flat_attr_t = tf.reduce_sum(attr_t, axis=(2,3)) 46 | elif len(style_attrs.shape) == 4: 47 | flat_attr_t = attr_t 48 | flat_attrs.append(flat_attr_t) 49 | flat_attr_t = tf.concat(flat_attrs, -1) 50 | return tf.reduce_sum(w[:,None,None,None]*flat_attr_t*style) 51 | return inner 52 | 53 | def render_facet(model, neuron_obj, layers, style_attrs, strength = (0.1, 0.3), l2_weight = 10.0, resolution = 128, alpha = False): 54 | 55 | def mean_alpha(): 56 | def inner(T): 57 | input_t = T("input") 58 | return tf.sqrt(tf.reduce_mean(input_t[..., 3:] ** 2)) 59 | return objectives.Objective(inner) 60 | 61 | standard_transforms = [ 62 | transform.pad(2, mode='constant', constant_value=.5), 63 | transform.jitter(4), 64 | transform.jitter(4), 65 | transform.jitter(4), 66 | transform.jitter(4), 67 | transform.jitter(4), 68 | transform.jitter(4), 69 | transform.jitter(4), 70 | transform.jitter(4), 71 | transform.jitter(4), 72 | transform.jitter(4), 73 | transform.random_scale([0.995**n for n in range(-5,80)] + [0.998**n for n in 2*list(range(20,40))]), 74 | transform.random_rotate(list(range(-20,20))+list(range(-10,10))+list(range(-5,5))+5*[0]), 75 | transform.jitter(2), 76 | transform.crop_or_pad_to(resolution, resolution) 77 | ] 78 | 79 | if alpha: 80 | standard_transforms.append(transform.collapse_alpha_random()) 81 | param_f = lambda: param.image(resolution, batch=9, alpha=True) 82 | else: 83 | param_f = lambda: param.image(resolution, batch=9) 84 | 85 | optimizer = tf.train.AdamOptimizer(0.02) 86 | ultimate_layer = [n.name for n in model.graph_def.node if "image_block_4" in n.name][-1] 87 | obj = vector(ultimate_layer, neuron_obj) 88 | facetsp = [(5/len(layers))*attr(obj, style, [layer], strength) for style, layer in list(zip(style_attrs, layers))] 89 | for facetp in facetsp: 90 | obj = obj + facetp 91 | obj = obj + l2_weight*l2() 92 | if alpha: 93 | obj -= mean_alpha() 94 | obj -= 1e2 * objectives.blur_alpha_each_step() 95 | data = render.render_vis(model, obj, param_f, transforms=standard_transforms, optimizer=optimizer, thresholds=(1024*4,)) 96 | return data 97 | 98 | def one_hot(ind): 99 | z = np.zeros(2560) 100 | z[ind] = 1 101 | return z.astype(np.float32) 102 | 103 | facets = ["face", "text", "logo", "pose", "arch", "nature", "indoor"] 104 | model = CLIPImage() 105 | d = one_hot(100) 106 | 107 | for facet in facets: 108 | layernames = [n.name for n in model.graph_def.node if ("image_block_3" in n.name) and ("Relu_2" in n.name)][::2] 109 | def loadnpy(url): 110 | import blobfile 111 | from io import BytesIO 112 | fp = blobfile.BlobFile(url, "rb") 113 | x = np.load(BytesIO(fp.read())) 114 | fp.close() 115 | return x 116 | 117 | style_attrs = [loadnpy(f"https://openaipublic.blob.core.windows.net/clip/facets/{model.name}/{layername}/{facet}_spatial.npy") for layername in layernames] 118 | for l2_weight in [10]: 119 | img = render_facet(model, 120 | d, 121 | layernames, 122 | style_attrs, 123 | l2_weight = l2_weight, 124 | strength = (0.1, 5.0), 125 | alpha = False, 126 | resolution = 256) 127 | save(img[0][-1], f"/root/{facet}.png") 128 | -------------------------------------------------------------------------------- /example_usage.py: -------------------------------------------------------------------------------- 1 | from tokenizer import SimpleTokenizer 2 | from model import CLIPImage, CLIPText 3 | import tensorflow as tf 4 | from lucid.misc.io import load 5 | import numpy as np 6 | 7 | def imresize(img, size, scale=255): 8 | from PIL import Image 9 | im = Image.fromarray((img*scale).astype(np.uint8) ) 10 | return np.array(im.resize(size, Image.BICUBIC)).astype(np.float32)/scale 11 | 12 | tokenizer = SimpleTokenizer() 13 | 14 | tf.reset_default_graph() 15 | inp_text, T_text = CLIPText().load() 16 | inp_img, T_img = CLIPImage().load() 17 | 18 | sess = tf.Session() 19 | 20 | captions = ["This is a dog", "This is a cat", "This is a dog and a cat"] 21 | tokens = [] 22 | for caption in captions: 23 | tokens.append(tokenizer.tokenize(caption)[0]) 24 | 25 | img = imresize(load("https://openaipublic.blob.core.windows.net/clarity/dog_cat.jpeg"), [288,288]) 26 | 27 | text_embd = sess.run(T_text("text_post/l2_normalize"), {inp_text: tokens}) 28 | img_embd = sess.run(T_img("l2_normalize"), {inp_img: [img]}) 29 | 30 | scores = (text_embd @ img_embd.T)[:,0] 31 | 32 | for score, caption in zip(scores, captions): 33 | print(caption, score) -------------------------------------------------------------------------------- /model-card.md: -------------------------------------------------------------------------------- 1 | # Model Card: CLIP 2 | 3 | Inspired by [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf), we’re providing some accompanying information about the multimodal model. 4 | 5 | ## Model Details 6 | 7 | The CLIP model was developed by researchers at OpenAI to learn about what contributes to robustness in computer vision tasks. The model was also developed to test the ability of models to generalize to arbitrary image classification tasks in a zero-shot manner. It was not developed for general model deployment - to deploy models like CLIP, researchers will first need to carefully study their capabilities in relation to the specific context they’re being deployed within. 8 | 9 | ### Model Date 10 | 11 | January 2021 12 | 13 | ### Model Type 14 | 15 | The base model uses a ResNet50 with several modifications as an image encoder and uses a masked self-attention Transformer as a text encoder. These encoders are trained to maximize the similarity of (image, text) pairs via a contrastive loss. There is also a variant of the model where the ResNet image encoder is replaced with a Vision Transformer. 16 | 17 | ### Model Version 18 | 19 | We are releasing RN50x4, a RN50 scaled up 4x according to the [EfficientNet](https://arxiv.org/abs/1905.11946) scaling rule. 20 | 21 | Please see the paper linked below for further details about their specification. 22 | 23 | ### Documents 24 | 25 | - [Blog Post](https://openai.com/blog/clip/) 26 | - [CLIP Paper](https://arxiv.org/abs/2103.00020) 27 | 28 | 29 | 30 | ## Model Use 31 | 32 | ### Intended Use 33 | 34 | The model is intended as a research output for research communities. We hope that this model will enable researchers to better understand and explore zero-shot, arbitrary image classification. We also hope it can be used for interdisciplinary studies of the potential impact of such models - the CLIP paper includes a discussion of potential downstream impacts to provide an example for this sort of analysis. 35 | 36 | #### Primary intended uses 37 | 38 | The primary intended users of these models are AI researchers. 39 | 40 | We primarily imagine the model will be used by researchers to better understand robustness, generalization, and other capabilities, biases, and constraints of computer vision models. 41 | 42 | ### Out-of-Scope Use Cases 43 | 44 | **Any** deployed use case of the model - whether commercial or not - is currently out of scope. Non-deployed use cases such as image search in a constrained environment, are also not recommended unless there is thorough in-domain testing of the model with a specific, fixed class taxonomy. This is because our safety assessment demonstrated a high need for task specific testing especially given the variability of CLIP’s performance with different class taxonomies. This makes untested and unconstrained deployment of the model in any use case currently potentially harmful. 45 | 46 | Certain use cases which would fall under the domain of surveillance and facial recognition are always out-of-scope regardless of performance of the model. This is because the use of artificial intelligence for tasks such as these can be premature currently given the lack of testing norms and checks to ensure its fair use. 47 | 48 | Since the model has not been purposefully trained in or evaluated on any languages other than English, its use should be limited to English language use cases. 49 | 50 | 51 | 52 | ## Data 53 | 54 | The model was trained on publicly available image-caption data. This was done through a combination of crawling a handful of websites and using commonly-used pre-existing image datasets such as [YFCC100M](http://projects.dfki.uni-kl.de/yfcc100m/). A large portion of the data comes from our crawling of the internet. This means that the data is more representative of people and societies most connected to the internet which tend to skew towards more developed nations, and younger, male users. 55 | 56 | ### Data Mission Statement 57 | 58 | Our goal with building this dataset was to test out robustness and generalizability in computer vision tasks. As a result, the focus was on gathering large quantities of data from different publicly-available internet data sources. The data was gathered in a mostly non-interventionist manner. However, we only crawled websites that had policies against excessively violent and adult images and allowed us to filter out such content. We do not intend for this dataset to be used as the basis for any commercial or deployed model and will not be releasing the dataset. 59 | 60 | 61 | 62 | ## Performance and Limitations 63 | 64 | ### Performance 65 | 66 | We have evaluated the performance of CLIP on a wide range of benchmarks across a variety of computer vision datasets such as OCR to texture recognition to fine-grained classification. The paper describes model performance on the following datasets: 67 | 68 | - Food101 69 | - CIFAR10 70 | - CIFAR100 71 | - Birdsnap 72 | - SUN397 73 | - Stanford Cars 74 | - FGVC Aircraft 75 | - VOC2007 76 | - DTD 77 | - Oxford-IIIT Pet dataset 78 | - Caltech101 79 | - Flowers102 80 | - MNIST 81 | - SVHN 82 | - IIIT5K 83 | - Hateful Memes 84 | - SST-2 85 | - UCF101 86 | - Kinetics700 87 | - Country211 88 | - CLEVR Counting 89 | - KITTI Distance 90 | - STL-10 91 | - RareAct 92 | - Flickr30 93 | - MSCOCO 94 | - ImageNet 95 | - ImageNet-A 96 | - ImageNet-R 97 | - ImageNet Sketch 98 | - ObjectNet (ImageNet Overlap) 99 | - Youtube-BB 100 | - ImageNet-Vid 101 | 102 | ## Limitations 103 | 104 | CLIP and our analysis of it have a number of limitations. CLIP currently struggles with respect to certain tasks such as fine grained classification and counting objects. CLIP also poses issues with regards to fairness and bias which we discuss in the paper and briefly in the next section. Additionally, our approach to testing CLIP also has an important limitation- in many cases we have used linear probes to evaluate the performance of CLIP and there is evidence suggesting that linear probes can underestimate model performance. 105 | 106 | ### Bias and Fairness 107 | 108 | We find that the performance of CLIP - and the specific biases it exhibits - can depend significantly on class design and the choices one makes for categories to include and exclude. We tested the risk of certain kinds of denigration with CLIP by classifying images of people from [Fairface](https://arxiv.org/abs/1908.04913) into crime-related and non-human animal categories. We found significant disparities with respect to race and gender. Additionally, we found that these disparities could shift based on how the classes were constructed. (Details captured in the Broader Impacts Section in the paper). 109 | 110 | We also tested the performance of CLIP on gender, race and age classification using the Fairface dataset (We default to using race categories as they are constructed in the Fairface dataset.) in order to assess quality of performance across different demographics. We found accuracy >96% across all races for gender classification with ‘Middle Eastern’ having the highest accuracy (98.4%) and ‘White’ having the lowest (96.5%). Additionally, CLIP averaged ~93% for racial classification and ~63% for age classification. Our use of evaluations to test for gender, race and age classification as well as denigration harms is simply to evaluate performance of the model across people and surface potential risks and not to demonstrate an endorsement/enthusiasm for such tasks. 111 | 112 | 113 | 114 | ## Feedback 115 | 116 | ### Where to send questions or comments about the model 117 | 118 | Please use [this Google Form](https://forms.gle/Uv7afRH5dvY34ZEs9) 119 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from lucid.modelzoo.vision_base import Model 2 | from lucid.optvis import render 3 | import tensorflow as tf 4 | from lucid.misc.io import load, save 5 | 6 | 7 | class CLIPImage(Model): 8 | image_value_range = (0, 255) 9 | input_name = 'input_image' 10 | def __init__(self): 11 | self.model_name = "RN50_4x" 12 | self.image_shape = [288, 288, 3] 13 | self.model_path = "https://openaipublic.blob.core.windows.net/clip/tf/RN50_4x/084ee9c176da32014b0ebe42cd7ca66e/image32.pb" 14 | 15 | def load(self, inp = None): 16 | import tensorflow as tf 17 | if inp == None: 18 | self.inp = tf.placeholder(shape = (None,self.image_shape[0], self.image_shape[1], 3), dtype = tf.float32) 19 | else: 20 | self.inp = inp 21 | self.T = render.import_model(self, self.inp, self.inp) 22 | return self.inp, self.T 23 | 24 | 25 | class CLIPText(Model): 26 | input_name = 'tokens' 27 | 28 | def __init__(self): 29 | self.model_name = f"RN50_4x_text" 30 | self.model_path = "https://openaipublic.blob.core.windows.net/clip/tf/RN50_4x/da21bc82c7bba068aa8163333438354c/text32.pb" 31 | 32 | def load(self, O = None): 33 | import tensorflow as tf 34 | if O == None: 35 | self.O = tf.placeholder(tf.int32, [None, None]) 36 | else: 37 | self.O = O 38 | tf.import_graph_def(self.graph_def, {self.input_name: self.O}, name = "text") 39 | gph = tf.get_default_graph() 40 | self.T = lambda x: gph.get_tensor_by_name("text/" + x + ":0") 41 | return self.O, self.T 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | lucid 5 | blobfile 6 | tensorflow-gpu==1.13.2 -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | # By Alec Radford 2 | 3 | import html 4 | import ftfy 5 | import json 6 | import regex as re 7 | from functools import lru_cache 8 | import tensorflow as tf 9 | import blobfile 10 | 11 | def pad(x, pad_length = 76): 12 | z = np.zeros((pad_length)) 13 | z[0:len(x)] = x 14 | return z 15 | 16 | @lru_cache() 17 | def bytes_to_unicode(): 18 | """ 19 | Returns list of utf-8 byte and a corresponding list of unicode strings. 20 | The reversible bpe codes work on unicode strings. 21 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 22 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 23 | This is a signficant percentage of your normal, say, 32K bpe vocab. 24 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 25 | And avoids mapping to whitespace/control characters the bpe code barfs on. 26 | """ 27 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 28 | cs = bs[:] 29 | n = 0 30 | for b in range(2**8): 31 | if b not in bs: 32 | bs.append(b) 33 | cs.append(2**8+n) 34 | n += 1 35 | cs = [chr(n) for n in cs] 36 | return dict(zip(bs, cs)) 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | def basic_clean(text): 50 | text = ftfy.fix_text(text) 51 | text = html.unescape(html.unescape(text)) 52 | return text.strip() 53 | 54 | def whitespace_clean(text): 55 | text = re.sub(r'\s+', ' ', text) 56 | text = text.strip() 57 | return text 58 | 59 | class SimpleTokenizer(object): 60 | 61 | def __init__(self, bpe_path = None): 62 | if bpe_path == None: 63 | bpe_path = blobfile.BlobFile('https://openaipublic.blob.core.windows.net/clip/bpe_simple_vocab_16e6.txt', 'r') 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 66 | merges = bpe_path.read().split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v:k for k,v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>':'<|startoftext|>', '<|endoftext|>':'<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | 134 | def tokenize(self, text, n_text = 76, pad = True): 135 | sot = self.encoder['<|startoftext|>'] 136 | eot = self.encoder['<|endoftext|>'] 137 | tokens = self.encode(text) 138 | tokens = [sot]+tokens[:n_text-1]+[eot] 139 | if pad: 140 | return [tokens + [0]*(n_text+1-len(tokens))] 141 | else: 142 | return tokens 143 | 144 | def sot(self): 145 | return self.encoder['<|startoftext|>'] 146 | 147 | def eot(self): 148 | return self.encoder['<|endoftext|>'] 149 | --------------------------------------------------------------------------------