├── .gitignore ├── LICENSE ├── Multitasking ├── __init__.py ├── clip_load.py ├── clip_models.py ├── dataset_utils │ ├── __init__.py │ ├── cub200_dataset.py │ └── generic_dataset.py ├── main.py ├── metrics.py ├── models.py ├── swin_models.py ├── trainer.py └── utils.py ├── OPENCLIP ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── coca_model.py ├── constants.py ├── factory.py ├── generation_utils.py ├── hf_configs.py ├── hf_model.py ├── loss.py ├── model.py ├── model_configs │ ├── RN101-quickgelu.json │ ├── RN101.json │ ├── RN50-quickgelu.json │ ├── RN50.json │ ├── RN50x16.json │ ├── RN50x4.json │ ├── RN50x64.json │ ├── ViT-B-16-plus-240.json │ ├── ViT-B-16-plus.json │ ├── ViT-B-16.json │ ├── ViT-B-32-plus-256.json │ ├── ViT-B-32-quickgelu.json │ ├── ViT-B-32.json │ ├── ViT-H-14.json │ ├── ViT-H-16.json │ ├── ViT-L-14-280.json │ ├── ViT-L-14-336.json │ ├── ViT-L-14.json │ ├── ViT-L-16-320.json │ ├── ViT-L-16.json │ ├── ViT-M-16-alt.json │ ├── ViT-M-16.json │ ├── ViT-M-32-alt.json │ ├── ViT-M-32.json │ ├── ViT-S-16-alt.json │ ├── ViT-S-16.json │ ├── ViT-S-32-alt.json │ ├── ViT-S-32.json │ ├── ViT-bigG-14.json │ ├── ViT-e-14.json │ ├── ViT-g-14.json │ ├── coca_ViT-B-32.json │ ├── coca_ViT-L-14.json │ ├── coca_base.json │ ├── coca_roberta-ViT-B-32.json │ ├── convnext_base.json │ ├── convnext_base_w.json │ ├── convnext_base_w_320.json │ ├── convnext_large.json │ ├── convnext_large_d.json │ ├── convnext_large_d_320.json │ ├── convnext_small.json │ ├── convnext_tiny.json │ ├── convnext_xlarge.json │ ├── convnext_xxlarge.json │ ├── convnext_xxlarge_320.json │ ├── mt5-base-ViT-B-32.json │ ├── mt5-xl-ViT-H-14.json │ ├── roberta-ViT-B-32.json │ ├── swin_base_patch4_window7_224.json │ ├── vit_medium_patch16_gap_256.json │ ├── vit_relpos_medium_patch16_cls_224.json │ ├── xlm-roberta-base-ViT-B-32.json │ └── xlm-roberta-large-ViT-H-14.json ├── modified_resnet.py ├── openai.py ├── pretrained.py ├── push_to_hf_hub.py ├── timm_model.py ├── tokenizer.py ├── transform.py ├── transformer.py ├── transformer_lora.py ├── utils.py └── version.py ├── README.md ├── setup.py └── visual.png /.gitignore: -------------------------------------------------------------------------------- 1 | Results/** 2 | Oracle_seg/** 3 | Datasets/** 4 | Models/** 5 | Tokenizers/** 6 | .idea 7 | *.pyc 8 | **/__pycache__ 9 | *.pt 10 | *.egg-info 11 | *.tar.gz 12 | 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /Multitasking/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | -------------------------------------------------------------------------------- /Multitasking/clip_load.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from https://github.com/openai/CLIP 3 | """ 4 | 5 | 6 | import hashlib 7 | import os 8 | import urllib 9 | import warnings 10 | from typing import Any, Union, List 11 | from pkg_resources import packaging 12 | 13 | import torch 14 | from PIL import Image 15 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 16 | from tqdm import tqdm 17 | 18 | from Multitasking.clip_models import build_model 19 | 20 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 21 | 22 | try: 23 | from torchvision.transforms import InterpolationMode 24 | BICUBIC = InterpolationMode.BICUBIC 25 | except ImportError: 26 | BICUBIC = Image.BICUBIC 27 | 28 | 29 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 30 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 31 | 32 | 33 | __all__ = ["available_models", "load", "tokenize"] 34 | _tokenizer = _Tokenizer() 35 | 36 | _MODELS = { 37 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 38 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 39 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 40 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 41 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 42 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 43 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 44 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 45 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 46 | } 47 | 48 | 49 | def _download(url: str, root: str): 50 | os.makedirs(root, exist_ok=True) 51 | filename = os.path.basename(url) 52 | 53 | expected_sha256 = url.split("/")[-2] 54 | download_target = os.path.join(root, filename) 55 | 56 | if os.path.exists(download_target) and not os.path.isfile(download_target): 57 | raise RuntimeError(f"{download_target} exists and is not a regular file") 58 | 59 | if os.path.isfile(download_target): 60 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 61 | return download_target 62 | else: 63 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 64 | 65 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 66 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 67 | while True: 68 | buffer = source.read(8192) 69 | if not buffer: 70 | break 71 | 72 | output.write(buffer) 73 | loop.update(len(buffer)) 74 | 75 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 76 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 77 | 78 | return download_target 79 | 80 | 81 | def _convert_image_to_rgb(image): 82 | return image.convert("RGB") 83 | 84 | 85 | def _transform(n_px): 86 | return Compose([ 87 | Resize(n_px, interpolation=BICUBIC), 88 | CenterCrop(n_px), 89 | _convert_image_to_rgb, 90 | ToTensor(), 91 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 92 | ]) 93 | 94 | 95 | def available_models() -> List[str]: 96 | """Returns the names of available CLIP models""" 97 | return list(_MODELS.keys()) 98 | 99 | 100 | def load(name: str, params: dict, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 101 | """Load a CLIP model 102 | 103 | Parameters 104 | ---------- 105 | name : str 106 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 107 | 108 | device : Union[str, torch.device] 109 | The device to put the loaded model 110 | 111 | jit : bool 112 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 113 | 114 | download_root: str 115 | path to download the model files; by default, it uses "~/.cache/clip" 116 | 117 | Returns 118 | ------- 119 | model : torch.nn.Module 120 | The CLIP model 121 | 122 | preprocess : Callable[[PIL.Image], torch.Tensor] 123 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 124 | """ 125 | if name in _MODELS: 126 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 127 | elif os.path.isfile(name): 128 | model_path = name 129 | else: 130 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 131 | 132 | with open(model_path, 'rb') as opened_file: 133 | try: 134 | # loading JIT archive 135 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 136 | state_dict = None 137 | except RuntimeError: 138 | # loading saved state dict 139 | if jit: 140 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 141 | jit = False 142 | state_dict = torch.load(opened_file, map_location="cpu") 143 | 144 | if not jit: 145 | model = build_model(state_dict or model.state_dict(), params).to(device) 146 | if str(device) == "cpu": 147 | model.float() 148 | return model, _transform(model.visual.input_resolution) 149 | 150 | # patch the device names 151 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 152 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 153 | 154 | def patch_device(module): 155 | try: 156 | graphs = [module.graph] if hasattr(module, "graph") else [] 157 | except RuntimeError: 158 | graphs = [] 159 | 160 | if hasattr(module, "forward1"): 161 | graphs.append(module.forward1.graph) 162 | 163 | for graph in graphs: 164 | for node in graph.findAllNodes("prim::Constant"): 165 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 166 | node.copyAttributes(device_node) 167 | 168 | model.apply(patch_device) 169 | patch_device(model.encode_image) 170 | patch_device(model.encode_text) 171 | 172 | # patch dtype to float32 on CPU 173 | if str(device) == "cpu": 174 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 175 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 176 | float_node = float_input.node() 177 | 178 | def patch_float(module): 179 | try: 180 | graphs = [module.graph] if hasattr(module, "graph") else [] 181 | except RuntimeError: 182 | graphs = [] 183 | 184 | if hasattr(module, "forward1"): 185 | graphs.append(module.forward1.graph) 186 | 187 | for graph in graphs: 188 | for node in graph.findAllNodes("aten::to"): 189 | inputs = list(node.inputs()) 190 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 191 | if inputs[i].node()["value"] == 5: 192 | inputs[i].node().copyAttributes(float_node) 193 | 194 | model.apply(patch_float) 195 | patch_float(model.encode_image) 196 | patch_float(model.encode_text) 197 | 198 | model.float() 199 | 200 | return model, _transform(model.input_resolution.item()) 201 | 202 | 203 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 204 | """ 205 | Returns the tokenized representation of given input string(s) 206 | 207 | Parameters 208 | ---------- 209 | texts : Union[str, List[str]] 210 | An input string or a list of input strings to tokenize 211 | 212 | context_length : int 213 | The context length to use; all CLIP models use 77 as the context length 214 | 215 | truncate: bool 216 | Whether to truncate the text in case its encoding is longer than the context length 217 | 218 | Returns 219 | ------- 220 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 221 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 222 | """ 223 | if isinstance(texts, str): 224 | texts = [texts] 225 | 226 | sot_token = _tokenizer.encoder["<|startoftext|>"] 227 | eot_token = _tokenizer.encoder["<|endoftext|>"] 228 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 229 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 230 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 231 | else: 232 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 233 | 234 | for i, tokens in enumerate(all_tokens): 235 | if len(tokens) > context_length: 236 | if truncate: 237 | tokens = tokens[:context_length] 238 | tokens[-1] = eot_token 239 | else: 240 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 241 | result[i, :len(tokens)] = torch.tensor(tokens) 242 | 243 | return result 244 | -------------------------------------------------------------------------------- /Multitasking/dataset_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * -------------------------------------------------------------------------------- /Multitasking/dataset_utils/generic_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import torchvision 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | from torchvision.transforms import Resize, Normalize 8 | from torchvision.transforms import Compose, RandomAdjustSharpness, ToTensor, InterpolationMode 9 | from torch.nn import Module 10 | from tqdm import tqdm 11 | from PIL import Image 12 | 13 | from Multitasking.utils import rand, get_tokenizer, randint 14 | 15 | 16 | class GenericDataset(Dataset): 17 | 18 | def __init__(self, path, set_name, params): 19 | self.params = params 20 | self.path = path 21 | self.set_name = set_name 22 | self.attr_from = dict() 23 | self.collate_fn = GenericCollateFunction() 24 | self.samples = None 25 | self.preprocess_fn = None 26 | self.input_size = self.params["input_size"] 27 | self.norm_fn = None 28 | self.image_id_class_id_mapping = None 29 | self.oracle_matching_keys = None 30 | self.init_dataset() 31 | self.tokenizer = get_tokenizer(self.params["model_name"]) 32 | self.open_tokenizer = "open" in self.params["model_name"] 33 | self.dataset_name = os.path.basename(self.path) 34 | self.model_name = self.params["model_name"].replace("/", "-") 35 | if self.params["other_vision_encoder"] is None: 36 | archi_name = self.model_name 37 | else: 38 | archi_name = "{}_{}".format(self.model_name, self.params["other_vision_encoder"]) 39 | self.preprocess_foldpath = os.path.join("Preprocess", self.dataset_name, archi_name, self.set_name) 40 | self.class_names = self.get_class_names() 41 | self.num_classes = len(self.class_names) 42 | 43 | self.samples = self.load_samples() 44 | 45 | self.image_preprocess_function = Compose([ 46 | ToTensor(), 47 | ToRGB(), 48 | ]) 49 | 50 | self.da_function = Compose([ 51 | RandomAdjustSharpness(sharpness_factor=1.5, p=1) 52 | ]) 53 | 54 | self.swin_normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 55 | 56 | self.text_preprocessing() 57 | 58 | def init_dataset(self): 59 | pass 60 | 61 | def get_class_names(self): 62 | pass 63 | 64 | def get_image_by_id(self, img_id): 65 | pass 66 | 67 | def get_image_ids(self): 68 | pass 69 | 70 | def update_info_resize(self, sample, init_size, new_size): 71 | pass 72 | 73 | def update_info_crop(self, sample, init_size, new_size, start_h, start_w): 74 | pass 75 | 76 | def update_info_hflip(self, sample, size): 77 | pass 78 | 79 | def before_preprocess_sample(self, sample, preprocess_fn): 80 | return sample 81 | 82 | def after_preprocess_sample(self, sample): 83 | return sample 84 | 85 | def text_preprocessing(self): 86 | pass 87 | 88 | def __len__(self): 89 | return len(self.samples) 90 | 91 | def __getitem__(self, idx): 92 | sample = copy.deepcopy(self.samples[idx]) 93 | 94 | # load pseudolabel for localization if oracle (UIO/OFA) is used 95 | if self.params["load_oracle_segmentation"]: 96 | sample["oracle_seg"] = self.load_oracle_segmentation(sample) 97 | 98 | # IF CLIP IMAGE ENCODER 99 | if self.params["other_vision_encoder"] is None: 100 | self.before_preprocess_sample(sample, self.preprocess_fn) 101 | sample["image"] = self.preprocess_fn(Image.fromarray(sample["image"])) 102 | if self.set_name == "train": 103 | sample = self.apply_custom_da(sample) 104 | sample = self.after_preprocess_sample(sample) 105 | sample["image"] = self.norm_fn(sample["image"]) 106 | # IF SWIN IMAGE ENCODER 107 | else: 108 | sample["image"] = self.image_preprocess_function(sample["image"]) 109 | sample = self.resize(sample, (510, 510)) 110 | if self.set_name == "train": 111 | sample = self.apply_custom_da(sample) 112 | sample = self.random_crop(sample, self.input_size) 113 | else: 114 | sample = self.center_crop(sample, self.input_size) 115 | sample = self.after_preprocess_sample(sample) 116 | sample["image"] = self.swin_normalize(sample["image"]) 117 | 118 | return sample 119 | 120 | def denormalize(self, image): 121 | if self.params["other_vision_encoder"] is None: 122 | norm_fn = self.norm_fn 123 | else: 124 | norm_fn = self.swin_normalize 125 | return ((image * torch.tensor(norm_fn.std).view(3, 1, 1) + torch.tensor(norm_fn.mean).view(3, 1, 1)) * 255).to(torch.int) 126 | 127 | def apply_custom_da(self, sample): 128 | sample["image"] = self.da_function(sample["image"]) 129 | if rand() < 0.1: 130 | sample["image"] = torchvision.transforms.functional.hflip(sample["image"]) 131 | self.update_info_hflip(sample, sample["image"].shape[1:]) 132 | return sample 133 | 134 | def random_crop(self, sample, crop_size): 135 | h, w = sample["image"].shape[1:] 136 | diff_h = h - crop_size[0] 137 | diff_w = w - crop_size[1] 138 | assert diff_w >= 0 and diff_h >= 0 139 | start_h = randint(0, diff_h + 1) 140 | start_w = randint(0, diff_w + 1) 141 | return self.crop(sample, start_h, start_w, crop_size[0], crop_size[1]) 142 | 143 | def center_crop(self, sample, crop_size): 144 | h, w = sample["image"].shape[1:] 145 | diff_h = h - crop_size[0] 146 | diff_w = w - crop_size[1] 147 | assert diff_w >= 0 and diff_h >= 0 148 | start_h = diff_h // 2 149 | start_w = diff_w // 2 150 | return self.crop(sample, start_h, start_w, crop_size[0], crop_size[1]) 151 | 152 | def crop(self, sample, start_h, start_w, crop_height, crop_width): 153 | h, w = sample["image"].shape[1:] 154 | sample["image"] = sample["image"][:, start_h:start_h+crop_height, start_w:start_w+crop_width] 155 | self.update_info_crop(sample, (h, w), (crop_height, crop_width), start_h, start_w) 156 | return sample 157 | 158 | def resize(self, sample, new_size): 159 | h, w = sample["image"].shape[1:] 160 | sample["image"] = Resize(new_size, InterpolationMode.BILINEAR, antialias=True)(sample["image"]) 161 | self.update_info_resize(sample, (h, w), new_size) 162 | return sample 163 | 164 | def load_samples(self): 165 | samples = list() 166 | image_ids = self.get_image_ids() 167 | bar = tqdm(image_ids) 168 | bar.set_description("Loading {} samples".format(self.set_name)) 169 | for i, img_id in enumerate(bar): 170 | samples.append(self.load_sample(img_id, i)) 171 | return samples 172 | 173 | def load_sample(self, img_id, sample_id): 174 | img = self.get_image_by_id(img_id) 175 | class_id = self.image_id_class_id_mapping[img_id] 176 | sample = { 177 | "image": img, 178 | "class_id": class_id, 179 | "class_name": self.class_names[class_id], 180 | "image_id": img_id, 181 | "sample_id": sample_id, 182 | "image_name": img_id, 183 | } 184 | return sample 185 | 186 | def set_preprocess(self, preprocess_fn): 187 | self.preprocess_fn = Compose([tr for tr in preprocess_fn.transforms if not isinstance(tr, Normalize)]) 188 | self.norm_fn = [tr for tr in preprocess_fn.transforms if isinstance(tr, Normalize)][0] 189 | 190 | def load_oracle_segmentation(self, sample): 191 | items = list() 192 | for i in self.oracle_matching_keys.values(): 193 | items.extend(i) 194 | items = list(np.unique(items)) 195 | 196 | # load segmentation masks 197 | image_name = os.path.basename(sample["image_name"]).split(".")[0] 198 | size, shape = (sample["image"][:, :, 0].size, sample["image"][:, :, 0].shape) if len(sample["image"].shape) == 3 else (sample["image"].size, sample["image"].shape) 199 | segmentations_maps = dict() 200 | filepath = os.path.join("Oracle_seg", "merged", self.dataset_name, self.params["oracle"], "{}.npz".format(image_name)) 201 | masks = torch.load(filepath) 202 | for item in items: 203 | segmentations_maps[item] = torch.tensor(np.unpackbits(masks[item], count=size).reshape(shape).view(bool), dtype=torch.bool) 204 | 205 | # generate mask per attribute 206 | masks = list() 207 | for i in range(self.num_attr): 208 | attr_name = self.attr_names[i].split("::")[0] 209 | attr_words = attr_name\ 210 | .replace("has_shape", "birdshape")\ 211 | .replace("-", " ").replace("_", " ").split(" ") 212 | mask = torch.zeros((sample["image"].shape[:2]), dtype=torch.bool) 213 | for word in attr_words: 214 | if word in segmentations_maps.keys(): 215 | mask = torch.logical_or(mask, segmentations_maps[word]) 216 | if word not in self.oracle_matching_keys: 217 | continue 218 | for key in self.oracle_matching_keys[word]: 219 | mask = torch.logical_or(mask, segmentations_maps[key]) 220 | masks.append(mask) 221 | return torch.stack(masks, dim=0) 222 | 223 | def oracle_mask_to_patch_loc(self, sample): 224 | image_size = sample["image"].shape[-2:] 225 | patch_size = self.params["patch_size"] 226 | num_patches = image_size[0] // patch_size, image_size[1] // patch_size 227 | masks = sample["oracle_seg"] 228 | masks = masks.reshape(masks.size(0), num_patches[0], patch_size, num_patches[1], patch_size).permute(0, 1, 3, 2, 4) 229 | return torch.sum(masks, dim=[3, 4]) > 0.25*patch_size**2 230 | 231 | 232 | class GenericCollateFunction: 233 | 234 | def __init__(self): 235 | pass 236 | 237 | def __call__(self, batch_data): 238 | data = { 239 | "class_ids": torch.tensor([data["class_id"] for data in batch_data]), 240 | "class_names": [data["class_name"] for data in batch_data], 241 | "images": torch.stack([data["image"] for data in batch_data]), 242 | "sample_ids": [data["sample_id"] for data in batch_data], 243 | "sample_names": [data["image_name"] for data in batch_data], 244 | } 245 | if "oracle_attr_location" in batch_data[0]: 246 | data["oracle_attr_location"] = torch.stack([data["oracle_attr_location"] for data in batch_data], dim=0) 247 | return data 248 | 249 | 250 | # preprocessing 251 | class ToRGB(Module): 252 | 253 | def __init__(self): 254 | super().__init__() 255 | 256 | def forward(self, x): 257 | assert isinstance(x, torch.Tensor) and x.size(0) in [1, 3] 258 | if x.size(0) == 1: 259 | x = torch.cat([x, x, x], dim=0) 260 | return x 261 | 262 | -------------------------------------------------------------------------------- /Multitasking/main.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Adam 2 | from Multitasking.dataset_utils.cub200_dataset import CUBDataset, check_or_download_CUB200_dataset, check_or_download_CUB200_oracle_seg 3 | from Multitasking.trainer import Trainer 4 | from Multitasking.utils import seed_all, none_or_str, arg_to_bool 5 | from Multitasking.models import check_or_download_model_weights 6 | import torch 7 | import argparse 8 | 9 | 10 | seed_all(0) 11 | 12 | 13 | def main(args): 14 | clip_model_name = args.model 15 | dataset_name = "CUB200" 16 | 17 | check_or_download_CUB200_dataset() 18 | 19 | # Get patch size from model name 20 | patch_size = clip_model_name.split("/")[1] if "/" in clip_model_name else None 21 | if "@" in patch_size: 22 | patch_size = patch_size.split("@")[0] 23 | patch_size = int(patch_size) if patch_size else None 24 | 25 | dataset_params = { 26 | "patch_size": patch_size if args.image_encoder is None else 32, 27 | "model_name": clip_model_name, 28 | "load_oracle_segmentation": args.train_oracle_loc, 29 | "oracle": args.oracle, 30 | "other_vision_encoder": args.image_encoder, 31 | "input_size": (224, 224) if args.image_encoder is None else (384, 384), 32 | } 33 | 34 | dataset_path = "./Datasets/{}".format(dataset_name) 35 | dataset_class = CUBDataset 36 | 37 | datasets = { 38 | "train": dataset_class(dataset_path, "train", dataset_params), 39 | "valid": dataset_class(dataset_path, "valid", dataset_params), 40 | } 41 | 42 | params = { 43 | "training": { 44 | "output_fold_name": "{}_{}".format(clip_model_name.replace("/", "-"), args.output_name), 45 | "load_weights": "last", 46 | "max_num_epochs": args.num_epochs, 47 | "eval_on_valid_interval": 2, 48 | "num_iter_display_value_update": 100, 49 | "gpu_index": "0", 50 | "use_amp": torch.cuda.is_available(), 51 | "batch_size": { 52 | "train": args.batch_size, 53 | "valid": 4*args.batch_size, 54 | }, 55 | "gradient_acc": 200, 56 | "metric_names": { 57 | "train": ["loss", "loss_class", "loss_attr", "loss_loc", "accuracy", "top5", "attr_mAP", "loc_mAP"], 58 | "valid": ["accuracy", "top5", "attr_mAP", "loc_mAP", ], 59 | }, 60 | "optimizer": { 61 | "class": Adam, 62 | "args": { 63 | "lr": 1e-5, 64 | "amsgrad": False, 65 | "betas": (0.9, 0.98), 66 | "eps": 1e-6, 67 | "weight_decay": 0.2 68 | } 69 | }, 70 | "train_class_image": args.train_class, 71 | "train_attr_image": args.train_attr, 72 | "train_loc": args.train_loc, 73 | "train_oracle_loc": args.train_oracle_loc, 74 | "train_class_linear": args.train_proj_class, 75 | "train_attr_linear": args.train_proj_attr, 76 | "loss_weights": { 77 | "class": args.weight_class, 78 | "attr": args.weight_attr, 79 | "loc": args.weight_loc, 80 | "proj_class": args.weight_proj_class, 81 | "proj_attr": args.weight_proj_attr, 82 | "oracle_loc": args.weight_oracle_loc, 83 | }, 84 | "metric_to_focus": "accuracy", 85 | "expected_metric_value": "high", 86 | "use_negative_attr": args.neg_attributes, 87 | }, 88 | "model": { 89 | "clip_model": clip_model_name, 90 | "classif_linear": False, 91 | "attr_linear": False, 92 | "input_size": dataset_params["input_size"], 93 | "config": { 94 | # goal: [transformer_name, freeze, part to freeze] 95 | "vision": ["clip_vision", args.adapter_image, "backbone"], 96 | "class": ["clip_text", args.adapter_text, 'backbone'], 97 | "attr": ["clip_text", args.adapter_text, "backbone"], 98 | }, 99 | } 100 | } 101 | if args.image_encoder is not None: 102 | params["model"]["config"]["vision"][0] = args.image_encoder 103 | 104 | if params["training"]["train_class_linear"]: 105 | params["training"]["metric_names"]["train"].append("loss_proj") 106 | params["model"]["classif_linear"] = True 107 | 108 | if params["training"]["train_attr_linear"]: 109 | params["training"]["metric_names"]["train"].append("loss_alpha") 110 | params["model"]["attr_linear"] = True 111 | 112 | if params["training"]["train_oracle_loc"]: 113 | params["training"]["metric_names"]["train"].append("loss_loc_oracle") 114 | 115 | trainer = Trainer(datasets, params) 116 | 117 | if args.train: 118 | trainer.train() 119 | 120 | if args.eval: 121 | trainer.free_memory() 122 | trainer.load_weights("last") 123 | trainer.evaluate_classification("valid", output=True) 124 | 125 | 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument("--output-name", type=str, default="my_expe") 128 | parser.add_argument("--batch-size", type=int, default=2) 129 | parser.add_argument("--num-epochs", type=int, default=100) 130 | parser.add_argument("--oracle", type=str, default="OFA") 131 | 132 | parser.add_argument("--model", type=str, default="open-ViT-L/14") 133 | parser.add_argument("--image-encoder", type=none_or_str, default="swin_vision") 134 | parser.add_argument("--load-pretrain", type=arg_to_bool, default=False) 135 | 136 | parser.add_argument("--train", type=arg_to_bool, default=True) 137 | parser.add_argument("--eval", type=arg_to_bool, default=True) 138 | 139 | parser.add_argument("--train-class", type=arg_to_bool, default=True) 140 | parser.add_argument("--train-attr", type=arg_to_bool, default=True) 141 | parser.add_argument("--train-loc", type=arg_to_bool, default=True) 142 | parser.add_argument("--train-oracle-loc", type=arg_to_bool, default=False) 143 | parser.add_argument("--train-proj-class", type=arg_to_bool, default=True) 144 | parser.add_argument("--train-proj-attr", type=arg_to_bool, default=False) 145 | 146 | parser.add_argument("--weight-class", type=int, default=1) 147 | parser.add_argument("--weight-attr", type=int, default=1) 148 | parser.add_argument("--weight-loc", type=int, default=1) 149 | parser.add_argument("--weight-oracle-loc", type=int, default=1) 150 | parser.add_argument("--weight-proj-attr", type=int, default=1) 151 | parser.add_argument("--weight-proj-class", type=int, default=1) 152 | 153 | parser.add_argument("--adapter-image", type=arg_to_bool, default=False) 154 | parser.add_argument("--adapter-text", type=arg_to_bool, default=True) 155 | 156 | parser.add_argument("--neg-attributes", type=arg_to_bool, default=True) 157 | 158 | args = parser.parse_args() 159 | 160 | if args.load_pretrain: 161 | args.output_name = "swin_clip_text_finetuned" 162 | check_or_download_model_weights() 163 | 164 | if args.train_oracle_loc: 165 | check_or_download_CUB200_oracle_seg() 166 | 167 | main(args) 168 | -------------------------------------------------------------------------------- /Multitasking/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class MetricManager: 6 | 7 | def __init__(self, metric_names, output_path, dataset): 8 | self.dataset = dataset 9 | self.metric_names = metric_names 10 | self.output_path = output_path 11 | self.epoch_metrics = None 12 | 13 | self.linked_metrics = { 14 | "attr_mAP": ["attr_AP_by_mask", "num_attr_by_mask", "attr_classes"], 15 | "attr_mAP_linear": ["attr_mAP_linear_attr_AP_by_mask", "attr_mAP_linear_num_attr_by_mask", "attr_classes"], 16 | } 17 | 18 | self.init_metrics() 19 | 20 | def init_metrics(self): 21 | self.epoch_metrics = { 22 | "nb_samples": list(), 23 | "sample_ids": list(), 24 | "sample_names": list(), 25 | "preds_classif": list(), 26 | "preds_attr": list() 27 | 28 | } 29 | 30 | for metric_name in self.metric_names: 31 | if metric_name in self.linked_metrics: 32 | for linked_metric_name in self.linked_metrics[metric_name]: 33 | if linked_metric_name not in self.epoch_metrics.keys(): 34 | self.epoch_metrics[linked_metric_name] = list() 35 | else: 36 | self.epoch_metrics[metric_name] = list() 37 | 38 | def add_batch_values(self, batch_values): 39 | batch_metrics = self.compute_metrics_from_batch_values(batch_values) 40 | for key in batch_metrics.keys(): 41 | if key in self.epoch_metrics: 42 | self.epoch_metrics[key] += batch_metrics[key] 43 | 44 | def compute_metrics_from_batch_values(self, batch_values): 45 | metrics = dict() 46 | for v in batch_values.keys(): 47 | metrics[v] = batch_values[v] 48 | for metric_name in self.metric_names: 49 | if metric_name == "accuracy": 50 | metrics[metric_name] = accuracy(metrics["preds_classif"], metrics["gt_classif"]) 51 | elif metric_name == "top5": 52 | metrics[metric_name] = topk_accuracy(metrics["preds_classif"], metrics["gt_classif"], k=5) 53 | elif "accuracy_" in metric_name: 54 | values = metrics[metric_name] 55 | metrics[metric_name] = accuracy(values, metrics["gt_classif"]) 56 | name = "top5_" + "_".join(metric_name.split("_")[1:]) 57 | metrics[name] = topk_accuracy(values, metrics["gt_classif"], k=5) 58 | elif metric_name == "attr_mAP": 59 | metrics = metrics | compute_sample_mAP(metrics["preds_attr"], metrics["gt_attr"], metrics['certainty_masks']) 60 | elif "attr_mAP_" in metric_name: 61 | temp = compute_sample_mAP(metrics[metric_name], metrics["gt_attr"], metrics['certainty_masks']) 62 | metrics["{}_num_attr_by_mask".format(metric_name)] = temp["num_attr_by_mask"] 63 | metrics["{}_attr_AP_by_mask".format(metric_name)] = temp["attr_AP_by_mask"] 64 | elif metric_name == "loc_mAP": 65 | metrics[metric_name] = compute_loc_mAP(metrics["preds_loc"], metrics["gt_loc"]) 66 | return metrics 67 | 68 | def get_display_values(self): 69 | metric_names = self.metric_names.copy() 70 | display_values = dict() 71 | 72 | for metric_name in metric_names: 73 | if metric_name in ["accuracy", "top5"]: 74 | value = 100 * np.mean(self.epoch_metrics[metric_name]) 75 | elif "accuracy_" in metric_name or 'top5_' in metric_name: 76 | value = 100 * np.mean(self.epoch_metrics[metric_name]) 77 | elif metric_name in ["attr_mAP", ]: 78 | value = 100 * compute_mAP_by_class_and_global(self.epoch_metrics["attr_AP_by_mask"], self.epoch_metrics["num_attr_by_mask"], self.epoch_metrics["attr_classes"]) 79 | elif "attr_mAP_" in metric_name: 80 | value = 100 * compute_mAP_by_class_and_global(self.epoch_metrics["{}_attr_AP_by_mask".format(metric_name)], self.epoch_metrics["{}_num_attr_by_mask".format(metric_name)], self.epoch_metrics["attr_classes"]) 81 | elif metric_name in ["loc_mAP", ]: 82 | value, _ = compute_loc_mAP_by_attr_and_global(self.epoch_metrics["loc_mAP"]) 83 | value = 100 * value 84 | elif "loss" in metric_name: 85 | value = None 86 | if len(self.epoch_metrics[metric_name]) > 0: 87 | mask = np.array(self.epoch_metrics[metric_name]) != None 88 | weights = np.array(self.epoch_metrics["nb_samples"])[mask] 89 | if np.sum(weights) > 0: 90 | value = np.average(np.array(self.epoch_metrics[metric_name])[mask], weights=weights) 91 | else: 92 | value = self.epoch_metrics[metric_name] 93 | 94 | display_values[metric_name] = round(value, 2) if value is not None else None 95 | 96 | return display_values 97 | 98 | 99 | def compute_sample_mAP(preds, gt, certainty_masks): 100 | num_samples = gt.size(0) 101 | num_masks = certainty_masks.size(0) 102 | AP = torch.zeros((num_samples, num_masks)) 103 | num_attr = torch.zeros((num_samples, num_masks)) 104 | for i, pred in enumerate(preds): 105 | # ordering predictions 106 | pred_attr, order_attr = preds[i].sort(descending=True) 107 | gt_attr = gt[i, order_attr] 108 | mask_attr = certainty_masks[:, i, order_attr].bool() 109 | for m in range(num_masks): 110 | m_gt = gt_attr[mask_attr[m]] 111 | num_correct = torch.sum(m_gt) 112 | if num_correct <= 0: 113 | continue 114 | precision = torch.cumsum(m_gt, dim=0) / torch.arange(1, m_gt.size(0)+1) 115 | recall = torch.cumsum(m_gt, dim=0) / num_correct 116 | max_precision = torch.cummax(precision.flip(dims=(0, )), dim=0)[0].flip(dims=(0, )) 117 | shift_recall = torch.clone(recall) 118 | shift_recall[1:] = shift_recall[:-1].clone() 119 | shift_recall[0] = 0 120 | recall_diff = recall - shift_recall 121 | AP[i, m] = torch.dot(max_precision, recall_diff) 122 | num_attr[i, m] = m_gt.size(0) 123 | return { 124 | "num_attr_by_mask": num_attr.tolist(), 125 | "attr_AP_by_mask": AP.tolist(), 126 | } 127 | 128 | 129 | def compute_mAP_by_class_and_global(AP_by_sample_by_mask, num_attr_by_sample_by_mask, classes): 130 | num_masks = len(AP_by_sample_by_mask[0]) 131 | classes = torch.tensor(classes) 132 | AP_by_sample_by_mask = torch.tensor(AP_by_sample_by_mask, dtype=torch.float) 133 | num_attr_by_sample_by_mask = torch.tensor(num_attr_by_sample_by_mask, dtype=torch.float) 134 | class_unique = torch.unique(classes) 135 | AP_by_class_by_mask = dict() 136 | for class_id in class_unique: 137 | mask = classes == class_id 138 | APs = AP_by_sample_by_mask[mask] 139 | num_attrs = num_attr_by_sample_by_mask[mask] 140 | total_attrs = torch.sum(num_attrs, dim=0) 141 | class_AP = torch.bmm(APs.permute(1, 0).unsqueeze(1), num_attrs.permute(1, 0).unsqueeze(2)).view(num_masks) / total_attrs 142 | class_AP[total_attrs <= 0] = 0 143 | AP_by_class_by_mask[int(class_id)] = (class_AP, total_attrs) 144 | 145 | APs = torch.stack([AP_by_class_by_mask[k][0] for k in AP_by_class_by_mask.keys()], dim=0) 146 | num_attrs = torch.stack([AP_by_class_by_mask[k][1] for k in AP_by_class_by_mask.keys()], dim=0) 147 | total_attrs = torch.sum(num_attrs, dim=0) 148 | mAP = torch.bmm(APs.permute(1, 0).unsqueeze(1), num_attrs.permute(1, 0).unsqueeze(2)).view(num_masks) / total_attrs 149 | 150 | for k in AP_by_class_by_mask: 151 | AP_by_class_by_mask[k] = AP_by_class_by_mask[k][0].numpy() 152 | 153 | return mAP[-1].item() 154 | 155 | 156 | def compute_loc_mAP_by_attr_and_global(list_AP): 157 | # list_AP (N x A) 158 | # N: num samples 159 | # A: nm attributes 160 | list_AP = np.array(list_AP) 161 | mask = list_AP == None 162 | num_AP_by_attr = np.sum(~mask, axis=0) 163 | list_AP[mask] = 0 164 | list_AP = list_AP.astype(float) 165 | mAP_by_attr = np.sum(list_AP, axis=0) / num_AP_by_attr 166 | mask_nan = np.isnan(mAP_by_attr) 167 | mAP_by_attr[mask_nan] = 0 168 | mAP = np.sum(mAP_by_attr) / np.sum(~mask_nan) 169 | mAP_by_attr = mAP_by_attr.astype(object) 170 | mAP_by_attr[mask_nan] = None 171 | return mAP, mAP_by_attr 172 | 173 | 174 | def compute_loc_mAP(preds, gts): 175 | preds = preds.permute(0, 2, 1) 176 | gts = gts.permute(0, 2, 1) 177 | B, A, L = preds.size() 178 | AP = compute_batch_mAP(preds.reshape(B * A, L), gts.reshape(B * A, L)) 179 | loc_AP = np.array(AP).reshape((B, A)).tolist() 180 | return loc_AP 181 | 182 | 183 | def compute_mAP(scores, gt): 184 | if torch.sum(gt) == 0: 185 | return None 186 | ordered_scores, indices = torch.sort(scores, descending=True) 187 | ordered_gt = gt[indices] 188 | num_correct = torch.sum(gt) 189 | precision = torch.cumsum(ordered_gt, dim=0) / torch.arange(1, ordered_gt.size(0) + 1) 190 | recall = torch.cumsum(ordered_gt, dim=0) / num_correct 191 | max_precision = torch.cummax(precision.flip(dims=(0,)), dim=0)[0].flip(dims=(0,)) 192 | shift_recall = torch.clone(recall) 193 | shift_recall[1:] = shift_recall[:-1].clone() 194 | shift_recall[0] = 0 195 | recall_diff = recall - shift_recall 196 | return float(torch.dot(max_precision, recall_diff)) 197 | 198 | 199 | def compute_batch_mAP(scores, gt): 200 | ordered_scores, indices = torch.sort(scores, descending=True, dim=1) 201 | ordered_gt = torch.stack([gt[i][indices[i]] for i in range(gt.size(0))], dim=0) 202 | num_correct = torch.sum(gt, dim=1) 203 | precision = torch.cumsum(ordered_gt, dim=1) / torch.arange(1, ordered_gt.size(1) + 1) 204 | recall = torch.cumsum(ordered_gt, dim=1) / num_correct.unsqueeze(1) 205 | max_precision = torch.cummax(precision.flip(dims=(1,)), dim=1)[0].flip(dims=(1,)) 206 | shift_recall = torch.clone(recall) 207 | shift_recall[:, 1:] = shift_recall[:, :-1].clone() 208 | shift_recall[:, 0] = 0 209 | recall_diff = recall - shift_recall 210 | batch_AP = (max_precision.unsqueeze(1) @ recall_diff.unsqueeze(2)).squeeze(2).squeeze(1) 211 | batch_AP = [float(ap) if c > 0 else None for ap, c in zip(batch_AP, num_correct)] 212 | return batch_AP 213 | 214 | 215 | def compute_top1_by_attr_and_global(matching, global_num_attr, k): 216 | weights_by_attr = np.array([a[k] for a in global_num_attr]) 217 | value_by_attr = np.array([a[k] for a in matching]) 218 | num_attr = weights_by_attr.shape[1] 219 | avg_by_attr = np.array([np.average([a[i] for a in value_by_attr], weights=[a[i] for a in weights_by_attr]) if np.sum([a[i] for a in weights_by_attr]) > 0 else None for i in range(num_attr)]) 220 | global_weights_by_attr = np.array([np.sum([w[i] for w in weights_by_attr]) for i in range(num_attr)]) 221 | mask = np.array([a != None for a in avg_by_attr]) 222 | global_avg = 100*np.average(avg_by_attr[mask], weights=global_weights_by_attr[mask]) if np.sum(global_weights_by_attr[mask]) > 0 else 0 223 | avg_by_attr = [100 * a if a is not None else a for a in avg_by_attr] 224 | avg_weights_by_attr = [np.mean(a[a > 0]) if (a > 0).max() else None for a in weights_by_attr.T] 225 | return avg_by_attr, avg_weights_by_attr, global_avg 226 | 227 | 228 | def topk_attr_by_mask(prediction, gt, certainty_masks, k=5): 229 | batch_size = prediction.size(0) 230 | num_mask = certainty_masks.size(0) 231 | num_attr_by_mask = certainty_masks.sum(dim=2) 232 | masked_similarity = certainty_masks * prediction 233 | topk = masked_similarity.topk(k, dim=2)[1] 234 | masked_matching = torch.zeros(topk.size()) 235 | metrics = dict() 236 | for k in range(num_mask): 237 | for i in range(batch_size): 238 | masked_matching[k][i] = gt[i].index_select(0, topk[k, i]) 239 | metrics["attr_top1_by_mask"] = masked_matching[:, :, 0].permute(1, 0).tolist() 240 | metrics["attr_top5_by_mask"] = torch.mean(masked_matching, dim=2).permute(1, 0).tolist() 241 | metrics["num_attr_by_mask"] = num_attr_by_mask.permute(1, 0).tolist() 242 | return metrics 243 | 244 | 245 | def topk_attr_mul_by_mask(attr_bin_mul_mapping, prediction, gt, certainty_masks): 246 | batch_size = prediction.size(0) 247 | num_mask = certainty_masks.size(0) 248 | masked_similarity = certainty_masks * prediction 249 | gt_by_mul_attr = [gt.index_select(dim=1, index=torch.tensor(attr_bin_mul_mapping[mul_id])) for mul_id in sorted(attr_bin_mul_mapping.keys())] 250 | gt_num_correct_by_mask = torch.stack([torch.sum(gt_mul, dim=1)for gt_mul in gt_by_mul_attr], dim=0).permute(1, 0) 251 | similarity_by_mul_attr_by_mask = [masked_similarity.index_select(dim=2, index=torch.tensor(attr_bin_mul_mapping[mul_id])) for mul_id in sorted(attr_bin_mul_mapping.keys())] 252 | num_attr_by_mask = torch.stack([(sim != 0).sum(dim=2) for sim in similarity_by_mul_attr_by_mask], dim=2) 253 | num_attr = len(gt_by_mul_attr) 254 | matching = torch.zeros((num_mask, batch_size, num_attr)) 255 | 256 | for m in range(num_mask): 257 | for b in range(batch_size): 258 | for a in range(num_attr): 259 | matching[m, b, a] = gt_by_mul_attr[a][b, similarity_by_mul_attr_by_mask[a][m, b, :].argmax()] 260 | metrics = dict() 261 | metrics["attr_mul_top1_by_mask"] = [matching[:, b, :].tolist() for b in range(batch_size)] 262 | metrics["num_attr_mul_by_attr_by_mask"] = [(torch.stack([torch.logical_and(gt_num_correct_by_mask[b] >= 1, num_attr_by_mask[m, b] >= 1) for m in range(num_mask)] , dim=0) * num_attr_by_mask[:, b, :]).tolist() for b in range(batch_size)] 263 | return metrics 264 | 265 | 266 | def accuracy(prediction, gt): 267 | return (torch.argmax(prediction, dim=1) == gt).int().numpy().tolist() 268 | 269 | 270 | def topk_accuracy(prediction, gt, k=5): 271 | return torch.sum(prediction.topk(k, dim=1)[1] == gt.unsqueeze(1), dim=1).numpy().tolist() 272 | 273 | -------------------------------------------------------------------------------- /Multitasking/models.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torch 3 | from torch.nn import Module, ModuleDict 4 | from torch.nn import Linear 5 | from Multitasking.clip_load import load 6 | 7 | from OPENCLIP.factory import create_model_and_transforms 8 | from Multitasking.swin_models import SwinTransformer 9 | import wget 10 | 11 | 12 | OPEN_MODELS = { 13 | "open-ViT-L/14": 'hf-hub:laion/CLIP-ViT-L-14-laion2B-s32B-b82K', 14 | "open-ViT-H/14": 'hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K', 15 | "open-ViT-G/14": 'hf-hub:laion/CLIP-ViT-bigG-14-laion2B-39B-b160k', 16 | } 17 | 18 | 19 | class ClipManager(Module): 20 | 21 | def __init__(self, params): 22 | super(ClipManager, self).__init__() 23 | self.params = params 24 | self.attr_linear = None 25 | self.classif_linear = None 26 | 27 | if params["clip_model"] in ["RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px"]: 28 | self.clip, preprocess = load(params["clip_model"], self.params, device=params["device"], jit=False, download_root="./Models") 29 | self.preprocess = { 30 | "train": preprocess, 31 | "eval": preprocess 32 | } 33 | self.open_model = False 34 | else: 35 | self.clip, preprocess_train, preprocess_val = create_model_and_transforms(OPEN_MODELS[params["clip_model"]], device=params["device"], jit=False, cache_dir="./Models") 36 | self.preprocess = { 37 | "train": preprocess_train, 38 | "eval": preprocess_val 39 | } 40 | self.open_model = True 41 | 42 | self.params["emb_dim"] = self.clip.visual.output_dim 43 | self.params["text_dim"] = self.clip.text_projection.size(0) 44 | 45 | self.additional_vision_models = ModuleDict() 46 | for key in self.params["config"].keys(): 47 | name, freeze, model_part = self.params["config"][key] 48 | if name == "swin_vision": 49 | self.additional_vision_models[key] = SwinTransformer(num_classes=self.params["emb_dim"], img_size=params["input_size"][0], window_size=12, embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48]) 50 | swin_path = "Models/swin_large_patch4_window12_384_22k.pth" 51 | if not os.path.exists(swin_path): 52 | wget.download("https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth", out="Models") 53 | weights = torch.load(swin_path, map_location="cpu")["model"] 54 | del weights["head.weight"] 55 | del weights["head.bias"] 56 | self.additional_vision_models[key].load_state_dict(weights, strict=False) 57 | self.clip.visual = None 58 | 59 | self.params["visual_dim"] = self.clip.visual.proj.size(0) if self.params["config"]["vision"][0] != "swin_vision" else self.additional_vision_models["vision"].head.weight.size(1) 60 | 61 | if params["classif_linear"]: 62 | self.classif_linear = Linear(self.params["emb_dim"], self.params["num_classes"]) 63 | 64 | if params["attr_linear"]: 65 | self.attr_linear = Linear(self.params["emb_dim"], self.params["num_attributes"]) 66 | 67 | self.visual_loc_proj = Linear(self.params["visual_dim"], self.params["emb_dim"]) 68 | 69 | 70 | @property 71 | def dtype(self): 72 | return self.clip.visual.conv1.weight.dtype if self.clip.visual is not None else self.additional_vision_models["vision"].head.weight.dtype 73 | 74 | def normalize_embedding(self, emb): 75 | return emb / emb.norm(dim=-1, keepdim=True) 76 | 77 | def encode_text(self, text, return_after=None): 78 | pos = text.argmax(dim=-1) 79 | seq = self.clip.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 80 | seq = seq + self.clip.positional_embedding.type(self.dtype) 81 | 82 | seq = seq.permute(1, 0, 2) # NLD -> LND 83 | seq = self.clip.transformer.all_but_last_layer(seq, attn_mask=None if not self.open_model else self.clip.attn_mask) 84 | seq = seq.permute(1, 0, 2) 85 | 86 | if return_after == "transformer-1": 87 | return seq, seq, pos 88 | 89 | seq = seq.permute(1, 0, 2) 90 | seq = self.clip.transformer.last_layer(seq, attn_mask=None if not self.open_model else self.clip.attn_mask) 91 | seq = seq.permute(1, 0, 2) # LND -> NLD 92 | 93 | seq = self.clip.ln_final(seq).type(self.dtype) 94 | 95 | # x.shape = [batch_size, n_ctx, transformer.width] 96 | # take features from the eot embedding (eot_token is the highest number in each sequence) 97 | x = seq[torch.arange(seq.shape[0]), pos] 98 | if return_after == "transformer": 99 | return x, seq, pos 100 | 101 | x = x @ self.clip.text_projection 102 | if return_after == "proj": 103 | return x, seq, pos 104 | 105 | x = self.normalize_embedding(x) 106 | return x, seq, pos 107 | 108 | def proj_text(self, x, seq, pos, features_from=None): 109 | last_tr = proj = norm = False 110 | if features_from == "transformer-1": 111 | last_tr = proj = norm = True 112 | elif features_from == "transformer": 113 | proj = norm = True 114 | elif features_from == "proj": 115 | norm = True 116 | 117 | if last_tr: 118 | seq = seq.permute(1, 0, 2) 119 | seq = self.clip.transformer.last_layer(seq, attn_mask=None if not self.open_model else self.clip.attn_mask) 120 | seq = seq.permute(1, 0, 2) # LND -> NLD 121 | seq = self.clip.ln_final(seq).type(self.dtype) 122 | 123 | # x.shape = [batch_size, n_ctx, transformer.width] 124 | # take features from the eot embedding (eot_token is the highest number in each sequence) 125 | x = seq[torch.arange(seq.shape[0]), pos] 126 | 127 | if proj: 128 | x = x @ self.clip.text_projection 129 | 130 | if norm: 131 | x = self.normalize_embedding(x) 132 | 133 | return x, seq 134 | 135 | def encode_image(self, x, return_after=None): 136 | x = x.type(self.dtype) 137 | 138 | if "vision" in self.additional_vision_models: 139 | return self.additional_vision_models["vision"](x, return_after=return_after) 140 | 141 | if self.open_model and self.clip.visual.input_patchnorm: 142 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') 143 | x = x.reshape(x.shape[0], x.shape[1], self.clip.visual.grid_size[0], self.clip.visual.patch_size[0], self.clip.visual.grid_size[1], 144 | self.clip.visual.patch_size[1]) 145 | x = x.permute(0, 2, 4, 1, 3, 5) 146 | x = x.reshape(x.shape[0], self.clip.visual.grid_size[0] * self.clip.visual.grid_size[1], -1) 147 | x = self.clip.visual.patchnorm_pre_ln(x) 148 | x = self.clip.visual.conv1(x) 149 | else: 150 | x = self.clip.visual.conv1(x) # shape = [*, width, grid, grid] 151 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 152 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 153 | 154 | x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 155 | x = x + self.clip.visual.positional_embedding.to(x.dtype) 156 | 157 | if self.open_model: 158 | x = self.clip.visual.patch_dropout(x) 159 | 160 | x = self.clip.visual.ln_pre(x) 161 | 162 | x = x.permute(1, 0, 2) # NLD -> LND 163 | seq = self.clip.visual.transformer.all_but_last_layer(x) 164 | seq = seq.permute(1, 0, 2) 165 | 166 | if return_after == "transformer-1": 167 | return seq, seq 168 | 169 | seq = seq.permute(1, 0, 2) 170 | seq = self.clip.visual.transformer.last_layer(seq) 171 | seq = seq.permute(1, 0, 2) # LND -> NLD 172 | 173 | if self.open_model: 174 | if self.clip.visual.attn_pool is not None: 175 | seq = self.clip.visual.attn_pool(seq) 176 | seq = self.clip.visual.ln_post(seq) 177 | x = self.clip.visual._global_pool(seq) 178 | else: 179 | x = self.clip.visual._global_pool(seq) 180 | x = self.clip.visual.ln_post(x) 181 | seq = self.clip.visual.ln_post(seq) 182 | else: 183 | seq = self.clip.visual.ln_post(seq) 184 | x = seq[:, 0, :] 185 | 186 | if return_after == "transformer": 187 | return x, seq 188 | 189 | if self.clip.visual.proj is not None: 190 | x = x @ self.clip.visual.proj 191 | 192 | if return_after == "proj": 193 | return x, seq 194 | 195 | return x, seq 196 | 197 | def proj_vision(self, x, seq, features_from=None): 198 | if "vision" in self.additional_vision_models: 199 | return self.additional_vision_models["vision"].proj(x, seq, features_from) 200 | 201 | last_tr = proj = False 202 | if features_from == "transformer-1": 203 | last_tr = proj = True 204 | elif features_from == "transformer": 205 | proj = True 206 | 207 | if last_tr: 208 | seq = seq.permute(1, 0, 2) 209 | seq = self.clip.visual.transformer.last_layer(seq) 210 | seq = seq.permute(1, 0, 2) # LND -> NLD 211 | 212 | if self.open_model: 213 | if self.clip.visual.attn_pool is not None: 214 | seq = self.clip.visual.attn_pool(seq) 215 | seq = self.clip.visual.ln_post(seq) 216 | x = self.clip.visual._global_pool(seq) 217 | else: 218 | x = self.clip.visual._global_pool(seq) 219 | x = self.clip.visual.ln_post(x) 220 | seq = self.clip.visual.ln_post(seq) 221 | else: 222 | seq = self.clip.visual.ln_post(seq) 223 | x = seq[:, 0, :] 224 | 225 | if proj: 226 | x = x @ self.clip.visual.proj 227 | 228 | return x, seq 229 | 230 | def proj_visual_for_loc(self, x): 231 | return self.visual_loc_proj(x) 232 | 233 | def compute_class_linear_scores(self, vision_embedding): 234 | return self.classif_linear(vision_embedding) 235 | 236 | def compute_attr_linear_scores(self, vision_embedding): 237 | return self.attr_linear(vision_embedding) 238 | 239 | 240 | def check_or_download_model_weights(): 241 | fold = os.path.join("Results", "open-ViT-L-14_swin_clip_text_finetuned", "model") 242 | os.makedirs(fold, exist_ok=True) 243 | if len(os.listdir(fold)) == 0: 244 | print("Downloading pre-trained model weights, this can take a moment...") 245 | wget.download("https://zenodo.org/record/8124014/files/last_100.pt?download=1", out=fold) 246 | print("Download completed") -------------------------------------------------------------------------------- /Multitasking/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import os 5 | from Multitasking.models import OPEN_MODELS 6 | from clip.simple_tokenizer import SimpleTokenizer 7 | from OPENCLIP.factory import get_tokenizer as open_get_tokenizer 8 | 9 | 10 | def randint(low, high): 11 | """ 12 | call torch.randint to preserve random among dataloader workers 13 | """ 14 | return int(torch.randint(low, high, (1, ))) 15 | 16 | 17 | def rand(): 18 | """ 19 | call torch.rand to preserve random among dataloader workers 20 | """ 21 | return float(torch.rand((1, ))) 22 | 23 | 24 | def pad_images(data, padding_value): 25 | """ 26 | data: list of numpy array 27 | """ 28 | x_lengths = [x.shape[0] for x in data] 29 | y_lengths = [x.shape[1] for x in data] 30 | longest_x = max(x_lengths) 31 | longest_y = max(y_lengths) 32 | padded_data = np.ones((len(data), longest_x, longest_y, data[0].shape[2])) * padding_value 33 | for i, xy_len in enumerate(zip(x_lengths, y_lengths)): 34 | x_len, y_len = xy_len 35 | padded_data[i, :x_len, :y_len, ...] = data[i] 36 | return padded_data 37 | 38 | 39 | def seed_all(seed): 40 | random.seed(seed) 41 | os.environ['PYTHONHASHSEED'] = str(seed) 42 | np.random.seed(seed) 43 | torch.manual_seed(seed) 44 | torch.cuda.manual_seed(seed) 45 | torch.backends.cudnn.deterministic = True 46 | torch.backends.cudnn.benchmark = True 47 | 48 | 49 | def get_tokenizer(model_name): 50 | if "open-" in model_name: 51 | return open_get_tokenizer(OPEN_MODELS[model_name].replace("open-", ""), cache_dir="Tokenizers").tokenizer 52 | return SimpleTokenizer() 53 | 54 | 55 | def arg_to_bool(arg): 56 | return arg.lower() == "true" 57 | 58 | 59 | def none_or_str(value): 60 | if value.lower() == 'none': 61 | return None 62 | return value -------------------------------------------------------------------------------- /OPENCLIP/__init__.py: -------------------------------------------------------------------------------- 1 | from .coca_model import CoCa 2 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 3 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 4 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 5 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 6 | from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ 7 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 8 | from .openai import load_openai_model, list_openai_models 9 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 10 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 11 | from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub 12 | from .tokenizer import SimpleTokenizer, tokenize, decode 13 | from .transform import image_transform, AugmentationCfg 14 | -------------------------------------------------------------------------------- /OPENCLIP/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FactoDeepLearning/MultitaskVLFM/cde10ef783a5cdeaf1e06016560dfb6cc2a3ffa2/OPENCLIP/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /OPENCLIP/coca_model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | import numpy as np 7 | from dataclasses import dataclass 8 | 9 | from .transformer import ( 10 | LayerNormFp32, 11 | LayerNorm, 12 | QuickGELU, 13 | MultimodalTransformer, 14 | ) 15 | from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower 16 | 17 | try: 18 | from transformers import ( 19 | BeamSearchScorer, 20 | LogitsProcessorList, 21 | TopPLogitsWarper, 22 | TopKLogitsWarper, 23 | RepetitionPenaltyLogitsProcessor, 24 | MinLengthLogitsProcessor, 25 | MaxLengthCriteria, 26 | StoppingCriteriaList 27 | ) 28 | 29 | GENERATION_TYPES = { 30 | "top_k": TopKLogitsWarper, 31 | "top_p": TopPLogitsWarper, 32 | "beam_search": "beam_search" 33 | } 34 | except ImportError as e: 35 | pass 36 | 37 | 38 | 39 | @dataclass 40 | class MultimodalCfg(CLIPTextCfg): 41 | mlp_ratio: int = 4 42 | dim_head: int = 64 43 | heads: int = 8 44 | n_queries: int = 256 45 | attn_pooler_heads: int = 8 46 | 47 | 48 | def _build_text_decoder_tower( 49 | embed_dim, 50 | multimodal_cfg, 51 | quick_gelu: bool = False, 52 | cast_dtype: Optional[torch.dtype] = None, 53 | ): 54 | multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg 55 | act_layer = QuickGELU if quick_gelu else nn.GELU 56 | norm_layer = ( 57 | LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 58 | ) 59 | 60 | decoder = MultimodalTransformer( 61 | context_length=multimodal_cfg.context_length, 62 | width=multimodal_cfg.width, 63 | heads=multimodal_cfg.heads, 64 | layers=multimodal_cfg.layers, 65 | ls_init_value=multimodal_cfg.ls_init_value, 66 | output_dim=embed_dim, 67 | act_layer=act_layer, 68 | norm_layer=norm_layer, 69 | ) 70 | 71 | return decoder 72 | 73 | 74 | class CoCa(nn.Module): 75 | def __init__( 76 | self, 77 | embed_dim, 78 | multimodal_cfg: MultimodalCfg, 79 | text_cfg: CLIPTextCfg, 80 | vision_cfg: CLIPVisionCfg, 81 | quick_gelu: bool = False, 82 | cast_dtype: Optional[torch.dtype] = None, 83 | pad_id: int = 0, 84 | ): 85 | super().__init__() 86 | multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg 87 | text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg 88 | vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg 89 | 90 | self.text = _build_text_tower( 91 | embed_dim=embed_dim, 92 | text_cfg=text_cfg, 93 | quick_gelu=quick_gelu, 94 | cast_dtype=cast_dtype, 95 | ) 96 | 97 | vocab_size = ( 98 | text_cfg.vocab_size # for hf models 99 | if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None 100 | else text_cfg.vocab_size 101 | ) 102 | 103 | self.visual = _build_vision_tower( 104 | embed_dim=embed_dim, 105 | vision_cfg=vision_cfg, 106 | quick_gelu=quick_gelu, 107 | cast_dtype=cast_dtype, 108 | ) 109 | 110 | self.text_decoder = _build_text_decoder_tower( 111 | vocab_size, 112 | multimodal_cfg=multimodal_cfg, 113 | quick_gelu=quick_gelu, 114 | cast_dtype=cast_dtype, 115 | ) 116 | 117 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 118 | self.pad_id = pad_id 119 | 120 | @torch.jit.ignore 121 | def set_grad_checkpointing(self, enable=True): 122 | self.visual.set_grad_checkpointing(enable) 123 | self.text.set_grad_checkpointing(enable) 124 | self.text_decoder.set_grad_checkpointing(enable) 125 | 126 | def _encode_image(self, images, normalize=True): 127 | image_latent, tokens_embs = self.visual(images) 128 | image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent 129 | return image_latent, tokens_embs 130 | 131 | def _encode_text(self, text, normalize=True, embed_cls=True): 132 | text = text[:, :-1] if embed_cls else text # make space for CLS token 133 | text_latent, token_emb = self.text(text) 134 | text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent 135 | return text_latent, token_emb 136 | 137 | def encode_image(self, images, normalize=True): 138 | image_latent, _ = self._encode_image(images, normalize=normalize) 139 | return image_latent 140 | 141 | def encode_text(self, text, normalize=True, embed_cls=True): 142 | text_latent, _ = self._encode_text(text, normalize=normalize, embed_cls=embed_cls) 143 | return text_latent 144 | 145 | def forward(self, image, text, embed_cls=True, image_latent=None, image_embs=None): 146 | text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) 147 | if image_latent is None or image_embs is None: 148 | image_latent, image_embs = self._encode_image(image) 149 | 150 | # TODO: add assertion to avoid bugs? 151 | labels = text[:, -token_embs.shape[1]:] 152 | 153 | logits = self.text_decoder(image_embs, token_embs) 154 | return { 155 | "image_features": image_latent, 156 | "text_features": text_latent, 157 | "logits": logits, 158 | "labels": labels, 159 | "logit_scale": self.logit_scale.exp() 160 | } 161 | 162 | 163 | # taking many ideas and components from HuggingFace GenerationMixin 164 | # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation 165 | def generate( 166 | self, 167 | image, 168 | text=None, 169 | seq_len=30, 170 | max_seq_len=77, 171 | temperature=1., 172 | generation_type="beam_search", 173 | top_p=0.1, # keep tokens in the 1 - top_p quantile 174 | top_k=1, # keeps the top_k most probable tokens 175 | pad_token_id=None, 176 | eos_token_id=None, 177 | sot_token_id=None, 178 | num_beams=6, 179 | num_beam_groups=3, 180 | min_seq_len=5, 181 | stopping_criteria=None, 182 | repetition_penalty=1.0, 183 | fixed_output_length=False # if True output.shape == (batch_size, seq_len) 184 | ): 185 | 186 | assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" 187 | with torch.no_grad(): 188 | sot_token_id = 49406 if sot_token_id is None else sot_token_id 189 | eos_token_id = 49407 if eos_token_id is None else eos_token_id 190 | pad_token_id = self.pad_id if pad_token_id is None else pad_token_id 191 | logit_processor = LogitsProcessorList( 192 | [ 193 | MinLengthLogitsProcessor(min_seq_len, eos_token_id), 194 | RepetitionPenaltyLogitsProcessor(repetition_penalty), 195 | ] 196 | ) 197 | 198 | if stopping_criteria is None: 199 | stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] 200 | 201 | stopping_criteria = StoppingCriteriaList( 202 | stopping_criteria 203 | ) 204 | 205 | device = image.device 206 | 207 | if generation_type == "beam_search": 208 | output = self._generate_beamsearch( 209 | image_inputs = image, 210 | pad_token_id=pad_token_id, 211 | eos_token_id=eos_token_id, 212 | sot_token_id=sot_token_id, 213 | num_beams=num_beams, 214 | num_beam_groups=num_beam_groups, 215 | min_seq_len=min_seq_len, 216 | stopping_criteria=stopping_criteria, 217 | logit_processor=logit_processor, 218 | ) 219 | if fixed_output_length and output.shape[1] < seq_len: 220 | return torch.cat( 221 | (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), 222 | dim=1 223 | ) 224 | return output 225 | 226 | elif generation_type == "top_p": 227 | logit_warper = GENERATION_TYPES[generation_type](top_p) 228 | elif generation_type == "top_k": 229 | logit_warper = GENERATION_TYPES[generation_type](top_k) 230 | else: 231 | raise ValueError( 232 | f"generation_type has to be one of " 233 | f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." 234 | ) 235 | 236 | image_latent, image_embs = self._encode_image(image) 237 | 238 | if text is None: 239 | text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id 240 | 241 | was_training = self.training 242 | num_dims = len(text.shape) 243 | 244 | if num_dims == 1: 245 | text = text[None, :] 246 | 247 | cur_len = text.shape[1] 248 | self.eval() 249 | out = text 250 | 251 | 252 | 253 | while True: 254 | x = out[:, -max_seq_len:] 255 | cur_len = x.shape[1] 256 | logits = self(image, x, image_latent=image_latent, image_embs=image_embs, embed_cls=False)["logits"][:, -1] 257 | mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) 258 | sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id 259 | 260 | if mask.all(): 261 | if not fixed_output_length: 262 | break 263 | else: 264 | logits = logits[~mask, :] 265 | filtered_logits = logit_processor(x[~mask, :], logits) 266 | filtered_logits = logit_warper(x[~mask, :], filtered_logits) 267 | probs = F.softmax(filtered_logits / temperature, dim=-1) 268 | 269 | if (cur_len + 1 == seq_len): 270 | sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id 271 | else: 272 | sample[~mask, :] = torch.multinomial(probs, 1) 273 | 274 | out = torch.cat((out, sample), dim=-1) 275 | 276 | cur_len += 1 277 | 278 | if stopping_criteria(out, None): 279 | break 280 | 281 | if num_dims == 1: 282 | out = out.squeeze(0) 283 | 284 | self.train(was_training) 285 | return out 286 | 287 | def _generate_beamsearch( 288 | self, 289 | image_inputs, 290 | pad_token_id=None, 291 | eos_token_id=None, 292 | sot_token_id=None, 293 | num_beams=6, 294 | num_beam_groups=3, 295 | min_seq_len=5, 296 | stopping_criteria=None, 297 | logit_processor=None, 298 | logit_warper=None, 299 | ): 300 | 301 | device = image_inputs.device 302 | batch_size = image_inputs.shape[0] 303 | image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) 304 | image_latent, image_embs = self._encode_image(image_inputs) 305 | 306 | input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) 307 | input_ids = input_ids * sot_token_id 308 | beam_scorer = BeamSearchScorer( 309 | batch_size=batch_size, 310 | num_beams=num_beams, 311 | device=device, 312 | num_beam_groups=num_beam_groups, 313 | ) 314 | # instantiate logits processors 315 | logits_processor = ( 316 | LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) 317 | if logit_processor is None 318 | else logit_processor 319 | ) 320 | 321 | batch_size = len(beam_scorer._beam_hyps) 322 | num_beams = beam_scorer.num_beams 323 | num_beam_groups = beam_scorer.num_beam_groups 324 | num_sub_beams = num_beams // num_beam_groups 325 | batch_beam_size, cur_len = input_ids.shape 326 | beam_indices = None 327 | 328 | if num_beams * batch_size != batch_beam_size: 329 | raise ValueError( 330 | f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." 331 | ) 332 | 333 | beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) 334 | # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in 335 | # the same group don't produce same tokens everytime. 336 | beam_scores[:, ::num_sub_beams] = 0 337 | beam_scores = beam_scores.view((batch_size * num_beams,)) 338 | 339 | while True: 340 | 341 | # predicted tokens in cur_len step 342 | current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) 343 | 344 | # indices which will form the beams in the next time step 345 | reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) 346 | 347 | # do one decoder step on all beams of all sentences in batch 348 | model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) 349 | outputs = self( 350 | model_inputs['images'], 351 | model_inputs['text'], 352 | embed_cls=False, 353 | image_latent=image_latent, 354 | image_embs=image_embs 355 | ) 356 | 357 | for beam_group_idx in range(num_beam_groups): 358 | group_start_idx = beam_group_idx * num_sub_beams 359 | group_end_idx = min(group_start_idx + num_sub_beams, num_beams) 360 | group_size = group_end_idx - group_start_idx 361 | 362 | # indices of beams of current group among all sentences in batch 363 | batch_group_indices = [] 364 | 365 | for batch_idx in range(batch_size): 366 | batch_group_indices.extend( 367 | [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] 368 | ) 369 | group_input_ids = input_ids[batch_group_indices] 370 | 371 | # select outputs of beams of currentg group only 372 | next_token_logits = outputs['logits'][batch_group_indices, -1, :] 373 | vocab_size = next_token_logits.shape[-1] 374 | 375 | next_token_scores_processed = logits_processor( 376 | group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx 377 | ) 378 | next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) 379 | next_token_scores = next_token_scores.expand_as(next_token_scores_processed) 380 | 381 | # reshape for beam search 382 | next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) 383 | 384 | next_token_scores, next_tokens = torch.topk( 385 | next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True 386 | ) 387 | 388 | next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") 389 | next_tokens = next_tokens % vocab_size 390 | 391 | # stateless 392 | process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None 393 | beam_outputs = beam_scorer.process( 394 | group_input_ids, 395 | next_token_scores, 396 | next_tokens, 397 | next_indices, 398 | pad_token_id=pad_token_id, 399 | eos_token_id=eos_token_id, 400 | beam_indices=process_beam_indices, 401 | ) 402 | beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] 403 | beam_next_tokens = beam_outputs["next_beam_tokens"] 404 | beam_idx = beam_outputs["next_beam_indices"] 405 | 406 | input_ids[batch_group_indices] = group_input_ids[beam_idx] 407 | group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) 408 | current_tokens[batch_group_indices] = group_input_ids[:, -1] 409 | 410 | # (beam_idx // group_size) -> batch_idx 411 | # (beam_idx % group_size) -> offset of idx inside the group 412 | reordering_indices[batch_group_indices] = ( 413 | num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) 414 | ) 415 | 416 | input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) 417 | 418 | # increase cur_len 419 | cur_len = cur_len + 1 420 | if beam_scorer.is_done or stopping_criteria(input_ids, None): 421 | break 422 | 423 | final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None 424 | sequence_outputs = beam_scorer.finalize( 425 | input_ids, 426 | beam_scores, 427 | next_tokens, 428 | next_indices, 429 | pad_token_id=pad_token_id, 430 | eos_token_id=eos_token_id, 431 | max_length=stopping_criteria.max_length, 432 | beam_indices=final_beam_indices, 433 | ) 434 | return sequence_outputs['sequences'] 435 | 436 | 437 | def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): 438 | if past: 439 | input_ids = input_ids[:, -1].unsqueeze(-1) 440 | 441 | attention_mask = kwargs.get("attention_mask", None) 442 | position_ids = kwargs.get("position_ids", None) 443 | 444 | if attention_mask is not None and position_ids is None: 445 | # create position_ids on the fly for batch generation 446 | position_ids = attention_mask.long().cumsum(-1) - 1 447 | position_ids.masked_fill_(attention_mask == 0, 1) 448 | else: 449 | position_ids = None 450 | return { 451 | "text": input_ids, 452 | "images": image_inputs, 453 | "past_key_values": past, 454 | "position_ids": position_ids, 455 | "attention_mask": attention_mask, 456 | } 457 | -------------------------------------------------------------------------------- /OPENCLIP/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /OPENCLIP/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | from typing import Any, Dict, Optional, Tuple, Union 9 | 10 | import torch 11 | 12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ 14 | resize_pos_embed, get_cast_dtype 15 | from .coca_model import CoCa 16 | from .loss import ClipLoss, DistillClipLoss, CoCaLoss 17 | from .openai import load_openai_model 18 | from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf 19 | from .transform import image_transform, AugmentationCfg 20 | from .tokenizer import HFTokenizer, tokenize 21 | 22 | 23 | HF_HUB_PREFIX = 'hf-hub:' 24 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 25 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 26 | 27 | 28 | def _natural_key(string_): 29 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 30 | 31 | 32 | def _rescan_model_configs(): 33 | global _MODEL_CONFIGS 34 | 35 | config_ext = ('.json',) 36 | config_files = [] 37 | for config_path in _MODEL_CONFIG_PATHS: 38 | if config_path.is_file() and config_path.suffix in config_ext: 39 | config_files.append(config_path) 40 | elif config_path.is_dir(): 41 | for ext in config_ext: 42 | config_files.extend(config_path.glob(f'*{ext}')) 43 | 44 | for cf in config_files: 45 | with open(cf, 'r') as f: 46 | model_cfg = json.load(f) 47 | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): 48 | _MODEL_CONFIGS[cf.stem] = model_cfg 49 | 50 | _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} 51 | 52 | 53 | _rescan_model_configs() # initial populate of model config registry 54 | 55 | 56 | def list_models(): 57 | """ enumerate available model architectures based on config files """ 58 | return list(_MODEL_CONFIGS.keys()) 59 | 60 | 61 | def add_model_config(path): 62 | """ add model config path or file and update registry """ 63 | if not isinstance(path, Path): 64 | path = Path(path) 65 | _MODEL_CONFIG_PATHS.append(path) 66 | _rescan_model_configs() 67 | 68 | 69 | def get_model_config(model_name): 70 | if model_name in _MODEL_CONFIGS: 71 | return deepcopy(_MODEL_CONFIGS[model_name]) 72 | else: 73 | return None 74 | 75 | 76 | def get_tokenizer(model_name, cache_dir=None): 77 | if model_name.startswith(HF_HUB_PREFIX): 78 | tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):], cache_dir) 79 | else: 80 | config = get_model_config(model_name) 81 | tokenizer = HFTokenizer( 82 | config['text_cfg']['hf_tokenizer_name'], cache_dir) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize 83 | return tokenizer 84 | 85 | 86 | def load_state_dict(checkpoint_path: str, map_location='cpu'): 87 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 88 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 89 | state_dict = checkpoint['state_dict'] 90 | else: 91 | state_dict = checkpoint 92 | if next(iter(state_dict.items()))[0].startswith('module'): 93 | state_dict = {k[7:]: v for k, v in state_dict.items()} 94 | return state_dict 95 | 96 | 97 | def load_checkpoint(model, checkpoint_path, strict=True): 98 | state_dict = load_state_dict(checkpoint_path) 99 | # detect old format and make compatible with new format 100 | if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): 101 | state_dict = convert_to_custom_text_state_dict(state_dict) 102 | resize_pos_embed(state_dict, model) 103 | incompatible_keys = model.load_state_dict(state_dict, strict=False) 104 | return incompatible_keys 105 | 106 | 107 | def create_model( 108 | model_name: str, 109 | pretrained: Optional[str] = None, 110 | precision: str = 'fp32', 111 | device: Union[str, torch.device] = 'cpu', 112 | jit: bool = False, 113 | force_quick_gelu: bool = False, 114 | force_custom_text: bool = False, 115 | force_patch_dropout: Optional[float] = None, 116 | force_image_size: Optional[Union[int, Tuple[int, int]]] = None, 117 | pretrained_image: bool = False, 118 | pretrained_hf: bool = True, 119 | cache_dir: Optional[str] = None, 120 | output_dict: Optional[bool] = None, 121 | require_pretrained: bool = False, 122 | ): 123 | has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) 124 | if has_hf_hub_prefix: 125 | model_id = model_name[len(HF_HUB_PREFIX):] 126 | checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) 127 | config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) 128 | 129 | with open(config_path, 'r', encoding='utf-8') as f: 130 | config = json.load(f) 131 | pretrained_cfg = config['preprocess_cfg'] 132 | model_cfg = config['model_cfg'] 133 | else: 134 | model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names 135 | checkpoint_path = None 136 | pretrained_cfg = {} 137 | model_cfg = None 138 | 139 | if isinstance(device, str): 140 | device = torch.device(device) 141 | 142 | if pretrained and pretrained.lower() == 'openai': 143 | logging.info(f'Loading pretrained {model_name} from OpenAI.') 144 | model = load_openai_model( 145 | model_name, 146 | precision=precision, 147 | device=device, 148 | jit=jit, 149 | cache_dir=cache_dir, 150 | ) 151 | 152 | # to always output dict even if it is clip 153 | if output_dict and hasattr(model, "output_dict"): 154 | model.output_dict = True 155 | else: 156 | model_cfg = model_cfg or get_model_config(model_name) 157 | if model_cfg is not None: 158 | logging.info(f'Loaded {model_name} model config.') 159 | else: 160 | logging.error(f'Model config for {model_name} not found; available models {list_models()}.') 161 | raise RuntimeError(f'Model config for {model_name} not found.') 162 | 163 | if force_quick_gelu: 164 | # override for use of QuickGELU on non-OpenAI transformer models 165 | model_cfg["quick_gelu"] = True 166 | 167 | if force_patch_dropout is not None: 168 | # override the default patch dropout value 169 | model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout 170 | 171 | if force_image_size is not None: 172 | # override model config's image size 173 | model_cfg["vision_cfg"]["image_size"] = force_image_size 174 | 175 | if pretrained_image: 176 | if 'timm_model_name' in model_cfg.get('vision_cfg', {}): 177 | # pretrained weight loading for timm models set via vision_cfg 178 | model_cfg['vision_cfg']['timm_model_pretrained'] = True 179 | else: 180 | assert False, 'pretrained image towers currently only supported for timm models' 181 | 182 | cast_dtype = get_cast_dtype(precision) 183 | is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) 184 | custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model 185 | 186 | if custom_text: 187 | if is_hf_model: 188 | model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf 189 | if "coca" in model_name: 190 | model = CoCa(**model_cfg, cast_dtype=cast_dtype) 191 | else: 192 | model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) 193 | else: 194 | model = CLIP(**model_cfg, cast_dtype=cast_dtype) 195 | 196 | pretrained_loaded = False 197 | if pretrained: 198 | checkpoint_path = '' 199 | pretrained_cfg = get_pretrained_cfg(model_name, pretrained) 200 | if pretrained_cfg: 201 | checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) 202 | elif os.path.exists(pretrained): 203 | checkpoint_path = pretrained 204 | 205 | if checkpoint_path: 206 | logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') 207 | load_checkpoint(model, checkpoint_path) 208 | else: 209 | error_str = ( 210 | f'Pretrained weights ({pretrained}) not found for model {model_name}.' 211 | f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') 212 | logging.warning(error_str) 213 | raise RuntimeError(error_str) 214 | pretrained_loaded = True 215 | elif has_hf_hub_prefix: 216 | logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') 217 | load_checkpoint(model, checkpoint_path) 218 | pretrained_loaded = True 219 | 220 | if require_pretrained and not pretrained_loaded: 221 | # callers of create_model_from_pretrained always expect pretrained weights 222 | raise RuntimeError( 223 | f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') 224 | 225 | model.to(device=device) 226 | if precision in ("fp16", "bf16"): 227 | convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16) 228 | 229 | # set image / mean metadata from pretrained_cfg if available, or use default 230 | model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN 231 | model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD 232 | 233 | # to always output dict even if it is clip 234 | if output_dict and hasattr(model, "output_dict"): 235 | model.output_dict = True 236 | 237 | if jit: 238 | model = torch.jit.script(model) 239 | 240 | return model 241 | 242 | 243 | def create_loss(args): 244 | if args.distill: 245 | return DistillClipLoss( 246 | local_loss=args.local_loss, 247 | gather_with_grad=args.gather_with_grad, 248 | cache_labels=True, 249 | rank=args.rank, 250 | world_size=args.world_size, 251 | use_horovod=args.horovod, 252 | ) 253 | elif "coca" in args.model.lower(): 254 | return CoCaLoss( 255 | caption_loss_weight=args.coca_caption_loss_weight, 256 | clip_loss_weight=args.coca_contrastive_loss_weight, 257 | local_loss=args.local_loss, 258 | gather_with_grad=args.gather_with_grad, 259 | cache_labels=True, 260 | rank=args.rank, 261 | world_size=args.world_size, 262 | use_horovod=args.horovod, 263 | ) 264 | return ClipLoss( 265 | local_loss=args.local_loss, 266 | gather_with_grad=args.gather_with_grad, 267 | cache_labels=True, 268 | rank=args.rank, 269 | world_size=args.world_size, 270 | use_horovod=args.horovod, 271 | ) 272 | 273 | 274 | def create_model_and_transforms( 275 | model_name: str, 276 | pretrained: Optional[str] = None, 277 | precision: str = 'fp32', 278 | device: Union[str, torch.device] = 'cpu', 279 | jit: bool = False, 280 | force_quick_gelu: bool = False, 281 | force_custom_text: bool = False, 282 | force_patch_dropout: Optional[float] = None, 283 | force_image_size: Optional[Union[int, Tuple[int, int]]] = None, 284 | pretrained_image: bool = False, 285 | pretrained_hf: bool = True, 286 | image_mean: Optional[Tuple[float, ...]] = None, 287 | image_std: Optional[Tuple[float, ...]] = None, 288 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, 289 | cache_dir: Optional[str] = None, 290 | output_dict: Optional[bool] = None, 291 | ): 292 | model = create_model( 293 | model_name, 294 | pretrained, 295 | precision=precision, 296 | device=device, 297 | jit=jit, 298 | force_quick_gelu=force_quick_gelu, 299 | force_custom_text=force_custom_text, 300 | force_patch_dropout=force_patch_dropout, 301 | force_image_size=force_image_size, 302 | pretrained_image=pretrained_image, 303 | pretrained_hf=pretrained_hf, 304 | cache_dir=cache_dir, 305 | output_dict=output_dict, 306 | ) 307 | 308 | image_mean = image_mean or getattr(model.visual, 'image_mean', None) 309 | image_std = image_std or getattr(model.visual, 'image_std', None) 310 | preprocess_train = image_transform( 311 | model.visual.image_size, 312 | is_train=True, 313 | mean=image_mean, 314 | std=image_std, 315 | aug_cfg=aug_cfg, 316 | ) 317 | preprocess_val = image_transform( 318 | model.visual.image_size, 319 | is_train=False, 320 | mean=image_mean, 321 | std=image_std, 322 | ) 323 | 324 | return model, preprocess_train, preprocess_val 325 | 326 | 327 | def create_model_from_pretrained( 328 | model_name: str, 329 | pretrained: Optional[str] = None, 330 | precision: str = 'fp32', 331 | device: Union[str, torch.device] = 'cpu', 332 | jit: bool = False, 333 | force_quick_gelu: bool = False, 334 | force_custom_text: bool = False, 335 | force_image_size: Optional[Union[int, Tuple[int, int]]] = None, 336 | return_transform: bool = True, 337 | image_mean: Optional[Tuple[float, ...]] = None, 338 | image_std: Optional[Tuple[float, ...]] = None, 339 | cache_dir: Optional[str] = None, 340 | ): 341 | model = create_model( 342 | model_name, 343 | pretrained, 344 | precision=precision, 345 | device=device, 346 | jit=jit, 347 | force_quick_gelu=force_quick_gelu, 348 | force_custom_text=force_custom_text, 349 | force_image_size=force_image_size, 350 | cache_dir=cache_dir, 351 | require_pretrained=True, 352 | ) 353 | 354 | if not return_transform: 355 | return model 356 | 357 | image_mean = image_mean or getattr(model.visual, 'image_mean', None) 358 | image_std = image_std or getattr(model.visual, 'image_std', None) 359 | preprocess = image_transform( 360 | model.visual.image_size, 361 | is_train=False, 362 | mean=image_mean, 363 | std=image_std, 364 | ) 365 | 366 | return model, preprocess 367 | -------------------------------------------------------------------------------- /OPENCLIP/generation_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FactoDeepLearning/MultitaskVLFM/cde10ef783a5cdeaf1e06016560dfb6cc2a3ffa2/OPENCLIP/generation_utils.py -------------------------------------------------------------------------------- /OPENCLIP/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | } 46 | -------------------------------------------------------------------------------- /OPENCLIP/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in Multitasking model. 4 | """ 5 | 6 | import re 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch import TensorType 11 | 12 | try: 13 | import transformers 14 | from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig 15 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ 16 | BaseModelOutputWithPoolingAndCrossAttentions 17 | except ImportError as e: 18 | transformers = None 19 | 20 | 21 | class BaseModelOutput: 22 | pass 23 | 24 | 25 | class PretrainedConfig: 26 | pass 27 | 28 | from .hf_configs import arch_dict 29 | 30 | 31 | # utils 32 | def _camel2snake(s): 33 | return re.sub(r'(? 1: 91 | all_image_features, all_text_features = gather_features( 92 | image_features, text_features, 93 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 94 | 95 | if self.local_loss: 96 | logits_per_image = logit_scale * image_features @ all_text_features.T 97 | logits_per_text = logit_scale * text_features @ all_image_features.T 98 | else: 99 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 100 | logits_per_text = logits_per_image.T 101 | else: 102 | logits_per_image = logit_scale * image_features @ text_features.T 103 | logits_per_text = logit_scale * text_features @ image_features.T 104 | 105 | return logits_per_image, logits_per_text 106 | 107 | def forward(self, image_features, text_features, logit_scale, output_dict=False): 108 | device = image_features.device 109 | logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) 110 | 111 | # calculated ground-truth and cache if enabled 112 | num_logits = logits_per_image.shape[0] 113 | if self.prev_num_logits != num_logits or device not in self.labels: 114 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 115 | if self.world_size > 1 and self.local_loss: 116 | labels = labels + num_logits * self.rank 117 | if self.cache_labels: 118 | self.labels[device] = labels 119 | self.prev_num_logits = num_logits 120 | else: 121 | labels = self.labels[device] 122 | 123 | total_loss = ( 124 | F.cross_entropy(logits_per_image, labels) + 125 | F.cross_entropy(logits_per_text, labels) 126 | ) / 2 127 | 128 | return {"contrastive_loss": total_loss} if output_dict else total_loss 129 | 130 | 131 | class CoCaLoss(ClipLoss): 132 | def __init__( 133 | self, 134 | caption_loss_weight, 135 | clip_loss_weight, 136 | pad_id=0, # pad_token for open_clip custom tokenizer 137 | local_loss=False, 138 | gather_with_grad=False, 139 | cache_labels=False, 140 | rank=0, 141 | world_size=1, 142 | use_horovod=False, 143 | ): 144 | super().__init__( 145 | local_loss=local_loss, 146 | gather_with_grad=gather_with_grad, 147 | cache_labels=cache_labels, 148 | rank=rank, 149 | world_size=world_size, 150 | use_horovod=use_horovod 151 | ) 152 | 153 | self.clip_loss_weight = clip_loss_weight 154 | self.caption_loss_weight = caption_loss_weight 155 | self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) 156 | 157 | def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): 158 | clip_loss = super().forward(image_features, text_features, logit_scale) 159 | clip_loss = self.clip_loss_weight * clip_loss 160 | 161 | caption_loss = self.caption_loss( 162 | logits.permute(0, 2, 1), 163 | labels, 164 | ) 165 | caption_loss = caption_loss * self.caption_loss_weight 166 | 167 | if output_dict: 168 | return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} 169 | 170 | return clip_loss, caption_loss 171 | 172 | 173 | class DistillClipLoss(ClipLoss): 174 | 175 | def dist_loss(self, teacher_logits, student_logits): 176 | return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) 177 | 178 | def forward(self, image_features, text_features, logit_scale, dist_image_features, dist_text_features, dist_logit_scale, output_dict=False): 179 | logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) 180 | dist_logits_per_image, dist_logits_per_text = self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) 181 | 182 | #FIXME: remove some of this duplicate code. 183 | # calculated ground-truth and cache if enabled 184 | device = image_features.device 185 | num_logits = logits_per_image.shape[0] 186 | if self.prev_num_logits != num_logits or device not in self.labels: 187 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 188 | if self.world_size > 1 and self.local_loss: 189 | labels = labels + num_logits * self.rank 190 | if self.cache_labels: 191 | self.labels[device] = labels 192 | self.prev_num_logits = num_logits 193 | else: 194 | labels = self.labels[device] 195 | 196 | contrastive_loss = ( 197 | F.cross_entropy(logits_per_image, labels) + 198 | F.cross_entropy(logits_per_text, labels) 199 | ) / 2 200 | distill_loss = ( 201 | self.dist_loss(dist_logits_per_image, logits_per_image) + 202 | self.dist_loss(dist_logits_per_text, logits_per_text) 203 | ) / 2 204 | 205 | if output_dict: 206 | return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} 207 | 208 | return contrastive_loss, distill_loss -------------------------------------------------------------------------------- /OPENCLIP/model.py: -------------------------------------------------------------------------------- 1 | """ Multitasking Model 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | from dataclasses import dataclass 6 | import logging 7 | import math 8 | from typing import Optional, Tuple, Union 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn 14 | from torch.utils.checkpoint import checkpoint 15 | 16 | from .hf_model import HFTextEncoder 17 | from .modified_resnet import ModifiedResNet 18 | from .timm_model import TimmModel 19 | from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer 20 | from .utils import to_2tuple 21 | 22 | 23 | @dataclass 24 | class CLIPVisionCfg: 25 | layers: Union[Tuple[int, int, int, int], int] = 12 26 | width: int = 768 27 | head_width: int = 64 28 | mlp_ratio: float = 4.0 29 | patch_size: int = 16 30 | image_size: Union[Tuple[int, int], int] = 224 31 | ls_init_value: Optional[float] = None # layer scale initial value 32 | patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results 33 | input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design 34 | global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) 35 | attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer 36 | n_queries: int = 256 # n_queries for attentional pooler 37 | attn_pooler_heads: int = 8 # n heads for attentional_pooling 38 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size 39 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model 40 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') 41 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') 42 | timm_proj_bias: bool = False # enable bias final projection 43 | timm_drop: float = 0. # head dropout 44 | timm_drop_path: Optional[float] = None # backbone stochastic depth 45 | output_tokens: bool = False 46 | 47 | 48 | @dataclass 49 | class CLIPTextCfg: 50 | context_length: int = 77 51 | vocab_size: int = 49408 52 | width: int = 512 53 | heads: int = 8 54 | layers: int = 12 55 | ls_init_value: Optional[float] = None # layer scale initial value 56 | hf_model_name: str = None 57 | hf_tokenizer_name: str = None 58 | hf_model_pretrained: bool = True 59 | proj: str = 'mlp' 60 | pooler_type: str = 'mean_pooler' 61 | embed_cls: bool = False 62 | pad_id: int = 0 63 | output_tokens: bool = False 64 | 65 | 66 | def get_cast_dtype(precision: str): 67 | cast_dtype = None 68 | if precision == 'bf16': 69 | cast_dtype = torch.bfloat16 70 | elif precision == 'fp16': 71 | cast_dtype = torch.float16 72 | return cast_dtype 73 | 74 | 75 | def _build_vision_tower( 76 | embed_dim: int, 77 | vision_cfg: CLIPVisionCfg, 78 | quick_gelu: bool = False, 79 | cast_dtype: Optional[torch.dtype] = None 80 | ): 81 | if isinstance(vision_cfg, dict): 82 | vision_cfg = CLIPVisionCfg(**vision_cfg) 83 | 84 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more 85 | # memory efficient in recent PyTorch releases (>= 1.10). 86 | # NOTE: timm models always use native GELU regardless of quick_gelu flag. 87 | act_layer = QuickGELU if quick_gelu else nn.GELU 88 | 89 | if vision_cfg.timm_model_name: 90 | visual = TimmModel( 91 | vision_cfg.timm_model_name, 92 | pretrained=vision_cfg.timm_model_pretrained, 93 | pool=vision_cfg.timm_pool, 94 | proj=vision_cfg.timm_proj, 95 | proj_bias=vision_cfg.timm_proj_bias, 96 | drop=vision_cfg.timm_drop, 97 | drop_path=vision_cfg.timm_drop_path, 98 | embed_dim=embed_dim, 99 | image_size=vision_cfg.image_size, 100 | ) 101 | act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models 102 | elif isinstance(vision_cfg.layers, (tuple, list)): 103 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width 104 | visual = ModifiedResNet( 105 | layers=vision_cfg.layers, 106 | output_dim=embed_dim, 107 | heads=vision_heads, 108 | image_size=vision_cfg.image_size, 109 | width=vision_cfg.width, 110 | ) 111 | else: 112 | vision_heads = vision_cfg.width // vision_cfg.head_width 113 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 114 | visual = VisionTransformer( 115 | image_size=vision_cfg.image_size, 116 | patch_size=vision_cfg.patch_size, 117 | width=vision_cfg.width, 118 | layers=vision_cfg.layers, 119 | heads=vision_heads, 120 | mlp_ratio=vision_cfg.mlp_ratio, 121 | ls_init_value=vision_cfg.ls_init_value, 122 | patch_dropout=vision_cfg.patch_dropout, 123 | input_patchnorm=vision_cfg.input_patchnorm, 124 | global_average_pool=vision_cfg.global_average_pool, 125 | attentional_pool=vision_cfg.attentional_pool, 126 | n_queries=vision_cfg.n_queries, 127 | attn_pooler_heads=vision_cfg.attn_pooler_heads, 128 | output_tokens=vision_cfg.output_tokens, 129 | output_dim=embed_dim, 130 | act_layer=act_layer, 131 | norm_layer=norm_layer, 132 | ) 133 | 134 | return visual 135 | 136 | 137 | def _build_text_tower( 138 | embed_dim: int, 139 | text_cfg: CLIPTextCfg, 140 | quick_gelu: bool = False, 141 | cast_dtype: Optional[torch.dtype] = None, 142 | ): 143 | if isinstance(text_cfg, dict): 144 | text_cfg = CLIPTextCfg(**text_cfg) 145 | 146 | if text_cfg.hf_model_name: 147 | text = HFTextEncoder( 148 | text_cfg.hf_model_name, 149 | output_dim=embed_dim, 150 | proj=text_cfg.proj, 151 | pooler_type=text_cfg.pooler_type, 152 | pretrained=text_cfg.hf_model_pretrained, 153 | output_tokens=text_cfg.output_tokens, 154 | ) 155 | else: 156 | act_layer = QuickGELU if quick_gelu else nn.GELU 157 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 158 | 159 | text = TextTransformer( 160 | context_length=text_cfg.context_length, 161 | vocab_size=text_cfg.vocab_size, 162 | width=text_cfg.width, 163 | heads=text_cfg.heads, 164 | layers=text_cfg.layers, 165 | ls_init_value=text_cfg.ls_init_value, 166 | output_dim=embed_dim, 167 | embed_cls=text_cfg.embed_cls, 168 | output_tokens=text_cfg.output_tokens, 169 | pad_id=text_cfg.pad_id, 170 | act_layer=act_layer, 171 | norm_layer=norm_layer, 172 | ) 173 | return text 174 | 175 | 176 | class CLIP(nn.Module): 177 | output_dict: torch.jit.Final[bool] 178 | 179 | def __init__( 180 | self, 181 | embed_dim: int, 182 | vision_cfg: CLIPVisionCfg, 183 | text_cfg: CLIPTextCfg, 184 | quick_gelu: bool = False, 185 | cast_dtype: Optional[torch.dtype] = None, 186 | output_dict: bool = False, 187 | ): 188 | super().__init__() 189 | self.output_dict = output_dict 190 | self.context_length = text_cfg["context_length"] 191 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) 192 | 193 | text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) 194 | self.transformer = text.transformer 195 | self.vocab_size = text.vocab_size 196 | self.token_embedding = text.token_embedding 197 | self.positional_embedding = text.positional_embedding 198 | self.ln_final = text.ln_final 199 | self.text_projection = text.text_projection 200 | self.output_dim = text.output_dim 201 | self.cls_embed = text.cls_emb 202 | self.pad_id = text.pad_id 203 | self.output_tokens = text.output_tokens 204 | self.register_buffer('attn_mask', text.attn_mask, persistent=False) 205 | 206 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 207 | 208 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): 209 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 210 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) 211 | 212 | @torch.jit.ignore 213 | def set_grad_checkpointing(self, enable=True): 214 | self.visual.set_grad_checkpointing(enable) 215 | self.transformer.grad_checkpointing = enable 216 | 217 | def encode_image(self, image, normalize: bool = False): 218 | features = self.visual(image) 219 | return F.normalize(features, dim=-1) if normalize else features 220 | 221 | def encode_text(self, text, normalize: bool = False): 222 | cast_dtype = self.transformer.get_cast_dtype() 223 | 224 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] 225 | 226 | x = x + self.positional_embedding.to(cast_dtype) 227 | x = x.permute(1, 0, 2) # NLD -> LND 228 | x = self.transformer(x, attn_mask=self.attn_mask) 229 | x = x.permute(1, 0, 2) # LND -> NLD 230 | x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] 231 | # take features from the eot embedding (eot_token is the highest number in each sequence) 232 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 233 | return F.normalize(x, dim=-1) if normalize else x 234 | 235 | def forward(self, image, text): 236 | image_features = self.encode_image(image, normalize=True) 237 | text_features = self.encode_text(text, normalize=True) 238 | if self.output_dict: 239 | return { 240 | "image_features": image_features, 241 | "text_features": text_features, 242 | "logit_scale": self.logit_scale.exp() 243 | } 244 | return image_features, text_features, self.logit_scale.exp() 245 | 246 | 247 | class CustomTextCLIP(nn.Module): 248 | output_dict: torch.jit.Final[bool] 249 | 250 | def __init__( 251 | self, 252 | embed_dim: int, 253 | vision_cfg: CLIPVisionCfg, 254 | text_cfg: CLIPTextCfg, 255 | quick_gelu: bool = False, 256 | cast_dtype: Optional[torch.dtype] = None, 257 | output_dict: bool = False, 258 | ): 259 | super().__init__() 260 | self.output_dict = output_dict 261 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) 262 | self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) 263 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 264 | 265 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): 266 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 267 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) 268 | 269 | def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): 270 | self.text.lock(unlocked_layers, freeze_layer_norm) 271 | 272 | @torch.jit.ignore 273 | def set_grad_checkpointing(self, enable=True): 274 | self.visual.set_grad_checkpointing(enable) 275 | self.text.set_grad_checkpointing(enable) 276 | 277 | def encode_image(self, image, normalize: bool = False): 278 | features = self.visual(image) 279 | return F.normalize(features, dim=-1) if normalize else features 280 | 281 | def encode_text(self, text, normalize: bool = False): 282 | features = self.text(text) 283 | return F.normalize(features, dim=-1) if normalize else features 284 | 285 | def forward(self, image, text): 286 | image_features = self.encode_image(image, normalize=True) 287 | text_features = self.encode_text(text, normalize=True) 288 | if self.output_dict: 289 | return { 290 | "image_features": image_features, 291 | "text_features": text_features, 292 | "logit_scale": self.logit_scale.exp() 293 | } 294 | return image_features, text_features, self.logit_scale.exp() 295 | 296 | 297 | def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): 298 | """Convert applicable model parameters to low-precision (bf16 or fp16)""" 299 | 300 | def _convert_weights(l): 301 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 302 | l.weight.data = l.weight.data.to(dtype) 303 | if l.bias is not None: 304 | l.bias.data = l.bias.data.to(dtype) 305 | 306 | if isinstance(l, (nn.MultiheadAttention, Attention)): 307 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 308 | tensor = getattr(l, attr) 309 | if tensor is not None: 310 | tensor.data = tensor.data.to(dtype) 311 | 312 | for name in ["text_projection", "proj"]: 313 | if hasattr(l, name): 314 | attr = getattr(l, name) 315 | if attr is not None: 316 | attr.data = attr.data.to(dtype) 317 | 318 | model.apply(_convert_weights) 319 | 320 | 321 | convert_weights_to_fp16 = convert_weights_to_lp # backwards compat 322 | 323 | 324 | # used to maintain checkpoint compatibility 325 | def convert_to_custom_text_state_dict(state_dict: dict): 326 | if 'text_projection' in state_dict: 327 | # old format state_dict, move text tower -> .text 328 | new_state_dict = {} 329 | for k, v in state_dict.items(): 330 | if any(k.startswith(p) for p in ( 331 | 'text_projection', 332 | 'positional_embedding', 333 | 'token_embedding', 334 | 'transformer', 335 | 'ln_final', 336 | )): 337 | k = 'text.' + k 338 | new_state_dict[k] = v 339 | return new_state_dict 340 | return state_dict 341 | 342 | 343 | def build_model_from_openai_state_dict( 344 | state_dict: dict, 345 | quick_gelu=True, 346 | cast_dtype=torch.float16, 347 | ): 348 | vit = "visual.proj" in state_dict 349 | 350 | if vit: 351 | vision_width = state_dict["visual.conv1.weight"].shape[0] 352 | vision_layers = len( 353 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 354 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 355 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 356 | image_size = vision_patch_size * grid_size 357 | else: 358 | counts: list = [ 359 | len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 360 | vision_layers = tuple(counts) 361 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 362 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 363 | vision_patch_size = None 364 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 365 | image_size = output_width * 32 366 | 367 | embed_dim = state_dict["text_projection"].shape[1] 368 | context_length = state_dict["positional_embedding"].shape[0] 369 | vocab_size = state_dict["token_embedding.weight"].shape[0] 370 | transformer_width = state_dict["ln_final.weight"].shape[0] 371 | transformer_heads = transformer_width // 64 372 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 373 | 374 | vision_cfg = CLIPVisionCfg( 375 | layers=vision_layers, 376 | width=vision_width, 377 | patch_size=vision_patch_size, 378 | image_size=image_size, 379 | ) 380 | text_cfg = CLIPTextCfg( 381 | context_length=context_length, 382 | vocab_size=vocab_size, 383 | width=transformer_width, 384 | heads=transformer_heads, 385 | layers=transformer_layers, 386 | ) 387 | model = CLIP( 388 | embed_dim, 389 | vision_cfg=vision_cfg, 390 | text_cfg=text_cfg, 391 | quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU 392 | cast_dtype=cast_dtype, 393 | ) 394 | 395 | for key in ["input_resolution", "context_length", "vocab_size"]: 396 | state_dict.pop(key, None) 397 | 398 | convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 399 | model.load_state_dict(state_dict) 400 | return model.eval() 401 | 402 | 403 | def trace_model(model, batch_size=256, device=torch.device('cpu')): 404 | model.eval() 405 | image_size = model.visual.image_size 406 | example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) 407 | example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) 408 | model = torch.jit.trace_module( 409 | model, 410 | inputs=dict( 411 | forward=(example_images, example_text), 412 | encode_text=(example_text,), 413 | encode_image=(example_images,) 414 | )) 415 | model.visual.image_size = image_size 416 | return model 417 | 418 | 419 | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): 420 | # Rescale the grid of position embeddings when loading from state_dict 421 | old_pos_embed = state_dict.get('visual.positional_embedding', None) 422 | if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): 423 | return 424 | grid_size = to_2tuple(model.visual.grid_size) 425 | extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) 426 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens 427 | if new_seq_len == old_pos_embed.shape[0]: 428 | return 429 | 430 | if extra_tokens: 431 | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] 432 | else: 433 | pos_emb_tok, pos_emb_img = None, old_pos_embed 434 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) 435 | 436 | logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) 437 | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) 438 | pos_emb_img = F.interpolate( 439 | pos_emb_img, 440 | size=grid_size, 441 | mode=interpolation, 442 | antialias=antialias, 443 | align_corners=False, 444 | ) 445 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] 446 | if pos_emb_tok is not None: 447 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) 448 | else: 449 | new_pos_embed = pos_emb_img 450 | state_dict['visual.positional_embedding'] = new_pos_embed 451 | -------------------------------------------------------------------------------- /OPENCLIP/model_configs/RN101-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 23, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/RN101.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 23, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/RN50-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": [ 7 | 3, 8 | 4, 9 | 6, 10 | 3 11 | ], 12 | "width": 64, 13 | "patch_size": null 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 512, 19 | "heads": 8, 20 | "layers": 12 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /OPENCLIP/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": [ 6 | 3, 7 | 4, 8 | 6, 9 | 3 10 | ], 11 | "width": 64, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 512, 18 | "heads": 8, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/RN50x16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 384, 5 | "layers": [ 6 | 6, 7 | 8, 8 | 18, 9 | 8 10 | ], 11 | "width": 96, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 768, 18 | "heads": 12, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/RN50x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 288, 5 | "layers": [ 6 | 4, 7 | 6, 8 | 10, 9 | 6 10 | ], 11 | "width": 80, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 640, 18 | "heads": 10, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/RN50x64.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 448, 5 | "layers": [ 6 | 3, 7 | 15, 8 | 36, 9 | 10 10 | ], 11 | "width": 128, 12 | "patch_size": null 13 | }, 14 | "text_cfg": { 15 | "context_length": 77, 16 | "vocab_size": 49408, 17 | "width": 1024, 18 | "heads": 16, 19 | "layers": 12 20 | } 21 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-L-14-280.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 280, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-M-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16, 8 | "ls_init_value": 1e-4 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 384, 14 | "heads": 6, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-M-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-M-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-M-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-S-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-S-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-S-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-S-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-bigG-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 32 17 | } 18 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-e-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 56, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.5715, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 36 17 | } 18 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/coca_ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 512, 25 | "heads": 8, 26 | "layers": 12, 27 | "attn_pooler_heads": 8 28 | }, 29 | "custom_text": true 30 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/coca_ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 768, 25 | "heads": 12, 26 | "layers": 12, 27 | "attn_pooler_heads": 12 28 | }, 29 | "custom_text": true 30 | } 31 | -------------------------------------------------------------------------------- /OPENCLIP/model_configs/coca_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "multimodal_cfg": { 4 | "width": 768, 5 | "context_length": 76, 6 | "vocab_size": 64000, 7 | "mlp_ratio": 4, 8 | "layers": 12, 9 | "dim_head": 64, 10 | "heads": 12, 11 | "n_queries": 256, 12 | "attn_pooler_heads": 8 13 | }, 14 | "vision_cfg": { 15 | "image_size": 288, 16 | "layers": 12, 17 | "width": 768, 18 | "patch_size": 18, 19 | "output_tokens": true 20 | }, 21 | "text_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 64000, 24 | "layers": 12, 25 | "heads": 12, 26 | "width": 768, 27 | "embed_cls": true, 28 | "output_tokens": true 29 | }, 30 | "custom_text": true 31 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/coca_roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "output_tokens": true 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "proj": "linear", 14 | "width": 768, 15 | "output_tokens": true 16 | }, 17 | "multimodal_cfg": { 18 | "context_length": 76, 19 | "width": 768, 20 | "heads": 8, 21 | "layers": 12 22 | }, 23 | "custom_text": true 24 | } 25 | -------------------------------------------------------------------------------- /OPENCLIP/model_configs/convnext_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/convnext_base_w.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/convnext_base_w_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_base", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 640, 16 | "heads": 10, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/convnext_large.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/convnext_large_d.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/convnext_large_d_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_large", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "mlp", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 16 18 | } 19 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/convnext_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_small", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/convnext_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_tiny", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 224 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/convnext_xlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 20 18 | } 19 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/convnext_xxlarge.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 256 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/convnext_xxlarge_320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "timm_model_name": "convnext_xxlarge", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "timm_drop": 0.0, 9 | "timm_drop_path": 0.1, 10 | "image_size": 320 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1024, 16 | "heads": 16, 17 | "layers": 24 18 | } 19 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/mt5-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "google/mt5-base", 11 | "hf_tokenizer_name": "google/mt5-base", 12 | "proj": "mlp", 13 | "pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /OPENCLIP/model_configs/mt5-xl-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "google/mt5-xl", 12 | "hf_tokenizer_name": "google/mt5-xl", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /OPENCLIP/model_configs/roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /OPENCLIP/model_configs/swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 640, 14 | "heads": 10, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/vit_medium_patch16_gap_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_medium_patch16_gap_256", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 256 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/vit_relpos_medium_patch16_cls_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_relpos_medium_patch16_cls_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /OPENCLIP/model_configs/xlm-roberta-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "xlm-roberta-base", 11 | "hf_tokenizer_name": "xlm-roberta-base", 12 | "proj": "mlp", 13 | "pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /OPENCLIP/model_configs/xlm-roberta-large-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "xlm-roberta-large", 12 | "hf_tokenizer_name": "xlm-roberta-large", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /OPENCLIP/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from open_clip.utils import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /OPENCLIP/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available Multitasking models""" 20 | return list_pretrained_models_by_tag('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | jit: bool = True, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a Multitasking model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | cache_dir : Optional[str] 43 | The directory to cache the downloaded model weights 44 | 45 | Returns 46 | ------- 47 | model : torch.nn.Module 48 | The Multitasking model 49 | preprocess : Callable[[PIL.Image], torch.Tensor] 50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 51 | """ 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = 'fp32' if device == 'cpu' else 'fp16' 56 | 57 | if get_pretrained_url(name, 'openai'): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 67 | state_dict = None 68 | except RuntimeError: 69 | # loading saved state dict 70 | if jit: 71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 72 | jit = False 73 | state_dict = torch.load(model_path, map_location="cpu") 74 | 75 | if not jit: 76 | # Build a non-jit model from the OpenAI jitted model state dict 77 | cast_dtype = get_cast_dtype(precision) 78 | try: 79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 80 | except KeyError: 81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 83 | 84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 85 | model = model.to(device) 86 | if precision.startswith('amp') or precision == 'fp32': 87 | model.float() 88 | elif precision == 'bf16': 89 | convert_weights_to_lp(model, dtype=torch.bfloat16) 90 | 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 96 | 97 | def patch_device(module): 98 | try: 99 | graphs = [module.graph] if hasattr(module, "graph") else [] 100 | except RuntimeError: 101 | graphs = [] 102 | 103 | if hasattr(module, "forward1"): 104 | graphs.append(module.forward1.graph) 105 | 106 | for graph in graphs: 107 | for node in graph.findAllNodes("prim::Constant"): 108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 109 | node.copyAttributes(device_node) 110 | 111 | model.apply(patch_device) 112 | patch_device(model.encode_image) 113 | patch_device(model.encode_text) 114 | 115 | # patch dtype to float32 (typically for CPU) 116 | if precision == 'fp32': 117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 119 | float_node = float_input.node() 120 | 121 | def patch_float(module): 122 | try: 123 | graphs = [module.graph] if hasattr(module, "graph") else [] 124 | except RuntimeError: 125 | graphs = [] 126 | 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("aten::to"): 132 | inputs = list(node.inputs()) 133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 134 | if inputs[i].node()["value"] == 5: 135 | inputs[i].node().copyAttributes(float_node) 136 | 137 | model.apply(patch_float) 138 | patch_float(model.encode_image) 139 | patch_float(model.encode_text) 140 | model.float() 141 | 142 | # ensure image_size attr available at consistent location for both jit and non-jit 143 | model.visual.image_size = model.input_resolution.item() 144 | return model 145 | -------------------------------------------------------------------------------- /OPENCLIP/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from functools import partial 6 | from typing import Dict, Union 7 | 8 | from tqdm import tqdm 9 | 10 | from .version import __version__ 11 | 12 | try: 13 | from huggingface_hub import hf_hub_download 14 | hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) 15 | _has_hf_hub = True 16 | except ImportError: 17 | hf_hub_download = None 18 | _has_hf_hub = False 19 | 20 | 21 | def _pcfg(url='', hf_hub='', mean=None, std=None): 22 | return dict( 23 | url=url, 24 | hf_hub=hf_hub, 25 | mean=mean, 26 | std=std, 27 | ) 28 | 29 | 30 | _RN50 = dict( 31 | openai=_pcfg( 32 | "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), 33 | yfcc15m=_pcfg( 34 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), 35 | cc12m=_pcfg( 36 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), 37 | ) 38 | 39 | _RN50_quickgelu = dict( 40 | openai=_pcfg( 41 | "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), 42 | yfcc15m=_pcfg( 43 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), 44 | cc12m=_pcfg( 45 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), 46 | ) 47 | 48 | _RN101 = dict( 49 | openai=_pcfg( 50 | "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), 51 | yfcc15m=_pcfg( 52 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), 53 | ) 54 | 55 | _RN101_quickgelu = dict( 56 | openai=_pcfg( 57 | "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), 58 | yfcc15m=_pcfg( 59 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), 60 | ) 61 | 62 | _RN50x4 = dict( 63 | openai=_pcfg( 64 | "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), 65 | ) 66 | 67 | _RN50x16 = dict( 68 | openai=_pcfg( 69 | "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), 70 | ) 71 | 72 | _RN50x64 = dict( 73 | openai=_pcfg( 74 | "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), 75 | ) 76 | 77 | _VITB32 = dict( 78 | openai=_pcfg( 79 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), 80 | laion400m_e31=_pcfg( 81 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), 82 | laion400m_e32=_pcfg( 83 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), 84 | laion2b_e16=_pcfg( 85 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), 86 | laion2b_s34b_b79k=_pcfg(hf_hub='laion/Multitasking-ViT-B-32-laion2B-s34B-b79K/') 87 | ) 88 | 89 | _VITB32_quickgelu = dict( 90 | openai=_pcfg( 91 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), 92 | laion400m_e31=_pcfg( 93 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), 94 | laion400m_e32=_pcfg( 95 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), 96 | ) 97 | 98 | _VITB16 = dict( 99 | openai=_pcfg( 100 | "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), 101 | laion400m_e31=_pcfg( 102 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), 103 | laion400m_e32=_pcfg( 104 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), 105 | # laion400m_32k=_pcfg( 106 | # url="", 107 | # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 108 | # laion400m_64k=_pcfg( 109 | # url="", 110 | # mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 111 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/Multitasking-ViT-B-16-laion2B-s34B-b88K/'), 112 | ) 113 | 114 | _VITB16_PLUS_240 = dict( 115 | laion400m_e31=_pcfg( 116 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), 117 | laion400m_e32=_pcfg( 118 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), 119 | ) 120 | 121 | _VITL14 = dict( 122 | openai=_pcfg( 123 | "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), 124 | laion400m_e31=_pcfg( 125 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), 126 | laion400m_e32=_pcfg( 127 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), 128 | laion2b_s32b_b82k=_pcfg( 129 | hf_hub='laion/Multitasking-ViT-L-14-laion2B-s32B-b82K/', 130 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 131 | ) 132 | 133 | _VITL14_336 = dict( 134 | openai=_pcfg( 135 | "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), 136 | ) 137 | 138 | _VITH14 = dict( 139 | laion2b_s32b_b79k=_pcfg(hf_hub='laion/Multitasking-ViT-H-14-laion2B-s32B-b79K/'), 140 | ) 141 | 142 | _VITg14 = dict( 143 | laion2b_s12b_b42k=_pcfg(hf_hub='laion/Multitasking-ViT-g-14-laion2B-s12B-b42K/'), 144 | ) 145 | 146 | _VITbigG14 = dict( 147 | laion2b_s39b_b160k=_pcfg(hf_hub='laion/Multitasking-ViT-bigG-14-laion2B-39B-b160k/'), 148 | ) 149 | 150 | _robertaViTB32 = dict( 151 | laion2b_s12b_b32k=_pcfg(hf_hub='laion/Multitasking-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), 152 | ) 153 | 154 | _xlmRobertaBaseViTB32 = dict( 155 | laion5b_s13b_b90k=_pcfg(hf_hub='laion/Multitasking-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), 156 | ) 157 | 158 | _xlmRobertaLargeFrozenViTH14 = dict( 159 | frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/Multitasking-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), 160 | ) 161 | 162 | _convnext_base = dict( 163 | laion400m_s13b_b51k=_pcfg(hf_hub='laion/Multitasking-convnext_base-laion400M-s13B-b51K/'), 164 | ) 165 | 166 | _convnext_base_w = dict( 167 | laion2b_s13b_b82k=_pcfg(hf_hub='laion/Multitasking-convnext_base_w-laion2B-s13B-b82K/'), 168 | laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/Multitasking-convnext_base_w-laion2B-s13B-b82K-augreg/'), 169 | laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/Multitasking-convnext_base_w-laion_aesthetic-s13B-b82K/'), 170 | ) 171 | 172 | _convnext_base_w_320 = dict( 173 | laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/Multitasking-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), 174 | laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/Multitasking-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), 175 | ) 176 | 177 | _convnext_large_d = dict( 178 | laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/Multitasking-convnext_large_d.laion2B-s26B-b102K-augreg/'), 179 | ) 180 | 181 | _convnext_large_d_320 = dict( 182 | laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/Multitasking-convnext_large_d_320.laion2B-s29B-b131K-ft/'), 183 | laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/Multitasking-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), 184 | ) 185 | 186 | _coca_VITB32 = dict( 187 | laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), 188 | mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') 189 | ) 190 | 191 | _coca_VITL14 = dict( 192 | laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), 193 | mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') 194 | ) 195 | 196 | 197 | _PRETRAINED = { 198 | "RN50": _RN50, 199 | "RN50-quickgelu": _RN50_quickgelu, 200 | "RN101": _RN101, 201 | "RN101-quickgelu": _RN101_quickgelu, 202 | "RN50x4": _RN50x4, 203 | "RN50x16": _RN50x16, 204 | "RN50x64": _RN50x64, 205 | "ViT-B-32": _VITB32, 206 | "ViT-B-32-quickgelu": _VITB32_quickgelu, 207 | "ViT-B-16": _VITB16, 208 | "ViT-B-16-plus-240": _VITB16_PLUS_240, 209 | "ViT-L-14": _VITL14, 210 | "ViT-L-14-336": _VITL14_336, 211 | "ViT-H-14": _VITH14, 212 | "ViT-g-14": _VITg14, 213 | "ViT-bigG-14": _VITbigG14, 214 | "roberta-ViT-B-32": _robertaViTB32, 215 | "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, 216 | "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, 217 | "convnext_base": _convnext_base, 218 | "convnext_base_w": _convnext_base_w, 219 | "convnext_base_w_320": _convnext_base_w_320, 220 | "convnext_large_d": _convnext_large_d, 221 | "convnext_large_d_320": _convnext_large_d_320, 222 | "coca_ViT-B-32": _coca_VITB32, 223 | "coca_ViT-L-14": _coca_VITL14, 224 | } 225 | 226 | 227 | def _clean_tag(tag: str): 228 | # normalize pretrained tags 229 | return tag.lower().replace('-', '_') 230 | 231 | 232 | def list_pretrained(as_str: bool = False): 233 | """ returns list of pretrained models 234 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 235 | """ 236 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] 237 | 238 | 239 | def list_pretrained_models_by_tag(tag: str): 240 | """ return all models having the specified pretrain tag """ 241 | models = [] 242 | tag = _clean_tag(tag) 243 | for k in _PRETRAINED.keys(): 244 | if tag in _PRETRAINED[k]: 245 | models.append(k) 246 | return models 247 | 248 | 249 | def list_pretrained_tags_by_model(model: str): 250 | """ return all pretrain tags for the specified model architecture """ 251 | tags = [] 252 | if model in _PRETRAINED: 253 | tags.extend(_PRETRAINED[model].keys()) 254 | return tags 255 | 256 | 257 | def is_pretrained_cfg(model: str, tag: str): 258 | if model not in _PRETRAINED: 259 | return False 260 | return _clean_tag(tag) in _PRETRAINED[model] 261 | 262 | 263 | def get_pretrained_cfg(model: str, tag: str): 264 | if model not in _PRETRAINED: 265 | return {} 266 | model_pretrained = _PRETRAINED[model] 267 | return model_pretrained.get(_clean_tag(tag), {}) 268 | 269 | 270 | def get_pretrained_url(model: str, tag: str): 271 | cfg = get_pretrained_cfg(model, _clean_tag(tag)) 272 | return cfg.get('url', '') 273 | 274 | 275 | def download_pretrained_from_url( 276 | url: str, 277 | cache_dir: Union[str, None] = None, 278 | ): 279 | if not cache_dir: 280 | cache_dir = os.path.expanduser("~/.cache/clip") 281 | os.makedirs(cache_dir, exist_ok=True) 282 | filename = os.path.basename(url) 283 | 284 | if 'openaipublic' in url: 285 | expected_sha256 = url.split("/")[-2] 286 | elif 'mlfoundations' in url: 287 | expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] 288 | else: 289 | expected_sha256 = '' 290 | 291 | download_target = os.path.join(cache_dir, filename) 292 | 293 | if os.path.exists(download_target) and not os.path.isfile(download_target): 294 | raise RuntimeError(f"{download_target} exists and is not a regular file") 295 | 296 | if os.path.isfile(download_target): 297 | if expected_sha256: 298 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 299 | return download_target 300 | else: 301 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 302 | else: 303 | return download_target 304 | 305 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 306 | with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 307 | while True: 308 | buffer = source.read(8192) 309 | if not buffer: 310 | break 311 | 312 | output.write(buffer) 313 | loop.update(len(buffer)) 314 | 315 | if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 316 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 317 | 318 | return download_target 319 | 320 | 321 | def has_hf_hub(necessary=False): 322 | if not _has_hf_hub and necessary: 323 | # if no HF Hub module installed, and it is necessary to continue, raise error 324 | raise RuntimeError( 325 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') 326 | return _has_hf_hub 327 | 328 | 329 | def download_pretrained_from_hf( 330 | model_id: str, 331 | filename: str = 'open_clip_pytorch_model.bin', 332 | revision=None, 333 | cache_dir: Union[str, None] = None, 334 | ): 335 | has_hf_hub(True) 336 | cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) 337 | return cached_file 338 | 339 | 340 | def download_pretrained( 341 | cfg: Dict, 342 | force_hf_hub: bool = False, 343 | cache_dir: Union[str, None] = None, 344 | ): 345 | target = '' 346 | if not cfg: 347 | return target 348 | 349 | download_url = cfg.get('url', '') 350 | download_hf_hub = cfg.get('hf_hub', '') 351 | if download_hf_hub and force_hf_hub: 352 | # use HF hub even if url exists 353 | download_url = '' 354 | 355 | if download_url: 356 | target = download_pretrained_from_url(download_url, cache_dir=cache_dir) 357 | elif download_hf_hub: 358 | has_hf_hub(True) 359 | # we assume the hf_hub entries in pretrained config combine model_id + filename in 360 | # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and 361 | # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. 362 | model_id, filename = os.path.split(download_hf_hub) 363 | if filename: 364 | target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) 365 | else: 366 | target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) 367 | 368 | return target 369 | -------------------------------------------------------------------------------- /OPENCLIP/push_to_hf_hub.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | from tempfile import TemporaryDirectory 5 | from typing import Optional, Tuple 6 | 7 | import torch 8 | 9 | try: 10 | from huggingface_hub import ( 11 | create_repo, 12 | get_hf_file_metadata, 13 | hf_hub_download, 14 | hf_hub_url, 15 | repo_type_and_id_from_hf_id, 16 | upload_folder, 17 | ) 18 | from huggingface_hub.utils import EntryNotFoundError 19 | _has_hf_hub = True 20 | except ImportError: 21 | _has_hf_hub = False 22 | 23 | from .factory import create_model_from_pretrained, get_model_config, get_tokenizer 24 | from .tokenizer import HFTokenizer 25 | 26 | 27 | def save_config_for_hf( 28 | model, 29 | config_path: str, 30 | model_config: Optional[dict] 31 | ): 32 | preprocess_cfg = { 33 | 'mean': model.visual.image_mean, 34 | 'std': model.visual.image_std, 35 | } 36 | hf_config = { 37 | 'model_cfg': model_config, 38 | 'preprocess_cfg': preprocess_cfg, 39 | } 40 | 41 | with config_path.open('w') as f: 42 | json.dump(hf_config, f, indent=2) 43 | 44 | 45 | def save_for_hf( 46 | model, 47 | tokenizer: HFTokenizer, 48 | model_config: dict, 49 | save_directory: str, 50 | weights_filename='open_clip_pytorch_model.bin', 51 | config_filename='open_clip_config.json', 52 | ): 53 | save_directory = Path(save_directory) 54 | save_directory.mkdir(exist_ok=True, parents=True) 55 | 56 | weights_path = save_directory / weights_filename 57 | torch.save(model.state_dict(), weights_path) 58 | 59 | tokenizer.save_pretrained(save_directory) 60 | 61 | config_path = save_directory / config_filename 62 | save_config_for_hf(model, config_path, model_config=model_config) 63 | 64 | 65 | def push_to_hf_hub( 66 | model, 67 | tokenizer, 68 | model_config: Optional[dict], 69 | repo_id: str, 70 | commit_message: str = 'Add model', 71 | token: Optional[str] = None, 72 | revision: Optional[str] = None, 73 | private: bool = False, 74 | create_pr: bool = False, 75 | model_card: Optional[dict] = None, 76 | ): 77 | if not isinstance(tokenizer, HFTokenizer): 78 | # default Multitasking tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 79 | tokenizer = HFTokenizer('openai/clip-vit-large-patch14') 80 | 81 | # Create repo if it doesn't exist yet 82 | repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) 83 | 84 | # Infer complete repo_id from repo_url 85 | # Can be different from the input `repo_id` if repo_owner was implicit 86 | _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) 87 | repo_id = f"{repo_owner}/{repo_name}" 88 | 89 | # Check if README file already exist in repo 90 | try: 91 | get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) 92 | has_readme = True 93 | except EntryNotFoundError: 94 | has_readme = False 95 | 96 | # Dump model and push to Hub 97 | with TemporaryDirectory() as tmpdir: 98 | # Save model weights and config. 99 | save_for_hf( 100 | model, 101 | tokenizer=tokenizer, 102 | model_config=model_config, 103 | save_directory=tmpdir, 104 | ) 105 | 106 | # Add readme if it does not exist 107 | if not has_readme: 108 | model_card = model_card or {} 109 | model_name = repo_id.split('/')[-1] 110 | readme_path = Path(tmpdir) / "README.md" 111 | readme_text = generate_readme(model_card, model_name) 112 | readme_path.write_text(readme_text) 113 | 114 | # Upload model and return 115 | return upload_folder( 116 | repo_id=repo_id, 117 | folder_path=tmpdir, 118 | revision=revision, 119 | create_pr=create_pr, 120 | commit_message=commit_message, 121 | ) 122 | 123 | 124 | def push_pretrained_to_hf_hub( 125 | model_name, 126 | pretrained: str, 127 | repo_id: str, 128 | image_mean: Optional[Tuple[float, ...]] = None, 129 | image_std: Optional[Tuple[float, ...]] = None, 130 | commit_message: str = 'Add model', 131 | token: Optional[str] = None, 132 | revision: Optional[str] = None, 133 | private: bool = False, 134 | create_pr: bool = False, 135 | model_card: Optional[dict] = None, 136 | ): 137 | model, preprocess_eval = create_model_from_pretrained( 138 | model_name, 139 | pretrained=pretrained, 140 | image_mean=image_mean, 141 | image_std=image_std, 142 | ) 143 | 144 | model_config = get_model_config(model_name) 145 | assert model_config 146 | 147 | tokenizer = get_tokenizer(model_name) 148 | 149 | push_to_hf_hub( 150 | model=model, 151 | tokenizer=tokenizer, 152 | model_config=model_config, 153 | repo_id=repo_id, 154 | commit_message=commit_message, 155 | token=token, 156 | revision=revision, 157 | private=private, 158 | create_pr=create_pr, 159 | model_card=model_card, 160 | ) 161 | 162 | 163 | def generate_readme(model_card: dict, model_name: str): 164 | readme_text = "---\n" 165 | readme_text += "tags:\n- zero-shot-image-classification\n- clip\n" 166 | readme_text += "library_tag: open_clip\n" 167 | readme_text += f"license: {model_card.get('license', 'mit')}\n" 168 | if 'details' in model_card and 'Dataset' in model_card['details']: 169 | readme_text += 'datasets:\n' 170 | readme_text += f"- {model_card['details']['Dataset'].lower()}\n" 171 | readme_text += "---\n" 172 | readme_text += f"# Model card for {model_name}\n" 173 | if 'description' in model_card: 174 | readme_text += f"\n{model_card['description']}\n" 175 | if 'details' in model_card: 176 | readme_text += f"\n## Model Details\n" 177 | for k, v in model_card['details'].items(): 178 | if isinstance(v, (list, tuple)): 179 | readme_text += f"- **{k}:**\n" 180 | for vi in v: 181 | readme_text += f" - {vi}\n" 182 | elif isinstance(v, dict): 183 | readme_text += f"- **{k}:**\n" 184 | for ki, vi in v.items(): 185 | readme_text += f" - {ki}: {vi}\n" 186 | else: 187 | readme_text += f"- **{k}:** {v}\n" 188 | if 'usage' in model_card: 189 | readme_text += f"\n## Model Usage\n" 190 | readme_text += model_card['usage'] 191 | readme_text += '\n' 192 | 193 | if 'comparison' in model_card: 194 | readme_text += f"\n## Model Comparison\n" 195 | readme_text += model_card['comparison'] 196 | readme_text += '\n' 197 | 198 | if 'citation' in model_card: 199 | readme_text += f"\n## Citation\n" 200 | if not isinstance(model_card['citation'], (list, tuple)): 201 | citations = [model_card['citation']] 202 | else: 203 | citations = model_card['citation'] 204 | for c in citations: 205 | readme_text += f"```bibtex\n{c}\n```\n" 206 | 207 | return readme_text 208 | 209 | 210 | if __name__ == "__main__": 211 | parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") 212 | parser.add_argument( 213 | "--model", type=str, help="Name of the model to use.", 214 | ) 215 | parser.add_argument( 216 | "--pretrained", type=str, 217 | help="Use a pretrained Multitasking model weights with the specified tag or file path.", 218 | ) 219 | parser.add_argument( 220 | "--repo-id", type=str, 221 | help="Destination HF Hub repo-id ie 'organization/model_id'.", 222 | ) 223 | parser.add_argument( 224 | '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', 225 | help='Override default image mean value of dataset') 226 | parser.add_argument( 227 | '--image-std', type=float, nargs='+', default=None, metavar='STD', 228 | help='Override default image std deviation of of dataset') 229 | args = parser.parse_args() 230 | 231 | print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') 232 | 233 | # FIXME add support to pass model_card json / template from file via cmd line 234 | 235 | push_pretrained_to_hf_hub( 236 | args.model, 237 | args.pretrained, 238 | args.repo_id, 239 | image_mean=args.image_mean, # override image mean/std if trained w/ non defaults 240 | image_std=args.image_std, 241 | ) 242 | 243 | print(f'{args.model} saved.') 244 | -------------------------------------------------------------------------------- /OPENCLIP/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in Multitasking model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from .utils import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 31 | """ 32 | 33 | def __init__( 34 | self, 35 | model_name, 36 | embed_dim, 37 | image_size=224, 38 | pool='avg', 39 | proj='linear', 40 | proj_bias=False, 41 | drop=0., 42 | drop_path=None, 43 | pretrained=False, 44 | ): 45 | super().__init__() 46 | if timm is None: 47 | raise RuntimeError("Please `pip install timm` to use timm models.") 48 | 49 | self.image_size = to_2tuple(image_size) 50 | timm_kwargs = {} 51 | if drop_path is not None: 52 | timm_kwargs['drop_path_rate'] = drop_path 53 | self.trunk = timm.create_model(model_name, pretrained=pretrained, **timm_kwargs) 54 | feat_size = self.trunk.default_cfg.get('pool_size', None) 55 | feature_ndim = 1 if not feat_size else 2 56 | if pool in ('abs_attn', 'rot_attn'): 57 | assert feature_ndim == 2 58 | # if attn pooling used, remove both classifier and default pool 59 | self.trunk.reset_classifier(0, global_pool='') 60 | else: 61 | # reset global pool if pool config set, otherwise leave as network default 62 | reset_kwargs = dict(global_pool=pool) if pool else {} 63 | self.trunk.reset_classifier(0, **reset_kwargs) 64 | prev_chs = self.trunk.num_features 65 | 66 | head_layers = OrderedDict() 67 | if pool == 'abs_attn': 68 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 69 | prev_chs = embed_dim 70 | elif pool == 'rot_attn': 71 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 72 | prev_chs = embed_dim 73 | else: 74 | assert proj, 'projection layer needed if non-attention pooling is used.' 75 | 76 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 77 | if proj == 'linear': 78 | head_layers['drop'] = nn.Dropout(drop) 79 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 80 | elif proj == 'mlp': 81 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) 82 | 83 | self.head = nn.Sequential(head_layers) 84 | 85 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 86 | """ lock modules 87 | Args: 88 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 89 | """ 90 | if not unlocked_groups: 91 | # lock full model 92 | for param in self.trunk.parameters(): 93 | param.requires_grad = False 94 | if freeze_bn_stats: 95 | freeze_batch_norm_2d(self.trunk) 96 | else: 97 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 98 | try: 99 | # FIXME import here until API stable and in an official release 100 | from timm.models.helpers import group_parameters, group_modules 101 | except ImportError: 102 | raise RuntimeError( 103 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 104 | matcher = self.trunk.group_matcher() 105 | gparams = group_parameters(self.trunk, matcher) 106 | max_layer_id = max(gparams.keys()) 107 | max_layer_id = max_layer_id - unlocked_groups 108 | for group_idx in range(max_layer_id + 1): 109 | group = gparams[group_idx] 110 | for param in group: 111 | self.trunk.get_parameter(param).requires_grad = False 112 | if freeze_bn_stats: 113 | gmodules = group_modules(self.trunk, matcher, reverse=True) 114 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 115 | freeze_batch_norm_2d(self.trunk, gmodules) 116 | 117 | @torch.jit.ignore 118 | def set_grad_checkpointing(self, enable=True): 119 | try: 120 | self.trunk.set_grad_checkpointing(enable) 121 | except Exception as e: 122 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 123 | 124 | def forward(self, x): 125 | x = self.trunk(x) 126 | x = self.head(x) 127 | return x 128 | -------------------------------------------------------------------------------- /OPENCLIP/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ Multitasking tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | # https://stackoverflow.com/q/62691279 16 | import os 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | 19 | 20 | @lru_cache() 21 | def default_bpe(): 22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 23 | 24 | 25 | @lru_cache() 26 | def bytes_to_unicode(): 27 | """ 28 | Returns list of utf-8 byte and a corresponding list of unicode strings. 29 | The reversible bpe codes work on unicode strings. 30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 32 | This is a significant percentage of your normal, say, 32K bpe vocab. 33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 34 | And avoids mapping to whitespace/control characters the bpe code barfs on. 35 | """ 36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 37 | cs = bs[:] 38 | n = 0 39 | for b in range(2**8): 40 | if b not in bs: 41 | bs.append(b) 42 | cs.append(2**8+n) 43 | n += 1 44 | cs = [chr(n) for n in cs] 45 | return dict(zip(bs, cs)) 46 | 47 | 48 | def get_pairs(word): 49 | """Return set of symbol pairs in a word. 50 | Word is represented as tuple of symbols (symbols being variable-length strings). 51 | """ 52 | pairs = set() 53 | prev_char = word[0] 54 | for char in word[1:]: 55 | pairs.add((prev_char, char)) 56 | prev_char = char 57 | return pairs 58 | 59 | 60 | def basic_clean(text): 61 | text = ftfy.fix_text(text) 62 | text = html.unescape(html.unescape(text)) 63 | return text.strip() 64 | 65 | 66 | def whitespace_clean(text): 67 | text = re.sub(r'\s+', ' ', text) 68 | text = text.strip() 69 | return text 70 | 71 | 72 | class SimpleTokenizer(object): 73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 74 | self.byte_encoder = bytes_to_unicode() 75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 77 | merges = merges[1:49152-256-2+1] 78 | merges = [tuple(merge.split()) for merge in merges] 79 | vocab = list(bytes_to_unicode().values()) 80 | vocab = vocab + [v+'' for v in vocab] 81 | for merge in merges: 82 | vocab.append(''.join(merge)) 83 | if not special_tokens: 84 | special_tokens = ['', ''] 85 | else: 86 | special_tokens = ['', ''] + special_tokens 87 | vocab.extend(special_tokens) 88 | self.encoder = dict(zip(vocab, range(len(vocab)))) 89 | self.decoder = {v: k for k, v in self.encoder.items()} 90 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 91 | self.cache = {t:t for t in special_tokens} 92 | special = "|".join(special_tokens) 93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 94 | 95 | self.vocab_size = len(self.encoder) 96 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 97 | 98 | def bpe(self, token): 99 | if token in self.cache: 100 | return self.cache[token] 101 | word = tuple(token[:-1]) + ( token[-1] + '',) 102 | pairs = get_pairs(word) 103 | 104 | if not pairs: 105 | return token+'' 106 | 107 | while True: 108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 109 | if bigram not in self.bpe_ranks: 110 | break 111 | first, second = bigram 112 | new_word = [] 113 | i = 0 114 | while i < len(word): 115 | try: 116 | j = word.index(first, i) 117 | new_word.extend(word[i:j]) 118 | i = j 119 | except: 120 | new_word.extend(word[i:]) 121 | break 122 | 123 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 124 | new_word.append(first+second) 125 | i += 2 126 | else: 127 | new_word.append(word[i]) 128 | i += 1 129 | new_word = tuple(new_word) 130 | word = new_word 131 | if len(word) == 1: 132 | break 133 | else: 134 | pairs = get_pairs(word) 135 | word = ' '.join(word) 136 | self.cache[token] = word 137 | return word 138 | 139 | def encode(self, text): 140 | bpe_tokens = [] 141 | text = whitespace_clean(basic_clean(text)).lower() 142 | for token in re.findall(self.pat, text): 143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 145 | return bpe_tokens 146 | 147 | def decode(self, tokens): 148 | text = ''.join([self.decoder[token] for token in tokens]) 149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 150 | return text 151 | 152 | 153 | _tokenizer = SimpleTokenizer() 154 | 155 | def decode(output_ids: torch.Tensor): 156 | output_ids = output_ids.cpu().numpy() 157 | return _tokenizer.decode(output_ids) 158 | 159 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 160 | """ 161 | Returns the tokenized representation of given input string(s) 162 | 163 | Parameters 164 | ---------- 165 | texts : Union[str, List[str]] 166 | An input string or a list of input strings to tokenize 167 | context_length : int 168 | The context length to use; all Multitasking models use 77 as the context length 169 | 170 | Returns 171 | ------- 172 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 173 | """ 174 | if isinstance(texts, str): 175 | texts = [texts] 176 | 177 | sot_token = _tokenizer.encoder[""] 178 | eot_token = _tokenizer.encoder[""] 179 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 180 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 181 | 182 | for i, tokens in enumerate(all_tokens): 183 | if len(tokens) > context_length: 184 | tokens = tokens[:context_length] # Truncate 185 | tokens[-1] = eot_token 186 | result[i, :len(tokens)] = torch.tensor(tokens) 187 | 188 | return result 189 | 190 | 191 | class HFTokenizer: 192 | """HuggingFace tokenizer wrapper""" 193 | 194 | def __init__(self, tokenizer_name: str, cache_dir=None): 195 | from transformers import AutoTokenizer 196 | self.cache_dir = os.path.join(cache_dir, tokenizer_name) 197 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, cache_dir=self.cache_dir) 198 | 199 | def save_pretrained(self, dest): 200 | self.tokenizer.save_pretrained(dest, cache_dir=self.cache_dir) 201 | 202 | def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: 203 | # same cleaning as for default tokenizer, except lowercasing 204 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance 205 | if isinstance(texts, str): 206 | texts = [texts] 207 | texts = [whitespace_clean(basic_clean(text)) for text in texts] 208 | input_ids = self.tokenizer( 209 | texts, 210 | return_tensors='pt', 211 | max_length=context_length, 212 | padding='max_length', 213 | truncation=True, 214 | ).input_ids 215 | return input_ids 216 | -------------------------------------------------------------------------------- /OPENCLIP/transform.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass, asdict 3 | from typing import Any, Dict, Optional, Sequence, Tuple, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.transforms.functional as F 8 | 9 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 10 | CenterCrop 11 | 12 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | 14 | 15 | @dataclass 16 | class AugmentationCfg: 17 | scale: Tuple[float, float] = (0.9, 1.0) 18 | ratio: Optional[Tuple[float, float]] = None 19 | color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None 20 | interpolation: Optional[str] = None 21 | re_prob: Optional[float] = None 22 | re_count: Optional[int] = None 23 | use_timm: bool = False 24 | 25 | 26 | class ResizeMaxSize(nn.Module): 27 | 28 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 29 | super().__init__() 30 | if not isinstance(max_size, int): 31 | raise TypeError(f"Size should be int. Got {type(max_size)}") 32 | self.max_size = max_size 33 | self.interpolation = interpolation 34 | self.fn = min if fn == 'min' else min 35 | self.fill = fill 36 | 37 | def forward(self, img): 38 | if isinstance(img, torch.Tensor): 39 | height, width = img.shape[:2] 40 | else: 41 | width, height = img.size 42 | scale = self.max_size / float(max(height, width)) 43 | if scale != 1.0: 44 | new_size = tuple(round(dim * scale) for dim in (height, width)) 45 | img = F.resize(img, new_size, self.interpolation) 46 | pad_h = self.max_size - new_size[0] 47 | pad_w = self.max_size - new_size[1] 48 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) 49 | return img 50 | 51 | 52 | def _convert_to_rgb(image): 53 | return image.convert('RGB') 54 | 55 | 56 | def image_transform( 57 | image_size: int, 58 | is_train: bool, 59 | mean: Optional[Tuple[float, ...]] = None, 60 | std: Optional[Tuple[float, ...]] = None, 61 | resize_longest_max: bool = False, 62 | fill_color: int = 0, 63 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, 64 | ): 65 | mean = mean or OPENAI_DATASET_MEAN 66 | if not isinstance(mean, (list, tuple)): 67 | mean = (mean,) * 3 68 | 69 | std = std or OPENAI_DATASET_STD 70 | if not isinstance(std, (list, tuple)): 71 | std = (std,) * 3 72 | 73 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 74 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 75 | image_size = image_size[0] 76 | 77 | if isinstance(aug_cfg, dict): 78 | aug_cfg = AugmentationCfg(**aug_cfg) 79 | else: 80 | aug_cfg = aug_cfg or AugmentationCfg() 81 | normalize = Normalize(mean=mean, std=std) 82 | if is_train: 83 | aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} 84 | use_timm = aug_cfg_dict.pop('use_timm', False) 85 | if use_timm: 86 | from timm.data import create_transform # timm can still be optional 87 | if isinstance(image_size, (tuple, list)): 88 | assert len(image_size) >= 2 89 | input_size = (3,) + image_size[-2:] 90 | else: 91 | input_size = (3, image_size, image_size) 92 | # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time 93 | aug_cfg_dict.setdefault('interpolation', 'random') 94 | aug_cfg_dict.setdefault('color_jitter', None) # disable by default 95 | train_transform = create_transform( 96 | input_size=input_size, 97 | is_training=True, 98 | hflip=0., 99 | mean=mean, 100 | std=std, 101 | re_mode='pixel', 102 | **aug_cfg_dict, 103 | ) 104 | else: 105 | train_transform = Compose([ 106 | RandomResizedCrop( 107 | image_size, 108 | scale=aug_cfg_dict.pop('scale'), 109 | interpolation=InterpolationMode.BICUBIC, 110 | ), 111 | _convert_to_rgb, 112 | ToTensor(), 113 | normalize, 114 | ]) 115 | if aug_cfg_dict: 116 | warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') 117 | return train_transform 118 | else: 119 | if resize_longest_max: 120 | transforms = [ 121 | ResizeMaxSize(image_size, fill=fill_color) 122 | ] 123 | else: 124 | transforms = [ 125 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 126 | CenterCrop(image_size), 127 | ] 128 | transforms.extend([ 129 | _convert_to_rgb, 130 | ToTensor(), 131 | normalize, 132 | ]) 133 | return Compose(transforms) 134 | -------------------------------------------------------------------------------- /OPENCLIP/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | from torch import nn as nn 5 | from torchvision.ops.misc import FrozenBatchNorm2d 6 | 7 | 8 | def freeze_batch_norm_2d(module, module_match={}, name=''): 9 | """ 10 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 11 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 12 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 13 | 14 | Args: 15 | module (torch.nn.Module): Any PyTorch module. 16 | module_match (dict): Dictionary of full module names to freeze (all if empty) 17 | name (str): Full module name (prefix) 18 | 19 | Returns: 20 | torch.nn.Module: Resulting module 21 | 22 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 23 | """ 24 | res = module 25 | is_match = True 26 | if module_match: 27 | is_match = name in module_match 28 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 29 | res = FrozenBatchNorm2d(module.num_features) 30 | res.num_features = module.num_features 31 | res.affine = module.affine 32 | if module.affine: 33 | res.weight.data = module.weight.data.clone().detach() 34 | res.bias.data = module.bias.data.clone().detach() 35 | res.running_mean.data = module.running_mean.data 36 | res.running_var.data = module.running_var.data 37 | res.eps = module.eps 38 | else: 39 | for child_name, child in module.named_children(): 40 | full_child_name = '.'.join([name, child_name]) if name else child_name 41 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 42 | if new_child is not child: 43 | res.add_module(child_name, new_child) 44 | return res 45 | 46 | 47 | # From PyTorch internals 48 | def _ntuple(n): 49 | def parse(x): 50 | if isinstance(x, collections.abc.Iterable): 51 | return x 52 | return tuple(repeat(x, n)) 53 | return parse 54 | 55 | 56 | to_1tuple = _ntuple(1) 57 | to_2tuple = _ntuple(2) 58 | to_3tuple = _ntuple(3) 59 | to_4tuple = _ntuple(4) 60 | to_ntuple = lambda n, x: _ntuple(n)(x) 61 | -------------------------------------------------------------------------------- /OPENCLIP/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.13.0' 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Leveraging Vision-Language Foundation Models for Fine-Grained Downstream Tasks 2 | 3 | This repository is the official implementation of the paper: 4 | "Leveraging Vision-Language Foundation Models for Fine-Grained Downstream Tasks" 5 | 6 | ![image](visual.png) 7 | 8 | The paper is available [here](https://arxiv.org/abs/2307.06795). 9 | 10 | Pre-trained model is available [here](https://zenodo.org/record/8124014). 11 | 12 | ## Getting Started 13 | 14 | Configuration used for the paper: 15 | - Python: 3.10.9 16 | - Pytorch: 1.13.0 17 | - CUDA: 11.7 18 | - CUDNN: 8500 19 | 20 | ### Installation 21 | ```commandline 22 | git clone https://github.com/FactoDeepLearning/MultitaskVLFM.git 23 | cd MultitaskVLFM 24 | 25 | # In your virtualenv, for instance: 26 | conda create --name VLFM python=3.10 27 | conda activate VLFM 28 | 29 | pip install -e . 30 | ``` 31 | 32 | 33 | ## Reproducing results 34 | ```commandline 35 | python3 Multitasking/main.py --load-pretrain=True --train=False 36 | ``` 37 | 38 | ## Training 39 | ```commandline 40 | python3 Multitasking/main.py 41 | ``` 42 | 43 | Here is a list of the available arguments 44 | ```commandline 45 | --output-name # Name of the output folder [default="my_expe"] 46 | --batch-size" # Size of the training batch [default=2] 47 | --num-epochs # Number of training epochs [default=100] 48 | 49 | --oracle # Which oracle to use [default="OFA"] 50 | # Other option is "UIO", ignored if "--train-oracle-loc" is False 51 | 52 | --model # Which model to use as basis [default="open-ViT-L/14] 53 | # Official CLIP: "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px" 54 | # CLIP trained on LAION: "open-ViT-L/14", "open-ViT-H-14", "open-ViT-G-14" 55 | 56 | --image-encoder # Which image encoder to use [default="swin_vision"] 57 | # None to preserve clip image encoder or "swin_vision" 58 | 59 | --load-pretrain # Whether to load pre-trained weights or not [default=False] 60 | # Must be used with default values for --train-X options 61 | 62 | --train # Whether to train the model until reaching "--num-epochs" or not [default=True] 63 | --eval # Whether to evaluate the model on the test set or not [default=True] 64 | 65 | --train-class # Whether to use classification loss (through embedding similarity) or not [default=True] 66 | --train-attr # Whether to use attribute detection loss (through embedding similarity) or not [default=True] 67 | --train-loc # Whether to use attribute localization loss (with expert annotation as ground truth) or not [default=True] 68 | --train-oracle-loc # Whether to use attribute localization loss (with oracle annotation as ground truth) or not [default=False] 69 | --train-proj-class" # Whether to use classification loss (through projection) or not [default=True] 70 | --train-proj-attr # Whether to use attribute detection loss (through projection) or not [default=False] 71 | 72 | # Weights for the different losses 73 | --weight-class [default=1] 74 | --weight-attr [default=1] 75 | --weight-loc [default=1] 76 | --weight-oracle-loc [default=1] 77 | --weight-proj-attr [default=1] 78 | --weight-proj-class [default=1] 79 | 80 | --adapter-image # Whether to only train image encoder last layers or not [default=False] 81 | --adapter-text # Whether to only train text encoder last layers or not [default=True] 82 | --neg-attributes # Whether to use the positive/negative prompt formulation or not [default=True] 83 | ``` 84 | 85 | ## Citation 86 | 87 | ```bibtex 88 | @misc{Coquenet2023d, 89 | title={Leveraging Vision-Language Foundation Models for Fine-Grained Downstream Tasks}, 90 | author={Denis Coquenet and Clément Rambour and Emanuele Dalsasso and Nicolas Thome}, 91 | year={2023}, 92 | eprint={2307.06795}, 93 | archivePrefix={arXiv}, 94 | primaryClass={cs.CV} 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_namespace_packages 2 | 3 | setup(name='MultitaskVLFM', 4 | packages=find_namespace_packages(include=["Multitasking", "Multitasking.*"]), 5 | version='1.0.0', 6 | install_requires=[ 7 | "torch==1.13.0", 8 | "torchvision==0.14.0", 9 | "tensorboard", 10 | "scikit-learn", 11 | "clip", 12 | "open_clip_torch", 13 | "transformers", 14 | "tqdm", 15 | "pillow", 16 | "einops", 17 | "wget", 18 | "clip @ git+https://github.com/openai/CLIP.git" 19 | ] 20 | ) -------------------------------------------------------------------------------- /visual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FactoDeepLearning/MultitaskVLFM/cde10ef783a5cdeaf1e06016560dfb6cc2a3ffa2/visual.png --------------------------------------------------------------------------------