├── README.md
├── crepe_compo_eval_open_clip.py
├── crepe_eval_utils.py
├── crepe_params.py
├── crepe_prod_eval_albef.py
├── crepe_prod_eval_clip.py
├── crepe_prod_eval_cyclip.py
├── crepe_prod_eval_flava.py
├── data
├── prod_hard_negatives.zip
└── syst_hard_negatives.zip
└── open_clip
├── __init__.py
├── bpe_simple_vocab_16e6.txt.gz
├── factory.py
├── loss.py
├── model.py
├── model_configs
├── RN101-quickgelu.json
├── RN101.json
├── RN50-quickgelu.json
├── RN50.json
├── RN50x16.json
├── RN50x4.json
├── ViT-B-16-plus-240.json
├── ViT-B-16-plus.json
├── ViT-B-16.json
├── ViT-B-32-plus-256.json
├── ViT-B-32-quickgelu.json
├── ViT-B-32.json
├── ViT-H-14.json
├── ViT-H-16.json
├── ViT-L-14-280.json
├── ViT-L-14-336.json
├── ViT-L-14.json
├── ViT-L-16-320.json
├── ViT-L-16.json
├── ViT-g-14.json
├── timm-efficientnetv2_rw_s.json
├── timm-resnet50d.json
├── timm-resnetaa50d.json
├── timm-resnetblur50.json
├── timm-swin_base_patch4_window7_224.json
├── timm-vit_base_patch16_224.json
├── timm-vit_base_patch32_224.json
└── timm-vit_small_patch16_224.json
├── openai.py
├── pretrained.py
├── timm_model.py
├── tokenizer.py
├── transform.py
├── utils.py
└── version.py
/README.md:
--------------------------------------------------------------------------------
1 | # CREPE
2 |
3 | In this repository, you can find the code we used to evaluate these models: [open_clip](https://github.com/mlfoundations/open_clip) CLIP models,
4 | the official OpenAI CLIP models, CyCLIP, FLAVA and ALBEF on compositional reasoning in our paper [CREPE: Can Vision-Language Foundation Models Reason Compositionally?](https://arxiv.org/abs/2212.07796).
5 |
6 |
7 |
8 | ## Systematicity procedure
9 |
10 |
11 | ## Produtivity procedure
12 |
13 |
14 | ## Evaluation instructions
15 | In `crepe_eval_utils.py`, you can find common evaluation util functions, and you will need to replace `vg_image_paths`
16 | with the path to Visual Genome images on your machine. The VG images can be downloaded [here](https://drive.google.com/drive/folders/11dMtJByk7zmbQjV47PXVwfmakN3Gr5Ic?usp=share_link).
17 |
18 | We evaluated all models on an NVIDIA TITAN X GPU with a CUDA version of 11.4.
19 |
20 | ### Evaluate open_clip CLIP models on systematicity and productivity
21 | You will need to install the packages required to use open_clip [here](https://github.com/mlfoundations/open_clip/blob/main/requirements.txt).
22 | You can download the [pretrained CLIP models](https://github.com/mlfoundations/open_clip#pretrained-model-details) and replace `--model-dir`
23 | with your own model checkpoint directory path in `crepe_compo_eval_open_clip.py`. (You can also modify the code to use open_clip's
24 | [pretrained model interface](https://github.com/mlfoundations/open_clip#pretrained-model-interface).)
25 |
26 | To evaluate all models reported in our paper, simply run:
27 |
28 | ```
29 | python -m crepe_compo_eval_open_clip --compo-type --hard-neg-types --input-dir --output-dir
30 | ```
31 |
32 | where the valid compositionality types are `systematicity` and `productivity`. The valid negative types are `atom`, `comp` and `combined` (`atom`+`comp`) for systematicity, and `atom`, `swap` and `negate` for productivity.
33 |
34 | To evaluate other pretrained models, simply modify the `--train-dataset` argument and/or the `DATA2MODEL` variable in `crepe_compo_eval_open_clip.py`.
35 | **Note that the systematicity eval set should only be used to evaluate models pretrained on CC12M, YFCC15M or LAION400M.**
36 |
37 | ### Evaluate all other vision-language models on productivity
38 | For each model, you will need to clone the model's official repository, set up
39 | an environment according to its instructions and place the files `crepe_prod_eval_.py`
40 | and `crepe_eval_utils.py` to their relevant locations. In `crepe_params.py`, you will need to replace `--input-dir`
41 | with your own directory path to CREPE's productivity hard negatives test set.
42 |
43 | #### CLIP-specific instructions
44 | Clone the CLIP repository [here](https://github.com/openai/CLIP) and place `crepe_prod_eval_clip.py`
45 | and `crepe_eval_utils.py` on the top level of the repository. To evaluate models, simply run:
46 |
47 | ```
48 | python -m crepe_prod_eval_clip --model-name --hard-neg-types --output-dir
49 | ```
50 |
51 | where the valid negative types are `atom`, `swap` and `negate`, and model names are `RN50`, `RN101`, `ViT-B/32`, `ViT-B/16` and `ViT-L/14`.
52 |
53 | #### CyCLIP-specific instructions
54 | Clone the CyCLIP repository [here](https://github.com/goel-shashank/CyCLIP), place `crepe_prod_eval_cyclip.py`
55 | and `crepe_eval_utils.py` on the top level of the repository and download the
56 | model checkpoint under the folder `cyclip.pt` (accessible from the bottom of the
57 | repository's README). To evaluate models, simply run:
58 |
59 | ```
60 | python -m crepe_prod_eval_cyclip --hard-neg-types --output-dir
61 | ```
62 |
63 | #### FLAVA-specific instructions
64 | Clone the FLAVA repository [here](https://github.com/facebookresearch/multimodal) and copy `crepe_prod_eval_flava.py`
65 | and `crepe_eval_utils.py` into the folder `examples/flava/`. To evaluate models, simply run:
66 |
67 | ```
68 | python -m crepe_prod_eval_flava --hard-neg-types --output-dir
69 | ```
70 |
71 | #### ALBEF-specific instructions
72 | Clone the ALBEF repository [here](https://github.com/salesforce/ALBEF/tree/b9727e43c3040491774d1b22cc27718aa7772fac),
73 | copy `crepe_prod_eval_albef.py` and `crepe_eval_utils.py` to the top level of the repository
74 | and download the pretrained checkpoint marked '14M' from the repository. To evaluate models, simply run:
75 |
76 | ```
77 | python -m crepe_prod_eval_albef --hard-neg-types --output-dir
78 | ```
79 |
80 | ## Citation
81 | If you find our work helpful, please cite us:
82 |
83 | ```bibtex
84 | @article{ma2023crepe,
85 | title={CREPE: Can Vision-Language Foundation Models Reason Compositionally?},
86 | author={Zixian Ma and Jerry Hong and Mustafa Omer Gul and Mona Gandhi and Irena Gao and Ranjay Krishna},
87 | year={2023},
88 | journal={arXiv preprint arXiv:2212.07796},
89 | }
90 | ```
91 |
--------------------------------------------------------------------------------
/crepe_compo_eval_open_clip.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 | import torch
5 | import numpy as np
6 | import torch.nn.functional as F
7 | import torchvision.transforms.functional as TF
8 | from torch.utils.data import DataLoader
9 | from torch.utils.data.distributed import DistributedSampler
10 | from dataclasses import dataclass
11 | from open_clip import tokenize, create_model_and_transforms
12 | from crepe_eval_utils import BaseCsvDataset, get_one2many_metrics, get_one2many_rank, get_metrics
13 | from crepe_params import setup_args
14 |
15 | DATA2MODEL = {
16 | 'cc12m': {
17 | 'RN50-quickgelu': 'rn50-quickgelu-cc12m-f000538c.pt'
18 | },
19 | 'yfcc': {
20 | 'RN50-quickgelu': 'rn50-quickgelu-yfcc15m-455df137.pt',
21 | 'RN101-quickgelu': 'rn101-quickgelu-yfcc15m-3e04b30e.pt'
22 | },
23 | 'laion': {
24 | 'ViT-B-16':'vit_b_16-laion400m_e32-55e67d44.pt',
25 | 'ViT-B-16-plus-240': 'vit_b_16_plus_240-laion400m_e32-699c4b84.pt',
26 | 'ViT-B-32-quickgelu': 'vit_b_32-quickgelu-laion400m_e32-46683a32.pt',
27 | 'ViT-L-14': 'vit_l_14-laion400m_e32-3d133497.pt',
28 | }
29 | }
30 |
31 | COMPO_SPLITS = ['seen_compounds', 'unseen_compounds']
32 | COMPLEXITIES = list(range(4, 13))
33 |
34 | @dataclass
35 | class DataInfo:
36 | dataloader: DataLoader
37 | sampler: DistributedSampler
38 |
39 | class CsvDataset(BaseCsvDataset):
40 | def __init__(self, input_filename, args, transforms):
41 | super().__init__(input_filename, args, transforms=transforms)
42 |
43 | def __getitem__(self, idx):
44 | raw_image = self.get_image_by_id(self.images[idx])
45 | if self.crop:
46 | raw_image = TF.crop(raw_image, self.ys[idx], self.xs[idx], self.heights[idx], self.widths[idx])
47 | image = self.transforms(raw_image)
48 | if self.one2many:
49 | texts = tokenize([str(self.captions[idx])] + list(self.hard_negs[idx]))
50 | else:
51 | texts = tokenize([str(self.captions[idx])])[0]
52 | return image, texts
53 |
54 | def get_csv_dataset(args, preprocess_fn, is_train):
55 | input_filename = args.val_data
56 | assert input_filename
57 | dataset = CsvDataset(
58 | input_filename,
59 | args,
60 | preprocess_fn)
61 | num_samples = len(dataset)
62 |
63 | sampler = None
64 | shuffle = is_train and sampler is None
65 |
66 | dataloader = DataLoader(
67 | dataset,
68 | batch_size=args.batch_size,
69 | shuffle=shuffle,
70 | num_workers=1,
71 | pin_memory=True,
72 | sampler=sampler,
73 | drop_last=is_train,
74 | )
75 | dataloader.num_samples = num_samples
76 | dataloader.num_batches = len(dataloader)
77 |
78 | return DataInfo(dataloader, sampler)
79 |
80 | def get_data(args, preprocess_fns):
81 | preprocess_train, preprocess_val = preprocess_fns
82 | data = {}
83 |
84 | data["val"] = get_csv_dataset(
85 | args, preprocess_val, is_train=False)
86 | return data
87 |
88 | def evaluate(model, data, args):
89 | metrics = {}
90 | device = torch.device(args.device)
91 | model.eval()
92 |
93 | autocast = torch.cuda.amp.autocast
94 | dataloader = data['val'].dataloader
95 |
96 | # FIXME this does not scale past small eval datasets
97 | # all_image_features @ all_text_features will blow up memory and compute very quickly
98 | all_image_features, all_text_features = [], []
99 | one2many = dataloader.dataset.one2many
100 | if one2many:
101 | all_ranks = []
102 | with torch.no_grad():
103 | for i, batch in enumerate(dataloader):
104 | images, texts = batch
105 | images = images.to(device=device, non_blocking=True)
106 | texts = texts.to(device=device, non_blocking=True)
107 |
108 | if one2many:
109 | image_features = model.encode_image(images)
110 | image_features = F.normalize(image_features, dim=-1)
111 |
112 | texts = torch.squeeze(texts, dim=0)
113 | text_features = model.encode_text(texts)
114 | text_features = F.normalize(text_features, dim=-1)
115 |
116 | rank = get_one2many_rank(image_features, text_features)
117 | all_ranks.append(rank)
118 | else:
119 | with autocast():
120 | image_features, text_features, logit_scale = model(images, texts)
121 | # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly
122 | # however, system RAM is easily exceeded and compute time becomes problematic
123 | all_image_features.append(image_features.cpu())
124 | all_text_features.append(text_features.cpu())
125 |
126 | if one2many:
127 | val_metrics = get_one2many_metrics(np.array(all_ranks))
128 | metrics.update(
129 | {**val_metrics}
130 | )
131 | else:
132 | val_metrics = get_metrics(
133 | image_features=torch.cat(all_image_features),
134 | text_features=torch.cat(all_text_features)
135 | )
136 | metrics.update(
137 | {**val_metrics}
138 | )
139 |
140 | logging.info("\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()]))
141 |
142 | return metrics
143 |
144 | def gather_params(args, hard_neg_type, split):
145 | if args.compo_type == 'systematicity':
146 | if hard_neg_type in ['atom', 'comp', 'combined']:
147 | hard_neg_key = f'valid_hard_negs_{hard_neg_type}'
148 | else:
149 | raise NotImplementedError
150 |
151 | retrieval_data_path = os.path.join(args.input_dir, f'syst_vg_hard_negs_{split}_in_{args.train_dataset}.csv')
152 |
153 | elif args.compo_type == 'productivity':
154 | hard_neg_key = 'hard_negs'
155 | if hard_neg_type in ['atom', 'negate', 'swap']:
156 | input_dir = os.path.join(args.input_dir, hard_neg_type)
157 | retrieval_data_path = os.path.join(input_dir, f'prod_vg_hard_negs_{hard_neg_type}_complexity_{split}.csv')
158 | else:
159 | raise NotImplementedError
160 | else:
161 | raise NotImplementedError
162 |
163 | args.val_data = retrieval_data_path
164 | args.one2many = True
165 | args.crop = True
166 | args.hard_neg_key = hard_neg_key
167 | args.batch_size = 1
168 | return args
169 |
170 | def main():
171 | args = setup_args()
172 | models = DATA2MODEL[args.train_dataset].keys()
173 | if args.compo_type == 'systematicity':
174 | splits = COMPO_SPLITS
175 | elif args.compo_type == 'productivity':
176 | splits = COMPLEXITIES
177 |
178 | if args.output_dir:
179 | if not os.path.exists(args.output_dir):
180 | os.mkdir(args.output_dir)
181 |
182 | if torch.cuda.is_available():
183 | device = 'cuda:0'
184 | torch.cuda.set_device(device)
185 | else:
186 | device = 'cpu'
187 | args.device = device
188 | device = torch.device(device)
189 |
190 | for model_name in models:
191 | pretrained = os.path.join(args.model_dir, DATA2MODEL[args.train_dataset][model_name])
192 | model, preprocess_train, preprocess_val = create_model_and_transforms(
193 | model_name,
194 | pretrained,
195 | precision='amp',
196 | device=device
197 | )
198 | for hard_neg_type in args.hard_neg_types:
199 | all_metrics = {}
200 | for split in splits:
201 | # params = gather_params(args, model, split)
202 | print('\n' + '*' * 45 + f' Evaluating {model_name} {args.compo_type} on HN-{hard_neg_type.upper()} test set split {split} ' + '*' * 45 + '\n')
203 | args = gather_params(args, hard_neg_type, split)
204 | # initialize datasets
205 | data = get_data(args, (preprocess_train, preprocess_val))
206 | assert len(data), 'At least one dataset must be specified.'
207 |
208 | metrics = evaluate(model, data, args)
209 |
210 | all_metrics[split] = metrics
211 |
212 | if args.output_dir:
213 | output = os.path.join(args.output_dir, f'{args.compo_type}_{args.train_dataset}_{model_name}_{hard_neg_type}_metrics.json')
214 | print("saving results to:", output)
215 | with open(output, 'w') as f:
216 | json.dump(all_metrics, f)
217 |
218 | if __name__ == "__main__":
219 | main()
220 |
--------------------------------------------------------------------------------
/crepe_eval_utils.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import logging
3 | import os
4 | from PIL import Image
5 | from dataclasses import dataclass
6 |
7 | import torch
8 | from torch.utils.data import DataLoader, Dataset
9 | import numpy as np
10 |
11 | import pandas as pd
12 |
13 | logging.basicConfig(level=logging.INFO)
14 | logger = logging.getLogger()
15 |
16 | ### DATASET CONSTRUCTION
17 |
18 | class BaseCsvDataset(Dataset):
19 | def __init__(self, input_filename, args, transforms=None):
20 | logging.debug(f'Loading csv data from {input_filename}.')
21 | df = pd.read_csv(input_filename)
22 | # print(f"Total number of examples: {len(df)}.")
23 | self.crop = args.crop
24 | if self.crop:
25 | assert 'x' in df.columns and 'y' in df.columns and 'width' in df.columns and 'height' in df.columns, "missing x, y, width, or height."
26 | self.xs = df['x'].tolist()
27 | self.ys = df['y'].tolist()
28 | self.heights = df['height'].tolist()
29 | self.widths = df['width'].tolist()
30 | # print("cropping:", self.crop)
31 | self.one2many = args.one2many
32 | # print("one2many:", self.one2many)
33 | if self.one2many:
34 | self.hard_negs = [ast.literal_eval(ls_str) for ls_str in df[args.hard_neg_key]]
35 | self.images = df[args.csv_img_key].tolist()
36 | self.captions = df[args.csv_caption_key].tolist()
37 | self.transforms = transforms
38 |
39 | def __len__(self):
40 | return len(self.captions)
41 |
42 | def get_image_by_id(self, image_id):
43 | vg_image_paths = ['/nlp/scr/irena/data/visual_genome/img/VG_100K', '/nlp/scr/irena/data/visual_genome/img/VG_100K_2']
44 | for p in vg_image_paths:
45 | path = os.path.join(p, f"{image_id}.jpg")
46 | if os.path.exists(path):
47 | return Image.open(path).convert("RGB")
48 | raise FileNotFoundError(f'The image with id {image_id} is not found.')
49 |
50 | def __getitem__(self, idx):
51 | print("Not yet implemented.")
52 | assert(False)
53 |
54 | @dataclass
55 | class DataInfo:
56 | dataloader: DataLoader
57 |
58 | # EVALUATION UTILITIES
59 |
60 | def get_one2many_rank(image_features, text_features):
61 | logits_per_image = (image_features @ text_features.t()).detach().cpu()
62 | ground_truth = 0 # because the grountruth caption is placed first, see CsvDataset.__getitem__() in data.py
63 | ranking = torch.argsort(logits_per_image, descending=True)
64 | pred = torch.where(ranking == ground_truth)[1].detach().cpu().numpy()
65 | return pred
66 |
67 | def get_one2many_metrics(preds, name='image_to_text'):
68 | metrics = {}
69 | metrics[f"{name}_mean_rank"] = preds.mean() + 1
70 | metrics[f"{name}_rank_std"] = preds.std()
71 | metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
72 |
73 | for k in [1, 3, 5, 10]:
74 | metrics[f"{name}_R@{k}"] = np.mean(preds < k)
75 | metrics[f"{name}_R@{k}_std"] = np.std(preds < k)
76 | return metrics
77 |
78 | def get_metrics(image_features, text_features):
79 | metrics = {}
80 | logits_per_image = (image_features @ text_features.t()).detach().cpu()
81 | logits_per_text = logits_per_image.t().detach().cpu()
82 |
83 | logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
84 | ground_truth = torch.arange(len(text_features)).view(-1, 1)
85 |
86 | for name, logit in logits.items():
87 | ranking = torch.argsort(logit, descending=True)
88 | preds = torch.where(ranking == ground_truth)[1]
89 | preds = preds.detach().cpu().numpy()
90 | metrics[f"{name}_mean_rank"] = preds.mean() + 1
91 | metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
92 |
93 | for k in [1, 3, 5, 10]:
94 | metrics[f"{name}_R@{k}"] = np.mean(preds < k)
95 |
96 | return metrics
97 |
--------------------------------------------------------------------------------
/crepe_params.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def setup_args():
4 | parser = argparse.ArgumentParser(description="Run image2text retrieval eval.")
5 | parser.add_argument("--compo-type", required=True, type=str, default="systematicity", help="Either systematicity or productivity")
6 | parser.add_argument("--input-dir", required=True, type=str, default="/vision/group/CLIPComp/crepe/prod_hard_negatives")
7 | parser.add_argument('--hard-neg-types', required=True, type=str, nargs='+', help="The type(s) of hard negatives to include in the retrieval set.")
8 | parser.add_argument("--model-dir", type=str, default="/vision/group/clip")
9 | parser.add_argument("--output-dir", type=str, default="log/")
10 | parser.add_argument("--csv-img-key", type=str, default="image_id")
11 | parser.add_argument("--csv-caption-key", type=str, default="caption")
12 | parser.add_argument("--hard-neg-key", type=str, default="hard_negs", help="The column name of the hard negative captions.")
13 | parser.add_argument("--crop", type=bool, default=True, help="Whether to crop the image input.")
14 | parser.add_argument("--one2many", type=bool, default=True, help="Whether each image query has a different retrieval text set.")
15 | # For systematicity eval on open_clip's pretrained models with known training dataset
16 | parser.add_argument("--train-dataset", type=str, default="cc12m")
17 | # For CLIP & CyCLIP
18 | parser.add_argument("--model-name", type=str, default="RN50")
19 | # For CyCLIP
20 | parser.add_argument("--pretrained", default=False, action="store_true", help="Use the OpenAI pretrained models")
21 | args = parser.parse_args()
22 | return args
--------------------------------------------------------------------------------
/crepe_prod_eval_albef.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | from PIL import Image
10 | from time import time
11 |
12 | import torch
13 | from torch import nn
14 | from torch.utils.data import DataLoader
15 | from torchvision import transforms
16 | import torch.nn.functional as F
17 | import torchvision.transforms.functional as TF
18 | import numpy as np
19 | import json
20 |
21 | # ALBEF:
22 | # from torchmultimodal.transforms.flava_transform import FLAVAImageTransform
23 | import ruamel.yaml as yaml
24 | from models.model_retrieval import ALBEF
25 | from models.vit import interpolate_pos_embed
26 | # from transformers import BertTokenizer
27 | from models.tokenization_bert import BertTokenizer
28 |
29 | from crepe_eval_utils import BaseCsvDataset, get_one2many_rank, get_one2many_metrics, DataInfo
30 | from crepe_params import setup_args
31 |
32 | import pandas as pd
33 |
34 | logging.basicConfig(level=logging.INFO)
35 | logger = logging.getLogger()
36 |
37 | max_text_length = 512
38 | TEXT_DEFAULT_TOKENIZER = "bert-base-uncased"
39 | text_tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER)
40 |
41 | def collator(batch):
42 | images = torch.stack([x[0] for x in batch], dim=0)
43 | texts = torch.cat([x[1] for x in batch], dim=0)
44 | masks = torch.cat([x[2] for x in batch], dim=0)
45 |
46 | return images, texts, masks
47 |
48 | ### DATASET CONSTRUCTION
49 |
50 | def default_text_transform(texts):
51 | # Expect a list of texts
52 | tokenized_texts = []
53 | attention_masks = []
54 | start_time = time()
55 | for text in texts:
56 | tokenized = text_tokenizer(text, padding="max_length",
57 | max_length=max_text_length, truncation=True, return_tensors='pt')
58 |
59 | tokenized_texts.append(tokenized['input_ids'])
60 | attention_masks.append(tokenized['attention_mask'])
61 |
62 | tokenized_texts = torch.cat(tokenized_texts, dim=0)
63 | attention_masks = torch.cat(attention_masks, dim=0)
64 |
65 | return tokenized_texts, attention_masks
66 |
67 | class CsvDataset(BaseCsvDataset):
68 | def __init__(self, input_filename, args, config):
69 | super().__init__(input_filename, args)
70 |
71 | # albef transform:
72 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
73 | test_transform = transforms.Compose([
74 | transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
75 | transforms.ToTensor(),
76 | normalize,
77 | ])
78 | self.image_transform = test_transform
79 | self.text_transform = default_text_transform
80 |
81 | def __getitem__(self, idx):
82 | raw_image = self.get_image_by_id(self.images[idx])
83 | if self.crop:
84 | raw_image = TF.crop(raw_image, self.ys[idx], self.xs[idx], self.heights[idx], self.widths[idx])
85 | image = self.transforms(raw_image)
86 | texts, attn_mask = self.text_transform([str(self.captions[idx])] + list(self.hard_negs[idx]))
87 |
88 | return image, texts, attn_mask
89 |
90 | def get_data(args, retrieval_data_path, config):
91 | # Get CSVDataset
92 | input_filename = retrieval_data_path
93 | dataset = CsvDataset(
94 | input_filename,
95 | args,
96 | config=config)
97 | num_samples = len(dataset)
98 | sampler = None
99 | shuffle=False
100 |
101 | dataloader = DataLoader(
102 | dataset,
103 | batch_size=16,
104 | shuffle=shuffle,
105 | num_workers=1,
106 | pin_memory=True,
107 | sampler=sampler,
108 | drop_last=False,
109 | collate_fn=collator
110 | )
111 | dataloader.num_samples = num_samples
112 | dataloader.num_batches = len(dataloader)
113 |
114 | return DataInfo(dataloader)
115 |
116 | ### EVALUATION
117 |
118 | def evaluate(model, data, complexity, negative_type):
119 | metrics = {}
120 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
121 |
122 | dataloader = data.dataloader
123 | # num_samples = 0
124 | # samples_per_val = dataloader.num_samples
125 |
126 | # cumulative_loss = 0.0
127 | # all_image_features, all_text_features = [], []
128 | one2many = dataloader.dataset.one2many
129 | assert(one2many, "Not one2many?")
130 |
131 | if one2many:
132 | all_ranks = []
133 |
134 | with torch.no_grad():
135 | for i, batch in enumerate(dataloader):
136 | images, texts, masks = batch
137 | images = images.to(device=device, non_blocking=True)
138 | texts = texts.to(device=device, non_blocking=True)
139 | masks = masks.to(device=device, non_blocking=True)
140 |
141 | if one2many:
142 | image_feat = model.visual_encoder(images)
143 | image_embed = model.vision_proj(image_feat[:,0,:])
144 | image_embed = F.normalize(image_embed,dim=-1)
145 |
146 | text_out = model.text_encoder(texts, attention_mask = masks, mode='text')
147 | text_feat = text_out.last_hidden_state
148 | text_emb = F.normalize(model.text_proj(text_feat[:,0,:]))
149 |
150 | set_size = text_emb.shape[0] // image_embed.shape[0]
151 | for j in range(image_embed.shape[0]):
152 | curr_image_emb = image_embed[j:j+1, :]
153 | curr_text_emb = text_emb[j*set_size:(j+1)*set_size, :]
154 | rank = get_one2many_rank(curr_image_emb, curr_text_emb)
155 | all_ranks.append(rank)
156 |
157 | print(f'Processed example {i*16}')
158 |
159 | metrics = get_one2many_metrics(np.array(all_ranks))
160 |
161 | # Alter output here
162 | logging.info(
163 | "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
164 | )
165 |
166 | return metrics
167 |
168 | def main():
169 | args = setup_args()
170 | if args.output_dir:
171 | output_dir = os.path.join(args.output_dir, 'albef')
172 | if not os.path.exists(output_dir):
173 | os.makedirs(output_dir)
174 | # LOAD ALBEF
175 | config_str = './configs/Retrieval_coco.yaml'
176 | config = yaml.load(open(config_str, 'r'), Loader=yaml.Loader)
177 | tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER)
178 | albef = ALBEF(config=config, text_encoder=TEXT_DEFAULT_TOKENIZER, tokenizer=tokenizer)
179 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
180 | logger.info(f"Using device: {device}")
181 |
182 | # MODEL CHECKPOINT
183 | checkpoint = torch.load('./ALBEF.pth', map_location='cpu')
184 | state_dict = checkpoint['model']
185 |
186 | # reshape positional embedding to accomodate for image resolution change
187 | pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],albef.visual_encoder)
188 | state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped
189 | m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],albef.visual_encoder_m)
190 | state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped
191 |
192 | for key in list(state_dict.keys()):
193 | if 'bert' in key:
194 | encoder_key = key.replace('bert.','')
195 | state_dict[encoder_key] = state_dict[key]
196 | del state_dict[key]
197 | msg = albef.load_state_dict(state_dict,strict=False)
198 | albef = albef.to(device)
199 | albef.eval()
200 |
201 | for hard_neg_type in args.hard_neg_types:
202 | all_metrics = {}
203 | # Iterate over each complexity
204 | for i in range(4, 13):
205 | print('\n' + '*' * 45 + f' Evaluating on complexity {i} ' + '*' * 45 + '\n')
206 | start_time = time()
207 | retrieval_data_path = os.path.join(args.input_dir, f'{hard_neg_type}/prod_vg_hard_negs_{hard_neg_type}_complexity_{i}.csv')
208 |
209 | data = get_data(args, retrieval_data_path, config)
210 | metrics = evaluate(albef, data, i, hard_neg_type)
211 |
212 | print(f'Complexity {i} took {time() - start_time} seconds')
213 | all_metrics[i] = metrics
214 | if args.output_dir:
215 | output = os.path.join(output_dir, f'productivity_albef_{hard_neg_type}_metrics.json')
216 | print("saving results to:", output)
217 | with open(output, 'w') as f:
218 | json.dump(all_metrics, f)
219 |
220 | if __name__ == "__main__":
221 | main()
222 |
--------------------------------------------------------------------------------
/crepe_prod_eval_clip.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from time import time
4 | import json
5 |
6 | import torch
7 | import torchvision.transforms.functional as TF
8 | import clip
9 | from torch.utils.data import DataLoader
10 | import numpy as np
11 |
12 | import pandas as pd
13 |
14 | from crepe_eval_utils import BaseCsvDataset, get_one2many_rank, get_one2many_metrics, DataInfo
15 | from crepe_params import setup_args
16 |
17 | logging.basicConfig(level=logging.INFO)
18 | logger = logging.getLogger()
19 |
20 | def collator(batch):
21 | images = torch.stack([x[0] for x in batch], dim=0)
22 | texts = torch.cat([x[1] for x in batch], dim=0)
23 |
24 | return images, texts
25 |
26 | ### DATASET CONSTRUCTION
27 |
28 | class CsvDataset(BaseCsvDataset):
29 | def __init__(self, input_filename, args, processor, device):
30 | super().__init__(input_filename, args)
31 |
32 | self.processor = processor
33 | self.device = device
34 |
35 | def __getitem__(self, idx):
36 | raw_image = self.get_image_by_id(self.images[idx])
37 | if self.crop:
38 | raw_image = TF.crop(raw_image, self.ys[idx], self.xs[idx], self.heights[idx], self.widths[idx])
39 |
40 | image = self.processor(raw_image)
41 | texts = self.process_text([str(self.captions[idx])] + list(self.hard_negs[idx]))
42 | return image, texts
43 |
44 | def process_text(self, texts):
45 | proc_text = [clip.tokenize(text, truncate=True) for text in texts]
46 | return torch.cat(proc_text)
47 |
48 | def get_data(args, retrieval_data_path, processor, device):
49 | # Get CSVDataset
50 | input_filename = retrieval_data_path
51 | dataset = CsvDataset(
52 | input_filename,
53 | args,
54 | processor,
55 | device)
56 | num_samples = len(dataset)
57 | sampler = None
58 | shuffle=False
59 |
60 | dataloader = DataLoader(
61 | dataset,
62 | batch_size=16,
63 | shuffle=shuffle,
64 | num_workers=1,
65 | pin_memory=True,
66 | sampler=sampler,
67 | drop_last=False,
68 | collate_fn=collator
69 | )
70 | dataloader.num_samples = num_samples
71 | dataloader.num_batches = len(dataloader)
72 |
73 | return DataInfo(dataloader)
74 |
75 | ### EVALUATION
76 |
77 | def evaluate(model, data, complexity, negative_type, device):
78 | metrics = {}
79 |
80 | dataloader = data.dataloader
81 | # num_samples = 0
82 | # samples_per_val = dataloader.num_samples
83 |
84 | # cumulative_loss = 0.0
85 | # all_image_features, all_text_features = [], []
86 | one2many = dataloader.dataset.one2many
87 |
88 | if one2many:
89 | all_ranks = []
90 |
91 | with torch.no_grad():
92 | for i, batch in enumerate(dataloader):
93 | images, texts = batch
94 | images = images.to(device)
95 | texts = texts.to(device)
96 |
97 | if one2many:
98 | image_emb = model.encode_image(images)
99 | image_emb /= image_emb.norm(dim = -1, keepdim = True)
100 |
101 | text_emb = model.encode_text(texts)
102 | text_emb /= text_emb.norm(dim = -1, keepdim = True)
103 |
104 | set_size = text_emb.shape[0] // image_emb.shape[0]
105 | for j in range(image_emb.shape[0]):
106 | curr_image_emb = image_emb[j:j+1, :]
107 | curr_text_emb = text_emb[j*set_size:(j+1)*set_size, :]
108 | rank = get_one2many_rank(curr_image_emb, curr_text_emb)
109 | all_ranks.append(rank)
110 |
111 | print(f'Processed example {i*16}')
112 |
113 | metrics = get_one2many_metrics(np.array(all_ranks))
114 |
115 | # Alter output here
116 | logging.info(
117 | "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
118 | )
119 |
120 | return metrics
121 |
122 | def main():
123 | args = setup_args()
124 | if args.output_dir:
125 | output_dir = os.path.join(args.output_dir, 'open_ai_clip')
126 | if not os.path.exists(output_dir):
127 | os.makedirs(output_dir)
128 | # Load the model
129 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
130 | model, preprocess = clip.load(name = args.model_name, device=device)
131 | model = model.to(device)
132 | model.eval()
133 |
134 | for hard_neg_type in args.hard_neg_types:
135 | all_metrics = {}
136 | # Iterate over each complexity
137 | for i in range(4, 13):
138 | print('\n' + '*' * 45 + f' Evaluating on complexity {i} ' + '*' * 45 + '\n')
139 | start_time = time()
140 | retrieval_data_path = os.path.join(args.input_dir, f'{hard_neg_type}/prod_vg_hard_negs_{hard_neg_type}_complexity_{i}.csv')
141 |
142 | if args.model_name == "RN50" or args.model_name == "RN101":
143 | model_save_name = args.model_name
144 | elif args.model_name == "ViT-B/32":
145 | model_save_name = 'vit_b32'
146 | elif args.model_name == "ViT-B/16":
147 | model_save_name = 'vit_b16'
148 | elif args.model_name == "ViT-L/14":
149 | model_save_name = 'vit_l14'
150 |
151 | data = get_data(args, retrieval_data_path, preprocess, device)
152 | metrics = evaluate(model, data, i, hard_neg_type, device)
153 |
154 | print(f'Complexity {i} took {time() - start_time} seconds')
155 | all_metrics[i] = metrics
156 |
157 | if args.output_dir:
158 | output = os.path.join(output_dir, f'productivity_clip_{model_save_name}_{hard_neg_type}_metrics.json')
159 | print("saving results to:", output)
160 | with open(output, 'w') as f:
161 | json.dump(all_metrics, f)
162 |
163 | if __name__ == '__main__':
164 | main()
165 |
--------------------------------------------------------------------------------
/crepe_prod_eval_cyclip.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import ast
8 | import argparse
9 | import logging
10 | import os
11 | from PIL import Image, ImageFile
12 | from dataclasses import dataclass
13 | from time import time
14 | import json
15 |
16 | import torch
17 | import torchvision.transforms.functional as TF
18 | from pkgs.openai.clip import load
19 | from torch import nn
20 | from torch.utils.data import DataLoader, Dataset
21 | import numpy as np
22 |
23 | import pandas as pd
24 |
25 | from crepe_eval_utils import BaseCsvDataset, get_one2many_rank, get_one2many_metrics, DataInfo
26 | from crepe_params import setup_args
27 |
28 | logging.basicConfig(level=logging.INFO)
29 | logger = logging.getLogger()
30 |
31 | def collator(batch):
32 | texts = []
33 |
34 | images = torch.stack([x[0] for x in batch], dim=0)
35 | texts = torch.cat([x[1] for x in batch], dim=0)
36 | attention_masks = torch.cat([x[2] for x in batch], dim=0)
37 |
38 | return images, texts, attention_masks
39 |
40 | ### DATASET CONSTRUCTION
41 |
42 | class CsvDataset(BaseCsvDataset):
43 | def __init__(self, input_filename, args, processor):
44 | super().__init__(input_filename, args)
45 |
46 | self.processor = processor
47 |
48 | def __getitem__(self, idx):
49 | raw_image = self.get_image_by_id(self.images[idx])
50 | if self.crop:
51 | raw_image = TF.crop(raw_image, self.ys[idx], self.xs[idx], self.heights[idx], self.widths[idx])
52 |
53 |
54 | image = torch.tensor(self.processor.process_image(raw_image))
55 | return_dict = self.processor.process_text([str(self.captions[idx])] + list(self.hard_negs[idx]))
56 | input_ids = return_dict['input_ids']
57 | attention_mask = return_dict['attention_mask']
58 |
59 | return image, input_ids, attention_mask
60 |
61 | def get_data(args, retrieval_data_path, processor):
62 | # Get CSVDataset
63 | input_filename = retrieval_data_path
64 | dataset = CsvDataset(
65 | input_filename,
66 | args,
67 | processor)
68 | num_samples = len(dataset)
69 | sampler = None
70 | shuffle=False
71 |
72 | dataloader = DataLoader(
73 | dataset,
74 | batch_size=16,
75 | shuffle=shuffle,
76 | num_workers=1,
77 | pin_memory=True,
78 | sampler=sampler,
79 | drop_last=False,
80 | collate_fn=collator
81 | )
82 | dataloader.num_samples = num_samples
83 | dataloader.num_batches = len(dataloader)
84 |
85 | return DataInfo(dataloader)
86 |
87 | ### EVALUATION
88 |
89 | def evaluate(model, data, complexity, negative_type):
90 | metrics = {}
91 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92 |
93 | dataloader = data.dataloader
94 | # num_samples = 0
95 | # samples_per_val = dataloader.num_samples
96 |
97 | # cumulative_loss = 0.0
98 | # all_image_features, all_text_features = [], []
99 | one2many = dataloader.dataset.one2many
100 |
101 | if one2many:
102 | all_ranks = []
103 |
104 | with torch.no_grad():
105 | for i, batch in enumerate(dataloader):
106 | images, texts, attention_mask = batch
107 | images = images.to(device=device, non_blocking=True)
108 | texts = texts.to(device=device, non_blocking=True)
109 | attention_mask = attention_mask.to(device=device, non_blocking=True)
110 |
111 | if one2many:
112 | image_emb = model.get_image_features(images)
113 | image_emb /= image_emb.norm(dim = -1, keepdim = True)
114 |
115 | text_emb = model.get_text_features(input_ids = texts, attention_mask = attention_mask)
116 | text_emb /= text_emb.norm(dim = -1, keepdim = True)
117 |
118 | set_size = text_emb.shape[0] // image_emb.shape[0]
119 | for j in range(image_emb.shape[0]):
120 | curr_image_emb = image_emb[j:j+1, :]
121 | curr_text_emb = text_emb[j*set_size:(j+1)*set_size, :]
122 | rank = get_one2many_rank(curr_image_emb, curr_text_emb)
123 | all_ranks.append(rank)
124 |
125 | print(f'Processed example {i*16}')
126 |
127 | metrics = get_one2many_metrics(np.array(all_ranks))
128 |
129 | # Alter output here
130 | logging.info(
131 | "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
132 | )
133 |
134 | return metrics
135 |
136 | def main():
137 | args = setup_args()
138 | if args.output_dir:
139 | output_dir = os.path.join(args.output_dir, 'cyclip')
140 | if not os.path.exists(output_dir):
141 | os.makedirs(output_dir)
142 | # Load the model
143 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
144 | model, processor = load(name = args.model_name, pretrained = args.pretrained)
145 | checkpoint = torch.load('best.pt', map_location=device)
146 | state_dict = checkpoint['state_dict']
147 | if(next(iter(state_dict.items()))[0].startswith("module")):
148 | state_dict = {key[len("module."):]: value for key, value in state_dict.items()}
149 | model.load_state_dict(state_dict)
150 | model = model.to(device)
151 | model.eval()
152 |
153 | for hard_neg_type in args.hard_neg_types:
154 | all_metrics = {}
155 | # Iterate over each complexity
156 | for i in range(4, 13):
157 | print('\n' + '*' * 45 + f' Evaluating on complexity {i} ' + '*' * 45 + '\n')
158 | start_time = time()
159 | retrieval_data_path = os.path.join(args.input_dir, f'{hard_neg_type}/prod_vg_hard_negs_{hard_neg_type}_complexity_{i}.csv')
160 |
161 | data = get_data(args, retrieval_data_path, processor)
162 | metrics = evaluate(model, data, i, hard_neg_type)
163 |
164 | print(f'Complexity {i} took {time() - start_time} seconds')
165 |
166 | all_metrics[i] = metrics
167 |
168 | if args.output_dir:
169 | output = os.path.join(output_dir, f'productivity_cyclip_{args.model_name}_{hard_neg_type}_metrics.json')
170 | print("saving results to:", output)
171 | with open(output, 'w') as f:
172 | json.dump(all_metrics, f)
173 |
174 | if __name__ == "__main__":
175 | main()
176 |
--------------------------------------------------------------------------------
/crepe_prod_eval_flava.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import ast
8 |
9 | import logging
10 | import os
11 | from PIL import Image
12 | from dataclasses import dataclass
13 | from time import time
14 | import json
15 |
16 | import torch
17 | from torchmultimodal.transforms.flava_transform import FLAVAImageTransform
18 | from torch import nn
19 | from torch.utils.data import DataLoader, Dataset
20 | from torchmultimodal.models.flava.model import flava_model
21 | from transformers import BertTokenizer
22 | import torchvision.transforms.functional as TF
23 | import numpy as np
24 |
25 | import pandas as pd
26 |
27 | from crepe_eval_utils import BaseCsvDataset, get_one2many_rank, get_one2many_metrics, DataInfo
28 | from crepe_params import setup_args
29 |
30 | logging.basicConfig(level=logging.INFO)
31 | logger = logging.getLogger()
32 |
33 | max_text_length = 512
34 | TEXT_DEFAULT_TOKENIZER = "bert-base-uncased"
35 | text_tokenizer = BertTokenizer.from_pretrained(TEXT_DEFAULT_TOKENIZER)
36 |
37 | def collator(batch):
38 | texts = []
39 | images = torch.stack([x[0]["image"] for x in batch], dim=0)
40 | texts = torch.cat([x[1] for x in batch], dim=0)
41 |
42 | return images, texts
43 |
44 | ### DATASET CONSTRUCTION
45 |
46 | def default_text_transform(texts):
47 | # Expect a list of texts
48 | tokenized_texts = []
49 | start_time = time()
50 | for text in texts:
51 | tokenized = text_tokenizer(text, padding="max_length",
52 | max_length=max_text_length, truncation=True, return_tensors='pt')
53 | tokenized_texts.append(torch.LongTensor(tokenized['input_ids']))
54 | tokenized_texts = torch.cat(tokenized_texts, dim=0)
55 |
56 | return tokenized_texts
57 |
58 | class CsvDataset(BaseCsvDataset):
59 | def __init__(self, input_filename, args):
60 | super().__init__(input_filename, args)
61 |
62 | self.image_transform = FLAVAImageTransform(is_train=False)
63 | self.text_transform = default_text_transform
64 |
65 | def __getitem__(self, idx):
66 | raw_image = self.get_image_by_id(self.images[idx])
67 | if self.crop:
68 | raw_image = TF.crop(raw_image, self.ys[idx], self.xs[idx], self.heights[idx], self.widths[idx])
69 | image = self.image_transform(raw_image)
70 | if self.one2many:
71 | texts = self.text_transform([str(self.captions[idx])] + list(self.hard_negs[idx]))
72 | else:
73 | texts = self.text_transform([str(self.captions[idx])])[0]
74 | return image, texts
75 |
76 | def get_data(args, retrieval_data_path):
77 | # Get CSVDataset
78 | input_filename = retrieval_data_path
79 | dataset = CsvDataset(
80 | input_filename,
81 | args)
82 | num_samples = len(dataset)
83 | sampler = None
84 | shuffle=False
85 |
86 | dataloader = DataLoader(
87 | dataset,
88 | batch_size=8,
89 | shuffle=shuffle,
90 | num_workers=1,
91 | pin_memory=True,
92 | sampler=sampler,
93 | drop_last=False,
94 | collate_fn=collator
95 | )
96 | dataloader.num_samples = num_samples
97 | dataloader.num_batches = len(dataloader)
98 |
99 | return DataInfo(dataloader)
100 |
101 | ### EVALUATION
102 |
103 | def evaluate(model, data, complexity, negative_type, output_path):
104 | metrics = {}
105 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106 |
107 | dataloader = data.dataloader
108 | # num_samples = 0
109 | # samples_per_val = dataloader.num_samples
110 |
111 | # cumulative_loss = 0.0
112 | # all_image_features, all_text_features = [], []
113 | one2many = dataloader.dataset.one2many
114 | assert(one2many, "Not one2many?")
115 |
116 | if one2many:
117 | all_ranks = []
118 |
119 | with torch.no_grad():
120 | for i, batch in enumerate(dataloader):
121 | images, texts = batch
122 | images = images.to(device=device, non_blocking=True)
123 | texts = texts.to(device=device, non_blocking=True)
124 |
125 | if one2many:
126 | _, image_emb = model.encode_image(images, projection=True)
127 | image_emb = nn.functional.normalize(image_emb, dim=-1)
128 | _, text_emb = model.encode_text(texts, projection=True)
129 | text_emb = nn.functional.normalize(text_emb)
130 |
131 | set_size = text_emb.shape[0] // image_emb.shape[0]
132 | for j in range(image_emb.shape[0]):
133 | curr_image_emb = image_emb[j:j+1, :]
134 | curr_text_emb = text_emb[j*set_size:(j+1)*set_size, :]
135 | rank = get_one2many_rank(curr_image_emb, curr_text_emb)
136 | all_ranks.append(rank)
137 |
138 | # print(f'Processed example {i*8}')
139 |
140 | metrics = get_one2many_metrics(np.array(all_ranks))
141 |
142 | # Alter output here
143 | logging.info(
144 | "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
145 | )
146 |
147 | return metrics
148 |
149 | def main():
150 | args = setup_args()
151 | if args.output_dir:
152 | output_dir = os.path.join(args.output_dir, 'flava')
153 | if not os.path.exists(output_dir):
154 | os.makedirs(output_dir)
155 | # Load the model
156 | flava = flava_model(pretrained=True)
157 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
158 | logger.info(f"Using device: {device}")
159 | flava = flava.to(device)
160 | flava.eval()
161 |
162 | for hard_neg_type in args.hard_neg_types:
163 | all_metrics = {}
164 | # Iterate over each complexity
165 | for i in range(4, 13):
166 | print('\n' + '*' * 45 + f' Evaluating on complexity {i} ' + '*' * 45 + '\n')
167 | start_time = time()
168 | retrieval_data_path = os.path.join(args.input_dir, f'{hard_neg_type}/prod_vg_hard_negs_{hard_neg_type}_complexity_{i}.csv')
169 |
170 | data = get_data(args, retrieval_data_path)
171 | metrics = evaluate(flava, data, i, hard_neg_type)
172 |
173 | print(f'Complexity {i} took {time() - start_time} seconds')
174 | all_metrics[i] = metrics
175 |
176 | if args.output_dir:
177 | output = os.path.join(output_dir, f'productivity_flava_{hard_neg_type}_metrics.json')
178 | print("saving results to:", output)
179 | with open(output, 'w') as f:
180 | json.dump(all_metrics, f)
181 |
182 | if __name__ == "__main__":
183 | main()
184 |
--------------------------------------------------------------------------------
/data/prod_hard_negatives.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RAIVNLab/CREPE/1fa81c425f442396fe304f170c8eb6dc0747c814/data/prod_hard_negatives.zip
--------------------------------------------------------------------------------
/data/syst_hard_negatives.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RAIVNLab/CREPE/1fa81c425f442396fe304f170c8eb6dc0747c814/data/syst_hard_negatives.zip
--------------------------------------------------------------------------------
/open_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .factory import list_models, create_model, create_model_and_transforms, add_model_config
2 | from .loss import ClipLoss
3 | from .model import CLIP, CLIPTextCfg, CLIPVisionCfg, convert_weights_to_fp16, trace_model
4 | from .openai import load_openai_model, list_openai_models
5 | from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\
6 | get_pretrained_url, download_pretrained
7 | from .tokenizer import SimpleTokenizer, tokenize
8 | from .transform import image_transform
9 |
--------------------------------------------------------------------------------
/open_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RAIVNLab/CREPE/1fa81c425f442396fe304f170c8eb6dc0747c814/open_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/open_clip/factory.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import pathlib
5 | import re
6 | from copy import deepcopy
7 | from pathlib import Path
8 |
9 | import torch
10 |
11 | from .model import CLIP, convert_weights_to_fp16
12 | from .openai import load_openai_model
13 | from .pretrained import get_pretrained_url, download_pretrained
14 | from .transform import image_transform
15 |
16 |
17 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
18 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
19 |
20 |
21 | def _natural_key(string_):
22 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
23 |
24 |
25 | def _rescan_model_configs():
26 | global _MODEL_CONFIGS
27 |
28 | config_ext = ('.json',)
29 | config_files = []
30 | for config_path in _MODEL_CONFIG_PATHS:
31 | if config_path.is_file() and config_path.suffix in config_ext:
32 | config_files.append(config_path)
33 | elif config_path.is_dir():
34 | for ext in config_ext:
35 | config_files.extend(config_path.glob(f'*{ext}'))
36 |
37 | for cf in config_files:
38 | with open(cf, 'r') as f:
39 | model_cfg = json.load(f)
40 | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
41 | _MODEL_CONFIGS[cf.stem] = model_cfg
42 |
43 | _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
44 |
45 |
46 | _rescan_model_configs() # initial populate of model config registry
47 |
48 |
49 | def load_state_dict(checkpoint_path: str, map_location='cpu'):
50 | checkpoint = torch.load(checkpoint_path, map_location=map_location)
51 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
52 | state_dict = checkpoint['state_dict']
53 | else:
54 | state_dict = checkpoint
55 | if next(iter(state_dict.items()))[0].startswith('module'):
56 | state_dict = {k[7:]: v for k, v in state_dict.items()}
57 | return state_dict
58 |
59 |
60 | def create_model(
61 | model_name: str,
62 | pretrained: str = '',
63 | precision: str = 'fp32',
64 | device: torch.device = torch.device('cpu'),
65 | jit: bool = False,
66 | force_quick_gelu: bool = False,
67 | pretrained_image: bool = False,
68 | ):
69 | model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
70 |
71 | if pretrained.lower() == 'openai':
72 | logging.info(f'Loading pretrained {model_name} from OpenAI.')
73 | model = load_openai_model(model_name, device=device, jit=jit)
74 | # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
75 | if precision == "amp" or precision == "fp32":
76 | model = model.float()
77 | else:
78 | if model_name in _MODEL_CONFIGS:
79 | logging.info(f'Loading {model_name} model config.')
80 | model_cfg = deepcopy(_MODEL_CONFIGS[model_name])
81 | else:
82 | logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
83 | raise RuntimeError(f'Model config for {model_name} not found.')
84 |
85 | if force_quick_gelu:
86 | # override for use of QuickGELU on non-OpenAI transformer models
87 | model_cfg["quick_gelu"] = True
88 |
89 | if pretrained_image:
90 | if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
91 | # pretrained weight loading for timm models set via vision_cfg
92 | model_cfg['vision_cfg']['timm_model_pretrained'] = True
93 | else:
94 | assert False, 'pretrained image towers currently only supported for timm models'
95 |
96 | model = CLIP(**model_cfg)
97 |
98 | if pretrained:
99 | checkpoint_path = ''
100 | url = get_pretrained_url(model_name, pretrained)
101 | if url:
102 | checkpoint_path = download_pretrained(url)
103 | elif os.path.exists(pretrained):
104 | checkpoint_path = pretrained
105 |
106 | if checkpoint_path:
107 | logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
108 | model.load_state_dict(load_state_dict(checkpoint_path))
109 | else:
110 | logging.warning(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
111 | raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
112 |
113 | model.to(device=device)
114 | if precision == "fp16":
115 | assert device.type != 'cpu'
116 | convert_weights_to_fp16(model)
117 |
118 | if jit:
119 | model = torch.jit.script(model)
120 |
121 | return model
122 |
123 |
124 | def create_model_and_transforms(
125 | model_name: str,
126 | pretrained: str = '',
127 | precision: str = 'fp32',
128 | device: torch.device = torch.device('cpu'),
129 | jit: bool = False,
130 | force_quick_gelu: bool = False,
131 | pretrained_image: bool = False,
132 | ):
133 | model = create_model(
134 | model_name, pretrained, precision, device, jit,
135 | force_quick_gelu=force_quick_gelu,
136 | pretrained_image=pretrained_image)
137 | preprocess_train = image_transform(model.visual.image_size, is_train=True)
138 | preprocess_val = image_transform(model.visual.image_size, is_train=False)
139 | return model, preprocess_train, preprocess_val
140 |
141 |
142 | def list_models():
143 | """ enumerate available model architectures based on config files """
144 | return list(_MODEL_CONFIGS.keys())
145 |
146 |
147 | def add_model_config(path):
148 | """ add model config path or file and update registry """
149 | if not isinstance(path, Path):
150 | path = Path(path)
151 | _MODEL_CONFIG_PATHS.append(path)
152 | _rescan_model_configs()
153 |
--------------------------------------------------------------------------------
/open_clip/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed.nn
3 | from torch import distributed as dist, nn as nn
4 | from torch.nn import functional as F
5 |
6 | try:
7 | import horovod.torch as hvd
8 | except ImportError:
9 | hvd = None
10 |
11 |
12 | def gather_features(
13 | image_features,
14 | text_features,
15 | local_loss=False,
16 | gather_with_grad=False,
17 | rank=0,
18 | world_size=1,
19 | use_horovod=False
20 | ):
21 | if use_horovod:
22 | assert hvd is not None, 'Please install horovod'
23 | if gather_with_grad:
24 | all_image_features = hvd.allgather(image_features)
25 | all_text_features = hvd.allgather(text_features)
26 | else:
27 | with torch.no_grad():
28 | all_image_features = hvd.allgather(image_features)
29 | all_text_features = hvd.allgather(text_features)
30 | if not local_loss:
31 | # ensure grads for local rank when all_* features don't have a gradient
32 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
33 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
34 | gathered_image_features[rank] = image_features
35 | gathered_text_features[rank] = text_features
36 | all_image_features = torch.cat(gathered_image_features, dim=0)
37 | all_text_features = torch.cat(gathered_text_features, dim=0)
38 | else:
39 | # We gather tensors from all gpus
40 | if gather_with_grad:
41 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
42 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
43 | else:
44 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
45 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
46 | dist.all_gather(gathered_image_features, image_features)
47 | dist.all_gather(gathered_text_features, text_features)
48 | if not local_loss:
49 | # ensure grads for local rank when all_* features don't have a gradient
50 | gathered_image_features[rank] = image_features
51 | gathered_text_features[rank] = text_features
52 | all_image_features = torch.cat(gathered_image_features, dim=0)
53 | all_text_features = torch.cat(gathered_text_features, dim=0)
54 |
55 | return all_image_features, all_text_features
56 |
57 |
58 | class ClipLoss(nn.Module):
59 |
60 | def __init__(
61 | self,
62 | local_loss=False,
63 | gather_with_grad=False,
64 | cache_labels=False,
65 | rank=0,
66 | world_size=1,
67 | use_horovod=False,
68 | ):
69 | super().__init__()
70 | self.local_loss = local_loss
71 | self.gather_with_grad = gather_with_grad
72 | self.cache_labels = cache_labels
73 | self.rank = rank
74 | self.world_size = world_size
75 | self.use_horovod = use_horovod
76 |
77 | # cache state
78 | self.prev_num_logits = 0
79 | self.labels = {}
80 |
81 | def forward(self, image_features, text_features, logit_scale):
82 | device = image_features.device
83 | if self.world_size > 1:
84 | all_image_features, all_text_features = gather_features(
85 | image_features, text_features,
86 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
87 |
88 | if self.local_loss:
89 | logits_per_image = logit_scale * image_features @ all_text_features.T
90 | logits_per_text = logit_scale * text_features @ all_image_features.T
91 | else:
92 | logits_per_image = logit_scale * all_image_features @ all_text_features.T
93 | logits_per_text = logits_per_image.T
94 | else:
95 | logits_per_image = logit_scale * image_features @ text_features.T
96 | logits_per_text = logit_scale * text_features @ image_features.T
97 |
98 | # calculated ground-truth and cache if enabled
99 | num_logits = logits_per_image.shape[0]
100 | if self.prev_num_logits != num_logits or device not in self.labels:
101 | labels = torch.arange(num_logits, device=device, dtype=torch.long)
102 | if self.world_size > 1 and self.local_loss:
103 | labels = labels + num_logits * self.rank
104 | if self.cache_labels:
105 | self.labels[device] = labels
106 | self.prev_num_logits = num_logits
107 | else:
108 | labels = self.labels[device]
109 |
110 | total_loss = (
111 | F.cross_entropy(logits_per_image, labels) +
112 | F.cross_entropy(logits_per_text, labels)
113 | ) / 2
114 | return total_loss
115 |
--------------------------------------------------------------------------------
/open_clip/model.py:
--------------------------------------------------------------------------------
1 | """ CLIP Model
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | from collections import OrderedDict
7 | from dataclasses import dataclass
8 | from typing import Tuple, Union, Callable, Optional
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn.functional as F
13 | from torch import nn
14 | from torch.utils.checkpoint import checkpoint
15 |
16 | from .timm_model import TimmModel
17 | from .utils import freeze_batch_norm_2d
18 |
19 |
20 | class Bottleneck(nn.Module):
21 | expansion = 4
22 |
23 | def __init__(self, inplanes, planes, stride=1):
24 | super().__init__()
25 |
26 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
27 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
28 | self.bn1 = nn.BatchNorm2d(planes)
29 | self.relu1 = nn.ReLU(inplace=True)
30 |
31 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
32 | self.bn2 = nn.BatchNorm2d(planes)
33 | self.relu2 = nn.ReLU(inplace=True)
34 |
35 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
36 |
37 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
38 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
39 | self.relu3 = nn.ReLU(inplace=True)
40 |
41 | self.downsample = None
42 | self.stride = stride
43 |
44 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
45 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
46 | self.downsample = nn.Sequential(OrderedDict([
47 | ("-1", nn.AvgPool2d(stride)),
48 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
49 | ("1", nn.BatchNorm2d(planes * self.expansion))
50 | ]))
51 |
52 | def forward(self, x: torch.Tensor):
53 | identity = x
54 |
55 | out = self.relu1(self.bn1(self.conv1(x)))
56 | out = self.relu2(self.bn2(self.conv2(out)))
57 | out = self.avgpool(out)
58 | out = self.bn3(self.conv3(out))
59 |
60 | if self.downsample is not None:
61 | identity = self.downsample(x)
62 |
63 | out += identity
64 | out = self.relu3(out)
65 | return out
66 |
67 |
68 | class AttentionPool2d(nn.Module):
69 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
70 | super().__init__()
71 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
72 | self.k_proj = nn.Linear(embed_dim, embed_dim)
73 | self.q_proj = nn.Linear(embed_dim, embed_dim)
74 | self.v_proj = nn.Linear(embed_dim, embed_dim)
75 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
76 | self.num_heads = num_heads
77 |
78 | def forward(self, x):
79 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
80 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
81 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
82 | x, _ = F.multi_head_attention_forward(
83 | query=x, key=x, value=x,
84 | embed_dim_to_check=x.shape[-1],
85 | num_heads=self.num_heads,
86 | q_proj_weight=self.q_proj.weight,
87 | k_proj_weight=self.k_proj.weight,
88 | v_proj_weight=self.v_proj.weight,
89 | in_proj_weight=None,
90 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
91 | bias_k=None,
92 | bias_v=None,
93 | add_zero_attn=False,
94 | dropout_p=0,
95 | out_proj_weight=self.c_proj.weight,
96 | out_proj_bias=self.c_proj.bias,
97 | use_separate_proj_weight=True,
98 | training=self.training,
99 | need_weights=False
100 | )
101 |
102 | return x[0]
103 |
104 |
105 | class ModifiedResNet(nn.Module):
106 | """
107 | A ResNet class that is similar to torchvision's but contains the following changes:
108 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
109 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
110 | - The final pooling layer is a QKV attention instead of an average pool
111 | """
112 |
113 | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
114 | super().__init__()
115 | self.output_dim = output_dim
116 | self.image_size = image_size
117 |
118 | # the 3-layer stem
119 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
120 | self.bn1 = nn.BatchNorm2d(width // 2)
121 | self.relu1 = nn.ReLU(inplace=True)
122 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
123 | self.bn2 = nn.BatchNorm2d(width // 2)
124 | self.relu2 = nn.ReLU(inplace=True)
125 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
126 | self.bn3 = nn.BatchNorm2d(width)
127 | self.relu3 = nn.ReLU(inplace=True)
128 | self.avgpool = nn.AvgPool2d(2)
129 |
130 | # residual layers
131 | self._inplanes = width # this is a *mutable* variable used during construction
132 | self.layer1 = self._make_layer(width, layers[0])
133 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
134 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
135 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
136 |
137 | embed_dim = width * 32 # the ResNet feature dimension
138 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
139 |
140 | self.init_parameters()
141 |
142 | def _make_layer(self, planes, blocks, stride=1):
143 | layers = [Bottleneck(self._inplanes, planes, stride)]
144 |
145 | self._inplanes = planes * Bottleneck.expansion
146 | for _ in range(1, blocks):
147 | layers.append(Bottleneck(self._inplanes, planes))
148 |
149 | return nn.Sequential(*layers)
150 |
151 | def init_parameters(self):
152 | if self.attnpool is not None:
153 | std = self.attnpool.c_proj.in_features ** -0.5
154 | nn.init.normal_(self.attnpool.q_proj.weight, std=std)
155 | nn.init.normal_(self.attnpool.k_proj.weight, std=std)
156 | nn.init.normal_(self.attnpool.v_proj.weight, std=std)
157 | nn.init.normal_(self.attnpool.c_proj.weight, std=std)
158 |
159 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
160 | for name, param in resnet_block.named_parameters():
161 | if name.endswith("bn3.weight"):
162 | nn.init.zeros_(param)
163 |
164 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
165 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
166 | for param in self.parameters():
167 | param.requires_grad = False
168 | if freeze_bn_stats:
169 | freeze_batch_norm_2d(self)
170 |
171 | @torch.jit.ignore
172 | def set_grad_checkpointing(self, enable=True):
173 | # FIXME support for non-transformer
174 | pass
175 |
176 | def stem(self, x):
177 | x = self.relu1(self.bn1(self.conv1(x)))
178 | x = self.relu2(self.bn2(self.conv2(x)))
179 | x = self.relu3(self.bn3(self.conv3(x)))
180 | x = self.avgpool(x)
181 | return x
182 |
183 | def forward(self, x):
184 | x = self.stem(x)
185 | x = self.layer1(x)
186 | x = self.layer2(x)
187 | x = self.layer3(x)
188 | x = self.layer4(x)
189 | x = self.attnpool(x)
190 |
191 | return x
192 |
193 |
194 | class LayerNorm(nn.LayerNorm):
195 | """Subclass torch's LayerNorm to handle fp16."""
196 |
197 | def forward(self, x: torch.Tensor):
198 | orig_type = x.dtype
199 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
200 | return x.to(orig_type)
201 |
202 |
203 | class QuickGELU(nn.Module):
204 | # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
205 | def forward(self, x: torch.Tensor):
206 | return x * torch.sigmoid(1.702 * x)
207 |
208 |
209 | class ResidualAttentionBlock(nn.Module):
210 | def __init__(self, d_model: int, n_head: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU):
211 | super().__init__()
212 |
213 | self.attn = nn.MultiheadAttention(d_model, n_head)
214 | self.ln_1 = LayerNorm(d_model)
215 | mlp_width = int(d_model * mlp_ratio)
216 | self.mlp = nn.Sequential(OrderedDict([
217 | ("c_fc", nn.Linear(d_model, mlp_width)),
218 | ("gelu", act_layer()),
219 | ("c_proj", nn.Linear(mlp_width, d_model))
220 | ]))
221 | self.ln_2 = LayerNorm(d_model)
222 |
223 | def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
224 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
225 |
226 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
227 | x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
228 | x = x + self.mlp(self.ln_2(x))
229 | return x
230 |
231 |
232 | class Transformer(nn.Module):
233 | def __init__(self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU):
234 | super().__init__()
235 | self.width = width
236 | self.layers = layers
237 | self.grad_checkpointing = False
238 |
239 | self.resblocks = nn.ModuleList([
240 | ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer)
241 | for _ in range(layers)
242 | ])
243 |
244 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
245 | for r in self.resblocks:
246 | if self.grad_checkpointing and not torch.jit.is_scripting():
247 | x = checkpoint(r, x, attn_mask)
248 | else:
249 | x = r(x, attn_mask=attn_mask)
250 | return x
251 |
252 |
253 | class VisualTransformer(nn.Module):
254 | def __init__(
255 | self, image_size: int, patch_size: int, width: int, layers: int, heads: int, mlp_ratio: float,
256 | output_dim: int, act_layer: Callable = nn.GELU):
257 | super().__init__()
258 | self.image_size = image_size
259 | self.output_dim = output_dim
260 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
261 |
262 | scale = width ** -0.5
263 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
264 | self.positional_embedding = nn.Parameter(scale * torch.randn((image_size // patch_size) ** 2 + 1, width))
265 | self.ln_pre = LayerNorm(width)
266 |
267 | self.transformer = Transformer(width, layers, heads, mlp_ratio, act_layer=act_layer)
268 |
269 | self.ln_post = LayerNorm(width)
270 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
271 |
272 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
273 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
274 | for param in self.parameters():
275 | param.requires_grad = False
276 |
277 | @torch.jit.ignore
278 | def set_grad_checkpointing(self, enable=True):
279 | self.transformer.grad_checkpointing = enable
280 |
281 | def forward(self, x: torch.Tensor):
282 | x = self.conv1(x) # shape = [*, width, grid, grid]
283 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
284 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
285 | x = torch.cat(
286 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
287 | x], dim=1) # shape = [*, grid ** 2 + 1, width]
288 | x = x + self.positional_embedding.to(x.dtype)
289 | x = self.ln_pre(x)
290 |
291 | x = x.permute(1, 0, 2) # NLD -> LND
292 | x = self.transformer(x)
293 | x = x.permute(1, 0, 2) # LND -> NLD
294 |
295 | x = self.ln_post(x[:, 0, :])
296 |
297 | if self.proj is not None:
298 | x = x @ self.proj
299 |
300 | return x
301 |
302 |
303 | @dataclass
304 | class CLIPVisionCfg:
305 | layers: Union[Tuple[int, int, int, int], int] = 12
306 | width: int = 768
307 | head_width: int = 64
308 | mlp_ratio: float = 4.0
309 | patch_size: int = 16
310 | image_size: Union[Tuple[int, int], int] = 224
311 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size
312 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
313 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
314 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
315 |
316 |
317 | @dataclass
318 | class CLIPTextCfg:
319 | context_length: int = 77
320 | vocab_size: int = 49408
321 | width: int = 512
322 | heads: int = 8
323 | layers: int = 12
324 |
325 |
326 | class CLIP(nn.Module):
327 | def __init__(
328 | self,
329 | embed_dim: int,
330 | vision_cfg: CLIPVisionCfg,
331 | text_cfg: CLIPTextCfg,
332 | quick_gelu: bool = False,
333 | ):
334 | super().__init__()
335 | if isinstance(vision_cfg, dict):
336 | vision_cfg = CLIPVisionCfg(**vision_cfg)
337 | if isinstance(text_cfg, dict):
338 | text_cfg = CLIPTextCfg(**text_cfg)
339 |
340 | self.context_length = text_cfg.context_length
341 |
342 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
343 | # memory efficient in recent PyTorch releases (>= 1.10).
344 | # NOTE: timm models always use native GELU regardless of quick_gelu flag.
345 | act_layer = QuickGELU if quick_gelu else nn.GELU
346 |
347 | if vision_cfg.timm_model_name:
348 | self.visual = TimmModel(
349 | vision_cfg.timm_model_name,
350 | pretrained=vision_cfg.timm_model_pretrained,
351 | pool=vision_cfg.timm_pool,
352 | proj=vision_cfg.timm_proj,
353 | embed_dim=embed_dim,
354 | image_size=vision_cfg.image_size
355 | )
356 | act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
357 | elif isinstance(vision_cfg.layers, (tuple, list)):
358 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
359 | self.visual = ModifiedResNet(
360 | layers=vision_cfg.layers,
361 | output_dim=embed_dim,
362 | heads=vision_heads,
363 | image_size=vision_cfg.image_size,
364 | width=vision_cfg.width
365 | )
366 | else:
367 | vision_heads = vision_cfg.width // vision_cfg.head_width
368 | self.visual = VisualTransformer(
369 | image_size=vision_cfg.image_size,
370 | patch_size=vision_cfg.patch_size,
371 | width=vision_cfg.width,
372 | layers=vision_cfg.layers,
373 | heads=vision_heads,
374 | mlp_ratio=vision_cfg.mlp_ratio,
375 | output_dim=embed_dim,
376 | act_layer=act_layer,
377 | )
378 |
379 | self.transformer = Transformer(
380 | width=text_cfg.width,
381 | layers=text_cfg.layers,
382 | heads=text_cfg.heads,
383 | act_layer=act_layer,
384 | )
385 |
386 | self.vocab_size = text_cfg.vocab_size
387 | self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
388 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, text_cfg.width))
389 | self.ln_final = LayerNorm(text_cfg.width)
390 |
391 | self.text_projection = nn.Parameter(torch.empty(text_cfg.width, embed_dim))
392 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
393 | self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
394 |
395 | self.init_parameters()
396 |
397 | def init_parameters(self):
398 | nn.init.normal_(self.token_embedding.weight, std=0.02)
399 | nn.init.normal_(self.positional_embedding, std=0.01)
400 | nn.init.constant_(self.logit_scale, np.log(1 / 0.07))
401 |
402 | if hasattr(self.visual, 'init_parameters'):
403 | self.visual.init_parameters()
404 |
405 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
406 | attn_std = self.transformer.width ** -0.5
407 | fc_std = (2 * self.transformer.width) ** -0.5
408 | for block in self.transformer.resblocks:
409 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
410 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
411 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
412 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
413 |
414 | if self.text_projection is not None:
415 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
416 |
417 | def build_attention_mask(self):
418 | # lazily create causal attention mask, with full attention between the vision tokens
419 | # pytorch uses additive attention mask; fill with -inf
420 | mask = torch.empty(self.context_length, self.context_length)
421 | mask.fill_(float("-inf"))
422 | mask.triu_(1) # zero out the lower diagonal
423 | return mask
424 |
425 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
426 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
427 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
428 |
429 | @torch.jit.ignore
430 | def set_grad_checkpointing(self, enable=True):
431 | self.visual.set_grad_checkpointing(enable)
432 | self.transformer.grad_checkpointing = enable
433 |
434 | def encode_image(self, image):
435 | return self.visual(image)
436 |
437 | def encode_text(self, text):
438 | # print('text before embedding:', text)
439 | x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
440 | # print('text after embedding:', x)
441 | x = x + self.positional_embedding
442 | x = x.permute(1, 0, 2) # NLD -> LND
443 | x = self.transformer(x, attn_mask=self.attn_mask)
444 | x = x.permute(1, 0, 2) # LND -> NLD
445 | x = self.ln_final(x)
446 |
447 | # x.shape = [batch_size, n_ctx, transformer.width]
448 | # take features from the eot embedding (eot_token is the highest number in each sequence)
449 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
450 |
451 | return x
452 |
453 | def forward(self, image, text):
454 | if image is None:
455 | return self.encode_text(text)
456 | elif text is None:
457 | return self.encode_image(image)
458 | image_features = self.encode_image(image)
459 | image_features = F.normalize(image_features, dim=-1)
460 |
461 | text_features = self.encode_text(text)
462 | text_features = F.normalize(text_features, dim=-1)
463 |
464 | return image_features, text_features, self.logit_scale.exp()
465 |
466 |
467 | def convert_weights_to_fp16(model: nn.Module):
468 | """Convert applicable model parameters to fp16"""
469 |
470 | def _convert_weights_to_fp16(l):
471 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
472 | l.weight.data = l.weight.data.half()
473 | if l.bias is not None:
474 | l.bias.data = l.bias.data.half()
475 |
476 | if isinstance(l, nn.MultiheadAttention):
477 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
478 | tensor = getattr(l, attr)
479 | if tensor is not None:
480 | tensor.data = tensor.data.half()
481 |
482 | for name in ["text_projection", "proj"]:
483 | if hasattr(l, name):
484 | attr = getattr(l, name)
485 | if attr is not None:
486 | attr.data = attr.data.half()
487 |
488 | model.apply(_convert_weights_to_fp16)
489 |
490 |
491 | def build_model_from_openai_state_dict(state_dict: dict):
492 | vit = "visual.proj" in state_dict
493 |
494 | if vit:
495 | vision_width = state_dict["visual.conv1.weight"].shape[0]
496 | vision_layers = len(
497 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
498 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
499 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
500 | image_size = vision_patch_size * grid_size
501 | else:
502 | counts: list = [
503 | len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
504 | vision_layers = tuple(counts)
505 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
506 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
507 | vision_patch_size = None
508 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
509 | image_size = output_width * 32
510 |
511 | embed_dim = state_dict["text_projection"].shape[1]
512 | context_length = state_dict["positional_embedding"].shape[0]
513 | vocab_size = state_dict["token_embedding.weight"].shape[0]
514 | transformer_width = state_dict["ln_final.weight"].shape[0]
515 | transformer_heads = transformer_width // 64
516 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
517 |
518 | vision_cfg = CLIPVisionCfg(
519 | layers=vision_layers,
520 | width=vision_width,
521 | patch_size=vision_patch_size,
522 | image_size=image_size,
523 | )
524 | text_cfg = CLIPTextCfg(
525 | context_length=context_length,
526 | vocab_size=vocab_size,
527 | width=transformer_width,
528 | heads=transformer_heads,
529 | layers=transformer_layers
530 | )
531 | model = CLIP(
532 | embed_dim,
533 | vision_cfg=vision_cfg,
534 | text_cfg=text_cfg,
535 | quick_gelu=True, # OpenAI models were trained with QuickGELU
536 | )
537 |
538 | for key in ["input_resolution", "context_length", "vocab_size"]:
539 | state_dict.pop(key, None)
540 |
541 | convert_weights_to_fp16(model)
542 | model.load_state_dict(state_dict)
543 | return model.eval()
544 |
545 |
546 | def trace_model(model, batch_size=256, device=torch.device('cpu')):
547 | model.eval()
548 | image_size = model.visual.image_size
549 | example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
550 | example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
551 | model = torch.jit.trace_module(
552 | model,
553 | inputs=dict(
554 | forward=(example_images, example_text),
555 | encode_text=(example_text,),
556 | encode_image=(example_images,)
557 | ))
558 | model.visual.image_size = image_size
559 | return model
560 |
--------------------------------------------------------------------------------
/open_clip/model_configs/RN101-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 23,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/RN101.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 23,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/RN50-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": [
7 | 3,
8 | 4,
9 | 6,
10 | 3
11 | ],
12 | "width": 64,
13 | "patch_size": null
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 512,
19 | "heads": 8,
20 | "layers": 12
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/open_clip/model_configs/RN50.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": [
6 | 3,
7 | 4,
8 | 6,
9 | 3
10 | ],
11 | "width": 64,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 512,
18 | "heads": 8,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/RN50x16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 384,
5 | "layers": [
6 | 6,
7 | 8,
8 | 18,
9 | 8
10 | ],
11 | "width": 96,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 768,
18 | "heads": 12,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/RN50x4.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 288,
5 | "layers": [
6 | 4,
7 | 6,
8 | 10,
9 | 6
10 | ],
11 | "width": 80,
12 | "patch_size": null
13 | },
14 | "text_cfg": {
15 | "context_length": 77,
16 | "vocab_size": 49408,
17 | "width": 640,
18 | "heads": 10,
19 | "layers": 12
20 | }
21 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-B-16-plus-240.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 240,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-B-16-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-B-32-plus-256.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 640,
3 | "vision_cfg": {
4 | "image_size": 256,
5 | "layers": 12,
6 | "width": 896,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 640,
13 | "heads": 10,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-B-32-quickgelu.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "quick_gelu": true,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": 12,
7 | "width": 768,
8 | "patch_size": 32
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-B-32.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 32
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 512,
13 | "heads": 8,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-H-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 14
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 1024,
14 | "heads": 16,
15 | "layers": 24
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-H-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 32,
6 | "width": 1280,
7 | "head_width": 80,
8 | "patch_size": 16
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 1024,
14 | "heads": 16,
15 | "layers": 24
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-L-14-280.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 280,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-L-14-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 14
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-L-16-320.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 320,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-L-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "patch_size": 16
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "width": 768,
13 | "heads": 12,
14 | "layers": 12
15 | }
16 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/ViT-g-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14
10 | },
11 | "text_cfg": {
12 | "context_length": 77,
13 | "vocab_size": 49408,
14 | "width": 1024,
15 | "heads": 16,
16 | "layers": 24
17 | }
18 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/timm-efficientnetv2_rw_s.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "timm_model_name": "efficientnetv2_rw_s",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "abs_attn",
7 | "timm_proj": "",
8 | "image_size": 288
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 768,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/timm-resnet50d.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "resnet50d",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "abs_attn",
7 | "timm_proj": "",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/open_clip/model_configs/timm-resnetaa50d.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "resnetaa50d",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "abs_attn",
7 | "timm_proj": "",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/open_clip/model_configs/timm-resnetblur50.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "timm_model_name": "resnetblur50",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "abs_attn",
7 | "timm_proj": "",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/open_clip/model_configs/timm-swin_base_patch4_window7_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "swin_base_patch4_window7_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/timm-vit_base_patch16_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "vit_base_patch16_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/timm-vit_base_patch32_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "vit_base_patch32_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/model_configs/timm-vit_small_patch16_224.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "timm_model_name": "vit_small_patch16_224",
5 | "timm_model_pretrained": false,
6 | "timm_pool": "",
7 | "timm_proj": "linear",
8 | "image_size": 224
9 | },
10 | "text_cfg": {
11 | "context_length": 77,
12 | "vocab_size": 49408,
13 | "width": 512,
14 | "heads": 8,
15 | "layers": 12
16 | }
17 | }
--------------------------------------------------------------------------------
/open_clip/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import Union, List
9 |
10 | import torch
11 |
12 | from .model import build_model_from_openai_state_dict
13 | from .pretrained import get_pretrained_url, list_pretrained_tag_models, download_pretrained
14 |
15 | __all__ = ["list_openai_models", "load_openai_model"]
16 |
17 |
18 | def list_openai_models() -> List[str]:
19 | """Returns the names of available CLIP models"""
20 | return list_pretrained_tag_models('openai')
21 |
22 |
23 | def load_openai_model(
24 | name: str,
25 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
26 | jit=True,
27 | ):
28 | """Load a CLIP model
29 |
30 | Parameters
31 | ----------
32 | name : str
33 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
34 | device : Union[str, torch.device]
35 | The device to put the loaded model
36 | jit : bool
37 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
38 |
39 | Returns
40 | -------
41 | model : torch.nn.Module
42 | The CLIP model
43 | preprocess : Callable[[PIL.Image], torch.Tensor]
44 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
45 | """
46 | if get_pretrained_url(name, 'openai'):
47 | model_path = download_pretrained(get_pretrained_url(name, 'openai'))
48 | elif os.path.isfile(name):
49 | model_path = name
50 | else:
51 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
52 |
53 | try:
54 | # loading JIT archive
55 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
56 | state_dict = None
57 | except RuntimeError:
58 | # loading saved state dict
59 | if jit:
60 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
61 | jit = False
62 | state_dict = torch.load(model_path, map_location="cpu")
63 |
64 | if not jit:
65 | try:
66 | model = build_model_from_openai_state_dict(state_dict or model.state_dict()).to(device)
67 | except KeyError:
68 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
69 | model = build_model_from_openai_state_dict(sd).to(device)
70 |
71 | if str(device) == "cpu":
72 | model.float()
73 | return model
74 |
75 | # patch the device names
76 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
77 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
78 |
79 | def patch_device(module):
80 | try:
81 | graphs = [module.graph] if hasattr(module, "graph") else []
82 | except RuntimeError:
83 | graphs = []
84 |
85 | if hasattr(module, "forward1"):
86 | graphs.append(module.forward1.graph)
87 |
88 | for graph in graphs:
89 | for node in graph.findAllNodes("prim::Constant"):
90 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
91 | node.copyAttributes(device_node)
92 |
93 | model.apply(patch_device)
94 | patch_device(model.encode_image)
95 | patch_device(model.encode_text)
96 |
97 | # patch dtype to float32 on CPU
98 | if str(device) == "cpu":
99 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
100 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
101 | float_node = float_input.node()
102 |
103 | def patch_float(module):
104 | try:
105 | graphs = [module.graph] if hasattr(module, "graph") else []
106 | except RuntimeError:
107 | graphs = []
108 |
109 | if hasattr(module, "forward1"):
110 | graphs.append(module.forward1.graph)
111 |
112 | for graph in graphs:
113 | for node in graph.findAllNodes("aten::to"):
114 | inputs = list(node.inputs())
115 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
116 | if inputs[i].node()["value"] == 5:
117 | inputs[i].node().copyAttributes(float_node)
118 |
119 | model.apply(patch_float)
120 | patch_float(model.encode_image)
121 | patch_float(model.encode_text)
122 | model.float()
123 |
124 | # ensure image_size attr available at consistent location for both jit and non-jit
125 | model.visual.image_size = model.input_resolution.item()
126 | return model
127 |
--------------------------------------------------------------------------------
/open_clip/pretrained.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 |
6 | from tqdm import tqdm
7 |
8 | _RN50 = dict(
9 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
10 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
11 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"
12 | )
13 |
14 | _RN50_quickgelu = dict(
15 | openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
16 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
17 | cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"
18 | )
19 |
20 | _RN101 = dict(
21 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
22 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"
23 | )
24 |
25 | _RN101_quickgelu = dict(
26 | openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
27 | yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"
28 | )
29 |
30 | _RN50x4 = dict(
31 | openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
32 | )
33 |
34 | _RN50x16 = dict(
35 | openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
36 | )
37 |
38 | _RN50x64 = dict(
39 | openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
40 | )
41 |
42 | _VITB32 = dict(
43 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
44 | laion2b_e16="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth",
45 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
46 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
47 | )
48 |
49 | _VITB32_quickgelu = dict(
50 | openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
51 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
52 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
53 | )
54 |
55 | _VITB16 = dict(
56 | openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
57 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt",
58 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt",
59 | )
60 |
61 | _VITB16_PLUS_240 = dict(
62 | laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt",
63 | laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt",
64 | )
65 |
66 | _VITL14 = dict(
67 | openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
68 | laion400m_e31='https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt',
69 | laion400m_e32='https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt',
70 | )
71 |
72 | _VITL14_336 = dict(
73 | openai="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"
74 | )
75 |
76 | _PRETRAINED = {
77 | "RN50": _RN50,
78 | "RN50-quickgelu": _RN50_quickgelu,
79 | "RN101": _RN101,
80 | "RN101-quickgelu": _RN101_quickgelu,
81 | "RN50x4": _RN50x4,
82 | "RN50x16": _RN50x16,
83 | "RN50x64": _RN50x64,
84 | "ViT-B-32": _VITB32,
85 | "ViT-B-32-quickgelu": _VITB32_quickgelu,
86 | "ViT-B-16": _VITB16,
87 | "ViT-B-16-plus-240": _VITB16_PLUS_240,
88 | "ViT-L-14": _VITL14,
89 | "ViT-L-14-336": _VITL14_336,
90 | }
91 |
92 |
93 | def list_pretrained(as_str: bool = False):
94 | """ returns list of pretrained models
95 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
96 | """
97 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
98 |
99 |
100 | def list_pretrained_tag_models(tag: str):
101 | """ return all models having the specified pretrain tag """
102 | models = []
103 | for k in _PRETRAINED.keys():
104 | if tag in _PRETRAINED[k]:
105 | models.append(k)
106 | return models
107 |
108 |
109 | def list_pretrained_model_tags(model: str):
110 | """ return all pretrain tags for the specified model architecture """
111 | tags = []
112 | if model in _PRETRAINED:
113 | tags.extend(_PRETRAINED[model].keys())
114 | return tags
115 |
116 |
117 | def get_pretrained_url(model: str, tag: str):
118 | if model not in _PRETRAINED:
119 | return ''
120 | model_pretrained = _PRETRAINED[model]
121 | tag = tag.lower()
122 | if tag not in model_pretrained:
123 | return ''
124 | return model_pretrained[tag]
125 |
126 |
127 | def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
128 | os.makedirs(root, exist_ok=True)
129 | filename = os.path.basename(url)
130 |
131 | if 'openaipublic' in url:
132 | expected_sha256 = url.split("/")[-2]
133 | else:
134 | expected_sha256 = ''
135 |
136 | download_target = os.path.join(root, filename)
137 |
138 | if os.path.exists(download_target) and not os.path.isfile(download_target):
139 | raise RuntimeError(f"{download_target} exists and is not a regular file")
140 |
141 | if os.path.isfile(download_target):
142 | if expected_sha256:
143 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
144 | return download_target
145 | else:
146 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
147 | else:
148 | return download_target
149 |
150 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
151 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
152 | while True:
153 | buffer = source.read(8192)
154 | if not buffer:
155 | break
156 |
157 | output.write(buffer)
158 | loop.update(len(buffer))
159 |
160 | if expected_sha256 and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
161 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
162 |
163 | return download_target
164 |
--------------------------------------------------------------------------------
/open_clip/timm_model.py:
--------------------------------------------------------------------------------
1 | """ timm model adapter
2 |
3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4 | """
5 | from collections import OrderedDict
6 |
7 | import torch.nn as nn
8 |
9 | try:
10 | import timm
11 | from timm.models.layers import Mlp, to_2tuple
12 | from timm.models.layers.attention_pool2d import RotAttentionPool2d
13 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
14 | except ImportError as e:
15 | timm = None
16 |
17 | from .utils import freeze_batch_norm_2d
18 |
19 |
20 | class TimmModel(nn.Module):
21 | """ timm model adapter
22 | # FIXME this adapter is a work in progress, may change in ways that break weight compat
23 | """
24 |
25 | def __init__(
26 | self,
27 | model_name,
28 | embed_dim,
29 | image_size=224,
30 | pool='avg',
31 | proj='linear',
32 | drop=0.,
33 | pretrained=False):
34 | super().__init__()
35 | if timm is None:
36 | raise RuntimeError("Please `pip install timm` to use timm models.")
37 |
38 | self.image_size = to_2tuple(image_size)
39 | self.trunk = timm.create_model(model_name, pretrained=pretrained)
40 | feat_size = self.trunk.default_cfg.get('pool_size', None)
41 | feature_ndim = 1 if not feat_size else 2
42 | if pool in ('abs_attn', 'rot_attn'):
43 | assert feature_ndim == 2
44 | # if attn pooling used, remove both classifier and default pool
45 | self.trunk.reset_classifier(0, global_pool='')
46 | else:
47 | # reset global pool if pool config set, otherwise leave as network default
48 | reset_kwargs = dict(global_pool=pool) if pool else {}
49 | self.trunk.reset_classifier(0, **reset_kwargs)
50 | prev_chs = self.trunk.num_features
51 |
52 | head_layers = OrderedDict()
53 | if pool == 'abs_attn':
54 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
55 | prev_chs = embed_dim
56 | elif pool == 'rot_attn':
57 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
58 | prev_chs = embed_dim
59 | else:
60 | assert proj, 'projection layer needed if non-attention pooling is used.'
61 |
62 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
63 | if proj == 'linear':
64 | head_layers['drop'] = nn.Dropout(drop)
65 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim)
66 | elif proj == 'mlp':
67 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
68 |
69 | self.head = nn.Sequential(head_layers)
70 |
71 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
72 | """ lock modules
73 | Args:
74 | unlocked_groups (int): leave last n layer groups unlocked (default: 0)
75 | """
76 | if not unlocked_groups:
77 | # lock full model
78 | for param in self.trunk.parameters():
79 | param.requires_grad = False
80 | if freeze_bn_stats:
81 | freeze_batch_norm_2d(self.trunk)
82 | else:
83 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change
84 | try:
85 | # FIXME import here until API stable and in an official release
86 | from timm.models.helpers import group_parameters, group_modules
87 | except ImportError:
88 | raise RuntimeError(
89 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
90 | matcher = self.trunk.group_matcher()
91 | gparams = group_parameters(self.trunk, matcher)
92 | max_layer_id = max(gparams.keys())
93 | max_layer_id = max_layer_id - unlocked_groups
94 | for group_idx in range(max_layer_id + 1):
95 | group = gparams[group_idx]
96 | for param in group:
97 | self.trunk.get_parameter(param).requires_grad = False
98 | if freeze_bn_stats:
99 | gmodules = group_modules(self.trunk, matcher, reverse=True)
100 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
101 | freeze_batch_norm_2d(self.trunk, gmodules)
102 |
103 | def forward(self, x):
104 | x = self.trunk(x)
105 | x = self.head(x)
106 | return x
107 |
--------------------------------------------------------------------------------
/open_clip/tokenizer.py:
--------------------------------------------------------------------------------
1 | """ CLIP tokenizer
2 |
3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | import gzip
6 | import html
7 | import os
8 | from functools import lru_cache
9 | from typing import Union, List
10 |
11 | import ftfy
12 | import regex as re
13 | import torch
14 |
15 |
16 | @lru_cache()
17 | def default_bpe():
18 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
19 |
20 |
21 | @lru_cache()
22 | def bytes_to_unicode():
23 | """
24 | Returns list of utf-8 byte and a corresponding list of unicode strings.
25 | The reversible bpe codes work on unicode strings.
26 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
27 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
28 | This is a signficant percentage of your normal, say, 32K bpe vocab.
29 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
30 | And avoids mapping to whitespace/control characters the bpe code barfs on.
31 | """
32 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
33 | cs = bs[:]
34 | n = 0
35 | for b in range(2**8):
36 | if b not in bs:
37 | bs.append(b)
38 | cs.append(2**8+n)
39 | n += 1
40 | cs = [chr(n) for n in cs]
41 | return dict(zip(bs, cs))
42 |
43 |
44 | def get_pairs(word):
45 | """Return set of symbol pairs in a word.
46 | Word is represented as tuple of symbols (symbols being variable-length strings).
47 | """
48 | pairs = set()
49 | prev_char = word[0]
50 | for char in word[1:]:
51 | pairs.add((prev_char, char))
52 | prev_char = char
53 | return pairs
54 |
55 |
56 | def basic_clean(text):
57 | text = ftfy.fix_text(text)
58 | text = html.unescape(html.unescape(text))
59 | return text.strip()
60 |
61 |
62 | def whitespace_clean(text):
63 | text = re.sub(r'\s+', ' ', text)
64 | text = text.strip()
65 | return text
66 |
67 |
68 | class SimpleTokenizer(object):
69 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
70 | self.byte_encoder = bytes_to_unicode()
71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
73 | # print("merges len:", len(merges))
74 | merges = merges[1:49152-256-2+1]
75 | merges = [tuple(merge.split()) for merge in merges]
76 | vocab = list(bytes_to_unicode().values())
77 | vocab = vocab + [v+'' for v in vocab]
78 | for merge in merges:
79 | vocab.append(''.join(merge))
80 | if not special_tokens:
81 | special_tokens = ['', '']
82 | else:
83 | special_tokens = ['', ''] + special_tokens
84 | vocab.extend(special_tokens)
85 | # print("vocab:", len(vocab))
86 | self.encoder = dict(zip(vocab, range(len(vocab))))
87 | # print("encoder:", self.encoder)
88 | self.decoder = {v: k for k, v in self.encoder.items()}
89 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
90 | self.cache = {t:t for t in special_tokens}
91 | special = "|".join(special_tokens)
92 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
93 |
94 | self.vocab_size = len(self.encoder)
95 | self.all_special_ids = [self.encoder[t] for t in special_tokens]
96 |
97 | def bpe(self, token):
98 | if token in self.cache:
99 | return self.cache[token]
100 | word = tuple(token[:-1]) + ( token[-1] + '',)
101 | pairs = get_pairs(word)
102 |
103 | if not pairs:
104 | return token+''
105 |
106 | while True:
107 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
108 | if bigram not in self.bpe_ranks:
109 | break
110 | first, second = bigram
111 | new_word = []
112 | i = 0
113 | while i < len(word):
114 | try:
115 | j = word.index(first, i)
116 | new_word.extend(word[i:j])
117 | i = j
118 | except:
119 | new_word.extend(word[i:])
120 | break
121 |
122 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
123 | new_word.append(first+second)
124 | i += 2
125 | else:
126 | new_word.append(word[i])
127 | i += 1
128 | new_word = tuple(new_word)
129 | word = new_word
130 | if len(word) == 1:
131 | break
132 | else:
133 | pairs = get_pairs(word)
134 | word = ' '.join(word)
135 | self.cache[token] = word
136 | return word
137 |
138 | def encode(self, text):
139 | bpe_tokens = []
140 | text = whitespace_clean(basic_clean(text)).lower()
141 | for token in re.findall(self.pat, text):
142 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
143 | # print("token, bpe:", token, self.bpe(token))
144 | # print(self.bpe(token).split(' '))
145 | # for bpe_token in self.bpe(token).split(' '):
146 | # print('token:', bpe_token )
147 | # print('bpe:', self.encoder[bpe_token])
148 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
149 | # print("overall bpe:", bpe_tokens)
150 | return bpe_tokens
151 |
152 | def decode(self, tokens):
153 | text = ''.join([self.decoder[token] for token in tokens])
154 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
155 | return text
156 |
157 |
158 | _tokenizer = SimpleTokenizer()
159 |
160 |
161 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
162 | """
163 | Returns the tokenized representation of given input string(s)
164 |
165 | Parameters
166 | ----------
167 | texts : Union[str, List[str]]
168 | An input string or a list of input strings to tokenize
169 | context_length : int
170 | The context length to use; all CLIP models use 77 as the context length
171 |
172 | Returns
173 | -------
174 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
175 | """
176 | if isinstance(texts, str):
177 | texts = [texts]
178 |
179 | sot_token = _tokenizer.encoder[""]
180 | eot_token = _tokenizer.encoder[""]
181 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
182 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
183 |
184 | for i, tokens in enumerate(all_tokens):
185 | if len(tokens) > context_length:
186 | tokens = tokens[:context_length] # Truncate
187 | result[i, :len(tokens)] = torch.tensor(tokens)
188 |
189 | return result
190 |
--------------------------------------------------------------------------------
/open_clip/transform.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, ToTensor, Resize, \
2 | CenterCrop
3 | from PIL import Image
4 |
5 | def _convert_to_rgb(image):
6 | return image.convert('RGB')
7 |
8 |
9 | def image_transform(
10 | image_size: int,
11 | is_train: bool,
12 | mean=(0.48145466, 0.4578275, 0.40821073),
13 | std=(0.26862954, 0.26130258, 0.27577711)
14 | ):
15 | normalize = Normalize(mean=mean, std=std)
16 | if is_train:
17 | return Compose([
18 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=Image.BICUBIC),
19 | _convert_to_rgb,
20 | ToTensor(),
21 | normalize,
22 | ])
23 | else:
24 | return Compose([
25 | Resize(image_size, interpolation=Image.BICUBIC),
26 | CenterCrop(image_size),
27 | _convert_to_rgb,
28 | ToTensor(),
29 | normalize,
30 | ])
31 |
--------------------------------------------------------------------------------
/open_clip/utils.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 | from torchvision.ops.misc import FrozenBatchNorm2d
3 |
4 |
5 | def freeze_batch_norm_2d(module, module_match={}, name=''):
6 | """
7 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
8 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
9 | returned. Otherwise, the module is walked recursively and submodules are converted in place.
10 |
11 | Args:
12 | module (torch.nn.Module): Any PyTorch module.
13 | module_match (dict): Dictionary of full module names to freeze (all if empty)
14 | name (str): Full module name (prefix)
15 |
16 | Returns:
17 | torch.nn.Module: Resulting module
18 |
19 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
20 | """
21 | res = module
22 | is_match = True
23 | if module_match:
24 | is_match = name in module_match
25 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
26 | res = FrozenBatchNorm2d(module.num_features)
27 | res.num_features = module.num_features
28 | res.affine = module.affine
29 | if module.affine:
30 | res.weight.data = module.weight.data.clone().detach()
31 | res.bias.data = module.bias.data.clone().detach()
32 | res.running_mean.data = module.running_mean.data
33 | res.running_var.data = module.running_var.data
34 | res.eps = module.eps
35 | else:
36 | for child_name, child in module.named_children():
37 | full_child_name = '.'.join([name, child_name]) if name else child_name
38 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
39 | if new_child is not child:
40 | res.add_module(child_name, new_child)
41 | return res
--------------------------------------------------------------------------------
/open_clip/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '1.3.0'
2 |
--------------------------------------------------------------------------------