├── .gitattributes ├── .gitconfig ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── acdc ├── TLACDCCorrespondence.py ├── TLACDCEdge.py ├── TLACDCExperiment.py ├── TLACDCInterpNode.py ├── __init__.py ├── acdc_graphics.py ├── acdc_utils.py ├── docstring │ ├── __init__.py │ ├── prompts.py │ └── utils.py ├── global_cache.py ├── greaterthan │ ├── __init__.py │ └── utils.py ├── induction │ ├── __init__.py │ └── utils.py ├── ioi │ ├── __init__.py │ ├── ioi_dataset.py │ └── utils.py ├── logic_gates │ ├── __init__.py │ └── utils.py ├── main.py └── tracr_task │ ├── __init__.py │ └── utils.py ├── assets └── acdc_finds_subgraph.png ├── experiments ├── all_roc_plots.py ├── code_for_or_gate_figure.ipynb ├── collect_data.py ├── docstring_results.py ├── induction_results.py ├── launch_abstract.py ├── launch_abstract_old.py ├── launch_all_sixteen_heads.py ├── launch_docstring.py ├── launch_induction.py ├── launch_sixteen_heads.py ├── launch_spreadsheet.py ├── launcher.py └── results │ ├── .gitignore │ ├── auc_tables.tex │ ├── canonical_circuits │ ├── .gitignore │ ├── Makefile │ ├── greaterthan │ │ ├── Makefile │ │ └── layout.gv │ └── ioi │ │ ├── .gitignore │ │ ├── Makefile │ │ └── layout.gv │ ├── plots │ └── .gitignore │ └── plots_data │ ├── 16h-docstring-docstring_metric-False-0.json │ ├── 16h-docstring-docstring_metric-False-1.json │ ├── 16h-docstring-docstring_metric-True-0.json │ ├── 16h-docstring-docstring_metric-True-1.json │ ├── 16h-docstring-kl_div-False-0.json │ ├── 16h-docstring-kl_div-False-1.json │ ├── 16h-docstring-kl_div-True-0.json │ ├── 16h-docstring-kl_div-True-1.json │ ├── Makefile │ ├── acdc-docstring-docstring_metric-False-0.json │ ├── acdc-docstring-docstring_metric-False-1.json │ ├── acdc-docstring-kl_div-False-0.json │ ├── acdc-docstring-kl_div-False-1.json │ ├── acdc-greaterthan-greaterthan-False-0.json │ ├── acdc-greaterthan-greaterthan-True-0.json │ ├── acdc-greaterthan-kl_div-False-0.json │ ├── acdc-greaterthan-kl_div-True-0.json │ ├── acdc-ioi-kl_div-False-0.json │ ├── acdc-ioi-kl_div-False-1.json │ ├── acdc-ioi-logit_diff-False-0.json │ ├── acdc-ioi-logit_diff-False-1.json │ ├── acdc-tracr-proportion-l2-False-0.json │ ├── acdc-tracr-proportion-l2-False-1.json │ ├── acdc-tracr-reverse-l2-False-0.json │ ├── acdc-tracr-reverse-l2-False-1.json │ ├── generate_makefile.py │ ├── sp-docstring-docstring_metric-False-0.json │ ├── sp-docstring-docstring_metric-False-1.json │ ├── sp-docstring-docstring_metric-True-0.json │ ├── sp-docstring-docstring_metric-True-1.json │ ├── sp-docstring-kl_div-False-0.json │ ├── sp-docstring-kl_div-False-1.json │ ├── sp-docstring-kl_div-True-0.json │ └── sp-docstring-kl_div-True-1.json ├── ims ├── current_paper_induction_json.json ├── current_paper_induction_zero.json ├── induction_json.json ├── make_jsons.py └── my_new_plotly_graph.json ├── notebooks ├── _converted │ └── .gitignore ├── auc_tables.py ├── colabs │ ├── ACDC_Editing_Edges_Demo.ipynb │ ├── ACDC_Implementation_Demo.ipynb │ └── ACDC_Main_Demo.ipynb ├── convert_to_ipynb.sh ├── df_plots_data.py ├── easier_roc_plot.py ├── editing_edges.py ├── emacs_plotly_render.py ├── implementation_demo.py ├── make_plotly_plots.py ├── minimal_acdc_node_roc.py ├── pareto_plot.py └── roc_plot_generator.py ├── poetry.lock ├── pyproject.toml ├── subnetwork_probing ├── README.md ├── create_reset_networks.py ├── launch_grid_fill.py ├── train.py └── transformer_lens │ ├── .pre-commit-config.yaml │ ├── Attribution_Patching_Demo.ipynb │ ├── Exploratory_Analysis_Demo.ipynb │ ├── Interactive Neuroscope.ipynb │ ├── LICENSE │ ├── Main_Demo.ipynb │ ├── No_Position_Experiment.ipynb │ ├── Old_Demo.ipynb │ ├── README.md │ ├── Tracr_to_Transformer_Lens_Demo.ipynb │ ├── activation_patching_in_TL_demo.py.ipynb │ ├── easy_transformer │ └── __init__.py │ ├── further_comments.md │ ├── poetry.lock │ ├── pyproject.toml │ ├── setup.py │ ├── transformer_lens │ ├── ActivationCache.py │ ├── FactoredMatrix.py │ ├── HookedTransformer.py │ ├── HookedTransformerConfig.py │ ├── __init__.py │ ├── components.py │ ├── evals.py │ ├── hook_points.py │ ├── ioi_dataset.py │ ├── loading_from_pretrained.py │ ├── make_docs.py │ ├── model_properties_table.md │ ├── past_key_value_caching.py │ ├── patching.py │ ├── torchtyping_helper.py │ ├── train.py │ └── utils.py │ └── typing_demo.py └── tests ├── acdc ├── test_acdc.py └── test_greaterthan.py └── subnetwork_probing ├── test_count_nodes.py └── test_sp_launch.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb merge=nbdev-merge 2 | *.pt filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitconfig: -------------------------------------------------------------------------------- 1 | # Generated by nbdev_install_hooks 2 | # 3 | # If you need to disable this instrumentation do: 4 | # git config --local --unset include.path 5 | # 6 | # To restore: 7 | # git config --local include.path ../.gitconfig 8 | # 9 | [merge "nbdev-merge"] 10 | name = resolve conflicts with nbdev_fix 11 | driver = nbdev_merge %O %A %B %P 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | out.txt 2 | out2.txt 3 | out1.txt 4 | acdc.egg-info/ 5 | *.sage.py 6 | *.pt 7 | *.pkl 8 | better*.json 9 | histories/ 10 | transformer_lens/media/ 11 | profile*.json 12 | *.png 13 | *.gv 14 | !experiments/results/canonical_circuits/*/layout.gv 15 | __pycache__ 16 | Testing_Notebook.ipynb 17 | 18 | scratch.py 19 | transformer_lens/scratch.py 20 | 21 | .vscode/ 22 | .idea/ 23 | wandb* 24 | transformer_lens\.egg* 25 | MANIFEST.in 26 | settings.ini 27 | _proc 28 | core.py 29 | nbs 30 | _modidx.py 31 | .ipynb_checkpoints 32 | env 33 | dist/ 34 | docs/build 35 | .coverage 36 | 37 | # don't really know what these are.. 38 | .hypothesis/ 39 | arthur 40 | et_model_state_dict_the_(1).pt 41 | gpt2_hypothesis_tree_Sun Jan 29 23:24:51 2023.png 42 | gpt2_hypothesis_tree_Tue Jan 31 09:50:47 2023.png 43 | gpt2_hypothesis_tree_Tue Jan 31 10:04:54 2023.png 44 | gpt2_hypothesis_tree_Tue Jan 31 10:07:44 2023.png 45 | gpt2_hypothesis_tree_Tue Jan 31 10:34:57 2023.png 46 | gpt2_hypothesis_tree_Tue Jan 31 10:45:48 2023.png 47 | gpt2_hypothesis_tree_Tue Jan 31 11:05:51 2023.png 48 | gpt2_hypothesis_tree_Tue Jan 31 11:40:10 2023.png 49 | gpt2_hypothesis_tree_Tue Jan 31 11:45:14 2023.png 50 | hypothesis_tree.dot 51 | openwebtext-10k.jsonl 52 | openwebtext-10k.jsonl.gz 53 | profile.json 54 | profile2.json 55 | profile_trash.json 56 | pts/ 57 | .coverage* 58 | .Ds_Store 59 | .pylintrc 60 | **/.DS_Store 61 | 62 | ob-jupyter 63 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04 2 | LABEL org.opencontainers.image.source=https://github.com/arthurconmy/Automatic-Circuit-Discovery 3 | ENV DEBIAN_FRONTEND noninteractive 4 | 5 | RUN apt-get update -q \ 6 | && apt-get install -y --no-install-recommends \ 7 | wget git git-lfs \ 8 | python3 python3-dev python3-pip python3-venv python3-setuptools python-is-python3 \ 9 | libgl1-mesa-glx graphviz graphviz-dev \ 10 | && apt-get clean \ 11 | && rm -rf /var/lib/apt/lists/* 12 | 13 | # This venv only holds Poetry and its dependencies. They are isolated from the main project dependencies. 14 | ENV POETRY_HOME="/opt/poetry" 15 | RUN python3 -m venv $POETRY_HOME \ 16 | # Here we use the pip inside $POETRY_HOME but afterwards we should not 17 | && "$POETRY_HOME/bin/pip" install poetry==1.4.2 \ 18 | && rm -rf "${HOME}/.cache" 19 | ENV POETRY="${POETRY_HOME}/bin/poetry" 20 | 21 | WORKDIR "/Automatic-Circuit-Discovery" 22 | COPY --chown=root:root pyproject.toml poetry.lock ./ 23 | 24 | # Don't create a virtualenv, the Docker container is already enough isolation 25 | RUN "$POETRY" config virtualenvs.create false \ 26 | # Install dependencies 27 | && "$POETRY" install --no-root --no-interaction "--only=main,dev" \ 28 | && rm -rf "${HOME}/.cache" 29 | 30 | # Copy whole repo 31 | COPY --chown=root:root . . 32 | # Abort if repo is dirty 33 | RUN if ! { [ -z "$(git status --porcelain --ignored=traditional)" ] \ 34 | ; }; then exit 1; fi 35 | 36 | # Finally install this package 37 | RUN "$POETRY" install --only-root --no-interaction 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Arthur Conmy, Adrià Garriga-Alonso 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Python](https://img.shields.io/badge/python-3.8%2B-blue)]() [![Open Pull Requests](https://img.shields.io/github/issues-pr/ArthurConmy/Automatic-Circuit-Discovery.svg)](https://github.com/ArthurConmy/Automatic-Circuit-Discovery/pulls) 2 | 3 | # Automatic Circuit DisCovery 4 | 5 | ![](assets/acdc_finds_subgraph.png) 6 | 7 | This is the accompanying code to the paper ["Towards Automated Circuit Discovery for Mechanistic Interpretability"](https://arxiv.org/abs/2304.14997) (NeurIPS 2023 Spotlight). 8 | 9 | ⚠️ You may wish to use the repo Auto Circuit by @UFO-101 as this codebase has many sharp edges. Nevertheless: 10 | 11 | * :zap: To run ACDC, see `acdc/main.py`, or this Colab notebook 12 | * :wrench: To see how edit edges in computational graphs in models, see `notebooks/editing_edges.py` or this Colab notebook 13 | * :sparkle: To understand the low-level implementation of completely editable computational graphs, see this Colab notebook or `notebooks/implementation_demo.py` 14 | 15 | This library builds upon the abstractions (`HookPoint`s and standardised `HookedTransformer`s) from [TransformerLens](https://github.com/neelnanda-io/TransformerLens) :mag_right: 16 | 17 | ## Installation: 18 | 19 | First, install the system dependencies for either [Mac](#apple-mac-os-x) or [Linux](#penguin-ubuntu-linux). 20 | 21 | Then, you need Python 3.8+ and [Poetry](https://python-poetry.org/docs/) to install ACDC, like so 22 | 23 | ```bash 24 | git clone git+https://github.com/ArthurConmy/Automatic-Circuit-Discovery.git 25 | cd Automatic-Circuit-Discovery 26 | poetry env use 3.10 # Or be inside a conda or venv environment 27 | # Python 3.10 is recommended but use any Python version >= 3.8 28 | poetry install 29 | ``` 30 | 31 | ### System Dependencies 32 | 33 | #### :penguin: Ubuntu Linux 34 | 35 | ```bash 36 | sudo apt-get update && sudo apt-get install libgl1-mesa-glx graphviz build-essential graphviz-dev 37 | ``` 38 | 39 | You may also need `apt-get install python3.x-dev` where `x` is your Python version (also see [the issue](https://github.com/ArthurConmy/Automatic-Circuit-Discovery/issues/57) and [pygraphviz installation troubleshooting](https://pygraphviz.github.io/documentation/stable/install.html)) 40 | 41 | #### :apple: Mac OS X 42 | 43 | On Mac, you need to let pip (inside poetry) know about the path to the Graphviz libraries. 44 | 45 | ``` 46 | brew install graphviz 47 | export CFLAGS="-I$(brew --prefix graphviz)/include" 48 | export LDFLAGS="-L$(brew --prefix graphviz)/lib" 49 | ``` 50 | 51 | ### Reproducing results 52 | 53 | To reproduce the Pareto Frontier of KL divergences against number of edges for ACDC runs, run `python experiments/launch_induction.py`. Similarly, `python experiments/launch_sixteen_heads.py` and `python subnetwork_probing/train.py` were used to generate individual data points for the other methods, using the CLI help. All these three commands can produce wandb runs. We use `notebooks/roc_plot_generator.py` to process data from wandb runs into JSON files (see `experiments/results/plots_data/Makefile` for the commands) and `notebooks/make_plotly_plots.py` to produce plots from these JSON files. 54 | 55 | ## Tests 56 | 57 | From the root directory, run 58 | 59 | ```bash 60 | pytest -vvv -m "not slow" 61 | ``` 62 | 63 | This will only select tests not marked as `slow`. These tests take a _long_ time, and are good to run occasionally, but 64 | not every time. 65 | 66 | You can run the slow tests with 67 | 68 | ``` bash 69 | pytest -s -m slow 70 | ``` 71 | 72 | ## Contributing 73 | 74 | We welcome issues where the code is unclear! 75 | 76 | If your PR affects the main demo, rerun 77 | ```bash 78 | chmod +x experiments/make_notebooks.sh 79 | ./experiments/make_notebooks.sh 80 | ``` 81 | to automatically turn the `main.py` into a working demo and check that no errors arise. It is essential that the notebooks converted here consist only of `#%% [markdown]` markdown-only cells, and `#%%` cells with code. 82 | 83 | ## Citing ACDC 84 | 85 | If you use ACDC, please reach out! You can reference the work as follows: 86 | 87 | ``` 88 | @inproceedings{conmy2023automated, 89 | title={Towards Automated Circuit Discovery for Mechanistic Interpretability}, 90 | author={Arthur Conmy and Augustine N. Mavor-Parker and Aengus Lynch and Stefan Heimersheim and Adri{\`a} Garriga-Alonso}, 91 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 92 | year={2023}, 93 | eprint={2304.14997}, 94 | archivePrefix={arXiv}, 95 | primaryClass={cs.LG} 96 | } 97 | ``` 98 | 99 | ## TODO 100 | 101 |
102 | Mostly finished TODO list 103 | 104 | [ x ] Make `TransformerLens` install be Neel's code not my PR 105 | 106 | [ x ] Add `hook_mlp_in` to `TransformerLens` and delete `hook_resid_mid` (and test to ensure no bad things?) 107 | 108 | [ x ] Delete `arthur-try-merge-tl` references from the repo 109 | 110 | [ x ] Make notebook on abstractions 111 | 112 | [ ? ] Fix huge edge sizes in Induction Main example and change that occurred 113 | 114 | [ x ] Find a better way to deal with the versioning on the Colabs installs... 115 | 116 | [ ] Neuron-level experiments 117 | 118 | [ ] Position-level experiments 119 | 120 | [ ] Edge gradient descent experiments 121 | 122 | [ ] Implement the circuit breaking paper 123 | 124 | [ x ] `tracr` and other dependencies better managed 125 | 126 | [ ? ] Make SP tests work (lots outdated so skipped) - and check SubnetworkProbing installs properly (no __init__.pys !!!) 127 | 128 | [ ? ] Make the 9 tests also failing on TransformerLens-main pass 129 | 130 | [ x ] Remove Codebase under construction 131 | 132 |
133 | -------------------------------------------------------------------------------- /acdc/TLACDCEdge.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import defaultdict 3 | from enum import Enum 4 | from typing import Optional, List 5 | 6 | 7 | class EdgeType(Enum): 8 | """ 9 | Property of edges in the computational graph - either 10 | 11 | ADDITION: the child (hook_name, index) is a sum of the parent (hook_name, index)s 12 | DIRECT_COMPUTATION The *single* child is a function of and only of the parent (e.g the value hooked by hook_q is a function of what hook_q_input saves). 13 | PLACEHOLDER generally like 2. but where there are generally multiple parents. Here in ACDC we just include these edges by default when we find them. Explained below? 14 | 15 | Q: Why do we do this? 16 | 17 | There are two answers to this question: A1 is an interactive notebook, see this Colab notebook, which is in this repo at notebooks/implementation_demo.py. A2 is an answer that is written here below, but probably not as clear as A1 (though shorter). 18 | 19 | A2: We need something inside TransformerLens to represent the edges of a computational graph. 20 | The object we choose is pairs (hook_name, index). For example the output of Layer 11 Heads is a hook (blocks.11.attn.hook_result) and to sepcify the 3rd head we add the index [:, :, 3]. Then we can build a computational graph on these! 21 | 22 | However, when we do ACDC there turn out to be two conflicting things "removing edges" wants to do: 23 | i) for things in the residual stream, we want to remove the sum of the effects from previous hooks 24 | ii) for things that are not linear we want to *recompute* e.g the result inside the hook 25 | blocks.11.attn.hook_result from a corrupted Q and normal K and V 26 | 27 | The easiest way I thought of of reconciling these different cases, while also having a connected computational graph, is to have three types of edges: addition for the residual case, direct computation for easy cases where we can just replace hook_q with a cached value when we e.g cut it off from hook_q_input, and placeholder to make the graph connected (when hook_result is connected to hook_q and hook_k and hook_v)""" 28 | 29 | ADDITION = 0 30 | DIRECT_COMPUTATION = 1 31 | PLACEHOLDER = 2 32 | 33 | def __eq__(self, other): 34 | """Necessary because of extremely frustrating error that arises with load_ext autoreload (because this uses importlib under the hood: https://stackoverflow.com/questions/66458864/enum-comparison-become-false-after-reloading-module)""" 35 | 36 | assert isinstance(other, EdgeType) 37 | return self.value == other.value 38 | 39 | 40 | class Edge: 41 | def __init__( 42 | self, 43 | edge_type: EdgeType, 44 | present: bool = True, 45 | effect_size: Optional[float] = None, 46 | ): 47 | self.edge_type = edge_type 48 | self.present = present 49 | self.effect_size = effect_size 50 | 51 | def __repr__(self) -> str: 52 | return f"Edge({self.edge_type}, {self.present})" 53 | 54 | class TorchIndex: 55 | """There is not a clean bijection between things we 56 | want in the computational graph, and things that are hooked 57 | (e.g hook_result covers all heads in a layer) 58 | 59 | `TorchIndex`s are essentially indices that say which part of the tensor is being affected. 60 | 61 | EXAMPLES: Initialise [:, :, 3] with TorchIndex([None, None, 3]) and [:] with TorchIndex([None]) 62 | 63 | Also we want to be able to call e.g `my_dictionary[my_torch_index]` hence the hashable tuple stuff 64 | 65 | Note: ideally this would be integrated with transformer_lens.utils.Slice in future; they are accomplishing similar but different things""" 66 | 67 | def __init__( 68 | self, 69 | list_of_things_in_tuple: List, 70 | ): 71 | # check correct types 72 | for arg in list_of_things_in_tuple: 73 | if type(arg) in [type(None), int]: 74 | continue 75 | else: 76 | assert isinstance(arg, list) 77 | assert all([type(x) == int for x in arg]) 78 | 79 | # make an object that can be indexed into a tensor 80 | self.as_index = tuple([slice(None) if x is None else x for x in list_of_things_in_tuple]) 81 | 82 | # make an object that can be hashed (so used as a dictionary key) 83 | self.hashable_tuple = tuple(list_of_things_in_tuple) 84 | 85 | def __hash__(self): 86 | return hash(self.hashable_tuple) 87 | 88 | def __eq__(self, other): 89 | return self.hashable_tuple == other.hashable_tuple 90 | 91 | # some graphics things 92 | 93 | def __repr__(self, use_actual_colon=True) -> str: # graphviz, an old library used to dislike actual colons in strings, but this shouldn't be an issue anymore 94 | ret = "[" 95 | for idx, x in enumerate(self.hashable_tuple): 96 | if idx > 0: 97 | ret += ", " 98 | if x is None: 99 | ret += ":" if use_actual_colon else "COLON" 100 | elif type(x) == int: 101 | ret += str(x) 102 | else: 103 | raise NotImplementedError(x) 104 | ret += "]" 105 | return ret 106 | 107 | def graphviz_index(self, use_actual_colon=True) -> str: 108 | return self.__repr__(use_actual_colon=use_actual_colon) 109 | -------------------------------------------------------------------------------- /acdc/TLACDCInterpNode.py: -------------------------------------------------------------------------------- 1 | from acdc.TLACDCEdge import ( 2 | TorchIndex, 3 | Edge, 4 | EdgeType, 5 | ) # these introduce several important classes !!! 6 | from typing import List, Dict, Optional, Tuple, Union, Set, Callable, TypeVar, Iterable, Any 7 | 8 | class TLACDCInterpNode: 9 | """Represents one node in the computational graph, similar to ACDCInterpNode from the rust_circuit code 10 | 11 | But WARNING this has nodes closer to the input tokens as *parents* of nodes closer to the output tokens, the opposite of the rust_circuit code 12 | 13 | Params: 14 | name: name of the node 15 | index: the index of the tensor that this node represents 16 | mode: how we deal with this node when we bump into it as a parent of another node. Addition: it's summed to make up the child. Direct_computation: it's the sole node used to compute the child. Off: it's not the parent of a child ever.""" 17 | 18 | def __init__(self, name: str, index: TorchIndex, incoming_edge_type: EdgeType): 19 | 20 | self.name = name 21 | self.index = index 22 | 23 | self.parents: List["TLACDCInterpNode"] = [] 24 | self.children: List["TLACDCInterpNode"] = [] 25 | 26 | self.incoming_edge_type = incoming_edge_type 27 | 28 | def _add_child(self, child_node: "TLACDCInterpNode"): 29 | """Use the method on TLACDCCorrespondence instead of this one""" 30 | self.children.append(child_node) 31 | 32 | def _add_parent(self, parent_node: "TLACDCInterpNode"): 33 | """Use the method on TLACDCCorrespondence instead of this one""" 34 | self.parents.append(parent_node) 35 | 36 | def __repr__(self): 37 | return f"TLACDCInterpNode({self.name}, {self.index})" 38 | 39 | def __str__(self) -> str: 40 | index_str = "" if len(self.index.hashable_tuple) < 3 else f"_{self.index.hashable_tuple[2]}" 41 | return f"{self.name}{self.index}" 42 | 43 | # ------------------ 44 | # some munging utils 45 | # ------------------ 46 | 47 | def parse_interpnode(s: str) -> TLACDCInterpNode: 48 | try: 49 | name, idx = s.split("[") 50 | name = name.replace("hook_resid_mid", "hook_mlp_in") 51 | try: 52 | idx = int(idx[-3:-1]) 53 | except: 54 | try: 55 | idx = int(idx[-2]) 56 | except: 57 | idx = None 58 | return TLACDCInterpNode(name, TorchIndex([None, None, idx]) if idx is not None else TorchIndex([None]), EdgeType.ADDITION) 59 | 60 | except Exception as e: 61 | print(s, e) 62 | raise e 63 | 64 | return TLACDCInterpNode(name, TorchIndex([None, None, idx]), EdgeType.ADDITION) 65 | 66 | def heads_to_nodes_to_mask(heads: List[Tuple[int, int]], return_dict=False): 67 | nodes_to_mask_strings = [ 68 | f"blocks.{layer_idx}{'.attn' if not inputting else ''}.hook_{letter}{'_input' if inputting else ''}[COL, COL, {head_idx}]" 69 | # for layer_idx in range(model.cfg.n_layers) 70 | # for head_idx in range(model.cfg.n_heads) 71 | for layer_idx, head_idx in heads 72 | for letter in ["q", "k", "v"] 73 | for inputting in [True, False] 74 | ] 75 | nodes_to_mask_strings.extend([ 76 | f"blocks.{layer_idx}.attn.hook_result[COL, COL, {head_idx}]" 77 | for layer_idx, head_idx in heads 78 | ]) 79 | 80 | if return_dict: 81 | return {s: parse_interpnode(s) for s in nodes_to_mask_strings} 82 | 83 | else: 84 | return [parse_interpnode(s) for s in nodes_to_mask_strings] 85 | -------------------------------------------------------------------------------- /acdc/__init__.py: -------------------------------------------------------------------------------- 1 | def check_transformer_lens_version(): 2 | """Test that your TransformerLens version is up-to-date for ACDC 3 | by checking that `hook_mlp_in`s exist""" 4 | 5 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig 6 | 7 | cfg = HookedTransformerConfig.from_dict( 8 | { 9 | "n_layers": 1, 10 | "d_model": 1, 11 | "n_ctx": 1, 12 | "d_head": 1, 13 | "act_fn": "gelu", 14 | "d_vocab": 0, 15 | } 16 | ) 17 | 18 | from transformer_lens.HookedTransformer import HookedTransformer 19 | mini_trans = HookedTransformer(cfg) 20 | 21 | mini_trans.blocks[0].hook_mlp_in # try and access the hook_mlp_in: if this fails, your TL is not sufficiently up-to-date 22 | 23 | check_transformer_lens_version() -------------------------------------------------------------------------------- /acdc/docstring/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/acdc/docstring/__init__.py -------------------------------------------------------------------------------- /acdc/global_cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Union, Tuple, Literal, Dict 3 | from collections import OrderedDict 4 | 5 | 6 | class GlobalCache: # this dict stores the activations from the forward pass 7 | """Class for managing several caches for passing activations around""" 8 | 9 | def __init__(self, device: Union[str, Tuple[str, str]] = "cuda"): 10 | # TODO find a way to make the device propagate when we to .to on the p 11 | # TODO make it essential first key is a str, second a TorchIndex, third a str 12 | 13 | if isinstance(device, str): 14 | device = (device, device) 15 | 16 | self.online_cache = OrderedDict() 17 | self.corrupted_cache = OrderedDict() 18 | self.device: Tuple[str, str] = (device, device) 19 | 20 | 21 | def clear(self, just_first_cache=False): 22 | 23 | if not just_first_cache: 24 | self.online_cache = OrderedDict() 25 | else: 26 | raise NotImplementedError() 27 | self.__init__(self.device[0], self.device[1]) # lol 28 | 29 | import gc 30 | gc.collect() 31 | torch.cuda.empty_cache() 32 | 33 | def to(self, device, which_caches: Literal["online", "corrupted", "all"]="all"): # 34 | 35 | caches = [] 36 | if which_caches != "online": 37 | self.device = (device, self.device[1]) 38 | caches.append(self.online_cache) 39 | if which_caches != "corrupted": 40 | self.device = (self.device[0], device) 41 | caches.append(self.corrupted_cache) 42 | 43 | # move all the parameters 44 | for cache in caches: # mutable means this works.. 45 | for name in cache: 46 | cache_keys = list(cache.keys()) 47 | for k in cache_keys: 48 | cache[k].to(device) # = cache[name].to(device) 49 | 50 | return self -------------------------------------------------------------------------------- /acdc/greaterthan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/acdc/greaterthan/__init__.py -------------------------------------------------------------------------------- /acdc/induction/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/acdc/induction/__init__.py -------------------------------------------------------------------------------- /acdc/induction/utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from functools import partial 3 | from acdc.docstring.utils import AllDataThings 4 | import wandb 5 | import os 6 | from collections import defaultdict 7 | import pickle 8 | import torch 9 | import huggingface_hub 10 | import datetime 11 | from typing import Dict, Callable 12 | import torch 13 | import random 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from typing import ( 17 | List, 18 | Tuple, 19 | Dict, 20 | Any, 21 | Optional, 22 | ) 23 | import warnings 24 | import networkx as nx 25 | from acdc.acdc_utils import ( 26 | MatchNLLMetric, 27 | make_nd_dict, 28 | shuffle_tensor, 29 | ) 30 | 31 | from acdc.TLACDCEdge import ( 32 | TorchIndex, 33 | Edge, 34 | EdgeType, 35 | ) # these introduce several important classes !!! 36 | from transformer_lens import HookedTransformer 37 | from acdc.acdc_utils import kl_divergence, negative_log_probs 38 | 39 | def get_model(device): 40 | tl_model = HookedTransformer.from_pretrained( 41 | "redwood_attn_2l", # load Redwood's model 42 | center_writing_weights=False, # these are needed as this model is a Shortformer; this is a technical detail 43 | center_unembed=False, 44 | fold_ln=False, 45 | device=device, 46 | ) 47 | 48 | # standard ACDC options 49 | tl_model.set_use_attn_result(True) 50 | tl_model.set_use_split_qkv_input(True) 51 | return tl_model 52 | 53 | def get_validation_data(num_examples=None, seq_len=None, device=None): 54 | validation_fname = huggingface_hub.hf_hub_download( 55 | repo_id="ArthurConmy/redwood_attn_2l", filename="validation_data.pt" 56 | ) 57 | validation_data = torch.load(validation_fname, map_location=device).long() 58 | 59 | if num_examples is None: 60 | return validation_data 61 | else: 62 | return validation_data[:num_examples][:seq_len] 63 | 64 | def get_good_induction_candidates(num_examples=None, seq_len=None, device=None): 65 | """Not needed?""" 66 | good_induction_candidates_fname = huggingface_hub.hf_hub_download( 67 | repo_id="ArthurConmy/redwood_attn_2l", filename="good_induction_candidates.pt" 68 | ) 69 | good_induction_candidates = torch.load(good_induction_candidates_fname, map_location=device) 70 | 71 | if num_examples is None: 72 | return good_induction_candidates 73 | else: 74 | return good_induction_candidates[:num_examples][:seq_len] 75 | 76 | def get_mask_repeat_candidates(num_examples=None, seq_len=None, device=None): 77 | mask_repeat_candidates_fname = huggingface_hub.hf_hub_download( 78 | repo_id="ArthurConmy/redwood_attn_2l", filename="mask_repeat_candidates.pkl" 79 | ) 80 | mask_repeat_candidates = torch.load(mask_repeat_candidates_fname, map_location=device) 81 | mask_repeat_candidates.requires_grad = False 82 | 83 | if num_examples is None: 84 | return mask_repeat_candidates 85 | else: 86 | return mask_repeat_candidates[:num_examples, :seq_len] 87 | 88 | 89 | def get_all_induction_things(num_examples, seq_len, device, data_seed=42, metric="kl_div", return_one_element=True) -> AllDataThings: 90 | tl_model = get_model(device=device) 91 | 92 | validation_data_orig = get_validation_data(device=device) 93 | mask_orig = get_mask_repeat_candidates(num_examples=None, device=device) # None so we get all 94 | assert validation_data_orig.shape == mask_orig.shape 95 | 96 | assert seq_len <= validation_data_orig.shape[1]-1 97 | 98 | validation_slice = slice(0, num_examples) 99 | validation_data = validation_data_orig[validation_slice, :seq_len].contiguous() 100 | validation_labels = validation_data_orig[validation_slice, 1:seq_len+1].contiguous() 101 | validation_mask = mask_orig[validation_slice, :seq_len].contiguous() 102 | 103 | validation_patch_data = shuffle_tensor(validation_data, seed=data_seed).contiguous() 104 | 105 | test_slice = slice(num_examples, num_examples*2) 106 | test_data = validation_data_orig[test_slice, :seq_len].contiguous() 107 | test_labels = validation_data_orig[test_slice, 1:seq_len+1].contiguous() 108 | test_mask = mask_orig[test_slice, :seq_len].contiguous() 109 | 110 | # data_seed+1: different shuffling 111 | test_patch_data = shuffle_tensor(test_data, seed=data_seed).contiguous() 112 | 113 | with torch.no_grad(): 114 | base_val_logprobs = F.log_softmax(tl_model(validation_data), dim=-1).detach() 115 | base_test_logprobs = F.log_softmax(tl_model(test_data), dim=-1).detach() 116 | 117 | if metric == "kl_div": 118 | validation_metric = partial( 119 | kl_divergence, 120 | base_model_logprobs=base_val_logprobs, 121 | mask_repeat_candidates=validation_mask, 122 | last_seq_element_only=False, 123 | return_one_element=return_one_element, 124 | ) 125 | elif metric == "nll": 126 | validation_metric = partial( 127 | negative_log_probs, 128 | labels=validation_labels, 129 | mask_repeat_candidates=validation_mask, 130 | last_seq_element_only=False, 131 | ) 132 | elif metric == "match_nll": 133 | validation_metric = MatchNLLMetric( 134 | labels=validation_labels, base_model_logprobs=base_val_logprobs, mask_repeat_candidates=validation_mask, 135 | last_seq_element_only=False, 136 | ) 137 | else: 138 | raise ValueError(f"Unknown metric {metric}") 139 | 140 | test_metrics = { 141 | "kl_div": partial( 142 | kl_divergence, 143 | base_model_logprobs=base_test_logprobs, 144 | mask_repeat_candidates=test_mask, 145 | last_seq_element_only=False, 146 | ), 147 | "nll": partial( 148 | negative_log_probs, 149 | labels=test_labels, 150 | mask_repeat_candidates=test_mask, 151 | last_seq_element_only=False, 152 | ), 153 | "match_nll": MatchNLLMetric( 154 | labels=test_labels, base_model_logprobs=base_test_logprobs, mask_repeat_candidates=test_mask, 155 | last_seq_element_only=False, 156 | ), 157 | } 158 | return AllDataThings( 159 | tl_model=tl_model, 160 | validation_metric=validation_metric, 161 | validation_data=validation_data, 162 | validation_labels=validation_labels, 163 | validation_mask=validation_mask, 164 | validation_patch_data=validation_patch_data, 165 | test_metrics=test_metrics, 166 | test_data=test_data, 167 | test_labels=test_labels, 168 | test_mask=test_mask, 169 | test_patch_data=test_patch_data, 170 | ) 171 | 172 | 173 | def one_item_per_batch(toks_int_values, toks_int_values_other, mask_rep, base_model_logprobs, kl_take_mean=True): 174 | """Returns each instance of induction as its own batch idx""" 175 | 176 | end_positions = [] 177 | batch_size, seq_len = toks_int_values.shape 178 | new_tensors = [] 179 | 180 | toks_int_values_other_batch_list = [] 181 | new_base_model_logprobs_list = [] 182 | 183 | for i in range(batch_size): 184 | for j in range(seq_len - 1): # -1 because we don't know what follows the last token so can't calculate losses 185 | if mask_rep[i, j]: 186 | end_positions.append(j) 187 | new_tensors.append(toks_int_values[i].cpu().clone()) 188 | toks_int_values_other_batch_list.append(toks_int_values_other[i].cpu().clone()) 189 | new_base_model_logprobs_list.append(base_model_logprobs[i].cpu().clone()) 190 | 191 | toks_int_values_other_batch = torch.stack(toks_int_values_other_batch_list).to(toks_int_values.device).clone() 192 | return_tensor = torch.stack(new_tensors).to(toks_int_values.device).clone() 193 | end_positions_tensor = torch.tensor(end_positions).long() 194 | 195 | new_base_model_logprobs = torch.stack(new_base_model_logprobs_list)[torch.arange(len(end_positions_tensor)), end_positions_tensor].to(toks_int_values.device).clone() 196 | metric = partial( 197 | kl_divergence, 198 | base_model_logprobs=new_base_model_logprobs, 199 | end_positions=end_positions_tensor, 200 | mask_repeat_candidates=None, # !!! 201 | last_seq_element_only=False, 202 | return_one_element=False 203 | ) 204 | 205 | return return_tensor, toks_int_values_other_batch, end_positions_tensor, metric 206 | -------------------------------------------------------------------------------- /acdc/ioi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/acdc/ioi/__init__.py -------------------------------------------------------------------------------- /acdc/logic_gates/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/acdc/logic_gates/__init__.py -------------------------------------------------------------------------------- /acdc/logic_gates/utils.py: -------------------------------------------------------------------------------- 1 | #%% 2 | 3 | from functools import partial 4 | import time 5 | import torch 6 | from typing import Literal, Optional 7 | from transformer_lens.HookedTransformer import HookedTransformer, HookedTransformerConfig 8 | from acdc.docstring.utils import AllDataThings 9 | from acdc.tracr_task.utils import get_perm 10 | from acdc.acdc_utils import kl_divergence 11 | import torch.nn.functional as F 12 | 13 | MAX_LOGIC_GATE_SEQ_LEN = 100_000 # Can be increased further provided numerics and memory do not explode 14 | 15 | def get_logic_gate_model(mode: Literal["OR", "AND"] = "OR", seq_len: Optional[int]=None, device="cuda") -> HookedTransformer: 16 | 17 | if seq_len is None: 18 | assert 1 <= seq_len <= MAX_LOGIC_GATE_SEQ_LEN, "We need some bound on sequence length, but this can be increased if the variable at the top is increased" 19 | 20 | if mode == "OR": 21 | assert seq_len == 1 22 | cfg = HookedTransformerConfig.from_dict( 23 | { 24 | "n_layers": 1, 25 | "d_model": 2, 26 | "n_ctx": 1, 27 | "n_heads": 2, 28 | "d_head": 1, 29 | "act_fn": "relu", 30 | "d_vocab": 1, 31 | "d_mlp": 1, 32 | "d_vocab_out": 1, 33 | "normalization_type": None, 34 | "attn_only": False, 35 | } 36 | ) 37 | elif mode == "AND": 38 | cfg = HookedTransformerConfig.from_dict( 39 | { 40 | "n_layers": 1, 41 | "d_model": 3, 42 | "n_ctx": seq_len, 43 | "n_heads": 1, 44 | "d_head": 1, 45 | "act_fn": "relu", 46 | "d_vocab": 2, 47 | "d_mlp": 1, 48 | "d_vocab_out": 1, 49 | "normalization_type": None, 50 | } 51 | ) 52 | else: 53 | raise ValueError(f"mode {mode} not recognized") 54 | 55 | model = HookedTransformer(cfg).to(device) 56 | model.set_use_attn_result(True) 57 | model.set_use_split_qkv_input(True) 58 | if "use_hook_mlp_in" in model.cfg.to_dict(): 59 | model.set_use_hook_mlp_in(True) 60 | model = model.to(torch.double) 61 | 62 | # Turn off model gradient so we can edit weights 63 | # And also set all the weights to 0 64 | for param in model.parameters(): 65 | param.requires_grad = False 66 | param[:] = 0.0 67 | 68 | if mode == "AND": 69 | # # Embed 1s as 1.0 in residual component 0 70 | model.embed.W_E[1, 0] = 1.0 71 | 72 | # No QK so uniform attention; this allows us to detect if everything is a 1 as the output into the channel 1 will be 1 not less than that 73 | 74 | # Output 1.0 into residual component 1 for all things present 75 | model.blocks[0].attn.W_V[0, 0, 0] = 1.0 # Shape [head_index d_model d_head] 76 | model.blocks[0].attn.W_O[0, 0, 1] = 1.0 # Shape [head_index d_head d_model] 77 | 78 | model.blocks[0].mlp.W_in[1, 0] = 1.0 # [d_model d_mlp] 79 | model.blocks[0].mlp.b_in[:] = -(MAX_LOGIC_GATE_SEQ_LEN-1)/MAX_LOGIC_GATE_SEQ_LEN # Unless everything in input is a 1, do not fire 80 | 81 | # Write the output to residual component 2 82 | # (TODO: I think we could get away with 2 components here?) 83 | model.blocks[0].mlp.W_out[0, 2] = MAX_LOGIC_GATE_SEQ_LEN # Shape [d_mlp d_model] 84 | 85 | model.unembed.W_U[2, 0] = 1.0 # Shape [d_model d_vocab_out] 86 | 87 | elif mode == "OR": 88 | 89 | # a0.0 and a0.1 are the two inputs to the OR gate; they always dump 1.0 into the residual stream 90 | # Both heads dump a 1 into the residual stream 91 | # We can test our circuit recovery methods with zero ablation to see if they recover either or both heads! 92 | model.blocks[0].attn.b_V[:, 0] = 1.0 # [num_heads, d_head] 93 | model.blocks[0].attn.W_O[:, 0, 0] = 1.0 # [num_heads, d_head, d_model] 94 | 95 | # mlp0 is an OR gate on the output on the output of a0.0 and a0.1; it turns the sum S of their outputs into 1 if S >= 1 and 0 if S = 0 96 | model.blocks[0].mlp.W_in[0, 0] = -1.0 # [d_model d_mlp] 97 | model.blocks[0].mlp.b_in[:] = 1.0 # [d_mlp] 98 | 99 | model.blocks[0].mlp.W_out[0, 1] = -1.0 100 | model.blocks[0].mlp.b_out[:] = 1.0 # [d_model] 101 | 102 | model.unembed.W_U[1, 0] = 1.0 # shape [d_model d_vocab_out] 103 | 104 | else: 105 | raise ValueError(f"mode {mode} not recognized") 106 | 107 | return model 108 | 109 | def test_and_logical_model(): 110 | """ 111 | Test that the AND gate works 112 | """ 113 | 114 | seq_len=3 115 | and_model = get_logic_gate_model(mode="AND", seq_len=seq_len, device = "cpu") 116 | 117 | all_inputs = [] 118 | for i in range(2**seq_len): 119 | input = torch.tensor([int(x) for x in f"{i:03b}"]).unsqueeze(0).long() 120 | all_inputs.append(input) 121 | input = torch.cat(all_inputs, dim=0) 122 | 123 | and_output = and_model(input)[:, -1, :] 124 | assert torch.equal(and_output[:2**seq_len - 1], torch.zeros(2**seq_len - 1, 1)) 125 | torch.testing.assert_close(and_output[2**seq_len - 1], torch.ones(1).to(torch.double)) 126 | 127 | #%% 128 | 129 | def get_all_logic_gate_things(mode: str = "AND", device=None, seq_len: Optional[int] = 5, num_examples: Optional[int] = 10, return_one_element: bool = False) -> AllDataThings: 130 | 131 | assert mode == "OR" 132 | 133 | model = get_logic_gate_model(mode=mode, seq_len=seq_len, device=device) 134 | # Convert the set of binary string back llto tensor 135 | data = torch.tensor([[0.0]]).long() # Input is actually meaningless, all that matters is Attention Heads 0 and 1 136 | correct_answers = data.clone().to(torch.double) + 1 137 | 138 | def validation_metric(output, correct): 139 | output = output[:, -1, :] 140 | 141 | assert output.shape == correct.shape 142 | if not return_one_element: 143 | return torch.mean((output - correct)**2, dim=0) 144 | else: 145 | return ((output - correct)**2).squeeze(1) 146 | 147 | base_validation_logprobs = F.log_softmax(model(data)[:, -1], dim=-1) 148 | 149 | test_metrics = { 150 | "kl_div": partial( 151 | kl_divergence, 152 | base_model_logprobs=base_validation_logprobs, 153 | last_seq_element_only=True, 154 | base_model_probs_last_seq_element_only=False, 155 | return_one_element=return_one_element, 156 | ),} 157 | 158 | return AllDataThings( 159 | tl_model=model, 160 | validation_metric=partial(validation_metric, correct=correct_answers), 161 | validation_data=data, 162 | validation_labels=None, 163 | validation_mask=None, 164 | validation_patch_data=data.clone(), # We're doing zero ablation so irrelevant 165 | test_metrics=test_metrics, 166 | test_data=data, 167 | test_labels=None, 168 | test_mask=None, 169 | test_patch_data=data.clone(), 170 | ) 171 | 172 | 173 | # # # test_logical_models() 174 | # # %% 175 | 176 | # or_model = get_logic_gate_model(seq_len=1, device = "cpu") 177 | # logits, cache = or_model.run_with_cache( 178 | # torch.tensor([[0]]).to(torch.long), 179 | # ) 180 | # print(logits) 181 | 182 | # # %% 183 | 184 | # for key in cache.keys(): 185 | # print(key) 186 | # print(cache[key].shape) 187 | # print(cache[key]) 188 | # print("\n\n\n") 189 | # # %% 190 | # #batch pos head_index d_head for hook_q 191 | # %% 192 | -------------------------------------------------------------------------------- /acdc/tracr_task/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/acdc/tracr_task/__init__.py -------------------------------------------------------------------------------- /assets/acdc_finds_subgraph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/assets/acdc_finds_subgraph.png -------------------------------------------------------------------------------- /experiments/all_roc_plots.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from experiments.launcher import KubernetesJob, launch 3 | import numpy as np 4 | import random 5 | from typing import List 6 | 7 | 8 | TASKS = ["ioi", "docstring", "greaterthan", "tracr-reverse", "tracr-proportion"] 9 | 10 | METRICS_FOR_TASK = { 11 | "ioi": ["kl_div", "logit_diff"], 12 | "tracr-reverse": ["kl_div"], 13 | "tracr-proportion": ["kl_div", "l2"], 14 | "induction": ["kl_div", "nll"], 15 | "docstring": ["kl_div", "docstring_metric"], 16 | "greaterthan": ["kl_div", "greaterthan"], 17 | } 18 | 19 | 20 | def main(): 21 | commands = [] 22 | for alg in ["16h", "sp", "acdc"]: 23 | for reset_network in [0, 1]: 24 | for zero_ablation in [0, 1]: 25 | for task in TASKS: 26 | for metric in METRICS_FOR_TASK[task]: 27 | command = [ 28 | "python", 29 | "notebooks/roc_plot_generator.py", 30 | f"--task={task}", 31 | f"--reset-network={reset_network}", 32 | f"--metric={metric}", 33 | f"--alg={alg}", 34 | ] 35 | if zero_ablation: 36 | command.append("--zero-ablation") 37 | commands.append(command) 38 | 39 | launch(commands, name="plots", job=None, synchronous=False) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /experiments/collect_data.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import argparse 3 | import os 4 | from pathlib import Path 5 | from experiments.launcher import KubernetesJob, WandbIdentifier, launch 6 | import shlex 7 | import random 8 | 9 | IS_ADRIA = not str(os.environ.get("CONDA_DEFAULT_ENV")).lower().startswith("arthur") 10 | if IS_ADRIA: 11 | print("WARNING: IS_ADRIA=True, using Adria's Docker container") 12 | 13 | #TASKS = ["ioi", "docstring", "greaterthan", "tracr-reverse", "tracr-proportion", "induction"] 14 | TASKS = ["ioi", "docstring", "greaterthan", "induction"] 15 | 16 | METRICS_FOR_TASK = { 17 | "ioi": ["kl_div", "logit_diff"], 18 | "tracr-reverse": ["l2"], 19 | "tracr-proportion": ["l2"], 20 | "induction": ["kl_div", "nll"], 21 | "docstring": ["kl_div", "docstring_metric"], 22 | "greaterthan": ["kl_div", "greaterthan"], 23 | } 24 | 25 | 26 | def main( 27 | alg: str, 28 | task: str, 29 | job: KubernetesJob, 30 | testing: bool = False, 31 | mod_idx=0, 32 | num_processes=1, 33 | ): 34 | # mod_idx= MPI.COMM_WORLD.Get_rank() 35 | # num_processes = MPI.COMM_WORLD.Get_size() 36 | 37 | if IS_ADRIA: 38 | OUT_RELPATH = Path(".cache") / "plots_data" 39 | OUT_HOME_DIR = Path(os.environ["HOME"]) / OUT_RELPATH 40 | else: 41 | OUT_RELPATH = Path("experiments/results/arthur_plots_data") # trying to remove extra things from acdc/ 42 | OUT_HOME_DIR = OUT_RELPATH 43 | 44 | assert OUT_HOME_DIR.exists() 45 | 46 | if IS_ADRIA: 47 | OUT_DIR = Path("/root") / OUT_RELPATH 48 | else: 49 | OUT_DIR = OUT_RELPATH 50 | 51 | seed = 1233778640 52 | random.seed(seed) 53 | 54 | commands = [] 55 | for reset_network in [0, 1]: 56 | for zero_ablation in [0, 1]: 57 | for metric in METRICS_FOR_TASK[task]: 58 | if alg == "canonical" and (task == "induction" or metric == "kl_div"): 59 | continue 60 | 61 | command = [ 62 | "python", 63 | "notebooks/roc_plot_generator.py", 64 | f"--task={task}", 65 | f"--reset-network={reset_network}", 66 | f"--metric={metric}", 67 | f"--alg={alg}", 68 | f"--device={'cpu' if testing or not job.gpu else 'cuda'}", 69 | f"--torch-num-threads={job.cpu}", 70 | f"--out-dir={OUT_DIR}", 71 | f"--seed={random.randint(0, 2**31-1)}", 72 | ] 73 | if zero_ablation: 74 | command.append("--zero-ablation") 75 | 76 | if alg == "acdc" and task == "greaterthan" and metric == "kl_div" and not zero_ablation and not reset_network: 77 | command.append("--ignore-missing-score") 78 | commands.append(command) 79 | 80 | if IS_ADRIA: 81 | launch( 82 | commands, 83 | name="collect_data", 84 | job=job, 85 | synchronous=True, 86 | just_print_commands=False, 87 | check_wandb=WandbIdentifier(f"agarriga-col-{alg}-{task[-5:]}-{{i:04d}}b", "collect", "acdc"), 88 | ) 89 | 90 | else: 91 | for command_idx in range(mod_idx, len(commands), num_processes): # commands: 92 | # run 4 in parallel 93 | command = commands[command_idx] 94 | print(f"Running command {command_idx} / {len(commands)}") 95 | print(" ".join(command)) 96 | subprocess.run(command) 97 | 98 | 99 | tasks_for = { 100 | "acdc": ["ioi", "greaterthan"], 101 | "16h": TASKS, 102 | "sp": TASKS, 103 | "canonical": TASKS, 104 | } 105 | 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument("--i", type=int, default=0) 108 | parser.add_argument("--n", type=int, default=1) 109 | 110 | args = parser.parse_args() 111 | mod_idx = args.i 112 | num_processes = args.n 113 | 114 | if __name__ == "__main__": 115 | for alg in ["acdc"]: # , "16h", "sp", "canonical"]: 116 | for task in tasks_for[alg]: 117 | main( 118 | alg, 119 | task, 120 | KubernetesJob( 121 | container="ghcr.io/rhaps0dy/automatic-circuit-discovery:e1884e4", 122 | cpu=6, 123 | gpu=0 if not IS_ADRIA or task.startswith("tracr") or alg not in ["acdc", "canonical"] else 1, 124 | mount_training=False, 125 | ), 126 | testing=False, 127 | mod_idx=mod_idx, 128 | num_processes=num_processes, 129 | ) 130 | -------------------------------------------------------------------------------- /experiments/docstring_results.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import pandas as pd 4 | import numpy as np 5 | import plotly.express as px 6 | import wandb 7 | 8 | import hashlib 9 | import os 10 | 11 | import plotly.io as pio 12 | import plotly.graph_objects as go 13 | from plotly.subplots import make_subplots 14 | 15 | 16 | class EmacsRenderer(pio.base_renderers.ColabRenderer): 17 | save_dir = "ob-jupyter" 18 | base_url = f"http://localhost:8888/files" 19 | 20 | def to_mimebundle(self, fig_dict): 21 | html = super().to_mimebundle(fig_dict)["text/html"] 22 | 23 | mhash = hashlib.md5(html.encode("utf-8")).hexdigest() 24 | if not os.path.isdir(self.save_dir): 25 | os.mkdir(self.save_dir) 26 | fhtml = os.path.join(self.save_dir, mhash + ".html") 27 | with open(fhtml, "w") as f: 28 | f.write(html) 29 | 30 | return {"text/html": f'Click to open {fhtml}'} 31 | 32 | 33 | pio.renderers["emacs"] = EmacsRenderer() 34 | 35 | 36 | def set_plotly_renderer(renderer="emacs"): 37 | pio.renderers.default = renderer 38 | 39 | 40 | set_plotly_renderer("emacs") 41 | 42 | ACDC_GROUP = "adria-docstring3" 43 | SP_GROUP = "docstring3" 44 | 45 | # %% 46 | 47 | api = wandb.Api() 48 | all_runs = api.runs(path="remix_school-of-rock/acdc", filters={"group": ACDC_GROUP}) 49 | 50 | df = pd.DataFrame() 51 | for r in all_runs: 52 | try: 53 | cfg = {k: r.config[k] for k in ["reset_network", "zero_ablation", "metric", "task", "threshold"]} 54 | d = { 55 | k: r.summary[k] 56 | for k in [ 57 | "cur_metric", 58 | "num_edges", 59 | "test_docstring_metric", 60 | "test_docstring_stefan", 61 | "test_kl_div", 62 | "test_match_nll", 63 | "test_nll", 64 | ] 65 | } 66 | except KeyError as e: 67 | print("problems with run ", r.name, e) 68 | continue 69 | d = dict(**cfg, **d) 70 | d["alg"] = "acdc" 71 | 72 | idx = int(r.name.split("-")[-1]) 73 | df = pd.concat([df, pd.DataFrame(d, index=[idx])]) 74 | 75 | # %% Now subnetwork-probing runs 76 | 77 | start_idx: float = df.index.max() + 1 78 | 79 | all_runs = api.runs(path="remix_school-of-rock/induction-sp-replicate", filters={"group": SP_GROUP}) 80 | for r in all_runs: 81 | try: 82 | cfg = {k: r.config[k] for k in ["reset_subject", "zero_ablation", "loss_type", "lambda_reg"]} 83 | d = { 84 | k: r.summary[k] 85 | for k in [ 86 | "number_of_edges", 87 | "specific_metric", 88 | "test_docstring_metric", 89 | "test_docstring_stefan", 90 | "test_kl_div", 91 | "test_match_nll", 92 | "test_nll", 93 | ] 94 | } 95 | except KeyError as e: 96 | print("problems with run ", r.name, e) 97 | continue 98 | cfg["metric"] = cfg["loss_type"] 99 | del cfg["loss_type"] 100 | cfg["reset_network"] = cfg["reset_subject"] 101 | del cfg["reset_subject"] 102 | cfg["num_edges"] = d["number_of_edges"] 103 | cfg["cur_metric"] = d["specific_metric"] / r.config["n_loss_average_runs"] 104 | for k in d.keys(): 105 | if k.startswith("test_"): 106 | cfg[k] = d[k] / r.config["n_loss_average_runs"] 107 | cfg["alg"] = "subnetwork-probing" 108 | 109 | idx = int(r.name.split("-")[-1]) + start_idx 110 | df = pd.concat([df, pd.DataFrame(cfg, index=[idx])]) 111 | 112 | # %% 113 | 114 | df.loc[:, "color"] = df.apply(lambda x: f"{x['alg']}-reset={x['reset_network']:.0f}", axis=1) 115 | 116 | # Scatter plot of num_edges vs cur_metric grouped by reset_network 117 | 118 | fig = px.scatter( 119 | df, 120 | x="num_edges", 121 | y="cur_metric", 122 | color="color", 123 | color_discrete_map={ 124 | "acdc-reset=0": "red", 125 | "acdc-reset=1": "blue", 126 | "subnetwork-probing-reset=0": "orange", 127 | "subnetwork-probing-reset=1": "green", 128 | }, 129 | facet_col="zero_ablation", 130 | facet_row="metric", 131 | facet_col_wrap=2, 132 | hover_data=["threshold", "lambda_reg"], 133 | title="Induction, TRAIN metric", 134 | ) 135 | fig.show() 136 | 137 | # %% 138 | 139 | for test_metric in ["test_docstring_metric", "test_docstring_stefan", "test_kl_div", "test_match_nll", "test_nll"]: 140 | fig = px.scatter( 141 | df, 142 | x="num_edges", 143 | y=test_metric, 144 | color="color", 145 | color_discrete_map={ 146 | "acdc-reset=0": "red", 147 | "acdc-reset=1": "blue", 148 | "subnetwork-probing-reset=0": "orange", 149 | "subnetwork-probing-reset=1": "green", 150 | }, 151 | facet_col="zero_ablation", 152 | facet_row="metric", 153 | facet_col_wrap=2, 154 | hover_data=["threshold", "lambda_reg"], 155 | title=f"Induction, {test_metric} metric", 156 | ) 157 | fig.show() 158 | 159 | 160 | # %% Scatter plot for train vs test of every metric 161 | fig = make_subplots() 162 | 163 | for test_metric in ["test_docstring_metric", "test_docstring_stefan", "test_kl_div", "test_match_nll", "test_nll"]: 164 | this_df = df[(df["metric"] == test_metric.lstrip("test_")) & (~df["zero_ablation"])] 165 | trace1 = go.Scatter( 166 | x=this_df["cur_metric"], 167 | y=this_df[test_metric], 168 | mode="markers", 169 | name=test_metric, 170 | ) 171 | fig.add_trace(trace1) 172 | fig.show() 173 | 174 | # %% Compare each metric with the other metrics 175 | 176 | for main_metric in ["test_docstring_metric", "test_docstring_stefan", "test_kl_div", "test_match_nll", "test_nll"]: 177 | fig = make_subplots() 178 | this_df = df[(df["metric"] == main_metric.lstrip("test_")) & (~df["zero_ablation"])] 179 | for test_metric in ["test_docstring_metric", "test_docstring_stefan", "test_kl_div", "test_match_nll", "test_nll"]: 180 | trace1 = go.Scatter( 181 | x=this_df["cur_metric"], 182 | y=this_df[test_metric], 183 | mode="markers", 184 | name=test_metric, 185 | ) 186 | fig.add_trace(trace1) 187 | # set title 188 | fig.update_layout( 189 | title_text=f"Comparison of {main_metric} with other metrics", 190 | xaxis_title=main_metric, 191 | yaxis_title="other metrics", 192 | ) 193 | fig.show() 194 | -------------------------------------------------------------------------------- /experiments/induction_results.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import pandas as pd 4 | import numpy as np 5 | import plotly.express as px 6 | import wandb 7 | 8 | import hashlib 9 | import os 10 | 11 | import plotly.io as pio 12 | 13 | 14 | class EmacsRenderer(pio.base_renderers.ColabRenderer): 15 | save_dir = "ob-jupyter" 16 | base_url = f"http://localhost:8888/files" 17 | 18 | def to_mimebundle(self, fig_dict): 19 | html = super().to_mimebundle(fig_dict)["text/html"] 20 | 21 | mhash = hashlib.md5(html.encode("utf-8")).hexdigest() 22 | if not os.path.isdir(self.save_dir): 23 | os.mkdir(self.save_dir) 24 | fhtml = os.path.join(self.save_dir, mhash + ".html") 25 | with open(fhtml, "w") as f: 26 | f.write(html) 27 | 28 | return {"text/html": f'Click to open {fhtml}'} 29 | 30 | 31 | pio.renderers["emacs"] = EmacsRenderer() 32 | 33 | 34 | def set_plotly_renderer(renderer="emacs"): 35 | pio.renderers.default = renderer 36 | 37 | set_plotly_renderer("emacs") 38 | 39 | ACDC_GROUP = "adria-induction-3" 40 | # ACDC_GROUP = "adria-docstring3" 41 | SP_GROUP = "reset-with-nll-21" 42 | # SP_GROUP = "docstring3" 43 | 44 | # %% 45 | 46 | api = wandb.Api() 47 | all_runs = api.runs(path="remix_school-of-rock/acdc", filters={"group": ACDC_GROUP}) 48 | 49 | df = pd.DataFrame() 50 | 51 | total = len(all_runs) 52 | failed = 0 53 | 54 | for r in all_runs: 55 | try: 56 | d = {k: r.summary[k] for k in ["cur_metric", "test_specific_metric", "num_edges"]} 57 | except KeyError: 58 | failed+=1 59 | else: 60 | idx = int(r.name.split("-")[-1]) 61 | df = pd.concat([df, pd.DataFrame(d, index=[idx])]) 62 | 63 | assert failed/total < 0.5 64 | 65 | # %% 66 | 67 | thresholds = 10 ** np.linspace(-2, 0.5, 21) 68 | 69 | i = 0 70 | for reset_network in [0, 1]: 71 | for zero_ablation in [0, 1]: 72 | for loss_type in ["kl_div", "nll", "match_nll"]: 73 | for threshold in thresholds: 74 | df.loc[i, "reset_network"] = reset_network 75 | df.loc[i, "zero_ablation"] = zero_ablation 76 | df.loc[i, "loss_type"] = loss_type 77 | df.loc[i, "threshold"] = threshold 78 | 79 | i += 1 80 | 81 | df.loc[:, "alg"] = "acdc" 82 | df.loc[:, "num_examples"] = 50 83 | 84 | # %% Now subnetwork-probing runs 85 | 86 | sp_runs = [] 87 | 88 | all_runs = api.runs(path="remix_school-of-rock/induction-sp-replicate", filters={"group": SP_GROUP}) 89 | for r in all_runs: 90 | try: 91 | cfg = {k: r.config[k] for k in ["reset_subject", "zero_ablation", "loss_type", "lambda_reg", "num_examples"]} 92 | d = {k: r.summary[k] for k in ["number_of_edges", "specific_metric", "test_specific_metric", "specific_metric_loss"]} 93 | except KeyError: 94 | continue 95 | cfg["reset_network"] = cfg["reset_subject"] 96 | del cfg["reset_subject"] 97 | cfg["num_edges"] = d["number_of_edges"] 98 | cfg["cur_metric"] = d["specific_metric"] / r.config["n_loss_average_runs"] 99 | cfg["test_specific_metric"] = d["test_specific_metric"] / r.config["n_loss_average_runs"] 100 | cfg["alg"] = "subnetwork-probing" 101 | sp_runs.append(cfg) 102 | 103 | 104 | i = df.index.max() + 1 105 | for r in sp_runs: 106 | df = pd.concat([df, pd.DataFrame(r, index=[i])]) 107 | i += 1 108 | 109 | # %% 110 | 111 | df.loc[:, "color"] = df.apply(lambda x: f"{x['alg']}-reset={x['reset_network']:.0f}", axis=1) 112 | 113 | # Scatter plot of num_edges vs cur_metric grouped by reset_network 114 | 115 | fig = px.scatter( 116 | df, 117 | x="num_edges", 118 | y="cur_metric", 119 | color="color", 120 | color_discrete_map={ 121 | "acdc-reset=0": "red", 122 | "acdc-reset=1": "blue", 123 | "subnetwork-probing-reset=0": "orange", 124 | "subnetwork-probing-reset=1": "green", 125 | }, 126 | facet_col="zero_ablation", 127 | facet_row="loss_type", 128 | facet_col_wrap=2, 129 | hover_data=["threshold", "lambda_reg"], 130 | title="Induction, TRAIN metric", 131 | ) 132 | fig.show() 133 | 134 | # %% 135 | 136 | fig = px.scatter( 137 | df, 138 | x="num_edges", 139 | y="test_specific_metric", 140 | color="color", 141 | color_discrete_map={ 142 | "acdc-reset=0": "red", 143 | "acdc-reset=1": "blue", 144 | "subnetwork-probing-reset=0": "orange", 145 | "subnetwork-probing-reset=1": "green", 146 | }, 147 | facet_col="zero_ablation", 148 | facet_row="loss_type", 149 | facet_col_wrap=2, 150 | hover_data=["threshold", "lambda_reg"], 151 | title="Induction, TEST metric", 152 | ) 153 | fig.show() 154 | -------------------------------------------------------------------------------- /experiments/launch_abstract.py: -------------------------------------------------------------------------------- 1 | from experiments.launcher import KubernetesJob, WandbIdentifier, launch 2 | import numpy as np 3 | import random 4 | from typing import List 5 | 6 | def main(use_kubernetes: bool, testing: bool, CPU: int = 4): 7 | testing = True 8 | task = "docstring" 9 | reset_network = 0 10 | kwargses = [ 11 | {"threshold": 0.067, "metric": "docstring_metric"}, 12 | {"threshold": 0.005, "metric": "kl_div"}, 13 | {"threshold": 0.095, "metric": "kl_div"}, 14 | ] 15 | 16 | seed = 1495006036 17 | random.seed(seed) 18 | 19 | wandb_identifier = WandbIdentifier( 20 | run_name="abstract-{i:05d}", 21 | group_name="abstract", 22 | project="acdc", 23 | ) 24 | 25 | commands: List[List[str]] = [] 26 | for kwargs in kwargses: 27 | command = [ 28 | "python", 29 | "acdc/main.py", 30 | f"--task={task}", 31 | f"--threshold={kwargs['threshold']:.5f}", 32 | "--using-wandb", 33 | f"--wandb-run-name={wandb_identifier.run_name.format(i=len(commands))}", 34 | f"--wandb-group-name={wandb_identifier.group_name}", 35 | f"--wandb-project-name={wandb_identifier.project}", 36 | "--device=cuda", 37 | f"--torch-num-threads={CPU}", 38 | f"--reset-network={reset_network}", 39 | f"--seed={random.randint(0, 2**32 - 1)}", 40 | f"--metric={kwargs['metric']}", 41 | "--wandb-dir=/root/.cache/huggingface/tracr-training/acdc", # If it doesn't exist wandb will use /tmp 42 | f"--wandb-mode=online", 43 | f"--max-num-epochs={1 if testing else 40_000}", 44 | ] 45 | commands.append(command) 46 | 47 | launch( 48 | commands, 49 | name="acdc-docstring-abstract", 50 | job=None 51 | if not use_kubernetes 52 | else KubernetesJob(container="ghcr.io/rhaps0dy/automatic-circuit-discovery:1.6.1", cpu=CPU, gpu=1), 53 | check_wandb=wandb_identifier, 54 | just_print_commands=False, 55 | ) 56 | 57 | 58 | if __name__ == "__main__": 59 | main(use_kubernetes=True, testing=False) 60 | -------------------------------------------------------------------------------- /experiments/launch_abstract_old.py: -------------------------------------------------------------------------------- 1 | from experiments.launcher import KubernetesJob, WandbIdentifier, launch 2 | import numpy as np 3 | import random 4 | 5 | def main(use_kubernetes: bool, testing: bool, CPU: int = 4): 6 | task = "docstring" 7 | reset_network = 0 8 | kwargses = [ 9 | {"threshold": 0.067, "metric": "docstring_metric"}, 10 | {"threshold": 0.005, "metric": "kl_div"}, 11 | {"threshold": 0.095, "metric": "kl_div"}, 12 | ] 13 | 14 | seed = 1495006036 15 | random.seed(seed) 16 | 17 | for kwargs in kwargses: 18 | wandb_identifier = WandbIdentifier( 19 | run_name=f"docstring_kl_{kwargs['threshold']:.5f}", 20 | group_name="default", 21 | project="acdc-abstract", 22 | ) 23 | command = [ 24 | "python", 25 | "acdc/main.py", 26 | f"--task={task}", 27 | f"--threshold={kwargs['threshold']:.5f}", 28 | "--using-wandb", 29 | "--wandb-entity-name=remix_school-of-rock", 30 | f"--wandb-run-name={wandb_identifier.run_name}", 31 | f"--wandb-project-name={wandb_identifier.project}", 32 | ] 33 | launch( 34 | [command], 35 | name=wandb_identifier.run_name, 36 | job=None 37 | if not use_kubernetes 38 | else KubernetesJob(container="ghcr.io/rhaps0dy/automatic-circuit-discovery:abstract-0.0", cpu=CPU, gpu=1), 39 | check_wandb=wandb_identifier, 40 | just_print_commands=False, 41 | ) 42 | 43 | 44 | if __name__ == "__main__": 45 | main(use_kubernetes=True, testing=False) 46 | -------------------------------------------------------------------------------- /experiments/launch_all_sixteen_heads.py: -------------------------------------------------------------------------------- 1 | from experiments.launcher import KubernetesJob, WandbIdentifier, launch 2 | import numpy as np 3 | import random 4 | from typing import List 5 | 6 | METRICS_FOR_TASK = { 7 | "ioi": ["kl_div", "logit_diff"], 8 | "tracr-reverse": ["l2"], 9 | "tracr-proportion": ["kl_div", "l2"], 10 | "induction": ["kl_div", "nll"], 11 | "docstring": ["kl_div", "docstring_metric"], 12 | "greaterthan": ["greaterthan"], # "kl_div", 13 | } 14 | 15 | 16 | CPU = 2 17 | 18 | def main(TASKS: list[str], job: KubernetesJob, name: str, group_name: str): 19 | seed = 1259281515 20 | random.seed(seed) 21 | 22 | wandb_identifier = WandbIdentifier( 23 | run_name=f"{name}-{{i:05d}}", 24 | group_name=group_name, 25 | project="acdc") 26 | 27 | commands: List[List[str]] = [] 28 | for reset_network in [0, 1]: 29 | for zero_ablation in [0, 1]: 30 | for task in TASKS: 31 | for metric in METRICS_FOR_TASK[task]: 32 | if "tracr" not in task: 33 | if reset_network==0 and zero_ablation==0: 34 | continue 35 | if task in ["ioi", "induction"] and reset_network==0 and zero_ablation==1: 36 | continue 37 | 38 | command = [ 39 | "python", 40 | "experiments/launch_sixteen_heads.py", 41 | f"--task={task}", 42 | f"--wandb-run-name={wandb_identifier.run_name.format(i=len(commands))}", 43 | f"--wandb-group={wandb_identifier.group_name}", 44 | f"--wandb-project={wandb_identifier.project}", 45 | f"--device={'cuda' if job.gpu else 'cpu'}", 46 | f"--reset-network={reset_network}", 47 | f"--seed={random.randint(0, 2**32 - 1)}", 48 | f"--metric={metric}", 49 | f"--torch-num-threads={CPU}", 50 | "--wandb-dir=/root/.cache/huggingface/tracr-training/16heads", # If it doesn't exist wandb will use /tmp 51 | f"--wandb-mode=online", 52 | ] 53 | if zero_ablation: 54 | command.append("--zero-ablation") 55 | 56 | commands.append(command) 57 | 58 | 59 | launch( 60 | commands, 61 | name=wandb_identifier.run_name, 62 | job=job, 63 | check_wandb=wandb_identifier, 64 | just_print_commands=False, 65 | synchronous=True, 66 | ) 67 | 68 | 69 | if __name__ == "__main__": 70 | main( 71 | # ["ioi", "greaterthan", "induction", "docstring"], 72 | ["tracr-reverse"], 73 | KubernetesJob(container="ghcr.io/rhaps0dy/automatic-circuit-discovery:1.7.1", cpu=CPU, gpu=0), 74 | "16h-redo", 75 | group_name="sixteen-heads", 76 | ) 77 | # main( 78 | # ["tracr-reverse", "tracr-proportion"], 79 | # KubernetesJob(container="ghcr.io/rhaps0dy/automatic-circuit-discovery:1.6.1", cpu=4, gpu=0), 80 | # "16h-tracr", 81 | # ) 82 | -------------------------------------------------------------------------------- /experiments/launch_docstring.py: -------------------------------------------------------------------------------- 1 | from experiments.launcher import KubernetesJob, launch 2 | import numpy as np 3 | 4 | CPU = 4 5 | 6 | 7 | def main(testing: bool): 8 | thresholds = 10 ** np.linspace(-2, 0.5, 21) 9 | seed = 516626229 10 | 11 | commands: list[list[str]] = [] 12 | for reset_network in [0]: 13 | for zero_ablation in [0, 1]: 14 | for loss_type in ["kl_div", "docstring_metric", "docstring_stefan", "nll", "match_nll"]: 15 | for threshold in [1.0] if testing else thresholds: 16 | command = [ 17 | "python", 18 | "acdc/main.py" if testing else "/Automatic-Circuit-Discovery/acdc/main.py", 19 | "--task=docstring", 20 | f"--threshold={threshold:.5f}", 21 | "--using-wandb", 22 | f"--wandb-run-name=agarriga-docstring-{len(commands):03d}", 23 | "--wandb-group-name=adria-docstring", 24 | f"--device=cpu", 25 | f"--reset-network={reset_network}", 26 | f"--seed={seed}", 27 | f"--metric={loss_type}", 28 | f"--torch-num-threads={CPU}", 29 | "--wandb-dir=/training/acdc", # If it doesn't exist wandb will use /tmp 30 | f"--wandb-mode={'offline' if testing else 'online'}", 31 | f"--max-num-epochs={1 if testing else 100_000}", 32 | ] 33 | if zero_ablation: 34 | command.append("--zero-ablation") 35 | 36 | commands.append(command) 37 | 38 | launch( 39 | commands, 40 | name="acdc-docstring", 41 | job=None 42 | if testing 43 | else KubernetesJob(container="ghcr.io/rhaps0dy/automatic-circuit-discovery:1.2.10", cpu=CPU, gpu=0), 44 | ) 45 | 46 | 47 | if __name__ == "__main__": 48 | main(testing=False) 49 | -------------------------------------------------------------------------------- /experiments/launch_induction.py: -------------------------------------------------------------------------------- 1 | from experiments.launcher import KubernetesJob, launch 2 | import subprocess 3 | import argparse 4 | import numpy as np 5 | 6 | CPU = 4 7 | 8 | def main( 9 | testing: bool, 10 | is_adria: bool, 11 | ): 12 | thresholds = 10 ** np.linspace(-2, 0.5, 21) 13 | seed = 424671755 14 | 15 | commands: list[list[str]] = [] 16 | for reset_network in [0, 1]: 17 | for zero_ablation in [0, 1]: 18 | for loss_type in ["kl_div"]: 19 | for threshold in [1.0] if testing else thresholds: 20 | command = [ 21 | "python", 22 | "acdc/main.py" if (not is_adria) else "/Automatic-Circuit-Discovery/acdc/main.py", 23 | "--task=induction", 24 | f"--threshold={threshold:.5f}", 25 | "--using-wandb", 26 | f"--wandb-run-name=agarriga-acdc-{len(commands):03d}", 27 | "--wandb-group-name=adria-induction-3", 28 | f"--device=cpu", 29 | f"--reset-network={reset_network}", 30 | f"--seed={seed}", 31 | f"--metric={loss_type}", 32 | f"--torch-num-threads={CPU}", 33 | "--wandb-dir=/training/acdc", 34 | f"--wandb-mode={'offline' if testing else 'online'}", 35 | ] 36 | if zero_ablation: 37 | command.append("--zero-ablation") 38 | 39 | commands.append(command) 40 | 41 | if is_adria: 42 | launch( 43 | commands, 44 | name="acdc-induction", 45 | job=None 46 | if testing 47 | else KubernetesJob(container="ghcr.io/rhaps0dy/automatic-circuit-discovery:1.2.8", cpu=CPU, gpu=0), 48 | ) 49 | 50 | else: 51 | for command in commands: 52 | print("Running", command) 53 | subprocess.run(command) 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("--testing", action="store_true") 58 | parser.add_argument("--is-adria", action="store_true") 59 | main( 60 | testing=parser.parse_args().testing, 61 | is_adria=parser.parse_args().is_adria, 62 | ) 63 | -------------------------------------------------------------------------------- /experiments/launch_spreadsheet.py: -------------------------------------------------------------------------------- 1 | from experiments.launcher import KubernetesJob, WandbIdentifier, launch 2 | import numpy as np 3 | import random 4 | from typing import List 5 | 6 | METRICS_FOR_TASK = { 7 | "ioi": ["kl_div", "logit_diff"], 8 | "tracr-reverse": ["l2"], 9 | "tracr-proportion": ["l2"], 10 | "induction": ["kl_div", "nll"], 11 | "docstring": ["kl_div", "docstring_metric"], 12 | "greaterthan": ["kl_div", "greaterthan"], 13 | } 14 | 15 | CPU = 4 16 | 17 | def main(TASKS: list[str], group_name: str, run_name: str, testing: bool, use_kubernetes: bool, reset_networks: bool, abs_value_threshold: bool, use_gpu: bool=True): 18 | NUM_SPACINGS = 5 if reset_networks else 21 19 | base_thresholds = 10 ** np.linspace(-4, 0, 21) 20 | 21 | seed = 486887094 22 | random.seed(seed) 23 | 24 | wandb_identifier = WandbIdentifier( 25 | run_name=run_name, 26 | group_name=group_name, 27 | project="acdc") 28 | 29 | commands: List[List[str]] = [] 30 | for reset_network in [int(reset_networks)]: 31 | for zero_ablation in [0]: 32 | for task in TASKS: 33 | for metric in METRICS_FOR_TASK[task]: 34 | 35 | if task.startswith("tracr"): 36 | # Typical metric value range: 0.0-0.1 37 | thresholds = 10 ** np.linspace(-5, -1, 21) 38 | 39 | if task == "tracr-reverse": 40 | num_examples = 6 41 | seq_len = 5 42 | elif task == "tracr-proportion": 43 | num_examples = 50 44 | seq_len = 5 45 | else: 46 | raise ValueError("Unknown task") 47 | 48 | elif task == "greaterthan": 49 | if metric == "kl_div": 50 | # Typical metric value range: 0.0-20 51 | # thresholds = 10 ** np.linspace(-4, 0, NUM_SPACINGS) 52 | thresholds = 10 ** np.linspace(-6, -4, 11) 53 | elif metric == "greaterthan": 54 | # Typical metric value range: -1.0 - 0.0 55 | thresholds = 10 ** np.linspace(-3, -1, NUM_SPACINGS) 56 | else: 57 | raise ValueError("Unknown metric") 58 | num_examples = 100 59 | seq_len = -1 60 | elif task == "docstring": 61 | seq_len = 41 62 | if metric == "kl_div": 63 | # Typical metric value range: 0.0-10.0 64 | thresholds = base_thresholds 65 | elif metric == "docstring_metric": 66 | # Typical metric value range: -1.0 - 0.0 67 | thresholds = 10 ** np.linspace(-4, 0, 21) 68 | else: 69 | raise ValueError("Unknown metric") 70 | num_examples = 50 71 | elif task == "ioi": 72 | num_examples = 100 73 | seq_len = -1 74 | if metric == "kl_div": 75 | # Typical metric value range: 0.0-12.0 76 | thresholds = 10 ** np.linspace(-6, 0, 31) 77 | elif metric == "logit_diff": 78 | # Typical metric value range: -0.31 -- -0.01 79 | thresholds = 10 ** np.linspace(-4, 0, NUM_SPACINGS) 80 | else: 81 | raise ValueError("Unknown metric") 82 | elif task == "induction": 83 | seq_len = 300 84 | num_examples = 50 85 | if metric == "kl_div": 86 | # Typical metric value range: 0.0-16.0 87 | thresholds = base_thresholds 88 | elif metric == "nll": 89 | # Typical metric value range: 0.0-16.0 90 | thresholds = base_thresholds 91 | else: 92 | raise ValueError("Unknown metric") 93 | else: 94 | raise ValueError("Unknown task") 95 | 96 | for threshold in [1.0] if testing else thresholds: 97 | command = [ 98 | "python", 99 | "acdc/main.py", 100 | f"--task={task}", 101 | f"--threshold={threshold}", 102 | "--using-wandb", 103 | f"--wandb-run-name={wandb_identifier.run_name.format(i=len(commands))}", 104 | f"--wandb-group-name={wandb_identifier.group_name}", 105 | f"--wandb-project-name={wandb_identifier.project}", 106 | f"--device={'cuda' if not testing else 'cpu'}" if "tracr" not in task else "--device=cpu", 107 | f"--reset-network={reset_network}", 108 | f"--seed={random.randint(0, 2**32 - 1)}", 109 | f"--metric={metric}", 110 | f"--torch-num-threads={CPU}", 111 | "--wandb-dir=/root/.cache/huggingface/tracr-training/acdc", # If it doesn't exist wandb will use /tmp 112 | f"--wandb-mode=online", 113 | f"--max-num-epochs={1 if testing else 40_000}", 114 | ] 115 | if zero_ablation: 116 | command.append("--zero-ablation") 117 | if abs_value_threshold: 118 | command.append("--abs-value-threshold") 119 | commands.append(command) 120 | 121 | launch( 122 | commands, 123 | name="acdc-spreadsheet", 124 | job=None 125 | if not use_kubernetes 126 | else KubernetesJob(container="ghcr.io/rhaps0dy/automatic-circuit-discovery:181999f", cpu=CPU, gpu=int(use_gpu)), 127 | check_wandb=wandb_identifier, 128 | just_print_commands=False, 129 | ) 130 | 131 | 132 | if __name__ == "__main__": 133 | for reset_networks in [False]: 134 | main( 135 | TASKS=["ioi"], 136 | group_name="abs-value", 137 | run_name=f"agarriga-ioi-res{int(reset_networks)}-{{i:05d}}", 138 | testing=False, 139 | use_kubernetes=True, 140 | reset_networks=reset_networks, 141 | abs_value_threshold=True, 142 | use_gpu=True, 143 | ) 144 | 145 | # if __name__ == "__main__": 146 | # for reset_networks in [False, True]: 147 | # main( 148 | # ["ioi", "greaterthan", "induction", "docstring"], 149 | # "reset-networks-neurips", 150 | # "agarriga-tracr3-{i:05d}", 151 | # testing=False, 152 | # use_kubernetes=True, 153 | # reset_networks=True, 154 | # ) 155 | # main( 156 | # ["induction"], 157 | # "adria-induction-3", 158 | # "agarriga-induction-{i:05d}", 159 | # testing=False, 160 | # use_kubernetes=True, 161 | # reset_networks=False, 162 | # ) 163 | -------------------------------------------------------------------------------- /experiments/launcher.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import subprocess 3 | from typing import Optional, TextIO, List, Tuple 4 | import numpy as np 5 | import shlex 6 | import dataclasses 7 | import wandb 8 | 9 | @dataclasses.dataclass(frozen=True) 10 | class KubernetesJob: 11 | container: str 12 | cpu: int 13 | gpu: int 14 | mount_training: bool=False 15 | 16 | def mount_training_options(self) -> list[str]: 17 | if not self.mount_training: 18 | return [] 19 | return [ 20 | "--volume-mount=/training", 21 | "--volume-name=agarriga-models-training", 22 | ] 23 | 24 | 25 | @dataclasses.dataclass(frozen=True) 26 | class WandbIdentifier: 27 | run_name: str 28 | group_name: str 29 | project: str 30 | 31 | 32 | def launch(commands: List[List[str]], name: str, job: Optional[KubernetesJob] = None, check_wandb: Optional[WandbIdentifier]=None, ids_for_worker=range(0, 10000000), synchronous=True, just_print_commands=False): 33 | to_wait: List[Tuple[str, subprocess.Popen, TextIO, TextIO]] = [] 34 | 35 | assert len(commands) <= 100_000, "Too many commands for 5 digits" 36 | 37 | print(f"Launching {len(commands)} jobs") 38 | for i, command in enumerate(commands): 39 | if i not in ids_for_worker: 40 | print(f"Skipping {name} because it's not my turn, {i} not in {ids_for_worker}") 41 | continue 42 | 43 | command_str = shlex.join(command) 44 | 45 | 46 | if check_wandb is not None: 47 | # HACK this is pretty vulnerable to duplicating work if the same run is launched in close succession, 48 | # it's more to be able to restart 49 | # api = wandb.Api() 50 | name = check_wandb.run_name.format(i=i) 51 | # if name in existing_names: 52 | # print(f"Skipping {name} because it already exists") 53 | # continue 54 | 55 | # runs = api.runs(path=f"remix_school-of-rock/{check_wandb.project}", filters={"group": check_wandb.group_name}) 56 | # existing_names = existing_names.union({r.name for r in runs}) 57 | # print("Runs that exist: ", existing_names) 58 | # if name in existing_names: 59 | # print(f"Run {name} already exists, skipping") 60 | # continue 61 | 62 | print("Launching", name, command_str) 63 | if just_print_commands: 64 | continue 65 | 66 | if job is None: 67 | if synchronous: 68 | out = subprocess.run(command) 69 | assert out.returncode == 0, f"Command return={out.returncode} != 0" 70 | else: 71 | base_path = Path(f"/tmp/{name}") 72 | base_path.mkdir(parents=True, exist_ok=True) 73 | stdout = open(base_path / f"stdout_{i:05d}.txt", "w") 74 | stderr = open(base_path / f"stderr_{i:05d}.txt", "w") 75 | out = subprocess.Popen(command, stdout=stdout, stderr=stderr) 76 | to_wait.append((command_str, out, stdout, stderr)) 77 | else: 78 | if "cuda" in command_str: 79 | assert job.gpu > 0 80 | else: 81 | assert job.gpu == 0 82 | 83 | subprocess.run( 84 | [ 85 | "ctl", 86 | "job", 87 | "run", 88 | f"--name={name}", 89 | "--shared-host-dir-slow-tolerant", 90 | f"--container={job.container}", 91 | f"--cpu={job.cpu}", 92 | f"--gpu={job.gpu}", 93 | "--login", 94 | "--wandb", 95 | "--never-restart", 96 | f"--command={command_str}", 97 | "--working-dir=/Automatic-Circuit-Discovery", 98 | "--shared-host-dir=/home/agarriga/.cache", 99 | "--shared-host-dir-mount=/root/.cache", 100 | *job.mount_training_options(), 101 | ], 102 | check=True, 103 | ) 104 | i += 1 105 | 106 | for (command, process, out, err) in to_wait: 107 | retcode = process.wait() 108 | with open(out.name, 'r') as f: 109 | stdout = f.read() 110 | with open(err.name, 'r') as f: 111 | stderr = f.read() 112 | 113 | if retcode != 0 or "nan" in stdout.lower() or "nan" in stderr.lower(): 114 | s = f""" Command {command} exited with code {retcode}. 115 | stdout: 116 | {stdout} 117 | stderr: 118 | {stderr} 119 | """ 120 | print(s) 121 | -------------------------------------------------------------------------------- /experiments/results/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | *.log 4 | tmp/ 5 | 6 | 7 | plots_with_qkv_deleted 8 | plots_with_qkv_deleted_two 9 | plots_old_colors_rgb 10 | .auctex-auto/* 11 | auc_tables.aux 12 | auc_tables.pdf 13 | auc_tables.synctex.gz 14 | -------------------------------------------------------------------------------- /experiments/results/auc_tables.tex: -------------------------------------------------------------------------------- 1 | \documentclass{article} 2 | \usepackage{booktabs} 3 | \usepackage{multirow} 4 | \begin{document} 5 | 6 | WARNING: you need to add some vertical and horizontal rows for this to look good 7 | 8 | \begin{table} 9 | \caption{Key=auc, weights\_type=trained, Random Ablation} 10 | \begin{tabular}{llllllll} 11 | \toprule 12 | & & \multicolumn{3}{r}{roc\_edges} & \multicolumn{3}{r}{roc\_nodes} \\ 13 | & & ACDC & HISP & SP & ACDC & HISP & SP \\ 14 | metric\_pretty & task & & & & & & \\ 15 | \midrule 16 | \multirow[c]{3}{*}{KL} & docstring & 0.434 & 0.183 & \textbf{0.937} & 0.391 & 0.178 & \textbf{0.928} \\ 17 | & greaterthan & \textbf{0.883} & 0.279 & 0.820 & \textbf{0.890} & 0.384 & 0.838 \\ 18 | & ioi & 0.868 & 0.239 & \textbf{0.888} & \textbf{0.873} & 0.339 & 0.852 \\ 19 | \multirow[c]{5}{*}{Loss} & docstring & \textbf{0.972} & 0.177 & 0.942 & 0.938 & 0.170 & \textbf{0.941} \\ 20 | & greaterthan & 0.461 & 0.275 & \textbf{0.848} & 0.766 & 0.374 & \textbf{0.830} \\ 21 | & ioi & 0.589 & 0.227 & \textbf{0.837} & 0.777 & 0.283 & \textbf{0.814} \\ 22 | & tracr-proportion & 0.600 & \textbf{0.679} & 0.400 & 0.727 & \textbf{0.909} & 0.716 \\ 23 | & tracr-reverse & 0.200 & \textbf{0.656} & 0.416 & 0.312 & \textbf{0.750} & 0.533 \\ 24 | \bottomrule 25 | \end{tabular} 26 | \end{table} 27 | \begin{table} 28 | \caption{Key=auc, weights\_type=trained, Zero Ablation} 29 | \begin{tabular}{llllllll} 30 | \toprule 31 | & & \multicolumn{3}{r}{roc\_edges} & \multicolumn{3}{r}{roc\_nodes} \\ 32 | & & ACDC & HISP & SP & ACDC & HISP & SP \\ 33 | metric\_pretty & task & & & & & & \\ 34 | \midrule 35 | \multirow[c]{3}{*}{KL} & docstring & \textbf{0.585} & 0.183 & 0.428 & 0.190 & 0.178 & \textbf{0.420} \\ 36 | & greaterthan & 0.276 & \textbf{0.279} & 0.163 & \textbf{0.653} & 0.384 & 0.134 \\ 37 | & ioi & 0.226 & 0.239 & \textbf{0.702} & 0.511 & 0.339 & \textbf{0.638} \\ 38 | \multirow[c]{5}{*}{Loss} & docstring & \textbf{0.816} & 0.177 & 0.482 & \textbf{0.845} & 0.170 & 0.398 \\ 39 | & greaterthan & 0.159 & 0.275 & \textbf{0.715} & 0.317 & 0.374 & \textbf{0.597} \\ 40 | & ioi & 0.403 & 0.227 & \textbf{0.598} & \textbf{0.541} & 0.283 & 0.507 \\ 41 | & tracr-proportion & \textbf{1.000} & 0.679 & 0.561 & \textbf{1.000} & 0.909 & 0.875 \\ 42 | & tracr-reverse & \textbf{1.000} & 0.656 & 0.692 & \textbf{1.000} & 0.750 & 0.947 \\ 43 | \bottomrule 44 | \end{tabular} 45 | \end{table} 46 | \end{document} 47 | -------------------------------------------------------------------------------- /experiments/results/canonical_circuits/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | *.log 4 | tmp/ 5 | 6 | *.pdf 7 | -------------------------------------------------------------------------------- /experiments/results/canonical_circuits/Makefile: -------------------------------------------------------------------------------- 1 | tracr-proportion.gv: 2 | python ../../../notebooks/roc_plot_generator.py --task=tracr-proportion --only-save-canonical --metric=l2 3 | 4 | tracr-reverse.gv: 5 | python ../../../notebooks/roc_plot_generator.py --task=tracr-reverse --only-save-canonical --metric=l2 6 | 7 | ioi.gv ioi_heads.gv ioi_heads_qkv.gv: 8 | python ../../../notebooks/roc_plot_generator.py --task=ioi --only-save-canonical 9 | 10 | greaterthan.gv greaterthan_heads.gv greaterthan_heads_qkv.gv: 11 | python ../../../notebooks/roc_plot_generator.py --task=greaterthan --only-save-canonical 12 | 13 | docstring.gv: 14 | python ../../../notebooks/roc_plot_generator.py --task=docstring --only-save-canonical 15 | 16 | %.pdf: %.gv 17 | neato -Tpdf $< -o $@ 18 | 19 | all: tracr-proportion.pdf tracr-reverse.pdf ioi.pdf greaterthan.pdf docstring.pdf ioi_heads.pdf ioi_heads_qkv.pdf greaterthan_heads.pdf greaterthan_heads_qkv.pdf 20 | 21 | ioi: ioi.pdf ioi_heads.pdf ioi_heads_qkv.pdf 22 | 23 | greaterthan: greaterthan.pdf greaterthan_heads.pdf greaterthan_heads_qkv.pdf 24 | 25 | tracr: tracr-proportion.pdf tracr-reverse.pdf 26 | 27 | docstring: docstring.pdf 28 | -------------------------------------------------------------------------------- /experiments/results/canonical_circuits/greaterthan/Makefile: -------------------------------------------------------------------------------- 1 | layout.\#%.gv: \#%.gv 2 | neato -olayout.$@ $< 3 | 4 | layout.gv: $(wildcard layout.*.gv) 5 | gvpack -Glayout=neato -Goverlap=false -Gsplines=true -Gbgcolor=transparent -o$@ $^ 6 | -------------------------------------------------------------------------------- /experiments/results/canonical_circuits/ioi/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | *.log 4 | tmp/ 5 | 6 | 7 | \#cbd5e8.gv 8 | \#f0f0f0.gv 9 | \#fff2ae.gv 10 | layout.\#e7f2da.gv 11 | layout.\#fad6e9.gv 12 | d7f8ee.gv 13 | \#f9ecd7.gv 14 | \#fff6db.gv 15 | layout.\#ececf5.gv 16 | layout.\#fee7d5.gv 17 | out.bak.xgr 18 | out.gv.bak 19 | -------------------------------------------------------------------------------- /experiments/results/canonical_circuits/ioi/Makefile: -------------------------------------------------------------------------------- 1 | layout.\#%.gv: \#%.gv 2 | neato -olayout.$@ $< 3 | 4 | layout.gv: $(wildcard layout.*.gv) 5 | gvpack -Glayout=neato -Goverlap=false -Gsplines=true -Gbgcolor=transparent -o$@ $^ 6 | -------------------------------------------------------------------------------- /experiments/results/plots/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | *.log 4 | tmp/ 5 | 6 | 7 | *.pdf 8 | data.csv 9 | -------------------------------------------------------------------------------- /experiments/results/plots_data/acdc-tracr-proportion-l2-False-1.json: -------------------------------------------------------------------------------- 1 | { 2 | "reset": { 3 | "random_ablation": { 4 | "tracr-proportion": { 5 | "l2": { 6 | "ACDC": { 7 | "score": [ 8 | Infinity, 9 | 0.1, 10 | 0.0631, 11 | 0.03981, 12 | 0.02512, 13 | 0.01585, 14 | 0.01, 15 | 0.00631, 16 | 0.00398, 17 | 0.00251, 18 | 0.00158, 19 | 0.001, 20 | 0.00063, 21 | 0.0004, 22 | 0.00025, 23 | 0.00016, 24 | 0.0001, 25 | 6e-05, 26 | 4e-05, 27 | 3e-05, 28 | 2e-05, 29 | 1e-05, 30 | -Infinity 31 | ], 32 | "test_kl_div": [ 33 | Infinity, 34 | 0.010016628541052341, 35 | 0.010016628541052341, 36 | 0.010016628541052341, 37 | 0.010016628541052341, 38 | 0.010016628541052341, 39 | 0.010016628541052341, 40 | 0.010016628541052341, 41 | 0.010016628541052341, 42 | 0.010016628541052341, 43 | 0.010016628541052341, 44 | 0.010016628541052341, 45 | 0.010016628541052341, 46 | 0.010016628541052341, 47 | 0.010016628541052341, 48 | 0.010016628541052341, 49 | 0.010016628541052341, 50 | 0.010016628541052341, 51 | 0.010016628541052341, 52 | 0.010016628541052341, 53 | 0.010016628541052341, 54 | 0.010016628541052341, 55 | -Infinity 56 | ], 57 | "steps": [ 58 | Infinity, 59 | 74, 60 | 74, 61 | 74, 62 | 74, 63 | 74, 64 | 74, 65 | 74, 66 | 74, 67 | 74, 68 | 74, 69 | 74, 70 | 74, 71 | 74, 72 | 74, 73 | 74, 74 | 74, 75 | 74, 76 | 74, 77 | 74, 78 | 74, 79 | 74, 80 | -Infinity 81 | ], 82 | "test_l2": [ 83 | Infinity, 84 | 0.11851852387189865, 85 | 0.11851852387189865, 86 | 0.11851852387189865, 87 | 0.11851852387189865, 88 | 0.11851852387189865, 89 | 0.11851852387189865, 90 | 0.11851852387189865, 91 | 0.11851852387189865, 92 | 0.11851852387189865, 93 | 0.11851852387189865, 94 | 0.11851852387189865, 95 | 0.11851852387189865, 96 | 0.11851852387189865, 97 | 0.11851852387189865, 98 | 0.11851852387189865, 99 | 0.11851852387189865, 100 | 0.11851852387189865, 101 | 0.11851852387189865, 102 | 0.11851852387189865, 103 | 0.11851852387189865, 104 | 0.11851852387189865, 105 | -Infinity 106 | ], 107 | "edge_fpr": [ 108 | 0.0, 109 | 0.0, 110 | 0.0, 111 | 0.0, 112 | 0.0, 113 | 0.0, 114 | 0.0, 115 | 0.0, 116 | 0.0, 117 | 0.0, 118 | 0.0, 119 | 0.0, 120 | 0.0, 121 | 0.0, 122 | 0.0, 123 | 0.0, 124 | 0.0, 125 | 0.0, 126 | 0.0, 127 | 0.0, 128 | 0.0, 129 | 0.0, 130 | 1.0 131 | ], 132 | "edge_tpr": [ 133 | 0.0, 134 | 0.0, 135 | 0.0, 136 | 0.0, 137 | 0.0, 138 | 0.0, 139 | 0.0, 140 | 0.0, 141 | 0.0, 142 | 0.0, 143 | 0.0, 144 | 0.0, 145 | 0.0, 146 | 0.0, 147 | 0.0, 148 | 0.0, 149 | 0.0, 150 | 0.0, 151 | 0.0, 152 | 0.0, 153 | 0.0, 154 | 0.0, 155 | 1.0 156 | ], 157 | "edge_precision": [ 158 | 1.0, 159 | 1, 160 | 1, 161 | 1, 162 | 1, 163 | 1, 164 | 1, 165 | 1, 166 | 1, 167 | 1, 168 | 1, 169 | 1, 170 | 1, 171 | 1, 172 | 1, 173 | 1, 174 | 1, 175 | 1, 176 | 1, 177 | 1, 178 | 1, 179 | 1, 180 | 0.0 181 | ], 182 | "n_edges": [ 183 | NaN, 184 | 0, 185 | 0, 186 | 0, 187 | 0, 188 | 0, 189 | 0, 190 | 0, 191 | 0, 192 | 0, 193 | 0, 194 | 0, 195 | 0, 196 | 0, 197 | 0, 198 | 0, 199 | 0, 200 | 0, 201 | 0, 202 | 0, 203 | 0, 204 | 0, 205 | NaN 206 | ], 207 | "node_fpr": [ 208 | 0.0, 209 | 0.0, 210 | 0.0, 211 | 0.0, 212 | 0.0, 213 | 0.0, 214 | 0.0, 215 | 0.0, 216 | 0.0, 217 | 0.0, 218 | 0.0, 219 | 0.0, 220 | 0.0, 221 | 0.0, 222 | 0.0, 223 | 0.0, 224 | 0.0, 225 | 0.0, 226 | 0.0, 227 | 0.0, 228 | 0.0, 229 | 0.0, 230 | 1.0 231 | ], 232 | "node_tpr": [ 233 | 0.0, 234 | 0.0, 235 | 0.0, 236 | 0.0, 237 | 0.0, 238 | 0.0, 239 | 0.0, 240 | 0.0, 241 | 0.0, 242 | 0.0, 243 | 0.0, 244 | 0.0, 245 | 0.0, 246 | 0.0, 247 | 0.0, 248 | 0.0, 249 | 0.0, 250 | 0.0, 251 | 0.0, 252 | 0.0, 253 | 0.0, 254 | 0.0, 255 | 1.0 256 | ], 257 | "node_precision": [ 258 | 1.0, 259 | 1, 260 | 1, 261 | 1, 262 | 1, 263 | 1, 264 | 1, 265 | 1, 266 | 1, 267 | 1, 268 | 1, 269 | 1, 270 | 1, 271 | 1, 272 | 1, 273 | 1, 274 | 1, 275 | 1, 276 | 1, 277 | 1, 278 | 1, 279 | 1, 280 | 0.0 281 | ], 282 | "n_nodes": [ 283 | NaN, 284 | 0, 285 | 0, 286 | 0, 287 | 0, 288 | 0, 289 | 0, 290 | 0, 291 | 0, 292 | 0, 293 | 0, 294 | 0, 295 | 0, 296 | 0, 297 | 0, 298 | 0, 299 | 0, 300 | 0, 301 | 0, 302 | 0, 303 | 0, 304 | 0, 305 | NaN 306 | ] 307 | } 308 | } 309 | } 310 | } 311 | } 312 | } -------------------------------------------------------------------------------- /experiments/results/plots_data/generate_makefile.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from experiments.launcher import KubernetesJob, launch 4 | import shlex 5 | 6 | TASKS = ["ioi", "docstring", "greaterthan", "tracr-reverse", "tracr-proportion", "induction"] 7 | 8 | METRICS_FOR_TASK = { 9 | "ioi": ["kl_div", "logit_diff"], 10 | "tracr-reverse": ["l2"], 11 | "tracr-proportion": ["l2"], 12 | "induction": ["kl_div", "nll"], 13 | "docstring": ["kl_div", "docstring_metric"], 14 | "greaterthan": ["kl_div", "greaterthan"], 15 | } 16 | 17 | 18 | def main(): 19 | OUT_DIR = Path(__file__).resolve().parent 20 | 21 | actual_files = set(os.listdir(OUT_DIR)) 22 | trained_files = [] 23 | reset_files = [] 24 | sixteenh_files = [] 25 | sp_files = [] 26 | acdc_files = [] 27 | canonical_files = [] 28 | random_files = [] 29 | zero_files = [] 30 | ioi_files = [] 31 | docstring_files = [] 32 | greaterthan_files = [] 33 | tracr_reverse_files = [] 34 | tracr_proportion_files = [] 35 | induction_files = [] 36 | 37 | with open(OUT_DIR/ "Makefile", "w") as f: 38 | possible_files = {"generate_makefile.py", "Makefile"} 39 | 40 | for alg in ["16h", "sp", "acdc", "canonical"]: 41 | for reset_network in [0, 1]: 42 | for zero_ablation in [0, 1]: 43 | for task in TASKS: 44 | for metric in METRICS_FOR_TASK[task]: 45 | if alg == "canonical" and metric == "kl_div": 46 | # No need to repeat the canonical calculations for both train metrics 47 | # (they're the same, nothing is trained) 48 | continue 49 | 50 | fname = f"{alg}-{task}-{metric}-{bool(zero_ablation)}-{reset_network}.json" 51 | possible_files.add(fname) 52 | 53 | 54 | command = [ 55 | "python", 56 | "../../../notebooks/roc_plot_generator.py", 57 | f"--task={task}", 58 | f"--reset-network={reset_network}", 59 | f"--metric={metric}", 60 | f"--alg={alg}", 61 | ] 62 | if zero_ablation: 63 | command.append("--zero-ablation") 64 | 65 | f.write(fname + ":\n" + "\t" + shlex.join(command) + "\n\n") 66 | 67 | if alg == "16h": 68 | sixteenh_files.append(fname) 69 | elif alg == "sp": 70 | sp_files.append(fname) 71 | elif alg == "acdc": 72 | acdc_files.append(fname) 73 | elif alg == "canonical": 74 | canonical_files.append(fname) 75 | 76 | if reset_network: 77 | reset_files.append(fname) 78 | else: 79 | trained_files.append(fname) 80 | 81 | if zero_ablation: 82 | zero_files.append(fname) 83 | else: 84 | random_files.append(fname) 85 | 86 | if task == "ioi": 87 | ioi_files.append(fname) 88 | elif task == "docstring": 89 | docstring_files.append(fname) 90 | elif task == "greaterthan": 91 | greaterthan_files.append(fname) 92 | elif task == "tracr-reverse": 93 | tracr_reverse_files.append(fname) 94 | elif task == "tracr-proportion": 95 | tracr_proportion_files.append(fname) 96 | elif task == "induction": 97 | induction_files.append(fname) 98 | 99 | f.write("all: " + " ".join(sorted(possible_files)) + "\n\n") 100 | f.write("16h: " + " ".join(sorted(sixteenh_files)) + "\n\n") 101 | f.write("sp: " + " ".join(sorted(sp_files)) + "\n\n") 102 | f.write("acdc: " + " ".join(sorted(acdc_files)) + "\n\n") 103 | f.write("canonical: " + " ".join(sorted(canonical_files)) + "\n\n") 104 | f.write("trained: " + " ".join(sorted(trained_files)) + "\n\n") 105 | f.write("reset: " + " ".join(sorted(reset_files)) + "\n\n") 106 | f.write("zero: " + " ".join(sorted(zero_files)) + "\n\n") 107 | f.write("random: " + " ".join(sorted(random_files)) + "\n\n") 108 | f.write("ioi: " + " ".join(sorted(ioi_files)) + "\n\n") 109 | f.write("docstring: " + " ".join(sorted(docstring_files)) + "\n\n") 110 | f.write("greaterthan: " + " ".join(sorted(greaterthan_files)) + "\n\n") 111 | f.write("tracr-reverse: " + " ".join(sorted(tracr_reverse_files)) + "\n\n") 112 | f.write("tracr-proportion: " + " ".join(sorted(tracr_proportion_files)) + "\n\n") 113 | f.write("induction: " + " ".join(sorted(induction_files)) + "\n\n") 114 | 115 | print(actual_files - possible_files) 116 | assert len(actual_files - possible_files) == 0, "There are files that shouldn't be there" 117 | 118 | missing_files = possible_files - actual_files 119 | print(f"Missing {len(missing_files)} files:") 120 | for missing_file in missing_files: 121 | print(missing_file) 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /ims/make_jsons.py: -------------------------------------------------------------------------------- 1 | #%% 2 | 3 | import plotly 4 | 5 | # Set renderers 6 | plotly.io.renderers.default = "jupyterlab" # apparently the secret sauce to make latex work in notebooks (not requireing setting renderer to browser?) 7 | # that failed 8 | 9 | plotly.io.renderers.default = "browser" 10 | 11 | # note: LATEX works in scripts, 12 | 13 | 14 | # %% 15 | 16 | fname = "induction_json.json" 17 | plotly_graph = plotly.io.read_json(fname) 18 | 19 | # %% 20 | 21 | plotly_graph.show() 22 | 23 | # %% 24 | 25 | # Get new data 26 | paper_corrupted_fname = "current_paper_induction_json.json" 27 | paper_corrupted_figure = plotly.io.read_json(paper_corrupted_fname) 28 | paper_zero_figure = plotly.io.read_json("current_paper_induction_zero.json") 29 | 30 | # %% 31 | 32 | scatter_names = [ 33 | "ACDC", 34 | "SP", 35 | "HISP", 36 | ] 37 | 38 | x_data = {} 39 | y_data = {} 40 | color_data = {} 41 | 42 | figures = { 43 | "corrupted": paper_corrupted_figure, 44 | "zero": paper_zero_figure, 45 | } 46 | 47 | for figure_name, figure in figures.items(): 48 | for name in scatter_names: 49 | x_data[name] = [] 50 | y_data[name] = [] 51 | 52 | x_data_element = [thing for thing in figure["data"] if thing["name"] == name][0] 53 | y_data_element = [thing for thing in figure["data"] if thing["name"] == name][0] 54 | 55 | x_data[(figure_name, name)] = x_data_element["x"] 56 | y_data[(figure_name, name)] = y_data_element["y"] 57 | 58 | color_data[(figure_name, name)] = x_data_element["marker"]["color"] 59 | 60 | # %% 61 | 62 | # Add this into the paper figure 63 | 64 | the_ref = {} 65 | 66 | for i in range(len(plotly_graph.data)): 67 | if plotly_graph.data[i]["yaxis"] == "y" and plotly_graph.data[i]["name"] in scatter_names: 68 | plotly_graph.data[i]["x"] = x_data[("corrupted", plotly_graph.data[i]["name"])] 69 | plotly_graph.data[i]["y"] = y_data[("corrupted", plotly_graph.data[i]["name"])] 70 | if plotly_graph.data[i]["name"] == "ACDC": 71 | cur_col_data = color_data[("corrupted", plotly_graph.data[i]["name"])] 72 | plotly_graph.data[i]["marker"]["color"] = cur_col_data 73 | the_ref[plotly_graph.data[i]["name"]] =plotly_graph.data[i]["marker"] 74 | 75 | if plotly_graph.data[i]["yaxis"] == "y2" and plotly_graph.data[i]["name"] in scatter_names: 76 | plotly_graph.data[i]["x"] = x_data[("zero", plotly_graph.data[i]["name"])] 77 | plotly_graph.data[i]["y"] = y_data[("zero", plotly_graph.data[i]["name"])] 78 | # if plotly_graph.data[i]["name"] == "ACDC": 79 | # cur_col_data = color_data[("zero", plotly_graph.data[i]["name"])] 80 | # plotly_graph.data[i]["marker"]["color"] = cur_col_data 81 | # cur_col_data = color_data[("corrupted", plotly_graph.data[i]["name"])] 82 | 83 | 84 | for i in range(len(plotly_graph.data)): 85 | if plotly_graph.data[i]["yaxis"] == "y2" and plotly_graph.data[i]["name"] in scatter_names: 86 | print(plotly_graph.data[i]) 87 | print("Hey") 88 | plotly_graph.data[i]["marker"] = the_ref[plotly_graph.data[i]["name"]] # color_data[("corrupted", plotly_graph.data[i]["name"])] 89 | 90 | # %% 91 | 92 | plotly_graph.show() 93 | 94 | # %% 95 | 96 | lis = [] 97 | 98 | for i in range(len(plotly_graph.data)): 99 | if plotly_graph.data[i]["yaxis"] == "y3": 100 | # print(plotly_graph.data[i]) 101 | # break 102 | # print(i) 103 | lis.append(i) 104 | 105 | # %% 106 | 107 | # Remove this from the paper figure 108 | 109 | for l in sorted(lis[1:], reverse=True): 110 | plotly_graph.data = plotly_graph.data[:l] + plotly_graph.data[l+1:] 111 | 112 | # %% 113 | 114 | fig = plotly.io.read_json("my_new_plotly_graph.json") 115 | 116 | # %% 117 | 118 | fig.show() 119 | # %% 120 | 121 | # TODO add better legend 122 | -------------------------------------------------------------------------------- /notebooks/_converted/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | *.log 4 | tmp/ 5 | 6 | 7 | *.ipynb 8 | -------------------------------------------------------------------------------- /notebooks/auc_tables.py: -------------------------------------------------------------------------------- 1 | #%% 2 | 3 | from IPython import get_ipython 4 | ipython = get_ipython() 5 | if ipython is not None: 6 | ipython.run_line_magic('load_ext', 'autoreload') 7 | ipython.run_line_magic('autoreload', '2') 8 | import io 9 | import numpy as np 10 | 11 | import pandas as pd 12 | from pathlib import Path 13 | from tabulate import tabulate 14 | import argparse 15 | 16 | parser = argparse.ArgumentParser( 17 | usage="Generate AUC tables from CSV files. Pass the data.csv file as an argument fname, e.g python notebooks/auc_tables.py --fname=experiments/results/plots/data.csv" 18 | ) 19 | parser.add_argument('--in-fname', type=str, default="experiments/results/plots/data.csv") 20 | parser.add_argument('--out-fname', type=str, default="experiments/results/auc_tables.tex") 21 | 22 | if ipython is None: 23 | args = parser.parse_args() 24 | else: # make parsing arguments work in jupyter notebook 25 | args = parser.parse_args(args=[]) 26 | 27 | data = pd.read_csv(args.in_fname) 28 | 29 | # %% 30 | with io.StringIO() as buf: 31 | for key in ["auc"]: # ["test_kl_div", "test_loss", "auc"]: 32 | for weights_type in ["trained"]: # ["reset", "trained"] 33 | df = data[(data["weights_type"] == weights_type)] 34 | df = df.replace({"metric": df.metric.map(lambda x: "other" if x != "kl_div" else x)}) 35 | df = df.drop_duplicates(subset=["task", "method", "metric", "ablation_type", "plot_type"]) 36 | 37 | def process_metric_pretty(row): 38 | if row["metric"] == "kl_div": 39 | return "KL" 40 | else: 41 | return "Loss" 42 | 43 | df["metric_pretty"] = df.apply(process_metric_pretty, axis=1) 44 | out = df.drop("Unnamed: 0", axis=1).pivot_table( 45 | index=["metric_pretty", "task"], columns=["ablation_type", "plot_type", "method"], values=key 46 | ) 47 | # Needed to handle non-AUC keys 48 | # out = out.applymap(lambda x: None if x == -1 else (texts[x] if isinstance(x, int) else x)) 49 | out = out.dropna(axis=0) 50 | 51 | # %% Export as latex 52 | def export_table(out, name): 53 | def make_bold_column(row): 54 | out = pd.Series(dtype=np.float64) 55 | for plot_type in ["roc_edges", "roc_nodes"]: 56 | the_max = row.loc[plot_type].max() 57 | out = pd.concat( 58 | [ 59 | out, 60 | pd.Series( 61 | data=[ 62 | f"\\textbf{{{x:.3f}}}" if x == the_max else f"{x:.3f}" 63 | for x in row.loc[plot_type] 64 | ], 65 | index=row.loc[[plot_type]].index, 66 | ), 67 | ] 68 | ) 69 | return out 70 | 71 | old_out = out 72 | out = out.apply(make_bold_column, axis=1) 73 | out.columns = pd.MultiIndex.from_tuples(out.columns) 74 | out.style.to_latex(buf, hrules=True, environment="table", caption=name) 75 | 76 | export_table(out.random_ablation, f"Key={key}, weights_type={weights_type}, Random Ablation") 77 | export_table(out.zero_ablation, f"Key={key}, weights_type={weights_type}, Zero Ablation") 78 | 79 | with open(args.out_fname, "w") as f: 80 | f.write( 81 | r"""\documentclass{article} 82 | \usepackage{booktabs} 83 | \usepackage{multirow} 84 | \begin{document} 85 | 86 | WARNING: you need to add some vertical and horizontal rows for this to look good 87 | 88 | """ 89 | ) 90 | f.write(buf.getvalue().replace("_", "\\_")) 91 | f.write( 92 | r"""\end{document} 93 | """ 94 | ) 95 | -------------------------------------------------------------------------------- /notebooks/convert_to_ipynb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # TODO add this all to the Makefile 4 | 5 | set -e 6 | 7 | # Check for --skip-run flag 8 | skip_run=false 9 | for arg in "$@" 10 | do 11 | if [ "$arg" == "--skip-run" ]; then 12 | skip_run=true 13 | fi 14 | done 15 | 16 | required_commands=("jupytext" "papermill" "python") 17 | 18 | # Loop over each required command 19 | for command in "${required_commands[@]}"; do 20 | # Check if the command is installed 21 | if ! command -v $command &> /dev/null 22 | then 23 | echo "$command could not be found" 24 | echo "Please install it using 'pip install $command'" 25 | exit 26 | fi 27 | done 28 | 29 | # Define the file paths 30 | declare -A file_paths 31 | file_paths=( 32 | ["notebooks/editing_edges.py"]="notebooks/_converted/editing_edges.ipynb notebooks/colabs/ACDC_Editing_Edges_Demo.ipynb" 33 | ["acdc/main.py"]="notebooks/_converted/main_demo.ipynb notebooks/colabs/ACDC_Main_Demo.ipynb" 34 | ["notebooks/implementation_demo.py"]="notebooks/_converted/implementation_demo.ipynb notebooks/colabs/ACDC_Implementation_Demo.ipynb" 35 | ) 36 | 37 | # Loop over each file path 38 | for in_path in "${!file_paths[@]}"; do 39 | # Split the output paths 40 | IFS=' ' read -r -a out_paths <<< "${file_paths[$in_path]}" 41 | 42 | middle_path=${out_paths[0]} 43 | final_out_path=${out_paths[1]} 44 | 45 | # Run jupytext and papermill 46 | jupytext --to notebook "$in_path" -o "$middle_path" 47 | 48 | if ! $skip_run; then 49 | papermill "$middle_path" "$final_out_path" --kernel=python 50 | 51 | # TODO fix this; it seems some errored files are slipping through 52 | python -c " 53 | import nbformat 54 | nb = nbformat.read('$final_out_path', as_version=4) 55 | errors = [cell for cell in nb['cells'] if 'outputs' in cell and any(output.get('output_type') == 'error' for output in cell['outputs'])] 56 | if errors: 57 | raise Exception(f'Error: The following cells failed in notebook $final_out_path:\n{errors}') 58 | " 59 | else 60 | cp "$middle_path" "$final_out_path" 61 | fi 62 | done -------------------------------------------------------------------------------- /notebooks/df_plots_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from warnings import warn 3 | import warnings 4 | from IPython import get_ipython 5 | if get_ipython() is not None: 6 | get_ipython().magic('load_ext autoreload') 7 | get_ipython().magic('autoreload 2') 8 | 9 | __file__ = os.path.join(get_ipython().run_line_magic('pwd', ''), "notebooks", "df_plots_data.py") 10 | 11 | from notebooks.emacs_plotly_render import set_plotly_renderer 12 | if "adria" in __file__: 13 | set_plotly_renderer("emacs") 14 | 15 | import plotly 16 | import numpy as np 17 | import json 18 | import wandb 19 | from acdc.acdc_graphics import dict_merge, pessimistic_auc 20 | import time 21 | from plotly.subplots import make_subplots 22 | import plotly.graph_objects as go 23 | import plotly.colors as pc 24 | from pathlib import Path 25 | import plotly.express as px 26 | import pandas as pd 27 | import argparse 28 | 29 | # %% 30 | 31 | DATA_DIR = Path(__file__).resolve().parent.parent / "experiments" / "results" / "plots_data" 32 | all_data = {} 33 | 34 | for fname in os.listdir(DATA_DIR): 35 | if fname.endswith(".json"): 36 | with open(DATA_DIR / fname, "r") as f: 37 | data = json.load(f) 38 | dict_merge(all_data, data) 39 | 40 | 41 | 42 | # %% Possibly convert all this data to pandas dataframe 43 | 44 | rows = [] 45 | for weights_type, v in all_data.items(): 46 | for ablation_type, v2 in v.items(): 47 | for task, v3 in v2.items(): 48 | for metric, v4 in v3.items(): 49 | for alg, v5 in v4.items(): 50 | for i in range(len(v5["score"])): 51 | rows.append(pd.Series({ 52 | "weights_type": weights_type, 53 | "ablation_type": ablation_type, 54 | "task": task, 55 | "metric": metric, 56 | "alg": alg, 57 | **{k: val[i] for k, val in v5.items()}})) 58 | 59 | df = pd.DataFrame(rows) 60 | 61 | # %% Print KL 62 | present = df[(df["alg"] == "CANONICAL") & (df["weights_type"] == "trained") & (df["score"] == 1.0)][ 63 | ["ablation_type", "task", "metric", "test_kl_div"] 64 | ] 65 | present.sort_values(["task", "ablation_type"]) 66 | 67 | # %% Print Docstring stuff 68 | 69 | for TASK in ["docstring", "ioi", "greaterthan", "tracr-proportion", "tracr-reverse"]: 70 | print(TASK) 71 | present = df[ 72 | (df["alg"] == "CANONICAL") 73 | & (df["weights_type"] == "trained") 74 | & (df["metric"] != "kl_div") 75 | & (df["task"] == TASK) 76 | & (np.isfinite(df["score"])) 77 | ][["ablation_type", "score", *filter(lambda x: x.startswith("test_"), df.columns)]] 78 | present["type"] = present["score"].map(lambda x: {0.0: "corrupted_model", 1.0: "clean_model", 0.5: "canonical"}[x]) 79 | out = present[["ablation_type", "type", "test_docstring_metric", *filter(lambda x: x.startswith("test_"), present.columns)]] 80 | 81 | out.dropna(axis=1, how="all", inplace=True) 82 | print(out) 83 | -------------------------------------------------------------------------------- /notebooks/easier_roc_plot.py: -------------------------------------------------------------------------------- 1 | # %% 2 | # 3 | import plotly 4 | import os 5 | import numpy as np 6 | import json 7 | import wandb 8 | import time 9 | from plotly.subplots import make_subplots 10 | import plotly.graph_objects as go 11 | from pathlib import Path 12 | import plotly.express as px 13 | import pandas as pd 14 | import argparse 15 | import plotly.colors as pc 16 | from acdc.graphics import dict_merge, pessimistic_auc 17 | 18 | from notebooks.emacs_plotly_render import set_plotly_renderer 19 | 20 | set_plotly_renderer("emacs") 21 | 22 | 23 | # %% 24 | 25 | DATA_DIR = Path("acdc") / "media" / "plots_data" 26 | 27 | all_data = {} 28 | 29 | for fname in os.listdir(DATA_DIR): 30 | if fname.endswith(".json"): 31 | with open(DATA_DIR / fname, "r") as f: 32 | data = json.load(f) 33 | dict_merge(all_data, data) 34 | 35 | # %% 36 | 37 | 38 | def discard_non_pareto_optimal(points, cmp="gt"): 39 | ret = [] 40 | for x, y in points: 41 | for x1, y1 in points: 42 | if x1 < x and getattr(y1, f"__{cmp}__")(y) and (x1, y1) != (x, y): 43 | break 44 | else: 45 | ret.append((x, y)) 46 | return list(sorted(ret)) 47 | 48 | 49 | fig = make_subplots(rows=1, cols=2, subplot_titles=["ROC Curves"], column_widths=[0.95, 0.05], horizontal_spacing=0.03) 50 | 51 | colorscales = { 52 | "ACDC": "Blues", 53 | "SP": "Greens", 54 | } 55 | 56 | for i, alg in enumerate(["ACDC", "SP"]): 57 | this_data = all_data["trained"]["random_ablation"]["ioi"]["logit_diff"][alg] 58 | x_data = this_data["edge_fpr"] 59 | y_data = this_data["edge_tpr"] 60 | scores = this_data["score"] 61 | 62 | log_scores = np.log10(scores) 63 | log_scores = np.nan_to_num(log_scores, nan=0.0, neginf=0.0, posinf=0.0) 64 | 65 | min_score = np.min(log_scores) 66 | max_score = np.max(log_scores) 67 | 68 | normalized_scores = (log_scores - min_score) / (max_score - min_score) 69 | normalized_scores[~np.isfinite(normalized_scores)] = 0.0 70 | 71 | points = list(zip(x_data, y_data)) 72 | pareto_optimal = discard_non_pareto_optimal(points) 73 | 74 | methodof = "acdc" 75 | 76 | pareto_x_data, pareto_y_data = zip(*pareto_optimal) 77 | fig.add_trace( 78 | go.Scatter( 79 | x=pareto_x_data, 80 | y=pareto_y_data, 81 | name=methodof, 82 | mode="lines", 83 | line=dict( 84 | shape="hv", 85 | color=pc.sample_colorscale(pc.get_colorscale(colorscales[alg]), 0.7)[0], 86 | ), 87 | showlegend=False, 88 | hovertext=log_scores, 89 | ), 90 | row=1, 91 | col=1, 92 | ) 93 | 94 | N_TICKS = 5 95 | tickvals = np.linspace(0, 1, N_TICKS) 96 | ticktext = 10 ** np.linspace(min_score, max_score, N_TICKS) 97 | 98 | fig.add_trace( 99 | go.Scatter( 100 | x=x_data, 101 | y=y_data, 102 | name=methodof, 103 | mode="markers", 104 | showlegend=False, 105 | marker=dict( 106 | size=7, 107 | color=normalized_scores, 108 | colorscale=colorscales[alg], 109 | symbol="circle", 110 | # colorbar=dict( 111 | # title="Log scores", 112 | # tickvals=tickvals, # positions for ticks 113 | # ticktext=["%.2e" % i for i in ticktext], # tick labels, formatted as strings 114 | # thickness=5, 115 | # # y=0.25, 116 | # # len=0.5, 117 | # x=1 + i * 0.1, 118 | # ), 119 | showscale=False, 120 | ), 121 | ), 122 | row=1, 123 | col=1, 124 | ) 125 | nums = np.arange(200).reshape(2, 100).T.astype(float) 126 | nums[:20, :20] = np.nan 127 | fig.add_trace( 128 | go.Heatmap( 129 | z=nums, 130 | colorscale='Viridis', 131 | showscale=False, 132 | ), 133 | row=1, 134 | col=2, 135 | ) 136 | 137 | fig.update_xaxes(showline=False, zeroline=False, showgrid=False, row=1, col=2, showticklabels=False, ticks="") 138 | fig.update_yaxes(showline=False, zeroline=False, showgrid=False, row=1, col=2, side="right") 139 | fig.show() 140 | -------------------------------------------------------------------------------- /notebooks/editing_edges.py: -------------------------------------------------------------------------------- 1 | # %% [markdown] 2 | #

