├── .gitignore ├── README.md ├── configs ├── base.yaml └── models │ ├── idefics2.yaml │ ├── llava.yaml │ ├── mantis.yaml │ ├── mistral.yaml │ └── vicuna.yaml ├── data └── examples │ ├── image_icl │ ├── Mr_Krabs.png │ ├── Mrs_Puff.png │ └── Sandy_Cheeks.png │ ├── instruction │ └── Swamp_Wallaby.png │ ├── task_conflict │ └── VQA_Food.png │ └── text_icl │ └── France.png ├── demo.ipynb ├── download_data.sh ├── environment.yml ├── experiment_base.sh ├── experiment_ensemble.sh └── src ├── __init__.py ├── xpatch_dataset.py ├── xpatch_evaluate.py └── xpatch_helpers.py /.gitignore: -------------------------------------------------------------------------------- 1 | **__pycache__** 2 | runs/ 3 | data/ 4 | out/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vision-Language Models Create Cross-Modal Task Representations 2 | 3 | **Grace Luo, Trevor Darrell, Amir Bar** 4 | 5 | This repository contains the data and code for the paper "Vision-Language Models Create Cross-Modal Task Representations." 6 | 7 | [[`Project Page`](https://vlm-cross-modal-reps.github.io)][[`arXiv`](https://arxiv.org/abs/2410.22330)] 8 | 9 | ## Releases 10 | - ✍️ 2025/05/01: Update project title from "Task Vectors are Cross-Modal" 11 | - 🚀 2024/10/29: Initial codebase release 12 | 13 | ## Setup 14 | This code was tested with Python 3.8. Please run the following to install the necessary packages and authenticate with HuggingFace, which is necessary to download some models. 15 | ``` 16 | conda env create -f environment.yml 17 | conda activate xpatch 18 | huggingface-cli login 19 | ``` 20 | 21 | ## Data 22 | All data can be found on our [HuggingFace page](https://huggingface.co/datasets/g-luo/vlm-cross-modal-reps/tree/main). To download and set up the data, run the following script: 23 | ``` 24 | ./download_data.sh 25 | ``` 26 | 27 | ## Demo 28 | In `demo.ipynb`, we walk through a few qualitative examples that illustrate the cross-modal nature of task vectors. Specifically, it demonstrates: 29 | 30 | - **Cross-Modal Transfer.** Task vectors can be derived from text ICL examples, instructions, and image ICL examples and transferred to queries in another modality. 31 | - **LLM to VLM Transfer.** Task vectors can also be patched from the base LLM to its corresponding fine-tuned VLM. 32 | - **Vector Ensembling.** Instructions can improve the sample effiency of text ICL when averaging their corresponding task vectors. 33 | - **Task Conflict.** When one task is provided in the prompt and a conflicting one is patched as a task vector, the model completes one of those two tasks. 34 | 35 | ## Experiments 36 | To evaluate cross-modal patching on our six tasks, run the following scripts. Once the scripts are finished running, you can find per-task and average task accuracy statistics saved as csvs under `runs/experiments/`. 37 | 38 | - **Cross-Modal Transfer.** Run the script `./experiment_base.sh`. By default, the script is set to cross-modal patching from text ICL to image queries. The arguments to the script correspond to `feats_model` and `patch_model`, or the model to extract features from and the model to patch those features to. 39 | 40 | ``` 41 | # Default VLM Transfer 42 | ./experiment_base.sh idefics2 idefics2 43 | ./experiment_base.sh llava llava 44 | ./experiment_base.sh mantis mantis 45 | 46 | # LLM to VLM Transfer 47 | ./experiment_base.sh mistral idefics2 48 | ./experiment_base.sh vicuna llava 49 | ``` 50 | 51 | - **Vector Ensembling.** Run the script `./experiment_ensemble.sh`. The script evaluates the scaling properties of task vectors derived from text ICL examples, ranging from 5 to 30 examples. It also demonstrates the performance when averaging with instruction-based task vectors, where the instructions are defined in `configs/base.yaml`. The argument to the script corresponds to `ensemble`, i.e., whether to run only exemplar-based vectors or ensemble them with instruction-based vectors. 52 | 53 | ``` 54 | ./experiment_ensemble.sh false 55 | ./experiment_ensemble.sh true 56 | ``` 57 | 58 | ## Citing 59 | ``` 60 | @inproceedings{luo2025vlm, 61 | title={Vision-Language Models Create Cross-Modal Task Representations}, 62 | author={Grace Luo and Trevor Darrell and Amir Bar}, 63 | booktitle={ICML}, 64 | year={2025} 65 | } 66 | ``` 67 | 68 | ## Acknowledgements 69 | This codebase was implemented from scratch, inspired by design patterns and conventions from [Task Vectors](https://github.com/roeehendel/icl_task_vectors) and [Function Vectors](https://github.com/ericwtodd/function_vectors). -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | # ======= Tasks ======= 2 | num_seeds: 3 3 | tasks: 4 | - country-capital 5 | - country-currency 6 | - animal-latin 7 | - animal-young 8 | - food-color 9 | - food-flavor 10 | instructions: 11 | - "The capital city of the country:" 12 | - "The last word of the official currency of the country:" 13 | - "The scientific name of the animal's species in latin:" 14 | - "The term for the baby of the animal:" 15 | - "The color of the food:" 16 | - "The flavor descriptor of the food:" 17 | # ======= Dataset ======= 18 | data_folder: data/annotations 19 | dataset_kwargs: 20 | src_modality: ${src_modality} 21 | tgt_modality: ${tgt_modality} 22 | include_regular: True 23 | # ====== Generate ======= 24 | generate_kwargs: 25 | max_new_tokens: 1 26 | do_sample: False 27 | # ======= Saving ======== 28 | exp: ${spec}_${src_modality}-${tgt_modality} 29 | feats_folder: runs/feats/${exp}/${feats_model}/val 30 | save_folder: runs/experiments/${exp}/feats-${feats_model}_patch-${patch_model}/${mode} 31 | # ======================== -------------------------------------------------------------------------------- /configs/models/idefics2.yaml: -------------------------------------------------------------------------------- 1 | # ======== Model ======== 2 | model_id: HuggingFaceM4/idefics2-8b 3 | model_revision: 2c42686c57fe21cf0348c9ce1077d094b72e7698 4 | patch_L: 16 5 | # ======================= -------------------------------------------------------------------------------- /configs/models/llava.yaml: -------------------------------------------------------------------------------- 1 | # ======== Model ======== 2 | model_id: llava-hf/llava-1.5-7b-hf 3 | model_revision: a272c74b2481d8aff3aa6fc2c4bf891fe57334fb 4 | patch_L: 15 5 | # ======= Dataset ======= 6 | dataset_kwargs: 7 | # !!! IMPORTANT !!! 8 | # llava adds image tokens before text 9 | # but doesn't update input_ids 10 | image_offset: 575 11 | # ======================= -------------------------------------------------------------------------------- /configs/models/mantis.yaml: -------------------------------------------------------------------------------- 1 | # ======== Model ======== 2 | model_id: TIGER-Lab/Mantis-8B-Fuyu 3 | model_revision: ea76e7446341b399d9af2d7af390abb83fda6b28 4 | patch_L: 23 5 | # ======= Dataset ======= 6 | dataset_kwargs: 7 | # !!! IMPORTANT !!! 8 | # mantis adds "image 0:" to 9 | # delimit the images 10 | src_match: -4 11 | # ======================= -------------------------------------------------------------------------------- /configs/models/mistral.yaml: -------------------------------------------------------------------------------- 1 | # ======== Model ======== 2 | model_id: mistralai/Mistral-7B-v0.1 3 | model_revision: 7231864981174d9bee8c7687c24c8344414eae6b 4 | # ======================= -------------------------------------------------------------------------------- /configs/models/vicuna.yaml: -------------------------------------------------------------------------------- 1 | # ======== Model ======== 2 | model_id: lmsys/vicuna-7b-v1.5 3 | model_revision: 3321f76e3f527bd14065daf69dad9344000a201d 4 | # ======================= -------------------------------------------------------------------------------- /data/examples/image_icl/Mr_Krabs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g-luo/vlm_cross_modal_reps/0cc55c3dbeb677f1585b8f18b2aea8e2391ec9ce/data/examples/image_icl/Mr_Krabs.png -------------------------------------------------------------------------------- /data/examples/image_icl/Mrs_Puff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g-luo/vlm_cross_modal_reps/0cc55c3dbeb677f1585b8f18b2aea8e2391ec9ce/data/examples/image_icl/Mrs_Puff.png -------------------------------------------------------------------------------- /data/examples/image_icl/Sandy_Cheeks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g-luo/vlm_cross_modal_reps/0cc55c3dbeb677f1585b8f18b2aea8e2391ec9ce/data/examples/image_icl/Sandy_Cheeks.png -------------------------------------------------------------------------------- /data/examples/instruction/Swamp_Wallaby.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g-luo/vlm_cross_modal_reps/0cc55c3dbeb677f1585b8f18b2aea8e2391ec9ce/data/examples/instruction/Swamp_Wallaby.png -------------------------------------------------------------------------------- /data/examples/task_conflict/VQA_Food.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g-luo/vlm_cross_modal_reps/0cc55c3dbeb677f1585b8f18b2aea8e2391ec9ce/data/examples/task_conflict/VQA_Food.png -------------------------------------------------------------------------------- /data/examples/text_icl/France.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g-luo/vlm_cross_modal_reps/0cc55c3dbeb677f1585b8f18b2aea8e2391ec9ce/data/examples/text_icl/France.png -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | data_root=data 3 | repo=https://huggingface.co/datasets/g-luo/vlm-cross-modal-reps 4 | 5 | zip_name=annotations 6 | wget -P ${data_root} -O ${data_root}/${zip_name}.zip ${repo}/resolve/main/data/${zip_name}.zip?download=true 7 | unzip ${data_root}/${zip_name}.zip -d ${data_root} 8 | rm -rf ${data_root}/${zip_name}.zip 9 | 10 | zip_name=images 11 | wget -P ${data_root} -O ${data_root}/${zip_name}.zip ${repo}/resolve/main/data/${zip_name}.zip?download=true 12 | unzip ${data_root}/${zip_name}.zip -d ${data_root} 13 | rm -rf ${data_root}/${zip_name}.zip -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: xpatch 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - python=3.8.19 8 | - cudatoolkit=11.7.0 9 | - pytorch=2.0.1 10 | - pip: 11 | - git+https://github.com/davidbau/baukit@main#egg=baukit 12 | - git+https://github.com/TIGER-AI-Lab/Mantis.git 13 | - accelerate==0.33.0 14 | - bitsandbytes==0.43.1 15 | - datasets==2.18.0 16 | - huggingface-hub==0.23.4 17 | - matplotlib==3.7.5 18 | - numpy==1.24.4 19 | - omegaconf==2.1.1 20 | - pandas==2.0.3 21 | - protobuf==3.20.3 22 | - scikit-learn==1.3.2 23 | - seaborn==0.13.2 24 | - sentencepiece==0.2.0 25 | - transformers==4.44.2 26 | - tqdm==4.66.4 -------------------------------------------------------------------------------- /experiment_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # =========== CLI Args ======== 3 | feats_model=$1 4 | patch_model=$2 5 | spec=icl 6 | src_modality=text 7 | tgt_modality=image 8 | patch_L=-1 9 | config_name=configs/base.yaml 10 | # ============================= 11 | 12 | cli_args="feats_model=${feats_model} patch_model=${patch_model} spec=${spec} src_modality=${src_modality} tgt_modality=${tgt_modality}" 13 | 14 | # Feature Caching 15 | python3 src/xpatch_evaluate.py ${config_name} save_feats=true mode=val ${cli_args} 16 | 17 | # Validation 18 | python3 src/xpatch_evaluate.py ${config_name} save_feats=false mode=val ${cli_args} 19 | 20 | # Test 21 | python3 src/xpatch_evaluate.py ${config_name} patch_L=${patch_L} save_feats=false mode=test ${cli_args} -------------------------------------------------------------------------------- /experiment_ensemble.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # ======== CLI Args =========== 3 | ensemble=$1 4 | feats_model=idefics2 5 | patch_model=${feats_model} 6 | src_modality=text 7 | tgt_modality=image 8 | config_name=configs/base.yaml 9 | # ============================= 10 | 11 | cli_args="feats_model=${feats_model} patch_model=${patch_model} src_modality=${src_modality} tgt_modality=${tgt_modality}" 12 | 13 | # Instruction Feature Caching 14 | instruction_args="num_seeds=1 exp=instruction_${src_modality}-${tgt_modality}_scaling" 15 | python3 src/xpatch_evaluate.py ${config_name} spec=instruction save_feats=true mode=val ${cli_args} ${instruction_args} 16 | 17 | # n=0 18 | python3 src/xpatch_evaluate.py ${config_name} spec=instruction save_feats=false mode=test allow_lower_case=true ${cli_args} ${instruction_args} 19 | 20 | # ICL Feature Caching 21 | feats_all_args="save_feats_all=true exp=icl_${src_modality}-${tgt_modality}_scaling" 22 | python3 src/xpatch_evaluate.py ${config_name} spec=icl save_feats=true mode=val ${cli_args} ${feats_all_args} 23 | 24 | # n=5 to 30 25 | for r in {1..6} 26 | do 27 | # Each row r in feats is composed of 5 ICL examples 28 | n=$((r * 5)) 29 | save_folder=runs/experiments/icl_${src_modality}-${tgt_modality}_scaling/feats-${feats_model}_patch-${patch_model}/test/n-${n}_ensemble-${ensemble} 30 | scaling_args="save_feats_subset=${r} save_folder=${save_folder}" 31 | python3 src/xpatch_evaluate.py ${config_name} spec=icl ensemble=${ensemble} save_feats=false mode=test allow_lower_case=true ${cli_args} ${feats_all_args} ${scaling_args} 32 | done -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/g-luo/vlm_cross_modal_reps/0cc55c3dbeb677f1585b8f18b2aea8e2391ec9ce/src/__init__.py -------------------------------------------------------------------------------- /src/xpatch_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from PIL import Image 4 | import torch 5 | 6 | import xpatch_helpers 7 | 8 | def open_image(image_path, image_size): 9 | image = Image.open(image_path) 10 | image = image.convert("RGB") 11 | width, height = image.size 12 | new_width = image_size 13 | new_height = int((image_size / width) * height) 14 | image = image.resize((new_width, new_height)) 15 | return image 16 | 17 | def format_prompt(samples, modality, remove_last=True, special_token=":", image_root="", image_size=224): 18 | prompt, images = "", [] 19 | query_fn, answer_fn = lambda x: f"Q{special_token}{x}\nA{special_token}", lambda x: f"{x}\n\n" 20 | for i, sample in enumerate(samples): 21 | if modality == "image": 22 | user = f"" 23 | image_path = f"{image_root}/{sample['image']}" 24 | image = open_image(image_path, image_size) 25 | images.append(image) 26 | elif modality == "text": 27 | user = sample["input"] 28 | else: 29 | raise NotImplementedError 30 | prompt += query_fn(user) 31 | asst = sample.get("output", "") 32 | if not remove_last or (i != len(samples) - 1): 33 | prompt += answer_fn(asst) 34 | return prompt, images 35 | 36 | class xPatchDataset(torch.utils.data.Dataset): 37 | def __init__( 38 | self, 39 | annotation_file, 40 | processor, 41 | num_examples=5, 42 | src_modality="text", 43 | tgt_modality="image", 44 | special_token=":", 45 | image_root="data", 46 | image_size=224, 47 | seed=1, 48 | include_regular=True, 49 | src_match=None, 50 | tgt_match=None, 51 | image_offset=0 52 | ): 53 | super().__init__() 54 | 55 | np.random.seed(seed) 56 | self.data = json.load(open(annotation_file)) 57 | self.processor = processor 58 | self.num_examples = num_examples 59 | self.src_modality = src_modality 60 | self.tgt_modality = tgt_modality 61 | self.image_root = image_root 62 | self.image_size = image_size 63 | self.include_regular = include_regular 64 | self.image_offset = image_offset 65 | self.random_state = {} 66 | 67 | if type(self.data) is dict: 68 | self.data = self.data["data"] 69 | 70 | # processor fails when first image doesn't have image 71 | if src_modality == "text" and tgt_modality == "image": 72 | self.src_idx, self.tgt_idx = 1, 0 73 | else: 74 | self.src_idx, self.tgt_idx = 0, 1 75 | 76 | if src_match is None: 77 | self.src_match = -3 if (self.include_regular and num_examples > 0) else -1 78 | else: 79 | self.src_match = src_match 80 | if tgt_match is None: 81 | self.tgt_match = -1 82 | else: 83 | self.tgt_match = tgt_match 84 | 85 | self.special_token = special_token 86 | self.find_token_kwargs = { 87 | "token": self.special_token, 88 | "token_ids": [processor.tokenizer(f"A{self.special_token}")["input_ids"][-1]] 89 | } 90 | 91 | def get_text_images(self, idx, train_idxs): 92 | test_sample = self.data[idx] 93 | test_text, test_images = format_prompt( 94 | [test_sample], 95 | self.tgt_modality, 96 | remove_last=True, 97 | special_token=self.special_token, 98 | image_root=self.image_root, 99 | image_size=self.image_size 100 | ) 101 | 102 | # Select random ICL examples 103 | train_samples = [self.data[i] for i in train_idxs] 104 | train_text, train_images = format_prompt( 105 | train_samples, 106 | self.src_modality, 107 | remove_last=False, 108 | special_token=self.special_token, 109 | image_root=self.image_root, 110 | image_size=self.image_size 111 | ) 112 | 113 | # We copy the task vector only 114 | # from an unrelated ICL example (match = -3) 115 | # Run 1 = ICL + test sample 116 | # Run 2 = ZS test sample 117 | # text = [train_text, test_text] 118 | # images = [train_images, test_images] 119 | 120 | # Set src and tgt 121 | if self.include_regular: 122 | train_text += test_text 123 | train_images += test_images 124 | text = [None, None] 125 | images = [None, None] 126 | text[self.src_idx] = train_text 127 | text[self.tgt_idx] = test_text 128 | images[self.src_idx] = train_images 129 | images[self.tgt_idx] = test_images 130 | 131 | if len(sum(images, [])) == 0: 132 | images = None 133 | 134 | return text, images 135 | 136 | def get_meta(self, meta, batch, flag, find_idx=None): 137 | idx = getattr(self, f"{flag}_idx") 138 | match = getattr(self, f"{flag}_match") 139 | if find_idx is None: 140 | find_idx = idx 141 | token = xpatch_helpers.find_token( 142 | self.processor, 143 | batch["input_ids"][find_idx], 144 | match=match, 145 | **self.find_token_kwargs 146 | ) 147 | if getattr(self, f"{flag}_modality") == "image": 148 | token += self.image_offset 149 | meta[f"{flag}_token"] = token 150 | meta[f"{flag}_idx"] = idx 151 | return meta 152 | 153 | def get_image_or_none(self, images, idx): 154 | if images is None: 155 | return None 156 | images = images[idx] 157 | if len(images) == 0: 158 | return None 159 | else: 160 | return images 161 | 162 | def __len__(self): 163 | return len(self.data) 164 | 165 | def __getitem__(self, idx, flag=None): 166 | # since __getitem__ may be called twice 167 | # ensure that train_idxs are the same 168 | if idx in self.random_state: 169 | np.random.set_state(self.random_state[idx]) 170 | else: 171 | self.random_state[idx] = np.random.get_state() 172 | all_idxs = [i for i in range(len(self.data)) if i != idx] 173 | train_idxs = np.random.permutation(all_idxs)[:self.num_examples].tolist() 174 | 175 | meta = {} 176 | text, images = self.get_text_images(idx, train_idxs) 177 | if flag == "src": 178 | batch = self.processor(text=text[self.src_idx], images=self.get_image_or_none(images, self.src_idx), return_tensors="pt") 179 | meta = self.get_meta(meta, batch, "src", 0) 180 | elif flag == "tgt": 181 | batch = self.processor(text=text[self.tgt_idx], images=self.get_image_or_none(images, self.tgt_idx), return_tensors="pt") 182 | meta = self.get_meta(meta, batch, "tgt", 0) 183 | else: 184 | batch = self.processor(text=text, images=images, padding=True, return_tensors="pt") 185 | meta = self.get_meta(meta, batch, "src") 186 | meta = self.get_meta(meta, batch, "tgt") 187 | meta["train_idxs"] = train_idxs 188 | meta["labels"] = self.processor(text=self.data[idx]["output"], return_tensors="pt")["input_ids"] 189 | return batch, meta -------------------------------------------------------------------------------- /src/xpatch_evaluate.py: -------------------------------------------------------------------------------- 1 | from baukit import TraceDict 2 | import copy 3 | import json 4 | import numpy as np 5 | from omegaconf import OmegaConf 6 | import os 7 | import pandas as pd 8 | import sys 9 | import torch 10 | from tqdm import tqdm 11 | 12 | import xpatch_helpers, xpatch_dataset 13 | 14 | SEED_START = 1 15 | 16 | # =========================== 17 | # Feature Loading 18 | # =========================== 19 | def get_file(folder, task=None, seed=None, filetype=None): 20 | file = [] 21 | if task is not None: 22 | file += [f"task-{task}"] 23 | if seed is not None: 24 | file += [f"seed-{seed}"] 25 | file = "_".join(file) 26 | if filetype: 27 | file += filetype 28 | return f"{folder}/{file}" 29 | 30 | def open_feats(feats_file): 31 | if not os.path.exists(feats_file): 32 | raise Exception(f"File {feats_file} does not exist!") 33 | else: 34 | print(f"Loading feats, {feats_file}") 35 | feats = torch.load(feats_file) 36 | feats = feats.detach().cpu() 37 | return feats 38 | 39 | def postprocess_feats(feats, config, seed): 40 | if config.get("save_feats_all", False): 41 | save_feats_subset = config.get("save_feats_subset") 42 | if save_feats_subset: 43 | np.random.seed(seed) 44 | random_idxs = np.random.permutation(range(feats.shape[0]))[:save_feats_subset] 45 | feats = feats[random_idxs] 46 | feats = feats.mean(dim=0) 47 | return feats 48 | 49 | def get_feats(config, task, seed): 50 | feats_file = get_file(config['feats_folder'], task, seed, ".pt") 51 | feats = open_feats(feats_file) 52 | feats = postprocess_feats(feats, config, seed) 53 | # Ensemble instruction feats (optional) 54 | if config.get("ensemble"): 55 | ensemble_feats_folder = config['feats_folder'].replace("icl", "instruction") 56 | ensemble_feats_file = get_file(ensemble_feats_folder, task, SEED_START, ".pt") 57 | ensemble_feats = open_feats(ensemble_feats_file) 58 | feats = (feats + ensemble_feats) / 2 59 | return feats 60 | 61 | # =========================== 62 | # Feature Saving 63 | # =========================== 64 | def embed_prompt(model, processor, text_model, model_config, config, text, images=None): 65 | # Assumes the special token is the last one 66 | src_token = -1 67 | src_batch = processor(text=text, images=images, return_tensors="pt") 68 | with TraceDict(text_model, layers=model_config["layer_hook_names"], retain_output=True) as ret: 69 | model.forward(**src_batch) 70 | feats = [ret[k].output[0] for k in model_config["layer_hook_names"]] 71 | feats = [feat[:, src_token, :][:, None, :] for feat in feats] 72 | feats = [feat.detach().cpu() for feat in feats] 73 | feats = torch.stack(feats) 74 | instruction_feats = [feats] 75 | return instruction_feats 76 | 77 | def embed_dataset(model, processor, text_model, model_config, config, icl_pair_xpatch_dataset): 78 | dataset_feats = [] 79 | for i in tqdm(range(len(icl_pair_xpatch_dataset))): 80 | src_batch, src_meta = icl_pair_xpatch_dataset.__getitem__(i, "src") 81 | with TraceDict(text_model, layers=model_config["layer_hook_names"], retain_output=True) as ret: 82 | xpatch_helpers.generate(model, processor, src_batch, **config["generate_kwargs"]) 83 | feats = [ret[k].output[0] for k in model_config["layer_hook_names"]] 84 | feats = [feat[:, src_meta["src_token"], :] for feat in feats] 85 | feats = [feat.detach().cpu() for feat in feats] 86 | feats = torch.stack(feats) 87 | dataset_feats.append(feats) 88 | return dataset_feats 89 | 90 | def save_feats(model, processor, text_model, model_config, config, task, seed, icl_pair_xpatch_dataset, instruction): 91 | # Create the feats 92 | if config["spec"] == "icl": 93 | feats = embed_dataset(model, processor, text_model, model_config, config, icl_pair_xpatch_dataset) 94 | elif config["spec"] == "instruction": 95 | feats = embed_prompt(model, processor, text_model, model_config, config, instruction) 96 | else: 97 | raise NotImplementedError("Config does not have a valid spec!") 98 | # Average across all feats 99 | feats = torch.stack(feats) 100 | if not config.get("save_feats_all", False): 101 | feats = feats.mean(dim=0) 102 | # Write the feats to file 103 | feats_file = get_file(config['feats_folder'], task, seed, ".pt") 104 | save_feats_folder = os.path.dirname(feats_file) 105 | if not os.path.exists(save_feats_folder): 106 | os.makedirs(save_feats_folder, exist_ok=True) 107 | torch.save(feats, feats_file) 108 | 109 | # =========================== 110 | # Patching 111 | # =========================== 112 | def patch_layer(model, processor, text_model, model_config, generate_kwargs, feats, L, cache_L, tgt_batch, tgt_token, patch_L=None, return_feats=False): 113 | if patch_L and L != patch_L: 114 | return None 115 | intervention_fn = xpatch_helpers.patch_output_pair( 116 | tgt_token=tgt_token, 117 | L=L, 118 | cache_L=cache_L, 119 | feats=feats 120 | ) 121 | with TraceDict(text_model, layers=model_config['layer_hook_names'], edit_output=intervention_fn, retain_output=return_feats) as ret: 122 | # Convert to text and back to input_ids 123 | # since tokenization is different based on position 124 | output_text = xpatch_helpers.generate(model, processor, tgt_batch, **generate_kwargs) 125 | output_text = output_text[0] 126 | if return_feats: 127 | feats = [ret[k].output[0] for k in model_config['layer_hook_names']] 128 | return output_text, feats 129 | else: 130 | return output_text 131 | 132 | def evaluate(model, processor, text_model, model_config, config, task, seed, icl_pair_xpatch_dataset): 133 | use_patching = config.get("use_patching", True) 134 | patch_L = config.get("patch_L") 135 | 136 | # Create save folder 137 | save_file = get_file(config['save_folder'], task, seed, ".json") 138 | save_folder = os.path.dirname(save_file) 139 | if not os.path.exists(save_folder): 140 | os.makedirs(save_folder, exist_ok=True) 141 | save_keys = ["baseline_text", "regular_text", "labels_text", "train_idxs"] 142 | results = {k: [] for k in save_keys} 143 | results["hypothesis_text"] = [[] for _ in range(model_config["n_layers"])] 144 | results["config"] = OmegaConf.to_container(config, resolve=True) 145 | json.dump(results, open(save_file, "w")) 146 | 147 | # Load cached feats 148 | feats = get_feats(config, task, seed) 149 | feats = [feat.to(model.device).to(model.dtype) for feat in feats] 150 | 151 | for i in tqdm(range(len(icl_pair_xpatch_dataset))): 152 | src_batch, src_meta = icl_pair_xpatch_dataset.__getitem__(i, "src") 153 | tgt_batch, tgt_meta = icl_pair_xpatch_dataset.__getitem__(i, "tgt") 154 | # Run Regular 155 | output_text = xpatch_helpers.generate(model, processor, src_batch, **config["generate_kwargs"]) 156 | results["regular_text"].append(output_text[0]) 157 | # Run Baseline 158 | output_text = xpatch_helpers.generate(model, processor, tgt_batch, **config["generate_kwargs"]) 159 | results["baseline_text"].append(output_text[0]) 160 | # Patch Task Vector 161 | if use_patching: 162 | patch_layer_kwargs = { 163 | "model": model, 164 | "processor": processor, 165 | "text_model": text_model, 166 | "model_config": model_config, 167 | "generate_kwargs": config["generate_kwargs"], 168 | "feats": feats, 169 | "tgt_batch": tgt_batch, 170 | "tgt_token": tgt_meta["tgt_token"] 171 | } 172 | layer_results = [patch_layer(L=L, patch_L=patch_L, cache_L=L, **patch_layer_kwargs) for L in tqdm(range(model_config["n_layers"]))] 173 | for L, layer_result in enumerate(layer_results): 174 | results["hypothesis_text"][L].append(layer_result) 175 | # Save metadata 176 | labels = tgt_meta["labels"] 177 | train_idxs = src_meta["train_idxs"] 178 | results["labels_text"].append(processor.batch_decode(labels, skip_special_tokens=True)[0]) 179 | results["train_idxs"].append(train_idxs) 180 | json.dump(results, open(save_file, "w")) 181 | 182 | # =========================== 183 | # Setup 184 | # =========================== 185 | def load_model(model_id, model_revision=None, device="cuda"): 186 | model, processor = xpatch_helpers.load_hf_model(model_id, model_revision=model_revision, device=device) 187 | xpatch_helpers.remove_cache(model) 188 | text_model = xpatch_helpers.load_text_model(model_id, model) 189 | model_config = xpatch_helpers.get_model_config(text_model) 190 | model_config["layer_hook_names"] = [f"model.layers.{layer}" for layer in range(text_model.config.num_hidden_layers)] 191 | return model, processor, text_model, model_config 192 | 193 | def get_best_layer(tasks, save_folder): 194 | save_files = [] 195 | for task in tasks: 196 | save_file = get_file(save_folder, task, None) 197 | save_files.append(save_file) 198 | acc_info = xpatch_helpers.avg_task_accuracy(save_files, verbose=False) 199 | return acc_info["max_idx"] 200 | 201 | def get_overall_results(tasks, save_folder, allow_lower_case): 202 | patch_L = get_best_layer(tasks, save_folder) 203 | results = [] 204 | for task in tasks: 205 | save_file = get_file(save_folder, task, None) 206 | acc_info = xpatch_helpers.avg_task_accuracy([save_file], patch_L=patch_L, verbose=False, allow_lower_case=allow_lower_case) 207 | results.append(acc_info["accuracy"]) 208 | df = pd.DataFrame(np.array(results).T) 209 | df.index = ["no_context", "prompt", f"patch (L={patch_L})"] 210 | df.index.name = "method" 211 | df.columns = list(tasks) 212 | df['avg'] = df.mean(axis=1) 213 | return df 214 | 215 | def main(config): 216 | # Patch all layers by default 217 | # If patch_L == -1, find best layer by best val score 218 | # according to results files 219 | if config.get("patch_L") == -1: 220 | config["patch_L"] = get_best_layer(config["tasks"], config["save_folder"].replace(config["mode"], "val")) 221 | 222 | # Load model 223 | model_name = config["feats_model"] if config.get("save_feats", False) else config["patch_model"] 224 | model_info = OmegaConf.load(f"configs/models/{model_name}.yaml") 225 | model, processor, text_model, model_config = load_model(model_info["model_id"], model_revision=model_info.get("model_revision")) 226 | config = OmegaConf.merge(config, model_info) 227 | 228 | # Loop through tasks 229 | for t, task in enumerate(config["tasks"]): 230 | dataset_kwargs = copy.deepcopy(config["dataset_kwargs"]) 231 | dataset_kwargs["annotation_file"] = f"{config['data_folder']}/{task}_{config['mode']}.json" 232 | # Setup instruction 233 | if config["spec"] == "instruction" and config.get("instructions"): 234 | instruction = config["instructions"][t] 235 | dataset_kwargs["num_examples"] = 0 236 | else: 237 | instruction = "" 238 | # Loop through seeds 239 | for seed in range(SEED_START, SEED_START + config["num_seeds"]): 240 | icl_pair_xpatch_dataset = xpatch_dataset.xPatchDataset( 241 | processor=processor, 242 | seed=seed, 243 | **dataset_kwargs 244 | ) 245 | if config.get("save_feats", False): 246 | save_feats(model, processor, text_model, model_config, config, task, seed, icl_pair_xpatch_dataset, instruction) 247 | else: 248 | evaluate(model, processor, text_model, model_config, config, task, seed, icl_pair_xpatch_dataset) 249 | if not config.get("save_feats", False): 250 | results = get_overall_results(config["tasks"], config["save_folder"], config.get("allow_lower_case", False)) 251 | results.to_csv(f"{config['save_folder']}.csv") 252 | 253 | if __name__ == "__main__": 254 | config = OmegaConf.load(sys.argv[1]) 255 | if len(sys.argv) > 2: 256 | cli_overrides = OmegaConf.from_cli(sys.argv[2:]) 257 | config = OmegaConf.merge(config, cli_overrides) 258 | OmegaConf.resolve(config) 259 | main(config) -------------------------------------------------------------------------------- /src/xpatch_helpers.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import json 3 | import glob 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | 8 | from transformers import ( 9 | AutoModelForVision2Seq, 10 | AutoProcessor, 11 | BitsAndBytesConfig 12 | ) 13 | 14 | # =========================== 15 | # VLM Loaders 16 | # =========================== 17 | class Idefics2Wrapper(torch.nn.Module): 18 | def __init__(self, text_model): 19 | super().__init__() 20 | self.model = text_model 21 | self.config = text_model.config 22 | 23 | class LLMProcessor: 24 | def __init__(self, tokenizer): 25 | self.tokenizer = tokenizer 26 | 27 | def __call__(self, text=None, images=None, **kwargs): 28 | return self.tokenizer(text, **kwargs) 29 | 30 | def remove_cache(module): 31 | def new_forward(self, **kwargs): 32 | outputs = self.old_forward(**kwargs) 33 | outputs["past_key_values"] = None 34 | return outputs 35 | if not hasattr(module, "old_forward"): 36 | module.old_forward = module.forward 37 | # Fix hf method signature complaints 38 | new_forward.__signature__ = inspect.signature(module.old_forward) 39 | module.forward = new_forward.__get__(module, type(module)) 40 | 41 | def get_quantization_config(model_kwargs): 42 | torch_dtype = torch.float16 43 | model_kwargs["quantization_config"] = BitsAndBytesConfig( 44 | load_in_4bit=True, 45 | bnb_4bit_quant_type="nf4", 46 | bnb_4bit_compute_dtype=torch_dtype, 47 | bnb_4bit_quant_storage=torch_dtype, 48 | bnb_4bit_use_double_quant=True, 49 | ) 50 | return model_kwargs 51 | 52 | def get_model_config(model): 53 | model_config = { 54 | "name_or_path": model.config._name_or_path, 55 | "n_heads": model.config.num_attention_heads, 56 | "n_layers": model.config.num_hidden_layers, 57 | "resid_dim": model.config.hidden_size 58 | } 59 | return model_config 60 | 61 | def load_text_model(model_id, model): 62 | text_model_loaders = { 63 | "idefics2": lambda model: Idefics2Wrapper(model.model.text_model), 64 | "llava": lambda model: model.language_model, 65 | "Mantis": lambda model: model.language_model, 66 | } 67 | text_model_loader = None 68 | for k, v in text_model_loaders.items(): 69 | if k in model_id: 70 | text_model_loader = v 71 | break 72 | if text_model_loader is None: 73 | return model 74 | else: 75 | text_model = text_model_loader(model) 76 | return text_model 77 | 78 | def load_hf_model(model_id, model_revision=None, device="cuda", load_in_4bit=True): 79 | model_kwargs = {} 80 | if load_in_4bit: 81 | model_kwargs = get_quantization_config(model_kwargs) 82 | if "Mantis" in model_id: 83 | from mantis.models.mfuyu import MFuyuForCausalLM, MFuyuProcessor 84 | model = MFuyuForCausalLM.from_pretrained(model_id, revision=model_revision, device_map=device, **model_kwargs) 85 | processor = MFuyuProcessor.from_pretrained(model_id, revision=model_revision) 86 | elif "mistral" in model_id or "vicuna" in model_id: 87 | from transformers import AutoModelForCausalLM, AutoTokenizer 88 | model = AutoModelForCausalLM.from_pretrained(model_id, revision=model_revision, device_map=device, **model_kwargs) 89 | tokenizer = AutoTokenizer.from_pretrained(model_id, revision=model_revision) 90 | processor = LLMProcessor(tokenizer) 91 | else: 92 | model = AutoModelForVision2Seq.from_pretrained(model_id, revision=model_revision, device_map=device, **model_kwargs) 93 | processor = AutoProcessor.from_pretrained(model_id, revision=model_revision) 94 | processor.image_processor.do_image_splitting = False 95 | return model, processor 96 | 97 | # =========================== 98 | # Generation 99 | # =========================== 100 | def prepare_batch(inputs, device): 101 | for k, v in inputs.items(): 102 | if torch.is_tensor(v): 103 | inputs[k] = v.to(device) 104 | if type(v) is list: 105 | inputs[k] = [v_.to(device) if torch.is_tensor(v_) else v_ for v_ in v] 106 | return inputs 107 | 108 | def generate(model, processor, inputs, remove_input=True, **generate_kwargs): 109 | with torch.no_grad(): 110 | inputs = prepare_batch(inputs, model.device) 111 | output = model.generate(**inputs, **generate_kwargs) 112 | if remove_input: 113 | input_len = inputs["input_ids"].shape[1] 114 | output = output[:, input_len:] 115 | output = processor.tokenizer.batch_decode(output, skip_special_tokens=True) 116 | return output 117 | 118 | def find_token(processor, input_ids, token_ids=None, token=None, token_offset=0, match="last", pad_id=0): 119 | if token is None: 120 | idxs = (input_ids != pad_id).sum() - (token_offset + 1) 121 | idxs = idxs[None, ...] 122 | return idxs 123 | else: 124 | # Replace the last instance in src and first in tgt 125 | if token_ids is None: 126 | token_ids = processor.tokenizer(token)["input_ids"][1:] 127 | token_ids = torch.tensor(token_ids, device=input_ids.device) 128 | windows = input_ids.unfold(0, len(token_ids), 1) 129 | matches = torch.all(windows == token_ids, dim=1) 130 | matches = torch.where(matches)[0] 131 | if match == "last": 132 | matches = matches[-1] 133 | elif match == "first": 134 | matches = matches[0] 135 | else: 136 | matches = matches[match] 137 | matches = matches.item() 138 | idxs = torch.arange(matches, matches + len(token_ids)) 139 | return idxs 140 | 141 | # =========================== 142 | # Patching 143 | # =========================== 144 | """ 145 | Patching adapted from Function Vectors 146 | (Todd et. al., ICLR 2024) 147 | https://github.com/ericwtodd/function_vectors 148 | """ 149 | def patch_output_pair(src_idx=0, tgt_idx=0, src_token=-1, tgt_token=-1, L=None, cache_L=None, feats=None): 150 | def rep_act(output, layer_name, inputs): 151 | current_layer = int(layer_name.split(".")[2]) 152 | return_cache = type(output) is tuple 153 | if return_cache: 154 | act = output[0] 155 | cache = output[1] 156 | else: 157 | act = output 158 | if act.shape[1] == 1: 159 | print("WARNING: cache seems to be in use") 160 | if L == current_layer: 161 | act[tgt_idx, tgt_token] = feats[cache_L][src_idx, src_token].detach() 162 | if return_cache: 163 | output = (act, cache) 164 | else: 165 | output = act 166 | return output 167 | return rep_act 168 | 169 | # =========================== 170 | # Metrics 171 | # =========================== 172 | """ 173 | Metrics adapted from In-Context Learning 174 | Creates Task Vectors (Hendel et. al., EMNLP Findings 2023) 175 | https://github.com/roeehendel/icl_task_vectors 176 | """ 177 | 178 | def preprocess_text(lst, allow_lower_case=False): 179 | lst = [x.strip() for x in lst] 180 | if allow_lower_case: 181 | lst = [x.lower() for x in lst] 182 | return lst 183 | 184 | def compare_outputs(output1, output2): 185 | output1, output2 = output1.strip(), output2.strip() 186 | nonempy = len(output1) > 0 and len(output2) > 0 187 | return nonempy and (output1.startswith(output2) or output2.startswith(output1)) 188 | 189 | def compute_text_acc(pred_text, labels_text, allow_lower_case=False): 190 | pred_text = preprocess_text(pred_text, allow_lower_case=allow_lower_case) 191 | labels_text = preprocess_text(labels_text, allow_lower_case=allow_lower_case) 192 | vectorized_compare = np.vectorize(compare_outputs) 193 | correct = vectorized_compare(pred_text, labels_text) 194 | acc = correct.mean() 195 | return acc 196 | 197 | # =========================== 198 | # Logging 199 | # =========================== 200 | def get_max(x, verbose=False): 201 | max_value = np.max(x) 202 | max_indices = np.argwhere(x == max_value) 203 | max_index = max_indices[np.argmin(max_indices[:, 0])] 204 | if verbose: 205 | print("All indices with the maximum value:", max_indices) 206 | return max_index 207 | 208 | def avg_seed_accuracy(file_prefix, patch_L=None, verbose=True, allow_lower_case=False): 209 | baseline = [] 210 | regular = [] 211 | hypothesis = {} 212 | files = glob.glob(f"{file_prefix}*") 213 | if len(files) == 0: 214 | raise Exception("No files found!") 215 | for file in files: 216 | results = json.load(open(file)) 217 | baseline.append(compute_text_acc(results["baseline_text"], results["labels_text"], allow_lower_case=allow_lower_case)) 218 | regular.append(compute_text_acc(results["regular_text"], results["labels_text"], allow_lower_case=allow_lower_case)) 219 | for L in range(len(results["hypothesis_text"])): 220 | if patch_L is not None and L != patch_L: 221 | acc = 0 222 | else: 223 | results["hypothesis_text"][L] = [x if x is not None else "" for x in results["hypothesis_text"][L]] 224 | acc = compute_text_acc(results["hypothesis_text"][L], results["labels_text"], allow_lower_case=allow_lower_case) 225 | hypothesis[L] = hypothesis.get(L, []) + [acc] 226 | baseline = np.mean(baseline) 227 | regular = np.mean(regular) 228 | hypothesis = [np.mean(v) for v in hypothesis.values()] 229 | max_idx = get_max(hypothesis).item() 230 | if verbose: 231 | print(f"===== Accuracy ({len(files)} Seeds) =====") 232 | print(f"Baseline: {baseline:.2f}") 233 | print(f"Regular: {regular:.2f}") 234 | print(f"Hypothesis (L={max_idx}): {hypothesis[max_idx]:.2f}") 235 | return baseline, regular, hypothesis, max_idx 236 | 237 | def avg_task_accuracy(file_prefixes, patch_L=None, verbose=True, allow_lower_case=False): 238 | all_baseline, all_regular, all_hypothesis = [], [], [] 239 | for file_prefix in file_prefixes: 240 | baseline, regular, hypothesis, _ = avg_seed_accuracy(file_prefix, patch_L=patch_L, verbose=False, allow_lower_case=allow_lower_case) 241 | all_baseline.append(baseline) 242 | all_regular.append(regular) 243 | all_hypothesis.append(hypothesis) 244 | all_baseline_avg = np.mean(all_baseline) 245 | all_regular_avg = np.mean(all_regular) 246 | # For the same index in each list, compute the mean 247 | all_hypothesis_avg = [np.mean([hypothesis[L] for hypothesis in all_hypothesis]) for L in range(len(all_hypothesis[0]))] 248 | max_idx = get_max(all_hypothesis_avg).item() 249 | if verbose: 250 | print(f"===== Accuracy ({len(file_prefixes)} Tasks) =====") 251 | print(f"Baseline: {all_baseline_avg:.2f}") 252 | print(f"Regular: {all_regular_avg:.2f}") 253 | print(f"Hypothesis (L={max_idx}): {all_hypothesis_avg[max_idx]:.2f}") 254 | acc_info = { 255 | "raw": [all_baseline, all_regular, all_hypothesis], 256 | "accuracy": [all_baseline_avg, all_regular_avg, all_hypothesis_avg[max_idx]], 257 | "max_idx": max_idx 258 | } 259 | return acc_info 260 | 261 | def plot_icl(text, images): 262 | n = len(text) 263 | width = n * 4 264 | if images: 265 | height = 4 266 | else: 267 | height = 1 268 | fig, axes = plt.subplots(1, n, figsize=(width, height)) 269 | if n == 1: 270 | axes = [axes] 271 | for i in range(len(text)): 272 | if images: 273 | axes[i].imshow(images[i]) 274 | axes[i].axis('off') 275 | axes[i].set_title(text[i], wrap=True, size=24) 276 | plt.tight_layout() 277 | plt.show() --------------------------------------------------------------------------------