├── .gitattributes
├── .gitconfig
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── acdc
├── TLACDCCorrespondence.py
├── TLACDCEdge.py
├── TLACDCExperiment.py
├── TLACDCInterpNode.py
├── __init__.py
├── acdc_graphics.py
├── acdc_utils.py
├── docstring
│ ├── __init__.py
│ ├── prompts.py
│ └── utils.py
├── global_cache.py
├── greaterthan
│ ├── __init__.py
│ └── utils.py
├── induction
│ ├── __init__.py
│ └── utils.py
├── ioi
│ ├── __init__.py
│ ├── ioi_dataset.py
│ └── utils.py
├── logic_gates
│ ├── __init__.py
│ └── utils.py
├── main.py
└── tracr_task
│ ├── __init__.py
│ └── utils.py
├── assets
└── acdc_finds_subgraph.png
├── experiments
├── all_roc_plots.py
├── code_for_or_gate_figure.ipynb
├── collect_data.py
├── docstring_results.py
├── induction_results.py
├── launch_abstract.py
├── launch_abstract_old.py
├── launch_all_sixteen_heads.py
├── launch_docstring.py
├── launch_induction.py
├── launch_sixteen_heads.py
├── launch_spreadsheet.py
├── launcher.py
└── results
│ ├── .gitignore
│ ├── auc_tables.tex
│ ├── canonical_circuits
│ ├── .gitignore
│ ├── Makefile
│ ├── greaterthan
│ │ ├── Makefile
│ │ └── layout.gv
│ └── ioi
│ │ ├── .gitignore
│ │ ├── Makefile
│ │ └── layout.gv
│ ├── plots
│ └── .gitignore
│ └── plots_data
│ ├── 16h-docstring-docstring_metric-False-0.json
│ ├── 16h-docstring-docstring_metric-False-1.json
│ ├── 16h-docstring-docstring_metric-True-0.json
│ ├── 16h-docstring-docstring_metric-True-1.json
│ ├── 16h-docstring-kl_div-False-0.json
│ ├── 16h-docstring-kl_div-False-1.json
│ ├── 16h-docstring-kl_div-True-0.json
│ ├── 16h-docstring-kl_div-True-1.json
│ ├── Makefile
│ ├── acdc-docstring-docstring_metric-False-0.json
│ ├── acdc-docstring-docstring_metric-False-1.json
│ ├── acdc-docstring-kl_div-False-0.json
│ ├── acdc-docstring-kl_div-False-1.json
│ ├── acdc-greaterthan-greaterthan-False-0.json
│ ├── acdc-greaterthan-greaterthan-True-0.json
│ ├── acdc-greaterthan-kl_div-False-0.json
│ ├── acdc-greaterthan-kl_div-True-0.json
│ ├── acdc-ioi-kl_div-False-0.json
│ ├── acdc-ioi-kl_div-False-1.json
│ ├── acdc-ioi-logit_diff-False-0.json
│ ├── acdc-ioi-logit_diff-False-1.json
│ ├── acdc-tracr-proportion-l2-False-0.json
│ ├── acdc-tracr-proportion-l2-False-1.json
│ ├── acdc-tracr-reverse-l2-False-0.json
│ ├── acdc-tracr-reverse-l2-False-1.json
│ ├── generate_makefile.py
│ ├── sp-docstring-docstring_metric-False-0.json
│ ├── sp-docstring-docstring_metric-False-1.json
│ ├── sp-docstring-docstring_metric-True-0.json
│ ├── sp-docstring-docstring_metric-True-1.json
│ ├── sp-docstring-kl_div-False-0.json
│ ├── sp-docstring-kl_div-False-1.json
│ ├── sp-docstring-kl_div-True-0.json
│ └── sp-docstring-kl_div-True-1.json
├── ims
├── current_paper_induction_json.json
├── current_paper_induction_zero.json
├── induction_json.json
├── make_jsons.py
└── my_new_plotly_graph.json
├── notebooks
├── _converted
│ └── .gitignore
├── auc_tables.py
├── colabs
│ ├── ACDC_Editing_Edges_Demo.ipynb
│ ├── ACDC_Implementation_Demo.ipynb
│ └── ACDC_Main_Demo.ipynb
├── convert_to_ipynb.sh
├── df_plots_data.py
├── easier_roc_plot.py
├── editing_edges.py
├── emacs_plotly_render.py
├── implementation_demo.py
├── make_plotly_plots.py
├── minimal_acdc_node_roc.py
├── pareto_plot.py
└── roc_plot_generator.py
├── poetry.lock
├── pyproject.toml
├── subnetwork_probing
├── README.md
├── create_reset_networks.py
├── launch_grid_fill.py
├── train.py
└── transformer_lens
│ ├── .pre-commit-config.yaml
│ ├── Attribution_Patching_Demo.ipynb
│ ├── Exploratory_Analysis_Demo.ipynb
│ ├── Interactive Neuroscope.ipynb
│ ├── LICENSE
│ ├── Main_Demo.ipynb
│ ├── No_Position_Experiment.ipynb
│ ├── Old_Demo.ipynb
│ ├── README.md
│ ├── Tracr_to_Transformer_Lens_Demo.ipynb
│ ├── activation_patching_in_TL_demo.py.ipynb
│ ├── easy_transformer
│ └── __init__.py
│ ├── further_comments.md
│ ├── poetry.lock
│ ├── pyproject.toml
│ ├── setup.py
│ ├── transformer_lens
│ ├── ActivationCache.py
│ ├── FactoredMatrix.py
│ ├── HookedTransformer.py
│ ├── HookedTransformerConfig.py
│ ├── __init__.py
│ ├── components.py
│ ├── evals.py
│ ├── hook_points.py
│ ├── ioi_dataset.py
│ ├── loading_from_pretrained.py
│ ├── make_docs.py
│ ├── model_properties_table.md
│ ├── past_key_value_caching.py
│ ├── patching.py
│ ├── torchtyping_helper.py
│ ├── train.py
│ └── utils.py
│ └── typing_demo.py
└── tests
├── acdc
├── test_acdc.py
└── test_greaterthan.py
└── subnetwork_probing
├── test_count_nodes.py
└── test_sp_launch.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb merge=nbdev-merge
2 | *.pt filter=lfs diff=lfs merge=lfs -text
3 |
--------------------------------------------------------------------------------
/.gitconfig:
--------------------------------------------------------------------------------
1 | # Generated by nbdev_install_hooks
2 | #
3 | # If you need to disable this instrumentation do:
4 | # git config --local --unset include.path
5 | #
6 | # To restore:
7 | # git config --local include.path ../.gitconfig
8 | #
9 | [merge "nbdev-merge"]
10 | name = resolve conflicts with nbdev_fix
11 | driver = nbdev_merge %O %A %B %P
12 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | out.txt
2 | out2.txt
3 | out1.txt
4 | acdc.egg-info/
5 | *.sage.py
6 | *.pt
7 | *.pkl
8 | better*.json
9 | histories/
10 | transformer_lens/media/
11 | profile*.json
12 | *.png
13 | *.gv
14 | !experiments/results/canonical_circuits/*/layout.gv
15 | __pycache__
16 | Testing_Notebook.ipynb
17 |
18 | scratch.py
19 | transformer_lens/scratch.py
20 |
21 | .vscode/
22 | .idea/
23 | wandb*
24 | transformer_lens\.egg*
25 | MANIFEST.in
26 | settings.ini
27 | _proc
28 | core.py
29 | nbs
30 | _modidx.py
31 | .ipynb_checkpoints
32 | env
33 | dist/
34 | docs/build
35 | .coverage
36 |
37 | # don't really know what these are..
38 | .hypothesis/
39 | arthur
40 | et_model_state_dict_the_(1).pt
41 | gpt2_hypothesis_tree_Sun Jan 29 23:24:51 2023.png
42 | gpt2_hypothesis_tree_Tue Jan 31 09:50:47 2023.png
43 | gpt2_hypothesis_tree_Tue Jan 31 10:04:54 2023.png
44 | gpt2_hypothesis_tree_Tue Jan 31 10:07:44 2023.png
45 | gpt2_hypothesis_tree_Tue Jan 31 10:34:57 2023.png
46 | gpt2_hypothesis_tree_Tue Jan 31 10:45:48 2023.png
47 | gpt2_hypothesis_tree_Tue Jan 31 11:05:51 2023.png
48 | gpt2_hypothesis_tree_Tue Jan 31 11:40:10 2023.png
49 | gpt2_hypothesis_tree_Tue Jan 31 11:45:14 2023.png
50 | hypothesis_tree.dot
51 | openwebtext-10k.jsonl
52 | openwebtext-10k.jsonl.gz
53 | profile.json
54 | profile2.json
55 | profile_trash.json
56 | pts/
57 | .coverage*
58 | .Ds_Store
59 | .pylintrc
60 | **/.DS_Store
61 |
62 | ob-jupyter
63 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
2 | LABEL org.opencontainers.image.source=https://github.com/arthurconmy/Automatic-Circuit-Discovery
3 | ENV DEBIAN_FRONTEND noninteractive
4 |
5 | RUN apt-get update -q \
6 | && apt-get install -y --no-install-recommends \
7 | wget git git-lfs \
8 | python3 python3-dev python3-pip python3-venv python3-setuptools python-is-python3 \
9 | libgl1-mesa-glx graphviz graphviz-dev \
10 | && apt-get clean \
11 | && rm -rf /var/lib/apt/lists/*
12 |
13 | # This venv only holds Poetry and its dependencies. They are isolated from the main project dependencies.
14 | ENV POETRY_HOME="/opt/poetry"
15 | RUN python3 -m venv $POETRY_HOME \
16 | # Here we use the pip inside $POETRY_HOME but afterwards we should not
17 | && "$POETRY_HOME/bin/pip" install poetry==1.4.2 \
18 | && rm -rf "${HOME}/.cache"
19 | ENV POETRY="${POETRY_HOME}/bin/poetry"
20 |
21 | WORKDIR "/Automatic-Circuit-Discovery"
22 | COPY --chown=root:root pyproject.toml poetry.lock ./
23 |
24 | # Don't create a virtualenv, the Docker container is already enough isolation
25 | RUN "$POETRY" config virtualenvs.create false \
26 | # Install dependencies
27 | && "$POETRY" install --no-root --no-interaction "--only=main,dev" \
28 | && rm -rf "${HOME}/.cache"
29 |
30 | # Copy whole repo
31 | COPY --chown=root:root . .
32 | # Abort if repo is dirty
33 | RUN if ! { [ -z "$(git status --porcelain --ignored=traditional)" ] \
34 | ; }; then exit 1; fi
35 |
36 | # Finally install this package
37 | RUN "$POETRY" install --only-root --no-interaction
38 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2023 Arthur Conmy, Adrià Garriga-Alonso
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | []() [](https://github.com/ArthurConmy/Automatic-Circuit-Discovery/pulls)
2 |
3 | # Automatic Circuit DisCovery
4 |
5 | 
6 |
7 | This is the accompanying code to the paper ["Towards Automated Circuit Discovery for Mechanistic Interpretability"](https://arxiv.org/abs/2304.14997) (NeurIPS 2023 Spotlight).
8 |
9 | ⚠️ You may wish to use the repo Auto Circuit by @UFO-101 as this codebase has many sharp edges. Nevertheless:
10 |
11 | * :zap: To run ACDC, see `acdc/main.py`, or this Colab notebook
12 | * :wrench: To see how edit edges in computational graphs in models, see `notebooks/editing_edges.py` or this Colab notebook
13 | * :sparkle: To understand the low-level implementation of completely editable computational graphs, see this Colab notebook or `notebooks/implementation_demo.py`
14 |
15 | This library builds upon the abstractions (`HookPoint`s and standardised `HookedTransformer`s) from [TransformerLens](https://github.com/neelnanda-io/TransformerLens) :mag_right:
16 |
17 | ## Installation:
18 |
19 | First, install the system dependencies for either [Mac](#apple-mac-os-x) or [Linux](#penguin-ubuntu-linux).
20 |
21 | Then, you need Python 3.8+ and [Poetry](https://python-poetry.org/docs/) to install ACDC, like so
22 |
23 | ```bash
24 | git clone git+https://github.com/ArthurConmy/Automatic-Circuit-Discovery.git
25 | cd Automatic-Circuit-Discovery
26 | poetry env use 3.10 # Or be inside a conda or venv environment
27 | # Python 3.10 is recommended but use any Python version >= 3.8
28 | poetry install
29 | ```
30 |
31 | ### System Dependencies
32 |
33 | #### :penguin: Ubuntu Linux
34 |
35 | ```bash
36 | sudo apt-get update && sudo apt-get install libgl1-mesa-glx graphviz build-essential graphviz-dev
37 | ```
38 |
39 | You may also need `apt-get install python3.x-dev` where `x` is your Python version (also see [the issue](https://github.com/ArthurConmy/Automatic-Circuit-Discovery/issues/57) and [pygraphviz installation troubleshooting](https://pygraphviz.github.io/documentation/stable/install.html))
40 |
41 | #### :apple: Mac OS X
42 |
43 | On Mac, you need to let pip (inside poetry) know about the path to the Graphviz libraries.
44 |
45 | ```
46 | brew install graphviz
47 | export CFLAGS="-I$(brew --prefix graphviz)/include"
48 | export LDFLAGS="-L$(brew --prefix graphviz)/lib"
49 | ```
50 |
51 | ### Reproducing results
52 |
53 | To reproduce the Pareto Frontier of KL divergences against number of edges for ACDC runs, run `python experiments/launch_induction.py`. Similarly, `python experiments/launch_sixteen_heads.py` and `python subnetwork_probing/train.py` were used to generate individual data points for the other methods, using the CLI help. All these three commands can produce wandb runs. We use `notebooks/roc_plot_generator.py` to process data from wandb runs into JSON files (see `experiments/results/plots_data/Makefile` for the commands) and `notebooks/make_plotly_plots.py` to produce plots from these JSON files.
54 |
55 | ## Tests
56 |
57 | From the root directory, run
58 |
59 | ```bash
60 | pytest -vvv -m "not slow"
61 | ```
62 |
63 | This will only select tests not marked as `slow`. These tests take a _long_ time, and are good to run occasionally, but
64 | not every time.
65 |
66 | You can run the slow tests with
67 |
68 | ``` bash
69 | pytest -s -m slow
70 | ```
71 |
72 | ## Contributing
73 |
74 | We welcome issues where the code is unclear!
75 |
76 | If your PR affects the main demo, rerun
77 | ```bash
78 | chmod +x experiments/make_notebooks.sh
79 | ./experiments/make_notebooks.sh
80 | ```
81 | to automatically turn the `main.py` into a working demo and check that no errors arise. It is essential that the notebooks converted here consist only of `#%% [markdown]` markdown-only cells, and `#%%` cells with code.
82 |
83 | ## Citing ACDC
84 |
85 | If you use ACDC, please reach out! You can reference the work as follows:
86 |
87 | ```
88 | @inproceedings{conmy2023automated,
89 | title={Towards Automated Circuit Discovery for Mechanistic Interpretability},
90 | author={Arthur Conmy and Augustine N. Mavor-Parker and Aengus Lynch and Stefan Heimersheim and Adri{\`a} Garriga-Alonso},
91 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
92 | year={2023},
93 | eprint={2304.14997},
94 | archivePrefix={arXiv},
95 | primaryClass={cs.LG}
96 | }
97 | ```
98 |
99 | ## TODO
100 |
101 |
102 | Mostly finished TODO list
103 |
104 | [ x ] Make `TransformerLens` install be Neel's code not my PR
105 |
106 | [ x ] Add `hook_mlp_in` to `TransformerLens` and delete `hook_resid_mid` (and test to ensure no bad things?)
107 |
108 | [ x ] Delete `arthur-try-merge-tl` references from the repo
109 |
110 | [ x ] Make notebook on abstractions
111 |
112 | [ ? ] Fix huge edge sizes in Induction Main example and change that occurred
113 |
114 | [ x ] Find a better way to deal with the versioning on the Colabs installs...
115 |
116 | [ ] Neuron-level experiments
117 |
118 | [ ] Position-level experiments
119 |
120 | [ ] Edge gradient descent experiments
121 |
122 | [ ] Implement the circuit breaking paper
123 |
124 | [ x ] `tracr` and other dependencies better managed
125 |
126 | [ ? ] Make SP tests work (lots outdated so skipped) - and check SubnetworkProbing installs properly (no __init__.pys !!!)
127 |
128 | [ ? ] Make the 9 tests also failing on TransformerLens-main pass
129 |
130 | [ x ] Remove Codebase under construction
131 |
132 |
133 |
--------------------------------------------------------------------------------
/acdc/TLACDCEdge.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from collections import defaultdict
3 | from enum import Enum
4 | from typing import Optional, List
5 |
6 |
7 | class EdgeType(Enum):
8 | """
9 | Property of edges in the computational graph - either
10 |
11 | ADDITION: the child (hook_name, index) is a sum of the parent (hook_name, index)s
12 | DIRECT_COMPUTATION The *single* child is a function of and only of the parent (e.g the value hooked by hook_q is a function of what hook_q_input saves).
13 | PLACEHOLDER generally like 2. but where there are generally multiple parents. Here in ACDC we just include these edges by default when we find them. Explained below?
14 |
15 | Q: Why do we do this?
16 |
17 | There are two answers to this question: A1 is an interactive notebook, see this Colab notebook, which is in this repo at notebooks/implementation_demo.py. A2 is an answer that is written here below, but probably not as clear as A1 (though shorter).
18 |
19 | A2: We need something inside TransformerLens to represent the edges of a computational graph.
20 | The object we choose is pairs (hook_name, index). For example the output of Layer 11 Heads is a hook (blocks.11.attn.hook_result) and to sepcify the 3rd head we add the index [:, :, 3]. Then we can build a computational graph on these!
21 |
22 | However, when we do ACDC there turn out to be two conflicting things "removing edges" wants to do:
23 | i) for things in the residual stream, we want to remove the sum of the effects from previous hooks
24 | ii) for things that are not linear we want to *recompute* e.g the result inside the hook
25 | blocks.11.attn.hook_result from a corrupted Q and normal K and V
26 |
27 | The easiest way I thought of of reconciling these different cases, while also having a connected computational graph, is to have three types of edges: addition for the residual case, direct computation for easy cases where we can just replace hook_q with a cached value when we e.g cut it off from hook_q_input, and placeholder to make the graph connected (when hook_result is connected to hook_q and hook_k and hook_v)"""
28 |
29 | ADDITION = 0
30 | DIRECT_COMPUTATION = 1
31 | PLACEHOLDER = 2
32 |
33 | def __eq__(self, other):
34 | """Necessary because of extremely frustrating error that arises with load_ext autoreload (because this uses importlib under the hood: https://stackoverflow.com/questions/66458864/enum-comparison-become-false-after-reloading-module)"""
35 |
36 | assert isinstance(other, EdgeType)
37 | return self.value == other.value
38 |
39 |
40 | class Edge:
41 | def __init__(
42 | self,
43 | edge_type: EdgeType,
44 | present: bool = True,
45 | effect_size: Optional[float] = None,
46 | ):
47 | self.edge_type = edge_type
48 | self.present = present
49 | self.effect_size = effect_size
50 |
51 | def __repr__(self) -> str:
52 | return f"Edge({self.edge_type}, {self.present})"
53 |
54 | class TorchIndex:
55 | """There is not a clean bijection between things we
56 | want in the computational graph, and things that are hooked
57 | (e.g hook_result covers all heads in a layer)
58 |
59 | `TorchIndex`s are essentially indices that say which part of the tensor is being affected.
60 |
61 | EXAMPLES: Initialise [:, :, 3] with TorchIndex([None, None, 3]) and [:] with TorchIndex([None])
62 |
63 | Also we want to be able to call e.g `my_dictionary[my_torch_index]` hence the hashable tuple stuff
64 |
65 | Note: ideally this would be integrated with transformer_lens.utils.Slice in future; they are accomplishing similar but different things"""
66 |
67 | def __init__(
68 | self,
69 | list_of_things_in_tuple: List,
70 | ):
71 | # check correct types
72 | for arg in list_of_things_in_tuple:
73 | if type(arg) in [type(None), int]:
74 | continue
75 | else:
76 | assert isinstance(arg, list)
77 | assert all([type(x) == int for x in arg])
78 |
79 | # make an object that can be indexed into a tensor
80 | self.as_index = tuple([slice(None) if x is None else x for x in list_of_things_in_tuple])
81 |
82 | # make an object that can be hashed (so used as a dictionary key)
83 | self.hashable_tuple = tuple(list_of_things_in_tuple)
84 |
85 | def __hash__(self):
86 | return hash(self.hashable_tuple)
87 |
88 | def __eq__(self, other):
89 | return self.hashable_tuple == other.hashable_tuple
90 |
91 | # some graphics things
92 |
93 | def __repr__(self, use_actual_colon=True) -> str: # graphviz, an old library used to dislike actual colons in strings, but this shouldn't be an issue anymore
94 | ret = "["
95 | for idx, x in enumerate(self.hashable_tuple):
96 | if idx > 0:
97 | ret += ", "
98 | if x is None:
99 | ret += ":" if use_actual_colon else "COLON"
100 | elif type(x) == int:
101 | ret += str(x)
102 | else:
103 | raise NotImplementedError(x)
104 | ret += "]"
105 | return ret
106 |
107 | def graphviz_index(self, use_actual_colon=True) -> str:
108 | return self.__repr__(use_actual_colon=use_actual_colon)
109 |
--------------------------------------------------------------------------------
/acdc/TLACDCInterpNode.py:
--------------------------------------------------------------------------------
1 | from acdc.TLACDCEdge import (
2 | TorchIndex,
3 | Edge,
4 | EdgeType,
5 | ) # these introduce several important classes !!!
6 | from typing import List, Dict, Optional, Tuple, Union, Set, Callable, TypeVar, Iterable, Any
7 |
8 | class TLACDCInterpNode:
9 | """Represents one node in the computational graph, similar to ACDCInterpNode from the rust_circuit code
10 |
11 | But WARNING this has nodes closer to the input tokens as *parents* of nodes closer to the output tokens, the opposite of the rust_circuit code
12 |
13 | Params:
14 | name: name of the node
15 | index: the index of the tensor that this node represents
16 | mode: how we deal with this node when we bump into it as a parent of another node. Addition: it's summed to make up the child. Direct_computation: it's the sole node used to compute the child. Off: it's not the parent of a child ever."""
17 |
18 | def __init__(self, name: str, index: TorchIndex, incoming_edge_type: EdgeType):
19 |
20 | self.name = name
21 | self.index = index
22 |
23 | self.parents: List["TLACDCInterpNode"] = []
24 | self.children: List["TLACDCInterpNode"] = []
25 |
26 | self.incoming_edge_type = incoming_edge_type
27 |
28 | def _add_child(self, child_node: "TLACDCInterpNode"):
29 | """Use the method on TLACDCCorrespondence instead of this one"""
30 | self.children.append(child_node)
31 |
32 | def _add_parent(self, parent_node: "TLACDCInterpNode"):
33 | """Use the method on TLACDCCorrespondence instead of this one"""
34 | self.parents.append(parent_node)
35 |
36 | def __repr__(self):
37 | return f"TLACDCInterpNode({self.name}, {self.index})"
38 |
39 | def __str__(self) -> str:
40 | index_str = "" if len(self.index.hashable_tuple) < 3 else f"_{self.index.hashable_tuple[2]}"
41 | return f"{self.name}{self.index}"
42 |
43 | # ------------------
44 | # some munging utils
45 | # ------------------
46 |
47 | def parse_interpnode(s: str) -> TLACDCInterpNode:
48 | try:
49 | name, idx = s.split("[")
50 | name = name.replace("hook_resid_mid", "hook_mlp_in")
51 | try:
52 | idx = int(idx[-3:-1])
53 | except:
54 | try:
55 | idx = int(idx[-2])
56 | except:
57 | idx = None
58 | return TLACDCInterpNode(name, TorchIndex([None, None, idx]) if idx is not None else TorchIndex([None]), EdgeType.ADDITION)
59 |
60 | except Exception as e:
61 | print(s, e)
62 | raise e
63 |
64 | return TLACDCInterpNode(name, TorchIndex([None, None, idx]), EdgeType.ADDITION)
65 |
66 | def heads_to_nodes_to_mask(heads: List[Tuple[int, int]], return_dict=False):
67 | nodes_to_mask_strings = [
68 | f"blocks.{layer_idx}{'.attn' if not inputting else ''}.hook_{letter}{'_input' if inputting else ''}[COL, COL, {head_idx}]"
69 | # for layer_idx in range(model.cfg.n_layers)
70 | # for head_idx in range(model.cfg.n_heads)
71 | for layer_idx, head_idx in heads
72 | for letter in ["q", "k", "v"]
73 | for inputting in [True, False]
74 | ]
75 | nodes_to_mask_strings.extend([
76 | f"blocks.{layer_idx}.attn.hook_result[COL, COL, {head_idx}]"
77 | for layer_idx, head_idx in heads
78 | ])
79 |
80 | if return_dict:
81 | return {s: parse_interpnode(s) for s in nodes_to_mask_strings}
82 |
83 | else:
84 | return [parse_interpnode(s) for s in nodes_to_mask_strings]
85 |
--------------------------------------------------------------------------------
/acdc/__init__.py:
--------------------------------------------------------------------------------
1 | def check_transformer_lens_version():
2 | """Test that your TransformerLens version is up-to-date for ACDC
3 | by checking that `hook_mlp_in`s exist"""
4 |
5 | from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
6 |
7 | cfg = HookedTransformerConfig.from_dict(
8 | {
9 | "n_layers": 1,
10 | "d_model": 1,
11 | "n_ctx": 1,
12 | "d_head": 1,
13 | "act_fn": "gelu",
14 | "d_vocab": 0,
15 | }
16 | )
17 |
18 | from transformer_lens.HookedTransformer import HookedTransformer
19 | mini_trans = HookedTransformer(cfg)
20 |
21 | mini_trans.blocks[0].hook_mlp_in # try and access the hook_mlp_in: if this fails, your TL is not sufficiently up-to-date
22 |
23 | check_transformer_lens_version()
--------------------------------------------------------------------------------
/acdc/docstring/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/acdc/docstring/__init__.py
--------------------------------------------------------------------------------
/acdc/global_cache.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Union, Tuple, Literal, Dict
3 | from collections import OrderedDict
4 |
5 |
6 | class GlobalCache: # this dict stores the activations from the forward pass
7 | """Class for managing several caches for passing activations around"""
8 |
9 | def __init__(self, device: Union[str, Tuple[str, str]] = "cuda"):
10 | # TODO find a way to make the device propagate when we to .to on the p
11 | # TODO make it essential first key is a str, second a TorchIndex, third a str
12 |
13 | if isinstance(device, str):
14 | device = (device, device)
15 |
16 | self.online_cache = OrderedDict()
17 | self.corrupted_cache = OrderedDict()
18 | self.device: Tuple[str, str] = (device, device)
19 |
20 |
21 | def clear(self, just_first_cache=False):
22 |
23 | if not just_first_cache:
24 | self.online_cache = OrderedDict()
25 | else:
26 | raise NotImplementedError()
27 | self.__init__(self.device[0], self.device[1]) # lol
28 |
29 | import gc
30 | gc.collect()
31 | torch.cuda.empty_cache()
32 |
33 | def to(self, device, which_caches: Literal["online", "corrupted", "all"]="all"): #
34 |
35 | caches = []
36 | if which_caches != "online":
37 | self.device = (device, self.device[1])
38 | caches.append(self.online_cache)
39 | if which_caches != "corrupted":
40 | self.device = (self.device[0], device)
41 | caches.append(self.corrupted_cache)
42 |
43 | # move all the parameters
44 | for cache in caches: # mutable means this works..
45 | for name in cache:
46 | cache_keys = list(cache.keys())
47 | for k in cache_keys:
48 | cache[k].to(device) # = cache[name].to(device)
49 |
50 | return self
--------------------------------------------------------------------------------
/acdc/greaterthan/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/acdc/greaterthan/__init__.py
--------------------------------------------------------------------------------
/acdc/induction/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/acdc/induction/__init__.py
--------------------------------------------------------------------------------
/acdc/induction/utils.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from functools import partial
3 | from acdc.docstring.utils import AllDataThings
4 | import wandb
5 | import os
6 | from collections import defaultdict
7 | import pickle
8 | import torch
9 | import huggingface_hub
10 | import datetime
11 | from typing import Dict, Callable
12 | import torch
13 | import random
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | from typing import (
17 | List,
18 | Tuple,
19 | Dict,
20 | Any,
21 | Optional,
22 | )
23 | import warnings
24 | import networkx as nx
25 | from acdc.acdc_utils import (
26 | MatchNLLMetric,
27 | make_nd_dict,
28 | shuffle_tensor,
29 | )
30 |
31 | from acdc.TLACDCEdge import (
32 | TorchIndex,
33 | Edge,
34 | EdgeType,
35 | ) # these introduce several important classes !!!
36 | from transformer_lens import HookedTransformer
37 | from acdc.acdc_utils import kl_divergence, negative_log_probs
38 |
39 | def get_model(device):
40 | tl_model = HookedTransformer.from_pretrained(
41 | "redwood_attn_2l", # load Redwood's model
42 | center_writing_weights=False, # these are needed as this model is a Shortformer; this is a technical detail
43 | center_unembed=False,
44 | fold_ln=False,
45 | device=device,
46 | )
47 |
48 | # standard ACDC options
49 | tl_model.set_use_attn_result(True)
50 | tl_model.set_use_split_qkv_input(True)
51 | return tl_model
52 |
53 | def get_validation_data(num_examples=None, seq_len=None, device=None):
54 | validation_fname = huggingface_hub.hf_hub_download(
55 | repo_id="ArthurConmy/redwood_attn_2l", filename="validation_data.pt"
56 | )
57 | validation_data = torch.load(validation_fname, map_location=device).long()
58 |
59 | if num_examples is None:
60 | return validation_data
61 | else:
62 | return validation_data[:num_examples][:seq_len]
63 |
64 | def get_good_induction_candidates(num_examples=None, seq_len=None, device=None):
65 | """Not needed?"""
66 | good_induction_candidates_fname = huggingface_hub.hf_hub_download(
67 | repo_id="ArthurConmy/redwood_attn_2l", filename="good_induction_candidates.pt"
68 | )
69 | good_induction_candidates = torch.load(good_induction_candidates_fname, map_location=device)
70 |
71 | if num_examples is None:
72 | return good_induction_candidates
73 | else:
74 | return good_induction_candidates[:num_examples][:seq_len]
75 |
76 | def get_mask_repeat_candidates(num_examples=None, seq_len=None, device=None):
77 | mask_repeat_candidates_fname = huggingface_hub.hf_hub_download(
78 | repo_id="ArthurConmy/redwood_attn_2l", filename="mask_repeat_candidates.pkl"
79 | )
80 | mask_repeat_candidates = torch.load(mask_repeat_candidates_fname, map_location=device)
81 | mask_repeat_candidates.requires_grad = False
82 |
83 | if num_examples is None:
84 | return mask_repeat_candidates
85 | else:
86 | return mask_repeat_candidates[:num_examples, :seq_len]
87 |
88 |
89 | def get_all_induction_things(num_examples, seq_len, device, data_seed=42, metric="kl_div", return_one_element=True) -> AllDataThings:
90 | tl_model = get_model(device=device)
91 |
92 | validation_data_orig = get_validation_data(device=device)
93 | mask_orig = get_mask_repeat_candidates(num_examples=None, device=device) # None so we get all
94 | assert validation_data_orig.shape == mask_orig.shape
95 |
96 | assert seq_len <= validation_data_orig.shape[1]-1
97 |
98 | validation_slice = slice(0, num_examples)
99 | validation_data = validation_data_orig[validation_slice, :seq_len].contiguous()
100 | validation_labels = validation_data_orig[validation_slice, 1:seq_len+1].contiguous()
101 | validation_mask = mask_orig[validation_slice, :seq_len].contiguous()
102 |
103 | validation_patch_data = shuffle_tensor(validation_data, seed=data_seed).contiguous()
104 |
105 | test_slice = slice(num_examples, num_examples*2)
106 | test_data = validation_data_orig[test_slice, :seq_len].contiguous()
107 | test_labels = validation_data_orig[test_slice, 1:seq_len+1].contiguous()
108 | test_mask = mask_orig[test_slice, :seq_len].contiguous()
109 |
110 | # data_seed+1: different shuffling
111 | test_patch_data = shuffle_tensor(test_data, seed=data_seed).contiguous()
112 |
113 | with torch.no_grad():
114 | base_val_logprobs = F.log_softmax(tl_model(validation_data), dim=-1).detach()
115 | base_test_logprobs = F.log_softmax(tl_model(test_data), dim=-1).detach()
116 |
117 | if metric == "kl_div":
118 | validation_metric = partial(
119 | kl_divergence,
120 | base_model_logprobs=base_val_logprobs,
121 | mask_repeat_candidates=validation_mask,
122 | last_seq_element_only=False,
123 | return_one_element=return_one_element,
124 | )
125 | elif metric == "nll":
126 | validation_metric = partial(
127 | negative_log_probs,
128 | labels=validation_labels,
129 | mask_repeat_candidates=validation_mask,
130 | last_seq_element_only=False,
131 | )
132 | elif metric == "match_nll":
133 | validation_metric = MatchNLLMetric(
134 | labels=validation_labels, base_model_logprobs=base_val_logprobs, mask_repeat_candidates=validation_mask,
135 | last_seq_element_only=False,
136 | )
137 | else:
138 | raise ValueError(f"Unknown metric {metric}")
139 |
140 | test_metrics = {
141 | "kl_div": partial(
142 | kl_divergence,
143 | base_model_logprobs=base_test_logprobs,
144 | mask_repeat_candidates=test_mask,
145 | last_seq_element_only=False,
146 | ),
147 | "nll": partial(
148 | negative_log_probs,
149 | labels=test_labels,
150 | mask_repeat_candidates=test_mask,
151 | last_seq_element_only=False,
152 | ),
153 | "match_nll": MatchNLLMetric(
154 | labels=test_labels, base_model_logprobs=base_test_logprobs, mask_repeat_candidates=test_mask,
155 | last_seq_element_only=False,
156 | ),
157 | }
158 | return AllDataThings(
159 | tl_model=tl_model,
160 | validation_metric=validation_metric,
161 | validation_data=validation_data,
162 | validation_labels=validation_labels,
163 | validation_mask=validation_mask,
164 | validation_patch_data=validation_patch_data,
165 | test_metrics=test_metrics,
166 | test_data=test_data,
167 | test_labels=test_labels,
168 | test_mask=test_mask,
169 | test_patch_data=test_patch_data,
170 | )
171 |
172 |
173 | def one_item_per_batch(toks_int_values, toks_int_values_other, mask_rep, base_model_logprobs, kl_take_mean=True):
174 | """Returns each instance of induction as its own batch idx"""
175 |
176 | end_positions = []
177 | batch_size, seq_len = toks_int_values.shape
178 | new_tensors = []
179 |
180 | toks_int_values_other_batch_list = []
181 | new_base_model_logprobs_list = []
182 |
183 | for i in range(batch_size):
184 | for j in range(seq_len - 1): # -1 because we don't know what follows the last token so can't calculate losses
185 | if mask_rep[i, j]:
186 | end_positions.append(j)
187 | new_tensors.append(toks_int_values[i].cpu().clone())
188 | toks_int_values_other_batch_list.append(toks_int_values_other[i].cpu().clone())
189 | new_base_model_logprobs_list.append(base_model_logprobs[i].cpu().clone())
190 |
191 | toks_int_values_other_batch = torch.stack(toks_int_values_other_batch_list).to(toks_int_values.device).clone()
192 | return_tensor = torch.stack(new_tensors).to(toks_int_values.device).clone()
193 | end_positions_tensor = torch.tensor(end_positions).long()
194 |
195 | new_base_model_logprobs = torch.stack(new_base_model_logprobs_list)[torch.arange(len(end_positions_tensor)), end_positions_tensor].to(toks_int_values.device).clone()
196 | metric = partial(
197 | kl_divergence,
198 | base_model_logprobs=new_base_model_logprobs,
199 | end_positions=end_positions_tensor,
200 | mask_repeat_candidates=None, # !!!
201 | last_seq_element_only=False,
202 | return_one_element=False
203 | )
204 |
205 | return return_tensor, toks_int_values_other_batch, end_positions_tensor, metric
206 |
--------------------------------------------------------------------------------
/acdc/ioi/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/acdc/ioi/__init__.py
--------------------------------------------------------------------------------
/acdc/logic_gates/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/acdc/logic_gates/__init__.py
--------------------------------------------------------------------------------
/acdc/logic_gates/utils.py:
--------------------------------------------------------------------------------
1 | #%%
2 |
3 | from functools import partial
4 | import time
5 | import torch
6 | from typing import Literal, Optional
7 | from transformer_lens.HookedTransformer import HookedTransformer, HookedTransformerConfig
8 | from acdc.docstring.utils import AllDataThings
9 | from acdc.tracr_task.utils import get_perm
10 | from acdc.acdc_utils import kl_divergence
11 | import torch.nn.functional as F
12 |
13 | MAX_LOGIC_GATE_SEQ_LEN = 100_000 # Can be increased further provided numerics and memory do not explode
14 |
15 | def get_logic_gate_model(mode: Literal["OR", "AND"] = "OR", seq_len: Optional[int]=None, device="cuda") -> HookedTransformer:
16 |
17 | if seq_len is None:
18 | assert 1 <= seq_len <= MAX_LOGIC_GATE_SEQ_LEN, "We need some bound on sequence length, but this can be increased if the variable at the top is increased"
19 |
20 | if mode == "OR":
21 | assert seq_len == 1
22 | cfg = HookedTransformerConfig.from_dict(
23 | {
24 | "n_layers": 1,
25 | "d_model": 2,
26 | "n_ctx": 1,
27 | "n_heads": 2,
28 | "d_head": 1,
29 | "act_fn": "relu",
30 | "d_vocab": 1,
31 | "d_mlp": 1,
32 | "d_vocab_out": 1,
33 | "normalization_type": None,
34 | "attn_only": False,
35 | }
36 | )
37 | elif mode == "AND":
38 | cfg = HookedTransformerConfig.from_dict(
39 | {
40 | "n_layers": 1,
41 | "d_model": 3,
42 | "n_ctx": seq_len,
43 | "n_heads": 1,
44 | "d_head": 1,
45 | "act_fn": "relu",
46 | "d_vocab": 2,
47 | "d_mlp": 1,
48 | "d_vocab_out": 1,
49 | "normalization_type": None,
50 | }
51 | )
52 | else:
53 | raise ValueError(f"mode {mode} not recognized")
54 |
55 | model = HookedTransformer(cfg).to(device)
56 | model.set_use_attn_result(True)
57 | model.set_use_split_qkv_input(True)
58 | if "use_hook_mlp_in" in model.cfg.to_dict():
59 | model.set_use_hook_mlp_in(True)
60 | model = model.to(torch.double)
61 |
62 | # Turn off model gradient so we can edit weights
63 | # And also set all the weights to 0
64 | for param in model.parameters():
65 | param.requires_grad = False
66 | param[:] = 0.0
67 |
68 | if mode == "AND":
69 | # # Embed 1s as 1.0 in residual component 0
70 | model.embed.W_E[1, 0] = 1.0
71 |
72 | # No QK so uniform attention; this allows us to detect if everything is a 1 as the output into the channel 1 will be 1 not less than that
73 |
74 | # Output 1.0 into residual component 1 for all things present
75 | model.blocks[0].attn.W_V[0, 0, 0] = 1.0 # Shape [head_index d_model d_head]
76 | model.blocks[0].attn.W_O[0, 0, 1] = 1.0 # Shape [head_index d_head d_model]
77 |
78 | model.blocks[0].mlp.W_in[1, 0] = 1.0 # [d_model d_mlp]
79 | model.blocks[0].mlp.b_in[:] = -(MAX_LOGIC_GATE_SEQ_LEN-1)/MAX_LOGIC_GATE_SEQ_LEN # Unless everything in input is a 1, do not fire
80 |
81 | # Write the output to residual component 2
82 | # (TODO: I think we could get away with 2 components here?)
83 | model.blocks[0].mlp.W_out[0, 2] = MAX_LOGIC_GATE_SEQ_LEN # Shape [d_mlp d_model]
84 |
85 | model.unembed.W_U[2, 0] = 1.0 # Shape [d_model d_vocab_out]
86 |
87 | elif mode == "OR":
88 |
89 | # a0.0 and a0.1 are the two inputs to the OR gate; they always dump 1.0 into the residual stream
90 | # Both heads dump a 1 into the residual stream
91 | # We can test our circuit recovery methods with zero ablation to see if they recover either or both heads!
92 | model.blocks[0].attn.b_V[:, 0] = 1.0 # [num_heads, d_head]
93 | model.blocks[0].attn.W_O[:, 0, 0] = 1.0 # [num_heads, d_head, d_model]
94 |
95 | # mlp0 is an OR gate on the output on the output of a0.0 and a0.1; it turns the sum S of their outputs into 1 if S >= 1 and 0 if S = 0
96 | model.blocks[0].mlp.W_in[0, 0] = -1.0 # [d_model d_mlp]
97 | model.blocks[0].mlp.b_in[:] = 1.0 # [d_mlp]
98 |
99 | model.blocks[0].mlp.W_out[0, 1] = -1.0
100 | model.blocks[0].mlp.b_out[:] = 1.0 # [d_model]
101 |
102 | model.unembed.W_U[1, 0] = 1.0 # shape [d_model d_vocab_out]
103 |
104 | else:
105 | raise ValueError(f"mode {mode} not recognized")
106 |
107 | return model
108 |
109 | def test_and_logical_model():
110 | """
111 | Test that the AND gate works
112 | """
113 |
114 | seq_len=3
115 | and_model = get_logic_gate_model(mode="AND", seq_len=seq_len, device = "cpu")
116 |
117 | all_inputs = []
118 | for i in range(2**seq_len):
119 | input = torch.tensor([int(x) for x in f"{i:03b}"]).unsqueeze(0).long()
120 | all_inputs.append(input)
121 | input = torch.cat(all_inputs, dim=0)
122 |
123 | and_output = and_model(input)[:, -1, :]
124 | assert torch.equal(and_output[:2**seq_len - 1], torch.zeros(2**seq_len - 1, 1))
125 | torch.testing.assert_close(and_output[2**seq_len - 1], torch.ones(1).to(torch.double))
126 |
127 | #%%
128 |
129 | def get_all_logic_gate_things(mode: str = "AND", device=None, seq_len: Optional[int] = 5, num_examples: Optional[int] = 10, return_one_element: bool = False) -> AllDataThings:
130 |
131 | assert mode == "OR"
132 |
133 | model = get_logic_gate_model(mode=mode, seq_len=seq_len, device=device)
134 | # Convert the set of binary string back llto tensor
135 | data = torch.tensor([[0.0]]).long() # Input is actually meaningless, all that matters is Attention Heads 0 and 1
136 | correct_answers = data.clone().to(torch.double) + 1
137 |
138 | def validation_metric(output, correct):
139 | output = output[:, -1, :]
140 |
141 | assert output.shape == correct.shape
142 | if not return_one_element:
143 | return torch.mean((output - correct)**2, dim=0)
144 | else:
145 | return ((output - correct)**2).squeeze(1)
146 |
147 | base_validation_logprobs = F.log_softmax(model(data)[:, -1], dim=-1)
148 |
149 | test_metrics = {
150 | "kl_div": partial(
151 | kl_divergence,
152 | base_model_logprobs=base_validation_logprobs,
153 | last_seq_element_only=True,
154 | base_model_probs_last_seq_element_only=False,
155 | return_one_element=return_one_element,
156 | ),}
157 |
158 | return AllDataThings(
159 | tl_model=model,
160 | validation_metric=partial(validation_metric, correct=correct_answers),
161 | validation_data=data,
162 | validation_labels=None,
163 | validation_mask=None,
164 | validation_patch_data=data.clone(), # We're doing zero ablation so irrelevant
165 | test_metrics=test_metrics,
166 | test_data=data,
167 | test_labels=None,
168 | test_mask=None,
169 | test_patch_data=data.clone(),
170 | )
171 |
172 |
173 | # # # test_logical_models()
174 | # # %%
175 |
176 | # or_model = get_logic_gate_model(seq_len=1, device = "cpu")
177 | # logits, cache = or_model.run_with_cache(
178 | # torch.tensor([[0]]).to(torch.long),
179 | # )
180 | # print(logits)
181 |
182 | # # %%
183 |
184 | # for key in cache.keys():
185 | # print(key)
186 | # print(cache[key].shape)
187 | # print(cache[key])
188 | # print("\n\n\n")
189 | # # %%
190 | # #batch pos head_index d_head for hook_q
191 | # %%
192 |
--------------------------------------------------------------------------------
/acdc/tracr_task/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/acdc/tracr_task/__init__.py
--------------------------------------------------------------------------------
/assets/acdc_finds_subgraph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/assets/acdc_finds_subgraph.png
--------------------------------------------------------------------------------
/experiments/all_roc_plots.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | from experiments.launcher import KubernetesJob, launch
3 | import numpy as np
4 | import random
5 | from typing import List
6 |
7 |
8 | TASKS = ["ioi", "docstring", "greaterthan", "tracr-reverse", "tracr-proportion"]
9 |
10 | METRICS_FOR_TASK = {
11 | "ioi": ["kl_div", "logit_diff"],
12 | "tracr-reverse": ["kl_div"],
13 | "tracr-proportion": ["kl_div", "l2"],
14 | "induction": ["kl_div", "nll"],
15 | "docstring": ["kl_div", "docstring_metric"],
16 | "greaterthan": ["kl_div", "greaterthan"],
17 | }
18 |
19 |
20 | def main():
21 | commands = []
22 | for alg in ["16h", "sp", "acdc"]:
23 | for reset_network in [0, 1]:
24 | for zero_ablation in [0, 1]:
25 | for task in TASKS:
26 | for metric in METRICS_FOR_TASK[task]:
27 | command = [
28 | "python",
29 | "notebooks/roc_plot_generator.py",
30 | f"--task={task}",
31 | f"--reset-network={reset_network}",
32 | f"--metric={metric}",
33 | f"--alg={alg}",
34 | ]
35 | if zero_ablation:
36 | command.append("--zero-ablation")
37 | commands.append(command)
38 |
39 | launch(commands, name="plots", job=None, synchronous=False)
40 |
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
--------------------------------------------------------------------------------
/experiments/collect_data.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import argparse
3 | import os
4 | from pathlib import Path
5 | from experiments.launcher import KubernetesJob, WandbIdentifier, launch
6 | import shlex
7 | import random
8 |
9 | IS_ADRIA = not str(os.environ.get("CONDA_DEFAULT_ENV")).lower().startswith("arthur")
10 | if IS_ADRIA:
11 | print("WARNING: IS_ADRIA=True, using Adria's Docker container")
12 |
13 | #TASKS = ["ioi", "docstring", "greaterthan", "tracr-reverse", "tracr-proportion", "induction"]
14 | TASKS = ["ioi", "docstring", "greaterthan", "induction"]
15 |
16 | METRICS_FOR_TASK = {
17 | "ioi": ["kl_div", "logit_diff"],
18 | "tracr-reverse": ["l2"],
19 | "tracr-proportion": ["l2"],
20 | "induction": ["kl_div", "nll"],
21 | "docstring": ["kl_div", "docstring_metric"],
22 | "greaterthan": ["kl_div", "greaterthan"],
23 | }
24 |
25 |
26 | def main(
27 | alg: str,
28 | task: str,
29 | job: KubernetesJob,
30 | testing: bool = False,
31 | mod_idx=0,
32 | num_processes=1,
33 | ):
34 | # mod_idx= MPI.COMM_WORLD.Get_rank()
35 | # num_processes = MPI.COMM_WORLD.Get_size()
36 |
37 | if IS_ADRIA:
38 | OUT_RELPATH = Path(".cache") / "plots_data"
39 | OUT_HOME_DIR = Path(os.environ["HOME"]) / OUT_RELPATH
40 | else:
41 | OUT_RELPATH = Path("experiments/results/arthur_plots_data") # trying to remove extra things from acdc/
42 | OUT_HOME_DIR = OUT_RELPATH
43 |
44 | assert OUT_HOME_DIR.exists()
45 |
46 | if IS_ADRIA:
47 | OUT_DIR = Path("/root") / OUT_RELPATH
48 | else:
49 | OUT_DIR = OUT_RELPATH
50 |
51 | seed = 1233778640
52 | random.seed(seed)
53 |
54 | commands = []
55 | for reset_network in [0, 1]:
56 | for zero_ablation in [0, 1]:
57 | for metric in METRICS_FOR_TASK[task]:
58 | if alg == "canonical" and (task == "induction" or metric == "kl_div"):
59 | continue
60 |
61 | command = [
62 | "python",
63 | "notebooks/roc_plot_generator.py",
64 | f"--task={task}",
65 | f"--reset-network={reset_network}",
66 | f"--metric={metric}",
67 | f"--alg={alg}",
68 | f"--device={'cpu' if testing or not job.gpu else 'cuda'}",
69 | f"--torch-num-threads={job.cpu}",
70 | f"--out-dir={OUT_DIR}",
71 | f"--seed={random.randint(0, 2**31-1)}",
72 | ]
73 | if zero_ablation:
74 | command.append("--zero-ablation")
75 |
76 | if alg == "acdc" and task == "greaterthan" and metric == "kl_div" and not zero_ablation and not reset_network:
77 | command.append("--ignore-missing-score")
78 | commands.append(command)
79 |
80 | if IS_ADRIA:
81 | launch(
82 | commands,
83 | name="collect_data",
84 | job=job,
85 | synchronous=True,
86 | just_print_commands=False,
87 | check_wandb=WandbIdentifier(f"agarriga-col-{alg}-{task[-5:]}-{{i:04d}}b", "collect", "acdc"),
88 | )
89 |
90 | else:
91 | for command_idx in range(mod_idx, len(commands), num_processes): # commands:
92 | # run 4 in parallel
93 | command = commands[command_idx]
94 | print(f"Running command {command_idx} / {len(commands)}")
95 | print(" ".join(command))
96 | subprocess.run(command)
97 |
98 |
99 | tasks_for = {
100 | "acdc": ["ioi", "greaterthan"],
101 | "16h": TASKS,
102 | "sp": TASKS,
103 | "canonical": TASKS,
104 | }
105 |
106 | parser = argparse.ArgumentParser()
107 | parser.add_argument("--i", type=int, default=0)
108 | parser.add_argument("--n", type=int, default=1)
109 |
110 | args = parser.parse_args()
111 | mod_idx = args.i
112 | num_processes = args.n
113 |
114 | if __name__ == "__main__":
115 | for alg in ["acdc"]: # , "16h", "sp", "canonical"]:
116 | for task in tasks_for[alg]:
117 | main(
118 | alg,
119 | task,
120 | KubernetesJob(
121 | container="ghcr.io/rhaps0dy/automatic-circuit-discovery:e1884e4",
122 | cpu=6,
123 | gpu=0 if not IS_ADRIA or task.startswith("tracr") or alg not in ["acdc", "canonical"] else 1,
124 | mount_training=False,
125 | ),
126 | testing=False,
127 | mod_idx=mod_idx,
128 | num_processes=num_processes,
129 | )
130 |
--------------------------------------------------------------------------------
/experiments/docstring_results.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import pandas as pd
4 | import numpy as np
5 | import plotly.express as px
6 | import wandb
7 |
8 | import hashlib
9 | import os
10 |
11 | import plotly.io as pio
12 | import plotly.graph_objects as go
13 | from plotly.subplots import make_subplots
14 |
15 |
16 | class EmacsRenderer(pio.base_renderers.ColabRenderer):
17 | save_dir = "ob-jupyter"
18 | base_url = f"http://localhost:8888/files"
19 |
20 | def to_mimebundle(self, fig_dict):
21 | html = super().to_mimebundle(fig_dict)["text/html"]
22 |
23 | mhash = hashlib.md5(html.encode("utf-8")).hexdigest()
24 | if not os.path.isdir(self.save_dir):
25 | os.mkdir(self.save_dir)
26 | fhtml = os.path.join(self.save_dir, mhash + ".html")
27 | with open(fhtml, "w") as f:
28 | f.write(html)
29 |
30 | return {"text/html": f'Click to open {fhtml}'}
31 |
32 |
33 | pio.renderers["emacs"] = EmacsRenderer()
34 |
35 |
36 | def set_plotly_renderer(renderer="emacs"):
37 | pio.renderers.default = renderer
38 |
39 |
40 | set_plotly_renderer("emacs")
41 |
42 | ACDC_GROUP = "adria-docstring3"
43 | SP_GROUP = "docstring3"
44 |
45 | # %%
46 |
47 | api = wandb.Api()
48 | all_runs = api.runs(path="remix_school-of-rock/acdc", filters={"group": ACDC_GROUP})
49 |
50 | df = pd.DataFrame()
51 | for r in all_runs:
52 | try:
53 | cfg = {k: r.config[k] for k in ["reset_network", "zero_ablation", "metric", "task", "threshold"]}
54 | d = {
55 | k: r.summary[k]
56 | for k in [
57 | "cur_metric",
58 | "num_edges",
59 | "test_docstring_metric",
60 | "test_docstring_stefan",
61 | "test_kl_div",
62 | "test_match_nll",
63 | "test_nll",
64 | ]
65 | }
66 | except KeyError as e:
67 | print("problems with run ", r.name, e)
68 | continue
69 | d = dict(**cfg, **d)
70 | d["alg"] = "acdc"
71 |
72 | idx = int(r.name.split("-")[-1])
73 | df = pd.concat([df, pd.DataFrame(d, index=[idx])])
74 |
75 | # %% Now subnetwork-probing runs
76 |
77 | start_idx: float = df.index.max() + 1
78 |
79 | all_runs = api.runs(path="remix_school-of-rock/induction-sp-replicate", filters={"group": SP_GROUP})
80 | for r in all_runs:
81 | try:
82 | cfg = {k: r.config[k] for k in ["reset_subject", "zero_ablation", "loss_type", "lambda_reg"]}
83 | d = {
84 | k: r.summary[k]
85 | for k in [
86 | "number_of_edges",
87 | "specific_metric",
88 | "test_docstring_metric",
89 | "test_docstring_stefan",
90 | "test_kl_div",
91 | "test_match_nll",
92 | "test_nll",
93 | ]
94 | }
95 | except KeyError as e:
96 | print("problems with run ", r.name, e)
97 | continue
98 | cfg["metric"] = cfg["loss_type"]
99 | del cfg["loss_type"]
100 | cfg["reset_network"] = cfg["reset_subject"]
101 | del cfg["reset_subject"]
102 | cfg["num_edges"] = d["number_of_edges"]
103 | cfg["cur_metric"] = d["specific_metric"] / r.config["n_loss_average_runs"]
104 | for k in d.keys():
105 | if k.startswith("test_"):
106 | cfg[k] = d[k] / r.config["n_loss_average_runs"]
107 | cfg["alg"] = "subnetwork-probing"
108 |
109 | idx = int(r.name.split("-")[-1]) + start_idx
110 | df = pd.concat([df, pd.DataFrame(cfg, index=[idx])])
111 |
112 | # %%
113 |
114 | df.loc[:, "color"] = df.apply(lambda x: f"{x['alg']}-reset={x['reset_network']:.0f}", axis=1)
115 |
116 | # Scatter plot of num_edges vs cur_metric grouped by reset_network
117 |
118 | fig = px.scatter(
119 | df,
120 | x="num_edges",
121 | y="cur_metric",
122 | color="color",
123 | color_discrete_map={
124 | "acdc-reset=0": "red",
125 | "acdc-reset=1": "blue",
126 | "subnetwork-probing-reset=0": "orange",
127 | "subnetwork-probing-reset=1": "green",
128 | },
129 | facet_col="zero_ablation",
130 | facet_row="metric",
131 | facet_col_wrap=2,
132 | hover_data=["threshold", "lambda_reg"],
133 | title="Induction, TRAIN metric",
134 | )
135 | fig.show()
136 |
137 | # %%
138 |
139 | for test_metric in ["test_docstring_metric", "test_docstring_stefan", "test_kl_div", "test_match_nll", "test_nll"]:
140 | fig = px.scatter(
141 | df,
142 | x="num_edges",
143 | y=test_metric,
144 | color="color",
145 | color_discrete_map={
146 | "acdc-reset=0": "red",
147 | "acdc-reset=1": "blue",
148 | "subnetwork-probing-reset=0": "orange",
149 | "subnetwork-probing-reset=1": "green",
150 | },
151 | facet_col="zero_ablation",
152 | facet_row="metric",
153 | facet_col_wrap=2,
154 | hover_data=["threshold", "lambda_reg"],
155 | title=f"Induction, {test_metric} metric",
156 | )
157 | fig.show()
158 |
159 |
160 | # %% Scatter plot for train vs test of every metric
161 | fig = make_subplots()
162 |
163 | for test_metric in ["test_docstring_metric", "test_docstring_stefan", "test_kl_div", "test_match_nll", "test_nll"]:
164 | this_df = df[(df["metric"] == test_metric.lstrip("test_")) & (~df["zero_ablation"])]
165 | trace1 = go.Scatter(
166 | x=this_df["cur_metric"],
167 | y=this_df[test_metric],
168 | mode="markers",
169 | name=test_metric,
170 | )
171 | fig.add_trace(trace1)
172 | fig.show()
173 |
174 | # %% Compare each metric with the other metrics
175 |
176 | for main_metric in ["test_docstring_metric", "test_docstring_stefan", "test_kl_div", "test_match_nll", "test_nll"]:
177 | fig = make_subplots()
178 | this_df = df[(df["metric"] == main_metric.lstrip("test_")) & (~df["zero_ablation"])]
179 | for test_metric in ["test_docstring_metric", "test_docstring_stefan", "test_kl_div", "test_match_nll", "test_nll"]:
180 | trace1 = go.Scatter(
181 | x=this_df["cur_metric"],
182 | y=this_df[test_metric],
183 | mode="markers",
184 | name=test_metric,
185 | )
186 | fig.add_trace(trace1)
187 | # set title
188 | fig.update_layout(
189 | title_text=f"Comparison of {main_metric} with other metrics",
190 | xaxis_title=main_metric,
191 | yaxis_title="other metrics",
192 | )
193 | fig.show()
194 |
--------------------------------------------------------------------------------
/experiments/induction_results.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import pandas as pd
4 | import numpy as np
5 | import plotly.express as px
6 | import wandb
7 |
8 | import hashlib
9 | import os
10 |
11 | import plotly.io as pio
12 |
13 |
14 | class EmacsRenderer(pio.base_renderers.ColabRenderer):
15 | save_dir = "ob-jupyter"
16 | base_url = f"http://localhost:8888/files"
17 |
18 | def to_mimebundle(self, fig_dict):
19 | html = super().to_mimebundle(fig_dict)["text/html"]
20 |
21 | mhash = hashlib.md5(html.encode("utf-8")).hexdigest()
22 | if not os.path.isdir(self.save_dir):
23 | os.mkdir(self.save_dir)
24 | fhtml = os.path.join(self.save_dir, mhash + ".html")
25 | with open(fhtml, "w") as f:
26 | f.write(html)
27 |
28 | return {"text/html": f'Click to open {fhtml}'}
29 |
30 |
31 | pio.renderers["emacs"] = EmacsRenderer()
32 |
33 |
34 | def set_plotly_renderer(renderer="emacs"):
35 | pio.renderers.default = renderer
36 |
37 | set_plotly_renderer("emacs")
38 |
39 | ACDC_GROUP = "adria-induction-3"
40 | # ACDC_GROUP = "adria-docstring3"
41 | SP_GROUP = "reset-with-nll-21"
42 | # SP_GROUP = "docstring3"
43 |
44 | # %%
45 |
46 | api = wandb.Api()
47 | all_runs = api.runs(path="remix_school-of-rock/acdc", filters={"group": ACDC_GROUP})
48 |
49 | df = pd.DataFrame()
50 |
51 | total = len(all_runs)
52 | failed = 0
53 |
54 | for r in all_runs:
55 | try:
56 | d = {k: r.summary[k] for k in ["cur_metric", "test_specific_metric", "num_edges"]}
57 | except KeyError:
58 | failed+=1
59 | else:
60 | idx = int(r.name.split("-")[-1])
61 | df = pd.concat([df, pd.DataFrame(d, index=[idx])])
62 |
63 | assert failed/total < 0.5
64 |
65 | # %%
66 |
67 | thresholds = 10 ** np.linspace(-2, 0.5, 21)
68 |
69 | i = 0
70 | for reset_network in [0, 1]:
71 | for zero_ablation in [0, 1]:
72 | for loss_type in ["kl_div", "nll", "match_nll"]:
73 | for threshold in thresholds:
74 | df.loc[i, "reset_network"] = reset_network
75 | df.loc[i, "zero_ablation"] = zero_ablation
76 | df.loc[i, "loss_type"] = loss_type
77 | df.loc[i, "threshold"] = threshold
78 |
79 | i += 1
80 |
81 | df.loc[:, "alg"] = "acdc"
82 | df.loc[:, "num_examples"] = 50
83 |
84 | # %% Now subnetwork-probing runs
85 |
86 | sp_runs = []
87 |
88 | all_runs = api.runs(path="remix_school-of-rock/induction-sp-replicate", filters={"group": SP_GROUP})
89 | for r in all_runs:
90 | try:
91 | cfg = {k: r.config[k] for k in ["reset_subject", "zero_ablation", "loss_type", "lambda_reg", "num_examples"]}
92 | d = {k: r.summary[k] for k in ["number_of_edges", "specific_metric", "test_specific_metric", "specific_metric_loss"]}
93 | except KeyError:
94 | continue
95 | cfg["reset_network"] = cfg["reset_subject"]
96 | del cfg["reset_subject"]
97 | cfg["num_edges"] = d["number_of_edges"]
98 | cfg["cur_metric"] = d["specific_metric"] / r.config["n_loss_average_runs"]
99 | cfg["test_specific_metric"] = d["test_specific_metric"] / r.config["n_loss_average_runs"]
100 | cfg["alg"] = "subnetwork-probing"
101 | sp_runs.append(cfg)
102 |
103 |
104 | i = df.index.max() + 1
105 | for r in sp_runs:
106 | df = pd.concat([df, pd.DataFrame(r, index=[i])])
107 | i += 1
108 |
109 | # %%
110 |
111 | df.loc[:, "color"] = df.apply(lambda x: f"{x['alg']}-reset={x['reset_network']:.0f}", axis=1)
112 |
113 | # Scatter plot of num_edges vs cur_metric grouped by reset_network
114 |
115 | fig = px.scatter(
116 | df,
117 | x="num_edges",
118 | y="cur_metric",
119 | color="color",
120 | color_discrete_map={
121 | "acdc-reset=0": "red",
122 | "acdc-reset=1": "blue",
123 | "subnetwork-probing-reset=0": "orange",
124 | "subnetwork-probing-reset=1": "green",
125 | },
126 | facet_col="zero_ablation",
127 | facet_row="loss_type",
128 | facet_col_wrap=2,
129 | hover_data=["threshold", "lambda_reg"],
130 | title="Induction, TRAIN metric",
131 | )
132 | fig.show()
133 |
134 | # %%
135 |
136 | fig = px.scatter(
137 | df,
138 | x="num_edges",
139 | y="test_specific_metric",
140 | color="color",
141 | color_discrete_map={
142 | "acdc-reset=0": "red",
143 | "acdc-reset=1": "blue",
144 | "subnetwork-probing-reset=0": "orange",
145 | "subnetwork-probing-reset=1": "green",
146 | },
147 | facet_col="zero_ablation",
148 | facet_row="loss_type",
149 | facet_col_wrap=2,
150 | hover_data=["threshold", "lambda_reg"],
151 | title="Induction, TEST metric",
152 | )
153 | fig.show()
154 |
--------------------------------------------------------------------------------
/experiments/launch_abstract.py:
--------------------------------------------------------------------------------
1 | from experiments.launcher import KubernetesJob, WandbIdentifier, launch
2 | import numpy as np
3 | import random
4 | from typing import List
5 |
6 | def main(use_kubernetes: bool, testing: bool, CPU: int = 4):
7 | testing = True
8 | task = "docstring"
9 | reset_network = 0
10 | kwargses = [
11 | {"threshold": 0.067, "metric": "docstring_metric"},
12 | {"threshold": 0.005, "metric": "kl_div"},
13 | {"threshold": 0.095, "metric": "kl_div"},
14 | ]
15 |
16 | seed = 1495006036
17 | random.seed(seed)
18 |
19 | wandb_identifier = WandbIdentifier(
20 | run_name="abstract-{i:05d}",
21 | group_name="abstract",
22 | project="acdc",
23 | )
24 |
25 | commands: List[List[str]] = []
26 | for kwargs in kwargses:
27 | command = [
28 | "python",
29 | "acdc/main.py",
30 | f"--task={task}",
31 | f"--threshold={kwargs['threshold']:.5f}",
32 | "--using-wandb",
33 | f"--wandb-run-name={wandb_identifier.run_name.format(i=len(commands))}",
34 | f"--wandb-group-name={wandb_identifier.group_name}",
35 | f"--wandb-project-name={wandb_identifier.project}",
36 | "--device=cuda",
37 | f"--torch-num-threads={CPU}",
38 | f"--reset-network={reset_network}",
39 | f"--seed={random.randint(0, 2**32 - 1)}",
40 | f"--metric={kwargs['metric']}",
41 | "--wandb-dir=/root/.cache/huggingface/tracr-training/acdc", # If it doesn't exist wandb will use /tmp
42 | f"--wandb-mode=online",
43 | f"--max-num-epochs={1 if testing else 40_000}",
44 | ]
45 | commands.append(command)
46 |
47 | launch(
48 | commands,
49 | name="acdc-docstring-abstract",
50 | job=None
51 | if not use_kubernetes
52 | else KubernetesJob(container="ghcr.io/rhaps0dy/automatic-circuit-discovery:1.6.1", cpu=CPU, gpu=1),
53 | check_wandb=wandb_identifier,
54 | just_print_commands=False,
55 | )
56 |
57 |
58 | if __name__ == "__main__":
59 | main(use_kubernetes=True, testing=False)
60 |
--------------------------------------------------------------------------------
/experiments/launch_abstract_old.py:
--------------------------------------------------------------------------------
1 | from experiments.launcher import KubernetesJob, WandbIdentifier, launch
2 | import numpy as np
3 | import random
4 |
5 | def main(use_kubernetes: bool, testing: bool, CPU: int = 4):
6 | task = "docstring"
7 | reset_network = 0
8 | kwargses = [
9 | {"threshold": 0.067, "metric": "docstring_metric"},
10 | {"threshold": 0.005, "metric": "kl_div"},
11 | {"threshold": 0.095, "metric": "kl_div"},
12 | ]
13 |
14 | seed = 1495006036
15 | random.seed(seed)
16 |
17 | for kwargs in kwargses:
18 | wandb_identifier = WandbIdentifier(
19 | run_name=f"docstring_kl_{kwargs['threshold']:.5f}",
20 | group_name="default",
21 | project="acdc-abstract",
22 | )
23 | command = [
24 | "python",
25 | "acdc/main.py",
26 | f"--task={task}",
27 | f"--threshold={kwargs['threshold']:.5f}",
28 | "--using-wandb",
29 | "--wandb-entity-name=remix_school-of-rock",
30 | f"--wandb-run-name={wandb_identifier.run_name}",
31 | f"--wandb-project-name={wandb_identifier.project}",
32 | ]
33 | launch(
34 | [command],
35 | name=wandb_identifier.run_name,
36 | job=None
37 | if not use_kubernetes
38 | else KubernetesJob(container="ghcr.io/rhaps0dy/automatic-circuit-discovery:abstract-0.0", cpu=CPU, gpu=1),
39 | check_wandb=wandb_identifier,
40 | just_print_commands=False,
41 | )
42 |
43 |
44 | if __name__ == "__main__":
45 | main(use_kubernetes=True, testing=False)
46 |
--------------------------------------------------------------------------------
/experiments/launch_all_sixteen_heads.py:
--------------------------------------------------------------------------------
1 | from experiments.launcher import KubernetesJob, WandbIdentifier, launch
2 | import numpy as np
3 | import random
4 | from typing import List
5 |
6 | METRICS_FOR_TASK = {
7 | "ioi": ["kl_div", "logit_diff"],
8 | "tracr-reverse": ["l2"],
9 | "tracr-proportion": ["kl_div", "l2"],
10 | "induction": ["kl_div", "nll"],
11 | "docstring": ["kl_div", "docstring_metric"],
12 | "greaterthan": ["greaterthan"], # "kl_div",
13 | }
14 |
15 |
16 | CPU = 2
17 |
18 | def main(TASKS: list[str], job: KubernetesJob, name: str, group_name: str):
19 | seed = 1259281515
20 | random.seed(seed)
21 |
22 | wandb_identifier = WandbIdentifier(
23 | run_name=f"{name}-{{i:05d}}",
24 | group_name=group_name,
25 | project="acdc")
26 |
27 | commands: List[List[str]] = []
28 | for reset_network in [0, 1]:
29 | for zero_ablation in [0, 1]:
30 | for task in TASKS:
31 | for metric in METRICS_FOR_TASK[task]:
32 | if "tracr" not in task:
33 | if reset_network==0 and zero_ablation==0:
34 | continue
35 | if task in ["ioi", "induction"] and reset_network==0 and zero_ablation==1:
36 | continue
37 |
38 | command = [
39 | "python",
40 | "experiments/launch_sixteen_heads.py",
41 | f"--task={task}",
42 | f"--wandb-run-name={wandb_identifier.run_name.format(i=len(commands))}",
43 | f"--wandb-group={wandb_identifier.group_name}",
44 | f"--wandb-project={wandb_identifier.project}",
45 | f"--device={'cuda' if job.gpu else 'cpu'}",
46 | f"--reset-network={reset_network}",
47 | f"--seed={random.randint(0, 2**32 - 1)}",
48 | f"--metric={metric}",
49 | f"--torch-num-threads={CPU}",
50 | "--wandb-dir=/root/.cache/huggingface/tracr-training/16heads", # If it doesn't exist wandb will use /tmp
51 | f"--wandb-mode=online",
52 | ]
53 | if zero_ablation:
54 | command.append("--zero-ablation")
55 |
56 | commands.append(command)
57 |
58 |
59 | launch(
60 | commands,
61 | name=wandb_identifier.run_name,
62 | job=job,
63 | check_wandb=wandb_identifier,
64 | just_print_commands=False,
65 | synchronous=True,
66 | )
67 |
68 |
69 | if __name__ == "__main__":
70 | main(
71 | # ["ioi", "greaterthan", "induction", "docstring"],
72 | ["tracr-reverse"],
73 | KubernetesJob(container="ghcr.io/rhaps0dy/automatic-circuit-discovery:1.7.1", cpu=CPU, gpu=0),
74 | "16h-redo",
75 | group_name="sixteen-heads",
76 | )
77 | # main(
78 | # ["tracr-reverse", "tracr-proportion"],
79 | # KubernetesJob(container="ghcr.io/rhaps0dy/automatic-circuit-discovery:1.6.1", cpu=4, gpu=0),
80 | # "16h-tracr",
81 | # )
82 |
--------------------------------------------------------------------------------
/experiments/launch_docstring.py:
--------------------------------------------------------------------------------
1 | from experiments.launcher import KubernetesJob, launch
2 | import numpy as np
3 |
4 | CPU = 4
5 |
6 |
7 | def main(testing: bool):
8 | thresholds = 10 ** np.linspace(-2, 0.5, 21)
9 | seed = 516626229
10 |
11 | commands: list[list[str]] = []
12 | for reset_network in [0]:
13 | for zero_ablation in [0, 1]:
14 | for loss_type in ["kl_div", "docstring_metric", "docstring_stefan", "nll", "match_nll"]:
15 | for threshold in [1.0] if testing else thresholds:
16 | command = [
17 | "python",
18 | "acdc/main.py" if testing else "/Automatic-Circuit-Discovery/acdc/main.py",
19 | "--task=docstring",
20 | f"--threshold={threshold:.5f}",
21 | "--using-wandb",
22 | f"--wandb-run-name=agarriga-docstring-{len(commands):03d}",
23 | "--wandb-group-name=adria-docstring",
24 | f"--device=cpu",
25 | f"--reset-network={reset_network}",
26 | f"--seed={seed}",
27 | f"--metric={loss_type}",
28 | f"--torch-num-threads={CPU}",
29 | "--wandb-dir=/training/acdc", # If it doesn't exist wandb will use /tmp
30 | f"--wandb-mode={'offline' if testing else 'online'}",
31 | f"--max-num-epochs={1 if testing else 100_000}",
32 | ]
33 | if zero_ablation:
34 | command.append("--zero-ablation")
35 |
36 | commands.append(command)
37 |
38 | launch(
39 | commands,
40 | name="acdc-docstring",
41 | job=None
42 | if testing
43 | else KubernetesJob(container="ghcr.io/rhaps0dy/automatic-circuit-discovery:1.2.10", cpu=CPU, gpu=0),
44 | )
45 |
46 |
47 | if __name__ == "__main__":
48 | main(testing=False)
49 |
--------------------------------------------------------------------------------
/experiments/launch_induction.py:
--------------------------------------------------------------------------------
1 | from experiments.launcher import KubernetesJob, launch
2 | import subprocess
3 | import argparse
4 | import numpy as np
5 |
6 | CPU = 4
7 |
8 | def main(
9 | testing: bool,
10 | is_adria: bool,
11 | ):
12 | thresholds = 10 ** np.linspace(-2, 0.5, 21)
13 | seed = 424671755
14 |
15 | commands: list[list[str]] = []
16 | for reset_network in [0, 1]:
17 | for zero_ablation in [0, 1]:
18 | for loss_type in ["kl_div"]:
19 | for threshold in [1.0] if testing else thresholds:
20 | command = [
21 | "python",
22 | "acdc/main.py" if (not is_adria) else "/Automatic-Circuit-Discovery/acdc/main.py",
23 | "--task=induction",
24 | f"--threshold={threshold:.5f}",
25 | "--using-wandb",
26 | f"--wandb-run-name=agarriga-acdc-{len(commands):03d}",
27 | "--wandb-group-name=adria-induction-3",
28 | f"--device=cpu",
29 | f"--reset-network={reset_network}",
30 | f"--seed={seed}",
31 | f"--metric={loss_type}",
32 | f"--torch-num-threads={CPU}",
33 | "--wandb-dir=/training/acdc",
34 | f"--wandb-mode={'offline' if testing else 'online'}",
35 | ]
36 | if zero_ablation:
37 | command.append("--zero-ablation")
38 |
39 | commands.append(command)
40 |
41 | if is_adria:
42 | launch(
43 | commands,
44 | name="acdc-induction",
45 | job=None
46 | if testing
47 | else KubernetesJob(container="ghcr.io/rhaps0dy/automatic-circuit-discovery:1.2.8", cpu=CPU, gpu=0),
48 | )
49 |
50 | else:
51 | for command in commands:
52 | print("Running", command)
53 | subprocess.run(command)
54 |
55 | if __name__ == "__main__":
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument("--testing", action="store_true")
58 | parser.add_argument("--is-adria", action="store_true")
59 | main(
60 | testing=parser.parse_args().testing,
61 | is_adria=parser.parse_args().is_adria,
62 | )
63 |
--------------------------------------------------------------------------------
/experiments/launch_spreadsheet.py:
--------------------------------------------------------------------------------
1 | from experiments.launcher import KubernetesJob, WandbIdentifier, launch
2 | import numpy as np
3 | import random
4 | from typing import List
5 |
6 | METRICS_FOR_TASK = {
7 | "ioi": ["kl_div", "logit_diff"],
8 | "tracr-reverse": ["l2"],
9 | "tracr-proportion": ["l2"],
10 | "induction": ["kl_div", "nll"],
11 | "docstring": ["kl_div", "docstring_metric"],
12 | "greaterthan": ["kl_div", "greaterthan"],
13 | }
14 |
15 | CPU = 4
16 |
17 | def main(TASKS: list[str], group_name: str, run_name: str, testing: bool, use_kubernetes: bool, reset_networks: bool, abs_value_threshold: bool, use_gpu: bool=True):
18 | NUM_SPACINGS = 5 if reset_networks else 21
19 | base_thresholds = 10 ** np.linspace(-4, 0, 21)
20 |
21 | seed = 486887094
22 | random.seed(seed)
23 |
24 | wandb_identifier = WandbIdentifier(
25 | run_name=run_name,
26 | group_name=group_name,
27 | project="acdc")
28 |
29 | commands: List[List[str]] = []
30 | for reset_network in [int(reset_networks)]:
31 | for zero_ablation in [0]:
32 | for task in TASKS:
33 | for metric in METRICS_FOR_TASK[task]:
34 |
35 | if task.startswith("tracr"):
36 | # Typical metric value range: 0.0-0.1
37 | thresholds = 10 ** np.linspace(-5, -1, 21)
38 |
39 | if task == "tracr-reverse":
40 | num_examples = 6
41 | seq_len = 5
42 | elif task == "tracr-proportion":
43 | num_examples = 50
44 | seq_len = 5
45 | else:
46 | raise ValueError("Unknown task")
47 |
48 | elif task == "greaterthan":
49 | if metric == "kl_div":
50 | # Typical metric value range: 0.0-20
51 | # thresholds = 10 ** np.linspace(-4, 0, NUM_SPACINGS)
52 | thresholds = 10 ** np.linspace(-6, -4, 11)
53 | elif metric == "greaterthan":
54 | # Typical metric value range: -1.0 - 0.0
55 | thresholds = 10 ** np.linspace(-3, -1, NUM_SPACINGS)
56 | else:
57 | raise ValueError("Unknown metric")
58 | num_examples = 100
59 | seq_len = -1
60 | elif task == "docstring":
61 | seq_len = 41
62 | if metric == "kl_div":
63 | # Typical metric value range: 0.0-10.0
64 | thresholds = base_thresholds
65 | elif metric == "docstring_metric":
66 | # Typical metric value range: -1.0 - 0.0
67 | thresholds = 10 ** np.linspace(-4, 0, 21)
68 | else:
69 | raise ValueError("Unknown metric")
70 | num_examples = 50
71 | elif task == "ioi":
72 | num_examples = 100
73 | seq_len = -1
74 | if metric == "kl_div":
75 | # Typical metric value range: 0.0-12.0
76 | thresholds = 10 ** np.linspace(-6, 0, 31)
77 | elif metric == "logit_diff":
78 | # Typical metric value range: -0.31 -- -0.01
79 | thresholds = 10 ** np.linspace(-4, 0, NUM_SPACINGS)
80 | else:
81 | raise ValueError("Unknown metric")
82 | elif task == "induction":
83 | seq_len = 300
84 | num_examples = 50
85 | if metric == "kl_div":
86 | # Typical metric value range: 0.0-16.0
87 | thresholds = base_thresholds
88 | elif metric == "nll":
89 | # Typical metric value range: 0.0-16.0
90 | thresholds = base_thresholds
91 | else:
92 | raise ValueError("Unknown metric")
93 | else:
94 | raise ValueError("Unknown task")
95 |
96 | for threshold in [1.0] if testing else thresholds:
97 | command = [
98 | "python",
99 | "acdc/main.py",
100 | f"--task={task}",
101 | f"--threshold={threshold}",
102 | "--using-wandb",
103 | f"--wandb-run-name={wandb_identifier.run_name.format(i=len(commands))}",
104 | f"--wandb-group-name={wandb_identifier.group_name}",
105 | f"--wandb-project-name={wandb_identifier.project}",
106 | f"--device={'cuda' if not testing else 'cpu'}" if "tracr" not in task else "--device=cpu",
107 | f"--reset-network={reset_network}",
108 | f"--seed={random.randint(0, 2**32 - 1)}",
109 | f"--metric={metric}",
110 | f"--torch-num-threads={CPU}",
111 | "--wandb-dir=/root/.cache/huggingface/tracr-training/acdc", # If it doesn't exist wandb will use /tmp
112 | f"--wandb-mode=online",
113 | f"--max-num-epochs={1 if testing else 40_000}",
114 | ]
115 | if zero_ablation:
116 | command.append("--zero-ablation")
117 | if abs_value_threshold:
118 | command.append("--abs-value-threshold")
119 | commands.append(command)
120 |
121 | launch(
122 | commands,
123 | name="acdc-spreadsheet",
124 | job=None
125 | if not use_kubernetes
126 | else KubernetesJob(container="ghcr.io/rhaps0dy/automatic-circuit-discovery:181999f", cpu=CPU, gpu=int(use_gpu)),
127 | check_wandb=wandb_identifier,
128 | just_print_commands=False,
129 | )
130 |
131 |
132 | if __name__ == "__main__":
133 | for reset_networks in [False]:
134 | main(
135 | TASKS=["ioi"],
136 | group_name="abs-value",
137 | run_name=f"agarriga-ioi-res{int(reset_networks)}-{{i:05d}}",
138 | testing=False,
139 | use_kubernetes=True,
140 | reset_networks=reset_networks,
141 | abs_value_threshold=True,
142 | use_gpu=True,
143 | )
144 |
145 | # if __name__ == "__main__":
146 | # for reset_networks in [False, True]:
147 | # main(
148 | # ["ioi", "greaterthan", "induction", "docstring"],
149 | # "reset-networks-neurips",
150 | # "agarriga-tracr3-{i:05d}",
151 | # testing=False,
152 | # use_kubernetes=True,
153 | # reset_networks=True,
154 | # )
155 | # main(
156 | # ["induction"],
157 | # "adria-induction-3",
158 | # "agarriga-induction-{i:05d}",
159 | # testing=False,
160 | # use_kubernetes=True,
161 | # reset_networks=False,
162 | # )
163 |
--------------------------------------------------------------------------------
/experiments/launcher.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import subprocess
3 | from typing import Optional, TextIO, List, Tuple
4 | import numpy as np
5 | import shlex
6 | import dataclasses
7 | import wandb
8 |
9 | @dataclasses.dataclass(frozen=True)
10 | class KubernetesJob:
11 | container: str
12 | cpu: int
13 | gpu: int
14 | mount_training: bool=False
15 |
16 | def mount_training_options(self) -> list[str]:
17 | if not self.mount_training:
18 | return []
19 | return [
20 | "--volume-mount=/training",
21 | "--volume-name=agarriga-models-training",
22 | ]
23 |
24 |
25 | @dataclasses.dataclass(frozen=True)
26 | class WandbIdentifier:
27 | run_name: str
28 | group_name: str
29 | project: str
30 |
31 |
32 | def launch(commands: List[List[str]], name: str, job: Optional[KubernetesJob] = None, check_wandb: Optional[WandbIdentifier]=None, ids_for_worker=range(0, 10000000), synchronous=True, just_print_commands=False):
33 | to_wait: List[Tuple[str, subprocess.Popen, TextIO, TextIO]] = []
34 |
35 | assert len(commands) <= 100_000, "Too many commands for 5 digits"
36 |
37 | print(f"Launching {len(commands)} jobs")
38 | for i, command in enumerate(commands):
39 | if i not in ids_for_worker:
40 | print(f"Skipping {name} because it's not my turn, {i} not in {ids_for_worker}")
41 | continue
42 |
43 | command_str = shlex.join(command)
44 |
45 |
46 | if check_wandb is not None:
47 | # HACK this is pretty vulnerable to duplicating work if the same run is launched in close succession,
48 | # it's more to be able to restart
49 | # api = wandb.Api()
50 | name = check_wandb.run_name.format(i=i)
51 | # if name in existing_names:
52 | # print(f"Skipping {name} because it already exists")
53 | # continue
54 |
55 | # runs = api.runs(path=f"remix_school-of-rock/{check_wandb.project}", filters={"group": check_wandb.group_name})
56 | # existing_names = existing_names.union({r.name for r in runs})
57 | # print("Runs that exist: ", existing_names)
58 | # if name in existing_names:
59 | # print(f"Run {name} already exists, skipping")
60 | # continue
61 |
62 | print("Launching", name, command_str)
63 | if just_print_commands:
64 | continue
65 |
66 | if job is None:
67 | if synchronous:
68 | out = subprocess.run(command)
69 | assert out.returncode == 0, f"Command return={out.returncode} != 0"
70 | else:
71 | base_path = Path(f"/tmp/{name}")
72 | base_path.mkdir(parents=True, exist_ok=True)
73 | stdout = open(base_path / f"stdout_{i:05d}.txt", "w")
74 | stderr = open(base_path / f"stderr_{i:05d}.txt", "w")
75 | out = subprocess.Popen(command, stdout=stdout, stderr=stderr)
76 | to_wait.append((command_str, out, stdout, stderr))
77 | else:
78 | if "cuda" in command_str:
79 | assert job.gpu > 0
80 | else:
81 | assert job.gpu == 0
82 |
83 | subprocess.run(
84 | [
85 | "ctl",
86 | "job",
87 | "run",
88 | f"--name={name}",
89 | "--shared-host-dir-slow-tolerant",
90 | f"--container={job.container}",
91 | f"--cpu={job.cpu}",
92 | f"--gpu={job.gpu}",
93 | "--login",
94 | "--wandb",
95 | "--never-restart",
96 | f"--command={command_str}",
97 | "--working-dir=/Automatic-Circuit-Discovery",
98 | "--shared-host-dir=/home/agarriga/.cache",
99 | "--shared-host-dir-mount=/root/.cache",
100 | *job.mount_training_options(),
101 | ],
102 | check=True,
103 | )
104 | i += 1
105 |
106 | for (command, process, out, err) in to_wait:
107 | retcode = process.wait()
108 | with open(out.name, 'r') as f:
109 | stdout = f.read()
110 | with open(err.name, 'r') as f:
111 | stderr = f.read()
112 |
113 | if retcode != 0 or "nan" in stdout.lower() or "nan" in stderr.lower():
114 | s = f""" Command {command} exited with code {retcode}.
115 | stdout:
116 | {stdout}
117 | stderr:
118 | {stderr}
119 | """
120 | print(s)
121 |
--------------------------------------------------------------------------------
/experiments/results/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
3 | *.log
4 | tmp/
5 |
6 |
7 | plots_with_qkv_deleted
8 | plots_with_qkv_deleted_two
9 | plots_old_colors_rgb
10 | .auctex-auto/*
11 | auc_tables.aux
12 | auc_tables.pdf
13 | auc_tables.synctex.gz
14 |
--------------------------------------------------------------------------------
/experiments/results/auc_tables.tex:
--------------------------------------------------------------------------------
1 | \documentclass{article}
2 | \usepackage{booktabs}
3 | \usepackage{multirow}
4 | \begin{document}
5 |
6 | WARNING: you need to add some vertical and horizontal rows for this to look good
7 |
8 | \begin{table}
9 | \caption{Key=auc, weights\_type=trained, Random Ablation}
10 | \begin{tabular}{llllllll}
11 | \toprule
12 | & & \multicolumn{3}{r}{roc\_edges} & \multicolumn{3}{r}{roc\_nodes} \\
13 | & & ACDC & HISP & SP & ACDC & HISP & SP \\
14 | metric\_pretty & task & & & & & & \\
15 | \midrule
16 | \multirow[c]{3}{*}{KL} & docstring & 0.434 & 0.183 & \textbf{0.937} & 0.391 & 0.178 & \textbf{0.928} \\
17 | & greaterthan & \textbf{0.883} & 0.279 & 0.820 & \textbf{0.890} & 0.384 & 0.838 \\
18 | & ioi & 0.868 & 0.239 & \textbf{0.888} & \textbf{0.873} & 0.339 & 0.852 \\
19 | \multirow[c]{5}{*}{Loss} & docstring & \textbf{0.972} & 0.177 & 0.942 & 0.938 & 0.170 & \textbf{0.941} \\
20 | & greaterthan & 0.461 & 0.275 & \textbf{0.848} & 0.766 & 0.374 & \textbf{0.830} \\
21 | & ioi & 0.589 & 0.227 & \textbf{0.837} & 0.777 & 0.283 & \textbf{0.814} \\
22 | & tracr-proportion & 0.600 & \textbf{0.679} & 0.400 & 0.727 & \textbf{0.909} & 0.716 \\
23 | & tracr-reverse & 0.200 & \textbf{0.656} & 0.416 & 0.312 & \textbf{0.750} & 0.533 \\
24 | \bottomrule
25 | \end{tabular}
26 | \end{table}
27 | \begin{table}
28 | \caption{Key=auc, weights\_type=trained, Zero Ablation}
29 | \begin{tabular}{llllllll}
30 | \toprule
31 | & & \multicolumn{3}{r}{roc\_edges} & \multicolumn{3}{r}{roc\_nodes} \\
32 | & & ACDC & HISP & SP & ACDC & HISP & SP \\
33 | metric\_pretty & task & & & & & & \\
34 | \midrule
35 | \multirow[c]{3}{*}{KL} & docstring & \textbf{0.585} & 0.183 & 0.428 & 0.190 & 0.178 & \textbf{0.420} \\
36 | & greaterthan & 0.276 & \textbf{0.279} & 0.163 & \textbf{0.653} & 0.384 & 0.134 \\
37 | & ioi & 0.226 & 0.239 & \textbf{0.702} & 0.511 & 0.339 & \textbf{0.638} \\
38 | \multirow[c]{5}{*}{Loss} & docstring & \textbf{0.816} & 0.177 & 0.482 & \textbf{0.845} & 0.170 & 0.398 \\
39 | & greaterthan & 0.159 & 0.275 & \textbf{0.715} & 0.317 & 0.374 & \textbf{0.597} \\
40 | & ioi & 0.403 & 0.227 & \textbf{0.598} & \textbf{0.541} & 0.283 & 0.507 \\
41 | & tracr-proportion & \textbf{1.000} & 0.679 & 0.561 & \textbf{1.000} & 0.909 & 0.875 \\
42 | & tracr-reverse & \textbf{1.000} & 0.656 & 0.692 & \textbf{1.000} & 0.750 & 0.947 \\
43 | \bottomrule
44 | \end{tabular}
45 | \end{table}
46 | \end{document}
47 |
--------------------------------------------------------------------------------
/experiments/results/canonical_circuits/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
3 | *.log
4 | tmp/
5 |
6 | *.pdf
7 |
--------------------------------------------------------------------------------
/experiments/results/canonical_circuits/Makefile:
--------------------------------------------------------------------------------
1 | tracr-proportion.gv:
2 | python ../../../notebooks/roc_plot_generator.py --task=tracr-proportion --only-save-canonical --metric=l2
3 |
4 | tracr-reverse.gv:
5 | python ../../../notebooks/roc_plot_generator.py --task=tracr-reverse --only-save-canonical --metric=l2
6 |
7 | ioi.gv ioi_heads.gv ioi_heads_qkv.gv:
8 | python ../../../notebooks/roc_plot_generator.py --task=ioi --only-save-canonical
9 |
10 | greaterthan.gv greaterthan_heads.gv greaterthan_heads_qkv.gv:
11 | python ../../../notebooks/roc_plot_generator.py --task=greaterthan --only-save-canonical
12 |
13 | docstring.gv:
14 | python ../../../notebooks/roc_plot_generator.py --task=docstring --only-save-canonical
15 |
16 | %.pdf: %.gv
17 | neato -Tpdf $< -o $@
18 |
19 | all: tracr-proportion.pdf tracr-reverse.pdf ioi.pdf greaterthan.pdf docstring.pdf ioi_heads.pdf ioi_heads_qkv.pdf greaterthan_heads.pdf greaterthan_heads_qkv.pdf
20 |
21 | ioi: ioi.pdf ioi_heads.pdf ioi_heads_qkv.pdf
22 |
23 | greaterthan: greaterthan.pdf greaterthan_heads.pdf greaterthan_heads_qkv.pdf
24 |
25 | tracr: tracr-proportion.pdf tracr-reverse.pdf
26 |
27 | docstring: docstring.pdf
28 |
--------------------------------------------------------------------------------
/experiments/results/canonical_circuits/greaterthan/Makefile:
--------------------------------------------------------------------------------
1 | layout.\#%.gv: \#%.gv
2 | neato -olayout.$@ $<
3 |
4 | layout.gv: $(wildcard layout.*.gv)
5 | gvpack -Glayout=neato -Goverlap=false -Gsplines=true -Gbgcolor=transparent -o$@ $^
6 |
--------------------------------------------------------------------------------
/experiments/results/canonical_circuits/ioi/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
3 | *.log
4 | tmp/
5 |
6 |
7 | \#cbd5e8.gv
8 | \#f0f0f0.gv
9 | \#fff2ae.gv
10 | layout.\#e7f2da.gv
11 | layout.\#fad6e9.gv
12 | d7f8ee.gv
13 | \#f9ecd7.gv
14 | \#fff6db.gv
15 | layout.\#ececf5.gv
16 | layout.\#fee7d5.gv
17 | out.bak.xgr
18 | out.gv.bak
19 |
--------------------------------------------------------------------------------
/experiments/results/canonical_circuits/ioi/Makefile:
--------------------------------------------------------------------------------
1 | layout.\#%.gv: \#%.gv
2 | neato -olayout.$@ $<
3 |
4 | layout.gv: $(wildcard layout.*.gv)
5 | gvpack -Glayout=neato -Goverlap=false -Gsplines=true -Gbgcolor=transparent -o$@ $^
6 |
--------------------------------------------------------------------------------
/experiments/results/plots/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
3 | *.log
4 | tmp/
5 |
6 |
7 | *.pdf
8 | data.csv
9 |
--------------------------------------------------------------------------------
/experiments/results/plots_data/acdc-tracr-proportion-l2-False-1.json:
--------------------------------------------------------------------------------
1 | {
2 | "reset": {
3 | "random_ablation": {
4 | "tracr-proportion": {
5 | "l2": {
6 | "ACDC": {
7 | "score": [
8 | Infinity,
9 | 0.1,
10 | 0.0631,
11 | 0.03981,
12 | 0.02512,
13 | 0.01585,
14 | 0.01,
15 | 0.00631,
16 | 0.00398,
17 | 0.00251,
18 | 0.00158,
19 | 0.001,
20 | 0.00063,
21 | 0.0004,
22 | 0.00025,
23 | 0.00016,
24 | 0.0001,
25 | 6e-05,
26 | 4e-05,
27 | 3e-05,
28 | 2e-05,
29 | 1e-05,
30 | -Infinity
31 | ],
32 | "test_kl_div": [
33 | Infinity,
34 | 0.010016628541052341,
35 | 0.010016628541052341,
36 | 0.010016628541052341,
37 | 0.010016628541052341,
38 | 0.010016628541052341,
39 | 0.010016628541052341,
40 | 0.010016628541052341,
41 | 0.010016628541052341,
42 | 0.010016628541052341,
43 | 0.010016628541052341,
44 | 0.010016628541052341,
45 | 0.010016628541052341,
46 | 0.010016628541052341,
47 | 0.010016628541052341,
48 | 0.010016628541052341,
49 | 0.010016628541052341,
50 | 0.010016628541052341,
51 | 0.010016628541052341,
52 | 0.010016628541052341,
53 | 0.010016628541052341,
54 | 0.010016628541052341,
55 | -Infinity
56 | ],
57 | "steps": [
58 | Infinity,
59 | 74,
60 | 74,
61 | 74,
62 | 74,
63 | 74,
64 | 74,
65 | 74,
66 | 74,
67 | 74,
68 | 74,
69 | 74,
70 | 74,
71 | 74,
72 | 74,
73 | 74,
74 | 74,
75 | 74,
76 | 74,
77 | 74,
78 | 74,
79 | 74,
80 | -Infinity
81 | ],
82 | "test_l2": [
83 | Infinity,
84 | 0.11851852387189865,
85 | 0.11851852387189865,
86 | 0.11851852387189865,
87 | 0.11851852387189865,
88 | 0.11851852387189865,
89 | 0.11851852387189865,
90 | 0.11851852387189865,
91 | 0.11851852387189865,
92 | 0.11851852387189865,
93 | 0.11851852387189865,
94 | 0.11851852387189865,
95 | 0.11851852387189865,
96 | 0.11851852387189865,
97 | 0.11851852387189865,
98 | 0.11851852387189865,
99 | 0.11851852387189865,
100 | 0.11851852387189865,
101 | 0.11851852387189865,
102 | 0.11851852387189865,
103 | 0.11851852387189865,
104 | 0.11851852387189865,
105 | -Infinity
106 | ],
107 | "edge_fpr": [
108 | 0.0,
109 | 0.0,
110 | 0.0,
111 | 0.0,
112 | 0.0,
113 | 0.0,
114 | 0.0,
115 | 0.0,
116 | 0.0,
117 | 0.0,
118 | 0.0,
119 | 0.0,
120 | 0.0,
121 | 0.0,
122 | 0.0,
123 | 0.0,
124 | 0.0,
125 | 0.0,
126 | 0.0,
127 | 0.0,
128 | 0.0,
129 | 0.0,
130 | 1.0
131 | ],
132 | "edge_tpr": [
133 | 0.0,
134 | 0.0,
135 | 0.0,
136 | 0.0,
137 | 0.0,
138 | 0.0,
139 | 0.0,
140 | 0.0,
141 | 0.0,
142 | 0.0,
143 | 0.0,
144 | 0.0,
145 | 0.0,
146 | 0.0,
147 | 0.0,
148 | 0.0,
149 | 0.0,
150 | 0.0,
151 | 0.0,
152 | 0.0,
153 | 0.0,
154 | 0.0,
155 | 1.0
156 | ],
157 | "edge_precision": [
158 | 1.0,
159 | 1,
160 | 1,
161 | 1,
162 | 1,
163 | 1,
164 | 1,
165 | 1,
166 | 1,
167 | 1,
168 | 1,
169 | 1,
170 | 1,
171 | 1,
172 | 1,
173 | 1,
174 | 1,
175 | 1,
176 | 1,
177 | 1,
178 | 1,
179 | 1,
180 | 0.0
181 | ],
182 | "n_edges": [
183 | NaN,
184 | 0,
185 | 0,
186 | 0,
187 | 0,
188 | 0,
189 | 0,
190 | 0,
191 | 0,
192 | 0,
193 | 0,
194 | 0,
195 | 0,
196 | 0,
197 | 0,
198 | 0,
199 | 0,
200 | 0,
201 | 0,
202 | 0,
203 | 0,
204 | 0,
205 | NaN
206 | ],
207 | "node_fpr": [
208 | 0.0,
209 | 0.0,
210 | 0.0,
211 | 0.0,
212 | 0.0,
213 | 0.0,
214 | 0.0,
215 | 0.0,
216 | 0.0,
217 | 0.0,
218 | 0.0,
219 | 0.0,
220 | 0.0,
221 | 0.0,
222 | 0.0,
223 | 0.0,
224 | 0.0,
225 | 0.0,
226 | 0.0,
227 | 0.0,
228 | 0.0,
229 | 0.0,
230 | 1.0
231 | ],
232 | "node_tpr": [
233 | 0.0,
234 | 0.0,
235 | 0.0,
236 | 0.0,
237 | 0.0,
238 | 0.0,
239 | 0.0,
240 | 0.0,
241 | 0.0,
242 | 0.0,
243 | 0.0,
244 | 0.0,
245 | 0.0,
246 | 0.0,
247 | 0.0,
248 | 0.0,
249 | 0.0,
250 | 0.0,
251 | 0.0,
252 | 0.0,
253 | 0.0,
254 | 0.0,
255 | 1.0
256 | ],
257 | "node_precision": [
258 | 1.0,
259 | 1,
260 | 1,
261 | 1,
262 | 1,
263 | 1,
264 | 1,
265 | 1,
266 | 1,
267 | 1,
268 | 1,
269 | 1,
270 | 1,
271 | 1,
272 | 1,
273 | 1,
274 | 1,
275 | 1,
276 | 1,
277 | 1,
278 | 1,
279 | 1,
280 | 0.0
281 | ],
282 | "n_nodes": [
283 | NaN,
284 | 0,
285 | 0,
286 | 0,
287 | 0,
288 | 0,
289 | 0,
290 | 0,
291 | 0,
292 | 0,
293 | 0,
294 | 0,
295 | 0,
296 | 0,
297 | 0,
298 | 0,
299 | 0,
300 | 0,
301 | 0,
302 | 0,
303 | 0,
304 | 0,
305 | NaN
306 | ]
307 | }
308 | }
309 | }
310 | }
311 | }
312 | }
--------------------------------------------------------------------------------
/experiments/results/plots_data/generate_makefile.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from experiments.launcher import KubernetesJob, launch
4 | import shlex
5 |
6 | TASKS = ["ioi", "docstring", "greaterthan", "tracr-reverse", "tracr-proportion", "induction"]
7 |
8 | METRICS_FOR_TASK = {
9 | "ioi": ["kl_div", "logit_diff"],
10 | "tracr-reverse": ["l2"],
11 | "tracr-proportion": ["l2"],
12 | "induction": ["kl_div", "nll"],
13 | "docstring": ["kl_div", "docstring_metric"],
14 | "greaterthan": ["kl_div", "greaterthan"],
15 | }
16 |
17 |
18 | def main():
19 | OUT_DIR = Path(__file__).resolve().parent
20 |
21 | actual_files = set(os.listdir(OUT_DIR))
22 | trained_files = []
23 | reset_files = []
24 | sixteenh_files = []
25 | sp_files = []
26 | acdc_files = []
27 | canonical_files = []
28 | random_files = []
29 | zero_files = []
30 | ioi_files = []
31 | docstring_files = []
32 | greaterthan_files = []
33 | tracr_reverse_files = []
34 | tracr_proportion_files = []
35 | induction_files = []
36 |
37 | with open(OUT_DIR/ "Makefile", "w") as f:
38 | possible_files = {"generate_makefile.py", "Makefile"}
39 |
40 | for alg in ["16h", "sp", "acdc", "canonical"]:
41 | for reset_network in [0, 1]:
42 | for zero_ablation in [0, 1]:
43 | for task in TASKS:
44 | for metric in METRICS_FOR_TASK[task]:
45 | if alg == "canonical" and metric == "kl_div":
46 | # No need to repeat the canonical calculations for both train metrics
47 | # (they're the same, nothing is trained)
48 | continue
49 |
50 | fname = f"{alg}-{task}-{metric}-{bool(zero_ablation)}-{reset_network}.json"
51 | possible_files.add(fname)
52 |
53 |
54 | command = [
55 | "python",
56 | "../../../notebooks/roc_plot_generator.py",
57 | f"--task={task}",
58 | f"--reset-network={reset_network}",
59 | f"--metric={metric}",
60 | f"--alg={alg}",
61 | ]
62 | if zero_ablation:
63 | command.append("--zero-ablation")
64 |
65 | f.write(fname + ":\n" + "\t" + shlex.join(command) + "\n\n")
66 |
67 | if alg == "16h":
68 | sixteenh_files.append(fname)
69 | elif alg == "sp":
70 | sp_files.append(fname)
71 | elif alg == "acdc":
72 | acdc_files.append(fname)
73 | elif alg == "canonical":
74 | canonical_files.append(fname)
75 |
76 | if reset_network:
77 | reset_files.append(fname)
78 | else:
79 | trained_files.append(fname)
80 |
81 | if zero_ablation:
82 | zero_files.append(fname)
83 | else:
84 | random_files.append(fname)
85 |
86 | if task == "ioi":
87 | ioi_files.append(fname)
88 | elif task == "docstring":
89 | docstring_files.append(fname)
90 | elif task == "greaterthan":
91 | greaterthan_files.append(fname)
92 | elif task == "tracr-reverse":
93 | tracr_reverse_files.append(fname)
94 | elif task == "tracr-proportion":
95 | tracr_proportion_files.append(fname)
96 | elif task == "induction":
97 | induction_files.append(fname)
98 |
99 | f.write("all: " + " ".join(sorted(possible_files)) + "\n\n")
100 | f.write("16h: " + " ".join(sorted(sixteenh_files)) + "\n\n")
101 | f.write("sp: " + " ".join(sorted(sp_files)) + "\n\n")
102 | f.write("acdc: " + " ".join(sorted(acdc_files)) + "\n\n")
103 | f.write("canonical: " + " ".join(sorted(canonical_files)) + "\n\n")
104 | f.write("trained: " + " ".join(sorted(trained_files)) + "\n\n")
105 | f.write("reset: " + " ".join(sorted(reset_files)) + "\n\n")
106 | f.write("zero: " + " ".join(sorted(zero_files)) + "\n\n")
107 | f.write("random: " + " ".join(sorted(random_files)) + "\n\n")
108 | f.write("ioi: " + " ".join(sorted(ioi_files)) + "\n\n")
109 | f.write("docstring: " + " ".join(sorted(docstring_files)) + "\n\n")
110 | f.write("greaterthan: " + " ".join(sorted(greaterthan_files)) + "\n\n")
111 | f.write("tracr-reverse: " + " ".join(sorted(tracr_reverse_files)) + "\n\n")
112 | f.write("tracr-proportion: " + " ".join(sorted(tracr_proportion_files)) + "\n\n")
113 | f.write("induction: " + " ".join(sorted(induction_files)) + "\n\n")
114 |
115 | print(actual_files - possible_files)
116 | assert len(actual_files - possible_files) == 0, "There are files that shouldn't be there"
117 |
118 | missing_files = possible_files - actual_files
119 | print(f"Missing {len(missing_files)} files:")
120 | for missing_file in missing_files:
121 | print(missing_file)
122 |
123 | if __name__ == "__main__":
124 | main()
125 |
--------------------------------------------------------------------------------
/ims/make_jsons.py:
--------------------------------------------------------------------------------
1 | #%%
2 |
3 | import plotly
4 |
5 | # Set renderers
6 | plotly.io.renderers.default = "jupyterlab" # apparently the secret sauce to make latex work in notebooks (not requireing setting renderer to browser?)
7 | # that failed
8 |
9 | plotly.io.renderers.default = "browser"
10 |
11 | # note: LATEX works in scripts,
12 |
13 |
14 | # %%
15 |
16 | fname = "induction_json.json"
17 | plotly_graph = plotly.io.read_json(fname)
18 |
19 | # %%
20 |
21 | plotly_graph.show()
22 |
23 | # %%
24 |
25 | # Get new data
26 | paper_corrupted_fname = "current_paper_induction_json.json"
27 | paper_corrupted_figure = plotly.io.read_json(paper_corrupted_fname)
28 | paper_zero_figure = plotly.io.read_json("current_paper_induction_zero.json")
29 |
30 | # %%
31 |
32 | scatter_names = [
33 | "ACDC",
34 | "SP",
35 | "HISP",
36 | ]
37 |
38 | x_data = {}
39 | y_data = {}
40 | color_data = {}
41 |
42 | figures = {
43 | "corrupted": paper_corrupted_figure,
44 | "zero": paper_zero_figure,
45 | }
46 |
47 | for figure_name, figure in figures.items():
48 | for name in scatter_names:
49 | x_data[name] = []
50 | y_data[name] = []
51 |
52 | x_data_element = [thing for thing in figure["data"] if thing["name"] == name][0]
53 | y_data_element = [thing for thing in figure["data"] if thing["name"] == name][0]
54 |
55 | x_data[(figure_name, name)] = x_data_element["x"]
56 | y_data[(figure_name, name)] = y_data_element["y"]
57 |
58 | color_data[(figure_name, name)] = x_data_element["marker"]["color"]
59 |
60 | # %%
61 |
62 | # Add this into the paper figure
63 |
64 | the_ref = {}
65 |
66 | for i in range(len(plotly_graph.data)):
67 | if plotly_graph.data[i]["yaxis"] == "y" and plotly_graph.data[i]["name"] in scatter_names:
68 | plotly_graph.data[i]["x"] = x_data[("corrupted", plotly_graph.data[i]["name"])]
69 | plotly_graph.data[i]["y"] = y_data[("corrupted", plotly_graph.data[i]["name"])]
70 | if plotly_graph.data[i]["name"] == "ACDC":
71 | cur_col_data = color_data[("corrupted", plotly_graph.data[i]["name"])]
72 | plotly_graph.data[i]["marker"]["color"] = cur_col_data
73 | the_ref[plotly_graph.data[i]["name"]] =plotly_graph.data[i]["marker"]
74 |
75 | if plotly_graph.data[i]["yaxis"] == "y2" and plotly_graph.data[i]["name"] in scatter_names:
76 | plotly_graph.data[i]["x"] = x_data[("zero", plotly_graph.data[i]["name"])]
77 | plotly_graph.data[i]["y"] = y_data[("zero", plotly_graph.data[i]["name"])]
78 | # if plotly_graph.data[i]["name"] == "ACDC":
79 | # cur_col_data = color_data[("zero", plotly_graph.data[i]["name"])]
80 | # plotly_graph.data[i]["marker"]["color"] = cur_col_data
81 | # cur_col_data = color_data[("corrupted", plotly_graph.data[i]["name"])]
82 |
83 |
84 | for i in range(len(plotly_graph.data)):
85 | if plotly_graph.data[i]["yaxis"] == "y2" and plotly_graph.data[i]["name"] in scatter_names:
86 | print(plotly_graph.data[i])
87 | print("Hey")
88 | plotly_graph.data[i]["marker"] = the_ref[plotly_graph.data[i]["name"]] # color_data[("corrupted", plotly_graph.data[i]["name"])]
89 |
90 | # %%
91 |
92 | plotly_graph.show()
93 |
94 | # %%
95 |
96 | lis = []
97 |
98 | for i in range(len(plotly_graph.data)):
99 | if plotly_graph.data[i]["yaxis"] == "y3":
100 | # print(plotly_graph.data[i])
101 | # break
102 | # print(i)
103 | lis.append(i)
104 |
105 | # %%
106 |
107 | # Remove this from the paper figure
108 |
109 | for l in sorted(lis[1:], reverse=True):
110 | plotly_graph.data = plotly_graph.data[:l] + plotly_graph.data[l+1:]
111 |
112 | # %%
113 |
114 | fig = plotly.io.read_json("my_new_plotly_graph.json")
115 |
116 | # %%
117 |
118 | fig.show()
119 | # %%
120 |
121 | # TODO add better legend
122 |
--------------------------------------------------------------------------------
/notebooks/_converted/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
3 | *.log
4 | tmp/
5 |
6 |
7 | *.ipynb
8 |
--------------------------------------------------------------------------------
/notebooks/auc_tables.py:
--------------------------------------------------------------------------------
1 | #%%
2 |
3 | from IPython import get_ipython
4 | ipython = get_ipython()
5 | if ipython is not None:
6 | ipython.run_line_magic('load_ext', 'autoreload')
7 | ipython.run_line_magic('autoreload', '2')
8 | import io
9 | import numpy as np
10 |
11 | import pandas as pd
12 | from pathlib import Path
13 | from tabulate import tabulate
14 | import argparse
15 |
16 | parser = argparse.ArgumentParser(
17 | usage="Generate AUC tables from CSV files. Pass the data.csv file as an argument fname, e.g python notebooks/auc_tables.py --fname=experiments/results/plots/data.csv"
18 | )
19 | parser.add_argument('--in-fname', type=str, default="experiments/results/plots/data.csv")
20 | parser.add_argument('--out-fname', type=str, default="experiments/results/auc_tables.tex")
21 |
22 | if ipython is None:
23 | args = parser.parse_args()
24 | else: # make parsing arguments work in jupyter notebook
25 | args = parser.parse_args(args=[])
26 |
27 | data = pd.read_csv(args.in_fname)
28 |
29 | # %%
30 | with io.StringIO() as buf:
31 | for key in ["auc"]: # ["test_kl_div", "test_loss", "auc"]:
32 | for weights_type in ["trained"]: # ["reset", "trained"]
33 | df = data[(data["weights_type"] == weights_type)]
34 | df = df.replace({"metric": df.metric.map(lambda x: "other" if x != "kl_div" else x)})
35 | df = df.drop_duplicates(subset=["task", "method", "metric", "ablation_type", "plot_type"])
36 |
37 | def process_metric_pretty(row):
38 | if row["metric"] == "kl_div":
39 | return "KL"
40 | else:
41 | return "Loss"
42 |
43 | df["metric_pretty"] = df.apply(process_metric_pretty, axis=1)
44 | out = df.drop("Unnamed: 0", axis=1).pivot_table(
45 | index=["metric_pretty", "task"], columns=["ablation_type", "plot_type", "method"], values=key
46 | )
47 | # Needed to handle non-AUC keys
48 | # out = out.applymap(lambda x: None if x == -1 else (texts[x] if isinstance(x, int) else x))
49 | out = out.dropna(axis=0)
50 |
51 | # %% Export as latex
52 | def export_table(out, name):
53 | def make_bold_column(row):
54 | out = pd.Series(dtype=np.float64)
55 | for plot_type in ["roc_edges", "roc_nodes"]:
56 | the_max = row.loc[plot_type].max()
57 | out = pd.concat(
58 | [
59 | out,
60 | pd.Series(
61 | data=[
62 | f"\\textbf{{{x:.3f}}}" if x == the_max else f"{x:.3f}"
63 | for x in row.loc[plot_type]
64 | ],
65 | index=row.loc[[plot_type]].index,
66 | ),
67 | ]
68 | )
69 | return out
70 |
71 | old_out = out
72 | out = out.apply(make_bold_column, axis=1)
73 | out.columns = pd.MultiIndex.from_tuples(out.columns)
74 | out.style.to_latex(buf, hrules=True, environment="table", caption=name)
75 |
76 | export_table(out.random_ablation, f"Key={key}, weights_type={weights_type}, Random Ablation")
77 | export_table(out.zero_ablation, f"Key={key}, weights_type={weights_type}, Zero Ablation")
78 |
79 | with open(args.out_fname, "w") as f:
80 | f.write(
81 | r"""\documentclass{article}
82 | \usepackage{booktabs}
83 | \usepackage{multirow}
84 | \begin{document}
85 |
86 | WARNING: you need to add some vertical and horizontal rows for this to look good
87 |
88 | """
89 | )
90 | f.write(buf.getvalue().replace("_", "\\_"))
91 | f.write(
92 | r"""\end{document}
93 | """
94 | )
95 |
--------------------------------------------------------------------------------
/notebooks/convert_to_ipynb.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # TODO add this all to the Makefile
4 |
5 | set -e
6 |
7 | # Check for --skip-run flag
8 | skip_run=false
9 | for arg in "$@"
10 | do
11 | if [ "$arg" == "--skip-run" ]; then
12 | skip_run=true
13 | fi
14 | done
15 |
16 | required_commands=("jupytext" "papermill" "python")
17 |
18 | # Loop over each required command
19 | for command in "${required_commands[@]}"; do
20 | # Check if the command is installed
21 | if ! command -v $command &> /dev/null
22 | then
23 | echo "$command could not be found"
24 | echo "Please install it using 'pip install $command'"
25 | exit
26 | fi
27 | done
28 |
29 | # Define the file paths
30 | declare -A file_paths
31 | file_paths=(
32 | ["notebooks/editing_edges.py"]="notebooks/_converted/editing_edges.ipynb notebooks/colabs/ACDC_Editing_Edges_Demo.ipynb"
33 | ["acdc/main.py"]="notebooks/_converted/main_demo.ipynb notebooks/colabs/ACDC_Main_Demo.ipynb"
34 | ["notebooks/implementation_demo.py"]="notebooks/_converted/implementation_demo.ipynb notebooks/colabs/ACDC_Implementation_Demo.ipynb"
35 | )
36 |
37 | # Loop over each file path
38 | for in_path in "${!file_paths[@]}"; do
39 | # Split the output paths
40 | IFS=' ' read -r -a out_paths <<< "${file_paths[$in_path]}"
41 |
42 | middle_path=${out_paths[0]}
43 | final_out_path=${out_paths[1]}
44 |
45 | # Run jupytext and papermill
46 | jupytext --to notebook "$in_path" -o "$middle_path"
47 |
48 | if ! $skip_run; then
49 | papermill "$middle_path" "$final_out_path" --kernel=python
50 |
51 | # TODO fix this; it seems some errored files are slipping through
52 | python -c "
53 | import nbformat
54 | nb = nbformat.read('$final_out_path', as_version=4)
55 | errors = [cell for cell in nb['cells'] if 'outputs' in cell and any(output.get('output_type') == 'error' for output in cell['outputs'])]
56 | if errors:
57 | raise Exception(f'Error: The following cells failed in notebook $final_out_path:\n{errors}')
58 | "
59 | else
60 | cp "$middle_path" "$final_out_path"
61 | fi
62 | done
--------------------------------------------------------------------------------
/notebooks/df_plots_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from warnings import warn
3 | import warnings
4 | from IPython import get_ipython
5 | if get_ipython() is not None:
6 | get_ipython().magic('load_ext autoreload')
7 | get_ipython().magic('autoreload 2')
8 |
9 | __file__ = os.path.join(get_ipython().run_line_magic('pwd', ''), "notebooks", "df_plots_data.py")
10 |
11 | from notebooks.emacs_plotly_render import set_plotly_renderer
12 | if "adria" in __file__:
13 | set_plotly_renderer("emacs")
14 |
15 | import plotly
16 | import numpy as np
17 | import json
18 | import wandb
19 | from acdc.acdc_graphics import dict_merge, pessimistic_auc
20 | import time
21 | from plotly.subplots import make_subplots
22 | import plotly.graph_objects as go
23 | import plotly.colors as pc
24 | from pathlib import Path
25 | import plotly.express as px
26 | import pandas as pd
27 | import argparse
28 |
29 | # %%
30 |
31 | DATA_DIR = Path(__file__).resolve().parent.parent / "experiments" / "results" / "plots_data"
32 | all_data = {}
33 |
34 | for fname in os.listdir(DATA_DIR):
35 | if fname.endswith(".json"):
36 | with open(DATA_DIR / fname, "r") as f:
37 | data = json.load(f)
38 | dict_merge(all_data, data)
39 |
40 |
41 |
42 | # %% Possibly convert all this data to pandas dataframe
43 |
44 | rows = []
45 | for weights_type, v in all_data.items():
46 | for ablation_type, v2 in v.items():
47 | for task, v3 in v2.items():
48 | for metric, v4 in v3.items():
49 | for alg, v5 in v4.items():
50 | for i in range(len(v5["score"])):
51 | rows.append(pd.Series({
52 | "weights_type": weights_type,
53 | "ablation_type": ablation_type,
54 | "task": task,
55 | "metric": metric,
56 | "alg": alg,
57 | **{k: val[i] for k, val in v5.items()}}))
58 |
59 | df = pd.DataFrame(rows)
60 |
61 | # %% Print KL
62 | present = df[(df["alg"] == "CANONICAL") & (df["weights_type"] == "trained") & (df["score"] == 1.0)][
63 | ["ablation_type", "task", "metric", "test_kl_div"]
64 | ]
65 | present.sort_values(["task", "ablation_type"])
66 |
67 | # %% Print Docstring stuff
68 |
69 | for TASK in ["docstring", "ioi", "greaterthan", "tracr-proportion", "tracr-reverse"]:
70 | print(TASK)
71 | present = df[
72 | (df["alg"] == "CANONICAL")
73 | & (df["weights_type"] == "trained")
74 | & (df["metric"] != "kl_div")
75 | & (df["task"] == TASK)
76 | & (np.isfinite(df["score"]))
77 | ][["ablation_type", "score", *filter(lambda x: x.startswith("test_"), df.columns)]]
78 | present["type"] = present["score"].map(lambda x: {0.0: "corrupted_model", 1.0: "clean_model", 0.5: "canonical"}[x])
79 | out = present[["ablation_type", "type", "test_docstring_metric", *filter(lambda x: x.startswith("test_"), present.columns)]]
80 |
81 | out.dropna(axis=1, how="all", inplace=True)
82 | print(out)
83 |
--------------------------------------------------------------------------------
/notebooks/easier_roc_plot.py:
--------------------------------------------------------------------------------
1 | # %%
2 | #
3 | import plotly
4 | import os
5 | import numpy as np
6 | import json
7 | import wandb
8 | import time
9 | from plotly.subplots import make_subplots
10 | import plotly.graph_objects as go
11 | from pathlib import Path
12 | import plotly.express as px
13 | import pandas as pd
14 | import argparse
15 | import plotly.colors as pc
16 | from acdc.graphics import dict_merge, pessimistic_auc
17 |
18 | from notebooks.emacs_plotly_render import set_plotly_renderer
19 |
20 | set_plotly_renderer("emacs")
21 |
22 |
23 | # %%
24 |
25 | DATA_DIR = Path("acdc") / "media" / "plots_data"
26 |
27 | all_data = {}
28 |
29 | for fname in os.listdir(DATA_DIR):
30 | if fname.endswith(".json"):
31 | with open(DATA_DIR / fname, "r") as f:
32 | data = json.load(f)
33 | dict_merge(all_data, data)
34 |
35 | # %%
36 |
37 |
38 | def discard_non_pareto_optimal(points, cmp="gt"):
39 | ret = []
40 | for x, y in points:
41 | for x1, y1 in points:
42 | if x1 < x and getattr(y1, f"__{cmp}__")(y) and (x1, y1) != (x, y):
43 | break
44 | else:
45 | ret.append((x, y))
46 | return list(sorted(ret))
47 |
48 |
49 | fig = make_subplots(rows=1, cols=2, subplot_titles=["ROC Curves"], column_widths=[0.95, 0.05], horizontal_spacing=0.03)
50 |
51 | colorscales = {
52 | "ACDC": "Blues",
53 | "SP": "Greens",
54 | }
55 |
56 | for i, alg in enumerate(["ACDC", "SP"]):
57 | this_data = all_data["trained"]["random_ablation"]["ioi"]["logit_diff"][alg]
58 | x_data = this_data["edge_fpr"]
59 | y_data = this_data["edge_tpr"]
60 | scores = this_data["score"]
61 |
62 | log_scores = np.log10(scores)
63 | log_scores = np.nan_to_num(log_scores, nan=0.0, neginf=0.0, posinf=0.0)
64 |
65 | min_score = np.min(log_scores)
66 | max_score = np.max(log_scores)
67 |
68 | normalized_scores = (log_scores - min_score) / (max_score - min_score)
69 | normalized_scores[~np.isfinite(normalized_scores)] = 0.0
70 |
71 | points = list(zip(x_data, y_data))
72 | pareto_optimal = discard_non_pareto_optimal(points)
73 |
74 | methodof = "acdc"
75 |
76 | pareto_x_data, pareto_y_data = zip(*pareto_optimal)
77 | fig.add_trace(
78 | go.Scatter(
79 | x=pareto_x_data,
80 | y=pareto_y_data,
81 | name=methodof,
82 | mode="lines",
83 | line=dict(
84 | shape="hv",
85 | color=pc.sample_colorscale(pc.get_colorscale(colorscales[alg]), 0.7)[0],
86 | ),
87 | showlegend=False,
88 | hovertext=log_scores,
89 | ),
90 | row=1,
91 | col=1,
92 | )
93 |
94 | N_TICKS = 5
95 | tickvals = np.linspace(0, 1, N_TICKS)
96 | ticktext = 10 ** np.linspace(min_score, max_score, N_TICKS)
97 |
98 | fig.add_trace(
99 | go.Scatter(
100 | x=x_data,
101 | y=y_data,
102 | name=methodof,
103 | mode="markers",
104 | showlegend=False,
105 | marker=dict(
106 | size=7,
107 | color=normalized_scores,
108 | colorscale=colorscales[alg],
109 | symbol="circle",
110 | # colorbar=dict(
111 | # title="Log scores",
112 | # tickvals=tickvals, # positions for ticks
113 | # ticktext=["%.2e" % i for i in ticktext], # tick labels, formatted as strings
114 | # thickness=5,
115 | # # y=0.25,
116 | # # len=0.5,
117 | # x=1 + i * 0.1,
118 | # ),
119 | showscale=False,
120 | ),
121 | ),
122 | row=1,
123 | col=1,
124 | )
125 | nums = np.arange(200).reshape(2, 100).T.astype(float)
126 | nums[:20, :20] = np.nan
127 | fig.add_trace(
128 | go.Heatmap(
129 | z=nums,
130 | colorscale='Viridis',
131 | showscale=False,
132 | ),
133 | row=1,
134 | col=2,
135 | )
136 |
137 | fig.update_xaxes(showline=False, zeroline=False, showgrid=False, row=1, col=2, showticklabels=False, ticks="")
138 | fig.update_yaxes(showline=False, zeroline=False, showgrid=False, row=1, col=2, side="right")
139 | fig.show()
140 |
--------------------------------------------------------------------------------
/notebooks/editing_edges.py:
--------------------------------------------------------------------------------
1 | # %% [markdown]
2 | #
ACDC Editing Edges Demo
3 | #
4 | # This notebook gives a high-level overview of the main abstractions used in the ACDC codebase.
5 | #
6 | # If you are interested in models that are >=1B parameters, this library currently may be too slow and we would recommend you look at the path patching implementations in `TransformerLens` (for example, see this notebook)
7 | #
8 | # Setup
9 | #
10 | #
Janky code to do different setup when run in a Colab notebook vs VSCode (adapted from e.g this notebook)
11 | #
12 | # You can ignore warnings that "packages were previously imported in this runtime"
13 |
14 | #%%
15 |
16 | try:
17 | import google.colab
18 |
19 | IN_COLAB = True
20 | print("Running as a Colab notebook")
21 |
22 | import subprocess # to install graphviz dependencies
23 | command = ['apt-get', 'install', 'graphviz-dev']
24 | subprocess.run(command, check=True)
25 |
26 | from IPython import get_ipython
27 | ipython = get_ipython()
28 |
29 | ipython.run_line_magic( # install ACDC
30 | "pip",
31 | "install git+https://github.com/ArthurConmy/Automatic-Circuit-Discovery.git@d89f7fa9cbd095202f3940c889cb7c6bf5a9b516",
32 | )
33 |
34 | except Exception as e:
35 | IN_COLAB = False
36 | print("Running outside of Colab notebook")
37 |
38 | import numpy # crucial to not get cursed error
39 | import plotly
40 |
41 | plotly.io.renderers.default = "colab" # added by Arthur so running as a .py notebook with #%% generates .ipynb notebooks that display in colab
42 | # disable this option when developing rather than generating notebook outputs
43 |
44 | from IPython import get_ipython
45 |
46 | ipython = get_ipython()
47 | if ipython is not None:
48 | print("Running as a notebook")
49 | ipython.run_line_magic("load_ext", "autoreload") # type: ignore
50 | ipython.run_line_magic("autoreload", "2") # type: ignore
51 | else:
52 | print("Running as a .py script")
53 |
54 | # %% [markdown]
55 | # Imports etc
56 |
57 | #%%
58 |
59 | from transformer_lens.HookedTransformer import HookedTransformer
60 | from acdc.TLACDCExperiment import TLACDCExperiment
61 | from acdc.induction.utils import get_all_induction_things
62 | from acdc.acdc_utils import TorchIndex
63 | import torch
64 | import gc
65 |
66 | # %% [markdown]
67 | # Load in the model and data for the induction task
68 |
69 | #%%
70 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
71 | num_examples = 40
72 | seq_len = 50
73 |
74 | # load in a tl_model and grab some data
75 | all_induction_things = get_all_induction_things(
76 | num_examples=num_examples,
77 | seq_len=seq_len,
78 | device=DEVICE,
79 | )
80 |
81 | tl_model, toks_int_values, toks_int_values_other, metric, mask_rep = (
82 | all_induction_things.tl_model,
83 | all_induction_things.validation_data,
84 | all_induction_things.validation_patch_data,
85 | all_induction_things.validation_metric,
86 | all_induction_things.validation_mask,
87 | )
88 |
89 | # You should read the get_model function from that file to see what the tl_model is : )
90 |
91 | # %% [markdown]
92 | #
Ensure we stay under mem limit on small machines
93 |
94 | #%%
95 | gc.collect()
96 | torch.cuda.empty_cache()
97 |
98 | # %% [markdown]
99 | # Let's see an example from the dataset.
100 | # `|` separates tokens
101 |
102 | #%%
103 | EXAMPLE_NO = 33
104 | EXAMPLE_LENGTH = 36
105 |
106 | print(
107 | "|".join(tl_model.to_str_tokens(toks_int_values[EXAMPLE_NO, :EXAMPLE_LENGTH])),
108 | )
109 |
110 | #%% [markdown]
111 | # This dataset has several examples of induction! F -> #, mon -> ads
112 | # The `mask_rep` mask is a boolean mask of shape `(num_examples, seq_len)` that indicates where induction is present in the dataset
113 | # Let's see
114 |
115 | #%%
116 | for i in range(EXAMPLE_LENGTH):
117 | if mask_rep[EXAMPLE_NO, i]:
118 | print(f"At position {i} there is induction")
119 | print(tl_model.to_str_tokens(toks_int_values[EXAMPLE_NO:EXAMPLE_NO+1, i : i + 1]))
120 |
121 | # %% [markdown]
122 | #
Let's get the initial loss on the induction examples
123 |
124 | #%%
125 | def get_loss(model, data, mask):
126 | loss = model(
127 | data,
128 | return_type="loss",
129 | loss_per_token=True,
130 | )
131 | return (loss * mask[:, :-1].int()).sum() / mask[:, :-1].int().sum()
132 |
133 |
134 | print(f"Loss: {get_loss(tl_model, toks_int_values, mask_rep)}")
135 |
136 | #%% [markdown]
137 | #We will now wrap ACDC things inside an `experiment`for further experiments
138 | # For more advanced usage of the `TLACDCExperiment` object (the main object in this codebase), see the README for links to the `main.py` and its demos
139 |
140 | #%%
141 | experiment = TLACDCExperiment(
142 | model=tl_model,
143 | threshold=0.0,
144 | ds=toks_int_values,
145 | ref_ds=None, # This argument is the corrupted dataset from the ACDC paper. We're going to do zero ablation here so we omit this
146 | metric=metric,
147 | zero_ablation=True,
148 | hook_verbose=False,
149 | )
150 |
151 | # %% [markdown]
152 |
153 | # Usually, the `TLACDCExperiment` efficiently add hooks to the model in order to do ACDC runs fast.
154 | # For this tutorial, we'll add ALL the hooks so you can edit connections in the model as easily as possible.
155 |
156 | #%%
157 | experiment.model.reset_hooks()
158 | experiment.setup_model_hooks(
159 | add_sender_hooks=True,
160 | add_receiver_hooks=True,
161 | doing_acdc_runs=False,
162 | )
163 |
164 | # %% [markdown]
165 | # Let's take a look at the edges
166 |
167 | #%%
168 | for edge_indices, edge in experiment.corr.all_edges().items():
169 | # here's what's inside the edge
170 | receiver_name, receiver_index, sender_name, sender_index = edge_indices
171 |
172 | # for now, all edges should be present
173 | assert edge.present, edge_indices
174 |
175 | # %% [markdown]
176 | # Let's make a function that's able to turn off all the connections from the nodes to the output, except the induction head (1.5 and 1.6)
177 | # (we'll later turn ON all connections EXCEPT the induction heads)
178 |
179 | #%%
180 | def change_direct_output_connections(exp, invert=False):
181 | residual_stream_end_name = "blocks.1.hook_resid_post"
182 | residual_stream_end_index = TorchIndex([None])
183 | induction_heads = [
184 | ("blocks.1.attn.hook_result", TorchIndex([None, None, 5])),
185 | ("blocks.1.attn.hook_result", TorchIndex([None, None, 6])),
186 | ]
187 |
188 | inputs_to_residual_stream_end = exp.corr.edges[residual_stream_end_name][
189 | residual_stream_end_index
190 | ]
191 | for sender_name in inputs_to_residual_stream_end:
192 | for sender_index in inputs_to_residual_stream_end[sender_name]:
193 | edge = inputs_to_residual_stream_end[sender_name][sender_index]
194 | is_induction_head = (sender_name, sender_index) in induction_heads
195 |
196 | if is_induction_head:
197 | edge.present = not invert
198 |
199 | else:
200 | edge.present = invert
201 |
202 | print(
203 | f"{'Adding' if (invert == is_induction_head) else 'Removing'} edge from {sender_name} {sender_index} to {residual_stream_end_name} {residual_stream_end_index}"
204 | )
205 |
206 |
207 | change_direct_output_connections(experiment)
208 | print(
209 | "Loss with only the induction head direct connections:",
210 | get_loss(experiment.model, toks_int_values, mask_rep).item(),
211 | )
212 |
213 | # %% [markdown]
214 | # Let's turn ON all the connections EXCEPT the induction heads
215 |
216 | #%%
217 | change_direct_output_connections(experiment, invert=True)
218 | print(
219 | "Loss without the induction head direct connections:",
220 | get_loss(experiment.model, toks_int_values, mask_rep).item(),
221 | )
222 |
223 | #%% [markdown]
224 | # That's much larger!
225 | # See acdc/main.py for how to run ACDC experiments; try `python acdc/main.py --help` or check the README for the links to this file
226 |
--------------------------------------------------------------------------------
/notebooks/emacs_plotly_render.py:
--------------------------------------------------------------------------------
1 | """A utils file from Adria"""
2 |
3 | import hashlib
4 | import os
5 |
6 | import plotly.io as pio
7 |
8 |
9 | class EmacsRenderer(pio.base_renderers.ColabRenderer):
10 | save_dir = "ob-jupyter"
11 | base_url = f"http://localhost:8888/files"
12 |
13 | def to_mimebundle(self, fig_dict):
14 | html = super().to_mimebundle(fig_dict)["text/html"]
15 |
16 | mhash = hashlib.md5(html.encode("utf-8")).hexdigest()
17 | if not os.path.isdir(self.save_dir):
18 | os.mkdir(self.save_dir)
19 | fhtml = os.path.join(self.save_dir, mhash + ".html")
20 | with open(fhtml, "w") as f:
21 | f.write(html)
22 |
23 | return {"text/html": f'Click to open {fhtml}'}
24 |
25 |
26 | pio.renderers["emacs"] = EmacsRenderer()
27 |
28 |
29 | def set_plotly_renderer(renderer="emacs"):
30 | pio.renderers.default = renderer
31 |
--------------------------------------------------------------------------------
/notebooks/minimal_acdc_node_roc.py:
--------------------------------------------------------------------------------
1 | #%%
2 |
3 | from IPython import get_ipython
4 | ipython = get_ipython()
5 | if ipython is not None:
6 | ipython.magic("%load_ext autoreload")
7 | ipython.magic("%autoreload 2")
8 | import os
9 | from pathlib import Path
10 | import json
11 | import plotly.graph_objects as go
12 | import plotly.express as px
13 | import matplotlib.pyplot as plt
14 |
15 | #%%
16 |
17 | # Set your root directory here
18 | ROOT_DIR = Path("/home/arthur/Documents/Automatic-Circuit-Discovery")
19 | assert ROOT_DIR.exists(), f"I don't think your ROOT_DIR is correct (ROOT_DIR = {ROOT_DIR})"
20 |
21 | # %%
22 |
23 | TASK = "ioi"
24 | METRIC = "kl_div"
25 | FNAME = f"experiments/results/plots_data/acdc-{TASK}-{METRIC}-False-0.json"
26 | FPATH = ROOT_DIR / FNAME
27 | assert FPATH.exists(), f"I don't think your FNAME is correct (FPATH = {FPATH})"
28 |
29 | # %%
30 |
31 | data = json.load(open(FPATH, "r"))
32 |
33 | # %%
34 |
35 | relevant_data = data["trained"]["random_ablation"]["ioi"]["kl_div"]["ACDC"]
36 |
37 | # %%
38 |
39 | node_tpr = relevant_data["node_tpr"]
40 | node_fpr = relevant_data["node_fpr"]
41 |
42 | # %%
43 |
44 | # We would just plot these, but sometimes points are not on the Pareto frontier
45 |
46 | def pareto_optimal_sublist(xs, ys):
47 | retx, rety = [], []
48 | for x, y in zip(xs, ys):
49 | for x1, y1 in zip(xs, ys):
50 | if x1 > x and y1 < y:
51 | break
52 | else:
53 | retx.append(x)
54 | rety.append(y)
55 | indices = sorted(range(len(retx)), key=lambda i: retx[i])
56 | return [retx[i] for i in indices], [rety[i] for i in indices]
57 |
58 | # %%
59 |
60 | pareto_node_tpr, pareto_node_fpr = pareto_optimal_sublist(node_tpr, node_fpr)
61 |
62 | # %%
63 |
64 | # Thanks GPT-4 for this code
65 |
66 | # Create the plot
67 | plt.figure()
68 |
69 | # Plot the ROC curve
70 | plt.step(pareto_node_fpr, pareto_node_tpr, where='post')
71 |
72 | # Add titles and labels
73 | plt.title("ROC Curve of number of Nodes recovered by ACDC")
74 | plt.xlabel("False Positive Rate")
75 | plt.ylabel("True Positive Rate")
76 |
77 | # Show the plot
78 | plt.show()
79 |
80 | # %%
81 |
82 | # Original code from https://plotly.com/python/line-and-scatter/
83 |
84 | # I use plotly but it should be easy to adjust to matplotlib
85 | fig = go.Figure()
86 | fig.add_trace(
87 | go.Scatter(
88 | x=list(pareto_node_fpr),
89 | y=list(pareto_node_tpr),
90 | mode="lines",
91 | line=dict(shape="hv"),
92 | showlegend=False,
93 | ),
94 | )
95 |
96 | fig.update_layout(
97 | title="ROC Curve of number of Nodes recovered by ACDC",
98 | xaxis_title="False Positive Rate",
99 | yaxis_title="True Positive Rate",
100 | )
101 |
102 | fig.show()
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "acdc"
3 | version = "0.0.0" # This should automatically be set by the CD pipeline on release
4 | description = "ACDC: Automatic Circuit DisCovery implementation on top of TransformerLens"
5 | authors = ["Arthur Conmy, Adrià Garriga-Alonso"]
6 | license = "MIT"
7 | readme = "README.md"
8 | packages = [{include = "acdc"}, {include = "subnetwork_probing"}]
9 |
10 | [tool.poetry.dependencies]
11 | python = "^3.8"
12 | einops = "^0.6.0"
13 | numpy = [{ version = "^1.21", python = "<3.10" },
14 | { version = "^1.23", python = ">=3.10" }]
15 | torch = ">=1.10, <2.0"
16 | datasets = "^2.7.1"
17 | transformers = "^4.35.0"
18 | tokenizers = "^0.15.0"
19 | tqdm = "^4.64.1"
20 | pandas = "^1.1.5"
21 | wandb = "^0.13.5"
22 | torchtyping = "^0.1.4"
23 | huggingface-hub = "^0.16.0"
24 | cmapy = "^0.6.6"
25 | networkx = "^3.1"
26 | plotly = "^5.12.0"
27 | kaleido = "0.2.1"
28 | pygraphviz = "^1.11"
29 | tracr = {git = "https://github.com/deepmind/tracr.git", rev = "e75ecda"}
30 | transformer-lens = "1.6.1"
31 |
32 | [tool.poetry.group.dev.dependencies]
33 | pytest = "^7.2.0"
34 | pytest-cov = "^4.0.0"
35 | jupyterlab = "^3.5.0"
36 | jupyter = "^1.0.0"
37 |
38 | [build-system]
39 | requires = ["poetry-core"]
40 | build-backend = "poetry.core.masonry.api"
41 |
42 | [tool.pytest.ini_options]
43 | filterwarnings = [
44 | # Ignore numpy.distutils deprecation warning caused by pandas
45 | # More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils
46 | "ignore:distutils Version classes are deprecated:DeprecationWarning"
47 | ]
48 | markers = [
49 | "slow: marks tests as slow (deselect with '-m \"not slow\"')",
50 | ]
51 |
52 | [tool.black]
53 | line-length = 120
54 |
55 | [tool.isort]
56 | profile = "black"
57 | line_length = 120
58 | skip_gitignore = true
59 |
--------------------------------------------------------------------------------
/subnetwork_probing/README.md:
--------------------------------------------------------------------------------
1 | # Setup
2 |
3 | This implementation of Subnetwork Probing should install by default when installing the ACDC code.
4 |
5 | It hosts a fork of `transformer_lens` as a submodule. This should probably be changed in the future.
6 |
7 | The fork introduces the class `MaskedHookPoint`, which masks some of its values with either zero or stored activations.
8 | That's the only crucial difference with mainstream `transformer_lens`. Most of the complexity is in `train.py`.
9 |
10 | # Subnetwork Probing
11 |
12 | [Low-Complexity Probing via Finding Subnetworks](https://github.com/stevenxcao/subnetwork-probing)
13 | Steven Cao, Victor Sanh, Alexander M. Rush
14 | NAACL-HLT 2021
15 |
16 | # HISP
17 |
18 | [Are Sixteen Heads Really Better than One?](https://arxiv.org/abs/1905.10650) Michel et al 2019
19 |
--------------------------------------------------------------------------------
/subnetwork_probing/create_reset_networks.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from pprint import pprint
4 | import gc
5 |
6 | from acdc.greaterthan.utils import get_all_greaterthan_things
7 | from acdc.induction.utils import get_all_induction_things
8 | from acdc.ioi.utils import get_all_ioi_things
9 | from acdc.tracr_task.utils import get_all_tracr_things, get_tracr_model_input_and_tl_model
10 | from acdc.docstring.utils import get_all_docstring_things, AllDataThings
11 |
12 | torch.set_grad_enabled(False)
13 |
14 | DEVICE = "cuda"
15 |
16 | def scramble_sd(seed, things: AllDataThings, name, scramble_heads=True, scramble_mlps=True, scramble_head_outputs: bool=False):
17 | model = things.tl_model
18 | old_sd = model.state_dict()
19 | n_heads = model.cfg.n_heads
20 | n_neurons = model.cfg.d_mlp
21 | d_head = model.cfg.d_head
22 | torch.manual_seed(seed)
23 |
24 | if scramble_head_outputs and not scramble_heads:
25 | raise NotImplementedError
26 |
27 | sd = {}
28 | for k, v in old_sd.items():
29 | if scramble_heads and ("attn" in k and not (k.endswith("_O") or k.endswith("mask") or k.endswith("IGNORE"))):
30 | assert v.shape[0] == n_heads
31 | to_sd = v[torch.randperm(n_heads), ...].contiguous()
32 | if scramble_head_outputs:
33 | to_sd = to_sd[..., torch.randperm(d_head)].contiguous()
34 | sd[k] = to_sd
35 | elif scramble_mlps and ("mlp" in k):
36 | if k.endswith("W_in"):
37 | assert v.shape[1] == n_neurons
38 | sd[k] = v[:, torch.randperm(n_neurons)].contiguous()
39 | elif k.endswith("b_in"):
40 | assert v.shape[0] == n_neurons
41 | sd[k] = v[torch.randperm(n_neurons)].contiguous()
42 | else:
43 | # print(f"Leaving {k} intact")
44 | sd[k] = v
45 | else:
46 | # print(f"Leaving {k} intact")
47 | sd[k] = v
48 | assert sd.keys() == old_sd.keys()
49 | for k in sd.keys():
50 | assert sd[k].shape == old_sd[k].shape
51 |
52 | torch.save({k: v.cpu() for k, v in sd.items()}, name + ".pt")
53 |
54 | def all_metrics(logits):
55 | m = {}
56 | for k, v in things.test_metrics.items():
57 | m[k] = v(logits).item()
58 | return m
59 |
60 | metrics = {}
61 | metrics["trained_orig"] = all_metrics(model(things.test_data))
62 | metrics["trained_patched"] = all_metrics(model(things.test_patch_data))
63 |
64 | model.load_state_dict(sd)
65 | metrics["reset_orig"] = all_metrics(model(things.test_data))
66 | metrics["reset_patched"] = all_metrics(model(things.test_patch_data))
67 |
68 | with open(name + "_test_metrics.json", "w") as f:
69 | json.dump(metrics, f)
70 | pprint(metrics)
71 |
72 |
73 | # %% IOI
74 |
75 | things = get_all_ioi_things(num_examples=100, device=DEVICE, metric_name="kl_div")
76 | scramble_sd(1504304416, things, "ioi_reset_heads_neurons")
77 | del things
78 | gc.collect()
79 |
80 | # %% Tracr-reverse
81 | things = get_all_tracr_things(task="reverse", metric_name="kl_div", num_examples=6, device=DEVICE)
82 | scramble_sd(1207775456, things, "tracr_reverse_reset_heads_head_outputs_neurons", scramble_head_outputs=True)
83 | gc.collect()
84 |
85 | scramble_sd(1666927681, things, "tracr_reverse_reset_heads_neurons", scramble_head_outputs=False)
86 | del things
87 | gc.collect()
88 |
89 | # %% Tracr-proportion
90 |
91 | things = get_all_tracr_things(task="proportion", metric_name="kl_div", num_examples=50, device=DEVICE)
92 | scramble_sd(2126292961, things, "tracr_proportion_reset_heads_head_outputs_neurons", scramble_head_outputs=True)
93 | gc.collect()
94 |
95 | scramble_sd(913070797, things, "tracr_proportion_reset_heads_neurons", scramble_head_outputs=False)
96 | del things
97 | gc.collect()
98 |
99 | # %% Induction
100 |
101 | things = get_all_induction_things(num_examples=50, seq_len=300, device=DEVICE, metric="kl_div")
102 | scramble_sd(2016630123, things, "induction_reset_heads_neurons")
103 | del things
104 | gc.collect()
105 |
106 | # %% Docstring
107 |
108 | things = get_all_docstring_things(num_examples=50, seq_len=2, device=DEVICE, metric_name="kl_div", correct_incorrect_wandb=False)
109 | scramble_sd(814220622, things, "docstring_reset_heads_neurons")
110 | del things
111 | gc.collect()
112 |
113 |
114 | # %% GreaterThan
115 |
116 | things = get_all_greaterthan_things(num_examples=100, device=DEVICE, metric_name="kl_div")
117 | scramble_sd(1028419464, things, "greaterthan_reset_heads_neurons")
118 | del things
119 | gc.collect()
120 |
--------------------------------------------------------------------------------
/subnetwork_probing/launch_grid_fill.py:
--------------------------------------------------------------------------------
1 | from experiments.launcher import KubernetesJob, WandbIdentifier, launch
2 | import numpy as np
3 | import random
4 | from typing import List, Optional
5 |
6 | METRICS_FOR_TASK = {
7 | "ioi": ["kl_div", "logit_diff"],
8 | "tracr-reverse": ["l2"],
9 | "tracr-proportion": ["kl_div", "l2"],
10 | "induction": ["kl_div", "nll"],
11 | "docstring": ["kl_div", "docstring_metric"],
12 | "greaterthan": ["kl_div", "greaterthan"],
13 | }
14 |
15 |
16 | def main(TASKS: list[str], job: Optional[KubernetesJob], name: str, testing: bool, reset_networks: bool):
17 | NUM_SPACINGS = 5 if reset_networks else 21
18 | expensive_base_regularization_params = np.concatenate(
19 | [
20 | 10 ** np.linspace(-2, 0, 11),
21 | np.linspace(1, 10, 10)[1:],
22 | np.linspace(10, 250, 13)[1:],
23 | ]
24 | )
25 |
26 | if reset_networks:
27 | base_regularization_params = 10 ** np.linspace(-2, 1.5, NUM_SPACINGS)
28 | else:
29 | base_regularization_params = expensive_base_regularization_params
30 |
31 | wandb_identifier = WandbIdentifier(
32 | run_name=f"{name}-res{int(reset_networks)}-{{i:05d}}",
33 | group_name="tracr-shuffled-redo",
34 | project="induction-sp-replicate")
35 |
36 |
37 | seed = 1507014021
38 | random.seed(seed)
39 |
40 | commands: List[List[str]] = []
41 | for reset_network in [int(reset_networks)]:
42 | for zero_ablation in [0, 1]:
43 | for task in TASKS:
44 | for metric in METRICS_FOR_TASK[task]:
45 | if task.startswith("tracr"):
46 | # Typical metric value range: 0.0-0.1
47 | regularization_params = 10 ** np.linspace(-3, 0, 11)
48 |
49 | if task == "tracr-reverse":
50 | num_examples = 6
51 | seq_len = 5
52 | elif task == "tracr-proportion":
53 | num_examples = 50
54 | seq_len = 5
55 | else:
56 | raise ValueError("Unknown task")
57 |
58 | elif task == "greaterthan":
59 | if metric == "kl_div":
60 | # Typical metric value range: 0.0-20
61 | regularization_params = base_regularization_params
62 | elif metric == "greaterthan":
63 | # Typical metric value range: -1.0 - 0.0
64 | regularization_params = 10 ** np.linspace(-4, 2, NUM_SPACINGS)
65 | else:
66 | raise ValueError("Unknown metric")
67 | num_examples = 100
68 | seq_len = -1
69 | elif task == "docstring":
70 | seq_len = 41
71 | if metric == "kl_div":
72 | # Typical metric value range: 0.0-10.0
73 | regularization_params = expensive_base_regularization_params
74 | elif metric == "docstring_metric":
75 | # Typical metric value range: -1.0 - 0.0
76 | regularization_params = 10 ** np.linspace(-4, 2, 21)
77 | else:
78 | raise ValueError("Unknown metric")
79 | num_examples = 50
80 | elif task == "ioi":
81 | num_examples = 100
82 | seq_len = -1
83 | if metric == "kl_div":
84 | # Typical metric value range: 0.0-12.0
85 | regularization_params = base_regularization_params
86 | elif metric == "logit_diff":
87 | # Typical metric value range: -0.31 -- -0.01
88 | regularization_params = 10 ** np.linspace(-4, 2, NUM_SPACINGS)
89 | else:
90 | raise ValueError("Unknown metric")
91 | elif task == "induction":
92 | seq_len = 300
93 | num_examples = 50
94 | if metric == "kl_div":
95 | # Typical metric value range: 0.0-16.0
96 | regularization_params = expensive_base_regularization_params
97 | elif metric == "nll":
98 | # Typical metric value range: 0.0-16.0
99 | regularization_params = expensive_base_regularization_params
100 | else:
101 | raise ValueError("Unknown metric")
102 | else:
103 | raise ValueError("Unknown task")
104 |
105 | if job is None:
106 | device = "cpu"
107 | n_cpu = 4
108 | assert testing
109 | else:
110 | device = "cuda" if job.gpu else "cpu"
111 | n_cpu = job.cpu
112 |
113 | for lambda_reg in [0.01] if testing else regularization_params:
114 | command = [
115 | "python",
116 | "subnetwork_probing/train.py",
117 | f"--task={task}",
118 | f"--lambda-reg={lambda_reg:.3f}",
119 | f"--wandb-name=agarriga-sp-{len(commands):05d}{'-optional' if task in ['induction', 'docstring'] else ''}",
120 | "--wandb-project=induction-sp-replicate",
121 | "--wandb-entity=remix_school-of-rock",
122 | "--wandb-group=tracr-shuffled-redo",
123 | f"--device={device}",
124 | f"--epochs={1 if testing else 10000}",
125 | f"--zero-ablation={zero_ablation}",
126 | f"--reset-subject={reset_network}",
127 | f"--seed={random.randint(0, 2**32 - 1)}",
128 | f"--loss-type={metric}",
129 | f"--num-examples={6 if testing else num_examples}",
130 | f"--seq-len={seq_len}",
131 | f"--n-loss-average-runs={1 if testing else 20}",
132 | "--wandb-dir=/training", # If it doesn't exist wandb will use /tmp
133 | f"--wandb-mode={'offline' if testing else 'online'}",
134 | f"--torch-num-threads={n_cpu}",
135 | ]
136 | commands.append(command)
137 |
138 | launch(
139 | commands,
140 | name=name,
141 | job=job,
142 | synchronous=True,
143 | check_wandb=wandb_identifier,
144 | just_print_commands=False,
145 | )
146 |
147 | if __name__ == "__main__":
148 | for reset_networks in [False, True]:
149 | for task in ["tracr-reverse"]:
150 | main(
151 | [task],
152 | KubernetesJob(
153 | container="ghcr.io/rhaps0dy/automatic-circuit-discovery:1.7.2",
154 | cpu=4,
155 | gpu=0 if task.startswith("tracr") else 1,
156 | mount_training=False,
157 | ),
158 | name=f"sp-{task}",
159 | testing=False,
160 | reset_networks=reset_networks,
161 | )
162 |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/fastai/nbdev
3 | rev: 2.2.10
4 | hooks:
5 | - id: nbdev_clean
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 neelnanda-io
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/Old_Demo.ipynb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ArthurConmy/Automatic-Circuit-Discovery/bc99ace817974b5584b7ee203d596a8e2bbcd399/subnetwork_probing/transformer_lens/Old_Demo.ipynb
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/README.md:
--------------------------------------------------------------------------------
1 | # TransformerLens
2 |
3 | A modified version of TransformerLens.
4 |
5 | [](https://pypi.org/project/transformer-lens/)
6 |
7 | (Formerly known as EasyTransformer)
8 |
9 | ## [Start Here](https://neelnanda.io/transformer-lens-demo)
10 |
11 | ## A Library for Mechanistic Interpretability of Generative Language Models
12 |
13 | This is a library for doing [mechanistic interpretability](https://distill.pub/2020/circuits/zoom-in/) of GPT-2 Style language models. The goal of mechanistic interpretability is to take a trained model and reverse engineer the algorithms the model learned during training from its weights. It is a fact about the world today that we have computer programs that can essentially speak English at a human level (GPT-3, PaLM, etc), yet we have no idea how they work nor how to write one ourselves. This offends me greatly, and I would like to solve this!
14 |
15 | TransformerLens lets you load in an open source language model, like GPT-2, and exposes the internal activations of the model to you. You can cache any internal activation in the model, and add in functions to edit, remove or replace these activations as the model runs. The core design principle I've followed is to enable exploratory analysis. One of the most fun parts of mechanistic interpretability compared to normal ML is the extremely short feedback loops! The point of this library is to keep the gap between having an experiment idea and seeing the results as small as possible, to make it easy for **research to feel like play** and to enter a flow state. Part of what I aimed for is to make *my* experience of doing research easier and more fun, hopefully this transfers to you!
16 |
17 | I used to work for the [Anthropic interpretability team](transformer-circuits.pub), and I wrote this library because after I left and tried doing independent research, I got extremely frustrated by the state of open source tooling. There's a lot of excellent infrastructure like HuggingFace and DeepSpeed to *use* or *train* models, but very little to dig into their internals and reverse engineer how they work. **This library tries to solve that**, and to make it easy to get into the field even if you don't work at an industry org with real infrastructure! One of the great things about mechanistic interpretability is that you don't need large models or tons of compute. There are lots of important open problems that can be solved with a small model in a Colab notebook!
18 |
19 | The core features were heavily inspired by the interface to [Anthropic's excellent Garcon tool](https://transformer-circuits.pub/2021/garcon/index.html). Credit to Nelson Elhage and Chris Olah for building Garcon and showing me the value of good infrastructure for enabling exploratory research!
20 |
21 | ## Getting Started
22 |
23 | **Start with the [main demo](https://neelnanda.io/transformer-lens-demo) to learn how the library works, and the basic features**.
24 |
25 | To see what using it for exploratory analysis in practice looks like, check out [my notebook analysing Indirect Objection Identification](https://neelnanda.io/exploratory-analysis-demo) or [my recording of myself doing research](https://www.youtube.com/watch?v=yo4QvDn-vsU)!
26 |
27 | Mechanistic interpretability is a very young and small field, and there are a *lot* of open problems - if you would like to help, please try working on one! **Check out my [list of concrete open problems](https://docs.google.com/document/d/1WONBzNqfKIxERejrrPlQMyKqg7jSFW92x5UMXNrMdPo/edit) to figure out where to start.**. It begins with advice on skilling up, and key resources to check out.
28 |
29 | If you're new to transformers, check out my [what is a transformer tutorial](https://neelnanda.io/transformer-tutorial) and [tutorial on coding GPT-2 from scratch](https://neelnanda.io/transformer-tutorial-2) (with [an accompanying template](https://neelnanda.io/transformer-template) to write one yourself!
30 |
31 | ## Gallery
32 |
33 | User contributed examples of the library being used in action:
34 | * [Induction Heads Phase Change Replication](https://colab.research.google.com/github/ckkissane/induction-heads-transformer-lens/blob/main/Induction_Heads_Phase_Change.ipynb): A partial replication of [In-Context Learning and Induction Heads](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html) from Connor Kissane
35 |
36 | ## Advice for Reading the Code
37 |
38 | One significant design decision made was to have a single transformer implementation that could support a range of subtly different GPT-style models. This has the upside of interpretability code just working for arbitrary models when you change the model name in `HookedTransformer.from_pretrained`! But it has the significant downside that the code implementing the model (in `HookedTransformer.py` and `components.py`) can be difficult to read. I recommend starting with my [Clean Transformer Demo](https://neelnanda.io/transformer-solution), which is a clean, minimal implementation of GPT-2 with the same internal architecture and activation names as HookedTransformer, but is significantly clearer and better documented.
39 |
40 | ## Installation
41 |
42 | `pip install git+https://github.com/neelnanda-io/TransformerLens`
43 |
44 | Import the library with `import transformer_lens`
45 |
46 | (Note: This library used to be known as EasyTransformer, and some breaking changes have been made since the rename. If you need to use the old version with some legacy code, run `pip install git+https://github.com/neelnanda-io/TransformerLens@v1`.)
47 |
48 | ## Local Development
49 |
50 | ### DevContainer
51 |
52 | For a one-click setup of your development environment, this project includes a [DevContainer](https://containers.dev/). It can be used locally with [VS Code](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) or with [GitHub Codespaces](https://github.com/features/codespaces).
53 |
54 | ### Manual Setup
55 |
56 | This project uses [Poetry](https://python-poetry.org/docs/#installation) for package management. Install as follows (this will also setup your virtual environment):
57 |
58 | ```bash
59 | poetry config virtualenvs.in-project true
60 | poetry install --with dev
61 | ```
62 |
63 | Optionally, if you want Jupyter Lab you can run `poetry run pip install jupyterlab` (to install in the same virtual environment), and then run with `poetry run jupyter lab`.
64 |
65 | Then the library can be imported as `import transformer_lens`.
66 |
67 | ### Testing
68 |
69 | If adding a feature, please add unit tests for it to the tests folder, and check that it hasn't broken anything major using the existing tests (install pytest and run it in the root TransformerLens/ directory)
70 |
71 | ## Citation
72 |
73 | Please cite this library as:
74 | ```
75 | @misc{nandatransformerlens2022,
76 | title = {TransformerLens},
77 | author = {Nanda, Neel},
78 | url = {https://github.com/neelnanda-io/TransformerLens},
79 | year = {2022}
80 | }
81 | ```
82 | (This is my best guess for how citing software works, feel free to send a correction!)
83 | Also, if you're actually using this for your research, I'd love to chat! Reach out at neelnanda27@gmail.com
84 |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/easy_transformer/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logging.warning("DEPRECATED: Library has been renamed, import transformer_lens instead")
4 | from transformer_lens import *
5 |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/further_comments.md:
--------------------------------------------------------------------------------
1 | # Further Details on Config Options
2 | ## Shortformer Attention (`positional_embeddings_type == "shortformer"`)
3 | Shortformer style models are a variant on GPT-2 style positional embeddings, which do not add positional embeddings into the residual stream but instead add it in to the queries and keys immediately before multiplying by W_Q and W_K, and NOT having it around for the values or MLPs. It's otherwise the same - the positional embeddings are absolute, and are learned. The positional embeddings are NOT added to the residual stream in the standard way, and instead the queries and keys are calculated as W_Q(res_stream + pos_embed) and W_K(res_stream + pos_embed). The values and MLPs are calculated as W_V(res_stream) and W_MLP(res_stream) and so don't have access to positional information. This is otherwise the same as GPT-2 style positional embeddings. This is a variant on the Shortformer model from the paper [Shortformer: The Benefits of Shorter Sequences in Language Modeling](https://arxiv.org/abs/2012.15832). It's morally similar to rotary, which also only gives keys & queries access to positional info
4 |
5 | The original intention was to use this to do more efficient caching: caching is hard with absolute positional embeddings, since you can't translate the context window without recomputing the entire thing, but easier if the prior values and residual stream terms are the same. I've mostly implemented it because it makes it easier for models to form induction heads. I'm not entirely sure why, though hypothesise that it's because there's two ways for induction heads to form with positional embeddings in the residual stream and only one with shortformer style positional embeddings.
6 |
7 | # Weight Processing
8 | ## What is LayerNorm Folding? (`fold_ln`)
9 | [LayerNorm](https://wandb.ai/wandb_fc/LayerNorm/reports/Layer-Normalization-in-Pytorch-With-Examples---VmlldzoxMjk5MTk1) is a common regularisation technique used in transformers. Annoyingly, unlike eg BatchNorm, it can't be turned off at inference time, it's a meaningful change to the mathematical function implemented by the transformer. From an interpretability perspective, this is a headache! And it's easy to shoot yourself in the foot by naively ignoring it - eg, making the mistake of saying neuron_pre = resid_mid @ W_in, rather than LayerNorm(resid_mid) @ W_in. This mistake is an OK approximation, but by folding in the LayerNorm we can do much better!
10 |
11 | TLDR: If we have LayerNorm (weights w_ln and b_ln) followed by a linear layer (W+b), we can reduce the LayerNorm to LayerNormPre (just centering & normalising) and follow it by a linear layer with `W_eff = w[:, None] * W` (element-wise multiplication) and `b_eff = b + b_ln @ W`. This is computationally equivalent, and it never makes sense to think of W and w_ln as separate objects, so HookedTransformer handles it for you when loading pre-trained weights - set fold_ln = False when loading a state dict if you want to turn this off
12 |
13 | Mathematically, LayerNorm is the following:
14 | ```
15 | x1 = x0 - x0.mean()
16 | x2 = x1 / ((x1**2).mean()).sqrt()
17 | x3 = x2 * w
18 | x4 = x3 + b
19 | ```
20 |
21 | Apart from dividing by the norm, these are all pretty straightforwards operations from a linear algebra perspective. And from an interpretability perspective, if anything is linear, it's really easy and you can mostly ignore it (everything breaks up into sums, you can freely change basis, don't need to track interference between terms, etc) - the hard part is engaging with non-linearities!
22 |
23 | A key thing to bear in mind is that EVERY time we read from the residual stream, we apply a LayerNorm - this gives us a lot of leverage to reason about it!
24 |
25 | So let's translate this into linear algebra notation.
26 | `x0` is a vector in `R^n`
27 |
28 | ```
29 | x1 = x0 - x0.mean()
30 | = x0 - (x0.mean()) * ones (broadcasting, ones=torch.ones(n))
31 | = x0 - (x0 @ ones/sqrt(n)) * ones/sqrt(n).
32 | ```
33 |
34 | ones has norm sqrt(n), so ones/sqrt(n) is the unit vector in the diagonal direction. We're just projecting x0 onto this (fixed) vector and subtracting that value off. Alternately, we're projecting onto the n-1 dimensional subspace orthogonal to ones.
35 |
36 | Since LayerNorm is applied EVERY time we read from the stream, the model just never uses the ones direction of the residual stream, so it's essentially just decreasing d_model by one. We can simulate this by just centering all matrices writing to the residual stream.
37 |
38 | Why is removing this dimension useful? I have no idea! I'm not convinced it is...
39 |
40 | ```
41 | x2 = x1 / ((x1**2).mean()).sqrt() (Ignoring eps)
42 | = (x1 / x1.norm()) * sqrt(n)
43 | ```
44 |
45 | This is a projection onto the unit sphere (well, sphere of radius sqrt(n) - the norm of ones). This is fundamentally non-linear, eg doubling the input keeps the output exactly the same.
46 |
47 | This is by far the most irritating part of LayerNorm. I THINK it's mostly useful for numerical stability reasons and not used to do useful computation by the model, but I could easily be wrong! And interpreting a circuit containing LayerNorm sounds like a nightmare...
48 |
49 | In practice, you can mostly get aware with ignore this and treating the scaling factor as a constant, since it does apply across the entire residual stream for each token - this makes it a "global" property of the model's calculation, so for any specific question it hopefully doesn't matter that much. But when you're considering a sufficiently important circuit that it's a good fraction of the norm of the residual stream, it's probably worth thinking about.
50 |
51 | ```
52 | x3 = x2 * w
53 | = x2 @ W_ln
54 | ```
55 |
56 | (`W_ln` is a diagonal matrix with the weights of the LayerNorm - this is equivalent to element-wise multiplication)
57 | This is really easy to deal with - we're about to be input to a linear layer, and can say `(x2 @ W_ln) @ W = x2 @ (W_ln @ W) = x2 @ W_eff` - we can just fold the LayerNorm weights into the linear layer weights.
58 |
59 | `x4 = x3 + b` is similarly easy - `x4 @ W + B = x2 @ W_eff + B_eff`, where `W_eff = W_ln @ W` and `B_eff = B + b @ W`
60 |
61 | This function is calculating `W_eff` and `B_eff` for each layer reading from the residual stream and replacing W and B with those.
62 |
63 | A final optimisation we can make is to **center the reading weights**. x2 has mean 0, which means it's orthogonal to the vector of all ones (`x2 @ ones = x2.sum() = len(x2) * x2.mean()`). This means that the component of `W_eff` that's parallel to `ones` is irrelevant, and we can set that to zero. In code, this means `W_eff -= W_eff.mean(dim=0, keepdim=True)`. This doesn't change the computation but makes things a bit simpler.
64 |
65 | See this for more: https://transformer-circuits.pub/2021/framework/index.html#:~:text=Handling%20Layer%20Normalization
66 |
67 | ## Centering Writing Weights (`center_writing_weight`)
68 |
69 | A related idea to folding layernorm - *every* component reading an input from the residual stream is preceded by a LayerNorm, which means that the mean of a residual stream vector (ie the component in the direction of all ones) never matters. This means we can remove the all ones component of weights and biases whose output *writes* to the residual stream. Mathematically, `W_writing -= W_writing.mean(dim=1, keepdim=True)`
70 |
71 | ## Centering Unembed (`center_unembed`)
72 |
73 | The logits are fed into a softmax. Softmax is translation invariant (eg, adding 1 to every logit doesn't change the output), so we can simplify things by setting the mean of the logits to be zero. This is equivalent to setting the mean of every output vector of `W_U` to zero. In code, `W_U -= W_U.mean(dim=-1, keepdim=True)`
74 |
75 | ## Fold Value Biases (`fold_value_biases`)
76 |
77 | Each attention head has a value bias. Values are averaged to create mixed values (`z`), weighted by the attention pattern, but as the bias is constant, its contribution to `z` is exactly the same. The output of a head is `z @ W_O`, and so the value bias just linearly adds to the output of the head. This means that the value bias of a head has *nothing to do with the head*, and is just a constant added to the attention layer outputs. We can take the sum across these and `b_O` to get an "effective bias" for the layer. In code, we set `b_V=0.` and `b_O = (b_V @ W_O).sum(dim=0) + b_O`
78 |
79 | Technical derivation
80 |
81 | `v = residual @ W_V[h] + broadcast_b_V[h]` for each head `h` (where `b_V` is broadcast up from shape `d_head` to shape `[position, d_head]`). And `z = pattern[h] @ v = pattern[h] @ residual @ W_V[h] + pattern[h] @ broadcast_b_V[h]`. Because `pattern[h]` is `[destination_position, source_position]` and `broadcast_b_V` is *constant* along the `(source_)position` dimension, we're basically just multiplying it by the sum of the pattern across the `source_position` dimension, which is just 1. So it remains exactly the same, and so is just brodcast across the destination positions.
82 |
83 |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "transformer-lens"
3 | version = "0.0.0" # This is automatically set by the CD pipeline on release
4 | description = "An implementation of transformers tailored for mechanistic interpretability."
5 | authors = ["Neel Nanda <77788841+neelnanda-io@users.noreply.github.com>"]
6 | license = "MIT"
7 | readme = "README.md"
8 | packages = [{include = "transformer_lens"}]
9 |
10 | [tool.poetry.dependencies]
11 | python = "^3.7"
12 | einops = "^0.6.0"
13 | numpy = [{ version = "^1.21", python = "<3.10" },
14 | { version = "^1.23", python = ">=3.10" }]
15 | torch = "^1.10"
16 | datasets = "^2.7.1"
17 | transformers = "^4.25.1"
18 | tqdm = "^4.64.1"
19 | pandas = "^1.1.5"
20 | wandb = "^0.13.5"
21 | fancy-einsum = "^0.0.3"
22 | torchtyping = "^0.1.4"
23 | rich = "^12.6.0"
24 |
25 | [tool.poetry.group.dev.dependencies]
26 | pytest = "^7.2.0"
27 | mypy = "^0.991"
28 | jupyter = "^1.0.0"
29 | circuitsvis = "^1.38.1"
30 | plotly = "^5.12.0"
31 |
32 | [tool.poetry.group.jupyter.dependencies]
33 | jupyterlab = "^3.5.0"
34 |
35 | [build-system]
36 | requires = ["poetry-core"]
37 | build-backend = "poetry.core.masonry.api"
38 |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name="transformer_lens",
5 | version="0.1.0",
6 | packages=["transformer_lens"],
7 | license="LICENSE",
8 | description="An implementation of transformers tailored for mechanistic interpretability.",
9 | long_description=open("README.md").read(),
10 | install_requires=[
11 | "einops",
12 | "numpy",
13 | "torch",
14 | "datasets",
15 | "transformers",
16 | "tqdm",
17 | "pandas",
18 | "datasets",
19 | "wandb",
20 | "fancy_einsum",
21 | "torchtyping",
22 | "rich",
23 | ],
24 | extras_require={"dev": ["pytest", "mypy"]},
25 | )
26 |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/transformer_lens/FactoredMatrix.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import torch
3 | from typing import Optional, Union, Tuple, List, Dict
4 | from torchtyping import TensorType as TT
5 | from functools import lru_cache
6 | from . import utils as utils
7 | from .torchtyping_helper import (
8 | T as TH,
9 | ) # Need the import as since FactoredMatrix declares T (for transpose)
10 |
11 |
12 | class FactoredMatrix:
13 | """
14 | Class to represent low rank factored matrices, where the matrix is represented as a product of two matrices. Has utilities for efficient calculation of eigenvalues, norm and SVD.
15 | """
16 |
17 | def __init__(self, A: TT[..., TH.ldim, TH.mdim], B: TT[..., TH.mdim, TH.rdim]):
18 | self.A = A
19 | self.B = B
20 | assert self.A.size(-1) == self.B.size(
21 | -2
22 | ), f"Factored matrix must match on inner dimension, shapes were a: {self.A.shape}, b:{self.B.shape}"
23 | self.ldim = self.A.size(-2)
24 | self.rdim = self.B.size(-1)
25 | self.mdim = self.B.size(-2)
26 | self.has_leading_dims = (self.A.ndim > 2) or (self.B.ndim > 2)
27 | self.shape = torch.broadcast_shapes(self.A.shape[:-2], self.B.shape[:-2]) + (
28 | self.ldim,
29 | self.rdim,
30 | )
31 | self.A = self.A.broadcast_to(self.shape[:-2] + (self.ldim, self.mdim))
32 | self.B = self.B.broadcast_to(self.shape[:-2] + (self.mdim, self.rdim))
33 |
34 | def __matmul__(
35 | self, other: Union[TT[..., TH.rdim, TH.new_rdim], TT[TH.rdim], FactoredMatrix]
36 | ) -> Union[FactoredMatrix, TT[..., TH.ldim]]:
37 | if isinstance(other, torch.Tensor):
38 | if other.ndim < 2:
39 | # It's a vector, so we collapse the factorisation and just return a vector
40 | # Squeezing/Unsqueezing is to preserve broadcasting working nicely
41 | return (self.A @ (self.B @ other.unsqueeze(-1))).squeeze(-1)
42 | else:
43 | assert (
44 | other.size(-2) == self.rdim
45 | ), f"Right matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}"
46 | if self.rdim > self.mdim:
47 | return FactoredMatrix(self.A, self.B @ other)
48 | else:
49 | return FactoredMatrix(self.AB, other)
50 | elif isinstance(other, FactoredMatrix):
51 | return (self @ other.A) @ other.B
52 |
53 | def __rmatmul__(
54 | self, other: Union[TT[..., TH.new_rdim, TH.ldim], TT[TH.ldim], FactoredMatrix]
55 | ) -> Union[FactoredMatrix, TT[..., TH.rdim]]:
56 | if isinstance(other, torch.Tensor):
57 | assert (
58 | other.size(-1) == self.ldim
59 | ), f"Left matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}"
60 | if other.ndim < 2:
61 | # It's a vector, so we collapse the factorisation and just return a vector
62 | return ((other.unsqueeze(-2) @ self.A) @ self.B).squeeze(-1)
63 | elif self.ldim > self.mdim:
64 | return FactoredMatrix(other @ self.A, self.B)
65 | else:
66 | return FactoredMatrix(other, self.AB)
67 | elif isinstance(other, FactoredMatrix):
68 | return other.A @ (other.B @ self)
69 |
70 | @property
71 | def AB(self) -> TT[TH.leading_dims : ..., TH.ldim, TH.rdim]:
72 | """The product matrix - expensive to compute, and can consume a lot of GPU memory"""
73 | return self.A @ self.B
74 |
75 | @property
76 | def BA(self) -> TT[TH.leading_dims : ..., TH.rdim, TH.ldim]:
77 | """The reverse product. Only makes sense when ldim==rdim"""
78 | assert (
79 | self.rdim == self.ldim
80 | ), f"Can only take ba if ldim==rdim, shapes were self: {self.shape}"
81 | return self.B @ self.A
82 |
83 | @property
84 | def T(self) -> FactoredMatrix:
85 | return FactoredMatrix(self.B.transpose(-2, -1), self.A.transpose(-2, -1))
86 |
87 | @lru_cache(maxsize=None)
88 | def svd(
89 | self,
90 | ) -> Tuple[
91 | TT[TH.leading_dims : ..., TH.ldim, TH.mdim],
92 | TT[TH.leading_dims : ..., TH.mdim],
93 | TT[TH.leading_dims : ..., TH.rdim, TH.mdim],
94 | ]:
95 | """
96 | Efficient algorithm for finding Singular Value Decomposition, a tuple (U, S, Vh) for matrix M st S is a vector and U, Vh are orthogonal matrices, and U @ S.diag() @ Vh.T == M
97 |
98 | (Note that Vh is given as the transpose of the obvious thing)
99 | """
100 | Ua, Sa, Vha = torch.svd(self.A)
101 | Ub, Sb, Vhb = torch.svd(self.B)
102 | middle = Sa[..., :, None] * utils.transpose(Vha) @ Ub * Sb[..., None, :]
103 | Um, Sm, Vhm = torch.svd(middle)
104 | U = Ua @ Um
105 | Vh = Vhb @ Vhm
106 | S = Sm
107 | return U, S, Vh
108 |
109 | @property
110 | def U(self) -> TT[TH.leading_dims : ..., TH.ldim, TH.mdim]:
111 | return self.svd()[0]
112 |
113 | @property
114 | def S(self) -> TT[TH.leading_dims : ..., TH.mdim]:
115 | return self.svd()[1]
116 |
117 | @property
118 | def Vh(self) -> TT[TH.leading_dims : ..., TH.rdim, TH.mdim]:
119 | return self.svd()[2]
120 |
121 | @property
122 | def eigenvalues(self) -> TT[TH.leading_dims : ..., TH.mdim]:
123 | """Eigenvalues of AB are the same as for BA (apart from trailing zeros), because if BAv=kv ABAv = A(BAv)=kAv, so Av is an eigenvector of AB with eigenvalue k."""
124 | return torch.linalg.eig(self.BA).eigenvalues
125 |
126 | def __getitem__(self, idx: Union[int, Tuple]) -> FactoredMatrix:
127 | """Indexing - assumed to only apply to the leading dimensions."""
128 | if not isinstance(idx, tuple):
129 | idx = (idx,)
130 | length = len([i for i in idx if i is not None])
131 | if length <= len(self.shape) - 2:
132 | return FactoredMatrix(self.A[idx], self.B[idx])
133 | elif length == len(self.shape) - 1:
134 | return FactoredMatrix(self.A[idx], self.B[idx[:-1]])
135 | elif length == len(self.shape):
136 | return FactoredMatrix(
137 | self.A[idx[:-1]], self.B[idx[:-2] + (slice(None), idx[-1])]
138 | )
139 | else:
140 | raise ValueError(
141 | f"{idx} is too long an index for a FactoredMatrix with shape {self.shape}"
142 | )
143 |
144 | def norm(self) -> TT[TH.leading_dims : ...]:
145 | """
146 | Frobenius norm is sqrt(sum of squared singular values)
147 | """
148 | return self.S.pow(2).sum(-1).sqrt()
149 |
150 | def __repr__(self):
151 | return f"FactoredMatrix: Shape({self.shape}), Hidden Dim({self.mdim})"
152 |
153 | def make_even(self) -> FactoredMatrix:
154 | """
155 | Returns the factored form of (U @ S.sqrt().diag(), S.sqrt().diag() @ Vh) where U, S, Vh are the SVD of the matrix. This is an equivalent factorisation, but more even - each half has half the singular values, and orthogonal rows/cols
156 | """
157 | return FactoredMatrix(
158 | self.U * self.S.sqrt()[..., None, :],
159 | self.S.sqrt()[..., :, None] * utils.transpose(self.Vh),
160 | )
161 |
162 | def get_corner(self, k=3):
163 | return utils.get_corner(self.A[..., :k, :] @ self.B[..., :, :k], k)
164 |
165 | @property
166 | def ndim(self) -> int:
167 | return len(self.shape)
168 |
169 | def collapse_l(self) -> TT[TH.leading_dims : ..., TH.mdim, TH.rdim]:
170 | """
171 | Collapses the left side of the factorization by removing the orthogonal factor (given by self.U). Returns a (..., mdim, rdim) tensor
172 | """
173 | return self.S[..., :, None] * utils.transpose(self.Vh)
174 |
175 | def collapse_r(self) -> TT[TH.leading_dims : ..., TH.ldim, TH.mdim]:
176 | """
177 | Analogous to collapse_l, returns a (..., ldim, mdim) tensor
178 | """
179 | return self.U * self.S[..., None, :]
180 |
181 | def unsqueeze(self, k: int) -> FactoredMatrix:
182 | return FactoredMatrix(self.A.unsqueeze(k), self.B.unsqueeze(k))
183 |
184 | @property
185 | def pair(
186 | self,
187 | ) -> Tuple[
188 | TT[TH.leading_dims : ..., TH.ldim, TH.mdim],
189 | TT[TH.leading_dims : ..., TH.mdim, TH.rdim],
190 | ]:
191 | return (self.A, self.B)
192 |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/transformer_lens/__init__.py:
--------------------------------------------------------------------------------
1 | from . import hook_points
2 | from . import utils
3 | from . import evals
4 | from .past_key_value_caching import (
5 | HookedTransformerKeyValueCache,
6 | HookedTransformerKeyValueCacheEntry,
7 | )
8 | from . import components
9 | from .HookedTransformerConfig import HookedTransformerConfig
10 | from .FactoredMatrix import FactoredMatrix
11 | from .ActivationCache import ActivationCache
12 | from .HookedTransformer import HookedTransformer
13 | from . import loading_from_pretrained as loading
14 | from . import patching
15 | from . import train
16 |
17 | from .past_key_value_caching import (
18 | HookedTransformerKeyValueCache as EasyTransformerKeyValueCache,
19 | HookedTransformerKeyValueCacheEntry as EasyTransformerKeyValueCacheEntry,
20 | )
21 | from .HookedTransformer import HookedTransformer as EasyTransformer
22 | from .HookedTransformerConfig import HookedTransformerConfig as EasyTransformerConfig
23 |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/transformer_lens/evals.py:
--------------------------------------------------------------------------------
1 | # %%
2 | """
3 | A file with some rough evals for models - I expect you to be likely better off using the HuggingFace evaluate library if you want to do anything properly, but this is here if you want it and want to eg cheaply and roughly compare models you've trained to baselines.
4 | """
5 |
6 | import torch
7 | import tqdm.auto as tqdm
8 | from datasets import load_dataset
9 | from . import HookedTransformer, HookedTransformerConfig, utils
10 | from torch.utils.data import DataLoader
11 | import einops
12 |
13 | # %%
14 | def sanity_check(model):
15 | """
16 | Very basic eval - just feeds a string into the model (in this case, the first paragraph of Circuits: Zoom In), and returns the loss. It's a rough and quick sanity check - if the loss is <5 the model is probably OK, if the loss is >7 something's gone wrong.
17 |
18 | Note that this is a very basic eval, and doesn't really tell you much about the model's performance.
19 | """
20 |
21 | text = "Many important transition points in the history of science have been moments when science 'zoomed in.' At these points, we develop a visualization or tool that allows us to see the world in a new level of detail, and a new field of science develops to study the world through this lens."
22 |
23 | return model(text, return_type="loss")
24 |
25 |
26 | # %%
27 | def make_wiki_data_loader(tokenizer, batch_size=8):
28 | """
29 | Evaluate on Wikitext 2, a dump of Wikipedia articles. (Using the train set because it's larger, I don't really expect anyone to bother with quarantining the validation set nowadays.)
30 |
31 | Note there's likely to be dataset leakage into training data (though I believe GPT-2 was explicitly trained on non-Wikipedia data)
32 | """
33 | wiki_data = load_dataset("wikitext", "wikitext-2-v1", split="train")
34 | print(len(wiki_data))
35 | dataset = utils.tokenize_and_concatenate(wiki_data, tokenizer)
36 | data_loader = DataLoader(
37 | dataset, batch_size=batch_size, shuffle=True, drop_last=True
38 | )
39 | return data_loader
40 |
41 |
42 | def make_owt_data_loader(tokenizer, batch_size=8):
43 | """
44 | Evaluate on OpenWebText an open source replication of the GPT-2 training corpus (Reddit links with >3 karma)
45 |
46 | I think the Mistral models were trained on this dataset, so they get very good performance.
47 | """
48 | owt_data = load_dataset("stas/openwebtext-10k", split="train")
49 | print(len(owt_data))
50 | dataset = utils.tokenize_and_concatenate(owt_data, tokenizer)
51 | data_loader = DataLoader(
52 | dataset, batch_size=batch_size, shuffle=True, drop_last=True
53 | )
54 | return data_loader
55 |
56 |
57 | def make_pile_data_loader(tokenizer, batch_size=8):
58 | """
59 | Evaluate on OpenWebText an open source replication of the GPT-2 training corpus (Reddit links with >3 karma)
60 |
61 | I think the Mistral models were trained on this dataset, so they get very good performance.
62 | """
63 | pile_data = load_dataset("NeelNanda/pile-10k", split="train")
64 | print(len(pile_data))
65 | dataset = utils.tokenize_and_concatenate(pile_data, tokenizer)
66 | data_loader = DataLoader(
67 | dataset, batch_size=batch_size, shuffle=True, drop_last=True
68 | )
69 | return data_loader
70 |
71 |
72 | def make_code_data_loader(tokenizer, batch_size=8):
73 | """
74 | Evaluate on the CodeParrot dataset, a dump of Python code. All models seem to get significantly lower loss here (even non-code trained models like GPT-2), presumably code is much easier to predict than natural language?
75 | """
76 | code_data = load_dataset("codeparrot/codeparrot-valid-v2-near-dedup", split="train")
77 | print(len(code_data))
78 | dataset = utils.tokenize_and_concatenate(
79 | code_data, tokenizer, column_name="content"
80 | )
81 | data_loader = DataLoader(
82 | dataset, batch_size=batch_size, shuffle=True, drop_last=True
83 | )
84 | return data_loader
85 |
86 |
87 | DATASET_NAMES = ["wiki", "owt", "pile", "code"]
88 | DATASET_LOADERS = [
89 | make_wiki_data_loader,
90 | make_owt_data_loader,
91 | make_pile_data_loader,
92 | make_code_data_loader,
93 | ]
94 |
95 | # %%
96 | @torch.inference_mode()
97 | def evaluate_on_dataset(model, data_loader, truncate=100):
98 | running_loss = 0
99 | total = 0
100 | for batch in tqdm.tqdm(data_loader):
101 | loss = model(batch["tokens"].cuda(), return_type="loss").mean()
102 | running_loss += loss.item()
103 | total += 1
104 | if total > truncate:
105 | break
106 | return running_loss / total
107 |
108 |
109 | # %%
110 | @torch.inference_mode()
111 | def induction_loss(
112 | model, tokenizer=None, batch_size=4, subseq_len=384, prepend_bos=True
113 | ):
114 | """
115 | Generates a batch of random sequences repeated twice, and measures model performance on the second half. Tests whether a model has induction heads.
116 |
117 | By default, prepends a beginning of string token (prepend_bos flag), which is useful to give models a resting position, and sometimes models were trained with this.
118 | """
119 | # Make the repeated sequence
120 | first_half_tokens = torch.randint(100, 20000, (batch_size, subseq_len)).cuda()
121 | repeated_tokens = einops.repeat(first_half_tokens, "b p -> b (2 p)")
122 |
123 | # Prepend a Beginning Of String token
124 | if prepend_bos:
125 | if tokenizer is None:
126 | tokenizer = model.tokenizer
127 | repeated_tokens[:, 0] = tokenizer.bos_token_id
128 | # Run the model, and extract the per token correct log prob
129 | logits = model(repeated_tokens, return_type="logits")
130 | correct_log_probs = utils.lm_cross_entropy_loss(
131 | logits, repeated_tokens, per_token=True
132 | )
133 | # Take the loss over the second half of the sequence
134 | return correct_log_probs[:, subseq_len + 1 :].mean()
135 |
136 |
137 | # %%
138 | @torch.inference_mode()
139 | def evaluate(model, truncate=100, batch_size=8, tokenizer=None):
140 | if tokenizer is None:
141 | tokenizer = model.tokenizer
142 | losses = {}
143 | for data_name, data_loader_fn in zip(DATASET_NAMES, DATASET_LOADERS):
144 | data_loader = data_loader_fn(tokenizer=tokenizer, batch_size=batch_size)
145 | loss = evaluate_on_dataset(model, data_loader, truncate=truncate)
146 | print(f"{data_name}: {loss}")
147 | losses[f"{data_name}_loss"] = loss
148 | return losses
149 |
150 |
151 | # %%
152 |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/transformer_lens/make_docs.py:
--------------------------------------------------------------------------------
1 | # %%
2 | from easy_transformer import loading
3 | from easy_transformer import utils
4 | from functools import lru_cache
5 |
6 | # %%
7 | cfg = loading.get_pretrained_model_config("solu-1l")
8 | print(cfg)
9 | # %%
10 | """
11 | Structure:
12 | d_model, d_mlp, d_head, d_vocab, act_fn, n_heads, n_layers, n_ctx, n_params,
13 | Make an architecture table separately probs
14 | tokenizer_name, training_data, has checkpoints
15 | act_fn includes attn_only
16 | architecture
17 | Architecture should list weird shit to be aware of.
18 | """
19 | import pandas as pd
20 | import numpy as np
21 |
22 | df = pd.DataFrame(np.random.randn(2, 2))
23 | print(df.to_markdown(open("test.md", "w")))
24 | # %%
25 | @lru_cache(maxsize=None)
26 | def get_config(model_name):
27 | return loading.get_pretrained_model_config(model_name)
28 |
29 |
30 | def get_property(name, model_name):
31 | cfg = get_config(model_name)
32 | if name == "act_fn":
33 | if cfg.attn_only:
34 | return "attn_only"
35 | elif cfg.act_fn == "gelu_new":
36 | return "gelu"
37 | elif cfg.act_fn == "gelu_fast":
38 | return "gelu"
39 | elif cfg.act_fn == "solu_ln":
40 | return "solu"
41 | else:
42 | return cfg.act_fn
43 | if name == "n_params":
44 | n_params = cfg.n_params
45 | if n_params < 1e4:
46 | return f"{n_params/1e3:.1f}K"
47 | elif n_params < 1e6:
48 | return f"{round(n_params/1e3)}K"
49 | elif n_params < 1e7:
50 | return f"{n_params/1e6:.1f}M"
51 | elif n_params < 1e9:
52 | return f"{round(n_params/1e6)}M"
53 | elif n_params < 1e10:
54 | return f"{n_params/1e9:.1f}B"
55 | elif n_params < 1e12:
56 | return f"{round(n_params/1e9)}B"
57 | else:
58 | raise ValueError(f"Passed in {n_params} above 1T?")
59 | else:
60 | return cfg.to_dict()[name]
61 |
62 |
63 | column_names = (
64 | "n_params, n_layers, d_model, n_heads, act_fn, n_ctx, d_vocab, d_head, d_mlp".split(
65 | ", "
66 | )
67 | )
68 | print(column_names)
69 | df = pd.DataFrame(
70 | {
71 | name: [
72 | get_property(name, model_name)
73 | for model_name in loading.DEFAULT_MODEL_ALIASES
74 | ]
75 | for name in column_names
76 | },
77 | index=loading.DEFAULT_MODEL_ALIASES,
78 | )
79 | display(df)
80 | df.to_markdown(open("model_properties_table.md", "w"))
81 | # %%
82 |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/transformer_lens/model_properties_table.md:
--------------------------------------------------------------------------------
1 | | | d_model | d_mlp | d_head | d_vocab | act_fn | n_heads | n_layers | n_ctx | n_params |
2 | |:-------------------------------|----------:|--------:|---------:|----------:|:----------|----------:|-----------:|--------:|:-----------|
3 | | gpt2-small | 768 | 3072 | 64 | 50257 | gelu | 12 | 12 | 1024 | 85M |
4 | | gpt2-medium | 1024 | 4096 | 64 | 50257 | gelu | 16 | 24 | 1024 | 302M |
5 | | gpt2-large | 1280 | 5120 | 64 | 50257 | gelu | 20 | 36 | 1024 | 708M |
6 | | gpt2-xl | 1600 | 6400 | 64 | 50257 | gelu | 25 | 48 | 1024 | 1.5B |
7 | | distillgpt2 | 768 | 3072 | 64 | 50257 | gelu | 12 | 6 | 1024 | 42M |
8 | | opt-125m | 768 | 3072 | 64 | 50272 | relu | 12 | 12 | 2048 | 85M |
9 | | opt-1.3b | 2048 | 8192 | 64 | 50272 | relu | 32 | 24 | 2048 | 1.2B |
10 | | opt-2.7b | 2560 | 10240 | 80 | 50272 | relu | 32 | 32 | 2048 | 2.5B |
11 | | opt-6.7b | 4096 | 16384 | 128 | 50272 | relu | 32 | 32 | 2048 | 6.4B |
12 | | opt-13b | 5120 | 20480 | 128 | 50272 | relu | 40 | 40 | 2048 | 13B |
13 | | opt-30b | 7168 | 28672 | 128 | 50272 | relu | 56 | 48 | 2048 | 30B |
14 | | opt-66b | 9216 | 36864 | 128 | 50272 | relu | 72 | 64 | 2048 | 65B |
15 | | gpt-neo-125M | 768 | 3072 | 64 | 50257 | gelu | 12 | 12 | 2048 | 85M |
16 | | gpt-neo-1.3B | 2048 | 8192 | 128 | 50257 | gelu | 16 | 24 | 2048 | 1.2B |
17 | | gpt-neo-2.7B | 2560 | 10240 | 128 | 50257 | gelu | 20 | 32 | 2048 | 2.5B |
18 | | gpt-j-6B | 4096 | 16384 | 256 | 50400 | gelu | 16 | 28 | 2048 | 5.6B |
19 | | gpt-neox-20b | 6144 | 24576 | 96 | 50432 | gelu_fast | 64 | 44 | 2048 | 20B |
20 | | stanford-gpt2-small-a | 768 | 3072 | 64 | 50257 | gelu | 12 | 12 | 1024 | 85M |
21 | | stanford-gpt2-small-b | 768 | 3072 | 64 | 50257 | gelu | 12 | 12 | 1024 | 85M |
22 | | stanford-gpt2-small-c | 768 | 3072 | 64 | 50257 | gelu | 12 | 12 | 1024 | 85M |
23 | | stanford-gpt2-small-d | 768 | 3072 | 64 | 50257 | gelu | 12 | 12 | 1024 | 85M |
24 | | stanford-gpt2-small-e | 768 | 3072 | 64 | 50257 | gelu | 12 | 12 | 1024 | 85M |
25 | | stanford-gpt2-medium-a | 1024 | 4096 | 64 | 50257 | gelu | 16 | 24 | 1024 | 302M |
26 | | stanford-gpt2-medium-b | 1024 | 4096 | 64 | 50257 | gelu | 16 | 24 | 1024 | 302M |
27 | | stanford-gpt2-medium-c | 1024 | 4096 | 64 | 50257 | gelu | 16 | 24 | 1024 | 302M |
28 | | stanford-gpt2-medium-d | 1024 | 4096 | 64 | 50257 | gelu | 16 | 24 | 1024 | 302M |
29 | | stanford-gpt2-medium-e | 1024 | 4096 | 64 | 50257 | gelu | 16 | 24 | 1024 | 302M |
30 | | pythia-70m | 512 | 2048 | 64 | 50304 | gelu | 8 | 6 | 2048 | 19M |
31 | | pythia-160m | 768 | 3072 | 64 | 50304 | gelu | 12 | 12 | 2048 | 85M |
32 | | pythia-410m | 1024 | 4096 | 64 | 50304 | gelu | 16 | 24 | 2048 | 302M |
33 | | pythia-1b | 2048 | 8192 | 256 | 50304 | gelu | 8 | 16 | 2048 | 805M |
34 | | pythia-1.4b | 2048 | 8192 | 128 | 50304 | gelu | 16 | 24 | 2048 | 1.2B |
35 | | pythia-6.9b | 4096 | 16384 | 128 | 50432 | gelu | 32 | 32 | 2048 | 6.4B |
36 | | pythia-12b | 5120 | 20480 | 128 | 50688 | gelu | 40 | 36 | 2048 | 11B |
37 | | pythia-70m-deduped | 512 | 2048 | 64 | 50304 | gelu | 8 | 6 | 2048 | 19M |
38 | | pythia-160m-deduped | 768 | 3072 | 64 | 50304 | gelu | 12 | 12 | 2048 | 85M |
39 | | EleutherAI/pythia-410m-deduped | 1024 | 4096 | 64 | 50304 | gelu | 16 | 24 | 2048 | 302M |
40 | | pythia-1.4b-deduped | 2048 | 8192 | 128 | 50304 | gelu | 16 | 24 | 2048 | 1.2B |
41 | | pythia-6.9b-deduped | 4096 | 16384 | 128 | 50432 | gelu | 32 | 32 | 2048 | 6.4B |
42 | | pythia-12b-deduped | 5120 | 20480 | 128 | 50688 | gelu | 40 | 36 | 2048 | 11B |
43 | | solu-1l-old | 1024 | 4096 | 64 | 50278 | solu | 16 | 1 | 1024 | 13M |
44 | | solu-2l-old | 736 | 2944 | 64 | 50278 | solu | 11 | 2 | 1024 | 13M |
45 | | solu-4l-old | 512 | 2048 | 64 | 50278 | solu | 8 | 4 | 1024 | 13M |
46 | | solu-6l-old | 768 | 3072 | 64 | 50278 | solu | 12 | 6 | 1024 | 42M |
47 | | solu-8l-old | 1024 | 4096 | 64 | 50278 | solu | 16 | 8 | 1024 | 101M |
48 | | solu-10l-old | 1280 | 5120 | 64 | 50278 | solu | 20 | 10 | 1024 | 197M |
49 | | solu-12l-old | 1536 | 6144 | 64 | 50278 | solu | 24 | 12 | 1024 | 340M |
50 | | solu-1l | 512 | 2048 | 64 | 48262 | solu | 8 | 1 | 1024 | 3.1M |
51 | | solu-2l | 512 | 2048 | 64 | 48262 | solu | 8 | 2 | 1024 | 6.3M |
52 | | solu-3l | 512 | 2048 | 64 | 48262 | solu | 8 | 3 | 1024 | 9.4M |
53 | | solu-4l | 512 | 2048 | 64 | 48262 | solu | 8 | 4 | 1024 | 13M |
54 | | solu-6l | 768 | 3072 | 64 | 48262 | solu | 12 | 6 | 1024 | 42M |
55 | | solu-8l | 1024 | 4096 | 64 | 48262 | solu | 16 | 8 | 1024 | 101M |
56 | | solu-10l | 1280 | 5120 | 64 | 48262 | solu | 20 | 10 | 1024 | 197M |
57 | | solu-12l | 1536 | 6144 | 64 | 48262 | solu | 24 | 12 | 1024 | 340M |
58 | | gelu-1l | 512 | 2048 | 64 | 48262 | gelu | 8 | 1 | 1024 | 3.1M |
59 | | gelu-2l | 512 | 2048 | 64 | 48262 | gelu | 8 | 2 | 1024 | 6.3M |
60 | | gelu-3l | 512 | 2048 | 64 | 48262 | gelu | 8 | 3 | 1024 | 9.4M |
61 | | gelu-4l | 512 | 2048 | 64 | 48262 | gelu | 8 | 4 | 1024 | 13M |
62 | | attn-only-1l | 512 | 2048 | 64 | 48262 | attn_only | 8 | 1 | 1024 | 1.0M |
63 | | attn-only-2l | 512 | 2048 | 64 | 48262 | attn_only | 8 | 2 | 1024 | 2.1M |
64 | | attn-only-3l | 512 | 2048 | 64 | 48262 | attn_only | 8 | 3 | 1024 | 3.1M |
65 | | attn-only-4l | 512 | 2048 | 64 | 48262 | attn_only | 8 | 4 | 1024 | 4.2M |
66 | | attn-only-2l-demo | 512 | 2048 | 64 | 50277 | attn_only | 8 | 2 | 1024 | 2.1M |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/transformer_lens/past_key_value_caching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from dataclasses import dataclass
4 | from typing import Union, Tuple, List, Dict, Any, Optional
5 | from .HookedTransformerConfig import HookedTransformerConfig
6 | from torchtyping import TensorType as TT
7 | from .torchtyping_helper import T
8 |
9 |
10 | @dataclass
11 | class HookedTransformerKeyValueCacheEntry:
12 | past_keys: TT[T.batch, T.pos_so_far, T.n_heads, T.d_head]
13 | past_values: TT[T.batch, T.pos_so_far, T.n_heads, T.d_head]
14 |
15 | @classmethod
16 | def init_cache_entry(
17 | cls,
18 | cfg: HookedTransformerConfig,
19 | device: torch.device,
20 | batch_size: int = 1,
21 | ):
22 | return cls(
23 | past_keys=torch.empty(
24 | (batch_size, 0, cfg.n_heads, cfg.d_head), device=device
25 | ),
26 | past_values=torch.empty(
27 | (batch_size, 0, cfg.n_heads, cfg.d_head), device=device
28 | ),
29 | )
30 |
31 | def append(
32 | self,
33 | new_keys: TT[T.batch, T.new_tokens, T.n_heads, T.d_head],
34 | new_values: TT[T.batch, T.new_tokens, T.n_heads, T.d_head],
35 | ):
36 | updated_keys: TT[
37 | "batch", "pos_so_far + new_tokens", "n_heads", "d_head"
38 | ] = torch.cat([self.past_keys, new_keys], dim=1)
39 | updated_values: TT[
40 | "batch", "pos_so_far + new_tokens", "n_heads", "d_head"
41 | ] = torch.cat([self.past_values, new_values], dim=1)
42 | self.past_keys = updated_keys
43 | self.past_values = updated_values
44 | return updated_keys, updated_values
45 |
46 |
47 | @dataclass
48 | class HookedTransformerKeyValueCache:
49 | """
50 | A cache for storing past keys and values for the Transformer. This is important for generating text - we can cache a lot of past computation and avoid repeating ourselves!
51 |
52 | This cache is a list of HookedTransformerKeyValueCacheEntry objects, one for each layer in the Transformer. Each object stores a [batch, pos_so_far, n_heads, d_head] tensor for both keys and values, and each entry has an append method to add a single new key and value.
53 |
54 | Generation is assumed to be done by initializing with some prompt and then continuing iteratively one token at a time. So append only works for adding a single token's worth of keys and values, and but the cache can be initialized with many.
55 |
56 | """
57 |
58 | entries: List[HookedTransformerKeyValueCacheEntry]
59 |
60 | @classmethod
61 | def init_cache(
62 | cls, cfg: HookedTransformerConfig, device: torch.device, batch_size: int = 1
63 | ):
64 | return cls(
65 | entries=[
66 | HookedTransformerKeyValueCacheEntry.init_cache_entry(
67 | cfg, device, batch_size
68 | )
69 | for _ in range(cfg.n_layers)
70 | ]
71 | )
72 |
73 | def __getitem__(self, idx):
74 | return self.entries[idx]
75 |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/transformer_lens/torchtyping_helper.py:
--------------------------------------------------------------------------------
1 | class T:
2 | """Helper class to get mypy to work with TorchTyping and solidify naming conventions as a byproduct.
3 |
4 | Examples:
5 | - `TT[T.batch, T.pos, T.d_model]`
6 | - `TT[T.num_components, T.batch_and_pos_dims:...]`
7 | """
8 |
9 | batch: str = "batch"
10 | pos: str = "pos"
11 | head_index: str = "head_index"
12 | length: str = "length"
13 | rotary_dim: str = "rotary_dim"
14 | new_tokens: str = "new_tokens"
15 | batch_and_pos_dims: str = "batch_and_pos_dims"
16 | layers_accumulated_over: str = "layers_accumulated_over"
17 | layers_covered: str = "layers_covered"
18 | past_kv_pos_offset: str = "past_kv_pos_offset"
19 | num_components: str = "num_components"
20 | num_neurons: str = "num_neurons"
21 | pos_so_far: str = "pos_so_far"
22 | n_ctx: str = "n_ctx"
23 | n_heads: str = "n_heads"
24 | n_layers: str = "n_layers"
25 | d_vocab: str = "d_vocab"
26 | d_vocab_out: str = "d_vocab_out"
27 | d_head: str = "d_head"
28 | d_mlp: str = "d_mlp"
29 | d_model: str = "d_model"
30 |
31 | ldim: str = "ldim"
32 | rdim: str = "rdim"
33 | new_rdim: str = "new_rdim"
34 | mdim: str = "mdim"
35 | leading_dims: str = "leading_dims"
36 | leading_dims_left: str = "leading_dims_left"
37 | leading_dims_right: str = "leading_dims_right"
38 |
39 | a: str = "a"
40 | b: str = "b"
41 |
42 | pos_plus_past_kv_pos_offset = "pos + past_kv_pos_offset"
43 | d_vocab_plus_n_ctx = "d_vocab + n_ctx"
44 | pos_plus_new_tokens = "pos + new_tokens"
45 |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/transformer_lens/train.py:
--------------------------------------------------------------------------------
1 | from . import HookedTransformer
2 | from . import HookedTransformerConfig
3 | from dataclasses import dataclass
4 | from typing import Optional, Callable
5 | from torch.utils.data import Dataset, DataLoader
6 | import torch.optim as optim
7 | import wandb
8 | import torch
9 | import torch.nn as nn
10 | from tqdm.auto import tqdm
11 | from einops import rearrange
12 |
13 |
14 | @dataclass
15 | class HookedTransformerTrainConfig:
16 | """
17 | Configuration class to store training hyperparameters for a training run of
18 | an HookedTransformer model.
19 | Args:
20 | num_epochs (int): Number of epochs to train for
21 | batch_size (int): Size of batches to use for training
22 | lr (float): Learning rate to use for training
23 | seed (int): Random seed to use for training
24 | momentum (float): Momentum to use for training
25 | max_grad_norm (float, *optional*): Maximum gradient norm to use for
26 | weight_decay (float, *optional*): Weight decay to use for training
27 | training
28 | optimizer_name (str): The name of the optimizer to use
29 | device (str, *optional*): Device to use for training
30 | warmup_steps (int, *optional*): Number of warmup steps to use for training
31 | save_every (int, *optional*): After how many batches should a checkpoint be saved
32 | save_dir, (str, *optional*): Where to save checkpoints
33 | wandb (bool): Whether to use Weights and Biases for logging
34 | wandb_project (str, *optional*): Name of the Weights and Biases project to use
35 | print_every (int, *optional*): Print the loss every n steps
36 | max_steps (int, *optional*): Terminate the epoch after this many steps. Used for debugging.
37 | """
38 |
39 | num_epochs: int
40 | batch_size: int
41 | lr: float = 1e-3
42 | seed: int = 0
43 | momentum: float = 0.0
44 | max_grad_norm: Optional[float] = None
45 | weight_decay: Optional[float] = None
46 | optimizer_name: str = "Adam"
47 | device: Optional[str] = None
48 | warmup_steps: int = 0
49 | save_every: Optional[int] = None
50 | save_dir: Optional[str] = None
51 | wandb: bool = False
52 | wandb_project_name: Optional[str] = None
53 | print_every: Optional[int] = 50
54 | max_steps: Optional[int] = None
55 |
56 |
57 | def train(
58 | model: HookedTransformer,
59 | config: HookedTransformerTrainConfig,
60 | dataset: Dataset,
61 | ) -> HookedTransformer:
62 | """
63 | Trains an HookedTransformer model on an autoregressive language modeling task.
64 | Args:
65 | model: The model to train
66 | config: The training configuration
67 | dataset: The dataset to train on - this function assumes the dataset is
68 | set up for autoregressive language modeling.
69 | Returns:
70 | The trained model
71 | """
72 | torch.manual_seed(config.seed)
73 | model.train()
74 | if config.wandb:
75 | if config.wandb_project_name is None:
76 | config.wandb_project_name = "easy-transformer"
77 | wandb.init(project=config.wandb_project_name, config=vars(config))
78 |
79 | if config.device is None:
80 | config.device = "cuda" if torch.cuda.is_available() else "cpu"
81 |
82 | if config.optimizer_name in ["Adam", "AdamW"]:
83 | # Weight decay in Adam is implemented badly, so use AdamW instead (see PyTorch AdamW docs)
84 | if config.weight_decay is not None:
85 | optimizer = optim.AdamW(
86 | model.parameters(),
87 | lr=config.lr,
88 | weight_decay=config.weight_decay,
89 | )
90 | else:
91 | optimizer = optim.Adam(
92 | model.parameters(),
93 | lr=config.lr,
94 | )
95 | elif config.optimizer_name == "SGD":
96 | optimizer = optim.SGD(
97 | model.parameters(),
98 | lr=config.lr,
99 | weight_decay=config.weight_decay
100 | if config.weight_decay is not None
101 | else 0.0,
102 | momentum=config.momentum,
103 | )
104 | else:
105 | raise ValueError(f"Optimizer {config.optimizer_name} not supported")
106 |
107 | scheduler = None
108 | if config.warmup_steps > 0:
109 | scheduler = optim.lr_scheduler.LambdaLR(
110 | optimizer,
111 | lr_lambda=lambda step: min(1.0, step / config.warmup_steps),
112 | )
113 |
114 | dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
115 |
116 | model.to(config.device)
117 |
118 | for epoch in tqdm(range(1, config.num_epochs + 1)):
119 | samples = 0
120 | for step, batch in tqdm(enumerate(dataloader)):
121 | tokens = batch["tokens"].to(config.device)
122 | loss = model(tokens, return_type="loss")
123 | loss.backward()
124 | if config.max_grad_norm is not None:
125 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
126 | optimizer.step()
127 | if config.warmup_steps > 0:
128 | assert scheduler is not None
129 | scheduler.step()
130 | optimizer.zero_grad()
131 |
132 | samples += tokens.shape[0]
133 |
134 | if config.wandb:
135 | wandb.log(
136 | {"train_loss": loss.item(), "samples": samples, "epoch": epoch}
137 | )
138 |
139 | if config.print_every is not None and step % config.print_every == 0:
140 | print(f"Epoch {epoch} Samples {samples} Step {step} Loss {loss.item()}")
141 |
142 | if (
143 | config.save_every is not None
144 | and step % config.save_every == 0
145 | and config.save_dir is not None
146 | ):
147 | torch.save(model.state_dict(), f"{config.save_dir}/model_{step}.pt")
148 |
149 | if config.max_steps is not None and step >= config.max_steps:
150 | break
151 |
152 | return model
153 |
--------------------------------------------------------------------------------
/subnetwork_probing/transformer_lens/typing_demo.py:
--------------------------------------------------------------------------------
1 | # %%
2 |
3 | import torch as t
4 | from torchtyping import TT as TT, patch_typeguard
5 | from typeguard import typechecked
6 | import einops
7 |
8 | patch_typeguard()
9 |
10 | ZimZam = TT["batch", "feature", float]
11 |
12 |
13 | @typechecked
14 | def test(x: ZimZam) -> ZimZam:
15 | return einops.rearrange(x, "f b -> f b")
16 |
17 |
18 | x = t.rand((10000, 1), dtype=t.float32)
19 |
20 | test(x)
21 |
22 | # what if "batch" and "feature" now take on different values?
23 |
24 | x = t.rand((20000, 2), dtype=t.float32)
25 |
26 | test(x)
27 |
28 | # ah so indeed batch and feature must only be consistent across a single function call
29 |
30 | # now what if we repeat the same strings across type definitions?
31 |
32 | ZimZam2 = TT["batch", "feature", float]
33 |
34 |
35 | @typechecked
36 | def test2(x: ZimZam2) -> ZimZam:
37 | return einops.rearrange(x, "f b -> f b")
38 |
39 |
40 | @typechecked
41 | def test3(x: ZimZam) -> ZimZam2:
42 | return einops.rearrange(x, "f b -> f b")
43 |
44 |
45 | test2(x)
46 | test3(x)
47 |
48 | # so the right mental model is that the decorators register
49 | # a dictionary whose keys are the dimension names and
50 | # whose values are the sizes. and the values must be consistent
51 | # across a single function call
52 |
53 | # now let's watch the type checker fail
54 |
55 |
56 | @typechecked
57 | def test4(x: ZimZam) -> ZimZam:
58 | return einops.rearrange(x, "f b -> b f")
59 |
60 |
61 | # %%
62 |
--------------------------------------------------------------------------------
/tests/acdc/test_acdc.py:
--------------------------------------------------------------------------------
1 | #%%
2 | from copy import deepcopy
3 | from typing import (
4 | List,
5 | Tuple,
6 | Dict,
7 | Any,
8 | Optional,
9 | Union,
10 | Callable,
11 | TypeVar,
12 | Iterable,
13 | Set,
14 | )
15 | import wandb
16 | import IPython
17 | import torch
18 |
19 | from tqdm import tqdm
20 | import random
21 | from functools import *
22 | import json
23 | import pathlib
24 | import warnings
25 | import time
26 | import networkx as nx
27 | import os
28 | import torch
29 | import huggingface_hub
30 | from enum import Enum
31 | import torch.nn as nn
32 | import torch.nn.functional as F
33 | import torch.optim as optim
34 | import numpy as np
35 | import einops
36 | from tqdm import tqdm
37 | import yaml
38 | import gc
39 | from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
40 |
41 | import matplotlib.pyplot as plt
42 | import plotly.express as px
43 | import plotly.io as pio
44 | from plotly.subplots import make_subplots
45 | import plotly.graph_objects as go
46 | from transformer_lens.hook_points import HookedRootModule, HookPoint
47 | from transformer_lens.HookedTransformer import (
48 | HookedTransformer,
49 | )
50 | from acdc.acdc_utils import (
51 | make_nd_dict,
52 | shuffle_tensor,
53 | ct,
54 | )
55 | from acdc.TLACDCEdge import (
56 | TorchIndex,
57 | Edge,
58 | EdgeType,
59 | )
60 | # these introduce several important classes !!!
61 |
62 | from acdc.TLACDCCorrespondence import TLACDCCorrespondence
63 | from acdc.TLACDCInterpNode import TLACDCInterpNode
64 | from acdc.TLACDCExperiment import TLACDCExperiment
65 |
66 | from collections import defaultdict, deque, OrderedDict
67 | from acdc.docstring.utils import get_all_docstring_things
68 | from acdc.greaterthan.utils import get_all_greaterthan_things
69 | from acdc.induction.utils import (
70 | get_all_induction_things,
71 | get_validation_data,
72 | get_good_induction_candidates,
73 | get_mask_repeat_candidates,
74 | )
75 | from acdc.ioi.utils import get_all_ioi_things
76 | from acdc.tracr_task.utils import get_all_tracr_things, get_tracr_model_input_and_tl_model
77 | from acdc.acdc_graphics import (
78 | build_colorscheme,
79 | show,
80 | )
81 | import pytest
82 | from pathlib import Path
83 |
84 | @pytest.mark.slow
85 | @pytest.mark.skip(reason="TODO fix")
86 | def test_induction_several_steps():
87 | # get induction task stuff
88 | num_examples = 400
89 | seq_len = 30
90 | # TODO initialize the `tl_model` with the right model
91 | all_induction_things = get_all_induction_things(num_examples=num_examples, seq_len=seq_len, device="cpu") # removed some randomize seq_len thing - hopefully unimportant
92 | tl_model, toks_int_values, toks_int_values_other, metric = all_induction_things.tl_model, all_induction_things.validation_data, all_induction_things.validation_patch_data, all_induction_things.validation_metric
93 |
94 | gc.collect()
95 | torch.cuda.empty_cache()
96 |
97 | # initialise object
98 | exp = TLACDCExperiment(
99 | model=tl_model,
100 | threshold=0.1,
101 | using_wandb=False,
102 | zero_ablation=False,
103 | ds=toks_int_values,
104 | ref_ds=toks_int_values_other,
105 | metric=metric,
106 | second_metric=None,
107 | verbose=True,
108 | indices_mode="reverse",
109 | names_mode="normal",
110 | corrupted_cache_cpu=True,
111 | online_cache_cpu=True,
112 | add_sender_hooks=True, # attempting to be efficient...
113 | add_receiver_hooks=False,
114 | remove_redundant=True,
115 | )
116 |
117 | for STEP_IDX in range(10):
118 | exp.step()
119 |
120 | edges_to_consider = {edge_tuple: edge for edge_tuple, edge in exp.corr.all_edges().items() if edge.effect_size is not None}
121 |
122 | EDGE_EFFECTS = OrderedDict([
123 | ( ('blocks.1.hook_resid_post', TorchIndex([None]), 'blocks.1.attn.hook_result', TorchIndex([None, None, 6])) , 0.6195546984672546 ),
124 | ( ('blocks.1.hook_resid_post', TorchIndex([None]), 'blocks.1.attn.hook_result', TorchIndex([None, None, 5])) , 0.8417580723762512 ),
125 | ( ('blocks.1.hook_resid_post', TorchIndex([None]), 'blocks.0.attn.hook_result', TorchIndex([None, None, 5])) , 0.1795809268951416 ),
126 | ( ('blocks.1.hook_resid_post', TorchIndex([None]), 'blocks.0.attn.hook_result', TorchIndex([None, None, 4])) , 0.15076303482055664 ),
127 | ( ('blocks.1.hook_resid_post', TorchIndex([None]), 'blocks.0.attn.hook_result', TorchIndex([None, None, 3])) , 0.11805805563926697 ),
128 | ( ('blocks.1.hook_resid_post', TorchIndex([None]), 'blocks.0.hook_resid_pre', TorchIndex([None])) , 0.6345541179180145 ),
129 | ( ('blocks.1.attn.hook_q', TorchIndex([None, None, 6]), 'blocks.1.hook_q_input', TorchIndex([None, None, 6])) , 1.4423644244670868 ),
130 | ( ('blocks.1.attn.hook_q', TorchIndex([None, None, 5]), 'blocks.1.hook_q_input', TorchIndex([None, None, 5])) , 1.2416923940181732 ),
131 | ( ('blocks.1.attn.hook_k', TorchIndex([None, None, 6]), 'blocks.1.hook_k_input', TorchIndex([None, None, 6])) , 1.4157390296459198 ),
132 | ( ('blocks.1.attn.hook_k', TorchIndex([None, None, 5]), 'blocks.1.hook_k_input', TorchIndex([None, None, 5])) , 1.270191639661789 ),
133 | ( ('blocks.1.attn.hook_v', TorchIndex([None, None, 6]), 'blocks.1.hook_v_input', TorchIndex([None, None, 6])) , 2.9806662499904633 ),
134 | ( ('blocks.1.attn.hook_v', TorchIndex([None, None, 5]), 'blocks.1.hook_v_input', TorchIndex([None, None, 5])) , 2.7053256928920746 ),
135 | ( ('blocks.1.hook_v_input', TorchIndex([None, None, 6]), 'blocks.0.attn.hook_result', TorchIndex([None, None, 2])) , 0.12778228521347046 ),
136 | ( ('blocks.1.hook_v_input', TorchIndex([None, None, 6]), 'blocks.0.hook_resid_pre', TorchIndex([None])) , 1.8775241374969482 ),
137 | ])
138 |
139 | assert set(edges_to_consider.keys()) == set(EDGE_EFFECTS.keys()), (set(edges_to_consider.keys()) - set(EDGE_EFFECTS.keys()), set(EDGE_EFFECTS.keys()) - set(edges_to_consider.keys()), EDGE_EFFECTS.keys())
140 |
141 | for edge_tuple, edge in edges_to_consider.items():
142 | assert abs(edge.effect_size - EDGE_EFFECTS[edge_tuple]) < 1e-5, (edge_tuple, edge.effect_size, EDGE_EFFECTS[edge_tuple])
143 |
144 | @pytest.mark.slow
145 | @pytest.mark.parametrize("task, metric", [
146 | ("tracr-proportion", "l2"),
147 | ("tracr-reverse", "l2"),
148 | ("docstring", "kl_div"),
149 | ("induction", "kl_div"),
150 | ("ioi", "kl_div"),
151 | ("greaterthan", "kl_div"),
152 | ])
153 | def test_main_script(task, metric):
154 | import subprocess
155 |
156 | main_path = Path(__file__).resolve().parent.parent.parent / "acdc" / "main.py"
157 | subprocess.check_call(["python", str(main_path), f"--task={task}", "--threshold=1234", "--single-step", "--device=cpu", f"--metric={metric}"])
158 |
159 | def test_editing_edges_notebook():
160 | import notebooks.editing_edges
161 |
162 |
163 |
164 | @pytest.mark.parametrize("task", ["tracr-proportion", "tracr-reverse", "docstring", "induction", "ioi", "greaterthan"])
165 | @pytest.mark.parametrize("zero_ablation", [False, True])
166 | def test_full_correspondence_zero_kl(task, zero_ablation, device="cpu", metric_name="kl_div", num_examples=4, seq_len=10):
167 | if task == "tracr-proportion":
168 | things = get_all_tracr_things(task="proportion", num_examples=num_examples, device=device, metric_name="l2")
169 | elif task == "tracr-reverse":
170 | things = get_all_tracr_things(task="reverse", num_examples=6, device=device, metric_name="l2")
171 | elif task == "induction":
172 | things = get_all_induction_things(num_examples=100, seq_len=20, device=device, metric=metric_name)
173 | elif task == "ioi":
174 | things = get_all_ioi_things(num_examples=num_examples, device=device, metric_name=metric_name)
175 | elif task == "docstring":
176 | things = get_all_docstring_things(num_examples=num_examples, seq_len=seq_len, device=device, metric_name=metric_name, correct_incorrect_wandb=False)
177 | elif task == "greaterthan":
178 | things = get_all_greaterthan_things(num_examples=num_examples, metric_name=metric_name, device=device)
179 | else:
180 | raise ValueError(task)
181 |
182 | exp = TLACDCExperiment(
183 | model=things.tl_model,
184 | threshold=100_000,
185 | early_exit=False,
186 | using_wandb=False,
187 | zero_ablation=zero_ablation,
188 | ds=things.test_data,
189 | ref_ds=things.test_patch_data,
190 | metric=things.validation_metric,
191 | second_metric=None,
192 | verbose=True,
193 | use_pos_embed=False, # In the case that this is True, the KL should not be zero.
194 | online_cache_cpu=True,
195 | corrupted_cache_cpu=True,
196 | )
197 | exp.setup_corrupted_cache()
198 |
199 | corr = deepcopy(exp.corr)
200 | for e in corr.all_edges().values():
201 | e.present = True
202 |
203 | with torch.no_grad():
204 | out = exp.call_metric_with_corr(corr, things.test_metrics["kl_div"], things.test_data)
205 | assert abs(out) < 1e-6, f"{out} should be abs(out) < 1e-6"
206 |
--------------------------------------------------------------------------------
/tests/acdc/test_greaterthan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from acdc.greaterthan.utils import greaterthan_metric_reference, greaterthan_metric, get_year_data
3 | from acdc.ioi.utils import get_gpt2_small
4 |
5 |
6 | def test_greaterthan_metric():
7 | model = get_gpt2_small(device="cpu")
8 | data, _ = get_year_data(20, model)
9 | logits = model(data)
10 |
11 | expected = greaterthan_metric_reference(logits, data)
12 | actual = greaterthan_metric(logits, data)
13 | torch.testing.assert_close(actual, torch.as_tensor(expected))
14 |
--------------------------------------------------------------------------------
/tests/subnetwork_probing/test_count_nodes.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import pygraphviz as pgv
4 | from acdc.TLACDCCorrespondence import TLACDCCorrespondence
5 | from acdc.TLACDCInterpNode import TLACDCInterpNode
6 | from acdc.TLACDCEdge import EdgeType, TorchIndex
7 | from acdc.acdc_graphics import show
8 | import tempfile
9 | import os
10 |
11 | from subnetwork_probing.transformer_lens.transformer_lens.HookedTransformer import HookedTransformer
12 | import pytest
13 | from pathlib import Path
14 | import sys
15 |
16 | sys.path.append(str(Path(__file__).parent.parent / "code"))
17 |
18 | from subnetwork_probing.train import iterative_correspondence_from_mask, get_transformer_config
19 | import networkx as nx
20 | from acdc.TLACDCInterpNode import parse_interpnode
21 |
22 | def delete_nested_dict(d: dict, keys: list):
23 | inner_dicts = [d]
24 | try:
25 | for k in keys[:-1]:
26 | inner_dicts.append(inner_dicts[-1][k])
27 |
28 | del inner_dicts[-1][keys[-1]]
29 | except KeyError:
30 | return
31 | assert len(inner_dicts) == len(keys)
32 |
33 | if len(inner_dicts[-1]) == 0:
34 | for k_to_delete, inner_dict in reversed(list(zip(keys, inner_dicts))):
35 | assert not isinstance(inner_dict, dict) or len(inner_dict) == 0
36 | try:
37 | del inner_dict[k_to_delete]
38 | except KeyError:
39 | return
40 | if len(inner_dict) > 0:
41 | break
42 |
43 | def test_count_nodes():
44 | nodes_to_mask_str = [
45 | "blocks.0.attn.hook_q[COL, COL, 0]",
46 | "blocks.0.attn.hook_k[COL, COL, 0]",
47 | "blocks.0.attn.hook_q[COL, COL, 1]",
48 | "blocks.0.attn.hook_k[COL, COL, 1]",
49 | "blocks.0.attn.hook_v[COL, COL, 1]",
50 | "blocks.0.attn.hook_q[COL, COL, 2]",
51 | "blocks.0.attn.hook_k[COL, COL, 2]",
52 | "blocks.0.attn.hook_v[COL, COL, 2]",
53 | "blocks.0.attn.hook_q[COL, COL, 3]",
54 | "blocks.0.attn.hook_k[COL, COL, 3]",
55 | "blocks.0.attn.hook_v[COL, COL, 3]",
56 | "blocks.0.attn.hook_q[COL, COL, 4]",
57 | "blocks.0.attn.hook_k[COL, COL, 4]",
58 | "blocks.0.attn.hook_v[COL, COL, 4]",
59 | "blocks.0.attn.hook_q[COL, COL, 5]",
60 | "blocks.0.attn.hook_k[COL, COL, 5]",
61 | "blocks.0.attn.hook_q[COL, COL, 6]",
62 | "blocks.0.attn.hook_k[COL, COL, 6]",
63 | "blocks.0.attn.hook_q[COL, COL, 7]",
64 | "blocks.0.attn.hook_k[COL, COL, 7]",
65 | "blocks.1.attn.hook_q[COL, COL, 0]",
66 | "blocks.1.attn.hook_k[COL, COL, 0]",
67 | "blocks.1.attn.hook_v[COL, COL, 0]",
68 | "blocks.1.attn.hook_q[COL, COL, 1]",
69 | "blocks.1.attn.hook_k[COL, COL, 1]",
70 | "blocks.1.attn.hook_v[COL, COL, 1]",
71 | "blocks.1.attn.hook_q[COL, COL, 2]",
72 | "blocks.1.attn.hook_k[COL, COL, 2]",
73 | "blocks.1.attn.hook_v[COL, COL, 2]",
74 | "blocks.1.attn.hook_q[COL, COL, 3]",
75 | "blocks.1.attn.hook_k[COL, COL, 3]",
76 | "blocks.1.attn.hook_q[COL, COL, 4]",
77 | "blocks.1.attn.hook_k[COL, COL, 4]",
78 | "blocks.1.attn.hook_q[COL, COL, 5]",
79 | "blocks.1.attn.hook_k[COL, COL, 5]",
80 | "blocks.1.attn.hook_q[COL, COL, 6]",
81 | "blocks.1.attn.hook_k[COL, COL, 6]",
82 | "blocks.1.attn.hook_v[COL, COL, 6]",
83 | "blocks.1.attn.hook_q[COL, COL, 7]",
84 | "blocks.1.attn.hook_k[COL, COL, 7]",
85 | ]
86 | nodes_to_mask = [parse_interpnode(s) for s in nodes_to_mask_str]
87 | nodes_to_mask2 = [
88 | TLACDCInterpNode(
89 | n.name.replace(".attn", "") + "_input", n.index, EdgeType.ADDITION
90 | )
91 | for n in nodes_to_mask
92 | ]
93 | nodes_to_mask += nodes_to_mask2
94 |
95 | cfg = get_transformer_config()
96 | model = HookedTransformer(cfg, is_masked=True)
97 |
98 | corr = TLACDCCorrespondence.setup_from_model(model)
99 | for child_hook_name in corr.edges:
100 | for child_index in corr.edges[child_hook_name]:
101 | for parent_hook_name in corr.edges[child_hook_name][child_index]:
102 | for parent_index in corr.edges[child_hook_name][child_index][
103 | parent_hook_name
104 | ]:
105 | edge = corr.edges[child_hook_name][child_index][parent_hook_name][
106 | parent_index
107 | ]
108 |
109 | if all(
110 | (child_hook_name != n.name or child_index != n.index)
111 | for n in nodes_to_mask
112 | ) and all(
113 | (parent_hook_name != n.name or parent_index != n.index)
114 | for n in nodes_to_mask
115 | ):
116 | edge.effect_size = 1
117 |
118 | with tempfile.TemporaryDirectory() as tmpdir:
119 | g = show(corr, os.path.join(tmpdir, "out.png"), show_full_index=False)
120 | assert isinstance(g, pgv.AGraph)
121 | path = Path(tmpdir) / "out.gv"
122 | assert path.exists()
123 | # In advance I predict that it should be 41
124 | g2 = nx.nx_agraph.read_dot(path)
125 |
126 | to_delete = []
127 | for n in g2.nodes:
128 | if not nx.has_path(g2, "embed", n) or not nx.has_path(g2, n, ""):
129 | to_delete.append(n)
130 |
131 | for n in to_delete:
132 | g2.remove_node(n)
133 |
134 | # Delete self-loops
135 | for n in g2.nodes:
136 | if g2.has_edge(n, n):
137 | g2.remove_edge(n, n)
138 | assert len(g2.edges) == 41
139 |
140 | corr, _ = iterative_correspondence_from_mask(model, nodes_to_mask)
141 | assert corr.count_no_edges() == 41
142 |
--------------------------------------------------------------------------------
/tests/subnetwork_probing/test_sp_launch.py:
--------------------------------------------------------------------------------
1 | import subnetwork_probing.launch_grid_fill
2 | import pytest
3 |
4 | @pytest.mark.slow
5 | @pytest.mark.parametrize("reset_networks", [True, False])
6 | def test_sp_grid(reset_networks):
7 | tasks = ["tracr-reverse", "tracr-proportion", "docstring", "induction", "greaterthan", "ioi"]
8 | subnetwork_probing.launch_grid_fill.main(TASKS=tasks, job=None, name="sp-test", testing=True, reset_networks=reset_networks)
9 |
--------------------------------------------------------------------------------