ACDC Editing Edges Demo

3 | # 4 | #

This notebook gives a high-level overview of the main abstractions used in the ACDC codebase.

5 | # 6 | #

If you are interested in models that are >=1B parameters, this library currently may be too slow and we would recommend you look at the path patching implementations in `TransformerLens` (for example, see this notebook)

7 | # 8 | #

Setup

9 | # 10 | #

Janky code to do different setup when run in a Colab notebook vs VSCode (adapted from e.g this notebook)

11 | # 12 | #

You can ignore warnings that "packages were previously imported in this runtime"

13 | 14 | #%% 15 | 16 | try: 17 | import google.colab 18 | 19 | IN_COLAB = True 20 | print("Running as a Colab notebook") 21 | 22 | import subprocess # to install graphviz dependencies 23 | command = ['apt-get', 'install', 'graphviz-dev'] 24 | subprocess.run(command, check=True) 25 | 26 | from IPython import get_ipython 27 | ipython = get_ipython() 28 | 29 | ipython.run_line_magic( # install ACDC 30 | "pip", 31 | "install git+https://github.com/ArthurConmy/Automatic-Circuit-Discovery.git@d89f7fa9cbd095202f3940c889cb7c6bf5a9b516", 32 | ) 33 | 34 | except Exception as e: 35 | IN_COLAB = False 36 | print("Running outside of Colab notebook") 37 | 38 | import numpy # crucial to not get cursed error 39 | import plotly 40 | 41 | plotly.io.renderers.default = "colab" # added by Arthur so running as a .py notebook with #%% generates .ipynb notebooks that display in colab 42 | # disable this option when developing rather than generating notebook outputs 43 | 44 | from IPython import get_ipython 45 | 46 | ipython = get_ipython() 47 | if ipython is not None: 48 | print("Running as a notebook") 49 | ipython.run_line_magic("load_ext", "autoreload") # type: ignore 50 | ipython.run_line_magic("autoreload", "2") # type: ignore 51 | else: 52 | print("Running as a .py script") 53 | 54 | # %% [markdown] 55 | #

