├── .gitignore ├── LICENSE.txt ├── README.md ├── compute_ablations.py ├── compute_complete_text_set.py ├── compute_prs.py ├── compute_segmentations.py ├── compute_siglip.py ├── compute_text_projection.py ├── compute_text_set_projection.py ├── compute_use_specific_heads.py ├── demo.ipynb ├── demo_siglip.ipynb ├── environment.yml ├── images ├── catdog.png └── teaser.png ├── nns.ipynb ├── output_dir └── binary_waterbirds_labels.npy ├── prs_hook.py ├── text_descriptions ├── google_3498_english.txt └── image_descriptions_general.txt └── utils ├── __init__.py ├── binary_waterbirds.py ├── constants.py ├── cub_classes.py ├── factory.py ├── hook.py ├── imagenet_classes.py ├── imagenet_segmentation.py ├── misc.py ├── model.py ├── model_configs ├── EVA01-g-14-plus.json ├── EVA01-g-14.json ├── EVA02-B-16.json ├── EVA02-E-14-plus.json ├── EVA02-E-14.json ├── EVA02-L-14-336.json ├── EVA02-L-14.json ├── ViT-B-16-plus-240.json ├── ViT-B-16-plus.json ├── ViT-B-16.json ├── ViT-B-32-plus-256.json ├── ViT-B-32-quickgelu.json ├── ViT-B-32.json ├── ViT-H-14.json ├── ViT-H-16.json ├── ViT-L-14-280.json ├── ViT-L-14-336.json ├── ViT-L-14.json ├── ViT-L-16-320.json ├── ViT-L-16.json ├── ViT-M-16-alt.json ├── ViT-M-16.json ├── ViT-M-32-alt.json ├── ViT-M-32.json ├── ViT-S-16-alt.json ├── ViT-S-16.json ├── ViT-S-32-alt.json ├── ViT-S-32.json ├── ViT-bigG-14.json ├── ViT-e-14.json ├── ViT-g-14.json ├── coca_ViT-B-32.json ├── coca_ViT-L-14.json ├── coca_base.json ├── coca_roberta-ViT-B-32.json ├── mt5-base-ViT-B-32.json ├── mt5-xl-ViT-H-14.json ├── roberta-ViT-B-32.json ├── swin_base_patch4_window7_224.json ├── vit_medium_patch16_gap_256.json ├── vit_relpos_medium_patch16_cls_224.json ├── xlm-roberta-base-ViT-B-32.json └── xlm-roberta-large-ViT-H-14.json ├── modified_resnet.py ├── openai_models.py ├── openai_templates.py ├── pretrained.py ├── segmentation_utils.py ├── siglip ├── configuration_siglip.py ├── image_processing_siglip.py ├── image_processing_siglip_fast.py ├── modeling_siglip.py └── processing_siglip.py ├── timm_model.py ├── tokenizer.py ├── transform.py ├── transformer.py ├── visualization.py └── vocab └── bpe_simple_vocab_16e6.txt.gz /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | */*.mat 4 | utils/__pycache__ 5 | imagenet_seg/ 6 | run/ -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yossi Gandelsman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Interpreting CLIP's Image Representation via Text-Based Decomposition 2 | Official PyTorch Implementation 3 | 4 | ### [Paper](https://arxiv.org/abs/2310.05916) | [Project Page](https://yossigandelsman.github.io/clip_decomposition/) 5 | 6 | [Yossi Gandelsman](https://yossigandelsman.github.io/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/) and [Jacob Steinhardt](https://jsteinhardt.stat.berkeley.edu/) 7 | 8 | ![Teaser](images/teaser.png) 9 | 10 | 🔥 Check out [our latest work](https://yossigandelsman.github.io/clip_neurons/) on interpreting neurons in CLIP with text. 11 | 12 | ### Setup 13 | We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment: 14 | 15 | ```bash 16 | conda env create -f environment.yml 17 | conda activate prsclip 18 | ``` 19 | ### Preprocessing 20 | To obtain the projected residual stream components for the ImageNet validation set, including the contributions from multi-head attentions and MLPs, please run one of the following instructions: 21 | 22 | ```bash 23 | python compute_prs.py --dataset imagenet --device cuda:0 --model ViT-H-14 --pretrained laion2b_s32b_b79k --data_path 24 | python compute_prs.py --dataset imagenet --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path 25 | python compute_prs.py --dataset imagenet --device cuda:0 --model ViT-B-16 --pretrained laion2b_s34b_b88k --data_path 26 | ``` 27 | 28 | To obtain the precomputed text representations of the ImageNet classes, please run: 29 | ```bash 30 | python compute_text_projection.py --dataset imagenet --device cuda:0 --model ViT-H-14 --pretrained laion2b_s32b_b79k 31 | python compute_text_projection.py --dataset imagenet --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k 32 | python compute_text_projection.py --dataset imagenet --device cuda:0 --model ViT-B-16 --pretrained laion2b_s34b_b88k 33 | ``` 34 | 35 | ### Mean-ablations 36 | To verify that the MLPs and the attention from the class token to itself can be mean-ablated, please run: 37 | 38 | ```bash 39 | python compute_ablations.py --model ViT-H-14 40 | python compute_ablations.py --model ViT-L-14 41 | python compute_ablations.py --model ViT-B-16 42 | ``` 43 | 44 | ### Convert text labels to representation 45 | To convert the text labels for TextSpan to CLIP text representations, please run: 46 | 47 | ```bash 48 | python compute_text_set_projection.py --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path text_descriptions/google_3498_english.txt 49 | python compute_text_set_projection.py --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path text_descriptions/image_descriptions_general.txt 50 | ``` 51 | 52 | ### ImageNet segmentation 53 | Please download the dataset from [here](http://calvin-vision.net/bigstuff/proj-imagenet/data/gtsegs_ijcv.mat): 54 | 55 | ```bash 56 | mkdir imagenet_seg 57 | cd imagenet_seg 58 | wget http://calvin-vision.net/bigstuff/proj-imagenet/data/gtsegs_ijcv.mat 59 | ``` 60 | 61 | To get the evaluation results, please run: 62 | 63 | ```bash 64 | python compute_segmentations.py --device cuda:0 --model ViT-H-14 --pretrained laion2b_s32b_b79k --data_path imagenet_seg/gtsegs_ijcv.mat --save_img 65 | python compute_segmentations.py --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path imagenet_seg/gtsegs_ijcv.mat --save_img 66 | python compute_segmentations.py --device cuda:0 --model ViT-B-16 --pretrained laion2b_s34b_b88k --data_path imagenet_seg/gtsegs_ijcv.mat --save_img 67 | ``` 68 | Save the results with the `--save_img` flag. 69 | 70 | 71 | ### TextSpan 72 | 73 | To find meaningful directions for all the attenion heads, run: 74 | ```bash 75 | python compute_complete_text_set.py --device cuda:0 --model ViT-B-16 --texts_per_head 20 --num_of_last_layers 4 --text_descriptions image_descriptions_general 76 | python compute_complete_text_set.py --device cuda:0 --model ViT-L-14 --texts_per_head 20 --num_of_last_layers 4 --text_descriptions image_descriptions_general 77 | python compute_complete_text_set.py --device cuda:0 --model ViT-H-14 --texts_per_head 20 --num_of_last_layers 4 --text_descriptions image_descriptions_general 78 | ``` 79 | 80 | ### Other datasets 81 | To download the Waterbirds datasets, run: 82 | ```bash 83 | wget https://nlp.stanford.edu/data/dro/waterbird_complete95_forest2water2.tar.gz 84 | tar -xf waterbird_complete95_forest2water2.tar.gz 85 | ``` 86 | To compute the overall accuracy, run: 87 | ```bash 88 | python compute_prs.py --dataset binary_waterbirds --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k --data_path 89 | python compute_text_projection.py --dataset binary_waterbirds --device cuda:0 --model ViT-L-14 --pretrained laion2b_s32b_b82k 90 | python compute_use_specific_heads.py --model ViT-L-14 --dataset binary_waterbirds 91 | ``` 92 | 93 | ### Spatial decomposition 94 | Please see a demo for the spatial decomposition of CLIP in `demo.ipynb`. 95 | 96 | 97 | ### Nearest neighbors search 98 | Please see the nearest neighbors search demo in `nns.ipynb`. 99 | 100 | ### BibTeX 101 | 102 | ```bibtex 103 | @inproceedings{ 104 | gandelsman2024interpreting, 105 | title={Interpreting {CLIP}'s Image Representation via Text-Based Decomposition}, 106 | author={Yossi Gandelsman and Alexei A. Efros and Jacob Steinhardt}, 107 | booktitle={The Twelfth International Conference on Learning Representations}, 108 | year={2024}, 109 | url={https://openreview.net/forum?id=5Ca9sSzuDp} 110 | } 111 | ``` 112 | -------------------------------------------------------------------------------- /compute_ablations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os.path 4 | import argparse 5 | import einops 6 | from pathlib import Path 7 | 8 | import tqdm 9 | from utils.misc import accuracy 10 | 11 | 12 | def get_args_parser(): 13 | parser = argparse.ArgumentParser("Ablations part", add_help=False) 14 | 15 | # Model parameters 16 | parser.add_argument( 17 | "--model", 18 | default="ViT-H-14", 19 | type=str, 20 | metavar="MODEL", 21 | help="Name of model to use", 22 | ) 23 | # Dataset parameters 24 | parser.add_argument("--num_workers", default=10, type=int) 25 | parser.add_argument( 26 | "--figures_dir", default="./output_dir", help="path where data is saved" 27 | ) 28 | parser.add_argument( 29 | "--input_dir", default="./output_dir", help="path where data is saved" 30 | ) 31 | parser.add_argument( 32 | "--dataset", 33 | type=str, 34 | default="imagenet", 35 | help="imagenet, waterbirds, cub, binary_waterbirds", 36 | ) 37 | return parser 38 | 39 | 40 | def main(args): 41 | 42 | attns = np.load(os.path.join(args.input_dir, f"{args.dataset}_attn_{args.model}.npy"), mmap_mode="r") # [b, l, h, d] 43 | mlps = np.load(os.path.join(args.input_dir, f"{args.dataset}_mlp_{args.model}.npy"), mmap_mode="r") # [b, l+1, d] 44 | with open( 45 | os.path.join(args.input_dir, f"{args.dataset}_classifier_{args.model}.npy"), 46 | "rb", 47 | ) as f: 48 | classifier = np.load(f) 49 | if args.dataset == "imagenet": 50 | labels = np.array([i // 50 for i in range(attns.shape[0])]) 51 | else: 52 | with open( 53 | os.path.join(args.input_dir, f"{args.dataset}_labels.npy"), "rb" 54 | ) as f: 55 | labels = np.load(f) 56 | baseline = attns.sum(axis=(1, 2)) + mlps.sum(axis=1) 57 | baseline_acc = ( 58 | accuracy( 59 | torch.from_numpy(baseline @ classifier).float(), torch.from_numpy(labels) 60 | )[0] 61 | * 100 62 | ) 63 | print("Baseline:", baseline_acc) 64 | mlps_mean = einops.repeat(mlps.mean(axis=0), "l d -> b l d", b=attns.shape[0]) 65 | mlps_ablation = attns.sum(axis=(1, 2)) + mlps_mean.sum(axis=1) 66 | mlps_ablation_acc = ( 67 | accuracy( 68 | torch.from_numpy(mlps_ablation @ classifier).float(), 69 | torch.from_numpy(labels), 70 | )[0] 71 | * 100 72 | ) 73 | print("+ MLPs ablation:", mlps_ablation_acc) 74 | mlps_no_layers = mlps.sum(axis=1) 75 | attns_no_cls = attns.sum(axis=2) 76 | with open( 77 | os.path.join(args.input_dir, f"{args.dataset}_cls_attn_{args.model}.npy"), "rb" 78 | ) as f: 79 | cls_attn = np.load(f) # [b, l, d] 80 | attns_no_cls = attns_no_cls - cls_attn + cls_attn.mean(axis=0)[np.newaxis, :, :] 81 | no_cls_ablation = attns_no_cls.sum(axis=1) + mlps_no_layers 82 | no_cls_acc = ( 83 | accuracy( 84 | torch.from_numpy(no_cls_ablation @ classifier).float(), 85 | torch.from_numpy(labels), 86 | )[0] 87 | * 100 88 | ) 89 | print("+ CLS ablation:", no_cls_acc) 90 | mlp_and_no_cls_ablation = attns_no_cls.sum(axis=1) + mlps_mean.sum(axis=1) 91 | mlp_and_no_cls_ablation_acc = ( 92 | accuracy( 93 | torch.from_numpy(mlp_and_no_cls_ablation @ classifier).float(), 94 | torch.from_numpy(labels), 95 | )[0] 96 | * 100 97 | ) 98 | print("+ MLPs + CLS ablation:", mlp_and_no_cls_ablation_acc) 99 | no_heads_attentions = attns.sum(axis=(2)) 100 | all_accuracies = [baseline_acc] 101 | for layer in range(attns.shape[1]): 102 | current_model = ( 103 | np.sum( 104 | np.mean(no_heads_attentions[:, :layer], axis=0, keepdims=True), axis=1 105 | ) 106 | + np.mean(no_heads_attentions[:, layer], axis=0, keepdims=True) 107 | + np.sum(no_heads_attentions[:, layer + 1 :], axis=1) 108 | ) 109 | current_accuracy = ( 110 | accuracy( 111 | torch.from_numpy((mlps_no_layers + current_model) @ classifier).float(), 112 | torch.from_numpy(labels), 113 | )[0] 114 | * 100 115 | ) 116 | all_accuracies.append(current_accuracy) 117 | print("Attention ablations:", all_accuracies) 118 | 119 | 120 | if __name__ == "__main__": 121 | args = get_args_parser() 122 | args = args.parse_args() 123 | if args.figures_dir: 124 | Path(args.figures_dir).mkdir(parents=True, exist_ok=True) 125 | main(args) 126 | -------------------------------------------------------------------------------- /compute_complete_text_set.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | import glob 6 | import sys 7 | import os 8 | import einops 9 | from torch.utils.data import DataLoader 10 | import tqdm 11 | import argparse 12 | from torchvision.datasets import ImageNet 13 | from pathlib import Path 14 | 15 | from utils.misc import accuracy 16 | 17 | 18 | @torch.no_grad() 19 | def replace_with_iterative_removal(data, text_features, texts, iters, rank, device): 20 | results = [] 21 | u, s, vh = np.linalg.svd(data, full_matrices=False) 22 | vh = vh[:rank] 23 | text_features = ( 24 | vh.T.dot(np.linalg.inv(vh.dot(vh.T)).dot(vh)).dot(text_features.T).T 25 | ) # Project the text to the span of W_OV 26 | data = torch.from_numpy(data).float().to(device) 27 | mean_data = data.mean(dim=0, keepdim=True) 28 | data = data - mean_data 29 | reconstruct = einops.repeat(mean_data, "A B -> (C A) B", C=data.shape[0]) 30 | reconstruct = reconstruct.detach().cpu().numpy() 31 | text_features = torch.from_numpy(text_features).float().to(device) 32 | for i in range(iters): 33 | projection = data @ text_features.T 34 | projection_std = projection.std(axis=0).detach().cpu().numpy() 35 | top_n = np.argmax(projection_std) 36 | results.append(texts[top_n]) 37 | text_norm = text_features[top_n] @ text_features[top_n].T 38 | reconstruct += ( 39 | ( 40 | (data @ text_features[top_n] / text_norm)[:, np.newaxis] 41 | * text_features[top_n][np.newaxis, :] 42 | ) 43 | .detach() 44 | .cpu() 45 | .numpy() 46 | ) 47 | data = data - ( 48 | (data @ text_features[top_n] / text_norm)[:, np.newaxis] 49 | * text_features[top_n][np.newaxis, :] 50 | ) 51 | text_features = ( 52 | text_features 53 | - (text_features @ text_features[top_n] / text_norm)[:, np.newaxis] 54 | * text_features[top_n][np.newaxis, :] 55 | ) 56 | return reconstruct, results 57 | 58 | 59 | def get_args_parser(): 60 | parser = argparse.ArgumentParser("Completeness part", add_help=False) 61 | 62 | # Model parameters 63 | parser.add_argument( 64 | "--model", 65 | default="ViT-H-14", 66 | type=str, 67 | metavar="MODEL", 68 | help="Name of model to use", 69 | ) 70 | # Dataset parameters 71 | parser.add_argument("--num_workers", default=10, type=int) 72 | parser.add_argument( 73 | "--output_dir", default="./output_dir", help="path where data is saved" 74 | ) 75 | parser.add_argument( 76 | "--input_dir", default="./output_dir", help="path where data is saved" 77 | ) 78 | parser.add_argument( 79 | "--text_descriptions", 80 | default="image_descriptions_per_class", 81 | type=str, 82 | help="name of the evalauted text set", 83 | ) 84 | parser.add_argument( 85 | "--text_dir", 86 | default="./text_descriptions", 87 | type=str, 88 | help="The folder with the text files", 89 | ) 90 | parser.add_argument( 91 | "--dataset", type=str, default="imagenet", help="imagenet or waterbirds" 92 | ) 93 | parser.add_argument( 94 | "--num_of_last_layers", 95 | type=int, 96 | default=8, 97 | help="How many attention layers to replace.", 98 | ) 99 | parser.add_argument( 100 | "--w_ov_rank", type=int, default=80, help="The rank of the OV matrix" 101 | ) 102 | parser.add_argument( 103 | "--texts_per_head", 104 | type=int, 105 | default=10, 106 | help="The number of text examples per head.", 107 | ) 108 | parser.add_argument("--device", default="cuda:0", help="device to use for testing") 109 | return parser 110 | 111 | 112 | def main(args): 113 | with open( 114 | os.path.join(args.input_dir, f"{args.dataset}_attn_{args.model}.npy"), "rb" 115 | ) as f: 116 | attns = np.load(f) # [b, l, h, d] 117 | with open( 118 | os.path.join(args.input_dir, f"{args.dataset}_mlp_{args.model}.npy"), "rb" 119 | ) as f: 120 | mlps = np.load(f) # [b, l+1, d] 121 | with open( 122 | os.path.join(args.input_dir, f"{args.dataset}_classifier_{args.model}.npy"), 123 | "rb", 124 | ) as f: 125 | classifier = np.load(f) 126 | print(f"Number of layers: {attns.shape[1]}") 127 | all_images = set() 128 | # Mean-ablate the other parts 129 | for i in tqdm.trange(attns.shape[1] - args.num_of_last_layers): 130 | for head in range(attns.shape[2]): 131 | attns[:, i, head] = np.mean(attns[:, i, head], axis=0, keepdims=True) 132 | # Load text: 133 | with open( 134 | os.path.join(args.input_dir, f"{args.text_descriptions}_{args.model}.npy"), "rb" 135 | ) as f: 136 | text_features = np.load(f) 137 | with open(os.path.join(args.text_dir, f"{args.text_descriptions}.txt"), "r") as f: 138 | lines = [i.replace("\n", "") for i in f.readlines()] 139 | with open( 140 | os.path.join( 141 | args.output_dir, 142 | f"{args.dataset}_completeness_{args.text_descriptions}_top_{args.texts_per_head}_heads_{args.model}.txt", 143 | ), 144 | "w", 145 | ) as w: 146 | for i in tqdm.trange(attns.shape[1] - args.num_of_last_layers, attns.shape[1]): 147 | for head in range(attns.shape[2]): 148 | reconstruct, results = replace_with_iterative_removal( 149 | attns[:, i, head], 150 | text_features, 151 | lines, 152 | args.texts_per_head, 153 | args.w_ov_rank, 154 | args.device, 155 | ) 156 | attns[:, i, head] = reconstruct 157 | all_images |= set(results) 158 | w.write(f"------------------\n") 159 | w.write(f"Layer {i}, Head {head}\n") 160 | w.write(f"------------------\n") 161 | for text in results: 162 | w.write(f"{text}\n") 163 | 164 | mean_ablated_and_replaced = mlps.sum(axis=1) + attns.sum(axis=(1, 2)) 165 | projections = torch.from_numpy(mean_ablated_and_replaced).float().to( 166 | args.device 167 | ) @ torch.from_numpy(classifier).float().to(args.device) 168 | labels = np.array([i // 50 for i in range(attns.shape[0])]) 169 | current_accuracy = ( 170 | accuracy(projections.cpu(), torch.from_numpy(labels))[0] * 100.0 171 | ) 172 | print( 173 | f"Current accuracy:", 174 | current_accuracy, 175 | "\nNumber of texts:", 176 | len(all_images), 177 | ) 178 | w.write(f"------------------\n") 179 | w.write( 180 | f"Current accuracy: {current_accuracy}\nNumber of texts: {len(all_images)}" 181 | ) 182 | 183 | 184 | if __name__ == "__main__": 185 | args = get_args_parser() 186 | args = args.parse_args() 187 | if args.output_dir: 188 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 189 | main(args) 190 | -------------------------------------------------------------------------------- /compute_prs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | import os.path 5 | import argparse 6 | from pathlib import Path 7 | 8 | from torch.utils.data import DataLoader 9 | import tqdm 10 | from utils.factory import create_model_and_transforms, get_tokenizer 11 | from utils.binary_waterbirds import BinaryWaterbirds 12 | from prs_hook import hook_prs_logger 13 | from torchvision.datasets import CIFAR100, CIFAR10, ImageNet, ImageFolder 14 | 15 | 16 | def get_args_parser(): 17 | parser = argparse.ArgumentParser("Project Residual Stream", add_help=False) 18 | parser.add_argument("--batch_size", default=2, type=int, help="Batch size") 19 | # Model parameters 20 | parser.add_argument( 21 | "--model", 22 | default="ViT-H-14", 23 | type=str, 24 | metavar="MODEL", 25 | help="Name of model to use", 26 | ) 27 | parser.add_argument("--pretrained", default="laion2b_s32b_b79k", type=str) 28 | # Dataset parameters 29 | parser.add_argument( 30 | "--data_path", default="/shared/group/ilsvrc", type=str, help="dataset path" 31 | ) 32 | parser.add_argument( 33 | "--dataset", type=str, default="imagenet", help="imagenet, cub or waterbirds" 34 | ) 35 | parser.add_argument("--num_workers", default=10, type=int) 36 | parser.add_argument( 37 | "--output_dir", default="./output_dir", help="path where to save" 38 | ) 39 | parser.add_argument("--device", default="cuda:0", help="device to use for testing") 40 | return parser 41 | 42 | 43 | def main(args): 44 | """Calculates the projected residual stream for a dataset.""" 45 | model, _, preprocess = create_model_and_transforms( 46 | args.model, pretrained=args.pretrained 47 | ) 48 | model.to(args.device) 49 | model.eval() 50 | context_length = model.context_length 51 | vocab_size = model.vocab_size 52 | 53 | print( 54 | "Model parameters:", 55 | f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}", 56 | ) 57 | print("Context length:", context_length) 58 | print("Vocab size:", vocab_size) 59 | print("Len of res:", len(model.visual.transformer.resblocks)) 60 | 61 | prs = hook_prs_logger(model, args.device) 62 | 63 | # Data: 64 | if args.dataset == "imagenet": 65 | ds = ImageNet(root=args.data_path, split="val", transform=preprocess) 66 | elif args.dataset == "binary_waterbirds": 67 | ds = BinaryWaterbirds(root=args.data_path, split="test", transform=preprocess) 68 | elif args.dataset == "CIFAR100": 69 | ds = CIFAR100( 70 | root=args.data_path, download=True, train=False, transform=preprocess 71 | ) 72 | elif args.dataset == "CIFAR10": 73 | ds = CIFAR10( 74 | root=args.data_path, download=True, train=False, transform=preprocess 75 | ) 76 | else: 77 | ds = ImageFolder(root=args.data_path, transform=preprocess) 78 | dataloader = DataLoader( 79 | ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers 80 | ) 81 | 82 | attention_results = [] 83 | mlp_results = [] 84 | cls_to_cls_results = [] 85 | for i, (image, _) in enumerate(tqdm.tqdm(dataloader)): 86 | with torch.no_grad(): 87 | prs.reinit() 88 | representation = model.encode_image( 89 | image.to(args.device), attn_method="head", normalize=False 90 | ) 91 | attentions, mlps = prs.finalize(representation) 92 | attentions = attentions.detach().cpu().numpy() # [b, l, n, h, d] 93 | mlps = mlps.detach().cpu().numpy() # [b, l+1, d] 94 | attention_results.append( 95 | np.sum(attentions, axis=2) 96 | ) # Reduce the spatial dimension 97 | mlp_results.append(mlps) 98 | cls_to_cls_results.append( 99 | np.sum(attentions[:, :, 0], axis=2) 100 | ) # Store the cls->cls attention, reduce the heads 101 | with open( 102 | os.path.join(args.output_dir, f"{args.dataset}_attn_{args.model}.npy"), "wb" 103 | ) as f: 104 | np.save(f, np.concatenate(attention_results, axis=0)) 105 | with open( 106 | os.path.join(args.output_dir, f"{args.dataset}_mlp_{args.model}.npy"), "wb" 107 | ) as f: 108 | np.save(f, np.concatenate(mlp_results, axis=0)) 109 | with open( 110 | os.path.join(args.output_dir, f"{args.dataset}_cls_attn_{args.model}.npy"), "wb" 111 | ) as f: 112 | np.save(f, np.concatenate(cls_to_cls_results, axis=0)) 113 | 114 | 115 | if __name__ == "__main__": 116 | args = get_args_parser() 117 | args = args.parse_args() 118 | if args.output_dir: 119 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 120 | main(args) 121 | -------------------------------------------------------------------------------- /compute_segmentations.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import numpy as np 4 | import scipy 5 | import torchvision.transforms as transforms 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | from PIL import Image 9 | import imageio 10 | import cv2 11 | import os 12 | from pathlib import Path 13 | import tqdm 14 | from utils.factory import create_model_and_transforms 15 | from utils.imagenet_segmentation import ImagenetSegmentation 16 | from utils.segmentation_utils import ( 17 | batch_pix_accuracy, 18 | batch_intersection_union, 19 | get_ap_scores, 20 | Saver, 21 | ) 22 | from sklearn.metrics import precision_recall_curve 23 | from prs_hook import hook_prs_logger 24 | 25 | 26 | # Args 27 | def get_args_parser(): 28 | parser = argparse.ArgumentParser(description="Segmentation scores") 29 | parser.add_argument("--save_img", action="store_true", default=False, help="") 30 | parser.add_argument( 31 | "--train_dataset", 32 | type=str, 33 | default="imagenet_seg", 34 | help="The name of the dataset", 35 | ) 36 | parser.add_argument( 37 | "--classifier_dataset", 38 | type=str, 39 | default="imagenet", 40 | help="The name of the classifier dataset", 41 | ) 42 | parser.add_argument("--image_size", default=224, type=int, help="Image size") 43 | parser.add_argument("--thr", type=float, default=0.0, help="threshold") 44 | parser.add_argument( 45 | "--data_path", 46 | default="imagenet_seg/gtsegs_ijcv.mat", 47 | type=str, 48 | help="dataset path", 49 | ) 50 | parser.add_argument("--num_workers", default=10, type=int) 51 | parser.add_argument("--classifier_dir", default="./output_dir/") 52 | parser.add_argument("--batch_size", default=1, type=int, help="Batch size") 53 | # Model parameters 54 | parser.add_argument( 55 | "--model", 56 | default="ViT-H-14", 57 | type=str, 58 | metavar="MODEL", 59 | help="Name of model to use", 60 | ) 61 | parser.add_argument("--pretrained", default="laion2b_s32b_b79k", type=str) 62 | parser.add_argument( 63 | "--output_dir", default="./output_dir", help="path where to save" 64 | ) 65 | parser.add_argument("--device", default="cuda:0", help="device to use for testing") 66 | return parser 67 | 68 | 69 | @torch.no_grad() 70 | def eval_batch(model, prs, image, labels, index, args, classifier, saver): 71 | # Save input image 72 | if args.save_img: 73 | # Saves one image from each batch 74 | img = image[0].permute(1, 2, 0).data.cpu().numpy() 75 | img = 255 * (img - img.min()) / (img.max() - img.min()) 76 | img = img.astype("uint8") 77 | Image.fromarray(img, "RGB").save( 78 | os.path.join(saver.results_dir, "input/{}_input.png".format(index)) 79 | ) 80 | Image.fromarray( 81 | (labels.repeat(3, 1, 1).permute(1, 2, 0).data.cpu().numpy() * 255).astype( 82 | "uint8" 83 | ), 84 | "RGB", 85 | ).save(os.path.join(saver.results_dir, "input/{}_mask.png".format(index))) 86 | 87 | # Get the model attention maps: 88 | prs.reinit() 89 | representation = model.encode_image( 90 | image.to(args.device), attn_method="head", normalize=False 91 | ) 92 | attentions, _ = prs.finalize(representation) 93 | attentions = attentions.detach().cpu() # [b, l, n, h, d] 94 | chosen_class = (representation.detach().cpu().numpy() @ classifier).argmax(axis=1) 95 | patches = args.image_size // model.visual.patch_size[0] 96 | attentions_collapse = attentions[:, :, 1:].sum(axis=(1, 3)) 97 | class_heatmap = ( 98 | attentions_collapse.detach().cpu().numpy() @ classifier 99 | ) # [b, n, classes] 100 | results = [] 101 | for i in range(image.shape[0]): 102 | normalized = class_heatmap[i, :, chosen_class[i]] - np.mean( 103 | class_heatmap[i], axis=1 104 | ) 105 | results.append(normalized) 106 | results = torch.from_numpy( 107 | np.stack(results, axis=0).reshape((attentions.shape[0], patches, patches)) 108 | ) 109 | 110 | Res = torch.nn.functional.interpolate( 111 | results[:, np.newaxis], scale_factor=model.visual.patch_size[0], mode="bilinear" 112 | ).to(args.device) 113 | Res = torch.clip(Res, 0, Res.max()) 114 | # threshold between FG and BG is the mean 115 | Res = (Res - Res.min()) / (Res.max() - Res.min()) 116 | 117 | ret = Res.mean() 118 | 119 | Res_1 = Res.gt(ret).type(Res.type()) 120 | Res_0 = Res.le(ret).type(Res.type()) 121 | 122 | Res_1_AP = Res 123 | Res_0_AP = 1 - Res 124 | 125 | Res_1[Res_1 != Res_1] = 0 126 | Res_0[Res_0 != Res_0] = 0 127 | Res_1_AP[Res_1_AP != Res_1_AP] = 0 128 | Res_0_AP[Res_0_AP != Res_0_AP] = 0 129 | 130 | # TEST 131 | pred = Res.clamp(min=args.thr) / Res.max() 132 | pred = pred.view(-1).data.cpu().numpy() 133 | target = labels.view(-1).data.cpu().numpy() 134 | 135 | output = torch.cat((Res_0, Res_1), 1) 136 | output_AP = torch.cat((Res_0_AP, Res_1_AP), 1) 137 | 138 | if args.save_img: 139 | # Save predicted mask 140 | mask = F.interpolate(Res_1, [args.image_size, args.image_size], mode="bilinear") 141 | mask = mask[0].squeeze().data.cpu().numpy() 142 | mask = 255 * mask 143 | mask = mask.astype("uint8") 144 | imageio.imsave( 145 | os.path.join(args.exp_img_path, "mask_" + str(index) + ".jpg"), mask 146 | ) 147 | 148 | relevance = F.interpolate(Res, [args.image_size, args.image_size], mode="bicubic") 149 | relevance = relevance[0].permute(1, 2, 0).data.cpu().numpy() 150 | hm = np.sum(relevance, axis=-1) 151 | hm = np.clip(255.0 * hm / hm.max(), 0, 255.0).astype(np.uint8) 152 | high = cv2.cvtColor(cv2.applyColorMap(hm, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB) 153 | imageio.imsave( 154 | os.path.join(args.exp_img_path, "heatmap_" + str(index) + ".jpg"), high 155 | ) 156 | 157 | # Evaluate Segmentation 158 | batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0 159 | batch_ap = 0 160 | 161 | # Segmentation resutls 162 | correct, labeled = batch_pix_accuracy(output[0].data.cpu(), labels[0]) 163 | inter, union = batch_intersection_union(output[0].data.cpu(), labels[0], 2) 164 | batch_correct += correct 165 | batch_label += labeled 166 | batch_inter += inter 167 | batch_union += union 168 | ap = np.nan_to_num(get_ap_scores(output_AP, labels)) 169 | batch_ap += ap 170 | 171 | return batch_correct, batch_label, batch_inter, batch_union, batch_ap, pred, target 172 | 173 | 174 | def _create_saver_and_folders(args): 175 | saver = Saver(args) 176 | saver.results_dir = os.path.join(saver.experiment_dir, "results") 177 | if not os.path.exists(saver.results_dir): 178 | os.makedirs(saver.results_dir) 179 | if not os.path.exists(os.path.join(saver.results_dir, "input")): 180 | os.makedirs(os.path.join(saver.results_dir, "input")) 181 | if not os.path.exists(os.path.join(saver.results_dir, "explain")): 182 | os.makedirs(os.path.join(saver.results_dir, "explain")) 183 | 184 | args.exp_img_path = os.path.join(saver.results_dir, "explain/img") 185 | if not os.path.exists(args.exp_img_path): 186 | os.makedirs(args.exp_img_path) 187 | return saver 188 | 189 | 190 | def main(args): 191 | # Model 192 | model, _, preprocess = create_model_and_transforms( 193 | args.model, pretrained=args.pretrained 194 | ) 195 | model.to(args.device) 196 | model.eval() 197 | context_length = model.context_length 198 | vocab_size = model.vocab_size 199 | 200 | print( 201 | "Model parameters:", 202 | f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}", 203 | ) 204 | print("Context length:", context_length) 205 | print("Vocab size:", vocab_size) 206 | print("Len of res:", len(model.visual.transformer.resblocks)) 207 | 208 | prs = hook_prs_logger(model, args.device) 209 | # Data 210 | target_transform = transforms.Compose( 211 | [ 212 | transforms.Resize((args.image_size, args.image_size), Image.NEAREST), 213 | ] 214 | ) 215 | 216 | ds = ImagenetSegmentation( 217 | args.data_path, transform=preprocess, target_transform=target_transform 218 | ) 219 | dl = DataLoader( 220 | ds, 221 | batch_size=args.batch_size, 222 | shuffle=False, 223 | num_workers=args.num_workers, 224 | drop_last=False, 225 | ) 226 | iterator = tqdm.tqdm(dl) 227 | # Saver 228 | saver = _create_saver_and_folders(args) 229 | # Classifier 230 | with open( 231 | os.path.join( 232 | args.classifier_dir, 233 | f"{args.classifier_dataset}_classifier_{args.model}.npy", 234 | ), 235 | "rb", 236 | ) as f: 237 | classifier = np.load(f) 238 | # Eval in loop 239 | total_inter, total_union, total_correct, total_label = ( 240 | np.int64(0), 241 | np.int64(0), 242 | np.int64(0), 243 | np.int64(0), 244 | ) 245 | total_ap = [] 246 | 247 | predictions, targets = [], [] 248 | for batch_idx, (image, labels) in enumerate(iterator): 249 | 250 | images = image.to(args.device) 251 | labels = labels.to(args.device) 252 | 253 | correct, labeled, inter, union, ap, pred, target = eval_batch( 254 | model, prs, images, labels, batch_idx, args, classifier, saver 255 | ) 256 | 257 | predictions.append(pred) 258 | targets.append(target) 259 | 260 | total_correct += correct.astype("int64") 261 | total_label += labeled.astype("int64") 262 | total_inter += inter.astype("int64") 263 | total_union += union.astype("int64") 264 | total_ap += [ap] 265 | pixAcc = ( 266 | np.float64(1.0) 267 | * total_correct 268 | / (np.spacing(1, dtype=np.float64) + total_label) 269 | ) 270 | IoU = ( 271 | np.float64(1.0) 272 | * total_inter 273 | / (np.spacing(1, dtype=np.float64) + total_union) 274 | ) 275 | mIoU = IoU.mean() 276 | mAp = np.mean(total_ap) 277 | iterator.set_description( 278 | "pixAcc: %.4f, mIoU: %.4f, mAP: %.4f" % (pixAcc, mIoU, mAp) 279 | ) 280 | 281 | predictions = np.concatenate(predictions) 282 | targets = np.concatenate(targets) 283 | pr, rc, thr = precision_recall_curve(targets, predictions) 284 | np.save(os.path.join(saver.experiment_dir, "precision.npy"), pr) 285 | np.save(os.path.join(saver.experiment_dir, "recall.npy"), rc) 286 | 287 | txtfile = os.path.join(saver.experiment_dir, "result_mIoU_%.4f.txt" % mIoU) 288 | fh = open(txtfile, "w") 289 | print("Mean IoU over %d classes: %.4f\n" % (2, mIoU)) 290 | print("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100)) 291 | print("Mean AP over %d classes: %.4f\n" % (2, mAp)) 292 | 293 | fh.write("Mean IoU over %d classes: %.4f\n" % (2, mIoU)) 294 | fh.write("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100)) 295 | fh.write("Mean AP over %d classes: %.4f\n" % (2, mAp)) 296 | fh.close() 297 | 298 | 299 | if __name__ == "__main__": 300 | args = get_args_parser() 301 | args = args.parse_args() 302 | if args.output_dir: 303 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 304 | main(args) 305 | -------------------------------------------------------------------------------- /compute_siglip.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | import os.path 5 | import argparse 6 | from pathlib import Path 7 | from matplotlib import pyplot as plt 8 | from torch.utils.data import DataLoader 9 | import tqdm 10 | from utils.factory import create_model_and_transforms, get_tokenizer 11 | from torchvision.datasets import ImageFolder 12 | from utils.siglip.modeling_siglip import SiglipVisionModel, SiglipTextModel 13 | from utils.siglip.processing_siglip import SiglipProcessor 14 | from transformers import AutoTokenizer 15 | from utils.openai_templates import OPENAI_IMAGENET_TEMPLATES 16 | from utils.imagenet_classes import imagenet_classes 17 | import torch.nn.functional as F 18 | from typing import Union, Any 19 | from compute_complete_text_set import replace_with_iterative_removal 20 | 21 | class ImageNet(ImageFolder): 22 | def __init__(self, root: Union[str, Path], split: str = "train", **kwargs: Any) -> None: 23 | wnid_to_classes = torch.load(os.path.join(root, "meta.bin"), weights_only=True)[0] 24 | super().__init__(os.path.join(root, split), **kwargs) 25 | 26 | self.wnids = self.classes 27 | self.wnid_to_idx = self.class_to_idx 28 | self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] 29 | self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss} 30 | 31 | 32 | def get_args_parser(): 33 | parser = argparse.ArgumentParser("Project Residual Stream", add_help=False) 34 | parser.add_argument("--batch_size", default=256, type=int, help="Batch size") 35 | # Model parameters 36 | parser.add_argument( 37 | "--model", 38 | default="google/siglip-so400m-patch14-384", 39 | type=str, 40 | help="Name of model to use", 41 | ) 42 | # Dataset parameters 43 | parser.add_argument( 44 | "--data_path", default="/mnt/data/yossi/ILSVRC2012/", type=str, help="dataset path" 45 | ) 46 | parser.add_argument("--save_everything", action="store_true", help="save everything") 47 | parser.add_argument("--num_workers", default=10, type=int) 48 | parser.add_argument( 49 | "--output_dir", default="./output_dir", help="path where to save" 50 | ) 51 | parser.add_argument("--device", default="cuda:0", help="device to use for testing") 52 | parser.add_argument("--compute_text_spans", action="store_true", help="compute text spans") 53 | parser.add_argument("--text_descriptions", default="text_descriptions/image_descriptions_general.txt", type=str, help="text descriptions to use") 54 | return parser 55 | 56 | 57 | def compute_zeroshot_weights(model, model_name, tokenizer, classnames, device, templates, use_format=False): 58 | max_length = { 59 | 'google/siglip-so400m-patch14-384': 64, 60 | 'google/siglip-base-patch16-224': 64 61 | } 62 | model.eval() 63 | zeroshot_weights = [] 64 | with torch.no_grad(): 65 | for classname in tqdm.tqdm(classnames): 66 | texts = [template.format(c=classname) if use_format else template(classname) for template in templates] 67 | inputs = tokenizer(texts, truncation=False, padding="max_length", max_length=max_length[model_name], return_tensors="pt") 68 | inputs = {k: v.to(device) for k, v in inputs.items()} 69 | outputs = model(**inputs) 70 | class_embedding = F.normalize(outputs.pooler_output, dim=-1).mean(dim=0) 71 | class_embedding /= class_embedding.norm() 72 | zeroshot_weights.append(class_embedding.cpu()) 73 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0) 74 | return zeroshot_weights 75 | 76 | @torch.no_grad() 77 | def get_text_features(model, model_name, tokenizer, lines, 78 | device, batch_size): 79 | max_length = { 80 | 'google/siglip-so400m-patch14-384': 64, 81 | 'google/siglip-base-patch16-224': 64 82 | } 83 | model.eval() 84 | zeroshot_weights = [] 85 | for i in tqdm.trange(0, len(lines), batch_size): 86 | texts = [l.replace('\n', '') for l in lines[i:i+batch_size]] 87 | inputs = tokenizer(texts, truncation=False, padding="max_length", max_length=max_length[model_name], return_tensors="pt") 88 | inputs = {k: v.to(device) for k, v in inputs.items()} 89 | outputs = model(**inputs) 90 | class_embedding = F.normalize(outputs.pooler_output, dim=-1) 91 | zeroshot_weights.append(class_embedding.detach().cpu()) 92 | zeroshot_weights = torch.concatenate(zeroshot_weights, dim=0) 93 | return zeroshot_weights 94 | 95 | 96 | # Minimal PRS hook for out.post and mlp_output 97 | class PRSHook: 98 | def __init__(self, collapse_spatial: bool = False): 99 | self.attention_records = [] 100 | self.mlp_records = [] 101 | self.collapse_spatial = collapse_spatial 102 | 103 | def save_attention(self, ret, **kwargs): 104 | if self.collapse_spatial: 105 | to_return = ret.sum(axis=2).detach().cpu() 106 | self.attention_records.append(to_return) 107 | else: 108 | self.attention_records.append(ret.detach().cpu()) 109 | return ret 110 | 111 | def save_mlp(self, ret, **kwargs): 112 | self.mlp_records.append(ret.detach().cpu()) 113 | return ret 114 | 115 | def finalize(self): 116 | self.attention_records = torch.cat(self.attention_records, dim=0) 117 | self.mlp_records = torch.cat(self.mlp_records, dim=0) 118 | return {"attention_records": self.attention_records, "mlp_records": self.mlp_records} 119 | 120 | def compute_accuracy(features, labels, zeroshot_weights): 121 | zeroshot_weights = zeroshot_weights.to(features.device) # (1000, D) 122 | logits = features @ zeroshot_weights.t() # (N, 1000) 123 | preds = logits.argmax(dim=1) 124 | correct = (preds.cpu() == labels).sum().item() 125 | total = labels.size(0) 126 | acc = correct / total * 100 127 | return acc, correct, total 128 | 129 | @torch.no_grad() 130 | def main(args): 131 | """Calculates the projected residual stream for a dataset and zeroshot weights.""" 132 | model = SiglipVisionModel.from_pretrained(args.model) 133 | model.to(args.device) 134 | model.eval() 135 | processor = SiglipProcessor.from_pretrained(args.model, use_fast=True) 136 | print( 137 | "Model parameters:", 138 | f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}", 139 | ) 140 | 141 | # Data: 142 | transform = lambda x: processor(images=x, return_tensors="pt") 143 | ds = ImageNet(root=args.data_path, split="val", transform=transform) 144 | dataloader = DataLoader( 145 | ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers 146 | ) 147 | # Zeroshot weights for ImageNet 148 | print("Computing zeroshot weights for ImageNet...") 149 | text_model = SiglipTextModel.from_pretrained(args.model) 150 | text_model.to(args.device) 151 | tokenizer = AutoTokenizer.from_pretrained(args.model) 152 | zeroshot_weights_path = os.path.join(args.output_dir, f"imagenet_zeroshot_weights_{args.model.replace('/', '_')}.npy") 153 | if os.path.exists(zeroshot_weights_path): 154 | print(f"Loading zeroshot weights from {zeroshot_weights_path}") 155 | zeroshot_weights = torch.from_numpy(np.load(zeroshot_weights_path)) 156 | else: 157 | zeroshot_weights = compute_zeroshot_weights(text_model, args.model, tokenizer, imagenet_classes, args.device, templates=OPENAI_IMAGENET_TEMPLATES) 158 | np.save(zeroshot_weights_path, zeroshot_weights.numpy()) 159 | 160 | prs_hook = PRSHook(collapse_spatial=True) 161 | 162 | # # Register the hook on the model's hook manager 163 | model.hook.register('pooling_head.attention.out.post', prs_hook.save_attention) 164 | model.hook.register('pooling_head.mlp_output', prs_hook.save_mlp) 165 | 166 | # Compute representation accuracy 167 | representation_results = [] 168 | for i, (inputs, labels) in enumerate(tqdm.tqdm(dataloader)): 169 | inputs['pixel_values'] = inputs['pixel_values'].squeeze(1).to(args.device) 170 | outputs = model(**inputs) 171 | representation_results.append(outputs.pooler_output) 172 | # Save labels for accuracy calculation 173 | if i == 0: 174 | all_labels = labels.clone() 175 | else: 176 | all_labels = torch.cat([all_labels, labels], dim=0) 177 | representation_results = torch.cat(representation_results, dim=0) # (N, D) 178 | constant_bias = model.vision_model.head.attention.out_proj.bias.detach().cpu() 179 | # To be more percise, there is also a bias in the mlp output 180 | mlp_bias = model.vision_model.head.mlp.fc2.bias.detach().cpu() 181 | # Compute representation accuracy 182 | print("Computing accuracy (representation)...") 183 | acc, correct, total = compute_accuracy(representation_results, all_labels, zeroshot_weights) 184 | print(f"Top-1 representation accuracy: {acc:.2f}% ({correct}/{total})") 185 | 186 | attn_and_mlp_results = prs_hook.finalize() 187 | attn_results = attn_and_mlp_results["attention_records"] 188 | mlp_results = attn_and_mlp_results["mlp_records"] 189 | # Compute mlp accuracy 190 | print("Computing accuracy (mlp)...") 191 | acc, correct, total = compute_accuracy(mlp_results[:, 0] + constant_bias, all_labels, zeroshot_weights) 192 | print(f"Top-1 mlp accuracy: {acc:.2f}% ({correct}/{total})") 193 | 194 | # Compute attention accuracy 195 | print("Computing accuracy (attention)...") 196 | acc, correct, total = compute_accuracy(attn_results.sum(axis=2)[:, 0] + constant_bias + mlp_bias, all_labels, zeroshot_weights) 197 | print(f"Top-1 attention accuracy: {acc:.2f}% ({correct}/{total})") 198 | 199 | # Compute attention + mlp accuracy 200 | print("Computing accuracy (attention + mlp for sanity check)...") 201 | acc, correct, total = compute_accuracy(mlp_results[:, 0] + attn_results.sum(axis=2)[:, 0] + constant_bias, all_labels, zeroshot_weights) 202 | print(f"Top-1 attention + mlp accuracy: {acc:.2f}% ({correct}/{total})") 203 | 204 | # Optionally, save to disk: 205 | if args.save_everything: 206 | torch.save(attn_and_mlp_results, os.path.join(args.output_dir, f"siglip_{args.model.replace('/', '_')}_prs.pt")) 207 | 208 | # Compute text features 209 | with open(args.text_descriptions, 'r') as f: 210 | lines = f.readlines() 211 | base, name = os.path.split(args.text_descriptions) 212 | name = name.replace('.txt', '') 213 | text_features_path = os.path.join(args.output_dir, f'{name}_{args.model.replace('/', '_')}.npy') 214 | if os.path.exists(text_features_path): 215 | print(f"Loading text features from {text_features_path}") 216 | text_features = np.load(text_features_path) 217 | else: 218 | text_features = get_text_features(text_model, args.model, tokenizer, lines, args.device, args.batch_size).detach().cpu().numpy() 219 | with open(text_features_path, 'wb') as f: 220 | np.save(f, text_features) 221 | print(f"Saved text features to {text_features_path}") 222 | print(f"Text features shape: {text_features.shape}") 223 | non_spatial_results = attn_results[:, 0] # (N, h, D) 224 | non_spatial_results_centered = non_spatial_results - non_spatial_results.mean(dim=0, keepdim=True) # (1, h, D) 225 | # Check how orthogonal the heads are 226 | print("Checking how orthogonal the heads are...") 227 | orthogonalities = torch.zeros((non_spatial_results.shape[1], non_spatial_results.shape[1])) 228 | for batch_idx in range(0, non_spatial_results.shape[0], args.batch_size): 229 | examples = non_spatial_results_centered[batch_idx:batch_idx+args.batch_size] 230 | orthogonalities += torch.abs(torch.einsum('nhd,ngd->nhg', 231 | F.normalize(examples, dim=-1), 232 | F.normalize(examples, dim=-1))).sum(dim=0).detach().cpu() 233 | orthogonalities = orthogonalities.detach().cpu().numpy() / (non_spatial_results.shape[0]) 234 | with open(os.path.join(args.output_dir, f'{name}_{args.model.replace('/', '_')}_orthogonalities.npy'), 'wb') as f: 235 | np.save(f, orthogonalities) 236 | plt.figure(figsize=(10, 10)) 237 | plt.imshow(orthogonalities - np.eye(orthogonalities.shape[0])) 238 | plt.colorbar() 239 | plt.savefig(os.path.join(args.output_dir, f'{name}_{args.model.replace('/', '_')}_orthogonalities.pdf')) 240 | plt.close() 241 | print(f"Saved orthogonalities to {os.path.join(args.output_dir, f'{name}_{args.model.replace('/', '_')}_orthogonalities.npy')}") 242 | 243 | if args.compute_text_spans: 244 | print(f"Non-spatial results shape: {non_spatial_results.shape}") 245 | print(f"Text features shape: {text_features.shape}") 246 | for head in range(non_spatial_results.shape[1]): 247 | reconstruct, results = replace_with_iterative_removal( 248 | non_spatial_results[:, head].detach().cpu().numpy(), 249 | text_features, 250 | lines, 251 | non_spatial_results.shape[-1], 252 | non_spatial_results.shape[-1], 253 | args.device) 254 | print('--------------------------------') 255 | print(f"Head {head}") 256 | for text in results: 257 | print(text.replace('\n', '')) 258 | print("--------------------------------") 259 | 260 | 261 | if __name__ == "__main__": 262 | args = get_args_parser() 263 | args = args.parse_args() 264 | if args.output_dir: 265 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 266 | main(args) 267 | -------------------------------------------------------------------------------- /compute_text_projection.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | import glob 6 | import sys 7 | import os.path 8 | import argparse 9 | import datetime 10 | import json 11 | from pathlib import Path 12 | from torch import nn 13 | from torch.nn import functional as F 14 | from torch.utils.data import DataLoader 15 | import tqdm 16 | from utils.factory import create_model_and_transforms, get_tokenizer 17 | from utils.openai_templates import OPENAI_IMAGENET_TEMPLATES 18 | from utils.imagenet_classes import imagenet_classes 19 | from utils.cub_classes import cub_classes, waterbird_classes 20 | 21 | 22 | def get_args_parser(): 23 | parser = argparse.ArgumentParser('Get classifier weights', add_help=False) 24 | # Model parameters 25 | parser.add_argument('--model', default='ViT-H-14', type=str, metavar='MODEL', 26 | help='Name of model to use') 27 | parser.add_argument('--dataset', default='imagenet', help='waterbirds or imagenet') 28 | parser.add_argument('--pretrained', default='laion2b_s32b_b79k', type=str) 29 | # Dataset parameters 30 | parser.add_argument('--output_dir', default='./output_dir', 31 | help='path where to save') 32 | parser.add_argument('--device', default='cuda:0', 33 | help='device to use for testing') 34 | return parser 35 | 36 | 37 | 38 | def zero_shot_classifier(model, tokenizer, classnames, templates, 39 | device, amp=True, use_format=False): 40 | """ 41 | This function returns zero-shot vectors for each class in order 42 | to use it for zero-shot classification. 43 | 44 | 45 | model: 46 | CLIP-like model with `encode_text` 47 | 48 | tokenizer: 49 | text tokenizer, i.e. convert list of strings to torch.Tensor of integers 50 | 51 | classnames: list of str 52 | name of classes 53 | 54 | templates: list of str 55 | templates to use. 56 | 57 | Returns 58 | ------- 59 | 60 | torch.Tensor of shape (N,C) where N is the number 61 | of templates, and C is the number of classes. 62 | """ 63 | autocast = torch.cuda.amp.autocast 64 | with torch.no_grad(), autocast(): 65 | zeroshot_weights = [] 66 | for classname in tqdm.tqdm(classnames): 67 | texts = [template.format(c=classname) if use_format else template(classname) for template in templates] 68 | texts = tokenizer(texts).to(device) # tokenize 69 | class_embeddings = model.encode_text(texts) 70 | class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) 71 | class_embedding /= class_embedding.norm() 72 | zeroshot_weights.append(class_embedding) 73 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) 74 | return zeroshot_weights 75 | 76 | 77 | def main(args): 78 | """Calculates the classifier projection weights.""" 79 | model, _, preprocess = create_model_and_transforms(args.model, pretrained=args.pretrained) 80 | tokenizer = get_tokenizer(args.model) 81 | model.to(args.device) 82 | model.eval() 83 | context_length = model.context_length 84 | vocab_size = model.vocab_size 85 | 86 | print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}") 87 | print("Context length:", context_length) 88 | print("Vocab size:", vocab_size) 89 | classes = { 90 | 'imagenet': imagenet_classes, 91 | 'waterbirds': cub_classes, 92 | 'binary_waterbirds': waterbird_classes, 93 | 'cub': cub_classes}[args.dataset] 94 | classifier = zero_shot_classifier(model, tokenizer, classes, OPENAI_IMAGENET_TEMPLATES, args.device) 95 | with open(os.path.join(args.output_dir, f'{args.dataset}_classifier_{args.model}.npy'), 'wb') as f: 96 | np.save(f, classifier.detach().cpu().numpy()) 97 | 98 | 99 | if __name__ == '__main__': 100 | args = get_args_parser() 101 | args = args.parse_args() 102 | if args.output_dir: 103 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 104 | main(args) -------------------------------------------------------------------------------- /compute_text_set_projection.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | import glob 6 | import sys 7 | import os.path 8 | import argparse 9 | import datetime 10 | import json 11 | from pathlib import Path 12 | from torch import nn 13 | from torch.nn import functional as F 14 | from torch.utils.data import DataLoader 15 | import tqdm 16 | from utils.factory import create_model_and_transforms, get_tokenizer 17 | 18 | 19 | def get_args_parser(): 20 | parser = argparse.ArgumentParser('Get text list weights', add_help=False) 21 | # Model parameters 22 | parser.add_argument('--batch_size', default=2048, type=int, 23 | help='Batch size') 24 | parser.add_argument('--model', default='ViT-H-14', type=str, metavar='MODEL', 25 | help='Name of model to use') 26 | parser.add_argument('--pretrained', default='laion2b_s32b_b79k', type=str) 27 | # Dataset parameters 28 | parser.add_argument('--data_path', default='text_descriptions/image_descriptions_general.txt', 29 | type=str, help='dataset path') 30 | parser.add_argument('--num_workers', default=10, type=int) 31 | parser.add_argument('--output_dir', default='./output_dir', 32 | help='path where to save') 33 | parser.add_argument('--device', default='cuda:0', 34 | help='device to use for testing') 35 | return parser 36 | 37 | 38 | 39 | def get_text_features(model, tokenizer, lines, 40 | device, batch_size, amp=True, use_format=False): 41 | """ 42 | This function returns zero-shot vectors for each class in order 43 | to use it for zero-shot classification. 44 | 45 | 46 | model: 47 | CLIP-like model with `encode_text` 48 | 49 | tokenizer: 50 | text tokenizer, i.e. convert list of strings to torch.Tensor of integers 51 | 52 | lines: list of str 53 | name of classes 54 | 55 | Returns 56 | ------- 57 | 58 | torch.Tensor of shape (N,C) where N is the number 59 | of templates, and C is the number of classes. 60 | """ 61 | autocast = torch.cuda.amp.autocast 62 | with torch.no_grad(), autocast(): 63 | zeroshot_weights = [] 64 | for i in tqdm.trange(0, len(lines), batch_size): 65 | texts = lines[i:i+batch_size] 66 | texts = tokenizer(texts).to(device) # tokenize 67 | class_embeddings = model.encode_text(texts) 68 | class_embeddings = F.normalize(class_embeddings, dim=-1) 69 | zeroshot_weights.append(class_embeddings.detach().cpu()) 70 | zeroshot_weights = torch.concatenate(zeroshot_weights, dim=0) 71 | return zeroshot_weights 72 | 73 | 74 | def main(args): 75 | """Calculates the classifier projection weights.""" 76 | model, _, preprocess = create_model_and_transforms(args.model, pretrained=args.pretrained) 77 | tokenizer = get_tokenizer(args.model) 78 | model.to(args.device) 79 | model.eval() 80 | context_length = model.context_length 81 | vocab_size = model.vocab_size 82 | 83 | print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}") 84 | print("Context length:", context_length) 85 | print("Vocab size:", vocab_size) 86 | with open(args.data_path, 'r') as f: 87 | lines = f.readlines() 88 | base, name = os.path.split(args.data_path) 89 | name = name.replace('.txt', '') 90 | features = get_text_features(model, tokenizer, lines, args.device, args.batch_size) 91 | with open(os.path.join(args.output_dir, f'{name}_{args.model}.npy'), 'wb') as f: 92 | np.save(f, features.numpy()) 93 | 94 | 95 | if __name__ == '__main__': 96 | args = get_args_parser() 97 | args = args.parse_args() 98 | if args.output_dir: 99 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 100 | main(args) -------------------------------------------------------------------------------- /compute_use_specific_heads.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os.path 4 | import argparse 5 | import einops 6 | from pathlib import Path 7 | import random 8 | import tqdm 9 | from utils.misc import accuracy 10 | 11 | 12 | def full_accuracy(preds, labels, locs_attributes): 13 | locs_labels = labels.detach().cpu().numpy() 14 | accs = {} 15 | for i in [0, 1]: 16 | for j in [0, 1]: 17 | locs = np.logical_and(locs_labels == i, locs_attributes == j) 18 | accs[f"({i}, {j})"] = accuracy(preds[locs], labels[locs])[0] * 100 19 | accs[f"full"] = accuracy(preds, labels)[0] * 100 20 | return accs 21 | 22 | 23 | def get_args_parser(): 24 | parser = argparse.ArgumentParser("Ablations part", add_help=False) 25 | 26 | # Model parameters 27 | parser.add_argument( 28 | "--model", 29 | default="ViT-H-14", 30 | type=str, 31 | metavar="MODEL", 32 | help="Name of model to use", 33 | ) 34 | # Dataset parameters 35 | parser.add_argument("--num_workers", default=10, type=int) 36 | parser.add_argument( 37 | "--figures_dir", default="./output_dir", help="path where data is saved" 38 | ) 39 | parser.add_argument( 40 | "--input_dir", default="./output_dir", help="path where data is saved" 41 | ) 42 | parser.add_argument( 43 | "--dataset", 44 | type=str, 45 | default="binary_waterbirds", 46 | help="imagenet, waterbirds, waterbirds_binary or cub", 47 | ) 48 | return parser 49 | 50 | 51 | def main(args): 52 | if args.model == "ViT-H-14": 53 | to_mean_ablate_setting = [(31, 12), (30, 11), (29, 4)] 54 | to_mean_ablate_geo = [(31, 8), (30, 15), (30, 12), (30, 6), (29, 14), (29, 8)] 55 | elif args.model == "ViT-L-14": 56 | to_mean_ablate_geo = [(21, 1), (22, 12), (22, 13), (21, 11), (21, 14), (23, 6)] 57 | to_mean_ablate_setting = [ 58 | (21, 3), 59 | (21, 6), 60 | (21, 8), 61 | (21, 13), 62 | (22, 2), 63 | (22, 12), 64 | (22, 15), 65 | (23, 1), 66 | (23, 3), 67 | (23, 5), 68 | ] 69 | elif args.model == "ViT-B-16": 70 | to_mean_ablate_setting = [(11, 3), (10, 11), (10, 10), (9, 8), (9, 6)] 71 | to_mean_ablate_geo = [(11, 6), (11, 0)] 72 | else: 73 | raise ValueError('model not analyzed') 74 | to_mean_ablate_output = to_mean_ablate_geo + to_mean_ablate_setting 75 | with open( 76 | os.path.join(args.input_dir, f"{args.dataset}_attn_{args.model}.npy"), "rb" 77 | ) as f: 78 | attns = np.load(f) # [b, l, h, d] 79 | with open( 80 | os.path.join(args.input_dir, f"{args.dataset}_mlp_{args.model}.npy"), "rb" 81 | ) as f: 82 | mlps = np.load(f) # [b, l+1, d] 83 | with open( 84 | os.path.join(args.input_dir, f"{args.dataset}_classifier_{args.model}.npy"), 85 | "rb", 86 | ) as f: 87 | classifier = np.load(f) 88 | 89 | if args.dataset == "imagenet": 90 | labels = np.array([i // 50 for i in range(attns.shape[0])]) 91 | else: 92 | with open( 93 | os.path.join(args.input_dir, f"{args.dataset}_labels.npy"), "rb" 94 | ) as f: 95 | labels = np.load(f) 96 | labels = labels[:, :, 0] 97 | baseline = attns.sum(axis=(1, 2)) + mlps.sum(axis=1) 98 | baseline_acc = full_accuracy( 99 | torch.from_numpy(baseline @ classifier).float(), 100 | torch.from_numpy(labels[:, 0]), 101 | labels[:, 1], 102 | ) 103 | print("Baseline:", baseline_acc) 104 | for layer, head in to_mean_ablate_output: 105 | attns[:, layer, head, :] = np.mean( 106 | attns[:, layer, head, :], axis=0, keepdims=True 107 | ) 108 | for layer in range(attns.shape[1] - 4): 109 | for head in range(attns.shape[2]): 110 | attns[:, layer, head, :] = np.mean( 111 | attns[:, layer, head, :], axis=0, keepdims=True 112 | ) 113 | for layer in range(mlps.shape[1]): 114 | mlps[:, layer] = np.mean(mlps[:, layer], axis=0, keepdims=True) 115 | ablated = attns.sum(axis=(1, 2)) + mlps.sum(axis=1) 116 | ablated_acc = full_accuracy( 117 | torch.from_numpy(ablated @ classifier).float(), 118 | torch.from_numpy(labels[:, 0]), 119 | labels[:, 1], 120 | ) 121 | print("Replaced:", ablated_acc) 122 | 123 | 124 | if __name__ == "__main__": 125 | args = get_args_parser() 126 | args = args.parse_args() 127 | if args.figures_dir: 128 | Path(args.figures_dir).mkdir(parents=True, exist_ok=True) 129 | main(args) 130 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: prsclip 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pytorch >= 1.13 8 | - torchvision 9 | - pytorch-cuda=11.7 10 | - pip: 11 | - timm 12 | - einops 13 | - ftfy 14 | - scipy 15 | - imageio 16 | - h5py 17 | - scikit-image 18 | - scikit-learn 19 | - opencv-python -------------------------------------------------------------------------------- /images/catdog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yossigandelsman/clip_text_span/49f70f31bb13437a870ff8de340626b225500a22/images/catdog.png -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yossigandelsman/clip_text_span/49f70f31bb13437a870ff8de340626b225500a22/images/teaser.png -------------------------------------------------------------------------------- /output_dir/binary_waterbirds_labels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yossigandelsman/clip_text_span/49f70f31bb13437a870ff8de340626b225500a22/output_dir/binary_waterbirds_labels.npy -------------------------------------------------------------------------------- /prs_hook.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | import glob 6 | import sys 7 | import argparse 8 | import datetime 9 | import json 10 | from pathlib import Path 11 | 12 | 13 | class PRSLogger(object): 14 | def __init__(self, model, device, spatial: bool = True): 15 | self.current_layer = 0 16 | self.device = device 17 | self.attentions = [] 18 | self.mlps = [] 19 | self.spatial = spatial 20 | self.post_ln_std = None 21 | self.post_ln_mean = None 22 | self.model = model 23 | 24 | @torch.no_grad() 25 | def compute_attentions_spatial(self, ret): 26 | assert len(ret.shape) == 5, "Verify that you use method=`head` and not method=`head_no_spatial`" # [b, n, m, h, d] 27 | assert self.spatial, "Verify that you use method=`head` and not method=`head_no_spatial`" 28 | bias_term = self.model.visual.transformer.resblocks[ 29 | self.current_layer 30 | ].attn.out_proj.bias 31 | self.current_layer += 1 32 | return_value = ret[:, 0].detach().cpu() # This is only for the cls token 33 | self.attentions.append( 34 | return_value 35 | + bias_term[np.newaxis, np.newaxis, np.newaxis].cpu() 36 | / (return_value.shape[1] * return_value.shape[2]) 37 | ) # [b, n, h, d] 38 | return ret 39 | 40 | @torch.no_grad() 41 | def compute_attentions_non_spatial(self, ret): 42 | assert len(ret.shape) == 4, "Verify that you use method=`head_no_spatial` and not method=`head`" # [b, n, h, d] 43 | assert not self.spatial, "Verify that you use method=`head_no_spatial` and not method=`head`" 44 | bias_term = self.model.visual.transformer.resblocks[ 45 | self.current_layer 46 | ].attn.out_proj.bias 47 | self.current_layer += 1 48 | return_value = ret[:, 0].detach().cpu() # This is only for the cls token 49 | self.attentions.append( 50 | return_value 51 | + bias_term[np.newaxis, np.newaxis].cpu() 52 | / (return_value.shape[1]) 53 | ) # [b, h, d] 54 | return ret 55 | 56 | @torch.no_grad() 57 | def compute_mlps(self, ret): 58 | self.mlps.append(ret[:, 0].detach().cpu()) # [b, d] 59 | return ret 60 | 61 | @torch.no_grad() 62 | def log_post_ln_mean(self, ret): 63 | self.post_ln_mean = ret.detach().cpu() # [b, 1] 64 | return ret 65 | 66 | @torch.no_grad() 67 | def log_post_ln_std(self, ret): 68 | self.post_ln_std = ret.detach().cpu() # [b, 1] 69 | return ret 70 | 71 | def _normalize_mlps(self): 72 | len_intermediates = self.attentions.shape[1] + self.mlps.shape[1] 73 | # This is just the normalization layer: 74 | mean_centered = ( 75 | self.mlps 76 | - self.post_ln_mean[:, :, np.newaxis].to(self.device) / len_intermediates 77 | ) 78 | weighted_mean_centered = ( 79 | self.model.visual.ln_post.weight.detach().to(self.device) * mean_centered 80 | ) 81 | weighted_mean_by_std = weighted_mean_centered / self.post_ln_std[ 82 | :, :, np.newaxis 83 | ].to(self.device) 84 | bias_term = ( 85 | self.model.visual.ln_post.bias.detach().to(self.device) / len_intermediates 86 | ) 87 | post_ln = weighted_mean_by_std + bias_term 88 | return post_ln @ self.model.visual.proj.detach().to(self.device) 89 | 90 | def _normalize_attentions_spatial(self): 91 | len_intermediates = self.attentions.shape[1] + self.mlps.shape[1] # 2*l + 1 92 | normalization_term = ( 93 | self.attentions.shape[2] * self.attentions.shape[3] 94 | ) # n * h 95 | # This is just the normalization layer: 96 | mean_centered = self.attentions - self.post_ln_mean[ 97 | :, :, np.newaxis, np.newaxis, np.newaxis 98 | ].to(self.device) / (len_intermediates * normalization_term) 99 | weighted_mean_centered = ( 100 | self.model.visual.ln_post.weight.detach().to(self.device) * mean_centered 101 | ) 102 | weighted_mean_by_std = weighted_mean_centered / self.post_ln_std[ 103 | :, :, np.newaxis, np.newaxis, np.newaxis 104 | ].to(self.device) 105 | bias_term = self.model.visual.ln_post.bias.detach().to(self.device) / ( 106 | len_intermediates * normalization_term 107 | ) 108 | post_ln = weighted_mean_by_std + bias_term 109 | return post_ln @ self.model.visual.proj.detach().to(self.device) 110 | 111 | def _normalize_attentions_non_spatial(self): 112 | len_intermediates = self.attentions.shape[1] + self.mlps.shape[1] # 2*l + 1 113 | normalization_term = ( 114 | self.attentions.shape[2] 115 | ) # h 116 | # This is just the normalization layer: 117 | mean_centered = self.attentions - self.post_ln_mean[ 118 | :, :, np.newaxis, np.newaxis 119 | ].to(self.device) / (len_intermediates * normalization_term) 120 | weighted_mean_centered = ( 121 | self.model.visual.ln_post.weight.detach().to(self.device) * mean_centered 122 | ) 123 | weighted_mean_by_std = weighted_mean_centered / self.post_ln_std[ 124 | :, :, np.newaxis, np.newaxis 125 | ].to(self.device) 126 | bias_term = self.model.visual.ln_post.bias.detach().to(self.device) / ( 127 | len_intermediates * normalization_term 128 | ) 129 | post_ln = weighted_mean_by_std + bias_term 130 | return post_ln @ self.model.visual.proj.detach().to(self.device) 131 | 132 | @torch.no_grad() 133 | def finalize(self, representation): 134 | """We calculate the post-ln scaling, project it and normalize by the last norm.""" 135 | self.attentions = torch.stack(self.attentions, axis=1).to( 136 | self.device 137 | ) # [b, l, n, h, d] 138 | self.mlps = torch.stack(self.mlps, axis=1).to(self.device) # [b, l + 1, d] 139 | if self.spatial: 140 | projected_attentions = self._normalize_attentions_spatial() 141 | else: 142 | projected_attentions = self._normalize_attentions_non_spatial() 143 | projected_mlps = self._normalize_mlps() 144 | norm = representation.norm(dim=-1).detach() 145 | if self.spatial: 146 | return ( 147 | projected_attentions 148 | / norm[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis], 149 | projected_mlps / norm[:, np.newaxis, np.newaxis], 150 | ) 151 | return ( 152 | projected_attentions 153 | / norm[:, np.newaxis, np.newaxis, np.newaxis], 154 | projected_mlps / norm[:, np.newaxis, np.newaxis], 155 | ) 156 | 157 | def reinit(self): 158 | self.current_layer = 0 159 | self.attentions = [] 160 | self.mlps = [] 161 | self.post_ln_mean = None 162 | self.post_ln_std = None 163 | torch.cuda.empty_cache() 164 | 165 | 166 | def hook_prs_logger(model, device, spatial: bool = True): 167 | """Hooks a projected residual stream logger to the model.""" 168 | prs = PRSLogger(model, device, spatial=spatial) 169 | if spatial: 170 | model.hook_manager.register( 171 | "visual.transformer.resblocks.*.attn.out.post", prs.compute_attentions_spatial 172 | ) 173 | else: 174 | model.hook_manager.register( 175 | "visual.transformer.resblocks.*.attn.out.post", prs.compute_attentions_non_spatial 176 | ) 177 | model.hook_manager.register( 178 | "visual.transformer.resblocks.*.mlp.c_proj.post", prs.compute_mlps 179 | ) 180 | model.hook_manager.register("visual.ln_pre_post", prs.compute_mlps) 181 | model.hook_manager.register("visual.ln_post.mean", prs.log_post_ln_mean) 182 | model.hook_manager.register("visual.ln_post.sqrt_var", prs.log_post_ln_std) 183 | return prs 184 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 2 | from utils.factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss 3 | from utils.factory import list_models, add_model_config, get_model_config, load_checkpoint 4 | from utils.pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ 5 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 6 | from utils.tokenizer import SimpleTokenizer, tokenize, decode 7 | from utils.transform import image_transform, AugmentationCfg 8 | from utils.openai_templates import OPENAI_IMAGENET_TEMPLATES -------------------------------------------------------------------------------- /utils/binary_waterbirds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple 4 | from typing import Union 5 | 6 | from PIL import Image 7 | import pandas as pd 8 | from torchvision.datasets import VisionDataset 9 | import torch 10 | 11 | 12 | def pil_loader(path: str) -> Image.Image: 13 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 14 | with open(path, "rb") as f: 15 | img = Image.open(f) 16 | return img.convert("RGB") 17 | 18 | class BinaryWaterbirds(VisionDataset): 19 | def __init__( 20 | self, 21 | root: str, 22 | split: str, 23 | loader: Callable[[str], Any] = pil_loader, 24 | transform: Optional[Callable] = None, 25 | target_transform: Optional[Callable] = None, 26 | ) -> None: 27 | super().__init__(root, transform=transform, target_transform=target_transform) 28 | 29 | self.loader = loader 30 | csv = pd.read_csv(os.path.join(root, 'metadata.csv')) 31 | split = {'test': 2, 'valid': 1, 'train': 0}[split] 32 | csv = csv[csv['split'] == split] 33 | self.samples = [(os.path.join(root, csv.iloc[i]['img_filename']), csv.iloc[i]['y']) for i in range(len(csv))] 34 | 35 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 36 | """ 37 | Args: 38 | index (int): Index 39 | Returns: 40 | tuple: (sample, target) where target is class_index of the target class. 41 | """ 42 | path, target = self.samples[index] 43 | sample = self.loader(path) 44 | if self.transform is not None: 45 | sample = self.transform(sample) 46 | if self.target_transform is not None: 47 | target = self.target_transform(target) 48 | 49 | return sample, target 50 | 51 | def __len__(self) -> int: 52 | return len(self.samples) 53 | -------------------------------------------------------------------------------- /utils/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) -------------------------------------------------------------------------------- /utils/cub_classes.py: -------------------------------------------------------------------------------- 1 | cub_classes = ['Black footed Albatross', 'Laysan Albatross', 'Sooty Albatross', 'Groove billed Ani', 'Crested Auklet', 'Least Auklet', 'Parakeet Auklet', 'Rhinoceros Auklet', 'Brewer Blackbird', 'Red winged Blackbird', 'Rusty Blackbird', 'Yellow headed Blackbird', 'Bobolink', 'Indigo Bunting', 'Lazuli Bunting', 'Painted Bunting', 'Cardinal', 'Spotted Catbird', 'Gray Catbird', 'Yellow breasted Chat', 'Eastern Towhee', 'Chuck will Widow', 'Brandt Cormorant', 'Red faced Cormorant', 'Pelagic Cormorant', 'Bronzed Cowbird', 'Shiny Cowbird', 'Brown Creeper', 'American Crow', 'Fish Crow', 'Black billed Cuckoo', 'Mangrove Cuckoo', 'Yellow billed Cuckoo', 'Gray crowned Rosy Finch', 'Purple Finch', 'Northern Flicker', 'Acadian Flycatcher', 'Great Crested Flycatcher', 'Least Flycatcher', 'Olive sided Flycatcher', 'Scissor tailed Flycatcher', 'Vermilion Flycatcher', 'Yellow bellied Flycatcher', 'Frigatebird', 'Northern Fulmar', 'Gadwall', 'American Goldfinch', 'European Goldfinch', 'Boat tailed Grackle', 'Eared Grebe', 'Horned Grebe', 'Pied billed Grebe', 'Western Grebe', 'Blue Grosbeak', 'Evening Grosbeak', 'Pine Grosbeak', 'Rose breasted Grosbeak', 'Pigeon Guillemot', 'California Gull', 'Glaucous winged Gull', 'Heermann Gull', 'Herring Gull', 'Ivory Gull', 'Ring billed Gull', 'Slaty backed Gull', 'Western Gull', 'Anna Hummingbird', 'Ruby throated Hummingbird', 'Rufous Hummingbird', 'Green Violetear', 'Long tailed Jaeger', 'Pomarine Jaeger', 'Blue Jay', 'Florida Jay', 'Green Jay', 'Dark eyed Junco', 'Tropical Kingbird', 'Gray Kingbird', 'Belted Kingfisher', 'Green Kingfisher', 'Pied Kingfisher', 'Ringed Kingfisher', 'White breasted Kingfisher', 'Red legged Kittiwake', 'Horned Lark', 'Pacific Loon', 'Mallard', 'Western Meadowlark', 'Hooded Merganser', 'Red breasted Merganser', 'Mockingbird', 'Nighthawk', 'Clark Nutcracker', 'White breasted Nuthatch', 'Baltimore Oriole', 'Hooded Oriole', 'Orchard Oriole', 'Scott Oriole', 'Ovenbird', 'Brown Pelican', 'White Pelican', 'Western Wood Pewee', 'Sayornis', 'American Pipit', 'Whip poor Will', 'Horned Puffin', 'Common Raven', 'White necked Raven', 'American Redstart', 'Geococcyx', 'Loggerhead Shrike', 'Great Grey Shrike', 'Baird Sparrow', 'Black throated Sparrow', 'Brewer Sparrow', 'Chipping Sparrow', 'Clay colored Sparrow', 'House Sparrow', 'Field Sparrow', 'Fox Sparrow', 'Grasshopper Sparrow', 'Harris Sparrow', 'Henslow Sparrow', 'Le Conte Sparrow', 'Lincoln Sparrow', 'Nelson Sharp tailed Sparrow', 'Savannah Sparrow', 'Seaside Sparrow', 'Song Sparrow', 'Tree Sparrow', 'Vesper Sparrow', 'White crowned Sparrow', 'White throated Sparrow', 'Cape Glossy Starling', 'Bank Swallow', 'Barn Swallow', 'Cliff Swallow', 'Tree Swallow', 'Scarlet Tanager', 'Summer Tanager', 'Artic Tern', 'Black Tern', 'Caspian Tern', 'Common Tern', 'Elegant Tern', 'Forsters Tern', 'Least Tern', 'Green tailed Towhee', 'Brown Thrasher', 'Sage Thrasher', 'Black capped Vireo', 'Blue headed Vireo', 'Philadelphia Vireo', 'Red eyed Vireo', 'Warbling Vireo', 'White eyed Vireo', 'Yellow throated Vireo', 'Bay breasted Warbler', 'Black and white Warbler', 'Black throated Blue Warbler', 'Blue winged Warbler', 'Canada Warbler', 'Cape May Warbler', 'Cerulean Warbler', 'Chestnut sided Warbler', 'Golden winged Warbler', 'Hooded Warbler', 'Kentucky Warbler', 'Magnolia Warbler', 'Mourning Warbler', 'Myrtle Warbler', 'Nashville Warbler', 'Orange crowned Warbler', 'Palm Warbler', 'Pine Warbler', 'Prairie Warbler', 'Prothonotary Warbler', 'Swainson Warbler', 'Tennessee Warbler', 'Wilson Warbler', 'Worm eating Warbler', 'Yellow Warbler', 'Northern Waterthrush', 'Louisiana Waterthrush', 'Bohemian Waxwing', 'Cedar Waxwing', 'American Three toed Woodpecker', 'Pileated Woodpecker', 'Red bellied Woodpecker', 'Red cockaded Woodpecker', 'Red headed Woodpecker', 'Downy Woodpecker', 'Bewick Wren', 'Cactus Wren', 'Carolina Wren', 'House Wren', 'Marsh Wren', 'Rock Wren', 'Winter Wren', 'Common Yellowthroat'] 2 | waterbird_classes = ['landbird', 'waterbird'] -------------------------------------------------------------------------------- /utils/factory.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import pathlib 5 | import re 6 | from copy import deepcopy 7 | from pathlib import Path 8 | from typing import Any, Dict, Optional, Tuple, Union 9 | 10 | import torch 11 | 12 | from utils.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | from utils.model import CLIP, convert_to_custom_text_state_dict,\ 14 | resize_pos_embed, get_cast_dtype 15 | from utils.openai_models import load_openai_model 16 | from utils.pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ 17 | list_pretrained_tags_by_model, download_pretrained_from_hf 18 | from utils.transform import image_transform, AugmentationCfg 19 | from utils.tokenizer import HFTokenizer, tokenize 20 | 21 | 22 | HF_HUB_PREFIX = 'hf-hub:' 23 | _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] 24 | _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs 25 | 26 | 27 | def _natural_key(string_): 28 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] 29 | 30 | 31 | def _rescan_model_configs(): 32 | global _MODEL_CONFIGS 33 | 34 | config_ext = ('.json',) 35 | config_files = [] 36 | for config_path in _MODEL_CONFIG_PATHS: 37 | if config_path.is_file() and config_path.suffix in config_ext: 38 | config_files.append(config_path) 39 | elif config_path.is_dir(): 40 | for ext in config_ext: 41 | config_files.extend(config_path.glob(f'*{ext}')) 42 | 43 | for cf in config_files: 44 | with open(cf, 'r') as f: 45 | model_cfg = json.load(f) 46 | if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): 47 | _MODEL_CONFIGS[cf.stem] = model_cfg 48 | 49 | _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} 50 | 51 | 52 | _rescan_model_configs() # initial populate of model config registry 53 | 54 | 55 | def list_models(): 56 | """ enumerate available model architectures based on config files """ 57 | return list(_MODEL_CONFIGS.keys()) 58 | 59 | 60 | def add_model_config(path): 61 | """ add model config path or file and update registry """ 62 | if not isinstance(path, Path): 63 | path = Path(path) 64 | _MODEL_CONFIG_PATHS.append(path) 65 | _rescan_model_configs() 66 | 67 | 68 | def get_model_config(model_name): 69 | if model_name in _MODEL_CONFIGS: 70 | return deepcopy(_MODEL_CONFIGS[model_name]) 71 | else: 72 | return None 73 | 74 | 75 | def get_tokenizer(model_name): 76 | if model_name.startswith(HF_HUB_PREFIX): 77 | tokenizer = HFTokenizer(model_name[len(HF_HUB_PREFIX):]) 78 | else: 79 | config = get_model_config(model_name) 80 | tokenizer = HFTokenizer( 81 | config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize 82 | return tokenizer 83 | 84 | 85 | def load_state_dict(checkpoint_path: str, map_location='cpu'): 86 | checkpoint = torch.load(checkpoint_path, map_location=map_location) 87 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 88 | state_dict = checkpoint['state_dict'] 89 | else: 90 | state_dict = checkpoint 91 | if next(iter(state_dict.items()))[0].startswith('module'): 92 | state_dict = {k[7:]: v for k, v in state_dict.items()} 93 | return state_dict 94 | 95 | 96 | def load_checkpoint(model, checkpoint_path, strict=True): 97 | state_dict = load_state_dict(checkpoint_path) 98 | # detect old format and make compatible with new format 99 | if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): 100 | state_dict = convert_to_custom_text_state_dict(state_dict) 101 | resize_pos_embed(state_dict, model) 102 | incompatible_keys = model.load_state_dict(state_dict, strict=strict) 103 | return incompatible_keys 104 | 105 | 106 | def create_model( 107 | model_name: str, 108 | pretrained: Optional[str] = None, 109 | precision: str = 'fp32', 110 | device: Union[str, torch.device] = 'cpu', 111 | jit: bool = False, 112 | force_quick_gelu: bool = False, 113 | force_custom_text: bool = False, 114 | force_patch_dropout: Optional[float] = None, 115 | force_image_size: Optional[Union[int, Tuple[int, int]]] = None, 116 | pretrained_image: bool = False, 117 | pretrained_hf: bool = True, 118 | cache_dir: Optional[str] = None, 119 | output_dict: Optional[bool] = None, 120 | require_pretrained: bool = False, 121 | ): 122 | has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) 123 | if has_hf_hub_prefix: 124 | model_id = model_name[len(HF_HUB_PREFIX):] 125 | checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) 126 | config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) 127 | 128 | with open(config_path, 'r', encoding='utf-8') as f: 129 | config = json.load(f) 130 | pretrained_cfg = config['preprocess_cfg'] 131 | model_cfg = config['model_cfg'] 132 | else: 133 | model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names 134 | checkpoint_path = None 135 | pretrained_cfg = {} 136 | model_cfg = None 137 | 138 | if isinstance(device, str): 139 | device = torch.device(device) 140 | 141 | if pretrained and pretrained.lower() == 'openai': 142 | logging.info(f'Loading pretrained {model_name} from OpenAI.') 143 | model = load_openai_model( 144 | model_name, 145 | precision=precision, 146 | device=device, 147 | cache_dir=cache_dir, 148 | ) 149 | else: 150 | model_cfg = model_cfg or get_model_config(model_name) 151 | if model_cfg is not None: 152 | logging.info(f'Loaded {model_name} model config.') 153 | else: 154 | logging.error(f'Model config for {model_name} not found; available models {list_models()}.') 155 | raise RuntimeError(f'Model config for {model_name} not found.') 156 | 157 | if force_quick_gelu: 158 | # override for use of QuickGELU on non-OpenAI transformer models 159 | model_cfg["quick_gelu"] = True 160 | 161 | if force_patch_dropout is not None: 162 | # override the default patch dropout value 163 | model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout 164 | 165 | if force_image_size is not None: 166 | # override model config's image size 167 | model_cfg["vision_cfg"]["image_size"] = force_image_size 168 | 169 | is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) 170 | if pretrained_image: 171 | if is_timm_model: 172 | # pretrained weight loading for timm models set via vision_cfg 173 | model_cfg['vision_cfg']['timm_model_pretrained'] = True 174 | else: 175 | assert False, 'pretrained image towers currently only supported for timm models' 176 | 177 | # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes 178 | cast_dtype = get_cast_dtype(precision) 179 | is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) 180 | custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model 181 | 182 | if custom_text: 183 | if is_hf_model: 184 | model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf 185 | if "coca" in model_name: 186 | raise ValueError('Coca is not implemented') 187 | model = CoCa(**model_cfg, cast_dtype=cast_dtype) 188 | else: 189 | raise ValueError('CustomTextCLIP is not implemented') 190 | model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) 191 | else: 192 | model = CLIP(**model_cfg, cast_dtype=cast_dtype) 193 | 194 | if precision in ("fp16", "bf16"): 195 | dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 196 | # manual mixed precision that matches original OpenAI behaviour 197 | if is_timm_model: 198 | # FIXME this is a bit janky, create timm based model in low-precision and 199 | # then cast only LayerNormFp32 instances back to float32 so they don't break. 200 | # Why? The convert_weights_to_lp fn only works with native models. 201 | model.to(device=device, dtype=dtype) 202 | from transformer import LayerNormFp32 203 | def _convert_ln(m): 204 | if isinstance(m, LayerNormFp32): 205 | m.weight.data = m.weight.data.to(torch.float32) 206 | m.bias.data = m.bias.data.to(torch.float32) 207 | model.apply(_convert_ln) 208 | else: 209 | model.to(device=device) 210 | convert_weights_to_lp(model, dtype=dtype) 211 | elif precision in ("pure_fp16", "pure_bf16"): 212 | dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 213 | model.to(device=device, dtype=dtype) 214 | else: 215 | model.to(device=device) 216 | 217 | pretrained_loaded = False 218 | if pretrained: 219 | checkpoint_path = '' 220 | pretrained_cfg = get_pretrained_cfg(model_name, pretrained) 221 | if pretrained_cfg: 222 | checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) 223 | elif os.path.exists(pretrained): 224 | checkpoint_path = pretrained 225 | 226 | if checkpoint_path: 227 | logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') 228 | load_checkpoint(model, checkpoint_path) 229 | else: 230 | error_str = ( 231 | f'Pretrained weights ({pretrained}) not found for model {model_name}.' 232 | f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') 233 | logging.warning(error_str) 234 | raise RuntimeError(error_str) 235 | pretrained_loaded = True 236 | elif has_hf_hub_prefix: 237 | logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') 238 | load_checkpoint(model, checkpoint_path) 239 | pretrained_loaded = True 240 | 241 | if require_pretrained and not pretrained_loaded: 242 | # callers of create_model_from_pretrained always expect pretrained weights 243 | raise RuntimeError( 244 | f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') 245 | 246 | # set image / mean metadata from pretrained_cfg if available, or use default 247 | model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN 248 | model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD 249 | 250 | if output_dict and hasattr(model, "output_dict"): 251 | model.output_dict = True 252 | 253 | if jit: 254 | model = torch.jit.script(model) 255 | 256 | return model 257 | 258 | 259 | def create_loss(args): 260 | if args.distill: 261 | return DistillClipLoss( 262 | local_loss=args.local_loss, 263 | gather_with_grad=args.gather_with_grad, 264 | cache_labels=True, 265 | rank=args.rank, 266 | world_size=args.world_size, 267 | use_horovod=args.horovod, 268 | ) 269 | elif "coca" in args.model.lower(): 270 | return CoCaLoss( 271 | caption_loss_weight=args.coca_caption_loss_weight, 272 | clip_loss_weight=args.coca_contrastive_loss_weight, 273 | local_loss=args.local_loss, 274 | gather_with_grad=args.gather_with_grad, 275 | cache_labels=True, 276 | rank=args.rank, 277 | world_size=args.world_size, 278 | use_horovod=args.horovod, 279 | ) 280 | return ClipLoss( 281 | local_loss=args.local_loss, 282 | gather_with_grad=args.gather_with_grad, 283 | cache_labels=True, 284 | rank=args.rank, 285 | world_size=args.world_size, 286 | use_horovod=args.horovod, 287 | ) 288 | 289 | 290 | def create_model_and_transforms( 291 | model_name: str, 292 | pretrained: Optional[str] = None, 293 | precision: str = 'fp32', 294 | device: Union[str, torch.device] = 'cpu', 295 | jit: bool = False, 296 | force_quick_gelu: bool = False, 297 | force_custom_text: bool = False, 298 | force_patch_dropout: Optional[float] = None, 299 | force_image_size: Optional[Union[int, Tuple[int, int]]] = None, 300 | pretrained_image: bool = False, 301 | pretrained_hf: bool = True, 302 | image_mean: Optional[Tuple[float, ...]] = None, 303 | image_std: Optional[Tuple[float, ...]] = None, 304 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, 305 | cache_dir: Optional[str] = None, 306 | output_dict: Optional[bool] = None, 307 | ): 308 | model = create_model( 309 | model_name, 310 | pretrained, 311 | precision=precision, 312 | device=device, 313 | jit=jit, 314 | force_quick_gelu=force_quick_gelu, 315 | force_custom_text=force_custom_text, 316 | force_patch_dropout=force_patch_dropout, 317 | force_image_size=force_image_size, 318 | pretrained_image=pretrained_image, 319 | pretrained_hf=pretrained_hf, 320 | cache_dir=cache_dir, 321 | output_dict=output_dict, 322 | ) 323 | 324 | image_mean = image_mean or getattr(model.visual, 'image_mean', None) 325 | image_std = image_std or getattr(model.visual, 'image_std', None) 326 | preprocess_train = image_transform( 327 | model.visual.image_size, 328 | is_train=True, 329 | mean=image_mean, 330 | std=image_std, 331 | aug_cfg=aug_cfg, 332 | ) 333 | preprocess_val = image_transform( 334 | model.visual.image_size, 335 | is_train=False, 336 | mean=image_mean, 337 | std=image_std, 338 | ) 339 | 340 | return model, preprocess_train, preprocess_val 341 | 342 | 343 | def create_model_from_pretrained( 344 | model_name: str, 345 | pretrained: Optional[str] = None, 346 | precision: str = 'fp32', 347 | device: Union[str, torch.device] = 'cpu', 348 | jit: bool = False, 349 | force_quick_gelu: bool = False, 350 | force_custom_text: bool = False, 351 | force_image_size: Optional[Union[int, Tuple[int, int]]] = None, 352 | return_transform: bool = True, 353 | image_mean: Optional[Tuple[float, ...]] = None, 354 | image_std: Optional[Tuple[float, ...]] = None, 355 | cache_dir: Optional[str] = None, 356 | ): 357 | model = create_model( 358 | model_name, 359 | pretrained, 360 | precision=precision, 361 | device=device, 362 | jit=jit, 363 | force_quick_gelu=force_quick_gelu, 364 | force_custom_text=force_custom_text, 365 | force_image_size=force_image_size, 366 | cache_dir=cache_dir, 367 | require_pretrained=True, 368 | ) 369 | 370 | if not return_transform: 371 | return model 372 | 373 | image_mean = image_mean or getattr(model.visual, 'image_mean', None) 374 | image_std = image_std or getattr(model.visual, 'image_std', None) 375 | preprocess = image_transform( 376 | model.visual.image_size, 377 | is_train=False, 378 | mean=image_mean, 379 | std=image_std, 380 | ) 381 | 382 | return model, preprocess -------------------------------------------------------------------------------- /utils/hook.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Text, Callable, List 2 | from collections import defaultdict 3 | 4 | 5 | class HookManager(object): 6 | def __init__(self, hook_dict: Dict[Text, List[Callable]] = None): 7 | self.hook_dict = hook_dict or defaultdict(list) 8 | self.called = defaultdict(int) 9 | self.forks = dict() 10 | 11 | def register(self, name: Text, func: Callable): 12 | assert name 13 | found_successor = False 14 | for header, d in self.forks.items(): 15 | if name.startswith(header.split('.')[0]+'.'): 16 | next_ = name[len(header.split('.')[0]+'.'):].split('.')[0] 17 | prev_ = header.split('.')[0] 18 | if next_.isnumeric() and prev_ + '.' + next_ == header: 19 | d.register(name[len(header)+1:], func) 20 | elif next_ == '*': 21 | d.register(name[len(prev_ + '.*')+1:], func) 22 | else: 23 | d.register(name[len(header)+1:], func) 24 | found_successor = True 25 | if not found_successor: 26 | self.hook_dict[name].append(func) 27 | 28 | def unregister(self, name: Text, func: Callable): 29 | assert name 30 | found_successor = False 31 | for header, d in self.forks.items(): 32 | if name.startswith(header.split('.')[0]+'.'): 33 | next_ = name[len(header.split('.')[0]+'.'):].split('.')[0] 34 | prev_ = header.split('.')[0] 35 | if next_.isnumeric() and prev_ + '.' + next_ == header: 36 | d.register(name[len(header)+1:], func) 37 | elif next_ == '*': 38 | d.register(name[len(prev_ + '.*')+1:], func) 39 | else: 40 | d.register(name[len(header)+1:], func) 41 | found_successor = True 42 | if not found_successor and func in self.hook_dict[name]: 43 | self.hook_dict[name].remove(func) 44 | 45 | def __call__(self, name: Text, **kwargs): 46 | if name in self.hook_dict: 47 | self.called[name] += 1 48 | for function in self.hook_dict[name]: 49 | ret = function(**kwargs) 50 | if len(self.hook_dict[name]) > 1: 51 | last = self.hook_dict[name][-1] 52 | # print(f'The last returned value comes from func {last}') 53 | return ret 54 | else: 55 | return kwargs['ret'] 56 | 57 | def fork(self, name): 58 | if name in self.forks: 59 | raise ValueError(f'Forking with the same name is not allowed. Already forked with {name}.') 60 | filtered_hooks = [(k[len(name)+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.')] 61 | filtered_hooks_d = defaultdict(list) 62 | for i, j in filtered_hooks: 63 | if isinstance(j, list): 64 | filtered_hooks_d[i].extend(j) 65 | else: 66 | filtered_hooks_d[i].append(j) 67 | new_hook = HookManager(filtered_hooks_d) 68 | self.forks[name] = new_hook 69 | return new_hook 70 | 71 | def fork_iterative(self, name, iteration): 72 | filtered_hooks = [(k[len(name+'.'+str(iteration))+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.'+str(iteration)+'.')] 73 | filtered_hooks += [(k[len(name+'.*')+1:], v) for k, v in self.hook_dict.items() if k.startswith(name+'.*.')] 74 | filtered_hooks_d = defaultdict(list) 75 | for i, j in filtered_hooks: 76 | if isinstance(j, list): 77 | filtered_hooks_d[i].extend(j) 78 | else: 79 | filtered_hooks_d[i].append(j) 80 | new_hook = HookManager(filtered_hooks_d) 81 | self.forks[name+'.'+str(iteration)] = new_hook 82 | return new_hook 83 | 84 | def finalize(self): 85 | for name in self.hook_dict.keys(): 86 | if self.called[name] == 0: 87 | raise ValueError(f'Hook {name} was registered but never used!') -------------------------------------------------------------------------------- /utils/imagenet_classes.py: -------------------------------------------------------------------------------- 1 | imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] -------------------------------------------------------------------------------- /utils/imagenet_segmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.utils.data as data 4 | import numpy as np 5 | 6 | from torchvision.datasets import ImageNet 7 | 8 | from PIL import Image, ImageFilter 9 | import h5py 10 | from glob import glob 11 | 12 | 13 | class ImagenetSegmentation(data.Dataset): 14 | CLASSES = 2 15 | 16 | def __init__(self, 17 | path, 18 | transform=None, 19 | target_transform=None): 20 | self.path = path 21 | self.transform = transform 22 | self.target_transform = target_transform 23 | self.h5py = None 24 | tmp = h5py.File(path, 'r') 25 | self.data_length = len(tmp['/value/img']) 26 | tmp.close() 27 | del tmp 28 | 29 | def __getitem__(self, index): 30 | 31 | if self.h5py is None: 32 | self.h5py = h5py.File(self.path, 'r') 33 | 34 | img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0)) 35 | target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0)) 36 | 37 | img = Image.fromarray(img).convert('RGB') 38 | target = Image.fromarray(target) 39 | 40 | if self.transform is not None: 41 | img = self.transform(img) 42 | 43 | if self.target_transform is not None: 44 | target = np.array(self.target_transform(target)).astype('int32') 45 | target = torch.from_numpy(target).long() 46 | 47 | return img, target 48 | 49 | def __len__(self): 50 | return self.data_length 51 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | 4 | import torch 5 | from torch import nn as nn 6 | from torchvision.ops.misc import FrozenBatchNorm2d 7 | 8 | 9 | def freeze_batch_norm_2d(module, module_match={}, name=''): 10 | """ 11 | Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is 12 | itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and 13 | returned. Otherwise, the module is walked recursively and submodules are converted in place. 14 | 15 | Args: 16 | module (torch.nn.Module): Any PyTorch module. 17 | module_match (dict): Dictionary of full module names to freeze (all if empty) 18 | name (str): Full module name (prefix) 19 | 20 | Returns: 21 | torch.nn.Module: Resulting module 22 | 23 | Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 24 | """ 25 | res = module 26 | is_match = True 27 | if module_match: 28 | is_match = name in module_match 29 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 30 | res = FrozenBatchNorm2d(module.num_features) 31 | res.num_features = module.num_features 32 | res.affine = module.affine 33 | if module.affine: 34 | res.weight.data = module.weight.data.clone().detach() 35 | res.bias.data = module.bias.data.clone().detach() 36 | res.running_mean.data = module.running_mean.data 37 | res.running_var.data = module.running_var.data 38 | res.eps = module.eps 39 | else: 40 | for child_name, child in module.named_children(): 41 | full_child_name = '.'.join([name, child_name]) if name else child_name 42 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 43 | if new_child is not child: 44 | res.add_module(child_name, new_child) 45 | return res 46 | 47 | 48 | # From PyTorch internals 49 | def _ntuple(n): 50 | def parse(x): 51 | if isinstance(x, collections.abc.Iterable): 52 | return x 53 | return tuple(repeat(x, n)) 54 | return parse 55 | 56 | 57 | to_1tuple = _ntuple(1) 58 | to_2tuple = _ntuple(2) 59 | to_3tuple = _ntuple(3) 60 | to_4tuple = _ntuple(4) 61 | to_ntuple = lambda n, x: _ntuple(n)(x) 62 | 63 | # Replaces all linear layers with linear_replacement 64 | # TODO: add int8 support for other linear layers including attn and convnets 65 | def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): 66 | for name, module in model.named_children(): 67 | if len(list(module.children())) > 0: 68 | replace_linear(module, linear_replacement, include_modules, copy_weights) 69 | 70 | if isinstance(module, torch.nn.Linear) and name in include_modules: 71 | old_module = model._modules[name] 72 | model._modules[name] = linear_replacement( 73 | module.in_features, 74 | module.out_features, 75 | module.bias is not None, 76 | ) 77 | if copy_weights: 78 | model._modules[name].weight.data.copy_(old_module.weight.data) 79 | if model._modules[name].bias is not None: 80 | model._modules[name].bias.data.copy_(old_module.bias) 81 | 82 | return model 83 | 84 | def convert_int8_model_to_inference_mode(model): 85 | for m in model.modules(): 86 | if hasattr(m, 'prepare_for_eval'): 87 | int8_original_dtype = m.weight.dtype 88 | m.prepare_for_eval() 89 | m.int8_original_dtype = int8_original_dtype 90 | 91 | 92 | def accuracy(output, target, topk=(1,)): 93 | """ 94 | Compute top-k accuracy 95 | 96 | output: torch.Tensor 97 | shape (N, C) where N is the number of examples, C the number of classes. 98 | these are the logits. 99 | 100 | target: torch.Tensor 101 | shape (N,) where N is the number of examples. Groundtruth class id of each example. 102 | 103 | topk: tuple 104 | which topk to compute, e.g., topk=(1,5) will compute top-1 and top-5 accuracies 105 | 106 | Returns 107 | ------- 108 | 109 | list of top-k accuracies in the same order as `topk` 110 | """ 111 | pred = output.topk(max(topk), 1, True, True)[1].t() 112 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 113 | n = len(target) 114 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) / n for k in topk] 115 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | """ CLIP Model 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | from dataclasses import dataclass 6 | import logging 7 | import math 8 | from typing import Optional, Tuple, Union, Text 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn 14 | from torch.utils.checkpoint import checkpoint 15 | 16 | 17 | from utils.modified_resnet import ModifiedResNet 18 | from utils.timm_model import TimmModel 19 | from utils.transformer import LayerNorm, QuickGELU, VisionTransformer, TextTransformer, Attention 20 | from utils.misc import to_2tuple 21 | from utils.hook import HookManager 22 | 23 | 24 | @dataclass 25 | class CLIPVisionCfg: 26 | layers: Union[Tuple[int, int, int, int], int] = 12 27 | width: int = 768 28 | head_width: int = 64 29 | mlp_ratio: float = 4.0 30 | patch_size: int = 16 31 | image_size: Union[Tuple[int, int], int] = 224 32 | 33 | ls_init_value: Optional[float] = None # layer scale initial value 34 | patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results 35 | input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design 36 | global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) 37 | attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer 38 | n_queries: int = 256 # n_queries for attentional pooler 39 | attn_pooler_heads: int = 8 # n heads for attentional_pooling 40 | output_tokens: bool = False 41 | 42 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size 43 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model 44 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') 45 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') 46 | timm_proj_bias: bool = False # enable bias final projection 47 | timm_drop: float = 0. # head dropout 48 | timm_drop_path: Optional[float] = None # backbone stochastic depth 49 | 50 | 51 | 52 | 53 | def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): 54 | """Convert applicable model parameters to low-precision (bf16 or fp16)""" 55 | 56 | def _convert_weights(l): 57 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 58 | l.weight.data = l.weight.data.to(dtype) 59 | if l.bias is not None: 60 | l.bias.data = l.bias.data.to(dtype) 61 | 62 | if isinstance(l, (nn.MultiheadAttention, Attention)): 63 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 64 | tensor = getattr(l, attr) 65 | if tensor is not None: 66 | tensor.data = tensor.data.to(dtype) 67 | 68 | if isinstance(l, (CLIP, TextTransformer)): 69 | # convert text nn.Parameter projections 70 | attr = getattr(l, "text_projection", None) 71 | if attr is not None: 72 | attr.data = attr.data.to(dtype) 73 | 74 | if isinstance(l, VisionTransformer): 75 | # convert vision nn.Parameter projections 76 | attr = getattr(l, "proj", None) 77 | if attr is not None: 78 | attr.data = attr.data.to(dtype) 79 | 80 | model.apply(_convert_weights) 81 | 82 | convert_weights_to_fp16 = convert_weights_to_lp # backwards compat 83 | 84 | 85 | @dataclass 86 | class CLIPTextCfg: 87 | context_length: int = 77 88 | vocab_size: int = 49408 89 | width: int = 512 90 | heads: int = 8 91 | layers: int = 12 92 | ls_init_value: Optional[float] = None # layer scale initial value 93 | hf_model_name: str = None 94 | hf_tokenizer_name: str = None 95 | hf_model_pretrained: bool = True 96 | proj: str = 'mlp' 97 | pooler_type: str = 'mean_pooler' 98 | embed_cls: bool = False 99 | pad_id: int = 0 100 | output_tokens: bool = False 101 | 102 | 103 | def get_cast_dtype(precision: str): 104 | cast_dtype = None 105 | if precision == 'bf16': 106 | cast_dtype = torch.bfloat16 107 | elif precision == 'fp16': 108 | cast_dtype = torch.float16 109 | return cast_dtype 110 | 111 | 112 | def get_input_dtype(precision: str): 113 | input_dtype = None 114 | if precision in ('bf16', 'pure_bf16'): 115 | input_dtype = torch.bfloat16 116 | elif precision in ('fp16', 'pure_fp16'): 117 | input_dtype = torch.float16 118 | return input_dtype 119 | 120 | 121 | def _build_vision_tower( 122 | embed_dim: int, 123 | vision_cfg: CLIPVisionCfg, 124 | quick_gelu: bool = False, 125 | cast_dtype: Optional[torch.dtype] = None, 126 | hook: Optional[HookManager]= None, 127 | ): 128 | if isinstance(vision_cfg, dict): 129 | vision_cfg = CLIPVisionCfg(**vision_cfg) 130 | 131 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more 132 | # memory efficient in recent PyTorch releases (>= 1.10). 133 | # NOTE: timm models always use native GELU regardless of quick_gelu flag. 134 | act_layer = QuickGELU if quick_gelu else nn.GELU 135 | 136 | if vision_cfg.timm_model_name: 137 | visual = TimmModel( 138 | vision_cfg.timm_model_name, 139 | pretrained=vision_cfg.timm_model_pretrained, 140 | pool=vision_cfg.timm_pool, 141 | proj=vision_cfg.timm_proj, 142 | proj_bias=vision_cfg.timm_proj_bias, 143 | drop=vision_cfg.timm_drop, 144 | drop_path=vision_cfg.timm_drop_path, 145 | patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, 146 | embed_dim=embed_dim, 147 | image_size=vision_cfg.image_size, 148 | hook=hook, 149 | ) 150 | elif isinstance(vision_cfg.layers, (tuple, list)): 151 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width 152 | visual = ModifiedResNet( 153 | layers=vision_cfg.layers, 154 | output_dim=embed_dim, 155 | heads=vision_heads, 156 | image_size=vision_cfg.image_size, 157 | width=vision_cfg.width, 158 | hook=hook, 159 | ) 160 | else: 161 | vision_heads = vision_cfg.width // vision_cfg.head_width 162 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 163 | visual = VisionTransformer( 164 | image_size=vision_cfg.image_size, 165 | patch_size=vision_cfg.patch_size, 166 | width=vision_cfg.width, 167 | layers=vision_cfg.layers, 168 | heads=vision_heads, 169 | mlp_ratio=vision_cfg.mlp_ratio, 170 | ls_init_value=vision_cfg.ls_init_value, 171 | patch_dropout=vision_cfg.patch_dropout, 172 | input_patchnorm=vision_cfg.input_patchnorm, 173 | global_average_pool=vision_cfg.global_average_pool, 174 | attentional_pool=vision_cfg.attentional_pool, 175 | n_queries=vision_cfg.n_queries, 176 | attn_pooler_heads=vision_cfg.attn_pooler_heads, 177 | output_tokens=vision_cfg.output_tokens, 178 | output_dim=embed_dim, 179 | act_layer=act_layer, 180 | norm_layer=norm_layer, 181 | hook=hook, 182 | ) 183 | 184 | return visual 185 | 186 | 187 | def _build_text_tower( 188 | embed_dim: int, 189 | text_cfg: CLIPTextCfg, 190 | quick_gelu: bool = False, 191 | cast_dtype: Optional[torch.dtype] = None, 192 | hook: Optional[HookManager] = None, 193 | ): 194 | if isinstance(text_cfg, dict): 195 | text_cfg = CLIPTextCfg(**text_cfg) 196 | 197 | if text_cfg.hf_model_name: 198 | from hf_model import HFTextEncoder 199 | text = HFTextEncoder( 200 | text_cfg.hf_model_name, 201 | output_dim=embed_dim, 202 | proj=text_cfg.proj, 203 | pooler_type=text_cfg.pooler_type, 204 | pretrained=text_cfg.hf_model_pretrained, 205 | output_tokens=text_cfg.output_tokens, 206 | ) 207 | else: 208 | act_layer = QuickGELU if quick_gelu else nn.GELU 209 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 210 | 211 | text = TextTransformer( 212 | context_length=text_cfg.context_length, 213 | vocab_size=text_cfg.vocab_size, 214 | width=text_cfg.width, 215 | heads=text_cfg.heads, 216 | layers=text_cfg.layers, 217 | ls_init_value=text_cfg.ls_init_value, 218 | output_dim=embed_dim, 219 | embed_cls=text_cfg.embed_cls, 220 | output_tokens=text_cfg.output_tokens, 221 | pad_id=text_cfg.pad_id, 222 | act_layer=act_layer, 223 | norm_layer=norm_layer, 224 | ) 225 | return text 226 | 227 | 228 | class CLIP(nn.Module): 229 | output_dict: torch.jit.Final[bool] 230 | 231 | def __init__( 232 | self, 233 | embed_dim: int, 234 | vision_cfg: CLIPVisionCfg, 235 | text_cfg: CLIPTextCfg, 236 | quick_gelu: bool = False, 237 | cast_dtype: Optional[torch.dtype] = None, 238 | output_dict: bool = False, 239 | hook: Optional[HookManager] = None, 240 | ): 241 | super().__init__() 242 | self.hook_manager = hook or HookManager() 243 | self.output_dict = output_dict 244 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype, self.hook_manager.fork('visual')) 245 | 246 | text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype, self.hook_manager.fork('textual')) 247 | self.transformer = text.transformer 248 | self.context_length = text.context_length 249 | self.vocab_size = text.vocab_size 250 | self.token_embedding = text.token_embedding 251 | self.positional_embedding = text.positional_embedding 252 | self.ln_final = text.ln_final 253 | self.text_projection = text.text_projection 254 | self.register_buffer('attn_mask', text.attn_mask, persistent=False) 255 | 256 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 257 | 258 | @torch.jit.ignore 259 | def set_grad_checkpointing(self, enable=True): 260 | self.visual.set_grad_checkpointing(enable) 261 | self.transformer.grad_checkpointing = enable 262 | 263 | def encode_image(self, image, normalize: bool = False, attn_method: Text = 'direct'): 264 | features = self.visual(image, attn_method=attn_method) 265 | return F.normalize(features, dim=-1) if normalize else features 266 | 267 | def encode_text(self, text, normalize: bool = False): 268 | cast_dtype = self.transformer.get_cast_dtype() 269 | 270 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] 271 | 272 | x = x + self.positional_embedding.to(cast_dtype) 273 | # x = x.permute(1, 0, 2) # NLD -> LND 274 | x = self.transformer(x, attn_mask=self.attn_mask) 275 | # x = x.permute(1, 0, 2) # LND -> NLD 276 | x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] 277 | # take features from the eot embedding (eot_token is the highest number in each sequence) 278 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 279 | return F.normalize(x, dim=-1) if normalize else x 280 | 281 | def forward( 282 | self, 283 | image: Optional[torch.Tensor] = None, 284 | text: Optional[torch.Tensor] = None, 285 | ): 286 | image_features = self.encode_image(image, normalize=True) if image is not None else None 287 | text_features = self.encode_text(text, normalize=True) if text is not None else None 288 | if self.output_dict: 289 | return { 290 | "image_features": image_features, 291 | "text_features": text_features, 292 | "logit_scale": self.logit_scale.exp() 293 | } 294 | return image_features, text_features, self.logit_scale.exp() 295 | 296 | 297 | # used to maintain checkpoint compatibility 298 | def convert_to_custom_text_state_dict(state_dict: dict): 299 | if 'text_projection' in state_dict: 300 | # old format state_dict, move text tower -> .text 301 | new_state_dict = {} 302 | for k, v in state_dict.items(): 303 | if any(k.startswith(p) for p in ( 304 | 'text_projection', 305 | 'positional_embedding', 306 | 'token_embedding', 307 | 'transformer', 308 | 'ln_final', 309 | )): 310 | k = 'text.' + k 311 | new_state_dict[k] = v 312 | return new_state_dict 313 | return state_dict 314 | 315 | 316 | def build_model_from_openai_state_dict( 317 | state_dict: dict, 318 | quick_gelu=True, 319 | cast_dtype=torch.float16, 320 | ): 321 | vit = "visual.proj" in state_dict 322 | 323 | if vit: 324 | vision_width = state_dict["visual.conv1.weight"].shape[0] 325 | vision_layers = len( 326 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 327 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 328 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 329 | image_size = vision_patch_size * grid_size 330 | else: 331 | counts: list = [ 332 | len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 333 | vision_layers = tuple(counts) 334 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 335 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 336 | vision_patch_size = None 337 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 338 | image_size = output_width * 32 339 | 340 | embed_dim = state_dict["text_projection"].shape[1] 341 | context_length = state_dict["positional_embedding"].shape[0] 342 | vocab_size = state_dict["token_embedding.weight"].shape[0] 343 | transformer_width = state_dict["ln_final.weight"].shape[0] 344 | transformer_heads = transformer_width // 64 345 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 346 | 347 | vision_cfg = CLIPVisionCfg( 348 | layers=vision_layers, 349 | width=vision_width, 350 | patch_size=vision_patch_size, 351 | image_size=image_size, 352 | ) 353 | text_cfg = CLIPTextCfg( 354 | context_length=context_length, 355 | vocab_size=vocab_size, 356 | width=transformer_width, 357 | heads=transformer_heads, 358 | layers=transformer_layers, 359 | ) 360 | model = CLIP( 361 | embed_dim, 362 | vision_cfg=vision_cfg, 363 | text_cfg=text_cfg, 364 | quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU 365 | cast_dtype=cast_dtype, 366 | ) 367 | 368 | for key in ["input_resolution", "context_length", "vocab_size"]: 369 | state_dict.pop(key, None) 370 | 371 | convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 372 | model.load_state_dict(state_dict) 373 | return model.eval() 374 | 375 | 376 | def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): 377 | # Rescale the grid of position embeddings when loading from state_dict 378 | old_pos_embed = state_dict.get('visual.positional_embedding', None) 379 | if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): 380 | return 381 | grid_size = to_2tuple(model.visual.grid_size) 382 | extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) 383 | new_seq_len = grid_size[0] * grid_size[1] + extra_tokens 384 | if new_seq_len == old_pos_embed.shape[0]: 385 | return 386 | 387 | if extra_tokens: 388 | pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] 389 | else: 390 | pos_emb_tok, pos_emb_img = None, old_pos_embed 391 | old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) 392 | 393 | logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) 394 | pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) 395 | pos_emb_img = F.interpolate( 396 | pos_emb_img, 397 | size=grid_size, 398 | mode=interpolation, 399 | antialias=antialias, 400 | align_corners=False, 401 | ) 402 | pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] 403 | if pos_emb_tok is not None: 404 | new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) 405 | else: 406 | new_pos_embed = pos_emb_img 407 | state_dict['visual.positional_embedding'] = new_pos_embed -------------------------------------------------------------------------------- /utils/model_configs/EVA01-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva_giant_patch14_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /utils/model_configs/EVA01-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva_giant_patch14_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /utils/model_configs/EVA02-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_base_patch16_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /utils/model_configs/EVA02-E-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_enormous_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1280, 14 | "heads": 20, 15 | "layers": 32 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /utils/model_configs/EVA02-E-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_enormous_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /utils/model_configs/EVA02-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "timm_model_name": "eva02_large_patch14_clip_336", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /utils/model_configs/EVA02-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "timm_model_name": "eva02_large_patch14_clip_224", 6 | "timm_model_pretrained": false, 7 | "timm_pool": "token", 8 | "timm_proj": null 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 768, 14 | "heads": 12, 15 | "layers": 12 16 | }, 17 | "custom_text": true 18 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-B-16-plus-240.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 240, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-B-16-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-B-32-plus-256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "image_size": 256, 5 | "layers": 12, 6 | "width": 896, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 640, 13 | "heads": 10, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-B-32-quickgelu.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-H-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 16 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 1024, 14 | "heads": 16, 15 | "layers": 24 16 | } 17 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-L-14-280.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 280, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-L-16-320.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 320, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-L-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 768, 13 | "heads": 12, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-M-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16, 8 | "ls_init_value": 1e-4 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 384, 14 | "heads": 6, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-M-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-M-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-M-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 512, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 512, 13 | "heads": 8, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-S-16-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-S-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 16 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-S-32-alt.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 256, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 256, 13 | "heads": 4, 14 | "layers": 10 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-S-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 384, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 384, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "context_length": 77, 11 | "vocab_size": 49408, 12 | "width": 384, 13 | "heads": 6, 14 | "layers": 12 15 | } 16 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-bigG-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 48, 6 | "width": 1664, 7 | "head_width": 104, 8 | "mlp_ratio": 4.9231, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 32 17 | } 18 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-e-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1280, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 56, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.5715, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1280, 15 | "heads": 20, 16 | "layers": 36 17 | } 18 | } -------------------------------------------------------------------------------- /utils/model_configs/ViT-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14 10 | }, 11 | "text_cfg": { 12 | "context_length": 77, 13 | "vocab_size": 49408, 14 | "width": 1024, 15 | "heads": 16, 16 | "layers": 24 17 | } 18 | } -------------------------------------------------------------------------------- /utils/model_configs/coca_ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 512, 25 | "heads": 8, 26 | "layers": 12, 27 | "attn_pooler_heads": 8 28 | }, 29 | "custom_text": true 30 | } -------------------------------------------------------------------------------- /utils/model_configs/coca_ViT-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "patch_size": 14, 8 | "attentional_pool": true, 9 | "attn_pooler_heads": 8, 10 | "output_tokens": true 11 | }, 12 | "text_cfg": { 13 | "context_length": 76, 14 | "vocab_size": 49408, 15 | "width": 768, 16 | "heads": 12, 17 | "layers": 12, 18 | "embed_cls": true, 19 | "output_tokens": true 20 | }, 21 | "multimodal_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 49408, 24 | "width": 768, 25 | "heads": 12, 26 | "layers": 12, 27 | "attn_pooler_heads": 12 28 | }, 29 | "custom_text": true 30 | } 31 | -------------------------------------------------------------------------------- /utils/model_configs/coca_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "multimodal_cfg": { 4 | "width": 768, 5 | "context_length": 76, 6 | "vocab_size": 64000, 7 | "mlp_ratio": 4, 8 | "layers": 12, 9 | "dim_head": 64, 10 | "heads": 12, 11 | "n_queries": 256, 12 | "attn_pooler_heads": 8 13 | }, 14 | "vision_cfg": { 15 | "image_size": 288, 16 | "layers": 12, 17 | "width": 768, 18 | "patch_size": 18, 19 | "output_tokens": true 20 | }, 21 | "text_cfg": { 22 | "context_length": 76, 23 | "vocab_size": 64000, 24 | "layers": 12, 25 | "heads": 12, 26 | "width": 768, 27 | "embed_cls": true, 28 | "output_tokens": true 29 | }, 30 | "custom_text": true 31 | } -------------------------------------------------------------------------------- /utils/model_configs/coca_roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32, 8 | "output_tokens": true 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "proj": "linear", 14 | "width": 768, 15 | "output_tokens": true 16 | }, 17 | "multimodal_cfg": { 18 | "context_length": 76, 19 | "width": 768, 20 | "heads": 8, 21 | "layers": 12 22 | }, 23 | "custom_text": true 24 | } 25 | -------------------------------------------------------------------------------- /utils/model_configs/mt5-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "google/mt5-base", 11 | "hf_tokenizer_name": "google/mt5-base", 12 | "proj": "mlp", 13 | "pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /utils/model_configs/mt5-xl-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "google/mt5-xl", 12 | "hf_tokenizer_name": "google/mt5-xl", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /utils/model_configs/roberta-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "quick_gelu": true, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 12, 7 | "width": 768, 8 | "patch_size": 32 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "roberta-base", 12 | "hf_tokenizer_name": "roberta-base", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /utils/model_configs/swin_base_patch4_window7_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 640, 3 | "vision_cfg": { 4 | "timm_model_name": "swin_base_patch4_window7_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 640, 14 | "heads": 10, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /utils/model_configs/vit_medium_patch16_gap_256.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_medium_patch16_gap_256", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 256 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /utils/model_configs/vit_relpos_medium_patch16_cls_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "timm_model_name": "vit_relpos_medium_patch16_cls_224", 5 | "timm_model_pretrained": false, 6 | "timm_pool": "", 7 | "timm_proj": "linear", 8 | "image_size": 224 9 | }, 10 | "text_cfg": { 11 | "context_length": 77, 12 | "vocab_size": 49408, 13 | "width": 512, 14 | "heads": 8, 15 | "layers": 12 16 | } 17 | } -------------------------------------------------------------------------------- /utils/model_configs/xlm-roberta-base-ViT-B-32.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 32 8 | }, 9 | "text_cfg": { 10 | "hf_model_name": "xlm-roberta-base", 11 | "hf_tokenizer_name": "xlm-roberta-base", 12 | "proj": "mlp", 13 | "pooler_type": "mean_pooler" 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /utils/model_configs/xlm-roberta-large-ViT-H-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 32, 6 | "width": 1280, 7 | "head_width": 80, 8 | "patch_size": 14 9 | }, 10 | "text_cfg": { 11 | "hf_model_name": "xlm-roberta-large", 12 | "hf_tokenizer_name": "xlm-roberta-large", 13 | "proj": "mlp", 14 | "pooler_type": "mean_pooler" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /utils/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from utils.misc import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x -------------------------------------------------------------------------------- /utils/openai_models.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from utils.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | from utils.model import build_model_from_openai_state_dict, get_cast_dtype 14 | from utils.pretrained import * 15 | 16 | __all__ = ["list_openai_models", "load_openai_model"] 17 | 18 | 19 | def list_openai_models() -> List[str]: 20 | """Returns the names of available CLIP models""" 21 | return list_pretrained_models_by_tag('openai') 22 | 23 | 24 | def load_openai_model( 25 | name: str, 26 | precision: Optional[str] = None, 27 | device: Optional[Union[str, torch.device]] = None, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | cache_dir : Optional[str] 41 | The directory to cache the downloaded model weights 42 | 43 | Returns 44 | ------- 45 | model : torch.nn.Module 46 | The CLIP model 47 | preprocess : Callable[[PIL.Image], torch.Tensor] 48 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 49 | """ 50 | if device is None: 51 | device = "cuda" if torch.cuda.is_available() else "cpu" 52 | if precision is None: 53 | precision = 'fp32' if device == 'cpu' else 'fp16' 54 | 55 | if get_pretrained_url(name, 'openai'): 56 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 57 | elif os.path.isfile(name): 58 | model_path = name 59 | else: 60 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 61 | 62 | try: 63 | # loading JIT archive 64 | model = torch.jit.load(model_path, map_location="cpu").eval() 65 | state_dict = None 66 | except RuntimeError: 67 | # loading saved state dict 68 | state_dict = torch.load(model_path, map_location="cpu") 69 | 70 | # Build a non-jit model from the OpenAI jitted model state dict 71 | cast_dtype = get_cast_dtype(precision) 72 | try: 73 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 74 | except KeyError: 75 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 76 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 77 | 78 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 79 | model = model.to(device) 80 | # FIXME support pure fp16/bf16 precision modes 81 | if precision != 'fp16': 82 | model.float() 83 | if precision == 'bf16': 84 | # for bf16, convert back to low-precision 85 | convert_weights_to_lp(model, dtype=torch.bfloat16) 86 | 87 | # add mean / std attributes for consistency with OpenCLIP models 88 | model.visual.image_mean = OPENAI_DATASET_MEAN 89 | model.visual.image_std = OPENAI_DATASET_STD 90 | return model -------------------------------------------------------------------------------- /utils/openai_templates.py: -------------------------------------------------------------------------------- 1 | 2 | OPENAI_IMAGENET_TEMPLATES = ( 3 | lambda c: f'a bad photo of a {c}.', 4 | lambda c: f'a photo of many {c}.', 5 | lambda c: f'a sculpture of a {c}.', 6 | lambda c: f'a photo of the hard to see {c}.', 7 | lambda c: f'a low resolution photo of the {c}.', 8 | lambda c: f'a rendering of a {c}.', 9 | lambda c: f'graffiti of a {c}.', 10 | lambda c: f'a bad photo of the {c}.', 11 | lambda c: f'a cropped photo of the {c}.', 12 | lambda c: f'a tattoo of a {c}.', 13 | lambda c: f'the embroidered {c}.', 14 | lambda c: f'a photo of a hard to see {c}.', 15 | lambda c: f'a bright photo of a {c}.', 16 | lambda c: f'a photo of a clean {c}.', 17 | lambda c: f'a photo of a dirty {c}.', 18 | lambda c: f'a dark photo of the {c}.', 19 | lambda c: f'a drawing of a {c}.', 20 | lambda c: f'a photo of my {c}.', 21 | lambda c: f'the plastic {c}.', 22 | lambda c: f'a photo of the cool {c}.', 23 | lambda c: f'a close-up photo of a {c}.', 24 | lambda c: f'a black and white photo of the {c}.', 25 | lambda c: f'a painting of the {c}.', 26 | lambda c: f'a painting of a {c}.', 27 | lambda c: f'a pixelated photo of the {c}.', 28 | lambda c: f'a sculpture of the {c}.', 29 | lambda c: f'a bright photo of the {c}.', 30 | lambda c: f'a cropped photo of a {c}.', 31 | lambda c: f'a plastic {c}.', 32 | lambda c: f'a photo of the dirty {c}.', 33 | lambda c: f'a jpeg corrupted photo of a {c}.', 34 | lambda c: f'a blurry photo of the {c}.', 35 | lambda c: f'a photo of the {c}.', 36 | lambda c: f'a good photo of the {c}.', 37 | lambda c: f'a rendering of the {c}.', 38 | lambda c: f'a {c} in a video game.', 39 | lambda c: f'a photo of one {c}.', 40 | lambda c: f'a doodle of a {c}.', 41 | lambda c: f'a close-up photo of the {c}.', 42 | lambda c: f'a photo of a {c}.', 43 | lambda c: f'the origami {c}.', 44 | lambda c: f'the {c} in a video game.', 45 | lambda c: f'a sketch of a {c}.', 46 | lambda c: f'a doodle of the {c}.', 47 | lambda c: f'a origami {c}.', 48 | lambda c: f'a low resolution photo of a {c}.', 49 | lambda c: f'the toy {c}.', 50 | lambda c: f'a rendition of the {c}.', 51 | lambda c: f'a photo of the clean {c}.', 52 | lambda c: f'a photo of a large {c}.', 53 | lambda c: f'a rendition of a {c}.', 54 | lambda c: f'a photo of a nice {c}.', 55 | lambda c: f'a photo of a weird {c}.', 56 | lambda c: f'a blurry photo of a {c}.', 57 | lambda c: f'a cartoon {c}.', 58 | lambda c: f'art of a {c}.', 59 | lambda c: f'a sketch of the {c}.', 60 | lambda c: f'a embroidered {c}.', 61 | lambda c: f'a pixelated photo of a {c}.', 62 | lambda c: f'itap of the {c}.', 63 | lambda c: f'a jpeg corrupted photo of the {c}.', 64 | lambda c: f'a good photo of a {c}.', 65 | lambda c: f'a plushie {c}.', 66 | lambda c: f'a photo of the nice {c}.', 67 | lambda c: f'a photo of the small {c}.', 68 | lambda c: f'a photo of the weird {c}.', 69 | lambda c: f'the cartoon {c}.', 70 | lambda c: f'art of the {c}.', 71 | lambda c: f'a drawing of the {c}.', 72 | lambda c: f'a photo of the large {c}.', 73 | lambda c: f'a black and white photo of a {c}.', 74 | lambda c: f'the plushie {c}.', 75 | lambda c: f'a dark photo of a {c}.', 76 | lambda c: f'itap of a {c}.', 77 | lambda c: f'graffiti of the {c}.', 78 | lambda c: f'a toy {c}.', 79 | lambda c: f'itap of my {c}.', 80 | lambda c: f'a photo of a cool {c}.', 81 | lambda c: f'a photo of a small {c}.', 82 | lambda c: f'a tattoo of the {c}.', 83 | ) 84 | 85 | -------------------------------------------------------------------------------- /utils/siglip/configuration_siglip.py: -------------------------------------------------------------------------------- 1 | """Siglip model configuration""" 2 | 3 | from transformers.configuration_utils import PretrainedConfig 4 | from transformers.utils import logging 5 | 6 | 7 | logger = logging.get_logger(__name__) 8 | 9 | 10 | class SiglipTextConfig(PretrainedConfig): 11 | r""" 12 | This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a 13 | Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a 14 | configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip 15 | [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. 16 | 17 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 18 | documentation from [`PretrainedConfig`] for more information. 19 | 20 | Args: 21 | vocab_size (`int`, *optional*, defaults to 32000): 22 | Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by 23 | the `inputs_ids` passed when calling [`SiglipModel`]. 24 | hidden_size (`int`, *optional*, defaults to 768): 25 | Dimensionality of the encoder layers and the pooler layer. 26 | intermediate_size (`int`, *optional*, defaults to 3072): 27 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 28 | num_hidden_layers (`int`, *optional*, defaults to 12): 29 | Number of hidden layers in the Transformer encoder. 30 | num_attention_heads (`int`, *optional*, defaults to 12): 31 | Number of attention heads for each attention layer in the Transformer encoder. 32 | max_position_embeddings (`int`, *optional*, defaults to 64): 33 | The maximum sequence length that this model might ever be used with. Typically set this to something large 34 | just in case (e.g., 512 or 1024 or 2048). 35 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): 36 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 37 | `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. 38 | layer_norm_eps (`float`, *optional*, defaults to 1e-06): 39 | The epsilon used by the layer normalization layers. 40 | attention_dropout (`float`, *optional*, defaults to 0.0): 41 | The dropout ratio for the attention probabilities. 42 | pad_token_id (`int`, *optional*, defaults to 1): 43 | The id of the padding token in the vocabulary. 44 | bos_token_id (`int`, *optional*, defaults to 49406): 45 | The id of the beginning-of-sequence token in the vocabulary. 46 | eos_token_id (`int`, *optional*, defaults to 49407): 47 | The id of the end-of-sequence token in the vocabulary. 48 | projection_size (`int`, *optional*, defaults to `hidden_size`): 49 | The size of the projection head. 50 | 51 | Example: 52 | 53 | ```python 54 | >>> from transformers import SiglipTextConfig, SiglipTextModel 55 | 56 | >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration 57 | >>> configuration = SiglipTextConfig() 58 | 59 | >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration 60 | >>> model = SiglipTextModel(configuration) 61 | 62 | >>> # Accessing the model configuration 63 | >>> configuration = model.config 64 | ```""" 65 | 66 | model_type = "siglip_text_model" 67 | base_config_key = "text_config" 68 | 69 | def __init__( 70 | self, 71 | vocab_size=32000, 72 | hidden_size=768, 73 | intermediate_size=3072, 74 | num_hidden_layers=12, 75 | num_attention_heads=12, 76 | max_position_embeddings=64, 77 | hidden_act="gelu_pytorch_tanh", 78 | layer_norm_eps=1e-6, 79 | attention_dropout=0.0, 80 | # This differs from `CLIPTokenizer`'s default and from openai/siglip 81 | # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 82 | pad_token_id=1, 83 | bos_token_id=49406, 84 | eos_token_id=49407, 85 | projection_size=None, 86 | **kwargs, 87 | ): 88 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 89 | 90 | self.vocab_size = vocab_size 91 | self.hidden_size = hidden_size 92 | self.intermediate_size = intermediate_size 93 | self.num_hidden_layers = num_hidden_layers 94 | self.num_attention_heads = num_attention_heads 95 | self.max_position_embeddings = max_position_embeddings 96 | self.layer_norm_eps = layer_norm_eps 97 | self.hidden_act = hidden_act 98 | self.attention_dropout = attention_dropout 99 | self.projection_size = projection_size if projection_size is not None else hidden_size 100 | 101 | 102 | class SiglipVisionConfig(PretrainedConfig): 103 | r""" 104 | This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a 105 | Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a 106 | configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip 107 | [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. 108 | 109 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 110 | documentation from [`PretrainedConfig`] for more information. 111 | 112 | Args: 113 | hidden_size (`int`, *optional*, defaults to 768): 114 | Dimensionality of the encoder layers and the pooler layer. 115 | intermediate_size (`int`, *optional*, defaults to 3072): 116 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 117 | num_hidden_layers (`int`, *optional*, defaults to 12): 118 | Number of hidden layers in the Transformer encoder. 119 | num_attention_heads (`int`, *optional*, defaults to 12): 120 | Number of attention heads for each attention layer in the Transformer encoder. 121 | num_channels (`int`, *optional*, defaults to 3): 122 | Number of channels in the input images. 123 | image_size (`int`, *optional*, defaults to 224): 124 | The size (resolution) of each image. 125 | patch_size (`int`, *optional*, defaults to 16): 126 | The size (resolution) of each patch. 127 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): 128 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 129 | `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. 130 | layer_norm_eps (`float`, *optional*, defaults to 1e-06): 131 | The epsilon used by the layer normalization layers. 132 | attention_dropout (`float`, *optional*, defaults to 0.0): 133 | The dropout ratio for the attention probabilities. 134 | 135 | Example: 136 | 137 | ```python 138 | >>> from transformers import SiglipVisionConfig, SiglipVisionModel 139 | 140 | >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration 141 | >>> configuration = SiglipVisionConfig() 142 | 143 | >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration 144 | >>> model = SiglipVisionModel(configuration) 145 | 146 | >>> # Accessing the model configuration 147 | >>> configuration = model.config 148 | ```""" 149 | 150 | model_type = "siglip_vision_model" 151 | base_config_key = "vision_config" 152 | 153 | def __init__( 154 | self, 155 | hidden_size=768, 156 | intermediate_size=3072, 157 | num_hidden_layers=12, 158 | num_attention_heads=12, 159 | num_channels=3, 160 | image_size=224, 161 | patch_size=16, 162 | hidden_act="gelu_pytorch_tanh", 163 | layer_norm_eps=1e-6, 164 | attention_dropout=0.0, 165 | **kwargs, 166 | ): 167 | super().__init__(**kwargs) 168 | 169 | self.hidden_size = hidden_size 170 | self.intermediate_size = intermediate_size 171 | self.num_hidden_layers = num_hidden_layers 172 | self.num_attention_heads = num_attention_heads 173 | self.num_channels = num_channels 174 | self.patch_size = patch_size 175 | self.image_size = image_size 176 | self.attention_dropout = attention_dropout 177 | self.layer_norm_eps = layer_norm_eps 178 | self.hidden_act = hidden_act 179 | 180 | 181 | class SiglipConfig(PretrainedConfig): 182 | r""" 183 | [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to 184 | instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. 185 | Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip 186 | [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. 187 | 188 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 189 | documentation from [`PretrainedConfig`] for more information. 190 | 191 | Args: 192 | text_config (`dict`, *optional*): 193 | Dictionary of configuration options used to initialize [`SiglipTextConfig`]. 194 | vision_config (`dict`, *optional*): 195 | Dictionary of configuration options used to initialize [`SiglipVisionConfig`]. 196 | kwargs (*optional*): 197 | Dictionary of keyword arguments. 198 | 199 | Example: 200 | 201 | ```python 202 | >>> from transformers import SiglipConfig, SiglipModel 203 | 204 | >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration 205 | >>> configuration = SiglipConfig() 206 | 207 | >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration 208 | >>> model = SiglipModel(configuration) 209 | 210 | >>> # Accessing the model configuration 211 | >>> configuration = model.config 212 | 213 | >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig 214 | >>> from transformers import SiglipTextConfig, SiglipVisionConfig 215 | 216 | >>> # Initializing a SiglipText and SiglipVision configuration 217 | >>> config_text = SiglipTextConfig() 218 | >>> config_vision = SiglipVisionConfig() 219 | 220 | >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision) 221 | ```""" 222 | 223 | model_type = "siglip" 224 | sub_configs = {"text_config": SiglipTextConfig, "vision_config": SiglipVisionConfig} 225 | 226 | def __init__(self, text_config=None, vision_config=None, **kwargs): 227 | super().__init__(**kwargs) 228 | 229 | if text_config is None: 230 | text_config = {} 231 | logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.") 232 | 233 | if vision_config is None: 234 | vision_config = {} 235 | logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.") 236 | 237 | self.text_config = SiglipTextConfig(**text_config) 238 | self.vision_config = SiglipVisionConfig(**vision_config) 239 | 240 | self.initializer_factor = 1.0 241 | 242 | @classmethod 243 | def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs): 244 | r""" 245 | Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision 246 | model configuration. 247 | 248 | Returns: 249 | [`SiglipConfig`]: An instance of a configuration object 250 | """ 251 | 252 | return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) 253 | 254 | 255 | __all__ = ["SiglipConfig", "SiglipTextConfig", "SiglipVisionConfig"] -------------------------------------------------------------------------------- /utils/siglip/image_processing_siglip.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Image processor class for SigLIP.""" 16 | 17 | from typing import Dict, List, Optional, Union 18 | 19 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict 20 | from transformers.image_transforms import ( 21 | convert_to_rgb, 22 | resize, 23 | to_channel_dimension_format, 24 | ) 25 | from transformers.image_utils import ( 26 | IMAGENET_STANDARD_MEAN, 27 | IMAGENET_STANDARD_STD, 28 | ChannelDimension, 29 | ImageInput, 30 | PILImageResampling, 31 | infer_channel_dimension_format, 32 | is_scaled_image, 33 | make_flat_list_of_images, 34 | to_numpy_array, 35 | valid_images, 36 | validate_preprocess_arguments, 37 | ) 38 | from transformers.utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging 39 | 40 | 41 | logger = logging.get_logger(__name__) 42 | 43 | 44 | if is_vision_available(): 45 | import PIL 46 | 47 | 48 | class SiglipImageProcessor(BaseImageProcessor): 49 | r""" 50 | Constructs a SigLIP image processor. 51 | 52 | Args: 53 | do_resize (`bool`, *optional*, defaults to `True`): 54 | Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by 55 | `do_resize` in the `preprocess` method. 56 | size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`): 57 | Size of the image after resizing. Can be overridden by `size` in the `preprocess` method. 58 | resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): 59 | Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. 60 | do_rescale (`bool`, *optional*, defaults to `True`): 61 | Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in 62 | the `preprocess` method. 63 | rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): 64 | Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` 65 | method. 66 | do_normalize (`bool`, *optional*, defaults to `True`): 67 | Whether to normalize the image by the specified mean and standard deviation. Can be overridden by 68 | `do_normalize` in the `preprocess` method. 69 | image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): 70 | Mean to use if normalizing the image. This is a float or list of floats the length of the number of 71 | channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. 72 | image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`): 73 | Standard deviation to use if normalizing the image. This is a float or list of floats the length of the 74 | number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. 75 | Can be overridden by the `image_std` parameter in the `preprocess` method. 76 | do_convert_rgb (`bool`, *optional*, defaults to `True`): 77 | Whether to convert the image to RGB. 78 | """ 79 | 80 | model_input_names = ["pixel_values"] 81 | 82 | def __init__( 83 | self, 84 | do_resize: bool = True, 85 | size: Optional[Dict[str, int]] = None, 86 | resample: PILImageResampling = PILImageResampling.BICUBIC, 87 | do_rescale: bool = True, 88 | rescale_factor: Union[int, float] = 1 / 255, 89 | do_normalize: bool = True, 90 | image_mean: Optional[Union[float, List[float]]] = None, 91 | image_std: Optional[Union[float, List[float]]] = None, 92 | do_convert_rgb: Optional[bool] = None, 93 | **kwargs, 94 | ) -> None: 95 | super().__init__(**kwargs) 96 | size = size if size is not None else {"height": 224, "width": 224} 97 | image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN 98 | image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD 99 | 100 | self.do_resize = do_resize 101 | self.size = size 102 | self.resample = resample 103 | self.do_rescale = do_rescale 104 | self.rescale_factor = rescale_factor 105 | self.do_normalize = do_normalize 106 | self.image_mean = image_mean 107 | self.image_std = image_std 108 | self.do_convert_rgb = do_convert_rgb 109 | 110 | @filter_out_non_signature_kwargs() 111 | def preprocess( 112 | self, 113 | images: ImageInput, 114 | do_resize: Optional[bool] = None, 115 | size: Optional[Dict[str, int]] = None, 116 | resample: PILImageResampling = None, 117 | do_rescale: Optional[bool] = None, 118 | rescale_factor: Optional[float] = None, 119 | do_normalize: Optional[bool] = None, 120 | image_mean: Optional[Union[float, List[float]]] = None, 121 | image_std: Optional[Union[float, List[float]]] = None, 122 | return_tensors: Optional[Union[str, TensorType]] = None, 123 | data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, 124 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 125 | do_convert_rgb: Optional[bool] = None, 126 | ) -> PIL.Image.Image: 127 | """ 128 | Preprocess an image or batch of images. 129 | 130 | Args: 131 | images (`ImageInput`): 132 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If 133 | passing in images with pixel values between 0 and 1, set `do_rescale=False`. 134 | do_resize (`bool`, *optional*, defaults to `self.do_resize`): 135 | Whether to resize the image. 136 | size (`Dict[str, int]`, *optional*, defaults to `self.size`): 137 | Size of the image after resizing. 138 | resample (`int`, *optional*, defaults to `self.resample`): 139 | Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only 140 | has an effect if `do_resize` is set to `True`. 141 | do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): 142 | Whether to rescale the image. 143 | rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): 144 | Rescale factor to rescale the image by if `do_rescale` is set to `True`. 145 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): 146 | Whether to normalize the image. 147 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): 148 | Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. 149 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): 150 | Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to 151 | `True`. 152 | return_tensors (`str` or `TensorType`, *optional*): 153 | The type of tensors to return. Can be one of: 154 | - Unset: Return a list of `np.ndarray`. 155 | - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. 156 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. 157 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. 158 | - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. 159 | data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): 160 | The channel dimension format for the output image. Can be one of: 161 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 162 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 163 | - Unset: Use the channel dimension format of the input image. 164 | input_data_format (`ChannelDimension` or `str`, *optional*): 165 | The channel dimension format for the input image. If unset, the channel dimension format is inferred 166 | from the input image. Can be one of: 167 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 168 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 169 | - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. 170 | do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): 171 | Whether to convert the image to RGB. 172 | """ 173 | do_resize = do_resize if do_resize is not None else self.do_resize 174 | size = size if size is not None else self.size 175 | size = get_size_dict(size, param_name="size", default_to_square=False) 176 | resample = resample if resample is not None else self.resample 177 | do_rescale = do_rescale if do_rescale is not None else self.do_rescale 178 | rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor 179 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize 180 | image_mean = image_mean if image_mean is not None else self.image_mean 181 | image_std = image_std if image_std is not None else self.image_std 182 | do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb 183 | 184 | images = make_flat_list_of_images(images) 185 | 186 | if not valid_images(images): 187 | raise ValueError( 188 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " 189 | "torch.Tensor, tf.Tensor or jax.ndarray." 190 | ) 191 | validate_preprocess_arguments( 192 | do_rescale=do_rescale, 193 | rescale_factor=rescale_factor, 194 | do_normalize=do_normalize, 195 | image_mean=image_mean, 196 | image_std=image_std, 197 | do_resize=do_resize, 198 | size=size, 199 | resample=resample, 200 | ) 201 | if do_convert_rgb: 202 | images = [convert_to_rgb(image) for image in images] 203 | 204 | # All transformations expect numpy arrays. 205 | images = [to_numpy_array(image) for image in images] 206 | 207 | if do_rescale and is_scaled_image(images[0]): 208 | logger.warning_once( 209 | "It looks like you are trying to rescale already rescaled images. If the input" 210 | " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." 211 | ) 212 | 213 | if input_data_format is None: 214 | # We assume that all images have the same channel dimension format. 215 | input_data_format = infer_channel_dimension_format(images[0]) 216 | 217 | if do_resize: 218 | height, width = size["height"], size["width"] 219 | images = [ 220 | resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format) 221 | for image in images 222 | ] 223 | 224 | if do_rescale: 225 | images = [ 226 | self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) 227 | for image in images 228 | ] 229 | 230 | if do_normalize: 231 | images = [ 232 | self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) 233 | for image in images 234 | ] 235 | 236 | images = [ 237 | to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images 238 | ] 239 | 240 | data = {"pixel_values": images} 241 | return BatchFeature(data=data, tensor_type=return_tensors) 242 | 243 | 244 | __all__ = ["SiglipImageProcessor"] -------------------------------------------------------------------------------- /utils/siglip/image_processing_siglip_fast.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Fast Image processor class for SigLIP.""" 16 | 17 | from transformers.image_processing_utils_fast import BaseImageProcessorFast 18 | from transformers.image_utils import ( 19 | IMAGENET_STANDARD_MEAN, 20 | IMAGENET_STANDARD_STD, 21 | PILImageResampling, 22 | ) 23 | 24 | class SiglipImageProcessorFast(BaseImageProcessorFast): 25 | resample = PILImageResampling.BICUBIC 26 | image_mean = IMAGENET_STANDARD_MEAN 27 | image_std = IMAGENET_STANDARD_STD 28 | size = {"height": 224, "width": 224} 29 | default_to_square = False 30 | do_resize = True 31 | do_rescale = True 32 | do_normalize = True 33 | 34 | 35 | __all__ = ["SiglipImageProcessorFast"] -------------------------------------------------------------------------------- /utils/siglip/processing_siglip.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image/Text processor class for SigLIP. 3 | """ 4 | 5 | from typing import List, Optional, Union 6 | 7 | from transformers.feature_extraction_utils import BatchFeature 8 | from transformers.image_utils import ImageInput 9 | from transformers.processing_utils import ProcessorMixin 10 | from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy 11 | from transformers.utils import TensorType 12 | 13 | 14 | class SiglipProcessor(ProcessorMixin): 15 | r""" 16 | Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor. 17 | 18 | [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the 19 | [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information. 20 | 21 | Args: 22 | image_processor ([`SiglipImageProcessor`]): 23 | The image processor is a required input. 24 | tokenizer ([`SiglipTokenizer`]): 25 | The tokenizer is a required input. 26 | """ 27 | 28 | attributes = ["image_processor", "tokenizer"] 29 | image_processor_class = ("SiglipImageProcessor", "SiglipImageProcessorFast") 30 | tokenizer_class = "AutoTokenizer" 31 | 32 | def __init__(self, image_processor, tokenizer): 33 | super().__init__(image_processor, tokenizer) 34 | 35 | def __call__( 36 | self, 37 | text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, 38 | images: ImageInput = None, 39 | padding: Union[bool, str, PaddingStrategy] = False, 40 | truncation: Union[bool, str, TruncationStrategy] = None, 41 | max_length: Optional[int] = None, 42 | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, 43 | ) -> BatchFeature: 44 | """ 45 | Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` 46 | and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode 47 | the text. To prepare the image(s), this method forwards the `images` argument to 48 | SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring 49 | of the above two methods for more information. 50 | 51 | Args: 52 | text (`str`, `List[str]`, `List[List[str]]`): 53 | The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings 54 | (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set 55 | `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). 56 | images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): 57 | The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch 58 | tensor. Both channels-first and channels-last formats are supported. 59 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): 60 | Select a strategy to pad the returned sequences (according to the model's padding side and padding 61 | index) among: 62 | - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 63 | sequence if provided). 64 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 65 | acceptable input length for the model if that argument is not provided. 66 | - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different 67 | lengths). 68 | max_length (`int`, *optional*): 69 | Maximum length of the returned list and optionally padding length (see above). 70 | truncation (`bool`, *optional*): 71 | Activates truncation to cut input sequences longer than `max_length` to `max_length`. 72 | return_tensors (`str` or [`~utils.TensorType`], *optional*): 73 | If set, will return tensors of a particular framework. Acceptable values are: 74 | 75 | - `'tf'`: Return TensorFlow `tf.constant` objects. 76 | - `'pt'`: Return PyTorch `torch.Tensor` objects. 77 | - `'np'`: Return NumPy `np.ndarray` objects. 78 | - `'jax'`: Return JAX `jnp.ndarray` objects. 79 | 80 | Returns: 81 | [`BatchFeature`]: A [`BatchFeature`] with the following fields: 82 | 83 | - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. 84 | - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when 85 | `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not 86 | `None`). 87 | - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. 88 | """ 89 | 90 | if text is None and images is None: 91 | raise ValueError("You have to specify either text or images. Both cannot be none.") 92 | 93 | if text is not None: 94 | encoding = self.tokenizer( 95 | text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length 96 | ) 97 | 98 | if images is not None: 99 | image_features = self.image_processor(images, return_tensors=return_tensors) 100 | 101 | if text is not None and images is not None: 102 | encoding.update(image_features) 103 | return encoding 104 | elif text is not None: 105 | return encoding 106 | else: 107 | return BatchFeature(data=dict(**image_features), tensor_type=return_tensors) 108 | 109 | def decode(self, *args, **kwargs): 110 | """ 111 | This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to 112 | the docstring of this method for more information. 113 | """ 114 | return self.tokenizer.decode(*args, **kwargs) 115 | 116 | def batch_decode(self, *args, **kwargs): 117 | """ 118 | This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please 119 | refer to the docstring of this method for more information. 120 | """ 121 | return self.tokenizer.batch_decode(*args, **kwargs) 122 | 123 | @property 124 | # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip 125 | def model_input_names(self): 126 | tokenizer_input_names = self.tokenizer.model_input_names 127 | image_processor_input_names = self.image_processor.model_input_names 128 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) 129 | 130 | 131 | __all__ = ["SiglipProcessor"] -------------------------------------------------------------------------------- /utils/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from utils.misc import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | """ 31 | 32 | def __init__( 33 | self, 34 | model_name, 35 | embed_dim, 36 | image_size=224, 37 | pool='avg', 38 | proj='linear', 39 | proj_bias=False, 40 | drop=0., 41 | drop_path=None, 42 | patch_drop=None, 43 | pretrained=False, 44 | ): 45 | super().__init__() 46 | if timm is None: 47 | raise RuntimeError("Please `pip install timm` to use timm models.") 48 | self.image_size = to_2tuple(image_size) 49 | 50 | # setup kwargs that may not be common across all models 51 | timm_kwargs = {} 52 | if drop_path is not None: 53 | timm_kwargs['drop_path_rate'] = drop_path 54 | if patch_drop is not None: 55 | timm_kwargs['patch_drop_rate'] = patch_drop 56 | 57 | custom_pool = pool in ('abs_attn', 'rot_attn') 58 | if not proj and not custom_pool: 59 | # use network classifier head as projection if no proj specified and no custom pooling used 60 | self.trunk = timm.create_model( 61 | model_name, 62 | num_classes=embed_dim, 63 | global_pool=pool, 64 | pretrained=pretrained, 65 | **timm_kwargs, 66 | ) 67 | prev_chs = embed_dim 68 | else: 69 | self.trunk = timm.create_model( 70 | model_name, 71 | pretrained=pretrained, 72 | **timm_kwargs, 73 | ) 74 | feat_size = self.trunk.default_cfg.get('pool_size', None) 75 | feature_ndim = 1 if not feat_size else 2 76 | if custom_pool: 77 | assert feature_ndim == 2 78 | # if attn pooling used, remove both classifier and default pool 79 | self.trunk.reset_classifier(0, global_pool='') 80 | else: 81 | # reset global pool if pool config set, otherwise leave as network default 82 | reset_kwargs = dict(global_pool=pool) if pool else {} 83 | self.trunk.reset_classifier(0, **reset_kwargs) 84 | prev_chs = self.trunk.num_features 85 | 86 | head_layers = OrderedDict() 87 | 88 | # Add custom pooling to head 89 | if pool == 'abs_attn': 90 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 91 | prev_chs = embed_dim 92 | elif pool == 'rot_attn': 93 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 94 | prev_chs = embed_dim 95 | 96 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 97 | if proj == 'linear': 98 | head_layers['drop'] = nn.Dropout(drop) 99 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 100 | elif proj == 'mlp': 101 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) 102 | else: 103 | assert not proj, f'Unknown projection type {proj}.' 104 | 105 | self.head = nn.Sequential(head_layers) 106 | 107 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 108 | """ lock modules 109 | Args: 110 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 111 | """ 112 | if not unlocked_groups: 113 | # lock full model 114 | for param in self.trunk.parameters(): 115 | param.requires_grad = False 116 | if freeze_bn_stats: 117 | freeze_batch_norm_2d(self.trunk) 118 | else: 119 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 120 | try: 121 | # FIXME import here until API stable and in an official release 122 | from timm.models.helpers import group_parameters, group_modules 123 | except ImportError: 124 | raise RuntimeError( 125 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 126 | matcher = self.trunk.group_matcher() 127 | gparams = group_parameters(self.trunk, matcher) 128 | max_layer_id = max(gparams.keys()) 129 | max_layer_id = max_layer_id - unlocked_groups 130 | for group_idx in range(max_layer_id + 1): 131 | group = gparams[group_idx] 132 | for param in group: 133 | self.trunk.get_parameter(param).requires_grad = False 134 | if freeze_bn_stats: 135 | gmodules = group_modules(self.trunk, matcher, reverse=True) 136 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 137 | freeze_batch_norm_2d(self.trunk, gmodules) 138 | 139 | @torch.jit.ignore 140 | def set_grad_checkpointing(self, enable=True): 141 | try: 142 | self.trunk.set_grad_checkpointing(enable) 143 | except Exception as e: 144 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 145 | 146 | def forward(self, x): 147 | x = self.trunk(x) 148 | x = self.head(x) 149 | return x -------------------------------------------------------------------------------- /utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | # https://stackoverflow.com/q/62691279 16 | import os 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | 19 | 20 | @lru_cache() 21 | def default_bpe(): 22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "vocab/bpe_simple_vocab_16e6.txt.gz") 23 | 24 | 25 | @lru_cache() 26 | def bytes_to_unicode(): 27 | """ 28 | Returns list of utf-8 byte and a corresponding list of unicode strings. 29 | The reversible bpe codes work on unicode strings. 30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 32 | This is a significant percentage of your normal, say, 32K bpe vocab. 33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 34 | And avoids mapping to whitespace/control characters the bpe code barfs on. 35 | """ 36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 37 | cs = bs[:] 38 | n = 0 39 | for b in range(2**8): 40 | if b not in bs: 41 | bs.append(b) 42 | cs.append(2**8+n) 43 | n += 1 44 | cs = [chr(n) for n in cs] 45 | return dict(zip(bs, cs)) 46 | 47 | 48 | def get_pairs(word): 49 | """Return set of symbol pairs in a word. 50 | Word is represented as tuple of symbols (symbols being variable-length strings). 51 | """ 52 | pairs = set() 53 | prev_char = word[0] 54 | for char in word[1:]: 55 | pairs.add((prev_char, char)) 56 | prev_char = char 57 | return pairs 58 | 59 | 60 | def basic_clean(text): 61 | text = ftfy.fix_text(text) 62 | text = html.unescape(html.unescape(text)) 63 | return text.strip() 64 | 65 | 66 | def whitespace_clean(text): 67 | text = re.sub(r'\s+', ' ', text) 68 | text = text.strip() 69 | return text 70 | 71 | 72 | class SimpleTokenizer(object): 73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 74 | self.byte_encoder = bytes_to_unicode() 75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 77 | merges = merges[1:49152-256-2+1] 78 | merges = [tuple(merge.split()) for merge in merges] 79 | vocab = list(bytes_to_unicode().values()) 80 | vocab = vocab + [v+'' for v in vocab] 81 | for merge in merges: 82 | vocab.append(''.join(merge)) 83 | if not special_tokens: 84 | special_tokens = ['', ''] 85 | else: 86 | special_tokens = ['', ''] + special_tokens 87 | vocab.extend(special_tokens) 88 | self.encoder = dict(zip(vocab, range(len(vocab)))) 89 | self.decoder = {v: k for k, v in self.encoder.items()} 90 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 91 | self.cache = {t:t for t in special_tokens} 92 | special = "|".join(special_tokens) 93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 94 | 95 | self.vocab_size = len(self.encoder) 96 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 97 | 98 | def bpe(self, token): 99 | if token in self.cache: 100 | return self.cache[token] 101 | word = tuple(token[:-1]) + ( token[-1] + '',) 102 | pairs = get_pairs(word) 103 | 104 | if not pairs: 105 | return token+'' 106 | 107 | while True: 108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 109 | if bigram not in self.bpe_ranks: 110 | break 111 | first, second = bigram 112 | new_word = [] 113 | i = 0 114 | while i < len(word): 115 | try: 116 | j = word.index(first, i) 117 | new_word.extend(word[i:j]) 118 | i = j 119 | except: 120 | new_word.extend(word[i:]) 121 | break 122 | 123 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 124 | new_word.append(first+second) 125 | i += 2 126 | else: 127 | new_word.append(word[i]) 128 | i += 1 129 | new_word = tuple(new_word) 130 | word = new_word 131 | if len(word) == 1: 132 | break 133 | else: 134 | pairs = get_pairs(word) 135 | word = ' '.join(word) 136 | self.cache[token] = word 137 | return word 138 | 139 | def encode(self, text): 140 | bpe_tokens = [] 141 | text = whitespace_clean(basic_clean(text)).lower() 142 | for token in re.findall(self.pat, text): 143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 145 | return bpe_tokens 146 | 147 | def decode(self, tokens): 148 | text = ''.join([self.decoder[token] for token in tokens]) 149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 150 | return text 151 | 152 | 153 | _tokenizer = SimpleTokenizer() 154 | 155 | def decode(output_ids: torch.Tensor): 156 | output_ids = output_ids.cpu().numpy() 157 | return _tokenizer.decode(output_ids) 158 | 159 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 160 | """ 161 | Returns the tokenized representation of given input string(s) 162 | 163 | Parameters 164 | ---------- 165 | texts : Union[str, List[str]] 166 | An input string or a list of input strings to tokenize 167 | context_length : int 168 | The context length to use; all CLIP models use 77 as the context length 169 | 170 | Returns 171 | ------- 172 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 173 | """ 174 | if isinstance(texts, str): 175 | texts = [texts] 176 | 177 | sot_token = _tokenizer.encoder[""] 178 | eot_token = _tokenizer.encoder[""] 179 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 180 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 181 | 182 | for i, tokens in enumerate(all_tokens): 183 | if len(tokens) > context_length: 184 | tokens = tokens[:context_length] # Truncate 185 | tokens[-1] = eot_token 186 | result[i, :len(tokens)] = torch.tensor(tokens) 187 | 188 | return result 189 | 190 | 191 | class HFTokenizer: 192 | """HuggingFace tokenizer wrapper""" 193 | 194 | def __init__(self, tokenizer_name: str): 195 | from transformers import AutoTokenizer 196 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 197 | 198 | def save_pretrained(self, dest): 199 | self.tokenizer.save_pretrained(dest) 200 | 201 | def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: 202 | # same cleaning as for default tokenizer, except lowercasing 203 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance 204 | if isinstance(texts, str): 205 | texts = [texts] 206 | texts = [whitespace_clean(basic_clean(text)) for text in texts] 207 | input_ids = self.tokenizer( 208 | texts, 209 | return_tensors='pt', 210 | max_length=context_length, 211 | padding='max_length', 212 | truncation=True, 213 | ).input_ids 214 | return input_ids -------------------------------------------------------------------------------- /utils/transform.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass, asdict 3 | from typing import Any, Dict, Optional, Sequence, Tuple, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.transforms.functional as F 8 | 9 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 10 | CenterCrop 11 | 12 | from utils.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 13 | 14 | 15 | @dataclass 16 | class AugmentationCfg: 17 | scale: Tuple[float, float] = (0.9, 1.0) 18 | ratio: Optional[Tuple[float, float]] = None 19 | color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None 20 | interpolation: Optional[str] = None 21 | re_prob: Optional[float] = None 22 | re_count: Optional[int] = None 23 | use_timm: bool = False 24 | 25 | 26 | class ResizeMaxSize(nn.Module): 27 | 28 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 29 | super().__init__() 30 | if not isinstance(max_size, int): 31 | raise TypeError(f"Size should be int. Got {type(max_size)}") 32 | self.max_size = max_size 33 | self.interpolation = interpolation 34 | self.fn = min if fn == 'min' else min 35 | self.fill = fill 36 | 37 | def forward(self, img): 38 | if isinstance(img, torch.Tensor): 39 | height, width = img.shape[:2] 40 | else: 41 | width, height = img.size 42 | scale = self.max_size / float(max(height, width)) 43 | if scale != 1.0: 44 | new_size = tuple(round(dim * scale) for dim in (height, width)) 45 | img = F.resize(img, new_size, self.interpolation) 46 | pad_h = self.max_size - new_size[0] 47 | pad_w = self.max_size - new_size[1] 48 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) 49 | return img 50 | 51 | 52 | def _convert_to_rgb(image): 53 | return image.convert('RGB') 54 | 55 | 56 | def image_transform( 57 | image_size: int, 58 | is_train: bool, 59 | mean: Optional[Tuple[float, ...]] = None, 60 | std: Optional[Tuple[float, ...]] = None, 61 | resize_longest_max: bool = False, 62 | fill_color: int = 0, 63 | aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, 64 | ): 65 | mean = mean or OPENAI_DATASET_MEAN 66 | if not isinstance(mean, (list, tuple)): 67 | mean = (mean,) * 3 68 | 69 | std = std or OPENAI_DATASET_STD 70 | if not isinstance(std, (list, tuple)): 71 | std = (std,) * 3 72 | 73 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 74 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 75 | image_size = image_size[0] 76 | 77 | if isinstance(aug_cfg, dict): 78 | aug_cfg = AugmentationCfg(**aug_cfg) 79 | else: 80 | aug_cfg = aug_cfg or AugmentationCfg() 81 | normalize = Normalize(mean=mean, std=std) 82 | if is_train: 83 | aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} 84 | use_timm = aug_cfg_dict.pop('use_timm', False) 85 | if use_timm: 86 | from timm.data import create_transform # timm can still be optional 87 | if isinstance(image_size, (tuple, list)): 88 | assert len(image_size) >= 2 89 | input_size = (3,) + image_size[-2:] 90 | else: 91 | input_size = (3, image_size, image_size) 92 | # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time 93 | aug_cfg_dict.setdefault('interpolation', 'random') 94 | aug_cfg_dict.setdefault('color_jitter', None) # disable by default 95 | train_transform = create_transform( 96 | input_size=input_size, 97 | is_training=True, 98 | hflip=0., 99 | mean=mean, 100 | std=std, 101 | re_mode='pixel', 102 | **aug_cfg_dict, 103 | ) 104 | else: 105 | train_transform = Compose([ 106 | RandomResizedCrop( 107 | image_size, 108 | scale=aug_cfg_dict.pop('scale'), 109 | interpolation=InterpolationMode.BICUBIC, 110 | ), 111 | _convert_to_rgb, 112 | ToTensor(), 113 | normalize, 114 | ]) 115 | if aug_cfg_dict: 116 | warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') 117 | return train_transform 118 | else: 119 | if resize_longest_max: 120 | transforms = [ 121 | ResizeMaxSize(image_size, fill=fill_color) 122 | ] 123 | else: 124 | transforms = [ 125 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 126 | CenterCrop(image_size), 127 | ] 128 | transforms.extend([ 129 | _convert_to_rgb, 130 | ToTensor(), 131 | normalize, 132 | ]) 133 | return Compose(transforms) -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | ## Imports 4 | from PIL import Image 5 | from torchvision import transforms 6 | 7 | 8 | def _convert_to_rgb(image): 9 | return image.convert("RGB") 10 | 11 | 12 | visualization_preprocess = transforms.Compose( 13 | [ 14 | transforms.Resize(size=224, interpolation=Image.BICUBIC), 15 | transforms.CenterCrop(size=(224, 224)), 16 | _convert_to_rgb, 17 | ] 18 | ) 19 | 20 | 21 | def image_grid(imgs, rows, cols): 22 | assert len(imgs) == rows * cols 23 | 24 | w, h = imgs[0].size 25 | grid = Image.new("RGB", size=(cols * w, rows * h)) 26 | grid_w, grid_h = grid.size 27 | 28 | for i, img in enumerate(imgs): 29 | grid.paste(img, box=(i % cols * w, i // cols * h)) 30 | return grid 31 | -------------------------------------------------------------------------------- /utils/vocab/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yossigandelsman/clip_text_span/49f70f31bb13437a870ff8de340626b225500a22/utils/vocab/bpe_simple_vocab_16e6.txt.gz --------------------------------------------------------------------------------