├── .gitignore ├── README.md ├── eval.py ├── model_conf └── vq-ViT-L-14-k64 │ └── config.json ├── requirements.txt ├── setup.py ├── train_vqclip_from_preembed.py ├── training_conf ├── VQ-ViT-L-14-affine.yaml └── VQ-ViT-L-14.yaml └── vq_clip ├── __init__.py ├── cosine_annealing_warmup.py ├── embedding_dataset.py ├── eval.py ├── modeling_vq_adapter.py ├── modeling_vq_clip.py ├── modules.py ├── perplexity.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/*/__pycache__/ 2 | __pycache__/ 3 | ckpt/ 4 | out/ 5 | build/ 6 | vq_clip.egg-info/ 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VQ-CLIP 2 | 3 | Finetune a CLIP model with a vector quantization bottleneck layer over the output embeddings. The quantization step is only applied to the final normalized CLIP embedding, and can be trained on a dataset of frozen CLIP embeddings. 4 | 5 | # Pretrained VQ-CLIP models 6 | 7 | ### On top of openai/ViT-L-14 8 | 9 | Both of these models were trained for roughly one epoch on datacomp medium, with a batch size of 16384. See `training_conf/VQ-ViT-L-14.yaml` for the training parameters that were used. 10 | 11 | * [k=64 32 heads, multiheaded vq](https://huggingface.co/adams-story/vq-ViT-L-14-k64-d32-ema/tree/main): Gets 0.642 @1 on imagenet. Trained with EMA codebook rather than learnable. 12 | 13 | * [k=32 32 heads, residual quantization](https://huggingface.co/adams-story/vq-ViT-L-14-k32): Gets 0.51 @1 on imagenet validation. 14 | 15 | * [k=64 32 heads, vq quantization with affine parameters](https://huggingface.co/adams-story/vq-ViT-L-14-k64-d32): Gets 0.586 @1 on imagenet validation 16 | 17 | # Set up env 18 | 19 | ``` 20 | $ conda create -n vq-clip 21 | $ conda activate vq-clip 22 | $ conda install pip -y 23 | $ pip install -r requirements.txt 24 | ``` 25 | 26 | # Load a pretrained model 27 | 28 | This will print a bunch of lines to the console complaining about missing `clip_model` weights in the state dict. Don't worry about it; the clip weights are loaded from `clip_path` argument. 29 | 30 | ```python 31 | from PIL import Image 32 | import requests 33 | from vq_clip import VQCLIPModel 34 | from transformers import CLIPProcessor 35 | 36 | model = VQCLIPModel.from_pretrained_clip(clip_path="openai/clip-vit-large-patch14", vision_vq_adapter_path="adams-story/vq-ViT-L-14-k32", ) 37 | 38 | # make prediction 39 | processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") 40 | 41 | url = "http://images.cocodataset.org/val2017/000000039769.jpg" 42 | image = Image.open(requests.get(url, stream=True).raw) 43 | 44 | inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True) 45 | 46 | outputs = model(**inputs) 47 | logits_per_image = outputs.logits_per_image # this is the image-text similarity score 48 | probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities 49 | print(probs) 50 | codes = outputs.image_codes # the vq codes 51 | ``` 52 | 53 | 54 | # Set up training data 55 | 56 | You can train VQ-CLIP from a dataset of text-image CLIP embeddings. You can find these on [HuggingFace](https://huggingface.co/mlfoundations), I'd recommend using the image/text embeddings from [datacomp 1B](mlfoundations/datacomp_1b) dataset. 57 | 58 | Only the .npx files are needed, these can be downloaded using the huggingface `snapshot_download` function. 59 | 60 | This code downloads the dataset into the current directory: 61 | 62 | ```python 63 | import sys 64 | from huggingface_hub import snapshot_download 65 | size = 'medium' 66 | assert size in {"small", "medium", "large", "xlarge"} 67 | 68 | snapshot_download(repo_id=f"mlfoundations/datacomp_{size}", repo_type="dataset", cache_dir="./hf-cache", local_dir=f"./{size}/metadata/", local_dir_use_symlinks=True, resume_download=True, allow_patterns="*.npz", max_workers=4) 69 | 70 | print("\ndone.") 71 | ``` 72 | 73 | You can manually cut a single npx file from the downloaded data to be used as the validation set. 74 | 75 | # Training 76 | 77 | ``` 78 | python train_rqclip_from_preembed.py fit -c conf/VQ-ViT-L-14.yaml --data.path_train /path/to/size/metadata/ --data.path_val /path/to/validation/metadata/ --model.vq_clip_config_path model_conf/vq-ViT-L-14-k1024/config.json 79 | ``` 80 | 81 | By default, training uses ~7GB VRAM, and saves a checkpoint and evaluates every 1000 steps 82 | 83 | Training output is saved in the `out/` folder and can be viewed using tensorboard. 84 | 85 | # ImageNet evaluation 86 | 87 | * Download and extract imagenet 2012 val folder: https://academictorrents.com/details/207ebd69f80a3707f035cd91a114466a270e044d 88 | 89 | * Change the folder structure into a format suitable for pytorch ImageFolder using [this script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh): 90 | 91 | * Run evaluation: 92 | 93 | ```python 94 | from vq_clip import VQCLIPModel 95 | from transformers import CLIPProcessor 96 | from vq_clip.eval import zero_shot_eval 97 | 98 | model = VQCLIPModel.from_pretrained_clip(clip_path="openai/clip-vit-large-patch14", vision_vq_adapter_path="adams-story/vq-ViT-L-14-k32", ) 99 | processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") 100 | 101 | with torch.no_grad(): 102 | with torch.autocast(device): 103 | top1, top5 = zero_shot_eval(vq_clip, processor, imagenet_path, validation_batch_size) 104 | print(top1, top5) 105 | ``` 106 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import CLIPProcessor 3 | from vq_clip import VQCLIPModel 4 | from vq_clip.eval import zero_shot_eval 5 | 6 | 7 | def evaluate( 8 | imagenet_path: str = "", 9 | pretrained_clip_url: str = "openai/clip-vit-large-patch14", 10 | vq_vision_model_url: str = "adams-story/vq-ViT-L-14-k64-d32", 11 | batch_size: int = 1024, 12 | ): 13 | 14 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 15 | 16 | model = VQCLIPModel.from_pretrained_clip(pretrained_clip_url, vision_vq_adapter_path=vq_vision_model_url) 17 | print("loaded vqclip") 18 | 19 | processor = CLIPProcessor.from_pretrained(pretrained_clip_url) 20 | 21 | model = model.to(device) 22 | 23 | with torch.no_grad(): 24 | with torch.autocast(str(device)): 25 | res = zero_shot_eval(model, processor, imagenet_path, batch_size = batch_size) 26 | print(res) 27 | 28 | 29 | if __name__ == "__main__": 30 | import jsonargparse 31 | jsonargparse.CLI(evaluate) 32 | -------------------------------------------------------------------------------- /model_conf/vq-ViT-L-14-k64/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "VQAdapterModel" 4 | ], 5 | "clip_dim": 768, 6 | "codebook_lr": 10.0, 7 | "is_rq": false, 8 | "mlp_dim": 1028, 9 | "mlp_hidden_dim": 512, 10 | "mlp_layers": 1, 11 | "rq_quantize_dropout": true, 12 | "rq_quantize_dropout_cutoff_index": 1, 13 | "rq_quantize_dropout_multiple_of": 4, 14 | "torch_dtype": "float32", 15 | "transformers_version": "4.30.2", 16 | "vq_accept_image_fmap": false, 17 | "vq_affine_param": false, 18 | "vq_affine_param_batch_decay": 0.99, 19 | "vq_affine_param_codebook_decay": 0.9, 20 | "vq_channel_last": true, 21 | "vq_codebook_dim": 32, 22 | "vq_codebook_size": 64, 23 | "vq_commitment_use_cross_entropy_loss": false, 24 | "vq_commitment_weight": 0.1, 25 | "vq_decay": 0.85, 26 | "vq_ema_update": true, 27 | "vq_eps": 1e-05, 28 | "vq_heads": 32, 29 | "vq_kmeans_init": false, 30 | "vq_kmeans_iters": 20, 31 | "vq_learnable_codebook": false, 32 | "vq_orthogonal_reg_active_codes_only": false, 33 | "vq_orthogonal_reg_max_codes": null, 34 | "vq_orthogonal_reg_weight": 0.0, 35 | "vq_reinmax": false, 36 | "vq_sample_codebook_temp": 1.0, 37 | "vq_separate_codebook_per_head": true, 38 | "vq_stochastic_sample_codes": true, 39 | "vq_straight_through": false, 40 | "vq_sync_affine_param": false, 41 | "vq_sync_codebook": false, 42 | "vq_sync_kmeans": true, 43 | "vq_sync_update_v": 0.0, 44 | "vq_threshold_ema_dead_code": 2, 45 | "vq_use_cosine_sim": false 46 | } 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | vector-quantize-pytorch @ git+https://github.com/lucidrains/vector-quantize-pytorch 2 | einops==0.6.1 3 | lightning==2.0.3 4 | numpy==1.24.3 5 | torch==2.0.1 6 | torchvision==0.15.2 7 | tqdm==4.65.0 8 | transformers==4.30.2 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup(name='vq-clip', 6 | version='0.1', 7 | packages=find_packages(), 8 | ) 9 | 10 | -------------------------------------------------------------------------------- /train_vqclip_from_preembed.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.callbacks import ModelCheckpoint 2 | from lightning.pytorch.cli import LightningCLI 3 | from lightning.pytorch.callbacks.lr_monitor import LearningRateMonitor 4 | import torch 5 | 6 | torch.set_float32_matmul_precision("high") 7 | 8 | # Workaround for 'too many open files' error 9 | # import torch.multiprocessing 10 | # torch.multiprocessing.set_sharing_strategy('file_system') 11 | 12 | from vq_clip.embedding_dataset import LightningEmbeddingDataModule 13 | from vq_clip.trainer import LightningVQCLIPTrainer 14 | 15 | 16 | def main(): 17 | cli = LightningCLI( 18 | LightningVQCLIPTrainer, 19 | LightningEmbeddingDataModule, 20 | trainer_defaults={ 21 | "callbacks": [ 22 | LearningRateMonitor(logging_interval="step"), 23 | ModelCheckpoint( 24 | save_last=True, 25 | ), 26 | ], 27 | }, 28 | ) 29 | 30 | 31 | if __name__ in {"__console__", "__main__"}: 32 | main() 33 | -------------------------------------------------------------------------------- /training_conf/VQ-ViT-L-14-affine.yaml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.1.0dev 2 | seed_everything: true 3 | trainer: 4 | accelerator: auto 5 | strategy: auto 6 | devices: auto 7 | num_nodes: 1 8 | precision: bf16-mixed 9 | logger: null 10 | callbacks: null 11 | fast_dev_run: false 12 | max_epochs: 3 13 | min_epochs: null 14 | max_steps: -1 15 | min_steps: null 16 | max_time: null 17 | limit_train_batches: null 18 | limit_val_batches: null 19 | limit_test_batches: null 20 | limit_predict_batches: null 21 | overfit_batches: 0.0 22 | val_check_interval: 2000 23 | check_val_every_n_epoch: 1 24 | num_sanity_val_steps: 0 25 | log_every_n_steps: 1 26 | enable_checkpointing: true 27 | enable_progress_bar: null 28 | enable_model_summary: null 29 | accumulate_grad_batches: 1 30 | gradient_clip_val: 0.5 31 | gradient_clip_algorithm: null 32 | deterministic: null 33 | benchmark: null 34 | inference_mode: true 35 | use_distributed_sampler: true 36 | profiler: null 37 | detect_anomaly: false 38 | barebones: false 39 | plugins: null 40 | sync_batchnorm: false 41 | reload_dataloaders_every_n_epochs: 0 42 | default_root_dir: out/VQ-ViT-L-14/ 43 | 44 | model: 45 | pretrained_clip_url: openai/clip-vit-large-patch14 46 | warmup_steps: 2000 47 | max_lr: 8.e-4 48 | min_lr: 5.e-6 49 | lr_gamma: 0.45 50 | lr_cycle_steps: 20000 51 | 52 | imagenet_path: null 53 | validation_batch_size: 128 54 | 55 | torch_compile: false 56 | data: 57 | key_name: l14 58 | path_train: null 59 | path_val: null 60 | batch_size: 16384 61 | ckpt_path: null 62 | -------------------------------------------------------------------------------- /training_conf/VQ-ViT-L-14.yaml: -------------------------------------------------------------------------------- 1 | # lightning.pytorch==2.1.0dev 2 | seed_everything: true 3 | trainer: 4 | accelerator: auto 5 | strategy: auto 6 | devices: auto 7 | num_nodes: 1 8 | precision: bf16-mixed 9 | logger: null 10 | callbacks: null 11 | fast_dev_run: false 12 | max_epochs: 2 13 | min_epochs: null 14 | max_steps: -1 15 | min_steps: null 16 | max_time: null 17 | limit_train_batches: null 18 | limit_val_batches: null 19 | limit_test_batches: null 20 | limit_predict_batches: null 21 | overfit_batches: 0.0 22 | val_check_interval: 1000 23 | check_val_every_n_epoch: 1 24 | num_sanity_val_steps: 0 25 | log_every_n_steps: 1 26 | enable_checkpointing: true 27 | enable_progress_bar: null 28 | enable_model_summary: null 29 | accumulate_grad_batches: 1 30 | gradient_clip_val: 1.0 31 | gradient_clip_algorithm: null 32 | deterministic: null 33 | benchmark: null 34 | inference_mode: true 35 | use_distributed_sampler: true 36 | profiler: null 37 | detect_anomaly: false 38 | barebones: false 39 | plugins: null 40 | sync_batchnorm: false 41 | reload_dataloaders_every_n_epochs: 0 42 | default_root_dir: out/VQ-ViT-L-14/ 43 | 44 | model: 45 | pretrained_clip_url: openai/clip-vit-large-patch14 46 | warmup_steps: 50 47 | max_lr: 3.e-4 48 | min_lr: 5.e-6 49 | lr_gamma: 0.25 50 | lr_cycle_steps: 20000 51 | 52 | imagenet_path: null 53 | validation_batch_size: 512 54 | 55 | torch_compile: false 56 | data: 57 | path_train: null 58 | path_val: null 59 | batch_size: 24576 60 | ckpt_path: null 61 | -------------------------------------------------------------------------------- /vq_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_vq_clip import VQCLIPModel, VQCLIPOutput, VQCLIPConfig 2 | from .modeling_vq_adapter import VQAdapterModel, VQAdapterConfig 3 | -------------------------------------------------------------------------------- /vq_clip/cosine_annealing_warmup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Thanks: https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup/blob/master/cosine_annealing_warmup/scheduler.py 3 | """ 4 | import math 5 | import torch 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | 8 | 9 | class CosineAnnealingWarmupRestarts(_LRScheduler): 10 | """ 11 | optimizer (Optimizer): Wrapped optimizer. 12 | first_cycle_steps (int): First cycle step size. 13 | cycle_mult(float): Cycle steps magnification. Default: -1. 14 | max_lr(float): First cycle's max learning rate. Default: 0.1. 15 | min_lr(float): Min learning rate. Default: 0.001. 16 | warmup_steps(int): Linear warmup step size. Default: 0. 17 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 18 | last_epoch (int): The index of last epoch. Default: -1. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | optimizer: torch.optim.Optimizer, 24 | first_cycle_steps: int, 25 | cycle_mult: float = 1.0, 26 | max_lr: float = 0.1, 27 | min_lr: float = 0.001, 28 | warmup_steps: int = 0, 29 | gamma: float = 1.0, 30 | last_epoch: int = -1, 31 | ): 32 | assert warmup_steps < first_cycle_steps 33 | 34 | self.first_cycle_steps = first_cycle_steps # first cycle step size 35 | self.cycle_mult = cycle_mult # cycle steps magnification 36 | self.base_max_lr = max_lr # first max learning rate 37 | self.max_lr = max_lr # max learning rate in the current cycle 38 | self.min_lr = min_lr # min learning rate 39 | self.warmup_steps = warmup_steps # warmup step size 40 | self.gamma = gamma # decrease rate of max learning rate by cycle 41 | 42 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 43 | self.cycle = 0 # cycle count 44 | self.step_in_cycle = last_epoch # step size of the current cycle 45 | 46 | super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch) 47 | 48 | # set learning rate min_lr 49 | self.init_lr() 50 | 51 | def init_lr(self): 52 | self.base_lrs = [] 53 | for param_group in self.optimizer.param_groups: 54 | param_group["lr"] = self.min_lr 55 | self.base_lrs.append(self.min_lr) 56 | 57 | def get_lr(self): 58 | if self.step_in_cycle == -1: 59 | return self.base_lrs 60 | elif self.step_in_cycle < self.warmup_steps: 61 | return [ 62 | (self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps 63 | + base_lr 64 | for base_lr in self.base_lrs 65 | ] 66 | else: 67 | return [ 68 | base_lr 69 | + (self.max_lr - base_lr) 70 | * ( 71 | 1 72 | + math.cos( 73 | math.pi 74 | * (self.step_in_cycle - self.warmup_steps) 75 | / (self.cur_cycle_steps - self.warmup_steps) 76 | ) 77 | ) 78 | / 2 79 | for base_lr in self.base_lrs 80 | ] 81 | 82 | def step(self, epoch=None): 83 | if epoch is None: 84 | epoch = self.last_epoch + 1 85 | self.step_in_cycle = self.step_in_cycle + 1 86 | if self.step_in_cycle >= self.cur_cycle_steps: 87 | self.cycle += 1 88 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 89 | self.cur_cycle_steps = ( 90 | int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) 91 | + self.warmup_steps 92 | ) 93 | else: 94 | if epoch >= self.first_cycle_steps: 95 | if self.cycle_mult == 1.0: 96 | self.step_in_cycle = epoch % self.first_cycle_steps 97 | self.cycle = epoch // self.first_cycle_steps 98 | else: 99 | n = int( 100 | math.log( 101 | ( 102 | epoch / self.first_cycle_steps * (self.cycle_mult - 1) 103 | + 1 104 | ), 105 | self.cycle_mult, 106 | ) 107 | ) 108 | self.cycle = n 109 | self.step_in_cycle = epoch - int( 110 | self.first_cycle_steps 111 | * (self.cycle_mult**n - 1) 112 | / (self.cycle_mult - 1) 113 | ) 114 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** ( 115 | n 116 | ) 117 | else: 118 | self.cur_cycle_steps = self.first_cycle_steps 119 | self.step_in_cycle = epoch 120 | 121 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle) 122 | self.last_epoch = math.floor(epoch) 123 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 124 | param_group["lr"] = lr 125 | -------------------------------------------------------------------------------- /vq_clip/embedding_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | For training from a pre-embedded dataset of image-text pairs from a clip model 3 | 4 | 5 | Make sure that the CLIP model you use is the same as the one used to obtain the 6 | pre embeddings 7 | """ 8 | 9 | import torch.utils.data 10 | from math import ceil 11 | from typing import List 12 | import lightning.pytorch as pl 13 | import numpy as np 14 | from glob import glob 15 | import re 16 | import os 17 | 18 | from torch.utils.data import DataLoader, Dataset, IterableDataset 19 | 20 | 21 | def get_file_code(filename: str): 22 | return os.path.basename(filename).split(".")[-2] 23 | 24 | 25 | def random_sort(*lists): 26 | indices = np.random.permutation(len(lists[0])) 27 | return [l[i] for i in indices for l in lists] 28 | 29 | 30 | class IterableImageTextPairDataset(IterableDataset): 31 | def __init__(self, path: str, batch_size: int, ): 32 | self.img_files = glob(path + "/image/*.npy") 33 | self.img_files.sort(key=get_file_code) 34 | 35 | self.txt_files = glob(path + "/text/*.npy") 36 | self.txt_files.sort(key=get_file_code) 37 | 38 | assert len(self.img_files) > 0 39 | assert len(self.img_files) == len(self.txt_files) 40 | 41 | for txt_file, img_file in zip(self.txt_files, self.img_files): 42 | assert os.path.basename(txt_file) == os.path.basename(img_file), f'{txt_file} != {img_file}' 43 | 44 | self.batch_size = batch_size 45 | 46 | def __iter__(self): 47 | worker_info = torch.utils.data.get_worker_info() 48 | if worker_info is None: 49 | start, end = 0, len(self.img_files) 50 | else: 51 | n_files_per_worker = len(self.img_files) // worker_info.num_workers 52 | worker_id = worker_info.id 53 | start = worker_id * n_files_per_worker 54 | end = min(start + n_files_per_worker, len(self.img_files)) 55 | return ImageTextPairDatasetWorker( 56 | self.img_files[start:end], 57 | self.txt_files[start:end], 58 | batch_size=self.batch_size, 59 | ) 60 | 61 | 62 | class ImageTextPairDatasetWorker(IterableDataset): 63 | def __init__(self, img_files: List[str], txt_files: List[str], batch_size: int, ): 64 | self.img_files = img_files 65 | self.txt_files = txt_files 66 | assert len(self.img_files) == len(self.txt_files) 67 | 68 | # increasing this number improves random sampling 69 | self.num_files_to_load = 2 70 | 71 | self.batch_size = batch_size 72 | 73 | self.file_i = 0 74 | self.batch_i = 0 75 | self.__iter_file() 76 | 77 | def __iter_file(self): 78 | num_files_to_load = min( 79 | len(self.txt_files) - self.file_i, self.num_files_to_load 80 | ) 81 | print("__iter_file loading files ", num_files_to_load) 82 | text_data = [] 83 | image_data = [] 84 | n_loaded = 0 85 | while n_loaded < num_files_to_load: 86 | print("Loading files", self.txt_files[self.file_i], self.img_files[self.file_i]) 87 | try: 88 | img_dat = np.load(self.img_files[self.file_i]) 89 | txt_dat = np.load(self.txt_files[self.file_i]) 90 | 91 | assert len(img_dat) == len(txt_dat) 92 | 93 | text_data.append(img_dat) 94 | image_data.append(txt_dat) 95 | n_loaded += 1 96 | assert len(text_data[-1]) == len(image_data[-1]) 97 | assert len(text_data[0]) == len(image_data[0]) 98 | except Exception as e: 99 | print("error loading files", self.img_files[self.file_i], self.txt_files[self.file_i], e) 100 | self.file_i += 1 101 | 102 | text_data = np.concatenate(text_data, axis=0) 103 | image_data = np.concatenate(image_data, axis=0) 104 | 105 | rnd_indices = np.random.permutation(len(text_data)) 106 | text_data = text_data[rnd_indices] 107 | image_data = image_data[rnd_indices] 108 | 109 | self.text_data = np.array_split( 110 | text_data, ceil(len(text_data) / self.batch_size) 111 | ) 112 | self.image_data = np.array_split( 113 | image_data, ceil(len(image_data) / self.batch_size) 114 | ) 115 | assert len(self.text_data) == len(self.image_data) 116 | self.batch_i = 0 117 | 118 | def __iter__(self): 119 | return self 120 | 121 | def __len__(self): 122 | n_files = len(self.img_files) 123 | num_rows_per_file = 500000 124 | return n_files * num_rows_per_file // self.batch_size 125 | 126 | def __next__(self): 127 | if self.batch_i >= len(self.image_data): 128 | if self.file_i >= len(self.img_files): 129 | raise StopIteration 130 | else: 131 | self.__iter_file() 132 | _ret = self.image_data[self.batch_i], self.text_data[self.batch_i] 133 | self.batch_i += 1 134 | return _ret 135 | 136 | 137 | class LightningEmbeddingDataModule(pl.LightningDataModule): 138 | def __init__( 139 | self, 140 | path_train: str, 141 | path_val: str, 142 | batch_size: int, 143 | *_, 144 | **kwargs 145 | ): 146 | """ 147 | path_train: some path to a directory with the following subfolders: 148 | images/*.npy 149 | texts/*.npy 150 | 151 | This module will iterate over all rows of all npy files. 152 | """ 153 | super().__init__() 154 | self.batch_size = batch_size 155 | 156 | self.ds_train = IterableImageTextPairDataset(path_train, batch_size, ) 157 | self.ds_test = IterableImageTextPairDataset(path_val, batch_size, ) 158 | 159 | def train_dataloader(self): 160 | return DataLoader(self.ds_train, num_workers=4, batch_size=None, shuffle=False) 161 | 162 | def val_dataloader(self): 163 | return DataLoader(self.ds_test, num_workers=1, batch_size=None, shuffle=False) 164 | -------------------------------------------------------------------------------- /vq_clip/eval.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from itertools import islice 3 | from typing import Callable, Optional, Sequence, Union 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import torchvision 8 | from tqdm import tqdm 9 | from transformers import CLIPModel, CLIPProcessor 10 | 11 | 12 | def batched(iterable, n): 13 | """Batch data into lists of length *n*. The last batch may be shorter. 14 | NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl 15 | """ 16 | it = iter(iterable) 17 | while True: 18 | batch = list(islice(it, n)) 19 | if not batch: 20 | break 21 | yield batch 22 | 23 | 24 | def accuracy(output, target, topk=(1,)): 25 | pred = output.topk(max(topk), 1, True, True)[1].t() 26 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 27 | return [ 28 | float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) 29 | for k in topk 30 | ] 31 | 32 | 33 | def build_zero_shot_classifier( 34 | model: CLIPModel, 35 | processor: CLIPProcessor, 36 | classnames: Sequence[str], 37 | templates: Sequence[Union[Callable, str]], 38 | num_classes_per_batch: Optional[int] = 10, 39 | use_tqdm: bool = False, 40 | ): 41 | """Build zero-shot classifier weights by iterating over class names in batches 42 | Args: 43 | model: CLIP model instance 44 | tokenizer: CLIP tokenizer instance 45 | classnames: A sequence of class (label) names 46 | templates: A sequence of callables or format() friendly strings to produce templates per class name 47 | num_classes_per_batch: The number of classes to batch together in each forward, all if None 48 | device: Device to use. 49 | use_tqdm: Enable TQDM progress bar. 50 | """ 51 | assert isinstance(templates, Sequence) and len(templates) > 0 52 | assert isinstance(classnames, Sequence) and len(classnames) > 0 53 | use_format = isinstance(templates[0], str) 54 | num_templates = len(templates) 55 | num_classes = len(classnames) 56 | if use_tqdm: 57 | import tqdm 58 | 59 | num_iter = ( 60 | 1 61 | if num_classes_per_batch is None 62 | else ((num_classes - 1) // num_classes_per_batch + 1) 63 | ) 64 | iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) 65 | else: 66 | iter_wrap = iter 67 | 68 | def _process_batch(batch_classnames): 69 | num_batch_classes = len(batch_classnames) 70 | texts = [ 71 | template.format(c) if use_format else template(c) 72 | for c in batch_classnames 73 | for template in templates 74 | ] 75 | inputs = processor(text=texts, return_tensors="pt", padding=True).to( 76 | model.device 77 | ) 78 | class_embeddings = model.get_text_features(**inputs) 79 | class_embeddings = class_embeddings.reshape( 80 | num_batch_classes, num_templates, -1 81 | ).mean(dim=1) 82 | class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) 83 | class_embeddings = class_embeddings.T 84 | return class_embeddings 85 | 86 | with torch.no_grad(): 87 | if num_classes_per_batch: 88 | batched_embeds = [ 89 | _process_batch(batch) 90 | for batch in iter_wrap(batched(classnames, num_classes_per_batch)) 91 | ] 92 | zeroshot_weights = torch.cat(batched_embeds, dim=1) 93 | else: 94 | zeroshot_weights = _process_batch(classnames) 95 | return zeroshot_weights 96 | 97 | 98 | def zero_shot_eval( 99 | model: CLIPModel, processor: CLIPProcessor, imagenet_path: str, batch_size=32 100 | ): 101 | was_training = model.training 102 | model.eval() 103 | print("Starting zero-shot imagenet.") 104 | 105 | ds = torchvision.datasets.ImageFolder(imagenet_path) 106 | 107 | def collate(x): 108 | images, labels = list(zip(*x)) 109 | images = processor(images=images, return_tensors="pt")["pixel_values"] 110 | labels = torch.LongTensor(labels) 111 | return images, labels 112 | 113 | dl = DataLoader( 114 | ds, batch_size=batch_size, shuffle=False, collate_fn=collate, num_workers=2 115 | ) 116 | 117 | print("Building zero-shot classifier") 118 | classifier = build_zero_shot_classifier( 119 | model, 120 | processor, 121 | classnames=IMAGENET_CLASSNAMES, 122 | templates=OPENAI_IMAGENET_TEMPLATES, 123 | num_classes_per_batch=10, 124 | use_tqdm=True, 125 | ) 126 | 127 | print("Using classifier") 128 | with torch.no_grad(): 129 | top1, top5, n = 0.0, 0.0, 0.0 130 | for images, target in tqdm(dl, unit_scale=batch_size): 131 | images = images.to(device=model.device) 132 | target = target.to(model.device) 133 | 134 | # predict 135 | # inputs = processor(images=images, return_tensors='pt') 136 | image_features = model.get_image_features(pixel_values=images) 137 | logits = model.logit_scale * image_features @ classifier 138 | 139 | # measure accuracy 140 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 141 | top1 += acc1 142 | top5 += acc5 143 | n += images.size(0) 144 | 145 | top1 = top1 / n 146 | top5 = top5 / n 147 | print("top1:", top1, "top5:", top5) 148 | 149 | print("Finished zero-shot imagenet.") 150 | 151 | if was_training: 152 | model.train() 153 | 154 | return top1, top5 155 | 156 | 157 | OPENAI_IMAGENET_TEMPLATES = ( 158 | lambda c: f"a bad photo of a {c}.", 159 | lambda c: f"a photo of many {c}.", 160 | lambda c: f"a sculpture of a {c}.", 161 | lambda c: f"a photo of the hard to see {c}.", 162 | lambda c: f"a low resolution photo of the {c}.", 163 | lambda c: f"a rendering of a {c}.", 164 | lambda c: f"graffiti of a {c}.", 165 | lambda c: f"a bad photo of the {c}.", 166 | lambda c: f"a cropped photo of the {c}.", 167 | lambda c: f"a tattoo of a {c}.", 168 | lambda c: f"the embroidered {c}.", 169 | lambda c: f"a photo of a hard to see {c}.", 170 | lambda c: f"a bright photo of a {c}.", 171 | lambda c: f"a photo of a clean {c}.", 172 | lambda c: f"a photo of a dirty {c}.", 173 | lambda c: f"a dark photo of the {c}.", 174 | lambda c: f"a drawing of a {c}.", 175 | lambda c: f"a photo of my {c}.", 176 | lambda c: f"the plastic {c}.", 177 | lambda c: f"a photo of the cool {c}.", 178 | lambda c: f"a close-up photo of a {c}.", 179 | lambda c: f"a black and white photo of the {c}.", 180 | lambda c: f"a painting of the {c}.", 181 | lambda c: f"a painting of a {c}.", 182 | lambda c: f"a pixelated photo of the {c}.", 183 | lambda c: f"a sculpture of the {c}.", 184 | lambda c: f"a bright photo of the {c}.", 185 | lambda c: f"a cropped photo of a {c}.", 186 | lambda c: f"a plastic {c}.", 187 | lambda c: f"a photo of the dirty {c}.", 188 | lambda c: f"a jpeg corrupted photo of a {c}.", 189 | lambda c: f"a blurry photo of the {c}.", 190 | lambda c: f"a photo of the {c}.", 191 | lambda c: f"a good photo of the {c}.", 192 | lambda c: f"a rendering of the {c}.", 193 | lambda c: f"a {c} in a video game.", 194 | lambda c: f"a photo of one {c}.", 195 | lambda c: f"a doodle of a {c}.", 196 | lambda c: f"a close-up photo of the {c}.", 197 | lambda c: f"a photo of a {c}.", 198 | lambda c: f"the origami {c}.", 199 | lambda c: f"the {c} in a video game.", 200 | lambda c: f"a sketch of a {c}.", 201 | lambda c: f"a doodle of the {c}.", 202 | lambda c: f"a origami {c}.", 203 | lambda c: f"a low resolution photo of a {c}.", 204 | lambda c: f"the toy {c}.", 205 | lambda c: f"a rendition of the {c}.", 206 | lambda c: f"a photo of the clean {c}.", 207 | lambda c: f"a photo of a large {c}.", 208 | lambda c: f"a rendition of a {c}.", 209 | lambda c: f"a photo of a nice {c}.", 210 | lambda c: f"a photo of a weird {c}.", 211 | lambda c: f"a blurry photo of a {c}.", 212 | lambda c: f"a cartoon {c}.", 213 | lambda c: f"art of a {c}.", 214 | lambda c: f"a sketch of the {c}.", 215 | lambda c: f"a embroidered {c}.", 216 | lambda c: f"a pixelated photo of a {c}.", 217 | lambda c: f"itap of the {c}.", 218 | lambda c: f"a jpeg corrupted photo of the {c}.", 219 | lambda c: f"a good photo of a {c}.", 220 | lambda c: f"a plushie {c}.", 221 | lambda c: f"a photo of the nice {c}.", 222 | lambda c: f"a photo of the small {c}.", 223 | lambda c: f"a photo of the weird {c}.", 224 | lambda c: f"the cartoon {c}.", 225 | lambda c: f"art of the {c}.", 226 | lambda c: f"a drawing of the {c}.", 227 | lambda c: f"a photo of the large {c}.", 228 | lambda c: f"a black and white photo of a {c}.", 229 | lambda c: f"the plushie {c}.", 230 | lambda c: f"a dark photo of a {c}.", 231 | lambda c: f"itap of a {c}.", 232 | lambda c: f"graffiti of the {c}.", 233 | lambda c: f"a toy {c}.", 234 | lambda c: f"itap of my {c}.", 235 | lambda c: f"a photo of a cool {c}.", 236 | lambda c: f"a photo of a small {c}.", 237 | lambda c: f"a tattoo of the {c}.", 238 | ) 239 | 240 | 241 | # a much smaller subset of above prompts 242 | # from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb 243 | SIMPLE_IMAGENET_TEMPLATES = ( 244 | lambda c: f"itap of a {c}.", 245 | lambda c: f"a bad photo of the {c}.", 246 | lambda c: f"a origami {c}.", 247 | lambda c: f"a photo of the large {c}.", 248 | lambda c: f"a {c} in a video game.", 249 | lambda c: f"art of the {c}.", 250 | lambda c: f"a photo of the small {c}.", 251 | ) 252 | 253 | 254 | IMAGENET_CLASSNAMES = ( 255 | "tench", 256 | "goldfish", 257 | "great white shark", 258 | "tiger shark", 259 | "hammerhead shark", 260 | "electric ray", 261 | "stingray", 262 | "rooster", 263 | "hen", 264 | "ostrich", 265 | "brambling", 266 | "goldfinch", 267 | "house finch", 268 | "junco", 269 | "indigo bunting", 270 | "American robin", 271 | "bulbul", 272 | "jay", 273 | "magpie", 274 | "chickadee", 275 | "American dipper", 276 | "kite (bird of prey)", 277 | "bald eagle", 278 | "vulture", 279 | "great grey owl", 280 | "fire salamander", 281 | "smooth newt", 282 | "newt", 283 | "spotted salamander", 284 | "axolotl", 285 | "American bullfrog", 286 | "tree frog", 287 | "tailed frog", 288 | "loggerhead sea turtle", 289 | "leatherback sea turtle", 290 | "mud turtle", 291 | "terrapin", 292 | "box turtle", 293 | "banded gecko", 294 | "green iguana", 295 | "Carolina anole", 296 | "desert grassland whiptail lizard", 297 | "agama", 298 | "frilled-necked lizard", 299 | "alligator lizard", 300 | "Gila monster", 301 | "European green lizard", 302 | "chameleon", 303 | "Komodo dragon", 304 | "Nile crocodile", 305 | "American alligator", 306 | "triceratops", 307 | "worm snake", 308 | "ring-necked snake", 309 | "eastern hog-nosed snake", 310 | "smooth green snake", 311 | "kingsnake", 312 | "garter snake", 313 | "water snake", 314 | "vine snake", 315 | "night snake", 316 | "boa constrictor", 317 | "African rock python", 318 | "Indian cobra", 319 | "green mamba", 320 | "sea snake", 321 | "Saharan horned viper", 322 | "eastern diamondback rattlesnake", 323 | "sidewinder rattlesnake", 324 | "trilobite", 325 | "harvestman", 326 | "scorpion", 327 | "yellow garden spider", 328 | "barn spider", 329 | "European garden spider", 330 | "southern black widow", 331 | "tarantula", 332 | "wolf spider", 333 | "tick", 334 | "centipede", 335 | "black grouse", 336 | "ptarmigan", 337 | "ruffed grouse", 338 | "prairie grouse", 339 | "peafowl", 340 | "quail", 341 | "partridge", 342 | "african grey parrot", 343 | "macaw", 344 | "sulphur-crested cockatoo", 345 | "lorikeet", 346 | "coucal", 347 | "bee eater", 348 | "hornbill", 349 | "hummingbird", 350 | "jacamar", 351 | "toucan", 352 | "duck", 353 | "red-breasted merganser", 354 | "goose", 355 | "black swan", 356 | "tusker", 357 | "echidna", 358 | "platypus", 359 | "wallaby", 360 | "koala", 361 | "wombat", 362 | "jellyfish", 363 | "sea anemone", 364 | "brain coral", 365 | "flatworm", 366 | "nematode", 367 | "conch", 368 | "snail", 369 | "slug", 370 | "sea slug", 371 | "chiton", 372 | "chambered nautilus", 373 | "Dungeness crab", 374 | "rock crab", 375 | "fiddler crab", 376 | "red king crab", 377 | "American lobster", 378 | "spiny lobster", 379 | "crayfish", 380 | "hermit crab", 381 | "isopod", 382 | "white stork", 383 | "black stork", 384 | "spoonbill", 385 | "flamingo", 386 | "little blue heron", 387 | "great egret", 388 | "bittern bird", 389 | "crane bird", 390 | "limpkin", 391 | "common gallinule", 392 | "American coot", 393 | "bustard", 394 | "ruddy turnstone", 395 | "dunlin", 396 | "common redshank", 397 | "dowitcher", 398 | "oystercatcher", 399 | "pelican", 400 | "king penguin", 401 | "albatross", 402 | "grey whale", 403 | "killer whale", 404 | "dugong", 405 | "sea lion", 406 | "Chihuahua", 407 | "Japanese Chin", 408 | "Maltese", 409 | "Pekingese", 410 | "Shih Tzu", 411 | "King Charles Spaniel", 412 | "Papillon", 413 | "toy terrier", 414 | "Rhodesian Ridgeback", 415 | "Afghan Hound", 416 | "Basset Hound", 417 | "Beagle", 418 | "Bloodhound", 419 | "Bluetick Coonhound", 420 | "Black and Tan Coonhound", 421 | "Treeing Walker Coonhound", 422 | "English foxhound", 423 | "Redbone Coonhound", 424 | "borzoi", 425 | "Irish Wolfhound", 426 | "Italian Greyhound", 427 | "Whippet", 428 | "Ibizan Hound", 429 | "Norwegian Elkhound", 430 | "Otterhound", 431 | "Saluki", 432 | "Scottish Deerhound", 433 | "Weimaraner", 434 | "Staffordshire Bull Terrier", 435 | "American Staffordshire Terrier", 436 | "Bedlington Terrier", 437 | "Border Terrier", 438 | "Kerry Blue Terrier", 439 | "Irish Terrier", 440 | "Norfolk Terrier", 441 | "Norwich Terrier", 442 | "Yorkshire Terrier", 443 | "Wire Fox Terrier", 444 | "Lakeland Terrier", 445 | "Sealyham Terrier", 446 | "Airedale Terrier", 447 | "Cairn Terrier", 448 | "Australian Terrier", 449 | "Dandie Dinmont Terrier", 450 | "Boston Terrier", 451 | "Miniature Schnauzer", 452 | "Giant Schnauzer", 453 | "Standard Schnauzer", 454 | "Scottish Terrier", 455 | "Tibetan Terrier", 456 | "Australian Silky Terrier", 457 | "Soft-coated Wheaten Terrier", 458 | "West Highland White Terrier", 459 | "Lhasa Apso", 460 | "Flat-Coated Retriever", 461 | "Curly-coated Retriever", 462 | "Golden Retriever", 463 | "Labrador Retriever", 464 | "Chesapeake Bay Retriever", 465 | "German Shorthaired Pointer", 466 | "Vizsla", 467 | "English Setter", 468 | "Irish Setter", 469 | "Gordon Setter", 470 | "Brittany dog", 471 | "Clumber Spaniel", 472 | "English Springer Spaniel", 473 | "Welsh Springer Spaniel", 474 | "Cocker Spaniel", 475 | "Sussex Spaniel", 476 | "Irish Water Spaniel", 477 | "Kuvasz", 478 | "Schipperke", 479 | "Groenendael dog", 480 | "Malinois", 481 | "Briard", 482 | "Australian Kelpie", 483 | "Komondor", 484 | "Old English Sheepdog", 485 | "Shetland Sheepdog", 486 | "collie", 487 | "Border Collie", 488 | "Bouvier des Flandres dog", 489 | "Rottweiler", 490 | "German Shepherd Dog", 491 | "Dobermann", 492 | "Miniature Pinscher", 493 | "Greater Swiss Mountain Dog", 494 | "Bernese Mountain Dog", 495 | "Appenzeller Sennenhund", 496 | "Entlebucher Sennenhund", 497 | "Boxer", 498 | "Bullmastiff", 499 | "Tibetan Mastiff", 500 | "French Bulldog", 501 | "Great Dane", 502 | "St. Bernard", 503 | "husky", 504 | "Alaskan Malamute", 505 | "Siberian Husky", 506 | "Dalmatian", 507 | "Affenpinscher", 508 | "Basenji", 509 | "pug", 510 | "Leonberger", 511 | "Newfoundland dog", 512 | "Great Pyrenees dog", 513 | "Samoyed", 514 | "Pomeranian", 515 | "Chow Chow", 516 | "Keeshond", 517 | "brussels griffon", 518 | "Pembroke Welsh Corgi", 519 | "Cardigan Welsh Corgi", 520 | "Toy Poodle", 521 | "Miniature Poodle", 522 | "Standard Poodle", 523 | "Mexican hairless dog (xoloitzcuintli)", 524 | "grey wolf", 525 | "Alaskan tundra wolf", 526 | "red wolf or maned wolf", 527 | "coyote", 528 | "dingo", 529 | "dhole", 530 | "African wild dog", 531 | "hyena", 532 | "red fox", 533 | "kit fox", 534 | "Arctic fox", 535 | "grey fox", 536 | "tabby cat", 537 | "tiger cat", 538 | "Persian cat", 539 | "Siamese cat", 540 | "Egyptian Mau", 541 | "cougar", 542 | "lynx", 543 | "leopard", 544 | "snow leopard", 545 | "jaguar", 546 | "lion", 547 | "tiger", 548 | "cheetah", 549 | "brown bear", 550 | "American black bear", 551 | "polar bear", 552 | "sloth bear", 553 | "mongoose", 554 | "meerkat", 555 | "tiger beetle", 556 | "ladybug", 557 | "ground beetle", 558 | "longhorn beetle", 559 | "leaf beetle", 560 | "dung beetle", 561 | "rhinoceros beetle", 562 | "weevil", 563 | "fly", 564 | "bee", 565 | "ant", 566 | "grasshopper", 567 | "cricket insect", 568 | "stick insect", 569 | "cockroach", 570 | "praying mantis", 571 | "cicada", 572 | "leafhopper", 573 | "lacewing", 574 | "dragonfly", 575 | "damselfly", 576 | "red admiral butterfly", 577 | "ringlet butterfly", 578 | "monarch butterfly", 579 | "small white butterfly", 580 | "sulphur butterfly", 581 | "gossamer-winged butterfly", 582 | "starfish", 583 | "sea urchin", 584 | "sea cucumber", 585 | "cottontail rabbit", 586 | "hare", 587 | "Angora rabbit", 588 | "hamster", 589 | "porcupine", 590 | "fox squirrel", 591 | "marmot", 592 | "beaver", 593 | "guinea pig", 594 | "common sorrel horse", 595 | "zebra", 596 | "pig", 597 | "wild boar", 598 | "warthog", 599 | "hippopotamus", 600 | "ox", 601 | "water buffalo", 602 | "bison", 603 | "ram (adult male sheep)", 604 | "bighorn sheep", 605 | "Alpine ibex", 606 | "hartebeest", 607 | "impala (antelope)", 608 | "gazelle", 609 | "arabian camel", 610 | "llama", 611 | "weasel", 612 | "mink", 613 | "European polecat", 614 | "black-footed ferret", 615 | "otter", 616 | "skunk", 617 | "badger", 618 | "armadillo", 619 | "three-toed sloth", 620 | "orangutan", 621 | "gorilla", 622 | "chimpanzee", 623 | "gibbon", 624 | "siamang", 625 | "guenon", 626 | "patas monkey", 627 | "baboon", 628 | "macaque", 629 | "langur", 630 | "black-and-white colobus", 631 | "proboscis monkey", 632 | "marmoset", 633 | "white-headed capuchin", 634 | "howler monkey", 635 | "titi monkey", 636 | "Geoffroy's spider monkey", 637 | "common squirrel monkey", 638 | "ring-tailed lemur", 639 | "indri", 640 | "Asian elephant", 641 | "African bush elephant", 642 | "red panda", 643 | "giant panda", 644 | "snoek fish", 645 | "eel", 646 | "silver salmon", 647 | "rock beauty fish", 648 | "clownfish", 649 | "sturgeon", 650 | "gar fish", 651 | "lionfish", 652 | "pufferfish", 653 | "abacus", 654 | "abaya", 655 | "academic gown", 656 | "accordion", 657 | "acoustic guitar", 658 | "aircraft carrier", 659 | "airliner", 660 | "airship", 661 | "altar", 662 | "ambulance", 663 | "amphibious vehicle", 664 | "analog clock", 665 | "apiary", 666 | "apron", 667 | "trash can", 668 | "assault rifle", 669 | "backpack", 670 | "bakery", 671 | "balance beam", 672 | "balloon", 673 | "ballpoint pen", 674 | "Band-Aid", 675 | "banjo", 676 | "baluster / handrail", 677 | "barbell", 678 | "barber chair", 679 | "barbershop", 680 | "barn", 681 | "barometer", 682 | "barrel", 683 | "wheelbarrow", 684 | "baseball", 685 | "basketball", 686 | "bassinet", 687 | "bassoon", 688 | "swimming cap", 689 | "bath towel", 690 | "bathtub", 691 | "station wagon", 692 | "lighthouse", 693 | "beaker", 694 | "military hat (bearskin or shako)", 695 | "beer bottle", 696 | "beer glass", 697 | "bell tower", 698 | "baby bib", 699 | "tandem bicycle", 700 | "bikini", 701 | "ring binder", 702 | "binoculars", 703 | "birdhouse", 704 | "boathouse", 705 | "bobsleigh", 706 | "bolo tie", 707 | "poke bonnet", 708 | "bookcase", 709 | "bookstore", 710 | "bottle cap", 711 | "hunting bow", 712 | "bow tie", 713 | "brass memorial plaque", 714 | "bra", 715 | "breakwater", 716 | "breastplate", 717 | "broom", 718 | "bucket", 719 | "buckle", 720 | "bulletproof vest", 721 | "high-speed train", 722 | "butcher shop", 723 | "taxicab", 724 | "cauldron", 725 | "candle", 726 | "cannon", 727 | "canoe", 728 | "can opener", 729 | "cardigan", 730 | "car mirror", 731 | "carousel", 732 | "tool kit", 733 | "cardboard box / carton", 734 | "car wheel", 735 | "automated teller machine", 736 | "cassette", 737 | "cassette player", 738 | "castle", 739 | "catamaran", 740 | "CD player", 741 | "cello", 742 | "mobile phone", 743 | "chain", 744 | "chain-link fence", 745 | "chain mail", 746 | "chainsaw", 747 | "storage chest", 748 | "chiffonier", 749 | "bell or wind chime", 750 | "china cabinet", 751 | "Christmas stocking", 752 | "church", 753 | "movie theater", 754 | "cleaver", 755 | "cliff dwelling", 756 | "cloak", 757 | "clogs", 758 | "cocktail shaker", 759 | "coffee mug", 760 | "coffeemaker", 761 | "spiral or coil", 762 | "combination lock", 763 | "computer keyboard", 764 | "candy store", 765 | "container ship", 766 | "convertible", 767 | "corkscrew", 768 | "cornet", 769 | "cowboy boot", 770 | "cowboy hat", 771 | "cradle", 772 | "construction crane", 773 | "crash helmet", 774 | "crate", 775 | "infant bed", 776 | "Crock Pot", 777 | "croquet ball", 778 | "crutch", 779 | "cuirass", 780 | "dam", 781 | "desk", 782 | "desktop computer", 783 | "rotary dial telephone", 784 | "diaper", 785 | "digital clock", 786 | "digital watch", 787 | "dining table", 788 | "dishcloth", 789 | "dishwasher", 790 | "disc brake", 791 | "dock", 792 | "dog sled", 793 | "dome", 794 | "doormat", 795 | "drilling rig", 796 | "drum", 797 | "drumstick", 798 | "dumbbell", 799 | "Dutch oven", 800 | "electric fan", 801 | "electric guitar", 802 | "electric locomotive", 803 | "entertainment center", 804 | "envelope", 805 | "espresso machine", 806 | "face powder", 807 | "feather boa", 808 | "filing cabinet", 809 | "fireboat", 810 | "fire truck", 811 | "fire screen", 812 | "flagpole", 813 | "flute", 814 | "folding chair", 815 | "football helmet", 816 | "forklift", 817 | "fountain", 818 | "fountain pen", 819 | "four-poster bed", 820 | "freight car", 821 | "French horn", 822 | "frying pan", 823 | "fur coat", 824 | "garbage truck", 825 | "gas mask or respirator", 826 | "gas pump", 827 | "goblet", 828 | "go-kart", 829 | "golf ball", 830 | "golf cart", 831 | "gondola", 832 | "gong", 833 | "gown", 834 | "grand piano", 835 | "greenhouse", 836 | "radiator grille", 837 | "grocery store", 838 | "guillotine", 839 | "hair clip", 840 | "hair spray", 841 | "half-track", 842 | "hammer", 843 | "hamper", 844 | "hair dryer", 845 | "hand-held computer", 846 | "handkerchief", 847 | "hard disk drive", 848 | "harmonica", 849 | "harp", 850 | "combine harvester", 851 | "hatchet", 852 | "holster", 853 | "home theater", 854 | "honeycomb", 855 | "hook", 856 | "hoop skirt", 857 | "gymnastic horizontal bar", 858 | "horse-drawn vehicle", 859 | "hourglass", 860 | "iPod", 861 | "clothes iron", 862 | "carved pumpkin", 863 | "jeans", 864 | "jeep", 865 | "T-shirt", 866 | "jigsaw puzzle", 867 | "rickshaw", 868 | "joystick", 869 | "kimono", 870 | "knee pad", 871 | "knot", 872 | "lab coat", 873 | "ladle", 874 | "lampshade", 875 | "laptop computer", 876 | "lawn mower", 877 | "lens cap", 878 | "letter opener", 879 | "library", 880 | "lifeboat", 881 | "lighter", 882 | "limousine", 883 | "ocean liner", 884 | "lipstick", 885 | "slip-on shoe", 886 | "lotion", 887 | "music speaker", 888 | "loupe magnifying glass", 889 | "sawmill", 890 | "magnetic compass", 891 | "messenger bag", 892 | "mailbox", 893 | "tights", 894 | "one-piece bathing suit", 895 | "manhole cover", 896 | "maraca", 897 | "marimba", 898 | "mask", 899 | "matchstick", 900 | "maypole", 901 | "maze", 902 | "measuring cup", 903 | "medicine cabinet", 904 | "megalith", 905 | "microphone", 906 | "microwave oven", 907 | "military uniform", 908 | "milk can", 909 | "minibus", 910 | "miniskirt", 911 | "minivan", 912 | "missile", 913 | "mitten", 914 | "mixing bowl", 915 | "mobile home", 916 | "ford model t", 917 | "modem", 918 | "monastery", 919 | "monitor", 920 | "moped", 921 | "mortar and pestle", 922 | "graduation cap", 923 | "mosque", 924 | "mosquito net", 925 | "vespa", 926 | "mountain bike", 927 | "tent", 928 | "computer mouse", 929 | "mousetrap", 930 | "moving van", 931 | "muzzle", 932 | "metal nail", 933 | "neck brace", 934 | "necklace", 935 | "baby pacifier", 936 | "notebook computer", 937 | "obelisk", 938 | "oboe", 939 | "ocarina", 940 | "odometer", 941 | "oil filter", 942 | "pipe organ", 943 | "oscilloscope", 944 | "overskirt", 945 | "bullock cart", 946 | "oxygen mask", 947 | "product packet / packaging", 948 | "paddle", 949 | "paddle wheel", 950 | "padlock", 951 | "paintbrush", 952 | "pajamas", 953 | "palace", 954 | "pan flute", 955 | "paper towel", 956 | "parachute", 957 | "parallel bars", 958 | "park bench", 959 | "parking meter", 960 | "railroad car", 961 | "patio", 962 | "payphone", 963 | "pedestal", 964 | "pencil case", 965 | "pencil sharpener", 966 | "perfume", 967 | "Petri dish", 968 | "photocopier", 969 | "plectrum", 970 | "Pickelhaube", 971 | "picket fence", 972 | "pickup truck", 973 | "pier", 974 | "piggy bank", 975 | "pill bottle", 976 | "pillow", 977 | "ping-pong ball", 978 | "pinwheel", 979 | "pirate ship", 980 | "drink pitcher", 981 | "block plane", 982 | "planetarium", 983 | "plastic bag", 984 | "plate rack", 985 | "farm plow", 986 | "plunger", 987 | "Polaroid camera", 988 | "pole", 989 | "police van", 990 | "poncho", 991 | "pool table", 992 | "soda bottle", 993 | "plant pot", 994 | "potter's wheel", 995 | "power drill", 996 | "prayer rug", 997 | "printer", 998 | "prison", 999 | "missile", 1000 | "projector", 1001 | "hockey puck", 1002 | "punching bag", 1003 | "purse", 1004 | "quill", 1005 | "quilt", 1006 | "race car", 1007 | "racket", 1008 | "radiator", 1009 | "radio", 1010 | "radio telescope", 1011 | "rain barrel", 1012 | "recreational vehicle", 1013 | "fishing casting reel", 1014 | "reflex camera", 1015 | "refrigerator", 1016 | "remote control", 1017 | "restaurant", 1018 | "revolver", 1019 | "rifle", 1020 | "rocking chair", 1021 | "rotisserie", 1022 | "eraser", 1023 | "rugby ball", 1024 | "ruler measuring stick", 1025 | "sneaker", 1026 | "safe", 1027 | "safety pin", 1028 | "salt shaker", 1029 | "sandal", 1030 | "sarong", 1031 | "saxophone", 1032 | "scabbard", 1033 | "weighing scale", 1034 | "school bus", 1035 | "schooner", 1036 | "scoreboard", 1037 | "CRT monitor", 1038 | "screw", 1039 | "screwdriver", 1040 | "seat belt", 1041 | "sewing machine", 1042 | "shield", 1043 | "shoe store", 1044 | "shoji screen / room divider", 1045 | "shopping basket", 1046 | "shopping cart", 1047 | "shovel", 1048 | "shower cap", 1049 | "shower curtain", 1050 | "ski", 1051 | "balaclava ski mask", 1052 | "sleeping bag", 1053 | "slide rule", 1054 | "sliding door", 1055 | "slot machine", 1056 | "snorkel", 1057 | "snowmobile", 1058 | "snowplow", 1059 | "soap dispenser", 1060 | "soccer ball", 1061 | "sock", 1062 | "solar thermal collector", 1063 | "sombrero", 1064 | "soup bowl", 1065 | "keyboard space bar", 1066 | "space heater", 1067 | "space shuttle", 1068 | "spatula", 1069 | "motorboat", 1070 | "spider web", 1071 | "spindle", 1072 | "sports car", 1073 | "spotlight", 1074 | "stage", 1075 | "steam locomotive", 1076 | "through arch bridge", 1077 | "steel drum", 1078 | "stethoscope", 1079 | "scarf", 1080 | "stone wall", 1081 | "stopwatch", 1082 | "stove", 1083 | "strainer", 1084 | "tram", 1085 | "stretcher", 1086 | "couch", 1087 | "stupa", 1088 | "submarine", 1089 | "suit", 1090 | "sundial", 1091 | "sunglasses", 1092 | "sunglasses", 1093 | "sunscreen", 1094 | "suspension bridge", 1095 | "mop", 1096 | "sweatshirt", 1097 | "swim trunks / shorts", 1098 | "swing", 1099 | "electrical switch", 1100 | "syringe", 1101 | "table lamp", 1102 | "tank", 1103 | "tape player", 1104 | "teapot", 1105 | "teddy bear", 1106 | "television", 1107 | "tennis ball", 1108 | "thatched roof", 1109 | "front curtain", 1110 | "thimble", 1111 | "threshing machine", 1112 | "throne", 1113 | "tile roof", 1114 | "toaster", 1115 | "tobacco shop", 1116 | "toilet seat", 1117 | "torch", 1118 | "totem pole", 1119 | "tow truck", 1120 | "toy store", 1121 | "tractor", 1122 | "semi-trailer truck", 1123 | "tray", 1124 | "trench coat", 1125 | "tricycle", 1126 | "trimaran", 1127 | "tripod", 1128 | "triumphal arch", 1129 | "trolleybus", 1130 | "trombone", 1131 | "hot tub", 1132 | "turnstile", 1133 | "typewriter keyboard", 1134 | "umbrella", 1135 | "unicycle", 1136 | "upright piano", 1137 | "vacuum cleaner", 1138 | "vase", 1139 | "vaulted or arched ceiling", 1140 | "velvet fabric", 1141 | "vending machine", 1142 | "vestment", 1143 | "viaduct", 1144 | "violin", 1145 | "volleyball", 1146 | "waffle iron", 1147 | "wall clock", 1148 | "wallet", 1149 | "wardrobe", 1150 | "military aircraft", 1151 | "sink", 1152 | "washing machine", 1153 | "water bottle", 1154 | "water jug", 1155 | "water tower", 1156 | "whiskey jug", 1157 | "whistle", 1158 | "hair wig", 1159 | "window screen", 1160 | "window shade", 1161 | "Windsor tie", 1162 | "wine bottle", 1163 | "airplane wing", 1164 | "wok", 1165 | "wooden spoon", 1166 | "wool", 1167 | "split-rail fence", 1168 | "shipwreck", 1169 | "sailboat", 1170 | "yurt", 1171 | "website", 1172 | "comic book", 1173 | "crossword", 1174 | "traffic or street sign", 1175 | "traffic light", 1176 | "dust jacket", 1177 | "menu", 1178 | "plate", 1179 | "guacamole", 1180 | "consomme", 1181 | "hot pot", 1182 | "trifle", 1183 | "ice cream", 1184 | "popsicle", 1185 | "baguette", 1186 | "bagel", 1187 | "pretzel", 1188 | "cheeseburger", 1189 | "hot dog", 1190 | "mashed potatoes", 1191 | "cabbage", 1192 | "broccoli", 1193 | "cauliflower", 1194 | "zucchini", 1195 | "spaghetti squash", 1196 | "acorn squash", 1197 | "butternut squash", 1198 | "cucumber", 1199 | "artichoke", 1200 | "bell pepper", 1201 | "cardoon", 1202 | "mushroom", 1203 | "Granny Smith apple", 1204 | "strawberry", 1205 | "orange", 1206 | "lemon", 1207 | "fig", 1208 | "pineapple", 1209 | "banana", 1210 | "jackfruit", 1211 | "cherimoya (custard apple)", 1212 | "pomegranate", 1213 | "hay", 1214 | "carbonara", 1215 | "chocolate syrup", 1216 | "dough", 1217 | "meatloaf", 1218 | "pizza", 1219 | "pot pie", 1220 | "burrito", 1221 | "red wine", 1222 | "espresso", 1223 | "tea cup", 1224 | "eggnog", 1225 | "mountain", 1226 | "bubble", 1227 | "cliff", 1228 | "coral reef", 1229 | "geyser", 1230 | "lakeshore", 1231 | "promontory", 1232 | "sandbar", 1233 | "beach", 1234 | "valley", 1235 | "volcano", 1236 | "baseball player", 1237 | "bridegroom", 1238 | "scuba diver", 1239 | "rapeseed", 1240 | "daisy", 1241 | "yellow lady's slipper", 1242 | "corn", 1243 | "acorn", 1244 | "rose hip", 1245 | "horse chestnut seed", 1246 | "coral fungus", 1247 | "agaric", 1248 | "gyromitra", 1249 | "stinkhorn mushroom", 1250 | "earth star fungus", 1251 | "hen of the woods mushroom", 1252 | "bolete", 1253 | "corn cob", 1254 | "toilet paper", 1255 | ) 1256 | -------------------------------------------------------------------------------- /vq_clip/modeling_vq_adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.optim import SGD, Adagrad 4 | 5 | from transformers import PreTrainedModel, PretrainedConfig 6 | from .modules import Block 7 | from vector_quantize_pytorch import VectorQuantize, ResidualVQ 8 | from .perplexity import calculate_perplexity 9 | 10 | 11 | class VQAdapterConfig(PretrainedConfig): 12 | def __init__( 13 | self, 14 | # vq args 15 | # All of these have 'vq_' appended to the start 16 | vq_codebook_size: int = 32, 17 | vq_codebook_dim: int = 32, 18 | vq_heads: int = 32, 19 | vq_separate_codebook_per_head: bool = True, 20 | vq_decay: float = 0.85, 21 | vq_eps: float = 1e-5, 22 | vq_kmeans_init: bool = False, 23 | vq_kmeans_iters: int = 20, 24 | vq_sync_kmeans: bool = True, 25 | vq_use_cosine_sim: bool = False, 26 | vq_threshold_ema_dead_code: int = 0, 27 | vq_channel_last: bool = True, 28 | vq_accept_image_fmap: bool = False, 29 | vq_commitment_weight: float = 1.0, 30 | vq_commitment_use_cross_entropy_loss: bool = False, 31 | vq_orthogonal_reg_weight: float = 0.0, 32 | vq_orthogonal_reg_active_codes_only: bool = False, 33 | vq_orthogonal_reg_max_codes: bool = None, 34 | vq_stochastic_sample_codes: bool = True, 35 | vq_sample_codebook_temp: float = 1.0, 36 | vq_straight_through: bool = False, 37 | vq_reinmax: bool = False, 38 | # using reinmax for improved straight-through, assuming straight through helps at all 39 | vq_sync_codebook: bool = False, 40 | vq_sync_affine_param: bool = False, 41 | vq_ema_update: bool = True, 42 | vq_learnable_codebook: bool = False, 43 | vq_affine_param: bool = False, 44 | vq_affine_param_batch_decay: float = 0.99, 45 | vq_affine_param_codebook_decay: float = 0.9, 46 | # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf 47 | vq_sync_update_v: float = 0.0, 48 | 49 | # codebook optimizer 50 | codebook_lr: float = 10., 51 | 52 | # rq_specific args 53 | rq_quantize_dropout=False, 54 | rq_quantize_dropout_cutoff_index=0, 55 | rq_quantize_dropout_multiple_of=1, 56 | 57 | # nn args 58 | is_rq: bool = True, 59 | mlp_dim: int = 1028, 60 | mlp_hidden_dim: int = 512, 61 | mlp_layers: int = 1, 62 | # Default clip dim for L models 63 | clip_dim: int = 768, 64 | **kwargs 65 | ): 66 | super().__init__(**kwargs) 67 | 68 | self.vq_codebook_size = vq_codebook_size 69 | self.vq_codebook_dim = vq_codebook_dim 70 | self.vq_heads = vq_heads 71 | self.vq_separate_codebook_per_head = vq_separate_codebook_per_head 72 | self.vq_decay = vq_decay 73 | self.vq_eps = vq_eps 74 | self.vq_kmeans_init = vq_kmeans_init 75 | self.vq_kmeans_iters = vq_kmeans_iters 76 | self.vq_sync_kmeans = vq_sync_kmeans 77 | self.vq_use_cosine_sim = vq_use_cosine_sim 78 | self.vq_threshold_ema_dead_code = vq_threshold_ema_dead_code 79 | self.vq_channel_last = vq_channel_last 80 | self.vq_accept_image_fmap = vq_accept_image_fmap 81 | self.vq_commitment_weight = vq_commitment_weight 82 | self.vq_commitment_use_cross_entropy_loss = vq_commitment_use_cross_entropy_loss 83 | self.vq_orthogonal_reg_weight = vq_orthogonal_reg_weight 84 | self.vq_orthogonal_reg_active_codes_only = vq_orthogonal_reg_active_codes_only 85 | self.vq_orthogonal_reg_max_codes = vq_orthogonal_reg_max_codes 86 | self.vq_stochastic_sample_codes = vq_stochastic_sample_codes 87 | self.vq_sample_codebook_temp = vq_sample_codebook_temp 88 | self.vq_straight_through = vq_straight_through 89 | self.vq_reinmax = vq_reinmax 90 | self.vq_sync_codebook = vq_sync_codebook 91 | self.vq_sync_affine_param = vq_sync_affine_param 92 | self.vq_ema_update = vq_ema_update 93 | self.vq_learnable_codebook = vq_learnable_codebook 94 | self.vq_affine_param = vq_affine_param 95 | self.vq_affine_param_batch_decay = vq_affine_param_batch_decay 96 | self.vq_affine_param_codebook_decay = vq_affine_param_codebook_decay 97 | self.vq_sync_update_v = vq_sync_update_v 98 | 99 | self.codebook_lr=codebook_lr 100 | 101 | self.rq_quantize_dropout=rq_quantize_dropout 102 | self.rq_quantize_dropout_cutoff_index=rq_quantize_dropout_cutoff_index 103 | self.rq_quantize_dropout_multiple_of=rq_quantize_dropout_multiple_of 104 | 105 | self.is_rq = is_rq 106 | self.mlp_dim = mlp_dim 107 | self.mlp_hidden_dim = mlp_hidden_dim 108 | self.mlp_layers = mlp_layers 109 | self.clip_dim = clip_dim 110 | 111 | 112 | class VQAdapterModel(PreTrainedModel): 113 | config_class = VQAdapterConfig 114 | 115 | def __init__(self, config: VQAdapterConfig): 116 | super().__init__(config) 117 | 118 | quantizer_args = { 119 | k.removeprefix("vq_"): v 120 | for k, v in config.to_dict().items() 121 | if k.startswith("vq_") 122 | } 123 | 124 | #if quantizer_args['learnable_codebook']: 125 | #quantizer_args['in_place_codebook_optimizer'] = lambda *args, **kwargs: Adagrad(*args, lr=config.codebook_lr, **kwargs) 126 | 127 | quantizer_args["dim"] = config.clip_dim 128 | if config.is_rq: 129 | rq_args = { 130 | k.removeprefix("rq_"): v 131 | for k, v in config.to_dict().items() 132 | if k.startswith("rq_") 133 | } 134 | quantizer_args.update(rq_args) 135 | quantizer_args["heads"] = 1 136 | quantizer_args["num_quantizers"] = config.vq_heads 137 | self.vq = ResidualVQ(**quantizer_args) 138 | else: 139 | self.vq = VectorQuantize(**quantizer_args) 140 | 141 | self.in_feature_net = nn.Sequential( 142 | # input is assumed to an already normalized clip embedding 143 | nn.Linear(config.clip_dim, config.mlp_dim, bias=False), 144 | nn.GELU(), 145 | nn.LayerNorm(config.mlp_dim), 146 | *[ 147 | Block(config.mlp_dim, config.mlp_hidden_dim) 148 | for _ in range(config.mlp_layers) 149 | ], 150 | nn.Linear(config.mlp_dim, config.clip_dim, bias=False), 151 | # normalize before passing to VQ? 152 | # nn.GELU(), 153 | # nn.LayerNorm(args.clip_dim), 154 | ) 155 | 156 | self.out_feature_net = nn.Identity() 157 | 158 | def decode(self, codes: torch.LongTensor): 159 | z = self.vq.get_codes_from_indices(codes) 160 | z = self.vq.project_out(z) 161 | return z 162 | 163 | def _init_weights(self, _): 164 | pass 165 | 166 | def forward(self, z, return_perplexity=False): 167 | """ 168 | z: B by D 169 | """ 170 | z = self.in_feature_net(z) 171 | z, codes, loss = self.vq(z.unsqueeze(1)) 172 | loss = loss.mean() 173 | z = z.squeeze(1) 174 | codes = codes.squeeze(1) 175 | if return_perplexity: 176 | perplexity = calculate_perplexity(codes, self.config.vq_codebook_size) 177 | else: 178 | perplexity = None 179 | z = self.out_feature_net(z) 180 | 181 | return dict(z=z, codes=codes, perplexity=perplexity, loss=loss) 182 | -------------------------------------------------------------------------------- /vq_clip/modeling_vq_clip.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import Optional, Union, Tuple 3 | from dataclasses import dataclass 4 | import torch 5 | from transformers.models.clip.modeling_clip import CLIPOutput, clip_loss 6 | from transformers import CLIPConfig, CLIPModel, PreTrainedModel, PretrainedConfig 7 | 8 | from .modeling_vq_adapter import VQAdapterModel, VQAdapterConfig 9 | 10 | 11 | @dataclass 12 | class VQCLIPOutput(CLIPOutput): 13 | text_codes: torch.LongTensor = None 14 | image_codes: torch.LongTensor = None 15 | quantization_loss: torch.FloatTensor = None 16 | contrastive_loss: torch.FloatTensor = None 17 | perplexity: torch.FloatTensor = None 18 | 19 | 20 | class VQCLIPConfig(PretrainedConfig): 21 | model_type = "VQCLIP" 22 | 23 | def __init__( 24 | self, 25 | clip_config_dict: dict = CLIPConfig().to_dict(), 26 | vision_vq_adapter_config_dict: Optional[dict] = VQAdapterConfig().to_dict(), 27 | text_vq_adapter_config_dict: Optional[dict] = None, 28 | **kwargs, 29 | ): 30 | self.clip_config_dict = clip_config_dict 31 | self.vision_vq_adapter_config_dict = vision_vq_adapter_config_dict 32 | self.text_vq_adapter_config_dict = text_vq_adapter_config_dict 33 | super().__init__(**kwargs) 34 | 35 | 36 | class VQCLIPModel(PreTrainedModel): 37 | config_class = VQCLIPConfig 38 | 39 | def __init__(self, config: VQCLIPConfig): 40 | super().__init__(config) 41 | 42 | self.clip_config = CLIPConfig.from_dict(config.clip_config_dict) 43 | self.clip_model = CLIPModel(self.clip_config) 44 | 45 | self.vision_vq_adapter, self.text_vq_adapter = None, None 46 | 47 | if config.vision_vq_adapter_config_dict: 48 | self.vision_vq_adapter = VQAdapterModel( 49 | VQAdapterConfig.from_dict(config.vision_vq_adapter_config_dict) 50 | ) 51 | if config.text_vq_adapter_config_dict: 52 | self.text_vq_adapter = VQAdapterModel( 53 | VQAdapterConfig.from_dict(config.text_vq_adapter_config_dict) 54 | ) 55 | 56 | def _init_weights(self, module): 57 | pass 58 | 59 | def get_text_features( 60 | self, 61 | input_ids: Optional[torch.Tensor] = None, 62 | attention_mask: Optional[torch.Tensor] = None, 63 | position_ids: Optional[torch.Tensor] = None, 64 | output_attentions: Optional[bool] = None, 65 | output_hidden_states: Optional[bool] = None, 66 | return_dict: Optional[bool] = None, 67 | return_codes: Optional[bool] = None, 68 | ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, Union[torch.FloatTensor, None]]]: 69 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 70 | output_attentions = ( 71 | output_attentions 72 | if output_attentions is not None 73 | else self.clip_config.output_attentions 74 | ) 75 | output_hidden_states = ( 76 | output_hidden_states 77 | if output_hidden_states is not None 78 | else self.clip_config.output_hidden_states 79 | ) 80 | return_dict = ( 81 | return_dict if return_dict is not None else self.clip_config.use_return_dict 82 | ) 83 | 84 | text_outputs = self.clip_model.text_model( 85 | input_ids=input_ids, 86 | attention_mask=attention_mask, 87 | position_ids=position_ids, 88 | output_attentions=output_attentions, 89 | output_hidden_states=output_hidden_states, 90 | return_dict=return_dict, 91 | ) 92 | 93 | pooled_output = text_outputs[1] 94 | text_features = self.clip_model.text_projection(pooled_output) 95 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 96 | 97 | if self.text_vq_adapter: 98 | _res = self.text_vq_adapter(text_features) 99 | text_features = _res['z'] 100 | text_codes = _res['codes'] 101 | else: 102 | text_codes = None 103 | 104 | if return_codes: 105 | return text_features, text_codes 106 | 107 | return text_features 108 | 109 | def get_image_features( 110 | self, 111 | pixel_values: Optional[torch.FloatTensor] = None, 112 | output_attentions: Optional[bool] = None, 113 | output_hidden_states: Optional[bool] = None, 114 | return_dict: Optional[bool] = None, 115 | return_codes: Optional[bool] = None, 116 | ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, Union[torch.FloatTensor, None]]]: 117 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 118 | output_attentions = ( 119 | output_attentions 120 | if output_attentions is not None 121 | else self.clip_config.output_attentions 122 | ) 123 | output_hidden_states = ( 124 | output_hidden_states 125 | if output_hidden_states is not None 126 | else self.clip_config.output_hidden_states 127 | ) 128 | return_dict = ( 129 | return_dict if return_dict is not None else self.clip_config.use_return_dict 130 | ) 131 | 132 | vision_outputs = self.clip_model.vision_model( 133 | pixel_values=pixel_values, 134 | output_attentions=output_attentions, 135 | output_hidden_states=output_hidden_states, 136 | return_dict=return_dict, 137 | ) 138 | 139 | pooled_output = vision_outputs[1] # pooled_output 140 | image_features = self.clip_model.visual_projection(pooled_output) 141 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 142 | if self.vision_vq_adapter: 143 | _res = self.vision_vq_adapter(image_features) 144 | image_features = _res['z'] 145 | image_codes = _res['codes'] 146 | else: 147 | image_codes = None 148 | 149 | if return_codes: 150 | return image_features, image_codes 151 | 152 | return image_features 153 | 154 | @property 155 | def logit_scale(self): 156 | return self.clip_model.logit_scale 157 | 158 | def forward( 159 | self, 160 | input_ids: Optional[torch.LongTensor] = None, 161 | pixel_values: Optional[torch.FloatTensor] = None, 162 | attention_mask: Optional[torch.Tensor] = None, 163 | position_ids: Optional[torch.LongTensor] = None, 164 | return_loss: Optional[bool] = None, 165 | output_attentions: Optional[bool] = None, 166 | output_hidden_states: Optional[bool] = None, 167 | return_dict: Optional[bool] = None, 168 | return_perplexity: Optional[bool]=None, 169 | ) -> Union[Tuple, VQCLIPOutput]: 170 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 171 | output_attentions = ( 172 | output_attentions 173 | if output_attentions is not None 174 | else self.clip_config.output_attentions 175 | ) 176 | output_hidden_states = ( 177 | output_hidden_states 178 | if output_hidden_states is not None 179 | else self.clip_config.output_hidden_states 180 | ) 181 | return_dict = ( 182 | return_dict if return_dict is not None else self.clip_config.use_return_dict 183 | ) 184 | 185 | vision_outputs = self.clip_model.vision_model( 186 | pixel_values=pixel_values, 187 | output_attentions=output_attentions, 188 | output_hidden_states=output_hidden_states, 189 | return_dict=return_dict, 190 | ) 191 | 192 | text_outputs = self.clip_model.text_model( 193 | input_ids=input_ids, 194 | attention_mask=attention_mask, 195 | position_ids=position_ids, 196 | output_attentions=output_attentions, 197 | output_hidden_states=output_hidden_states, 198 | return_dict=return_dict, 199 | ) 200 | 201 | image_embeds = vision_outputs[1] 202 | image_embeds = self.clip_model.visual_projection(image_embeds) 203 | 204 | text_embeds = text_outputs[1] 205 | text_embeds = self.clip_model.text_projection(text_embeds) 206 | 207 | # normalized features 208 | image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) 209 | text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) 210 | 211 | # quantization 212 | image_codes, text_codes = None, None 213 | perplexity = 0.0 214 | vq_loss = 0.0 215 | if self.text_vq_adapter: 216 | res = self.text_vq_adapter(text_embeds, return_perplexity=return_perplexity) 217 | text_embeds = res["z"] 218 | text_codes = res["codes"] 219 | perplexity += res["perplexity"] if return_perplexity else 0.0 220 | vq_loss += res["loss"] 221 | if self.vision_vq_adapter: 222 | res = self.vision_vq_adapter(image_embeds, return_perplexity=return_perplexity) 223 | image_embeds = res["z"] 224 | image_codes = res["codes"] 225 | perplexity += res["perplexity"] if return_perplexity else 0.0 226 | vq_loss += res["loss"] 227 | if self.vision_vq_adapter and self.text_vq_adapter: 228 | # averages 229 | perplexity = perplexity / 2.0 230 | vq_loss = vq_loss / 2.0 231 | 232 | # cosine similarity as logits 233 | logit_scale = self.clip_model.logit_scale.exp() 234 | logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale 235 | logits_per_image = logits_per_text.t() 236 | 237 | loss = None 238 | if return_loss: 239 | loss = clip_loss(logits_per_text) 240 | 241 | total_loss = None 242 | if loss is not None: 243 | total_loss = loss + vq_loss 244 | 245 | if not return_dict: 246 | output = ( 247 | logits_per_image, 248 | logits_per_text, 249 | text_embeds, 250 | image_embeds, 251 | text_outputs, 252 | vision_outputs, 253 | ) 254 | return ((loss,) + output) if loss is not None else output 255 | 256 | return VQCLIPOutput( 257 | loss=total_loss, 258 | quantization_loss=vq_loss, 259 | contrastive_loss=loss, 260 | logits_per_image=logits_per_image, 261 | logits_per_text=logits_per_text, 262 | text_embeds=text_embeds, 263 | text_codes=text_codes, 264 | image_embeds=image_embeds, 265 | image_codes=image_codes, 266 | ) 267 | 268 | @staticmethod 269 | def from_pretrained_clip(clip_path: str, vision_vq_adapter_path: Optional[str] = None, text_vq_adapter_path: Optional[str] = None): 270 | """ 271 | load only the adapter from the vq_clip_path, and load the clip model 272 | from the clip_path 273 | """ 274 | 275 | assert not (text_vq_adapter_path is None and vision_vq_adapter_path is None) 276 | 277 | clip_config = CLIPConfig.from_pretrained(clip_path).to_dict() 278 | if vision_vq_adapter_path is not None: 279 | vision_vq_config = VQAdapterConfig.from_pretrained(vision_vq_adapter_path).to_dict() 280 | else: vision_vq_config = None 281 | if text_vq_adapter_path is not None: 282 | text_vq_config = VQAdapterConfig.from_pretrained(text_vq_adapter_path).to_dict() 283 | else: text_vq_config = None 284 | 285 | vq_clip_config = VQCLIPConfig(clip_config_dict=clip_config, vision_vq_adapter_config_dict=vision_vq_config, text_vq_adapter_config_dict=text_vq_config) 286 | 287 | init_provider = contextlib.suppress 288 | try: 289 | from accelerate import init_empty_weights 290 | init_provider = init_empty_weights 291 | except ImportError: 292 | print("Could not do vq-clip lazy init") 293 | 294 | with init_provider(): 295 | vq_clip = VQCLIPModel(vq_clip_config) 296 | 297 | clip: CLIPModel = CLIPModel.from_pretrained(clip_path) 298 | vq_clip.clip_model = clip 299 | 300 | if vision_vq_adapter_path is not None: 301 | vision_vq_adapter = VQAdapterModel.from_pretrained(vision_vq_adapter_path) 302 | vq_clip.vision_vq_adapter = vision_vq_adapter 303 | if text_vq_adapter_path is not None: 304 | text_vq_adapter = VQAdapterModel.from_pretrained(text_vq_adapter_path) 305 | vq_clip.text_vq_adapter = text_vq_adapter 306 | 307 | return vq_clip 308 | -------------------------------------------------------------------------------- /vq_clip/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class Bottleneck(nn.Sequential): 7 | def __init__(self, in_dim, hidden_dim): 8 | c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False) 9 | c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False) 10 | super().__init__(*[c_fc1, c_fc2]) 11 | 12 | 13 | class Block(nn.Module): 14 | def __init__(self, in_dim, hidden_dim) -> None: 15 | super().__init__() 16 | self.norm = nn.LayerNorm(in_dim) 17 | self.mlp = MLP(in_dim, hidden_dim) 18 | 19 | def forward(self, x: torch.Tensor) -> torch.Tensor: 20 | x = x + self.mlp(self.norm(x)) 21 | return x 22 | 23 | 24 | class MLP(nn.Module): 25 | def __init__(self, in_dim, hidden_dim): 26 | super().__init__() 27 | self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False) 28 | self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False) 29 | self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False) 30 | 31 | def forward(self, x: torch.Tensor): 32 | x = F.silu(self.c_fc1(x)) * self.c_fc2(x) 33 | x = self.c_proj(x) 34 | return x 35 | -------------------------------------------------------------------------------- /vq_clip/perplexity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def calculate_perplexity(codes, codebook_size, null_index=-1): 4 | """ 5 | Perplexity is 2^(H(p)) where H(p) is the entropy over the codebook likelyhood 6 | 7 | the null index is assumed to be -1, perplexity is only calculated over the 8 | non null codes 9 | """ 10 | dtype, device = codes.dtype, codes.device 11 | codes = codes.flatten() 12 | codes = codes[codes!= null_index] 13 | src = torch.ones_like(codes) 14 | counts = torch.zeros(codebook_size).to(dtype).to(device) 15 | counts = counts.scatter_add_(0, codes, src) 16 | 17 | probs = counts / codes.numel() 18 | # Entropy H(x) when p(x)=0 is defined as 0 19 | logits = torch.log2(probs) 20 | logits[probs == 0.0] = 0.0 21 | entropy = -torch.sum(probs * logits) 22 | return 2**entropy 23 | 24 | -------------------------------------------------------------------------------- /vq_clip/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import lightning.pytorch as pl 4 | from torch.optim import AdamW 5 | from transformers import ( 6 | CLIPModel, 7 | CLIPProcessor, 8 | ) 9 | from transformers.models.clip.modeling_clip import clip_loss 10 | 11 | from .modeling_vq_clip import VQCLIPConfig, VQCLIPModel 12 | from .modeling_vq_adapter import VQAdapterModel, VQAdapterConfig 13 | from .cosine_annealing_warmup import CosineAnnealingWarmupRestarts 14 | from .eval import zero_shot_eval 15 | 16 | 17 | def models_eq(model1:nn.Module, model2:nn.Module): 18 | model1 = model1.to(model2.device) 19 | sd1 = model1.state_dict() 20 | sd2 = model2.state_dict() 21 | def _check(sd1, sd2): 22 | for k,v in sd1.items(): 23 | get = sd2.get(k) 24 | if get is None: 25 | print(k, "not in model2") 26 | return False 27 | if not torch.equal(v, get): 28 | print(k, v, " ne ", get) 29 | return False 30 | return True 31 | 32 | return _check(sd1, sd2) and _check(sd2, sd1) 33 | 34 | 35 | def clip_loss_from_embeds(text_embeds, image_embeds, logit_scale): 36 | logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale 37 | return clip_loss(logits_per_text) 38 | 39 | 40 | class LightningVQCLIPTrainer(pl.LightningModule): 41 | """ 42 | Trainer for a vision VQ adapter on top of a CLIP vision tower 43 | """ 44 | 45 | def __init__( 46 | self, 47 | vision_vq_config_path: str = "./model_conf/vq-ViT-L-14-k64/config.json", 48 | # pretrained clip args 49 | pretrained_clip_url: str = "openai/clip-vit-base-patch32", 50 | # training_specific args 51 | warmup_steps: int = 100, 52 | max_lr: float = 8e-4, 53 | min_lr: float = 5e-5, 54 | lr_gamma: float = 0.4, 55 | lr_cycle_steps: int = 500, 56 | torch_compile: bool = False, 57 | # eval 58 | imagenet_path: str = "", 59 | validation_batch_size: int = 512, 60 | ): 61 | super().__init__() 62 | self.vision_vq_config_path = vision_vq_config_path 63 | self.vision_vq_config = VQAdapterConfig.from_pretrained(vision_vq_config_path) 64 | self.vision_vq_adapter = VQAdapterModel(self.vision_vq_config) 65 | 66 | self.clip_url = pretrained_clip_url 67 | self.imagenet_path = imagenet_path 68 | self.validation_batch_size = validation_batch_size 69 | 70 | 71 | if torch_compile: 72 | self.vision_vq_adapter = torch.compile(self.vision_vq_adapter) 73 | 74 | self.warmup_steps = warmup_steps 75 | self.max_lr = max_lr 76 | self.min_lr = min_lr 77 | self.lr_gamma = lr_gamma 78 | self.lr_cycle_steps = lr_cycle_steps 79 | 80 | self.save_hyperparameters() 81 | 82 | def on_save_checkpoint(self, _): 83 | self.save_hf(self.logger.log_dir + "/hf/") 84 | 85 | def save_hf(self, path: str = ""): 86 | self.vision_vq_adapter.save_pretrained(path) 87 | 88 | def step(self, img_emb, text_emb): 89 | """ 90 | img_emb normalized image embedding tensor batch from CLIP 91 | text_emb normalized text embedding tensor batch from CLIP 92 | """ 93 | with torch.no_grad(): 94 | # TODO Assumes logit_scale.exp() = 100. 95 | pre_quant_contrastive_loss = clip_loss_from_embeds(img_emb, text_emb, 100.0) 96 | 97 | res = self.vision_vq_adapter(img_emb, return_perplexity=True) 98 | img_emb = res["z"] 99 | quant_loss = res["loss"] 100 | perplexity = res["perplexity"] 101 | 102 | img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True) 103 | 104 | contrastive_loss = clip_loss_from_embeds(img_emb, text_emb, 100.0) 105 | 106 | loss = contrastive_loss + quant_loss 107 | 108 | logs = dict( 109 | quant_loss=quant_loss, 110 | contrastive_loss=contrastive_loss, 111 | loss=loss, 112 | perplexity=perplexity, 113 | pre_quant_contrastive_loss=pre_quant_contrastive_loss, 114 | ) 115 | 116 | return loss, logs 117 | 118 | def validation_step(self, batch, batch_idx): 119 | if batch_idx == 0: 120 | # Builds temporary clip model 121 | tmp_dir = "/tmp/vq-vision/" 122 | self.save_hf(tmp_dir) 123 | vq_clip = VQCLIPModel.from_pretrained_clip(self.clip_url, vision_vq_adapter_path=tmp_dir) 124 | vq_clip.to(self.device) 125 | # uncomment to see how performance w/o adapters is the same as normal pretrained CLIP 126 | # vq_clip.vision_vq_adapter = None 127 | # vq_clip.text_vq_adapter = None 128 | 129 | processor = CLIPProcessor.from_pretrained(self.clip_url) 130 | 131 | if self.imagenet_path is not None: 132 | with torch.no_grad(): 133 | with torch.autocast("cuda"): 134 | top1, top5 = zero_shot_eval( 135 | vq_clip, 136 | processor, 137 | self.imagenet_path, 138 | self.validation_batch_size, 139 | ) 140 | self.log_dict(dict(imagenet_top1=top1, imagenet_top5=top5)) 141 | 142 | self.eval() 143 | img_emb, text_emb = batch 144 | loss, logs = self.step(img_emb, text_emb) 145 | 146 | self.log_dict({"v_" + k: v for k, v in logs.items()}) 147 | return loss 148 | 149 | def training_step(self, batch, _): 150 | img_emb, text_emb = batch 151 | loss, logs = self.step(img_emb, text_emb) 152 | self.log_dict({"t_" + k: v for k, v in logs.items()}) 153 | return loss 154 | 155 | def configure_optimizers(self): 156 | optimizer = AdamW(params=self.vision_vq_adapter.parameters()) 157 | learning_rate_scheduler = CosineAnnealingWarmupRestarts( 158 | optimizer, 159 | first_cycle_steps=self.lr_cycle_steps, 160 | max_lr=self.max_lr, 161 | min_lr=self.min_lr, 162 | warmup_steps=self.warmup_steps, 163 | gamma=self.lr_gamma, 164 | ) 165 | return [optimizer], [ 166 | {"scheduler": learning_rate_scheduler, "interval": "step", "frequency": 1} 167 | ] 168 | --------------------------------------------------------------------------------