Imports etc

56 | 57 | #%% 58 | 59 | from transformer_lens.HookedTransformer import HookedTransformer 60 | from acdc.TLACDCExperiment import TLACDCExperiment 61 | from acdc.induction.utils import get_all_induction_things 62 | from acdc.acdc_utils import TorchIndex 63 | import torch 64 | import gc 65 | 66 | # %% [markdown] 67 | #

Load in the model and data for the induction task 68 | 69 | #%% 70 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 71 | num_examples = 40 72 | seq_len = 50 73 | 74 | # load in a tl_model and grab some data 75 | all_induction_things = get_all_induction_things( 76 | num_examples=num_examples, 77 | seq_len=seq_len, 78 | device=DEVICE, 79 | ) 80 | 81 | tl_model, toks_int_values, toks_int_values_other, metric, mask_rep = ( 82 | all_induction_things.tl_model, 83 | all_induction_things.validation_data, 84 | all_induction_things.validation_patch_data, 85 | all_induction_things.validation_metric, 86 | all_induction_things.validation_mask, 87 | ) 88 | 89 | # You should read the get_model function from that file to see what the tl_model is : ) 90 | 91 | # %% [markdown] 92 | #

Ensure we stay under mem limit on small machines

93 | 94 | #%% 95 | gc.collect() 96 | torch.cuda.empty_cache() 97 | 98 | # %% [markdown] 99 | #

