├── .gitignore ├── .gitmodules ├── ACDCPPExperiment.py ├── README.md ├── docstring_task ├── .ipynb_checkpoints │ ├── Results Analysis-checkpoint.ipynb │ ├── abs_value_pruned_attrs_docstring-checkpoint.json │ └── acdcpp_docstring-checkpoint.ipynb ├── Results Analysis.ipynb ├── abs_value_num_passes_docstring.json ├── abs_value_pruned_attrs_docstring.json ├── abs_value_pruned_heads_docstring.json └── acdcpp_docstring.ipynb ├── drawings └── pruning.excalidraw ├── greaterthan_task ├── Results Analysis.ipynb ├── acdcpp_greaterthan.ipynb ├── minimal_acdc_node_roc.py └── results │ ├── greaterthan_absval_num_passes.json │ ├── greaterthan_absval_pruned_attrs.json │ ├── greaterthan_absval_pruned_heads.json │ ├── greaterthan_first_pass_num_passes.json │ ├── greaterthan_first_pass_pruned_attrs.json │ └── greaterthan_first_pass_pruned_heads.json ├── ioi_task ├── Results Analysis.ipynb ├── abs_value_num_passes.json ├── abs_value_pruned_attrs.json ├── abs_value_pruned_heads.json ├── acdcpp_on_edges_demo.ipynb ├── ims.zip ├── ioi_dataset.py ├── noabs_value_num_passes.json ├── noabs_value_pruned_attrs.json └── noabs_value_pruned_heads.json ├── requirements.txt ├── threshold_investigation ├── docstring_thresh.ipynb ├── greaterthan_thresh.ipynb └── ioi_thresh.ipynb ├── utils ├── graphics_utils.py └── prune_utils.py └── vast-startup.sh /.gitignore: -------------------------------------------------------------------------------- 1 | env/* 2 | .ipynb_checkpoints/* 3 | __pycache__/* 4 | utils/.ipynb_checkpoints/* 5 | utils/__pycache__/* 6 | ioi_task/.ipynb_checkpoints/* 7 | ioi_task/__pycache__/* 8 | ims/*.png 9 | ioi_task/ims/*.png 10 | docstring_task/ims/*.png 11 | .python-version 12 | ims/* 13 | .python-version 14 | greaterthan_task/ims/* 15 | .vscode -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "Automatic-Circuit-Discovery"] 2 | path = Automatic-Circuit-Discovery 3 | url = https://github.com/Aaquib111/Automatic-Circuit-Discovery.git 4 | -------------------------------------------------------------------------------- /ACDCPPExperiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append('Automatic-Circuit-Discovery/') 4 | 5 | from acdc.TLACDCExperiment import TLACDCExperiment 6 | from utils.prune_utils import acdc_nodes, get_nodes 7 | from utils.graphics_utils import show 8 | 9 | from typing import Callable, List, Literal 10 | 11 | from transformer_lens import HookedTransformer 12 | import torch as t 13 | from torch import Tensor 14 | import warnings 15 | from tqdm import tqdm 16 | 17 | class ACDCPPExperiment(): 18 | 19 | def __init__( 20 | self, 21 | model: HookedTransformer, 22 | clean_data: Tensor, 23 | corr_data: Tensor, 24 | acdc_metric: Callable[[Tensor], Tensor], 25 | acdcpp_metric: Callable[[Tensor], Tensor], 26 | thresholds: List[float], 27 | run_name: str, 28 | save_graphs_after: float, 29 | verbose: bool = False, 30 | attr_absolute_val: bool = True, 31 | zero_ablation: bool = False, 32 | return_pruned_heads: bool = True, 33 | return_pruned_attr: bool = True, 34 | return_num_passes: bool = True, 35 | pass_tokens_to_metric: bool = False, 36 | pruning_mode: Literal["edge", "node"] = "node", 37 | no_pruned_nodes_attr: int = 10, 38 | **acdc_kwargs 39 | ): 40 | self.model = model 41 | 42 | self.clean_data = clean_data 43 | self.corr_data = corr_data 44 | 45 | self.run_name = run_name 46 | self.verbose = verbose 47 | 48 | self.acdc_metric = acdc_metric 49 | self.acdcpp_metric = acdcpp_metric 50 | self.pass_tokens_to_metric = pass_tokens_to_metric 51 | 52 | self.thresholds = thresholds 53 | self.attr_absolute_val = attr_absolute_val 54 | self.zero_ablation = zero_ablation 55 | 56 | # For now, not using these (lol) 57 | self.return_pruned_heads = return_pruned_heads 58 | self.return_pruned_attr = return_pruned_attr 59 | self.return_num_passes = return_num_passes 60 | self.save_graphs_after = save_graphs_after 61 | self.pruning_mode: Literal["edge", "node"] = pruning_mode 62 | self.no_pruned_nodes_attr = no_pruned_nodes_attr 63 | 64 | if self.pruning_mode == "edge" and self.no_pruned_nodes_attr != 1: 65 | warnings.warn("I've been getting errors with no_pruned_nodes_attr > 1 with edge pruning, you may wish to switch to no_pruned_nodes_attr=1") 66 | 67 | self.acdc_args = acdc_kwargs 68 | if verbose: 69 | print('Set up model hooks') 70 | 71 | def setup_exp(self, threshold: float) -> TLACDCExperiment: 72 | exp = TLACDCExperiment( 73 | model=self.model, 74 | threshold=threshold, 75 | run_name=self.run_name, 76 | ds=self.clean_data, 77 | ref_ds=self.corr_data, 78 | metric=self.acdc_metric, 79 | zero_ablation=self.zero_ablation, 80 | # save_graphs_after=self.save_graphs_after, 81 | online_cache_cpu=False, 82 | corrupted_cache_cpu=False, 83 | verbose=self.verbose, 84 | **self.acdc_args 85 | ) 86 | exp.model.reset_hooks() 87 | exp.setup_model_hooks( 88 | add_sender_hooks=True, 89 | add_receiver_hooks=True, 90 | doing_acdc_runs=False 91 | ) 92 | 93 | return exp 94 | 95 | def run_acdcpp(self, exp: TLACDCExperiment, threshold: float): 96 | if self.verbose: 97 | print('Running ACDC++') 98 | 99 | for _ in range(self.no_pruned_nodes_attr): 100 | pruned_nodes_attr = acdc_nodes( 101 | model=exp.model, 102 | clean_input=self.clean_data, 103 | corrupted_input=self.corr_data, 104 | metric=self.acdcpp_metric, 105 | threshold=threshold, 106 | exp=exp, 107 | verbose=self.verbose, 108 | attr_absolute_val=self.attr_absolute_val, 109 | mode=self.pruning_mode, 110 | ) 111 | t.cuda.empty_cache() 112 | return (get_nodes(exp.corr), pruned_nodes_attr) 113 | 114 | def run_acdc(self, exp: TLACDCExperiment): 115 | if self.verbose: 116 | print('Running ACDC') 117 | 118 | while exp.current_node: 119 | exp.step(testing=False) 120 | 121 | return (get_nodes(exp.corr), exp.num_passes) 122 | 123 | def run(self, save_after_acdcpp=True, save_after_acdc=True): 124 | os.makedirs(f'ims/{self.run_name}', exist_ok=True) 125 | 126 | pruned_heads = {} 127 | num_passes = {} 128 | pruned_attrs = {} 129 | 130 | for threshold in tqdm(self.thresholds): 131 | exp = self.setup_exp(threshold) 132 | acdcpp_heads, attrs = self.run_acdcpp(exp, threshold) 133 | # Only applying threshold to this one as these graphs tend to be HUGE 134 | if threshold >= self.save_graphs_after: 135 | print('Saving ACDC++ Graph') 136 | show(exp.corr, fname=f'ims/{self.run_name}/thresh{threshold}_before_acdc.png') 137 | 138 | acdc_heads, passes = self.run_acdc(exp) 139 | 140 | print('Saving ACDC Graph') 141 | show(exp.corr, fname=f'ims/{self.run_name}/thresh{threshold}_after_acdc.png') 142 | 143 | pruned_heads[threshold] = [acdcpp_heads, acdc_heads] 144 | num_passes[threshold] = passes 145 | pruned_attrs[threshold] = attrs 146 | del exp 147 | t.cuda.empty_cache() 148 | t.cuda.empty_cache() 149 | return pruned_heads, num_passes, pruned_attrs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Edge Attribution Patching 2 | 3 | Use the `minimal-implementation` branch for an easy-to-use version of edge attribution patching! All code in the minimal_implementation branch has been created by [Oscar Balcells](https://github.com/obalcells). 4 | 5 | 6 | 7 | This repository is currently under development. It is built on top of https://github.com/neelnanda-io/TransformerLens which we may merge into eventually. 8 | 9 | Please cite this work as: 10 | ``` 11 | @inproceedings{ 12 | syed2023attribution, 13 | title={Attribution Patching Outperforms Automated Circuit Discovery}, 14 | author={Aaquib Syed and Can Rager and Arthur Conmy}, 15 | booktitle={NeurIPS Workshop on Attributing Model Behavior at Scale}, 16 | year={2023}, 17 | url={https://openreview.net/forum?id=tiLbFR4bJW} 18 | } 19 | ``` 20 | -------------------------------------------------------------------------------- /docstring_task/.ipynb_checkpoints/abs_value_pruned_attrs_docstring-checkpoint.json: -------------------------------------------------------------------------------- 1 | {"0.005": {}, "0.01": {}, "0.015": {}, "0.02": {"L0H4": 0.01623706705868244, "L1H1": 0.01764504984021187, "L2H4": 0.019617609679698944}, "0.025": {"L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L2H4": 0.019617609679698944}, "0.030000000000000002": {"L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L2H4": 0.019617609679698944, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794}, "0.034999999999999996": {"L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794}, "0.04": {"L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L1H6": 0.03670835122466087, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.045": {"L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.049999999999999996": {"L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.055": {"L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.06": {"L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.065": {"L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.07": {"L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.07500000000000001": {"L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.08": {"L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.085": {"L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.09000000000000001": {"L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H1": 0.08805526793003082, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.095": {"L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H1": 0.08805526793003082, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.1": {"L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H1": 0.08805526793003082, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.10500000000000001": {"L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H1": 0.08805526793003082, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.11": {"L0H2": 0.10929969698190689, "L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H2": 0.10675916075706482, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H1": 0.08805526793003082, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.115": {"L0H2": 0.10929969698190689, "L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H2": 0.10675916075706482, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H1": 0.08805526793003082, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.12000000000000001": {"L0H2": 0.10929969698190689, "L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H2": 0.10675916075706482, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H1": 0.08805526793003082, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.125": {"L0H2": 0.10929969698190689, "L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H2": 0.10675916075706482, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H1": 0.08805526793003082, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.13": {"L0H2": 0.10929969698190689, "L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H2": 0.10675916075706482, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H1": 0.08805526793003082, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H3": 0.12964804470539093, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.135": {"L0H2": 0.10929969698190689, "L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H2": 0.10675916075706482, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H1": 0.08805526793003082, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H3": 0.12964804470539093, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.14": {"L0H2": 0.10929969698190689, "L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H2": 0.10675916075706482, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H1": 0.08805526793003082, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H3": 0.12964804470539093, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.14500000000000002": {"L0H2": 0.10929969698190689, "L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L1H1": 0.01764504984021187, "L1H2": 0.10675916075706482, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H1": 0.08805526793003082, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H3": 0.12964804470539093, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.15": {"L0H2": 0.10929969698190689, "L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L0H7": 0.14744149148464203, "L1H1": 0.01764504984021187, "L1H2": 0.10675916075706482, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H1": 0.08805526793003082, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H3": 0.12964804470539093, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}, "0.155": {"L0H2": 0.10929969698190689, "L0H3": 0.08162035793066025, "L0H4": 0.01623706705868244, "L0H6": 0.02447594702243805, "L0H7": 0.14744149148464203, "L1H1": 0.01764504984021187, "L1H2": 0.10675916075706482, "L1H3": 0.02843538671731949, "L1H5": 0.04031563550233841, "L1H6": 0.03670835122466087, "L1H7": 0.059216126799583435, "L2H1": 0.08805526793003082, "L2H4": 0.019617609679698944, "L2H7": 0.030115559697151184, "L3H1": 0.03058820776641369, "L3H2": 0.029797352850437164, "L3H3": 0.12964804470539093, "L3H4": 0.02627941593527794, "L3H5": 0.03959112614393234}} -------------------------------------------------------------------------------- /docstring_task/.ipynb_checkpoints/acdcpp_docstring-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "2861525f-e65f-4dc8-9774-84a5a292a2e5", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "cuda\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import os\n", 19 | "import sys\n", 20 | "sys.path.append('..')\n", 21 | "sys.path.append('../Automatic-Circuit-Discovery/')\n", 22 | "\n", 23 | "import torch as t\n", 24 | "from torch import Tensor\n", 25 | "\n", 26 | "from acdc.docstring.utils import get_all_docstring_things\n", 27 | "device = t.device(\"cuda\" if t.cuda.is_available() else \"CPU\")\n", 28 | "print(device)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "id": "5306b51d-26b3-47b8-9c08-f77b6a9350ff", 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "Loaded pretrained model attn-only-4l into HookedTransformer\n", 42 | "Moving model to device: cuda\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "all_docstring_items, batched_prompts = get_all_docstring_things(num_examples=40, seq_len=5, device=device, metric_name='docstring_metric', correct_incorrect_wandb=False)\n", 48 | "\n", 49 | "tl_model = all_docstring_items.tl_model\n", 50 | "validation_metric = all_docstring_items.validation_metric\n", 51 | "validation_data = all_docstring_items.validation_data\n", 52 | "validation_labels = all_docstring_items.validation_labels\n", 53 | "validation_patch_data = all_docstring_items.validation_patch_data\n", 54 | "test_metrics = all_docstring_items.test_metrics\n", 55 | "test_data = all_docstring_items.test_data\n", 56 | "test_labels = all_docstring_items.test_labels\n", 57 | "test_patch_data = all_docstring_items.test_patch_data" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 3, 63 | "id": "15465560-54a2-4433-8c8a-e13a152849e5", 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "def abs_docstring_metric(logits):\n", 68 | " return -abs(test_metrics['docstring_metric'](logits))" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 4, 74 | "id": "73b607c0-cedd-417f-91bb-bd3a430f3068", 75 | "metadata": { 76 | "scrolled": true 77 | }, 78 | "outputs": [ 79 | { 80 | "name": "stderr", 81 | "output_type": "stream", 82 | "text": [ 83 | " 0%| | 0/1 [00:00", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""]], "0.01": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""]], "0.015": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""]], "0.02": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""]], "0.025": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""]], "0.030000000000000002": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""]], "0.034999999999999996": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""]], "0.04": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""]], "0.045": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", ""]], "0.049999999999999996": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", ""]], "0.055": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", ""]], "0.06": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", ""]], "0.065": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", ""]], "0.07": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", ""]], "0.07500000000000001": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", ""]], "0.08": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", ""]], "0.085": [["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", ""]], "0.09000000000000001": [["", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.095": [["", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.1": [["", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.10500000000000001": [["", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.11": [["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "embed", "", "", "", "", "", "", "", "", "", ""]], "0.115": [["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "embed", "", "", "", "", "", "", "", "", "", ""]], "0.12000000000000001": [["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "embed", "", "", "", "", "", "", "", "", ""]], "0.125": [["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "embed", "", "", "", "", "", "", "", "", ""]], "0.13": [["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "embed", "", "", "", "", "", "", "", "", ""]], "0.135": [["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "embed", "", "", "", "", "", "", "", "", ""]], "0.14": [["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "embed", "", "", "", "", "", "", "", "", ""]], "0.14500000000000002": [["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "embed", "", "", "", "", "", "", "", "", ""]], "0.15": [["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""], ["", "", "", "", "embed", "", "", "", "", "", "", "", "", ""]], "0.155": [["", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""], ["", "", "", "", "embed", "", "", "", "", "", "", "", "", ""]]} -------------------------------------------------------------------------------- /docstring_task/acdcpp_docstring.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "2861525f-e65f-4dc8-9774-84a5a292a2e5", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stderr", 11 | "output_type": "stream", 12 | "text": [ 13 | "/tmp/ipykernel_13315/1528838084.py:9: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", 14 | " ipython.magic('load_ext autoreload')\n", 15 | "/tmp/ipykernel_13315/1528838084.py:10: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", 16 | " ipython.magic('autoreload 2')\n" 17 | ] 18 | }, 19 | { 20 | "name": "stdout", 21 | "output_type": "stream", 22 | "text": [ 23 | "cuda\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "import os\n", 29 | "import sys\n", 30 | "sys.path.append('..')\n", 31 | "sys.path.append('../Automatic-Circuit-Discovery/')\n", 32 | "sys.path.append('../tracr/')\n", 33 | "import IPython\n", 34 | "ipython = get_ipython()\n", 35 | "if ipython is not None:\n", 36 | " ipython.magic('load_ext autoreload')\n", 37 | " ipython.magic('autoreload 2')\n", 38 | "import torch as t\n", 39 | "from torch import Tensor\n", 40 | "\n", 41 | "from acdc.docstring.utils import get_all_docstring_things\n", 42 | "device = t.device(\"cuda\" if t.cuda.is_available() else \"CPU\")\n", 43 | "print(device)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "id": "5306b51d-26b3-47b8-9c08-f77b6a9350ff", 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "Loaded pretrained model attn-only-4l into HookedTransformer\n", 57 | "Moving model to device: cuda\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "all_docstring_items = get_all_docstring_things(num_examples=40, seq_len=5, device=device, metric_name='docstring_metric', correct_incorrect_wandb=False)\n", 63 | "\n", 64 | "tl_model = all_docstring_items.tl_model\n", 65 | "validation_metric = all_docstring_items.validation_metric\n", 66 | "validation_data = all_docstring_items.validation_data\n", 67 | "validation_labels = all_docstring_items.validation_labels\n", 68 | "validation_patch_data = all_docstring_items.validation_patch_data\n", 69 | "test_metrics = all_docstring_items.test_metrics\n", 70 | "test_data = all_docstring_items.test_data\n", 71 | "test_labels = all_docstring_items.test_labels\n", 72 | "test_patch_data = all_docstring_items.test_patch_data" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 3, 78 | "id": "15465560-54a2-4433-8c8a-e13a152849e5", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "def abs_docstring_metric(logits):\n", 83 | " return -abs(test_metrics['docstring_metric'](logits))" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "id": "73b607c0-cedd-417f-91bb-bd3a430f3068", 90 | "metadata": { 91 | "scrolled": true 92 | }, 93 | "outputs": [], 94 | "source": [ 95 | "from ACDCPPExperiment import ACDCPPExperiment\n", 96 | "import numpy as np\n", 97 | "THRESHOLDS = [0.09] # np.arange(0.04, 0.16, 0.005)\n", 98 | "# I'm just using one threshold so I can move fast!\n", 99 | "\n", 100 | "tl_model.reset_hooks()\n", 101 | "RUN_NAME = 'abs_edges'\n", 102 | "acdcpp_exp = ACDCPPExperiment(\n", 103 | " tl_model,\n", 104 | " test_data,\n", 105 | " test_patch_data,\n", 106 | " test_metrics['docstring_metric'],\n", 107 | " abs_docstring_metric,\n", 108 | " THRESHOLDS,\n", 109 | " run_name=RUN_NAME,\n", 110 | " verbose=False,\n", 111 | " attr_absolute_val=True,\n", 112 | " save_graphs_after=0,\n", 113 | " pruning_mode='edge',\n", 114 | " no_pruned_nodes_attr=1,\n", 115 | ")\n", 116 | "pruned_heads, num_passes, pruned_attrs = acdcpp_exp.run()" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 19, 122 | "id": "7d0868fc", 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "name": "stderr", 127 | "output_type": "stream", 128 | "text": [ 129 | "WARNING:root:cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache\n", 130 | "WARNING:root:cache_all is deprecated and will eventually be removed, use add_caching_hooks or run_with_cache\n" 131 | ] 132 | }, 133 | { 134 | "name": "stdout", 135 | "output_type": "stream", 136 | "text": [ 137 | "ln_final.hook_normalized\n", 138 | "ln_final.hook_scale\n", 139 | "blocks.3.hook_resid_post\n", 140 | "blocks.3.hook_attn_out\n", 141 | "blocks.3.attn.hook_result\n", 142 | "blocks.3.attn.hook_z\n", 143 | "blocks.3.attn.hook_pattern\n", 144 | "blocks.3.attn.hook_attn_scores\n", 145 | "blocks.3.attn.hook_v\n", 146 | "blocks.3.attn.hook_k\n", 147 | "blocks.3.attn.hook_q\n", 148 | "blocks.3.ln1.hook_normalized\n", 149 | "blocks.3.ln1.hook_scale\n", 150 | "blocks.3.hook_v_input\n", 151 | "blocks.3.hook_k_input\n", 152 | "blocks.3.hook_q_input\n", 153 | "blocks.3.hook_resid_pre\n", 154 | "blocks.2.hook_resid_post\n", 155 | "blocks.2.hook_attn_out\n", 156 | "blocks.2.attn.hook_result\n", 157 | "blocks.2.attn.hook_z\n", 158 | "blocks.2.attn.hook_pattern\n", 159 | "blocks.2.attn.hook_attn_scores\n", 160 | "blocks.2.attn.hook_v\n", 161 | "blocks.2.attn.hook_k\n", 162 | "blocks.2.attn.hook_q\n", 163 | "blocks.2.ln1.hook_normalized\n", 164 | "blocks.2.ln1.hook_scale\n", 165 | "blocks.2.hook_v_input\n", 166 | "blocks.2.hook_k_input\n", 167 | "blocks.2.hook_q_input\n", 168 | "blocks.2.hook_resid_pre\n", 169 | "blocks.1.hook_resid_post\n", 170 | "blocks.1.hook_attn_out\n", 171 | "blocks.1.attn.hook_result\n", 172 | "blocks.1.attn.hook_z\n", 173 | "blocks.1.attn.hook_pattern\n", 174 | "blocks.1.attn.hook_attn_scores\n", 175 | "blocks.1.attn.hook_v\n", 176 | "blocks.1.attn.hook_k\n", 177 | "blocks.1.attn.hook_q\n", 178 | "blocks.1.ln1.hook_normalized\n", 179 | "blocks.1.ln1.hook_scale\n", 180 | "blocks.1.hook_v_input\n", 181 | "blocks.1.hook_k_input\n", 182 | "blocks.1.hook_q_input\n", 183 | "blocks.1.hook_resid_pre\n", 184 | "blocks.0.hook_resid_post\n", 185 | "blocks.0.hook_attn_out\n", 186 | "blocks.0.attn.hook_result\n", 187 | "blocks.0.attn.hook_z\n", 188 | "blocks.0.attn.hook_pattern\n", 189 | "blocks.0.attn.hook_attn_scores\n", 190 | "blocks.0.attn.hook_v\n", 191 | "blocks.0.attn.hook_k\n", 192 | "blocks.0.attn.hook_q\n", 193 | "blocks.0.ln1.hook_normalized\n", 194 | "blocks.0.ln1.hook_scale\n", 195 | "blocks.0.hook_v_input\n", 196 | "blocks.0.hook_k_input\n", 197 | "blocks.0.hook_q_input\n", 198 | "blocks.0.hook_resid_pre\n", 199 | "hook_pos_embed\n", 200 | "hook_embed\n", 201 | "self.current_node=TLACDCInterpNode(blocks.3.hook_resid_post, [:])\n" 202 | ] 203 | }, 204 | { 205 | "name": "stdout", 206 | "output_type": "stream", 207 | "text": [ 208 | "('blocks.3.attn.hook_result', [:, :, 7], 'blocks.3.attn.hook_q', [:, :, 7]) Edge(EdgeType.PLACEHOLDER, True)\n", 209 | "('blocks.3.attn.hook_result', [:, :, 7], 'blocks.3.attn.hook_k', [:, :, 7]) Edge(EdgeType.PLACEHOLDER, True)\n", 210 | "('blocks.3.attn.hook_result', [:, :, 7], 'blocks.3.attn.hook_v', [:, :, 7]) Edge(EdgeType.PLACEHOLDER, True)\n", 211 | "('blocks.3.attn.hook_result', [:, :, 6], 'blocks.3.attn.hook_q', [:, :, 6]) Edge(EdgeType.PLACEHOLDER, True)\n", 212 | "('blocks.3.attn.hook_result', [:, :, 6], 'blocks.3.attn.hook_k', [:, :, 6]) Edge(EdgeType.PLACEHOLDER, True)\n", 213 | "('blocks.3.attn.hook_result', [:, :, 6], 'blocks.3.attn.hook_v', [:, :, 6]) Edge(EdgeType.PLACEHOLDER, True)\n", 214 | "('blocks.3.attn.hook_result', [:, :, 5], 'blocks.3.attn.hook_q', [:, :, 5]) Edge(EdgeType.PLACEHOLDER, True)\n", 215 | "('blocks.3.attn.hook_result', [:, :, 5], 'blocks.3.attn.hook_k', [:, :, 5]) Edge(EdgeType.PLACEHOLDER, True)\n", 216 | "('blocks.3.attn.hook_result', [:, :, 5], 'blocks.3.attn.hook_v', [:, :, 5]) Edge(EdgeType.PLACEHOLDER, True)\n", 217 | "('blocks.3.attn.hook_result', [:, :, 4], 'blocks.3.attn.hook_q', [:, :, 4]) Edge(EdgeType.PLACEHOLDER, True)\n", 218 | "('blocks.3.attn.hook_result', [:, :, 4], 'blocks.3.attn.hook_k', [:, :, 4]) Edge(EdgeType.PLACEHOLDER, True)\n", 219 | "('blocks.3.attn.hook_result', [:, :, 4], 'blocks.3.attn.hook_v', [:, :, 4]) Edge(EdgeType.PLACEHOLDER, True)\n", 220 | "('blocks.3.attn.hook_result', [:, :, 3], 'blocks.3.attn.hook_q', [:, :, 3]) Edge(EdgeType.PLACEHOLDER, True)\n", 221 | "('blocks.3.attn.hook_result', [:, :, 3], 'blocks.3.attn.hook_k', [:, :, 3]) Edge(EdgeType.PLACEHOLDER, True)\n", 222 | "('blocks.3.attn.hook_result', [:, :, 3], 'blocks.3.attn.hook_v', [:, :, 3]) Edge(EdgeType.PLACEHOLDER, True)\n", 223 | "('blocks.3.attn.hook_result', [:, :, 2], 'blocks.3.attn.hook_q', [:, :, 2]) Edge(EdgeType.PLACEHOLDER, True)\n", 224 | "('blocks.3.attn.hook_result', [:, :, 2], 'blocks.3.attn.hook_k', [:, :, 2]) Edge(EdgeType.PLACEHOLDER, True)\n", 225 | "('blocks.3.attn.hook_result', [:, :, 2], 'blocks.3.attn.hook_v', [:, :, 2]) Edge(EdgeType.PLACEHOLDER, True)\n", 226 | "('blocks.3.attn.hook_result', [:, :, 1], 'blocks.3.attn.hook_q', [:, :, 1]) Edge(EdgeType.PLACEHOLDER, True)\n", 227 | "('blocks.3.attn.hook_result', [:, :, 1], 'blocks.3.attn.hook_k', [:, :, 1]) Edge(EdgeType.PLACEHOLDER, True)\n", 228 | "('blocks.3.attn.hook_result', [:, :, 1], 'blocks.3.attn.hook_v', [:, :, 1]) Edge(EdgeType.PLACEHOLDER, True)\n", 229 | "('blocks.3.attn.hook_result', [:, :, 0], 'blocks.3.attn.hook_q', [:, :, 0]) Edge(EdgeType.PLACEHOLDER, True)\n", 230 | "('blocks.3.attn.hook_result', [:, :, 0], 'blocks.3.attn.hook_k', [:, :, 0]) Edge(EdgeType.PLACEHOLDER, True)\n", 231 | "('blocks.3.attn.hook_result', [:, :, 0], 'blocks.3.attn.hook_v', [:, :, 0]) Edge(EdgeType.PLACEHOLDER, True)\n", 232 | "('blocks.2.attn.hook_result', [:, :, 7], 'blocks.2.attn.hook_q', [:, :, 7]) Edge(EdgeType.PLACEHOLDER, True)\n", 233 | "('blocks.2.attn.hook_result', [:, :, 7], 'blocks.2.attn.hook_k', [:, :, 7]) Edge(EdgeType.PLACEHOLDER, True)\n", 234 | "('blocks.2.attn.hook_result', [:, :, 7], 'blocks.2.attn.hook_v', [:, :, 7]) Edge(EdgeType.PLACEHOLDER, True)\n", 235 | "('blocks.2.attn.hook_result', [:, :, 6], 'blocks.2.attn.hook_q', [:, :, 6]) Edge(EdgeType.PLACEHOLDER, True)\n", 236 | "('blocks.2.attn.hook_result', [:, :, 6], 'blocks.2.attn.hook_k', [:, :, 6]) Edge(EdgeType.PLACEHOLDER, True)\n", 237 | "('blocks.2.attn.hook_result', [:, :, 6], 'blocks.2.attn.hook_v', [:, :, 6]) Edge(EdgeType.PLACEHOLDER, True)\n", 238 | "('blocks.2.attn.hook_result', [:, :, 5], 'blocks.2.attn.hook_q', [:, :, 5]) Edge(EdgeType.PLACEHOLDER, True)\n", 239 | "('blocks.2.attn.hook_result', [:, :, 5], 'blocks.2.attn.hook_k', [:, :, 5]) Edge(EdgeType.PLACEHOLDER, True)\n", 240 | "('blocks.2.attn.hook_result', [:, :, 5], 'blocks.2.attn.hook_v', [:, :, 5]) Edge(EdgeType.PLACEHOLDER, True)\n", 241 | "('blocks.2.attn.hook_result', [:, :, 4], 'blocks.2.attn.hook_q', [:, :, 4]) Edge(EdgeType.PLACEHOLDER, True)\n", 242 | "('blocks.2.attn.hook_result', [:, :, 4], 'blocks.2.attn.hook_k', [:, :, 4]) Edge(EdgeType.PLACEHOLDER, True)\n", 243 | "('blocks.2.attn.hook_result', [:, :, 4], 'blocks.2.attn.hook_v', [:, :, 4]) Edge(EdgeType.PLACEHOLDER, True)\n", 244 | "('blocks.2.attn.hook_result', [:, :, 3], 'blocks.2.attn.hook_q', [:, :, 3]) Edge(EdgeType.PLACEHOLDER, True)\n", 245 | "('blocks.2.attn.hook_result', [:, :, 3], 'blocks.2.attn.hook_k', [:, :, 3]) Edge(EdgeType.PLACEHOLDER, True)\n", 246 | "('blocks.2.attn.hook_result', [:, :, 3], 'blocks.2.attn.hook_v', [:, :, 3]) Edge(EdgeType.PLACEHOLDER, True)\n", 247 | "('blocks.2.attn.hook_result', [:, :, 2], 'blocks.2.attn.hook_q', [:, :, 2]) Edge(EdgeType.PLACEHOLDER, True)\n", 248 | "('blocks.2.attn.hook_result', [:, :, 2], 'blocks.2.attn.hook_k', [:, :, 2]) Edge(EdgeType.PLACEHOLDER, True)\n", 249 | "('blocks.2.attn.hook_result', [:, :, 2], 'blocks.2.attn.hook_v', [:, :, 2]) Edge(EdgeType.PLACEHOLDER, True)\n", 250 | "('blocks.2.attn.hook_result', [:, :, 1], 'blocks.2.attn.hook_q', [:, :, 1]) Edge(EdgeType.PLACEHOLDER, True)\n", 251 | "('blocks.2.attn.hook_result', [:, :, 1], 'blocks.2.attn.hook_k', [:, :, 1]) Edge(EdgeType.PLACEHOLDER, True)\n", 252 | "('blocks.2.attn.hook_result', [:, :, 1], 'blocks.2.attn.hook_v', [:, :, 1]) Edge(EdgeType.PLACEHOLDER, True)\n", 253 | "('blocks.2.attn.hook_result', [:, :, 0], 'blocks.2.attn.hook_q', [:, :, 0]) Edge(EdgeType.PLACEHOLDER, True)\n", 254 | "('blocks.2.attn.hook_result', [:, :, 0], 'blocks.2.attn.hook_k', [:, :, 0]) Edge(EdgeType.PLACEHOLDER, True)\n", 255 | "('blocks.2.attn.hook_result', [:, :, 0], 'blocks.2.attn.hook_v', [:, :, 0]) Edge(EdgeType.PLACEHOLDER, True)\n", 256 | "('blocks.1.attn.hook_result', [:, :, 7], 'blocks.1.attn.hook_q', [:, :, 7]) Edge(EdgeType.PLACEHOLDER, True)\n", 257 | "('blocks.1.attn.hook_result', [:, :, 7], 'blocks.1.attn.hook_k', [:, :, 7]) Edge(EdgeType.PLACEHOLDER, True)\n", 258 | "('blocks.1.attn.hook_result', [:, :, 7], 'blocks.1.attn.hook_v', [:, :, 7]) Edge(EdgeType.PLACEHOLDER, True)\n", 259 | "('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_q', [:, :, 6]) Edge(EdgeType.PLACEHOLDER, True)\n", 260 | "('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_k', [:, :, 6]) Edge(EdgeType.PLACEHOLDER, True)\n", 261 | "('blocks.1.attn.hook_result', [:, :, 6], 'blocks.1.attn.hook_v', [:, :, 6]) Edge(EdgeType.PLACEHOLDER, True)\n", 262 | "('blocks.1.attn.hook_result', [:, :, 5], 'blocks.1.attn.hook_q', [:, :, 5]) Edge(EdgeType.PLACEHOLDER, True)\n", 263 | "('blocks.1.attn.hook_result', [:, :, 5], 'blocks.1.attn.hook_k', [:, :, 5]) Edge(EdgeType.PLACEHOLDER, True)\n", 264 | "('blocks.1.attn.hook_result', [:, :, 5], 'blocks.1.attn.hook_v', [:, :, 5]) Edge(EdgeType.PLACEHOLDER, True)\n", 265 | "('blocks.1.attn.hook_result', [:, :, 4], 'blocks.1.attn.hook_q', [:, :, 4]) Edge(EdgeType.PLACEHOLDER, True)\n", 266 | "('blocks.1.attn.hook_result', [:, :, 4], 'blocks.1.attn.hook_k', [:, :, 4]) Edge(EdgeType.PLACEHOLDER, True)\n", 267 | "('blocks.1.attn.hook_result', [:, :, 4], 'blocks.1.attn.hook_v', [:, :, 4]) Edge(EdgeType.PLACEHOLDER, True)\n", 268 | "('blocks.1.attn.hook_result', [:, :, 3], 'blocks.1.attn.hook_q', [:, :, 3]) Edge(EdgeType.PLACEHOLDER, True)\n", 269 | "('blocks.1.attn.hook_result', [:, :, 3], 'blocks.1.attn.hook_k', [:, :, 3]) Edge(EdgeType.PLACEHOLDER, True)\n", 270 | "('blocks.1.attn.hook_result', [:, :, 3], 'blocks.1.attn.hook_v', [:, :, 3]) Edge(EdgeType.PLACEHOLDER, True)\n", 271 | "('blocks.1.attn.hook_result', [:, :, 2], 'blocks.1.attn.hook_q', [:, :, 2]) Edge(EdgeType.PLACEHOLDER, True)\n", 272 | "('blocks.1.attn.hook_result', [:, :, 2], 'blocks.1.attn.hook_k', [:, :, 2]) Edge(EdgeType.PLACEHOLDER, True)\n", 273 | "('blocks.1.attn.hook_result', [:, :, 2], 'blocks.1.attn.hook_v', [:, :, 2]) Edge(EdgeType.PLACEHOLDER, True)\n", 274 | "('blocks.1.attn.hook_result', [:, :, 1], 'blocks.1.attn.hook_q', [:, :, 1]) Edge(EdgeType.PLACEHOLDER, True)\n", 275 | "('blocks.1.attn.hook_result', [:, :, 1], 'blocks.1.attn.hook_k', [:, :, 1]) Edge(EdgeType.PLACEHOLDER, True)\n", 276 | "('blocks.1.attn.hook_result', [:, :, 1], 'blocks.1.attn.hook_v', [:, :, 1]) Edge(EdgeType.PLACEHOLDER, True)\n", 277 | "('blocks.1.attn.hook_result', [:, :, 0], 'blocks.1.attn.hook_q', [:, :, 0]) Edge(EdgeType.PLACEHOLDER, True)\n", 278 | "('blocks.1.attn.hook_result', [:, :, 0], 'blocks.1.attn.hook_k', [:, :, 0]) Edge(EdgeType.PLACEHOLDER, True)\n", 279 | "('blocks.1.attn.hook_result', [:, :, 0], 'blocks.1.attn.hook_v', [:, :, 0]) Edge(EdgeType.PLACEHOLDER, True)\n", 280 | "('blocks.0.attn.hook_result', [:, :, 7], 'blocks.0.attn.hook_q', [:, :, 7]) Edge(EdgeType.PLACEHOLDER, True)\n", 281 | "('blocks.0.attn.hook_result', [:, :, 7], 'blocks.0.attn.hook_k', [:, :, 7]) Edge(EdgeType.PLACEHOLDER, True)\n", 282 | "('blocks.0.attn.hook_result', [:, :, 7], 'blocks.0.attn.hook_v', [:, :, 7]) Edge(EdgeType.PLACEHOLDER, True)\n", 283 | "('blocks.0.attn.hook_result', [:, :, 6], 'blocks.0.attn.hook_q', [:, :, 6]) Edge(EdgeType.PLACEHOLDER, True)\n", 284 | "('blocks.0.attn.hook_result', [:, :, 6], 'blocks.0.attn.hook_k', [:, :, 6]) Edge(EdgeType.PLACEHOLDER, True)\n", 285 | "('blocks.0.attn.hook_result', [:, :, 6], 'blocks.0.attn.hook_v', [:, :, 6]) Edge(EdgeType.PLACEHOLDER, True)\n", 286 | "('blocks.0.attn.hook_result', [:, :, 5], 'blocks.0.attn.hook_q', [:, :, 5]) Edge(EdgeType.PLACEHOLDER, True)\n", 287 | "('blocks.0.attn.hook_result', [:, :, 5], 'blocks.0.attn.hook_k', [:, :, 5]) Edge(EdgeType.PLACEHOLDER, True)\n", 288 | "('blocks.0.attn.hook_result', [:, :, 5], 'blocks.0.attn.hook_v', [:, :, 5]) Edge(EdgeType.PLACEHOLDER, True)\n", 289 | "('blocks.0.attn.hook_result', [:, :, 4], 'blocks.0.attn.hook_q', [:, :, 4]) Edge(EdgeType.PLACEHOLDER, True)\n", 290 | "('blocks.0.attn.hook_result', [:, :, 4], 'blocks.0.attn.hook_k', [:, :, 4]) Edge(EdgeType.PLACEHOLDER, True)\n", 291 | "('blocks.0.attn.hook_result', [:, :, 4], 'blocks.0.attn.hook_v', [:, :, 4]) Edge(EdgeType.PLACEHOLDER, True)\n", 292 | "('blocks.0.attn.hook_result', [:, :, 3], 'blocks.0.attn.hook_q', [:, :, 3]) Edge(EdgeType.PLACEHOLDER, True)\n", 293 | "('blocks.0.attn.hook_result', [:, :, 3], 'blocks.0.attn.hook_k', [:, :, 3]) Edge(EdgeType.PLACEHOLDER, True)\n", 294 | "('blocks.0.attn.hook_result', [:, :, 3], 'blocks.0.attn.hook_v', [:, :, 3]) Edge(EdgeType.PLACEHOLDER, True)\n", 295 | "('blocks.0.attn.hook_result', [:, :, 2], 'blocks.0.attn.hook_q', [:, :, 2]) Edge(EdgeType.PLACEHOLDER, True)\n", 296 | "('blocks.0.attn.hook_result', [:, :, 2], 'blocks.0.attn.hook_k', [:, :, 2]) Edge(EdgeType.PLACEHOLDER, True)\n", 297 | "('blocks.0.attn.hook_result', [:, :, 2], 'blocks.0.attn.hook_v', [:, :, 2]) Edge(EdgeType.PLACEHOLDER, True)\n", 298 | "('blocks.0.attn.hook_result', [:, :, 1], 'blocks.0.attn.hook_q', [:, :, 1]) Edge(EdgeType.PLACEHOLDER, True)\n", 299 | "('blocks.0.attn.hook_result', [:, :, 1], 'blocks.0.attn.hook_k', [:, :, 1]) Edge(EdgeType.PLACEHOLDER, True)\n", 300 | "('blocks.0.attn.hook_result', [:, :, 1], 'blocks.0.attn.hook_v', [:, :, 1]) Edge(EdgeType.PLACEHOLDER, True)\n", 301 | "('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_q', [:, :, 0]) Edge(EdgeType.PLACEHOLDER, True)\n", 302 | "('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_k', [:, :, 0]) Edge(EdgeType.PLACEHOLDER, True)\n", 303 | "('blocks.0.attn.hook_result', [:, :, 0], 'blocks.0.attn.hook_v', [:, :, 0]) Edge(EdgeType.PLACEHOLDER, True)\n" 304 | ] 305 | } 306 | ], 307 | "source": [ 308 | "from acdc.TLACDCEdge import TorchIndex\n", 309 | "e=acdcpp_exp.setup_exp(0.0)\n", 310 | "\n", 311 | "# Are there any bad placeholder edges?\n", 312 | "for edge_tuple, e in e.corr.all_edges().items():\n", 313 | " if \"placeholder\" in str(e).lower():\n", 314 | " print(edge_tuple, e)\n", 315 | "# No" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "id": "2e7fa71a-eec4-4b89-afcf-6eeab1483b1b", 322 | "metadata": {}, 323 | "outputs": [], 324 | "source": [ 325 | "import json\n", 326 | "\n", 327 | "def convert_to_torch_index(index_list):\n", 328 | " return ''.join(['None' if i == ':' else i for i in index_list])\n", 329 | "\n", 330 | "for thresh in pruned_heads.keys():\n", 331 | " pruned_heads[thresh][0] = list(pruned_heads[thresh][0])\n", 332 | " pruned_heads[thresh][1] = list(pruned_heads[thresh][1])\n", 333 | " \n", 334 | "cleaned_attrs = {}\n", 335 | "for thresh in pruned_attrs.keys():\n", 336 | " cleaned_attrs[thresh] = []\n", 337 | " for ((e1, i1), (e2, i2)), attr in pruned_attrs[thresh].items():\n", 338 | " cleaned_attrs[thresh].append([e1, convert_to_torch_index(str(i1)), e2, convert_to_torch_index(str(i2)), attr])\n", 339 | " \n", 340 | "with open(f'{RUN_NAME}_pruned_heads_docstring.json', 'w') as f:\n", 341 | " json.dump(pruned_heads, f)\n", 342 | "with open(f'{RUN_NAME}_num_passes_docstring.json', 'w') as f:\n", 343 | " json.dump(num_passes, f)\n", 344 | "with open(f'{RUN_NAME}_pruned_attrs_docstring.json', 'w') as f:\n", 345 | " json.dump(cleaned_attrs, f)" 346 | ] 347 | } 348 | ], 349 | "metadata": { 350 | "kernelspec": { 351 | "display_name": "Python 3 (ipykernel)", 352 | "language": "python", 353 | "name": "python3" 354 | }, 355 | "language_info": { 356 | "codemirror_mode": { 357 | "name": "ipython", 358 | "version": 3 359 | }, 360 | "file_extension": ".py", 361 | "mimetype": "text/x-python", 362 | "name": "python", 363 | "nbconvert_exporter": "python", 364 | "pygments_lexer": "ipython3", 365 | "version": "3.10.0" 366 | } 367 | }, 368 | "nbformat": 4, 369 | "nbformat_minor": 5 370 | } 371 | -------------------------------------------------------------------------------- /greaterthan_task/acdcpp_greaterthan.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "6606875e", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import sys\n", 11 | "sys.path.append('../Automatic-Circuit-Discovery/')\n", 12 | "sys.path.append('..')\n", 13 | "\n", 14 | "from acdc.greaterthan.utils import get_all_greaterthan_things\n", 15 | "from ACDCPPExperiment import ACDCPPExperiment\n", 16 | "from transformer_lens import HookedTransformer\n", 17 | "\n", 18 | "import numpy as np\n", 19 | "import torch as t\n", 20 | "import tqdm.notebook as tqdm\n", 21 | "import json\n", 22 | "\n", 23 | "device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')\n", 24 | "print(f'Device: {device}')" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "07a16eab", 30 | "metadata": {}, 31 | "source": [ 32 | "# Model Setup" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "20df2bef", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "model = HookedTransformer.from_pretrained(\n", 43 | " 'gpt2-small',\n", 44 | " center_writing_weights=False,\n", 45 | " center_unembed=False,\n", 46 | " fold_ln=False,\n", 47 | " device=device,\n", 48 | ")\n", 49 | "model.set_use_hook_mlp_in(True)\n", 50 | "model.set_use_split_qkv_input(True)\n", 51 | "model.set_use_attn_result(True)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "292dfbf6", 57 | "metadata": {}, 58 | "source": [ 59 | "# Dataset Setup" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "601a7d92", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "# Make clean dataset and reference dataset\n", 70 | "N = 25\n", 71 | "\n", 72 | "things = get_all_greaterthan_things(\n", 73 | " num_examples=N, metric_name=\"greaterthan\", device=device\n", 74 | ")\n", 75 | "greaterthan_metric = things.validation_metric\n", 76 | "toks_int_values = things.validation_data # clean data x_i\n", 77 | "toks_int_values_other = things.validation_patch_data # corrupted data x_i'\n", 78 | "\n", 79 | "print(\"\\nClean dataset samples\")\n", 80 | "for i in range(5):\n", 81 | " print(model.tokenizer.decode(toks_int_values[i]))\n", 82 | "\n", 83 | "print(\"\\nReference dataset samples\")\n", 84 | "for i in range(5):\n", 85 | " print(model.tokenizer.decode(toks_int_values_other[i]))" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "id": "cf81ab6e", 91 | "metadata": {}, 92 | "source": [ 93 | "# Run Experiment" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "56b08e9e-a140-4a97-a309-3210cc8f8ff3", 100 | "metadata": { 101 | "scrolled": true 102 | }, 103 | "outputs": [], 104 | "source": [ 105 | "THRESHOLDS = np.linspace(1e-4, 0.013, 30)\n", 106 | "RUN_NAME = 'greaterthan_edge_absval'\n", 107 | "acdcpp_exp = ACDCPPExperiment(model,\n", 108 | " toks_int_values,\n", 109 | " toks_int_values_other,\n", 110 | " greaterthan_metric,\n", 111 | " greaterthan_metric,\n", 112 | " THRESHOLDS,\n", 113 | " run_name=RUN_NAME,\n", 114 | " verbose=False,\n", 115 | " attr_absolute_val=True,\n", 116 | " save_graphs_after=0,\n", 117 | " pruning_mode=\"edge\",\n", 118 | " no_pruned_nodes_attr=1\n", 119 | " )\n", 120 | "pruned_heads, num_passes, pruned_attrs = acdcpp_exp.run()" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "f3c0b5e5-7732-42da-b92e-687536aca96c", 126 | "metadata": {}, 127 | "source": [ 128 | "# Save Data" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "e9fdca38-9c1a-45ee-8625-93c06b569533", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "def convert_to_torch_index(index_list):\n", 139 | " return ''.join(['None' if i == ':' else i for i in index_list])\n", 140 | "\n", 141 | "for thresh in pruned_heads.keys():\n", 142 | " pruned_heads[thresh][0] = list(pruned_heads[thresh][0])\n", 143 | " pruned_heads[thresh][1] = list(pruned_heads[thresh][1])\n", 144 | "\n", 145 | "cleaned_attrs = {}\n", 146 | "for thresh in pruned_attrs.keys():\n", 147 | " cleaned_attrs[thresh] = []\n", 148 | " for ((e1, i1), (e2, i2)), attr in pruned_attrs[thresh].items():\n", 149 | " cleaned_attrs[thresh].append([e1, convert_to_torch_index(str(i1)), e2, convert_to_torch_index(str(i2)), attr])\n", 150 | " \n", 151 | "with open(f'{RUN_NAME}_pruned_heads.json', 'w') as f:\n", 152 | " json.dump(pruned_heads, f)\n", 153 | "with open(f'{RUN_NAME}_num_passes.json', 'w') as f:\n", 154 | " json.dump(num_passes, f)\n", 155 | "with open(f'{RUN_NAME}_pruned_attrs.json', 'w') as f:\n", 156 | " json.dump(cleaned_attrs, f)" 157 | ] 158 | } 159 | ], 160 | "metadata": { 161 | "kernelspec": { 162 | "display_name": "Python 3 (ipykernel)", 163 | "language": "python", 164 | "name": "python3" 165 | }, 166 | "language_info": { 167 | "codemirror_mode": { 168 | "name": "ipython", 169 | "version": 3 170 | }, 171 | "file_extension": ".py", 172 | "mimetype": "text/x-python", 173 | "name": "python", 174 | "nbconvert_exporter": "python", 175 | "pygments_lexer": "ipython3", 176 | "version": "3.10.0" 177 | } 178 | }, 179 | "nbformat": 4, 180 | "nbformat_minor": 5 181 | } 182 | -------------------------------------------------------------------------------- /greaterthan_task/minimal_acdc_node_roc.py: -------------------------------------------------------------------------------- 1 | #%% 2 | 3 | from IPython import get_ipython 4 | ipython = get_ipython() 5 | if ipython is not None: 6 | ipython.magic("%load_ext autoreload") 7 | ipython.magic("%autoreload 2") 8 | import os 9 | from pathlib import Path 10 | import json 11 | import plotly.graph_objects as go 12 | import plotly.express as px 13 | import matplotlib.pyplot as plt 14 | import pandas as pd 15 | import re 16 | import numpy as np 17 | 18 | #%% 19 | TASK = "greaterthan" 20 | METRIC = "greaterthan" 21 | 22 | # Set your root directory here 23 | ROOT_DIR = Path("/Users/canrager/acdcpp") 24 | assert ROOT_DIR.exists(), f"I don't think your ROOT_DIR is correct (ROOT_DIR = {ROOT_DIR})" 25 | 26 | # %% ACDCPP 27 | ######################################## 28 | FNAME = f"greaterthan_task/results/greaterthan_absval_pruned_heads.json" 29 | FPATH = ROOT_DIR / FNAME 30 | assert FPATH.exists(), f"I don't think your FNAME is correct (FPATH = {FPATH})" 31 | 32 | # %% 33 | 34 | with open(FPATH, 'r') as f: 35 | pruned_heads = json.load(f) 36 | with open(ROOT_DIR /'greaterthan_task/results/greaterthan_absval_num_passes.json', 'r') as f: 37 | num_passes = json.load(f) 38 | 39 | # %% 40 | 41 | cleaned_heads = {} 42 | 43 | for thresh in pruned_heads.keys(): 44 | cleaned_heads[thresh] = {} 45 | cleaned_heads[thresh]['acdcpp'] = set() 46 | cleaned_heads[thresh]['acdc'] = set() 47 | 48 | for i in range(2): 49 | for head in pruned_heads[thresh][i]: 50 | attn_head_pttn = re.compile('^$') 51 | matched = attn_head_pttn.match(head) 52 | if matched: 53 | head_str = f'{matched.group(1)}.{matched.group(2)}' 54 | if i == 0: 55 | cleaned_heads[thresh]['acdcpp'].add(head_str) 56 | else: 57 | cleaned_heads[thresh]['acdc'].add(head_str) 58 | 59 | #%% 60 | true_baseline_heads = set(["5.1", "5.5", "6.1", "6.9", "7.10", "8.8", "8.11", "9.1"]) 61 | print(len(true_baseline_heads)) 62 | 63 | all_heads = set() 64 | 65 | for layer in range(12): 66 | for head in range(12): 67 | all_heads.add(f'{layer}.{head}') 68 | 69 | 70 | # %% 71 | data = { 72 | 'Threshold': [0], 73 | 'ACDCpp TPR': [1], 74 | 'ACDCpp TNR': [0], 75 | 'ACDCpp FPR': [1], 76 | 'ACDCpp FNR': [0], 77 | 'TPR': [1], 78 | 'TNR': [0], 79 | 'FPR': [1], 80 | 'FNR': [0], 81 | 'Num Passes': [np.inf], 82 | } 83 | 84 | for thresh in cleaned_heads.keys(): 85 | data['Threshold'].append(round(float(thresh), 3)) # Correct rounding error 86 | # Variables prefixed with pp_ are after ADCDCpp only 87 | pp_heads = cleaned_heads[thresh]['acdcpp'] 88 | heads = cleaned_heads[thresh]['acdc'] 89 | 90 | pp_tp = len(pp_heads.intersection(true_baseline_heads)) 91 | pp_tn = len((all_heads - true_baseline_heads).intersection(all_heads - pp_heads)) 92 | pp_fp = len(pp_heads - true_baseline_heads) 93 | pp_fn = len(true_baseline_heads - pp_heads) 94 | 95 | tp = len(heads.intersection(true_baseline_heads)) 96 | tn = len((all_heads - true_baseline_heads).intersection(all_heads - heads)) 97 | fp = len(heads - true_baseline_heads) 98 | fn = len(true_baseline_heads - heads) 99 | 100 | pp_tpr = pp_tp / (pp_tp + pp_fn) 101 | pp_tnr = pp_tn / (pp_tn + pp_fp) 102 | pp_fpr = 1 - pp_tnr 103 | pp_fnr = 1 - pp_tpr 104 | 105 | tpr = tp / (tp + fn) 106 | tnr = tn / (tn + fp) 107 | fpr = 1 - tnr 108 | fnr = 1 - tpr 109 | 110 | data['ACDCpp TPR'].append(pp_tpr) 111 | data['ACDCpp TNR'].append(pp_tnr) 112 | data['ACDCpp FPR'].append(pp_fpr) 113 | data['ACDCpp FNR'].append(pp_fnr) 114 | 115 | data['TPR'].append(tpr) 116 | data['TNR'].append(tnr) 117 | data['FPR'].append(fpr) 118 | data['FNR'].append(fnr) 119 | 120 | data['Num Passes'].append(num_passes[thresh]) 121 | df = pd.DataFrame(data) 122 | # Add thresh inf to end of df 123 | row = [np.inf, 0, 1, 0, 1, 0, 1, 0, 1, 0] 124 | df.loc[len(df)] = row 125 | 126 | # %% ACDC 127 | ####################################################### 128 | # %% 129 | 130 | FNAME = f"Automatic-Circuit-Discovery/experiments/results/plots_data/acdc-{TASK}-{METRIC}-False-0.json" 131 | FPATH = ROOT_DIR / FNAME 132 | assert FPATH.exists(), f"I don't think your FNAME is correct (FPATH = {FPATH})" 133 | 134 | # %% 135 | 136 | acdc_data = json.load(open(FPATH, "r")) 137 | 138 | # %% 139 | 140 | relevant_data = acdc_data["trained"]["random_ablation"][f"{TASK}"][f"{METRIC}"]["ACDC"] 141 | 142 | # %% 143 | 144 | node_tpr = relevant_data["node_tpr"] 145 | node_fpr = relevant_data["node_fpr"] 146 | 147 | #%% 148 | print(relevant_data.keys()) 149 | print(relevant_data["node_precision"]) 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | # %% 158 | 159 | 160 | # We would just plot these, but sometimes points are not on the Pareto frontier 161 | 162 | def pareto_optimal_sublist(xs, ys): 163 | retx, rety = [], [] 164 | for x, y in zip(xs, ys): 165 | for x1, y1 in zip(xs, ys): 166 | if x1 > x and y1 < y: 167 | break 168 | else: 169 | retx.append(x) 170 | rety.append(y) 171 | indices = sorted(range(len(retx)), key=lambda i: retx[i]) 172 | return [retx[i] for i in indices], [rety[i] for i in indices] 173 | 174 | # %% 175 | 176 | acdcpp_pareto_node_tpr, acdcpp_pareto_node_fpr = pareto_optimal_sublist(data['TPR'], data['FPR']) 177 | acdc_pareto_node_tpr, acdc_pareto_node_fpr = pareto_optimal_sublist(node_tpr, node_fpr) 178 | # %% 179 | 180 | # Thanks GPT-4 for this code 181 | 182 | # Create the plot 183 | plt.figure() 184 | 185 | # Plot the ROC curve 186 | plt.step(acdc_pareto_node_fpr, acdc_pareto_node_tpr, where='post', label="ACDCpp + ACDC") 187 | plt.step(acdcpp_pareto_node_fpr, acdcpp_pareto_node_tpr, where='post', label="ACDC only") 188 | 189 | # Add titles and labels 190 | plt.title("ROC Curve of number of Nodes recovered by ACDC") 191 | plt.xlabel("False Positive Rate") 192 | plt.ylabel("True Positive Rate") 193 | 194 | plt.legend(loc="lower right") 195 | 196 | # Show the plot 197 | plt.show() 198 | 199 | # %% 200 | 201 | # Original code from https://plotly.com/python/line-and-scatter/ 202 | 203 | # # I use plotly but it should be easy to adjust to matplotlib 204 | # fig = go.Figure() 205 | # fig.add_trace( 206 | # go.Scatter( 207 | # x=list(pareto_node_fpr), 208 | # y=list(pareto_node_tpr), 209 | # mode="lines", 210 | # line=dict(shape="hv"), 211 | # showlegend=False, 212 | # ), 213 | # ) 214 | 215 | # fig.update_layout( 216 | # title="ROC Curve of number of Nodes recovered by ACDC", 217 | # xaxis_title="False Positive Rate", 218 | # yaxis_title="True Positive Rate", 219 | # ) 220 | 221 | # fig.show() 222 | # # %% 223 | 224 | # data['TPR'] 225 | # %% 226 | -------------------------------------------------------------------------------- /greaterthan_task/results/greaterthan_absval_num_passes.json: -------------------------------------------------------------------------------- 1 | {"0.007925": 288, "0.009176315789473683": 285, "0.010427631578947369": 262, "0.011678947368421053": 253, "0.012930263157894736": 238, "0.01418157894736842": 207, "0.015432894736842104": 201, "0.016684210526315787": 201, "0.01793552631578947": 201, "0.019186842105263155": 201, "0.02043815789473684": 151, "0.021689473684210526": 151, "0.02294078947368421": 144, "0.024192105263157897": 144, "0.02544342105263158": 121, "0.026694736842105264": 121, "0.027946052631578948": 121, "0.02919736842105263": 121, "0.030448684210526315": 121, "0.0317": 121} -------------------------------------------------------------------------------- /greaterthan_task/results/greaterthan_absval_pruned_heads.json: -------------------------------------------------------------------------------- 1 | {"0.007925": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""]], "0.009176315789473683": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""]], "0.010427631578947369": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", ""]], "0.011678947368421053": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""]], "0.012930263157894736": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""]], "0.01418157894736842": [["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""]], "0.015432894736842104": [["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", ""]], "0.016684210526315787": [["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", ""]], "0.01793552631578947": [["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", ""]], "0.019186842105263155": [["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", ""]], "0.02043815789473684": [["", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.021689473684210526": [["", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", ""], ["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.02294078947368421": [["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], ["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.024192105263157897": [["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], ["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.02544342105263158": [["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], ["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.026694736842105264": [["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], ["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.027946052631578948": [["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], ["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.02919736842105263": [["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], ["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.030448684210526315": [["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], ["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.0317": [["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], ["embed", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]]} -------------------------------------------------------------------------------- /greaterthan_task/results/greaterthan_first_pass_num_passes.json: -------------------------------------------------------------------------------- 1 | {"0.01585": 201} -------------------------------------------------------------------------------- /greaterthan_task/results/greaterthan_first_pass_pruned_attrs.json: -------------------------------------------------------------------------------- 1 | {"0.01585": {"L0H0": 6.682943785563111e-05, "L0H2": 0.004924206528812647, "L0H3": 0.007153269834816456, "L0H4": 0.0038876731414347887, "L0H5": 0.010861625894904137, "L0H6": 0.00043244220432825387, "L0H7": 0.006177806295454502, "L0H8": 0.0005415635532699525, "L0H9": 0.001762608764693141, "L0H10": 0.003871350083500147, "L0H11": 0.001359875313937664, "L1H0": 0.0009917484130710363, "L1H1": 0.0009105786448344588, "L1H2": 0.00014150900824461132, "L1H3": 0.0020992462523281574, "L1H4": 0.00017145770834758878, "L1H5": 0.007083612494170666, "L1H6": 0.00016919116023927927, "L1H7": 0.0019029075047001243, "L1H8": 4.725949838757515e-05, "L1H9": 4.230861668474972e-05, "L1H10": 0.00037578342016786337, "L1H11": 0.0014855725457891822, "L2H0": 0.0019191649043932557, "L2H1": 0.00214808969758451, "L2H2": 0.0024934974499046803, "L2H3": 0.000635548320133239, "L2H4": 0.004467578139156103, "L2H5": 0.002218085341155529, "L2H6": 0.0012357577215880156, "L2H7": 0.00027195035363547504, "L2H8": 0.00040346378227695823, "L2H9": 0.0014996943064033985, "L2H10": 0.001109811128117144, "L2H11": 0.0006678666104562581, "L3H0": 0.0016048499383032322, "L3H1": 4.179199459031224e-05, "L3H2": 0.000517719890922308, "L3H3": 9.97509341686964e-05, "L3H4": 0.0016818811418488622, "L3H5": 4.4079264625906944e-05, "L3H6": 0.0007969538564793766, "L3H7": 0.0002894916106015444, "L3H8": 0.001279802294448018, "L3H9": 0.001918319845572114, "L3H10": 0.00046112603740766644, "L3H11": 0.001037092413753271, "L4H0": 0.00062256318051368, "L4H1": 0.0007489825366064906, "L4H2": 0.00022702667047269642, "L4H3": 0.00249772472307086, "L4H4": 0.004218580201268196, "L4H5": 0.0035213467199355364, "L4H6": 0.003603001357987523, "L4H7": 0.0022356221452355385, "L4H8": 0.0070498730055987835, "L4H9": 0.0008734700386412442, "L4H10": 0.00341726653277874, "L4H11": 0.0024168887175619602, "L5H0": 0.00030740260262973607, "L5H1": 0.0011553500080481172, "L5H2": 0.0015157873276621103, "L5H3": 0.00036283250665292144, "L5H4": 0.0009005400934256613, "L5H6": 0.00047112395986914635, "L5H7": 1.4019955415278673e-05, "L5H8": 0.0008736909367144108, "L5H9": 0.0006366747547872365, "L5H10": 0.0018326842691749334, "L5H11": 0.0006416576798073947, "L6H0": 0.00028642534743994474, "L6H1": 0.013869956135749817, "L6H2": 0.0002665279898792505, "L6H3": 0.001548783970065415, "L6H4": 0.0009254538454115391, "L6H5": 0.0030923117883503437, "L6H6": 0.0023855315521359444, "L6H7": 0.0006218899507075548, "L6H8": 0.0002659866586327553, "L6H9": 0.012234240770339966, "L6H10": 0.0002466221048962325, "L6H11": 0.0007808533846400678, "L7H0": 0.0006520153256133199, "L7H1": 0.001032261410728097, "L7H2": 0.00048802251694723964, "L7H3": 0.0012545776553452015, "L7H4": 0.000838662323076278, "L7H5": 0.0008072698838077486, "L7H6": 0.0066537377424538136, "L7H7": 0.006869889795780182, "L7H8": 0.0019015122670680285, "L7H9": 0.0009603225043974817, "L7H11": 0.004181566182523966, "L8H0": 0.0007331593660637736, "L8H1": 0.004181223921477795, "L8H2": 0.000742992851883173, "L8H3": 0.0003785094013437629, "L8H4": 0.000790676916949451, "L8H5": 0.0012235743924975395, "L8H6": 0.0020309530664235353, "L8H7": 0.0008130870992317796, "L8H9": 0.0024738118518143892, "L8H10": 0.006889326497912407, "L9H0": 2.074235089821741e-05, "L9H2": 0.0001314622932113707, "L9H3": 0.002187237609177828, "L9H4": 0.0005771390860900283, "L9H5": 0.0027549793012440205, "L9H6": 0.0111205680295825, "L9H7": 0.00046370632480829954, "L9H8": 0.0003021408338099718, "L9H9": 0.005357271991670132, "L9H10": 0.0001644733129069209, "L9H11": 0.0002228685189038515, "L10H0": 0.0011773486621677876, "L10H1": 0.001437217928469181, "L10H2": 0.006685856729745865, "L10H3": 0.002048744587227702, "L10H4": 0.008055365644395351, "L10H5": 7.358308357652277e-05, "L10H6": 0.001933764317072928, "L10H8": 0.000896780751645565, "L10H9": 0.00019626232096925378, "L10H10": 0.0007632666965946555, "L10H11": 2.5391784220119007e-05, "L11H0": 0.0022834488190710545, "L11H1": 0.0003687163698486984, "L11H2": 0.0013249889016151428, "L11H3": 5.9077941841678694e-05, "L11H4": 0.0003791903145611286, "L11H5": 2.9664173780474812e-05, "L11H6": 0.00020392087753862143, "L11H7": 0.00014993123477324843, "L11H8": 0.002310845535248518, "L11H9": 1.3490192941389978e-05, "L11H10": 0.0018624679651111364, "L11H11": 3.710501914611086e-05}} -------------------------------------------------------------------------------- /greaterthan_task/results/greaterthan_first_pass_pruned_heads.json: -------------------------------------------------------------------------------- 1 | {"0.01585": [["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", ""]]} -------------------------------------------------------------------------------- /ioi_task/abs_value_num_passes.json: -------------------------------------------------------------------------------- 1 | {"0.005": 8600, "0.01": 3956, "0.015": 2815, "0.02": 1929, "0.025": 1536, "0.030000000000000002": 1021, "0.034999999999999996": 871, "0.04": 678, "0.045": 617, "0.049999999999999996": 607, "0.055": 410, "0.06": 383, "0.065": 341, "0.07": 322, "0.07500000000000001": 302, "0.08": 282, "0.085": 282, "0.09000000000000001": 282, "0.095": 260, "0.1": 260, "0.10500000000000001": 250, "0.11": 240, "0.115": 234, "0.12000000000000001": 188, "0.125": 188, "0.13": 188, "0.135": 168, "0.14": 161, "0.14500000000000002": 161, "0.15": 151} -------------------------------------------------------------------------------- /ioi_task/abs_value_pruned_heads.json: -------------------------------------------------------------------------------- 1 | {"0.005": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.01": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.015": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.02": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.025": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", ""]], "0.030000000000000002": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""]], "0.034999999999999996": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.04": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.045": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.049999999999999996": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.055": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", ""]], "0.06": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", "", ""]], "0.065": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", ""]], "0.07": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", "", ""]], "0.07500000000000001": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", "", ""]], "0.08": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", ""]], "0.085": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", ""]], "0.09000000000000001": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", "", ""]], "0.095": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", ""]], "0.1": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", ""]], "0.10500000000000001": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", ""]], "0.11": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", ""]], "0.115": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", ""]], "0.12000000000000001": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", ""]], "0.125": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", ""]], "0.13": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", ""], ["", "", "", "", "", ""]], "0.135": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", "", ""], ["", "", "", "", ""]], "0.14": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""], ["", "", "", "", ""]], "0.14500000000000002": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""], ["", "", "", "", ""]], "0.15": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", "", "", "", ""], ["", "", "", ""]]} -------------------------------------------------------------------------------- /ioi_task/acdcpp_on_edges_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "6606875e", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from IPython import get_ipython\n", 11 | "ipython = get_ipython()\n", 12 | "if ipython is not None:\n", 13 | " ipython.magic(\"%load_ext autoreload\")\n", 14 | " ipython.magic(\"%autoreload 2\")\n", 15 | "\n", 16 | "import os\n", 17 | "import sys\n", 18 | "sys.path.append('../Automatic-Circuit-Discovery/')\n", 19 | "sys.path.append('..')\n", 20 | "import torch\n", 21 | "import re\n", 22 | "\n", 23 | "import acdc\n", 24 | "from utils.prune_utils import get_3_caches, split_layers_and_heads\n", 25 | "from acdc.TLACDCExperiment import TLACDCExperiment\n", 26 | "from acdc.acdc_utils import TorchIndex, EdgeType\n", 27 | "import numpy as np\n", 28 | "import torch as t\n", 29 | "from torch import Tensor\n", 30 | "import einops\n", 31 | "import itertools\n", 32 | "\n", 33 | "from transformer_lens import HookedTransformer, ActivationCache\n", 34 | "\n", 35 | "import tqdm.notebook as tqdm\n", 36 | "import plotly\n", 37 | "from rich import print as rprint\n", 38 | "from rich.table import Table\n", 39 | "\n", 40 | "from jaxtyping import Float, Bool\n", 41 | "from typing import Callable, Tuple, Union, Dict, Optional\n", 42 | "\n", 43 | "device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')\n", 44 | "print(f'Device: {device}')" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "07a16eab", 50 | "metadata": {}, 51 | "source": [ 52 | "# Model Setup" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "20df2bef", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "model = HookedTransformer.from_pretrained(\n", 63 | " 'gpt2-small',\n", 64 | " center_writing_weights=False,\n", 65 | " center_unembed=False,\n", 66 | " fold_ln=False,\n", 67 | " device=device,\n", 68 | ")\n", 69 | "model.set_use_hook_mlp_in(True)\n", 70 | "model.set_use_split_qkv_input(True)\n", 71 | "model.set_use_attn_result(True)" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "id": "292dfbf6", 77 | "metadata": {}, 78 | "source": [ 79 | "# Dataset Setup" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "601a7d92", 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "from ioi_dataset import IOIDataset, format_prompt, make_table\n", 90 | "N = 25\n", 91 | "clean_dataset = IOIDataset(\n", 92 | " prompt_type='mixed',\n", 93 | " N=N,\n", 94 | " tokenizer=model.tokenizer,\n", 95 | " prepend_bos=False,\n", 96 | " seed=1,\n", 97 | " device=device\n", 98 | ")\n", 99 | "corr_dataset = clean_dataset.gen_flipped_prompts('ABC->XYZ, BAB->XYZ')\n", 100 | "\n", 101 | "make_table(\n", 102 | " colnames = [\"IOI prompt\", \"IOI subj\", \"IOI indirect obj\", \"ABC prompt\"],\n", 103 | " cols = [\n", 104 | " map(format_prompt, clean_dataset.sentences),\n", 105 | " model.to_string(clean_dataset.s_tokenIDs).split(),\n", 106 | " model.to_string(clean_dataset.io_tokenIDs).split(),\n", 107 | " map(format_prompt, clean_dataset.sentences),\n", 108 | " ],\n", 109 | " title = \"Sentences from IOI vs ABC distribution\",\n", 110 | ")" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "id": "6657f126", 116 | "metadata": {}, 117 | "source": [ 118 | "# Metric Setup" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "d1b9d4d6", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "def ave_logit_diff(\n", 129 | " logits: Float[Tensor, 'batch seq d_vocab'],\n", 130 | " ioi_dataset: IOIDataset,\n", 131 | " per_prompt: bool = False\n", 132 | "):\n", 133 | " '''\n", 134 | " Return average logit difference between correct and incorrect answers\n", 135 | " '''\n", 136 | " # Get logits for indirect objects\n", 137 | " io_logits = logits[range(logits.size(0)), ioi_dataset.word_idx['end'], ioi_dataset.io_tokenIDs]\n", 138 | " s_logits = logits[range(logits.size(0)), ioi_dataset.word_idx['end'], ioi_dataset.s_tokenIDs]\n", 139 | " # Get logits for subject\n", 140 | " logit_diff = io_logits - s_logits\n", 141 | " return logit_diff if per_prompt else logit_diff.mean()\n", 142 | "\n", 143 | "with t.no_grad():\n", 144 | " clean_logits = model(clean_dataset.toks)\n", 145 | " corrupt_logits = model(corr_dataset.toks)\n", 146 | " clean_logit_diff = ave_logit_diff(clean_logits, clean_dataset).item()\n", 147 | " corrupt_logit_diff = ave_logit_diff(corrupt_logits, corr_dataset).item()\n", 148 | "\n", 149 | "def ioi_metric(\n", 150 | " logits: Float[Tensor, \"batch seq_len d_vocab\"],\n", 151 | " corrupted_logit_diff: float = corrupt_logit_diff,\n", 152 | " clean_logit_diff: float = clean_logit_diff,\n", 153 | " ioi_dataset: IOIDataset = clean_dataset\n", 154 | " ):\n", 155 | " patched_logit_diff = ave_logit_diff(logits, ioi_dataset)\n", 156 | " return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)\n", 157 | "\n", 158 | "def negative_ioi_metric(logits: Float[Tensor, \"batch seq_len d_vocab\"]):\n", 159 | " return -ioi_metric(logits)\n", 160 | " \n", 161 | "# Get clean and corrupt logit differences\n", 162 | "with t.no_grad():\n", 163 | " clean_metric = ioi_metric(clean_logits, corrupt_logit_diff, clean_logit_diff, clean_dataset)\n", 164 | " corrupt_metric = ioi_metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, corr_dataset)\n", 165 | "\n", 166 | "print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')\n", 167 | "print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "id": "cf81ab6e", 173 | "metadata": {}, 174 | "source": [ 175 | "# Run Experiment" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "id": "56b08e9e-a140-4a97-a309-3210cc8f8ff3", 182 | "metadata": { 183 | "scrolled": true 184 | }, 185 | "outputs": [], 186 | "source": [ 187 | "# get the 2 fwd and 1 bwd caches; cache \"normalized\" and \"result\" of attn layers\n", 188 | "clean_cache, corrupted_cache, clean_grad_cache = get_3_caches(\n", 189 | " model, \n", 190 | " clean_dataset.toks,\n", 191 | " corr_dataset.toks,\n", 192 | " metric=negative_ioi_metric,\n", 193 | " mode = \"edge\",\n", 194 | ")" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "id": "50407fbf", 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "clean_head_act = split_layers_and_heads(clean_cache.stack_head_results(), model=model)\n", 205 | "corr_head_act = split_layers_and_heads(corrupted_cache.stack_head_results(), model=model)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "id": "0112ada0", 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "stacked_grad_act = torch.zeros(\n", 216 | " 3, # QKV\n", 217 | " model.cfg.n_layers,\n", 218 | " model.cfg.n_heads,\n", 219 | " clean_head_act.shape[-3], # Batch\n", 220 | " clean_head_act.shape[-2], # Seq\n", 221 | " clean_head_act.shape[-1], # D\n", 222 | ")\n", 223 | "\n", 224 | "for letter_idx, letter in enumerate(\"qkv\"):\n", 225 | " for layer_idx in range(model.cfg.n_layers):\n", 226 | " stacked_grad_act[letter_idx, layer_idx] = einops.rearrange(clean_grad_cache[f\"blocks.{layer_idx}.hook_{letter}_input\"], \"batch seq n_heads d -> n_heads batch seq d\")" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "id": "a4d4f25d", 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "results = {}\n", 237 | "\n", 238 | "for upstream_layer_idx in range(model.cfg.n_layers):\n", 239 | " for upstream_head_idx in range(model.cfg.n_heads):\n", 240 | " for downstream_letter_idx, downstream_letter in enumerate(\"qkv\"):\n", 241 | " for downstream_layer_idx in range(upstream_layer_idx+1, model.cfg.n_layers):\n", 242 | " for downstream_head_idx in range(model.cfg.n_heads):\n", 243 | " results[\n", 244 | " (\n", 245 | " upstream_layer_idx,\n", 246 | " upstream_head_idx,\n", 247 | " downstream_letter,\n", 248 | " downstream_layer_idx,\n", 249 | " downstream_head_idx,\n", 250 | " )\n", 251 | " ] = (stacked_grad_act[downstream_letter_idx, downstream_layer_idx, downstream_head_idx].cpu() * (clean_head_act[upstream_layer_idx, upstream_head_idx] - corr_head_act[upstream_layer_idx, upstream_head_idx]).cpu()).sum()" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "id": "140a6ed6", 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "sorted_results = sorted(results.items(), key=lambda x: x[1].abs(), reverse=True)" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "id": "2ab2dd33", 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "print(\"Top 10 most important edges:\")\n", 272 | "for i in range(10):\n", 273 | " print(\n", 274 | " f\"{sorted_results[i][0][0]}:{sorted_results[i][0][1]} -> {sorted_results[i][0][3]}:{sorted_results[i][0][4]}\",\n", 275 | " )" 276 | ] 277 | } 278 | ], 279 | "metadata": { 280 | "kernelspec": { 281 | "display_name": "Python 3 (ipykernel)", 282 | "language": "python", 283 | "name": "python3" 284 | }, 285 | "language_info": { 286 | "codemirror_mode": { 287 | "name": "ipython", 288 | "version": 3 289 | }, 290 | "file_extension": ".py", 291 | "mimetype": "text/x-python", 292 | "name": "python", 293 | "nbconvert_exporter": "python", 294 | "pygments_lexer": "ipython3", 295 | "version": "3.10.11" 296 | } 297 | }, 298 | "nbformat": 4, 299 | "nbformat_minor": 5 300 | } 301 | -------------------------------------------------------------------------------- /ioi_task/ims.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Aaquib111/edge-attribution-patching/7124ef815b320383f2d29b0e2c2757075ed0c417/ioi_task/ims.zip -------------------------------------------------------------------------------- /ioi_task/ioi_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Optional 2 | import warnings 3 | import torch as t 4 | import numpy as np 5 | from transformers import AutoTokenizer 6 | import random 7 | import copy 8 | import re 9 | from rich import print as rprint 10 | from rich.table import Table 11 | 12 | NAMES = [ 13 | "Aaron", 14 | "Adam", 15 | "Alan", 16 | "Alex", 17 | "Alice", 18 | "Amy", 19 | "Anderson", 20 | "Andre", 21 | "Andrew", 22 | "Andy", 23 | "Anna", 24 | "Anthony", 25 | "Arthur", 26 | "Austin", 27 | "Blake", 28 | "Brandon", 29 | "Brian", 30 | "Carter", 31 | "Charles", 32 | "Charlie", 33 | "Christian", 34 | "Christopher", 35 | "Clark", 36 | "Cole", 37 | "Collins", 38 | "Connor", 39 | "Crew", 40 | "Crystal", 41 | "Daniel", 42 | "David", 43 | "Dean", 44 | "Edward", 45 | "Elizabeth", 46 | "Emily", 47 | "Eric", 48 | "Eva", 49 | "Ford", 50 | "Frank", 51 | "George", 52 | "Georgia", 53 | "Graham", 54 | "Grant", 55 | "Henry", 56 | "Ian", 57 | "Jack", 58 | "Jacob", 59 | "Jake", 60 | "James", 61 | "Jamie", 62 | "Jane", 63 | "Jason", 64 | "Jay", 65 | "Jennifer", 66 | "Jeremy", 67 | "Jessica", 68 | "John", 69 | "Jonathan", 70 | "Jordan", 71 | "Joseph", 72 | "Joshua", 73 | "Justin", 74 | "Kate", 75 | "Kelly", 76 | "Kevin", 77 | "Kyle", 78 | "Laura", 79 | "Leon", 80 | "Lewis", 81 | "Lisa", 82 | "Louis", 83 | "Luke", 84 | "Madison", 85 | "Marco", 86 | "Marcus", 87 | "Maria", 88 | "Mark", 89 | "Martin", 90 | "Mary", 91 | "Matthew", 92 | "Max", 93 | "Michael", 94 | "Michelle", 95 | "Morgan", 96 | "Patrick", 97 | "Paul", 98 | "Peter", 99 | "Prince", 100 | "Rachel", 101 | "Richard", 102 | "River", 103 | "Robert", 104 | "Roman", 105 | "Rose", 106 | "Ruby", 107 | "Russell", 108 | "Ryan", 109 | "Sarah", 110 | "Scott", 111 | "Sean", 112 | "Simon", 113 | "Stephen", 114 | "Steven", 115 | "Sullivan", 116 | "Taylor", 117 | "Thomas", 118 | "Tyler", 119 | "Victoria", 120 | "Warren", 121 | "William", 122 | ] 123 | 124 | ABC_TEMPLATES = [ 125 | "Then, [A], [B] and [C] went to the [PLACE]. [B] and [C] gave a [OBJECT] to [A]", 126 | "Afterwards [A], [B] and [C] went to the [PLACE]. [B] and [C] gave a [OBJECT] to [A]", 127 | "When [A], [B] and [C] arrived at the [PLACE], [B] and [C] gave a [OBJECT] to [A]", 128 | "Friends [A], [B] and [C] went to the [PLACE]. [B] and [C] gave a [OBJECT] to [A]", 129 | ] 130 | 131 | BAC_TEMPLATES = [ 132 | template.replace("[B]", "[A]", 1).replace("[A]", "[B]", 1) 133 | for template in ABC_TEMPLATES 134 | ] 135 | 136 | BABA_TEMPLATES = [ 137 | "Then, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", 138 | "Then, [B] and [A] had a lot of fun at the [PLACE]. [B] gave a [OBJECT] to [A]", 139 | "Then, [B] and [A] were working at the [PLACE]. [B] decided to give a [OBJECT] to [A]", 140 | "Then, [B] and [A] were thinking about going to the [PLACE]. [B] wanted to give a [OBJECT] to [A]", 141 | "Then, [B] and [A] had a long argument, and afterwards [B] said to [A]", 142 | "After [B] and [A] went to the [PLACE], [B] gave a [OBJECT] to [A]", 143 | "When [B] and [A] got a [OBJECT] at the [PLACE], [B] decided to give it to [A]", 144 | "When [B] and [A] got a [OBJECT] at the [PLACE], [B] decided to give the [OBJECT] to [A]", 145 | "While [B] and [A] were working at the [PLACE], [B] gave a [OBJECT] to [A]", 146 | "While [B] and [A] were commuting to the [PLACE], [B] gave a [OBJECT] to [A]", 147 | "After the lunch, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", 148 | "Afterwards, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", 149 | "Then, [B] and [A] had a long argument. Afterwards [B] said to [A]", 150 | "The [PLACE] [B] and [A] went to had a [OBJECT]. [B] gave it to [A]", 151 | "Friends [B] and [A] found a [OBJECT] at the [PLACE]. [B] gave it to [A]", 152 | ] 153 | 154 | BABA_LONG_TEMPLATES = [ 155 | "Then in the morning, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", 156 | "Then in the morning, [B] and [A] had a lot of fun at the [PLACE]. [B] gave a [OBJECT] to [A]", 157 | "Then in the morning, [B] and [A] were working at the [PLACE]. [B] decided to give a [OBJECT] to [A]", 158 | "Then in the morning, [B] and [A] were thinking about going to the [PLACE]. [B] wanted to give a [OBJECT] to [A]", 159 | "Then in the morning, [B] and [A] had a long argument, and afterwards [B] said to [A]", 160 | "After taking a long break [B] and [A] went to the [PLACE], [B] gave a [OBJECT] to [A]", 161 | "When soon afterwards [B] and [A] got a [OBJECT] at the [PLACE], [B] decided to give it to [A]", 162 | "When soon afterwards [B] and [A] got a [OBJECT] at the [PLACE], [B] decided to give the [OBJECT] to [A]", 163 | "While spending time together [B] and [A] were working at the [PLACE], [B] gave a [OBJECT] to [A]", 164 | "While spending time together [B] and [A] were commuting to the [PLACE], [B] gave a [OBJECT] to [A]", 165 | "After the lunch in the afternoon, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", 166 | "Afterwards, while spending time together [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", 167 | "Then in the morning afterwards, [B] and [A] had a long argument. Afterwards [B] said to [A]", 168 | "The local big [PLACE] [B] and [A] went to had a [OBJECT]. [B] gave it to [A]", 169 | "Friends separated at birth [B] and [A] found a [OBJECT] at the [PLACE]. [B] gave it to [A]", 170 | ] 171 | 172 | BABA_LATE_IOS = [ 173 | "Then, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", 174 | "Then, [B] and [A] had a lot of fun at the [PLACE]. [B] gave a [OBJECT] to [A]", 175 | "Then, [B] and [A] were working at the [PLACE]. [B] decided to give a [OBJECT] to [A]", 176 | "Then, [B] and [A] were thinking about going to the [PLACE]. [B] wanted to give a [OBJECT] to [A]", 177 | "Then, [B] and [A] had a long argument and after that [B] said to [A]", 178 | "After the lunch, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", 179 | "Afterwards, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]", 180 | "Then, [B] and [A] had a long argument. Afterwards [B] said to [A]", 181 | ] 182 | 183 | BABA_EARLY_IOS = [ 184 | "Then [B] and [A] went to the [PLACE], and [B] gave a [OBJECT] to [A]", 185 | "Then [B] and [A] had a lot of fun at the [PLACE], and [B] gave a [OBJECT] to [A]", 186 | "Then [B] and [A] were working at the [PLACE], and [B] decided to give a [OBJECT] to [A]", 187 | "Then [B] and [A] were thinking about going to the [PLACE], and [B] wanted to give a [OBJECT] to [A]", 188 | "Then [B] and [A] had a long argument, and after that [B] said to [A]", 189 | "After the lunch [B] and [A] went to the [PLACE], and [B] gave a [OBJECT] to [A]", 190 | "Afterwards [B] and [A] went to the [PLACE], and [B] gave a [OBJECT] to [A]", 191 | "Then [B] and [A] had a long argument, and afterwards [B] said to [A]", 192 | ] 193 | 194 | ABBA_TEMPLATES = BABA_TEMPLATES[:] 195 | ABBA_LATE_IOS = BABA_LATE_IOS[:] 196 | ABBA_EARLY_IOS = BABA_EARLY_IOS[:] 197 | 198 | for TEMPLATES in [ABBA_TEMPLATES, ABBA_LATE_IOS, ABBA_EARLY_IOS]: 199 | for i in range(len(TEMPLATES)): 200 | first_clause = True 201 | for j in range(1, len(TEMPLATES[i]) - 1): 202 | if TEMPLATES[i][j - 1 : j + 2] == "[B]" and first_clause: 203 | TEMPLATES[i] = TEMPLATES[i][:j] + "A" + TEMPLATES[i][j + 1 :] 204 | elif TEMPLATES[i][j - 1 : j + 2] == "[A]" and first_clause: 205 | first_clause = False 206 | TEMPLATES[i] = TEMPLATES[i][:j] + "B" + TEMPLATES[i][j + 1 :] 207 | 208 | VERBS = [" tried", " said", " decided", " wanted", " gave"] 209 | 210 | PLACES = [ 211 | "store", 212 | "garden", 213 | "restaurant", 214 | "school", 215 | "hospital", 216 | "office", 217 | "house", 218 | "station", 219 | ] 220 | 221 | OBJECTS = [ 222 | "ring", 223 | "kiss", 224 | "bone", 225 | "basketball", 226 | "computer", 227 | "necklace", 228 | "drink", 229 | "snack", 230 | ] 231 | 232 | 233 | def gen_prompt_uniform( 234 | templates, names, nouns_dict, N, symmetric, prefixes=None, abc=False 235 | ): 236 | nb_gen = 0 237 | ioi_prompts = [] 238 | while nb_gen < N: 239 | temp = random.choice(templates) 240 | temp_id = templates.index(temp) 241 | name_1 = "" 242 | name_2 = "" 243 | name_3 = "" 244 | while len(set([name_1, name_2, name_3])) < 3: 245 | name_1 = random.choice(names) 246 | name_2 = random.choice(names) 247 | name_3 = random.choice(names) 248 | 249 | nouns = {} 250 | ioi_prompt = {} 251 | for k in nouns_dict: 252 | nouns[k] = random.choice(nouns_dict[k]) 253 | ioi_prompt[k] = nouns[k] 254 | prompt = temp 255 | for k in nouns_dict: 256 | prompt = prompt.replace(k, nouns[k]) 257 | 258 | if prefixes is not None: 259 | L = random.randint(30, 40) 260 | pref = ".".join(random.choice(prefixes).split(".")[:L]) 261 | pref += "<|endoftext|>" 262 | else: 263 | pref = "" 264 | 265 | prompt1 = prompt.replace("[A]", name_1) 266 | prompt1 = prompt1.replace("[B]", name_2) 267 | if abc: 268 | prompt1 = prompt1.replace("[C]", name_3) 269 | prompt1 = pref + prompt1 270 | ioi_prompt["text"] = prompt1 271 | ioi_prompt["IO"] = name_1 272 | ioi_prompt["S"] = name_2 273 | ioi_prompt["TEMPLATE_IDX"] = temp_id 274 | ioi_prompts.append(ioi_prompt) 275 | if abc: 276 | ioi_prompts[-1]["C"] = name_3 277 | 278 | nb_gen += 1 279 | 280 | if symmetric and nb_gen < N: 281 | prompt2 = prompt.replace("[A]", name_2) 282 | prompt2 = prompt2.replace("[B]", name_1) 283 | prompt2 = pref + prompt2 284 | ioi_prompts.append( 285 | {"text": prompt2, "IO": name_2, "S": name_1, "TEMPLATE_IDX": temp_id} 286 | ) 287 | nb_gen += 1 288 | return ioi_prompts 289 | 290 | 291 | 292 | def flip_words_in_prompt(prompt: str, word1: str, word2: str, instances: Optional[Union[int, List[int]]] = None): 293 | ''' 294 | Flips instances of word `word1` with `word2` in the string `string`. 295 | 296 | By default it flips all instances, but the optional `instances` argument specifies which 297 | instances to flip (e.g. if instances = 0, then it only flips the 0th instance of either 298 | word1 or word2. 299 | 300 | Examples of (arguments) -> return value: 301 | 302 | ("ABA", "A", "B") -> "BAB" 303 | ("ABA", "A", "B", 1) -> "AAA" 304 | ("ABA", "A", "B", [0, 1]) -> "BAA 305 | ''' 306 | split_prompt = re.split("({}|{})".format(word1, word2), prompt) 307 | indices_of_names = [i for i, s in enumerate(split_prompt) if s in (word1, word2)] 308 | indices_to_flip = [indices_of_names[i] for i in instances] 309 | for i in indices_to_flip: 310 | split_prompt[i] = word1 if split_prompt[i] == word2 else word1 311 | prompt = "".join(split_prompt) 312 | return prompt 313 | 314 | 315 | 316 | def gen_flipped_prompts(prompts: List[dict], templates_by_prompt: List[str], flip: str, names: List[str], seed: int) -> List[dict]: 317 | ''' 318 | Flip prompts in a way described by the flip argument. Returns new prompts. 319 | 320 | prompts: List[dict] 321 | list of prompts, each prompt is a dict with keys "S", "IO", "text", etc 322 | 323 | templates_by_prompt: List[str] 324 | each element is "ABBA" or "BABA" 325 | 326 | flip: str 327 | "ABB -> XYZ, BAB -> XYZ" means that the prompt "A and B went to [place], B gave [object] to A" becomes "X and Y went to [place], Z gave [object] to A" (and equivalent for the BABA case) 328 | 329 | names: List[str] 330 | list of names, for when flip involves random tokens 331 | 332 | seed: int 333 | provides reproducibility 334 | 335 | Note that we don't bother flipping the last token in the prompt (IO2), since 336 | we don't use it for anything (intuitively, we use this function to create 337 | datasets to provide us with corrupted signals, but we still use the IO2 from 338 | the original uncorrupted IOI database as our "correct answer", so we don't 339 | care about what the correct answer (IO2) for the corrupted set is). 340 | ''' 341 | random.seed(seed) 342 | np.random.seed(seed) 343 | abba_flip, baba_flip = flip.split(",") 344 | flip_dict = { 345 | "ABB": [flip.strip() for flip in abba_flip.split("->")], 346 | "BAB": [flip.strip() for flip in baba_flip.split("->")] 347 | } 348 | 349 | new_prompts = [] 350 | 351 | for idx, (prompt, template) in enumerate(zip(prompts, templates_by_prompt)): 352 | 353 | flip_orig, flip_new = flip_dict[template[:-1]] 354 | 355 | prompt = copy.copy(prompt) 356 | 357 | # Get indices and original values of first three names int the prompt 358 | prompt_split = prompt["text"].split(" ") 359 | orig_names_and_posns = [(i, s) for i, s in enumerate(prompt_split) if s in names][:3] 360 | orig_names = list(zip(*orig_names_and_posns))[1] 361 | 362 | # Get a dictionary of the correspondence between orig names and letters in flip_orig 363 | # (and get a subdict for those names which are kept in flip_new) 364 | orig_names_key = { 365 | letter: s 366 | for s, letter in zip(orig_names, flip_orig) 367 | } 368 | kept_names_key = { 369 | k: v 370 | for k, v in orig_names_key.items() if k in flip_new 371 | } 372 | # This line will throw an error if flip_orig is wrong (e.g. if it says "SOS" but the 373 | # S1 and S2 tokens don't actually match 374 | assert len(orig_names_key) == len(set(flip_orig)) 375 | 376 | # Get all random names we'll need, in the form of a dictionary 377 | rand_names = { 378 | letter: np.random.choice(list(set(names) - set(orig_names))) 379 | for letter in set(flip_new) - set(flip_orig) 380 | } 381 | 382 | # Get a "full dictionary" which maps letters in flip_new to the new values they will have 383 | name_replacement_dict = {**kept_names_key, **rand_names} 384 | assert len(name_replacement_dict) == len(set(flip_new)), (name_replacement_dict, flip_new) 385 | 386 | # Populate the new names, with either random names or with the corresponding orig names 387 | for (i, s), letter in zip(orig_names_and_posns, flip_new): 388 | prompt_split[i] = name_replacement_dict[letter] 389 | 390 | # Join the prompt back together 391 | prompt["text"] = " ".join(prompt_split) 392 | 393 | # Change the identity of the S and IO tokens. 394 | # S token is just same as S2, but IO is a bit messier because it might not be 395 | # well-defined (it's defined as the unique non-duplicated name of the first 396 | # two). If it's ill-defined, WLOG set it to be the second name. 397 | prompt["S"] = name_replacement_dict[flip_new[-1]] 398 | possible_IOs = [name_replacement_dict[letter] for letter in flip_new[:2] if list(flip_new).count(letter) == 1] 399 | # Case where IO is well-defined 400 | if len(possible_IOs) == 1: 401 | prompt["IO"] = possible_IOs[0] 402 | # Case where it isn't well-defined 403 | else: 404 | prompt["IO"] = name_replacement_dict[flip_new[1]] 405 | 406 | new_prompts.append(prompt) 407 | 408 | return new_prompts 409 | 410 | 411 | 412 | def get_name_idxs(prompts, tokenizer, idx_types=["IO", "S1", "S2"], prepend_bos=False): 413 | name_idx_dict = dict((idx_type, []) for idx_type in idx_types) 414 | for prompt in prompts: 415 | text_split = prompt["text"].split(" ") 416 | toks = tokenizer.tokenize(" ".join(text_split[:-1])) 417 | # Get the first instance of IO token 418 | name_idx_dict["IO"].append( 419 | toks.index(tokenizer.tokenize(" " + prompt["IO"])[0]) 420 | ) 421 | # Get the first instance of S token 422 | name_idx_dict["S1"].append( 423 | toks.index(tokenizer.tokenize(" " + prompt["S"])[0]) 424 | ) 425 | # Get the last instance of S token 426 | name_idx_dict["S2"].append( 427 | len(toks) - toks[::-1].index(tokenizer.tokenize(" " + prompt["S"])[0]) - 1 428 | ) 429 | 430 | return [ 431 | int(prepend_bos) + t.tensor(name_idx_dict[idx_type]) 432 | for idx_type in idx_types 433 | ] 434 | 435 | 436 | def get_word_idxs(prompts, word_list, tokenizer): 437 | """Get the index of the words in word_list in the prompts. Exactly one of the word_list word has to be present in each prompt""" 438 | idxs = [] 439 | tokenized_words = [ 440 | tokenizer.decode(tokenizer(word)["input_ids"][0]) for word in word_list 441 | ] 442 | for prompt in prompts: 443 | toks = [ 444 | tokenizer.decode(t) 445 | for t in tokenizer(prompt["text"], return_tensors="pt", padding=True)[ 446 | "input_ids" 447 | ][0] 448 | ] 449 | idx = None 450 | for i, w_tok in enumerate(tokenized_words): 451 | if word_list[i] in prompt["text"]: 452 | try: 453 | idx = toks.index(w_tok) 454 | if toks.count(w_tok) > 1: 455 | idx = len(toks) - toks[::-1].index(w_tok) - 1 456 | except: 457 | idx = toks.index(w_tok) 458 | # raise ValueError(toks, w_tok, prompt["text"]) 459 | if idx is None: 460 | raise ValueError(f"Word {word_list} and {i} not found {prompt}") 461 | idxs.append(idx) 462 | return t.tensor(idxs) 463 | 464 | 465 | def get_end_idxs(toks, tokenizer, name_tok_len=1, prepend_bos=False): 466 | relevant_idx = int(prepend_bos) 467 | # if the sentence begins with an end token 468 | # AND the model pads at the end with the same end token, 469 | # then we need make special arrangements 470 | 471 | pad_token_id = tokenizer.pad_token_id 472 | 473 | end_idxs_raw = [] 474 | for i in range(toks.shape[0]): 475 | if pad_token_id not in toks[i][1:]: 476 | end_idxs_raw.append(toks.shape[1]) 477 | continue 478 | nonzers = (toks[i] == pad_token_id).nonzero()[relevant_idx][0].item() 479 | end_idxs_raw.append(nonzers) 480 | end_idxs = t.tensor(end_idxs_raw) 481 | end_idxs = end_idxs - 1 - name_tok_len 482 | 483 | for i in range(toks.shape[0]): 484 | assert toks[i][end_idxs[i] + 1] != 0 and ( 485 | toks.shape[1] == end_idxs[i] + 2 or toks[i][end_idxs[i] + 2] == pad_token_id 486 | ), ( 487 | toks[i], 488 | end_idxs[i], 489 | toks[i].shape, 490 | "the END idxs aren't properly formatted", 491 | ) 492 | 493 | return end_idxs 494 | 495 | 496 | 497 | 498 | 499 | def get_idx_dict(ioi_prompts, tokenizer, prepend_bos=False, toks=None): 500 | (IO_idxs, S1_idxs, S2_idxs,) = get_name_idxs( 501 | ioi_prompts, 502 | tokenizer, 503 | idx_types=["IO", "S1", "S2"], 504 | prepend_bos=prepend_bos, 505 | ) 506 | 507 | end_idxs = get_end_idxs( 508 | toks, 509 | tokenizer, 510 | name_tok_len=1, 511 | prepend_bos=prepend_bos, 512 | ) 513 | 514 | punct_idxs = get_word_idxs(ioi_prompts, [",", "."], tokenizer) 515 | 516 | return { 517 | "IO": IO_idxs, 518 | "IO-1": IO_idxs - 1, 519 | "IO+1": IO_idxs + 1, 520 | "S1": S1_idxs, 521 | "S1-1": S1_idxs - 1, 522 | "S1+1": S1_idxs + 1, 523 | "S2": S2_idxs, 524 | "end": end_idxs, 525 | "starts": t.zeros_like(end_idxs), 526 | "punct": punct_idxs, 527 | } 528 | 529 | def format_prompt(sentence: str) -> str: 530 | '''Format a prompt by underlining names (for rich print)''' 531 | return re.sub("(" + "|".join(NAMES) + ")", lambda x: f"[u bold dark_orange]{x.group(0)}[/]", sentence) + "\n" 532 | 533 | 534 | def make_table(cols, colnames, title="", n_rows=5, decimals=4): 535 | '''Makes and displays a table, from cols rather than rows (using rich print)''' 536 | table = Table(*colnames, title=title) 537 | rows = list(zip(*cols)) 538 | f = lambda x: x if isinstance(x, str) else f"{x:.{decimals}f}" 539 | for row in rows[:n_rows]: 540 | table.add_row(*list(map(f, row))) 541 | rprint(table) 542 | 543 | class IOIDataset: 544 | def __init__( 545 | self, 546 | prompt_type: Union[ 547 | str, List[str] 548 | ], # if list, then it will be a list of templates 549 | N=500, 550 | tokenizer=None, 551 | prompts=None, 552 | symmetric=False, 553 | prefixes=None, 554 | nb_templates=None, 555 | prepend_bos=False, 556 | manual_word_idx=None, 557 | has_been_flipped:bool=False, 558 | seed=0, 559 | device="cuda" 560 | ): 561 | self.seed = seed 562 | random.seed(self.seed) 563 | np.random.seed(self.seed) 564 | if not ( 565 | N == 1 566 | or prepend_bos == False 567 | or tokenizer.bos_token_id == tokenizer.eos_token_id 568 | ): 569 | warnings.warn( 570 | "Probably word_idx will be calculated incorrectly due to this formatting" 571 | ) 572 | self.has_been_flipped = has_been_flipped 573 | assert not (symmetric and prompt_type == "ABC") 574 | assert ( 575 | (prompts is not None) or (not symmetric) or (N % 2 == 0) 576 | ), f"{symmetric} {N}" 577 | self.prompt_type = prompt_type 578 | 579 | if nb_templates is None: 580 | nb_templates = len(BABA_TEMPLATES) 581 | 582 | if prompt_type == "ABBA": 583 | self.templates = ABBA_TEMPLATES[:nb_templates].copy() 584 | elif prompt_type == "BABA": 585 | self.templates = BABA_TEMPLATES[:nb_templates].copy() 586 | elif prompt_type == "mixed": 587 | self.templates = ( 588 | BABA_TEMPLATES[: nb_templates // 2].copy() 589 | + ABBA_TEMPLATES[: nb_templates // 2].copy() 590 | ) 591 | random.shuffle(self.templates) 592 | elif prompt_type == "ABC": 593 | self.templates = ABC_TEMPLATES[:nb_templates].copy() 594 | elif prompt_type == "BAC": 595 | self.templates = BAC_TEMPLATES[:nb_templates].copy() 596 | elif prompt_type == "ABC mixed": 597 | self.templates = ( 598 | ABC_TEMPLATES[: nb_templates // 2].copy() 599 | + BAC_TEMPLATES[: nb_templates // 2].copy() 600 | ) 601 | random.shuffle(self.templates) 602 | elif isinstance(prompt_type, list): 603 | self.templates = prompt_type 604 | else: 605 | raise ValueError(prompt_type) 606 | 607 | if tokenizer is None: 608 | self.tokenizer = AutoTokenizer.from_pretrained("gpt2") 609 | self.tokenizer.pad_token = self.tokenizer.eos_token 610 | else: 611 | self.tokenizer = tokenizer 612 | 613 | self.prefixes = prefixes 614 | self.prompt_type = prompt_type 615 | if prompts is None: 616 | self.ioi_prompts = gen_prompt_uniform( # list of dict of the form {"text": "Alice and Bob bla bla. Bob gave bla to Alice", "IO": "Alice", "S": "Bob"} 617 | self.templates, 618 | NAMES, 619 | nouns_dict={"[PLACE]": PLACES, "[OBJECT]": OBJECTS}, 620 | N=N, 621 | symmetric=symmetric, 622 | prefixes=self.prefixes, 623 | abc=(prompt_type in ["ABC", "ABC mixed", "BAC"]), 624 | ) 625 | else: 626 | assert N == len(prompts), f"{N} and {len(prompts)}" 627 | self.ioi_prompts = prompts 628 | 629 | all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts] 630 | all_ids_ar = np.array(all_ids) 631 | self.groups = [] 632 | for id in list(set(all_ids)): 633 | self.groups.append(np.where(all_ids_ar == id)[0]) 634 | 635 | small_groups = [] 636 | for group in self.groups: 637 | if len(group) < 5: 638 | small_groups.append(len(group)) 639 | 640 | self.sentences = [ 641 | prompt["text"] for prompt in self.ioi_prompts 642 | ] # a list of strings. Renamed as this should NOT be forward passed 643 | 644 | self.templates_by_prompt = [] # for each prompt if it's ABBA or BABA 645 | for i in range(N): 646 | if self.sentences[i].index(self.ioi_prompts[i]["IO"]) < self.sentences[ 647 | i 648 | ].index(self.ioi_prompts[i]["S"]): 649 | self.templates_by_prompt.append("ABBA") 650 | else: 651 | self.templates_by_prompt.append("BABA") 652 | 653 | texts = [ 654 | (self.tokenizer.bos_token if prepend_bos else "") + prompt["text"] 655 | for prompt in self.ioi_prompts 656 | ] 657 | self.toks = t.Tensor(self.tokenizer(texts, padding=True).input_ids).long() 658 | 659 | self.word_idx = get_idx_dict( 660 | self.ioi_prompts, 661 | self.tokenizer, 662 | prepend_bos=prepend_bos, 663 | toks=self.toks, 664 | ) 665 | self.prepend_bos = prepend_bos 666 | if manual_word_idx is not None: 667 | self.word_idx = manual_word_idx 668 | 669 | self.N = N 670 | self.max_len = max( 671 | [ 672 | len(self.tokenizer(prompt["text"]).input_ids) 673 | for prompt in self.ioi_prompts 674 | ] 675 | ) 676 | 677 | self.io_tokenIDs = [ 678 | self.tokenizer.encode(" " + prompt["IO"])[0] for prompt in self.ioi_prompts 679 | ] 680 | self.s_tokenIDs = [ 681 | self.tokenizer.encode(" " + prompt["S"])[0] for prompt in self.ioi_prompts 682 | ] 683 | 684 | self.tokenized_prompts = [] 685 | 686 | for i in range(self.N): 687 | self.tokenized_prompts.append( 688 | "|".join([self.tokenizer.decode(tok) for tok in self.toks[i]]) 689 | ) 690 | 691 | self.device = device 692 | self.to(device) 693 | 694 | def gen_flipped_prompts(self, flip): 695 | # Check if it's already been flipped (shouldn't string 2 flips together) 696 | if self.has_been_flipped: 697 | warnings.warn("This dataset has already been flipped. Generally, you should try and apply flips in one step, because this can lead to errors.") 698 | 699 | # Redefine seed (so it's different depending on what the flip is, e.g. we don't want (IO, RAND) then (S, RAND) to give us the same rand names) 700 | seed = self.seed + sum(map(ord, list("".join(flip)))) 701 | 702 | # Get flipped prompts 703 | flipped_prompts = gen_flipped_prompts(self.ioi_prompts, self.templates_by_prompt, flip, NAMES, seed) 704 | 705 | flipped_ioi_dataset = IOIDataset( 706 | prompt_type=self.prompt_type, 707 | N=self.N, 708 | tokenizer=self.tokenizer, 709 | prompts=flipped_prompts, 710 | prefixes=self.prefixes, 711 | prepend_bos=self.prepend_bos, 712 | manual_word_idx=self.word_idx, 713 | has_been_flipped=True, 714 | seed=seed 715 | ) 716 | return flipped_ioi_dataset 717 | 718 | def copy(self): 719 | copy_ioi_dataset = IOIDataset( 720 | prompt_type=self.prompt_type, 721 | N=self.N, 722 | tokenizer=self.tokenizer, 723 | prompts=self.ioi_prompts.copy(), 724 | prefixes=self.prefixes.copy() if self.prefixes is not None else self.prefixes, 725 | ) 726 | return copy_ioi_dataset 727 | 728 | def __getitem__(self, key): 729 | sliced_prompts = self.ioi_prompts[key] 730 | sliced_dataset = IOIDataset( 731 | prompt_type=self.prompt_type, 732 | N=len(sliced_prompts), 733 | tokenizer=self.tokenizer, 734 | prompts=sliced_prompts, 735 | prefixes=self.prefixes, 736 | prepend_bos=self.prepend_bos, 737 | ) 738 | return sliced_dataset 739 | 740 | def __setitem__(self, key, value): 741 | raise NotImplementedError() 742 | 743 | def __delitem__(self, key): 744 | raise NotImplementedError() 745 | 746 | def __len__(self): 747 | return self.N 748 | 749 | def tokenized_prompts(self): 750 | return self.toks 751 | 752 | def to(self, device): 753 | self.toks = self.toks.to(device) 754 | return self -------------------------------------------------------------------------------- /ioi_task/noabs_value_num_passes.json: -------------------------------------------------------------------------------- 1 | {"0.005": 2960, "0.01": 1787, "0.015": 1034, "0.02": 824, "0.025": 691, "0.030000000000000002": 528, "0.034999999999999996": 421, "0.04": 386, "0.045": 386, "0.049999999999999996": 362, "0.055": 331, "0.06": 300, "0.065": 261, "0.07": 231, "0.07500000000000001": 195, "0.08": 199, "0.085": 159, "0.09000000000000001": 154, "0.095": 154, "0.1": 154, "0.10500000000000001": 154, "0.11": 154, "0.115": 154, "0.12000000000000001": 154, "0.125": 154, "0.13": 154, "0.135": 124, "0.14": 113, "0.14500000000000002": 113, "0.15": 113} -------------------------------------------------------------------------------- /ioi_task/noabs_value_pruned_heads.json: -------------------------------------------------------------------------------- 1 | {"0.005": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", "", ""]], "0.01": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", ""]], "0.015": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", ""]], "0.02": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", ""]], "0.025": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", ""]], "0.030000000000000002": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.034999999999999996": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.04": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.045": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.049999999999999996": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.055": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.06": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.065": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", "", "", "", "", "", "", "", ""]], "0.07": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", "", "", "", "", "", "", ""]], "0.07500000000000001": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", "", "", "", "", "", "", ""]], "0.08": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", "", "", "", "", ""]], "0.085": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", "", "", ""]], "0.09000000000000001": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", ""]], "0.095": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", ""]], "0.1": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", ""]], "0.10500000000000001": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", ""]], "0.11": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", ""]], "0.115": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", ""]], "0.12000000000000001": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", ""]], "0.125": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", ""]], "0.13": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", ""]], "0.135": [["", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "", "embed", ""], ["", "", "", "", "", ""]], "0.14": [["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", ""], ["", "", "", "", ""]], "0.14500000000000002": [["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", ""], ["", "", "", "", ""]], "0.15": [["", "", "", "", "", "", "", "", "", "", "", "", "", "embed", "", "", "", ""], ["", "", "", "", ""]]} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.5 2 | aiosignal==1.3.1 3 | anyio==3.7.1 4 | appdirs==1.4.4 5 | argon2-cffi==21.3.0 6 | argon2-cffi-bindings==21.2.0 7 | arrow==1.2.3 8 | asttokens @ file:///opt/conda/conda-bld/asttokens_1646925590279/work 9 | astunparse==1.6.3 10 | async-lru==2.0.4 11 | async-timeout==4.0.3 12 | attrs==23.1.0 13 | Babel==2.12.1 14 | backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work 15 | bash_kernel==0.9.1 16 | beartype==0.14.1 17 | beautifulsoup4 @ file:///croot/beautifulsoup4-split_1681493039619/work 18 | bleach==6.0.0 19 | boltons @ file:///croot/boltons_1677628692245/work 20 | brotlipy==0.7.0 21 | certifi @ file:///croot/certifi_1683875369620/work/certifi 22 | cffi @ file:///croot/cffi_1670423208954/work 23 | chardet @ file:///home/builder/ci_310/chardet_1640804867535/work 24 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work 25 | click==8.1.7 26 | comm==0.1.4 27 | conda==23.3.1 28 | conda-build==3.24.0 29 | conda-content-trust @ file:///tmp/abs_5952f1c8-355c-4855-ad2e-538535021ba5h26t22e5/croots/recipe/conda-content-trust_1658126371814/work 30 | conda-package-handling @ file:///croot/conda-package-handling_1672865015732/work 31 | conda_package_streaming @ file:///croot/conda-package-streaming_1670508151586/work 32 | cryptography @ file:///croot/cryptography_1677533068310/work 33 | datasets==2.14.5 34 | debugpy==1.6.7.post1 35 | decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work 36 | defusedxml==0.7.1 37 | dill==0.3.7 38 | dnspython==2.3.0 39 | docker-pycreds==0.4.0 40 | einops==0.6.1 41 | exceptiongroup==1.1.1 42 | executing @ file:///opt/conda/conda-bld/executing_1646925071911/work 43 | expecttest==0.1.4 44 | fancy-einsum==0.0.3 45 | fastjsonschema==2.18.0 46 | filelock @ file:///croot/filelock_1672387128942/work 47 | fqdn==1.5.1 48 | frozenlist==1.4.0 49 | fsspec==2023.6.0 50 | gitdb==4.0.10 51 | GitPython==3.1.35 52 | glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work 53 | gmpy2 @ file:///tmp/build/80754af9/gmpy2_1645455533097/work 54 | huggingface-hub==0.16.4 55 | hypothesis==6.75.2 56 | idna @ file:///croot/idna_1666125576474/work 57 | iniconfig==2.0.0 58 | ipykernel==6.25.1 59 | ipython @ file:///croot/ipython_1680701871216/work 60 | ipython-genutils==0.2.0 61 | ipywidgets==8.1.0 62 | isoduration==20.11.0 63 | jaxtyping==0.2.21 64 | jedi @ file:///tmp/build/80754af9/jedi_1644315229345/work 65 | Jinja2 @ file:///croot/jinja2_1666908132255/work 66 | json5==0.9.14 67 | jsonpatch @ file:///tmp/build/80754af9/jsonpatch_1615747632069/work 68 | jsonpointer==2.1 69 | jsonschema==4.19.0 70 | jsonschema-specifications==2023.7.1 71 | jupyter==1.0.0 72 | jupyter-archive==3.3.4 73 | jupyter-console==6.6.3 74 | jupyter-events==0.7.0 75 | jupyter-http-over-ws==0.0.8 76 | jupyter-lsp==2.2.0 77 | jupyter_client==8.3.0 78 | jupyter_core==5.3.1 79 | jupyter_server==2.7.0 80 | jupyter_server_terminals==0.4.4 81 | jupyterlab==4.0.5 82 | jupyterlab-pygments==0.2.2 83 | jupyterlab-widgets==3.0.8 84 | jupyterlab_server==2.24.0 85 | libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work 86 | markdown-it-py==3.0.0 87 | MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work 88 | matplotlib-inline @ file:///opt/conda/conda-bld/matplotlib-inline_1662014470464/work 89 | mdurl==0.1.2 90 | mistune==3.0.1 91 | mkl-fft==1.3.6 92 | mkl-random @ file:///work/mkl/mkl_random_1682950433854/work 93 | mkl-service==2.4.0 94 | mpmath==1.3.0 95 | multidict==6.0.4 96 | multiprocess==0.70.15 97 | nbclient==0.8.0 98 | nbconvert==7.7.3 99 | nbformat==5.9.2 100 | nbzip==0.1.0 101 | nest-asyncio==1.5.7 102 | networkx==3.1 103 | notebook==7.0.2 104 | notebook_shim==0.2.3 105 | numpy @ file:///work/mkl/numpy_and_numpy_base_1682953417311/work 106 | overrides==7.4.0 107 | packaging @ file:///croot/packaging_1678965309396/work 108 | pandas==2.1.0 109 | pandocfilters==1.5.0 110 | parso @ file:///opt/conda/conda-bld/parso_1641458642106/work 111 | pathtools==0.1.2 112 | pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work 113 | pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work 114 | Pillow==9.4.0 115 | pkginfo @ file:///croot/pkginfo_1679431160147/work 116 | platformdirs==3.10.0 117 | plotly==5.16.1 118 | pluggy @ file:///tmp/build/80754af9/pluggy_1648024709248/work 119 | prometheus-client==0.17.1 120 | prompt-toolkit @ file:///croot/prompt-toolkit_1672387306916/work 121 | protobuf==4.24.3 122 | psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work 123 | ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 124 | pure-eval @ file:///opt/conda/conda-bld/pure_eval_1646925070566/work 125 | pyarrow==13.0.0 126 | pycosat @ file:///croot/pycosat_1666805502580/work 127 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work 128 | Pygments @ file:///croot/pygments_1683671804183/work 129 | pygraphviz==1.11 130 | pyOpenSSL @ file:///croot/pyopenssl_1677607685877/work 131 | PySocks @ file:///home/builder/ci_310/pysocks_1640793678128/work 132 | pytest==7.4.0 133 | python-dateutil==2.8.2 134 | python-etcd==0.4.5 135 | python-json-logger==2.0.7 136 | pytz @ file:///croot/pytz_1671697431263/work 137 | PyYAML @ file:///croot/pyyaml_1670514731622/work 138 | pyzmq==25.1.1 139 | qtconsole==5.4.3 140 | QtPy==2.3.1 141 | referencing==0.30.2 142 | regex==2023.8.8 143 | requests @ file:///croot/requests_1682607517574/work 144 | rfc3339-validator==0.1.4 145 | rfc3986-validator==0.1.1 146 | rich==13.5.2 147 | rpds-py==0.9.2 148 | ruamel.yaml @ file:///croot/ruamel.yaml_1666304550667/work 149 | ruamel.yaml.clib @ file:///croot/ruamel.yaml.clib_1666302247304/work 150 | safetensors==0.3.3 151 | Send2Trash==1.8.2 152 | sentry-sdk==1.30.0 153 | setproctitle==1.3.2 154 | six @ file:///tmp/build/80754af9/six_1644875935023/work 155 | smmap==5.0.0 156 | sniffio==1.3.0 157 | sortedcontainers==2.4.0 158 | soupsieve @ file:///croot/soupsieve_1680518478486/work 159 | stack-data @ file:///opt/conda/conda-bld/stack_data_1646927590127/work 160 | sympy==1.12 161 | tenacity==8.2.3 162 | terminado==0.17.1 163 | tinycss2==1.2.1 164 | tokenizers==0.13.3 165 | tomli @ file:///opt/conda/conda-bld/tomli_1657175507142/work 166 | toolz @ file:///croot/toolz_1667464077321/work 167 | torch==2.0.1 168 | torchaudio==2.0.2 169 | torchdata @ file:///__w/_temp/conda_build_env/conda-bld/torchdata_1682362130135/work 170 | torchelastic==0.2.2 171 | torchtext==0.15.2 172 | torchvision==0.15.2 173 | tornado==6.3.3 174 | tqdm @ file:///croot/tqdm_1679561862951/work 175 | traitlets @ file:///croot/traitlets_1671143879854/work 176 | transformer-lens==1.6.0 177 | transformers==4.33.1 178 | triton==2.0.0 179 | typeguard==4.1.3 180 | types-dataclasses==0.6.6 181 | typing_extensions==4.7.1 182 | tzdata==2023.3 183 | uri-template==1.3.0 184 | urllib3 @ file:///croot/urllib3_1680254681959/work 185 | wandb==0.15.10 186 | wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work 187 | webcolors==1.13 188 | webencodings==0.5.1 189 | websocket-client==1.6.1 190 | widgetsnbextension==4.0.8 191 | xxhash==3.3.0 192 | yarl==1.9.2 193 | zstandard @ file:///croot/zstandard_1677013143055/work 194 | -------------------------------------------------------------------------------- /utils/graphics_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('Automatic-Circuit-Discovery/') 3 | 4 | from acdc.TLACDCInterpNode import TLACDCInterpNode 5 | from acdc.acdc_utils import EdgeType 6 | import numpy as np 7 | 8 | import pygraphviz as pgv 9 | from pathlib import Path 10 | 11 | import tqdm.notebook as tqdm 12 | 13 | from typing import Union, Dict, Optional 14 | 15 | 16 | def get_node_name(node: TLACDCInterpNode, show_full_index=True): 17 | """Node name for use in pretty graphs""" 18 | 19 | if not show_full_index: 20 | name = "" 21 | qkv_substrings = [f"hook_{letter}" for letter in ["q", "k", "v"]] 22 | qkv_input_substrings = [f"hook_{letter}_input" for letter in ["q", "k", "v"]] 23 | 24 | # Handle embedz 25 | if "resid_pre" in node.name: 26 | assert "0" in node.name and not any([str(i) in node.name for i in range(1, 10)]) 27 | name += "embed" 28 | if len(node.index.hashable_tuple) > 2: 29 | name += f"_[{node.index.hashable_tuple[2]}]" 30 | return name 31 | 32 | elif "embed" in node.name: 33 | name = "pos_embeds" if "pos" in node.name else "token_embeds" 34 | 35 | # Handle q_input and hook_q etc 36 | elif any([node.name.endswith(qkv_input_substring) for qkv_input_substring in qkv_input_substrings]): 37 | relevant_letter = None 38 | for letter, qkv_substring in zip(["q", "k", "v"], qkv_substrings): 39 | if qkv_substring in node.name: 40 | assert relevant_letter is None 41 | relevant_letter = letter 42 | name += "a" + node.name.split(".")[1] + "." + str(node.index.hashable_tuple[2]) + "_" + relevant_letter 43 | 44 | # Handle attention hook_result 45 | elif "hook_result" in node.name or any([qkv_substring in node.name for qkv_substring in qkv_substrings]): 46 | name = "a" + node.name.split(".")[1] + "." + str(node.index.hashable_tuple[2]) 47 | 48 | # Handle MLPs 49 | elif node.name.endswith("resid_mid"): 50 | raise ValueError("We removed resid_mid annotations. Call these mlp_in now.") 51 | elif node.name.endswith("mlp_out") or node.name.endswith("mlp_in"): 52 | name = "m" + node.name.split(".")[1] 53 | 54 | # Handle resid_post 55 | elif "resid_post" in node.name: 56 | name += "resid_post" 57 | 58 | else: 59 | raise ValueError(f"Unrecognized node name {node.name}") 60 | 61 | else: 62 | 63 | name = node.name + str(node.index.graphviz_index(use_actual_colon=True)) 64 | 65 | return "<" + name + ">" 66 | 67 | def generate_random_color(colorscheme: str) -> str: 68 | """ 69 | https://stackoverflow.com/questions/28999287/generate-random-colors-rgb 70 | """ 71 | def rgb2hex(rgb): 72 | """ 73 | https://stackoverflow.com/questions/3380726/converting-an-rgb-color-tuple-to-a-hexidecimal-string 74 | """ 75 | return "#{:02x}{:02x}{:02x}".format(rgb[0], rgb[1], rgb[2]) 76 | 77 | return rgb2hex((np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256))) 78 | 79 | def build_colorscheme(correspondence, colorscheme: str = "Pastel2", show_full_index=True) -> Dict[str, str]: 80 | colors = {} 81 | for node in correspondence.nodes(): 82 | colors[get_node_name(node, show_full_index=show_full_index)] = generate_random_color(colorscheme) 83 | return colors 84 | 85 | def show( 86 | correspondence: TLACDCInterpNode, 87 | fname=None, 88 | colorscheme: Union[Dict, str] = "Pastel2", 89 | minimum_penwidth: float = 0.3, 90 | show_full_index: bool = False, 91 | remove_self_loops: bool = True, 92 | remove_qkv: bool = True, 93 | layout: str="dot", 94 | edge_type_colouring: bool = False, 95 | show_placeholders: bool = False, 96 | seed: Optional[int] = None 97 | ): 98 | g = pgv.AGraph(directed=True, bgcolor="transparent", overlap="false", splines="true", layout=layout) 99 | 100 | if seed is not None: 101 | np.random.seed(seed) 102 | 103 | groups = {} 104 | if isinstance(colorscheme, str): 105 | colors = build_colorscheme(correspondence, colorscheme, show_full_index=show_full_index) 106 | else: 107 | colors = colorscheme 108 | for name, color in colors.items(): 109 | if color not in groups: 110 | groups[color] = [name] 111 | else: 112 | groups[color].append(name) 113 | 114 | node_pos = {} 115 | if fname is not None: 116 | base_fname = ".".join(str(fname).split(".")[:-1]) 117 | 118 | base_path = Path(base_fname) 119 | fpath = base_path / "layout.gv" 120 | if fpath.exists(): 121 | g_pos = pgv.AGraph() 122 | g_pos.read(fpath) 123 | for node in g_pos.nodes(): 124 | node_pos[node.name] = node.attr["pos"] 125 | 126 | for child_hook_name in correspondence.edges: 127 | for child_index in correspondence.edges[child_hook_name]: 128 | for parent_hook_name in correspondence.edges[child_hook_name][child_index]: 129 | for parent_index in correspondence.edges[child_hook_name][child_index][parent_hook_name]: 130 | edge = correspondence.edges[child_hook_name][child_index][parent_hook_name][parent_index] 131 | 132 | parent = correspondence.graph[parent_hook_name][parent_index] 133 | child = correspondence.graph[child_hook_name][child_index] 134 | 135 | parent_name = get_node_name(parent, show_full_index=show_full_index) 136 | child_name = get_node_name(child, show_full_index=show_full_index) 137 | 138 | if remove_qkv: 139 | if any(qkv in child_name or qkv in parent_name for qkv in ['_q_', '_k_', '_v_']): 140 | continue 141 | parent_name = parent_name.replace("_q>", ">").replace("_k>", ">").replace("_v>", ">") 142 | child_name = child_name.replace("_q>", ">").replace("_k>", ">").replace("_v>", ">") 143 | 144 | if remove_self_loops and parent_name == child_name: 145 | # Important this go after the qkv removal 146 | continue 147 | 148 | if edge.present and (edge.edge_type != EdgeType.PLACEHOLDER or show_placeholders): 149 | #print(f'Edge from {parent_name=} to {child_name=}') 150 | for node_name in [parent_name, child_name]: 151 | maybe_pos = {} 152 | if node_name in node_pos: 153 | maybe_pos["pos"] = node_pos[node_name] 154 | g.add_node( 155 | node_name, 156 | fillcolor=colors[node_name], 157 | color="black", 158 | style="filled, rounded", 159 | shape="box", 160 | fontname="Helvetica", 161 | **maybe_pos, 162 | ) 163 | 164 | g.add_edge( 165 | parent_name, 166 | child_name, 167 | penwidth=str(minimum_penwidth * 2), 168 | color=colors[parent_name], 169 | ) 170 | if fname is not None: 171 | base_fname = ".".join(str(fname).split(".")[:-1]) 172 | 173 | base_path = Path(base_fname) 174 | base_path.mkdir(exist_ok=True) 175 | for k, s in groups.items(): 176 | g2 = pgv.AGraph(directed=True, bgcolor="transparent", overlap="false", splines="true", layout="neato") 177 | for node_name in s: 178 | g2.add_node( 179 | node_name, 180 | style="filled, rounded", 181 | shape="box", 182 | ) 183 | for i in range(len(s)): 184 | for j in range(i + 1, len(s)): 185 | g2.add_edge(s[i], s[j], style="invis", weight=200) 186 | g2.write(path=base_path / f"{k}.gv") 187 | 188 | g.write(path=base_fname + ".gv") 189 | 190 | if not fname.endswith(".gv"): # turn the .gv file into a .png file 191 | g.draw(path=fname, prog="dot") 192 | 193 | return g -------------------------------------------------------------------------------- /utils/prune_utils.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import sys 3 | sys.path.append('Automatic-Circuit-Discovery/') 4 | from tqdm import tqdm 5 | import torch 6 | from acdc.TLACDCExperiment import TLACDCExperiment 7 | from acdc.acdc_utils import TorchIndex, EdgeType 8 | import torch as t 9 | from torch import Tensor 10 | import einops 11 | import itertools 12 | 13 | from transformer_lens import HookedTransformer, ActivationCache 14 | 15 | from jaxtyping import Bool 16 | from typing import Callable, Tuple, Literal, Dict, Optional, List, Union, Set 17 | from utils.graphics_utils import get_node_name 18 | from typing import NamedTuple 19 | 20 | ModelComponent = NamedTuple("ModelComponent", [("hook_point_name", str), ("index", TorchIndex), ("incoming_edge_type", str)]) # Abstraction for a node in the computational graph. TODO: move this into ACDC repo when standardised. TODO make incoming_edge_type an enum that's hashable 21 | 22 | def remove_redundant_node(exp, node, safe=True, allow_fails=True): 23 | if safe: 24 | for parent_name in exp.corr.edges[node.name][node.index]: 25 | for parent_index in exp.corr.edges[node.name][node.index][parent_name]: 26 | if exp.corr.edges[node.name][node.index][parent_name][parent_index].present: 27 | raise Exception(f"You should not be removing a node that is still used by another node {node} {(parent_name, parent_index)}") 28 | 29 | bfs = [node] 30 | bfs_idx = 0 31 | 32 | while bfs_idx < len(bfs): 33 | cur_node = bfs[bfs_idx] 34 | bfs_idx += 1 35 | 36 | children = exp.corr.graph[cur_node.name][cur_node.index].children 37 | 38 | for child_node in children: 39 | if not cur_node.index in exp.corr.edges[child_node.name][child_node.index][cur_node.name]: 40 | #print(f'\t CANT remove edge {cur_node.name}, {cur_node.index} <-> {child_node.name}, {child_node.index}') 41 | continue 42 | 43 | try: 44 | #print(f'\t Removing edge {cur_node.name}, {cur_node.index} <-> {child_node.name}, {child_node.index}') 45 | exp.corr.remove_edge( 46 | child_node.name, child_node.index, cur_node.name, cur_node.index 47 | ) 48 | except KeyError as e: 49 | print("Got an error", e) 50 | if allow_fails: 51 | continue 52 | else: 53 | raise e 54 | 55 | remove_this = True 56 | for parent_of_child_name in exp.corr.edges[child_node.name][child_node.index]: 57 | for parent_of_child_index in exp.corr.edges[child_node.name][child_node.index][parent_of_child_name]: 58 | if exp.corr.edges[child_node.name][child_node.index][parent_of_child_name][parent_of_child_index].present: 59 | remove_this = False 60 | break 61 | if not remove_this: 62 | break 63 | 64 | if remove_this and child_node not in bfs: 65 | bfs.append(child_node) 66 | 67 | def remove_node(exp, node): 68 | ''' 69 | Method that removes node from model. Assumes children point towards 70 | the end of the residual stream and parents point towards the beginning. 71 | 72 | exp: A TLACDCExperiment object with a reverse top sorted graph 73 | node: A TLACDCInterpNode describing the node to remove 74 | root: Initally the first node in the graph 75 | ''' 76 | #Removing all edges pointing to the node 77 | remove_edges = [] 78 | for p_name in exp.corr.edges[node.name][node.index]: 79 | for p_idx in exp.corr.edges[node.name][node.index][p_name]: 80 | edge = exp.corr.edges[node.name][node.index][p_name][p_idx] 81 | remove_edges.append((node.name, node.index, p_name, p_idx)) 82 | edge.present = False 83 | for n_name, n_idx, p_name, p_idx in remove_edges: 84 | #print(f'\t Removing edge {p_name}, {p_idx} <-> {n_name}, {n_idx}') 85 | exp.corr.remove_edge( 86 | n_name, n_idx, p_name, p_idx 87 | ) 88 | # Removing all outgoing edges from the node using BFS 89 | remove_redundant_node(exp, node, safe=False) 90 | 91 | def find_attn_node(exp, layer, head): 92 | return exp.corr.graph[f'blocks.{layer}.attn.hook_result'][TorchIndex([None, None, head])] 93 | 94 | def find_attn_node_qkv(exp, layer, head): 95 | nodes = [] 96 | for qkv in ['q', 'k', 'v']: 97 | nodes.append(exp.corr.graph[f'blocks.{layer}.attn.hook_{qkv}'][TorchIndex([None, None, head])]) 98 | nodes.append(exp.corr.graph[f'blocks.{layer}.hook_{qkv}_input'][TorchIndex([None, None, head])]) 99 | return nodes 100 | 101 | def split_layers_and_heads(act: Tensor, model: HookedTransformer) -> Tensor: 102 | return einops.rearrange(act, '(layer head) batch seq d_model -> layer head batch seq d_model', 103 | layer=model.cfg.n_layers, 104 | head=model.cfg.n_heads) 105 | 106 | hook_filter = lambda name: name.endswith("ln1.hook_normalized") or name.endswith("attn.hook_result") 107 | 108 | def get_3_caches(model, clean_input, corrupted_input, metric, mode: Literal["node", "edge"]="node"): 109 | # cache the activations and gradients of the clean inputs 110 | model.reset_hooks() 111 | clean_cache = {} 112 | 113 | def forward_cache_hook(act, hook): 114 | clean_cache[hook.name] = act.detach() 115 | 116 | edge_acdcpp_outgoing_filter = lambda name: name.endswith(("hook_result", "hook_mlp_out", "blocks.0.hook_resid_pre", "hook_q", "hook_k", "hook_v")) 117 | model.add_hook(hook_filter if mode == "node" else edge_acdcpp_outgoing_filter, forward_cache_hook, "fwd") 118 | 119 | clean_grad_cache = {} 120 | 121 | def backward_cache_hook(act, hook): 122 | clean_grad_cache[hook.name] = act.detach() 123 | 124 | incoming_ends = ["hook_q_input", "hook_k_input", "hook_v_input", f"blocks.{model.cfg.n_layers-1}.hook_resid_post"] 125 | if not model.cfg.attn_only: 126 | incoming_ends.append("hook_mlp_in") 127 | edge_acdcpp_back_filter = lambda name: name.endswith(tuple(incoming_ends + ["hook_q", "hook_k", "hook_v"])) 128 | model.add_hook(hook_filter if mode=="node" else edge_acdcpp_back_filter, backward_cache_hook, "bwd") 129 | value = metric(model(clean_input)) 130 | 131 | 132 | value.backward() 133 | 134 | # cache the activations of the corrupted inputs 135 | model.reset_hooks() 136 | corrupted_cache = {} 137 | 138 | def forward_corrupted_cache_hook(act, hook): 139 | corrupted_cache[hook.name] = act.detach() 140 | 141 | model.add_hook(hook_filter if mode == "node" else edge_acdcpp_outgoing_filter, forward_corrupted_cache_hook, "fwd") 142 | model(corrupted_input) 143 | model.reset_hooks() 144 | 145 | clean_cache = ActivationCache(clean_cache, model) 146 | corrupted_cache = ActivationCache(corrupted_cache, model) 147 | clean_grad_cache = ActivationCache(clean_grad_cache, model) 148 | return clean_cache, corrupted_cache, clean_grad_cache 149 | 150 | def get_nodes(correspondence): 151 | nodes = set() 152 | for child_hook_name in correspondence.edges: 153 | for child_index in correspondence.edges[child_hook_name]: 154 | for parent_hook_name in correspondence.edges[child_hook_name][child_index]: 155 | for parent_index in correspondence.edges[child_hook_name][child_index][parent_hook_name]: 156 | edge = correspondence.edges[child_hook_name][child_index][parent_hook_name][parent_index] 157 | 158 | parent = correspondence.graph[parent_hook_name][parent_index] 159 | child = correspondence.graph[child_hook_name][child_index] 160 | 161 | parent_name = get_node_name(parent, show_full_index=False) 162 | child_name = get_node_name(child, show_full_index=False) 163 | 164 | if any(qkv in child_name or qkv in parent_name for qkv in ['_q_', '_k_', '_v_']): 165 | continue 166 | parent_name = parent_name.replace("_q>", ">").replace("_k>", ">").replace("_v>", ">") 167 | child_name = child_name.replace("_q>", ">").replace("_k>", ">").replace("_v>", ">") 168 | 169 | if parent_name == child_name: 170 | # Important this go after the qkv removal 171 | continue 172 | 173 | if edge.present and edge.edge_type != EdgeType.PLACEHOLDER: 174 | #print(f'Edge from {parent_name=} to {child_name=}') 175 | for node_name in [parent_name, child_name]: 176 | nodes.add(node_name) 177 | return nodes 178 | 179 | def acdc_nodes(model: HookedTransformer, 180 | clean_input: Tensor, 181 | corrupted_input: Tensor, 182 | metric: Callable[[Tensor], Tensor], 183 | threshold: float, 184 | exp: TLACDCExperiment, 185 | verbose: bool = False, 186 | attr_absolute_val: bool = False, 187 | mode: Literal["node", "edge", "edge_activation_patching"]="node", 188 | ) -> Dict: # TODO label this dict more precisely for the edge vs node methods 189 | ''' 190 | Runs attribution-patching-based ACDC on the model, using the given metric and data. 191 | Returns the pruned model, and which heads were pruned. 192 | 193 | Arguments: 194 | model: the model to prune 195 | clean_input: the input to the model that contains should elicit the behavior we're looking for 196 | corrupted_input: the input to the model that should elicit random behavior 197 | metric: the metric to use to compare the model's performance on the clean and corrupted inputs 198 | threshold: the threshold below which to prune 199 | create_model: a function that returns a new model of the same type as the input model 200 | attr_absolute_val: whether to take the absolute value of the attribution before thresholding 201 | ''' 202 | # get the 2 fwd and 1 bwd caches; cache "normalized" and "result" of attn layers 203 | clean_cache, corrupted_cache, clean_grad_cache = get_3_caches(model, clean_input, corrupted_input, metric, mode=mode) 204 | if mode == "node": 205 | # compute first-order Taylor approximation for each node to get the attribution 206 | clean_head_act = clean_cache.stack_head_results() 207 | corr_head_act = corrupted_cache.stack_head_results() 208 | clean_grad_act = clean_grad_cache.stack_head_results() 209 | 210 | # compute attributions of each node 211 | node_attr = (clean_head_act - corr_head_act) * clean_grad_act 212 | # separate layers and heads, sum over d_model (to complete the dot product), batch, and seq 213 | node_attr = split_layers_and_heads(node_attr, model).sum((2, 3, 4)) 214 | 215 | if attr_absolute_val: 216 | node_attr = node_attr.abs() 217 | del clean_cache 218 | del clean_head_act 219 | del corrupted_cache 220 | del corr_head_act 221 | del clean_grad_cache 222 | del clean_grad_act 223 | t.cuda.empty_cache() 224 | # prune all nodes whose attribution is below the threshold 225 | should_prune = node_attr < threshold 226 | pruned_nodes_attr = {} 227 | 228 | for layer, head in itertools.product(range(model.cfg.n_layers), range(model.cfg.n_heads)): 229 | if should_prune[layer, head]: 230 | # REMOVING NODE 231 | if verbose: 232 | print(f'PRUNING L{layer}H{head} with attribution {node_attr[layer, head]}') 233 | # Find the corresponding node in computation graph 234 | node = find_attn_node(exp, layer, head) 235 | if verbose: 236 | print(f'\tFound node {node.name}') 237 | # Prune node 238 | remove_node(exp, node) 239 | if verbose: 240 | print(f'\tRemoved node {node.name}') 241 | pruned_nodes_attr[(layer, head)] = node_attr[layer, head].item() 242 | 243 | # REMOVING QKV 244 | qkv_nodes = find_attn_node_qkv(exp, layer, head) 245 | for node in qkv_nodes: 246 | remove_node(exp, node) 247 | return pruned_nodes_attr 248 | 249 | elif mode.startswith("edge"): 250 | # Setup the upstream components 251 | relevant_nodes: List = [node for node in exp.corr.nodes() if node.incoming_edge_type in [EdgeType.ADDITION, EdgeType.DIRECT_COMPUTATION]] 252 | results: Dict[Tuple[ModelComponent, ModelComponent], float] = {} # We use a list of floats as we may be splitting by position 253 | 254 | for relevant_node in tqdm(relevant_nodes, desc="Edge pruning"): # TODO ideally we should batch compute things in this loop 255 | parents = set([ModelComponent(hook_point_name=node.name, index=node.index, incoming_edge_type=str(node.incoming_edge_type)) for node in relevant_node.parents]) 256 | downstream_component = ModelComponent(hook_point_name=relevant_node.name, index=relevant_node.index, incoming_edge_type=str(relevant_node.incoming_edge_type)) 257 | for parent in parents: 258 | if "." in parent.hook_point_name and "." in downstream_component.hook_point_name: # hook_embed and hook_pos_embed have no "." but should always be connected anyway 259 | upstream_layer = int(parent.hook_point_name.split(".")[1]) 260 | downstream_layer = int(downstream_component.hook_point_name.split(".")[1]) 261 | if upstream_layer > downstream_layer: 262 | continue 263 | if upstream_layer == downstream_layer and (parent.hook_point_name.endswith("mlp_out") or downstream_component.hook_point_name.endswith(("q_input", "k_input", "v_input"))): 264 | # Other cases where upstream is actually after downstream! 265 | continue 266 | 267 | if mode == "edge_activation_patching": 268 | # Compute the activation patching for this current node 269 | # OK so first try to just keep all the SHIT in memory 270 | # (This is done by default by the caching solution) 271 | # Then, after that, if we want to cram on more machines, be better! 272 | 273 | def my_current_hook(act, hook): 274 | act[downstream_component.index] = corrupted_cache[parent.hook_point_name][parent.index.as_index] 275 | return act 276 | 277 | # Get metric from model, while adding the hook at downstream_component.name 278 | 279 | else: 280 | print(f'Pruning {parent=} {downstream_component=}') 281 | fwd_cache_hook_name = parent.hook_point_name if downstream_component.incoming_edge_type == str(EdgeType.ADDITION) else downstream_component.hook_point_name 282 | fwd_cache_index = parent.index if downstream_component.incoming_edge_type == str(EdgeType.ADDITION) else downstream_component.index 283 | current_result = (clean_grad_cache[downstream_component.hook_point_name][downstream_component.index.as_index] * (clean_cache[fwd_cache_hook_name][fwd_cache_index.as_index] - corrupted_cache[fwd_cache_hook_name][fwd_cache_index.as_index])).sum() 284 | 285 | if attr_absolute_val: 286 | current_result = current_result.abs() 287 | results[parent, downstream_component] = current_result.item() 288 | # for position in exp.positions: # TODO add this back in! 289 | 290 | if mode == "edge_activation_patching": 291 | pass 292 | 293 | else: 294 | edge_tuple = (downstream_component.hook_point_name, downstream_component.index, parent.hook_point_name, parent.index) 295 | should_prune = current_result < threshold 296 | 297 | if should_prune: 298 | exp.corr.edges[edge_tuple[0]][edge_tuple[1]][edge_tuple[2]][edge_tuple[3]].present = False 299 | exp.corr.remove_edge(*edge_tuple) 300 | 301 | else: 302 | if verbose: # Putting this here since tons of things get pruned when doing edges! 303 | print(f'NOT PRUNING {parent=} {downstream_component=} with attribution {current_result}') 304 | t.cuda.empty_cache() 305 | return results 306 | 307 | else: 308 | raise Exception(f"Mode {mode} not supported") -------------------------------------------------------------------------------- /vast-startup.sh: -------------------------------------------------------------------------------- 1 | apt update && apt upgrade 2 | 3 | git init 4 | git remote set-url origin https://Aaquib111:$github_token@github.com/Aaquib111/acdcpp.git 5 | git config --global user.name $github_username 6 | git config --global user.email $email 7 | git config --global github.user $github_username 8 | git config --global github.token $github_token 9 | 10 | apt install graphviz 11 | apt install graphviz-dev 12 | 13 | cd .. 14 | git clone https://github.com/Aaquib111/Automatic-Circuit-Discovery.git 15 | cd acdcpp/Automatic-Circuit-Discovery/ 16 | git submodule init 17 | git submodule update 18 | git pull origin master 19 | 20 | pip install git+https://github.com/ArthurConmy/TransformerLens@arthur-fix-tokenizer plotly pygraphviz 21 | pip install -r requirements.txt 22 | --------------------------------------------------------------------------------