├── README.md ├── crepe_compo_eval_open_clip.py ├── crepe_eval_utils.py ├── crepe_params.py ├── crepe_prod_eval_albef.py ├── crepe_prod_eval_clip.py ├── crepe_prod_eval_cyclip.py ├── crepe_prod_eval_flava.py ├── data ├── prod_hard_negatives.zip └── syst_hard_negatives.zip └── open_clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── factory.py ├── loss.py ├── model.py ├── model_configs ├── RN101-quickgelu.json ├── RN101.json ├── RN50-quickgelu.json ├── RN50.json ├── RN50x16.json ├── RN50x4.json ├── ViT-B-16-plus-240.json ├── ViT-B-16-plus.json ├── ViT-B-16.json ├── ViT-B-32-plus-256.json ├── ViT-B-32-quickgelu.json ├── ViT-B-32.json ├── ViT-H-14.json ├── ViT-H-16.json ├── ViT-L-14-280.json ├── ViT-L-14-336.json ├── ViT-L-14.json ├── ViT-L-16-320.json ├── ViT-L-16.json ├── ViT-g-14.json ├── timm-efficientnetv2_rw_s.json ├── timm-resnet50d.json ├── timm-resnetaa50d.json ├── timm-resnetblur50.json ├── timm-swin_base_patch4_window7_224.json ├── timm-vit_base_patch16_224.json ├── timm-vit_base_patch32_224.json └── timm-vit_small_patch16_224.json ├── openai.py ├── pretrained.py ├── timm_model.py ├── tokenizer.py ├── transform.py ├── utils.py └── version.py /README.md: -------------------------------------------------------------------------------- 1 | # CREPE 2 | 3 | In this repository, you can find the code we used to evaluate these models: [open_clip](https://github.com/mlfoundations/open_clip) CLIP models, 4 | the official OpenAI CLIP models, CyCLIP, FLAVA and ALBEF on compositional reasoning in our paper [CREPE: Can Vision-Language Foundation Models Reason Compositionally?](https://arxiv.org/abs/2212.07796). 5 | 6 | 7 | 8 | ## Systematicity procedure 9 | 10 | 11 | ## Produtivity procedure 12 | 13 | 14 | ## Evaluation instructions 15 | In `crepe_eval_utils.py`, you can find common evaluation util functions, and you will need to replace `vg_image_paths` 16 | with the path to Visual Genome images on your machine. The VG images can be downloaded [here](https://drive.google.com/drive/folders/11dMtJByk7zmbQjV47PXVwfmakN3Gr5Ic?usp=share_link). 17 | 18 | We evaluated all models on an NVIDIA TITAN X GPU with a CUDA version of 11.4. 19 | 20 | ### Evaluate open_clip CLIP models on systematicity and productivity 21 | You will need to install the packages required to use open_clip [here](https://github.com/mlfoundations/open_clip/blob/main/requirements.txt). 22 | You can download the [pretrained CLIP models](https://github.com/mlfoundations/open_clip#pretrained-model-details) and replace `--model-dir` 23 | with your own model checkpoint directory path in `crepe_compo_eval_open_clip.py`. (You can also modify the code to use open_clip's 24 | [pretrained model interface](https://github.com/mlfoundations/open_clip#pretrained-model-interface).) 25 | 26 | To evaluate all models reported in our paper, simply run: 27 | 28 | ``` 29 | python -m crepe_compo_eval_open_clip --compo-type --hard-neg-types --input-dir --output-dir 30 | ``` 31 | 32 | where the valid compositionality types are `systematicity` and `productivity`. The valid negative types are `atom`, `comp` and `combined` (`atom`+`comp`) for systematicity, and `atom`, `swap` and `negate` for productivity. 33 | 34 | To evaluate other pretrained models, simply modify the `--train-dataset` argument and/or the `DATA2MODEL` variable in `crepe_compo_eval_open_clip.py`. 35 | **Note that the systematicity eval set should only be used to evaluate models pretrained on CC12M, YFCC15M or LAION400M.** 36 | 37 | ### Evaluate all other vision-language models on productivity 38 | For each model, you will need to clone the model's official repository, set up 39 | an environment according to its instructions and place the files `crepe_prod_eval_.py` 40 | and `crepe_eval_utils.py` to their relevant locations. In `crepe_params.py`, you will need to replace `--input-dir` 41 | with your own directory path to CREPE's productivity hard negatives test set. 42 | 43 | #### CLIP-specific instructions 44 | Clone the CLIP repository [here](https://github.com/openai/CLIP) and place `crepe_prod_eval_clip.py` 45 | and `crepe_eval_utils.py` on the top level of the repository. To evaluate models, simply run: 46 | 47 | ``` 48 | python -m crepe_prod_eval_clip --model-name --hard-neg-types --output-dir 49 | ``` 50 | 51 | where the valid negative types are `atom`, `swap` and `negate`, and model names are `RN50`, `RN101`, `ViT-B/32`, `ViT-B/16` and `ViT-L/14`. 52 | 53 | #### CyCLIP-specific instructions 54 | Clone the CyCLIP repository [here](https://github.com/goel-shashank/CyCLIP), place `crepe_prod_eval_cyclip.py` 55 | and `crepe_eval_utils.py` on the top level of the repository and download the 56 | model checkpoint under the folder `cyclip.pt` (accessible from the bottom of the 57 | repository's README). To evaluate models, simply run: 58 | 59 | ``` 60 | python -m crepe_prod_eval_cyclip --hard-neg-types --output-dir 61 | ``` 62 | 63 | #### FLAVA-specific instructions 64 | Clone the FLAVA repository [here](https://github.com/facebookresearch/multimodal) and copy `crepe_prod_eval_flava.py` 65 | and `crepe_eval_utils.py` into the folder `examples/flava/`. To evaluate models, simply run: 66 | 67 | ``` 68 | python -m crepe_prod_eval_flava --hard-neg-types --output-dir 69 | ``` 70 | 71 | #### ALBEF-specific instructions 72 | Clone the ALBEF repository [here](https://github.com/salesforce/ALBEF/tree/b9727e43c3040491774d1b22cc27718aa7772fac), 73 | copy `crepe_prod_eval_albef.py` and `crepe_eval_utils.py` to the top level of the repository 74 | and download the pretrained checkpoint marked '14M' from the repository. To evaluate models, simply run: 75 | 76 | ``` 77 | python -m crepe_prod_eval_albef --hard-neg-types --output-dir 78 | ``` 79 | 80 | ## Citation 81 | If you find our work helpful, please cite us: 82 | 83 | ```bibtex 84 | @article{ma2023crepe, 85 | title={CREPE: Can Vision-Language Foundation Models Reason Compositionally?}, 86 | author={Zixian Ma and Jerry Hong and Mustafa Omer Gul and Mona Gandhi and Irena Gao and Ranjay Krishna}, 87 | year={2023}, 88 | journal={arXiv preprint arXiv:2212.07796}, 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /crepe_compo_eval_open_clip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import torchvision.transforms.functional as TF 8 | from torch.utils.data import DataLoader 9 | from torch.utils.data.distributed import DistributedSampler 10 | from dataclasses import dataclass 11 | from open_clip import tokenize, create_model_and_transforms 12 | from crepe_eval_utils import BaseCsvDataset, get_one2many_metrics, get_one2many_rank, get_metrics 13 | from crepe_params import setup_args 14 | 15 | DATA2MODEL = { 16 | 'cc12m': { 17 | 'RN50-quickgelu': 'rn50-quickgelu-cc12m-f000538c.pt' 18 | }, 19 | 'yfcc': { 20 | 'RN50-quickgelu': 'rn50-quickgelu-yfcc15m-455df137.pt', 21 | 'RN101-quickgelu': 'rn101-quickgelu-yfcc15m-3e04b30e.pt' 22 | }, 23 | 'laion': { 24 | 'ViT-B-16':'vit_b_16-laion400m_e32-55e67d44.pt', 25 | 'ViT-B-16-plus-240': 'vit_b_16_plus_240-laion400m_e32-699c4b84.pt', 26 | 'ViT-B-32-quickgelu': 'vit_b_32-quickgelu-laion400m_e32-46683a32.pt', 27 | 'ViT-L-14': 'vit_l_14-laion400m_e32-3d133497.pt', 28 | } 29 | } 30 | 31 | COMPO_SPLITS = ['seen_compounds', 'unseen_compounds'] 32 | COMPLEXITIES = list(range(4, 13)) 33 | 34 | @dataclass 35 | class DataInfo: 36 | dataloader: DataLoader 37 | sampler: DistributedSampler 38 | 39 | class CsvDataset(BaseCsvDataset): 40 | def __init__(self, input_filename, args, transforms): 41 | super().__init__(input_filename, args, transforms=transforms) 42 | 43 | def __getitem__(self, idx): 44 | raw_image = self.get_image_by_id(self.images[idx]) 45 | if self.crop: 46 | raw_image = TF.crop(raw_image, self.ys[idx], self.xs[idx], self.heights[idx], self.widths[idx]) 47 | image = self.transforms(raw_image) 48 | if self.one2many: 49 | texts = tokenize([str(self.captions[idx])] + list(self.hard_negs[idx])) 50 | else: 51 | texts = tokenize([str(self.captions[idx])])[0] 52 | return image, texts 53 | 54 | def get_csv_dataset(args, preprocess_fn, is_train): 55 | input_filename = args.val_data 56 | assert input_filename 57 | dataset = CsvDataset( 58 | input_filename, 59 | args, 60 | preprocess_fn) 61 | num_samples = len(dataset) 62 | 63 | sampler = None 64 | shuffle = is_train and sampler is None 65 | 66 | dataloader = DataLoader( 67 | dataset, 68 | batch_size=args.batch_size, 69 | shuffle=shuffle, 70 | num_workers=1, 71 | pin_memory=True, 72 | sampler=sampler, 73 | drop_last=is_train, 74 | ) 75 | dataloader.num_samples = num_samples 76 | dataloader.num_batches = len(dataloader) 77 | 78 | return DataInfo(dataloader, sampler) 79 | 80 | def get_data(args, preprocess_fns): 81 | preprocess_train, preprocess_val = preprocess_fns 82 | data = {} 83 | 84 | data["val"] = get_csv_dataset( 85 | args, preprocess_val, is_train=False) 86 | return data 87 | 88 | def evaluate(model, data, args): 89 | metrics = {} 90 | device = torch.device(args.device) 91 | model.eval() 92 | 93 | autocast = torch.cuda.amp.autocast 94 | dataloader = data['val'].dataloader 95 | 96 | # FIXME this does not scale past small eval datasets 97 | # all_image_features @ all_text_features will blow up memory and compute very quickly 98 | all_image_features, all_text_features = [], [] 99 | one2many = dataloader.dataset.one2many 100 | if one2many: 101 | all_ranks = [] 102 | with torch.no_grad(): 103 | for i, batch in enumerate(dataloader): 104 | images, texts = batch 105 | images = images.to(device=device, non_blocking=True) 106 | texts = texts.to(device=device, non_blocking=True) 107 | 108 | if one2many: 109 | image_features = model.encode_image(images) 110 | image_features = F.normalize(image_features, dim=-1) 111 | 112 | texts = torch.squeeze(texts, dim=0) 113 | text_features = model.encode_text(texts) 114 | text_features = F.normalize(text_features, dim=-1) 115 | 116 | rank = get_one2many_rank(image_features, text_features) 117 | all_ranks.append(rank) 118 | else: 119 | with autocast(): 120 | image_features, text_features, logit_scale = model(images, texts) 121 | # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly 122 | # however, system RAM is easily exceeded and compute time becomes problematic 123 | all_image_features.append(image_features.cpu()) 124 | all_text_features.append(text_features.cpu()) 125 | 126 | if one2many: 127 | val_metrics = get_one2many_metrics(np.array(all_ranks)) 128 | metrics.update( 129 | {**val_metrics} 130 | ) 131 | else: 132 | val_metrics = get_metrics( 133 | image_features=torch.cat(all_image_features), 134 | text_features=torch.cat(all_text_features) 135 | ) 136 | metrics.update( 137 | {**val_metrics} 138 | ) 139 | 140 | logging.info("\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])) 141 | 142 | return metrics 143 | 144 | def gather_params(args, hard_neg_type, split): 145 | if args.compo_type == 'systematicity': 146 | if hard_neg_type in ['atom', 'comp', 'combined']: 147 | hard_neg_key = f'valid_hard_negs_{hard_neg_type}' 148 | else: 149 | raise NotImplementedError 150 | 151 | retrieval_data_path = os.path.join(args.input_dir, f'syst_vg_hard_negs_{split}_in_{args.train_dataset}.csv') 152 | 153 | elif args.compo_type == 'productivity': 154 | hard_neg_key = 'hard_negs' 155 | if hard_neg_type in ['atom', 'negate', 'swap']: 156 | input_dir = os.path.join(args.input_dir, hard_neg_type) 157 | retrieval_data_path = os.path.join(input_dir, f'prod_vg_hard_negs_{hard_neg_type}_complexity_{split}.csv') 158 | else: 159 | raise NotImplementedError 160 | else: 161 | raise NotImplementedError 162 | 163 | args.val_data = retrieval_data_path 164 | args.one2many = True 165 | args.crop = True 166 | args.hard_neg_key = hard_neg_key 167 | args.batch_size = 1 168 | return args 169 | 170 | def main(): 171 | args = setup_args() 172 | models = DATA2MODEL[args.train_dataset].keys() 173 | if args.compo_type == 'systematicity': 174 | splits = COMPO_SPLITS 175 | elif args.compo_type == 'productivity': 176 | splits = COMPLEXITIES 177 | 178 | if args.output_dir: 179 | if not os.path.exists(args.output_dir): 180 | os.mkdir(args.output_dir) 181 | 182 | if torch.cuda.is_available(): 183 | device = 'cuda:0' 184 | torch.cuda.set_device(device) 185 | else: 186 | device = 'cpu' 187 | args.device = device 188 | device = torch.device(device) 189 | 190 | for model_name in models: 191 | pretrained = os.path.join(args.model_dir, DATA2MODEL[args.train_dataset][model_name]) 192 | model, preprocess_train, preprocess_val = create_model_and_transforms( 193 | model_name, 194 | pretrained, 195 | precision='amp', 196 | device=device 197 | ) 198 | for hard_neg_type in args.hard_neg_types: 199 | all_metrics = {} 200 | for split in splits: 201 | # params = gather_params(args, model, split) 202 | print('\n' + '*' * 45 + f' Evaluating {model_name} {args.compo_type} on HN-{hard_neg_type.upper()} test set split {split} ' + '*' * 45 + '\n') 203 | args = gather_params(args, hard_neg_type, split) 204 | # initialize datasets 205 | data = get_data(args, (preprocess_train, preprocess_val)) 206 | assert len(data), 'At least one dataset must be specified.' 207 | 208 | metrics = evaluate(model, data, args) 209 | 210 | all_metrics[split] = metrics 211 | 212 | if args.output_dir: 213 | output = os.path.join(args.output_dir, f'{args.compo_type}_{args.train_dataset}_{model_name}_{hard_neg_type}_metrics.json') 214 | print("saving results to:", output) 215 | with open(output, 'w') as f: 216 | json.dump(all_metrics, f) 217 | 218 | if __name__ == "__main__": 219 | main() 220 | -------------------------------------------------------------------------------- /crepe_eval_utils.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import logging 3 | import os 4 | from PIL import Image 5 | from dataclasses import dataclass 6 | 7 | import torch 8 | from torch.utils.data import DataLoader, Dataset 9 | import numpy as np 10 | 11 | import pandas as pd 12 | 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger() 15 | 16 | ### DATASET CONSTRUCTION 17 | 18 | class BaseCsvDataset(Dataset): 19 | def __init__(self, input_filename, args, transforms=None): 20 | logging.debug(f'Loading csv data from {input_filename}.') 21 | df = pd.read_csv(input_filename) 22 | # print(f"Total number of examples: {len(df)}.") 23 | self.crop = args.crop 24 | if self.crop: 25 | assert 'x' in df.columns and 'y' in df.columns and 'width' in df.columns and 'height' in df.columns, "missing x, y, width, or height." 26 | self.xs = df['x'].tolist() 27 | self.ys = df['y'].tolist() 28 | self.heights = df['height'].tolist() 29 | self.widths = df['width'].tolist() 30 | # print("cropping:", self.crop) 31 | self.one2many = args.one2many 32 | # print("one2many:", self.one2many) 33 | if self.one2many: 34 | self.hard_negs = [ast.literal_eval(ls_str) for ls_str in df[args.hard_neg_key]] 35 | self.images = df[args.csv_img_key].tolist() 36 | self.captions = df[args.csv_caption_key].tolist() 37 | self.transforms = transforms 38 | 39 | def __len__(self): 40 | return len(self.captions) 41 | 42 | def get_image_by_id(self, image_id): 43 | vg_image_paths = ['/nlp/scr/irena/data/visual_genome/img/VG_100K', '/nlp/scr/irena/data/visual_genome/img/VG_100K_2'] 44 | for p in vg_image_paths: 45 | path = os.path.join(p, f"{image_id}.jpg") 46 | if os.path.exists(path): 47 | return Image.open(path).convert("RGB") 48 | raise FileNotFoundError(f'The image with id {image_id} is not found.') 49 | 50 | def __getitem__(self, idx): 51 | print("Not yet implemented.") 52 | assert(False) 53 | 54 | @dataclass 55 | class DataInfo: 56 | dataloader: DataLoader 57 | 58 | # EVALUATION UTILITIES 59 | 60 | def get_one2many_rank(image_features, text_features): 61 | logits_per_image = (image_features @ text_features.t()).detach().cpu() 62 | ground_truth = 0 # because the grountruth caption is placed first, see CsvDataset.__getitem__() in data.py 63 | ranking = torch.argsort(logits_per_image, descending=True) 64 | pred = torch.where(ranking == ground_truth)[1].detach().cpu().numpy() 65 | return pred 66 | 67 | def get_one2many_metrics(preds, name='image_to_text'): 68 | metrics = {} 69 | metrics[f"{name}_mean_rank"] = preds.mean() + 1 70 | metrics[f"{name}_rank_std"] = preds.std() 71 | metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 72 | 73 | for k in [1, 3, 5, 10]: 74 | metrics[f"{name}_R@{k}"] = np.mean(preds < k) 75 | metrics[f"{name}_R@{k}_std"] = np.std(preds < k) 76 | return metrics 77 | 78 | def get_metrics(image_features, text_features): 79 | metrics = {} 80 | logits_per_image = (image_features @ text_features.t()).detach().cpu() 81 | logits_per_text = logits_per_image.t().detach().cpu() 82 | 83 | logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} 84 | ground_truth = torch.arange(len(text_features)).view(-1, 1) 85 | 86 | for name, logit in logits.items(): 87 | ranking = torch.argsort(logit, descending=True) 88 | preds = torch.where(ranking == ground_truth)[1] 89 | preds = preds.detach().cpu().numpy() 90 | metrics[f"{name}_mean_rank"] = preds.mean() + 1 91 | metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 92 | 93 | for k in [1, 3, 5, 10]: 94 | metrics[f"{name}_R@{k}"] = np.mean(preds < k) 95 | 96 | return metrics 97 | -------------------------------------------------------------------------------- /crepe_params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def setup_args(): 4 | parser = argparse.ArgumentParser(description="Run image2text retrieval eval.") 5 | parser.add_argument("--compo-type", required=True, type=str, default="systematicity", help="Either systematicity or productivity") 6 | parser.add_argument("--input-dir", required=True, type=str, default="/vision/group/CLIPComp/crepe/prod_hard_negatives") 7 | parser.add_argument('--hard-neg-types', required=True, type=str, nargs='+', help="The type(s) of hard negatives to include in the retrieval set.") 8 | parser.add_argument("--model-dir", type=str, default="/vision/group/clip") 9 | parser.add_argument("--output-dir", type=str, default="log/") 10 | parser.add_argument("--csv-img-key", type=str, default="image_id") 11 | parser.add_argument("--csv-caption-key", type=str, default="caption") 12 | parser.add_argument("--hard-neg-key", type=str, default="hard_negs", help="The column name of the hard negative captions.") 13 | parser.add_argument("--crop", type=bool, default=True, help="Whether to crop the image input.") 14 | parser.add_argument("--one2many", type=bool, default=True, help="Whether each image query has a different retrieval text set.") 15 | # For systematicity eval on open_clip's pretrained models with known training dataset 16 | parser.add_argument("--train-dataset", type=str, default="cc12m") 17 | # For CLIP & CyCLIP 18 | parser.add_argument("--model-name", type=str, default="RN50") 19 | # For CyCLIP 20 | parser.add_argument("--pretrained", default=False, action="store_true", help="Use the OpenAI pretrained models") 21 | args = parser.parse_args() 22 | return args -------------------------------------------------------------------------------- /crepe_prod_eval_albef.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | from PIL import Image 10 | from time import time 11 | 12 | import torch 13 | from torch import nn 14 | from torch.utils.data import DataLoader 15 | from torchvision import transforms 16 | import torch.nn.functional as F 17 | import torchvision.transforms.functional as TF 18 | import numpy as np 19 | import json 20 | 21 | # ALBEF: 22 | # from torchmultimodal.transforms.flava_transform import FLAVAImageTransform 23 | import ruamel.yaml as yaml 24 | from models.model_retrieval import ALBEF 25 | from models.vit import interpolate_pos_embed 26 | # from transformers import BertTokenizer 27 | from models.tokenization_bert import BertTokenizer 28 | 29 | from crepe_eval_utils import BaseCsvDataset, get_one2many_rank, get_one2many_metrics, DataInfo 30 | from crepe_params import setup_args 31 | 32 | import pandas as pd 33 | 34 | logging.basicConfig(level=logging.INFO) 35 | logger = logging.getLogger() 36 | 37 | max_text_length = 512 38 | TEXT_DEFAULT_TOKENIZER = "bert-base-uncased" 39 | text_tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER) 40 | 41 | def collator(batch): 42 | images = torch.stack([x[0] for x in batch], dim=0) 43 | texts = torch.cat([x[1] for x in batch], dim=0) 44 | masks = torch.cat([x[2] for x in batch], dim=0) 45 | 46 | return images, texts, masks 47 | 48 | ### DATASET CONSTRUCTION 49 | 50 | def default_text_transform(texts): 51 | # Expect a list of texts 52 | tokenized_texts = [] 53 | attention_masks = [] 54 | start_time = time() 55 | for text in texts: 56 | tokenized = text_tokenizer(text, padding="max_length", 57 | max_length=max_text_length, truncation=True, return_tensors='pt') 58 | 59 | tokenized_texts.append(tokenized['input_ids']) 60 | attention_masks.append(tokenized['attention_mask']) 61 | 62 | tokenized_texts = torch.cat(tokenized_texts, dim=0) 63 | attention_masks = torch.cat(attention_masks, dim=0) 64 | 65 | return tokenized_texts, attention_masks 66 | 67 | class CsvDataset(BaseCsvDataset): 68 | def __init__(self, input_filename, args, config): 69 | super().__init__(input_filename, args) 70 | 71 | # albef transform: 72 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 73 | test_transform = transforms.Compose([ 74 | transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC), 75 | transforms.ToTensor(), 76 | normalize, 77 | ]) 78 | self.image_transform = test_transform 79 | self.text_transform = default_text_transform 80 | 81 | def __getitem__(self, idx): 82 | raw_image = self.get_image_by_id(self.images[idx]) 83 | if self.crop: 84 | raw_image = TF.crop(raw_image, self.ys[idx], self.xs[idx], self.heights[idx], self.widths[idx]) 85 | image = self.transforms(raw_image) 86 | texts, attn_mask = self.text_transform([str(self.captions[idx])] + list(self.hard_negs[idx])) 87 | 88 | return image, texts, attn_mask 89 | 90 | def get_data(args, retrieval_data_path, config): 91 | # Get CSVDataset 92 | input_filename = retrieval_data_path 93 | dataset = CsvDataset( 94 | input_filename, 95 | args, 96 | config=config) 97 | num_samples = len(dataset) 98 | sampler = None 99 | shuffle=False 100 | 101 | dataloader = DataLoader( 102 | dataset, 103 | batch_size=16, 104 | shuffle=shuffle, 105 | num_workers=1, 106 | pin_memory=True, 107 | sampler=sampler, 108 | drop_last=False, 109 | collate_fn=collator 110 | ) 111 | dataloader.num_samples = num_samples 112 | dataloader.num_batches = len(dataloader) 113 | 114 | return DataInfo(dataloader) 115 | 116 | ### EVALUATION 117 | 118 | def evaluate(model, data, complexity, negative_type): 119 | metrics = {} 120 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 121 | 122 | dataloader = data.dataloader 123 | # num_samples = 0 124 | # samples_per_val = dataloader.num_samples 125 | 126 | # cumulative_loss = 0.0 127 | # all_image_features, all_text_features = [], [] 128 | one2many = dataloader.dataset.one2many 129 | assert(one2many, "Not one2many?") 130 | 131 | if one2many: 132 | all_ranks = [] 133 | 134 | with torch.no_grad(): 135 | for i, batch in enumerate(dataloader): 136 | images, texts, masks = batch 137 | images = images.to(device=device, non_blocking=True) 138 | texts = texts.to(device=device, non_blocking=True) 139 | masks = masks.to(device=device, non_blocking=True) 140 | 141 | if one2many: 142 | image_feat = model.visual_encoder(images) 143 | image_embed = model.vision_proj(image_feat[:,0,:]) 144 | image_embed = F.normalize(image_embed,dim=-1) 145 | 146 | text_out = model.text_encoder(texts, attention_mask = masks, mode='text') 147 | text_feat = text_out.last_hidden_state 148 | text_emb = F.normalize(model.text_proj(text_feat[:,0,:])) 149 | 150 | set_size = text_emb.shape[0] // image_embed.shape[0] 151 | for j in range(image_embed.shape[0]): 152 | curr_image_emb = image_embed[j:j+1, :] 153 | curr_text_emb = text_emb[j*set_size:(j+1)*set_size, :] 154 | rank = get_one2many_rank(curr_image_emb, curr_text_emb) 155 | all_ranks.append(rank) 156 | 157 | print(f'Processed example {i*16}') 158 | 159 | metrics = get_one2many_metrics(np.array(all_ranks)) 160 | 161 | # Alter output here 162 | logging.info( 163 | "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) 164 | ) 165 | 166 | return metrics 167 | 168 | def main(): 169 | args = setup_args() 170 | if args.output_dir: 171 | output_dir = os.path.join(args.output_dir, 'albef') 172 | if not os.path.exists(output_dir): 173 | os.makedirs(output_dir) 174 | # LOAD ALBEF 175 | config_str = './configs/Retrieval_coco.yaml' 176 | config = yaml.load(open(config_str, 'r'), Loader=yaml.Loader) 177 | tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER) 178 | albef = ALBEF(config=config, text_encoder=TEXT_DEFAULT_TOKENIZER, tokenizer=tokenizer) 179 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 180 | logger.info(f"Using device: {device}") 181 | 182 | # MODEL CHECKPOINT 183 | checkpoint = torch.load('./ALBEF.pth', map_location='cpu') 184 | state_dict = checkpoint['model'] 185 | 186 | # reshape positional embedding to accomodate for image resolution change 187 | pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],albef.visual_encoder) 188 | state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped 189 | m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],albef.visual_encoder_m) 190 | state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped 191 | 192 | for key in list(state_dict.keys()): 193 | if 'bert' in key: 194 | encoder_key = key.replace('bert.','') 195 | state_dict[encoder_key] = state_dict[key] 196 | del state_dict[key] 197 | msg = albef.load_state_dict(state_dict,strict=False) 198 | albef = albef.to(device) 199 | albef.eval() 200 | 201 | for hard_neg_type in args.hard_neg_types: 202 | all_metrics = {} 203 | # Iterate over each complexity 204 | for i in range(4, 13): 205 | print('\n' + '*' * 45 + f' Evaluating on complexity {i} ' + '*' * 45 + '\n') 206 | start_time = time() 207 | retrieval_data_path = os.path.join(args.input_dir, f'{hard_neg_type}/prod_vg_hard_negs_{hard_neg_type}_complexity_{i}.csv') 208 | 209 | data = get_data(args, retrieval_data_path, config) 210 | metrics = evaluate(albef, data, i, hard_neg_type) 211 | 212 | print(f'Complexity {i} took {time() - start_time} seconds') 213 | all_metrics[i] = metrics 214 | if args.output_dir: 215 | output = os.path.join(output_dir, f'productivity_albef_{hard_neg_type}_metrics.json') 216 | print("saving results to:", output) 217 | with open(output, 'w') as f: 218 | json.dump(all_metrics, f) 219 | 220 | if __name__ == "__main__": 221 | main() 222 | -------------------------------------------------------------------------------- /crepe_prod_eval_clip.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from time import time 4 | import json 5 | 6 | import torch 7 | import torchvision.transforms.functional as TF 8 | import clip 9 | from torch.utils.data import DataLoader 10 | import numpy as np 11 | 12 | import pandas as pd 13 | 14 | from crepe_eval_utils import BaseCsvDataset, get_one2many_rank, get_one2many_metrics, DataInfo 15 | from crepe_params import setup_args 16 | 17 | logging.basicConfig(level=logging.INFO) 18 | logger = logging.getLogger() 19 | 20 | def collator(batch): 21 | images = torch.stack([x[0] for x in batch], dim=0) 22 | texts = torch.cat([x[1] for x in batch], dim=0) 23 | 24 | return images, texts 25 | 26 | ### DATASET CONSTRUCTION 27 | 28 | class CsvDataset(BaseCsvDataset): 29 | def __init__(self, input_filename, args, processor, device): 30 | super().__init__(input_filename, args) 31 | 32 | self.processor = processor 33 | self.device = device 34 | 35 | def __getitem__(self, idx): 36 | raw_image = self.get_image_by_id(self.images[idx]) 37 | if self.crop: 38 | raw_image = TF.crop(raw_image, self.ys[idx], self.xs[idx], self.heights[idx], self.widths[idx]) 39 | 40 | image = self.processor(raw_image) 41 | texts = self.process_text([str(self.captions[idx])] + list(self.hard_negs[idx])) 42 | return image, texts 43 | 44 | def process_text(self, texts): 45 | proc_text = [clip.tokenize(text, truncate=True) for text in texts] 46 | return torch.cat(proc_text) 47 | 48 | def get_data(args, retrieval_data_path, processor, device): 49 | # Get CSVDataset 50 | input_filename = retrieval_data_path 51 | dataset = CsvDataset( 52 | input_filename, 53 | args, 54 | processor, 55 | device) 56 | num_samples = len(dataset) 57 | sampler = None 58 | shuffle=False 59 | 60 | dataloader = DataLoader( 61 | dataset, 62 | batch_size=16, 63 | shuffle=shuffle, 64 | num_workers=1, 65 | pin_memory=True, 66 | sampler=sampler, 67 | drop_last=False, 68 | collate_fn=collator 69 | ) 70 | dataloader.num_samples = num_samples 71 | dataloader.num_batches = len(dataloader) 72 | 73 | return DataInfo(dataloader) 74 | 75 | ### EVALUATION 76 | 77 | def evaluate(model, data, complexity, negative_type, device): 78 | metrics = {} 79 | 80 | dataloader = data.dataloader 81 | # num_samples = 0 82 | # samples_per_val = dataloader.num_samples 83 | 84 | # cumulative_loss = 0.0 85 | # all_image_features, all_text_features = [], [] 86 | one2many = dataloader.dataset.one2many 87 | 88 | if one2many: 89 | all_ranks = [] 90 | 91 | with torch.no_grad(): 92 | for i, batch in enumerate(dataloader): 93 | images, texts = batch 94 | images = images.to(device) 95 | texts = texts.to(device) 96 | 97 | if one2many: 98 | image_emb = model.encode_image(images) 99 | image_emb /= image_emb.norm(dim = -1, keepdim = True) 100 | 101 | text_emb = model.encode_text(texts) 102 | text_emb /= text_emb.norm(dim = -1, keepdim = True) 103 | 104 | set_size = text_emb.shape[0] // image_emb.shape[0] 105 | for j in range(image_emb.shape[0]): 106 | curr_image_emb = image_emb[j:j+1, :] 107 | curr_text_emb = text_emb[j*set_size:(j+1)*set_size, :] 108 | rank = get_one2many_rank(curr_image_emb, curr_text_emb) 109 | all_ranks.append(rank) 110 | 111 | print(f'Processed example {i*16}') 112 | 113 | metrics = get_one2many_metrics(np.array(all_ranks)) 114 | 115 | # Alter output here 116 | logging.info( 117 | "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) 118 | ) 119 | 120 | return metrics 121 | 122 | def main(): 123 | args = setup_args() 124 | if args.output_dir: 125 | output_dir = os.path.join(args.output_dir, 'open_ai_clip') 126 | if not os.path.exists(output_dir): 127 | os.makedirs(output_dir) 128 | # Load the model 129 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 130 | model, preprocess = clip.load(name = args.model_name, device=device) 131 | model = model.to(device) 132 | model.eval() 133 | 134 | for hard_neg_type in args.hard_neg_types: 135 | all_metrics = {} 136 | # Iterate over each complexity 137 | for i in range(4, 13): 138 | print('\n' + '*' * 45 + f' Evaluating on complexity {i} ' + '*' * 45 + '\n') 139 | start_time = time() 140 | retrieval_data_path = os.path.join(args.input_dir, f'{hard_neg_type}/prod_vg_hard_negs_{hard_neg_type}_complexity_{i}.csv') 141 | 142 | if args.model_name == "RN50" or args.model_name == "RN101": 143 | model_save_name = args.model_name 144 | elif args.model_name == "ViT-B/32": 145 | model_save_name = 'vit_b32' 146 | elif args.model_name == "ViT-B/16": 147 | model_save_name = 'vit_b16' 148 | elif args.model_name == "ViT-L/14": 149 | model_save_name = 'vit_l14' 150 | 151 | data = get_data(args, retrieval_data_path, preprocess, device) 152 | metrics = evaluate(model, data, i, hard_neg_type, device) 153 | 154 | print(f'Complexity {i} took {time() - start_time} seconds') 155 | all_metrics[i] = metrics 156 | 157 | if args.output_dir: 158 | output = os.path.join(output_dir, f'productivity_clip_{model_save_name}_{hard_neg_type}_metrics.json') 159 | print("saving results to:", output) 160 | with open(output, 'w') as f: 161 | json.dump(all_metrics, f) 162 | 163 | if __name__ == '__main__': 164 | main() 165 | -------------------------------------------------------------------------------- /crepe_prod_eval_cyclip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import ast 8 | import argparse 9 | import logging 10 | import os 11 | from PIL import Image, ImageFile 12 | from dataclasses import dataclass 13 | from time import time 14 | import json 15 | 16 | import torch 17 | import torchvision.transforms.functional as TF 18 | from pkgs.openai.clip import load 19 | from torch import nn 20 | from torch.utils.data import DataLoader, Dataset 21 | import numpy as np 22 | 23 | import pandas as pd 24 | 25 | from crepe_eval_utils import BaseCsvDataset, get_one2many_rank, get_one2many_metrics, DataInfo 26 | from crepe_params import setup_args 27 | 28 | logging.basicConfig(level=logging.INFO) 29 | logger = logging.getLogger() 30 | 31 | def collator(batch): 32 | texts = [] 33 | 34 | images = torch.stack([x[0] for x in batch], dim=0) 35 | texts = torch.cat([x[1] for x in batch], dim=0) 36 | attention_masks = torch.cat([x[2] for x in batch], dim=0) 37 | 38 | return images, texts, attention_masks 39 | 40 | ### DATASET CONSTRUCTION 41 | 42 | class CsvDataset(BaseCsvDataset): 43 | def __init__(self, input_filename, args, processor): 44 | super().__init__(input_filename, args) 45 | 46 | self.processor = processor 47 | 48 | def __getitem__(self, idx): 49 | raw_image = self.get_image_by_id(self.images[idx]) 50 | if self.crop: 51 | raw_image = TF.crop(raw_image, self.ys[idx], self.xs[idx], self.heights[idx], self.widths[idx]) 52 | 53 | 54 | image = torch.tensor(self.processor.process_image(raw_image)) 55 | return_dict = self.processor.process_text([str(self.captions[idx])] + list(self.hard_negs[idx])) 56 | input_ids = return_dict['input_ids'] 57 | attention_mask = return_dict['attention_mask'] 58 | 59 | return image, input_ids, attention_mask 60 | 61 | def get_data(args, retrieval_data_path, processor): 62 | # Get CSVDataset 63 | input_filename = retrieval_data_path 64 | dataset = CsvDataset( 65 | input_filename, 66 | args, 67 | processor) 68 | num_samples = len(dataset) 69 | sampler = None 70 | shuffle=False 71 | 72 | dataloader = DataLoader( 73 | dataset, 74 | batch_size=16, 75 | shuffle=shuffle, 76 | num_workers=1, 77 | pin_memory=True, 78 | sampler=sampler, 79 | drop_last=False, 80 | collate_fn=collator 81 | ) 82 | dataloader.num_samples = num_samples 83 | dataloader.num_batches = len(dataloader) 84 | 85 | return DataInfo(dataloader) 86 | 87 | ### EVALUATION 88 | 89 | def evaluate(model, data, complexity, negative_type): 90 | metrics = {} 91 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 92 | 93 | dataloader = data.dataloader 94 | # num_samples = 0 95 | # samples_per_val = dataloader.num_samples 96 | 97 | # cumulative_loss = 0.0 98 | # all_image_features, all_text_features = [], [] 99 | one2many = dataloader.dataset.one2many 100 | 101 | if one2many: 102 | all_ranks = [] 103 | 104 | with torch.no_grad(): 105 | for i, batch in enumerate(dataloader): 106 | images, texts, attention_mask = batch 107 | images = images.to(device=device, non_blocking=True) 108 | texts = texts.to(device=device, non_blocking=True) 109 | attention_mask = attention_mask.to(device=device, non_blocking=True) 110 | 111 | if one2many: 112 | image_emb = model.get_image_features(images) 113 | image_emb /= image_emb.norm(dim = -1, keepdim = True) 114 | 115 | text_emb = model.get_text_features(input_ids = texts, attention_mask = attention_mask) 116 | text_emb /= text_emb.norm(dim = -1, keepdim = True) 117 | 118 | set_size = text_emb.shape[0] // image_emb.shape[0] 119 | for j in range(image_emb.shape[0]): 120 | curr_image_emb = image_emb[j:j+1, :] 121 | curr_text_emb = text_emb[j*set_size:(j+1)*set_size, :] 122 | rank = get_one2many_rank(curr_image_emb, curr_text_emb) 123 | all_ranks.append(rank) 124 | 125 | print(f'Processed example {i*16}') 126 | 127 | metrics = get_one2many_metrics(np.array(all_ranks)) 128 | 129 | # Alter output here 130 | logging.info( 131 | "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) 132 | ) 133 | 134 | return metrics 135 | 136 | def main(): 137 | args = setup_args() 138 | if args.output_dir: 139 | output_dir = os.path.join(args.output_dir, 'cyclip') 140 | if not os.path.exists(output_dir): 141 | os.makedirs(output_dir) 142 | # Load the model 143 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 144 | model, processor = load(name = args.model_name, pretrained = args.pretrained) 145 | checkpoint = torch.load('best.pt', map_location=device) 146 | state_dict = checkpoint['state_dict'] 147 | if(next(iter(state_dict.items()))[0].startswith("module")): 148 | state_dict = {key[len("module."):]: value for key, value in state_dict.items()} 149 | model.load_state_dict(state_dict) 150 | model = model.to(device) 151 | model.eval() 152 | 153 | for hard_neg_type in args.hard_neg_types: 154 | all_metrics = {} 155 | # Iterate over each complexity 156 | for i in range(4, 13): 157 | print('\n' + '*' * 45 + f' Evaluating on complexity {i} ' + '*' * 45 + '\n') 158 | start_time = time() 159 | retrieval_data_path = os.path.join(args.input_dir, f'{hard_neg_type}/prod_vg_hard_negs_{hard_neg_type}_complexity_{i}.csv') 160 | 161 | data = get_data(args, retrieval_data_path, processor) 162 | metrics = evaluate(model, data, i, hard_neg_type) 163 | 164 | print(f'Complexity {i} took {time() - start_time} seconds') 165 | 166 | all_metrics[i] = metrics 167 | 168 | if args.output_dir: 169 | output = os.path.join(output_dir, f'productivity_cyclip_{args.model_name}_{hard_neg_type}_metrics.json') 170 | print("saving results to:", output) 171 | with open(output, 'w') as f: 172 | json.dump(all_metrics, f) 173 | 174 | if __name__ == "__main__": 175 | main() 176 | -------------------------------------------------------------------------------- /crepe_prod_eval_flava.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import ast 8 | 9 | import logging 10 | import os 11 | from PIL import Image 12 | from dataclasses import dataclass 13 | from time import time 14 | import json 15 | 16 | import torch 17 | from torchmultimodal.transforms.flava_transform import FLAVAImageTransform 18 | from torch import nn 19 | from torch.utils.data import DataLoader, Dataset 20 | from torchmultimodal.models.flava.model import flava_model 21 | from transformers import BertTokenizer 22 | import torchvision.transforms.functional as TF 23 | import numpy as np 24 | 25 | import pandas as pd 26 | 27 | from crepe_eval_utils import BaseCsvDataset, get_one2many_rank, get_one2many_metrics, DataInfo 28 | from crepe_params import setup_args 29 | 30 | logging.basicConfig(level=logging.INFO) 31 | logger = logging.getLogger() 32 | 33 | max_text_length = 512 34 | TEXT_DEFAULT_TOKENIZER = "bert-base-uncased" 35 | text_tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER) 36 | 37 | def collator(batch): 38 | texts = [] 39 | images = torch.stack([x[0]["image"] for x in batch], dim=0) 40 | texts = torch.cat([x[1] for x in batch], dim=0) 41 | 42 | return images, texts 43 | 44 | ### DATASET CONSTRUCTION 45 | 46 | def default_text_transform(texts): 47 | # Expect a list of texts 48 | tokenized_texts = [] 49 | start_time = time() 50 | for text in texts: 51 | tokenized = text_tokenizer(text, padding="max_length", 52 | max_length=max_text_length, truncation=True, return_tensors='pt') 53 | tokenized_texts.append(torch.LongTensor(tokenized['input_ids'])) 54 | tokenized_texts = torch.cat(tokenized_texts, dim=0) 55 | 56 | return tokenized_texts 57 | 58 | class CsvDataset(BaseCsvDataset): 59 | def __init__(self, input_filename, args): 60 | super().__init__(input_filename, args) 61 | 62 | self.image_transform = FLAVAImageTransform(is_train=False) 63 | self.text_transform = default_text_transform 64 | 65 | def __getitem__(self, idx): 66 | raw_image = self.get_image_by_id(self.images[idx]) 67 | if self.crop: 68 | raw_image = TF.crop(raw_image, self.ys[idx], self.xs[idx], self.heights[idx], self.widths[idx]) 69 | image = self.image_transform(raw_image) 70 | if self.one2many: 71 | texts = self.text_transform([str(self.captions[idx])] + list(self.hard_negs[idx])) 72 | else: 73 | texts = self.text_transform([str(self.captions[idx])])[0] 74 | return image, texts 75 | 76 | def get_data(args, retrieval_data_path): 77 | # Get CSVDataset 78 | input_filename = retrieval_data_path 79 | dataset = CsvDataset( 80 | input_filename, 81 | args) 82 | num_samples = len(dataset) 83 | sampler = None 84 | shuffle=False 85 | 86 | dataloader = DataLoader( 87 | dataset, 88 | batch_size=8, 89 | shuffle=shuffle, 90 | num_workers=1, 91 | pin_memory=True, 92 | sampler=sampler, 93 | drop_last=False, 94 | collate_fn=collator 95 | ) 96 | dataloader.num_samples = num_samples 97 | dataloader.num_batches = len(dataloader) 98 | 99 | return DataInfo(dataloader) 100 | 101 | ### EVALUATION 102 | 103 | def evaluate(model, data, complexity, negative_type, output_path): 104 | metrics = {} 105 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 106 | 107 | dataloader = data.dataloader 108 | # num_samples = 0 109 | # samples_per_val = dataloader.num_samples 110 | 111 | # cumulative_loss = 0.0 112 | # all_image_features, all_text_features = [], [] 113 | one2many = dataloader.dataset.one2many 114 | assert(one2many, "Not one2many?") 115 | 116 | if one2many: 117 | all_ranks = [] 118 | 119 | with torch.no_grad(): 120 | for i, batch in enumerate(dataloader): 121 | images, texts = batch 122 | images = images.to(device=device, non_blocking=True) 123 | texts = texts.to(device=device, non_blocking=True) 124 | 125 | if one2many: 126 | _, image_emb = model.encode_image(images, projection=True) 127 | image_emb = nn.functional.normalize(image_emb, dim=-1) 128 | _, text_emb = model.encode_text(texts, projection=True) 129 | text_emb = nn.functional.normalize(text_emb) 130 | 131 | set_size = text_emb.shape[0] // image_emb.shape[0] 132 | for j in range(image_emb.shape[0]): 133 | curr_image_emb = image_emb[j:j+1, :] 134 | curr_text_emb = text_emb[j*set_size:(j+1)*set_size, :] 135 | rank = get_one2many_rank(curr_image_emb, curr_text_emb) 136 | all_ranks.append(rank) 137 | 138 | # print(f'Processed example {i*8}') 139 | 140 | metrics = get_one2many_metrics(np.array(all_ranks)) 141 | 142 | # Alter output here 143 | logging.info( 144 | "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]) 145 | ) 146 | 147 | return metrics 148 | 149 | def main(): 150 | args = setup_args() 151 | if args.output_dir: 152 | output_dir = os.path.join(args.output_dir, 'flava') 153 | if not os.path.exists(output_dir): 154 | os.makedirs(output_dir) 155 | # Load the model 156 | flava = flava_model(pretrained=True) 157 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 158 | logger.info(f"Using device: {device}") 159 | flava = flava.to(device) 160 | flava.eval() 161 | 162 | for hard_neg_type in args.hard_neg_types: 163 | all_metrics = {} 164 | # Iterate over each complexity 165 | for i in range(4, 13): 166 | print('\n' + '*' * 45 + f' Evaluating on complexity {i} ' + '*' * 45 + '\n') 167 | start_time = time() 168 | retrieval_data_path = os.path.join(args.input_dir, f'{hard_neg_type}/prod_vg_hard_negs_{hard_neg_type}_complexity_{i}.csv') 169 | 170 | data = get_data(args, retrieval_data_path) 171 | metrics = evaluate(flava, data, i, hard_neg_type) 172 | 173 | print(f'Complexity {i} took {time() - start_time} seconds') 174 | all_metrics[i] = metrics 175 | 176 | if args.output_dir: 177 | output = os.path.join(output_dir, f'productivity_flava_{hard_neg_type}_metrics.json') 178 | print("saving results to:", output) 179 | with open(output, 'w') as f: 180 | json.dump(all_metrics, f) 181 | 182 | if __name__ == "__main__": 183 | main() 184 | -------------------------------------------------------------------------------- /data/prod_hard_negatives.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/CREPE/1fa81c425f442396fe304f170c8eb6dc0747c814/data/prod_hard_negatives.zip -------------------------------------------------------------------------------- /data/syst_hard_negatives.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/CREPE/1fa81c425f442396fe304f170c8eb6dc0747c814/data/syst_hard_negatives.zip -------------------------------------------------------------------------------- /open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import list_models, create_model, create_model_and_transforms, add_model_config 2 | from .loss import ClipLoss 3 | from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_fp16, trace_model 4 | from .openai import load_openai_model, list_openai_models 5 | from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\ 6 | get_pretrained_url, download_pretrained 7 | from .tokenizer import SimpleTokenizer, tokenize 8 | from .transform import image_transform 9 | -------------------------------------------------------------------------------- /open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RAIVNLab/CREPE/1fa81c425f442396fe304f170c8eb6dc0747c814/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /open_clip/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | 9 | import torch 10 | 11 | from .model import CLIP, convert_weights_to_fp16 12 | from .openai import load_openai_model 13 | from .pretrained import get_pretrained_url, download_pretrained 14 | from .transform import image_transform 15 | 16 | 17 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 18 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 19 | 20 | 21 | def _natural_key(string_): 22 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 23 | 24 | 25 | def _rescan_model_configs(): 26 | global _MODEL_CONFIGS 27 | 28 | config_ext = ('.json',) 29 | config_files = [] 30 | for config_path in _MODEL_CONFIG_PATHS: 31 | if config_path.is_file() and config_path.suffix in config_ext: 32 | config_files.append(config_path) 33 | elif config_path.is_dir(): 34 | for ext in config_ext: 35 | config_files.extend(config_path.glob(f'*{ext}')) 36 | 37 | for cf in config_files: 38 | with open(cf, 'r') as f: 39 | model_cfg = json.load(f) 40 | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): 41 | _MODEL_CONFIGS[cf.stem] = model_cfg 42 | 43 | _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} 44 | 45 | 46 | _rescan_model_configs() # initial populate of model config registry 47 | 48 | 49 | def load_state_dict(checkpoint_path: str, map_location='cpu'): 50 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 51 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 52 | state_dict = checkpoint['state_dict'] 53 | else: 54 | state_dict = checkpoint 55 | if next(iter(state_dict.items()))[0].startswith('module'): 56 | state_dict = {k[7:]: v for k, v in state_dict.items()} 57 | return state_dict 58 | 59 | 60 | def create_model( 61 | model_name: str, 62 | pretrained: str = '', 63 | precision: str = 'fp32', 64 | device: torch.device = torch.device('cpu'), 65 | jit: bool = False, 66 | force_quick_gelu: bool = False, 67 | pretrained_image: bool = False, 68 | ): 69 | model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names 70 | 71 | if pretrained.lower() == 'openai': 72 | logging.info(f'Loading pretrained {model_name} from OpenAI.') 73 | model = load_openai_model(model_name, device=device, jit=jit) 74 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 75 | if precision == "amp" or precision == "fp32": 76 | model = model.float() 77 | else: 78 | if model_name in _MODEL_CONFIGS: 79 | logging.info(f'Loading {model_name} model config.') 80 | model_cfg = deepcopy(_MODEL_CONFIGS[model_name]) 81 | else: 82 | logging.error(f'Model config for {model_name} not found; available models {list_models()}.') 83 | raise RuntimeError(f'Model config for {model_name} not found.') 84 | 85 | if force_quick_gelu: 86 | # override for use of QuickGELU on non-OpenAI transformer models 87 | model_cfg["quick_gelu"] = True 88 | 89 | if pretrained_image: 90 | if 'timm_model_name' in model_cfg.get('vision_cfg', {}): 91 | # pretrained weight loading for timm models set via vision_cfg 92 | model_cfg['vision_cfg']['timm_model_pretrained'] = True 93 | else: 94 | assert False, 'pretrained image towers currently only supported for timm models' 95 | 96 | model = CLIP(**model_cfg) 97 | 98 | if pretrained: 99 | checkpoint_path = '' 100 | url = get_pretrained_url(model_name, pretrained) 101 | if url: 102 | checkpoint_path = download_pretrained(url) 103 | elif os.path.exists(pretrained): 104 | checkpoint_path = pretrained 105 | 106 | if checkpoint_path: 107 | logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') 108 | model.load_state_dict(load_state_dict(checkpoint_path)) 109 | else: 110 | logging.warning(f'Pretrained weights ({pretrained}) not found for model {model_name}.') 111 | raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.') 112 | 113 | model.to(device=device) 114 | if precision == "fp16": 115 | assert device.type != 'cpu' 116 | convert_weights_to_fp16(model) 117 | 118 | if jit: 119 | model = torch.jit.script(model) 120 | 121 | return model 122 | 123 | 124 | def create_model_and_transforms( 125 | model_name: str, 126 | pretrained: str = '', 127 | precision: str = 'fp32', 128 | device: torch.device = torch.device('cpu'), 129 | jit: bool = False, 130 | force_quick_gelu: bool = False, 131 | pretrained_image: bool = False, 132 | ): 133 | model = create_model( 134 | model_name, pretrained, precision, device, jit, 135 | force_quick_gelu=force_quick_gelu, 136 | pretrained_image=pretrained_image) 137 | preprocess_train = image_transform(model.visual.image_size, is_train=True) 138 | preprocess_val = image_transform(model.visual.image_size, is_train=False) 139 | return model, preprocess_train, preprocess_val 140 | 141 | 142 | def list_models(): 143 | """ enumerate available model architectures based on config files """ 144 | return list(_MODEL_CONFIGS.keys()) 145 | 146 | 147 | def add_model_config(path): 148 | """ add model config path or file and update registry """ 149 | if not isinstance(path, Path): 150 | path = Path(path) 151 | _MODEL_CONFIG_PATHS.append(path) 152 | _rescan_model_configs() 153 | -------------------------------------------------------------------------------- /open_clip/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed.nn 3 | from torch import distributed as dist, nn as nn 4 | from torch.nn import functional as F 5 | 6 | try: 7 | import horovod.torch as hvd 8 | except ImportError: 9 | hvd = None 10 | 11 | 12 | def gather_features( 13 | image_features, 14 | text_features, 15 | local_loss=False, 16 | gather_with_grad=False, 17 | rank=0, 18 | world_size=1, 19 | use_horovod=False 20 | ): 21 | if use_horovod: 22 | assert hvd is not None, 'Please install horovod' 23 | if gather_with_grad: 24 | all_image_features = hvd.allgather(image_features) 25 | all_text_features = hvd.allgather(text_features) 26 | else: 27 | with torch.no_grad(): 28 | all_image_features = hvd.allgather(image_features) 29 | all_text_features = hvd.allgather(text_features) 30 | if not local_loss: 31 | # ensure grads for local rank when all_* features don't have a gradient 32 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 33 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 34 | gathered_image_features[rank] = image_features 35 | gathered_text_features[rank] = text_features 36 | all_image_features = torch.cat(gathered_image_features, dim=0) 37 | all_text_features = torch.cat(gathered_text_features, dim=0) 38 | else: 39 | # We gather tensors from all gpus 40 | if gather_with_grad: 41 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 42 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 43 | else: 44 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 45 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 46 | dist.all_gather(gathered_image_features, image_features) 47 | dist.all_gather(gathered_text_features, text_features) 48 | if not local_loss: 49 | # ensure grads for local rank when all_* features don't have a gradient 50 | gathered_image_features[rank] = image_features 51 | gathered_text_features[rank] = text_features 52 | all_image_features = torch.cat(gathered_image_features, dim=0) 53 | all_text_features = torch.cat(gathered_text_features, dim=0) 54 | 55 | return all_image_features, all_text_features 56 | 57 | 58 | class ClipLoss(nn.Module): 59 | 60 | def __init__( 61 | self, 62 | local_loss=False, 63 | gather_with_grad=False, 64 | cache_labels=False, 65 | rank=0, 66 | world_size=1, 67 | use_horovod=False, 68 | ): 69 | super().__init__() 70 | self.local_loss = local_loss 71 | self.gather_with_grad = gather_with_grad 72 | self.cache_labels = cache_labels 73 | self.rank = rank 74 | self.world_size = world_size 75 | self.use_horovod = use_horovod 76 | 77 | # cache state 78 | self.prev_num_logits = 0 79 | self.labels = {} 80 | 81 | def forward(self, image_features, text_features, logit_scale): 82 | device = image_features.device 83 | if self.world_size > 1: 84 | all_image_features, all_text_features = gather_features( 85 | image_features, text_features, 86 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 87 | 88 | if self.local_loss: 89 | logits_per_image = logit_scale * image_features @ all_text_features.T 90 | logits_per_text = logit_scale * text_features @ all_image_features.T 91 | else: 92 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 93 | logits_per_text = logits_per_image.T 94 | else: 95 | logits_per_image = logit_scale * image_features @ text_features.T 96 | logits_per_text = logit_scale * text_features @ image_features.T 97 | 98 | # calculated ground-truth and cache if enabled 99 | num_logits = logits_per_image.shape[0] 100 | if self.prev_num_logits != num_logits or device not in self.labels: 101 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 102 | if self.world_size > 1 and self.local_loss: 103 | labels = labels + num_logits * self.rank 104 | if self.cache_labels: 105 | self.labels[device] = labels 106 | self.prev_num_logits = num_logits 107 | else: 108 | labels = self.labels[device] 109 | 110 | total_loss = ( 111 | F.cross_entropy(logits_per_image, labels) + 112 | F.cross_entropy(logits_per_text, labels) 113 | ) / 2 114 | return total_loss 115 | -------------------------------------------------------------------------------- /open_clip/model.py: -------------------------------------------------------------------------------- 1 | """ CLIP Model 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | from collections import OrderedDict 7 | from dataclasses import dataclass 8 | from typing import Tuple, Union, Callable, Optional 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn 14 | from torch.utils.checkpoint import checkpoint 15 | 16 | from .timm_model import TimmModel 17 | from .utils import freeze_batch_norm_2d 18 | 19 | 20 | class Bottleneck(nn.Module): 21 | expansion = 4 22 | 23 | def __init__(self, inplanes, planes, stride=1): 24 | super().__init__() 25 | 26 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 27 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.relu1 = nn.ReLU(inplace=True) 30 | 31 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | self.relu2 = nn.ReLU(inplace=True) 34 | 35 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 36 | 37 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 38 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 39 | self.relu3 = nn.ReLU(inplace=True) 40 | 41 | self.downsample = None 42 | self.stride = stride 43 | 44 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 45 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 46 | self.downsample = nn.Sequential(OrderedDict([ 47 | ("-1", nn.AvgPool2d(stride)), 48 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 49 | ("1", nn.BatchNorm2d(planes * self.expansion)) 50 | ])) 51 | 52 | def forward(self, x: torch.Tensor): 53 | identity = x 54 | 55 | out = self.relu1(self.bn1(self.conv1(x))) 56 | out = self.relu2(self.bn2(self.conv2(out))) 57 | out = self.avgpool(out) 58 | out = self.bn3(self.conv3(out)) 59 | 60 | if self.downsample is not None: 61 | identity = self.downsample(x) 62 | 63 | out += identity 64 | out = self.relu3(out) 65 | return out 66 | 67 | 68 | class AttentionPool2d(nn.Module): 69 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 70 | super().__init__() 71 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 72 | self.k_proj = nn.Linear(embed_dim, embed_dim) 73 | self.q_proj = nn.Linear(embed_dim, embed_dim) 74 | self.v_proj = nn.Linear(embed_dim, embed_dim) 75 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 76 | self.num_heads = num_heads 77 | 78 | def forward(self, x): 79 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 80 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 81 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 82 | x, _ = F.multi_head_attention_forward( 83 | query=x, key=x, value=x, 84 | embed_dim_to_check=x.shape[-1], 85 | num_heads=self.num_heads, 86 | q_proj_weight=self.q_proj.weight, 87 | k_proj_weight=self.k_proj.weight, 88 | v_proj_weight=self.v_proj.weight, 89 | in_proj_weight=None, 90 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 91 | bias_k=None, 92 | bias_v=None, 93 | add_zero_attn=False, 94 | dropout_p=0, 95 | out_proj_weight=self.c_proj.weight, 96 | out_proj_bias=self.c_proj.bias, 97 | use_separate_proj_weight=True, 98 | training=self.training, 99 | need_weights=False 100 | ) 101 | 102 | return x[0] 103 | 104 | 105 | class ModifiedResNet(nn.Module): 106 | """ 107 | A ResNet class that is similar to torchvision's but contains the following changes: 108 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 109 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 110 | - The final pooling layer is a QKV attention instead of an average pool 111 | """ 112 | 113 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 114 | super().__init__() 115 | self.output_dim = output_dim 116 | self.image_size = image_size 117 | 118 | # the 3-layer stem 119 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 120 | self.bn1 = nn.BatchNorm2d(width // 2) 121 | self.relu1 = nn.ReLU(inplace=True) 122 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 123 | self.bn2 = nn.BatchNorm2d(width // 2) 124 | self.relu2 = nn.ReLU(inplace=True) 125 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 126 | self.bn3 = nn.BatchNorm2d(width) 127 | self.relu3 = nn.ReLU(inplace=True) 128 | self.avgpool = nn.AvgPool2d(2) 129 | 130 | # residual layers 131 | self._inplanes = width # this is a *mutable* variable used during construction 132 | self.layer1 = self._make_layer(width, layers[0]) 133 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 134 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 135 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 136 | 137 | embed_dim = width * 32 # the ResNet feature dimension 138 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 139 | 140 | self.init_parameters() 141 | 142 | def _make_layer(self, planes, blocks, stride=1): 143 | layers = [Bottleneck(self._inplanes, planes, stride)] 144 | 145 | self._inplanes = planes * Bottleneck.expansion 146 | for _ in range(1, blocks): 147 | layers.append(Bottleneck(self._inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def init_parameters(self): 152 | if self.attnpool is not None: 153 | std = self.attnpool.c_proj.in_features ** -0.5 154 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 155 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 156 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 157 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 158 | 159 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 160 | for name, param in resnet_block.named_parameters(): 161 | if name.endswith("bn3.weight"): 162 | nn.init.zeros_(param) 163 | 164 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 165 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 166 | for param in self.parameters(): 167 | param.requires_grad = False 168 | if freeze_bn_stats: 169 | freeze_batch_norm_2d(self) 170 | 171 | @torch.jit.ignore 172 | def set_grad_checkpointing(self, enable=True): 173 | # FIXME support for non-transformer 174 | pass 175 | 176 | def stem(self, x): 177 | x = self.relu1(self.bn1(self.conv1(x))) 178 | x = self.relu2(self.bn2(self.conv2(x))) 179 | x = self.relu3(self.bn3(self.conv3(x))) 180 | x = self.avgpool(x) 181 | return x 182 | 183 | def forward(self, x): 184 | x = self.stem(x) 185 | x = self.layer1(x) 186 | x = self.layer2(x) 187 | x = self.layer3(x) 188 | x = self.layer4(x) 189 | x = self.attnpool(x) 190 | 191 | return x 192 | 193 | 194 | class LayerNorm(nn.LayerNorm): 195 | """Subclass torch's LayerNorm to handle fp16.""" 196 | 197 | def forward(self, x: torch.Tensor): 198 | orig_type = x.dtype 199 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 200 | return x.to(orig_type) 201 | 202 | 203 | class QuickGELU(nn.Module): 204 | # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory 205 | def forward(self, x: torch.Tensor): 206 | return x * torch.sigmoid(1.702 * x) 207 | 208 | 209 | class ResidualAttentionBlock(nn.Module): 210 | def __init__(self, d_model: int, n_head: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU): 211 | super().__init__() 212 | 213 | self.attn = nn.MultiheadAttention(d_model, n_head) 214 | self.ln_1 = LayerNorm(d_model) 215 | mlp_width = int(d_model * mlp_ratio) 216 | self.mlp = nn.Sequential(OrderedDict([ 217 | ("c_fc", nn.Linear(d_model, mlp_width)), 218 | ("gelu", act_layer()), 219 | ("c_proj", nn.Linear(mlp_width, d_model)) 220 | ])) 221 | self.ln_2 = LayerNorm(d_model) 222 | 223 | def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 224 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 225 | 226 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 227 | x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) 228 | x = x + self.mlp(self.ln_2(x)) 229 | return x 230 | 231 | 232 | class Transformer(nn.Module): 233 | def __init__(self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU): 234 | super().__init__() 235 | self.width = width 236 | self.layers = layers 237 | self.grad_checkpointing = False 238 | 239 | self.resblocks = nn.ModuleList([ 240 | ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer) 241 | for _ in range(layers) 242 | ]) 243 | 244 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 245 | for r in self.resblocks: 246 | if self.grad_checkpointing and not torch.jit.is_scripting(): 247 | x = checkpoint(r, x, attn_mask) 248 | else: 249 | x = r(x, attn_mask=attn_mask) 250 | return x 251 | 252 | 253 | class VisualTransformer(nn.Module): 254 | def __init__( 255 | self, image_size: int, patch_size: int, width: int, layers: int, heads: int, mlp_ratio: float, 256 | output_dim: int, act_layer: Callable = nn.GELU): 257 | super().__init__() 258 | self.image_size = image_size 259 | self.output_dim = output_dim 260 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 261 | 262 | scale = width ** -0.5 263 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 264 | self.positional_embedding = nn.Parameter(scale * torch.randn((image_size // patch_size) ** 2 + 1, width)) 265 | self.ln_pre = LayerNorm(width) 266 | 267 | self.transformer = Transformer(width, layers, heads, mlp_ratio, act_layer=act_layer) 268 | 269 | self.ln_post = LayerNorm(width) 270 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 271 | 272 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 273 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 274 | for param in self.parameters(): 275 | param.requires_grad = False 276 | 277 | @torch.jit.ignore 278 | def set_grad_checkpointing(self, enable=True): 279 | self.transformer.grad_checkpointing = enable 280 | 281 | def forward(self, x: torch.Tensor): 282 | x = self.conv1(x) # shape = [*, width, grid, grid] 283 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 284 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 285 | x = torch.cat( 286 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 287 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 288 | x = x + self.positional_embedding.to(x.dtype) 289 | x = self.ln_pre(x) 290 | 291 | x = x.permute(1, 0, 2) # NLD -> LND 292 | x = self.transformer(x) 293 | x = x.permute(1, 0, 2) # LND -> NLD 294 | 295 | x = self.ln_post(x[:, 0, :]) 296 | 297 | if self.proj is not None: 298 | x = x @ self.proj 299 | 300 | return x 301 | 302 | 303 | @dataclass 304 | class CLIPVisionCfg: 305 | layers: Union[Tuple[int, int, int, int], int] = 12 306 | width: int = 768 307 | head_width: int = 64 308 | mlp_ratio: float = 4.0 309 | patch_size: int = 16 310 | image_size: Union[Tuple[int, int], int] = 224 311 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size 312 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model 313 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') 314 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') 315 | 316 | 317 | @dataclass 318 | class CLIPTextCfg: 319 | context_length: int = 77 320 | vocab_size: int = 49408 321 | width: int = 512 322 | heads: int = 8 323 | layers: int = 12 324 | 325 | 326 | class CLIP(nn.Module): 327 | def __init__( 328 | self, 329 | embed_dim: int, 330 | vision_cfg: CLIPVisionCfg, 331 | text_cfg: CLIPTextCfg, 332 | quick_gelu: bool = False, 333 | ): 334 | super().__init__() 335 | if isinstance(vision_cfg, dict): 336 | vision_cfg = CLIPVisionCfg(**vision_cfg) 337 | if isinstance(text_cfg, dict): 338 | text_cfg = CLIPTextCfg(**text_cfg) 339 | 340 | self.context_length = text_cfg.context_length 341 | 342 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more 343 | # memory efficient in recent PyTorch releases (>= 1.10). 344 | # NOTE: timm models always use native GELU regardless of quick_gelu flag. 345 | act_layer = QuickGELU if quick_gelu else nn.GELU 346 | 347 | if vision_cfg.timm_model_name: 348 | self.visual = TimmModel( 349 | vision_cfg.timm_model_name, 350 | pretrained=vision_cfg.timm_model_pretrained, 351 | pool=vision_cfg.timm_pool, 352 | proj=vision_cfg.timm_proj, 353 | embed_dim=embed_dim, 354 | image_size=vision_cfg.image_size 355 | ) 356 | act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models 357 | elif isinstance(vision_cfg.layers, (tuple, list)): 358 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width 359 | self.visual = ModifiedResNet( 360 | layers=vision_cfg.layers, 361 | output_dim=embed_dim, 362 | heads=vision_heads, 363 | image_size=vision_cfg.image_size, 364 | width=vision_cfg.width 365 | ) 366 | else: 367 | vision_heads = vision_cfg.width // vision_cfg.head_width 368 | self.visual = VisualTransformer( 369 | image_size=vision_cfg.image_size, 370 | patch_size=vision_cfg.patch_size, 371 | width=vision_cfg.width, 372 | layers=vision_cfg.layers, 373 | heads=vision_heads, 374 | mlp_ratio=vision_cfg.mlp_ratio, 375 | output_dim=embed_dim, 376 | act_layer=act_layer, 377 | ) 378 | 379 | self.transformer = Transformer( 380 | width=text_cfg.width, 381 | layers=text_cfg.layers, 382 | heads=text_cfg.heads, 383 | act_layer=act_layer, 384 | ) 385 | 386 | self.vocab_size = text_cfg.vocab_size 387 | self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width) 388 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, text_cfg.width)) 389 | self.ln_final = LayerNorm(text_cfg.width) 390 | 391 | self.text_projection = nn.Parameter(torch.empty(text_cfg.width, embed_dim)) 392 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 393 | self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) 394 | 395 | self.init_parameters() 396 | 397 | def init_parameters(self): 398 | nn.init.normal_(self.token_embedding.weight, std=0.02) 399 | nn.init.normal_(self.positional_embedding, std=0.01) 400 | nn.init.constant_(self.logit_scale, np.log(1 / 0.07)) 401 | 402 | if hasattr(self.visual, 'init_parameters'): 403 | self.visual.init_parameters() 404 | 405 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 406 | attn_std = self.transformer.width ** -0.5 407 | fc_std = (2 * self.transformer.width) ** -0.5 408 | for block in self.transformer.resblocks: 409 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 410 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 411 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 412 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 413 | 414 | if self.text_projection is not None: 415 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 416 | 417 | def build_attention_mask(self): 418 | # lazily create causal attention mask, with full attention between the vision tokens 419 | # pytorch uses additive attention mask; fill with -inf 420 | mask = torch.empty(self.context_length, self.context_length) 421 | mask.fill_(float("-inf")) 422 | mask.triu_(1) # zero out the lower diagonal 423 | return mask 424 | 425 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): 426 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 427 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) 428 | 429 | @torch.jit.ignore 430 | def set_grad_checkpointing(self, enable=True): 431 | self.visual.set_grad_checkpointing(enable) 432 | self.transformer.grad_checkpointing = enable 433 | 434 | def encode_image(self, image): 435 | return self.visual(image) 436 | 437 | def encode_text(self, text): 438 | # print('text before embedding:', text) 439 | x = self.token_embedding(text) # [batch_size, n_ctx, d_model] 440 | # print('text after embedding:', x) 441 | x = x + self.positional_embedding 442 | x = x.permute(1, 0, 2) # NLD -> LND 443 | x = self.transformer(x, attn_mask=self.attn_mask) 444 | x = x.permute(1, 0, 2) # LND -> NLD 445 | x = self.ln_final(x) 446 | 447 | # x.shape = [batch_size, n_ctx, transformer.width] 448 | # take features from the eot embedding (eot_token is the highest number in each sequence) 449 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 450 | 451 | return x 452 | 453 | def forward(self, image, text): 454 | if image is None: 455 | return self.encode_text(text) 456 | elif text is None: 457 | return self.encode_image(image) 458 | image_features = self.encode_image(image) 459 | image_features = F.normalize(image_features, dim=-1) 460 | 461 | text_features = self.encode_text(text) 462 | text_features = F.normalize(text_features, dim=-1) 463 | 464 | return image_features, text_features, self.logit_scale.exp() 465 | 466 | 467 | def convert_weights_to_fp16(model: nn.Module): 468 | """Convert applicable model parameters to fp16""" 469 | 470 | def _convert_weights_to_fp16(l): 471 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 472 | l.weight.data = l.weight.data.half() 473 | if l.bias is not None: 474 | l.bias.data = l.bias.data.half() 475 | 476 | if isinstance(l, nn.MultiheadAttention): 477 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 478 | tensor = getattr(l, attr) 479 | if tensor is not None: 480 | tensor.data = tensor.data.half() 481 | 482 | for name in ["text_projection", "proj"]: 483 | if hasattr(l, name): 484 | attr = getattr(l, name) 485 | if attr is not None: 486 | attr.data = attr.data.half() 487 | 488 | model.apply(_convert_weights_to_fp16) 489 | 490 | 491 | def build_model_from_openai_state_dict(state_dict: dict): 492 | vit = "visual.proj" in state_dict 493 | 494 | if vit: 495 | vision_width = state_dict["visual.conv1.weight"].shape[0] 496 | vision_layers = len( 497 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 498 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 499 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 500 | image_size = vision_patch_size * grid_size 501 | else: 502 | counts: list = [ 503 | len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 504 | vision_layers = tuple(counts) 505 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 506 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 507 | vision_patch_size = None 508 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 509 | image_size = output_width * 32 510 | 511 | embed_dim = state_dict["text_projection"].shape[1] 512 | context_length = state_dict["positional_embedding"].shape[0] 513 | vocab_size = state_dict["token_embedding.weight"].shape[0] 514 | transformer_width = state_dict["ln_final.weight"].shape[0] 515 | transformer_heads = transformer_width // 64 516 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 517 | 518 | vision_cfg = CLIPVisionCfg( 519 | layers=vision_layers, 520 | width=vision_width, 521 | patch_size=vision_patch_size, 522 | image_size=image_size, 523 | ) 524 | text_cfg = CLIPTextCfg( 525 | context_length=context_length, 526 | vocab_size=vocab_size, 527 | width=transformer_width, 528 | heads=transformer_heads, 529 | layers=transformer_layers 530 | ) 531 | model = CLIP( 532 | embed_dim, 533 | vision_cfg=vision_cfg, 534 | text_cfg=text_cfg, 535 | quick_gelu=True, # OpenAI models were trained with QuickGELU 536 | ) 537 | 538 | for key in ["input_resolution", "context_length", "vocab_size"]: 539 | state_dict.pop(key, None) 540 | 541 | convert_weights_to_fp16(model) 542 | model.load_state_dict(state_dict) 543 | return model.eval() 544 | 545 | 546 | def trace_model(model, batch_size=256, device=torch.device('cpu')): 547 | model.eval() 548 | image_size = model.visual.image_size 549 | example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) 550 | example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) 551 | model = torch.jit.trace_module( 552 | model, 553 | inputs=dict( 554 | forward=(example_images, example_text), 555 | encode_text=(example_text,), 556 | encode_image=(example_images,) 557 | )) 558 | model.visual.image_size = image_size 559 | return model 560 | -------------------------------------------------------------------------------- /open_clip/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /open_clip/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-14-280.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 280, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /open_clip/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /open_clip/model_configs/timm-efficientnetv2_rw_s.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "efficientnetv2_rw_s", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 288 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/timm-resnet50d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnet50d", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /open_clip/model_configs/timm-resnetaa50d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnetaa50d", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /open_clip/model_configs/timm-resnetblur50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "resnetblur50", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "abs_attn", 7 | "timm_proj": "", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /open_clip/model_configs/timm-swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/timm-vit_base_patch16_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_base_patch16_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/timm-vit_base_patch32_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_base_patch32_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/model_configs/timm-vit_small_patch16_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_small_patch16_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /open_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import Union, List 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict 13 | from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_tag_models('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 26 | jit=True, 27 | ): 28 | """Load a CLIP model 29 | 30 | Parameters 31 | ---------- 32 | name : str 33 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 34 | device : Union[str, torch.device] 35 | The device to put the loaded model 36 | jit : bool 37 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 38 | 39 | Returns 40 | ------- 41 | model : torch.nn.Module 42 | The CLIP model 43 | preprocess : Callable[[PIL.Image], torch.Tensor] 44 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 45 | """ 46 | if get_pretrained_url(name, 'openai'): 47 | model_path = download_pretrained(get_pretrained_url(name, 'openai')) 48 | elif os.path.isfile(name): 49 | model_path = name 50 | else: 51 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 52 | 53 | try: 54 | # loading JIT archive 55 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 56 | state_dict = None 57 | except RuntimeError: 58 | # loading saved state dict 59 | if jit: 60 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 61 | jit = False 62 | state_dict = torch.load(model_path, map_location="cpu") 63 | 64 | if not jit: 65 | try: 66 | model = build_model_from_openai_state_dict(state_dict or model.state_dict()).to(device) 67 | except KeyError: 68 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 69 | model = build_model_from_openai_state_dict(sd).to(device) 70 | 71 | if str(device) == "cpu": 72 | model.float() 73 | return model 74 | 75 | # patch the device names 76 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 77 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 78 | 79 | def patch_device(module): 80 | try: 81 | graphs = [module.graph] if hasattr(module, "graph") else [] 82 | except RuntimeError: 83 | graphs = [] 84 | 85 | if hasattr(module, "forward1"): 86 | graphs.append(module.forward1.graph) 87 | 88 | for graph in graphs: 89 | for node in graph.findAllNodes("prim::Constant"): 90 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 91 | node.copyAttributes(device_node) 92 | 93 | model.apply(patch_device) 94 | patch_device(model.encode_image) 95 | patch_device(model.encode_text) 96 | 97 | # patch dtype to float32 on CPU 98 | if str(device) == "cpu": 99 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 100 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 101 | float_node = float_input.node() 102 | 103 | def patch_float(module): 104 | try: 105 | graphs = [module.graph] if hasattr(module, "graph") else [] 106 | except RuntimeError: 107 | graphs = [] 108 | 109 | if hasattr(module, "forward1"): 110 | graphs.append(module.forward1.graph) 111 | 112 | for graph in graphs: 113 | for node in graph.findAllNodes("aten::to"): 114 | inputs = list(node.inputs()) 115 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 116 | if inputs[i].node()["value"] == 5: 117 | inputs[i].node().copyAttributes(float_node) 118 | 119 | model.apply(patch_float) 120 | patch_float(model.encode_image) 121 | patch_float(model.encode_text) 122 | model.float() 123 | 124 | # ensure image_size attr available at consistent location for both jit and non-jit 125 | model.visual.image_size = model.input_resolution.item() 126 | return model 127 | -------------------------------------------------------------------------------- /open_clip/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | 6 | from tqdm import tqdm 7 | 8 | _RN50 = dict( 9 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 10 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 11 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" 12 | ) 13 | 14 | _RN50_quickgelu = dict( 15 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 16 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", 17 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt" 18 | ) 19 | 20 | _RN101 = dict( 21 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 22 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" 23 | ) 24 | 25 | _RN101_quickgelu = dict( 26 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 27 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt" 28 | ) 29 | 30 | _RN50x4 = dict( 31 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 32 | ) 33 | 34 | _RN50x16 = dict( 35 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 36 | ) 37 | 38 | _RN50x64 = dict( 39 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 40 | ) 41 | 42 | _VITB32 = dict( 43 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 44 | laion2b_e16="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth", 45 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 46 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 47 | ) 48 | 49 | _VITB32_quickgelu = dict( 50 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 51 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", 52 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", 53 | ) 54 | 55 | _VITB16 = dict( 56 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 57 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt", 58 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt", 59 | ) 60 | 61 | _VITB16_PLUS_240 = dict( 62 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt", 63 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt", 64 | ) 65 | 66 | _VITL14 = dict( 67 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 68 | laion400m_e31='https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt', 69 | laion400m_e32='https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt', 70 | ) 71 | 72 | _VITL14_336 = dict( 73 | openai="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt" 74 | ) 75 | 76 | _PRETRAINED = { 77 | "RN50": _RN50, 78 | "RN50-quickgelu": _RN50_quickgelu, 79 | "RN101": _RN101, 80 | "RN101-quickgelu": _RN101_quickgelu, 81 | "RN50x4": _RN50x4, 82 | "RN50x16": _RN50x16, 83 | "RN50x64": _RN50x64, 84 | "ViT-B-32": _VITB32, 85 | "ViT-B-32-quickgelu": _VITB32_quickgelu, 86 | "ViT-B-16": _VITB16, 87 | "ViT-B-16-plus-240": _VITB16_PLUS_240, 88 | "ViT-L-14": _VITL14, 89 | "ViT-L-14-336": _VITL14_336, 90 | } 91 | 92 | 93 | def list_pretrained(as_str: bool = False): 94 | """ returns list of pretrained models 95 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 96 | """ 97 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] 98 | 99 | 100 | def list_pretrained_tag_models(tag: str): 101 | """ return all models having the specified pretrain tag """ 102 | models = [] 103 | for k in _PRETRAINED.keys(): 104 | if tag in _PRETRAINED[k]: 105 | models.append(k) 106 | return models 107 | 108 | 109 | def list_pretrained_model_tags(model: str): 110 | """ return all pretrain tags for the specified model architecture """ 111 | tags = [] 112 | if model in _PRETRAINED: 113 | tags.extend(_PRETRAINED[model].keys()) 114 | return tags 115 | 116 | 117 | def get_pretrained_url(model: str, tag: str): 118 | if model not in _PRETRAINED: 119 | return '' 120 | model_pretrained = _PRETRAINED[model] 121 | tag = tag.lower() 122 | if tag not in model_pretrained: 123 | return '' 124 | return model_pretrained[tag] 125 | 126 | 127 | def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): 128 | os.makedirs(root, exist_ok=True) 129 | filename = os.path.basename(url) 130 | 131 | if 'openaipublic' in url: 132 | expected_sha256 = url.split("/")[-2] 133 | else: 134 | expected_sha256 = '' 135 | 136 | download_target = os.path.join(root, filename) 137 | 138 | if os.path.exists(download_target) and not os.path.isfile(download_target): 139 | raise RuntimeError(f"{download_target} exists and is not a regular file") 140 | 141 | if os.path.isfile(download_target): 142 | if expected_sha256: 143 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 144 | return download_target 145 | else: 146 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 147 | else: 148 | return download_target 149 | 150 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 151 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 152 | while True: 153 | buffer = source.read(8192) 154 | if not buffer: 155 | break 156 | 157 | output.write(buffer) 158 | loop.update(len(buffer)) 159 | 160 | if expected_sha256 and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 161 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 162 | 163 | return download_target 164 | -------------------------------------------------------------------------------- /open_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch.nn as nn 8 | 9 | try: 10 | import timm 11 | from timm.models.layers import Mlp, to_2tuple 12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 13 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 14 | except ImportError as e: 15 | timm = None 16 | 17 | from .utils import freeze_batch_norm_2d 18 | 19 | 20 | class TimmModel(nn.Module): 21 | """ timm model adapter 22 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model_name, 28 | embed_dim, 29 | image_size=224, 30 | pool='avg', 31 | proj='linear', 32 | drop=0., 33 | pretrained=False): 34 | super().__init__() 35 | if timm is None: 36 | raise RuntimeError("Please `pip install timm` to use timm models.") 37 | 38 | self.image_size = to_2tuple(image_size) 39 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 40 | feat_size = self.trunk.default_cfg.get('pool_size', None) 41 | feature_ndim = 1 if not feat_size else 2 42 | if pool in ('abs_attn', 'rot_attn'): 43 | assert feature_ndim == 2 44 | # if attn pooling used, remove both classifier and default pool 45 | self.trunk.reset_classifier(0, global_pool='') 46 | else: 47 | # reset global pool if pool config set, otherwise leave as network default 48 | reset_kwargs = dict(global_pool=pool) if pool else {} 49 | self.trunk.reset_classifier(0, **reset_kwargs) 50 | prev_chs = self.trunk.num_features 51 | 52 | head_layers = OrderedDict() 53 | if pool == 'abs_attn': 54 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 55 | prev_chs = embed_dim 56 | elif pool == 'rot_attn': 57 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 58 | prev_chs = embed_dim 59 | else: 60 | assert proj, 'projection layer needed if non-attention pooling is used.' 61 | 62 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 63 | if proj == 'linear': 64 | head_layers['drop'] = nn.Dropout(drop) 65 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim) 66 | elif proj == 'mlp': 67 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) 68 | 69 | self.head = nn.Sequential(head_layers) 70 | 71 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 72 | """ lock modules 73 | Args: 74 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 75 | """ 76 | if not unlocked_groups: 77 | # lock full model 78 | for param in self.trunk.parameters(): 79 | param.requires_grad = False 80 | if freeze_bn_stats: 81 | freeze_batch_norm_2d(self.trunk) 82 | else: 83 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 84 | try: 85 | # FIXME import here until API stable and in an official release 86 | from timm.models.helpers import group_parameters, group_modules 87 | except ImportError: 88 | raise RuntimeError( 89 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 90 | matcher = self.trunk.group_matcher() 91 | gparams = group_parameters(self.trunk, matcher) 92 | max_layer_id = max(gparams.keys()) 93 | max_layer_id = max_layer_id - unlocked_groups 94 | for group_idx in range(max_layer_id + 1): 95 | group = gparams[group_idx] 96 | for param in group: 97 | self.trunk.get_parameter(param).requires_grad = False 98 | if freeze_bn_stats: 99 | gmodules = group_modules(self.trunk, matcher, reverse=True) 100 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 101 | freeze_batch_norm_2d(self.trunk, gmodules) 102 | 103 | def forward(self, x): 104 | x = self.trunk(x) 105 | x = self.head(x) 106 | return x 107 | -------------------------------------------------------------------------------- /open_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | 16 | @lru_cache() 17 | def default_bpe(): 18 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 19 | 20 | 21 | @lru_cache() 22 | def bytes_to_unicode(): 23 | """ 24 | Returns list of utf-8 byte and a corresponding list of unicode strings. 25 | The reversible bpe codes work on unicode strings. 26 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 27 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 28 | This is a signficant percentage of your normal, say, 32K bpe vocab. 29 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 30 | And avoids mapping to whitespace/control characters the bpe code barfs on. 31 | """ 32 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 33 | cs = bs[:] 34 | n = 0 35 | for b in range(2**8): 36 | if b not in bs: 37 | bs.append(b) 38 | cs.append(2**8+n) 39 | n += 1 40 | cs = [chr(n) for n in cs] 41 | return dict(zip(bs, cs)) 42 | 43 | 44 | def get_pairs(word): 45 | """Return set of symbol pairs in a word. 46 | Word is represented as tuple of symbols (symbols being variable-length strings). 47 | """ 48 | pairs = set() 49 | prev_char = word[0] 50 | for char in word[1:]: 51 | pairs.add((prev_char, char)) 52 | prev_char = char 53 | return pairs 54 | 55 | 56 | def basic_clean(text): 57 | text = ftfy.fix_text(text) 58 | text = html.unescape(html.unescape(text)) 59 | return text.strip() 60 | 61 | 62 | def whitespace_clean(text): 63 | text = re.sub(r'\s+', ' ', text) 64 | text = text.strip() 65 | return text 66 | 67 | 68 | class SimpleTokenizer(object): 69 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 70 | self.byte_encoder = bytes_to_unicode() 71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 73 | # print("merges len:", len(merges)) 74 | merges = merges[1:49152-256-2+1] 75 | merges = [tuple(merge.split()) for merge in merges] 76 | vocab = list(bytes_to_unicode().values()) 77 | vocab = vocab + [v+'' for v in vocab] 78 | for merge in merges: 79 | vocab.append(''.join(merge)) 80 | if not special_tokens: 81 | special_tokens = ['', ''] 82 | else: 83 | special_tokens = ['', ''] + special_tokens 84 | vocab.extend(special_tokens) 85 | # print("vocab:", len(vocab)) 86 | self.encoder = dict(zip(vocab, range(len(vocab)))) 87 | # print("encoder:", self.encoder) 88 | self.decoder = {v: k for k, v in self.encoder.items()} 89 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 90 | self.cache = {t:t for t in special_tokens} 91 | special = "|".join(special_tokens) 92 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 93 | 94 | self.vocab_size = len(self.encoder) 95 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 96 | 97 | def bpe(self, token): 98 | if token in self.cache: 99 | return self.cache[token] 100 | word = tuple(token[:-1]) + ( token[-1] + '',) 101 | pairs = get_pairs(word) 102 | 103 | if not pairs: 104 | return token+'' 105 | 106 | while True: 107 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 108 | if bigram not in self.bpe_ranks: 109 | break 110 | first, second = bigram 111 | new_word = [] 112 | i = 0 113 | while i < len(word): 114 | try: 115 | j = word.index(first, i) 116 | new_word.extend(word[i:j]) 117 | i = j 118 | except: 119 | new_word.extend(word[i:]) 120 | break 121 | 122 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 123 | new_word.append(first+second) 124 | i += 2 125 | else: 126 | new_word.append(word[i]) 127 | i += 1 128 | new_word = tuple(new_word) 129 | word = new_word 130 | if len(word) == 1: 131 | break 132 | else: 133 | pairs = get_pairs(word) 134 | word = ' '.join(word) 135 | self.cache[token] = word 136 | return word 137 | 138 | def encode(self, text): 139 | bpe_tokens = [] 140 | text = whitespace_clean(basic_clean(text)).lower() 141 | for token in re.findall(self.pat, text): 142 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 143 | # print("token, bpe:", token, self.bpe(token)) 144 | # print(self.bpe(token).split(' ')) 145 | # for bpe_token in self.bpe(token).split(' '): 146 | # print('token:', bpe_token ) 147 | # print('bpe:', self.encoder[bpe_token]) 148 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 149 | # print("overall bpe:", bpe_tokens) 150 | return bpe_tokens 151 | 152 | def decode(self, tokens): 153 | text = ''.join([self.decoder[token] for token in tokens]) 154 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 155 | return text 156 | 157 | 158 | _tokenizer = SimpleTokenizer() 159 | 160 | 161 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 162 | """ 163 | Returns the tokenized representation of given input string(s) 164 | 165 | Parameters 166 | ---------- 167 | texts : Union[str, List[str]] 168 | An input string or a list of input strings to tokenize 169 | context_length : int 170 | The context length to use; all CLIP models use 77 as the context length 171 | 172 | Returns 173 | ------- 174 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 175 | """ 176 | if isinstance(texts, str): 177 | texts = [texts] 178 | 179 | sot_token = _tokenizer.encoder[""] 180 | eot_token = _tokenizer.encoder[""] 181 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 182 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 183 | 184 | for i, tokens in enumerate(all_tokens): 185 | if len(tokens) > context_length: 186 | tokens = tokens[:context_length] # Truncate 187 | result[i, :len(tokens)] = torch.tensor(tokens) 188 | 189 | return result 190 | -------------------------------------------------------------------------------- /open_clip/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, ToTensor, Resize, \ 2 | CenterCrop 3 | from PIL import Image 4 | 5 | def _convert_to_rgb(image): 6 | return image.convert('RGB') 7 | 8 | 9 | def image_transform( 10 | image_size: int, 11 | is_train: bool, 12 | mean=(0.48145466, 0.4578275, 0.40821073), 13 | std=(0.26862954, 0.26130258, 0.27577711) 14 | ): 15 | normalize = Normalize(mean=mean, std=std) 16 | if is_train: 17 | return Compose([ 18 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=Image.BICUBIC), 19 | _convert_to_rgb, 20 | ToTensor(), 21 | normalize, 22 | ]) 23 | else: 24 | return Compose([ 25 | Resize(image_size, interpolation=Image.BICUBIC), 26 | CenterCrop(image_size), 27 | _convert_to_rgb, 28 | ToTensor(), 29 | normalize, 30 | ]) 31 | -------------------------------------------------------------------------------- /open_clip/utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torchvision.ops.misc import FrozenBatchNorm2d 3 | 4 | 5 | def freeze_batch_norm_2d(module, module_match={}, name=''): 6 | """ 7 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 8 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 9 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 10 | 11 | Args: 12 | module (torch.nn.Module): Any PyTorch module. 13 | module_match (dict): Dictionary of full module names to freeze (all if empty) 14 | name (str): Full module name (prefix) 15 | 16 | Returns: 17 | torch.nn.Module: Resulting module 18 | 19 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 20 | """ 21 | res = module 22 | is_match = True 23 | if module_match: 24 | is_match = name in module_match 25 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 26 | res = FrozenBatchNorm2d(module.num_features) 27 | res.num_features = module.num_features 28 | res.affine = module.affine 29 | if module.affine: 30 | res.weight.data = module.weight.data.clone().detach() 31 | res.bias.data = module.bias.data.clone().detach() 32 | res.running_mean.data = module.running_mean.data 33 | res.running_var.data = module.running_var.data 34 | res.eps = module.eps 35 | else: 36 | for child_name, child in module.named_children(): 37 | full_child_name = '.'.join([name, child_name]) if name else child_name 38 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 39 | if new_child is not child: 40 | res.add_module(child_name, new_child) 41 | return res -------------------------------------------------------------------------------- /open_clip/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.3.0' 2 | --------------------------------------------------------------------------------