Let's see an example from the dataset.

100 | #

`|` separates tokens

101 | 102 | #%% 103 | EXAMPLE_NO = 33 104 | EXAMPLE_LENGTH = 36 105 | 106 | print( 107 | "|".join(tl_model.to_str_tokens(toks_int_values[EXAMPLE_NO, :EXAMPLE_LENGTH])), 108 | ) 109 | 110 | #%% [markdown] 111 | #

This dataset has several examples of induction! F -> #, mon -> ads

112 | #

The `mask_rep` mask is a boolean mask of shape `(num_examples, seq_len)` that indicates where induction is present in the dataset

113 | #

Let's see 114 | 115 | #%% 116 | for i in range(EXAMPLE_LENGTH): 117 | if mask_rep[EXAMPLE_NO, i]: 118 | print(f"At position {i} there is induction") 119 | print(tl_model.to_str_tokens(toks_int_values[EXAMPLE_NO:EXAMPLE_NO+1, i : i + 1])) 120 | 121 | # %% [markdown] 122 | #

Let's get the initial loss on the induction examples

123 | 124 | #%% 125 | def get_loss(model, data, mask): 126 | loss = model( 127 | data, 128 | return_type="loss", 129 | loss_per_token=True, 130 | ) 131 | return (loss * mask[:, :-1].int()).sum() / mask[:, :-1].int().sum() 132 | 133 | 134 | print(f"Loss: {get_loss(tl_model, toks_int_values, mask_rep)}") 135 | 136 | #%% [markdown] 137 | #

