├── eap ├── __init__.py ├── attr_patching_per_input_dim_per_neuron.py ├── eap_wrapper.py ├── attr_patching.py └── eap_graph.py ├── docs ├── paper.pdf ├── img │ ├── arrow.gif │ ├── probing_acc.png │ ├── opening_heuristics.png │ ├── llama3-8b_localization.png │ ├── prompt_knockout_across_training.png │ ├── heuristics_intersection_across_training.png │ └── llama3_8b_70b_prompt_knockout_per_layer.png └── index.html ├── .gitignore ├── README.md ├── LICENSE ├── model_analysis_consts.py ├── metrics.py ├── script_linear_probe.py ├── component.py ├── attention_analysis.py ├── visualization_utils.py ├── linear_probing.py ├── circuit_utils.py ├── activation_patching.py ├── script_circuit_localization.py ├── circuit.py ├── script_per_neuron_analysis.py ├── script_topk_neuron_eval.py ├── script_eval_pythia_faithfulness_only_mutual_neurons.py ├── prompt_generation.py ├── path_patching.py ├── heuristics_analysis.py ├── general_utils.py └── script_analyze_model_heuristics.py /eap/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/technion-cs-nlp/llm-arithmetic-heuristics/HEAD/docs/paper.pdf -------------------------------------------------------------------------------- /docs/img/arrow.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/technion-cs-nlp/llm-arithmetic-heuristics/HEAD/docs/img/arrow.gif -------------------------------------------------------------------------------- /docs/img/probing_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/technion-cs-nlp/llm-arithmetic-heuristics/HEAD/docs/img/probing_acc.png -------------------------------------------------------------------------------- /docs/img/opening_heuristics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/technion-cs-nlp/llm-arithmetic-heuristics/HEAD/docs/img/opening_heuristics.png -------------------------------------------------------------------------------- /docs/img/llama3-8b_localization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/technion-cs-nlp/llm-arithmetic-heuristics/HEAD/docs/img/llama3-8b_localization.png -------------------------------------------------------------------------------- /docs/img/prompt_knockout_across_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/technion-cs-nlp/llm-arithmetic-heuristics/HEAD/docs/img/prompt_knockout_across_training.png -------------------------------------------------------------------------------- /docs/img/heuristics_intersection_across_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/technion-cs-nlp/llm-arithmetic-heuristics/HEAD/docs/img/heuristics_intersection_across_training.png -------------------------------------------------------------------------------- /docs/img/llama3_8b_70b_prompt_knockout_per_layer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/technion-cs-nlp/llm-arithmetic-heuristics/HEAD/docs/img/llama3_8b_70b_prompt_knockout_per_layer.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | third_party/ 2 | nanoGPT/ 3 | mistral/ 4 | llama/ 5 | transformer_lens_101/ 6 | TransformerLensCode/ 7 | mcd_pp/ 8 | debug/ 9 | *.pyc 10 | .ipynb_checkpoints/ 11 | *attn_grid* 12 | *.log 13 | data/mean_cache* 14 | per_neuron_* 15 | mlp_input_* 16 | *.pt 17 | *.png 18 | .vscode 19 | data/pythia* 20 | data/old 21 | figs 22 | old/ 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llm-arithmetic-heuristics 2 | 3 | # [ICLR 2025] Arithmetic Without Algorithms: Language Models Solve Math With a Bag of Heuristics 4 | 5 | Official code for experiments and (and [website](https://technion-cs-nlp.github.io/llm-arithmetic-heuristics/)) of ["Arithmetic Without Algorithms" paper](https://arxiv.org/abs/2410.21272), accepted to ICLR 2025. 6 | 7 | 8 | ## Repository structure 9 | * The notebook files contain experimentation code and code to generate the data, results and figures for the paper. Specifically, `llm-arithmetic-analysis-main.ipynb` contains most of the code for general LLM arithmetic analysis presented in the paper, and `pythia-heuristics-analysis-notebook.ipynb` contains the relevant code for experiments across checkpoints, presented in section 5 of the paper. 10 | * All script files (`script_.*.py`) contain a separate-file version of some of the code from the notebooks, to run as GPU jobs. 11 | * Other files contain processes used in experiments (activation patching, faithfulness evaluations, heuristic classification algorithm, etc). 12 | * `docs` contains code for the project website. 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 technion-cs-nlp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model_analysis_consts.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass 4 | class ModelAnalysisConsts: 5 | max_single_token: int # The highest number n for which all numbers in [0, n] are represented by a single token 6 | first_heuristics_layer: int # The earliest layer where the model begins promoting the correct answer in the final position 7 | topk_neurons_per_layer: int # How many neurons in each middle- and late-layer MLP are required for high faithfulness? 8 | mlp_activations_also_negative: bool # Can mlp_post activations be negative? In models with GatedMLPs, this is True 9 | 10 | PYTHIA_6_9B_CONSTS = ModelAnalysisConsts(max_single_token=530, first_heuristics_layer=14, topk_neurons_per_layer=200, mlp_activations_also_negative=False) 11 | LLAMA3_70B_CONSTS = ModelAnalysisConsts(max_single_token=999, first_heuristics_layer=39, topk_neurons_per_layer=400, mlp_activations_also_negative=True) 12 | LLAMA3_8B_CONSTS = ModelAnalysisConsts(max_single_token=999, first_heuristics_layer=16, topk_neurons_per_layer=200, mlp_activations_also_negative=True) 13 | GPTJ_CONSTS = ModelAnalysisConsts(max_single_token=520, first_heuristics_layer=17, topk_neurons_per_layer=200, mlp_activations_also_negative=False) -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def logit_diff(logits: torch.Tensor, 5 | clean_labels: torch.Tensor, 6 | corrupt_labels: torch.Tensor): 7 | return logits.gather(1, clean_labels) - logits.gather(1, corrupt_labels) 8 | 9 | 10 | def indirect_effect(pre_patch_probs: torch.Tensor, 11 | post_patch_probs: torch.Tensor, 12 | clean_labels: torch.Tensor, 13 | corrupt_labels: torch.Tensor): 14 | """ 15 | Measure indirect effect of a patch on probabilities, as described in Eq. 2 of "Understanding Arithmetic 16 | Reasoning in Language Models using Causal Mediation Analysis". 17 | 18 | Args: 19 | pre_patch_probs (torch.Tensor (batch, vocab_size)): The probabilities before patching. 20 | post_patch_probs (torch.Tensor (batch, vocab_size)): The probabilities after patching. 21 | clean_labels (torch.Tensor(batch, 1)): The labels of the clean answers. 22 | corrupt_labels (torch.Tensor(batch, 1)): The label of the corrupt answers. 23 | 24 | Returns: 25 | torch.Tensor((batch,), dtype=torch.float32): The indirect effects for each prompt in the batch. The IE is not limited in magnitude. 26 | """ 27 | a = (post_patch_probs.gather(1, corrupt_labels) - pre_patch_probs.gather(1, corrupt_labels)) / pre_patch_probs.gather(1, corrupt_labels) 28 | b = (pre_patch_probs.gather(1, clean_labels) - post_patch_probs.gather(1, clean_labels)) / post_patch_probs.gather(1, clean_labels) 29 | return (a + b).squeeze(1) / 2 30 | -------------------------------------------------------------------------------- /script_linear_probe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import pickle 4 | import os 5 | import torch 6 | import logging 7 | from prompt_generation import separate_prompts_and_answers, OPERATORS 8 | from component import Component 9 | from general_utils import generate_activations, load_model, get_model_consts 10 | from linear_probing import linear_probe_across_layers 11 | 12 | 13 | torch.set_grad_enabled(False) 14 | device = 'cuda' 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--model_name', type=str, help='Name of the model to be loaded') 20 | parser.add_argument('--model_path', type=str, help='Path to the model to be loaded') 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | def main(): 26 | args = parse_args() 27 | model_name, model_path = args.model_name, args.model_path 28 | logging.info("Loading model") 29 | model = load_model(model_name, model_path, device) 30 | 31 | max_op = 300 32 | model_consts = get_model_consts(model_name) 33 | max_answer_value = model_consts.max_single_token 34 | large_prompts_and_answers = pickle.load(open(fr'./data/{model_name}/large_prompts_and_answers_max_op={max_op}.pkl', 'rb')) 35 | 36 | results_path = f"./data/{model_name}/probe_accs.pt" 37 | if os.path.exists(results_path): 38 | probe_accs = torch.load(results_path) 39 | else: 40 | probe_accs = {} 41 | 42 | for operator_idx in range(len(OPERATORS)): 43 | activations = None 44 | for pos_to_probe in [4, 3, 2, 1]: 45 | if (operator_idx, pos_to_probe) in probe_accs: 46 | print(f"Found results file for {operator_idx=}, {pos_to_probe=}, ") 47 | continue 48 | print(f"{operator_idx=}, {pos_to_probe=}") 49 | components = [Component('resid_post', layer=i) for i in range(model.cfg.n_layers)] 50 | correct_prompts = separate_prompts_and_answers(large_prompts_and_answers[operator_idx])[0] 51 | random.shuffle(correct_prompts) 52 | if activations is None: 53 | activations = generate_activations(model, correct_prompts, components, pos=None) 54 | pos_activations = {i: activations[i][:, pos_to_probe] for i in range(model.cfg.n_layers)} 55 | answers = torch.tensor([int(eval(prompt[:-1])) for prompt in correct_prompts]) 56 | probe_accs[(operator_idx, pos_to_probe)] = linear_probe_across_layers(model, pos_activations, answers, max_answer_value)[1] # [1] to get only test accs 57 | torch.save(probe_accs, results_path) 58 | 59 | 60 | if __name__ == '__main__': 61 | main() -------------------------------------------------------------------------------- /component.py: -------------------------------------------------------------------------------- 1 | import transformer_lens as lens 2 | 3 | 4 | class Component(): 5 | """ 6 | A wrapper class for a hookable component in a residual path in a transformer model. 7 | This extends the normal hooks functionality in transformer_lens by adding an optional 8 | head_idx parameter. 9 | """ 10 | def __init__(self, hook_name, layer=None, head=None, neurons=None): 11 | self.hook_name = hook_name 12 | self.layer = layer 13 | self.head_idx = head 14 | self.neuron_indices = tuple(neurons) if neurons is not None else None # Currently only supported for MLP neurons; Converted to tuple for hashability 15 | 16 | def __hash__(self): 17 | return hash((self.hook_name, self.layer, self.head_idx, self.neuron_indices)) 18 | 19 | def __eq__(self, other): 20 | # Compare two components by value and not by reference 21 | return self.hook_name == other.hook_name and \ 22 | self.layer == other.layer and \ 23 | self.head_idx == other.head_idx and \ 24 | self.neuron_indices == other.neuron_indices 25 | 26 | def valid_hook_name(self, layer=None) -> int: 27 | """ 28 | Get a valid hook name for this component, which can be used to set a TransformerLens hook. 29 | This valid name is compatible with TransformerLens, thus does not contain any head / neuron information. 30 | 31 | Args: 32 | layer (int): The layer to get the valid hook name for. If None, the layer is taken from the component. 33 | """ 34 | return lens.utils.get_act_name(name=self.hook_name, layer=layer or self.layer) 35 | 36 | @property 37 | def full_hook_name(self) -> str: 38 | """ 39 | Get the full hook name (without regard to the layer) for visualization purposes. 40 | """ 41 | if self.head_idx is not None: 42 | return f'{self.hook_name}.head{self.head_idx}' 43 | elif self.neuron_indices is not None: 44 | return f'{self.hook_name}.specific_neurons' 45 | return self.hook_name 46 | 47 | @property 48 | def is_mlp(self) -> bool: 49 | """ 50 | Check if the component is an MLP component. 51 | """ 52 | return 'mlp' in self.hook_name 53 | 54 | @property 55 | def is_attn(self) -> bool: 56 | """ 57 | Check if the component is an attention component. 58 | """ 59 | valid_hook_name = self.valid_hook_name() 60 | for attn_hook_name in ['attn', 'hook_q', 'hook_k', 'hook_v', 'hook_z', 'hook_pattern', 'hook_result']: 61 | if attn_hook_name in valid_hook_name: 62 | return True 63 | return False 64 | 65 | @property 66 | def is_qkv(self) -> bool: 67 | """ 68 | Checks if the component is a hook on either the Q/K/V tensors (post projection). 69 | """ 70 | valid_hook_name = self.valid_hook_name() 71 | return 'attn.hook_q' in valid_hook_name or 'attn.hook_k' in valid_hook_name or 'attn.hook_v' in valid_hook_name 72 | 73 | @property 74 | def is_resid(self) -> bool: 75 | """ 76 | Check if the component is a residual stream component. 77 | """ 78 | return 'resid' in self.hook_name 79 | 80 | def __repr__(self) -> str: 81 | """ 82 | Get the full hook name (with the layer). 83 | """ 84 | return f'blocks.{self.layer}.{self.full_hook_name}' 85 | 86 | def __lt__(self, other) -> bool: 87 | return self.layer < other.layer or \ 88 | (self.layer == other.layer and self.head_idx is not None and other.head_idx is not None and self.head_idx < other.head_idx) 89 | -------------------------------------------------------------------------------- /attention_analysis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformer_lens as lens 3 | from tqdm import tqdm 4 | from typing import List 5 | 6 | def two_operands_arithmetic_qk_heatmap(model: lens.HookedTransformer, 7 | operator: str = '+', 8 | maximal_operand_value: int = 100, 9 | dst_token_position: int = -1, 10 | show_progress: bool = True): 11 | """ 12 | Visualize the attention maps of a model, based on all possible combinations of 13 | input tokens at two operand positions (y axis for the first operand, x axis for 14 | the second operand). 15 | The resulting visualization is a heatmap (one for each layer, head index and src position), 16 | wehre the value at (x, y) represents the attention value at the (layer, head) from the last dst token 17 | to the src position. 18 | 19 | Args: 20 | model (lens.HookedTransformer): The model to visualize. 21 | operator (str): The operator to use for the calculation. 22 | maximal_operand_value (int): The maximal value for each of the two operands. 23 | show_progress (bool): Whether to show a progress bar. 24 | Returns: 25 | torch.Tensor (n_layers, n_heads, len(tokens), len(tokens), positions_per_prompt) - 26 | A matrix of attention maps where a cell at (l,h,x,y,p) represents the attention 27 | value at head h at layer l, from the dst token at the given position to token at 28 | position p, where the first operand is x and the second operand is y. 29 | """ 30 | positions_per_prompt = 5 # BOS, op1, operator, op2, = 31 | attention_pattern_values = torch.zeros((model.cfg.n_layers, model.cfg.n_heads, 32 | maximal_operand_value, maximal_operand_value, 33 | positions_per_prompt), dtype=torch.float16) 34 | 35 | progress = lambda x: tqdm(x) if show_progress else x 36 | for operand1 in progress(range(maximal_operand_value)): 37 | prompts = [f'{operand1}{operator}{operand2}=' for operand2 in range(0, maximal_operand_value)] 38 | dataloader = torch.utils.data.DataLoader(prompts, batch_size=32, shuffle=False) 39 | cur_idx = 0 40 | for batch in dataloader: 41 | _, cache = model.run_with_cache(batch) 42 | for layer in range(model.cfg.n_layers): 43 | for head_idx in range(model.cfg.n_heads): 44 | attention_pattern_values[layer, head_idx, operand1, cur_idx:cur_idx+len(batch), :] = \ 45 | cache[f'blocks.{layer}.attn.hook_pattern'][:, head_idx, dst_token_position, :] 46 | cur_idx += len(batch) 47 | del cache 48 | torch.cuda.empty_cache() 49 | 50 | return attention_pattern_values 51 | 52 | 53 | def ov_transition_analysis(model: lens.HookedTransformer, 54 | layer: int, 55 | head: int, 56 | words: List[str]): 57 | """ 58 | Visualize the OV transition matrix, defined as - 59 | W_Transition = W_U^T @ W_V @ W_O @ W_U 60 | 61 | Which defines how an OV circuit connects pairs of (input_token, output_token). 62 | This is taken from the paper "Analyzing Transformers in Embedding Space" (https://arxiv.org/abs/2209.02535). 63 | 64 | Args: 65 | model (lens.HookedTransformer): The model to visualize. 66 | layer (int): The layer of the attention head to look at. 67 | head (int): The head index of the attention head to look at. 68 | words (List[str]): A list of possible tokens to be considered for the (input, output) pairs. 69 | """ 70 | tokens = model.to_tokens(words, prepend_bos=False).view(-1) # T 71 | W_U = model.unembed.W_U[:, tokens] # d_model, T 72 | W_O = model.blocks[layer].attn.W_O[head] # d_head, d_model 73 | W_V = model.blocks[layer].attn.W_V[head] # d_model, d_head 74 | transition_matrix = W_U.T @ W_V @ W_O @ W_U # T, T 75 | return transition_matrix -------------------------------------------------------------------------------- /visualization_utils.py: -------------------------------------------------------------------------------- 1 | import transformer_lens as lens 2 | import plotly.express as px 3 | import plotly.graph_objects as go 4 | import circuitsvis as cv 5 | import torch 6 | from typing import List 7 | from component import Component 8 | 9 | 10 | def imshow(tensor, **kwargs): 11 | px.imshow(lens.utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show() 12 | 13 | 14 | def line(tensor, **kwargs): 15 | px.line(y=tensor, **kwargs).show() 16 | 17 | 18 | def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs): 19 | px.scatter(y=lens.utils.to_numpy(y), x=lens.utils.to_numpy(x), labels={"x": xaxis, "y": yaxis, "color": caxis}, **kwargs).show() 20 | 21 | 22 | def scatter_with_labels(x, y, hovertext, color=None, mode='markers', **layout_kwargs): 23 | fig = go.Figure(data=go.Scatter(x=x, y=y, mode=mode, hovertext=hovertext, marker=dict(color=color))) 24 | fig.update_layout(**layout_kwargs).show() 25 | 26 | 27 | def multiple_lines(x, y, line_titles, add_vlines_at_maximum=False, show_fig=True, hovertext=None, colors=None, **layout_kwargs): 28 | traces = [] 29 | colors = colors or px.colors.qualitative.Plotly 30 | for i in range(len(line_titles)): 31 | trace = go.Scatter(x=x, y=y[i], mode='lines', name=line_titles[i], hovertext=hovertext, line=dict(color=colors[i % len(colors)])) 32 | traces.append(trace) 33 | 34 | fig = go.Figure(traces) 35 | 36 | if add_vlines_at_maximum: 37 | for i, trace in enumerate(traces): 38 | fig.add_vline(x[y[i].argmax()], line_dash="dash", line_color=colors[i]) 39 | 40 | fig.update_layout(**layout_kwargs) 41 | 42 | if show_fig: 43 | fig.show() 44 | else: 45 | return fig 46 | 47 | 48 | def visualize_arithmetic_attention_patterns(model: lens.HookedTransformer, 49 | components: List[Component], 50 | prompts: List[str], 51 | use_bos_token: bool = True, 52 | return_raw_patterns: bool = False): 53 | """ 54 | Visualize the resulting attention patterns for a list of attention heads, averaged across a list of prompts. 55 | 56 | Args: 57 | model (lens.HookedTransformer): The model to visualize. 58 | components (List[Component]): The attention heads to visualize. If any components are not attention heads, they are ignored. 59 | prompts (List[str]): The prompts to pass through the model and average over. 60 | use_bos_token (bool): Should the BOS token be part of the prompt passed through the model. 61 | return_raw_patterns (bool): If True, the raw activation in the pattern hook are also returned. 62 | Returns: 63 | (circuitsvis.utils.render.RenderedHTML) - The rendered HTML visualization. 64 | """ 65 | prompt_loader = torch.utils.data.DataLoader(prompts, batch_size=32, shuffle=False) 66 | 67 | labels = [f'{head_component.layer}H{head_component.head_idx}' for head_component in components] 68 | patterns = [[] for _ in range(len(components))] 69 | for batch in prompt_loader: 70 | _, cache = model.run_with_cache(batch, return_type='logits', prepend_bos=use_bos_token) 71 | 72 | for i, head_component in enumerate(components): 73 | if head_component.head_idx is None: 74 | # Ignore non-head components (mlp etc) 75 | continue 76 | patterns[i].append(cache['pattern', head_component.layer].cpu()[:, head_component.head_idx]) 77 | del cache 78 | 79 | patterns = [torch.cat(p).mean(dim=0) for p in patterns] # Unify batches and mean across prompts to single tensor 80 | patterns = torch.stack(patterns, dim=0) # Convert list to a single tensor for visualization 81 | 82 | # Get the axis labels. In case the prompts are averaged 83 | if len(prompts) == 1: 84 | str_tokens = model.to_str_tokens(prompts[0], prepend_bos=use_bos_token) 85 | else: 86 | str_tokens = ['operand1', 'operator', 'operand2', '='] 87 | if use_bos_token: 88 | str_tokens.insert(0, 'BOS') 89 | 90 | heads_html = cv.attention.attention_heads(attention=patterns, tokens=str_tokens, attention_head_names=labels) 91 | if return_raw_patterns: 92 | return heads_html, patterns 93 | else: 94 | return heads_html 95 | -------------------------------------------------------------------------------- /linear_probing.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import transformer_lens as lens 3 | import torch 4 | import torch.nn as nn 5 | from general_utils import Metric 6 | 7 | 8 | def linear_probe_across_layers(model: lens.HookedTransformer, 9 | features: Dict[int, torch.Tensor], 10 | labels: torch.Tensor, 11 | possible_label_count: int = 100, 12 | train_test_split_percent: float = 0.8, 13 | train_epochs: int = 20, 14 | train_lr: float = 3e-4, 15 | device: str = 'cuda', 16 | verbose: bool = True, 17 | ): 18 | """ 19 | Perform a linear probing experiment across layers. 20 | The experiment trains a linear model on top of features (for every layer) to extract given labels, 21 | and measure the success rate. 22 | 23 | Args: 24 | model (lens.HookedTransformer): The TransformerLens model to probe. 25 | features (Dict[int, torch.Tensor]): The features extracted from the model to use as training and testing data for the probe. 26 | labels (torch.Tensor): The labels for the features. 27 | possible_label_count (int): The number of possible labels. Determines the output dimension of the linear probe. 28 | train_test_split_percent (float): The percentage of the data to use for training. 29 | train_epochs (int): The number of epochs to train the linear model. 30 | train_lr (float): The learning rate for the linear probe. 31 | device (str): The device to use for training the linear probe. 32 | verbose (bool): Whether to print accuracies and additional information during the training and testing of the linear probe. 33 | 34 | Returns: 35 | tuple(List(float), List(float)): The probing (train_accuracies, test_accuracies), where each list contains all layer accuracies. 36 | """ 37 | probe_accs = ([], []) 38 | 39 | for layer_to_probe in features.keys(): 40 | layer_features = features[layer_to_probe] 41 | 42 | # Define the probing datasets 43 | train_test_split = int(train_test_split_percent * len(layer_features)) 44 | train_dataset = torch.utils.data.TensorDataset(layer_features[:train_test_split], labels[:train_test_split]) 45 | test_dataset = torch.utils.data.TensorDataset(layer_features[train_test_split:], labels[train_test_split:]) 46 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) 47 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False) 48 | 49 | # Define the probing model, optimizer and loss 50 | linear_model = nn.Sequential( 51 | nn.Linear(model.cfg.d_model, possible_label_count), 52 | ) 53 | linear_model.to(device) 54 | linear_model.requires_grad_(True) 55 | optimizer = torch.optim.Adam(linear_model.parameters(), lr=train_lr) 56 | loss_fn = nn.CrossEntropyLoss() 57 | 58 | # Training and testing loop 59 | with torch.set_grad_enabled(True): 60 | for epoch in range(train_epochs): 61 | acc = Metric() 62 | for batch_idx, (batch_features, batch_answers) in enumerate(train_loader): 63 | optimizer.zero_grad() 64 | batch_features, batch_answers = batch_features.to(device), batch_answers.to(device) 65 | logits = linear_model(batch_features) 66 | loss = loss_fn(logits, batch_answers) 67 | acc.update((logits.argmax(dim=1) == batch_answers).float().mean().item()) 68 | loss.backward() 69 | optimizer.step() 70 | 71 | with torch.no_grad(): 72 | test_acc = Metric() 73 | for batch_idx, (batch_features, batch_answers) in enumerate(test_loader): 74 | batch_features, batch_answers = batch_features.to(device), batch_answers.to(device) 75 | logits = linear_model(batch_features) 76 | test_acc.update((logits.argmax(dim=1) == batch_answers).float().mean().item()) 77 | 78 | if verbose: 79 | print(f'Epoch {epoch+1}/{train_epochs}: Loss {loss.item():.3f}\t Train Accuracy: {acc.avg}\t Test Accuracy: {test_acc.avg :.3f}') 80 | 81 | if verbose: 82 | print(f'Layer {layer_to_probe}, Test Accuracy: {test_acc.avg :.3f}') 83 | 84 | probe_accs[0].append(acc.avg) 85 | probe_accs[1].append(test_acc.avg) 86 | 87 | return probe_accs 88 | -------------------------------------------------------------------------------- /eap/attr_patching_per_input_dim_per_neuron.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import random 3 | import torch 4 | import transformer_lens as lens 5 | from tqdm import tqdm 6 | from typing import List, Tuple 7 | from metrics import indirect_effect 8 | 9 | from prompt_generation import separate_prompts_and_answers 10 | 11 | 12 | # THIS IS A HACKEY VERSION OF "full_attribution_patching" as it exists 13 | # in attr_patching.py. 14 | # This is used only for finding the gradient of each input dimension at each mlp_in neuron input. 15 | 16 | 17 | def full_attribution_patching_per_input_dim_per_neuron( 18 | model: lens.HookedTransformer, 19 | prompts_and_answers: List[Tuple[str, str]], 20 | corrupt_prompts_and_answers: List[Tuple[str, str]] = None, 21 | metric: str = "IE", 22 | pos=-1, 23 | batch_size: int = 1, 24 | ): 25 | model.requires_grad_(True) 26 | 27 | forward_hook_names = ["ln2.hook_normalized"] # ['hook_mlp_in'] 28 | backward_hook_names = ["mlp.hook_pre"] 29 | 30 | forward_hook_filter = partial( 31 | should_measure_hook, measurable_hooks=forward_hook_names 32 | ) 33 | backward_hook_filter = partial( 34 | should_measure_hook, measurable_hooks=backward_hook_names 35 | ) 36 | 37 | # Choose a random corrupt prompt for each prompt, if not given 38 | if corrupt_prompts_and_answers is None: 39 | corrupt_prompts_and_answers = [] 40 | for prompt_idx in range(len(prompts_and_answers)): 41 | # Choose a random prompt to corrupt with, without any limitations other than choosing a different prompt 42 | corrupt_prompt_idx = random.choice( 43 | list(set(range(len(prompts_and_answers))) - {prompt_idx}) 44 | ) 45 | corrupt_prompts_and_answers.append(prompts_and_answers[corrupt_prompt_idx]) 46 | 47 | prompts, answers = separate_prompts_and_answers(prompts_and_answers) 48 | corrupt_prompts, corrupt_answers = separate_prompts_and_answers( 49 | corrupt_prompts_and_answers 50 | ) 51 | labels, corrupt_labels = model.to_tokens( 52 | answers, prepend_bos=False 53 | ), model.to_tokens(corrupt_answers, prepend_bos=False) 54 | 55 | attr_patching_scores = {} 56 | 57 | for idx in tqdm(range(0, len(prompts), batch_size)): 58 | prompt_batch, corrupt_prompt_batch = ( 59 | prompts[idx : idx + batch_size], 60 | corrupt_prompts[idx : idx + batch_size], 61 | ) 62 | label_batch, corrupt_label_batch = ( 63 | labels[idx : idx + batch_size], 64 | corrupt_labels[idx : idx + batch_size], 65 | ) 66 | 67 | # First forward pass to get corrupt cache 68 | _, corrupt_cache = model.run_with_cache(corrupt_prompt_batch) 69 | corrupt_cache = { 70 | k: v 71 | for (k, v) in corrupt_cache.cache_dict.items() 72 | if forward_hook_filter(k) 73 | } 74 | 75 | # Second forward pass to get corrupt cache 76 | clean_logits_orig, clean_cache = model.run_with_cache(prompt_batch) 77 | clean_cache = { 78 | k: v for (k, v) in clean_cache.cache_dict.items() if forward_hook_filter(k) 79 | } 80 | 81 | # Calculate the difference between every two parallel activation cache elements 82 | diff_cache = {} 83 | for k in clean_cache: 84 | diff_cache[k] = (corrupt_cache[k] - clean_cache[k])[:, pos, :].cpu() 85 | 86 | def backward_hook_fn(grad, hook, attr_patching_scores): 87 | matching_hook_name_key = f"blocks.{hook.layer()}.{forward_hook_names[0]}" 88 | if matching_hook_name_key not in attr_patching_scores: 89 | # attr_patching_scores[matching_hook_name_key] = torch.zeros((model.cfg.d_model,), device='cpu') # for full attribution (not per neuron per dim) 90 | # attr_patching_scores[matching_hook_name_key] = torch.zeros(model.cfg.d_model, model.cfg.d_mlp, device='cpu') # for per neuron per input dim attribution 91 | attr_patching_scores[matching_hook_name_key] = torch.zeros( 92 | len(prompts), model.cfg.d_model, model.cfg.d_mlp, device="cpu" 93 | ) # for per neuron per input dim attribution 94 | 95 | # Full attribution - not per neuron per dim, but should work (on dimension level) - THIS IS FOR DEBUGGING 96 | # grad_a_pre_act = model.blocks[hook.layer()].mlp.W_in # The gradient of the pre_act output w.r.t the input - (d_model, d_mlp) 97 | # grad_l = grad[:, pos, :] # The gradient of the loss metric w.r.t. the pre_act activation (d_mlp, ) 98 | # attr_patching_scores[matching_hook_name_key] += (diff_cache[matching_hook_name_key] * (grad_l @ grad_a_pre_act.transpose(0, 1))).sum(dim=0).cpu() 99 | # attr_patching_scores[matching_hook_name_key] += (diff_cache[matching_hook_name_key] * grad[:, pos, :]).sum(dim=0).cpu() 100 | 101 | # Per neuron per input dim attribution 102 | grad_L_wrt_e = (model.blocks[hook.layer()].mlp.W_in * grad[:, pos, :]).cpu() 103 | # attr_patching_scores[matching_hook_name_key] += (diff_cache[matching_hook_name_key].unsqueeze(-1) * grad_L_wrt_e).sum(dim=0) 104 | attr_patching_scores[matching_hook_name_key][idx : idx + batch_size] = ( 105 | diff_cache[matching_hook_name_key].unsqueeze(-1) * grad_L_wrt_e 106 | ) 107 | 108 | model.reset_hooks() 109 | model.add_hook( 110 | name=backward_hook_filter, 111 | hook=partial(backward_hook_fn, attr_patching_scores=attr_patching_scores), 112 | dir="bwd", 113 | ) 114 | with torch.set_grad_enabled(True): 115 | clean_logits = model(prompt_batch, return_type="logits") 116 | if metric == "IE": 117 | value = indirect_effect( 118 | clean_logits_orig[:, -1].softmax(dim=-1), 119 | clean_logits[:, -1].softmax(dim=-1), 120 | label_batch, 121 | corrupt_label_batch, 122 | ).mean(dim=0) 123 | else: 124 | raise ValueError(f"Unknown metric {metric}") 125 | value.backward() 126 | model.zero_grad() 127 | 128 | del diff_cache 129 | 130 | # for hook_name in attr_patching_scores.keys(): 131 | # attr_patching_scores[hook_name] = (attr_patching_scores[hook_name] / len(prompts)).cpu() 132 | 133 | model.reset_hooks() 134 | model.requires_grad_(False) 135 | torch.cuda.empty_cache() 136 | 137 | return attr_patching_scores 138 | 139 | 140 | def should_measure_hook(hook_name, measurable_hooks): 141 | if any([h in hook_name for h in measurable_hooks]): 142 | return True 143 | -------------------------------------------------------------------------------- /circuit_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformer_lens as lens 3 | from component import Component 4 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 5 | from typing import List 6 | 7 | # The valid early and late components in a residual path. 8 | # Valid early components write to the residual stream, and valid late components read from the residual stream. 9 | VALID_EARLY_COMPONENTS = ['hook_z', 'hook_attn_out', 'hook_mlp_out', 'hook_resid_pre', 'hook_resid_post'] 10 | VALID_LATE_COMPONENTS = ['hook_q_input', 'hook_k_input', 'hook_v_input', 'hook_mlp_in', 'hook_resid_pre', 'hook_resid_post'] 11 | 12 | 13 | def is_valid_path(early_component: Component, late_component: Component) -> bool: 14 | """ 15 | Check if a path from the early component to another late component is supported. 16 | A path is considered valid if the current (early) component writes to the residual stream, 17 | and the late component reads from the residual stream. 18 | The components can process information in different shapes (For example, hook_z -> hook_resid_post), 19 | but in this case a projection (using W_O matrix) is performed. 20 | 21 | NOTE: This function is used for path patching experiments. 22 | 23 | Args: 24 | early_component (Component): The early component to check the path from. 25 | late_component (Component): The late component to check the path to. 26 | """ 27 | # hook_name can be either like "z" or like "hook_z", thus we check if its contained in any valid name 28 | return any(early_component.hook_name in valid_name for valid_name in VALID_EARLY_COMPONENTS) and \ 29 | any(late_component.hook_name in valid_name for valid_name in VALID_LATE_COMPONENTS) 30 | 31 | 32 | def is_earlier_component(model_cfg: HookedTransformerConfig, 33 | early_component: Component, 34 | late_component: Component) -> bool: 35 | """ 36 | Check if the early component is earlier than the late component in the model. 37 | """ 38 | if early_component.layer < late_component.layer: 39 | return True 40 | elif early_component.layer == late_component.layer: 41 | # If both components are in the same layer, we say the early component is earlier only in several cases - 42 | # 1. Attention -> MLP when they are not parallel 43 | # 2. Attention/MLP/resid_pre -> resid_post 44 | # 3. resid_pre -> Attention/MLP/resid_post 45 | if early_component.is_attn and late_component.is_mlp and not model_cfg.parallel_attn_mlp: 46 | return True 47 | elif 'resid_post' in late_component.valid_hook_name() or 'resid_pre' in early_component.valid_hook_name(): 48 | return True 49 | else: 50 | return False 51 | else: 52 | return False 53 | 54 | 55 | def topk_effective_components(model: lens.HookedTransformer, 56 | effect_map: torch.Tensor, 57 | k: int = 3, 58 | effect_threshold: float = None, 59 | heads_only: bool = False): 60 | """ 61 | Get the most effective components in the effect map. 62 | 63 | Args: 64 | effect_map (torch.Tensor): The effect map to get the most effective components from. 65 | Should be of shape (c, l), where c is the number of components in each layer (first heads, then MLP), 66 | and l is the number of layers. 67 | k (int): The number of components to return. 68 | effect_threshold (float): The threshold to filter the components by. 69 | If None, no filtering is applied. 70 | heads_only (bool): If true, only attention heads are considered and MLPs are ignored. 71 | Returns: 72 | dict (Component -> float): A dictionary mapping the most effective components to their effect. 73 | """ 74 | if heads_only: 75 | effect_map = effect_map[:, :-1] # Ignore last column, where MLP information should be 76 | 77 | # Make up a list of the most effective components for each C_1 components (to create "C_2") 78 | most_effective_components = {} 79 | indices = torch.topk(effect_map.flatten(), k=k, dim=0).indices 80 | layers, heads = indices // effect_map.shape[1], indices % effect_map.shape[1] 81 | for layer, head in zip(layers, heads): 82 | layer, head = layer.item(), head.item() 83 | if head == model.cfg.n_heads: 84 | # is mlp 85 | most_effective_components[Component('mlp_out', layer=layer)] = effect_map[layer, -1] 86 | else: 87 | # is head 88 | most_effective_components[Component('z', layer=layer, head=head)] = effect_map[layer, head] 89 | 90 | if effect_threshold is not None: 91 | most_effective_components = {c:e for c, e in most_effective_components.items() if e.abs() > effect_threshold} 92 | 93 | return most_effective_components 94 | 95 | 96 | def convert_late_to_early(components: List[Component]): 97 | """ 98 | Convert late components (q_input, k_input, v_input, mlp_in) to early components (z, mlp_out). 99 | 100 | Args: 101 | components (list[Component]): The components to convert. 102 | 103 | Return: 104 | list[Component]: The converted components. 105 | """ 106 | converted = [] 107 | for comp in components: 108 | if 'q_input' in comp.hook_name or 'k_input' in comp.hook_name or 'v_input' in comp.hook_name: 109 | converted.append(Component('z', layer=comp.layer, head=comp.head_idx)) 110 | elif 'mlp_in' in comp.hook_name: 111 | converted.append(Component('mlp_out', layer=comp.layer)) 112 | else: 113 | # Component is already early 114 | converted.append(comp) 115 | return converted 116 | 117 | 118 | def convert_early_to_late(components: List[Component]): 119 | """ 120 | Convert early components (z, mlp_out) to late components (q_input, k_input, v_input, mlp_in). 121 | 122 | Args: 123 | components (list[Component]): The components to convert. 124 | 125 | Return: 126 | list[Component]: The converted components. 127 | """ 128 | converted = [] 129 | for comp in components: 130 | if 'z' in comp.hook_name: 131 | converted.append(Component('q_input', layer=comp.layer, head=comp.head_idx)) 132 | converted.append(Component('k_input', layer=comp.layer, head=comp.head_idx)) 133 | converted.append(Component('v_input', layer=comp.layer, head=comp.head_idx)) 134 | elif 'mlp_out' in comp.hook_name: 135 | converted.append(Component('mlp_in', layer=comp.layer)) 136 | else: 137 | # Component is already late 138 | converted.append(comp) 139 | return converted -------------------------------------------------------------------------------- /activation_patching.py: -------------------------------------------------------------------------------- 1 | from general_utils import set_deterministic 2 | from functools import partial 3 | import transformer_lens as lens 4 | import torch 5 | import random 6 | from metrics import indirect_effect 7 | from prompt_generation import separate_prompts_and_answers 8 | from typing import List, Tuple 9 | 10 | def activation_patching_experiment(model: lens.HookedTransformer, 11 | prompts_and_answers: List[Tuple[str, str]], 12 | metric: str='IE', 13 | hookpoint_name: str='mlp_post', 14 | n_shots: int=0, 15 | token_pos: int=-1, 16 | hook_func_overload=None, 17 | corrupt_prompts_and_answers: List[Tuple[str, str]]=None, 18 | random_seed: int=None): 19 | """ 20 | Performs an activation patching experiment. 21 | Each prompt is passed through the model, and at each layer, the activations are patched with the activations from another prompt. 22 | The effect of this patching is measured by the metric and averaged over all prompts. 23 | 24 | Args: 25 | model (lens.HookedTransformer): The model to patch. 26 | prompts_and_answers (List[Tuple[str, str]]): A list of (prompt, answer) tuples. 27 | metric (str): The metric to use. Can either be 'IE' (indirect effect) or 'IE-logits' (indirect effect on logits). 28 | hookpoint_name (str): The name of the hookpoint to patch. Defaults to patching MLP output (mlp_post). See transformer_lens for more details. 29 | n_shots (int): The number of pre-prompt examples to use. For example, for n_shots=1, the prompt '5+4=' might pass through the model as '13+27=40;5+4='. 30 | The shots are chosen randomly from the prompts_and_answers list (excluding the current clean and corrupt prompt). 31 | token_pos (int): The token position to patch. Defaults to -1 (last token). If None, all token positions are patched. 32 | hook_func_overload (Callable): A function to overload the hook function with. If None, the default hook function is used. 33 | If this is not None, other hook-related arguments are ignored. 34 | corrupt_prompts_and_answers (List[Tuple[str, str]]): A list of (prompt, answer) tuples to use as the corrupt prompts. 35 | If None, the corrupt prompts are chosen randomly from the prompts_and_answers list (such that no prompt 36 | is used as its own corrupt prompt). 37 | random_seed (int): The random seed to use for the experiment. If None, the seed is not set. 38 | Returns: 39 | torch.Tensor (n_prompts, n_layers): The metric results for each prompt and layer. 40 | """ 41 | if random_seed is not None: 42 | # Set random seed for reproducibility 43 | set_deterministic(seed=random_seed) 44 | 45 | metric_results = torch.zeros((len(prompts_and_answers), model.cfg.n_layers, ), dtype=torch.float32) 46 | 47 | # Define a default hooking function, which works for patching MLP / full attention output activations 48 | def default_patching_hook(value, hook, cache, token_pos): 49 | """ 50 | A hook that works for some of the more common modules (MLP outputs, Attention outputs). 51 | """ 52 | if token_pos is None: 53 | value = cache[hook.name] 54 | else: 55 | value[:, token_pos, :] = cache[hook.name][:, token_pos, :] 56 | 57 | return value 58 | 59 | hook_func = default_patching_hook if hook_func_overload is None else hook_func_overload 60 | 61 | # Choose a random corrupt prompt for each prompt, if not given 62 | if corrupt_prompts_and_answers is None: 63 | corrupt_prompts_and_answers = [] 64 | for prompt_idx in range(len(prompts_and_answers)): 65 | # Choose a random prompt to corrupt with, without any limitations other than choosing a different prompt 66 | corrupt_prompt_idx = random.choice(list(set(range(len(prompts_and_answers))) - {prompt_idx})) 67 | corrupt_prompts_and_answers.append(prompts_and_answers[corrupt_prompt_idx]) 68 | 69 | clean_prompts, clean_answers = separate_prompts_and_answers(prompts_and_answers) 70 | corrupt_prompts, corrupt_answers = separate_prompts_and_answers(corrupt_prompts_and_answers) 71 | clean_labels = model.to_tokens(clean_answers, prepend_bos=False) 72 | corrupt_labels = model.to_tokens(corrupt_answers, prepend_bos=False) 73 | 74 | # Add pre-prompt examples for each prompt, according to number of shots 75 | for i in range(n_shots): 76 | for prompt_idx in range(len(prompts_and_answers)): 77 | shot_prompt_idx = random.choice(list(set(range(len(prompts_and_answers))) - {prompt_idx, corrupt_prompt_idx})) 78 | shot_prompt = f'{clean_prompts[shot_prompt_idx]}={clean_answers[shot_prompt_idx]}' 79 | clean_prompts[prompt_idx] = shot_prompt + '\n' + clean_prompts[prompt_idx] 80 | corrupt_prompts[prompt_idx] = shot_prompt + '\n' + corrupt_prompts[prompt_idx] 81 | 82 | # Run both prompt batches to get the logits and activation cache 83 | clean_logits, clean_cache = model.run_with_cache(clean_prompts, return_type='logits') 84 | corrupt_logits, corrupt_cache = model.run_with_cache(corrupt_prompts, return_type='logits') 85 | 86 | # Patch each layer and measure the effect metric 87 | hook_fn_with_cache = partial(hook_func, cache=corrupt_cache, token_pos=token_pos) 88 | for layer in range(model.cfg.n_layers): 89 | patched_logits = model.run_with_hooks(clean_prompts, 90 | fwd_hooks=[(lens.utils.get_act_name(name=hookpoint_name, layer=layer), hook_fn_with_cache)], 91 | return_type='logits') 92 | if metric == 'IE': 93 | metric_results[:, layer] = indirect_effect(clean_logits[:, -1].softmax(dim=-1).to(model.cfg.device), 94 | patched_logits[:, -1].softmax(dim=-1).to(model.cfg.device), 95 | clean_labels.to(model.cfg.device), 96 | corrupt_labels.to(model.cfg.device)) 97 | elif metric == 'IE-Logits': 98 | metric_results[:, layer] = indirect_effect(clean_logits[:, -1].to(model.cfg.device), 99 | patched_logits[:, -1].to(model.cfg.device), 100 | clean_labels.to(model.cfg.device), 101 | corrupt_labels.to(model.cfg.device)) 102 | else: 103 | raise ValueError(f"Unknown metric {metric}") 104 | 105 | return metric_results 106 | -------------------------------------------------------------------------------- /script_circuit_localization.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import pickle 4 | import os 5 | import torch 6 | from functools import partial 7 | from prompt_generation import OPERATORS 8 | from general_utils import set_deterministic, load_model 9 | from eap.attr_patching import node_attribution_patching 10 | from activation_patching import activation_patching_experiment 11 | 12 | 13 | torch.set_grad_enabled(False) 14 | device = 'cuda' 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--model_name', type=str, help='Name of the model to be loaded') 20 | parser.add_argument('--model_path', type=str, help='Path to the model to be loaded') 21 | parser.add_argument('--do_attribution', action='store_true', help='Whether to perform node attribution instead of activation patching') 22 | args = parser.parse_args() 23 | return args 24 | 25 | 26 | def manual_ap_localization(model_name, model_path): 27 | """ 28 | Find the arithmetic circuit using activation patching on each component (head or MLP) of the model. 29 | """ 30 | model = load_model(model_name, model_path, device) 31 | 32 | results_file_path = f'./data/{model_name}/ie_maps_activation_patching.pt' 33 | max_op = 300 34 | analysis_prompts_file_path = fr'./data/{model_name}/large_prompts_and_answers_max_op={max_op}.pkl' 35 | set_deterministic(42) 36 | 37 | # Load pre-calculated prompts and answers 38 | with open(analysis_prompts_file_path, 'rb') as f: 39 | large_prompts_and_answers = pickle.load(f) 40 | for i in range(len(large_prompts_and_answers)): 41 | random.shuffle(large_prompts_and_answers[i]) 42 | correct_prompts_and_answers = [pa[:50] for pa in large_prompts_and_answers] 43 | 44 | seeds = [42, 412, 32879, 123, 436] 45 | if os.path.exists(results_file_path): 46 | ie_maps = torch.load(results_file_path) 47 | else: 48 | ie_maps = {} 49 | 50 | def head_hooking_func(value, hook, head_index, token_pos, cache): 51 | if token_pos is None: 52 | value[:, :, head_index, :] = cache[hook.name][:, :, head_index, :] # For z hooking 53 | else: 54 | value[:, token_pos, head_index, :] = cache[hook.name][:, token_pos, head_index, :] # For z hooking 55 | return value 56 | 57 | # Patch each component (MLP and attention head) at each token position, for each operator and each random seed. 58 | for token_pos in [4, 3, 2, 1]: 59 | for operator_idx in range(len(OPERATORS)): 60 | for seed in seeds: 61 | if (operator_idx, token_pos, seed) in ie_maps.keys(): 62 | continue 63 | print(f"{operator_idx=}, {token_pos=}, {seed=}") 64 | correct_pa = correct_prompts_and_answers[operator_idx] 65 | corrupt_pa = random.sample(sum(correct_prompts_and_answers, []), len(correct_pa)) 66 | ie_maps[(operator_idx, token_pos, seed)] = torch.zeros((model.cfg.n_layers, model.cfg.n_heads + 1), dtype=torch.float32) 67 | 68 | # MLP 69 | ie_maps[(operator_idx, token_pos, seed)][:, -1] = activation_patching_experiment(model, correct_pa, 70 | corrupt_prompts_and_answers=corrupt_pa, hookpoint_name='mlp_post', 71 | metric='IE', token_pos=token_pos, random_seed=seed).mean(dim=0) 72 | # Attention heads 73 | for head_idx in range(model.cfg.n_heads): 74 | head_hook_fn = partial(head_hooking_func, head_index=head_idx) 75 | ie_maps[(operator_idx, token_pos, seed)][:, head_idx] = activation_patching_experiment(model, correct_pa, 76 | corrupt_prompts_and_answers=corrupt_pa, hookpoint_name='z', 77 | metric='IE', token_pos=token_pos, 78 | hook_func_overload=head_hook_fn, random_seed=seed).mean(dim=0) 79 | # Save the results after each calculation to avoid losing them 80 | torch.save(ie_maps, results_file_path) 81 | 82 | 83 | def node_attr_patching_localization(model_name, model_path): 84 | """ 85 | Find the arithmetic circuit using node attribution patching on each component (head or MLP) of the model. 86 | Faster than activation patching, but less accurate. 87 | """ 88 | # Load the model into CPU because backward pass in the GPU takes up too much memory 89 | model = load_model(model_name, model_path, 'cpu') 90 | max_op = 300 91 | analysis_prompts_file_path = fr'./data/{model_name}/large_prompts_and_answers_max_op={max_op}.pkl' 92 | set_deterministic(42) 93 | 94 | # Load the pre-calculated prompts and answers 95 | with open(analysis_prompts_file_path, 'rb') as f: 96 | large_prompts_and_answers = pickle.load(f) 97 | for i in range(len(large_prompts_and_answers)): 98 | random.shuffle(large_prompts_and_answers[i]) 99 | correct_prompts_and_answers = [pa[:50] for pa in large_prompts_and_answers] 100 | 101 | results_file_path = f'./data/{model_name}/node_attribution_results.pt' 102 | if os.path.exists(results_file_path): 103 | attribution_results = torch.load(results_file_path) 104 | else: 105 | attribution_results = {} 106 | 107 | seeds = [42, 412, 32879, 123] 108 | for operator_idx in range(len(OPERATORS)): 109 | prompts_and_answers = correct_prompts_and_answers[operator_idx] 110 | for seed in seeds: 111 | if (operator_idx, seed) in attribution_results: 112 | print(f"Found results file for {operator_idx=}, {seed=}") 113 | continue 114 | print(f"{operator_idx=}, {seed=}") 115 | set_deterministic(seed) 116 | corrupt_prompts_and_answers = random.sample(sum(correct_prompts_and_answers, []), len(prompts_and_answers)) 117 | attribution_scores = node_attribution_patching(model, prompts_and_answers, corrupt_prompts_and_answers, metric='IE', batch_size=10) 118 | scores_tensor = torch.zeros((5, model.cfg.n_layers, model.cfg.n_heads + 1), dtype=torch.float32) 119 | for layer in range(model.cfg.n_layers): 120 | for head in range(model.cfg.n_heads): 121 | scores_tensor[:, layer, head] = attribution_scores[f'blocks.{layer}.attn.hook_z'].mean(dim=0)[:, head].sum(dim=-1) 122 | scores_tensor[:, layer, -1] = attribution_scores[f'blocks.{layer}.mlp.hook_post'].mean(dim=0).sum(dim=-1) 123 | attribution_results[(operator_idx, seed)] = scores_tensor 124 | 125 | # Save the results after each calculation to avoid losing them 126 | torch.save(attribution_results, results_file_path) 127 | 128 | 129 | def main(): 130 | args = parse_args() 131 | if args.do_attribution: 132 | node_attr_patching_localization(args.model_name, args.model_path) 133 | else: 134 | manual_ap_localization(args.model_name, args.model_path) 135 | 136 | 137 | if __name__ == '__main__': 138 | main() -------------------------------------------------------------------------------- /eap/eap_wrapper.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from functools import partial 3 | from typing import Callable, List, Union 4 | 5 | import einops 6 | import torch 7 | import tqdm 8 | from jaxtyping import Float, Int 9 | from torch import Tensor 10 | from tqdm import tqdm 11 | 12 | from transformer_lens import HookedTransformer 13 | from transformer_lens.hook_points import HookPoint 14 | 15 | from eap.eap_graph import EAPGraph 16 | 17 | def EAP_corrupted_forward_hook( 18 | activations: Union[Float[Tensor, "batch_size seq_len n_heads d_model"], Float[Tensor, "batch_size seq_len d_model"]], 19 | hook: HookPoint, 20 | upstream_activations_difference: Float[Tensor, "batch_size seq_len n_upstream_nodes d_model"], 21 | graph: EAPGraph 22 | ): 23 | hook_slice = graph.get_hook_slice(hook.name) 24 | if activations.ndim == 3: 25 | # We are in the case of a residual layer or MLP 26 | # Activations have shape [batch_size, seq_len, d_model] 27 | # We need to add an extra dimension to make it [batch_size, seq_len, 1, d_model] 28 | # The hook slice is a slice of length 1 29 | upstream_activations_difference[:, :, hook_slice, :] = activations.unsqueeze(-2)#-activations.unsqueeze(-2) 30 | elif activations.ndim == 4: 31 | # We are in the case of an attention layer 32 | # Activations have shape [batch_size, seq_len, n_heads, d_model] 33 | upstream_activations_difference[:, :, hook_slice, :] = activations #-activations 34 | 35 | def EAP_clean_forward_hook( 36 | activations: Union[Float[Tensor, "batch_size seq_len n_heads d_model"], Float[Tensor, "batch_size seq_len d_model"]], 37 | hook: HookPoint, 38 | upstream_activations_difference: Float[Tensor, "batch_size seq_len n_upstream_nodes d_model"], 39 | graph: EAPGraph 40 | ): 41 | hook_slice = graph.get_hook_slice(hook.name) 42 | if activations.ndim == 3: 43 | upstream_activations_difference[:, :, hook_slice, :] += -activations.unsqueeze(-2)#activations.unsqueeze(-2) 44 | elif activations.ndim == 4: 45 | upstream_activations_difference[:, :, hook_slice, :] += -activations#activations 46 | 47 | def EAP_clean_backward_hook( 48 | grad: Union[Float[Tensor, "batch_size seq_len n_heads d_model"], Float[Tensor, "batch_size seq_len d_model"]], 49 | hook: HookPoint, 50 | upstream_activations_difference: Float[Tensor, "batch_size seq_len n_upstream_nodes d_model"], 51 | graph: EAPGraph, 52 | attn_head_repeat_factor: int = None, 53 | pos: int = None, 54 | ): 55 | hook_slice = graph.get_hook_slice(hook.name) 56 | 57 | # we get the slice of all upstream nodes that come before this downstream node 58 | earlier_upstream_nodes_slice = graph.get_slice_previous_upstream_nodes(hook) 59 | 60 | # grad has shape [batch_size, seq_len, n_heads, d_model] or [batch_size, seq_len, d_model] 61 | # we want to multiply it by the upstream activations difference 62 | if grad.ndim == 3: 63 | grad_expanded = grad.unsqueeze(-2) # Shape: [batch_size, seq_len, 1, d_model] 64 | else: 65 | if attn_head_repeat_factor is not None and ('hook_k' in hook.name or 'hook_v' in hook.name): 66 | grad_expanded = torch.repeat_interleave(grad, dim=-2, repeats=attn_head_repeat_factor) # Shape: [batch_size, seq_len, n_heads (real), d_model] 67 | else: 68 | grad_expanded = grad # Shape: [batch_size, seq_len, n_heads, d_model] 69 | 70 | # we compute the mean over the batch_size and seq_len dimensions 71 | result = torch.matmul( 72 | upstream_activations_difference[:, :, earlier_upstream_nodes_slice], 73 | grad_expanded.transpose(-1, -2) 74 | ) 75 | 76 | if pos is None: 77 | result = result.sum(dim=0).sum(dim=0) # we sum over the batch_size and seq_len dimensions 78 | else: 79 | result = result[:, pos].sum(dim=0) 80 | 81 | graph.eap_scores[earlier_upstream_nodes_slice, hook_slice] += result 82 | 83 | 84 | def EAP( 85 | model: HookedTransformer, 86 | clean_tokens: Int[Tensor, "batch_size seq_len"], 87 | corrupted_tokens: Int[Tensor, "batch_size seq_len"], 88 | metric: Callable, 89 | upstream_nodes: List[str]=None, 90 | downstream_nodes: List[str]=None, 91 | batch_size: int=1, 92 | pos: int = None 93 | ): 94 | 95 | graph = EAPGraph(model.cfg, upstream_nodes, downstream_nodes) 96 | 97 | assert clean_tokens.shape == corrupted_tokens.shape, "Shape mismatch between clean and corrupted tokens" 98 | num_prompts, seq_len = clean_tokens.shape[0], clean_tokens.shape[1] 99 | 100 | assert num_prompts % batch_size == 0, "Number of prompts must be divisible by batch size" 101 | 102 | upstream_activations_difference = torch.zeros( 103 | (batch_size, seq_len, graph.n_upstream_nodes, model.cfg.d_model), 104 | device=model.cfg.device, 105 | dtype=model.cfg.dtype, 106 | requires_grad=False 107 | ) 108 | 109 | # set the EAP scores to zero 110 | graph.reset_scores() 111 | 112 | upstream_hook_filter = lambda name: name.endswith(tuple(graph.upstream_hooks)) 113 | downstream_hook_filter = lambda name: name.endswith(tuple(graph.downstream_hooks)) 114 | 115 | corruped_upstream_hook_fn = partial( 116 | EAP_corrupted_forward_hook, 117 | upstream_activations_difference=upstream_activations_difference, 118 | graph=graph 119 | ) 120 | 121 | clean_upstream_hook_fn = partial( 122 | EAP_clean_forward_hook, 123 | upstream_activations_difference=upstream_activations_difference, 124 | graph=graph 125 | ) 126 | 127 | attn_head_repeat_factor = None if model.cfg.n_key_value_heads is None else model.cfg.n_heads // model.cfg.n_key_value_heads 128 | clean_downstream_hook_fn = partial( 129 | EAP_clean_backward_hook, 130 | upstream_activations_difference=upstream_activations_difference, 131 | graph=graph, 132 | attn_head_repeat_factor = attn_head_repeat_factor, 133 | pos = pos 134 | ) 135 | 136 | for idx in tqdm(range(0, num_prompts, batch_size)): 137 | # we first perform a forward pass on the corrupted input 138 | model.add_hook(upstream_hook_filter, corruped_upstream_hook_fn, "fwd") 139 | 140 | # we don't need gradients for this forward pass 141 | # we'll take the gradients when we perform the forward pass on the clean input 142 | with torch.no_grad(): 143 | corrupted_tokens = corrupted_tokens.to(model.cfg.device) 144 | model(corrupted_tokens[idx:idx+batch_size], return_type=None) 145 | 146 | # now we perform a forward and backward pass on the clean input 147 | model.reset_hooks() 148 | model.add_hook(upstream_hook_filter, clean_upstream_hook_fn, "fwd") 149 | model.add_hook(downstream_hook_filter, clean_downstream_hook_fn, "bwd") 150 | 151 | clean_tokens = clean_tokens.to(model.cfg.device) 152 | value = metric(model(clean_tokens[idx:idx+batch_size], return_type="logits"), idx=idx, batch_size=batch_size) 153 | value.backward() 154 | 155 | # We delete the activation differences tensor to free up memory 156 | model.zero_grad() 157 | upstream_activations_difference *= 0 158 | 159 | del upstream_activations_difference 160 | gc.collect() 161 | torch.cuda.empty_cache() 162 | model.reset_hooks() 163 | 164 | graph.eap_scores /= num_prompts 165 | graph.eap_scores = graph.eap_scores.cpu() 166 | 167 | return graph 168 | -------------------------------------------------------------------------------- /circuit.py: -------------------------------------------------------------------------------- 1 | from component import Component 2 | import torch 3 | 4 | 5 | class Circuit(): 6 | """ 7 | A class representing a circuit in a transformer model. 8 | A circuit is a set of Component objects, which are connected as a DAG. Each connection represents the 9 | effect of an early component on a late component. 10 | """ 11 | def __init__(self, model_cfg) -> None: 12 | """ 13 | Initilize an empty circuit, as part of a transformer model. 14 | 15 | Args: 16 | model_cfg (lens.HookedTransformerConfig): The configuration of the transformer model. 17 | """ 18 | # Each key is a (early_component, late_component) tuple, and each value represents the effect of patching the early component 19 | # on the late component (via path patching). In case the late component is a psuedo "logits" component, the value is the effect 20 | # of activation patching the early component. 21 | self.components = set() 22 | self.edges = {} 23 | self.model_cfg = model_cfg 24 | 25 | def add_component(self, component, patching_effects=None) -> None: 26 | """ 27 | Add a component to the circuit, and connect it to the logits affected by it / to other components which affect it. 28 | 29 | Args: 30 | component (Component): The component to add. 31 | patching_effects (torch.Tensor): This is a matrix of the effects of early components on the given component (via path patching). 32 | The tensor must be of shape (num_layers, 1) in case the early components are only MLPs, 33 | (num_layers, num_heads) in case the early components are only z vectors, 34 | or (num_layers, num_heads + 1) in case the early components are both z vectors and MLPs (the MLP should be the last column). 35 | If None, the component is added without any edges to it. 36 | """ 37 | # The effects is a matrix representing the effect of many early components (in a certain structure) on the late component 38 | assert patching_effects is None or len(patching_effects.shape) == 2 39 | 40 | self.components.add(component) 41 | 42 | if patching_effects is not None: 43 | layer_count, component_count = patching_effects.shape 44 | assert layer_count == self.model_cfg.n_layers, \ 45 | 'The number of rows in the patching effects matrix must be equal to the number of layers in the model.' 46 | 47 | mlp_column_count = 1 48 | z_column_count = self.model_cfg.n_heads 49 | assert component_count in [mlp_column_count, z_column_count, z_column_count + mlp_column_count], \ 50 | f'Invalid number of columns in the patching effects matrix (got {component_count}, expected one of ' \ 51 | f'{[mlp_column_count, z_column_count, z_column_count + mlp_column_count]}' 52 | 53 | if component_count == mlp_column_count or component_count == (z_column_count + mlp_column_count): 54 | # Patching effects include effects of MLPs on component (in the last column) 55 | for layer in range(layer_count): 56 | early_component = Component('mlp_out', layer=layer) 57 | self.edges[(early_component, component)] = patching_effects[layer, -1] 58 | 59 | if component_count == z_column_count or component_count == (z_column_count + mlp_column_count): 60 | for layer in range(layer_count): 61 | for head_idx in range(z_column_count): 62 | early_component = Component('z', head=head_idx, layer=layer) 63 | self.edges[(early_component, component)] = patching_effects[layer, head_idx] 64 | 65 | def remove_component(self, component) -> None: 66 | self.components.remove(component) 67 | for edge in list(self.edges.keys()): 68 | if edge[1] == component: 69 | del self.edges[edge] 70 | 71 | def get_component_patching_effects(self, component, include_attn=True, include_mlp=True, is_component_late=True, zero_non_existing_edges=False) -> float: 72 | """ 73 | Get the patching effects of all early components on a given component / the effects of the given component on later components. 74 | 75 | Args: 76 | component (Component): The component to get the patching effects for. 77 | include_attn (bool): If True, the patching effects of attention heads are included. 78 | include_mlp (bool): If True, the patching effects of MLPs are included. 79 | is_component_late (bool): If True, the component is considered a late component, and the patching effects of early components on it are returned. 80 | If False, the component is considered an early component, and the patching effects of it on later components are returned. 81 | zero_non_existing_edges (bool): If True, the patching effects of non-existing edges are returned as 0. Otherwise, an exception is raised. 82 | Returns: 83 | (torch.tensor) - A tensor of shape (num_layers, num_components_per_layer) representing the patching effects of early components on the given component. 84 | The num of components per layer is determined by the flags include_attn and include_mlp (it can be one of 1, n_heads or n_heads + 1). 85 | """ 86 | assert include_attn or include_mlp, 'You must request at least one of the patching effects (attn heads or mlp) on the component' 87 | patching_effect_columns = (self.model_cfg.n_heads if include_attn else 0) + (1 if include_mlp else 0) 88 | effect_matrix = torch.zeros((self.model_cfg.n_layers, patching_effect_columns)) 89 | for layer in range(self.model_cfg.n_layers): 90 | if include_attn: 91 | for head_idx in range(self.model_cfg.n_heads): 92 | if is_component_late: 93 | head_component = Component('z', head=head_idx, layer=layer) 94 | effect_matrix[layer, head_idx] = self.edges.get((head_component, component), 0 if zero_non_existing_edges else None) 95 | else: 96 | # Edges are saved such that q_input/k_input/v_input are the late components, so searching for a late component with name=z doesnt work. 97 | effect_matrix[layer, head_idx] = self.edges.get((component, Component('q_input', head=head_idx, layer=layer)), 0 if zero_non_existing_edges else None) + \ 98 | self.edges.get((component, Component('k_input', head=head_idx, layer=layer)), 0 if zero_non_existing_edges else None) + \ 99 | self.edges.get((component, Component('v_input', head=head_idx, layer=layer)), 0 if zero_non_existing_edges else None) 100 | if include_mlp: 101 | if is_component_late: 102 | mlp_component = Component('mlp_out', layer=layer) 103 | effect_matrix[layer, -1] = self.edges.get((mlp_component, component), 0 if zero_non_existing_edges else None) 104 | else: 105 | mlp_component = Component('mlp_in', layer=layer) 106 | effect_matrix[layer, -1] = self.edges.get((component, mlp_component), 0 if zero_non_existing_edges else None) 107 | return effect_matrix 108 | -------------------------------------------------------------------------------- /script_per_neuron_analysis.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import transformer_lens as lens 3 | import torch 4 | import random 5 | import os 6 | from functools import partial 7 | from tqdm import tqdm 8 | from general_utils import set_deterministic, get_hook_dim 9 | from prompt_generation import separate_prompts_and_answers 10 | from typing import List, Tuple, Dict 11 | from metrics import indirect_effect 12 | from component import Component 13 | 14 | 15 | def per_neuron_ap_experiment(model: lens.HookedTransformer, 16 | prompts_and_answers: List[Tuple[str, str]], 17 | hook_component: Component, 18 | corrupt_prompts_and_answers: List[Tuple[str, str]] = None, 19 | metric: str = 'IE', 20 | token_pos: int = -1, 21 | random_seed: int = None): 22 | """ 23 | Conducts a activation patching experiment on individual neurons of a given component. 24 | Args: 25 | model (lens.HookedTransformer): The transformer model to be analyzed. 26 | prompts_and_answers (List[Tuple[str, str]]): A list of tuples containing prompts and their corresponding answers. 27 | hook_component (Component): The component of the model to hook into for patching. 28 | corrupt_prompts_and_answers (List[Tuple[str, str]], optional): A list of tuples containing corrupt prompts and their corresponding answers. 29 | If not provided, random corrupt prompts will be chosen. Defaults to None. 30 | metric (str, optional): The metric to be used for evaluation. Currently, only 'IE' (Indirect Effect) is supported. Defaults to 'IE'. 31 | token_pos (int, optional): The position of the token to be patched. If -1, the last token is used. Defaults to -1. 32 | random_seed (int, optional): The random seed for reproducibility. If None, no seed is set. Defaults to None. 33 | Returns: 34 | torch.Tensor: A tensor containing the metric results for each neuron. 35 | """ 36 | if random_seed is not None: 37 | # Set random seed for reproducibility 38 | set_deterministic(seed=random_seed) 39 | 40 | embed_dim = get_hook_dim(model, hook_component.hook_name) 41 | metric_results = torch.zeros((len(prompts_and_answers), embed_dim), dtype=torch.float32) 42 | 43 | # Define a default hooking function, which works for patching MLP / full attention output activations 44 | # For specific head 45 | def hook_fn(value, hook, cache, neuron_idx, token_pos): 46 | if token_pos is None: 47 | if hook_component.head_idx is None: 48 | value[:, :, neuron_idx] = cache[:, :, neuron_idx] 49 | else: 50 | value[:, :, hook_component.head_idx, neuron_idx] = cache[:, :, hook_component.head_idx, neuron_idx] 51 | else: 52 | if hook_component.head_idx is None: 53 | value[:, token_pos, neuron_idx] = cache[:, token_pos, neuron_idx] 54 | else: 55 | value[:, token_pos, hook_component.head_idx, neuron_idx] = cache[:, token_pos, hook_component.head_idx, neuron_idx] 56 | return value 57 | 58 | # Choose a random corrupt prompt for each prompt, if not given 59 | if corrupt_prompts_and_answers is None: 60 | corrupt_prompts_and_answers = [] 61 | for prompt_idx in range(len(prompts_and_answers)): 62 | corrupt_prompt_idx = random.choice(list(set(range(len(prompts_and_answers))) - {prompt_idx})) 63 | corrupt_prompts_and_answers.append(prompts_and_answers[corrupt_prompt_idx]) 64 | 65 | clean_prompts, clean_answers = separate_prompts_and_answers(prompts_and_answers) 66 | corrupt_prompts, corrupt_answers = separate_prompts_and_answers(corrupt_prompts_and_answers) 67 | clean_labels = model.to_tokens(clean_answers, prepend_bos=False) 68 | corrupt_labels = model.to_tokens(corrupt_answers, prepend_bos=False) 69 | 70 | # Run both prompt batches to get the logits and activation cache 71 | clean_logits = model(clean_prompts, return_type='logits') 72 | _, corrupt_cache = model.run_with_cache(corrupt_prompts, return_type='logits') 73 | specific_hook_cache = corrupt_cache[hook_component.valid_hook_name()].detach().clone() 74 | del corrupt_cache 75 | torch.cuda.empty_cache() 76 | 77 | # Patch each neuron and measure the effect 78 | for neuron_idx in tqdm(range(embed_dim)): 79 | hook_fn_with_cache = partial(hook_fn, cache=specific_hook_cache, neuron_idx=neuron_idx, token_pos=token_pos) 80 | patched_logits = model.run_with_hooks(clean_prompts, 81 | fwd_hooks=[(hook_component.valid_hook_name(), hook_fn_with_cache)], 82 | return_type='logits') 83 | if metric == 'IE': 84 | metric_results[:, neuron_idx] = indirect_effect(clean_logits[:, -1].softmax(dim=-1), patched_logits[:, -1].softmax(dim=-1), clean_labels, corrupt_labels) 85 | else: 86 | raise ValueError(f"Unknown metric {metric}") 87 | 88 | return metric_results 89 | 90 | 91 | if __name__ == '__main__': 92 | # Code to run per_neuron_analysis as a background script because it takes too long to run in notebook. 93 | print('Loading model, prompts, etc') 94 | os.environ['CUDA_VISIBLE_DEVICES'] = '6' 95 | device = 'cuda:0' 96 | torch.set_grad_enabled(False) 97 | gptj = lens.HookedTransformer.from_pretrained("EleutherAI/gpt-j-6b", fold_ln=True, center_unembed=True, center_writing_weights=True, device=device) 98 | gptj.eval() 99 | max_op = 100 100 | pos = -1 101 | 102 | results_output_path = fr'./data/addition/per_neuron_analysis_max_op={max_op}_operand_and_operator_corruptions.pkl' 103 | if os.path.exists(results_output_path): 104 | results_dict = pickle.load(open(results_output_path, 'rb')) 105 | else: 106 | results_dict = {} 107 | 108 | with open(fr'./data/gptj/correct_prompts_and_answers_max_op={max_op}.pkl', 'rb') as f: 109 | correct_prompts_and_answers = pickle.load(f) 110 | corrupt_prompts_and_answers = random.sample(correct_prompts_and_answers[0] + 111 | correct_prompts_and_answers[1] + 112 | correct_prompts_and_answers[2] + 113 | correct_prompts_and_answers[3], 114 | k=50) 115 | correct_prompts_and_answers = correct_prompts_and_answers[0] 116 | 117 | 118 | 119 | print(f'Correct prompts: {correct_prompts_and_answers}') 120 | print(f'Corrupt prompts: {corrupt_prompts_and_answers}') 121 | 122 | print('Running AP experiments') 123 | for mlp in range(23, -1, -1): 124 | print(f'Running mlp_post={mlp}') 125 | after_relu_ie = per_neuron_ap_experiment(gptj, correct_prompts_and_answers, 126 | hook_component=Component('mlp_post', layer=mlp), 127 | corrupt_prompts_and_answers=corrupt_prompts_and_answers, 128 | token_pos=pos, random_seed=42) 129 | results_dict[f'mlp_{mlp}_pos_{pos}_max_op_{max_op}_mlp_post'] = after_relu_ie.mean(dim=0).cpu() 130 | 131 | with open(results_output_path, 'wb') as f: 132 | pickle.dump(results_dict, f) -------------------------------------------------------------------------------- /eap/attr_patching.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import random 3 | import torch 4 | import transformer_lens as lens 5 | from tqdm import tqdm 6 | from typing import List, Tuple, Dict 7 | from general_utils import get_hook_dim 8 | from metrics import indirect_effect, logit_diff 9 | from component import Component 10 | 11 | from prompt_generation import separate_prompts_and_answers 12 | 13 | 14 | def node_attribution_patching( 15 | model: lens.HookedTransformer, 16 | prompts_and_answers: List[Tuple[str, str]], 17 | corrupt_prompts_and_answers: List[Tuple[str, str]] = None, 18 | mean_cache: Dict[Component, torch.Tensor] = None, 19 | attributed_hook_names: List[str] = ["mlp.hook_post", "hook_z"], 20 | metric: str = "IE", 21 | batch_size: int = 1, 22 | verbose: bool = True, 23 | ): 24 | """ 25 | Get a cache of attribution patching scores for all components of the model, 26 | estimating the ground truth activation patching result of each component. 27 | 28 | Args: 29 | model (lens.HookedTransformer): The model to be analyzed. 30 | prompts_and_answers (List[Tuple[str, str]]): A list of tuples containing prompts and answers. 31 | corrupt_prompts_and_answers (List[Tuple[str, str]], optional): A list of tuples containing corrupt prompts and answers. 32 | If mean ablations are used via the mean cache. Defaults to None. 33 | mean_cache (Dict[Component, torch.Tensor], optional): A dictionary containing the mean cache of the model. If corrupt_prompts_and_answers is not given, 34 | this cache is used to calculate the ablation. Defaults to None. 35 | attributed_hook_names (List[str], optional): A list of hook names to be attributed. Defaults to ['mlp.hook_post', 'hook_z']. 36 | metric (str, optional): The metric to be used for the attribution. Defaults to 'IE' (indirect effect). 37 | batch_size (int, optional): The batch size to be used for the attribution. Defaults to 1. 38 | verbose (bool, optional): Whether to print a progress bar. Defaults to True. 39 | Returns: 40 | Dict[str, torch.Tensor]: A dictionary containing the attribution patching scores for each component. 41 | Each score tensor has the shape (len(prompts), *hook_dim). 42 | """ 43 | model.requires_grad_(True) 44 | 45 | # Node filter function 46 | should_measure_hook_filter = partial( 47 | should_measure_hook, measurable_hooks=attributed_hook_names 48 | ) 49 | 50 | # Choose a random corrupt prompt for each prompt, if not given 51 | assert ( 52 | corrupt_prompts_and_answers or mean_cache 53 | ), "Either corrupt prompts or mean cache must be provided for ablation attribution" 54 | use_counterfactual_ablation = corrupt_prompts_and_answers is not None 55 | 56 | if use_counterfactual_ablation: 57 | corrupt_prompts, corrupt_answers = separate_prompts_and_answers( 58 | corrupt_prompts_and_answers 59 | ) 60 | corrupt_labels = model.to_tokens(corrupt_answers, prepend_bos=False) 61 | 62 | prompts, answers = separate_prompts_and_answers(prompts_and_answers) 63 | labels = model.to_tokens(answers, prepend_bos=False) 64 | 65 | attr_patching_scores = {} 66 | 67 | it = ( 68 | tqdm(range(0, len(prompts), batch_size)) 69 | if verbose 70 | else range(0, len(prompts), batch_size) 71 | ) 72 | for idx in it: 73 | prompt_batch = prompts[idx : idx + batch_size] 74 | label_batch = labels[idx : idx + batch_size].to(device=model.cfg.device) 75 | 76 | if use_counterfactual_ablation: 77 | corrupt_prompt_batch = corrupt_prompts[idx : idx + batch_size] 78 | corrupt_label_batch = corrupt_labels[idx : idx + batch_size].to( 79 | device=model.cfg.device 80 | ) 81 | 82 | # Forward pass to get corrupt cache 83 | _, corrupt_cache = model.run_with_cache(corrupt_prompt_batch) 84 | corrupt_cache = { 85 | k: v 86 | for (k, v) in corrupt_cache.cache_dict.items() 87 | if should_measure_hook_filter(k) 88 | } 89 | else: 90 | corrupt_cache = mean_cache 91 | 92 | # Forward pass to get corrupt cache 93 | clean_logits_orig, clean_cache = model.run_with_cache(prompt_batch) 94 | clean_cache = { 95 | k: v 96 | for (k, v) in clean_cache.cache_dict.items() 97 | if should_measure_hook_filter(k) 98 | } 99 | 100 | # Calculate the difference between every two parallel activation cache elements 101 | diff_cache = {} 102 | for k in clean_cache: 103 | diff_cache[k] = corrupt_cache[k] - clean_cache[k] 104 | 105 | def backward_hook_fn(grad, hook, attr_patching_scores): 106 | # Gradient is multiplicated with the activation difference (between corrupt and clean prompts). 107 | if hook.name not in attr_patching_scores: 108 | attr_patching_scores[hook.name] = torch.zeros( 109 | (len(prompts),) + clean_cache[hook.name].shape[1:], device="cpu" 110 | ) 111 | attr_patching_scores[hook.name][idx : idx + batch_size] = ( 112 | diff_cache[hook.name] * grad 113 | ).cpu() 114 | 115 | model.reset_hooks() 116 | model.add_hook( 117 | name=should_measure_hook_filter, 118 | hook=partial(backward_hook_fn, attr_patching_scores=attr_patching_scores), 119 | dir="bwd", 120 | ) 121 | with torch.set_grad_enabled(True): 122 | clean_logits = model(prompt_batch, return_type="logits") 123 | if metric == "IE": 124 | if use_counterfactual_ablation: 125 | value = indirect_effect( 126 | clean_logits_orig[:, -1] 127 | .softmax(dim=-1) 128 | .to(device=model.cfg.device), 129 | clean_logits[:, -1].softmax(dim=-1).to(device=model.cfg.device), 130 | label_batch, 131 | corrupt_label_batch, 132 | ).mean(dim=0) 133 | else: 134 | # When using mean ablation in attribution, the IE metric becomes only the "clean" part in the original IE 135 | pre_ablation_probs = clean_logits_orig[:, -1].softmax(dim=-1) 136 | post_ablation_probs = clean_logits[:, -1].softmax(dim=-1) 137 | value = ( 138 | pre_ablation_probs.gather(1, label_batch) 139 | - post_ablation_probs.gather(1, label_batch) 140 | ) / post_ablation_probs.gather(1, label_batch) 141 | value = value.nan_to_num(0) 142 | value = value.squeeze(1).mean(dim=0) 143 | elif metric == "KL": 144 | kl_loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=False) 145 | target = clean_logits_orig[:, -1].softmax(dim=-1) 146 | value = kl_loss(clean_logits[:, -1].softmax(dim=-1).log(), target) 147 | else: 148 | raise ValueError(f"Unknown metric {metric}") 149 | value.backward() 150 | model.zero_grad() 151 | 152 | del diff_cache 153 | 154 | model.reset_hooks() 155 | model.requires_grad_(False) 156 | torch.cuda.empty_cache() 157 | 158 | return attr_patching_scores 159 | 160 | 161 | def should_measure_hook(hook_name, measurable_hooks): 162 | if any([h in hook_name for h in measurable_hooks]): 163 | return True 164 | -------------------------------------------------------------------------------- /script_topk_neuron_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import pickle 4 | import random 5 | import torch 6 | import os 7 | from circuit import Circuit 8 | from circuit_utils import topk_effective_components 9 | from evaluation_utils import circuit_faithfulness_with_mean_ablation 10 | from component import Component 11 | from prompt_generation import OPERATORS, POSITIONS 12 | from general_utils import generate_activations, set_deterministic, load_model, get_model_consts, get_neuron_importance_scores 13 | 14 | 15 | torch.set_grad_enabled(False) 16 | max_op = 300 17 | 18 | 19 | def calc_mean_cache(model, model_name): 20 | eval_mean_cache_path = f'./data/{model_name}/mean_cache_for_evaluation_all_arithmetic_prompts_max_op=300.pt' 21 | if os.path.exists(eval_mean_cache_path): 22 | print('Mean cache file found, skipping creation') 23 | return 24 | 25 | print("Calculating mean cache") 26 | all_heads = [(l, h) for h in range(model.cfg.n_heads) for l in range(model.cfg.n_layers)] 27 | all_mlps = list(range(model.cfg.n_layers)) 28 | model.set_use_attn_result(True) 29 | all_components = [Component('z', layer=l, head=h) for (l, h) in all_heads] + [Component('result', layer=l, head=h) for (l, h) in all_heads] + \ 30 | [Component('mlp_post', layer=l) for l in all_mlps] + [Component('mlp_in', layer=l) for l in all_mlps] 31 | 32 | # Notice the prompts used for mean calculation are "illegal" prompts as well. This is to keep a balance between all operators. 33 | all_prompts = [f"{x}{operator}{y}=" for operator in OPERATORS for x in range(0, max_op) for y in range(0, max_op)] 34 | 35 | cached_activations = generate_activations(model, all_prompts, all_components, pos=None, reduce_mean=True) 36 | cached_activations = {c: a[None, ...].to(device='cpu') for c, a in zip(all_components, cached_activations)} 37 | torch.save(cached_activations, eval_mean_cache_path) 38 | 39 | 40 | def build_circuit(model, model_name, operator_idx, mlp_top_neurons): 41 | heads = [] 42 | 43 | ie_maps = torch.load(f'./data/{model_name}/ie_maps_activation_patching.pt') 44 | summed_seed_ie_maps = {} 45 | seeds = set([]) 46 | for op_idx, pos, seed in ie_maps.keys(): 47 | seeds.add(seed) 48 | if (op_idx, pos) not in summed_seed_ie_maps: 49 | summed_seed_ie_maps[(op_idx, pos)] = ie_maps[(op_idx, pos, seed)] 50 | else: 51 | summed_seed_ie_maps[(op_idx, pos)] += ie_maps[(op_idx, pos, seed)] 52 | ie_maps = {k: v / len(seeds) for (k, v) in summed_seed_ie_maps.items()} 53 | ie_maps = torch.stack([ie_maps[(operator_idx, pos)] for pos in POSITIONS]).mean(dim=0) 54 | 55 | if model_name == 'llama3-8b': 56 | if operator_idx == 0: 57 | # Addition 58 | heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 3), (5, 31), (14, 12), (15, 13), (16, 21)]] 59 | elif operator_idx == 1: 60 | # Subtraction 61 | heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (13, 21), (13, 22), (14, 12), (15, 13), (16, 21)]] 62 | elif operator_idx == 2: 63 | # Multiplication 64 | heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 30), (8, 15), (9, 26), (13, 18), (13, 21), (13, 22), 65 | (14, 12), (14, 13), (15, 8), (15, 13), (15, 14), (15, 15), (16, 3), 66 | (16, 21), (17, 24), (17, 26), (18, 16), (20, 2), (22, 1)]] 67 | elif operator_idx == 3: 68 | # Division 69 | heads = [Component('z', layer=l, head=h) for (l, h) in [(2, 2), (5, 31), (15, 13), (15, 14), (16, 21), (18, 16)]] 70 | elif model_name == 'gptj' or 'pythia-6.9b' in model_name or model_name == 'llama3-70b': 71 | heads += topk_effective_components(model, ie_maps, k=100 if model_name == 'llama3-70b' else 50, heads_only=True).keys() 72 | else: 73 | raise ValueError(f"Unknown model {model_name}") 74 | 75 | partial_mlp_layers = list(range(get_model_consts(model_name).first_heuristics_layer, model.cfg.n_layers)) 76 | full_mlps = [Component('mlp_post', layer=l) for l in range(model.cfg.n_layers) if l not in partial_mlp_layers] 77 | partial_mlps = [Component('mlp_post', layer=l, neurons=mlp_top_neurons[l]) for l in partial_mlp_layers] 78 | 79 | full_circuit = Circuit(model.cfg) 80 | for c in list(set(heads + full_mlps + partial_mlps)): 81 | full_circuit.add_component(c) 82 | return full_circuit 83 | 84 | 85 | def evaluate_circuit_with_topk_neurons(model, model_name, evaluation_prompts_and_answers): 86 | PROMPT_COUNT_TO_USE = 50 87 | 88 | # Load mean cache for ablations 89 | mean_cache = torch.load(f'./data/{model_name}/mean_cache_for_evaluation_all_arithmetic_prompts_max_op=300.pt') 90 | if mean_cache[list(mean_cache.keys())[0]].shape[0] != PROMPT_COUNT_TO_USE: 91 | # Mean cache was saved as a single vector, we repeat it for the length of the evaluation prompts 92 | mean_cache = {c: a.repeat(PROMPT_COUNT_TO_USE, 1, 1) for c, a in mean_cache.items()} 93 | 94 | # Settings 95 | k_values = torch.tensor(sorted(list(range(0, 500, 10)) + list(range(500, model.cfg.d_mlp, 50)) + [model.cfg.d_mlp])) 96 | seeds = [42, 412, 32879] 97 | 98 | # Load existing results 99 | results_file_path = f'./data/{model_name}/topk_neuron_faithfulness_evaluation_results.pt' 100 | if os.path.exists(results_file_path): 101 | faithfulness_per_k = torch.load(results_file_path) 102 | else: 103 | faithfulness_per_k = {} 104 | 105 | # Calculate faithfulness of the model for each operator and seed (The seed affects the prompts chosen for evaluation) 106 | for seed in seeds: 107 | for operator_idx in range(len(OPERATORS)): 108 | logging.info(f"Starting {operator_idx=}, {seed=}") 109 | if (operator_idx, seed) in faithfulness_per_k: 110 | logging.info(f"Found results file for {operator_idx=}, {seed=}") 111 | continue 112 | set_deterministic(seed) 113 | prompts_and_answers = random.sample(evaluation_prompts_and_answers[operator_idx], k=PROMPT_COUNT_TO_USE) 114 | 115 | mlppost_neuron_scores = get_neuron_importance_scores(model, model_name, operator_idx=operator_idx, pos=-1) # Ranking neurons according to Attribution patching 116 | 117 | faithfulness_per_k[(operator_idx, seed)] = torch.zeros((len(k_values),)) 118 | for i, k in enumerate(k_values): 119 | mlp_top_neurons = {} 120 | for mlp in range(1, model.cfg.n_layers): 121 | mlp_top_neurons[mlp] = mlppost_neuron_scores[mlp].topk(k).indices.tolist() 122 | full_circuit = build_circuit(model, model_name, operator_idx, mlp_top_neurons) 123 | faithfulness_per_k[(operator_idx, seed)][i] = circuit_faithfulness_with_mean_ablation(model, full_circuit, prompts_and_answers, mean_cache, metric='nl') 124 | logging.info(f"{k=}, faithfulness={faithfulness_per_k[(operator_idx, seed)][i].item()}") 125 | torch.save(faithfulness_per_k, results_file_path) 126 | 127 | 128 | def parse_args(): 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument('--model_name', type=str, help='Name of the model to be loaded') 131 | parser.add_argument('--model_path', type=str, help='Path to the model to be loaded') 132 | args = parser.parse_args() 133 | return args 134 | 135 | 136 | def main(): 137 | args = parse_args() 138 | model = load_model(args.model_name, args.model_path, "cuda") 139 | logging.info("Loaded model") 140 | 141 | # Pre-calculate mean cache, if doesn't exist 142 | calc_mean_cache(model, args.model_name) 143 | logging.info("Verified / Created mean cache") 144 | 145 | # Load prompts 146 | with open(fr'./data/{args.model_name}/large_prompts_and_answers_max_op=300.pkl', 'rb') as f: 147 | large_prompts_and_answers = pickle.load(f) 148 | large_prompts_and_answers = [[pa for pa in large_prompts_and_answers[op_idx] if pa[1] != '0'] for op_idx in range(len(OPERATORS))] # Drop prompts with zero for an answer due to bug (mean cache in Pythia leading to 0 logit, thus bad == good baseline in faithfulness function) 149 | 150 | # Evaluate the model's circuit 151 | evaluate_circuit_with_topk_neurons(model, args.model_name, large_prompts_and_answers) 152 | 153 | 154 | if __name__ == '__main__': 155 | logging.basicConfig(level=logging.INFO) 156 | main() -------------------------------------------------------------------------------- /script_eval_pythia_faithfulness_only_mutual_neurons.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import random 5 | import torch 6 | import logging 7 | from itertools import chain 8 | from circuit import Circuit 9 | from component import Component 10 | from circuit_utils import topk_effective_components 11 | from evaluation_utils import circuit_faithfulness_with_mean_ablation 12 | from general_utils import generate_activations, get_neuron_importance_scores, set_deterministic, load_model 13 | from heuristics_classification import load_heuristic_classes 14 | from prompt_generation import OPERATORS, POSITIONS 15 | from model_analysis_consts import PYTHIA_6_9B_CONSTS 16 | 17 | 18 | device = 'cuda' 19 | HEURISTIC_MATCH_THRESHOLD = 0.6 20 | PYTHIA_PREFIX = "pythia-6.9b" 21 | 22 | 23 | def load_mean_cache(model, model_name): 24 | """ 25 | Load (or calculate) the mean activation for each component in the model, to be used for evaluation. 26 | """ 27 | max_op = 300 28 | eval_mean_cache_path = f'./data/{model_name}/mean_cache_for_evaluation_all_arithmetic_prompts_max_op={max_op}.pt' 29 | if os.path.exists(eval_mean_cache_path): 30 | cached_activations = torch.load(eval_mean_cache_path) 31 | print('Loaded cached activations from file') 32 | else: 33 | all_heads = [(l, h) for h in range(model.cfg.n_heads) for l in range(model.cfg.n_layers)] 34 | all_mlps = list(range(model.cfg.n_layers)) 35 | model.set_use_attn_result(True) 36 | all_components = [Component('z', layer=l, head=h) for (l, h) in all_heads] + \ 37 | [Component('mlp_post', layer=l) for l in all_mlps] 38 | all_prompts = [f"{x}{operator}{y}=" for operator in OPERATORS for x in range(0, max_op) for y in range(0, max_op)] 39 | cached_activations = generate_activations(model, all_prompts, all_components, pos=None, reduce_mean=True, batch_size=64) 40 | cached_activations = {c: a[None, ...].to(device='cpu') for c, a in zip(all_components, cached_activations)} 41 | torch.save(cached_activations, eval_mean_cache_path) 42 | return cached_activations 43 | 44 | 45 | def build_circuit(model, mlp_neurons, operator_idx): 46 | # heads = [Component('z', layer=l, head=h) for l in range(0, model.cfg.n_layers) for h in range(0, model.cfg.n_heads)] 47 | ie_maps = torch.load(f'./data/{PYTHIA_PREFIX}/ie_maps_activation_patching.pt') 48 | summed_seed_ie_maps = {} 49 | seeds = set([]) 50 | for op_idx, pos, seed in ie_maps.keys(): 51 | seeds.add(seed) 52 | if (op_idx, pos) not in summed_seed_ie_maps: 53 | summed_seed_ie_maps[(op_idx, pos)] = ie_maps[(op_idx, pos, seed)] 54 | else: 55 | summed_seed_ie_maps[(op_idx, pos)] += ie_maps[(op_idx, pos, seed)] 56 | ie_maps = {k: v / len(seeds) for (k, v) in summed_seed_ie_maps.items()} 57 | ie_maps = torch.stack([ie_maps[(operator_idx, pos)] for pos in POSITIONS]).mean(dim=0) 58 | heads = topk_effective_components(model, ie_maps, k=50, heads_only=True).keys() 59 | 60 | partial_mlp_layers = list(range(PYTHIA_6_9B_CONSTS.first_heuristics_layer, model.cfg.n_layers)) 61 | full_mlps = [Component('mlp_post', layer=l) for l in range(model.cfg.n_layers) if l not in partial_mlp_layers] 62 | partial_mlps = [Component('mlp_post', layer=l, neurons=mlp_neurons[l]) for l in partial_mlp_layers] 63 | 64 | full_circuit = Circuit(model.cfg) 65 | for c in list(set(heads + full_mlps + partial_mlps)): 66 | full_circuit.add_component(c) 67 | return full_circuit 68 | 69 | 70 | def get_heuristic_neurons(model, model_name, operator_idx): 71 | heuristic_classes = load_heuristic_classes(f"./data/{model_name}", operator_idx, "HYBRID") 72 | 73 | # Filter by threshold 74 | heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= HEURISTIC_MATCH_THRESHOLD] for name, layer_neuron_scores in heuristic_classes.items()} 75 | heuristic_classes = {name: lns for name, lns in heuristic_classes.items() if len(lns) > 0} 76 | 77 | heuristic_neurons = {layer: [n for (l, n) in set([(v[0], v[1]) for v in chain.from_iterable(heuristic_classes.values())]) if l == layer] for layer in range(model.cfg.n_layers)} 78 | return heuristic_neurons 79 | 80 | 81 | def get_intersection_neurons(model, model_name, operator_idx): 82 | """ 83 | Get the neurons who are classified into the same heuristic both in the tested model (supplied by model_name) as well as the last checkpoint (step 143K) model. 84 | """ 85 | # Generate the heuristic list in the last (GT) checkpoint 86 | step_to_compare_to = "143000" 87 | gt_model_name = f"{PYTHIA_PREFIX}-step{step_to_compare_to}" 88 | 89 | gt_heuristic_classes = load_heuristic_classes(f"./data/{gt_model_name}", operator_idx, "HYBRID") 90 | # Filter by threshold 91 | gt_heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= HEURISTIC_MATCH_THRESHOLD] for name, layer_neuron_scores in gt_heuristic_classes.items()} 92 | gt_heuristic_classes = {name: lns for name, lns in gt_heuristic_classes.items() if len(lns) > 0} 93 | gt_heuristic_neuron_pairs = [(h_name, l, n) for h_name, lns in gt_heuristic_classes.items() for (l, n, s) in lns] 94 | 95 | # Generate the heuristic list in the current model 96 | heuristic_classes = load_heuristic_classes(f"./data/{model_name}", operator_idx, "HYBRID") 97 | # Filter by threshold 98 | heuristic_classes = {name: [(l, n, s) for (l, n, s) in layer_neuron_scores if s >= HEURISTIC_MATCH_THRESHOLD] for name, layer_neuron_scores in heuristic_classes.items()} 99 | heuristic_classes = {name: lns for name, lns in heuristic_classes.items() if len(lns) > 0} 100 | heuristic_neuron_pairs = [(h_name, l, n) for h_name, lns in heuristic_classes.items() for (l, n, s) in lns] 101 | 102 | # Get the intersection of the neurons 103 | mutual_neurons = list(set([(l, n) for (h_name, l, n) in set(gt_heuristic_neuron_pairs).intersection(set(heuristic_neuron_pairs))])) 104 | mutual_neurons = {layer: [n for (l, n) in mutual_neurons if l == layer] for layer in range(model.cfg.n_layers)} 105 | return mutual_neurons 106 | 107 | 108 | def get_topk_neurons_per_layer(model, model_name, k=200, operator_idx=0, pos=-1): 109 | """ 110 | Get a dictionary of {layer: [neurons]} where the neurons are the top-k neurons (sorted by indirect effect) in the given layer. 111 | """ 112 | neurons_scores = get_neuron_importance_scores(model, model_name, operator_idx=operator_idx, pos=pos) 113 | neurons_scores = {layer: neurons.topk(k).indices.tolist() for layer, neurons in neurons_scores.items()} 114 | return neurons_scores 115 | 116 | 117 | def calculate_faithfulness(model, model_name, mean_cache): 118 | """ 119 | Calculate the faithfulness of the arithmetic circuit, in few settings: 120 | 1. With all top neurons in each layer in the given model. (Should be high faithfulness, this is used as a sanity check and not presented in the figures). 121 | 2. With all neurons among the top neurons that implement heuristics. 122 | 3. With all neurons among the top neurons that implement heuristics and also implement the same heuristic in the final checkpoint. 123 | The results are saved and later analyzed into the figures shown in the relevant section in the paper. 124 | """ 125 | with open(fr'./data/{model_name}/large_prompts_and_answers_max_op=300.pkl', 'rb') as f: 126 | large_prompts_and_answers = pickle.load(f) 127 | 128 | for operator_idx in range(len(OPERATORS)): 129 | prompts_and_answers = random.sample(large_prompts_and_answers[operator_idx], k=50) 130 | 131 | # Sanity check faithfulness 132 | topk_neurons = get_topk_neurons_per_layer(model, model_name, k=200) 133 | full_circuit = build_circuit(model, topk_neurons, operator_idx) 134 | sanity_faithfulness = circuit_faithfulness_with_mean_ablation(model, full_circuit, prompts_and_answers, mean_cache, metric='nl') 135 | 136 | # Get the heuristic neurons in each layer and Calculate faithfulness based on them 137 | mlp_neurons = get_heuristic_neurons(model, model_name, operator_idx) 138 | full_circuit = build_circuit(model, mlp_neurons, operator_idx) 139 | baseline_faithfulness = circuit_faithfulness_with_mean_ablation(model, full_circuit, prompts_and_answers, mean_cache, metric='nl') 140 | 141 | # Get the neurons n who belong to a heuristic h where (h,n) also appears in the last checkpoint and calculate faithfulness based on them 142 | mutual_neurons_with_final_step = get_intersection_neurons(model, model_name, operator_idx) 143 | full_circuit = build_circuit(model, mutual_neurons_with_final_step, operator_idx) 144 | mutual_faithfulness = circuit_faithfulness_with_mean_ablation(model, full_circuit, prompts_and_answers, mean_cache, metric='nl') 145 | 146 | # Save the results 147 | logging.info(f"Operator {operator_idx}: Sanity faithfulness: {sanity_faithfulness}, Baseline faithfulness: {baseline_faithfulness}, Mutual faithfulness: {mutual_faithfulness}") 148 | results_file_path = f'./data/pythia-6.9b-step143000/mutual_faithfulness_results.pt' 149 | if os.path.exists(results_file_path): 150 | results = torch.load(results_file_path) 151 | else: 152 | results = {} 153 | results[(model_name, operator_idx)] = (sanity_faithfulness, baseline_faithfulness, mutual_faithfulness) 154 | torch.save(results, './data/pythia-6.9b-step143000/mutual_faithfulness_results.pt') 155 | 156 | 157 | def parse_args(): 158 | parser = argparse.ArgumentParser() 159 | parser.add_argument('--model_name', type=str, help='Name of the model to be loaded') 160 | parser.add_argument('--model_path', type=str, help='Path to the model to be loaded') 161 | args = parser.parse_args() 162 | return args 163 | 164 | 165 | # Generate the heuristics intersection with a chosen training step (without categorization, weighted mean across all heuristics) 166 | def main(): 167 | """ 168 | This script organizes the experiments done for the first part of section 5 (Analysis across Pythia-6.9B training checkpoints). 169 | """ 170 | torch.set_grad_enabled(False) 171 | 172 | args = parse_args() 173 | model_name = args.model_name 174 | model_path = args.model_path 175 | 176 | logging.info("Loading model") 177 | model = load_model(model_name, model_path, device, False) 178 | 179 | logging.info("Calculating mean cache") 180 | mean_cache = load_mean_cache(model, model_name) 181 | mean_cache = {c: a.repeat(50, 1, 1) for c, a in mean_cache.items()} 182 | 183 | print("Calculating faithfulness") 184 | set_deterministic(42) 185 | calculate_faithfulness(model, args.model_name, mean_cache) 186 | 187 | 188 | if __name__ == '__main__': 189 | logging.basicConfig(level = logging.INFO) 190 | main() 191 | 192 | 193 | 194 | -------------------------------------------------------------------------------- /prompt_generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import transformer_lens as lens 4 | from tqdm import tqdm 5 | from typing import List, Tuple, Dict, Optional 6 | from general_utils import predict_answer 7 | from model_analysis_consts import LLAMA3_8B_CONSTS 8 | 9 | 10 | OPERATORS = ['+', '-', '*' , '/'] 11 | OPERATOR_NAMES = ['addition', 'subtraction', 'multiplication', 'division'] 12 | POSITIONS = [1, 2, 3, 4] # All actual token positions (1 = op1, 2 = operator, 3 = op2, 4 = equals sign) 13 | 14 | 15 | def generate_prompts(model: lens.HookedTransformer, 16 | operand_ranges: Dict[str, Tuple[int, int]], 17 | validate_numerals: bool = True, 18 | correct_prompts: bool = True, 19 | num_prompts_per_operator: Optional[int] = 50, 20 | single_token_number_range: Tuple[int, int] = (0, LLAMA3_8B_CONSTS.max_single_token), 21 | additional_shots_per_operator: Dict[str, int] = None 22 | ): 23 | """ 24 | Generate arithmetic prompts of the for "x op y=" (without spaces), where x and y are integers and op is an operator. 25 | The prompts are filtered according to the arguments. 26 | 27 | Args: 28 | model (nn.Module): The model used to validate answer correctness. 29 | operand_ranges (dict): A dictionary of the form {operator: (operand_min, operand_max)}. 30 | validate_numerals (bool): If True, only prompts with valid numerals are returned (no "weird token" answers). 31 | correct_prompts (bool): If True, only prompts completed correctly by the model are returned. Otherwise, only prompts completed 32 | incorrectly by the model are returned. 33 | num_prompts_per_operator (int): The number of prompts to generate per operand. If None, all possible prompts are generated. 34 | 35 | Returns: 36 | List[List[Tuple[str, str]]]: The main list is indexed by the different operators, and each list contains num_prompts_per_operator tuples, 37 | where each tuple is of form (prompt, answer). 38 | """ 39 | prompts_and_answers = [] 40 | 41 | # assert num_prompts_per_operator > 0, 'num_prompts_per_operator must be a positive integer' 42 | for operator in operand_ranges.keys(): 43 | assert single_token_number_range[0] <= operand_ranges[operator][0] <= operand_ranges[operator][1] <= single_token_number_range[1], \ 44 | f'Invalid operand range for operator {operator}' 45 | 46 | for operator in operand_ranges.keys(): 47 | # Generate all possible prompts for the given operator within the operand limits 48 | operand_min, operand_max = operand_ranges[operator] 49 | all_operator_prompts = generate_all_prompts_for_operator(operator, operand_min, operand_max, single_token_number_range) 50 | 51 | if additional_shots_per_operator is not None: 52 | all_operator_prompts = [f"{additional_shots_per_operator[operator]}{p}" for p in all_operator_prompts] 53 | 54 | # Filter the prompts 55 | filtered_operator_prompts = filter_generated_prompts(model, all_operator_prompts, validate_numerals, correct_prompts) 56 | assert len(filtered_operator_prompts) > 0, f'No valid prompts for operator {operator} with given parameters' 57 | 58 | # Take k prompts, while maximizing the number of unique answers (so that during patching experiments, the clean and corrupt answers will be different) 59 | if num_prompts_per_operator is not None: 60 | filtered_operator_prompts = _maximize_unique_answers(filtered_operator_prompts, k=num_prompts_per_operator) 61 | 62 | prompts_and_answers.append(filtered_operator_prompts) 63 | return prompts_and_answers 64 | 65 | 66 | def generate_all_prompts_for_operator(operator: str, 67 | operand_min: int, 68 | operand_max: int, 69 | single_token_number_range: Tuple[int, int]) -> List[str]: 70 | """ 71 | Generate ALL possible "relevant" prompts for a given operator. 72 | Prompts are valid if there is no illegal operation (e.g. division by zero, negative result, etc). 73 | 74 | Args: 75 | operator (str): The operator to generate prompts for. 76 | operand_min (int): The minimum value for the operands. 77 | operand_max (int): The maximum value for the operands. 78 | Return: 79 | List[str]: A list of all possible relevant prompts for a given operator. 80 | """ 81 | all_operator_prompts = [] 82 | for operand1 in range(operand_min, operand_max): 83 | operand_2_range = _get_operand_range(operator, operand1, operand_min, operand_max, single_token_number_range[1]) 84 | for operand2 in operand_2_range: 85 | prompt = '{x}{op}{y}='.format(x=operand1, op=operator, y=operand2) 86 | answer = eval(prompt[:-1]) 87 | if single_token_number_range[0] <= answer <= single_token_number_range[1]: 88 | all_operator_prompts.append(prompt) 89 | return all_operator_prompts 90 | 91 | 92 | def separate_prompts_and_answers(prompts_and_answers: List[Tuple[str, str]]): 93 | """ 94 | Separates a list of (prompt, answer) tuples to two lists - one of prompts and one of answers. 95 | """ 96 | return [pa[0] for pa in prompts_and_answers], [pa[1] for pa in prompts_and_answers] 97 | 98 | 99 | def filter_generated_prompts(model: lens.HookedTransformer, 100 | prompts: List[str], 101 | validate_numerals: bool = True, 102 | correct_prompts: bool = True): 103 | """ 104 | Filters generated prompts according to the given arguments. 105 | 106 | Args: 107 | model (nn.Module): The model used to validate answer correctness. 108 | prompts (List[str]): The prompts to filter. 109 | validate_numerals (bool): If True, only prompts with valid numerals are returned (no "weird token" answers). 110 | correct_prompts (bool): If True, only prompts completed correctly by the model are returned. Otherwise, only prompts completed 111 | incorrectly by the model are returned. 112 | 113 | Returns: 114 | List[Tuple[str, str]]: A list of (prompt, answer) tuples. 115 | """ 116 | all_filtered_prompts_and_answers = [] 117 | dataloader = torch.utils.data.DataLoader(prompts, batch_size=32, shuffle=False) 118 | for batch in tqdm(dataloader): 119 | filtered_prompts = batch 120 | answers = predict_answer(model, batch) 121 | 122 | # Use only prompts with numerical answers by the model 123 | if validate_numerals: 124 | numerical_indices = [i for i in range(len(answers)) if _is_number(answers[i])] 125 | filtered_prompts = [filtered_prompts[i] for i in numerical_indices] 126 | answers = [answers[i] for i in numerical_indices] 127 | 128 | # Use only prompts with correct answers (or incorrect, if `correct_prompts` is False) 129 | is_correct_answers = [_is_answer_correct(prompt, answer) for prompt, answer in zip(filtered_prompts, answers)] 130 | filtered_prompts = [filtered_prompts[i] for i in range(len(filtered_prompts)) if (correct_prompts and is_correct_answers[i]) or (not correct_prompts and not is_correct_answers[i])] 131 | answers = [answers[i] for i in range(len(answers)) if (correct_prompts and is_correct_answers[i]) or (not correct_prompts and not is_correct_answers[i])] 132 | all_filtered_prompts_and_answers.extend(list(zip(filtered_prompts, answers))) 133 | 134 | return all_filtered_prompts_and_answers 135 | 136 | 137 | def _maximize_unique_answers(rigorous_prompts_and_answers, k=50): 138 | """ 139 | Get a subset of prompts and answers with as many unique answers as possible. 140 | Args: 141 | rigorous_prompts_and_answers (list of tuples): A list of (prompt, answer) pairs. 142 | k (int, optional): The desired number of prompt-answer pairs in the output list. Defaults to 50. 143 | If there are less than k unique answers in the input list, there will be answer repetitions. 144 | Returns: 145 | list of tuples: A list of (prompt, answer) pairs with as many unique answers as possible, up to length `k`. 146 | """ 147 | if len(rigorous_prompts_and_answers) < k: 148 | new_prompts_and_answers = rigorous_prompts_and_answers + random.choices(rigorous_prompts_and_answers, k=k-len(rigorous_prompts_and_answers)) 149 | random.shuffle(new_prompts_and_answers) 150 | return new_prompts_and_answers 151 | else: 152 | unique_answers = set() 153 | new_prompts_and_answers = [] 154 | random.shuffle(rigorous_prompts_and_answers) 155 | for prompt, answer in rigorous_prompts_and_answers: 156 | if answer not in unique_answers: 157 | unique_answers.add(answer) 158 | new_prompts_and_answers.append((prompt, answer)) 159 | if len(new_prompts_and_answers) < k: 160 | new_prompts_and_answers += random.choices(rigorous_prompts_and_answers, k=k-len(new_prompts_and_answers)) 161 | 162 | return new_prompts_and_answers[:k] 163 | 164 | 165 | def _get_operand_range(operator, previous_operand, operand_min, operand_max, max_single_token_value): 166 | if operator == '+': 167 | return range(operand_min, min(max_single_token_value - previous_operand, operand_max)) 168 | elif operator == '-': 169 | return range(operand_min, previous_operand + 1) 170 | elif operator == '*': 171 | if previous_operand == 0: 172 | return range(operand_min, operand_max) 173 | else: 174 | return range(operand_min, min((max_single_token_value // previous_operand) + 1, operand_max)) 175 | elif operator == '/': 176 | return range(max(1, operand_min), operand_max) 177 | else: 178 | raise ValueError(f'Operator {operator} is not supported') 179 | 180 | 181 | 182 | def _is_answer_correct(prompt: str, answer: str, convert_to_int: bool = True): 183 | """ 184 | Checks if an answer is a correct completion to a prompt. 185 | Whitespaces are ignored. 186 | 187 | Args: 188 | prompt (str): The prompt (for example '5+4=') 189 | answer (str): The answer (for example '9') 190 | convert_to_int (bool): If True, the ground truth answer is converted to an integer before comparison to the tested answer. 191 | """ 192 | # Handle few-shot case 193 | few_shot_sep = ';' if ';' in prompt else (',' if ',' in prompt else None) 194 | if few_shot_sep is not None: 195 | prompt = prompt[prompt.rfind(few_shot_sep) + 1:] 196 | 197 | real_answer = eval(prompt.replace('=', '')) 198 | if convert_to_int: 199 | real_answer = int(real_answer) 200 | try: 201 | return real_answer == _to_number(answer) 202 | except ValueError: 203 | return False 204 | 205 | 206 | def _to_number(s: str): 207 | try: 208 | return int(s) 209 | except ValueError: 210 | return float(s) 211 | 212 | 213 | def _is_number(s: str, is_int=False): 214 | try: 215 | if is_int: 216 | int(s) 217 | else: 218 | float(s) 219 | return True 220 | except ValueError: 221 | return False 222 | 223 | 224 | def is_writing_of_number(s: str): 225 | word_to_number = { 226 | 'zero': 0, 'one': 1, 'two': 2, 'three': 3, 'four': 4, 227 | 'five': 5, 'six': 6, 'seven': 7, 'eight': 8, 'nine': 9, 228 | 'ten': 10, 'eleven': 11, 'twelve': 12, 'thirteen': 13, 'fourteen': 14, 229 | 'fifteen': 15, 'sixteen': 16, 'seventeen': 17, 'eighteen': 18, 'nineteen': 19, 230 | 'twenty': 20, 'thirty': 30, 'forty': 40, 'fifty': 50, 'sixty': 60, 231 | 'seventy': 70, 'eighty': 80, 'ninety': 90, 'hundred': 100, 'thousand': 1000, 232 | 'million': 1000000 233 | } 234 | 235 | words = s.split() 236 | for word in words: 237 | if word not in word_to_number: 238 | return False 239 | return True -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 |
5 | 6 | 7 |215 | Do large language models (LLMs) solve reasoning tasks by learning robust generalizable algorithms, or do they memorize training data? 216 | To investigate this question, we use arithmetic reasoning as a representative task. 217 | Using causal analysis, we identify a subset of the model (a circuit) that explains most of the model's behavior for basic arithmetic logic and examine its functionality. 218 | By zooming in on the level of individual circuit neurons, we discover a sparse set of important neurons that implement simple heuristics. Each heuristic identifies a numerical input pattern and outputs corresponding answers. 219 | We hypothesize that the combination of these heuristic neurons is the mechanism used to produce correct arithmetic answers. 220 | To test this, we categorize each neuron into several heuristic types — such as neurons that activate when an operand falls within a certain range — and find that the unordered combination of these heuristic types is the mechanism that explains most of the model's accuracy on arithmetic prompts. 221 | Finally, we demonstrate that this mechanism appears as the main source of arithmetic accuracy early in training. 222 | Overall, our experimental results across several LLMs show that LLMs perform arithmetic using neither robust algorithms nor memorization; rather, they rely on a bag of heuristics. 223 |
224 |229 |
245 |
250 | 261 |
272 | 278 |
290 | 300 |
332 |
336 | 351 | Yaniv Nikankin, Anja Reusch, Aaron Mueller, Yonatan Belinkov, “Arithmetic Without Algorithms: Language Models Solve Math With a Bag of Heuristics”. 352 |
353 |
357 | @article{nikankin2024arithmetic,
358 | title={Arithmetic Without Algorithms: Language Models Solve Math With a Bag of Heuristics},
359 | author={Nikankin, Yaniv and Reusch, Anja and Mueller, Aaron and Belinkov, Yonatan},
360 | journal={arXiv preprint arXiv:2410.21272},
361 | year={2024}
362 | }
363 |
364 |