├── assets ├── mib_logo.png ├── circuit_track.png └── causal_variable_track.png ├── MIB-causal-variable-track ├── tasks │ ├── IOI_task │ │ ├── objects.json │ │ ├── places.json │ │ ├── names.json │ │ ├── templates.json │ │ └── ioi_task.py │ ├── hf_dataloader.py │ ├── ARC │ │ └── ARC.py │ ├── two_digit_addition_task │ │ └── arithmetic.py │ ├── simple_MCQA │ │ ├── simple_MCQA.py │ │ └── object_color_pairs.json │ └── RAVEL │ │ └── ravel.py ├── requirements.txt ├── mock_submission │ ├── featurizer.py │ └── token_position.py ├── baselines │ ├── run_quick_tests.sh │ ├── arithmetic_baselines.py │ ├── ioi_baselines │ │ ├── ioi_learn_linear_params.py │ │ ├── ioi_utils.py │ │ └── ioi_baselines.py │ ├── simple_MCQA_baselines.py │ ├── ARC_baselines.py │ └── ravel_baselines.py ├── process_all_submissions.py ├── README.md └── test_MIB_datasets │ └── test_ARC.py ├── .gitmodules ├── LICENSE └── README.md /assets/mib_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronmueller/MIB/HEAD/assets/mib_logo.png -------------------------------------------------------------------------------- /assets/circuit_track.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronmueller/MIB/HEAD/assets/circuit_track.png -------------------------------------------------------------------------------- /assets/causal_variable_track.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aaronmueller/MIB/HEAD/assets/causal_variable_track.png -------------------------------------------------------------------------------- /MIB-causal-variable-track/tasks/IOI_task/objects.json: -------------------------------------------------------------------------------- 1 | [ 2 | "ring", 3 | "kiss", 4 | "bone", 5 | "basketball", 6 | "computer", 7 | "necklace", 8 | "drink", 9 | "snack" 10 | ] -------------------------------------------------------------------------------- /MIB-causal-variable-track/tasks/IOI_task/places.json: -------------------------------------------------------------------------------- 1 | [ 2 | "store", 3 | "garden", 4 | "restaurant", 5 | "school", 6 | "hospital", 7 | "office", 8 | "house", 9 | "station" 10 | ] -------------------------------------------------------------------------------- /MIB-causal-variable-track/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | networkx 3 | matplotlib 4 | pandas 5 | numpy 6 | pyvene 7 | pytest 8 | scikit-learn 9 | ipywidgets 10 | jupyterlab 11 | ipycytoscape 12 | sae_lens 13 | transformers 14 | seaborn 15 | tensorboard 16 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "MIB-circuit-track"] 2 | path = MIB-circuit-track 3 | url = https://github.com/hannamw/MIB-circuit-track 4 | [submodule "MIB-causal-variable-track/CausalAbstraction"] 5 | path = MIB-causal-variable-track/CausalAbstraction 6 | url = https://github.com/atticusg/CausalAbstraction -------------------------------------------------------------------------------- /MIB-causal-variable-track/tasks/IOI_task/names.json: -------------------------------------------------------------------------------- 1 | [ 2 | "Michael", 3 | "Christopher", 4 | "Jessica", 5 | "Matthew", 6 | "Ashley", 7 | "Jennifer", 8 | "Joshua", 9 | "Amanda", 10 | "Daniel", 11 | "David", 12 | "James", 13 | "Robert", 14 | "John", 15 | "Joseph", 16 | "Andrew", 17 | "Ryan", 18 | "Brandon", 19 | "Jason", 20 | "Justin", 21 | "Sarah", 22 | "William", 23 | "Jonathan", 24 | "Stephanie", 25 | "Brian", 26 | "Nicole", 27 | "Nicholas", 28 | "Anthony", 29 | "Heather", 30 | "Eric", 31 | "Elizabeth", 32 | "Adam", 33 | "Megan", 34 | "Melissa", 35 | "Kevin", 36 | "Steven", 37 | "Thomas", 38 | "Timothy", 39 | "Christina", 40 | "Kyle", 41 | "Rachel", 42 | "Laura", 43 | "Lauren", 44 | "Amber", 45 | "Brittany", 46 | "Danielle", 47 | "Richard", 48 | "Kimberly", 49 | "Jeffrey", 50 | "Amy", 51 | "Crystal", 52 | "Michelle", 53 | "Tiffany", 54 | "Jeremy", 55 | "Benjamin", 56 | "Mark", 57 | "Emily", 58 | "Aaron", 59 | "Charles", 60 | "Rebecca", 61 | "Jacob", 62 | "Stephen", 63 | "Patrick", 64 | "Sean", 65 | "Erin", 66 | "Jamie", 67 | "Kelly", 68 | "Samantha", 69 | "Nathan", 70 | "Sara", 71 | "Dustin", 72 | "Paul", 73 | "Angela", 74 | "Tyler", 75 | "Scott", 76 | "Katherine", 77 | "Andrea", 78 | "Gregory", 79 | "Erica", 80 | "Mary", 81 | "Travis", 82 | "Lisa", 83 | "Kenneth", 84 | "Bryan", 85 | "Lindsey", 86 | "Kristen", 87 | "Jose", 88 | "Alexander", 89 | "Jesse", 90 | "Katie", 91 | "Lindsay", 92 | "Shannon", 93 | "Vanessa", 94 | "Courtney", 95 | "Christine", 96 | "Alicia", 97 | "Cody", 98 | "Allison", 99 | "Bradley", 100 | "Samuel" 101 | ] -------------------------------------------------------------------------------- /MIB-causal-variable-track/mock_submission/featurizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy of the existing SubspaceFeaturizer implementation for submission. 3 | This file provides the same SubspaceFeaturizer functionality in a self-contained format. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import pyvene as pv 9 | from CausalAbstraction.neural.featurizers import Featurizer 10 | 11 | 12 | class SubspaceFeaturizerModuleCopy(torch.nn.Module): 13 | def __init__(self, rotate_layer): 14 | super().__init__() 15 | self.rotate = rotate_layer 16 | 17 | def forward(self, x): 18 | r = self.rotate.weight.T 19 | f = x.to(r.dtype) @ r.T 20 | error = x - (f @ r).to(x.dtype) 21 | return f, error 22 | 23 | 24 | class SubspaceInverseFeaturizerModuleCopy(torch.nn.Module): 25 | def __init__(self, rotate_layer): 26 | super().__init__() 27 | self.rotate = rotate_layer 28 | 29 | def forward(self, f, error): 30 | r = self.rotate.weight.T 31 | return (f.to(r.dtype) @ r).to(f.dtype) + error.to(f.dtype) 32 | 33 | 34 | class SubspaceFeaturizerCopy(Featurizer): 35 | def __init__(self, shape=None, rotation_subspace=None, trainable=True, id="subspace"): 36 | assert shape is not None or rotation_subspace is not None, "Either shape or rotation_subspace must be provided." 37 | if shape is not None: 38 | self.rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=True) 39 | elif rotation_subspace is not None: 40 | shape = rotation_subspace.shape 41 | self.rotate = pv.models.layers.LowRankRotateLayer(*shape, init_orth=False) 42 | self.rotate.weight.data.copy_(rotation_subspace) 43 | self.rotate = torch.nn.utils.parametrizations.orthogonal(self.rotate) 44 | 45 | if not trainable: 46 | self.rotate.requires_grad_(False) 47 | 48 | # Create module-based featurizer and inverse_featurizer 49 | featurizer = SubspaceFeaturizerModuleCopy(self.rotate) 50 | inverse_featurizer = SubspaceInverseFeaturizerModuleCopy(self.rotate) 51 | 52 | super().__init__(featurizer, inverse_featurizer, n_features=self.rotate.weight.shape[1], id=id) -------------------------------------------------------------------------------- /MIB-causal-variable-track/mock_submission/token_position.py: -------------------------------------------------------------------------------- 1 | """ 2 | Token position definitions for MCQA task submission. 3 | This file provides token position functions that identify key tokens in MCQA prompts. 4 | """ 5 | 6 | import re 7 | from CausalAbstraction.neural.LM_units import TokenPosition, get_last_token_index 8 | 9 | 10 | def get_token_positions(pipeline, causal_model): 11 | """ 12 | Get token positions for the simple MCQA task. 13 | 14 | Args: 15 | pipeline: The language model pipeline with tokenizer 16 | causal_model: The causal model for the task 17 | 18 | Returns: 19 | list[TokenPosition]: List of TokenPosition objects for intervention experiments 20 | """ 21 | def get_correct_symbol_index(input, pipeline, causal_model): 22 | """ 23 | Find the index of the correct answer symbol in the prompt. 24 | 25 | Args: 26 | input (Dict): The input dictionary to a causal model 27 | pipeline: The tokenizer pipeline 28 | causal_model: The causal model 29 | 30 | Returns: 31 | list[int]: List containing the index of the correct answer symbol token 32 | """ 33 | # Run the model to get the answer position 34 | output = causal_model.run_forward(input) 35 | pointer = output["answer_pointer"] 36 | correct_symbol = output[f"symbol{pointer}"] 37 | prompt = input["raw_input"] 38 | 39 | # Find all single uppercase letters in the prompt 40 | matches = list(re.finditer(r"\b[A-Z]\b", prompt)) 41 | 42 | # Find the match corresponding to our correct symbol 43 | symbol_match = None 44 | for match in matches: 45 | if prompt[match.start():match.end()] == correct_symbol: 46 | symbol_match = match 47 | break 48 | 49 | if not symbol_match: 50 | raise ValueError(f"Could not find correct symbol {correct_symbol} in prompt: {prompt}") 51 | 52 | # Get the substring up to the symbol match end 53 | substring = prompt[:symbol_match.end()] 54 | tokenized_substring = list(pipeline.load(substring)["input_ids"][0]) 55 | 56 | # The symbol token will be at the end of the substring 57 | return [len(tokenized_substring) - 1] 58 | 59 | # Create TokenPosition objects 60 | token_positions = [ 61 | TokenPosition(lambda x: get_correct_symbol_index(x, pipeline, causal_model), pipeline, id="correct_symbol"), 62 | TokenPosition(lambda x: [get_correct_symbol_index(x, pipeline, causal_model)[0]+1], pipeline, id="correct_symbol_period"), 63 | TokenPosition(lambda x: get_last_token_index(x, pipeline), pipeline, id="last_token") 64 | ] 65 | return token_positions -------------------------------------------------------------------------------- /MIB-causal-variable-track/baselines/run_quick_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script to run all baseline scripts with --quick_test flag 4 | # This runs minimal tests with reduced dataset sizes and layers 5 | 6 | echo "===============================================" 7 | echo "Running all baselines with --quick_test flag" 8 | echo "===============================================" 9 | 10 | # Colors for output 11 | GREEN='\033[0;32m' 12 | RED='\033[0;31m' 13 | NC='\033[0m' # No Color 14 | 15 | # Function to run a command and check its status 16 | run_test() { 17 | local name=$1 18 | shift # Remove the first argument (name) from $@ 19 | 20 | echo -e "\n${GREEN}Running $name...${NC}" 21 | echo "Command: $@" 22 | 23 | if "$@"; then 24 | echo -e "${GREEN}✓ $name completed successfully${NC}" 25 | else 26 | echo -e "${RED}✗ $name failed${NC}" 27 | exit 1 28 | fi 29 | } 30 | 31 | # Change to baselines directory 32 | cd "$(dirname "$0")" 33 | 34 | # Run simple MCQA baseline 35 | run_test "Simple MCQA Baseline" python simple_MCQA_baselines.py --quick_test 36 | 37 | # Run ARC baseline 38 | run_test "ARC Baseline" python ARC_baselines.py --quick_test 39 | 40 | # Run arithmetic baseline 41 | run_test "Arithmetic Baseline" python arithmetic_baselines.py --quick_test 42 | 43 | # Run RAVEL baseline 44 | run_test "RAVEL Baseline" python ravel_baselines.py --quick_test 45 | 46 | # Run IOI linear parameter learning (needed for IOI baseline) 47 | echo -e "\n${GREEN}Preparing for IOI baseline...${NC}" 48 | run_test "IOI Linear Parameter Learning (GPT2)" python ioi_baselines/ioi_learn_linear_params.py --model gpt2 --quick_test --output_file ioi_linear_params_quick.json 49 | # Run IOI baseline with the computed parameters 50 | run_test "IOI Baseline" python ioi_baselines/ioi_baselines.py --model gpt2 --linear_params ioi_linear_params_quick.json --quick_test --run_baselines --run_brute_force 51 | 52 | run_test "IOI Linear Parameter Learning (llama)" python ioi_baselines/ioi_learn_linear_params.py --model llama --quick_test --output_file ioi_linear_params_quick.json 53 | # Run IOI baseline with the computed parameters 54 | run_test "IOI Baseline" python ioi_baselines/ioi_baselines.py --model llama --linear_params ioi_linear_params_quick.json --quick_test --run_baselines --run_brute_force 55 | 56 | run_test "IOI Linear Parameter Learning (qwen)" python ioi_baselines/ioi_learn_linear_params.py --model qwen --quick_test --output_file ioi_linear_params_quick.json 57 | # Run IOI baseline with the computed parameters 58 | run_test "IOI Baseline" python ioi_baselines/ioi_baselines.py --model qwen --linear_params ioi_linear_params_quick.json --quick_test --run_baselines --run_brute_force 59 | 60 | run_test "IOI Linear Parameter Learning (gemma)" python ioi_baselines/ioi_learn_linear_params.py --model gemma --quick_test --output_file ioi_linear_params_quick.json 61 | Run IOI baseline with the computed parameters 62 | run_test "IOI Baseline" python ioi_baselines/ioi_baselines.py --model gemma --linear_params ioi_linear_params_quick.json --quick_test --run_baselines --run_brute_force --heads_list "(3, 1)" "(4, 2)" 63 | 64 | echo -e "\n${GREEN}===============================================${NC}" 65 | echo -e "${GREEN}All quick tests completed successfully!${NC}" 66 | echo -e "${GREEN}===============================================${NC}" -------------------------------------------------------------------------------- /MIB-causal-variable-track/tasks/hf_dataloader.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Dataset 2 | from causal.counterfactual_dataset import CounterfactualDataset 3 | import os 4 | 5 | def load_hf_dataset(dataset_path, split, parse_fn, hf_token=None, size=None, 6 | name=None, ignore_names=[], shuffle=False, filter_fn=None): 7 | """ 8 | Load a HuggingFace dataset and reformat it to be compatible with the 9 | CounterfactualDataset class. 10 | 11 | Args: 12 | dataset_path (str): The path or name of the HF dataset 13 | split (str): Dataset split to load ("train", "test", or "validation") 14 | hf_token (str): HuggingFace authentication token 15 | size (int, optional): Number of examples to load. Defaults to None (all). 16 | name (str, optional): Sub-configuration name for the dataset. Defaults to None. 17 | parse_fn (callable, optional): A function that takes a single row from a 18 | dataset and returns a string or dict to be placed in the "input" column. 19 | If None, defaults to using row["question"] or row["prompt"]. 20 | ignore_names (list, optional): Names to ignore when looking for counterfactuals. 21 | Defaults to empty list. 22 | shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. 23 | 24 | Returns: 25 | dict: A dictionary containing CounterfactualDataset objects, one for each 26 | counterfactual type. Keys are formatted as "{counterfactual_name}_{split}". 27 | """ 28 | if hf_token is None: 29 | hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") 30 | base_dataset = load_dataset(dataset_path, name, split=split, token=hf_token) 31 | if filter_fn is not None: 32 | base_dataset = base_dataset.filter(filter_fn) 33 | 34 | if shuffle: 35 | base_dataset = base_dataset.shuffle(seed=42) 36 | 37 | if size is not None: 38 | if size > len(base_dataset): 39 | size = len(base_dataset) 40 | base_dataset = base_dataset.select(range(size)) 41 | 42 | # Retrieve all counterfactual names 43 | sample = base_dataset[0] 44 | counterfactual_names = [ 45 | k for k in sample.keys() 46 | if k.endswith('_counterfactual') 47 | and not any(name in k for name in ignore_names) 48 | ] 49 | 50 | data_dict = { 51 | counterfactual_name: {"input": [], "counterfactual_inputs": []} 52 | for counterfactual_name in counterfactual_names 53 | } 54 | 55 | for row in base_dataset: 56 | try: 57 | input_obj = parse_fn(row) 58 | except Exception as e: 59 | print(f"Error parsing input: {e} for row {row}") 60 | continue 61 | 62 | for counterfactual_name in counterfactual_names: 63 | if counterfactual_name in row: 64 | cf_data = row[counterfactual_name] 65 | else: 66 | cf_data = [] 67 | 68 | data_dict[counterfactual_name]["input"].append(input_obj) 69 | counterfactual_obj = parse_fn(cf_data) 70 | if not isinstance(counterfactual_obj, list): 71 | counterfactual_obj = [counterfactual_obj] 72 | data_dict[counterfactual_name]["counterfactual_inputs"].append( 73 | counterfactual_obj 74 | ) 75 | 76 | 77 | datasets = {} 78 | for counterfactual_name in data_dict.keys(): 79 | try: 80 | name = counterfactual_name.replace("_counterfactual", "_" + split) 81 | hf_dataset = Dataset.from_dict(data_dict[counterfactual_name]) 82 | datasets[name] = CounterfactualDataset( 83 | dataset=hf_dataset, 84 | id=f"{name}" 85 | ) 86 | except Exception as e: 87 | print( 88 | f"Error creating dataset for {counterfactual_name}: {e} " 89 | f"{type(data_dict[counterfactual_name])} " 90 | f"{data_dict[counterfactual_name]['input'][0]} " 91 | f"{data_dict[counterfactual_name]['counterfactual_inputs'][0]} " 92 | f"{split}" 93 | ) 94 | assert False 95 | 96 | return datasets -------------------------------------------------------------------------------- /MIB-causal-variable-track/tasks/IOI_task/templates.json: -------------------------------------------------------------------------------- 1 | [ 2 | "After the lunch, {name_A} and {name_B} went to the {place}. {name_C} gave a {object} to", 3 | "After the lunch, {name_A} and {name_B} went to the {place}. {name_C} gave an {object} to", 4 | "After {name_A} and {name_B} went to the {place}, {name_C} gave a {object} to", 5 | "After {name_A} and {name_B} went to the {place}, {name_C} gave an {object} to", 6 | "After {name_A} and {name_B} spent some time at the {place}, {name_C} offered a {object} to", 7 | "After {name_A} and {name_B} spent some time at the {place}, {name_C} offered an {object} to", 8 | "After {name_A} and {name_B} finished their work at the {place}, {name_C} gave a {object} to", 9 | "After {name_A} and {name_B} finished their work at the {place}, {name_C} gave an {object} to", 10 | "After {name_A} and {name_B} went to the {place}, {name_C} handed a {object} to", 11 | "After {name_A} and {name_B} went to the {place}, {name_C} handed an {object} to", 12 | "Afterwards, {name_A} and {name_B} went to the {place}. {name_C} gave a {object} to", 13 | "Afterwards, {name_A} and {name_B} went to the {place}. {name_C} gave an {object} to", 14 | 15 | "As {name_A} and {name_B} left the {place}, {name_C} gave a {object} to", 16 | "As {name_A} and {name_B} left the {place}, {name_C} gave an {object} to", 17 | "At the {place}, {name_A} and {name_B} found a {object}. {name_C} gave it to", 18 | 19 | "Before {name_A} and {name_B} left the {place}, {name_C} decided to give a {object} to", 20 | "Before {name_A} and {name_B} left the {place}, {name_C} decided to give an {object} to", 21 | 22 | "Friends {name_A} and {name_B} found a {object} at the {place}. {name_C} gave it to", 23 | "Friends {name_A} and {name_B} found an {object} at the {place}. {name_C} gave it to", 24 | 25 | "Just as {name_A} and {name_B} were leaving the {place}, {name_C} offered a {object} to", 26 | "Just as {name_A} and {name_B} were leaving the {place}, {name_C} offered an {object} to", 27 | 28 | "Later, {name_A} and {name_B} met at the {place}. {name_C} bought a {object} for", 29 | "Later, {name_A} and {name_B} met at the {place}. {name_C} bought an {object} for", 30 | 31 | "{name_A} and {name_B} walked to the {place}. {name_C} gave a {object} to", 32 | "{name_A} and {name_B} walked to the {place}. {name_C} gave an {object} to", 33 | "{name_A} and {name_B} decided to visit the {place}. Then, {name_C} gave a {object} to", 34 | "{name_A} and {name_B} decided to visit the {place}. Then, {name_C} gave an {object} to", 35 | "{name_A} and {name_B} were sitting together. {name_C} handed over a {object} to", 36 | "{name_A} and {name_B} were sitting together. {name_C} handed over an {object} to", 37 | "{name_A} and {name_B} were sitting at the {place}. {name_C} handed over a {object} to", 38 | "{name_A} and {name_B} were sitting at the {place}. {name_C} handed over an {object} to", 39 | "{name_A} and {name_B} were working together. {name_C} gave a {object} to", 40 | "{name_A} and {name_B} were working together. {name_C} gave an {object} to", 41 | "{name_A} and {name_B} went to the {place}. Then, {name_C} handed a {object} to", 42 | "{name_A} and {name_B} went to the {place}. Then, {name_C} handed an {object} to", 43 | "{name_A} and {name_B} were working at the {place}. Then, {name_C} decided to give a {object} to", 44 | "{name_A} and {name_B} were working at the {place}. Then, {name_C} decided to give an {object} to", 45 | "{name_A} and {name_B} were thinking about going to the {place}. {name_C} wanted to give a {object} to", 46 | "{name_A} and {name_B} were thinking about going to the {place}. {name_C} wanted to give an {object} to", 47 | "{name_A} and {name_B} had a lot of fun at the {place}. {name_C} gave a {object} to", 48 | "{name_A} and {name_B} had a lot of fun at the {place}. {name_C} gave an {object} to", 49 | 50 | "Once {name_A} and {name_B} finished at the {place}, {name_C} gave a {object} to", 51 | "Once {name_A} and {name_B} finished at the {place}, {name_C} gave an {object} to", 52 | "Once {name_A} and {name_B} arrived at the {place}, {name_C} gave a {object} to", 53 | "Once {name_A} and {name_B} arrived at the {place}, {name_C} gave an {object} to", 54 | 55 | "Then, {name_A} and {name_B} went to the {place}. {name_C} gave a {object} to", 56 | "Then, {name_A} and {name_B} went to the {place}. {name_C} gave an {object} to", 57 | "Then, {name_A} and {name_B} had a lot of fun at the {place}. {name_C} gave a {object} to", 58 | "Then, {name_A} and {name_B} had a lot of fun at the {place}. {name_C} gave an {object} to", 59 | "Then, {name_A} and {name_B} were working at the {place}. {name_C} decided to give a {object} to", 60 | "Then, {name_A} and {name_B} were working at the {place}. {name_C} decided to give an {object} to", 61 | "Then, {name_A} and {name_B} were thinking about going to the {place}. {name_C} wanted to give a {object} to", 62 | "Then, {name_A} and {name_B} were thinking about going to the {place}. {name_C} wanted to give an {object} to", 63 | "Then, {name_A} and {name_B} had a long argument, and afterwards {name_C} said to", 64 | "Then, {name_A} and {name_B} had a long argument. Afterwards {name_C} said to", 65 | "The {place} {name_A} and {name_B} went to had a {object}. {name_C} gave it to", 66 | 67 | "When {name_A} and {name_B} got a {object} at the {place}, {name_C} decided to give it to", 68 | "When {name_A} and {name_B} got a {object} at the {place}, {name_C} decided to give the {object} to", 69 | "When {name_A} and {name_B} were done exploring the {place}, {name_C} decided to give a {object} to", 70 | "When {name_A} and {name_B} were about to leave the {place}, {name_C} gave a {object} to", 71 | 72 | "While {name_A} and {name_B} were working at the {place}, {name_C} gave a {object} to", 73 | "While {name_A} and {name_B} were working at the {place}, {name_C} gave an {object} to", 74 | "While {name_A} and {name_B} were commuting to the {place}, {name_C} gave a {object} to", 75 | "While {name_A} and {name_B} were commuting to the {place}, {name_C} gave an {object} to", 76 | "While {name_A} and {name_B} relaxed at the {place}, {name_C} presented a {object} to", 77 | "While {name_A} and {name_B} relaxed at the {place}, {name_C} presented an {object} to", 78 | "While {name_A} and {name_B} relaxed, {name_C} gave a {object} to", 79 | "While {name_A} and {name_B} relaxed, {name_C} gave an {object} to", 80 | "While {name_A} and {name_B} were playing at the {place}, {name_C} gave a {object} to", 81 | "While {name_A} and {name_B} were playing at the {place}, {name_C} gave an {object} to" 82 | ] 83 | -------------------------------------------------------------------------------- /MIB-causal-variable-track/tasks/ARC/ARC.py: -------------------------------------------------------------------------------- 1 | import json, random, sys, os 2 | from pathlib import Path 3 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 4 | sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) 5 | 6 | 7 | from CausalAbstraction.causal.causal_model import CausalModel 8 | from CausalAbstraction.neural.LM_units import TokenPosition, get_last_token_index 9 | 10 | from copy import deepcopy 11 | from tasks.hf_dataloader import load_hf_dataset 12 | import re 13 | 14 | 15 | def get_causal_model(): 16 | """ 17 | Create and return the causal model for ARC Easy task. 18 | """ 19 | NUM_CHOICES = 4 20 | ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" 21 | 22 | variables = ["raw_input", "answer_pointer", "answer", "answerKey", "raw_output"] + ["symbol" + str(x) for x in range(NUM_CHOICES)] 23 | 24 | values = {} 25 | values.update({"symbol" + str(x): list(ALPHABET) for x in range(NUM_CHOICES)}) 26 | values.update({"answer_pointer": list(range(NUM_CHOICES)), "answer": list(ALPHABET)}) 27 | values.update({"answerKey": list(range(NUM_CHOICES))}) 28 | # FIXED: Change None to empty list for raw_input and raw_output 29 | values.update({"raw_input": [""], "raw_output": [""]}) 30 | 31 | parents = {"answer":["answer_pointer"] + ["symbol" + str(x) for x in range(NUM_CHOICES)], 32 | "answer_pointer": ["answerKey"], 33 | "answerKey": [], 34 | "raw_output": ["answer"], 35 | "raw_input": []} 36 | parents.update({"symbol" + str(x): [] for x in range(NUM_CHOICES)}) 37 | 38 | def get_raw_input(): 39 | return "" 40 | 41 | def get_symbol(): 42 | return random.choice(list(ALPHABET)) 43 | 44 | def get_answer_pointer(answerKey): 45 | return answerKey 46 | 47 | def get_answer(answer_pointer, *symbols): 48 | return " " + symbols[answer_pointer] 49 | 50 | def get_raw_output(answer): 51 | return answer 52 | 53 | def get_answerKey(): 54 | return random.choice(list(range(NUM_CHOICES))) 55 | 56 | mechanisms = { 57 | "raw_input": get_raw_input, 58 | **{f"symbol{i}": get_symbol for i in range(NUM_CHOICES)}, 59 | "answer_pointer": get_answer_pointer, 60 | "answer": get_answer, 61 | "answerKey": get_answerKey, 62 | "raw_output": get_raw_output 63 | } 64 | 65 | # Create and initialize the model 66 | return CausalModel(variables, values, parents, mechanisms, id=f"ARC_easy") 67 | 68 | 69 | def get_counterfactual_datasets(hf=True, size=None, load_private_data=False): 70 | """ 71 | Load and return counterfactual datasets for ARC Easy task. 72 | """ 73 | # Filter function to only keep examples with exactly 4 choices 74 | def has_four_choices(example): 75 | return len(example.get("choices", {}).get("label", [])) == 4 76 | 77 | if hf: 78 | # Load dataset from HuggingFace with customized parsing 79 | datasets = {} 80 | for split in ["train", "validation", "test"]: 81 | temp = load_hf_dataset( 82 | dataset_path="mib-bench/arc_easy", 83 | split=split, 84 | parse_fn=parse_arc_easy_example, 85 | size=size, 86 | ignore_names=["symbol"], 87 | filter_fn=has_four_choices, # Add filter to only keep 4-choice questions 88 | shuffle=True # Shuffle the dataset for better training 89 | ) 90 | datasets.update(temp) 91 | 92 | if load_private_data: 93 | private = load_hf_dataset( 94 | dataset_path="mib-bench/arc_easy_private_test", 95 | split="test", 96 | parse_fn=parse_arc_easy_example, 97 | size=size, 98 | ignore_names=["symbol"], 99 | filter_fn=has_four_choices, # Add filter to only keep 4-choice questions 100 | shuffle=True # Shuffle the dataset for better training 101 | ) 102 | datasets.update({k+"private":v for k,v in private.items()}) 103 | 104 | return datasets 105 | 106 | # Non-HF implementation would go here if needed 107 | # For now, just return empty dict for consistency 108 | return {} 109 | 110 | 111 | def parse_arc_easy_example(row): 112 | """ 113 | Customized parsing function for the ARC Easy dataset. 114 | Returns a variables dict compatible with the causal model. 115 | """ 116 | # Get the prompt string 117 | prompt_str = row.get("prompt", "") 118 | 119 | # Create variables dictionary 120 | variables_dict = { 121 | "raw_input": prompt_str, 122 | "answerKey": row["answerKey"] 123 | } 124 | 125 | # Parse choice labels 126 | choice_labels = row["choices"]["label"] 127 | for i in range(len(choice_labels)): 128 | variables_dict[f"symbol{i}"] = str(choice_labels[i]) 129 | 130 | return variables_dict 131 | 132 | 133 | def get_token_positions(pipeline, causal_model): 134 | """ 135 | Get token positions for ARC Easy task interventions. 136 | """ 137 | def get_correct_symbol_index(input_dict, pipeline, causal_model): 138 | """ 139 | Find the index of the correct answer symbol token in the prompt. 140 | 141 | Args: 142 | input_dict (dict): The input dictionary to a causal model 143 | pipeline: The tokenizer pipeline 144 | causal_model: The causal model 145 | 146 | Returns: 147 | list[int]: List containing the index of the correct answer symbol token 148 | """ 149 | # Run the model to get the answer position 150 | output = causal_model.run_forward(input_dict) 151 | pointer = output["answer_pointer"] 152 | correct_symbol = output[f"symbol{pointer}"] 153 | prompt = input_dict["raw_input"] 154 | 155 | # Find all single uppercase letters in the prompt 156 | matches = list(re.finditer(r"\b[A-Z]\b", prompt)) 157 | 158 | # Find the match corresponding to our correct symbol 159 | symbol_match = None 160 | for match in matches: 161 | if prompt[match.start():match.end()] == correct_symbol: 162 | symbol_match = match 163 | break 164 | 165 | if not symbol_match: 166 | raise ValueError(f"Could not find correct symbol {correct_symbol} in prompt: {prompt}") 167 | 168 | # Get the substring up to the symbol match end 169 | substring = prompt[:symbol_match.end()] 170 | tokenized_substring = list(pipeline.load(substring)["input_ids"][0]) 171 | 172 | # The symbol token will be at the end of the substring 173 | return [len(tokenized_substring) - 1] 174 | 175 | # Create TokenPosition objects 176 | token_positions = [ 177 | TokenPosition(lambda x: get_correct_symbol_index(x, pipeline, causal_model), pipeline, id="correct_symbol"), 178 | TokenPosition(lambda x: get_last_token_index(x, pipeline), pipeline, id="last_token") 179 | ] 180 | return token_positions 181 | 182 | 183 | def is_unique(lst): 184 | """Check if all elements in list are unique.""" 185 | return len(lst) == len(set(lst)) -------------------------------------------------------------------------------- /MIB-causal-variable-track/tasks/two_digit_addition_task/arithmetic.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | from pathlib import Path 3 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 4 | sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) 5 | 6 | from CausalAbstraction.causal.causal_model import CausalModel 7 | from CausalAbstraction.neural.LM_units import TokenPosition, get_last_token_index 8 | 9 | from tasks.hf_dataloader import load_hf_dataset 10 | 11 | from copy import deepcopy 12 | import re 13 | import random 14 | 15 | def get_causal_model(): 16 | variables = [ 17 | "raw_input", # Required by CausalModel 18 | "op1_tens", "op1_ones", 19 | "op2_tens", "op2_ones", 20 | "ones_carry", 21 | "hundreds_out", "tens_out", "ones_out", 22 | "raw_output" # Required by CausalModel 23 | ] 24 | 25 | # Allowed values for each variable. 26 | values = { 27 | "raw_input": [""], # Placeholder, actual values generated by mechanism 28 | "op1_tens": list(range(10)), 29 | "op1_ones": list(range(10)), 30 | "op2_tens": list(range(10)), 31 | "op2_ones": list(range(10)), 32 | "ones_carry": [0, 1], 33 | "ones_out": list(range(10)), 34 | "tens_out": list(range(10)), 35 | "hundreds_out": [0, 1], 36 | "raw_output": [""] # Placeholder, actual values generated by mechanism 37 | } 38 | 39 | # Specify parent relationships for each node. 40 | parents = { 41 | "raw_input": ["op1_tens", "op1_ones", "op2_tens", "op2_ones"], # Depends on the operands 42 | "op1_tens": [], 43 | "op1_ones": [], 44 | "op2_tens": [], 45 | "op2_ones": [], 46 | "ones_carry": ["op1_ones", "op2_ones"], 47 | "ones_out": ["op1_ones", "op2_ones"], 48 | "tens_out": ["op1_tens", "op2_tens", "ones_carry"], 49 | "hundreds_out": ["op1_tens", "op2_tens", "ones_carry"], 50 | "raw_output": ["hundreds_out", "tens_out", "ones_out"] # Depends on the result digits 51 | } 52 | 53 | def raw_output_mechanism(hundreds_out, tens_out, ones_out): 54 | """ 55 | Generate the raw output string based on result digits. 56 | """ 57 | if hundreds_out == 0: 58 | if tens_out == 0: 59 | return f"{ones_out:01d}" 60 | return f"{tens_out:01d}{ones_out:01d}" 61 | return f"{hundreds_out:01d}{tens_out:01d}{ones_out:01d}" 62 | 63 | # Define the mechanisms (the functions computing each node's value). 64 | mechanisms = { 65 | # Generate the raw input based on operands 66 | "raw_input": lambda op1_tens, op1_ones, op2_tens, op2_ones: 67 | f"Q: How much is {op1_tens}{op1_ones} plus {op2_tens}{op2_ones}? A: ", 68 | 69 | # Base input nodes: randomly choose a digit 70 | "op1_tens": lambda: random.choice(list(range(10))), 71 | "op1_ones": lambda: random.choice(list(range(10))), 72 | "op2_tens": lambda: random.choice(list(range(10))), 73 | "op2_ones": lambda: random.choice(list(range(10))), 74 | 75 | # Compute carries and outputs 76 | "ones_carry": lambda op1_ones, op2_ones: 1 if op1_ones + op2_ones > 9 else 0, 77 | "ones_out": lambda op1_ones, op2_ones: (op1_ones + op2_ones) % 10, 78 | "tens_out": lambda op1_tens, op2_tens, ones_carry: (op1_tens + op2_tens + ones_carry) % 10, 79 | "hundreds_out": lambda op1_tens, op2_tens, ones_carry: 1 if op1_tens + op2_tens + ones_carry > 9 else 0, 80 | 81 | # Generate the raw output based on result digits 82 | "raw_output": raw_output_mechanism 83 | } 84 | 85 | return CausalModel(variables, values, parents, mechanisms, id="arithmetic") 86 | 87 | 88 | def get_counterfactual_datasets(hf=True, size=None, load_private_data=False): 89 | """ 90 | Load and return counterfactual datasets for arithmetic task. 91 | """ 92 | if hf: 93 | # Load dataset from HuggingFace with customized parsing 94 | datasets = {} 95 | for split in ["train", "test"]: 96 | temp = load_hf_dataset( 97 | dataset_path="mib-bench/arithmetic_addition", 98 | split=split, 99 | parse_fn=parse_arithmetic_example, 100 | size=size, 101 | filter_fn=lambda example: example["num_digit"] == 2, 102 | ignore_names=["ones_op1", "ones_op2", "tens_op1", "tens_op2", "tens_carry"], 103 | shuffle=True 104 | ) 105 | datasets.update(temp) 106 | 107 | if load_private_data: 108 | private = load_hf_dataset( 109 | dataset_path="mib-bench/arithmetic_addition_private_test", 110 | split="test", 111 | parse_fn=parse_arithmetic_example, 112 | size=size, 113 | filter_fn=lambda example: example["num_digit"] == 2, 114 | ignore_names=["ones_op1", "ones_op2", "tens_op1", "tens_op2", "tens_carry"], 115 | shuffle=True 116 | ) 117 | datasets.update({k+"private":v for k,v in private.items()}) 118 | 119 | return datasets 120 | 121 | # Non-HF implementation would go here if needed 122 | return {} 123 | 124 | 125 | def parse_arithmetic_example(row): 126 | """ 127 | Customized parsing function for the arithmetic task. 128 | Returns a variables dict compatible with the causal model. 129 | """ 130 | # Get the prompt string 131 | prompt_str = row.get("prompt", "") 132 | 133 | # Parse the prompt to extract operands 134 | matches = re.findall(r"\d+", prompt_str) 135 | if len(matches) < 2: 136 | raise ValueError(f"Prompt must contain at least two numbers: {prompt_str}") 137 | 138 | # Take the last two numbers as operands 139 | op1_str, op2_str = matches[-2], matches[-1] 140 | 141 | # Parse into tens and ones 142 | op1_tens = int(op1_str[-2]) if len(op1_str) > 1 else 0 143 | op1_ones = int(op1_str[-1]) 144 | op2_tens = int(op2_str[-2]) if len(op2_str) > 1 else 0 145 | op2_ones = int(op2_str[-1]) 146 | 147 | # Return variables dict (not tuple) 148 | return { 149 | "raw_input": prompt_str, 150 | "op1_tens": op1_tens, 151 | "op1_ones": op1_ones, 152 | "op2_tens": op2_tens, 153 | "op2_ones": op2_ones 154 | } 155 | 156 | 157 | def get_token_positions(pipeline, causal_model): 158 | """ 159 | Get token positions for arithmetic task interventions. 160 | """ 161 | def get_op2_last_token_index(input_dict, pipeline): 162 | """ 163 | Find the index of the last token of the second operand. 164 | """ 165 | # Extract prompt from input dict 166 | prompt = input_dict["raw_input"] if isinstance(input_dict, dict) else input_dict 167 | 168 | matches = list(re.finditer(r"\b\d+\b", prompt)) 169 | if len(matches) < 2: 170 | raise ValueError(f"Prompt must contain at least two numbers: {prompt}") 171 | 172 | op2_match = matches[-1] # Last match 173 | 174 | # Get the substring up to the op2 match end 175 | substring = prompt[:op2_match.end()] 176 | tokenized_substring = list(pipeline.load(substring)["input_ids"][0]) 177 | 178 | # The last token of op2 will be at the end of the substring 179 | return [len(tokenized_substring) - 1] 180 | 181 | token_positions = [ 182 | TokenPosition(lambda x: get_op2_last_token_index(x, pipeline), pipeline, id="op2_last"), 183 | TokenPosition(lambda x: get_last_token_index(x, pipeline), pipeline, id="last") 184 | ] 185 | 186 | return token_positions -------------------------------------------------------------------------------- /MIB-causal-variable-track/tasks/simple_MCQA/simple_MCQA.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | from pathlib import Path 3 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 4 | sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) 5 | 6 | 7 | from CausalAbstraction.causal.causal_model import CausalModel 8 | from tasks.hf_dataloader import load_hf_dataset 9 | 10 | from copy import deepcopy 11 | from CausalAbstraction.neural.LM_units import TokenPosition, get_last_token_index 12 | import re 13 | import random 14 | import json 15 | 16 | def get_causal_model(): 17 | path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'object_color_pairs.json') 18 | #Load grandparent directory 19 | with open(path, 'r') as f: 20 | data = json.load(f) 21 | 22 | OBJECTS = [item['object'] for item in data] 23 | COLORS = [item['color'] for item in data] 24 | COLOR_OBJECTS = [(item["color"], item["object"]) for item in data] 25 | 26 | NUM_CHOICES = 4 27 | ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" 28 | 29 | variables = ["question", "raw_input"] + ["symbol" + str(x) for x in range(NUM_CHOICES)] + ["choice" + str(x) for x in range(NUM_CHOICES)] + [ "answer_pointer", "answer", "raw_output"] 30 | 31 | values = {"choice" + str(x): COLORS for x in range(NUM_CHOICES)} 32 | values.update({"symbol" + str(x): ALPHABET for x in range(NUM_CHOICES)}) 33 | values.update({"answer_pointer": range(NUM_CHOICES), "answer": ALPHABET}) 34 | values.update({"question": COLOR_OBJECTS }) 35 | values.update({"raw_input": None, "raw_output": None}) 36 | 37 | parents = {"answer":["answer_pointer"] + ["symbol" + str(x) for x in range(NUM_CHOICES)], 38 | "answer_pointer": ["question"] + ["choice" + str(x) for x in range(NUM_CHOICES)], 39 | "raw_output": ["answer"], 40 | "raw_input": [], 41 | "question": []} 42 | parents.update({"choice" + str(x): [] for x in range(NUM_CHOICES)}) 43 | parents.update({"symbol" + str(x): [] for x in range(NUM_CHOICES)}) 44 | 45 | def get_question(): 46 | return random.choice(COLOR_OBJECTS) 47 | 48 | def get_symbol(): 49 | return random.choice(list(ALPHABET)) 50 | 51 | def get_choice(): 52 | return random.choice(COLORS) 53 | 54 | def get_answer_pointer(question, *choices): 55 | for i, choice in enumerate(choices): 56 | if choice == question[0]: 57 | return i 58 | 59 | def get_answer(answer_pointer, *symbols): 60 | return " " + symbols[answer_pointer] 61 | 62 | def output_dumper(answer): 63 | return answer 64 | 65 | mechanisms = { 66 | "raw_input": lambda: "", 67 | "question": get_question, 68 | **{f"symbol{i}": get_symbol for i in range(NUM_CHOICES)}, 69 | 70 | **{f"choice{i}": get_choice for i in range(NUM_CHOICES)}, 71 | 72 | "answer_pointer": get_answer_pointer, 73 | "answer": get_answer, 74 | "raw_output": output_dumper, 75 | } 76 | 77 | 78 | # Create and initialize the model 79 | return CausalModel(variables, values, parents, mechanisms, id=f"{NUM_CHOICES}_answer_MCQA") 80 | 81 | 82 | def get_counterfactual_datasets(hf=True, size=None, load_private_data=False): 83 | NUM_CHOICES = 4 # Assuming this is fixed at 4 as in the original code 84 | 85 | if hf: 86 | # Load dataset from HuggingFace with customized parsing 87 | datasets = {} 88 | for split in ["train", "validation", "test"]: 89 | temp = load_hf_dataset( 90 | dataset_path="mib-bench/copycolors_mcqa", 91 | split=split, 92 | name=f"{NUM_CHOICES}_answer_choices", 93 | parse_fn=parse_mcqa_example, 94 | size=size, 95 | ignore_names=["noun", "color", "symbol"] 96 | ) 97 | datasets.update(temp) 98 | if load_private_data: 99 | private = load_hf_dataset( 100 | dataset_path="mib-bench/copycolors_mcqa_private_test", 101 | split="test", 102 | name=f"{NUM_CHOICES}_answer_choices", 103 | parse_fn=parse_mcqa_example, 104 | size=size, 105 | ignore_names=["noun", "color", "symbol"] 106 | ) 107 | datasets.update({k+"private":v for k,v in private.items()}) 108 | 109 | return datasets 110 | 111 | 112 | def get_token_positions(pipeline, causal_model): 113 | def get_correct_symbol_index(input, pipeline, causal_model): 114 | """ 115 | Find the index of the correct answer symbol in the prompt. 116 | 117 | Args: 118 | input (Dict): The input dictionary to a causal model 119 | pipeline: The tokenizer pipeline 120 | 121 | Returns: 122 | list[int]: List containing the index of the correct answer symbol token 123 | """ 124 | # Run the model to get the answer position 125 | output = causal_model.run_forward(input) 126 | pointer = output["answer_pointer"] 127 | correct_symbol = output[f"symbol{pointer}"] 128 | prompt = input["raw_input"] 129 | 130 | # Find all single uppercase letters in the prompt 131 | matches = list(re.finditer(r"\b[A-Z]\b", prompt)) 132 | 133 | # Find the match corresponding to our correct symbol 134 | symbol_match = None 135 | for match in matches: 136 | if prompt[match.start():match.end()] == correct_symbol: 137 | symbol_match = match 138 | break 139 | 140 | if not symbol_match: 141 | raise ValueError(f"Could not find correct symbol {correct_symbol} in prompt: {prompt}") 142 | 143 | # Get the substring up to the symbol match end 144 | substring = prompt[:symbol_match.end()] 145 | tokenized_substring = list(pipeline.load(substring)["input_ids"][0]) 146 | 147 | # The symbol token will be at the end of the substring 148 | return [len(tokenized_substring) - 1] 149 | 150 | # Create TokenPosition object 151 | token_positions = [ 152 | TokenPosition(lambda x: get_correct_symbol_index(x, pipeline, causal_model), pipeline, id="correct_symbol"), 153 | TokenPosition(lambda x: [get_correct_symbol_index(x, pipeline, causal_model)[0]+1], pipeline, id="correct_symbol_period"), 154 | TokenPosition(lambda x: get_last_token_index(x, pipeline), pipeline, id="last_token") 155 | ] 156 | return token_positions 157 | 158 | 159 | def parse_mcqa_example(row): 160 | """ 161 | Customized parsing function for the MCQA task. 162 | Returns a tuple of (prompt_str, variables_dict) like the arithmetic example. 163 | """ 164 | # Get the prompt/question text 165 | prompt_str = row.get("prompt", "") 166 | 167 | # Extract object and color information 168 | q_str = prompt_str 169 | if " is " in q_str: 170 | noun, color = q_str.split(" is ", 1) 171 | elif " are " in q_str: 172 | noun, color = q_str.split(" are ", 1) 173 | noun = noun.strip().lower() 174 | color = color.split(".", 1)[0].strip().lower() 175 | 176 | # Process choices 177 | choice_labels = row["choices"]["label"] 178 | choice_texts = row["choices"]["text"] 179 | 180 | # Create the variables dictionary 181 | variables_dict = { 182 | "question": (color, noun) 183 | } 184 | 185 | for i in range(len(choice_labels)): 186 | variables_dict[f"symbol{i}"] = str(choice_labels[i]) 187 | variables_dict[f"choice{i}"] = str(choice_texts[i]) 188 | 189 | variables_dict["raw_input"] = prompt_str 190 | 191 | # Return tuple of (prompt_str, variables_dict) to match the other file's format 192 | return variables_dict -------------------------------------------------------------------------------- /MIB-causal-variable-track/baselines/arithmetic_baselines.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 4 | from tasks.two_digit_addition_task.arithmetic import get_token_positions, get_counterfactual_datasets, get_causal_model 5 | from experiments.aggregate_experiments import residual_stream_baselines 6 | from neural.pipeline import LMPipeline 7 | from experiments.filter_experiment import FilterExperiment 8 | import torch 9 | import gc 10 | import os 11 | 12 | if __name__ == "__main__": 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser(description="Run arithmetic experiments with optional flags.") 16 | parser.add_argument("--skip_gemma", action="store_true", help="Skip running experiments for Gemma model.") 17 | parser.add_argument("--skip_llama", action="store_true", help="Skip running experiments for Llama model.") 18 | parser.add_argument("--use_gpu1", action="store_true", help="Use GPU1 instead of GPU0 if available.") 19 | parser.add_argument("--methods", nargs="+", 20 | default=["full_vector", "DAS", "DBM+SVD", "DBM+PCA", "DBM", "DBM+SAE"], 21 | help="List of methods to run") 22 | parser.add_argument("--batch_size", type=int, default=256, help="Batch size for training") 23 | parser.add_argument("--eval_batch_size", type=int, default=1024, help="Batch size for evaluation") 24 | parser.add_argument("--results_dir", type=str, default="arithmetic_results", help="Directory to save results") 25 | parser.add_argument("--model_dir", type=str, default="arithmetic_models", help="Directory to save trained models") 26 | parser.add_argument("--quick_test", action="store_true", help="Run quick test with reduced dataset size and layers") 27 | args = parser.parse_args() 28 | 29 | # Clear memory before starting 30 | gc.collect() 31 | torch.cuda.empty_cache() 32 | 33 | # Device setup 34 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 35 | if args.use_gpu1 and torch.cuda.is_available(): 36 | device = "cuda:1" 37 | 38 | # Check function for evaluating model outputs 39 | def checker(output_text, expected): 40 | # Clean the output by extracting just the numbers 41 | import re 42 | numbers_in_output = re.findall(r'\d+', output_text) 43 | if not numbers_in_output: 44 | return False 45 | 46 | # Get the first number found 47 | first_number = numbers_in_output[0] 48 | 49 | return first_number == expected 50 | 51 | # Function to clear memory between experiments 52 | def clear_memory(): 53 | gc.collect() 54 | if torch.cuda.is_available(): 55 | torch.cuda.empty_cache() 56 | torch.cuda.synchronize() 57 | 58 | # Get counterfactual datasets and causal model 59 | dataset_size = 10 if args.quick_test else 10000 60 | counterfactual_datasets = get_counterfactual_datasets(hf=True, size=dataset_size) 61 | causal_model = get_causal_model() 62 | 63 | # Print available datasets 64 | print("Available datasets:", counterfactual_datasets.keys()) 65 | 66 | # Set up models to test 67 | models = [] 68 | if not args.skip_gemma: 69 | models.append("google/gemma-2-2b") 70 | if not args.skip_llama: 71 | models.append("meta-llama/Meta-Llama-3.1-8B-Instruct") 72 | 73 | for model_name in models: 74 | print(f"\n===== Testing model: {model_name} =====") 75 | 76 | # Set up LM Pipeline with appropriate max_new_tokens for each model 77 | if "llama" in model_name.lower(): 78 | max_new_tokens = 1 79 | elif "gemma" in model_name.lower(): 80 | max_new_tokens = 3 81 | else: 82 | max_new_tokens = 3 83 | 84 | pipeline = LMPipeline(model_name, max_new_tokens=max_new_tokens, device=device, dtype=torch.float16) 85 | pipeline.tokenizer.padding_side = "left" 86 | print("DEVICE:", pipeline.model.device) 87 | 88 | # Get a sample input and check model's prediction 89 | sampled_example = next(iter(counterfactual_datasets.values()))[0] 90 | print("INPUT:", sampled_example["input"]) 91 | print("EXPECTED OUTPUT:", causal_model.run_forward(sampled_example["input"])["raw_output"]) 92 | print("MODEL PREDICTION:", pipeline.dump(pipeline.generate(sampled_example["input"]))) 93 | 94 | # Filter the datasets based on model performance 95 | print("\nFiltering datasets based on model performance...") 96 | exp = FilterExperiment(pipeline, causal_model, checker) 97 | filtered_datasets = exp.filter(counterfactual_datasets, verbose=True, batch_size=args.eval_batch_size) 98 | 99 | # Get token positions for intervention 100 | token_positions = get_token_positions(pipeline, causal_model) 101 | 102 | # Display token highlighting for a sample 103 | print("\nToken positions highlighted in samples:") 104 | for dataset in filtered_datasets.values(): 105 | for token_position in token_positions: 106 | example = dataset[0] 107 | print(token_position.highlight_selected_token(example["input"])) 108 | break 109 | break 110 | 111 | # Clear memory before running experiments 112 | clear_memory() 113 | 114 | # Setup experiment configuration 115 | start = 0 116 | end = 1 if args.quick_test else pipeline.get_num_layers() 117 | 118 | config = { 119 | "batch_size": args.batch_size, 120 | "evaluation_batch_size": args.eval_batch_size, 121 | "training_epoch": 1, 122 | "n_features": 16, 123 | "regularization_coefficient": 0.0, 124 | "output_scores": False 125 | } 126 | 127 | # Adjust batch size for Llama 128 | if "llama" in model_name.lower(): 129 | config["batch_size"] = 256 130 | config["evaluation_batch_size"] = 1024 131 | 132 | # Prepare dataset names - based on the HF dataset structure 133 | names = ["random", "ones_carry"] 134 | 135 | # Make sure results and model directories exist 136 | if not os.path.exists(args.results_dir): 137 | os.makedirs(args.results_dir) 138 | 139 | if not os.path.exists(args.model_dir): 140 | os.makedirs(args.model_dir) 141 | 142 | # Run experiments for ones_carry variable only 143 | print(f"\nRunning experiments for target variable: ones_carry") 144 | 145 | # Prepare train and test data dictionaries 146 | train_data = {} 147 | test_data = {} 148 | 149 | for name in names: 150 | if name + "_train" in filtered_datasets: 151 | train_data[name + "_train"] = filtered_datasets[name + "_train"] 152 | if name + "_test" in filtered_datasets: 153 | test_data[name + "_test"] = filtered_datasets[name + "_test"] 154 | if name + "_testprivate" in filtered_datasets: 155 | test_data[name + "_testprivate"] = filtered_datasets[name + "_testprivate"] 156 | 157 | residual_stream_baselines( 158 | pipeline=pipeline, 159 | task=causal_model, 160 | token_positions=token_positions, 161 | train_data=train_data, 162 | test_data=test_data, 163 | config=config, 164 | target_variables=["ones_carry"], 165 | checker=checker, 166 | start=start, 167 | end=end, 168 | verbose=True, 169 | model_dir=os.path.join(args.model_dir, "ones_carry"), 170 | results_dir=args.results_dir, 171 | methods=args.methods 172 | ) 173 | clear_memory() 174 | 175 | # Clean up pipeline to free memory before starting next model 176 | del pipeline 177 | clear_memory() 178 | 179 | print("\nAll experiments completed.") -------------------------------------------------------------------------------- /MIB-causal-variable-track/tasks/simple_MCQA/object_color_pairs.json: -------------------------------------------------------------------------------- 1 | [ 2 | {"object": "lemon", "color": "yellow"}, 3 | {"object": "banana", "color": "yellow"}, 4 | {"object": "daffodil", "color": "yellow"}, 5 | {"object": "sunflower", "color": "yellow"}, 6 | {"object": "dandelion", "color": "yellow"}, 7 | {"object": "canary", "color": "yellow"}, 8 | {"object": "buttercup", "color": "yellow"}, 9 | {"object": "corn", "color": "yellow"}, 10 | {"object": "pineapple", "color": "yellow"}, 11 | {"object": "bumblebee", "color": "yellow"}, 12 | {"object": "apple", "color": "red"}, 13 | {"object": "cherry", "color": "red"}, 14 | {"object": "strawberry", "color": "red"}, 15 | {"object": "rose", "color": "red"}, 16 | {"object": "tomato", "color": "red"}, 17 | {"object": "cardinal", "color": "red"}, 18 | {"object": "ruby", "color": "red"}, 19 | {"object": "lobster", "color": "red"}, 20 | {"object": "brick", "color": "red"}, 21 | {"object": "barn", "color": "red"}, 22 | {"object": "grass", "color": "green"}, 23 | {"object": "emerald", "color": "green"}, 24 | {"object": "lime", "color": "green"}, 25 | {"object": "spinach", "color": "green"}, 26 | {"object": "lettuce", "color": "green"}, 27 | {"object": "kale", "color": "green"}, 28 | {"object": "broccoli", "color": "green"}, 29 | {"object": "cucumber", "color": "green"}, 30 | {"object": "asparagus", "color": "green"}, 31 | {"object": "frog", "color": "green"}, 32 | {"object": "sky", "color": "blue"}, 33 | {"object": "ocean", "color": "blue"}, 34 | {"object": "sapphire", "color": "green"}, 35 | {"object": "denim", "color": "blue"}, 36 | {"object": "jeans", "color": "blue"}, 37 | {"object": "shark", "color": "grey"}, 38 | {"object": "dolphin", "color": "grey"}, 39 | {"object": "whale", "color": "grey"}, 40 | {"object": "seal", "color": "grey"}, 41 | {"object": "mackerel", "color": "grey"}, 42 | {"object": "carrot", "color": "orange"}, 43 | {"object": "orange", "color": "orange"}, 44 | {"object": "tangerine", "color": "orange"}, 45 | {"object": "pumpkin", "color": "orange"}, 46 | {"object": "marigold", "color": "orange"}, 47 | {"object": "clementine", "color": "orange"}, 48 | {"object": "mandarin", "color": "orange"}, 49 | {"object": "basketball", "color": "orange"}, 50 | {"object": "persimmon", "color": "orange"}, 51 | {"object": "apricot", "color": "orange"}, 52 | {"object": "eggplant", "color": "purple"}, 53 | {"object": "grape", "color": "purple"}, 54 | {"object": "lavender", "color": "purple"}, 55 | {"object": "plum", "color": "purple"}, 56 | {"object": "amethyst", "color": "purple"}, 57 | {"object": "iris", "color": "purple"}, 58 | {"object": "lilac", "color": "purple"}, 59 | {"object": "orchid", "color": "purple"}, 60 | {"object": "thistle", "color": "purple"}, 61 | {"object": "wisteria", "color": "purple"}, 62 | {"object": "chocolate", "color": "brown"}, 63 | {"object": "coffee", "color": "brown"}, 64 | {"object": "soil", "color": "brown"}, 65 | {"object": "bark", "color": "brown"}, 66 | {"object": "chestnut", "color": "brown"}, 67 | {"object": "walnut", "color": "brown"}, 68 | {"object": "beaver", "color": "brown"}, 69 | {"object": "acorn", "color": "brown"}, 70 | {"object": "bear", "color": "brown"}, 71 | {"object": "bison", "color": "brown"}, 72 | {"object": "snow", "color": "white"}, 73 | {"object": "cloud", "color": "white"}, 74 | {"object": "milk", "color": "white"}, 75 | {"object": "pearl", "color": "white"}, 76 | {"object": "cotton", "color": "white"}, 77 | {"object": "sugar", "color": "white"}, 78 | {"object": "salt", "color": "white"}, 79 | {"object": "tooth", "color": "white"}, 80 | {"object": "paper", "color": "white"}, 81 | {"object": "flour", "color": "white"}, 82 | {"object": "coal", "color": "black"}, 83 | {"object": "panther", "color": "black"}, 84 | {"object": "obsidian", "color": "black"}, 85 | {"object": "crow", "color": "black"}, 86 | {"object": "raven", "color": "black"}, 87 | {"object": "tire", "color": "black"}, 88 | {"object": "spider", "color": "black"}, 89 | {"object": "ink", "color": "black"}, 90 | {"object": "pepper", "color": "black"}, 91 | {"object": "ash", "color": "grey"}, 92 | {"object": "shrimp", "color": "pink"}, 93 | {"object": "salmon", "color": "pink"}, 94 | {"object": "bubblegum", "color": "pink"}, 95 | {"object": "pig", "color": "pink"}, 96 | {"object": "flamingo", "color": "pink"}, 97 | {"object": "peony", "color": "pink"}, 98 | {"object": "carnation", "color": "pink"}, 99 | {"object": "tulip", "color": "pink"}, 100 | {"object": "lotus", "color": "pink"}, 101 | {"object": "grapefruit", "color": "pink"}, 102 | {"object": "spoon", "color": "silver"}, 103 | {"object": "foil", "color": "silver"}, 104 | {"object": "sterling", "color": "silver"}, 105 | {"object": "platinum", "color": "silver"}, 106 | {"object": "mirror", "color": "silver"}, 107 | {"object": "nickel", "color": "silver"}, 108 | {"object": "chrome", "color": "silver"}, 109 | {"object": "cutlery", "color": "silver"}, 110 | {"object": "mercury", "color": "silver"}, 111 | {"object": "zinc", "color": "silver"}, 112 | {"object": "ring", "color": "gold"}, 113 | {"object": "trophy", "color": "gold"}, 114 | {"object": "medal", "color": "gold"}, 115 | {"object": "crown", "color": "gold"}, 116 | {"object": "wheat", "color": "gold"}, 117 | {"object": "honey", "color": "gold"}, 118 | {"object": "brass", "color": "gold"}, 119 | {"object": "champagne", "color": "gold"}, 120 | {"object": "straw", "color": "gold"}, 121 | {"object": "coin", "color": "gold"}, 122 | {"object": "kiwi", "color": "green"}, 123 | {"object": "avocado", "color": "green"}, 124 | {"object": "moss", "color": "green"}, 125 | {"object": "seaweed", "color": "green"}, 126 | {"object": "artichoke", "color": "green"}, 127 | {"object": "pear", "color": "green"}, 128 | {"object": "zucchini", "color": "green"}, 129 | {"object": "cabbage", "color": "green"}, 130 | {"object": "basil", "color": "green"}, 131 | {"object": "parsley", "color": "green"}, 132 | {"object": "cranberry", "color": "red"}, 133 | {"object": "pomegranate", "color": "red"}, 134 | {"object": "radish", "color": "red"}, 135 | {"object": "maple", "color": "red"}, 136 | {"object": "fire", "color": "red"}, 137 | {"object": "ambulance", "color": "red"}, 138 | {"object": "stop sign", "color": "red"}, 139 | {"object": "heart", "color": "red"}, 140 | {"object": "blood", "color": "red"}, 141 | {"object": "wine", "color": "red"}, 142 | {"object": "cantaloupe", "color": "orange"}, 143 | {"object": "papaya", "color": "orange"}, 144 | {"object": "mango", "color": "orange"}, 145 | {"object": "sunset", "color": "orange"}, 146 | {"object": "rust", "color": "orange"}, 147 | {"object": "copper", "color": "orange"}, 148 | {"object": "tiger", "color": "orange"}, 149 | {"object": "fox", "color": "orange"}, 150 | {"object": "coral", "color": "orange"}, 151 | {"object": "autumn", "color": "orange"}, 152 | {"object": "cauliflower", "color": "white"}, 153 | {"object": "garlic", "color": "white"}, 154 | {"object": "rice", "color": "white"}, 155 | {"object": "snowflake", "color": "white"}, 156 | {"object": "chalk", "color": "white"}, 157 | {"object": "marshmallow", "color": "white"}, 158 | {"object": "cream", "color": "white"}, 159 | {"object": "foam", "color": "white"}, 160 | {"object": "ivory", "color": "white"}, 161 | {"object": "porcelain", "color": "white"}, 162 | {"object": "lead", "color": "grey"}, 163 | {"object": "steel", "color": "grey"}, 164 | {"object": "concrete", "color": "grey"}, 165 | {"object": "stone", "color": "grey"}, 166 | {"object": "elephant", "color": "grey"}, 167 | {"object": "mouse", "color": "grey"}, 168 | {"object": "pigeon", "color": "grey"}, 169 | {"object": "smoke", "color": "grey"}, 170 | {"object": "fog", "color": "grey"}, 171 | {"object": "storm cloud", "color": "grey"}, 172 | {"object": "mahogany", "color": "brown"}, 173 | {"object": "teak", "color": "brown"}, 174 | {"object": "oak", "color": "brown"}, 175 | {"object": "coconut", "color": "brown"}, 176 | {"object": "pecan", "color": "brown"}, 177 | {"object": "cinnamon", "color": "brown"}, 178 | {"object": "deer", "color": "brown"}, 179 | {"object": "moose", "color": "brown"}, 180 | {"object": "camel", "color": "brown"}, 181 | {"object": "horse", "color": "brown"} 182 | ] -------------------------------------------------------------------------------- /MIB-causal-variable-track/baselines/ioi_baselines/ioi_learn_linear_params.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 4 | sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) 5 | 6 | from tasks.IOI_task.ioi_task import get_causal_model, get_counterfactual_datasets, get_token_positions 7 | from CausalAbstraction.experiments.filter_experiment import FilterExperiment 8 | from CausalAbstraction.experiments.attention_head_experiment import PatchAttentionHeads 9 | from ioi_utils import log_diff, clear_memory, checker, filter_checker, setup_pipeline 10 | import torch 11 | import gc 12 | import os 13 | import numpy as np 14 | from sklearn.linear_model import LinearRegression 15 | import json 16 | 17 | if __name__ == "__main__": 18 | import argparse 19 | 20 | parser = argparse.ArgumentParser(description="Compute linear model parameters for IOI experiments.") 21 | parser.add_argument("--model", type=str, required=True, choices=["gpt2", "qwen", "llama", "gemma"], 22 | help="Model to use for parameter computation") 23 | parser.add_argument("--use_gpu1", action="store_true", help="Use GPU1 instead of GPU0 if available.") 24 | parser.add_argument("--heads_list", nargs="+", type=lambda s: eval(s), 25 | default=[(7, 3), (7, 9), (8, 6), (8, 10)], 26 | help="List of (layer, head) tuples to intervene on. Example: '(7,3)' '(7,9)'") 27 | parser.add_argument("--eval_batch_size", type=int, default=None, help="Batch size for evaluation (uses model default if not specified)") 28 | parser.add_argument("--output_file", type=str, default="ioi_linear_params.json", help="Output file for linear parameters") 29 | parser.add_argument("--quick_test", action="store_true", help="Run quick test with reduced dataset size") 30 | args = parser.parse_args() 31 | 32 | # Clear memory before starting 33 | gc.collect() 34 | torch.cuda.empty_cache() 35 | 36 | # Device setup 37 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 38 | if args.use_gpu1 and torch.cuda.is_available(): 39 | device = "cuda:1" 40 | 41 | # Get causal model and counterfactual datasets 42 | causal_model = get_causal_model({"bias": 0.0, "token_coeff": 0.0, "position_coeff": 0.0}) 43 | dataset_size = 10 if args.quick_test else None 44 | counterfactual_datasets = get_counterfactual_datasets(hf=True, size=dataset_size) 45 | 46 | # Print dataset info 47 | print("Available datasets:", counterfactual_datasets.keys()) 48 | 49 | # Get a sample to display 50 | sample_dataset = next(iter(counterfactual_datasets.values())) 51 | if len(sample_dataset) > 0: 52 | sample = sample_dataset[0] 53 | print("Sample input:", sample["input"]) 54 | 55 | print(f"\n===== Computing parameters for model: {args.model} =====") 56 | 57 | # Set up pipeline 58 | pipeline, batch_size = setup_pipeline(args.model, device, args.eval_batch_size) 59 | print("DEVICE:", pipeline.model.device) 60 | 61 | # Test model on a sample 62 | if len(sample_dataset) > 0: 63 | sample = sample_dataset[0] 64 | print("INPUT:", sample["input"]["raw_input"]) 65 | expected = causal_model.run_forward(sample["input"])["raw_output"] 66 | print("EXPECTED OUTPUT:", expected) 67 | print("MODEL PREDICTION:", pipeline.dump(pipeline.generate(sample["input"]["raw_input"]))) 68 | 69 | # Filter the datasets 70 | print("\nFiltering datasets based on model performance...") 71 | exp = FilterExperiment(pipeline, causal_model, filter_checker) 72 | filtered_datasets = exp.filter(counterfactual_datasets, verbose=True, batch_size=batch_size) 73 | 74 | # Get token positions 75 | token_positions = get_token_positions(pipeline, causal_model) 76 | 77 | # Limit heads_list for quick test 78 | if args.quick_test and len(args.heads_list) > 1: 79 | args.heads_list = args.heads_list[:1] # Use only the first head for quick test 80 | print(f"Quick test mode: limiting to heads {args.heads_list}") 81 | 82 | print("\nFitting linear model for logit differences...") 83 | 84 | # Set up for return_scores 85 | pipeline.return_scores = True 86 | 87 | # Collect data for linear regression 88 | data_to_X = { 89 | "same_train": {"position": 1, "token": 1}, 90 | "s1_io_flip_train": {"position": -1, "token": 1}, 91 | "s2_io_flip_train": {"position": -1, "token": -1}, 92 | "s1_ioi_flip_s2_ioi_flip_train": {"position": 1, "token": -1} 93 | } 94 | 95 | # Limit datasets for quick test 96 | if args.quick_test: 97 | # Use only first two datasets for quick test 98 | data_to_X = dict(list(data_to_X.items())[:2]) 99 | X, y = [], [] 100 | 101 | for counterfactual_name in data_to_X: 102 | if counterfactual_name not in filtered_datasets: 103 | print(f"Warning: {counterfactual_name} not found in filtered datasets, skipping...") 104 | continue 105 | 106 | experiment = PatchAttentionHeads( 107 | pipeline=pipeline, 108 | causal_model=causal_model, 109 | layer_head_list=args.heads_list, 110 | token_positions=token_positions, 111 | checker=lambda logits, params: checker(logits, params, pipeline), 112 | config={"evaluation_batch_size": batch_size, "output_scores": True, "check_raw":True} 113 | ) 114 | 115 | raw_results = experiment.perform_interventions( 116 | {counterfactual_name: filtered_datasets[counterfactual_name]}, 117 | target_variables_list=[["output_token"]], 118 | verbose=False 119 | ) 120 | 121 | raw_outputs = None 122 | losses, labels, counterfactual_y = [], [], [] 123 | 124 | for v in raw_results["dataset"][counterfactual_name].values(): 125 | for v2 in v.values(): 126 | raw_outputs = v2["raw_outputs"][0] 127 | 128 | for raw_logits, input_data in zip(raw_outputs, filtered_datasets[counterfactual_name]): 129 | actual_diff = log_diff(raw_logits, causal_model.run_forward(input_data["input"]), pipeline) 130 | high_level_output = causal_model.run_interchange( 131 | input_data["input"], 132 | {"output_token": input_data["counterfactual_inputs"][0], 133 | "output_position": input_data["counterfactual_inputs"][0]} 134 | ) 135 | loss = checker(raw_logits, high_level_output, pipeline) 136 | label = high_level_output["logit_diff"] 137 | 138 | y.append(actual_diff) 139 | counterfactual_y.append(actual_diff) 140 | X.append((data_to_X[counterfactual_name]["position"], data_to_X[counterfactual_name]["token"])) 141 | losses.append(loss) 142 | labels.append(label) 143 | 144 | # Compute and print the average y for the current counterfactual 145 | avg_y = sum(counterfactual_y) / len(counterfactual_y) if counterfactual_y else 0 146 | print(f"Average y for counterfactual '{counterfactual_name}': {avg_y}") 147 | print(f"Average label for counterfactual '{counterfactual_name}': {sum(labels) / len(labels)}") 148 | print(f"Average loss for counterfactual '{counterfactual_name}': {sum(losses) / len(losses)}") 149 | 150 | # Fit linear model 151 | model = LinearRegression() 152 | X = torch.tensor(X) 153 | y = torch.tensor(y) 154 | model.fit(X, y) 155 | 156 | # Print the coefficients 157 | print("Coefficients:", model.coef_) 158 | print("Intercept:", model.intercept_) 159 | print("Score:", model.score(X, y)) 160 | 161 | intercept = float(model.intercept_) 162 | position_coef = float(model.coef_[0]) 163 | token_coef = float(model.coef_[1]) 164 | 165 | # Store results 166 | from ioi_utils import get_model_config 167 | model_config = get_model_config(args.model) 168 | model_path = model_config["model_path"] 169 | 170 | results = { 171 | args.model: { 172 | "bias": intercept, 173 | "position_coeff": position_coef, 174 | "token_coeff": token_coef, 175 | "score": float(model.score(X, y)), 176 | "model_name": model_path 177 | } 178 | } 179 | 180 | print(f"Linear parameters for {args.model}:") 181 | print(f" bias: {intercept}") 182 | print(f" position_coeff: {position_coef}") 183 | print(f" token_coeff: {token_coef}") 184 | print(f" R² score: {model.score(X, y)}") 185 | 186 | # Save results to JSON file 187 | with open(args.output_file, 'w') as f: 188 | json.dump(results, f, indent=2) 189 | 190 | print(f"\nLinear parameters saved to {args.output_file}") 191 | print("Parameter computation completed.") -------------------------------------------------------------------------------- /MIB-causal-variable-track/tasks/IOI_task/ioi_task.py: -------------------------------------------------------------------------------- 1 | import json, random, os, sys 2 | from pathlib import Path 3 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 4 | sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) 5 | 6 | from CausalAbstraction.causal.causal_model import CausalModel, CounterfactualDataset 7 | from CausalAbstraction.neural.LM_units import TokenPosition, get_last_token_index 8 | import copy 9 | 10 | from copy import deepcopy 11 | from tasks.hf_dataloader import load_hf_dataset 12 | import re 13 | 14 | def get_data(path): 15 | with open(path, 'r') as f: 16 | data = json.load(f) 17 | return data 18 | 19 | def get_causal_model(parameters): 20 | """ 21 | Create and return the causal model for IOI task. 22 | """ 23 | # Load data files 24 | 25 | pos_coeff = parameters["position_coeff"] 26 | token_coeff = parameters["token_coeff"] 27 | bias = parameters["bias"] 28 | 29 | variables = ["raw_input", "output_position", "output_token", "name_A", "name_B", "name_C", 30 | "logit_diff", "raw_output"] 31 | 32 | values = { 33 | "raw_input": [""], # Placeholder 34 | "output_position": [0, 1], 35 | "output_token": [""], 36 | "name_A": [""], 37 | "name_B": [""], 38 | "name_C": [""], 39 | "logit_diff": [0.0], # Placeholder for float values 40 | "raw_output": [""] # Placeholder 41 | } 42 | 43 | parents = { 44 | "raw_input":[], 45 | "name_A": [], 46 | "name_B": [], 47 | "name_C": [], 48 | "output_token": ["name_A", "name_B", "name_C"], 49 | "output_position": ["name_A", "name_B", "name_C"], 50 | "logit_diff": ["name_A", "name_B", "name_C", "output_token", "output_position"], 51 | "raw_output": ["output_token"] 52 | } 53 | 54 | def get_name_A(): 55 | return "" 56 | 57 | def get_name_B(): 58 | return "" 59 | 60 | def get_name_C(): 61 | return "" 62 | 63 | def get_output_position(name_A, name_B, name_C): 64 | if name_C == name_A: 65 | return 1 66 | elif name_C == name_B: 67 | return 0 68 | else: 69 | return "Error" 70 | 71 | def get_output_token(name_A, name_B, name_C): 72 | if name_C == name_A: 73 | return name_B 74 | elif name_C == name_B: 75 | return name_A 76 | else: 77 | return "Error" 78 | 79 | def get_logit_diff(name_A, name_B, name_C, output_token, output_position): 80 | token_signal = None 81 | if (name_C == name_A and output_token == name_B) or (name_C == name_B and output_token == name_A): 82 | token_signal = 1 83 | elif (name_C == name_A and output_token == name_A) or (name_C == name_B and output_token == name_B): 84 | token_signal = -1 85 | 86 | position_signal = None 87 | if (name_C == name_A and output_position == 1) or (name_C == name_B and output_position == 0): 88 | position_signal = 1 89 | elif (name_C == name_A and output_position == 0) or (name_C == name_B and output_position == 1): 90 | position_signal = -1 91 | 92 | return bias + token_coeff * token_signal + pos_coeff * position_signal 93 | 94 | def get_raw_output(output_token): 95 | """Generate the raw output (just the output token).""" 96 | return output_token 97 | 98 | mechanisms = { 99 | "raw_input": lambda: "", 100 | "name_A": get_name_A, 101 | "name_B": get_name_B, 102 | "name_C": get_name_C, 103 | "output_token": get_output_token, 104 | "output_position": get_output_position, 105 | "logit_diff": get_logit_diff, 106 | "raw_output": get_raw_output 107 | } 108 | 109 | return CausalModel(variables, values, parents, mechanisms, id="ioi") 110 | 111 | def parse_ioi_example(input): 112 | templates_path = os.path.join(Path(__file__).resolve().parent.parent, os.path.join("IOI_task", 'templates.json')) 113 | TEMPLATES = get_data(templates_path) 114 | # Helper to convert template into regex and track variable order 115 | def extract_vars(prompt): 116 | prompt = ' '.join(prompt.split()) # Normalize whitespace 117 | 118 | def template_to_regex(template): 119 | pattern = re.escape(template) 120 | var_counts = {} 121 | 122 | # Match all {var} placeholders in order 123 | all_vars = re.findall(r"\{(name_A|name_B|name_C|place|object)\}", template) 124 | 125 | for var in all_vars: 126 | var_counts[var] = var_counts.get(var, 0) + 1 127 | 128 | if var_counts[var] == 1: 129 | group = f"(?P<{var}>[^,\.]+)" 130 | else: 131 | # Avoid redefining the same named group 132 | group = r"[^,\.]+" 133 | 134 | escaped_var = re.escape(f"{{{var}}}") 135 | pattern = pattern.replace(escaped_var, group, 1) # only replace the first occurrence 136 | 137 | return re.compile(f"^{pattern}$") 138 | 139 | for template in TEMPLATES: 140 | regex = template_to_regex(template) 141 | match = regex.match(prompt) 142 | if match: 143 | return match.groupdict(), template 144 | 145 | print(f"Prompt '{prompt}' does not match any template.") 146 | output = {} 147 | output["raw_input"] = input["prompt"] 148 | if "metadata" in input: 149 | output["name_A"] = input["metadata"]["subject"] 150 | output["name_B"] = input["metadata"]["indirect_object"] 151 | output["name_C"] = input["metadata"]["subject"] 152 | output["object"] = input["metadata"]["object"] if "object" in input["metadata"] else None 153 | output["place"] = input["metadata"]["place"] if "place" in input["metadata"] else None 154 | output["template"] = input["template"] 155 | else: 156 | variables = {} 157 | try: 158 | variables, template = extract_vars(input['prompt']) 159 | output["name_A"] = variables["name_A"] 160 | output["name_B"] = variables["name_B"] 161 | output["name_C"] = variables["name_C"] 162 | output["object"] = variables["object"] if "object" in variables else None 163 | output["place"] = variables["place"] if "place" in variables else None 164 | output["template"] = template 165 | except Exception as e: 166 | print(f"Error parsing prompt: {input['prompt']} {output}") 167 | print(e) 168 | assert False 169 | 170 | 171 | 172 | return output 173 | 174 | def get_counterfactual_datasets(hf=True, size=None, load_private_data=False): 175 | """ 176 | Load and return counterfactual datasets for IOI task. 177 | """ 178 | 179 | 180 | # Load dataset from HuggingFace with customized parsing 181 | datasets = {} 182 | for split in ["train", "test"]: 183 | temp = load_hf_dataset( 184 | dataset_path="mib-bench/ioi", 185 | split=split, 186 | parse_fn=parse_ioi_example, 187 | size=size, 188 | ignore_names=["random", "abc"] 189 | ) 190 | datasets.update(temp) 191 | 192 | if load_private_data: 193 | private = load_hf_dataset( 194 | dataset_path="mib-bench/ioi_private_test", 195 | split="test", 196 | parse_fn=parse_ioi_example, 197 | size=size, 198 | ignore_names=["random", "abc"] 199 | ) 200 | datasets.update({k+"private":v for k,v in private.items()}) 201 | 202 | # Add "same" counterfactual dataset as post-processing step 203 | # For each existing dataset, create a "same" version where counterfactual_inputs equals input 204 | same_datasets = {} 205 | for dataset_name, dataset in datasets.items(): 206 | same_name = "same_" + dataset_name.split("_")[-1] # "same_train" 207 | 208 | # Create new dataset where counterfactual_inputs = [input] 209 | same_data = { 210 | "input": [], 211 | "counterfactual_inputs": [] 212 | } 213 | 214 | for example in dataset: 215 | same_data["input"].append(example["input"]) 216 | same_data["counterfactual_inputs"].append([copy.deepcopy(example["input"])]) 217 | 218 | same_datasets[same_name] = CounterfactualDataset.from_dict( 219 | same_data, 220 | id=same_name 221 | ) 222 | 223 | # Add the same datasets to the main datasets dict 224 | datasets.update(same_datasets) 225 | 226 | return datasets 227 | 228 | def get_token_positions(pipeline, causal_model): 229 | """ 230 | Get token positions for IOI task interventions. 231 | Returns all token positions. 232 | """ 233 | def get_all_token_positions(input_dict, pipeline): 234 | """Get all token positions in the input.""" 235 | tokens = list(range(len(pipeline.load(input_dict)['input_ids'][0]))) 236 | return tokens 237 | 238 | return [TokenPosition(lambda x: get_all_token_positions(x, pipeline), pipeline, id="all")] -------------------------------------------------------------------------------- /MIB-causal-variable-track/baselines/simple_MCQA_baselines.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 4 | from tasks.simple_MCQA.simple_MCQA import get_token_positions, get_counterfactual_datasets, get_causal_model 5 | from experiments.aggregate_experiments import residual_stream_baselines 6 | from neural.pipeline import LMPipeline 7 | from experiments.filter_experiment import FilterExperiment 8 | import torch 9 | import gc 10 | import os 11 | 12 | if __name__ == "__main__": 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser(description="Run simple MCQA experiments with optional flags.") 16 | parser.add_argument("--skip_gemma", action="store_true", help="Skip running experiments for Gemma model.") 17 | parser.add_argument("--skip_llama", action="store_true", help="Skip running experiments for Llama model.") 18 | parser.add_argument("--skip_qwen", action="store_true", help="Skip running experiments for Qwen model.") 19 | parser.add_argument("--skip_answer_pointer", action="store_true", help="Skip experiments for answer_pointer variable.") 20 | parser.add_argument("--skip_answer", action="store_true", help="Skip experiments for answer variable.") 21 | parser.add_argument("--use_gpu1", action="store_true", help="Use GPU1 instead of GPU0 if available.") 22 | parser.add_argument("--methods", nargs="+", 23 | default=["full_vector", "DAS", "DBM+SVD", "DBM+PCA", "DBM", "DBM+SAE"], 24 | help="List of methods to run") 25 | parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training") 26 | parser.add_argument("--eval_batch_size", type=int, default=1024, help="Batch size for evaluation") 27 | parser.add_argument("--results_dir", type=str, default="simple_MCQA_results", help="Directory to save results") 28 | parser.add_argument("--model_dir", type=str, default="simple_MCQA_models", help="Directory to save trained models") 29 | parser.add_argument("--quick_test", action="store_true",) 30 | args = parser.parse_args() 31 | 32 | # Clear memory before starting 33 | gc.collect() 34 | torch.cuda.empty_cache() 35 | 36 | # Device setu 37 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 38 | if args.use_gpu1 and torch.cuda.is_available(): 39 | device = "cuda:1" 40 | 41 | # Check function for evaluating model outputs 42 | def checker(output_text, expected): 43 | return expected in output_text 44 | 45 | # Function to clear memory between experiments 46 | def clear_memory(): 47 | gc.collect() 48 | if torch.cuda.is_available(): 49 | torch.cuda.empty_cache() 50 | torch.cuda.synchronize() 51 | 52 | # Get counterfactual datasets and causal model 53 | dataset_size = 10 if args.quick_test else None 54 | counterfactual_datasets = get_counterfactual_datasets(hf=True, size=dataset_size) 55 | causal_model = get_causal_model() 56 | 57 | # Print available datasets 58 | print("Available datasets:", counterfactual_datasets.keys()) 59 | 60 | # Set up models to test 61 | models = [] 62 | if not args.skip_qwen: 63 | models.append("Qwen/Qwen2.5-0.5B") 64 | if not args.skip_gemma: 65 | models.append("google/gemma-2-2b") 66 | if not args.skip_llama: 67 | models.append("meta-llama/Meta-Llama-3.1-8B-Instruct") 68 | 69 | for model_name in models: 70 | print(f"\n===== Testing model: {model_name} =====") 71 | 72 | # Set up LM Pipeline 73 | pipeline = LMPipeline(model_name, max_new_tokens=1, device=device, dtype=torch.float16) 74 | pipeline.tokenizer.padding_side = "left" 75 | print("DEVICE:", pipeline.model.device) 76 | 77 | # Get a sample input and check model's prediction 78 | sampled_example = next(iter(counterfactual_datasets.values()))[0] 79 | print("INPUT:", sampled_example["input"]) 80 | print("EXPECTED OUTPUT:", causal_model.run_forward(sampled_example["input"])["raw_output"]) 81 | print("MODEL PREDICTION:", pipeline.dump(pipeline.generate(sampled_example["input"]))) 82 | 83 | # Filter the datasets based on model performance 84 | print("\nFiltering datasets based on model performance...") 85 | exp = FilterExperiment(pipeline, causal_model, checker) 86 | filtered_datasets = exp.filter(counterfactual_datasets, verbose=True, batch_size=args.eval_batch_size) 87 | 88 | # Get token positions for intervention 89 | token_positions = get_token_positions(pipeline, causal_model) 90 | 91 | # Display token highlighting for a sample 92 | print("\nToken positions highlighted in samples:") 93 | for dataset in filtered_datasets.values(): 94 | for token_position in token_positions: 95 | example = dataset[0] 96 | print(token_position.highlight_selected_token(example["counterfactual_inputs"][0])) 97 | break 98 | break 99 | 100 | # Clear memory before running experiments 101 | clear_memory() 102 | 103 | # Setup experiment configuration 104 | start = 0 105 | end = pipeline.get_num_layers() 106 | if args.quick_test: 107 | end = 1 108 | 109 | config = { 110 | "batch_size": args.batch_size, 111 | "evaluation_batch_size": args.eval_batch_size, 112 | "training_epoch": 8, 113 | "n_features": 16, 114 | "regularization_coefficient": 0.0, 115 | "output_scores": False 116 | } 117 | 118 | # Prepare dataset names 119 | names = ["answerPosition", "randomLetter", "answerPosition_randomLetter"] 120 | 121 | # Make sure results and model directories exist 122 | if not os.path.exists(args.results_dir): 123 | os.makedirs(args.results_dir) 124 | 125 | if not os.path.exists(args.model_dir): 126 | os.makedirs(args.model_dir) 127 | 128 | # Run experiments for answer_pointer 129 | if not args.skip_answer_pointer: 130 | print(f"\nRunning experiments for target variable: answer_pointer") 131 | 132 | # Prepare train and test data dictionaries 133 | train_data = {} 134 | test_data = {} 135 | 136 | for name in names: 137 | if name + "_train" in filtered_datasets: 138 | train_data[name + "_train"] = filtered_datasets[name + "_train"] 139 | if name + "_test" in filtered_datasets: 140 | test_data[name + "_test"] = filtered_datasets[name + "_test"] 141 | if name + "_testprivate" in filtered_datasets: 142 | test_data[name + "_testprivate"] = filtered_datasets[name + "_testprivate"] 143 | 144 | residual_stream_baselines( 145 | pipeline=pipeline, 146 | task=causal_model, 147 | token_positions=token_positions, 148 | train_data=train_data, 149 | test_data=test_data, 150 | config=config, 151 | target_variables=["answer_pointer"], 152 | checker=checker, 153 | start=start, 154 | end=end, 155 | verbose=True, 156 | model_dir=os.path.join(args.model_dir, "answer_pointer"), 157 | results_dir=args.results_dir, 158 | methods=args.methods 159 | ) 160 | clear_memory() 161 | 162 | # Run experiments for answer (using larger feature size) 163 | if not args.skip_answer: 164 | print(f"\nRunning experiments for target variable: answer") 165 | config["n_features"] = pipeline.model.config.hidden_size // 2 166 | 167 | # Prepare train and test data dictionaries (again, in case filtered_datasets changed) 168 | train_data = {} 169 | test_data = {} 170 | 171 | for name in names: 172 | if name + "_train" in filtered_datasets: 173 | train_data[name + "_train"] = filtered_datasets[name + "_train"] 174 | if name + "_test" in filtered_datasets: 175 | test_data[name + "_test"] = filtered_datasets[name + "_test"] 176 | if name + "_testprivate" in filtered_datasets: 177 | test_data[name + "_testprivate"] = filtered_datasets[name + "_testprivate"] 178 | 179 | config["n_features"] = pipeline.model.config.hidden_size // 2 180 | residual_stream_baselines( 181 | pipeline=pipeline, 182 | task=causal_model, 183 | token_positions=token_positions, 184 | train_data=train_data, 185 | test_data=test_data, 186 | config=config, 187 | target_variables=["answer"], 188 | checker=checker, 189 | start=start, 190 | end=end, 191 | verbose=True, 192 | model_dir=os.path.join(args.model_dir, "answer"), 193 | results_dir=args.results_dir, 194 | methods=args.methods 195 | ) 196 | clear_memory() 197 | 198 | # Clean up pipeline to free memory before starting next model 199 | del pipeline 200 | clear_memory() 201 | 202 | print("\nAll experiments completed.") -------------------------------------------------------------------------------- /MIB-causal-variable-track/baselines/ioi_baselines/ioi_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gc 3 | import numpy as np 4 | from experiments.pyvene_core import _prepare_intervenable_inputs 5 | 6 | def clear_memory(): 7 | """Clear memory between experiments to prevent OOM errors.""" 8 | gc.collect() 9 | if torch.cuda.is_available(): 10 | torch.cuda.empty_cache() 11 | torch.cuda.synchronize() 12 | 13 | def log_diff(logits, params, pipeline): 14 | """ 15 | Compute the difference in logit scores between two tokens. 16 | 17 | Args: 18 | logits: Tensor containing logit scores for tokens 19 | params: Dictionary containing 'name_A', 'name_B', and 'name_C' 20 | pipeline: Pipeline object with tokenizer 21 | 22 | Returns: 23 | Tensor: logit_IO - logit_S 24 | """ 25 | # Extract names from params 26 | name_A = params["name_A"] 27 | name_B = params["name_B"] 28 | name_C = params["name_C"] 29 | 30 | if not isinstance(name_A, list): 31 | name_A = [name_A] 32 | if not isinstance(name_B, list): 33 | name_B = [name_B] 34 | if not isinstance(name_C, list): 35 | name_C = [name_C] 36 | 37 | token_id_A = [pipeline.tokenizer.encode(A, add_special_tokens=False)[0] for A in name_A] 38 | token_id_B = [pipeline.tokenizer.encode(B, add_special_tokens=False)[0] for B in name_B] 39 | token_id_C = [pipeline.tokenizer.encode(C, add_special_tokens=False)[0] for C in name_C] 40 | 41 | token_id_IO, token_id_S = [], [] 42 | for i in range(len(token_id_A)): 43 | if token_id_A[i] == token_id_C[i]: 44 | token_id_S.append(token_id_A[i]) 45 | token_id_IO.append(token_id_B[i]) 46 | elif token_id_B[i] == token_id_C[i]: 47 | token_id_S.append(token_id_B[i]) 48 | token_id_IO.append(token_id_A[i]) 49 | 50 | if isinstance(logits, tuple): 51 | logits = logits[0] 52 | 53 | # Get the logit scores for both tokens 54 | if len(logits.shape) == 3: 55 | logits = logits.squeeze(1) 56 | if len(logits.shape) == 2: 57 | # Create batch indices 58 | batch_indices = torch.arange(logits.shape[0]) 59 | 60 | # Extract specific logits using batch indices 61 | logit_S = logits[batch_indices, token_id_S] 62 | logit_IO = logits[batch_indices, token_id_IO] 63 | elif len(logits.shape) == 1: 64 | logit_S = logits[token_id_S[0]] 65 | logit_IO = logits[token_id_IO[0]] 66 | 67 | return logit_IO - logit_S 68 | 69 | def checker(logits, params, pipeline): 70 | """ 71 | Compute the squared error between the actual logit difference and the target logit difference. 72 | 73 | Args: 74 | logits: Tensor containing logit scores for tokens 75 | params: Dictionary containing 'name_A', 'name_B', 'name_C', and 'logit_diff' 76 | pipeline: Pipeline object with tokenizer 77 | 78 | Returns: 79 | Tensor: Squared error between the computed logit difference and the target logit difference 80 | """ 81 | if isinstance(logits, list): 82 | logits = logits[0] 83 | 84 | target_diff = params["logit_diff"] 85 | actual_diff = log_diff(logits, params, pipeline) 86 | if isinstance(target_diff, torch.Tensor): 87 | target_diff = target_diff.to(actual_diff.device).to(actual_diff.dtype) 88 | 89 | squared_error = (actual_diff - target_diff) ** 2 90 | 91 | return squared_error 92 | 93 | def filter_checker(output_text, expected): 94 | """ 95 | Simple checker for filtering that just checks if the expected token appears in the output. 96 | Used only for dataset filtering, not for the actual experiments. 97 | 98 | Args: 99 | output_text (str): The model's output text 100 | expected (str): The expected output 101 | 102 | Returns: 103 | bool: True if expected token appears in output 104 | """ 105 | return expected in output_text 106 | 107 | def custom_loss(logits, params, pipeline): 108 | """ 109 | Average loss function for training that handles both single examples and batches. 110 | 111 | Args: 112 | logits: Model logits 113 | params: Parameters (can be single dict or list of dicts for batch) 114 | pipeline: Pipeline object with tokenizer 115 | 116 | Returns: 117 | Tensor: Average loss 118 | """ 119 | if isinstance(params, list): 120 | # params is a list of dicts, one for each example in the batch 121 | total_loss = 0 122 | for i, param_dict in enumerate(params): 123 | # Extract the i-th logits for this example 124 | example_logits = logits[i] if logits.dim() > 1 else logits 125 | loss = checker(example_logits, param_dict, pipeline) 126 | total_loss += loss 127 | return total_loss / len(params) 128 | else: 129 | # Single example case (original behavior) 130 | return checker(logits, params, pipeline).mean() 131 | 132 | def ioi_loss_and_metric_fn(pipeline, intervenable_model, batch, model_units_list): 133 | """ 134 | Calculate loss and evaluation metrics for IOI interventions. 135 | 136 | Uses the checker function as a metric (squared error) and custom_loss 137 | as the loss function for training. 138 | 139 | Args: 140 | pipeline: Pipeline object 141 | intervenable_model: The intervenable model 142 | batch: Batch of data 143 | model_units_list: List of model units 144 | 145 | Returns: 146 | tuple: (loss, eval_metrics, logging_info) 147 | """ 148 | # 1. Prepare intervenable inputs 149 | batched_base, batched_counterfactuals, inv_locations, feature_indices = _prepare_intervenable_inputs( 150 | pipeline, batch, model_units_list) 151 | 152 | # 2. Run the intervenable model to get logits 153 | _, counterfactual_logits = intervenable_model( 154 | batched_base, batched_counterfactuals, 155 | unit_locations=inv_locations, 156 | subspaces=feature_indices 157 | ) 158 | 159 | # 3. Extract the logits (last token position since max_new_tokens=1) 160 | logits = counterfactual_logits.logits[:, -1, :] # Shape: (batch_size, vocab_size) 161 | 162 | # 4. Get the settings/parameters from the batch 163 | # These should contain name_A, name_B, name_C, and logit_diff 164 | settings = batch['setting'] # or batch['label'] depending on how it's structured 165 | 166 | # 5. Compute loss using custom_loss 167 | loss = custom_loss(logits, settings, pipeline) 168 | 169 | # 6. Compute metrics using checker (squared errors) 170 | squared_errors = [] 171 | for i in range(len(logits)): 172 | error = checker(logits[i], settings[i], pipeline) 173 | squared_errors.append(error.item()) 174 | 175 | eval_metrics = { 176 | "mse": np.mean(squared_errors), # Mean squared error 177 | "rmse": np.sqrt(np.mean(squared_errors)) # Root mean squared error 178 | } 179 | 180 | # 7. Prepare logging info 181 | logging_info = { 182 | "batch_size": len(batch['input']), 183 | "avg_logit_diff": np.mean([s['logit_diff'] for s in settings]) 184 | } 185 | 186 | return loss, eval_metrics, logging_info 187 | 188 | def get_model_config(model_name): 189 | """ 190 | Get model configuration based on model name. 191 | 192 | Args: 193 | model_name (str): One of "gpt2", "qwen", "llama", "gemma" 194 | 195 | Returns: 196 | dict: Configuration dictionary with model_path, batch_size, and special_config 197 | """ 198 | model_configs = { 199 | "gpt2": { 200 | "model_path": "openai-community/gpt2", 201 | "batch_size": 1024, 202 | "special_config": True # Needs special GPT2Config 203 | }, 204 | "qwen": { 205 | "model_path": "Qwen/Qwen2.5-0.5B", 206 | "batch_size": 256, 207 | "special_config": False 208 | }, 209 | "llama": { 210 | "model_path": "meta-llama/Meta-Llama-3.1-8B-Instruct", 211 | "batch_size": 256, 212 | "special_config": False 213 | }, 214 | "gemma": { 215 | "model_path": "google/gemma-2-2b", 216 | "batch_size": 256, 217 | "special_config": False 218 | } 219 | } 220 | 221 | if model_name not in model_configs: 222 | raise ValueError(f"Unknown model name: {model_name}. Choose from {list(model_configs.keys())}") 223 | 224 | return model_configs[model_name] 225 | 226 | def setup_pipeline(model_name, device, eval_batch_size=None): 227 | """ 228 | Set up the pipeline for a given model. 229 | 230 | Args: 231 | model_name (str): One of "gpt2", "qwen", "llama", "gemma" 232 | device (str): Device to use 233 | eval_batch_size (int, optional): Override default batch size 234 | 235 | Returns: 236 | tuple: (pipeline, batch_size) 237 | """ 238 | from neural.pipeline import LMPipeline 239 | 240 | config = get_model_config(model_name) 241 | model_path = config["model_path"] 242 | batch_size = eval_batch_size if eval_batch_size else config["batch_size"] 243 | 244 | if config["special_config"]: 245 | # Special configuration for GPT2 246 | from transformers import GPT2Config 247 | gpt_config = GPT2Config.from_pretrained(model_path) 248 | pipeline = LMPipeline(model_path, max_new_tokens=1, device=device, dtype=torch.float32, 249 | max_length=32, logit_labels=True, position_ids=True, config=gpt_config) 250 | else: 251 | pipeline = LMPipeline(model_path, max_new_tokens=1, device=device, dtype=torch.float16, 252 | max_length=32, logit_labels=True) 253 | 254 | pipeline.tokenizer.padding_side = "left" 255 | 256 | return pipeline, batch_size -------------------------------------------------------------------------------- /MIB-causal-variable-track/baselines/ARC_baselines.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 4 | from tasks.ARC.ARC import get_token_positions, get_counterfactual_datasets, get_causal_model 5 | from experiments.aggregate_experiments import residual_stream_baselines 6 | from neural.pipeline import LMPipeline 7 | from experiments.filter_experiment import FilterExperiment 8 | import torch 9 | import gc 10 | import os 11 | 12 | if __name__ == "__main__": 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser(description="Run ARC Easy experiments with optional flags.") 16 | parser.add_argument("--skip_gemma", action="store_true", help="Skip running experiments for Gemma model.") 17 | parser.add_argument("--skip_llama", action="store_true", help="Skip running experiments for Llama model.") 18 | parser.add_argument("--skip_answer_pointer", action="store_true", help="Skip experiments for answer_pointer variable.") 19 | parser.add_argument("--skip_answer", action="store_true", help="Skip experiments for answer variable.") 20 | parser.add_argument("--use_gpu1", action="store_true", help="Use GPU1 instead of GPU0 if available.") 21 | parser.add_argument("--methods", nargs="+", 22 | default=["full_vector", "DAS", "DBM+SVD", "DBM+PCA", "DBM", "DBM+SAE"], 23 | help="List of methods to run") 24 | parser.add_argument("--batch_size", type=int, default=48, help="Batch size for training") 25 | parser.add_argument("--eval_batch_size", type=int, default=48, help="Batch size for evaluation") 26 | parser.add_argument("--results_dir", type=str, default="ARC_results", help="Directory to save results") 27 | parser.add_argument("--model_dir", type=str, default="ARC_models", help="Directory to save trained models") 28 | parser.add_argument("--quick_test", action="store_true", help="Run quick test with reduced dataset size and layers") 29 | args = parser.parse_args() 30 | 31 | # Clear memory before starting 32 | gc.collect() 33 | torch.cuda.empty_cache() 34 | 35 | # Device setup 36 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 37 | if args.use_gpu1 and torch.cuda.is_available(): 38 | device = "cuda:1" 39 | 40 | # Check function for evaluating model outputs 41 | def checker(output_text, expected): 42 | return expected in output_text 43 | 44 | # Function to clear memory between experiments 45 | def clear_memory(): 46 | gc.collect() 47 | if torch.cuda.is_available(): 48 | torch.cuda.empty_cache() 49 | torch.cuda.synchronize() 50 | 51 | # Get counterfactual datasets and causal model 52 | dataset_size = 10 if args.quick_test else None 53 | counterfactual_datasets = get_counterfactual_datasets(hf=True, size=dataset_size) 54 | causal_model = get_causal_model() 55 | 56 | # Print available datasets 57 | print("Available datasets:", counterfactual_datasets.keys()) 58 | 59 | # Set up models to test 60 | models = [] 61 | if not args.skip_gemma: 62 | models.append("google/gemma-2-2b") 63 | if not args.skip_llama: 64 | models.append("meta-llama/Meta-Llama-3.1-8B-Instruct") 65 | 66 | for model_name in models: 67 | print(f"\n===== Testing model: {model_name} =====") 68 | 69 | # Set up LM Pipeline 70 | pipeline = LMPipeline(model_name, max_new_tokens=1, device=device, dtype=torch.float16) 71 | pipeline.tokenizer.padding_side = "left" 72 | print("DEVICE:", pipeline.model.device) 73 | 74 | # Get a sample input and check model's prediction 75 | sampled_example = next(iter(counterfactual_datasets.values()))[0] 76 | print("INPUT:", sampled_example["input"]) 77 | print("EXPECTED OUTPUT:", causal_model.run_forward(sampled_example["input"])["raw_output"]) 78 | print("MODEL PREDICTION:", pipeline.dump(pipeline.generate(sampled_example["input"]))) 79 | 80 | # Filter the datasets based on model performance 81 | print("\nFiltering datasets based on model performance...") 82 | exp = FilterExperiment(pipeline, causal_model, checker) 83 | filtered_datasets = exp.filter(counterfactual_datasets, verbose=True, batch_size=args.eval_batch_size) 84 | 85 | # Get token positions for intervention 86 | token_positions = get_token_positions(pipeline, causal_model) 87 | 88 | # Display token highlighting for a sample 89 | print("\nToken positions highlighted in samples:") 90 | for dataset in filtered_datasets.values(): 91 | for token_position in token_positions: 92 | example = dataset[0] 93 | print(token_position.highlight_selected_token(example["counterfactual_inputs"][0])) 94 | break 95 | break 96 | 97 | # Clear memory before running experiments 98 | clear_memory() 99 | 100 | # Setup experiment configuration 101 | start = 0 102 | end = 1 if args.quick_test else pipeline.get_num_layers() 103 | 104 | config = { 105 | "batch_size": args.batch_size, 106 | "evaluation_batch_size": args.eval_batch_size, 107 | "training_epoch": 2, 108 | "n_features": 16, 109 | "regularization_coefficient": 0.0, 110 | "output_scores": False 111 | } 112 | if model_name == "Llama-3.1-8B-Instruct": 113 | config["batch_size"] = 16 114 | config["evaluation_batch_size"] = 16 115 | 116 | # Prepare dataset names - ARC has different counterfactual types than MCQA 117 | # Based on the original code, ARC likely has: answerPosition, randomLetter, answerPosition_randomLetter 118 | names = ["answerPosition", "randomLetter", "answerPosition_randomLetter"] 119 | 120 | # Make sure results and model directories exist 121 | if not os.path.exists(args.results_dir): 122 | os.makedirs(args.results_dir) 123 | 124 | if not os.path.exists(args.model_dir): 125 | os.makedirs(args.model_dir) 126 | 127 | # Run experiments for answer_pointer 128 | if not args.skip_answer_pointer: 129 | print(f"\nRunning experiments for target variable: answer_pointer") 130 | 131 | # Prepare train and test data dictionaries 132 | train_data = {} 133 | test_data = {} 134 | 135 | for name in names: 136 | if name + "_train" in filtered_datasets: 137 | train_data[name + "_train"] = filtered_datasets[name + "_train"] 138 | if name + "_validation" in filtered_datasets: 139 | train_data[name + "_validation"] = filtered_datasets[name + "_validation"] 140 | if name + "_test" in filtered_datasets: 141 | test_data[name + "_test"] = filtered_datasets[name + "_test"] 142 | if name + "_testprivate" in filtered_datasets: 143 | test_data[name + "_testprivate"] = filtered_datasets[name + "_testprivate"] 144 | 145 | residual_stream_baselines( 146 | pipeline=pipeline, 147 | task=causal_model, 148 | token_positions=token_positions, 149 | train_data=train_data, 150 | test_data=test_data, 151 | config=config, 152 | target_variables=["answer_pointer"], 153 | checker=checker, 154 | start=start, 155 | end=end, 156 | verbose=True, 157 | model_dir=os.path.join(args.model_dir, "answer_pointer"), 158 | results_dir=args.results_dir, 159 | methods=args.methods 160 | ) 161 | clear_memory() 162 | 163 | # Run experiments for answer (using larger feature size) 164 | if not args.skip_answer: 165 | print(f"\nRunning experiments for target variable: answer") 166 | config["n_features"] = pipeline.model.config.hidden_size // 2 167 | 168 | # Prepare train and test data dictionaries 169 | train_data = {} 170 | test_data = {} 171 | 172 | for name in names: 173 | if name + "_train" in filtered_datasets: 174 | train_data[name + "_train"] = filtered_datasets[name + "_train"] 175 | if name + "_validation" in filtered_datasets: 176 | train_data[name + "_validation"] = filtered_datasets[name + "_validation"] 177 | if name + "_test" in filtered_datasets: 178 | test_data[name + "_test"] = filtered_datasets[name + "_test"] 179 | if name + "_testprivate" in filtered_datasets: 180 | test_data[name + "_testprivate"] = filtered_datasets[name + "_testprivate"] 181 | 182 | config["n_features"] = pipeline.model.config.hidden_size // 2 183 | residual_stream_baselines( 184 | pipeline=pipeline, 185 | task=causal_model, 186 | token_positions=token_positions, 187 | train_data=train_data, 188 | test_data=test_data, 189 | config=config, 190 | target_variables=["answer"], 191 | checker=checker, 192 | start=start, 193 | end=end, 194 | verbose=True, 195 | model_dir=os.path.join(args.model_dir, "answer"), 196 | results_dir=args.results_dir, 197 | methods=args.methods 198 | ) 199 | clear_memory() 200 | 201 | # Clean up pipeline to free memory before starting next model 202 | del pipeline 203 | clear_memory() 204 | 205 | print("\nAll experiments completed.") -------------------------------------------------------------------------------- /MIB-causal-variable-track/tasks/RAVEL/ravel.py: -------------------------------------------------------------------------------- 1 | import sys, os, json, random, re 2 | from pathlib import Path 3 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 4 | sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) 5 | 6 | from CausalAbstraction.causal.causal_model import CausalModel 7 | from CausalAbstraction.neural.LM_units import TokenPosition, get_last_token_index 8 | 9 | 10 | # Get the current file's directory 11 | current_dir = os.path.dirname(os.path.abspath(__file__)) 12 | 13 | # Get the parent directory of the parent directory (grandparent) 14 | grandparent_dir = os.path.dirname(os.path.dirname(current_dir)) 15 | 16 | # Add the grandparent directory to the path 17 | sys.path.append(grandparent_dir) 18 | 19 | from copy import deepcopy 20 | from tasks.hf_dataloader import load_hf_dataset 21 | 22 | 23 | # Load RAVEL metadata - only keep the essential city entity data 24 | _RAVEL_DATA_DIR = os.path.dirname(os.path.abspath(__file__)) 25 | _CITY_ENTITY = {} 26 | 27 | 28 | def load_city_entity_data(): 29 | """Load city entity data if not already loaded.""" 30 | if not _CITY_ENTITY: 31 | city_data_path = os.path.join(_RAVEL_DATA_DIR, 'ravel_city_entity.json') 32 | if os.path.exists(city_data_path): 33 | _CITY_ENTITY.update(json.load(open(city_data_path))) 34 | else: 35 | # If file doesn't exist, create a minimal set for testing 36 | _CITY_ENTITY.update({ 37 | "Paris": {"Continent": "Europe", "Country": "France", "Language": "French"}, 38 | "Tokyo": {"Continent": "Asia", "Country": "Japan", "Language": "Japanese"}, 39 | "New York": {"Continent": "North America", "Country": "United States", "Language": "English"}, 40 | # Add more as needed 41 | }) 42 | 43 | 44 | def get_causal_model(): 45 | """ 46 | Create and return the causal model for RAVEL task. 47 | """ 48 | # Ensure city data is loaded 49 | load_city_entity_data() 50 | 51 | # Define variables 52 | attributes = ["Continent", "Country", "Language"] 53 | variables = ["raw_input", "entity", "queried_attribute", "answer", "raw_output"] + attributes 54 | 55 | # Define parent relationships 56 | parents = { 57 | "raw_input": ["entity", "queried_attribute"], # raw_input depends on entity and attribute 58 | "entity": [], 59 | "queried_attribute": [], 60 | "answer": ["entity", "queried_attribute", "Continent", "Country", "Language"], 61 | "Continent": ["entity"], 62 | "Country": ["entity"], 63 | "Language": ["entity"], 64 | "raw_output": ["answer"] # raw_output depends on answer 65 | } 66 | 67 | # Define possible values for each variable 68 | values = { 69 | "raw_input": [""], # Placeholder, generated by mechanism 70 | "entity": list(_CITY_ENTITY.keys()), 71 | "queried_attribute": attributes + ["wikipedia"], # Include wikipedia as a possible query 72 | "answer": [""], # Will be populated with all possible answers 73 | "Continent": list(set(city["Continent"] for city in _CITY_ENTITY.values())), 74 | "Country": list(set(city["Country"] for city in _CITY_ENTITY.values())), 75 | "Language": list(set(city["Language"] for city in _CITY_ENTITY.values())), 76 | "raw_output": [""] # Placeholder, generated by mechanism 77 | } 78 | 79 | # Collect all possible answers 80 | all_answers = set() 81 | for city_data in _CITY_ENTITY.values(): 82 | all_answers.update(city_data.values()) 83 | all_answers.add("") # For wikipedia queries 84 | values["answer"] = list(all_answers) 85 | 86 | # Define mechanisms 87 | def get_raw_input(entity, queried_attribute): 88 | """Generate the input prompt based on entity and attribute.""" 89 | if queried_attribute == "wikipedia": 90 | return f"Q: What is {entity}? A:" 91 | else: 92 | return f"Q: What is the {queried_attribute.lower()} of {entity}? A:" 93 | 94 | def get_entity(): 95 | """Randomly select an entity.""" 96 | return random.choice(list(_CITY_ENTITY.keys())) 97 | 98 | def get_queried_attribute(): 99 | """Randomly select an attribute to query.""" 100 | return random.choice(attributes + ["wikipedia"]) 101 | 102 | def get_answer(entity, q_attr, *attr_values): 103 | """Get the answer based on entity and queried attribute.""" 104 | if q_attr == "wikipedia": 105 | return "" # Empty answer for wikipedia queries 106 | idx = attributes.index(q_attr) 107 | return attr_values[idx] 108 | 109 | def get_continent(entity): 110 | """Get continent for the entity.""" 111 | return _CITY_ENTITY[entity]["Continent"] 112 | 113 | def get_country(entity): 114 | """Get country for the entity.""" 115 | return _CITY_ENTITY[entity]["Country"] 116 | 117 | def get_language(entity): 118 | """Get language for the entity.""" 119 | return _CITY_ENTITY[entity]["Language"] 120 | 121 | def get_raw_output(answer): 122 | """Format the output.""" 123 | return f" {answer}" if answer else "" 124 | 125 | mechanisms = { 126 | "raw_input": get_raw_input, 127 | "entity": get_entity, 128 | "queried_attribute": get_queried_attribute, 129 | "answer": get_answer, 130 | "Continent": get_continent, 131 | "Country": get_country, 132 | "Language": get_language, 133 | "raw_output": get_raw_output 134 | } 135 | 136 | return CausalModel(variables, values, parents, mechanisms, id="RAVEL") 137 | 138 | 139 | def get_counterfactual_datasets(hf=True, size=None, load_private_data=False): 140 | """ 141 | Load and return counterfactual datasets for RAVEL task. 142 | """ 143 | # Ensure city data is loaded 144 | load_city_entity_data() 145 | 146 | if hf: 147 | # Load dataset from HuggingFace with customized parsing 148 | datasets = {} 149 | for split in ["train", "val", "test"]: 150 | temp = load_hf_dataset( 151 | dataset_path="mib-bench/ravel", 152 | split=split, 153 | parse_fn=parse_ravel_example, 154 | size=size, 155 | shuffle=True 156 | ) 157 | datasets.update(temp) 158 | 159 | # Load private test set 160 | if load_private_data: 161 | private = load_hf_dataset( 162 | dataset_path="mib-bench/ravel_private_test", 163 | split="test", 164 | parse_fn=parse_ravel_example, 165 | size=size, 166 | shuffle=True 167 | ) 168 | datasets.update({k+"private": v for k, v in private.items()}) 169 | 170 | return datasets 171 | 172 | # Non-HF implementation would go here if needed 173 | return {} 174 | 175 | 176 | def get_token_positions(pipeline, causal_model): 177 | """ 178 | Get token positions for RAVEL task interventions. 179 | """ 180 | def get_entity_last_token_position(input_dict, pipeline): 181 | """ 182 | Find the last token position of the entity in the prompt. 183 | 184 | Args: 185 | input_dict: Dictionary containing the input data 186 | pipeline: LMPipeline for tokenization 187 | 188 | Returns: 189 | List containing the token index of the last entity token 190 | """ 191 | # Get the prompt and entity 192 | if isinstance(input_dict, dict): 193 | prompt = input_dict.get("raw_input", "") 194 | # Run causal model to get the entity 195 | setting = causal_model.run_forward(input_dict) 196 | entity = setting["entity"] 197 | else: 198 | # Fallback if input is just a string 199 | prompt = input_dict 200 | # Try to extract entity from the prompt 201 | entity = None 202 | for city in _CITY_ENTITY.keys(): 203 | if city in prompt: 204 | entity = city 205 | break 206 | 207 | if not entity: 208 | raise ValueError(f"Could not find entity in prompt: {prompt}") 209 | 210 | # Find the entity in the prompt 211 | entity_match = re.search(r'\b' + re.escape(entity) + r'\b', prompt) 212 | if not entity_match: 213 | raise ValueError(f"Entity '{entity}' not found in prompt: {prompt}") 214 | 215 | # Get the substring up to the end of the entity 216 | substring = prompt[:entity_match.end()] 217 | 218 | # Tokenize the substring 219 | tokens = pipeline.load(substring)["input_ids"][0] 220 | 221 | # The last token of the entity is at the end of the tokenized substring 222 | return [len(tokens) - 1] 223 | 224 | # Create TokenPosition objects 225 | token_positions = [ 226 | TokenPosition(lambda x: get_last_token_index(x, pipeline), pipeline, id="last_token"), 227 | TokenPosition(lambda x: get_entity_last_token_position(x, pipeline), pipeline, id="entity_last_token") 228 | ] 229 | 230 | return token_positions 231 | 232 | 233 | def parse_ravel_example(row): 234 | """ 235 | Convert a single dataset row into a dict for the RAVEL causal model. 236 | 237 | Args: 238 | row: A row from the RAVEL dataset 239 | 240 | Returns: 241 | Dict containing the parsed variables 242 | """ 243 | # Extract the basic information 244 | entity = row.get("entity", "") 245 | attribute = row.get("attribute", "") 246 | prompt = row.get("prompt", "") 247 | 248 | # Create the variables dictionary 249 | variables_dict = { 250 | "entity": entity, 251 | "queried_attribute": attribute, 252 | } 253 | 254 | # If we have a prompt, we can use it as raw_input 255 | if prompt: 256 | variables_dict["raw_input"] = prompt 257 | 258 | return variables_dict -------------------------------------------------------------------------------- /MIB-causal-variable-track/process_all_submissions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Process all submission folders by running appropriate evaluation scripts and aggregating results. 4 | 5 | This script: 6 | 1. Scans a parent directory for submission folders 7 | 2. Identifies IOI tasks (folders containing subfolders with "ioi" in the name) 8 | 3. Runs ioi_evaluate_submission.py for IOI tasks 9 | 4. Runs evaluate_submission.py for other tasks 10 | 5. Aggregates results using aggregate_results.py 11 | 6. Saves aggregated results to output folder 12 | 13 | Usage: 14 | python process_all_submissions.py --parent_dir submissions/ --output_dir results/ 15 | """ 16 | 17 | import os 18 | import sys 19 | import subprocess 20 | import argparse 21 | import json 22 | import shutil 23 | from pathlib import Path 24 | 25 | 26 | def has_ioi_subfolder(submission_path): 27 | """ 28 | Check if a submission folder contains any subfolder with 'ioi' in the name. 29 | 30 | Args: 31 | submission_path (str): Path to submission folder 32 | 33 | Returns: 34 | bool: True if IOI subfolder found 35 | """ 36 | try: 37 | for item in os.listdir(submission_path): 38 | item_path = os.path.join(submission_path, item) 39 | if os.path.isdir(item_path) and 'ioi' in item.lower(): 40 | return True 41 | except Exception as e: 42 | print(f"Error checking {submission_path}: {e}") 43 | return False 44 | 45 | 46 | def run_command(cmd, description): 47 | """ 48 | Run a command and capture output. 49 | 50 | Args: 51 | cmd (list): Command and arguments 52 | description (str): Description of what the command does 53 | 54 | Returns: 55 | bool: True if successful 56 | """ 57 | print(f"\n{'='*60}") 58 | print(f"Running: {description}") 59 | print(f"Command: {' '.join(cmd)}") 60 | print(f"{'='*60}") 61 | 62 | try: 63 | result = subprocess.run(cmd, capture_output=True, text=True) 64 | 65 | # Print output 66 | if result.stdout: 67 | print(result.stdout) 68 | if result.stderr: 69 | print("STDERR:", result.stderr) 70 | 71 | if result.returncode != 0: 72 | print(f"ERROR: Command failed with return code {result.returncode}") 73 | return False 74 | 75 | return True 76 | except Exception as e: 77 | print(f"ERROR running command: {e}") 78 | return False 79 | 80 | 81 | def process_submission(submission_path, output_dir, private_data=True): 82 | """ 83 | Process a single submission folder. 84 | 85 | Args: 86 | submission_path (str): Path to submission folder 87 | output_dir (str): Directory to save results 88 | private_data (bool): Whether to evaluate on private data 89 | 90 | Returns: 91 | dict: Results summary 92 | """ 93 | submission_name = os.path.basename(submission_path) 94 | print(f"\n{'#'*80}") 95 | print(f"# Processing submission: {submission_name}") 96 | print(f"{'#'*80}") 97 | 98 | results = { 99 | "submission": submission_name, 100 | "has_ioi": False, 101 | "evaluation_success": False, 102 | "aggregation_success": False, 103 | "error": None 104 | } 105 | 106 | # Check if this is an IOI submission 107 | is_ioi = has_ioi_subfolder(submission_path) 108 | results["has_ioi"] = is_ioi 109 | 110 | # Create output directory for this submission 111 | submission_output_dir = os.path.join(output_dir, submission_name) 112 | os.makedirs(submission_output_dir, exist_ok=True) 113 | 114 | try: 115 | if is_ioi: 116 | # Run IOI evaluation 117 | print(f"Detected IOI submission - will evaluate using ioi_linear_params.json") 118 | 119 | cmd = [ 120 | sys.executable, 121 | "ioi_evaluate_submission.py", 122 | "--submission_folder", submission_path 123 | ] 124 | 125 | if private_data: 126 | cmd.append("--private_data") 127 | 128 | success = run_command(cmd, "IOI evaluation") 129 | results["evaluation_success"] = success 130 | 131 | else: 132 | # Run standard evaluation 133 | print(f"Detected standard submission - will run evaluate_submission.py") 134 | 135 | cmd = [ 136 | sys.executable, 137 | "evaluate_submission.py", 138 | "--submission_folder", submission_path 139 | ] 140 | 141 | if private_data: 142 | cmd.append("--private_data") 143 | 144 | success = run_command(cmd, "Standard evaluation") 145 | results["evaluation_success"] = success 146 | 147 | if results["evaluation_success"]: 148 | # Find all JSON result files in the submission folder 149 | json_files = [] 150 | for root, dirs, files in os.walk(submission_path): 151 | for file in files: 152 | if file.endswith('.json') and 'results' in file: 153 | json_files.append(os.path.join(root, file)) 154 | 155 | if json_files: 156 | print(f"\nFound {len(json_files)} result files to aggregate") 157 | 158 | # Copy result files to output directory for aggregation 159 | temp_results_dir = os.path.join(submission_output_dir, "temp_results") 160 | os.makedirs(temp_results_dir, exist_ok=True) 161 | 162 | for json_file in json_files: 163 | dest_path = os.path.join(temp_results_dir, os.path.basename(json_file)) 164 | shutil.copy2(json_file, dest_path) 165 | 166 | # Run aggregation 167 | aggregated_output = os.path.join(submission_output_dir, "aggregated_results.json") 168 | 169 | cmd = [ 170 | sys.executable, 171 | "aggregate_results.py", 172 | "--folder_path", temp_results_dir, 173 | "--output", aggregated_output 174 | ] 175 | 176 | if private_data: 177 | cmd.append("--private") 178 | 179 | success = run_command(cmd, "Result aggregation") 180 | results["aggregation_success"] = success 181 | 182 | # Clean up temp directory 183 | shutil.rmtree(temp_results_dir) 184 | 185 | # Save summary 186 | summary_path = os.path.join(submission_output_dir, "processing_summary.json") 187 | with open(summary_path, 'w') as f: 188 | json.dump(results, f, indent=2) 189 | 190 | else: 191 | print("WARNING: No result files found after evaluation") 192 | results["error"] = "No result files generated" 193 | 194 | except Exception as e: 195 | print(f"ERROR processing submission: {e}") 196 | results["error"] = str(e) 197 | 198 | return results 199 | 200 | 201 | def main(): 202 | parser = argparse.ArgumentParser(description="Process all submission folders") 203 | parser.add_argument("--parent_dir", required=True, 204 | help="Parent directory containing submission folders") 205 | parser.add_argument("--output_dir", required=True, 206 | help="Directory to save aggregated results") 207 | parser.add_argument("--private_data", action="store_true", default=True, 208 | help="Evaluate on private test data (default: True)") 209 | parser.add_argument("--specific_submission", type=str, default=None, 210 | help="Process only a specific submission folder") 211 | 212 | args = parser.parse_args() 213 | 214 | parent_dir = os.path.abspath(args.parent_dir) 215 | output_dir = os.path.abspath(args.output_dir) 216 | 217 | if not os.path.exists(parent_dir): 218 | print(f"ERROR: Parent directory does not exist: {parent_dir}") 219 | return 1 220 | 221 | # Create output directory 222 | os.makedirs(output_dir, exist_ok=True) 223 | 224 | print(f"Processing submissions in: {parent_dir}") 225 | print(f"Output directory: {output_dir}") 226 | print(f"Private data evaluation: {args.private_data}") 227 | 228 | # Find all submission folders 229 | submission_folders = [] 230 | for item in os.listdir(parent_dir): 231 | item_path = os.path.join(parent_dir, item) 232 | if os.path.isdir(item_path) and not item.startswith('.'): 233 | # Skip special directories 234 | if item in ['__pycache__', '.git', 'results', 'logs']: 235 | continue 236 | if args.specific_submission is None or item == args.specific_submission: 237 | submission_folders.append(item_path) 238 | 239 | if not submission_folders: 240 | print("ERROR: No submission folders found") 241 | return 1 242 | 243 | print(f"\nFound {len(submission_folders)} submission folders to process") 244 | 245 | # Process each submission 246 | all_results = [] 247 | successful = 0 248 | 249 | for submission_path in submission_folders: 250 | results = process_submission(submission_path, output_dir, args.private_data) 251 | all_results.append(results) 252 | 253 | if results["evaluation_success"] and results["aggregation_success"]: 254 | successful += 1 255 | 256 | # Save overall summary 257 | summary_path = os.path.join(output_dir, "overall_summary.json") 258 | with open(summary_path, 'w') as f: 259 | json.dump({ 260 | "total_submissions": len(submission_folders), 261 | "successful": successful, 262 | "results": all_results 263 | }, f, indent=2) 264 | 265 | print(f"\n{'='*80}") 266 | print(f"PROCESSING COMPLETE") 267 | print(f"Successfully processed: {successful}/{len(submission_folders)} submissions") 268 | print(f"Results saved to: {output_dir}") 269 | print(f"Overall summary: {summary_path}") 270 | print(f"{'='*80}") 271 | 272 | return 0 if successful == len(submission_folders) else 1 273 | 274 | 275 | if __name__ == "__main__": 276 | sys.exit(main()) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 Aaron Mueller, Atticus Geiger, Sarah Wiegreffe, and 190 | other authors of MIB: A Mechanistic Interpretability Benchmark. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MIB-causal-variable-track/README.md: -------------------------------------------------------------------------------- 1 | # MIB Causal Variable Localization Track 2 | 3 | This repository contains the implementation of the Causal Variable Localization Track of the **M**echanistic **I**nterpretability **B**enchmark (**MIB**). This track benchmarks featurization methods—i.e., methods for transforming model activations into a space where it's easier to isolate and intervene on specific causal variables. 4 | 5 |
11 | 12 | ## Leaderboard Submission Overview 13 | 14 | The Causal Variable Localization Track evaluates methods that can identify and manipulate specific causal variables within language model (LM) representations. A **high-level causal model** serves as a hypothesis about how the **low-level LM** solves the task. A **featurizer** is an invertible function that maps from the activation space of a hidden vector to a new space and we refer to the dimensions of this new space as **features**. A submission to this track will align a variable from the causal model with features of a hidden vector in the LM. As such a submission will consist of a **token position** that identify residual stream or attention head hidden vectors in an LM and a **featurizer** that maps from that hidden vectors activation space into a new feature space. **The submission can be for any number of layers, but only submissions that provide a token position and featurizer for every layer will be added to the leaderboard for average performance, only best performance.** 15 | 16 | See the Jupyter notebook [ioi_example_submission.ipynb](https://github.com/aaronmueller/MIB/blob/main/MIB-causal-variable-track/ioi_example_submission.ipynb) for an example of how to get the trained featurizer and token indices files for the ioi task and the Jupyter notebook [ioi_example_submission.ipynb](https://github.com/aaronmueller/MIB/blob/main/MIB-causal-variable-track/example_submission.ipynb) for an example of how to get the trained featurizer and token indices files for the rest of the tasks. These notebooks will save models in mock_submission folders. Currently the [mock_submission](https://github.com/aaronmueller/MIB/blob/main/MIB-causal-variable-track/) contains a mock featurizer file and mock token position file. **If not feauturizer or token position is provided in the submission, then the featurizers and token positions from the original baselines can still be used.** 17 | 18 | The resulting submission must be formatted in a folder structure shown below and then submitted to the leaderboard at [this link](https://huggingface.co/spaces/mib-bench/leaderboard). See an example submission huggingface repo at [this link](https://huggingface.co/atticusg/SampleSubmission/tree/main). Run the verification script [verify_submission.py](https://github.com/aaronmueller/MIB/blob/main/MIB-causal-variable-track/verify_submission.py) to ensure that custom token positions and featurizers are loaded properly. When we receive your submission, we will run the evaluation scripts [evalute_submission.py](https://github.com/aaronmueller/MIB/blob/main/MIB-causal-variable-track/evaluate_submission.py) and [ioi_evaluate_submission.py](https://github.com/aaronmueller/MIB/blob/main/MIB-causal-variable-track/ioi_evaluation_submission.py) on the private test sets and report the results on the leaderboard. 19 | 20 | ## Repository Structure 21 | 22 | ``` 23 | MIB-causal-variable-track/ 24 | ├── tasks/ # Task definitions and datasets 25 | │ ├── IOI_task/ # Indirect Object Identification task 26 | │ ├── simple_MCQA/ # Simple Multiple Choice QA task 27 | │ ├── ARC/ # ARC (AI2 Reasoning Challenge) task 28 | │ ├── two_digit_addition_task/ # Arithmetic task 29 | │ └── RAVEL/ # RAVEL knowledge editing task 30 | ├── CausalAbstraction/ # Core functionality submodule 31 | │ ├── causal/ # Causal models and datasets 32 | │ ├── neural/ # Neural network units and featurizers 33 | │ └── experiments/ # Intervention experiments 34 | ├── baselines/ # Baseline method implementations 35 | ├── example_submission.ipynb # Example submission for standard tasks 36 | ├── ioi_example_submission.ipynb # Example submission for IOI task 37 | ├── evaluate_submission.py # Evaluation script for standard tasks 38 | ├── ioi_evaluate_submission.py # Evaluation script for IOI task 39 | ├── aggregate_results.py # Results aggregation script 40 | ├── process_all_submissions.py # Batch submission processing 41 | └── verify_submission.py # Submission format verification 42 | ``` 43 | 44 | ## Getting Started 45 | 46 | ### Installation 47 | 48 | 1. Clone this repository: 49 | ```bash 50 | git clone https://github.com/your-repo/MIB-causal-variable-track.git 51 | cd MIB-causal-variable-track 52 | ``` 53 | 54 | 2. Install dependencies: 55 | ```bash 56 | pip install -r requirements.txt 57 | ``` 58 | 59 | 3. Initialize the CausalAbstraction submodule: 60 | ```bash 61 | git submodule update --init --recursive 62 | ``` 63 | 64 | **sae_lens requires python 3.12** 65 | 66 | ### Running Example Submissions 67 | 68 | We provide two example notebooks that demonstrate how to create submissions: 69 | 70 | 1. **Standard Tasks Example** (`example_submission.ipynb`): Shows how to train and submit featurizers for tasks like Simple MCQA, ARC, Arithmetic, and RAVEL 71 | 2. **IOI Task Example** (`ioi_example_submission.ipynb`): Demonstrates the specific workflow for the IOI (Indirect Object Identification) task which involves attention heads 72 | 73 | ## Available Tasks 74 | 75 | The benchmark includes five tasks: 76 | 77 | - **IOI (Indirect Object Identification)**: Tests the model's ability to track entity relationships and predict the correct indirect object 78 | - **Simple MCQA**: Multiple choice question answering with object-color associations 79 | - **ARC**: AI2 Reasoning Challenge questions testing scientific reasoning 80 | - **Arithmetic**: Two-digit addition problems testing numerical computation 81 | - **RAVEL**: Knowledge editing task for factual associations about countries 82 | 83 | Each task defines specific causal variables that methods must learn to identify and manipulate. 84 | 85 | ## Supported Models 86 | 87 | The benchmark evaluates methods on four language models: 88 | - GPT-2 Small 89 | - Qwen-2.5 (0.5B) 90 | - Gemma-2 (2B) 91 | - Llama-3.1 (8B) 92 | 93 | ## Creating a Submission 94 | 95 | ### Submission Structure 96 | 97 | Your submission should follow this directory structure: 98 | 99 | ``` 100 | submission_folder/ 101 | ├── featurizer.py # (Optional) Custom featurizer implementation 102 | ├── token_position.py # (Optional) Custom token position logic 103 | └── {TASK}_{MODEL}_{VARIABLE}/ 104 | ├── {ModelUnit}_featurizer 105 | ├── {ModelUnit}_inverse_featurizer 106 | └── {ModelUnit}_indices 107 | ``` 108 | 109 | For IOI tasks, you'll also need linear parameters for the high-level causal model: 110 | ``` 111 | submission_folder/ 112 | └── ioi_linear_params.json # Required for IOI submissions 113 | ``` 114 | These linear parameters can be computed from training data, e.g., 115 | ``` 116 | baselines/ioi_baselines/ioi_learn_linear_params.py --model qwen --heads_list "(4,5)" "(12,0)" 117 | ``` 118 | This code estimates the weight for the position and token variables in the IOI task for the attention head 5 in layer 4 and attention head 0 in layer 12 of the qwen model. 119 | 120 | 121 | ### Submission Components 122 | 123 | 1. **Featurizer Files**: PyTorch checkpoint files containing: 124 | - Trained featurizer weights 125 | - Inverse featurizer weights 126 | - Feature indices 127 | 128 | 2. **Optional Custom Code**: 129 | - `featurizer.py`: Custom featurizer class inheriting from `Featurizer` 130 | - `token_position.py`: Function returning `TokenPosition` objects 131 | 132 | 3. **Task-Specific Requirements**: 133 | - Standard tasks: Organize by task/model/variable naming convention 134 | - IOI task: Include linear parameters JSON file 135 | 136 | ### Verifying Your Submission 137 | 138 | Before submitting, verify your submission format: 139 | 140 | ```bash 141 | python verify_submission.py /path/to/your/submission 142 | ``` 143 | 144 | ## Evaluation 145 | 146 | Submissions are evaluated by: 147 | 148 | 1. Loading your trained featurizers 149 | 2. Performing interchange interventions on test datasets 150 | 3. Measuring accuracy of causal variable predictions 151 | 4. Computing metrics across layers and datasets 152 | 153 | Key metrics include: 154 | - **Average Score**: Mean accuracy across examples 155 | - **Average Accuracy Across Layers**: Mean of best accuracies per layer 156 | - **Highest Accuracy Across Layers**: Maximum accuracy achieved 157 | 158 | ## Running Baseline Methods Used in the Paper 159 | 160 | To replicate baseline results: 161 | 162 | ```bash 163 | # Run baselines for a specific task 164 | python baselines/simple_MCQA_baselines.py 165 | python baselines/ARC_baselines.py 166 | python baselines/arithmetic_baselines.py 167 | python baselines/ravel_baselines.py 168 | python baselines/ioi_baselines/ioi_baselines.py 169 | ``` 170 | 171 | ## Citation 172 | 173 | If you use this benchmark, please cite: 174 | 175 | ```bibtex 176 | @article{mib-2025, 177 | title = {{MIB}: A Mechanistic Interpretability Benchmark}, 178 | author = {Aaron Mueller and Atticus Geiger and Sarah Wiegreffe and Dana Arad and Iv{\'a}n Arcuschin and Adam Belfki and Yik Siu Chan and Jaden Fiotto-Kaufman and Tal Haklay and Michael Hanna and Jing Huang and Rohan Gupta and Yaniv Nikankin and Hadas Orgad and Nikhil Prakash and Anja Reusch and Aruna Sankaranarayanan and Shun Shao and Alessandro Stolfo and Martin Tutek and Amir Zur and David Bau and Yonatan Belinkov}, 179 | year = {2025}, 180 | journal = {CoRR}, 181 | volume = {arXiv:2504.13151}, 182 | url = {https://arxiv.org/abs/2504.13151v1} 183 | } 184 | ``` 185 | 186 | ## Resources 187 | 188 | - [MIB Paper](https://arxiv.org/abs/2504.13151) 189 | - [MIB Website](https://mib-bench.github.io) 190 | - [Leaderboard](https://huggingface.co/spaces/mib-bench/leaderboard) 191 | - [Datasets](https://huggingface.co/collections/mib-bench/mib-datasets-67f55273612ec3067a42a56b) 192 | - [Main MIB Repository](https://github.com/aaronmueller/MIB) 193 | -------------------------------------------------------------------------------- /MIB-causal-variable-track/baselines/ravel_baselines.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 4 | from tasks.RAVEL.ravel import get_token_positions, get_counterfactual_datasets, get_causal_model 5 | from experiments.aggregate_experiments import residual_stream_baselines 6 | from neural.pipeline import LMPipeline 7 | from experiments.filter_experiment import FilterExperiment 8 | from causal.counterfactual_dataset import CounterfactualDataset 9 | import torch 10 | import gc 11 | import os 12 | import re 13 | import random 14 | from copy import deepcopy 15 | 16 | if __name__ == "__main__": 17 | import argparse 18 | 19 | parser = argparse.ArgumentParser(description="Run RAVEL experiments with optional flags.") 20 | parser.add_argument("--skip_gemma", action="store_true", help="Skip running experiments for Gemma model.") 21 | parser.add_argument("--skip_llama", action="store_true", help="Skip running experiments for Llama model.") 22 | parser.add_argument("--use_gpu1", action="store_true", help="Use GPU1 instead of GPU0 if available.") 23 | parser.add_argument("--methods", nargs="+", 24 | default=["full_vector", "DAS", "DBM+SVD", "DBM+PCA", "DBM", "DBM+SAE"], 25 | help="List of methods to run") 26 | parser.add_argument("--batch_size", type=int, default=128, help="Batch size for training") 27 | parser.add_argument("--eval_batch_size", type=int, default=512, help="Batch size for evaluation") 28 | parser.add_argument("--results_dir", type=str, default="results_ravel", help="Directory to save results") 29 | parser.add_argument("--model_dir", type=str, default="ravel_models", help="Directory to save trained models") 30 | parser.add_argument("--quick_test", action="store_true", help="Run quick test with reduced dataset size and layers") 31 | args = parser.parse_args() 32 | 33 | # Clear memory before starting 34 | gc.collect() 35 | torch.cuda.empty_cache() 36 | 37 | # Device setup 38 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 39 | if args.use_gpu1 and torch.cuda.is_available(): 40 | device = "cuda:1" 41 | 42 | # Check function for evaluating model outputs 43 | def checker(output_text, expected): 44 | if output_text is None: 45 | return False 46 | 47 | output_clean = re.sub(r'[^\w\s]+', '', output_text.lower()).strip() 48 | expected_list = [e.strip().lower() for e in expected.split(',')] 49 | 50 | if any(part in output_clean for part in expected_list): 51 | return True 52 | 53 | # Edge cases 54 | if re.search(r'united states|united kingdom|czech republic', expected, re.IGNORECASE): 55 | raw_expected = expected.strip().lower().replace('the ', '') 56 | raw_output = output_text.strip().lower().replace('the ', '') 57 | if raw_output.startswith(raw_expected) or raw_output.startswith('england') or raw_output == "us": 58 | return True 59 | if re.search(r'south korea', expected, re.IGNORECASE): 60 | if output_clean.startswith('korea') or output_clean.startswith('south korea'): 61 | return True 62 | if re.search(r'persian|farsi', expected, re.IGNORECASE): 63 | if output_clean.startswith('persian') or output_clean.startswith('farsi'): 64 | return True 65 | if re.search(r'oceania', expected, re.IGNORECASE): 66 | if output_clean.startswith('australia'): 67 | return True 68 | if re.search(r'north america', expected, re.IGNORECASE): 69 | if 'north america' in output_clean or output_clean == 'na' or output_clean.startswith('america'): 70 | return True 71 | if re.search(r'mandarin|chinese', expected, re.IGNORECASE): 72 | if 'chinese' in output_clean or 'mandarin' in output_clean: 73 | return True 74 | 75 | return False 76 | 77 | # Function to clear memory between experiments 78 | def clear_memory(): 79 | gc.collect() 80 | if torch.cuda.is_available(): 81 | torch.cuda.empty_cache() 82 | torch.cuda.synchronize() 83 | 84 | def get_filtered_indices(dataset, variables_list, target_variable): 85 | """ 86 | Return a list of row indices to keep, where we keep total_size * ratio rows for each attribute. 87 | """ 88 | attr_to_indices = {attr: [] for attr in variables_list} 89 | for i in range(len(dataset)): 90 | example = dataset[i] 91 | # Access the input dict to get the queried_attribute 92 | if isinstance(example["input"], dict): 93 | attr = example["input"].get("queried_attribute") 94 | 95 | 96 | if attr in attr_to_indices: 97 | attr_to_indices[attr].append(i) 98 | 99 | half = len(attr_to_indices[target_variable]) 100 | for attr in variables_list: 101 | random.shuffle(attr_to_indices[attr]) 102 | 103 | final_indices = [] 104 | for attr in variables_list: 105 | if attr == target_variable: 106 | final_indices.extend(attr_to_indices[attr][:half]) 107 | else: 108 | final_indices.extend(attr_to_indices[attr][:int(half//2)]) 109 | 110 | random.shuffle(final_indices) 111 | return final_indices 112 | 113 | def filter_dataset_by_attribute(dataset, variables_list, target_variable): 114 | """ 115 | Filter a CounterfactualDataset by attribute, returning a new filtered dataset. 116 | """ 117 | # Get the indices to filter 118 | final_indices = get_filtered_indices(dataset.dataset, variables_list=variables_list, target_variable=target_variable) 119 | 120 | # Create new filtered dataset 121 | filtered_hf_dataset = dataset.dataset.select(final_indices) 122 | 123 | # Return a new CounterfactualDataset with the filtered data 124 | return CounterfactualDataset(dataset=filtered_hf_dataset, id=dataset.id) 125 | 126 | # Get counterfactual datasets and causal model once 127 | dataset_size = 10 if args.quick_test else 10000 128 | counterfactual_datasets = get_counterfactual_datasets(hf=True, size=dataset_size) 129 | causal_model = get_causal_model() 130 | 131 | # Print available datasets 132 | print("Available datasets:", counterfactual_datasets.keys()) 133 | 134 | # Set up models to test 135 | models = [] 136 | if not args.skip_gemma: 137 | models.append("google/gemma-2-2b") 138 | if not args.skip_llama: 139 | models.append("meta-llama/Meta-Llama-3.1-8B-Instruct") 140 | 141 | # Make sure results and model directories exist 142 | if not os.path.exists(args.results_dir): 143 | os.makedirs(args.results_dir) 144 | 145 | if not os.path.exists(args.model_dir): 146 | os.makedirs(args.model_dir) 147 | 148 | for model_name in models: 149 | print(f"\n===== Testing model: {model_name} =====") 150 | 151 | # Set up LM Pipeline 152 | pipeline = LMPipeline(model_name, max_new_tokens=2, device=device, dtype=torch.float16) 153 | pipeline.tokenizer.padding_side = "left" 154 | print("DEVICE:", pipeline.model.device) 155 | 156 | # Model-specific batch size adjustments 157 | if "gemma" in model_name: 158 | model_batch_size = args.batch_size 159 | model_eval_batch_size = args.eval_batch_size 160 | else: 161 | # Llama typically needs smaller batch sizes 162 | model_batch_size = 32 163 | model_eval_batch_size = 256 164 | 165 | # Get a sample input and check model's prediction 166 | sampled_example = next(iter(counterfactual_datasets.values()))[0] 167 | print("INPUT:", sampled_example["input"]) 168 | print("EXPECTED OUTPUT:", causal_model.run_forward(sampled_example["input"])["raw_output"]) 169 | print("MODEL PREDICTION:", pipeline.dump(pipeline.generate(sampled_example["input"]))) 170 | 171 | # Filter the datasets based on model performance 172 | print("\nFiltering datasets based on model performance...") 173 | exp = FilterExperiment(pipeline, causal_model, checker) 174 | filtered_datasets = exp.filter(counterfactual_datasets, verbose=True, batch_size=model_eval_batch_size) 175 | 176 | # Get token positions for intervention 177 | token_positions = get_token_positions(pipeline, causal_model) 178 | 179 | # Display token highlighting for a sample 180 | print("\nToken positions highlighted in samples:") 181 | for dataset in filtered_datasets.values(): 182 | for token_position in token_positions: 183 | example = dataset[0] 184 | print(token_position.highlight_selected_token(example["input"])) 185 | break 186 | break 187 | 188 | # Clear memory before running experiments 189 | clear_memory() 190 | 191 | # Setup experiment configuration 192 | start = 0 193 | end = 1 if args.quick_test else pipeline.get_num_layers() 194 | 195 | if "gemma" in model_name: 196 | config = { 197 | "batch_size": 128, 198 | "evaluation_batch_size": 512, 199 | "training_epoch": 1, 200 | "n_features": 288, 201 | "regularization_coefficient": 0.0, 202 | "output_scores": False 203 | } 204 | else: 205 | config = { 206 | "batch_size": 48, 207 | "evaluation_batch_size": 256, 208 | "training_epoch": 1, 209 | "n_features": 512, 210 | "regularization_coefficient": 0.0, 211 | "output_scores": False 212 | } 213 | 214 | all_attributes = ["Continent", "Country", "Language"] 215 | 216 | # Set which variables to localize 217 | target_variables = [ "Country", "Language", "Continent"] 218 | 219 | # Set which counterfactuals to use 220 | names = ["attribute", "wikipedia"] 221 | 222 | # Run experiments for each target variable 223 | for variable in target_variables: 224 | print(f"\nRunning experiments for target variable: {variable}") 225 | 226 | # Create deep copies of filtered datasets for attribute-specific filtering 227 | temp_datasets = deepcopy(filtered_datasets) 228 | 229 | # Prepare train and test data dictionaries with attribute-specific filtering 230 | train_data = {} 231 | test_data = {} 232 | 233 | for name in names: 234 | # Process training data 235 | if name + "_train" in temp_datasets: 236 | print(f"Original {name}_train size: {len(temp_datasets[name + '_train'])}") 237 | filtered_train = filter_dataset_by_attribute( 238 | temp_datasets[name + "_train"], 239 | variables_list=all_attributes, 240 | target_variable=variable 241 | ) 242 | print(f"Filtered {name}_train size: {len(filtered_train)}") 243 | train_data[name + "_train"] = filtered_train 244 | 245 | # Process test data 246 | if name + "_test" in temp_datasets: 247 | print(f"Original {name}_test size: {len(temp_datasets[name + '_test'])}") 248 | filtered_test = filter_dataset_by_attribute( 249 | temp_datasets[name + "_test"], 250 | variables_list=all_attributes, 251 | target_variable=variable 252 | ) 253 | print(f"Filtered {name}_test size: {len(filtered_test)}") 254 | test_data[name + "_test"] = filtered_test 255 | 256 | # Process private test data 257 | if name + "_testprivate" in temp_datasets: 258 | print(f"Original {name}_testprivate size: {len(temp_datasets[name + '_testprivate'])}") 259 | filtered_test_private = filter_dataset_by_attribute( 260 | temp_datasets[name + "_testprivate"], 261 | variables_list=all_attributes, 262 | target_variable=variable 263 | ) 264 | print(f"Filtered {name}_testprivate size: {len(filtered_test_private)}") 265 | test_data[name + "_testprivate"] = filtered_test_private 266 | 267 | # Run the baseline experiments 268 | residual_stream_baselines( 269 | pipeline=pipeline, 270 | task=causal_model, 271 | token_positions=token_positions, 272 | train_data=train_data, 273 | test_data=test_data, 274 | config=config, 275 | target_variables=[variable], 276 | checker=checker, 277 | start=start, 278 | end=end, 279 | verbose=True, 280 | model_dir=os.path.join(args.model_dir, variable), 281 | results_dir=args.results_dir, 282 | methods=args.methods 283 | ) 284 | clear_memory() 285 | 286 | # Clean up pipeline to free memory before starting next model 287 | del pipeline 288 | clear_memory() 289 | 290 | print("\nAll experiments completed.") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |
4 |
5 |
8 |
9 | A benchmark for systematic comparison of featurization and localization methods.
10 |
11 |
12 | circuits · causal variables · localization · featurization · faithfulness · interchange interventions · SAEs · counterfactuals
13 |
14 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
44 |
45 |
46 |
47 |
61 |
62 |
63 |
64 |