We will now wrap ACDC things inside an `experiment`for further experiments

138 | #

For more advanced usage of the `TLACDCExperiment` object (the main object in this codebase), see the README for links to the `main.py` and its demos

139 | 140 | #%% 141 | experiment = TLACDCExperiment( 142 | model=tl_model, 143 | threshold=0.0, 144 | ds=toks_int_values, 145 | ref_ds=None, # This argument is the corrupted dataset from the ACDC paper. We're going to do zero ablation here so we omit this 146 | metric=metric, 147 | zero_ablation=True, 148 | hook_verbose=False, 149 | ) 150 | 151 | # %% [markdown] 152 | 153 | #

Usually, the `TLACDCExperiment` efficiently add hooks to the model in order to do ACDC runs fast.

154 | #

For this tutorial, we'll add ALL the hooks so you can edit connections in the model as easily as possible.

155 | 156 | #%% 157 | experiment.model.reset_hooks() 158 | experiment.setup_model_hooks( 159 | add_sender_hooks=True, 160 | add_receiver_hooks=True, 161 | doing_acdc_runs=False, 162 | ) 163 | 164 | # %% [markdown] 165 | # Let's take a look at the edges 166 | 167 | #%% 168 | for edge_indices, edge in experiment.corr.all_edges().items(): 169 | # here's what's inside the edge 170 | receiver_name, receiver_index, sender_name, sender_index = edge_indices 171 | 172 | # for now, all edges should be present 173 | assert edge.present, edge_indices 174 | 175 | # %% [markdown] 176 | #

Let's make a function that's able to turn off all the connections from the nodes to the output, except the induction head (1.5 and 1.6)

177 | #

(we'll later turn ON all connections EXCEPT the induction heads)

178 | 179 | #%% 180 | def change_direct_output_connections(exp, invert=False): 181 | residual_stream_end_name = "blocks.1.hook_resid_post" 182 | residual_stream_end_index = TorchIndex([None]) 183 | induction_heads = [ 184 | ("blocks.1.attn.hook_result", TorchIndex([None, None, 5])), 185 | ("blocks.1.attn.hook_result", TorchIndex([None, None, 6])), 186 | ] 187 | 188 | inputs_to_residual_stream_end = exp.corr.edges[residual_stream_end_name][ 189 | residual_stream_end_index 190 | ] 191 | for sender_name in inputs_to_residual_stream_end: 192 | for sender_index in inputs_to_residual_stream_end[sender_name]: 193 | edge = inputs_to_residual_stream_end[sender_name][sender_index] 194 | is_induction_head = (sender_name, sender_index) in induction_heads 195 | 196 | if is_induction_head: 197 | edge.present = not invert 198 | 199 | else: 200 | edge.present = invert 201 | 202 | print( 203 | f"{'Adding' if (invert == is_induction_head) else 'Removing'} edge from {sender_name} {sender_index} to {residual_stream_end_name} {residual_stream_end_index}" 204 | ) 205 | 206 | 207 | change_direct_output_connections(experiment) 208 | print( 209 | "Loss with only the induction head direct connections:", 210 | get_loss(experiment.model, toks_int_values, mask_rep).item(), 211 | ) 212 | 213 | # %% [markdown] 214 | #

Let's turn ON all the connections EXCEPT the induction heads

215 | 216 | #%% 217 | change_direct_output_connections(experiment, invert=True) 218 | print( 219 | "Loss without the induction head direct connections:", 220 | get_loss(experiment.model, toks_int_values, mask_rep).item(), 221 | ) 222 | 223 | #%% [markdown] 224 | #

That's much larger!

225 | #

See acdc/main.py for how to run ACDC experiments; try `python acdc/main.py --help` or check the README for the links to this file

226 | -------------------------------------------------------------------------------- /notebooks/emacs_plotly_render.py: -------------------------------------------------------------------------------- 1 | """A utils file from Adria""" 2 | 3 | import hashlib 4 | import os 5 | 6 | import plotly.io as pio 7 | 8 | 9 | class EmacsRenderer(pio.base_renderers.ColabRenderer): 10 | save_dir = "ob-jupyter" 11 | base_url = f"http://localhost:8888/files" 12 | 13 | def to_mimebundle(self, fig_dict): 14 | html = super().to_mimebundle(fig_dict)["text/html"] 15 | 16 | mhash = hashlib.md5(html.encode("utf-8")).hexdigest() 17 | if not os.path.isdir(self.save_dir): 18 | os.mkdir(self.save_dir) 19 | fhtml = os.path.join(self.save_dir, mhash + ".html") 20 | with open(fhtml, "w") as f: 21 | f.write(html) 22 | 23 | return {"text/html": f'Click to open {fhtml}'} 24 | 25 | 26 | pio.renderers["emacs"] = EmacsRenderer() 27 | 28 | 29 | def set_plotly_renderer(renderer="emacs"): 30 | pio.renderers.default = renderer 31 | -------------------------------------------------------------------------------- /notebooks/minimal_acdc_node_roc.py: -------------------------------------------------------------------------------- 1 | #%% 2 | 3 | from IPython import get_ipython 4 | ipython = get_ipython() 5 | if ipython is not None: 6 | ipython.magic("%load_ext autoreload") 7 | ipython.magic("%autoreload 2") 8 | import os 9 | from pathlib import Path 10 | import json 11 | import plotly.graph_objects as go 12 | import plotly.express as px 13 | import matplotlib.pyplot as plt 14 | 15 | #%% 16 | 17 | # Set your root directory here 18 | ROOT_DIR = Path("/home/arthur/Documents/Automatic-Circuit-Discovery") 19 | assert ROOT_DIR.exists(), f"I don't think your ROOT_DIR is correct (ROOT_DIR = {ROOT_DIR})" 20 | 21 | # %% 22 | 23 | TASK = "ioi" 24 | METRIC = "kl_div" 25 | FNAME = f"experiments/results/plots_data/acdc-{TASK}-{METRIC}-False-0.json" 26 | FPATH = ROOT_DIR / FNAME 27 | assert FPATH.exists(), f"I don't think your FNAME is correct (FPATH = {FPATH})" 28 | 29 | # %% 30 | 31 | data = json.load(open(FPATH, "r")) 32 | 33 | # %% 34 | 35 | relevant_data = data["trained"]["random_ablation"]["ioi"]["kl_div"]["ACDC"] 36 | 37 | # %% 38 | 39 | node_tpr = relevant_data["node_tpr"] 40 | node_fpr = relevant_data["node_fpr"] 41 | 42 | # %% 43 | 44 | # We would just plot these, but sometimes points are not on the Pareto frontier 45 | 46 | def pareto_optimal_sublist(xs, ys): 47 | retx, rety = [], [] 48 | for x, y in zip(xs, ys): 49 | for x1, y1 in zip(xs, ys): 50 | if x1 > x and y1 < y: 51 | break 52 | else: 53 | retx.append(x) 54 | rety.append(y) 55 | indices = sorted(range(len(retx)), key=lambda i: retx[i]) 56 | return [retx[i] for i in indices], [rety[i] for i in indices] 57 | 58 | # %% 59 | 60 | pareto_node_tpr, pareto_node_fpr = pareto_optimal_sublist(node_tpr, node_fpr) 61 | 62 | # %% 63 | 64 | # Thanks GPT-4 for this code 65 | 66 | # Create the plot 67 | plt.figure() 68 | 69 | # Plot the ROC curve 70 | plt.step(pareto_node_fpr, pareto_node_tpr, where='post') 71 | 72 | # Add titles and labels 73 | plt.title("ROC Curve of number of Nodes recovered by ACDC") 74 | plt.xlabel("False Positive Rate") 75 | plt.ylabel("True Positive Rate") 76 | 77 | # Show the plot 78 | plt.show() 79 | 80 | # %% 81 | 82 | # Original code from https://plotly.com/python/line-and-scatter/ 83 | 84 | # I use plotly but it should be easy to adjust to matplotlib 85 | fig = go.Figure() 86 | fig.add_trace( 87 | go.Scatter( 88 | x=list(pareto_node_fpr), 89 | y=list(pareto_node_tpr), 90 | mode="lines", 91 | line=dict(shape="hv"), 92 | showlegend=False, 93 | ), 94 | ) 95 | 96 | fig.update_layout( 97 | title="ROC Curve of number of Nodes recovered by ACDC", 98 | xaxis_title="False Positive Rate", 99 | yaxis_title="True Positive Rate", 100 | ) 101 | 102 | fig.show() -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "acdc" 3 | version = "0.0.0" # This should automatically be set by the CD pipeline on release 4 | description = "ACDC: Automatic Circuit DisCovery implementation on top of TransformerLens" 5 | authors = ["Arthur Conmy, Adrià Garriga-Alonso"] 6 | license = "MIT" 7 | readme = "README.md" 8 | packages = [{include = "acdc"}, {include = "subnetwork_probing"}] 9 | 10 | [tool.poetry.dependencies] 11 | python = "^3.8" 12 | einops = "^0.6.0" 13 | numpy = [{ version = "^1.21", python = "<3.10" }, 14 | { version = "^1.23", python = ">=3.10" }] 15 | torch = ">=1.10, <2.0" 16 | datasets = "^2.7.1" 17 | transformers = "^4.35.0" 18 | tokenizers = "^0.15.0" 19 | tqdm = "^4.64.1" 20 | pandas = "^1.1.5" 21 | wandb = "^0.13.5" 22 | torchtyping = "^0.1.4" 23 | huggingface-hub = "^0.16.0" 24 | cmapy = "^0.6.6" 25 | networkx = "^3.1" 26 | plotly = "^5.12.0" 27 | kaleido = "0.2.1" 28 | pygraphviz = "^1.11" 29 | tracr = {git = "https://github.com/deepmind/tracr.git", rev = "e75ecda"} 30 | transformer-lens = "1.6.1" 31 | 32 | [tool.poetry.group.dev.dependencies] 33 | pytest = "^7.2.0" 34 | pytest-cov = "^4.0.0" 35 | jupyterlab = "^3.5.0" 36 | jupyter = "^1.0.0" 37 | 38 | [build-system] 39 | requires = ["poetry-core"] 40 | build-backend = "poetry.core.masonry.api" 41 | 42 | [tool.pytest.ini_options] 43 | filterwarnings = [ 44 | # Ignore numpy.distutils deprecation warning caused by pandas 45 | # More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils 46 | "ignore:distutils Version classes are deprecated:DeprecationWarning" 47 | ] 48 | markers = [ 49 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 50 | ] 51 | 52 | [tool.black] 53 | line-length = 120 54 | 55 | [tool.isort] 56 | profile = "black" 57 | line_length = 120 58 | skip_gitignore = true 59 | -------------------------------------------------------------------------------- /subnetwork_probing/README.md: -------------------------------------------------------------------------------- 1 | # Setup 2 | 3 | This implementation of Subnetwork Probing should install by default when installing the ACDC code. 4 | 5 | It hosts a fork of `transformer_lens` as a submodule. This should probably be changed in the future. 6 | 7 | The fork introduces the class `MaskedHookPoint`, which masks some of its values with either zero or stored activations. 8 | That's the only crucial difference with mainstream `transformer_lens`. Most of the complexity is in `train.py`. 9 | 10 | # Subnetwork Probing 11 | 12 | [Low-Complexity Probing via Finding Subnetworks](https://github.com/stevenxcao/subnetwork-probing) 13 | Steven Cao, Victor Sanh, Alexander M. Rush 14 | NAACL-HLT 2021 15 | 16 | # HISP 17 | 18 | [Are Sixteen Heads Really Better than One?](https://arxiv.org/abs/1905.10650) Michel et al 2019 19 | -------------------------------------------------------------------------------- /subnetwork_probing/create_reset_networks.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from pprint import pprint 4 | import gc 5 | 6 | from acdc.greaterthan.utils import get_all_greaterthan_things 7 | from acdc.induction.utils import get_all_induction_things 8 | from acdc.ioi.utils import get_all_ioi_things 9 | from acdc.tracr_task.utils import get_all_tracr_things, get_tracr_model_input_and_tl_model 10 | from acdc.docstring.utils import get_all_docstring_things, AllDataThings 11 | 12 | torch.set_grad_enabled(False) 13 | 14 | DEVICE = "cuda" 15 | 16 | def scramble_sd(seed, things: AllDataThings, name, scramble_heads=True, scramble_mlps=True, scramble_head_outputs: bool=False): 17 | model = things.tl_model 18 | old_sd = model.state_dict() 19 | n_heads = model.cfg.n_heads 20 | n_neurons = model.cfg.d_mlp 21 | d_head = model.cfg.d_head 22 | torch.manual_seed(seed) 23 | 24 | if scramble_head_outputs and not scramble_heads: 25 | raise NotImplementedError 26 | 27 | sd = {} 28 | for k, v in old_sd.items(): 29 | if scramble_heads and ("attn" in k and not (k.endswith("_O") or k.endswith("mask") or k.endswith("IGNORE"))): 30 | assert v.shape[0] == n_heads 31 | to_sd = v[torch.randperm(n_heads), ...].contiguous() 32 | if scramble_head_outputs: 33 | to_sd = to_sd[..., torch.randperm(d_head)].contiguous() 34 | sd[k] = to_sd 35 | elif scramble_mlps and ("mlp" in k): 36 | if k.endswith("W_in"): 37 | assert v.shape[1] == n_neurons 38 | sd[k] = v[:, torch.randperm(n_neurons)].contiguous() 39 | elif k.endswith("b_in"): 40 | assert v.shape[0] == n_neurons 41 | sd[k] = v[torch.randperm(n_neurons)].contiguous() 42 | else: 43 | # print(f"Leaving {k} intact") 44 | sd[k] = v 45 | else: 46 | # print(f"Leaving {k} intact") 47 | sd[k] = v 48 | assert sd.keys() == old_sd.keys() 49 | for k in sd.keys(): 50 | assert sd[k].shape == old_sd[k].shape 51 | 52 | torch.save({k: v.cpu() for k, v in sd.items()}, name + ".pt") 53 | 54 | def all_metrics(logits): 55 | m = {} 56 | for k, v in things.test_metrics.items(): 57 | m[k] = v(logits).item() 58 | return m 59 | 60 | metrics = {} 61 | metrics["trained_orig"] = all_metrics(model(things.test_data)) 62 | metrics["trained_patched"] = all_metrics(model(things.test_patch_data)) 63 | 64 | model.load_state_dict(sd) 65 | metrics["reset_orig"] = all_metrics(model(things.test_data)) 66 | metrics["reset_patched"] = all_metrics(model(things.test_patch_data)) 67 | 68 | with open(name + "_test_metrics.json", "w") as f: 69 | json.dump(metrics, f) 70 | pprint(metrics) 71 | 72 | 73 | # %% IOI 74 | 75 | things = get_all_ioi_things(num_examples=100, device=DEVICE, metric_name="kl_div") 76 | scramble_sd(1504304416, things, "ioi_reset_heads_neurons") 77 | del things 78 | gc.collect() 79 | 80 | # %% Tracr-reverse 81 | things = get_all_tracr_things(task="reverse", metric_name="kl_div", num_examples=6, device=DEVICE) 82 | scramble_sd(1207775456, things, "tracr_reverse_reset_heads_head_outputs_neurons", scramble_head_outputs=True) 83 | gc.collect() 84 | 85 | scramble_sd(1666927681, things, "tracr_reverse_reset_heads_neurons", scramble_head_outputs=False) 86 | del things 87 | gc.collect() 88 | 89 | # %% Tracr-proportion 90 | 91 | things = get_all_tracr_things(task="proportion", metric_name="kl_div", num_examples=50, device=DEVICE) 92 | scramble_sd(2126292961, things, "tracr_proportion_reset_heads_head_outputs_neurons", scramble_head_outputs=True) 93 | gc.collect() 94 | 95 | scramble_sd(913070797, things, "tracr_proportion_reset_heads_neurons", scramble_head_outputs=False) 96 | del things 97 | gc.collect() 98 | 99 | # %% Induction 100 | 101 | things = get_all_induction_things(num_examples=50, seq_len=300, device=DEVICE, metric="kl_div") 102 | scramble_sd(2016630123, things, "induction_reset_heads_neurons") 103 | del things 104 | gc.collect() 105 | 106 | # %% Docstring 107 | 108 | things = get_all_docstring_things(num_examples=50, seq_len=2, device=DEVICE, metric_name="kl_div", correct_incorrect_wandb=False) 109 | scramble_sd(814220622, things, "docstring_reset_heads_neurons") 110 | del things 111 | gc.collect() 112 | 113 | 114 | # %% GreaterThan 115 | 116 | things = get_all_greaterthan_things(num_examples=100, device=DEVICE, metric_name="kl_div") 117 | scramble_sd(1028419464, things, "greaterthan_reset_heads_neurons") 118 | del things 119 | gc.collect() 120 | -------------------------------------------------------------------------------- /subnetwork_probing/launch_grid_fill.py: -------------------------------------------------------------------------------- 1 | from experiments.launcher import KubernetesJob, WandbIdentifier, launch 2 | import numpy as np 3 | import random 4 | from typing import List, Optional 5 | 6 | METRICS_FOR_TASK = { 7 | "ioi": ["kl_div", "logit_diff"], 8 | "tracr-reverse": ["l2"], 9 | "tracr-proportion": ["kl_div", "l2"], 10 | "induction": ["kl_div", "nll"], 11 | "docstring": ["kl_div", "docstring_metric"], 12 | "greaterthan": ["kl_div", "greaterthan"], 13 | } 14 | 15 | 16 | def main(TASKS: list[str], job: Optional[KubernetesJob], name: str, testing: bool, reset_networks: bool): 17 | NUM_SPACINGS = 5 if reset_networks else 21 18 | expensive_base_regularization_params = np.concatenate( 19 | [ 20 | 10 ** np.linspace(-2, 0, 11), 21 | np.linspace(1, 10, 10)[1:], 22 | np.linspace(10, 250, 13)[1:], 23 | ] 24 | ) 25 | 26 | if reset_networks: 27 | base_regularization_params = 10 ** np.linspace(-2, 1.5, NUM_SPACINGS) 28 | else: 29 | base_regularization_params = expensive_base_regularization_params 30 | 31 | wandb_identifier = WandbIdentifier( 32 | run_name=f"{name}-res{int(reset_networks)}-{{i:05d}}", 33 | group_name="tracr-shuffled-redo", 34 | project="induction-sp-replicate") 35 | 36 | 37 | seed = 1507014021 38 | random.seed(seed) 39 | 40 | commands: List[List[str]] = [] 41 | for reset_network in [int(reset_networks)]: 42 | for zero_ablation in [0, 1]: 43 | for task in TASKS: 44 | for metric in METRICS_FOR_TASK[task]: 45 | if task.startswith("tracr"): 46 | # Typical metric value range: 0.0-0.1 47 | regularization_params = 10 ** np.linspace(-3, 0, 11) 48 | 49 | if task == "tracr-reverse": 50 | num_examples = 6 51 | seq_len = 5 52 | elif task == "tracr-proportion": 53 | num_examples = 50 54 | seq_len = 5 55 | else: 56 | raise ValueError("Unknown task") 57 | 58 | elif task == "greaterthan": 59 | if metric == "kl_div": 60 | # Typical metric value range: 0.0-20 61 | regularization_params = base_regularization_params 62 | elif metric == "greaterthan": 63 | # Typical metric value range: -1.0 - 0.0 64 | regularization_params = 10 ** np.linspace(-4, 2, NUM_SPACINGS) 65 | else: 66 | raise ValueError("Unknown metric") 67 | num_examples = 100 68 | seq_len = -1 69 | elif task == "docstring": 70 | seq_len = 41 71 | if metric == "kl_div": 72 | # Typical metric value range: 0.0-10.0 73 | regularization_params = expensive_base_regularization_params 74 | elif metric == "docstring_metric": 75 | # Typical metric value range: -1.0 - 0.0 76 | regularization_params = 10 ** np.linspace(-4, 2, 21) 77 | else: 78 | raise ValueError("Unknown metric") 79 | num_examples = 50 80 | elif task == "ioi": 81 | num_examples = 100 82 | seq_len = -1 83 | if metric == "kl_div": 84 | # Typical metric value range: 0.0-12.0 85 | regularization_params = base_regularization_params 86 | elif metric == "logit_diff": 87 | # Typical metric value range: -0.31 -- -0.01 88 | regularization_params = 10 ** np.linspace(-4, 2, NUM_SPACINGS) 89 | else: 90 | raise ValueError("Unknown metric") 91 | elif task == "induction": 92 | seq_len = 300 93 | num_examples = 50 94 | if metric == "kl_div": 95 | # Typical metric value range: 0.0-16.0 96 | regularization_params = expensive_base_regularization_params 97 | elif metric == "nll": 98 | # Typical metric value range: 0.0-16.0 99 | regularization_params = expensive_base_regularization_params 100 | else: 101 | raise ValueError("Unknown metric") 102 | else: 103 | raise ValueError("Unknown task") 104 | 105 | if job is None: 106 | device = "cpu" 107 | n_cpu = 4 108 | assert testing 109 | else: 110 | device = "cuda" if job.gpu else "cpu" 111 | n_cpu = job.cpu 112 | 113 | for lambda_reg in [0.01] if testing else regularization_params: 114 | command = [ 115 | "python", 116 | "subnetwork_probing/train.py", 117 | f"--task={task}", 118 | f"--lambda-reg={lambda_reg:.3f}", 119 | f"--wandb-name=agarriga-sp-{len(commands):05d}{'-optional' if task in ['induction', 'docstring'] else ''}", 120 | "--wandb-project=induction-sp-replicate", 121 | "--wandb-entity=remix_school-of-rock", 122 | "--wandb-group=tracr-shuffled-redo", 123 | f"--device={device}", 124 | f"--epochs={1 if testing else 10000}", 125 | f"--zero-ablation={zero_ablation}", 126 | f"--reset-subject={reset_network}", 127 | f"--seed={random.randint(0, 2**32 - 1)}", 128 | f"--loss-type={metric}", 129 | f"--num-examples={6 if testing else num_examples}", 130 | f"--seq-len={seq_len}", 131 | f"--n-loss-average-runs={1 if testing else 20}", 132 | "--wandb-dir=/training", # If it doesn't exist wandb will use /tmp 133 | f"--wandb-mode={'offline' if testing else 'online'}", 134 | f"--torch-num-threads={n_cpu}", 135 | ] 136 | commands.append(command) 137 | 138 | launch( 139 | commands, 140 | name=name, 141 | job=job, 142 | synchronous=True, 143 | check_wandb=wandb_identifier, 144 | just_print_commands=False, 145 | ) 146 | 147 | if __name__ == "__main__": 148 | for reset_networks in [False, True]: 149 | for task in ["tracr-reverse"]: 150 | main( 151 | [task], 152 | KubernetesJob( 153 | container="ghcr.io/rhaps0dy/automatic-circuit-discovery:1.7.2", 154 | cpu=4, 155 | gpu=0 if task.startswith("tracr") else 1, 156 | mount_training=False, 157 | ), 158 | name=f"sp-{task}", 159 | testing=False, 160 | reset_networks=reset_networks, 161 | ) 162 | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/fastai/nbdev 3 | rev: 2.2.10 4 | hooks: 5 | - id: nbdev_clean -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 neelnanda-io 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/Old_Demo.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/subnetwork_probing/transformer_lens/Old_Demo.ipynb -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/README.md: -------------------------------------------------------------------------------- 1 | # TransformerLens 2 | 3 | A modified version of TransformerLens. 4 | 5 | [![Pypi](https://img.shields.io/pypi/v/transformer-lens)](https://pypi.org/project/transformer-lens/) 6 | 7 | (Formerly known as EasyTransformer) 8 | 9 | ## [Start Here](https://neelnanda.io/transformer-lens-demo) 10 | 11 | ## A Library for Mechanistic Interpretability of Generative Language Models 12 | 13 | This is a library for doing [mechanistic interpretability](https://distill.pub/2020/circuits/zoom-in/) of GPT-2 Style language models. The goal of mechanistic interpretability is to take a trained model and reverse engineer the algorithms the model learned during training from its weights. It is a fact about the world today that we have computer programs that can essentially speak English at a human level (GPT-3, PaLM, etc), yet we have no idea how they work nor how to write one ourselves. This offends me greatly, and I would like to solve this! 14 | 15 | TransformerLens lets you load in an open source language model, like GPT-2, and exposes the internal activations of the model to you. You can cache any internal activation in the model, and add in functions to edit, remove or replace these activations as the model runs. The core design principle I've followed is to enable exploratory analysis. One of the most fun parts of mechanistic interpretability compared to normal ML is the extremely short feedback loops! The point of this library is to keep the gap between having an experiment idea and seeing the results as small as possible, to make it easy for **research to feel like play** and to enter a flow state. Part of what I aimed for is to make *my* experience of doing research easier and more fun, hopefully this transfers to you! 16 | 17 | I used to work for the [Anthropic interpretability team](transformer-circuits.pub), and I wrote this library because after I left and tried doing independent research, I got extremely frustrated by the state of open source tooling. There's a lot of excellent infrastructure like HuggingFace and DeepSpeed to *use* or *train* models, but very little to dig into their internals and reverse engineer how they work. **This library tries to solve that**, and to make it easy to get into the field even if you don't work at an industry org with real infrastructure! One of the great things about mechanistic interpretability is that you don't need large models or tons of compute. There are lots of important open problems that can be solved with a small model in a Colab notebook! 18 | 19 | The core features were heavily inspired by the interface to [Anthropic's excellent Garcon tool](https://transformer-circuits.pub/2021/garcon/index.html). Credit to Nelson Elhage and Chris Olah for building Garcon and showing me the value of good infrastructure for enabling exploratory research! 20 | 21 | ## Getting Started 22 | 23 | **Start with the [main demo](https://neelnanda.io/transformer-lens-demo) to learn how the library works, and the basic features**. 24 | 25 | To see what using it for exploratory analysis in practice looks like, check out [my notebook analysing Indirect Objection Identification](https://neelnanda.io/exploratory-analysis-demo) or [my recording of myself doing research](https://www.youtube.com/watch?v=yo4QvDn-vsU)! 26 | 27 | Mechanistic interpretability is a very young and small field, and there are a *lot* of open problems - if you would like to help, please try working on one! **Check out my [list of concrete open problems](https://docs.google.com/document/d/1WONBzNqfKIxERejrrPlQMyKqg7jSFW92x5UMXNrMdPo/edit) to figure out where to start.**. It begins with advice on skilling up, and key resources to check out. 28 | 29 | If you're new to transformers, check out my [what is a transformer tutorial](https://neelnanda.io/transformer-tutorial) and [tutorial on coding GPT-2 from scratch](https://neelnanda.io/transformer-tutorial-2) (with [an accompanying template](https://neelnanda.io/transformer-template) to write one yourself! 30 | 31 | ## Gallery 32 | 33 | User contributed examples of the library being used in action: 34 | * [Induction Heads Phase Change Replication](https://colab.research.google.com/github/ckkissane/induction-heads-transformer-lens/blob/main/Induction_Heads_Phase_Change.ipynb): A partial replication of [In-Context Learning and Induction Heads](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html) from Connor Kissane 35 | 36 | ## Advice for Reading the Code 37 | 38 | One significant design decision made was to have a single transformer implementation that could support a range of subtly different GPT-style models. This has the upside of interpretability code just working for arbitrary models when you change the model name in `HookedTransformer.from_pretrained`! But it has the significant downside that the code implementing the model (in `HookedTransformer.py` and `components.py`) can be difficult to read. I recommend starting with my [Clean Transformer Demo](https://neelnanda.io/transformer-solution), which is a clean, minimal implementation of GPT-2 with the same internal architecture and activation names as HookedTransformer, but is significantly clearer and better documented. 39 | 40 | ## Installation 41 | 42 | `pip install git+https://github.com/neelnanda-io/TransformerLens` 43 | 44 | Import the library with `import transformer_lens` 45 | 46 | (Note: This library used to be known as EasyTransformer, and some breaking changes have been made since the rename. If you need to use the old version with some legacy code, run `pip install git+https://github.com/neelnanda-io/TransformerLens@v1`.) 47 | 48 | ## Local Development 49 | 50 | ### DevContainer 51 | 52 | For a one-click setup of your development environment, this project includes a [DevContainer](https://containers.dev/). It can be used locally with [VS Code](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) or with [GitHub Codespaces](https://github.com/features/codespaces). 53 | 54 | ### Manual Setup 55 | 56 | This project uses [Poetry](https://python-poetry.org/docs/#installation) for package management. Install as follows (this will also setup your virtual environment): 57 | 58 | ```bash 59 | poetry config virtualenvs.in-project true 60 | poetry install --with dev 61 | ``` 62 | 63 | Optionally, if you want Jupyter Lab you can run `poetry run pip install jupyterlab` (to install in the same virtual environment), and then run with `poetry run jupyter lab`. 64 | 65 | Then the library can be imported as `import transformer_lens`. 66 | 67 | ### Testing 68 | 69 | If adding a feature, please add unit tests for it to the tests folder, and check that it hasn't broken anything major using the existing tests (install pytest and run it in the root TransformerLens/ directory) 70 | 71 | ## Citation 72 | 73 | Please cite this library as: 74 | ``` 75 | @misc{nandatransformerlens2022, 76 | title = {TransformerLens}, 77 | author = {Nanda, Neel}, 78 | url = {https://github.com/neelnanda-io/TransformerLens}, 79 | year = {2022} 80 | } 81 | ``` 82 | (This is my best guess for how citing software works, feel free to send a correction!) 83 | Also, if you're actually using this for your research, I'd love to chat! Reach out at neelnanda27@gmail.com 84 | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/easy_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.warning("DEPRECATED: Library has been renamed, import transformer_lens instead") 4 | from transformer_lens import * 5 | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/further_comments.md: -------------------------------------------------------------------------------- 1 | # Further Details on Config Options 2 | ## Shortformer Attention (`positional_embeddings_type == "shortformer"`) 3 | Shortformer style models are a variant on GPT-2 style positional embeddings, which do not add positional embeddings into the residual stream but instead add it in to the queries and keys immediately before multiplying by W_Q and W_K, and NOT having it around for the values or MLPs. It's otherwise the same - the positional embeddings are absolute, and are learned. The positional embeddings are NOT added to the residual stream in the standard way, and instead the queries and keys are calculated as W_Q(res_stream + pos_embed) and W_K(res_stream + pos_embed). The values and MLPs are calculated as W_V(res_stream) and W_MLP(res_stream) and so don't have access to positional information. This is otherwise the same as GPT-2 style positional embeddings. This is a variant on the Shortformer model from the paper [Shortformer: The Benefits of Shorter Sequences in Language Modeling](https://arxiv.org/abs/2012.15832). It's morally similar to rotary, which also only gives keys & queries access to positional info 4 | 5 | The original intention was to use this to do more efficient caching: caching is hard with absolute positional embeddings, since you can't translate the context window without recomputing the entire thing, but easier if the prior values and residual stream terms are the same. I've mostly implemented it because it makes it easier for models to form induction heads. I'm not entirely sure why, though hypothesise that it's because there's two ways for induction heads to form with positional embeddings in the residual stream and only one with shortformer style positional embeddings. 6 | 7 | # Weight Processing 8 | ## What is LayerNorm Folding? (`fold_ln`) 9 | [LayerNorm](https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1) is a common regularisation technique used in transformers. Annoyingly, unlike eg BatchNorm, it can't be turned off at inference time, it's a meaningful change to the mathematical function implemented by the transformer. From an interpretability perspective, this is a headache! And it's easy to shoot yourself in the foot by naively ignoring it - eg, making the mistake of saying neuron_pre = resid_mid @ W_in, rather than LayerNorm(resid_mid) @ W_in. This mistake is an OK approximation, but by folding in the LayerNorm we can do much better! 10 | 11 | TLDR: If we have LayerNorm (weights w_ln and b_ln) followed by a linear layer (W+b), we can reduce the LayerNorm to LayerNormPre (just centering & normalising) and follow it by a linear layer with `W_eff = w[:, None] * W` (element-wise multiplication) and `b_eff = b + b_ln @ W`. This is computationally equivalent, and it never makes sense to think of W and w_ln as separate objects, so HookedTransformer handles it for you when loading pre-trained weights - set fold_ln = False when loading a state dict if you want to turn this off 12 | 13 | Mathematically, LayerNorm is the following: 14 | ``` 15 | x1 = x0 - x0.mean() 16 | x2 = x1 / ((x1**2).mean()).sqrt() 17 | x3 = x2 * w 18 | x4 = x3 + b 19 | ``` 20 | 21 | Apart from dividing by the norm, these are all pretty straightforwards operations from a linear algebra perspective. And from an interpretability perspective, if anything is linear, it's really easy and you can mostly ignore it (everything breaks up into sums, you can freely change basis, don't need to track interference between terms, etc) - the hard part is engaging with non-linearities! 22 | 23 | A key thing to bear in mind is that EVERY time we read from the residual stream, we apply a LayerNorm - this gives us a lot of leverage to reason about it! 24 | 25 | So let's translate this into linear algebra notation. 26 | `x0` is a vector in `R^n` 27 | 28 | ``` 29 | x1 = x0 - x0.mean() 30 | = x0 - (x0.mean()) * ones (broadcasting, ones=torch.ones(n)) 31 | = x0 - (x0 @ ones/sqrt(n)) * ones/sqrt(n). 32 | ``` 33 | 34 | ones has norm sqrt(n), so ones/sqrt(n) is the unit vector in the diagonal direction. We're just projecting x0 onto this (fixed) vector and subtracting that value off. Alternately, we're projecting onto the n-1 dimensional subspace orthogonal to ones. 35 | 36 | Since LayerNorm is applied EVERY time we read from the stream, the model just never uses the ones direction of the residual stream, so it's essentially just decreasing d_model by one. We can simulate this by just centering all matrices writing to the residual stream. 37 | 38 | Why is removing this dimension useful? I have no idea! I'm not convinced it is... 39 | 40 | ``` 41 | x2 = x1 / ((x1**2).mean()).sqrt() (Ignoring eps) 42 | = (x1 / x1.norm()) * sqrt(n) 43 | ``` 44 | 45 | This is a projection onto the unit sphere (well, sphere of radius sqrt(n) - the norm of ones). This is fundamentally non-linear, eg doubling the input keeps the output exactly the same. 46 | 47 | This is by far the most irritating part of LayerNorm. I THINK it's mostly useful for numerical stability reasons and not used to do useful computation by the model, but I could easily be wrong! And interpreting a circuit containing LayerNorm sounds like a nightmare... 48 | 49 | In practice, you can mostly get aware with ignore this and treating the scaling factor as a constant, since it does apply across the entire residual stream for each token - this makes it a "global" property of the model's calculation, so for any specific question it hopefully doesn't matter that much. But when you're considering a sufficiently important circuit that it's a good fraction of the norm of the residual stream, it's probably worth thinking about. 50 | 51 | ``` 52 | x3 = x2 * w 53 | = x2 @ W_ln 54 | ``` 55 | 56 | (`W_ln` is a diagonal matrix with the weights of the LayerNorm - this is equivalent to element-wise multiplication) 57 | This is really easy to deal with - we're about to be input to a linear layer, and can say `(x2 @ W_ln) @ W = x2 @ (W_ln @ W) = x2 @ W_eff` - we can just fold the LayerNorm weights into the linear layer weights. 58 | 59 | `x4 = x3 + b` is similarly easy - `x4 @ W + B = x2 @ W_eff + B_eff`, where `W_eff = W_ln @ W` and `B_eff = B + b @ W` 60 | 61 | This function is calculating `W_eff` and `B_eff` for each layer reading from the residual stream and replacing W and B with those. 62 | 63 | A final optimisation we can make is to **center the reading weights**. x2 has mean 0, which means it's orthogonal to the vector of all ones (`x2 @ ones = x2.sum() = len(x2) * x2.mean()`). This means that the component of `W_eff` that's parallel to `ones` is irrelevant, and we can set that to zero. In code, this means `W_eff -= W_eff.mean(dim=0, keepdim=True)`. This doesn't change the computation but makes things a bit simpler. 64 | 65 | See this for more: https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization 66 | 67 | ## Centering Writing Weights (`center_writing_weight`) 68 | 69 | A related idea to folding layernorm - *every* component reading an input from the residual stream is preceded by a LayerNorm, which means that the mean of a residual stream vector (ie the component in the direction of all ones) never matters. This means we can remove the all ones component of weights and biases whose output *writes* to the residual stream. Mathematically, `W_writing -= W_writing.mean(dim=1, keepdim=True)` 70 | 71 | ## Centering Unembed (`center_unembed`) 72 | 73 | The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to every logit doesn't change the output), so we can simplify things by setting the mean of the logits to be zero. This is equivalent to setting the mean of every output vector of `W_U` to zero. In code, `W_U -= W_U.mean(dim=-1, keepdim=True)` 74 | 75 | ## Fold Value Biases (`fold_value_biases`) 76 | 77 | Each attention head has a value bias. Values are averaged to create mixed values (`z`), weighted by the attention pattern, but as the bias is constant, its contribution to `z` is exactly the same. The output of a head is `z @ W_O`, and so the value bias just linearly adds to the output of the head. This means that the value bias of a head has *nothing to do with the head*, and is just a constant added to the attention layer outputs. We can take the sum across these and `b_O` to get an "effective bias" for the layer. In code, we set `b_V=0.` and `b_O = (b_V @ W_O).sum(dim=0) + b_O` 78 | 79 |
Technical derivation 80 | 81 | `v = residual @ W_V[h] + broadcast_b_V[h]` for each head `h` (where `b_V` is broadcast up from shape `d_head` to shape `[position, d_head]`). And `z = pattern[h] @ v = pattern[h] @ residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]`. Because `pattern[h]` is `[destination_position, source_position]` and `broadcast_b_V` is *constant* along the `(source_)position` dimension, we're basically just multiplying it by the sum of the pattern across the `source_position` dimension, which is just 1. So it remains exactly the same, and so is just brodcast across the destination positions. 82 |
83 | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "transformer-lens" 3 | version = "0.0.0" # This is automatically set by the CD pipeline on release 4 | description = "An implementation of transformers tailored for mechanistic interpretability." 5 | authors = ["Neel Nanda <77788841+neelnanda-io@users.noreply.github.com>"] 6 | license = "MIT" 7 | readme = "README.md" 8 | packages = [{include = "transformer_lens"}] 9 | 10 | [tool.poetry.dependencies] 11 | python = "^3.7" 12 | einops = "^0.6.0" 13 | numpy = [{ version = "^1.21", python = "<3.10" }, 14 | { version = "^1.23", python = ">=3.10" }] 15 | torch = "^1.10" 16 | datasets = "^2.7.1" 17 | transformers = "^4.25.1" 18 | tqdm = "^4.64.1" 19 | pandas = "^1.1.5" 20 | wandb = "^0.13.5" 21 | fancy-einsum = "^0.0.3" 22 | torchtyping = "^0.1.4" 23 | rich = "^12.6.0" 24 | 25 | [tool.poetry.group.dev.dependencies] 26 | pytest = "^7.2.0" 27 | mypy = "^0.991" 28 | jupyter = "^1.0.0" 29 | circuitsvis = "^1.38.1" 30 | plotly = "^5.12.0" 31 | 32 | [tool.poetry.group.jupyter.dependencies] 33 | jupyterlab = "^3.5.0" 34 | 35 | [build-system] 36 | requires = ["poetry-core"] 37 | build-backend = "poetry.core.masonry.api" 38 | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="transformer_lens", 5 | version="0.1.0", 6 | packages=["transformer_lens"], 7 | license="LICENSE", 8 | description="An implementation of transformers tailored for mechanistic interpretability.", 9 | long_description=open("README.md").read(), 10 | install_requires=[ 11 | "einops", 12 | "numpy", 13 | "torch", 14 | "datasets", 15 | "transformers", 16 | "tqdm", 17 | "pandas", 18 | "datasets", 19 | "wandb", 20 | "fancy_einsum", 21 | "torchtyping", 22 | "rich", 23 | ], 24 | extras_require={"dev": ["pytest", "mypy"]}, 25 | ) 26 | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/transformer_lens/FactoredMatrix.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch 3 | from typing import Optional, Union, Tuple, List, Dict 4 | from torchtyping import TensorType as TT 5 | from functools import lru_cache 6 | from . import utils as utils 7 | from .torchtyping_helper import ( 8 | T as TH, 9 | ) # Need the import as since FactoredMatrix declares T (for transpose) 10 | 11 | 12 | class FactoredMatrix: 13 | """ 14 | Class to represent low rank factored matrices, where the matrix is represented as a product of two matrices. Has utilities for efficient calculation of eigenvalues, norm and SVD. 15 | """ 16 | 17 | def __init__(self, A: TT[..., TH.ldim, TH.mdim], B: TT[..., TH.mdim, TH.rdim]): 18 | self.A = A 19 | self.B = B 20 | assert self.A.size(-1) == self.B.size( 21 | -2 22 | ), f"Factored matrix must match on inner dimension, shapes were a: {self.A.shape}, b:{self.B.shape}" 23 | self.ldim = self.A.size(-2) 24 | self.rdim = self.B.size(-1) 25 | self.mdim = self.B.size(-2) 26 | self.has_leading_dims = (self.A.ndim > 2) or (self.B.ndim > 2) 27 | self.shape = torch.broadcast_shapes(self.A.shape[:-2], self.B.shape[:-2]) + ( 28 | self.ldim, 29 | self.rdim, 30 | ) 31 | self.A = self.A.broadcast_to(self.shape[:-2] + (self.ldim, self.mdim)) 32 | self.B = self.B.broadcast_to(self.shape[:-2] + (self.mdim, self.rdim)) 33 | 34 | def __matmul__( 35 | self, other: Union[TT[..., TH.rdim, TH.new_rdim], TT[TH.rdim], FactoredMatrix] 36 | ) -> Union[FactoredMatrix, TT[..., TH.ldim]]: 37 | if isinstance(other, torch.Tensor): 38 | if other.ndim < 2: 39 | # It's a vector, so we collapse the factorisation and just return a vector 40 | # Squeezing/Unsqueezing is to preserve broadcasting working nicely 41 | return (self.A @ (self.B @ other.unsqueeze(-1))).squeeze(-1) 42 | else: 43 | assert ( 44 | other.size(-2) == self.rdim 45 | ), f"Right matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}" 46 | if self.rdim > self.mdim: 47 | return FactoredMatrix(self.A, self.B @ other) 48 | else: 49 | return FactoredMatrix(self.AB, other) 50 | elif isinstance(other, FactoredMatrix): 51 | return (self @ other.A) @ other.B 52 | 53 | def __rmatmul__( 54 | self, other: Union[TT[..., TH.new_rdim, TH.ldim], TT[TH.ldim], FactoredMatrix] 55 | ) -> Union[FactoredMatrix, TT[..., TH.rdim]]: 56 | if isinstance(other, torch.Tensor): 57 | assert ( 58 | other.size(-1) == self.ldim 59 | ), f"Left matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}" 60 | if other.ndim < 2: 61 | # It's a vector, so we collapse the factorisation and just return a vector 62 | return ((other.unsqueeze(-2) @ self.A) @ self.B).squeeze(-1) 63 | elif self.ldim > self.mdim: 64 | return FactoredMatrix(other @ self.A, self.B) 65 | else: 66 | return FactoredMatrix(other, self.AB) 67 | elif isinstance(other, FactoredMatrix): 68 | return other.A @ (other.B @ self) 69 | 70 | @property 71 | def AB(self) -> TT[TH.leading_dims : ..., TH.ldim, TH.rdim]: 72 | """The product matrix - expensive to compute, and can consume a lot of GPU memory""" 73 | return self.A @ self.B 74 | 75 | @property 76 | def BA(self) -> TT[TH.leading_dims : ..., TH.rdim, TH.ldim]: 77 | """The reverse product. Only makes sense when ldim==rdim""" 78 | assert ( 79 | self.rdim == self.ldim 80 | ), f"Can only take ba if ldim==rdim, shapes were self: {self.shape}" 81 | return self.B @ self.A 82 | 83 | @property 84 | def T(self) -> FactoredMatrix: 85 | return FactoredMatrix(self.B.transpose(-2, -1), self.A.transpose(-2, -1)) 86 | 87 | @lru_cache(maxsize=None) 88 | def svd( 89 | self, 90 | ) -> Tuple[ 91 | TT[TH.leading_dims : ..., TH.ldim, TH.mdim], 92 | TT[TH.leading_dims : ..., TH.mdim], 93 | TT[TH.leading_dims : ..., TH.rdim, TH.mdim], 94 | ]: 95 | """ 96 | Efficient algorithm for finding Singular Value Decomposition, a tuple (U, S, Vh) for matrix M st S is a vector and U, Vh are orthogonal matrices, and U @ S.diag() @ Vh.T == M 97 | 98 | (Note that Vh is given as the transpose of the obvious thing) 99 | """ 100 | Ua, Sa, Vha = torch.svd(self.A) 101 | Ub, Sb, Vhb = torch.svd(self.B) 102 | middle = Sa[..., :, None] * utils.transpose(Vha) @ Ub * Sb[..., None, :] 103 | Um, Sm, Vhm = torch.svd(middle) 104 | U = Ua @ Um 105 | Vh = Vhb @ Vhm 106 | S = Sm 107 | return U, S, Vh 108 | 109 | @property 110 | def U(self) -> TT[TH.leading_dims : ..., TH.ldim, TH.mdim]: 111 | return self.svd()[0] 112 | 113 | @property 114 | def S(self) -> TT[TH.leading_dims : ..., TH.mdim]: 115 | return self.svd()[1] 116 | 117 | @property 118 | def Vh(self) -> TT[TH.leading_dims : ..., TH.rdim, TH.mdim]: 119 | return self.svd()[2] 120 | 121 | @property 122 | def eigenvalues(self) -> TT[TH.leading_dims : ..., TH.mdim]: 123 | """Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv, so Av is an eigenvector of AB with eigenvalue k.""" 124 | return torch.linalg.eig(self.BA).eigenvalues 125 | 126 | def __getitem__(self, idx: Union[int, Tuple]) -> FactoredMatrix: 127 | """Indexing - assumed to only apply to the leading dimensions.""" 128 | if not isinstance(idx, tuple): 129 | idx = (idx,) 130 | length = len([i for i in idx if i is not None]) 131 | if length <= len(self.shape) - 2: 132 | return FactoredMatrix(self.A[idx], self.B[idx]) 133 | elif length == len(self.shape) - 1: 134 | return FactoredMatrix(self.A[idx], self.B[idx[:-1]]) 135 | elif length == len(self.shape): 136 | return FactoredMatrix( 137 | self.A[idx[:-1]], self.B[idx[:-2] + (slice(None), idx[-1])] 138 | ) 139 | else: 140 | raise ValueError( 141 | f"{idx} is too long an index for a FactoredMatrix with shape {self.shape}" 142 | ) 143 | 144 | def norm(self) -> TT[TH.leading_dims : ...]: 145 | """ 146 | Frobenius norm is sqrt(sum of squared singular values) 147 | """ 148 | return self.S.pow(2).sum(-1).sqrt() 149 | 150 | def __repr__(self): 151 | return f"FactoredMatrix: Shape({self.shape}), Hidden Dim({self.mdim})" 152 | 153 | def make_even(self) -> FactoredMatrix: 154 | """ 155 | Returns the factored form of (U @ S.sqrt().diag(), S.sqrt().diag() @ Vh) where U, S, Vh are the SVD of the matrix. This is an equivalent factorisation, but more even - each half has half the singular values, and orthogonal rows/cols 156 | """ 157 | return FactoredMatrix( 158 | self.U * self.S.sqrt()[..., None, :], 159 | self.S.sqrt()[..., :, None] * utils.transpose(self.Vh), 160 | ) 161 | 162 | def get_corner(self, k=3): 163 | return utils.get_corner(self.A[..., :k, :] @ self.B[..., :, :k], k) 164 | 165 | @property 166 | def ndim(self) -> int: 167 | return len(self.shape) 168 | 169 | def collapse_l(self) -> TT[TH.leading_dims : ..., TH.mdim, TH.rdim]: 170 | """ 171 | Collapses the left side of the factorization by removing the orthogonal factor (given by self.U). Returns a (..., mdim, rdim) tensor 172 | """ 173 | return self.S[..., :, None] * utils.transpose(self.Vh) 174 | 175 | def collapse_r(self) -> TT[TH.leading_dims : ..., TH.ldim, TH.mdim]: 176 | """ 177 | Analogous to collapse_l, returns a (..., ldim, mdim) tensor 178 | """ 179 | return self.U * self.S[..., None, :] 180 | 181 | def unsqueeze(self, k: int) -> FactoredMatrix: 182 | return FactoredMatrix(self.A.unsqueeze(k), self.B.unsqueeze(k)) 183 | 184 | @property 185 | def pair( 186 | self, 187 | ) -> Tuple[ 188 | TT[TH.leading_dims : ..., TH.ldim, TH.mdim], 189 | TT[TH.leading_dims : ..., TH.mdim, TH.rdim], 190 | ]: 191 | return (self.A, self.B) 192 | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/transformer_lens/__init__.py: -------------------------------------------------------------------------------- 1 | from . import hook_points 2 | from . import utils 3 | from . import evals 4 | from .past_key_value_caching import ( 5 | HookedTransformerKeyValueCache, 6 | HookedTransformerKeyValueCacheEntry, 7 | ) 8 | from . import components 9 | from .HookedTransformerConfig import HookedTransformerConfig 10 | from .FactoredMatrix import FactoredMatrix 11 | from .ActivationCache import ActivationCache 12 | from .HookedTransformer import HookedTransformer 13 | from . import loading_from_pretrained as loading 14 | from . import patching 15 | from . import train 16 | 17 | from .past_key_value_caching import ( 18 | HookedTransformerKeyValueCache as EasyTransformerKeyValueCache, 19 | HookedTransformerKeyValueCacheEntry as EasyTransformerKeyValueCacheEntry, 20 | ) 21 | from .HookedTransformer import HookedTransformer as EasyTransformer 22 | from .HookedTransformerConfig import HookedTransformerConfig as EasyTransformerConfig 23 | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/transformer_lens/evals.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """ 3 | A file with some rough evals for models - I expect you to be likely better off using the HuggingFace evaluate library if you want to do anything properly, but this is here if you want it and want to eg cheaply and roughly compare models you've trained to baselines. 4 | """ 5 | 6 | import torch 7 | import tqdm.auto as tqdm 8 | from datasets import load_dataset 9 | from . import HookedTransformer, HookedTransformerConfig, utils 10 | from torch.utils.data import DataLoader 11 | import einops 12 | 13 | # %% 14 | def sanity_check(model): 15 | """ 16 | Very basic eval - just feeds a string into the model (in this case, the first paragraph of Circuits: Zoom In), and returns the loss. It's a rough and quick sanity check - if the loss is <5 the model is probably OK, if the loss is >7 something's gone wrong. 17 | 18 | Note that this is a very basic eval, and doesn't really tell you much about the model's performance. 19 | """ 20 | 21 | text = "Many important transition points in the history of science have been moments when science 'zoomed in.' At these points, we develop a visualization or tool that allows us to see the world in a new level of detail, and a new field of science develops to study the world through this lens." 22 | 23 | return model(text, return_type="loss") 24 | 25 | 26 | # %% 27 | def make_wiki_data_loader(tokenizer, batch_size=8): 28 | """ 29 | Evaluate on Wikitext 2, a dump of Wikipedia articles. (Using the train set because it's larger, I don't really expect anyone to bother with quarantining the validation set nowadays.) 30 | 31 | Note there's likely to be dataset leakage into training data (though I believe GPT-2 was explicitly trained on non-Wikipedia data) 32 | """ 33 | wiki_data = load_dataset("wikitext", "wikitext-2-v1", split="train") 34 | print(len(wiki_data)) 35 | dataset = utils.tokenize_and_concatenate(wiki_data, tokenizer) 36 | data_loader = DataLoader( 37 | dataset, batch_size=batch_size, shuffle=True, drop_last=True 38 | ) 39 | return data_loader 40 | 41 | 42 | def make_owt_data_loader(tokenizer, batch_size=8): 43 | """ 44 | Evaluate on OpenWebText an open source replication of the GPT-2 training corpus (Reddit links with >3 karma) 45 | 46 | I think the Mistral models were trained on this dataset, so they get very good performance. 47 | """ 48 | owt_data = load_dataset("stas/openwebtext-10k", split="train") 49 | print(len(owt_data)) 50 | dataset = utils.tokenize_and_concatenate(owt_data, tokenizer) 51 | data_loader = DataLoader( 52 | dataset, batch_size=batch_size, shuffle=True, drop_last=True 53 | ) 54 | return data_loader 55 | 56 | 57 | def make_pile_data_loader(tokenizer, batch_size=8): 58 | """ 59 | Evaluate on OpenWebText an open source replication of the GPT-2 training corpus (Reddit links with >3 karma) 60 | 61 | I think the Mistral models were trained on this dataset, so they get very good performance. 62 | """ 63 | pile_data = load_dataset("NeelNanda/pile-10k", split="train") 64 | print(len(pile_data)) 65 | dataset = utils.tokenize_and_concatenate(pile_data, tokenizer) 66 | data_loader = DataLoader( 67 | dataset, batch_size=batch_size, shuffle=True, drop_last=True 68 | ) 69 | return data_loader 70 | 71 | 72 | def make_code_data_loader(tokenizer, batch_size=8): 73 | """ 74 | Evaluate on the CodeParrot dataset, a dump of Python code. All models seem to get significantly lower loss here (even non-code trained models like GPT-2), presumably code is much easier to predict than natural language? 75 | """ 76 | code_data = load_dataset("codeparrot/codeparrot-valid-v2-near-dedup", split="train") 77 | print(len(code_data)) 78 | dataset = utils.tokenize_and_concatenate( 79 | code_data, tokenizer, column_name="content" 80 | ) 81 | data_loader = DataLoader( 82 | dataset, batch_size=batch_size, shuffle=True, drop_last=True 83 | ) 84 | return data_loader 85 | 86 | 87 | DATASET_NAMES = ["wiki", "owt", "pile", "code"] 88 | DATASET_LOADERS = [ 89 | make_wiki_data_loader, 90 | make_owt_data_loader, 91 | make_pile_data_loader, 92 | make_code_data_loader, 93 | ] 94 | 95 | # %% 96 | @torch.inference_mode() 97 | def evaluate_on_dataset(model, data_loader, truncate=100): 98 | running_loss = 0 99 | total = 0 100 | for batch in tqdm.tqdm(data_loader): 101 | loss = model(batch["tokens"].cuda(), return_type="loss").mean() 102 | running_loss += loss.item() 103 | total += 1 104 | if total > truncate: 105 | break 106 | return running_loss / total 107 | 108 | 109 | # %% 110 | @torch.inference_mode() 111 | def induction_loss( 112 | model, tokenizer=None, batch_size=4, subseq_len=384, prepend_bos=True 113 | ): 114 | """ 115 | Generates a batch of random sequences repeated twice, and measures model performance on the second half. Tests whether a model has induction heads. 116 | 117 | By default, prepends a beginning of string token (prepend_bos flag), which is useful to give models a resting position, and sometimes models were trained with this. 118 | """ 119 | # Make the repeated sequence 120 | first_half_tokens = torch.randint(100, 20000, (batch_size, subseq_len)).cuda() 121 | repeated_tokens = einops.repeat(first_half_tokens, "b p -> b (2 p)") 122 | 123 | # Prepend a Beginning Of String token 124 | if prepend_bos: 125 | if tokenizer is None: 126 | tokenizer = model.tokenizer 127 | repeated_tokens[:, 0] = tokenizer.bos_token_id 128 | # Run the model, and extract the per token correct log prob 129 | logits = model(repeated_tokens, return_type="logits") 130 | correct_log_probs = utils.lm_cross_entropy_loss( 131 | logits, repeated_tokens, per_token=True 132 | ) 133 | # Take the loss over the second half of the sequence 134 | return correct_log_probs[:, subseq_len + 1 :].mean() 135 | 136 | 137 | # %% 138 | @torch.inference_mode() 139 | def evaluate(model, truncate=100, batch_size=8, tokenizer=None): 140 | if tokenizer is None: 141 | tokenizer = model.tokenizer 142 | losses = {} 143 | for data_name, data_loader_fn in zip(DATASET_NAMES, DATASET_LOADERS): 144 | data_loader = data_loader_fn(tokenizer=tokenizer, batch_size=batch_size) 145 | loss = evaluate_on_dataset(model, data_loader, truncate=truncate) 146 | print(f"{data_name}: {loss}") 147 | losses[f"{data_name}_loss"] = loss 148 | return losses 149 | 150 | 151 | # %% 152 | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/transformer_lens/make_docs.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from easy_transformer import loading 3 | from easy_transformer import utils 4 | from functools import lru_cache 5 | 6 | # %% 7 | cfg = loading.get_pretrained_model_config("solu-1l") 8 | print(cfg) 9 | # %% 10 | """ 11 | Structure: 12 | d_model, d_mlp, d_head, d_vocab, act_fn, n_heads, n_layers, n_ctx, n_params, 13 | Make an architecture table separately probs 14 | tokenizer_name, training_data, has checkpoints 15 | act_fn includes attn_only 16 | architecture 17 | Architecture should list weird shit to be aware of. 18 | """ 19 | import pandas as pd 20 | import numpy as np 21 | 22 | df = pd.DataFrame(np.random.randn(2, 2)) 23 | print(df.to_markdown(open("test.md", "w"))) 24 | # %% 25 | @lru_cache(maxsize=None) 26 | def get_config(model_name): 27 | return loading.get_pretrained_model_config(model_name) 28 | 29 | 30 | def get_property(name, model_name): 31 | cfg = get_config(model_name) 32 | if name == "act_fn": 33 | if cfg.attn_only: 34 | return "attn_only" 35 | elif cfg.act_fn == "gelu_new": 36 | return "gelu" 37 | elif cfg.act_fn == "gelu_fast": 38 | return "gelu" 39 | elif cfg.act_fn == "solu_ln": 40 | return "solu" 41 | else: 42 | return cfg.act_fn 43 | if name == "n_params": 44 | n_params = cfg.n_params 45 | if n_params < 1e4: 46 | return f"{n_params/1e3:.1f}K" 47 | elif n_params < 1e6: 48 | return f"{round(n_params/1e3)}K" 49 | elif n_params < 1e7: 50 | return f"{n_params/1e6:.1f}M" 51 | elif n_params < 1e9: 52 | return f"{round(n_params/1e6)}M" 53 | elif n_params < 1e10: 54 | return f"{n_params/1e9:.1f}B" 55 | elif n_params < 1e12: 56 | return f"{round(n_params/1e9)}B" 57 | else: 58 | raise ValueError(f"Passed in {n_params} above 1T?") 59 | else: 60 | return cfg.to_dict()[name] 61 | 62 | 63 | column_names = ( 64 | "n_params, n_layers, d_model, n_heads, act_fn, n_ctx, d_vocab, d_head, d_mlp".split( 65 | ", " 66 | ) 67 | ) 68 | print(column_names) 69 | df = pd.DataFrame( 70 | { 71 | name: [ 72 | get_property(name, model_name) 73 | for model_name in loading.DEFAULT_MODEL_ALIASES 74 | ] 75 | for name in column_names 76 | }, 77 | index=loading.DEFAULT_MODEL_ALIASES, 78 | ) 79 | display(df) 80 | df.to_markdown(open("model_properties_table.md", "w")) 81 | # %% 82 | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/transformer_lens/model_properties_table.md: -------------------------------------------------------------------------------- 1 | | | d_model | d_mlp | d_head | d_vocab | act_fn | n_heads | n_layers | n_ctx | n_params | 2 | |:-------------------------------|----------:|--------:|---------:|----------:|:----------|----------:|-----------:|--------:|:-----------| 3 | | gpt2-small | 768 | 3072 | 64 | 50257 | gelu | 12 | 12 | 1024 | 85M | 4 | | gpt2-medium | 1024 | 4096 | 64 | 50257 | gelu | 16 | 24 | 1024 | 302M | 5 | | gpt2-large | 1280 | 5120 | 64 | 50257 | gelu | 20 | 36 | 1024 | 708M | 6 | | gpt2-xl | 1600 | 6400 | 64 | 50257 | gelu | 25 | 48 | 1024 | 1.5B | 7 | | distillgpt2 | 768 | 3072 | 64 | 50257 | gelu | 12 | 6 | 1024 | 42M | 8 | | opt-125m | 768 | 3072 | 64 | 50272 | relu | 12 | 12 | 2048 | 85M | 9 | | opt-1.3b | 2048 | 8192 | 64 | 50272 | relu | 32 | 24 | 2048 | 1.2B | 10 | | opt-2.7b | 2560 | 10240 | 80 | 50272 | relu | 32 | 32 | 2048 | 2.5B | 11 | | opt-6.7b | 4096 | 16384 | 128 | 50272 | relu | 32 | 32 | 2048 | 6.4B | 12 | | opt-13b | 5120 | 20480 | 128 | 50272 | relu | 40 | 40 | 2048 | 13B | 13 | | opt-30b | 7168 | 28672 | 128 | 50272 | relu | 56 | 48 | 2048 | 30B | 14 | | opt-66b | 9216 | 36864 | 128 | 50272 | relu | 72 | 64 | 2048 | 65B | 15 | | gpt-neo-125M | 768 | 3072 | 64 | 50257 | gelu | 12 | 12 | 2048 | 85M | 16 | | gpt-neo-1.3B | 2048 | 8192 | 128 | 50257 | gelu | 16 | 24 | 2048 | 1.2B | 17 | | gpt-neo-2.7B | 2560 | 10240 | 128 | 50257 | gelu | 20 | 32 | 2048 | 2.5B | 18 | | gpt-j-6B | 4096 | 16384 | 256 | 50400 | gelu | 16 | 28 | 2048 | 5.6B | 19 | | gpt-neox-20b | 6144 | 24576 | 96 | 50432 | gelu_fast | 64 | 44 | 2048 | 20B | 20 | | stanford-gpt2-small-a | 768 | 3072 | 64 | 50257 | gelu | 12 | 12 | 1024 | 85M | 21 | | stanford-gpt2-small-b | 768 | 3072 | 64 | 50257 | gelu | 12 | 12 | 1024 | 85M | 22 | | stanford-gpt2-small-c | 768 | 3072 | 64 | 50257 | gelu | 12 | 12 | 1024 | 85M | 23 | | stanford-gpt2-small-d | 768 | 3072 | 64 | 50257 | gelu | 12 | 12 | 1024 | 85M | 24 | | stanford-gpt2-small-e | 768 | 3072 | 64 | 50257 | gelu | 12 | 12 | 1024 | 85M | 25 | | stanford-gpt2-medium-a | 1024 | 4096 | 64 | 50257 | gelu | 16 | 24 | 1024 | 302M | 26 | | stanford-gpt2-medium-b | 1024 | 4096 | 64 | 50257 | gelu | 16 | 24 | 1024 | 302M | 27 | | stanford-gpt2-medium-c | 1024 | 4096 | 64 | 50257 | gelu | 16 | 24 | 1024 | 302M | 28 | | stanford-gpt2-medium-d | 1024 | 4096 | 64 | 50257 | gelu | 16 | 24 | 1024 | 302M | 29 | | stanford-gpt2-medium-e | 1024 | 4096 | 64 | 50257 | gelu | 16 | 24 | 1024 | 302M | 30 | | pythia-70m | 512 | 2048 | 64 | 50304 | gelu | 8 | 6 | 2048 | 19M | 31 | | pythia-160m | 768 | 3072 | 64 | 50304 | gelu | 12 | 12 | 2048 | 85M | 32 | | pythia-410m | 1024 | 4096 | 64 | 50304 | gelu | 16 | 24 | 2048 | 302M | 33 | | pythia-1b | 2048 | 8192 | 256 | 50304 | gelu | 8 | 16 | 2048 | 805M | 34 | | pythia-1.4b | 2048 | 8192 | 128 | 50304 | gelu | 16 | 24 | 2048 | 1.2B | 35 | | pythia-6.9b | 4096 | 16384 | 128 | 50432 | gelu | 32 | 32 | 2048 | 6.4B | 36 | | pythia-12b | 5120 | 20480 | 128 | 50688 | gelu | 40 | 36 | 2048 | 11B | 37 | | pythia-70m-deduped | 512 | 2048 | 64 | 50304 | gelu | 8 | 6 | 2048 | 19M | 38 | | pythia-160m-deduped | 768 | 3072 | 64 | 50304 | gelu | 12 | 12 | 2048 | 85M | 39 | | EleutherAI/pythia-410m-deduped | 1024 | 4096 | 64 | 50304 | gelu | 16 | 24 | 2048 | 302M | 40 | | pythia-1.4b-deduped | 2048 | 8192 | 128 | 50304 | gelu | 16 | 24 | 2048 | 1.2B | 41 | | pythia-6.9b-deduped | 4096 | 16384 | 128 | 50432 | gelu | 32 | 32 | 2048 | 6.4B | 42 | | pythia-12b-deduped | 5120 | 20480 | 128 | 50688 | gelu | 40 | 36 | 2048 | 11B | 43 | | solu-1l-old | 1024 | 4096 | 64 | 50278 | solu | 16 | 1 | 1024 | 13M | 44 | | solu-2l-old | 736 | 2944 | 64 | 50278 | solu | 11 | 2 | 1024 | 13M | 45 | | solu-4l-old | 512 | 2048 | 64 | 50278 | solu | 8 | 4 | 1024 | 13M | 46 | | solu-6l-old | 768 | 3072 | 64 | 50278 | solu | 12 | 6 | 1024 | 42M | 47 | | solu-8l-old | 1024 | 4096 | 64 | 50278 | solu | 16 | 8 | 1024 | 101M | 48 | | solu-10l-old | 1280 | 5120 | 64 | 50278 | solu | 20 | 10 | 1024 | 197M | 49 | | solu-12l-old | 1536 | 6144 | 64 | 50278 | solu | 24 | 12 | 1024 | 340M | 50 | | solu-1l | 512 | 2048 | 64 | 48262 | solu | 8 | 1 | 1024 | 3.1M | 51 | | solu-2l | 512 | 2048 | 64 | 48262 | solu | 8 | 2 | 1024 | 6.3M | 52 | | solu-3l | 512 | 2048 | 64 | 48262 | solu | 8 | 3 | 1024 | 9.4M | 53 | | solu-4l | 512 | 2048 | 64 | 48262 | solu | 8 | 4 | 1024 | 13M | 54 | | solu-6l | 768 | 3072 | 64 | 48262 | solu | 12 | 6 | 1024 | 42M | 55 | | solu-8l | 1024 | 4096 | 64 | 48262 | solu | 16 | 8 | 1024 | 101M | 56 | | solu-10l | 1280 | 5120 | 64 | 48262 | solu | 20 | 10 | 1024 | 197M | 57 | | solu-12l | 1536 | 6144 | 64 | 48262 | solu | 24 | 12 | 1024 | 340M | 58 | | gelu-1l | 512 | 2048 | 64 | 48262 | gelu | 8 | 1 | 1024 | 3.1M | 59 | | gelu-2l | 512 | 2048 | 64 | 48262 | gelu | 8 | 2 | 1024 | 6.3M | 60 | | gelu-3l | 512 | 2048 | 64 | 48262 | gelu | 8 | 3 | 1024 | 9.4M | 61 | | gelu-4l | 512 | 2048 | 64 | 48262 | gelu | 8 | 4 | 1024 | 13M | 62 | | attn-only-1l | 512 | 2048 | 64 | 48262 | attn_only | 8 | 1 | 1024 | 1.0M | 63 | | attn-only-2l | 512 | 2048 | 64 | 48262 | attn_only | 8 | 2 | 1024 | 2.1M | 64 | | attn-only-3l | 512 | 2048 | 64 | 48262 | attn_only | 8 | 3 | 1024 | 3.1M | 65 | | attn-only-4l | 512 | 2048 | 64 | 48262 | attn_only | 8 | 4 | 1024 | 4.2M | 66 | | attn-only-2l-demo | 512 | 2048 | 64 | 50277 | attn_only | 8 | 2 | 1024 | 2.1M | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/transformer_lens/past_key_value_caching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from dataclasses import dataclass 4 | from typing import Union, Tuple, List, Dict, Any, Optional 5 | from .HookedTransformerConfig import HookedTransformerConfig 6 | from torchtyping import TensorType as TT 7 | from .torchtyping_helper import T 8 | 9 | 10 | @dataclass 11 | class HookedTransformerKeyValueCacheEntry: 12 | past_keys: TT[T.batch, T.pos_so_far, T.n_heads, T.d_head] 13 | past_values: TT[T.batch, T.pos_so_far, T.n_heads, T.d_head] 14 | 15 | @classmethod 16 | def init_cache_entry( 17 | cls, 18 | cfg: HookedTransformerConfig, 19 | device: torch.device, 20 | batch_size: int = 1, 21 | ): 22 | return cls( 23 | past_keys=torch.empty( 24 | (batch_size, 0, cfg.n_heads, cfg.d_head), device=device 25 | ), 26 | past_values=torch.empty( 27 | (batch_size, 0, cfg.n_heads, cfg.d_head), device=device 28 | ), 29 | ) 30 | 31 | def append( 32 | self, 33 | new_keys: TT[T.batch, T.new_tokens, T.n_heads, T.d_head], 34 | new_values: TT[T.batch, T.new_tokens, T.n_heads, T.d_head], 35 | ): 36 | updated_keys: TT[ 37 | "batch", "pos_so_far + new_tokens", "n_heads", "d_head" 38 | ] = torch.cat([self.past_keys, new_keys], dim=1) 39 | updated_values: TT[ 40 | "batch", "pos_so_far + new_tokens", "n_heads", "d_head" 41 | ] = torch.cat([self.past_values, new_values], dim=1) 42 | self.past_keys = updated_keys 43 | self.past_values = updated_values 44 | return updated_keys, updated_values 45 | 46 | 47 | @dataclass 48 | class HookedTransformerKeyValueCache: 49 | """ 50 | A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves! 51 | 52 | This cache is a list of HookedTransformerKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value. 53 | 54 | Generation is assumed to be done by initializing with some prompt and then continuing iteratively one token at a time. So append only works for adding a single token's worth of keys and values, and but the cache can be initialized with many. 55 | 56 | """ 57 | 58 | entries: List[HookedTransformerKeyValueCacheEntry] 59 | 60 | @classmethod 61 | def init_cache( 62 | cls, cfg: HookedTransformerConfig, device: torch.device, batch_size: int = 1 63 | ): 64 | return cls( 65 | entries=[ 66 | HookedTransformerKeyValueCacheEntry.init_cache_entry( 67 | cfg, device, batch_size 68 | ) 69 | for _ in range(cfg.n_layers) 70 | ] 71 | ) 72 | 73 | def __getitem__(self, idx): 74 | return self.entries[idx] 75 | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/transformer_lens/torchtyping_helper.py: -------------------------------------------------------------------------------- 1 | class T: 2 | """Helper class to get mypy to work with TorchTyping and solidify naming conventions as a byproduct. 3 | 4 | Examples: 5 | - `TT[T.batch, T.pos, T.d_model]` 6 | - `TT[T.num_components, T.batch_and_pos_dims:...]` 7 | """ 8 | 9 | batch: str = "batch" 10 | pos: str = "pos" 11 | head_index: str = "head_index" 12 | length: str = "length" 13 | rotary_dim: str = "rotary_dim" 14 | new_tokens: str = "new_tokens" 15 | batch_and_pos_dims: str = "batch_and_pos_dims" 16 | layers_accumulated_over: str = "layers_accumulated_over" 17 | layers_covered: str = "layers_covered" 18 | past_kv_pos_offset: str = "past_kv_pos_offset" 19 | num_components: str = "num_components" 20 | num_neurons: str = "num_neurons" 21 | pos_so_far: str = "pos_so_far" 22 | n_ctx: str = "n_ctx" 23 | n_heads: str = "n_heads" 24 | n_layers: str = "n_layers" 25 | d_vocab: str = "d_vocab" 26 | d_vocab_out: str = "d_vocab_out" 27 | d_head: str = "d_head" 28 | d_mlp: str = "d_mlp" 29 | d_model: str = "d_model" 30 | 31 | ldim: str = "ldim" 32 | rdim: str = "rdim" 33 | new_rdim: str = "new_rdim" 34 | mdim: str = "mdim" 35 | leading_dims: str = "leading_dims" 36 | leading_dims_left: str = "leading_dims_left" 37 | leading_dims_right: str = "leading_dims_right" 38 | 39 | a: str = "a" 40 | b: str = "b" 41 | 42 | pos_plus_past_kv_pos_offset = "pos + past_kv_pos_offset" 43 | d_vocab_plus_n_ctx = "d_vocab + n_ctx" 44 | pos_plus_new_tokens = "pos + new_tokens" 45 | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/transformer_lens/train.py: -------------------------------------------------------------------------------- 1 | from . import HookedTransformer 2 | from . import HookedTransformerConfig 3 | from dataclasses import dataclass 4 | from typing import Optional, Callable 5 | from torch.utils.data import Dataset, DataLoader 6 | import torch.optim as optim 7 | import wandb 8 | import torch 9 | import torch.nn as nn 10 | from tqdm.auto import tqdm 11 | from einops import rearrange 12 | 13 | 14 | @dataclass 15 | class HookedTransformerTrainConfig: 16 | """ 17 | Configuration class to store training hyperparameters for a training run of 18 | an HookedTransformer model. 19 | Args: 20 | num_epochs (int): Number of epochs to train for 21 | batch_size (int): Size of batches to use for training 22 | lr (float): Learning rate to use for training 23 | seed (int): Random seed to use for training 24 | momentum (float): Momentum to use for training 25 | max_grad_norm (float, *optional*): Maximum gradient norm to use for 26 | weight_decay (float, *optional*): Weight decay to use for training 27 | training 28 | optimizer_name (str): The name of the optimizer to use 29 | device (str, *optional*): Device to use for training 30 | warmup_steps (int, *optional*): Number of warmup steps to use for training 31 | save_every (int, *optional*): After how many batches should a checkpoint be saved 32 | save_dir, (str, *optional*): Where to save checkpoints 33 | wandb (bool): Whether to use Weights and Biases for logging 34 | wandb_project (str, *optional*): Name of the Weights and Biases project to use 35 | print_every (int, *optional*): Print the loss every n steps 36 | max_steps (int, *optional*): Terminate the epoch after this many steps. Used for debugging. 37 | """ 38 | 39 | num_epochs: int 40 | batch_size: int 41 | lr: float = 1e-3 42 | seed: int = 0 43 | momentum: float = 0.0 44 | max_grad_norm: Optional[float] = None 45 | weight_decay: Optional[float] = None 46 | optimizer_name: str = "Adam" 47 | device: Optional[str] = None 48 | warmup_steps: int = 0 49 | save_every: Optional[int] = None 50 | save_dir: Optional[str] = None 51 | wandb: bool = False 52 | wandb_project_name: Optional[str] = None 53 | print_every: Optional[int] = 50 54 | max_steps: Optional[int] = None 55 | 56 | 57 | def train( 58 | model: HookedTransformer, 59 | config: HookedTransformerTrainConfig, 60 | dataset: Dataset, 61 | ) -> HookedTransformer: 62 | """ 63 | Trains an HookedTransformer model on an autoregressive language modeling task. 64 | Args: 65 | model: The model to train 66 | config: The training configuration 67 | dataset: The dataset to train on - this function assumes the dataset is 68 | set up for autoregressive language modeling. 69 | Returns: 70 | The trained model 71 | """ 72 | torch.manual_seed(config.seed) 73 | model.train() 74 | if config.wandb: 75 | if config.wandb_project_name is None: 76 | config.wandb_project_name = "easy-transformer" 77 | wandb.init(project=config.wandb_project_name, config=vars(config)) 78 | 79 | if config.device is None: 80 | config.device = "cuda" if torch.cuda.is_available() else "cpu" 81 | 82 | if config.optimizer_name in ["Adam", "AdamW"]: 83 | # Weight decay in Adam is implemented badly, so use AdamW instead (see PyTorch AdamW docs) 84 | if config.weight_decay is not None: 85 | optimizer = optim.AdamW( 86 | model.parameters(), 87 | lr=config.lr, 88 | weight_decay=config.weight_decay, 89 | ) 90 | else: 91 | optimizer = optim.Adam( 92 | model.parameters(), 93 | lr=config.lr, 94 | ) 95 | elif config.optimizer_name == "SGD": 96 | optimizer = optim.SGD( 97 | model.parameters(), 98 | lr=config.lr, 99 | weight_decay=config.weight_decay 100 | if config.weight_decay is not None 101 | else 0.0, 102 | momentum=config.momentum, 103 | ) 104 | else: 105 | raise ValueError(f"Optimizer {config.optimizer_name} not supported") 106 | 107 | scheduler = None 108 | if config.warmup_steps > 0: 109 | scheduler = optim.lr_scheduler.LambdaLR( 110 | optimizer, 111 | lr_lambda=lambda step: min(1.0, step / config.warmup_steps), 112 | ) 113 | 114 | dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True) 115 | 116 | model.to(config.device) 117 | 118 | for epoch in tqdm(range(1, config.num_epochs + 1)): 119 | samples = 0 120 | for step, batch in tqdm(enumerate(dataloader)): 121 | tokens = batch["tokens"].to(config.device) 122 | loss = model(tokens, return_type="loss") 123 | loss.backward() 124 | if config.max_grad_norm is not None: 125 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) 126 | optimizer.step() 127 | if config.warmup_steps > 0: 128 | assert scheduler is not None 129 | scheduler.step() 130 | optimizer.zero_grad() 131 | 132 | samples += tokens.shape[0] 133 | 134 | if config.wandb: 135 | wandb.log( 136 | {"train_loss": loss.item(), "samples": samples, "epoch": epoch} 137 | ) 138 | 139 | if config.print_every is not None and step % config.print_every == 0: 140 | print(f"Epoch {epoch} Samples {samples} Step {step} Loss {loss.item()}") 141 | 142 | if ( 143 | config.save_every is not None 144 | and step % config.save_every == 0 145 | and config.save_dir is not None 146 | ): 147 | torch.save(model.state_dict(), f"{config.save_dir}/model_{step}.pt") 148 | 149 | if config.max_steps is not None and step >= config.max_steps: 150 | break 151 | 152 | return model 153 | -------------------------------------------------------------------------------- /subnetwork_probing/transformer_lens/typing_demo.py: -------------------------------------------------------------------------------- 1 | # %% 2 | 3 | import torch as t 4 | from torchtyping import TT as TT, patch_typeguard 5 | from typeguard import typechecked 6 | import einops 7 | 8 | patch_typeguard() 9 | 10 | ZimZam = TT["batch", "feature", float] 11 | 12 | 13 | @typechecked 14 | def test(x: ZimZam) -> ZimZam: 15 | return einops.rearrange(x, "f b -> f b") 16 | 17 | 18 | x = t.rand((10000, 1), dtype=t.float32) 19 | 20 | test(x) 21 | 22 | # what if "batch" and "feature" now take on different values? 23 | 24 | x = t.rand((20000, 2), dtype=t.float32) 25 | 26 | test(x) 27 | 28 | # ah so indeed batch and feature must only be consistent across a single function call 29 | 30 | # now what if we repeat the same strings across type definitions? 31 | 32 | ZimZam2 = TT["batch", "feature", float] 33 | 34 | 35 | @typechecked 36 | def test2(x: ZimZam2) -> ZimZam: 37 | return einops.rearrange(x, "f b -> f b") 38 | 39 | 40 | @typechecked 41 | def test3(x: ZimZam) -> ZimZam2: 42 | return einops.rearrange(x, "f b -> f b") 43 | 44 | 45 | test2(x) 46 | test3(x) 47 | 48 | # so the right mental model is that the decorators register 49 | # a dictionary whose keys are the dimension names and 50 | # whose values are the sizes. and the values must be consistent 51 | # across a single function call 52 | 53 | # now let's watch the type checker fail 54 | 55 | 56 | @typechecked 57 | def test4(x: ZimZam) -> ZimZam: 58 | return einops.rearrange(x, "f b -> b f") 59 | 60 | 61 | # %% 62 | -------------------------------------------------------------------------------- /tests/acdc/test_acdc.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from copy import deepcopy 3 | from typing import ( 4 | List, 5 | Tuple, 6 | Dict, 7 | Any, 8 | Optional, 9 | Union, 10 | Callable, 11 | TypeVar, 12 | Iterable, 13 | Set, 14 | ) 15 | import wandb 16 | import IPython 17 | import torch 18 | 19 | from tqdm import tqdm 20 | import random 21 | from functools import * 22 | import json 23 | import pathlib 24 | import warnings 25 | import time 26 | import networkx as nx 27 | import os 28 | import torch 29 | import huggingface_hub 30 | from enum import Enum 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | import torch.optim as optim 34 | import numpy as np 35 | import einops 36 | from tqdm import tqdm 37 | import yaml 38 | import gc 39 | from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer 40 | 41 | import matplotlib.pyplot as plt 42 | import plotly.express as px 43 | import plotly.io as pio 44 | from plotly.subplots import make_subplots 45 | import plotly.graph_objects as go 46 | from transformer_lens.hook_points import HookedRootModule, HookPoint 47 | from transformer_lens.HookedTransformer import ( 48 | HookedTransformer, 49 | ) 50 | from acdc.acdc_utils import ( 51 | make_nd_dict, 52 | shuffle_tensor, 53 | ct, 54 | ) 55 | from acdc.TLACDCEdge import ( 56 | TorchIndex, 57 | Edge, 58 | EdgeType, 59 | ) 60 | # these introduce several important classes !!! 61 | 62 | from acdc.TLACDCCorrespondence import TLACDCCorrespondence 63 | from acdc.TLACDCInterpNode import TLACDCInterpNode 64 | from acdc.TLACDCExperiment import TLACDCExperiment 65 | 66 | from collections import defaultdict, deque, OrderedDict 67 | from acdc.docstring.utils import get_all_docstring_things 68 | from acdc.greaterthan.utils import get_all_greaterthan_things 69 | from acdc.induction.utils import ( 70 | get_all_induction_things, 71 | get_validation_data, 72 | get_good_induction_candidates, 73 | get_mask_repeat_candidates, 74 | ) 75 | from acdc.ioi.utils import get_all_ioi_things 76 | from acdc.tracr_task.utils import get_all_tracr_things, get_tracr_model_input_and_tl_model 77 | from acdc.acdc_graphics import ( 78 | build_colorscheme, 79 | show, 80 | ) 81 | import pytest 82 | from pathlib import Path 83 | 84 | @pytest.mark.slow 85 | @pytest.mark.skip(reason="TODO fix") 86 | def test_induction_several_steps(): 87 | # get induction task stuff 88 | num_examples = 400 89 | seq_len = 30 90 | # TODO initialize the `tl_model` with the right model 91 | all_induction_things = get_all_induction_things(num_examples=num_examples, seq_len=seq_len, device="cpu") # removed some randomize seq_len thing - hopefully unimportant 92 | tl_model, toks_int_values, toks_int_values_other, metric = all_induction_things.tl_model, all_induction_things.validation_data, all_induction_things.validation_patch_data, all_induction_things.validation_metric 93 | 94 | gc.collect() 95 | torch.cuda.empty_cache() 96 | 97 | # initialise object 98 | exp = TLACDCExperiment( 99 | model=tl_model, 100 | threshold=0.1, 101 | using_wandb=False, 102 | zero_ablation=False, 103 | ds=toks_int_values, 104 | ref_ds=toks_int_values_other, 105 | metric=metric, 106 | second_metric=None, 107 | verbose=True, 108 | indices_mode="reverse", 109 | names_mode="normal", 110 | corrupted_cache_cpu=True, 111 | online_cache_cpu=True, 112 | add_sender_hooks=True, # attempting to be efficient... 113 | add_receiver_hooks=False, 114 | remove_redundant=True, 115 | ) 116 | 117 | for STEP_IDX in range(10): 118 | exp.step() 119 | 120 | edges_to_consider = {edge_tuple: edge for edge_tuple, edge in exp.corr.all_edges().items() if edge.effect_size is not None} 121 | 122 | EDGE_EFFECTS = OrderedDict([ 123 | ( ('blocks.1.hook_resid_post', TorchIndex([None]), 'blocks.1.attn.hook_result', TorchIndex([None, None, 6])) , 0.6195546984672546 ), 124 | ( ('blocks.1.hook_resid_post', TorchIndex([None]), 'blocks.1.attn.hook_result', TorchIndex([None, None, 5])) , 0.8417580723762512 ), 125 | ( ('blocks.1.hook_resid_post', TorchIndex([None]), 'blocks.0.attn.hook_result', TorchIndex([None, None, 5])) , 0.1795809268951416 ), 126 | ( ('blocks.1.hook_resid_post', TorchIndex([None]), 'blocks.0.attn.hook_result', TorchIndex([None, None, 4])) , 0.15076303482055664 ), 127 | ( ('blocks.1.hook_resid_post', TorchIndex([None]), 'blocks.0.attn.hook_result', TorchIndex([None, None, 3])) , 0.11805805563926697 ), 128 | ( ('blocks.1.hook_resid_post', TorchIndex([None]), 'blocks.0.hook_resid_pre', TorchIndex([None])) , 0.6345541179180145 ), 129 | ( ('blocks.1.attn.hook_q', TorchIndex([None, None, 6]), 'blocks.1.hook_q_input', TorchIndex([None, None, 6])) , 1.4423644244670868 ), 130 | ( ('blocks.1.attn.hook_q', TorchIndex([None, None, 5]), 'blocks.1.hook_q_input', TorchIndex([None, None, 5])) , 1.2416923940181732 ), 131 | ( ('blocks.1.attn.hook_k', TorchIndex([None, None, 6]), 'blocks.1.hook_k_input', TorchIndex([None, None, 6])) , 1.4157390296459198 ), 132 | ( ('blocks.1.attn.hook_k', TorchIndex([None, None, 5]), 'blocks.1.hook_k_input', TorchIndex([None, None, 5])) , 1.270191639661789 ), 133 | ( ('blocks.1.attn.hook_v', TorchIndex([None, None, 6]), 'blocks.1.hook_v_input', TorchIndex([None, None, 6])) , 2.9806662499904633 ), 134 | ( ('blocks.1.attn.hook_v', TorchIndex([None, None, 5]), 'blocks.1.hook_v_input', TorchIndex([None, None, 5])) , 2.7053256928920746 ), 135 | ( ('blocks.1.hook_v_input', TorchIndex([None, None, 6]), 'blocks.0.attn.hook_result', TorchIndex([None, None, 2])) , 0.12778228521347046 ), 136 | ( ('blocks.1.hook_v_input', TorchIndex([None, None, 6]), 'blocks.0.hook_resid_pre', TorchIndex([None])) , 1.8775241374969482 ), 137 | ]) 138 | 139 | assert set(edges_to_consider.keys()) == set(EDGE_EFFECTS.keys()), (set(edges_to_consider.keys()) - set(EDGE_EFFECTS.keys()), set(EDGE_EFFECTS.keys()) - set(edges_to_consider.keys()), EDGE_EFFECTS.keys()) 140 | 141 | for edge_tuple, edge in edges_to_consider.items(): 142 | assert abs(edge.effect_size - EDGE_EFFECTS[edge_tuple]) < 1e-5, (edge_tuple, edge.effect_size, EDGE_EFFECTS[edge_tuple]) 143 | 144 | @pytest.mark.slow 145 | @pytest.mark.parametrize("task, metric", [ 146 | ("tracr-proportion", "l2"), 147 | ("tracr-reverse", "l2"), 148 | ("docstring", "kl_div"), 149 | ("induction", "kl_div"), 150 | ("ioi", "kl_div"), 151 | ("greaterthan", "kl_div"), 152 | ]) 153 | def test_main_script(task, metric): 154 | import subprocess 155 | 156 | main_path = Path(__file__).resolve().parent.parent.parent / "acdc" / "main.py" 157 | subprocess.check_call(["python", str(main_path), f"--task={task}", "--threshold=1234", "--single-step", "--device=cpu", f"--metric={metric}"]) 158 | 159 | def test_editing_edges_notebook(): 160 | import notebooks.editing_edges 161 | 162 | 163 | 164 | @pytest.mark.parametrize("task", ["tracr-proportion", "tracr-reverse", "docstring", "induction", "ioi", "greaterthan"]) 165 | @pytest.mark.parametrize("zero_ablation", [False, True]) 166 | def test_full_correspondence_zero_kl(task, zero_ablation, device="cpu", metric_name="kl_div", num_examples=4, seq_len=10): 167 | if task == "tracr-proportion": 168 | things = get_all_tracr_things(task="proportion", num_examples=num_examples, device=device, metric_name="l2") 169 | elif task == "tracr-reverse": 170 | things = get_all_tracr_things(task="reverse", num_examples=6, device=device, metric_name="l2") 171 | elif task == "induction": 172 | things = get_all_induction_things(num_examples=100, seq_len=20, device=device, metric=metric_name) 173 | elif task == "ioi": 174 | things = get_all_ioi_things(num_examples=num_examples, device=device, metric_name=metric_name) 175 | elif task == "docstring": 176 | things = get_all_docstring_things(num_examples=num_examples, seq_len=seq_len, device=device, metric_name=metric_name, correct_incorrect_wandb=False) 177 | elif task == "greaterthan": 178 | things = get_all_greaterthan_things(num_examples=num_examples, metric_name=metric_name, device=device) 179 | else: 180 | raise ValueError(task) 181 | 182 | exp = TLACDCExperiment( 183 | model=things.tl_model, 184 | threshold=100_000, 185 | early_exit=False, 186 | using_wandb=False, 187 | zero_ablation=zero_ablation, 188 | ds=things.test_data, 189 | ref_ds=things.test_patch_data, 190 | metric=things.validation_metric, 191 | second_metric=None, 192 | verbose=True, 193 | use_pos_embed=False, # In the case that this is True, the KL should not be zero. 194 | online_cache_cpu=True, 195 | corrupted_cache_cpu=True, 196 | ) 197 | exp.setup_corrupted_cache() 198 | 199 | corr = deepcopy(exp.corr) 200 | for e in corr.all_edges().values(): 201 | e.present = True 202 | 203 | with torch.no_grad(): 204 | out = exp.call_metric_with_corr(corr, things.test_metrics["kl_div"], things.test_data) 205 | assert abs(out) < 1e-6, f"{out} should be abs(out) < 1e-6" 206 | -------------------------------------------------------------------------------- /tests/acdc/test_greaterthan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from acdc.greaterthan.utils import greaterthan_metric_reference, greaterthan_metric, get_year_data 3 | from acdc.ioi.utils import get_gpt2_small 4 | 5 | 6 | def test_greaterthan_metric(): 7 | model = get_gpt2_small(device="cpu") 8 | data, _ = get_year_data(20, model) 9 | logits = model(data) 10 | 11 | expected = greaterthan_metric_reference(logits, data) 12 | actual = greaterthan_metric(logits, data) 13 | torch.testing.assert_close(actual, torch.as_tensor(expected)) 14 | -------------------------------------------------------------------------------- /tests/subnetwork_probing/test_count_nodes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import pygraphviz as pgv 4 | from acdc.TLACDCCorrespondence import TLACDCCorrespondence 5 | from acdc.TLACDCInterpNode import TLACDCInterpNode 6 | from acdc.TLACDCEdge import EdgeType, TorchIndex 7 | from acdc.acdc_graphics import show 8 | import tempfile 9 | import os 10 | 11 | from subnetwork_probing.transformer_lens.transformer_lens.HookedTransformer import HookedTransformer 12 | import pytest 13 | from pathlib import Path 14 | import sys 15 | 16 | sys.path.append(str(Path(__file__).parent.parent / "code")) 17 | 18 | from subnetwork_probing.train import iterative_correspondence_from_mask, get_transformer_config 19 | import networkx as nx 20 | from acdc.TLACDCInterpNode import parse_interpnode 21 | 22 | def delete_nested_dict(d: dict, keys: list): 23 | inner_dicts = [d] 24 | try: 25 | for k in keys[:-1]: 26 | inner_dicts.append(inner_dicts[-1][k]) 27 | 28 | del inner_dicts[-1][keys[-1]] 29 | except KeyError: 30 | return 31 | assert len(inner_dicts) == len(keys) 32 | 33 | if len(inner_dicts[-1]) == 0: 34 | for k_to_delete, inner_dict in reversed(list(zip(keys, inner_dicts))): 35 | assert not isinstance(inner_dict, dict) or len(inner_dict) == 0 36 | try: 37 | del inner_dict[k_to_delete] 38 | except KeyError: 39 | return 40 | if len(inner_dict) > 0: 41 | break 42 | 43 | def test_count_nodes(): 44 | nodes_to_mask_str = [ 45 | "blocks.0.attn.hook_q[COL, COL, 0]", 46 | "blocks.0.attn.hook_k[COL, COL, 0]", 47 | "blocks.0.attn.hook_q[COL, COL, 1]", 48 | "blocks.0.attn.hook_k[COL, COL, 1]", 49 | "blocks.0.attn.hook_v[COL, COL, 1]", 50 | "blocks.0.attn.hook_q[COL, COL, 2]", 51 | "blocks.0.attn.hook_k[COL, COL, 2]", 52 | "blocks.0.attn.hook_v[COL, COL, 2]", 53 | "blocks.0.attn.hook_q[COL, COL, 3]", 54 | "blocks.0.attn.hook_k[COL, COL, 3]", 55 | "blocks.0.attn.hook_v[COL, COL, 3]", 56 | "blocks.0.attn.hook_q[COL, COL, 4]", 57 | "blocks.0.attn.hook_k[COL, COL, 4]", 58 | "blocks.0.attn.hook_v[COL, COL, 4]", 59 | "blocks.0.attn.hook_q[COL, COL, 5]", 60 | "blocks.0.attn.hook_k[COL, COL, 5]", 61 | "blocks.0.attn.hook_q[COL, COL, 6]", 62 | "blocks.0.attn.hook_k[COL, COL, 6]", 63 | "blocks.0.attn.hook_q[COL, COL, 7]", 64 | "blocks.0.attn.hook_k[COL, COL, 7]", 65 | "blocks.1.attn.hook_q[COL, COL, 0]", 66 | "blocks.1.attn.hook_k[COL, COL, 0]", 67 | "blocks.1.attn.hook_v[COL, COL, 0]", 68 | "blocks.1.attn.hook_q[COL, COL, 1]", 69 | "blocks.1.attn.hook_k[COL, COL, 1]", 70 | "blocks.1.attn.hook_v[COL, COL, 1]", 71 | "blocks.1.attn.hook_q[COL, COL, 2]", 72 | "blocks.1.attn.hook_k[COL, COL, 2]", 73 | "blocks.1.attn.hook_v[COL, COL, 2]", 74 | "blocks.1.attn.hook_q[COL, COL, 3]", 75 | "blocks.1.attn.hook_k[COL, COL, 3]", 76 | "blocks.1.attn.hook_q[COL, COL, 4]", 77 | "blocks.1.attn.hook_k[COL, COL, 4]", 78 | "blocks.1.attn.hook_q[COL, COL, 5]", 79 | "blocks.1.attn.hook_k[COL, COL, 5]", 80 | "blocks.1.attn.hook_q[COL, COL, 6]", 81 | "blocks.1.attn.hook_k[COL, COL, 6]", 82 | "blocks.1.attn.hook_v[COL, COL, 6]", 83 | "blocks.1.attn.hook_q[COL, COL, 7]", 84 | "blocks.1.attn.hook_k[COL, COL, 7]", 85 | ] 86 | nodes_to_mask = [parse_interpnode(s) for s in nodes_to_mask_str] 87 | nodes_to_mask2 = [ 88 | TLACDCInterpNode( 89 | n.name.replace(".attn", "") + "_input", n.index, EdgeType.ADDITION 90 | ) 91 | for n in nodes_to_mask 92 | ] 93 | nodes_to_mask += nodes_to_mask2 94 | 95 | cfg = get_transformer_config() 96 | model = HookedTransformer(cfg, is_masked=True) 97 | 98 | corr = TLACDCCorrespondence.setup_from_model(model) 99 | for child_hook_name in corr.edges: 100 | for child_index in corr.edges[child_hook_name]: 101 | for parent_hook_name in corr.edges[child_hook_name][child_index]: 102 | for parent_index in corr.edges[child_hook_name][child_index][ 103 | parent_hook_name 104 | ]: 105 | edge = corr.edges[child_hook_name][child_index][parent_hook_name][ 106 | parent_index 107 | ] 108 | 109 | if all( 110 | (child_hook_name != n.name or child_index != n.index) 111 | for n in nodes_to_mask 112 | ) and all( 113 | (parent_hook_name != n.name or parent_index != n.index) 114 | for n in nodes_to_mask 115 | ): 116 | edge.effect_size = 1 117 | 118 | with tempfile.TemporaryDirectory() as tmpdir: 119 | g = show(corr, os.path.join(tmpdir, "out.png"), show_full_index=False) 120 | assert isinstance(g, pgv.AGraph) 121 | path = Path(tmpdir) / "out.gv" 122 | assert path.exists() 123 | # In advance I predict that it should be 41 124 | g2 = nx.nx_agraph.read_dot(path) 125 | 126 | to_delete = [] 127 | for n in g2.nodes: 128 | if not nx.has_path(g2, "embed", n) or not nx.has_path(g2, n, ""): 129 | to_delete.append(n) 130 | 131 | for n in to_delete: 132 | g2.remove_node(n) 133 | 134 | # Delete self-loops 135 | for n in g2.nodes: 136 | if g2.has_edge(n, n): 137 | g2.remove_edge(n, n) 138 | assert len(g2.edges) == 41 139 | 140 | corr, _ = iterative_correspondence_from_mask(model, nodes_to_mask) 141 | assert corr.count_no_edges() == 41 142 | -------------------------------------------------------------------------------- /tests/subnetwork_probing/test_sp_launch.py: -------------------------------------------------------------------------------- 1 | import subnetwork_probing.launch_grid_fill 2 | import pytest 3 | 4 | @pytest.mark.slow 5 | @pytest.mark.parametrize("reset_networks", [True, False]) 6 | def test_sp_grid(reset_networks): 7 | tasks = ["tracr-reverse", "tracr-proportion", "docstring", "induction", "greaterthan", "ioi"] 8 | subnetwork_probing.launch_grid_fill.main(TASKS=tasks, job=None, name="sp-test", testing=True, reset_networks=reset_networks) 9 | --------------------------------------------------------------------------------